mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-14 22:23:23 +00:00
fix bugs in request_handler.py and engine.py
This commit is contained in:
parent
10e3c9f923
commit
d40eb26029
@ -28,7 +28,6 @@ class InferenceConfig:
|
|||||||
dtype (Union[str, torch.dtype]): The data type for weights and activations.
|
dtype (Union[str, torch.dtype]): The data type for weights and activations.
|
||||||
tp_size (int): Tensor parallel size.
|
tp_size (int): Tensor parallel size.
|
||||||
pp_size (int): Pipeline parallel size.
|
pp_size (int): Pipeline parallel size.
|
||||||
max_seq_len (int): Maximum length of input sentence.
|
|
||||||
beam_width (int): The maximum beam width used to initialize KV Cache.
|
beam_width (int): The maximum beam width used to initialize KV Cache.
|
||||||
During generation, the beam width provided as sampling parameter should be less than or equivalent to this value.
|
During generation, the beam width provided as sampling parameter should be less than or equivalent to this value.
|
||||||
prefill_ratio (Optional[float]): A controling ratio for prefill and decoding in running list, we will do a step of prefill
|
prefill_ratio (Optional[float]): A controling ratio for prefill and decoding in running list, we will do a step of prefill
|
||||||
@ -46,7 +45,6 @@ class InferenceConfig:
|
|||||||
dtype: Union[str, torch.dtype] = torch.float32
|
dtype: Union[str, torch.dtype] = torch.float32
|
||||||
tp_size: int = 1
|
tp_size: int = 1
|
||||||
pp_size: int = 1
|
pp_size: int = 1
|
||||||
max_seq_len: int = 512
|
|
||||||
# TODO: beam search is not support for now
|
# TODO: beam search is not support for now
|
||||||
beam_width: int = 1
|
beam_width: int = 1
|
||||||
# the ratio of prefill sequences to decoding sequences, we do prefill step once the actual value exceeds ratio
|
# the ratio of prefill sequences to decoding sequences, we do prefill step once the actual value exceeds ratio
|
||||||
@ -99,6 +97,3 @@ class InferenceConfig:
|
|||||||
"gptq",
|
"gptq",
|
||||||
None,
|
None,
|
||||||
], f"quant should be one of 'smoothquant', 'gptq', but got {self.quant_mode}."
|
], f"quant should be one of 'smoothquant', 'gptq', but got {self.quant_mode}."
|
||||||
assert (
|
|
||||||
self.max_input_len + self.max_output_len <= self.max_seq_len
|
|
||||||
), f"The sum of max_input_len {self.max_input_len} and max_output_len {self.max_output_len} must be smaller than max_seq_len {self.max_seq_len}."
|
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from itertools import count
|
from itertools import count
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from transformers import GenerationConfig, PreTrainedTokenizer, PreTrainedTokenizerFast
|
from transformers import GenerationConfig, PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||||
@ -159,7 +160,7 @@ class InferenceEngine:
|
|||||||
self,
|
self,
|
||||||
requests_id: List[int] = None,
|
requests_id: List[int] = None,
|
||||||
prompts: List[str] = None,
|
prompts: List[str] = None,
|
||||||
prompts_token_ids: List[int] = None,
|
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Add requests.
|
Add requests.
|
||||||
@ -176,9 +177,18 @@ class InferenceEngine:
|
|||||||
assert prompts, "When the prompts_token_ids is none, the input prompt list must be provided."
|
assert prompts, "When the prompts_token_ids is none, the input prompt list must be provided."
|
||||||
prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=True)["input_ids"]
|
prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=True)["input_ids"]
|
||||||
|
|
||||||
|
if isinstance(prompts_token_ids, list):
|
||||||
|
pass
|
||||||
|
elif isinstance(prompts_token_ids, torch.Tensor) or isinstance(prompts_token_ids, np.ndarray):
|
||||||
|
prompts_token_ids = prompts_token_ids.tolist()
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
f"The dtype of prompts_token_ids must be one of list, torch.Tensor, np.ndarray, but got {type(prompts_token_ids)}."
|
||||||
|
)
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
len(prompts_token_ids[0]) < self.inference_config.max_input_len
|
len(prompts_token_ids[0]) <= self.inference_config.max_input_len
|
||||||
), "The length of input prompts must be less than max_input_len."
|
), f"The length of input prompts {len(prompts_token_ids[0])} must be less than max_input_len {self.inference_config.max_input_len}."
|
||||||
|
|
||||||
prompts_num = len(prompts_token_ids)
|
prompts_num = len(prompts_token_ids)
|
||||||
|
|
||||||
|
@ -131,9 +131,9 @@ class RequestHandler:
|
|||||||
"""
|
"""
|
||||||
assert not self._find_sequence(req.request_id), f"Sequence {req.request_id} already exists."
|
assert not self._find_sequence(req.request_id), f"Sequence {req.request_id} already exists."
|
||||||
assert (
|
assert (
|
||||||
req.input_len < self.inference_config.max_input_len
|
req.input_len <= self.inference_config.max_input_len
|
||||||
), f"Sequence {req.request_id} exceeds input length limit"
|
), f"Sequence {req.request_id} exceeds input length limit"
|
||||||
self.waiting_list[req.input_len * 3 // self.inference_config.max_input_len].append(req)
|
self.waiting_list[req.input_len * 3 // (self.inference_config.max_input_len + 1)].append(req)
|
||||||
|
|
||||||
def abort_sequence(self, request_id: str):
|
def abort_sequence(self, request_id: str):
|
||||||
"""
|
"""
|
||||||
|
@ -58,7 +58,12 @@ class KVCacheManager:
|
|||||||
# Parallel settings
|
# Parallel settings
|
||||||
self.tp_size = config.tp_size
|
self.tp_size = config.tp_size
|
||||||
# Model settings
|
# Model settings
|
||||||
self.dtype = config.dtype
|
if config.dtype == "fp32" or config.dtype == torch.float32:
|
||||||
|
self.dtype = torch.float32
|
||||||
|
elif config.dtype == "fp16" or config.dtype == torch.float16:
|
||||||
|
self.dtype = torch.float16
|
||||||
|
else:
|
||||||
|
self.dtype = torch.bfloat16
|
||||||
self.elem_size_in_bytes = torch.tensor([], dtype=self.dtype).element_size()
|
self.elem_size_in_bytes = torch.tensor([], dtype=self.dtype).element_size()
|
||||||
self.num_layers = get_model_config_attr(model_config, "num_hidden_layers")
|
self.num_layers = get_model_config_attr(model_config, "num_hidden_layers")
|
||||||
# For now we focus on MHA only, TODO add handling for MQA and GQA
|
# For now we focus on MHA only, TODO add handling for MQA and GQA
|
||||||
|
Loading…
Reference in New Issue
Block a user