This commit is contained in:
Tong Li
2025-02-28 10:16:42 +08:00
parent f736d747e3
commit 070907dd7f
6 changed files with 74 additions and 26 deletions

View File

@@ -154,6 +154,7 @@ class VLLMInferenceBackend(BaseInferenceBackend):
)
FORCE_GENERATE_CONFIG = dict(
logprobs=0,
n=4,
)
def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer):
@@ -166,19 +167,24 @@ class VLLMInferenceBackend(BaseInferenceBackend):
generate_config.update(self.FORCE_GENERATE_CONFIG)
self.generate_config = SamplingParams(**generate_config)
self.tokenizer = tokenizer
self.num_generations = self.FORCE_GENERATE_CONFIG["n"]
@torch.no_grad()
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
micro_batch_size = input_ids.size(0)
response_start_idx = input_ids.size(1)
outputs = self.llm.generate(
prompt_token_ids=input_ids.tolist(), sampling_params=self.generate_config, use_tqdm=False
)
out_tokens = []
out_len = []
log_probs = []
response_idx = []
for out in outputs:
for output_i in out.outputs:
out_len.append(len(output_i.token_ids))
out_tokens.append(list(output_i.token_ids))
response_idx.append((response_start_idx, response_start_idx + len(output_i.token_ids) - 1))
assert len(output_i.logprobs) == len(output_i.token_ids)
p = [m[t].logprob for m, t in zip(output_i.logprobs, output_i.token_ids)]
log_probs.append(p)
@@ -195,6 +201,8 @@ class VLLMInferenceBackend(BaseInferenceBackend):
out_tokens = torch.tensor(out_tokens)
log_probs = torch.tensor(log_probs)
response_idx = torch.tensor(response_idx)
if attention_mask.size(0) != action_mask.size(0):
assert action_mask.size(0) % attention_mask.size(0) == 0
num_returns = action_mask.size(0) // attention_mask.size(0)
@@ -209,9 +217,14 @@ class VLLMInferenceBackend(BaseInferenceBackend):
"attention_mask": attention_mask,
"action_log_probs": log_probs,
"action_mask": action_mask,
"response_idx": response_idx,
}
data = {k: v.view(micro_batch_size, self.num_generations, v.size(-1)) for k, v in data.items()}
if "gt_answer" in kwargs:
data["gt_answer"] = kwargs["gt_answer"]
# repeat gt_answer for each prompt.
data["gt_answer"] = kwargs["gt_answer"].repeat_interleave(self.num_generations, dim=1)
data = {k: v.to(get_current_device()) for k, v in data.items()}
return data