App下載

PyTorch中eval和no_grad有什么關(guān)系?

酒后表演藝術(shù)家 2021-08-20 15:13:16 瀏覽數(shù) (2985)
反饋

在進(jìn)行evaluate的時(shí)候,我們需要同時(shí)使用到eval和no_grad這兩個(gè)函數(shù),有些小伙伴就會(huì)問了,這兩個(gè)函數(shù)有什么功能呢,他們又有什么區(qū)別呢,今天小編就來介紹這兩個(gè)函數(shù)的區(qū)別。

首先這兩者有著本質(zhì)上區(qū)別

model.eval()是用來告知model內(nèi)的各個(gè)layer采取eval模式工作。這個(gè)操作主要是應(yīng)對(duì)諸如dropout和batchnorm這些在訓(xùn)練模式下需要采取不同操作的特殊layer。訓(xùn)練和測(cè)試的時(shí)候都可以開啟。

torch.no_grad()則是告知自動(dòng)求導(dǎo)引擎不要進(jìn)行求導(dǎo)操作。這個(gè)操作的意義在于加速計(jì)算、節(jié)約內(nèi)存。但是由于沒有g(shù)radient,也就沒有辦法進(jìn)行backward。所以只能在測(cè)試的時(shí)候開啟。

所以在evaluate的時(shí)候,需要同時(shí)使用兩者。

model = ...
dataset = ...
loss_fun = ...

# training
lr=0.001
model.train()
for x,y in dataset:
 model.zero_grad()
 p = model(x)
 l = loss_fun(p, y)
 l.backward()
 for p in model.parameters():
  p.data -= lr*p.grad
 
# evaluating
sum_loss = 0.0
model.eval()
with torch.no_grad():
 for x,y in dataset:
  p = model(x)
  l = loss_fun(p, y)
  sum_loss += l
print('total loss:', sum_loss)

另外no_grad還可以作為函數(shù)是修飾符來用,從而簡(jiǎn)化代碼。

def train(model, dataset, loss_fun, lr=0.001):
 model.train()
 for x,y in dataset:
  model.zero_grad()
  p = model(x)
  l = loss_fun(p, y)
  l.backward()
  for p in model.parameters():
   p.data -= lr*p.grad
 
@torch.no_grad()
def test(model, dataset, loss_fun):
 sum_loss = 0.0
 model.eval()
 for x,y in dataset:
  p = model(x)
  l = loss_fun(p, y)
  sum_loss += l
 return sum_loss

# main block:
model = ...
dataset = ...
loss_fun = ...

# training
train()
# test
sum_loss = test()
print('total loss:', sum_loss)

補(bǔ)充:pytorch中model.train、model.eval以及torch.no_grad的用法

1、model.train()

啟用 BatchNormalization 和 Dropout

model.train() 讓model變成訓(xùn)練模式,此時(shí) dropout和batch normalization的操作在訓(xùn)練起到防止網(wǎng)絡(luò)過擬合的問題

2、model.eval()

不啟用 BatchNormalization 和 Dropout

model.eval(),pytorch會(huì)自動(dòng)把BN和DropOut固定住,而用訓(xùn)練好的值。不然的話,一旦test的batch_size過小,很容易就會(huì)被BN層導(dǎo)致所生成圖片顏色失真極大

訓(xùn)練完train樣本后,生成的模型model要用來測(cè)試樣本。在model(test)之前,需要加上model.eval(),否則的話,有輸入數(shù)據(jù),即使不訓(xùn)練,它也會(huì)改變權(quán)值。這是model中含有batch normalization層所帶來的的性質(zhì)。

對(duì)于在訓(xùn)練和測(cè)試時(shí)為什么要這樣做,可以從下面兩段話理解:

在訓(xùn)練的時(shí)候, 會(huì)計(jì)算一個(gè)batch內(nèi)的mean 和var, 但是因?yàn)槭切atch小batch的訓(xùn)練的,所以會(huì)采用加權(quán)或者動(dòng)量的形式來將每個(gè)batch的 mean和var來累加起來,也就是說再算當(dāng)前的batch的時(shí)候,其實(shí)當(dāng)前的權(quán)重只是占了0.1, 之前所有訓(xùn)練過的占了0.9的權(quán)重,這樣做的好處是不至于因?yàn)槟骋粋€(gè)batch太過奇葩而導(dǎo)致的訓(xùn)練不穩(wěn)定。

好,現(xiàn)在假設(shè)訓(xùn)練完成了, 那么在整個(gè)訓(xùn)練集上面也得到了一個(gè)最終的”mean 和var”, BN層里面的參數(shù)也學(xué)習(xí)完了(如果指定學(xué)習(xí)的話),而現(xiàn)在需要測(cè)試了,測(cè)試的時(shí)候往往會(huì)一張圖一張圖的去測(cè),這時(shí)候沒有batch而言了,對(duì)單獨(dú)一個(gè)數(shù)據(jù)做 mean和var是沒有意義的, 那么怎么辦,實(shí)際上在測(cè)試的時(shí)候BN里面用的mean和var就是訓(xùn)練結(jié)束后的mean_final 和 val_final. 也可說是在測(cè)試的時(shí)候BN就是一個(gè)變換。所以在用pytorch的時(shí)候要注意這一點(diǎn),在訓(xùn)練之前要有model.train() 來告訴網(wǎng)絡(luò)現(xiàn)在開啟了訓(xùn)練模式,在eval的時(shí)候要用”model.eval()”, 用來告訴網(wǎng)絡(luò)現(xiàn)在要進(jìn)入測(cè)試模式了.因?yàn)檫@兩種模式下BN的作用是不同的。

3、torch.no_grad()

這條語句的作用是:在測(cè)試時(shí)不進(jìn)行梯度的計(jì)算,這樣可以在測(cè)試時(shí)有效減小顯存的占用,以免發(fā)生顯存溢出(OOM)。

這條語句通常加在網(wǎng)絡(luò)預(yù)測(cè)的那條代碼上。

4、pytorch中model.eval()和“with torch.no_grad()區(qū)別

兩者區(qū)別

在PyTorch中進(jìn)行validation時(shí),會(huì)使用model.eval()切換到測(cè)試模式,在該模式下,

主要用于通知dropout層和batchnorm層在train和val模式間切換

在train模式下,dropout網(wǎng)絡(luò)層會(huì)按照設(shè)定的參數(shù)p設(shè)置保留激活單元的概率(保留概率=p); batchnorm層會(huì)繼續(xù)計(jì)算數(shù)據(jù)的mean和var等參數(shù)并更新。

在val模式下,dropout層會(huì)讓所有的激活單元都通過,而batchnorm層會(huì)停止計(jì)算和更新mean和var,直接使用在訓(xùn)練階段已經(jīng)學(xué)出的mean和var值。

該模式不會(huì)影響各層的gradient計(jì)算行為,即gradient計(jì)算和存儲(chǔ)與training模式一樣,只是不進(jìn)行反傳(backprobagation)

而with torch.zero_grad()則主要是用于停止autograd模塊的工作,以起到加速和節(jié)省顯存的作用,具體行為就是停止gradient計(jì)算,從而節(jié)省了GPU算力和顯存,但是并不會(huì)影響dropout和batchnorm層的行為。

使用場(chǎng)景

如果不在意顯存大小和計(jì)算時(shí)間的話,僅僅使用model.eval()已足夠得到正確的validation的結(jié)果;而with torch.zero_grad()則是更進(jìn)一步加速和節(jié)省gpu空間(因?yàn)椴挥糜?jì)算和存儲(chǔ)gradient),從而可以更快計(jì)算,也可以跑更大的batch來測(cè)試。

以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持W3Cschool。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教。


0 人點(diǎn)贊