diff --git a/libs/langchain/langchain/llms/vllm.py b/libs/langchain/langchain/llms/vllm.py index 9f456fde514..1a6e1a5910b 100644 --- a/libs/langchain/langchain/llms/vllm.py +++ b/libs/langchain/langchain/llms/vllm.py @@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import BaseLLM from langchain.llms.openai import BaseOpenAI -from langchain.pydantic_v1 import root_validator +from langchain.pydantic_v1 import Field, root_validator from langchain.schema.output import Generation, LLMResult @@ -62,6 +62,9 @@ class VLLM(BaseLLM): dtype: str = "auto" """The data type for the model weights and activations.""" + vllm_kwargs: Dict[str, Any] = Field(default_factory=dict) + """Holds any model parameters valid for `vllm.LLM` call not explicitly specified.""" + client: Any #: :meta private: @root_validator() @@ -81,6 +84,7 @@ class VLLM(BaseLLM): tensor_parallel_size=values["tensor_parallel_size"], trust_remote_code=values["trust_remote_code"], dtype=values["dtype"], + **values["vllm_kwargs"], ) return values