[pipeline/rpc] implement distributed optimizer | test with assert_close (#1486)

* support p2p communication with any type of object | pass test

* reconstruct pipeline schedule with p2p_v2.py(support communication with List[Any]) | pass test

* [engin/schedule] use p2p_v2 to recontruct pipeline_schedule

* [pipeline/rpc] implement a demo for PP with cuda rpc framework

* [pipeline/rpc] support interleaving | fix checkpoint bug | change logic when dispatch data in work_list to ensure steady 1F1B

* [pipeline/rpc] implement distributed optimizer | test with assert_close

* [pipeline/rpc] implement distributed optimizer | test with assert_close
This commit is contained in:
Kirigaya Kazuto
2022-08-25 10:49:01 +08:00
committed by GitHub
parent 3da68d6b1b
commit 9145aef2b4
5 changed files with 220 additions and 157 deletions

View File

@@ -9,6 +9,7 @@ import torch.distributed.rpc as rpc
from torch.futures import Future
from torch._C._distributed_rpc import PyRRef
from torch import autograd
from torch import optim
from tqdm import tqdm
from colorama import Back, Style
@@ -43,8 +44,7 @@ def tensor_shape_list(tensors):
class Phase(Enum):
FORWARD = 0
BACKWARD = 1
ACCUM_GRAD = 2
SYNC = 3
UPDATE = 2
class UniqueKey:
@@ -440,8 +440,6 @@ class Worker:
if isinstance(input_node, torch.Tensor):
consume_result.append(input_node.grad)
elif phase == Phase.SYNC:
pass
else:
raise TypeError(f"Unknown phase appears in _consume_work_item_by_phase {phase}")
@@ -478,6 +476,18 @@ class Worker:
'work loop', 'green')
work_item.output.set_result(consume_result)
def initialize_optimizer(self, optimizer_class: type, **kwargs):
self.optimizer: optim.Optimizer = optimizer_class(self.module_partition.parameters(), **kwargs)
def step(self):
assert hasattr(self, "optimizer"), "call initialize_optimizer first before you call step!"
self.work_list.clear()
self.output_list.clear()
self.microbatch_id_to_backward_cache.clear()
self.optimizer.step()
self.optimizer.zero_grad()
# TODO
# 1. chunk
@@ -617,6 +627,24 @@ class PipelineEngineBase(ABC, nn.Module):
first_stage_worker.rpc_sync().get_output_by_key(key)
return forward_result
def initialize_optimizer(self, optimizer_class: type, **kwargs):
actual_stage_num = self._get_actual_stage_num()
for pp_rank in range(actual_stage_num):
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
worker_rref.rpc_sync().initialize_optimizer(optimizer_class, **kwargs)
def step(self):
step_futs: List[Future] = []
actual_stage_num = self._get_actual_stage_num()
for pp_rank in range(actual_stage_num):
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
fut = worker_rref.rpc_async().step()
step_futs.append(fut)
# wait for all optimizers
for fut in step_futs:
fut.wait()
class FillDrainPipelineEngine(PipelineEngineBase):