spot a possible bug

This commit is contained in:
YeAnbang
2025-05-05 18:48:42 +08:00
parent 6fff36dd63
commit 4d18e7d772
3 changed files with 49 additions and 12 deletions

View File

@@ -80,8 +80,9 @@ class TransformersInferenceBackend(BaseInferenceBackend):
if self.num_generations > 1:
input_ids = input_ids.repeat_interleave(self.num_generations, dim=0)
attention_mask = attention_mask.repeat_interleave(self.num_generations, dim=0)
generate_config = kwargs.get("generate_config", self.generate_config)
out = self.model.generate(
input_ids, attention_mask=attention_mask, **kwargs, **self.generate_config, tokenizer=self.tokenizer
input_ids, attention_mask=attention_mask, **kwargs, **generate_config, tokenizer=self.tokenizer
)
input_len = input_ids.shape[-1]
new_token_ids = out.sequences[:, input_len:]