diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index 6ad93e6cb..3fbbe6ed0 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -509,6 +509,15 @@ def run_fwd_bwd_iter_input(test_config): "precision": "bf16", "num_model_chunk": 2, }, + # { + # "batch_size": 8, + # "tp_size": 1, + # "pp_size": 4, + # "num_microbatches": 8, + # "zero_stage": 1, + # "precision": "bf16", + # "num_model_chunk": 2, + # }, ], ) def run_fwd_bwd_vschedule_with_optim(test_config): @@ -593,8 +602,8 @@ def run_fwd_bwd_vschedule_with_optim(test_config): local_chunk.append(sub_model) # init optimizer - optimizer_base = torch.optim.SGD(model_base.parameters(), lr=1e-5) - optimizer_pp = OptimizerWrapper(torch.optim.SGD(local_chunk.parameters(), lr=1e-5)) + optimizer_base = torch.optim.SGD(model_base.parameters(), momentum=0.1, lr=1e-5) + optimizer_pp = OptimizerWrapper(torch.optim.SGD(local_chunk.parameters(), momentum=0.1, lr=1e-5)) after_init_memory = torch.cuda.memory_allocated() / 1024**3 print(f"After init Model & input: {after_init_memory :.5f} GB on device {stage_manager.get_rank()};") @@ -617,15 +626,16 @@ def run_fwd_bwd_vschedule_with_optim(test_config): if rank != 0: # w.grad hid_dim * hid_dim * 4(fp32) * 2 (2 layer in each stage) / 1024**3 # output hid_dim * hid_dim * 4(fp32) / 1024**3 - print(f"rank {rank}: {(after_pp_step_memory - after_init_memory)} <= {(in_dim * in_dim * 4 * 3 / 1024**3)}") - assert (after_pp_step_memory - after_init_memory) <= (in_dim * in_dim * 4 * 3 / 1024**3) + # optim state hid_dim * hid_dim * 4(fp32) * 2 (2 layer in each stage) / 1024**3 + print(f"rank {rank}: {(after_pp_step_memory - after_init_memory)} <= {(in_dim * in_dim * 4 * 5 / 1024**3)}") + assert (after_pp_step_memory - after_init_memory) <= (in_dim * in_dim * 4 * 5 / 1024**3) else: # rank0 will also hold output; print( - f"rank {rank}: {round((after_pp_step_memory - after_init_memory), 5)} <= {round((in_dim * in_dim * 4 * 3 / 1024**3 + batch_size * in_dim * in_dim * 4 / 1024**3), 5)}" + f"rank {rank}: {round((after_pp_step_memory - after_init_memory), 5)} <= {round((in_dim * in_dim * 4 * 5 / 1024**3 + batch_size * in_dim * in_dim * 4 / 1024**3), 5)}" ) assert round((after_pp_step_memory - after_init_memory), 5) <= round( - (in_dim * in_dim * 4 * 3 / 1024**3 + batch_size * in_dim * in_dim * 4 / 1024**3), 5 + (in_dim * in_dim * 4 * 5 / 1024**3 + batch_size * in_dim * in_dim * 4 / 1024**3), 5 ) ########################## @@ -681,10 +691,15 @@ def run_fwd_bwd_vschedule_with_optim(test_config): ########################## # assert optim state ########################## - optim_base_state_dict = optimizer_base.state_dict()["param_groups"][0] - optim_pp_state_dict = optimizer_pp.state_dict()["param_groups"][0] + optim_base_state = optimizer_base.state_dict()["state"] + optim_pp_state = optimizer_pp.state_dict()["state"] + optim_base_param_groups = optimizer_base.state_dict()["param_groups"][0] + optim_pp_param_groups = optimizer_pp.state_dict()["param_groups"][0] + # if rank == 0: + # print(f"optim_base_state {optim_base_state}") - for (key_base, val_base), (key_pp, val_pp) in zip(optim_base_state_dict.items(), optim_pp_state_dict.items()): + # assert param group + for (key_base, val_base), (key_pp, val_pp) in zip(optim_base_param_groups.items(), optim_pp_param_groups.items()): if key_base == key_pp: if key_base != "params": assert val_base == val_pp @@ -694,6 +709,10 @@ def run_fwd_bwd_vschedule_with_optim(test_config): # params pp: [0, 1]; assert val_base[:2] == val_pp + # assert state + assert_close(optim_pp_state[0]["momentum_buffer"], optim_base_state[2 * rank]["momentum_buffer"]) + assert_close(optim_pp_state[1]["momentum_buffer"], optim_base_state[2 * rank + 1]["momentum_buffer"]) + # TODO:4) support Hybrid base 3) def run_with_hybridplugin(test_config):