TensorFlow張量變換函數(shù):tf.sequence_mask

2018-01-16 10:44 更新

tf.sequence_mask 函數(shù)

sequence_mask(
    lengths,
    maxlen=None,
    dtype=tf.bool,
    name=None
)

定義在:tensorflow/python/ops/array_ops.py.

參見指南:張量變換>分割和連接

返回一個(gè)表示每個(gè)單元的前N個(gè)位置的mask張量.

如果lengths的形狀為[d_1, d_2, ..., d_n],由此產(chǎn)生的張量mask有dtype類型和形狀[d_1, d_2, ..., d_n, maxlen],并且:

mask[i_1, i_2, ..., i_n, j] = (j < lengths[i_1, i_2, ..., i_n])

示例如下:

tf.sequence_mask([1, 3, 2], 5)  # [[True, False, False, False, False],
                                #  [True, True, True, False, False],
                                #  [True, True, False, False, False]]

tf.sequence_mask([[1, 3],[2,0]])  # [[[True, False, False],
                                  #   [True, True, True]],
                                  #  [[True, True, False],
                                  #   [False, False, False]]]

函數(shù)參數(shù)

  • lengths:整數(shù)張量,其所有值小于等于maxlen.
  • maxlen:標(biāo)量整數(shù)張量,返回張量的最后維度的大?。荒J(rèn)值是lengths中的最大值.
  • dtype:結(jié)果張量的輸出類型.
  • name:操作的名字.

函數(shù)返回值

形狀為lengths.shape + (maxlen,)的mask張量,投射到指定的dtype.

函數(shù)中可能存在的異常

  • ValueError:如果maxlen不是標(biāo)量.
以上內(nèi)容是否對(duì)您有幫助:
在線筆記
App下載
App下載

掃描二維碼

下載編程獅App

公眾號(hào)
微信公眾號(hào)

編程獅公眾號(hào)