[Inference] Fix Inference Generation Config and Sampling (#5710)

* refactor and add

* config default values

* fix gen config passing

* fix rpc generation config
This commit is contained in:
Yuanheng Zhao
2024-05-19 15:08:42 +08:00
committed by GitHub
parent 8bcfe360fd
commit 283c407a19
6 changed files with 124 additions and 68 deletions

View File

@@ -97,7 +97,9 @@ class rpcWorkerService(rpyc.Service):
)
logger.info("physical cache init over")
def exposed_execute_model_forward(self, input_token_ids_param: List[int], input_meta_data_param: dict):
def exposed_execute_model_forward(
self, input_token_ids_param: List[int], input_meta_data_param: dict, generation_config_param: dict
):
# prepare the data for model forward
input_meta_data = InputMetaData.from_rpc_param(input_meta_data_param)
input_meta_data.fd_inter_tensor = self.fd_inter_tensor
@@ -120,7 +122,7 @@ class rpcWorkerService(rpyc.Service):
if self.inference_config.pad_input:
logits = logits[:, -1, :]
next_tokens = search_tokens(
self.inference_config.to_generation_config(self.model_config),
generation_config_param,
logits,
input_meta_data.is_prompts,
input_meta_data.batch_token_ids,