fix(huggingface): avoid hf api calls when using local HuggingFaceEndpoint (#35633)

This commit is contained in:
Darshan Gorasiya
2026-03-29 02:14:36 +01:00
committed by GitHub
parent 61443c2580
commit aba72f7229
2 changed files with 94 additions and 2 deletions

View File

@@ -18,6 +18,15 @@ from typing_extensions import Self
logger = logging.getLogger(__name__)
def _is_huggingface_hosted_url(url: str | None) -> bool:
"""True if url is HF-hosted (huggingface.co or hf.space)."""
if not url:
return False
url_lower = url.lower().strip()
return "huggingface.co" in url_lower or "hf.space" in url_lower
VALID_TASKS = (
"text2text-generation",
"text-generation",
@@ -234,6 +243,11 @@ class HuggingFaceEndpoint(LLM):
huggingfacehub_api_token = self.huggingfacehub_api_token or os.getenv(
"HF_TOKEN"
)
# Local/custom endpoint URL -> don't pass HF token (avoids 401s and egress).
if self.endpoint_url and not _is_huggingface_hosted_url(self.endpoint_url):
client_api_key: str | None = None
else:
client_api_key = huggingfacehub_api_token
from huggingface_hub import ( # type: ignore[import]
AsyncInferenceClient, # type: ignore[import]
@@ -245,7 +259,7 @@ class HuggingFaceEndpoint(LLM):
self.client = InferenceClient(
model=self.model,
timeout=self.timeout,
api_key=huggingfacehub_api_token,
api_key=client_api_key,
provider=self.provider, # type: ignore[arg-type]
**{
key: value
@@ -258,7 +272,7 @@ class HuggingFaceEndpoint(LLM):
self.async_client = AsyncInferenceClient(
model=self.model,
timeout=self.timeout,
api_key=huggingfacehub_api_token,
api_key=client_api_key,
provider=self.provider, # type: ignore[arg-type]
**{
key: value

View File

@@ -0,0 +1,78 @@
"""Tests for HuggingFaceEndpoint with local/custom endpoint_url (no HF API calls)."""
from unittest.mock import MagicMock, patch
import pytest
from langchain_huggingface.llms.huggingface_endpoint import (
HuggingFaceEndpoint,
_is_huggingface_hosted_url,
)
@pytest.mark.parametrize(
("url", "expected"),
[
(None, False),
("", False),
("http://localhost:8010/", False),
("http://127.0.0.1:8080", False),
("http://my-tgi.internal/", False),
("https://api.inference-api.azure-api.net/", False),
("https://abc.huggingface.co/inference", True),
("https://xyz.hf.space/", True),
],
)
def test_is_huggingface_hosted_url(
url: str | None,
expected: bool, # noqa: FBT001
) -> None:
"""URL helper: local/custom vs HF-hosted."""
assert _is_huggingface_hosted_url(url) is expected
@patch(
"huggingface_hub.AsyncInferenceClient",
)
@patch("huggingface_hub.InferenceClient")
def test_local_endpoint_does_not_pass_api_key(
mock_inference_client: MagicMock,
mock_async_client: MagicMock,
) -> None:
"""With a local endpoint_url we don't pass api_key so the client doesn't hit HF."""
mock_inference_client.return_value = MagicMock()
mock_async_client.return_value = MagicMock()
HuggingFaceEndpoint( # type: ignore[call-arg]
endpoint_url="http://localhost:8010/",
max_new_tokens=64,
)
mock_inference_client.assert_called_once()
call_kwargs = mock_inference_client.call_args[1]
assert call_kwargs.get("api_key") is None
assert call_kwargs.get("model") == "http://localhost:8010/"
mock_async_client.assert_called_once()
async_call_kwargs = mock_async_client.call_args[1]
assert async_call_kwargs.get("api_key") is None
@patch("huggingface_hub.AsyncInferenceClient")
@patch("huggingface_hub.InferenceClient")
def test_huggingface_hosted_endpoint_keeps_api_key(
mock_inference_client: MagicMock,
mock_async_client: MagicMock,
) -> None:
"""HF-hosted endpoint_url still gets the token."""
mock_inference_client.return_value = MagicMock()
mock_async_client.return_value = MagicMock()
HuggingFaceEndpoint( # type: ignore[call-arg]
endpoint_url="https://abc.huggingface.co/inference",
max_new_tokens=64,
huggingfacehub_api_token="hf_xxx", # noqa: S106
)
call_kwargs = mock_inference_client.call_args[1]
assert call_kwargs.get("api_key") == "hf_xxx"