TensorFlow 返回張量的最大值索引

2020-07-16 10:59 更新

tf.argmax


argmax ( 
input ,
axis = None ,
name = None ,
dimension = None
)

定義在tensorflow/python/ops/math_ops.py.

參考指南:數(shù)學(xué)>序列比較和索引

返回在張量的坐標(biāo)軸上具有的最大值的索引.

請(qǐng)注意,在關(guān)聯(lián)的情況下,返回值的身份不能保證.

ARGS:

  • input:張量,必須是下列類型之一:float32,float64,int64,int32,uint8,uint16,int16,int8,complex64,complex128,qint8,quint8,qint32,half.
  • axis:張量,必須是以下類型之一:int32,int64.當(dāng)類型是 int32 時(shí),要滿足:0 <= axis < rank(input),描述輸入向量的哪個(gè)軸減少.對(duì)于矢量,使用 axis = 0.
  • name:操作的名稱(可選).

返回:

返回張量的 int 64 類型.


代碼示例:

  1. import tensorflow as tf
  2.  
  3. Vector = [1,1,2,5,3]           #定義一個(gè)向量
  4. X = [[1,3,2],[2,5,8],[7,5,9]]  #定義一個(gè)矩陣
  5.  
  6. with tf.Session() as sess:
  7.     a = tf.argmax(Vector, 0)
  8.     b = tf.argmax(X, 0)
  9.     c = tf.argmax(X, 1)
  10.     
  11.     print(sess.run(a))
  12.     print(sess.run(b))
  13.     print(sess.run(c))

運(yùn)行結(jié)果: 

  1. 3
  2. [2 1 2]
  3. [1 2 2]


以上內(nèi)容是否對(duì)您有幫助:
在線筆記
App下載
App下載

掃描二維碼

下載編程獅App

公眾號(hào)
微信公眾號(hào)

編程獅公眾號(hào)