mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 17:17:05 +00:00
[gemini] add profiler in the demo (#2534)
This commit is contained in:
@@ -1,4 +1,17 @@
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
from torch.profiler import ProfilerActivity, profile, schedule, tensorboard_trace_handler
|
||||
|
||||
|
||||
class DummyProfiler:
|
||||
|
||||
def __init__(self):
|
||||
self.step_number = 0
|
||||
|
||||
def step(self):
|
||||
self.step_number += 1
|
||||
|
||||
|
||||
# Randomly Generated Data
|
||||
@@ -10,3 +23,19 @@ def get_data(batch_size, seq_len, vocab_size):
|
||||
|
||||
def get_tflops(model_numel, batch_size, seq_len, step_time):
|
||||
return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12)
|
||||
|
||||
|
||||
def get_profile_context(enable_flag, warmup_steps, active_steps, save_dir):
|
||||
if enable_flag:
|
||||
return profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
||||
schedule=schedule(wait=0, warmup=warmup_steps, active=active_steps),
|
||||
on_trace_ready=tensorboard_trace_handler(save_dir),
|
||||
record_shapes=True,
|
||||
profile_memory=True)
|
||||
else:
|
||||
return nullcontext(DummyProfiler())
|
||||
|
||||
|
||||
def get_time_stamp():
|
||||
cur_time = time.strftime("%d-%H:%M", time.localtime())
|
||||
return cur_time
|
||||
|
Reference in New Issue
Block a user