[builder] multihead attn runtime building (#2203)

* [hotfix] correcnt cpu_optim runtime compilation

* [builder] multihead attn

* fix bug

* fix a bug
This commit is contained in:
Jiarui Fang
2022-12-27 16:06:09 +08:00
committed by GitHub
parent 8e22c38b89
commit 1cb532ffec
7 changed files with 88 additions and 25 deletions

View File

@@ -135,11 +135,8 @@ class MultiHeadAttention(nn.Module):
# Load cuda modules if needed
global colossal_multihead_attention
if colossal_multihead_attention is None:
try:
import colossalai._C.multihead_attention
colossal_multihead_attention = colossalai._C.multihead_attention
except ImportError:
raise RuntimeError('MultiHeadAttention requires cuda extensions')
from colossalai.kernel import multihead_attention
colossal_multihead_attention = multihead_attention
# create the layer in cuda kernels.
cuda_module = colossal_multihead_attention