[upgrade]upgrade mistral (#6296)

* upgrade mistral

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
flybird11111
2025-05-21 16:14:45 +08:00
committed by GitHub
parent 04516bb756
commit 6875a8a1cf
3 changed files with 79 additions and 160 deletions

View File

@@ -38,24 +38,10 @@ class MistralPolicy(Policy):
return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
from transformers.models.mistral.modeling_mistral import (
MistralAttention,
MistralDecoderLayer,
MistralFlashAttention2,
MistralModel,
MistralSdpaAttention,
)
ATTN_IMPLEMENTATION = {
"eager": MistralAttention,
"flash_attention_2": MistralFlashAttention2,
"sdpa": MistralSdpaAttention,
}
from transformers.models.mistral.modeling_mistral import MistralAttention, MistralDecoderLayer, MistralModel
policy = {}
attn_cls = ATTN_IMPLEMENTATION[self.model.config._attn_implementation]
embedding_cls = None
if self.shard_config.enable_tensor_parallelism:
embedding_cls = VocabParallelEmbedding1D
@@ -258,7 +244,7 @@ class MistralPolicy(Policy):
"forward": get_mistral_flash_attention_forward(self.shard_config),
},
policy=policy,
target_key=attn_cls,
target_key=MistralAttention,
)
if self.pipeline_stage_manager is None:
# replace llama model forward method
@@ -316,6 +302,7 @@ class MistralPolicy(Policy):
stage_manager = self.pipeline_stage_manager
held_layers = []
held_layers.append(module.rotary_emb)
if stage_manager.is_interleave:
assert stage_manager.num_model_chunks is not None
layers_per_stage = stage_manager.distribute_layers(len(module.layers))