community[patch]: Remove usage of @root_validator(allow_reuse=True) (#25235)

Remove usage of @root_validator(allow_reuse=True)
This commit is contained in:
Eugene Yurtsev 2024-08-09 10:57:42 -04:00 committed by GitHub
parent a2b4c33bd6
commit 6e57aa7c36
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 76 additions and 74 deletions

View File

@ -167,7 +167,7 @@ class GPTRouter(BaseChatModel):
"""Number of chat completions to generate for each prompt.""" """Number of chat completions to generate for each prompt."""
max_tokens: int = 256 max_tokens: int = 256
@root_validator(allow_reuse=True) @root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
values["gpt_router_api_base"] = get_from_dict_or_env( values["gpt_router_api_base"] = get_from_dict_or_env(
values, values,
@ -183,7 +183,10 @@ class GPTRouter(BaseChatModel):
"GPT_ROUTER_API_KEY", "GPT_ROUTER_API_KEY",
) )
) )
return values
@root_validator(pre=True, skip_on_failure=True)
def post_init(cls, values: Dict) -> Dict:
try: try:
from gpt_router.client import GPTRouterClient from gpt_router.client import GPTRouterClient

View File

@ -88,7 +88,7 @@ class ChatPerplexity(BaseChatModel):
def lc_secrets(self) -> Dict[str, str]: def lc_secrets(self) -> Dict[str, str]:
return {"pplx_api_key": "PPLX_API_KEY"} return {"pplx_api_key": "PPLX_API_KEY"}
@root_validator(pre=True, allow_reuse=True) @root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra kwargs from additional params that were passed in.""" """Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls) all_required_field_names = get_pydantic_field_names(cls)
@ -114,7 +114,7 @@ class ChatPerplexity(BaseChatModel):
values["model_kwargs"] = extra values["model_kwargs"] = extra
return values return values
@root_validator(allow_reuse=True) @root_validator(pre=False, skip_on_failure=True)
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""
values["pplx_api_key"] = get_from_dict_or_env( values["pplx_api_key"] = get_from_dict_or_env(

View File

@ -3,7 +3,11 @@ from typing import Any, Dict, List, Optional
import requests import requests
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env from langchain_core.utils import (
convert_to_secret_str,
get_from_dict_or_env,
secret_from_env,
)
from requests import RequestException from requests import RequestException
BAICHUAN_API_URL: str = "http://api.baichuan-ai.com/v1/embeddings" BAICHUAN_API_URL: str = "http://api.baichuan-ai.com/v1/embeddings"
@ -53,7 +57,10 @@ class BaichuanTextEmbeddings(BaseModel, Embeddings):
session: Any #: :meta private: session: Any #: :meta private:
model_name: str = Field(default="Baichuan-Text-Embedding", alias="model") model_name: str = Field(default="Baichuan-Text-Embedding", alias="model")
"""The model used to embed the documents.""" """The model used to embed the documents."""
baichuan_api_key: Optional[SecretStr] = Field(default=None, alias="api_key") baichuan_api_key: Optional[SecretStr] = Field(
alias="api_key",
default_factory=secret_from_env("BAICHUAN_API_KEY", default=None),
)
"""Automatically inferred from env var `BAICHUAN_API_KEY` if not provided.""" """Automatically inferred from env var `BAICHUAN_API_KEY` if not provided."""
chunk_size: int = 16 chunk_size: int = 16
"""Chunk size when multiple texts are input""" """Chunk size when multiple texts are input"""
@ -61,22 +68,21 @@ class BaichuanTextEmbeddings(BaseModel, Embeddings):
class Config: class Config:
allow_population_by_field_name = True allow_population_by_field_name = True
@root_validator(allow_reuse=True) @root_validator(pre=False, skip_on_failure=True)
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that auth token exists in environment.""" """Validate that auth token exists in environment."""
try: if values["baichuan_api_key"] is None:
# This is likely here for some backwards compatibility with
# BAICHUAN_AUTH_TOKEN
baichuan_api_key = convert_to_secret_str( baichuan_api_key = convert_to_secret_str(
get_from_dict_or_env(values, "baichuan_api_key", "BAICHUAN_API_KEY") get_from_dict_or_env(
) values, "baichuan_auth_token", "BAICHUAN_AUTH_TOKEN"
except ValueError as original_exc:
try:
baichuan_api_key = convert_to_secret_str(
get_from_dict_or_env(
values, "baichuan_auth_token", "BAICHUAN_AUTH_TOKEN"
)
) )
except ValueError: )
raise original_exc values["baichuan_api_key"] = baichuan_api_key
else:
baichuan_api_key = values["baichuan_api_key"]
session = requests.Session() session = requests.Session()
session.headers.update( session.headers.update(
{ {

View File

@ -53,7 +53,7 @@ class GradientEmbeddings(BaseModel, Embeddings):
class Config: class Config:
extra = "forbid" extra = "forbid"
@root_validator(allow_reuse=True) @root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""
@ -65,8 +65,15 @@ class GradientEmbeddings(BaseModel, Embeddings):
) )
values["gradient_api_url"] = get_from_dict_or_env( values["gradient_api_url"] = get_from_dict_or_env(
values, "gradient_api_url", "GRADIENT_API_URL" values,
"gradient_api_url",
"GRADIENT_API_URL",
default="https://api.gradient.ai/api",
) )
return values
@root_validator(pre=False, skip_on_failure=True)
def post_init(cls, values: Dict) -> Dict:
try: try:
import gradientai import gradientai
except ImportError: except ImportError:
@ -85,7 +92,6 @@ class GradientEmbeddings(BaseModel, Embeddings):
host=values["gradient_api_url"], host=values["gradient_api_url"],
) )
values["client"] = gradient.get_embeddings_model(slug=values["model"]) values["client"] = gradient.get_embeddings_model(slug=values["model"])
return values return values
def embed_documents(self, texts: List[str]) -> List[List[float]]: def embed_documents(self, texts: List[str]) -> List[List[float]]:

View File

@ -47,7 +47,7 @@ class InfinityEmbeddings(BaseModel, Embeddings):
class Config: class Config:
extra = "forbid" extra = "forbid"
@root_validator(allow_reuse=True) @root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""

View File

@ -60,7 +60,7 @@ class InfinityEmbeddingsLocal(BaseModel, Embeddings):
class Config: class Config:
extra = "forbid" extra = "forbid"
@root_validator(allow_reuse=True) @root_validator(pre=False, skip_on_failure=True)
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""

View File

@ -12,8 +12,10 @@ from wsgiref.handlers import format_date_time
import numpy as np import numpy as np
import requests import requests
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env from langchain_core.utils import (
secret_from_env,
)
from numpy import ndarray from numpy import ndarray
# SparkLLMTextEmbeddings is an embedding model provided by iFLYTEK Co., Ltd.. (https://iflytek.com/en/). # SparkLLMTextEmbeddings is an embedding model provided by iFLYTEK Co., Ltd.. (https://iflytek.com/en/).
@ -102,11 +104,18 @@ class SparkLLMTextEmbeddings(BaseModel, Embeddings):
] ]
""" # noqa: E501 """ # noqa: E501
spark_app_id: Optional[SecretStr] = Field(default=None, alias="app_id") spark_app_id: SecretStr = Field(
alias="app_id", default_factory=secret_from_env("SPARK_APP_ID")
)
"""Automatically inferred from env var `SPARK_APP_ID` if not provided.""" """Automatically inferred from env var `SPARK_APP_ID` if not provided."""
spark_api_key: Optional[SecretStr] = Field(default=None, alias="api_key") spark_api_key: Optional[SecretStr] = Field(
alias="api_key", default_factory=secret_from_env("SPARK_API_KEY", default=None)
)
"""Automatically inferred from env var `SPARK_API_KEY` if not provided.""" """Automatically inferred from env var `SPARK_API_KEY` if not provided."""
spark_api_secret: Optional[SecretStr] = Field(default=None, alias="api_secret") spark_api_secret: Optional[SecretStr] = Field(
alias="api_secret",
default_factory=secret_from_env("SPARK_API_SECRET", default=None),
)
"""Automatically inferred from env var `SPARK_API_SECRET` if not provided.""" """Automatically inferred from env var `SPARK_API_SECRET` if not provided."""
base_url: str = Field(default="https://emb-cn-huabei-1.xf-yun.com/") base_url: str = Field(default="https://emb-cn-huabei-1.xf-yun.com/")
"""Base URL path for API requests""" """Base URL path for API requests"""
@ -118,20 +127,6 @@ class SparkLLMTextEmbeddings(BaseModel, Embeddings):
class Config: class Config:
allow_population_by_field_name = True allow_population_by_field_name = True
@root_validator(allow_reuse=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that auth token exists in environment."""
values["spark_app_id"] = convert_to_secret_str(
get_from_dict_or_env(values, "spark_app_id", "SPARK_APP_ID")
)
values["spark_api_key"] = convert_to_secret_str(
get_from_dict_or_env(values, "spark_api_key", "SPARK_API_KEY")
)
values["spark_api_secret"] = convert_to_secret_str(
get_from_dict_or_env(values, "spark_api_secret", "SPARK_API_SECRET")
)
return values
def _embed(self, texts: List[str], host: str) -> Optional[List[List[float]]]: def _embed(self, texts: List[str], host: str) -> Optional[List[List[float]]]:
"""Internal method to call Spark Embedding API and return embeddings. """Internal method to call Spark Embedding API and return embeddings.

View File

@ -77,7 +77,7 @@ class GradientLLM(BaseLLM):
allow_population_by_field_name = True allow_population_by_field_name = True
extra = "forbid" extra = "forbid"
@root_validator(allow_reuse=True) @root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""
@ -88,6 +88,26 @@ class GradientLLM(BaseLLM):
values, "gradient_workspace_id", "GRADIENT_WORKSPACE_ID" values, "gradient_workspace_id", "GRADIENT_WORKSPACE_ID"
) )
values["gradient_api_url"] = get_from_dict_or_env(
values, "gradient_api_url", "GRADIENT_API_URL"
)
return values
@root_validator(pre=False, skip_on_failure=True)
def post_init(cls, values: Dict) -> Dict:
"""Post init validation."""
# Can be most to post_init_validation
try:
import gradientai # noqa
except ImportError:
logging.warning(
"DeprecationWarning: `GradientLLM` will use "
"`pip install gradientai` in future releases of langchain."
)
except Exception:
pass
# Can be most to post_init_validation
if ( if (
values["gradient_access_token"] is None values["gradient_access_token"] is None
or len(values["gradient_access_token"]) < 10 or len(values["gradient_access_token"]) < 10
@ -114,20 +134,6 @@ class GradientLLM(BaseLLM):
if 0 >= kw.get("max_generated_token_count", 1): if 0 >= kw.get("max_generated_token_count", 1):
raise ValueError("`max_generated_token_count` must be positive") raise ValueError("`max_generated_token_count` must be positive")
values["gradient_api_url"] = get_from_dict_or_env(
values, "gradient_api_url", "GRADIENT_API_URL"
)
try:
import gradientai # noqa
except ImportError:
logging.warning(
"DeprecationWarning: `GradientLLM` will use "
"`pip install gradientai` in future releases of langchain."
)
except Exception:
pass
return values return values
@property @property

View File

@ -6,8 +6,6 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from langchain_core.documents import Document from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from langchain_core.vectorstores import VectorStore from langchain_core.vectorstores import VectorStore
@ -164,18 +162,6 @@ class NeuralDBVectorStore(VectorStore):
offset = self.db._savable_state.documents.get_source_by_id(source_id)[1] offset = self.db._savable_state.documents.get_source_by_id(source_id)[1]
return [str(offset + i) for i in range(len(texts))] # type: ignore[arg-type] return [str(offset + i) for i in range(len(texts))] # type: ignore[arg-type]
@root_validator(allow_reuse=True)
def validate_environments(cls, values: Dict) -> Dict:
"""Validate ThirdAI environment variables."""
values["thirdai_key"] = convert_to_secret_str(
get_from_dict_or_env(
values,
"thirdai_key",
"THIRDAI_KEY",
)
)
return values
def insert( # type: ignore[no-untyped-def, no-untyped-def] def insert( # type: ignore[no-untyped-def, no-untyped-def]
self, self,
sources: List[Any], sources: List[Any],

View File

@ -20,7 +20,7 @@ count=$(git grep -E '(@root_validator)|(@validator)|(@pre_init)' -- "*.py" | wc
# PRs that increase the current count will not be accepted. # PRs that increase the current count will not be accepted.
# PRs that decrease update the code in the repository # PRs that decrease update the code in the repository
# and allow decreasing the count of are welcome! # and allow decreasing the count of are welcome!
current_count=336 current_count=337
if [ "$count" -gt "$current_count" ]; then if [ "$count" -gt "$current_count" ]; then
echo "The PR seems to be introducing new usage of @root_validator and/or @field_validator." echo "The PR seems to be introducing new usage of @root_validator and/or @field_validator."

View File

@ -8,7 +8,7 @@ from langchain_community.embeddings import BaichuanTextEmbeddings
def test_sparkllm_initialization_by_alias() -> None: def test_sparkllm_initialization_by_alias() -> None:
# Effective initialization # Effective initialization
embeddings = BaichuanTextEmbeddings( # type: ignore[call-arg] embeddings = BaichuanTextEmbeddings( # type: ignore[call-arg]
model="embedding_model", # type: ignore[arg-type] model="embedding_model",
api_key="your-api-key", # type: ignore[arg-type] api_key="your-api-key", # type: ignore[arg-type]
) )
assert embeddings.model_name == "embedding_model" assert embeddings.model_name == "embedding_model"

View File

@ -2,7 +2,7 @@ import os
from typing import cast from typing import cast
import pytest import pytest
from langchain_core.pydantic_v1 import SecretStr, ValidationError from langchain_core.pydantic_v1 import SecretStr
from langchain_community.embeddings import SparkLLMTextEmbeddings from langchain_community.embeddings import SparkLLMTextEmbeddings
@ -43,5 +43,5 @@ def test_initialization_parameters_from_env() -> None:
# Environment variable missing # Environment variable missing
del os.environ["SPARK_APP_ID"] del os.environ["SPARK_APP_ID"]
with pytest.raises(ValidationError): with pytest.raises(ValueError):
SparkLLMTextEmbeddings() SparkLLMTextEmbeddings()