Coloración de imágenes basada en aprendizaje profundo (Opencv, Pytorch, CNN)

1. Introducción

La dirección de descarga del código fuente se adjunta al final del artículo.

Colorea automáticamente imágenes en escala de grises

2. Formato de imagen (RGB, HSV, Laboratorio)

2.1 RGB

Si desea colorear una imagen en escala de grises, primero debe comprender el formato de la imagen. Para una imagen normal, generalmente es en formato RGB, es decir, tres canales de rojo, verde y azul. Puede usar opencv para separar Los tres canales de la imagen. El código es el siguiente. Se muestra:

import cv2

img=cv2.imread('pic/7.jpg')
B,G,R=cv2.split(img)
cv2.imshow('img',img)
cv2.imshow('B',B)
cv2.imshow('G',G)
cv2.imshow('R',R)
cv2.waitKey(0)

El resultado de ejecutar el código es el siguiente.
Insertar descripción de la imagen aquí

2.2 hsv

HSV es otro formato de imagen, donde h representa el tono de la imagen, s representa la saturación y v representa el brillo de la imagen. El tono, la saturación, el brillo y otra información de la imagen se pueden cambiar ajustando los valores. ​​de h, s y v.
También puede utilizar opencv para convertir imágenes del formato RGB al formato hsv. Luego puede separar los canales h, s y v y mostrar el código de imagen de la siguiente manera:

import cv2

img=cv2.imread('pic/7.jpg')
hsv=cv2.cvtColor(img,cv2.COLOR_BGR2HSV)
h,s,v=cv2.split(hsv)
cv2.imshow('hsv',hsv)
cv2.imshow('h',h)
cv2.imshow('s',s)
cv2.imshow('v',v)
cv2.waitKey(0)

Los resultados de ejecución son los siguientes:
Insertar descripción de la imagen aquí

2.3 Laboratorio

Lab es otro formato de imágenes, y también es el formato utilizado en este artículo. L representa la imagen en escala de grises, y a y b representan canales de color. Este artículo utiliza la imagen en escala de grises del canal L como entrada y los canales de color ab como salida para entrenar para generar confrontación.Network, el código para convertir la imagen del formato RGB al formato Lab es el siguiente:

import cv2

img=cv2.imread('pic/7.jpg')
Lab=cv2.cvtColor(img,cv2.COLOR_BGR2Lab)
L,a,b=cv2.split(Lab)
cv2.imshow('Lab',Lab)
cv2.imshow('L',L)
cv2.imshow('a',a)
cv2.imshow('b',b)
cv2.waitKey(0)

Insertar descripción de la imagen aquí

3. Red Adversaria Generativa (GAN)

生成对抗网络主要包含两部分,分别是生成网络和判别网络。
生成网络负责生成图像,判别网络负责鉴定生成图像的好坏,二者相辅相成,相互博弈。
本文使用U-net作为生成网络,使用ResNet18作为判别网络。U-net网络的结构图如下所示:

3.1 Red generadora (Unet)

Insertar descripción de la imagen aquí

El código de pytorch para construir la red unet es el siguiente:

class DownsampleLayer(nn.Module):
    def __init__(self,in_ch,out_ch):
        super(DownsampleLayer, self).__init__()
        self.Conv_BN_ReLU_2=nn.Sequential(
            nn.Conv2d(in_channels=in_ch,out_channels=out_ch,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(),
            nn.Conv2d(in_channels=out_ch, out_channels=out_ch, kernel_size=3, stride=1,padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU()
        )
        self.downsample=nn.Sequential(
            nn.Conv2d(in_channels=out_ch,out_channels=out_ch,kernel_size=3,stride=2,padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU()
        )

    def forward(self,x):
        """
        :param x:
        :return: out输出到深层,out_2输入到下一层,
        """
        out=self.Conv_BN_ReLU_2(x)
        out_2=self.downsample(out)
        return out,out_2
class UpSampleLayer(nn.Module):
	def __init__(self,in_ch,out_ch):
	   # 512-1024-512
	   # 1024-512-256
	   # 512-256-128
	   # 256-128-64
	   super(UpSampleLayer, self).__init__()
   self.Conv_BN_ReLU_2 = nn.Sequential(
       nn.Conv2d(in_channels=in_ch, out_channels=out_ch*2, kernel_size=3, stride=1,padding=1),
       nn.BatchNorm2d(out_ch*2),
       nn.ReLU(),
       nn.Conv2d(in_channels=out_ch*2, out_channels=out_ch*2, kernel_size=3, stride=1,padding=1),
       nn.BatchNorm2d(out_ch*2),
       nn.ReLU()
   )
   self.upsample=nn.Sequential(
       nn.ConvTranspose2d(in_channels=out_ch*2,out_channels=out_ch,kernel_size=3,stride=2,padding=1,output_padding=1),
       nn.BatchNorm2d(out_ch),
       nn.ReLU()
   )

	def forward(self,x,out):
	   '''
	   :param x: 输入卷积层
	   :param out:与上采样层进行cat
	   :return:
	   '''
	   x_out=self.Conv_BN_ReLU_2(x)
	   x_out=self.upsample(x_out)
	   cat_out=torch.cat((x_out,out),dim=1)
	   return cat_out
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        out_channels=[2**(i+6) for i in range(5)] #[64, 128, 256, 512, 1024]
        #下采样
        self.d1=DownsampleLayer(3,out_channels[0])#3-64
        self.d2=DownsampleLayer(out_channels[0],out_channels[1])#64-128
        self.d3=DownsampleLayer(out_channels[1],out_channels[2])#128-256
        self.d4=DownsampleLayer(out_channels[2],out_channels[3])#256-512
        #上采样
        self.u1=UpSampleLayer(out_channels[3],out_channels[3])#512-1024-512
        self.u2=UpSampleLayer(out_channels[4],out_channels[2])#1024-512-256
        self.u3=UpSampleLayer(out_channels[3],out_channels[1])#512-256-128
        self.u4=UpSampleLayer(out_channels[2],out_channels[0])#256-128-64
        #输出
        self.o=nn.Sequential(
            nn.Conv2d(out_channels[1],out_channels[0],kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(out_channels[0]),
            nn.ReLU(),
            nn.Conv2d(out_channels[0], out_channels[0], kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_channels[0]),
            nn.ReLU(),
            nn.Conv2d(out_channels[0],3,3,1,1),
            nn.Sigmoid(),
            # BCELoss
        )
    def forward(self,x):
        out_1,out1=self.d1(x)
        out_2,out2=self.d2(out1)
        out_3,out3=self.d3(out2)
        out_4,out4=self.d4(out3)
        out5=self.u1(out4,out_4)
        out6=self.u2(out5,out_3)
        out7=self.u3(out6,out_2)
        out8=self.u4(out7,out_1)
        out=self.o(out8)
        return out


3.2 Red discriminante (resnet18)

El diagrama de estructura de resnet18 es el siguiente:
Insertar descripción de la imagen aquí
Pytorch viene con su propio modelo resnet18. Puede construir el modelo resnet18 con solo una línea de código. Luego debe eliminar la capa final completamente conectada de la red. El código es el siguiente:

from torchvision import models

resnet18=models.resnet18(pretrained=False)
del resnet18.fc

print(resnet18)

4. Conjunto de datos

Este artículo utiliza imágenes de datos de paisajes naturales. Rastreamos más de 1000 imágenes de datos en el sitio web. Algunas de las imágenes son las siguientes
Insertar descripción de la imagen aquí

5. Diagrama de flujo de predicción y entrenamiento del modelo

5.1 Diagrama de flujo de capacitación

Como se muestra en la figura siguiente, primero convierta la imagen RGB en una imagen de laboratorio y luego use el canal L como entrada de la red de generación. La salida de la red de generación es el nuevo canal ab. Luego, el canal ab original del Se ingresan la imagen y el canal ab generado por la red de generación red discriminante.
Insertar descripción de la imagen aquí

5.2 Diagrama de flujo de pronóstico

La siguiente figura muestra el proceso de predicción del modelo. En el proceso de predicción, la red discriminante no tiene ningún papel. Primero, la imagen RGB se convierte en una imagen de laboratorio y luego la imagen L en escala de grises se ingresa en la red de generación para obtener una nueva imagen del canal ab. Luego el canal L es La imagen se concatena con la imagen del canal ab generada. Después de la concatenación, se puede obtener una nueva imagen Lab y luego se convierte al formato RGB. En este momento, la imagen es la imagen en color.
Insertar descripción de la imagen aquí

6. Efecto de predicción del modelo

下图为模型的预测效果。左侧的为灰度图像,中间的为原始的彩色图像,右侧的是模型上色以后的图像。整体上看,网络的上色效果还不错。

Insertar descripción de la imagen aquí
Insertar descripción de la imagen aquí
Insertar descripción de la imagen aquí
Insertar descripción de la imagen aquí

7. Producción de interfaz GUI

Para que el uso del modelo sea más conveniente, este artículo utiliza pyqt5 para crear la interfaz de operación, que es como se muestra a continuación: Primero, puede cargar la imagen desde la computadora y también puede cambiar a la imagen anterior o siguiente. y podrá mostrar la imagen en escala de grises. Se puede colorear y luego se puede ajustar la información H, S y V de la imagen en color. Finalmente, se admite la exportación de imágenes y la imagen en color se puede guardar localmente.
Insertar descripción de la imagen aquí
Insertar descripción de la imagen aquí

8. Descarga de código

El enlace contiene código de entrenamiento, código de prueba y código de interfaz. Además, también contiene más de 1000 conjuntos de datos y la interfaz de operación aparecerá al ejecutar el programa main.py directamente.
Descarga de código: Descargar lista de direcciones 1

Supongo que te gusta

Origin blog.csdn.net/2302_82079084/article/details/135126761
Recomendado
Clasificación