PyTorch torch腳本

2020-09-15 10:39 更新

原文: PyTorch torch腳本

TorchScript 是一種從 PyTorch 代碼創(chuàng)建可序列化和可優(yōu)化模型的方法。 任何 TorchScript 程序都可以從 Python 進(jìn)程中保存并加載到?jīng)]有 Python 依賴項的進(jìn)程中。

我們提供了將模型從純 Python 程序逐步過渡到可以獨立于 Python 運(yùn)行的 TorchScript 程序的工具,例如在獨立的 C ++程序中。 這樣就可以使用 Python 中熟悉的工具在 PyTorch 中訓(xùn)練模型,然后通過 TorchScript 將模型導(dǎo)出到生產(chǎn)環(huán)境中,在該生產(chǎn)環(huán)境中 Python 程序可能由于性能和多線程原因而處于不利地位。

有關(guān) TorchScript 的簡要介紹,請參見 TorchScript 簡介教程。

有關(guān)將 PyTorch 模型轉(zhuǎn)換為 TorchScript 并在 C ++中運(yùn)行的端到端示例,請參見在 C ++中加載 PyTorch 模型教程。

創(chuàng)建 TorchScript 代碼

class torch.jit.ScriptModule?

property code?

返回forward方法的內(nèi)部圖的漂亮打印表示形式(作為有效的 Python 語法)。

property graph?

返回forward方法的內(nèi)部圖形的字符串表示形式。

save(f, _extra_files=ExtraFilesMap{})?

torch.jit.save 。

class torch.jit.ScriptFunction?

功能上與 ScriptModule 等效,但是代表單個功能,沒有任何屬性或參數(shù)。

torch.jit.script(obj)?

為函數(shù)或nn.Module編寫腳本將檢查源代碼,使用 TorchScript 編譯器將其編譯為 TorchScript 代碼,然后返回 ScriptModuleScriptFunction 。 TorchScript 本身是 Python 語言的子集,因此 Python 并非所有功能都可以使用,但是我們提供了足夠的功能來在張量上進(jìn)行計算并執(zhí)行與控制有關(guān)的操作。

torch.jit.script可用作模塊和功能的函數(shù),以及 TorchScript 類和功能的修飾器@torch.jit.script。

Scripting a function

@torch.jit.script裝飾器將通過編譯函數(shù)的主體來構(gòu)造 ScriptFunction 。

示例(編寫函數(shù)):

import torch


@torch.jit.script
def foo(x, y):
    if x.max() > y.max():
        r = x
    else:
        r = y
    return r


print(type(foo))  # torch.jit.ScriptFuncion


## See the compiled graph as Python code
print(foo.code)


## Call the function using the TorchScript interpreter
foo(torch.ones(2, 2), torch.ones(2, 2))
Scripting an nn.Module

默認(rèn)情況下,為nn.Module編寫腳本將編譯forward方法,并遞歸編譯forward調(diào)用的任何方法,子模塊和函數(shù)。 如果nn.Module僅使用 TorchScript 支持的功能,則無需更改原始模塊代碼。 script將構(gòu)建 ScriptModule ,該副本具有原始模塊的屬性,參數(shù)和方法的副本。

示例(使用參數(shù)編寫簡單模塊的腳本):

import torch


class MyModule(torch.nn.Module):
    def __init__(self, N, M):
        super(MyModule, self).__init__()
        # This parameter will be copied to the new ScriptModule
        self.weight = torch.nn.Parameter(torch.rand(N, M))


        # When this submodule is used, it will be compiled
        self.linear = torch.nn.Linear(N, M)


    def forward(self, input):
        output = self.weight.mv(input)


        # This calls the `forward` method of the `nn.Linear` module, which will
        # cause the `self.linear` submodule to be compiled to a `ScriptModule` here
        output = self.linear(output)
        return output


scripted_module = torch.jit.script(MyModule(2, 3))

示例(使用跟蹤的子模塊編寫模塊腳本):

import torch
import torch.nn as nn
import torch.nn.functional as F


class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        # torch.jit.trace produces a ScriptModule's conv1 and conv2
        self.conv1 = torch.jit.trace(nn.Conv2d(1, 20, 5), torch.rand(1, 1, 16, 16))
        self.conv2 = torch.jit.trace(nn.Conv2d(20, 20, 5), torch.rand(1, 20, 16, 16))


    def forward(self, input):
      input = F.relu(self.conv1(input))
      input = F.relu(self.conv2(input))
      return input


scripted_module = torch.jit.script(MyModule())

要編譯除forward以外的方法(并遞歸編譯其調(diào)用的任何內(nèi)容),請將 @torch.jit.export 裝飾器添加到該方法。 要選擇退出編譯,請使用 @torch.jit.ignore 。

示例(模塊中的導(dǎo)出方法和忽略方法):

import torch
import torch.nn as nn


class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()


    @torch.jit.export
    def some_entry_point(self, input):
        return input + 10


    @torch.jit.ignore
    def python_only_fn(self, input):
        # This function won't be compiled, so any
        # Python APIs can be used
        import pdb
        pdb.set_trace()


    def forward(self, input):
        if self.training:
            self.python_only_fn(input)
        return input * 99


scripted_module = torch.jit.script(MyModule())
print(scripted_module.some_entry_point(torch.randn(2, 2)))
print(scripted_module(torch.randn(2, 2)))

torch.jit.trace(func, example_inputs, optimize=None, check_trace=True, check_inputs=None, check_tolerance=1e-5)?

跟蹤一個函數(shù)并返回將使用即時編譯進(jìn)行優(yōu)化的可執(zhí)行文件或 ScriptFunction 。 對于僅在TensorTensor的列表,字典和元組上運(yùn)行的代碼,跟蹤是理想的選擇。

使用torch.jit.tracetorch.jit.trace_module ,您可以將現(xiàn)有模塊或 Python 函數(shù)轉(zhuǎn)換為 TorchScript ScriptFunctionScriptModule 。 您必須提供示例輸入,然后我們運(yùn)行該函數(shù),記錄在所有張量上執(zhí)行的操作。

  • 獨立功能的最終記錄將產(chǎn)生 ScriptFunction 。
  • nn.Modulenn.Moduleforward功能的所得記錄產(chǎn)生 ScriptModule

該模塊還包含原始模塊也具有的任何參數(shù)。

警告

跟蹤僅正確記錄不依賴數(shù)據(jù)的功能和模塊(例如,對張量中的數(shù)據(jù)沒有條件)并且不包含任何未跟蹤的外部依賴項(例如,執(zhí)行輸入/輸出或訪問全局變量)。 跟蹤僅記錄在給定張量上運(yùn)行給定函數(shù)時執(zhí)行的操作。 因此,返回的 ScriptModule 將始終在任何輸入上運(yùn)行相同的跟蹤圖。 當(dāng)期望模塊根據(jù)輸入和/或模塊狀態(tài)運(yùn)行不同的操作集時,這具有重要意義。 例如,

  • 跟蹤將不會記錄任何控制流,例如 if 語句或循環(huán)。 當(dāng)整個模塊的控制流恒定時,這很好,并且通常內(nèi)聯(lián)控制流決策。 但是有時控制流實際上是模型本身的一部分。 例如,循環(huán)網(wǎng)絡(luò)是輸入序列(可能是動態(tài))長度上的循環(huán)。
  • 在返回的 ScriptModule 中,在trainingeval模式下具有不同行為的操作將始終像在跟蹤過程中一樣處于運(yùn)行狀態(tài),無論是哪種模式 ] ScriptModule 已插入。

在這種情況下,跟蹤是不合適的, scripting 是更好的選擇。 如果跟蹤此類模型,則可能在隨后的模型調(diào)用中靜默地得到不正確的結(jié)果。 在執(zhí)行可能會導(dǎo)致產(chǎn)生不正確跟蹤的操作時,跟蹤器將嘗試發(fā)出警告。

參數(shù)

  • 函數(shù)(可調(diào)用的 torch.nn.Module)– Python 函數(shù)或torch.nn.Moduleexample_inputs一起運(yùn)行。 func的參數(shù)和返回值必須是張量或包含張量的(可能是嵌套的)元組。 將模塊傳遞到 torch.jit.trace 時,僅運(yùn)行并跟蹤forward方法(有關(guān)詳細(xì)信息,參見 torch.jit.trace)。
  • example_inputs (tuple )–示例輸入的元組,將在跟蹤時傳遞給函數(shù)。 假設(shè)跟蹤的操作支持這些類型和形狀,則可以使用不同類型和形狀的輸入來運(yùn)行結(jié)果跟蹤。 example_inputs也可以是單個張量,在這種情況下,它會自動包裝在元組中。

Keyword Arguments

  • check_trace (bool,可選)–檢查通過跟蹤代碼運(yùn)行的相同輸入是否產(chǎn)生相同的輸出。 默認(rèn)值:True。 例如,如果您的網(wǎng)絡(luò)包含不確定性操作,或者即使檢查程序失敗,但您確定網(wǎng)絡(luò)正確,則可能要禁用此功能。
  • check_inputs (元組列表 , 可選)–輸入?yún)?shù)的元組列表,應(yīng)使用這些元組來檢查跟蹤內(nèi)容 是期待。 每個元組等效于example_inputs中指定的一組輸入?yún)?shù)。 為了獲得最佳結(jié)果,請傳遞一組檢查輸入,這些輸入代表您希望網(wǎng)絡(luò)看到的形狀和輸入類型的空間。 如果未指定,則使用原始的example_inputs進(jìn)行檢查
  • check_tolerance (python:float 可選)–在檢查程序中使用的浮點比較公差。 如果結(jié)果由于已知原因(例如操作員融合)而在數(shù)值上出現(xiàn)差異,則可以使用此方法來放松檢查器的嚴(yán)格性。

退貨

如果callablenn.Modulenn.Moduleforward,則trace將使用包含跟蹤代碼的單個forward方法返回 ScriptModule 對象。 返回的 ScriptModule 將具有與原始nn.Module相同的子模塊和參數(shù)集。 如果callable是獨立功能,則trace返回 ScriptFunction

示例(跟蹤函數(shù)):

import torch


def foo(x, y):
    return 2 * x + y


## Run `foo` with the provided inputs and record the tensor operations
traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3)))


## `traced_foo` can now be run with the TorchScript interpreter or saved
## and loaded in a Python-free environment

示例(跟蹤現(xiàn)有模塊):

import torch
import torch.nn as nn


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv = nn.Conv2d(1, 1, 3)


    def forward(self, x):
        return self.conv(x)


n = Net()
example_weight = torch.rand(1, 1, 3, 3)
example_forward_input = torch.rand(1, 1, 3, 3)


## Trace a specific method and construct `ScriptModule` with
## a single `forward` method
module = torch.jit.trace(n.forward, example_forward_input)


## Trace a module (implicitly traces `forward`) and construct a
## `ScriptModule` with a single `forward` method
module = torch.jit.trace(n, example_forward_input)

torch.jit.trace_module(mod, inputs, optimize=None, check_trace=True, check_inputs=None, check_tolerance=1e-5)?

跟蹤模塊并返回可執(zhí)行文件 ScriptModule ,該文件將使用即時編譯進(jìn)行優(yōu)化。 將模塊傳遞到 torch.jit.trace 時,僅運(yùn)行并跟蹤forward方法。 使用trace_module,您可以指定方法名稱的字典作為示例輸入,以跟蹤下面的參數(shù)(請參見example_inputs)。

有關(guān)跟蹤的更多信息,參見 torch.jit.trace

Parameters

  • mod (Torch.nn.Module)–一種torch.nn.Module,其中包含名稱在example_inputs中指定的方法。 給定的方法將被編譯為單個 <cite>ScriptModule</cite> 的一部分。
  • example_inputs (dict )–包含樣本輸入的字典,該樣本輸入由mod中的方法名稱索引。 輸入將在跟蹤時傳遞給名稱與輸入鍵對應(yīng)的方法。 { 'forward' : example_forward_input, 'method2': example_method2_input}

Keyword Arguments

  • check_trace (bool, optional) – Check if the same inputs run through traced code produce the same outputs. Default: True. You might want to disable this if, for example, your network contains non- deterministic ops or if you are sure that the network is correct despite a checker failure.
  • check_inputs (字典列表 , 可選)–輸入?yún)?shù)的字典列表,用于檢查跟蹤內(nèi)容 是期待。 每個元組等效于example_inputs中指定的一組輸入?yún)?shù)。 為了獲得最佳結(jié)果,請傳遞一組檢查輸入,這些輸入代表您希望網(wǎng)絡(luò)看到的形狀和輸入類型的空間。 如果未指定,則使用原始的example_inputs進(jìn)行檢查
  • check_tolerance (python:float__, optional) – Floating-point comparison tolerance to use in the checker procedure. This can be used to relax the checker strictness in the event that results diverge numerically for a known reason, such as operator fusion.

Returns

具有單個forward方法的 ScriptModule 對象,其中包含跟蹤的代碼。 當(dāng)functorch.nn.Module時,返回的 ScriptModule 將具有與func相同的子模塊和參數(shù)集。

示例(使用多種方法跟蹤模塊):

import torch
import torch.nn as nn


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv = nn.Conv2d(1, 1, 3)


    def forward(self, x):
        return self.conv(x)


    def weighted_kernel_sum(self, weight):
        return weight * self.conv.weight


n = Net()
example_weight = torch.rand(1, 1, 3, 3)
example_forward_input = torch.rand(1, 1, 3, 3)


## Trace a specific method and construct `ScriptModule` with
## a single `forward` method
module = torch.jit.trace(n.forward, example_forward_input)


## Trace a module (implicitly traces `forward`) and construct a
## `ScriptModule` with a single `forward` method
module = torch.jit.trace(n, example_forward_input)


## Trace specific methods on a module (specified in `inputs`), constructs
## a `ScriptModule` with `forward` and `weighted_kernel_sum` methods
inputs = {'forward' : example_forward_input, 'weighted_kernel_sum' : example_weight}
module = torch.jit.trace_module(n, inputs)

torch.jit.save(m, f, _extra_files=ExtraFilesMap{})?

保存此模塊的脫機(jī)版本以在單獨的過程中使用。 保存的模塊將序列化此模塊的所有方法,子模塊,參數(shù)和屬性。 可以使用torch::jit::load(filename)將其加載到 C ++ API 中,或者使用 torch.jit.load 加載到 Python API 中。

為了能夠保存模塊,它不得對本地 Python 函數(shù)進(jìn)行任何調(diào)用。 這意味著所有子模塊也必須是torch.jit.ScriptModule的子類。

危險

所有模塊,無論使用哪種設(shè)備,都始終在加載期間加載到 CPU 中。 這與 load 的語義不同,并且將來可能會發(fā)生變化。

Parameters

  • m –要保存的 ScriptModule。
  • f –類似于文件的對象(必須實現(xiàn)寫入和刷新)或包含文件名的字符串。
  • _extra_files -從文件名映射到將作為“ f”的一部分存儲的內(nèi)容。

Warning

如果您使用的是 Python 2,torch.jit.save不支持StringIO.StringIO作為有效的類似文件的對象。 這是因為 write 方法應(yīng)返回寫入的字節(jié)數(shù); StringIO.write()不這樣做。

請改用io.BytesIO之類的東西。

例:

import torch
import io


class MyModule(torch.nn.Module):
    def forward(self, x):
        return x + 10


m = torch.jit.script(MyModule())


## Save to file
torch.jit.save(m, 'scriptmodule.pt')
## This line is equivalent to the previous
m.save("scriptmodule.pt")


## Save to io.BytesIO buffer
buffer = io.BytesIO()
torch.jit.save(m, buffer)


## Save with extra files
extra_files = torch._C.ExtraFilesMap()
extra_files['foo.txt'] = 'bar'
torch.jit.save(m, 'scriptmodule.pt', _extra_files=extra_files)

torch.jit.load(f, map_location=None, _extra_files=ExtraFilesMap{})?

加載先前用 torch.jit.save 保存的 ScriptModuleScriptFunction

之前保存的所有模塊,無論使用何種設(shè)備,都首先加載到 CPU 中,然后再移動到保存它們的設(shè)備上。 如果失敗(例如,因為運(yùn)行時系統(tǒng)沒有某些設(shè)備),則會引發(fā)異常。

Parameters

  • f –類似于文件的對象(必須實現(xiàn)讀取,讀取行,告訴和查找),或包含文件名的字符串
  • map_location (字符串 torch設(shè)備)– torch.savemap_location的簡化版本 用于動態(tài)地將存儲重新映射到另一組設(shè)備。
  • _extra_files (文件名到內(nèi)容的字典)–映射中給定的多余文件名將被加載,其內(nèi)容將存儲在提供的映射中。

Returns

ScriptModule 對象。

Example:

import torch
import io


torch.jit.load('scriptmodule.pt')


## Load ScriptModule from io.BytesIO object
with open('scriptmodule.pt', 'rb') as f:
    buffer = io.BytesIO(f.read())


## Load all tensors to the original device
torch.jit.load(buffer)


## Load all tensors onto CPU, using a device
buffer.seek(0)
torch.jit.load(buffer, map_location=torch.device('cpu'))


## Load all tensors onto CPU, using a string
buffer.seek(0)
torch.jit.load(buffer, map_location='cpu')


## Load with extra files.
extra_files = torch._C.ExtraFilesMap()
extra_files['foo.txt'] = 'bar'
torch.jit.load('scriptmodule.pt', _extra_files=extra_files)
print(extra_files['foo.txt'])

混合跟蹤和腳本編寫

在許多情況下,將模型轉(zhuǎn)換為 TorchScript 都可以使用跟蹤或腳本編寫。 可以組成跟蹤和腳本以適合模型一部分的特定要求。

腳本函數(shù)可以調(diào)用跟蹤函數(shù)。 當(dāng)您需要在簡單的前饋模型周圍使用控制流時,這特別有用。 例如,序列到序列模型的波束搜索通常將用腳本編寫,但是可以調(diào)用使用跟蹤生成的編碼器模塊。

示例(在腳本中調(diào)用跟蹤的函數(shù)):

import torch


def foo(x, y):
    return 2 * x + y


traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3)))


@torch.jit.script
def bar(x):
    return traced_foo(x, x)

跟蹤的函數(shù)可以調(diào)用腳本函數(shù)。 即使大部分模型只是前饋網(wǎng)絡(luò),當(dāng)模型的一小部分需要一些控制流時,這也很有用。 跟蹤函數(shù)調(diào)用的腳本函數(shù)內(nèi)部的控制流已正確保留。

示例(在跟蹤函數(shù)中調(diào)用腳本函數(shù)):

import torch


@torch.jit.script
def foo(x, y):
    if x.max() > y.max():
        r = x
    else:
        r = y
    return r


def bar(x, y, z):
    return foo(x, y) + z


traced_bar = torch.jit.trace(bar, (torch.rand(3), torch.rand(3), torch.rand(3)))

此組合也適用于nn.Module,在這里它可用于通過跟蹤來生成子模塊,該跟蹤可以從腳本模塊的方法中調(diào)用。

示例(使用跟蹤模塊):

import torch
import torchvision


class MyScriptModule(torch.nn.Module):
    def __init__(self):
        super(MyScriptModule, self).__init__()
        self.means = torch.nn.Parameter(torch.tensor([103.939, 116.779, 123.68])
                                        .resize_(1, 3, 1, 1))
        self.resnet = torch.jit.trace(torchvision.models.resnet18(),
                                      torch.rand(1, 3, 224, 224))


    def forward(self, input):
        return self.resnet(input - self.means)


my_script_module = torch.jit.script(MyScriptModule())

遷移到 PyTorch 1.2 遞歸腳本 API

本節(jié)詳細(xì)介紹了 PyTorch 1.2 中對 TorchScript 的更改。 如果您不熟悉 TorchScript,則可以跳過本節(jié)。 PyTorch 1.2 對 TorchScript API 進(jìn)行了兩個主要更改。

\1. torch.jit.script 現(xiàn)在將嘗試遞歸編譯遇到的函數(shù),方法和類。 調(diào)用torch.jit.script后,編譯將是“選擇退出”,而不是“選擇加入”。

2.現(xiàn)在torch.jit.script(nn_module_instance)是創(chuàng)建 ScriptModule 的首選方法,而不是從torch.jit.ScriptModule繼承。 這些更改組合在一起,提供了一個更簡單易用的 API,可將您的nn.Module轉(zhuǎn)換為 ScriptModule ,可以在非 Python 環(huán)境中進(jìn)行優(yōu)化和執(zhí)行。

新用法如下所示:

import torch
import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)


    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))


my_model = Model()
my_scripted_model = torch.jit.script(my_model)

  • 該模塊的forward是默認(rèn)編譯的。 從forward調(diào)用的方法將按照在forward中使用的順序進(jìn)行延遲編譯。
  • 要編譯未從forward調(diào)用的forward以外的方法,請?zhí)砑?code>@torch.jit.export。
  • 要停止編譯器編譯方法,請?zhí)砑?@torch.jit.ignore@torch.jit.unused 。 @ignore離開
  • 方法作為對 python 的調(diào)用,并且@unused將其替換為異常。 @ignored無法導(dǎo)出; @unused可以。
  • 可以推斷大多數(shù)屬性類型,因此不需要torch.jit.Attribute。 對于空容器類型,請使用 PEP 526 樣式類注釋對其類型進(jìn)行注釋。
  • 可以使用Final類注釋來標(biāo)記常量,而不是將成員的名稱添加到__constants__中。
  • 可以使用 Python 3 類型提示代替torch.jit.annotate

As a result of these changes, the following items are considered deprecated and should not appear in new code:

  • @torch.jit.script_method裝飾器
  • 繼承自torch.jit.ScriptModule的類
  • torch.jit.Attribute包裝器類
  • __constants__數(shù)組
  • torch.jit.annotate功能

模塊

Warning

@torch.jit.ignore 注釋的行為在 PyTorch 1.2 中發(fā)生了變化。 在 PyTorch 1.2 之前,@ ignore 裝飾器用于使函數(shù)或方法可從導(dǎo)出的代碼中調(diào)用。 要恢復(fù)此功能,請使用@torch.jit.unused()。 @torch.jit.ignore現(xiàn)在等同于@torch.jit.ignore(drop=False)。 有關(guān)詳細(xì)信息,參見 @torch.jit.ignore@torch.jit.unused

當(dāng)傳遞給 torch.jit.script 函數(shù)時,torch.nn.Module的數(shù)據(jù)將復(fù)制到 ScriptModule ,然后 TorchScript 編譯器將編譯該模塊。 該模塊的forward默認(rèn)為編譯狀態(tài)。 從forward調(diào)用的方法以及它們在forward中使用的順序都是按延遲順序編譯的。

torch.jit.export(fn)?

此修飾符指示nn.Module上的方法用作 ScriptModule 的入口點,應(yīng)進(jìn)行編譯。

forward隱式地假定為入口點,因此不需要此裝飾器。 從forward調(diào)用的函數(shù)和方法在編譯器看到的情況下進(jìn)行編譯,因此它們也不需要此裝飾器。

示例(在方法上使用@torch.jit.export):

import torch
import torch.nn as nn


class MyModule(nn.Module):
    def implicitly_compiled_method(self, x):
        return x + 99


    # `forward` is implicitly decorated with `@torch.jit.export`,
    # so adding it here would have no effect
    def forward(self, x):
        return x + 10


    @torch.jit.export
    def another_forward(self, x):
        # When the compiler sees this call, it will compile
        # `implicitly_compiled_method`
        return self.implicitly_compiled_method(x)


    def unused_method(self, x):
        return x - 20


## `m` will contain compiled methods:
##     `forward`
##     `another_forward`
##     `implicitly_compiled_method`
## `unused_method` will not be compiled since it was not called from
## any compiled methods and wasn't decorated with `@torch.jit.export`
m = torch.jit.script(MyModule())

功能

功能沒有太大變化,可以根據(jù)需要用 @torch.jit.ignoretorch.jit.unused 裝飾。

## Same behavior as pre-PyTorch 1.2
@torch.jit.script
def some_fn():
    return 2


## Marks a function as ignored, if nothing
## ever calls it then this has no effect
@torch.jit.ignore
def some_fn2():
    return 2


## As with ignore, if nothing calls it then it has no effect.
## If it is called in script it is replaced with an exception.
@torch.jit.unused
def some_fn3():
  import pdb; pdb.set_trace()
  return 4


## Doesn't do anything, this function is already
## the main entry point
@torch.jit.export
def some_fn4():
    return 2

TorchScript 類

默認(rèn)情況下,將導(dǎo)出用戶定義的 TorchScript 類中的所有內(nèi)容,可以根據(jù)需要用 @torch.jit.ignore 修飾功能。

屬性

TorchScript 編譯器需要知道模塊屬性的類型。 大多數(shù)類型可以從成員的值推斷出來。 空列表和字典不能推斷其類型,而必須使用 PEP 526 樣式類注釋來注釋其類型。 如果無法推斷類型并且未對顯式類型進(jìn)行注釋,則不會將其作為屬性添加到結(jié)果 ScriptModule

舊 API:

from typing import Dict
import torch


class MyModule(torch.jit.ScriptModule):
    def __init__(self):
        super(MyModule, self).__init__()
        self.my_dict = torch.jit.Attribute({}, Dict[str, int])
        self.my_int = torch.jit.Attribute(20, int)


m = MyModule()

新 API:

from typing import Dict


class MyModule(torch.nn.Module):
    my_dict: Dict[str, int]


    def __init__(self):
        super(MyModule, self).__init__()
        # This type cannot be inferred and must be specified
        self.my_dict = {}


        # The attribute type here is inferred to be `int`
        self.my_int = 20


    def forward(self):
        pass


m = torch.jit.script(MyModule())

Python 2

如果您受制于 Python 2 并且無法使用類注釋語法,則可以使用__annotations__類成員直接應(yīng)用類型注釋。

from typing import Dict


class MyModule(torch.jit.ScriptModule):
    __annotations__ = {'my_dict': Dict[str, int]}


    def __init__(self):
        super(MyModule, self).__init__()
        self.my_dict = {}
        self.my_int = 20

常數(shù)

Final類型的構(gòu)造函數(shù)可用于將成員標(biāo)記為常量。 如果成員未標(biāo)記為常量,則將其復(fù)制為結(jié)果 ScriptModule 作為屬性。 如果已知該值是固定的,則使用Final可以進(jìn)行優(yōu)化,并提供附加的類型安全性。

Old API:

class MyModule(torch.jit.ScriptModule):
    __constants__ = ['my_constant']


    def __init__(self):
        super(MyModule, self).__init__()
        self.my_constant = 2


    def forward(self):
        pass
m = MyModule()

New API:

try:
    from typing_extensions import Final
except:
    # If you don't have `typing_extensions` installed, you can use a
    # polyfill from `torch.jit`.
    from torch.jit import Final


class MyModule(torch.nn.Module):


    my_constant: Final[int]


    def __init__(self):
        super(MyModule, self).__init__()
        self.my_constant = 2


    def forward(self):
        pass


m = torch.jit.script(MyModule())

變量

假定容器的類型為Tensor,并且是非可選的(有關(guān)更多信息,請參見默認(rèn)類型)。 以前,torch.jit.annotate用來告訴 TorchScript 編譯器類型是什么。 現(xiàn)在支持 Python 3 樣式類型提示。

import torch
from typing import Dict, Optional


@torch.jit.script
def make_dict(flag: bool):
    x: Dict[str, int] = {}
    x['hi'] = 2
    b: Optional[int] = None
    if flag:
        b = 2
    return x, b

TorchScript 語言參考

TorchScript 是 Python 的靜態(tài)類型子集,可以直接編寫(使用 @torch.jit.script 裝飾器),也可以通過跟蹤從 Python 代碼自動生成。 使用跟蹤時,通過僅在張量上記錄實際的運(yùn)算符并簡單地執(zhí)行和丟棄其他周圍的 Python 代碼,代碼會自動轉(zhuǎn)換為 Python 的此子集。

使用@torch.jit.script裝飾器直接編寫 TorchScript 時,程序員只能使用 TorchScript 支持的 Python 子集。 本節(jié)記錄了 TorchScript 支持的功能,就像它是獨立語言的語言參考一樣。 本參考中未提及的 Python 的任何功能都不屬于 TorchScript。 有關(guān)可用的 Pytorch 張量方法,模塊和功能的完整參考,請參見內(nèi)置函數(shù)。

作為 Python 的子集,任何有效的 TorchScript 函數(shù)也是有效的 Python 函數(shù)。 這樣就可以禁用 TorchScript 并使用pdb之類的標(biāo)準(zhǔn) Python 工具調(diào)試該功能。 反之則不成立:有許多有效的 Python 程序不是有效的 TorchScript 程序。 相反,TorchScript 特別專注于表示 PyTorch 中的神經(jīng)網(wǎng)絡(luò)模型所需的 Python 功能。

類型

TorchScript 與完整的 Python 語言之間的最大區(qū)別是 TorchScript 僅支持表達(dá)神經(jīng)網(wǎng)絡(luò)模型所需的一小部分類型。 特別是,TorchScript 支持:

|

類型

|

描述

| | --- | --- | | Tensor | 任何 dtype,尺寸或后端的 PyTorch 張量 | | Tuple[T0, T1, ...] | 包含子類型T0,T1等(例如Tuple[Tensor, Tensor])的元組 | | bool | 布爾值 | | int | 標(biāo)量整數(shù) | | float | 標(biāo)量浮點數(shù) | | str | 一串 | | List[T] | 所有成員均為T類型的列表 | | Optional[T] | 無或輸入T的值 | | Dict[K, V] | 鍵類型為K而值類型為V的字典。 只能將str,intfloat作為密鑰類型。 | | T | 一個 TorchScript 類 | | NamedTuple[T0, T1, ...] | collections.namedtuple元組類型 |

與 Python 不同,TorchScript 函數(shù)中的每個變量都必須具有一個靜態(tài)類型。 這使優(yōu)化 TorchScript 函數(shù)變得更加容易。

示例(類型不匹配)

import torch


@torch.jit.script
def an_error(x):
    if x:
        r = torch.rand(1)
    else:
        r = 4
    return r
Traceback (most recent call last):
  ...
RuntimeError: ...


Type mismatch: r is set to type Tensor in the true branch and type int in the false branch:
@torch.jit.script
def an_error(x):
    if x:
    ~~~~~...  <--- HERE
        r = torch.rand(1)
    else:
and was used here:
    else:
        r = 4
    return r
           ~ <--- HERE
...

默認(rèn)類型

默認(rèn)情況下,TorchScript 函數(shù)的所有參數(shù)均假定為 Tensor。 要指定 TorchScript 函數(shù)的參數(shù)是其他類型,可以使用上面列出的類型使用 MyPy 樣式的類型注釋。

import torch


@torch.jit.script
def foo(x, tup):
    # type: (int, Tuple[Tensor, Tensor]) -> Tensor
    t0, t1 = tup
    return t0 + t1 + x


print(foo(3, (torch.rand(3), torch.rand(3))))

注意

也可以使用typing模塊中的 Python 3 類型提示來注釋類型。

import torch
from typing import Tuple


@torch.jit.script
def foo(x: int, tup: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
    t0, t1 = tup
    return t0 + t1 + x


print(foo(3, (torch.rand(3), torch.rand(3))))

在我們的示例中,我們使用基于注釋的類型提示來確保 Python 2 的兼容性。

假定空列表為List[Tensor],空字典為Dict[str, Tensor]。 要實例化其他類型的空列表或字典,請使用 Python 3 類型提示。 如果您使用的是 Python 2,則可以使用torch.jit.annotate

示例(Python 3 的類型注釋):

import torch
import torch.nn as nn
from typing import Dict, List, Tuple


class EmptyDataStructures(torch.nn.Module):
    def __init__(self):
        super(EmptyDataStructures, self).__init__()


    def forward(self, x: torch.Tensor) -> Tuple[List[Tuple[int, float]], Dict[str, int]]:
        # This annotates the list to be a `List[Tuple[int, float]]`
        my_list: List[Tuple[int, float]] = []
        for i in range(10):
            my_list.append((i, x.item()))


        my_dict: Dict[str, int] = {}
        return my_list, my_dict


x = torch.jit.script(EmptyDataStructures())

示例(適用于 Python 2 的torch.jit.annotate):

import torch
import torch.nn as nn
from typing import Dict, List, Tuple


class EmptyDataStructures(torch.nn.Module):
    def __init__(self):
        super(EmptyDataStructures, self).__init__()


    def forward(self, x):
        # type: (Tensor) -> Tuple[List[Tuple[int, float]], Dict[str, int]]


        # This annotates the list to be a `List[Tuple[int, float]]`
        my_list = torch.jit.annotate(List[Tuple[int, float]], [])
        for i in range(10):
            my_list.append((i, float(x.item())))


        my_dict = torch.jit.annotate(Dict[str, int], {})
        return my_list, my_dict


x = torch.jit.script(EmptyDataStructures())

可選類型細(xì)化

在 if 語句的條件內(nèi)或在assert中檢查與None的比較時,TorchScript 將優(yōu)化Optional[T]類型的變量的類型。 編譯器可以推理與and,ornot結(jié)合的多個None檢查。 對于未明確編寫的 if 語句的 else 塊,也會進(jìn)行優(yōu)化。

None檢查必須在 if 語句的條件內(nèi); 將None檢查分配給變量,并在 if 語句的條件下使用它,將不會優(yōu)化檢查中的變量類型。 僅局部變量將被細(xì)化,self.x之類的屬性將不會且必須分配給要細(xì)化的局部變量。

示例(優(yōu)化參數(shù)和局部變量的類型):

import torch
import torch.nn as nn
from typing import Optional


class M(nn.Module):
    z: Optional[int]


    def __init__(self, z):
        super(M, self).__init__()
        # If `z` is None, its type cannot be inferred, so it must
        # be specified (above)
        self.z = z


    def forward(self, x, y, z):
        # type: (Optional[int], Optional[int], Optional[int]) -> int
        if x is None:
            x = 1
            x = x + 1


        # Refinement for an attribute by assigning it to a local
        z = self.z
        if y is not None and z is not None:
            x = y + z


        # Refinement via an `assert`
        assert z is not None
        x += z
        return x


module = torch.jit.script(M(2))
module = torch.jit.script(M(None))

TorchScript 類

如果 Python 類使用 @torch.jit.script 注釋,則可以在 TorchScript 中使用,類似于聲明 TorchScript 函數(shù)的方式:

@torch.jit.script
class Foo:
  def __init__(self, x, y):
    self.x = x


  def aug_add_x(self, inc):
    self.x += inc

此子集受限制:

  • 所有函數(shù)必須是有效的 TorchScript 函數(shù)(包括__init__())。

  • 這些類必須是新型類,因為我們使用__new__()和 pybind11 來構(gòu)造它們。

  • TorchScript 類是靜態(tài)類型的。 只能通過在__init__()方法中分配給 self 來聲明成員。

\> 例如,在__init__()方法之外分配給self: > > > @torch.jit.script > class Foo: > def assign_x(self): > self.x = torch.rand(2, 3) > > > > 將導(dǎo)致: > > > RuntimeError: > Tried to set nonexistent attribute: x. Did you forget to initialize it in __init__()?: > def assign_x(self): > self.x = torch.rand(2, 3) > ~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE > >

  • 類的主體中不允許使用除方法定義之外的任何表達(dá)式。

  • 除了從object繼承以指定新樣式類外,不支持繼承或任何其他多態(tài)策略。

定義了一個類之后,就可以像其他任何 TorchScript 類型一樣在 TorchScript 和 Python 中互換使用該類:

## Declare a TorchScript class
@torch.jit.script
class Pair:
  def __init__(self, first, second):
    self.first = first
    self.second = second


@torch.jit.script
def sum_pair(p):
  # type: (Pair) -> Tensor
  return p.first + p.second


p = Pair(torch.rand(2, 3), torch.rand(2, 3))
print(sum_pair(p))

命名為元組

collections.namedtuple產(chǎn)生的類型可以在 TorchScript 中使用。

import torch
import collections


Point = collections.namedtuple('Point', ['x', 'y'])


@torch.jit.script
def total(point):
    # type: (Point) -> Tensor
    return point.x + point.y


p = Point(x=torch.rand(3), y=torch.rand(3))
print(total(p))

表達(dá)式

支持以下 Python 表達(dá)式。

文字

True
False
None
'string literals'
"string literals"
3  # interpreted as int
3.4  # interpreted as a float

列表結(jié)構(gòu)

假定一個空列表具有List[Tensor]類型。 其他列表文字的類型是從成員的類型派生的。 有關(guān)更多詳細(xì)信息,請參見默認(rèn)類型。

[3, 4]
[]
[torch.rand(3), torch.rand(4)]

元組結(jié)構(gòu)

(3, 4)
(3,)

字典結(jié)構(gòu)

假定一個空字典為Dict[str, Tensor]類型。 其他 dict 文字的類型是從成員的類型派生的。 有關(guān)更多詳細(xì)信息,請參見默認(rèn)類型。

{'hello': 3}
{}
{'a': torch.rand(3), 'b': torch.rand(4)}

變量

有關(guān)如何解析變量的信息,請參見變量分辨率。

my_variable_name

算術(shù)運(yùn)算符

a + b
a - b
a * b
a / b
a ^ b
a @ b

比較運(yùn)算符

a == b
a != b
a < b
a > b
a <= b
a >= b

邏輯運(yùn)算符

a and b
a or b
not b

下標(biāo)和切片

t[0]
t[-1]
t[0:2]
t[1:]
t[:1]
t[:]
t[0, 1]
t[0, 1:2]
t[0, :1]
t[-1, 1:, 0]
t[1:, -1, 0]
t[i:j, i]

函數(shù)調(diào)用

調(diào)用內(nèi)置函數(shù)

torch.rand(3, dtype=torch.int)

調(diào)用其他腳本函數(shù):

import torch


@torch.jit.script
def foo(x):
    return x + 1


@torch.jit.script
def bar(x):
    return foo(x)

方法調(diào)用

調(diào)用諸如張量之類的內(nèi)置類型的方法:x.mm(y)

在模塊上,必須先編譯方法才能調(diào)用它們。 TorchScript 編譯器以遞歸方式編譯在編譯其他方法時看到的方法。 默認(rèn)情況下,編譯從forward方法開始。 將編譯forward調(diào)用的任何方法,以及這些方法調(diào)用的任何方法,依此類推。 要以forward以外的方法開始編譯,請使用 @torch.jit.export 裝飾器(forward隱式標(biāo)記為@torch.jit.export)。

直接調(diào)用子模塊(例如self.resnet(input))等效于調(diào)用其forward方法(例如self.resnet.forward(input))。

import torch
import torch.nn as nn
import torchvision


class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        means = torch.tensor([103.939, 116.779, 123.68])
        self.means = torch.nn.Parameter(means.resize_(1, 3, 1, 1))
        resnet = torchvision.models.resnet18()
        self.resnet = torch.jit.trace(resnet, torch.rand(1, 3, 224, 224))


    def helper(self, input):
        return self.resnet(input - self.means)


    def forward(self, input):
        return self.helper(input)


    # Since nothing in the model calls `top_level_method`, the compiler
    # must be explicitly told to compile this method
    @torch.jit.export
    def top_level_method(self, input):
        return self.other_helper(input)


    def other_helper(self, input):
        return input + 10


## `my_script_module` will have the compiled methods `forward`, `helper`,
## `top_level_method`, and `other_helper`
my_script_module = torch.jit.script(MyModule())

三元表達(dá)式

x if x > y else y

演員表

float(ten)
int(3.5)
bool(ten)
str(2)``

訪問模塊參數(shù)

self.my_parameter
self.my_submodule.my_parameter

語句

TorchScript 支持以下類型的語句:

簡單分配

a = b
a += b # short-hand for a = a + b, does not operate in-place on a
a -= b

模式匹配分配

a, b = tuple_or_list
a, b, *c = a_tuple

多項分配

a = b, c = tup

打印報表

print("the result of an add:", a + b)

If 語句

if a < 4:
    r = -a
elif a < 3:
    r = a + a
else:
    r = 3 * a

除布爾值外,浮點數(shù),整數(shù)和張量還可以在條件中使用,并將隱式轉(zhuǎn)換為布爾值。

While 循環(huán)

a = 0
while a < 4:
    print(a)
    a += 1

適用于范圍為的循環(huán)

x = 0
for i in range(10):
    x *= i

用于遍歷元組的循環(huán)

這些展開循環(huán),為元組的每個成員生成一個主體。 主體必須對每個成員進(jìn)行正確的類型檢查。

tup = (3, torch.rand(4))
for x in tup:
    print(x)

用于在常量 nn.ModuleList 上循環(huán)

要在已編譯方法中使用nn.ModuleList,必須通過將屬性名稱添加到__constants__列表中的類型來將其標(biāo)記為常量。 nn.ModuleList上的 for 循環(huán)將在編譯時展開循環(huán)的主體,并使用常量模塊列表的每個成員。

class SubModule(torch.nn.Module):
    def __init__(self):
        super(SubModule, self).__init__()
        self.weight = nn.Parameter(torch.randn(2))


    def forward(self, input):
        return self.weight + input


class MyModule(torch.nn.Module):
    __constants__ = ['mods']


    def __init__(self):
        super(MyModule, self).__init__()
        self.mods = torch.nn.ModuleList([SubModule() for i in range(10)])


    def forward(self, v):
        for module in self.mods:
            v = module(v)
        return v


m = torch.jit.script(MyModule())

中斷并繼續(xù)

for i in range(5):
    if i == 1:
    continue
    if i == 3:
    break
    print(i)

返回

return a, b

可變分辨率

TorchScript 支持 Python 的可變分辨率(即作用域)規(guī)則的子集。 局部變量的行為與 Python 中的相同,不同之處在于,在通過函數(shù)的所有路徑上,變量必須具有相同的類型。 如果變量在 if 語句的不同分支上具有不同的類型,則在 if 語句結(jié)束后使用它是錯誤的。

同樣,如果沿函數(shù)的某些路徑僅將定義為,則不允許使用該變量。

Example:

@torch.jit.script
def foo(x):
    if x < 0:
        y = 4
    print(y)
Traceback (most recent call last):
  ...
RuntimeError: ...


y is not defined in the false branch...
@torch.jit.script...
def foo(x):
    if x < 0:
    ~~~~~~~~~...  <--- HERE
        y = 4
    print(y)
...

定義函數(shù)時,會在編譯時將非局部變量解析為 Python 值。 然后使用 Python 值使用中描述的規(guī)則將這些值轉(zhuǎn)換為 TorchScript 值。

使用 Python 值

為了使編寫 TorchScript 更加方便,我們允許腳本代碼引用周圍范圍中的 Python 值。 例如,任何時候只要引用torch,當(dāng)聲明函數(shù)時,TorchScript 編譯器實際上就會將其解析為torch Python 模塊。 這些 Python 值不是 TorchScript 的一流部分。 而是在編譯時將它們分解為 TorchScript 支持的原始類型。 這取決于編譯發(fā)生時引用的 Python 值的動態(tài)類型。 本節(jié)介紹在 TorchScript 中訪問 Python 值時使用的規(guī)則。

功能

TorchScript 可以調(diào)用 Python 函數(shù)。 當(dāng)將模型逐步轉(zhuǎn)換為 TorchScript 時,此功能非常有用。 可以將模型逐函數(shù)移至 TorchScript,而對 Python 函數(shù)的調(diào)用保留在原處。 這樣,您可以在進(jìn)行過程中逐步檢查模型的正確性。

torch.jit.ignore(drop=False, **kwargs)?

該裝飾器向編譯器指示應(yīng)忽略函數(shù)或方法,而將其保留為 Python 函數(shù)。 這使您可以將代碼保留在尚未與 TorchScript 兼容的模型中。 具有忽略功能的模型無法導(dǎo)出; 請改用 torch.jit.unused。

示例(在方法上使用@torch.jit.ignore):

import torch
import torch.nn as nn


class MyModule(nn.Module):
    @torch.jit.ignore
    def debugger(self, x):
        import pdb
        pdb.set_trace()


    def forward(self, x):
        x += 10
        # The compiler would normally try to compile `debugger`,
        # but since it is `@ignore`d, it will be left as a call
        # to Python
        self.debugger(x)
        return x


m = torch.jit.script(MyModule())


## Error! The call `debugger` cannot be saved since it calls into Python
m.save("m.pt")

示例(在方法上使用@torch.jit.ignore(drop=True)):

import torch
import torch.nn as nn


class MyModule(nn.Module):
    @torch.jit.ignore(drop=True)
    def training_method(self, x):
        import pdb
        pdb.set_trace()


    def forward(self, x):
        if self.training:
            self.training_method(x)
        return x


m = torch.jit.script(MyModule())


## This is OK since `training_method` is not saved, the call is replaced
## with a `raise`.
m.save("m.pt")

torch.jit.unused(fn)?

此裝飾器向編譯器指示應(yīng)忽略函數(shù)或方法,并用引發(fā)異常的方法代替。 這樣,您就可以在尚不兼容 TorchScript 的模型中保留代碼,并仍然可以導(dǎo)出模型。

示例(在方法上使用@torch.jit.unused):


import torch
import torch.nn as nn

class MyModule(nn.Module):
def __init__(self, use_memory_efficent):
super(MyModule, self).__init__()
self.use_memory_efficent = use_memory_efficent

@torch.jit.unused
def memory_efficient(self, x):
import pdb
pdb.set_trace()
return x + 10

def forward(self, x):
# Use not-yet-scriptable memory efficient mode
if self.use_memory_efficient:
return self.memory_efficient(x)
else:
return x + 10

m = torch.jit.script(MyModule(use_memory_efficent=False))
m.save("m.pt")

m = torch.jit.script(MyModule(use_memory_efficient=True))
# exception raised
m(torch.rand(100))

torch.jit.is_scripting()?

在編譯時返回 True 的函數(shù),否則返回 False 的函數(shù)。 這對于使用@unused 裝飾器尤其有用,可以將尚不兼容 TorchScript 的代碼保留在模型中。 .. testcode:

import torch


@torch.jit.unused
def unsupported_linear_op(x):
    return x


def linear(x):
   if not torch.jit.is_scripting():
      return torch.linear(x)
   else:
      return unsupported_linear_op(x)

Python 模塊上的屬性查找

TorchScript 可以在模塊上查找屬性。 像torch.add這樣的內(nèi)置功能可以通過這種方式訪問。 這使 TorchScript 可以調(diào)用其他模塊中定義的函數(shù)。

Python 定義的常量

TorchScript 還提供了一種使用 Python 中定義的常量的方法。 這些可用于將超參數(shù)硬編碼到函數(shù)中,或定義通用常量。 有兩種指定 Python 值應(yīng)視為常量的方式。

  1. 查找為模塊屬性的值假定為常量:

import math
import torch


@torch.jit.script
def fn():
    return math.pi

  1. 可以通過使用Final[T]注釋 ScriptModule 的屬性來將其標(biāo)記為常量。

import torch
import torch.nn as nn


class Foo(nn.Module):
    # `Final` from the `typing_extensions` module can also be used
    a : torch.jit.Final[int]


    def __init__(self):
        super(Foo, self).__init__()
        self.a = 1 + 4


    def forward(self, input):
        return self.a + input


f = torch.jit.script(Foo())

支持的常量 Python 類型是

  • int
  • float
  • bool
  • torch.device
  • torch.layout
  • torch.dtype
  • 包含受支持類型的元組
  • torch.nn.ModuleList可以在 TorchScript for 循環(huán)中使用

Note

如果您使用的是 Python 2,則可以通過將屬性名稱添加到類的__constants__屬性中來將其標(biāo)記為常量:

import torch
import torch.nn as nn


class Foo(nn.Module):
    __constants__ = ['a']


    def __init__(self):
        super(Foo, self).__init__()
        self.a = 1 + 4


    def forward(self, input):
        return self.a + input


f = torch.jit.script(Foo())

模塊屬性

torch.nn.Parameter包裝器和register_buffer可用于將張量分配給模塊。 如果可以推斷出其他類型的值,則分配給已編譯模塊的其他值將添加到已編譯模塊中。 TorchScript 中可用的所有類型都可以用作模塊屬性。 張量屬性在語義上與緩沖區(qū)相同。 空列表和字典的類型以及None值無法推斷,必須通過 PEP 526 樣式類注釋指定。 如果無法推斷出類型并且未對其進(jìn)行顯式注釋,則不會將其作為屬性添加到結(jié)果 ScriptModule 中。

Example:

from typing import List, Dict


class Foo(nn.Module):
    # `words` is initialized as an empty list, so its type must be specified
    words: List[str]


    # The type could potentially be inferred if `a_dict` (below) was not
    # empty, but this annotation ensures `some_dict` will be made into the
    # proper type
    some_dict: Dict[str, int]


    def __init__(self, a_dict):
        super(Foo, self).__init__()
        self.words = []
        self.some_dict = a_dict


        # `int`s can be inferred
        self.my_int = 10


    def forward(self, input):
        # type: (str) -> int
        self.words.append(input)
        return self.some_dict[input] + self.my_int


f = torch.jit.script(Foo({'hi': 2}))

Note

如果您使用的是 Python 2,則可以通過將屬性的類型添加到__annotations__類屬性中作為屬性名字典來標(biāo)記屬性的類型

from typing import List, Dict


class Foo(nn.Module):
    __annotations__ = {'words': List[str], 'some_dict': Dict[str, int]}


    def __init__(self, a_dict):
        super(Foo, self).__init__()
        self.words = []
        self.some_dict = a_dict


        # `int`s can be inferred
        self.my_int = 10


    def forward(self, input):
        # type: (str) -> int
        self.words.append(input)
        return self.some_dict[input] + self.my_int


f = torch.jit.script(Foo({'hi': 2}))

調(diào)試

禁用用于調(diào)試的 JIT

PYTORCH_JIT?

設(shè)置環(huán)境變量PYTORCH_JIT=0將禁用所有腳本和跟蹤注釋。 如果您的 TorchScript 模型之一存在難以調(diào)試的錯誤,則可以使用此標(biāo)志來強(qiáng)制一切都使用本機(jī) Python 運(yùn)行。 由于此標(biāo)志禁用了 TorchScript(腳本編寫和跟蹤),因此可以使用pdb之類的工具來調(diào)試模型代碼。

給定一個示例腳本:

@torch.jit.script
def scripted_fn(x : torch.Tensor):
    for i in range(12):
        x = x + x
    return x


def fn(x):
    x = torch.neg(x)
    import pdb; pdb.set_trace()
    return scripted_fn(x)


traced_fn = torch.jit.trace(fn, (torch.rand(4, 5),))
traced_fn(torch.rand(3, 4))

除調(diào)用,@torch.jit.script,函數(shù)外,使用pdb調(diào)試此腳本是可行的。 我們可以全局禁用 JIT,以便我們可以將 @torch.jit.script 函數(shù)作為普通的 Python 函數(shù)調(diào)用,而不進(jìn)行編譯。 如果上述腳本稱為disable_jit_example.py,我們可以這樣調(diào)用它:

$ PYTORCH_JIT=0 python disable_jit_example.py

并且我們將能夠像普通的 Python 函數(shù)一樣進(jìn)入 @torch.jit.script 函數(shù)。 要為特定功能禁用 TorchScript 編譯器,請參見 @torch.jit.ignore 。

檢查碼

TorchScript 為所有 ScriptModule 實例提供了代碼漂亮的打印機(jī)。 這個漂亮的打印機(jī)可以將腳本方法的代碼解釋為有效的 Python 語法。 例如:

@torch.jit.script
def foo(len):
    # type: (int) -> torch.Tensor
    rv = torch.zeros(3, 4)
    for i in range(len):
        if i < 10:
            rv = rv - 1.0
        else:
            rv = rv + 1.0
    return rv


print(foo.code)

具有單個forward方法的 ScriptModule 將具有屬性code,您可以使用該屬性檢查 ScriptModule 的代碼。 如果 ScriptModule 具有多個方法,則需要在方法本身而非模塊上訪問.code。 我們可以通過訪問.foo.code在 ScriptModule 上檢查名為foo的方法的代碼。 上面的示例產(chǎn)生以下輸出:

def foo(len: int) -> Tensor:
    rv = torch.zeros([3, 4], dtype=None, layout=None, device=None, pin_memory=None)
    rv0 = rv
    for i in range(len):
        if torch.lt(i, 10):
            rv1 = torch.sub(rv0, 1., 1)
        else:
            rv1 = torch.add(rv0, 1., 1)
        rv0 = rv1
    return rv0

這是 TorchScript 對forward方法的代碼的編譯。 您可以使用它來確保 TorchScript(跟蹤或腳本)正確捕獲了模型代碼。

解釋圖

TorchScript 還以 IR 圖的形式在比代碼漂亮打印機(jī)更低的層次上進(jìn)行表示。

TorchScript 使用靜態(tài)單分配(SSA)中間表示(IR)表示計算。 這種格式的指令由 ATen(PyTorch 的 C ++后端)運(yùn)算符和其他原始運(yùn)算符組成,包括用于循環(huán)和條件的控制流運(yùn)算符。 舉個例子:

@torch.jit.script
def foo(len):
    # type: (int) -> torch.Tensor
    rv = torch.zeros(3, 4)
    for i in range(len):
        if i < 10:
            rv = rv - 1.0
        else:
            rv = rv + 1.0
    return rv


print(foo.graph)

graph遵循檢查代碼部分中關(guān)于forward方法查找所述的相同規(guī)則。

上面的示例腳本生成圖形:

graph(%len.1 : int):
  %24 : int = prim::Constant[value=1]()
  %17 : bool = prim::Constant[value=1]() # test.py:10:5
  %12 : bool? = prim::Constant()
  %10 : Device? = prim::Constant()
  %6 : int? = prim::Constant()
  %1 : int = prim::Constant[value=3]() # test.py:9:22
  %2 : int = prim::Constant[value=4]() # test.py:9:25
  %20 : int = prim::Constant[value=10]() # test.py:11:16
  %23 : float = prim::Constant[value=1]() # test.py:12:23
  %4 : int[] = prim::ListConstruct(%1, %2)
  %rv.1 : Tensor = aten::zeros(%4, %6, %6, %10, %12) # test.py:9:10
  %rv : Tensor = prim::Loop(%len.1, %17, %rv.1) # test.py:10:5
    block0(%i.1 : int, %rv.14 : Tensor):
      %21 : bool = aten::lt(%i.1, %20) # test.py:11:12
      %rv.13 : Tensor = prim::If(%21) # test.py:11:9
        block0():
          %rv.3 : Tensor = aten::sub(%rv.14, %23, %24) # test.py:12:18
          -> (%rv.3)
        block1():
          %rv.6 : Tensor = aten::add(%rv.14, %23, %24) # test.py:14:18
          -> (%rv.6)
      -> (%17, %rv.13)
  return (%rv)

以指令%rv.1 : Tensor = aten::zeros(%4, %6, %6, %10, %12) # test.py:9:10為例。

  • %rv.1 : Tensor表示我們將輸出分配給一個名為rv.1的(唯一)值,該值是Tensor類型,并且我們不知道其具體形狀。
  • aten::zeros是運(yùn)算符(與torch.zeros等效),輸入列表(%4, %6, %6, %10, %12)指定范圍中的哪些值應(yīng)作為輸入傳遞。 可以在內(nèi)置函數(shù)中找到aten::zeros等內(nèi)置函數(shù)的模式。
  • # test.py:9:10是生成此指令的原始源文件中的位置。 在這種情況下,它是第 9 行和字符 10 處名為 <cite>test.py</cite> 的文件。

請注意,運(yùn)算符也可以具有關(guān)聯(lián)的blocks,即prim::Loopprim::If運(yùn)算符。 在圖形打印輸出中,這些運(yùn)算符被格式化以反映其等效的源代碼形式,以方便進(jìn)行調(diào)試。

如下圖所示,可以檢查圖表以確認(rèn) ScriptModule 所描述的計算是正確的,無論是自動方式還是手動方式。

追蹤案例

在某些極端情況下,給定 Python 函數(shù)/模塊的跟蹤不會代表基礎(chǔ)代碼。 這些情況可以包括:

  • 跟蹤取決于輸入的控制流(例如張量形狀)
  • 跟蹤張量視圖的就地操作(例如,分配左側(cè)的索引)

請注意,這些情況實際上將來可能是可追溯的。

自動跟蹤檢查

自動捕獲跟蹤中許多錯誤的一種方法是使用torch.jit.trace() API 上的check_inputs。 check_inputs提取輸入元組的列表,這些列表將用于重新追蹤計算并驗證結(jié)果。 例如:

def loop_in_traced_fn(x):
    result = x[0]
    for i in range(x.size(0)):
        result = result * x[i]
    return result


inputs = (torch.rand(3, 4, 5),)
check_inputs = [(torch.rand(4, 5, 6),), (torch.rand(2, 3, 4),)]


traced = torch.jit.trace(loop_in_traced_fn, inputs, check_inputs=check_inputs)

為我們提供以下診斷信息:

ERROR: Graphs differed across invocations!
Graph diff:


            graph(%x : Tensor) {
            %1 : int = prim::Constant[value=0]()
            %2 : int = prim::Constant[value=0]()
            %result.1 : Tensor = aten::select(%x, %1, %2)
            %4 : int = prim::Constant[value=0]()
            %5 : int = prim::Constant[value=0]()
            %6 : Tensor = aten::select(%x, %4, %5)
            %result.2 : Tensor = aten::mul(%result.1, %6)
            %8 : int = prim::Constant[value=0]()
            %9 : int = prim::Constant[value=1]()
            %10 : Tensor = aten::select(%x, %8, %9)
        -   %result : Tensor = aten::mul(%result.2, %10)
        +   %result.3 : Tensor = aten::mul(%result.2, %10)
        ?          ++
            %12 : int = prim::Constant[value=0]()
            %13 : int = prim::Constant[value=2]()
            %14 : Tensor = aten::select(%x, %12, %13)
        +   %result : Tensor = aten::mul(%result.3, %14)
        +   %16 : int = prim::Constant[value=0]()
        +   %17 : int = prim::Constant[value=3]()
        +   %18 : Tensor = aten::select(%x, %16, %17)
        -   %15 : Tensor = aten::mul(%result, %14)
        ?     ^                                 ^
        +   %19 : Tensor = aten::mul(%result, %18)
        ?     ^                                 ^
        -   return (%15);
        ?             ^
        +   return (%19);
        ?             ^
            }

此消息向我們表明,在我們第一次追蹤它和使用check_inputs追蹤它之間,計算有所不同。 實際上,loop_in_traced_fn主體內(nèi)的循環(huán)取決于輸入x的形狀,因此,當(dāng)我們嘗試另一種形狀不同的x時,跡線會有所不同。

在這種情況下,可以使用 torch.jit.script() 來捕獲類似于數(shù)據(jù)的控制流:

def fn(x):
    result = x[0]
    for i in range(x.size(0)):
        result = result * x[i]
    return result


inputs = (torch.rand(3, 4, 5),)
check_inputs = [(torch.rand(4, 5, 6),), (torch.rand(2, 3, 4),)]


scripted_fn = torch.jit.script(fn)
print(scripted_fn.graph)
#print(str(scripted_fn.graph).strip())


for input_tuple in [inputs] + check_inputs:
    torch.testing.assert_allclose(fn(*input_tuple), scripted_fn(*input_tuple))

產(chǎn)生:

graph(%x : Tensor) {
    %5 : bool = prim::Constant[value=1]()
    %1 : int = prim::Constant[value=0]()
    %result.1 : Tensor = aten::select(%x, %1, %1)
    %4 : int = aten::size(%x, %1)
    %result : Tensor = prim::Loop(%4, %5, %result.1)
    block0(%i : int, %7 : Tensor) {
        %10 : Tensor = aten::select(%x, %1, %i)
        %result.2 : Tensor = aten::mul(%7, %10)
        -> (%5, %result.2)
    }
    return (%result);
}

跟蹤器警告

跟蹤器會針對跟蹤計算中的幾種有問題的模式生成警告。 舉個例子,追蹤一個在 Tensor 的切片(視圖)上包含就地分配的函數(shù):

def fill_row_zero(x):
    x[0] = torch.rand(*x.shape[1:2])
    return x


traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),))
print(traced.graph)

產(chǎn)生幾個警告和一個僅返回輸入的圖形:

fill_row_zero.py:4: TracerWarning: There are 2 live references to the data region being modified when tracing in-place operator copy_ (possibly due to an assignment). This might cause the trace to be incorrect, because all other views that also reference this data will not reflect this change in the trace! On the other hand, if all other views use the same memory chunk, but are disjoint (e.g. are outputs of torch.split), this might still be safe.
    x[0] = torch.rand(*x.shape[1:2])
fill_row_zero.py:6: TracerWarning: Output nr 1\. of the traced function does not match the corresponding output of the Python function. Detailed error:
Not within tolerance rtol=1e-05 atol=1e-05 at input[0, 1] (0.09115803241729736 vs. 0.6782537698745728) and 3 other locations (33.00%)
    traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),))
graph(%0 : Float(3, 4)) {
    return (%0);
}

我們可以通過修改代碼來解決此問題,使其不使用就地更新,而是使用torch.cat來錯位構(gòu)建結(jié)果張量:

def fill_row_zero(x):
    x = torch.cat((torch.rand(1, *x.shape[1:2]), x[1:2]), dim=0)
    return x


traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),))
print(traced.graph)

內(nèi)置函數(shù)

TorchScript 支持 PyTorch 提供的內(nèi)置張量和神經(jīng)網(wǎng)絡(luò)功能的子集。 Tensor 上的大多數(shù)方法以及torch名稱空間中的函數(shù),torch.nn.functional中的所有函數(shù)以及torch.nn中的所有模塊在 TorchScript 中均受支持,下表中沒有列出。 對于不支持的模塊,建議使用 torch.jit.trace() 。

不支持的torch.nn模塊

torch.nn.modules.adaptive.AdaptiveLogSoftmaxWithLoss
torch.nn.modules.normalization.CrossMapLRN2d
torch.nn.modules.rnn.RNN

有關(guān)支持的功能的完整參考,請參見 TorchScript 內(nèi)置函數(shù)。

常見問題解答

問:我想在 GPU 上訓(xùn)練模型并在 CPU 上進(jìn)行推理。 最佳做法是什么?

首先將模型從 GPU 轉(zhuǎn)換為 CPU,然后將其保存,如下所示:


cpu_model = gpu_model.cpu()
sample_input_cpu = sample_input_gpu.cpu()
traced_cpu = torch.jit.trace(traced_cpu, sample_input_cpu)
torch.jit.save(traced_cpu, "cpu.pth")

traced_gpu = torch.jit.trace(traced_gpu, sample_input_gpu)
torch.jit.save(traced_gpu, "gpu.pth")

# ... later, when using the model:

if use_gpu:
model = torch.jit.load("gpu.pth")
else:
model = torch.jit.load("cpu.pth")

model(input)


推薦這樣做是因為跟蹤器可能會在特定設(shè)備上見證張量的創(chuàng)建,因此強(qiáng)制轉(zhuǎn)換已加載的模型可能會產(chǎn)生意想不到的效果。 在保存之前對模型進(jìn)行轉(zhuǎn)換可確保跟蹤器具有正確的設(shè)備信息。

問:如何在 ScriptModule 上存儲屬性?

說我們有一個像這樣的模型:


class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.x = 2

def forward(self):
return self.x

m = torch.jit.script(Model())


如果實例化Model,則將導(dǎo)致編譯錯誤,因為編譯器不了解x。 有四種方法可以通知編譯器 ScriptModule 的屬性:

\1. nn.Parameter-包裝在nn.Parameter中的值將像在nn.Module上一樣工作

\2. register_buffer-包裝在register_buffer中的值將像在nn.Module上一樣工作。 這等效于Tensor類型的屬性(請參見 4)。

3.常量-將類成員注釋為Final(或在類定義級別將其添加到名為__constants__的列表中)會將包含的名稱標(biāo)記為常量。 常數(shù)直接保存在模型代碼中。 有關(guān)詳細(xì)信息,請參見 Python 定義的常量。

4.屬性-可以將支持的類型的值添加為可變屬性。 可以推斷大多數(shù)類型,但可能需要指定一些類型,有關(guān)詳細(xì)信息,請參見模塊屬性。

問:我想跟蹤模塊的方法,但一直出現(xiàn)此錯誤:

RuntimeError: Cannot insert a Tensor that requires grad as a constant. Consider making it a parameter or input, or detaching the gradient

此錯誤通常表示您要跟蹤的方法使用模塊的參數(shù),并且您正在傳遞模塊的方法而不是模塊實例(例如my_module_instance.forwardmy_module_instance)。

\& 使用模塊的方法調(diào)用trace會將模塊參數(shù)(可能需要漸變)捕獲為常量。 &
\&
\&
另一方面,使用模塊實例(例如my_module)調(diào)用trace會創(chuàng)建一個新模塊,并將參數(shù)正確復(fù)制到新模塊中,以便在需要時可以累積梯度。

& 要跟蹤模塊上的特定方法,請參見 torch.jit.trace_module

以上內(nèi)容是否對您有幫助:
在線筆記
App下載
App下載

掃描二維碼

下載編程獅App

公眾號
微信公眾號

編程獅公眾號