mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-20 00:47:13 +00:00
[fix] update optim state dict assert (include param group & state); fix mem assert after add optim;
This commit is contained in:
parent
ce58d8e8bf
commit
8366a7855f
@ -509,6 +509,15 @@ def run_fwd_bwd_iter_input(test_config):
|
|||||||
"precision": "bf16",
|
"precision": "bf16",
|
||||||
"num_model_chunk": 2,
|
"num_model_chunk": 2,
|
||||||
},
|
},
|
||||||
|
# {
|
||||||
|
# "batch_size": 8,
|
||||||
|
# "tp_size": 1,
|
||||||
|
# "pp_size": 4,
|
||||||
|
# "num_microbatches": 8,
|
||||||
|
# "zero_stage": 1,
|
||||||
|
# "precision": "bf16",
|
||||||
|
# "num_model_chunk": 2,
|
||||||
|
# },
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def run_fwd_bwd_vschedule_with_optim(test_config):
|
def run_fwd_bwd_vschedule_with_optim(test_config):
|
||||||
@ -593,8 +602,8 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
|
|||||||
local_chunk.append(sub_model)
|
local_chunk.append(sub_model)
|
||||||
|
|
||||||
# init optimizer
|
# init optimizer
|
||||||
optimizer_base = torch.optim.SGD(model_base.parameters(), lr=1e-5)
|
optimizer_base = torch.optim.SGD(model_base.parameters(), momentum=0.1, lr=1e-5)
|
||||||
optimizer_pp = OptimizerWrapper(torch.optim.SGD(local_chunk.parameters(), lr=1e-5))
|
optimizer_pp = OptimizerWrapper(torch.optim.SGD(local_chunk.parameters(), momentum=0.1, lr=1e-5))
|
||||||
|
|
||||||
after_init_memory = torch.cuda.memory_allocated() / 1024**3
|
after_init_memory = torch.cuda.memory_allocated() / 1024**3
|
||||||
print(f"After init Model & input: {after_init_memory :.5f} GB on device {stage_manager.get_rank()};")
|
print(f"After init Model & input: {after_init_memory :.5f} GB on device {stage_manager.get_rank()};")
|
||||||
@ -617,15 +626,16 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
|
|||||||
if rank != 0:
|
if rank != 0:
|
||||||
# w.grad hid_dim * hid_dim * 4(fp32) * 2 (2 layer in each stage) / 1024**3
|
# w.grad hid_dim * hid_dim * 4(fp32) * 2 (2 layer in each stage) / 1024**3
|
||||||
# output hid_dim * hid_dim * 4(fp32) / 1024**3
|
# output hid_dim * hid_dim * 4(fp32) / 1024**3
|
||||||
print(f"rank {rank}: {(after_pp_step_memory - after_init_memory)} <= {(in_dim * in_dim * 4 * 3 / 1024**3)}")
|
# optim state hid_dim * hid_dim * 4(fp32) * 2 (2 layer in each stage) / 1024**3
|
||||||
assert (after_pp_step_memory - after_init_memory) <= (in_dim * in_dim * 4 * 3 / 1024**3)
|
print(f"rank {rank}: {(after_pp_step_memory - after_init_memory)} <= {(in_dim * in_dim * 4 * 5 / 1024**3)}")
|
||||||
|
assert (after_pp_step_memory - after_init_memory) <= (in_dim * in_dim * 4 * 5 / 1024**3)
|
||||||
else:
|
else:
|
||||||
# rank0 will also hold output;
|
# rank0 will also hold output;
|
||||||
print(
|
print(
|
||||||
f"rank {rank}: {round((after_pp_step_memory - after_init_memory), 5)} <= {round((in_dim * in_dim * 4 * 3 / 1024**3 + batch_size * in_dim * in_dim * 4 / 1024**3), 5)}"
|
f"rank {rank}: {round((after_pp_step_memory - after_init_memory), 5)} <= {round((in_dim * in_dim * 4 * 5 / 1024**3 + batch_size * in_dim * in_dim * 4 / 1024**3), 5)}"
|
||||||
)
|
)
|
||||||
assert round((after_pp_step_memory - after_init_memory), 5) <= round(
|
assert round((after_pp_step_memory - after_init_memory), 5) <= round(
|
||||||
(in_dim * in_dim * 4 * 3 / 1024**3 + batch_size * in_dim * in_dim * 4 / 1024**3), 5
|
(in_dim * in_dim * 4 * 5 / 1024**3 + batch_size * in_dim * in_dim * 4 / 1024**3), 5
|
||||||
)
|
)
|
||||||
|
|
||||||
##########################
|
##########################
|
||||||
@ -681,10 +691,15 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
|
|||||||
##########################
|
##########################
|
||||||
# assert optim state
|
# assert optim state
|
||||||
##########################
|
##########################
|
||||||
optim_base_state_dict = optimizer_base.state_dict()["param_groups"][0]
|
optim_base_state = optimizer_base.state_dict()["state"]
|
||||||
optim_pp_state_dict = optimizer_pp.state_dict()["param_groups"][0]
|
optim_pp_state = optimizer_pp.state_dict()["state"]
|
||||||
|
optim_base_param_groups = optimizer_base.state_dict()["param_groups"][0]
|
||||||
|
optim_pp_param_groups = optimizer_pp.state_dict()["param_groups"][0]
|
||||||
|
# if rank == 0:
|
||||||
|
# print(f"optim_base_state {optim_base_state}")
|
||||||
|
|
||||||
for (key_base, val_base), (key_pp, val_pp) in zip(optim_base_state_dict.items(), optim_pp_state_dict.items()):
|
# assert param group
|
||||||
|
for (key_base, val_base), (key_pp, val_pp) in zip(optim_base_param_groups.items(), optim_pp_param_groups.items()):
|
||||||
if key_base == key_pp:
|
if key_base == key_pp:
|
||||||
if key_base != "params":
|
if key_base != "params":
|
||||||
assert val_base == val_pp
|
assert val_base == val_pp
|
||||||
@ -694,6 +709,10 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
|
|||||||
# params pp: [0, 1];
|
# params pp: [0, 1];
|
||||||
assert val_base[:2] == val_pp
|
assert val_base[:2] == val_pp
|
||||||
|
|
||||||
|
# assert state
|
||||||
|
assert_close(optim_pp_state[0]["momentum_buffer"], optim_base_state[2 * rank]["momentum_buffer"])
|
||||||
|
assert_close(optim_pp_state[1]["momentum_buffer"], optim_base_state[2 * rank + 1]["momentum_buffer"])
|
||||||
|
|
||||||
|
|
||||||
# TODO:4) support Hybrid base 3)
|
# TODO:4) support Hybrid base 3)
|
||||||
def run_with_hybridplugin(test_config):
|
def run_with_hybridplugin(test_config):
|
||||||
|
Loading…
Reference in New Issue
Block a user