mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-27 05:20:34 +00:00
community[patch]: Remove usage of @root_validator(allow_reuse=True) (#25235)
Remove usage of @root_validator(allow_reuse=True)
This commit is contained in:
parent
a2b4c33bd6
commit
6e57aa7c36
@ -167,7 +167,7 @@ class GPTRouter(BaseChatModel):
|
||||
"""Number of chat completions to generate for each prompt."""
|
||||
max_tokens: int = 256
|
||||
|
||||
@root_validator(allow_reuse=True)
|
||||
@root_validator(pre=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
values["gpt_router_api_base"] = get_from_dict_or_env(
|
||||
values,
|
||||
@ -183,7 +183,10 @@ class GPTRouter(BaseChatModel):
|
||||
"GPT_ROUTER_API_KEY",
|
||||
)
|
||||
)
|
||||
return values
|
||||
|
||||
@root_validator(pre=True, skip_on_failure=True)
|
||||
def post_init(cls, values: Dict) -> Dict:
|
||||
try:
|
||||
from gpt_router.client import GPTRouterClient
|
||||
|
||||
|
@ -88,7 +88,7 @@ class ChatPerplexity(BaseChatModel):
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
return {"pplx_api_key": "PPLX_API_KEY"}
|
||||
|
||||
@root_validator(pre=True, allow_reuse=True)
|
||||
@root_validator(pre=True)
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = get_pydantic_field_names(cls)
|
||||
@ -114,7 +114,7 @@ class ChatPerplexity(BaseChatModel):
|
||||
values["model_kwargs"] = extra
|
||||
return values
|
||||
|
||||
@root_validator(allow_reuse=True)
|
||||
@root_validator(pre=False, skip_on_failure=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
values["pplx_api_key"] = get_from_dict_or_env(
|
||||
|
@ -3,7 +3,11 @@ from typing import Any, Dict, List, Optional
|
||||
import requests
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
||||
from langchain_core.utils import (
|
||||
convert_to_secret_str,
|
||||
get_from_dict_or_env,
|
||||
secret_from_env,
|
||||
)
|
||||
from requests import RequestException
|
||||
|
||||
BAICHUAN_API_URL: str = "http://api.baichuan-ai.com/v1/embeddings"
|
||||
@ -53,7 +57,10 @@ class BaichuanTextEmbeddings(BaseModel, Embeddings):
|
||||
session: Any #: :meta private:
|
||||
model_name: str = Field(default="Baichuan-Text-Embedding", alias="model")
|
||||
"""The model used to embed the documents."""
|
||||
baichuan_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
|
||||
baichuan_api_key: Optional[SecretStr] = Field(
|
||||
alias="api_key",
|
||||
default_factory=secret_from_env("BAICHUAN_API_KEY", default=None),
|
||||
)
|
||||
"""Automatically inferred from env var `BAICHUAN_API_KEY` if not provided."""
|
||||
chunk_size: int = 16
|
||||
"""Chunk size when multiple texts are input"""
|
||||
@ -61,22 +68,21 @@ class BaichuanTextEmbeddings(BaseModel, Embeddings):
|
||||
class Config:
|
||||
allow_population_by_field_name = True
|
||||
|
||||
@root_validator(allow_reuse=True)
|
||||
@root_validator(pre=False, skip_on_failure=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that auth token exists in environment."""
|
||||
try:
|
||||
if values["baichuan_api_key"] is None:
|
||||
# This is likely here for some backwards compatibility with
|
||||
# BAICHUAN_AUTH_TOKEN
|
||||
baichuan_api_key = convert_to_secret_str(
|
||||
get_from_dict_or_env(values, "baichuan_api_key", "BAICHUAN_API_KEY")
|
||||
)
|
||||
except ValueError as original_exc:
|
||||
try:
|
||||
baichuan_api_key = convert_to_secret_str(
|
||||
get_from_dict_or_env(
|
||||
values, "baichuan_auth_token", "BAICHUAN_AUTH_TOKEN"
|
||||
)
|
||||
get_from_dict_or_env(
|
||||
values, "baichuan_auth_token", "BAICHUAN_AUTH_TOKEN"
|
||||
)
|
||||
except ValueError:
|
||||
raise original_exc
|
||||
)
|
||||
values["baichuan_api_key"] = baichuan_api_key
|
||||
else:
|
||||
baichuan_api_key = values["baichuan_api_key"]
|
||||
|
||||
session = requests.Session()
|
||||
session.headers.update(
|
||||
{
|
||||
|
@ -53,7 +53,7 @@ class GradientEmbeddings(BaseModel, Embeddings):
|
||||
class Config:
|
||||
extra = "forbid"
|
||||
|
||||
@root_validator(allow_reuse=True)
|
||||
@root_validator(pre=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
|
||||
@ -65,8 +65,15 @@ class GradientEmbeddings(BaseModel, Embeddings):
|
||||
)
|
||||
|
||||
values["gradient_api_url"] = get_from_dict_or_env(
|
||||
values, "gradient_api_url", "GRADIENT_API_URL"
|
||||
values,
|
||||
"gradient_api_url",
|
||||
"GRADIENT_API_URL",
|
||||
default="https://api.gradient.ai/api",
|
||||
)
|
||||
return values
|
||||
|
||||
@root_validator(pre=False, skip_on_failure=True)
|
||||
def post_init(cls, values: Dict) -> Dict:
|
||||
try:
|
||||
import gradientai
|
||||
except ImportError:
|
||||
@ -85,7 +92,6 @@ class GradientEmbeddings(BaseModel, Embeddings):
|
||||
host=values["gradient_api_url"],
|
||||
)
|
||||
values["client"] = gradient.get_embeddings_model(slug=values["model"])
|
||||
|
||||
return values
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
|
@ -47,7 +47,7 @@ class InfinityEmbeddings(BaseModel, Embeddings):
|
||||
class Config:
|
||||
extra = "forbid"
|
||||
|
||||
@root_validator(allow_reuse=True)
|
||||
@root_validator(pre=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
|
||||
|
@ -60,7 +60,7 @@ class InfinityEmbeddingsLocal(BaseModel, Embeddings):
|
||||
class Config:
|
||||
extra = "forbid"
|
||||
|
||||
@root_validator(allow_reuse=True)
|
||||
@root_validator(pre=False, skip_on_failure=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
|
||||
|
@ -12,8 +12,10 @@ from wsgiref.handlers import format_date_time
|
||||
import numpy as np
|
||||
import requests
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr
|
||||
from langchain_core.utils import (
|
||||
secret_from_env,
|
||||
)
|
||||
from numpy import ndarray
|
||||
|
||||
# SparkLLMTextEmbeddings is an embedding model provided by iFLYTEK Co., Ltd.. (https://iflytek.com/en/).
|
||||
@ -102,11 +104,18 @@ class SparkLLMTextEmbeddings(BaseModel, Embeddings):
|
||||
]
|
||||
""" # noqa: E501
|
||||
|
||||
spark_app_id: Optional[SecretStr] = Field(default=None, alias="app_id")
|
||||
spark_app_id: SecretStr = Field(
|
||||
alias="app_id", default_factory=secret_from_env("SPARK_APP_ID")
|
||||
)
|
||||
"""Automatically inferred from env var `SPARK_APP_ID` if not provided."""
|
||||
spark_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
|
||||
spark_api_key: Optional[SecretStr] = Field(
|
||||
alias="api_key", default_factory=secret_from_env("SPARK_API_KEY", default=None)
|
||||
)
|
||||
"""Automatically inferred from env var `SPARK_API_KEY` if not provided."""
|
||||
spark_api_secret: Optional[SecretStr] = Field(default=None, alias="api_secret")
|
||||
spark_api_secret: Optional[SecretStr] = Field(
|
||||
alias="api_secret",
|
||||
default_factory=secret_from_env("SPARK_API_SECRET", default=None),
|
||||
)
|
||||
"""Automatically inferred from env var `SPARK_API_SECRET` if not provided."""
|
||||
base_url: str = Field(default="https://emb-cn-huabei-1.xf-yun.com/")
|
||||
"""Base URL path for API requests"""
|
||||
@ -118,20 +127,6 @@ class SparkLLMTextEmbeddings(BaseModel, Embeddings):
|
||||
class Config:
|
||||
allow_population_by_field_name = True
|
||||
|
||||
@root_validator(allow_reuse=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that auth token exists in environment."""
|
||||
values["spark_app_id"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(values, "spark_app_id", "SPARK_APP_ID")
|
||||
)
|
||||
values["spark_api_key"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(values, "spark_api_key", "SPARK_API_KEY")
|
||||
)
|
||||
values["spark_api_secret"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(values, "spark_api_secret", "SPARK_API_SECRET")
|
||||
)
|
||||
return values
|
||||
|
||||
def _embed(self, texts: List[str], host: str) -> Optional[List[List[float]]]:
|
||||
"""Internal method to call Spark Embedding API and return embeddings.
|
||||
|
||||
|
@ -77,7 +77,7 @@ class GradientLLM(BaseLLM):
|
||||
allow_population_by_field_name = True
|
||||
extra = "forbid"
|
||||
|
||||
@root_validator(allow_reuse=True)
|
||||
@root_validator(pre=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
|
||||
@ -88,6 +88,26 @@ class GradientLLM(BaseLLM):
|
||||
values, "gradient_workspace_id", "GRADIENT_WORKSPACE_ID"
|
||||
)
|
||||
|
||||
values["gradient_api_url"] = get_from_dict_or_env(
|
||||
values, "gradient_api_url", "GRADIENT_API_URL"
|
||||
)
|
||||
return values
|
||||
|
||||
@root_validator(pre=False, skip_on_failure=True)
|
||||
def post_init(cls, values: Dict) -> Dict:
|
||||
"""Post init validation."""
|
||||
# Can be most to post_init_validation
|
||||
try:
|
||||
import gradientai # noqa
|
||||
except ImportError:
|
||||
logging.warning(
|
||||
"DeprecationWarning: `GradientLLM` will use "
|
||||
"`pip install gradientai` in future releases of langchain."
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Can be most to post_init_validation
|
||||
if (
|
||||
values["gradient_access_token"] is None
|
||||
or len(values["gradient_access_token"]) < 10
|
||||
@ -114,20 +134,6 @@ class GradientLLM(BaseLLM):
|
||||
if 0 >= kw.get("max_generated_token_count", 1):
|
||||
raise ValueError("`max_generated_token_count` must be positive")
|
||||
|
||||
values["gradient_api_url"] = get_from_dict_or_env(
|
||||
values, "gradient_api_url", "GRADIENT_API_URL"
|
||||
)
|
||||
|
||||
try:
|
||||
import gradientai # noqa
|
||||
except ImportError:
|
||||
logging.warning(
|
||||
"DeprecationWarning: `GradientLLM` will use "
|
||||
"`pip install gradientai` in future releases of langchain."
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return values
|
||||
|
||||
@property
|
||||
|
@ -6,8 +6,6 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.pydantic_v1 import root_validator
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
|
||||
|
||||
@ -164,18 +162,6 @@ class NeuralDBVectorStore(VectorStore):
|
||||
offset = self.db._savable_state.documents.get_source_by_id(source_id)[1]
|
||||
return [str(offset + i) for i in range(len(texts))] # type: ignore[arg-type]
|
||||
|
||||
@root_validator(allow_reuse=True)
|
||||
def validate_environments(cls, values: Dict) -> Dict:
|
||||
"""Validate ThirdAI environment variables."""
|
||||
values["thirdai_key"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(
|
||||
values,
|
||||
"thirdai_key",
|
||||
"THIRDAI_KEY",
|
||||
)
|
||||
)
|
||||
return values
|
||||
|
||||
def insert( # type: ignore[no-untyped-def, no-untyped-def]
|
||||
self,
|
||||
sources: List[Any],
|
||||
|
@ -20,7 +20,7 @@ count=$(git grep -E '(@root_validator)|(@validator)|(@pre_init)' -- "*.py" | wc
|
||||
# PRs that increase the current count will not be accepted.
|
||||
# PRs that decrease update the code in the repository
|
||||
# and allow decreasing the count of are welcome!
|
||||
current_count=336
|
||||
current_count=337
|
||||
|
||||
if [ "$count" -gt "$current_count" ]; then
|
||||
echo "The PR seems to be introducing new usage of @root_validator and/or @field_validator."
|
||||
|
@ -8,7 +8,7 @@ from langchain_community.embeddings import BaichuanTextEmbeddings
|
||||
def test_sparkllm_initialization_by_alias() -> None:
|
||||
# Effective initialization
|
||||
embeddings = BaichuanTextEmbeddings( # type: ignore[call-arg]
|
||||
model="embedding_model", # type: ignore[arg-type]
|
||||
model="embedding_model",
|
||||
api_key="your-api-key", # type: ignore[arg-type]
|
||||
)
|
||||
assert embeddings.model_name == "embedding_model"
|
||||
|
@ -2,7 +2,7 @@ import os
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
from langchain_core.pydantic_v1 import SecretStr, ValidationError
|
||||
from langchain_core.pydantic_v1 import SecretStr
|
||||
|
||||
from langchain_community.embeddings import SparkLLMTextEmbeddings
|
||||
|
||||
@ -43,5 +43,5 @@ def test_initialization_parameters_from_env() -> None:
|
||||
|
||||
# Environment variable missing
|
||||
del os.environ["SPARK_APP_ID"]
|
||||
with pytest.raises(ValidationError):
|
||||
with pytest.raises(ValueError):
|
||||
SparkLLMTextEmbeddings()
|
||||
|
Loading…
Reference in New Issue
Block a user