[kernel] update triton init #4740 (#4740)

This commit is contained in:
Xuanlei Zhao 2023-09-18 09:44:27 +08:00 committed by GitHub
parent d151dcab74
commit 32e7f99416
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,3 +1,7 @@
try:
import triton
HAS_TRITON = True
from .context_attention import bloom_context_attn_fwd, llama_context_attn_fwd from .context_attention import bloom_context_attn_fwd, llama_context_attn_fwd
from .copy_kv_cache_dest import copy_kv_cache_to_dest from .copy_kv_cache_dest import copy_kv_cache_to_dest
from .fused_layernorm import layer_norm from .fused_layernorm import layer_norm
@ -10,3 +14,7 @@ __all__ = [
"llama_context_attn_fwd", "bloom_context_attn_fwd", "softmax", "layer_norm", "rmsnorm_forward", "llama_context_attn_fwd", "bloom_context_attn_fwd", "softmax", "layer_norm", "rmsnorm_forward",
"copy_kv_cache_to_dest", "rotary_embedding_fwd", "token_attention_fwd" "copy_kv_cache_to_dest", "rotary_embedding_fwd", "token_attention_fwd"
] ]
except ImportError:
HAS_TRITON = False
print("Triton is not installed. Please install Triton to use Triton kernels.")