pytorch实现 CNN_LSTM_Attention_DNN网络模型

模型图:

 

 

import numpy as np
import random
import math
import os
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import pandas as pd
import glob
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision

class CNN_LSTM_ATT_DNN_Net(nn.Module):
    def __init__(self):
        # 模型是cnn + lstm + lstm + Dense
        super(CNN_LSTM_ATT_DNN_Net, self).__init__()
        # 初始参数-------
        self.input_size=31
        # LSTM
        self.cell_LSTM = nn.LSTM(input_size=self.input_size, hidden_size=self.input_size, num_layers=2, batch_first=True)
        # lstm输入:input: shape = [seq_length, batch_size, input_size]的张量
        # ls

猜你喜欢

转载自blog.csdn.net/qq_38735017/article/details/130303715