文章目录
Learned ISTA 基于学习的快速迭代软阈值方法
前置:
大佬的论文
LISTA与ISTA
相同点
思想
- 希望通过软阈值函数来保证Z稀疏
- 希望通过基于梯度下降(GD)算法实现对Z的优化
方法
- 通过迭代提高Z的稀疏性和保真程度
期望
- 希望通过优化使得 E W d ( X , Z ) = 1 2 ∣ ∣ X − W d Z ∣ ∣ 2 2 + α ∣ ∣ Z ∣ ∣ 1 E_{W_d}(X,Z)=\frac{1}{2}||X-W_dZ||_2^2+\alpha||Z||_1 EWd(X,Z)=21∣∣X−WdZ∣∣22+α∣∣Z∣∣1尽可能小
W d W_d Wd
- Wd都是最初确定的字典,都是过完备集
LISTA想要如何进一步优化
1.对迭代公式进行变形
Z = h α / L ( Z − 1 L W d T ( W d Z − X ) ) Z=h_{\alpha/L}(Z-\frac{1}{L}W_d^T(W_dZ-X)) Z=hα/L(Z−L1WdT(WdZ−X))
Z = h θ ( S Z + W e X ) Z=h_\theta(SZ+W_eX) Z=hθ(SZ+WeX)
其中 W e = 1 L W d − 1 , S = I − W d T W d , θ = α / L We=\frac{1}{L}W_d^{-1},S=I-W_d^TW_d,\theta=\alpha/L We=L1Wd−1,S=I−WdTWd,θ=α/L
2. 选取可训练参数集
ϕ = { θ , W e , S } \phi=\{\theta,W_e,S\} ϕ={ θ,We,S}
3. 通过BPTT进行训练
LISTA Pytorch 实现
from abc import ABC
import numpy as np
import torch
import torch.nn as nn
from torch.nn import modules
from random import randint
from scipy.linalg import orth
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# print('using device:' + torch.cuda.get_device_name(0))
class Lista(modules.Module, ABC):
def __init__(self, dictionary: torch.Tensor, alpha: float, lipchitz: float, iterate_times: int):
self.dictionary = dictionary
super(Lista, self).__init__()
self._theta = alpha / lipchitz
self._iterate_time = iterate_times
self._n, self._m = dictionary.shape
self._w = modules.Linear(self._m, self._n, bias=False)
self._s = modules.Linear(self._m, self._m, bias=False)
self._sh = modules.Softshrink(self._theta)
d = dictionary.cpu().numpy()
w = np.transpose(d).dot(1 / lipchitz)
s = np.eye(self._m) - np.dot(np.transpose(d), d).dot(1 / lipchitz)
w = torch.from_numpy(w).float().to(device)
s = torch.from_numpy(s).float().to(device)
self._w.weight = nn.Parameter(w)
self._s.weight = nn.Parameter(s)
def forward(self, x: np.ndarray) -> np.ndarray:
bias = self._w(x)
Z = self._sh(bias)
i = 0
for _ in range(self._iterate_time):
i += 1
c = bias + self._s(Z)
t = self._sh(bias + c)
e = t.sub(Z)
if len(e.shape) == 1:
e = torch.reshape(e, (e.shape[0], 1))
if np.max(torch.mm(e.t(), e).cpu().detach().numpy() >= 1e-2) == 0:
Z = t
break
Z = t
print(i)
return Z
def lista_train(x: np.ndarray, dictionary: np.ndarray, lipchitz: float, alpha: float = 0.1, Lambda: float = 0.1,
iterate_times: int = 1000, learning_rate: float = 0.01, batch: int = 128, epoch: int = 128):
# 数据集初始化
n, m = dictionary.shape
simple_size = x.shape[0]
x = torch.from_numpy(x).float().to(device)
step_size = simple_size // batch
zero_m = torch.zeros(batch, m).to(device)
# 模型初始化
dictionary = torch.from_numpy(dictionary).float().to(device)
model = Lista(dictionary, alpha, lipchitz, iterate_times).float().to(device)
criterion_similar = modules.MSELoss()
criterion_sparse = modules.L1Loss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
# 模型训练
for _epoch in range(epoch):
sample_index = np.random.choice(a=simple_size, size=batch * step_size, replace=False, p=None)
x_simple = x[sample_index]
for _batch in range(step_size):
x_batch = x_simple[_batch * batch:(_batch + 1) * batch, :]
# print(x_batch.shape)
optimizer.zero_grad()
_z = model(x_batch)
_x = torch.mm(_z, dictionary.t())
loss_1 = criterion_similar(_x.float(), x_batch.float())
loss_2 = Lambda * criterion_sparse(_z.float(), zero_m.float())
loss = loss_1 + loss_2
loss.backward()
optimizer.step()
return model
def train():
# dimensions of the sparse signal, measurement and sparsity
m, n, k = 1024, 256, 5
# number of test examples
N = randint(4000, 6000)
# N = 128
global W_d
# generate sparse signal Z and measurement X
Z = np.zeros((N, m))
X = np.zeros((N, n))
for i in range(N):
index_k = np.random.choice(a=m, size=k, replace=False, p=None)
Z[i, index_k] = 5 * np.random.randn(k, 1).reshape([-1, ])
X[i] = np.dot(W_d, Z[i, :])
# computing average reconstruction-SNR
return lista_train(X, W_d, 2, 0.1, iterate_times=1000)
def work(model: modules.Module, x: np.ndarray):
return model(torch.from_numpy(x).float().to(device))
if __name__ == '__main__':
# dimensions of the sparse signal, measurement and sparsity
m, n, k = 1024, 256, 5
Psi = np.eye(m)
Phi = np.random.randn(n, m)
Phi = np.transpose(orth(np.transpose(Phi)))
global W_d
W_d = np.dot(Phi, Psi)
model = train()
for _ in range(100):
Z = np.zeros(m)
index_k = np.random.choice(a=m, size=k, replace=False, p=None)
Z[index_k] = 5 * np.random.randn(k, 1).reshape([-1, ])
X = np.dot(W_d, Z)
z = work(model, X)
err = (Z - z.cpu().detach().numpy())
print(np.transpose(err).dot(err))
不会画图明天学!
评价有问题
我计算输出的err为:
e r r = ∣ ∣ Z − Z ∗ ∣ ∣ 2 2 err = ||Z-Z^*||_2^2 err=∣∣Z−Z∗∣∣22
误差应该使用:
E W d ( X , Z ) = 1 2 ∣ ∣ X − W d Z ∣ ∣ 2 2 + α ∣ ∣ Z ∣ ∣ 1 E_{W_d}(X,Z)=\frac{1}{2}||X-W_dZ||_2^2+\alpha||Z||_1 EWd(X,Z)=21∣∣X−WdZ∣∣22+α∣∣Z∣∣1