lstm_crf


code

import torch
#一些参数
embedding_dim=300
hidden_dim=256
vocab_size=4833
dic_label={'<BEG>': 0, 'B-ORG': 1, 'B-LOC': 2, 'B-PER': 3, 'I-PER': 4, 'I-ORG': 5, 'I-LOC': 6, 'O': 7, '<END>': 8}
tagset_size=len(dic_label)
x=torch.tensor([4445, 1021, 1759, 825, 8, 481, 3763, 2985, 976, 3416, 1894, 843, 1478, 2044, 3033, 3802, 1756, 3080, 2240, 1459, 2285, 1220, 4090, 1478, 3246, 348, 1756, 3520, 2430, 2453, 2490, 4301, 3839, 2004, 2985, 2826, 3256, 406, 3764, 1756, 3220, 405, 3197, 924, 3256, 646, 2522, 4445, 4427, 1065])
y=torch.tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7])
torch.manual_seed(1)
<torch._C.Generator at 0x7f38583be750>

1.这里和LSTM一样

import torch.nn as nn
def init_hidden():
    # 一开始并没有隐藏状态所以我们要先初始化一个
    # 关于维度为什么这么设计请参考Pytoch相关文档
    # 各个维度的含义是 (num_layers*num_directions, batch_size, hidden_dim)
    return (torch.zeros(4, 1, hidden_dim//2),
            torch.zeros(4, 1, hidden_dim//2))
word_embeddings = nn.Embedding(vocab_size, embedding_dim)
lstm = nn.LSTM(embedding_dim, hidden_dim//2,num_layers=2,dropout=0.5,bidirectional=True)
hidden2tag = nn.Linear(hidden_dim, tagset_size)
hidden=init_hidden()
embed=word_embeddings(x)
lstm_out,hidden=lstm(embed.view(len(x),1,-1),hidden)
lstm_feats=hidden2tag(lstm_out.view(len(x), -1))
print(lstm_feats)
tensor([[ 3.8183e-02,  1.5836e-02,  3.7862e-02,  7.0418e-02,  2.2438e-02,
          4.1844e-02,  6.3254e-02, -6.2891e-02, -2.4113e-02],
        [ 1.7222e-02,  2.5335e-02,  4.9524e-02,  2.8871e-02,  1.0654e-02,
          3.5806e-02,  5.0367e-02, -8.1280e-03, -1.3964e-02],
        [ 1.3208e-02,  4.9184e-02,  3.3503e-02,  4.9016e-02,  1.7465e-02,
          1.4404e-02,  2.6906e-02, -8.3263e-03,  1.3732e-02],
        [ 3.4434e-02,  6.7054e-02, -3.4955e-03,  3.3051e-02,  5.5812e-02,
          1.8879e-02, -2.3393e-02, -2.8981e-02, -2.1825e-03],
        [ 3.9222e-02,  5.2818e-02, -8.8901e-03,  6.4605e-02,  9.6511e-03,
          4.4643e-02, -4.2248e-02, -3.2466e-02, -3.0473e-02],
        [ 1.7703e-02,  5.7643e-02, -8.6769e-04,  3.7427e-02,  4.2789e-02,
         -3.0154e-03, -4.8294e-03, -6.7323e-02, -3.1352e-02],
        [ 7.6252e-03,  4.4999e-02,  2.2395e-02,  8.7719e-02,  5.0058e-02,
         -2.8418e-02, -1.7231e-02, -9.2161e-02,  2.4956e-02],
        [ 1.0476e-02,  5.4486e-02,  5.4586e-02,  1.3326e-01,  3.2672e-02,
         -5.3690e-02, -2.1257e-02, -1.1390e-01,  3.1813e-02],
        [ 5.0284e-02,  8.0629e-02,  7.8598e-02,  9.5313e-02,  6.1923e-02,
         -1.5688e-02, -4.8931e-02, -6.9627e-02,  2.5013e-03],
        [ 4.7074e-02,  3.6267e-02,  1.2090e-01,  1.5534e-01,  1.1135e-01,
         -4.7217e-02,  2.2335e-02, -5.7714e-02, -3.9817e-03],
        [ 7.1382e-02,  1.6969e-02,  1.1157e-01,  1.7124e-01,  8.8421e-02,
         -8.4020e-03,  7.9469e-02, -5.2841e-02,  2.3255e-02],
        [ 2.6767e-02,  8.4420e-03,  9.5907e-02,  1.7361e-01,  8.9725e-02,
         -2.3339e-02,  4.5094e-02, -1.0153e-01,  1.3048e-03],
        [ 4.7910e-03,  1.6833e-02,  7.6657e-02,  1.0251e-01,  8.0264e-02,
         -8.3014e-03,  4.4257e-02, -7.3046e-02,  1.4148e-02],
        [-6.8932e-03,  1.4081e-02,  8.9506e-03,  1.2938e-01,  2.6720e-02,
         -5.6110e-03, -9.9000e-03, -9.3737e-02, -9.7825e-03],
        [ 1.4149e-02,  1.1808e-02,  3.3058e-02,  6.6578e-02,  2.2146e-02,
         -5.0391e-02,  3.6378e-02, -4.2137e-02, -8.5403e-03],
        [ 1.7245e-02, -2.5705e-03,  4.0062e-02,  3.1326e-02,  2.9868e-02,
         -1.4757e-02,  1.9054e-02, -6.2589e-02, -5.0268e-02],
        [ 6.3002e-02,  1.1055e-02,  7.9968e-02,  3.7050e-03,  5.0645e-02,
         -2.2947e-02,  5.4943e-03, -6.3421e-02, -5.3466e-02],
        [ 1.9044e-02,  4.9453e-02,  5.6323e-02,  4.7003e-02,  3.0991e-02,
         -1.3294e-02,  1.6698e-02, -6.4374e-02,  2.8829e-02],
        [-3.2599e-02,  3.2138e-02,  9.9127e-02,  5.4674e-02,  1.7155e-05,
         -4.1551e-02, -3.7157e-04, -7.8810e-02, -8.2693e-03],
        [-2.6327e-02,  4.8945e-03,  2.5955e-02,  4.5786e-02,  2.5402e-02,
         -3.7257e-02, -8.9031e-03, -7.9090e-02,  1.3885e-02],
        [-1.1430e-02, -1.5008e-02,  7.0533e-02,  1.9002e-02,  7.2340e-02,
         -5.5851e-02,  5.1460e-02, -3.9105e-02, -1.2614e-02],
        [-3.7780e-02, -7.6263e-04,  6.7537e-02,  4.5665e-02,  7.4211e-02,
         -5.4651e-02,  8.8694e-02, -3.1336e-02, -2.2369e-02],
        [-4.8749e-02,  5.0722e-03,  3.4647e-02,  8.2097e-02,  7.9723e-02,
         -9.1170e-02,  1.0117e-01, -2.8669e-02, -1.0937e-02],
        [ 7.2655e-03,  1.7515e-02,  4.3003e-02,  8.7258e-02,  7.3814e-02,
         -1.3298e-02,  7.8785e-02,  9.9551e-03, -9.1298e-03],
        [-3.5876e-02,  3.6467e-03,  4.5647e-02,  1.2269e-01,  8.9789e-02,
         -2.7733e-02,  2.2605e-02, -2.6139e-02, -3.1497e-02],
        [ 7.9253e-03,  5.0785e-02,  6.3591e-02,  1.0623e-01,  9.4893e-02,
         -7.0320e-02,  4.7295e-03, -6.9957e-02, -2.6867e-02],
        [ 4.0818e-02,  6.3439e-02,  6.3788e-02,  5.9888e-02,  8.5492e-02,
         -5.3867e-02,  4.7728e-02,  1.9959e-02, -2.2501e-02],
        [ 1.7481e-02,  4.3535e-02,  1.1841e-01,  9.6153e-02,  1.0129e-01,
         -8.3599e-02,  5.0522e-03, -2.6142e-03, -4.8033e-02],
        [ 5.4343e-02, -1.3246e-02,  7.0181e-02,  6.6330e-02,  4.9628e-02,
         -8.1740e-02,  3.2154e-02, -3.0047e-03, -7.4529e-02],
        [-2.1330e-02, -1.2100e-02,  8.6243e-02,  1.0439e-01,  4.5414e-02,
         -5.5776e-02,  5.7161e-02, -3.2444e-02, -5.6948e-02],
        [ 3.9036e-03, -2.4023e-03,  1.1202e-01,  1.3427e-01,  5.9969e-02,
         -2.4287e-02,  4.6610e-02, -8.8279e-02, -4.2456e-02],
        [-1.5046e-02,  2.5931e-02,  4.3004e-02,  8.1989e-02,  3.6611e-02,
          3.5708e-02,  4.1205e-02, -4.7704e-02, -3.2832e-02],
        [-4.1585e-03,  1.6047e-02,  4.4101e-02,  5.1798e-02,  1.7038e-02,
          1.9604e-02,  5.2224e-02, -3.2164e-02, -9.1968e-03],
        [-1.0024e-02,  1.5265e-02,  6.6991e-02,  6.1134e-02,  4.9887e-02,
          2.6940e-02,  1.0078e-02, -7.3716e-02, -2.9959e-02],
        [ 2.9316e-02,  1.8044e-02,  5.4106e-02,  9.5514e-02,  4.2175e-02,
          4.0722e-02, -2.6533e-02, -9.7807e-02, -3.0428e-02],
        [ 3.5950e-02,  2.1914e-02,  2.2790e-02,  1.1776e-01,  8.5006e-02,
          3.3220e-03, -5.4721e-02, -1.0637e-01, -1.4969e-02],
        [ 2.8662e-04,  5.7606e-03, -5.1026e-02,  9.8858e-02,  4.9239e-02,
          6.3886e-03, -9.1599e-02, -1.1911e-01, -9.8223e-03],
        [-7.4165e-03, -9.9517e-03,  2.9346e-03,  1.0618e-01,  4.2462e-02,
         -1.5746e-02, -9.9808e-02, -9.0239e-02, -2.4547e-02],
        [ 2.3751e-02,  2.7337e-02,  5.2005e-02,  9.1845e-02,  3.9141e-02,
         -5.1640e-02, -6.5797e-02, -7.0130e-02, -3.7505e-02],
        [ 6.5887e-02,  3.8870e-02,  6.4941e-02,  5.1979e-02, -1.3211e-03,
         -4.3425e-02, -1.5248e-02, -1.6603e-04, -2.2244e-02],
        [ 1.0810e-01,  1.3062e-02,  6.5309e-02,  5.6089e-02,  2.5537e-02,
         -6.5674e-02,  1.7194e-02, -4.3231e-02,  6.0457e-02],
        [ 8.2087e-02,  2.0487e-02,  1.4045e-02,  4.9639e-02,  4.8768e-02,
         -3.1174e-02,  5.2010e-02,  2.8771e-03,  7.0507e-02],
        [ 1.8444e-02,  1.1632e-02,  4.5379e-02,  6.8511e-02,  5.6063e-02,
          1.2600e-04,  7.5738e-02, -2.0299e-02,  3.9262e-02],
        [-2.4086e-02, -2.7776e-02,  6.4662e-02,  9.4318e-02,  4.0032e-02,
         -2.0971e-02,  8.4450e-02, -5.6561e-02,  8.8655e-02],
        [-1.3338e-02, -4.5115e-02, -5.8770e-04,  1.0142e-01,  1.0268e-02,
         -1.2930e-02, -2.0746e-02, -7.1666e-02,  6.0637e-02],
        [-1.5192e-02, -4.2696e-02, -6.9170e-04,  6.6469e-02,  2.4795e-03,
         -8.2365e-02, -1.4042e-02, -4.1713e-02,  3.0981e-02],
        [ 2.4108e-02,  1.0243e-02,  2.7134e-02,  5.2960e-02,  1.6884e-03,
         -6.3587e-02,  2.4258e-02,  1.1909e-02, -2.1302e-02],
        [-4.6295e-03,  4.9394e-02,  4.3290e-02,  6.0374e-02,  3.6289e-02,
         -5.6213e-02,  5.7276e-02, -4.6284e-02, -4.6490e-02],
        [-2.3926e-02,  6.8226e-02,  1.1496e-02,  5.5457e-02,  2.7757e-02,
         -5.2626e-02,  1.2686e-02, -1.0490e-01,  8.4274e-03],
        [ 4.0195e-02, -6.3513e-03,  1.9505e-02,  4.9050e-02,  2.1933e-02,
         -7.4236e-02, -3.8088e-02, -7.7404e-02, -5.0512e-03]],
       grad_fn=<AddmmBackward>)

2.维特比

lstm_feats:是状态矩阵(只是这么看),(序列长,1,tagset_size)

transition:是转移矩阵

δ ( t ) j = m a x ( δ ( t 1 ) i a i , j ) b ( j ) \delta(t)_j=max(\delta(t-1)_ia_{i,j})*b(j)

print(lstm_feats.shape)
torch.Size([50, 9])
A=nn.Parameter(torch.randn(tagset_size,tagset_size))#行为t,列为t-1
A[0,:]=-1000
A[:,tagset_size-1]=-1000
sentence_len=lstm_feats.shape[0]

delta=torch.full((1,tagset_size),-1000.)
delta[0][0]=0
forward=[]
forward.append(delta)
i=0
gamma_r_l=forward[i]
delta0,indice=torch.max(gamma_r_l+A,axis=1)
g=torch.stack([forward[i]]*tagset_size)
print(torch.max(torch.squeeze(g)+A,axis=1))
torch.return_types.max(
values=tensor([-1.0000e+03, -2.0581e-01, -5.9801e-01, -1.6440e-01,  2.3199e-01,
         1.0266e+00,  8.5233e-01,  1.1755e+00, -6.3664e-01],
       grad_fn=<MaxBackward0>),
indices=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0]))
t_r1_k=torch.unsqueeze(lstm_feats[i],0)
print(t_r1_k)
tensor([[ 0.0382,  0.0158,  0.0379,  0.0704,  0.0224,  0.0418,  0.0633, -0.0629,
         -0.0241]], grad_fn=<UnsqueezeBackward0>)
delta=torch.unsqueeze(delta0,0)+t_r1_k
print(delta)
tensor([[-9.9996e+02, -1.8997e-01, -5.6015e-01, -9.3982e-02,  2.5443e-01,
          1.0684e+00,  9.1558e-01,  1.1126e+00, -6.6076e-01]],
       grad_fn=<AddBackward0>)
print(delta0+lstm_feats[i])
tensor([-9.9996e+02, -1.8997e-01, -5.6015e-01, -9.3982e-02,  2.5443e-01,
         1.0684e+00,  9.1558e-01,  1.1126e+00, -6.6076e-01],
       grad_fn=<AddBackward0>)
print(lstm_feats[i])
print(delta0)
tensor([ 0.0382,  0.0158,  0.0379,  0.0704,  0.0224,  0.0418,  0.0633, -0.0629,
        -0.0241], grad_fn=<SelectBackward>)
tensor([-1.0000e+03, -2.0581e-01, -5.9801e-01, -1.6440e-01,  2.3199e-01,
         1.0266e+00,  8.5233e-01,  1.1755e+00, -6.3664e-01],
       grad_fn=<MaxBackward0>)
indices=[]
forward.append(delta)
indices.append(indice.tolist())
print(forward)
[tensor([[    0., -1000., -1000., -1000., -1000., -1000., -1000., -1000., -1000.]]), tensor([[-9.9996e+02, -1.8997e-01, -5.6015e-01, -9.3982e-02,  2.5443e-01,
          1.0684e+00,  9.1558e-01,  1.1126e+00, -6.6076e-01]],
       grad_fn=<AddBackward0>), tensor([[-9.9996e+02, -1.8997e-01, -5.6015e-01, -9.3982e-02,  2.5443e-01,
          1.0684e+00,  9.1558e-01,  1.1126e+00, -6.6076e-01]],
       grad_fn=<AddBackward0>), tensor([[-9.9996e+02, -1.8997e-01, -5.6015e-01, -9.3982e-02,  2.5443e-01,
          1.0684e+00,  9.1558e-01,  1.1126e+00, -6.6076e-01]],
       grad_fn=<AddBackward0>), tensor([[-9.9996e+02, -1.8997e-01, -5.6015e-01, -9.3982e-02,  2.5443e-01,
          1.0684e+00,  9.1558e-01,  1.1126e+00, -6.6076e-01]],
       grad_fn=<AddBackward0>)]
def Viterbi_M(features):
    sequence_len=features.shape[0]
    delta = torch.full((1, tagset_size), -1000.)
    delta[0][0]=0;
    # logM = torch.log(features)
    forward=[]
    forward.append(delta)
    indices = []
    for i in range(len(features)):
        gamma_r_l=forward[i]
        # print(gamma_r_l+self.A)
        delta,indice=torch.max(gamma_r_l+A,dim=1)
        delta=features[i]+delta
        forward.append(delta.reshape(1,tagset_size))
        indices.append(indice.tolist())
    terminal=forward[-1]+A[tagset_size-1]
    best_tag_id=torch.argmax(terminal).tolist()
    best_score=terminal[0][best_tag_id]
#     print(best_tag_id)
#     print(best_score)
    bestpath=[best_tag_id]
    for indice in reversed(indices):
        best_tag_id=indice[best_tag_id]
        bestpath.append(best_tag_id)
    bestpath.pop()
    bestpath.reverse()
    return bestpath,best_score
    
bestpath,best_score=Viterbi_M(lstm_feats)
print(bestpath)
print(best_score)
[7, 5, 7, 5, 7, 6, 2, 3, 6, 2, 3, 6, 2, 3, 6, 2, 3, 6, 2, 3, 6, 2, 3, 6, 2, 3, 6, 2, 3, 6, 2, 3, 6, 2, 3, 6, 2, 3, 6, 2, 3, 6, 2, 3, 6, 2, 3, 6, 7, 5]
tensor(60.5767, grad_fn=<SelectBackward>)

3.neg-log-loss

θ ^ = a r g m a x Π i = 1 N p ( y ( i ) x ( i ) ) λ ^ , η ^ = a r g m a x λ , η Π i = 1 N p ( y ( i ) x ( i ) ) Σ i = 1 N l o g p ( y ( i ) x ( i ) ) = Σ i = 1 N ( l o g ( Z ) + Σ t = 1 T ( λ T f ( y t 1 , y t , x ) + η T g ( y t , x ) ) ) = L L λ = Σ i = 1 N l o g p ( y ( i ) x ( i ) ) = Σ i = 1 N ( λ l o g ( Z ) + Σ t = 1 T f ( y t 1 , y t , x ) ) l o g p a r t i t i o n f u n c t i o n : λ l o g ( Z ) = ( E ( Σ t = 1 T f ( y t 1 , y t , x ( i ) ) ) = Σ y P ( y x ( i ) ) Σ t = 1 T f ( y t 1 , y t , x ( i ) ) = Σ t = 1 T Σ y P ( y x ( i ) ) f ( y t 1 , y t , x ( i ) ) = Σ t = 1 T Σ y 1 , y 2 , . . . , y t 2 Σ y t 1 Σ y t Σ y t + 1 , y t + 2 , . . . , y T P ( y x ( i ) ) f ( y t 1 , y t , x ( i ) ) = Σ t = 1 T Σ y t 1 Σ y t ( Σ y 1 , y 2 , . . . , y t 2 Σ y t + 1 , y t + 2 , . . . , y T P ( y x ( i ) ) f ( y t 1 , y t , x ( i ) ) ) = Σ t = 1 T Σ y t 1 Σ y t P ( y t 1 , y t x ( i ) ) f ( y t 1 , y t , x ( i ) ) \hat{\theta}=argmax\Pi_{i=1}^N p(y^{(i)}|x^{(i)})\\ \hat{\lambda},\hat{\eta}=argmax_{\lambda,\eta}\Pi_{i=1}^N p(y^{(i)}|x^{(i)})\\ \Sigma_{i=1}^Nlog p(y^{(i)}|x^{(i)})=\Sigma_{i=1}^N(-log(Z)+\Sigma_{t=1}^T(\lambda^Tf(y_{t-1},y_t,x)+\eta^Tg(y_t,x)))\\ =L\\ \frac{\partial L}{\partial \lambda}=\Sigma_{i=1}^Nlog p(y^{(i)}|x^{(i)})=\Sigma_{i=1}^N(-\frac{\partial }{\partial \lambda} log(Z)+\Sigma_{t=1}^Tf(y_{t-1},y_t,x))\\ log-partition function:\\ \frac{\partial }{\partial \lambda} log(Z)\\ =(积分就是期望)E(\Sigma_{t=1}^Tf(y_{t-1},y_t,x^{(i)}))\\ =\Sigma_y P(y|x^{(i)})\Sigma_{t=1}^T f(y_{t-1},y_t,x^{(i)})\\ =\Sigma_{t=1}^T\Sigma_y P(y|x^{(i)}) f(y_{t-1},y_t,x^{(i)})\\ =\Sigma_{t=1}^T\Sigma_{y_1,y_2,...,y_{t-2}}\Sigma_{y_{t-1}}\Sigma_{y_t}\Sigma_{y_{t+1},y_{t+2},...,y_T} P(y|x^{(i)}) f(y_{t-1},y_t,x^{(i)})\\ =\Sigma_{t=1}^T\Sigma_{y_{t-1}}\Sigma_{y_t} (\Sigma_{y_1,y_2,...,y_{t-2}}\Sigma_{y_{t+1},y_{t+2},...,y_T}P(y|x^{(i)}) f(y_{t-1},y_t,x^{(i)}))\\ =\Sigma_{t=1}^T\Sigma_{y_{t-1}}\Sigma_{y_t}P(y_{t-1},y_t|x^{(i)}) f(y_{t-1},y_t,x^{(i)})
p ( y i 1 , y i x ) = α i 1 T ( y i x ) M i ( y i 1 , y i x ) β i ( y i x ) Z ( x ) p(y_{i-1},y_i|x)=\frac{\alpha_{i-1}^T(y_i|x)M_i(y_{i-1},y_i|x)\beta_i(y_i|x)}{Z(x)}

一次只对一个句子算就可以了
n e g L = Σ i = 1 N ( + l o g ( Z ) Σ t = 1 T ( λ T f ( y t 1 , y t , x ) + η T g ( y t , x ) ) ) = L negL=\Sigma_{i=1}^N(+log(Z)-\Sigma_{t=1}^T(\lambda^Tf(y_{t-1},y_t,x)+\eta^Tg(y_t,x)))=-L

#这是crf中的
# y=[0,1,1]
# delta=torch.sum(self.f[:,len(y),[0]+y,y+[9]],axis=(1))-torch.sum(self.f* self.p_y12_x_condition_alpha_beta(alpha, beta),axis=(1,2,3))

3.1求logZ(前向算法)

α 0 ( y i ) = 0 1000 1000 . . . , 1000 ) α 1 ( y i + 1 ) = l o g ( Σ y i e x p ( α 0 ( y i ) + λ T f ( y i , y i + 1 ) + μ T g ( y i , x d ) ) ) μ T g ( y i , x d ) = l s t m f e a t s λ T f ( y i , y i + 1 ) = A \alpha_0(y_i)=(0,-1000,-1000,...,-1000)\\ \alpha_1(y_{i+1})=log(\Sigma_{y_{i}}exp(\alpha_0(y_i)+\lambda^Tf(y_i,y_{i+1})+\mu^Tg(y_i,x_d)))\\ \mu^Tg(y_i,x_d)=lstmfeats\\ \lambda^Tf(y_i,y_{i+1})=A

def _forward_alg( feats):
    # Do the forward algorithm to compute the partition function
    init_alphas = torch.full([tagset_size], -10000.)
    # START_TAG has all of the score.
    init_alphas[0] = 0.

    # Wrap in a variable so that we will get automatic backprop
    # Iterate through the sentence
    forward_var_list=[]
    forward_var_list.append(init_alphas)
    for feat_index in range(feats.shape[0]):        
        gamar_r_l = torch.stack([forward_var_list[feat_index]] * feats.shape[1])
#         print(gamar_r_l)
        t_r1_k = torch.unsqueeze(feats[feat_index],0).transpose(0,1)
        aa = gamar_r_l + t_r1_k + A
        forward_var_list.append(torch.logsumexp(aa,dim=1))
    terminal_var = forward_var_list[-1] + A[tagset_size-1]
    terminal_var = torch.unsqueeze(terminal_var,0)
    alpha = torch.logsumexp(terminal_var, dim=1)[0]
    return alpha
_forward_alg(lstm_feats)
tensor(114.9550, grad_fn=<SelectBackward>)
def alpha_alg(feats):
    init_alpha=torch.full([tagset_size],-1000.)
    init_alpha[0]=0.
    alpha=[init_alpha]
    for i in range(feats.shape[0]):
        gamma=alpha[i]
        aa=gamma+A+feats[i].reshape(tagset_size,1)
        alpha.append(torch.logsumexp(aa,axis=1))
    terminal=alpha[-1]+A[tagset_size-1]
    logZ=torch.logsumexp(terminal,axis=0)
    return logZ;
        
        
alpha_alg(lstm_feats)
tensor([-888.8337,  110.8548,  111.7476,  112.9284,  110.7119,  114.4487,
         113.1612,  111.3721, -887.2665], grad_fn=<AddBackward0>)
tensor(114.9550, grad_fn=<LogsumexpBackward>)
init_alpha=torch.full([tagset_size],-1000.)
init_alpha[0]=0.
print(init_alpha)
tensor([    0., -1000., -1000., -1000., -1000., -1000., -1000., -1000., -1000.])
alpha=[]
alpha.append(init_alpha)
print(alpha)
[tensor([    0., -1000., -1000., -1000., -1000., -1000., -1000., -1000., -1000.])]
i=0
gammar=torch.stack([alpha[i]]*tagset_size)
print(gammar)

tensor([[    0., -1000., -1000., -1000., -1000., -1000., -1000., -1000., -1000.],
        [    0., -1000., -1000., -1000., -1000., -1000., -1000., -1000., -1000.],
        [    0., -1000., -1000., -1000., -1000., -1000., -1000., -1000., -1000.],
        [    0., -1000., -1000., -1000., -1000., -1000., -1000., -1000., -1000.],
        [    0., -1000., -1000., -1000., -1000., -1000., -1000., -1000., -1000.],
        [    0., -1000., -1000., -1000., -1000., -1000., -1000., -1000., -1000.],
        [    0., -1000., -1000., -1000., -1000., -1000., -1000., -1000., -1000.],
        [    0., -1000., -1000., -1000., -1000., -1000., -1000., -1000., -1000.],
        [    0., -1000., -1000., -1000., -1000., -1000., -1000., -1000., -1000.]])
print(lstm_feats[i])
print(torch.unsqueeze(lstm_feats[i],0))
print(torch.unsqueeze(lstm_feats[i],0).transpose(0,1))
tensor([ 0.0382,  0.0158,  0.0379,  0.0704,  0.0224,  0.0418,  0.0633, -0.0629,
        -0.0241], grad_fn=<SelectBackward>)
tensor([[ 0.0382,  0.0158,  0.0379,  0.0704,  0.0224,  0.0418,  0.0633, -0.0629,
         -0.0241]], grad_fn=<UnsqueezeBackward0>)
tensor([[ 0.0382],
        [ 0.0158],
        [ 0.0379],
        [ 0.0704],
        [ 0.0224],
        [ 0.0418],
        [ 0.0633],
        [-0.0629],
        [-0.0241]], grad_fn=<TransposeBackward0>)
t=lstm_feats[i].reshape(tagset_size,1)
print(t)
tensor([[ 0.0382],
        [ 0.0158],
        [ 0.0379],
        [ 0.0704],
        [ 0.0224],
        [ 0.0418],
        [ 0.0633],
        [-0.0629],
        [-0.0241]], grad_fn=<AsStridedBackward>)
aa=gammar+t+A
print(aa)
tensor([[-9.9996e+02, -2.0000e+03, -2.0000e+03, -2.0000e+03, -2.0000e+03,
         -2.0000e+03, -2.0000e+03, -2.0000e+03, -2.0000e+03],
        [-1.8997e-01, -1.0012e+03, -9.9866e+02, -9.9967e+02, -1.0013e+03,
         -1.0007e+03, -1.0006e+03, -1.0009e+03, -2.0000e+03],
        [-5.6015e-01, -1.0014e+03, -9.9917e+02, -9.9963e+02, -9.9776e+02,
         -9.9906e+02, -9.9860e+02, -1.0004e+03, -2.0000e+03],
        [-9.3982e-02, -9.9949e+02, -9.9901e+02, -9.9943e+02, -1.0008e+03,
         -9.9923e+02, -1.0001e+03, -1.0012e+03, -1.9999e+03],
        [ 2.5443e-01, -1.0014e+03, -1.0001e+03, -1.0005e+03, -9.9904e+02,
         -1.0002e+03, -1.0009e+03, -1.0018e+03, -2.0000e+03],
        [ 1.0684e+00, -9.9965e+02, -1.0008e+03, -1.0014e+03, -1.0022e+03,
         -1.0009e+03, -9.9986e+02, -9.9887e+02, -2.0000e+03],
        [ 9.1558e-01, -1.0010e+03, -1.0002e+03, -9.9890e+02, -9.9977e+02,
         -1.0006e+03, -1.0010e+03, -9.9912e+02, -1.9999e+03],
        [ 1.1126e+00, -1.0009e+03, -1.0013e+03, -9.9953e+02, -1.0025e+03,
         -9.9881e+02, -9.9909e+02, -9.9945e+02, -2.0001e+03],
        [-6.6076e-01, -1.0009e+03, -1.0006e+03, -9.9899e+02, -1.0004e+03,
         -9.9686e+02, -9.9844e+02, -1.0003e+03, -2.0000e+03]],
       grad_fn=<AddBackward0>)

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-HrvP3XSv-1581078252705)(attachment:image.png)]

torch.logsumexp(aa,dim=1)#一行求一个值(t+1时刻的)
tensor([-9.9996e+02, -1.8997e-01, -5.6015e-01, -9.3982e-02,  2.5443e-01,
         1.0684e+00,  9.1558e-01,  1.1126e+00, -6.6076e-01],
       grad_fn=<LogsumexpBackward>)
torch.logsumexp(aa[0],axis=0)
tensor(-999.9618, grad_fn=<LogsumexpBackward>)
torch.logsumexp(aa[:,0],axis=0)
tensor([-9.9996e+02, -1.8997e-01, -5.6015e-01, -9.3982e-02,  2.5443e-01,
         1.0684e+00,  9.1558e-01,  1.1126e+00, -6.6076e-01],
       grad_fn=<SelectBackward>)
torch.log(torch.sum(torch.exp(aa),axis=1))
tensor([   -inf, -0.1900, -0.5601, -0.0940,  0.2544,  1.0684,  0.9156,  1.1126,
        -0.6608], grad_fn=<LogBackward>)
alpha.append(torch.logsumexp(aa,dim=1))

2. Σ t = 1 T ( λ T f ( y t 1 , y t , x ) + η T g ( y t , x ) ) \Sigma_{t=1}^T(\lambda^Tf(y_{t-1},y_t,x)+\eta^Tg(y_t,x))

print(y)
tensor([4301, 2826,  375, 3802, 3197, 2874, 3016, 2453, 1389, 2284, 2490, 4301,
        3992, 3726,  981, 2985, 2557, 2218, 2264,  471, 1756,  397, 2874, 4154,
         535, 1244, 2406,  545, 2411, 2985,  348, 3489, 4586, 3551,  473, 3462,
        2401])
y=torch.tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7])
def gold_score(feats,y):
    goldScore=A[y[0],0]+feats[0,y[0]]
    for i in range(len(y)-1):
        goldScore+=A[y[i+1],y[i]]+feats[i+1,y[i+1]]
    return goldScore
print(sum)
tensor(28.4129, grad_fn=<AddBackward0>)

4.整体

import torch
import torch.nn as nn
import torch.optim as optim
from processData import *
from tqdm import tqdm


torch.manual_seed(1)

def prepare_sequence(seq, to_ix):
    idxs = [to_ix[w] for w in seq]
    return torch.tensor(idxs, dtype=torch.long)

class BiLSTM_CRF(nn.Module):

    def __init__(self, vocab_size, tag_to_ix, embedding_dim, hidden_dim):
        super(BiLSTM_CRF, self).__init__()
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.vocab_size = vocab_size
        self.tag_to_ix = tag_to_ix
        self.tagset_size = len(tag_to_ix)

        self.word_embeds = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim // 2,
                            num_layers=1, bidirectional=True)

        # Maps the output of the LSTM into tag space.
        self.hidden2tag = nn.Linear(hidden_dim, self.tagset_size)

        # Matrix of transition parameters.  Entry i,j is the score of
        # transitioning *to* i *from* j.
        self.transitions = nn.Parameter(
            torch.randn(self.tagset_size, self.tagset_size))

        # These two statements enforce the constraint that we never transfer
        # to the start tag and we never transfer from the stop tag
        self.transitions.data[tag_to_ix[START_TAG], :] = -10000
        self.transitions.data[:, tag_to_ix[STOP_TAG]] = -10000

        self.hidden = self.init_hidden()

    def init_hidden(self):
        return (torch.randn(2, 1, self.hidden_dim // 2),
                torch.randn(2, 1, self.hidden_dim // 2))

    def _forward_alg(self, feats):
        # Do the forward algorithm to compute the partition function
        init_alphas = torch.full([self.tagset_size], -10000.)
        # START_TAG has all of the score.
        init_alphas[self.tag_to_ix[START_TAG]] = 0.

        # Wrap in a variable so that we will get automatic backprop
        # Iterate through the sentence
        forward_var_list=[]
        forward_var_list.append(init_alphas)
        for feat_index in range(feats.shape[0]):
            gamar_r_l = torch.stack([forward_var_list[feat_index]] * feats.shape[1])
            t_r1_k = torch.unsqueeze(feats[feat_index],0).transpose(0,1)
            aa = gamar_r_l + t_r1_k + self.transitions
            forward_var_list.append(torch.logsumexp(aa,dim=1))
        terminal_var = forward_var_list[-1] + self.transitions[self.tag_to_ix[STOP_TAG]]
        terminal_var = torch.unsqueeze(terminal_var,0)
        alpha = torch.logsumexp(terminal_var, dim=1)[0]
        return alpha

    def _get_lstm_features(self, sentence):
        self.hidden = self.init_hidden()
        embeds = self.word_embeds(sentence).view(len(sentence), 1, -1)
        lstm_out, self.hidden = self.lstm(embeds, self.hidden)
        lstm_out = lstm_out.view(len(sentence), self.hidden_dim)
        lstm_feats = torch.tanh(self.hidden2tag(lstm_out))
        return lstm_feats

    def _score_sentence(self, feats, tags):
        # Gives the score of a provided tag sequence
        score = torch.zeros(1)
        tags = torch.cat([torch.tensor([self.tag_to_ix[START_TAG]], dtype=torch.long), tags])
        for i, feat in enumerate(feats):
            score = score + \
                self.transitions[tags[i + 1], tags[i]] + feat[tags[i + 1]]
        score = score + self.transitions[self.tag_to_ix[STOP_TAG], tags[-1]]
        return score

    def _viterbi_decode(self, feats):
        backpointers = []
        # Initialize the viterbi variables in log space
        init_vvars = torch.full((1, self.tagset_size), -10000.)
        init_vvars[0][self.tag_to_ix[START_TAG]] = 0

        # forward_var at step i holds the viterbi variables for step i-1
        forward_var_list = []
        forward_var_list.append(init_vvars)

        for feat_index in range(feats.shape[0]):
            gamar_r_l = torch.stack([forward_var_list[feat_index]] * feats.shape[1])
            gamar_r_l = torch.squeeze(gamar_r_l)
            next_tag_var = gamar_r_l + self.transitions
            viterbivars_t,bptrs_t = torch.max(next_tag_var,dim=1)

            t_r1_k = torch.unsqueeze(feats[feat_index], 0)
            forward_var_new = torch.unsqueeze(viterbivars_t,0) + t_r1_k

            forward_var_list.append(forward_var_new)
            backpointers.append(bptrs_t.tolist())

        # Transition to STOP_TAG
        terminal_var = forward_var_list[-1] + self.transitions[self.tag_to_ix[STOP_TAG]]

        best_tag_id = torch.argmax(terminal_var).tolist()
        # print(best_tag_id)
        path_score = terminal_var[0][best_tag_id]

        # Follow the back pointers to decode the best path.
        best_path = [best_tag_id]
        for bptrs_t in reversed(backpointers):
            best_tag_id = bptrs_t[best_tag_id]
            best_path.append(best_tag_id)
        # Pop off the start tag (we dont want to return that to the caller)
        start = best_path.pop()
        assert start == self.tag_to_ix[START_TAG]  # Sanity check
        best_path.reverse()
        return path_score, best_path


    def neg_log_likelihood(self, sentence, tags):
        feats = self._get_lstm_features(sentence)
        forward_score = self._forward_alg(feats)
        gold_score = self._score_sentence(feats, tags)
        return forward_score - gold_score

    def forward(self, sentence):  # dont confuse this with _forward_alg above.
        # Get the emission scores from the BiLSTM
        lstm_feats = self._get_lstm_features(sentence)
        # print(lstm_feats.shape)
        # Find the best path, given the features.
        score, tag_seq = self._viterbi_decode(lstm_feats)
        return score, tag_seq

def measure(predict,y):
    acc = (torch.sum(torch.eq(predict, y))).type(torch.FloatTensor) / float(len(y))
    TP=torch.zeros(7,dtype=float)
    FP=torch.zeros(7,dtype=float)
    FN=torch.zeros(7,dtype=float)
    for i in range(len(y)):
        if(y[i]==predict[i]):
            TP[y[i]-1]+=1
        else:
            FP[predict[i]-1]+=1
            FN[y[i]-1]+=1
    # micro:算总的
    # print(torch.sum(TP))
    print(TP)
    micro_precision=torch.sum(TP)/(torch.sum(TP)+torch.sum(FP))
    micro_recall=torch.sum(TP)/(torch.sum(TP)+torch.sum(FN))
    micro_F1=2*(micro_precision*micro_recall)/(micro_precision+micro_recall)
    # macro :算每一类的然后平均
    # TP[TP==0]=1e-8
    # FP[FP==0]=1e-8
    # FN[FN==0]=1e-8
    macro_precision=TP/(TP+FP)
    macro_recall=TP/(TP+FN)

    macro_F1=2*(macro_recall*macro_precision)/(macro_recall+macro_precision)
    print(macro_F1)
    macro_F1=torch.mean(macro_F1)
    print(acc,micro_F1,macro_F1)
    return acc,micro_F1,macro_F1
if __name__== '__main__':
    START_TAG = "<BEG>"
    STOP_TAG = "<END>"
    EMBEDDING_DIM = 300
    HIDDEN_DIM = 256
    training_data, dic_word_list, dic_label_list, word_to_ix, tag_to_ix = getAllTrain()

    # Make up some training data
    # training_data = [(
    #     "the wall street journal reported today that apple corporation made money".split(),
    #     "B I I I O O O B I O O".split()
    # ), (
    #     "georgia tech is a university in georgia".split(),
    #     "B I O O O O B".split()
    # )]
    #
    # word_to_ix = {}
    # for sentence, tags in training_data:
    #     for word in sentence:
    #         if word not in word_to_ix:
    #             word_to_ix[word] = len(word_to_ix)
    #
    # tag_to_ix = {"B": 0, "I": 1, "O": 2, START_TAG: 3, STOP_TAG: 4}
    #
    model = BiLSTM_CRF(len(dic_word_list), tag_to_ix, EMBEDDING_DIM, HIDDEN_DIM)
    # print(list(model.named_parameters()))
    # print(list(model.parameters()))
    optimizer = optim.SGD(model.parameters(), lr=0.01, weight_decay=1e-4)

    # Check predictions before training
    with torch.no_grad():
        precheck_sent = torch.tensor(training_data[0][0])
        # print(precheck_sent)
        precheck_tags = torch.tensor(training_data[1][0])
        print(model(precheck_sent))

    # Make sure prepare_sequence from earlier in the LSTM section is loaded
    for epoch in range(1):  # again, normally you would NOT do 300 epochs, it is toy data
        for sentence, tags in tqdm(zip(training_data[0][:-2000],training_data[1][:-2000])):
            # print(sentence,tags)
            # Step 1. Remmber that Pytorch accumulates gradients.
            # We need to clear them out before each instance
            model.zero_grad()
            sentence_in=torch.tensor(sentence)
            targets=torch.tensor(tags)
            # Step 2. Get our inputs ready for the network, that is,
            # turn them into Tensors of word indices.
            # sentence_in = prepare_sequence(sentence, word_to_ix)
            # targets = torch.tensor([tag_to_ix[t] for t in tags], dtype=torch.long)
            score,predict=model(sentence_in)
            # Step 3. Run our forward pass.
            loss = model.neg_log_likelihood(sentence_in, targets)

            # Step 4. Compute the loss, gradients, and update the parameters by
            # calling optimizer.step()
            loss.backward()
            optimizer.step()

    # Check predictions after training
    with torch.no_grad():
        # precheck_sent = torch.tensor(training_data[0][0])
        # print(model(precheck_sent))
        y = torch.tensor([training_data[1][-2000]])
        sentence_in = torch.tensor(training_data[0][-2000])
        tag_scores,predict1 = model(sentence_in)

        predict = torch.tensor([predict1])

        for sentence, tags in zip(training_data[0][-2001:], training_data[1][-2001:]):
            # 准备网络输入, 将其变为词索引的 Tensor 类型数据
            sentence_in = torch.tensor(sentence)
            # targets = torch.tensor(tags)
            tag_scores,predict1 = model(sentence_in)

            predict = torch.cat((predict, torch.tensor([predict1])), axis=1)
            y = torch.cat((y, torch.tensor([tags])), axis=1)

            x0 = [dic_word_list[s] for s in sentence]
            y0 = [dic_label_list[t] for t in tags]
            predict0 = [dic_label_list[t] for t in predict1]
            print(x0)
            print(y0)
            print(predict0)
        # print(predict.shape)
        # print(y.shape)
        measure(predict.reshape(y.shape[1]), y.reshape(y.shape[1]))


问题:倾向于全标注O

  • lstm后使用tanh层
发布了74 篇原创文章 · 获赞 2 · 访问量 4885

猜你喜欢

转载自blog.csdn.net/weixin_40485502/article/details/104215235