simplify vllm preprocessing input ids

This commit is contained in:
YeAnbang 2025-03-21 15:03:10 +08:00
parent 16e68a071d
commit 23aac43dcf

View File

@ -212,13 +212,11 @@ class VLLMInferenceBackend(BaseInferenceBackend):
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
micro_batch_size = input_ids.size(0)
response_start_idx = input_ids.size(1)
first_non_padding_token_idx = (input_ids != self.tokenizer.pad_token_id).int().argmax(dim=1)
micro_batch_input_ids = input_ids.tolist()
micro_batch_input_ids_no_padding = []
for i in range(micro_batch_size):
for j in range(input_ids.size(1)):
if micro_batch_input_ids[i][j] != self.tokenizer.pad_token_id:
micro_batch_input_ids_no_padding.append(micro_batch_input_ids[i][j:])
break
micro_batch_input_ids_no_padding = [
micro_batch_input_ids[i][first_non_padding_token_idx[i] :] for i in range(micro_batch_size)
]
outputs = self.llm.generate(
prompt_token_ids=micro_batch_input_ids_no_padding, sampling_params=self.generate_config, use_tqdm=False
)