TensorFlow函數(shù):tf.RegisterGradient

2018-12-19 11:20 更新
tf.RegisterGradient 函數(shù)

RegisterGradient 類

定義在:tensorflow/python/framework/ops.py.

請(qǐng)參閱指南:構(gòu)建圖>定義新操作

用于注冊(cè) op 類型的漸變函數(shù)的修飾符.

這個(gè)修飾符只在定義一個(gè)新的 op 類型時(shí)使用.對(duì)于具有 m 個(gè)輸入和 n 個(gè)輸出的運(yùn)算,梯度函數(shù)是一個(gè)采用原始的 Operation 和 n Tensor 對(duì)象(表示與 op 的每個(gè)輸出相關(guān)的梯度),并返回 m Tensor 對(duì)象(表示相對(duì)于 op 的每個(gè)輸入的部分梯度)的函數(shù).

例如,假設(shè)該類型的"Sub"操作需要兩個(gè)輸入 x 和 y,并返回一個(gè)單一的輸出 x - y,則以下梯度函數(shù)將被注冊(cè):

@tf.RegisterGradient("Sub")
def _sub_grad(unused_op, grad):
  return grad, tf.negative(grad)

修飾符參數(shù) op_type 是操作的字符串類型.這對(duì)應(yīng)于定義操作的原始 OpDef. name 字段.

方法

__init__

__init__(op_type)

該方法使用op_type作為Operation類型創(chuàng)建一個(gè)新的修飾符.

參數(shù):

  • op_type:操作的字符串類型.這對(duì)應(yīng)于定義操作的原始 OpDef. name 字段.

__call__

鬼斧神工

__call__(f)

該方法可以將函數(shù) f 注冊(cè)為 op_type 的梯度函數(shù). 

以上內(nèi)容是否對(duì)您有幫助:
在線筆記
App下載
App下載

掃描二維碼

下載編程獅App

公眾號(hào)
微信公眾號(hào)

編程獅公眾號(hào)