fix bugs in request_handler.py and engine.py

This commit is contained in:
yuehuayingxueluo 2024-01-10 10:38:53 +08:00 committed by FrankLeeeee
parent 10e3c9f923
commit d40eb26029
4 changed files with 21 additions and 11 deletions

View File

@ -28,7 +28,6 @@ class InferenceConfig:
dtype (Union[str, torch.dtype]): The data type for weights and activations. dtype (Union[str, torch.dtype]): The data type for weights and activations.
tp_size (int): Tensor parallel size. tp_size (int): Tensor parallel size.
pp_size (int): Pipeline parallel size. pp_size (int): Pipeline parallel size.
max_seq_len (int): Maximum length of input sentence.
beam_width (int): The maximum beam width used to initialize KV Cache. beam_width (int): The maximum beam width used to initialize KV Cache.
During generation, the beam width provided as sampling parameter should be less than or equivalent to this value. During generation, the beam width provided as sampling parameter should be less than or equivalent to this value.
prefill_ratio (Optional[float]): A controling ratio for prefill and decoding in running list, we will do a step of prefill prefill_ratio (Optional[float]): A controling ratio for prefill and decoding in running list, we will do a step of prefill
@ -46,7 +45,6 @@ class InferenceConfig:
dtype: Union[str, torch.dtype] = torch.float32 dtype: Union[str, torch.dtype] = torch.float32
tp_size: int = 1 tp_size: int = 1
pp_size: int = 1 pp_size: int = 1
max_seq_len: int = 512
# TODO: beam search is not support for now # TODO: beam search is not support for now
beam_width: int = 1 beam_width: int = 1
# the ratio of prefill sequences to decoding sequences, we do prefill step once the actual value exceeds ratio # the ratio of prefill sequences to decoding sequences, we do prefill step once the actual value exceeds ratio
@ -99,6 +97,3 @@ class InferenceConfig:
"gptq", "gptq",
None, None,
], f"quant should be one of 'smoothquant', 'gptq', but got {self.quant_mode}." ], f"quant should be one of 'smoothquant', 'gptq', but got {self.quant_mode}."
assert (
self.max_input_len + self.max_output_len <= self.max_seq_len
), f"The sum of max_input_len {self.max_input_len} and max_output_len {self.max_output_len} must be smaller than max_seq_len {self.max_seq_len}."

View File

@ -1,6 +1,7 @@
from itertools import count from itertools import count
from typing import List, Optional, Union from typing import List, Optional, Union
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from transformers import GenerationConfig, PreTrainedTokenizer, PreTrainedTokenizerFast from transformers import GenerationConfig, PreTrainedTokenizer, PreTrainedTokenizerFast
@ -159,7 +160,7 @@ class InferenceEngine:
self, self,
requests_id: List[int] = None, requests_id: List[int] = None,
prompts: List[str] = None, prompts: List[str] = None,
prompts_token_ids: List[int] = None, prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
) -> None: ) -> None:
""" """
Add requests. Add requests.
@ -176,9 +177,18 @@ class InferenceEngine:
assert prompts, "When the prompts_token_ids is none, the input prompt list must be provided." assert prompts, "When the prompts_token_ids is none, the input prompt list must be provided."
prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=True)["input_ids"] prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=True)["input_ids"]
if isinstance(prompts_token_ids, list):
pass
elif isinstance(prompts_token_ids, torch.Tensor) or isinstance(prompts_token_ids, np.ndarray):
prompts_token_ids = prompts_token_ids.tolist()
else:
raise TypeError(
f"The dtype of prompts_token_ids must be one of list, torch.Tensor, np.ndarray, but got {type(prompts_token_ids)}."
)
assert ( assert (
len(prompts_token_ids[0]) < self.inference_config.max_input_len len(prompts_token_ids[0]) <= self.inference_config.max_input_len
), "The length of input prompts must be less than max_input_len." ), f"The length of input prompts {len(prompts_token_ids[0])} must be less than max_input_len {self.inference_config.max_input_len}."
prompts_num = len(prompts_token_ids) prompts_num = len(prompts_token_ids)

View File

@ -131,9 +131,9 @@ class RequestHandler:
""" """
assert not self._find_sequence(req.request_id), f"Sequence {req.request_id} already exists." assert not self._find_sequence(req.request_id), f"Sequence {req.request_id} already exists."
assert ( assert (
req.input_len < self.inference_config.max_input_len req.input_len <= self.inference_config.max_input_len
), f"Sequence {req.request_id} exceeds input length limit" ), f"Sequence {req.request_id} exceeds input length limit"
self.waiting_list[req.input_len * 3 // self.inference_config.max_input_len].append(req) self.waiting_list[req.input_len * 3 // (self.inference_config.max_input_len + 1)].append(req)
def abort_sequence(self, request_id: str): def abort_sequence(self, request_id: str):
""" """

View File

@ -58,7 +58,12 @@ class KVCacheManager:
# Parallel settings # Parallel settings
self.tp_size = config.tp_size self.tp_size = config.tp_size
# Model settings # Model settings
self.dtype = config.dtype if config.dtype == "fp32" or config.dtype == torch.float32:
self.dtype = torch.float32
elif config.dtype == "fp16" or config.dtype == torch.float16:
self.dtype = torch.float16
else:
self.dtype = torch.bfloat16
self.elem_size_in_bytes = torch.tensor([], dtype=self.dtype).element_size() self.elem_size_in_bytes = torch.tensor([], dtype=self.dtype).element_size()
self.num_layers = get_model_config_attr(model_config, "num_hidden_layers") self.num_layers = get_model_config_attr(model_config, "num_hidden_layers")
# For now we focus on MHA only, TODO add handling for MQA and GQA # For now we focus on MHA only, TODO add handling for MQA and GQA