TensorFlow函數(shù)教程:tf.nn.dynamic_rnn

2019-01-31 13:47 更新

tf.nn.dynamic_rnn函數(shù)

tf.nn.dynamic_rnn(
    cell,
    inputs,
    sequence_length=None,
    initial_state=None,
    dtype=None,
    parallel_iterations=None,
    swap_memory=False,
    time_major=False,
    scope=None
)

定義在:tensorflow/python/ops/rnn.py.

請參閱指南:神經(jīng)網(wǎng)絡(luò)>遞歸神經(jīng)網(wǎng)絡(luò)

創(chuàng)建由 RNNCellcell指定的遞歸神經(jīng)網(wǎng)絡(luò)

執(zhí)行inputs的完全動態(tài)展開.

示例:

# create a BasicRNNCell
rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size)

# 'outputs' is a tensor of shape [batch_size, max_time, cell_state_size]

# defining initial state
initial_state = rnn_cell.zero_state(batch_size, dtype=tf.float32)

# 'state' is a tensor of shape [batch_size, cell_state_size]
outputs, state = tf.nn.dynamic_rnn(rnn_cell, input_data,
                                   initial_state=initial_state,
                                   dtype=tf.float32)
# create 2 LSTMCells
rnn_layers = [tf.nn.rnn_cell.LSTMCell(size) for size in [128, 256]]

# create a RNN cell composed sequentially of a number of RNNCells
multi_rnn_cell = tf.nn.rnn_cell.MultiRNNCell(rnn_layers)

# 'outputs' is a tensor of shape [batch_size, max_time, 256]
# 'state' is a N-tuple where N is the number of LSTMCells containing a
# tf.contrib.rnn.LSTMStateTuple for each cell
outputs, state = tf.nn.dynamic_rnn(cell=multi_rnn_cell,
                                   inputs=data,
                                   dtype=tf.float32)

參數(shù):

  • cell:RNNCell的一個(gè)實(shí)例.
  • inputs:RNN輸入.如果time_major == False(默認(rèn)),則是一個(gè)shape為[batch_size, max_time, ...]Tensor,或者這些元素的嵌套元組.如果time_major == True,則是一個(gè)shape為[max_time, batch_size, ...]Tensor,或這些元素的嵌套元組.這也可能是滿足此屬性的Tensors(可能是嵌套的)元組.前兩個(gè)維度必須匹配所有輸入,否則秩和其他形狀組件可能不同.在這種情況下,在每個(gè)時(shí)間步輸入到cell將復(fù)制這些元組的結(jié)構(gòu),時(shí)間維度(從中獲取時(shí)間)除外.在每個(gè)時(shí)間步輸入到個(gè)cell將是一個(gè)Tensor或(可能是嵌套的)Tensors元組,每個(gè)元素都有維度[batch_size, ...].
  • sequence_length:(可選)大小為[batch_size]的int32/int64的向量.超過批處理元素的序列長度時(shí)用于復(fù)制狀態(tài)和零輸出.所以它更多的是正確性而不是性能.
  • initial_state:(可選)RNN的初始狀態(tài).如果cell.state_size是整數(shù),則必須是具有適當(dāng)類型和shape為[batch_size, cell.state_size]Tensor.如果cell.state_size是一個(gè)元組,則應(yīng)該是張量元組,cell.state_size中為s設(shè)置shape[batch_size, s].
  • dtype:(可選)初始狀態(tài)和預(yù)期輸出的數(shù)據(jù)類型.如果未提供initial_state或RNN狀態(tài)具有異構(gòu)dtype,則是必需的.
  • parallel_iterations:(默認(rèn)值:32).并行運(yùn)行的迭代次數(shù).適用于那些沒有任何時(shí)間依賴性并且可以并行運(yùn)行的操作.該參數(shù)用于交換空間的時(shí)間.遠(yuǎn)大于1的值會使用更多內(nèi)存但占用更少時(shí)間,而較小值使用較少內(nèi)存但計(jì)算時(shí)間較長.
  • swap_memory:透明地交換推理中產(chǎn)生的張量,但是需要從GPU到CPU的支持.這允許訓(xùn)練通常不適合單個(gè)GPU的RNN,具有非常小的(或沒有)性能損失.
  • time_majorinputsoutputsTensor的形狀格式.如果是true,則這些 Tensors的shape必須為[max_time, batch_size, depth].如果是false,則這些Tensors的shape必須為[batch_size, max_time, depth].使用time_major = True更有效,因?yàn)樗苊饬薘NN計(jì)算開始和結(jié)束時(shí)的轉(zhuǎn)置.但是,大多數(shù)TensorFlow數(shù)據(jù)都是batch-major,因此默認(rèn)情況下,此函數(shù)接受輸入并以batch-major形式發(fā)出輸出.
  • scope:用于創(chuàng)建子圖的VariableScope;默認(rèn)為“rnn”.

返回:

一對(outputs, state),其中:

  • outputs:RNN輸出Tensor.

    如果time_major == False(默認(rèn)),這將是shape為[batch_size, max_time, cell.output_size]Tensor.

    如果time_major == True,這將是shape為[max_time, batch_size, cell.output_size]Tensor.

    注意,如果cell.output_size是整數(shù)或TensorShape對象的(可能是嵌套的)元組,那么outputs將是一個(gè)與cell.output_size具有相同結(jié)構(gòu)的元祖,它包含與cell.output_size中的形狀數(shù)據(jù)有對應(yīng)shape的Tensors.

  • state:最終的狀態(tài).如果cell.state_size是int,則會形成[batch_size, cell.state_size].如果它是TensorShape,則將形成[batch_size] + cell.state_size.如果它是一個(gè)(可能是嵌套的)int或TensorShape元組,那么這將是一個(gè)具有相應(yīng)shape的元組.如果單元格是LSTMCells,則state將是包含每個(gè)單元格的LSTMStateTuple的元組.

可能引發(fā)的異常:

  • TypeError:如果cell不是RNNCell的實(shí)例.
  • ValueError:如果輸入為None或是空列表.
以上內(nèi)容是否對您有幫助:
在線筆記
App下載
App下載

掃描二維碼

下載編程獅App

公眾號
微信公眾號

編程獅公眾號