mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 22:52:25 +00:00
[pipeline/rpc] implement distributed optimizer | test with assert_close (#1486)
* support p2p communication with any type of object | pass test * reconstruct pipeline schedule with p2p_v2.py(support communication with List[Any]) | pass test * [engin/schedule] use p2p_v2 to recontruct pipeline_schedule * [pipeline/rpc] implement a demo for PP with cuda rpc framework * [pipeline/rpc] support interleaving | fix checkpoint bug | change logic when dispatch data in work_list to ensure steady 1F1B * [pipeline/rpc] implement distributed optimizer | test with assert_close * [pipeline/rpc] implement distributed optimizer | test with assert_close
This commit is contained in:
@@ -9,6 +9,7 @@ import torch.distributed.rpc as rpc
|
||||
from torch.futures import Future
|
||||
from torch._C._distributed_rpc import PyRRef
|
||||
from torch import autograd
|
||||
from torch import optim
|
||||
from tqdm import tqdm
|
||||
|
||||
from colorama import Back, Style
|
||||
@@ -43,8 +44,7 @@ def tensor_shape_list(tensors):
|
||||
class Phase(Enum):
|
||||
FORWARD = 0
|
||||
BACKWARD = 1
|
||||
ACCUM_GRAD = 2
|
||||
SYNC = 3
|
||||
UPDATE = 2
|
||||
|
||||
|
||||
class UniqueKey:
|
||||
@@ -440,8 +440,6 @@ class Worker:
|
||||
if isinstance(input_node, torch.Tensor):
|
||||
consume_result.append(input_node.grad)
|
||||
|
||||
elif phase == Phase.SYNC:
|
||||
pass
|
||||
else:
|
||||
raise TypeError(f"Unknown phase appears in _consume_work_item_by_phase {phase}")
|
||||
|
||||
@@ -478,6 +476,18 @@ class Worker:
|
||||
'work loop', 'green')
|
||||
work_item.output.set_result(consume_result)
|
||||
|
||||
def initialize_optimizer(self, optimizer_class: type, **kwargs):
|
||||
self.optimizer: optim.Optimizer = optimizer_class(self.module_partition.parameters(), **kwargs)
|
||||
|
||||
def step(self):
|
||||
assert hasattr(self, "optimizer"), "call initialize_optimizer first before you call step!"
|
||||
self.work_list.clear()
|
||||
self.output_list.clear()
|
||||
self.microbatch_id_to_backward_cache.clear()
|
||||
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
|
||||
# TODO
|
||||
# 1. chunk
|
||||
@@ -617,6 +627,24 @@ class PipelineEngineBase(ABC, nn.Module):
|
||||
first_stage_worker.rpc_sync().get_output_by_key(key)
|
||||
return forward_result
|
||||
|
||||
def initialize_optimizer(self, optimizer_class: type, **kwargs):
|
||||
actual_stage_num = self._get_actual_stage_num()
|
||||
for pp_rank in range(actual_stage_num):
|
||||
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
|
||||
worker_rref.rpc_sync().initialize_optimizer(optimizer_class, **kwargs)
|
||||
|
||||
def step(self):
|
||||
step_futs: List[Future] = []
|
||||
actual_stage_num = self._get_actual_stage_num()
|
||||
for pp_rank in range(actual_stage_num):
|
||||
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
|
||||
fut = worker_rref.rpc_async().step()
|
||||
step_futs.append(fut)
|
||||
|
||||
# wait for all optimizers
|
||||
for fut in step_futs:
|
||||
fut.wait()
|
||||
|
||||
|
||||
class FillDrainPipelineEngine(PipelineEngineBase):
|
||||
|
||||
|
Reference in New Issue
Block a user