mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 03:52:01 +00:00
[Inference]Adapted to the triton attn kernels (#5264)
* adapted to the triton attn kernels * fix pad input * adapted to copy_kv_to_blocked_cache * fix ci test * update kv memcpy * remove print
This commit is contained in:
@@ -57,9 +57,6 @@ class RunningList:
|
||||
def is_empty(self):
|
||||
return not self.decoding and not self.prefill
|
||||
|
||||
def total_seq_num(self):
|
||||
return len(self.decoding) + len(self.prefill)
|
||||
|
||||
|
||||
class RequestHandler:
|
||||
"""
|
||||
@@ -81,6 +78,7 @@ class RequestHandler:
|
||||
device = torch.cuda.current_device()
|
||||
self.running_batch = BatchInfo(is_prompts=False, device=device)
|
||||
self.prefill_batch = BatchInfo(is_prompts=True, device=device)
|
||||
self.max_batch_size = inference_config.max_batch_size
|
||||
|
||||
def _init_cache(self, model_config):
|
||||
self.cache_manager = KVCacheManager(self.inference_config, model_config)
|
||||
@@ -108,20 +106,18 @@ class RequestHandler:
|
||||
)
|
||||
self.abort_sequence(seq.request_id)
|
||||
break
|
||||
|
||||
# stop feeding new sequence into running list to assure
|
||||
if self.cache_manager.num_available_blocks <= self.running_list.total_seq_num:
|
||||
break
|
||||
|
||||
# Try to allocate cache blocks for the sequence.
|
||||
if self.cache_manager.check_allocation(seq):
|
||||
if (
|
||||
self.cache_manager.check_allocation(seq)
|
||||
and (len(self.running_list.prefill) + len(self.running_list.decoding))
|
||||
< self.max_batch_size # There some bugs in continous batching, so we disable it here.
|
||||
):
|
||||
# If succeed, add the sequence to running list.
|
||||
remove_list.append(seq)
|
||||
self.running_list.append(seq)
|
||||
self.cache_manager.allocate_context_from_block_table(seq.block_table, seq.input_len)
|
||||
for seq in remove_list:
|
||||
lst.remove(seq)
|
||||
|
||||
if self.running_list.ready_for_prefill():
|
||||
for seq in self.running_list.prefill:
|
||||
seq.mark_running()
|
||||
@@ -130,12 +126,7 @@ class RequestHandler:
|
||||
|
||||
if not self.running_batch.is_empty:
|
||||
for seq in self.running_batch.sequences_set:
|
||||
recycle = self.cache_manager.allocate_token_from_block_table(seq.block_table, seq.sentence_len)
|
||||
if recycle:
|
||||
seq.recycle()
|
||||
self.running_batch.remove(seq)
|
||||
self.waiting_list[-1].append(seq)
|
||||
# the recycled sequences are handled with highest priority.
|
||||
self.cache_manager.allocate_token_from_block_table(seq.block_table, seq.sentence_len)
|
||||
|
||||
return self.running_batch
|
||||
|
||||
|
Reference in New Issue
Block a user