diff --git a/libs/partners/huggingface/langchain_huggingface/embeddings/huggingface_endpoint.py b/libs/partners/huggingface/langchain_huggingface/embeddings/huggingface_endpoint.py index e61c2292122..25dfd4fc753 100644 --- a/libs/partners/huggingface/langchain_huggingface/embeddings/huggingface_endpoint.py +++ b/libs/partners/huggingface/langchain_huggingface/embeddings/huggingface_endpoint.py @@ -65,6 +65,12 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings): @model_validator(mode="after") def validate_environment(self) -> Self: """Validate that api key and python package exists in environment.""" + for field_name in ("model", "repo_id"): + value = getattr(self, field_name) + if value and value.startswith(("http://", "https://")): + msg = f"`{field_name}` must be a HuggingFace repo ID, not a URL." + raise ValueError(msg) + huggingfacehub_api_token = self.huggingfacehub_api_token or os.getenv( "HF_TOKEN" ) diff --git a/libs/partners/huggingface/langchain_huggingface/llms/huggingface_endpoint.py b/libs/partners/huggingface/langchain_huggingface/llms/huggingface_endpoint.py index 4d86c97d24d..a3a15b1c675 100644 --- a/libs/partners/huggingface/langchain_huggingface/llms/huggingface_endpoint.py +++ b/libs/partners/huggingface/langchain_huggingface/llms/huggingface_endpoint.py @@ -5,6 +5,7 @@ import logging import os from collections.abc import AsyncIterator, Iterator, Mapping from typing import Any +from urllib.parse import urlparse from langchain_core.callbacks import ( AsyncCallbackManagerForLLMRun, @@ -23,8 +24,12 @@ def _is_huggingface_hosted_url(url: str | None) -> bool: """True if url is HF-hosted (huggingface.co or hf.space).""" if not url: return False - url_lower = url.lower().strip() - return "huggingface.co" in url_lower or "hf.space" in url_lower + hostname = (urlparse(url).hostname or "").lower() + return ( + hostname == "huggingface.co" + or hostname == "hf.space" + or hostname.endswith((".huggingface.co", ".hf.space")) + ) VALID_TASKS = ( @@ -220,6 +225,13 @@ class HuggingFaceEndpoint(LLM): endpoint_url = values.get("endpoint_url") repo_id = values.get("repo_id") + if repo_id and repo_id.startswith(("http://", "https://")): + msg = ( + "`repo_id` must be a HuggingFace repo ID, not a URL. " + "Use `endpoint_url` for direct endpoints." + ) + raise ValueError(msg) + if sum([bool(model), bool(endpoint_url), bool(repo_id)]) > 1: msg = ( "Please specify either a `model` OR an `endpoint_url` OR a `repo_id`,"