From 6159d4541736eed834ec9bb9730e20b2d8c59aa3 Mon Sep 17 00:00:00 2001 From: Kirigaya Kazuto <59416203+LSTM-Kirigaya@users.noreply.github.com> Date: Wed, 7 Sep 2022 19:01:06 +0800 Subject: [PATCH] [pipeline/tuning] improve dispatch performance both time and space cost (#1544) --- colossalai/pipeline/rpc/PipelineBase.py | 347 ++++++++++-------- tests/test_pipeline/rpc_test_utils.py | 1 + .../test_pipeline/test_cuda_rpc_optimizer.py | 1 - 3 files changed, 194 insertions(+), 155 deletions(-) diff --git a/colossalai/pipeline/rpc/PipelineBase.py b/colossalai/pipeline/rpc/PipelineBase.py index ebcbe0ea5..34c428b9d 100644 --- a/colossalai/pipeline/rpc/PipelineBase.py +++ b/colossalai/pipeline/rpc/PipelineBase.py @@ -1,7 +1,8 @@ import threading from enum import Enum -from typing import List, Any, Tuple, Dict +from typing import List, Any, Tuple, Dict, Callable from abc import ABC +import sys import torch from torch import nn @@ -11,6 +12,7 @@ from torch._C._distributed_rpc import PyRRef from torch import autograd from torch import optim from tqdm import tqdm +from time import time from colorama import Back, Style @@ -30,6 +32,10 @@ def color_debug(text, prefix=' ', color='blue'): def tensor_shape_list(tensors): + if tensors is None: + return None + if isinstance(tensors, (int, float)): + return tensors if isinstance(tensors, torch.Tensor): return tensors.shape shapes = [] @@ -41,6 +47,25 @@ def tensor_shape_list(tensors): return shapes +def get_real_args(args): + if isinstance(args, torch.Tensor): + return args + elif isinstance(args, list): + real_args = [] + for arg in args: + if isinstance(arg, Future): + value = arg.wait() + else: + value = arg + if isinstance(value, list): + real_args.extend(value) + else: + real_args.append(value) + return real_args + else: + raise TypeError(f"Expect receive tensor or list, but receive {type(args)}") + + class Phase(Enum): FORWARD = 0 BACKWARD = 1 @@ -112,18 +137,6 @@ class BackwardCache: setattr(self, arg_name, locals()[arg_name]) -class RemoteExecutor: - - def __init__(self) -> None: - pass - - -class RemoteOptimizer: - - def __init__(self) -> None: - pass - - class Worker: def __init__(self, @@ -133,6 +146,7 @@ class Worker: num_microbatches: int, use_1F1B: bool, device: str, + criterion: Callable = None, checkpoint: bool = False) -> None: super().__init__() self.pp_rank = pp_rank @@ -140,7 +154,8 @@ class Worker: self.num_microbatches = num_microbatches self.checkpoint = checkpoint self.device = device - self.outstanding_range = self._initialize_outstanding_range(pp_rank, actual_stage_num, use_1F1B) + self.use_1F1B = use_1F1B + self._initialize_outstanding_range() # variable and const for context managment self.outstanding = 0 @@ -157,30 +172,39 @@ class Worker: # module partitions self.module_partition = module_partition.to(device) + if criterion: + assert callable(criterion) + self.criterion = criterion # container to maintain loop self.microbatch_id_to_backward_cache: Dict[int, BackwardCache] = dict() + self.microbatch_id_to_labels: Dict[int, Any] = dict() self.work_list: Dict[UniqueKey, WorkItem] = dict() self.output_list: Dict[UniqueKey, WorkItem] = dict() # lock for the list self.work_list_condition_lock = threading.Condition(threading.Lock()) self.output_list_condition_lock = threading.Condition(threading.Lock()) + self.label_lock = threading.Condition(threading.Lock()) + self.step_lock = threading.Lock() + self.step_lock.acquire() + + # main loop self.main_loop_thread = threading.Thread(target=self._work_loop, name=f'rank_{pp_rank}', daemon=True) self.main_loop_thread.start() def _get_future_by_device(self): return torch.futures.Future(devices=None if self.device in (None, 'cpu') else [self.device]) - def _initialize_outstanding_range(self, pp_rank: int, actual_stage_num: int, use_1F1B: bool) -> Tuple[int]: + def _initialize_outstanding_range(self): outstanding_range = None - if use_1F1B: - if pp_rank == actual_stage_num - 1: + if self.use_1F1B: + if self.pp_rank == self.actual_stage_num - 1: outstanding_range = (0, 1) else: - outstanding_range = (actual_stage_num, actual_stage_num) - return outstanding_range + outstanding_range = (self.actual_stage_num, self.actual_stage_num) + self.outstanding_range = outstanding_range def sync_global_worker_rrefs(self, pp_rank_to_worker_rref: Dict[int, PyRRef]) -> None: assert self.pp_rank_to_worker_rref is None, f"in rank {self.pp_rank}, worker has sync global workers rrefs" @@ -189,20 +213,16 @@ class Worker: def get_output_by_key(self, key: UniqueKey) -> Any: with self.output_list_condition_lock: - while key not in self.output_list: - self.output_list_condition_lock.wait() - + self.output_list_condition_lock.wait_for(lambda: key in self.output_list) output_work_item = self.output_list[key] - output = output_work_item.output.wait() # color_debug(f'rank {self.pp_rank}, output {type(output)}', 'get output', 'red') output_work_item.refcount += 1 # all consumers have been satisfied, the work_item can be released with self.output_list_condition_lock: - if output_work_item.refcount == len(self.consumer_stage_ids): + if output_work_item.refcount >= len(self.consumer_stage_ids): self.output_list.pop(key) - return output def get_parameters(self) -> List[torch.Tensor]: @@ -211,34 +231,28 @@ class Worker: def get_parameter_gradients(self) -> List[torch.Tensor]: return [p.grad for p in self.module_partition.parameters()] - def reset_pp_context(self): - self.forward_times = 0 - self.backward_times = 0 - self.outstanding = 0 - self.microbatch_id_to_backward_cache.clear() - self.output_list.clear() - # just for first pp_rank def set_input(self, microbatch_id: int, microbatch: Tuple[Any], forward_only: bool): + assert self.consumer_stage_ids is not None + key = UniqueKey(microbatch_id, Phase.FORWARD) + output = self._get_future_by_device() + args = [microbatch] if isinstance(microbatch, torch.Tensor) else microbatch + work_item = WorkItem(self.pp_rank, Phase.FORWARD, args, {}, output, microbatch_id, None, self.num_microbatches, + forward_only) with self.work_list_condition_lock: - assert self.consumer_stage_ids is not None - consumer_num = len(self.consumer_stage_ids) - key = UniqueKey(microbatch_id, Phase.FORWARD) - output = self._get_future_by_device() - args = [microbatch] if isinstance(microbatch, torch.Tensor) else microbatch - - work_item = WorkItem(self.pp_rank, Phase.FORWARD, args, {}, output, microbatch_id, None, - self.num_microbatches, forward_only) self.work_list[key] = work_item - color_debug(f'rank {self.pp_rank} receive data from dataloader', 'data dispatch', 'magenta') self.work_list_condition_lock.notify_all() + # just for last pp_rank + def set_labels(self, microbatch_id: int, microlabels: Any): + self.microbatch_id_to_labels[microbatch_id] = microlabels + # just for last pp_rank def _begin_backward(self, microbatch_id: int): with self.work_list_condition_lock: assert self.producer_stage_ids is not None - producer_num = len(self.producer_stage_ids) + key = UniqueKey(microbatch_id, Phase.BACKWARD) output = self._get_future_by_device() grad_wrt_loss = torch.tensor(1, device=self.device) @@ -272,15 +286,10 @@ class Worker: color_debug(f'rank {self.pp_rank} get {len(subscribe_forward_futures)} futs from its producer', 'data dispatch', 'magenta') - args = [] - for i in range(producer_num): - producer_args = subscribe_forward_futures[i].wait() - args.extend(producer_args) + work_item_from_producer = WorkItem(stage_id, Phase.FORWARD, subscribe_forward_futures, {}, output, + microbatch_id, None, self.num_microbatches, forward_only) - work_item_from_producer = WorkItem(stage_id, Phase.FORWARD, args, {}, output, microbatch_id, None, - self.num_microbatches, forward_only) - - color_debug(f'rank {self.pp_rank} get value {tensor_shape_list(args)} from fut', 'data dispatch', 'magenta') + # color_debug(f'rank {self.pp_rank} get value {tensor_shape_list(args)} from fut', 'data dispatch', 'magenta') # add work_item to work_list with self.work_list_condition_lock: key = UniqueKey(microbatch_id, Phase.FORWARD) @@ -312,16 +321,11 @@ class Worker: consumer_worker_rref = self.pp_rank_to_worker_rref[consumer_stage_id] subscribe_backward_futures[i] = consumer_worker_rref.rpc_async().get_output_by_key(consumer_output_key) - args = [] - for i in range(consumer_num): - consumer_args = subscribe_backward_futures[i].wait() - args.extend(consumer_args) - # flatten args - work_item_from_consumer = WorkItem(stage_id, Phase.BACKWARD, args, {}, output, microbatch_id, None, - self.num_microbatches, False) + work_item_from_consumer = WorkItem(stage_id, Phase.BACKWARD, subscribe_backward_futures, {}, output, + microbatch_id, None, self.num_microbatches, False) - color_debug(f'rank {self.pp_rank} get value {tensor_shape_list(args)} from fut', 'data dispatch', 'magenta') + # color_debug(f'rank {self.pp_rank} get value {tensor_shape_list(args)} from fut', 'data dispatch', 'magenta') # add work_item to work_list with self.work_list_condition_lock: @@ -351,63 +355,50 @@ class Worker: self.consumer_stage_ids.append(next_rank) def _get_work_item_key(self) -> UniqueKey: - with self.work_list_condition_lock: - while len(self.work_list) == 0: - self.work_list_condition_lock.wait() - - # each stage must do Key(microbatch_id=0, phase=FORWARD) first - # before doing the operation, reset the context first - if self.reset_key in self.work_list: - self.reset_pp_context() - - # execute backward first (if backward phase in work_list) - pp_rank = self.pp_rank - actual_stage_num = self.actual_stage_num - num_microbatches = self.num_microbatches - is_last_stage = pp_rank == actual_stage_num - 1 - select_work_list_key: UniqueKey = None - - if self.outstanding_range: - if self.outstanding <= self.outstanding_range[0]: - target_phase = Phase.FORWARD - target_microbatch_id = self.forward_times - elif self.outstanding >= self.outstanding_range[1]: - target_phase = Phase.BACKWARD - target_microbatch_id = self.backward_times - else: - raise ValueError("outstanding_range[1] - outstanding_range[0] must be in [0, 1]") - - target_key = UniqueKey(target_microbatch_id, target_phase) - if target_key in self.work_list: - select_work_list_key = target_key - - # change outstanding_range at: - # 1. forward times reach actual_stage_num, this is the end of continuous forward - # 2. forward times reach num_microbatches, this is the end of 1F1B mode - if not is_last_stage and \ - select_work_list_key is not None and \ - select_work_list_key.phase == Phase.FORWARD: - if select_work_list_key.microbatch_id == actual_stage_num - 1: - outstanding_min = actual_stage_num - pp_rank - 1 - outstanding_max = actual_stage_num - pp_rank - self.outstanding_range = (outstanding_min, outstanding_max) - elif select_work_list_key.microbatch_id == num_microbatches - 1: - self.outstanding_range = (0, 0) + # execute backward first (if backward phase in work_list) + pp_rank = self.pp_rank + actual_stage_num = self.actual_stage_num + num_microbatches = self.num_microbatches + is_last_stage = pp_rank == actual_stage_num - 1 + if self.outstanding_range: + if self.outstanding <= self.outstanding_range[0]: + target_phase = Phase.FORWARD + target_microbatch_id = self.forward_times + elif self.outstanding >= self.outstanding_range[1]: + target_phase = Phase.BACKWARD + target_microbatch_id = self.backward_times else: - if self.forward_times < num_microbatches: - target_phase = Phase.FORWARD - target_microbatch_id = self.forward_times - else: - target_phase = Phase.BACKWARD - target_microbatch_id = self.backward_times + raise ValueError("outstanding_range[1] - outstanding_range[0] must be in [0, 1]") - target_key = UniqueKey(target_microbatch_id, target_phase) + target_key = UniqueKey(target_microbatch_id, target_phase) - if target_key in self.work_list: - select_work_list_key = target_key + # change outstanding_range at: + # 1. forward times reach actual_stage_num, this is the end of continuous forward + # 2. forward times reach num_microbatches, this is the end of 1F1B mode + if not is_last_stage and \ + target_key.phase == Phase.FORWARD: + if target_key.microbatch_id == actual_stage_num - 1: + outstanding_min = actual_stage_num - pp_rank - 1 + outstanding_max = actual_stage_num - pp_rank + self.outstanding_range = (outstanding_min, outstanding_max) + elif target_key.microbatch_id == num_microbatches - 1: + self.outstanding_range = (0, 0) - return select_work_list_key + else: + if self.forward_times < num_microbatches: + target_phase = Phase.FORWARD + target_microbatch_id = self.forward_times + else: + target_phase = Phase.BACKWARD + target_microbatch_id = self.backward_times + + target_key = UniqueKey(target_microbatch_id, target_phase) + + with self.work_list_condition_lock: + self.work_list_condition_lock.wait_for(lambda: target_key in self.work_list) + + return target_key def _consume_work_item_by_phase(self, work_item: WorkItem): phase = work_item.phase @@ -421,7 +412,7 @@ class Worker: is_first_stage = (self.pp_rank == 0) is_last_stage = (self.pp_rank == self.actual_stage_num - 1) - # if self.pp_rank == 3: + # if self.pp_rank == 0: # print( # f'I am rank_{self.pp_rank} microbatch_id : {microbatch_id} {phase} {self._get_store_len()} | {self.outstanding} {self.outstanding_range}' # ) @@ -436,12 +427,7 @@ class Worker: if not forward_only: self.outstanding += 1 - - # TODO : more elegant ? - for i in range(len(args)): - arg_obj = args[i] - if isinstance(arg_obj, torch.Tensor) and not arg_obj.requires_grad: - args[i] = arg_obj.requires_grad_() + args = get_real_args(args) # last stage doesn't need to do checkpoint, for it will do backward instantly if forward_only: @@ -458,7 +444,14 @@ class Worker: use_checkpoint = True else: consume_result = self.module_partition(*args, **kwargs) - stage_outputs = consume_result + if is_last_stage and self.criterion: + labels = self.microbatch_id_to_labels.pop(microbatch_id) + loss: torch.Tensor = self.criterion(consume_result, labels) + consume_result = loss.item() + else: + loss = consume_result + + stage_outputs = loss stage_inputs = args use_checkpoint = False @@ -467,7 +460,8 @@ class Worker: stage_outputs, checkpoint=use_checkpoint) - consume_result = [consume_result] if isinstance(consume_result, torch.Tensor) else consume_result + consume_result = [consume_result] if isinstance(consume_result, + (torch.Tensor, int, float)) else consume_result # if not forward_only, do the backward if not forward_only: @@ -488,20 +482,21 @@ class Worker: stage_outputs = backward_cache.stage_outputs stage_inputs = backward_cache.stage_inputs - grad_tensors = args - use_checkpoint = backward_cache.checkpoint if use_checkpoint: stage_outputs = [self.module_partition(*stage_inputs)] + # overlap recompute and future.wait + grad_tensors = get_real_args(args) autograd.backward(stage_outputs, grad_tensors=grad_tensors) # collect grad of input tensor consume_result = [] - for input_node in stage_inputs: - if isinstance(input_node, torch.Tensor): - consume_result.append(input_node.grad) + if not is_first_stage: + for input_node in stage_inputs: + if isinstance(input_node, torch.Tensor): + consume_result.append(input_node.grad) else: raise TypeError(f"Unknown phase appears in _consume_work_item_by_phase {phase}") @@ -509,7 +504,20 @@ class Worker: return consume_result def _get_store_len(self): - return f'work_list:{len(self.work_list)} output_list:{len(self.output_list)} backward_cache:{len(self.microbatch_id_to_backward_cache)}' + return f'work_list:{len(self.work_list)} output_list:{len(self.output_list)} backward_cache:{len(self.microbatch_id_to_backward_cache)} label_cache:{len(self.microbatch_id_to_labels)}' + + def _get_parameter_grad_sum(self): + grad_sum = 0 + for p in self.module_partition.parameters(): + if p.grad is not None: + grad_sum += p.grad.sum() + return grad_sum + + def _is_first_step(self, work_item) -> bool: + return work_item.phase == Phase.FORWARD and work_item.microbatch_id == 0 + + def _is_last_step(self, work_item) -> bool: + return work_item.phase == Phase.BACKWARD and work_item.microbatch_id == self.num_microbatches - 1 # do the main loop to consume ready_list def _work_loop(self): @@ -519,8 +527,6 @@ class Worker: # main loop while True: work_item_key = self._get_work_item_key() - if work_item_key is None: - continue # move current work item to output_list to activate subscribe in advance with self.work_list_condition_lock: @@ -540,20 +546,32 @@ class Worker: color_debug( f'rank_{self.pp_rank} [{work_item.phase}] finish consuming, result is {tensor_shape_list(consume_result)} {self._get_store_len()}', 'work loop', 'green') + work_item.output.set_result(consume_result) + # if is last step in one batch reset context and do step + if self._is_last_step(work_item): + if hasattr(self, 'optimizer'): + self.step() + self.forward_times = 0 + self.backward_times = 0 + self.outstanding = 0 + self._initialize_outstanding_range() + def initialize_optimizer(self, optimizer_class: type, **kwargs): self.optimizer: optim.Optimizer = optimizer_class(self.module_partition.parameters(), **kwargs) - def step(self): - assert hasattr(self, "optimizer"), "call initialize_optimizer first before you call step!" - self.work_list.clear() - self.output_list.clear() - self.microbatch_id_to_backward_cache.clear() + def wait_for_step(self): + self.step_lock.acquire() + def step(self): + # print(f'rank_{self.pp_rank}', sum([p.sum() for p in self.module_partition.parameters()])) self.optimizer.step() + # print(f'rank_{self.pp_rank}', sum([p.sum() for p in self.module_partition.parameters()])) self.optimizer.zero_grad() + self.step_lock.release() + class PipelineEngineBase(ABC, nn.Module): @@ -564,10 +582,12 @@ class PipelineEngineBase(ABC, nn.Module): device: str, use_1F1B=False, chunk: int = 1, + criterion: Callable = None, checkpoint: bool = False) -> None: super().__init__() self.module_partitions: List[nn.Module] = module_partitions self.chunk = chunk + self.criterion = criterion self.num_microbatches = num_microbatches self.device = device self.use_1F1B = use_1F1B @@ -577,6 +597,8 @@ class PipelineEngineBase(ABC, nn.Module): self.pp_rank_to_worker_rref: Dict[int, PyRRef] = dict() + self.step_futs: List[Future] = [] + self._check_argument() self._create_pp_rank_to_rpc_worker_id() self._init_worker() @@ -613,6 +635,7 @@ class PipelineEngineBase(ABC, nn.Module): checkpoint = self.checkpoint num_microbatches = self.num_microbatches device = self.device + criterion = self.criterion for pp_rank in range(actual_stage_num): module_partition = self.module_partitions[pp_rank] @@ -622,7 +645,8 @@ class PipelineEngineBase(ABC, nn.Module): self.pp_rank_to_worker_rref[pp_rank] = rpc.remote(rpc_worker_id, Worker, args=(module_partition, pp_rank, actual_stage_num, - num_microbatches, use_1F1B, device, checkpoint)) + num_microbatches, use_1F1B, device, criterion, + checkpoint)) # let each worker know global worker rref (include itself) for pp_rank in range(actual_stage_num): @@ -646,12 +670,15 @@ class PipelineEngineBase(ABC, nn.Module): grads[stage_id].append(grad) return grads - def forward_backward(self, batch: torch.Tensor, forward_only: bool = False): + def forward_backward(self, batch: torch.Tensor, labels: torch.Tensor = None, forward_only: bool = False): + if labels is not None: + assert len(batch) == len(labels) + num_microbatches = self.num_microbatches microbatch_size = len(batch) // num_microbatches actual_stage_num = self._get_actual_stage_num() - first_stage_worker = self.pp_rank_to_worker_rref[0] + first_worker_rref = self.pp_rank_to_worker_rref[0] last_worker_rref = self.pp_rank_to_worker_rref[actual_stage_num - 1] microbatch_iter = range(num_microbatches) @@ -659,11 +686,7 @@ class PipelineEngineBase(ABC, nn.Module): microbatch_iter = tqdm(microbatch_iter) ret_future: List[Future] = [None] * num_microbatches - from time import sleep - for microbatch_id in microbatch_iter: - microbatch = batch[microbatch_size * microbatch_id:microbatch_size * (microbatch_id + 1)] - # control data input speed # to prevent exceed of wait limitations if microbatch_id >= actual_stage_num: @@ -671,15 +694,27 @@ class PipelineEngineBase(ABC, nn.Module): ret_future[microbatch_id - actual_stage_num].wait() else: key = UniqueKey(microbatch_id - actual_stage_num, Phase.BACKWARD) - first_stage_worker.rpc_sync().get_output_by_key(key) + first_worker_rref.rpc_sync().get_output_by_key(key) - # run one microbatch - first_stage_worker.rpc_sync().set_input(microbatch_id, microbatch, forward_only) + # set input + microbatch = batch[microbatch_size * microbatch_id:microbatch_size * (microbatch_id + 1)] + microbatch = microbatch.cuda() + first_worker_rref.remote().set_input(microbatch_id, microbatch, forward_only) + # set labels + if not forward_only and labels is not None: + microlabels = labels[microbatch_size * microbatch_id:microbatch_size * (microbatch_id + 1)] + microlabels = microlabels.cuda() + last_worker_rref.remote().set_labels(microbatch_id, microlabels) key = UniqueKey(microbatch_id, Phase.FORWARD) ret_future[microbatch_id] = last_worker_rref.rpc_async().get_output_by_key(key) - # wait forward + # wait for last backward in rank0 + if not forward_only: + key = UniqueKey(self.num_microbatches - 1, Phase.BACKWARD) + first_worker_rref.rpc_sync().get_output_by_key(key) + + # collect forward result # TODO : all the node to output forward_result = None @@ -691,28 +726,30 @@ class PipelineEngineBase(ABC, nn.Module): for i in range(len(forward_result)): forward_result[i].append(ret[i]) - # wait for last backward in rank0 - if not forward_only: - key = UniqueKey(self.num_microbatches - 1, Phase.BACKWARD) - first_stage_worker.rpc_sync().get_output_by_key(key) + if hasattr(self, 'optimizer_class'): + # wait for all step + # TODO : more elegant ? + for pp_rank in self.pp_rank_to_worker_rref: + worker_rref = self.pp_rank_to_worker_rref[pp_rank] + worker_rref.rpc_sync().wait_for_step() + return forward_result def initialize_optimizer(self, optimizer_class: type, **kwargs): actual_stage_num = self._get_actual_stage_num() + self.optimizer_class = optimizer_class for pp_rank in range(actual_stage_num): worker_rref = self.pp_rank_to_worker_rref[pp_rank] - worker_rref.rpc_sync().initialize_optimizer(optimizer_class, **kwargs) + worker_rref.remote().initialize_optimizer(optimizer_class, **kwargs) def step(self): - step_futs: List[Future] = [] actual_stage_num = self._get_actual_stage_num() for pp_rank in range(actual_stage_num): worker_rref = self.pp_rank_to_worker_rref[pp_rank] fut = worker_rref.rpc_async().step() - step_futs.append(fut) + self.step_futs.append(fut) - # wait for all optimizers - for fut in step_futs: + for fut in self.step_futs: fut.wait() @@ -724,9 +761,10 @@ class FillDrainPipelineEngine(PipelineEngineBase): num_microbatches: int, device: str, chunk: int = 1, + criterion: Callable = None, checkpoint: bool = False) -> None: use_1F1B = False - super().__init__(module_partitions, stage_num, num_microbatches, device, use_1F1B, chunk, checkpoint) + super().__init__(module_partitions, stage_num, num_microbatches, device, use_1F1B, chunk, criterion, checkpoint) class OneFOneBPipelineEngine(PipelineEngineBase): @@ -737,6 +775,7 @@ class OneFOneBPipelineEngine(PipelineEngineBase): num_microbatches: int, device: str, chunk: int = 1, + criterion: Callable = None, checkpoint: bool = False) -> None: use_1F1B = True - super().__init__(module_partitions, stage_num, num_microbatches, device, use_1F1B, chunk, checkpoint) + super().__init__(module_partitions, stage_num, num_microbatches, device, use_1F1B, chunk, criterion, checkpoint) \ No newline at end of file diff --git a/tests/test_pipeline/rpc_test_utils.py b/tests/test_pipeline/rpc_test_utils.py index cb65f3eec..1a8472820 100644 --- a/tests/test_pipeline/rpc_test_utils.py +++ b/tests/test_pipeline/rpc_test_utils.py @@ -45,6 +45,7 @@ def parse_args(): parser = argparse.ArgumentParser() parser.add_argument('--epoch', type=int, default=1) parser.add_argument('--world_size', type=int, default=2) + parser.add_argument('--batch_size', type=int, default=16) parser.add_argument('--num_microbatches', type=int, default=2) parser.add_argument('--chunk', type=int, default=1) parser.add_argument('--use_checkpoint', action='store_true') diff --git a/tests/test_pipeline/test_cuda_rpc_optimizer.py b/tests/test_pipeline/test_cuda_rpc_optimizer.py index b740da578..5d3c193e3 100644 --- a/tests/test_pipeline/test_cuda_rpc_optimizer.py +++ b/tests/test_pipeline/test_cuda_rpc_optimizer.py @@ -43,7 +43,6 @@ def run_master(args): engine.initialize_optimizer(optimizer_class, lr=lr) _ = engine.forward_backward(input_sample) - engine.step() cuda_rpc_result = [] single_result = []