[pipeline/chimera] test chimera | fix bug of initializing (#1615)

* [pipeline/tuning] improve dispatch performance both time and space cost

* [pipeline/converge] add interface for testing convergence

* [NFC] polish colossalai/utils/multi_tensor_apply/multi_tensor_apply.py code style

* Update PipelineBase.py

* [pipeline/chimera] reconstruct PipelineBase and Worker to support more feasible custom schedule | finish Chimera

* [pipeline/chimera] test chimera | fix bug of initializing
This commit is contained in:
Kirigaya Kazuto
2022-09-20 18:00:39 +08:00
committed by GitHub
parent 504ff1d101
commit 170fa81095
13 changed files with 342 additions and 144 deletions

View File

@@ -0,0 +1,3 @@
from ._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine, ChimeraPipelineEngine
__all__ = ['FillDrainPipelineEngine', 'OneFOneBPipelineEngine', 'ChimeraPipelineEngine']

View File

@@ -139,7 +139,8 @@ class BackwardCache:
class WorkerBase(ABC):
def __init__(self,
module_partition: nn.Module,
partition_fn: Callable,
partition_args: tuple,
pp_rank: int,
actual_stage_num: int,
num_microbatches: int,
@@ -165,21 +166,22 @@ class WorkerBase(ABC):
# rref of other workers
self.pp_rank_to_worker_rref: Dict[int, PyRRef] = None
# lock for the list
self._initialize_lock()
# topology info
self.producer_stage_ids: List[int] = None
self.consumer_stage_ids: List[int] = None
# module partitions
self.module_partition = module_partition.to(device)
self.partition_fn = partition_fn
self.partition_args = partition_args
self.criterion = criterion
self.metric = metric
# context to maintain loop
self._initialize_context_container()
# lock for the list
self._initialize_lock()
# main loop
self.main_loop_thread = threading.Thread(target=self._work_loop, name=f'rank_{pp_rank}', daemon=True)
self.main_loop_thread.start()
@@ -202,20 +204,37 @@ class WorkerBase(ABC):
self.output_list: Dict[UniqueKey, WorkItem] = dict()
def _initialize_lock(self):
self.partition_condition_lock = threading.Condition(threading.Lock())
self.work_list_condition_lock = threading.Condition(threading.Lock())
self.output_list_condition_lock = threading.Condition(threading.Lock())
self.label_lock = threading.Condition(threading.Lock())
def _initialize_partition(self):
partition_fn = self.partition_fn
partition_args = self.partition_args
device = self.device
with self.partition_condition_lock:
self.module_partition: nn.Module = partition_fn(*partition_args).to(device)
self.partition_condition_lock.notify_all()
def sync_global_worker_rrefs(self, pp_rank_to_worker_rref: Dict[int, PyRRef]) -> None:
assert self.pp_rank_to_worker_rref is None, f"in rank {self.pp_rank}, worker has sync global workers rrefs"
assert pp_rank_to_worker_rref is not None, "stage_to_workers must be a dict instead of None"
self.pp_rank_to_worker_rref = pp_rank_to_worker_rref
# for some schedule need the other worker's info to initialise partition (like Chimera)
# construction of partition is executed after the registion of pp_rank_to_worker_rref
self._initialize_partition()
def get_output_by_key(self, key: UniqueKey) -> Any:
with self.output_list_condition_lock:
self.output_list_condition_lock.wait_for(lambda: key in self.output_list)
output_work_item = self.output_list[key]
output = output_work_item.output.wait()
output = output_work_item.output
if isinstance(output, Future):
output = output.wait()
# color_debug(f'rank {self.pp_rank}, output {type(output)}', 'get output', 'red')
output_work_item.refcount += 1
@@ -231,6 +250,16 @@ class WorkerBase(ABC):
def get_parameter_gradients(self) -> List[torch.Tensor]:
return [p.grad for p in self.module_partition.parameters()]
def get_partition(self):
with self.partition_condition_lock:
self.partition_condition_lock.wait_for(lambda: hasattr(self, 'module_partition'))
return self.module_partition
def get_partition_state_dict(self):
with self.partition_condition_lock:
self.partition_condition_lock.wait_for(lambda: hasattr(self, 'module_partition'))
return self.module_partition.state_dict()
# just for first pp_rank
def set_input(self, microbatch_id: int, microbatch: Tuple[Any], forward_only: bool):
assert self.consumer_stage_ids is not None
@@ -520,6 +549,15 @@ class WorkerBase(ABC):
is_last_microbatch = work_item.microbatch_id == self.num_microbatches - 1
return is_last_phase and is_last_microbatch
def _hook_before_step(self):
pass
def _reset_context(self):
self.forward_times = 0
self.backward_times = 0
self.outstanding = 0
self._initialize_outstanding_range()
# do the main loop to consume ready_list
def _work_loop(self):
# for init
@@ -545,19 +583,17 @@ class WorkerBase(ABC):
consume_result = self._consume_work_item_by_phase(work_item)
color_debug(
f'rank_{self.pp_rank} [{work_item.phase}] finish consuming, result is {tensor_shape_list(consume_result)} {self._get_store_len()}',
f'rank_{self.pp_rank} [{work_item.phase}] finish consuming, result is {tensor_shape_list(consume_result)} {self._get_store_len()} | {self.work_list.keys()} | {self.output_list.keys()}',
'work loop', 'green')
work_item.output.set_result(consume_result)
# if is last step in one batch reset context and do step
if self._is_last_step(work_item):
self._hook_before_step()
if hasattr(self, 'optimizer') and not work_item.forward_only:
self.step()
self.forward_times = 0
self.backward_times = 0
self.outstanding = 0
self._initialize_outstanding_range()
self._reset_context()
def initialize_optimizer(self, optimizer_class: type, **kwargs):
self.optimizer: optim.Optimizer = optimizer_class(self.module_partition.parameters(), **kwargs)
@@ -577,7 +613,7 @@ class PipelineEngineBase(ABC, nn.Module):
def __init__(self,
worker_type,
module_partitions,
partition_fn: Callable,
stage_num,
num_microbatches,
device: str,
@@ -588,7 +624,7 @@ class PipelineEngineBase(ABC, nn.Module):
checkpoint: bool = False) -> None:
super().__init__()
self.worker_type = worker_type
self.module_partitions: List[nn.Module] = module_partitions
self.partition_fn: Callable = partition_fn
self.chunk = chunk
self.criterion = criterion
self.metric = metric
@@ -609,18 +645,15 @@ class PipelineEngineBase(ABC, nn.Module):
def _check_argument(self) -> None:
self.virtual_stage_num = self.stage_num * self.chunk
assert self.stage_num <= torch.cuda.device_count(), "stage_num must be smaller than device count!"
assert self.virtual_stage_num == len(
self.module_partitions), "stage_num * chunk must be equal to length of model partition!"
def _get_actual_stage_num(self) -> int:
return self.stage_num if self.chunk == 1 else self.virtual_stage_num
def _create_pp_rank_to_rpc_worker_id(self) -> None:
"""create a map from model partition to stage_id, which is useful when use_interleave is True.
e.g. If a model is splited into 4 parts, which means len(self.module_partitions) == 3.
stage_num is 2, chunk is 2, then pp_rank_to_rpc_worker_id = [0, 1, 0, 1], that means first and third part
e.g. If a model is splited into 4 parts, which means stage_num is 2, chunk is 2, then
pp_rank_to_rpc_worker_id = [0, 1, 0, 1], that means first and third part
of partitions will be moved to device 0 and the others to device 1
"""
stage_num = self.stage_num
@@ -647,26 +680,34 @@ class PipelineEngineBase(ABC, nn.Module):
device = self.device
criterion = self.criterion
metric = self.metric
partition_fn = self.partition_fn
chunk = self.chunk
for pp_rank in range(len(self.pp_rank_to_rpc_worker_id)):
module_partition_id = self.pp_rank_to_module_partition_id[pp_rank]
partition_id = self.pp_rank_to_module_partition_id[pp_rank]
partition_args = (partition_id, chunk, actual_stage_num)
rpc_worker_id = self.pp_rank_to_rpc_worker_id[pp_rank]
if device[:4] == 'cuda':
device = f'cuda:{rpc_worker_id}'
module_partition = self.module_partitions[module_partition_id]
self.pp_rank_to_worker_rref[pp_rank] = rpc.remote(rpc_worker_id,
worker_type,
args=(module_partition, pp_rank, actual_stage_num,
num_microbatches, device, criterion, metric,
checkpoint))
args=(partition_fn, partition_args, pp_rank,
actual_stage_num, num_microbatches, device,
criterion, metric, checkpoint))
# let each worker know global worker rref (include itself)
sync_futs = []
for pp_rank in self.pp_rank_to_worker_rref:
self.pp_rank_to_worker_rref[pp_rank].rpc_sync().sync_global_worker_rrefs(self.pp_rank_to_worker_rref)
fut = self.pp_rank_to_worker_rref[pp_rank].rpc_async().sync_global_worker_rrefs(self.pp_rank_to_worker_rref)
sync_futs.append(fut)
for fut in sync_futs:
fut.wait()
def remote_parameters(self) -> Dict[int, List[torch.Tensor]]:
parameters = {}
for stage_id in self.pp_rank_to_worker_rref:
actual_stage_num = self._get_actual_stage_num()
for stage_id in range(actual_stage_num):
parameters[stage_id] = []
worker_rref = self.pp_rank_to_worker_rref[stage_id]
for p in worker_rref.rpc_sync().get_parameters():
@@ -675,7 +716,8 @@ class PipelineEngineBase(ABC, nn.Module):
def remote_grad(self) -> Dict[int, List[torch.Tensor]]:
grads = {}
for stage_id in self.pp_rank_to_worker_rref:
actual_stage_num = self._get_actual_stage_num()
for stage_id in range(actual_stage_num):
grads[stage_id] = []
worker_rref = self.pp_rank_to_worker_rref[stage_id]
for grad in worker_rref.rpc_sync().get_parameter_gradients():
@@ -784,7 +826,7 @@ class PipelineEngineBase(ABC, nn.Module):
# collect forward result
forward_result = self._collect_forward_result(output_pp_ranks, ret_future)
if not forward_only and labels is not None:
if not forward_only and hasattr(self, 'optimizer_class'):
# wait for all step
for pp_rank in self.pp_rank_to_worker_rref:
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
@@ -793,9 +835,8 @@ class PipelineEngineBase(ABC, nn.Module):
return forward_result
def initialize_optimizer(self, optimizer_class: type, **kwargs):
actual_stage_num = self._get_actual_stage_num()
self.optimizer_class = optimizer_class
for pp_rank in range(actual_stage_num):
for pp_rank in self.pp_rank_to_worker_rref:
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
worker_rref.remote().initialize_optimizer(optimizer_class, **kwargs)

View File

@@ -1,10 +1,12 @@
from typing import List, Callable, Dict
import torch.nn as nn
import torch
import torch.distributed as dist
from torch.futures import Future
from torch._C._distributed_rpc import PyRRef
from colossalai.pipeline.rpc._pipeline_base import PipelineEngineBase, WorkerBase, UniqueKey, Phase
from colossalai.pipeline.rpc._pipeline_base import PipelineEngineBase, WorkerBase, UniqueKey, Phase, WorkItem
from colossalai.pipeline.pipeline_process_group import ppg
# Implementation of different Pipeline schedule
# <strategy>Worker defines the worker for each stage
@@ -35,7 +37,7 @@ class FillDrainWorker(WorkerBase):
class FillDrainPipelineEngine(PipelineEngineBase):
def __init__(self,
module_partitions: List[nn.Module],
partition_fn: Callable,
stage_num: int,
num_microbatches: int,
device: str,
@@ -49,8 +51,8 @@ class FillDrainPipelineEngine(PipelineEngineBase):
"if you use interleaving strategy, make sure 'num_microbatches' is a multiple of stage_num!"
use_1F1B = False
super().__init__(FillDrainWorker, module_partitions, stage_num, num_microbatches, device, use_1F1B, chunk,
criterion, metric, checkpoint)
super().__init__(FillDrainWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion,
metric, checkpoint)
class OneFOneBWorker(WorkerBase):
@@ -94,7 +96,7 @@ class OneFOneBWorker(WorkerBase):
class OneFOneBPipelineEngine(PipelineEngineBase):
def __init__(self,
module_partitions: List[nn.Module],
partition_fn: Callable,
stage_num: int,
num_microbatches: int,
device: str,
@@ -106,10 +108,11 @@ class OneFOneBPipelineEngine(PipelineEngineBase):
if chunk > 1:
assert num_microbatches % stage_num == 0, \
"if you use interleaving strategy, make sure 'num_microbatches' is a multiple of stage_num!"
assert num_microbatches > stage_num * chunk, "num_microbatches must be greater than stage_num * chunk"
use_1F1B = True
super().__init__(OneFOneBWorker, module_partitions, stage_num, num_microbatches, device, use_1F1B, chunk,
criterion, metric, checkpoint)
super().__init__(OneFOneBWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion,
metric, checkpoint)
class ChimeraWorker(WorkerBase):
@@ -139,21 +142,16 @@ class ChimeraWorker(WorkerBase):
stage_num = self.actual_stage_num
real_microbatch_num = self.num_microbatches // 2
if self.forward_times < real_microbatch_num:
if (pp_rank + 1) % stage_num == 0: # last rank
forward_blocks = self.forward_times // (self.num_microbatches // stage_num)
if forward_blocks > self.backward_times:
target_phase = Phase.BACKWARD
target_microbatch_id = self.backward_times
else:
target_phase = Phase.FORWARD
target_microbatch_id = self.forward_times
else: # others
target_phase = Phase.FORWARD
target_microbatch_id = self.forward_times
else:
forward_block_size = 1 if self.num_microbatches < stage_num else self.num_microbatches // stage_num
forward_block_num = self.forward_times // forward_block_size
if self.forward_times >= real_microbatch_num or \
((pp_rank + 1) % stage_num == 0 and forward_block_num > self.backward_times):
target_phase = Phase.BACKWARD
target_microbatch_id = self.backward_times
else: # others
target_phase = Phase.FORWARD
target_microbatch_id = self.forward_times
# In up pipeline, microbatch_id to consume is 0, 2, 4 (2n)
# In down pipeline, microbatch_id to consume is 1, 3, 5 (2n + 1)
@@ -164,22 +162,85 @@ class ChimeraWorker(WorkerBase):
with self.work_list_condition_lock:
self.work_list_condition_lock.wait_for(lambda: target_key in self.work_list)
return target_key
def _initialize_partition(self):
# In order to ensure the down pipeline share the same parameter
# with the up pipeline, partition of down partition will be copied
# from corresponding up stage
pp_rank = self.pp_rank
stage_num = self.actual_stage_num
device = self.device
if pp_rank < stage_num:
super()._initialize_partition()
else:
# if it is down pipeline, create partition by origin method
co_up_pp_worker_rref = self.pp_rank_to_worker_rref[pp_rank - stage_num]
# get the coresponding model state dict and wait for its init
state_dict = co_up_pp_worker_rref.rpc_sync().get_partition_state_dict()
super()._initialize_partition()
self.module_partition.load_state_dict(state_dict)
# init group for chimera in ppg
ppg.get_chimera_all_reduce_group(pp_rank)
def is_first_stage(self):
return (self.pp_rank % self.actual_stage_num) == 0
def is_last_stage(self):
return (self.pp_rank % self.actual_stage_num) == self.actual_stage_num - 1
def _is_last_step(self, work_item: WorkItem) -> bool:
if work_item.forward_only:
last_phase = Phase.FORWARD
else:
last_phase = Phase.BACKWARD
is_last_phase = work_item.phase == last_phase
last_microbatch_id = self.num_microbatches - 1
if self.pp_rank < self.actual_stage_num:
last_microbatch_id -= 1
is_last_microbatch = work_item.microbatch_id == last_microbatch_id
return is_last_phase and is_last_microbatch
def _get_step_order(self) -> List[int]:
# TODO : If you want to extend it to multi head chimera, overwrite here
stage_num = self.actual_stage_num
pp_rank = self.pp_rank
# pp_rank in the same device
local_device_pp_ranks = [pp_rank, stage_num * 2 - pp_rank - 1]
local_device_pp_ranks.sort(reverse=min(local_device_pp_ranks) < stage_num // 2)
return local_device_pp_ranks
def _hook_before_step(self):
pp_rank = self.pp_rank
orders = self._get_step_order()
step_index = orders.index(pp_rank)
# if currrent pp_rank is not the first to do step
# wait its previous pp_rank finish step
all_reduce_group = ppg.get_chimera_all_reduce_group(self.pp_rank)
grads = self.get_parameter_gradients()
# print(self.pp_rank, "begin all reduce", torch.cuda.max_memory_allocated(ppg.get_local_pp_rank()), torch.cuda.max_memory_reserved(ppg.get_local_pp_rank()))
if step_index == 1:
ppg.chimera_step_lock.acquire()
print(f'rank_{self.pp_rank} before all reduce')
dist.all_reduce_coalesced(grads, group=all_reduce_group, async_op=False)
print(f'rank_{self.pp_rank} after all reduce')
if step_index == 0:
ppg.chimera_step_lock.release()
class ChimeraPipelineEngine(PipelineEngineBase):
def __init__(self,
module_partitions,
stage_num,
num_microbatches,
partition_fn: Callable,
stage_num: int,
num_microbatches: int,
device: str,
criterion: Callable = None,
metric: Callable = None,
@@ -189,11 +250,12 @@ class ChimeraPipelineEngine(PipelineEngineBase):
"In Chimera, num_microbatches must be the multiply of stage_num!"
use_1F1B = False
chunk = 1
super().__init__(ChimeraWorker, module_partitions, stage_num, num_microbatches, device, use_1F1B, chunk,
criterion, metric, checkpoint)
super().__init__(ChimeraWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion,
metric, checkpoint)
def _consume_constraint(self, microbatch_id: int, forward_only: bool, ret_future: Dict[PyRRef, List[Future]],
input_worker_rrefs: List[PyRRef], output_worker_rrefs: List[PyRRef]):
input_pp_ranks: List[PyRRef], output_pp_ranks: List[PyRRef]):
pass
def _create_pp_rank_to_rpc_worker_id(self) -> None:
@@ -254,7 +316,6 @@ class ChimeraPipelineEngine(PipelineEngineBase):
up_key = UniqueKey(up_last_microbatch_id, Phase.BACKWARD)
down_key = UniqueKey(down_last_microbatch_id, Phase.BACKWARD)
up_worker_rref.rpc_sync().get_output_by_key(up_key)
down_worker_rref.rpc_sync().get_output_by_key(down_key)