[chore] refactor & sync

This commit is contained in:
hxwang
2024-05-16 07:22:10 +00:00
parent 4148ceed9f
commit 2e68eebdfe
7 changed files with 82 additions and 46 deletions

View File

@@ -30,8 +30,9 @@ def get_profile_context(enable_flag, warmup_steps, active_steps, save_dir):
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
schedule=schedule(wait=0, warmup=warmup_steps, active=active_steps),
on_trace_ready=tensorboard_trace_handler(save_dir),
record_shapes=True,
profile_memory=True,
# record_shapes=True,
# profile_memory=True,
with_stack=True,
)
else:
return nullcontext(DummyProfiler())

View File

@@ -129,7 +129,7 @@ def main():
WARMUP_STEPS = 1
assert WARMUP_STEPS < NUM_STEPS, "warmup steps should smaller than the total steps"
assert (NUM_STEPS - WARMUP_STEPS) % 2 == 1, "the number of valid steps should be odd to take the median"
PROF_FLAG = False # The flag of profiling, False by default
PROF_FLAG = True # The flag of profiling, False by default
disable_existing_loggers()
colossalai.launch_from_torch()
@@ -166,7 +166,7 @@ def main():
stage=zero_stage, reduce_bucket_size_in_m=12, overlap_communication=True, verbose=True
)
elif args.distplan == "CAI_Gemini":
plugin = GeminiPlugin(search_range_m=128, hidden_dim=model.config.n_embd)
plugin = GeminiPlugin(search_range_m=128, hidden_dim=model.config.n_embd, max_prefetch=1)
else:
raise RuntimeError
@@ -248,7 +248,7 @@ def main():
prof.step()
tflops_list.sort()
median_index = ((NUM_STEPS - WARMUP_STEPS) >> 1) + WARMUP_STEPS
median_index = min(((NUM_STEPS - WARMUP_STEPS) >> 1) + WARMUP_STEPS, len(tflops_list) - 1)
logger.info(f"Median TFLOPS is {tflops_list[median_index]:.3f}")
torch.cuda.synchronize()