Support evaluation during training

This commit is contained in:
YeAnbang
2025-04-30 18:13:40 +08:00
parent b920af427b
commit 47a7dc7142
9 changed files with 234 additions and 65 deletions

View File

@@ -205,7 +205,8 @@ class VLLMInferenceBackend(BaseInferenceBackend):
generate_config = generate_config.copy()
generate_config.update(self.FORCE_GENERATE_CONFIG)
generate_config.update({"n": num_generations})
self.generate_config = SamplingParams(**generate_config)
self.generate_config = generate_config
self.sample_params = SamplingParams(**generate_config)
self.model_config = model_config
self.tokenizer = tokenizer
self.num_generations = num_generations
@@ -219,8 +220,9 @@ class VLLMInferenceBackend(BaseInferenceBackend):
micro_batch_input_ids_no_padding = [
micro_batch_input_ids[i][first_non_padding_token_idx[i] :] for i in range(micro_batch_size)
]
sample_params = kwargs.get("sample_params", self.sample_params)
outputs = self.llm.generate(
prompt_token_ids=micro_batch_input_ids_no_padding, sampling_params=self.generate_config, use_tqdm=False
prompt_token_ids=micro_batch_input_ids_no_padding, sampling_params=sample_params, use_tqdm=False
)
out_tokens = []
out_len = []
@@ -266,11 +268,11 @@ class VLLMInferenceBackend(BaseInferenceBackend):
"response_idx": response_idx,
}
data = {k: v.view(micro_batch_size, self.num_generations, v.size(-1)) for k, v in data.items()}
data = {k: v.view(micro_batch_size, -1, v.size(-1)) for k, v in data.items()}
if "gt_answer" in kwargs:
# repeat gt_answer for each prompt.
data["gt_answer"] = kwargs["gt_answer"].repeat_interleave(self.num_generations, dim=1)
data["gt_answer"] = kwargs["gt_answer"].repeat_interleave(data["input_ids"].size(1), dim=1)
data = {k: v.to(get_current_device()) for k, v in data.items()}
return data