本站原创文章,转载请说明来自《老饼讲解-深度学习》www.bbbdata.com
本节讲解在pytorch中历遍DataLoader多种实现方式,包括直接历遍、enumerate、iter等等
通过本文,可以选择和借鉴适用的DataLoader历遍方法,满足pytorch编程中的简洁和便利
本节为下节讲解DataLoader的多种迭代方法时,先统一准备一个可用的DataLoder
1. DataLoader数据准备
在展示DataLoader的多种迭代方法之前,我们先准备一个DataLoader数据
用于讲解各种迭代方法的DataLoader示例如下:
# -*- coding: utf-8 -*-
from torchvision import datasets
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
# 路径,用于存放下载的数据(需要改成自己的目录)
root= 'D:\pytorch\data'
# 获取训练数据
img_data = datasets.FashionMNIST(
root = root, # 路径,如果路径有,就直接从路径中加载,如果没有,就联网获取
train = True, # true为下载训练数据集,false为下载测试数据集
download = True, # 是否下载,选为True,就下载到root下面
transform = ToTensor() # 转换为tensor数据
)
#-----------------随机提取一些样本来看看-------------------------------------------------
# 将图片数据装载到DataLoader中使用
dataloader = DataLoader(img_data, batch_size=6, shuffle=False) # 将数据装载到DataLoader
本节讲解pytorch中历遍DataLoder的多种方法,包括enumerate、iter等多种方法
1.直接历遍DataLoader
使用in语法,直接提出Dataloader的数据来使用,示例如下:
t = 0 # 为方便控制打印的数量,记录迭代次数
for imgs,labels in dataloader: # 从dataloader中逐批提取出img和label
if t<3 : # 打印几个信息看一看
print('第'+str(t)+'批数据的标签:',labels) # 打印t和label,img太大就不打印了
t = t+1 # 更新迭代次数
运行结果如下:
第0批数据的标签: tensor([9, 0, 0, 3, 0, 2])
第1批数据的标签: tensor([7, 2, 5, 5, 0, 9])
第2批数据的标签: tensor([5, 5, 7, 9, 1, 0])
✍️解说:该方法的好处是可以直接使用数据,但缺点是需要另起变量来记录索引
2.用enumerate历遍DataLoader
通过enumerate迭代DataLoader可以在每次迭代时同时把数据与索引提取出来
示例如下:
for i,data in enumerate(dataloader): # i是批次编号,data里存放img和label
imgs, labels = data # 提取出img和label
if i<3 : # 打印几个信息看一看
print('第'+str(i)+'批数据的标签:',labels) # 打印i和label,img太大就不打印了
运行结果如下:
第0批数据的标签: tensor([9, 0, 0, 3, 0, 2])
第1批数据的标签: tensor([7, 2, 5, 5, 0, 9])
第2批数据的标签: tensor([5, 5, 7, 9, 1, 0])
✍️解说:该方法的好处是可以直接使用索引,但缺点是需要进一步具体地提出数据
3. 迭代器的方式根据批数历遍DataLoader
先将DataLoader转换为迭代器类型,再根据数据的批数对DataLoader进行迭代
示例如下:
dataloader_iter = iter(dataloader) # 将dataloader转为迭代对象
batch_num = len(dataloader_iter) # 获取数据的批数
# -----对dataloader进行历遍,并打印出每批数据的第一条数据(图片)-----
for i in range(batch_num):
imgs, labels = next(dataloader_iter) # 获取本批data(img)和label
if i<3 : # 打印几个信息看一看
print('第'+str(i)+'批数据的标签:',labels) # 打印t和label,img太大就不打印了
运行结果如下:
第0批数据的标签: tensor([9, 0, 0, 3, 0, 2])
第1批数据的标签: tensor([7, 2, 5, 5, 0, 9])
第2批数据的标签: tensor([5, 5, 7, 9, 1, 0])
4.传统迭代器迭代方式历遍DataLoader
将DataLoader转为迭代器,再用传统迭代器的历遍方式进行迭代
示例如下:
loader_iterator = iter(dataloader) # 将dataloader转为迭代对象
t = 0 # 为方便控制打印的数量,记录迭代次数
try:
while True:
imgs, labels = next(loader_iterator) # 获取本批img和label
if t<3 : # 打印几个信息看一看
print('第'+str(t)+'批数据的标签:',labels) # 打印t和label,img太大就不打印了
t = t+1 # 迭代次数+1
except StopIteration:
pass
运行结果如下:
第0批数据的标签: tensor([9, 0, 0, 3, 0, 2])
第1批数据的标签: tensor([7, 2, 5, 5, 0, 9])
第2批数据的标签: tensor([5, 5, 7, 9, 1, 0])
拓展:关于只提取Dataloader的首批数据
特别地,如果只想把DataLoader的第一批数据提出来看一看,则不需要for循环,只需用iter的方式就可以
示例如下:
dataloader_iter = iter(dataloader) # 将dataloader转为迭代对象
imgs, labels = next(dataloader_iter) # 获取首批data(img)和label
print('首批数据的标签:',labels) # 打印t和label,img太大就不打印了
运行结果如下:
首批数据的标签: tensor([9, 0, 0, 3, 0, 2])
解说:用iter历遍DataLoader虽然不太简洁,但用于提取首批数据却是最简洁的,因为避免了繁琐的for循环
所以iter的方式在数据探索时往往是最常用、最好用的
好了,以上就是pytorch中DataLoader的多种历遍方法了~
End