mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-18 09:21:26 +00:00
[NFC] polish colossalai/cli/benchmark/utils.py code style (#4254)
This commit is contained in:
parent
dee1c96344
commit
85774f0c1f
@ -1,10 +1,11 @@
|
|||||||
import math
|
import math
|
||||||
import time
|
import time
|
||||||
|
from typing import Callable, Dict, List, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from colossalai.context import Config, ParallelMode
|
||||||
from colossalai.utils import MultiTimer
|
from colossalai.utils import MultiTimer
|
||||||
from colossalai.context import ParallelMode, Config
|
|
||||||
from typing import List, Dict, Tuple, Callable
|
|
||||||
|
|
||||||
|
|
||||||
def get_time_stamp() -> int:
|
def get_time_stamp() -> int:
|
||||||
@ -25,8 +26,8 @@ def get_memory_states() -> Tuple[float]:
|
|||||||
Return the memory statistics.
|
Return the memory statistics.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
max_allocated (float): the allocated CUDA memory
|
max_allocated (float): the allocated CUDA memory
|
||||||
max_cached (float): the cached CUDA memory
|
max_cached (float): the cached CUDA memory
|
||||||
"""
|
"""
|
||||||
|
|
||||||
max_allocated = torch.cuda.max_memory_allocated() / (1024**3)
|
max_allocated = torch.cuda.max_memory_allocated() / (1024**3)
|
||||||
@ -101,7 +102,7 @@ def profile_model(model: torch.nn.Module, warmup_steps: int, profile_steps: int,
|
|||||||
profile_steps (int): the number of steps for profiling
|
profile_steps (int): the number of steps for profiling
|
||||||
data_func (Callable): a function to generate random data
|
data_func (Callable): a function to generate random data
|
||||||
timer (colossalai.utils.Multitimer): a timer instance for time recording
|
timer (colossalai.utils.Multitimer): a timer instance for time recording
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
fwd_time (float): the average forward time taken by forward pass in second
|
fwd_time (float): the average forward time taken by forward pass in second
|
||||||
bwd_time (float): the average backward time taken by forward pass in second
|
bwd_time (float): the average backward time taken by forward pass in second
|
||||||
|
Loading…
Reference in New Issue
Block a user