This commit is contained in:
Jianghai
2024-01-26 15:02:12 +08:00
committed by GitHub
parent 4f28cb43c0
commit 7ddd8b37f0
4 changed files with 149 additions and 75 deletions

View File

@@ -136,7 +136,7 @@ def fused_rotary_embedding(
q_total_tokens, q_head_num, head_dim = q.shape
assert q.size(0) == k.size(0)
BLOCK_HEAD = 4
BLOCK_SIZE = 16
BLOCK_SIZE = 8
cumsum_lens = torch.cumsum(lengths, dim=0)
grid = (triton.cdiv(q_head_num, BLOCK_HEAD), triton.cdiv(q_total_tokens, BLOCK_SIZE), BLOCK_SIZE)