mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 10:06:44 +00:00
[NFC] polish colossalai/cli/benchmark/utils.py code style (#4254)
This commit is contained in:
committed by
binmakeswell
parent
dee1c96344
commit
85774f0c1f
@@ -1,10 +1,11 @@
|
|||||||
import math
|
import math
|
||||||
import time
|
import time
|
||||||
|
from typing import Callable, Dict, List, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from colossalai.context import Config, ParallelMode
|
||||||
from colossalai.utils import MultiTimer
|
from colossalai.utils import MultiTimer
|
||||||
from colossalai.context import ParallelMode, Config
|
|
||||||
from typing import List, Dict, Tuple, Callable
|
|
||||||
|
|
||||||
|
|
||||||
def get_time_stamp() -> int:
|
def get_time_stamp() -> int:
|
||||||
@@ -25,8 +26,8 @@ def get_memory_states() -> Tuple[float]:
|
|||||||
Return the memory statistics.
|
Return the memory statistics.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
max_allocated (float): the allocated CUDA memory
|
max_allocated (float): the allocated CUDA memory
|
||||||
max_cached (float): the cached CUDA memory
|
max_cached (float): the cached CUDA memory
|
||||||
"""
|
"""
|
||||||
|
|
||||||
max_allocated = torch.cuda.max_memory_allocated() / (1024**3)
|
max_allocated = torch.cuda.max_memory_allocated() / (1024**3)
|
||||||
@@ -101,7 +102,7 @@ def profile_model(model: torch.nn.Module, warmup_steps: int, profile_steps: int,
|
|||||||
profile_steps (int): the number of steps for profiling
|
profile_steps (int): the number of steps for profiling
|
||||||
data_func (Callable): a function to generate random data
|
data_func (Callable): a function to generate random data
|
||||||
timer (colossalai.utils.Multitimer): a timer instance for time recording
|
timer (colossalai.utils.Multitimer): a timer instance for time recording
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
fwd_time (float): the average forward time taken by forward pass in second
|
fwd_time (float): the average forward time taken by forward pass in second
|
||||||
bwd_time (float): the average backward time taken by forward pass in second
|
bwd_time (float): the average backward time taken by forward pass in second
|
||||||
|
Reference in New Issue
Block a user