community[patch]: Update @root_validators to use explicit pre=True or pre=False (#23737)

This commit is contained in:
Eugene Yurtsev 2024-07-02 10:47:21 -04:00 committed by GitHub
parent b664dbcc36
commit 46ff0f7a3c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 15 additions and 8 deletions

View File

@ -80,7 +80,7 @@ class AlephAlphaAsymmetricSemanticEmbedding(BaseModel, Embeddings):
nice to other users nice to other users
by de-prioritizing your request below concurrent ones.""" by de-prioritizing your request below concurrent ones."""
@root_validator() @root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""
aleph_alpha_api_key = get_from_dict_or_env( aleph_alpha_api_key = get_from_dict_or_env(

View File

@ -16,7 +16,7 @@ class AwaEmbeddings(BaseModel, Embeddings):
client: Any #: :meta private: client: Any #: :meta private:
model: str = "all-mpnet-base-v2" model: str = "all-mpnet-base-v2"
@root_validator() @root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that awadb library is installed.""" """Validate that awadb library is installed."""

View File

@ -54,7 +54,7 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings):
"""Automatically inferred from env var `OPENAI_API_VERSION` if not provided.""" """Automatically inferred from env var `OPENAI_API_VERSION` if not provided."""
validate_base_url: bool = True validate_base_url: bool = True
@root_validator() @root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""
# Check OPENAI_KEY for backwards compatibility. # Check OPENAI_KEY for backwards compatibility.
@ -96,8 +96,7 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings):
# See: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#embeddings # See: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#embeddings
values["chunk_size"] = min(values["chunk_size"], 16) values["chunk_size"] = min(values["chunk_size"], 16)
try: try:
import openai import openai # noqa: F401
except ImportError: except ImportError:
raise ImportError( raise ImportError(
"Could not import openai python package. " "Could not import openai python package. "
@ -137,6 +136,14 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings):
"/deployments/" + values["deployment"] "/deployments/" + values["deployment"]
) )
values["deployment"] = None values["deployment"] = None
return values
@root_validator(pre=False, skip_on_failure=True)
def post_init_validator(cls, values: Dict) -> Dict:
"""Validate that the base url is set."""
import openai
if is_openai_v1():
client_params = { client_params = {
"api_version": values["openai_api_version"], "api_version": values["openai_api_version"],
"azure_endpoint": values["azure_endpoint"], "azure_endpoint": values["azure_endpoint"],

View File

@ -73,7 +73,7 @@ class BedrockEmbeddings(BaseModel, Embeddings):
extra = Extra.forbid extra = Extra.forbid
@root_validator() @root_validator(pre=False, skip_on_failure=True)
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that AWS credentials to and python package exists in environment.""" """Validate that AWS credentials to and python package exists in environment."""

View File

@ -48,7 +48,7 @@ class ClarifaiEmbeddings(BaseModel, Embeddings):
extra = Extra.forbid extra = Extra.forbid
@root_validator() @root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that we have all required info to access Clarifai """Validate that we have all required info to access Clarifai
platform and python package exists in environment.""" platform and python package exists in environment."""

View File

@ -54,7 +54,7 @@ class CohereEmbeddings(BaseModel, Embeddings):
extra = Extra.forbid extra = Extra.forbid
@root_validator() @root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""
cohere_api_key = get_from_dict_or_env( cohere_api_key = get_from_dict_or_env(