From 3f74dfc3d8efb7c58065a336b12d99853d275e15 Mon Sep 17 00:00:00 2001 From: Enes Bol <76845631+enesbol@users.noreply.github.com> Date: Tue, 15 Oct 2024 23:57:50 +0200 Subject: [PATCH] community[patch]: Fix vLLM integration to filter SamplingParams (#27367) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit **Description:** - This pull request addresses a bug in Langchain's VLLM integration, where the use_beam_search parameter was erroneously passed to SamplingParams. The SamplingParams class in vLLM does not support the use_beam_search argument, which caused a TypeError. - This PR introduces logic to filter out unsupported parameters, ensuring that only valid parameters are passed to SamplingParams. As a result, the integration now functions as expected without errors. - The bug was reproduced by running the code sample from Langchain’s documentation, which triggered the error due to the invalid parameter. This fix resolves that error by implementing proper parameter filtering. **VLLM Sampling Params Class:** https://github.com/vllm-project/vllm/blob/main/vllm/sampling_params.py **Issue:** I could not found an Issue that belongs to this. Fixes "TypeError: Unexpected keyword argument 'use_beam_search'" error when using VLLM from Langchain. **Dependencies:** None. **Tests and Documentation**: Tests: No new functionality was added, but I tested the changes by running multiple prompts through the VLLM integration with various parameter configurations. All tests passed successfully without breaking compatibility. Docs No documentation changes were necessary as this is a bug fix. **Reproducing the Error:** https://python.langchain.com/docs/integrations/llms/vllm/ The code sample from the original documentation can be used to reproduce the error I got. from langchain_community.llms import VLLM llm = VLLM( model="mosaicml/mpt-7b", trust_remote_code=True, # mandatory for hf models max_new_tokens=128, top_k=10, top_p=0.95, temperature=0.8, ) print(llm.invoke("What is the capital of France ?")) ![image](https://github.com/user-attachments/assets/3782d6ac-1f7b-4acc-bf2c-186216149de5) This PR resolves the issue by ensuring that only valid parameters are passed to SamplingParams. --- libs/community/langchain_community/llms/vllm.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/libs/community/langchain_community/llms/vllm.py b/libs/community/langchain_community/llms/vllm.py index b887a8b0ab2..dc8a7a76d24 100644 --- a/libs/community/langchain_community/llms/vllm.py +++ b/libs/community/langchain_community/llms/vllm.py @@ -123,14 +123,19 @@ class VLLM(BaseLLM): **kwargs: Any, ) -> LLMResult: """Run the LLM on the given prompt and input.""" - from vllm import SamplingParams # build sampling parameters params = {**self._default_params, **kwargs, "stop": stop} - sampling_params = SamplingParams(**params) + + # filter params for SamplingParams + known_keys = SamplingParams.__annotations__.keys() + sample_params = SamplingParams( + **{k: v for k, v in params.items() if k in known_keys} + ) + # call the model - outputs = self.client.generate(prompts, sampling_params) + outputs = self.client.generate(prompts, sample_params) generations = [] for output in outputs: