将tensor转换为image

 将图片转换为tensor的时候,图片发生了几个变化;

  • 首先图片的通道维度发生了变化,从h,w,c变化成了b,c,h,w
  • 图像的像素范围发生了变化,从0-255变化成了0-1
  • 图片的数据类型发生了变化,从np.uint8变成了torch.float32

所以要想将tensor变成image也需要从以上三个维度进行考虑。

image = tensor.squeeze(0).permute(1, 2, 0).mul(255).clamp(0, 255).cpu().numpy().astype("uint8")
def tensor_to_image(tensor):
    if tensor.dim()==4:
        tensor=tensor.squeeze(0)  ###去掉batch维度
    tensor=tensor.permute(1,2,0) ##将c,h,w 转换为h,w,c
    tensor=tensor.mul(255).clamp(0,255)  ###将像素值转换为0-255之间
    tensor=tensor.cpu().numpy().astype('uint8')  ###
    return tensor

在torchvision中已经封装了transform,其中包含transforms.ToTensor,transforms.ToPILImage函数可以直接进行转换。但是需要注意输入的数据类型。这个具体使用的时候可以看一下官方文档

 

猜你喜欢

转载自blog.csdn.net/qq_40107571/article/details/127331796
今日推荐