LISTA通过学习的方式进一步提高稀疏向量的生成效率 python(pytorch)

Learned ISTA 基于学习的快速迭代软阈值方法

前置:

大佬的论文

LISTA与ISTA

相同点

思想

  • 希望通过软阈值函数来保证Z稀疏
  • 希望通过基于梯度下降(GD)算法实现对Z的优化

方法

  • 通过迭代提高Z的稀疏性和保真程度

期望

  • 希望通过优化使得 E W d ( X , Z ) = 1 2 ∣ ∣ X − W d Z ∣ ∣ 2 2 + α ∣ ∣ Z ∣ ∣ 1 E_{W_d}(X,Z)=\frac{1}{2}||X-W_dZ||_2^2+\alpha||Z||_1 EWd(X,Z)=21XWdZ22+αZ1尽可能小

W d W_d Wd

  • Wd都是最初确定的字典,都是过完备集

LISTA想要如何进一步优化

1.对迭代公式进行变形

Z = h α / L ( Z − 1 L W d T ( W d Z − X ) ) Z=h_{\alpha/L}(Z-\frac{1}{L}W_d^T(W_dZ-X)) Z=hα/L(ZL1WdT(WdZX))
Z = h θ ( S Z + W e X ) Z=h_\theta(SZ+W_eX) Z=hθ(SZ+WeX)
其中 W e = 1 L W d − 1 , S = I − W d T W d , θ = α / L We=\frac{1}{L}W_d^{-1},S=I-W_d^TW_d,\theta=\alpha/L We=L1Wd1,S=IWdTWd,θ=α/L

2. 选取可训练参数集

ϕ = { θ , W e , S } \phi=\{\theta,W_e,S\} ϕ={ θ,We,S}

3. 通过BPTT进行训练

LISTA Pytorch 实现

from abc import ABC

import numpy as np
import torch
import torch.nn as nn
from torch.nn import modules
from random import randint
from scipy.linalg import orth

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


# print('using device:' + torch.cuda.get_device_name(0))


class Lista(modules.Module, ABC):
    def __init__(self, dictionary: torch.Tensor, alpha: float, lipchitz: float, iterate_times: int):
        self.dictionary = dictionary
        super(Lista, self).__init__()
        self._theta = alpha / lipchitz
        self._iterate_time = iterate_times
        self._n, self._m = dictionary.shape
        self._w = modules.Linear(self._m, self._n, bias=False)
        self._s = modules.Linear(self._m, self._m, bias=False)
        self._sh = modules.Softshrink(self._theta)
        d = dictionary.cpu().numpy()
        w = np.transpose(d).dot(1 / lipchitz)
        s = np.eye(self._m) - np.dot(np.transpose(d), d).dot(1 / lipchitz)
        w = torch.from_numpy(w).float().to(device)
        s = torch.from_numpy(s).float().to(device)
        self._w.weight = nn.Parameter(w)
        self._s.weight = nn.Parameter(s)

    def forward(self, x: np.ndarray) -> np.ndarray:
        bias = self._w(x)
        Z = self._sh(bias)
        i = 0
        for _ in range(self._iterate_time):
            i += 1
            c = bias + self._s(Z)
            t = self._sh(bias + c)
            e = t.sub(Z)
            if len(e.shape) == 1:
                e = torch.reshape(e, (e.shape[0], 1))
            if np.max(torch.mm(e.t(), e).cpu().detach().numpy() >= 1e-2) == 0:
                Z = t
                break
            Z = t
        print(i)
        return Z


def lista_train(x: np.ndarray, dictionary: np.ndarray, lipchitz: float, alpha: float = 0.1, Lambda: float = 0.1,
                iterate_times: int = 1000, learning_rate: float = 0.01, batch: int = 128, epoch: int = 128):
    # 数据集初始化
    n, m = dictionary.shape
    simple_size = x.shape[0]
    x = torch.from_numpy(x).float().to(device)
    step_size = simple_size // batch
    zero_m = torch.zeros(batch, m).to(device)
    # 模型初始化
    dictionary = torch.from_numpy(dictionary).float().to(device)
    model = Lista(dictionary, alpha, lipchitz, iterate_times).float().to(device)
    criterion_similar = modules.MSELoss()
    criterion_sparse = modules.L1Loss()
    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
    # 模型训练
    for _epoch in range(epoch):
        sample_index = np.random.choice(a=simple_size, size=batch * step_size, replace=False, p=None)
        x_simple = x[sample_index]
        for _batch in range(step_size):
            x_batch = x_simple[_batch * batch:(_batch + 1) * batch, :]
            # print(x_batch.shape)
            optimizer.zero_grad()
            _z = model(x_batch)
            _x = torch.mm(_z, dictionary.t())
            loss_1 = criterion_similar(_x.float(), x_batch.float())
            loss_2 = Lambda * criterion_sparse(_z.float(), zero_m.float())
            loss = loss_1 + loss_2
            loss.backward()
            optimizer.step()
    return model


def train():
    # dimensions of the sparse signal, measurement and sparsity
    m, n, k = 1024, 256, 5
    # number of test examples
    N = randint(4000, 6000)
    # N = 128
    global W_d
    # generate sparse signal Z and measurement X
    Z = np.zeros((N, m))
    X = np.zeros((N, n))
    for i in range(N):
        index_k = np.random.choice(a=m, size=k, replace=False, p=None)
        Z[i, index_k] = 5 * np.random.randn(k, 1).reshape([-1, ])
        X[i] = np.dot(W_d, Z[i, :])

    # computing average reconstruction-SNR
    return lista_train(X, W_d, 2, 0.1, iterate_times=1000)


def work(model: modules.Module, x: np.ndarray):
    return model(torch.from_numpy(x).float().to(device))


if __name__ == '__main__':
    # dimensions of the sparse signal, measurement and sparsity
    m, n, k = 1024, 256, 5
    Psi = np.eye(m)
    Phi = np.random.randn(n, m)
    Phi = np.transpose(orth(np.transpose(Phi)))
    global W_d
    W_d = np.dot(Phi, Psi)
    model = train()
    for _ in range(100):
        Z = np.zeros(m)
        index_k = np.random.choice(a=m, size=k, replace=False, p=None)
        Z[index_k] = 5 * np.random.randn(k, 1).reshape([-1, ])
        X = np.dot(W_d, Z)
        z = work(model, X)
        err = (Z - z.cpu().detach().numpy())
        print(np.transpose(err).dot(err))

不会画图明天学!

评价有问题

我计算输出的err为:
e r r = ∣ ∣ Z − Z ∗ ∣ ∣ 2 2 err = ||Z-Z^*||_2^2 err=ZZ22
误差应该使用:
E W d ( X , Z ) = 1 2 ∣ ∣ X − W d Z ∣ ∣ 2 2 + α ∣ ∣ Z ∣ ∣ 1 E_{W_d}(X,Z)=\frac{1}{2}||X-W_dZ||_2^2+\alpha||Z||_1 EWd(X,Z)=21XWdZ22+αZ1

猜你喜欢

转载自blog.csdn.net/goes_on/article/details/109095976