diff --git a/colossalai/pipeline/pipeline_process_group.py b/colossalai/pipeline/pipeline_process_group.py index e85f45548..c61d97eba 100644 --- a/colossalai/pipeline/pipeline_process_group.py +++ b/colossalai/pipeline/pipeline_process_group.py @@ -50,6 +50,7 @@ class PipelineProcessGroup: self.is_initialize = True # lock + self.initialise_lock = threading.Lock() self.chimera_lock = threading.Lock() def _initialize_process_group(self): diff --git a/colossalai/pipeline/rpc/_pipeline_base.py b/colossalai/pipeline/rpc/_pipeline_base.py index 96357c476..fd5b1b2d1 100644 --- a/colossalai/pipeline/rpc/_pipeline_base.py +++ b/colossalai/pipeline/rpc/_pipeline_base.py @@ -3,9 +3,7 @@ from enum import Enum from typing import List, Any, Tuple, Dict, Callable from functools import partial from abc import ABC, abstractmethod -import sys -import os -import time +import math import inspect import torch @@ -831,13 +829,16 @@ class PipelineEngineBase(ABC, nn.Module): def forward_backward(self, batch: torch.Tensor, labels: torch.Tensor = None, forward_only: bool = False): batch_lengths = get_batch_lengths(batch) + batch_length = batch_lengths[0] if labels is not None and not forward_only: assert hasattr( self, 'optimizer_class'), "call `initialize_optimizer` to initialize optimizer before forward_backward" num_microbatches = self.num_microbatches - microbatch_size = batch_lengths[0] // num_microbatches + + assert batch_length >= num_microbatches, "num_microbatches is greater than the size of a batch, which is illegal" + microbatch_size = math.ceil(batch_length / num_microbatches) device = self.device # If Chimera mode is used, then rank of down pipeline is excluded from 'input_pp_ranks' or 'output_pp_ranks' @@ -852,7 +853,7 @@ class PipelineEngineBase(ABC, nn.Module): # to prevent exceed of wait limitations self._consume_constraint(microbatch_id, forward_only, input_pp_ranks, output_pp_ranks, ret_future) batch_start = microbatch_size * microbatch_id - batch_end = batch_start + microbatch_size + batch_end = min(batch_start + microbatch_size, batch_length) # set input microbatch = split_batch(batch, batch_start, batch_end, device) diff --git a/colossalai/pipeline/rpc/_pipeline_schedule.py b/colossalai/pipeline/rpc/_pipeline_schedule.py index e534943e0..fb4feb26d 100644 --- a/colossalai/pipeline/rpc/_pipeline_schedule.py +++ b/colossalai/pipeline/rpc/_pipeline_schedule.py @@ -1,4 +1,5 @@ from typing import List, Callable, Dict +import threading import torch import torch.distributed as dist @@ -81,7 +82,8 @@ class OneFOneBWorker(WorkerBase): # 2. forward times reach num_microbatches, this is the end of 1F1B mode if not is_last_stage and \ target_key.phase == Phase.FORWARD: - if target_key.microbatch_id == actual_stage_num - 1: + if target_key.microbatch_id == actual_stage_num - 1 and num_microbatches > 2: + # Why need num_microbatches > 2 ? Because there is no steady stage when num_microbatches <= 2 outstanding_min = actual_stage_num - pp_rank - 1 outstanding_max = actual_stage_num - pp_rank self.outstanding_range = (outstanding_min, outstanding_max) @@ -186,6 +188,19 @@ class ChimeraWorker(WorkerBase): # init group for chimera in ppg ppg.get_chimera_all_reduce_group(pp_rank) + # lock for step sync + self.step_sync_lock = threading.Lock() + self.step_sync_lock.acquire() + + self.have_grad_lock = threading.Lock() + self.have_grad_lock.acquire() + + def _get_lock_gradient(self): + self.have_grad_lock.acquire() + grads = self.get_parameter_gradients() + self.step_sync_lock.release() + return grads + def is_first_stage(self): return (self.pp_rank % self.actual_stage_num) == 0 @@ -214,27 +229,22 @@ class ChimeraWorker(WorkerBase): return local_device_pp_ranks def _hook_before_step(self): + self.have_grad_lock.release() pp_rank = self.pp_rank - - orders = self._get_step_order() - step_index = orders.index(pp_rank) + stage_num = self.actual_stage_num + co_pp_rank = (pp_rank + stage_num) % (2 * stage_num) # if currrent pp_rank is not the first to do step # wait its previous pp_rank finish step - - all_reduce_group = ppg.get_chimera_all_reduce_group(self.pp_rank) grads = self.get_parameter_gradients() - # print(self.pp_rank, "begin all reduce", torch.cuda.max_memory_allocated(ppg.get_local_pp_rank()), torch.cuda.max_memory_reserved(ppg.get_local_pp_rank())) - if step_index == 1: - ppg.chimera_step_lock.acquire() - - # print(f'rank_{self.pp_rank} before all reduce') - dist.all_reduce_coalesced(grads, group=all_reduce_group, async_op=False) - # print(f'rank_{self.pp_rank} after all reduce') - - if step_index == 0: - ppg.chimera_step_lock.release() + # send + co_worker = self.pp_rank_to_worker_rref[co_pp_rank] + co_grads = co_worker.rpc_sync()._get_lock_gradient() + # sync + self.step_sync_lock.acquire() + for i in range(len(grads)): + grads[i] += co_grads[i] class ChimeraPipelineEngine(PipelineEngineBase): @@ -257,8 +267,8 @@ class ChimeraPipelineEngine(PipelineEngineBase): super().__init__(ChimeraWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion, metric, checkpoint, data_process_func) - def _consume_constraint(self, microbatch_id: int, forward_only: bool, ret_future: Dict[PyRRef, List[Future]], - input_pp_ranks: List[PyRRef], output_pp_ranks: List[PyRRef]): + def _consume_constraint(self, microbatch_id: int, forward_only: bool, input_pp_ranks: List[int], + output_pp_ranks: List[int], ret_future): pass def _create_pp_rank_to_rpc_worker_id(self) -> None: diff --git a/colossalai/pipeline/rpc/utils.py b/colossalai/pipeline/rpc/utils.py index c4d6897f6..887166467 100644 --- a/colossalai/pipeline/rpc/utils.py +++ b/colossalai/pipeline/rpc/utils.py @@ -1,10 +1,18 @@ from typing import List, Any, Tuple, Dict, Callable, Type, Union +import os +import warnings +import argparse import torch +import torch.multiprocessing as mp from torch.futures import Future - +import torch.distributed.rpc as rpc +from torch._C._distributed_rpc import _is_current_rpc_agent_set from colorama import Back, Style +from colossalai.initialize import launch +from colossalai.pipeline.pipeline_process_group import ppg + # config for debug and test use_color_debug = False @@ -87,3 +95,57 @@ def get_real_args_kwargs(args_or_kwargs): args_or_kwargs = flatten_args return args_or_kwargs + + +def run_worker(rank, args, master_func): + os.environ['MASTER_ADDR'] = args.master_addr + os.environ['MASTER_PORT'] = args.master_port + + device = args.device + world_size = args.world_size + dp_degree = args.dp_degree + tp_degree = args.tp_degree + num_worker_threads = args.num_worker_threads + host = args.master_addr + port = args.master_port + backend = 'nccl' if device == 'cuda' else 'gloo' + + launch(dict(), rank, world_size, host, int(port), backend, verbose=False) + ppg.set_global_info(rank=rank, + world_size=world_size, + dp_degree=dp_degree, + tp_degree=tp_degree, + num_worker_threads=num_worker_threads, + device=device) + ppg.args = args + # in rpc mode, only rank 0 is needed to be coded + if rank == 0: + master_func(args) + # barrier here + if _is_current_rpc_agent_set(): + rpc.shutdown() + else: + warnings.warn("RPC has not been initialized") + + +def rpc_run(args, master_func): + world_size = args.world_size + mp.spawn(run_worker, args=(args, master_func), nprocs=world_size) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--epoch', type=int, default=1) + parser.add_argument('--world_size', type=int, default=2) + parser.add_argument('--batch_size', type=int, default=16) + parser.add_argument('--dp_degree', type=int, default=1) + parser.add_argument('--tp_degree', type=int, default=1) + parser.add_argument('--num_microbatches', type=int, default=2) + parser.add_argument('--chunk', type=int, default=1) + parser.add_argument('--use_checkpoint', action='store_true') + parser.add_argument('--optimizer', type=str, choices=['SGD', 'Adam', 'RMSprop'], default='SGD') + parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cuda') + parser.add_argument('--master_addr', type=str, default='localhost') + parser.add_argument('--master_port', type=str, default='29020') + parser.add_argument('--num_worker_threads', type=str, default=128) + return parser.parse_args()