Pytorch教程

【数据】pytorch导入自定义图片数据到DataSet

作者 : 老饼 发表日期 : 2023-07-28 10:47:15 更新日期 : 2023-12-12 10:40:17
本站原创文章,转载请说明来自《老饼讲解-深度学习》www.bbbdata.com



在使用pytorch玩深度学习时,一般会需要导入自己的图片数据

本文讲解,如何在pytorch中将自己的数据封装到DataSet中




一、pytorch导入自己的图片数据-步骤与概述   



本节讲解在pytorch导入自己的图片数据的步骤



01. pytorch自实现DataSet的步骤

第一次理解如何在pytorch中将图片数据封装成DataSet相对较为麻烦,第二次就不难了


在pytorch中导入自己的图片数据到DataSet中,只需要如下四步:

👉1.将自己的图片保存在一个文件夹

👉2.给每个图片写上标注

👉3.编写数据类函数xxxDataset.py

     该文件需要定义一个数据类,用于读取数据

     其中,类里必须实现__init__、__len__、__getitem__这三个方法

👉4.将图片与标注数据装载到xxxDataset.py中


✍️笔者语:只看上述步骤是有些抽象的,依照下述实例具体操作一遍就熟悉了





02.pytorch导入自己的图片数据-实例讲解   



本节展示一个在pytorch导入自己的图片数据到DataSet中的实例



01. pytorch自实现DataSet实例


下面我们通过一个实例,展示具体如何在pytorch导入自己的图片数据


1. 图片与标签数据文件

新建一个img文件夹,img里存放了要导入的图片,

新建一个img_label.csv,img_label.csv里编写了各张图片对应的标签名称,

最终结果如下:

  

👉图片素材与img_label内容见文末附件


2.编码CustomImageDataset类

编写数据类CustomImageDataset的代码(CustomImageDataset只是一个命名,也可以修改成其它名称)

CustomImageDataset里要实现__init__、__len__、__getitem__这三个方法

具体代码如下:

# -*- coding: utf-8 -*-
import os
import pandas as pd
from torchvision.io   import read_image
from torch.utils.data import Dataset

class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file, header=None)            # 从CSV中读取图象标签
        self.img_dir    = img_dir                                               # 存放图片的文件夹
        self.transform  = transform                                             # 图片的转换函数
        self.target_transform = target_transform                                # 标签的转换函数
                                                                                
    def __len__(self):                                                          
        return len(self.img_labels)                                             # 标签的长度就是样本个数

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])     # 图片路径
        image    = read_image(img_path)                                         # 读取图片
        label    = self.img_labels.iloc[idx, 1]                                 # 读取标签
        if self.transform:                                                      # 如果有图片的转换函数
            image = self.transform(image)                                       # 就对图片进行转换
        if self.target_transform:                                               # 如果有标签的转换函数
            label = self.target_transform(label)                                # 就对标签进行转换
        return image, label                                                     # 返回图片和标签



3.装载数据并使用数据

在CustomImageDataset中传入需要读取的图像路径和相关参数,就可以加载数据

这里为了展示数据的可用性,我们除了装载数据外,还进一步将数据封装进DataLoader中使用演示

具体示例代码如下:

# -*- coding: utf-8 -*-
from CustomImageDataset import CustomImageDataset
from torch.utils.data   import DataLoader
from torchvision        import transforms
import matplotlib.pyplot as plt

# 通过数据类读取图片数据
img_dir          = 'D:\pytorch\data\mydata\img'                                                 # 图象文件夹
label_file       = 'D:\pytorch\data\mydata\img_label.csv'                                       # 标签文件
myDataSet        = CustomImageDataset(label_file,img_dir,transform=None,target_transform=None)  # 初始化数据类

# 将图片数据装载到DataLoader中使用
train_dataloader = DataLoader(myDataSet, batch_size=2, shuffle=False)                           # 将数据装载到DataLoader
dataloader_len   = len(train_dataloader)                                                        # dataloader的批数
batch_size       = train_dataloader.batch_size                                                  # dataloader每批的大小

# ---------从DataLoader中获取图片并进行打印-------------
figure = plt.figure(figsize=(8, 8))                                                            # 初始化图片,figsize可以指定图片的大小
for i,data in enumerate(train_dataloader):
    img, label = data                                                                           # 读取本批的图象与标签
    for j in range(len(label)):                                                                 # 对本批图象进行逐个打印
        figure.add_subplot(dataloader_len, batch_size, i*batch_size+j+1)                        # 添加到figure
        plt.imshow(transforms.ToPILImage()(img[j]), cmap="gray")                                # 绘图
plt.show()      

运行后结果显示如下:

  
可以看到,已经可以把图片打印进来了





  附件    


如果要复现笔者的实例,可以复制以下内容:
1.img图片(记得给图片修改对应名称):

2.img_label.csv的内容:
img_01.png,1
img_02.png,2
img_03.png,2
img_04.png,1












 End 



联系老饼