[NFC] polish colossalai/cli/benchmark/utils.py code style (#4254)

This commit is contained in:
ocd_with_naming 2023-07-18 10:54:27 +08:00 committed by binmakeswell
parent dee1c96344
commit 85774f0c1f

View File

@ -1,10 +1,11 @@
import math import math
import time import time
from typing import Callable, Dict, List, Tuple
import torch import torch
from colossalai.context import Config, ParallelMode
from colossalai.utils import MultiTimer from colossalai.utils import MultiTimer
from colossalai.context import ParallelMode, Config
from typing import List, Dict, Tuple, Callable
def get_time_stamp() -> int: def get_time_stamp() -> int:
@ -25,8 +26,8 @@ def get_memory_states() -> Tuple[float]:
Return the memory statistics. Return the memory statistics.
Returns: Returns:
max_allocated (float): the allocated CUDA memory max_allocated (float): the allocated CUDA memory
max_cached (float): the cached CUDA memory max_cached (float): the cached CUDA memory
""" """
max_allocated = torch.cuda.max_memory_allocated() / (1024**3) max_allocated = torch.cuda.max_memory_allocated() / (1024**3)
@ -101,7 +102,7 @@ def profile_model(model: torch.nn.Module, warmup_steps: int, profile_steps: int,
profile_steps (int): the number of steps for profiling profile_steps (int): the number of steps for profiling
data_func (Callable): a function to generate random data data_func (Callable): a function to generate random data
timer (colossalai.utils.Multitimer): a timer instance for time recording timer (colossalai.utils.Multitimer): a timer instance for time recording
Returns: Returns:
fwd_time (float): the average forward time taken by forward pass in second fwd_time (float): the average forward time taken by forward pass in second
bwd_time (float): the average backward time taken by forward pass in second bwd_time (float): the average backward time taken by forward pass in second