PyTorch 筆記

2020-09-10 15:26 更新

自動求導(dǎo)機制

原文: https://pytorch.org/docs/stable/notes/autograd.html

本說明將概述 autograd 的工作方式并記錄操作。 不一定要完全了解所有這些內(nèi)容,但我們建議您熟悉它,因為它可以幫助您編寫更高效,更簡潔的程序,并可以幫助您進行調(diào)試。

從向后排除子圖

每個張量都有一個標志:requires_grad,允許從梯度計算中細粒度地排除子圖,并可以提高效率。

requires_grad

如果某個操作的單個輸入需要進行漸變,則其輸出也將需要進行漸變。 相反,僅當所有輸入都不需要漸變時,輸出才不需要。 在所有張量都不要求漸變的子圖中,永遠不會執(zhí)行向后計算。

  1. >>> x = torch.randn(5, 5) # requires_grad=False by default
  2. >>> y = torch.randn(5, 5) # requires_grad=False by default
  3. >>> z = torch.randn((5, 5), requires_grad=True)
  4. >>> a = x + y
  5. >>> a.requires_grad
  6. False
  7. >>> b = a + z
  8. >>> b.requires_grad
  9. True

當您要凍結(jié)部分模型,或者事先知道您將不使用漸變色時,此功能特別有用。 一些參數(shù)。 例如,如果您想微調(diào)預(yù)訓(xùn)練的 CNN,只需在凍結(jié)的基數(shù)中切換requires_grad標志,就不會保存任何中間緩沖區(qū),直到計算到達最后一層,仿射變換將使用權(quán)重為 需要梯度,網(wǎng)絡(luò)的輸出也將需要它們。

  1. model = torchvision.models.resnet18(pretrained=True)
  2. for param in model.parameters():
  3. param.requires_grad = False
  4. ## Replace the last fully-connected layer
  5. ## Parameters of newly constructed modules have requires_grad=True by default
  6. model.fc = nn.Linear(512, 100)
  7. ## Optimize only the classifier
  8. optimizer = optim.SGD(model.fc.parameters(), lr=1e-2, momentum=0.9)

autograd 如何編碼歷史

Autograd 是反向自動分化系統(tǒng)。 從概念上講,autograd 會記錄一個圖形,記錄執(zhí)行操作時創(chuàng)建數(shù)據(jù)的所有操作,從而為您提供一個有向無環(huán)圖,其葉子為輸入張量,根為輸出張量。 通過從根到葉跟蹤該圖,您可以使用鏈式規(guī)則自動計算梯度。

在內(nèi)部,autograd 將該圖表示為Function對象(真正的表達式)的圖,可以將其apply()編輯以計算評估圖的結(jié)果。 在計算前向通過時,autograd 同時執(zhí)行請求的計算,并建立一個表示表示計算梯度的函數(shù)的圖形(每個 torch.Tensor 的.grad_fn屬性是該圖形的入口)。 完成前向遍歷后,我們在后向遍歷中評估此圖以計算梯度。

需要注意的重要一點是,每次迭代都會從頭開始重新創(chuàng)建圖形,這正是允許使用任意 Python 控制流語句的原因,它可以在每次迭代時更改圖形的整體形狀和大小。 在開始訓(xùn)練之前,您不必編碼所有可能的路徑-跑步就是您的與眾不同。

使用 autograd 進行就地操作

在 autograd 中支持就地操作很困難,并且在大多數(shù)情況下,我們不鼓勵使用它們。 Autograd 積極的緩沖區(qū)釋放和重用使其非常高效,就地操作實際上很少顯著降低內(nèi)存使用量的情況很少。 除非您在高內(nèi)存壓力下進行操作,否則可能永遠不需要使用它們。

限制就地操作的適用性的主要原因有兩個:

  1. 就地操作可能會覆蓋計算梯度所需的值。
  2. 實際上,每個就地操作都需要實現(xiàn)來重寫計算圖。 異地版本僅分配新對象并保留對舊圖形的引用,而就地操作則需要更改表示此操作的Function的所有輸入的創(chuàng)建者。 這可能很棘手,特別是如果有許多張量引用相同的存儲(例如通過索引或轉(zhuǎn)置創(chuàng)建的),并且如果修改后的輸入的存儲被任何其他Tensor引用,則就地函數(shù)實際上會引發(fā)錯誤。

就地正確性檢查

每個張量都有一個版本計數(shù)器,每次在任何操作中被標記為臟時,該計數(shù)器都會增加。 當函數(shù)保存任何張量以供向后時,也會保存其包含 Tensor 的版本計數(shù)器。 訪問self.saved_tensors后,將對其進行檢查,如果該值大于保存的值,則會引發(fā)錯誤。 這樣可以確保,如果您使用的是就地函數(shù)并且沒有看到任何錯誤,則可以確保計算出的梯度是正確的。


以上內(nèi)容是否對您有幫助:
在線筆記
App下載
App下載

掃描二維碼

下載編程獅App

公眾號
微信公眾號

編程獅公眾號