mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-24 22:42:15 +00:00
[feat] fix testcase;
This commit is contained in:
parent
12919de424
commit
337debcf2a
@ -432,8 +432,6 @@ def _communicate(
|
||||
overlap_p2p=overlap_p2p,
|
||||
send_first=send_first if send_first != None else True,
|
||||
)
|
||||
# print(f"rank {dist.get_rank()}; recv_src {recv_src}; send_dst {send_dst}; metadata_send {metadata_send}; metadata_recv {metadata_recv};")
|
||||
|
||||
if metadata_recv is not None:
|
||||
assert isinstance(metadata_recv, P2PMetadata)
|
||||
tree_spec = metadata_recv.tree_spec
|
||||
|
@ -101,9 +101,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
# init buffer
|
||||
self._free_buffers()
|
||||
|
||||
def _set_send_metadata_buffers(self, model_chunk_id):
|
||||
pass
|
||||
|
||||
def _free_buffers(self):
|
||||
# free local buffer
|
||||
# two dim array, first dim is the model chunk, second dim is the microbatch queue
|
||||
|
@ -752,13 +752,13 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
|
||||
@parameterize(
|
||||
"config",
|
||||
[
|
||||
# # Pass
|
||||
# Pass
|
||||
(1, 2, 1, 1, 2),
|
||||
(1, 1, 2, 2, 1),
|
||||
(1, 2, 1, 2, 1),
|
||||
# TODO: adapt mixtral with no TP Linear
|
||||
# (1, 2, 2, 1, 1),
|
||||
# (0, 1, 4, 1, 1),
|
||||
# (1, 1, 2, 2, 1),
|
||||
# (1, 2, 1, 2, 1),
|
||||
],
|
||||
)
|
||||
def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
|
||||
@ -1070,10 +1070,12 @@ def run_with_booster_hybridplugin(config: Tuple[int, ...]):
|
||||
# parall_name = ".".join(parall_name.split(".")[1:])
|
||||
# for base_name, base_param in torch_model.named_parameters():
|
||||
# if parall_name == base_name:
|
||||
# # assert weight
|
||||
# # print(f"parall_name {parall_name} parall_param.grad {parall_param.grad is not None}, base_name {base_name} base_param.grad {base_param.grad is not None}")
|
||||
# # # assert weight
|
||||
# assert_loose_close(parall_param, base_param, dtype=dtype, name=parall_name)
|
||||
# # assert weight.grad
|
||||
# # # assert weight.grad
|
||||
# if parall_param.grad is not None:
|
||||
# # print(f"parall_param.grad {parall_param.grad}, base_param.grad {base_param.grad}")
|
||||
# assert_loose_close(parall_param.grad, base_param.grad, dtype=dtype, name=f"{parall_name}.grad")
|
||||
|
||||
assert_loose_close(parallel_output, torch_output_sum, dtype=dtype)
|
||||
|
Loading…
Reference in New Issue
Block a user