mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-12-03 21:23:34 +00:00
[fix] revert reward update and evaluation (#6295)
* Revert "rewrite reward fn" This reverts commitd06042b434. * Revert "upgrade reward math verification" This reverts commita6085ff676. * Revert "fix bug" This reverts commit01640ebd65. * Revert "reuse comm-group" This reverts commitbd61918dcf. * Revert "Support evaluation during training" This reverts commit57a88395fe.
This commit is contained in:
@@ -205,8 +205,7 @@ class VLLMInferenceBackend(BaseInferenceBackend):
|
||||
generate_config = generate_config.copy()
|
||||
generate_config.update(self.FORCE_GENERATE_CONFIG)
|
||||
generate_config.update({"n": num_generations})
|
||||
self.generate_config = generate_config
|
||||
self.sample_params = SamplingParams(**generate_config)
|
||||
self.generate_config = SamplingParams(**generate_config)
|
||||
self.model_config = model_config
|
||||
self.tokenizer = tokenizer
|
||||
self.num_generations = num_generations
|
||||
@@ -220,9 +219,8 @@ class VLLMInferenceBackend(BaseInferenceBackend):
|
||||
micro_batch_input_ids_no_padding = [
|
||||
micro_batch_input_ids[i][first_non_padding_token_idx[i] :] for i in range(micro_batch_size)
|
||||
]
|
||||
sample_params = kwargs.get("sample_params", self.sample_params)
|
||||
outputs = self.llm.generate(
|
||||
prompt_token_ids=micro_batch_input_ids_no_padding, sampling_params=sample_params, use_tqdm=False
|
||||
prompt_token_ids=micro_batch_input_ids_no_padding, sampling_params=self.generate_config, use_tqdm=False
|
||||
)
|
||||
out_tokens = []
|
||||
out_len = []
|
||||
@@ -268,11 +266,11 @@ class VLLMInferenceBackend(BaseInferenceBackend):
|
||||
"response_idx": response_idx,
|
||||
}
|
||||
|
||||
data = {k: v.view(micro_batch_size, -1, v.size(-1)) for k, v in data.items()}
|
||||
data = {k: v.view(micro_batch_size, self.num_generations, v.size(-1)) for k, v in data.items()}
|
||||
|
||||
if "gt_answer" in kwargs:
|
||||
# repeat gt_answer for each prompt.
|
||||
data["gt_answer"] = kwargs["gt_answer"].repeat_interleave(data["input_ids"].size(1), dim=1)
|
||||
data["gt_answer"] = kwargs["gt_answer"].repeat_interleave(self.num_generations, dim=1)
|
||||
data = {k: v.to(get_current_device()) for k, v in data.items()}
|
||||
return data
|
||||
|
||||
|
||||
Reference in New Issue
Block a user