mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-11 13:59:08 +00:00
[builder] multihead attn runtime building (#2203)
* [hotfix] correcnt cpu_optim runtime compilation * [builder] multihead attn * fix bug * fix a bug
This commit is contained in:
@@ -135,11 +135,8 @@ class MultiHeadAttention(nn.Module):
|
||||
# Load cuda modules if needed
|
||||
global colossal_multihead_attention
|
||||
if colossal_multihead_attention is None:
|
||||
try:
|
||||
import colossalai._C.multihead_attention
|
||||
colossal_multihead_attention = colossalai._C.multihead_attention
|
||||
except ImportError:
|
||||
raise RuntimeError('MultiHeadAttention requires cuda extensions')
|
||||
from colossalai.kernel import multihead_attention
|
||||
colossal_multihead_attention = multihead_attention
|
||||
|
||||
# create the layer in cuda kernels.
|
||||
cuda_module = colossal_multihead_attention
|
||||
|
Reference in New Issue
Block a user