huggingface; fix huggingface_endpoint.py (initialize clients only with supported kwargs) (#26378)

## Description

By default, `HuggingFaceEndpoint` instantiates both the
`InferenceClient` and the `AsyncInferenceClient` with the
`"server_kwargs"` passed as input. This is an issue as both clients
might not support exactly the same kwargs. This has been highlighted in
https://github.com/huggingface/huggingface_hub/issues/2522 by
@morgandiverrez with the `trust_env` parameter. In order to make
`langchain` integration future-proof, I do think it's wiser to forward
only the supported parameters to each client. Parameters that are not
supported are simply ignored with a warning to the user. From a
`huggingface_hub` maintenance perspective, this allows us much more
flexibility as we are not constrained to support the exact same kwargs
in both clients.

## Issue

https://github.com/huggingface/huggingface_hub/issues/2522

## Dependencies

None

## Twitter 

https://x.com/Wauplin

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Lucain 2024-09-21 01:05:24 +02:00 committed by GitHub
parent f2285376a5
commit a2023a1e96
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,3 +1,4 @@
import inspect
import json # type: ignore[import-not-found]
import logging
import os
@ -212,17 +213,40 @@ class HuggingFaceEndpoint(LLM):
from huggingface_hub import AsyncInferenceClient, InferenceClient
# Instantiate clients with supported kwargs
sync_supported_kwargs = set(inspect.signature(InferenceClient).parameters)
self.client = InferenceClient(
model=self.model,
timeout=self.timeout,
token=huggingfacehub_api_token,
**self.server_kwargs,
**{
key: value
for key, value in self.server_kwargs.items()
if key in sync_supported_kwargs
},
)
async_supported_kwargs = set(inspect.signature(AsyncInferenceClient).parameters)
self.async_client = AsyncInferenceClient(
model=self.model,
timeout=self.timeout,
token=huggingfacehub_api_token,
**self.server_kwargs,
**{
key: value
for key, value in self.server_kwargs.items()
if key in async_supported_kwargs
},
)
ignored_kwargs = (
set(self.server_kwargs.keys())
- sync_supported_kwargs
- async_supported_kwargs
)
if len(ignored_kwargs) > 0:
logger.warning(
f"Ignoring following parameters as they are not supported by the "
f"InferenceClient or AsyncInferenceClient: {ignored_kwargs}."
)
return self