PyTorch 在 C ++中加載 TorchScript 模型

2020-09-09 15:56 更新
原文: https://pytorch.org/tutorials/advanced/cpp_export.html

顧名思義,PyTorch 的主要接口是 Python 編程語(yǔ)言。 盡管 Python 是許多需要?jiǎng)討B(tài)性和易于迭代的場(chǎng)景的合適且首選的語(yǔ)言,但是在同樣許多情況下,Python 的這些屬性恰恰是不利的。 后者經(jīng)常應(yīng)用的一種環(huán)境是生產(chǎn) –低延遲和嚴(yán)格部署要求的土地。 對(duì)于生產(chǎn)場(chǎng)景,即使僅將 C ++綁定到 Java,Rust 或 Go 之類的另一種語(yǔ)言中,它也是經(jīng)常選擇的語(yǔ)言。 以下段落將概述 PyTorch 提供的從現(xiàn)有 Python 模型到序列化表示形式的路徑,該序列化表示形式可以加載和完全由 C ++執(zhí)行,不依賴于 Python。

步驟 1:將 PyTorch 模型轉(zhuǎn)換為 Torch 腳本

PyTorch 模型從 Python 到 C ++的旅程由 Torch 腳本啟用,它是 PyTorch 模型的表示形式,可以由 Torch 腳本編譯器理解,編譯和序列化。 如果您是從使用香草“渴望” API 編寫的現(xiàn)有 PyTorch 模型開始的,則必須首先將模型轉(zhuǎn)換為 Torch 腳本。 在最常見的情況下(如下所述),只需很少的努力。 如果您已經(jīng)有了 Torch 腳本模塊,則可以跳到本教程的下一部分。

有兩種將 PyTorch 模型轉(zhuǎn)換為 Torch 腳本的方法。 第一種稱為跟蹤,該機(jī)制通過(guò)使用示例輸入對(duì)模型的結(jié)構(gòu)進(jìn)行一次評(píng)估并記錄這些輸入在模型中的流量來(lái)捕獲模型的結(jié)構(gòu)。 這適用于有限使用控制流的模型。 第二種方法是在模型中添加顯式批注,以告知 Torch Script 編譯器可以根據(jù) Torch Script 語(yǔ)言施加的約束直接解析和編譯模型代碼。

小費(fèi)

您可以在官方torch腳本參考中找到這兩種方法的完整文檔,以及使用方法的進(jìn)一步指導(dǎo)。

通過(guò)跟蹤轉(zhuǎn)換為 Torch 腳本

要將 PyTorch 模型通過(guò)跟蹤轉(zhuǎn)換為 Torch 腳本,必須將模型的實(shí)例以及示例輸入傳遞給torch.jit.trace函數(shù)。 這將產(chǎn)生一個(gè)torch.jit.ScriptModule對(duì)象,并將模型評(píng)估的軌跡嵌入到模塊的forward方法中:

  1. import torch
  2. import torchvision
  3. ## An instance of your model.
  4. model = torchvision.models.resnet18()
  5. ## An example input you would normally provide to your model's forward() method.
  6. example = torch.rand(1, 3, 224, 224)
  7. ## Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
  8. traced_script_module = torch.jit.trace(model, example)

現(xiàn)在可以對(duì)跟蹤的ScriptModule進(jìn)行評(píng)估,使其與常規(guī) PyTorch 模塊相同:

  1. In[1]: output = traced_script_module(torch.ones(1, 3, 224, 224))
  2. In[2]: output[0, :5]
  3. Out[2]: tensor([-0.2698, -0.0381, 0.4023, -0.3010, -0.0448], grad_fn=<SliceBackward>)

通過(guò)注釋轉(zhuǎn)換為 Torch 腳本

在某些情況下,例如,如果模型采用特定形式的控制流,則可能需要直接在 Torch 腳本中編寫模型并相應(yīng)地注釋模型。 例如,假設(shè)您具有以下香草 Pytorch 模型:

  1. import torch
  2. class MyModule(torch.nn.Module):
  3. def __init__(self, N, M):
  4. super(MyModule, self).__init__()
  5. self.weight = torch.nn.Parameter(torch.rand(N, M))
  6. def forward(self, input):
  7. if input.sum() > 0:
  8. output = self.weight.mv(input)
  9. else:
  10. output = self.weight + input
  11. return output

因?yàn)榇四K的forward方法使用取決于輸入的控制流,所以它不適合跟蹤。 相反,我們可以將其轉(zhuǎn)換為ScriptModule。 為了將模塊轉(zhuǎn)換為ScriptModule,需要使用torch.jit.script編譯模塊,如下所示:

  1. class MyModule(torch.nn.Module):
  2. def __init__(self, N, M):
  3. super(MyModule, self).__init__()
  4. self.weight = torch.nn.Parameter(torch.rand(N, M))
  5. def forward(self, input):
  6. if input.sum() > 0:
  7. output = self.weight.mv(input)
  8. else:
  9. output = self.weight + input
  10. return output
  11. my_module = MyModule(10,20)
  12. sm = torch.jit.script(my_module)

如果您需要在nn.Module中排除某些方法,因?yàn)樗鼈兪褂昧?TorchScript 尚不支持的 Python 功能,則可以使用@torch.jit.ignore注釋這些方法

my_module是已準(zhǔn)備好進(jìn)行序列化的ScriptModule的實(shí)例。

步驟 2:將腳本模塊序列化為文件

跟蹤或注釋 PyTorch 模型后,一旦您有了ScriptModule,就可以將其序列化為文件了。 稍后,您將可以使用 C ++從此文件加載模塊并執(zhí)行它,而無(wú)需依賴 Python。 假設(shè)我們要序列化先前在跟蹤示例中顯示的ResNet18模型。 要執(zhí)行此序列化,只需在模塊上調(diào)用保存并為其傳遞文件名:

  1. traced_script_module.save("traced_resnet_model.pt")

這將在您的工作目錄中生成一個(gè)traced_resnet_model.pt文件。 如果您還想序列化my_module,請(qǐng)致電my_module.save("my_module_model.pt")。我們現(xiàn)在已經(jīng)正式離開 Python 領(lǐng)域,并準(zhǔn)備跨入 C ++領(lǐng)域。

步驟 3:在 C ++中加載腳本模塊

要在 C ++中加載序列化的 PyTorch 模型,您的應(yīng)用程序必須依賴于 PyTorch C ++ API –也稱為 LibTorch 。 LibTorch 發(fā)行版包含共享庫(kù),頭文件和 CMake 構(gòu)建配置文件的集合。 雖然 CMake 不是依賴 LibTorch 的要求,但它是推薦的方法,將來(lái)會(huì)得到很好的支持。 對(duì)于本教程,我們將使用 CMake 和 LibTorch 構(gòu)建一個(gè)最小的 C ++應(yīng)用程序,該應(yīng)用程序簡(jiǎn)單地加載并執(zhí)行序列化的 PyTorch 模型。

最小的 C ++應(yīng)用程序

讓我們從討論加載模塊的代碼開始。 以下將已經(jīng)做:

  1. #include <torch/script.h> // One-stop header.
  2. #include <iostream>
  3. #include <memory>
  4. int main(int argc, const char* argv[]) {
  5. if (argc != 2) {
  6. std::cerr << "usage: example-app <path-to-exported-script-module>\n";
  7. return -1;
  8. }
  9. torch::jit::script::Module module;
  10. try {
  11. // Deserialize the ScriptModule from a file using torch::jit::load().
  12. module = torch::jit::load(argv[1]);
  13. }
  14. catch (const c10::Error& e) {
  15. std::cerr << "error loading the model\n";
  16. return -1;
  17. }
  18. std::cout << "ok\n";
  19. }

&lt;torch/script.h&gt;標(biāo)頭包含了運(yùn)行示例所需的 LibTorch 庫(kù)中的所有相關(guān)包含。 我們的應(yīng)用程序接受序列化 PyTorch ScriptModule的文件路徑作為其唯一的命令行參數(shù),然后繼續(xù)使用torch::jit::load()函數(shù)對(duì)該模塊進(jìn)行反序列化,該函數(shù)將這個(gè)文件路徑作為輸入。 作為回報(bào),我們收到一個(gè)torch::jit::script::Module對(duì)象。 我們將稍后討論如何執(zhí)行它。

取決于 LibTorch 和構(gòu)建應(yīng)用程序

假設(shè)我們將以上代碼存儲(chǔ)到名為example-app.cpp的文件中。 最小的CMakeLists.txt構(gòu)建起來(lái)看起來(lái)很簡(jiǎn)單:

  1. cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
  2. project(custom_ops)
  3. find_package(Torch REQUIRED)
  4. add_executable(example-app example-app.cpp)
  5. target_link_libraries(example-app "${TORCH_LIBRARIES}")
  6. set_property(TARGET example-app PROPERTY CXX_STANDARD 14)

構(gòu)建示例應(yīng)用程序的最后一件事是 LibTorch 發(fā)行版。 您可以隨時(shí)從 PyTorch 網(wǎng)站上的下載頁(yè)面獲取最新的穩(wěn)定版本。 如果下載并解壓縮最新的歸檔文件,則應(yīng)該收到具有以下目錄結(jié)構(gòu)的文件夾:

  1. libtorch/
  2. bin/
  3. include/
  4. lib/
  5. share/
  • lib/文件夾包含您必須鏈接的共享庫(kù),
  • include/文件夾包含程序需要包含的頭文件,
  • share/文件夾包含必要的 CMake 配置,以啟用上面的簡(jiǎn)單find_package(Torch)命令。

小竅門

在 Windows 上,調(diào)試和發(fā)行版本不兼容 ABI。 如果您打算以調(diào)試模式構(gòu)建項(xiàng)目,請(qǐng)嘗試使用 LibTorch 的調(diào)試版本。 另外,請(qǐng)確保在下面的cmake --build .行中指定正確的配置。

最后一步是構(gòu)建應(yīng)用程序。 為此,假定示例目錄的布局如下:

  1. example-app/
  2. CMakeLists.txt
  3. example-app.cpp

現(xiàn)在,我們可以運(yùn)行以下命令從example-app/文件夾中構(gòu)建應(yīng)用程序:

  1. mkdir build
  2. cd build
  3. cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch ..
  4. cmake --build . --config Release

其中/path/to/libtorch應(yīng)該是解壓縮的 LibTorch 發(fā)行版的完整路徑。 如果一切順利,它將看起來(lái)像這樣:

  1. root@4b5a67132e81:/example-app# mkdir build
  2. root@4b5a67132e81:/example-app# cd build
  3. root@4b5a67132e81:/example-app/build# cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch ..
  4. -- The C compiler identification is GNU 5.4.0
  5. -- The CXX compiler identification is GNU 5.4.0
  6. -- Check for working C compiler: /usr/bin/cc
  7. -- Check for working C compiler: /usr/bin/cc -- works
  8. -- Detecting C compiler ABI info
  9. -- Detecting C compiler ABI info - done
  10. -- Detecting C compile features
  11. -- Detecting C compile features - done
  12. -- Check for working CXX compiler: /usr/bin/c++
  13. -- Check for working CXX compiler: /usr/bin/c++ -- works
  14. -- Detecting CXX compiler ABI info
  15. -- Detecting CXX compiler ABI info - done
  16. -- Detecting CXX compile features
  17. -- Detecting CXX compile features - done
  18. -- Looking for pthread.h
  19. -- Looking for pthread.h - found
  20. -- Looking for pthread_create
  21. -- Looking for pthread_create - not found
  22. -- Looking for pthread_create in pthreads
  23. -- Looking for pthread_create in pthreads - not found
  24. -- Looking for pthread_create in pthread
  25. -- Looking for pthread_create in pthread - found
  26. -- Found Threads: TRUE
  27. -- Configuring done
  28. -- Generating done
  29. -- Build files have been written to: /example-app/build
  30. root@4b5a67132e81:/example-app/build# make
  31. Scanning dependencies of target example-app
  32. [ 50%] Building CXX object CMakeFiles/example-app.dir/example-app.cpp.o
  33. [100%] Linking CXX executable example-app
  34. [100%] Built target example-app

如果我們提供到先前創(chuàng)建的跟蹤ResNet18模型traced_resnet_model.pt到生成的example-app二進(jìn)制文件的路徑,則應(yīng)該以友好的“ ok”作為獎(jiǎng)勵(lì)。 請(qǐng)注意,如果嘗試使用my_module_model.pt運(yùn)行此示例,則會(huì)收到一條錯(cuò)誤消息,提示您輸入的形狀不兼容。 my_module_model.pt期望使用 1D 而不是 4D。

  1. root@4b5a67132e81:/example-app/build# ./example-app <path_to_model>/traced_resnet_model.pt
  2. ok

步驟 4:在 C ++中執(zhí)行腳本模塊

在用 C ++成功加載序列化的ResNet18之后,我們現(xiàn)在離執(zhí)行它僅幾行代碼了! 讓我們將這些行添加到 C ++應(yīng)用程序的main()函數(shù)中:

  1. // Create a vector of inputs.
  2. std::vector<torch::jit::IValue> inputs;
  3. inputs.push_back(torch::ones({1, 3, 224, 224}));
  4. // Execute the model and turn its output into a tensor.
  5. at::Tensor output = module.forward(inputs).toTensor();
  6. std::cout << output.slice(/*dim=*/1, /*start=*/0, /*end=*/5) << '\n';

前兩行設(shè)置了我們模型的輸入。 我們創(chuàng)建一個(gè)torch::jit::IValue的向量(類型擦除的值類型script::Module方法接受并返回),并添加單個(gè)輸入。 要?jiǎng)?chuàng)建輸入張量,我們使用torch::ones(),等效于 C ++ API 中的torch.ones。 然后,我們運(yùn)行script::Moduleforward方法,并將其傳遞給我們創(chuàng)建的輸入向量。 作為回報(bào),我們得到一個(gè)新的IValue,我們可以通過(guò)調(diào)用toTensor()將其轉(zhuǎn)換為張量。

小竅門

要總體上了解有關(guān)torch::ones和 PyTorch C ++ API 之類的功能的更多信息,請(qǐng)參閱 https://pytorch.org/cppdocs 上的文檔。 PyTorch C ++ API 提供了與 Python API 差不多的功能奇偶校驗(yàn),使您可以像在 Python 中一樣進(jìn)一步操縱和處理張量。

在最后一行,我們打印輸出的前五個(gè)條目。 由于在本教程前面的部分中,我們向 Python 中的模型提供了相同的輸入,因此理想情況下,我們應(yīng)該看到相同的輸出。 讓我們通過(guò)重新編譯我們的應(yīng)用程序并以相同的序列化模型運(yùn)行它來(lái)進(jìn)行嘗試:

  1. root@4b5a67132e81:/example-app/build# make
  2. Scanning dependencies of target example-app
  3. [ 50%] Building CXX object CMakeFiles/example-app.dir/example-app.cpp.o
  4. [100%] Linking CXX executable example-app
  5. [100%] Built target example-app
  6. root@4b5a67132e81:/example-app/build# ./example-app traced_resnet_model.pt
  7. -0.2698 -0.0381 0.4023 -0.3010 -0.0448
  8. [ Variable[CPUFloatType]{1,5} ]

作為參考,Python 以前的輸出為:

  1. tensor([-0.2698, -0.0381, 0.4023, -0.3010, -0.0448], grad_fn=<SliceBackward>)

看起來(lái)很不錯(cuò)!

小竅門

要將模型移至 GPU 內(nèi)存,可以編寫model.to(at::kCUDA);。 通過(guò)調(diào)用tensor.to(at::kCUDA)來(lái)確保模型的輸入也位于 CUDA 內(nèi)存中,這將在 CUDA 內(nèi)存中返回新的張量。

第 5 步:獲取幫助并探索 API

本教程有望使您對(duì) PyTorch 模型從 Python 到 C ++的路徑有一個(gè)大致的了解。 使用本教程中描述的概念,您應(yīng)該能夠從原始的“急切” PyTorch 模型,到 Python 中已編譯的ScriptModule,再到磁盤上的序列化文件,以及–關(guān)閉循環(huán)–到可執(zhí)行文件script::Module在 C ++中。

當(dāng)然,有許多我們沒(méi)有介紹的概念。 例如,您可能會(huì)發(fā)現(xiàn)自己想要擴(kuò)展使用 C ++或 CUDA 實(shí)現(xiàn)的自定義運(yùn)算符來(lái)擴(kuò)展ScriptModule,并希望在純 C ++生產(chǎn)環(huán)境中加載的ScriptModule內(nèi)執(zhí)行該自定義運(yùn)算符。 好消息是:這是可能的,并且得到了很好的支持! 現(xiàn)在,您可以瀏覽這個(gè)文件夾作為示例,我們將很快提供一個(gè)教程。 目前,以下鏈接通??赡軙?huì)有所幫助:

與往常一樣,如果您遇到任何問(wèn)題或疑問(wèn),可以使用我們的論壇或 GitHub 問(wèn)題進(jìn)行聯(lián)系。


以上內(nèi)容是否對(duì)您有幫助:
在線筆記
App下載
App下載

掃描二維碼

下載編程獅App

公眾號(hào)
微信公眾號(hào)

編程獅公眾號(hào)