[Inference] Fix API server, test and example (#5712)

* fix api server

* fix generation config

* fix api server

* fix comments

* fix infer hanging bug

* resolve comments, change backend to free port
This commit is contained in:
Jianghai
2024-05-15 15:47:31 +08:00
committed by GitHub
parent 74c47921fa
commit f47f2fbb24
5 changed files with 73 additions and 32 deletions

View File

@@ -154,7 +154,6 @@ class InferenceEngine:
else:
model_type = "nopadding_" + self.model_config.model_type
model_policy = model_policy_map[model_type]()
pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size)
tp_group = pg_mesh.get_group_along_axis(TP_AXIS)
@@ -589,7 +588,7 @@ class InferenceEngine:
def add_request(
self,
request_ids: Union[List[int], int] = None,
prompts: List[str] = None,
prompts: Union[List[str], str] = None,
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
**kwargs,
) -> None: