[kernel] Add RMSLayerNorm triton kernel (#5262)

* add layerrmsnorm triton kernel

* add layerrmsnorm kernel

* modify the atol and rtol in test file

* Remove the logics of mean computations, and update the name of ther kernel functions and files

* add benchmark of rms norm
This commit is contained in:
Yaozheng Fang
2024-01-18 10:21:03 +08:00
committed by GitHub
parent 86b63f720c
commit 5ae9099f92
4 changed files with 103 additions and 62 deletions

View File

@@ -10,7 +10,7 @@ except ImportError:
if HAS_TRITON:
from .context_attn_unpad import context_attention_unpadded
from .flash_decoding import flash_decoding_fwd
from .fused_layernorm import layer_norm
from .rms_layernorm import rms_layernorm
from .gptq_triton import gptq_fused_linear_triton
from .kvcache_copy import copy_kv_to_blocked_cache
from .no_pad_rotary_embedding import rotary_embedding
@@ -21,7 +21,7 @@ if HAS_TRITON:
"flash_decoding_fwd",
"copy_kv_to_blocked_cache",
"softmax",
"layer_norm",
"rms_layernorm",
"gptq_fused_linear_triton",
"rotary_embedding",
]