Coloração de imagens baseada em aprendizagem profunda (Opencv, Pytorch, CNN)

1. Introdução

O endereço para download do código-fonte está anexado no final do artigo.

Colorir automaticamente imagens em tons de cinza

2. Formato de imagem (RGB, HSV, Laboratório)

2.1 RGB

Se você deseja colorir uma imagem em tons de cinza, primeiro você deve entender o formato da imagem. Para uma imagem comum, geralmente está no formato RGB, ou seja, três canais de vermelho, verde e azul. Você pode usar opencv para separar os três canais da imagem. O código é o seguinte:

import cv2

img=cv2.imread('pic/7.jpg')
B,G,R=cv2.split(img)
cv2.imshow('img',img)
cv2.imshow('B',B)
cv2.imshow('G',G)
cv2.imshow('R',R)
cv2.waitKey(0)

O resultado da execução do código é o seguinte.
Insira a descrição da imagem aqui

2,2 hsv

HSV é outro formato de imagem, onde h representa o matiz da imagem, s representa a saturação e v representa o brilho da imagem. O matiz, saturação, brilho e outras informações da imagem podem ser alterados ajustando os valores de h, s e v.
Você também pode usar opencv para converter imagens do formato RGB para o formato hsv. Então você pode separar os canais h, s e v e exibir o código da imagem da seguinte forma:

import cv2

img=cv2.imread('pic/7.jpg')
hsv=cv2.cvtColor(img,cv2.COLOR_BGR2HSV)
h,s,v=cv2.split(hsv)
cv2.imshow('hsv',hsv)
cv2.imshow('h',h)
cv2.imshow('s',s)
cv2.imshow('v',v)
cv2.waitKey(0)

Os resultados da execução são os seguintes:
Insira a descrição da imagem aqui

2.3 Laboratório

Lab é outro formato de imagem e também é o formato usado neste artigo. L representa a imagem em tons de cinza, e a e b representam os canais de cores. Este artigo usa a imagem em tons de cinza do canal L como entrada e os canais de cores ab como saída para treinar para gerar confronto.Rede, o código para converter a imagem do formato RGB para o formato Lab é o seguinte:

import cv2

img=cv2.imread('pic/7.jpg')
Lab=cv2.cvtColor(img,cv2.COLOR_BGR2Lab)
L,a,b=cv2.split(Lab)
cv2.imshow('Lab',Lab)
cv2.imshow('L',L)
cv2.imshow('a',a)
cv2.imshow('b',b)
cv2.waitKey(0)

Insira a descrição da imagem aqui

3. Rede Adversarial Generativa (GAN)

生成对抗网络主要包含两部分,分别是生成网络和判别网络。
生成网络负责生成图像,判别网络负责鉴定生成图像的好坏,二者相辅相成,相互博弈。
本文使用U-net作为生成网络,使用ResNet18作为判别网络。U-net网络的结构图如下所示:

3.1 Rede geradora (Unet)

Insira a descrição da imagem aqui

O código do pytorch para construir a rede unet é o seguinte:

class DownsampleLayer(nn.Module):
    def __init__(self,in_ch,out_ch):
        super(DownsampleLayer, self).__init__()
        self.Conv_BN_ReLU_2=nn.Sequential(
            nn.Conv2d(in_channels=in_ch,out_channels=out_ch,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(),
            nn.Conv2d(in_channels=out_ch, out_channels=out_ch, kernel_size=3, stride=1,padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU()
        )
        self.downsample=nn.Sequential(
            nn.Conv2d(in_channels=out_ch,out_channels=out_ch,kernel_size=3,stride=2,padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU()
        )

    def forward(self,x):
        """
        :param x:
        :return: out输出到深层,out_2输入到下一层,
        """
        out=self.Conv_BN_ReLU_2(x)
        out_2=self.downsample(out)
        return out,out_2
class UpSampleLayer(nn.Module):
	def __init__(self,in_ch,out_ch):
	   # 512-1024-512
	   # 1024-512-256
	   # 512-256-128
	   # 256-128-64
	   super(UpSampleLayer, self).__init__()
   self.Conv_BN_ReLU_2 = nn.Sequential(
       nn.Conv2d(in_channels=in_ch, out_channels=out_ch*2, kernel_size=3, stride=1,padding=1),
       nn.BatchNorm2d(out_ch*2),
       nn.ReLU(),
       nn.Conv2d(in_channels=out_ch*2, out_channels=out_ch*2, kernel_size=3, stride=1,padding=1),
       nn.BatchNorm2d(out_ch*2),
       nn.ReLU()
   )
   self.upsample=nn.Sequential(
       nn.ConvTranspose2d(in_channels=out_ch*2,out_channels=out_ch,kernel_size=3,stride=2,padding=1,output_padding=1),
       nn.BatchNorm2d(out_ch),
       nn.ReLU()
   )

	def forward(self,x,out):
	   '''
	   :param x: 输入卷积层
	   :param out:与上采样层进行cat
	   :return:
	   '''
	   x_out=self.Conv_BN_ReLU_2(x)
	   x_out=self.upsample(x_out)
	   cat_out=torch.cat((x_out,out),dim=1)
	   return cat_out
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        out_channels=[2**(i+6) for i in range(5)] #[64, 128, 256, 512, 1024]
        #下采样
        self.d1=DownsampleLayer(3,out_channels[0])#3-64
        self.d2=DownsampleLayer(out_channels[0],out_channels[1])#64-128
        self.d3=DownsampleLayer(out_channels[1],out_channels[2])#128-256
        self.d4=DownsampleLayer(out_channels[2],out_channels[3])#256-512
        #上采样
        self.u1=UpSampleLayer(out_channels[3],out_channels[3])#512-1024-512
        self.u2=UpSampleLayer(out_channels[4],out_channels[2])#1024-512-256
        self.u3=UpSampleLayer(out_channels[3],out_channels[1])#512-256-128
        self.u4=UpSampleLayer(out_channels[2],out_channels[0])#256-128-64
        #输出
        self.o=nn.Sequential(
            nn.Conv2d(out_channels[1],out_channels[0],kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(out_channels[0]),
            nn.ReLU(),
            nn.Conv2d(out_channels[0], out_channels[0], kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_channels[0]),
            nn.ReLU(),
            nn.Conv2d(out_channels[0],3,3,1,1),
            nn.Sigmoid(),
            # BCELoss
        )
    def forward(self,x):
        out_1,out1=self.d1(x)
        out_2,out2=self.d2(out1)
        out_3,out3=self.d3(out2)
        out_4,out4=self.d4(out3)
        out5=self.u1(out4,out_4)
        out6=self.u2(out5,out_3)
        out7=self.u3(out6,out_2)
        out8=self.u4(out7,out_1)
        out=self.o(out8)
        return out


3.2 Rede discriminante (resnet18)

O diagrama de estrutura do resnet18 é o seguinte:
Insira a descrição da imagem aqui
Pytorch vem com seu próprio modelo resnet18. Você pode construir o modelo resnet18 com apenas uma linha de código. Em seguida, você precisa remover a camada final totalmente conectada da rede. O código é o seguinte:

from torchvision import models

resnet18=models.resnet18(pretrained=False)
del resnet18.fc

print(resnet18)

4. Conjunto de dados

Este artigo usa imagens de dados de paisagens naturais. Rastreamos mais de 1.000 imagens de dados no site. Algumas das imagens são as seguintes
Insira a descrição da imagem aqui

5. Treinamento de modelo e fluxograma de previsão

5.1 Fluxograma de treinamento

Conforme mostrado na figura abaixo, primeiro converta a imagem RGB em uma imagem Lab e, em seguida, use o canal L como entrada da rede de geração. A saída da rede de geração é o novo canal ab. Em seguida, o canal ab original do A imagem e o canal ab gerado pela rede de geração são de entrada.
Insira a descrição da imagem aqui

5.2 Fluxograma de previsão

A figura a seguir mostra o processo de predição do modelo. No processo de predição, a rede discriminante não tem função. Primeiro, a imagem RGB é convertida em uma imagem Lab e, em seguida, a imagem L em tons de cinza é inserida na rede de geração para obter um nova imagem do canal ab. Em seguida, o canal L é A imagem é concatenada com a imagem do canal ab gerada. Após a concatenação, uma nova imagem Lab pode ser obtida e depois convertida para o formato RGB. Neste momento, a imagem é a imagem colorida.
Insira a descrição da imagem aqui

6. Efeito de previsão do modelo

下图为模型的预测效果。左侧的为灰度图像,中间的为原始的彩色图像,右侧的是模型上色以后的图像。整体上看,网络的上色效果还不错。

Insira a descrição da imagem aqui
Insira a descrição da imagem aqui
Insira a descrição da imagem aqui
Insira a descrição da imagem aqui

7. Produção de interface GUI

Para tornar o uso do modelo mais conveniente, este artigo utiliza pyqt5 para criar a interface de operação, a interface é mostrada abaixo: Primeiro, você pode carregar a imagem do computador, você também pode alternar para a imagem anterior ou seguinte , e você pode exibir a imagem em escala de cinza. Ele pode ser colorido e, em seguida, as informações H, S e V da imagem colorida podem ser ajustadas.Finalmente, a exportação de imagens é suportada e a imagem colorida pode ser salva localmente.
Insira a descrição da imagem aqui
Insira a descrição da imagem aqui

8. Download de código

O link contém código de treinamento, código de teste e código de interface. Além disso, ele também contém mais de 1.000 conjuntos de dados, e a interface de operação aparecerá executando o programa main.py diretamente.
Download de código: Baixar lista de endereços 1

Acho que você gosta

Origin blog.csdn.net/2302_82079084/article/details/135126761
Recomendado
Clasificación