序文
よく遭遇しますが、ちょっとした補足をするたびに忘れてしまうので、記事を開いてゆっくり記録しておこうかなと思います。
VAE -> VQVAE、主にベクトル量子化を追加
VQ-VAEについては、生成された
この記事は継続的に更新されます…
この記事はよく書かれていますので、時間があるときによく見てください。変分オートエンコーダ VAE: これは何ですか | 添付のオープン ソース コード、Su Jianlin 氏の記事。
仮説
トレーニング段階
率直に言うと、サンプル (画像) を入力すると、エンコーダーはこのサンプルから特徴を抽出した後、2 つの量 (1 つは平均値、もう 1 つは分散) を学習します。
feature = encoder(img)
mu, var = w_mu(feature), w_var(feature)
次に、平均と分散に従って隠れ変数 z をサンプリングできます。
eps = torch.rand_like(mu)
z = mu + (var) ** 0.5 * eps
次に、この隠れた変数 z に従って、画像を復号化できます。
img_generate = デコーダ(z)
寸法が間違っている場合は、完全に接続されたレイヤーを追加して寸法をマッピングできます。
予測段階
まずランダムな隠し変数 z を生成し、z の次元に注意してから、
それをデコーダに詰め込みます。
コード
次のコードのソース: https://zhuanlan.zhihu.com/p/151587288
トレーニング コードはhttps://shenxiaohai.me/2018/10/20/pytorch-tutorial-advanced-02/
で参照できます。グラウンド トゥルースと KLloss を使用して CrossEntropyloss を実行します。
class VAE(nn.Module):
def __init__(self, latent_dim):
super().__init__()
self.encoder = nn.Sequential(nn.Linear(28 * 28, 256),
nn.ReLU(),
nn.Linear(256, 128))
self.mu = nn.Linear(128, latent_dim)
self.logvar = nn.Linear(128, latent_dim)
self.latent_mapping = nn.Linear(latent_dim, 128)
self.decoder = nn.Sequential(nn.Linear(128, 256),
nn.ReLU(),
nn.Linear(256, 28 * 28))
def encode(self, x):
x = x.view(x.size(0), -1)
encoder = self.encoder(x)
mu, logvar = self.mu(encoder), self.logvar(encoder)
return mu, logvar
def sample_z(self, mu, logvar):
eps = torch.rand_like(mu)
return mu + eps * torch.exp(0.5 * logvar)
def decode(self, z,x):
latent_z = self.latent_mapping(z)
out = self.decoder(latent_z)
reshaped_out = torch.sigmoid(out).view(x.shape[0],1, 28,28)
return reshaped_out
def forward(self, x):
mu, logvar = self.encode(x)
z = self.sample_z(mu, logvar)
output = self.decode(z,x)
return output
# 创建优化器
num_epochs = 10
learning_rate = 1e-3
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
for epoch in range(num_epochs):
for i, (x, _) in enumerate(data_loader):
# 获取样本,并前向传播
x = x.to(device).view(-1, 28 * 28)
x_predict = model(x)
# 计算重构损失和KL散度(KL散度用于衡量两种分布的相似程度)
# KL散度的计算可以参考论文或者文章开头的链接
reconst_loss = F.binary_cross_entropy(x_predict, x, size_average=False)
kl_div = - 0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
# 反向传播和优化
loss = reconst_loss + kl_div
optimizer.zero_grad()
loss.backward()
optimizer.step()