tf.estimator.Estimator函數(shù)
Estimator類
定義在:tensorflow/python/estimator/estimator.py
estimator類對TensorFlow模型進行訓(xùn)練和計算.
Estimator對象包裝由model_fn指定的模型,其中,給定輸入和其他一些參數(shù),返回需要進行訓(xùn)練、計算,或預(yù)測的操作.
所有輸出(檢查點,事件文件等)都被寫入model_dir或其子目錄.如果model_dir未設(shè)置,則使用臨時目錄.
可以通過RunConfig對象(包含了有關(guān)執(zhí)行環(huán)境的信息)傳遞config參數(shù).它被傳遞給model_fn,如果model_fn有一個名為“config”的參數(shù)(和輸入函數(shù)以相同的方式).如果該config參數(shù)未被傳遞,則由Estimator進行實例化.不傳遞配置意味著使用對本地執(zhí)行有用的默認(rèn)值.Estimator使配置對模型可用(例如,允許根據(jù)可用的工作人員數(shù)量進行專業(yè)化),并且還使用其一些字段來控制內(nèi)部,特別是關(guān)于檢查點.
該params參數(shù)包含hyperparameter,如果model_fn有一個名為“PARAMS”的參數(shù),并且以相同的方式傳遞給輸入函數(shù),則將它傳遞給 model_fn.Estimator只是沿著參數(shù)傳遞,并不檢查它.因此,params的結(jié)構(gòu)完全取決于開發(fā)人員.
不能在子類中重寫任何Estimator方法(其構(gòu)造函數(shù)強制執(zhí)行此操作).子類應(yīng)使用model_fn來配置基類,并且可以添加實現(xiàn)專門功能的方法.
Eager兼容性
estimator與eager執(zhí)行不兼容.
屬性
- config
- model_dir
- model_fn
返回綁定到self.params的model_fn.
返回:返回具有以下簽名的model_fn: def model_fn(features, labels, mode, config) - params
方法
__init__
__init__(
model_fn,
model_dir=None,
config=None,
params=None,
warm_start_from=None
)
構(gòu)造一個Estimator實例.
請參閱Estimator了解更多信息.啟動一個Estimator的方法如下所示:
estimator = tf.estimator.DNNClassifier(
feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb],
hidden_units=[1024, 512, 256],
warm_start_from="/path/to/checkpoint/dir")
有關(guān)warm-start啟動配置的更多詳細(xì)信息,請參閱WarmStartSettings.
參數(shù):
- model_fn:模型函數(shù),具有以下簽名:
ARGS.- features:這是從input_fn傳遞給train、evaluate和predict返回的第一個項目.這應(yīng)該是一個相同的單一的Tensor或dict.
- labels:這是從input_fn傳遞給train、evaluate和predict返回的第二個項目.這應(yīng)該是相同的單個Tensor或dict(對于multi-head模型).如果模式是ModeKeys.PREDICT,則將傳遞labels=None.如果model_fn簽名不接受mode,model_fn必須仍然能夠處理labels=None.
- mode:可選的.指定train、evaluate和predict.參考ModeKeys.
- params:hyperparameters的可選字典.將在params參數(shù)中接收傳遞給Estimator的內(nèi)容.這允許從hyperparameters調(diào)整來配置Estimator.
- config:可選配置對象.將收到傳遞給Estimator的config參數(shù)或默認(rèn)值config.允許根據(jù)配置(如num_ps_replicas或model_dir)更新您的model_fn中的內(nèi)容.
返回:EstimatorSpec
- model_dir:保存模型參數(shù)、圖形等的目錄.這也可用于將目錄中的檢查點加載到Estimator中,以繼續(xù)訓(xùn)練以前保存的模型.如果為PathLike對象,則路徑將被解析.如果為None,則將使用config中的model_dir(如果設(shè)置的話).如果兩者都設(shè)置,則它們必須相同.如果兩者都是None,則會使用臨時目錄.
- config:配置對象.
- params:dict將傳遞到model_fn中的hyperparameters.key是參數(shù)的名稱,value是基本的Python類型.
- warm_start_from:可選的字符串文件路徑,用于從warm-start的檢查點;或tf.estimator.WarmStartSettings對象,用于完全配置warm-start.如果提供字符串文件路徑而不是WarmStartSettings,則所有變量都是warm-start的,并且假定詞匯表和張量名稱未更改.
可能引發(fā)的異常:
- RuntimeError:如果eager執(zhí)行已啟用.
- ValueError:參數(shù)model_fn不匹配params.
- ValueError:如果這是通過子類調(diào)用的,并且該類重寫了Estimator的一個成員.
evaluate
evaluate(
input_fn,
steps=None,
hooks=None,
checkpoint_path=None,
name=None
)
計算給定計算數(shù)據(jù)input_fn的模型.
對于每個步驟來說,調(diào)用input_fn返回一批數(shù)據(jù).計算直到: -steps批處理被處理,或-input_fn引發(fā)輸入結(jié)束異常(OutOfRangeError或StopIteration).
參數(shù):
- input_fn:構(gòu)造用于計算的輸入數(shù)據(jù)的函數(shù).有關(guān)更多信息,請參閱TensorFlow入門.該函數(shù)應(yīng)該構(gòu)造并返回下列選項之一:
- tf.data.Dataset對象:Dataset對象的輸出必須是一個具有相同約束的元組(特征(features),標(biāo)簽(labels)),其約束條件與下面相同.
- tuple
(features, labels):其中features是Tensor或者名為Tensor的字符串特征的字典,而labels是Tensor或者名為Tensor的字符串標(biāo)簽的字典.這兩個特征和標(biāo)簽都由model_fn消耗.他們應(yīng)該滿足model_fn對輸入的期望.
- steps:計算模型所需的步驟數(shù).如果為None,則計算直到input_fn引發(fā)輸入異常時結(jié)束.
- hooks:SessionRunHook子類實例列表.用于計算調(diào)用中的回調(diào).
- checkpoint_path:計算特定檢查點的路徑.如果為None,則使用model_dir中的最新檢查點.
- name:需要使用的計算的名稱,如果用戶需要在不同的數(shù)據(jù)集上運行多個計算(如培訓(xùn)數(shù)據(jù)和測試數(shù)據(jù)).不同計算的度量標(biāo)準(zhǔn)保存在單獨的文件夾中,并單獨出現(xiàn)在tensorboard中.
返回值:
返回一個包含按name為鍵的model_fn中指定的計算指標(biāo)的詞典,以及包含執(zhí)行此技術(shù)的全局步驟的值的條目global_step.
可能引發(fā)的異常:
- ValueError:如果steps <= 0.
- ValueError:如果沒有模型被訓(xùn)練,名為model_dir,或者給定checkpoint_path是空的.
export_savedmodel
export_savedmodel(
export_dir_base,
serving_input_receiver_fn,
assets_extra=None,
as_text=False,
checkpoint_path=None,
strip_default_attrs=False
)
將推理圖作為SavedModel導(dǎo)出到給定的目錄中.
該方法通過首先調(diào)用serving_input_receiver_fn來獲取特征Tensors來構(gòu)建一個新圖,然后調(diào)用這個Estimator的model_fn來基于這些特征生成模型圖.它在新的會話中將給定的檢查點恢復(fù)到該圖中.最后它會在給定的export_dir_base下面創(chuàng)建一個時間戳導(dǎo)出目錄,并在其中寫入一個SavedModel,其中包含從此會話保存的單個MetaGraphDef.
導(dǎo)出的MetaGraphDef將為從model_fn返回的export_outputs字典的每個元素提供一個SignatureDef,該字典使用相同的key命名.其中一個key始終為signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY,指示在服務(wù)請求未指定簽名時將提供哪個簽名.對于每個簽名,輸出由相應(yīng)的ExportOutputs提供,并且輸入始終是由serving_input_receiver_fn提供的輸入接收器.
額外的資產(chǎn)可以通過assets_extra參數(shù)寫入SavedModel.這應(yīng)該是一個字典,其中每個key給出與assets.extra目錄相關(guān)的目標(biāo)路徑(包括文件名).相應(yīng)的值給出了要復(fù)制的源文件的完整路徑.例如,在不重命名的情況下復(fù)制單個文件的簡單情況被指定為{'my_asset_file.txt': '/path/to/my_asset_file.txt'}.
參數(shù):
- export_dir_base:包含一個目錄的字符串,在該目錄中創(chuàng)建包含導(dǎo)出的SavedModels的時間戳子目錄.
- serving_input_receiver_fn:一個不帶參數(shù)并返回一個ServingInputReceiver的函數(shù).
- assets_extra:指定如何在導(dǎo)出的SavedModel中填充assets.extra目錄的字典,如果不需要額外的資產(chǎn),則為 None.
- as_text:是否以文本格式編寫SavedModel原型.
- checkpoint_path:要導(dǎo)出的檢查點路徑.如果None(默認(rèn)),則選擇在模型目錄中找到的最近檢查點.
- strip_default_attrs:布爾值.如果True,則將從NodeDefs中刪除默認(rèn)值屬性.
返回值:
導(dǎo)出目錄的字符串路徑.
可能引發(fā)的異常:
- ValueError:如果未提供serving_input_receiver_fn,則不提供export_outputs,或者找不到檢查點.
get_variable_names
get_variable_names()
返回此模型中所有變量名稱的列表.
返回值:
返回名字列表.
可能引發(fā)的異常:
- ValueError:如果Estimator尚未產(chǎn)生檢查點.
get_variable_value
get_variable_value(name)
返回由名稱給出的變量的值.
參數(shù):
返回值:
Numpy數(shù)組 - 張量的值.
可能引發(fā)的異常:
- ValueError:如果Estimator尚未產(chǎn)生檢查點.
latest_checkpoint
latest_checkpoint()
查找model_dir中最新保存的檢查點文件的文件名.
返回值:
返回最新檢查點的完整路徑或None(未找到檢查點).
predict
predict(
input_fn,
predict_keys=None,
hooks=None,
checkpoint_path=None
)
對給定的features產(chǎn)生預(yù)測.
參數(shù):
- input_fn:構(gòu)造特征的函數(shù).預(yù)測繼續(xù),直到input_fn引發(fā)輸入端異常(OutOfRangeError或StopIteration).有關(guān)更多信息,請參閱TensorFlow入門.該函數(shù)應(yīng)該構(gòu)造并返回下列之一:
- tf.data.Dataset對象:Dataset對象的輸出必須具有與下面相同的約束.
- features:一個Tensor或者名為Tensor的字符串特征的字典.feature被model_fn消耗.他們應(yīng)該滿足model_fn對輸入的期望.
- 一個元組,在這種情況下,第一個項被提取為feature.
- predict_keys:str列表,要預(yù)測的鍵名稱.如果EstimatorSpec.predictions是字典,則使用該方法.如果使用predict_keys,則剩余的預(yù)測將從字典中過濾.如果None,則返回全部.
- hooks:SessionRunHook子類實例列表.用于預(yù)測調(diào)用中的回調(diào).
- checkpoint_path:要預(yù)測的特定檢查點的路徑.如果為None,則使用model_dir中的最新的檢查點.
返回值:
predictions張量的計算值.
可能引發(fā)的異常:
- ValueError:在model_dir中找不到訓(xùn)練有素的模型.
- ValueError:如果批次的預(yù)測長度不相同.
- ValueError:如果predict_keys和predictions之間有沖突.例如,如果predict_keys不是None,但EstimatorSpec.predictions不是一個dict.
train
train(
input_fn,
hooks=None,
steps=None,
max_steps=None,
saving_listeners=None
)
訓(xùn)練給定訓(xùn)練數(shù)據(jù)input_fn的模型.
參數(shù):
- input_fn:提供作為minibatches培訓(xùn)的輸入數(shù)據(jù)的函數(shù).有關(guān)更多信息,請參閱TensorFlow入門.該函數(shù)應(yīng)該構(gòu)造并返回下列之一:
- tf.data.Dataset對象:Dataset對象的輸出必須是一個具有相同約束的元組(特征,標(biāo)簽)((features, labels)),其約束條件與下面相同.
- tuple
(features, labels):其中features是一個Tensor或者名為Tensor的字符串特征的字典,labels是一個Tensor或者名為Tensor的字符串標(biāo)簽的字典.這兩個特征和標(biāo)簽都由model_fn消耗.他們應(yīng)該滿足model_fn對輸入的期望.
- hooks:SessionRunHook子類實例列表.用于訓(xùn)練循環(huán)內(nèi)的回調(diào).
- steps:訓(xùn)練模型的步驟數(shù).如果為None,則永遠(yuǎn)訓(xùn)練或訓(xùn)練直到input_fn產(chǎn)生OutOfRange錯誤或StopIteration異常.“steps”逐步運作.如果您調(diào)用兩次train(steps=10),則訓(xùn)練總共發(fā)生20個步驟.如果OutOfRange或StopIteration發(fā)生在中間,訓(xùn)練在20步之前停止.如果你不想有增量行為,請改為設(shè)置.如果設(shè)置max_steps,max_steps必須None.
- max_steps:訓(xùn)練模型的總步驟數(shù).如果為None,則永遠(yuǎn)訓(xùn)練或訓(xùn)練直到input_fn產(chǎn)生OutOfRange錯誤或StopIteration異常.如果設(shè)置,steps必須None.如果OutOfRange或StopIteration發(fā)生在中間,訓(xùn)練在max_steps步驟之前停止.兩次調(diào)用train(steps=100)意味著200次訓(xùn)練迭代.另一方面,兩次調(diào)用train(max_steps=100)意味著第二次調(diào)用將不會做任何迭代,因為第一次調(diào)用完成了所有100個步驟.
- saving_listeners:CheckpointSaverListener對象列表.用于在檢查點節(jié)省之前或之后立即執(zhí)行的回調(diào).
可能引發(fā)的異常:
- ValueError:如果steps和max_steps都不是None.
- ValueError:如果steps或max_steps其中之一小于等于0.
更多建議: