Pytorch教程

【例子】pytorch导入FashionMNIST数据

作者 : 老饼 发表日期 : 2023-07-28 10:48:58 更新日期 : 2023-12-04 04:40:52
本站原创文章,转载请说明来自《老饼讲解-深度学习》www.bbbdata.com


本文通过展示一个pytorch导入FashionMNIST数据的例子

来讲解pytorch导入torchvision自带数据集的流程和注意事项



 一、pytorch导入自带的FashionMNIST数据  



本节展示如何在pytorch中导入torchvision的自带FashionMNIST数据  



1. pytorch导入FashionMNIST数据的代码示例

下面直接展示一个在pytorch中加载FashionMNIST数据

示例代码如下:

# -*- coding: utf-8 -*-
"""
本代码展示如何下载pytorch的torchvision自带的FashionMNIST数据,
并展示下载后的数据内容
"""
# ------------下载torchvision自带的FashionMNIST图片样本---------------
from torchvision.transforms import ToTensor

# 获取训练数据
import torchvision
img_data = torchvision.datasets.FashionMNIST(
    root       = 'D:\pytorch\data'                    # 路径,如果路径有,就直接从路径中加载,如果没有,就联网获取
    ,train     = True                                 # true为下载训练数据集,false为下载测试数据集
    ,download  = True                                 # 是否下载,选为True,就下载到root下面
    ,transform = torchvision.transforms.ToTensor()    # 转换为tensor数据
)
# 获取测试数据
test_data = torchvision.datasets.FashionMNIST(
    root      = 'D:\pytorch\data',
    train     = False,
    download  = True,
    transform = ToTensor()
)

#-----------------随机提取一些样本来看看-------------------------------------------------
import random
import matplotlib.pyplot as plt

figure     = plt.figure(figsize=(8, 8))                   # 初始化图片,figsize可以指定图片的大小
sample_num = len(img_data)                                # 样本个数
cols, rows = 3, 3                                         # 展示3*3个样本
pic_num    = cols * rows                                  # 展示样本个数
sample_idx = random.sample(range(0,sample_num),pic_num)   # 随机抽取样本
labels_map = img_data.classes                             # 提取出标签名称   
for i in range(0, pic_num):                               # 绘画各个样本
    img, label = img_data[sample_idx[i]]                  # 本次要绘画的样本
    figure.add_subplot(rows, cols, i+1)                   # 添加到figure
    plt.title(labels_map[label])                          # 样本标签 
    plt.axis("off")                                       # 不展示坐标轴
    plt.imshow(img.squeeze(), cmap="gray")                # 绘图
plt.show()

运行上述代码,下载后的图片如下

  





 二、pytorch导入torchvision自带数据的详细说明  



本节讲解并剖析上节例子中的一些细节,进一步了解pytorch中如何导入torchvision自带数据



01. 关于torchvision自带数据的接口请求参数


torchvision数据中的每个数据接口入参并不是完全一样的,甚至相同的入参名,所代表的意义也不同

所以对于具体数据的下载,必须根据具体的数据接口说明来进设设置

所有图片数据的数据列表地址为: https://pytorch.org/vision/stable/datasets.html 

本例中下载的是FashionMNIST的数据,它的说明页面为: FashionMNIST

根据FashionMNIST的数据说明,配置的参数如下:

img_data = torchvision.datasets.FashionMNIST(
    root       = 'D:\pytorch\data'                    # 路径,如果路径有,就直接从路径中加载,如果没有,就联网获取
    ,train     = True                                 # true为下载训练数据集,false为下载测试数据集
    ,download  = True                                 # 是否下载,选为True,就下载到root下面
    ,transform = torchvision.transforms.ToTensor()    # 转换为tensor数据
)

其中各个参数的意义如下:

root          :数据的下载、保存路径

                     如果路径上已经有数据,就直接从路径加载,如果没有,再到pytorch上下载

train         :是否下载训练数据,

                    如果为True,则下载训练数据,如果为False,就下载测试数据

dowload  :是否下载数据

                    一般都选择是,因为下载了数据之后,第二次再运行时就会直接从本地获取,而不需要联网下载

transform:数据的转换,一般设为ToTensor,这样在使用数据时就是Tensor数据类型


02. 关于torchvision自带数据的返回数据


在获取数据后,返回的是数据集DataSet,可以直接装载到DataLoader中使用

但需要注意的是,每个数据集返回的具体数据内容是不同的,

可以点击变量名来查看返回的数据内容,配合数据的文档说明来了解返回的具体数据








 End 







联系老饼