[Fix/Inference] Fix format of input prompts and input model in inference engine (#5395)

* Fix bugs in inference_engine

* fix bugs in engine.py

* rm  CUDA_VISIBLE_DEVICES

* add request_ids in generate

* fix bug in engine.py

* add logger.debug for BatchBucket
This commit is contained in:
yuehuayingxueluo
2024-02-23 10:51:35 +08:00
committed by GitHub
parent 2a718c8be8
commit bc1da87366
5 changed files with 27 additions and 8 deletions

View File

@@ -57,6 +57,7 @@ class InferenceEngine:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.generation_config = inference_config.to_generation_config(self.model_config)
model = model.eval()
model = model.cuda()
model.to(self.dtype)
if model_policy is None:
@@ -133,12 +134,13 @@ class InferenceEngine:
)
shardformer = ShardFormer(shard_config=shardconfig)
shard_model, _ = shardformer.optimize(model, model_policy)
return shard_model.cuda()
return shard_model
def generate(
self,
prompts: List[str] = None,
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
request_ids: List[int] = None,
return_token_ids: bool = False,
generation_config: Optional[GenerationConfig] = None,
) -> List[str]:
@@ -148,6 +150,7 @@ class InferenceEngine:
Args:
prompts (Union[List[str], optional): Input prompts. Defaults to None.
prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None.
request_ids (List[int], optional): The request ID. Defaults to None.
return_token_ids (bool): Whether to return output token ids. Defaults to False.
generation_config (GenerationConfig, optional): Huggingface GenerationConfig used for inference. Defaults to None.
@@ -157,7 +160,7 @@ class InferenceEngine:
with torch.inference_mode():
self.generation_config = generation_config
if prompts is not None or prompts_token_ids is not None:
self.add_request(prompts=prompts, prompts_token_ids=prompts_token_ids)
self.add_request(request_ids=request_ids, prompts=prompts, prompts_token_ids=prompts_token_ids)
output_seqs_list = []
total_tokens_list = []
@@ -204,7 +207,7 @@ class InferenceEngine:
def add_request(
self,
requests_id: List[int] = None,
request_ids: List[int] = None,
prompts: List[str] = None,
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
) -> None:
@@ -212,7 +215,7 @@ class InferenceEngine:
Add requests.
Args:
requests_id (List[int], optional): The request ID. Defaults to None.
request_ids (List[int], optional): The request ID. Defaults to None.
prompts (Union[List[str], optional): Input prompts. Defaults to None.
prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None.
"""
@@ -223,6 +226,9 @@ class InferenceEngine:
block_size = self.inference_config.block_size
if prompts is not None and not isinstance(prompts, list):
prompts = [prompts]
if prompts_token_ids is None:
assert prompts, "When the prompts_token_ids is none, the input prompt list must be provided."
prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=self.inference_config.pad_input)[
@@ -245,8 +251,14 @@ class InferenceEngine:
prompts_num = len(prompts_token_ids)
for i in range(prompts_num):
if requests_id:
request_id = requests_id[i]
if request_ids:
if not isinstance(request_ids, list):
request_ids = [request_ids]
assert isinstance(
request_ids[0], int
), f"The request_id type must be int, but got {type(request_ids[0])}"
assert len(request_ids) == prompts_num
request_id = request_ids[i]
else:
request_id = next(self.counter)
if prompts == None: