mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 10:34:41 +00:00
[cli] fixed a bug in user args and refactored the module structure
This commit is contained in:
@@ -1,77 +1,34 @@
|
||||
import click
|
||||
from colossalai.cli.launcher.run import main as col_launch
|
||||
from .launcher import run
|
||||
from colossalai.cli.benchmark.utils import BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM, ITER_TIMES
|
||||
from colossalai.cli.benchmark.run import launch as col_benchmark
|
||||
|
||||
|
||||
class Arguments():
|
||||
|
||||
def __init__(self, arg_dict):
|
||||
for k, v in arg_dict.items():
|
||||
self.__dict__[k] = v
|
||||
|
||||
|
||||
@click.group()
|
||||
def cli():
|
||||
pass
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option("--num_gpus",
|
||||
type=int,
|
||||
default=-1)
|
||||
@click.option("--bs",
|
||||
type=int,
|
||||
default=BATCH_SIZE)
|
||||
@click.option("--seq_len",
|
||||
type=int,
|
||||
default=SEQ_LENGTH)
|
||||
@click.option("--hid_dim",
|
||||
type=int,
|
||||
default=HIDDEN_DIM)
|
||||
@click.option("--num_steps",
|
||||
type=int,
|
||||
default=ITER_TIMES)
|
||||
@click.option("--num_gpus", type=int, default=-1)
|
||||
@click.option("--bs", type=int, default=BATCH_SIZE)
|
||||
@click.option("--seq_len", type=int, default=SEQ_LENGTH)
|
||||
@click.option("--hid_dim", type=int, default=HIDDEN_DIM)
|
||||
@click.option("--num_steps", type=int, default=ITER_TIMES)
|
||||
def benchmark(num_gpus, bs, seq_len, hid_dim, num_steps):
|
||||
args_dict = locals()
|
||||
args = Arguments(args_dict)
|
||||
col_benchmark(args)
|
||||
|
||||
@click.command()
|
||||
@click.option("--hostfile",
|
||||
type=str,
|
||||
default="")
|
||||
@click.option("--include",
|
||||
type=str,
|
||||
default="")
|
||||
@click.option("--exclude",
|
||||
type=str,
|
||||
default="")
|
||||
@click.option("--num_nodes",
|
||||
type=int,
|
||||
default=-1)
|
||||
@click.option("--num_gpus",
|
||||
type=int,
|
||||
default=-1)
|
||||
@click.option("--master_port",
|
||||
type=int,
|
||||
default=29500)
|
||||
@click.option("--master_addr",
|
||||
type=str,
|
||||
default="127.0.0.1")
|
||||
@click.option("--launcher",
|
||||
type=str,
|
||||
default="torch")
|
||||
@click.option("--launcher_args",
|
||||
type=str,
|
||||
default="")
|
||||
@click.argument("user_script",
|
||||
type=str)
|
||||
@click.argument('user_args', nargs=-1)
|
||||
def launch(hostfile, num_nodes, num_gpus, include, exclude, master_addr, master_port,
|
||||
launcher, launcher_args, user_script, user_args):
|
||||
args_dict = locals()
|
||||
args = Arguments(args_dict)
|
||||
args.user_args = list(args.user_args)
|
||||
col_launch(args)
|
||||
|
||||
cli.add_command(launch)
|
||||
cli.add_command(run)
|
||||
cli.add_command(benchmark)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
Reference in New Issue
Block a user