MENU

Transfer Learning

February 8, 2020 • Read: 3440 • Deep Learning阅读设置

Pokemon Dataset

通过网络上收集宝可梦的图片,制作图像分类数据集。我收集了 5 种宝可梦,分别是皮卡丘,超梦,杰尼龟,小火龙,妙蛙种子

数据集链接:https://pan.baidu.com/s/1Kept7FF88lb8TqPZMD_Yxw 提取码:1sdd

一共有 1168 张宝可梦的图片,其中皮卡丘 234 张,超梦 239 张,杰尼龟 223 张,小火龙 238 张,妙蛙种子 234 张

每个目录由神奇宝贝名字命名,对应目录下是该神奇宝贝的图片,图片的格式有 jpg,png,jpeg 三种

数据集的划分如下(训练集 60%,验证集 20%,测试集 20%)。这个比例不是针对每一类提取,而是针对总体的 1168 张

Load Data

在 PyTorch 中定义数据集主要涉及到两个主要的类:Dataset 和 DataLoder

DataSet 类

DataSet 类是 PyTorch 中所有数据集加载类中都应该继承的父类,它的两个私有成员函数__len__()__getitem__() 必须被重载,否则将触发错误提示

其中__len__() 应该返回数据集的样本数量,而__getitem__() 实现通过索引返回样本数据的功能

首先看一个自定义 Dataset 的例子

  • class NumbersDataset(Dataset):
  • def __init__(self, training=True):
  • if training:
  • self.samples = list(range(1, 1001))
  • else:
  • self.samples = list(range(1001, 1501))
  • def __len__(self):
  • return len(self.samples)
  • def __getitem__(self, idx):
  • return self.samples[idx]

然后需要对图片做 Preprocessing

  1. Image Resize:224*224 for ResNet18
  2. Data Argumentation:Rotate & Crop
  3. Normalize:Mean & std
  4. ToTensor

首先我们在__init__() 函数里将 name->label,这里的 name 就是文件夹的名字,然后拆分数据集,按照 6:2:2 的比例

  • class Pokemon(Dataset):
  • def __init__(self, root, resize, model):
  • super(Pokemon, self).__init__()
  • self.root = root
  • self.resize = resize
  • self.name2label = {} # 将文件夹的名字映射为label(数字)
  • for name in sorted(os.listdir(os.path.join(root))):
  • if not os.path.isdir(os.path.join(root, name)):
  • continue
  • self.name2label[name] = len(self.name2label.keys())
  • # image, label
  • self.images, self.labels = self.load_csv('images.csv')
  • if model == 'train': # 60%
  • self.images = self.images[:int(0.6*len(self.images))]
  • self.labels = self.labels[:int(0.6*len(self.labels))]
  • elif model == 'val': # 20%
  • self.images = self.images[int(0.6*len(self.images)):int(0.8*len(self.images))]
  • self.labels = self.labels[int(0.6*len(self.labels)):int(0.8*len(self.labels))]
  • else: # 20%
  • self.images = self.images[int(0.8*len(self.images)):]
  • self.labels = self.labels[int(0.8*len(self.labels)):]

其中 load_csv() 函数的作用是将所有的图片名(名字里包含完整的路径)以及 label 都存到 csv 文件里,例如,有一个图片的路径是 pokemon\\bulbasaur\\00000000.png,对应的 label 是 0,那么 csv 就会写入一行 pokemon\\bulbasaur\\00000000.png, 0,总共写入了 1167 行(有一张图片既不是 png,也不是 jpg 和 jpeg,找不到,算了)。load_csv() 函数具体如下所示

  • def load_csv(self, filename):
  • if not os.path.exists(os.path.join(self.root, filename)):
  • images = []
  • for name in self.name2label.keys():
  • images += glob.glob(os.path.join(self.root, name, '*.png'))
  • images += glob.glob(os.path.join(self.root, name, '*.jpg'))
  • images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
  • random.shuffle(images)
  • with open(os.path.join(self.root, filename), mode='w', newline='') as f:
  • writer = csv.writer(f)
  • for img in images: # pokemon\\bulbasaur\\00000000.png
  • name = img.split(os.sep)[-2] # bulbasaur
  • label = self.name2label[name]
  • # pokemon\\bulbasaur\\00000000.png 0
  • writer.writerow([img, label])
  • print('writen into csv file:', filename)
  • # read csv file
  • images, labels = [], []
  • with open(os.path.join(self.root, filename)) as f:
  • reader = csv.reader(f)
  • for row in reader:
  • image, label = row
  • label = int(label)
  • images.append(image)
  • labels.append(label)
  • assert len(images) == len(labels)
  • return images, labels

然后是__len__() 函数的代码

  • def __len__(self):
  • return len(self.images)

最后是__getitem__() 函数的代码,这个比较复杂,因为我们现在只有图片的 string path (字符串形式的路径),要先转成三通道的 image data,这个利用 PIL 库中的 Image.open(path).convert('RGB') 函数可以完成。图片读取出来以后,要经过一系列的 transforms,具体代码如下

  • def __getitem__(self, idx):
  • # idx [0~len(images)]
  • # self.images, self.labels
  • # pokemon\\bulbasaur\\00000000.png 0
  • img, label = self.images[idx], self.labels[idx]
  • tf = transforms.Compose([
  • lambda x:Image.open(x).convert('RGB'), # string path => image data
  • transforms.Resize((int(self.resize*1.25), int(self.resize*1.25))),
  • transforms.RandomRotation(15),
  • transforms.CenterCrop(self.resize),
  • transforms.ToTensor(),
  • transforms.Normalize(mean=[0.485, 0.456, 0.406],
  • std=[0.229, 0.224, 0.225])
  • ])
  • img = tf(img)
  • label = torch.tensor(label)
  • return img, label

Normalize 的参数是 PyTorch 推荐的,直接写上就可以了

DataLoader 类

Dataset 类是读入数据集并对读入的数据进行了索引,但是光有这个功能是不够的,在实际加载数据集的过程中,我们的数据量往往都很大,因此还需要以下几个功能:

  1. 每次读入一些批次:batch_size
  2. 可以对数据进行随机读取,打乱数据的顺序 (shuffling)
  3. 可以并行加载数据集(利用多核处理器加快载入数据的效率)

为此,就需要 DataLoader 类了,它里面常用的参数有:

  • batch_size:每个 batch 的大小
  • shuffle:是否进行 shuffle 操作
  • num_works:加载数据的时候使用几个进程

DataLoader 这个类并不需要我们自己设计代码,只需要利用它读取我们设计好的 Dataset 的子类即可

  • db = Pokemon('pokemon', 224, 'train')
  • lodder = DataLoader(db, batch_size=32, shuffle=True, num_workers=4)

完整代码如下:

  • import torch
  • import os, glob
  • import random, csv
  • from torch.utils.data import Dataset, DataLoader
  • from torchvision import transforms
  • from PIL import Image
  • class Pokemon(Dataset):
  • def __init__(self, root, resize, model):
  • super(Pokemon, self).__init__()
  • self.root = root
  • self.resize = resize
  • self.name2label = {} # 将文件夹的名字映射为label(数字)
  • for name in sorted(os.listdir(os.path.join(root))):
  • if not os.path.isdir(os.path.join(root, name)):
  • continue
  • self.name2label[name] = len(self.name2label.keys())
  • # image, label
  • self.images, self.labels = self.load_csv('images.csv')
  • if model == 'train': # 60%
  • self.images = self.images[:int(0.6*len(self.images))]
  • self.labels = self.labels[:int(0.6*len(self.labels))]
  • elif model == 'val': # 20%
  • self.images = self.images[int(0.6*len(self.images)):int(0.8*len(self.images))]
  • self.labels = self.labels[int(0.6*len(self.labels)):int(0.8*len(self.labels))]
  • else: # 20%
  • self.images = self.images[int(0.8*len(self.images)):]
  • self.labels = self.labels[int(0.8*len(self.labels)):]
  • def load_csv(self, filename):
  • if not os.path.exists(os.path.join(self.root, filename)):
  • images = []
  • for name in self.name2label.keys():
  • images += glob.glob(os.path.join(self.root, name, '*.png'))
  • images += glob.glob(os.path.join(self.root, name, '*.jpg'))
  • images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
  • random.shuffle(images)
  • with open(os.path.join(self.root, filename), mode='w', newline='') as f:
  • writer = csv.writer(f)
  • for img in images: # pokemon\\bulbasaur\\00000000.png
  • name = img.split(os.sep)[-2] # bulbasaur
  • label = self.name2label[name]
  • # pokemon\\bulbasaur\\00000000.png 0
  • writer.writerow([img, label])
  • print('writen into csv file:', filename)
  • # read csv file
  • images, labels = [], []
  • with open(os.path.join(self.root, filename)) as f:
  • reader = csv.reader(f)
  • for row in reader:
  • image, label = row
  • label = int(label)
  • images.append(image)
  • labels.append(label)
  • assert len(images) == len(labels)
  • return images, labels
  • def __len__(self):
  • return len(self.images)
  • def __getitem__(self, idx):
  • # idx [0~len(images)]
  • # self.images, self.labels
  • # pokemon\\bulbasaur\\00000000.png 0
  • img, label = self.images[idx], self.labels[idx]
  • tf = transforms.Compose([
  • lambda x:Image.open(x).convert('RGB'), # string path => image data
  • transforms.Resize((int(self.resize*1.25), int(self.resize*1.25))),
  • transforms.RandomRotation(15),
  • transforms.CenterCrop(self.resize),
  • transforms.ToTensor(),
  • transforms.Normalize(mean=[0.485, 0.456, 0.406],
  • std=[0.229, 0.224, 0.225])
  • ])
  • img = tf(img)
  • label = torch.tensor(label)
  • return img, label
  • db = Pokemon('pokemon', 224, 'train')
  • lodder = DataLoader(db, batch_size=32, shuffle=True, num_workers=8)

Build Model

用 PyTorch 搭建 ResNet 其实在我之前的文章已经讲过了,这里直接拿来用,修改一下里面的参数就行了

  • import torch
  • import torch.nn as nn
  • import torch.nn.functional as F
  • class ResBlk(nn.Module):
  • def __init__(self, ch_in, ch_out, stride=1):
  • super(ResBlk, self).__init__()
  • self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=stride, padding=1)
  • self.bn1 = nn.BatchNorm2d(ch_out)
  • self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)
  • self.bn2 = nn.BatchNorm2d(ch_out)
  • self.extra = nn.Sequential()
  • if ch_out != ch_in:
  • self.extra = nn.Sequential(
  • nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=stride),
  • nn.BatchNorm2d(ch_out),
  • )
  • def forward(self, x):
  • out = F.relu(self.bn1(self.conv1(x)))
  • out = self.bn2(self.conv2(out))
  • # short cut
  • out = self.extra(x) + out
  • out = F.relu(out)
  • return out
  • class ResNet18(nn.Module):
  • def __init__(self, num_class):
  • super(ResNet18, self).__init__()
  • self.conv1 = nn.Sequential(
  • nn.Conv2d(3, 16, kernel_size=3, stride=3, padding=0),
  • nn.BatchNorm2d(16),
  • )
  • # followed 4 blocks
  • # [b, 16, h, w] => [b, 32, h, w]
  • self.blk1 = ResBlk(16, 32, stride=3)
  • # [b, 32, h, w] => [b, 64, h, w]
  • self.blk2 = ResBlk(32, 64, stride=3)
  • # [b, 64, h, w] => [b, 128, h, w]
  • self.blk3 = ResBlk(64, 128, stride=2)
  • # [b, 128, h, w] => [b, 256, h, w]
  • self.blk4 = ResBlk(128, 256, stride=2)
  • self.outlayer = nn.Linear(256*3*3, num_class)
  • def forward(self, x):
  • x = F.relu(self.conv1(x))
  • x = self.blk1(x)
  • x = self.blk2(x)
  • x = self.blk3(x)
  • x = self.blk4(x)
  • x = x.view(x.size(0), -1)
  • x = self.outlayer(x)
  • return x

Train and Test

训练的时候,严格按照 Training 和 Test 的逻辑,就是在训练 epoch 的过程中,间断的做一次 validation,然后看一下当前的 validation accuracy 是不是最高的,如果是最高的,就把当前的模型参数保存起来。training 完以后,加载最好的模型,再做 testing。这就是非常严格的训练逻辑。代码如下:

  • batchsz = 32
  • lr = 1e-3
  • epochs = 10
  • device = torch.device('cuda')
  • torch.manual_seed(1234)
  • train_db = Pokemon('pokemon', 224, model='train')
  • val_db = Pokemon('pokemon', 224, model='val')
  • test_db = Pokemon('pokemon', 224, model='test')
  • train_loader = DataLoader(train_db, batch_size=batchsz, shuffle=True, num_workers=2)
  • val_loader = DataLoader(val_db, batch_size=batchsz, num_workers=2)
  • test_loader = DataLoader(test_db, batch_size=batchsz, num_workers=2)
  • def evalute(model, loader):
  • correct = 0
  • total = len(loader.dataset)
  • for x,y in loader:
  • with torch.no_grad():
  • logits = model(x)
  • pred = logits.argmax(dim=1)
  • correct += torch.eq(pred, y).sum().float().item()
  • return correct / total
  • def main():
  • model = ResNet18(5)
  • optimizer = optim.Adam(model.parameters(), lr=lr)
  • criteon = nn.CrossEntropyLoss()
  • best_acc, best_epoch = 0, 0
  • for epoch in range(epochs):
  • for step, (x, y) in enumerate(train_loader):
  • # x:[b, 3, 224, 224], y:[b]
  • logits = model(x)
  • loss = criteon(logits, y)
  • optimizer.zero_grad()
  • loss.backward()
  • optimizer.step()
  • if epoch % 2 == 0:
  • val_acc = evalute(model, val_loader)
  • if val_acc > best_acc:
  • best_epoch = epoch
  • best_acc = val_acc
  • torch.save(model.state_dict(), 'best.mdl')
  • print('best acc:', best_acc, 'best_epoch', best_epoch)
  • model.load_state_dict(torch.load('best.mdl'))
  • print('loaded from ckt!')
  • test_acc = evalute(model, test_loader)
  • print('test_acc:', test_acc)

截至到目前为止,能完整运行的代码如下:

  • import torch
  • import os, glob
  • import warnings
  • import random, csv
  • from PIL import Image
  • from torch import optim, nn
  • import torch.nn.functional as F
  • from torchvision import transforms
  • from torch.utils.data import Dataset, DataLoader
  • warnings.filterwarnings('ignore')
  • class Pokemon(Dataset):
  • def __init__(self, root, resize, model):
  • super(Pokemon, self).__init__()
  • self.root = root
  • self.resize = resize
  • self.name2label = {} # 将文件夹的名字映射为label(数字)
  • for name in sorted(os.listdir(os.path.join(root))):
  • if not os.path.isdir(os.path.join(root, name)):
  • continue
  • self.name2label[name] = len(self.name2label.keys())
  • # image, label
  • self.images, self.labels = self.load_csv('images.csv')
  • if model == 'train': # 60%
  • self.images = self.images[:int(0.6*len(self.images))]
  • self.labels = self.labels[:int(0.6*len(self.labels))]
  • elif model == 'val': # 20%
  • self.images = self.images[int(0.6*len(self.images)):int(0.8*len(self.images))]
  • self.labels = self.labels[int(0.6*len(self.labels)):int(0.8*len(self.labels))]
  • else: # 20%
  • self.images = self.images[int(0.8*len(self.images)):]
  • self.labels = self.labels[int(0.8*len(self.labels)):]
  • def load_csv(self, filename):
  • if not os.path.exists(os.path.join(self.root, filename)):
  • images = []
  • for name in self.name2label.keys():
  • images += glob.glob(os.path.join(self.root, name, '*.png'))
  • images += glob.glob(os.path.join(self.root, name, '*.jpg'))
  • images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
  • random.shuffle(images)
  • with open(os.path.join(self.root, filename), mode='w', newline='') as f:
  • writer = csv.writer(f)
  • for img in images: # pokemon\\bulbasaur\\00000000.png
  • name = img.split(os.sep)[-2] # bulbasaur
  • label = self.name2label[name]
  • # pokemon\\bulbasaur\\00000000.png 0
  • writer.writerow([img, label])
  • print('writen into csv file:', filename)
  • # read csv file
  • images, labels = [], []
  • with open(os.path.join(self.root, filename)) as f:
  • reader = csv.reader(f)
  • for row in reader:
  • image, label = row
  • label = int(label)
  • images.append(image)
  • labels.append(label)
  • assert len(images) == len(labels)
  • return images, labels
  • def __len__(self):
  • return len(self.images)
  • def __getitem__(self, idx):
  • # idx [0~len(images)]
  • # self.images, self.labels
  • # pokemon\\bulbasaur\\00000000.png 0
  • img, label = self.images[idx], self.labels[idx]
  • tf = transforms.Compose([
  • lambda x:Image.open(x).convert('RGB'), # string path => image data
  • transforms.Resize((int(self.resize*1.25), int(self.resize*1.25))),
  • transforms.RandomRotation(15),
  • transforms.CenterCrop(self.resize),
  • transforms.ToTensor(),
  • transforms.Normalize(mean=[0.485, 0.456, 0.406],
  • std=[0.229, 0.224, 0.225])
  • ])
  • img = tf(img)
  • label = torch.tensor(label)
  • return img, label
  • class ResBlk(nn.Module):
  • def __init__(self, ch_in, ch_out, stride=1):
  • super(ResBlk, self).__init__()
  • self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=stride, padding=1)
  • self.bn1 = nn.BatchNorm2d(ch_out)
  • self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)
  • self.bn2 = nn.BatchNorm2d(ch_out)
  • self.extra = nn.Sequential()
  • if ch_out != ch_in:
  • self.extra = nn.Sequential(
  • nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=stride),
  • nn.BatchNorm2d(ch_out),
  • )
  • def forward(self, x):
  • out = F.relu(self.bn1(self.conv1(x)))
  • out = self.bn2(self.conv2(out))
  • # short cut
  • out = self.extra(x) + out
  • out = F.relu(out)
  • return out
  • class ResNet18(nn.Module):
  • def __init__(self, num_class):
  • super(ResNet18, self).__init__()
  • self.conv1 = nn.Sequential(
  • nn.Conv2d(3, 16, kernel_size=3, stride=3, padding=0),
  • nn.BatchNorm2d(16),
  • )
  • # followed 4 blocks
  • # [b, 16, h, w] => [b, 32, h, w]
  • self.blk1 = ResBlk(16, 32, stride=3)
  • # [b, 32, h, w] => [b, 64, h, w]
  • self.blk2 = ResBlk(32, 64, stride=3)
  • # [b, 64, h, w] => [b, 128, h, w]
  • self.blk3 = ResBlk(64, 128, stride=2)
  • # [b, 128, h, w] => [b, 256, h, w]
  • self.blk4 = ResBlk(128, 256, stride=2)
  • self.outlayer = nn.Linear(256*3*3, num_class)
  • def forward(self, x):
  • x = F.relu(self.conv1(x))
  • x = self.blk1(x)
  • x = self.blk2(x)
  • x = self.blk3(x)
  • x = self.blk4(x)
  • x = x.view(x.size(0), -1)
  • x = self.outlayer(x)
  • return x
  • batchsz = 32
  • lr = 1e-3
  • epochs = 10
  • device = torch.device('cuda')
  • torch.manual_seed(1234)
  • train_db = Pokemon('pokemon', 224, model='train')
  • val_db = Pokemon('pokemon', 224, model='val')
  • test_db = Pokemon('pokemon', 224, model='test')
  • train_loader = DataLoader(train_db, batch_size=batchsz, shuffle=True, num_workers=2)
  • val_loader = DataLoader(val_db, batch_size=batchsz, num_workers=2)
  • test_loader = DataLoader(test_db, batch_size=batchsz, num_workers=2)
  • def evalute(model, loader):
  • correct = 0
  • total = len(loader.dataset)
  • for x,y in loader:
  • with torch.no_grad():
  • logits = model(x)
  • pred = logits.argmax(dim=1)
  • correct += torch.eq(pred, y).sum().float().item()
  • return correct / total
  • def main():
  • model = ResNet18(5)
  • optimizer = optim.Adam(model.parameters(), lr=lr)
  • criteon = nn.CrossEntropyLoss()
  • best_acc, best_epoch = 0, 0
  • for epoch in range(epochs):
  • for step, (x, y) in enumerate(train_loader):
  • # x:[b, 3, 224, 224], y:[b]
  • logits = model(x)
  • loss = criteon(logits, y)
  • optimizer.zero_grad()
  • loss.backward()
  • optimizer.step()
  • if epoch % 2 == 0:
  • val_acc = evalute(model, val_loader)
  • if val_acc > best_acc:
  • best_epoch = epoch
  • best_acc = val_acc
  • torch.save(model.state_dict(), 'best.mdl')
  • print('best acc:', best_acc, 'best_epoch', best_epoch)
  • model.load_state_dict(torch.load('best.mdl'))
  • print('loaded from ckt!')
  • test_acc = evalute(model, test_loader)
  • print('test_acc:', test_acc)
  • if __name__ == '__main__':
  • main()

Transfer Learning

运行上面的代码,基本上最终 test accuracy 可以达到 0.88 左右。如果想要提升的话,就需要使用更多工程上的 tricks 或者调参

当然还有一种方法,就是迁移学习,我们先看下面这张图,这张图展示的问题在于,当数据很少的情况下(第一张图),模型训练的结果可能会有很多情况(第二张图),当然最终输出就一个结果。然而这个结果可能 test accuracy 并不高。就比方说我们的 pokemon 图片,只有 1000 多张,算是一个比较少的数据集了,但是由于 pokemon 和 ImageNet 都是图片,它们可能存在某些共性。那我们能不能用 ImageNet 的一些 train 好的模型,拿来帮助我们解决一下特定的图片分类任务,这就是 Transfer Learning,也就是在 A 任务上 train 好一个分类器,再 transfer 到 B 上去

我个人理解 Transfer Learning 的作用是这样的,我们都知道神经网络初始化参数非常重要,有时候初始化不好,可能就会导致最终效果非常差。现在我们用一个在 A 任务上已经训练好了的网络,相当于帮你做了一个很好的初始化,你在这个网络的基础上,去做 B 任务,如果这两个任务比较接近的话,夸张一点说,这个网络的训练可能就只需要微调一下,就能在 B 任务上显示出非常好的效果

下图展示的是一个真实的 Transfer Learning 的过程,左边是已经 training 好的网络,我们利用这个网络的公有部分,吸取它的 common knowledge, 然后把最后一层去掉,换成我们需要的

先上核心代码

  • import torch.nn as nn
  • from torchvision.models import resnet18
  • class Flatten(nn.Module):
  • def __init__(self):
  • super(Flatten, self).__init__()
  • def forward(self, x):
  • shape = torch.prod(torch.tensor(x.shape[1:])).item()
  • return x.view(-1, shape)
  • trained_model = resnet18(pretrained=True)
  • model = nn.Sequential(*list(trained_model.children())[:-1],# [b, 512, 1, 1]
  • Flatten(), # [b, 512, 1, 1] => [b, 512]
  • nn.Linear(512, 5) # [b, 512] => [b, 5]
  • )

PyTorch 中有已经训练好的各种规格的 resnet,第一次使用需要下载。我们不要 resnet18 的最后一层,所以要用 list(trained_model.children())[:-1] 把除了最后一层以外的所有层都取出来,保存在 list 中,然后用 * 将其 list 展开,之后接一个我们自定义的 Flatten 层,作用是将 output 打平,打平以后才能送到 Linear 层去

上面几行代码就实现了 Transfer Learning,而且不需要我们自己实现 resnet,完整代码如下

  • import torch
  • import os, glob
  • import warnings
  • import random, csv
  • from PIL import Image
  • from torch import optim, nn
  • import torch.nn.functional as F
  • from torchvision import transforms
  • from torchvision.models import resnet18
  • from torch.utils.data import Dataset, DataLoader
  • warnings.filterwarnings('ignore')
  • from matplotlib import pyplot as plt
  • class Pokemon(Dataset):
  • def __init__(self, root, resize, model):
  • super(Pokemon, self).__init__()
  • self.root = root
  • self.resize = resize
  • self.name2label = {} # 将文件夹的名字映射为label(数字)
  • for name in sorted(os.listdir(os.path.join(root))):
  • if not os.path.isdir(os.path.join(root, name)):
  • continue
  • self.name2label[name] = len(self.name2label.keys())
  • # image, label
  • self.images, self.labels = self.load_csv('images.csv')
  • if model == 'train': # 60%
  • self.images = self.images[:int(0.6*len(self.images))]
  • self.labels = self.labels[:int(0.6*len(self.labels))]
  • elif model == 'val': # 20%
  • self.images = self.images[int(0.6*len(self.images)):int(0.8*len(self.images))]
  • self.labels = self.labels[int(0.6*len(self.labels)):int(0.8*len(self.labels))]
  • else: # 20%
  • self.images = self.images[int(0.8*len(self.images)):]
  • self.labels = self.labels[int(0.8*len(self.labels)):]
  • def load_csv(self, filename):
  • if not os.path.exists(os.path.join(self.root, filename)):
  • images = []
  • for name in self.name2label.keys():
  • images += glob.glob(os.path.join(self.root, name, '*.png'))
  • images += glob.glob(os.path.join(self.root, name, '*.jpg'))
  • images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
  • random.shuffle(images)
  • with open(os.path.join(self.root, filename), mode='w', newline='') as f:
  • writer = csv.writer(f)
  • for img in images: # pokemon\\bulbasaur\\00000000.png
  • name = img.split(os.sep)[-2] # bulbasaur
  • label = self.name2label[name]
  • # pokemon\\bulbasaur\\00000000.png 0
  • writer.writerow([img, label])
  • print('writen into csv file:', filename)
  • # read csv file
  • images, labels = [], []
  • with open(os.path.join(self.root, filename)) as f:
  • reader = csv.reader(f)
  • for row in reader:
  • image, label = row
  • label = int(label)
  • images.append(image)
  • labels.append(label)
  • assert len(images) == len(labels)
  • return images, labels
  • def __len__(self):
  • return len(self.images)
  • def __getitem__(self, idx):
  • # idx [0~len(images)]
  • # self.images, self.labels
  • # pokemon\\bulbasaur\\00000000.png 0
  • img, label = self.images[idx], self.labels[idx]
  • tf = transforms.Compose([
  • lambda x:Image.open(x).convert('RGB'), # string path => image data
  • transforms.Resize((int(self.resize*1.25), int(self.resize*1.25))),
  • transforms.RandomRotation(15),
  • transforms.CenterCrop(self.resize),
  • transforms.ToTensor(),
  • transforms.Normalize(mean=[0.485, 0.456, 0.406],
  • std=[0.229, 0.224, 0.225])
  • ])
  • img = tf(img)
  • label = torch.tensor(label)
  • return img, label
  • class Flatten(nn.Module):
  • def __init__(self):
  • super(Flatten, self).__init__()
  • def forward(self, x):
  • shape = torch.prod(torch.tensor(x.shape[1:])).item()
  • return x.view(-1, shape)
  • batchsz = 32
  • lr = 1e-3
  • epochs = 10
  • device = torch.device('cuda')
  • torch.manual_seed(1234)
  • train_db = Pokemon('pokemon', 224, model='train')
  • val_db = Pokemon('pokemon', 224, model='val')
  • test_db = Pokemon('pokemon', 224, model='test')
  • train_loader = DataLoader(train_db, batch_size=batchsz, shuffle=True, num_workers=2)
  • val_loader = DataLoader(val_db, batch_size=batchsz, num_workers=2)
  • test_loader = DataLoader(test_db, batch_size=batchsz, num_workers=2)
  • def evalute(model, loader):
  • correct = 0
  • total = len(loader.dataset)
  • for x,y in loader:
  • with torch.no_grad():
  • logits = model(x)
  • pred = logits.argmax(dim=1)
  • correct += torch.eq(pred, y).sum().float().item()
  • return correct / total
  • def main():
  • trained_model = resnet18(pretrained=True)
  • model = nn.Sequential(*list(trained_model.children())[:-1],# [b, 512, 1, 1]
  • Flatten(), # [b, 512, 1, 1] => [b, 512]
  • nn.Linear(512, 5)
  • )
  • optimizer = optim.Adam(model.parameters(), lr=lr)
  • criteon = nn.CrossEntropyLoss()
  • best_acc, best_epoch = 0, 0
  • for epoch in range(epochs):
  • for step, (x, y) in enumerate(train_loader):
  • # x:[b, 3, 224, 224], y:[b]
  • logits = model(x)
  • loss = criteon(logits, y)
  • optimizer.zero_grad()
  • loss.backward()
  • optimizer.step()
  • if epoch % 2 == 0:
  • val_acc = evalute(model, val_loader)
  • if val_acc > best_acc:
  • best_epoch = epoch
  • best_acc = val_acc
  • torch.save(model.state_dict(), 'best.mdl')
  • print('best acc:', best_acc, 'best_epoch', best_epoch)
  • model.load_state_dict(torch.load('best.mdl'))
  • print('loaded from ckt!')
  • test_acc = evalute(model, test_loader)
  • print('test_acc:', test_acc)
  • if __name__ == '__main__':
  • main()

最终 test accuracy 在 0.94 左右,比我们自己从 0 开始训练效果好了很多

Last Modified: February 9, 2020
Archives Tip
QR Code for this page
Tipping QR Code
Leave a Comment