mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-11-24 15:25:34 +00:00
Support evaluation during training
This commit is contained in:
@@ -205,7 +205,8 @@ class VLLMInferenceBackend(BaseInferenceBackend):
|
||||
generate_config = generate_config.copy()
|
||||
generate_config.update(self.FORCE_GENERATE_CONFIG)
|
||||
generate_config.update({"n": num_generations})
|
||||
self.generate_config = SamplingParams(**generate_config)
|
||||
self.generate_config = generate_config
|
||||
self.sample_params = SamplingParams(**generate_config)
|
||||
self.model_config = model_config
|
||||
self.tokenizer = tokenizer
|
||||
self.num_generations = num_generations
|
||||
@@ -219,8 +220,9 @@ class VLLMInferenceBackend(BaseInferenceBackend):
|
||||
micro_batch_input_ids_no_padding = [
|
||||
micro_batch_input_ids[i][first_non_padding_token_idx[i] :] for i in range(micro_batch_size)
|
||||
]
|
||||
sample_params = kwargs.get("sample_params", self.sample_params)
|
||||
outputs = self.llm.generate(
|
||||
prompt_token_ids=micro_batch_input_ids_no_padding, sampling_params=self.generate_config, use_tqdm=False
|
||||
prompt_token_ids=micro_batch_input_ids_no_padding, sampling_params=sample_params, use_tqdm=False
|
||||
)
|
||||
out_tokens = []
|
||||
out_len = []
|
||||
@@ -266,11 +268,11 @@ class VLLMInferenceBackend(BaseInferenceBackend):
|
||||
"response_idx": response_idx,
|
||||
}
|
||||
|
||||
data = {k: v.view(micro_batch_size, self.num_generations, v.size(-1)) for k, v in data.items()}
|
||||
data = {k: v.view(micro_batch_size, -1, v.size(-1)) for k, v in data.items()}
|
||||
|
||||
if "gt_answer" in kwargs:
|
||||
# repeat gt_answer for each prompt.
|
||||
data["gt_answer"] = kwargs["gt_answer"].repeat_interleave(self.num_generations, dim=1)
|
||||
data["gt_answer"] = kwargs["gt_answer"].repeat_interleave(data["input_ids"].size(1), dim=1)
|
||||
data = {k: v.to(get_current_device()) for k, v in data.items()}
|
||||
return data
|
||||
|
||||
|
||||
Reference in New Issue
Block a user