『零基础+1』一文看懂LSTM原理-《动手学深度学习》
更新时间:2026-02-15 10:37:46
『零基础+1』一文看懂LSTM原理-《动手学深度学习》
长短期记忆网络(LSTM)是为解决隐变量模型长期信息保存与短期输入缺失问题而设计的。它包含了记忆元、输入门、遗忘门和输出门三个门控机制,通过特定计算控制信息的留存和更新。本文从数学原理、从零开始实现以及简洁实现等角度进行了详细介绍,并提及了变体(如带猫眼连接)以及与GRU的区别。最后,展示了训练和预测的示例,便于理解和应用。
1 长短期记忆网络(LSTM)
长久以来,隐变量模型面临了长期信息保留与短期输入缺失的问题。为解决此难题,最早的解决方案便是长短期记忆网络(Long Short-Term Memory Network, LSTM),源自Hochreiter & Schmidhuber(。
许多长短期记忆网络具有类似门控循环单元的性质。然而,其设计更为复杂,并且比门控循环单元早问世约之久。
1.1 门控记忆元
可以说,长短期记忆网络的设计灵感来自于计算机的逻辑门。
长短期记忆网络引入了记忆元(memory cell),或简称为单元(cell)。
有些文献认为记忆元是隐状态的一种特殊类型,
它们与隐状态具有相同的形状,其设计目的是用于记录附加的信息。
为了控制记忆元,我们需要许多门。
其中一个门用来从单元中输出条目,我们将其称为输出门(output gate)。
另外一个门用来决定何时将数据读入单元,我们将其称为输入门(input gate)。
我们还需要一种机制来重置单元的内容,由遗忘门(forget gate)来管理,
采用相同的设计理念,旨在通过特定方式激活或抑制隐状态中输入的内容,这与门控循环单元的功能相似。
注:

Sigmoid 层的输出值在 到 之间表示各部分处理的信息程度。代表「完全关闭通道」,代表「门一直开着」。
一个 LSTM 有三个这样的门,控制 cell 的状态。
门实际上是一种信息选择机制,它通过sigmoid神经网络层对输入数据进行处理,生成一个在[ 区间内的概率值,并与原始输入数据相乘以实现去留控制。
门的操作是相同的,只是根据不同的设计思想,不同的数据流,叫不同的名字
1.2 输入门、忘记门和输出门
就如在门控循环单元中一样,当前时刻的输入与前一时刻的隐藏状态同样被送入长期记忆网络的门。
通过使用具有Sigmoid激活函数的全连接层处理,这些神经网络能够计算出输入门、遗忘门和输出门的值,确保它们的范围始终在(之间。
在 LSTMs 的初期阶段,我们需确定要抛弃哪些细胞中的信息,这一过程通过激活函数中的「记忆门」完成。
它的输入是 ht-和 xt,输出是一个介于间的数。Ct示了cell中所有值处于间的情况,即为“全保留”或“全抛弃”。

下一步,我们需要决定什么样的信息应该被存储起来。这个过程主要分两步。
首先是 sigmoid 层(输入门)决定我们需要更新哪些值;
随后,tanh 层生成了一个新的候选向量 C`,它能够加入状态中。

接下来,我们就可以更新 cell 的状态了。
将旧状态与 ft 相乘,忘记此前我们想要忘记的内容,然后加上 C`。此时遗忘门为ftft
得到的结果便是新的候选值,依照itit进行缩放。

最后,我们需要决定要输出什么。此输出将基于我们处理后的单元状态。
首先,我们会运行一个 sigmoid 层决定 cell 状态输出哪一部分。
随后,我们把 cell 状态通过 tanh 函数,将输出值保持在-1 到 1 间。
之后,我们再乘以 sigmoid 门的输出值,就可以得到结果了。
我们来细化一下长短期记忆网络的数学表达。
假设有hh个隐藏单元,批量大小为nn,输入数为dd。
因此,输入为Xt∈Rn×dXt∈Rn×d,
前一时间步的隐状态为Ht1∈Rn×hHt1∈Rn×h。
相应地,时间步tt的门被定义如下:
输入门是It∈Rn×hIt∈Rn×h,
遗忘门是Ft∈Rn×hFt∈Rn×h,
输出门是Ot∈Rn×hOt∈Rn×h。
它们的计算方法如下:
It=σ(XtWxi+Ht1Whi+bi)It=σ(XtWxi+Ht1Whi+bi)
Ft=σ(XtWxf+Ht1Whf+bf),Ft=σ(XtWxf+Ht1Whf+bf),
Ot=σ(XtWxo+Ht1Who+bo)Ot=σ(XtWxo+Ht1Who+bo)
其中Wxi,Wxf,Wxo∈Rd×hWxi,Wxf,Wxo∈Rd×h
和Whi,Whf,Who∈Rh×hWhi,Whf,Who∈Rh×h是权重参数,
bi,bf,bo∈R1×hbi,bf,bo∈R1×h是偏置参数。
我们将其中的一些操作集合命名为不同的记忆元名称
1.3 候选记忆元
由于缺乏指定不同门的操作,我们先简要介绍候选记忆单元(candidate memory cell),记作 \(C_t \in R^{n \times h}_t\)。它与上面描述的三个门的操作类似,但使用了双切线函数(tanh)作为激活函数,其值范围为(。下面是时间步 t 处的方程:
C~t=tanh(XtWxc+Ht1Whc+bc),C~t=tanh(XtWxc+Ht1Whc+bc),
其中Wxc∈Rd×hWxc∈Rd×h和Whc∈Rh×hWhc∈Rh×h是权重参数,bc∈R1×hbc∈R1×h是偏置参数。
1.4 记忆元
在门控循环单元中,有一种机制来控制输入和遗忘(或跳过)。 类似地,在长短期记忆网络中,也有两个门用于这样的目的: 输入门ItIt控制采用多少来自C~tC~t的新数据, 而遗忘门FtFt控制保留多少过去的 记忆元Ct1∈Rn×hCt1∈Rn×h的内容。 使用按元素乘法,得出:
Ct=Ft⊙Ct1+It⊙C~t.Ct=Ft⊙Ct1+It⊙C~t.
如果遗忘门保持为输入门始终为那么过去的记忆元Ctt随时间被保存并传递到当前的时间步。引入这一设计旨在减轻梯度消失问题,并更好地捕捉序列中的长距离依赖关系。
1.5 隐状态
最后,我们定义如何计算隐状态 \( H_t \in R^n \times h \),这正是输出门功能所在之处。在长短期记忆网络中,它仅是记忆单元激活函数 tanh 的版本控制。这样确保了 \( H_t \) 的值始终位于区间 ( 内。
Ht=Ot⊙tanh(Ct).(9.2.4)Ht=Ot⊙tanh(Ct).(9.2.4)
当输出门值接近,我们能够高效地向预测部分传输所有的记忆信息;相反,如果输出门的值接近则仅保留记忆元中的原有信息而不进行更新。
2 从零开始实现
现在,我们从零开始实现长短期记忆网络。
我们首先加载时光机器数据集。 In [1]
import paddlefrom paddle import nnfrom d2l import paddle as d2limport paddle.nn.functional as Function batch_size, num_steps = 32, 35train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)登录后复制
2.1 初始化模型参数
接下来,我们需要定义和初始化模型参数。
如前所述,超参数num_hiddens定义隐藏单元的数量。
我们按照标准差0.010.01的高斯分布初始化权重,并将偏置项设为00。 In [2]
def get_lstm_params(vocab_size, num_hiddens): num_inputs = num_outputs = vocab_size def normal(shape): return paddle.randn(shape=shape)*0.01 def three(): return (normal((num_inputs, num_hiddens)), normal((num_hiddens, num_hiddens)), paddle.zeros([num_hiddens])) W_xi, W_hi, b_i = three() # 输入门参数 W_xf, W_hf, b_f = three() # 遗忘门参数 W_xo, W_ho, b_o = three() # 输出门参数 W_xc, W_hc, b_c = three() # 候选记忆元参数 # 输出层参数 W_hq = normal((num_hiddens, num_outputs)) b_q = paddle.zeros([num_outputs]) # 附加梯度 params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c, W_hq, b_q] for param in params: param.stop_gradient = False return params登录后复制
2.2 定义模型
在[初始化函数]中,长短期记忆网络的隐藏状态需传递额外的记忆元素,单元值初始化为维度为(批次大小,隐藏单元数)。
因此,我们得到以下的状态初始化。 In [3]
def init_lstm_state(batch_size, num_hiddens): return (paddle.zeros([batch_size, num_hiddens]), paddle.zeros([batch_size, num_hiddens]))登录后复制
在[中,实际模型被定义为包含三个门控单元与额外的细胞记忆单元,但记忆单元仅影响输出层的状态更新,而不参与最终输出计算过程。
def lstm(inputs, state, params): [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c, W_hq, b_q] = params (H, C) = state outputs = [] for X in inputs: I = Function.sigmoid((X @ W_xi) + (H @ W_hi) + b_i) F = Function.sigmoid((X @ W_xf) + (H @ W_hf) + b_f) O = Function.sigmoid((X @ W_xo) + (H @ W_ho) + b_o) C_tilda = paddle.tanh((X @ W_xc) + (H @ W_hc) + b_c) C = F * C + I * C_tilda H = O * paddle.tanh(C) Y = (H @ W_hq) + b_q outputs.append(Y) return paddle.concat(outputs, axis=0), (H, C)登录后复制
2.3 训练 和 预测
让我们通过实例化8.5节中,引入的RNNModelScratch类来训练一个长短期记忆网络。
此外,我们还加入了额外的模型测试。 In [6]
## 训练vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu() num_epochs, lr = 500, 1.0model = d2l.RNNModelScratch(len(vocab), num_hiddens, device,get_lstm_params, init_lstm_state, lstm) d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)登录后复制 In [10]
## 预测# 自定义 prefix , num_preds 进行预测prefix = 'tr'num_preds = 5net = model d2l.predict_ch8(prefix, num_preds, net, vocab, device)登录后复制
'treasth'登录后复制
2.4 简洁实现
使用高级API,我们可以直接实例化LSTM模型。
高级API封装了前文介绍的所有配置细节。
这段代码性能更优,因为使用了预编译的操作符而非Python处理大量内容细节。
num_inputs = vocab_size lstm_layer = nn.LSTM(num_inputs, num_hiddens, time_major=True) model = d2l.RNNModel(lstm_layer, len(vocab)) d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)登录后复制
2.5 结构拓展
LSTM 的变种之一是 Gers & Schmidhuber 在 年提出的“猫眼连接”门控层可以接收细胞的状态。

上图展示了全加上「猫眼连接」的效果,但实际上论文中并不会加这么多。
另一种变体就是采用一对门,分别叫遗忘门(forget)及输入门(input)。
与分开决定遗忘及输入的内容不同,现在的变体会将这两个流程一同实现。
我们只有在将要输入新信息时才会遗忘,而也只会在忘记信息的同时才会有新的信息输入。

一个新的变体为LSTM(Long Short-Term Memory),由Hochreiner等人在提出。它包括一个遗忘门、一个输入门和一个更新门,共同构成记忆单元,同时结合了cell状态与隐藏层状态,进行了细微的调整。

GRU比LSTM少一个门,简化了结构;去除了细胞单元C,并省略了第二阶非线性处理。
这个模型比起标准 LSTM 模型简单一些,因此也变得更加流行了。
当然,这里所列举的只是一管窥豹,还有很多其它的变体,
比如 Yao, et al. (2015) 提出的 Depth Gated RNNs;或是另辟蹊径处理长期依赖问题的 Clockwork RNNs,由 Koutnik, et al. (2014) 提出。
哪个是最好的呢?而这些变化是否真的意义深远?
Greff, et al. (2015)曾经对比较流行的几种变种做过对比,发现它们基本上都差不多;
Jozefowicz, et al. (2015)测试了超过一万种 RNN 结构,发现有一些能够在特定任务上超过 LSTMs。
以上就是『零基础+1』一文看懂LSTM原理-动手学深度学习的详细内容,更多请关注其它相关文章!
