ECBSR(リアルタイムMM21_ECBSR用エッジ指向畳み込みブロック)
1. 著者の目的は、モバイル端末に適した効率的な超解像ネットワークを開発することです。
マルチブランチ構造と高密度接続により、特徴の抽出と表現が強化され、あまり多くの FLOP は発生しませんが、並列化速度が犠牲になり、DDR の低帯域幅の影響を受けます。
delite conv などの他の畳み込み手法もネットワーク パフォーマンスを向上させるために提案されていますが、GPU や NPU では十分に最適化されていない可能性があります。
したがって、フラットなネットワーク構造と従来の畳み込み手法を使用する予定です。
2. 著者はプレーンネットを使用することにしましたが、効果が良くなかったので、特徴表現を豊かにするために重いパラメータ化方法を使用しました。
主な構造は次の図に示されています。
-
単一の conv-3x3
-
conv-1x1 + conv-3x3: 拡張と圧縮
-
conv-1x1 + sobelx
-
conv-1x1 + 地味に(画像はコードと矛盾しています)
-
conv-1x1 + laplasian は抽出された画像のエッジ特徴を表示します
トレーニング中、ネットワークは右側の 5 つのブランチで構成されます。推論中に、再パラメータ化テクノロジーを使用してそれらを conv-3x3 にマージできます。このようにして、推論の速度と効率が向上し、基本的に精度が損なわれることはありません。
3. 効率的な推論のための再パラメータ化
全体的なネットワーク構造: ECB モジュールとピクセル シャッフル
## parameters for ecbsr
scale: 2
colors: 1
m_ecbsr: 4
c_ecbsr: 16
idt_ecbsr: 0
act_type: 'prelu'
pretrain: null
1 + 4 个 conv
1 个 pixel shuffle
class ECBSR(nn.Module):
def __init__(self, module_nums, channel_nums, with_idt, act_type, scale, colors):
super(ECBSR, self).__init__()
self.module_nums = module_nums
self.channel_nums = channel_nums
self.scale = scale
self.colors = colors
self.with_idt = with_idt
self.act_type = act_type
self.backbone = None
self.upsampler = None
backbone = []
backbone += [ECB(self.colors, self.channel_nums, depth_multiplier=2.0, act_type=self.act_type, with_idt = self.with_idt)]
for i in range(self.module_nums):
backbone += [ECB(self.channel_nums, self.channel_nums, depth_multiplier=2.0, act_type=self.act_type, with_idt = self.with_idt)]
backbone += [ECB(self.channel_nums, self.colors*self.scale*self.scale, depth_multiplier=2.0, act_type='linear', with_idt = self.with_idt)]
self.backbone = nn.Sequential(*backbone)
self.upsampler = nn.PixelShuffle(self.scale)
def forward(self, x):
y = self.backbone(x) + x
y = self.upsampler(y)
return y
ecb モジュール: 5 つの畳み込みブランチの定義が含まれます
class ECB(nn.Module):
def __init__(self, inp_planes, out_planes, depth_multiplier, act_type='prelu', with_idt = False):
super(ECB, self).__init__()
self.depth_multiplier = depth_multiplier
self.inp_planes = inp_planes
self.out_planes = out_planes
self.act_type = act_type
if with_idt and (self.inp_planes == self.out_planes):
self.with_idt = True
else:
self.with_idt = False
self.conv3x3 = torch.nn.Conv2d(self.inp_planes, self.out_planes, kernel_size=3, padding=1)
self.conv1x1_3x3 = SeqConv3x3('conv1x1-conv3x3', self.inp_planes, self.out_planes, self.depth_multiplier)
self.conv1x1_sbx = SeqConv3x3('conv1x1-sobelx', self.inp_planes, self.out_planes, -1)
self.conv1x1_sby = SeqConv3x3('conv1x1-sobely', self.inp_planes, self.out_planes, -1)
self.conv1x1_lpl = SeqConv3x3('conv1x1-laplacian', self.inp_planes, self.out_planes, -1)
if self.act_type == 'prelu':
self.act = nn.PReLU(num_parameters=self.out_planes)
elif self.act_type == 'relu':
self.act = nn.ReLU(inplace=True)
elif self.act_type == 'rrelu':
self.act = nn.RReLU(lower=-0.05, upper=0.05)
elif self.act_type == 'softplus':
self.act = nn.Softplus()
elif self.act_type == 'linear':
pass
else:
raise ValueError('The type of activation if not support!')
def forward(self, x):
if self.training:
y = self.conv3x3(x) + \
self.conv1x1_3x3(x) + \
self.conv1x1_sbx(x) + \
self.conv1x1_sby(x) + \
self.conv1x1_lpl(x)
if self.with_idt:
y += x
else:
RK, RB = self.rep_params()
y = F.conv2d(input=x, weight=RK, bias=RB, stride=1, padding=1)
if self.act_type != 'linear':
y = self.act(y)
return y
def rep_params(self):
K0, B0 = self.conv3x3.weight, self.conv3x3.bias
K1, B1 = self.conv1x1_3x3.rep_params()
K2, B2 = self.conv1x1_sbx.rep_params()
K3, B3 = self.conv1x1_sby.rep_params()
K4, B4 = self.conv1x1_lpl.rep_params()
RK, RB = (K0+K1+K2+K3+K4), (B0+B1+B2+B3+B4)
if self.with_idt:
device = RK.get_device()
if device < 0:
device = None
K_idt = torch.zeros(self.out_planes, self.out_planes, 3, 3, device=device)
for i in range(self.out_planes):
K_idt[i, i, 1, 1] = 1.0
B_idt = 0.0
RK, RB = RK + K_idt, RB + B_idt
return RK, RB
重いパラメータ化の具体的な実装について
class SeqConv3x3(nn.Module):
def __init__(self, seq_type, inp_planes, out_planes, depth_multiplier):
super(SeqConv3x3, self).__init__()
self.type = seq_type
self.inp_planes = inp_planes
self.out_planes = out_planes
if self.type == 'conv1x1-conv3x3':
self.mid_planes = int(out_planes * depth_multiplier)
conv0 = torch.nn.Conv2d(self.inp_planes, self.mid_planes, kernel_size=1, padding=0)
self.k0 = conv0.weight
self.b0 = conv0.bias
conv1 = torch.nn.Conv2d(self.mid_planes, self.out_planes, kernel_size=3)
self.k1 = conv1.weight
self.b1 = conv1.bias
elif self.type == 'conv1x1-sobelx':
conv0 = torch.nn.Conv2d(self.inp_planes, self.out_planes, kernel_size=1, padding=0)
self.k0 = conv0.weight
self.b0 = conv0.bias
# init scale & bias
scale = torch.randn(size=(self.out_planes, 1, 1, 1)) * 1e-3
self.scale = nn.Parameter(scale)
# bias = 0.0
# bias = [bias for c in range(self.out_planes)]
# bias = torch.FloatTensor(bias)
bias = torch.randn(self.out_planes) * 1e-3
bias = torch.reshape(bias, (self.out_planes,))
self.bias = nn.Parameter(bias)
# init mask
self.mask = torch.zeros((self.out_planes, 1, 3, 3), dtype=torch.float32)
for i in range(self.out_planes):
self.mask[i, 0, 0, 0] = 1.0
self.mask[i, 0, 1, 0] = 2.0
self.mask[i, 0, 2, 0] = 1.0
self.mask[i, 0, 0, 2] = -1.0
self.mask[i, 0, 1, 2] = -2.0
self.mask[i, 0, 2, 2] = -1.0
self.mask = nn.Parameter(data=self.mask, requires_grad=False)
elif self.type == 'conv1x1-sobely':
conv0 = torch.nn.Conv2d(self.inp_planes, self.out_planes, kernel_size=1, padding=0)
self.k0 = conv0.weight
self.b0 = conv0.bias
# init scale & bias
scale = torch.randn(size=(self.out_planes, 1, 1, 1)) * 1e-3
self.scale = nn.Parameter(torch.FloatTensor(scale))
# bias = 0.0
# bias = [bias for c in range(self.out_planes)]
# bias = torch.FloatTensor(bias)
bias = torch.randn(self.out_planes) * 1e-3
bias = torch.reshape(bias, (self.out_planes,))
self.bias = nn.Parameter(torch.FloatTensor(bias))
# init mask
self.mask = torch.zeros((self.out_planes, 1, 3, 3), dtype=torch.float32)
for i in range(self.out_planes):
self.mask[i, 0, 0, 0] = 1.0
self.mask[i, 0, 0, 1] = 2.0
self.mask[i, 0, 0, 2] = 1.0
self.mask[i, 0, 2, 0] = -1.0
self.mask[i, 0, 2, 1] = -2.0
self.mask[i, 0, 2, 2] = -1.0
self.mask = nn.Parameter(data=self.mask, requires_grad=False)
elif self.type == 'conv1x1-laplacian':
conv0 = torch.nn.Conv2d(self.inp_planes, self.out_planes, kernel_size=1, padding=0)
self.k0 = conv0.weight
self.b0 = conv0.bias
# init scale & bias
scale = torch.randn(size=(self.out_planes, 1, 1, 1)) * 1e-3
self.scale = nn.Parameter(torch.FloatTensor(scale))
# bias = 0.0
# bias = [bias for c in range(self.out_planes)]
# bias = torch.FloatTensor(bias)
bias = torch.randn(self.out_planes) * 1e-3
bias = torch.reshape(bias, (self.out_planes,))
self.bias = nn.Parameter(torch.FloatTensor(bias))
# init mask
self.mask = torch.zeros((self.out_planes, 1, 3, 3), dtype=torch.float32)
for i in range(self.out_planes):
self.mask[i, 0, 0, 1] = 1.0
self.mask[i, 0, 1, 0] = 1.0
self.mask[i, 0, 1, 2] = 1.0
self.mask[i, 0, 2, 1] = 1.0
self.mask[i, 0, 1, 1] = -4.0
self.mask = nn.Parameter(data=self.mask, requires_grad=False)
else:
raise ValueError('the type of seqconv is not supported!')
def forward(self, x):
if self.type == 'conv1x1-conv3x3':
# conv-1x1
y0 = F.conv2d(input=x, weight=self.k0, bias=self.b0, stride=1)
# explicitly padding with bias
y0 = F.pad(y0, (1, 1, 1, 1), 'constant', 0)
b0_pad = self.b0.view(1, -1, 1, 1)
y0[:, :, 0:1, :] = b0_pad
y0[:, :, -1:, :] = b0_pad
y0[:, :, :, 0:1] = b0_pad
y0[:, :, :, -1:] = b0_pad
# conv-3x3
y1 = F.conv2d(input=y0, weight=self.k1, bias=self.b1, stride=1)
else:
y0 = F.conv2d(input=x, weight=self.k0, bias=self.b0, stride=1)
# explicitly padding with bias
y0 = F.pad(y0, (1, 1, 1, 1), 'constant', 0)
b0_pad = self.b0.view(1, -1, 1, 1)
y0[:, :, 0:1, :] = b0_pad
y0[:, :, -1:, :] = b0_pad
y0[:, :, :, 0:1] = b0_pad
y0[:, :, :, -1:] = b0_pad
# conv-3x3
y1 = F.conv2d(input=y0, weight=self.scale * self.mask, bias=self.bias, stride=1, groups=self.out_planes)
return y1
def rep_params(self):
device = self.k0.get_device()
if device < 0:
device = None
if self.type == 'conv1x1-conv3x3':
# re-param conv kernel
RK = F.conv2d(input=self.k1, weight=self.k0.permute(1, 0, 2, 3))
# re-param conv bias
RB = torch.ones(1, self.mid_planes, 3, 3, device=device) * self.b0.view(1, -1, 1, 1)
RB = F.conv2d(input=RB, weight=self.k1).view(-1,) + self.b1
else:
tmp = self.scale * self.mask
k1 = torch.zeros((self.out_planes, self.out_planes, 3, 3), device=device)
for i in range(self.out_planes):
k1[i, i, :, :] = tmp[i, 0, :, :]
b1 = self.bias
# re-param conv kernel
RK = F.conv2d(input=k1, weight=self.k0.permute(1, 0, 2, 3))
# re-param conv bias
RB = torch.ones(1, self.out_planes, 3, 3, device=device) * self.b0.view(1, -1, 1, 1)
RB = F.conv2d(input=RB, weight=k1).view(-1,) + b1
return RK, RB
4. 結果
エッジSR
1. 転置コンボリューションアップサンプリングとピクセルシャッフルの違い
2.プーリングまたはダウンサンプリングにはエイリアシングアーティファクトが発生する可能性があります
アンチエイリアス ローパス フィルターを使用して、画像をダウンサンプリングします。
このプロセスは、カーネル パラメータまたは重みパラメータがローパス フィルタ係数に対応する、ストライド畳み込み
層を備えたテンソル処理フレームワークに実装されます。
3.単層ネットワークeSR-MAX
コンボリューション 1 つ、ピクセル シャッフル 1 つ、最大 1 つ
コンボリューションによって出力されるチャンネル数: sxsxchannel
out_channels=self.stride[0]*self.stride[1]*self.channels,
4.eSR-TM、eSR-TR、eSR-CNN
コードを直接見て理解することをお勧めします。
class edgeSR_TM(nn.Module):
def __init__(self, model_id):
self.model_id = model_id
super().__init__()
assert self.model_id.startswith('eSR-TM_')
parse = self.model_id.split('_')
self.channels = int([s for s in parse if s.startswith('C')][0][1:])
self.kernel_size = (int([s for s in parse if s.startswith('K')][0][1:]), ) * 2
self.stride = (int([s for s in parse if s.startswith('s')][0][1:]), ) * 2
self.pixel_shuffle = nn.PixelShuffle(self.stride[0])
self.softmax = nn.Softmax(dim=1)
self.filter = nn.Conv2d(
in_channels=1,
out_channels=2*self.stride[0]*self.stride[1]*self.channels,
kernel_size=self.kernel_size,
stride=1,
padding=(
(self.kernel_size[0]-1)//2,
(self.kernel_size[1]-1)//2
),
groups=1,
bias=False,
dilation=1
)
nn.init.xavier_normal_(self.filter.weight, gain=1.)
self.filter.weight.data[:, 0, self.kernel_size[0]//2, self.kernel_size[0]//2] = 1.
def forward(self, input):
filtered = self.pixel_shuffle(self.filter(input))
value, key = torch.split(filtered, [self.channels, self.channels], dim=1)
return torch.sum(
value * self.softmax(key),
dim=1, keepdim=True
)
class edgeSR_TR(nn.Module):
def __init__(self, model_id):
self.model_id = model_id
super().__init__()
assert self.model_id.startswith('eSR-TR_')
parse = self.model_id.split('_')
self.channels = int([s for s in parse if s.startswith('C')][0][1:])
self.kernel_size = (int([s for s in parse if s.startswith('K')][0][1:]), ) * 2
self.stride = (int([s for s in parse if s.startswith('s')][0][1:]), ) * 2
self.pixel_shuffle = nn.PixelShuffle(self.stride[0])
self.softmax = nn.Softmax(dim=1)
self.filter = nn.Conv2d(
in_channels=1,
out_channels=3*self.stride[0]*self.stride[1]*self.channels,
kernel_size=self.kernel_size,
stride=1,
padding=(
(self.kernel_size[0]-1)//2,
(self.kernel_size[1]-1)//2
),
groups=1,
bias=False,
dilation=1
)
nn.init.xavier_normal_(self.filter.weight, gain=1.)
self.filter.weight.data[:, 0, self.kernel_size[0]//2, self.kernel_size[0]//2] = 1.
def forward(self, input):
filtered = self.pixel_shuffle(self.filter(input))
value, query, key = torch.split(filtered, [self.channels, self.channels, self.channels], dim=1)
return torch.sum(
value * self.softmax(query*key),
dim=1, keepdim=True
)