PyTorch NLP From Scratch: 基于注意力機(jī)制的 seq2seq 神經(jīng)網(wǎng)絡(luò)翻譯

2020-09-11 10:28 更新
原文: https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html

作者: Sean Robertson

這是關(guān)于“從頭開始進(jìn)行 NLP”的第三篇也是最后一篇教程,我們?cè)谄渲芯帉懽约旱念惡秃瘮?shù)來預(yù)處理數(shù)據(jù)以完成 NLP 建模任務(wù)。 我們希望在完成本教程后,您將繼續(xù)學(xué)習(xí)緊接著本教程的三本教程, <cite>torchtext</cite> 如何為您處理許多此類預(yù)處理。

在這個(gè)項(xiàng)目中,我們將教授將法語翻譯成英語的神經(jīng)網(wǎng)絡(luò)。

  1. [KEY: > input, = target, < output]
  2. > il est en train de peindre un tableau .
  3. = he is painting a picture .
  4. < he is painting a picture .
  5. > pourquoi ne pas essayer ce vin delicieux ?
  6. = why not try that delicious wine ?
  7. < why not try that delicious wine ?
  8. > elle n est pas poete mais romanciere .
  9. = she is not a poet but a novelist .
  10. < she not not a poet but a novelist .
  11. > vous etes trop maigre .
  12. = you re too skinny .
  13. < you re all alone .

……取得不同程度的成功。

通過序列到序列網(wǎng)絡(luò)的簡(jiǎn)單但強(qiáng)大的構(gòu)想,使這成為可能,在該網(wǎng)絡(luò)中,兩個(gè)循環(huán)神經(jīng)網(wǎng)絡(luò)協(xié)同工作,將一個(gè)序列轉(zhuǎn)換為另一個(gè)序列。 編碼器網(wǎng)絡(luò)將輸入序列壓縮為一個(gè)向量,而解碼器網(wǎng)絡(luò)將該向量展開為一個(gè)新序列。

為了改進(jìn)此模型,我們將使用注意機(jī)制,該機(jī)制可讓解碼器學(xué)習(xí)將注意力集中在輸入序列的特定范圍內(nèi)。

推薦讀物:

我假設(shè)您至少已經(jīng)安裝了 PyTorch,了解 Python 和了解 Tensors:

  • https://pytorch.org/ 有關(guān)安裝說明
  • 使用 PyTorch 進(jìn)行深度學(xué)習(xí):60 分鐘的閃電戰(zhàn)通常開始使用 PyTorch
  • 使用示例學(xué)習(xí) PyTorch 進(jìn)行廣泛而深入的概述
  • PyTorch(以前的 Torch 用戶)(如果您以前是 Lua Torch 用戶)

了解序列到序列網(wǎng)絡(luò)及其工作方式也將很有用:

您還將找到先前的 NLP 從零開始:使用字符級(jí) RNN 對(duì)名稱進(jìn)行分類的教程,以及 NLP 從零開始:使用字符級(jí) RNN 生成名稱的指南,因?yàn)檫@些概念是有用的 分別與編碼器和解碼器模型非常相似。

有關(guān)更多信息,請(qǐng)閱讀介紹以下主題的論文:

要求

  1. from __future__ import unicode_literals, print_function, division
  2. from io import open
  3. import unicodedata
  4. import string
  5. import re
  6. import random
  7. import torch
  8. import torch.nn as nn
  9. from torch import optim
  10. import torch.nn.functional as F
  11. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

加載數(shù)據(jù)文件

該項(xiàng)目的數(shù)據(jù)是成千上萬的英語到法語翻譯對(duì)的集合。

開放數(shù)據(jù)堆棧交換上的這個(gè)問題使我指向開放翻譯站點(diǎn) https://tatoeba.org/ ,該站點(diǎn)可從  https://tatoeba.org/下載。 eng / downloads -更好的是,有人在這里做了額外的工作,將語言對(duì)拆分為單獨(dú)的文本文件: https://www.manythings.org/anki/

英文對(duì)法文對(duì)太大,無法包含在倉庫中,因此請(qǐng)先下載到data/eng-fra.txt,然后再繼續(xù)。 該文件是制表符分隔的翻譯對(duì)列表:

  1. I am cold. J'ai froid.

注意:

從的下載數(shù)據(jù),并將其提取到當(dāng)前目錄。

與字符級(jí) RNN 教程中使用的字符編碼類似,我們將一種語言中的每個(gè)單詞表示為一個(gè)單向矢量,或者零外的一個(gè)巨大矢量(除了單個(gè)索引(在單詞的索引處))。 與一種語言中可能存在的數(shù)十個(gè)字符相比,單詞有很多,因此編碼向量要大得多。 但是,我們將作弊并整理數(shù)據(jù)以使每種語言僅使用幾千個(gè)單詞。

我們需要每個(gè)單詞一個(gè)唯一的索引,以便以后用作網(wǎng)絡(luò)的輸入和目標(biāo)。 為了跟蹤所有這些信息,我們將使用一個(gè)名為Lang的幫助程序類,該類具有單詞→索引(word2index)和索引→單詞(index2word)字典,以及每個(gè)要使用的單詞word2count的計(jì)數(shù) 以便以后替換稀有詞。

  1. SOS_token = 0
  2. EOS_token = 1
  3. class Lang:
  4. def __init__(self, name):
  5. self.name = name
  6. self.word2index = {}
  7. self.word2count = {}
  8. self.index2word = {0: "SOS", 1: "EOS"}
  9. self.n_words = 2 # Count SOS and EOS
  10. def addSentence(self, sentence):
  11. for word in sentence.split(' '):
  12. self.addWord(word)
  13. def addWord(self, word):
  14. if word not in self.word2index:
  15. self.word2index[word] = self.n_words
  16. self.word2count[word] = 1
  17. self.index2word[self.n_words] = word
  18. self.n_words += 1
  19. else:
  20. self.word2count[word] += 1

這些文件全部為 Unicode,為簡(jiǎn)化起見,我們將 Unicode 字符轉(zhuǎn)換為 ASCII,將所有內(nèi)容都轉(zhuǎn)換為小寫,并修剪大多數(shù)標(biāo)點(diǎn)符號(hào)。

  1. # Turn a Unicode string to plain ASCII, thanks to
  2. ## https://stackoverflow.com/a/518232/2809427
  3. def unicodeToAscii(s):
  4. return ''.join(
  5. c for c in unicodedata.normalize('NFD', s)
  6. if unicodedata.category(c) != 'Mn'
  7. )
  8. ## Lowercase, trim, and remove non-letter characters
  9. def normalizeString(s):
  10. s = unicodeToAscii(s.lower().strip())
  11. s = re.sub(r"([.!?])", r" \1", s)
  12. s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
  13. return s

要讀取數(shù)據(jù)文件,我們將文件分成幾行,然后將行分成兩對(duì)。 這些文件都是英語→其他語言的,因此,如果我們要從其他語言→英語進(jìn)行翻譯,我添加了reverse標(biāo)志來反轉(zhuǎn)對(duì)。

  1. def readLangs(lang1, lang2, reverse=False):
  2. print("Reading lines...")
  3. # Read the file and split into lines
  4. lines = open('data/%s-%s.txt' % (lang1, lang2), encoding='utf-8').\
  5. read().strip().split('\n')
  6. # Split every line into pairs and normalize
  7. pairs = [[normalizeString(s) for s in l.split('\t')] for l in lines]
  8. # Reverse pairs, make Lang instances
  9. if reverse:
  10. pairs = [list(reversed(p)) for p in pairs]
  11. input_lang = Lang(lang2)
  12. output_lang = Lang(lang1)
  13. else:
  14. input_lang = Lang(lang1)
  15. output_lang = Lang(lang2)
  16. return input_lang, output_lang, pairs

由于示例句子的數(shù)量很多,并且我們想快速訓(xùn)練一些東西,因此我們將數(shù)據(jù)集修剪為僅相對(duì)較短和簡(jiǎn)單的句子。 在這里,最大長(zhǎng)度為 10 個(gè)字(包括結(jié)尾的標(biāo)點(diǎn)符號(hào)),并且我們正在過濾翻譯成“我是”或“他是”等形式的句子(考慮到前面已替換掉撇號(hào)的情況)。

  1. MAX_LENGTH = 10
  2. eng_prefixes = (
  3. "i am ", "i m ",
  4. "he is", "he s ",
  5. "she is", "she s ",
  6. "you are", "you re ",
  7. "we are", "we re ",
  8. "they are", "they re "
  9. )
  10. def filterPair(p):
  11. return len(p[0].split(' ')) < MAX_LENGTH and \
  12. len(p[1].split(' ')) < MAX_LENGTH and \
  13. p[1].startswith(eng_prefixes)
  14. def filterPairs(pairs):
  15. return [pair for pair in pairs if filterPair(pair)]

準(zhǔn)備數(shù)據(jù)的完整過程是:

  • 讀取文本文件并拆分為行,將行拆分為成對(duì)
  • 規(guī)范文本,按長(zhǎng)度和內(nèi)容過濾
  • 成對(duì)建立句子中的單詞列表
  1. def prepareData(lang1, lang2, reverse=False):
  2. input_lang, output_lang, pairs = readLangs(lang1, lang2, reverse)
  3. print("Read %s sentence pairs" % len(pairs))
  4. pairs = filterPairs(pairs)
  5. print("Trimmed to %s sentence pairs" % len(pairs))
  6. print("Counting words...")
  7. for pair in pairs:
  8. input_lang.addSentence(pair[0])
  9. output_lang.addSentence(pair[1])
  10. print("Counted words:")
  11. print(input_lang.name, input_lang.n_words)
  12. print(output_lang.name, output_lang.n_words)
  13. return input_lang, output_lang, pairs
  14. input_lang, output_lang, pairs = prepareData('eng', 'fra', True)
  15. print(random.choice(pairs))

得出:

  1. Reading lines...
  2. Read 135842 sentence pairs
  3. Trimmed to 10599 sentence pairs
  4. Counting words...
  5. Counted words:
  6. fra 4345
  7. eng 2803
  8. ['je ne suis pas grand .', 'i m not tall .']

Seq2Seq 模型

遞歸神經(jīng)網(wǎng)絡(luò)(RNN)是在序列上運(yùn)行并將其自身的輸出用作后續(xù)步驟的輸入的網(wǎng)絡(luò)。

序列到序列網(wǎng)絡(luò)或 seq2seq 網(wǎng)絡(luò)或編碼器解碼器網(wǎng)絡(luò)是由兩個(gè)稱為編碼器和解碼器的 RNN 組成的模型。 編碼器讀取輸入序列并輸出單個(gè)向量,而解碼器讀取該向量以產(chǎn)生輸出序列。

與使用單個(gè) RNN 進(jìn)行序列預(yù)測(cè)(每個(gè)輸入對(duì)應(yīng)一個(gè)輸出)不同,seq2seq 模型使我們擺脫了序列長(zhǎng)度和順序的限制,這使其非常適合在兩種語言之間進(jìn)行翻譯。

考慮一下句子“ Je ne suis pas le chat noir”→“我不是黑貓”。 輸入句子中的大多數(shù)單詞在輸出句子中具有直接翻譯,但是順序略有不同,例如 “黑貓聊天”和“黑貓”。 由于采用“ ne / pas”結(jié)構(gòu),因此在輸入句子中還有一個(gè)單詞。 直接從輸入單詞的序列中產(chǎn)生正確的翻譯將是困難的。

使用 seq2seq 模型,編碼器創(chuàng)建單個(gè)矢量,在理想情況下,該矢量將輸入序列的“含義”編碼為單個(gè)矢量-句子的某些 N 維空間中的單個(gè)點(diǎn)。

編碼器

seq2seq 網(wǎng)絡(luò)的編碼器是 RNN,它為輸入句子中的每個(gè)單詞輸出一些值。 對(duì)于每個(gè)輸入字,編碼器輸出一個(gè)向量和一個(gè)隱藏狀態(tài),并將隱藏狀態(tài)用于下一個(gè)輸入字。

  1. class EncoderRNN(nn.Module):
  2. def __init__(self, input_size, hidden_size):
  3. super(EncoderRNN, self).__init__()
  4. self.hidden_size = hidden_size
  5. self.embedding = nn.Embedding(input_size, hidden_size)
  6. self.gru = nn.GRU(hidden_size, hidden_size)
  7. def forward(self, input, hidden):
  8. embedded = self.embedding(input).view(1, 1, -1)
  9. output = embedded
  10. output, hidden = self.gru(output, hidden)
  11. return output, hidden
  12. def initHidden(self):
  13. return torch.zeros(1, 1, self.hidden_size, device=device)

解碼器

解碼器是另一個(gè) RNN,它采用編碼器輸出矢量并輸出單詞序列來創(chuàng)建翻譯。

簡(jiǎn)單解碼器

在最簡(jiǎn)單的 seq2seq 解碼器中,我們僅使用編碼器的最后一個(gè)輸出。 最后的輸出有時(shí)稱為上下文向量,因?yàn)樗鼘?duì)整個(gè)序列的上下文進(jìn)行編碼。 該上下文向量用作解碼器的初始隱藏狀態(tài)。

在解碼的每個(gè)步驟中,為解碼器提供輸入令牌和隱藏狀態(tài)。 初始輸入令牌是字符串開始&lt;SOS&gt;令牌,第一個(gè)隱藏狀態(tài)是上下文向量(編碼器的最后一個(gè)隱藏狀態(tài))。

  1. class DecoderRNN(nn.Module):
  2. def __init__(self, hidden_size, output_size):
  3. super(DecoderRNN, self).__init__()
  4. self.hidden_size = hidden_size
  5. self.embedding = nn.Embedding(output_size, hidden_size)
  6. self.gru = nn.GRU(hidden_size, hidden_size)
  7. self.out = nn.Linear(hidden_size, output_size)
  8. self.softmax = nn.LogSoftmax(dim=1)
  9. def forward(self, input, hidden):
  10. output = self.embedding(input).view(1, 1, -1)
  11. output = F.relu(output)
  12. output, hidden = self.gru(output, hidden)
  13. output = self.softmax(self.out(output[0]))
  14. return output, hidden
  15. def initHidden(self):
  16. return torch.zeros(1, 1, self.hidden_size, device=device)

我鼓勵(lì)您訓(xùn)練并觀察該模型的結(jié)果,但是為了節(jié)省空間,我們將直接努力,并引入注意機(jī)制。

注意解碼器

如果僅上下文向量在編碼器和解碼器之間傳遞,則該單個(gè)向量承擔(dān)對(duì)整個(gè)句子進(jìn)行編碼的負(fù)擔(dān)。

注意使解碼器網(wǎng)絡(luò)可以針對(duì)解碼器自身輸出的每一步,“專注”于編碼器輸出的不同部分。 首先,我們計(jì)算一組注意權(quán)重。 將這些與編碼器輸出向量相乘以創(chuàng)建加權(quán)組合。 結(jié)果(在代碼中稱為attn_applied)應(yīng)包含有關(guān)輸入序列特定部分的信息,從而幫助解碼器選擇正確的輸出字。

計(jì)算注意力權(quán)重的方法是使用另一個(gè)前饋層attn,并使用解碼器的輸入和隱藏狀態(tài)作為輸入。 由于訓(xùn)練數(shù)據(jù)中包含各種大小的句子,因此要實(shí)際創(chuàng)建和訓(xùn)練該層,我們必須選擇可以應(yīng)用的最大句子長(zhǎng)度(輸入長(zhǎng)度??,用于編碼器輸出)。 最大長(zhǎng)度的句子將使用所有注意權(quán)重,而較短的句子將僅使用前幾個(gè)。

  1. class AttnDecoderRNN(nn.Module):
  2. def __init__(self, hidden_size, output_size, dropout_p=0.1, max_length=MAX_LENGTH):
  3. super(AttnDecoderRNN, self).__init__()
  4. self.hidden_size = hidden_size
  5. self.output_size = output_size
  6. self.dropout_p = dropout_p
  7. self.max_length = max_length
  8. self.embedding = nn.Embedding(self.output_size, self.hidden_size)
  9. self.attn = nn.Linear(self.hidden_size * 2, self.max_length)
  10. self.attn_combine = nn.Linear(self.hidden_size * 2, self.hidden_size)
  11. self.dropout = nn.Dropout(self.dropout_p)
  12. self.gru = nn.GRU(self.hidden_size, self.hidden_size)
  13. self.out = nn.Linear(self.hidden_size, self.output_size)
  14. def forward(self, input, hidden, encoder_outputs):
  15. embedded = self.embedding(input).view(1, 1, -1)
  16. embedded = self.dropout(embedded)
  17. attn_weights = F.softmax(
  18. self.attn(torch.cat((embedded[0], hidden[0]), 1)), dim=1)
  19. attn_applied = torch.bmm(attn_weights.unsqueeze(0),
  20. encoder_outputs.unsqueeze(0))
  21. output = torch.cat((embedded[0], attn_applied[0]), 1)
  22. output = self.attn_combine(output).unsqueeze(0)
  23. output = F.relu(output)
  24. output, hidden = self.gru(output, hidden)
  25. output = F.log_softmax(self.out(output[0]), dim=1)
  26. return output, hidden, attn_weights
  27. def initHidden(self):
  28. return torch.zeros(1, 1, self.hidden_size, device=device)

注意;

還有其他形式的注意力可以通過使用相對(duì)位置方法來解決長(zhǎng)度限制問題。 閱讀基于注意力的神經(jīng)機(jī)器翻譯的有效方法中的“本地注意力”信息。

訓(xùn)練

準(zhǔn)備訓(xùn)練數(shù)據(jù)

為了訓(xùn)練,對(duì)于每一對(duì),我們將需要一個(gè)輸入張量(輸入句子中單詞的索引)和目標(biāo)張量(目標(biāo)句子中單詞的索引)。 創(chuàng)建這些向量時(shí),我們會(huì)將 EOS 令牌附加到兩個(gè)序列上。

  1. def indexesFromSentence(lang, sentence):
  2. return [lang.word2index[word] for word in sentence.split(' ')]
  3. def tensorFromSentence(lang, sentence):
  4. indexes = indexesFromSentence(lang, sentence)
  5. indexes.append(EOS_token)
  6. return torch.tensor(indexes, dtype=torch.long, device=device).view(-1, 1)
  7. def tensorsFromPair(pair):
  8. input_tensor = tensorFromSentence(input_lang, pair[0])
  9. target_tensor = tensorFromSentence(output_lang, pair[1])
  10. return (input_tensor, target_tensor)

訓(xùn)練模型

為了進(jìn)行訓(xùn)練,我們通過編碼器運(yùn)行輸入語句,并跟蹤每個(gè)輸出和最新的隱藏狀態(tài)。 然后,為解碼器提供&lt;SOS&gt;令牌作為其第一個(gè)輸入,并將編碼器的最后一個(gè)隱藏狀態(tài)作為其第一個(gè)隱藏狀態(tài)。

“教師強(qiáng)制”的概念是使用實(shí)際目標(biāo)輸出作為每個(gè)下一個(gè)輸入,而不是使用解碼器的猜測(cè)作為下一個(gè)輸入。 使用教師強(qiáng)制會(huì)導(dǎo)致其收斂更快,但是當(dāng)使用受過訓(xùn)練的網(wǎng)絡(luò)時(shí),可能會(huì)顯示不穩(wěn)定。

您可以觀察到以教師為主導(dǎo)的網(wǎng)絡(luò)的輸出,這些輸出閱讀的是連貫的語法,但卻偏離了正確的翻譯-直觀地,它學(xué)會(huì)了代表輸出語法,并且一旦老師說了最初的幾個(gè)單詞就可以“理解”含義,但是 首先,它還沒有正確地學(xué)習(xí)如何從翻譯中創(chuàng)建句子。

由于 PyTorch 的 autograd 具有給我們的自由,我們可以通過簡(jiǎn)單的 if 語句隨意選擇是否使用教師強(qiáng)迫。 調(diào)高teacher_forcing_ratio以使用更多功能。

  1. teacher_forcing_ratio = 0.5
  2. def train(input_tensor, target_tensor, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion, max_length=MAX_LENGTH):
  3. encoder_hidden = encoder.initHidden()
  4. encoder_optimizer.zero_grad()
  5. decoder_optimizer.zero_grad()
  6. input_length = input_tensor.size(0)
  7. target_length = target_tensor.size(0)
  8. encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)
  9. loss = 0
  10. for ei in range(input_length):
  11. encoder_output, encoder_hidden = encoder(
  12. input_tensor[ei], encoder_hidden)
  13. encoder_outputs[ei] = encoder_output[0, 0]
  14. decoder_input = torch.tensor([[SOS_token]], device=device)
  15. decoder_hidden = encoder_hidden
  16. use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False
  17. if use_teacher_forcing:
  18. # Teacher forcing: Feed the target as the next input
  19. for di in range(target_length):
  20. decoder_output, decoder_hidden, decoder_attention = decoder(
  21. decoder_input, decoder_hidden, encoder_outputs)
  22. loss += criterion(decoder_output, target_tensor[di])
  23. decoder_input = target_tensor[di] # Teacher forcing
  24. else:
  25. # Without teacher forcing: use its own predictions as the next input
  26. for di in range(target_length):
  27. decoder_output, decoder_hidden, decoder_attention = decoder(
  28. decoder_input, decoder_hidden, encoder_outputs)
  29. topv, topi = decoder_output.topk(1)
  30. decoder_input = topi.squeeze().detach() # detach from history as input
  31. loss += criterion(decoder_output, target_tensor[di])
  32. if decoder_input.item() == EOS_token:
  33. break
  34. loss.backward()
  35. encoder_optimizer.step()
  36. decoder_optimizer.step()
  37. return loss.item() / target_length

這是一個(gè)幫助功能,用于在給定當(dāng)前時(shí)間和進(jìn)度%的情況下打印經(jīng)過的時(shí)間和估計(jì)的剩余時(shí)間。

  1. import time
  2. import math
  3. def asMinutes(s):
  4. m = math.floor(s / 60)
  5. s -= m * 60
  6. return '%dm %ds' % (m, s)
  7. def timeSince(since, percent):
  8. now = time.time()
  9. s = now - since
  10. es = s / (percent)
  11. rs = es - s
  12. return '%s (- %s)' % (asMinutes(s), asMinutes(rs))

整個(gè)訓(xùn)練過程如下所示:

  • 啟動(dòng)計(jì)時(shí)器
  • 初始化優(yōu)化器和標(biāo)準(zhǔn)
  • 創(chuàng)建一組訓(xùn)練對(duì)
  • 啟動(dòng)空損耗陣列進(jìn)行繪圖

然后我們多次調(diào)用train,并偶爾打印進(jìn)度(示例的百分比,到目前為止的時(shí)間,估計(jì)的時(shí)間)和平均損失。

  1. def trainIters(encoder, decoder, n_iters, print_every=1000, plot_every=100, learning_rate=0.01):
  2. start = time.time()
  3. plot_losses = []
  4. print_loss_total = 0 # Reset every print_every
  5. plot_loss_total = 0 # Reset every plot_every
  6. encoder_optimizer = optim.SGD(encoder.parameters(), lr=learning_rate)
  7. decoder_optimizer = optim.SGD(decoder.parameters(), lr=learning_rate)
  8. training_pairs = [tensorsFromPair(random.choice(pairs))
  9. for i in range(n_iters)]
  10. criterion = nn.NLLLoss()
  11. for iter in range(1, n_iters + 1):
  12. training_pair = training_pairs[iter - 1]
  13. input_tensor = training_pair[0]
  14. target_tensor = training_pair[1]
  15. loss = train(input_tensor, target_tensor, encoder,
  16. decoder, encoder_optimizer, decoder_optimizer, criterion)
  17. print_loss_total += loss
  18. plot_loss_total += loss
  19. if iter % print_every == 0:
  20. print_loss_avg = print_loss_total / print_every
  21. print_loss_total = 0
  22. print('%s (%d %d%%) %.4f' % (timeSince(start, iter / n_iters),
  23. iter, iter / n_iters * 100, print_loss_avg))
  24. if iter % plot_every == 0:
  25. plot_loss_avg = plot_loss_total / plot_every
  26. plot_losses.append(plot_loss_avg)
  27. plot_loss_total = 0
  28. showPlot(plot_losses)

繪圖結(jié)果

使用訓(xùn)練時(shí)保存的損失值數(shù)組plot_losses,使用 matplotlib 進(jìn)行繪制。

  1. import matplotlib.pyplot as plt
  2. plt.switch_backend('agg')
  3. import matplotlib.ticker as ticker
  4. import numpy as np
  5. def showPlot(points):
  6. plt.figure()
  7. fig, ax = plt.subplots()
  8. # this locator puts ticks at regular intervals
  9. loc = ticker.MultipleLocator(base=0.2)
  10. ax.yaxis.set_major_locator(loc)
  11. plt.plot(points)

評(píng)價(jià)

評(píng)估與訓(xùn)練基本相同,但是沒有目標(biāo),因此我們只需將解碼器的預(yù)測(cè)反饋給每一步。 每當(dāng)它預(yù)測(cè)一個(gè)單詞時(shí),我們都會(huì)將其添加到輸出字符串中,如果它預(yù)測(cè)到 EOS 令牌,我們將在此處停止。 我們還將存儲(chǔ)解碼器的注意輸出,以供以后顯示。

  1. def evaluate(encoder, decoder, sentence, max_length=MAX_LENGTH):
  2. with torch.no_grad():
  3. input_tensor = tensorFromSentence(input_lang, sentence)
  4. input_length = input_tensor.size()[0]
  5. encoder_hidden = encoder.initHidden()
  6. encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)
  7. for ei in range(input_length):
  8. encoder_output, encoder_hidden = encoder(input_tensor[ei],
  9. encoder_hidden)
  10. encoder_outputs[ei] += encoder_output[0, 0]
  11. decoder_input = torch.tensor([[SOS_token]], device=device) # SOS
  12. decoder_hidden = encoder_hidden
  13. decoded_words = []
  14. decoder_attentions = torch.zeros(max_length, max_length)
  15. for di in range(max_length):
  16. decoder_output, decoder_hidden, decoder_attention = decoder(
  17. decoder_input, decoder_hidden, encoder_outputs)
  18. decoder_attentions[di] = decoder_attention.data
  19. topv, topi = decoder_output.data.topk(1)
  20. if topi.item() == EOS_token:
  21. decoded_words.append('<EOS>')
  22. break
  23. else:
  24. decoded_words.append(output_lang.index2word[topi.item()])
  25. decoder_input = topi.squeeze().detach()
  26. return decoded_words, decoder_attentions[:di + 1]

我們可以從訓(xùn)練集中評(píng)估隨機(jī)句子,并打印出輸入,目標(biāo)和輸出以做出一些主觀的質(zhì)量判斷:

  1. def evaluateRandomly(encoder, decoder, n=10):
  2. for i in range(n):
  3. pair = random.choice(pairs)
  4. print('>', pair[0])
  5. print('=', pair[1])
  6. output_words, attentions = evaluate(encoder, decoder, pair[0])
  7. output_sentence = ' '.join(output_words)
  8. print('<', output_sentence)
  9. print('')

訓(xùn)練與評(píng)估

有了所有這些幫助器功能(看起來像是額外的工作,但它使運(yùn)行多個(gè)實(shí)驗(yàn)更加容易),我們實(shí)際上可以初始化網(wǎng)絡(luò)并開始訓(xùn)練。

請(qǐng)記住,輸入句子已被嚴(yán)格過濾。 對(duì)于這個(gè)小的數(shù)據(jù)集,我們可以使用具有 256 個(gè)隱藏節(jié)點(diǎn)和單個(gè) GRU 層的相對(duì)較小的網(wǎng)絡(luò)。 在 MacBook CPU 上運(yùn)行約 40 分鐘后,我們將獲得一些合理的結(jié)果。

注意:

如果運(yùn)行此筆記本,則可以進(jìn)行訓(xùn)練,中斷內(nèi)核,評(píng)估并在以后繼續(xù)進(jìn)行訓(xùn)練。 注釋掉編碼器和解碼器已初始化的行,然后再次運(yùn)行trainIters

  1. hidden_size = 256
  2. encoder1 = EncoderRNN(input_lang.n_words, hidden_size).to(device)
  3. attn_decoder1 = AttnDecoderRNN(hidden_size, output_lang.n_words, dropout_p=0.1).to(device)
  4. trainIters(encoder1, attn_decoder1, 75000, print_every=5000)
  • ../_images/sphx_glr_seq2seq_translation_tutorial_002.png

得出:

  1. 1m 54s (- 26m 42s) (5000 6%) 2.8452
  2. 3m 44s (- 24m 19s) (10000 13%) 2.2926
  3. 5m 34s (- 22m 17s) (15000 20%) 1.9628
  4. 7m 24s (- 20m 23s) (20000 26%) 1.7224
  5. 9m 15s (- 18m 31s) (25000 33%) 1.4997
  6. 11m 7s (- 16m 41s) (30000 40%) 1.3610
  7. 12m 58s (- 14m 49s) (35000 46%) 1.2299
  8. 14m 48s (- 12m 57s) (40000 53%) 1.0881
  9. 16m 38s (- 11m 5s) (45000 60%) 0.9991
  10. 18m 29s (- 9m 14s) (50000 66%) 0.9053
  11. 20m 19s (- 7m 23s) (55000 73%) 0.8031
  12. 22m 8s (- 5m 32s) (60000 80%) 0.7141
  13. 23m 58s (- 3m 41s) (65000 86%) 0.6693
  14. 25m 48s (- 1m 50s) (70000 93%) 0.6342
  15. 27m 38s (- 0m 0s) (75000 100%) 0.5604
  1. evaluateRandomly(encoder1, attn_decoder1)

得出:

  1. > je suis tres serieux .
  2. = i m quite serious .
  3. < i m very serious . <EOS>
  4. > tu es creatif .
  5. = you re creative .
  6. < you re creative . <EOS>
  7. > j attends de vos nouvelles .
  8. = i m looking forward to hearing from you .
  9. < i m looking forward to hearing from you . <EOS>
  10. > tu es un de ces pauvres types !
  11. = you re such a jerk .
  12. < you re such a jerk . <EOS>
  13. > je ne suis pas si preoccupe .
  14. = i m not that worried .
  15. < i m not that worried . <EOS>
  16. > vous etes avides .
  17. = you re greedy .
  18. < you re greedy . <EOS>
  19. > ils ne sont pas satisfaits .
  20. = they re not happy .
  21. < they re not happy . <EOS>
  22. > nous avons tous peur .
  23. = we re all afraid .
  24. < we re all scared . <EOS>
  25. > nous sommes tous uniques .
  26. = we re all unique .
  27. < we re all unique . <EOS>
  28. > c est un tres chouette garcon .
  29. = he s a very nice boy .
  30. < he s a very nice boy . <EOS>

可視化注意力

注意機(jī)制的一個(gè)有用特性是其高度可解釋的輸出。 因?yàn)樗糜诩訖?quán)輸入序列的特定編碼器輸出,所以我們可以想象一下在每個(gè)時(shí)間步長(zhǎng)上網(wǎng)絡(luò)最關(guān)注的位置。

您可以簡(jiǎn)單地運(yùn)行plt.matshow(attentions)以將注意力輸出顯示為矩陣,其中列為輸入步驟,行為輸出步驟:

  1. output_words, attentions = evaluate(
  2. encoder1, attn_decoder1, "je suis trop froid .")
  3. plt.matshow(attentions.numpy())

../_images/sphx_glr_seq2seq_translation_tutorial_003.png

為了獲得更好的觀看體驗(yàn),我們將做一些額外的工作來添加軸和標(biāo)簽:

  1. def showAttention(input_sentence, output_words, attentions):
  2. # Set up figure with colorbar
  3. fig = plt.figure()
  4. ax = fig.add_subplot(111)
  5. cax = ax.matshow(attentions.numpy(), cmap='bone')
  6. fig.colorbar(cax)
  7. # Set up axes
  8. ax.set_xticklabels([''] + input_sentence.split(' ') +
  9. ['<EOS>'], rotation=90)
  10. ax.set_yticklabels([''] + output_words)
  11. # Show label at every tick
  12. ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
  13. ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
  14. plt.show()
  15. def evaluateAndShowAttention(input_sentence):
  16. output_words, attentions = evaluate(
  17. encoder1, attn_decoder1, input_sentence)
  18. print('input =', input_sentence)
  19. print('output =', ' '.join(output_words))
  20. showAttention(input_sentence, output_words, attentions)
  21. evaluateAndShowAttention("elle a cinq ans de moins que moi .")
  22. evaluateAndShowAttention("elle est trop petit .")
  23. evaluateAndShowAttention("je ne crains pas de mourir .")
  24. evaluateAndShowAttention("c est un jeune directeur plein de talent .")
  • ../_images/sphx_glr_seq2seq_translation_tutorial_004.png
  • ../_images/sphx_glr_seq2seq_translation_tutorial_005.png
  • ../_images/sphx_glr_seq2seq_translation_tutorial_006.png
  • ../_images/sphx_glr_seq2seq_translation_tutorial_007.png

得出:

  1. input = elle a cinq ans de moins que moi .
  2. output = she is five years years years years . <EOS>
  3. input = elle est trop petit .
  4. output = she is too short . <EOS>
  5. input = je ne crains pas de mourir .
  6. output = i m not scared of dying . <EOS>
  7. input = c est un jeune directeur plein de talent .
  8. output = he s a talented young director . <EOS>

練習(xí)題

  • 嘗試使用其他數(shù)據(jù)集另一對(duì)語言人機(jī)→機(jī)器(例如 IOT 命令)聊天→回復(fù)問題→答案
  • 用預(yù)先訓(xùn)練的單詞嵌入(例如 word2vec 或 GloVe)替換嵌入
  • 嘗試使用更多層,更多隱藏單元和更多句子。 比較訓(xùn)練時(shí)間和結(jié)果。
  • 如果您使用的翻原文件中,對(duì)具有兩個(gè)相同的詞組(I am test \t I am test),則可以將其用作自動(dòng)編碼器。 嘗試這個(gè):訓(xùn)練為自動(dòng)編碼器僅保存編碼器網(wǎng)絡(luò)從那里訓(xùn)練新的解碼器進(jìn)行翻譯

腳本的總運(yùn)行時(shí)間:(27 分鐘 45.966 秒)

Download Python source code: seq2seq_translation_tutorial.py Download Jupyter notebook: seq2seq_translation_tutorial.ipynb

由獅身人面像畫廊生成的畫廊


以上內(nèi)容是否對(duì)您有幫助:
在線筆記
App下載
App下載

掃描二維碼

下載編程獅App

公眾號(hào)
微信公眾號(hào)

編程獅公眾號(hào)