MENU

PyTorch 训练神经网络玩游戏

June 2, 2019 • Read: 8491 • Deep Learning阅读设置

Game rules

很简单的一个小游戏,名字叫 "FizzBuzz",游戏规则如下:

从 1 开始数数,当遇到 3 的倍数的时候,说 fizz,当遇到 5 的倍数的时候,说 buzz,当遇到 15 的倍数的时候,就说 fizzbuzz,其他情况则正常数数

Game conversion to classification problem

可以想到,在这个游戏中,总共只有四类,fizzbuzzbuzzfizznumber

所以我们先定义一个函数,这个函数的作用是将输入的数字,离散为这四类中的某一类

  • def fizz_buzz_encode(i):
  • if i % 15 == 0:
  • return 3
  • elif i % 5 == 0:
  • return 2
  • elif i % 3 == 0:
  • return 1
  • else:
  • return 0

有了 encode 函数,还需要一个 decode 函数,参数是个数字,以及这个数字的类别,返回是这个数字应该喊什么,比方说 decode(15, 3),返回的就应该是 fizzbuzz,再比如 decode(7, 0),就应该返回 7

  • def fizz_buzz_decode(i, label):
  • return [str(i), 'fizz', 'buzz', 'fizzbuzz'][label]

写个测试函数测试一下

  • def helper(i):
  • print(fizz_buzz_decode(i, fizz_buzz_encode(i)))
  • for i in range(1, 16):
  • helper(i)
  • 输出:
  • 1
  • 2
  • fizz
  • 4
  • buzz
  • fizz
  • 7
  • 8
  • fizz
  • buzz
  • 11
  • fizz
  • 13
  • 14
  • fizzbuzz

Generate training set

  • import numpy as np
  • import torch
  • from torch import nn

对于一个神经网络,我们的输入是一个数字,我们要他返回的是这个数字属于哪个类别(知道哪个类别之后调用 decode 函数就行了)

但其实输入如果单纯是个十进制数字特征不够明显,我们可以尝试把十进制转换为二进制,将 01 编码作为输入

  • NUM_DIGITS = 10
  • def binary_encode(i, NUM_DIGITS): # 将一个十进制数转换为二进制
  • return np.array([i >> d & 1 for d in range(NUM_DIGITS)][::-1])
  • #print(binary_encode(15, NUM_DIGITS))

然后生成训练集 Xy,我把 $[101,1024]$ 之间的所有整数转为二进制作为 X_train,调用 encode 函数生成的标签作为 y_train

  • X_train = torch.Tensor([binary_encode(i, NUM_DIGITS) for i in range(101, 2 ** NUM_DIGITS)])
  • y_train = torch.LongTensor([fizz_buzz_encode(i) for i in range(101, 2 ** NUM_DIGITS)])

Construct neural network

首先设计网络结构

然后利用 PyTorch 定义模型

  • NUM_HIDDEN = 100 # 隐藏层100个神经元
  • model = nn.Sequential( # 网络结构:Input -> Hidden_Layer1 -> OutPut
  • nn.Linear(NUM_DIGITS, NUM_HIDDEN, bias = False), # z = w1*x, 其中w1.shape=(10, 100), x.shape=(923, 10)
  • nn.ReLU(), # z = relu(z), 其中z.shape=(923, 100)
  • nn.Linear(NUM_HIDDEN, 4, bias = False) # y_pred = z*w2, 其中z.shape(923, 100), w2.shape=(100, 4)
  • # 输出的是个923*4的矩阵
  • )

定义 Loss_Function 和梯度下降的方法

  • loss_fn = nn.CrossEntropyLoss() # 专为分类问题设计的Loss
  • optimizer = torch.optim.SGD(model.parameters(), lr = 0.1) # lr is learning_rate

开始训练模型

  • BATCH_SIZE = 128
  • for epoch in range(10000):
  • for start in range(0, len(X_train), BATCH_SIZE):
  • end = start + BATCH_SIZE
  • batchX = X_train[start:end]
  • batchY = y_train[start:end]
  • y_pred = model(batchX)
  • loss = loss_fn(y_pred, batchY)
  • print('Epoch', epoch, loss.item())
  • optimizer.zero_grad()
  • loss.backward()
  • optimizer.step()

如果关于 BATCH_SIZEEPOCH 不清楚作用,可以看这篇文章

训练最终结果如下图,我们说,如果一个人通过瞎猜玩这个游戏,那他每次的正确率只有 $\frac {1}{4}$,但是从训练结果来看,很明显我们的网络的准确度比瞎猜要高很多

训练完以后生成测试数据 X_test

  • X_test = torch.Tensor([binary_encode(i, NUM_DIGITS) for i in range(0, 101)])

然后用训练好的模型对测试数据进行预测,生成 y_test,假设测试数据有 100 个,那 y_test 的大小就是 (100, 4),4 列分别对应每个类型的概率,我们取出最大概率对应的下标值,带入 decode 函数,就能看到他在测试数据上的表现了

  • with torch.no_grad():
  • y_test = model(X_test)
  • #y_test.max(1)[1]
  • predicts = zip(range(0, 101), list(y_test.max(1)[1].data.tolist()))
  • print([fizz_buzz_decode(i, x) for i, x in predicts])
  • 输出:
  • ['0', '1', 'fizz', '3', 'buzz', 'fizz', '6', '7', 'fizz', 'buzz', '10', 'fizz', '12', '13', 'fizzbuzz', '15', '16', 'fizz', '18', 'buzz', '20', '21', '22', 'fizz', 'buzz', '25', 'fizz', '27', '28', 'fizzbuzz', '30', 'fizz', 'fizz', '33', 'buzz', 'fizz', '36', '37', 'fizz', 'buzz', '40', 'fizz', '42', '43', 'fizzbuzz', '45', '46', 'fizz', '48', 'buzz', 'fizz', '51', '52', 'fizz', 'fizz', '55', 'fizz', '57', '58', 'fizzbuzz', '60', '61', 'fizz', '63', 'buzz', 'fizz', '66', '67', 'fizz', 'buzz', '70', 'fizz', '72', '73', '74', '75', '76', 'fizz', '78', 'buzz', 'fizz', '81', '82', 'fizz', 'buzz', '85', 'fizz', '87', '88', 'fizzbuzz', '90', '91', 'fizz', '93', 'buzz', 'fizz', '96', '97', 'fizz']

最终测试的效果并不是特别好,但是从一些数据当中可以看到,我们这个网络实际还是找到了这个游戏的部分规律。单从 fizzbuzz 的结果来看,虽然他并没有准确的达到每次都在 15 的倍数输出,但是它隐约知道在 15 的倍数附近要输出

Last Modified: March 15, 2022
Archives Tip
QR Code for this page
Tipping QR Code
Leave a Comment

3 Comments
  1. 布啦豆 布啦豆

    我是个菜鸡,不懂就问

    X_test = torch.Tensor([binary_encode(i, NUM_DIGITS) for i in range(1, 100)])

    with torch.no_grad():

    y_test = model(X_test) y_test.max(1)[1]

    predicts = zip(range(0, 101), list(y_test.max(1)[1].data.tolist()))

    print([fizz_buzz_decode(i, x) for i, x in predicts])

    这里的 range (1, 100) 和 range (0, 101),两个 range 范围值是不是必须一致?

    1. mathor mathor

      @布啦豆没想到过了一年,终于找到问题在哪了,已修改,感谢

  2. bluesky bluesky

    所以结果还是很好的