TensorFlow聚合梯度的條件累加器

2018-09-12 16:19 更新

tf.ConditionalAccumulator

tf.ConditionalAccumulator 類

定義在:tensorflow/python/ops/data_flow_ops.py

用于聚合梯度的條件累加器.

最新的梯度(即計算梯度的時間步長等于累加器的時間步長)被加到累加器中.

平均梯度的提取被阻塞,直到所需數(shù)量的梯度被累積為止.

屬性

  • accumulator_ref:底層累加器參考.
  • dtype:該累加器積累的梯度的數(shù)據(jù)類型.
  • name:底層累加器的名稱.

方法

__init__

__init__ (  
    dtype ,  
    shape = None ,  
    shared_name = None ,  
    name = 'conditional_accumulator' 
)

創(chuàng)建一個新的 ConditionalAccumulator.

ARGS:
  • dtype:累積梯度的數(shù)據(jù)類型.
  • shape:累積梯度的形狀.
  • shared_name:可選.如果非空,這個累加器將在多個會話的給定名稱下共享.
  • name:累加器的可選名稱.

apply_grad

apply_grad (  
    grad ,  
    local_step = 0 ,  
    name = None
  )

嘗試向累加器應(yīng)用梯度.

如果梯度是陳舊的,即 local_step 小于累加器的全局時間步長,則該嘗試將被靜默地丟棄.

ARGS:

  • grad:要應(yīng)用的梯度張量.
  • local_step:計算梯度的時間步長.
  • name:操作的可選名稱.

返回:

(有條件地) 將梯度應(yīng)用于累加器的操作.

注意:

  • ValueError:如果 grad 是錯誤的形狀

num_accumulated

num_accumulated ( name = None )

目前在累加器中聚合的梯度數(shù).

ARGS:

  • name:操作的可選名稱.

返回:

累加器中當(dāng)前累積的梯度數(shù).

set_global_step

set_global_step (  
    new_global_step ,  
    name = None
  )

設(shè)置累加器的全局時間步長.

如果嘗試設(shè)置的時間步長低于累加器自己的時間步長, 則操作會記錄一個警告.

ARGS:

  • new_global_step:新的時間步長的值,可以是變量或常量.
  • name:操作的可選名稱.

返回:

設(shè)置累加器時間步長的操作.

take_grad

take_grad (  
    num_required ,  
    name = None
  )

嘗試從累加器中提取平均梯度.

操作阻止直到足夠數(shù)量的梯度已成功應(yīng)用于累加器.

一旦成功,還會觸發(fā)以下操作:

  • 累加梯度的計數(shù)器復(fù)位為0.
  • 聚合梯度被重置為0張量.
  • 累加器的內(nèi)部時間步長增加1.

ARGS:

  • num_required:需要聚合的梯度次數(shù)
  • name:操作的可選名稱

返回:

一個持續(xù)平均梯度值的張量.

注意:

  • InvalidArgumentError:如果 num_required <1


以上內(nèi)容是否對您有幫助:
在線筆記
App下載
App下載

掃描二維碼

下載編程獅App

公眾號
微信公眾號

編程獅公眾號