[pipeline/fix-bug] num_microbatches support any integrate | stable chimera | launch tool for rpc pp framework (#1684)

* [pipeline/tuning] improve dispatch performance both time and space cost

* [pipeline/converge] add interface for testing convergence

* [NFC] polish colossalai/utils/multi_tensor_apply/multi_tensor_apply.py code style

* Update PipelineBase.py

* [pipeline/chimera] reconstruct PipelineBase and Worker to support more feasible custom schedule | finish Chimera

* [pipeline/chimera] test chimera | fix bug of initializing

* [pipeline/pytree] add pytree to process args and kwargs | provide  to process args and kwargs after forward

* [pipeline/fix-bug] num_microbatches support any integrate | stable chimera | launch tool for rpc pp framework
This commit is contained in:
Kirigaya Kazuto
2022-10-10 16:01:02 +08:00
committed by GitHub
parent e5ab6be72e
commit 0df5034a36
4 changed files with 98 additions and 24 deletions

View File

@@ -1,4 +1,5 @@
from typing import List, Callable, Dict
import threading
import torch
import torch.distributed as dist
@@ -81,7 +82,8 @@ class OneFOneBWorker(WorkerBase):
# 2. forward times reach num_microbatches, this is the end of 1F1B mode
if not is_last_stage and \
target_key.phase == Phase.FORWARD:
if target_key.microbatch_id == actual_stage_num - 1:
if target_key.microbatch_id == actual_stage_num - 1 and num_microbatches > 2:
# Why need num_microbatches > 2 ? Because there is no steady stage when num_microbatches <= 2
outstanding_min = actual_stage_num - pp_rank - 1
outstanding_max = actual_stage_num - pp_rank
self.outstanding_range = (outstanding_min, outstanding_max)
@@ -186,6 +188,19 @@ class ChimeraWorker(WorkerBase):
# init group for chimera in ppg
ppg.get_chimera_all_reduce_group(pp_rank)
# lock for step sync
self.step_sync_lock = threading.Lock()
self.step_sync_lock.acquire()
self.have_grad_lock = threading.Lock()
self.have_grad_lock.acquire()
def _get_lock_gradient(self):
self.have_grad_lock.acquire()
grads = self.get_parameter_gradients()
self.step_sync_lock.release()
return grads
def is_first_stage(self):
return (self.pp_rank % self.actual_stage_num) == 0
@@ -214,27 +229,22 @@ class ChimeraWorker(WorkerBase):
return local_device_pp_ranks
def _hook_before_step(self):
self.have_grad_lock.release()
pp_rank = self.pp_rank
orders = self._get_step_order()
step_index = orders.index(pp_rank)
stage_num = self.actual_stage_num
co_pp_rank = (pp_rank + stage_num) % (2 * stage_num)
# if currrent pp_rank is not the first to do step
# wait its previous pp_rank finish step
all_reduce_group = ppg.get_chimera_all_reduce_group(self.pp_rank)
grads = self.get_parameter_gradients()
# print(self.pp_rank, "begin all reduce", torch.cuda.max_memory_allocated(ppg.get_local_pp_rank()), torch.cuda.max_memory_reserved(ppg.get_local_pp_rank()))
if step_index == 1:
ppg.chimera_step_lock.acquire()
# print(f'rank_{self.pp_rank} before all reduce')
dist.all_reduce_coalesced(grads, group=all_reduce_group, async_op=False)
# print(f'rank_{self.pp_rank} after all reduce')
if step_index == 0:
ppg.chimera_step_lock.release()
# send
co_worker = self.pp_rank_to_worker_rref[co_pp_rank]
co_grads = co_worker.rpc_sync()._get_lock_gradient()
# sync
self.step_sync_lock.acquire()
for i in range(len(grads)):
grads[i] += co_grads[i]
class ChimeraPipelineEngine(PipelineEngineBase):
@@ -257,8 +267,8 @@ class ChimeraPipelineEngine(PipelineEngineBase):
super().__init__(ChimeraWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion,
metric, checkpoint, data_process_func)
def _consume_constraint(self, microbatch_id: int, forward_only: bool, ret_future: Dict[PyRRef, List[Future]],
input_pp_ranks: List[PyRRef], output_pp_ranks: List[PyRRef]):
def _consume_constraint(self, microbatch_id: int, forward_only: bool, input_pp_ranks: List[int],
output_pp_ranks: List[int], ret_future):
pass
def _create_pp_rank_to_rpc_worker_id(self) -> None: