Files
ColossalAI/colossalai/shardformer/policies/blip2.py
duanjunwen a9bedc7a43 [Sharderformer] Support zbv in Sharderformer Policy (#6150)
* [feat] Sharderformer support zbv

* [feat] support chatglm2, command, deepseek for zbv

* [feat] support zbv in shardformer policy:
falcon,gptj,mistral,opt,qwen2,t5, vit, whisper

* [feat] support GPT2FusedLinearConv1D

* [feat] support GPT2FusedLinear (without tp)

* [fix] debug FusedConvLinear

* [shardfromer] support gpt2 policy for zbv, support GPT2FusedLinearConv
Col and Row.

* [Shardformer] support FusedLinear1D base for zbv

* [shardformer] support zbv in FusedLinear1D base, Col, Row

* [shardformer] support zbv in blip2 and sam policy

* [shardformer] fix bug incorrect number of gradients; add fusedLinear
base testcase;

* [fix] fix incorrect number of gradients ;

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [Shardformer] add en doc for zbv;

* [fix] fix typo in Model compatibility table

* [fix] fix API Reference typo

* [Shardformer] add zh-Han doc for zbv

* [fix] fix Linear name; update en & zh doc

* [fix] fix shardformer doc import err

* [fix] fix shardconfig import in doc

* [fix] fix shardformer doc

* [fix] fix shardconfig doc

* [fix] fix config

* [fix] remove shardconfig

* [fix] fix doc

* [feat] add zbv doc string

* [fix] rm doc

* [fix] fix doc

* [fix] empty zbv doc

* [fix] ifx torch version

* [fix] fix torch version

* [fix] fix torch versions

* [fix] fix torch versions

* [fix] fix pyramid versions

* [fix] fix pyramid, zope version

* [fix] try fix workflow

* [fix] try import ShardConfig in yml

* [fix] fix workflow

* [fix] fix workflow

* [fix] fix workflow

* [fix] fix workflow

* [fix] fix ci

* [fix] fix zbv doc

* [fix] fix param for qkv linear, gpt2fused linear; fix requirments;

* [fix] fix policy use fused_linear

* [fix] fix weight grad none, err caused by  weight ptr change

* [fix] fix comm in WeightGradStore

* [fix] fix WeightGradStore pop param

* [fix] remove useless param in doc; fix gpt2 qkv test;

* [shardformer] simplify execute_w_pass_grad_accum;

* [fix] rm useless comments

* [shardformer] simplify execute_w_pass_grad_accum & execute_w_pass

* [shardformer] Run meaningful doc test

* [shadformer] fix doc test cmd;

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-01-02 10:22:26 +08:00

703 lines
30 KiB
Python

import colossalai.shardformer.layer as col_nn
from ..modeling.blip2 import (
forward_fn,
get_blip2_flash_attention_forward,
get_jit_fused_blip2_mlp_forward,
get_jit_fused_blip2_QFormer_output_forward,
get_jit_fused_blip2_QFormer_self_output_forward,
)
from ..modeling.jit import get_jit_fused_dropout_add_func
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ["BlipPolicy", "BlipModelPolicy"]
class BlipPolicy(Policy):
def config_sanity_check(self):
pass
def preprocess(self):
self.tie_weight = self.tie_weight_check()
self.enable_bias_gelu_fused = (
self.shard_config.enable_jit_fused and self.model.config.vision_config.hidden_act == "gelu"
)
return self.model
def module_policy(self):
from transformers.models.blip_2.modeling_blip_2 import (
Blip2Attention,
Blip2EncoderLayer,
Blip2MLP,
Blip2QFormerLayer,
Blip2QFormerModel,
Blip2QFormerOutput,
Blip2QFormerSelfOutput,
Blip2VisionModel,
)
from transformers.models.opt.modeling_opt import OPTDecoderLayer, OPTForCausalLM
policy = {}
embedding_cls = None
if self.shard_config.enable_tensor_parallelism:
embedding_cls = col_nn.VocabParallelEmbedding1D
else:
if self.tie_weight:
embedding_cls = col_nn.PaddingEmbedding
if self.shard_config.enable_fused_normalization:
norm_cls = col_nn.FusedLayerNorm
else:
norm_cls = col_nn.LayerNorm
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
if self.shard_config.enable_tensor_parallelism:
assert (
self.model.config.vision_config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
), f"The number of attention heads must be divisible by tensor parallel size."
policy[Blip2EncoderLayer] = ModulePolicyDescription(
attribute_replacement={
"self_attn.num_heads": self.model.config.vision_config.num_attention_heads
// self.shard_config.tensor_parallel_size,
"self_attn.embed_dim": self.model.config.vision_config.hidden_size
// self.shard_config.tensor_parallel_size,
},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attn.dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="self_attn.qkv",
target_module=col_nn.FusedLinear1D_Col,
kwargs={
"split_sizes": [self.model.config.vision_config.hidden_size] * 3,
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="self_attn.projection",
target_module=col_nn.Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="mlp.fc1",
target_module=col_nn.Linear1D_Col,
kwargs={
"skip_bias_add": self.enable_bias_gelu_fused,
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="mlp.fc2",
target_module=col_nn.Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
],
)
policy[Blip2QFormerModel] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.DropoutForParallelInput,
),
]
)
policy[Blip2QFormerLayer] = ModulePolicyDescription(
attribute_replacement={
"attention.attention.num_attention_heads": self.model.config.qformer_config.num_attention_heads
// self.shard_config.tensor_parallel_size,
"attention.attention.all_head_size": self.model.config.qformer_config.hidden_size
// self.shard_config.tensor_parallel_size,
"crossattention.attention.num_attention_heads": self.model.config.qformer_config.num_attention_heads
// self.shard_config.tensor_parallel_size,
"crossattention.attention.all_head_size": self.model.config.qformer_config.hidden_size
// self.shard_config.tensor_parallel_size,
},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="attention.attention.query",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="attention.attention.key",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="attention.attention.value",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="attention.attention.dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="attention.output.dense",
target_module=col_nn.Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="attention.output.dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="crossattention.attention.query",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="crossattention.attention.key",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="crossattention.attention.value",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="crossattention.attention.dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="crossattention.output.dense",
target_module=col_nn.Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="crossattention.output.dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="intermediate_query.dense",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="output_query.dense",
target_module=col_nn.Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="output_query.dropout",
target_module=col_nn.DropoutForParallelInput,
),
],
)
policy[OPTDecoderLayer] = ModulePolicyDescription(
attribute_replacement={
"self_attn.embed_dim": self.model.config.text_config.hidden_size
// self.shard_config.tensor_parallel_size,
"self_attn.num_heads": self.model.config.text_config.num_attention_heads
// self.shard_config.tensor_parallel_size,
},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="self_attn.out_proj",
target_module=col_nn.Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="fc1",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="fc2",
target_module=col_nn.Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
],
)
policy[Blip2Attention] = ModulePolicyDescription(method_replacement={"forward": forward_fn()})
if self.enable_bias_gelu_fused:
self.append_or_create_method_replacement(
description={
"forward": get_jit_fused_blip2_mlp_forward(),
},
policy=policy,
target_key=Blip2MLP,
)
elif use_zbv:
policy[Blip2EncoderLayer] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attn.dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="self_attn.qkv",
target_module=col_nn.FusedLinear,
kwargs={
"split_sizes": [self.model.config.vision_config.hidden_size] * 3,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="self_attn.projection",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="mlp.fc1",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"skip_bias_add": self.enable_bias_gelu_fused,
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="mlp.fc2",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
],
)
policy[Blip2QFormerModel] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.DropoutForParallelInput,
),
]
)
policy[Blip2QFormerLayer] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="attention.attention.query",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="attention.attention.key",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="attention.attention.value",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="attention.attention.dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="attention.output.dense",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="attention.output.dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="crossattention.attention.query",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="crossattention.attention.key",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="crossattention.attention.value",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="crossattention.attention.dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="crossattention.output.dense",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="crossattention.output.dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="intermediate_query.dense",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="output_query.dense",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="output_query.dropout",
target_module=col_nn.DropoutForParallelInput,
),
],
)
policy[OPTDecoderLayer] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="self_attn.out_proj",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="fc1",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="fc2",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
],
)
policy[Blip2Attention] = ModulePolicyDescription(method_replacement={"forward": forward_fn()})
if self.enable_bias_gelu_fused:
self.append_or_create_method_replacement(
description={
"forward": get_jit_fused_blip2_mlp_forward(),
},
policy=policy,
target_key=Blip2MLP,
)
if embedding_cls is not None:
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="model.decoder.embed_tokens",
target_module=embedding_cls,
kwargs=(
{
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
"fp8_communication": self.shard_config.fp8_communication,
}
if self.shard_config.enable_tensor_parallelism
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
),
),
],
policy=policy,
target_key=OPTForCausalLM,
)
if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="lm_head",
target_module=col_nn.VocabParallelLMHead1D,
kwargs={
"gather_output": True,
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
"fp8_communication": self.shard_config.fp8_communication,
},
),
],
policy=policy,
target_key=OPTForCausalLM,
)
else:
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="lm_head",
target_module=col_nn.PaddingLMHead,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
),
],
policy=policy,
target_key=OPTForCausalLM,
)
# optimization configuration
# Handle Blip2EncoderLayer layer
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="layer_norm1",
target_module=norm_cls,
),
SubModuleReplacementDescription(
suffix="layer_norm2",
target_module=norm_cls,
),
],
policy=policy,
target_key=Blip2EncoderLayer,
)
# handle Blip2VisionModel layer
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="post_layernorm",
target_module=norm_cls,
)
],
policy=policy,
target_key=Blip2VisionModel,
)
# handle Blip2VisionModel layer
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="layernorm",
target_module=norm_cls,
)
],
policy=policy,
target_key=Blip2QFormerModel,
)
# handle Blip2QFormerLayer layer
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="attention.output.LayerNorm",
target_module=norm_cls,
),
SubModuleReplacementDescription(
suffix="crossattention.output.LayerNorm",
target_module=norm_cls,
),
SubModuleReplacementDescription(
suffix="output_query.LayerNorm",
target_module=norm_cls,
),
],
policy=policy,
target_key=Blip2QFormerLayer,
)
# handle OPTForCausalLM layer
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="model.decoder.final_layer_norm",
target_module=norm_cls,
)
],
policy=policy,
target_key=OPTForCausalLM,
)
# handle OPTDecoderLayer layer
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="self_attn_layer_norm",
target_module=norm_cls,
),
SubModuleReplacementDescription(
suffix="final_layer_norm",
target_module=norm_cls,
),
],
policy=policy,
target_key=OPTDecoderLayer,
)
# use flash attention
if self.shard_config.enable_flash_attention:
self.append_or_create_method_replacement(
description={
"forward": get_blip2_flash_attention_forward(),
},
policy=policy,
target_key=Blip2Attention,
)
# use jit operator
if self.shard_config.enable_jit_fused:
self.append_or_create_method_replacement(
description={
"forward": get_jit_fused_blip2_QFormer_self_output_forward(),
"dropout_add": get_jit_fused_dropout_add_func(),
},
policy=policy,
target_key=Blip2QFormerSelfOutput,
)
self.append_or_create_method_replacement(
description={
"forward": get_jit_fused_blip2_QFormer_output_forward(),
"dropout_add": get_jit_fused_dropout_add_func(),
},
policy=policy,
target_key=Blip2QFormerOutput,
)
return policy
def postprocess(self):
return self.model
# Blip2Model
class Blip2ModelPolicy(BlipPolicy):
def __init__(self) -> None:
super().__init__()
# Blip2ForConditionalGeneration
class Blip2ForConditionalGenerationPolicy(BlipPolicy):
def __init__(self) -> None:
super().__init__()