mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-10 21:35:08 +00:00
fix private attr
This commit is contained in:
parent
e317d457cf
commit
5ef237cfb6
@ -19,6 +19,7 @@ from langchain_core.utils import (
|
|||||||
)
|
)
|
||||||
from langchain_core.utils.utils import build_extra_kwargs
|
from langchain_core.utils.utils import build_extra_kwargs
|
||||||
from pydantic import Field, SecretStr, model_validator
|
from pydantic import Field, SecretStr, model_validator
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
SUPPORTED_ROLES: List[str] = [
|
SUPPORTED_ROLES: List[str] = [
|
||||||
"system",
|
"system",
|
||||||
@ -139,14 +140,6 @@ class ChatSnowflakeCortex(BaseChatModel):
|
|||||||
|
|
||||||
@pre_init
|
@pre_init
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
try:
|
|
||||||
from snowflake.snowpark import Session
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError(
|
|
||||||
"`snowflake-snowpark-python` package not found, please install it with "
|
|
||||||
"`pip install snowflake-snowpark-python`"
|
|
||||||
)
|
|
||||||
|
|
||||||
values["snowflake_username"] = get_from_dict_or_env(
|
values["snowflake_username"] = get_from_dict_or_env(
|
||||||
values, "snowflake_username", "SNOWFLAKE_USERNAME"
|
values, "snowflake_username", "SNOWFLAKE_USERNAME"
|
||||||
)
|
)
|
||||||
@ -168,23 +161,33 @@ class ChatSnowflakeCortex(BaseChatModel):
|
|||||||
values["snowflake_role"] = get_from_dict_or_env(
|
values["snowflake_role"] = get_from_dict_or_env(
|
||||||
values, "snowflake_role", "SNOWFLAKE_ROLE"
|
values, "snowflake_role", "SNOWFLAKE_ROLE"
|
||||||
)
|
)
|
||||||
|
return values
|
||||||
|
|
||||||
connection_params = {
|
@model_validator(mode="after")
|
||||||
"account": values["snowflake_account"],
|
def post_init(self) -> Self:
|
||||||
"user": values["snowflake_username"],
|
"""Post initialization."""
|
||||||
"password": values["snowflake_password"].get_secret_value(),
|
try:
|
||||||
"database": values["snowflake_database"],
|
from snowflake.snowpark import Session
|
||||||
"schema": values["snowflake_schema"],
|
except ImportError:
|
||||||
"warehouse": values["snowflake_warehouse"],
|
raise ImportError(
|
||||||
"role": values["snowflake_role"],
|
"`snowflake-snowpark-python` package not found, please install it with "
|
||||||
}
|
"`pip install snowflake-snowpark-python`"
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
values["_sp_session"] = Session.builder.configs(connection_params).create()
|
connection_params = {
|
||||||
|
"account": self.snowflake_account,
|
||||||
|
"user": self.snowflake_username,
|
||||||
|
"password": self.snowflake_password.get_secret_value(),
|
||||||
|
"database": self.snowflake_database,
|
||||||
|
"schema": self.snowflake_schema,
|
||||||
|
"warehouse": self.snowflake_warehouse,
|
||||||
|
"role": self.snowflake_role,
|
||||||
|
}
|
||||||
|
self._sp_session = Session.builder.configs(connection_params).create()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ChatSnowflakeCortexError(f"Failed to create session: {e}")
|
raise ChatSnowflakeCortexError(f"Failed to create session: {e}")
|
||||||
|
return self
|
||||||
return values
|
|
||||||
|
|
||||||
def __del__(self) -> None:
|
def __del__(self) -> None:
|
||||||
if getattr(self, "_sp_session", None) is not None:
|
if getattr(self, "_sp_session", None) is not None:
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
from langchain_core.utils import pre_init
|
from pydantic import BaseModel, ConfigDict, model_validator
|
||||||
from pydantic import BaseModel, ConfigDict
|
from typing_extensions import Self
|
||||||
|
|
||||||
LASER_MULTILINGUAL_MODEL: str = "laser2"
|
LASER_MULTILINGUAL_MODEL: str = "laser2"
|
||||||
|
|
||||||
@ -27,7 +27,7 @@ class LaserEmbeddings(BaseModel, Embeddings):
|
|||||||
embeddings = encoder.encode_sentences(["Hello", "World"])
|
embeddings = encoder.encode_sentences(["Hello", "World"])
|
||||||
"""
|
"""
|
||||||
|
|
||||||
lang: Optional[str] = None
|
lang: str = LASER_MULTILINGUAL_MODEL
|
||||||
"""The language or language code you'd like to use
|
"""The language or language code you'd like to use
|
||||||
If empty, this implementation will default
|
If empty, this implementation will default
|
||||||
to using a multilingual earlier LASER encoder model (called laser2)
|
to using a multilingual earlier LASER encoder model (called laser2)
|
||||||
@ -41,25 +41,18 @@ class LaserEmbeddings(BaseModel, Embeddings):
|
|||||||
extra="forbid",
|
extra="forbid",
|
||||||
)
|
)
|
||||||
|
|
||||||
@pre_init
|
@model_validator(mode="after")
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def post_init(self) -> Self:
|
||||||
"""Validate that laser_encoders has been installed."""
|
|
||||||
try:
|
try:
|
||||||
from laser_encoders import LaserEncoderPipeline
|
from laser_encoders import LaserEncoderPipeline
|
||||||
|
|
||||||
lang = values.get("lang")
|
|
||||||
if lang:
|
|
||||||
encoder_pipeline = LaserEncoderPipeline(lang=lang)
|
|
||||||
else:
|
|
||||||
encoder_pipeline = LaserEncoderPipeline(laser=LASER_MULTILINGUAL_MODEL)
|
|
||||||
values["_encoder_pipeline"] = encoder_pipeline
|
|
||||||
|
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"Could not import 'laser_encoders' Python package. "
|
"Could not import 'laser_encoders' Python package. "
|
||||||
"Please install it with `pip install laser_encoders`."
|
"Please install it with `pip install laser_encoders`."
|
||||||
) from e
|
) from e
|
||||||
return values
|
|
||||||
|
self._encoder_pipeline = LaserEncoderPipeline(lang=self.lang)
|
||||||
|
return self
|
||||||
|
|
||||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
"""Generate embeddings for documents using LASER.
|
"""Generate embeddings for documents using LASER.
|
||||||
|
@ -17,6 +17,7 @@ from langchain_core.callbacks import (
|
|||||||
from langchain_core.language_models.llms import LLM
|
from langchain_core.language_models.llms import LLM
|
||||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init
|
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init
|
||||||
from pydantic import BaseModel, Field, SecretStr, model_validator
|
from pydantic import BaseModel, Field, SecretStr, model_validator
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
from langchain_community.llms.utils import enforce_stop_tokens
|
from langchain_community.llms.utils import enforce_stop_tokens
|
||||||
|
|
||||||
@ -89,13 +90,18 @@ class MinimaxCommon(BaseModel):
|
|||||||
"MINIMAX_API_HOST",
|
"MINIMAX_API_HOST",
|
||||||
default="https://api.minimax.chat",
|
default="https://api.minimax.chat",
|
||||||
)
|
)
|
||||||
values["_client"] = _MinimaxEndpointClient( # type: ignore[call-arg]
|
|
||||||
host=values["minimax_api_host"],
|
|
||||||
api_key=values["minimax_api_key"],
|
|
||||||
group_id=values["minimax_group_id"],
|
|
||||||
)
|
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def post_init(self) -> Self:
|
||||||
|
"""Post initialization."""
|
||||||
|
self._client = _MinimaxEndpointClient(
|
||||||
|
host=self.minimax_api_host,
|
||||||
|
api_key=self.minimax_api_key,
|
||||||
|
group_id=self.minimax_group_id,
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _default_params(self) -> Dict[str, Any]:
|
def _default_params(self) -> Dict[str, Any]:
|
||||||
"""Get the default parameters for calling OpenAI API."""
|
"""Get the default parameters for calling OpenAI API."""
|
||||||
|
@ -11,6 +11,7 @@ from pydantic import (
|
|||||||
SecretStr,
|
SecretStr,
|
||||||
model_validator,
|
model_validator,
|
||||||
)
|
)
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
from langchain_community.llms.utils import enforce_stop_tokens
|
from langchain_community.llms.utils import enforce_stop_tokens
|
||||||
|
|
||||||
@ -90,12 +91,13 @@ class SolarCommon(BaseModel):
|
|||||||
|
|
||||||
if "base_url" in values and not values["base_url"].startswith(SOLAR_SERVICE):
|
if "base_url" in values and not values["base_url"].startswith(SOLAR_SERVICE):
|
||||||
raise ValueError("base_url must match with: " + SOLAR_SERVICE)
|
raise ValueError("base_url must match with: " + SOLAR_SERVICE)
|
||||||
|
|
||||||
values["_client"] = _SolarClient(
|
|
||||||
api_key=values["solar_api_key"], base_url=values["base_url"]
|
|
||||||
)
|
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def post_init(self) -> Self:
|
||||||
|
self._client = _SolarClient(api_key=self.solar_api_key, base_url=self.base_url)
|
||||||
|
return self
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _llm_type(self) -> str:
|
def _llm_type(self) -> str:
|
||||||
return "solar"
|
return "solar"
|
||||||
|
Loading…
Reference in New Issue
Block a user