PyTorch 使用 PyTorch C ++前端

2020-09-16 14:19 更新

原文:PyTorch 使用 PyTorch C ++前端

校驗(yàn)者:qiwei_ji

PyTorch C ++前端是 PyTorch 機(jī)器學(xué)習(xí)框架的純 C ++接口。 雖然 PyTorch 的主要接口自然是 Python,但此 Python API 建立于大量的 C ++代碼庫(kù)之上,提供基本的數(shù)據(jù)結(jié)構(gòu)和功能,例如張量和自動(dòng)微分。 C ++前端公開了純 C ++ 11 API,該 API 使用機(jī)器學(xué)習(xí)訓(xùn)練和推理所需的工具,擴(kuò)展了此基礎(chǔ) C ++代碼庫(kù)。 該拓展包括用于神經(jīng)網(wǎng)絡(luò)建模的通用組件的內(nèi)置集合; 使用自定義模塊擴(kuò)展此集合的 API; 一個(gè)流行的優(yōu)化算法庫(kù),例如隨機(jī)梯度下降; 具有 API 的并行數(shù)據(jù)加載器,用于定義和加載數(shù)據(jù)集; 序列化例程等。

本教程將引導(dǎo)您完成使用 C ++前端訓(xùn)練模型的端到端示例。 具體來說,我們將訓(xùn)練 DCGAN (一種生成模型),以生成 MNIST 數(shù)字的圖像。 雖然從概念上講,這只是一個(gè)簡(jiǎn)單的示例,但它足以使您對(duì) PyTorch C ++前端有個(gè)大概的了解,并可以滿足訓(xùn)練更復(fù)雜模型的需求。 我們將從一些鼓舞人心的詞開始,說明您為什么要使用 C ++前端,然后直接深入定義和訓(xùn)練我們的模型。

Tip

本筆記概述了 C ++前端的組件和設(shè)計(jì)原理。

Tip

有關(guān) PyTorch C ++生態(tài)系統(tǒng)的文檔,請(qǐng)?jiān)L問 https://pytorch.org/cppdocs 。 您可以在此處找到高級(jí)描述以及 API 級(jí)文檔。

動(dòng)機(jī)

在我們開始 GAN 和 MNIST 數(shù)字的激動(dòng)人心的旅程之前,讓我們退一步來討論為什么您要使用 C ++前端而不是 Python。 我們(PyTorch 團(tuán)隊(duì))創(chuàng)建了 C ++前端,以便能夠在無法使用 Python 或根本不適合該工具的環(huán)境中進(jìn)行研究。 此類環(huán)境的示例包括:

  • 低延遲系統(tǒng):您可能希望在具有高幀率和低延遲要求的純 C ++游戲引擎中進(jìn)行強(qiáng)化學(xué)習(xí)研究。 與 Python 庫(kù)相比,使用純 C ++庫(kù)更適合這種環(huán)境。 由于 Python 解釋器運(yùn)行緩慢,Python 可能根本無法處理此類問題。
  • 高度多線程環(huán)境:由于全局解釋器鎖定(GIL),Python 一次不能運(yùn)行多個(gè)系統(tǒng)線程。 并行處理是一種替代方法,但可擴(kuò)展性不強(qiáng),并且存在很多缺點(diǎn)。 C ++沒有這樣的約束,線程易于使用和創(chuàng)建。 需要高度并行化的模型,例如深層神經(jīng)進(jìn)化中使用的模型,可以從中受益。
  • 現(xiàn)有的 C ++代碼庫(kù):您可能下載了 C ++應(yīng)用程序,其工作范圍從后端服務(wù)器中的網(wǎng)頁(yè)服務(wù)到照片編輯軟件中的 3D 圖形渲染應(yīng)有盡有,并且希望將機(jī)器學(xué)習(xí)方法集成到您的系統(tǒng)中。 C ++前端使您可以繼續(xù)使用 C ++,并省去在 Python 和 C ++之間來回綁定的麻煩,同時(shí)保留了傳統(tǒng) PyTorch(Python)大部分的靈活性和直觀性。

C ++前端與 Python 前端并非是競(jìng)爭(zhēng)關(guān)系。 前者是對(duì)后者的補(bǔ)充。 我們知道研究人員和工程師都喜歡 PyTorch,因?yàn)樗哂泻?jiǎn)單,靈活和直觀的 API。 我們的目標(biāo)是確保您可以在所有可能的環(huán)境(包括上述環(huán)境)中利用這些核心設(shè)計(jì)原則。 如果上述的這些情況之一很好地描述了您的用例,或者您只是感興趣或好奇,請(qǐng)?jiān)谝韵露温渲欣^續(xù)研究 C ++前端。

Tip

C ++前端試圖提供一個(gè)與 Python 前端盡可能接近的 API。 如果您對(duì) Python 前端有豐富的經(jīng)驗(yàn),并且問過自己“我可以使用 C ++前端做些什么 ?”,請(qǐng)像在 Python 中那樣編寫代碼,并且大多數(shù)情況下,相同的函數(shù)和方法都可以在 C ++中使用。 就像在 Python 中一樣(記得用雙冒號(hào)替換點(diǎn))。

編寫基本應(yīng)用程序

首先,編寫一個(gè)最小的 C ++應(yīng)用程序,以驗(yàn)證我們是否在同一頁(yè)面上了解我們的設(shè)置和構(gòu)建環(huán)境。 首先,您需要獲取 LibTorch 發(fā)行版的副本-我們現(xiàn)成的 zip 歸檔文件,其中打包了使用 C ++前端所需的所有相關(guān)首部,庫(kù)和 CMake 構(gòu)建文件。 LibTorch 發(fā)行版可在 PyTorch 網(wǎng)站上下載,適用于 Linux,MacOS 和 Windows。 本教程的其余部分將假定基本的 Ubuntu Linux 環(huán)境,但是您也可以在 MacOS 或 Windows 上進(jìn)行學(xué)習(xí)。

Tip

關(guān)于安裝 PyTorch的 C ++發(fā)行版 的注釋更詳細(xì)地描述了以下步驟。

Tip

在 Windows 上,調(diào)試和發(fā)行版本不兼容 ABI。 如果您打算以調(diào)試模式構(gòu)建項(xiàng)目,請(qǐng)嘗試使用 LibTorch 的調(diào)試版本。 另外,請(qǐng)確保在下面的cmake --build .行中指定正確的配置。

第一步,通過從 PyTorch 網(wǎng)站獲取的鏈接在本地下載 LibTorch 發(fā)行版。 對(duì)于普通的 Ubuntu Linux 環(huán)境,這意味著運(yùn)行以下步驟:

  1. ## If you need e.g. CUDA 9.0 support, please replace "cpu" with "cu90" in the URL below.
  2. wget https://download.pytorch.org/libtorch/nightly/cpu/libtorch-shared-with-deps-latest.zip
  3. unzip libtorch-shared-with-deps-latest.zip

接下來,讓我們編寫一個(gè)名為dcgan.cpp的小型 C ++文件,其中包含torch/torch.h,現(xiàn)在只需打印出三乘三的單位矩陣即可:

  1. #include <torch/torch.h>
  2. #include <iostream>
  3. int main() {
  4. torch::Tensor tensor = torch::eye(3);
  5. std::cout << tensor << std::endl;
  6. }

要在以后構(gòu)建這個(gè)應(yīng)用程序以及我們完整的訓(xùn)練腳本,我們將使用以下CMakeLists.txt文件:

  1. cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
  2. project(dcgan)
  3. find_package(Torch REQUIRED)
  4. add_executable(dcgan dcgan.cpp)
  5. target_link_libraries(dcgan "${TORCH_LIBRARIES}")
  6. set_property(TARGET dcgan PROPERTY CXX_STANDARD 14)

注意

雖然 CMake 是 LibTorch 的推薦的構(gòu)建系統(tǒng),但這并不是硬性要求。 您還可以使用 Visual Studio 項(xiàng)目文件,QMake,普通 Makefile 或您認(rèn)為合適的任何其他構(gòu)建環(huán)境。 但是,我們不為此提供現(xiàn)成的支持。

在上面的 CMake 文件中記下第 4 行:find_package(Torch REQUIRED)。 這表示 CMake 在查找 LibTorch 庫(kù)的構(gòu)建配置。 為了使 CMake 知道在哪里找到這些文件,調(diào)用cmake時(shí)必須設(shè)置CMAKE_PREFIX_PATH。 在執(zhí)行此操作之前,讓我們就dcgan應(yīng)用程序的以下目錄結(jié)構(gòu)達(dá)成一致:_

  1. dcgan/
  2. CMakeLists.txt
  3. dcgan.cpp

此外,我將指向未壓縮的 LibTorch 分布的路徑稱為/path/to/libtorch。 請(qǐng)注意,此必須是絕對(duì)路徑。 特別是,將CMAKE_PREFIX_PATH設(shè)置為../../libtorch之類的內(nèi)容會(huì)以意想不到的方式中斷, 應(yīng)該寫$PWD/../../libtorch以獲取相應(yīng)的絕對(duì)路徑。 現(xiàn)在,我們準(zhǔn)備構(gòu)建我們的應(yīng)用程序:

  1. root@fa350df05ecf:/home# mkdir build
  2. root@fa350df05ecf:/home# cd build
  3. root@fa350df05ecf:/home/build# cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch ..
  4. -- The C compiler identification is GNU 5.4.0
  5. -- The CXX compiler identification is GNU 5.4.0
  6. -- Check for working C compiler: /usr/bin/cc
  7. -- Check for working C compiler: /usr/bin/cc -- works
  8. -- Detecting C compiler ABI info
  9. -- Detecting C compiler ABI info - done
  10. -- Detecting C compile features
  11. -- Detecting C compile features - done
  12. -- Check for working CXX compiler: /usr/bin/c++
  13. -- Check for working CXX compiler: /usr/bin/c++ -- works
  14. -- Detecting CXX compiler ABI info
  15. -- Detecting CXX compiler ABI info - done
  16. -- Detecting CXX compile features
  17. -- Detecting CXX compile features - done
  18. -- Looking for pthread.h
  19. -- Looking for pthread.h - found
  20. -- Looking for pthread_create
  21. -- Looking for pthread_create - not found
  22. -- Looking for pthread_create in pthreads
  23. -- Looking for pthread_create in pthreads - not found
  24. -- Looking for pthread_create in pthread
  25. -- Looking for pthread_create in pthread - found
  26. -- Found Threads: TRUE
  27. -- Found torch: /path/to/libtorch/lib/libtorch.so
  28. -- Configuring done
  29. -- Generating done
  30. -- Build files have been written to: /home/build
  31. root@fa350df05ecf:/home/build# cmake --build . --config Release
  32. Scanning dependencies of target dcgan
  33. [ 50%] Building CXX object CMakeFiles/dcgan.dir/dcgan.cpp.o
  34. [100%] Linking CXX executable dcgan
  35. [100%] Built target dcgan

上面,我們首先在dcgan目錄內(nèi)創(chuàng)建一個(gè)build文件夾,進(jìn)入該文件夾,運(yùn)行cmake命令以生成必要的 build(Make)文件,最后通過運(yùn)行cmake --build . --config Release成功編譯該項(xiàng)目。 現(xiàn)在我們準(zhǔn)備執(zhí)行最小的二進(jìn)制文件并完成有關(guān)基本項(xiàng)目配置的這一部分:

  1. root@fa350df05ecf:/home/build# ./dcgan
  2. 1 0 0
  3. 0 1 0
  4. 0 0 1
  5. [ Variable[CPUFloatType]{3,3} ]

在我看來這就像一個(gè)單位矩陣!

定義神經(jīng)網(wǎng)絡(luò)模型

現(xiàn)在我們已經(jīng)配置了基本環(huán)境,我們可以深入研究本教程中更有趣的部分。 首先,我們將討論如何在 C ++前端中定義模塊并與之交互。 我們將從基本的小規(guī)模示例模塊開始,然后使用 C ++前端提供的廣泛的內(nèi)置模塊庫(kù)來實(shí)現(xiàn)完整的 GAN。

模塊 API 基礎(chǔ)

與 Python 界面一致,基于 C ++前端的神經(jīng)網(wǎng)絡(luò)由稱為模塊的可重用構(gòu)建塊組成。 有一個(gè)基礎(chǔ)模塊類,所有其他模塊都從該基礎(chǔ)類派生。 在 Python 中,此類為torch.nn.Module,在 C ++中為torch::nn::Module。 除了實(shí)現(xiàn)模塊封裝的算法的forward()方法之外,模塊通常還包含以下三種子對(duì)象中的任何一種:參數(shù),緩沖區(qū)和子模塊。

參數(shù)和緩沖區(qū)以張量的形式存儲(chǔ)。 參數(shù)記錄梯度,但緩沖區(qū)不記錄。 參數(shù)通常是神經(jīng)網(wǎng)絡(luò)的可訓(xùn)練權(quán)重。 緩沖區(qū)的示例包括批量標(biāo)準(zhǔn)化的均值和方差。 為了重用特定的邏輯和狀態(tài)塊,PyTorch API 允許嵌套模塊。 嵌套模塊稱為子模塊。

參數(shù),緩沖區(qū)和子模塊是必須被注冊(cè)的。 注冊(cè)后,可以使用parameters()buffers()之類的方法來檢索整個(gè)(嵌套)模塊層次結(jié)構(gòu)中所有參數(shù)的容器。 類似地,使用to(...)之類的方法,例如 to(torch::kCUDA)將所有參數(shù)和緩沖區(qū)從 CPU 移到 CUDA 內(nèi)存,在整個(gè)模塊層次結(jié)構(gòu)上工作。

定義模塊和注冊(cè)參數(shù)

為了將這些詞寫成代碼,讓我們考慮一下用 Python 界面編寫的簡(jiǎn)單模塊:

  1. import torch
  2. class Net(torch.nn.Module):
  3. def __init__(self, N, M):
  4. super(Net, self).__init__()
  5. self.W = torch.nn.Parameter(torch.randn(N, M))
  6. self.b = torch.nn.Parameter(torch.randn(M))
  7. def forward(self, input):
  8. return torch.addmm(self.b, input, self.W)

在 C ++中,它看起來像這樣:

  1. #include <torch/torch.h>
  2. struct Net : torch::nn::Module {
  3. Net(int64_t N, int64_t M) {
  4. W = register_parameter("W", torch::randn({N, M}));
  5. b = register_parameter("b", torch::randn(M));
  6. }
  7. torch::Tensor forward(torch::Tensor input) {
  8. return torch::addmm(b, input, W);
  9. }
  10. torch::Tensor W, b;
  11. };

就像在 Python 中一樣,我們定義了一個(gè)名為Net的類(為簡(jiǎn)單起見,這里是struct而不是class),然后從模塊基類派生它。 在構(gòu)造函數(shù)內(nèi)部,我們使用torch::randn創(chuàng)建張量,就像在 Python 中使用torch.randn一樣。 一個(gè)有趣的區(qū)別是我們?nèi)绾巫?cè)參數(shù)。 在 Python 中,我們用torch.nn.Parameter類包裝了張量,而在 C ++中,我們不得不通過register_parameter方法傳遞張量。 這樣做的原因是 Python API 可以檢測(cè)到屬性為torch.nn.Parameter類型并自動(dòng)注冊(cè)此類張量。 在 C ++中,反射非常有限,因此提供了一種更傳統(tǒng)(而且并不是那么不可思議)的方法。

注冊(cè)子模塊并遍歷模塊層次結(jié)構(gòu)

同樣,我們可以注冊(cè)參數(shù),也可以注冊(cè)子模塊。 在 Python 中,將子模塊分配為模塊的屬性時(shí),會(huì)自動(dòng)檢測(cè)并注冊(cè)這些子模塊:

  1. class Net(torch.nn.Module):
  2. def __init__(self, N, M):
  3. super(Net, self).__init__()
  4. # Registered as a submodule behind the scenes
  5. self.linear = torch.nn.Linear(N, M)
  6. self.another_bias = torch.nn.Parameter(torch.rand(M))
  7. def forward(self, input):
  8. return self.linear(input) + self.another_bias

例如,允許使用parameters()方法來遞歸訪問模塊層次結(jié)構(gòu)中的所有參數(shù):

  1. >>> net = Net(4, 5)
  2. >>> print(list(net.parameters()))
  3. [Parameter containing:
  4. tensor([0.0808, 0.8613, 0.2017, 0.5206, 0.5353], requires_grad=True), Parameter containing:
  5. tensor([[-0.3740, -0.0976, -0.4786, -0.4928],
  6. [-0.1434, 0.4713, 0.1735, -0.3293],
  7. [-0.3467, -0.3858, 0.1980, 0.1986],
  8. [-0.1975, 0.4278, -0.1831, -0.2709],
  9. [ 0.3730, 0.4307, 0.3236, -0.0629]], requires_grad=True), Parameter containing:
  10. tensor([ 0.2038, 0.4638, -0.2023, 0.1230, -0.0516], requires_grad=True)]

要在 C ++中注冊(cè)子模塊,請(qǐng)使用恰當(dāng)命名的register_module()方法注冊(cè)類似torch::nn::Linear的模塊:

  1. struct Net : torch::nn::Module {
  2. Net(int64_t N, int64_t M)
  3. : linear(register_module("linear", torch::nn::Linear(N, M))) {
  4. another_bias = register_parameter("b", torch::randn(M));
  5. }
  6. torch::Tensor forward(torch::Tensor input) {
  7. return linear(input) + another_bias;
  8. }
  9. torch::nn::Linear linear;
  10. torch::Tensor another_bias;
  11. };

Tip

您可以在torch::nn命名空間的文檔中找到可用的內(nèi)置模塊的完整列表,例如torch::nn::Linear,torch::nn::Dropouttorch::nn::Conv2d

微妙之處在于,為什么在構(gòu)造函數(shù)的初始值設(shè)定項(xiàng)列表中創(chuàng)建子模塊,而在構(gòu)造函數(shù)的主體內(nèi)部創(chuàng)建參數(shù)。 這是有充分的理由的,我們將在下面有關(guān) C ++前端的所有權(quán)模型的部分中對(duì)此進(jìn)行介紹。 但是,最終結(jié)果是,就像 Python 中一樣,我們可以遞歸訪問模塊樹的參數(shù)。 調(diào)用parameters()將返回std::vector<torch::Tensor>,我們可以對(duì)其進(jìn)行迭代:

  1. int main() {
  2. Net net(4, 5);
  3. for (const auto& p : net.parameters()) {
  4. std::cout << p << std::endl;
  5. }
  6. }

打印:

  1. root@fa350df05ecf:/home/build# ./dcgan
  2. 0.0345
  3. 1.4456
  4. -0.6313
  5. -0.3585
  6. -0.4008
  7. [ Variable[CPUFloatType]{5} ]
  8. -0.1647 0.2891 0.0527 -0.0354
  9. 0.3084 0.2025 0.0343 0.1824
  10. -0.4630 -0.2862 0.2500 -0.0420
  11. 0.3679 -0.1482 -0.0460 0.1967
  12. 0.2132 -0.1992 0.4257 0.0739
  13. [ Variable[CPUFloatType]{5,4} ]
  14. 0.01 *
  15. 3.6861
  16. -10.1166
  17. -45.0333
  18. 7.9983
  19. -20.0705
  20. [ Variable[CPUFloatType]{5} ]

具有三個(gè)參數(shù),就像在 Python 中一樣。 為了也查看這些參數(shù)的名稱,C ++ API 提供了named_parameters()方法,該方法返回OrderedDict,就像在 Python 中一樣:

  1. Net net(4, 5);
  2. for (const auto& pair : net.named_parameters()) {
  3. std::cout << pair.key() << ": " << pair.value() << std::endl;
  4. }

我們可以再次執(zhí)行以查看輸出:

  1. root@fa350df05ecf:/home/build# make && ./dcgan 11:13:48
  2. Scanning dependencies of target dcgan
  3. [ 50%] Building CXX object CMakeFiles/dcgan.dir/dcgan.cpp.o
  4. [100%] Linking CXX executable dcgan
  5. [100%] Built target dcgan
  6. b: -0.1863
  7. -0.8611
  8. -0.1228
  9. 1.3269
  10. 0.9858
  11. [ Variable[CPUFloatType]{5} ]
  12. linear.weight: 0.0339 0.2484 0.2035 -0.2103
  13. -0.0715 -0.2975 -0.4350 -0.1878
  14. -0.3616 0.1050 -0.4982 0.0335
  15. -0.1605 0.4963 0.4099 -0.2883
  16. 0.1818 -0.3447 -0.1501 -0.0215
  17. [ Variable[CPUFloatType]{5,4} ]
  18. linear.bias: -0.0250
  19. 0.0408
  20. 0.3756
  21. -0.2149
  22. -0.3636
  23. [ Variable[CPUFloatType]{5} ]

Note

torch::nn::Module的文檔包含在模塊層次結(jié)構(gòu)上運(yùn)行方法的完整列表中。

在轉(zhuǎn)發(fā)模式下運(yùn)行網(wǎng)絡(luò)

要使用 C ++執(zhí)行網(wǎng)絡(luò),我們只需調(diào)用我們自己定義的forward()方法:

  1. int main() {
  2. Net net(4, 5);
  3. std::cout << net.forward(torch::ones({2, 4})) << std::endl;
  4. }

打?。?/p>

  1. root@fa350df05ecf:/home/build# ./dcgan
  2. 0.8559 1.1572 2.1069 -0.1247 0.8060
  3. 0.8559 1.1572 2.1069 -0.1247 0.8060
  4. [ Variable[CPUFloatType]{2,5} ]

模塊所有權(quán)

至此,我們知道了如何使用 C ++定義模塊,注冊(cè)參數(shù),注冊(cè)子模塊,通過parameters()之類的方法遍歷模塊層次結(jié)構(gòu)并最終運(yùn)行模塊的forward()方法。 盡管在 C ++ API 中還有很多方法,類和主題需要使用,但我將為您提供完整菜單的文檔。 我們將在稍后實(shí)現(xiàn) DCGAN 模型和端到端訓(xùn)練管道的過程中,涉及更多概念。 在我們這樣做之前,讓我簡(jiǎn)要地談?wù)?C ++前端為torch::nn::Module的子類提供的所有權(quán)模型。

在本次討論中,所有權(quán)模型是指模塊的存儲(chǔ)和傳遞方式-確定特定模塊實(shí)例的所有者或所有者。 在 Python 中,對(duì)象始終是動(dòng)態(tài)分配的(在堆上),并具有引用語(yǔ)義。 這是非常容易使用且易于理解的。 實(shí)際上,在 Python 中,您可以很大程度上忽略對(duì)象的位置以及如何引用它們,而將精力集中在完成事情上。

C ++是一種較低級(jí)的語(yǔ)言,它在此領(lǐng)域提供了更多選擇。 這增加了復(fù)雜性,并嚴(yán)重影響了 C ++前端的設(shè)計(jì)和人體工程學(xué)。 特別是,對(duì)于 C ++前端中的模塊,我們可以選擇使用值語(yǔ)義參考語(yǔ)義。 第一種情況是最簡(jiǎn)單的,并且在到目前為止的示例中已進(jìn)行了展示:模塊對(duì)象分配在堆棧上,并在傳遞給函數(shù)時(shí)可以復(fù)制,移動(dòng)(使用std::move)或通過引用或指針獲?。?/p>

  1. struct Net : torch::nn::Module { };
  2. void a(Net net) { }
  3. void b(Net& net) { }
  4. void c(Net* net) { }
  5. int main() {
  6. Net net;
  7. a(net);
  8. a(std::move(net));
  9. b(net);
  10. c(&net);
  11. }

對(duì)于第二種情況-參考語(yǔ)義-我們可以使用std::shared_ptr。 引用語(yǔ)義的優(yōu)勢(shì)在于,就像在 Python 中一樣,它減少了思考如何將模塊傳遞給函數(shù)以及如何聲明參數(shù)的認(rèn)知開銷(假設(shè)您在任何地方都使用shared_ptr)。

  1. struct Net : torch::nn::Module {};
  2. void a(std::shared_ptr<Net> net) { }
  3. int main() {
  4. auto net = std::make_shared<Net>();
  5. a(net);
  6. }

根據(jù)我們的經(jīng)驗(yàn),來自動(dòng)態(tài)語(yǔ)言的研究人員非常喜歡引用語(yǔ)義而不是值語(yǔ)義,盡管后者比 C ++更“原生”。 同樣重要的是,torch::nn::Module的設(shè)計(jì)為了要與 Python API 的人體工程學(xué)保持緊密聯(lián)系,要共享所有權(quán)。 例如,采用我們先前的Net定義(此處為簡(jiǎn)稱):

  1. struct Net : torch::nn::Module {
  2. Net(int64_t N, int64_t M)
  3. : linear(register_module("linear", torch::nn::Linear(N, M)))
  4. { }
  5. torch::nn::Linear linear;
  6. };

為了使用linear子模塊,我們想將其直接存儲(chǔ)在我們的類中。 但是,我們還希望模塊基類了解并有權(quán)訪問此子模塊。 為此,它必須存儲(chǔ)對(duì)此子模塊的引用。 至此,我們已經(jīng)達(dá)到了共享所有權(quán)的需要。 torch::nn::Module類和具體的Net類都需要引用該子模塊。 因此,基類將模塊存儲(chǔ)為shared_ptr,因此具體類也必須存儲(chǔ)。

可是等等! 在以上代碼中我沒有看到任何關(guān)于shared_ptr的提示! 這是為什么? 好吧,因?yàn)?code>std::shared_ptr<MyModule>實(shí)在令人難受。 為了保持研究人員的生產(chǎn)力,我們提出了一個(gè)精心設(shè)計(jì)的方案,以隱藏shared_ptr的提法-通常保留給值語(yǔ)義的好處-同時(shí)保留參考語(yǔ)義。 要了解它是如何工作的,我們可以看一下核心庫(kù)中torch::nn::Linear模塊的簡(jiǎn)化定義(完整定義為,在此處):

  1. struct LinearImpl : torch::nn::Module {
  2. LinearImpl(int64_t in, int64_t out);
  3. Tensor forward(const Tensor& input);
  4. Tensor weight, bias;
  5. };
  6. TORCH_MODULE(Linear);

簡(jiǎn)而言之:該模塊不是Linear,而是LinearImpl。 然后,宏TORCH_MODULE定義了實(shí)際的Linear類。 這個(gè)“生成的”類實(shí)際上是std::shared_ptr<LinearImpl>的包裝。 它是一個(gè)包裝器,而不是簡(jiǎn)單的 typedef,因此,除其他事項(xiàng)外,構(gòu)造函數(shù)仍可按預(yù)期工作,即,您仍然可以編寫torch::nn::Linear(3, 4)而不是std::make_shared<LinearImpl>(3, 4)。 我們將由宏創(chuàng)建的類稱為模塊持有人。 與(共享)指針一樣,您可以使用箭頭運(yùn)算符(例如model->forward(...))訪問基礎(chǔ)對(duì)象。 最終結(jié)果是一個(gè)所有權(quán)模型,該所有權(quán)模型非常類似于 Python API。 引用語(yǔ)義成為默認(rèn)語(yǔ)義,但是沒有額外輸入std::shared_ptrstd::make_shared。 對(duì)于我們的Net,使用模塊持有人 API 如下所示:

  1. struct NetImpl : torch::nn::Module {};
  2. TORCH_MODULE(Net);
  3. void a(Net net) { }
  4. int main() {
  5. Net net;
  6. a(net);
  7. }

這里有一個(gè)微妙的問題值得一提。 默認(rèn)構(gòu)造的std::shared_ptr為“空”,即包含空指針。 什么是默認(rèn)構(gòu)造的LinearNet? 好吧,這是一個(gè)棘手的選擇。 我們可以說它應(yīng)該是一個(gè)空(null)std::shared_ptr<LinearImpl>。 但是,請(qǐng)記住Linear(3, 4)std::make_shared<LinearImpl>(3, 4)相同。 這意味著如果我們已確定Linear linear;應(yīng)該為空指針,則將無法構(gòu)造不采用任何構(gòu)造函數(shù)參數(shù)或都不使用所有缺省構(gòu)造函數(shù)的模塊。 因此,在當(dāng)前的 API 中,默認(rèn)構(gòu)造的模塊持有人(如Linear())將調(diào)用基礎(chǔ)模塊的默認(rèn)構(gòu)造函數(shù)(LinearImpl())。 如果基礎(chǔ)模塊沒有默認(rèn)構(gòu)造函數(shù),則會(huì)出現(xiàn)編譯器錯(cuò)誤。 要構(gòu)造空持有人,可以將nullptr傳遞給持有人的構(gòu)造函數(shù)。

實(shí)際上,這意味著您可以使用如先前所示的子模塊,在初始化程序列表中注冊(cè)并構(gòu)造該模塊:

  1. struct Net : torch::nn::Module {
  2. Net(int64_t N, int64_t M)
  3. : linear(register_module("linear", torch::nn::Linear(N, M)))
  4. { }
  5. torch::nn::Linear linear;
  6. };

或者,您可以先使用空指針構(gòu)造持有人,然后在構(gòu)造函數(shù)中為其分配值(Pythonistas 更熟悉):

  1. struct Net : torch::nn::Module {
  2. Net(int64_t N, int64_t M) {
  3. linear = register_module("linear", torch::nn::Linear(N, M));
  4. }
  5. torch::nn::Linear linear{nullptr}; // construct an empty holder
  6. };

結(jié)論:您應(yīng)該使用哪種所有權(quán)模型–哪種語(yǔ)義? C ++前端的 API 最能支持模塊所有者提供的所有權(quán)模型。 這種機(jī)制的唯一缺點(diǎn)是在模塊聲明下方多了一行樣板。 也就是說,最簡(jiǎn)單的模型仍然是 C ++模塊簡(jiǎn)介中顯示的值語(yǔ)義模型。 對(duì)于小的,簡(jiǎn)單的腳本,您也可以擺脫它。 但是,由于技術(shù)原因,您遲早會(huì)發(fā)現(xiàn)它并不總是受支持。 例如,序列化 API(torch::savetorch::load)僅支持模塊支架(或普通shared_ptr)。 因此,推薦使用模塊持有人 API 和 C ++前端定義模塊,此后我們將在本教程中使用此 API。

定義 DCGAN 模塊

現(xiàn)在,我們有必要的背景和簡(jiǎn)介來定義我們要在本文中解決的機(jī)器學(xué)習(xí)任務(wù)的模塊。 回顧一下:我們的任務(wù)是從 MNIST 數(shù)據(jù)集生成數(shù)字圖像。 我們想使用生成對(duì)抗網(wǎng)絡(luò)(GAN)解決此任務(wù)。 特別是,我們將使用 DCGAN 體系結(jié)構(gòu)-這是同類中最早的也是最簡(jiǎn)單的一種,但是完全可以完成此任務(wù)。

Tip

您可以在存儲(chǔ)庫(kù)中找到本教程中提供的完整源代碼。

什么是 GAN aGAN?

GAN 由兩個(gè)不同的神經(jīng)網(wǎng)絡(luò)模型組成:生成器鑒別器。 生成器從噪聲分布中接收樣本,其目的是將每個(gè)噪聲樣本轉(zhuǎn)換為類似于目標(biāo)分布的圖像(在我們的情況下為 MNIST 數(shù)據(jù)集)。 鑒別器又從 MNIST 數(shù)據(jù)集接收實(shí)際圖像,或從生成器接收圖像。 要求發(fā)出一個(gè)概率來判斷特定圖像的真實(shí)程度(接近1)或偽造(接近0)。 來自鑒別器的關(guān)于由發(fā)生器產(chǎn)生的圖像如何真實(shí)的反饋被用來訓(xùn)練發(fā)生器。 鑒別器對(duì)真實(shí)性有多好的反饋將用于優(yōu)化鑒別器。 從理論上講,生成器和鑒別器之間的微妙平衡使它們串聯(lián)起來得到改善,從而導(dǎo)致生成器生成與目標(biāo)分布無法區(qū)分的圖像,從而使鑒別器(那時(shí))的敏銳眼睛冒出了散發(fā)0.5的真實(shí)和真實(shí)可能性。 假圖片。 對(duì)我們來說,最終結(jié)果是一臺(tái)接收噪聲作為輸入并生成數(shù)字逼真的圖像作為其輸出的機(jī)器。

發(fā)電機(jī)模塊

我們首先定義生成器模塊,該模塊由一系列轉(zhuǎn)置的 2D 卷積,批處理歸一化和 ReLU 激活單元組成。 我們?cè)诙x自己的模塊的forward()方法中顯式地(在功能上)在模塊之間傳遞輸入:

  1. struct DCGANGeneratorImpl : nn::Module {
  2. DCGANGeneratorImpl(int kNoiseSize)
  3. : conv1(nn::ConvTranspose2dOptions(kNoiseSize, 256, 4)
  4. .bias(false)),
  5. batch_norm1(256),
  6. conv2(nn::ConvTranspose2dOptions(256, 128, 3)
  7. .stride(2)
  8. .padding(1)
  9. .bias(false)),
  10. batch_norm2(128),
  11. conv3(nn::ConvTranspose2dOptions(128, 64, 4)
  12. .stride(2)
  13. .padding(1)
  14. .bias(false)),
  15. batch_norm3(64),
  16. conv4(nn::ConvTranspose2dOptions(64, 1, 4)
  17. .stride(2)
  18. .padding(1)
  19. .bias(false))
  20. {
  21. // register_module() is needed if we want to use the parameters() method later on
  22. register_module("conv1", conv1);
  23. register_module("conv2", conv2);
  24. register_module("conv3", conv3);
  25. register_module("conv4", conv4);
  26. register_module("batch_norm1", batch_norm1);
  27. register_module("batch_norm2", batch_norm2);
  28. register_module("batch_norm3", batch_norm3);
  29. }
  30. torch::Tensor forward(torch::Tensor x) {
  31. x = torch::relu(batch_norm1(conv1(x)));
  32. x = torch::relu(batch_norm2(conv2(x)));
  33. x = torch::relu(batch_norm3(conv3(x)));
  34. x = torch::tanh(conv4(x));
  35. return x;
  36. }
  37. nn::ConvTranspose2d conv1, conv2, conv3, conv4;
  38. nn::BatchNorm2d batch_norm1, batch_norm2, batch_norm3;
  39. };
  40. TORCH_MODULE(DCGANGenerator);
  41. DCGANGenerator generator(kNoiseSize);

現(xiàn)在我們可以在DCGANGenerator上調(diào)用forward()將噪聲樣本映射到圖像。

選擇的特定模塊,例如nn::ConvTranspose2dnn::BatchNorm2d,遵循前面概述的結(jié)構(gòu)。 kNoiseSize常數(shù)確定輸入噪聲矢量的大小,并將其設(shè)置為100。 當(dāng)然,超參數(shù)是通過研究生的血統(tǒng)發(fā)現(xiàn)的。

注意

在超參數(shù)的發(fā)現(xiàn)中,沒有研究生受到傷害。 他們定期喂給 Soylent。

Note

簡(jiǎn)要介紹如何將選項(xiàng)傳遞給 C ++前端中的Conv2d等內(nèi)置模塊:每個(gè)模塊都有一些必需的選項(xiàng),例如BatchNorm2d的功能數(shù)量。 如果您只需要配置所需的選項(xiàng),則可以將它們直接傳遞給模塊的構(gòu)造函數(shù),例如BatchNorm2d(128)Dropout(0.5)Conv2d(8, 4, 2)(用于輸入通道數(shù),輸出通道數(shù)和內(nèi)核大小)。 但是,如果需要修改其他通常默認(rèn)設(shè)置的選項(xiàng),例如Conv2dbias,則需要構(gòu)造并傳遞選項(xiàng)對(duì)象。 C ++前端中的每個(gè)模塊都有一個(gè)關(guān)聯(lián)的選項(xiàng)結(jié)構(gòu),稱為ModuleOptions,其中Module是模塊的名稱,例如LinearLinearOptions。 這就是我們上面的Conv2d模塊的工作。

鑒別模塊

鑒別器類似地是卷積,批歸一化和激活的序列。 但是,卷積現(xiàn)在是常規(guī)的卷積,而不是轉(zhuǎn)置的卷積,我們使用 alpha 值為 0.2 的泄漏 ReLU 代替了普通的 ReLU。 同樣,最后的激活變?yōu)?Sigmoid,將值壓縮到 0 到 1 之間。然后,我們可以將這些壓縮后的值解釋為鑒別器分配給真實(shí)圖像的概率。

為了構(gòu)建鑒別器,我們將嘗試不同的方法:<cite>順序</cite>模塊。 像在 Python 中一樣,PyTorch 在此提供了兩種用于模型定義的 API:一種功能,其中的輸入通過連續(xù)的函數(shù)傳遞(例如,生成器模塊示例),而另一種面向?qū)ο蟮?,其中我們?gòu)建了<cite>順序</cite>模塊 包含整個(gè)模型作為子模塊。 使用<cite>順序</cite>,鑒別符將如下所示:

  1. nn::Sequential discriminator(
  2. // Layer 1
  3. nn::Conv2d(
  4. nn::Conv2dOptions(1, 64, 4).stride(2).padding(1).bias(false)),
  5. nn::LeakyReLU(nn::LeakyReLUOptions().negative_slope(0.2)),
  6. // Layer 2
  7. nn::Conv2d(
  8. nn::Conv2dOptions(64, 128, 4).stride(2).padding(1).bias(false)),
  9. nn::BatchNorm2d(128),
  10. nn::LeakyReLU(nn::LeakyReLUOptions().negative_slope(0.2)),
  11. // Layer 3
  12. nn::Conv2d(
  13. nn::Conv2dOptions(128, 256, 4).stride(2).padding(1).bias(false)),
  14. nn::BatchNorm2d(256),
  15. nn::LeakyReLU(nn::LeakyReLUOptions().negative_slope(0.2)),
  16. // Layer 4
  17. nn::Conv2d(
  18. nn::Conv2dOptions(256, 1, 3).stride(1).padding(0).bias(false)),
  19. nn::Sigmoid());

Tip

Sequential模塊僅執(zhí)行功能組合。 第一個(gè)子模塊的輸出成為第二個(gè)子模塊的輸入,第三個(gè)子模塊的輸出成為第四個(gè)子模塊的輸入,依此類推。

加載數(shù)據(jù)中

現(xiàn)在我們已經(jīng)定義了生成器和鑒別器模型,我們需要一些可以用來訓(xùn)練這些模型的數(shù)據(jù)。 與 Python 一樣,C ++前端也具有強(qiáng)大的并行數(shù)據(jù)加載器。 該數(shù)據(jù)加載器可以從數(shù)據(jù)集中讀取批次數(shù)據(jù)(您可以定義自己),并提供許多配置旋鈕。

Note

盡管 Python 數(shù)據(jù)加載器使用多重處理,但 C ++數(shù)據(jù)加載器實(shí)際上是多線程的,不會(huì)啟動(dòng)任何新進(jìn)程。

數(shù)據(jù)加載器是 C ++前端data API 的一部分,該 API 包含在torch::data::名稱空間中。 該 API 由幾個(gè)不同的組件組成:

  • 數(shù)據(jù)加載器類,
  • 用于定義數(shù)據(jù)集的 API,
  • 用于定義轉(zhuǎn)換的 API,可以將其應(yīng)用于數(shù)據(jù)集,
  • 用于定義采樣器的 API,該采樣器會(huì)生成用于對(duì)數(shù)據(jù)集建立索引的索引,
  • 現(xiàn)有數(shù)據(jù)集,變換和采樣器的庫(kù)。

對(duì)于本教程,我們可以使用 C ++前端附帶的MNIST數(shù)據(jù)集。 讓我們?yōu)榇藢?shí)例化一個(gè)torch::data::datasets::MNIST,并應(yīng)用兩個(gè)轉(zhuǎn)換:首先,我們對(duì)圖像進(jìn)行歸一化,以使其在-1+1的范圍內(nèi)(從01的原始范圍)。 其次,我們應(yīng)用Stack 歸類,它采用一批張量并將它們沿第一維堆疊為單個(gè)張量:

  1. auto dataset = torch::data::datasets::MNIST("./mnist")
  2. .map(torch::data::transforms::Normalize<>(0.5, 0.5))
  3. .map(torch::data::transforms::Stack<>());

請(qǐng)注意,相對(duì)于執(zhí)行訓(xùn)練二進(jìn)制文件的位置,MNIST 數(shù)據(jù)集應(yīng)位于./mnist目錄中。 您可以使用此腳本下載 MNIST 數(shù)據(jù)集。

接下來,我們創(chuàng)建一個(gè)數(shù)據(jù)加載器并將其傳遞給此數(shù)據(jù)集。 為了創(chuàng)建一個(gè)新的數(shù)據(jù)加載器,我們使用torch::data::make_data_loader,它返回正確類型的std::unique_ptr(取決于數(shù)據(jù)集的類型,采樣器的類型以及其他一些實(shí)現(xiàn)細(xì)節(jié)):

  1. auto data_loader = torch::data::make_data_loader(std::move(dataset));

數(shù)據(jù)加載器確實(shí)提供了很多選項(xiàng)。 您可以在處檢查全套。 例如,為了加快數(shù)據(jù)加載速度,我們可以增加工作人員的數(shù)量。 默認(rèn)數(shù)字為零,這表示將使用主線程。 如果將workers設(shè)置為2,將產(chǎn)生兩個(gè)線程并發(fā)加載數(shù)據(jù)。 我們還應(yīng)該將批次大小從其默認(rèn)值1增大到更合理的值,例如64(kBatchSize的值)。 因此,讓我們創(chuàng)建一個(gè)DataLoaderOptions對(duì)象并設(shè)置適當(dāng)?shù)膶傩裕?/a>

  1. auto data_loader = torch::data::make_data_loader(
  2. std::move(dataset),
  3. torch::data::DataLoaderOptions().batch_size(kBatchSize).workers(2));

現(xiàn)在,我們可以編寫一個(gè)循環(huán)來加載批量數(shù)據(jù),目前我們僅將其打印到控制臺(tái):

  1. for (torch::data::Example<>& batch : *data_loader) {
  2. std::cout << "Batch size: " << batch.data.size(0) << " | Labels: ";
  3. for (int64_t i = 0; i < batch.data.size(0); ++i) {
  4. std::cout << batch.target[i].item<int64_t>() << " ";
  5. }
  6. std::cout << std::endl;
  7. }

在這種情況下,數(shù)據(jù)加載器返回的類型為torch::data::Example。 此類型是一種簡(jiǎn)單的結(jié)構(gòu),其中的data字段用于數(shù)據(jù),而target字段用于標(biāo)簽。 因?yàn)槲覀冎皯?yīng)用了Stack歸類,所以數(shù)據(jù)加載器僅返回一個(gè)這樣的示例。 如果我們未應(yīng)用排序規(guī)則,則數(shù)據(jù)加載器將改為生成std::vector<torch::data::Example<>>,批處理中每個(gè)示例包含一個(gè)元素。

如果重新生成并運(yùn)行此代碼,則應(yīng)看到類似以下內(nèi)容的內(nèi)容:

  1. root@fa350df05ecf:/home/build# make
  2. Scanning dependencies of target dcgan
  3. [ 50%] Building CXX object CMakeFiles/dcgan.dir/dcgan.cpp.o
  4. [100%] Linking CXX executable dcgan
  5. [100%] Built target dcgan
  6. root@fa350df05ecf:/home/build# make
  7. [100%] Built target dcgan
  8. root@fa350df05ecf:/home/build# ./dcgan
  9. Batch size: 64 | Labels: 5 2 6 7 2 1 6 7 0 1 6 2 3 6 9 1 8 4 0 6 5 3 3 0 4 6 6 6 4 0 8 6 0 6 9 2 4 0 2 8 6 3 3 2 9 2 0 1 4 2 3 4 8 2 9 9 3 5 8 0 0 7 9 9
  10. Batch size: 64 | Labels: 2 2 4 7 1 2 8 8 6 9 0 2 2 9 3 6 1 3 8 0 4 4 8 8 8 9 2 6 4 7 1 5 0 9 7 5 4 3 5 4 1 2 8 0 7 1 9 6 1 6 5 3 4 4 1 2 3 2 3 5 0 1 6 2
  11. Batch size: 64 | Labels: 4 5 4 2 1 4 8 3 8 3 6 1 5 4 3 6 2 2 5 1 3 1 5 0 8 2 1 5 3 2 4 4 5 9 7 2 8 9 2 0 6 7 4 3 8 3 5 8 8 3 0 5 8 0 8 7 8 5 5 6 1 7 8 0
  12. Batch size: 64 | Labels: 3 3 7 1 4 1 6 1 0 3 6 4 0 2 5 4 0 4 2 8 1 9 6 5 1 6 3 2 8 9 2 3 8 7 4 5 9 6 0 8 3 0 0 6 4 8 2 5 4 1 8 3 7 8 0 0 8 9 6 7 2 1 4 7
  13. Batch size: 64 | Labels: 3 0 5 5 9 8 3 9 8 9 5 9 5 0 4 1 2 7 7 2 0 0 5 4 8 7 7 6 1 0 7 9 3 0 6 3 2 6 2 7 6 3 3 4 0 5 8 8 9 1 9 2 1 9 4 4 9 2 4 6 2 9 4 0
  14. Batch size: 64 | Labels: 9 6 7 5 3 5 9 0 8 6 6 7 8 2 1 9 8 8 1 1 8 2 0 7 1 4 1 6 7 5 1 7 7 4 0 3 2 9 0 6 6 3 4 4 8 1 2 8 6 9 2 0 3 1 2 8 5 6 4 8 5 8 6 2
  15. Batch size: 64 | Labels: 9 3 0 3 6 5 1 8 6 0 1 9 9 1 6 1 7 7 4 4 4 7 8 8 6 7 8 2 6 0 4 6 8 2 5 3 9 8 4 0 9 9 3 7 0 5 8 2 4 5 6 2 8 2 5 3 7 1 9 1 8 2 2 7
  16. Batch size: 64 | Labels: 9 1 9 2 7 2 6 0 8 6 8 7 7 4 8 6 1 1 6 8 5 7 9 1 3 2 0 5 1 7 3 1 6 1 0 8 6 0 8 1 0 5 4 9 3 8 5 8 4 8 0 1 2 6 2 4 2 7 7 3 7 4 5 3
  17. Batch size: 64 | Labels: 8 8 3 1 8 6 4 2 9 5 8 0 2 8 6 6 7 0 9 8 3 8 7 1 6 6 2 7 7 4 5 5 2 1 7 9 5 4 9 1 0 3 1 9 3 9 8 8 5 3 7 5 3 6 8 9 4 2 0 1 2 5 4 7
  18. Batch size: 64 | Labels: 9 2 7 0 8 4 4 2 7 5 0 0 6 2 0 5 9 5 9 8 8 9 3 5 7 5 4 7 3 0 5 7 6 5 7 1 6 2 8 7 6 3 2 6 5 6 1 2 7 7 0 0 5 9 0 0 9 1 7 8 3 2 9 4
  19. Batch size: 64 | Labels: 7 6 5 7 7 5 2 2 4 9 9 4 8 7 4 8 9 4 5 7 1 2 6 9 8 5 1 2 3 6 7 8 1 1 3 9 8 7 9 5 0 8 5 1 8 7 2 6 5 1 2 0 9 7 4 0 9 0 4 6 0 0 8 6
  20. ...

這意味著我們能夠成功地從 MNIST 數(shù)據(jù)集中加載數(shù)據(jù)。

編寫訓(xùn)練循環(huán)

現(xiàn)在,讓我們完成示例的算法部分,并實(shí)現(xiàn)生成器和鑒別器之間的精妙舞蹈。 首先,我們將創(chuàng)建兩個(gè)優(yōu)化器,一個(gè)用于生成器,一個(gè)用于區(qū)分器。 我們使用的優(yōu)化程序?qū)崿F(xiàn)了 Adam 算法:

  1. torch::optim::Adam generator_optimizer(
  2. generator->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5));
  3. torch::optim::Adam discriminator_optimizer(
  4. discriminator->parameters(), torch::optim::AdamOptions(5e-4).beta1(0.5));

Note

在撰寫本文時(shí),C ++前端提供了實(shí)現(xiàn) Adagrad,Adam,LBBFG,RMSprop 和 SGD 的優(yōu)化器。 文檔具有最新列表。

接下來,我們需要更新我們的訓(xùn)練循環(huán)。 我們將添加一個(gè)外部循環(huán)以在每個(gè)時(shí)期耗盡數(shù)據(jù)加載器,然后編寫 GAN 訓(xùn)練代碼:

  1. for (int64_t epoch = 1; epoch <= kNumberOfEpochs; ++epoch) {
  2. int64_t batch_index = 0;
  3. for (torch::data::Example<>& batch : *data_loader) {
  4. // Train discriminator with real images.
  5. discriminator->zero_grad();
  6. torch::Tensor real_images = batch.data;
  7. torch::Tensor real_labels = torch::empty(batch.data.size(0)).uniform_(0.8, 1.0);
  8. torch::Tensor real_output = discriminator->forward(real_images);
  9. torch::Tensor d_loss_real = torch::binary_cross_entropy(real_output, real_labels);
  10. d_loss_real.backward();
  11. // Train discriminator with fake images.
  12. torch::Tensor noise = torch::randn({batch.data.size(0), kNoiseSize, 1, 1});
  13. torch::Tensor fake_images = generator->forward(noise);
  14. torch::Tensor fake_labels = torch::zeros(batch.data.size(0));
  15. torch::Tensor fake_output = discriminator->forward(fake_images.detach());
  16. torch::Tensor d_loss_fake = torch::binary_cross_entropy(fake_output, fake_labels);
  17. d_loss_fake.backward();
  18. torch::Tensor d_loss = d_loss_real + d_loss_fake;
  19. discriminator_optimizer.step();
  20. // Train generator.
  21. generator->zero_grad();
  22. fake_labels.fill_(1);
  23. fake_output = discriminator->forward(fake_images);
  24. torch::Tensor g_loss = torch::binary_cross_entropy(fake_output, fake_labels);
  25. g_loss.backward();
  26. generator_optimizer.step();
  27. std::printf(
  28. "\r[%2ld/%2ld][%3ld/%3ld] D_loss: %.4f | G_loss: %.4f",
  29. epoch,
  30. kNumberOfEpochs,
  31. ++batch_index,
  32. batches_per_epoch,
  33. d_loss.item<float>(),
  34. g_loss.item<float>());
  35. }
  36. }

上面,我們首先在真實(shí)圖像上評(píng)估鑒別器,為此應(yīng)為其分配較高的概率。 為此,我們使用torch::empty(batch.data.size(0)).uniform_(0.8, 1.0)作為目標(biāo)概率。

Note

我們選擇均勻分布在 0.8 到 1.0 之間的隨機(jī)值,而不是各處的 1.0,以使鑒別器訓(xùn)練更加可靠。 此技巧稱為標(biāo)簽平滑。

在評(píng)估鑒別器之前,我們將其參數(shù)的梯度歸零。 計(jì)算完損耗后,我們通過調(diào)用d_loss.backward()計(jì)算新的梯度來在網(wǎng)絡(luò)中反向傳播。 我們對(duì)虛假圖像重復(fù)此步驟。 我們不使用數(shù)據(jù)集中的圖像,而是讓生成器通過為它提供一批隨機(jī)噪聲來為此創(chuàng)建偽造圖像。 然后,我們將這些偽造圖像轉(zhuǎn)發(fā)給鑒別器。 這次,我們希望鑒別器發(fā)出低概率,最好是全零。 一旦計(jì)算了一批真實(shí)圖像和一批偽造圖像的鑒別器損耗,我們就可以一步一步地進(jìn)行鑒別器的優(yōu)化程序,以更新其參數(shù)。

為了訓(xùn)練生成器,我們?cè)俅问紫葘⑵涮荻葰w零,然后在偽圖像上重新評(píng)估鑒別器。 但是,這一次,我們希望鑒別器將概率分配為非常接近的概率,這將表明生成器可以生成使鑒別器認(rèn)為它們實(shí)際上是真實(shí)的圖像(來自數(shù)據(jù)集)。 為此,我們用全部填充fake_labels張量。 最后,我們逐步使用生成器的優(yōu)化器來更新其參數(shù)。

現(xiàn)在,我們應(yīng)該準(zhǔn)備在 CPU 上訓(xùn)練我們的模型。 我們還沒有任何代碼可以捕獲狀態(tài)或示例輸出,但是我們稍后會(huì)添加。 現(xiàn)在,讓我們觀察一下我們的模型正在做某事 –我們稍后將根據(jù)生成的圖像來驗(yàn)證這是否有意義。 重建和運(yùn)行應(yīng)打印如下內(nèi)容:

  1. root@3c0711f20896:/home/build# make && ./dcgan
  2. Scanning dependencies of target dcgan
  3. [ 50%] Building CXX object CMakeFiles/dcgan.dir/dcgan.cpp.o
  4. [100%] Linking CXX executable dcgan
  5. [100%] Built target dcga
  6. [ 1/10][100/938] D_loss: 0.6876 | G_loss: 4.1304
  7. [ 1/10][200/938] D_loss: 0.3776 | G_loss: 4.3101
  8. [ 1/10][300/938] D_loss: 0.3652 | G_loss: 4.6626
  9. [ 1/10][400/938] D_loss: 0.8057 | G_loss: 2.2795
  10. [ 1/10][500/938] D_loss: 0.3531 | G_loss: 4.4452
  11. [ 1/10][600/938] D_loss: 0.3501 | G_loss: 5.0811
  12. [ 1/10][700/938] D_loss: 0.3581 | G_loss: 4.5623
  13. [ 1/10][800/938] D_loss: 0.6423 | G_loss: 1.7385
  14. [ 1/10][900/938] D_loss: 0.3592 | G_loss: 4.7333
  15. [ 2/10][100/938] D_loss: 0.4660 | G_loss: 2.5242
  16. [ 2/10][200/938] D_loss: 0.6364 | G_loss: 2.0886
  17. [ 2/10][300/938] D_loss: 0.3717 | G_loss: 3.8103
  18. [ 2/10][400/938] D_loss: 1.0201 | G_loss: 1.3544
  19. [ 2/10][500/938] D_loss: 0.4522 | G_loss: 2.6545
  20. ...

移至 GPU

盡管我們當(dāng)前的腳本可以在 CPU 上正常運(yùn)行,但是我們都知道卷積在 GPU 上要快得多。 讓我們快速討論如何將訓(xùn)練轉(zhuǎn)移到 GPU 上。 為此,我們需要做兩件事:將 GPU 設(shè)備規(guī)范傳遞給我們分配給自己的張量,并通過to()方法將所有其他張量明確復(fù)制到 C ++前端中所有張量和模塊上。 實(shí)現(xiàn)這兩者的最簡(jiǎn)單方法是在訓(xùn)練腳本的頂層創(chuàng)建torch::Device的實(shí)例,然后將該設(shè)備傳遞給諸如torch::zerosto()方法之類的張量工廠函數(shù)。 我們可以從使用 CPU 設(shè)備開始:

  1. // Place this somewhere at the top of your training script.
  2. torch::Device device(torch::kCPU);

新的張量分配,例如

  1. torch::Tensor fake_labels = torch::zeros(batch.data.size(0));

應(yīng)該更新為以device作為最后一個(gè)參數(shù):

  1. torch::Tensor fake_labels = torch::zeros(batch.data.size(0), device);

對(duì)于那些不在我們手中的張量,例如來自 MNIST 數(shù)據(jù)集的張量,我們必須插入顯式的to()調(diào)用。 這表示

  1. torch::Tensor real_images = batch.data;

變成

  1. torch::Tensor real_images = batch.data.to(device);

并且我們的模型參數(shù)也應(yīng)該移到正確的設(shè)備上:

  1. generator->to(device);
  2. discriminator->to(device);

Note

如果張量已經(jīng)存在于提供給to()的設(shè)備上,則該調(diào)用為空操作。 沒有多余的副本。

至此,我們已經(jīng)使之前的 CPU 駐留代碼更加明確。 但是,現(xiàn)在將設(shè)備更改為 CUDA 設(shè)備也非常容易:

  1. torch::Device device(torch::kCUDA)

現(xiàn)在,所有張量都將駐留在 GPU 上,并調(diào)用快速 CUDA 內(nèi)核進(jìn)行所有操作,而無需我們更改任何下游代碼。 如果我們想指定一個(gè)特定的設(shè)備索引,則可以將其作為第二個(gè)參數(shù)傳遞給Device構(gòu)造函數(shù)。 如果我們希望不同的張量駐留在不同的設(shè)備上,則可以傳遞單獨(dú)的設(shè)備實(shí)例(例如,一個(gè)在 CUDA 設(shè)備 0 上,另一個(gè)在 CUDA 設(shè)備 1 上)。 我們甚至可以動(dòng)態(tài)地進(jìn)行此配置,這通常對(duì)于使我們的訓(xùn)練腳本更具可移植性很有用:

  1. torch::Device device = torch::kCPU;
  2. if (torch::cuda::is_available()) {
  3. std::cout << "CUDA is available! Training on GPU." << std::endl;
  4. device = torch::kCUDA;
  5. }

甚至

  1. torch::Device device(torch::cuda::is_available() ? torch::kCUDA : torch::kCPU);

檢查點(diǎn)和恢復(fù)訓(xùn)練狀態(tài)

我們應(yīng)該對(duì)訓(xùn)練腳本進(jìn)行的最后擴(kuò)充是定期保存模型參數(shù)的狀態(tài),優(yōu)化器的狀態(tài)以及一些生成的圖像樣本。 如果我們的計(jì)算機(jī)在訓(xùn)練過程中崩潰,則前兩個(gè)將使我們能夠恢復(fù)訓(xùn)練狀態(tài)。 對(duì)于長(zhǎng)期的訓(xùn)練課程,這是絕對(duì)必要的。 幸運(yùn)的是,C ++前端提供了一個(gè) API,用于對(duì)模型和優(yōu)化器狀態(tài)以及單個(gè)張量進(jìn)行序列化和反序列化。

為此的核心 API 是torch::save(thing,filename)torch::load(thing,filename),其中thing可以是torch::nn::Module子類或優(yōu)化程序?qū)嵗?,例如我們?cè)谟?xùn)練腳本中擁有的Adam對(duì)象。 讓我們更新訓(xùn)練循環(huán),以一定間隔檢查模型和優(yōu)化器狀態(tài):

  1. if (batch_index % kCheckpointEvery == 0) {
  2. // Checkpoint the model and optimizer state.
  3. torch::save(generator, "generator-checkpoint.pt");
  4. torch::save(generator_optimizer, "generator-optimizer-checkpoint.pt");
  5. torch::save(discriminator, "discriminator-checkpoint.pt");
  6. torch::save(discriminator_optimizer, "discriminator-optimizer-checkpoint.pt");
  7. // Sample the generator and save the images.
  8. torch::Tensor samples = generator->forward(torch::randn({8, kNoiseSize, 1, 1}, device));
  9. torch::save((samples + 1.0) / 2.0, torch::str("dcgan-sample-", checkpoint_counter, ".pt"));
  10. std::cout << "\n-> checkpoint " << ++checkpoint_counter << '\n';
  11. }

其中kCheckpointEvery是設(shè)置為類似于100之類的整數(shù),以便每批100都進(jìn)行檢查,而checkpoint_counter是每次創(chuàng)建檢查點(diǎn)時(shí)都會(huì)增加的計(jì)數(shù)器。

要恢復(fù)訓(xùn)練狀態(tài),可以在創(chuàng)建所有模型和優(yōu)化器之后但在訓(xùn)練循環(huán)之前添加如下代碼:

  1. torch::optim::Adam generator_optimizer(
  2. generator->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5));
  3. torch::optim::Adam discriminator_optimizer(
  4. discriminator->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5));
  5. if (kRestoreFromCheckpoint) {
  6. torch::load(generator, "generator-checkpoint.pt");
  7. torch::load(generator_optimizer, "generator-optimizer-checkpoint.pt");
  8. torch::load(discriminator, "discriminator-checkpoint.pt");
  9. torch::load(
  10. discriminator_optimizer, "discriminator-optimizer-checkpoint.pt");
  11. }
  12. int64_t checkpoint_counter = 0;
  13. for (int64_t epoch = 1; epoch <= kNumberOfEpochs; ++epoch) {
  14. int64_t batch_index = 0;
  15. for (torch::data::Example<>& batch : *data_loader) {

檢查生成的圖像

我們的訓(xùn)練腳本現(xiàn)已完成。 我們準(zhǔn)備在 CPU 或 GPU 上訓(xùn)練 GAN。 為了檢查我們訓(xùn)練過程的中間輸出,為此我們添加了將代碼樣本定期保存到"dcgan-sample-xxx.pt"文件的代碼,我們可以編寫一個(gè)小的 Python 腳本來加載張量并使用 matplotlib 顯示它們:

  1. from __future__ import print_function
  2. from __future__ import unicode_literals
  3. import argparse
  4. import matplotlib.pyplot as plt
  5. import torch
  6. parser = argparse.ArgumentParser()
  7. parser.add_argument("-i", "--sample-file", required=True)
  8. parser.add_argument("-o", "--out-file", default="out.png")
  9. parser.add_argument("-d", "--dimension", type=int, default=3)
  10. options = parser.parse_args()
  11. module = torch.jit.load(options.sample_file)
  12. images = list(module.parameters())[0]
  13. for index in range(options.dimension * options.dimension):
  14. image = images[index].detach().cpu().reshape(28, 28).mul(255).to(torch.uint8)
  15. array = image.numpy()
  16. axis = plt.subplot(options.dimension, options.dimension, 1 + index)
  17. plt.imshow(array, cmap="gray")
  18. axis.get_xaxis().set_visible(False)
  19. axis.get_yaxis().set_visible(False)
  20. plt.savefig(options.out_file)
  21. print("Saved ", options.out_file)

現(xiàn)在,讓我們訓(xùn)練模型約 30 個(gè)紀(jì)元:

  1. root@3c0711f20896:/home/build# make && ./dcgan 10:17:57
  2. Scanning dependencies of target dcgan
  3. [ 50%] Building CXX object CMakeFiles/dcgan.dir/dcgan.cpp.o
  4. [100%] Linking CXX executable dcgan
  5. [100%] Built target dcgan
  6. CUDA is available! Training on GPU.
  7. [ 1/30][200/938] D_loss: 0.4953 | G_loss: 4.0195
  8. -> checkpoint 1
  9. [ 1/30][400/938] D_loss: 0.3610 | G_loss: 4.8148
  10. -> checkpoint 2
  11. [ 1/30][600/938] D_loss: 0.4072 | G_loss: 4.36760
  12. -> checkpoint 3
  13. [ 1/30][800/938] D_loss: 0.4444 | G_loss: 4.0250
  14. -> checkpoint 4
  15. [ 2/30][200/938] D_loss: 0.3761 | G_loss: 3.8790
  16. -> checkpoint 5
  17. [ 2/30][400/938] D_loss: 0.3977 | G_loss: 3.3315
  18. ...
  19. -> checkpoint 120
  20. [30/30][938/938] D_loss: 0.3610 | G_loss: 3.8084

并在圖中顯示圖像:

  1. root@3c0711f20896:/home/build# python display.py -i dcgan-sample-100.pt
  2. Saved out.png

應(yīng)該看起來像這樣:

digits

數(shù)字! 萬歲! 現(xiàn)在,事情就在您的球場(chǎng)上:您可以改進(jìn)模型以使數(shù)字看起來更好嗎?

結(jié)論

希望本教程為您提供了 PyTorch C ++前端的摘要。 像 PyTorch 這樣的機(jī)器學(xué)習(xí)庫(kù)必然具有非常廣泛的 API。 因此,有許多概念我們沒有時(shí)間或空間來討論。 但是,我建議您嘗試使用該 API,并在遇到問題時(shí)查閱我們的文檔,尤其是庫(kù) API 部分。 另外,請(qǐng)記住,只要我們能夠做到,就可以期望 C ++前端遵循 Python 前端的設(shè)計(jì)和語(yǔ)義,因此您可以利用這一事實(shí)來提高學(xué)習(xí)率。

Tip

You can find the full source code presented in this tutorial in this repository.

與往常一樣,如果您遇到任何問題或疑問,可以使用我們的論壇GitHub 問題進(jìn)行聯(lián)系。

以上內(nèi)容是否對(duì)您有幫助:
在線筆記
App下載
App下載

掃描二維碼

下載編程獅App

公眾號(hào)
微信公眾號(hào)

編程獅公眾號(hào)