From a2023a1e9690dcb6505ef8a72cf3659530d0f9cf Mon Sep 17 00:00:00 2001 From: Lucain Date: Sat, 21 Sep 2024 01:05:24 +0200 Subject: [PATCH] huggingface; fix huggingface_endpoint.py (initialize clients only with supported kwargs) (#26378) ## Description By default, `HuggingFaceEndpoint` instantiates both the `InferenceClient` and the `AsyncInferenceClient` with the `"server_kwargs"` passed as input. This is an issue as both clients might not support exactly the same kwargs. This has been highlighted in https://github.com/huggingface/huggingface_hub/issues/2522 by @morgandiverrez with the `trust_env` parameter. In order to make `langchain` integration future-proof, I do think it's wiser to forward only the supported parameters to each client. Parameters that are not supported are simply ignored with a warning to the user. From a `huggingface_hub` maintenance perspective, this allows us much more flexibility as we are not constrained to support the exact same kwargs in both clients. ## Issue https://github.com/huggingface/huggingface_hub/issues/2522 ## Dependencies None ## Twitter https://x.com/Wauplin --------- Co-authored-by: Erick Friis --- .../llms/huggingface_endpoint.py | 28 +++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/libs/partners/huggingface/langchain_huggingface/llms/huggingface_endpoint.py b/libs/partners/huggingface/langchain_huggingface/llms/huggingface_endpoint.py index b270a993041..076cafd4de6 100644 --- a/libs/partners/huggingface/langchain_huggingface/llms/huggingface_endpoint.py +++ b/libs/partners/huggingface/langchain_huggingface/llms/huggingface_endpoint.py @@ -1,3 +1,4 @@ +import inspect import json # type: ignore[import-not-found] import logging import os @@ -212,19 +213,42 @@ class HuggingFaceEndpoint(LLM): from huggingface_hub import AsyncInferenceClient, InferenceClient + # Instantiate clients with supported kwargs + sync_supported_kwargs = set(inspect.signature(InferenceClient).parameters) self.client = InferenceClient( model=self.model, timeout=self.timeout, token=huggingfacehub_api_token, - **self.server_kwargs, + **{ + key: value + for key, value in self.server_kwargs.items() + if key in sync_supported_kwargs + }, ) + + async_supported_kwargs = set(inspect.signature(AsyncInferenceClient).parameters) self.async_client = AsyncInferenceClient( model=self.model, timeout=self.timeout, token=huggingfacehub_api_token, - **self.server_kwargs, + **{ + key: value + for key, value in self.server_kwargs.items() + if key in async_supported_kwargs + }, ) + ignored_kwargs = ( + set(self.server_kwargs.keys()) + - sync_supported_kwargs + - async_supported_kwargs + ) + if len(ignored_kwargs) > 0: + logger.warning( + f"Ignoring following parameters as they are not supported by the " + f"InferenceClient or AsyncInferenceClient: {ignored_kwargs}." + ) + return self @property