mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
[example] add profile util for llama
This commit is contained in:
@@ -4,6 +4,7 @@ from typing import Optional
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import Tensor
|
||||
from torch.profiler import ProfilerActivity, profile, schedule, tensorboard_trace_handler
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.cluster import DistCoordinator
|
||||
@@ -27,6 +28,27 @@ def all_reduce_mean(x: float, world_size: int) -> float:
|
||||
return tensor.item()
|
||||
|
||||
|
||||
def get_profile_context(enable_flag, warmup_steps, active_steps, save_dir):
|
||||
class DummyProfiler:
|
||||
def __init__(self):
|
||||
self.step_number = 0
|
||||
|
||||
def step(self):
|
||||
self.step_number += 1
|
||||
|
||||
if enable_flag:
|
||||
return profile(
|
||||
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
||||
schedule=schedule(wait=0, warmup=warmup_steps, active=active_steps),
|
||||
on_trace_ready=tensorboard_trace_handler(save_dir),
|
||||
# record_shapes=True,
|
||||
# profile_memory=True,
|
||||
with_stack=True,
|
||||
)
|
||||
else:
|
||||
return DummyProfiler()
|
||||
|
||||
|
||||
class Timer:
|
||||
def __init__(self) -> None:
|
||||
self.start_time: Optional[float] = None
|
||||
|
Reference in New Issue
Block a user