fix(huggingface): avoid hf api calls when using local HuggingFaceEndpoint (#35633)

This commit is contained in:
Darshan Gorasiya
2026-03-29 02:14:36 +01:00
committed by GitHub
parent 61443c2580
commit aba72f7229
2 changed files with 94 additions and 2 deletions

View File

@@ -18,6 +18,15 @@ from typing_extensions import Self
logger = logging.getLogger(__name__)
def _is_huggingface_hosted_url(url: str | None) -> bool:
"""True if url is HF-hosted (huggingface.co or hf.space)."""
if not url:
return False
url_lower = url.lower().strip()
return "huggingface.co" in url_lower or "hf.space" in url_lower
VALID_TASKS = (
"text2text-generation",
"text-generation",
@@ -234,6 +243,11 @@ class HuggingFaceEndpoint(LLM):
huggingfacehub_api_token = self.huggingfacehub_api_token or os.getenv(
"HF_TOKEN"
)
# Local/custom endpoint URL -> don't pass HF token (avoids 401s and egress).
if self.endpoint_url and not _is_huggingface_hosted_url(self.endpoint_url):
client_api_key: str | None = None
else:
client_api_key = huggingfacehub_api_token
from huggingface_hub import ( # type: ignore[import]
AsyncInferenceClient, # type: ignore[import]
@@ -245,7 +259,7 @@ class HuggingFaceEndpoint(LLM):
self.client = InferenceClient(
model=self.model,
timeout=self.timeout,
api_key=huggingfacehub_api_token,
api_key=client_api_key,
provider=self.provider, # type: ignore[arg-type]
**{
key: value
@@ -258,7 +272,7 @@ class HuggingFaceEndpoint(LLM):
self.async_client = AsyncInferenceClient(
model=self.model,
timeout=self.timeout,
api_key=huggingfacehub_api_token,
api_key=client_api_key,
provider=self.provider, # type: ignore[arg-type]
**{
key: value