diff --git a/examples/language/gpt/gemini/commons/utils.py b/examples/language/gpt/gemini/commons/utils.py index 782f546dc..7bd098c19 100644 --- a/examples/language/gpt/gemini/commons/utils.py +++ b/examples/language/gpt/gemini/commons/utils.py @@ -1,4 +1,17 @@ +import time +from contextlib import nullcontext + import torch +from torch.profiler import ProfilerActivity, profile, schedule, tensorboard_trace_handler + + +class DummyProfiler: + + def __init__(self): + self.step_number = 0 + + def step(self): + self.step_number += 1 # Randomly Generated Data @@ -10,3 +23,19 @@ def get_data(batch_size, seq_len, vocab_size): def get_tflops(model_numel, batch_size, seq_len, step_time): return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12) + + +def get_profile_context(enable_flag, warmup_steps, active_steps, save_dir): + if enable_flag: + return profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + schedule=schedule(wait=0, warmup=warmup_steps, active=active_steps), + on_trace_ready=tensorboard_trace_handler(save_dir), + record_shapes=True, + profile_memory=True) + else: + return nullcontext(DummyProfiler()) + + +def get_time_stamp(): + cur_time = time.strftime("%d-%H:%M", time.localtime()) + return cur_time diff --git a/examples/language/gpt/gemini/train_gpt_demo.py b/examples/language/gpt/gemini/train_gpt_demo.py index ab8a65e62..f46226bce 100644 --- a/examples/language/gpt/gemini/train_gpt_demo.py +++ b/examples/language/gpt/gemini/train_gpt_demo.py @@ -6,7 +6,7 @@ import psutil import torch import torch.nn as nn from commons.model_zoo import model_builder -from commons.utils import get_data, get_tflops +from commons.utils import get_data, get_profile_context, get_tflops, get_time_stamp from packaging import version from torch.nn.parallel import DistributedDataParallel as DDP @@ -201,7 +201,8 @@ def main(): WARMUP_STEPS = 1 assert WARMUP_STEPS < NUM_STEPS, "warmup steps should smaller than the total steps" - assert (NUM_STEPS - WARMUP_STEPS) % 2 == 1, "the number of valid steps should be odd to take the median " + assert (NUM_STEPS - WARMUP_STEPS) % 2 == 1, "the number of valid steps should be odd to take the median" + PROF_FLAG = False # The flag of profiling, False by default disable_existing_loggers() colossalai.launch_from_torch(config={}) @@ -292,7 +293,8 @@ def main(): torch.cuda.synchronize() model.train() tflops_list = [] - for n in range(NUM_STEPS): + + def train_step(): # we just use randomly generated data here input_ids, attn_mask = get_data(BATCH_SIZE, SEQ_LEN, VOCAB_SIZE) optimizer.zero_grad() @@ -331,6 +333,16 @@ def main(): if n >= WARMUP_STEPS: tflops_list.append(step_tflops) + demo_profiler = get_profile_context(PROF_FLAG, + WARMUP_STEPS, + NUM_STEPS - WARMUP_STEPS, + save_dir=f"profile/{get_time_stamp()}-demo") + + with demo_profiler as prof: + for n in range(NUM_STEPS): + train_step() + prof.step() + tflops_list.sort() median_index = ((NUM_STEPS - WARMUP_STEPS) >> 1) + WARMUP_STEPS logger.info(f"Median TFLOPS is {tflops_list[median_index]:.3f}")