mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-25 08:03:39 +00:00
huggingface: fix model
param population (#24743)
- **Description:** Fix the validation error for `endpoint_url` for HuggingFaceEndpoint. I have given a descriptive detail of the isse in the issue that I have created. - **Issue:** #24742 --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
c7a8af2e75
commit
9a29398fe6
@ -1,5 +1,6 @@
|
|||||||
import json # type: ignore[import-not-found]
|
import json # type: ignore[import-not-found]
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
from typing import Any, AsyncIterator, Dict, Iterator, List, Mapping, Optional
|
from typing import Any, AsyncIterator, Dict, Iterator, List, Mapping, Optional
|
||||||
|
|
||||||
from langchain_core.callbacks import (
|
from langchain_core.callbacks import (
|
||||||
@ -66,9 +67,10 @@ class HuggingFaceEndpoint(LLM):
|
|||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
|
|
||||||
endpoint_url: Optional[str] = None
|
endpoint_url: Optional[str] = None
|
||||||
"""Endpoint URL to use."""
|
"""Endpoint URL to use. If repo_id is not specified then this needs to given or
|
||||||
|
should be pass as env variable in `HF_INFERENCE_ENDPOINT`"""
|
||||||
repo_id: Optional[str] = None
|
repo_id: Optional[str] = None
|
||||||
"""Repo to use."""
|
"""Repo to use. If endpoint_url is not specified then this needs to given"""
|
||||||
huggingfacehub_api_token: Optional[str] = None
|
huggingfacehub_api_token: Optional[str] = None
|
||||||
max_new_tokens: int = 512
|
max_new_tokens: int = 512
|
||||||
"""Maximum number of generated tokens"""
|
"""Maximum number of generated tokens"""
|
||||||
@ -146,19 +148,38 @@ class HuggingFaceEndpoint(LLM):
|
|||||||
|
|
||||||
values["model_kwargs"] = extra
|
values["model_kwargs"] = extra
|
||||||
|
|
||||||
values["endpoint_url"] = get_from_dict_or_env(
|
# to correctly create the InferenceClient and AsyncInferenceClient
|
||||||
values, "endpoint_url", "HF_INFERENCE_ENDPOINT", None
|
# in validate_environment, we need to populate values["model"].
|
||||||
)
|
# from InferenceClient docstring:
|
||||||
|
# model (`str`, `optional`):
|
||||||
|
# The model to run inference with. Can be a model id hosted on the Hugging
|
||||||
|
# Face Hub, e.g. `bigcode/starcoder`
|
||||||
|
# or a URL to a deployed Inference Endpoint. Defaults to None, in which
|
||||||
|
# case a recommended model is
|
||||||
|
# automatically selected for the task.
|
||||||
|
|
||||||
if values["endpoint_url"] is None and "repo_id" not in values:
|
# this string could be in 3 places of descending priority:
|
||||||
|
# 2. values["model"] or values["endpoint_url"] or values["repo_id"]
|
||||||
|
# (equal priority - don't allow both set)
|
||||||
|
# 3. values["HF_INFERENCE_ENDPOINT"] (if none above set)
|
||||||
|
|
||||||
|
model = values.get("model")
|
||||||
|
endpoint_url = values.get("endpoint_url")
|
||||||
|
repo_id = values.get("repo_id")
|
||||||
|
|
||||||
|
if sum([bool(model), bool(endpoint_url), bool(repo_id)]) > 1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Please specify an `endpoint_url` or `repo_id` for the model."
|
"Please specify either a `model` OR an `endpoint_url` OR a `repo_id`,"
|
||||||
|
"not more than one."
|
||||||
)
|
)
|
||||||
if values["endpoint_url"] is not None and "repo_id" in values:
|
values["model"] = (
|
||||||
|
model or endpoint_url or repo_id or os.environ.get("HF_INFERENCE_ENDPOINT")
|
||||||
|
)
|
||||||
|
if not values["model"]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Please specify either an `endpoint_url` OR a `repo_id`, not both."
|
"Please specify a `model` or an `endpoint_url` or a `repo_id` for the "
|
||||||
|
"model."
|
||||||
)
|
)
|
||||||
values["model"] = values.get("endpoint_url") or values.get("repo_id")
|
|
||||||
return values
|
return values
|
||||||
|
|
||||||
@root_validator(pre=False, skip_on_failure=True)
|
@root_validator(pre=False, skip_on_failure=True)
|
||||||
|
Loading…
Reference in New Issue
Block a user