LSTM正反向傳播推導

吳政龍
5 min readMar 14, 2019

--

之前寫了一篇實作,這次來寫個理論推導了.

LSTM來自於1997年的論文[1],為了解決RNN的長期運作,梯度消失/爆炸的問題,這篇論文用了四種gate(input, forget,output,cell)來決定是否存入輸入資料的特徵或是刪除之前留下的特徵,由於原始的論文和現在的套件所使用的方法有些微不同,因此在這邊以Pytorch文檔[2]上的算式為主.

把這個畫成流程圖,就像下圖

LSTM流程圖

至於每個gate,用下圖表示gate所需的元素

LSTM正向傳播

為了方便說明,我們把x(輸入)設成5個元素的1*5矩陣,而h(t-1)(上一刻的輸出)則設成有10個hidden_feature的向量,也就是10個元素的1*10矩陣.

輸入之向量
上一刻的輸出

而他們都會乘上各自的權重,再加上各自的偏差,如下圖

但和CNN的權重有所不同,LSTM的權重都是矩陣型態.

對應到輸入x的權重,它就是個5*10的矩陣,共50個可學習的變數

而對應到h的權重,是個10*10的矩陣,共100個可學習的變數

由於x本身是1*5的矩陣,而權重是5*10的矩陣,經過矩陣相乘後,就成為1*10的矩陣,h也同理,輸出為1*10的矩陣,而偏差都是1*10的矩陣,經過矩陣運算後,input gate也是1*10的矩陣.

同理,其他三個gate也會是1*10的矩陣.

接下來要講解各個gate的作用:

input gate

單純是個sigmoid,主要是把數據整成介於0~1之間的數字.

cell gate

經由tanh作用,數據介於-1~1,用來調整進入cell裏頭的數據,由於可以到-1,因此也有消減cell狀態的功用.

forget gate

經過sigmoid作用,數據介於0~1,用來決定把上一刻cell的狀態的多少比例存到這個cell裏頭.如果是0,則上一刻cell狀態就無法對此刻的cell造成影響.

output gate

經過sigmoid作用,數據介於0~1,用來決定把cell的狀態的多少比例輸出到ht.如果是0,則此刻的cell狀態則無法輸出.

hidden state

此為LSTM用來判斷狀態的依據,為1*10的矩陣,由於名字為“隱藏狀態”,當時還被這個名字困惑了一陣子,以為和CNN的隱藏層類似.在Pytorch裏頭,LSTMCell就是這樣的架構,至於LSTM則是把LSTMCell每一次運算出來的cell state 和hidden state儲存起來,其陣列長度是輸入的sequence length.如果是一般的序列狀態判斷,只要讓LSTM cell掃過一次,最後出來的hidden state就是判斷的依據了.而像是要憑著一小段序列推測整個序列,則需要把整個hidden state陣列都用上.

LSTM反向傳播

和CNN的反向傳播一樣,都是靠鏈鎖律,由於這個沒有convolutional kernel的反向傳播這麼複雜,因此這邊只用forget gate的權重更新當例子

LSTM cell誤差傳遞流程如下

根據鏈鎖律,誤差對forget gate的微分可寫成如下算式:

而forget gate對權重W的誤差傳遞如下

此段誤差傳遞可寫成下面算式

把兩段的算式串聯如下,即權重的更新流程.

LSTM Cell 在hidden state輸入端的權重數量為hidden state陣列長度的平方,而在x輸入端的權重數量為hidden state陣列長度*輸入陣列長度,加上共有4個gate的權重要算.因此運算量會比一般的神經元多上很多.

ref:

[1]https://www.bioinf.jku.at/publications/older/2604.pdf

[2]https://pytorch.org/docs/stable/nn.html?highlight=lstm#torch.nn.LSTM

--

--

No responses yet