mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-23 02:06:35 +00:00
[feat] support mixtral policy with zbv tp_Linear & non_tp_Linear
This commit is contained in:
parent
337debcf2a
commit
80b04d7855
@ -45,7 +45,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
num_model_chunks: int,
|
num_model_chunks: int,
|
||||||
num_microbatch: Optional[int] = None,
|
num_microbatch: Optional[int] = None,
|
||||||
microbatch_size: Optional[int] = None,
|
microbatch_size: Optional[int] = None,
|
||||||
enable_metadata_cache: bool = True,
|
enable_metadata_cache: bool = False,
|
||||||
overlap_p2p: bool = True,
|
overlap_p2p: bool = True,
|
||||||
):
|
):
|
||||||
super().__init__(stage_manager)
|
super().__init__(stage_manager)
|
||||||
@ -679,6 +679,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
accum_loss=accum_loss,
|
accum_loss=accum_loss,
|
||||||
outputs=outputs,
|
outputs=outputs,
|
||||||
)
|
)
|
||||||
|
# print(f"stage {self.stage_manager.stage}; chunk {model_chunk_id}; output_obj {output_obj}")
|
||||||
|
|
||||||
# Step3:
|
# Step3:
|
||||||
# 3-1:detach output; detach output for send fwd;
|
# 3-1:detach output; detach output for send fwd;
|
||||||
|
@ -194,15 +194,7 @@ class LlamaPolicy(Policy):
|
|||||||
|
|
||||||
# not enable tp, replace layer to LinearWithGradAccum
|
# not enable tp, replace layer to LinearWithGradAccum
|
||||||
elif use_zbv:
|
elif use_zbv:
|
||||||
decoder_attribute_replacement = {
|
|
||||||
"self_attn.hidden_size": self.model.config.hidden_size // tp_size,
|
|
||||||
"self_attn.num_heads": num_q_heads,
|
|
||||||
}
|
|
||||||
if getattr(self.model.config, "num_key_value_heads", False):
|
|
||||||
decoder_attribute_replacement["self_attn.num_key_value_heads"] = num_kv_heads
|
|
||||||
|
|
||||||
policy[LlamaDecoderLayer] = ModulePolicyDescription(
|
policy[LlamaDecoderLayer] = ModulePolicyDescription(
|
||||||
attribute_replacement=decoder_attribute_replacement,
|
|
||||||
sub_module_replacement=[
|
sub_module_replacement=[
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="self_attn.q_proj",
|
suffix="self_attn.q_proj",
|
||||||
|
@ -10,6 +10,7 @@ from colossalai.shardformer.layer import (
|
|||||||
FusedRMSNorm,
|
FusedRMSNorm,
|
||||||
Linear1D_Col,
|
Linear1D_Col,
|
||||||
Linear1D_Row,
|
Linear1D_Row,
|
||||||
|
LinearWithGradAccum,
|
||||||
PaddingEmbedding,
|
PaddingEmbedding,
|
||||||
PaddingLMHead,
|
PaddingLMHead,
|
||||||
VocabParallelEmbedding1D,
|
VocabParallelEmbedding1D,
|
||||||
@ -62,6 +63,8 @@ class MistralPolicy(Policy):
|
|||||||
if self.tie_weight:
|
if self.tie_weight:
|
||||||
embedding_cls = PaddingEmbedding
|
embedding_cls = PaddingEmbedding
|
||||||
|
|
||||||
|
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
|
||||||
|
|
||||||
if self.shard_config.enable_sequence_parallelism:
|
if self.shard_config.enable_sequence_parallelism:
|
||||||
self.shard_config.enable_sequence_parallelism = False
|
self.shard_config.enable_sequence_parallelism = False
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
@ -90,6 +93,7 @@ class MistralPolicy(Policy):
|
|||||||
target_module=Linear1D_Col,
|
target_module=Linear1D_Col,
|
||||||
kwargs={
|
kwargs={
|
||||||
"fp8_communication": self.shard_config.fp8_communication,
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
|
"use_zbv": use_zbv,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
@ -97,6 +101,7 @@ class MistralPolicy(Policy):
|
|||||||
target_module=Linear1D_Col,
|
target_module=Linear1D_Col,
|
||||||
kwargs={
|
kwargs={
|
||||||
"fp8_communication": self.shard_config.fp8_communication,
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
|
"use_zbv": use_zbv,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
@ -104,6 +109,7 @@ class MistralPolicy(Policy):
|
|||||||
target_module=Linear1D_Col,
|
target_module=Linear1D_Col,
|
||||||
kwargs={
|
kwargs={
|
||||||
"fp8_communication": self.shard_config.fp8_communication,
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
|
"use_zbv": use_zbv,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
@ -111,6 +117,7 @@ class MistralPolicy(Policy):
|
|||||||
target_module=Linear1D_Row,
|
target_module=Linear1D_Row,
|
||||||
kwargs={
|
kwargs={
|
||||||
"fp8_communication": self.shard_config.fp8_communication,
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
|
"use_zbv": use_zbv,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
@ -118,6 +125,7 @@ class MistralPolicy(Policy):
|
|||||||
target_module=Linear1D_Col,
|
target_module=Linear1D_Col,
|
||||||
kwargs={
|
kwargs={
|
||||||
"fp8_communication": self.shard_config.fp8_communication,
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
|
"use_zbv": use_zbv,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
@ -125,6 +133,7 @@ class MistralPolicy(Policy):
|
|||||||
target_module=Linear1D_Col,
|
target_module=Linear1D_Col,
|
||||||
kwargs={
|
kwargs={
|
||||||
"fp8_communication": self.shard_config.fp8_communication,
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
|
"use_zbv": use_zbv,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
@ -132,6 +141,68 @@ class MistralPolicy(Policy):
|
|||||||
target_module=Linear1D_Row,
|
target_module=Linear1D_Row,
|
||||||
kwargs={
|
kwargs={
|
||||||
"fp8_communication": self.shard_config.fp8_communication,
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
|
"use_zbv": use_zbv,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
elif use_zbv:
|
||||||
|
policy[MistralDecoderLayer] = ModulePolicyDescription(
|
||||||
|
sub_module_replacement=[
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="self_attn.q_proj",
|
||||||
|
target_module=LinearWithGradAccum,
|
||||||
|
kwargs={
|
||||||
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
|
"use_zbv": use_zbv,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="self_attn.k_proj",
|
||||||
|
target_module=LinearWithGradAccum,
|
||||||
|
kwargs={
|
||||||
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
|
"use_zbv": use_zbv,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="self_attn.v_proj",
|
||||||
|
target_module=LinearWithGradAccum,
|
||||||
|
kwargs={
|
||||||
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
|
"use_zbv": use_zbv,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="self_attn.o_proj",
|
||||||
|
target_module=LinearWithGradAccum,
|
||||||
|
kwargs={
|
||||||
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
|
"use_zbv": use_zbv,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="mlp.gate_proj",
|
||||||
|
target_module=LinearWithGradAccum,
|
||||||
|
kwargs={
|
||||||
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
|
"use_zbv": use_zbv,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="mlp.up_proj",
|
||||||
|
target_module=LinearWithGradAccum,
|
||||||
|
kwargs={
|
||||||
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
|
"use_zbv": use_zbv,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="mlp.down_proj",
|
||||||
|
target_module=LinearWithGradAccum,
|
||||||
|
kwargs={
|
||||||
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
|
"use_zbv": use_zbv,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
|
@ -36,6 +36,24 @@ NUM_HEADS = 4
|
|||||||
TOP_K = 1
|
TOP_K = 1
|
||||||
|
|
||||||
|
|
||||||
|
def register_hooks(module: torch.nn.Module):
|
||||||
|
|
||||||
|
def fwd_hook(module, input, output):
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
name = module._name if hasattr(module, "_name") else module
|
||||||
|
print(f"Fwd hook {name} \n output {output}")
|
||||||
|
|
||||||
|
def bwd_hook(module, grad_input, grad_output):
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
def bwd_pre_hook(module, grad_output):
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
module.register_forward_hook(fwd_hook)
|
||||||
|
# module.register_backward_hook(bwd_hook)
|
||||||
|
# module.register_full_backward_pre_hook(bwd_pre_hook)
|
||||||
|
|
||||||
|
|
||||||
class MlpModel(nn.Module):
|
class MlpModel(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -756,9 +774,9 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
|
|||||||
(1, 2, 1, 1, 2),
|
(1, 2, 1, 1, 2),
|
||||||
(1, 1, 2, 2, 1),
|
(1, 1, 2, 2, 1),
|
||||||
(1, 2, 1, 2, 1),
|
(1, 2, 1, 2, 1),
|
||||||
# TODO: adapt mixtral with no TP Linear
|
(1, 2, 2, 1, 1),
|
||||||
# (1, 2, 2, 1, 1),
|
# # TODO: adapt mixtral with no TP Linear
|
||||||
# (0, 1, 4, 1, 1),
|
(0, 1, 4, 1, 1),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
|
def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
|
||||||
@ -872,7 +890,6 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
|
|||||||
return_outputs=True,
|
return_outputs=True,
|
||||||
)
|
)
|
||||||
# stage 0 chunk 0
|
# stage 0 chunk 0
|
||||||
parallel_output = None
|
|
||||||
if (
|
if (
|
||||||
booster.plugin.stage_manager.is_first_stage(ignore_chunk=True)
|
booster.plugin.stage_manager.is_first_stage(ignore_chunk=True)
|
||||||
and rank == dist.get_process_group_ranks(plugin.pp_group)[0]
|
and rank == dist.get_process_group_ranks(plugin.pp_group)[0]
|
||||||
@ -880,6 +897,7 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
|
|||||||
parallel_output = sharded_output["loss"]
|
parallel_output = sharded_output["loss"]
|
||||||
else:
|
else:
|
||||||
parallel_output = torch.tensor(12345.0, device="cuda")
|
parallel_output = torch.tensor(12345.0, device="cuda")
|
||||||
|
print(f"rank {dist.get_rank()} parallel_output {parallel_output}")
|
||||||
# broadcast along pp axis
|
# broadcast along pp axis
|
||||||
dist.broadcast(parallel_output, src=dist.get_process_group_ranks(plugin.pp_group)[0], group=plugin.pp_group)
|
dist.broadcast(parallel_output, src=dist.get_process_group_ranks(plugin.pp_group)[0], group=plugin.pp_group)
|
||||||
|
|
||||||
@ -920,8 +938,8 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
|
|||||||
(1, 2, 2, 1),
|
(1, 2, 2, 1),
|
||||||
(1, 2, 1, 2),
|
(1, 2, 1, 2),
|
||||||
(1, 1, 2, 2),
|
(1, 1, 2, 2),
|
||||||
# TODO: acc err in pp4
|
# TODO: support overlap p2p in pp4
|
||||||
# (1, 4, 1, 1),
|
(1, 4, 1, 1),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def run_with_booster_hybridplugin(config: Tuple[int, ...]):
|
def run_with_booster_hybridplugin(config: Tuple[int, ...]):
|
||||||
@ -1030,7 +1048,6 @@ def run_with_booster_hybridplugin(config: Tuple[int, ...]):
|
|||||||
return_outputs=True,
|
return_outputs=True,
|
||||||
)
|
)
|
||||||
# stage 0 chunk 0
|
# stage 0 chunk 0
|
||||||
parallel_output = None
|
|
||||||
if (
|
if (
|
||||||
booster.plugin.stage_manager.is_first_stage(ignore_chunk=True)
|
booster.plugin.stage_manager.is_first_stage(ignore_chunk=True)
|
||||||
and rank == dist.get_process_group_ranks(plugin.pp_group)[0]
|
and rank == dist.get_process_group_ranks(plugin.pp_group)[0]
|
||||||
@ -1054,6 +1071,7 @@ def run_with_booster_hybridplugin(config: Tuple[int, ...]):
|
|||||||
all_inputs = [input_embeddings.clone() for _ in range(dp_size)]
|
all_inputs = [input_embeddings.clone() for _ in range(dp_size)]
|
||||||
dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group)
|
dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group)
|
||||||
torch_output_sum = 0
|
torch_output_sum = 0
|
||||||
|
# torch_model.apply(register_hooks) # register hook for base model
|
||||||
for input_data_ in all_inputs:
|
for input_data_ in all_inputs:
|
||||||
torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()
|
torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()
|
||||||
torch_output.backward()
|
torch_output.backward()
|
||||||
@ -1065,19 +1083,7 @@ def run_with_booster_hybridplugin(config: Tuple[int, ...]):
|
|||||||
torch_optimizer.step()
|
torch_optimizer.step()
|
||||||
torch_optimizer.zero_grad()
|
torch_optimizer.zero_grad()
|
||||||
|
|
||||||
# # assert param
|
print(f"parallel_output {parallel_output}, torch_output_sum {torch_output_sum}")
|
||||||
# for parall_name, parall_param in parallel_model.named_parameters():
|
|
||||||
# parall_name = ".".join(parall_name.split(".")[1:])
|
|
||||||
# for base_name, base_param in torch_model.named_parameters():
|
|
||||||
# if parall_name == base_name:
|
|
||||||
# # print(f"parall_name {parall_name} parall_param.grad {parall_param.grad is not None}, base_name {base_name} base_param.grad {base_param.grad is not None}")
|
|
||||||
# # # assert weight
|
|
||||||
# assert_loose_close(parall_param, base_param, dtype=dtype, name=parall_name)
|
|
||||||
# # # assert weight.grad
|
|
||||||
# if parall_param.grad is not None:
|
|
||||||
# # print(f"parall_param.grad {parall_param.grad}, base_param.grad {base_param.grad}")
|
|
||||||
# assert_loose_close(parall_param.grad, base_param.grad, dtype=dtype, name=f"{parall_name}.grad")
|
|
||||||
|
|
||||||
assert_loose_close(parallel_output, torch_output_sum, dtype=dtype)
|
assert_loose_close(parallel_output, torch_output_sum, dtype=dtype)
|
||||||
print(f"rank {dist.get_rank()} pp_size:{pp_size}, tp_size {tp_size}, sp_size :{sp_size} test passed")
|
print(f"rank {dist.get_rank()} pp_size:{pp_size}, tp_size {tp_size}, sp_size :{sp_size} test passed")
|
||||||
clear_layout_converter()
|
clear_layout_converter()
|
||||||
|
Loading…
Reference in New Issue
Block a user