diff --git a/colossalai/pipeline/rpc/PipelineBase.py b/colossalai/pipeline/rpc/PipelineBase.py index 9bb548ff6..ebcbe0ea5 100644 --- a/colossalai/pipeline/rpc/PipelineBase.py +++ b/colossalai/pipeline/rpc/PipelineBase.py @@ -68,7 +68,7 @@ class UniqueKey: class WorkItem: __slots__ = ('stage_id', 'phase', 'args', 'kwargs', 'output', 'refcount', 'microbatch_id', 'batch_id', - 'num_microbatches') + 'num_microbatches', 'forward_only') stage_id: int phase: Phase @@ -81,6 +81,7 @@ class WorkItem: batch_id: int num_microbatches: int + forward_only: bool def __init__(self, stage_id, @@ -91,6 +92,7 @@ class WorkItem: microbatch_id, batch_id, num_microbatches, + forward_only, refcount=0) -> None: for attr_name in self.__slots__: setattr(self, attr_name, locals()[attr_name]) @@ -129,36 +131,39 @@ class Worker: pp_rank: int, actual_stage_num: int, num_microbatches: int, - max_outstanding: int, + use_1F1B: bool, device: str, checkpoint: bool = False) -> None: super().__init__() self.pp_rank = pp_rank self.actual_stage_num = actual_stage_num self.num_microbatches = num_microbatches - self.max_outstanding = max_outstanding - self.outstanding = 0 self.checkpoint = checkpoint self.device = device + self.outstanding_range = self._initialize_outstanding_range(pp_rank, actual_stage_num, use_1F1B) - self.future_devices = None if device is None or device == 'cpu' else [device] + # variable and const for context managment + self.outstanding = 0 + self.forward_times = 0 + self.backward_times = 0 + self.reset_key = UniqueKey(0, Phase.FORWARD) + # rref of other workers self.pp_rank_to_worker_rref: Dict[int, PyRRef] = None + + # topology info self.producer_stage_ids: List[int] = None self.consumer_stage_ids: List[int] = None - # module + # module partitions self.module_partition = module_partition.to(device) - self.debug_list = [None] * num_microbatches - + # container to maintain loop self.microbatch_id_to_backward_cache: Dict[int, BackwardCache] = dict() - self.work_list: Dict[UniqueKey, WorkItem] = dict() self.output_list: Dict[UniqueKey, WorkItem] = dict() - # Why must a Lock instead of RLock ? - # Because RLock cannot be pickled + # lock for the list self.work_list_condition_lock = threading.Condition(threading.Lock()) self.output_list_condition_lock = threading.Condition(threading.Lock()) @@ -168,6 +173,15 @@ class Worker: def _get_future_by_device(self): return torch.futures.Future(devices=None if self.device in (None, 'cpu') else [self.device]) + def _initialize_outstanding_range(self, pp_rank: int, actual_stage_num: int, use_1F1B: bool) -> Tuple[int]: + outstanding_range = None + if use_1F1B: + if pp_rank == actual_stage_num - 1: + outstanding_range = (0, 1) + else: + outstanding_range = (actual_stage_num, actual_stage_num) + return outstanding_range + def sync_global_worker_rrefs(self, pp_rank_to_worker_rref: Dict[int, PyRRef]) -> None: assert self.pp_rank_to_worker_rref is None, f"in rank {self.pp_rank}, worker has sync global workers rrefs" assert pp_rank_to_worker_rref is not None, "stage_to_workers must be a dict instead of None" @@ -197,8 +211,15 @@ class Worker: def get_parameter_gradients(self) -> List[torch.Tensor]: return [p.grad for p in self.module_partition.parameters()] + def reset_pp_context(self): + self.forward_times = 0 + self.backward_times = 0 + self.outstanding = 0 + self.microbatch_id_to_backward_cache.clear() + self.output_list.clear() + # just for first pp_rank - def set_input(self, microbatch_id: int, microbatch: Tuple[Any]): + def set_input(self, microbatch_id: int, microbatch: Tuple[Any], forward_only: bool): with self.work_list_condition_lock: assert self.consumer_stage_ids is not None consumer_num = len(self.consumer_stage_ids) @@ -207,11 +228,10 @@ class Worker: args = [microbatch] if isinstance(microbatch, torch.Tensor) else microbatch work_item = WorkItem(self.pp_rank, Phase.FORWARD, args, {}, output, microbatch_id, None, - self.num_microbatches, consumer_num) + self.num_microbatches, forward_only) self.work_list[key] = work_item color_debug(f'rank {self.pp_rank} receive data from dataloader', 'data dispatch', 'magenta') - self.work_list_condition_lock.notify_all() # just for last pp_rank @@ -224,24 +244,22 @@ class Worker: grad_wrt_loss = torch.tensor(1, device=self.device) work_item = WorkItem(self.pp_rank, Phase.BACKWARD, grad_wrt_loss, {}, output, microbatch_id, None, - self.num_microbatches, producer_num) + self.num_microbatches, False) color_debug(f'rank {self.pp_rank} propose backward', 'data dispatch', 'magenta') self.work_list[key] = work_item self.work_list_condition_lock.notify_all() - def subscribe_producer(self, microbatch_id: int): + def subscribe_producer(self, microbatch_id: int, forward_only: bool): """ You should call this function asynchronously """ assert self.producer_stage_ids is not None producer_num = len(self.producer_stage_ids) - consumer_num = len(self.consumer_stage_ids) assert producer_num > 0, "only stage that has producers can subscribe producers" stage_id = self.pp_rank - subscribe_forward_futures: List[Future] = [None] * producer_num output = self._get_future_by_device() @@ -259,9 +277,8 @@ class Worker: producer_args = subscribe_forward_futures[i].wait() args.extend(producer_args) - # TODO : not only args work_item_from_producer = WorkItem(stage_id, Phase.FORWARD, args, {}, output, microbatch_id, None, - self.num_microbatches, consumer_num) + self.num_microbatches, forward_only) color_debug(f'rank {self.pp_rank} get value {tensor_shape_list(args)} from fut', 'data dispatch', 'magenta') # add work_item to work_list @@ -279,13 +296,10 @@ class Worker: You should call this function asynchronously """ assert self.producer_stage_ids is not None - producer_num = len(self.producer_stage_ids) consumer_num = len(self.consumer_stage_ids) assert consumer_num > 0, "only stage that has consumers can subscribe comsumers" - # TODO : is this right? stage_id = self.pp_rank - subscribe_backward_futures: List[Future] = [None] * consumer_num output = self._get_future_by_device() @@ -305,7 +319,7 @@ class Worker: # flatten args work_item_from_consumer = WorkItem(stage_id, Phase.BACKWARD, args, {}, output, microbatch_id, None, - self.num_microbatches, producer_num) + self.num_microbatches, False) color_debug(f'rank {self.pp_rank} get value {tensor_shape_list(args)} from fut', 'data dispatch', 'magenta') @@ -341,32 +355,57 @@ class Worker: while len(self.work_list) == 0: self.work_list_condition_lock.wait() - # execute backward first (if backward phase in work_list) - select_work_list_key = None - for key in self.work_list: - work_item = self.work_list[key] - if work_item.phase == Phase.FORWARD and \ - self.max_outstanding is not None and \ - self.outstanding >= self.max_outstanding: - continue - else: - if select_work_list_key is not None and \ - select_work_list_key.phase == Phase.FORWARD and \ - key.phase == Phase.BACKWARD: - continue + # each stage must do Key(microbatch_id=0, phase=FORWARD) first + # before doing the operation, reset the context first + if self.reset_key in self.work_list: + self.reset_pp_context() - if select_work_list_key is None: - select_work_list_key = key - else: - phase_pair = (select_work_list_key.phase, key.phase) - # choose forward first - if phase_pair == (Phase.BACKWARD, Phase.FORWARD): - select_work_list_key = key - elif phase_pair == (Phase.FORWARD, Phase.BACKWARD): - continue - # choose work_item which has a smaller microbactch_id first - elif key.microbatch_id < select_work_list_key.microbatch_id: - select_work_list_key = key + # execute backward first (if backward phase in work_list) + pp_rank = self.pp_rank + actual_stage_num = self.actual_stage_num + num_microbatches = self.num_microbatches + is_last_stage = pp_rank == actual_stage_num - 1 + select_work_list_key: UniqueKey = None + + if self.outstanding_range: + if self.outstanding <= self.outstanding_range[0]: + target_phase = Phase.FORWARD + target_microbatch_id = self.forward_times + elif self.outstanding >= self.outstanding_range[1]: + target_phase = Phase.BACKWARD + target_microbatch_id = self.backward_times + else: + raise ValueError("outstanding_range[1] - outstanding_range[0] must be in [0, 1]") + + target_key = UniqueKey(target_microbatch_id, target_phase) + if target_key in self.work_list: + select_work_list_key = target_key + + # change outstanding_range at: + # 1. forward times reach actual_stage_num, this is the end of continuous forward + # 2. forward times reach num_microbatches, this is the end of 1F1B mode + if not is_last_stage and \ + select_work_list_key is not None and \ + select_work_list_key.phase == Phase.FORWARD: + if select_work_list_key.microbatch_id == actual_stage_num - 1: + outstanding_min = actual_stage_num - pp_rank - 1 + outstanding_max = actual_stage_num - pp_rank + self.outstanding_range = (outstanding_min, outstanding_max) + elif select_work_list_key.microbatch_id == num_microbatches - 1: + self.outstanding_range = (0, 0) + + else: + if self.forward_times < num_microbatches: + target_phase = Phase.FORWARD + target_microbatch_id = self.forward_times + else: + target_phase = Phase.BACKWARD + target_microbatch_id = self.backward_times + + target_key = UniqueKey(target_microbatch_id, target_phase) + + if target_key in self.work_list: + select_work_list_key = target_key return select_work_list_key @@ -375,15 +414,28 @@ class Worker: args = work_item.args kwargs = work_item.kwargs microbatch_id = work_item.microbatch_id + forward_only = work_item.forward_only consume_result = None - # if self.pp_rank == 0: - # print(f"I am rank_{self.pp_rank} microbatch_id : {microbatch_id}", work_item.phase, len(self.work_list)) + # TODO : use process manager to acquire rank info later + is_first_stage = (self.pp_rank == 0) + is_last_stage = (self.pp_rank == self.actual_stage_num - 1) - # color_debug(f'rank_{self.pp_rank} enter consume', 'consume', 'blue') + # if self.pp_rank == 3: + # print( + # f'I am rank_{self.pp_rank} microbatch_id : {microbatch_id} {phase} {self._get_store_len()} | {self.outstanding} {self.outstanding_range}' + # ) if phase == Phase.FORWARD: - self.outstanding += 1 + # remind its consumer to get data before forward + if not is_last_stage: + for stage_id in self.consumer_stage_ids: + consumer_worker_rref = self.pp_rank_to_worker_rref[stage_id] + consumer_worker_rref.remote().subscribe_producer(microbatch_id, forward_only) + self.forward_times += 1 + + if not forward_only: + self.outstanding += 1 # TODO : more elegant ? for i in range(len(args)): @@ -391,35 +443,46 @@ class Worker: if isinstance(arg_obj, torch.Tensor) and not arg_obj.requires_grad: args[i] = arg_obj.requires_grad_() - # TODO : use process manager to acquire rank info later - is_last_stage = (self.pp_rank == self.actual_stage_num - 1) - # last stage doesn't need to do checkpoint, for it will do backward instantly - if self.checkpoint and not is_last_stage: + if forward_only: + with torch.no_grad(): + consume_result = self.module_partition(*args, **kwargs) + stage_outputs = None + stage_inputs = None + use_checkpoint = None + elif self.checkpoint and not is_last_stage: with torch.no_grad(): consume_result = self.module_partition(*args, **kwargs) stage_outputs = None stage_inputs = args - self.microbatch_id_to_backward_cache[microbatch_id] = BackwardCache(stage_inputs, - stage_outputs, - checkpoint=True) + use_checkpoint = True else: consume_result = self.module_partition(*args, **kwargs) - stage_outputs = consume_result stage_inputs = args + use_checkpoint = False + + if not forward_only: self.microbatch_id_to_backward_cache[microbatch_id] = BackwardCache(stage_inputs, stage_outputs, - checkpoint=False) + checkpoint=use_checkpoint) consume_result = [consume_result] if isinstance(consume_result, torch.Tensor) else consume_result - # if it is the last stage, trigger backward automatic - if is_last_stage: - self._begin_backward(microbatch_id) + # if not forward_only, do the backward + if not forward_only: + if is_last_stage: # if it is the last stage, trigger backward automatic + self._begin_backward(microbatch_id) elif phase == Phase.BACKWARD: + # remind its producer to get data before backward + if not is_first_stage: + for stage_id in self.producer_stage_ids: + producer_worker_rref = self.pp_rank_to_worker_rref[stage_id] + producer_worker_rref.remote().subscribe_consumer(microbatch_id) + self.backward_times += 1 self.outstanding -= 1 + assert microbatch_id in self.microbatch_id_to_backward_cache, f"microbatch_id {microbatch_id} not in backward cache" backward_cache = self.microbatch_id_to_backward_cache.pop(microbatch_id) @@ -445,6 +508,9 @@ class Worker: return consume_result + def _get_store_len(self): + return f'work_list:{len(self.work_list)} output_list:{len(self.output_list)} backward_cache:{len(self.microbatch_id_to_backward_cache)}' + # do the main loop to consume ready_list def _work_loop(self): # for init @@ -461,7 +527,7 @@ class Worker: work_item = self.work_list.pop(work_item_key) color_debug( - f'rank {self.pp_rank} get a key : {work_item_key} work_item args: {tensor_shape_list(work_item.args)}', + f'rank {self.pp_rank} get a key : {work_item_key} work_item args: {tensor_shape_list(work_item.args)} {self._get_store_len()}', 'work loop', 'green') with self.output_list_condition_lock: @@ -472,7 +538,7 @@ class Worker: consume_result = self._consume_work_item_by_phase(work_item) color_debug( - f'rank_{self.pp_rank} [{work_item.phase}] finish consuming, result is {tensor_shape_list(consume_result)}', + f'rank_{self.pp_rank} [{work_item.phase}] finish consuming, result is {tensor_shape_list(consume_result)} {self._get_store_len()}', 'work loop', 'green') work_item.output.set_result(consume_result) @@ -489,9 +555,6 @@ class Worker: self.optimizer.zero_grad() -# TODO -# 1. chunk -# 2. checkpoint class PipelineEngineBase(ABC, nn.Module): def __init__(self, @@ -499,19 +562,18 @@ class PipelineEngineBase(ABC, nn.Module): stage_num, num_microbatches, device: str, - max_outstanding=None, + use_1F1B=False, chunk: int = 1, - use_interleave: bool = False, checkpoint: bool = False) -> None: super().__init__() self.module_partitions: List[nn.Module] = module_partitions self.chunk = chunk self.num_microbatches = num_microbatches self.device = device - self.max_outstanding = max_outstanding + self.use_1F1B = use_1F1B self.stage_num = stage_num self.checkpoint = checkpoint - self.use_interleave = use_interleave + self.use_interleave = chunk > 1 self.pp_rank_to_worker_rref: Dict[int, PyRRef] = dict() @@ -547,7 +609,7 @@ class PipelineEngineBase(ABC, nn.Module): def _init_worker(self): actual_stage_num = self._get_actual_stage_num() - max_outstanding = self.max_outstanding + use_1F1B = self.use_1F1B checkpoint = self.checkpoint num_microbatches = self.num_microbatches device = self.device @@ -560,8 +622,7 @@ class PipelineEngineBase(ABC, nn.Module): self.pp_rank_to_worker_rref[pp_rank] = rpc.remote(rpc_worker_id, Worker, args=(module_partition, pp_rank, actual_stage_num, - num_microbatches, max_outstanding, device, - checkpoint)) + num_microbatches, use_1F1B, device, checkpoint)) # let each worker know global worker rref (include itself) for pp_rank in range(actual_stage_num): @@ -585,46 +646,55 @@ class PipelineEngineBase(ABC, nn.Module): grads[stage_id].append(grad) return grads - def forward_backward(self, batch: torch.Tensor): - first_stage_worker = self.pp_rank_to_worker_rref[0] - microbatch_size = len(batch) // self.num_microbatches + def forward_backward(self, batch: torch.Tensor, forward_only: bool = False): + num_microbatches = self.num_microbatches + microbatch_size = len(batch) // num_microbatches actual_stage_num = self._get_actual_stage_num() - microbatch_iter = range(self.num_microbatches) + first_stage_worker = self.pp_rank_to_worker_rref[0] + last_worker_rref = self.pp_rank_to_worker_rref[actual_stage_num - 1] + + microbatch_iter = range(num_microbatches) if use_progress: microbatch_iter = tqdm(microbatch_iter) + ret_future: List[Future] = [None] * num_microbatches + from time import sleep + for microbatch_id in microbatch_iter: microbatch = batch[microbatch_size * microbatch_id:microbatch_size * (microbatch_id + 1)] - # forward subscribe asynchronously - for pp_rank in range(1, actual_stage_num, 1): - worker_rref = self.pp_rank_to_worker_rref[pp_rank] - worker_rref.rpc_async().subscribe_producer(microbatch_id) - - # backward subscribe asynchronously - for pp_rank in range(actual_stage_num - 2, -1, -1): - worker_rref = self.pp_rank_to_worker_rref[pp_rank] - worker_rref.rpc_async().subscribe_consumer(microbatch_id) + # control data input speed + # to prevent exceed of wait limitations + if microbatch_id >= actual_stage_num: + if forward_only or not self.use_1F1B: + ret_future[microbatch_id - actual_stage_num].wait() + else: + key = UniqueKey(microbatch_id - actual_stage_num, Phase.BACKWARD) + first_stage_worker.rpc_sync().get_output_by_key(key) # run one microbatch - first_stage_worker.rpc_sync().set_input(microbatch_id, microbatch) + first_stage_worker.rpc_sync().set_input(microbatch_id, microbatch, forward_only) + + key = UniqueKey(microbatch_id, Phase.FORWARD) + ret_future[microbatch_id] = last_worker_rref.rpc_async().get_output_by_key(key) # wait forward # TODO : all the node to output forward_result = None - last_worker_rref = self.pp_rank_to_worker_rref[actual_stage_num - 1] + for microbatch_id in range(self.num_microbatches): key = UniqueKey(microbatch_id, Phase.FORWARD) - ret = last_worker_rref.rpc_sync().get_output_by_key(key) + ret = ret_future[microbatch_id].wait() if forward_result is None: forward_result = [[]] * len(ret) for i in range(len(forward_result)): forward_result[i].append(ret[i]) # wait for last backward in rank0 - key = UniqueKey(self.num_microbatches - 1, Phase.BACKWARD) - first_stage_worker.rpc_sync().get_output_by_key(key) + if not forward_only: + key = UniqueKey(self.num_microbatches - 1, Phase.BACKWARD) + first_stage_worker.rpc_sync().get_output_by_key(key) return forward_result def initialize_optimizer(self, optimizer_class: type, **kwargs): @@ -654,11 +724,9 @@ class FillDrainPipelineEngine(PipelineEngineBase): num_microbatches: int, device: str, chunk: int = 1, - use_interleave: bool = False, checkpoint: bool = False) -> None: - max_outstanding = None - super().__init__(module_partitions, stage_num, num_microbatches, device, max_outstanding, chunk, use_interleave, - checkpoint) + use_1F1B = False + super().__init__(module_partitions, stage_num, num_microbatches, device, use_1F1B, chunk, checkpoint) class OneFOneBPipelineEngine(PipelineEngineBase): @@ -668,11 +736,7 @@ class OneFOneBPipelineEngine(PipelineEngineBase): stage_num: int, num_microbatches: int, device: str, - max_outstanding=None, chunk: int = 1, - use_interleave: bool = False, checkpoint: bool = False) -> None: - if max_outstanding is None: - max_outstanding = len(module_partitions) - super().__init__(module_partitions, stage_num, num_microbatches, device, max_outstanding, chunk, use_interleave, - checkpoint) + use_1F1B = True + super().__init__(module_partitions, stage_num, num_microbatches, device, use_1F1B, chunk, checkpoint) diff --git a/tests/test_pipeline/rpc_test_utils.py b/tests/test_pipeline/rpc_test_utils.py index e7caea0bd..bcfeb1760 100644 --- a/tests/test_pipeline/rpc_test_utils.py +++ b/tests/test_pipeline/rpc_test_utils.py @@ -5,13 +5,9 @@ import torch from torch import nn import torch.multiprocessing as mp import torch.distributed.rpc as rpc -from torch import autograd from torch.optim import SGD, Adam, RMSprop, Optimizer from colorama import Back, Style -from colossalai.pipeline.rpc.PipelineBase import FillDrainPipelineEngine, OneFOneBPipelineEngine -from colossalai.testing import assert_close - def color_debug(text, prefix=' ', color='blue'): color = color.upper() @@ -43,13 +39,13 @@ class RpcTestModel(nn.Module): def parse_args(): parser = argparse.ArgumentParser() + parser.add_argument('--epoch', type=int, default=1) parser.add_argument('--world_size', type=int, default=2) parser.add_argument('--num_microbatches', type=int, default=2) parser.add_argument('--chunk', type=int, default=1) parser.add_argument('--use_checkpoint', action='store_true') - parser.add_argument('--use_interleave', action='store_true') parser.add_argument('--optimizer', type=str, choices=['SGD', 'Adam', 'RMSprop'], default='SGD') - parser.add_argument('--device', type=str, default='cuda') + parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cuda') parser.add_argument('--master_addr', type=str, default='localhost') parser.add_argument('--master_port', type=str, default='29020') parser.add_argument('--num_worker_threads', type=str, default=128) diff --git a/tests/test_pipeline/test_cuda_rpc_optimizer.py b/tests/test_pipeline/test_cuda_rpc_optimizer.py index 12db694fa..b740da578 100644 --- a/tests/test_pipeline/test_cuda_rpc_optimizer.py +++ b/tests/test_pipeline/test_cuda_rpc_optimizer.py @@ -1,13 +1,7 @@ -import os -import argparse - import torch from torch import nn -import torch.multiprocessing as mp -import torch.distributed.rpc as rpc from torch import autograd from torch.optim import SGD, Adam, RMSprop, Optimizer -from colorama import Back, Style from colossalai.pipeline.rpc.PipelineBase import FillDrainPipelineEngine, OneFOneBPipelineEngine from colossalai.testing import assert_close @@ -21,7 +15,6 @@ def run_master(args): stage_num = args.world_size chunk = args.chunk actual_stage_num = stage_num * chunk - use_interleave = args.use_interleave use_checkpoint = args.use_checkpoint num_microbatches = args.num_microbatches optimizer_class = globals()[args.optimizer] @@ -45,7 +38,6 @@ def run_master(args): num_microbatches=num_microbatches, device=device, chunk=chunk, - use_interleave=use_interleave, checkpoint=use_checkpoint) engine.initialize_optimizer(optimizer_class, lr=lr) diff --git a/tests/test_pipeline/test_cuda_rpc_pipeline.py b/tests/test_pipeline/test_cuda_rpc_pipeline.py index 9dc19f13d..e4099401e 100644 --- a/tests/test_pipeline/test_cuda_rpc_pipeline.py +++ b/tests/test_pipeline/test_cuda_rpc_pipeline.py @@ -1,10 +1,5 @@ -import os -import argparse - import torch from torch import nn -import torch.multiprocessing as mp -import torch.distributed.rpc as rpc from colossalai.pipeline.rpc.PipelineBase import FillDrainPipelineEngine, OneFOneBPipelineEngine from rpc_test_utils import rpc_run, parse_args, RpcTestModel @@ -13,12 +8,12 @@ from rpc_test_utils import rpc_run, parse_args, RpcTestModel def run_master(args): torch.manual_seed(100) + epoch = args.epoch device = args.device stage_num = args.world_size chunk = args.chunk num_microbatches = args.num_microbatches actual_stage_num = stage_num * chunk - use_interleave = args.use_interleave use_checkpoint = args.use_checkpoint sample_num = 1024 @@ -38,10 +33,10 @@ def run_master(args): num_microbatches=num_microbatches, device=device, chunk=chunk, - use_interleave=use_interleave, checkpoint=use_checkpoint) - _ = engine.forward_backward(input_sample) + for _ in range(epoch): + _ = engine.forward_backward(input_sample, forward_only=False) if __name__ == "__main__": diff --git a/tests/test_pipeline/test_cuda_rpc_value_correctness.py b/tests/test_pipeline/test_cuda_rpc_value_correctness.py index c7a439c37..c8b1a9c14 100644 --- a/tests/test_pipeline/test_cuda_rpc_value_correctness.py +++ b/tests/test_pipeline/test_cuda_rpc_value_correctness.py @@ -1,12 +1,6 @@ -import os -import argparse - import torch from torch import nn -import torch.multiprocessing as mp -import torch.distributed.rpc as rpc from torch import autograd -from colorama import Back, Style from colossalai.pipeline.rpc.PipelineBase import FillDrainPipelineEngine, OneFOneBPipelineEngine from colossalai.testing import assert_close @@ -20,7 +14,6 @@ def run_master(args): stage_num = args.world_size chunk = args.chunk actual_stage_num = stage_num * chunk - use_interleave = args.use_interleave use_checkpoint = args.use_checkpoint num_microbatches = args.num_microbatches @@ -41,7 +34,6 @@ def run_master(args): num_microbatches=num_microbatches, device=device, chunk=chunk, - use_interleave=use_interleave, checkpoint=use_checkpoint) forward_result = engine.forward_backward(input_sample)