mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-19 08:27:23 +00:00
simplify vllm preprocessing input ids
This commit is contained in:
parent
16e68a071d
commit
23aac43dcf
@ -212,13 +212,11 @@ class VLLMInferenceBackend(BaseInferenceBackend):
|
||||
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
|
||||
micro_batch_size = input_ids.size(0)
|
||||
response_start_idx = input_ids.size(1)
|
||||
first_non_padding_token_idx = (input_ids != self.tokenizer.pad_token_id).int().argmax(dim=1)
|
||||
micro_batch_input_ids = input_ids.tolist()
|
||||
micro_batch_input_ids_no_padding = []
|
||||
for i in range(micro_batch_size):
|
||||
for j in range(input_ids.size(1)):
|
||||
if micro_batch_input_ids[i][j] != self.tokenizer.pad_token_id:
|
||||
micro_batch_input_ids_no_padding.append(micro_batch_input_ids[i][j:])
|
||||
break
|
||||
micro_batch_input_ids_no_padding = [
|
||||
micro_batch_input_ids[i][first_non_padding_token_idx[i] :] for i in range(micro_batch_size)
|
||||
]
|
||||
outputs = self.llm.generate(
|
||||
prompt_token_ids=micro_batch_input_ids_no_padding, sampling_params=self.generate_config, use_tqdm=False
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user