mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-10 21:35:08 +00:00
**Description:** - This pull request addresses a bug in Langchain's VLLM integration, where the use_beam_search parameter was erroneously passed to SamplingParams. The SamplingParams class in vLLM does not support the use_beam_search argument, which caused a TypeError. - This PR introduces logic to filter out unsupported parameters, ensuring that only valid parameters are passed to SamplingParams. As a result, the integration now functions as expected without errors. - The bug was reproduced by running the code sample from Langchain’s documentation, which triggered the error due to the invalid parameter. This fix resolves that error by implementing proper parameter filtering. **VLLM Sampling Params Class:** https://github.com/vllm-project/vllm/blob/main/vllm/sampling_params.py **Issue:** I could not found an Issue that belongs to this. Fixes "TypeError: Unexpected keyword argument 'use_beam_search'" error when using VLLM from Langchain. **Dependencies:** None. **Tests and Documentation**: Tests: No new functionality was added, but I tested the changes by running multiple prompts through the VLLM integration with various parameter configurations. All tests passed successfully without breaking compatibility. Docs No documentation changes were necessary as this is a bug fix. **Reproducing the Error:** https://python.langchain.com/docs/integrations/llms/vllm/ The code sample from the original documentation can be used to reproduce the error I got. from langchain_community.llms import VLLM llm = VLLM( model="mosaicml/mpt-7b", trust_remote_code=True, # mandatory for hf models max_new_tokens=128, top_k=10, top_p=0.95, temperature=0.8, ) print(llm.invoke("What is the capital of France ?"))  This PR resolves the issue by ensuring that only valid parameters are passed to SamplingParams.
183 lines
5.6 KiB
Python
183 lines
5.6 KiB
Python
from typing import Any, Dict, List, Optional
|
|
|
|
from langchain_core.callbacks import CallbackManagerForLLMRun
|
|
from langchain_core.language_models.llms import BaseLLM
|
|
from langchain_core.outputs import Generation, LLMResult
|
|
from langchain_core.utils import pre_init
|
|
from pydantic import Field
|
|
|
|
from langchain_community.llms.openai import BaseOpenAI
|
|
from langchain_community.utils.openai import is_openai_v1
|
|
|
|
|
|
class VLLM(BaseLLM):
|
|
"""VLLM language model."""
|
|
|
|
model: str = ""
|
|
"""The name or path of a HuggingFace Transformers model."""
|
|
|
|
tensor_parallel_size: Optional[int] = 1
|
|
"""The number of GPUs to use for distributed execution with tensor parallelism."""
|
|
|
|
trust_remote_code: Optional[bool] = False
|
|
"""Trust remote code (e.g., from HuggingFace) when downloading the model
|
|
and tokenizer."""
|
|
|
|
n: int = 1
|
|
"""Number of output sequences to return for the given prompt."""
|
|
|
|
best_of: Optional[int] = None
|
|
"""Number of output sequences that are generated from the prompt."""
|
|
|
|
presence_penalty: float = 0.0
|
|
"""Float that penalizes new tokens based on whether they appear in the
|
|
generated text so far"""
|
|
|
|
frequency_penalty: float = 0.0
|
|
"""Float that penalizes new tokens based on their frequency in the
|
|
generated text so far"""
|
|
|
|
temperature: float = 1.0
|
|
"""Float that controls the randomness of the sampling."""
|
|
|
|
top_p: float = 1.0
|
|
"""Float that controls the cumulative probability of the top tokens to consider."""
|
|
|
|
top_k: int = -1
|
|
"""Integer that controls the number of top tokens to consider."""
|
|
|
|
use_beam_search: bool = False
|
|
"""Whether to use beam search instead of sampling."""
|
|
|
|
stop: Optional[List[str]] = None
|
|
"""List of strings that stop the generation when they are generated."""
|
|
|
|
ignore_eos: bool = False
|
|
"""Whether to ignore the EOS token and continue generating tokens after
|
|
the EOS token is generated."""
|
|
|
|
max_new_tokens: int = 512
|
|
"""Maximum number of tokens to generate per output sequence."""
|
|
|
|
logprobs: Optional[int] = None
|
|
"""Number of log probabilities to return per output token."""
|
|
|
|
dtype: str = "auto"
|
|
"""The data type for the model weights and activations."""
|
|
|
|
download_dir: Optional[str] = None
|
|
"""Directory to download and load the weights. (Default to the default
|
|
cache dir of huggingface)"""
|
|
|
|
vllm_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
|
"""Holds any model parameters valid for `vllm.LLM` call not explicitly specified."""
|
|
|
|
client: Any = None #: :meta private:
|
|
|
|
@pre_init
|
|
def validate_environment(cls, values: Dict) -> Dict:
|
|
"""Validate that python package exists in environment."""
|
|
|
|
try:
|
|
from vllm import LLM as VLLModel
|
|
except ImportError:
|
|
raise ImportError(
|
|
"Could not import vllm python package. "
|
|
"Please install it with `pip install vllm`."
|
|
)
|
|
|
|
values["client"] = VLLModel(
|
|
model=values["model"],
|
|
tensor_parallel_size=values["tensor_parallel_size"],
|
|
trust_remote_code=values["trust_remote_code"],
|
|
dtype=values["dtype"],
|
|
download_dir=values["download_dir"],
|
|
**values["vllm_kwargs"],
|
|
)
|
|
|
|
return values
|
|
|
|
@property
|
|
def _default_params(self) -> Dict[str, Any]:
|
|
"""Get the default parameters for calling vllm."""
|
|
return {
|
|
"n": self.n,
|
|
"best_of": self.best_of,
|
|
"max_tokens": self.max_new_tokens,
|
|
"top_k": self.top_k,
|
|
"top_p": self.top_p,
|
|
"temperature": self.temperature,
|
|
"presence_penalty": self.presence_penalty,
|
|
"frequency_penalty": self.frequency_penalty,
|
|
"stop": self.stop,
|
|
"ignore_eos": self.ignore_eos,
|
|
"use_beam_search": self.use_beam_search,
|
|
"logprobs": self.logprobs,
|
|
}
|
|
|
|
def _generate(
|
|
self,
|
|
prompts: List[str],
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
**kwargs: Any,
|
|
) -> LLMResult:
|
|
"""Run the LLM on the given prompt and input."""
|
|
from vllm import SamplingParams
|
|
|
|
# build sampling parameters
|
|
params = {**self._default_params, **kwargs, "stop": stop}
|
|
|
|
# filter params for SamplingParams
|
|
known_keys = SamplingParams.__annotations__.keys()
|
|
sample_params = SamplingParams(
|
|
**{k: v for k, v in params.items() if k in known_keys}
|
|
)
|
|
|
|
# call the model
|
|
outputs = self.client.generate(prompts, sample_params)
|
|
|
|
generations = []
|
|
for output in outputs:
|
|
text = output.outputs[0].text
|
|
generations.append([Generation(text=text)])
|
|
|
|
return LLMResult(generations=generations)
|
|
|
|
@property
|
|
def _llm_type(self) -> str:
|
|
"""Return type of llm."""
|
|
return "vllm"
|
|
|
|
|
|
class VLLMOpenAI(BaseOpenAI):
|
|
"""vLLM OpenAI-compatible API client"""
|
|
|
|
@classmethod
|
|
def is_lc_serializable(cls) -> bool:
|
|
return False
|
|
|
|
@property
|
|
def _invocation_params(self) -> Dict[str, Any]:
|
|
"""Get the parameters used to invoke the model."""
|
|
|
|
params: Dict[str, Any] = {
|
|
"model": self.model_name,
|
|
**self._default_params,
|
|
"logit_bias": None,
|
|
}
|
|
if not is_openai_v1():
|
|
params.update(
|
|
{
|
|
"api_key": self.openai_api_key,
|
|
"api_base": self.openai_api_base,
|
|
}
|
|
)
|
|
|
|
return params
|
|
|
|
@property
|
|
def _llm_type(self) -> str:
|
|
"""Return type of llm."""
|
|
return "vllm-openai"
|