pytorch笔记:08)使用LSTM写古诗词

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/jiangpeng59/article/details/81003058

测试环境:
centos7 + python3.6 + pytorch0.4 +cuda9

下面是用模型生成的藏头诗(深度学习)

深宫昔时见,古貌多自有。
度日不相容,年年生一目。
学者若为霖,百姓贻忧厄。
习坎与天聪,优游宁敢屡。

训练数据
57580首诗歌,每首诗歌,书(pytorch入门与实践)的作者对其进行了预处理,每首诗歌长度125字符(不足补空格,超过则丢弃)
下面data.py文件用于提取数据

import numpy as np
import os

def get_data(conf):
    '''
    生成数据
    :param conf: 配置选项,Config对象
    :return: word2ix: 每个字符对应的索引id,如u'月'->100
    :return: ix2word: 每个字符对应的索引id,如100->u'月'
    :return: data: 每一行是一首诗对应的字的索引id
    '''
    if os.path.exists(conf.data_path):
        datas = np.load(conf.data_path) #np数据文件
        data = datas['data']
        ix2word = datas['ix2word'].item()
        word2ix = datas['word2ix'].item()
        return data, word2ix, ix2word

配置文件

class Config(object):
    """Base configuration class.For custom configurations, create a
    sub-class that inherits from this one and override  properties that
    need to changed
    """
    #模型保存路径前缀(几个epoch后保存)
    model_prefix='checkpoints/tang'

    #模型保存路径
    model_path='checkpoints/tang.pth'

    #start words
    start_words='春江花月夜'

    #生成诗歌的类型,默认为藏头诗
    p_type='acrostic'

    # 训练次数
    max_epech = 200

    #数据存放的路径
    data_path='tang.npz'

    #mini批大小
    batch_size=128

    #dataloader加载数据使用多少进程
    num_workers=4

    #LSTM的层数
    num_layers=2

    #词向量维数
    embedding_dim=128

    #LSTM隐藏层维度
    hidden_dim=256

    #多少个epoch保存一次模型权重和诗
    save_every=10

    #训练是生成诗的保存路径
    out_path='out'

    #测试生成诗的保存路径
    out_poetry_path='out/poetry.txt'

    #生成诗的最大长度
    max_gen_len=200

模型定义

class PoetryModel(nn.Module):
    def __init__(self, vocab_size, conf, device):
        super(PoetryModel, self).__init__()
        self.num_layers = conf.num_layers
        self.hidden_dim = conf.hidden_dim
        self.device = device
        # 定义词向量层
        self.embeddings = nn.Embedding(vocab_size, conf.embedding_dim)
        # 定义2层的LSTM,并且batch位于函数参数的第一位
        self.lstm = nn.LSTM(conf.embedding_dim, conf.hidden_dim, num_layers=self.num_layers)
        # 定义全连接层,后接一个softmax进行分类
        self.linear_out = nn.Linear(conf.hidden_dim, vocab_size)

    def forward(self, input, hidden=None):
        '''
        :param input: (seq,batch)
        :return: 模型的结果
        '''
        seq_len, batch_size = input.size()
        # embeds_size:(seq_len,batch_size,embedding_dim)
        embeds = self.embeddings(input)
        if hidden is None:
            h0 = torch.zeros(self.num_layers, batch_size, self.hidden_dim).to(self.device)
            c0 = torch.zeros(self.num_layers, batch_size, self.hidden_dim).to(self.device)
        else:
            h0, c0 = hidden
        output, hidden = self.lstm(embeds, (h0, c0))
        # output_size:(seq_len*batch_size,vocab_size)
        output = self.linear_out(output.view(seq_len * batch_size, -1))
        return output, hidden

模型训练

def train(**kwargs):
    conf = Config()
    for k, v in kwargs.items():
        setattr(conf, k, v)
    # 获取数据
    data, word2ix, ix2word = get_data(conf)
    # 生成dataload
    dataloader = DataLoader(dataset=data, batch_size=conf.batch_size,
                            shuffle=True,
                            drop_last=True,
                            num_workers=conf.num_workers)
    # 定义模型
    model = PoetryModel(len(word2ix), conf, device).to(device)
    # 定义优化器
    optimizer = Adam(model.parameters())
    # 定义损失函数
    criterion = nn.CrossEntropyLoss()
    # 开始训练模型
    for epoch in range(conf.max_epech):
        epoch_loss = 0
        for i, data in enumerate(dataloader):
            data = data.long().transpose(1, 0).contiguous()
            input, target = data[:-1, :], data[1:, :]
            input, target = input.to(device), target.to(device)
            optimizer.zero_grad()
            output, _ = model(input)
            loss = criterion(output, target.view(-1))
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        print("epoch_%d_loss:%0.4f" % (epoch, epoch_loss / conf.batch_size))
        if epoch % conf.save_every == 0:
            fout = open('%s/p%d' % (conf.out_path, epoch), 'w')
            for word in list('春江花月夜'):
                gen_poetry = generate(model, word, ix2word, word2ix, conf)
                fout.write("".join(gen_poetry) + "\n\n")
            fout.close()
            torch.save(model.state_dict(), "%s_%d.pth" % (conf.model_prefix, epoch))

本内容参考陈云《pytorch入门与实践》

猜你喜欢

转载自blog.csdn.net/jiangpeng59/article/details/81003058