diff --git a/libs/community/langchain_community/llms/vllm.py b/libs/community/langchain_community/llms/vllm.py index b887a8b0ab2..dc8a7a76d24 100644 --- a/libs/community/langchain_community/llms/vllm.py +++ b/libs/community/langchain_community/llms/vllm.py @@ -123,14 +123,19 @@ class VLLM(BaseLLM): **kwargs: Any, ) -> LLMResult: """Run the LLM on the given prompt and input.""" - from vllm import SamplingParams # build sampling parameters params = {**self._default_params, **kwargs, "stop": stop} - sampling_params = SamplingParams(**params) + + # filter params for SamplingParams + known_keys = SamplingParams.__annotations__.keys() + sample_params = SamplingParams( + **{k: v for k, v in params.items() if k in known_keys} + ) + # call the model - outputs = self.client.generate(prompts, sampling_params) + outputs = self.client.generate(prompts, sample_params) generations = [] for output in outputs: