ディープラーニング(Opencv、Pytorch、CNN)に基づく画像の色付け

1 はじめに

ソースコードのダウンロードアドレスは記事の最後に添付されています。

グレースケール画像に自動的に色を付ける

2. 画像フォーマット(RGB、HSV、Lab)

2.1 RGB

グレースケール画像に色を付けたい場合は、まず画像の形式を理解する必要があります。通常の画像の場合、通常は RGB 形式、つまり赤、緑、青の 3 チャンネルです。opencv を使用して分離できます。画像の 3 つのチャネルのコードは次のとおりです。

import cv2

img=cv2.imread('pic/7.jpg')
B,G,R=cv2.split(img)
cv2.imshow('img',img)
cv2.imshow('B',B)
cv2.imshow('G',G)
cv2.imshow('R',R)
cv2.waitKey(0)

コードを実行した結果は次のようになります。
ここに画像の説明を挿入します

2.2hv

HSV は画像の別の形式で、h は画像の色相、s は画像の彩度、v は画像の明るさを表し、値を調整することで画像の色相、彩度、明るさなどの情報を変更できます。 h、s、vの。
opencv を使用して、イメージを RGB 形式から hsv 形式に変換することもできます。次に、次のように h、s、v チャネルを分離して画像コードを表示できます。

import cv2

img=cv2.imread('pic/7.jpg')
hsv=cv2.cvtColor(img,cv2.COLOR_BGR2HSV)
h,s,v=cv2.split(hsv)
cv2.imshow('hsv',hsv)
cv2.imshow('h',h)
cv2.imshow('s',s)
cv2.imshow('v',v)
cv2.waitKey(0)

実行結果は次のとおりです。
ここに画像の説明を挿入します

2.3 ラボ

Lab はイメージの別の形式であり、この記事で使用されている形式でもあります。L はグレースケール イメージを表し、a と b はカラー チャネルを表します。この記事では、L チャネルのグレースケール イメージを入力として使用し、ab カラー チャネルを入力として使用します。対立を生成するためにトレーニングする出力ネットワーク、画像を RGB 形式から Lab 形式に変換するコードは次のとおりです。

import cv2

img=cv2.imread('pic/7.jpg')
Lab=cv2.cvtColor(img,cv2.COLOR_BGR2Lab)
L,a,b=cv2.split(Lab)
cv2.imshow('Lab',Lab)
cv2.imshow('L',L)
cv2.imshow('a',a)
cv2.imshow('b',b)
cv2.waitKey(0)

ここに画像の説明を挿入します

3. 敵対的生成ネットワーク (GAN)

生成对抗网络主要包含两部分,分别是生成网络和判别网络。
生成网络负责生成图像,判别网络负责鉴定生成图像的好坏,二者相辅相成,相互博弈。
本文使用U-net作为生成网络,使用ResNet18作为判别网络。U-net网络的结构图如下所示:

3.1 ネットワーク(Unet)の生成

ここに画像の説明を挿入します

pytorch が unet ネットワークを構築するためのコードは次のとおりです。

class DownsampleLayer(nn.Module):
    def __init__(self,in_ch,out_ch):
        super(DownsampleLayer, self).__init__()
        self.Conv_BN_ReLU_2=nn.Sequential(
            nn.Conv2d(in_channels=in_ch,out_channels=out_ch,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(),
            nn.Conv2d(in_channels=out_ch, out_channels=out_ch, kernel_size=3, stride=1,padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU()
        )
        self.downsample=nn.Sequential(
            nn.Conv2d(in_channels=out_ch,out_channels=out_ch,kernel_size=3,stride=2,padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU()
        )

    def forward(self,x):
        """
        :param x:
        :return: out输出到深层,out_2输入到下一层,
        """
        out=self.Conv_BN_ReLU_2(x)
        out_2=self.downsample(out)
        return out,out_2
class UpSampleLayer(nn.Module):
	def __init__(self,in_ch,out_ch):
	   # 512-1024-512
	   # 1024-512-256
	   # 512-256-128
	   # 256-128-64
	   super(UpSampleLayer, self).__init__()
   self.Conv_BN_ReLU_2 = nn.Sequential(
       nn.Conv2d(in_channels=in_ch, out_channels=out_ch*2, kernel_size=3, stride=1,padding=1),
       nn.BatchNorm2d(out_ch*2),
       nn.ReLU(),
       nn.Conv2d(in_channels=out_ch*2, out_channels=out_ch*2, kernel_size=3, stride=1,padding=1),
       nn.BatchNorm2d(out_ch*2),
       nn.ReLU()
   )
   self.upsample=nn.Sequential(
       nn.ConvTranspose2d(in_channels=out_ch*2,out_channels=out_ch,kernel_size=3,stride=2,padding=1,output_padding=1),
       nn.BatchNorm2d(out_ch),
       nn.ReLU()
   )

	def forward(self,x,out):
	   '''
	   :param x: 输入卷积层
	   :param out:与上采样层进行cat
	   :return:
	   '''
	   x_out=self.Conv_BN_ReLU_2(x)
	   x_out=self.upsample(x_out)
	   cat_out=torch.cat((x_out,out),dim=1)
	   return cat_out
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        out_channels=[2**(i+6) for i in range(5)] #[64, 128, 256, 512, 1024]
        #下采样
        self.d1=DownsampleLayer(3,out_channels[0])#3-64
        self.d2=DownsampleLayer(out_channels[0],out_channels[1])#64-128
        self.d3=DownsampleLayer(out_channels[1],out_channels[2])#128-256
        self.d4=DownsampleLayer(out_channels[2],out_channels[3])#256-512
        #上采样
        self.u1=UpSampleLayer(out_channels[3],out_channels[3])#512-1024-512
        self.u2=UpSampleLayer(out_channels[4],out_channels[2])#1024-512-256
        self.u3=UpSampleLayer(out_channels[3],out_channels[1])#512-256-128
        self.u4=UpSampleLayer(out_channels[2],out_channels[0])#256-128-64
        #输出
        self.o=nn.Sequential(
            nn.Conv2d(out_channels[1],out_channels[0],kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(out_channels[0]),
            nn.ReLU(),
            nn.Conv2d(out_channels[0], out_channels[0], kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_channels[0]),
            nn.ReLU(),
            nn.Conv2d(out_channels[0],3,3,1,1),
            nn.Sigmoid(),
            # BCELoss
        )
    def forward(self,x):
        out_1,out1=self.d1(x)
        out_2,out2=self.d2(out1)
        out_3,out3=self.d3(out2)
        out_4,out4=self.d4(out3)
        out5=self.u1(out4,out_4)
        out6=self.u2(out5,out_3)
        out7=self.u3(out6,out_2)
        out8=self.u4(out7,out_1)
        out=self.o(out8)
        return out


3.2 判別ネットワーク (resnet18)

resnet18 の構造図は次のとおりです:
ここに画像の説明を挿入します
Pytorch には独自の resnet18 モデルが付属しています。たった 1 行のコードで resnet18 モデルを構築できます。その後、ネットワークの最後の完全に接続された層を削除する必要があります。コードは次のとおりです:

from torchvision import models

resnet18=models.resnet18(pretrained=False)
del resnet18.fc

print(resnet18)

4. データセット

この記事では、Webサイト上にある1,000枚以上の自然風景のデータをクロールして使用しています。
ここに画像の説明を挿入します

5. モデルのトレーニングと予測のフローチャート

5.1 トレーニングのフローチャート

以下の図に示すように、最初に RGB 画像を Lab 画像に変換し、次に L チャネルを生成ネットワークの入力として使用します。生成ネットワークの出力は新しい ab チャネルです。その後、元の ab チャネルが生成されます。画像と生成ネットワークで生成したabチャンネルを入力とする判別ネットワーク。
ここに画像の説明を挿入します

5.2 予測フローチャート

次の図は、モデルの予測プロセスを示しています。予測プロセスでは、判別ネットワークは役割を持ちません。まず、RGB 画像が Lab 画像に変換され、次に L グレースケール画像が生成ネットワークに入力されて、新しい ab チャネル画像が生成され、L チャネルは生成された ab チャネル画像と連結されます。連結後、新しい Lab 画像が取得され、RGB 形式に変換されます。このとき、画像はカラー画像になります。
ここに画像の説明を挿入します

6. モデル予測効果

下图为模型的预测效果。左侧的为灰度图像,中间的为原始的彩色图像,右侧的是模型上色以后的图像。整体上看,网络的上色效果还不错。

ここに画像の説明を挿入します
ここに画像の説明を挿入します
ここに画像の説明を挿入します
ここに画像の説明を挿入します

7. GUIインターフェースの制作

モデルをより便利に使用するために、この記事では pyqt5 を使用して操作インターフェイスを作成します。インターフェイスは次のようになります: まず、コンピューターから画像をロードし、前後の画像に切り替えることもできますを選択すると、画像をグレースケールで表示できます。カラー画像の H、S、V 情報を調整し、画像のエクスポートをサポートし、カラー画像をローカルに保存できます。
ここに画像の説明を挿入します
ここに画像の説明を挿入します

8. コードのダウンロード

リンクには、トレーニング コード、テスト コード、インターフェイス コードが含まれています。さらに、1,000 を超えるデータセットも含まれており、main.py プログラムを直接実行することで操作インターフェイスがポップアップ表示されます。
コードダウンロード:ダウンロードアドレスリスト1

おすすめ

転載: blog.csdn.net/2302_82079084/article/details/135126761