mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 04:24:47 +00:00
fix bugs in attention.py and request_handler.py
This commit is contained in:
committed by
FrankLeeeee
parent
bfd9b1b494
commit
47e53eaa1c
@@ -29,47 +29,50 @@ def copy_to_cache(source, cache, lengths, block_tables, type: str = "prefill"):
|
||||
for block_idx in range(block_num - 1):
|
||||
cache[block_tables[i][block_idx]] = source[i][token_id : token_id + block_size].permute(1, 2, 0)
|
||||
token_id += block_size
|
||||
cache[block_tables[i][block_num - 1]] = source[i][token_id:seq_len].permute(1, 2, 0)
|
||||
cache[block_tables[i][block_num - 1], :, :, : seq_len - token_id] = source[i][token_id:seq_len].permute(
|
||||
1, 2, 0
|
||||
)
|
||||
elif type == "decoding":
|
||||
assert len(source[0]) == 1, "seq_len should be equal to 1 when decoding."
|
||||
source = source.squeeze(1)
|
||||
slot_idx = (lengths + block_size - 1) % block_size
|
||||
for i in range(bsz):
|
||||
cache[block_tables[i, needed_blocks[i] - 1], :, :, slot_idx[i]] = source[i].permute(0, 1)
|
||||
cache[block_tables[i, needed_blocks[i] - 1], :, :, slot_idx[i]] = source[i]
|
||||
|
||||
return cache
|
||||
|
||||
|
||||
def convert_kvcache(source, cache, lengths, block_tables):
|
||||
def convert_kvcache(cache, lengths, block_tables):
|
||||
"""
|
||||
Func: convert key/value cache for calculation
|
||||
|
||||
Args: key/value(source): shape [bsz, 1, num_heads, head_size]
|
||||
cache: shape [num_blocks, num_heads, head_size, block_size]
|
||||
Args: cache: shape [num_blocks, num_heads, head_size, block_size]
|
||||
lengths: key/value length
|
||||
block_tables
|
||||
"""
|
||||
num_blocks, num_heads, head_size, block_size = cache.shape
|
||||
|
||||
needed_blocks = (lengths + block_size - 1) // block_size
|
||||
num_remaing_tokens = (lengths - 1) % block_size
|
||||
num_remaing_tokens = lengths % block_size
|
||||
num_remaing_tokens[num_remaing_tokens == 0] += block_size
|
||||
bsz = block_tables.shape[0]
|
||||
seq_len = max(lengths)
|
||||
padded_cache = []
|
||||
for i in range(bsz):
|
||||
cache1 = cache[block_tables[i][: needed_blocks[i] - 1]].permute((0, 3, 1, 2)).reshape(-1, num_heads, head_size)
|
||||
cache2 = cache[block_tables[i][needed_blocks[i] - 1], :, :, : num_remaing_tokens[i]].permute(2, 0, 1)
|
||||
|
||||
_cache = torch.cat(
|
||||
(
|
||||
cache[block_tables[i][: needed_blocks[i] - 1]].permute((3, 0, 1, 2)).reshape(-1, num_heads, head_size),
|
||||
cache[block_tables[i][needed_blocks[i] - 1], :, :, : num_remaing_tokens[i]].permute(2, 1, 0),
|
||||
cache1,
|
||||
cache2,
|
||||
),
|
||||
dim=0,
|
||||
)
|
||||
concat_cache = torch.cat((_cache, source[i]), dim=0)
|
||||
padding = seq_len - concat_cache.size(0)
|
||||
padding = seq_len - _cache.size(0)
|
||||
if padding > 0:
|
||||
concat_cache = F.pad(concat_cache, (0, 0, 0, 0, 0, 1))
|
||||
padded_cache.append(concat_cache)
|
||||
|
||||
_cache = F.pad(_cache, (0, 0, 0, 0, 0, 1))
|
||||
padded_cache.append(_cache)
|
||||
return torch.stack(padded_cache, dim=0)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user