mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-09 13:00:34 +00:00
fix private attr
This commit is contained in:
parent
e317d457cf
commit
5ef237cfb6
@ -19,6 +19,7 @@ from langchain_core.utils import (
|
||||
)
|
||||
from langchain_core.utils.utils import build_extra_kwargs
|
||||
from pydantic import Field, SecretStr, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
SUPPORTED_ROLES: List[str] = [
|
||||
"system",
|
||||
@ -139,14 +140,6 @@ class ChatSnowflakeCortex(BaseChatModel):
|
||||
|
||||
@pre_init
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
try:
|
||||
from snowflake.snowpark import Session
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"`snowflake-snowpark-python` package not found, please install it with "
|
||||
"`pip install snowflake-snowpark-python`"
|
||||
)
|
||||
|
||||
values["snowflake_username"] = get_from_dict_or_env(
|
||||
values, "snowflake_username", "SNOWFLAKE_USERNAME"
|
||||
)
|
||||
@ -168,23 +161,33 @@ class ChatSnowflakeCortex(BaseChatModel):
|
||||
values["snowflake_role"] = get_from_dict_or_env(
|
||||
values, "snowflake_role", "SNOWFLAKE_ROLE"
|
||||
)
|
||||
return values
|
||||
|
||||
connection_params = {
|
||||
"account": values["snowflake_account"],
|
||||
"user": values["snowflake_username"],
|
||||
"password": values["snowflake_password"].get_secret_value(),
|
||||
"database": values["snowflake_database"],
|
||||
"schema": values["snowflake_schema"],
|
||||
"warehouse": values["snowflake_warehouse"],
|
||||
"role": values["snowflake_role"],
|
||||
}
|
||||
@model_validator(mode="after")
|
||||
def post_init(self) -> Self:
|
||||
"""Post initialization."""
|
||||
try:
|
||||
from snowflake.snowpark import Session
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"`snowflake-snowpark-python` package not found, please install it with "
|
||||
"`pip install snowflake-snowpark-python`"
|
||||
)
|
||||
|
||||
try:
|
||||
values["_sp_session"] = Session.builder.configs(connection_params).create()
|
||||
connection_params = {
|
||||
"account": self.snowflake_account,
|
||||
"user": self.snowflake_username,
|
||||
"password": self.snowflake_password.get_secret_value(),
|
||||
"database": self.snowflake_database,
|
||||
"schema": self.snowflake_schema,
|
||||
"warehouse": self.snowflake_warehouse,
|
||||
"role": self.snowflake_role,
|
||||
}
|
||||
self._sp_session = Session.builder.configs(connection_params).create()
|
||||
except Exception as e:
|
||||
raise ChatSnowflakeCortexError(f"Failed to create session: {e}")
|
||||
|
||||
return values
|
||||
return self
|
||||
|
||||
def __del__(self) -> None:
|
||||
if getattr(self, "_sp_session", None) is not None:
|
||||
|
@ -1,9 +1,9 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, List
|
||||
|
||||
import numpy as np
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.utils import pre_init
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic import BaseModel, ConfigDict, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
LASER_MULTILINGUAL_MODEL: str = "laser2"
|
||||
|
||||
@ -27,7 +27,7 @@ class LaserEmbeddings(BaseModel, Embeddings):
|
||||
embeddings = encoder.encode_sentences(["Hello", "World"])
|
||||
"""
|
||||
|
||||
lang: Optional[str] = None
|
||||
lang: str = LASER_MULTILINGUAL_MODEL
|
||||
"""The language or language code you'd like to use
|
||||
If empty, this implementation will default
|
||||
to using a multilingual earlier LASER encoder model (called laser2)
|
||||
@ -41,25 +41,18 @@ class LaserEmbeddings(BaseModel, Embeddings):
|
||||
extra="forbid",
|
||||
)
|
||||
|
||||
@pre_init
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that laser_encoders has been installed."""
|
||||
@model_validator(mode="after")
|
||||
def post_init(self) -> Self:
|
||||
try:
|
||||
from laser_encoders import LaserEncoderPipeline
|
||||
|
||||
lang = values.get("lang")
|
||||
if lang:
|
||||
encoder_pipeline = LaserEncoderPipeline(lang=lang)
|
||||
else:
|
||||
encoder_pipeline = LaserEncoderPipeline(laser=LASER_MULTILINGUAL_MODEL)
|
||||
values["_encoder_pipeline"] = encoder_pipeline
|
||||
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Could not import 'laser_encoders' Python package. "
|
||||
"Please install it with `pip install laser_encoders`."
|
||||
) from e
|
||||
return values
|
||||
|
||||
self._encoder_pipeline = LaserEncoderPipeline(lang=self.lang)
|
||||
return self
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Generate embeddings for documents using LASER.
|
||||
|
@ -17,6 +17,7 @@ from langchain_core.callbacks import (
|
||||
from langchain_core.language_models.llms import LLM
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init
|
||||
from pydantic import BaseModel, Field, SecretStr, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from langchain_community.llms.utils import enforce_stop_tokens
|
||||
|
||||
@ -89,13 +90,18 @@ class MinimaxCommon(BaseModel):
|
||||
"MINIMAX_API_HOST",
|
||||
default="https://api.minimax.chat",
|
||||
)
|
||||
values["_client"] = _MinimaxEndpointClient( # type: ignore[call-arg]
|
||||
host=values["minimax_api_host"],
|
||||
api_key=values["minimax_api_key"],
|
||||
group_id=values["minimax_group_id"],
|
||||
)
|
||||
return values
|
||||
|
||||
@model_validator(mode="after")
|
||||
def post_init(self) -> Self:
|
||||
"""Post initialization."""
|
||||
self._client = _MinimaxEndpointClient(
|
||||
host=self.minimax_api_host,
|
||||
api_key=self.minimax_api_key,
|
||||
group_id=self.minimax_group_id,
|
||||
)
|
||||
return self
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling OpenAI API."""
|
||||
|
@ -11,6 +11,7 @@ from pydantic import (
|
||||
SecretStr,
|
||||
model_validator,
|
||||
)
|
||||
from typing_extensions import Self
|
||||
|
||||
from langchain_community.llms.utils import enforce_stop_tokens
|
||||
|
||||
@ -90,12 +91,13 @@ class SolarCommon(BaseModel):
|
||||
|
||||
if "base_url" in values and not values["base_url"].startswith(SOLAR_SERVICE):
|
||||
raise ValueError("base_url must match with: " + SOLAR_SERVICE)
|
||||
|
||||
values["_client"] = _SolarClient(
|
||||
api_key=values["solar_api_key"], base_url=values["base_url"]
|
||||
)
|
||||
return values
|
||||
|
||||
@model_validator(mode="after")
|
||||
def post_init(self) -> Self:
|
||||
self._client = _SolarClient(api_key=self.solar_api_key, base_url=self.base_url)
|
||||
return self
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "solar"
|
||||
|
Loading…
Reference in New Issue
Block a user