[builder] multihead attn runtime building (#2203)

* [hotfix] correcnt cpu_optim runtime compilation

* [builder] multihead attn

* fix bug

* fix a bug
This commit is contained in:
Jiarui Fang
2022-12-27 16:06:09 +08:00
committed by GitHub
parent 8e22c38b89
commit 1cb532ffec
7 changed files with 88 additions and 25 deletions

View File

@@ -12,4 +12,12 @@ except ImportError:
from colossalai.kernel.op_builder import CPUAdamBuilder
cpu_optim = CPUAdamBuilder().load()
__all__ = ["fused_optim", "cpu_optim", "LayerNorm", "FusedScaleMaskSoftmax", "MultiHeadAttention"]
try:
from colossalai._C import multihead_attention
except ImportError:
from colossalai.kernel.op_builder import MultiHeadAttnBuilder
multihead_attention = MultiHeadAttnBuilder().load()
__all__ = [
"fused_optim", "cpu_optim", "multihead_attention", "LayerNorm", "FusedScaleMaskSoftmax", "MultiHeadAttention"
]