From 2c180d645ea691d3430ed3b0382cc2b0be7cd404 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Mon, 8 Jul 2024 16:09:29 -0400 Subject: [PATCH] core[minor],community[minor]: Upgrade all @root_validator() to @pre_init (#23841) This PR introduces a @pre_init decorator that's a @root_validator(pre=True) but with all the defaults populated! --- .../chat_models/azure_openai.py | 6 +- .../langchain_community/chat_models/edenai.py | 5 +- .../chat_models/google_palm.py | 6 +- .../chat_models/hunyuan.py | 3 +- .../chat_models/jinachat.py | 3 +- .../chat_models/kinetica.py | 6 +- .../langchain_community/chat_models/konko.py | 6 +- .../chat_models/litellm.py | 6 +- .../chat_models/moonshot.py | 4 +- .../langchain_community/chat_models/octoai.py | 6 +- .../langchain_community/chat_models/openai.py | 3 +- .../langchain_community/chat_models/premai.py | 5 +- .../chat_models/snowflake.py | 3 +- .../langchain_community/chat_models/solar.py | 6 +- .../langchain_community/chat_models/tongyi.py | 6 +- .../chat_models/vertexai.py | 4 +- .../langchain_community/chat_models/yuan2.py | 3 +- .../embeddings/anyscale.py | 6 +- .../embeddings/baidu_qianfan_endpoint.py | 6 +- .../embeddings/deepinfra.py | 6 +- .../langchain_community/embeddings/edenai.py | 5 +- .../langchain_community/embeddings/embaas.py | 6 +- .../langchain_community/embeddings/ernie.py | 6 +- .../embeddings/fastembed.py | 5 +- .../embeddings/gigachat.py | 5 +- .../embeddings/google_palm.py | 6 +- .../langchain_community/embeddings/laser.py | 5 +- .../embeddings/llm_rails.py | 6 +- .../langchain_community/embeddings/localai.py | 8 +- .../langchain_community/embeddings/minimax.py | 6 +- .../langchain_community/embeddings/nemo.py | 5 +- .../embeddings/nlpcloud.py | 6 +- .../embeddings/oci_generative_ai.py | 5 +- .../embeddings/octoai_embeddings.py | 6 +- .../langchain_community/embeddings/openai.py | 8 +- .../langchain_community/embeddings/premai.py | 6 +- .../embeddings/sagemaker_endpoint.py | 5 +- .../embeddings/sambanova.py | 6 +- .../langchain_community/embeddings/solar.py | 6 +- .../embeddings/vertexai.py | 4 +- .../embeddings/volcengine.py | 6 +- .../langchain_community/embeddings/yandex.py | 6 +- .../langchain_community/llms/ai21.py | 6 +- .../langchain_community/llms/aleph_alpha.py | 8 +- .../langchain_community/llms/anthropic.py | 5 +- .../langchain_community/llms/anyscale.py | 6 +- .../langchain_community/llms/aphrodite.py | 5 +- .../langchain_community/llms/baichuan.py | 6 +- .../llms/baidu_qianfan_endpoint.py | 6 +- .../langchain_community/llms/bananadev.py | 4 +- .../langchain_community/llms/beam.py | 4 +- .../langchain_community/llms/bedrock.py | 8 +- .../langchain_community/llms/cerebriumai.py | 4 +- .../langchain_community/llms/clarifai.py | 5 +- .../langchain_community/llms/cohere.py | 6 +- .../langchain_community/llms/ctransformers.py | 4 +- .../langchain_community/llms/ctranslate2.py | 5 +- .../langchain_community/llms/deepinfra.py | 6 +- .../langchain_community/llms/deepsparse.py | 8 +- .../langchain_community/llms/edenai.py | 4 +- .../langchain_community/llms/exllamav2.py | 5 +- .../langchain_community/llms/fireworks.py | 6 +- .../langchain_community/llms/friendli.py | 5 +- .../langchain_community/llms/gigachat.py | 4 +- .../langchain_community/llms/google_palm.py | 6 +- .../langchain_community/llms/gooseai.py | 4 +- .../langchain_community/llms/gpt4all.py | 5 +- .../llms/huggingface_endpoint.py | 8 +- .../llms/huggingface_hub.py | 6 +- .../llms/huggingface_text_gen_inference.py | 4 +- .../langchain_community/llms/llamacpp.py | 4 +- .../langchain_community/llms/manifest.py | 5 +- .../langchain_community/llms/minimax.py | 4 +- .../langchain_community/llms/moonshot.py | 4 +- .../langchain_community/llms/mosaicml.py | 6 +- .../langchain_community/llms/nlpcloud.py | 6 +- ..._data_science_model_deployment_endpoint.py | 6 +- .../llms/oci_generative_ai.py | 5 +- .../llms/octoai_endpoint.py | 6 +- .../langchain_community/llms/opaqueprompts.py | 6 +- .../langchain_community/llms/openai.py | 12 ++- .../langchain_community/llms/openlm.py | 4 +- .../llms/pai_eas_endpoint.py | 5 +- .../langchain_community/llms/petals.py | 4 +- .../langchain_community/llms/pipelineai.py | 4 +- .../llms/predictionguard.py | 6 +- .../langchain_community/llms/replicate.py | 4 +- .../langchain_community/llms/rwkv.py | 5 +- .../llms/sagemaker_endpoint.py | 5 +- .../langchain_community/llms/sambanova.py | 8 +- .../langchain_community/llms/solar.py | 4 +- .../langchain_community/llms/sparkllm.py | 6 +- .../langchain_community/llms/stochasticai.py | 4 +- .../llms/symblai_nebula.py | 6 +- .../langchain_community/llms/tongyi.py | 6 +- .../langchain_community/llms/vertexai.py | 7 +- .../langchain_community/llms/vllm.py | 5 +- .../llms/volcengine_maas.py | 6 +- .../langchain_community/llms/watsonxllm.py | 6 +- .../langchain_community/llms/writer.py | 6 +- .../langchain_community/llms/yandex.py | 6 +- .../langchain_community/retrievers/arcee.py | 6 +- .../google_cloud_documentai_warehouse.py | 5 +- .../retrievers/pinecone_hybrid_search.py | 5 +- .../qdrant_sparse_vector_retriever.py | 5 +- .../retrievers/thirdai_neuraldb.py | 6 +- .../vectorstores/test_vectara.py | 6 +- .../tests/unit_tests/llms/test_aleph_alpha.py | 4 +- .../tests/unit_tests/llms/test_bedrock.py | 7 +- .../load/__snapshots__/test_dump.ambr | 6 +- libs/core/langchain_core/utils/__init__.py | 2 + libs/core/langchain_core/utils/pydantic.py | 37 +++++++++ .../tests/unit_tests/utils/test_imports.py | 1 + .../tests/unit_tests/utils/test_pydantic.py | 75 +++++++++++++++++++ 114 files changed, 439 insertions(+), 276 deletions(-) create mode 100644 libs/core/tests/unit_tests/utils/test_pydantic.py diff --git a/libs/community/langchain_community/chat_models/azure_openai.py b/libs/community/langchain_community/chat_models/azure_openai.py index d70b9f76029..e53112583a6 100644 --- a/libs/community/langchain_community/chat_models/azure_openai.py +++ b/libs/community/langchain_community/chat_models/azure_openai.py @@ -9,8 +9,8 @@ from typing import Any, Callable, Dict, List, Union from langchain_core._api.deprecation import deprecated from langchain_core.outputs import ChatResult -from langchain_core.pydantic_v1 import BaseModel, Field, root_validator -from langchain_core.utils import get_from_dict_or_env +from langchain_core.pydantic_v1 import BaseModel, Field +from langchain_core.utils import get_from_dict_or_env, pre_init from langchain_community.chat_models.openai import ChatOpenAI from langchain_community.utils.openai import is_openai_v1 @@ -106,7 +106,7 @@ class AzureChatOpenAI(ChatOpenAI): """Get the namespace of the langchain object.""" return ["langchain", "chat_models", "azure_openai"] - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" if values["n"] < 1: diff --git a/libs/community/langchain_community/chat_models/edenai.py b/libs/community/langchain_community/chat_models/edenai.py index 28dcb2347c5..0f16b0937ba 100644 --- a/libs/community/langchain_community/chat_models/edenai.py +++ b/libs/community/langchain_community/chat_models/edenai.py @@ -50,11 +50,10 @@ from langchain_core.pydantic_v1 import ( Extra, Field, SecretStr, - root_validator, ) from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough from langchain_core.tools import BaseTool -from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env +from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init from langchain_core.utils.function_calling import convert_to_openai_tool from langchain_community.utilities.requests import Requests @@ -300,7 +299,7 @@ class ChatEdenAI(BaseChatModel): extra = Extra.forbid - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that api key exists in environment.""" values["edenai_api_key"] = convert_to_secret_str( diff --git a/libs/community/langchain_community/chat_models/google_palm.py b/libs/community/langchain_community/chat_models/google_palm.py index 9fbb55abe33..28d1d9018ce 100644 --- a/libs/community/langchain_community/chat_models/google_palm.py +++ b/libs/community/langchain_community/chat_models/google_palm.py @@ -21,8 +21,8 @@ from langchain_core.outputs import ( ChatGeneration, ChatResult, ) -from langchain_core.pydantic_v1 import BaseModel, SecretStr, root_validator -from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env +from langchain_core.pydantic_v1 import BaseModel, SecretStr +from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init from tenacity import ( before_sleep_log, retry, @@ -261,7 +261,7 @@ class ChatGooglePalm(BaseChatModel, BaseModel): """Get the namespace of the langchain object.""" return ["langchain", "chat_models", "google_palm"] - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate api key, python package exists, temperature, top_p, and top_k.""" google_api_key = convert_to_secret_str( diff --git a/libs/community/langchain_community/chat_models/hunyuan.py b/libs/community/langchain_community/chat_models/hunyuan.py index e7681d41551..45ab9212367 100644 --- a/libs/community/langchain_community/chat_models/hunyuan.py +++ b/libs/community/langchain_community/chat_models/hunyuan.py @@ -29,6 +29,7 @@ from langchain_core.utils import ( convert_to_secret_str, get_from_dict_or_env, get_pydantic_field_names, + pre_init, ) logger = logging.getLogger(__name__) @@ -190,7 +191,7 @@ class ChatHunyuan(BaseChatModel): values["model_kwargs"] = extra return values - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: values["hunyuan_api_base"] = get_from_dict_or_env( values, diff --git a/libs/community/langchain_community/chat_models/jinachat.py b/libs/community/langchain_community/chat_models/jinachat.py index dae380092cc..4e0a0f01ba2 100644 --- a/libs/community/langchain_community/chat_models/jinachat.py +++ b/libs/community/langchain_community/chat_models/jinachat.py @@ -45,6 +45,7 @@ from langchain_core.utils import ( convert_to_secret_str, get_from_dict_or_env, get_pydantic_field_names, + pre_init, ) from tenacity import ( before_sleep_log, @@ -218,7 +219,7 @@ class JinaChat(BaseChatModel): values["model_kwargs"] = extra return values - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" values["jinachat_api_key"] = convert_to_secret_str( diff --git a/libs/community/langchain_community/chat_models/kinetica.py b/libs/community/langchain_community/chat_models/kinetica.py index 9362e550608..3724d82b38c 100644 --- a/libs/community/langchain_community/chat_models/kinetica.py +++ b/libs/community/langchain_community/chat_models/kinetica.py @@ -11,6 +11,8 @@ from importlib.metadata import version from pathlib import Path from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast +from langchain_core.utils import pre_init + if TYPE_CHECKING: import gpudb @@ -24,7 +26,7 @@ from langchain_core.messages import ( ) from langchain_core.output_parsers.transform import BaseOutputParser from langchain_core.outputs import ChatGeneration, ChatResult, Generation -from langchain_core.pydantic_v1 import BaseModel, Field, root_validator +from langchain_core.pydantic_v1 import BaseModel, Field LOG = logging.getLogger(__name__) @@ -341,7 +343,7 @@ class ChatKinetica(BaseChatModel): kdbc: Any = Field(exclude=True) """ Kinetica DB connection. """ - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Pydantic object validator.""" diff --git a/libs/community/langchain_community/chat_models/konko.py b/libs/community/langchain_community/chat_models/konko.py index efc6f64b27e..d4ee329d835 100644 --- a/libs/community/langchain_community/chat_models/konko.py +++ b/libs/community/langchain_community/chat_models/konko.py @@ -23,8 +23,8 @@ from langchain_core.callbacks import ( ) from langchain_core.messages import AIMessageChunk, BaseMessage from langchain_core.outputs import ChatGenerationChunk, ChatResult -from langchain_core.pydantic_v1 import Field, SecretStr, root_validator -from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env +from langchain_core.pydantic_v1 import Field, SecretStr +from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init from langchain_community.adapters.openai import ( convert_message_to_dict, @@ -85,7 +85,7 @@ class ChatKonko(ChatOpenAI): max_tokens: int = 20 """Maximum number of tokens to generate.""" - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" values["konko_api_key"] = convert_to_secret_str( diff --git a/libs/community/langchain_community/chat_models/litellm.py b/libs/community/langchain_community/chat_models/litellm.py index e57a5b07252..a16cb187c4f 100644 --- a/libs/community/langchain_community/chat_models/litellm.py +++ b/libs/community/langchain_community/chat_models/litellm.py @@ -48,10 +48,10 @@ from langchain_core.outputs import ( ChatGenerationChunk, ChatResult, ) -from langchain_core.pydantic_v1 import BaseModel, Field, root_validator +from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.runnables import Runnable from langchain_core.tools import BaseTool -from langchain_core.utils import get_from_dict_or_env +from langchain_core.utils import get_from_dict_or_env, pre_init from langchain_core.utils.function_calling import convert_to_openai_tool logger = logging.getLogger(__name__) @@ -249,7 +249,7 @@ class ChatLiteLLM(BaseChatModel): return _completion_with_retry(**kwargs) - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate api key, python package exists, temperature, top_p, and top_k.""" try: diff --git a/libs/community/langchain_community/chat_models/moonshot.py b/libs/community/langchain_community/chat_models/moonshot.py index 17966960e1e..fd8a455f888 100644 --- a/libs/community/langchain_community/chat_models/moonshot.py +++ b/libs/community/langchain_community/chat_models/moonshot.py @@ -2,10 +2,10 @@ from typing import Dict -from langchain_core.pydantic_v1 import root_validator from langchain_core.utils import ( convert_to_secret_str, get_from_dict_or_env, + pre_init, ) from langchain_community.chat_models import ChatOpenAI @@ -29,7 +29,7 @@ class MoonshotChat(MoonshotCommon, ChatOpenAI): # type: ignore[misc] moonshot = MoonshotChat(model="moonshot-v1-8k") """ - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that the environment is set up correctly.""" values["moonshot_api_key"] = convert_to_secret_str( diff --git a/libs/community/langchain_community/chat_models/octoai.py b/libs/community/langchain_community/chat_models/octoai.py index a284989cd9b..d799e1db7b7 100644 --- a/libs/community/langchain_community/chat_models/octoai.py +++ b/libs/community/langchain_community/chat_models/octoai.py @@ -2,8 +2,8 @@ from typing import Dict -from langchain_core.pydantic_v1 import Field, SecretStr, root_validator -from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env +from langchain_core.pydantic_v1 import Field, SecretStr +from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init from langchain_community.chat_models.openai import ChatOpenAI from langchain_community.utils.openai import is_openai_v1 @@ -48,7 +48,7 @@ class ChatOctoAI(ChatOpenAI): def is_lc_serializable(cls) -> bool: return False - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" values["octoai_api_base"] = get_from_dict_or_env( diff --git a/libs/community/langchain_community/chat_models/openai.py b/libs/community/langchain_community/chat_models/openai.py index daea2c1050d..3a0ae345110 100644 --- a/libs/community/langchain_community/chat_models/openai.py +++ b/libs/community/langchain_community/chat_models/openai.py @@ -49,6 +49,7 @@ from langchain_core.runnables import Runnable from langchain_core.utils import ( get_from_dict_or_env, get_pydantic_field_names, + pre_init, ) from langchain_community.adapters.openai import ( @@ -274,7 +275,7 @@ class ChatOpenAI(BaseChatModel): values["model_kwargs"] = extra return values - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" if values["n"] < 1: diff --git a/libs/community/langchain_community/chat_models/premai.py b/libs/community/langchain_community/chat_models/premai.py index b326434e9cc..311e94e763e 100644 --- a/libs/community/langchain_community/chat_models/premai.py +++ b/libs/community/langchain_community/chat_models/premai.py @@ -40,9 +40,8 @@ from langchain_core.pydantic_v1 import ( Extra, Field, SecretStr, - root_validator, ) -from langchain_core.utils import get_from_dict_or_env +from langchain_core.utils import get_from_dict_or_env, pre_init if TYPE_CHECKING: from premai.api.chat_completions.v1_chat_completions_create import ( @@ -249,7 +248,7 @@ class ChatPremAI(BaseChatModel, BaseModel): allow_population_by_field_name = True arbitrary_types_allowed = True - @root_validator() + @pre_init def validate_environments(cls, values: Dict) -> Dict: """Validate that the package is installed and that the API token is valid""" try: diff --git a/libs/community/langchain_community/chat_models/snowflake.py b/libs/community/langchain_community/chat_models/snowflake.py index c25d2254f97..f5af38afdd2 100644 --- a/libs/community/langchain_community/chat_models/snowflake.py +++ b/libs/community/langchain_community/chat_models/snowflake.py @@ -16,6 +16,7 @@ from langchain_core.utils import ( convert_to_secret_str, get_from_dict_or_env, get_pydantic_field_names, + pre_init, ) from langchain_core.utils.utils import build_extra_kwargs @@ -135,7 +136,7 @@ class ChatSnowflakeCortex(BaseChatModel): ) return values - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: try: from snowflake.snowpark import Session diff --git a/libs/community/langchain_community/chat_models/solar.py b/libs/community/langchain_community/chat_models/solar.py index 417880757bb..8363d55895a 100644 --- a/libs/community/langchain_community/chat_models/solar.py +++ b/libs/community/langchain_community/chat_models/solar.py @@ -3,8 +3,8 @@ from typing import Dict from langchain_core._api import deprecated -from langchain_core.pydantic_v1 import Field, root_validator -from langchain_core.utils import get_from_dict_or_env +from langchain_core.pydantic_v1 import Field +from langchain_core.utils import get_from_dict_or_env, pre_init from langchain_community.chat_models import ChatOpenAI from langchain_community.llms.solar import SOLAR_SERVICE_URL_BASE, SolarCommon @@ -37,7 +37,7 @@ class SolarChat(SolarCommon, ChatOpenAI): arbitrary_types_allowed = True extra = "ignore" - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that the environment is set up correctly.""" values["solar_api_key"] = get_from_dict_or_env( diff --git a/libs/community/langchain_community/chat_models/tongyi.py b/libs/community/langchain_community/chat_models/tongyi.py index 939dd620d5d..7941aa07264 100644 --- a/libs/community/langchain_community/chat_models/tongyi.py +++ b/libs/community/langchain_community/chat_models/tongyi.py @@ -49,10 +49,10 @@ from langchain_core.outputs import ( ChatGenerationChunk, ChatResult, ) -from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator +from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr from langchain_core.runnables import Runnable from langchain_core.tools import BaseTool -from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env +from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init from langchain_core.utils.function_calling import convert_to_openai_tool from requests.exceptions import HTTPError from tenacity import ( @@ -431,7 +431,7 @@ class ChatTongyi(BaseChatModel): """Return type of llm.""" return "tongyi" - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" values["dashscope_api_key"] = convert_to_secret_str( diff --git a/libs/community/langchain_community/chat_models/vertexai.py b/libs/community/langchain_community/chat_models/vertexai.py index 39da4e4afc5..db038f615c1 100644 --- a/libs/community/langchain_community/chat_models/vertexai.py +++ b/libs/community/langchain_community/chat_models/vertexai.py @@ -27,7 +27,7 @@ from langchain_core.messages import ( SystemMessage, ) from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult -from langchain_core.pydantic_v1 import root_validator +from langchain_core.utils import pre_init from langchain_community.llms.vertexai import ( _VertexAICommon, @@ -225,7 +225,7 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel): """Get the namespace of the langchain object.""" return ["langchain", "chat_models", "vertexai"] - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that the python package exists in environment.""" is_gemini = is_gemini_model(values["model_name"]) diff --git a/libs/community/langchain_community/chat_models/yuan2.py b/libs/community/langchain_community/chat_models/yuan2.py index 87640ba048d..0deb6394e92 100644 --- a/libs/community/langchain_community/chat_models/yuan2.py +++ b/libs/community/langchain_community/chat_models/yuan2.py @@ -44,6 +44,7 @@ from langchain_core.pydantic_v1 import BaseModel, Field, root_validator from langchain_core.utils import ( get_from_dict_or_env, get_pydantic_field_names, + pre_init, ) from tenacity import ( before_sleep_log, @@ -166,7 +167,7 @@ class ChatYuan2(BaseChatModel): values["model_kwargs"] = extra return values - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" values["yuan2_api_key"] = get_from_dict_or_env( diff --git a/libs/community/langchain_community/embeddings/anyscale.py b/libs/community/langchain_community/embeddings/anyscale.py index 135b0918bf2..831946a4da5 100644 --- a/libs/community/langchain_community/embeddings/anyscale.py +++ b/libs/community/langchain_community/embeddings/anyscale.py @@ -4,8 +4,8 @@ from __future__ import annotations from typing import Dict -from langchain_core.pydantic_v1 import Field, SecretStr, root_validator -from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env +from langchain_core.pydantic_v1 import Field, SecretStr +from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init from langchain_community.embeddings.openai import OpenAIEmbeddings from langchain_community.utils.openai import is_openai_v1 @@ -34,7 +34,7 @@ class AnyscaleEmbeddings(OpenAIEmbeddings): "anyscale_api_key": "ANYSCALE_API_KEY", } - @root_validator() + @pre_init def validate_environment(cls, values: dict) -> dict: """Validate that api key and python package exists in environment.""" values["anyscale_api_key"] = convert_to_secret_str( diff --git a/libs/community/langchain_community/embeddings/baidu_qianfan_endpoint.py b/libs/community/langchain_community/embeddings/baidu_qianfan_endpoint.py index 41bbd96984d..d2b75578a35 100644 --- a/libs/community/langchain_community/embeddings/baidu_qianfan_endpoint.py +++ b/libs/community/langchain_community/embeddings/baidu_qianfan_endpoint.py @@ -4,8 +4,8 @@ import logging from typing import Any, Dict, List, Optional from langchain_core.embeddings import Embeddings -from langchain_core.pydantic_v1 import BaseModel, Field, root_validator -from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env +from langchain_core.pydantic_v1 import BaseModel, Field +from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init logger = logging.getLogger(__name__) @@ -48,7 +48,7 @@ class QianfanEmbeddingsEndpoint(BaseModel, Embeddings): model_kwargs: Dict[str, Any] = Field(default_factory=dict) """extra params for model invoke using with `do`.""" - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """ Validate whether qianfan_ak and qianfan_sk in the environment variables or diff --git a/libs/community/langchain_community/embeddings/deepinfra.py b/libs/community/langchain_community/embeddings/deepinfra.py index 4dbc0e4dad8..09f2c0dbe48 100644 --- a/libs/community/langchain_community/embeddings/deepinfra.py +++ b/libs/community/langchain_community/embeddings/deepinfra.py @@ -2,8 +2,8 @@ from typing import Any, Dict, List, Mapping, Optional import requests from langchain_core.embeddings import Embeddings -from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator -from langchain_core.utils import get_from_dict_or_env +from langchain_core.pydantic_v1 import BaseModel, Extra +from langchain_core.utils import get_from_dict_or_env, pre_init DEFAULT_MODEL_ID = "sentence-transformers/clip-ViT-B-32" MAX_BATCH_SIZE = 1024 @@ -59,7 +59,7 @@ class DeepInfraEmbeddings(BaseModel, Embeddings): extra = Extra.forbid - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" deepinfra_api_token = get_from_dict_or_env( diff --git a/libs/community/langchain_community/embeddings/edenai.py b/libs/community/langchain_community/embeddings/edenai.py index 9446969d332..7bebad4b608 100644 --- a/libs/community/langchain_community/embeddings/edenai.py +++ b/libs/community/langchain_community/embeddings/edenai.py @@ -6,9 +6,8 @@ from langchain_core.pydantic_v1 import ( Extra, Field, SecretStr, - root_validator, ) -from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env +from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init from langchain_community.utilities.requests import Requests @@ -35,7 +34,7 @@ class EdenAiEmbeddings(BaseModel, Embeddings): extra = Extra.forbid - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that api key exists in environment.""" values["edenai_api_key"] = convert_to_secret_str( diff --git a/libs/community/langchain_community/embeddings/embaas.py b/libs/community/langchain_community/embeddings/embaas.py index 00f23116bf4..9801267f464 100644 --- a/libs/community/langchain_community/embeddings/embaas.py +++ b/libs/community/langchain_community/embeddings/embaas.py @@ -2,8 +2,8 @@ from typing import Any, Dict, List, Mapping, Optional import requests from langchain_core.embeddings import Embeddings -from langchain_core.pydantic_v1 import BaseModel, Extra, SecretStr, root_validator -from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env +from langchain_core.pydantic_v1 import BaseModel, Extra, SecretStr +from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init from requests.adapters import HTTPAdapter, Retry from typing_extensions import NotRequired, TypedDict @@ -61,7 +61,7 @@ class EmbaasEmbeddings(BaseModel, Embeddings): extra = Extra.forbid - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" embaas_api_key = convert_to_secret_str( diff --git a/libs/community/langchain_community/embeddings/ernie.py b/libs/community/langchain_community/embeddings/ernie.py index b6bff742bd1..66888bbc68a 100644 --- a/libs/community/langchain_community/embeddings/ernie.py +++ b/libs/community/langchain_community/embeddings/ernie.py @@ -6,9 +6,9 @@ from typing import Dict, List, Optional import requests from langchain_core._api.deprecation import deprecated from langchain_core.embeddings import Embeddings -from langchain_core.pydantic_v1 import BaseModel, root_validator +from langchain_core.pydantic_v1 import BaseModel from langchain_core.runnables.config import run_in_executor -from langchain_core.utils import get_from_dict_or_env +from langchain_core.utils import get_from_dict_or_env, pre_init logger = logging.getLogger(__name__) @@ -31,7 +31,7 @@ class ErnieEmbeddings(BaseModel, Embeddings): _lock = threading.Lock() - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: values["ernie_api_base"] = get_from_dict_or_env( values, "ernie_api_base", "ERNIE_API_BASE", "https://aip.baidubce.com" diff --git a/libs/community/langchain_community/embeddings/fastembed.py b/libs/community/langchain_community/embeddings/fastembed.py index 4f03ab70fe2..a5ebfdebda1 100644 --- a/libs/community/langchain_community/embeddings/fastembed.py +++ b/libs/community/langchain_community/embeddings/fastembed.py @@ -2,7 +2,8 @@ from typing import Any, Dict, List, Literal, Optional import numpy as np from langchain_core.embeddings import Embeddings -from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator +from langchain_core.pydantic_v1 import BaseModel, Extra +from langchain_core.utils import pre_init class FastEmbedEmbeddings(BaseModel, Embeddings): @@ -54,7 +55,7 @@ class FastEmbedEmbeddings(BaseModel, Embeddings): extra = Extra.forbid - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that FastEmbed has been installed.""" model_name = values.get("model_name") diff --git a/libs/community/langchain_community/embeddings/gigachat.py b/libs/community/langchain_community/embeddings/gigachat.py index 8fb6f233fc6..473878103cf 100644 --- a/libs/community/langchain_community/embeddings/gigachat.py +++ b/libs/community/langchain_community/embeddings/gigachat.py @@ -5,7 +5,8 @@ from functools import cached_property from typing import Any, Dict, List, Optional from langchain_core.embeddings import Embeddings -from langchain_core.pydantic_v1 import BaseModel, root_validator +from langchain_core.pydantic_v1 import BaseModel +from langchain_core.utils import pre_init logger = logging.getLogger(__name__) @@ -77,7 +78,7 @@ class GigaChatEmbeddings(BaseModel, Embeddings): key_file_password=self.key_file_password, ) - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate authenticate data in environment and python package is installed.""" try: diff --git a/libs/community/langchain_community/embeddings/google_palm.py b/libs/community/langchain_community/embeddings/google_palm.py index 0ffb5cfb6fb..2c960c93e84 100644 --- a/libs/community/langchain_community/embeddings/google_palm.py +++ b/libs/community/langchain_community/embeddings/google_palm.py @@ -4,8 +4,8 @@ import logging from typing import Any, Callable, Dict, List, Optional from langchain_core.embeddings import Embeddings -from langchain_core.pydantic_v1 import BaseModel, root_validator -from langchain_core.utils import get_from_dict_or_env +from langchain_core.pydantic_v1 import BaseModel +from langchain_core.utils import get_from_dict_or_env, pre_init from tenacity import ( before_sleep_log, retry, @@ -62,7 +62,7 @@ class GooglePalmEmbeddings(BaseModel, Embeddings): show_progress_bar: bool = False """Whether to show a tqdm progress bar. Must have `tqdm` installed.""" - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate api key, python package exists.""" google_api_key = get_from_dict_or_env( diff --git a/libs/community/langchain_community/embeddings/laser.py b/libs/community/langchain_community/embeddings/laser.py index 675cd74f211..f85d9f6eb1e 100644 --- a/libs/community/langchain_community/embeddings/laser.py +++ b/libs/community/langchain_community/embeddings/laser.py @@ -2,7 +2,8 @@ from typing import Any, Dict, List, Optional import numpy as np from langchain_core.embeddings import Embeddings -from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator +from langchain_core.pydantic_v1 import BaseModel, Extra +from langchain_core.utils import pre_init LASER_MULTILINGUAL_MODEL: str = "laser2" @@ -41,7 +42,7 @@ class LaserEmbeddings(BaseModel, Embeddings): extra = Extra.forbid - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that laser_encoders has been installed.""" try: diff --git a/libs/community/langchain_community/embeddings/llm_rails.py b/libs/community/langchain_community/embeddings/llm_rails.py index 5019cba61cf..fe10dad0b85 100644 --- a/libs/community/langchain_community/embeddings/llm_rails.py +++ b/libs/community/langchain_community/embeddings/llm_rails.py @@ -4,8 +4,8 @@ from typing import Dict, List, Optional import requests from langchain_core.embeddings import Embeddings -from langchain_core.pydantic_v1 import BaseModel, Extra, SecretStr, root_validator -from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env +from langchain_core.pydantic_v1 import BaseModel, Extra, SecretStr +from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init class LLMRailsEmbeddings(BaseModel, Embeddings): @@ -37,7 +37,7 @@ class LLMRailsEmbeddings(BaseModel, Embeddings): extra = Extra.forbid - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that api key exists in environment.""" api_key = convert_to_secret_str( diff --git a/libs/community/langchain_community/embeddings/localai.py b/libs/community/langchain_community/embeddings/localai.py index fe620e9a6fa..45cd41c3334 100644 --- a/libs/community/langchain_community/embeddings/localai.py +++ b/libs/community/langchain_community/embeddings/localai.py @@ -17,7 +17,11 @@ from typing import ( from langchain_core.embeddings import Embeddings from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator -from langchain_core.utils import get_from_dict_or_env, get_pydantic_field_names +from langchain_core.utils import ( + get_from_dict_or_env, + get_pydantic_field_names, + pre_init, +) from tenacity import ( AsyncRetrying, before_sleep_log, @@ -193,7 +197,7 @@ class LocalAIEmbeddings(BaseModel, Embeddings): values["model_kwargs"] = extra return values - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" values["openai_api_key"] = get_from_dict_or_env( diff --git a/libs/community/langchain_community/embeddings/minimax.py b/libs/community/langchain_community/embeddings/minimax.py index 38e56c553b8..810708dca84 100644 --- a/libs/community/langchain_community/embeddings/minimax.py +++ b/libs/community/langchain_community/embeddings/minimax.py @@ -5,8 +5,8 @@ from typing import Any, Callable, Dict, List, Optional import requests from langchain_core.embeddings import Embeddings -from langchain_core.pydantic_v1 import BaseModel, Extra, SecretStr, root_validator -from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env +from langchain_core.pydantic_v1 import BaseModel, Extra, SecretStr +from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init from tenacity import ( before_sleep_log, retry, @@ -84,7 +84,7 @@ class MiniMaxEmbeddings(BaseModel, Embeddings): extra = Extra.forbid - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that group id and api key exists in environment.""" minimax_group_id = get_from_dict_or_env( diff --git a/libs/community/langchain_community/embeddings/nemo.py b/libs/community/langchain_community/embeddings/nemo.py index b21c0d7e000..95d00c3bbb4 100644 --- a/libs/community/langchain_community/embeddings/nemo.py +++ b/libs/community/langchain_community/embeddings/nemo.py @@ -8,7 +8,8 @@ import aiohttp import requests from langchain_core._api.deprecation import deprecated from langchain_core.embeddings import Embeddings -from langchain_core.pydantic_v1 import BaseModel, root_validator +from langchain_core.pydantic_v1 import BaseModel +from langchain_core.utils import pre_init def is_endpoint_live(url: str, headers: Optional[dict], payload: Any) -> bool: @@ -58,7 +59,7 @@ class NeMoEmbeddings(BaseModel, Embeddings): model: str = "NV-Embed-QA-003" api_endpoint_url: str = "http://localhost:8088/v1/embeddings" - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that the end point is alive using the values that are provided.""" diff --git a/libs/community/langchain_community/embeddings/nlpcloud.py b/libs/community/langchain_community/embeddings/nlpcloud.py index 748d63b9005..69f08d28572 100644 --- a/libs/community/langchain_community/embeddings/nlpcloud.py +++ b/libs/community/langchain_community/embeddings/nlpcloud.py @@ -1,8 +1,8 @@ from typing import Any, Dict, List from langchain_core.embeddings import Embeddings -from langchain_core.pydantic_v1 import BaseModel, root_validator -from langchain_core.utils import get_from_dict_or_env +from langchain_core.pydantic_v1 import BaseModel +from langchain_core.utils import get_from_dict_or_env, pre_init class NLPCloudEmbeddings(BaseModel, Embeddings): @@ -30,7 +30,7 @@ class NLPCloudEmbeddings(BaseModel, Embeddings): ) -> None: super().__init__(model_name=model_name, gpu=gpu, **kwargs) - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" nlpcloud_api_key = get_from_dict_or_env( diff --git a/libs/community/langchain_community/embeddings/oci_generative_ai.py b/libs/community/langchain_community/embeddings/oci_generative_ai.py index 9b08aee4e44..70d6a7ed4f5 100644 --- a/libs/community/langchain_community/embeddings/oci_generative_ai.py +++ b/libs/community/langchain_community/embeddings/oci_generative_ai.py @@ -2,7 +2,8 @@ from enum import Enum from typing import Any, Dict, Iterator, List, Mapping, Optional from langchain_core.embeddings import Embeddings -from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator +from langchain_core.pydantic_v1 import BaseModel, Extra +from langchain_core.utils import pre_init CUSTOM_ENDPOINT_PREFIX = "ocid1.generativeaiendpoint" @@ -89,7 +90,7 @@ class OCIGenAIEmbeddings(BaseModel, Embeddings): extra = Extra.forbid - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: # pylint: disable=no-self-argument """Validate that OCI config and python package exists in environment.""" diff --git a/libs/community/langchain_community/embeddings/octoai_embeddings.py b/libs/community/langchain_community/embeddings/octoai_embeddings.py index 43ce14c38e6..287c04c2645 100644 --- a/libs/community/langchain_community/embeddings/octoai_embeddings.py +++ b/libs/community/langchain_community/embeddings/octoai_embeddings.py @@ -1,7 +1,7 @@ from typing import Dict -from langchain_core.pydantic_v1 import Field, SecretStr, root_validator -from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env +from langchain_core.pydantic_v1 import Field, SecretStr +from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init from langchain_community.embeddings.openai import OpenAIEmbeddings from langchain_community.utils.openai import is_openai_v1 @@ -38,7 +38,7 @@ class OctoAIEmbeddings(OpenAIEmbeddings): def lc_secrets(self) -> Dict[str, str]: return {"octoai_api_token": "OCTOAI_API_TOKEN"} - @root_validator() + @pre_init def validate_environment(cls, values: dict) -> dict: """Validate that api key and python package exists in environment.""" values["endpoint_url"] = get_from_dict_or_env( diff --git a/libs/community/langchain_community/embeddings/openai.py b/libs/community/langchain_community/embeddings/openai.py index 22ddac802a4..713e2535d85 100644 --- a/libs/community/langchain_community/embeddings/openai.py +++ b/libs/community/langchain_community/embeddings/openai.py @@ -22,7 +22,11 @@ import numpy as np from langchain_core._api.deprecation import deprecated from langchain_core.embeddings import Embeddings from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator -from langchain_core.utils import get_from_dict_or_env, get_pydantic_field_names +from langchain_core.utils import ( + get_from_dict_or_env, + get_pydantic_field_names, + pre_init, +) from tenacity import ( AsyncRetrying, before_sleep_log, @@ -282,7 +286,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): values["model_kwargs"] = extra return values - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" values["openai_api_key"] = get_from_dict_or_env( diff --git a/libs/community/langchain_community/embeddings/premai.py b/libs/community/langchain_community/embeddings/premai.py index 9f2585c3ed2..0ae42759776 100644 --- a/libs/community/langchain_community/embeddings/premai.py +++ b/libs/community/langchain_community/embeddings/premai.py @@ -5,8 +5,8 @@ from typing import Any, Callable, Dict, List, Optional, Union from langchain_core.embeddings import Embeddings from langchain_core.language_models.llms import create_base_retry_decorator -from langchain_core.pydantic_v1 import BaseModel, SecretStr, root_validator -from langchain_core.utils import get_from_dict_or_env +from langchain_core.pydantic_v1 import BaseModel, SecretStr +from langchain_core.utils import get_from_dict_or_env, pre_init logger = logging.getLogger(__name__) @@ -32,7 +32,7 @@ class PremAIEmbeddings(BaseModel, Embeddings): client: Any - @root_validator() + @pre_init def validate_environments(cls, values: Dict) -> Dict: """Validate that the package is installed and that the API token is valid""" try: diff --git a/libs/community/langchain_community/embeddings/sagemaker_endpoint.py b/libs/community/langchain_community/embeddings/sagemaker_endpoint.py index 1f01e1ecf4d..945b33f93d9 100644 --- a/libs/community/langchain_community/embeddings/sagemaker_endpoint.py +++ b/libs/community/langchain_community/embeddings/sagemaker_endpoint.py @@ -1,7 +1,8 @@ from typing import Any, Dict, List, Optional from langchain_core.embeddings import Embeddings -from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator +from langchain_core.pydantic_v1 import BaseModel, Extra +from langchain_core.utils import pre_init from langchain_community.llms.sagemaker_endpoint import ContentHandlerBase @@ -115,7 +116,7 @@ class SagemakerEndpointEmbeddings(BaseModel, Embeddings): extra = Extra.forbid arbitrary_types_allowed = True - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Dont do anything if client provided externally""" if values.get("client") is not None: diff --git a/libs/community/langchain_community/embeddings/sambanova.py b/libs/community/langchain_community/embeddings/sambanova.py index 1b360f25470..6f86b417454 100644 --- a/libs/community/langchain_community/embeddings/sambanova.py +++ b/libs/community/langchain_community/embeddings/sambanova.py @@ -3,8 +3,8 @@ from typing import Dict, Generator, List, Optional import requests from langchain_core.embeddings import Embeddings -from langchain_core.pydantic_v1 import BaseModel, root_validator -from langchain_core.utils import get_from_dict_or_env +from langchain_core.pydantic_v1 import BaseModel +from langchain_core.utils import get_from_dict_or_env, pre_init class SambaStudioEmbeddings(BaseModel, Embeddings): @@ -64,7 +64,7 @@ class SambaStudioEmbeddings(BaseModel, Embeddings): batch_size: int = 32 """Batch size for the embedding models""" - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" values["sambastudio_embeddings_base_url"] = get_from_dict_or_env( diff --git a/libs/community/langchain_community/embeddings/solar.py b/libs/community/langchain_community/embeddings/solar.py index d7e05df0962..8622525c286 100644 --- a/libs/community/langchain_community/embeddings/solar.py +++ b/libs/community/langchain_community/embeddings/solar.py @@ -6,8 +6,8 @@ from typing import Any, Callable, Dict, List, Optional import requests from langchain_core._api import deprecated from langchain_core.embeddings import Embeddings -from langchain_core.pydantic_v1 import BaseModel, Extra, SecretStr, root_validator -from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env +from langchain_core.pydantic_v1 import BaseModel, Extra, SecretStr +from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init from tenacity import ( before_sleep_log, retry, @@ -80,7 +80,7 @@ class SolarEmbeddings(BaseModel, Embeddings): extra = Extra.forbid - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate api key exists in environment.""" solar_api_key = convert_to_secret_str( diff --git a/libs/community/langchain_community/embeddings/vertexai.py b/libs/community/langchain_community/embeddings/vertexai.py index 4328a4816be..c1b38c9ab8f 100644 --- a/libs/community/langchain_community/embeddings/vertexai.py +++ b/libs/community/langchain_community/embeddings/vertexai.py @@ -8,7 +8,7 @@ from typing import Any, Dict, List, Literal, Optional, Tuple from langchain_core._api.deprecation import deprecated from langchain_core.embeddings import Embeddings from langchain_core.language_models.llms import create_base_retry_decorator -from langchain_core.pydantic_v1 import root_validator +from langchain_core.utils import pre_init from langchain_community.llms.vertexai import _VertexAICommon from langchain_community.utilities.vertexai import raise_vertex_import_error @@ -33,7 +33,7 @@ class VertexAIEmbeddings(_VertexAICommon, Embeddings): show_progress_bar: bool = False """Whether to show a tqdm progress bar. Must have `tqdm` installed.""" - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validates that the python package exists in environment.""" cls._try_init_vertexai(values) diff --git a/libs/community/langchain_community/embeddings/volcengine.py b/libs/community/langchain_community/embeddings/volcengine.py index 98ac729b968..9b8e3992648 100644 --- a/libs/community/langchain_community/embeddings/volcengine.py +++ b/libs/community/langchain_community/embeddings/volcengine.py @@ -4,8 +4,8 @@ import logging from typing import Any, Dict, List, Optional from langchain_core.embeddings import Embeddings -from langchain_core.pydantic_v1 import BaseModel, root_validator -from langchain_core.utils import get_from_dict_or_env +from langchain_core.pydantic_v1 import BaseModel +from langchain_core.utils import get_from_dict_or_env, pre_init logger = logging.getLogger(__name__) @@ -43,7 +43,7 @@ class VolcanoEmbeddings(BaseModel, Embeddings): client: Any """volcano client""" - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """ Validate whether volcano_ak and volcano_sk in the environment variables or diff --git a/libs/community/langchain_community/embeddings/yandex.py b/libs/community/langchain_community/embeddings/yandex.py index 6fea5f121a3..c262e1acdac 100644 --- a/libs/community/langchain_community/embeddings/yandex.py +++ b/libs/community/langchain_community/embeddings/yandex.py @@ -7,8 +7,8 @@ import time from typing import Any, Callable, Dict, List, Sequence from langchain_core.embeddings import Embeddings -from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator -from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env +from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr +from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init from tenacity import ( before_sleep_log, retry, @@ -76,7 +76,7 @@ class YandexGPTEmbeddings(BaseModel, Embeddings): allow_population_by_field_name = True - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that iam token exists in environment.""" diff --git a/libs/community/langchain_community/llms/ai21.py b/libs/community/langchain_community/llms/ai21.py index dd86ba516ae..484077a3f2c 100644 --- a/libs/community/langchain_community/llms/ai21.py +++ b/libs/community/langchain_community/llms/ai21.py @@ -3,8 +3,8 @@ from typing import Any, Dict, List, Optional, cast import requests from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models.llms import LLM -from langchain_core.pydantic_v1 import BaseModel, Extra, SecretStr, root_validator -from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env +from langchain_core.pydantic_v1 import BaseModel, Extra, SecretStr +from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init class AI21PenaltyData(BaseModel): @@ -73,7 +73,7 @@ class AI21(LLM): extra = Extra.forbid - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that api key exists in environment.""" ai21_api_key = convert_to_secret_str( diff --git a/libs/community/langchain_community/llms/aleph_alpha.py b/libs/community/langchain_community/llms/aleph_alpha.py index d2e56ac35d5..c34ffcd4b51 100644 --- a/libs/community/langchain_community/llms/aleph_alpha.py +++ b/libs/community/langchain_community/llms/aleph_alpha.py @@ -2,8 +2,8 @@ from typing import Any, Dict, List, Optional, Sequence from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models.llms import LLM -from langchain_core.pydantic_v1 import Extra, root_validator -from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env +from langchain_core.pydantic_v1 import Extra, SecretStr +from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init from langchain_community.llms.utils import enforce_stop_tokens @@ -129,7 +129,7 @@ class AlephAlpha(LLM): """Stop sequences to use.""" # Client params - aleph_alpha_api_key: Optional[str] = None + aleph_alpha_api_key: Optional[SecretStr] = None """API key for Aleph Alpha API.""" host: str = "https://api.aleph-alpha.com" """The hostname of the API host. @@ -167,7 +167,7 @@ class AlephAlpha(LLM): extra = Extra.forbid - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" values["aleph_alpha_api_key"] = convert_to_secret_str( diff --git a/libs/community/langchain_community/llms/anthropic.py b/libs/community/langchain_community/llms/anthropic.py index 8daf3fccabd..3727da55b62 100644 --- a/libs/community/langchain_community/llms/anthropic.py +++ b/libs/community/langchain_community/llms/anthropic.py @@ -25,6 +25,7 @@ from langchain_core.utils import ( check_package_version, get_from_dict_or_env, get_pydantic_field_names, + pre_init, ) from langchain_core.utils.utils import build_extra_kwargs, convert_to_secret_str @@ -74,7 +75,7 @@ class _AnthropicCommon(BaseLanguageModel): ) return values - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" values["anthropic_api_key"] = convert_to_secret_str( @@ -185,7 +186,7 @@ class Anthropic(LLM, _AnthropicCommon): allow_population_by_field_name = True arbitrary_types_allowed = True - @root_validator() + @pre_init def raise_warning(cls, values: Dict) -> Dict: """Raise warning that this class is deprecated.""" warnings.warn( diff --git a/libs/community/langchain_community/llms/anyscale.py b/libs/community/langchain_community/llms/anyscale.py index 3d4f727d3de..2e43fa38ae5 100644 --- a/libs/community/langchain_community/llms/anyscale.py +++ b/libs/community/langchain_community/llms/anyscale.py @@ -14,8 +14,8 @@ from langchain_core.callbacks import ( CallbackManagerForLLMRun, ) from langchain_core.outputs import Generation, GenerationChunk, LLMResult -from langchain_core.pydantic_v1 import Field, SecretStr, root_validator -from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env +from langchain_core.pydantic_v1 import Field, SecretStr +from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init from langchain_community.llms.openai import ( BaseOpenAI, @@ -93,7 +93,7 @@ class Anyscale(BaseOpenAI): def is_lc_serializable(cls) -> bool: return False - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" values["anyscale_api_base"] = get_from_dict_or_env( diff --git a/libs/community/langchain_community/llms/aphrodite.py b/libs/community/langchain_community/llms/aphrodite.py index 11720a088f2..d8bf1c1871c 100644 --- a/libs/community/langchain_community/llms/aphrodite.py +++ b/libs/community/langchain_community/llms/aphrodite.py @@ -3,7 +3,8 @@ from typing import Any, Dict, List, Optional from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models import BaseLLM from langchain_core.outputs import Generation, LLMResult -from langchain_core.pydantic_v1 import Field, root_validator +from langchain_core.pydantic_v1 import Field +from langchain_core.utils import pre_init class Aphrodite(BaseLLM): @@ -157,7 +158,7 @@ class Aphrodite(BaseLLM): client: Any #: :meta private: - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that python package exists in environment.""" diff --git a/libs/community/langchain_community/llms/baichuan.py b/libs/community/langchain_community/llms/baichuan.py index 2f897b2c6d2..4cc614ce8ac 100644 --- a/libs/community/langchain_community/llms/baichuan.py +++ b/libs/community/langchain_community/llms/baichuan.py @@ -7,8 +7,8 @@ from typing import Any, Dict, List, Optional import requests from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models.llms import LLM -from langchain_core.pydantic_v1 import Field, SecretStr, root_validator -from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env +from langchain_core.pydantic_v1 import Field, SecretStr +from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init from langchain_community.llms.utils import enforce_stop_tokens @@ -31,7 +31,7 @@ class BaichuanLLM(LLM): baichuan_api_host: Optional[str] = None baichuan_api_key: Optional[SecretStr] = None - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: values["baichuan_api_key"] = convert_to_secret_str( get_from_dict_or_env(values, "baichuan_api_key", "BAICHUAN_API_KEY") diff --git a/libs/community/langchain_community/llms/baidu_qianfan_endpoint.py b/libs/community/langchain_community/llms/baidu_qianfan_endpoint.py index b92b0cd5f61..4d091f0614c 100644 --- a/libs/community/langchain_community/llms/baidu_qianfan_endpoint.py +++ b/libs/community/langchain_community/llms/baidu_qianfan_endpoint.py @@ -16,8 +16,8 @@ from langchain_core.callbacks import ( ) from langchain_core.language_models.llms import LLM from langchain_core.outputs import GenerationChunk -from langchain_core.pydantic_v1 import Field, root_validator -from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env +from langchain_core.pydantic_v1 import Field +from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init logger = logging.getLogger(__name__) @@ -76,7 +76,7 @@ class QianfanLLMEndpoint(LLM): In the case of other model, passing these params will not affect the result. """ - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: values["qianfan_ak"] = convert_to_secret_str( get_from_dict_or_env( diff --git a/libs/community/langchain_community/llms/bananadev.py b/libs/community/langchain_community/llms/bananadev.py index 43ee44b83df..32e075133ba 100644 --- a/libs/community/langchain_community/llms/bananadev.py +++ b/libs/community/langchain_community/llms/bananadev.py @@ -4,7 +4,7 @@ from typing import Any, Dict, List, Mapping, Optional, cast from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models.llms import LLM from langchain_core.pydantic_v1 import Extra, Field, SecretStr, root_validator -from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env +from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init from langchain_community.llms.utils import enforce_stop_tokens @@ -63,7 +63,7 @@ class Banana(LLM): values["model_kwargs"] = extra return values - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" banana_api_key = convert_to_secret_str( diff --git a/libs/community/langchain_community/llms/beam.py b/libs/community/langchain_community/llms/beam.py index 9476cf07d05..b69abdb6bb5 100644 --- a/libs/community/langchain_community/llms/beam.py +++ b/libs/community/langchain_community/llms/beam.py @@ -10,7 +10,7 @@ import requests from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models.llms import LLM from langchain_core.pydantic_v1 import Extra, Field, root_validator -from langchain_core.utils import get_from_dict_or_env +from langchain_core.utils import get_from_dict_or_env, pre_init logger = logging.getLogger(__name__) @@ -95,7 +95,7 @@ class Beam(LLM): values["model_kwargs"] = extra return values - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" beam_client_id = get_from_dict_or_env( diff --git a/libs/community/langchain_community/llms/bedrock.py b/libs/community/langchain_community/llms/bedrock.py index 713d6eb84b6..3300f19077c 100644 --- a/libs/community/langchain_community/llms/bedrock.py +++ b/libs/community/langchain_community/llms/bedrock.py @@ -21,8 +21,8 @@ from langchain_core.callbacks import ( ) from langchain_core.language_models.llms import LLM from langchain_core.outputs import GenerationChunk -from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator -from langchain_core.utils import get_from_dict_or_env +from langchain_core.pydantic_v1 import BaseModel, Extra, Field +from langchain_core.utils import get_from_dict_or_env, pre_init from langchain_community.llms.utils import enforce_stop_tokens from langchain_community.utilities.anthropic import ( @@ -389,7 +389,7 @@ class BedrockBase(BaseModel, ABC): ...Logic to handle guardrail intervention... """ # noqa: E501 - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that AWS credentials to and python package exists in environment.""" @@ -743,7 +743,7 @@ class Bedrock(LLM, BedrockBase): """ - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: model_id = values["model_id"] if model_id.startswith("anthropic.claude-3"): diff --git a/libs/community/langchain_community/llms/cerebriumai.py b/libs/community/langchain_community/llms/cerebriumai.py index c9e219995ae..417bedf802d 100644 --- a/libs/community/langchain_community/llms/cerebriumai.py +++ b/libs/community/langchain_community/llms/cerebriumai.py @@ -5,7 +5,7 @@ import requests from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models.llms import LLM from langchain_core.pydantic_v1 import Extra, Field, SecretStr, root_validator -from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env +from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init from langchain_community.llms.utils import enforce_stop_tokens @@ -62,7 +62,7 @@ class CerebriumAI(LLM): values["model_kwargs"] = extra return values - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" cerebriumai_api_key = convert_to_secret_str( diff --git a/libs/community/langchain_community/llms/clarifai.py b/libs/community/langchain_community/llms/clarifai.py index 8a3c9b04212..83812968393 100644 --- a/libs/community/langchain_community/llms/clarifai.py +++ b/libs/community/langchain_community/llms/clarifai.py @@ -4,7 +4,8 @@ from typing import Any, Dict, List, Optional from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models.llms import LLM from langchain_core.outputs import Generation, LLMResult -from langchain_core.pydantic_v1 import Extra, Field, root_validator +from langchain_core.pydantic_v1 import Extra, Field +from langchain_core.utils import pre_init from langchain_community.llms.utils import enforce_stop_tokens @@ -53,7 +54,7 @@ class Clarifai(LLM): extra = Extra.forbid - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that we have all required info to access Clarifai platform and python package exists in environment.""" diff --git a/libs/community/langchain_community/llms/cohere.py b/libs/community/langchain_community/llms/cohere.py index 02e505e2fbe..c266fcfafcf 100644 --- a/libs/community/langchain_community/llms/cohere.py +++ b/libs/community/langchain_community/llms/cohere.py @@ -10,8 +10,8 @@ from langchain_core.callbacks import ( ) from langchain_core.language_models.llms import LLM from langchain_core.load.serializable import Serializable -from langchain_core.pydantic_v1 import Extra, Field, SecretStr, root_validator -from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env +from langchain_core.pydantic_v1 import Extra, Field, SecretStr +from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init from tenacity import ( before_sleep_log, retry, @@ -95,7 +95,7 @@ class BaseCohere(Serializable): user_agent: str = "langchain" """Identifier for the application making the request.""" - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" try: diff --git a/libs/community/langchain_community/llms/ctransformers.py b/libs/community/langchain_community/llms/ctransformers.py index 95c13ae487f..612e6041db5 100644 --- a/libs/community/langchain_community/llms/ctransformers.py +++ b/libs/community/langchain_community/llms/ctransformers.py @@ -6,7 +6,7 @@ from langchain_core.callbacks import ( CallbackManagerForLLMRun, ) from langchain_core.language_models.llms import LLM -from langchain_core.pydantic_v1 import root_validator +from langchain_core.utils import pre_init class CTransformers(LLM): @@ -57,7 +57,7 @@ class CTransformers(LLM): """Return type of llm.""" return "ctransformers" - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that ``ctransformers`` package is installed.""" try: diff --git a/libs/community/langchain_community/llms/ctranslate2.py b/libs/community/langchain_community/llms/ctranslate2.py index 84e357aa8e5..22c5de078ac 100644 --- a/libs/community/langchain_community/llms/ctranslate2.py +++ b/libs/community/langchain_community/llms/ctranslate2.py @@ -3,7 +3,8 @@ from typing import Any, Dict, List, Optional, Union from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models.llms import BaseLLM from langchain_core.outputs import Generation, LLMResult -from langchain_core.pydantic_v1 import Field, root_validator +from langchain_core.pydantic_v1 import Field +from langchain_core.utils import pre_init class CTranslate2(BaseLLM): @@ -50,7 +51,7 @@ class CTranslate2(BaseLLM): explicitly specified. """ - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that python package exists in environment.""" diff --git a/libs/community/langchain_community/llms/deepinfra.py b/libs/community/langchain_community/llms/deepinfra.py index 6d8deb0caa8..0dbf4bcc668 100644 --- a/libs/community/langchain_community/llms/deepinfra.py +++ b/libs/community/langchain_community/llms/deepinfra.py @@ -8,8 +8,8 @@ from langchain_core.callbacks import ( ) from langchain_core.language_models.llms import LLM from langchain_core.outputs import GenerationChunk -from langchain_core.pydantic_v1 import Extra, root_validator -from langchain_core.utils import get_from_dict_or_env +from langchain_core.pydantic_v1 import Extra +from langchain_core.utils import get_from_dict_or_env, pre_init from langchain_community.utilities.requests import Requests @@ -43,7 +43,7 @@ class DeepInfra(LLM): extra = Extra.forbid - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" deepinfra_api_token = get_from_dict_or_env( diff --git a/libs/community/langchain_community/llms/deepsparse.py b/libs/community/langchain_community/llms/deepsparse.py index 5560edafd42..106750e2731 100644 --- a/libs/community/langchain_community/llms/deepsparse.py +++ b/libs/community/langchain_community/llms/deepsparse.py @@ -1,12 +1,18 @@ # flake8: noqa +from langchain_core.utils import pre_init from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union +from langchain_core.utils import pre_init from langchain_core.pydantic_v1 import root_validator +from langchain_core.utils import pre_init from langchain_core.callbacks import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) +from langchain_core.utils import pre_init from langchain_core.language_models.llms import LLM +from langchain_core.utils import pre_init from langchain_community.llms.utils import enforce_stop_tokens +from langchain_core.utils import pre_init from langchain_core.outputs import GenerationChunk @@ -55,7 +61,7 @@ class DeepSparse(LLM): """Return type of llm.""" return "deepsparse" - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that ``deepsparse`` package is installed.""" try: diff --git a/libs/community/langchain_community/llms/edenai.py b/libs/community/langchain_community/llms/edenai.py index 281b4cbd195..d9c1b5a7caf 100644 --- a/libs/community/langchain_community/llms/edenai.py +++ b/libs/community/langchain_community/llms/edenai.py @@ -10,7 +10,7 @@ from langchain_core.callbacks import ( ) from langchain_core.language_models.llms import LLM from langchain_core.pydantic_v1 import Extra, Field, root_validator -from langchain_core.utils import get_from_dict_or_env +from langchain_core.utils import get_from_dict_or_env, pre_init from langchain_community.llms.utils import enforce_stop_tokens from langchain_community.utilities.requests import Requests @@ -73,7 +73,7 @@ class EdenAI(LLM): extra = Extra.forbid - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that api key exists in environment.""" values["edenai_api_key"] = get_from_dict_or_env( diff --git a/libs/community/langchain_community/llms/exllamav2.py b/libs/community/langchain_community/llms/exllamav2.py index 2962466e463..3e577cbf977 100644 --- a/libs/community/langchain_community/llms/exllamav2.py +++ b/libs/community/langchain_community/llms/exllamav2.py @@ -3,7 +3,8 @@ from typing import Any, Dict, Iterator, List, Optional from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models import LLM from langchain_core.outputs import GenerationChunk -from langchain_core.pydantic_v1 import Field, root_validator +from langchain_core.pydantic_v1 import Field +from langchain_core.utils import pre_init class ExLlamaV2(LLM): @@ -58,7 +59,7 @@ class ExLlamaV2(LLM): disallowed_tokens: List[int] = Field(None) """List of tokens to disallow during generation.""" - @root_validator() + @pre_init def validate_environment(cls, values: Dict[str, Any]) -> Dict[str, Any]: try: import torch diff --git a/libs/community/langchain_community/llms/fireworks.py b/libs/community/langchain_community/llms/fireworks.py index a2ecc74c3f3..a8f55c2012f 100644 --- a/libs/community/langchain_community/llms/fireworks.py +++ b/libs/community/langchain_community/llms/fireworks.py @@ -9,8 +9,8 @@ from langchain_core.callbacks import ( ) from langchain_core.language_models.llms import BaseLLM, create_base_retry_decorator from langchain_core.outputs import Generation, GenerationChunk, LLMResult -from langchain_core.pydantic_v1 import Field, SecretStr, root_validator -from langchain_core.utils import convert_to_secret_str +from langchain_core.pydantic_v1 import Field, SecretStr +from langchain_core.utils import convert_to_secret_str, pre_init from langchain_core.utils.env import get_from_dict_or_env @@ -61,7 +61,7 @@ class Fireworks(BaseLLM): """Get the namespace of the langchain object.""" return ["langchain", "llms", "fireworks"] - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that api key in environment.""" try: diff --git a/libs/community/langchain_community/llms/friendli.py b/libs/community/langchain_community/llms/friendli.py index 88e23e4ac8a..5f137b078e1 100644 --- a/libs/community/langchain_community/llms/friendli.py +++ b/libs/community/langchain_community/llms/friendli.py @@ -10,7 +10,8 @@ from langchain_core.callbacks.manager import ( from langchain_core.language_models.llms import LLM from langchain_core.load.serializable import Serializable from langchain_core.outputs import GenerationChunk, LLMResult -from langchain_core.pydantic_v1 import Field, SecretStr, root_validator +from langchain_core.pydantic_v1 import Field, SecretStr +from langchain_core.utils import pre_init from langchain_core.utils.env import get_from_dict_or_env from langchain_core.utils.utils import convert_to_secret_str @@ -66,7 +67,7 @@ class BaseFriendli(Serializable): # is used by default. top_p: Optional[float] = None - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate if personal access token is provided in environment.""" try: diff --git a/libs/community/langchain_community/llms/gigachat.py b/libs/community/langchain_community/llms/gigachat.py index 67b604b2d20..a8fbb1654c4 100644 --- a/libs/community/langchain_community/llms/gigachat.py +++ b/libs/community/langchain_community/llms/gigachat.py @@ -11,7 +11,7 @@ from langchain_core.callbacks import ( from langchain_core.language_models.llms import BaseLLM from langchain_core.load.serializable import Serializable from langchain_core.outputs import Generation, GenerationChunk, LLMResult -from langchain_core.pydantic_v1 import root_validator +from langchain_core.utils import pre_init if TYPE_CHECKING: import gigachat @@ -113,7 +113,7 @@ class _BaseGigaChat(Serializable): verbose=self.verbose, ) - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate authenticate data in environment and python package is installed.""" try: diff --git a/libs/community/langchain_community/llms/google_palm.py b/libs/community/langchain_community/llms/google_palm.py index b0e7e22c094..278a46ea2d7 100644 --- a/libs/community/langchain_community/llms/google_palm.py +++ b/libs/community/langchain_community/llms/google_palm.py @@ -6,8 +6,8 @@ from langchain_core._api.deprecation import deprecated from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models import LanguageModelInput from langchain_core.outputs import Generation, GenerationChunk, LLMResult -from langchain_core.pydantic_v1 import BaseModel, SecretStr, root_validator -from langchain_core.utils import get_from_dict_or_env +from langchain_core.pydantic_v1 import BaseModel, SecretStr +from langchain_core.utils import get_from_dict_or_env, pre_init from langchain_community.llms import BaseLLM from langchain_community.utilities.vertexai import create_retry_decorator @@ -107,7 +107,7 @@ class GooglePalm(BaseLLM, BaseModel): """Get the namespace of the langchain object.""" return ["langchain", "llms", "google_palm"] - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate api key, python package exists.""" google_api_key = get_from_dict_or_env( diff --git a/libs/community/langchain_community/llms/gooseai.py b/libs/community/langchain_community/llms/gooseai.py index 27ff257ab63..617a825b2c9 100644 --- a/libs/community/langchain_community/llms/gooseai.py +++ b/libs/community/langchain_community/llms/gooseai.py @@ -4,7 +4,7 @@ from typing import Any, Dict, List, Mapping, Optional from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models.llms import LLM from langchain_core.pydantic_v1 import Extra, Field, SecretStr, root_validator -from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env +from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init logger = logging.getLogger(__name__) @@ -86,7 +86,7 @@ class GooseAI(LLM): values["model_kwargs"] = extra return values - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" gooseai_api_key = convert_to_secret_str( diff --git a/libs/community/langchain_community/llms/gpt4all.py b/libs/community/langchain_community/llms/gpt4all.py index 7824b00aa5a..42fef0034e8 100644 --- a/libs/community/langchain_community/llms/gpt4all.py +++ b/libs/community/langchain_community/llms/gpt4all.py @@ -3,7 +3,8 @@ from typing import Any, Dict, List, Mapping, Optional, Set from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models.llms import LLM -from langchain_core.pydantic_v1 import Extra, Field, root_validator +from langchain_core.pydantic_v1 import Extra, Field +from langchain_core.utils import pre_init from langchain_community.llms.utils import enforce_stop_tokens @@ -127,7 +128,7 @@ class GPT4All(LLM): "streaming": self.streaming, } - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that the python package exists in the environment.""" try: diff --git a/libs/community/langchain_community/llms/huggingface_endpoint.py b/libs/community/langchain_community/llms/huggingface_endpoint.py index 1b233c53481..31884057ff2 100644 --- a/libs/community/langchain_community/llms/huggingface_endpoint.py +++ b/libs/community/langchain_community/llms/huggingface_endpoint.py @@ -10,7 +10,11 @@ from langchain_core.callbacks import ( from langchain_core.language_models.llms import LLM from langchain_core.outputs import GenerationChunk from langchain_core.pydantic_v1 import Extra, Field, root_validator -from langchain_core.utils import get_from_dict_or_env, get_pydantic_field_names +from langchain_core.utils import ( + get_from_dict_or_env, + get_pydantic_field_names, + pre_init, +) logger = logging.getLogger(__name__) @@ -162,7 +166,7 @@ class HuggingFaceEndpoint(LLM): values["model"] = values.get("endpoint_url") or values.get("repo_id") return values - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that package is installed and that the API token is valid.""" try: diff --git a/libs/community/langchain_community/llms/huggingface_hub.py b/libs/community/langchain_community/llms/huggingface_hub.py index ab2977e9834..a49d959ed96 100644 --- a/libs/community/langchain_community/llms/huggingface_hub.py +++ b/libs/community/langchain_community/llms/huggingface_hub.py @@ -4,8 +4,8 @@ from typing import Any, Dict, List, Mapping, Optional from langchain_core._api.deprecation import deprecated from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models.llms import LLM -from langchain_core.pydantic_v1 import Extra, root_validator -from langchain_core.utils import get_from_dict_or_env +from langchain_core.pydantic_v1 import Extra +from langchain_core.utils import get_from_dict_or_env, pre_init from langchain_community.llms.utils import enforce_stop_tokens @@ -61,7 +61,7 @@ class HuggingFaceHub(LLM): extra = Extra.forbid - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" huggingfacehub_api_token = get_from_dict_or_env( diff --git a/libs/community/langchain_community/llms/huggingface_text_gen_inference.py b/libs/community/langchain_community/llms/huggingface_text_gen_inference.py index 322007ec174..d432949b8fc 100644 --- a/libs/community/langchain_community/llms/huggingface_text_gen_inference.py +++ b/libs/community/langchain_community/llms/huggingface_text_gen_inference.py @@ -9,7 +9,7 @@ from langchain_core.callbacks import ( from langchain_core.language_models.llms import LLM from langchain_core.outputs import GenerationChunk from langchain_core.pydantic_v1 import Extra, Field, root_validator -from langchain_core.utils import get_pydantic_field_names +from langchain_core.utils import get_pydantic_field_names, pre_init logger = logging.getLogger(__name__) @@ -134,7 +134,7 @@ class HuggingFaceTextGenInference(LLM): values["model_kwargs"] = extra return values - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that python package exists in environment.""" diff --git a/libs/community/langchain_community/llms/llamacpp.py b/libs/community/langchain_community/llms/llamacpp.py index 39e58093a6b..0bd6d7e010c 100644 --- a/libs/community/langchain_community/llms/llamacpp.py +++ b/libs/community/langchain_community/llms/llamacpp.py @@ -8,7 +8,7 @@ from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models.llms import LLM from langchain_core.outputs import GenerationChunk from langchain_core.pydantic_v1 import Field, root_validator -from langchain_core.utils import get_pydantic_field_names +from langchain_core.utils import get_pydantic_field_names, pre_init from langchain_core.utils.utils import build_extra_kwargs logger = logging.getLogger(__name__) @@ -133,7 +133,7 @@ class LlamaCpp(LLM): verbose: bool = True """Print verbose output to stderr.""" - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that llama-cpp-python library is installed.""" try: diff --git a/libs/community/langchain_community/llms/manifest.py b/libs/community/langchain_community/llms/manifest.py index 2852ab1d7c7..7e4f45d7c91 100644 --- a/libs/community/langchain_community/llms/manifest.py +++ b/libs/community/langchain_community/llms/manifest.py @@ -2,7 +2,8 @@ from typing import Any, Dict, List, Mapping, Optional from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models.llms import LLM -from langchain_core.pydantic_v1 import Extra, root_validator +from langchain_core.pydantic_v1 import Extra +from langchain_core.utils import pre_init class ManifestWrapper(LLM): @@ -16,7 +17,7 @@ class ManifestWrapper(LLM): extra = Extra.forbid - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that python package exists in environment.""" try: diff --git a/libs/community/langchain_community/llms/minimax.py b/libs/community/langchain_community/llms/minimax.py index df0cb53ff03..08b1ec38991 100644 --- a/libs/community/langchain_community/llms/minimax.py +++ b/libs/community/langchain_community/llms/minimax.py @@ -16,7 +16,7 @@ from langchain_core.callbacks import ( ) from langchain_core.language_models.llms import LLM from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator -from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env +from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init from langchain_community.llms.utils import enforce_stop_tokens @@ -72,7 +72,7 @@ class MinimaxCommon(BaseModel): minimax_group_id: Optional[str] = None minimax_api_key: Optional[SecretStr] = None - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" values["minimax_api_key"] = convert_to_secret_str( diff --git a/libs/community/langchain_community/llms/moonshot.py b/libs/community/langchain_community/llms/moonshot.py index d72659f1266..0370db3a242 100644 --- a/libs/community/langchain_community/llms/moonshot.py +++ b/libs/community/langchain_community/llms/moonshot.py @@ -4,7 +4,7 @@ import requests from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models import LLM from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator -from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env +from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init from langchain_community.llms.utils import enforce_stop_tokens @@ -79,7 +79,7 @@ class MoonshotCommon(BaseModel): """ return values - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" values["moonshot_api_key"] = convert_to_secret_str( diff --git a/libs/community/langchain_community/llms/mosaicml.py b/libs/community/langchain_community/llms/mosaicml.py index f7fcba0c7ea..75bf50ca909 100644 --- a/libs/community/langchain_community/llms/mosaicml.py +++ b/libs/community/langchain_community/llms/mosaicml.py @@ -3,8 +3,8 @@ from typing import Any, Dict, List, Mapping, Optional import requests from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models.llms import LLM -from langchain_core.pydantic_v1 import Extra, root_validator -from langchain_core.utils import get_from_dict_or_env +from langchain_core.pydantic_v1 import Extra +from langchain_core.utils import get_from_dict_or_env, pre_init from langchain_community.llms.utils import enforce_stop_tokens @@ -64,7 +64,7 @@ class MosaicML(LLM): extra = Extra.forbid - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" mosaicml_api_token = get_from_dict_or_env( diff --git a/libs/community/langchain_community/llms/nlpcloud.py b/libs/community/langchain_community/llms/nlpcloud.py index a73087918ce..0ac0451dc36 100644 --- a/libs/community/langchain_community/llms/nlpcloud.py +++ b/libs/community/langchain_community/llms/nlpcloud.py @@ -2,8 +2,8 @@ from typing import Any, Dict, List, Mapping, Optional from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models.llms import LLM -from langchain_core.pydantic_v1 import Extra, SecretStr, root_validator -from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env +from langchain_core.pydantic_v1 import Extra, SecretStr +from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init class NLPCloud(LLM): @@ -56,7 +56,7 @@ class NLPCloud(LLM): extra = Extra.forbid - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" values["nlpcloud_api_key"] = convert_to_secret_str( diff --git a/libs/community/langchain_community/llms/oci_data_science_model_deployment_endpoint.py b/libs/community/langchain_community/llms/oci_data_science_model_deployment_endpoint.py index acd4ff5e11d..0c1178d29c9 100644 --- a/libs/community/langchain_community/llms/oci_data_science_model_deployment_endpoint.py +++ b/libs/community/langchain_community/llms/oci_data_science_model_deployment_endpoint.py @@ -4,8 +4,8 @@ from typing import Any, Dict, List, Optional import requests from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models.llms import LLM -from langchain_core.pydantic_v1 import Field, root_validator -from langchain_core.utils import get_from_dict_or_env +from langchain_core.pydantic_v1 import Field +from langchain_core.utils import get_from_dict_or_env, pre_init logger = logging.getLogger(__name__) @@ -47,7 +47,7 @@ class OCIModelDeploymentLLM(LLM): """Stop words to use when generating. Model output is cut off at the first occurrence of any of these substrings.""" - @root_validator() + @pre_init def validate_environment( # pylint: disable=no-self-argument cls, values: Dict ) -> Dict: diff --git a/libs/community/langchain_community/llms/oci_generative_ai.py b/libs/community/langchain_community/llms/oci_generative_ai.py index b513d8e2124..3c4feb1aa5b 100644 --- a/libs/community/langchain_community/llms/oci_generative_ai.py +++ b/libs/community/langchain_community/llms/oci_generative_ai.py @@ -8,7 +8,8 @@ from typing import Any, Dict, Iterator, List, Mapping, Optional from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models.llms import LLM from langchain_core.outputs import GenerationChunk -from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator +from langchain_core.pydantic_v1 import BaseModel, Extra +from langchain_core.utils import pre_init from langchain_community.llms.utils import enforce_stop_tokens @@ -99,7 +100,7 @@ class OCIGenAIBase(BaseModel, ABC): is_stream: bool = False """Whether to stream back partial progress""" - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that OCI config and python package exists in environment.""" diff --git a/libs/community/langchain_community/llms/octoai_endpoint.py b/libs/community/langchain_community/llms/octoai_endpoint.py index a2bc1484011..ff5302c3ff0 100644 --- a/libs/community/langchain_community/llms/octoai_endpoint.py +++ b/libs/community/langchain_community/llms/octoai_endpoint.py @@ -1,7 +1,7 @@ from typing import Any, Dict -from langchain_core.pydantic_v1 import Field, SecretStr, root_validator -from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env +from langchain_core.pydantic_v1 import Field, SecretStr +from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init from langchain_community.llms.openai import BaseOpenAI from langchain_community.utils.openai import is_openai_v1 @@ -66,7 +66,7 @@ class OctoAIEndpoint(BaseOpenAI): """Return type of llm.""" return "octoai_endpoint" - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" values["octoai_api_base"] = get_from_dict_or_env( diff --git a/libs/community/langchain_community/llms/opaqueprompts.py b/libs/community/langchain_community/llms/opaqueprompts.py index 67bd5b62722..d00354183f9 100644 --- a/libs/community/langchain_community/llms/opaqueprompts.py +++ b/libs/community/langchain_community/llms/opaqueprompts.py @@ -5,8 +5,8 @@ from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models import BaseLanguageModel from langchain_core.language_models.llms import LLM from langchain_core.messages import AIMessage -from langchain_core.pydantic_v1 import Extra, root_validator -from langchain_core.utils import get_from_dict_or_env +from langchain_core.pydantic_v1 import Extra +from langchain_core.utils import get_from_dict_or_env, pre_init logger = logging.getLogger(__name__) @@ -38,7 +38,7 @@ class OpaquePrompts(LLM): extra = Extra.forbid - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validates that the OpaquePrompts API key and the Python package exist.""" try: diff --git a/libs/community/langchain_community/llms/openai.py b/libs/community/langchain_community/llms/openai.py index 757e3d85ba2..3274652b604 100644 --- a/libs/community/langchain_community/llms/openai.py +++ b/libs/community/langchain_community/llms/openai.py @@ -29,7 +29,11 @@ from langchain_core.callbacks import ( from langchain_core.language_models.llms import BaseLLM, create_base_retry_decorator from langchain_core.outputs import Generation, GenerationChunk, LLMResult from langchain_core.pydantic_v1 import Field, root_validator -from langchain_core.utils import get_from_dict_or_env, get_pydantic_field_names +from langchain_core.utils import ( + get_from_dict_or_env, + get_pydantic_field_names, + pre_init, +) from langchain_core.utils.utils import build_extra_kwargs from langchain_community.utils.openai import is_openai_v1 @@ -269,7 +273,7 @@ class BaseOpenAI(BaseLLM): ) return values - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" if values["n"] < 1: @@ -818,7 +822,7 @@ class AzureOpenAI(BaseOpenAI): """Get the namespace of the langchain object.""" return ["langchain", "llms", "openai"] - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" if values["n"] < 1: @@ -1025,7 +1029,7 @@ class OpenAIChat(BaseLLM): values["model_kwargs"] = extra return values - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" openai_api_key = get_from_dict_or_env( diff --git a/libs/community/langchain_community/llms/openlm.py b/libs/community/langchain_community/llms/openlm.py index 47b303012bf..1601a3bd068 100644 --- a/libs/community/langchain_community/llms/openlm.py +++ b/libs/community/langchain_community/llms/openlm.py @@ -1,6 +1,6 @@ from typing import Any, Dict -from langchain_core.pydantic_v1 import root_validator +from langchain_core.utils import pre_init from langchain_community.llms.openai import BaseOpenAI @@ -16,7 +16,7 @@ class OpenLM(BaseOpenAI): def _invocation_params(self) -> Dict[str, Any]: return {**{"model": self.model_name}, **super()._invocation_params} - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: try: import openlm diff --git a/libs/community/langchain_community/llms/pai_eas_endpoint.py b/libs/community/langchain_community/llms/pai_eas_endpoint.py index c84083b17ea..5f447bda4c1 100644 --- a/libs/community/langchain_community/llms/pai_eas_endpoint.py +++ b/libs/community/langchain_community/llms/pai_eas_endpoint.py @@ -6,8 +6,7 @@ import requests from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models.llms import LLM from langchain_core.outputs import GenerationChunk -from langchain_core.pydantic_v1 import root_validator -from langchain_core.utils import get_from_dict_or_env +from langchain_core.utils import get_from_dict_or_env, pre_init from langchain_community.llms.utils import enforce_stop_tokens @@ -52,7 +51,7 @@ class PaiEasEndpoint(LLM): version: Optional[str] = "2.0" - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" values["eas_service_url"] = get_from_dict_or_env( diff --git a/libs/community/langchain_community/llms/petals.py b/libs/community/langchain_community/llms/petals.py index 9112d18a1df..f907c7420ed 100644 --- a/libs/community/langchain_community/llms/petals.py +++ b/libs/community/langchain_community/llms/petals.py @@ -4,7 +4,7 @@ from typing import Any, Dict, List, Mapping, Optional from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models.llms import LLM from langchain_core.pydantic_v1 import Extra, Field, SecretStr, root_validator -from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env +from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init from langchain_community.llms.utils import enforce_stop_tokens @@ -86,7 +86,7 @@ class Petals(LLM): values["model_kwargs"] = extra return values - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" huggingface_api_key = convert_to_secret_str( diff --git a/libs/community/langchain_community/llms/pipelineai.py b/libs/community/langchain_community/llms/pipelineai.py index 170722d920a..8d2b4c2074e 100644 --- a/libs/community/langchain_community/llms/pipelineai.py +++ b/libs/community/langchain_community/llms/pipelineai.py @@ -10,7 +10,7 @@ from langchain_core.pydantic_v1 import ( SecretStr, root_validator, ) -from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env +from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init from langchain_community.llms.utils import enforce_stop_tokens @@ -65,7 +65,7 @@ class PipelineAI(LLM, BaseModel): values["pipeline_kwargs"] = extra return values - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" pipeline_api_key = convert_to_secret_str( diff --git a/libs/community/langchain_community/llms/predictionguard.py b/libs/community/langchain_community/llms/predictionguard.py index 62115509cbc..45638c6ddea 100644 --- a/libs/community/langchain_community/llms/predictionguard.py +++ b/libs/community/langchain_community/llms/predictionguard.py @@ -3,8 +3,8 @@ from typing import Any, Dict, List, Optional from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models.llms import LLM -from langchain_core.pydantic_v1 import Extra, root_validator -from langchain_core.utils import get_from_dict_or_env +from langchain_core.pydantic_v1 import Extra +from langchain_core.utils import get_from_dict_or_env, pre_init from langchain_community.llms.utils import enforce_stop_tokens @@ -53,7 +53,7 @@ class PredictionGuard(LLM): extra = Extra.forbid - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that the access token and python package exists in environment.""" token = get_from_dict_or_env(values, "token", "PREDICTIONGUARD_TOKEN") diff --git a/libs/community/langchain_community/llms/replicate.py b/libs/community/langchain_community/llms/replicate.py index 0a1d9cffccd..10407ad7354 100644 --- a/libs/community/langchain_community/llms/replicate.py +++ b/libs/community/langchain_community/llms/replicate.py @@ -7,7 +7,7 @@ from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models.llms import LLM from langchain_core.outputs import GenerationChunk from langchain_core.pydantic_v1 import Extra, Field, root_validator -from langchain_core.utils import get_from_dict_or_env +from langchain_core.utils import get_from_dict_or_env, pre_init if TYPE_CHECKING: from replicate.prediction import Prediction @@ -97,7 +97,7 @@ class Replicate(LLM): values["model_kwargs"] = extra return values - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" replicate_api_token = get_from_dict_or_env( diff --git a/libs/community/langchain_community/llms/rwkv.py b/libs/community/langchain_community/llms/rwkv.py index 7381afe7db9..344e1d5de65 100644 --- a/libs/community/langchain_community/llms/rwkv.py +++ b/libs/community/langchain_community/llms/rwkv.py @@ -8,7 +8,8 @@ from typing import Any, Dict, List, Mapping, Optional, Set from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models.llms import LLM -from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator +from langchain_core.pydantic_v1 import BaseModel, Extra +from langchain_core.utils import pre_init from langchain_community.llms.utils import enforce_stop_tokens @@ -98,7 +99,7 @@ class RWKV(LLM, BaseModel): "verbose", } - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that the python package exists in the environment.""" try: diff --git a/libs/community/langchain_community/llms/sagemaker_endpoint.py b/libs/community/langchain_community/llms/sagemaker_endpoint.py index 4954c0707e3..89237890aa1 100644 --- a/libs/community/langchain_community/llms/sagemaker_endpoint.py +++ b/libs/community/langchain_community/llms/sagemaker_endpoint.py @@ -7,7 +7,8 @@ from typing import Any, Dict, Generic, Iterator, List, Mapping, Optional, TypeVa from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models.llms import LLM -from langchain_core.pydantic_v1 import Extra, root_validator +from langchain_core.pydantic_v1 import Extra +from langchain_core.utils import pre_init from langchain_community.llms.utils import enforce_stop_tokens @@ -249,7 +250,7 @@ class SagemakerEndpoint(LLM): extra = Extra.forbid - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Dont do anything if client provided externally""" if values.get("client") is not None: diff --git a/libs/community/langchain_community/llms/sambanova.py b/libs/community/langchain_community/llms/sambanova.py index 6017f9b8b3b..2b59d2b8c76 100644 --- a/libs/community/langchain_community/llms/sambanova.py +++ b/libs/community/langchain_community/llms/sambanova.py @@ -5,8 +5,8 @@ import requests from langchain_core.callbacks.manager import CallbackManagerForLLMRun from langchain_core.language_models.llms import LLM from langchain_core.outputs import GenerationChunk -from langchain_core.pydantic_v1 import Extra, root_validator -from langchain_core.utils import get_from_dict_or_env +from langchain_core.pydantic_v1 import Extra +from langchain_core.utils import get_from_dict_or_env, pre_init class SVEndpointHandler: @@ -218,7 +218,7 @@ class Sambaverse(LLM): def is_lc_serializable(cls) -> bool: return True - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that api key exists in environment.""" values["sambaverse_url"] = get_from_dict_or_env( @@ -731,7 +731,7 @@ class SambaStudio(LLM): """Return type of llm.""" return "Sambastudio LLM" - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" values["sambastudio_base_url"] = get_from_dict_or_env( diff --git a/libs/community/langchain_community/llms/solar.py b/libs/community/langchain_community/llms/solar.py index e8e7b7e11d6..65dfdc58ed5 100644 --- a/libs/community/langchain_community/llms/solar.py +++ b/libs/community/langchain_community/llms/solar.py @@ -4,7 +4,7 @@ import requests from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models import LLM from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator -from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env +from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init from langchain_community.llms.utils import enforce_stop_tokens @@ -68,7 +68,7 @@ class SolarCommon(BaseModel): def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: return values - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: api_key = get_from_dict_or_env(values, "solar_api_key", "SOLAR_API_KEY") if api_key is None or len(api_key) == 0: diff --git a/libs/community/langchain_community/llms/sparkllm.py b/libs/community/langchain_community/llms/sparkllm.py index c74abf21fd0..4e10e3b2536 100644 --- a/libs/community/langchain_community/llms/sparkllm.py +++ b/libs/community/langchain_community/llms/sparkllm.py @@ -17,8 +17,8 @@ from wsgiref.handlers import format_date_time from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models.llms import LLM from langchain_core.outputs import GenerationChunk -from langchain_core.pydantic_v1 import Field, root_validator -from langchain_core.utils import get_from_dict_or_env +from langchain_core.pydantic_v1 import Field +from langchain_core.utils import get_from_dict_or_env, pre_init logger = logging.getLogger(__name__) @@ -54,7 +54,7 @@ class SparkLLM(LLM): top_k: int = 4 model_kwargs: Dict[str, Any] = Field(default_factory=dict) - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: values["spark_app_id"] = get_from_dict_or_env( values, diff --git a/libs/community/langchain_community/llms/stochasticai.py b/libs/community/langchain_community/llms/stochasticai.py index 4cfc6d5ee94..090de9553f5 100644 --- a/libs/community/langchain_community/llms/stochasticai.py +++ b/libs/community/langchain_community/llms/stochasticai.py @@ -6,7 +6,7 @@ import requests from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models.llms import LLM from langchain_core.pydantic_v1 import Extra, Field, SecretStr, root_validator -from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env +from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init from langchain_community.llms.utils import enforce_stop_tokens @@ -58,7 +58,7 @@ class StochasticAI(LLM): values["model_kwargs"] = extra return values - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that api key exists in environment.""" stochasticai_api_key = convert_to_secret_str( diff --git a/libs/community/langchain_community/llms/symblai_nebula.py b/libs/community/langchain_community/llms/symblai_nebula.py index 9e0b2687a2f..b8f43008d99 100644 --- a/libs/community/langchain_community/llms/symblai_nebula.py +++ b/libs/community/langchain_community/llms/symblai_nebula.py @@ -5,8 +5,8 @@ from typing import Any, Callable, Dict, List, Mapping, Optional import requests from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models.llms import LLM -from langchain_core.pydantic_v1 import Extra, SecretStr, root_validator -from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env +from langchain_core.pydantic_v1 import Extra, SecretStr +from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init from requests import ConnectTimeout, ReadTimeout, RequestException from tenacity import ( before_sleep_log, @@ -65,7 +65,7 @@ class Nebula(LLM): extra = Extra.forbid - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" nebula_service_url = get_from_dict_or_env( diff --git a/libs/community/langchain_community/llms/tongyi.py b/libs/community/langchain_community/llms/tongyi.py index 8e13b6e03f1..7a75d116acf 100644 --- a/libs/community/langchain_community/llms/tongyi.py +++ b/libs/community/langchain_community/llms/tongyi.py @@ -24,8 +24,8 @@ from langchain_core.callbacks import ( ) from langchain_core.language_models.llms import BaseLLM from langchain_core.outputs import Generation, GenerationChunk, LLMResult -from langchain_core.pydantic_v1 import Field, root_validator -from langchain_core.utils import get_from_dict_or_env +from langchain_core.pydantic_v1 import Field +from langchain_core.utils import get_from_dict_or_env, pre_init from requests.exceptions import HTTPError from tenacity import ( before_sleep_log, @@ -198,7 +198,7 @@ class Tongyi(BaseLLM): """Return type of llm.""" return "tongyi" - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" values["dashscope_api_key"] = get_from_dict_or_env( diff --git a/libs/community/langchain_community/llms/vertexai.py b/libs/community/langchain_community/llms/vertexai.py index 8498bed5ff9..42572f3ac19 100644 --- a/libs/community/langchain_community/llms/vertexai.py +++ b/libs/community/langchain_community/llms/vertexai.py @@ -10,7 +10,8 @@ from langchain_core.callbacks.manager import ( ) from langchain_core.language_models.llms import BaseLLM from langchain_core.outputs import Generation, GenerationChunk, LLMResult -from langchain_core.pydantic_v1 import BaseModel, Field, root_validator +from langchain_core.pydantic_v1 import BaseModel, Field +from langchain_core.utils import pre_init from langchain_community.utilities.vertexai import ( create_retry_decorator, @@ -222,7 +223,7 @@ class VertexAI(_VertexAICommon, BaseLLM): """Get the namespace of the langchain object.""" return ["langchain", "llms", "vertexai"] - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that the python package exists in environment.""" tuned_model_name = values.get("tuned_model_name") @@ -409,7 +410,7 @@ class VertexAIModelGarden(_VertexAIBase, BaseLLM): "Set result_arg to None if output of the model is expected to be a string." "Otherwise, if it's a dict, provided an argument that contains the result." - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that the python package exists in environment.""" try: diff --git a/libs/community/langchain_community/llms/vllm.py b/libs/community/langchain_community/llms/vllm.py index 5710cf48723..f4eddf46350 100644 --- a/libs/community/langchain_community/llms/vllm.py +++ b/libs/community/langchain_community/llms/vllm.py @@ -3,7 +3,8 @@ from typing import Any, Dict, List, Optional from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models.llms import BaseLLM from langchain_core.outputs import Generation, LLMResult -from langchain_core.pydantic_v1 import Field, root_validator +from langchain_core.pydantic_v1 import Field +from langchain_core.utils import pre_init from langchain_community.llms.openai import BaseOpenAI from langchain_community.utils.openai import is_openai_v1 @@ -73,7 +74,7 @@ class VLLM(BaseLLM): client: Any #: :meta private: - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that python package exists in environment.""" diff --git a/libs/community/langchain_community/llms/volcengine_maas.py b/libs/community/langchain_community/llms/volcengine_maas.py index aa6a37cc2c4..dab4989adf3 100644 --- a/libs/community/langchain_community/llms/volcengine_maas.py +++ b/libs/community/langchain_community/llms/volcengine_maas.py @@ -5,8 +5,8 @@ from typing import Any, Dict, Iterator, List, Optional from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models.llms import LLM from langchain_core.outputs import GenerationChunk -from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator -from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env +from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr +from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init class VolcEngineMaasBase(BaseModel): @@ -52,7 +52,7 @@ class VolcEngineMaasBase(BaseModel): """Timeout for read response from volc engine maas endpoint. Default is 60 seconds.""" - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: volc_engine_maas_ak = convert_to_secret_str( get_from_dict_or_env(values, "volc_engine_maas_ak", "VOLC_ACCESSKEY") diff --git a/libs/community/langchain_community/llms/watsonxllm.py b/libs/community/langchain_community/llms/watsonxllm.py index 53255f17116..952e9f573e4 100644 --- a/libs/community/langchain_community/llms/watsonxllm.py +++ b/libs/community/langchain_community/llms/watsonxllm.py @@ -6,8 +6,8 @@ from langchain_core._api.deprecation import deprecated from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models.llms import BaseLLM from langchain_core.outputs import Generation, GenerationChunk, LLMResult -from langchain_core.pydantic_v1 import Extra, SecretStr, root_validator -from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env +from langchain_core.pydantic_v1 import Extra, SecretStr +from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init logger = logging.getLogger(__name__) @@ -115,7 +115,7 @@ class WatsonxLLM(BaseLLM): "instance_id": "WATSONX_INSTANCE_ID", } - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that credentials and python package exists in environment.""" values["url"] = convert_to_secret_str( diff --git a/libs/community/langchain_community/llms/writer.py b/libs/community/langchain_community/llms/writer.py index 3b7bc6a06c4..23d2986e8d2 100644 --- a/libs/community/langchain_community/llms/writer.py +++ b/libs/community/langchain_community/llms/writer.py @@ -3,8 +3,8 @@ from typing import Any, Dict, List, Mapping, Optional import requests from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models.llms import LLM -from langchain_core.pydantic_v1 import Extra, root_validator -from langchain_core.utils import get_from_dict_or_env +from langchain_core.pydantic_v1 import Extra +from langchain_core.utils import get_from_dict_or_env, pre_init from langchain_community.llms.utils import enforce_stop_tokens @@ -69,7 +69,7 @@ class Writer(LLM): extra = Extra.forbid - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and organization id exist in environment.""" diff --git a/libs/community/langchain_community/llms/yandex.py b/libs/community/langchain_community/llms/yandex.py index df5780f0c88..39e79ba4f4a 100644 --- a/libs/community/langchain_community/llms/yandex.py +++ b/libs/community/langchain_community/llms/yandex.py @@ -9,8 +9,8 @@ from langchain_core.callbacks import ( ) from langchain_core.language_models.llms import LLM from langchain_core.load.serializable import Serializable -from langchain_core.pydantic_v1 import SecretStr, root_validator -from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env +from langchain_core.pydantic_v1 import SecretStr +from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init from tenacity import ( before_sleep_log, retry, @@ -74,7 +74,7 @@ class _BaseYandexGPT(Serializable): "max_retries": self.max_retries, } - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that iam token exists in environment.""" diff --git a/libs/community/langchain_community/retrievers/arcee.py b/libs/community/langchain_community/retrievers/arcee.py index 90c6a996af6..821808f2f39 100644 --- a/libs/community/langchain_community/retrievers/arcee.py +++ b/libs/community/langchain_community/retrievers/arcee.py @@ -2,9 +2,9 @@ from typing import Any, Dict, List, Optional from langchain_core.callbacks import CallbackManagerForRetrieverRun from langchain_core.documents import Document -from langchain_core.pydantic_v1 import Extra, SecretStr, root_validator +from langchain_core.pydantic_v1 import Extra, SecretStr from langchain_core.retrievers import BaseRetriever -from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env +from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init from langchain_community.utilities.arcee import ArceeWrapper, DALMFilter @@ -70,7 +70,7 @@ class ArceeRetriever(BaseRetriever): self._client.validate_model_training_status() - @root_validator() + @pre_init def validate_environments(cls, values: Dict) -> Dict: """Validate Arcee environment variables.""" diff --git a/libs/community/langchain_community/retrievers/google_cloud_documentai_warehouse.py b/libs/community/langchain_community/retrievers/google_cloud_documentai_warehouse.py index b989cef5b02..d4722775013 100644 --- a/libs/community/langchain_community/retrievers/google_cloud_documentai_warehouse.py +++ b/libs/community/langchain_community/retrievers/google_cloud_documentai_warehouse.py @@ -5,9 +5,8 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional from langchain_core._api.deprecation import deprecated from langchain_core.callbacks import CallbackManagerForRetrieverRun from langchain_core.documents import Document -from langchain_core.pydantic_v1 import root_validator from langchain_core.retrievers import BaseRetriever -from langchain_core.utils import get_from_dict_or_env +from langchain_core.utils import get_from_dict_or_env, pre_init from langchain_community.utilities.vertexai import get_client_info @@ -48,7 +47,7 @@ class GoogleDocumentAIWarehouseRetriever(BaseRetriever): """The limit on the number of documents returned.""" client: "DocumentServiceClient" = None #: :meta private: - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validates the environment.""" try: diff --git a/libs/community/langchain_community/retrievers/pinecone_hybrid_search.py b/libs/community/langchain_community/retrievers/pinecone_hybrid_search.py index fdc18287165..b3a1c0cfcc9 100644 --- a/libs/community/langchain_community/retrievers/pinecone_hybrid_search.py +++ b/libs/community/langchain_community/retrievers/pinecone_hybrid_search.py @@ -6,8 +6,9 @@ from typing import Any, Dict, List, Optional from langchain_core.callbacks import CallbackManagerForRetrieverRun from langchain_core.documents import Document from langchain_core.embeddings import Embeddings -from langchain_core.pydantic_v1 import Extra, root_validator +from langchain_core.pydantic_v1 import Extra from langchain_core.retrievers import BaseRetriever +from langchain_core.utils import pre_init def hash_text(text: str) -> str: @@ -136,7 +137,7 @@ class PineconeHybridSearchRetriever(BaseRetriever): namespace=namespace, ) - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" try: diff --git a/libs/community/langchain_community/retrievers/qdrant_sparse_vector_retriever.py b/libs/community/langchain_community/retrievers/qdrant_sparse_vector_retriever.py index daf8917173d..40bd31db82c 100644 --- a/libs/community/langchain_community/retrievers/qdrant_sparse_vector_retriever.py +++ b/libs/community/langchain_community/retrievers/qdrant_sparse_vector_retriever.py @@ -15,8 +15,9 @@ from typing import ( from langchain_core.callbacks import CallbackManagerForRetrieverRun from langchain_core.documents import Document -from langchain_core.pydantic_v1 import Extra, root_validator +from langchain_core.pydantic_v1 import Extra from langchain_core.retrievers import BaseRetriever +from langchain_core.utils import pre_init from langchain_community.vectorstores.qdrant import Qdrant, QdrantException @@ -49,7 +50,7 @@ class QdrantSparseVectorRetriever(BaseRetriever): extra = Extra.forbid arbitrary_types_allowed = True - @root_validator() + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that 'qdrant_client' python package exists in environment.""" try: diff --git a/libs/community/langchain_community/retrievers/thirdai_neuraldb.py b/libs/community/langchain_community/retrievers/thirdai_neuraldb.py index e4a25aea4a1..22015b30331 100644 --- a/libs/community/langchain_community/retrievers/thirdai_neuraldb.py +++ b/libs/community/langchain_community/retrievers/thirdai_neuraldb.py @@ -7,9 +7,9 @@ from typing import Any, Dict, List, Optional, Tuple, Union from langchain_core.callbacks import CallbackManagerForRetrieverRun from langchain_core.documents import Document -from langchain_core.pydantic_v1 import Extra, SecretStr, root_validator +from langchain_core.pydantic_v1 import Extra, SecretStr from langchain_core.retrievers import BaseRetriever -from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env +from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init class NeuralDBRetriever(BaseRetriever): @@ -110,7 +110,7 @@ class NeuralDBRetriever(BaseRetriever): return cls(thirdai_key=thirdai_key, db=ndb.NeuralDB.from_checkpoint(checkpoint)) # type: ignore[arg-type] - @root_validator() + @pre_init def validate_environments(cls, values: Dict) -> Dict: """Validate ThirdAI environment variables.""" values["thirdai_key"] = convert_to_secret_str( diff --git a/libs/community/tests/integration_tests/vectorstores/test_vectara.py b/libs/community/tests/integration_tests/vectorstores/test_vectara.py index dcaae7dd0c9..d8f557bc6d1 100644 --- a/libs/community/tests/integration_tests/vectorstores/test_vectara.py +++ b/libs/community/tests/integration_tests/vectorstores/test_vectara.py @@ -1,6 +1,6 @@ import tempfile import urllib.request -from typing import Iterable +from typing import Generator, Iterable import pytest from langchain_core.documents import Document @@ -102,10 +102,10 @@ def test_vectara_add_documents(vectara1: Vectara) -> None: @pytest.fixture(scope="function") -def vectara2(): # type: ignore[no-untyped-def] +def vectara2() -> Generator[Vectara, None, None]: # download documents to local storage and then upload as files # attention paper and deep learning book - vectara2: Vectara = Vectara() + vectara2: Vectara = Vectara() # type: ignore urls = [ ( diff --git a/libs/community/tests/unit_tests/llms/test_aleph_alpha.py b/libs/community/tests/unit_tests/llms/test_aleph_alpha.py index ae09ecd116f..57cd544d7eb 100644 --- a/libs/community/tests/unit_tests/llms/test_aleph_alpha.py +++ b/libs/community/tests/unit_tests/llms/test_aleph_alpha.py @@ -9,7 +9,7 @@ from langchain_community.llms.aleph_alpha import AlephAlpha @pytest.mark.requires("aleph_alpha_client") def test_api_key_is_secret_string() -> None: - llm = AlephAlpha(aleph_alpha_api_key="secret-api-key") # type: ignore[call-arg] + llm = AlephAlpha(aleph_alpha_api_key="secret-api-key") # type: ignore assert isinstance(llm.aleph_alpha_api_key, SecretStr) @@ -17,7 +17,7 @@ def test_api_key_is_secret_string() -> None: def test_api_key_masked_when_passed_via_constructor( capsys: CaptureFixture, ) -> None: - llm = AlephAlpha(aleph_alpha_api_key="secret-api-key") # type: ignore[call-arg] + llm = AlephAlpha(aleph_alpha_api_key="secret-api-key") # type: ignore print(llm.aleph_alpha_api_key, end="") # noqa: T201 captured = capsys.readouterr() diff --git a/libs/community/tests/unit_tests/llms/test_bedrock.py b/libs/community/tests/unit_tests/llms/test_bedrock.py index 60012182936..42af3a01121 100644 --- a/libs/community/tests/unit_tests/llms/test_bedrock.py +++ b/libs/community/tests/unit_tests/llms/test_bedrock.py @@ -277,9 +277,12 @@ async def async_gen_mock_streaming_response() -> AsyncGenerator[Dict, None]: async def test_bedrock_async_streaming_call() -> None: # Mock boto3 import mock_boto3 = MagicMock() + session = MagicMock() + session.region_name = "region" + mock_boto3.Session.return_value = session mock_boto3.Session.return_value.client.return_value = ( - MagicMock() - ) # Mocking the client method of the Session object + session # Mocking the client method of the Session object + ) with patch.dict( "sys.modules", {"boto3": mock_boto3} diff --git a/libs/community/tests/unit_tests/load/__snapshots__/test_dump.ambr b/libs/community/tests/unit_tests/load/__snapshots__/test_dump.ambr index 2f98e081056..2789472be22 100644 --- a/libs/community/tests/unit_tests/load/__snapshots__/test_dump.ambr +++ b/libs/community/tests/unit_tests/load/__snapshots__/test_dump.ambr @@ -139,7 +139,7 @@ "model_name": "davinci", "temperature": 0.5, "max_tokens": 256, - "top_p": 1, + "top_p": 1.0, "n": 1, "best_of": 1, "openai_api_key": { @@ -655,7 +655,7 @@ "model_name": "davinci", "temperature": 0.5, "max_tokens": 256, - "top_p": 1, + "top_p": 1.0, "n": 1, "best_of": 1, "openai_api_key": { @@ -817,7 +817,7 @@ "model_name": "davinci", "temperature": 0.7, "max_tokens": 256, - "top_p": 1, + "top_p": 1.0, "n": 1, "best_of": 1, "openai_api_key": { diff --git a/libs/core/langchain_core/utils/__init__.py b/libs/core/langchain_core/utils/__init__.py index 80fbf680f3e..654f81df56b 100644 --- a/libs/core/langchain_core/utils/__init__.py +++ b/libs/core/langchain_core/utils/__init__.py @@ -16,6 +16,7 @@ from langchain_core.utils.input import ( ) from langchain_core.utils.iter import batch_iterate from langchain_core.utils.loading import try_load_from_hub +from langchain_core.utils.pydantic import pre_init from langchain_core.utils.strings import comma_list, stringify_dict, stringify_value from langchain_core.utils.utils import ( build_extra_kwargs, @@ -50,6 +51,7 @@ __all__ = [ "stringify_dict", "comma_list", "stringify_value", + "pre_init", "batch_iterate", "abatch_iterate", ] diff --git a/libs/core/langchain_core/utils/pydantic.py b/libs/core/langchain_core/utils/pydantic.py index 80ddb81fcb9..a79eb70fefe 100644 --- a/libs/core/langchain_core/utils/pydantic.py +++ b/libs/core/langchain_core/utils/pydantic.py @@ -1,5 +1,10 @@ """Utilities for tests.""" +from functools import wraps +from typing import Any, Callable, Dict, Type + +from langchain_core.pydantic_v1 import BaseModel, root_validator + def get_pydantic_major_version() -> int: """Get the major version of Pydantic.""" @@ -12,3 +17,35 @@ def get_pydantic_major_version() -> int: PYDANTIC_MAJOR_VERSION = get_pydantic_major_version() + + +# How to type hint this? +def pre_init(func: Callable) -> Any: + """Decorator to run a function before model initialization.""" + + @root_validator(pre=True) + @wraps(func) + def wrapper(cls: Type[BaseModel], values: Dict[str, Any]) -> Dict[str, Any]: + """Decorator to run a function before model initialization.""" + # Insert default values + fields = cls.__fields__ + for name, field_info in fields.items(): + # Check if allow_population_by_field_name is enabled + # If yes, then set the field name to the alias + if hasattr(cls, "Config"): + if hasattr(cls.Config, "allow_population_by_field_name"): + if cls.Config.allow_population_by_field_name: + if field_info.alias in values: + values[name] = values.pop(field_info.alias) + + if name not in values or values[name] is None: + if not field_info.required: + if field_info.default_factory is not None: + values[name] = field_info.default_factory() + else: + values[name] = field_info.default + + # Call the decorated function + return func(cls, values) + + return wrapper diff --git a/libs/core/tests/unit_tests/utils/test_imports.py b/libs/core/tests/unit_tests/utils/test_imports.py index 8a1d4236688..8cb909d3f70 100644 --- a/libs/core/tests/unit_tests/utils/test_imports.py +++ b/libs/core/tests/unit_tests/utils/test_imports.py @@ -24,6 +24,7 @@ EXPECTED_ALL = [ "stringify_dict", "comma_list", "stringify_value", + "pre_init", ] diff --git a/libs/core/tests/unit_tests/utils/test_pydantic.py b/libs/core/tests/unit_tests/utils/test_pydantic.py new file mode 100644 index 00000000000..cedf4db7c25 --- /dev/null +++ b/libs/core/tests/unit_tests/utils/test_pydantic.py @@ -0,0 +1,75 @@ +"""Test for some custom pydantic decorators.""" + +from typing import Any, Dict, Optional + +from langchain_core.pydantic_v1 import BaseModel, Field +from langchain_core.utils.pydantic import pre_init + + +def test_pre_init_decorator() -> None: + class Foo(BaseModel): + x: int = 5 + y: int + + @pre_init + def validator(cls, v: Dict[str, Any]) -> Dict[str, Any]: + v["y"] = v["x"] + 1 + return v + + # Type ignore initialization b/c y is marked as required + foo = Foo() # type: ignore + assert foo.y == 6 + foo = Foo(x=10) # type: ignore + assert foo.y == 11 + + +def test_pre_init_decorator_with_more_defaults() -> None: + class Foo(BaseModel): + a: int = 1 + b: Optional[int] = None + c: int = Field(default=2) + d: int = Field(default_factory=lambda: 3) + + @pre_init + def validator(cls, v: Dict[str, Any]) -> Dict[str, Any]: + assert v["a"] == 1 + assert v["b"] is None + assert v["c"] == 2 + assert v["d"] == 3 + return v + + # Try to create an instance of Foo + # nothing is required, but mypy can't track the default for `c` + Foo() # type: ignore + + +def test_with_aliases() -> None: + class Foo(BaseModel): + x: int = Field(default=1, alias="y") + z: int + + class Config: + allow_population_by_field_name = True + + @pre_init + def validator(cls, v: Dict[str, Any]) -> Dict[str, Any]: + v["z"] = v["x"] + return v + + # Based on defaults + # z is required + foo = Foo() # type: ignore + assert foo.x == 1 + assert foo.z == 1 + + # Based on field name + # z is required + foo = Foo(x=2) # type: ignore + assert foo.x == 2 + assert foo.z == 2 + + # Based on alias + # z is required + foo = Foo(y=2) # type: ignore + assert foo.x == 2 + assert foo.z == 2