mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-24 07:35:18 +00:00
huggingface: fix model
param population (#24743)
- **Description:** Fix the validation error for `endpoint_url` for HuggingFaceEndpoint. I have given a descriptive detail of the isse in the issue that I have created. - **Issue:** #24742 --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
c7a8af2e75
commit
9a29398fe6
@ -1,5 +1,6 @@
|
||||
import json # type: ignore[import-not-found]
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, AsyncIterator, Dict, Iterator, List, Mapping, Optional
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
@ -66,9 +67,10 @@ class HuggingFaceEndpoint(LLM):
|
||||
""" # noqa: E501
|
||||
|
||||
endpoint_url: Optional[str] = None
|
||||
"""Endpoint URL to use."""
|
||||
"""Endpoint URL to use. If repo_id is not specified then this needs to given or
|
||||
should be pass as env variable in `HF_INFERENCE_ENDPOINT`"""
|
||||
repo_id: Optional[str] = None
|
||||
"""Repo to use."""
|
||||
"""Repo to use. If endpoint_url is not specified then this needs to given"""
|
||||
huggingfacehub_api_token: Optional[str] = None
|
||||
max_new_tokens: int = 512
|
||||
"""Maximum number of generated tokens"""
|
||||
@ -146,19 +148,38 @@ class HuggingFaceEndpoint(LLM):
|
||||
|
||||
values["model_kwargs"] = extra
|
||||
|
||||
values["endpoint_url"] = get_from_dict_or_env(
|
||||
values, "endpoint_url", "HF_INFERENCE_ENDPOINT", None
|
||||
)
|
||||
# to correctly create the InferenceClient and AsyncInferenceClient
|
||||
# in validate_environment, we need to populate values["model"].
|
||||
# from InferenceClient docstring:
|
||||
# model (`str`, `optional`):
|
||||
# The model to run inference with. Can be a model id hosted on the Hugging
|
||||
# Face Hub, e.g. `bigcode/starcoder`
|
||||
# or a URL to a deployed Inference Endpoint. Defaults to None, in which
|
||||
# case a recommended model is
|
||||
# automatically selected for the task.
|
||||
|
||||
if values["endpoint_url"] is None and "repo_id" not in values:
|
||||
# this string could be in 3 places of descending priority:
|
||||
# 2. values["model"] or values["endpoint_url"] or values["repo_id"]
|
||||
# (equal priority - don't allow both set)
|
||||
# 3. values["HF_INFERENCE_ENDPOINT"] (if none above set)
|
||||
|
||||
model = values.get("model")
|
||||
endpoint_url = values.get("endpoint_url")
|
||||
repo_id = values.get("repo_id")
|
||||
|
||||
if sum([bool(model), bool(endpoint_url), bool(repo_id)]) > 1:
|
||||
raise ValueError(
|
||||
"Please specify an `endpoint_url` or `repo_id` for the model."
|
||||
"Please specify either a `model` OR an `endpoint_url` OR a `repo_id`,"
|
||||
"not more than one."
|
||||
)
|
||||
if values["endpoint_url"] is not None and "repo_id" in values:
|
||||
values["model"] = (
|
||||
model or endpoint_url or repo_id or os.environ.get("HF_INFERENCE_ENDPOINT")
|
||||
)
|
||||
if not values["model"]:
|
||||
raise ValueError(
|
||||
"Please specify either an `endpoint_url` OR a `repo_id`, not both."
|
||||
"Please specify a `model` or an `endpoint_url` or a `repo_id` for the "
|
||||
"model."
|
||||
)
|
||||
values["model"] = values.get("endpoint_url") or values.get("repo_id")
|
||||
return values
|
||||
|
||||
@root_validator(pre=False, skip_on_failure=True)
|
||||
|
Loading…
Reference in New Issue
Block a user