Loading [MathJax]/jax/output/SVG/jax.js
MENU

Train / Val / Test 划分

January 8, 2020 • Read: 5989 • Deep Learning阅读设置

合理的 Train/Test 集划分会有效地减少 under-fitting 和 over-fitting 现象。以数字识别为例,正常一个数据集我们要划分出来训练部分和测设部分,如下图所示

左侧橘色部分作为训练部分,神经网络在该区域内不停地学习,将特征转入到函数中,学习好后得到一个函数模型。随后将上图右面白色区域的测试部分导入到该模型中,进行 accuracy 和 loss 的验证

通过不断地测试,查看模型是否调整到了一个最佳的参数,及结果是否发生了 over-fitting 现象

  • # 训练-测试代码写法
  • train_loader = torch.utils.data.DataLoader(
  • # 一般使用DataLoader函数来让机器学习或测试
  • datasets.MNIST('../data', train=True, download=True,
  • # 使用 train=True 或 train=False来进行数据集的划分
  • # train=True时为训练集,相反即为测试集
  • transform=transforms.Compose([
  • transforms.ToTensor(),
  • transforms.Normalize((0.1307,),(0.3081,))
  • ])),
  • batch_size=batch_size, shuffle=True)
  • test_loader = torch.utils.data.DataLoader(
  • datasets.MNIST('../data', train=False, download=True,
  • transform=transforms.Compose([
  • transforms.ToTensor(),
  • transforms.Normalize((0.1307,),(0.3081,))
  • ])),
  • batch_size=batch_size, shuffle=True)

这里注意,正常情况下数据集是要有 validation(验证)集的,若没有设置,即将 test 和 val 集合并为一个了

上面讲解了如何对数据集进行划分,那么如何进行循环学习验证测试呢?

  • for epoch in range(epochs):
  • for batch_idx, (data, target) in enumerate(train_loader):
  • ...
  • optimizer.zero_grad()
  • loss.backward()
  • optimizer.step()
  • if batch_idx % 100 == 0:
  • print('Train Epoch: {} [{} / {} ({:.0f}%)]\tLoss: {:.6f}'.format(
  • epoch, batch_idx * len(data), len(train_loadern.dataset),
  • 100. * batch_idx / len(train_loader), loss.item()))
  • # 每次循环都查看一次是否发生了over-fitting
  • # 如果发生了over-fitting,我们便将最后一次模型的状态作为最终的版本
  • test_loss = 0
  • correct = 0
  • for data, target in test_loader:
  • data = data.view(-1, 28*28)
  • pred = logits.data.max(1)[1]
  • correct += pred.eq(target.data).sum()
  • ...
  • test_loss /= len(test_loader.dataset)
  • print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
  • test_loss, correct, len(test_loader.dataset),
  • 100. * correct / len(test_loader.dataset)))

以一个实际例子的 Train Error 和 Test Error 来举例作图

由图看出在 Training 到第 5 次后,Test Error 便达到一个较低的位置。而后随着训练次数的增加,Test Error 逐渐增加,发生 over-fitting 现象。我们将训练次数在第 5 次的这个点叫做 check-point

其实正常情况下除了 Train Set 和挑选最佳参数的 Test Set 外,一般还要有 Validation Set。Val Set 代替 Test Set 的功能,而 Test Set 则要交给客户,进行实际验证的,正常情况下 Test Set 是不加入到测试中的

说个很具体的场景,就比方说 Kaggle 竞赛中,比赛的主办方给你训练的数据集,一般我们拿来都会自己分成 Train Set 和 Val Set 两部分,用 Train Set 训练,Val Set 挑选最佳参数,训练好以后提交模型,此时主办方会用它自己的数据集,即 Test Set 去测试你的模型,得到一个 Score

从上面的过程能看出,Val Set 可以理解为是从 Train Set 中拆出来的一部分,而与 Test Set 没有关系

  • print('train:', len(train_db), 'test:', len(test_db))
  • # 首先先查看train和test数据集的数量,看看是否满足预订的分配目标
  • train_db, val_db = torch.utils.data.random_split(train_db, [50000, 10000])
  • # 随机分配法将数据分为50k和10k的数量
  • train_loader = torch.utils.data.DataLoader(
  • train_db,
  • batch_size = batch_size, shuffle=True)
  • val_loader = torch.utils.data.DataLoader(
  • val_db,
  • batch_size = batch_size, shuffle=True)

但是这种训练方式也会有一些问题,如下图,假设总的数据量是 70k

划分完成之后,Test Set 中的数据集是无法使用的,这样就只有 50+10k 的数据可以用于学习

为了增加学习的样本,我们可以用 K-fold cross-validation 的方法,将这 60k 训练了的样本,再重新随机划分出 50k 的 Train Set 和 10k 的 Val Set

白色部分为新划分的 Val Set,两个黄色部分加一块为 Train Set。每进行一个 epoch,便将新的 Train Set 给了网络。这样做的好处是使得数据集中的每一个数据都有可能被网络学习,防止网络对相同的数据产生记忆

叫 K-fold cross-validation 的原因在于,假设有 60K 的 (Train+Val) Set 可供使用,分成了 N 份,每次取 N1N 份用来做 train,另外 1N 份用来做 validation

Last Modified: August 4, 2021
Archives Tip
QR Code for this page
Tipping QR Code
Leave a Comment

  • OωO
  • |´・ω・)ノ
  • ヾ(≧∇≦*)ゝ
  • (☆ω☆)
  • (╯‵□′)╯︵┴─┴
  •  ̄﹃ ̄
  • (/ω\)
  • ∠( ᐛ 」∠)_
  • (๑•̀ㅁ•́ฅ)
  • →_→
  • ୧(๑•̀⌄•́๑)૭
  • ٩(ˊᗜˋ*)و
  • (ノ°ο°)ノ
  • (´இ皿இ`)
  • ⌇●﹏●⌇
  • (ฅ´ω`ฅ)
  • (╯°A°)╯︵○○○
  • φ( ̄∇ ̄o)
  • ヾ(´・ ・`。)ノ"
  • ( ง ᵒ̌皿ᵒ̌)ง⁼³₌₃
  • (ó﹏ò。)
  • Σ(っ °Д °;)っ
  • ( ,,´・ω・)ノ"(´っω・`。)
  • ╮(╯▽╰)╭
  • o(*////▽////*)q
  • >﹏<
  • ( ๑´•ω•) "(ㆆᴗㆆ)
  • (。•ˇ‸ˇ•。)
  • 泡泡
  • 阿鲁
  • 颜文字