TensorFlow的estimator類函數(shù):tf.estimator.Estimator

2018-04-27 09:55 更新

tf.estimator.Estimator函數(shù)

Estimator類

定義在:tensorflow/python/estimator/estimator.py

estimator類對TensorFlow模型進行訓(xùn)練和計算.

Estimator對象包裝由model_fn指定的模型,其中,給定輸入和其他一些參數(shù),返回需要進行訓(xùn)練、計算,或預(yù)測的操作.

所有輸出(檢查點,事件文件等)都被寫入model_dir或其子目錄.如果model_dir未設(shè)置,則使用臨時目錄.

可以通過RunConfig對象(包含了有關(guān)執(zhí)行環(huán)境的信息)傳遞config參數(shù).它被傳遞給model_fn,如果model_fn有一個名為“config”的參數(shù)(和輸入函數(shù)以相同的方式).如果該config參數(shù)未被傳遞,則由Estimator進行實例化.不傳遞配置意味著使用對本地執(zhí)行有用的默認(rèn)值.Estimator使配置對模型可用(例如,允許根據(jù)可用的工作人員數(shù)量進行專業(yè)化),并且還使用其一些字段來控制內(nèi)部,特別是關(guān)于檢查點.

該params參數(shù)包含hyperparameter,如果model_fn有一個名為“PARAMS”的參數(shù),并且以相同的方式傳遞給輸入函數(shù),則將它傳遞給 model_fn.Estimator只是沿著參數(shù)傳遞,并不檢查它.因此,params的結(jié)構(gòu)完全取決于開發(fā)人員.

不能在子類中重寫任何Estimator方法(其構(gòu)造函數(shù)強制執(zhí)行此操作).子類應(yīng)使用model_fn來配置基類,并且可以添加實現(xiàn)專門功能的方法.

Eager兼容性

estimator與eager執(zhí)行不兼容.

屬性

  • config
  • model_dir
  • model_fn
    返回綁定到self.params的model_fn.
    返回:返回具有以下簽名的model_fn: def model_fn(features, labels, mode, config)
  • params

方法

__init__

__init__(
    model_fn,
    model_dir=None,
    config=None,
    params=None,
    warm_start_from=None
)

構(gòu)造一個Estimator實例.

請參閱Estimator了解更多信息.啟動一個Estimator的方法如下所示:

estimator = tf.estimator.DNNClassifier(
    feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb],
    hidden_units=[1024, 512, 256],
    warm_start_from="/path/to/checkpoint/dir")

有關(guān)warm-start啟動配置的更多詳細(xì)信息,請參閱WarmStartSettings.

參數(shù):

  • model_fn:模型函數(shù),具有以下簽名:
      ARGS.
    • features:這是從input_fn傳遞給train、evaluate和predict返回的第一個項目.這應(yīng)該是一個相同的單一的Tensor或dict.
    • labels:這是從input_fn傳遞給train、evaluate和predict返回的第二個項目.這應(yīng)該是相同的單個Tensor或dict(對于multi-head模型).如果模式是ModeKeys.PREDICT,則將傳遞labels=None.如果model_fn簽名不接受mode,model_fn必須仍然能夠處理labels=None.
    • mode:可選的.指定train、evaluate和predict.參考ModeKeys.
    • params:hyperparameters的可選字典.將在params參數(shù)中接收傳遞給Estimator的內(nèi)容.這允許從hyperparameters調(diào)整來配置Estimator.
    • config:可選配置對象.將收到傳遞給Estimator的config參數(shù)或默認(rèn)值config.允許根據(jù)配置(如num_ps_replicas或model_dir)更新您的model_fn中的內(nèi)容.
    • 返回:EstimatorSpec
  • model_dir:保存模型參數(shù)、圖形等的目錄.這也可用于將目錄中的檢查點加載到Estimator中,以繼續(xù)訓(xùn)練以前保存的模型.如果為PathLike對象,則路徑將被解析.如果為None,則將使用config中的model_dir(如果設(shè)置的話).如果兩者都設(shè)置,則它們必須相同.如果兩者都是None,則會使用臨時目錄.
  • config:配置對象.
  • params:dict將傳遞到model_fn中的hyperparameters.key是參數(shù)的名稱,value是基本的Python類型.
  • warm_start_from:可選的字符串文件路徑,用于從warm-start的檢查點;或tf.estimator.WarmStartSettings對象,用于完全配置warm-start.如果提供字符串文件路徑而不是WarmStartSettings,則所有變量都是warm-start的,并且假定詞匯表和張量名稱未更改.

可能引發(fā)的異常:

  • RuntimeError:如果eager執(zhí)行已啟用.
  • ValueError:參數(shù)model_fn不匹配params.
  • ValueError:如果這是通過子類調(diào)用的,并且該類重寫了Estimator的一個成員.

evaluate

evaluate(
    input_fn,
    steps=None,
    hooks=None,
    checkpoint_path=None,
    name=None
)

計算給定計算數(shù)據(jù)input_fn的模型.

對于每個步驟來說,調(diào)用input_fn返回一批數(shù)據(jù).計算直到: -steps批處理被處理,或-input_fn引發(fā)輸入結(jié)束異常(OutOfRangeError或StopIteration).

參數(shù):

  • input_fn:構(gòu)造用于計算的輸入數(shù)據(jù)的函數(shù).有關(guān)更多信息,請參閱TensorFlow入門.該函數(shù)應(yīng)該構(gòu)造并返回下列選項之一:
    • tf.data.Dataset對象:Dataset對象的輸出必須是一個具有相同約束的元組(特征(features),標(biāo)簽(labels)),其約束條件與下面相同.
    • tuple (features, labels):其中features是Tensor或者名為Tensor的字符串特征的字典,而labels是Tensor或者名為Tensor的字符串標(biāo)簽的字典.這兩個特征和標(biāo)簽都由model_fn消耗.他們應(yīng)該滿足model_fn對輸入的期望.
  • steps:計算模型所需的步驟數(shù).如果為None,則計算直到input_fn引發(fā)輸入異常時結(jié)束.
  • hooks:SessionRunHook子類實例列表.用于計算調(diào)用中的回調(diào).
  • checkpoint_path:計算特定檢查點的路徑.如果為None,則使用model_dir中的最新檢查點.
  • name:需要使用的計算的名稱,如果用戶需要在不同的數(shù)據(jù)集上運行多個計算(如培訓(xùn)數(shù)據(jù)和測試數(shù)據(jù)).不同計算的度量標(biāo)準(zhǔn)保存在單獨的文件夾中,并單獨出現(xiàn)在tensorboard中.

返回值:

返回一個包含按name為鍵的model_fn中指定的計算指標(biāo)的詞典,以及包含執(zhí)行此技術(shù)的全局步驟的值的條目global_step.

可能引發(fā)的異常:

  • ValueError:如果steps <= 0.
  • ValueError:如果沒有模型被訓(xùn)練,名為model_dir,或者給定checkpoint_path是空的.

export_savedmodel

export_savedmodel(
    export_dir_base,
    serving_input_receiver_fn,
    assets_extra=None,
    as_text=False,
    checkpoint_path=None,
    strip_default_attrs=False
)

將推理圖作為SavedModel導(dǎo)出到給定的目錄中.

該方法通過首先調(diào)用serving_input_receiver_fn來獲取特征Tensors來構(gòu)建一個新圖,然后調(diào)用這個Estimator的model_fn來基于這些特征生成模型圖.它在新的會話中將給定的檢查點恢復(fù)到該圖中.最后它會在給定的export_dir_base下面創(chuàng)建一個時間戳導(dǎo)出目錄,并在其中寫入一個SavedModel,其中包含從此會話保存的單個MetaGraphDef.

導(dǎo)出的MetaGraphDef將為從model_fn返回的export_outputs字典的每個元素提供一個SignatureDef,該字典使用相同的key命名.其中一個key始終為signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY,指示在服務(wù)請求未指定簽名時將提供哪個簽名.對于每個簽名,輸出由相應(yīng)的ExportOutputs提供,并且輸入始終是由serving_input_receiver_fn提供的輸入接收器.

額外的資產(chǎn)可以通過assets_extra參數(shù)寫入SavedModel.這應(yīng)該是一個字典,其中每個key給出與assets.extra目錄相關(guān)的目標(biāo)路徑(包括文件名).相應(yīng)的值給出了要復(fù)制的源文件的完整路徑.例如,在不重命名的情況下復(fù)制單個文件的簡單情況被指定為{'my_asset_file.txt': '/path/to/my_asset_file.txt'}.

參數(shù):

  • export_dir_base:包含一個目錄的字符串,在該目錄中創(chuàng)建包含導(dǎo)出的SavedModels的時間戳子目錄.
  • serving_input_receiver_fn:一個不帶參數(shù)并返回一個ServingInputReceiver的函數(shù).
  • assets_extra:指定如何在導(dǎo)出的SavedModel中填充assets.extra目錄的字典,如果不需要額外的資產(chǎn),則為 None.
  • as_text:是否以文本格式編寫SavedModel原型.
  • checkpoint_path:要導(dǎo)出的檢查點路徑.如果None(默認(rèn)),則選擇在模型目錄中找到的最近檢查點.
  • strip_default_attrs:布爾值.如果True,則將從NodeDefs中刪除默認(rèn)值屬性.

返回值:

導(dǎo)出目錄的字符串路徑.

可能引發(fā)的異常:

  • ValueError:如果未提供serving_input_receiver_fn,則不提供export_outputs,或者找不到檢查點.

get_variable_names

get_variable_names()

返回此模型中所有變量名稱的列表.

返回值:

返回名字列表.

可能引發(fā)的異常:

  • ValueError:如果Estimator尚未產(chǎn)生檢查點.

get_variable_value

get_variable_value(name)

返回由名稱給出的變量的值.

參數(shù):

  • name:字符串或字符串列表,張量的名稱.

返回值:

Numpy數(shù)組 - 張量的值.

可能引發(fā)的異常:

  • ValueError:如果Estimator尚未產(chǎn)生檢查點.

latest_checkpoint

latest_checkpoint()

查找model_dir中最新保存的檢查點文件的文件名.

返回值:

返回最新檢查點的完整路徑或None(未找到檢查點).

predict

predict(
    input_fn,
    predict_keys=None,
    hooks=None,
    checkpoint_path=None
)

對給定的features產(chǎn)生預(yù)測.

參數(shù):

  • input_fn:構(gòu)造特征的函數(shù).預(yù)測繼續(xù),直到input_fn引發(fā)輸入端異常(OutOfRangeError或StopIteration).有關(guān)更多信息,請參閱TensorFlow入門.該函數(shù)應(yīng)該構(gòu)造并返回下列之一:
    • tf.data.Dataset對象:Dataset對象的輸出必須具有與下面相同的約束.
    • features:一個Tensor或者名為Tensor的字符串特征的字典.feature被model_fn消耗.他們應(yīng)該滿足model_fn對輸入的期望.
    • 一個元組,在這種情況下,第一個項被提取為feature.
  • predict_keys:str列表,要預(yù)測的鍵名稱.如果EstimatorSpec.predictions是字典,則使用該方法.如果使用predict_keys,則剩余的預(yù)測將從字典中過濾.如果None,則返回全部.
  • hooks:SessionRunHook子類實例列表.用于預(yù)測調(diào)用中的回調(diào).
  • checkpoint_path:要預(yù)測的特定檢查點的路徑.如果為None,則使用model_dir中的最新的檢查點.

返回值:

predictions張量的計算值.

可能引發(fā)的異常:

  • ValueError:在model_dir中找不到訓(xùn)練有素的模型.
  • ValueError:如果批次的預(yù)測長度不相同.
  • ValueError:如果predict_keys和predictions之間有沖突.例如,如果predict_keys不是None,但EstimatorSpec.predictions不是一個dict.

train

train(
    input_fn,
    hooks=None,
    steps=None,
    max_steps=None,
    saving_listeners=None
)

訓(xùn)練給定訓(xùn)練數(shù)據(jù)input_fn的模型.

參數(shù):

  • input_fn:提供作為minibatches培訓(xùn)的輸入數(shù)據(jù)的函數(shù).有關(guān)更多信息,請參閱TensorFlow入門.該函數(shù)應(yīng)該構(gòu)造并返回下列之一:
    • tf.data.Dataset對象:Dataset對象的輸出必須是一個具有相同約束的元組(特征,標(biāo)簽)((features, labels)),其約束條件與下面相同.
    • tuple (features, labels):其中features是一個Tensor或者名為Tensor的字符串特征的字典,labels是一個Tensor或者名為Tensor的字符串標(biāo)簽的字典.這兩個特征和標(biāo)簽都由model_fn消耗.他們應(yīng)該滿足model_fn對輸入的期望.
  • hooks:SessionRunHook子類實例列表.用于訓(xùn)練循環(huán)內(nèi)的回調(diào).
  • steps:訓(xùn)練模型的步驟數(shù).如果為None,則永遠(yuǎn)訓(xùn)練或訓(xùn)練直到input_fn產(chǎn)生OutOfRange錯誤或StopIteration異常.“steps”逐步運作.如果您調(diào)用兩次train(steps=10),則訓(xùn)練總共發(fā)生20個步驟.如果OutOfRange或StopIteration發(fā)生在中間,訓(xùn)練在20步之前停止.如果你不想有增量行為,請改為設(shè)置.如果設(shè)置max_steps,max_steps必須None.
  • max_steps:訓(xùn)練模型的總步驟數(shù).如果為None,則永遠(yuǎn)訓(xùn)練或訓(xùn)練直到input_fn產(chǎn)生OutOfRange錯誤或StopIteration異常.如果設(shè)置,steps必須None.如果OutOfRange或StopIteration發(fā)生在中間,訓(xùn)練在max_steps步驟之前停止.兩次調(diào)用train(steps=100)意味著200次訓(xùn)練迭代.另一方面,兩次調(diào)用train(max_steps=100)意味著第二次調(diào)用將不會做任何迭代,因為第一次調(diào)用完成了所有100個步驟.
  • saving_listeners:CheckpointSaverListener對象列表.用于在檢查點節(jié)省之前或之后立即執(zhí)行的回調(diào).

可能引發(fā)的異常:

  • ValueError:如果steps和max_steps都不是None.
  • ValueError:如果steps或max_steps其中之一小于等于0.
以上內(nèi)容是否對您有幫助:
在線筆記
App下載
App下載

掃描二維碼

下載編程獅App

公眾號
微信公眾號

編程獅公眾號