From 5ef237cfb6dccd692a8b0f5f21a87b2c839ffce9 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Mon, 30 Sep 2024 21:40:18 -0400 Subject: [PATCH] fix private attr --- .../chat_models/snowflake.py | 43 ++++++++++--------- .../langchain_community/embeddings/laser.py | 25 ++++------- .../langchain_community/llms/minimax.py | 16 ++++--- .../langchain_community/llms/solar.py | 10 +++-- 4 files changed, 49 insertions(+), 45 deletions(-) diff --git a/libs/community/langchain_community/chat_models/snowflake.py b/libs/community/langchain_community/chat_models/snowflake.py index 52483818830..4c3037f941f 100644 --- a/libs/community/langchain_community/chat_models/snowflake.py +++ b/libs/community/langchain_community/chat_models/snowflake.py @@ -19,6 +19,7 @@ from langchain_core.utils import ( ) from langchain_core.utils.utils import build_extra_kwargs from pydantic import Field, SecretStr, model_validator +from typing_extensions import Self SUPPORTED_ROLES: List[str] = [ "system", @@ -139,14 +140,6 @@ class ChatSnowflakeCortex(BaseChatModel): @pre_init def validate_environment(cls, values: Dict) -> Dict: - try: - from snowflake.snowpark import Session - except ImportError: - raise ImportError( - "`snowflake-snowpark-python` package not found, please install it with " - "`pip install snowflake-snowpark-python`" - ) - values["snowflake_username"] = get_from_dict_or_env( values, "snowflake_username", "SNOWFLAKE_USERNAME" ) @@ -168,23 +161,33 @@ class ChatSnowflakeCortex(BaseChatModel): values["snowflake_role"] = get_from_dict_or_env( values, "snowflake_role", "SNOWFLAKE_ROLE" ) + return values - connection_params = { - "account": values["snowflake_account"], - "user": values["snowflake_username"], - "password": values["snowflake_password"].get_secret_value(), - "database": values["snowflake_database"], - "schema": values["snowflake_schema"], - "warehouse": values["snowflake_warehouse"], - "role": values["snowflake_role"], - } + @model_validator(mode="after") + def post_init(self) -> Self: + """Post initialization.""" + try: + from snowflake.snowpark import Session + except ImportError: + raise ImportError( + "`snowflake-snowpark-python` package not found, please install it with " + "`pip install snowflake-snowpark-python`" + ) try: - values["_sp_session"] = Session.builder.configs(connection_params).create() + connection_params = { + "account": self.snowflake_account, + "user": self.snowflake_username, + "password": self.snowflake_password.get_secret_value(), + "database": self.snowflake_database, + "schema": self.snowflake_schema, + "warehouse": self.snowflake_warehouse, + "role": self.snowflake_role, + } + self._sp_session = Session.builder.configs(connection_params).create() except Exception as e: raise ChatSnowflakeCortexError(f"Failed to create session: {e}") - - return values + return self def __del__(self) -> None: if getattr(self, "_sp_session", None) is not None: diff --git a/libs/community/langchain_community/embeddings/laser.py b/libs/community/langchain_community/embeddings/laser.py index 1299e55e518..6805edc0c08 100644 --- a/libs/community/langchain_community/embeddings/laser.py +++ b/libs/community/langchain_community/embeddings/laser.py @@ -1,9 +1,9 @@ -from typing import Any, Dict, List, Optional +from typing import Any, List import numpy as np from langchain_core.embeddings import Embeddings -from langchain_core.utils import pre_init -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, model_validator +from typing_extensions import Self LASER_MULTILINGUAL_MODEL: str = "laser2" @@ -27,7 +27,7 @@ class LaserEmbeddings(BaseModel, Embeddings): embeddings = encoder.encode_sentences(["Hello", "World"]) """ - lang: Optional[str] = None + lang: str = LASER_MULTILINGUAL_MODEL """The language or language code you'd like to use If empty, this implementation will default to using a multilingual earlier LASER encoder model (called laser2) @@ -41,25 +41,18 @@ class LaserEmbeddings(BaseModel, Embeddings): extra="forbid", ) - @pre_init - def validate_environment(cls, values: Dict) -> Dict: - """Validate that laser_encoders has been installed.""" + @model_validator(mode="after") + def post_init(self) -> Self: try: from laser_encoders import LaserEncoderPipeline - - lang = values.get("lang") - if lang: - encoder_pipeline = LaserEncoderPipeline(lang=lang) - else: - encoder_pipeline = LaserEncoderPipeline(laser=LASER_MULTILINGUAL_MODEL) - values["_encoder_pipeline"] = encoder_pipeline - except ImportError as e: raise ImportError( "Could not import 'laser_encoders' Python package. " "Please install it with `pip install laser_encoders`." ) from e - return values + + self._encoder_pipeline = LaserEncoderPipeline(lang=self.lang) + return self def embed_documents(self, texts: List[str]) -> List[List[float]]: """Generate embeddings for documents using LASER. diff --git a/libs/community/langchain_community/llms/minimax.py b/libs/community/langchain_community/llms/minimax.py index e508eb4f959..fdf114b3c6a 100644 --- a/libs/community/langchain_community/llms/minimax.py +++ b/libs/community/langchain_community/llms/minimax.py @@ -17,6 +17,7 @@ from langchain_core.callbacks import ( from langchain_core.language_models.llms import LLM from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init from pydantic import BaseModel, Field, SecretStr, model_validator +from typing_extensions import Self from langchain_community.llms.utils import enforce_stop_tokens @@ -89,13 +90,18 @@ class MinimaxCommon(BaseModel): "MINIMAX_API_HOST", default="https://api.minimax.chat", ) - values["_client"] = _MinimaxEndpointClient( # type: ignore[call-arg] - host=values["minimax_api_host"], - api_key=values["minimax_api_key"], - group_id=values["minimax_group_id"], - ) return values + @model_validator(mode="after") + def post_init(self) -> Self: + """Post initialization.""" + self._client = _MinimaxEndpointClient( + host=self.minimax_api_host, + api_key=self.minimax_api_key, + group_id=self.minimax_group_id, + ) + return self + @property def _default_params(self) -> Dict[str, Any]: """Get the default parameters for calling OpenAI API.""" diff --git a/libs/community/langchain_community/llms/solar.py b/libs/community/langchain_community/llms/solar.py index b5ea74bd632..071ffee8c68 100644 --- a/libs/community/langchain_community/llms/solar.py +++ b/libs/community/langchain_community/llms/solar.py @@ -11,6 +11,7 @@ from pydantic import ( SecretStr, model_validator, ) +from typing_extensions import Self from langchain_community.llms.utils import enforce_stop_tokens @@ -90,12 +91,13 @@ class SolarCommon(BaseModel): if "base_url" in values and not values["base_url"].startswith(SOLAR_SERVICE): raise ValueError("base_url must match with: " + SOLAR_SERVICE) - - values["_client"] = _SolarClient( - api_key=values["solar_api_key"], base_url=values["base_url"] - ) return values + @model_validator(mode="after") + def post_init(self) -> Self: + self._client = _SolarClient(api_key=self.solar_api_key, base_url=self.base_url) + return self + @property def _llm_type(self) -> str: return "solar"