Pytorch教程

【介绍】pytorch提供的各种图片数据集

作者 : 老饼 发表日期 : 2023-07-28 10:49:50 更新日期 : 2024-10-28 07:59:49
本站原创文章,转载请说明来自《老饼讲解-深度学习》www.bbbdata.com




前言

pytorch的torchvision提供了许多深度学习的图片数据集

为了方便使用时挑选,本文展示各个数据集的下载代码及下载后的图片样例 

附:pytorch提供的图片数据链接: https://pytorch.org/vision/stable/datasets.html 



  一、pytorch提供的各种图片数据集    



本节展示pytorch中下载各个图片数据集的代码,以及展示各个数据集的图片样例



01.手写数字MNIST


数据示例如下:

数 据 描 述:
手写数字0-9
数 据 接 口 说 明: https://pytorch.org/vision/stable/generated/torchvision.datasets.MNIST.html
数据接口示例:
import torchvision
img_data  = torchvision.datasets.MNIST(
    root       = 'D:\pytorch\data'                        # 路径,如果路径有,就直接从路径中加载,如果没有,就联网获取
    ,train     = True
    ,transform = torchvision.transforms.ToTensor()        # 转换为tensor数据
    ,download  = True                                     # 是否下载,选为True,就下载到root下面
    ,target_transform= None)



02. 手写字体EMNIST


数据示例如下:

数 据 描 述:
手写数字0-9,英语字母大小写
数 据 接 口 说 明:  https://pytorch.org/vision/stable/generated/torchvision.datasets.EMNIST.html
数 据 接 口 示例:
import torchvision
img_data = torchvision.datasets.EMNIST(
    root               = 'D:\pytorch\data'     
    ,train             = True
    ,split             = 'byclass'
    ,transform         = torchvision.transforms.ToTensor()
    ,target_transform  = None
    ,download          = True)




03. 手写数字USPS


数据示例如下:
 
数 据 描 述:手写数字0-9
数 据 接 口 说 明:  https://pytorch.org/vision/stable/generated/torchvision.datasets.USPS.html#torchvision.datasets.USPS 
数 据 接 口 示例:
import torchvision
img_data  = torchvision.datasets.USPS(
    root       = 'D:\pytorch\data'                        # 路径,如果路径有,就直接从路径中加载,如果没有,就联网获取
    ,train     =True
    ,transform = torchvision.transforms.ToTensor()        # 转换为tensor数据
    ,download  = True                                     # 是否下载,选为True,就下载到root下面
    ,target_transform= None)




04. 衣服款式Fashion-MNIST


数据示例如下:

数 据 接 口 说 明:
 https://pytorch.org/vision/stable/generated/torchvision.datasets.FashionMNIST.html
数 据 接 口 示 例:
import torchvision
img_data = torchvision.datasets.FashionMNIST(
    root       = 'D:\pytorch\data'                    # 路径,如果路径有,就直接从路径中加载,如果没有,就联网获取
    ,train     = True                                 # true为下载训练数据集,false为下载测试数据集
    ,download  = True                                 # 是否下载,选为True,就下载到root下面
    ,transform = torchvision.transforms.ToTensor()    # 转换为tensor数据
)




05. 花朵Flowers102


数据示例如下:

数 据 接 口 说 明:
 https://pytorch.org/vision/stable/generated/torchvision.datasets.Flowers102.html 
数 据 接 口 示 例:
import torchvision
img_data = torchvision.datasets.Flowers102(
    root       = 'D:\pytorch\data'               # 路径,如果路径有,就直接从路径中加载,如果没有,就联网获取
    ,split     ='train'
    ,transform = torchvision.transforms.ToTensor()        # 转换为tensor数据
    ,download  = True              # 是否下载,选为True,就下载到root下面
    ,target_transform= None)




06. 人脸识别LFWPeople

数据示例如下:

数 据 接 口 说 明:
 https://pytorch.org/vision/stable/generated/torchvision.datasets.LFWPeople.html 
数 据 接 口 示 例:
import torchvision
img_data = torchvision.datasets.LFWPeople(
    root       = 'D:\pytorch\data'               # 路径,如果路径有,就直接从路径中加载,如果没有,就联网获取
   ,split     ='train'
   ,image_set = 'original'
   ,transform = torchvision.transforms.ToTensor()        # 转换为tensor数据
   ,download  = True              # 是否下载,选为True,就下载到root下面
   ,target_transform= None )




07. 街景门牌号码SVHN


数据示例如下:

数 据 描 述:SVHN(Street View House Number)Dateset 来源于谷歌街景门牌号码
数 据 接 口 说 明: https://pytorch.org/vision/stable/generated/torchvision.datasets.SVHN.html#torchvision.datasets.SVHN 
数 据 接 口 示 例:
import torchvision
from torch.utils.data   import DataLoader
img_data  = torchvision.datasets.SVHN(
    root       = 'D:\pytorch\data'                        # 路径,如果路径有,就直接从路径中加载,如果没有,就联网获取
    ,split     ='train'
    ,transform = torchvision.transforms.ToTensor()        # 转换为tensor数据
    ,download  = True                                     # 是否下载,选为True,就下载到root下面
    ,target_transform= None)




08. Food101


数据示例如下:

数 据 接 口 说 明:
  https://pytorch.org/vision/stable/generated/torchvision.datasets.Food101.html#torchvision.datasets.Food101 
数 据 接 口 示 例:
import torchvision
img_data  = torchvision.datasets.Food101(
    root       = 'D:\pytorch\data'                        # 路径,如果路径有,就直接从路径中加载,如果没有,就联网获取
    ,split      ='train'
    ,transform = torchvision.transforms.ToTensor()        # 转换为tensor数据
    ,download  = True                                     # 是否下载,选为True,就下载到root下面
    ,target_transform= None)




09. STL10


数据示例如下:

数 据 接 口 说 明:
 https://pytorch.org/vision/stable/generated/torchvision.datasets.STL10.html#torchvision.datasets.STL10 
数 据 接 口 示 例:
import torchvision
img_data  = torchvision.datasets.STL10(
    root       = 'D:\pytorch\data'                        # 路径,如果路径有,就直接从路径中加载,如果没有,就联网获取
    ,split     = 'train'
    ,transform = torchvision.transforms.ToTensor()        # 转换为tensor数据
    ,download  = True                                     # 是否下载,选为True,就下载到root下面
    ,target_transform= None)






10. CIFAR10


数据示例如下:

数 据 接 口 说 明:
  https://pytorch.org/vision/stable/generated/torchvision.datasets.CIFAR10.html#torchvision.datasets.CIFAR10 
数 据 接 口 示 例:
import torchvision
img_data  = torchvision.datasets.CIFAR10(
    root       = 'D:\pytorch\data'                        # 路径,如果路径有,就直接从路径中加载,如果没有,就联网获取
    ,train     = True
    ,transform = torchvision.transforms.ToTensor()        # 转换为tensor数据
    ,download  = True                                     # 是否下载,选为True,就下载到root下面
    ,target_transform= None)





二. pytorch图片数据集的样本画图代码  



本节展示在下载pytorch的图片数据集后,如何查看图片样例



01. 绘画pytorch图片数据集的样例


在下载pytorch的图片数据后,可以使用下述代码画出图片样例


import random
import matplotlib.pyplot as plt

figure     = plt.figure(figsize=(8, 8))                   # 初始化图片,figsize可以指定图片的大小
sample_num = len(img_data)                                # 样本个数
cols, rows = 3, 3                                         # 展示3*3个样本
pic_num    = cols * rows                                  # 展示样本个数
sample_idx = random.sample(range(0,sample_num),pic_num)   # 随机抽取样本
for i in range(0, pic_num):                               # 绘画各个样本
    img, label = img_data[sample_idx[i]]                  # 本次要绘画的样本
    figure.add_subplot(rows, cols, i+1)                   # 添加到figure

    plt.axis("off")                                       # 不展示坐标轴
    plt.imshow(torchvision.transforms.ToPILImage()(img), cmap="gray")    # 绘图
plt.show()











 End 








联系老饼