mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
fix vllm
This commit is contained in:
@@ -183,7 +183,7 @@ class VLLMInferenceBackend(BaseInferenceBackend):
|
||||
raise ImportError("vllm is not installed")
|
||||
model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG)
|
||||
path = model_config.pop("path")
|
||||
self.llm = LLM(path, **model_config)
|
||||
self.llm = LLM(model=path, **model_config)
|
||||
generate_config = generate_config.copy()
|
||||
generate_config.update(self.FORCE_GENERATE_CONFIG)
|
||||
self.generate_config = SamplingParams(**generate_config)
|
||||
@@ -194,8 +194,15 @@ class VLLMInferenceBackend(BaseInferenceBackend):
|
||||
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
|
||||
micro_batch_size = input_ids.size(0)
|
||||
response_start_idx = input_ids.size(1)
|
||||
micro_batch_input_ids = input_ids.tolist()
|
||||
micro_batch_input_ids_no_padding = []
|
||||
for i in range(micro_batch_size):
|
||||
for j in range(input_ids.size(1)):
|
||||
if micro_batch_input_ids[i][j] != self.tokenizer.pad_token_id:
|
||||
micro_batch_input_ids_no_padding.append(micro_batch_input_ids[i][j:])
|
||||
break
|
||||
outputs = self.llm.generate(
|
||||
prompt_token_ids=input_ids.tolist(), sampling_params=self.generate_config, use_tqdm=False
|
||||
prompt_token_ids=micro_batch_input_ids_no_padding, sampling_params=self.generate_config, use_tqdm=False
|
||||
)
|
||||
out_tokens = []
|
||||
out_len = []
|
||||
|
Reference in New Issue
Block a user