Pytorch教程

【历遍】pytorch-DataLoader的多种历遍方法

作者 : 老饼 发表日期 : 2023-07-28 10:49:08 更新日期 : 2024-01-19 08:33:19
本站原创文章,转载请说明来自《老饼讲解-深度学习》www.bbbdata.com



本节讲解在pytorch中历遍DataLoader多种实现方式,包括直接历遍、enumerate、iter等等

通过本文,可以选择和借鉴适用的DataLoader历遍方法,满足pytorch编程中的简洁和便利





  一、DataLoader的多种迭代方法-数据准备    



本节为下节讲解DataLoader的多种迭代方法时,先统一准备一个可用的DataLoder



1. DataLoader数据准备


在展示DataLoader的多种迭代方法之前,我们先准备一个DataLoader数据

用于讲解各种迭代方法的DataLoader示例如下:

# -*- coding: utf-8 -*-
from torchvision import datasets
from torchvision.transforms import ToTensor
from torch.utils.data   import DataLoader

# 路径,用于存放下载的数据(需要改成自己的目录)
root= 'D:\pytorch\data'

# 获取训练数据
img_data = datasets.FashionMNIST(
    root      = root,              # 路径,如果路径有,就直接从路径中加载,如果没有,就联网获取
    train     = True,              # true为下载训练数据集,false为下载测试数据集
    download  = True,              # 是否下载,选为True,就下载到root下面
    transform = ToTensor()         # 转换为tensor数据
)

#-----------------随机提取一些样本来看看-------------------------------------------------
# 将图片数据装载到DataLoader中使用
dataloader = DataLoader(img_data, batch_size=6, shuffle=False)     # 将数据装载到DataLoader





二、DataLoader的多种迭代方法



本节讲解pytorch中历遍DataLoder的多种方法,包括enumerate、iter等多种方法



1.直接历遍DataLoader


使用in语法,直接提出Dataloader的数据来使用,示例如下:

t = 0                                                    # 为方便控制打印的数量,记录迭代次数
for imgs,labels in dataloader:                           # 从dataloader中逐批提取出img和label
    if t<3 :                                             # 打印几个信息看一看
        print('第'+str(t)+'批数据的标签:',labels)         # 打印t和label,img太大就不打印了
    t = t+1                                              # 更新迭代次数

运行结果如下:

第0批数据的标签: tensor([9, 0, 0, 3, 0, 2])
第1批数据的标签: tensor([7, 2, 5, 5, 0, 9])
第2批数据的标签: tensor([5, 5, 7, 9, 1, 0])

✍️解说:该方法的好处是可以直接使用数据,但缺点是需要另起变量来记录索引



2.用enumerate历遍DataLoader


通过enumerate迭代DataLoader可以在每次迭代时同时把数据与索引提取出来

示例如下:

for i,data in enumerate(dataloader):                     # i是批次编号,data里存放img和label
    imgs, labels = data                                  # 提取出img和label
    if i<3 :                                             # 打印几个信息看一看
        print('第'+str(i)+'批数据的标签:',labels)         # 打印i和label,img太大就不打印了

运行结果如下:

第0批数据的标签: tensor([9, 0, 0, 3, 0, 2])
第1批数据的标签: tensor([7, 2, 5, 5, 0, 9])
第2批数据的标签: tensor([5, 5, 7, 9, 1, 0])

✍️解说:该方法的好处是可以直接使用索引,但缺点是需要进一步具体地提出数据



3. 迭代器的方式根据批数历遍DataLoader

先将DataLoader转换为迭代器类型,再根据数据的批数对DataLoader进行迭代

示例如下:

dataloader_iter = iter(dataloader)                                     # 将dataloader转为迭代对象
batch_num  = len(dataloader_iter)                                      # 获取数据的批数
# -----对dataloader进行历遍,并打印出每批数据的第一条数据(图片)-----
for i in range(batch_num):
    imgs, labels = next(dataloader_iter)                               # 获取本批data(img)和label    
    if i<3 :                                                           # 打印几个信息看一看
            print('第'+str(i)+'批数据的标签:',labels)                   # 打印t和label,img太大就不打印了

运行结果如下:

第0批数据的标签: tensor([9, 0, 0, 3, 0, 2])
第1批数据的标签: tensor([7, 2, 5, 5, 0, 9])
第2批数据的标签: tensor([5, 5, 7, 9, 1, 0])



4.传统迭代器迭代方式历遍DataLoader


将DataLoader转为迭代器,再用传统迭代器的历遍方式进行迭代

示例如下:

loader_iterator = iter(dataloader)                          # 将dataloader转为迭代对象
t = 0                                                        # 为方便控制打印的数量,记录迭代次数
try:
    while True:
        imgs, labels = next(loader_iterator)                 # 获取本批img和label
        if t<3 :                                             # 打印几个信息看一看
            print('第'+str(t)+'批数据的标签:',labels)         # 打印t和label,img太大就不打印了
        t = t+1                                              # 迭代次数+1
except StopIteration:
    pass

运行结果如下:

第0批数据的标签: tensor([9, 0, 0, 3, 0, 2])
第1批数据的标签: tensor([7, 2, 5, 5, 0, 9])
第2批数据的标签: tensor([5, 5, 7, 9, 1, 0])





    拓展:关于只提取Dataloader的首批数据     



特别地,如果只想把DataLoader的第一批数据提出来看一看,则不需要for循环,只需用iter的方式就可以

示例如下:

dataloader_iter = iter(dataloader)                                     # 将dataloader转为迭代对象
imgs, labels = next(dataloader_iter)                                   # 获取首批data(img)和label    
print('首批数据的标签:',labels)                                         # 打印t和label,img太大就不打印了

运行结果如下:

首批数据的标签: tensor([9, 0, 0, 3, 0, 2])

解说:用iter历遍DataLoader虽然不太简洁,但用于提取首批数据却是最简洁的,因为避免了繁琐的for循环

          所以iter的方式在数据探索时往往是最常用、最好用的






好了,以上就是pytorch中DataLoader的多种历遍方法了~








 End 






联系老饼