TensorFlow函數(shù):tf.while_loop

2018-04-20 10:29 更新

tf.while_loop函數(shù)

tf.while_loop(
    cond,
    body,
    loop_vars,
    shape_invariants=None,
    parallel_iterations=10,
    back_prop=True,
    swap_memory=False,
    name=None,
    maximum_iterations=None
)

定義在:tensorflow/python/ops/control_flow_ops.py.

請(qǐng)參閱指南:控制流程>控制流程操作

在條件cond成立時(shí)重復(fù)body.

cond是可返回的布爾標(biāo)量張量;body是一個(gè)可調(diào)用的函數(shù),它返回一個(gè)(可能是嵌套的)元組、namedtuple或者與loop_vars具有相同arity(長(zhǎng)度和結(jié)構(gòu))和類型的張量列表;loop_vars是一個(gè)(可能是嵌套的)元組,namedtuple或被傳遞給cond和body的張量的列表.cond和body的參數(shù)都盡可能的和loop_vars一樣多.

除了常規(guī)的Tensors或IndexedSlices,body可以接受并返回TensorArray對(duì)象.TensorArray對(duì)象的流將在循環(huán)之間和梯度計(jì)算期間適當(dāng)?shù)剞D(zhuǎn)發(fā).

請(qǐng)注意,while_loop只調(diào)用cond和body一次(調(diào)用while_loop內(nèi),而不是在所有Session.run()期間).while_loop將在cond和body調(diào)用期間創(chuàng)建的圖片段拼接在一起,并添加一些額外的圖形節(jié)點(diǎn),以創(chuàng)建重復(fù)body的圖形流,直到cond返回false為止.

為了正確性,tf.while_loop()嚴(yán)格執(zhí)行循環(huán)變量的形狀不變量.形狀不變是一個(gè)(可能是部分)形狀,在整個(gè)迭代循環(huán)中不變.如果迭代之后的循環(huán)變量的形狀被確定為比形狀不變性更一般或不兼容,則會(huì)引發(fā)錯(cuò)誤.例如,[11,None]的形狀比[11,17]的形狀更普遍,[11,21]與[11,17]不兼容.默認(rèn)情況下(如果沒(méi)有指定shape_invariants參數(shù)),則假設(shè)每個(gè)迭代中的loop_vars中的每個(gè)張量的初始形狀是相同的.該shape_invariants參數(shù)允許調(diào)用者為每個(gè)循環(huán)變量指定一個(gè)不太具體的形狀不變量,如果形狀在迭代之間變化,則需要該變量.該tf.Tensor.set_shape函數(shù)也可以用在body函數(shù)中來(lái)指示輸出循環(huán)變量具有特定的形狀.

SparseTensor和IndexedSlices的形狀不變的定義如下:

a)如果循環(huán)變量是SparseTensor,則形狀不變量必須是TensorShape([r]),其中r是由稀疏張量表示的稠密張量的秩.這意味著SparseTensor的三個(gè)張量的形狀是([None],[None,r],[r]).注意:此處不變的形狀是SparseTensor.dense_shape屬性的形狀.它必須是矢量的形狀.

b)如果循環(huán)變量是IndexedSlices,則形狀不變量必須是IndexedSlices的值張量的形狀不變量.這意味著IndexedSlices的三個(gè)張量的形狀是(shape,[shape [0]],[shape.ndims]).

while_loop實(shí)現(xiàn)非嚴(yán)格的語(yǔ)義,允許多個(gè)迭代并行運(yùn)行.并行迭代的最大數(shù)量可以通過(guò)parallel_iterations控制,這使用戶可以控制內(nèi)存消耗和執(zhí)行順序.對(duì)于正確的程序,while_loop應(yīng)該為任何parallel_iterations>0返回相同的結(jié)果.

對(duì)于訓(xùn)練,TensorFlow存儲(chǔ)在正向推斷中生成的張量,并且需要反向傳播.這些張量是內(nèi)存消耗的主要來(lái)源,并且在GPU上訓(xùn)練時(shí)經(jīng)常會(huì)導(dǎo)致OOM錯(cuò)誤.當(dāng)標(biāo)志swap_memory為true時(shí),我們將這些張量從GPU交換到CPU.例如,這允許我們訓(xùn)練具有很長(zhǎng)序列和大批量的RNN模型.

函數(shù)參數(shù):

  • cond:代表循環(huán)終止條件的可調(diào)用對(duì)象.
  • body:代表循環(huán)體的可調(diào)用對(duì)象.
  • loop_vars:一個(gè)(可能是嵌套的)元組,namedtuple或numpy數(shù)組、Tensor以及TensorArray對(duì)象的列表.
  • shape_invariants:循環(huán)變量的形狀不變量.
  • parallel_iterations:允許并行運(yùn)行的迭代次數(shù).它必須是一個(gè)正整數(shù).
  • back_prop:表示是否為此while循環(huán)啟用backprop.
  • swap_memory:此循環(huán)是否啟用GPU-CPU內(nèi)存交換.
  • name:返回張量的可選名稱前綴.
  • maximum_iterations:要運(yùn)行的while循環(huán)的可選最大迭代次數(shù).如果提供了,則cond輸出將與附加條件進(jìn)行AND運(yùn)算,以確保執(zhí)行的迭代次數(shù)不超過(guò)maximum_iterations.要運(yùn)行的 while 循環(huán)的最大迭代次數(shù).

返回值:

執(zhí)行循環(huán)后循環(huán)變量的輸出張量.當(dāng)loop_vars的長(zhǎng)度為1時(shí),這是一個(gè)Tensor、TensorArray或IndexedSlice,當(dāng)loop_vars的長(zhǎng)度大于1時(shí),它返回一個(gè)列表.

可能引發(fā)的異常:

  • TypeError:如果cond或者body是不可調(diào)用的.
  • ValueError:如果loop_vars是空的.

使用示例:

i = tf.constant(0)
c = lambda i: tf.less(i, 10)
b = lambda i: tf.add(i, 1)
r = tf.while_loop(c, b, [i])

嵌套和namedtuple的示例:

import collections
Pair = collections.namedtuple('Pair', 'j, k')
ijk_0 = (tf.constant(0), Pair(tf.constant(1), tf.constant(2)))
c = lambda i, p: i < 10
b = lambda i, p: (i + 1, Pair((p.j + p.k), (p.j - p.k)))
ijk_final = tf.while_loop(c, b, ijk_0)

使用shape_invariants的示例:

i0 = tf.constant(0)
m0 = tf.ones([2, 2])
c = lambda i, m: i < 10
b = lambda i, m: [i+1, tf.concat([m, m], axis=0)]
tf.while_loop(
    c, b, loop_vars=[i0, m0],
    shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])])
以上內(nèi)容是否對(duì)您有幫助:
在線筆記
App下載
App下載

掃描二維碼

下載編程獅App

公眾號(hào)
微信公眾號(hào)

編程獅公眾號(hào)