diff --git a/applications/ColossalChat/coati/distributed/inference_backend.py b/applications/ColossalChat/coati/distributed/inference_backend.py index 5039d89f5..17c71c8a8 100644 --- a/applications/ColossalChat/coati/distributed/inference_backend.py +++ b/applications/ColossalChat/coati/distributed/inference_backend.py @@ -212,13 +212,11 @@ class VLLMInferenceBackend(BaseInferenceBackend): def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]: micro_batch_size = input_ids.size(0) response_start_idx = input_ids.size(1) + first_non_padding_token_idx = (input_ids != self.tokenizer.pad_token_id).int().argmax(dim=1) micro_batch_input_ids = input_ids.tolist() - micro_batch_input_ids_no_padding = [] - for i in range(micro_batch_size): - for j in range(input_ids.size(1)): - if micro_batch_input_ids[i][j] != self.tokenizer.pad_token_id: - micro_batch_input_ids_no_padding.append(micro_batch_input_ids[i][j:]) - break + micro_batch_input_ids_no_padding = [ + micro_batch_input_ids[i][first_non_padding_token_idx[i] :] for i in range(micro_batch_size) + ] outputs = self.llm.generate( prompt_token_ids=micro_batch_input_ids_no_padding, sampling_params=self.generate_config, use_tqdm=False )