[fix] update optim state dict assert (include param group & state); fix mem assert after add optim;

This commit is contained in:
duanjunwen 2024-09-09 09:27:13 +00:00
parent ce58d8e8bf
commit 8366a7855f

View File

@ -509,6 +509,15 @@ def run_fwd_bwd_iter_input(test_config):
"precision": "bf16", "precision": "bf16",
"num_model_chunk": 2, "num_model_chunk": 2,
}, },
# {
# "batch_size": 8,
# "tp_size": 1,
# "pp_size": 4,
# "num_microbatches": 8,
# "zero_stage": 1,
# "precision": "bf16",
# "num_model_chunk": 2,
# },
], ],
) )
def run_fwd_bwd_vschedule_with_optim(test_config): def run_fwd_bwd_vschedule_with_optim(test_config):
@ -593,8 +602,8 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
local_chunk.append(sub_model) local_chunk.append(sub_model)
# init optimizer # init optimizer
optimizer_base = torch.optim.SGD(model_base.parameters(), lr=1e-5) optimizer_base = torch.optim.SGD(model_base.parameters(), momentum=0.1, lr=1e-5)
optimizer_pp = OptimizerWrapper(torch.optim.SGD(local_chunk.parameters(), lr=1e-5)) optimizer_pp = OptimizerWrapper(torch.optim.SGD(local_chunk.parameters(), momentum=0.1, lr=1e-5))
after_init_memory = torch.cuda.memory_allocated() / 1024**3 after_init_memory = torch.cuda.memory_allocated() / 1024**3
print(f"After init Model & input: {after_init_memory :.5f} GB on device {stage_manager.get_rank()};") print(f"After init Model & input: {after_init_memory :.5f} GB on device {stage_manager.get_rank()};")
@ -617,15 +626,16 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
if rank != 0: if rank != 0:
# w.grad hid_dim * hid_dim * 4(fp32) * 2 (2 layer in each stage) / 1024**3 # w.grad hid_dim * hid_dim * 4(fp32) * 2 (2 layer in each stage) / 1024**3
# output hid_dim * hid_dim * 4(fp32) / 1024**3 # output hid_dim * hid_dim * 4(fp32) / 1024**3
print(f"rank {rank}: {(after_pp_step_memory - after_init_memory)} <= {(in_dim * in_dim * 4 * 3 / 1024**3)}") # optim state hid_dim * hid_dim * 4(fp32) * 2 (2 layer in each stage) / 1024**3
assert (after_pp_step_memory - after_init_memory) <= (in_dim * in_dim * 4 * 3 / 1024**3) print(f"rank {rank}: {(after_pp_step_memory - after_init_memory)} <= {(in_dim * in_dim * 4 * 5 / 1024**3)}")
assert (after_pp_step_memory - after_init_memory) <= (in_dim * in_dim * 4 * 5 / 1024**3)
else: else:
# rank0 will also hold output; # rank0 will also hold output;
print( print(
f"rank {rank}: {round((after_pp_step_memory - after_init_memory), 5)} <= {round((in_dim * in_dim * 4 * 3 / 1024**3 + batch_size * in_dim * in_dim * 4 / 1024**3), 5)}" f"rank {rank}: {round((after_pp_step_memory - after_init_memory), 5)} <= {round((in_dim * in_dim * 4 * 5 / 1024**3 + batch_size * in_dim * in_dim * 4 / 1024**3), 5)}"
) )
assert round((after_pp_step_memory - after_init_memory), 5) <= round( assert round((after_pp_step_memory - after_init_memory), 5) <= round(
(in_dim * in_dim * 4 * 3 / 1024**3 + batch_size * in_dim * in_dim * 4 / 1024**3), 5 (in_dim * in_dim * 4 * 5 / 1024**3 + batch_size * in_dim * in_dim * 4 / 1024**3), 5
) )
########################## ##########################
@ -681,10 +691,15 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
########################## ##########################
# assert optim state # assert optim state
########################## ##########################
optim_base_state_dict = optimizer_base.state_dict()["param_groups"][0] optim_base_state = optimizer_base.state_dict()["state"]
optim_pp_state_dict = optimizer_pp.state_dict()["param_groups"][0] optim_pp_state = optimizer_pp.state_dict()["state"]
optim_base_param_groups = optimizer_base.state_dict()["param_groups"][0]
optim_pp_param_groups = optimizer_pp.state_dict()["param_groups"][0]
# if rank == 0:
# print(f"optim_base_state {optim_base_state}")
for (key_base, val_base), (key_pp, val_pp) in zip(optim_base_state_dict.items(), optim_pp_state_dict.items()): # assert param group
for (key_base, val_base), (key_pp, val_pp) in zip(optim_base_param_groups.items(), optim_pp_param_groups.items()):
if key_base == key_pp: if key_base == key_pp:
if key_base != "params": if key_base != "params":
assert val_base == val_pp assert val_base == val_pp
@ -694,6 +709,10 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
# params pp: [0, 1]; # params pp: [0, 1];
assert val_base[:2] == val_pp assert val_base[:2] == val_pp
# assert state
assert_close(optim_pp_state[0]["momentum_buffer"], optim_base_state[2 * rank]["momentum_buffer"])
assert_close(optim_pp_state[1]["momentum_buffer"], optim_base_state[2 * rank + 1]["momentum_buffer"])
# TODO:4) support Hybrid base 3) # TODO:4) support Hybrid base 3)
def run_with_hybridplugin(test_config): def run_with_hybridplugin(test_config):