diff --git a/libs/partners/huggingface/langchain_huggingface/llms/huggingface_endpoint.py b/libs/partners/huggingface/langchain_huggingface/llms/huggingface_endpoint.py index b270a993041..076cafd4de6 100644 --- a/libs/partners/huggingface/langchain_huggingface/llms/huggingface_endpoint.py +++ b/libs/partners/huggingface/langchain_huggingface/llms/huggingface_endpoint.py @@ -1,3 +1,4 @@ +import inspect import json # type: ignore[import-not-found] import logging import os @@ -212,19 +213,42 @@ class HuggingFaceEndpoint(LLM): from huggingface_hub import AsyncInferenceClient, InferenceClient + # Instantiate clients with supported kwargs + sync_supported_kwargs = set(inspect.signature(InferenceClient).parameters) self.client = InferenceClient( model=self.model, timeout=self.timeout, token=huggingfacehub_api_token, - **self.server_kwargs, + **{ + key: value + for key, value in self.server_kwargs.items() + if key in sync_supported_kwargs + }, ) + + async_supported_kwargs = set(inspect.signature(AsyncInferenceClient).parameters) self.async_client = AsyncInferenceClient( model=self.model, timeout=self.timeout, token=huggingfacehub_api_token, - **self.server_kwargs, + **{ + key: value + for key, value in self.server_kwargs.items() + if key in async_supported_kwargs + }, ) + ignored_kwargs = ( + set(self.server_kwargs.keys()) + - sync_supported_kwargs + - async_supported_kwargs + ) + if len(ignored_kwargs) > 0: + logger.warning( + f"Ignoring following parameters as they are not supported by the " + f"InferenceClient or AsyncInferenceClient: {ignored_kwargs}." + ) + return self @property