PyTorch 分布式 Autograd 設(shè)計

2020-09-10 16:24 更新
原文: https://pytorch.org/docs/stable/notes/distributed_autograd.html

警告

分布式 RPC 框架是實驗性的,隨時可能更改。

本說明將介紹分布式自動分級的詳細設(shè)計,并逐步介紹其內(nèi)部。 在繼續(xù)之前,請確保您熟悉 Autograd 機械手和分布式 RPC 框架。

背景

假設(shè)您有兩個節(jié)點,并且在兩個節(jié)點之間劃分了一個非常簡單的模型。 可以使用 torch.distributed.rpc 如下實現(xiàn):

  1. import torch
  2. import torch.distributed.rpc as rpc
  3. def my_add(t1, t2):
  4. return torch.add(t1, t2)
  5. ## On worker 0:
  6. t1 = torch.rand((3, 3), requires_grad=True)
  7. t2 = torch.rand((3, 3), requires_grad=True)
  8. ## Perform some computation remotely.
  9. t3 = rpc.rpc_sync("worker1", my_add, args=(t1, t2))
  10. ## Perform some computation locally based on remote result.
  11. t4 = torch.rand((3, 3), requires_grad=True)
  12. t5 = torch.mul(t3, t4)
  13. ## Compute some loss.
  14. loss = t5.sum()

分布式 autograd 背后的主要動機是使用我們已經(jīng)計算并記錄所有需要梯度的張量的合適梯度的loss在這樣的分布式模型上運行向后傳遞。

正向通過過程中的自動分級記錄

PyTorch 在正向傳遞過程中會構(gòu)建自動分級圖,該圖用于執(zhí)行向后傳遞。 有關(guān)更多詳細信息,請參見 autograd 如何編碼歷史記錄。

對于分布式 autograd,我們需要在正向傳遞過程中跟蹤所有 RPC,以確保正確執(zhí)行向后傳遞。 為此,我們在執(zhí)行 RPC 時將sendrecv函數(shù)附加到自動縮放圖。

  • send函數(shù)附加到 RPC 的源,并且其輸出邊指向 RPC 輸入張量的 autograd 函數(shù)。 從目的地接收反向傳遞期間此功能的輸入,作為適當(dāng)?shù)?code>recv功能的輸出。
  • recv函數(shù)附加到 RPC 的目標,并且使用輸入張量從在目標上執(zhí)行的運算符檢索其輸入。 在向后傳遞過程中,此函數(shù)的輸出梯度將發(fā)送到源節(jié)點并發(fā)送到適當(dāng)?shù)?code>send函數(shù)。
  • 每個send-recv對都分配有一個全局唯一的autograd_message_id,以唯一地標識該對。 這對于在反向傳遞期間在遠程節(jié)點上查找對應(yīng)的功能很有用。
  • 對于 RRef ,每當(dāng)我們調(diào)用 torch.distributed.rpc.RRef.to_here() 時,我們都會為所涉及的張量附加一個適當(dāng)?shù)膕end-recv對。

舉例來說,這就是我們上面示例中的 autograd 圖的樣子(為簡單起見,排除了 t5.sum()):

../_images/send_recv_functions.png

分布式 Autograd 上下文

每個使用分布式 autograd 的正向和反向傳遞都分配有唯一的 torch.distributed.autograd.context ,并且此上下文具有全局唯一的autograd_context_id。 根據(jù)需要在每個節(jié)點上創(chuàng)建此上下文。

此上下文具有以下目的:

  1. 運行分布式后向遍歷的多個節(jié)點可能會在同一張量上累積梯度,因此,在我們有機會運行優(yōu)化器之前,張量的.grad字段將具有來自各種分布式后向遍歷的梯度。 這類似于在本地多次調(diào)用 torch.autograd.backward() 。 為了提供一種為每個后退通道分離梯度的方法,對于每個后退通道,梯度會累積在  torch.distributed.autograd.context 中。
  2. 在前向傳遞過程中,我們在這種情況下為每個自動分級傳遞存儲sendrecv函數(shù)。 這樣可以確保我們保留對 autograd 圖中適當(dāng)節(jié)點的引用,以使其保持活動狀態(tài)。 除此之外,在向后傳遞過程中很容易查找適當(dāng)?shù)?code>send和recv功能。
  3. 通常,我們還使用此上下文為每個分布式 autograd pass 存儲一些元數(shù)據(jù)。

從用戶的角度來看,自動分級上下文的設(shè)置如下:

  1. import torch.distributed.autograd as dist_autograd
  2. with dist_autograd.context() as context_id:
  3. loss = model.forward()
  4. dist_autograd.backward(loss)

分布式后向通行證

在本節(jié)中,我們概述了在分布式后向傳遞過程中準確計算依賴項的挑戰(zhàn),并描述了一些關(guān)于如何執(zhí)行分布式后向傳遞的算法(需要權(quán)衡)。

計算依賴

考慮以下代碼在單臺計算機上運行

  1. import torch
  2. a = torch.rand((3, 3), requires_grad=True)
  3. b = torch.rand((3, 3), requires_grad=True)
  4. c = torch.rand((3, 3), requires_grad=True)
  5. d = a + b
  6. e = b * c
  7. d.sum.().backward()

這就是上面代碼的 autograd 圖形:

../_images/local_dependencies.png

autograd 引擎作為向后傳遞的一部分執(zhí)行的第一步是計算 autograd 圖中每個節(jié)點的依賴項數(shù)量。 這有助于 autograd 引擎知道何時可以執(zhí)行圖中的節(jié)點。 add(1)mul(0)括號中的數(shù)字表示依賴項的數(shù)量。 如您所見,這意味著在向后傳遞期間,add節(jié)點需要 1 個輸入,mul節(jié)點不需要任何輸入(換句話說,不需要執(zhí)行)。 本地 autograd 引擎通過遍歷根節(jié)點中的圖來計算這些依賴性(在這種情況下為d)。

autograd 圖中的某些節(jié)點可能無法在向后傳遞中執(zhí)行的事實對分布式 autograd 提出了挑戰(zhàn)。 考慮使用 RPC 的這段代碼。

  1. import torch
  2. import torch.distributed.rpc as rpc
  3. a = torch.rand((3, 3), requires_grad=True)
  4. b = torch.rand((3, 3), requires_grad=True)
  5. c = torch.rand((3, 3), requires_grad=True)
  6. d = rpc.rpc_sync("worker1", torch.add, args=(a, b))
  7. e = rpc.rpc_sync("worker1", torch.mul, args=(b, c))
  8. loss = d.sum()

上面的代碼的相關(guān)自動分級圖為:

../_images/distributed_dependencies.png

計算此分布式 autograd 圖的依賴項更具挑戰(zhàn)性,并且需要一些開銷(無論是在計算還是在網(wǎng)絡(luò)通信方面)。

對于性能敏感的應(yīng)用程序,我們可以通過假設(shè)每個sendrecv函數(shù)在反向傳遞中都是有效的(大多數(shù)應(yīng)用程序不執(zhí)行未使用的 RPC)來避免很多開銷。 這簡化了分布式 autograd 算法,并且效率更高,但代價是應(yīng)用程序需要意識到這些限制。 該算法稱為 FAST 模式算法,下面將對其進行詳細說明。

在一般情況下,可能不需要每個sendrecv函數(shù)都有效作為反向傳遞的一部分。 為了解決這個問題,我們還有一個 SMART 模式算法,將在后面的部分中進行介紹。

快速模式算法

該算法的關(guān)鍵假設(shè)是,當(dāng)我們運行向后傳遞時,每個send函數(shù)的相關(guān)性均為 1。 換句話說,我們假設(shè)將從另一個節(jié)點接收到 RPC 上的漸變。

算法如下:

  1. 我們從具有向后遍歷的根的工作程序開始(所有根必須是本地的)。
  2. 查找當(dāng)前分布式 Autograd 上下文的所有send功能。
  3. 從提供的根目錄和我們檢索到的所有send函數(shù)開始,本地計算依賴項。
  4. 計算依賴關(guān)系后,使用提供的根啟動本地 autograd 引擎。
  5. 當(dāng) autograd 引擎執(zhí)行recv功能時,recv功能會通過 RPC 將輸入梯度發(fā)送到適當(dāng)?shù)墓ぷ鞒绦颉?每個recv函數(shù)都知道目標工作者 ID,因為它被記錄為正向傳遞的一部分。 recv功能還將autograd_context_idautograd_message_id發(fā)送到遠程主機。
  6. 當(dāng)在遠程主機上收到此請求時,我們使用autograd_context_id和autograd_message_id查找適當(dāng)?shù)?code>send功能。
  7. 如果這是工作人員第一次收到對給定autograd_context_id的請求,則它將如上面的第 1-3 點所述在本地計算依賴性。
  8. 然后,將在 6 中檢索到的send函數(shù)排隊以便在該工作者的本地 autograd 引擎上執(zhí)行。
  9. 最后,我們不是在張量的.grad字段上累積梯度,而是根據(jù)分布式自學(xué)背景分別累積梯度。 梯度存儲在Dict[Tensor, Tensor]中,基本上是從 Tensor 到其相關(guān)梯度的映射,可以使用 get_gradients() API 檢索此映射。

例如,具有分布式 autograd 的完整代碼如下:

  1. import torch
  2. import torch.distributed.autograd as dist_autograd
  3. import torch.distributed.rpc as rpc
  4. def my_add(t1, t2):
  5. return torch.add(t1, t2)
  6. ## On worker 0:
  7. ## Setup the autograd context.
  8. with dist_autograd.context() as context_id:
  9. t1 = torch.rand((3, 3), requires_grad=True)
  10. t2 = torch.rand((3, 3), requires_grad=True)
  11. # Perform some computation remotely.
  12. t3 = rpc.rpc_sync("worker1", my_add, args=(t1, t2))
  13. # Perform some computation locally based on remote result.
  14. t4 = torch.rand((3, 3), requires_grad=True)
  15. t5 = torch.mul(t3, t4)
  16. # Compute some loss.
  17. loss = t5.sum()
  18. # Run the backward pass.
  19. dist_autograd.backward([loss])
  20. # Retrieve the gradients from the context.
  21. dist_autograd.get_gradients(context_id)

具有依賴關(guān)系的分布式 autograd 圖如下所示:

../_images/distributed_dependencies_computed.png

應(yīng)用于以上示例的 FAST 模式算法如下:

  1. Worker 0上,我們從根losssend1開始計算依賴關(guān)系。 結(jié)果,send1的依賴性為 1,mulWorker 0的依賴性為 1。
  2. 現(xiàn)在,我們在Worker 0上啟動本地 autograd 引擎。 我們首先執(zhí)行mul函數(shù),將其輸出在 autograd 上下文中累積為t4的梯度。 然后,我們執(zhí)行recv2,它將梯度發(fā)送到Worker 1。
  3. 由于這是Worker 1第一次聽到有關(guān)此反向傳遞的信息,因此它將開始依賴性計算并適當(dāng)?shù)貥擞?code>send2,addrecv1的依賴性。
  4. 接下來,將send2排隊在Worker 1的本地 autograd 引擎上,該引擎依次執(zhí)行addrecv1。
  5. 當(dāng)執(zhí)行recv1時,它將梯度發(fā)送到Worker 0
  6. 由于Worker 0已經(jīng)計算了此向后傳遞的依賴性,因此它僅排隊并在本地執(zhí)行send1。
  7. 最后,t1,t2t4的梯度會累積在分布式 Autograd 上下文中。

SMART 模式算法

該算法的完整細節(jié)仍在研究中,但是對于一般概念,您可以參考 RFC 中的分布式 Autograd Algorithm Smart 模式部分。

分布式優(yōu)化器

DistributedOptimizer 的操作如下:

  1. 獲取要優(yōu)化的遠程參數(shù)列表 (RRef)。 這些也可以是包裝在本地RRef中的本地參數(shù)。
  2. 將 Optimizer 類作為本地優(yōu)化器,以在所有不同的RRef所有者上運行。
  3. 分布式優(yōu)化器在每個工作程序節(jié)點上創(chuàng)建本地Optimizer的實例,并將其保存RRef。
  4. 當(dāng)調(diào)用 torch.distributed.optim.DistributedOptimizer.step() 時,分布式優(yōu)化器使用 RPC 在適當(dāng)?shù)倪h程工作器上遠程執(zhí)行所有本地優(yōu)化器。
  5. 如果多個并發(fā)的分布式優(yōu)化器正在更新工作器上的相同參數(shù),則這些更新將通過鎖序列化。

簡單的端到端示例

綜上所述,以下是使用分布式 autograd 和分布式優(yōu)化器的簡單的端到端示例。 如果將代碼放入名為“ dist_autograd_simple.py”的文件中,則可以使用命令MASTER_ADDR="localhost" MASTER_PORT=29500 python dist_autograd_simple.py運行該代碼:

  1. import multiprocessing as mp
  2. import torch
  3. import torch.distributed.autograd as dist_autograd
  4. from torch.distributed import rpc
  5. from torch import optim
  6. from torch.distributed.optim import DistributedOptimizer
  7. def random_tensor():
  8. return torch.rand((3, 3), requires_grad=True)
  9. def _run_process(rank, dst_rank, world_size):
  10. name = "worker{}".format(rank)
  11. dst_name = "worker{}".format(dst_rank)
  12. # Initialize RPC.
  13. rpc.init_rpc(
  14. name=name,
  15. rank=rank,
  16. world_size=world_size
  17. )
  18. # Use a distributed autograd context.
  19. with dist_autograd.context() as context_id:
  20. # Forward pass (create references on remote nodes).
  21. rref1 = rpc.remote(dst_name, random_tensor)
  22. rref2 = rpc.remote(dst_name, random_tensor)
  23. loss = rref1.to_here() + rref2.to_here()
  24. # Backward pass (run distributed autograd).
  25. dist_autograd.backward([loss.sum()])
  26. # Build DistributedOptimizer.
  27. dist_optim = DistributedOptimizer(
  28. optim.SGD,
  29. [rref1, rref2],
  30. lr=0.05,
  31. )
  32. # Run the distributed optimizer step.
  33. dist_optim.step()
  34. def run_process(rank, dst_rank, world_size):
  35. _run_process(rank, dst_rank, world_size)
  36. rpc.shutdown()
  37. processes = []
  38. ## Run world_size workers.
  39. world_size = 2
  40. for i in range(world_size):
  41. p = mp.Process(target=run_process, args=(i, (i + 1) % 2, world_size))
  42. p.start()
  43. processes.append(p)
  44. for p in processes:
  45. p.join()


以上內(nèi)容是否對您有幫助:
在線筆記
App下載
App下載

掃描二維碼

下載編程獅App

公眾號
微信公眾號

編程獅公眾號