mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 10:34:41 +00:00
[cli] refactored micro-benchmarking cli and added more metrics (#858)
This commit is contained in:
@@ -1,8 +1,7 @@
|
||||
import click
|
||||
from .launcher import run
|
||||
from .check import check
|
||||
from colossalai.cli.benchmark.utils import BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM, ITER_TIMES
|
||||
from colossalai.cli.benchmark.run import launch as col_benchmark
|
||||
from .benchmark import benchmark
|
||||
|
||||
|
||||
class Arguments():
|
||||
@@ -17,18 +16,6 @@ def cli():
|
||||
pass
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option("--num_gpus", type=int, default=-1)
|
||||
@click.option("--bs", type=int, default=BATCH_SIZE)
|
||||
@click.option("--seq_len", type=int, default=SEQ_LENGTH)
|
||||
@click.option("--hid_dim", type=int, default=HIDDEN_DIM)
|
||||
@click.option("--num_steps", type=int, default=ITER_TIMES)
|
||||
def benchmark(num_gpus, bs, seq_len, hid_dim, num_steps):
|
||||
args_dict = locals()
|
||||
args = Arguments(args_dict)
|
||||
col_benchmark(args)
|
||||
|
||||
|
||||
cli.add_command(run)
|
||||
cli.add_command(check)
|
||||
cli.add_command(benchmark)
|
||||
|
Reference in New Issue
Block a user