diff --git a/colossalai/pipeline/pipelinable.py b/colossalai/pipeline/pipelinable.py index d55da0fce..4d37c9833 100644 --- a/colossalai/pipeline/pipelinable.py +++ b/colossalai/pipeline/pipelinable.py @@ -110,7 +110,8 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses): name_list.append((name, param)) for name, param in name_list: - delattr(module, name) + if hasattr(module, name): + delattr(module, name) setattr(module, name, ColoParameter.from_torch_tensor(tensor=param.data, requires_grad=param.requires_grad)) def to_layer_list(self, exec_seq=None): diff --git a/colossalai/pipeline/pipeline_process_group.py b/colossalai/pipeline/pipeline_process_group.py index d6fe47bc4..e85f45548 100644 --- a/colossalai/pipeline/pipeline_process_group.py +++ b/colossalai/pipeline/pipeline_process_group.py @@ -1,5 +1,6 @@ from typing import List, Dict, Tuple import os +import threading from torch.distributed import rpc import torch.distributed as dist @@ -10,13 +11,17 @@ from colossalai.tensor import ProcessGroup class PipelineProcessGroup: # TODO : flexible API for DP size and TP size # In the future design mode, dp_degree and tp_degree should be removed - def __init__(self, - rank: int, - world_size: int, - dp_degree: int = 1, - tp_degree: int = 1, - num_worker_threads: int = 1, - device: str = "cuda") -> None: + def __init__(self) -> None: + self.is_initialize = False + + def set_global_info(self, + rank: int, + world_size: int, + dp_degree: int = 1, + tp_degree: int = 1, + num_worker_threads: int = 1, + device: str = "cuda") -> None: + device_mesh_size = dp_degree * tp_degree assert world_size % device_mesh_size == 0, "world_size must be the multiple of dp_degree * tp_degree !!!" self._num_worker_threads = num_worker_threads @@ -42,6 +47,11 @@ class PipelineProcessGroup: self._is_first_pp_rank = self._pp_rank == 0 self._is_last_pp_rank = self._pp_rank == self._stage_num - 1 + self.is_initialize = True + + # lock + self.chimera_lock = threading.Lock() + def _initialize_process_group(self): stage_num = self.get_stage_num() if stage_num == 1: @@ -133,3 +143,25 @@ class PipelineProcessGroup: def get_tp_global_ranks(self): pass + + def get_chimera_all_reduce_group(self, pp_rank: int): + with self.chimera_lock: + if not hasattr(self, 'chimera_groups'): + world_size = self.get_world_size() + stage_num = self.get_stage_num() + assert world_size % 2 == 0, 'world_size must be even in chimera!' + self.chimera_groups = {} + for rank in range(world_size // 2): + pair = [rank, world_size - 1 - rank] + group = dist.new_group(pair) + self.chimera_groups[pair[0]] = group + self.chimera_groups[pair[1]] = group + self.chimera_groups[pair[0] + stage_num] = group + self.chimera_groups[pair[1] + stage_num] = group + self.chimera_step_lock = threading.Lock() + self.chimera_step_lock.acquire() + + return self.chimera_groups[pp_rank] + + +ppg = PipelineProcessGroup() diff --git a/colossalai/pipeline/rpc/__init__.py b/colossalai/pipeline/rpc/__init__.py new file mode 100644 index 000000000..5e0726456 --- /dev/null +++ b/colossalai/pipeline/rpc/__init__.py @@ -0,0 +1,3 @@ +from ._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine, ChimeraPipelineEngine + +__all__ = ['FillDrainPipelineEngine', 'OneFOneBPipelineEngine', 'ChimeraPipelineEngine'] \ No newline at end of file diff --git a/colossalai/pipeline/rpc/_pipeline_base.py b/colossalai/pipeline/rpc/_pipeline_base.py index c03148505..58071dc26 100644 --- a/colossalai/pipeline/rpc/_pipeline_base.py +++ b/colossalai/pipeline/rpc/_pipeline_base.py @@ -139,7 +139,8 @@ class BackwardCache: class WorkerBase(ABC): def __init__(self, - module_partition: nn.Module, + partition_fn: Callable, + partition_args: tuple, pp_rank: int, actual_stage_num: int, num_microbatches: int, @@ -165,21 +166,22 @@ class WorkerBase(ABC): # rref of other workers self.pp_rank_to_worker_rref: Dict[int, PyRRef] = None + # lock for the list + self._initialize_lock() + # topology info self.producer_stage_ids: List[int] = None self.consumer_stage_ids: List[int] = None # module partitions - self.module_partition = module_partition.to(device) + self.partition_fn = partition_fn + self.partition_args = partition_args self.criterion = criterion self.metric = metric # context to maintain loop self._initialize_context_container() - # lock for the list - self._initialize_lock() - # main loop self.main_loop_thread = threading.Thread(target=self._work_loop, name=f'rank_{pp_rank}', daemon=True) self.main_loop_thread.start() @@ -202,20 +204,37 @@ class WorkerBase(ABC): self.output_list: Dict[UniqueKey, WorkItem] = dict() def _initialize_lock(self): + self.partition_condition_lock = threading.Condition(threading.Lock()) self.work_list_condition_lock = threading.Condition(threading.Lock()) self.output_list_condition_lock = threading.Condition(threading.Lock()) self.label_lock = threading.Condition(threading.Lock()) + def _initialize_partition(self): + partition_fn = self.partition_fn + partition_args = self.partition_args + device = self.device + with self.partition_condition_lock: + self.module_partition: nn.Module = partition_fn(*partition_args).to(device) + self.partition_condition_lock.notify_all() + def sync_global_worker_rrefs(self, pp_rank_to_worker_rref: Dict[int, PyRRef]) -> None: assert self.pp_rank_to_worker_rref is None, f"in rank {self.pp_rank}, worker has sync global workers rrefs" assert pp_rank_to_worker_rref is not None, "stage_to_workers must be a dict instead of None" self.pp_rank_to_worker_rref = pp_rank_to_worker_rref + # for some schedule need the other worker's info to initialise partition (like Chimera) + # construction of partition is executed after the registion of pp_rank_to_worker_rref + self._initialize_partition() + def get_output_by_key(self, key: UniqueKey) -> Any: with self.output_list_condition_lock: self.output_list_condition_lock.wait_for(lambda: key in self.output_list) output_work_item = self.output_list[key] - output = output_work_item.output.wait() + + output = output_work_item.output + if isinstance(output, Future): + output = output.wait() + # color_debug(f'rank {self.pp_rank}, output {type(output)}', 'get output', 'red') output_work_item.refcount += 1 @@ -231,6 +250,16 @@ class WorkerBase(ABC): def get_parameter_gradients(self) -> List[torch.Tensor]: return [p.grad for p in self.module_partition.parameters()] + def get_partition(self): + with self.partition_condition_lock: + self.partition_condition_lock.wait_for(lambda: hasattr(self, 'module_partition')) + return self.module_partition + + def get_partition_state_dict(self): + with self.partition_condition_lock: + self.partition_condition_lock.wait_for(lambda: hasattr(self, 'module_partition')) + return self.module_partition.state_dict() + # just for first pp_rank def set_input(self, microbatch_id: int, microbatch: Tuple[Any], forward_only: bool): assert self.consumer_stage_ids is not None @@ -520,6 +549,15 @@ class WorkerBase(ABC): is_last_microbatch = work_item.microbatch_id == self.num_microbatches - 1 return is_last_phase and is_last_microbatch + def _hook_before_step(self): + pass + + def _reset_context(self): + self.forward_times = 0 + self.backward_times = 0 + self.outstanding = 0 + self._initialize_outstanding_range() + # do the main loop to consume ready_list def _work_loop(self): # for init @@ -545,19 +583,17 @@ class WorkerBase(ABC): consume_result = self._consume_work_item_by_phase(work_item) color_debug( - f'rank_{self.pp_rank} [{work_item.phase}] finish consuming, result is {tensor_shape_list(consume_result)} {self._get_store_len()}', + f'rank_{self.pp_rank} [{work_item.phase}] finish consuming, result is {tensor_shape_list(consume_result)} {self._get_store_len()} | {self.work_list.keys()} | {self.output_list.keys()}', 'work loop', 'green') work_item.output.set_result(consume_result) # if is last step in one batch reset context and do step if self._is_last_step(work_item): + self._hook_before_step() if hasattr(self, 'optimizer') and not work_item.forward_only: self.step() - self.forward_times = 0 - self.backward_times = 0 - self.outstanding = 0 - self._initialize_outstanding_range() + self._reset_context() def initialize_optimizer(self, optimizer_class: type, **kwargs): self.optimizer: optim.Optimizer = optimizer_class(self.module_partition.parameters(), **kwargs) @@ -577,7 +613,7 @@ class PipelineEngineBase(ABC, nn.Module): def __init__(self, worker_type, - module_partitions, + partition_fn: Callable, stage_num, num_microbatches, device: str, @@ -588,7 +624,7 @@ class PipelineEngineBase(ABC, nn.Module): checkpoint: bool = False) -> None: super().__init__() self.worker_type = worker_type - self.module_partitions: List[nn.Module] = module_partitions + self.partition_fn: Callable = partition_fn self.chunk = chunk self.criterion = criterion self.metric = metric @@ -609,18 +645,15 @@ class PipelineEngineBase(ABC, nn.Module): def _check_argument(self) -> None: self.virtual_stage_num = self.stage_num * self.chunk - assert self.stage_num <= torch.cuda.device_count(), "stage_num must be smaller than device count!" - assert self.virtual_stage_num == len( - self.module_partitions), "stage_num * chunk must be equal to length of model partition!" def _get_actual_stage_num(self) -> int: return self.stage_num if self.chunk == 1 else self.virtual_stage_num def _create_pp_rank_to_rpc_worker_id(self) -> None: """create a map from model partition to stage_id, which is useful when use_interleave is True. - e.g. If a model is splited into 4 parts, which means len(self.module_partitions) == 3. - stage_num is 2, chunk is 2, then pp_rank_to_rpc_worker_id = [0, 1, 0, 1], that means first and third part + e.g. If a model is splited into 4 parts, which means stage_num is 2, chunk is 2, then + pp_rank_to_rpc_worker_id = [0, 1, 0, 1], that means first and third part of partitions will be moved to device 0 and the others to device 1 """ stage_num = self.stage_num @@ -647,26 +680,34 @@ class PipelineEngineBase(ABC, nn.Module): device = self.device criterion = self.criterion metric = self.metric + partition_fn = self.partition_fn + chunk = self.chunk for pp_rank in range(len(self.pp_rank_to_rpc_worker_id)): - module_partition_id = self.pp_rank_to_module_partition_id[pp_rank] + partition_id = self.pp_rank_to_module_partition_id[pp_rank] + partition_args = (partition_id, chunk, actual_stage_num) rpc_worker_id = self.pp_rank_to_rpc_worker_id[pp_rank] if device[:4] == 'cuda': device = f'cuda:{rpc_worker_id}' - module_partition = self.module_partitions[module_partition_id] self.pp_rank_to_worker_rref[pp_rank] = rpc.remote(rpc_worker_id, worker_type, - args=(module_partition, pp_rank, actual_stage_num, - num_microbatches, device, criterion, metric, - checkpoint)) + args=(partition_fn, partition_args, pp_rank, + actual_stage_num, num_microbatches, device, + criterion, metric, checkpoint)) # let each worker know global worker rref (include itself) + sync_futs = [] for pp_rank in self.pp_rank_to_worker_rref: - self.pp_rank_to_worker_rref[pp_rank].rpc_sync().sync_global_worker_rrefs(self.pp_rank_to_worker_rref) + fut = self.pp_rank_to_worker_rref[pp_rank].rpc_async().sync_global_worker_rrefs(self.pp_rank_to_worker_rref) + sync_futs.append(fut) + + for fut in sync_futs: + fut.wait() def remote_parameters(self) -> Dict[int, List[torch.Tensor]]: parameters = {} - for stage_id in self.pp_rank_to_worker_rref: + actual_stage_num = self._get_actual_stage_num() + for stage_id in range(actual_stage_num): parameters[stage_id] = [] worker_rref = self.pp_rank_to_worker_rref[stage_id] for p in worker_rref.rpc_sync().get_parameters(): @@ -675,7 +716,8 @@ class PipelineEngineBase(ABC, nn.Module): def remote_grad(self) -> Dict[int, List[torch.Tensor]]: grads = {} - for stage_id in self.pp_rank_to_worker_rref: + actual_stage_num = self._get_actual_stage_num() + for stage_id in range(actual_stage_num): grads[stage_id] = [] worker_rref = self.pp_rank_to_worker_rref[stage_id] for grad in worker_rref.rpc_sync().get_parameter_gradients(): @@ -784,7 +826,7 @@ class PipelineEngineBase(ABC, nn.Module): # collect forward result forward_result = self._collect_forward_result(output_pp_ranks, ret_future) - if not forward_only and labels is not None: + if not forward_only and hasattr(self, 'optimizer_class'): # wait for all step for pp_rank in self.pp_rank_to_worker_rref: worker_rref = self.pp_rank_to_worker_rref[pp_rank] @@ -793,9 +835,8 @@ class PipelineEngineBase(ABC, nn.Module): return forward_result def initialize_optimizer(self, optimizer_class: type, **kwargs): - actual_stage_num = self._get_actual_stage_num() self.optimizer_class = optimizer_class - for pp_rank in range(actual_stage_num): + for pp_rank in self.pp_rank_to_worker_rref: worker_rref = self.pp_rank_to_worker_rref[pp_rank] worker_rref.remote().initialize_optimizer(optimizer_class, **kwargs) diff --git a/colossalai/pipeline/rpc/_pipeline_schedule.py b/colossalai/pipeline/rpc/_pipeline_schedule.py index 991588bae..523d2d807 100644 --- a/colossalai/pipeline/rpc/_pipeline_schedule.py +++ b/colossalai/pipeline/rpc/_pipeline_schedule.py @@ -1,10 +1,12 @@ from typing import List, Callable, Dict -import torch.nn as nn +import torch +import torch.distributed as dist from torch.futures import Future from torch._C._distributed_rpc import PyRRef -from colossalai.pipeline.rpc._pipeline_base import PipelineEngineBase, WorkerBase, UniqueKey, Phase +from colossalai.pipeline.rpc._pipeline_base import PipelineEngineBase, WorkerBase, UniqueKey, Phase, WorkItem +from colossalai.pipeline.pipeline_process_group import ppg # Implementation of different Pipeline schedule # Worker defines the worker for each stage @@ -35,7 +37,7 @@ class FillDrainWorker(WorkerBase): class FillDrainPipelineEngine(PipelineEngineBase): def __init__(self, - module_partitions: List[nn.Module], + partition_fn: Callable, stage_num: int, num_microbatches: int, device: str, @@ -49,8 +51,8 @@ class FillDrainPipelineEngine(PipelineEngineBase): "if you use interleaving strategy, make sure 'num_microbatches' is a multiple of stage_num!" use_1F1B = False - super().__init__(FillDrainWorker, module_partitions, stage_num, num_microbatches, device, use_1F1B, chunk, - criterion, metric, checkpoint) + super().__init__(FillDrainWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion, + metric, checkpoint) class OneFOneBWorker(WorkerBase): @@ -94,7 +96,7 @@ class OneFOneBWorker(WorkerBase): class OneFOneBPipelineEngine(PipelineEngineBase): def __init__(self, - module_partitions: List[nn.Module], + partition_fn: Callable, stage_num: int, num_microbatches: int, device: str, @@ -106,10 +108,11 @@ class OneFOneBPipelineEngine(PipelineEngineBase): if chunk > 1: assert num_microbatches % stage_num == 0, \ "if you use interleaving strategy, make sure 'num_microbatches' is a multiple of stage_num!" + assert num_microbatches > stage_num * chunk, "num_microbatches must be greater than stage_num * chunk" use_1F1B = True - super().__init__(OneFOneBWorker, module_partitions, stage_num, num_microbatches, device, use_1F1B, chunk, - criterion, metric, checkpoint) + super().__init__(OneFOneBWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion, + metric, checkpoint) class ChimeraWorker(WorkerBase): @@ -139,21 +142,16 @@ class ChimeraWorker(WorkerBase): stage_num = self.actual_stage_num real_microbatch_num = self.num_microbatches // 2 - if self.forward_times < real_microbatch_num: - if (pp_rank + 1) % stage_num == 0: # last rank - forward_blocks = self.forward_times // (self.num_microbatches // stage_num) - if forward_blocks > self.backward_times: - target_phase = Phase.BACKWARD - target_microbatch_id = self.backward_times - else: - target_phase = Phase.FORWARD - target_microbatch_id = self.forward_times - else: # others - target_phase = Phase.FORWARD - target_microbatch_id = self.forward_times - else: + forward_block_size = 1 if self.num_microbatches < stage_num else self.num_microbatches // stage_num + forward_block_num = self.forward_times // forward_block_size + + if self.forward_times >= real_microbatch_num or \ + ((pp_rank + 1) % stage_num == 0 and forward_block_num > self.backward_times): target_phase = Phase.BACKWARD target_microbatch_id = self.backward_times + else: # others + target_phase = Phase.FORWARD + target_microbatch_id = self.forward_times # In up pipeline, microbatch_id to consume is 0, 2, 4 (2n) # In down pipeline, microbatch_id to consume is 1, 3, 5 (2n + 1) @@ -164,22 +162,85 @@ class ChimeraWorker(WorkerBase): with self.work_list_condition_lock: self.work_list_condition_lock.wait_for(lambda: target_key in self.work_list) - return target_key + def _initialize_partition(self): + # In order to ensure the down pipeline share the same parameter + # with the up pipeline, partition of down partition will be copied + # from corresponding up stage + pp_rank = self.pp_rank + stage_num = self.actual_stage_num + device = self.device + if pp_rank < stage_num: + super()._initialize_partition() + else: + # if it is down pipeline, create partition by origin method + co_up_pp_worker_rref = self.pp_rank_to_worker_rref[pp_rank - stage_num] + # get the coresponding model state dict and wait for its init + state_dict = co_up_pp_worker_rref.rpc_sync().get_partition_state_dict() + super()._initialize_partition() + self.module_partition.load_state_dict(state_dict) + + # init group for chimera in ppg + ppg.get_chimera_all_reduce_group(pp_rank) + def is_first_stage(self): return (self.pp_rank % self.actual_stage_num) == 0 def is_last_stage(self): return (self.pp_rank % self.actual_stage_num) == self.actual_stage_num - 1 + def _is_last_step(self, work_item: WorkItem) -> bool: + if work_item.forward_only: + last_phase = Phase.FORWARD + else: + last_phase = Phase.BACKWARD + is_last_phase = work_item.phase == last_phase + last_microbatch_id = self.num_microbatches - 1 + if self.pp_rank < self.actual_stage_num: + last_microbatch_id -= 1 + is_last_microbatch = work_item.microbatch_id == last_microbatch_id + return is_last_phase and is_last_microbatch + + def _get_step_order(self) -> List[int]: + # TODO : If you want to extend it to multi head chimera, overwrite here + stage_num = self.actual_stage_num + pp_rank = self.pp_rank + # pp_rank in the same device + local_device_pp_ranks = [pp_rank, stage_num * 2 - pp_rank - 1] + local_device_pp_ranks.sort(reverse=min(local_device_pp_ranks) < stage_num // 2) + return local_device_pp_ranks + + def _hook_before_step(self): + pp_rank = self.pp_rank + + orders = self._get_step_order() + step_index = orders.index(pp_rank) + + # if currrent pp_rank is not the first to do step + # wait its previous pp_rank finish step + + all_reduce_group = ppg.get_chimera_all_reduce_group(self.pp_rank) + grads = self.get_parameter_gradients() + + # print(self.pp_rank, "begin all reduce", torch.cuda.max_memory_allocated(ppg.get_local_pp_rank()), torch.cuda.max_memory_reserved(ppg.get_local_pp_rank())) + if step_index == 1: + ppg.chimera_step_lock.acquire() + + print(f'rank_{self.pp_rank} before all reduce') + dist.all_reduce_coalesced(grads, group=all_reduce_group, async_op=False) + print(f'rank_{self.pp_rank} after all reduce') + + if step_index == 0: + ppg.chimera_step_lock.release() + class ChimeraPipelineEngine(PipelineEngineBase): def __init__(self, - module_partitions, - stage_num, - num_microbatches, + partition_fn: Callable, + stage_num: int, + num_microbatches: int, device: str, criterion: Callable = None, metric: Callable = None, @@ -189,11 +250,12 @@ class ChimeraPipelineEngine(PipelineEngineBase): "In Chimera, num_microbatches must be the multiply of stage_num!" use_1F1B = False chunk = 1 - super().__init__(ChimeraWorker, module_partitions, stage_num, num_microbatches, device, use_1F1B, chunk, - criterion, metric, checkpoint) + + super().__init__(ChimeraWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion, + metric, checkpoint) def _consume_constraint(self, microbatch_id: int, forward_only: bool, ret_future: Dict[PyRRef, List[Future]], - input_worker_rrefs: List[PyRRef], output_worker_rrefs: List[PyRRef]): + input_pp_ranks: List[PyRRef], output_pp_ranks: List[PyRRef]): pass def _create_pp_rank_to_rpc_worker_id(self) -> None: @@ -254,7 +316,6 @@ class ChimeraPipelineEngine(PipelineEngineBase): up_key = UniqueKey(up_last_microbatch_id, Phase.BACKWARD) down_key = UniqueKey(down_last_microbatch_id, Phase.BACKWARD) - up_worker_rref.rpc_sync().get_output_by_key(up_key) down_worker_rref.rpc_sync().get_output_by_key(down_key) diff --git a/data/cifar-10-python.tar.gz b/data/cifar-10-python.tar.gz deleted file mode 100644 index 048005dfd..000000000 Binary files a/data/cifar-10-python.tar.gz and /dev/null differ diff --git a/tests/test_pipeline/rpc_test_utils.py b/tests/test_pipeline/rpc_test_utils.py index 1a8472820..5c332f270 100644 --- a/tests/test_pipeline/rpc_test_utils.py +++ b/tests/test_pipeline/rpc_test_utils.py @@ -8,8 +8,13 @@ import torch.multiprocessing as mp import torch.distributed.rpc as rpc from torch.optim import SGD, Adam, RMSprop, Optimizer from torch._C._distributed_rpc import _is_current_rpc_agent_set +import torch.distributed as dist from colorama import Back, Style +from colossalai.pipeline.pipeline_process_group import ppg +from colossalai.logging import disable_existing_loggers +from colossalai import launch + rpc_is_initialized = _is_current_rpc_agent_set @@ -25,12 +30,15 @@ class RpcTestModel(nn.Module): self.rank = stage_id self.is_last_rank = stage_id == actual_stage_num - 1 self.linear_name = f'linear_{stage_id}' + if stage_id == 0: - setattr(self, self.linear_name, nn.Linear(feat_num, h)) + linear = nn.Linear(feat_num, h) elif stage_id == actual_stage_num - 1: - setattr(self, self.linear_name, nn.Linear(h, 1)) + linear = nn.Linear(h, 1) else: - setattr(self, self.linear_name, nn.Linear(h, h)) + linear = nn.Linear(h, h) + + setattr(self, self.linear_name, linear) def forward(self, x) -> torch.Tensor: linear: nn.Module = getattr(self, self.linear_name) @@ -46,6 +54,8 @@ def parse_args(): parser.add_argument('--epoch', type=int, default=1) parser.add_argument('--world_size', type=int, default=2) parser.add_argument('--batch_size', type=int, default=16) + parser.add_argument('--dp_degree', type=int, default=1) + parser.add_argument('--tp_degree', type=int, default=1) parser.add_argument('--num_microbatches', type=int, default=2) parser.add_argument('--chunk', type=int, default=1) parser.add_argument('--use_checkpoint', action='store_true') @@ -74,16 +84,24 @@ def run_worker(rank, args, master_func): os.environ['MASTER_ADDR'] = args.master_addr os.environ['MASTER_PORT'] = args.master_port - # config rpc - # if cuda is used, set_device_map is a must is configured - # for cuda is not supported in torch rpc by default - options = rpc.TensorPipeRpcBackendOptions(num_worker_threads=args.num_worker_threads) - + device = args.device world_size = args.world_size - for rank_idx in range(world_size): - options.set_device_map(f'work{rank_idx}', {rank: rank_idx}) + dp_degree = args.dp_degree + tp_degree = args.tp_degree + num_worker_threads = args.num_worker_threads + host = args.master_addr + port = args.master_port + backend = 'nccl' if device == 'cuda' else 'gloo' - rpc.init_rpc(name=f'work{rank}', rank=rank, world_size=world_size, rpc_backend_options=options) + disable_existing_loggers() + + launch(dict(), rank, world_size, host, int(port), backend, verbose=False) + ppg.set_global_info(rank=rank, + world_size=world_size, + dp_degree=dp_degree, + tp_degree=tp_degree, + num_worker_threads=num_worker_threads, + device=device) # in rpc mode, only rank 0 is needed to be coded if rank == 0: diff --git a/tests/test_pipeline/test_cuda_rpc_chimera.py b/tests/test_pipeline/test_cuda_rpc_chimera.py index 98caf5913..cf9e4114f 100644 --- a/tests/test_pipeline/test_cuda_rpc_chimera.py +++ b/tests/test_pipeline/test_cuda_rpc_chimera.py @@ -1,9 +1,21 @@ import torch from torch import nn +import torch.autograd as autograd -from colossalai.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine, ChimeraPipelineEngine +from colossalai.pipeline.rpc import ChimeraPipelineEngine +from colossalai.testing import assert_close from rpc_test_utils import rpc_run, parse_args, RpcTestModel +# global variable for model created +feat_num = 100 +h = 100 + + +def partition(pp_rank: int, chunk: int, stage_num: int): + torch.manual_seed(1024) + partition = RpcTestModel(pp_rank, stage_num, feat_num, h) + return partition + def run_master(args): torch.manual_seed(100) @@ -17,23 +29,51 @@ def run_master(args): use_checkpoint = False sample_num = 1024 - feat_num = 10 - h = 10 batch_size = 1024 assert sample_num % batch_size == 0 - module_partitions = [RpcTestModel(pp_rank, actual_stage_num, feat_num, h) for pp_rank in range(actual_stage_num)] - engine = ChimeraPipelineEngine(module_partitions=module_partitions, + engine = ChimeraPipelineEngine(partition_fn=partition, stage_num=stage_num, num_microbatches=num_microbatches, device=device, checkpoint=use_checkpoint) + engine.initialize_optimizer(torch.optim.Adam, lr=1e-3) input_sample = torch.randn((sample_num, feat_num), device=device) - for _ in range(epoch): - _ = engine.forward_backward(input_sample, forward_only=False) + forward_result = engine.forward_backward(input_sample) + + cuda_rpc_result = [] + single_result = [] + actual_stage_num = engine._get_actual_stage_num() + + # compute forward result and backward grad of parameters in cuda rpc + cuda_rpc_result.append(sum(forward_result[0])) + grad = engine.remote_grad() + for stage_id in range(actual_stage_num): + for p in grad[stage_id]: + cuda_rpc_result.append(p) + + # compute forward result and backward grad of parameters just in rank_0 + test_model = nn.Sequential( + *[partition(pp_rank, chunk, actual_stage_num) for pp_rank in range(actual_stage_num)]).to(device) + # input_sample = input_sample[len(input_sample) // 2:] + input_sample = input_sample.requires_grad_() + out_val = test_model(input_sample).sum() + autograd.backward(out_val) + single_result.append(out_val) + for p in test_model.parameters(): + single_result.append(p.grad) + + # print("my") + # print(cuda_rpc_result[1]) + # print("answer:") + # print(single_result[1]) + + # assert len(cuda_rpc_result) == len(single_result) + # for r_c, r_s in zip(cuda_rpc_result, single_result): + # assert_close(r_c, r_s, 0.001, 0.001) if __name__ == "__main__": diff --git a/tests/test_pipeline/test_cuda_rpc_optimizer.py b/tests/test_pipeline/test_cuda_rpc_optimizer.py index ce0b646a3..842566730 100644 --- a/tests/test_pipeline/test_cuda_rpc_optimizer.py +++ b/tests/test_pipeline/test_cuda_rpc_optimizer.py @@ -7,6 +7,16 @@ from colossalai.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, from colossalai.testing import assert_close from rpc_test_utils import rpc_run, parse_args, RpcTestModel +# global variable for model created +feat_num = 100 +h = 100 + + +def partition(pp_rank: int, chunk: int, stage_num: int): + torch.manual_seed(1024) + partition = RpcTestModel(pp_rank, stage_num, feat_num, h) + return partition + def run_master(args): torch.manual_seed(100) @@ -20,20 +30,14 @@ def run_master(args): optimizer_class = globals()[args.optimizer] lr = 1e-3 - sample_num = 1024 - feat_num = 100 - h = 100 batch_size = 1024 assert sample_num % batch_size == 0 - batch_num = sample_num // batch_size input_sample = torch.randn((sample_num, feat_num), device=device) - module_partitions = [RpcTestModel(pp_rank, actual_stage_num, feat_num, h) for pp_rank in range(actual_stage_num)] - - engine = OneFOneBPipelineEngine(module_partitions=module_partitions, + engine = OneFOneBPipelineEngine(partition_fn=partition, stage_num=stage_num, num_microbatches=num_microbatches, device=device, @@ -55,7 +59,8 @@ def run_master(args): cuda_rpc_result.append(p) # compute forward result and backward grad of parameters just in rank_0 - test_model = nn.Sequential(*module_partitions).to(device) + test_model = nn.Sequential( + *[partition(pp_rank, chunk, actual_stage_num) for pp_rank in range(actual_stage_num)]).to(device) optimizer: Optimizer = optimizer_class(test_model.parameters(), lr=lr) input_sample = input_sample.requires_grad_() out_val = test_model(input_sample).sum() diff --git a/tests/test_pipeline/test_cuda_rpc_performance.py b/tests/test_pipeline/test_cuda_rpc_performance.py index ab16c3b2a..6a0509555 100644 --- a/tests/test_pipeline/test_cuda_rpc_performance.py +++ b/tests/test_pipeline/test_cuda_rpc_performance.py @@ -18,17 +18,30 @@ from colossalai.trainer import Trainer, hooks from colossalai.utils import MultiTimer, get_dataloader from colossalai.context import ParallelMode from colossalai.pipeline.pipelinable import PipelinableContext, PipelinableModel -from colossalai.pipeline.rpc._pipeline_schedule import OneFOneBPipelineEngine +from colossalai.pipeline.rpc import OneFOneBPipelineEngine, ChimeraPipelineEngine +from colossalai.pipeline.pipeline_process_group import ppg def flatten(x): return torch.flatten(x, 1) -class Flatten(nn.Module): +def partition(pp_rank: int, chunk: int, stage_num: int): + pipelinable = PipelinableContext() - def forward(self, x): - return torch.flatten(x, start_dim=1) + # build model partitions + with pipelinable: + # input : [B, 3, 32, 32] + _ = resnet50() + + pipelinable.policy = "customized" + + exec_seq = [ + 'conv1', 'bn1', 'relu', 'maxpool', 'layer1', 'layer2', 'layer3', 'layer4', 'avgpool', (flatten, "behind"), 'fc' + ] + pipelinable.to_layer_list(exec_seq) + partition = pipelinable.partition(chunk, stage_num, pp_rank) + return partition def run_master(args): @@ -39,37 +52,12 @@ def run_master(args): stage_num = world_size num_microbatches = args.num_microbatches - assert chunk == 1 - - pipelinable = PipelinableContext() - - # build model partitions - with pipelinable: - # input : [B, 3, 32, 32] - model = resnet50() - - exec_seq = [ - 'conv1', 'bn1', 'relu', 'maxpool', 'layer1', 'layer2', 'layer3', 'layer4', 'avgpool', (flatten, "behind"), 'fc' - ] - pipelinable.to_layer_list(exec_seq) - module_partitions: List[PipelinableModel] = [ - pipelinable.partition(chunk, stage_num, pp_rank) for pp_rank in range(world_size) - ] - # build dataloader root = os.environ.get('DATA', './data') train_dataloader, test_dataloader = build_cifar(batch_size, root, padding=4, crop=32, resize=32) criterion = nn.CrossEntropyLoss() - partition_1 = module_partitions[0] - partition_2 = [] - for model in module_partitions[1]._module_list: - partition_2.append(model) - partition_2.insert(len(partition_2) - 1, Flatten()) - partition_2 = nn.Sequential(*partition_2) - module_partitions = [partition_1, partition_2] - - pp_engine = OneFOneBPipelineEngine(module_partitions=module_partitions, + pp_engine = OneFOneBPipelineEngine(partition_fn=partition, stage_num=stage_num, num_microbatches=num_microbatches, device=device, diff --git a/tests/test_pipeline/test_cuda_rpc_pipeline.py b/tests/test_pipeline/test_cuda_rpc_pipeline.py index e7c82045c..8d03e7981 100644 --- a/tests/test_pipeline/test_cuda_rpc_pipeline.py +++ b/tests/test_pipeline/test_cuda_rpc_pipeline.py @@ -4,6 +4,16 @@ from torch import nn from colossalai.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine from rpc_test_utils import rpc_run, parse_args, RpcTestModel +# global variable for model created +feat_num = 100 +h = 100 + + +def partition(pp_rank: int, chunk: int, stage_num: int): + torch.manual_seed(1024) + partition = RpcTestModel(pp_rank, stage_num, feat_num, h) + return partition + def run_master(args): torch.manual_seed(100) @@ -13,22 +23,16 @@ def run_master(args): stage_num = args.world_size chunk = args.chunk num_microbatches = args.num_microbatches - actual_stage_num = stage_num * chunk use_checkpoint = args.use_checkpoint sample_num = 1024 - feat_num = 10 - h = 10 batch_size = 1024 assert sample_num % batch_size == 0 - batch_num = sample_num // batch_size input_sample = torch.randn((sample_num, feat_num), device=device) - module_partitions = [RpcTestModel(pp_rank, actual_stage_num, feat_num, h) for pp_rank in range(actual_stage_num)] - - engine = OneFOneBPipelineEngine(module_partitions=module_partitions, + engine = OneFOneBPipelineEngine(partition_fn=partition, stage_num=stage_num, num_microbatches=num_microbatches, device=device, diff --git a/tests/test_pipeline/test_cuda_rpc_value_correctness.py b/tests/test_pipeline/test_cuda_rpc_value_correctness.py index 98085726f..e6713478b 100644 --- a/tests/test_pipeline/test_cuda_rpc_value_correctness.py +++ b/tests/test_pipeline/test_cuda_rpc_value_correctness.py @@ -6,6 +6,15 @@ from colossalai.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, from colossalai.testing import assert_close from rpc_test_utils import rpc_run, parse_args, RpcTestModel +feat_num = 100 +h = 100 + + +def partition(pp_rank: int, chunk: int, stage_num: int): + torch.manual_seed(1024) + partition = RpcTestModel(pp_rank, stage_num, feat_num, h) + return partition + def run_master(args): torch.manual_seed(100) @@ -18,25 +27,20 @@ def run_master(args): num_microbatches = args.num_microbatches sample_num = 1024 - feat_num = 100 - h = 100 batch_size = 1024 assert sample_num % batch_size == 0 - batch_num = sample_num // batch_size input_sample = torch.randn((sample_num, feat_num), device=device) - module_partitions = [RpcTestModel(pp_rank, actual_stage_num, feat_num, h) for pp_rank in range(actual_stage_num)] - - engine = OneFOneBPipelineEngine(module_partitions=module_partitions, + engine = OneFOneBPipelineEngine(partition_fn=partition, stage_num=stage_num, num_microbatches=num_microbatches, device=device, chunk=chunk, checkpoint=use_checkpoint) - forward_result = engine.forward_backward(input_sample)[0] + forward_result = engine.forward_backward(input_sample) cuda_rpc_result = [] single_result = [] @@ -50,7 +54,8 @@ def run_master(args): cuda_rpc_result.append(p) # compute forward result and backward grad of parameters just in rank_0 - test_model = nn.Sequential(*module_partitions).to(device) + test_model = nn.Sequential( + *[partition(pp_rank, chunk, actual_stage_num) for pp_rank in range(actual_stage_num)]).to(device) input_sample = input_sample.requires_grad_() out_val = test_model(input_sample).sum() autograd.backward(out_val) diff --git a/tests/test_pipeline/test_pipeline_process_group.py b/tests/test_pipeline/test_pipeline_process_group.py index c0aff8c10..c67e4175d 100644 --- a/tests/test_pipeline/test_pipeline_process_group.py +++ b/tests/test_pipeline/test_pipeline_process_group.py @@ -4,7 +4,7 @@ import torch.distributed.rpc as rpc import torch.multiprocessing as mp import pytest -from colossalai.pipeline.pipeline_process_group import PipelineProcessGroup +from colossalai.pipeline.pipeline_process_group import ppg from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers from rpc_test_utils import pg_parse_args, rpc_is_initialized @@ -26,12 +26,12 @@ def run_worker(rank, args): disable_existing_loggers() launch(dict(), rank, world_size, host, int(port), backend, verbose=False) - pg = PipelineProcessGroup(rank=rank, - world_size=world_size, - dp_degree=dp_degree, - tp_degree=tp_degree, - num_worker_threads=num_worker_threads, - device=device) + ppg.set_global_info(rank=rank, + world_size=world_size, + dp_degree=dp_degree, + tp_degree=tp_degree, + num_worker_threads=num_worker_threads, + device=device) if rpc_is_initialized(): rpc.shutdown()