pytorch中torch.nn.Parameter()

1.公式チュートリアル:

URL:

torch.nn.parameter —PyTorch1.11.0のドキュメント

第二に、コードの解釈:

torch.nn.Parameter(torch.Tensor)は、torch.Tensorクラスを継承するクラスであり、次の2つのパラメーターがあります。

  • data(Tensor):与えられたテンソル;
  • require_grad:グラデーションが必要かどうかを指定します。デフォルトはTrueです。

簡単な説明:

まず、この関数は型変換関数として理解できます。この関数は、トレーニング不可能な型トレーニング可能な型に変換し、これにバインドしますこのバインドが含まれているため、パラメーターの最適化中に最適化できます)。型変換は、モデルの一部になり、トレーニングに応じて変更できるモデルのパラメーターになります。Tensorparameterparametermodulenet.parameter()parameterこの関数を使用する目的は、最適化を達成するために、学習プロセス中にいくつかの変数がそれらの値を継続的に変更できるようにすることです。

3.実用的なアプリケーション:

  • たとえば、GSTでは、注意を計算するために複数のトークンをKおよびVとして定義する必要があります。ここでは、モデルの一部として継続的に変更および最適化するためにtorch.nn.Parameter()が使用されます。
  • 主にモデルクラスの__init__()で、それを宣言して標準化します。
self.embed = nn.Parameter(torch.FloatTensor(8, 64))
init.normal_(self.embed, mean=0, std=0.5)
        
  • 詳細は次のとおりです。 
class STL(nn.Module):
    '''
    inputs --- [N, E//2]
    '''

    def __init__(self,model_config):

        super().__init__()
        self.embed = nn.Parameter(torch.FloatTensor(model_config["gst"]["n_style_token"], model_config["gst"]["E"] // model_config["gst"]["attn_head"]))
        d_q = model_config["gst"]["E"] // 2
        d_k = model_config["gst"]["E"] // model_config["gst"]["attn_head"]
        self.attention = MultiHeadAttention(query_dim=d_q, key_dim=d_k, num_units=model_config["gst"]["E"], num_heads=model_config["gst"]["attn_head"])

        init.normal_(self.embed, mean=0, std=0.5)

    def forward(self, inputs):
        N = inputs.size(0)
        query = inputs.unsqueeze(1)  # [N, 1, E//2]
        keys = F.tanh(self.embed).unsqueeze(0).expand(N, -1, -1)  # [N, token_num, E // num_heads]
        style_embed = self.attention(query, keys)

        return style_embed
  • さらに、注意を使用するときに、Qをカスタマイズしてランダムに初期化する必要がある場合も、同じことが言えます。 

いくつかの参考文献:

torch.nn.Parameter()_chenzy_hustのブログ-CSDNblog_nn.parameter()

PyTorchのtorch.nn.Parameter()-プログラマーが求めた

おすすめ

転載: blog.csdn.net/m0_46483236/article/details/124020902