upgrade command

This commit is contained in:
wangbluo
2025-05-08 16:06:05 +08:00
parent 46ed5d856b
commit a9bb7cb943
3 changed files with 58 additions and 102 deletions

View File

@@ -41,15 +41,13 @@ class CommandPolicy(Policy):
from transformers.models.cohere.modeling_cohere import (
CohereAttention,
CohereDecoderLayer,
CohereFlashAttention2,
CohereModel,
CohereSdpaAttention,
)
ATTN_IMPLEMENTATION = {
"eager": CohereAttention,
"flash_attention_2": CohereFlashAttention2,
"sdpa": CohereSdpaAttention,
"flash_attention_2": CohereAttention,
"sdpa": CohereAttention,
}
policy = {}
@@ -60,12 +58,7 @@ class CommandPolicy(Policy):
else:
if self.tie_weight:
embedding_cls = PaddingEmbedding
if self.shard_config.enable_fused_normalization:
norm_cls = FusedLayerNorm
else:
norm_cls = LayerNorm
sp_mode = self.shard_config.sequence_parallelism_mode or None
sp_size = self.shard_config.sequence_parallel_size or None
sp_group = self.shard_config.sequence_parallel_process_group or None
@@ -280,29 +273,6 @@ class CommandPolicy(Policy):
target_key=CohereModel,
)
# optimization configuration
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="input_layernorm",
target_module=norm_cls,
kwargs={"sp_partial_derived": sp_partial_derived},
),
],
policy=policy,
target_key=CohereDecoderLayer,
)
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="norm",
target_module=norm_cls,
kwargs={"sp_partial_derived": sp_partial_derived},
),
policy=policy,
target_key=CohereModel,
)
return policy
def postprocess(self):
@@ -349,6 +319,7 @@ class CommandPolicy(Policy):
stage_manager = self.pipeline_stage_manager
held_layers = []
held_layers.append(module.rotary_emb)
if stage_manager.is_interleave:
assert stage_manager.num_model_chunks is not None
layers_per_stage = stage_manager.distribute_layers(len(module.layers))