W3Cschool
恭喜您成為首批注冊用戶
獲得88經(jīng)驗值獎勵
原文: https://pytorch.org/docs/stable/notes/autograd.html
本說明將概述 autograd 的工作方式并記錄操作。 不一定要完全了解所有這些內(nèi)容,但我們建議您熟悉它,因為它可以幫助您編寫更高效,更簡潔的程序,并可以幫助您進行調(diào)試。
每個張量都有一個標志:requires_grad
,允許從梯度計算中細粒度地排除子圖,并可以提高效率。
requires_grad
如果某個操作的單個輸入需要進行漸變,則其輸出也將需要進行漸變。 相反,僅當所有輸入都不需要漸變時,輸出才不需要。 在所有張量都不要求漸變的子圖中,永遠不會執(zhí)行向后計算。
>>> x = torch.randn(5, 5) # requires_grad=False by default
>>> y = torch.randn(5, 5) # requires_grad=False by default
>>> z = torch.randn((5, 5), requires_grad=True)
>>> a = x + y
>>> a.requires_grad
False
>>> b = a + z
>>> b.requires_grad
True
當您要凍結(jié)部分模型,或者事先知道您將不使用漸變色時,此功能特別有用。 一些參數(shù)。 例如,如果您想微調(diào)預(yù)訓(xùn)練的 CNN,只需在凍結(jié)的基數(shù)中切換requires_grad
標志,就不會保存任何中間緩沖區(qū),直到計算到達最后一層,仿射變換將使用權(quán)重為 需要梯度,網(wǎng)絡(luò)的輸出也將需要它們。
model = torchvision.models.resnet18(pretrained=True)
for param in model.parameters():
param.requires_grad = False
## Replace the last fully-connected layer
## Parameters of newly constructed modules have requires_grad=True by default
model.fc = nn.Linear(512, 100)
## Optimize only the classifier
optimizer = optim.SGD(model.fc.parameters(), lr=1e-2, momentum=0.9)
Autograd 是反向自動分化系統(tǒng)。 從概念上講,autograd 會記錄一個圖形,記錄執(zhí)行操作時創(chuàng)建數(shù)據(jù)的所有操作,從而為您提供一個有向無環(huán)圖,其葉子為輸入張量,根為輸出張量。 通過從根到葉跟蹤該圖,您可以使用鏈式規(guī)則自動計算梯度。
在內(nèi)部,autograd 將該圖表示為Function
對象(真正的表達式)的圖,可以將其apply()
編輯以計算評估圖的結(jié)果。 在計算前向通過時,autograd 同時執(zhí)行請求的計算,并建立一個表示表示計算梯度的函數(shù)的圖形(每個 torch.Tensor
的.grad_fn
屬性是該圖形的入口)。 完成前向遍歷后,我們在后向遍歷中評估此圖以計算梯度。
需要注意的重要一點是,每次迭代都會從頭開始重新創(chuàng)建圖形,這正是允許使用任意 Python 控制流語句的原因,它可以在每次迭代時更改圖形的整體形狀和大小。 在開始訓(xùn)練之前,您不必編碼所有可能的路徑-跑步就是您的與眾不同。
在 autograd 中支持就地操作很困難,并且在大多數(shù)情況下,我們不鼓勵使用它們。 Autograd 積極的緩沖區(qū)釋放和重用使其非常高效,就地操作實際上很少顯著降低內(nèi)存使用量的情況很少。 除非您在高內(nèi)存壓力下進行操作,否則可能永遠不需要使用它們。
限制就地操作的適用性的主要原因有兩個:
Function
的所有輸入的創(chuàng)建者。 這可能很棘手,特別是如果有許多張量引用相同的存儲(例如通過索引或轉(zhuǎn)置創(chuàng)建的),并且如果修改后的輸入的存儲被任何其他Tensor
引用,則就地函數(shù)實際上會引發(fā)錯誤。每個張量都有一個版本計數(shù)器,每次在任何操作中被標記為臟時,該計數(shù)器都會增加。 當函數(shù)保存任何張量以供向后時,也會保存其包含 Tensor 的版本計數(shù)器。 訪問self.saved_tensors
后,將對其進行檢查,如果該值大于保存的值,則會引發(fā)錯誤。 這樣可以確保,如果您使用的是就地函數(shù)并且沒有看到任何錯誤,則可以確保計算出的梯度是正確的。
Copyright©2021 w3cschool編程獅|閩ICP備15016281號-3|閩公網(wǎng)安備35020302033924號
違法和不良信息舉報電話:173-0602-2364|舉報郵箱:jubao@eeedong.com
掃描二維碼
下載編程獅App
編程獅公眾號
聯(lián)系方式:
更多建議: