Pytorch教程

【说明】初识pytorch的DataSet与DataLoader

作者 : 老饼 发表日期 : 2023-07-28 10:49:18 更新日期 : 2024-01-19 08:32:37
本站原创文章,转载请说明来自《老饼讲解-深度学习》www.bbbdata.com



pytorch是为深度学习量身打造的框架,但由于深度学习的数据量一般较大,

因此,pytorch特地为深度学习提供了一套数据类与数据方法来管理数据集

通过本文可以初步了解在pytorch中是以什么形式来使用数据集的,并简单认识DataLoder和DataSet是什么




  一、初识DataLoader   



本节简单讲解和认识pytorch中是以什么形式使用数据的,并初步了解DataLoader和DataSet是什么



1. pytorch中的DataLoader


一般来说,pytorch并不直接使用表格形式的数据进行建模,而是将数据装到DataLoader中再进行使用,

DataLoader提供了一些常用数据处理API,通过DataLoader来使用数据在建模过程中会更加便利

DataLoader在某种意义上,充当了数据中心的角色,把数据装载到DataLoader后,再由DataLoader来提供各种数据服务


2. pytorch中的数据集DataSet


在pytorch中,由于数据集一般都加载到DataLoader中进行使用,

所以,pytorch中所说的数据集也并非一个表格数据(例如numpy或DataFrame之类的表格数据)

而是一个由DataLoader约定形式的DataSet数据类,

✍️备注:DataSet长怎么样子,我们在后续文章中再具体进行讲解


3. pytorch使用数据的简单总结


总的来说,在pytorch中一般需要按如下形式使用数据:

1. 先要将原始数据(原始数据可能是csv,也可能是numpy之类的表格数据)按指定方式封装成DataSet

2. 然后把DataSet装载到DataLoader中

3. 最后建模时通过DataLoader来调用数据





二、pytorch中的数据使用示例



本节通过一个示例,展示在pytorch中是怎么通过DataSet和DataLoader来使用数据集的



1.DataSe-DataLoder的使用示例


下面通过一个示例来认识pytorch是以什么方式使用数据集的

示例如下:

# -*- coding: utf-8 -*-
"DataLoader的一个简单用例"

from torchvision import datasets
from torchvision.transforms import ToTensor
from torch.utils.data   import DataLoader
import matplotlib.pyplot as plt

# ---------获取自带数据FashionMNIST-----------------------------
# 路径,用于存放下载的数据(需要改成自己的目录)
root= 'D:\pytorch\data'

# 获取训练数据
img_data = datasets.FashionMNIST(
    root      = root,                                                  # 路径,如果路径有,就直接从路径中加载,如果没有,就联网获取
    train     = True,                                                  # true为下载训练数据集,false为下载测试数据集
    download  = True,                                                  # 是否下载,选为True,就下载到root下面
    transform = ToTensor()                                             # 转换为tensor数据
)

# ---------将图片数据装载到DataLoader中使用------------------------
dataloader = DataLoader(img_data, batch_size=20000, shuffle=False)     # 将数据装载到DataLoader,每批20000条数据
batch_num  = len(dataloader)                                           # 获取数据的批数
dataloader_iter = iter(dataloader)                                     # 将dataloader转为迭代对象

# -----对dataloader进行历遍,并打印出每批数据的第一条数据(图片)-----
figure     = plt.figure(figsize=(4, 4))                                # 初始化图片,figsize可以指定图片的大小
for i in range(batch_num):
    # 获取本批data(img)和label
    imgs, labels = next(dataloader_iter)                               # 获取本批data(img)和label
                       
    # 将该批的第一条数据打印出来
    figure.add_subplot(1,batch_num,i+1)                                # 添加到figure
    plt.title(labels[0].item())                                        # 样本标签 
    plt.axis("off")                                                    # 不展示坐标轴
    plt.imshow(imgs[0].squeeze(), cmap="gray")                         # 展示本批图片的第一张
plt.show()

运行结果如下:


2. 示例逐步解说


上述示例共包含了三大块,下面我们进行逐块解说


👉1.数据下载

# ---------获取自带数据FashionMNIST-----------------------------
# 路径,用于存放下载的数据(需要改成自己的目录)
root= 'D:\pytorch\data'

# 获取训练数据
img_data = datasets.FashionMNIST(
    root      = root,                                                  # 路径,如果路径有,就直接从路径中加载,如果没有,就联网获取
    train     = True,                                                  # true为下载训练数据集,false为下载测试数据集
    download  = True,                                                  # 是否下载,选为True,就下载到root下面
    transform = ToTensor()                                             # 转换为tensor数据
)

该部分通过datasets去下载pytorch官方提供的FashionMNIST数据集,

注意,这里下载到的img_data(FashionMNIST数据集)是一个DataSet类,并不是一个DataFrame之类的表格数据

关于怎么下载官方数据、或者将自己的数据封装成一个DataSet,在后续文章中再进行详细解说


👉2.数据装载

# ---------将图片数据装载到DataLoader中使用------------------------
dataloader = DataLoader(img_data, batch_size=20000, shuffle=False)     # 将数据装载到DataLoader,每批20000条数据
batch_num  = len(dataloader)                                           # 获取数据的批数
dataloader_iter = iter(dataloader)                                     # 将dataloader转为迭代对象

该部分将上述下载到的DataSet数据集img_data按20000每批装载到DataLoader,

其中shuffle=False表示在装载过程中不打乱原数据集的顺序

从这里可以知道,DataLoader是把数据按一批一批的形式存放的,

这是因为在深度学习中,并不是一次性使用所有的数据对模型进行训练,而是逐批数据训练模型

在装载完成后,把dataloader转换为迭代器,方便后续的迭代使用


👉3.数据使用

# -----对dataloader进行历遍,并打印出每批数据的第一条数据(图片)-----
figure     = plt.figure(figsize=(4, 4))                                # 初始化图片,figsize可以指定图片的大小
for i in range(batch_num):
    # 获取本批data(img)和label
    imgs, labels = next(dataloader_iter)                               # 获取本批data(img)和label
                       
    # 将该批的第一条数据打印出来
    figure.add_subplot(1,batch_num,i+1)                                # 添加到figure
    plt.title(labels[0].item())                                        # 样本标签 
    plt.axis("off")                                                    # 不展示坐标轴
    plt.imshow(imgs[0].squeeze(), cmap="gray")                         # 展示本批图片的第一张
plt.show()

该部分通过迭代DataLoader来逐批使用数据集

先通过imgs, labels = next(dataloader_iter) 将迭代器指向下一批,并提出其中的数据imgs和labels

在本示例中,我们只是简单地将每批数据的第一条数据通过图象imshow展示出来,

在实际中就应切换为具体的操作,例如用于模型训练等等




  结束语    


本文我们只是简单地了解与认识pytorch中是以什么形式来使用数据集的,
更多的细节与认识在后续文章再逐步深入讲解,千里之行,始于足下~






好了,以上就是pytorch的DataSet与DataLoader的全部内容了~







 End 






联系老饼