App下載

pytorch 一行代碼查看網(wǎng)絡(luò)參數(shù)總量的實(shí)現(xiàn)

一瞬之光 2021-08-19 09:54:52 瀏覽數(shù) (3219)
反饋

在機(jī)器學(xué)習(xí)的代碼調(diào)試過(guò)程中,網(wǎng)絡(luò)參數(shù)總量是一個(gè)重要的參考數(shù)值。那么在pytorch中怎么查看網(wǎng)絡(luò)參數(shù)總量呢?接下來(lái)的這篇文章帶你了解。

大家還是直接看代碼吧~

netG = Generator()
print('# generator parameters:', sum(param.numel() for param in netG.parameters()))
netD = Discriminator()
print('# discriminator parameters:', sum(param.numel() for param in netD.parameters()))

補(bǔ)充:PyTorch查看網(wǎng)絡(luò)模型的參數(shù)量PARAMS和FLOPS等

在PyTorch中,可以使用torchstat這個(gè)庫(kù)來(lái)查看網(wǎng)絡(luò)模型的一些信息,包括總的參數(shù)量params、MAdd、顯卡內(nèi)存占用量和FLOPs等。

示例代碼如下:

from torchstat import stat
from torchvision.models import resnet50, resnet101, resnet152, resnext101_32x8d
model = resnet50()
stat(model, (3, 224, 224))

打印信息如下:

以上就是pytorch中怎么查看網(wǎng)絡(luò)參數(shù)總量的全部?jī)?nèi)容,希望能給大家一個(gè)參考,也希望大家多多支持W3Cschool。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教。



0 人點(diǎn)贊