[Inference] Fix Inference Generation Config and Sampling (#5710)

* refactor and add

* config default values

* fix gen config passing

* fix rpc generation config
This commit is contained in:
Yuanheng Zhao
2024-05-19 15:08:42 +08:00
committed by GitHub
parent 8bcfe360fd
commit 283c407a19
6 changed files with 124 additions and 68 deletions

View File

@@ -257,7 +257,12 @@ class RPCInferenceEngine(InferenceEngine):
assert len(self.workers) == self.tp_size, "init workers first"
init_tasks = [
self.async_parallel_wrapper(worker.execute_model_forward, input_token_ids, input_meta_data.to_rpc_param())
self.async_parallel_wrapper(
worker.execute_model_forward,
input_token_ids,
input_meta_data.to_rpc_param(),
self.generation_config_dict,
)
for worker in self.workers
]
ret = await asyncio.gather(*init_tasks)