[feat] support meta cache, meta_grad_send, meta_tensor_send; fix runtime too long in Recv Bwd; benchmark for llama + Hybrid(tp+pp);

This commit is contained in:
duanjunwen
2024-10-24 07:30:19 +00:00
parent 705b18e1e7
commit 2eca112c90
8 changed files with 184 additions and 63 deletions

View File

@@ -40,6 +40,7 @@ MODEL_CONFIGS = {
),
"5b": LlamaConfig(max_position_embeddings=4096, num_key_value_heads=8),
"7b": LlamaConfig(max_position_embeddings=4096),
# "7b": LlamaConfig(num_hidden_layers=4, max_position_embeddings=4096),
"13b": LlamaConfig(
hidden_size=5120,
intermediate_size=13824,
@@ -127,9 +128,12 @@ def main():
{
"gradient_checkpoint_config": PipelineGradientCheckpointConfig(
num_ckpt_layers_per_stage=[19, 19, 19, 13],
# num_ckpt_layers_per_stage=[48, 48, 48, 48],
),
"num_layers_per_stage": [19, 20, 20, 21],
"pp_style": "interleaved",
# "num_layers_per_stage": [48, 48, 48, 48],
# "pp_style": "interleaved",
"pp_style": "1f1b",
}
if args.custom_ckpt
else {}
@@ -227,12 +231,14 @@ def main():
b_cost=1000,
w_cost=1000,
c_cost=1,
f_mem=mem_f,
b_mem=mem_b,
w_mem=mem_w,
f_mem=mem_f * 1.5,
b_mem=mem_b * 1.5,
w_mem=mem_w * 1.5,
).get_v_schedule()
else:
scheduler_nodes = None
# print(f"{dist.get_rank()} {scheduler_nodes[]} ")
plugin = HybridParallelPlugin(
tp_size=args.tp,
pp_size=args.pp,
@@ -267,7 +273,7 @@ def main():
microbatch_size=args.mbs,
initial_scale=2**8,
precision="bf16",
overlap_p2p=args.overlap,
overlap_p2p=True,
use_fp8=args.use_fp8,
fp8_communication=args.use_fp8_comm,
)
@@ -328,7 +334,7 @@ def main():
torch.set_default_dtype(torch.bfloat16)
model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader)
torch.set_default_dtype(torch.float)
# torch.set_default_dtype(torch.float)
coordinator.print_on_master(
f"Booster init max CUDA memory: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB"
)
@@ -340,7 +346,7 @@ def main():
args.profile,
args.ignore_steps,
1, # avoid creating massive log files
save_dir=f"profile/{time.strftime('%H:%M', time.localtime())}-{args.plugin}-llama-{args.config}",
save_dir=f"./profile/{time.strftime('%H:%M', time.localtime())}-{args.plugin}-llama-{args.config}",
nsys=args.nsys,
) as prof:
if isinstance(plugin, HybridParallelPlugin) and args.pp > 1: