MENU

AutoEncoder的PyTorch实现

February 11, 2020 • Read: 109 • Deep Learning

之前的文章叙述了AutoEncoder的原理,这篇文章主要侧重于用PyTorch实现AutoEncoder

AutoEncoder

其实AutoEncoder就是非常简单的DNN。在encoder中神经元随着层数的增加逐渐变少,也就是降维的过程。而在decoder中神经元随着层数的增加逐渐变多,也就是升维的过程

class AE(nn.Module):
    def __init__(self):
        super(AE, self).__init__()
        self.encoder = nn.Sequential(
            # [b, 784] => [b, 256]
            nn.Linear(784, 256),
            nn.ReLU(),
            # [b, 256] => [b, 64]
            nn.Linear(256, 64),
            nn.ReLU(),
            # [b, 64] => [b, 20]
            nn.Linear(64, 20),
            nn.ReLU()
        )
        self.decoder = nn.Sequential(
            # [b, 20] => [b, 64]
            nn.Linear(20, 64),
            nn.ReLU(),
            # [b, 64] => [b, 256]
            nn.Linear(64, 256),
            nn.ReLU(),
            # [b, 256] => [b, 784]
            nn.Linear(256, 784),
            nn.Sigmoid()
        )

    def forward(self, x):
        """
        :param [b, 1, 28, 28]:
        :return [b, 1, 28, 28]:
        """
        batchsz = x.size(0)
        # flatten
        x = x.view(batchsz, -1)
        # encode
        x = self.encoder(x)
        # decode
        x = self.decoder(x)
        # reshape
        x = x.view(batchsz, 1, 28, 28)

        return x

上面代码都是基本操作,有一个地方需要特别注意,在decoder网络中,最后跟的不是ReLU而是Sigmoid函数,因为我们想要将图片打印出来看一下,而使用的数据集是MNIST,所以要将tensor里面的值最终都压缩到0-1之间

然后定义训练集和测试集,将它们分别带入到DataLoader中

mnist_train = datasets.MNIST('mnist', train=True, transform=transforms.Compose([
    transforms.ToTensor()
]), download=True)
mnist_train = DataLoader(mnist_train, batch_size=32, shuffle=True)

mnist_test = datasets.MNIST('mnist', train=False, transform=transforms.Compose([
    transforms.ToTensor()
]), download=True)
mnist_test = DataLoader(mnist_test, batch_size=32)

由于input是0-1之间的实数,所以Loss function选择MSE

epochs = 1000
lr = 1e-3
model = AE()
criteon = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
print(model)

在通常(监督学习)情况下,我们需要将网络的输出output和训练集的label进行对比,计算loss。但AutoEncoder是无监督学习,不需要label,我们只需要将网络的输出output和网络的输入input进行对比,计算loss即可

viz = visdom.Visdom()
    for epoch in range(epochs):
        # 不需要label,所以用一个占位符"_"代替
        for batchidx, (x, _) in enumerate(mnist_train):
            x_hat = model(x)
            loss = criteon(x_hat, x)

            # backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
        if epoch % 10 == 0:
            print(epoch, 'loss:', loss.item())
        x, _ = iter(mnist_test).next()
        with torch.no_grad():
            x_hat = model(x)
        viz.images(x, nrow=8, win='x', opts=dict(title='x'))
        viz.images(x_hat, nrow=8, win='x_hat', opts=dict(title='x_hat'))

到这里,最简单的AutoEncoder代码已经写完了,完整代码如下:

import torch
import visdom
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from torch import nn, optim


class AE(nn.Module):
    def __init__(self):
        super(AE, self).__init__()
        self.encoder = nn.Sequential(
            # [b, 784] => [b, 256]
            nn.Linear(784, 256),
            nn.ReLU(),
            # [b, 256] => [b, 64]
            nn.Linear(256, 64),
            nn.ReLU(),
            # [b, 64] => [b, 20]
            nn.Linear(64, 20),
            nn.ReLU()
        )
        self.decoder = nn.Sequential(
            # [b, 20] => [b, 64]
            nn.Linear(20, 64),
            nn.ReLU(),
            # [b, 64] => [b, 256]
            nn.Linear(64, 256),
            nn.ReLU(),
            # [b, 256] => [b, 784]
            nn.Linear(256, 784),
            nn.Sigmoid()
        )

    def forward(self, x):
        """
        :param [b, 1, 28, 28]:
        :return [b, 1, 28, 28]:
        """
        batchsz = x.size(0)
        # flatten
        x = x.view(batchsz, -1)
        # encoder
        x = self.encoder(x)
        # decoder
        x = self.decoder(x)
        # reshape
        x = x.view(batchsz, 1, 28, 28)

        return x

def main():
    mnist_train = datasets.MNIST('mnist', train=True, transform=transforms.Compose([
        transforms.ToTensor()
    ]), download=True)
    mnist_train = DataLoader(mnist_train, batch_size=32, shuffle=True)

    mnist_test = datasets.MNIST('mnist', train=False, transform=transforms.Compose([
        transforms.ToTensor()
    ]), download=True)
    mnist_test = DataLoader(mnist_test, batch_size=32)

    epochs = 1000
    lr = 1e-3
    model = AE()
    criteon = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    print(model)

    viz = visdom.Visdom()
    for epoch in range(epochs):
        # 不需要label,所以用一个占位符"_"代替
        for batchidx, (x, _) in enumerate(mnist_train):
            x_hat = model(x)
            loss = criteon(x_hat, x)

            # backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
        if epoch % 10 == 0:
            print(epoch, 'loss:', loss.item())
        x, _ = iter(mnist_test).next()
        with torch.no_grad():
            x_hat = model(x)
        viz.images(x, nrow=8, win='x', opts=dict(title='x'))
        viz.images(x_hat, nrow=8, win='x_hat', opts=dict(title='x_hat'))

if __name__ == '__main__':
    main()

得到的效果如下图所示,普通的AutoEncoder还是差了一点,可以看到很多图片已经看不清具体代表的数字了

Variational AutoEncoders

AutoEncoder的shape变化是[b, 784] => [b, 20] => [b, 784],虽然VAE也是这样,但其中的20并不一样,对于VAE来说,[b, 20]要分成两个[b, 10],分别是$\mu$和$\sigma$,具体形式见下图

最主要先关注一下定义网络的部分

class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        # [b, 784] => [b, 20]
        # u: [b, 10]
        # sigma: [b, 10]
        self.encoder = nn.Sequential(
            # [b, 784] => [b, 256]
            nn.Linear(784, 256),
            nn.ReLU(),
            # [b, 256] => [b, 64]
            nn.Linear(256, 64),
            nn.ReLU(),
            # [b, 64] => [b, 20]
            nn.Linear(64, 20),
            nn.ReLU()
        )
        self.decoder = nn.Sequential(
            # [b, 10] => [b, 64]
            nn.Linear(10, 64),
            nn.ReLU(),
            # [b, 64] => [b, 256]
            nn.Linear(64, 256),
            nn.ReLU(),
            # [b, 256] => [b, 784]
            nn.Linear(256, 784),
            nn.Sigmoid()
        )

    def forward(self, x):
        """
        :param [b, 1, 28, 28]:
        :return [b, 1, 28, 28]:
        """
        batchsz = x.size(0)
        # flatten
        x = x.view(batchsz, -1)
        # encoder
        # [b, 20] including mean and sigma
        q = self.encoder(x)
        # [b, 20] => [b, 10] and [b, 10]
        mu, sigma = q.chunk(2, dim=1)
        # reparameterize trick,  epsilon~N(0, 1)
        q = mu + sigma * torch.randn_like(sigma)

        # decoder
        x_hat = self.decoder(q)
        # reshape
        x_hat = x_hat.view(batchsz, 1, 28, 28)

        # KL
        kld = 0.5 * torch.sum(
            torch.pow(mu, 2) +
            torch.pow(sigma, 2) -
            torch.log(1e-8 + torch.pow(sigma, 2)) - 1
        ) / (batchsz*28*28)

        return x_hat, kld

Encode以后的变量$h$要分成两半儿,利用h.chunk(num, dim)实现,num表示要分成几块,dim值表示在什么维度上进行。然后随机采样出标准正态分布的数据,用$\mu$和$\sigma$对其进行变换。这里的kld指的是KL Divergence,它是Loss的一部分,其计算过程如下:

import torch
import visdom
import numpy as np
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import transforms, datasets



class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        # [b, 784] => [b, 20]
        # u: [b, 10]
        # sigma: [b, 10]
        self.encoder = nn.Sequential(
            # [b, 784] => [b, 256]
            nn.Linear(784, 256),
            nn.ReLU(),
            # [b, 256] => [b, 64]
            nn.Linear(256, 64),
            nn.ReLU(),
            # [b, 64] => [b, 20]
            nn.Linear(64, 20),
            nn.ReLU()
        )
        self.decoder = nn.Sequential(
            # [b, 10] => [b, 64]
            nn.Linear(10, 64),
            nn.ReLU(),
            # [b, 64] => [b, 256]
            nn.Linear(64, 256),
            nn.ReLU(),
            # [b, 256] => [b, 784]
            nn.Linear(256, 784),
            nn.Sigmoid()
        )

    def forward(self, x):
        """
        :param [b, 1, 28, 28]:
        :return [b, 1, 28, 28]:
        """
        batchsz = x.size(0)
        # flatten
        x = x.view(batchsz, -1)
        # encoder
        # [b, 20] including mean and sigma
        q = self.encoder(x)
        # [b, 20] => [b, 10] and [b, 10]
        mu, sigma = q.chunk(2, dim=1)
        # reparameterize trick,  epsilon~N(0, 1)
        q = mu + sigma * torch.randn_like(sigma)

        # decoder
        x_hat = self.decoder(q)
        # reshape
        x_hat = x_hat.view(batchsz, 1, 28, 28)

        # KL
        kld = 0.5 * torch.sum(
            torch.pow(mu, 2) +
            torch.pow(sigma, 2) -
            torch.log(1e-8 + torch.pow(sigma, 2)) - 1
        ) / (batchsz*28*28)

        return x_hat, kld

def main():
    mnist_train = datasets.MNIST('mnist', train=True, transform=transforms.Compose([
        transforms.ToTensor()
    ]), download=True)
    mnist_train = DataLoader(mnist_train, batch_size=32, shuffle=True)

    mnist_test = datasets.MNIST('mnist', train=False, transform=transforms.Compose([
        transforms.ToTensor()
    ]), download=True)
    mnist_test = DataLoader(mnist_test, batch_size=32)

    epochs = 1000
    lr = 1e-3
    model = VAE()
    criteon = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    print(model)

    viz = visdom.Visdom()
    for epoch in range(epochs):
        # 不需要label,所以用一个占位符"_"代替
        for batchidx, (x, _) in enumerate(mnist_train):
            x_hat, kld = model(x)
            loss = criteon(x_hat, x)

            if kld is not None:
                elbo = loss + 1.0 * kld
                loss = elbo

            # backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        if epoch % 10 == 0:
            print(epoch, 'loss:', loss.item(), 'kld', kld.item())
        x, _ = iter(mnist_test).next()
        with torch.no_grad():
            x_hat, kld = model(x)
        viz.images(x, nrow=8, win='x', opts=dict(title='x'))
        viz.images(x_hat, nrow=8, win='x_hat', opts=dict(title='x_hat'))

if __name__ == '__main__':
    main()
Archives Tip
QR Code for this page
Tipping QR Code
Leave a Comment