mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-05-10 17:38:32 +00:00
[pipeline/fix-bug] num_microbatches support any integrate | stable chimera | launch tool for rpc pp framework (#1684)
* [pipeline/tuning] improve dispatch performance both time and space cost * [pipeline/converge] add interface for testing convergence * [NFC] polish colossalai/utils/multi_tensor_apply/multi_tensor_apply.py code style * Update PipelineBase.py * [pipeline/chimera] reconstruct PipelineBase and Worker to support more feasible custom schedule | finish Chimera * [pipeline/chimera] test chimera | fix bug of initializing * [pipeline/pytree] add pytree to process args and kwargs | provide to process args and kwargs after forward * [pipeline/fix-bug] num_microbatches support any integrate | stable chimera | launch tool for rpc pp framework
This commit is contained in:
parent
e5ab6be72e
commit
0df5034a36
@ -50,6 +50,7 @@ class PipelineProcessGroup:
|
|||||||
self.is_initialize = True
|
self.is_initialize = True
|
||||||
|
|
||||||
# lock
|
# lock
|
||||||
|
self.initialise_lock = threading.Lock()
|
||||||
self.chimera_lock = threading.Lock()
|
self.chimera_lock = threading.Lock()
|
||||||
|
|
||||||
def _initialize_process_group(self):
|
def _initialize_process_group(self):
|
||||||
|
@ -3,9 +3,7 @@ from enum import Enum
|
|||||||
from typing import List, Any, Tuple, Dict, Callable
|
from typing import List, Any, Tuple, Dict, Callable
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
import sys
|
import math
|
||||||
import os
|
|
||||||
import time
|
|
||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -831,13 +829,16 @@ class PipelineEngineBase(ABC, nn.Module):
|
|||||||
|
|
||||||
def forward_backward(self, batch: torch.Tensor, labels: torch.Tensor = None, forward_only: bool = False):
|
def forward_backward(self, batch: torch.Tensor, labels: torch.Tensor = None, forward_only: bool = False):
|
||||||
batch_lengths = get_batch_lengths(batch)
|
batch_lengths = get_batch_lengths(batch)
|
||||||
|
batch_length = batch_lengths[0]
|
||||||
|
|
||||||
if labels is not None and not forward_only:
|
if labels is not None and not forward_only:
|
||||||
assert hasattr(
|
assert hasattr(
|
||||||
self, 'optimizer_class'), "call `initialize_optimizer` to initialize optimizer before forward_backward"
|
self, 'optimizer_class'), "call `initialize_optimizer` to initialize optimizer before forward_backward"
|
||||||
|
|
||||||
num_microbatches = self.num_microbatches
|
num_microbatches = self.num_microbatches
|
||||||
microbatch_size = batch_lengths[0] // num_microbatches
|
|
||||||
|
assert batch_length >= num_microbatches, "num_microbatches is greater than the size of a batch, which is illegal"
|
||||||
|
microbatch_size = math.ceil(batch_length / num_microbatches)
|
||||||
device = self.device
|
device = self.device
|
||||||
|
|
||||||
# If Chimera mode is used, then rank of down pipeline is excluded from 'input_pp_ranks' or 'output_pp_ranks'
|
# If Chimera mode is used, then rank of down pipeline is excluded from 'input_pp_ranks' or 'output_pp_ranks'
|
||||||
@ -852,7 +853,7 @@ class PipelineEngineBase(ABC, nn.Module):
|
|||||||
# to prevent exceed of wait limitations
|
# to prevent exceed of wait limitations
|
||||||
self._consume_constraint(microbatch_id, forward_only, input_pp_ranks, output_pp_ranks, ret_future)
|
self._consume_constraint(microbatch_id, forward_only, input_pp_ranks, output_pp_ranks, ret_future)
|
||||||
batch_start = microbatch_size * microbatch_id
|
batch_start = microbatch_size * microbatch_id
|
||||||
batch_end = batch_start + microbatch_size
|
batch_end = min(batch_start + microbatch_size, batch_length)
|
||||||
|
|
||||||
# set input
|
# set input
|
||||||
microbatch = split_batch(batch, batch_start, batch_end, device)
|
microbatch = split_batch(batch, batch_start, batch_end, device)
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from typing import List, Callable, Dict
|
from typing import List, Callable, Dict
|
||||||
|
import threading
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
@ -81,7 +82,8 @@ class OneFOneBWorker(WorkerBase):
|
|||||||
# 2. forward times reach num_microbatches, this is the end of 1F1B mode
|
# 2. forward times reach num_microbatches, this is the end of 1F1B mode
|
||||||
if not is_last_stage and \
|
if not is_last_stage and \
|
||||||
target_key.phase == Phase.FORWARD:
|
target_key.phase == Phase.FORWARD:
|
||||||
if target_key.microbatch_id == actual_stage_num - 1:
|
if target_key.microbatch_id == actual_stage_num - 1 and num_microbatches > 2:
|
||||||
|
# Why need num_microbatches > 2 ? Because there is no steady stage when num_microbatches <= 2
|
||||||
outstanding_min = actual_stage_num - pp_rank - 1
|
outstanding_min = actual_stage_num - pp_rank - 1
|
||||||
outstanding_max = actual_stage_num - pp_rank
|
outstanding_max = actual_stage_num - pp_rank
|
||||||
self.outstanding_range = (outstanding_min, outstanding_max)
|
self.outstanding_range = (outstanding_min, outstanding_max)
|
||||||
@ -186,6 +188,19 @@ class ChimeraWorker(WorkerBase):
|
|||||||
# init group for chimera in ppg
|
# init group for chimera in ppg
|
||||||
ppg.get_chimera_all_reduce_group(pp_rank)
|
ppg.get_chimera_all_reduce_group(pp_rank)
|
||||||
|
|
||||||
|
# lock for step sync
|
||||||
|
self.step_sync_lock = threading.Lock()
|
||||||
|
self.step_sync_lock.acquire()
|
||||||
|
|
||||||
|
self.have_grad_lock = threading.Lock()
|
||||||
|
self.have_grad_lock.acquire()
|
||||||
|
|
||||||
|
def _get_lock_gradient(self):
|
||||||
|
self.have_grad_lock.acquire()
|
||||||
|
grads = self.get_parameter_gradients()
|
||||||
|
self.step_sync_lock.release()
|
||||||
|
return grads
|
||||||
|
|
||||||
def is_first_stage(self):
|
def is_first_stage(self):
|
||||||
return (self.pp_rank % self.actual_stage_num) == 0
|
return (self.pp_rank % self.actual_stage_num) == 0
|
||||||
|
|
||||||
@ -214,27 +229,22 @@ class ChimeraWorker(WorkerBase):
|
|||||||
return local_device_pp_ranks
|
return local_device_pp_ranks
|
||||||
|
|
||||||
def _hook_before_step(self):
|
def _hook_before_step(self):
|
||||||
|
self.have_grad_lock.release()
|
||||||
pp_rank = self.pp_rank
|
pp_rank = self.pp_rank
|
||||||
|
stage_num = self.actual_stage_num
|
||||||
orders = self._get_step_order()
|
co_pp_rank = (pp_rank + stage_num) % (2 * stage_num)
|
||||||
step_index = orders.index(pp_rank)
|
|
||||||
|
|
||||||
# if currrent pp_rank is not the first to do step
|
# if currrent pp_rank is not the first to do step
|
||||||
# wait its previous pp_rank finish step
|
# wait its previous pp_rank finish step
|
||||||
|
|
||||||
all_reduce_group = ppg.get_chimera_all_reduce_group(self.pp_rank)
|
|
||||||
grads = self.get_parameter_gradients()
|
grads = self.get_parameter_gradients()
|
||||||
|
|
||||||
# print(self.pp_rank, "begin all reduce", torch.cuda.max_memory_allocated(ppg.get_local_pp_rank()), torch.cuda.max_memory_reserved(ppg.get_local_pp_rank()))
|
# send
|
||||||
if step_index == 1:
|
co_worker = self.pp_rank_to_worker_rref[co_pp_rank]
|
||||||
ppg.chimera_step_lock.acquire()
|
co_grads = co_worker.rpc_sync()._get_lock_gradient()
|
||||||
|
# sync
|
||||||
# print(f'rank_{self.pp_rank} before all reduce')
|
self.step_sync_lock.acquire()
|
||||||
dist.all_reduce_coalesced(grads, group=all_reduce_group, async_op=False)
|
for i in range(len(grads)):
|
||||||
# print(f'rank_{self.pp_rank} after all reduce')
|
grads[i] += co_grads[i]
|
||||||
|
|
||||||
if step_index == 0:
|
|
||||||
ppg.chimera_step_lock.release()
|
|
||||||
|
|
||||||
|
|
||||||
class ChimeraPipelineEngine(PipelineEngineBase):
|
class ChimeraPipelineEngine(PipelineEngineBase):
|
||||||
@ -257,8 +267,8 @@ class ChimeraPipelineEngine(PipelineEngineBase):
|
|||||||
super().__init__(ChimeraWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion,
|
super().__init__(ChimeraWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion,
|
||||||
metric, checkpoint, data_process_func)
|
metric, checkpoint, data_process_func)
|
||||||
|
|
||||||
def _consume_constraint(self, microbatch_id: int, forward_only: bool, ret_future: Dict[PyRRef, List[Future]],
|
def _consume_constraint(self, microbatch_id: int, forward_only: bool, input_pp_ranks: List[int],
|
||||||
input_pp_ranks: List[PyRRef], output_pp_ranks: List[PyRRef]):
|
output_pp_ranks: List[int], ret_future):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _create_pp_rank_to_rpc_worker_id(self) -> None:
|
def _create_pp_rank_to_rpc_worker_id(self) -> None:
|
||||||
|
@ -1,10 +1,18 @@
|
|||||||
from typing import List, Any, Tuple, Dict, Callable, Type, Union
|
from typing import List, Any, Tuple, Dict, Callable, Type, Union
|
||||||
|
import os
|
||||||
|
import warnings
|
||||||
|
import argparse
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.multiprocessing as mp
|
||||||
from torch.futures import Future
|
from torch.futures import Future
|
||||||
|
import torch.distributed.rpc as rpc
|
||||||
|
from torch._C._distributed_rpc import _is_current_rpc_agent_set
|
||||||
from colorama import Back, Style
|
from colorama import Back, Style
|
||||||
|
|
||||||
|
from colossalai.initialize import launch
|
||||||
|
from colossalai.pipeline.pipeline_process_group import ppg
|
||||||
|
|
||||||
# config for debug and test
|
# config for debug and test
|
||||||
use_color_debug = False
|
use_color_debug = False
|
||||||
|
|
||||||
@ -87,3 +95,57 @@ def get_real_args_kwargs(args_or_kwargs):
|
|||||||
args_or_kwargs = flatten_args
|
args_or_kwargs = flatten_args
|
||||||
|
|
||||||
return args_or_kwargs
|
return args_or_kwargs
|
||||||
|
|
||||||
|
|
||||||
|
def run_worker(rank, args, master_func):
|
||||||
|
os.environ['MASTER_ADDR'] = args.master_addr
|
||||||
|
os.environ['MASTER_PORT'] = args.master_port
|
||||||
|
|
||||||
|
device = args.device
|
||||||
|
world_size = args.world_size
|
||||||
|
dp_degree = args.dp_degree
|
||||||
|
tp_degree = args.tp_degree
|
||||||
|
num_worker_threads = args.num_worker_threads
|
||||||
|
host = args.master_addr
|
||||||
|
port = args.master_port
|
||||||
|
backend = 'nccl' if device == 'cuda' else 'gloo'
|
||||||
|
|
||||||
|
launch(dict(), rank, world_size, host, int(port), backend, verbose=False)
|
||||||
|
ppg.set_global_info(rank=rank,
|
||||||
|
world_size=world_size,
|
||||||
|
dp_degree=dp_degree,
|
||||||
|
tp_degree=tp_degree,
|
||||||
|
num_worker_threads=num_worker_threads,
|
||||||
|
device=device)
|
||||||
|
ppg.args = args
|
||||||
|
# in rpc mode, only rank 0 is needed to be coded
|
||||||
|
if rank == 0:
|
||||||
|
master_func(args)
|
||||||
|
# barrier here
|
||||||
|
if _is_current_rpc_agent_set():
|
||||||
|
rpc.shutdown()
|
||||||
|
else:
|
||||||
|
warnings.warn("RPC has not been initialized")
|
||||||
|
|
||||||
|
|
||||||
|
def rpc_run(args, master_func):
|
||||||
|
world_size = args.world_size
|
||||||
|
mp.spawn(run_worker, args=(args, master_func), nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--epoch', type=int, default=1)
|
||||||
|
parser.add_argument('--world_size', type=int, default=2)
|
||||||
|
parser.add_argument('--batch_size', type=int, default=16)
|
||||||
|
parser.add_argument('--dp_degree', type=int, default=1)
|
||||||
|
parser.add_argument('--tp_degree', type=int, default=1)
|
||||||
|
parser.add_argument('--num_microbatches', type=int, default=2)
|
||||||
|
parser.add_argument('--chunk', type=int, default=1)
|
||||||
|
parser.add_argument('--use_checkpoint', action='store_true')
|
||||||
|
parser.add_argument('--optimizer', type=str, choices=['SGD', 'Adam', 'RMSprop'], default='SGD')
|
||||||
|
parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cuda')
|
||||||
|
parser.add_argument('--master_addr', type=str, default='localhost')
|
||||||
|
parser.add_argument('--master_port', type=str, default='29020')
|
||||||
|
parser.add_argument('--num_worker_threads', type=str, default=128)
|
||||||
|
return parser.parse_args()
|
||||||
|
Loading…
Reference in New Issue
Block a user