ネットワーク縮退現象と残留ネットワーク効果

導入

最近、ニューラルネットワークを利用して「文字列のすべての文字を後続の文字に置き換える(たとえば、a を b に置き換え、b を c に置き換える)」機能を実現するコードをインターネット上で目にしましたそこに残差ネットワークが追加されているのを見て、残差ネットワークに関連する概念を調べてみました。この記事では次のようになります。

残差ネットワークは何を解決しますか?また、なぜ効果的ですか?

ネットワークの劣化現象(つまり、深いネットワークの効率が浅いネットワークほど良くない)について言及されていたので、実験してみたいと思いました。

この実験では 3 つのネットワークが使用されました。

  1. 26 → 64 → 26 26 \rightarrow 64 \rightarrow 262 66 42 6完全に接続されたニューラル ネットワーク
  2. 26 → 64 → 26 → 64 → 64 → 26 26 \rightarrow 64 \rightarrow 26 \rightarrow 64 \rightarrow 64 \rightarrow 262 66 42 66 46 42 6完全に接続されたニューラル ネットワーク
  3. → 26 \rightarrow 26 ( 2 番目のネットワーク内)残差を追加した2 6ニューラル ネットワーク

上記 3 つのモデルのトレーニングに使用されるトレーニング セットはまったく同じです。

環境

  1. pytorch==1.10
  2. テンソルボード==2.7.0

コード

import string

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from tqdm import trange
from torch.utils.tensorboard import SummaryWriter

# 驱动选择
device = "cuda" if torch.cuda.is_available() else "cpu"

# tensorboard
writer_rn = SummaryWriter(log_dir = "runs/loss_rn")
writer_nn = SummaryWriter(log_dir = "runs/loss_nn")
writer_n = SummaryWriter(log_dir = "runs/loss_n")


print(f"Using {
      
      device} devive")

# 残差网络
class RN(nn.Module):
    def __init__(self):
        super(RN, self).__init__()
        self.linear_stack = nn.Sequential(
            nn.Linear(26, 64),
            nn.Hardsigmoid(),
            nn.Linear(64, 26),
            nn.Hardsigmoid(),
        )
        
        self.linear_stack_2 = nn.Sequential(
            nn.Linear(26, 64),
            nn.Hardsigmoid(),
            nn.Linear(64, 64),
            nn.Hardsigmoid(),
        )
        
        self.output_layer = nn.Linear(64, 26)
        
    def forward(self, x):
        y = self.linear_stack(x)
        # 残差
        y = y+x
        y = self.linear_stack_2(y)
        y = self.output_layer(y)
        
        return y

# 没加残差、其他结构完全一致的神经网络
class NN(nn.Module):
    def __init__(self):
        super(NN, self).__init__()
        self.linear_stack = nn.Sequential(
            nn.Linear(26, 64),
            nn.Hardsigmoid(),
            nn.Linear(64, 26),
            nn.Hardsigmoid(),
        )
        
        self.linear_stack_2 = nn.Sequential(
            nn.Linear(26, 64),
            nn.Hardsigmoid(),
            nn.Linear(64, 64),
            nn.Hardsigmoid(),
        )
        
        self.output_layer = nn.Linear(64, 26)
        
    def forward(self, x):
        y = self.linear_stack(x)
        # 此处没有参擦
        # x = y+x
        y = self.linear_stack_2(y)
        y = self.output_layer(y)
        
        return y
    
# 只有一层的神经网络
class N(nn.Module):
    def __init__(self):
        super(N, self).__init__()
        self.linear_stack = nn.Sequential(
            nn.Linear(26, 64),            
            nn.Hardsigmoid(),
        )
        
        
        self.output_layer = nn.Linear(64, 26)
        
    def forward(self, x):
        y = self.linear_stack(x)
        y = self.output_layer(y)
        
        return y
    
# Dataset类
class Data(Dataset):
    def __init__(self, x):
        """
        x:[1...26, 0]
        """
        self.data = list(

                zip(x, list(range(1, 26)) + [0])
        )
        
    def __len__(self):
        return 26
    
    def __getitem__(self, idx):
        return self.data[idx]

# 输出替换结果 
def trans_word(model, word):
    return "".join(
        alphabet_digit_map_reverse[model(alphabet_digit_map[w]).argmax().item()]
        for w in word
    )

# 生成两个模型实例
rNetwork = RN().to(device=device)
nNetwork = NN().to(device=device)
n = N().to(device=device)

"""
生成数据集
a:[1, 0, ..., 0]
b:[0, 1, ..., 0]
...
z:[0, 0, ..., 1]
"""
x = torch.zeros((26, 26), dtype=torch.float32).to(device=device)
for i in range(26):
    x[i][i] = 1

# 生成数据集对象
data = Data(x)
dataloader = DataLoader(data, batch_size=1, shuffle=True)

# 定义损失函数和优化器
loss_rn = nn.CrossEntropyLoss()
loss_nn = nn.CrossEntropyLoss()
loss_n = nn.CrossEntropyLoss()
optimizer_rn = torch.optim.Adam(rNetwork.parameters(), lr=1e-3)
optimizer_nn = torch.optim.Adam(nNetwork.parameters(), lr=1e-3)
optimizer_n = torch.optim.Adam(n.parameters(), lr=1e-3)

# 训练, 三个个模型使用完全相同的数据
for epoch in trange(500):
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)
        
        # 训练 rn
        pred = rNetwork(X)
        loss = loss_rn(pred, y)
        writer_rn.add_scalar("loss", loss.item(), global_step=epoch)
        optimizer_rn.zero_grad()
        loss.backward()
        optimizer_rn.step()
        
        # 训练 nn
        pred = nNetwork(X)
        loss = loss_nn(pred, y)
        writer_nn.add_scalar("loss", loss.item(), global_step=epoch)
        optimizer_nn.zero_grad()
        loss.backward()
        optimizer_nn.step()
        
        # 训练 n
        pred = n(X)
        loss = loss_n(pred, y)
        writer_n.add_scalar("loss", loss.item(), global_step=epoch)
        optimizer_n.zero_grad()
        loss.backward()
        optimizer_n.step()

# 关闭tensorboard流,保证信息所有输出完毕
writer_rn.close()
writer_nn.close()
writer_n.close()

# 定义字母表到数字的映射
alphabet_digit_map = dict(zip(string.ascii_lowercase, x))
# 数字到字母的映射
alphabet_digit_map_reverse = dict(zip(range(26), string.ascii_lowercase))

# a-z
my_word = string.ascii_lowercase
# 输出结果
print(trans_word(rNetwork, my_word))
print(trans_word(nNetwork, my_word))
print(trans_word(n, my_word))

出力結果

bcdefghijklmnopqrstuvwxyza
bidtfgtiiiidttptittitwiyia
bcdefghijklmnopqrstuvwxyza

テンソルボードの結果出力

ここに画像の説明を挿入
ここに画像の説明を挿入

このうち、オレンジ色は残留したディープ ニューラル ネットワーク、青色は残留物のないディープ ニューラル ネットワーク、赤色は浅いニューラル ネットワークです。

要約する

全体として、赤は収束速度と最終結果の点で 3 つのモデルの中で最も優れており、構造が最も単純なモデルでもあります。(ディープ ニューラル ネットワーク) は最も拡張されたモデルです。

青と赤を比較すると、実際にネットワークの劣化現象が発生していることがわかります。赤の収束速度は青よりもはるかに速く、青の損失は大きく変動しています(背景の水色が実際の損失で、濃い青色の線は圧縮されています)。

青とオレンジを比較すると、残差のみを追加することでネットワークの効率が大幅に向上し、ネットワークの劣化の問題を解決するのに残差が非常に有効であることがわかります。

赤とオレンジを比較すると、残差によってネットワーク劣化の問題がいくつか解決されていますが、収束速度と最終的な効果はまだ浅いネットワークほど良くないことがわかります。ただし、これはこの単純なモデルと単純な問題の結果にすぎません。モデル (Transformer など) には多くのアプリケーションがあり、非常に良い結果が達成されています。

オッカムの剃刀の原理と同様に、ニューラル ネットワークの深さをやみくもに増やしても必ずしも良い結果が得られるわけではないことがわかります。必要がない場合はエンティティを増加させず、問題の本質により注意を払います。 。

おすすめ

転載: blog.csdn.net/qq_42464569/article/details/122584736