mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
fix(huggingface): harden hostname validation and reject URLs in repo_id (#36831)
This commit is contained in:
@@ -65,6 +65,12 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings):
|
|||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def validate_environment(self) -> Self:
|
def validate_environment(self) -> Self:
|
||||||
"""Validate that api key and python package exists in environment."""
|
"""Validate that api key and python package exists in environment."""
|
||||||
|
for field_name in ("model", "repo_id"):
|
||||||
|
value = getattr(self, field_name)
|
||||||
|
if value and value.startswith(("http://", "https://")):
|
||||||
|
msg = f"`{field_name}` must be a HuggingFace repo ID, not a URL."
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
huggingfacehub_api_token = self.huggingfacehub_api_token or os.getenv(
|
huggingfacehub_api_token = self.huggingfacehub_api_token or os.getenv(
|
||||||
"HF_TOKEN"
|
"HF_TOKEN"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import logging
|
|||||||
import os
|
import os
|
||||||
from collections.abc import AsyncIterator, Iterator, Mapping
|
from collections.abc import AsyncIterator, Iterator, Mapping
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from langchain_core.callbacks import (
|
from langchain_core.callbacks import (
|
||||||
AsyncCallbackManagerForLLMRun,
|
AsyncCallbackManagerForLLMRun,
|
||||||
@@ -23,8 +24,12 @@ def _is_huggingface_hosted_url(url: str | None) -> bool:
|
|||||||
"""True if url is HF-hosted (huggingface.co or hf.space)."""
|
"""True if url is HF-hosted (huggingface.co or hf.space)."""
|
||||||
if not url:
|
if not url:
|
||||||
return False
|
return False
|
||||||
url_lower = url.lower().strip()
|
hostname = (urlparse(url).hostname or "").lower()
|
||||||
return "huggingface.co" in url_lower or "hf.space" in url_lower
|
return (
|
||||||
|
hostname == "huggingface.co"
|
||||||
|
or hostname == "hf.space"
|
||||||
|
or hostname.endswith((".huggingface.co", ".hf.space"))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
VALID_TASKS = (
|
VALID_TASKS = (
|
||||||
@@ -220,6 +225,13 @@ class HuggingFaceEndpoint(LLM):
|
|||||||
endpoint_url = values.get("endpoint_url")
|
endpoint_url = values.get("endpoint_url")
|
||||||
repo_id = values.get("repo_id")
|
repo_id = values.get("repo_id")
|
||||||
|
|
||||||
|
if repo_id and repo_id.startswith(("http://", "https://")):
|
||||||
|
msg = (
|
||||||
|
"`repo_id` must be a HuggingFace repo ID, not a URL. "
|
||||||
|
"Use `endpoint_url` for direct endpoints."
|
||||||
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
if sum([bool(model), bool(endpoint_url), bool(repo_id)]) > 1:
|
if sum([bool(model), bool(endpoint_url), bool(repo_id)]) > 1:
|
||||||
msg = (
|
msg = (
|
||||||
"Please specify either a `model` OR an `endpoint_url` OR a `repo_id`,"
|
"Please specify either a `model` OR an `endpoint_url` OR a `repo_id`,"
|
||||||
|
|||||||
Reference in New Issue
Block a user