MENU

AutoEncoder 的 PyTorch 实现

February 11, 2020 • Read: 5601 • Deep Learning阅读设置

之前的文章叙述了 AutoEncoder 的原理,这篇文章主要侧重于用 PyTorch 实现 AutoEncoder

AutoEncoder

其实 AutoEncoder 就是非常简单的 DNN。在 encoder 中神经元随着层数的增加逐渐变少,也就是降维的过程。而在 decoder 中神经元随着层数的增加逐渐变多,也就是升维的过程

  • class AE(nn.Module):
  • def __init__(self):
  • super(AE, self).__init__()
  • self.encoder = nn.Sequential(
  • # [b, 784] => [b, 256]
  • nn.Linear(784, 256),
  • nn.ReLU(),
  • # [b, 256] => [b, 64]
  • nn.Linear(256, 64),
  • nn.ReLU(),
  • # [b, 64] => [b, 20]
  • nn.Linear(64, 20),
  • nn.ReLU()
  • )
  • self.decoder = nn.Sequential(
  • # [b, 20] => [b, 64]
  • nn.Linear(20, 64),
  • nn.ReLU(),
  • # [b, 64] => [b, 256]
  • nn.Linear(64, 256),
  • nn.ReLU(),
  • # [b, 256] => [b, 784]
  • nn.Linear(256, 784),
  • nn.Sigmoid()
  • )
  • def forward(self, x):
  • """
  • :param [b, 1, 28, 28]:
  • :return [b, 1, 28, 28]:
  • """
  • batchsz = x.size(0)
  • # flatten
  • x = x.view(batchsz, -1)
  • # encode
  • x = self.encoder(x)
  • # decode
  • x = self.decoder(x)
  • # reshape
  • x = x.view(batchsz, 1, 28, 28)
  • return x

上面代码都是基本操作,有一个地方需要特别注意,在 decoder 网络中,最后跟的不是 ReLU 而是 Sigmoid 函数,因为我们想要将图片打印出来看一下,而使用的数据集是 MNIST,所以要将 tensor 里面的值最终都压缩到 0-1 之间

然后定义训练集和测试集,将它们分别带入到 DataLoader 中

  • mnist_train = datasets.MNIST('mnist', train=True, transform=transforms.Compose([
  • transforms.ToTensor()
  • ]), download=True)
  • mnist_train = DataLoader(mnist_train, batch_size=32, shuffle=True)
  • mnist_test = datasets.MNIST('mnist', train=False, transform=transforms.Compose([
  • transforms.ToTensor()
  • ]), download=True)
  • mnist_test = DataLoader(mnist_test, batch_size=32)

由于 input 是 0-1 之间的实数,所以 Loss function 选择 MSE

  • epochs = 1000
  • lr = 1e-3
  • model = AE()
  • criteon = nn.MSELoss()
  • optimizer = optim.Adam(model.parameters(), lr=lr)
  • print(model)

在通常(监督学习)情况下,我们需要将网络的输出 output 和训练集的 label 进行对比,计算 loss。但 AutoEncoder 是无监督学习,不需要 label,我们只需要将网络的输出 output 和网络的输入 input 进行对比,计算 loss 即可

  • viz = visdom.Visdom()
  • for epoch in range(epochs):
  • # 不需要label,所以用一个占位符"_"代替
  • for batchidx, (x, _) in enumerate(mnist_train):
  • x_hat = model(x)
  • loss = criteon(x_hat, x)
  • # backprop
  • optimizer.zero_grad()
  • loss.backward()
  • optimizer.step()
  • if epoch % 10 == 0:
  • print(epoch, 'loss:', loss.item())
  • x, _ = iter(mnist_test).next()
  • with torch.no_grad():
  • x_hat = model(x)
  • viz.images(x, nrow=8, win='x', opts=dict(title='x'))
  • viz.images(x_hat, nrow=8, win='x_hat', opts=dict(title='x_hat'))

到这里,最简单的 AutoEncoder 代码已经写完了,完整代码如下:

  • import torch
  • import visdom
  • from torch.utils.data import DataLoader
  • from torchvision import transforms, datasets
  • from torch import nn, optim
  • class AE(nn.Module):
  • def __init__(self):
  • super(AE, self).__init__()
  • self.encoder = nn.Sequential(
  • # [b, 784] => [b, 256]
  • nn.Linear(784, 256),
  • nn.ReLU(),
  • # [b, 256] => [b, 64]
  • nn.Linear(256, 64),
  • nn.ReLU(),
  • # [b, 64] => [b, 20]
  • nn.Linear(64, 20),
  • nn.ReLU()
  • )
  • self.decoder = nn.Sequential(
  • # [b, 20] => [b, 64]
  • nn.Linear(20, 64),
  • nn.ReLU(),
  • # [b, 64] => [b, 256]
  • nn.Linear(64, 256),
  • nn.ReLU(),
  • # [b, 256] => [b, 784]
  • nn.Linear(256, 784),
  • nn.Sigmoid()
  • )
  • def forward(self, x):
  • """
  • :param [b, 1, 28, 28]:
  • :return [b, 1, 28, 28]:
  • """
  • batchsz = x.size(0)
  • # flatten
  • x = x.view(batchsz, -1)
  • # encoder
  • x = self.encoder(x)
  • # decoder
  • x = self.decoder(x)
  • # reshape
  • x = x.view(batchsz, 1, 28, 28)
  • return x
  • def main():
  • mnist_train = datasets.MNIST('mnist', train=True, transform=transforms.Compose([
  • transforms.ToTensor()
  • ]), download=True)
  • mnist_train = DataLoader(mnist_train, batch_size=32, shuffle=True)
  • mnist_test = datasets.MNIST('mnist', train=False, transform=transforms.Compose([
  • transforms.ToTensor()
  • ]), download=True)
  • mnist_test = DataLoader(mnist_test, batch_size=32)
  • epochs = 1000
  • lr = 1e-3
  • model = AE()
  • criteon = nn.MSELoss()
  • optimizer = optim.Adam(model.parameters(), lr=lr)
  • print(model)
  • viz = visdom.Visdom()
  • for epoch in range(epochs):
  • # 不需要label,所以用一个占位符"_"代替
  • for batchidx, (x, _) in enumerate(mnist_train):
  • x_hat = model(x)
  • loss = criteon(x_hat, x)
  • # backprop
  • optimizer.zero_grad()
  • loss.backward()
  • optimizer.step()
  • if epoch % 10 == 0:
  • print(epoch, 'loss:', loss.item())
  • x, _ = iter(mnist_test).next()
  • with torch.no_grad():
  • x_hat = model(x)
  • viz.images(x, nrow=8, win='x', opts=dict(title='x'))
  • viz.images(x_hat, nrow=8, win='x_hat', opts=dict(title='x_hat'))
  • if __name__ == '__main__':
  • main()

得到的效果如下图所示,普通的 AutoEncoder 还是差了一点,可以看到很多图片已经看不清具体代表的数字了

Variational AutoEncoders

AutoEncoder 的 shape 变化是 [b, 784] => [b, 20] => [b, 784],虽然 VAE 也是这样,但其中的 20 并不一样,对于 VAE 来说,[b, 20] 要分成两个 [b, 10],分别是 $\mu$ 和 $\sigma$,具体形式见下图

最主要先关注一下定义网络的部分

  • class VAE(nn.Module):
  • def __init__(self):
  • super(VAE, self).__init__()
  • # [b, 784] => [b, 20]
  • # u: [b, 10]
  • # sigma: [b, 10]
  • self.encoder = nn.Sequential(
  • # [b, 784] => [b, 256]
  • nn.Linear(784, 256),
  • nn.ReLU(),
  • # [b, 256] => [b, 64]
  • nn.Linear(256, 64),
  • nn.ReLU(),
  • # [b, 64] => [b, 20]
  • nn.Linear(64, 20),
  • nn.ReLU()
  • )
  • self.decoder = nn.Sequential(
  • # [b, 10] => [b, 64]
  • nn.Linear(10, 64),
  • nn.ReLU(),
  • # [b, 64] => [b, 256]
  • nn.Linear(64, 256),
  • nn.ReLU(),
  • # [b, 256] => [b, 784]
  • nn.Linear(256, 784),
  • nn.Sigmoid()
  • )
  • def forward(self, x):
  • """
  • :param [b, 1, 28, 28]:
  • :return [b, 1, 28, 28]:
  • """
  • batchsz = x.size(0)
  • # flatten
  • x = x.view(batchsz, -1)
  • # encoder
  • # [b, 20] including mean and sigma
  • q = self.encoder(x)
  • # [b, 20] => [b, 10] and [b, 10]
  • mu, sigma = q.chunk(2, dim=1)
  • # reparameterize trick, epsilon~N(0, 1)
  • q = mu + sigma * torch.randn_like(sigma)
  • # decoder
  • x_hat = self.decoder(q)
  • # reshape
  • x_hat = x_hat.view(batchsz, 1, 28, 28)
  • # KL
  • kld = 0.5 * torch.sum(
  • torch.pow(mu, 2) +
  • torch.pow(sigma, 2) -
  • torch.log(1e-8 + torch.pow(sigma, 2)) - 1
  • ) / (batchsz*28*28)
  • return x_hat, kld

Encode 以后的变量 $h$ 要分成两半儿,利用 h.chunk(num, dim) 实现,num 表示要分成几块,dim 值表示在什么维度上进行。然后随机采样出标准正态分布的数据,用 $\mu$ 和 $\sigma$ 对其进行变换。这里的 kld 指的是 KL Divergence,它是 Loss 的一部分,其计算过程如下:

$$ q(x) \sim \mathcal{N}(\mu, \sigma),\ p(x)\sim \mathcal{N}(0, 1) $$

$$ \begin{aligned} KL(q||p) &= \log \frac{1}{\sigma} + \frac{\sigma^2+\mu^2}{2} - \frac{1}{2}\\ &= -\log \sigma + \frac{1}{2}\sigma^2 + \frac{1}{2}\mu^2 - \frac{1}{2}\\ &= -\frac{1}{2} \log \sigma^2 + \frac{1}{2}\sigma^2 + \frac{1}{2}\mu^2 - \frac{1}{2}\\ &= \frac{1}{2}(\mu^2 + \sigma^2 - \log \sigma^2 - 1) \end{aligned} $$

  • import torch
  • import visdom
  • import numpy as np
  • from torch import nn, optim
  • from torch.utils.data import DataLoader
  • from torchvision import transforms, datasets
  • class VAE(nn.Module):
  • def __init__(self):
  • super(VAE, self).__init__()
  • # [b, 784] => [b, 20]
  • # u: [b, 10]
  • # sigma: [b, 10]
  • self.encoder = nn.Sequential(
  • # [b, 784] => [b, 256]
  • nn.Linear(784, 256),
  • nn.ReLU(),
  • # [b, 256] => [b, 64]
  • nn.Linear(256, 64),
  • nn.ReLU(),
  • # [b, 64] => [b, 20]
  • nn.Linear(64, 20),
  • nn.ReLU()
  • )
  • self.decoder = nn.Sequential(
  • # [b, 10] => [b, 64]
  • nn.Linear(10, 64),
  • nn.ReLU(),
  • # [b, 64] => [b, 256]
  • nn.Linear(64, 256),
  • nn.ReLU(),
  • # [b, 256] => [b, 784]
  • nn.Linear(256, 784),
  • nn.Sigmoid()
  • )
  • def forward(self, x):
  • """
  • :param [b, 1, 28, 28]:
  • :return [b, 1, 28, 28]:
  • """
  • batchsz = x.size(0)
  • # flatten
  • x = x.view(batchsz, -1)
  • # encoder
  • # [b, 20] including mean and sigma
  • q = self.encoder(x)
  • # [b, 20] => [b, 10] and [b, 10]
  • mu, sigma = q.chunk(2, dim=1)
  • # reparameterize trick, epsilon~N(0, 1)
  • q = mu + sigma * torch.randn_like(sigma)
  • # decoder
  • x_hat = self.decoder(q)
  • # reshape
  • x_hat = x_hat.view(batchsz, 1, 28, 28)
  • # KL
  • kld = 0.5 * torch.sum(
  • torch.pow(mu, 2) +
  • torch.pow(sigma, 2) -
  • torch.log(1e-8 + torch.pow(sigma, 2)) - 1
  • ) / (batchsz*28*28)
  • return x_hat, kld
  • def main():
  • mnist_train = datasets.MNIST('mnist', train=True, transform=transforms.Compose([
  • transforms.ToTensor()
  • ]), download=True)
  • mnist_train = DataLoader(mnist_train, batch_size=32, shuffle=True)
  • mnist_test = datasets.MNIST('mnist', train=False, transform=transforms.Compose([
  • transforms.ToTensor()
  • ]), download=True)
  • mnist_test = DataLoader(mnist_test, batch_size=32)
  • epochs = 1000
  • lr = 1e-3
  • model = VAE()
  • criteon = nn.MSELoss()
  • optimizer = optim.Adam(model.parameters(), lr=lr)
  • print(model)
  • viz = visdom.Visdom()
  • for epoch in range(epochs):
  • # 不需要label,所以用一个占位符"_"代替
  • for batchidx, (x, _) in enumerate(mnist_train):
  • x_hat, kld = model(x)
  • loss = criteon(x_hat, x)
  • if kld is not None:
  • elbo = loss + 1.0 * kld
  • loss = elbo
  • # backprop
  • optimizer.zero_grad()
  • loss.backward()
  • optimizer.step()
  • if epoch % 10 == 0:
  • print(epoch, 'loss:', loss.item(), 'kld', kld.item())
  • x, _ = iter(mnist_test).next()
  • with torch.no_grad():
  • x_hat, kld = model(x)
  • viz.images(x, nrow=8, win='x', opts=dict(title='x'))
  • viz.images(x_hat, nrow=8, win='x_hat', opts=dict(title='x_hat'))
  • if __name__ == '__main__':
  • main()
Last Modified: August 2, 2021
Archives Tip
QR Code for this page
Tipping QR Code
Leave a Comment

已有 1 条评论
  1. Xiang Shen Xiang Shen

    criterion 拼错了,不过没多大影响