當(dāng)前位置:首頁 > 學(xué)習(xí)資源 > 講師博文 > 一文弄懂RNN、LSTM 和 GRU單元 !
RNN
循環(huán)神經(jīng)網(wǎng)絡(luò)(Recurrent Neural Network,RNN ),主要處理序列數(shù)據(jù),輸入的序列數(shù)據(jù)可以是連續(xù)的、長度不固定的序列數(shù)據(jù),也可以是固定的序列數(shù)據(jù)。循環(huán)神經(jīng)網(wǎng)絡(luò)能保持對(duì)過去事件和當(dāng)前事件的記憶,從而可以捕獲長距離樣本之間的關(guān)聯(lián)信息。循環(huán)神經(jīng)網(wǎng)絡(luò)在文字預(yù)測、語音識(shí)別等領(lǐng)域表現(xiàn)較大優(yōu)勢(shì)。
RNN網(wǎng)絡(luò)結(jié)構(gòu)解析
圖1是RNN網(wǎng)絡(luò)圖示
RNN存在的問題
存在梯度爆炸和消失的問題,對(duì)于長距離的句子的學(xué)習(xí)效果不好。
反向傳播中,對(duì)激活函數(shù)進(jìn)行求導(dǎo),如果此部分大于1,那么層數(shù)增多的時(shí)候,最終的求出的梯度更新將以指數(shù)形式增加,即發(fā)生梯度爆炸,如果此部分小于1,那么隨著層數(shù)增多,求出的梯度更新信息將會(huì)以指數(shù)形式衰減,即發(fā)生了梯度消失。
RNN代碼示例
pytorch 簡單代碼示例
rnn = nn.RNN(10, 20, 2)
input = torch.randn(5, 3, 10)
h0 = torch.randn(2, 3, 20)
output, hn = rnn(input, h0)
LSTM
長短期記憶網(wǎng)絡(luò)(LSTM,Long Short-Term Memory)是一種時(shí)間循環(huán)神經(jīng)網(wǎng)絡(luò),是為了解決一般的RNN(循環(huán)神經(jīng)網(wǎng)絡(luò))存在的長期依賴問題而專門設(shè)計(jì)出來的。
LSTM網(wǎng)絡(luò)結(jié)構(gòu)解析
LSTM網(wǎng)絡(luò)結(jié)構(gòu)如圖2所示
LSTM優(yōu)勢(shì)
RNN中只有一個(gè)隱藏狀態(tài),LSTM增加了一個(gè)元胞狀態(tài)單元,其在不同時(shí)刻有著可變的連接權(quán)重,以解決RNN中梯度消失或爆炸問題。隱藏狀態(tài)控制短期記憶,元胞狀態(tài)單元控制長期記憶,和配合形成長短期記憶。
LSTM代碼示例
pytorch 簡單代碼示例
rnn = nn.LSTM(10, 20, 2)
input = torch.randn(5, 3, 10)
h0 = torch.randn(2, 3, 20)
c0 = torch.randn(2, 3, 20)
output, (hn, cn) = rnn(input, (h0, c0))
GRU單元
門控循環(huán)單元(gated recurrent unit,GRU)是為了解決循環(huán)神經(jīng)網(wǎng)絡(luò)中計(jì)算梯度, 以及矩陣連續(xù)乘積導(dǎo)致梯度消失或梯度爆炸的問題而提出,GRU更簡單,通常它能夠獲得跟LSTM同等的效果,優(yōu)勢(shì)是計(jì)算的速度明顯更快。
GRU單元結(jié)構(gòu)解析
GRU單元結(jié)構(gòu)如圖3所示
GRU優(yōu)勢(shì)
GRU可以取得與LSTM想當(dāng)甚至更好的性能,且收斂速度更快。
GRU代碼示例
pytorch 簡單代碼示例
rnn = nn.GRU(10, 20, 2)
input = torch.randn(5, 3, 10)
h0 = torch.randn(2, 3, 20)
output, hn = rnn(input, h0)