diff --git a/libs/langchain/langchain/llms/vllm.py b/libs/langchain/langchain/llms/vllm.py index 1a6e1a5910b..537a9bbb6f3 100644 --- a/libs/langchain/langchain/llms/vllm.py +++ b/libs/langchain/langchain/llms/vllm.py @@ -62,6 +62,10 @@ class VLLM(BaseLLM): dtype: str = "auto" """The data type for the model weights and activations.""" + download_dir: Optional[str] = None + """Directory to download and load the weights. (Default to the default + cache dir of huggingface)""" + vllm_kwargs: Dict[str, Any] = Field(default_factory=dict) """Holds any model parameters valid for `vllm.LLM` call not explicitly specified.""" @@ -84,6 +88,7 @@ class VLLM(BaseLLM): tensor_parallel_size=values["tensor_parallel_size"], trust_remote_code=values["trust_remote_code"], dtype=values["dtype"], + download_dir=values["download_dir"], **values["vllm_kwargs"], )