PyTorch 轉(zhuǎn)移學(xué)習(xí)的計(jì)算機(jī)視覺(jué)教程

2020-09-07 17:25 更新
原文: https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html

作者: Sasank Chilamkurthy

在本教程中,您將學(xué)習(xí)如何使用轉(zhuǎn)移學(xué)習(xí)訓(xùn)練卷積神經(jīng)網(wǎng)絡(luò)進(jìn)行圖像分類(lèi)。 您可以在 cs231n 筆記上了解有關(guān)轉(zhuǎn)移學(xué)習(xí)的更多信息。

引用這些注釋?zhuān)?/p>

實(shí)際上,很少有人從頭開(kāi)始訓(xùn)練整個(gè)卷積網(wǎng)絡(luò)(使用隨機(jī)初始化),因?yàn)閾碛凶銐虼笮〉臄?shù)據(jù)集相對(duì)很少。 相反,通常在非常大的數(shù)據(jù)集上對(duì) ConvNet 進(jìn)行預(yù)訓(xùn)練(例如 ImageNet,其中包含 120 萬(wàn)個(gè)具有 1000 個(gè)類(lèi)別的圖像),然后將 ConvNet 用作初始化或固定特征提取器以完成感興趣的任務(wù)。

這兩個(gè)主要的轉(zhuǎn)移學(xué)習(xí)方案如下所示:

  • 對(duì)卷積網(wǎng)絡(luò)進(jìn)行微調(diào):代替隨機(jī)初始化,我們使用經(jīng)過(guò)預(yù)訓(xùn)練的網(wǎng)絡(luò)初始化網(wǎng)絡(luò),例如在 imagenet 1000 數(shù)據(jù)集上進(jìn)行訓(xùn)練的網(wǎng)絡(luò)。 其余的訓(xùn)練照常進(jìn)行。
  • ConvNet 作為固定特征提取器:在這里,我們將凍結(jié)除最終完全連接層以外的所有網(wǎng)絡(luò)的權(quán)重。 最后一個(gè)完全連接的層將替換為具有隨機(jī)權(quán)重的新層,并且僅訓(xùn)練該層。
  1. # License: BSD
  2. ## Author: Sasank Chilamkurthy
  3. from __future__ import print_function, division
  4. import torch
  5. import torch.nn as nn
  6. import torch.optim as optim
  7. from torch.optim import lr_scheduler
  8. import numpy as np
  9. import torchvision
  10. from torchvision import datasets, models, transforms
  11. import matplotlib.pyplot as plt
  12. import time
  13. import os
  14. import copy
  15. plt.ion() # interactive mode

載入資料

我們將使用 torchvision 和 torch.utils.data 包來(lái)加載數(shù)據(jù)。

我們今天要解決的問(wèn)題是訓(xùn)練一個(gè)模型來(lái)對(duì)螞蟻和蜜蜂進(jìn)行分類(lèi)。 我們?yōu)槲浵伜兔鄯涮峁┝舜蠹s 120 張訓(xùn)練圖像。 每個(gè)類(lèi)別有 75 個(gè)驗(yàn)證圖像。 通常,如果從頭開(kāi)始訓(xùn)練的話,這是一個(gè)很小的數(shù)據(jù)集。 由于我們正在使用遷移學(xué)習(xí),因此我們應(yīng)該能夠很好地概括。

該數(shù)據(jù)集是 imagenet 的很小一部分。

注意

此處下載數(shù)據(jù),并將其解壓縮到當(dāng)前目錄。

  1. # Data augmentation and normalization for training
  2. ## Just normalization for validation
  3. data_transforms = {
  4. 'train': transforms.Compose([
  5. transforms.RandomResizedCrop(224),
  6. transforms.RandomHorizontalFlip(),
  7. transforms.ToTensor(),
  8. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  9. ]),
  10. 'val': transforms.Compose([
  11. transforms.Resize(256),
  12. transforms.CenterCrop(224),
  13. transforms.ToTensor(),
  14. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  15. ]),
  16. }
  17. data_dir = 'data/hymenoptera_data'
  18. image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
  19. data_transforms[x])
  20. for x in ['train', 'val']}
  21. dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
  22. shuffle=True, num_workers=4)
  23. for x in ['train', 'val']}
  24. dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
  25. class_names = image_datasets['train'].classes
  26. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

可視化一些圖像

讓我們可視化一些訓(xùn)練圖像,以了解數(shù)據(jù)擴(kuò)充。

  1. def imshow(inp, title=None):
  2. """Imshow for Tensor."""
  3. inp = inp.numpy().transpose((1, 2, 0))
  4. mean = np.array([0.485, 0.456, 0.406])
  5. std = np.array([0.229, 0.224, 0.225])
  6. inp = std * inp + mean
  7. inp = np.clip(inp, 0, 1)
  8. plt.imshow(inp)
  9. if title is not None:
  10. plt.title(title)
  11. plt.pause(0.001) # pause a bit so that plots are updated
  12. ## Get a batch of training data
  13. inputs, classes = next(iter(dataloaders['train']))
  14. ## Make a grid from batch
  15. out = torchvision.utils.make_grid(inputs)
  16. imshow(out, title=[class_names[x] for x in classes])

../_images/sphx_glr_transfer_learning_tutorial_001.png

訓(xùn)練模型

現(xiàn)在,讓我們編寫(xiě)一個(gè)通用函數(shù)來(lái)訓(xùn)練模型。 在這里,我們將說(shuō)明:

  • 安排學(xué)習(xí)率
  • 保存最佳模型

以下,參數(shù)scheduler是來(lái)自torch.optim.lr_scheduler的 LR 調(diào)度程序?qū)ο蟆?/p>

  1. def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
  2. since = time.time()
  3. best_model_wts = copy.deepcopy(model.state_dict())
  4. best_acc = 0.0
  5. for epoch in range(num_epochs):
  6. print('Epoch {}/{}'.format(epoch, num_epochs - 1))
  7. print('-' * 10)
  8. # Each epoch has a training and validation phase
  9. for phase in ['train', 'val']:
  10. if phase == 'train':
  11. model.train() # Set model to training mode
  12. else:
  13. model.eval() # Set model to evaluate mode
  14. running_loss = 0.0
  15. running_corrects = 0
  16. # Iterate over data.
  17. for inputs, labels in dataloaders[phase]:
  18. inputs = inputs.to(device)
  19. labels = labels.to(device)
  20. # zero the parameter gradients
  21. optimizer.zero_grad()
  22. # forward
  23. # track history if only in train
  24. with torch.set_grad_enabled(phase == 'train'):
  25. outputs = model(inputs)
  26. _, preds = torch.max(outputs, 1)
  27. loss = criterion(outputs, labels)
  28. # backward + optimize only if in training phase
  29. if phase == 'train':
  30. loss.backward()
  31. optimizer.step()
  32. # statistics
  33. running_loss += loss.item() * inputs.size(0)
  34. running_corrects += torch.sum(preds == labels.data)
  35. if phase == 'train':
  36. scheduler.step()
  37. epoch_loss = running_loss / dataset_sizes[phase]
  38. epoch_acc = running_corrects.double() / dataset_sizes[phase]
  39. print('{} Loss: {:.4f} Acc: {:.4f}'.format(
  40. phase, epoch_loss, epoch_acc))
  41. # deep copy the model
  42. if phase == 'val' and epoch_acc > best_acc:
  43. best_acc = epoch_acc
  44. best_model_wts = copy.deepcopy(model.state_dict())
  45. print()
  46. time_elapsed = time.time() - since
  47. print('Training complete in {:.0f}m {:.0f}s'.format(
  48. time_elapsed // 60, time_elapsed % 60))
  49. print('Best val Acc: {:4f}'.format(best_acc))
  50. # load best model weights
  51. model.load_state_dict(best_model_wts)
  52. return model

可視化模型預(yù)測(cè)

通用功能可顯示一些圖像的預(yù)測(cè)

  1. def visualize_model(model, num_images=6):
  2. was_training = model.training
  3. model.eval()
  4. images_so_far = 0
  5. fig = plt.figure()
  6. with torch.no_grad():
  7. for i, (inputs, labels) in enumerate(dataloaders['val']):
  8. inputs = inputs.to(device)
  9. labels = labels.to(device)
  10. outputs = model(inputs)
  11. _, preds = torch.max(outputs, 1)
  12. for j in range(inputs.size()[0]):
  13. images_so_far += 1
  14. ax = plt.subplot(num_images//2, 2, images_so_far)
  15. ax.axis('off')
  16. ax.set_title('predicted: {}'.format(class_names[preds[j]]))
  17. imshow(inputs.cpu().data[j])
  18. if images_so_far == num_images:
  19. model.train(mode=was_training)
  20. return
  21. model.train(mode=was_training)

微調(diào) convnet

加載預(yù)訓(xùn)練的模型并重置最終的完全連接層。

  1. model_ft = models.resnet18(pretrained=True)
  2. num_ftrs = model_ft.fc.in_features
  3. ## Here the size of each output sample is set to 2.
  4. ## Alternatively, it can be generalized to nn.Linear(num_ftrs, len(class_names)).
  5. model_ft.fc = nn.Linear(num_ftrs, 2)
  6. model_ft = model_ft.to(device)
  7. criterion = nn.CrossEntropyLoss()
  8. ## Observe that all parameters are being optimized
  9. optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)
  10. ## Decay LR by a factor of 0.1 every 7 epochs
  11. exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

訓(xùn)練和評(píng)估

在 CPU 上大約需要 15-25 分鐘。 但是在 GPU 上,此過(guò)程不到一分鐘。

  1. model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler,
  2. num_epochs=25)

得出:

  1. Epoch 0/24
  2. ----------
  3. train Loss: 0.5582 Acc: 0.6967
  4. val Loss: 0.1987 Acc: 0.9216
  5. Epoch 1/24
  6. ----------
  7. train Loss: 0.4663 Acc: 0.8238
  8. val Loss: 0.2519 Acc: 0.8889
  9. Epoch 2/24
  10. ----------
  11. train Loss: 0.5978 Acc: 0.7623
  12. val Loss: 1.2933 Acc: 0.6601
  13. Epoch 3/24
  14. ----------
  15. train Loss: 0.4471 Acc: 0.8320
  16. val Loss: 0.2576 Acc: 0.8954
  17. Epoch 4/24
  18. ----------
  19. train Loss: 0.3654 Acc: 0.8115
  20. val Loss: 0.2977 Acc: 0.9150
  21. Epoch 5/24
  22. ----------
  23. train Loss: 0.4404 Acc: 0.8197
  24. val Loss: 0.3330 Acc: 0.8627
  25. Epoch 6/24
  26. ----------
  27. train Loss: 0.6416 Acc: 0.7623
  28. val Loss: 0.3174 Acc: 0.8693
  29. Epoch 7/24
  30. ----------
  31. train Loss: 0.4058 Acc: 0.8361
  32. val Loss: 0.2551 Acc: 0.9085
  33. Epoch 8/24
  34. ----------
  35. train Loss: 0.2294 Acc: 0.9098
  36. val Loss: 0.2603 Acc: 0.9085
  37. Epoch 9/24
  38. ----------
  39. train Loss: 0.2805 Acc: 0.8730
  40. val Loss: 0.2765 Acc: 0.8954
  41. Epoch 10/24
  42. ----------
  43. train Loss: 0.3139 Acc: 0.8525
  44. val Loss: 0.2639 Acc: 0.9020
  45. Epoch 11/24
  46. ----------
  47. train Loss: 0.3198 Acc: 0.8648
  48. val Loss: 0.2458 Acc: 0.9020
  49. Epoch 12/24
  50. ----------
  51. train Loss: 0.2947 Acc: 0.8811
  52. val Loss: 0.2835 Acc: 0.8889
  53. Epoch 13/24
  54. ----------
  55. train Loss: 0.3097 Acc: 0.8730
  56. val Loss: 0.2542 Acc: 0.9085
  57. Epoch 14/24
  58. ----------
  59. train Loss: 0.1849 Acc: 0.9303
  60. val Loss: 0.2710 Acc: 0.9085
  61. Epoch 15/24
  62. ----------
  63. train Loss: 0.2764 Acc: 0.8934
  64. val Loss: 0.2522 Acc: 0.9085
  65. Epoch 16/24
  66. ----------
  67. train Loss: 0.2214 Acc: 0.9098
  68. val Loss: 0.2620 Acc: 0.9085
  69. Epoch 17/24
  70. ----------
  71. train Loss: 0.2949 Acc: 0.8525
  72. val Loss: 0.2600 Acc: 0.9085
  73. Epoch 18/24
  74. ----------
  75. train Loss: 0.2237 Acc: 0.9139
  76. val Loss: 0.2666 Acc: 0.9020
  77. Epoch 19/24
  78. ----------
  79. train Loss: 0.2456 Acc: 0.8852
  80. val Loss: 0.2521 Acc: 0.9150
  81. Epoch 20/24
  82. ----------
  83. train Loss: 0.2351 Acc: 0.8852
  84. val Loss: 0.2781 Acc: 0.9085
  85. Epoch 21/24
  86. ----------
  87. train Loss: 0.2654 Acc: 0.8730
  88. val Loss: 0.2560 Acc: 0.9085
  89. Epoch 22/24
  90. ----------
  91. train Loss: 0.1955 Acc: 0.9262
  92. val Loss: 0.2605 Acc: 0.9020
  93. Epoch 23/24
  94. ----------
  95. train Loss: 0.2285 Acc: 0.8893
  96. val Loss: 0.2650 Acc: 0.9085
  97. Epoch 24/24
  98. ----------
  99. train Loss: 0.2360 Acc: 0.9221
  100. val Loss: 0.2690 Acc: 0.8954
  101. Training complete in 1m 7s
  102. Best val Acc: 0.921569
  1. visualize_model(model_ft)

../_images/sphx_glr_transfer_learning_tutorial_002.png

ConvNet 作為固定特征提取器

在這里,我們需要凍結(jié)除最后一層之外的所有網(wǎng)絡(luò)。 我們需要設(shè)置requires_grad == False凍結(jié)參數(shù),以便不在backward()中計(jì)算梯度。

您可以在的文檔中閱讀有關(guān)此內(nèi)容的更多信息。

  1. model_conv = torchvision.models.resnet18(pretrained=True)
  2. for param in model_conv.parameters():
  3. param.requires_grad = False
  4. ## Parameters of newly constructed modules have requires_grad=True by default
  5. num_ftrs = model_conv.fc.in_features
  6. model_conv.fc = nn.Linear(num_ftrs, 2)
  7. model_conv = model_conv.to(device)
  8. criterion = nn.CrossEntropyLoss()
  9. ## Observe that only parameters of final layer are being optimized as
  10. ## opposed to before.
  11. optimizer_conv = optim.SGD(model_conv.fc.parameters(), lr=0.001, momentum=0.9)
  12. ## Decay LR by a factor of 0.1 every 7 epochs
  13. exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)

Train and evaluate

與以前的方案相比,在 CPU 上將花費(fèi)大約一半的時(shí)間。 這是可以預(yù)期的,因?yàn)椴恍枰獮榇蠖鄶?shù)網(wǎng)絡(luò)計(jì)算梯度。 但是,確實(shí)需要計(jì)算正向。

  1. model_conv = train_model(model_conv, criterion, optimizer_conv,
  2. exp_lr_scheduler, num_epochs=25)

得出:

  1. Epoch 0/24
  2. ----------
  3. train Loss: 0.5633 Acc: 0.7008
  4. val Loss: 0.2159 Acc: 0.9412
  5. Epoch 1/24
  6. ----------
  7. train Loss: 0.4394 Acc: 0.7623
  8. val Loss: 0.2000 Acc: 0.9150
  9. Epoch 2/24
  10. ----------
  11. train Loss: 0.5182 Acc: 0.7623
  12. val Loss: 0.1897 Acc: 0.9346
  13. Epoch 3/24
  14. ----------
  15. train Loss: 0.3993 Acc: 0.8074
  16. val Loss: 0.3029 Acc: 0.8824
  17. Epoch 4/24
  18. ----------
  19. train Loss: 0.4163 Acc: 0.8607
  20. val Loss: 0.2190 Acc: 0.9412
  21. Epoch 5/24
  22. ----------
  23. train Loss: 0.4741 Acc: 0.7951
  24. val Loss: 0.1903 Acc: 0.9477
  25. Epoch 6/24
  26. ----------
  27. train Loss: 0.4266 Acc: 0.8115
  28. val Loss: 0.2178 Acc: 0.9281
  29. Epoch 7/24
  30. ----------
  31. train Loss: 0.3623 Acc: 0.8238
  32. val Loss: 0.2080 Acc: 0.9412
  33. Epoch 8/24
  34. ----------
  35. train Loss: 0.3979 Acc: 0.8279
  36. val Loss: 0.1796 Acc: 0.9412
  37. Epoch 9/24
  38. ----------
  39. train Loss: 0.3534 Acc: 0.8648
  40. val Loss: 0.2043 Acc: 0.9412
  41. Epoch 10/24
  42. ----------
  43. train Loss: 0.3849 Acc: 0.8115
  44. val Loss: 0.2012 Acc: 0.9346
  45. Epoch 11/24
  46. ----------
  47. train Loss: 0.3814 Acc: 0.8361
  48. val Loss: 0.2088 Acc: 0.9412
  49. Epoch 12/24
  50. ----------
  51. train Loss: 0.3443 Acc: 0.8648
  52. val Loss: 0.1823 Acc: 0.9477
  53. Epoch 13/24
  54. ----------
  55. train Loss: 0.2931 Acc: 0.8525
  56. val Loss: 0.1853 Acc: 0.9477
  57. Epoch 14/24
  58. ----------
  59. train Loss: 0.2749 Acc: 0.8811
  60. val Loss: 0.2068 Acc: 0.9412
  61. Epoch 15/24
  62. ----------
  63. train Loss: 0.3387 Acc: 0.8566
  64. val Loss: 0.2080 Acc: 0.9477
  65. Epoch 16/24
  66. ----------
  67. train Loss: 0.2992 Acc: 0.8648
  68. val Loss: 0.2096 Acc: 0.9346
  69. Epoch 17/24
  70. ----------
  71. train Loss: 0.3396 Acc: 0.8648
  72. val Loss: 0.1870 Acc: 0.9412
  73. Epoch 18/24
  74. ----------
  75. train Loss: 0.3956 Acc: 0.8320
  76. val Loss: 0.1858 Acc: 0.9412
  77. Epoch 19/24
  78. ----------
  79. train Loss: 0.3379 Acc: 0.8402
  80. val Loss: 0.1729 Acc: 0.9542
  81. Epoch 20/24
  82. ----------
  83. train Loss: 0.2555 Acc: 0.8811
  84. val Loss: 0.2186 Acc: 0.9281
  85. Epoch 21/24
  86. ----------
  87. train Loss: 0.3764 Acc: 0.8484
  88. val Loss: 0.1817 Acc: 0.9477
  89. Epoch 22/24
  90. ----------
  91. train Loss: 0.2747 Acc: 0.8975
  92. val Loss: 0.2042 Acc: 0.9412
  93. Epoch 23/24
  94. ----------
  95. train Loss: 0.3072 Acc: 0.8689
  96. val Loss: 0.1924 Acc: 0.9477
  97. Epoch 24/24
  98. ----------
  99. train Loss: 0.3479 Acc: 0.8402
  100. val Loss: 0.1835 Acc: 0.9477
  101. Training complete in 0m 34s
  102. Best val Acc: 0.954248
  1. visualize_model(model_conv)
  2. plt.ioff()
  3. plt.show()

../_images/sphx_glr_transfer_learning_tutorial_003.png

進(jìn)階學(xué)習(xí)

如果您想了解有關(guān)遷移學(xué)習(xí)的更多信息,請(qǐng)查看我們的計(jì)算機(jī)視覺(jué)教程的量化遷移學(xué)習(xí)。

腳本的總運(yùn)行時(shí)間:(1 分鐘 53.551 秒)

Download Python source code: transfer_learning_tutorial.py Download Jupyter notebook: transfer_learning_tutorial.ipynb

由獅身人面像畫(huà)廊生成的畫(huà)廊


以上內(nèi)容是否對(duì)您有幫助:
在線筆記
App下載
App下載

掃描二維碼

下載編程獅App

公眾號(hào)
微信公眾號(hào)

編程獅公眾號(hào)