mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-20 20:54:55 +00:00
[fix] fix send_tensor_metadata & send_grad_metadata;
This commit is contained in:
parent
0d6d40ccc6
commit
12919de424
@ -432,6 +432,7 @@ def _communicate(
|
|||||||
overlap_p2p=overlap_p2p,
|
overlap_p2p=overlap_p2p,
|
||||||
send_first=send_first if send_first != None else True,
|
send_first=send_first if send_first != None else True,
|
||||||
)
|
)
|
||||||
|
# print(f"rank {dist.get_rank()}; recv_src {recv_src}; send_dst {send_dst}; metadata_send {metadata_send}; metadata_recv {metadata_recv};")
|
||||||
|
|
||||||
if metadata_recv is not None:
|
if metadata_recv is not None:
|
||||||
assert isinstance(metadata_recv, P2PMetadata)
|
assert isinstance(metadata_recv, P2PMetadata)
|
||||||
|
@ -64,8 +64,25 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
|
|
||||||
# P2PMeta cache
|
# P2PMeta cache
|
||||||
self.enable_metadata_cache = enable_metadata_cache
|
self.enable_metadata_cache = enable_metadata_cache
|
||||||
|
|
||||||
|
# check send_tensor_metadata, send_grad_metadata
|
||||||
|
# pp4 as sample, we should follow this meta strategy
|
||||||
|
# send_tensor_meta(fwd) send_grad_meta(bwd)
|
||||||
|
# chunk0 | chunk1 chunk0 | chunk 1
|
||||||
|
# stage 0 T | F F | T
|
||||||
|
# stage 1 T | T T | T
|
||||||
|
# stage 2 T | T T | T
|
||||||
|
# stage 3 F | T F | T
|
||||||
|
if stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
|
self.send_tensor_metadata = [True, False]
|
||||||
|
self.send_grad_metadata = [False, True]
|
||||||
|
elif stage_manager.is_last_stage(ignore_chunk=True):
|
||||||
|
self.send_tensor_metadata = [False, True]
|
||||||
|
self.send_grad_metadata = [True, False]
|
||||||
|
else:
|
||||||
self.send_tensor_metadata = [True, True]
|
self.send_tensor_metadata = [True, True]
|
||||||
self.send_grad_metadata = [True, True]
|
self.send_grad_metadata = [True, True]
|
||||||
|
|
||||||
# meta cache buffer
|
# meta cache buffer
|
||||||
self.tensor_metadata_recv = [None, None] # [chunk 0 meta, chunk 1 meta]
|
self.tensor_metadata_recv = [None, None] # [chunk 0 meta, chunk 1 meta]
|
||||||
self.grad_metadata_recv = [None, None]
|
self.grad_metadata_recv = [None, None]
|
||||||
@ -84,6 +101,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
# init buffer
|
# init buffer
|
||||||
self._free_buffers()
|
self._free_buffers()
|
||||||
|
|
||||||
|
def _set_send_metadata_buffers(self, model_chunk_id):
|
||||||
|
pass
|
||||||
|
|
||||||
def _free_buffers(self):
|
def _free_buffers(self):
|
||||||
# free local buffer
|
# free local buffer
|
||||||
# two dim array, first dim is the model chunk, second dim is the microbatch queue
|
# two dim array, first dim is the model chunk, second dim is the microbatch queue
|
||||||
@ -285,7 +305,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
# do nothing; Already get dy from local_send_backward_buffer in schedule b
|
# do nothing; Already get dy from local_send_backward_buffer in schedule b
|
||||||
################
|
################
|
||||||
if self.stage_manager.is_last_stage(ignore_chunk=True):
|
if self.stage_manager.is_last_stage(ignore_chunk=True):
|
||||||
# return None, []
|
|
||||||
return []
|
return []
|
||||||
|
|
||||||
################
|
################
|
||||||
@ -300,7 +319,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
if self.enable_metadata_cache and self.grad_metadata_recv[model_chunk_id] is None:
|
if self.enable_metadata_cache and self.grad_metadata_recv[model_chunk_id] is None:
|
||||||
self.grad_metadata_recv[model_chunk_id] = create_send_metadata(output_tensor_grad)
|
self.grad_metadata_recv[model_chunk_id] = create_send_metadata(output_tensor_grad)
|
||||||
self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad)
|
self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad)
|
||||||
# return output_tensor_grad, wait_handles
|
|
||||||
return wait_handles
|
return wait_handles
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@ -345,6 +363,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
# do nothing; hold y on local_send_forward_buffer
|
# do nothing; hold y on local_send_forward_buffer
|
||||||
################
|
################
|
||||||
if self.stage_manager.is_last_stage(ignore_chunk=True):
|
if self.stage_manager.is_last_stage(ignore_chunk=True):
|
||||||
|
self.send_tensor_metadata[model_chunk_id] = not self.enable_metadata_cache
|
||||||
return []
|
return []
|
||||||
|
|
||||||
################
|
################
|
||||||
@ -368,6 +387,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
# do nothing; Already send LOSS to local_send_backward_buffer in schedule f send part
|
# do nothing; Already send LOSS to local_send_backward_buffer in schedule f send part
|
||||||
################
|
################
|
||||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
|
self.send_tensor_metadata[model_chunk_id] = not self.enable_metadata_cache
|
||||||
return []
|
return []
|
||||||
|
|
||||||
################
|
################
|
||||||
@ -403,6 +423,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
# do nothing; cause u are the first chunk in first stage; bwd end
|
# do nothing; cause u are the first chunk in first stage; bwd end
|
||||||
################
|
################
|
||||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
|
self.send_grad_metadata[model_chunk_id] = not self.enable_metadata_cache
|
||||||
return []
|
return []
|
||||||
|
|
||||||
################
|
################
|
||||||
@ -425,6 +446,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
# do nothing; Already send input_tensor_grad to local_send_bwd_buffer in schedule b;
|
# do nothing; Already send input_tensor_grad to local_send_bwd_buffer in schedule b;
|
||||||
################
|
################
|
||||||
if self.stage_manager.is_last_stage(ignore_chunk=True):
|
if self.stage_manager.is_last_stage(ignore_chunk=True):
|
||||||
|
self.send_grad_metadata[model_chunk_id] = not self.enable_metadata_cache
|
||||||
return []
|
return []
|
||||||
|
|
||||||
################
|
################
|
||||||
@ -889,7 +911,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
for h in self.wait_handles:
|
for h in self.wait_handles:
|
||||||
for hh in h:
|
for hh in h:
|
||||||
hh.wait()
|
hh.wait()
|
||||||
|
# print(f"stage {self.stage_manager.stage}; self.tensor_metadata_recv[0] {self.tensor_metadata_recv[0]}; self.tensor_metadata_recv[1] {self.tensor_metadata_recv[1]}; self.grad_metadata_recv[0] {self.grad_metadata_recv[0]}; self.grad_metadata_recv[1] {self.grad_metadata_recv[1]}")
|
||||||
# return loss & output
|
# return loss & output
|
||||||
if outputs is not None:
|
if outputs is not None:
|
||||||
outputs = merge_batch(outputs)
|
outputs = merge_batch(outputs)
|
||||||
|
@ -193,7 +193,7 @@ class LlamaPolicy(Policy):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# not enable tp, replace layer to LinearWithGradAccum
|
# not enable tp, replace layer to LinearWithGradAccum
|
||||||
else:
|
elif use_zbv:
|
||||||
decoder_attribute_replacement = {
|
decoder_attribute_replacement = {
|
||||||
"self_attn.hidden_size": self.model.config.hidden_size // tp_size,
|
"self_attn.hidden_size": self.model.config.hidden_size // tp_size,
|
||||||
"self_attn.num_heads": num_q_heads,
|
"self_attn.num_heads": num_q_heads,
|
||||||
@ -514,24 +514,25 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy):
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
policy.update(new_item)
|
policy.update(new_item)
|
||||||
# enable tp, replace layer to LinearWithGradAccum
|
# TODO: test lora bug here
|
||||||
else:
|
# # enable tp, replace layer to LinearWithGradAccum
|
||||||
# add a new item for sequence classification
|
# else:
|
||||||
new_item = {
|
# # add a new item for sequence classification
|
||||||
LlamaForSequenceClassification: ModulePolicyDescription(
|
# new_item = {
|
||||||
sub_module_replacement=[
|
# LlamaForSequenceClassification: ModulePolicyDescription(
|
||||||
SubModuleReplacementDescription(
|
# sub_module_replacement=[
|
||||||
suffix="score",
|
# SubModuleReplacementDescription(
|
||||||
target_module=LinearWithGradAccum,
|
# suffix="score",
|
||||||
kwargs=dict(
|
# target_module=LinearWithGradAccum,
|
||||||
fp8_communication=self.shard_config.fp8_communication,
|
# kwargs=dict(
|
||||||
use_zbv=use_zbv,
|
# fp8_communication=self.shard_config.fp8_communication,
|
||||||
),
|
# use_zbv=use_zbv,
|
||||||
)
|
# ),
|
||||||
]
|
# )
|
||||||
)
|
# ]
|
||||||
}
|
# )
|
||||||
policy.update(new_item)
|
# }
|
||||||
|
# policy.update(new_item)
|
||||||
|
|
||||||
# to be confirmed
|
# to be confirmed
|
||||||
if self.pipeline_stage_manager:
|
if self.pipeline_stage_manager:
|
||||||
|
@ -916,12 +916,12 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
|
|||||||
@parameterize(
|
@parameterize(
|
||||||
"config",
|
"config",
|
||||||
[
|
[
|
||||||
# # Pass
|
# Pass
|
||||||
# (1, 2, 2, 1),
|
(1, 2, 2, 1),
|
||||||
# (1, 2, 1, 2),
|
(1, 2, 1, 2),
|
||||||
# (1, 1, 2, 2),
|
(1, 1, 2, 2),
|
||||||
# TODO: acc err in pp4
|
# TODO: acc err in pp4
|
||||||
(1, 4, 1, 1),
|
# (1, 4, 1, 1),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def run_with_booster_hybridplugin(config: Tuple[int, ...]):
|
def run_with_booster_hybridplugin(config: Tuple[int, ...]):
|
||||||
@ -1065,16 +1065,16 @@ def run_with_booster_hybridplugin(config: Tuple[int, ...]):
|
|||||||
torch_optimizer.step()
|
torch_optimizer.step()
|
||||||
torch_optimizer.zero_grad()
|
torch_optimizer.zero_grad()
|
||||||
|
|
||||||
# assert param
|
# # assert param
|
||||||
for parall_name, parall_param in parallel_model.named_parameters():
|
# for parall_name, parall_param in parallel_model.named_parameters():
|
||||||
parall_name = ".".join(parall_name.split(".")[1:])
|
# parall_name = ".".join(parall_name.split(".")[1:])
|
||||||
for base_name, base_param in torch_model.named_parameters():
|
# for base_name, base_param in torch_model.named_parameters():
|
||||||
if parall_name == base_name:
|
# if parall_name == base_name:
|
||||||
# assert weight
|
# # assert weight
|
||||||
assert_loose_close(parall_param, base_param, dtype=dtype, name=parall_name)
|
# assert_loose_close(parall_param, base_param, dtype=dtype, name=parall_name)
|
||||||
# assert weight.grad
|
# # assert weight.grad
|
||||||
if parall_param.grad is not None:
|
# if parall_param.grad is not None:
|
||||||
assert_loose_close(parall_param.grad, base_param.grad, dtype=dtype, name=f"{parall_name}.grad")
|
# assert_loose_close(parall_param.grad, base_param.grad, dtype=dtype, name=f"{parall_name}.grad")
|
||||||
|
|
||||||
assert_loose_close(parallel_output, torch_output_sum, dtype=dtype)
|
assert_loose_close(parallel_output, torch_output_sum, dtype=dtype)
|
||||||
print(f"rank {dist.get_rank()} pp_size:{pp_size}, tp_size {tp_size}, sp_size :{sp_size} test passed")
|
print(f"rank {dist.get_rank()} pp_size:{pp_size}, tp_size {tp_size}, sp_size :{sp_size} test passed")
|
||||||
@ -1086,7 +1086,7 @@ def run_with_booster_hybridplugin(config: Tuple[int, ...]):
|
|||||||
def run_dist(rank, world_size, port):
|
def run_dist(rank, world_size, port):
|
||||||
disable_existing_loggers()
|
disable_existing_loggers()
|
||||||
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||||
# run_with_booster_moehybridplugin()
|
run_with_booster_moehybridplugin()
|
||||||
run_with_booster_hybridplugin()
|
run_with_booster_hybridplugin()
|
||||||
|
|
||||||
|
|
||||||
|
@ -420,4 +420,4 @@ def test_llama_3d():
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_llama()
|
test_llama()
|
||||||
# test_llama_3d()
|
test_llama_3d()
|
||||||
|
Loading…
Reference in New Issue
Block a user