PyTorch | Game:FizzBuzz

FizzBuzz


FizzBuzz是一个简单的小游戏。游戏规则如下:从1开始往上数数,当遇到3的倍数的时候,说fizz,当遇到5的倍数,说buzz,当遇到15的倍数,就说fizzbuzz,其他情况下则正常数数。

我们可以写一个简单的小程序决定要返回正常数值还是fizz,buzz或者fizzbuzz。

def fizz_buzz_encode(i):
    if i % 15 == 0: return 3
    elif i % 5 == 0: return 2
    elif i % 3 ==0: return 1
    else: return 0

def fizz_buzz_decode(i,prediction):
    return [str(i), "fizz", "buzz", "fizzbuzz"][prediction]

def helper(i):
    print(fizz_buzz_decode(i,fizz_buzz_encode(i)))
    
for i in range(1,16):
    helper(i)

Output

1
2
fizz
4
buzz
fizz
7
8
fizz
buzz
11
fizz
13
14
fizzbuzz

我们首先定义模型的输入与输出(训练数据)

import numpy as np
import torch

NUM_DIGITS = 10

def binary_encode(i,num_digits):
    return np.array([i >> d & 1 for d in range(num_digits)][::-1])

trX = torch.Tensor([binary_encode(i, NUM_DIGITS) for i in range(101, 2 ** NUM_DIGITS)])
trY = torch.LongTensor([fizz_buzz_encode(i) for i in range(101, 2 ** NUM_DIGITS)])

然后我们用PyTorch定义模型

NUM_HIDDEN = 100
model = torch.nn.Sequential(
    torch.nn.Linear(NUM_DIGITS, NUM_HIDDEN),
    torch.nn.ReLU(),
    torch.nn.Linear(NUM_HIDDEN, 4) # 4 logits, after softmax, we get a probability distribution
)
  • 为了让我们的模型学会FizzBuzz这个游戏,我们需要定义一个损失函数,和一个优化算法
  • 这个优化算法会不断优化(降低)损失函数的值,使得模型在该任务上取得尽可能低的损失值
  • 由于FizzBuzz游戏本质上是一个分类问题,我们选用Cross Entropy Loss函数
  • 优化方法我们选用Adam
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3)

以下是模型的训练代码

BATCH_SIZE = 128
for epoch in range(10000):
    for start in range(0, len(trX), BATCH_SIZE):
        end = start + BATCH_SIZE
        batchX = trX[start: end]
        batchY = trY[start: end]
        
        y_pred = model(batchX)
        loss = loss_fn(y_pred, batchY)
        
        optimizer.zero_grad()
        loss.backward() # backpass
        optimizer.step() # gradient descent
    if epoch % 100 == 99:
            print("Epoch", epoch, loss.item())

Output

Epoch 99 0.09734082221984863
Epoch 199 0.05357144773006439
Epoch 299 0.0314633809030056
Epoch 399 0.019847147166728973
Epoch 499 0.012578330934047699
Epoch 599 0.00820158887654543
Epoch 699 0.0055471863597631454
Epoch 799 0.0038466511759907007
Epoch 899 0.002692295704036951
Epoch 999 0.001864218502305448
Epoch 1099 0.0013077397597953677
Epoch 1199 0.0009263207321055233
Epoch 1299 0.0006638177437707782
Epoch 1399 0.00046644656686112285
Epoch 1499 0.0003352142230141908
Epoch 1599 0.0002458970120642334
Epoch 1699 0.00017381591896992177
Epoch 1799 0.0001258551055798307
Epoch 1899 9.327056613983586e-05
Epoch 1999 6.702014798065647e-05
Epoch 2099 4.869134136242792e-05
Epoch 2199 3.6228873796062544e-05
Epoch 2299 2.7699532438418828e-05
Epoch 2399 2.139510070264805e-05
Epoch 2499 1.5258415260177571e-05
Epoch 2599 1.1108341823273804e-05
Epoch 2699 8.216499736590777e-06
Epoch 2799 5.898595645703608e-06
Epoch 2899 4.3091636143799406e-06
Epoch 2999 2.9404811812128173e-06
Epoch 3099 2.1413443391793407e-06
Epoch 3199 1.5232258192554582e-06
Epoch 3299 1.0508059631320066e-06
Epoch 3399 8.34463719456835e-07
Epoch 3499 5.739700554840965e-07
Epoch 3599 3.9736403323331615e-07
Epoch 3699 2.958154823318182e-07
Epoch 3799 2.0309724391154305e-07
Epoch 3899 1.589456815054291e-07
Epoch 3999 1.1037895575327639e-07
Epoch 4099 8.388801120418066e-08
Epoch 4199 5.739705954965757e-08
Epoch 4299 3.973642392907095e-08
Epoch 4399 2.4106079763441812e-06
Epoch 4499 2.428332663839683e-07
Epoch 4599 2.3400301074616436e-07
Epoch 4699 2.0309704495957703e-07
Epoch 4799 1.7219102232957084e-07
Epoch 4899 1.5011529796993273e-07
Epoch 4999 1.2803954518858518e-07
Epoch 5099 1.103789202261396e-07
Epoch 5199 9.713345860973277e-08
Epoch 5299 7.947283364728719e-08
Epoch 5399 7.505768451210315e-08
Epoch 5499 6.622737203088036e-08
Epoch 5599 4.85667435157211e-08
Epoch 5699 4.4151583722396026e-08
Epoch 5799 3.5321267688459557e-08
Epoch 5899 3.5321267688459557e-08
Epoch 5999 3.5321267688459557e-08
Epoch 6099 3.090611144784816e-08
Epoch 6199 2.6490951654523087e-08
Epoch 6299 2.6490951654523087e-08
Epoch 6399 2.6490951654523087e-08
Epoch 6499 2.6490951654523087e-08
Epoch 6599 2.6490951654523087e-08
Epoch 6699 2.2075791861198013e-08
Epoch 6799 1.7660633844229778e-08
Epoch 6899 1.7660633844229778e-08
Epoch 6999 1.3245475827261544e-08
Epoch 7099 8.830316922114889e-09
Epoch 7199 0.0
Epoch 7299 0.0
Epoch 7399 4.4151584610574446e-09
Epoch 7499 0.0
Epoch 7599 8.830316922114889e-09
Epoch 7699 2.119274284950734e-07
Epoch 7799 6.6227364925453e-08
Epoch 7899 3.973642037635727e-08
Epoch 7999 3.090610789513448e-08
Epoch 8099 1.7660633844229778e-08
Epoch 8199 1.7660633844229778e-08
Epoch 8299 1.3245475827261544e-08
Epoch 8399 8.830316922114889e-09
Epoch 8499 4.4151584610574446e-09
Epoch 8599 4.4151584610574446e-09
Epoch 8699 4.4151584610574446e-09
Epoch 8799 4.4151584610574446e-09
Epoch 8899 0.0
Epoch 8999 0.0
Epoch 9099 0.0
Epoch 9199 0.0
Epoch 9299 0.0
Epoch 9399 0.0
Epoch 9499 0.0
Epoch 9599 0.0
Epoch 9699 0.0
Epoch 9799 0.0
Epoch 9899 0.0
Epoch 9999 0.0

最后我们用训练好的模型尝试在1到100这些数字上玩FizzBuzz游戏

testX = torch.Tensor([binary_encode(i, NUM_DIGITS) for i in range(1, 101)])
with torch.no_grad():
    testY = model(testX)
    
predictions = zip(range(1,101), list(testY.max(1)[1].data.tolist()))
print([fizz_buzz_decode(i,x) for i, x in predictions])

Output

['1', '2', 'fizz', '4', 'buzz', 'fizz', '7', '8', 'fizz', 'buzz', '11', 'fizz', '13', '14', 'fizzbuzz', '16', '17', 'fizz', '19', 'buzz', 'fizz', '22', 'buzz', 'fizz', 'buzz', '26', 'fizz', '28', '29', 'fizzbuzz', '31', '32', 'fizz', '34', 'buzz', 'fizz', '37', 'fizz', 'fizz', 'buzz', '41', 'fizz', '43', '44', 'fizzbuzz', 'fizz', '47', 'fizz', '49', 'buzz', 'fizz', '52', '53', 'fizz', 'buzz', '56', 'fizz', '58', '59', 'fizzbuzz', '61', '62', 'fizz', '64', 'buzz', 'fizz', '67', '68', 'fizz', 'buzz', '71', 'fizz', '73', '74', 'fizzbuzz', '76', '77', 'fizz', '79', 'buzz', '81', '82', '83', 'fizz', 'buzz', '86', 'fizz', '88', '89', 'fizzbuzz', '91', '92', 'fizz', '94', 'buzz', 'fizz', '97', 'fizz', 'fizz', 'buzz']
y_pred = testY.max(1)[1].data.tolist()
y = [fizz_buzz_encode(i) for i in range(1, 101)]
correctSum = 0
for i in range(0, 100):
    if y[i] == y_pred[i]:
        correctSum += 1
eta = correctSum / 100
print(eta)

Output

0.95

测试数据效果还挺好!

发布了36 篇原创文章 · 获赞 3 · 访问量 6227

猜你喜欢

转载自blog.csdn.net/Oh_MyBug/article/details/104426371
今日推荐