partners: support reading HuggingFace params from env (#23309)

Description: 
1. partners/HuggingFace module support reading params from env. Not
adjust langchain_community/.../huggingfaceXX modules since they are
deprecated.
  2. pydantic 2 @root_validator migration.

Issue: #22448 #22819

---------

Co-authored-by: gongwn1 <gongwn1@lenovo.com>
This commit is contained in:
wenngong 2024-07-02 22:12:45 +08:00 committed by GitHub
parent ffde8a6a09
commit ee5eedfa04
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 26 additions and 12 deletions

View File

@ -325,7 +325,7 @@ class ChatHuggingFace(BaseChatModel):
else self.tokenizer else self.tokenizer
) )
@root_validator() @root_validator(pre=False, skip_on_failure=True)
def validate_llm(cls, values: dict) -> dict: def validate_llm(cls, values: dict) -> dict:
if ( if (
not _is_huggingface_hub(values["llm"]) not _is_huggingface_hub(values["llm"])

View File

@ -1,9 +1,9 @@
import json import json
import os
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
from langchain_core.utils import get_from_dict_or_env
DEFAULT_MODEL = "sentence-transformers/all-mpnet-base-v2" DEFAULT_MODEL = "sentence-transformers/all-mpnet-base-v2"
VALID_TASKS = ("feature-extraction",) VALID_TASKS = ("feature-extraction",)
@ -46,11 +46,15 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings):
extra = Extra.forbid extra = Extra.forbid
@root_validator() @root_validator(pre=False, skip_on_failure=True)
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""
huggingfacehub_api_token = values["huggingfacehub_api_token"] or os.getenv( values["huggingfacehub_api_token"] = get_from_dict_or_env(
"HUGGINGFACEHUB_API_TOKEN" values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN", None
)
huggingfacehub_api_token = get_from_dict_or_env(
values, "huggingfacehub_api_token", "HF_TOKEN", None
) )
try: try:

View File

@ -1,6 +1,5 @@
import json import json
import logging import logging
import os
from typing import Any, AsyncIterator, Dict, Iterator, List, Mapping, Optional from typing import Any, AsyncIterator, Dict, Iterator, List, Mapping, Optional
from langchain_core.callbacks import ( from langchain_core.callbacks import (
@ -10,7 +9,7 @@ from langchain_core.callbacks import (
from langchain_core.language_models.llms import LLM from langchain_core.language_models.llms import LLM
from langchain_core.outputs import GenerationChunk from langchain_core.outputs import GenerationChunk
from langchain_core.pydantic_v1 import Extra, Field, root_validator from langchain_core.pydantic_v1 import Extra, Field, root_validator
from langchain_core.utils import get_pydantic_field_names from langchain_core.utils import get_from_dict_or_env, get_pydantic_field_names
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -146,18 +145,23 @@ class HuggingFaceEndpoint(LLM):
) )
values["model_kwargs"] = extra values["model_kwargs"] = extra
if "endpoint_url" not in values and "repo_id" not in values:
values["endpoint_url"] = get_from_dict_or_env(
values, "endpoint_url", "HF_INFERENCE_ENDPOINT", None
)
if values["endpoint_url"] is None and "repo_id" not in values:
raise ValueError( raise ValueError(
"Please specify an `endpoint_url` or `repo_id` for the model." "Please specify an `endpoint_url` or `repo_id` for the model."
) )
if "endpoint_url" in values and "repo_id" in values: if values["endpoint_url"] is not None and "repo_id" in values:
raise ValueError( raise ValueError(
"Please specify either an `endpoint_url` OR a `repo_id`, not both." "Please specify either an `endpoint_url` OR a `repo_id`, not both."
) )
values["model"] = values.get("endpoint_url") or values.get("repo_id") values["model"] = values.get("endpoint_url") or values.get("repo_id")
return values return values
@root_validator() @root_validator(pre=False, skip_on_failure=True)
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that package is installed and that the API token is valid.""" """Validate that package is installed and that the API token is valid."""
try: try:
@ -168,9 +172,15 @@ class HuggingFaceEndpoint(LLM):
"Could not import huggingface_hub python package. " "Could not import huggingface_hub python package. "
"Please install it with `pip install huggingface_hub`." "Please install it with `pip install huggingface_hub`."
) )
huggingfacehub_api_token = values["huggingfacehub_api_token"] or os.getenv(
"HUGGINGFACEHUB_API_TOKEN" values["huggingfacehub_api_token"] = get_from_dict_or_env(
values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN", None
) )
huggingfacehub_api_token = get_from_dict_or_env(
values, "huggingfacehub_api_token", "HF_TOKEN", None
)
if huggingfacehub_api_token is not None: if huggingfacehub_api_token is not None:
try: try:
login(token=huggingfacehub_api_token) login(token=huggingfacehub_api_token)