mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-13 13:11:05 +00:00
fix bugs in request_handler.py and engine.py
This commit is contained in:
committed by
FrankLeeeee
parent
10e3c9f923
commit
d40eb26029
@@ -1,6 +1,7 @@
|
||||
from itertools import count
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import GenerationConfig, PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
@@ -159,7 +160,7 @@ class InferenceEngine:
|
||||
self,
|
||||
requests_id: List[int] = None,
|
||||
prompts: List[str] = None,
|
||||
prompts_token_ids: List[int] = None,
|
||||
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Add requests.
|
||||
@@ -176,9 +177,18 @@ class InferenceEngine:
|
||||
assert prompts, "When the prompts_token_ids is none, the input prompt list must be provided."
|
||||
prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=True)["input_ids"]
|
||||
|
||||
if isinstance(prompts_token_ids, list):
|
||||
pass
|
||||
elif isinstance(prompts_token_ids, torch.Tensor) or isinstance(prompts_token_ids, np.ndarray):
|
||||
prompts_token_ids = prompts_token_ids.tolist()
|
||||
else:
|
||||
raise TypeError(
|
||||
f"The dtype of prompts_token_ids must be one of list, torch.Tensor, np.ndarray, but got {type(prompts_token_ids)}."
|
||||
)
|
||||
|
||||
assert (
|
||||
len(prompts_token_ids[0]) < self.inference_config.max_input_len
|
||||
), "The length of input prompts must be less than max_input_len."
|
||||
len(prompts_token_ids[0]) <= self.inference_config.max_input_len
|
||||
), f"The length of input prompts {len(prompts_token_ids[0])} must be less than max_input_len {self.inference_config.max_input_len}."
|
||||
|
||||
prompts_num = len(prompts_token_ids)
|
||||
|
||||
|
@@ -131,9 +131,9 @@ class RequestHandler:
|
||||
"""
|
||||
assert not self._find_sequence(req.request_id), f"Sequence {req.request_id} already exists."
|
||||
assert (
|
||||
req.input_len < self.inference_config.max_input_len
|
||||
req.input_len <= self.inference_config.max_input_len
|
||||
), f"Sequence {req.request_id} exceeds input length limit"
|
||||
self.waiting_list[req.input_len * 3 // self.inference_config.max_input_len].append(req)
|
||||
self.waiting_list[req.input_len * 3 // (self.inference_config.max_input_len + 1)].append(req)
|
||||
|
||||
def abort_sequence(self, request_id: str):
|
||||
"""
|
||||
|
Reference in New Issue
Block a user