From 46ff0f7a3cf5fcc5e3c40d6c8cb2939e0e280c7a Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Tue, 2 Jul 2024 10:47:21 -0400 Subject: [PATCH] community[patch]: Update @root_validators to use explicit pre=True or pre=False (#23737) --- .../langchain_community/embeddings/aleph_alpha.py | 2 +- .../community/langchain_community/embeddings/awa.py | 2 +- .../langchain_community/embeddings/azure_openai.py | 13 ++++++++++--- .../langchain_community/embeddings/bedrock.py | 2 +- .../langchain_community/embeddings/clarifai.py | 2 +- .../langchain_community/embeddings/cohere.py | 2 +- 6 files changed, 15 insertions(+), 8 deletions(-) diff --git a/libs/community/langchain_community/embeddings/aleph_alpha.py b/libs/community/langchain_community/embeddings/aleph_alpha.py index f2d78756573..409373d8bff 100644 --- a/libs/community/langchain_community/embeddings/aleph_alpha.py +++ b/libs/community/langchain_community/embeddings/aleph_alpha.py @@ -80,7 +80,7 @@ class AlephAlphaAsymmetricSemanticEmbedding(BaseModel, Embeddings): nice to other users by de-prioritizing your request below concurrent ones.""" - @root_validator() + @root_validator(pre=True) def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" aleph_alpha_api_key = get_from_dict_or_env( diff --git a/libs/community/langchain_community/embeddings/awa.py b/libs/community/langchain_community/embeddings/awa.py index 9145d8006a7..6f09b2d064c 100644 --- a/libs/community/langchain_community/embeddings/awa.py +++ b/libs/community/langchain_community/embeddings/awa.py @@ -16,7 +16,7 @@ class AwaEmbeddings(BaseModel, Embeddings): client: Any #: :meta private: model: str = "all-mpnet-base-v2" - @root_validator() + @root_validator(pre=True) def validate_environment(cls, values: Dict) -> Dict: """Validate that awadb library is installed.""" diff --git a/libs/community/langchain_community/embeddings/azure_openai.py b/libs/community/langchain_community/embeddings/azure_openai.py index f8eeca2b2cd..e607c3a4161 100644 --- a/libs/community/langchain_community/embeddings/azure_openai.py +++ b/libs/community/langchain_community/embeddings/azure_openai.py @@ -54,7 +54,7 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings): """Automatically inferred from env var `OPENAI_API_VERSION` if not provided.""" validate_base_url: bool = True - @root_validator() + @root_validator(pre=True) def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" # Check OPENAI_KEY for backwards compatibility. @@ -96,8 +96,7 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings): # See: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#embeddings values["chunk_size"] = min(values["chunk_size"], 16) try: - import openai - + import openai # noqa: F401 except ImportError: raise ImportError( "Could not import openai python package. " @@ -137,6 +136,14 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings): "/deployments/" + values["deployment"] ) values["deployment"] = None + return values + + @root_validator(pre=False, skip_on_failure=True) + def post_init_validator(cls, values: Dict) -> Dict: + """Validate that the base url is set.""" + import openai + + if is_openai_v1(): client_params = { "api_version": values["openai_api_version"], "azure_endpoint": values["azure_endpoint"], diff --git a/libs/community/langchain_community/embeddings/bedrock.py b/libs/community/langchain_community/embeddings/bedrock.py index 2c95fc8be79..f17f717d3f5 100644 --- a/libs/community/langchain_community/embeddings/bedrock.py +++ b/libs/community/langchain_community/embeddings/bedrock.py @@ -73,7 +73,7 @@ class BedrockEmbeddings(BaseModel, Embeddings): extra = Extra.forbid - @root_validator() + @root_validator(pre=False, skip_on_failure=True) def validate_environment(cls, values: Dict) -> Dict: """Validate that AWS credentials to and python package exists in environment.""" diff --git a/libs/community/langchain_community/embeddings/clarifai.py b/libs/community/langchain_community/embeddings/clarifai.py index e846df927a7..b0800c64bf7 100644 --- a/libs/community/langchain_community/embeddings/clarifai.py +++ b/libs/community/langchain_community/embeddings/clarifai.py @@ -48,7 +48,7 @@ class ClarifaiEmbeddings(BaseModel, Embeddings): extra = Extra.forbid - @root_validator() + @root_validator(pre=True) def validate_environment(cls, values: Dict) -> Dict: """Validate that we have all required info to access Clarifai platform and python package exists in environment.""" diff --git a/libs/community/langchain_community/embeddings/cohere.py b/libs/community/langchain_community/embeddings/cohere.py index eafc7fcbe9a..9f8d18eaaa1 100644 --- a/libs/community/langchain_community/embeddings/cohere.py +++ b/libs/community/langchain_community/embeddings/cohere.py @@ -54,7 +54,7 @@ class CohereEmbeddings(BaseModel, Embeddings): extra = Extra.forbid - @root_validator() + @root_validator(pre=True) def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" cohere_api_key = get_from_dict_or_env(