From 85774f0c1fd1425149a61a463f726b25d49ec420 Mon Sep 17 00:00:00 2001 From: ocd_with_naming <54058983+yuanheng-zhao@users.noreply.github.com> Date: Tue, 18 Jul 2023 10:54:27 +0800 Subject: [PATCH] [NFC] polish colossalai/cli/benchmark/utils.py code style (#4254) --- colossalai/cli/benchmark/utils.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/colossalai/cli/benchmark/utils.py b/colossalai/cli/benchmark/utils.py index 825b795f2..ee7d92d6e 100644 --- a/colossalai/cli/benchmark/utils.py +++ b/colossalai/cli/benchmark/utils.py @@ -1,10 +1,11 @@ import math import time +from typing import Callable, Dict, List, Tuple + import torch +from colossalai.context import Config, ParallelMode from colossalai.utils import MultiTimer -from colossalai.context import ParallelMode, Config -from typing import List, Dict, Tuple, Callable def get_time_stamp() -> int: @@ -25,8 +26,8 @@ def get_memory_states() -> Tuple[float]: Return the memory statistics. Returns: - max_allocated (float): the allocated CUDA memory - max_cached (float): the cached CUDA memory + max_allocated (float): the allocated CUDA memory + max_cached (float): the cached CUDA memory """ max_allocated = torch.cuda.max_memory_allocated() / (1024**3) @@ -101,7 +102,7 @@ def profile_model(model: torch.nn.Module, warmup_steps: int, profile_steps: int, profile_steps (int): the number of steps for profiling data_func (Callable): a function to generate random data timer (colossalai.utils.Multitimer): a timer instance for time recording - + Returns: fwd_time (float): the average forward time taken by forward pass in second bwd_time (float): the average backward time taken by forward pass in second