TensorFlow函數(shù):tf.reverse_sequence

2018-12-26 10:53 更新

tf.reverse_sequence 函數(shù)

reverse_sequence(
    input,
    seq_lengths,
    seq_axis=None,
    batch_axis=None,
    name=None,
    seq_dim=None,
    batch_dim=None
)

定義在:tensorflow/python/ops/array_ops.py.

參見指南:張量變換>分割和連接

反轉可變長度切片.

此操作首先沿著維度batch_axis對input進行分割,并且對于每個切片 i,將第一個 seq_lengths 元素沿維度 seq_axis 反轉.

seq_lengths 的元素必須服從 seq_lengths[i] <= input.dims[seq_dim],seq_lengths必須是一個長度input.dims[batch_dim]的矢量.

然后沿著維度batch_axis的輸出切片i由輸入切片i給出,與第一個 seq_lengths [i] 切片沿維度 seq_axis 反轉.

例如:

# 給定如下:
batch_dim = 0 seq_dim = 1 input.dims = (4, 8, ...) seq_lengths = [7, 2, 3, 5] # 然后輸入的切片在seq_dim上反轉,但只到seq_lengths:
output[0, 0:7, :, ...] = input[0, 7:0:-1, :, ...] output[1, 0:2, :, ...] = input[1, 2:0:-1, :, ...] output[2, 0:3, :, ...] = input[2, 3:0:-1, :, ...] output[3, 0:5, :, ...] = input[3, 5:0:-1, :, ...] #當條目通過seq_lens被復制通過:output[0, 7:, :, ...] = input[0, 7:, :, ...] output[1, 2:, :, ...] = input[1, 2:, :, ...] output[2, 3:, :, ...] = input[2, 3:, :, ...] output[3, 2:, :, ...] = input[3, 2:, :, ...]

相反的情況如下:

# 給定如下:
batch_dim = 2
seq_dim = 0
input.dims = (8, ?, 4, ...)
seq_lengths = [7, 2, 3, 5]

# 然后在seq_dim上切換輸入的切片,但只能切換到seq_lengths:
output[0:7, :, 0, :, ...] = input[7:0:-1, :, 0, :, ...] output[0:2, :, 1, :, ...] = input[2:0:-1, :, 1, :, ...] output[0:3, :, 2, :, ...] = input[3:0:-1, :, 2, :, ...] output[0:5, :, 3, :, ...] = input[5:0:-1, :, 3, :, ...] # 通過seq_lens的條目被復制:
output[7:, :, 0, :, ...] = input[7:, :, 0, :, ...] output[2:, :, 1, :, ...] = input[2:, :, 1, :, ...] output[3:, :, 2, :, ...] = input[3:, :, 2, :, ...] output[2:, :, 3, :, ...] = input[2:, :, 3, :, ...]

參數(shù):

  • input:一個Tensor.要反轉的輸入.
  • seq_lengths:一個Tensor.必須是以下類型之一:int32,int64;長度為input.dims(batch_dim)和max(seq_lengths) <= input.dims(seq_dim)的一維
  • seq_axis:int;部分逆轉的維度.
  • batch_axis:可選int.默認為0.逆向執(zhí)行的維度.
  • name:操作的名稱(可選).

返回:

該函數(shù)返回Tensor,與input有相同的類型和形狀,是部分反轉的輸入.

以上內容是否對您有幫助:
在線筆記
App下載
App下載

掃描二維碼

下載編程獅App

公眾號
微信公眾號

編程獅公眾號