diff --git a/colossalai/shardformer/policies/mistral.py b/colossalai/shardformer/policies/mistral.py index 936fd2d24..621982f29 100644 --- a/colossalai/shardformer/policies/mistral.py +++ b/colossalai/shardformer/policies/mistral.py @@ -18,9 +18,9 @@ from colossalai.shardformer.layer import ( from ..modeling.mistral import ( MistralForwards, + get_lm_forward_with_dist_cross_entropy, get_mistral_flash_attention_forward, get_mistral_model_forward_for_flash_attn, - get_lm_forward_with_dist_cross_entropy, ) from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription