update_bloom

This commit is contained in:
wangbluo
2025-05-13 18:21:57 +08:00
parent 46ed5d856b
commit 2237531137
2 changed files with 93 additions and 1 deletions

View File

@@ -11,6 +11,7 @@ import colossalai.shardformer.layer as col_nn
from ..modeling.bert import (
BertPipelineForwards,
bert_sequence_parallel_forward_fn,
get_bert_sequence_parallel_attention_forward,
get_jit_fused_bert_intermediate_forward,
get_jit_fused_bert_output_forward,
get_jit_fused_bert_self_output_forward,
@@ -48,6 +49,7 @@ class BertPolicy(Policy):
BertLayer,
BertModel,
BertOutput,
BertSdpaSelfAttention,
BertSelfOutput,
)
@@ -77,6 +79,15 @@ class BertPolicy(Policy):
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
if self.shard_config.enable_sequence_parallelism:
self.append_or_create_method_replacement(
description={
"forward": get_bert_sequence_parallel_attention_forward(self.shard_config),
},
policy=policy,
target_key=BertSdpaSelfAttention,
)
if self.shard_config.enable_tensor_parallelism:
assert (
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0