本文轉(zhuǎn)載至知乎ID:Charles(白露未晞)知乎個(gè)人專欄
本文轉(zhuǎn)載至知乎ID:Charles(白露未晞)知乎個(gè)人專欄
下載W3Cschool手機(jī)App,0基礎(chǔ)隨時(shí)隨地學(xué)編程>>戳此了解
導(dǎo)語
好幾天沒推文的罪惡感讓我決定今天來水一篇文章。
和之前“Python玩CartPole”那篇推文一樣,這也是來自于PyTorch官方教程的一個(gè)簡單實(shí)例。
為了展示我的誠意,我依舊會(huì)由淺入深地講解本文使用到的基本模型:Seq2Seq以及Attention機(jī)制。
內(nèi)容依舊會(huì)很長~~~
希望對初入NLP/DeepLearning的童鞋有所幫助~
廢話不多說,直接進(jìn)入正題~~~
相關(guān)文件
百度網(wǎng)盤下載鏈接: https://pan.baidu.com/s/1y3KcMboz_xZJ9Afh5nRkUw
密碼: qvhd
參考文獻(xiàn)
官方英文教程鏈接:
http://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html
另外:
對英文文獻(xiàn)閱讀有困難的同學(xué)也不必?fù)?dān)心,我已經(jīng)把這個(gè)教程翻譯為中文放到了相關(guān)文件中。
開發(fā)工具
系統(tǒng):Windows10
Python版本:3.6.4
相關(guān)模塊:
torch模塊;
numpy模塊;
matplotlib模塊;
以及一些Python自帶的模塊。
其中PyTorch版本為:
0.3.0
環(huán)境搭建
安裝Python并添加到環(huán)境變量,pip安裝需要的相關(guān)模塊即可。
補(bǔ)充說明:
PyTorch暫時(shí)不支持直接pip安裝。
有兩個(gè)選擇:
(1)安裝anaconda3后在anaconda3的環(huán)境下安裝(直接pip安裝即可);
(2)使用編譯好的whl文件安裝,下載鏈接為:
https://pan.baidu.com/s/1dF6ayLr#list/path=%2Fpytorch
原理介紹
PS:
部分內(nèi)容參考了相關(guān)網(wǎng)絡(luò)博客和書籍。
(1)單層網(wǎng)絡(luò)
單層網(wǎng)絡(luò)的結(jié)構(gòu)類似下圖:
輸入的x經(jīng)過變換wx+b和激活函數(shù)f得到輸出y。
相信對機(jī)器學(xué)習(xí)/深度學(xué)習(xí)有初步了解的同學(xué)都知道,這其實(shí)就是單層感知機(jī)嘛~~~
為了方便起見,我們把它畫成這樣(請忽視我拙劣的繪圖水平):
x為輸入向量,y為輸出向量,箭頭表示一次變換,也就是y=f(Wx+b)。
(2)經(jīng)典RNN
在實(shí)際中,我們會(huì)遇到很多序列形的數(shù)據(jù):
X1,X2,X3,X4...
例如我們的機(jī)器翻譯模型,X1可以看作是第一個(gè)單詞,X2可以看作是第二個(gè)單詞,以此類推。
原始的神經(jīng)網(wǎng)絡(luò)并不能很好地處理序列形的數(shù)據(jù),于是救世主RNN出現(xiàn)了,它引入了隱狀態(tài)h的概念,利用h對序列形的數(shù)據(jù)提取特征,接著再轉(zhuǎn)換為輸出。下面詳細(xì)說明一下其計(jì)算過程(下圖中的h0為初始隱藏狀態(tài),為簡單起見,我們假設(shè)它是根據(jù)具體模型而設(shè)置的一個(gè)合理值):
其中:
再重申一遍,所有的字母均為向量,箭頭代表對向量做一次變換。
h2的計(jì)算與h1類似,并且每一步使用的參數(shù)P、Q、b都是一樣的,也就是說每個(gè)步驟的參數(shù)共享:
其中:
以此類推(記住參數(shù)都是一樣的?。?!),該計(jì)算可以無限地持續(xù)下去(不限于圖中的長度4?。?!)。
那么RNN的輸出又如何得到呢?
RNN的輸出值是通過h進(jìn)行計(jì)算的:
其中:
類似地,有y2、y3、y4...:
當(dāng)然,和前面一樣,這里的參數(shù)W和c也是共享的。
以上就是最經(jīng)典的RNN結(jié)構(gòu),我們可以發(fā)現(xiàn)其存在一個(gè)致命的缺點(diǎn):
輸入和輸出序列必須是等長的!
這個(gè)缺點(diǎn)導(dǎo)致了經(jīng)典RNN的適用范圍并沒有想象中的那么大。
(3)改進(jìn)經(jīng)典RNN
情況1(輸入為N,輸出為1):
假設(shè)我們的問題要求我們輸入的是一個(gè)序列,輸出的是一個(gè)單獨(dú)的數(shù)值。那么我們只在最后一個(gè)h上進(jìn)行輸出變換就可以了:
情況2(輸入為1,輸出為N):
當(dāng)輸入只是單一數(shù)值而非序列時(shí)該怎么辦呢?
我們可以只在序列開始進(jìn)行輸入計(jì)算:
當(dāng)然你也可以把輸入信息x作為每個(gè)階段的輸入:
情況3(輸入為N,輸出為M):
這是RNN最重要的一個(gè)變種,這種結(jié)構(gòu)也被稱為:
Encoder-Decoder模型,或者說Seq2Seq模型。
我們的機(jī)器翻譯模型就是以它為基礎(chǔ)的。
Seq2Seq結(jié)構(gòu)先將輸入數(shù)據(jù)編碼成一個(gè)上下文量c:
其中:
即上下文量c可以直接等于最后一個(gè)隱藏狀態(tài),也可以是對最后的隱藏狀態(tài)做一個(gè)變換V得到,當(dāng)然也可以是對所有的隱藏狀態(tài)做一個(gè)變換V得到等等。
上述RNN結(jié)構(gòu)一般稱為Encoder。
得到c之后,我們需要另外一個(gè)RNN網(wǎng)絡(luò)對其進(jìn)行解碼操作,即Decoder。你可以把這個(gè)c當(dāng)作初始狀態(tài)h'0輸入到Decoder中:
當(dāng)然你也可以把c當(dāng)作Decoder每一步的輸入:
算了,補(bǔ)充說明一下吧:
缺少輸入的部分(比如某些藍(lán)色的方塊沒有x輸入)你完全可以把x作為0處理然后再代入經(jīng)典RNN所列出的公式中計(jì)算輸出,其他的也類似。
(4)Attention機(jī)制
在Encoder-Decoder結(jié)構(gòu)中,Encoder把所有的輸入序列都編碼成一個(gè)統(tǒng)一的語義特征c后再進(jìn)行解碼,當(dāng)輸入序列較長時(shí),c很可能無法勝任存儲輸入序列所有信息的任務(wù)。
Attention機(jī)制很好地解決了上述問題。它在Decoder每一步輸入不同的c:
其中,c根據(jù)Encoder中的h生成:
aij代表Encoder中第j階段的hj和Decoder中第i階段的相關(guān)性。
那么這些權(quán)重aij該如何確定呢?aij自然也是從模型中學(xué)得的,我們一般認(rèn)為它與Encoder的第j個(gè)階段的隱狀態(tài)和Decoder的第i-1階段的隱狀態(tài)有關(guān)。
比如我們要計(jì)算a1j:
然后我們需要計(jì)算a2j:
以此類推。
(5)最后任務(wù):法語翻譯成英語
有了前面的鋪墊,相信大家都能看懂官網(wǎng)的教程。
在這里我們僅做簡單的介紹,詳細(xì)的建模和實(shí)現(xiàn)過程可以參考我翻譯的官方文檔。
Encoder網(wǎng)絡(luò)為:
Decoder網(wǎng)絡(luò)為:
其中,encoder最后一個(gè)隱藏狀態(tài)作為decoder的初始隱藏狀態(tài)。attention機(jī)制的權(quán)重計(jì)算類似(4)中所述。GRU網(wǎng)絡(luò)的結(jié)構(gòu)為:
GRU網(wǎng)絡(luò)結(jié)構(gòu)在此就不作詳細(xì)的介紹了,篇幅太長的話估計(jì)沒人看得下去吧,就先這樣了~~~
在相關(guān)文件中我也提供了4篇相關(guān)的論文供感興趣者閱讀與研究。(T_T純英文的~~~)
結(jié)果展示
在cmd窗口運(yùn)行Translation.py文件即可。
誤差曲線:
訓(xùn)練過程中cmd窗口的輸出:
模型測試:
作為對比:
和最后一個(gè)測試結(jié)果一模一樣有木有?。?!
當(dāng)然,有些翻譯結(jié)果就不怎么理想了。因?yàn)槟P秃陀?xùn)練數(shù)據(jù)過于簡單了(T_T這里就不舉例了)~~~
最后四句話的attention圖:
That's all~~~
更多
感興趣的同學(xué)可以進(jìn)一步修改模型來獲得更好的結(jié)果,當(dāng)然也可以找找其他數(shù)據(jù)集制作諸如中翻英之類的模型~~~