mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
[Inference]Add fused rotary kernel and get cos cache kernel (#5302)
* add fused rotary and get cos cache func * staged * fix bugs * fix bugs
This commit is contained in:
@@ -98,11 +98,12 @@ def rotary_embedding(
|
||||
Args:
|
||||
q: query tensor, [total_tokens, head_num, head_dim]
|
||||
k: key tensor, [total_tokens, head_num, head_dim]
|
||||
cos: cosine for rotary embedding, [total_tokens, head_dim]
|
||||
sin: sine for rotary embedding, [total_tokens, head_dim]
|
||||
cos: cosine for rotary embedding, [max_position_len, head_dim]
|
||||
sin: sine for rotary embedding, [max_position_len, head_dim]
|
||||
lengths [num_seqs]
|
||||
"""
|
||||
q_total_tokens, q_head_num, head_dim = q.shape
|
||||
assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f"q shape {q.shape} cos shape {cos.shape}"
|
||||
assert q.size(0) == k.size(0)
|
||||
BLOCK_HEAD = 4
|
||||
BLOCK_TOKENS = 8
|
||||
grid = (triton.cdiv(q_head_num, BLOCK_HEAD), 2 * triton.cdiv(q_total_tokens, BLOCK_TOKENS))
|
||||
|
Reference in New Issue
Block a user