PyTorch 使用 PyTorch C ++前端

2020-09-16 14:19 更新

原文:PyTorch 使用 PyTorch C ++前端

校驗(yàn)者:qiwei_ji

PyTorch C ++前端是 PyTorch 機(jī)器學(xué)習(xí)框架的純 C ++接口。 雖然 PyTorch 的主要接口自然是 Python,但此 Python API 建立于大量的 C ++代碼庫(kù)之上,提供基本的數(shù)據(jù)結(jié)構(gòu)和功能,例如張量和自動(dòng)微分。 C ++前端公開了純 C ++ 11 API,該 API 使用機(jī)器學(xué)習(xí)訓(xùn)練和推理所需的工具,擴(kuò)展了此基礎(chǔ) C ++代碼庫(kù)。 該拓展包括用于神經(jīng)網(wǎng)絡(luò)建模的通用組件的內(nèi)置集合; 使用自定義模塊擴(kuò)展此集合的 API; 一個(gè)流行的優(yōu)化算法庫(kù),例如隨機(jī)梯度下降; 具有 API 的并行數(shù)據(jù)加載器,用于定義和加載數(shù)據(jù)集; 序列化例程等。

本教程將引導(dǎo)您完成使用 C ++前端訓(xùn)練模型的端到端示例。 具體來說,我們將訓(xùn)練 DCGAN (一種生成模型),以生成 MNIST 數(shù)字的圖像。 雖然從概念上講,這只是一個(gè)簡(jiǎn)單的示例,但它足以使您對(duì) PyTorch C ++前端有個(gè)大概的了解,并可以滿足訓(xùn)練更復(fù)雜模型的需求。 我們將從一些鼓舞人心的詞開始,說明您為什么要使用 C ++前端,然后直接深入定義和訓(xùn)練我們的模型。

Tip

本筆記概述了 C ++前端的組件和設(shè)計(jì)原理。

Tip

有關(guān) PyTorch C ++生態(tài)系統(tǒng)的文檔,請(qǐng)?jiān)L問 https://pytorch.org/cppdocs 。 您可以在此處找到高級(jí)描述以及 API 級(jí)文檔。

動(dòng)機(jī)

在我們開始 GAN 和 MNIST 數(shù)字的激動(dòng)人心的旅程之前,讓我們退一步來討論為什么您要使用 C ++前端而不是 Python。 我們(PyTorch 團(tuán)隊(duì))創(chuàng)建了 C ++前端,以便能夠在無法使用 Python 或根本不適合該工具的環(huán)境中進(jìn)行研究。 此類環(huán)境的示例包括:

  • 低延遲系統(tǒng):您可能希望在具有高幀率和低延遲要求的純 C ++游戲引擎中進(jìn)行強(qiáng)化學(xué)習(xí)研究。 與 Python 庫(kù)相比,使用純 C ++庫(kù)更適合這種環(huán)境。 由于 Python 解釋器運(yùn)行緩慢,Python 可能根本無法處理此類問題。
  • 高度多線程環(huán)境:由于全局解釋器鎖定(GIL),Python 一次不能運(yùn)行多個(gè)系統(tǒng)線程。 并行處理是一種替代方法,但可擴(kuò)展性不強(qiáng),并且存在很多缺點(diǎn)。 C ++沒有這樣的約束,線程易于使用和創(chuàng)建。 需要高度并行化的模型,例如深層神經(jīng)進(jìn)化中使用的模型,可以從中受益。
  • 現(xiàn)有的 C ++代碼庫(kù):您可能下載了 C ++應(yīng)用程序,其工作范圍從后端服務(wù)器中的網(wǎng)頁(yè)服務(wù)到照片編輯軟件中的 3D 圖形渲染應(yīng)有盡有,并且希望將機(jī)器學(xué)習(xí)方法集成到您的系統(tǒng)中。 C ++前端使您可以繼續(xù)使用 C ++,并省去在 Python 和 C ++之間來回綁定的麻煩,同時(shí)保留了傳統(tǒng) PyTorch(Python)大部分的靈活性和直觀性。

C ++前端與 Python 前端并非是競(jìng)爭(zhēng)關(guān)系。 前者是對(duì)后者的補(bǔ)充。 我們知道研究人員和工程師都喜歡 PyTorch,因?yàn)樗哂泻?jiǎn)單,靈活和直觀的 API。 我們的目標(biāo)是確保您可以在所有可能的環(huán)境(包括上述環(huán)境)中利用這些核心設(shè)計(jì)原則。 如果上述的這些情況之一很好地描述了您的用例,或者您只是感興趣或好奇,請(qǐng)?jiān)谝韵露温渲欣^續(xù)研究 C ++前端。

Tip

C ++前端試圖提供一個(gè)與 Python 前端盡可能接近的 API。 如果您對(duì) Python 前端有豐富的經(jīng)驗(yàn),并且問過自己“我可以使用 C ++前端做些什么 ?”,請(qǐng)像在 Python 中那樣編寫代碼,并且大多數(shù)情況下,相同的函數(shù)和方法都可以在 C ++中使用。 就像在 Python 中一樣(記得用雙冒號(hào)替換點(diǎn))。

編寫基本應(yīng)用程序

首先,編寫一個(gè)最小的 C ++應(yīng)用程序,以驗(yàn)證我們是否在同一頁(yè)面上了解我們的設(shè)置和構(gòu)建環(huán)境。 首先,您需要獲取 LibTorch 發(fā)行版的副本-我們現(xiàn)成的 zip 歸檔文件,其中打包了使用 C ++前端所需的所有相關(guān)首部,庫(kù)和 CMake 構(gòu)建文件。 LibTorch 發(fā)行版可在 PyTorch 網(wǎng)站上下載,適用于 Linux,MacOS 和 Windows。 本教程的其余部分將假定基本的 Ubuntu Linux 環(huán)境,但是您也可以在 MacOS 或 Windows 上進(jìn)行學(xué)習(xí)。

Tip

關(guān)于安裝 PyTorch的 C ++發(fā)行版 的注釋更詳細(xì)地描述了以下步驟。

Tip

在 Windows 上,調(diào)試和發(fā)行版本不兼容 ABI。 如果您打算以調(diào)試模式構(gòu)建項(xiàng)目,請(qǐng)嘗試使用 LibTorch 的調(diào)試版本。 另外,請(qǐng)確保在下面的cmake --build .行中指定正確的配置。

第一步,通過從 PyTorch 網(wǎng)站獲取的鏈接在本地下載 LibTorch 發(fā)行版。 對(duì)于普通的 Ubuntu Linux 環(huán)境,這意味著運(yùn)行以下步驟:

## If you need e.g. CUDA 9.0 support, please replace "cpu" with "cu90" in the URL below.
wget https://download.pytorch.org/libtorch/nightly/cpu/libtorch-shared-with-deps-latest.zip
unzip libtorch-shared-with-deps-latest.zip

接下來,讓我們編寫一個(gè)名為dcgan.cpp的小型 C ++文件,其中包含torch/torch.h,現(xiàn)在只需打印出三乘三的單位矩陣即可:

#include <torch/torch.h>
#include <iostream>


int main() {
  torch::Tensor tensor = torch::eye(3);
  std::cout << tensor << std::endl;
}

要在以后構(gòu)建這個(gè)應(yīng)用程序以及我們完整的訓(xùn)練腳本,我們將使用以下CMakeLists.txt文件:

cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
project(dcgan)


find_package(Torch REQUIRED)


add_executable(dcgan dcgan.cpp)
target_link_libraries(dcgan "${TORCH_LIBRARIES}")
set_property(TARGET dcgan PROPERTY CXX_STANDARD 14)

注意

雖然 CMake 是 LibTorch 的推薦的構(gòu)建系統(tǒng),但這并不是硬性要求。 您還可以使用 Visual Studio 項(xiàng)目文件,QMake,普通 Makefile 或您認(rèn)為合適的任何其他構(gòu)建環(huán)境。 但是,我們不為此提供現(xiàn)成的支持。

在上面的 CMake 文件中記下第 4 行:find_package(Torch REQUIRED)。 這表示 CMake 在查找 LibTorch 庫(kù)的構(gòu)建配置。 為了使 CMake 知道在哪里找到這些文件,調(diào)用cmake時(shí)必須設(shè)置CMAKE_PREFIX_PATH。 在執(zhí)行此操作之前,讓我們就dcgan應(yīng)用程序的以下目錄結(jié)構(gòu)達(dá)成一致:_

dcgan/
  CMakeLists.txt
  dcgan.cpp

此外,我將指向未壓縮的 LibTorch 分布的路徑稱為/path/to/libtorch。 請(qǐng)注意,此必須是絕對(duì)路徑。 特別是,將CMAKE_PREFIX_PATH設(shè)置為../../libtorch之類的內(nèi)容會(huì)以意想不到的方式中斷, 應(yīng)該寫$PWD/../../libtorch以獲取相應(yīng)的絕對(duì)路徑。 現(xiàn)在,我們準(zhǔn)備構(gòu)建我們的應(yīng)用程序:

root@fa350df05ecf:/home# mkdir build
root@fa350df05ecf:/home# cd build
root@fa350df05ecf:/home/build# cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch ..
-- The C compiler identification is GNU 5.4.0
-- The CXX compiler identification is GNU 5.4.0
-- Check for working C compiler: /usr/bin/cc
-- Check for working C compiler: /usr/bin/cc -- works
-- Detecting C compiler ABI info
-- Detecting C compiler ABI info - done
-- Detecting C compile features
-- Detecting C compile features - done
-- Check for working CXX compiler: /usr/bin/c++
-- Check for working CXX compiler: /usr/bin/c++ -- works
-- Detecting CXX compiler ABI info
-- Detecting CXX compiler ABI info - done
-- Detecting CXX compile features
-- Detecting CXX compile features - done
-- Looking for pthread.h
-- Looking for pthread.h - found
-- Looking for pthread_create
-- Looking for pthread_create - not found
-- Looking for pthread_create in pthreads
-- Looking for pthread_create in pthreads - not found
-- Looking for pthread_create in pthread
-- Looking for pthread_create in pthread - found
-- Found Threads: TRUE
-- Found torch: /path/to/libtorch/lib/libtorch.so
-- Configuring done
-- Generating done
-- Build files have been written to: /home/build
root@fa350df05ecf:/home/build# cmake --build . --config Release
Scanning dependencies of target dcgan
[ 50%] Building CXX object CMakeFiles/dcgan.dir/dcgan.cpp.o
[100%] Linking CXX executable dcgan
[100%] Built target dcgan

上面,我們首先在dcgan目錄內(nèi)創(chuàng)建一個(gè)build文件夾,進(jìn)入該文件夾,運(yùn)行cmake命令以生成必要的 build(Make)文件,最后通過運(yùn)行cmake --build . --config Release成功編譯該項(xiàng)目。 現(xiàn)在我們準(zhǔn)備執(zhí)行最小的二進(jìn)制文件并完成有關(guān)基本項(xiàng)目配置的這一部分:

root@fa350df05ecf:/home/build# ./dcgan
1  0  0
0  1  0
0  0  1
[ Variable[CPUFloatType]{3,3} ]

在我看來這就像一個(gè)單位矩陣!

定義神經(jīng)網(wǎng)絡(luò)模型

現(xiàn)在我們已經(jīng)配置了基本環(huán)境,我們可以深入研究本教程中更有趣的部分。 首先,我們將討論如何在 C ++前端中定義模塊并與之交互。 我們將從基本的小規(guī)模示例模塊開始,然后使用 C ++前端提供的廣泛的內(nèi)置模塊庫(kù)來實(shí)現(xiàn)完整的 GAN。

模塊 API 基礎(chǔ)

與 Python 界面一致,基于 C ++前端的神經(jīng)網(wǎng)絡(luò)由稱為模塊的可重用構(gòu)建塊組成。 有一個(gè)基礎(chǔ)模塊類,所有其他模塊都從該基礎(chǔ)類派生。 在 Python 中,此類為torch.nn.Module,在 C ++中為torch::nn::Module。 除了實(shí)現(xiàn)模塊封裝的算法的forward()方法之外,模塊通常還包含以下三種子對(duì)象中的任何一種:參數(shù),緩沖區(qū)和子模塊。

參數(shù)和緩沖區(qū)以張量的形式存儲(chǔ)。 參數(shù)記錄梯度,但緩沖區(qū)不記錄。 參數(shù)通常是神經(jīng)網(wǎng)絡(luò)的可訓(xùn)練權(quán)重。 緩沖區(qū)的示例包括批量標(biāo)準(zhǔn)化的均值和方差。 為了重用特定的邏輯和狀態(tài)塊,PyTorch API 允許嵌套模塊。 嵌套模塊稱為子模塊。

參數(shù),緩沖區(qū)和子模塊是必須被注冊(cè)的。 注冊(cè)后,可以使用parameters()buffers()之類的方法來檢索整個(gè)(嵌套)模塊層次結(jié)構(gòu)中所有參數(shù)的容器。 類似地,使用to(...)之類的方法,例如 to(torch::kCUDA)將所有參數(shù)和緩沖區(qū)從 CPU 移到 CUDA 內(nèi)存,在整個(gè)模塊層次結(jié)構(gòu)上工作。

定義模塊和注冊(cè)參數(shù)

為了將這些詞寫成代碼,讓我們考慮一下用 Python 界面編寫的簡(jiǎn)單模塊:

import torch


class Net(torch.nn.Module):
  def __init__(self, N, M):
    super(Net, self).__init__()
    self.W = torch.nn.Parameter(torch.randn(N, M))
    self.b = torch.nn.Parameter(torch.randn(M))


  def forward(self, input):
    return torch.addmm(self.b, input, self.W)

在 C ++中,它看起來像這樣:

#include <torch/torch.h>


struct Net : torch::nn::Module {
  Net(int64_t N, int64_t M) {
    W = register_parameter("W", torch::randn({N, M}));
    b = register_parameter("b", torch::randn(M));
  }
  torch::Tensor forward(torch::Tensor input) {
    return torch::addmm(b, input, W);
  }
  torch::Tensor W, b;
};

就像在 Python 中一樣,我們定義了一個(gè)名為Net的類(為簡(jiǎn)單起見,這里是struct而不是class),然后從模塊基類派生它。 在構(gòu)造函數(shù)內(nèi)部,我們使用torch::randn創(chuàng)建張量,就像在 Python 中使用torch.randn一樣。 一個(gè)有趣的區(qū)別是我們?nèi)绾巫?cè)參數(shù)。 在 Python 中,我們用torch.nn.Parameter類包裝了張量,而在 C ++中,我們不得不通過register_parameter方法傳遞張量。 這樣做的原因是 Python API 可以檢測(cè)到屬性為torch.nn.Parameter類型并自動(dòng)注冊(cè)此類張量。 在 C ++中,反射非常有限,因此提供了一種更傳統(tǒng)(而且并不是那么不可思議)的方法。

注冊(cè)子模塊并遍歷模塊層次結(jié)構(gòu)

同樣,我們可以注冊(cè)參數(shù),也可以注冊(cè)子模塊。 在 Python 中,將子模塊分配為模塊的屬性時(shí),會(huì)自動(dòng)檢測(cè)并注冊(cè)這些子模塊:

class Net(torch.nn.Module):
  def __init__(self, N, M):
      super(Net, self).__init__()
      # Registered as a submodule behind the scenes
      self.linear = torch.nn.Linear(N, M)
      self.another_bias = torch.nn.Parameter(torch.rand(M))


  def forward(self, input):
    return self.linear(input) + self.another_bias

例如,允許使用parameters()方法來遞歸訪問模塊層次結(jié)構(gòu)中的所有參數(shù):

>>> net = Net(4, 5)
>>> print(list(net.parameters()))
[Parameter containing:
tensor([0.0808, 0.8613, 0.2017, 0.5206, 0.5353], requires_grad=True), Parameter containing:
tensor([[-0.3740, -0.0976, -0.4786, -0.4928],
        [-0.1434,  0.4713,  0.1735, -0.3293],
        [-0.3467, -0.3858,  0.1980,  0.1986],
        [-0.1975,  0.4278, -0.1831, -0.2709],
        [ 0.3730,  0.4307,  0.3236, -0.0629]], requires_grad=True), Parameter containing:
tensor([ 0.2038,  0.4638, -0.2023,  0.1230, -0.0516], requires_grad=True)]

要在 C ++中注冊(cè)子模塊,請(qǐng)使用恰當(dāng)命名的register_module()方法注冊(cè)類似torch::nn::Linear的模塊:

struct Net : torch::nn::Module {
  Net(int64_t N, int64_t M)
      : linear(register_module("linear", torch::nn::Linear(N, M))) {
    another_bias = register_parameter("b", torch::randn(M));
  }
  torch::Tensor forward(torch::Tensor input) {
    return linear(input) + another_bias;
  }
  torch::nn::Linear linear;
  torch::Tensor another_bias;
};

Tip

您可以在torch::nn命名空間的文檔中找到可用的內(nèi)置模塊的完整列表,例如torch::nn::Linear,torch::nn::Dropouttorch::nn::Conv2d

微妙之處在于,為什么在構(gòu)造函數(shù)的初始值設(shè)定項(xiàng)列表中創(chuàng)建子模塊,而在構(gòu)造函數(shù)的主體內(nèi)部創(chuàng)建參數(shù)。 這是有充分的理由的,我們將在下面有關(guān) C ++前端的所有權(quán)模型的部分中對(duì)此進(jìn)行介紹。 但是,最終結(jié)果是,就像 Python 中一樣,我們可以遞歸訪問模塊樹的參數(shù)。 調(diào)用parameters()將返回std::vector<torch::Tensor>,我們可以對(duì)其進(jìn)行迭代:

int main() {
  Net net(4, 5);
  for (const auto& p : net.parameters()) {
    std::cout << p << std::endl;
  }
}

打印:

root@fa350df05ecf:/home/build# ./dcgan
0.0345
1.4456
-0.6313
-0.3585
-0.4008
[ Variable[CPUFloatType]{5} ]
-0.1647  0.2891  0.0527 -0.0354
0.3084  0.2025  0.0343  0.1824
-0.4630 -0.2862  0.2500 -0.0420
0.3679 -0.1482 -0.0460  0.1967
0.2132 -0.1992  0.4257  0.0739
[ Variable[CPUFloatType]{5,4} ]
0.01 *
3.6861
-10.1166
-45.0333
7.9983
-20.0705
[ Variable[CPUFloatType]{5} ]

具有三個(gè)參數(shù),就像在 Python 中一樣。 為了也查看這些參數(shù)的名稱,C ++ API 提供了named_parameters()方法,該方法返回OrderedDict,就像在 Python 中一樣:

Net net(4, 5);
for (const auto& pair : net.named_parameters()) {
  std::cout << pair.key() << ": " << pair.value() << std::endl;
}

我們可以再次執(zhí)行以查看輸出:

root@fa350df05ecf:/home/build# make && ./dcgan                                                                                                                                            11:13:48
Scanning dependencies of target dcgan
[ 50%] Building CXX object CMakeFiles/dcgan.dir/dcgan.cpp.o
[100%] Linking CXX executable dcgan
[100%] Built target dcgan
b: -0.1863
-0.8611
-0.1228
1.3269
0.9858
[ Variable[CPUFloatType]{5} ]
linear.weight:  0.0339  0.2484  0.2035 -0.2103
-0.0715 -0.2975 -0.4350 -0.1878
-0.3616  0.1050 -0.4982  0.0335
-0.1605  0.4963  0.4099 -0.2883
0.1818 -0.3447 -0.1501 -0.0215
[ Variable[CPUFloatType]{5,4} ]
linear.bias: -0.0250
0.0408
0.3756
-0.2149
-0.3636
[ Variable[CPUFloatType]{5} ]

Note

torch::nn::Module的文檔包含在模塊層次結(jié)構(gòu)上運(yùn)行方法的完整列表中。

在轉(zhuǎn)發(fā)模式下運(yùn)行網(wǎng)絡(luò)

要使用 C ++執(zhí)行網(wǎng)絡(luò),我們只需調(diào)用我們自己定義的forward()方法:

int main() {
  Net net(4, 5);
  std::cout << net.forward(torch::ones({2, 4})) << std::endl;
}

打?。?/p>

root@fa350df05ecf:/home/build# ./dcgan
0.8559  1.1572  2.1069 -0.1247  0.8060
0.8559  1.1572  2.1069 -0.1247  0.8060
[ Variable[CPUFloatType]{2,5} ]

模塊所有權(quán)

至此,我們知道了如何使用 C ++定義模塊,注冊(cè)參數(shù),注冊(cè)子模塊,通過parameters()之類的方法遍歷模塊層次結(jié)構(gòu)并最終運(yùn)行模塊的forward()方法。 盡管在 C ++ API 中還有很多方法,類和主題需要使用,但我將為您提供完整菜單的文檔。 我們將在稍后實(shí)現(xiàn) DCGAN 模型和端到端訓(xùn)練管道的過程中,涉及更多概念。 在我們這樣做之前,讓我簡(jiǎn)要地談?wù)?C ++前端為torch::nn::Module的子類提供的所有權(quán)模型。

在本次討論中,所有權(quán)模型是指模塊的存儲(chǔ)和傳遞方式-確定特定模塊實(shí)例的所有者或所有者。 在 Python 中,對(duì)象始終是動(dòng)態(tài)分配的(在堆上),并具有引用語(yǔ)義。 這是非常容易使用且易于理解的。 實(shí)際上,在 Python 中,您可以很大程度上忽略對(duì)象的位置以及如何引用它們,而將精力集中在完成事情上。

C ++是一種較低級(jí)的語(yǔ)言,它在此領(lǐng)域提供了更多選擇。 這增加了復(fù)雜性,并嚴(yán)重影響了 C ++前端的設(shè)計(jì)和人體工程學(xué)。 特別是,對(duì)于 C ++前端中的模塊,我們可以選擇使用值語(yǔ)義參考語(yǔ)義。 第一種情況是最簡(jiǎn)單的,并且在到目前為止的示例中已進(jìn)行了展示:模塊對(duì)象分配在堆棧上,并在傳遞給函數(shù)時(shí)可以復(fù)制,移動(dòng)(使用std::move)或通過引用或指針獲?。?/p>

struct Net : torch::nn::Module { };


void a(Net net) { }
void b(Net& net) { }
void c(Net* net) { }


int main() {
  Net net;
  a(net);
  a(std::move(net));
  b(net);
  c(&net);
}

對(duì)于第二種情況-參考語(yǔ)義-我們可以使用std::shared_ptr。 引用語(yǔ)義的優(yōu)勢(shì)在于,就像在 Python 中一樣,它減少了思考如何將模塊傳遞給函數(shù)以及如何聲明參數(shù)的認(rèn)知開銷(假設(shè)您在任何地方都使用shared_ptr)。

struct Net : torch::nn::Module {};


void a(std::shared_ptr<Net> net) { }


int main() {
  auto net = std::make_shared<Net>();
  a(net);
}

根據(jù)我們的經(jīng)驗(yàn),來自動(dòng)態(tài)語(yǔ)言的研究人員非常喜歡引用語(yǔ)義而不是值語(yǔ)義,盡管后者比 C ++更“原生”。 同樣重要的是,torch::nn::Module的設(shè)計(jì)為了要與 Python API 的人體工程學(xué)保持緊密聯(lián)系,要共享所有權(quán)。 例如,采用我們先前的Net定義(此處為簡(jiǎn)稱):

struct Net : torch::nn::Module {
  Net(int64_t N, int64_t M)
    : linear(register_module("linear", torch::nn::Linear(N, M)))
  { }
  torch::nn::Linear linear;
};

為了使用linear子模塊,我們想將其直接存儲(chǔ)在我們的類中。 但是,我們還希望模塊基類了解并有權(quán)訪問此子模塊。 為此,它必須存儲(chǔ)對(duì)此子模塊的引用。 至此,我們已經(jīng)達(dá)到了共享所有權(quán)的需要。 torch::nn::Module類和具體的Net類都需要引用該子模塊。 因此,基類將模塊存儲(chǔ)為shared_ptr,因此具體類也必須存儲(chǔ)。

可是等等! 在以上代碼中我沒有看到任何關(guān)于shared_ptr的提示! 這是為什么? 好吧,因?yàn)?code>std::shared_ptr<MyModule>實(shí)在令人難受。 為了保持研究人員的生產(chǎn)力,我們提出了一個(gè)精心設(shè)計(jì)的方案,以隱藏shared_ptr的提法-通常保留給值語(yǔ)義的好處-同時(shí)保留參考語(yǔ)義。 要了解它是如何工作的,我們可以看一下核心庫(kù)中torch::nn::Linear模塊的簡(jiǎn)化定義(完整定義為,在此處):

struct LinearImpl : torch::nn::Module {
  LinearImpl(int64_t in, int64_t out);


  Tensor forward(const Tensor& input);


  Tensor weight, bias;
};


TORCH_MODULE(Linear);

簡(jiǎn)而言之:該模塊不是Linear,而是LinearImpl。 然后,宏TORCH_MODULE定義了實(shí)際的Linear類。 這個(gè)“生成的”類實(shí)際上是std::shared_ptr<LinearImpl>的包裝。 它是一個(gè)包裝器,而不是簡(jiǎn)單的 typedef,因此,除其他事項(xiàng)外,構(gòu)造函數(shù)仍可按預(yù)期工作,即,您仍然可以編寫torch::nn::Linear(3, 4)而不是std::make_shared<LinearImpl>(3, 4)。 我們將由宏創(chuàng)建的類稱為模塊持有人。 與(共享)指針一樣,您可以使用箭頭運(yùn)算符(例如model->forward(...))訪問基礎(chǔ)對(duì)象。 最終結(jié)果是一個(gè)所有權(quán)模型,該所有權(quán)模型非常類似于 Python API。 引用語(yǔ)義成為默認(rèn)語(yǔ)義,但是沒有額外輸入std::shared_ptrstd::make_shared。 對(duì)于我們的Net,使用模塊持有人 API 如下所示:

struct NetImpl : torch::nn::Module {};
TORCH_MODULE(Net);


void a(Net net) { }


int main() {
  Net net;
  a(net);
}

這里有一個(gè)微妙的問題值得一提。 默認(rèn)構(gòu)造的std::shared_ptr為“空”,即包含空指針。 什么是默認(rèn)構(gòu)造的LinearNet? 好吧,這是一個(gè)棘手的選擇。 我們可以說它應(yīng)該是一個(gè)空(null)std::shared_ptr<LinearImpl>。 但是,請(qǐng)記住Linear(3, 4)std::make_shared<LinearImpl>(3, 4)相同。 這意味著如果我們已確定Linear linear;應(yīng)該為空指針,則將無法構(gòu)造不采用任何構(gòu)造函數(shù)參數(shù)或都不使用所有缺省構(gòu)造函數(shù)的模塊。 因此,在當(dāng)前的 API 中,默認(rèn)構(gòu)造的模塊持有人(如Linear())將調(diào)用基礎(chǔ)模塊的默認(rèn)構(gòu)造函數(shù)(LinearImpl())。 如果基礎(chǔ)模塊沒有默認(rèn)構(gòu)造函數(shù),則會(huì)出現(xiàn)編譯器錯(cuò)誤。 要構(gòu)造空持有人,可以將nullptr傳遞給持有人的構(gòu)造函數(shù)。

實(shí)際上,這意味著您可以使用如先前所示的子模塊,在初始化程序列表中注冊(cè)并構(gòu)造該模塊:

struct Net : torch::nn::Module {
  Net(int64_t N, int64_t M)
    : linear(register_module("linear", torch::nn::Linear(N, M)))
  { }
  torch::nn::Linear linear;
};

或者,您可以先使用空指針構(gòu)造持有人,然后在構(gòu)造函數(shù)中為其分配值(Pythonistas 更熟悉):

struct Net : torch::nn::Module {
  Net(int64_t N, int64_t M) {
    linear = register_module("linear", torch::nn::Linear(N, M));
  }
  torch::nn::Linear linear{nullptr}; // construct an empty holder
};

結(jié)論:您應(yīng)該使用哪種所有權(quán)模型–哪種語(yǔ)義? C ++前端的 API 最能支持模塊所有者提供的所有權(quán)模型。 這種機(jī)制的唯一缺點(diǎn)是在模塊聲明下方多了一行樣板。 也就是說,最簡(jiǎn)單的模型仍然是 C ++模塊簡(jiǎn)介中顯示的值語(yǔ)義模型。 對(duì)于小的,簡(jiǎn)單的腳本,您也可以擺脫它。 但是,由于技術(shù)原因,您遲早會(huì)發(fā)現(xiàn)它并不總是受支持。 例如,序列化 API(torch::savetorch::load)僅支持模塊支架(或普通shared_ptr)。 因此,推薦使用模塊持有人 API 和 C ++前端定義模塊,此后我們將在本教程中使用此 API。

定義 DCGAN 模塊

現(xiàn)在,我們有必要的背景和簡(jiǎn)介來定義我們要在本文中解決的機(jī)器學(xué)習(xí)任務(wù)的模塊。 回顧一下:我們的任務(wù)是從 MNIST 數(shù)據(jù)集生成數(shù)字圖像。 我們想使用生成對(duì)抗網(wǎng)絡(luò)(GAN)解決此任務(wù)。 特別是,我們將使用 DCGAN 體系結(jié)構(gòu)-這是同類中最早的也是最簡(jiǎn)單的一種,但是完全可以完成此任務(wù)。

Tip

您可以在存儲(chǔ)庫(kù)中找到本教程中提供的完整源代碼。

什么是 GAN aGAN?

GAN 由兩個(gè)不同的神經(jīng)網(wǎng)絡(luò)模型組成:生成器鑒別器。 生成器從噪聲分布中接收樣本,其目的是將每個(gè)噪聲樣本轉(zhuǎn)換為類似于目標(biāo)分布的圖像(在我們的情況下為 MNIST 數(shù)據(jù)集)。 鑒別器又從 MNIST 數(shù)據(jù)集接收實(shí)際圖像,或從生成器接收圖像。 要求發(fā)出一個(gè)概率來判斷特定圖像的真實(shí)程度(接近1)或偽造(接近0)。 來自鑒別器的關(guān)于由發(fā)生器產(chǎn)生的圖像如何真實(shí)的反饋被用來訓(xùn)練發(fā)生器。 鑒別器對(duì)真實(shí)性有多好的反饋將用于優(yōu)化鑒別器。 從理論上講,生成器和鑒別器之間的微妙平衡使它們串聯(lián)起來得到改善,從而導(dǎo)致生成器生成與目標(biāo)分布無法區(qū)分的圖像,從而使鑒別器(那時(shí))的敏銳眼睛冒出了散發(fā)0.5的真實(shí)和真實(shí)可能性。 假圖片。 對(duì)我們來說,最終結(jié)果是一臺(tái)接收噪聲作為輸入并生成數(shù)字逼真的圖像作為其輸出的機(jī)器。

發(fā)電機(jī)模塊

我們首先定義生成器模塊,該模塊由一系列轉(zhuǎn)置的 2D 卷積,批處理歸一化和 ReLU 激活單元組成。 我們?cè)诙x自己的模塊的forward()方法中顯式地(在功能上)在模塊之間傳遞輸入:

struct DCGANGeneratorImpl : nn::Module {
  DCGANGeneratorImpl(int kNoiseSize)
      : conv1(nn::ConvTranspose2dOptions(kNoiseSize, 256, 4)
                  .bias(false)),
        batch_norm1(256),
        conv2(nn::ConvTranspose2dOptions(256, 128, 3)
                  .stride(2)
                  .padding(1)
                  .bias(false)),
        batch_norm2(128),
        conv3(nn::ConvTranspose2dOptions(128, 64, 4)
                  .stride(2)
                  .padding(1)
                  .bias(false)),
        batch_norm3(64),
        conv4(nn::ConvTranspose2dOptions(64, 1, 4)
                  .stride(2)
                  .padding(1)
                  .bias(false))
 {
   // register_module() is needed if we want to use the parameters() method later on
   register_module("conv1", conv1);
   register_module("conv2", conv2);
   register_module("conv3", conv3);
   register_module("conv4", conv4);
   register_module("batch_norm1", batch_norm1);
   register_module("batch_norm2", batch_norm2);
   register_module("batch_norm3", batch_norm3);
 }


 torch::Tensor forward(torch::Tensor x) {
   x = torch::relu(batch_norm1(conv1(x)));
   x = torch::relu(batch_norm2(conv2(x)));
   x = torch::relu(batch_norm3(conv3(x)));
   x = torch::tanh(conv4(x));
   return x;
 }


 nn::ConvTranspose2d conv1, conv2, conv3, conv4;
 nn::BatchNorm2d batch_norm1, batch_norm2, batch_norm3;
};
TORCH_MODULE(DCGANGenerator);


DCGANGenerator generator(kNoiseSize);

現(xiàn)在我們可以在DCGANGenerator上調(diào)用forward()將噪聲樣本映射到圖像。

選擇的特定模塊,例如nn::ConvTranspose2dnn::BatchNorm2d,遵循前面概述的結(jié)構(gòu)。 kNoiseSize常數(shù)確定輸入噪聲矢量的大小,并將其設(shè)置為100。 當(dāng)然,超參數(shù)是通過研究生的血統(tǒng)發(fā)現(xiàn)的。

注意

在超參數(shù)的發(fā)現(xiàn)中,沒有研究生受到傷害。 他們定期喂給 Soylent。

Note

簡(jiǎn)要介紹如何將選項(xiàng)傳遞給 C ++前端中的Conv2d等內(nèi)置模塊:每個(gè)模塊都有一些必需的選項(xiàng),例如BatchNorm2d的功能數(shù)量。 如果您只需要配置所需的選項(xiàng),則可以將它們直接傳遞給模塊的構(gòu)造函數(shù),例如BatchNorm2d(128)Dropout(0.5)Conv2d(8, 4, 2)(用于輸入通道數(shù),輸出通道數(shù)和內(nèi)核大小)。 但是,如果需要修改其他通常默認(rèn)設(shè)置的選項(xiàng),例如Conv2dbias,則需要構(gòu)造并傳遞選項(xiàng)對(duì)象。 C ++前端中的每個(gè)模塊都有一個(gè)關(guān)聯(lián)的選項(xiàng)結(jié)構(gòu),稱為ModuleOptions,其中Module是模塊的名稱,例如LinearLinearOptions。 這就是我們上面的Conv2d模塊的工作。

鑒別模塊

鑒別器類似地是卷積,批歸一化和激活的序列。 但是,卷積現(xiàn)在是常規(guī)的卷積,而不是轉(zhuǎn)置的卷積,我們使用 alpha 值為 0.2 的泄漏 ReLU 代替了普通的 ReLU。 同樣,最后的激活變?yōu)?Sigmoid,將值壓縮到 0 到 1 之間。然后,我們可以將這些壓縮后的值解釋為鑒別器分配給真實(shí)圖像的概率。

為了構(gòu)建鑒別器,我們將嘗試不同的方法:<cite>順序</cite>模塊。 像在 Python 中一樣,PyTorch 在此提供了兩種用于模型定義的 API:一種功能,其中的輸入通過連續(xù)的函數(shù)傳遞(例如,生成器模塊示例),而另一種面向?qū)ο蟮?,其中我們?gòu)建了<cite>順序</cite>模塊 包含整個(gè)模型作為子模塊。 使用<cite>順序</cite>,鑒別符將如下所示:

nn::Sequential discriminator(
  // Layer 1
  nn::Conv2d(
      nn::Conv2dOptions(1, 64, 4).stride(2).padding(1).bias(false)),
  nn::LeakyReLU(nn::LeakyReLUOptions().negative_slope(0.2)),
  // Layer 2
  nn::Conv2d(
      nn::Conv2dOptions(64, 128, 4).stride(2).padding(1).bias(false)),
  nn::BatchNorm2d(128),
  nn::LeakyReLU(nn::LeakyReLUOptions().negative_slope(0.2)),
  // Layer 3
  nn::Conv2d(
      nn::Conv2dOptions(128, 256, 4).stride(2).padding(1).bias(false)),
  nn::BatchNorm2d(256),
  nn::LeakyReLU(nn::LeakyReLUOptions().negative_slope(0.2)),
  // Layer 4
  nn::Conv2d(
      nn::Conv2dOptions(256, 1, 3).stride(1).padding(0).bias(false)),
  nn::Sigmoid());

Tip

Sequential模塊僅執(zhí)行功能組合。 第一個(gè)子模塊的輸出成為第二個(gè)子模塊的輸入,第三個(gè)子模塊的輸出成為第四個(gè)子模塊的輸入,依此類推。

加載數(shù)據(jù)中

現(xiàn)在我們已經(jīng)定義了生成器和鑒別器模型,我們需要一些可以用來訓(xùn)練這些模型的數(shù)據(jù)。 與 Python 一樣,C ++前端也具有強(qiáng)大的并行數(shù)據(jù)加載器。 該數(shù)據(jù)加載器可以從數(shù)據(jù)集中讀取批次數(shù)據(jù)(您可以定義自己),并提供許多配置旋鈕。

Note

盡管 Python 數(shù)據(jù)加載器使用多重處理,但 C ++數(shù)據(jù)加載器實(shí)際上是多線程的,不會(huì)啟動(dòng)任何新進(jìn)程。

數(shù)據(jù)加載器是 C ++前端data API 的一部分,該 API 包含在torch::data::名稱空間中。 該 API 由幾個(gè)不同的組件組成:

  • 數(shù)據(jù)加載器類,
  • 用于定義數(shù)據(jù)集的 API,
  • 用于定義轉(zhuǎn)換的 API,可以將其應(yīng)用于數(shù)據(jù)集,
  • 用于定義采樣器的 API,該采樣器會(huì)生成用于對(duì)數(shù)據(jù)集建立索引的索引,
  • 現(xiàn)有數(shù)據(jù)集,變換和采樣器的庫(kù)。

對(duì)于本教程,我們可以使用 C ++前端附帶的MNIST數(shù)據(jù)集。 讓我們?yōu)榇藢?shí)例化一個(gè)torch::data::datasets::MNIST,并應(yīng)用兩個(gè)轉(zhuǎn)換:首先,我們對(duì)圖像進(jìn)行歸一化,以使其在-1+1的范圍內(nèi)(從01的原始范圍)。 其次,我們應(yīng)用Stack 歸類,它采用一批張量并將它們沿第一維堆疊為單個(gè)張量:

auto dataset = torch::data::datasets::MNIST("./mnist")
    .map(torch::data::transforms::Normalize<>(0.5, 0.5))
    .map(torch::data::transforms::Stack<>());

請(qǐng)注意,相對(duì)于執(zhí)行訓(xùn)練二進(jìn)制文件的位置,MNIST 數(shù)據(jù)集應(yīng)位于./mnist目錄中。 您可以使用此腳本下載 MNIST 數(shù)據(jù)集。

接下來,我們創(chuàng)建一個(gè)數(shù)據(jù)加載器并將其傳遞給此數(shù)據(jù)集。 為了創(chuàng)建一個(gè)新的數(shù)據(jù)加載器,我們使用torch::data::make_data_loader,它返回正確類型的std::unique_ptr(取決于數(shù)據(jù)集的類型,采樣器的類型以及其他一些實(shí)現(xiàn)細(xì)節(jié)):

auto data_loader = torch::data::make_data_loader(std::move(dataset));

數(shù)據(jù)加載器確實(shí)提供了很多選項(xiàng)。 您可以在處檢查全套。 例如,為了加快數(shù)據(jù)加載速度,我們可以增加工作人員的數(shù)量。 默認(rèn)數(shù)字為零,這表示將使用主線程。 如果將workers設(shè)置為2,將產(chǎn)生兩個(gè)線程并發(fā)加載數(shù)據(jù)。 我們還應(yīng)該將批次大小從其默認(rèn)值1增大到更合理的值,例如64(kBatchSize的值)。 因此,讓我們創(chuàng)建一個(gè)DataLoaderOptions對(duì)象并設(shè)置適當(dāng)?shù)膶傩裕?/a>

auto data_loader = torch::data::make_data_loader(
    std::move(dataset),
    torch::data::DataLoaderOptions().batch_size(kBatchSize).workers(2));

現(xiàn)在,我們可以編寫一個(gè)循環(huán)來加載批量數(shù)據(jù),目前我們僅將其打印到控制臺(tái):

for (torch::data::Example<>& batch : *data_loader) {
  std::cout << "Batch size: " << batch.data.size(0) << " | Labels: ";
  for (int64_t i = 0; i < batch.data.size(0); ++i) {
    std::cout << batch.target[i].item<int64_t>() << " ";
  }
  std::cout << std::endl;
}

在這種情況下,數(shù)據(jù)加載器返回的類型為torch::data::Example。 此類型是一種簡(jiǎn)單的結(jié)構(gòu),其中的data字段用于數(shù)據(jù),而target字段用于標(biāo)簽。 因?yàn)槲覀冎皯?yīng)用了Stack歸類,所以數(shù)據(jù)加載器僅返回一個(gè)這樣的示例。 如果我們未應(yīng)用排序規(guī)則,則數(shù)據(jù)加載器將改為生成std::vector<torch::data::Example<>>,批處理中每個(gè)示例包含一個(gè)元素。

如果重新生成并運(yùn)行此代碼,則應(yīng)看到類似以下內(nèi)容的內(nèi)容:

root@fa350df05ecf:/home/build# make
Scanning dependencies of target dcgan
[ 50%] Building CXX object CMakeFiles/dcgan.dir/dcgan.cpp.o
[100%] Linking CXX executable dcgan
[100%] Built target dcgan
root@fa350df05ecf:/home/build# make
[100%] Built target dcgan
root@fa350df05ecf:/home/build# ./dcgan
Batch size: 64 | Labels: 5 2 6 7 2 1 6 7 0 1 6 2 3 6 9 1 8 4 0 6 5 3 3 0 4 6 6 6 4 0 8 6 0 6 9 2 4 0 2 8 6 3 3 2 9 2 0 1 4 2 3 4 8 2 9 9 3 5 8 0 0 7 9 9
Batch size: 64 | Labels: 2 2 4 7 1 2 8 8 6 9 0 2 2 9 3 6 1 3 8 0 4 4 8 8 8 9 2 6 4 7 1 5 0 9 7 5 4 3 5 4 1 2 8 0 7 1 9 6 1 6 5 3 4 4 1 2 3 2 3 5 0 1 6 2
Batch size: 64 | Labels: 4 5 4 2 1 4 8 3 8 3 6 1 5 4 3 6 2 2 5 1 3 1 5 0 8 2 1 5 3 2 4 4 5 9 7 2 8 9 2 0 6 7 4 3 8 3 5 8 8 3 0 5 8 0 8 7 8 5 5 6 1 7 8 0
Batch size: 64 | Labels: 3 3 7 1 4 1 6 1 0 3 6 4 0 2 5 4 0 4 2 8 1 9 6 5 1 6 3 2 8 9 2 3 8 7 4 5 9 6 0 8 3 0 0 6 4 8 2 5 4 1 8 3 7 8 0 0 8 9 6 7 2 1 4 7
Batch size: 64 | Labels: 3 0 5 5 9 8 3 9 8 9 5 9 5 0 4 1 2 7 7 2 0 0 5 4 8 7 7 6 1 0 7 9 3 0 6 3 2 6 2 7 6 3 3 4 0 5 8 8 9 1 9 2 1 9 4 4 9 2 4 6 2 9 4 0
Batch size: 64 | Labels: 9 6 7 5 3 5 9 0 8 6 6 7 8 2 1 9 8 8 1 1 8 2 0 7 1 4 1 6 7 5 1 7 7 4 0 3 2 9 0 6 6 3 4 4 8 1 2 8 6 9 2 0 3 1 2 8 5 6 4 8 5 8 6 2
Batch size: 64 | Labels: 9 3 0 3 6 5 1 8 6 0 1 9 9 1 6 1 7 7 4 4 4 7 8 8 6 7 8 2 6 0 4 6 8 2 5 3 9 8 4 0 9 9 3 7 0 5 8 2 4 5 6 2 8 2 5 3 7 1 9 1 8 2 2 7
Batch size: 64 | Labels: 9 1 9 2 7 2 6 0 8 6 8 7 7 4 8 6 1 1 6 8 5 7 9 1 3 2 0 5 1 7 3 1 6 1 0 8 6 0 8 1 0 5 4 9 3 8 5 8 4 8 0 1 2 6 2 4 2 7 7 3 7 4 5 3
Batch size: 64 | Labels: 8 8 3 1 8 6 4 2 9 5 8 0 2 8 6 6 7 0 9 8 3 8 7 1 6 6 2 7 7 4 5 5 2 1 7 9 5 4 9 1 0 3 1 9 3 9 8 8 5 3 7 5 3 6 8 9 4 2 0 1 2 5 4 7
Batch size: 64 | Labels: 9 2 7 0 8 4 4 2 7 5 0 0 6 2 0 5 9 5 9 8 8 9 3 5 7 5 4 7 3 0 5 7 6 5 7 1 6 2 8 7 6 3 2 6 5 6 1 2 7 7 0 0 5 9 0 0 9 1 7 8 3 2 9 4
Batch size: 64 | Labels: 7 6 5 7 7 5 2 2 4 9 9 4 8 7 4 8 9 4 5 7 1 2 6 9 8 5 1 2 3 6 7 8 1 1 3 9 8 7 9 5 0 8 5 1 8 7 2 6 5 1 2 0 9 7 4 0 9 0 4 6 0 0 8 6
...

這意味著我們能夠成功地從 MNIST 數(shù)據(jù)集中加載數(shù)據(jù)。

編寫訓(xùn)練循環(huán)

現(xiàn)在,讓我們完成示例的算法部分,并實(shí)現(xiàn)生成器和鑒別器之間的精妙舞蹈。 首先,我們將創(chuàng)建兩個(gè)優(yōu)化器,一個(gè)用于生成器,一個(gè)用于區(qū)分器。 我們使用的優(yōu)化程序?qū)崿F(xiàn)了 Adam 算法:

torch::optim::Adam generator_optimizer(
    generator->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5));
torch::optim::Adam discriminator_optimizer(
    discriminator->parameters(), torch::optim::AdamOptions(5e-4).beta1(0.5));

Note

在撰寫本文時(shí),C ++前端提供了實(shí)現(xiàn) Adagrad,Adam,LBBFG,RMSprop 和 SGD 的優(yōu)化器。 文檔具有最新列表。

接下來,我們需要更新我們的訓(xùn)練循環(huán)。 我們將添加一個(gè)外部循環(huán)以在每個(gè)時(shí)期耗盡數(shù)據(jù)加載器,然后編寫 GAN 訓(xùn)練代碼:

for (int64_t epoch = 1; epoch <= kNumberOfEpochs; ++epoch) {
  int64_t batch_index = 0;
  for (torch::data::Example<>& batch : *data_loader) {
    // Train discriminator with real images.
    discriminator->zero_grad();
    torch::Tensor real_images = batch.data;
    torch::Tensor real_labels = torch::empty(batch.data.size(0)).uniform_(0.8, 1.0);
    torch::Tensor real_output = discriminator->forward(real_images);
    torch::Tensor d_loss_real = torch::binary_cross_entropy(real_output, real_labels);
    d_loss_real.backward();


    // Train discriminator with fake images.
    torch::Tensor noise = torch::randn({batch.data.size(0), kNoiseSize, 1, 1});
    torch::Tensor fake_images = generator->forward(noise);
    torch::Tensor fake_labels = torch::zeros(batch.data.size(0));
    torch::Tensor fake_output = discriminator->forward(fake_images.detach());
    torch::Tensor d_loss_fake = torch::binary_cross_entropy(fake_output, fake_labels);
    d_loss_fake.backward();


    torch::Tensor d_loss = d_loss_real + d_loss_fake;
    discriminator_optimizer.step();


    // Train generator.
    generator->zero_grad();
    fake_labels.fill_(1);
    fake_output = discriminator->forward(fake_images);
    torch::Tensor g_loss = torch::binary_cross_entropy(fake_output, fake_labels);
    g_loss.backward();
    generator_optimizer.step();


    std::printf(
        "\r[%2ld/%2ld][%3ld/%3ld] D_loss: %.4f | G_loss: %.4f",
        epoch,
        kNumberOfEpochs,
        ++batch_index,
        batches_per_epoch,
        d_loss.item<float>(),
        g_loss.item<float>());
  }
}

上面,我們首先在真實(shí)圖像上評(píng)估鑒別器,為此應(yīng)為其分配較高的概率。 為此,我們使用torch::empty(batch.data.size(0)).uniform_(0.8, 1.0)作為目標(biāo)概率。

Note

我們選擇均勻分布在 0.8 到 1.0 之間的隨機(jī)值,而不是各處的 1.0,以使鑒別器訓(xùn)練更加可靠。 此技巧稱為標(biāo)簽平滑。

在評(píng)估鑒別器之前,我們將其參數(shù)的梯度歸零。 計(jì)算完損耗后,我們通過調(diào)用d_loss.backward()計(jì)算新的梯度來在網(wǎng)絡(luò)中反向傳播。 我們對(duì)虛假圖像重復(fù)此步驟。 我們不使用數(shù)據(jù)集中的圖像,而是讓生成器通過為它提供一批隨機(jī)噪聲來為此創(chuàng)建偽造圖像。 然后,我們將這些偽造圖像轉(zhuǎn)發(fā)給鑒別器。 這次,我們希望鑒別器發(fā)出低概率,最好是全零。 一旦計(jì)算了一批真實(shí)圖像和一批偽造圖像的鑒別器損耗,我們就可以一步一步地進(jìn)行鑒別器的優(yōu)化程序,以更新其參數(shù)。

為了訓(xùn)練生成器,我們?cè)俅问紫葘⑵涮荻葰w零,然后在偽圖像上重新評(píng)估鑒別器。 但是,這一次,我們希望鑒別器將概率分配為非常接近的概率,這將表明生成器可以生成使鑒別器認(rèn)為它們實(shí)際上是真實(shí)的圖像(來自數(shù)據(jù)集)。 為此,我們用全部填充fake_labels張量。 最后,我們逐步使用生成器的優(yōu)化器來更新其參數(shù)。

現(xiàn)在,我們應(yīng)該準(zhǔn)備在 CPU 上訓(xùn)練我們的模型。 我們還沒有任何代碼可以捕獲狀態(tài)或示例輸出,但是我們稍后會(huì)添加。 現(xiàn)在,讓我們觀察一下我們的模型正在做某事 –我們稍后將根據(jù)生成的圖像來驗(yàn)證這是否有意義。 重建和運(yùn)行應(yīng)打印如下內(nèi)容:

root@3c0711f20896:/home/build# make && ./dcgan
Scanning dependencies of target dcgan
[ 50%] Building CXX object CMakeFiles/dcgan.dir/dcgan.cpp.o
[100%] Linking CXX executable dcgan
[100%] Built target dcga
[ 1/10][100/938] D_loss: 0.6876 | G_loss: 4.1304
[ 1/10][200/938] D_loss: 0.3776 | G_loss: 4.3101
[ 1/10][300/938] D_loss: 0.3652 | G_loss: 4.6626
[ 1/10][400/938] D_loss: 0.8057 | G_loss: 2.2795
[ 1/10][500/938] D_loss: 0.3531 | G_loss: 4.4452
[ 1/10][600/938] D_loss: 0.3501 | G_loss: 5.0811
[ 1/10][700/938] D_loss: 0.3581 | G_loss: 4.5623
[ 1/10][800/938] D_loss: 0.6423 | G_loss: 1.7385
[ 1/10][900/938] D_loss: 0.3592 | G_loss: 4.7333
[ 2/10][100/938] D_loss: 0.4660 | G_loss: 2.5242
[ 2/10][200/938] D_loss: 0.6364 | G_loss: 2.0886
[ 2/10][300/938] D_loss: 0.3717 | G_loss: 3.8103
[ 2/10][400/938] D_loss: 1.0201 | G_loss: 1.3544
[ 2/10][500/938] D_loss: 0.4522 | G_loss: 2.6545
...

移至 GPU

盡管我們當(dāng)前的腳本可以在 CPU 上正常運(yùn)行,但是我們都知道卷積在 GPU 上要快得多。 讓我們快速討論如何將訓(xùn)練轉(zhuǎn)移到 GPU 上。 為此,我們需要做兩件事:將 GPU 設(shè)備規(guī)范傳遞給我們分配給自己的張量,并通過to()方法將所有其他張量明確復(fù)制到 C ++前端中所有張量和模塊上。 實(shí)現(xiàn)這兩者的最簡(jiǎn)單方法是在訓(xùn)練腳本的頂層創(chuàng)建torch::Device的實(shí)例,然后將該設(shè)備傳遞給諸如torch::zerosto()方法之類的張量工廠函數(shù)。 我們可以從使用 CPU 設(shè)備開始:

// Place this somewhere at the top of your training script.
torch::Device device(torch::kCPU);

新的張量分配,例如

torch::Tensor fake_labels = torch::zeros(batch.data.size(0));

應(yīng)該更新為以device作為最后一個(gè)參數(shù):

torch::Tensor fake_labels = torch::zeros(batch.data.size(0), device);

對(duì)于那些不在我們手中的張量,例如來自 MNIST 數(shù)據(jù)集的張量,我們必須插入顯式的to()調(diào)用。 這表示

torch::Tensor real_images = batch.data;

變成

torch::Tensor real_images = batch.data.to(device);

并且我們的模型參數(shù)也應(yīng)該移到正確的設(shè)備上:

generator->to(device);
discriminator->to(device);

Note

如果張量已經(jīng)存在于提供給to()的設(shè)備上,則該調(diào)用為空操作。 沒有多余的副本。

至此,我們已經(jīng)使之前的 CPU 駐留代碼更加明確。 但是,現(xiàn)在將設(shè)備更改為 CUDA 設(shè)備也非常容易:

torch::Device device(torch::kCUDA)

現(xiàn)在,所有張量都將駐留在 GPU 上,并調(diào)用快速 CUDA 內(nèi)核進(jìn)行所有操作,而無需我們更改任何下游代碼。 如果我們想指定一個(gè)特定的設(shè)備索引,則可以將其作為第二個(gè)參數(shù)傳遞給Device構(gòu)造函數(shù)。 如果我們希望不同的張量駐留在不同的設(shè)備上,則可以傳遞單獨(dú)的設(shè)備實(shí)例(例如,一個(gè)在 CUDA 設(shè)備 0 上,另一個(gè)在 CUDA 設(shè)備 1 上)。 我們甚至可以動(dòng)態(tài)地進(jìn)行此配置,這通常對(duì)于使我們的訓(xùn)練腳本更具可移植性很有用:

torch::Device device = torch::kCPU;
if (torch::cuda::is_available()) {
  std::cout << "CUDA is available! Training on GPU." << std::endl;
  device = torch::kCUDA;
}

甚至

torch::Device device(torch::cuda::is_available() ? torch::kCUDA : torch::kCPU);

檢查點(diǎn)和恢復(fù)訓(xùn)練狀態(tài)

我們應(yīng)該對(duì)訓(xùn)練腳本進(jìn)行的最后擴(kuò)充是定期保存模型參數(shù)的狀態(tài),優(yōu)化器的狀態(tài)以及一些生成的圖像樣本。 如果我們的計(jì)算機(jī)在訓(xùn)練過程中崩潰,則前兩個(gè)將使我們能夠恢復(fù)訓(xùn)練狀態(tài)。 對(duì)于長(zhǎng)期的訓(xùn)練課程,這是絕對(duì)必要的。 幸運(yùn)的是,C ++前端提供了一個(gè) API,用于對(duì)模型和優(yōu)化器狀態(tài)以及單個(gè)張量進(jìn)行序列化和反序列化。

為此的核心 API 是torch::save(thing,filename)torch::load(thing,filename),其中thing可以是torch::nn::Module子類或優(yōu)化程序?qū)嵗?,例如我們?cè)谟?xùn)練腳本中擁有的Adam對(duì)象。 讓我們更新訓(xùn)練循環(huán),以一定間隔檢查模型和優(yōu)化器狀態(tài):

if (batch_index % kCheckpointEvery == 0) {
  // Checkpoint the model and optimizer state.
  torch::save(generator, "generator-checkpoint.pt");
  torch::save(generator_optimizer, "generator-optimizer-checkpoint.pt");
  torch::save(discriminator, "discriminator-checkpoint.pt");
  torch::save(discriminator_optimizer, "discriminator-optimizer-checkpoint.pt");
  // Sample the generator and save the images.
  torch::Tensor samples = generator->forward(torch::randn({8, kNoiseSize, 1, 1}, device));
  torch::save((samples + 1.0) / 2.0, torch::str("dcgan-sample-", checkpoint_counter, ".pt"));
  std::cout << "\n-> checkpoint " << ++checkpoint_counter << '\n';
}

其中kCheckpointEvery是設(shè)置為類似于100之類的整數(shù),以便每批100都進(jìn)行檢查,而checkpoint_counter是每次創(chuàng)建檢查點(diǎn)時(shí)都會(huì)增加的計(jì)數(shù)器。

要恢復(fù)訓(xùn)練狀態(tài),可以在創(chuàng)建所有模型和優(yōu)化器之后但在訓(xùn)練循環(huán)之前添加如下代碼:

torch::optim::Adam generator_optimizer(
    generator->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5));
torch::optim::Adam discriminator_optimizer(
    discriminator->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5));


if (kRestoreFromCheckpoint) {
  torch::load(generator, "generator-checkpoint.pt");
  torch::load(generator_optimizer, "generator-optimizer-checkpoint.pt");
  torch::load(discriminator, "discriminator-checkpoint.pt");
  torch::load(
      discriminator_optimizer, "discriminator-optimizer-checkpoint.pt");
}


int64_t checkpoint_counter = 0;
for (int64_t epoch = 1; epoch <= kNumberOfEpochs; ++epoch) {
  int64_t batch_index = 0;
  for (torch::data::Example<>& batch : *data_loader) {

檢查生成的圖像

我們的訓(xùn)練腳本現(xiàn)已完成。 我們準(zhǔn)備在 CPU 或 GPU 上訓(xùn)練 GAN。 為了檢查我們訓(xùn)練過程的中間輸出,為此我們添加了將代碼樣本定期保存到"dcgan-sample-xxx.pt"文件的代碼,我們可以編寫一個(gè)小的 Python 腳本來加載張量并使用 matplotlib 顯示它們:

from __future__ import print_function
from __future__ import unicode_literals


import argparse


import matplotlib.pyplot as plt
import torch


parser = argparse.ArgumentParser()
parser.add_argument("-i", "--sample-file", required=True)
parser.add_argument("-o", "--out-file", default="out.png")
parser.add_argument("-d", "--dimension", type=int, default=3)
options = parser.parse_args()


module = torch.jit.load(options.sample_file)
images = list(module.parameters())[0]


for index in range(options.dimension * options.dimension):
  image = images[index].detach().cpu().reshape(28, 28).mul(255).to(torch.uint8)
  array = image.numpy()
  axis = plt.subplot(options.dimension, options.dimension, 1 + index)
  plt.imshow(array, cmap="gray")
  axis.get_xaxis().set_visible(False)
  axis.get_yaxis().set_visible(False)


plt.savefig(options.out_file)
print("Saved ", options.out_file)

現(xiàn)在,讓我們訓(xùn)練模型約 30 個(gè)紀(jì)元:

root@3c0711f20896:/home/build# make && ./dcgan                                                                                                                                10:17:57
Scanning dependencies of target dcgan
[ 50%] Building CXX object CMakeFiles/dcgan.dir/dcgan.cpp.o
[100%] Linking CXX executable dcgan
[100%] Built target dcgan
CUDA is available! Training on GPU.
[ 1/30][200/938] D_loss: 0.4953 | G_loss: 4.0195
-> checkpoint 1
[ 1/30][400/938] D_loss: 0.3610 | G_loss: 4.8148
-> checkpoint 2
[ 1/30][600/938] D_loss: 0.4072 | G_loss: 4.36760
-> checkpoint 3
[ 1/30][800/938] D_loss: 0.4444 | G_loss: 4.0250
-> checkpoint 4
[ 2/30][200/938] D_loss: 0.3761 | G_loss: 3.8790
-> checkpoint 5
[ 2/30][400/938] D_loss: 0.3977 | G_loss: 3.3315
...
-> checkpoint 120
[30/30][938/938] D_loss: 0.3610 | G_loss: 3.8084

并在圖中顯示圖像:

root@3c0711f20896:/home/build# python display.py -i dcgan-sample-100.pt
Saved out.png

應(yīng)該看起來像這樣:

digits

數(shù)字! 萬歲! 現(xiàn)在,事情就在您的球場(chǎng)上:您可以改進(jìn)模型以使數(shù)字看起來更好嗎?

結(jié)論

希望本教程為您提供了 PyTorch C ++前端的摘要。 像 PyTorch 這樣的機(jī)器學(xué)習(xí)庫(kù)必然具有非常廣泛的 API。 因此,有許多概念我們沒有時(shí)間或空間來討論。 但是,我建議您嘗試使用該 API,并在遇到問題時(shí)查閱我們的文檔,尤其是庫(kù) API 部分。 另外,請(qǐng)記住,只要我們能夠做到,就可以期望 C ++前端遵循 Python 前端的設(shè)計(jì)和語(yǔ)義,因此您可以利用這一事實(shí)來提高學(xué)習(xí)率。

Tip

You can find the full source code presented in this tutorial in this repository.

與往常一樣,如果您遇到任何問題或疑問,可以使用我們的論壇GitHub 問題進(jìn)行聯(lián)系。

以上內(nèi)容是否對(duì)您有幫助:
在線筆記
App下載
App下載

掃描二維碼

下載編程獅App

公眾號(hào)
微信公眾號(hào)

編程獅公眾號(hào)