mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-27 20:46:00 +00:00
[upgrade]upgrade mistral (#6296)
* upgrade mistral * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -38,24 +38,10 @@ class MistralPolicy(Policy):
|
||||
return self.model
|
||||
|
||||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||
from transformers.models.mistral.modeling_mistral import (
|
||||
MistralAttention,
|
||||
MistralDecoderLayer,
|
||||
MistralFlashAttention2,
|
||||
MistralModel,
|
||||
MistralSdpaAttention,
|
||||
)
|
||||
|
||||
ATTN_IMPLEMENTATION = {
|
||||
"eager": MistralAttention,
|
||||
"flash_attention_2": MistralFlashAttention2,
|
||||
"sdpa": MistralSdpaAttention,
|
||||
}
|
||||
from transformers.models.mistral.modeling_mistral import MistralAttention, MistralDecoderLayer, MistralModel
|
||||
|
||||
policy = {}
|
||||
|
||||
attn_cls = ATTN_IMPLEMENTATION[self.model.config._attn_implementation]
|
||||
|
||||
embedding_cls = None
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
embedding_cls = VocabParallelEmbedding1D
|
||||
@@ -258,7 +244,7 @@ class MistralPolicy(Policy):
|
||||
"forward": get_mistral_flash_attention_forward(self.shard_config),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=attn_cls,
|
||||
target_key=MistralAttention,
|
||||
)
|
||||
if self.pipeline_stage_manager is None:
|
||||
# replace llama model forward method
|
||||
@@ -316,6 +302,7 @@ class MistralPolicy(Policy):
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
|
||||
held_layers = []
|
||||
held_layers.append(module.rotary_emb)
|
||||
if stage_manager.is_interleave:
|
||||
assert stage_manager.num_model_chunks is not None
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||
|
Reference in New Issue
Block a user