tf.nn.bidirectional_dynamic_rnn函數(shù)
tf.nn.bidirectional_dynamic_rnn(
cell_fw,
cell_bw,
inputs,
sequence_length=None,
initial_state_fw=None,
initial_state_bw=None,
dtype=None,
parallel_iterations=None,
swap_memory=False,
time_major=False,
scope=None
)
定義在:tensorflow/python/ops/rnn.py.
請(qǐng)參閱指南:神經(jīng)網(wǎng)絡(luò)>遞歸神經(jīng)網(wǎng)絡(luò)
創(chuàng)建雙向遞歸神經(jīng)網(wǎng)絡(luò)的動(dòng)態(tài)版本.
獲取輸入并構(gòu)建獨(dú)立的前向和后向RNN.前向和后向單元的input_size必須匹配.默認(rèn)情況下,兩個(gè)方向的初始狀態(tài)為零(但可以選擇設(shè)置),并且不返回任何中間狀態(tài) - 如果沒有給出長(zhǎng)度,則為給定的 (傳入) 長(zhǎng)度完全展開網(wǎng)絡(luò),或者在未給出長(zhǎng)度的情況下完全展開網(wǎng)絡(luò).
參數(shù):
- cell_fw:用于正向的RNNCell的實(shí)例.
- cell_bw:用于反向的RNNCell的實(shí)例.
- inputs:RNN輸入.如果time_major == False(默認(rèn)值),則必須是形狀為[batch_size, max_time, ...]的張量,或這些元素的嵌套元組;如果time_major == True,則必須是形狀為[max_time, batch_size, ...]的張量,或這些元素的嵌套元組.
- sequence_length:(可選)int32/int64向量,size為[batch_size],包含批處理中每個(gè)序列的實(shí)際長(zhǎng)度.如果未提供,則假定所有批次條目都是完整序列;對(duì)每個(gè)序列從時(shí)間0到max_time應(yīng)用時(shí)間反轉(zhuǎn).
- initial_state_fw:(可選)正向RNN的初始狀態(tài).這必須是適當(dāng)類型和形狀為[batch_size, cell_fw.state_size]的張量.如果cell_fw.state_size是一個(gè)元組,這應(yīng)該是一個(gè)張量元組,在cell_fw.state_size中s具有形狀[batch_size, s].
- initial_state_bw:(可選)與initial_state_fw相同,但使用與cell_bw相應(yīng)的屬性.
- dtype:(可選)初始狀態(tài)和預(yù)期輸出的數(shù)據(jù)類型;如果未提供initial_states或RNN狀態(tài)具有異構(gòu)dtype,則為必需有的.
- parallel_iterations:(默認(rèn)值:32),并行運(yùn)行的迭代次數(shù);那些沒有任何時(shí)間依賴性并且可以并行運(yùn)行的操作將是;該參數(shù)用于交換空間的時(shí)間;值>> 1使用更多內(nèi)存但占用更少時(shí)間,而較小值使用較少內(nèi)存但計(jì)算時(shí)間較長(zhǎng).
- swap_memory:透明地交換正向推理中產(chǎn)生的張量,但是從GPU到CPU需要向后支撐.這允許訓(xùn)練通常不適合單個(gè)GPU的RNN,具有非常小的(或沒有)性能損失.
- time_major:inputs和outputs張量的形狀格式;如果是true,那些Tensors形狀必須為[max_time, batch_size, depth],如果false,這些Tensors的形狀必須為[batch_size, max_time, depth];使用time_major = True更有效,因?yàn)樗苊饬薘NN計(jì)算開始和結(jié)束時(shí)的轉(zhuǎn)置,但是,大多數(shù)TensorFlow數(shù)據(jù)都是批處理主數(shù)據(jù),因此默認(rèn)情況下,此函數(shù)接受輸入并以批處理主體形式發(fā)出輸出.
- scope:用于創(chuàng)建子圖的VariableScope;默認(rèn)為“bidirectional_rnn”.
返回:
元組(outputs,output_states),其中:
可能引發(fā)的異常:
- TypeError:如果cell_fw或cell_bw不是RNNCell的實(shí)例.
更多建議: