tf.one_hot函數(shù):返回one-hot張量

2018-11-07 18:01 更新
tf.one_hot 函數(shù)
one_hot(
    indices,
    depth,
    on_value=None,
    off_value=None,
    axis=None,
    dtype=None,
    name=None
)

定義在:tensorflow/python/ops/array_ops.py.

參見指南:張量變換>分割和連接

返回一個(gè) one-hot 張量.

索引中由索引表示的位置取值 on_value,而所有其他位置都取值 off_value.

on_value 和 off_value必須具有匹配的數(shù)據(jù)類型.如果還提供了 dtype,則它們必須與 dtype 指定的數(shù)據(jù)類型相同.

如果未提供 on_value,則默認(rèn)值將為 1,其類型為 dtype.

如果未提供 off_value,則默認(rèn)值為 0,其類型為 dtype.

如果輸入的索引的秩為 N,則輸出的秩為 N+1.新的坐標(biāo)軸在維度上創(chuàng)建 axis(默認(rèn)值:新坐標(biāo)軸在末尾追加).

如果索引是標(biāo)量,則輸出形狀將是長(zhǎng)度 depth 的向量.

如果索引是長(zhǎng)度 features 的向量,則輸出形狀將為:

features x depth if axis == -1
depth x features if axis == 0

如果索引是具有形狀 [batch, features] 的矩陣(批次),則輸出形狀將是: 

batch x features x depth if axis == -1
batch x depth x features if axis == 1
depth x batch x features if axis == 0

如果 dtype 沒有提供,則它會(huì)嘗試假定 on_value 或者 off_value 的數(shù)據(jù)類型,如果其中一個(gè)或兩個(gè)都傳入.如果沒有提供 on_value、off_value 或 dtype,則dtype 將默認(rèn)為值 tf.float32.

注意:如果一個(gè)非數(shù)值數(shù)據(jù)類型輸出期望(tf.string,tf.bool等),都o(jì)n_value與off_value 必須被提供給one_hot.

示例

示例-1

假設(shè)如下:

indices = [0, 2, -1, 1]
depth = 3
on_value = 5.0
off_value = 0.0
axis = -1

那么輸出為 [4 x 3]:

output =
[5.0 0.0 0.0]  // one_hot(0)
[0.0 0.0 5.0]  // one_hot(2)
[0.0 0.0 0.0]  // one_hot(-1)
[0.0 5.0 0.0]  // one_hot(1)

示例-2

假設(shè)如下:

indices = [[0, 2], [1, -1]]
depth = 3
on_value = 1.0
off_value = 0.0
axis = -1

那么輸出是 [2 x 2 x 3]:

output =
[
  [1.0, 0.0, 0.0]  // one_hot(0)
  [0.0, 0.0, 1.0]  // one_hot(2)
][
  [0.0, 1.0, 0.0]  // one_hot(1)
  [0.0, 0.0, 0.0]  // one_hot(-1)
]

使用 on_value 和 off_value 的默認(rèn)值:

indices = [0, 1, 2]
depth = 3

輸出將是:

output =
[[1., 0., 0.],
 [0., 1., 0.],
 [0., 0., 1.]]

參數(shù):

  • indices:指數(shù)的張量.
  • depth:一個(gè)標(biāo)量,用于定義一個(gè) one hot 維度的深度.
  • on_value:定義在 indices[j] = i 時(shí)填充輸出的值的標(biāo)量.(默認(rèn):1)
  • off_value:定義在 indices[j] != i 時(shí)填充輸出的值的標(biāo)量.(默認(rèn):0)
  • axis:要填充的軸(默認(rèn):-1,一個(gè)新的最內(nèi)層軸).
  • dtype:輸出張量的數(shù)據(jù)類型.

返回值:

  • output: one-hot 張量.

可能引發(fā)的異常:

  • TypeError:如果 on_value 或者 off_value 的類型不匹配 dtype.
  • TypeError:如果 on_value 和 off_value 的 dtype 不匹配.
以上內(nèi)容是否對(duì)您有幫助:
在線筆記
App下載
App下載

掃描二維碼

下載編程獅App

公眾號(hào)
微信公眾號(hào)

編程獅公眾號(hào)