MENU

PyTorch nn.Module

January 19, 2020 • Read: 3910 • Deep Learning阅读设置

本节将介绍在 pytorch 中非常重要的类:nn.Module。在实现自己设计的网络时,必须要继承这个类,示例写法如下

  • import torch
  • import torch.nn as nn
  • import torch.nn.functional as F
  • # 先定义自己的类
  • class MyNN(nn.Module):
  • def __init__(self, inp, outp):
  • # 初始化自己定义的类
  • super(MyNN, self).__init__()
  • self.w = nn.Parameter(torch.randn(outp, inp))
  • self.b = nn.Parameter(torch.randn(outp))
  • # 定义前向传播
  • def forward(self, x):
  • x = x @ self.w.t() + self.b
  • return x

那么 nn.Module 这个类有哪些功能?

  • nn.Module 提供了很多已经编写好的功能,如 LinearReLUConv2dDropout
  • 书写代码方便。例如我们要定义一个基本的 CNN 结构,代码如下
  • self.net = nn.Sequential(
  • # .Sequential()相当于设置了一个容器(Container)
  • # 将需要进行forward的函数写在其中
  • nn.Conv2d(1, 32, 5, 1, 1),
  • nn.MaxPool2d(2, 2),
  • nn.ReLU(True),
  • nn.BatchNorm2d(32),
  • nn.Conv2d(32, 64, 3, 1, 1),
  • nn.ReLU(True),
  • nn.BatchNorm2d(64),
  • nn.Conv2d(64, 64, 3, 1, 1),
  • nn.MaxPool2d(2, 2),
  • nn.ReLU(True),
  • nn.BatchNorm2d(64),
  • nn.Conv2d(64, 128, 3, 1, 1),
  • nn.ReLU(True),
  • nn.BatchNorm2d(128)
  • )

或者需要将自己设计的层连接在一起的情况

  • class Faltten(nn.Module):
  • def __init__(self):
  • super(Faltten, self).__init__()
  • def forward(self, input):
  • return input.view(inputt.size(0), -1)
  • class TestNet(nn.Module):
  • def __init__(self):
  • super(TestNet, self).__init__()
  • self.net = nn.Sequential(
  • nn.Conv2d(1, 16, stride=1, padding=1),
  • nn.MaxPool2d(2, 2),
  • Flatten(),
  • nn.Linear(1*14*14, 10)
  • )
  • def forward(self, x):
  • return self.net(x)
  • 使用 nn.Module 可以对网络中的参数进行有效的管理
  • net = nn.Sequential(
  • nn.Linear(in_features=4, out_features=2),
  • nn.Linear(in_features=2, out_features=2)
  • )
  • # 隐藏层的编号是从0开始的
  • list(net.parameters())[0] # [0]是layer0的w
  • list(net.parameters())[3].shape # [3]是layer1的b
  • dict(net.named_parameters()).items() # 返回所有层的参数
  • optimizer = optim.SGD(net.parameters(), lr=1e-3)

输出

  • torch.Size([2, 4])
  • torch.Size([2])
  • dict_items([('0.weight', Parameter containing:
  • tensor([[ 0.0195, 0.4698, -0.4913, -0.3336],
  • [ 0.1422, 0.2908, -0.2469, 0.0583]], requires_grad=True)), ('0.bias', Parameter containing:
  • tensor([-0.4704, -0.1133], requires_grad=True)), ('1.weight', Parameter containing:
  • tensor([[-0.6511, 0.2442],
  • [ 0.5658, 0.4419]], requires_grad=True)), ('1.bias', Parameter containing:
  • tensor([ 0.0114, -0.5664], requires_grad=True))])
  • 可以很方便的将所有运算都转入到 GPU 上去,使用.device() 函数
  • device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  • net = Net()
  • net.to(device)
  • 可以很方便的进行 save 和 load,以防止突然发生的断点和系统崩溃现象
  • torch.save(net.state_dict(), 'ckpt.mdl')
  • net.load_state_dict(torch.load('ckpt.mdl'))
  • 还可以很方便的切换 train 和 test 的状态
  • # train
  • net.train()
  • # test
  • net.eval()
Last Modified: August 2, 2021
Archives Tip
QR Code for this page
Tipping QR Code
Leave a Comment