pytorch的 torchvision transforms

技术·学习 · 2023-02-13 · 43 人浏览

pytorch的 torchvision transforms

torchvision是pytorch的数据集,也包含常用数据处理工具,包含几个模块:

  • datasets(包含常用的数据集:minist,COCO等)
  • models(包含常用的著名网络结构:AlexNet,VGG,ResNet等等,你可以使用随机初始化的网络结构,也可以使用已经训练好的网络)
  • transforms(对PIL.Image进行变换处理:Scale(缩放)、CenterCrop(中心切割)、Pad(填充)等),PIL(Python Imaging Library)是python对图形处理的库。

下面具体讲一下transforms中常用函数的使用

transforms.Scale(size)

将输入的PIL.Image重新改变大小成给定的size,size是最小边的边长。举个例子,如果原图的height>width,那么改变大小后的图片大小是(size*height/width, size),若是height<width,那么就是(size, size*width/height)
例:

from PIL import Image
from torchvision import transforms
crop=transforms.Scale(12)
img=Image.open('test.jpg')
print(type(img))
print(img.size)
print(crop(img).size)


输出:
<class ‘PIL.JpegImagePlugin.JpegImageFile’>
(261, 230)
(13, 12)
transforms.ToTensor()

把一个取值范围是[0,255]的PIL.Image或者shape(Height,Width,Channel)numpy.ndarray,转换成形状为[Channel,Height,Width],(也就是把通道数放第一维度了)且取值范围是[0,1.0]的torch.FloadTensor
例:

from PIL import Image
from torchvision import transforms
import numpy as np
im = Image.open('test.jpg')
im_arry=np.asarray(im)
print(im_arry.shape)
im_tensor=transforms.ToTensor()(im)#或者用im_arry也是可以的
print(im_tensor)
print(im_tensor.shape)


输出:
(230, 261, 3)
tensor([[[0.2314, 0.2392, 0.2392, …, 0.2314, 0.2314, 0.2392],
[0.2314, 0.2314, 0.2314, …, 0.2314, 0.2314, 0.2314],
[0.2314, 0.2314, 0.2314, …, 0.2314, 0.2314, 0.2314],
…,

torch.Size([3, 230, 261])

可以看出通道数的确放前面去了,且取值范围都在0-1之间,而且transforms.ToTensor()是直接处理PIL image也可以是image array

transforms.ToPILImage

与前面的相反:将shape为(C,H,W)的Tensor或shape为(H,W,C)numpy.ndarray转换成PIL.Image,值不变 。

transforms.Normalize(mean, std)

给定均值:(R,G,B) ,方差:(R,G,B),将会把Tensor正则化。即:Normalized_image=(image-mean)/std
例:

from PIL import Image
from torchvision import transforms
import numpy as np
im = Image.open('test.jpg')
im_arry=np.asarray(im)
print(im_arry.shape)
im_tensor=transforms.ToTensor()(im)

im_Normal=transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))(im_tensor)
print(im_tensor)
print(im_tensor.shape)


输出:
(230, 261, 3)
tensor([[[-0.5373, -0.5216, -0.5216, …, -0.5373, -0.5373, -0.5216],
[-0.5373, -0.5373, -0.5373, …, -0.5373, -0.5373, -0.5373],
[-0.5373, -0.5373, -0.5373, …, -0.5373, -0.5373, -0.5373],
…,
[-0.6627, -0.6627, -0.6627, …, -0.6627, -0.6627, -0.6627],
[-0.6627, -0.6627, -0.6627, …, -0.6627, -0.6627, -0.6627],
[-0.6627, -0.6627, -0.6627, …, -0.6627, -0.6627, -0.6627]],

torch.Size([3, 230, 261])

一定要把图像先转换为tensor,在用此函数。

transforms.Pad(padding, fill=0)

将给定的PIL.Image的所有边用给定的pad value填充。 padding:要填充多少像素 fill:用什么值填充
例:

from PIL import Image
from torchvision import transforms
import numpy as np
im = Image.open('test.jpg')
print(im)
im_re=transforms.Resize((3, 3))(im)
print(np.asarray(im_re))
print(np.asarray(im_re).shape)
im_pad=transforms.Pad(padding=1,fill=0)(im_re)
print(np.asarray(im_pad))
print(np.asarray(im_pad).shape)


输出:
<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=261x230 at 0x20FBDA02EB8>
[[[60 61 61]
[61 62 62]
[62 64 64]]

[[53 54 53]
[53 53 52]
[51 51 50]]

[[45 45 45]
[45 45 45]
[45 45 45]]]
(3, 3, 3)
[[[ 0 0 0]
[ 0 0 0]
[ 0 0 0]
[ 0 0 0]
[ 0 0 0]]

[[ 0 0 0]
[60 61 61]
[61 62 62]
[62 64 64]
[ 0 0 0]]

[[ 0 0 0]
[53 54 53]
[53 53 52]
[51 51 50]
[ 0 0 0]]

[[ 0 0 0]
[45 45 45]
[45 45 45]
[45 45 45]
[ 0 0 0]]

[[ 0 0 0]
[ 0 0 0]
[ 0 0 0]
[ 0 0 0]
[ 0 0 0]]]
(5, 5, 3)

可以看出图片填充在每一个通道上都进行了填充,可以把初始[3,3,3]想像成一个333的立方体,然后上下两个面不动,周围4个面各向外推出2,就得到553的立方体。

transforms.Resize((height, width))

resize图像,例子见上

transforms.Compose()

就是把多个transforms组合起来.
例子

from PIL import Image
from torchvision import transforms
import numpy as np
im = Image.open('test.jpg')
print(im)
com=transforms.Compose([
     transforms.Resize((3,4)),
     transforms.ToTensor(),
 ])
im_com=com(im)
print(im_com)


输出:
<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=261x230 at 0x1BDDCFC2E10>
tensor([[[0.2353, 0.2314, 0.2471, 0.2392],
[0.2078, 0.2078, 0.2039, 0.1961],
[0.1765, 0.1765, 0.1765, 0.1765]],…
transforms.Lambda()

用户可以用transforms.Lambda()函数自行定义transform操作,该操作不是由torchvision库所拥有的,其中参数是lambda表示的是自定义函数。

举例说明:

比如当我们想要截取图像,但并不想在随机位置截取,而是希望在一个自己指定的位置去截取

那么你就需要自定义一个截取函数,然后使用transforms.Lambda去封装它即可,如:

# coding:utf-8
from torchvision import transforms as T

def __crop(img, pos, size):
    """
    :param img: 输入的图像
    :param pos: 图像截取的位置,类型为元组,包含(x, y)
    :param size: 图像截取的大小
    :return: 返回截取后的图像
    """
    ow, oh = img.size
    x1, y1 = pos
    tw = th = size
    # 有足够的大小截取
    # img.crop坐标表示 (left, upper, right, lower)
    if (ow > tw or oh > th):
        return img.crop((x1, y1, x1+tw, y1+th))
    return img

# 然后使用transforms.Lambda封装其为transforms策略
# 然后定义新的transforms为
normalize = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
data_transforms = T.Compose([
    T.Lambda(lambda img: __crop(img, (5,5), 224)),
    T.RandomHorizontalFlip(),  # 随机水平翻转给定的PIL.Image,翻转概率为0.5
    T.ToTensor(),  # 转成Tensor格式,大小范围为[0,1]
    normalize
])
python pytorch
Theme Jasmine by Kent Liao