W3Cschool
恭喜您成為首批注冊用戶
獲得88經(jīng)驗值獎勵
用于構(gòu)建 seq2seq 模型和動態(tài)解碼的模塊,建立在 tf.contrib.rnn 庫的頂部。
該庫由兩個主要組件組成:
attention 包裝是 RNNCell 包裝其他 RNNCell 對象并實現(xiàn) attention 的對象。attention 的形式由一個子類 tf.contrib.seq2seq.AttentionMechanism 決定。這些子類描述了在創(chuàng)建包裝時要使用的 attention 形式(例如,加法與乘法)。AttentionMechanism 的一個實例是由一個 memory 張量構(gòu)成,從中創(chuàng)建查詢鍵和值。
兩個基本的 attention 機(jī)制是:tf.contrib.seq2seq.BahdanauAttention (附加的 attention,參考)和 tf.contrib.seq2seq.LuongAttention(增加的 attention,參考)
該 memory 張量傳遞的 attention 機(jī)制的構(gòu)造,預(yù)計將被塑造 [batch_size, memory_max_time, memory_depth];并且通常一個額外的 memory_sequence_length 向量被接受。如果提供的話,memory 張量的行被零掩蔽,超過其真正的序列長度。
attention 機(jī)制也具有深度概念,通常被確定為構(gòu)造參數(shù) num_units。對于某些類型的 attention(如BahdanauAttention),查詢和內(nèi)存都將投射到深度 num_units 的張量。對于其他類型(如LuongAttention),num_units 應(yīng)該匹配查詢的深度;memory 張量將被投射到這個深度。
基本的 attention 包裝是 tf.contrib.seq2seq.DynamicAttentionWrapper。這個包裝器接受一個 RNNCell 實例,一個實例 AttentionMechanism 和一個 attention 深度參數(shù)(attention_size);以及允許自定義中間計算的幾個可選參數(shù)。
在每個時間步驟,這個包裝器執(zhí)行的基本計算是:
cell_inputs = concat([inputs, prev_state.attention], -1)
cell_output, next_cell_state = cell(cell_inputs, prev_state.cell_state)
score = attention_mechanism(cell_output)
alignments = softmax(score)
context = matmul(alignments, attention_mechanism.values)
attention = tf.layers.Dense(attention_size)(concat([cell_output, context], 1))
next_state = AttentionWrapperState(
cell_state=next_cell_state,
attention=attention)
output = attention
return output, next_state
在實踐中,許多中間計算是可配置的。例如,初始連接 inputs 和 prev_state.attention 可以用另一種混合功能代替。在從分?jǐn)?shù)計算對齊時,可以用其他選項替換函數(shù) softmax。最后,包裝器返回的輸出可以配置為值 cell_output 而不是 attention。
使用 DynamicAttentionWrapper 的好處是它能很好地與其他包裝器和下面描述的動態(tài)解碼器一起播放。例如,你可以寫:
cell = tf.contrib.rnn.DeviceWrapper(LSTMCell(512), "/device:GPU:0")
attention_mechanism = tf.contrib.seq2seq.LuongAttention(512, encoder_outputs)
attn_cell = tf.contrib.seq2seq.AttentionWrapper(
cell, attention_mechanism, attention_size=256)
attn_cell = tf.contrib.rnn.DeviceWrapper(attn_cell, "/device:GPU:1")
top_cell = tf.contrib.rnn.DeviceWrapper(LSTMCell(512), "/device:GPU:1")
multi_cell = MultiRNNCell([attn_cell, top_cell])
所述 multi_rnn 單元將執(zhí)行對 GPU 0 底層計算; attention 計算將在 GPU 1 上執(zhí)行,并立即傳遞到也在 GPU 1 上計算的頂層。attention 也在時間上傳遞到下一個時間步,并在下一個時間步驟復(fù)制到 GPU 0單元。(注意:這只是一個使用的例子,而不是建議的設(shè)備分區(qū)策略。)
Copyright©2021 w3cschool編程獅|閩ICP備15016281號-3|閩公網(wǎng)安備35020302033924號
違法和不良信息舉報電話:173-0602-2364|舉報郵箱:jubao@eeedong.com
掃描二維碼
下載編程獅App
編程獅公眾號
聯(lián)系方式:
更多建議: