Pytorch教程

【介绍】torchvision.transforms的基本使用

作者 : 老饼 发表日期 : 2023-07-28 10:47:19 更新日期 : 2024-01-26 22:45:05
本站原创文章,转载请说明来自《老饼讲解-深度学习》www.bbbdata.com



本文初步介绍pytorch的图片处理模块torchvision.transforms的功能

并展示相关实例,说明如何使用torchvision.transforms对图片进行处理



前言

本文介绍torchvision.transforms的基础使用,以此了解torchvision.transforms的整体使用

关于torchvision.transforms的更细节的使用在后续文章再进行补充

torchvision.transforms的介绍主要来自官方文档:https://pytorch.org/vision/stable/transforms.html 



一、torchvision.transforms介绍  



本节先简单介绍torchvision.transforms是什么,有什么用


01. torchvision.transforms是什么

torchvision.transforms是pytorch提供的一些常见的图象处理的变换方法

既支持单个变换处理,也支持一系列变换处理,

目前,torchvision.transforms共提供了两个版本,

torchvision.transforms下的API是V1版本,torchvision.transforms.V2下的API是V2版本


02.  torchvision.transforms提供的变换

transforms提供的变换主要包括改变尺寸、改变颜色、裁剪、自动增加等等

基本深度学习中常用的图片处理与样本增强方法都包含在torchvision.transforms之中

例如:

Resize-缩放:缩放Resize是将图片修改成指定尺寸

RandomCrop-随机裁剪:在原图像中随机裁剪指定尺寸的子图

CenterCrop-中心裁剪:在图片中心指定尺寸进行裁剪

ColorJitter-颜色抖动:随机修改图片的亮度、对比度、饱和度和色相

...等等

详细可见:《transforms提供的各种转换》




二、torchvision.transforms的使用  



本节展示如何使用torchvision.transforms对图片进行处理


01. torchvision.transforms单个变换的使用示例

下面以改变图片的Size为例,展示如何通过torchvision.transforms.v2.Resize进行处理,

原图如下:

 

通过torchvision.transforms改变图片Size的具体示例代码如下:

import matplotlib.pyplot as plt
import torchvision
from torchvision.transforms import v2

# ---------用transforms对图片对进行转换-------
img       = torchvision.io.read_image('img_01.png' )     # 读取图片
transform = v2.Resize((100,100))                         # 初始化转换
img_trans = transform(img)                               # 对图片进行转换
											            
# ----------展示结果----------                          
axs     = plt.figure().subplots(1, 2)                    # 初始化画布
axs[0].imshow(v2.ToPILImage()(img))                      # 展示原始图片
axs[0].set_title('src')                                  # 展示原始图片标题
axs[1].imshow(v2.ToPILImage()(img_trans))                # 展示转换后的图片
axs[1].set_title('Resize')                               # 展示转换后的图片标题

运行结果如下:


可以看到,图片的Size已经改为了100*100



02. torchvision.transforms多个变换的使用示例

👉关于Compose

在pytorch的torchvision.transforms中,可以通过Compose将多个变换合并为一个变换

如下,用ComposeResizeRandomInvert两个基础变换并合为一个变换,

使用Compose后transforms时,就会按顺序依次进行Resize和RandomInvert

transforms = v2.Compose([
    v2.Resize((100,100)),                                # 改变图片大小
    v2.RandomInvert(p=1.0),                              # 颜色反转
])


👉Compose的完整使用示例

Compose的完整使用示例如下:

import matplotlib.pyplot as plt
import torchvision
from torchvision.transforms import v2

# ---------用transforms对图片对进行转换-------
img       = torchvision.io.read_image('img_01.jpg' )     # 读取图片
transforms = v2.Compose([
    v2.Resize((100,100)),                                # 改变图片大小
    v2.RandomInvert(p=1.0),                              # 颜色反转
])

img_trans = transforms(img)                               # 对图片进行转换
											            
# ----------展示结果----------                          
axs     = plt.figure(figsize=(8,3)).subplots(1, 2)       # 初始化画布
axs[0].imshow(v2.ToPILImage()(img))                      # 展示原始图片
axs[0].set_title('src')                                  # 展示原始图片标题
axs[1].imshow(v2.ToPILImage()(img_trans))                # 展示转换后的图片
axs[1].set_title('transforms')                           # 展示转换后的图片标题

运行结果如下:

 
可以看到,将图片的Size改为100*100之后,又把图片的颜色进行了反转


👉更多的Compose相关API

除了Compose,pytorch还提供了其它相关的API,如下

Compose-组成 :将多个变换组合为一个变换

RandomApply-随机应用 :以一定的概率应用变换列表

RandomChoice-随机选择:从变换列表中随机选择一个变换

RandomOrder-随机顺序:对变换列表中的变换按随机顺序进行变换


03. torchvision.transforms的数据转换

pytorch的torchvision.transforms还提供了相关的数据格式转换API

相关API如下:

ToImage: 将张量、ndarray或PIL图像转换为图像,这不会缩放值

ToPureTensor: 将所有tv_tensor转换为纯张量,如果有相关的元数据则删除相关元数据

PILToTensor:将PIL图像转换为相同类型的张量,这不会缩放值

ToPILImage:将张量或ndarray转换为PIL图像

ToDtype:将输入转换为特定的数据类型,可以选择缩放图像或视频的值

ConvertBoundingBoxFormat:将边界框坐标转换为给定格式,例如从“CXCYWH”转换为“XYXY”




三、关于torchvision.transforms的版本



本节拓展性地简单介绍一下关于pytorch的torchvision.transforms版本


01. V1与V2的区别

torchvision.transforms共有两个版本:V1和V2

V1的API在torchvision.transforms之下,V2的API在torchvision.transforms.v2之下

pytorch官方基本推荐使用V2,V2兼容V1版本,但V2的功能更多性能更好


02. 如何令V2获得更好的性能

据官方说明,在torch.utils.data.DataLoader 的num_workers > 0时,

V2对于tensor类型且为uint8的数据类型性能最好,

如果不是tensor、uint8的数据,可以先进行转换(当然不转也可以,只是没发挥最好的性能)

from torchvision.transforms import v2
import torch
transforms = v2.Compose([
    v2.ToPureTensor(),                  # 转换为tensor
    v2.ToDtype(torch.uint8),        # 转换为uint8
])










 End 






联系老饼