[Inference]Adapted to the triton attn kernels (#5264)

* adapted to the triton attn kernels

* fix pad input

* adapted to copy_kv_to_blocked_cache

* fix ci test

* update kv memcpy

* remove print
This commit is contained in:
yuehuayingxueluo
2024-01-17 16:03:10 +08:00
committed by GitHub
parent 0f2b46a41c
commit 86b63f720c
7 changed files with 221 additions and 101 deletions

View File

@@ -57,9 +57,6 @@ class RunningList:
def is_empty(self):
return not self.decoding and not self.prefill
def total_seq_num(self):
return len(self.decoding) + len(self.prefill)
class RequestHandler:
"""
@@ -81,6 +78,7 @@ class RequestHandler:
device = torch.cuda.current_device()
self.running_batch = BatchInfo(is_prompts=False, device=device)
self.prefill_batch = BatchInfo(is_prompts=True, device=device)
self.max_batch_size = inference_config.max_batch_size
def _init_cache(self, model_config):
self.cache_manager = KVCacheManager(self.inference_config, model_config)
@@ -108,20 +106,18 @@ class RequestHandler:
)
self.abort_sequence(seq.request_id)
break
# stop feeding new sequence into running list to assure
if self.cache_manager.num_available_blocks <= self.running_list.total_seq_num:
break
# Try to allocate cache blocks for the sequence.
if self.cache_manager.check_allocation(seq):
if (
self.cache_manager.check_allocation(seq)
and (len(self.running_list.prefill) + len(self.running_list.decoding))
< self.max_batch_size # There some bugs in continous batching, so we disable it here.
):
# If succeed, add the sequence to running list.
remove_list.append(seq)
self.running_list.append(seq)
self.cache_manager.allocate_context_from_block_table(seq.block_table, seq.input_len)
for seq in remove_list:
lst.remove(seq)
if self.running_list.ready_for_prefill():
for seq in self.running_list.prefill:
seq.mark_running()
@@ -130,12 +126,7 @@ class RequestHandler:
if not self.running_batch.is_empty:
for seq in self.running_batch.sequences_set:
recycle = self.cache_manager.allocate_token_from_block_table(seq.block_table, seq.sentence_len)
if recycle:
seq.recycle()
self.running_batch.remove(seq)
self.waiting_list[-1].append(seq)
# the recycled sequences are handled with highest priority.
self.cache_manager.allocate_token_from_block_table(seq.block_table, seq.sentence_len)
return self.running_batch