TensorFlow Seq2seq庫(kù)(contrib)

2019-01-31 18:12 更新

用于構(gòu)建 seq2seq 模型和動(dòng)態(tài)解碼的模塊,建立在 tf.contrib.rnn 庫(kù)的頂部。

該庫(kù)由兩個(gè)主要組件組成:

  • tf.contrib.rnn.RNNCell 對(duì)象的新 attention 包裝。
  • 一種新的面向?qū)ο蟮膭?dòng)態(tài)解碼框架。

注意

attention 包裝是 RNNCell 包裝其他 RNNCell 對(duì)象并實(shí)現(xiàn) attention 的對(duì)象。attention 的形式由一個(gè)子類(lèi) tf.contrib.seq2seq.AttentionMechanism 決定。這些子類(lèi)描述了在創(chuàng)建包裝時(shí)要使用的 attention 形式(例如,加法與乘法)。AttentionMechanism 的一個(gè)實(shí)例是由一個(gè) memory 張量構(gòu)成,從中創(chuàng)建查詢鍵和值。

attention 機(jī)制

兩個(gè)基本的 attention 機(jī)制是:tf.contrib.seq2seq.BahdanauAttention (附加的 attention,參考)和 tf.contrib.seq2seq.LuongAttention(增加的 attention,參考)

該 memory 張量傳遞的 attention 機(jī)制的構(gòu)造,預(yù)計(jì)將被塑造 [batch_size, memory_max_time, memory_depth];并且通常一個(gè)額外的 memory_sequence_length 向量被接受。如果提供的話,memory 張量的行被零掩蔽,超過(guò)其真正的序列長(zhǎng)度。

attention 機(jī)制也具有深度概念,通常被確定為構(gòu)造參數(shù) num_units。對(duì)于某些類(lèi)型的 attention(如BahdanauAttention),查詢和內(nèi)存都將投射到深度 num_units 的張量。對(duì)于其他類(lèi)型(如LuongAttention),num_units 應(yīng)該匹配查詢的深度;memory 張量將被投射到這個(gè)深度。

attention 包裝器

基本的 attention 包裝是 tf.contrib.seq2seq.DynamicAttentionWrapper。這個(gè)包裝器接受一個(gè) RNNCell 實(shí)例,一個(gè)實(shí)例 AttentionMechanism 和一個(gè) attention 深度參數(shù)(attention_size);以及允許自定義中間計(jì)算的幾個(gè)可選參數(shù)。

在每個(gè)時(shí)間步驟,這個(gè)包裝器執(zhí)行的基本計(jì)算是:

cell_inputs = concat([inputs, prev_state.attention], -1)
cell_output, next_cell_state = cell(cell_inputs, prev_state.cell_state)
score = attention_mechanism(cell_output)
alignments = softmax(score)
context = matmul(alignments, attention_mechanism.values)
attention = tf.layers.Dense(attention_size)(concat([cell_output, context], 1))
next_state = AttentionWrapperState(
  cell_state=next_cell_state,
  attention=attention)
output = attention
return output, next_state

在實(shí)踐中,許多中間計(jì)算是可配置的。例如,初始連接 inputs 和 prev_state.attention 可以用另一種混合功能代替。在從分?jǐn)?shù)計(jì)算對(duì)齊時(shí),可以用其他選項(xiàng)替換函數(shù) softmax。最后,包裝器返回的輸出可以配置為值 cell_output 而不是 attention。

使用 DynamicAttentionWrapper 的好處是它能很好地與其他包裝器和下面描述的動(dòng)態(tài)解碼器一起播放。例如,你可以寫(xiě):

cell = tf.contrib.rnn.DeviceWrapper(LSTMCell(512), "/device:GPU:0")
attention_mechanism = tf.contrib.seq2seq.LuongAttention(512, encoder_outputs)
attn_cell = tf.contrib.seq2seq.AttentionWrapper(
  cell, attention_mechanism, attention_size=256)
attn_cell = tf.contrib.rnn.DeviceWrapper(attn_cell, "/device:GPU:1")
top_cell = tf.contrib.rnn.DeviceWrapper(LSTMCell(512), "/device:GPU:1")
multi_cell = MultiRNNCell([attn_cell, top_cell])

所述 multi_rnn 單元將執(zhí)行對(duì) GPU 0 底層計(jì)算; attention 計(jì)算將在 GPU 1 上執(zhí)行,并立即傳遞到也在 GPU 1 上計(jì)算的頂層。attention 也在時(shí)間上傳遞到下一個(gè)時(shí)間步,并在下一個(gè)時(shí)間步驟復(fù)制到 GPU 0單元。(注意:這只是一個(gè)使用的例子,而不是建議的設(shè)備分區(qū)策略。)

TensorFlow 動(dòng)態(tài)解碼

解碼器基類(lèi)和功能

  • tf.contrib.seq2seq.Decoder
  • tf.contrib.seq2seq.dynamic_decode

基本解碼器

  • tf.contrib.seq2seq.BasicDecoderOutput
  • tf.contrib.seq2seq.BasicDecoder

解碼器助手

  • tf.contrib.seq2seq.Helper
  • tf.contrib.seq2seq.CustomHelper
  • tf.contrib.seq2seq.GreedyEmbeddingHelper
  • tf.contrib.seq2seq.ScheduledEmbeddingTrainingHelper
  • tf.contrib.seq2seq.ScheduledOutputTrainingHelper
  • tf.contrib.seq2seq.TrainingHelper
以上內(nèi)容是否對(duì)您有幫助:
在線筆記
App下載
App下載

掃描二維碼

下載編程獅App

公眾號(hào)
微信公眾號(hào)

編程獅公眾號(hào)