[feat] fix testcase;

This commit is contained in:
duanjunwen 2024-11-11 11:34:29 +00:00
parent 12919de424
commit 337debcf2a
3 changed files with 7 additions and 10 deletions

View File

@ -432,8 +432,6 @@ def _communicate(
overlap_p2p=overlap_p2p,
send_first=send_first if send_first != None else True,
)
# print(f"rank {dist.get_rank()}; recv_src {recv_src}; send_dst {send_dst}; metadata_send {metadata_send}; metadata_recv {metadata_recv};")
if metadata_recv is not None:
assert isinstance(metadata_recv, P2PMetadata)
tree_spec = metadata_recv.tree_spec

View File

@ -101,9 +101,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# init buffer
self._free_buffers()
def _set_send_metadata_buffers(self, model_chunk_id):
pass
def _free_buffers(self):
# free local buffer
# two dim array, first dim is the model chunk, second dim is the microbatch queue

View File

@ -752,13 +752,13 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
@parameterize(
"config",
[
# # Pass
# Pass
(1, 2, 1, 1, 2),
(1, 1, 2, 2, 1),
(1, 2, 1, 2, 1),
# TODO: adapt mixtral with no TP Linear
# (1, 2, 2, 1, 1),
# (0, 1, 4, 1, 1),
# (1, 1, 2, 2, 1),
# (1, 2, 1, 2, 1),
],
)
def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
@ -1070,10 +1070,12 @@ def run_with_booster_hybridplugin(config: Tuple[int, ...]):
# parall_name = ".".join(parall_name.split(".")[1:])
# for base_name, base_param in torch_model.named_parameters():
# if parall_name == base_name:
# # assert weight
# # print(f"parall_name {parall_name} parall_param.grad {parall_param.grad is not None}, base_name {base_name} base_param.grad {base_param.grad is not None}")
# # # assert weight
# assert_loose_close(parall_param, base_param, dtype=dtype, name=parall_name)
# # assert weight.grad
# # # assert weight.grad
# if parall_param.grad is not None:
# # print(f"parall_param.grad {parall_param.grad}, base_param.grad {base_param.grad}")
# assert_loose_close(parall_param.grad, base_param.grad, dtype=dtype, name=f"{parall_name}.grad")
assert_loose_close(parallel_output, torch_output_sum, dtype=dtype)