mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
[Fix/Inference]Fix CUDA Rotary Rmbedding GQA (#5623)
* fix rotary embedding GQA * change test_rotary_embdding_unpad.py KH
This commit is contained in:
@@ -115,7 +115,7 @@ __device__ void apply_k_rotary_emb_compute(
|
||||
(head_offset % shard_block_size) / VecSize;
|
||||
const int64_t addr_offset =
|
||||
token_id * key_stride + (i / half_head_dim) * head_dim + head_offset;
|
||||
const int64_t target_id = block_id * head_num * head_dim * block_size +
|
||||
const int64_t target_id = block_id * kv_head_num * head_dim * block_size +
|
||||
(i / half_head_dim) * block_size * head_dim +
|
||||
block_offset * head_dim + head_offset;
|
||||
|
||||
@@ -137,7 +137,7 @@ __device__ void apply_k_rotary_emb_compute(
|
||||
|
||||
// apply value memcopy
|
||||
apply_kv_memcopy<scalar_t, VecSize>(
|
||||
value, value_cache, value_stride, token_id, block_id, head_num * head_dim,
|
||||
value, value_cache, value_stride, token_id, block_id, kv_head_num * head_dim,
|
||||
block_size, block_offset, head_dim, half_head_dim);
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user