mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-26 20:23:26 +00:00
[Sharderformer] Support zbv in Sharderformer Policy (#6150)
* [feat] Sharderformer support zbv * [feat] support chatglm2, command, deepseek for zbv * [feat] support zbv in shardformer policy: falcon,gptj,mistral,opt,qwen2,t5, vit, whisper * [feat] support GPT2FusedLinearConv1D * [feat] support GPT2FusedLinear (without tp) * [fix] debug FusedConvLinear * [shardfromer] support gpt2 policy for zbv, support GPT2FusedLinearConv Col and Row. * [Shardformer] support FusedLinear1D base for zbv * [shardformer] support zbv in FusedLinear1D base, Col, Row * [shardformer] support zbv in blip2 and sam policy * [shardformer] fix bug incorrect number of gradients; add fusedLinear base testcase; * [fix] fix incorrect number of gradients ; * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [Shardformer] add en doc for zbv; * [fix] fix typo in Model compatibility table * [fix] fix API Reference typo * [Shardformer] add zh-Han doc for zbv * [fix] fix Linear name; update en & zh doc * [fix] fix shardformer doc import err * [fix] fix shardconfig import in doc * [fix] fix shardformer doc * [fix] fix shardconfig doc * [fix] fix config * [fix] remove shardconfig * [fix] fix doc * [feat] add zbv doc string * [fix] rm doc * [fix] fix doc * [fix] empty zbv doc * [fix] ifx torch version * [fix] fix torch version * [fix] fix torch versions * [fix] fix torch versions * [fix] fix pyramid versions * [fix] fix pyramid, zope version * [fix] try fix workflow * [fix] try import ShardConfig in yml * [fix] fix workflow * [fix] fix workflow * [fix] fix workflow * [fix] fix workflow * [fix] fix ci * [fix] fix zbv doc * [fix] fix param for qkv linear, gpt2fused linear; fix requirments; * [fix] fix policy use fused_linear * [fix] fix weight grad none, err caused by weight ptr change * [fix] fix comm in WeightGradStore * [fix] fix WeightGradStore pop param * [fix] remove useless param in doc; fix gpt2 qkv test; * [shardformer] simplify execute_w_pass_grad_accum; * [fix] rm useless comments * [shardformer] simplify execute_w_pass_grad_accum & execute_w_pass * [shardformer] Run meaningful doc test * [shadformer] fix doc test cmd; --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -67,6 +67,8 @@ class GPT2Policy(Policy):
|
||||
self.shard_config.sequence_parallelism_mode = sp_mode = "split_gather"
|
||||
sp_partial_derived = sp_mode in ["split_gather", "ring"]
|
||||
use_flash_attention = self.shard_config.enable_flash_attention
|
||||
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
assert (
|
||||
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
|
||||
@@ -94,12 +96,17 @@ class GPT2Policy(Policy):
|
||||
"split_sizes": [self.model.config.hidden_size] * 3,
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attn.c_proj",
|
||||
target_module=col_nn.GPT2FusedLinearConv1D_Row,
|
||||
kwargs={"seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication},
|
||||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.c_fc",
|
||||
@@ -109,12 +116,17 @@ class GPT2Policy(Policy):
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"skip_bias_add": self.enable_bias_gelu_fused,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.c_proj",
|
||||
target_module=col_nn.GPT2FusedLinearConv1D_Row,
|
||||
kwargs={"seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication},
|
||||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attn.attn_dropout",
|
||||
@@ -138,6 +150,78 @@ class GPT2Policy(Policy):
|
||||
policy=policy,
|
||||
target_key=GPT2MLP,
|
||||
)
|
||||
elif use_zbv:
|
||||
policy[GPT2Model] = ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="drop",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
policy[GPT2Block] = ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attn.c_attn",
|
||||
target_module=col_nn.GPT2FusedLinearConv,
|
||||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attn.c_proj",
|
||||
target_module=col_nn.GPT2FusedLinearConv,
|
||||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.c_fc",
|
||||
target_module=col_nn.GPT2FusedLinearConv,
|
||||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"skip_bias_add": self.enable_bias_gelu_fused,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.c_proj",
|
||||
target_module=col_nn.GPT2FusedLinearConv,
|
||||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attn.attn_dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attn.resid_dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
],
|
||||
)
|
||||
if self.enable_bias_gelu_fused:
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_jit_fused_gpt2_mlp_forward(),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=GPT2MLP,
|
||||
)
|
||||
|
||||
if embedding_cls is not None:
|
||||
# padding vocabulary size when using pp to make it divisible by shard_config.make_vocab_size_divisible_by
|
||||
self.append_or_create_submodule_replacement(
|
||||
@@ -352,8 +436,17 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
|
||||
|
||||
def get_held_layers(self) -> List[nn.Module]:
|
||||
held_layers = super().get_held_layers()
|
||||
if self.pipeline_stage_manager.is_last_stage(ignore_chunk=True):
|
||||
held_layers.append(self.model.lm_head)
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if stage_manager.is_interleave:
|
||||
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
|
||||
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
|
||||
):
|
||||
held_layers.append(self.model.lm_head)
|
||||
else:
|
||||
if self.pipeline_stage_manager.is_last_stage(ignore_chunk=True):
|
||||
held_layers.append(self.model.lm_head)
|
||||
# if self.pipeline_stage_manager.is_last_stage(ignore_chunk=True):
|
||||
# held_layers.append(self.model.lm_head)
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
@@ -420,13 +513,24 @@ class GPT2DoubleHeadsModelPolicy(GPT2Policy):
|
||||
|
||||
def get_held_layers(self) -> List[nn.Module]:
|
||||
held_layers = super().get_held_layers()
|
||||
if self.pipeline_stage_manager.is_last_stage():
|
||||
multiple_choice_head = self.model.multiple_choice_head
|
||||
held_layers.append(self.model.lm_head)
|
||||
held_layers.append(multiple_choice_head.summary)
|
||||
held_layers.append(multiple_choice_head.activation)
|
||||
held_layers.append(multiple_choice_head.first_dropout)
|
||||
held_layers.append(multiple_choice_head.last_dropout)
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if stage_manager.is_interleave:
|
||||
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
|
||||
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
|
||||
):
|
||||
held_layers.append(self.model.lm_head)
|
||||
held_layers.append(multiple_choice_head.summary)
|
||||
held_layers.append(multiple_choice_head.activation)
|
||||
held_layers.append(multiple_choice_head.first_dropout)
|
||||
held_layers.append(multiple_choice_head.last_dropout)
|
||||
else:
|
||||
if self.pipeline_stage_manager.is_last_stage():
|
||||
multiple_choice_head = self.model.multiple_choice_head
|
||||
held_layers.append(self.model.lm_head)
|
||||
held_layers.append(multiple_choice_head.summary)
|
||||
held_layers.append(multiple_choice_head.activation)
|
||||
held_layers.append(multiple_choice_head.first_dropout)
|
||||
held_layers.append(multiple_choice_head.last_dropout)
|
||||
|
||||
return held_layers
|
||||
|
||||
@@ -464,8 +568,17 @@ class GPT2ForQuestionAnsweringPolicy(GPT2Policy):
|
||||
|
||||
def get_held_layers(self) -> List[nn.Module]:
|
||||
held_layers = super().get_held_layers()
|
||||
if self.pipeline_stage_manager.is_last_stage():
|
||||
held_layers.append(self.model.qa_outputs)
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if stage_manager.is_interleave:
|
||||
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
|
||||
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
|
||||
):
|
||||
held_layers.append(self.model.qa_outputs)
|
||||
else:
|
||||
if self.pipeline_stage_manager.is_last_stage():
|
||||
held_layers.append(self.model.qa_outputs)
|
||||
# if self.pipeline_stage_manager.is_last_stage():
|
||||
# held_layers.append(self.model.qa_outputs)
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
@@ -503,9 +616,20 @@ class GPT2ForTokenClassificationPolicy(GPT2Policy):
|
||||
|
||||
def get_held_layers(self) -> List[nn.Module]:
|
||||
held_layers = super().get_held_layers()
|
||||
if self.pipeline_stage_manager.is_last_stage():
|
||||
held_layers.append(self.model.dropout)
|
||||
held_layers.append(self.model.classifier)
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if stage_manager.is_interleave:
|
||||
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
|
||||
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
|
||||
):
|
||||
held_layers.append(self.model.dropout)
|
||||
held_layers.append(self.model.classifier)
|
||||
else:
|
||||
if self.pipeline_stage_manager.is_last_stage():
|
||||
held_layers.append(self.model.dropout)
|
||||
held_layers.append(self.model.classifier)
|
||||
# if self.pipeline_stage_manager.is_last_stage():
|
||||
# held_layers.append(self.model.dropout)
|
||||
# held_layers.append(self.model.classifier)
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
@@ -530,8 +654,18 @@ class GPT2ForSequenceClassificationPolicy(GPT2Policy):
|
||||
|
||||
def get_held_layers(self) -> List[nn.Module]:
|
||||
held_layers = super().get_held_layers()
|
||||
if self.pipeline_stage_manager.is_last_stage():
|
||||
held_layers.append(self.model.score)
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if stage_manager.is_interleave:
|
||||
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
|
||||
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
|
||||
):
|
||||
held_layers.append(self.model.score)
|
||||
else:
|
||||
if self.pipeline_stage_manager.is_last_stage():
|
||||
held_layers.append(self.model.score)
|
||||
|
||||
# if self.pipeline_stage_manager.is_last_stage():
|
||||
# held_layers.append(self.model.score)
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
|
Reference in New Issue
Block a user