TensorFlow如何計(jì)算混淆矩陣

2018-09-14 16:08 更新

tf.confusion_matrix

confusion_matrix(
    labels,
    predictions,
    num_classes=None,
    dtype=tf.int32,
    name=None,
    weights=None
)

定義在:tensorflow/python/ops/confusion_matrix.py.

從預(yù)測(cè)和標(biāo)簽計(jì)算混淆矩陣.

計(jì)算一對(duì)預(yù)測(cè)和標(biāo)簽的 1維 int 數(shù)組的混淆矩陣.

矩陣的列表示預(yù)測(cè)標(biāo)簽,行表示實(shí)際標(biāo)簽.混淆矩陣總是形狀 [n, n] 的一個(gè)二維數(shù)組,其中 n 是給定分類任務(wù)的有效標(biāo)簽的數(shù)量.預(yù)測(cè)和標(biāo)簽都必須是相同形狀的 1維數(shù)組,以使此函數(shù)正常工作.

如果 num_classes 為 None,則 num_classes 將被設(shè)置為一個(gè)加上預(yù)測(cè)值或標(biāo)簽中的最大值.類標(biāo)簽預(yù)計(jì)從0開始.例如, 如果 num_classes 是三個(gè),那么可能的標(biāo)簽將是 [0, 1, 2].
如果權(quán)重不是 None,則每個(gè)預(yù)測(cè)都會(huì)對(duì)混淆矩陣單元的總值做出相應(yīng)的權(quán)重.

例如:

tf.contrib.metrics.confusion_matrix([1, 2, 4], [2, 2, 4]) ==>
    [[0 0 0 0 0]
     [0 0 1 0 0]
     [0 0 1 0 0]
     [0 0 0 0 0]
     [0 0 0 0 1]]
請(qǐng)注意,可能的標(biāo)簽被假設(shè)為[0, 1, 2, 3, 4] ,從而導(dǎo)致 5x5 混淆矩陣.

ARGS:

  • labels:用于分類任務(wù)的實(shí)際標(biāo)簽的 1維張量 .
  • predictions:給定分類的 1維預(yù)測(cè)的張量.
  • num_classes:分類任務(wù)可能具有的標(biāo)簽數(shù)量.如果未提供此值,則將使用預(yù)測(cè)和標(biāo)簽數(shù)組計(jì)算它.
  • dtype:混淆矩陣的數(shù)據(jù)類型.
  • name:作用域名稱.
  • weights:一個(gè)形狀匹配預(yù)測(cè)的可選張量.

返回:

表示混淆矩陣的 k X k 矩陣, 其中 k 是分類任務(wù)中可能的標(biāo)簽數(shù).

注意:

  • ValueError:如果預(yù)測(cè)和標(biāo)簽都不是1維向量, 且形狀不匹配, 或者權(quán)重不是 None, 并且其形狀與預(yù)測(cè)不匹配.


以上內(nèi)容是否對(duì)您有幫助:
在線筆記
App下載
App下載

掃描二維碼

下載編程獅App

公眾號(hào)
微信公眾號(hào)

編程獅公眾號(hào)