mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-24 22:42:15 +00:00
[cli] added micro benchmarking for tp (#789)
* [CLI] add CLI launcher
* Revert "[CLI] add CLI launcher"
This reverts commit df7e6506d4
.
* [CLI]add cli benchmark feature
* fix CodeFactor issues.
* refactor the module structure.
This commit is contained in:
parent
cfadc9df8e
commit
de2f581d43
2
colossalai/cli/benchmark/__init__.py
Normal file
2
colossalai/cli/benchmark/__init__.py
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
from .utils import *
|
||||||
|
from .run import *
|
86
colossalai/cli/benchmark/run.py
Normal file
86
colossalai/cli/benchmark/run.py
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
import torch
|
||||||
|
import inspect
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from colossalai.initialize import launch_from_torch
|
||||||
|
from colossalai.logging import disable_existing_loggers
|
||||||
|
from colossalai.utils import print_rank_0
|
||||||
|
from colossalai.core import global_context as gpc
|
||||||
|
from colossalai.logging import get_dist_logger
|
||||||
|
from colossalai.utils import free_port
|
||||||
|
from colossalai.cli.benchmark import build_args_parser, build_configs, \
|
||||||
|
build_input_tensor, profile_1d, profile_2d, profile_2p5d, profile_3d, \
|
||||||
|
BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM, ITER_TIMES
|
||||||
|
|
||||||
|
|
||||||
|
def launch(args=None):
|
||||||
|
train_script = inspect.getfile(inspect.currentframe())
|
||||||
|
assert args is not None, "args should not be None"
|
||||||
|
env = os.environ.copy()
|
||||||
|
if args.num_gpus == -1 or args.num_gpus > torch.cuda.device_count():
|
||||||
|
nproc_per_node = torch.cuda.device_count()
|
||||||
|
else:
|
||||||
|
nproc_per_node = args.num_gpus
|
||||||
|
|
||||||
|
train_args = [f"--num_gpus={nproc_per_node}"]
|
||||||
|
if args.bs != BATCH_SIZE:
|
||||||
|
train_args.append(f"--bs={args.bs}")
|
||||||
|
if args.hid_dim != HIDDEN_DIM:
|
||||||
|
train_args.append(f"--hid_dim={args.hid_dim}")
|
||||||
|
if args.num_steps != ITER_TIMES:
|
||||||
|
train_args.append(f"--num_steps={args.num_steps}")
|
||||||
|
if args.seq_len != SEQ_LENGTH:
|
||||||
|
train_args.append(f"--seq_len={args.seq_len}")
|
||||||
|
|
||||||
|
master_port = free_port()
|
||||||
|
if torch.__version__ <= "1.09":
|
||||||
|
cmd = [sys.executable, "-u", "-m",
|
||||||
|
"torch.distributed.launch",
|
||||||
|
f"--nproc_per_node={nproc_per_node}",
|
||||||
|
f"--master_port={master_port}"] + [train_script] + train_args
|
||||||
|
else:
|
||||||
|
cmd = ["torchrun",
|
||||||
|
f"--nproc_per_node={nproc_per_node}",
|
||||||
|
f"--master_port={master_port}"] + [train_script] + train_args
|
||||||
|
|
||||||
|
result = subprocess.Popen(cmd, env=env)
|
||||||
|
result.wait()
|
||||||
|
if result.returncode > 0:
|
||||||
|
sys.exit(result.returncode)
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = build_args_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
disable_existing_loggers()
|
||||||
|
logger = get_dist_logger()
|
||||||
|
launch_from_torch(config={}, verbose=False)
|
||||||
|
input_tensor = build_input_tensor(args)
|
||||||
|
config_dict = build_configs(args)
|
||||||
|
if len(config_dict) == 0:
|
||||||
|
print_rank_0(f"WARNING: We need at least two devices to profile TP strategies performance.")
|
||||||
|
gpc.destroy()
|
||||||
|
return
|
||||||
|
for parallel_mode, config in config_dict.items():
|
||||||
|
if parallel_mode == "1d":
|
||||||
|
result_1d = profile_1d(input_tensor, config, args)
|
||||||
|
print_rank_0(f"INFO: Totoal time cost in 1D TP is {result_1d}.")
|
||||||
|
if parallel_mode == "2d":
|
||||||
|
result_2d = profile_2d(input_tensor, config, args)
|
||||||
|
print_rank_0(f"INFO: Totoal time cost in 2D TP is {result_2d}.")
|
||||||
|
if parallel_mode == "2p5d":
|
||||||
|
result_2p5d = profile_2p5d(input_tensor, config, args)
|
||||||
|
print_rank_0(f"INFO: Totoal time cost in 2P5D TP is {result_2p5d}.")
|
||||||
|
if parallel_mode == "3d":
|
||||||
|
result_3d = profile_3d(input_tensor, config, args)
|
||||||
|
print_rank_0(f"INFO: Totoal time cost in 3D TP is {result_3d}.")
|
||||||
|
if "2d" not in config_dict:
|
||||||
|
print_rank_0(f"WARNING: To use 2D tensor parallel, you have to provide at least 4 computing devices.")
|
||||||
|
if "2p5d" not in config_dict:
|
||||||
|
print_rank_0(f"WARNING: To use 2P5D tensor parallel, you have to provide at least 8 computing devices.")
|
||||||
|
print_rank_0(f"WARNING: To use 3D tensor parallel, you have to provide at least 8 computing devices.")
|
||||||
|
gpc.destroy()
|
||||||
|
|
||||||
|
if __name__=="__main__":
|
||||||
|
main()
|
19
colossalai/cli/benchmark/simple_model.py
Normal file
19
colossalai/cli/benchmark/simple_model.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
import torch
|
||||||
|
import colossalai
|
||||||
|
import colossalai.nn as col_nn
|
||||||
|
|
||||||
|
class MLP(torch.nn.Module):
|
||||||
|
def __init__(self, dim: int = 256):
|
||||||
|
super().__init__()
|
||||||
|
intermediate_dim = dim * 4
|
||||||
|
self.dense_1 = col_nn.Linear(dim, intermediate_dim)
|
||||||
|
self.activation = torch.nn.GELU()
|
||||||
|
self.dense_2 = col_nn.Linear(intermediate_dim, dim)
|
||||||
|
self.dropout = col_nn.Dropout(0.1)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.dense_1(x)
|
||||||
|
x = self.activation(x)
|
||||||
|
x = self.dense_2(x)
|
||||||
|
x = self.dropout(x)
|
||||||
|
return x
|
146
colossalai/cli/benchmark/utils.py
Normal file
146
colossalai/cli/benchmark/utils.py
Normal file
@ -0,0 +1,146 @@
|
|||||||
|
import torch
|
||||||
|
from .simple_model import MLP
|
||||||
|
from colossalai.utils import Timer, synchronize
|
||||||
|
from colossalai.core import global_context as gpc
|
||||||
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
|
||||||
|
BATCH_SIZE = 8
|
||||||
|
SEQ_LENGTH = 120
|
||||||
|
HIDDEN_DIM = 1024
|
||||||
|
ITER_TIMES = 2000
|
||||||
|
|
||||||
|
def build_args_parser() -> ArgumentParser:
|
||||||
|
"""Helper function parsing the command line options."""
|
||||||
|
|
||||||
|
parser = ArgumentParser(description="colossal benchmark")
|
||||||
|
|
||||||
|
parser.add_argument("--num_gpus",
|
||||||
|
type=int,
|
||||||
|
default=-1,
|
||||||
|
help="Total number of devices to use.")
|
||||||
|
parser.add_argument("--bs",
|
||||||
|
type=int,
|
||||||
|
default=BATCH_SIZE,
|
||||||
|
help="Batch size of the input tensor.")
|
||||||
|
parser.add_argument("--seq_len",
|
||||||
|
type=int,
|
||||||
|
default=SEQ_LENGTH,
|
||||||
|
help="Sequence length of the input tensor.")
|
||||||
|
parser.add_argument("--hid_dim",
|
||||||
|
type=int,
|
||||||
|
default=HIDDEN_DIM,
|
||||||
|
help="Hidden dimension of the input tensor.")
|
||||||
|
parser.add_argument("--num_steps",
|
||||||
|
type=int,
|
||||||
|
default=ITER_TIMES,
|
||||||
|
help="The number of iteration times.")
|
||||||
|
return parser
|
||||||
|
|
||||||
|
def build_input_tensor(args):
|
||||||
|
return torch.rand(args.bs, args.seq_len, args.hid_dim)
|
||||||
|
|
||||||
|
def build_configs_helper(device_cnt: int):
|
||||||
|
config_dict = {}
|
||||||
|
|
||||||
|
if device_cnt < 2:
|
||||||
|
return config_dict
|
||||||
|
|
||||||
|
if device_cnt < 4:
|
||||||
|
config_dict["1d"] = dict(parallel=dict(tensor=dict(size=2, mode='1d')))
|
||||||
|
elif device_cnt < 8:
|
||||||
|
config_dict["1d"] = dict(parallel=dict(tensor=dict(size=4, mode='1d')))
|
||||||
|
config_dict["2d"] = dict(parallel=dict(tensor=dict(size=4, mode='2d')))
|
||||||
|
else:
|
||||||
|
config_dict["1d"] = dict(parallel=dict(tensor=dict(size=8, mode='1d')))
|
||||||
|
config_dict["2d"] = dict(parallel=dict(data=2, tensor=dict(size=4, mode='2d')))
|
||||||
|
config_dict["2p5d"] = dict(parallel=dict(tensor=dict(size=8, mode='2.5d', depth=2)))
|
||||||
|
config_dict["3d"] = dict(parallel=dict(tensor=dict(size=8, mode='3d')))
|
||||||
|
|
||||||
|
return config_dict
|
||||||
|
|
||||||
|
def build_configs(args):
|
||||||
|
total_device_cnt = torch.cuda.device_count()
|
||||||
|
if args.num_gpus == -1:
|
||||||
|
config_dict = build_configs_helper(total_device_cnt)
|
||||||
|
else:
|
||||||
|
valid_device_cnt = min(args.num_gpus, total_device_cnt)
|
||||||
|
config_dict = build_configs_helper(valid_device_cnt)
|
||||||
|
return config_dict
|
||||||
|
|
||||||
|
def profile_1d(input_tensor, config, args):
|
||||||
|
gpc.load_config(config)
|
||||||
|
gpc.init_parallel_groups()
|
||||||
|
assert gpc.is_initialized(ParallelMode.PARALLEL_1D)
|
||||||
|
model = MLP(args.hid_dim).cuda()
|
||||||
|
input_tensor = input_tensor.cuda()
|
||||||
|
torch.distributed.broadcast(input_tensor, src=0)
|
||||||
|
timer = Timer()
|
||||||
|
iter_times = args.num_steps
|
||||||
|
timer.start()
|
||||||
|
for i in range(iter_times):
|
||||||
|
input_tensor = model(input_tensor)
|
||||||
|
synchronize()
|
||||||
|
result_1d = timer.stop()
|
||||||
|
return result_1d
|
||||||
|
|
||||||
|
def profile_2d(input_tensor, config, args):
|
||||||
|
gpc.load_config(config)
|
||||||
|
gpc.init_parallel_groups()
|
||||||
|
assert gpc.is_initialized(ParallelMode.PARALLEL_2D_COL)
|
||||||
|
assert gpc.is_initialized(ParallelMode.PARALLEL_2D_ROW)
|
||||||
|
model = MLP(args.hid_dim).cuda()
|
||||||
|
input_tensor = input_tensor.cuda()
|
||||||
|
torch.distributed.broadcast(input_tensor, src=0)
|
||||||
|
input_tensor = torch.chunk(input_tensor, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)]
|
||||||
|
input_tensor = torch.chunk(input_tensor, 2, dim=-1)[gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)]
|
||||||
|
timer = Timer()
|
||||||
|
iter_times = args.num_steps
|
||||||
|
timer.start()
|
||||||
|
for i in range(iter_times):
|
||||||
|
input_tensor = model(input_tensor)
|
||||||
|
synchronize()
|
||||||
|
result_2d = timer.stop()
|
||||||
|
return result_2d
|
||||||
|
|
||||||
|
def profile_2p5d(input_tensor, config, args):
|
||||||
|
gpc.load_config(config)
|
||||||
|
gpc.init_parallel_groups()
|
||||||
|
assert gpc.is_initialized(ParallelMode.PARALLEL_2P5D_COL)
|
||||||
|
assert gpc.is_initialized(ParallelMode.PARALLEL_2P5D_ROW)
|
||||||
|
assert gpc.is_initialized(ParallelMode.PARALLEL_2P5D_DEP)
|
||||||
|
model = MLP(args.hid_dim).cuda()
|
||||||
|
input_tensor = input_tensor.cuda()
|
||||||
|
torch.distributed.broadcast(input_tensor, src=0)
|
||||||
|
input_tensor = torch.chunk(input_tensor, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)]
|
||||||
|
input_tensor = torch.chunk(input_tensor, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)]
|
||||||
|
input_tensor = torch.chunk(input_tensor, 2, dim=-1)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)]
|
||||||
|
timer = Timer()
|
||||||
|
iter_times = args.num_steps
|
||||||
|
timer.start()
|
||||||
|
for i in range(iter_times):
|
||||||
|
input_tensor = model(input_tensor)
|
||||||
|
synchronize()
|
||||||
|
result_2p5d = timer.stop()
|
||||||
|
return result_2p5d
|
||||||
|
|
||||||
|
def profile_3d(input_tensor, config, args):
|
||||||
|
gpc.load_config(config)
|
||||||
|
gpc.init_parallel_groups()
|
||||||
|
assert gpc.is_initialized(ParallelMode.PARALLEL_3D_WEIGHT)
|
||||||
|
assert gpc.is_initialized(ParallelMode.PARALLEL_3D_INPUT)
|
||||||
|
assert gpc.is_initialized(ParallelMode.PARALLEL_3D_OUTPUT)
|
||||||
|
model = MLP(args.hid_dim).cuda()
|
||||||
|
input_tensor = input_tensor.cuda()
|
||||||
|
torch.distributed.broadcast(input_tensor, src=0)
|
||||||
|
input_tensor = torch.chunk(input_tensor, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_3D_WEIGHT)]
|
||||||
|
input_tensor = torch.chunk(input_tensor, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_3D_INPUT)]
|
||||||
|
input_tensor = torch.chunk(input_tensor, 2, dim=-1)[gpc.get_local_rank(ParallelMode.PARALLEL_3D_OUTPUT)]
|
||||||
|
timer = Timer()
|
||||||
|
iter_times = args.num_steps
|
||||||
|
timer.start()
|
||||||
|
for i in range(iter_times):
|
||||||
|
input_tensor = model(input_tensor)
|
||||||
|
synchronize()
|
||||||
|
result_3d = timer.stop()
|
||||||
|
return result_3d
|
@ -1,15 +1,38 @@
|
|||||||
import click
|
import click
|
||||||
from colossalai.cli.launcher.run import main as col_launch
|
from colossalai.cli.launcher.run import main as col_launch
|
||||||
|
from colossalai.cli.benchmark.utils import BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM, ITER_TIMES
|
||||||
|
from colossalai.cli.benchmark.run import launch as col_benchmark
|
||||||
|
|
||||||
class Arguments():
|
class Arguments():
|
||||||
def __init__(self, dict):
|
def __init__(self, arg_dict):
|
||||||
for k, v in dict.items():
|
for k, v in arg_dict.items():
|
||||||
self.__dict__[k] = v
|
self.__dict__[k] = v
|
||||||
|
|
||||||
@click.group()
|
@click.group()
|
||||||
def cli():
|
def cli():
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@click.command()
|
||||||
|
@click.option("--num_gpus",
|
||||||
|
type=int,
|
||||||
|
default=-1)
|
||||||
|
@click.option("--bs",
|
||||||
|
type=int,
|
||||||
|
default=BATCH_SIZE)
|
||||||
|
@click.option("--seq_len",
|
||||||
|
type=int,
|
||||||
|
default=SEQ_LENGTH)
|
||||||
|
@click.option("--hid_dim",
|
||||||
|
type=int,
|
||||||
|
default=HIDDEN_DIM)
|
||||||
|
@click.option("--num_steps",
|
||||||
|
type=int,
|
||||||
|
default=ITER_TIMES)
|
||||||
|
def benchmark(num_gpus, bs, seq_len, hid_dim, num_steps):
|
||||||
|
args_dict = locals()
|
||||||
|
args = Arguments(args_dict)
|
||||||
|
col_benchmark(args)
|
||||||
|
|
||||||
@click.command()
|
@click.command()
|
||||||
@click.option("--hostfile",
|
@click.option("--hostfile",
|
||||||
type=str,
|
type=str,
|
||||||
@ -49,6 +72,7 @@ def launch(hostfile, num_nodes, num_gpus, include, exclude, master_addr, master_
|
|||||||
col_launch(args)
|
col_launch(args)
|
||||||
|
|
||||||
cli.add_command(launch)
|
cli.add_command(launch)
|
||||||
|
cli.add_command(benchmark)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
cli()
|
cli()
|
||||||
|
Loading…
Reference in New Issue
Block a user