pytorch中如果自己搭建網(wǎng)絡(luò)并且加載別人的與訓(xùn)練模型的話,如果模型和參數(shù)不嚴(yán)格匹配,就可能會(huì)出問(wèn)題,接下來(lái)記錄一下我的解決方法。
兩個(gè)有序字典找不同
模型的參數(shù)和pth文件的參數(shù)都是有序字典(OrderedDict),把字典中的鍵轉(zhuǎn)為列表就可以在for循環(huán)里迭代找不同了。
model = ResNet18(1)
model_dict1 = torch.load('resnet18.pth')
model_dict2 = model.state_dict()
model_list1 = list(model_dict1.keys())
model_list2 = list(model_dict2.keys())
len1 = len(model_list1)
len2 = len(model_list2)
minlen = min(len1, len2)
for n in range(minlen):
if model_dict1[model_list1[n]].shape != model_dict2[model_list2[n]].shape:
err = 1
自己搭建模型的注意事項(xiàng)
搭網(wǎng)絡(luò)時(shí)要對(duì)照pth文件的字典順序搭,字典順序、權(quán)重尺寸(shape)和變量命名必須與pth文件完全一致。如果僅僅是變量命名不同,可采用類似的方法對(duì)模型的權(quán)重重新賦值。
model = ResNet18(1)
model_dict1 = torch.load('resnet18.pth')
model_dict2 = model.state_dict()
model_list1 = list(model_dict1.keys())
model_list2 = list(model_dict2.keys())
len1 = len(model_list1)
len2 = len(model_list2)
minlen = min(len1, len2)
for n in range(minlen):
if model_dict1[model_list1[n]].shape != model_dict2[model_list2[n]].shape:
continue
model_dict1[model_list1[n]] = model_dict2[model_list2[n]]
model.load_state_dict(model_dict2)
完整的代碼見(jiàn)自己搭建resnet18網(wǎng)絡(luò)并加載torchvision自帶權(quán)重
新增的改進(jìn)代碼
model_dict1 = torch.load('yolov5.pth')
model_dict2 = model.state_dict()
model_list1 = list(model_dict1.keys())
model_list2 = list(model_dict2.keys())
len1 = len(model_list1)
len2 = len(model_list2)
m, n = 0, 0
while True:
if m >= len1 or n >= len2:
break
layername1, layername2 = model_list1[m], model_list2[n]
w1, w2 = model_dict1[layername1], model_dict2[layername2]
if w1.shape != w2.shape:
continue
model_dict2[layername2] = model_dict1[layername1]
m += 1
n += 1
model.load_state_dict(model_dict2)
如果因?yàn)槟P筒黄ヅ?,運(yùn)行第14行語(yǔ)句后,可看自己情況手動(dòng)對(duì)m或n加上1。
補(bǔ)充:pytorch的一些坑:用預(yù)訓(xùn)練的vgg模型的部分層的特征報(bào)錯(cuò),如張量不匹配
看代碼吧~
#打算取VGG19的第二個(gè)全連接層的輸出,那么就需要構(gòu)建一個(gè)類,這個(gè)類要包含VGG的全部卷積層,
#以及到第二個(gè)全連接層的全部網(wǎng)絡(luò)還有他們對(duì)應(yīng)的參數(shù)
class Classification_att(nn.Module):
def __init__(self, rgb_range):
super(Classification_att, self).__init__()
self.vgg19 =models.vgg19(pretrained=True)
vgg = models.vgg19(pretrained=True).features
conv_modules = [m for m in vgg]
self.vgg_conv = nn.Sequential(*conv_modules[:37])
classfi = models.vgg19(pretrained=True).classifier
classif_modules = [n for n in classfi]
self.vgg_class = nn.Sequential(*classif_modules[:4])
vgg_mean = (0.485, 0.456, 0.406)
vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range)
self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std)
for p in self.vgg_conv.parameters():
p.requires_grad = False
for p in self.vgg_class.parameters():
p.requires_grad = False
self.classifi = nn.Sequential(
nn.Linear(4096, 1024),
nn.ReLU(True),
nn.Linear(1024, 256),
nn.ReLU(True),
nn.Linear(256, 64),
)
def forward(self, x):
x = F.interpolate(x, size=[224, 224], scale_factor=None, mode='bilinear',
align_corners=False)
x = self.sub_mean(x)
x = self.vgg_conv(x)
x = self.vgg_class(x) #執(zhí)行這部報(bào)錯(cuò),說(shuō)張量不匹配
原因是因?yàn)榫矸e層的輸出不能直接連接全連接層,即使輸出的張量的總的大小是一致的
查看vgg的pytorch源碼發(fā)現(xiàn)是
x = self.features(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
#自己的代碼沒(méi)有torch.flatten(x, 1)這步
所以自己的少了一步
x = torch.flatten(x, 1)
補(bǔ)上就好了!
以上就是pytorch加載預(yù)訓(xùn)練模型與自己的模型不匹配的解決方案,希望能給大家一個(gè)參考,也希望大家多多支持W3Cschool。