W3Cschool
恭喜您成為首批注冊用戶
獲得88經(jīng)驗值獎勵
tf.estimator.DNNRegressor
繼承自: Estimator
定義在:tensorflow/python/estimator/canned/dnn.py.
TensorFlow DNN 模型的回歸器.
例:
sparse_feature_a = sparse_column_with_hash_bucket(...)
sparse_feature_b = sparse_column_with_hash_bucket(...)
sparse_feature_a_emb = embedding_column(sparse_id_column=sparse_feature_a,
...)
sparse_feature_b_emb = embedding_column(sparse_id_column=sparse_feature_b,
...)
estimator = DNNRegressor(
feature_columns=[sparse_feature_a_emb, sparse_feature_b_emb],
hidden_units=[1024, 512, 256])
# Or estimator using the ProximalAdagradOptimizer optimizer with
# regularization.
estimator = DNNRegressor(
feature_columns=[sparse_feature_a_emb, sparse_feature_b_emb],
hidden_units=[1024, 512, 256],
optimizer=tf.train.ProximalAdagradOptimizer(
learning_rate=0.1,
l1_regularization_strength=0.001
))
# Input builders
def input_fn_train: # returns x, y
pass
estimator.train(input_fn=input_fn_train, steps=100)
def input_fn_eval: # returns x, y
pass
metrics = estimator.evaluate(input_fn=input_fn_eval, steps=10)
def input_fn_predict: # returns x, None
pass
predictions = estimator.predict(input_fn=input_fn_predict)
輸入的 train 和 evaluate 應(yīng)具有以下特點,否則將會產(chǎn)生 KeyError:
損失通過使用均方誤差計算.
__init__(
hidden_units,
feature_columns,
model_dir=None,
label_dimension=1,
weight_column=None,
optimizer='Adagrad',
activation_fn=tf.nn.relu,
dropout=None,
input_layer_partitioner=None,
config=None
)
初始化一個 DNNRegressor 實例.
evaluate(
input_fn,
steps=None,
hooks=None,
checkpoint_path=None,
name=None
)
評估給定的評估數(shù)據(jù) input_fn 的模型.
對于每個步驟,調(diào)用 input_fn,它返回一組數(shù)據(jù).評估結(jié)束條件:達到 - steps 批處理,或 - input_fn 引發(fā) end-of-input 異常(OutOfRangeError 或 StopIteration).
包含 model_fn 按名稱鍵入指定的評估度量的 dict ,以及一個條目 global_step,它包含執(zhí)行此評估的全局步驟值.
export_savedmodel(
export_dir_base,
serving_input_receiver_fn,
assets_extra=None,
as_text=False,
checkpoint_path=None
)
將推理圖作為 SavedModel 導(dǎo)出到給定的目錄中.
此方法首先調(diào)用 serve_input_receiver_fn 來獲取特征 Tensors,然后調(diào)用此 Estimator 的 model_fn 以生成基于這些特征的模型圖,從而構(gòu)建新的圖.它在新的會話中將給定的檢查點(或缺少最新的檢查點)還原到此圖中.最后,它在給定的 export_dir_base 下面創(chuàng)建一個時間戳的導(dǎo)出目錄,并將 SavedModel 寫入其中,其中包含從此會話保存的單個 MetaGraphDef.
導(dǎo)出的 MetaGraphDef 將為從 model_fn 返回的 export_outputs 字典的每個元素提供一個 SignatureDef,使用相同的鍵命名.這些密鑰之一始終是signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY,指示當一個服務(wù)請求沒有指定時將提供哪個簽名.對于每個簽名,輸出由相應(yīng)的ExportOutputs 提供,并且輸入始終是由 serve_input_receiver_fn 提供的輸入接收器.
額外的資產(chǎn)可以通過 extra_assets 參數(shù)寫入 SavedModel.這應(yīng)該是一個 dict,其中每個鍵都給出了相對于 assets.extra 目錄的目標路徑(包括文件名).相應(yīng)的值給出要復(fù)制的源文件的完整路徑.例如,復(fù)制單個文件而不重命名的簡單情況被指定為:{'my_asset_file.txt': '/path/to/my_asset_file.txt'}.
導(dǎo)出目錄的字符串路徑.
predict(
input_fn,
predict_keys=None,
hooks=None,
checkpoint_path=None
)
返回給定功能的預(yù)測.
計算預(yù)測張量的值.
train(
input_fn,
hooks=None,
steps=None,
max_steps=None
)
訓(xùn)練一個給定訓(xùn)練數(shù)據(jù) input_fn 的模型.
返回 self,用于鏈接.
Copyright©2021 w3cschool編程獅|閩ICP備15016281號-3|閩公網(wǎng)安備35020302033924號
違法和不良信息舉報電話:173-0602-2364|舉報郵箱:jubao@eeedong.com
掃描二維碼
下載編程獅App
編程獅公眾號
聯(lián)系方式:
更多建議: