From c3dc9b4dba80f7f9948a89463ee97d96e20e641f Mon Sep 17 00:00:00 2001 From: hxwang Date: Tue, 23 Jul 2024 12:56:58 +0000 Subject: [PATCH] [deepseek] replace attn (a workaround for bug in transformers) --- colossalai/shardformer/policies/deepseek.py | 34 ++++++++++++++++--- .../test_model/test_shard_deepseek_ghz.py | 1 + 2 files changed, 30 insertions(+), 5 deletions(-) diff --git a/colossalai/shardformer/policies/deepseek.py b/colossalai/shardformer/policies/deepseek.py index 1e44aba4e..d1d004ed5 100644 --- a/colossalai/shardformer/policies/deepseek.py +++ b/colossalai/shardformer/policies/deepseek.py @@ -1,4 +1,3 @@ -import warnings from functools import partial from typing import Callable, Dict, List, Union @@ -195,11 +194,36 @@ class DeepseekPolicy(Policy): ) if self.shard_config.enable_flash_attention: - warnings.warn( - "Flash attention has already been replaced in deepseek, and now set enable_flash_attention = False." - ) - self.shard_config.enable_flash_attention = False + # NOTE: there is a bug for toggling flash attention in AutoModel, which has to be used for deepseek right now + from transformers.dynamic_module_utils import get_class_from_dynamic_module + flash_attn_cls = get_class_from_dynamic_module( + "deepseek-ai/deepseek-moe-16b-base--modeling_deepseek.DeepseekFlashAttention2", + "deepseek-ai/deepseek-moe-16b-base", + ) + + class TargetFlashAttn: + def __init__(self): + raise RuntimeError("This class should not be instantiated") + + @staticmethod + def from_native_module(original_attn: nn.Module, *args, **kwargs) -> nn.Module: + flash_attn_module = flash_attn_cls(original_attn.config, original_attn.layer_idx) + flash_attn_module.q_proj = original_attn.q_proj + flash_attn_module.k_proj = original_attn.k_proj + flash_attn_module.v_proj = original_attn.v_proj + flash_attn_module.o_proj = original_attn.o_proj + flash_attn_module.rotary_emb = original_attn.rotary_emb + return flash_attn_module + + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="self_attn", + target_module=TargetFlashAttn, + ), + policy=policy, + target_key="DeepseekDecoderLayer", + ) return policy def postprocess(self): diff --git a/tests/test_shardformer/test_model/test_shard_deepseek_ghz.py b/tests/test_shardformer/test_model/test_shard_deepseek_ghz.py index fdca11005..fe834a4f6 100644 --- a/tests/test_shardformer/test_model/test_shard_deepseek_ghz.py +++ b/tests/test_shardformer/test_model/test_shard_deepseek_ghz.py @@ -220,6 +220,7 @@ def check_deepseek(rank, world_size, port): run_deepseek_test() +@pytest.mark.skip("redundant") @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run()