community[patch]: Remove usage of @root_validator(allow_reuse=True) (#25235)

Remove usage of @root_validator(allow_reuse=True)
This commit is contained in:
Eugene Yurtsev 2024-08-09 10:57:42 -04:00 committed by GitHub
parent a2b4c33bd6
commit 6e57aa7c36
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 76 additions and 74 deletions

View File

@ -167,7 +167,7 @@ class GPTRouter(BaseChatModel):
"""Number of chat completions to generate for each prompt."""
max_tokens: int = 256
@root_validator(allow_reuse=True)
@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
values["gpt_router_api_base"] = get_from_dict_or_env(
values,
@ -183,7 +183,10 @@ class GPTRouter(BaseChatModel):
"GPT_ROUTER_API_KEY",
)
)
return values
@root_validator(pre=True, skip_on_failure=True)
def post_init(cls, values: Dict) -> Dict:
try:
from gpt_router.client import GPTRouterClient

View File

@ -88,7 +88,7 @@ class ChatPerplexity(BaseChatModel):
def lc_secrets(self) -> Dict[str, str]:
return {"pplx_api_key": "PPLX_API_KEY"}
@root_validator(pre=True, allow_reuse=True)
@root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls)
@ -114,7 +114,7 @@ class ChatPerplexity(BaseChatModel):
values["model_kwargs"] = extra
return values
@root_validator(allow_reuse=True)
@root_validator(pre=False, skip_on_failure=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["pplx_api_key"] = get_from_dict_or_env(

View File

@ -3,7 +3,11 @@ from typing import Any, Dict, List, Optional
import requests
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from langchain_core.utils import (
convert_to_secret_str,
get_from_dict_or_env,
secret_from_env,
)
from requests import RequestException
BAICHUAN_API_URL: str = "http://api.baichuan-ai.com/v1/embeddings"
@ -53,7 +57,10 @@ class BaichuanTextEmbeddings(BaseModel, Embeddings):
session: Any #: :meta private:
model_name: str = Field(default="Baichuan-Text-Embedding", alias="model")
"""The model used to embed the documents."""
baichuan_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
baichuan_api_key: Optional[SecretStr] = Field(
alias="api_key",
default_factory=secret_from_env("BAICHUAN_API_KEY", default=None),
)
"""Automatically inferred from env var `BAICHUAN_API_KEY` if not provided."""
chunk_size: int = 16
"""Chunk size when multiple texts are input"""
@ -61,22 +68,21 @@ class BaichuanTextEmbeddings(BaseModel, Embeddings):
class Config:
allow_population_by_field_name = True
@root_validator(allow_reuse=True)
@root_validator(pre=False, skip_on_failure=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that auth token exists in environment."""
try:
if values["baichuan_api_key"] is None:
# This is likely here for some backwards compatibility with
# BAICHUAN_AUTH_TOKEN
baichuan_api_key = convert_to_secret_str(
get_from_dict_or_env(values, "baichuan_api_key", "BAICHUAN_API_KEY")
)
except ValueError as original_exc:
try:
baichuan_api_key = convert_to_secret_str(
get_from_dict_or_env(
values, "baichuan_auth_token", "BAICHUAN_AUTH_TOKEN"
)
get_from_dict_or_env(
values, "baichuan_auth_token", "BAICHUAN_AUTH_TOKEN"
)
except ValueError:
raise original_exc
)
values["baichuan_api_key"] = baichuan_api_key
else:
baichuan_api_key = values["baichuan_api_key"]
session = requests.Session()
session.headers.update(
{

View File

@ -53,7 +53,7 @@ class GradientEmbeddings(BaseModel, Embeddings):
class Config:
extra = "forbid"
@root_validator(allow_reuse=True)
@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
@ -65,8 +65,15 @@ class GradientEmbeddings(BaseModel, Embeddings):
)
values["gradient_api_url"] = get_from_dict_or_env(
values, "gradient_api_url", "GRADIENT_API_URL"
values,
"gradient_api_url",
"GRADIENT_API_URL",
default="https://api.gradient.ai/api",
)
return values
@root_validator(pre=False, skip_on_failure=True)
def post_init(cls, values: Dict) -> Dict:
try:
import gradientai
except ImportError:
@ -85,7 +92,6 @@ class GradientEmbeddings(BaseModel, Embeddings):
host=values["gradient_api_url"],
)
values["client"] = gradient.get_embeddings_model(slug=values["model"])
return values
def embed_documents(self, texts: List[str]) -> List[List[float]]:

View File

@ -47,7 +47,7 @@ class InfinityEmbeddings(BaseModel, Embeddings):
class Config:
extra = "forbid"
@root_validator(allow_reuse=True)
@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""

View File

@ -60,7 +60,7 @@ class InfinityEmbeddingsLocal(BaseModel, Embeddings):
class Config:
extra = "forbid"
@root_validator(allow_reuse=True)
@root_validator(pre=False, skip_on_failure=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""

View File

@ -12,8 +12,10 @@ from wsgiref.handlers import format_date_time
import numpy as np
import requests
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr
from langchain_core.utils import (
secret_from_env,
)
from numpy import ndarray
# SparkLLMTextEmbeddings is an embedding model provided by iFLYTEK Co., Ltd.. (https://iflytek.com/en/).
@ -102,11 +104,18 @@ class SparkLLMTextEmbeddings(BaseModel, Embeddings):
]
""" # noqa: E501
spark_app_id: Optional[SecretStr] = Field(default=None, alias="app_id")
spark_app_id: SecretStr = Field(
alias="app_id", default_factory=secret_from_env("SPARK_APP_ID")
)
"""Automatically inferred from env var `SPARK_APP_ID` if not provided."""
spark_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
spark_api_key: Optional[SecretStr] = Field(
alias="api_key", default_factory=secret_from_env("SPARK_API_KEY", default=None)
)
"""Automatically inferred from env var `SPARK_API_KEY` if not provided."""
spark_api_secret: Optional[SecretStr] = Field(default=None, alias="api_secret")
spark_api_secret: Optional[SecretStr] = Field(
alias="api_secret",
default_factory=secret_from_env("SPARK_API_SECRET", default=None),
)
"""Automatically inferred from env var `SPARK_API_SECRET` if not provided."""
base_url: str = Field(default="https://emb-cn-huabei-1.xf-yun.com/")
"""Base URL path for API requests"""
@ -118,20 +127,6 @@ class SparkLLMTextEmbeddings(BaseModel, Embeddings):
class Config:
allow_population_by_field_name = True
@root_validator(allow_reuse=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that auth token exists in environment."""
values["spark_app_id"] = convert_to_secret_str(
get_from_dict_or_env(values, "spark_app_id", "SPARK_APP_ID")
)
values["spark_api_key"] = convert_to_secret_str(
get_from_dict_or_env(values, "spark_api_key", "SPARK_API_KEY")
)
values["spark_api_secret"] = convert_to_secret_str(
get_from_dict_or_env(values, "spark_api_secret", "SPARK_API_SECRET")
)
return values
def _embed(self, texts: List[str], host: str) -> Optional[List[List[float]]]:
"""Internal method to call Spark Embedding API and return embeddings.

View File

@ -77,7 +77,7 @@ class GradientLLM(BaseLLM):
allow_population_by_field_name = True
extra = "forbid"
@root_validator(allow_reuse=True)
@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
@ -88,6 +88,26 @@ class GradientLLM(BaseLLM):
values, "gradient_workspace_id", "GRADIENT_WORKSPACE_ID"
)
values["gradient_api_url"] = get_from_dict_or_env(
values, "gradient_api_url", "GRADIENT_API_URL"
)
return values
@root_validator(pre=False, skip_on_failure=True)
def post_init(cls, values: Dict) -> Dict:
"""Post init validation."""
# Can be most to post_init_validation
try:
import gradientai # noqa
except ImportError:
logging.warning(
"DeprecationWarning: `GradientLLM` will use "
"`pip install gradientai` in future releases of langchain."
)
except Exception:
pass
# Can be most to post_init_validation
if (
values["gradient_access_token"] is None
or len(values["gradient_access_token"]) < 10
@ -114,20 +134,6 @@ class GradientLLM(BaseLLM):
if 0 >= kw.get("max_generated_token_count", 1):
raise ValueError("`max_generated_token_count` must be positive")
values["gradient_api_url"] = get_from_dict_or_env(
values, "gradient_api_url", "GRADIENT_API_URL"
)
try:
import gradientai # noqa
except ImportError:
logging.warning(
"DeprecationWarning: `GradientLLM` will use "
"`pip install gradientai` in future releases of langchain."
)
except Exception:
pass
return values
@property

View File

@ -6,8 +6,6 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from langchain_core.vectorstores import VectorStore
@ -164,18 +162,6 @@ class NeuralDBVectorStore(VectorStore):
offset = self.db._savable_state.documents.get_source_by_id(source_id)[1]
return [str(offset + i) for i in range(len(texts))] # type: ignore[arg-type]
@root_validator(allow_reuse=True)
def validate_environments(cls, values: Dict) -> Dict:
"""Validate ThirdAI environment variables."""
values["thirdai_key"] = convert_to_secret_str(
get_from_dict_or_env(
values,
"thirdai_key",
"THIRDAI_KEY",
)
)
return values
def insert( # type: ignore[no-untyped-def, no-untyped-def]
self,
sources: List[Any],

View File

@ -20,7 +20,7 @@ count=$(git grep -E '(@root_validator)|(@validator)|(@pre_init)' -- "*.py" | wc
# PRs that increase the current count will not be accepted.
# PRs that decrease update the code in the repository
# and allow decreasing the count of are welcome!
current_count=336
current_count=337
if [ "$count" -gt "$current_count" ]; then
echo "The PR seems to be introducing new usage of @root_validator and/or @field_validator."

View File

@ -8,7 +8,7 @@ from langchain_community.embeddings import BaichuanTextEmbeddings
def test_sparkllm_initialization_by_alias() -> None:
# Effective initialization
embeddings = BaichuanTextEmbeddings( # type: ignore[call-arg]
model="embedding_model", # type: ignore[arg-type]
model="embedding_model",
api_key="your-api-key", # type: ignore[arg-type]
)
assert embeddings.model_name == "embedding_model"

View File

@ -2,7 +2,7 @@ import os
from typing import cast
import pytest
from langchain_core.pydantic_v1 import SecretStr, ValidationError
from langchain_core.pydantic_v1 import SecretStr
from langchain_community.embeddings import SparkLLMTextEmbeddings
@ -43,5 +43,5 @@ def test_initialization_parameters_from_env() -> None:
# Environment variable missing
del os.environ["SPARK_APP_ID"]
with pytest.raises(ValidationError):
with pytest.raises(ValueError):
SparkLLMTextEmbeddings()