diff --git a/libs/community/langchain_community/llms/vllm.py b/libs/community/langchain_community/llms/vllm.py index dc8a7a76d24..66a0f17756b 100644 --- a/libs/community/langchain_community/llms/vllm.py +++ b/libs/community/langchain_community/llms/vllm.py @@ -125,6 +125,8 @@ class VLLM(BaseLLM): """Run the LLM on the given prompt and input.""" from vllm import SamplingParams + lora_request = kwargs.pop("lora_request", None) + # build sampling parameters params = {**self._default_params, **kwargs, "stop": stop} @@ -135,7 +137,12 @@ class VLLM(BaseLLM): ) # call the model - outputs = self.client.generate(prompts, sample_params) + if lora_request: + outputs = self.client.generate( + prompts, sample_params, lora_request=lora_request + ) + else: + outputs = self.client.generate(prompts, sample_params) generations = [] for output in outputs: