mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-10-09 20:14:29 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -9,12 +9,14 @@ SUPPORT_XFORMERS = False
|
||||
SUPPORT_FLASH2 = False
|
||||
try:
|
||||
import xformers.ops as xops
|
||||
|
||||
SUPPORT_XFORMERS = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from flash_attn import flash_attn_func
|
||||
|
||||
SUPPORT_FLASH2 = True
|
||||
except ImportError:
|
||||
pass
|
||||
@@ -62,10 +64,9 @@ def llama_flash_attention(
|
||||
if SUPPORT_FLASH2:
|
||||
attn_output = flash_attn_func(query_states, key_states, value_states, causal=True)
|
||||
else:
|
||||
attn_output = xops.memory_efficient_attention(query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_bias=xops.LowerTriangularMask())
|
||||
attn_output = xops.memory_efficient_attention(
|
||||
query_states, key_states, value_states, attn_bias=xops.LowerTriangularMask()
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
|
Reference in New Issue
Block a user