mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
adapted to pad_context_forward
This commit is contained in:
committed by
FrankLeeeee
parent
47e53eaa1c
commit
fa4fbdbffb
@@ -51,6 +51,8 @@ class InferenceEngine:
|
||||
self.model_config = model.config
|
||||
self.device = torch.device("cuda")
|
||||
|
||||
model = model.eval()
|
||||
|
||||
if inference_config.dtype == "fp32" or inference_config.dtype == torch.float32:
|
||||
self.dtype = torch.float32
|
||||
elif inference_config.dtype == "fp16" or inference_config.dtype == torch.float16:
|
||||
@@ -85,12 +87,12 @@ class InferenceEngine:
|
||||
Verify the input config
|
||||
"""
|
||||
if not isinstance(self.model, nn.Module):
|
||||
raise TypeError(f"the model type must be nn.Module, but get {type(self.model)}")
|
||||
raise TypeError(f"the model type must be nn.Module, but got {type(self.model)}")
|
||||
if not isinstance(self.tokenizer, PreTrainedTokenizerFast) and not isinstance(
|
||||
self.tokenizer, PreTrainedTokenizer
|
||||
):
|
||||
raise TypeError(
|
||||
f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but get {type(self.tokenizer)}"
|
||||
f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but got {type(self.tokenizer)}"
|
||||
)
|
||||
assert (
|
||||
self.model.__class__.__name__ in _supported_models
|
||||
|
@@ -8,6 +8,9 @@ from colossalai.inference.kv_cache import KVCacheManager
|
||||
from colossalai.inference.logit_processors import logit_processor
|
||||
from colossalai.inference.sampler import *
|
||||
from colossalai.inference.struct import BatchInfo, Sequence
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
logger = get_dist_logger(__name__)
|
||||
|
||||
|
||||
class RunningList:
|
||||
@@ -93,17 +96,23 @@ class RequestHandler:
|
||||
# Try to allocate cache blocks for the sequence using a priority of prompt length.
|
||||
for lst in reversed(self.waiting_list):
|
||||
if lst:
|
||||
remove_list = []
|
||||
for seq in lst:
|
||||
if seq.input_len > self.inference_config.max_input_len:
|
||||
# If the prompt length is longer than max_input_len, abort the sequence.
|
||||
logger.warning(
|
||||
f"the prompt(Request id = {seq.request_id}) length is longer than max_input_len, abort this sequence."
|
||||
)
|
||||
self.abort_sequence(seq.request_id)
|
||||
break
|
||||
remove_list.append(seq)
|
||||
# Try to allocate cache blocks for the sequence.
|
||||
if self.cache_manager.check_allocation(seq):
|
||||
# If succeed, add the sequence to running list.
|
||||
remove_list.append(seq)
|
||||
self.running_list.append(seq)
|
||||
self.cache_manager.allocate_context_from_block_table(seq.block_table, seq.input_len)
|
||||
lst.clear()
|
||||
for seq in remove_list:
|
||||
lst.remove(seq)
|
||||
if self.running_list.ready_for_prefill():
|
||||
for seq in self.running_list.prefill:
|
||||
seq.mark_running()
|
||||
@@ -130,10 +139,9 @@ class RequestHandler:
|
||||
"""
|
||||
Abort the request.
|
||||
"""
|
||||
seq, priority = self._find_sequence(request_id)
|
||||
seq, _ = self._find_sequence(request_id)
|
||||
if seq.status.is_waiting:
|
||||
seq.mark_aborted()
|
||||
self.waiting_list[priority].remove(seq)
|
||||
elif seq.status.is_running():
|
||||
self.cache_manager.free_block_table(seq.block_table)
|
||||
self.running_list.remove(seq)
|
||||
|
Reference in New Issue
Block a user