mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-27 04:33:04 +00:00
[Sharderformer] Support zbv in Sharderformer Policy (#6150)
* [feat] Sharderformer support zbv * [feat] support chatglm2, command, deepseek for zbv * [feat] support zbv in shardformer policy: falcon,gptj,mistral,opt,qwen2,t5, vit, whisper * [feat] support GPT2FusedLinearConv1D * [feat] support GPT2FusedLinear (without tp) * [fix] debug FusedConvLinear * [shardfromer] support gpt2 policy for zbv, support GPT2FusedLinearConv Col and Row. * [Shardformer] support FusedLinear1D base for zbv * [shardformer] support zbv in FusedLinear1D base, Col, Row * [shardformer] support zbv in blip2 and sam policy * [shardformer] fix bug incorrect number of gradients; add fusedLinear base testcase; * [fix] fix incorrect number of gradients ; * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [Shardformer] add en doc for zbv; * [fix] fix typo in Model compatibility table * [fix] fix API Reference typo * [Shardformer] add zh-Han doc for zbv * [fix] fix Linear name; update en & zh doc * [fix] fix shardformer doc import err * [fix] fix shardconfig import in doc * [fix] fix shardformer doc * [fix] fix shardconfig doc * [fix] fix config * [fix] remove shardconfig * [fix] fix doc * [feat] add zbv doc string * [fix] rm doc * [fix] fix doc * [fix] empty zbv doc * [fix] ifx torch version * [fix] fix torch version * [fix] fix torch versions * [fix] fix torch versions * [fix] fix pyramid versions * [fix] fix pyramid, zope version * [fix] try fix workflow * [fix] try import ShardConfig in yml * [fix] fix workflow * [fix] fix workflow * [fix] fix workflow * [fix] fix workflow * [fix] fix ci * [fix] fix zbv doc * [fix] fix param for qkv linear, gpt2fused linear; fix requirments; * [fix] fix policy use fused_linear * [fix] fix weight grad none, err caused by weight ptr change * [fix] fix comm in WeightGradStore * [fix] fix WeightGradStore pop param * [fix] remove useless param in doc; fix gpt2 qkv test; * [shardformer] simplify execute_w_pass_grad_accum; * [fix] rm useless comments * [shardformer] simplify execute_w_pass_grad_accum & execute_w_pass * [shardformer] Run meaningful doc test * [shadformer] fix doc test cmd; --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -51,6 +51,8 @@ class BlipPolicy(Policy):
|
||||
else:
|
||||
norm_cls = col_nn.LayerNorm
|
||||
|
||||
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
assert (
|
||||
self.model.config.vision_config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
|
||||
@@ -73,6 +75,7 @@ class BlipPolicy(Policy):
|
||||
kwargs={
|
||||
"split_sizes": [self.model.config.vision_config.hidden_size] * 3,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
@@ -80,6 +83,7 @@ class BlipPolicy(Policy):
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
@@ -88,6 +92,7 @@ class BlipPolicy(Policy):
|
||||
kwargs={
|
||||
"skip_bias_add": self.enable_bias_gelu_fused,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
@@ -95,6 +100,7 @@ class BlipPolicy(Policy):
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
],
|
||||
@@ -126,6 +132,7 @@ class BlipPolicy(Policy):
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
@@ -133,6 +140,7 @@ class BlipPolicy(Policy):
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
@@ -140,6 +148,7 @@ class BlipPolicy(Policy):
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
@@ -151,6 +160,7 @@ class BlipPolicy(Policy):
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
@@ -162,6 +172,7 @@ class BlipPolicy(Policy):
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
@@ -169,6 +180,7 @@ class BlipPolicy(Policy):
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
@@ -176,6 +188,7 @@ class BlipPolicy(Policy):
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
@@ -187,6 +200,7 @@ class BlipPolicy(Policy):
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
@@ -198,6 +212,7 @@ class BlipPolicy(Policy):
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
@@ -205,6 +220,7 @@ class BlipPolicy(Policy):
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
@@ -227,6 +243,7 @@ class BlipPolicy(Policy):
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
@@ -234,6 +251,7 @@ class BlipPolicy(Policy):
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
@@ -241,6 +259,7 @@ class BlipPolicy(Policy):
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
@@ -248,6 +267,7 @@ class BlipPolicy(Policy):
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
@@ -255,6 +275,7 @@ class BlipPolicy(Policy):
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
@@ -262,6 +283,226 @@ class BlipPolicy(Policy):
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
policy[Blip2Attention] = ModulePolicyDescription(method_replacement={"forward": forward_fn()})
|
||||
if self.enable_bias_gelu_fused:
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_jit_fused_blip2_mlp_forward(),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=Blip2MLP,
|
||||
)
|
||||
elif use_zbv:
|
||||
policy[Blip2EncoderLayer] = ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.qkv",
|
||||
target_module=col_nn.FusedLinear,
|
||||
kwargs={
|
||||
"split_sizes": [self.model.config.vision_config.hidden_size] * 3,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.projection",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.fc1",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"skip_bias_add": self.enable_bias_gelu_fused,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.fc2",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
policy[Blip2QFormerModel] = ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
policy[Blip2QFormerLayer] = ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.attention.query",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.attention.key",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.attention.value",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.attention.dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.output.dense",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.output.dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="crossattention.attention.query",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="crossattention.attention.key",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="crossattention.attention.value",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="crossattention.attention.dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="crossattention.output.dense",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="crossattention.output.dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="intermediate_query.dense",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="output_query.dense",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="output_query.dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
policy[OPTDecoderLayer] = ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.q_proj",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.k_proj",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.v_proj",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.out_proj",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="fc1",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="fc2",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
],
|
||||
|
Reference in New Issue
Block a user