[Inference]Add fused rotary kernel and get cos cache kernel (#5302)

* add fused rotary and get cos cache func

* staged

* fix bugs

* fix bugs
This commit is contained in:
Jianghai
2024-01-24 16:20:42 +08:00
committed by GitHub
parent 3da9993b0d
commit c647e00e3c
6 changed files with 477 additions and 5 deletions

View File

@@ -98,11 +98,12 @@ def rotary_embedding(
Args:
q: query tensor, [total_tokens, head_num, head_dim]
k: key tensor, [total_tokens, head_num, head_dim]
cos: cosine for rotary embedding, [total_tokens, head_dim]
sin: sine for rotary embedding, [total_tokens, head_dim]
cos: cosine for rotary embedding, [max_position_len, head_dim]
sin: sine for rotary embedding, [max_position_len, head_dim]
lengths [num_seqs]
"""
q_total_tokens, q_head_num, head_dim = q.shape
assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f"q shape {q.shape} cos shape {cos.shape}"
assert q.size(0) == k.size(0)
BLOCK_HEAD = 4
BLOCK_TOKENS = 8
grid = (triton.cdiv(q_head_num, BLOCK_HEAD), 2 * triton.cdiv(q_total_tokens, BLOCK_TOKENS))