TensorFlow函數(shù):tf.estimator.WarmStartSettings

2018-09-29 11:05 更新

tf.estimator.WarmStartSettings函數(shù)

WarmStartSettings類

定義在:tensorflow/python/estimator/warm_starting_util.py.

在Estimators中進(jìn)行warm-starting的設(shè)置.

示例:使用 DNNEstimator 罐頭

emb_vocab_file = tf.feature_column.embedding_column(
    tf.feature_column.categorical_column_with_vocabulary_file(
        "sc_vocab_file", "new_vocab.txt", vocab_size=100),
    dimension=8)
emb_vocab_list = tf.feature_column.embedding_column(
    tf.feature_column.categorical_column_with_vocabulary_list(
        "sc_vocab_list", vocabulary_list=["a", "b"]),
    dimension=8)
estimator = tf.estimator.DNNClassifier(
  hidden_units=[128, 64], feature_columns=[emb_vocab_file, emb_vocab_list],
  warm_start_from=ws)

其中ws可以定義為:

模型中warm-start的所有權(quán)重(輸入層和隱藏權(quán)重).可以提供目錄或特定的檢查點(diǎn)(在前者的情況下,將使用最新的檢查點(diǎn)):

ws = WarmStartSettings(ckpt_to_initialize_from="/tmp")
ws = WarmStartSettings(ckpt_to_initialize_from="/tmp/model-1000")

僅warm-start啟動(dòng)嵌入(輸入層)及其累加器變量:

ws = WarmStartSettings(ckpt_to_initialize_from="/tmp",
                       vars_to_warm_start=".*input_layer.*")

warm-start除優(yōu)化器累加器變量(DNN默認(rèn)為Adagrad)之外的所有內(nèi)容:

ws = WarmStartSettings(ckpt_to_initialize_from="/tmp",
                       vars_to_warm_start="^(?!.*(Adagrad))")

warm-start所有權(quán)重,但與sc_vocab_file對(duì)應(yīng)的嵌入?yún)?shù)與當(dāng)前模型中使用的詞匯不同:

vocab_info = ws_util.VocabInfo(
    new_vocab=sc_vocab_file.vocabulary_file,
    new_vocab_size=sc_vocab_file.vocabulary_size,
    num_oov_buckets=sc_vocab_file.num_oov_buckets,
    old_vocab="old_vocab.txt"
)
ws = WarmStartSettings(
    ckpt_to_initialize_from="/tmp",
    var_name_to_vocab_info={
        "input_layer/sc_vocab_file_embedding/embedding_weights": vocab_info
    })

僅warm-start sc_vocab_file嵌入(并且沒有其他變量),它們與當(dāng)前模型中使用的詞匯不同:

vocab_info = ws_util.VocabInfo(
    new_vocab=sc_vocab_file.vocabulary_file,
    new_vocab_size=sc_vocab_file.vocabulary_size,
    num_oov_buckets=sc_vocab_file.num_oov_buckets,
    old_vocab="old_vocab.txt"
)
ws = WarmStartSettings(
    ckpt_to_initialize_from="/tmp",
    vars_to_warm_start=None,
    var_name_to_vocab_info={
        "input_layer/sc_vocab_file_embedding/embedding_weights": vocab_info
    })

對(duì)所有權(quán)重進(jìn)行warm-start,但sc_vocab_file對(duì)應(yīng)的參數(shù)與當(dāng)前檢查點(diǎn)中使用的詞匯不同,只有100個(gè)項(xiàng)被使用:

vocab_info = ws_util.VocabInfo(
    new_vocab=sc_vocab_file.vocabulary_file,
    new_vocab_size=sc_vocab_file.vocabulary_size,
    num_oov_buckets=sc_vocab_file.num_oov_buckets,
    old_vocab="old_vocab.txt",
    old_vocab_size=100
)
ws = WarmStartSettings(
    ckpt_to_initialize_from="/tmp",
    var_name_to_vocab_info={
        "input_layer/sc_vocab_file_embedding/embedding_weights": vocab_info
    })

warm-start所有權(quán)重,但sc_vocab_file對(duì)應(yīng)的參數(shù)與當(dāng)前檢查點(diǎn)中使用的詞匯不同,sc_vocab_list對(duì)應(yīng)的參數(shù)與當(dāng)前檢查點(diǎn)有不同的名稱:

vocab_info = ws_util.VocabInfo(
    new_vocab=sc_vocab_file.vocabulary_file,
    new_vocab_size=sc_vocab_file.vocabulary_size,
    num_oov_buckets=sc_vocab_file.num_oov_buckets,
    old_vocab="old_vocab.txt",
    old_vocab_size=100
)
ws = WarmStartSettings(
    ckpt_to_initialize_from="/tmp",
    var_name_to_vocab_info={
        "input_layer/sc_vocab_file_embedding/embedding_weights": vocab_info
    },
    var_name_to_prev_var_name={
        "input_layer/sc_vocab_list_embedding/embedding_weights":
            "old_tensor_name"
    })

屬性:

  • ckpt_to_initialize_from:[必需]一個(gè)字符串,用于指定具有檢查點(diǎn)文件的目錄或檢查點(diǎn)的路徑,以便從中啟動(dòng)模型參數(shù).
  • vars_to_warm_start:[可選]一個(gè)正則表達(dá)式,用于捕獲要啟動(dòng)哪個(gè)變量.默認(rèn)為'.*',它會(huì)warm-start所有變量.如果None明確給出,只有var_name_to_vocab_info中指定的變量將被warm-start.
  • var_name_to_vocab_info:[可選]字典變量名稱(字符串)的VocabInfo.變量名稱應(yīng)該是“完整的”變量,而不是分區(qū)的名稱.如果沒有明確提供,則假定該變量沒有詞匯表.
  • var_name_to_prev_var_name:[可選]將變量名稱(字符串)指定為之前ckpt_to_initialize_from中訓(xùn)練的變量的名稱.如果未明確提供,則假定變量的名稱在前一個(gè)檢查點(diǎn)和當(dāng)前模型之間相同.

函數(shù)屬性

  • ckpt_to_initialize_from

    字段編號(hào)0的別名

  • var_name_to_prev_var_name

    字段編號(hào)3的別名

  • var_name_to_vocab_info

    字段編號(hào)2的別名

  • vars_to_warm_start

    字段編號(hào)1的別名

函數(shù)方法

__new__

@staticmethod
__new__(
    cls,
    ckpt_to_initialize_from,
    vars_to_warm_start='.*',
    var_name_to_vocab_info=None,
    var_name_to_prev_var_name=None
)
以上內(nèi)容是否對(duì)您有幫助:
在線筆記
App下載
App下載

掃描二維碼

下載編程獅App

公眾號(hào)
微信公眾號(hào)

編程獅公眾號(hào)