TensorFlow Seq2seq庫(contrib)

2019-01-31 18:12 更新

用于構(gòu)建 seq2seq 模型和動態(tài)解碼的模塊,建立在 tf.contrib.rnn 庫的頂部。

該庫由兩個主要組件組成:

  • tf.contrib.rnn.RNNCell 對象的新 attention 包裝。
  • 一種新的面向?qū)ο蟮膭討B(tài)解碼框架。

注意

attention 包裝是 RNNCell 包裝其他 RNNCell 對象并實現(xiàn) attention 的對象。attention 的形式由一個子類 tf.contrib.seq2seq.AttentionMechanism 決定。這些子類描述了在創(chuàng)建包裝時要使用的 attention 形式(例如,加法與乘法)。AttentionMechanism 的一個實例是由一個 memory 張量構(gòu)成,從中創(chuàng)建查詢鍵和值。

attention 機(jī)制

兩個基本的 attention 機(jī)制是:tf.contrib.seq2seq.BahdanauAttention (附加的 attention,參考)和 tf.contrib.seq2seq.LuongAttention(增加的 attention,參考)

該 memory 張量傳遞的 attention 機(jī)制的構(gòu)造,預(yù)計將被塑造 [batch_size, memory_max_time, memory_depth];并且通常一個額外的 memory_sequence_length 向量被接受。如果提供的話,memory 張量的行被零掩蔽,超過其真正的序列長度。

attention 機(jī)制也具有深度概念,通常被確定為構(gòu)造參數(shù) num_units。對于某些類型的 attention(如BahdanauAttention),查詢和內(nèi)存都將投射到深度 num_units 的張量。對于其他類型(如LuongAttention),num_units 應(yīng)該匹配查詢的深度;memory 張量將被投射到這個深度。

attention 包裝器

基本的 attention 包裝是 tf.contrib.seq2seq.DynamicAttentionWrapper。這個包裝器接受一個 RNNCell 實例,一個實例 AttentionMechanism 和一個 attention 深度參數(shù)(attention_size);以及允許自定義中間計算的幾個可選參數(shù)。

在每個時間步驟,這個包裝器執(zhí)行的基本計算是:

cell_inputs = concat([inputs, prev_state.attention], -1)
cell_output, next_cell_state = cell(cell_inputs, prev_state.cell_state)
score = attention_mechanism(cell_output)
alignments = softmax(score)
context = matmul(alignments, attention_mechanism.values)
attention = tf.layers.Dense(attention_size)(concat([cell_output, context], 1))
next_state = AttentionWrapperState(
  cell_state=next_cell_state,
  attention=attention)
output = attention
return output, next_state

在實踐中,許多中間計算是可配置的。例如,初始連接 inputs 和 prev_state.attention 可以用另一種混合功能代替。在從分?jǐn)?shù)計算對齊時,可以用其他選項替換函數(shù) softmax。最后,包裝器返回的輸出可以配置為值 cell_output 而不是 attention。

使用 DynamicAttentionWrapper 的好處是它能很好地與其他包裝器和下面描述的動態(tài)解碼器一起播放。例如,你可以寫:

cell = tf.contrib.rnn.DeviceWrapper(LSTMCell(512), "/device:GPU:0")
attention_mechanism = tf.contrib.seq2seq.LuongAttention(512, encoder_outputs)
attn_cell = tf.contrib.seq2seq.AttentionWrapper(
  cell, attention_mechanism, attention_size=256)
attn_cell = tf.contrib.rnn.DeviceWrapper(attn_cell, "/device:GPU:1")
top_cell = tf.contrib.rnn.DeviceWrapper(LSTMCell(512), "/device:GPU:1")
multi_cell = MultiRNNCell([attn_cell, top_cell])

所述 multi_rnn 單元將執(zhí)行對 GPU 0 底層計算; attention 計算將在 GPU 1 上執(zhí)行,并立即傳遞到也在 GPU 1 上計算的頂層。attention 也在時間上傳遞到下一個時間步,并在下一個時間步驟復(fù)制到 GPU 0單元。(注意:這只是一個使用的例子,而不是建議的設(shè)備分區(qū)策略。)

TensorFlow 動態(tài)解碼

解碼器基類和功能

  • tf.contrib.seq2seq.Decoder
  • tf.contrib.seq2seq.dynamic_decode

基本解碼器

  • tf.contrib.seq2seq.BasicDecoderOutput
  • tf.contrib.seq2seq.BasicDecoder

解碼器助手

  • tf.contrib.seq2seq.Helper
  • tf.contrib.seq2seq.CustomHelper
  • tf.contrib.seq2seq.GreedyEmbeddingHelper
  • tf.contrib.seq2seq.ScheduledEmbeddingTrainingHelper
  • tf.contrib.seq2seq.ScheduledOutputTrainingHelper
  • tf.contrib.seq2seq.TrainingHelper
以上內(nèi)容是否對您有幫助:
在線筆記
App下載
App下載

掃描二維碼

下載編程獅App

公眾號
微信公眾號

編程獅公眾號