[Inference] add logit processor and request handler (#5166)

* add logit processor and request handler

* add

* add

* add

* fix

* add search tokens and update func

* finish request handler

* add running list test

* fix test

* fix some bug

* add

* add

* fix bugs

* fix some bugs

* fix bug

* fix

* fix

* add copy fun

* del useless attn

* fix request status

---------

Co-authored-by: CjhHa1 <cjh18671720497outlook.com>
This commit is contained in:
Jianghai
2023-12-25 12:15:15 +08:00
committed by FrankLeeeee
parent 8daee26989
commit 0e616462a7
10 changed files with 463 additions and 66 deletions

View File

@@ -88,7 +88,7 @@ def check_cache_manager(test_config):
)
cache_manager = KVCacheManager(inference_config, model_config)
num_blocks = cache_manager.get_total_num_blocks()
num_blocks = cache_manager.total_num_blocks
assert num_blocks > 0
assert len(cache_manager._cache_blocks) == num_blocks
key_caches = cache_manager._kv_caches[0] # key caches for all the blocks in all the layers
@@ -114,7 +114,7 @@ def check_cache_manager(test_config):
last_allocated_idx = (cur_seq_len - 1) // block_size
assert torch.all(cur_block_table[: last_allocated_idx + 1] >= 0)
cnt_blocks_used += torch.sum(cur_block_table >= 0).item()
assert cache_manager.get_num_available_blocks() == num_blocks - cnt_blocks_used
assert cache_manager.num_available_blocks == num_blocks - cnt_blocks_used
# Mock Decoding
for req_i in range(max_batch_size):
@@ -136,9 +136,9 @@ def check_cache_manager(test_config):
req_i = random.randint(0, max_batch_size - 1)
context_length = context_lengths[req_i]
blocks_used_by_req = torch.sum(block_tables[req_i] >= 0).item()
prev_available_blocks = cache_manager.get_num_available_blocks()
prev_available_blocks = cache_manager.num_available_blocks
cache_manager.free_block_table(block_tables[req_i])
assert cache_manager.get_num_available_blocks() == blocks_used_by_req + prev_available_blocks
assert cache_manager.num_available_blocks == blocks_used_by_req + prev_available_blocks
k_ptr_block0_layer0, _ = cache_manager.get_block_kv_ptrs(0, 0)
k_ptr_block1_layer0, _ = cache_manager.get_block_kv_ptrs(1, 0)
@@ -146,7 +146,7 @@ def check_cache_manager(test_config):
expected_stride = block_size * num_attention_heads * head_size * elem_size
assert k_ptr_block1_layer0 - k_ptr_block0_layer0 == expected_stride
cache_manager.clear_all()
assert cache_manager.get_num_available_blocks() == num_blocks
assert cache_manager.num_available_blocks == num_blocks
def run_dist(rank, world_size, port):