将图片转换为tensor的时候,图片发生了几个变化;
- 首先图片的通道维度发生了变化,从h,w,c变化成了b,c,h,w
- 图像的像素范围发生了变化,从0-255变化成了0-1
- 图片的数据类型发生了变化,从np.uint8变成了torch.float32
所以要想将tensor变成image也需要从以上三个维度进行考虑。
image = tensor.squeeze(0).permute(1, 2, 0).mul(255).clamp(0, 255).cpu().numpy().astype("uint8")
def tensor_to_image(tensor):
if tensor.dim()==4:
tensor=tensor.squeeze(0) ###去掉batch维度
tensor=tensor.permute(1,2,0) ##将c,h,w 转换为h,w,c
tensor=tensor.mul(255).clamp(0,255) ###将像素值转换为0-255之间
tensor=tensor.cpu().numpy().astype('uint8') ###
return tensor
在torchvision中已经封装了transform,其中包含transforms.ToTensor,transforms.ToPILImage函数可以直接进行转换。但是需要注意输入的数据类型。这个具体使用的时候可以看一下官方文档