diff --git a/colossalai/cli/benchmark/benchmark.py b/colossalai/cli/benchmark/benchmark.py index 43632b150..f40f8f2f9 100644 --- a/colossalai/cli/benchmark/benchmark.py +++ b/colossalai/cli/benchmark/benchmark.py @@ -1,16 +1,17 @@ -import colossalai +from functools import partial +from typing import Dict, List + import click import torch.multiprocessing as mp -from functools import partial -from typing import List, Dict - +import colossalai +from colossalai.cli.benchmark.utils import find_all_configs, get_batch_data, profile_model from colossalai.context import Config from colossalai.context.random import reset_seeds from colossalai.core import global_context as gpc from colossalai.logging import disable_existing_loggers, get_dist_logger -from colossalai.utils import free_port, MultiTimer -from colossalai.cli.benchmark.utils import find_all_configs, profile_model, get_batch_data +from colossalai.utils import MultiTimer, free_port + from .models import MLP @@ -53,7 +54,7 @@ def run_dist_profiling(rank: int, world_size: int, port_list: List[int], config_ port_list (List[int]): a list of free ports for initializing distributed networks config_list (List[Dict]): a list of configuration hyperparams (Config): the hyperparameters given by the user - + """ # disable logging for clean output