tf.cond函數(shù)的使用

2020-07-30 10:51 更新

tf.cond

cond ( 
    pred , 
    true_fn = None , 
    false_fn = None , 
    strict = False , 
    name = None , 
    fn1 = None , 
    fn2 = None
 )

定義在:tensorflow/python/ops/control_flow_ops.py.

請參閱指南:控制流程>控制流程操作

在TensorFlow中,tf.cond()類似于c語言中的if...else...,用來控制數(shù)據(jù)流向,但是僅僅類似而已,其中差別還是挺大的。關(guān)于tf.cond()函數(shù)的具體操作,我參考了tf的說明文檔。

format:tf.cond(pred, fn1, fn2, name=None)

Return :either fn1() or fn2() based on the boolean predicate `pred`.(注意這里,也就是說'fnq'和‘fn2’是兩個函數(shù))

arguments:`fn1` and `fn2` both return lists of output tensors. `fn1` and `fn2` must have the same non-zero number and type of outputs('fnq'和‘fn2’返回的是非零的且類型相同的輸出)

如果斷言 pred 為 true ;則返回 true_fn() ,否則返回 false_fn().(棄用參數(shù))

某些參數(shù)已棄用.它們將在將來的版本中被刪除.有關(guān)更新的說明: fn1/fn2 不推薦使用,支持使用 true_fn/false_fn 參數(shù).

true_fn 和 false_fn 都返回輸出張量的列表.true_fn false_fn 必須具有相同的非零數(shù)和輸出類型.
請注意, 條件執(zhí)行僅適用于在 true_fn 和 false_fn 中定義的操作. 請考慮以下簡單程序:

z = tf.multiply( a , b ) 
result = tf.cond( X < Y, lambda:tf.add( X ,Z), lambda: tf.square( Y ))

如果x < y,tf.add 將執(zhí)行并且 tf.square 操作不執(zhí)行.因為 z 是需要的至少一個分支的條件,因為 tf.multiply 操作始終無條件地執(zhí)行.雖然這種行為與 TensorFlow 的數(shù)據(jù)流模型是一致的,但有時候,有些用戶會期待一種較為惰性的語義.

請注意,cond 調(diào)用 true_fn 和 false_fn 一次(在調(diào)用 cond 的內(nèi)部,而不是在 Session.run()期間 ).cond 將在 true_fn 和 false_fn 期間創(chuàng)建的圖形片段一起使用一些附加的圖形節(jié)點來確保右分支根據(jù) pred 的值執(zhí)行.

tf.cond 支持嵌套結(jié)構(gòu)在 tensorflow.python.util.nest 中的實現(xiàn).true_fn 和 false_fn 必須返回列表,元組和/或命名元組的相同(可能嵌套的)值結(jié)構(gòu).

單例列表和元組是唯一的例外:當(dāng)由 true_fn 或者 false_fn 隱式解壓縮到單個值時.通過傳遞 strict=True 禁用此行為.

ARGS:

  • pred:標(biāo)量決定是否返回 true_fn 或 false_fn 結(jié)果.
  • true_fn:如果 pred 為 true,則被調(diào)用.
  • false_fn:如果 pred 為 false,則被調(diào)用.
  • strict:啟用/禁用 “嚴格”模式的布爾值.
  • name:返回的張量的可選名稱前綴.

返回:

通過調(diào)用 true_fn 或 false_fn 返回的張量.如果 callables 返回單一實例列表, 則從列表中提取元素.

注意:

  • TypeError: 如果 true_fn 或 false_fn 是不可調(diào)用的.
  • ValueError:如果 true_fn 和 false_fn 不返回相同數(shù)量的張量, 或返回不同類型的張量.

例:

x = tf.constant(2 ) 
y = tf.constant(5 )
def  f1 (): return tf .multiply( x , 17 )
def  f2 (): return tf .add ( y , 23 ) 
r = tf .cond( tf.less( X ,y ), f1 , f2 )
#r 設(shè)置為f1().
#f2 中的操作(例如,tf.add)不執(zhí)行.


以上內(nèi)容是否對您有幫助:
在線筆記
App下載
App下載

掃描二維碼

下載編程獅App

公眾號
微信公眾號

編程獅公眾號