fix(huggingface): harden hostname validation and reject URLs in repo_id (#36831)

This commit is contained in:
ccurme
2026-04-16 15:49:48 -04:00
committed by GitHub
parent af0e174ef7
commit a029c7bf1d
2 changed files with 20 additions and 2 deletions

View File

@@ -65,6 +65,12 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings):
@model_validator(mode="after") @model_validator(mode="after")
def validate_environment(self) -> Self: def validate_environment(self) -> Self:
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""
for field_name in ("model", "repo_id"):
value = getattr(self, field_name)
if value and value.startswith(("http://", "https://")):
msg = f"`{field_name}` must be a HuggingFace repo ID, not a URL."
raise ValueError(msg)
huggingfacehub_api_token = self.huggingfacehub_api_token or os.getenv( huggingfacehub_api_token = self.huggingfacehub_api_token or os.getenv(
"HF_TOKEN" "HF_TOKEN"
) )

View File

@@ -5,6 +5,7 @@ import logging
import os import os
from collections.abc import AsyncIterator, Iterator, Mapping from collections.abc import AsyncIterator, Iterator, Mapping
from typing import Any from typing import Any
from urllib.parse import urlparse
from langchain_core.callbacks import ( from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun,
@@ -23,8 +24,12 @@ def _is_huggingface_hosted_url(url: str | None) -> bool:
"""True if url is HF-hosted (huggingface.co or hf.space).""" """True if url is HF-hosted (huggingface.co or hf.space)."""
if not url: if not url:
return False return False
url_lower = url.lower().strip() hostname = (urlparse(url).hostname or "").lower()
return "huggingface.co" in url_lower or "hf.space" in url_lower return (
hostname == "huggingface.co"
or hostname == "hf.space"
or hostname.endswith((".huggingface.co", ".hf.space"))
)
VALID_TASKS = ( VALID_TASKS = (
@@ -220,6 +225,13 @@ class HuggingFaceEndpoint(LLM):
endpoint_url = values.get("endpoint_url") endpoint_url = values.get("endpoint_url")
repo_id = values.get("repo_id") repo_id = values.get("repo_id")
if repo_id and repo_id.startswith(("http://", "https://")):
msg = (
"`repo_id` must be a HuggingFace repo ID, not a URL. "
"Use `endpoint_url` for direct endpoints."
)
raise ValueError(msg)
if sum([bool(model), bool(endpoint_url), bool(repo_id)]) > 1: if sum([bool(model), bool(endpoint_url), bool(repo_id)]) > 1:
msg = ( msg = (
"Please specify either a `model` OR an `endpoint_url` OR a `repo_id`," "Please specify either a `model` OR an `endpoint_url` OR a `repo_id`,"