mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-28 13:05:26 +00:00
[shardformer] bert support sequence parallel. (#4455)
* [shardformer] bert support sequence parallel [shardformer] bert support sequence parallel [shardformer] bert support sequence parallel [shardformer] bert support sequence parallel [shardformer] bert support sequence parallel [shardformer] bert support sequence parallel [shardformer] bert support sequence parallel [shardformer] bert support sequence parallel [shardformer] bert support sequence parallel * [shardformer] bert support sequence parallel [shardformer] bert support sequence parallel [shardformer] bert support sequence parallel * [shardformer] bert support sequence parallel
This commit is contained in:
@@ -10,6 +10,7 @@ import colossalai.shardformer.layer as col_nn
|
||||
from .._utils import getattr_, setattr_
|
||||
from ..modeling.bert import (
|
||||
BertPipelineForwards,
|
||||
bert_sequence_parallel_forward_fn,
|
||||
get_bert_flash_attention_forward,
|
||||
get_jit_fused_bert_output_forward,
|
||||
get_jit_fused_bert_self_output_forward,
|
||||
@@ -47,13 +48,14 @@ class BertPolicy(Policy):
|
||||
from transformers.models.bert.modeling_bert import (
|
||||
BertEmbeddings,
|
||||
BertLayer,
|
||||
BertModel,
|
||||
BertOutput,
|
||||
BertSelfAttention,
|
||||
BertSelfOutput,
|
||||
)
|
||||
|
||||
policy = {}
|
||||
|
||||
use_sequence_parallel = self.shard_config.enable_sequence_parallelism
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
policy[BertLayer] = ModulePolicyDescription(attribute_replacement={
|
||||
"attention.self.all_head_size":
|
||||
@@ -69,14 +71,17 @@ class BertPolicy(Policy):
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.self.query",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={"seq_parallel": use_sequence_parallel},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.self.key",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={"seq_parallel": use_sequence_parallel},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.self.value",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={"seq_parallel": use_sequence_parallel},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.self.dropout",
|
||||
@@ -85,6 +90,7 @@ class BertPolicy(Policy):
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.output.dense",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={"seq_parallel": use_sequence_parallel},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.output.dropout",
|
||||
@@ -93,10 +99,12 @@ class BertPolicy(Policy):
|
||||
SubModuleReplacementDescription(
|
||||
suffix="intermediate.dense",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={"seq_parallel": use_sequence_parallel},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="output.dense",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={"seq_parallel": use_sequence_parallel},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="output.dropout",
|
||||
@@ -115,6 +123,12 @@ class BertPolicy(Policy):
|
||||
)
|
||||
])
|
||||
|
||||
if use_sequence_parallel:
|
||||
self.append_or_create_method_replacement(
|
||||
description={'forward': bert_sequence_parallel_forward_fn(self.shard_config)},
|
||||
policy=policy,
|
||||
target_key=BertModel)
|
||||
|
||||
# optimization configuration
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
# Handle bert layer
|
||||
@@ -205,7 +219,13 @@ class BertPolicy(Policy):
|
||||
|
||||
layers_per_stage = Policy.distribute_layers(len(module.encoder.layer), stage_manager.num_stages)
|
||||
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
method_replacement = {'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
|
||||
method_replacement = {
|
||||
'forward':
|
||||
partial(new_forward,
|
||||
stage_manager=stage_manager,
|
||||
stage_index=stage_index,
|
||||
shard_config=self.shard_config)
|
||||
}
|
||||
self.append_or_create_method_replacement(description=method_replacement,
|
||||
policy=policy,
|
||||
target_key=model_cls)
|
||||
|
Reference in New Issue
Block a user