TensorFlow張量變換函數:tf.gather

2018-12-24 14:48 更新
函數:tf.gather
gather(
    params,
    indices,
    validate_indices=None,
    name=None,
    axis=0
)

定義在:tensorflow/python/ops/array_ops.py.

參見指南:張量變換>切片和連接

根據索引從參數軸上收集切片.
索引必須是任何維度的整數張量 (通常為 0-D 或 1-D).生成輸出張量該張量的形狀為:params.shape[:axis] + indices.shape + params.shape[axis + 1:]

使用示例如下:

# Scalar indices (output is rank(params) - 1).
output[a_0, ..., a_n, b_0, ..., b_n] =
  params[a_0, ..., a_n, indices, b_0, ..., b_n]

# Vector indices (output is rank(params)).
output[a_0, ..., a_n, i, b_0, ..., b_n] =
  params[a_0, ..., a_n, indices[i], b_0, ..., b_n]

# Higher rank indices (output is rank(params) + rank(indices) - 1).
output[a_0, ..., a_n, i, ..., j, b_0, ... b_n] =
  params[a_0, ..., a_n, indices[i, ..., j], b_0, ..., b_n]

TensorFlow張量變換函數:tf.gather

參數:

  • params:一個張量.這個張量是用來收集數值的.該張量的秩必須至少是 axis + 1.
  • indices:一個張量.必須是以下類型之一:int32,int64.索引張量必須在 [0, params.shape[axis]) 范圍內.
  • axis:一個張量.必須是以下類型之一:int32,int64.在參數軸從中收集索引.默認為第一個維度.支持負索引.
  • name:操作的名稱(可選).

返回值:

該函數返回一個張量.與參數具有相同的類型.參數值從索引給定的索引中收集而來,并且形狀為:params.shape[:axis] + indices.shape + params.shape[axis + 1:].

以上內容是否對您有幫助:
在線筆記
App下載
App下載

掃描二維碼

下載編程獅App

公眾號
微信公眾號

編程獅公眾號