W3Cschool
恭喜您成為首批注冊用戶
獲得88經(jīng)驗值獎勵
tf.nn.static_rnn(
cell,
inputs,
initial_state=None,
dtype=None,
sequence_length=None,
scope=None
)
定義在:tensorflow/python/ops/rnn.py。
創(chuàng)建由RNNCell cell
指定的循環(huán)神經(jīng)網(wǎng)絡(luò)。
生成的最簡單的RNN網(wǎng)絡(luò)形式是:
state = cell.zero_state(...)
outputs = []
for input_ in inputs:
output, state = cell(input_, state)
outputs.append(output)
return (outputs, state)
但是,還有一些其他選項:
可以提供初始狀態(tài)。如果提供sequence_length向量,則執(zhí)行動態(tài)計算。這種計算方法不計算超過最小批處理的最大序列長度的RNN步驟(從而節(jié)省計算時間),并且將示例的序列長度的狀態(tài)適當(dāng)?shù)貍鞑サ阶罱K狀態(tài)輸出。
在批處理行b的時間t上執(zhí)行的動態(tài)計算:
(output, state)(b, t) =
(t >= sequence_length(b))
? (zeros(cell.output_size), states(b, sequence_length(b) - 1))
: cell(input(b, t), state(b, t - 1))
參數(shù):
返回:
(outputs, state)對,其中:
可能引發(fā)的異常:
import tensorflow as tf
x=tf.Variable(tf.random_normal([2,4,3])) #[batch_size,timesteps,embedding_dim]
x=tf.unstack(x,axis=1) #按時間步展開
n_neurons = 5 #輸出神經(jīng)元數(shù)量
basic_cell = tf.contrib.rnn.BasicRNNCell(num_units=n_neurons)
output_seqs, states = tf.contrib.rnn.static_rnn(basic_cell,x, dtype=tf.float32)
print(len(output_seqs)) #四個時間步
print(output_seqs[0]) #每個時間步輸出一個張量
print(output_seqs[1]) #每個時間步輸出一個張量
print(states) #隱藏狀態(tài)
輸出結(jié)果如下:
4
Tensor("rnn/basic_rnn_cell/Tanh:0", shape=(2, 5), dtype=float32)
Tensor("rnn/basic_rnn_cell/Tanh_1:0", shape=(2, 5), dtype=float32)
Tensor("rnn/basic_rnn_cell/Tanh_3:0", shape=(2, 5), dtype=float32)
Copyright©2021 w3cschool編程獅|閩ICP備15016281號-3|閩公網(wǎng)安備35020302033924號
違法和不良信息舉報電話:173-0602-2364|舉報郵箱:jubao@eeedong.com
掃描二維碼
下載編程獅App
編程獅公眾號
聯(lián)系方式:
更多建議: