PyTorch 廣播語義

2020-09-10 15:29 更新
原文: https://pytorch.org/docs/stable/notes/broadcasting.html

許多 PyTorch 操作都支持NumPy Broadcasting Semantics

簡而言之,如果 PyTorch 操作支持廣播,則其 Tensor 參數(shù)可以自動(dòng)擴(kuò)展為相等大小(無需復(fù)制數(shù)據(jù))。

一般語義

如果滿足以下規(guī)則,則兩個(gè)張量是“可廣播的”:

  • 每個(gè)張量具有至少一個(gè)維度。
  • 從尾隨尺寸開始迭代尺寸尺寸時(shí),尺寸尺寸必須相等,其中之一為 1,或者不存在其中之一。

例如:

  1. >>> x=torch.empty(5,7,3)
  2. >>> y=torch.empty(5,7,3)
  3. ## same shapes are always broadcastable (i.e. the above rules always hold)
  4. >>> x=torch.empty((0,))
  5. >>> y=torch.empty(2,2)
  6. ## x and y are not broadcastable, because x does not have at least 1 dimension
  7. ## can line up trailing dimensions
  8. >>> x=torch.empty(5,3,4,1)
  9. >>> y=torch.empty( 3,1,1)
  10. ## x and y are broadcastable.
  11. ## 1st trailing dimension: both have size 1
  12. ## 2nd trailing dimension: y has size 1
  13. ## 3rd trailing dimension: x size == y size
  14. ## 4th trailing dimension: y dimension doesn't exist
  15. ## but:
  16. >>> x=torch.empty(5,2,4,1)
  17. >>> y=torch.empty( 3,1,1)
  18. ## x and y are not broadcastable, because in the 3rd trailing dimension 2 != 3

如果兩個(gè)張量xy是“可廣播的”,則所得張量大小的計(jì)算如下:

  • 如果xy的維數(shù)不相等,則在張量的維數(shù)前面加 1,以使其長度相等。
  • 然后,對于每個(gè)尺寸尺寸,所得尺寸尺寸是該尺寸上xy尺寸的最大值。

For Example:

  1. # can line up trailing dimensions to make reading easier
  2. >>> x=torch.empty(5,1,4,1)
  3. >>> y=torch.empty( 3,1,1)
  4. >>> (x+y).size()
  5. torch.Size([5, 3, 4, 1])
  6. ## but not necessary:
  7. >>> x=torch.empty(1)
  8. >>> y=torch.empty(3,1,7)
  9. >>> (x+y).size()
  10. torch.Size([3, 1, 7])
  11. >>> x=torch.empty(5,2,4,1)
  12. >>> y=torch.empty(3,1,1)
  13. >>> (x+y).size()
  14. RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1

就地語義

一個(gè)復(fù)雜之處在于,就地操作不允許就地張量由于廣播而改變形狀。

For Example:

  1. >>> x=torch.empty(5,3,4,1)
  2. >>> y=torch.empty(3,1,1)
  3. >>> (x.add_(y)).size()
  4. torch.Size([5, 3, 4, 1])
  5. ## but:
  6. >>> x=torch.empty(1,3,1)
  7. >>> y=torch.empty(3,1,7)
  8. >>> (x.add_(y)).size()
  9. RuntimeError: The expanded size of the tensor (1) must match the existing size (7) at non-singleton dimension 2.

向后兼容

只要每個(gè)張量中的元素?cái)?shù)量相等,以前的 PyTorch 版本都可以在具有不同形狀的張量上執(zhí)行某些逐點(diǎn)函數(shù)。 然后,通過將每個(gè)張量視為一維來執(zhí)行逐點(diǎn)操作。 PyTorch 現(xiàn)在支持廣播,并且“一維”按點(diǎn)行為被認(rèn)為已棄用,并且在張量不可廣播但具有相同數(shù)量元素的情況下會(huì)生成 Python 警告。

注意,在兩個(gè)張量不具有相同形狀但可廣播且具有相同元素?cái)?shù)量的情況下,廣播的引入會(huì)導(dǎo)致向后不兼容的更改。 例如:

  1. >>> torch.add(torch.ones(4,1), torch.randn(4))

以前會(huì)產(chǎn)生一個(gè)具有大小:torch.Size([4,1])的張量,但現(xiàn)在會(huì)產(chǎn)生一個(gè)具有以下大?。簍orch.Size([4,4])的張量。 為了幫助確定代碼中可能存在廣播引入的向后不兼容的情況,可以將 <cite>torch.utils.backcompat.broadcast_warning.enabled</cite> 設(shè)置為 <cite>True</cite> ,這將生成一個(gè) python 在這種情況下發(fā)出警告。

For Example:

  1. >>> torch.utils.backcompat.broadcast_warning.enabled=True
  2. >>> torch.add(torch.ones(4,1), torch.ones(4))
  3. __main__:1: UserWarning: self and other do not have the same shape, but are broadcastable, and have the same number of elements.
  4. Changing behavior in a backwards incompatible manner to broadcasting rather than viewing as 1-dimensional.


以上內(nèi)容是否對您有幫助:
在線筆記
App下載
App下載

掃描二維碼

下載編程獅App

公眾號
微信公眾號

編程獅公眾號