之前的文章叙述了 AutoEncoder 的原理,这篇文章主要侧重于用 PyTorch 实现 AutoEncoder
AutoEncoder
其实 AutoEncoder 就是非常简单的 DNN。在 encoder 中神经元随着层数的增加逐渐变少,也就是降维的过程。而在 decoder 中神经元随着层数的增加逐渐变多,也就是升维的过程
- class AE(nn.Module):
- def __init__(self):
- super(AE, self).__init__()
- self.encoder = nn.Sequential(
- # [b, 784] => [b, 256]
- nn.Linear(784, 256),
- nn.ReLU(),
- # [b, 256] => [b, 64]
- nn.Linear(256, 64),
- nn.ReLU(),
- # [b, 64] => [b, 20]
- nn.Linear(64, 20),
- nn.ReLU()
- )
- self.decoder = nn.Sequential(
- # [b, 20] => [b, 64]
- nn.Linear(20, 64),
- nn.ReLU(),
- # [b, 64] => [b, 256]
- nn.Linear(64, 256),
- nn.ReLU(),
- # [b, 256] => [b, 784]
- nn.Linear(256, 784),
- nn.Sigmoid()
- )
-
- def forward(self, x):
- """
- :param [b, 1, 28, 28]:
- :return [b, 1, 28, 28]:
- """
- batchsz = x.size(0)
- # flatten
- x = x.view(batchsz, -1)
- # encode
- x = self.encoder(x)
- # decode
- x = self.decoder(x)
- # reshape
- x = x.view(batchsz, 1, 28, 28)
-
- return x
上面代码都是基本操作,有一个地方需要特别注意,在 decoder 网络中,最后跟的不是 ReLU 而是 Sigmoid 函数,因为我们想要将图片打印出来看一下,而使用的数据集是 MNIST,所以要将 tensor 里面的值最终都压缩到 0-1 之间
然后定义训练集和测试集,将它们分别带入到 DataLoader 中
- mnist_train = datasets.MNIST('mnist', train=True, transform=transforms.Compose([
- transforms.ToTensor()
- ]), download=True)
- mnist_train = DataLoader(mnist_train, batch_size=32, shuffle=True)
-
- mnist_test = datasets.MNIST('mnist', train=False, transform=transforms.Compose([
- transforms.ToTensor()
- ]), download=True)
- mnist_test = DataLoader(mnist_test, batch_size=32)
由于 input 是 0-1 之间的实数,所以 Loss function 选择 MSE
- epochs = 1000
- lr = 1e-3
- model = AE()
- criteon = nn.MSELoss()
- optimizer = optim.Adam(model.parameters(), lr=lr)
- print(model)
在通常(监督学习)情况下,我们需要将网络的输出 output 和训练集的 label 进行对比,计算 loss。但 AutoEncoder 是无监督学习,不需要 label,我们只需要将网络的输出 output 和网络的输入 input 进行对比,计算 loss 即可
- viz = visdom.Visdom()
- for epoch in range(epochs):
- # 不需要label,所以用一个占位符"_"代替
- for batchidx, (x, _) in enumerate(mnist_train):
- x_hat = model(x)
- loss = criteon(x_hat, x)
-
- # backprop
- optimizer.zero_grad()
- loss.backward()
- optimizer.step()
-
- if epoch % 10 == 0:
- print(epoch, 'loss:', loss.item())
- x, _ = iter(mnist_test).next()
- with torch.no_grad():
- x_hat = model(x)
- viz.images(x, nrow=8, win='x', opts=dict(title='x'))
- viz.images(x_hat, nrow=8, win='x_hat', opts=dict(title='x_hat'))
到这里,最简单的 AutoEncoder 代码已经写完了,完整代码如下:
- import torch
- import visdom
- from torch.utils.data import DataLoader
- from torchvision import transforms, datasets
- from torch import nn, optim
-
- class AE(nn.Module):
- def __init__(self):
- super(AE, self).__init__()
- self.encoder = nn.Sequential(
- # [b, 784] => [b, 256]
- nn.Linear(784, 256),
- nn.ReLU(),
- # [b, 256] => [b, 64]
- nn.Linear(256, 64),
- nn.ReLU(),
- # [b, 64] => [b, 20]
- nn.Linear(64, 20),
- nn.ReLU()
- )
- self.decoder = nn.Sequential(
- # [b, 20] => [b, 64]
- nn.Linear(20, 64),
- nn.ReLU(),
- # [b, 64] => [b, 256]
- nn.Linear(64, 256),
- nn.ReLU(),
- # [b, 256] => [b, 784]
- nn.Linear(256, 784),
- nn.Sigmoid()
- )
-
- def forward(self, x):
- """
- :param [b, 1, 28, 28]:
- :return [b, 1, 28, 28]:
- """
- batchsz = x.size(0)
- # flatten
- x = x.view(batchsz, -1)
- # encoder
- x = self.encoder(x)
- # decoder
- x = self.decoder(x)
- # reshape
- x = x.view(batchsz, 1, 28, 28)
-
- return x
-
- def main():
- mnist_train = datasets.MNIST('mnist', train=True, transform=transforms.Compose([
- transforms.ToTensor()
- ]), download=True)
- mnist_train = DataLoader(mnist_train, batch_size=32, shuffle=True)
-
- mnist_test = datasets.MNIST('mnist', train=False, transform=transforms.Compose([
- transforms.ToTensor()
- ]), download=True)
- mnist_test = DataLoader(mnist_test, batch_size=32)
-
- epochs = 1000
- lr = 1e-3
- model = AE()
- criteon = nn.MSELoss()
- optimizer = optim.Adam(model.parameters(), lr=lr)
- print(model)
-
- viz = visdom.Visdom()
- for epoch in range(epochs):
- # 不需要label,所以用一个占位符"_"代替
- for batchidx, (x, _) in enumerate(mnist_train):
- x_hat = model(x)
- loss = criteon(x_hat, x)
-
- # backprop
- optimizer.zero_grad()
- loss.backward()
- optimizer.step()
-
- if epoch % 10 == 0:
- print(epoch, 'loss:', loss.item())
- x, _ = iter(mnist_test).next()
- with torch.no_grad():
- x_hat = model(x)
- viz.images(x, nrow=8, win='x', opts=dict(title='x'))
- viz.images(x_hat, nrow=8, win='x_hat', opts=dict(title='x_hat'))
-
- if __name__ == '__main__':
- main()
得到的效果如下图所示,普通的 AutoEncoder 还是差了一点,可以看到很多图片已经看不清具体代表的数字了
Variational AutoEncoders
AutoEncoder 的 shape 变化是 [b, 784] => [b, 20] => [b, 784]
,虽然 VAE 也是这样,但其中的 20 并不一样,对于 VAE 来说,[b, 20]
要分成两个 [b, 10]
,分别是 $\mu$ 和 $\sigma$,具体形式见下图
最主要先关注一下定义网络的部分
- class VAE(nn.Module):
- def __init__(self):
- super(VAE, self).__init__()
- # [b, 784] => [b, 20]
- # u: [b, 10]
- # sigma: [b, 10]
- self.encoder = nn.Sequential(
- # [b, 784] => [b, 256]
- nn.Linear(784, 256),
- nn.ReLU(),
- # [b, 256] => [b, 64]
- nn.Linear(256, 64),
- nn.ReLU(),
- # [b, 64] => [b, 20]
- nn.Linear(64, 20),
- nn.ReLU()
- )
- self.decoder = nn.Sequential(
- # [b, 10] => [b, 64]
- nn.Linear(10, 64),
- nn.ReLU(),
- # [b, 64] => [b, 256]
- nn.Linear(64, 256),
- nn.ReLU(),
- # [b, 256] => [b, 784]
- nn.Linear(256, 784),
- nn.Sigmoid()
- )
-
- def forward(self, x):
- """
- :param [b, 1, 28, 28]:
- :return [b, 1, 28, 28]:
- """
- batchsz = x.size(0)
- # flatten
- x = x.view(batchsz, -1)
- # encoder
- # [b, 20] including mean and sigma
- q = self.encoder(x)
- # [b, 20] => [b, 10] and [b, 10]
- mu, sigma = q.chunk(2, dim=1)
- # reparameterize trick, epsilon~N(0, 1)
- q = mu + sigma * torch.randn_like(sigma)
-
- # decoder
- x_hat = self.decoder(q)
- # reshape
- x_hat = x_hat.view(batchsz, 1, 28, 28)
-
- # KL
- kld = 0.5 * torch.sum(
- torch.pow(mu, 2) +
- torch.pow(sigma, 2) -
- torch.log(1e-8 + torch.pow(sigma, 2)) - 1
- ) / (batchsz*28*28)
-
- return x_hat, kld
Encode 以后的变量 $h$ 要分成两半儿,利用 h.chunk(num, dim)
实现,num 表示要分成几块,dim 值表示在什么维度上进行。然后随机采样出标准正态分布的数据,用 $\mu$ 和 $\sigma$ 对其进行变换。这里的 kld
指的是 KL Divergence,它是 Loss 的一部分,其计算过程如下:
$$ q(x) \sim \mathcal{N}(\mu, \sigma),\ p(x)\sim \mathcal{N}(0, 1) $$
$$ \begin{aligned} KL(q||p) &= \log \frac{1}{\sigma} + \frac{\sigma^2+\mu^2}{2} - \frac{1}{2}\\ &= -\log \sigma + \frac{1}{2}\sigma^2 + \frac{1}{2}\mu^2 - \frac{1}{2}\\ &= -\frac{1}{2} \log \sigma^2 + \frac{1}{2}\sigma^2 + \frac{1}{2}\mu^2 - \frac{1}{2}\\ &= \frac{1}{2}(\mu^2 + \sigma^2 - \log \sigma^2 - 1) \end{aligned} $$
- import torch
- import visdom
- import numpy as np
- from torch import nn, optim
- from torch.utils.data import DataLoader
- from torchvision import transforms, datasets
-
- class VAE(nn.Module):
- def __init__(self):
- super(VAE, self).__init__()
- # [b, 784] => [b, 20]
- # u: [b, 10]
- # sigma: [b, 10]
- self.encoder = nn.Sequential(
- # [b, 784] => [b, 256]
- nn.Linear(784, 256),
- nn.ReLU(),
- # [b, 256] => [b, 64]
- nn.Linear(256, 64),
- nn.ReLU(),
- # [b, 64] => [b, 20]
- nn.Linear(64, 20),
- nn.ReLU()
- )
- self.decoder = nn.Sequential(
- # [b, 10] => [b, 64]
- nn.Linear(10, 64),
- nn.ReLU(),
- # [b, 64] => [b, 256]
- nn.Linear(64, 256),
- nn.ReLU(),
- # [b, 256] => [b, 784]
- nn.Linear(256, 784),
- nn.Sigmoid()
- )
-
- def forward(self, x):
- """
- :param [b, 1, 28, 28]:
- :return [b, 1, 28, 28]:
- """
- batchsz = x.size(0)
- # flatten
- x = x.view(batchsz, -1)
- # encoder
- # [b, 20] including mean and sigma
- q = self.encoder(x)
- # [b, 20] => [b, 10] and [b, 10]
- mu, sigma = q.chunk(2, dim=1)
- # reparameterize trick, epsilon~N(0, 1)
- q = mu + sigma * torch.randn_like(sigma)
-
- # decoder
- x_hat = self.decoder(q)
- # reshape
- x_hat = x_hat.view(batchsz, 1, 28, 28)
-
- # KL
- kld = 0.5 * torch.sum(
- torch.pow(mu, 2) +
- torch.pow(sigma, 2) -
- torch.log(1e-8 + torch.pow(sigma, 2)) - 1
- ) / (batchsz*28*28)
-
- return x_hat, kld
-
- def main():
- mnist_train = datasets.MNIST('mnist', train=True, transform=transforms.Compose([
- transforms.ToTensor()
- ]), download=True)
- mnist_train = DataLoader(mnist_train, batch_size=32, shuffle=True)
-
- mnist_test = datasets.MNIST('mnist', train=False, transform=transforms.Compose([
- transforms.ToTensor()
- ]), download=True)
- mnist_test = DataLoader(mnist_test, batch_size=32)
-
- epochs = 1000
- lr = 1e-3
- model = VAE()
- criteon = nn.MSELoss()
- optimizer = optim.Adam(model.parameters(), lr=lr)
- print(model)
-
- viz = visdom.Visdom()
- for epoch in range(epochs):
- # 不需要label,所以用一个占位符"_"代替
- for batchidx, (x, _) in enumerate(mnist_train):
- x_hat, kld = model(x)
- loss = criteon(x_hat, x)
-
- if kld is not None:
- elbo = loss + 1.0 * kld
- loss = elbo
-
- # backprop
- optimizer.zero_grad()
- loss.backward()
- optimizer.step()
-
- if epoch % 10 == 0:
- print(epoch, 'loss:', loss.item(), 'kld', kld.item())
- x, _ = iter(mnist_test).next()
- with torch.no_grad():
- x_hat, kld = model(x)
- viz.images(x, nrow=8, win='x', opts=dict(title='x'))
- viz.images(x_hat, nrow=8, win='x_hat', opts=dict(title='x_hat'))
-
- if __name__ == '__main__':
- main()
criterion 拼错了,不过没多大影响