diff --git a/colossalai/pipeline/rpc/PipelineBase.py b/colossalai/pipeline/rpc/_pipeline_base.py similarity index 74% rename from colossalai/pipeline/rpc/PipelineBase.py rename to colossalai/pipeline/rpc/_pipeline_base.py index 8cc90b825..c03148505 100644 --- a/colossalai/pipeline/rpc/PipelineBase.py +++ b/colossalai/pipeline/rpc/_pipeline_base.py @@ -1,8 +1,9 @@ import threading from enum import Enum from typing import List, Any, Tuple, Dict, Callable -from abc import ABC +from abc import ABC, abstractmethod import sys +import os import torch from torch import nn @@ -17,12 +18,10 @@ from time import time from colorama import Back, Style # config for debug and test -use_color_debug = False -use_progress = False +use_color_debug = True # TODO: -# 1. replace world_size with other parameters -# 2. adjust to args and kwargs +# 1. adjust to args and kwargs (pytree) def color_debug(text, prefix=' ', color='blue'): @@ -137,24 +136,24 @@ class BackwardCache: setattr(self, arg_name, locals()[arg_name]) -class Worker: +class WorkerBase(ABC): def __init__(self, module_partition: nn.Module, pp_rank: int, actual_stage_num: int, num_microbatches: int, - use_1F1B: bool, device: str, criterion: Callable = None, + metric: Callable = None, checkpoint: bool = False) -> None: super().__init__() + self.pp_rank = pp_rank self.actual_stage_num = actual_stage_num self.num_microbatches = num_microbatches self.checkpoint = checkpoint self.device = device - self.use_1F1B = use_1F1B self._initialize_outstanding_range() # variable and const for context managment @@ -172,23 +171,14 @@ class Worker: # module partitions self.module_partition = module_partition.to(device) - if criterion: - assert callable(criterion) self.criterion = criterion + self.metric = metric - # container to maintain loop - self.microbatch_id_to_backward_cache: Dict[int, BackwardCache] = dict() - self.microbatch_id_to_labels: Dict[int, Any] = dict() - self.work_list: Dict[UniqueKey, WorkItem] = dict() - self.output_list: Dict[UniqueKey, WorkItem] = dict() + # context to maintain loop + self._initialize_context_container() # lock for the list - self.work_list_condition_lock = threading.Condition(threading.Lock()) - self.output_list_condition_lock = threading.Condition(threading.Lock()) - self.label_lock = threading.Condition(threading.Lock()) - - self.step_lock = threading.Lock() - self.step_lock.acquire() + self._initialize_lock() # main loop self.main_loop_thread = threading.Thread(target=self._work_loop, name=f'rank_{pp_rank}', daemon=True) @@ -199,13 +189,23 @@ class Worker: def _initialize_outstanding_range(self): outstanding_range = None - if self.use_1F1B: - if self.pp_rank == self.actual_stage_num - 1: - outstanding_range = (0, 1) - else: - outstanding_range = (self.actual_stage_num, self.actual_stage_num) + if self.pp_rank == self.actual_stage_num - 1: + outstanding_range = (0, 1) + else: + outstanding_range = (self.actual_stage_num, self.actual_stage_num) self.outstanding_range = outstanding_range + def _initialize_context_container(self): + self.microbatch_id_to_backward_cache: Dict[int, BackwardCache] = dict() + self.microbatch_id_to_labels: Dict[int, Any] = dict() + self.work_list: Dict[UniqueKey, WorkItem] = dict() + self.output_list: Dict[UniqueKey, WorkItem] = dict() + + def _initialize_lock(self): + self.work_list_condition_lock = threading.Condition(threading.Lock()) + self.output_list_condition_lock = threading.Condition(threading.Lock()) + self.label_lock = threading.Condition(threading.Lock()) + def sync_global_worker_rrefs(self, pp_rank_to_worker_rref: Dict[int, PyRRef]) -> None: assert self.pp_rank_to_worker_rref is None, f"in rank {self.pp_rank}, worker has sync global workers rrefs" assert pp_rank_to_worker_rref is not None, "stage_to_workers must be a dict instead of None" @@ -241,12 +241,15 @@ class Worker: forward_only) with self.work_list_condition_lock: self.work_list[key] = work_item - color_debug(f'rank {self.pp_rank} receive data from dataloader', 'data dispatch', 'magenta') + color_debug(f'rank {self.pp_rank} receive data from dataloader {self._get_store_len()}', 'data dispatch', + 'magenta') self.work_list_condition_lock.notify_all() # just for last pp_rank def set_labels(self, microbatch_id: int, microlabels: Any): - self.microbatch_id_to_labels[microbatch_id] = microlabels + with self.label_lock: + self.microbatch_id_to_labels[microbatch_id] = microlabels + self.label_lock.notify_all() # just for last pp_rank def _begin_backward(self, microbatch_id: int): @@ -354,51 +357,17 @@ class Worker: if next_rank <= self.actual_stage_num - 1: self.consumer_stage_ids.append(next_rank) + @abstractmethod def _get_work_item_key(self) -> UniqueKey: - # execute backward first (if backward phase in work_list) - pp_rank = self.pp_rank - actual_stage_num = self.actual_stage_num - num_microbatches = self.num_microbatches - is_last_stage = pp_rank == actual_stage_num - 1 + """ + this method control the order of the microbatch to consume + """ - if self.outstanding_range: - if self.outstanding <= self.outstanding_range[0]: - target_phase = Phase.FORWARD - target_microbatch_id = self.forward_times - elif self.outstanding >= self.outstanding_range[1]: - target_phase = Phase.BACKWARD - target_microbatch_id = self.backward_times - else: - raise ValueError("outstanding_range[1] - outstanding_range[0] must be in [0, 1]") + def is_first_stage(self): + return self.pp_rank == 0 - target_key = UniqueKey(target_microbatch_id, target_phase) - - # change outstanding_range at: - # 1. forward times reach actual_stage_num, this is the end of continuous forward - # 2. forward times reach num_microbatches, this is the end of 1F1B mode - if not is_last_stage and \ - target_key.phase == Phase.FORWARD: - if target_key.microbatch_id == actual_stage_num - 1: - outstanding_min = actual_stage_num - pp_rank - 1 - outstanding_max = actual_stage_num - pp_rank - self.outstanding_range = (outstanding_min, outstanding_max) - elif target_key.microbatch_id == num_microbatches - 1: - self.outstanding_range = (0, 0) - - else: - if self.forward_times < num_microbatches: - target_phase = Phase.FORWARD - target_microbatch_id = self.forward_times - else: - target_phase = Phase.BACKWARD - target_microbatch_id = self.backward_times - - target_key = UniqueKey(target_microbatch_id, target_phase) - - with self.work_list_condition_lock: - self.work_list_condition_lock.wait_for(lambda: target_key in self.work_list) - - return target_key + def is_last_stage(self): + return self.pp_rank == self.actual_stage_num - 1 def _consume_work_item_by_phase(self, work_item: WorkItem): phase = work_item.phase @@ -408,9 +377,8 @@ class Worker: forward_only = work_item.forward_only consume_result = None - # TODO : use process manager to acquire rank info later - is_first_stage = (self.pp_rank == 0) - is_last_stage = (self.pp_rank == self.actual_stage_num - 1) + is_first_stage = self.is_first_stage() + is_last_stage = self.is_last_stage() # if self.pp_rank == 0: # print( @@ -433,6 +401,21 @@ class Worker: if forward_only: with torch.no_grad(): consume_result = self.module_partition(*args, **kwargs) + + # TODO : integrate output list + if is_last_stage and self.criterion: + with self.label_lock: + self.label_lock.wait_for(lambda: microbatch_id in self.microbatch_id_to_labels) + labels = self.microbatch_id_to_labels.pop(microbatch_id) + loss: torch.Tensor = self.criterion(consume_result, labels) + if self.metric is not None: + metric_result = self.metric(consume_result, labels) + if isinstance(metric_result, torch.Tensor): + metric_result = metric_result.item() + else: + metric_result = None + consume_result = [loss.item(), metric_result] + stage_outputs = None stage_inputs = None use_checkpoint = None @@ -444,10 +427,21 @@ class Worker: use_checkpoint = True else: consume_result = self.module_partition(*args, **kwargs) + # print(f'model{self.pp_rank + 1}(param_sum: {sum([p.sum().item() for p in self.module_partition.parameters()])}) input sum: {args[0].sum().item()} forward output sum: {consume_result.sum().item()}', ) + if is_last_stage and self.criterion: + with self.label_lock: + self.label_lock.wait_for(lambda: microbatch_id in self.microbatch_id_to_labels) labels = self.microbatch_id_to_labels.pop(microbatch_id) loss: torch.Tensor = self.criterion(consume_result, labels) - consume_result = loss.item() + if self.metric is not None: + metric_result = self.metric(consume_result, labels) + if isinstance(metric_result, torch.Tensor): + metric_result = metric_result.item() + else: + metric_result = None + + consume_result = [loss.item(), metric_result] else: loss = consume_result @@ -486,6 +480,7 @@ class Worker: if use_checkpoint: stage_outputs = [self.module_partition(*stage_inputs)] + # overlap recompute and future.wait grad_tensors = get_real_args(args) @@ -513,11 +508,17 @@ class Worker: grad_sum += p.grad.sum() return grad_sum - def _is_first_step(self, work_item) -> bool: + def _is_first_step(self, work_item: WorkItem) -> bool: return work_item.phase == Phase.FORWARD and work_item.microbatch_id == 0 - def _is_last_step(self, work_item) -> bool: - return work_item.phase == Phase.BACKWARD and work_item.microbatch_id == self.num_microbatches - 1 + def _is_last_step(self, work_item: WorkItem) -> bool: + if work_item.forward_only: + last_phase = Phase.FORWARD + else: + last_phase = Phase.BACKWARD + is_last_phase = work_item.phase == last_phase + is_last_microbatch = work_item.microbatch_id == self.num_microbatches - 1 + return is_last_phase and is_last_microbatch # do the main loop to consume ready_list def _work_loop(self): @@ -551,7 +552,7 @@ class Worker: # if is last step in one batch reset context and do step if self._is_last_step(work_item): - if hasattr(self, 'optimizer'): + if hasattr(self, 'optimizer') and not work_item.forward_only: self.step() self.forward_times = 0 self.backward_times = 0 @@ -560,22 +561,22 @@ class Worker: def initialize_optimizer(self, optimizer_class: type, **kwargs): self.optimizer: optim.Optimizer = optimizer_class(self.module_partition.parameters(), **kwargs) + self.step_lock = threading.Lock() + self.step_lock.acquire() def wait_for_step(self): self.step_lock.acquire() def step(self): - # print(f'rank_{self.pp_rank}', sum([p.sum() for p in self.module_partition.parameters()])) self.optimizer.step() - # print(f'rank_{self.pp_rank}', sum([p.sum() for p in self.module_partition.parameters()])) self.optimizer.zero_grad() - self.step_lock.release() class PipelineEngineBase(ABC, nn.Module): def __init__(self, + worker_type, module_partitions, stage_num, num_microbatches, @@ -583,17 +584,19 @@ class PipelineEngineBase(ABC, nn.Module): use_1F1B=False, chunk: int = 1, criterion: Callable = None, + metric: Callable = None, checkpoint: bool = False) -> None: super().__init__() + self.worker_type = worker_type self.module_partitions: List[nn.Module] = module_partitions self.chunk = chunk self.criterion = criterion + self.metric = metric self.num_microbatches = num_microbatches self.device = device self.use_1F1B = use_1F1B self.stage_num = stage_num self.checkpoint = checkpoint - self.use_interleave = chunk > 1 self.pp_rank_to_worker_rref: Dict[int, PyRRef] = dict() @@ -601,26 +604,24 @@ class PipelineEngineBase(ABC, nn.Module): self._check_argument() self._create_pp_rank_to_rpc_worker_id() + self._create_pp_rank_to_module_partition_id() self._init_worker() - def _check_argument(self): + def _check_argument(self) -> None: self.virtual_stage_num = self.stage_num * self.chunk assert self.stage_num <= torch.cuda.device_count(), "stage_num must be smaller than device count!" assert self.virtual_stage_num == len( self.module_partitions), "stage_num * chunk must be equal to length of model partition!" - if self.use_interleave: - assert self.num_microbatches % self.stage_num == 0, "if you use interleaving strategy, make sure 'num_microbatches' is a multiple of stage_num!" - def _get_actual_stage_num(self): + def _get_actual_stage_num(self) -> int: return self.stage_num if self.chunk == 1 else self.virtual_stage_num - def _create_pp_rank_to_rpc_worker_id(self): + def _create_pp_rank_to_rpc_worker_id(self) -> None: """create a map from model partition to stage_id, which is useful when use_interleave is True. e.g. If a model is splited into 4 parts, which means len(self.module_partitions) == 3. stage_num is 2, chunk is 2, then pp_rank_to_rpc_worker_id = [0, 1, 0, 1], that means first and third part of partitions will be moved to device 0 and the others to device 1 - """ stage_num = self.stage_num actual_stage_num = self._get_actual_stage_num() @@ -628,28 +629,39 @@ class PipelineEngineBase(ABC, nn.Module): for pp_rank in range(actual_stage_num): self.pp_rank_to_rpc_worker_id[pp_rank] = pp_rank % stage_num - def _init_worker(self): + def _create_pp_rank_to_module_partition_id(self) -> None: + """By default(both fill drain and 1F1B), length of model partitions equal to + actual_stage_num, so allocate model partition to corresponding stage + """ + actual_stage_num = self._get_actual_stage_num() + self.pp_rank_to_module_partition_id = [0] * actual_stage_num + for pp_rank in range(actual_stage_num): + self.pp_rank_to_module_partition_id[pp_rank] = pp_rank + + def _init_worker(self) -> None: actual_stage_num = self._get_actual_stage_num() - use_1F1B = self.use_1F1B + worker_type = self.worker_type checkpoint = self.checkpoint num_microbatches = self.num_microbatches device = self.device criterion = self.criterion + metric = self.metric - for pp_rank in range(actual_stage_num): - module_partition = self.module_partitions[pp_rank] + for pp_rank in range(len(self.pp_rank_to_rpc_worker_id)): + module_partition_id = self.pp_rank_to_module_partition_id[pp_rank] rpc_worker_id = self.pp_rank_to_rpc_worker_id[pp_rank] if device[:4] == 'cuda': device = f'cuda:{rpc_worker_id}' + module_partition = self.module_partitions[module_partition_id] self.pp_rank_to_worker_rref[pp_rank] = rpc.remote(rpc_worker_id, - Worker, + worker_type, args=(module_partition, pp_rank, actual_stage_num, - num_microbatches, use_1F1B, device, criterion, + num_microbatches, device, criterion, metric, checkpoint)) # let each worker know global worker rref (include itself) - for pp_rank in range(actual_stage_num): + for pp_rank in self.pp_rank_to_worker_rref: self.pp_rank_to_worker_rref[pp_rank].rpc_sync().sync_global_worker_rrefs(self.pp_rank_to_worker_rref) def remote_parameters(self) -> Dict[int, List[torch.Tensor]]: @@ -670,65 +682,110 @@ class PipelineEngineBase(ABC, nn.Module): grads[stage_id].append(grad) return grads + def get_input_pp_ranks(self) -> List[int]: + return [0] + + def get_output_pp_ranks(self) -> List[int]: + return [self._get_actual_stage_num() - 1] + + def _consume_constraint(self, microbatch_id: int, forward_only: bool, input_pp_ranks: List[int], + output_pp_ranks: List[int], ret_future): + actual_stage_num = self._get_actual_stage_num() + use_1F1B = self.use_1F1B + if microbatch_id >= actual_stage_num: + if forward_only or not use_1F1B: + for pp_rank in output_pp_ranks: + ret_future[pp_rank][microbatch_id - actual_stage_num].wait() + else: + key = UniqueKey(microbatch_id - actual_stage_num, Phase.BACKWARD) + for pp_rank in input_pp_ranks: + worker_rref = self.pp_rank_to_worker_rref[pp_rank] + worker_rref.rpc_sync().get_output_by_key(key) + + def _create_ret_future(self, output_pp_ranks: List[int]) -> Dict[int, List[Future]]: + num_microbatches = self.num_microbatches + return {pp_rank: [None] * num_microbatches for pp_rank in output_pp_ranks} + + def _set_input(self, input_pp_ranks: List[int], microbatch_id: int, microbatch, forward_only: bool): + for pp_rank in input_pp_ranks: + worker_rref = self.pp_rank_to_worker_rref[pp_rank] + # TODO : add relationship between input_pp_ranks and parts of microbatch + worker_rref.remote().set_input(microbatch_id, microbatch, forward_only) + + def _set_labels(self, output_pp_ranks: List[int], microbatch_id: int, microlabels): + for pp_rank in output_pp_ranks: + worker_rref = self.pp_rank_to_worker_rref[pp_rank] + # TODO : add relationship between output_pp_ranks and parts of microlabels + worker_rref.remote().set_labels(microbatch_id, microlabels) + + def _subscribe_forward(self, microbatch_id: int, output_pp_ranks: List[int], ret_future: Dict[int, List[Future]]): + key = UniqueKey(microbatch_id, Phase.FORWARD) + for pp_rank in output_pp_ranks: + worker_rref = self.pp_rank_to_worker_rref[pp_rank] + ret_future[pp_rank][microbatch_id] = worker_rref.rpc_async().get_output_by_key(key) + + def _ensure_backward(self, forward_only: bool, input_pp_ranks: List[int]): + if not forward_only: + for pp_rank in input_pp_ranks: + worker_rref = self.pp_rank_to_worker_rref[pp_rank] + key = UniqueKey(self.num_microbatches - 1, Phase.BACKWARD) + worker_rref.rpc_sync().get_output_by_key(key) + + def _collect_forward_result(self, output_pp_ranks: List[int], ret_future: Dict[int, List[Future]]): + forward_result = [] + for pp_rank in output_pp_ranks: + worker_forward_result = [None] * self.num_microbatches + for microbatch_id in range(self.num_microbatches): + ret = ret_future[pp_rank][microbatch_id].wait() + worker_forward_result[microbatch_id] = ret + worker_forward_result = list(zip(*worker_forward_result)) + forward_result.extend(worker_forward_result) + + return forward_result + def forward_backward(self, batch: torch.Tensor, labels: torch.Tensor = None, forward_only: bool = False): if labels is not None: assert len(batch) == len(labels) + if not forward_only: + assert hasattr(self, 'optimizer_class') num_microbatches = self.num_microbatches microbatch_size = len(batch) // num_microbatches - actual_stage_num = self._get_actual_stage_num() - first_worker_rref = self.pp_rank_to_worker_rref[0] - last_worker_rref = self.pp_rank_to_worker_rref[actual_stage_num - 1] + # If Chimera mode is used, then rank of down pipeline is excluded from 'input_pp_ranks' or 'output_pp_ranks' + input_pp_ranks = self.get_input_pp_ranks() + output_pp_ranks = self.get_output_pp_ranks() - microbatch_iter = range(num_microbatches) - if use_progress: - microbatch_iter = tqdm(microbatch_iter) + # a cache to collect data and control flow + ret_future = self._create_ret_future(output_pp_ranks) - ret_future: List[Future] = [None] * num_microbatches - for microbatch_id in microbatch_iter: - # control data input speed + for microbatch_id in range(num_microbatches): + # control data input speed # to prevent exceed of wait limitations - if microbatch_id >= actual_stage_num: - if forward_only or not self.use_1F1B: - ret_future[microbatch_id - actual_stage_num].wait() - else: - key = UniqueKey(microbatch_id - actual_stage_num, Phase.BACKWARD) - first_worker_rref.rpc_sync().get_output_by_key(key) + self._consume_constraint(microbatch_id, forward_only, input_pp_ranks, output_pp_ranks, ret_future) # set input microbatch = batch[microbatch_size * microbatch_id:microbatch_size * (microbatch_id + 1)] microbatch = microbatch.cuda() - first_worker_rref.remote().set_input(microbatch_id, microbatch, forward_only) + self._set_input(input_pp_ranks, microbatch_id, microbatch, forward_only) + # set labels - if not forward_only and labels is not None: + if labels is not None: microlabels = labels[microbatch_size * microbatch_id:microbatch_size * (microbatch_id + 1)] microlabels = microlabels.cuda() - last_worker_rref.remote().set_labels(microbatch_id, microlabels) + self._set_labels(output_pp_ranks, microbatch_id, microlabels) - key = UniqueKey(microbatch_id, Phase.FORWARD) - ret_future[microbatch_id] = last_worker_rref.rpc_async().get_output_by_key(key) + # get data asynchronously + self._subscribe_forward(microbatch_id, output_pp_ranks, ret_future) - # wait for last backward in rank0 - if not forward_only: - key = UniqueKey(self.num_microbatches - 1, Phase.BACKWARD) - first_worker_rref.rpc_sync().get_output_by_key(key) + # wait for first rank to ensure all backwards are done + self._ensure_backward(forward_only, input_pp_ranks) # collect forward result - # TODO : all the node to output - forward_result = None + forward_result = self._collect_forward_result(output_pp_ranks, ret_future) - for microbatch_id in range(self.num_microbatches): - key = UniqueKey(microbatch_id, Phase.FORWARD) - ret = ret_future[microbatch_id].wait() - if forward_result is None: - forward_result = [[]] * len(ret) - for i in range(len(forward_result)): - forward_result[i].append(ret[i]) - - if hasattr(self, 'optimizer_class'): + if not forward_only and labels is not None: # wait for all step - # TODO : more elegant ? for pp_rank in self.pp_rank_to_worker_rref: worker_rref = self.pp_rank_to_worker_rref[pp_rank] worker_rref.rpc_sync().wait_for_step() @@ -751,31 +808,3 @@ class PipelineEngineBase(ABC, nn.Module): for fut in self.step_futs: fut.wait() - - -class FillDrainPipelineEngine(PipelineEngineBase): - - def __init__(self, - module_partitions: List[nn.Module], - stage_num: int, - num_microbatches: int, - device: str, - chunk: int = 1, - criterion: Callable = None, - checkpoint: bool = False) -> None: - use_1F1B = False - super().__init__(module_partitions, stage_num, num_microbatches, device, use_1F1B, chunk, criterion, checkpoint) - - -class OneFOneBPipelineEngine(PipelineEngineBase): - - def __init__(self, - module_partitions: List[nn.Module], - stage_num: int, - num_microbatches: int, - device: str, - chunk: int = 1, - criterion: Callable = None, - checkpoint: bool = False) -> None: - use_1F1B = True - super().__init__(module_partitions, stage_num, num_microbatches, device, use_1F1B, chunk, criterion, checkpoint) diff --git a/colossalai/pipeline/rpc/_pipeline_schedule.py b/colossalai/pipeline/rpc/_pipeline_schedule.py new file mode 100644 index 000000000..991588bae --- /dev/null +++ b/colossalai/pipeline/rpc/_pipeline_schedule.py @@ -0,0 +1,277 @@ +from typing import List, Callable, Dict + +import torch.nn as nn +from torch.futures import Future +from torch._C._distributed_rpc import PyRRef + +from colossalai.pipeline.rpc._pipeline_base import PipelineEngineBase, WorkerBase, UniqueKey, Phase + +# Implementation of different Pipeline schedule +# Worker defines the worker for each stage +# PipelineEngine is the class for use + + +class FillDrainWorker(WorkerBase): + + def _get_work_item_key(self) -> UniqueKey: + # execute backward first (if backward phase in work_list) + num_microbatches = self.num_microbatches + + if self.forward_times < num_microbatches: + target_phase = Phase.FORWARD + target_microbatch_id = self.forward_times + else: + target_phase = Phase.BACKWARD + target_microbatch_id = self.backward_times + + target_key = UniqueKey(target_microbatch_id, target_phase) + + with self.work_list_condition_lock: + self.work_list_condition_lock.wait_for(lambda: target_key in self.work_list) + + return target_key + + +class FillDrainPipelineEngine(PipelineEngineBase): + + def __init__(self, + module_partitions: List[nn.Module], + stage_num: int, + num_microbatches: int, + device: str, + chunk: int = 1, + criterion: Callable = None, + metric: Callable = None, + checkpoint: bool = False) -> None: + + if chunk > 1: + assert num_microbatches % stage_num == 0, \ + "if you use interleaving strategy, make sure 'num_microbatches' is a multiple of stage_num!" + use_1F1B = False + + super().__init__(FillDrainWorker, module_partitions, stage_num, num_microbatches, device, use_1F1B, chunk, + criterion, metric, checkpoint) + + +class OneFOneBWorker(WorkerBase): + + def _get_work_item_key(self) -> UniqueKey: + # execute backward first (if backward phase in work_list) + pp_rank = self.pp_rank + actual_stage_num = self.actual_stage_num + num_microbatches = self.num_microbatches + is_last_stage = pp_rank == actual_stage_num - 1 + + if self.outstanding <= self.outstanding_range[0]: + target_phase = Phase.FORWARD + target_microbatch_id = self.forward_times + elif self.outstanding >= self.outstanding_range[1]: + target_phase = Phase.BACKWARD + target_microbatch_id = self.backward_times + else: + raise ValueError("outstanding_range[1] - outstanding_range[0] must be in [0, 1]") + + target_key = UniqueKey(target_microbatch_id, target_phase) + + # change outstanding_range at: + # 1. forward times reach actual_stage_num, this is the end of continuous forward + # 2. forward times reach num_microbatches, this is the end of 1F1B mode + if not is_last_stage and \ + target_key.phase == Phase.FORWARD: + if target_key.microbatch_id == actual_stage_num - 1: + outstanding_min = actual_stage_num - pp_rank - 1 + outstanding_max = actual_stage_num - pp_rank + self.outstanding_range = (outstanding_min, outstanding_max) + elif target_key.microbatch_id == num_microbatches - 1: + self.outstanding_range = (0, 0) + + with self.work_list_condition_lock: + self.work_list_condition_lock.wait_for(lambda: target_key in self.work_list) + + return target_key + + +class OneFOneBPipelineEngine(PipelineEngineBase): + + def __init__(self, + module_partitions: List[nn.Module], + stage_num: int, + num_microbatches: int, + device: str, + chunk: int = 1, + criterion: Callable = None, + metric: Callable = None, + checkpoint: bool = False) -> None: + + if chunk > 1: + assert num_microbatches % stage_num == 0, \ + "if you use interleaving strategy, make sure 'num_microbatches' is a multiple of stage_num!" + use_1F1B = True + + super().__init__(OneFOneBWorker, module_partitions, stage_num, num_microbatches, device, use_1F1B, chunk, + criterion, metric, checkpoint) + + +class ChimeraWorker(WorkerBase): + + def _get_producer_consumer(self) -> None: + rank = self.pp_rank + min_pp_rank = (rank // self.actual_stage_num) * self.actual_stage_num + max_pp_rank = min_pp_rank + self.actual_stage_num - 1 + + assert self.producer_stage_ids is None, f"all the producers of rank {rank} has been subscribed" + assert self.consumer_stage_ids is None, f"all the consumers of rank {rank} has been subscribed" + + # should be aranged in order, the order of the input of current forward + self.producer_stage_ids = [] + self.consumer_stage_ids = [] + + # Just for demo + prev_rank = rank - 1 + next_rank = rank + 1 + if prev_rank >= min_pp_rank: + self.producer_stage_ids.append(prev_rank) + if next_rank <= max_pp_rank: + self.consumer_stage_ids.append(next_rank) + + def _get_work_item_key(self) -> UniqueKey: + pp_rank = self.pp_rank + stage_num = self.actual_stage_num + real_microbatch_num = self.num_microbatches // 2 + + if self.forward_times < real_microbatch_num: + if (pp_rank + 1) % stage_num == 0: # last rank + forward_blocks = self.forward_times // (self.num_microbatches // stage_num) + if forward_blocks > self.backward_times: + target_phase = Phase.BACKWARD + target_microbatch_id = self.backward_times + else: + target_phase = Phase.FORWARD + target_microbatch_id = self.forward_times + else: # others + target_phase = Phase.FORWARD + target_microbatch_id = self.forward_times + else: + target_phase = Phase.BACKWARD + target_microbatch_id = self.backward_times + + # In up pipeline, microbatch_id to consume is 0, 2, 4 (2n) + # In down pipeline, microbatch_id to consume is 1, 3, 5 (2n + 1) + real_target_microbatch_id = target_microbatch_id * 2 + if pp_rank >= stage_num: + real_target_microbatch_id += 1 + target_key = UniqueKey(real_target_microbatch_id, target_phase) + + with self.work_list_condition_lock: + self.work_list_condition_lock.wait_for(lambda: target_key in self.work_list) + + return target_key + + def is_first_stage(self): + return (self.pp_rank % self.actual_stage_num) == 0 + + def is_last_stage(self): + return (self.pp_rank % self.actual_stage_num) == self.actual_stage_num - 1 + + +class ChimeraPipelineEngine(PipelineEngineBase): + + def __init__(self, + module_partitions, + stage_num, + num_microbatches, + device: str, + criterion: Callable = None, + metric: Callable = None, + checkpoint: bool = False) -> None: + + assert num_microbatches % stage_num == 0, \ + "In Chimera, num_microbatches must be the multiply of stage_num!" + use_1F1B = False + chunk = 1 + super().__init__(ChimeraWorker, module_partitions, stage_num, num_microbatches, device, use_1F1B, chunk, + criterion, metric, checkpoint) + + def _consume_constraint(self, microbatch_id: int, forward_only: bool, ret_future: Dict[PyRRef, List[Future]], + input_worker_rrefs: List[PyRRef], output_worker_rrefs: List[PyRRef]): + pass + + def _create_pp_rank_to_rpc_worker_id(self) -> None: + stage_num = self.stage_num + self.pp_rank_to_rpc_worker_id = [0] * (stage_num * 2) + for pp_rank in range(stage_num): + self.pp_rank_to_rpc_worker_id[pp_rank] = pp_rank + self.pp_rank_to_rpc_worker_id[pp_rank + stage_num] = stage_num - pp_rank - 1 + + def _create_pp_rank_to_module_partition_id(self) -> None: + stage_num = self.stage_num + self.pp_rank_to_module_partition_id = [0] * (stage_num * 2) + for pp_rank in range(stage_num): + self.pp_rank_to_module_partition_id[pp_rank] = pp_rank + self.pp_rank_to_module_partition_id[pp_rank + stage_num] = pp_rank + + def _create_ret_future(self, output_pp_ranks: List[int]) -> Dict[int, List[Future]]: + num_microbatches = self.num_microbatches + stage_num = self.stage_num + up_ret_future = {pp_rank: [None] * num_microbatches for pp_rank in output_pp_ranks} + down_ret_future = {pp_rank + stage_num: [None] * num_microbatches for pp_rank in output_pp_ranks} + # merge up and down + return {**up_ret_future, **down_ret_future} + + def _set_input(self, input_pp_ranks: List[int], microbatch_id: int, microbatch, forward_only: bool): + # offset is 0 for all the ranks in up pipeline + # offset is stage_num for all the ranks in down pipeline + offset = (microbatch_id % 2) * self.stage_num + for pp_rank in input_pp_ranks: + worker_rref = self.pp_rank_to_worker_rref[pp_rank + offset] + worker_rref.remote().set_input(microbatch_id, microbatch, forward_only) + + def _set_labels(self, output_pp_ranks: List[int], microbatch_id: int, microlabels): + # offset is 0 for all the ranks in up pipeline + # offset is stage_num for all the ranks in down pipeline + offset = (microbatch_id % 2) * self.stage_num + for pp_rank in output_pp_ranks: + worker_rref = self.pp_rank_to_worker_rref[pp_rank + offset] + worker_rref.remote().set_labels(microbatch_id, microlabels) + + def _subscribe_forward(self, microbatch_id: int, output_pp_ranks: List[int], ret_future: Dict[int, List[Future]]): + key = UniqueKey(microbatch_id, Phase.FORWARD) + offset = (microbatch_id % 2) * self.stage_num + for pp_rank in output_pp_ranks: + worker_rref = self.pp_rank_to_worker_rref[pp_rank + offset] + ret_future[pp_rank + offset][microbatch_id] = worker_rref.rpc_async().get_output_by_key(key) + + def _ensure_backward(self, forward_only: bool, input_pp_ranks: List[int]): + stage_num = self.stage_num + num_microbatches = self.num_microbatches + if not forward_only: + for pp_rank in input_pp_ranks: + up_last_microbatch_id = num_microbatches - 2 + down_last_microbatch_id = num_microbatches - 1 + + up_worker_rref = self.pp_rank_to_worker_rref[pp_rank] + down_worker_rref = self.pp_rank_to_worker_rref[pp_rank + stage_num] + + up_key = UniqueKey(up_last_microbatch_id, Phase.BACKWARD) + down_key = UniqueKey(down_last_microbatch_id, Phase.BACKWARD) + + up_worker_rref.rpc_sync().get_output_by_key(up_key) + down_worker_rref.rpc_sync().get_output_by_key(down_key) + + def _collect_forward_result(self, output_pp_ranks: List[int], ret_future: Dict[PyRRef, List[Future]]): + """Logic of collection of forward in Chimera. + Currently, only one input one output model is supported + """ + stage_num = self.stage_num + forward_result = [] + for pp_rank in output_pp_ranks: + worker_forward_result = [None] * self.num_microbatches + for microbatch_id in range(self.num_microbatches): + offset = (microbatch_id % 2) * stage_num + ret = ret_future[pp_rank + offset][microbatch_id].wait() + worker_forward_result[microbatch_id] = ret + + worker_forward_result = list(zip(*worker_forward_result)) + forward_result.extend(worker_forward_result) + + return forward_result diff --git a/data/cifar-10-python.tar.gz b/data/cifar-10-python.tar.gz new file mode 100644 index 000000000..048005dfd Binary files /dev/null and b/data/cifar-10-python.tar.gz differ diff --git a/tests/test_pipeline/test_cuda_rpc_chimera.py b/tests/test_pipeline/test_cuda_rpc_chimera.py new file mode 100644 index 000000000..98caf5913 --- /dev/null +++ b/tests/test_pipeline/test_cuda_rpc_chimera.py @@ -0,0 +1,43 @@ +import torch +from torch import nn + +from colossalai.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine, ChimeraPipelineEngine +from rpc_test_utils import rpc_run, parse_args, RpcTestModel + + +def run_master(args): + torch.manual_seed(100) + + epoch = args.epoch + device = args.device + stage_num = 4 + chunk = 1 + num_microbatches = 4 + actual_stage_num = 4 + use_checkpoint = False + + sample_num = 1024 + feat_num = 10 + h = 10 + batch_size = 1024 + + assert sample_num % batch_size == 0 + + module_partitions = [RpcTestModel(pp_rank, actual_stage_num, feat_num, h) for pp_rank in range(actual_stage_num)] + engine = ChimeraPipelineEngine(module_partitions=module_partitions, + stage_num=stage_num, + num_microbatches=num_microbatches, + device=device, + checkpoint=use_checkpoint) + + input_sample = torch.randn((sample_num, feat_num), device=device) + + for _ in range(epoch): + _ = engine.forward_backward(input_sample, forward_only=False) + + +if __name__ == "__main__": + args = parse_args() + args.world_size = 4 + args.num_microbatches = 4 + rpc_run(args, run_master) diff --git a/tests/test_pipeline/test_cuda_rpc_optimizer.py b/tests/test_pipeline/test_cuda_rpc_optimizer.py index 5d3c193e3..ce0b646a3 100644 --- a/tests/test_pipeline/test_cuda_rpc_optimizer.py +++ b/tests/test_pipeline/test_cuda_rpc_optimizer.py @@ -3,7 +3,7 @@ from torch import nn from torch import autograd from torch.optim import SGD, Adam, RMSprop, Optimizer -from colossalai.pipeline.rpc.PipelineBase import FillDrainPipelineEngine, OneFOneBPipelineEngine +from colossalai.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine from colossalai.testing import assert_close from rpc_test_utils import rpc_run, parse_args, RpcTestModel diff --git a/tests/test_pipeline/test_cuda_rpc_performance.py b/tests/test_pipeline/test_cuda_rpc_performance.py new file mode 100644 index 000000000..ab16c3b2a --- /dev/null +++ b/tests/test_pipeline/test_cuda_rpc_performance.py @@ -0,0 +1,102 @@ +import os +from typing import Callable, List, Optional, Type, Union +import time + +import pytest +import torch +import torch.nn as nn +from titans.dataloader.cifar10 import build_cifar +from torchvision.models import resnet50 +from torchvision.models.resnet import BasicBlock, Bottleneck, conv1x1 +from tqdm import tqdm + +from rpc_test_utils import rpc_run, parse_args +import colossalai +import colossalai.nn as col_nn +from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.trainer import Trainer, hooks +from colossalai.utils import MultiTimer, get_dataloader +from colossalai.context import ParallelMode +from colossalai.pipeline.pipelinable import PipelinableContext, PipelinableModel +from colossalai.pipeline.rpc._pipeline_schedule import OneFOneBPipelineEngine + + +def flatten(x): + return torch.flatten(x, 1) + + +class Flatten(nn.Module): + + def forward(self, x): + return torch.flatten(x, start_dim=1) + + +def run_master(args): + batch_size = args.batch_size + chunk = args.chunk + device = args.device + world_size = args.world_size + stage_num = world_size + num_microbatches = args.num_microbatches + + assert chunk == 1 + + pipelinable = PipelinableContext() + + # build model partitions + with pipelinable: + # input : [B, 3, 32, 32] + model = resnet50() + + exec_seq = [ + 'conv1', 'bn1', 'relu', 'maxpool', 'layer1', 'layer2', 'layer3', 'layer4', 'avgpool', (flatten, "behind"), 'fc' + ] + pipelinable.to_layer_list(exec_seq) + module_partitions: List[PipelinableModel] = [ + pipelinable.partition(chunk, stage_num, pp_rank) for pp_rank in range(world_size) + ] + + # build dataloader + root = os.environ.get('DATA', './data') + train_dataloader, test_dataloader = build_cifar(batch_size, root, padding=4, crop=32, resize=32) + criterion = nn.CrossEntropyLoss() + + partition_1 = module_partitions[0] + partition_2 = [] + for model in module_partitions[1]._module_list: + partition_2.append(model) + partition_2.insert(len(partition_2) - 1, Flatten()) + partition_2 = nn.Sequential(*partition_2) + module_partitions = [partition_1, partition_2] + + pp_engine = OneFOneBPipelineEngine(module_partitions=module_partitions, + stage_num=stage_num, + num_microbatches=num_microbatches, + device=device, + chunk=chunk, + criterion=criterion, + checkpoint=False) + + pp_engine.initialize_optimizer(torch.optim.Adam, lr=1e-3) + s = time.time() + + for bx, by in tqdm(train_dataloader): + pp_engine.forward_backward(bx, labels=by, forward_only=False) + + cost_time = time.time() - s + + print("total cost time :", cost_time) + print("cost time per batch:", cost_time / len(train_dataloader)) + + +@pytest.mark.skip("Test for performance, no need for CI") +def main(): + args = parse_args() + # this is due to limitation of partition function + args.world_size = 2 + args.chunk = 1 + rpc_run(args, run_master) + + +if __name__ == '__main__': + main() diff --git a/tests/test_pipeline/test_cuda_rpc_pipeline.py b/tests/test_pipeline/test_cuda_rpc_pipeline.py index e4099401e..e7c82045c 100644 --- a/tests/test_pipeline/test_cuda_rpc_pipeline.py +++ b/tests/test_pipeline/test_cuda_rpc_pipeline.py @@ -1,7 +1,7 @@ import torch from torch import nn -from colossalai.pipeline.rpc.PipelineBase import FillDrainPipelineEngine, OneFOneBPipelineEngine +from colossalai.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine from rpc_test_utils import rpc_run, parse_args, RpcTestModel diff --git a/tests/test_pipeline/test_cuda_rpc_value_correctness.py b/tests/test_pipeline/test_cuda_rpc_value_correctness.py index c8b1a9c14..98085726f 100644 --- a/tests/test_pipeline/test_cuda_rpc_value_correctness.py +++ b/tests/test_pipeline/test_cuda_rpc_value_correctness.py @@ -2,7 +2,7 @@ import torch from torch import nn from torch import autograd -from colossalai.pipeline.rpc.PipelineBase import FillDrainPipelineEngine, OneFOneBPipelineEngine +from colossalai.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine from colossalai.testing import assert_close from rpc_test_utils import rpc_run, parse_args, RpcTestModel @@ -36,7 +36,7 @@ def run_master(args): chunk=chunk, checkpoint=use_checkpoint) - forward_result = engine.forward_backward(input_sample) + forward_result = engine.forward_backward(input_sample)[0] cuda_rpc_result = [] single_result = []