TensorFlow函數(shù):tf.strided_slice

2018-03-20 14:01 更新

tf.strided_slice函數(shù)

tf.strided_slice(
    input_,
    begin,
    end,
    strides=None,
    begin_mask=0,
    end_mask=0,
    ellipsis_mask=0,
    new_axis_mask=0,
    shrink_axis_mask=0,
    var=None,
    name=None
)

定義在:tensorflow/python/ops/array_ops.py

參見指南:張量變換>分割和連接

提取張量的一個分段切片(廣義 python 數(shù)組索引).

而不是直接調(diào)用這個操作,大多數(shù)用戶會想要使用 NumPy 的風(fēng)格的切片語法(例如,tensor[..., 3:4:-1, tf.newaxis, 3]),它們通過 tf.Tensor.getitemtf.Variable.getitem來支持.此運(yùn)算的接口是切片語法的低級編碼.

粗略地說,這個運(yùn)算從給定的 input_ 張量中提取一個尺寸 (end-begin)/stride 的片段.從 begin 片段指定的位置開始,繼續(xù)添加 stride 索引,直到所有維度都不小于 end.請注意,步幅可能是負(fù)值,這會導(dǎo)致反向切片.

給定一個 Python 的切片 input[spec0, spec1, ..., specn],這個函數(shù)將被調(diào)用如下.

begin,end 與 strides 將是長度 n 的向量.一般 n 不等于 input_ 張量的等級.

在每個掩碼字段(begin_mask、end_mask、ellipsis_mask、new_axis_mask、shrink_axis_mask)中,第 i 位將對應(yīng)于第 i 個規(guī)范.

如果設(shè)置了 begin_mask 的第 i 位,則忽略 begin[i],并使用該維度中的最大范圍來代替.end_mask 類似地工作,除了結(jié)束范圍.

foo[5:,:,:3] 在 7x8x9 張量上相當(dāng)于 foo[5:7,0:8,0:3].foo[::-1] 反轉(zhuǎn)形狀為 8 的張量.

如果設(shè)置了 ellipsis_mask 的第 i 位,則會在其他維度之間插入所需的許多未指定維度.ellipsis_mask 中只允許有一個非零位.

例如,foo[3:5,...,4:5] 在一個形狀 10x3x3x10 張量就相當(dāng)于:foo[3:5,:,:,4:5],并且:foo[3:5,...] 相當(dāng)于 foo[3:5,:,:,:].

如果設(shè)置了 new_axis_mask 的第 i 位,則 begin,end 和 stride 被忽略,并且在輸出張量中的該點(diǎn)處添加新的長度 1 維.

例如,foo[:4, tf.newaxis, :2] 會產(chǎn)生一個形狀 (4, 1, 2)張量.

如果設(shè)置了 shrink_axis_mask 的第 i 位,則意味著第 i 規(guī)范將維度縮小 1.begin[i],end[i] 和 strides[i] 必須意味著維度中的尺寸 1 的切片.例如,在Python中,可能 foo[:, 3, :] 會導(dǎo)致 shrink_axis_mask 等于 2.

注:begin 和 end都是零索引,strides 條目必須非零.

t = tf.constant([[[1, 1, 1], [2, 2, 2]],
                 [[3, 3, 3], [4, 4, 4]],
                 [[5, 5, 5], [6, 6, 6]]])
tf.strided_slice(t, [1, 0, 0], [2, 1, 3], [1, 1, 1])  # [[[3, 3, 3]]]
tf.strided_slice(t, [1, 0, 0], [2, 2, 3], [1, 1, 1])  # [[[3, 3, 3],
                                                      #   [4, 4, 4]]]
tf.strided_slice(t, [1, -1, 0], [2, -3, 3], [1, -1, 1])  # [[[4, 4, 4],
                                                         #   [3, 3, 3]]]

函數(shù)參數(shù):

  • input_:一個 Tensor.
  • begin:一個 int32 或 int64 Tensor.
  • end:一個 int32 或 int64 Tensor.
  • strides:一個 int32 或 int64 Tensor.
  • begin_mask:一個 int32 mask.
  • end_mask:一個 int32 mask.
  • ellipsis_mask:一個 int32 mask.
  • new_axis_mask:一個 int32 mask.
  • shrink_axis_mask:一個 int32 mask.
  • var:與 input_None 對應(yīng)的變量
  • name:操作的名稱(可選).

函數(shù)返回值:

tf.strided_slice函數(shù)返回一個與 input 具有相同的類型的 Tensor.

以上內(nèi)容是否對您有幫助:
在線筆記
App下載
App下載

掃描二維碼

下載編程獅App

公眾號
微信公眾號

編程獅公眾號