pytorch实现图像RGB和HSV色彩空间的相互转换,可直接加入神经网络中,支持反向传播,支持cuda。
今天在设计一个pytorch神经网络结构时需要把RGB图像转换到HSV空间,因为HSV空间更适合做一些色彩平滑过渡的图像渐变处理。因此我需要一个pytorch版的RGB和HSV相互转换函数,而且要求它可微,即可通过反向传播计算梯度。在github上找了一个"Differentiable-RGB-to-HSV-convertion-pytorch",然而这个代码中的HSV-to-RGB部分是不能用的,所以我补充了后部分,作为一个完整功能分享出来并备忘,以后再用到的时候可方便的找到。
一、代码
"""
Pytorch implementation of RGB convert to HSV, and HSV convert to RGB,
RGB or HSV's shape: (B * C * H * W)
RGB or HSV's range: [0, 1)
"""
import torch
from torch import nn
class RGB_HSV(nn.Module):
def __init__(self, eps=1e-8):
super(RGB_HSV, self).__init__()
self.eps = eps
def rgb_to_hsv(self, img):
hue = torch.Tensor(img.shape[0], img.shape[2], img.shape[3]).to(img.device)
hue[ img[:,2]==img.max(1)[0] ] = 4.0 + ( (img[:,0]-img[:,1]) / ( img.max(1)[0] - img.min(1)[0] + self.eps) ) [ img[:,2]==img.max(1)[0] ]
hue[ img[:,1]==img.max(1)[0] ] = 2.0 + ( (img[:,2]-img[:,0]) / ( img.max(1)[0] - img.min(1)[0] + self.eps) ) [ img[:,1]==img.max(1)[0] ]
hue[ img[:,0]==img.max(1)[0] ] = (0.0 + ( (img[:,1]-img[:,2]) / ( img.max(1)[0] - img.min(1)[0] + self.eps) ) [ img[:,0]==img.max(1)[0] ]) % 6
hue[img.min(1)[0]==img.max(1)[0]] = 0.0
hue = hue/6
saturation = ( img.max(1)[0] - img.min(1)[0] ) / ( img.max(1)[0] + self.eps )
saturation[ img.max(1)[0]==0 ] = 0
value = img.max(1)[0]
hue = hue.unsqueeze(1)
saturation = saturation.unsqueeze(1)
value = value.unsqueeze(1)
hsv = torch.cat([hue, saturation, value],dim=1)
return hsv
def hsv_to_rgb(self, hsv):
h,s,v = hsv[:,0,:,:],hsv[:,1,:,:],hsv[:,2,:,:]
#对出界值的处理
h = h%1
s = torch.clamp(s,0,1)
v = torch.clamp(v,0,1)
r = torch.zeros_like(h)
g = torch.zeros_like(h)
b = torch.zeros_like(h)
hi = torch.floor(h * 6)
f = h * 6 - hi
p = v * (1 - s)
q = v * (1 - (f * s))
t = v * (1 - ((1 - f) * s))
hi0 = hi==0
hi1 = hi==1
hi2 = hi==2
hi3 = hi==3
hi4 = hi==4
hi5 = hi==5
r[hi0] = v[hi0]
g[hi0] = t[hi0]
b[hi0] = p[hi0]
r[hi1] = q[hi1]
g[hi1] = v[hi1]
b[hi1] = p[hi1]
r[hi2] = p[hi2]
g[hi2] = v[hi2]
b[hi2] = t[hi2]
r[hi3] = p[hi3]
g[hi3] = q[hi3]
b[hi3] = v[hi3]
r[hi4] = t[hi4]
g[hi4] = p[hi4]
b[hi4] = v[hi4]
r[hi5] = v[hi5]
g[hi5] = p[hi5]
b[hi5] = q[hi5]
r = r.unsqueeze(1)
g = g.unsqueeze(1)
b = b.unsqueeze(1)
rgb = torch.cat([r, g, b], dim=1)
return rgb
二、验证
matplotlib.colors中也有rgb和hsv相互转换的代码,我们用它和我上面的代码对比:
import torch
import cv2
import matplotlib.pyplot as plt
from rgb_hsv import RGB_HSV
import matplotlib.colors as mcolors
img = cv2.imread('../images/0.jpg')
rgb = img[:,:,::-1] #注意opencv是BGR顺序,必须转换成RGB
rgb = rgb / 255
rgb_tensor = torch.from_numpy(rgb).permute(2,0,1).unsqueeze(0).float()
convertor = RGB_HSV()
hsv_tensor = convertor.rgb_to_hsv(rgb_tensor)
rgb1 = convertor.hsv_to_rgb(hsv_tensor)
hsv_arr = hsv_tensor[0].permute(1,2,0).numpy()
rgb1_arr = rgb1[0].permute(1,2,0).numpy()
hsv_m = mcolors.rgb_to_hsv(rgb)
rgb1_m = mcolors.hsv_to_rgb(hsv_m)
print('mse of my code and matplotlib:',((rgb1_arr - rgb)**2).mean())
plt.figure()
plt.imshow(rgb)
plt.title('origin image')
plt.figure()
plt.imshow(hsv_arr)
plt.title('visual to hsv')
plt.figure()
plt.imshow(rgb1_arr)
plt.title('convert back: my code')
plt.figure()
plt.imshow(rgb1_m)
plt.title('convert back: matplotlib method')
打印出的mse是非常接近0的一个小浮点数(因加入的防除零的eps导致)。画出的转换效果图如下。可见和matplotlib结果是一样的,证实代码没有问题。
在神经网络中加入此代码后也证实确实可以反向传播,可以在cuda上运行,代码略。