#!/usr/bin/env python3
import torch.nn as nn
import torch.nn.functional
import numpy as np
i
#此案例来源于:https://pytorch.org/docs/master/nn.html#torch.nn.LSTM
rnn = nn.LSTM(input_size=10,hidden_size=20,num_layers=2)#输入向量维数10, 隐藏元维度20, 2个LSTM层串联(若不写则默认为1)
input = torch.randn(5,3,10)#输入(seq_len,batch , input_size) 序列长度为5 batch为3 输入维度为10
#print(input)
h0 = torch.randn(2,3,20)#h_0(num_layers * num_directions, batch, hidden_size) num_layers = 2 ,batch=3 ,hidden_size = 20
c0 = torch.randn(2,3,20)#同上
output, (hn,cn) = rnn(input, (h0,c0))
#保存hn
a = hn.detach().numpy()
#三维numpy无法输出为csv,必须用reshape重塑
b = a.reshape(-1,20)
np.savetxt('hn.csv',b,delimiter=',')
print(hn)
print(b)
输出如下:
C:\Users\Penghj\Anaconda3\envs\pytorch-py36\python.exe D:/PycharmProjects/LSTMprocess/LSTM_test.py
tensor([[[-7.3971e-02, -1.2780e-01, -1.4474e-01, 1.8253e-01, -1.9796e-01,
1.0784e-01, -2.3064e-01, -3.4146e-01, 1.8185e-01, -8.9133e-04,
4.0730e-02, -1.2608e-01, 5.8794e-02, -7.0834e-02, -1.8286e-01,
-1.0784e-02, 1.9347e-01, -1.5460e-01, 8.6555e-02, -1.6798e-02],
[ 9.6833e-02, -5.8035e-02, 1.7099e-01, 2.2732e-01, 1.3198e-02,
3.5267e-03, 4.5223e-02, -1.5988e-01, -9.0247e-02, 7.3956e-02,
7.6391e-02, -4.5348e-02, 7.6712e-02, 2.2441e-03, -1.1218e-01,
7.3470e-02, 1.0081e-01, -3.8649e-02, 1.3708e-01, 2.1249e-01],
[ 5.5192e-02, 1.4865e-02, 1.6241e-02, -9.0010e-03, -1.9215e-01,
-4.5265e-02, 1.0673e-01, -1.2473e-01, -1.9864e-04, -1.5129e-01,
-8.4610e-02, -4.1374e-02, 1.1179e-01, -7.7280e-02, -9.2104e-04,
4.8305e-02, 1.4308e-02, -6.9500e-03, 2.7160e-01, 1.0379e-01]],
[[-1.4114e-01, 2.4604e-02, -3.0190e-02, 9.0903e-02, 1.7423e-02,
-4.3247e-02, -1.7116e-01, -1.2302e-01, -1.6867e-01, 6.4440e-03,
-1.3868e-01, -4.2504e-02, -8.0390e-02, 4.2705e-03, 3.0945e-02,
-7.8335e-02, 6.4501e-02, 5.7008e-03, -5.9228e-02, -3.5933e-02],
[-6.2894e-02, 6.8849e-02, 1.5310e-02, 3.9452e-02, -4.1810e-02,
-1.6745e-01, -2.1690e-01, -1.1189e-02, -1.0120e-01, -5.4169e-02,
-1.2677e-01, -8.9171e-02, -5.1950e-02, -8.6439e-02, 7.8760e-03,
-1.0244e-01, 6.7068e-02, 8.7209e-02, -1.7350e-02, -1.0928e-02],
[-1.4949e-01, 4.2771e-02, -2.9070e-02, 4.4891e-02, -4.9142e-02,
-1.6704e-01, -1.4158e-01, -1.0561e-02, -1.2310e-01, -1.1926e-02,
-1.1779e-01, -1.2357e-02, -7.4506e-03, -2.5705e-02, 1.0918e-02,
-7.2432e-02, 7.4882e-02, 5.3149e-02, -7.8296e-03, -4.9892e-02]]],
grad_fn=<StackBackward>)
[[-7.39708394e-02 -1.27802983e-01 -1.44737855e-01 1.82530552e-01
-1.97959319e-01 1.07843086e-01 -2.30639011e-01 -3.41463804e-01
1.81853518e-01 -8.91325413e-04 4.07299958e-02 -1.26077503e-01
5.87935671e-02 -7.08335489e-02 -1.82858154e-01 -1.07838241e-02
1.93470061e-01 -1.54596388e-01 8.65550786e-02 -1.67978760e-02]
[ 9.68326107e-02 -5.80345728e-02 1.70993403e-01 2.27316663e-01
1.31975403e-02 3.52669228e-03 4.52234782e-02 -1.59884900e-01
-9.02474746e-02 7.39560351e-02 7.63912499e-02 -4.53478247e-02
7.67116323e-02 2.24413490e-03 -1.12175845e-01 7.34703988e-02
1.00805841e-01 -3.86485495e-02 1.37081578e-01 2.12489918e-01]
[ 5.51915988e-02 1.48648266e-02 1.62407588e-02 -9.00098588e-03
-1.92146197e-01 -4.52653170e-02 1.06727183e-01 -1.24732286e-01
-1.98639274e-04 -1.51294112e-01 -8.46096203e-02 -4.13742140e-02
1.11793645e-01 -7.72797391e-02 -9.21042752e-04 4.83049154e-02
1.43076414e-02 -6.94996910e-03 2.71601051e-01 1.03792876e-01]
[-1.41142428e-01 2.46040653e-02 -3.01896222e-02 9.09029543e-02
1.74230635e-02 -4.32474427e-02 -1.71161547e-01 -1.23024724e-01
-1.68667182e-01 6.44403743e-03 -1.38677344e-01 -4.25040163e-02
-8.03901553e-02 4.27051587e-03 3.09446547e-02 -7.83353224e-02
6.45011663e-02 5.70079219e-03 -5.92282042e-02 -3.59332785e-02]
[-6.28941953e-02 6.88489452e-02 1.53099252e-02 3.94519903e-02
-4.18104976e-02 -1.67447835e-01 -2.16900289e-01 -1.11890212e-02
-1.01196185e-01 -5.41693345e-02 -1.26771182e-01 -8.91706571e-02
-5.19499257e-02 -8.64390358e-02 7.87597708e-03 -1.02441035e-01
6.70683980e-02 8.72091055e-02 -1.73503663e-02 -1.09276902e-02]
[-1.49493888e-01 4.27707098e-02 -2.90701501e-02 4.48911339e-02
-4.91419397e-02 -1.67038664e-01 -1.41577035e-01 -1.05611794e-02
-1.23095192e-01 -1.19260075e-02 -1.17788106e-01 -1.23572880e-02
-7.45061925e-03 -2.57053599e-02 1.09177465e-02 -7.24323615e-02
7.48823658e-02 5.31486273e-02 -7.82960933e-03 -4.98918593e-02]]
Process finished with exit code 0