TensorFlow函數:tf.where

2018-04-19 10:47 更新

tf.where函數

tf.where(
    condition,
    x=None,
    y=None,
    name=None
)

定義在:tensorflow/python/ops/array_ops.py.

請參閱指南:控制流程>比較運算符,數學函數>序列比較和索引

根據condition返回x或y中的元素.

如果x和y都為None,則該操作將返回condition中true元素的坐標.坐標以二維張量返回,其中第一維(行)表示真實元素的數量,第二維(列)表示真實元素的坐標.請記住,輸出張量的形狀可以根據輸入中的真實值的多少而變化.索引以行優(yōu)先順序輸出.

如果兩者都不是None,則x和y必須具有相同的形狀.如果x和y是標量,則condition張量必須是標量.如果x和y是更高級別的矢量,則condition必須是大小與x的第一維度相匹配的矢量,或者必須具有與x相同的形狀.

condition張量作為一個可以選擇的掩碼(mask),它根據每個元素的值來判斷輸出中的相應元素/行是否應從 x (如果為 true) 或 y (如果為 false)中選擇.

如果condition是向量,則x和y是更高級別的矩陣,那么它選擇從x和y復制哪個行(外部維度).如果condition與x和y具有相同的形狀,那么它將選擇從x和y復制哪個元素.

函數參數:

  • condition:一個bool類型的張量(Tensor).
  • x:可能與condition具有相同形狀的張量;如果condition的秩是1,則x可能有更高的排名,但其第一維度必須匹配condition的大小.
  • y:與x具有相同的形狀和類型的張量.
  • name:操作的名稱(可選).

返回值:

如果它們不是None,則返回與x,y具有相同類型與形狀的張量;張量具有形狀(num_true, dim_size(condition)).

可能引發(fā)的異常:

  • ValueError:當一個x或y正好不是None.
以上內容是否對您有幫助:
在線筆記
App下載
App下載

掃描二維碼

下載編程獅App

公眾號
微信公眾號

編程獅公眾號