Tensorflow2.0学习(22):文本生成之模型构建

构建模型

vocab_size = len(vocab)
embedding_dim = 256
rnn_units = 1024
def build_model(vocab_size, embedding_dim, rnn_units, batch_size):
    model = keras.models.Sequential([
        keras.layers.Embedding(vocab_size, embedding_dim, 
                              batch_input_shape = [batch_size, None]),
        keras.layers.SimpleRNN(units = rnn_units,
                              return_sequences = True),
        keras.layers.Dense(vocab_size)
    ])
    return model

model = build_model(
    vocab_size = vocab_size,
    embedding_dim=embedding_dim,
    rnn_units=rnn_units,
    batch_size=batch_size)
model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
embedding (Embedding)        (64, None, 256)           16640     
_________________________________________________________________
simple_rnn (SimpleRNN)       (64, None, 1024)          1311744   
_________________________________________________________________
dense (Dense)                (64, None, 65)            66625     
=================================================================
Total params: 1,395,009
Trainable params: 1,395,009
Non-trainable params: 0
_________________________________________________________________
for input_example_batch, target_example_batch in seq_dataset.take(1):
    example_batch_predictions = model(input_example_batch)
    print(example_batch_predictions.shape)
    # 64是batch_size,100是每个句子的长度,65是一个概率分布
(64, 100, 65)
# 基于输出的65,进行随机采样
# 当选取概率最大的值时,被称为贪心策略,当随机采样时,为随机策略
# logits:在分类任务中,softmax之前的值就为logits
sample_indices = tf.random.categorical(
    logits = example_batch_predictions[0], num_samples=1)
# (100, 65) -> (100, 1)
print(sample_indices)
# 变成向量
sample_indices = tf.squeeze(sample_indices, axis = -1)
print(sample_indices)
tf.Tensor(
[[50]
 [47]
 [16]
 [41]
 [41]
 [15]
 [ 0]
 [58]
 [48]
 [58]
 [62]
 [22]
 [48]
 [36]
 [36]
 [44]
 [45]
 [12]
 [ 7]
 [31]
 [22]
 [53]
 [32]
 [44]
 [26]
 [17]
 [ 1]
 [ 1]
 [31]
 [ 5]
 [35]
 [22]
 [64]
 [32]
 [15]
 [25]
 [60]
 [12]
 [ 3]
 [28]
 [11]
 [24]
 [28]
 [ 7]
 [39]
 [56]
 [18]
 [26]
 [55]
 [39]
 [10]
 [48]
 [28]
 [53]
 [43]
 [17]
 [48]
 [27]
 [23]
 [55]
 [ 5]
 [49]
 [64]
 [ 6]
 [11]
 [ 4]
 [32]
 [ 8]
 [23]
 [46]
 [18]
 [ 5]
 [64]
 [52]
 [44]
 [26]
 [16]
 [59]
 [37]
 [15]
 [27]
 [41]
 [16]
 [38]
 [ 6]
 [20]
 [42]
 [62]
 [24]
 [62]
 [14]
 [42]
 [12]
 [14]
 [12]
 [48]
 [ 5]
 [45]
 [42]
 [25]], shape=(100, 1), dtype=int64)
tf.Tensor(
[50 47 16 41 41 15  0 58 48 58 62 22 48 36 36 44 45 12  7 31 22 53 32 44
 26 17  1  1 31  5 35 22 64 32 15 25 60 12  3 28 11 24 28  7 39 56 18 26
 55 39 10 48 28 53 43 17 48 27 23 55  5 49 64  6 11  4 32  8 23 46 18  5
 64 52 44 26 16 59 37 15 27 41 16 38  6 20 42 62 24 62 14 42 12 14 12 48
  5 45 42 25], shape=(100,), dtype=int64)
print("Input:", repr("".join(idx2char[input_example_batch[0]])))
print()
print("Output:", repr("".join(idx2char[target_example_batch[0]])))
print()
print("Predictions:", repr("".join(idx2char[sample_indices])))
Input: 'ction, sir,\nEven by your own.\n\nAUFIDIUS:\nI cannot help it now,\nUnless, by using means, I lame the fo'

Output: 'tion, sir,\nEven by your own.\n\nAUFIDIUS:\nI cannot help it now,\nUnless, by using means, I lame the foo'

Predictions: "liDccC\ntjtxJjXXfg?-SJoTfNE  S'WJzTCMv?$P;LP-arFNqa:jPoeEjOKq'kz,;&T.KhF'znfNDuYCOcDZ,HdxLxBd?B?j'gdM"
# 自定义损失函数
def loss(labels, logits):
    return keras.losses.sparse_categorical_crossentropy(
        labels, logits, from_logits=True)

model.compile(optimizer= 'adam', loss = loss)
example_loss = loss(target_example_batch, example_batch_predictions)
print(example_loss.shape)
print(example_loss.numpy().mean())
(64, 100)
4.1735864
output_dir = "./text_generation_checkpoints"
if not os.path.exists(output_dir):
    os.mkdir(output_dir)
checkpoint_prefix = os.path.join(output_dir, 'ckpt_{epoch}')
checkpoint_callback = keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_prefix,
    save_weights_only=True)

epochs =50
history = model.fit(seq_dataset, epochs=epochs, callbacks=[checkpoint_callback])
Train for 172 steps
Epoch 1/50
172/172 [==============================] - 83s 485ms/step - loss: 1.3785
Epoch 2/50
172/172 [==============================] - 102s 591ms/step - loss: 1.3592
Epoch 3/50
172/172 [==============================] - 97s 565ms/step - loss: 1.3413
Epoch 4/50
172/172 [==============================] - 97s 563ms/step - loss: 1.3255
Epoch 5/50
172/172 [==============================] - 93s 543ms/step - loss: 1.3118
Epoch 6/50
172/172 [==============================] - 97s 562ms/step - loss: 1.2964
Epoch 7/50
172/172 [==============================] - 99s 577ms/step - loss: 1.2826
Epoch 8/50
172/172 [==============================] - 96s 559ms/step - loss: 1.2695
Epoch 9/50
172/172 [==============================] - 95s 550ms/step - loss: 1.2572
Epoch 10/50
172/172 [==============================] - 93s 541ms/step - loss: 1.2468
Epoch 11/50
172/172 [==============================] - 93s 540ms/step - loss: 1.2338
Epoch 12/50
172/172 [==============================] - 91s 531ms/step - loss: 1.2217
Epoch 13/50
172/172 [==============================] - 95s 551ms/step - loss: 1.2104
Epoch 14/50
172/172 [==============================] - 95s 551ms/step - loss: 1.1972
Epoch 15/50
172/172 [==============================] - 95s 554ms/step - loss: 1.1882
Epoch 16/50
172/172 [==============================] - 90s 526ms/step - loss: 1.1761
Epoch 17/50
172/172 [==============================] - 93s 542ms/step - loss: 1.1636
Epoch 18/50
172/172 [==============================] - 97s 564ms/step - loss: 1.1555
Epoch 19/50
172/172 [==============================] - 94s 548ms/step - loss: 1.1408
Epoch 20/50
172/172 [==============================] - 93s 540ms/step - loss: 1.1322
Epoch 21/50
172/172 [==============================] - 94s 549ms/step - loss: 1.1215
Epoch 22/50
172/172 [==============================] - 95s 551ms/step - loss: 1.1115
Epoch 23/50
172/172 [==============================] - 95s 551ms/step - loss: 1.0999
Epoch 24/50
172/172 [==============================] - 96s 555ms/step - loss: 1.0902
Epoch 25/50
172/172 [==============================] - 94s 545ms/step - loss: 1.0794
Epoch 26/50
172/172 [==============================] - 97s 563ms/step - loss: 1.0724
Epoch 27/50
172/172 [==============================] - 94s 548ms/step - loss: 1.0603
Epoch 28/50
172/172 [==============================] - 96s 557ms/step - loss: 1.0528
Epoch 29/50
172/172 [==============================] - 95s 550ms/step - loss: 1.0471
Epoch 30/50
172/172 [==============================] - 99s 576ms/step - loss: 1.0338
Epoch 31/50
172/172 [==============================] - 98s 570ms/step - loss: 1.0278
Epoch 32/50
172/172 [==============================] - 97s 567ms/step - loss: 1.0208
Epoch 33/50
172/172 [==============================] - 94s 547ms/step - loss: 1.0127
Epoch 34/50
172/172 [==============================] - 99s 573ms/step - loss: 1.0064
Epoch 35/50
172/172 [==============================] - 99s 573ms/step - loss: 1.0021
Epoch 36/50
172/172 [==============================] - 97s 565ms/step - loss: 0.9938
Epoch 37/50
172/172 [==============================] - 96s 559ms/step - loss: 0.9892
Epoch 38/50
172/172 [==============================] - 100s 581ms/step - loss: 0.9835
Epoch 39/50
172/172 [==============================] - 96s 557ms/step - loss: 0.9790
Epoch 40/50
172/172 [==============================] - 98s 571ms/step - loss: 0.9725
Epoch 41/50
172/172 [==============================] - 109s 636ms/step - loss: 0.9690
Epoch 42/50
172/172 [==============================] - 102s 592ms/step - loss: 0.9675
Epoch 43/50
172/172 [==============================] - 101s 585ms/step - loss: 0.9615
Epoch 44/50
172/172 [==============================] - 100s 581ms/step - loss: 0.9564
Epoch 45/50
172/172 [==============================] - 101s 585ms/step - loss: 0.9571
Epoch 46/50
172/172 [==============================] - 99s 574ms/step - loss: 0.9521
Epoch 47/50
172/172 [==============================] - 101s 589ms/step - loss: 0.9527
Epoch 48/50
172/172 [==============================] - 100s 581ms/step - loss: 0.9459
Epoch 49/50
172/172 [==============================] - 101s 588ms/step - loss: 0.9439
Epoch 50/50
172/172 [==============================] - 101s 590ms/step - loss: 0.9455
tf.train.latest_checkpoint(output_dir)
'./text_generation_checkpoints\\ckpt_10'
# 载入训练好的模型
model2 = build_model(vocab_size,
                    embedding_dim,
                    rnn_units,
                    batch_size = 1)
# 载入权重
model2.load_weights(tf.train.latest_checkpoint(output_dir))
# 设置输入的size
model2.build(tf.TensorShape([1, None]))

# 文本生成的流程
# 初始是一个字符串char -> A,
# A -> model -> b
# A.append(b) -> Ab
# Ab -> model -> c
# Ab.append(c) -> Abc
# Abc -> model -> ...
model2.summary()
Model: "sequential_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
embedding_3 (Embedding)      (1, None, 256)            16640     
_________________________________________________________________
simple_rnn_3 (SimpleRNN)     (1, None, 1024)           1311744   
_________________________________________________________________
dense_3 (Dense)              (1, None, 65)             66625     
=================================================================
Total params: 1,395,009
Trainable params: 1,395,009
Non-trainable params: 0
_________________________________________________________________
# 文本生成
def generate_text(model, start_string, num_generate = 1000):
    # 变成id文本,1维
    input_eval = [char2idx[ch] for ch in start_string]
    # 维度扩展,变成2维
    input_eval = tf.expand_dims(input_eval, 0)
    text_generated = []
    model.reset_states()
    
    for _ in range(num_generate):
        
        # predictions:[batch_size, input_eval_len, vocab_size]
        predictions = model(input_eval)
        # 降低维度
        # predictions: [input_eval_len, vocab_size]
        predictions = tf.squeeze(predictions, 0)
        # predicted_idds: [input_eval_len, 1]
        predicted_id = tf.random.categorical(
            predictions, num_samples = 1)[-1, 0].numpy()
        text_generated.append(idx2char[predicted_id])
        
        input_eval = tf.expand_dims([predicted_id], 0)
    return start_string + ''.join(text_generated)

new_text = generate_text(model2, 'All: ')
print(new_text)
All: ngs llou,
TINIf,
F t m ndaigome tooure d o ailowe and se?
Forerinenowiliounome! thetould kereriret mo is,

IORDIst, tss teey bante thes g'l
pofamyounge hu,
TE:

Thay yershand ome bar, bo t d bame pat ad l fidshey s ha, he K:
Whe

Cleel? pe bred t d mag our ts bulange y? d.
I t.
Timoure;
IXESTy yo d mol m HAUCHARD oucixan s, me dey tr mery male u scostivis f thy.
FFidir my ik thif My that m theve CII kersor, s agrs,
AUSoustoutithed ga, likimilerd nde uthigan t thatouthe oo bot, t buicithe s m't buche t he se t,
TESAUCHASA:

CO:

IOf hey nghat urow lll allie t! hire? w'ss lourally apr w, ndean le
MENGrff ilayo gh Icade pe he my
Fouser'son yoman;
FI:
Thilllildize: bas.
Thishe.
I ththes e and f ay tesseshicu, teade omowonanouss
ENCRIO:
VI bupootsie, ve HOMes he:
WAs Ay t t e pigulke or coungh ton touncancousir ANAnes banathie cen thy welfanghavere th vifo d s w soure iciee! bjur's,
A wanicomer t w'st be aint omofecofeye myowhe;
LYO:
While: burenowite howhame t,
Ofreer the tay?

BE:
HAs duc

可以看出效果不怎么样。

发布了35 篇原创文章 · 获赞 3 · 访问量 2489

猜你喜欢

转载自blog.csdn.net/Smile_mingm/article/details/104658327