tf.keras.backend.rnn函數(shù)
tf.keras.backend.rnn(
step_function,
inputs,
initial_states,
go_backwards=False,
mask=None,
constants=None,
unroll=False,
input_length=None,
time_major=False,
zero_output_for_mask=False
)
定義在:tensorflow/python/keras/backend.py。
在張量的時(shí)間維度迭代。
參數(shù):
- step_function:RNN步驟函數(shù)。參數(shù):input:具有shape (samples, ...)的張量(沒有時(shí)間維度),表示在特定時(shí)間步驟的一批樣品的輸入。states:張量列表。返回:output:具有shape (samples, output_dim)的張量 (沒有時(shí)間維度)。new_states:張量列表,長度和shape與“states”相同。列表中的第一個(gè)狀態(tài)必須是前一個(gè)時(shí)間步的輸出張量。
- inputs:shape為(samples, time, ...) (至少3D)的時(shí)間數(shù)據(jù)的張量,或嵌套張量,并且每個(gè)都具有shape (samples, time, ...)。
- initial_states:shape為(samples, state_size) 的張量(無時(shí)間維度),包含step函數(shù)中使用的狀態(tài)的初始值。在state_size是嵌套形狀的情況下,initial_states的形狀也將遵循嵌套結(jié)構(gòu)。
- go_backwards:布爾值。如果為True,則以相反的順序?qū)r(shí)間維度進(jìn)行迭代,并返回相反的序列。
- mask:shape為(samples, time, 1)的二進(jìn)制張量,對于每個(gè)被屏蔽的元素都為零。
- constants:每個(gè)步驟傳遞的常量值列表。
- unroll:是否展開RNN或使用符號while_loop。
- input_length:如果指定,則假設(shè)時(shí)間維度為此長度。
- time_major:布爾值。如果為true,則輸入和輸出將位于shape (timesteps, batch, ...),而在False情況下,它將是 (batch, timesteps, ...)。使用time_major = True更有效,因?yàn)樗苊饬薘NN計(jì)算開始和結(jié)束時(shí)的轉(zhuǎn)置。但是,大多數(shù)TensorFlow數(shù)據(jù)都是批處理主數(shù)據(jù),因此默認(rèn)情況下,此函數(shù)接受批處理主要形式的輸入并發(fā)出輸出。
- zero_output_for_mask:布爾值。如果為True,則屏蔽時(shí)間步長的輸出將為零,而在False情況下,將返回先前時(shí)間步長的輸出。
返回:
一個(gè)元組,(last_output, outputs, new_states)。
last_output:shape為(samples, ...) 輸出的rnn的最新輸出。
outputs:shape為(samples, time, ...)的張量,其中每個(gè)條目 outputs[s, t] 是樣本 s 在時(shí)間 t 的步驟函數(shù)輸出值。
new_states:張量列表,步長函數(shù)返回的最新狀態(tài),shape為(samples, ...)。
可能引發(fā)的異常:
- ValueError:如果輸入維度小于3。
- ValueError:如果unroll是True,但輸入時(shí)間步長不是固定數(shù)字。
- ValueError:如果提供了mask(非None),但未提供狀態(tài)(len(states)== 0)。
更多建議: