[fix] revert reward update and evaluation (#6295)

* Revert "rewrite reward fn"

This reverts commit d06042b434.

* Revert "upgrade reward math verification"

This reverts commit a6085ff676.

* Revert "fix bug"

This reverts commit 01640ebd65.

* Revert "reuse comm-group"

This reverts commit bd61918dcf.

* Revert "Support evaluation during training"

This reverts commit 57a88395fe.
This commit is contained in:
YeAnbang
2025-05-07 10:56:47 +08:00
committed by YeAnbang
parent 06b892bf4d
commit 9544c51a74
9 changed files with 82 additions and 307 deletions

View File

@@ -205,8 +205,7 @@ class VLLMInferenceBackend(BaseInferenceBackend):
generate_config = generate_config.copy()
generate_config.update(self.FORCE_GENERATE_CONFIG)
generate_config.update({"n": num_generations})
self.generate_config = generate_config
self.sample_params = SamplingParams(**generate_config)
self.generate_config = SamplingParams(**generate_config)
self.model_config = model_config
self.tokenizer = tokenizer
self.num_generations = num_generations
@@ -220,9 +219,8 @@ class VLLMInferenceBackend(BaseInferenceBackend):
micro_batch_input_ids_no_padding = [
micro_batch_input_ids[i][first_non_padding_token_idx[i] :] for i in range(micro_batch_size)
]
sample_params = kwargs.get("sample_params", self.sample_params)
outputs = self.llm.generate(
prompt_token_ids=micro_batch_input_ids_no_padding, sampling_params=sample_params, use_tqdm=False
prompt_token_ids=micro_batch_input_ids_no_padding, sampling_params=self.generate_config, use_tqdm=False
)
out_tokens = []
out_len = []
@@ -268,11 +266,11 @@ class VLLMInferenceBackend(BaseInferenceBackend):
"response_idx": response_idx,
}
data = {k: v.view(micro_batch_size, -1, v.size(-1)) for k, v in data.items()}
data = {k: v.view(micro_batch_size, self.num_generations, v.size(-1)) for k, v in data.items()}
if "gt_answer" in kwargs:
# repeat gt_answer for each prompt.
data["gt_answer"] = kwargs["gt_answer"].repeat_interleave(data["input_ids"].size(1), dim=1)
data["gt_answer"] = kwargs["gt_answer"].repeat_interleave(self.num_generations, dim=1)
data = {k: v.to(get_current_device()) for k, v in data.items()}
return data