mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 17:46:42 +00:00
fix (#5311)
This commit is contained in:
@@ -136,7 +136,7 @@ def fused_rotary_embedding(
|
||||
q_total_tokens, q_head_num, head_dim = q.shape
|
||||
assert q.size(0) == k.size(0)
|
||||
BLOCK_HEAD = 4
|
||||
BLOCK_SIZE = 16
|
||||
BLOCK_SIZE = 8
|
||||
cumsum_lens = torch.cumsum(lengths, dim=0)
|
||||
|
||||
grid = (triton.cdiv(q_head_num, BLOCK_HEAD), triton.cdiv(q_total_tokens, BLOCK_SIZE), BLOCK_SIZE)
|
||||
|
Reference in New Issue
Block a user