This commit is contained in:
YeAnbang
2025-03-19 17:07:20 +08:00
parent 7795d4c50d
commit 7ee4452f8c
5 changed files with 172 additions and 24 deletions

View File

@@ -183,7 +183,7 @@ class VLLMInferenceBackend(BaseInferenceBackend):
raise ImportError("vllm is not installed")
model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG)
path = model_config.pop("path")
self.llm = LLM(path, **model_config)
self.llm = LLM(model=path, **model_config)
generate_config = generate_config.copy()
generate_config.update(self.FORCE_GENERATE_CONFIG)
self.generate_config = SamplingParams(**generate_config)
@@ -194,8 +194,15 @@ class VLLMInferenceBackend(BaseInferenceBackend):
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
micro_batch_size = input_ids.size(0)
response_start_idx = input_ids.size(1)
micro_batch_input_ids = input_ids.tolist()
micro_batch_input_ids_no_padding = []
for i in range(micro_batch_size):
for j in range(input_ids.size(1)):
if micro_batch_input_ids[i][j] != self.tokenizer.pad_token_id:
micro_batch_input_ids_no_padding.append(micro_batch_input_ids[i][j:])
break
outputs = self.llm.generate(
prompt_token_ids=input_ids.tolist(), sampling_params=self.generate_config, use_tqdm=False
prompt_token_ids=micro_batch_input_ids_no_padding, sampling_params=self.generate_config, use_tqdm=False
)
out_tokens = []
out_len = []