fix transformers backend

This commit is contained in:
YeAnbang
2025-03-14 18:12:35 +08:00
parent e224673c44
commit 35dabd718e
3 changed files with 34 additions and 10 deletions

View File

@@ -61,12 +61,22 @@ class TransformersInferenceBackend(BaseInferenceBackend):
self.generate_config = generate_config.copy()
self.generate_config.update(self.FORCE_GENERATE_CONFIG)
self.tokenizer = tokenizer
self.num_generations = 8
@torch.no_grad()
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
micro_batch_size = input_ids.size(0)
input_ids = input_ids.to(get_current_device())
attention_mask = attention_mask.to(get_current_device())
out = self.model.generate(input_ids, attention_mask=attention_mask, **kwargs, **self.generate_config)
gt_answer = None
if "gt_answer" in kwargs:
gt_answer = kwargs.pop("gt_answer")
if self.num_generations > 1:
input_ids = input_ids.repeat_interleave(self.num_generations, dim=0)
attention_mask = attention_mask.repeat_interleave(self.num_generations, dim=0)
out = self.model.generate(
input_ids, attention_mask=attention_mask, **kwargs, **self.generate_config, tokenizer=self.tokenizer
)
input_len = input_ids.shape[-1]
new_token_ids = out.sequences[:, input_len:]
# get log probs
@@ -76,10 +86,13 @@ class TransformersInferenceBackend(BaseInferenceBackend):
action_log_probs.append(log_probs_from_logits(logits[:, None, :], new_token_ids[:, i : i + 1]))
action_log_probs = torch.cat(action_log_probs, dim=1)
# get action mask
response_idx = torch.zeros((new_token_ids.size(0), 2), dtype=torch.int).to(get_current_device())
action_mask = torch.ones_like(new_token_ids, dtype=attention_mask.dtype)
if self.tokenizer.eos_token_id is not None:
for indices in torch.nonzero(new_token_ids == self.tokenizer.eos_token_id):
action_mask[indices[0], indices[1] + 1 :] = 0
response_idx[:, 0] = input_len
response_idx[:, 1] = input_len + action_mask.sum(dim=1) - 1
if attention_mask.size(0) != action_mask.size(0):
assert action_mask.size(0) % attention_mask.size(0) == 0
@@ -91,7 +104,15 @@ class TransformersInferenceBackend(BaseInferenceBackend):
"attention_mask": attention_mask,
"action_log_probs": action_log_probs,
"action_mask": action_mask,
"response_idx": response_idx,
}
data = {k: v.view(micro_batch_size, self.num_generations, v.size(-1)) for k, v in data.items()}
if gt_answer is not None:
# repeat gt_answer for each prompt.
data["gt_answer"] = gt_answer.repeat_interleave(self.num_generations, dim=1)
data = {k: v.to(get_current_device()) for k, v in data.items()}
return data
def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None: