MENU

Train / Val / Test划分

January 8, 2020 • Read: 500 • Deep Learning

合理的Train/Test集划分会有效地减少under-fitting和over-fitting现象

以数字识别为例,正常一个数据集我们要划分出来训练部分和测设部分,如下图所示

左侧橘色部分作为训练部分,神经网络在该区域内不停地学习,将特征转入到函数中,学习好后得到一个函数模型。随后将上图右面白色区域的测试部分导入到该模型中,进行accuracy和loss的验证

通过不断地测试,查看模型是否调整到了一个最佳的参数,及结果是否发生了over-fitting现象

# 训练-测试代码写法
train_loader = torch.utils.data.DataLoader(
# 一般使用DataLoader函数来让机器学习或测试
    datasets.MNIST('../data', train=True, download=True,
# 使用 train=True 或 train=False来进行数据集的划分
# train=True时为训练集,相反即为测试集
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,),(0.3081,))
                   ])),
    batch_size=batch_size, shuffle=True)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,),(0.3081,))
                   ])),
    batch_size=batch_size, shuffle=True)

这里注意,正常情况下数据集是要有validation(验证)集的,若没有设置,即将test和val集合并为一个了

上面讲解了如何对数据集进行划分,那么如何进行循环学习验证测试呢?

for epoch in range(epochs):
    for batch_idx, (data, target) in enumerate(train_loader):
        ...
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{} / {} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loadern.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
    
    # 每次循环都查看一次是否发生了over-fitting
    # 如果发生了over-fitting,我们便将最后一次模型的状态作为最终的版本
    test_loss = 0
    correct = 0
    for data, target in test_loader:
        data = data.view(-1, 28*28)
        pred = logits.data.max(1)[1]
        correct += pred.eq(target.data).sum()
        ...
    
    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

以一个实际例子的Train Error和Test Error来举例作图

由图看出在Training到第5次后,Test Error便达到一个较低的位置。而后随着训练次数的增加,Test Error逐渐增加,发生over-fitting现象。我们将训练次数在第5次的这个点叫做check-point

其实正常情况下除了Train Set和挑选最佳参数的Test Set外,一般还要有Validation Set。Val Set代替Test Set的功能,而Test Set则要交给客户,进行实际验证的,正常情况下Test Set是不加入到测试中的

说个很具体的场景,就比方说Kaggle竞赛中,比赛的主办方给你训练的数据集,一般我们拿来都会自己分成Train Set和Val Set两部分,用Train Set训练,Val Set挑选最佳参数,训练好以后提交模型,此时主办方会用它自己的数据集,即Test Set去测试你的模型,得到一个Score

从上面的过程能看出,Val Set可以理解为是从Train Set中拆出来的一部分,而与Test Set没有关系

print('train:', len(train_db), 'test:', len(test_db))
# 首先先查看train和test数据集的数量,看看是否满足预订的分配目标
train_db, val_db = torch.utils.data.random_split(train_db, [50000, 10000])
# 随机分配法将数据分为50k和10k的数量
train_loader = torch.utils.data.DataLoader(
    train_db,
    batch_size = batch_size, shuffle=True)

val_loader = torch.utils.data.DataLoader(
    val_db,
    batch_size = batch_size, shuffle=True)

但是这种训练方式也会有一些问题,如下图,假设总的数据量是70k

划分完成之后,Test Set中的数据集是无法使用的,这样就只有50+10k的数据可以用于学习

为了增加学习的样本,我们可以用K-fold cross-validation的方法,将这60k训练了的样本,再重新随机划分出50k的Train Set和10k的Val Set

白色部分为新划分的Val Set,两个黄色部分加一块为Train Set。每进行一个epoch,便将新的Train Set给了网络。这样做的好处是使得数据集中的每一个数据都有可能被网络学习,防止网络对相同的数据产生记忆

叫K-fold cross-validation的原因在于,假设有60K的(Train+Val) Set可供使用,分成了N份,每次取$\frac{N-1}{N}$份用来做train,另外$\frac{1}{N}$份用来做validation

Archives Tip
QR Code for this page
Tipping QR Code