TensorFlow函數(shù):tf.nn.bidirectional_dynamic_rnn

2019-01-31 11:30 更新

tf.nn.bidirectional_dynamic_rnn函數(shù)

tf.nn.bidirectional_dynamic_rnn(
    cell_fw,
    cell_bw,
    inputs,
    sequence_length=None,
    initial_state_fw=None,
    initial_state_bw=None,
    dtype=None,
    parallel_iterations=None,
    swap_memory=False,
    time_major=False,
    scope=None
)

定義在:tensorflow/python/ops/rnn.py.

請(qǐng)參閱指南:神經(jīng)網(wǎng)絡(luò)>遞歸神經(jīng)網(wǎng)絡(luò)

創(chuàng)建雙向遞歸神經(jīng)網(wǎng)絡(luò)的動(dòng)態(tài)版本.

獲取輸入并構(gòu)建獨(dú)立的前向和后向RNN.前向和后向單元的input_size必須匹配.默認(rèn)情況下,兩個(gè)方向的初始狀態(tài)為零(但可以選擇設(shè)置),并且不返回任何中間狀態(tài) - 如果沒有給出長(zhǎng)度,則為給定的 (傳入) 長(zhǎng)度完全展開網(wǎng)絡(luò),或者在未給出長(zhǎng)度的情況下完全展開網(wǎng)絡(luò).

參數(shù):

  • cell_fw:用于正向的RNNCell的實(shí)例.
  • cell_bw:用于反向的RNNCell的實(shí)例.
  • inputs:RNN輸入.如果time_major == False(默認(rèn)值),則必須是形狀為[batch_size, max_time, ...]的張量,或這些元素的嵌套元組;如果time_major == True,則必須是形狀為[max_time, batch_size, ...]的張量,或這些元素的嵌套元組.
  • sequence_length:(可選)int32/int64向量,size為[batch_size],包含批處理中每個(gè)序列的實(shí)際長(zhǎng)度.如果未提供,則假定所有批次條目都是完整序列;對(duì)每個(gè)序列從時(shí)間0到max_time應(yīng)用時(shí)間反轉(zhuǎn).
  • initial_state_fw:(可選)正向RNN的初始狀態(tài).這必須是適當(dāng)類型和形狀為[batch_size, cell_fw.state_size]的張量.如果cell_fw.state_size是一個(gè)元組,這應(yīng)該是一個(gè)張量元組,在cell_fw.state_size中s具有形狀[batch_size, s].
  • initial_state_bw:(可選)與initial_state_fw相同,但使用與cell_bw相應(yīng)的屬性.
  • dtype:(可選)初始狀態(tài)和預(yù)期輸出的數(shù)據(jù)類型;如果未提供initial_states或RNN狀態(tài)具有異構(gòu)dtype,則為必需有的.
  • parallel_iterations:(默認(rèn)值:32),并行運(yùn)行的迭代次數(shù);那些沒有任何時(shí)間依賴性并且可以并行運(yùn)行的操作將是;該參數(shù)用于交換空間的時(shí)間;值>> 1使用更多內(nèi)存但占用更少時(shí)間,而較小值使用較少內(nèi)存但計(jì)算時(shí)間較長(zhǎng).
  • swap_memory:透明地交換正向推理中產(chǎn)生的張量,但是從GPU到CPU需要向后支撐.這允許訓(xùn)練通常不適合單個(gè)GPU的RNN,具有非常小的(或沒有)性能損失.
  • time_major:inputs和outputs張量的形狀格式;如果是true,那些Tensors形狀必須為[max_time, batch_size, depth],如果false,這些Tensors的形狀必須為[batch_size, max_time, depth];使用time_major = True更有效,因?yàn)樗苊饬薘NN計(jì)算開始和結(jié)束時(shí)的轉(zhuǎn)置,但是,大多數(shù)TensorFlow數(shù)據(jù)都是批處理主數(shù)據(jù),因此默認(rèn)情況下,此函數(shù)接受輸入并以批處理主體形式發(fā)出輸出.
  • scope:用于創(chuàng)建子圖的VariableScope;默認(rèn)為“bidirectional_rnn”.

返回:

元組(outputs,output_states),其中:

  • outputs:包含正向和反向rnn輸出的元組(output_fw,output_bw).

    • 如果time_major == False(默認(rèn)值),則output_fw將是形狀為[batch_size, max_time, cell_fw.output_size]的張量,則output_bw將是形狀為[batch_size, max_time, cell_bw.output_size]的張量;
    • 如果time_major == True,則output_fw將是形狀為[max_time, batch_size, cell_fw.output_size]的張量;output_bw將會(huì)是形狀為[max_time, batch_size, cell_bw.output_size]的張量.
      與bidirectional_rnn不同,它返回一個(gè)元組而不是單個(gè)連接的張量.如果優(yōu)選連接的,則正向和反向輸出可以連接為tf.concat(outputs, 2).
  • output_states:包含雙向rnn的正向和反向最終狀態(tài)的元組.

可能引發(fā)的異常:

  • TypeError:如果cell_fw或cell_bw不是RNNCell的實(shí)例.
以上內(nèi)容是否對(duì)您有幫助:
在線筆記
App下載
App下載

掃描二維碼

下載編程獅App

公眾號(hào)
微信公眾號(hào)

編程獅公眾號(hào)