TensorFlow處理RNN參數(shù)變量

2018-09-01 15:50 更新

tf.contrib.cudnn_rnn.RNNParamsSaveable


tf.contrib.cudnn_rnn.RNNParamsSaveable 類

定義在:tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py

用于處理 RNN 參數(shù)變量的 SaveableObject 實(shí)現(xiàn).

方法


__init__

__init__ (
params_to_canonical ,
canonical_to_params ,
param_variables ,
name = 'params_canonical'
)

創(chuàng)建一個(gè) RNNParamsSaveable 對(duì)象.

RNNParams 可以在檢查點(diǎn)文件中保存/恢復(fù),用于以規(guī)范格式保存/恢復(fù)權(quán)重和偏置參數(shù),其中參數(shù)逐層保存為張量.對(duì)于每個(gè)層,偏差張量在重量張量之后被保存.恢復(fù)時(shí),用戶可以根據(jù)需要命名 param_variables,并將權(quán)重和偏差張量恢復(fù)到這些變量.

對(duì)于 CudnnRNNRelu 或 CudnnRNNTanh,每個(gè)層的每個(gè)權(quán)重和每個(gè)偏移量都有兩個(gè)張量:張量0被用于從前一層輸入,張量1用于循環(huán)輸入.

對(duì)于 CudnnLSTM,每個(gè)層的每個(gè)權(quán)重和每個(gè)偏移量有8個(gè)張量;張量0-3被用于從前一層輸入;張量4-7用于循環(huán)輸入;張量0和4用于輸入門;張量1和5忘記門;張量2和6新的存儲(chǔ)門; 張量3和7是輸出門.

對(duì)于 CudnnGRU,每個(gè)層的每個(gè)權(quán)重和每個(gè)偏移量有6張張量;張量0-2被用于從前一層輸入;張量3-5用于循環(huán)輸入;張量0和3用于復(fù)位門;張量1和4更新門;張量2和5新的存儲(chǔ)門.

ARGS:

  • params_to_canonical:一種函數(shù), 用于將參數(shù)從特定格式轉(zhuǎn)換為 cuDNN 或其他 RNN ops 轉(zhuǎn)換到規(guī)范格式._CudnnRNN params_to_canonical () 應(yīng)在這里提供.
  • canonical_to_params:用于將參數(shù)從規(guī)范格式轉(zhuǎn)換為 cuDNN 或其他 RNN ops 的特定格式的函數(shù).函數(shù)必須返回一個(gè)標(biāo)量 (如 cuDNN) 或元組.此函數(shù)可以是 _CudnnRN.
  • param_variables:特定窗體中參數(shù)的變量列表.對(duì)于 cuDNN RNN ops,這是一個(gè)單一的加權(quán)和偏見(jiàn)合并變量;對(duì)于其他 RNN ops, 這可能是多個(gè)未或部分合并的變量, 分別用于權(quán)重和偏差.
  • name:RNNParamsSaveable 對(duì)象的名稱.

restore

restore(
restored_tensors ,
restored_shapes
)


以上內(nèi)容是否對(duì)您有幫助:
在線筆記
App下載
App下載

掃描二維碼

下載編程獅App

公眾號(hào)
微信公眾號(hào)

編程獅公眾號(hào)