mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-18 09:21:26 +00:00
[gemini] add profiler in the demo (#2534)
This commit is contained in:
parent
df437ca039
commit
6e0faa70e0
@ -1,4 +1,17 @@
|
|||||||
|
import time
|
||||||
|
from contextlib import nullcontext
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from torch.profiler import ProfilerActivity, profile, schedule, tensorboard_trace_handler
|
||||||
|
|
||||||
|
|
||||||
|
class DummyProfiler:
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.step_number = 0
|
||||||
|
|
||||||
|
def step(self):
|
||||||
|
self.step_number += 1
|
||||||
|
|
||||||
|
|
||||||
# Randomly Generated Data
|
# Randomly Generated Data
|
||||||
@ -10,3 +23,19 @@ def get_data(batch_size, seq_len, vocab_size):
|
|||||||
|
|
||||||
def get_tflops(model_numel, batch_size, seq_len, step_time):
|
def get_tflops(model_numel, batch_size, seq_len, step_time):
|
||||||
return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12)
|
return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12)
|
||||||
|
|
||||||
|
|
||||||
|
def get_profile_context(enable_flag, warmup_steps, active_steps, save_dir):
|
||||||
|
if enable_flag:
|
||||||
|
return profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
||||||
|
schedule=schedule(wait=0, warmup=warmup_steps, active=active_steps),
|
||||||
|
on_trace_ready=tensorboard_trace_handler(save_dir),
|
||||||
|
record_shapes=True,
|
||||||
|
profile_memory=True)
|
||||||
|
else:
|
||||||
|
return nullcontext(DummyProfiler())
|
||||||
|
|
||||||
|
|
||||||
|
def get_time_stamp():
|
||||||
|
cur_time = time.strftime("%d-%H:%M", time.localtime())
|
||||||
|
return cur_time
|
||||||
|
@ -6,7 +6,7 @@ import psutil
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from commons.model_zoo import model_builder
|
from commons.model_zoo import model_builder
|
||||||
from commons.utils import get_data, get_tflops
|
from commons.utils import get_data, get_profile_context, get_tflops, get_time_stamp
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
|
||||||
@ -201,7 +201,8 @@ def main():
|
|||||||
|
|
||||||
WARMUP_STEPS = 1
|
WARMUP_STEPS = 1
|
||||||
assert WARMUP_STEPS < NUM_STEPS, "warmup steps should smaller than the total steps"
|
assert WARMUP_STEPS < NUM_STEPS, "warmup steps should smaller than the total steps"
|
||||||
assert (NUM_STEPS - WARMUP_STEPS) % 2 == 1, "the number of valid steps should be odd to take the median "
|
assert (NUM_STEPS - WARMUP_STEPS) % 2 == 1, "the number of valid steps should be odd to take the median"
|
||||||
|
PROF_FLAG = False # The flag of profiling, False by default
|
||||||
|
|
||||||
disable_existing_loggers()
|
disable_existing_loggers()
|
||||||
colossalai.launch_from_torch(config={})
|
colossalai.launch_from_torch(config={})
|
||||||
@ -292,7 +293,8 @@ def main():
|
|||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
model.train()
|
model.train()
|
||||||
tflops_list = []
|
tflops_list = []
|
||||||
for n in range(NUM_STEPS):
|
|
||||||
|
def train_step():
|
||||||
# we just use randomly generated data here
|
# we just use randomly generated data here
|
||||||
input_ids, attn_mask = get_data(BATCH_SIZE, SEQ_LEN, VOCAB_SIZE)
|
input_ids, attn_mask = get_data(BATCH_SIZE, SEQ_LEN, VOCAB_SIZE)
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
@ -331,6 +333,16 @@ def main():
|
|||||||
if n >= WARMUP_STEPS:
|
if n >= WARMUP_STEPS:
|
||||||
tflops_list.append(step_tflops)
|
tflops_list.append(step_tflops)
|
||||||
|
|
||||||
|
demo_profiler = get_profile_context(PROF_FLAG,
|
||||||
|
WARMUP_STEPS,
|
||||||
|
NUM_STEPS - WARMUP_STEPS,
|
||||||
|
save_dir=f"profile/{get_time_stamp()}-demo")
|
||||||
|
|
||||||
|
with demo_profiler as prof:
|
||||||
|
for n in range(NUM_STEPS):
|
||||||
|
train_step()
|
||||||
|
prof.step()
|
||||||
|
|
||||||
tflops_list.sort()
|
tflops_list.sort()
|
||||||
median_index = ((NUM_STEPS - WARMUP_STEPS) >> 1) + WARMUP_STEPS
|
median_index = ((NUM_STEPS - WARMUP_STEPS) >> 1) + WARMUP_STEPS
|
||||||
logger.info(f"Median TFLOPS is {tflops_list[median_index]:.3f}")
|
logger.info(f"Median TFLOPS is {tflops_list[median_index]:.3f}")
|
||||||
|
Loading…
Reference in New Issue
Block a user