TensorFlow函數(shù):tf.estimator.EstimatorSpec

2018-04-27 14:06 更新

tf.estimator.EstimatorSpec函數(shù)

EstimatorSpec類

定義在:tensorflow/python/estimator/model_fn.py.

從model_fn返回的操作和對(duì)象并傳遞給Estimator.

EstimatorSpec完全定義了由Estimator運(yùn)行的模型.

屬性

  • eval_metric_ops
    字段號(hào)4的別名
  • evaluation_hooks
    字段號(hào)9的別名
  • export_outputs
    字段號(hào)5的別名
  • loss
    字段號(hào)2的別名
  • mode
    字段號(hào)0的別名
  • prediction_hooks
    字段號(hào)10的別名
  • predictions
    字段號(hào)1的別名
  • scaffold
    字段號(hào)8的別名
  • train_op
    字段號(hào)3的別名
  • training_chief_hooks
    字段號(hào)6的別名
  • training_hooks
    字段號(hào)7的別名

方法

__new__

@ staticmethod 
__new__ ( 
    cls , 
    mode , 
    predictions = None , 
    loss = None , 
    train_op = None , 
    eval_metric_ops = None , 
    export_outputs = None , 
    training_chief_hooks = None , 
    training_hooks = None , 
    scaffold = None , 
    evaluation_hooks = None , 
    prediction_hooks= 無
)

創(chuàng)建一個(gè)已經(jīng)驗(yàn)證的EstimatorSpec實(shí)例.

根據(jù)mode的值的不同,需要不同的參數(shù),即:

  • 對(duì)于mode == ModeKeys.TRAIN:必填字段是loss和train_op.
  • 對(duì)于mode == ModeKeys.EVAL:必填字段是loss.
  • 為mode == ModeKeys.PREDICT:必填字段是predictions.

model_fn可以填充獨(dú)立于模式的所有參數(shù).在這種情況下,Estimator將忽略某些參數(shù).在eval和infer模式中,train_op將被忽略.例子如下:

def my_model_fn(mode, features, labels):
  predictions = ...
  loss = ...
  train_op = ...
  return tf.estimator.EstimatorSpec(
      mode=mode,
      predictions=predictions,
      loss=loss,
      train_op=train_op)

或者,model_fn可以填充適合給定模式的參數(shù).例:

def my_model_fn(mode, features, labels):
  if (mode == tf.estimator.ModeKeys.TRAIN or
      mode == tf.estimator.ModeKeys.EVAL):
    loss = ...
  else:
    loss = None
  if mode == tf.estimator.ModeKeys.TRAIN:
    train_op = ...
  else:
    train_op = None
  if mode == tf.estimator.ModeKeys.PREDICT:
    predictions = ...
  else:
    predictions = None

  return tf.estimator.EstimatorSpec(
      mode=mode,
      predictions=predictions,
      loss=loss,
      train_op=train_op)

函數(shù)參數(shù):

  • mode:一個(gè)ModeKeys,指定是training(訓(xùn)練)、evaluation(計(jì)算)還是prediction(預(yù)測).
  • predictions:預(yù)測Tensor或字典Tensor.
  • loss:訓(xùn)練損失Tensor,必須是標(biāo)量或形狀[1].
  • train_op:適用于訓(xùn)練的步驟.
  • eval_metric_ops:按名稱鍵入的度量結(jié)果字典.字典的值是調(diào)用度量函數(shù)的結(jié)果,即(metric_tensor, update_op)元組.應(yīng)該在沒有任何狀態(tài)影響的情況下進(jìn)行metric_tensor計(jì)算(通常是基于變量的純計(jì)算結(jié)果).例如,它不應(yīng)該觸發(fā)update_op或需要任何輸入提取.
  • export_outputs:描述要在服務(wù)期間導(dǎo)出到SavedModel并使用的輸出簽名.在字典{name: output}中:name:此輸出的任意名稱.output:一個(gè)ExportOutput對(duì)象,如ClassificationOutput,RegressionOutput或PredictOutput.Single-headed模型只需要在本字典中指定一個(gè)條目.Multi-headed模型應(yīng)為每個(gè)頭指定一個(gè)條目,其中之一必須使用signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY進(jìn)行命名.
  • training_chief_hooks:在訓(xùn)練期間可以在主要工作人員中運(yùn)行的tf.train.SessionRunHook對(duì)象的迭代.
  • training_hooks:在訓(xùn)練過程中可以對(duì)所有工作人員運(yùn)行的tf.train.SessionRunHook對(duì)象.
  • scaffold:可用于設(shè)置初始化,保護(hù)程序等用于訓(xùn)練的tf.train.Scaffold對(duì)象.
  • evaluation_hooks:評(píng)估期間要運(yùn)行的tf.train.SessionRunHook對(duì)象的可迭代性.
  • prediction_hooks:在預(yù)測期間可以運(yùn)行的tf.train.SessionRunHook對(duì)象的可迭代性.

返回值:

一個(gè)經(jīng)過驗(yàn)證的EstimatorSpec對(duì)象.

可能引發(fā)的異常:

  • ValueError:如果驗(yàn)證失敗,則會(huì)引發(fā)此異常.
  • TypeError:如果任何參數(shù)不是預(yù)期的類型.
以上內(nèi)容是否對(duì)您有幫助:
在線筆記
App下載
App下載

掃描二維碼

下載編程獅App

公眾號(hào)
微信公眾號(hào)

編程獅公眾號(hào)