mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-27 12:43:02 +00:00
[Sharderformer] Support zbv in Sharderformer Policy (#6150)
* [feat] Sharderformer support zbv * [feat] support chatglm2, command, deepseek for zbv * [feat] support zbv in shardformer policy: falcon,gptj,mistral,opt,qwen2,t5, vit, whisper * [feat] support GPT2FusedLinearConv1D * [feat] support GPT2FusedLinear (without tp) * [fix] debug FusedConvLinear * [shardfromer] support gpt2 policy for zbv, support GPT2FusedLinearConv Col and Row. * [Shardformer] support FusedLinear1D base for zbv * [shardformer] support zbv in FusedLinear1D base, Col, Row * [shardformer] support zbv in blip2 and sam policy * [shardformer] fix bug incorrect number of gradients; add fusedLinear base testcase; * [fix] fix incorrect number of gradients ; * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [Shardformer] add en doc for zbv; * [fix] fix typo in Model compatibility table * [fix] fix API Reference typo * [Shardformer] add zh-Han doc for zbv * [fix] fix Linear name; update en & zh doc * [fix] fix shardformer doc import err * [fix] fix shardconfig import in doc * [fix] fix shardformer doc * [fix] fix shardconfig doc * [fix] fix config * [fix] remove shardconfig * [fix] fix doc * [feat] add zbv doc string * [fix] rm doc * [fix] fix doc * [fix] empty zbv doc * [fix] ifx torch version * [fix] fix torch version * [fix] fix torch versions * [fix] fix torch versions * [fix] fix pyramid versions * [fix] fix pyramid, zope version * [fix] try fix workflow * [fix] try import ShardConfig in yml * [fix] fix workflow * [fix] fix workflow * [fix] fix workflow * [fix] fix workflow * [fix] fix ci * [fix] fix zbv doc * [fix] fix param for qkv linear, gpt2fused linear; fix requirments; * [fix] fix policy use fused_linear * [fix] fix weight grad none, err caused by weight ptr change * [fix] fix comm in WeightGradStore * [fix] fix WeightGradStore pop param * [fix] remove useless param in doc; fix gpt2 qkv test; * [shardformer] simplify execute_w_pass_grad_accum; * [fix] rm useless comments * [shardformer] simplify execute_w_pass_grad_accum & execute_w_pass * [shardformer] Run meaningful doc test * [shadformer] fix doc test cmd; --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -59,6 +59,8 @@ class BloomPolicy(Policy):
|
||||
|
||||
sp_partial_derived = sp_mode == "split_gather"
|
||||
|
||||
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
assert (
|
||||
self.model.config.n_head % self.shard_config.tensor_parallel_size == 0
|
||||
@@ -78,6 +80,7 @@ class BloomPolicy(Policy):
|
||||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
@@ -86,6 +89,7 @@ class BloomPolicy(Policy):
|
||||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
@@ -98,6 +102,7 @@ class BloomPolicy(Policy):
|
||||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
@@ -106,6 +111,7 @@ class BloomPolicy(Policy):
|
||||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
],
|
||||
@@ -120,6 +126,52 @@ class BloomPolicy(Policy):
|
||||
},
|
||||
)
|
||||
|
||||
if use_zbv:
|
||||
policy[BloomBlock] = ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attention.query_key_value",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attention.dense",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attention.attention_dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.dense_h_to_4h",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.dense_4h_to_h",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
if embedding_cls is not None:
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
@@ -247,14 +299,27 @@ class BloomPolicy(Policy):
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
|
||||
held_layers = []
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.h))
|
||||
if stage_manager.is_first_stage():
|
||||
held_layers.append(module.word_embeddings)
|
||||
held_layers.append(module.word_embeddings_layernorm)
|
||||
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
|
||||
held_layers.extend(module.h[start_idx:end_idx])
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(module.ln_f)
|
||||
if stage_manager.is_interleave:
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.h))
|
||||
stage_indices = stage_manager.get_stage_index(layers_per_stage)
|
||||
if stage_manager.is_first_stage(ignore_chunk=True):
|
||||
held_layers.append(module.word_embeddings)
|
||||
held_layers.append(module.word_embeddings_layernorm)
|
||||
for start_idx, end_idx in stage_indices:
|
||||
held_layers.extend(module.h[start_idx:end_idx])
|
||||
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
|
||||
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
|
||||
):
|
||||
held_layers.append(module.ln_f)
|
||||
else:
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.h))
|
||||
if stage_manager.is_first_stage():
|
||||
held_layers.append(module.word_embeddings)
|
||||
held_layers.append(module.word_embeddings_layernorm)
|
||||
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
|
||||
held_layers.extend(module.h[start_idx:end_idx])
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(module.ln_f)
|
||||
|
||||
return held_layers
|
||||
|
||||
@@ -328,8 +393,14 @@ class BloomForCausalLMPolicy(BloomPolicy):
|
||||
"""Get pipeline layers for current stage."""
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
held_layers = super().get_held_layers()
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(self.model.lm_head)
|
||||
if stage_manager.is_interleave:
|
||||
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
|
||||
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
|
||||
):
|
||||
held_layers.append(self.model.lm_head)
|
||||
else:
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(self.model.lm_head)
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
@@ -351,6 +422,7 @@ class BloomForSequenceClassificationPolicy(BloomPolicy):
|
||||
from transformers.models.bloom.modeling_bloom import BloomForSequenceClassification
|
||||
|
||||
policy = super().module_policy()
|
||||
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
|
||||
|
||||
# handle tensor parallelism
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
@@ -363,6 +435,18 @@ class BloomForSequenceClassificationPolicy(BloomPolicy):
|
||||
policy=policy,
|
||||
target_key=BloomForSequenceClassification,
|
||||
)
|
||||
elif use_zbv:
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="score",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs=dict(
|
||||
gather_output=True, fp8_communication=self.shard_config.fp8_communication, use_zbv=use_zbv
|
||||
),
|
||||
),
|
||||
policy=policy,
|
||||
target_key=BloomForSequenceClassification,
|
||||
)
|
||||
if self.pipeline_stage_manager:
|
||||
self.set_pipeline_forward(
|
||||
model_cls=BloomForSequenceClassification,
|
||||
@@ -375,8 +459,14 @@ class BloomForSequenceClassificationPolicy(BloomPolicy):
|
||||
"""Get pipeline layers for current stage."""
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
held_layers = super().get_held_layers()
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(self.model.score)
|
||||
if stage_manager.is_interleave:
|
||||
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
|
||||
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
|
||||
):
|
||||
held_layers.append(self.model.score)
|
||||
else:
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(self.model.score)
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
@@ -389,6 +479,7 @@ class BloomForTokenClassificationPolicy(BloomPolicy):
|
||||
from transformers.models.bloom.modeling_bloom import BloomForTokenClassification
|
||||
|
||||
policy = super().module_policy()
|
||||
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
|
||||
|
||||
# handle tensor parallelism
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
@@ -407,6 +498,24 @@ class BloomForTokenClassificationPolicy(BloomPolicy):
|
||||
policy=policy,
|
||||
target_key=BloomForTokenClassification,
|
||||
)
|
||||
elif use_zbv:
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="classifier",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs=dict(
|
||||
gather_output=True, fp8_communication=self.shard_config.fp8_communication, use_zbv=use_zbv
|
||||
),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=col_nn.DropoutForReplicatedInput,
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
target_key=BloomForTokenClassification,
|
||||
)
|
||||
if self.pipeline_stage_manager:
|
||||
self.set_pipeline_forward(
|
||||
model_cls=BloomForTokenClassification,
|
||||
@@ -420,9 +529,16 @@ class BloomForTokenClassificationPolicy(BloomPolicy):
|
||||
"""Get pipeline layers for current stage."""
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
held_layers = super().get_held_layers()
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(self.model.dropout)
|
||||
held_layers.append(self.model.classifier)
|
||||
if stage_manager.is_interleave:
|
||||
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
|
||||
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
|
||||
):
|
||||
held_layers.append(self.model.dropout)
|
||||
held_layers.append(self.model.classifier)
|
||||
else:
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(self.model.dropout)
|
||||
held_layers.append(self.model.classifier)
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
@@ -448,8 +564,14 @@ class BloomForQuestionAnsweringPolicy(BloomPolicy):
|
||||
"""Get pipeline layers for current stage."""
|
||||
held_layers = super().get_held_layers()
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(self.model.qa_outputs)
|
||||
if stage_manager.is_interleave:
|
||||
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
|
||||
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
|
||||
):
|
||||
held_layers.append(self.model.qa_outputs)
|
||||
else:
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(self.model.qa_outputs)
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
|
Reference in New Issue
Block a user