之前寫了一篇實作,這次來寫個理論推導了.
LSTM來自於1997年的論文[1],為了解決RNN的長期運作,梯度消失/爆炸的問題,這篇論文用了四種gate(input, forget,output,cell)來決定是否存入輸入資料的特徵或是刪除之前留下的特徵,由於原始的論文和現在的套件所使用的方法有些微不同,因此在這邊以Pytorch文檔[2]上的算式為主.
把這個畫成流程圖,就像下圖
至於每個gate,用下圖表示gate所需的元素
LSTM正向傳播
為了方便說明,我們把x(輸入)設成5個元素的1*5矩陣,而h(t-1)(上一刻的輸出)則設成有10個hidden_feature的向量,也就是10個元素的1*10矩陣.
而他們都會乘上各自的權重,再加上各自的偏差,如下圖
但和CNN的權重有所不同,LSTM的權重都是矩陣型態.
對應到輸入x的權重,它就是個5*10的矩陣,共50個可學習的變數
而對應到h的權重,是個10*10的矩陣,共100個可學習的變數
由於x本身是1*5的矩陣,而權重是5*10的矩陣,經過矩陣相乘後,就成為1*10的矩陣,h也同理,輸出為1*10的矩陣,而偏差都是1*10的矩陣,經過矩陣運算後,input gate也是1*10的矩陣.
同理,其他三個gate也會是1*10的矩陣.
接下來要講解各個gate的作用:
input gate
單純是個sigmoid,主要是把數據整成介於0~1之間的數字.
cell gate
經由tanh作用,數據介於-1~1,用來調整進入cell裏頭的數據,由於可以到-1,因此也有消減cell狀態的功用.
forget gate
經過sigmoid作用,數據介於0~1,用來決定把上一刻cell的狀態的多少比例存到這個cell裏頭.如果是0,則上一刻cell狀態就無法對此刻的cell造成影響.
output gate
經過sigmoid作用,數據介於0~1,用來決定把cell的狀態的多少比例輸出到ht.如果是0,則此刻的cell狀態則無法輸出.
hidden state
此為LSTM用來判斷狀態的依據,為1*10的矩陣,由於名字為“隱藏狀態”,當時還被這個名字困惑了一陣子,以為和CNN的隱藏層類似.在Pytorch裏頭,LSTMCell就是這樣的架構,至於LSTM則是把LSTMCell每一次運算出來的cell state 和hidden state儲存起來,其陣列長度是輸入的sequence length.如果是一般的序列狀態判斷,只要讓LSTM cell掃過一次,最後出來的hidden state就是判斷的依據了.而像是要憑著一小段序列推測整個序列,則需要把整個hidden state陣列都用上.
LSTM反向傳播
和CNN的反向傳播一樣,都是靠鏈鎖律,由於這個沒有convolutional kernel的反向傳播這麼複雜,因此這邊只用forget gate的權重更新當例子
LSTM cell誤差傳遞流程如下
根據鏈鎖律,誤差對forget gate的微分可寫成如下算式:
而forget gate對權重W的誤差傳遞如下
此段誤差傳遞可寫成下面算式
把兩段的算式串聯如下,即權重的更新流程.
LSTM Cell 在hidden state輸入端的權重數量為hidden state陣列長度的平方,而在x輸入端的權重數量為hidden state陣列長度*輸入陣列長度,加上共有4個gate的權重要算.因此運算量會比一般的神經元多上很多.
ref:
[1]https://www.bioinf.jku.at/publications/older/2604.pdf
[2]https://pytorch.org/docs/stable/nn.html?highlight=lstm#torch.nn.LSTM