本站原创文章,转载请说明来自《老饼讲解-深度学习》www.bbbdata.com
BiRNN使用了正反双向RNN来共同预测某时刻的输出,使得单时刻的输出考虑了正反时刻的信息
本文介绍BiRNN双向循环神经网络的模型结构和思想,并展示BiRNN的具体代码实现
通过本文,可以快速了解什么是双向循环神经网络BiRNN,以及如何使用代码实现BiRNN来进行序列预测
本节介绍BiRNN双向循环神经网络的模型结构,以及为什么需要BiRNN模型
BiRNN双向循环神经网络是什么
BiRNN全称为(Bidirectional Recurrent Neural Networks)双向循环神经网,它的特色是使用了双向预测
BiRNN出自1997年的论文:《Bidirectional Recurrent Neural Networks》
在经典RNN中,只有前馈信息,也就t时刻的输出只考虑了t时刻之前的信息,但往往t时刻还需要参考后面的信息
例如,句子"你好像没睡",在第2时刻时,只捕捉到"你好"会以为是在问好,此时需要参考后面的信息才能预测准确
在BiRNN出来之前,处理上述问题的方法是独立训练两个RNN:
👉正向RNN:一个从t1预测到tn的RNN(称为正向RNN)
👉反向RNN:一个从tn预测到t1的RNN(称为反向RNN)
最后再用正、反向RNN在tk时刻的隐节点共同预测yk,如下:
总的来说,就是用两个独立的RNN,一个顺着预测,一个倒着预测,最终使用两个RNN的隐层共同预测输出
上述处理方法使用的是两个独立的正反RNN,而BiRNN则是将正反RNN综合成一个模型
BiRNN的模型结构如下:
BiRNN就是将一个正向RNN和一个反向RNN(将序列倒着预测的RNN)组合在一起进行预测y
举例,T5时刻的输出,正向RNN综合了t1...t5时刻的信息,而反向RNN综合了tn...t6,t5的信息,然后两者加权得到输出
BiRNN这样就可以把t5前后的所信息一起综合来进行预测y5
BiRNN的正向传播与反向传播
BiRNN的正向传播与反向传播基本只要基于单向RNN进行计算就可以
BiRNN的正馈与后馈如下:
👉BiRNN的正向传播:BiRNN预测y时,只需先独立算出正反RNN的所有隐节点
然后再把正反RNN对应的隐节点进行预测y就可以
👉BiRNN的反向传播:由于正反RNN的参数是独立的,BiRNN仍然使用单向RNN的训练方法就可以
因为参数独立,所以任何一个参数的梯度计算与单向RNN的梯度计算并没有本质的区别
BiRNN是在1997年提出的,最初提出时使用的是正反RNN,但之后其实已经演变为一种思想
例如,2005年的论文《Framewise Phoneme Classification with Bidirectional LSTM and Other Neural Network Architectures 》
就改用正反LSTM来解决语音问题,所以,总的来说,BiRNN就是使用双向模型来共同预测的一种方法
本节讲解如何用代码实现一个BiRNN双向循环神经网络
BiRNN-代码实现
在pytorch中,只需要把RNN隐神经元设为双向的(即参数bidirectional=True)就可以实现双向循环神经网络
具体代码示例如下:
import torch
import torch.nn as nn
# BiRNN神经网络的结构
class BiRnnNet(nn.Module):
def __init__(self,input_size,out_size,hiden_size):
super(BiRnnNet, self).__init__()
self.rnn = nn.LSTM(input_size, hiden_size,bidirectional=True) # 使用双向预测,此时它会返回正反向的隐节点
self.fc = nn.Linear(hiden_size*2, out_size) # 线性层,由于是正反隐节点,所以实际节点个数需要翻倍
def forward(self, x):
h,_ = self.rnn(x) # 计算循环隐层,此时h会返回正反向的隐节点
y = self.fc(h) # 使用线性层计算输出
return y,h
# 模型应用
model = BiRnnNet(1,1,3) # 初始化荆
x = torch.tensor([1.,2.,3.,4.]).unsqueeze(1).unsqueeze(2) # 生成一个输入
y,h = model(x) # 计算模型的输出
print('\nx:',x) # 打印输入
print('\ny:',y) # 打印模型的输出
print('\nh:',h) # 打印模型的隐节点
运行后结果如下:
上述代码中,就实现了一个双向LSTM循环神经网络,并设置了3个隐节点
因此,输入长度为4的序列x时,就得到了4个时刻的隐节点(隐节点长度为6)
然后根据每个时刻的隐节点,经由输出层得到最终每个时刻的输出y
End