fix private attr

This commit is contained in:
Eugene Yurtsev 2024-09-30 21:40:18 -04:00
parent e317d457cf
commit 5ef237cfb6
4 changed files with 49 additions and 45 deletions

View File

@ -19,6 +19,7 @@ from langchain_core.utils import (
) )
from langchain_core.utils.utils import build_extra_kwargs from langchain_core.utils.utils import build_extra_kwargs
from pydantic import Field, SecretStr, model_validator from pydantic import Field, SecretStr, model_validator
from typing_extensions import Self
SUPPORTED_ROLES: List[str] = [ SUPPORTED_ROLES: List[str] = [
"system", "system",
@ -139,14 +140,6 @@ class ChatSnowflakeCortex(BaseChatModel):
@pre_init @pre_init
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
try:
from snowflake.snowpark import Session
except ImportError:
raise ImportError(
"`snowflake-snowpark-python` package not found, please install it with "
"`pip install snowflake-snowpark-python`"
)
values["snowflake_username"] = get_from_dict_or_env( values["snowflake_username"] = get_from_dict_or_env(
values, "snowflake_username", "SNOWFLAKE_USERNAME" values, "snowflake_username", "SNOWFLAKE_USERNAME"
) )
@ -168,23 +161,33 @@ class ChatSnowflakeCortex(BaseChatModel):
values["snowflake_role"] = get_from_dict_or_env( values["snowflake_role"] = get_from_dict_or_env(
values, "snowflake_role", "SNOWFLAKE_ROLE" values, "snowflake_role", "SNOWFLAKE_ROLE"
) )
return values
connection_params = { @model_validator(mode="after")
"account": values["snowflake_account"], def post_init(self) -> Self:
"user": values["snowflake_username"], """Post initialization."""
"password": values["snowflake_password"].get_secret_value(), try:
"database": values["snowflake_database"], from snowflake.snowpark import Session
"schema": values["snowflake_schema"], except ImportError:
"warehouse": values["snowflake_warehouse"], raise ImportError(
"role": values["snowflake_role"], "`snowflake-snowpark-python` package not found, please install it with "
} "`pip install snowflake-snowpark-python`"
)
try: try:
values["_sp_session"] = Session.builder.configs(connection_params).create() connection_params = {
"account": self.snowflake_account,
"user": self.snowflake_username,
"password": self.snowflake_password.get_secret_value(),
"database": self.snowflake_database,
"schema": self.snowflake_schema,
"warehouse": self.snowflake_warehouse,
"role": self.snowflake_role,
}
self._sp_session = Session.builder.configs(connection_params).create()
except Exception as e: except Exception as e:
raise ChatSnowflakeCortexError(f"Failed to create session: {e}") raise ChatSnowflakeCortexError(f"Failed to create session: {e}")
return self
return values
def __del__(self) -> None: def __del__(self) -> None:
if getattr(self, "_sp_session", None) is not None: if getattr(self, "_sp_session", None) is not None:

View File

@ -1,9 +1,9 @@
from typing import Any, Dict, List, Optional from typing import Any, List
import numpy as np import numpy as np
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from langchain_core.utils import pre_init from pydantic import BaseModel, ConfigDict, model_validator
from pydantic import BaseModel, ConfigDict from typing_extensions import Self
LASER_MULTILINGUAL_MODEL: str = "laser2" LASER_MULTILINGUAL_MODEL: str = "laser2"
@ -27,7 +27,7 @@ class LaserEmbeddings(BaseModel, Embeddings):
embeddings = encoder.encode_sentences(["Hello", "World"]) embeddings = encoder.encode_sentences(["Hello", "World"])
""" """
lang: Optional[str] = None lang: str = LASER_MULTILINGUAL_MODEL
"""The language or language code you'd like to use """The language or language code you'd like to use
If empty, this implementation will default If empty, this implementation will default
to using a multilingual earlier LASER encoder model (called laser2) to using a multilingual earlier LASER encoder model (called laser2)
@ -41,25 +41,18 @@ class LaserEmbeddings(BaseModel, Embeddings):
extra="forbid", extra="forbid",
) )
@pre_init @model_validator(mode="after")
def validate_environment(cls, values: Dict) -> Dict: def post_init(self) -> Self:
"""Validate that laser_encoders has been installed."""
try: try:
from laser_encoders import LaserEncoderPipeline from laser_encoders import LaserEncoderPipeline
lang = values.get("lang")
if lang:
encoder_pipeline = LaserEncoderPipeline(lang=lang)
else:
encoder_pipeline = LaserEncoderPipeline(laser=LASER_MULTILINGUAL_MODEL)
values["_encoder_pipeline"] = encoder_pipeline
except ImportError as e: except ImportError as e:
raise ImportError( raise ImportError(
"Could not import 'laser_encoders' Python package. " "Could not import 'laser_encoders' Python package. "
"Please install it with `pip install laser_encoders`." "Please install it with `pip install laser_encoders`."
) from e ) from e
return values
self._encoder_pipeline = LaserEncoderPipeline(lang=self.lang)
return self
def embed_documents(self, texts: List[str]) -> List[List[float]]: def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Generate embeddings for documents using LASER. """Generate embeddings for documents using LASER.

View File

@ -17,6 +17,7 @@ from langchain_core.callbacks import (
from langchain_core.language_models.llms import LLM from langchain_core.language_models.llms import LLM
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init
from pydantic import BaseModel, Field, SecretStr, model_validator from pydantic import BaseModel, Field, SecretStr, model_validator
from typing_extensions import Self
from langchain_community.llms.utils import enforce_stop_tokens from langchain_community.llms.utils import enforce_stop_tokens
@ -89,13 +90,18 @@ class MinimaxCommon(BaseModel):
"MINIMAX_API_HOST", "MINIMAX_API_HOST",
default="https://api.minimax.chat", default="https://api.minimax.chat",
) )
values["_client"] = _MinimaxEndpointClient( # type: ignore[call-arg]
host=values["minimax_api_host"],
api_key=values["minimax_api_key"],
group_id=values["minimax_group_id"],
)
return values return values
@model_validator(mode="after")
def post_init(self) -> Self:
"""Post initialization."""
self._client = _MinimaxEndpointClient(
host=self.minimax_api_host,
api_key=self.minimax_api_key,
group_id=self.minimax_group_id,
)
return self
@property @property
def _default_params(self) -> Dict[str, Any]: def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling OpenAI API.""" """Get the default parameters for calling OpenAI API."""

View File

@ -11,6 +11,7 @@ from pydantic import (
SecretStr, SecretStr,
model_validator, model_validator,
) )
from typing_extensions import Self
from langchain_community.llms.utils import enforce_stop_tokens from langchain_community.llms.utils import enforce_stop_tokens
@ -90,12 +91,13 @@ class SolarCommon(BaseModel):
if "base_url" in values and not values["base_url"].startswith(SOLAR_SERVICE): if "base_url" in values and not values["base_url"].startswith(SOLAR_SERVICE):
raise ValueError("base_url must match with: " + SOLAR_SERVICE) raise ValueError("base_url must match with: " + SOLAR_SERVICE)
values["_client"] = _SolarClient(
api_key=values["solar_api_key"], base_url=values["base_url"]
)
return values return values
@model_validator(mode="after")
def post_init(self) -> Self:
self._client = _SolarClient(api_key=self.solar_api_key, base_url=self.base_url)
return self
@property @property
def _llm_type(self) -> str: def _llm_type(self) -> str:
return "solar" return "solar"