W3Cschool
恭喜您成為首批注冊用戶
獲得88經(jīng)驗值獎勵
tf.estimator.DNNLinearCombinedRegressor
繼承自: Estimator
定義在:tensorflow/python/estimator/canned/dnn_linear_combined.py.
TensorFlow Linear 和 DNN 的估計器加入了回歸模型.
注意:此估計器(estimator)也稱為 wide-n-deep.
例如:
numeric_feature = numeric_column(...)
sparse_column_a = categorical_column_with_hash_bucket(...)
sparse_column_b = categorical_column_with_hash_bucket(...)
sparse_feature_a_x_sparse_feature_b = crossed_column(...)
sparse_feature_a_emb = embedding_column(sparse_id_column=sparse_feature_a,
...)
sparse_feature_b_emb = embedding_column(sparse_id_column=sparse_feature_b,
...)
estimator = DNNLinearCombinedRegressor(
# wide settings
linear_feature_columns=[sparse_feature_a_x_sparse_feature_b],
linear_optimizer=tf.train.FtrlOptimizer(...),
# deep settings
dnn_feature_columns=[
sparse_feature_a_emb, sparse_feature_b_emb, numeric_feature],
dnn_hidden_units=[1000, 500, 100],
dnn_optimizer=tf.train.ProximalAdagradOptimizer(...))
# To apply L1 and L2 regularization, you can set optimizers as follows:
tf.train.ProximalAdagradOptimizer(
learning_rate=0.1,
l1_regularization_strength=0.001,
l2_regularization_strength=0.001)
# It is same for FtrlOptimizer.
# Input builders
def input_fn_train: # returns x, y
pass
estimator.train(input_fn=input_fn_train, steps=100)
def input_fn_eval: # returns x, y
pass
metrics = estimator.evaluate(input_fn=input_fn_eval, steps=10)
def input_fn_predict: # returns x, None
pass
predictions = estimator.predict(input_fn=input_fn_predict)
輸入的 train 和 evaluate 應具有以下特點,否則將會產(chǎn)生 KeyError:
損失通過使用均方誤差計算.
__init__(
model_dir=None,
linear_feature_columns=None,
linear_optimizer='Ftrl',
dnn_feature_columns=None,
dnn_optimizer='Adagrad',
dnn_hidden_units=None,
dnn_activation_fn=tf.nn.relu,
dnn_dropout=None,
label_dimension=1,
weight_column=None,
input_layer_partitioner=None,
config=None
)
初始化 DNNLinearCombinedRegressor 實例.
evaluate(
input_fn,
steps=None,
hooks=None,
checkpoint_path=None,
name=None
)
評估給定的評估數(shù)據(jù) input_fn 的模型.
對于每個步驟,調(diào)用 input_fn,它返回一組數(shù)據(jù).評估結(jié)束條件:達到 - steps 批處理,或 - input_fn 引發(fā) end-of-input 異常(OutOfRangeError 或 StopIteration).
包含 model_fn 按名稱鍵入指定的評估度量的 dict ,以及一個條目 global_step,它包含執(zhí)行此評估的全局步驟值.
export_savedmodel(
export_dir_base,
serving_input_receiver_fn,
assets_extra=None,
as_text=False,
checkpoint_path=None
)
將推理圖作為 SavedModel 導出到給定的目錄中.
此方法首先調(diào)用 serve_input_receiver_fn 來獲取特征 Tensors,然后調(diào)用此 Estimator 的 model_fn 以生成基于這些特征的模型圖,從而構(gòu)建新的圖.它在新的會話中將給定的檢查點(或缺少最新的檢查點)還原到此圖中.最后,它在給定的 export_dir_base 下面創(chuàng)建一個時間戳的導出目錄,并將 SavedModel 寫入其中,其中包含從此會話保存的單個 MetaGraphDef.
導出的 MetaGraphDef 將為從 model_fn 返回的 export_outputs 字典的每個元素提供一個 SignatureDef,使用相同的鍵命名.這些密鑰之一始終是signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY,指示當一個服務(wù)請求沒有指定時將提供哪個簽名.對于每個簽名,輸出由相應的ExportOutputs 提供,并且輸入始終是由 serve_input_receiver_fn 提供的輸入接收器.
額外的資產(chǎn)可以通過 extra_assets 參數(shù)寫入 SavedModel.這應該是一個 dict,其中每個鍵都給出了相對于 assets.extra 目錄的目標路徑(包括文件名).相應的值給出要復制的源文件的完整路徑.例如,復制單個文件而不重命名的簡單情況被指定為:{'my_asset_file.txt': '/path/to/my_asset_file.txt'}.
返回:
導出目錄的字符串路徑.
predict(
input_fn,
predict_keys=None,
hooks=None,
checkpoint_path=None
)
返回給定功能的預測.
計算預測張量的值.
train(
input_fn,
hooks=None,
steps=None,
max_steps=None
)
訓練一個給定訓練數(shù)據(jù) input_fn 的模型.
返回 self,用于鏈接.
Copyright©2021 w3cschool編程獅|閩ICP備15016281號-3|閩公網(wǎng)安備35020302033924號
違法和不良信息舉報電話:173-0602-2364|舉報郵箱:jubao@eeedong.com
掃描二維碼
下載編程獅App
編程獅公眾號
聯(lián)系方式:
更多建議: