本节将介绍在 pytorch 中非常重要的类:nn.Module
。在实现自己设计的网络时,必须要继承这个类,示例写法如下
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
-
- # 先定义自己的类
- class MyNN(nn.Module):
- def __init__(self, inp, outp):
- # 初始化自己定义的类
- super(MyNN, self).__init__()
-
- self.w = nn.Parameter(torch.randn(outp, inp))
- self.b = nn.Parameter(torch.randn(outp))
-
- # 定义前向传播
- def forward(self, x):
- x = x @ self.w.t() + self.b
- return x
那么 nn.Module
这个类有哪些功能?
nn.Module
提供了很多已经编写好的功能,如Linear
、ReLU
、Conv2d
、Dropout
等- 书写代码方便。例如我们要定义一个基本的 CNN 结构,代码如下
- self.net = nn.Sequential(
- # .Sequential()相当于设置了一个容器(Container)
- # 将需要进行forward的函数写在其中
-
- nn.Conv2d(1, 32, 5, 1, 1),
- nn.MaxPool2d(2, 2),
- nn.ReLU(True),
- nn.BatchNorm2d(32),
-
- nn.Conv2d(32, 64, 3, 1, 1),
- nn.ReLU(True),
- nn.BatchNorm2d(64),
-
- nn.Conv2d(64, 64, 3, 1, 1),
- nn.MaxPool2d(2, 2),
- nn.ReLU(True),
- nn.BatchNorm2d(64),
-
- nn.Conv2d(64, 128, 3, 1, 1),
- nn.ReLU(True),
- nn.BatchNorm2d(128)
- )
或者需要将自己设计的层连接在一起的情况
- class Faltten(nn.Module):
- def __init__(self):
- super(Faltten, self).__init__()
-
- def forward(self, input):
- return input.view(inputt.size(0), -1)
-
- class TestNet(nn.Module):
- def __init__(self):
- super(TestNet, self).__init__()
- self.net = nn.Sequential(
- nn.Conv2d(1, 16, stride=1, padding=1),
- nn.MaxPool2d(2, 2),
- Flatten(),
- nn.Linear(1*14*14, 10)
- )
-
- def forward(self, x):
- return self.net(x)
- 使用
nn.Module
可以对网络中的参数进行有效的管理
- net = nn.Sequential(
- nn.Linear(in_features=4, out_features=2),
- nn.Linear(in_features=2, out_features=2)
- )
-
- # 隐藏层的编号是从0开始的
- list(net.parameters())[0] # [0]是layer0的w
- list(net.parameters())[3].shape # [3]是layer1的b
- dict(net.named_parameters()).items() # 返回所有层的参数
-
- optimizer = optim.SGD(net.parameters(), lr=1e-3)
输出
- torch.Size([2, 4])
- torch.Size([2])
- dict_items([('0.weight', Parameter containing:
- tensor([[ 0.0195, 0.4698, -0.4913, -0.3336],
- [ 0.1422, 0.2908, -0.2469, 0.0583]], requires_grad=True)), ('0.bias', Parameter containing:
- tensor([-0.4704, -0.1133], requires_grad=True)), ('1.weight', Parameter containing:
- tensor([[-0.6511, 0.2442],
- [ 0.5658, 0.4419]], requires_grad=True)), ('1.bias', Parameter containing:
- tensor([ 0.0114, -0.5664], requires_grad=True))])
- 可以很方便的将所有运算都转入到 GPU 上去,使用
.device()
函数
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- net = Net()
- net.to(device)
- 可以很方便的进行 save 和 load,以防止突然发生的断点和系统崩溃现象
- torch.save(net.state_dict(), 'ckpt.mdl')
- net.load_state_dict(torch.load('ckpt.mdl'))
- 还可以很方便的切换 train 和 test 的状态
- # train
- net.train()
-
- # test
- net.eval()