MENU

PyTorch 中的 LSTM

February 7, 2020 • Read: 13026 • Deep Learning阅读设置

nn.LSTM

PyTorch LSTM API 文档

输入数据格式:

  • input:[seq_len, batch, input_size]
  • $h_0$:[num_layers * num_directions, batch, hidden_size]
  • $c_0$:[num_layers * num_directions, batch, hidden_size]

输出数据格式:

  • output:[seq_len, batch, hidden_size * num_directions]
  • $h_n$:[num_layers * num_directions, batch, hidden_size]
  • $c_n$:[num_layers * num_directions, batch, hidden_size]

接下来看个具体的例子

  • import torch
  • import torch.nn as nn
  • lstm = nn.LSTM(input_size=100, hidden_size=20, num_layers=4)
  • x = torch.randn(10, 3, 100) # 一个句子10个单词,送进去3条句子,每个单词用一个100维的vector表示
  • out, (h, c) = lstm(x)
  • print(out.shape, h.shape, c.shape)
  • # torch.Size([10, 3, 20]) torch.Size([4, 3, 20]) torch.Size([4, 3, 20])

nn.LSTMCell

PyTorch LSTMCell API 文档

和 RNNCell 类似,输入 input_size 的 shape 是 [batch, input_size],输出 $h_t$ 和 $c_t$ 的 shape 是 [batch, hidden_size]

看个一层的 LSTM 的例子

  • import torch
  • import torch.nn as nn
  • cell = nn.LSTMCell(input_size=100, hidden_size=20) # one layer LSTM
  • h = torch.zeros(3, 20)
  • c = torch.zeros(3, 20)
  • x = torch.randn(10, 3, 100)
  • for xt in x:
  • h, c = cell(xt, [h, c])
  • print(h.shape, c.shape) # torch.Size([3, 20]) torch.Size([3, 20])

两层的 LSTM 例子

  • import torch
  • import torch.nn as nn
  • cell1 = nn.LSTMCell(input_size=100, hidden_size=30)
  • cell2 = nn.LSTMCell(input_size=30, hidden_size=20)
  • h1 = torch.zeros(3, 30)
  • c1 = torch.zeros(3, 30)
  • h2 = torch.zeros(3, 20)
  • c2 = torch.zeros(3, 20)
  • x = torch.randn(10, 3, 100)
  • for xt in x:
  • h1, c1 = cell1(xt, [h1, c1])
  • h2, c2 = cell2(h1, [h2, c2])
  • print(h2.shape, c2.shape) # torch.Size([3, 20]) torch.Size([3, 20])
Last Modified: February 8, 2020
Archives Tip
QR Code for this page
Tipping QR Code
Leave a Comment

5 Comments
  1. dsy dsy

    学到了,我自己编写一个 LSTM 实现算法,不知道对不对,看到博主用 LSTMCell,看到了测试的希望

  2. noao noao

    x = torch.randn (10, 3, 100) # 一个句子 10 个单词,送进去 3 条句子,每个单词用一个 100 维的 vector 表示

    请问,为什么我在调试的时候看到是:(batch_size 批次大小,max_len 句子长度,词向量维度),和这里说的前两维反了,求正解.

    1. mathor mathor

      @noao 正常来说,大部分情况下我们确实会将句子处理成 [batch_size, max_len, dim] 这样的维度,但是 PyTorch 中的 RNN 类模型,默认的输入格式是 [max_len, batch_size, dim],所以你才会觉得前两个维度反了,是因为默认就需要 max_len 在前面

    2. noao noao

      @mathor 好的 谢谢博主~

    3. Huskky Huskky

      @mathor 请问博主,torch.randn (10,3,100),如果默认输入格式是 [max_len, batch_size, dim],那么生成的 shape 不应该是 310100,3 个 channel 的 10 个单词的 100 维 vector 吗?可是输出好像是 103100,10 个 channel 的 3 个单词的 100 维 vector 表示。这里没太明白,求正解 @(哈哈)