diff --git a/colossalai/cli/benchmark/utils.py b/colossalai/cli/benchmark/utils.py index 825b795f2..ee7d92d6e 100644 --- a/colossalai/cli/benchmark/utils.py +++ b/colossalai/cli/benchmark/utils.py @@ -1,10 +1,11 @@ import math import time +from typing import Callable, Dict, List, Tuple + import torch +from colossalai.context import Config, ParallelMode from colossalai.utils import MultiTimer -from colossalai.context import ParallelMode, Config -from typing import List, Dict, Tuple, Callable def get_time_stamp() -> int: @@ -25,8 +26,8 @@ def get_memory_states() -> Tuple[float]: Return the memory statistics. Returns: - max_allocated (float): the allocated CUDA memory - max_cached (float): the cached CUDA memory + max_allocated (float): the allocated CUDA memory + max_cached (float): the cached CUDA memory """ max_allocated = torch.cuda.max_memory_allocated() / (1024**3) @@ -101,7 +102,7 @@ def profile_model(model: torch.nn.Module, warmup_steps: int, profile_steps: int, profile_steps (int): the number of steps for profiling data_func (Callable): a function to generate random data timer (colossalai.utils.Multitimer): a timer instance for time recording - + Returns: fwd_time (float): the average forward time taken by forward pass in second bwd_time (float): the average backward time taken by forward pass in second