本站原创文章,转载请说明来自《老饼讲解-深度学习》www.bbbdata.com
pytorch是为深度学习量身打造的框架,但由于深度学习的数据量一般较大,
因此,pytorch特地为深度学习提供了一套数据类与数据方法来管理数据集
通过本文可以初步了解在pytorch中是以什么形式来使用数据集的,并简单认识DataLoder和DataSet是什么
本节简单讲解和认识pytorch中是以什么形式使用数据的,并初步了解DataLoader和DataSet是什么
1. pytorch中的DataLoader
一般来说,pytorch并不直接使用表格形式的数据进行建模,而是将数据装到DataLoader中再进行使用,
DataLoader提供了一些常用数据处理API,通过DataLoader来使用数据在建模过程中会更加便利
DataLoader在某种意义上,充当了数据中心的角色,把数据装载到DataLoader后,再由DataLoader来提供各种数据服务
2. pytorch中的数据集DataSet
在pytorch中,由于数据集一般都加载到DataLoader中进行使用,
所以,pytorch中所说的数据集也并非一个表格数据(例如numpy或DataFrame之类的表格数据)
而是一个由DataLoader约定形式的DataSet数据类,
✍️备注:DataSet长怎么样子,我们在后续文章中再具体进行讲解
3. pytorch使用数据的简单总结
总的来说,在pytorch中一般需要按如下形式使用数据:
1. 先要将原始数据(原始数据可能是csv,也可能是numpy之类的表格数据)按指定方式封装成DataSet
2. 然后把DataSet装载到DataLoader中
3. 最后建模时通过DataLoader来调用数据
本节通过一个示例,展示在pytorch中是怎么通过DataSet和DataLoader来使用数据集的
1.DataSe-DataLoder的使用示例
下面通过一个示例来认识pytorch是以什么方式使用数据集的
示例如下:
# -*- coding: utf-8 -*-
"DataLoader的一个简单用例"
from torchvision import datasets
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
# ---------获取自带数据FashionMNIST-----------------------------
# 路径,用于存放下载的数据(需要改成自己的目录)
root= 'D:\pytorch\data'
# 获取训练数据
img_data = datasets.FashionMNIST(
root = root, # 路径,如果路径有,就直接从路径中加载,如果没有,就联网获取
train = True, # true为下载训练数据集,false为下载测试数据集
download = True, # 是否下载,选为True,就下载到root下面
transform = ToTensor() # 转换为tensor数据
)
# ---------将图片数据装载到DataLoader中使用------------------------
dataloader = DataLoader(img_data, batch_size=20000, shuffle=False) # 将数据装载到DataLoader,每批20000条数据
batch_num = len(dataloader) # 获取数据的批数
dataloader_iter = iter(dataloader) # 将dataloader转为迭代对象
# -----对dataloader进行历遍,并打印出每批数据的第一条数据(图片)-----
figure = plt.figure(figsize=(4, 4)) # 初始化图片,figsize可以指定图片的大小
for i in range(batch_num):
# 获取本批data(img)和label
imgs, labels = next(dataloader_iter) # 获取本批data(img)和label
# 将该批的第一条数据打印出来
figure.add_subplot(1,batch_num,i+1) # 添加到figure
plt.title(labels[0].item()) # 样本标签
plt.axis("off") # 不展示坐标轴
plt.imshow(imgs[0].squeeze(), cmap="gray") # 展示本批图片的第一张
plt.show()
运行结果如下:
2. 示例逐步解说
上述示例共包含了三大块,下面我们进行逐块解说
👉1.数据下载
# ---------获取自带数据FashionMNIST-----------------------------
# 路径,用于存放下载的数据(需要改成自己的目录)
root= 'D:\pytorch\data'
# 获取训练数据
img_data = datasets.FashionMNIST(
root = root, # 路径,如果路径有,就直接从路径中加载,如果没有,就联网获取
train = True, # true为下载训练数据集,false为下载测试数据集
download = True, # 是否下载,选为True,就下载到root下面
transform = ToTensor() # 转换为tensor数据
)
该部分通过datasets去下载pytorch官方提供的FashionMNIST数据集,
注意,这里下载到的img_data(FashionMNIST数据集)是一个DataSet类,并不是一个DataFrame之类的表格数据
关于怎么下载官方数据、或者将自己的数据封装成一个DataSet,在后续文章中再进行详细解说
👉2.数据装载
# ---------将图片数据装载到DataLoader中使用------------------------
dataloader = DataLoader(img_data, batch_size=20000, shuffle=False) # 将数据装载到DataLoader,每批20000条数据
batch_num = len(dataloader) # 获取数据的批数
dataloader_iter = iter(dataloader) # 将dataloader转为迭代对象
该部分将上述下载到的DataSet数据集img_data按20000每批装载到DataLoader,
其中shuffle=False表示在装载过程中不打乱原数据集的顺序
从这里可以知道,DataLoader是把数据按一批一批的形式存放的,
这是因为在深度学习中,并不是一次性使用所有的数据对模型进行训练,而是逐批数据训练模型
在装载完成后,把dataloader转换为迭代器,方便后续的迭代使用
👉3.数据使用
# -----对dataloader进行历遍,并打印出每批数据的第一条数据(图片)-----
figure = plt.figure(figsize=(4, 4)) # 初始化图片,figsize可以指定图片的大小
for i in range(batch_num):
# 获取本批data(img)和label
imgs, labels = next(dataloader_iter) # 获取本批data(img)和label
# 将该批的第一条数据打印出来
figure.add_subplot(1,batch_num,i+1) # 添加到figure
plt.title(labels[0].item()) # 样本标签
plt.axis("off") # 不展示坐标轴
plt.imshow(imgs[0].squeeze(), cmap="gray") # 展示本批图片的第一张
plt.show()
该部分通过迭代DataLoader来逐批使用数据集
先通过imgs, labels = next(dataloader_iter) 将迭代器指向下一批,并提出其中的数据imgs和labels
在本示例中,我们只是简单地将每批数据的第一条数据通过图象imshow展示出来,
在实际中就应切换为具体的操作,例如用于模型训练等等
结束语
本文我们只是简单地了解与认识pytorch中是以什么形式来使用数据集的,
更多的细节与认识在后续文章再逐步深入讲解,千里之行,始于足下~
好了,以上就是pytorch的DataSet与DataLoader的全部内容了~
End