From d40eb26029e8c61fc2b8ef3a1b8126a229e48047 Mon Sep 17 00:00:00 2001 From: yuehuayingxueluo <867460659@qq.com> Date: Wed, 10 Jan 2024 10:38:53 +0800 Subject: [PATCH] fix bugs in request_handler.py and engine.py --- colossalai/inference/config.py | 5 ----- colossalai/inference/core/engine.py | 16 +++++++++++++--- colossalai/inference/core/request_handler.py | 4 ++-- colossalai/inference/kv_cache/kvcache_manager.py | 7 ++++++- 4 files changed, 21 insertions(+), 11 deletions(-) diff --git a/colossalai/inference/config.py b/colossalai/inference/config.py index 8ce4ce967..2c77a6e12 100644 --- a/colossalai/inference/config.py +++ b/colossalai/inference/config.py @@ -28,7 +28,6 @@ class InferenceConfig: dtype (Union[str, torch.dtype]): The data type for weights and activations. tp_size (int): Tensor parallel size. pp_size (int): Pipeline parallel size. - max_seq_len (int): Maximum length of input sentence. beam_width (int): The maximum beam width used to initialize KV Cache. During generation, the beam width provided as sampling parameter should be less than or equivalent to this value. prefill_ratio (Optional[float]): A controling ratio for prefill and decoding in running list, we will do a step of prefill @@ -46,7 +45,6 @@ class InferenceConfig: dtype: Union[str, torch.dtype] = torch.float32 tp_size: int = 1 pp_size: int = 1 - max_seq_len: int = 512 # TODO: beam search is not support for now beam_width: int = 1 # the ratio of prefill sequences to decoding sequences, we do prefill step once the actual value exceeds ratio @@ -99,6 +97,3 @@ class InferenceConfig: "gptq", None, ], f"quant should be one of 'smoothquant', 'gptq', but got {self.quant_mode}." - assert ( - self.max_input_len + self.max_output_len <= self.max_seq_len - ), f"The sum of max_input_len {self.max_input_len} and max_output_len {self.max_output_len} must be smaller than max_seq_len {self.max_seq_len}." diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py index eaacfe0f5..84810a82c 100644 --- a/colossalai/inference/core/engine.py +++ b/colossalai/inference/core/engine.py @@ -1,6 +1,7 @@ from itertools import count from typing import List, Optional, Union +import numpy as np import torch import torch.nn as nn from transformers import GenerationConfig, PreTrainedTokenizer, PreTrainedTokenizerFast @@ -159,7 +160,7 @@ class InferenceEngine: self, requests_id: List[int] = None, prompts: List[str] = None, - prompts_token_ids: List[int] = None, + prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None, ) -> None: """ Add requests. @@ -176,9 +177,18 @@ class InferenceEngine: assert prompts, "When the prompts_token_ids is none, the input prompt list must be provided." prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=True)["input_ids"] + if isinstance(prompts_token_ids, list): + pass + elif isinstance(prompts_token_ids, torch.Tensor) or isinstance(prompts_token_ids, np.ndarray): + prompts_token_ids = prompts_token_ids.tolist() + else: + raise TypeError( + f"The dtype of prompts_token_ids must be one of list, torch.Tensor, np.ndarray, but got {type(prompts_token_ids)}." + ) + assert ( - len(prompts_token_ids[0]) < self.inference_config.max_input_len - ), "The length of input prompts must be less than max_input_len." + len(prompts_token_ids[0]) <= self.inference_config.max_input_len + ), f"The length of input prompts {len(prompts_token_ids[0])} must be less than max_input_len {self.inference_config.max_input_len}." prompts_num = len(prompts_token_ids) diff --git a/colossalai/inference/core/request_handler.py b/colossalai/inference/core/request_handler.py index dd8591e7f..09443c92a 100644 --- a/colossalai/inference/core/request_handler.py +++ b/colossalai/inference/core/request_handler.py @@ -131,9 +131,9 @@ class RequestHandler: """ assert not self._find_sequence(req.request_id), f"Sequence {req.request_id} already exists." assert ( - req.input_len < self.inference_config.max_input_len + req.input_len <= self.inference_config.max_input_len ), f"Sequence {req.request_id} exceeds input length limit" - self.waiting_list[req.input_len * 3 // self.inference_config.max_input_len].append(req) + self.waiting_list[req.input_len * 3 // (self.inference_config.max_input_len + 1)].append(req) def abort_sequence(self, request_id: str): """ diff --git a/colossalai/inference/kv_cache/kvcache_manager.py b/colossalai/inference/kv_cache/kvcache_manager.py index 419fef3fb..33edebe63 100644 --- a/colossalai/inference/kv_cache/kvcache_manager.py +++ b/colossalai/inference/kv_cache/kvcache_manager.py @@ -58,7 +58,12 @@ class KVCacheManager: # Parallel settings self.tp_size = config.tp_size # Model settings - self.dtype = config.dtype + if config.dtype == "fp32" or config.dtype == torch.float32: + self.dtype = torch.float32 + elif config.dtype == "fp16" or config.dtype == torch.float16: + self.dtype = torch.float16 + else: + self.dtype = torch.bfloat16 self.elem_size_in_bytes = torch.tensor([], dtype=self.dtype).element_size() self.num_layers = get_model_config_attr(model_config, "num_hidden_layers") # For now we focus on MHA only, TODO add handling for MQA and GQA