fix bugs in attention.py and request_handler.py

This commit is contained in:
yuehuayingxueluo
2024-01-08 12:35:06 +08:00
committed by FrankLeeeee
parent bfd9b1b494
commit 47e53eaa1c
6 changed files with 208 additions and 60 deletions

View File

@@ -29,47 +29,50 @@ def copy_to_cache(source, cache, lengths, block_tables, type: str = "prefill"):
for block_idx in range(block_num - 1):
cache[block_tables[i][block_idx]] = source[i][token_id : token_id + block_size].permute(1, 2, 0)
token_id += block_size
cache[block_tables[i][block_num - 1]] = source[i][token_id:seq_len].permute(1, 2, 0)
cache[block_tables[i][block_num - 1], :, :, : seq_len - token_id] = source[i][token_id:seq_len].permute(
1, 2, 0
)
elif type == "decoding":
assert len(source[0]) == 1, "seq_len should be equal to 1 when decoding."
source = source.squeeze(1)
slot_idx = (lengths + block_size - 1) % block_size
for i in range(bsz):
cache[block_tables[i, needed_blocks[i] - 1], :, :, slot_idx[i]] = source[i].permute(0, 1)
cache[block_tables[i, needed_blocks[i] - 1], :, :, slot_idx[i]] = source[i]
return cache
def convert_kvcache(source, cache, lengths, block_tables):
def convert_kvcache(cache, lengths, block_tables):
"""
Func: convert key/value cache for calculation
Args: key/value(source): shape [bsz, 1, num_heads, head_size]
cache: shape [num_blocks, num_heads, head_size, block_size]
Args: cache: shape [num_blocks, num_heads, head_size, block_size]
lengths: key/value length
block_tables
"""
num_blocks, num_heads, head_size, block_size = cache.shape
needed_blocks = (lengths + block_size - 1) // block_size
num_remaing_tokens = (lengths - 1) % block_size
num_remaing_tokens = lengths % block_size
num_remaing_tokens[num_remaing_tokens == 0] += block_size
bsz = block_tables.shape[0]
seq_len = max(lengths)
padded_cache = []
for i in range(bsz):
cache1 = cache[block_tables[i][: needed_blocks[i] - 1]].permute((0, 3, 1, 2)).reshape(-1, num_heads, head_size)
cache2 = cache[block_tables[i][needed_blocks[i] - 1], :, :, : num_remaing_tokens[i]].permute(2, 0, 1)
_cache = torch.cat(
(
cache[block_tables[i][: needed_blocks[i] - 1]].permute((3, 0, 1, 2)).reshape(-1, num_heads, head_size),
cache[block_tables[i][needed_blocks[i] - 1], :, :, : num_remaing_tokens[i]].permute(2, 1, 0),
cache1,
cache2,
),
dim=0,
)
concat_cache = torch.cat((_cache, source[i]), dim=0)
padding = seq_len - concat_cache.size(0)
padding = seq_len - _cache.size(0)
if padding > 0:
concat_cache = F.pad(concat_cache, (0, 0, 0, 0, 0, 1))
padded_cache.append(concat_cache)
_cache = F.pad(_cache, (0, 0, 0, 0, 0, 1))
padded_cache.append(_cache)
return torch.stack(padded_cache, dim=0)