mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[Inference] Fix Inference Generation Config and Sampling (#5710)
* refactor and add * config default values * fix gen config passing * fix rpc generation config
This commit is contained in:
@@ -97,7 +97,9 @@ class rpcWorkerService(rpyc.Service):
|
||||
)
|
||||
logger.info("physical cache init over")
|
||||
|
||||
def exposed_execute_model_forward(self, input_token_ids_param: List[int], input_meta_data_param: dict):
|
||||
def exposed_execute_model_forward(
|
||||
self, input_token_ids_param: List[int], input_meta_data_param: dict, generation_config_param: dict
|
||||
):
|
||||
# prepare the data for model forward
|
||||
input_meta_data = InputMetaData.from_rpc_param(input_meta_data_param)
|
||||
input_meta_data.fd_inter_tensor = self.fd_inter_tensor
|
||||
@@ -120,7 +122,7 @@ class rpcWorkerService(rpyc.Service):
|
||||
if self.inference_config.pad_input:
|
||||
logits = logits[:, -1, :]
|
||||
next_tokens = search_tokens(
|
||||
self.inference_config.to_generation_config(self.model_config),
|
||||
generation_config_param,
|
||||
logits,
|
||||
input_meta_data.is_prompts,
|
||||
input_meta_data.batch_token_ids,
|
||||
|
Reference in New Issue
Block a user