[pipeline/fix-bug] num_microbatches support any integrate | stable chimera | launch tool for rpc pp framework (#1684)

* [pipeline/tuning] improve dispatch performance both time and space cost

* [pipeline/converge] add interface for testing convergence

* [NFC] polish colossalai/utils/multi_tensor_apply/multi_tensor_apply.py code style

* Update PipelineBase.py

* [pipeline/chimera] reconstruct PipelineBase and Worker to support more feasible custom schedule | finish Chimera

* [pipeline/chimera] test chimera | fix bug of initializing

* [pipeline/pytree] add pytree to process args and kwargs | provide  to process args and kwargs after forward

* [pipeline/fix-bug] num_microbatches support any integrate | stable chimera | launch tool for rpc pp framework
This commit is contained in:
Kirigaya Kazuto
2022-10-10 16:01:02 +08:00
committed by GitHub
parent e5ab6be72e
commit 0df5034a36
4 changed files with 98 additions and 24 deletions

View File

@@ -1,10 +1,18 @@
from typing import List, Any, Tuple, Dict, Callable, Type, Union
import os
import warnings
import argparse
import torch
import torch.multiprocessing as mp
from torch.futures import Future
import torch.distributed.rpc as rpc
from torch._C._distributed_rpc import _is_current_rpc_agent_set
from colorama import Back, Style
from colossalai.initialize import launch
from colossalai.pipeline.pipeline_process_group import ppg
# config for debug and test
use_color_debug = False
@@ -87,3 +95,57 @@ def get_real_args_kwargs(args_or_kwargs):
args_or_kwargs = flatten_args
return args_or_kwargs
def run_worker(rank, args, master_func):
os.environ['MASTER_ADDR'] = args.master_addr
os.environ['MASTER_PORT'] = args.master_port
device = args.device
world_size = args.world_size
dp_degree = args.dp_degree
tp_degree = args.tp_degree
num_worker_threads = args.num_worker_threads
host = args.master_addr
port = args.master_port
backend = 'nccl' if device == 'cuda' else 'gloo'
launch(dict(), rank, world_size, host, int(port), backend, verbose=False)
ppg.set_global_info(rank=rank,
world_size=world_size,
dp_degree=dp_degree,
tp_degree=tp_degree,
num_worker_threads=num_worker_threads,
device=device)
ppg.args = args
# in rpc mode, only rank 0 is needed to be coded
if rank == 0:
master_func(args)
# barrier here
if _is_current_rpc_agent_set():
rpc.shutdown()
else:
warnings.warn("RPC has not been initialized")
def rpc_run(args, master_func):
world_size = args.world_size
mp.spawn(run_worker, args=(args, master_func), nprocs=world_size)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--epoch', type=int, default=1)
parser.add_argument('--world_size', type=int, default=2)
parser.add_argument('--batch_size', type=int, default=16)
parser.add_argument('--dp_degree', type=int, default=1)
parser.add_argument('--tp_degree', type=int, default=1)
parser.add_argument('--num_microbatches', type=int, default=2)
parser.add_argument('--chunk', type=int, default=1)
parser.add_argument('--use_checkpoint', action='store_true')
parser.add_argument('--optimizer', type=str, choices=['SGD', 'Adam', 'RMSprop'], default='SGD')
parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cuda')
parser.add_argument('--master_addr', type=str, default='localhost')
parser.add_argument('--master_port', type=str, default='29020')
parser.add_argument('--num_worker_threads', type=str, default=128)
return parser.parse_args()