TensorFlow函數(shù)教程:tf.nn.static_rnn

2020-07-17 09:56 更新

tf.nn.static_rnn函數(shù)

別名:

  • tf.contrib.rnn.static_rnn
  • tf.nn.static_rnn
tf.nn.static_rnn(
    cell,
    inputs,
    initial_state=None,
    dtype=None,
    sequence_length=None,
    scope=None
)

定義在:tensorflow/python/ops/rnn.py。

創(chuàng)建由RNNCell cell指定的循環(huán)神經(jīng)網(wǎng)絡。

生成的最簡單的RNN網(wǎng)絡形式是:

  state = cell.zero_state(...)
  outputs = []
  for input_ in inputs:
    output, state = cell(input_, state)
    outputs.append(output)
  return (outputs, state)

但是,還有一些其他選項:

可以提供初始狀態(tài)。如果提供sequence_length向量,則執(zhí)行動態(tài)計算。這種計算方法不計算超過最小批處理的最大序列長度的RNN步驟(從而節(jié)省計算時間),并且將示例的序列長度的狀態(tài)適當?shù)貍鞑サ阶罱K狀態(tài)輸出。

在批處理行b的時間t上執(zhí)行的動態(tài)計算:

  (output, state)(b, t) =
    (t >= sequence_length(b))
      ? (zeros(cell.output_size), states(b, sequence_length(b) - 1))
      : cell(input(b, t), state(b, t - 1))

參數(shù):

  • cell:RNNCell的一個實例。
  • inputs:輸入的長度為T的列表,每個Tensor具有shape [batch_size, input_size];或這些元素的嵌套元組。
  • initial_state:(可選)RNN的初始狀態(tài)。如果cell.state_size是整數(shù),則必須是具有適當?shù)念愋秃蛃hape為[batch_size, cell.state_size]的Tensor。如果cell.state_size是一個元組,這應該是具有shape [batch_size, s]的張量元組,其中s位于cell.state_size。
  • dtype:(可選)初始狀態(tài)和預期輸出的數(shù)據(jù)類型。如果未提供initial_state或RNN狀態(tài)具有異構類型,則為必需。
  • sequence_length:指定輸入中每個序列的長度。int32或int64向量(張量),大小為[batch_size],值位于[0, T)。
  • scope:用于創(chuàng)建子圖的VariableScope;默認為“rnn”。

返回:

(outputs, state)對,其中:

  • outputs的長度為T的列表(每個輸入一個),或這些元素的嵌套元組。
  • state是最終狀態(tài)

可能引發(fā)的異常:

  • TypeError:如果cell不是RNNCell的實例。
  • ValueError:如果inputs為None或是一個空列表,或者無法通過形狀推斷從輸入推斷輸入深度(列大?。?。

實例:

import tensorflow as tf

x=tf.Variable(tf.random_normal([2,4,3])) #[batch_size,timesteps,embedding_dim] 
x=tf.unstack(x,axis=1) #按時間步展開 
n_neurons = 5 #輸出神經(jīng)元數(shù)量
 
basic_cell = tf.contrib.rnn.BasicRNNCell(num_units=n_neurons)
output_seqs, states = tf.contrib.rnn.static_rnn(basic_cell,x, dtype=tf.float32)
 
print(len(output_seqs)) #四個時間步 
print(output_seqs[0]) #每個時間步輸出一個張量 
print(output_seqs[1]) #每個時間步輸出一個張量
print(states) #隱藏狀態(tài)

輸出結果如下:

4
Tensor("rnn/basic_rnn_cell/Tanh:0", shape=(2, 5), dtype=float32)
Tensor("rnn/basic_rnn_cell/Tanh_1:0", shape=(2, 5), dtype=float32)
Tensor("rnn/basic_rnn_cell/Tanh_3:0", shape=(2, 5), dtype=float32)


以上內(nèi)容是否對您有幫助:
在線筆記
App下載
App下載

掃描二維碼

下載編程獅App

公眾號
微信公眾號

編程獅公眾號