TensorFlow函數(shù)教程:tf.nn.pool

2019-01-31 13:51 更新

tf.nn.pool函數(shù)

tf.nn.pool(
    input,
    window_shape,
    pooling_type,
    padding,
    dilation_rate=None,
    strides=None,
    name=None,
    data_format=None
)

定義在:tensorflow/python/ops/nn_ops.py.

請參閱指南:神經(jīng)網(wǎng)絡(luò)>池操作

執(zhí)行N-D池操作.

在data_format不以“NC”開頭的情況下,計算0 <= b <batch_size,0 <= x [i] <output_spatial_shape [i],0 <= c <num_channels:

output[b, x[0], ..., x[N-1], c] =
  REDUCE_{z[0], ..., z[N-1]}
    input[b,
          x[0] * strides[0] - pad_before[0] + dilation_rate[0]*z[0],
          ...
          x[N-1]*strides[N-1] - pad_before[N-1] + dilation_rate[N-1]*z[N-1],
          c],

其中,還原函數(shù)REDUCE取決于pooling_type的值,并且pad_before是根據(jù)此處注釋中描述的padding的值定義的.減少從不包括越界位置.

在data_format以“NC”開頭的情況下,輸入和輸出簡單地轉(zhuǎn)置如下:

pool(input, data_format, **kwargs) =
  tf.transpose(pool(tf.transpose(input, [0] + range(2,N+2) + [1]),
                    **kwargs),
               [0, N+1] + range(1, N+1))

參數(shù):

  • input:秩為N + 2的Tensor,如果data_format不以“NC”(默認(rèn)),則shape為[batch_size] + input_spatial_shape + [num_channels];如果data_format以“NC”開頭,則shape為[batch_size, num_channels] + input_spatial_shape.池化僅在空間維度上發(fā)生.
  • window_shape:N個int> = 1的序列.
  • pooling_type:指定池操作,必須是“AVG”或“MAX”.
  • padding:填充算法必須為“SAME”或“VALID”.
  • dilation_rate: 可選.擴張速度.N個int> = 1的列表.默認(rèn)為[1] * N.如果dilation_rate的任何值> 1,則步幅的所有值必須為1.
  • strides: 可選.N個int> = 1的序列.默認(rèn)為[1] * N.如果步幅的任何值> 1,則dilation_rate的所有值必須為1.
  • name: 可選.操作的名稱.
  • data_format:string或None.指定input和輸出的通道維度是最后一個維度(默認(rèn),或者data_format不是以“NC”開頭),還是第二個維度(如果data_format以“NC”開頭).對于N = 1,有效值為“NWC”(默認(rèn))和“NCW”.對于N = 2,有效值是“NHWC”(默認(rèn))和“NCHW”.對于N = 3,有效值為“NDHWC”(默認(rèn))和“NCDHW”.

返回:

秩為N + 2的張量,如果data_format為None或者不以“NC”開頭,則shape為[batch_size] + output_spatial_shape + [num_channels],或者如果data_format以“NC”開頭,則shape為

[batch_size,num_channels] + output_spatial_shape,其中output_spatial_shape取決于填充的值:

  • 如果padding =“SAME”:output_spatial_shape [i] = ceil(input_spatial_shape [i] / strides [i])
  • 如果padding =“VALID”:output_spatial_shape [i] = ceil((input_spatial_shape [i] - (window_shape [i] - 1)* dilation_rate [i])/ strides [i])

可能引發(fā)的異常:

  • ValueError:如果參數(shù)無效.
以上內(nèi)容是否對您有幫助:
在線筆記
App下載
App下載

掃描二維碼

下載編程獅App

公眾號
微信公眾號

編程獅公眾號