fix private attr

This commit is contained in:
Eugene Yurtsev 2024-09-30 21:40:18 -04:00
parent e317d457cf
commit 5ef237cfb6
4 changed files with 49 additions and 45 deletions

View File

@ -19,6 +19,7 @@ from langchain_core.utils import (
)
from langchain_core.utils.utils import build_extra_kwargs
from pydantic import Field, SecretStr, model_validator
from typing_extensions import Self
SUPPORTED_ROLES: List[str] = [
"system",
@ -139,14 +140,6 @@ class ChatSnowflakeCortex(BaseChatModel):
@pre_init
def validate_environment(cls, values: Dict) -> Dict:
try:
from snowflake.snowpark import Session
except ImportError:
raise ImportError(
"`snowflake-snowpark-python` package not found, please install it with "
"`pip install snowflake-snowpark-python`"
)
values["snowflake_username"] = get_from_dict_or_env(
values, "snowflake_username", "SNOWFLAKE_USERNAME"
)
@ -168,23 +161,33 @@ class ChatSnowflakeCortex(BaseChatModel):
values["snowflake_role"] = get_from_dict_or_env(
values, "snowflake_role", "SNOWFLAKE_ROLE"
)
return values
connection_params = {
"account": values["snowflake_account"],
"user": values["snowflake_username"],
"password": values["snowflake_password"].get_secret_value(),
"database": values["snowflake_database"],
"schema": values["snowflake_schema"],
"warehouse": values["snowflake_warehouse"],
"role": values["snowflake_role"],
}
@model_validator(mode="after")
def post_init(self) -> Self:
"""Post initialization."""
try:
from snowflake.snowpark import Session
except ImportError:
raise ImportError(
"`snowflake-snowpark-python` package not found, please install it with "
"`pip install snowflake-snowpark-python`"
)
try:
values["_sp_session"] = Session.builder.configs(connection_params).create()
connection_params = {
"account": self.snowflake_account,
"user": self.snowflake_username,
"password": self.snowflake_password.get_secret_value(),
"database": self.snowflake_database,
"schema": self.snowflake_schema,
"warehouse": self.snowflake_warehouse,
"role": self.snowflake_role,
}
self._sp_session = Session.builder.configs(connection_params).create()
except Exception as e:
raise ChatSnowflakeCortexError(f"Failed to create session: {e}")
return values
return self
def __del__(self) -> None:
if getattr(self, "_sp_session", None) is not None:

View File

@ -1,9 +1,9 @@
from typing import Any, Dict, List, Optional
from typing import Any, List
import numpy as np
from langchain_core.embeddings import Embeddings
from langchain_core.utils import pre_init
from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, model_validator
from typing_extensions import Self
LASER_MULTILINGUAL_MODEL: str = "laser2"
@ -27,7 +27,7 @@ class LaserEmbeddings(BaseModel, Embeddings):
embeddings = encoder.encode_sentences(["Hello", "World"])
"""
lang: Optional[str] = None
lang: str = LASER_MULTILINGUAL_MODEL
"""The language or language code you'd like to use
If empty, this implementation will default
to using a multilingual earlier LASER encoder model (called laser2)
@ -41,25 +41,18 @@ class LaserEmbeddings(BaseModel, Embeddings):
extra="forbid",
)
@pre_init
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that laser_encoders has been installed."""
@model_validator(mode="after")
def post_init(self) -> Self:
try:
from laser_encoders import LaserEncoderPipeline
lang = values.get("lang")
if lang:
encoder_pipeline = LaserEncoderPipeline(lang=lang)
else:
encoder_pipeline = LaserEncoderPipeline(laser=LASER_MULTILINGUAL_MODEL)
values["_encoder_pipeline"] = encoder_pipeline
except ImportError as e:
raise ImportError(
"Could not import 'laser_encoders' Python package. "
"Please install it with `pip install laser_encoders`."
) from e
return values
self._encoder_pipeline = LaserEncoderPipeline(lang=self.lang)
return self
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Generate embeddings for documents using LASER.

View File

@ -17,6 +17,7 @@ from langchain_core.callbacks import (
from langchain_core.language_models.llms import LLM
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init
from pydantic import BaseModel, Field, SecretStr, model_validator
from typing_extensions import Self
from langchain_community.llms.utils import enforce_stop_tokens
@ -89,13 +90,18 @@ class MinimaxCommon(BaseModel):
"MINIMAX_API_HOST",
default="https://api.minimax.chat",
)
values["_client"] = _MinimaxEndpointClient( # type: ignore[call-arg]
host=values["minimax_api_host"],
api_key=values["minimax_api_key"],
group_id=values["minimax_group_id"],
)
return values
@model_validator(mode="after")
def post_init(self) -> Self:
"""Post initialization."""
self._client = _MinimaxEndpointClient(
host=self.minimax_api_host,
api_key=self.minimax_api_key,
group_id=self.minimax_group_id,
)
return self
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling OpenAI API."""

View File

@ -11,6 +11,7 @@ from pydantic import (
SecretStr,
model_validator,
)
from typing_extensions import Self
from langchain_community.llms.utils import enforce_stop_tokens
@ -90,12 +91,13 @@ class SolarCommon(BaseModel):
if "base_url" in values and not values["base_url"].startswith(SOLAR_SERVICE):
raise ValueError("base_url must match with: " + SOLAR_SERVICE)
values["_client"] = _SolarClient(
api_key=values["solar_api_key"], base_url=values["base_url"]
)
return values
@model_validator(mode="after")
def post_init(self) -> Self:
self._client = _SolarClient(api_key=self.solar_api_key, base_url=self.base_url)
return self
@property
def _llm_type(self) -> str:
return "solar"