mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
fix (#5311)
This commit is contained in:
@@ -5,9 +5,11 @@ import triton.language as tl
|
||||
|
||||
@triton.jit
|
||||
def prefill_cache_kernel(
|
||||
CaChe,
|
||||
cos_cache,
|
||||
sin_cache,
|
||||
cumsum_lengths,
|
||||
output,
|
||||
cos_output,
|
||||
sin_output,
|
||||
cache_stride,
|
||||
hidden_stride,
|
||||
total_length,
|
||||
@@ -22,15 +24,31 @@ def prefill_cache_kernel(
|
||||
# original seq_idx and pos
|
||||
cumsum_lens = tl.load(cumsum_lengths + tl.arange(0, N_ELEMENTS))
|
||||
ori_seq_idx = idx - tl.max(tl.where(cumsum_lens <= idx, cumsum_lens, 0))
|
||||
_cache = tl.load(CaChe + ori_seq_idx * cache_stride + tl.arange(0, HIDDEN_DIM) * hidden_stride)
|
||||
tl.store(output + idx * cache_stride + tl.arange(0, HIDDEN_DIM) * hidden_stride, _cache, mask=idx < total_length)
|
||||
cos_cache_part = tl.load(
|
||||
cos_cache + ori_seq_idx * cache_stride + tl.arange(0, HIDDEN_DIM) * hidden_stride, mask=idx < total_length
|
||||
)
|
||||
sin_cache_part = tl.load(
|
||||
sin_cache + ori_seq_idx * cache_stride + tl.arange(0, HIDDEN_DIM) * hidden_stride, mask=idx < total_length
|
||||
)
|
||||
tl.store(
|
||||
cos_output + idx * cache_stride + tl.arange(0, HIDDEN_DIM) * hidden_stride,
|
||||
cos_cache_part,
|
||||
mask=idx < total_length,
|
||||
)
|
||||
tl.store(
|
||||
sin_output + idx * cache_stride + tl.arange(0, HIDDEN_DIM) * hidden_stride,
|
||||
sin_cache_part,
|
||||
mask=idx < total_length,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def decoding_cache_kernel(
|
||||
CaChe,
|
||||
cos_cache,
|
||||
sin_cache,
|
||||
lengths,
|
||||
output,
|
||||
cos_output,
|
||||
sin_output,
|
||||
cache_stride,
|
||||
hidden_stride,
|
||||
HIDDEN_DIM: tl.constexpr,
|
||||
@@ -39,16 +57,28 @@ def decoding_cache_kernel(
|
||||
):
|
||||
idx = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
ori_seq_idx = tl.load(lengths + idx, mask=(idx < NUM_SEQS), other=None) # [BLOCK_SIZE,]
|
||||
_cache = tl.load(CaChe + ori_seq_idx[:, None] * cache_stride + tl.arange(0, HIDDEN_DIM)[None, :] * hidden_stride)
|
||||
cos_cache_part = tl.load(
|
||||
cos_cache + ori_seq_idx[:, None] * cache_stride + tl.arange(0, HIDDEN_DIM)[None, :] * hidden_stride,
|
||||
mask=idx[:, None] < NUM_SEQS,
|
||||
)
|
||||
sin_cache_part = tl.load(
|
||||
sin_cache + ori_seq_idx[:, None] * cache_stride + tl.arange(0, HIDDEN_DIM)[None, :] * hidden_stride,
|
||||
mask=idx[:, None] < NUM_SEQS,
|
||||
)
|
||||
tl.store(
|
||||
output + (idx[:, None] * cache_stride + tl.arange(0, HIDDEN_DIM)[None, :] * hidden_stride),
|
||||
_cache,
|
||||
cos_output + (idx[:, None] * cache_stride + tl.arange(0, HIDDEN_DIM)[None, :] * hidden_stride),
|
||||
cos_cache_part,
|
||||
mask=idx[:, None] < NUM_SEQS,
|
||||
)
|
||||
tl.store(
|
||||
sin_output + (idx[:, None] * cache_stride + tl.arange(0, HIDDEN_DIM)[None, :] * hidden_stride),
|
||||
sin_cache_part,
|
||||
mask=idx[:, None] < NUM_SEQS,
|
||||
)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def get_xine_cache(lengths: torch.Tensor, cache: torch.Tensor, is_prompts: bool = False):
|
||||
def get_xine_cache(lengths: torch.Tensor, cos_cache: torch.Tensor, sin_cache: torch.Tensor, is_prompts: bool = False):
|
||||
"""
|
||||
Transform cos/sin cache into no pad sequence, with two different modes.
|
||||
Args:
|
||||
@@ -60,28 +90,33 @@ def get_xine_cache(lengths: torch.Tensor, cache: torch.Tensor, is_prompts: bool
|
||||
For decoding mode:
|
||||
cos/sin cache is only needed for the last token.
|
||||
"""
|
||||
|
||||
_, hidden_dim = cache.shape
|
||||
assert cos_cache.shape[1] == sin_cache.shape[1]
|
||||
_, hidden_dim = cos_cache.shape
|
||||
num_seqs = lengths.numel()
|
||||
|
||||
BLOCK_SIZE = 16
|
||||
if hidden_dim >= 128:
|
||||
if hidden_dim >= 256:
|
||||
num_warps = 16
|
||||
elif hidden_dim >= 128:
|
||||
num_warps = 8
|
||||
else:
|
||||
num_warps = 4
|
||||
|
||||
cache_stride = cache.stride(0)
|
||||
hidden_stride = cache.stride(1)
|
||||
cache_stride = cos_cache.stride(0)
|
||||
hidden_stride = cos_cache.stride(1)
|
||||
|
||||
if is_prompts:
|
||||
BLOCK_SIZE = 16
|
||||
total_length = lengths.sum().item()
|
||||
cumsum_lens = torch.cumsum(lengths, dim=0)
|
||||
output = torch.empty((total_length, hidden_dim), dtype=cache.dtype, device=cache.device)
|
||||
cos_output = torch.empty((total_length, hidden_dim), dtype=cos_cache.dtype, device=cos_cache.device)
|
||||
sin_output = torch.empty((total_length, hidden_dim), dtype=sin_cache.dtype, device=sin_cache.device)
|
||||
grid = (triton.cdiv(total_length, BLOCK_SIZE), BLOCK_SIZE)
|
||||
prefill_cache_kernel[grid](
|
||||
cache,
|
||||
cos_cache,
|
||||
sin_cache,
|
||||
cumsum_lens,
|
||||
output,
|
||||
cos_output,
|
||||
sin_output,
|
||||
cache_stride,
|
||||
hidden_stride,
|
||||
total_length,
|
||||
@@ -91,14 +126,17 @@ def get_xine_cache(lengths: torch.Tensor, cache: torch.Tensor, is_prompts: bool
|
||||
num_warps=num_warps,
|
||||
)
|
||||
else:
|
||||
# BUG: get memory access error whe using a deepcopy lengths to replace lengths
|
||||
BLOCK_SIZE = 4
|
||||
nlengths = torch.as_tensor(lengths) - 1
|
||||
output = torch.empty((num_seqs, hidden_dim), dtype=cache.dtype, device=cache.device)
|
||||
cos_output = torch.empty((num_seqs, hidden_dim), dtype=cos_cache.dtype, device=cos_cache.device)
|
||||
sin_output = torch.empty((num_seqs, hidden_dim), dtype=sin_cache.dtype, device=sin_cache.device)
|
||||
grid = (triton.cdiv(num_seqs, BLOCK_SIZE),)
|
||||
decoding_cache_kernel[grid](
|
||||
cache,
|
||||
cos_cache,
|
||||
sin_cache,
|
||||
nlengths,
|
||||
output,
|
||||
cos_output,
|
||||
sin_output,
|
||||
cache_stride,
|
||||
hidden_stride,
|
||||
HIDDEN_DIM=hidden_dim,
|
||||
@@ -107,4 +145,4 @@ def get_xine_cache(lengths: torch.Tensor, cache: torch.Tensor, is_prompts: bool
|
||||
num_warps=num_warps,
|
||||
)
|
||||
|
||||
return output
|
||||
return cos_output, sin_output
|
||||
|
Reference in New Issue
Block a user