本站原创文章,转载请说明来自《老饼讲解-深度学习》www.bbbdata.com
本文通过展示一个pytorch导入FashionMNIST数据的例子
来讲解pytorch导入torchvision自带数据集的流程和注意事项
本节展示如何在pytorch中导入torchvision的自带FashionMNIST数据
1. pytorch导入FashionMNIST数据的代码示例
下面直接展示一个在pytorch中加载FashionMNIST数据
示例代码如下:
# -*- coding: utf-8 -*-
"""
本代码展示如何下载pytorch的torchvision自带的FashionMNIST数据,
并展示下载后的数据内容
"""
# ------------下载torchvision自带的FashionMNIST图片样本---------------
from torchvision.transforms import ToTensor
# 获取训练数据
import torchvision
img_data = torchvision.datasets.FashionMNIST(
root = 'D:\pytorch\data' # 路径,如果路径有,就直接从路径中加载,如果没有,就联网获取
,train = True # true为下载训练数据集,false为下载测试数据集
,download = True # 是否下载,选为True,就下载到root下面
,transform = torchvision.transforms.ToTensor() # 转换为tensor数据
)
# 获取测试数据
test_data = torchvision.datasets.FashionMNIST(
root = 'D:\pytorch\data',
train = False,
download = True,
transform = ToTensor()
)
#-----------------随机提取一些样本来看看-------------------------------------------------
import random
import matplotlib.pyplot as plt
figure = plt.figure(figsize=(8, 8)) # 初始化图片,figsize可以指定图片的大小
sample_num = len(img_data) # 样本个数
cols, rows = 3, 3 # 展示3*3个样本
pic_num = cols * rows # 展示样本个数
sample_idx = random.sample(range(0,sample_num),pic_num) # 随机抽取样本
labels_map = img_data.classes # 提取出标签名称
for i in range(0, pic_num): # 绘画各个样本
img, label = img_data[sample_idx[i]] # 本次要绘画的样本
figure.add_subplot(rows, cols, i+1) # 添加到figure
plt.title(labels_map[label]) # 样本标签
plt.axis("off") # 不展示坐标轴
plt.imshow(img.squeeze(), cmap="gray") # 绘图
plt.show()
运行上述代码,下载后的图片如下
本节讲解并剖析上节例子中的一些细节,进一步了解pytorch中如何导入torchvision自带数据
01. 关于torchvision自带数据的接口请求参数
torchvision数据中的每个数据接口入参并不是完全一样的,甚至相同的入参名,所代表的意义也不同
所以对于具体数据的下载,必须根据具体的数据接口说明来进设设置
所有图片数据的数据列表地址为: https://pytorch.org/vision/stable/datasets.html
本例中下载的是FashionMNIST的数据,它的说明页面为: FashionMNIST
根据FashionMNIST的数据说明,配置的参数如下:
img_data = torchvision.datasets.FashionMNIST(
root = 'D:\pytorch\data' # 路径,如果路径有,就直接从路径中加载,如果没有,就联网获取
,train = True # true为下载训练数据集,false为下载测试数据集
,download = True # 是否下载,选为True,就下载到root下面
,transform = torchvision.transforms.ToTensor() # 转换为tensor数据
)
其中各个参数的意义如下:
root :数据的下载、保存路径
如果路径上已经有数据,就直接从路径加载,如果没有,再到pytorch上下载
train :是否下载训练数据,
如果为True,则下载训练数据,如果为False,就下载测试数据
dowload :是否下载数据
一般都选择是,因为下载了数据之后,第二次再运行时就会直接从本地获取,而不需要联网下载
transform:数据的转换,一般设为ToTensor,这样在使用数据时就是Tensor数据类型
02. 关于torchvision自带数据的返回数据
在获取数据后,返回的是数据集DataSet,可以直接装载到DataLoader中使用
但需要注意的是,每个数据集返回的具体数据内容是不同的,
可以点击变量名来查看返回的数据内容,配合数据的文档说明来了解返回的具体数据
End