mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 13:00:52 +00:00
[Fix/Inference] Fix format of input prompts and input model in inference engine (#5395)
* Fix bugs in inference_engine * fix bugs in engine.py * rm CUDA_VISIBLE_DEVICES * add request_ids in generate * fix bug in engine.py * add logger.debug for BatchBucket
This commit is contained in:
@@ -57,6 +57,7 @@ class InferenceEngine:
|
||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
self.generation_config = inference_config.to_generation_config(self.model_config)
|
||||
model = model.eval()
|
||||
model = model.cuda()
|
||||
model.to(self.dtype)
|
||||
|
||||
if model_policy is None:
|
||||
@@ -133,12 +134,13 @@ class InferenceEngine:
|
||||
)
|
||||
shardformer = ShardFormer(shard_config=shardconfig)
|
||||
shard_model, _ = shardformer.optimize(model, model_policy)
|
||||
return shard_model.cuda()
|
||||
return shard_model
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompts: List[str] = None,
|
||||
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
|
||||
request_ids: List[int] = None,
|
||||
return_token_ids: bool = False,
|
||||
generation_config: Optional[GenerationConfig] = None,
|
||||
) -> List[str]:
|
||||
@@ -148,6 +150,7 @@ class InferenceEngine:
|
||||
Args:
|
||||
prompts (Union[List[str], optional): Input prompts. Defaults to None.
|
||||
prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None.
|
||||
request_ids (List[int], optional): The request ID. Defaults to None.
|
||||
return_token_ids (bool): Whether to return output token ids. Defaults to False.
|
||||
generation_config (GenerationConfig, optional): Huggingface GenerationConfig used for inference. Defaults to None.
|
||||
|
||||
@@ -157,7 +160,7 @@ class InferenceEngine:
|
||||
with torch.inference_mode():
|
||||
self.generation_config = generation_config
|
||||
if prompts is not None or prompts_token_ids is not None:
|
||||
self.add_request(prompts=prompts, prompts_token_ids=prompts_token_ids)
|
||||
self.add_request(request_ids=request_ids, prompts=prompts, prompts_token_ids=prompts_token_ids)
|
||||
|
||||
output_seqs_list = []
|
||||
total_tokens_list = []
|
||||
@@ -204,7 +207,7 @@ class InferenceEngine:
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
requests_id: List[int] = None,
|
||||
request_ids: List[int] = None,
|
||||
prompts: List[str] = None,
|
||||
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
|
||||
) -> None:
|
||||
@@ -212,7 +215,7 @@ class InferenceEngine:
|
||||
Add requests.
|
||||
|
||||
Args:
|
||||
requests_id (List[int], optional): The request ID. Defaults to None.
|
||||
request_ids (List[int], optional): The request ID. Defaults to None.
|
||||
prompts (Union[List[str], optional): Input prompts. Defaults to None.
|
||||
prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None.
|
||||
"""
|
||||
@@ -223,6 +226,9 @@ class InferenceEngine:
|
||||
|
||||
block_size = self.inference_config.block_size
|
||||
|
||||
if prompts is not None and not isinstance(prompts, list):
|
||||
prompts = [prompts]
|
||||
|
||||
if prompts_token_ids is None:
|
||||
assert prompts, "When the prompts_token_ids is none, the input prompt list must be provided."
|
||||
prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=self.inference_config.pad_input)[
|
||||
@@ -245,8 +251,14 @@ class InferenceEngine:
|
||||
prompts_num = len(prompts_token_ids)
|
||||
|
||||
for i in range(prompts_num):
|
||||
if requests_id:
|
||||
request_id = requests_id[i]
|
||||
if request_ids:
|
||||
if not isinstance(request_ids, list):
|
||||
request_ids = [request_ids]
|
||||
assert isinstance(
|
||||
request_ids[0], int
|
||||
), f"The request_id type must be int, but got {type(request_ids[0])}"
|
||||
assert len(request_ids) == prompts_num
|
||||
request_id = request_ids[i]
|
||||
else:
|
||||
request_id = next(self.counter)
|
||||
if prompts == None:
|
||||
|
Reference in New Issue
Block a user