mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[Inference] add logit processor and request handler (#5166)
* add logit processor and request handler * add * add * add * fix * add search tokens and update func * finish request handler * add running list test * fix test * fix some bug * add * add * fix bugs * fix some bugs * fix bug * fix * fix * add copy fun * del useless attn * fix request status --------- Co-authored-by: CjhHa1 <cjh18671720497outlook.com>
This commit is contained in:
@@ -88,7 +88,7 @@ def check_cache_manager(test_config):
|
||||
)
|
||||
cache_manager = KVCacheManager(inference_config, model_config)
|
||||
|
||||
num_blocks = cache_manager.get_total_num_blocks()
|
||||
num_blocks = cache_manager.total_num_blocks
|
||||
assert num_blocks > 0
|
||||
assert len(cache_manager._cache_blocks) == num_blocks
|
||||
key_caches = cache_manager._kv_caches[0] # key caches for all the blocks in all the layers
|
||||
@@ -114,7 +114,7 @@ def check_cache_manager(test_config):
|
||||
last_allocated_idx = (cur_seq_len - 1) // block_size
|
||||
assert torch.all(cur_block_table[: last_allocated_idx + 1] >= 0)
|
||||
cnt_blocks_used += torch.sum(cur_block_table >= 0).item()
|
||||
assert cache_manager.get_num_available_blocks() == num_blocks - cnt_blocks_used
|
||||
assert cache_manager.num_available_blocks == num_blocks - cnt_blocks_used
|
||||
|
||||
# Mock Decoding
|
||||
for req_i in range(max_batch_size):
|
||||
@@ -136,9 +136,9 @@ def check_cache_manager(test_config):
|
||||
req_i = random.randint(0, max_batch_size - 1)
|
||||
context_length = context_lengths[req_i]
|
||||
blocks_used_by_req = torch.sum(block_tables[req_i] >= 0).item()
|
||||
prev_available_blocks = cache_manager.get_num_available_blocks()
|
||||
prev_available_blocks = cache_manager.num_available_blocks
|
||||
cache_manager.free_block_table(block_tables[req_i])
|
||||
assert cache_manager.get_num_available_blocks() == blocks_used_by_req + prev_available_blocks
|
||||
assert cache_manager.num_available_blocks == blocks_used_by_req + prev_available_blocks
|
||||
|
||||
k_ptr_block0_layer0, _ = cache_manager.get_block_kv_ptrs(0, 0)
|
||||
k_ptr_block1_layer0, _ = cache_manager.get_block_kv_ptrs(1, 0)
|
||||
@@ -146,7 +146,7 @@ def check_cache_manager(test_config):
|
||||
expected_stride = block_size * num_attention_heads * head_size * elem_size
|
||||
assert k_ptr_block1_layer0 - k_ptr_block0_layer0 == expected_stride
|
||||
cache_manager.clear_all()
|
||||
assert cache_manager.get_num_available_blocks() == num_blocks
|
||||
assert cache_manager.num_available_blocks == num_blocks
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
|
Reference in New Issue
Block a user