本站原创文章,转载请说明来自《老饼讲解-深度学习》www.bbbdata.com
在使用pytorch玩深度学习时,一般会需要导入自己的图片数据
本文讲解,如何在pytorch中将自己的数据封装到DataSet中
本节讲解在pytorch导入自己的图片数据的步骤
第一次理解如何在pytorch中将图片数据封装成DataSet相对较为麻烦,第二次就不难了
在pytorch中导入自己的图片数据到DataSet中,只需要如下四步:
👉1.将自己的图片保存在一个文件夹
👉2.给每个图片写上标注
👉3.编写数据类函数xxxDataset.py
该文件需要定义一个数据类,用于读取数据
其中,类里必须实现__init__、__len__、__getitem__这三个方法
👉4.将图片与标注数据装载到xxxDataset.py中
✍️笔者语:只看上述步骤是有些抽象的,依照下述实例具体操作一遍就熟悉了
本节展示一个在pytorch导入自己的图片数据到DataSet中的实例
下面我们通过一个实例,展示具体如何在pytorch导入自己的图片数据
1. 图片与标签数据文件
新建一个img文件夹,img里存放了要导入的图片,
新建一个img_label.csv,img_label.csv里编写了各张图片对应的标签名称,
最终结果如下:
👉图片素材与img_label内容见文末附件
2.编码CustomImageDataset类
编写数据类CustomImageDataset的代码(CustomImageDataset只是一个命名,也可以修改成其它名称)
CustomImageDataset里要实现__init__、__len__、__getitem__这三个方法
具体代码如下:
# -*- coding: utf-8 -*-
import os
import pandas as pd
from torchvision.io import read_image
from torch.utils.data import Dataset
class CustomImageDataset(Dataset):
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
self.img_labels = pd.read_csv(annotations_file, header=None) # 从CSV中读取图象标签
self.img_dir = img_dir # 存放图片的文件夹
self.transform = transform # 图片的转换函数
self.target_transform = target_transform # 标签的转换函数
def __len__(self):
return len(self.img_labels) # 标签的长度就是样本个数
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0]) # 图片路径
image = read_image(img_path) # 读取图片
label = self.img_labels.iloc[idx, 1] # 读取标签
if self.transform: # 如果有图片的转换函数
image = self.transform(image) # 就对图片进行转换
if self.target_transform: # 如果有标签的转换函数
label = self.target_transform(label) # 就对标签进行转换
return image, label # 返回图片和标签
3.装载数据并使用数据
在CustomImageDataset中传入需要读取的图像路径和相关参数,就可以加载数据
这里为了展示数据的可用性,我们除了装载数据外,还进一步将数据封装进DataLoader中使用演示
具体示例代码如下:
# -*- coding: utf-8 -*-
from CustomImageDataset import CustomImageDataset
from torch.utils.data import DataLoader
from torchvision import transforms
import matplotlib.pyplot as plt
# 通过数据类读取图片数据
img_dir = 'D:\pytorch\data\mydata\img' # 图象文件夹
label_file = 'D:\pytorch\data\mydata\img_label.csv' # 标签文件
myDataSet = CustomImageDataset(label_file,img_dir,transform=None,target_transform=None) # 初始化数据类
# 将图片数据装载到DataLoader中使用
train_dataloader = DataLoader(myDataSet, batch_size=2, shuffle=False) # 将数据装载到DataLoader
dataloader_len = len(train_dataloader) # dataloader的批数
batch_size = train_dataloader.batch_size # dataloader每批的大小
# ---------从DataLoader中获取图片并进行打印-------------
figure = plt.figure(figsize=(8, 8)) # 初始化图片,figsize可以指定图片的大小
for i,data in enumerate(train_dataloader):
img, label = data # 读取本批的图象与标签
for j in range(len(label)): # 对本批图象进行逐个打印
figure.add_subplot(dataloader_len, batch_size, i*batch_size+j+1) # 添加到figure
plt.imshow(transforms.ToPILImage()(img[j]), cmap="gray") # 绘图
plt.show()
运行后结果显示如下:
可以看到,已经可以把图片打印进来了
附件
如果要复现笔者的实例,可以复制以下内容:
1.img图片(记得给图片修改对应名称):
2.img_label.csv的内容:
img_01.png,1
img_02.png,2
img_03.png,2
img_04.png,1
End