mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-13 22:32:33 +00:00
Compare commits
10 Commits
sr/fix-chr
...
eugene/cle
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
081de0027d | ||
|
|
6809ae8563 | ||
|
|
7379d7129d | ||
|
|
d10703b4b5 | ||
|
|
148383bdfd | ||
|
|
afab4fef9f | ||
|
|
e6e7347adf | ||
|
|
1e174f6e5e | ||
|
|
0c81f260d1 | ||
|
|
7bdac9855f |
@@ -21,6 +21,7 @@ from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResu
|
||||
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
|
||||
from langchain_core.utils import (
|
||||
convert_to_secret_str,
|
||||
from_env,
|
||||
get_from_dict_or_env,
|
||||
get_pydantic_field_names,
|
||||
pre_init,
|
||||
@@ -98,9 +99,13 @@ class ChatHunyuan(BaseChatModel):
|
||||
def lc_serializable(self) -> bool:
|
||||
return True
|
||||
|
||||
hunyuan_app_id: Optional[int] = None
|
||||
hunyuan_app_id: Optional[int] = Field(
|
||||
default_factory=from_env("HUNYUAN_APP_ID", default=None)
|
||||
)
|
||||
"""Hunyuan App ID"""
|
||||
hunyuan_secret_id: Optional[str] = None
|
||||
hunyuan_secret_id: Optional[str] = Field(
|
||||
default_factory=from_env("HUNYUAN_SECRET_ID", default=None)
|
||||
)
|
||||
"""Hunyuan Secret ID"""
|
||||
hunyuan_secret_key: Optional[SecretStr] = None
|
||||
"""Hunyuan Secret Key"""
|
||||
@@ -163,16 +168,6 @@ class ChatHunyuan(BaseChatModel):
|
||||
|
||||
@pre_init
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
values["hunyuan_app_id"] = get_from_dict_or_env(
|
||||
values,
|
||||
"hunyuan_app_id",
|
||||
"HUNYUAN_APP_ID",
|
||||
)
|
||||
values["hunyuan_secret_id"] = get_from_dict_or_env(
|
||||
values,
|
||||
"hunyuan_secret_id",
|
||||
"HUNYUAN_SECRET_ID",
|
||||
)
|
||||
values["hunyuan_secret_key"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(
|
||||
values,
|
||||
|
||||
@@ -55,7 +55,7 @@ from langchain_core.outputs import (
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
from langchain_core.runnables import Runnable
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils import get_from_dict_or_env, pre_init
|
||||
from langchain_core.utils import from_env, get_from_dict_or_env, pre_init
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -219,12 +219,24 @@ class ChatLiteLLM(BaseChatModel):
|
||||
model: str = "gpt-3.5-turbo"
|
||||
model_name: Optional[str] = None
|
||||
"""Model name to use."""
|
||||
openai_api_key: Optional[str] = None
|
||||
azure_api_key: Optional[str] = None
|
||||
anthropic_api_key: Optional[str] = None
|
||||
replicate_api_key: Optional[str] = None
|
||||
cohere_api_key: Optional[str] = None
|
||||
openrouter_api_key: Optional[str] = None
|
||||
openai_api_key: Optional[str] = Field(
|
||||
default_factory=from_env("OPENAI_API_KEY", default="")
|
||||
)
|
||||
azure_api_key: Optional[str] = Field(
|
||||
default_factory=from_env("AZURE_API_KEY", default="")
|
||||
)
|
||||
anthropic_api_key: Optional[str] = Field(
|
||||
default_factory=from_env("ANTHROPIC_API_KEY", default="")
|
||||
)
|
||||
replicate_api_key: Optional[str] = Field(
|
||||
default_factory=from_env("REPLICATE_API_KEY", default="")
|
||||
)
|
||||
cohere_api_key: Optional[str] = Field(
|
||||
default_factory=from_env("COHERE_API_KEY", default="")
|
||||
)
|
||||
openrouter_api_key: Optional[str] = Field(
|
||||
default_factory=from_env("OPENROUTER_API_KEY", default="")
|
||||
)
|
||||
streaming: bool = False
|
||||
api_base: Optional[str] = None
|
||||
organization: Optional[str] = None
|
||||
@@ -302,24 +314,6 @@ class ChatLiteLLM(BaseChatModel):
|
||||
"Please install it with `pip install litellm`"
|
||||
)
|
||||
|
||||
values["openai_api_key"] = get_from_dict_or_env(
|
||||
values, "openai_api_key", "OPENAI_API_KEY", default=""
|
||||
)
|
||||
values["azure_api_key"] = get_from_dict_or_env(
|
||||
values, "azure_api_key", "AZURE_API_KEY", default=""
|
||||
)
|
||||
values["anthropic_api_key"] = get_from_dict_or_env(
|
||||
values, "anthropic_api_key", "ANTHROPIC_API_KEY", default=""
|
||||
)
|
||||
values["replicate_api_key"] = get_from_dict_or_env(
|
||||
values, "replicate_api_key", "REPLICATE_API_KEY", default=""
|
||||
)
|
||||
values["openrouter_api_key"] = get_from_dict_or_env(
|
||||
values, "openrouter_api_key", "OPENROUTER_API_KEY", default=""
|
||||
)
|
||||
values["cohere_api_key"] = get_from_dict_or_env(
|
||||
values, "cohere_api_key", "COHERE_API_KEY", default=""
|
||||
)
|
||||
values["huggingface_api_key"] = get_from_dict_or_env(
|
||||
values, "huggingface_api_key", "HUGGINGFACE_API_KEY", default=""
|
||||
)
|
||||
|
||||
@@ -3,7 +3,12 @@
|
||||
from typing import Dict
|
||||
|
||||
from langchain_core.pydantic_v1 import Field, SecretStr
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init
|
||||
from langchain_core.utils import (
|
||||
convert_to_secret_str,
|
||||
from_env,
|
||||
get_from_dict_or_env,
|
||||
pre_init,
|
||||
)
|
||||
|
||||
from langchain_community.chat_models.openai import ChatOpenAI
|
||||
from langchain_community.utils.openai import is_openai_v1
|
||||
@@ -31,9 +36,13 @@ class ChatOctoAI(ChatOpenAI):
|
||||
chat = ChatOctoAI(model_name="mixtral-8x7b-instruct")
|
||||
"""
|
||||
|
||||
octoai_api_base: str = Field(default=DEFAULT_API_BASE)
|
||||
octoai_api_base: str = Field(
|
||||
default_factory=from_env("OCTOAI_API_BASE", default=DEFAULT_API_BASE)
|
||||
)
|
||||
octoai_api_token: SecretStr = Field(default=None)
|
||||
model_name: str = Field(default=DEFAULT_MODEL)
|
||||
model_name: str = Field(
|
||||
default_factory=from_env("MODEL_NAME", default=DEFAULT_MODEL)
|
||||
)
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
@@ -51,21 +60,9 @@ class ChatOctoAI(ChatOpenAI):
|
||||
@pre_init
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
values["octoai_api_base"] = get_from_dict_or_env(
|
||||
values,
|
||||
"octoai_api_base",
|
||||
"OCTOAI_API_BASE",
|
||||
default=DEFAULT_API_BASE,
|
||||
)
|
||||
values["octoai_api_token"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(values, "octoai_api_token", "OCTOAI_API_TOKEN")
|
||||
)
|
||||
values["model_name"] = get_from_dict_or_env(
|
||||
values,
|
||||
"model_name",
|
||||
"MODEL_NAME",
|
||||
default=DEFAULT_MODEL,
|
||||
)
|
||||
|
||||
try:
|
||||
import openai
|
||||
|
||||
@@ -47,7 +47,7 @@ from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResu
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
|
||||
from langchain_core.runnables import Runnable
|
||||
from langchain_core.utils import (
|
||||
get_from_dict_or_env,
|
||||
from_env,
|
||||
get_pydantic_field_names,
|
||||
pre_init,
|
||||
)
|
||||
@@ -205,7 +205,9 @@ class ChatOpenAI(BaseChatModel):
|
||||
# When updating this to use a SecretStr
|
||||
# Check for classes that derive from this class (as some of them
|
||||
# may assume openai_api_key is a str)
|
||||
openai_api_key: Optional[str] = Field(default=None, alias="api_key")
|
||||
openai_api_key: Optional[str] = Field(
|
||||
default_factory=from_env("OPENAI_API_KEY", default=None), alias="api_key"
|
||||
)
|
||||
"""Automatically inferred from env var `OPENAI_API_KEY` if not provided."""
|
||||
openai_api_base: Optional[str] = Field(default=None, alias="base_url")
|
||||
"""Base URL path for API requests, leave blank if not using a proxy or service
|
||||
@@ -213,7 +215,9 @@ class ChatOpenAI(BaseChatModel):
|
||||
openai_organization: Optional[str] = Field(default=None, alias="organization")
|
||||
"""Automatically inferred from env var `OPENAI_ORG_ID` if not provided."""
|
||||
# to support explicit proxy for OpenAI
|
||||
openai_proxy: Optional[str] = None
|
||||
openai_proxy: Optional[str] = Field(
|
||||
default_factory=from_env("OPENAI_PROXY", default="")
|
||||
)
|
||||
request_timeout: Union[float, Tuple[float, float], Any, None] = Field(
|
||||
default=None, alias="timeout"
|
||||
)
|
||||
@@ -281,9 +285,6 @@ class ChatOpenAI(BaseChatModel):
|
||||
if values["n"] > 1 and values["streaming"]:
|
||||
raise ValueError("n must be 1 when streaming.")
|
||||
|
||||
values["openai_api_key"] = get_from_dict_or_env(
|
||||
values, "openai_api_key", "OPENAI_API_KEY"
|
||||
)
|
||||
# Check OPENAI_ORGANIZATION for backwards compatibility.
|
||||
values["openai_organization"] = (
|
||||
values["openai_organization"]
|
||||
@@ -293,12 +294,6 @@ class ChatOpenAI(BaseChatModel):
|
||||
values["openai_api_base"] = values["openai_api_base"] or os.getenv(
|
||||
"OPENAI_API_BASE"
|
||||
)
|
||||
values["openai_proxy"] = get_from_dict_or_env(
|
||||
values,
|
||||
"openai_proxy",
|
||||
"OPENAI_PROXY",
|
||||
default="",
|
||||
)
|
||||
try:
|
||||
import openai
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
|
||||
from langchain_core.utils import (
|
||||
convert_to_secret_str,
|
||||
from_env,
|
||||
get_from_dict_or_env,
|
||||
get_pydantic_field_names,
|
||||
pre_init,
|
||||
@@ -111,19 +112,31 @@ class ChatSnowflakeCortex(BaseChatModel):
|
||||
cumulative probabilities. Value should be ranging between 0.0 and 1.0.
|
||||
"""
|
||||
|
||||
snowflake_username: Optional[str] = Field(default=None, alias="username")
|
||||
snowflake_username: Optional[str] = Field(
|
||||
default_factory=from_env("SNOWFLAKE_USERNAME", default=None), alias="username"
|
||||
)
|
||||
"""Automatically inferred from env var `SNOWFLAKE_USERNAME` if not provided."""
|
||||
snowflake_password: Optional[SecretStr] = Field(default=None, alias="password")
|
||||
"""Automatically inferred from env var `SNOWFLAKE_PASSWORD` if not provided."""
|
||||
snowflake_account: Optional[str] = Field(default=None, alias="account")
|
||||
snowflake_account: Optional[str] = Field(
|
||||
default_factory=from_env("SNOWFLAKE_ACCOUNT", default=None), alias="account"
|
||||
)
|
||||
"""Automatically inferred from env var `SNOWFLAKE_ACCOUNT` if not provided."""
|
||||
snowflake_database: Optional[str] = Field(default=None, alias="database")
|
||||
snowflake_database: Optional[str] = Field(
|
||||
default_factory=from_env("SNOWFLAKE_DATABASE", default=None), alias="database"
|
||||
)
|
||||
"""Automatically inferred from env var `SNOWFLAKE_DATABASE` if not provided."""
|
||||
snowflake_schema: Optional[str] = Field(default=None, alias="schema")
|
||||
snowflake_schema: Optional[str] = Field(
|
||||
default_factory=from_env("SNOWFLAKE_SCHEMA", default=None), alias="schema"
|
||||
)
|
||||
"""Automatically inferred from env var `SNOWFLAKE_SCHEMA` if not provided."""
|
||||
snowflake_warehouse: Optional[str] = Field(default=None, alias="warehouse")
|
||||
snowflake_warehouse: Optional[str] = Field(
|
||||
default_factory=from_env("SNOWFLAKE_WAREHOUSE", default=None), alias="warehouse"
|
||||
)
|
||||
"""Automatically inferred from env var `SNOWFLAKE_WAREHOUSE` if not provided."""
|
||||
snowflake_role: Optional[str] = Field(default=None, alias="role")
|
||||
snowflake_role: Optional[str] = Field(
|
||||
default_factory=from_env("SNOWFLAKE_ROLE", default=None), alias="role"
|
||||
)
|
||||
"""Automatically inferred from env var `SNOWFLAKE_ROLE` if not provided."""
|
||||
|
||||
@root_validator(pre=True)
|
||||
@@ -146,27 +159,9 @@ class ChatSnowflakeCortex(BaseChatModel):
|
||||
"`pip install snowflake-snowpark-python`"
|
||||
)
|
||||
|
||||
values["snowflake_username"] = get_from_dict_or_env(
|
||||
values, "snowflake_username", "SNOWFLAKE_USERNAME"
|
||||
)
|
||||
values["snowflake_password"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(values, "snowflake_password", "SNOWFLAKE_PASSWORD")
|
||||
)
|
||||
values["snowflake_account"] = get_from_dict_or_env(
|
||||
values, "snowflake_account", "SNOWFLAKE_ACCOUNT"
|
||||
)
|
||||
values["snowflake_database"] = get_from_dict_or_env(
|
||||
values, "snowflake_database", "SNOWFLAKE_DATABASE"
|
||||
)
|
||||
values["snowflake_schema"] = get_from_dict_or_env(
|
||||
values, "snowflake_schema", "SNOWFLAKE_SCHEMA"
|
||||
)
|
||||
values["snowflake_warehouse"] = get_from_dict_or_env(
|
||||
values, "snowflake_warehouse", "SNOWFLAKE_WAREHOUSE"
|
||||
)
|
||||
values["snowflake_role"] = get_from_dict_or_env(
|
||||
values, "snowflake_role", "SNOWFLAKE_ROLE"
|
||||
)
|
||||
|
||||
connection_params = {
|
||||
"account": values["snowflake_account"],
|
||||
|
||||
@@ -5,7 +5,12 @@ from __future__ import annotations
|
||||
from typing import Dict
|
||||
|
||||
from langchain_core.pydantic_v1 import Field, SecretStr
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init
|
||||
from langchain_core.utils import (
|
||||
convert_to_secret_str,
|
||||
from_env,
|
||||
get_from_dict_or_env,
|
||||
pre_init,
|
||||
)
|
||||
|
||||
from langchain_community.embeddings.openai import OpenAIEmbeddings
|
||||
from langchain_community.utils.openai import is_openai_v1
|
||||
@@ -21,7 +26,9 @@ class AnyscaleEmbeddings(OpenAIEmbeddings):
|
||||
"""AnyScale Endpoints API keys."""
|
||||
model: str = Field(default=DEFAULT_MODEL)
|
||||
"""Model name to use."""
|
||||
anyscale_api_base: str = Field(default=DEFAULT_API_BASE)
|
||||
anyscale_api_base: str = Field(
|
||||
default_factory=from_env("ANYSCALE_API_BASE", default=DEFAULT_API_BASE)
|
||||
)
|
||||
"""Base URL path for API requests."""
|
||||
tiktoken_enabled: bool = False
|
||||
"""Set this to False for non-OpenAI implementations of the embeddings API"""
|
||||
@@ -44,12 +51,6 @@ class AnyscaleEmbeddings(OpenAIEmbeddings):
|
||||
"ANYSCALE_API_KEY",
|
||||
)
|
||||
)
|
||||
values["anyscale_api_base"] = get_from_dict_or_env(
|
||||
values,
|
||||
"anyscale_api_base",
|
||||
"ANYSCALE_API_BASE",
|
||||
default=DEFAULT_API_BASE,
|
||||
)
|
||||
try:
|
||||
import openai
|
||||
|
||||
|
||||
@@ -6,9 +6,9 @@ from typing import Dict, List, Optional
|
||||
import requests
|
||||
from langchain_core._api.deprecation import deprecated
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
from langchain_core.runnables.config import run_in_executor
|
||||
from langchain_core.utils import get_from_dict_or_env, pre_init
|
||||
from langchain_core.utils import from_env, pre_init
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -20,9 +20,15 @@ logger = logging.getLogger(__name__)
|
||||
class ErnieEmbeddings(BaseModel, Embeddings):
|
||||
"""`Ernie Embeddings V1` embedding models."""
|
||||
|
||||
ernie_api_base: Optional[str] = None
|
||||
ernie_client_id: Optional[str] = None
|
||||
ernie_client_secret: Optional[str] = None
|
||||
ernie_api_base: Optional[str] = Field(
|
||||
default_factory=from_env("ERNIE_API_BASE", default="https://aip.baidubce.com")
|
||||
)
|
||||
ernie_client_id: Optional[str] = Field(
|
||||
default_factory=from_env("ERNIE_CLIENT_ID", default=None)
|
||||
)
|
||||
ernie_client_secret: Optional[str] = Field(
|
||||
default_factory=from_env("ERNIE_CLIENT_SECRET", default=None)
|
||||
)
|
||||
access_token: Optional[str] = None
|
||||
|
||||
chunk_size: int = 16
|
||||
@@ -33,19 +39,6 @@ class ErnieEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
@pre_init
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
values["ernie_api_base"] = get_from_dict_or_env(
|
||||
values, "ernie_api_base", "ERNIE_API_BASE", "https://aip.baidubce.com"
|
||||
)
|
||||
values["ernie_client_id"] = get_from_dict_or_env(
|
||||
values,
|
||||
"ernie_client_id",
|
||||
"ERNIE_CLIENT_ID",
|
||||
)
|
||||
values["ernie_client_secret"] = get_from_dict_or_env(
|
||||
values,
|
||||
"ernie_client_secret",
|
||||
"ERNIE_CLIENT_SECRET",
|
||||
)
|
||||
return values
|
||||
|
||||
def _embedding(self, json: object) -> dict:
|
||||
|
||||
@@ -1,7 +1,12 @@
|
||||
from typing import Dict
|
||||
|
||||
from langchain_core.pydantic_v1 import Field, SecretStr
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init
|
||||
from langchain_core.utils import (
|
||||
convert_to_secret_str,
|
||||
from_env,
|
||||
get_from_dict_or_env,
|
||||
pre_init,
|
||||
)
|
||||
|
||||
from langchain_community.embeddings.openai import OpenAIEmbeddings
|
||||
from langchain_community.utils.openai import is_openai_v1
|
||||
@@ -22,9 +27,11 @@ class OctoAIEmbeddings(OpenAIEmbeddings):
|
||||
|
||||
octoai_api_token: SecretStr = Field(default=None)
|
||||
"""OctoAI Endpoints API keys."""
|
||||
endpoint_url: str = Field(default=DEFAULT_API_BASE)
|
||||
endpoint_url: str = Field(
|
||||
default_factory=from_env("ENDPOINT_URL", default=DEFAULT_API_BASE)
|
||||
)
|
||||
"""Base URL path for API requests."""
|
||||
model: str = Field(default=DEFAULT_MODEL)
|
||||
model: str = Field(default_factory=from_env("MODEL", default=DEFAULT_MODEL))
|
||||
"""Model name to use."""
|
||||
tiktoken_enabled: bool = False
|
||||
"""Set this to False for non-OpenAI implementations of the embeddings API"""
|
||||
@@ -41,21 +48,9 @@ class OctoAIEmbeddings(OpenAIEmbeddings):
|
||||
@pre_init
|
||||
def validate_environment(cls, values: dict) -> dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
values["endpoint_url"] = get_from_dict_or_env(
|
||||
values,
|
||||
"endpoint_url",
|
||||
"ENDPOINT_URL",
|
||||
default=DEFAULT_API_BASE,
|
||||
)
|
||||
values["octoai_api_token"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(values, "octoai_api_token", "OCTOAI_API_TOKEN")
|
||||
)
|
||||
values["model"] = get_from_dict_or_env(
|
||||
values,
|
||||
"model",
|
||||
"MODEL",
|
||||
default=DEFAULT_MODEL,
|
||||
)
|
||||
|
||||
try:
|
||||
import openai
|
||||
|
||||
@@ -23,6 +23,7 @@ from langchain_core._api.deprecation import deprecated
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
|
||||
from langchain_core.utils import (
|
||||
from_env,
|
||||
get_from_dict_or_env,
|
||||
get_pydantic_field_names,
|
||||
pre_init,
|
||||
@@ -202,12 +203,18 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
"""Base URL path for API requests, leave blank if not using a proxy or service
|
||||
emulator."""
|
||||
# to support Azure OpenAI Service custom endpoints
|
||||
openai_api_type: Optional[str] = None
|
||||
openai_api_type: Optional[str] = Field(
|
||||
default_factory=from_env("OPENAI_API_TYPE", default="")
|
||||
)
|
||||
# to support explicit proxy for OpenAI
|
||||
openai_proxy: Optional[str] = None
|
||||
openai_proxy: Optional[str] = Field(
|
||||
default_factory=from_env("OPENAI_PROXY", default="")
|
||||
)
|
||||
embedding_ctx_length: int = 8191
|
||||
"""The maximum number of tokens to embed at once."""
|
||||
openai_api_key: Optional[str] = Field(default=None, alias="api_key")
|
||||
openai_api_key: Optional[str] = Field(
|
||||
default_factory=from_env("OPENAI_API_KEY"), alias="api_key"
|
||||
)
|
||||
"""Automatically inferred from env var `OPENAI_API_KEY` if not provided."""
|
||||
openai_organization: Optional[str] = Field(default=None, alias="organization")
|
||||
"""Automatically inferred from env var `OPENAI_ORG_ID` if not provided."""
|
||||
@@ -287,24 +294,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
@pre_init
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
values["openai_api_key"] = get_from_dict_or_env(
|
||||
values, "openai_api_key", "OPENAI_API_KEY"
|
||||
)
|
||||
values["openai_api_base"] = values["openai_api_base"] or os.getenv(
|
||||
"OPENAI_API_BASE"
|
||||
)
|
||||
values["openai_api_type"] = get_from_dict_or_env(
|
||||
values,
|
||||
"openai_api_type",
|
||||
"OPENAI_API_TYPE",
|
||||
default="",
|
||||
)
|
||||
values["openai_proxy"] = get_from_dict_or_env(
|
||||
values,
|
||||
"openai_proxy",
|
||||
"OPENAI_PROXY",
|
||||
default="",
|
||||
)
|
||||
if values["openai_api_type"] in ("azure", "azure_ad", "azuread"):
|
||||
default_api_version = "2023-05-15"
|
||||
# Azure OpenAI embedding models allow a maximum of 16 texts
|
||||
|
||||
@@ -3,8 +3,8 @@ from typing import Dict, Generator, List, Optional
|
||||
|
||||
import requests
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.utils import get_from_dict_or_env, pre_init
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
from langchain_core.utils import from_env, get_from_dict_or_env, pre_init
|
||||
|
||||
|
||||
class SambaStudioEmbeddings(BaseModel, Embeddings):
|
||||
@@ -43,19 +43,27 @@ class SambaStudioEmbeddings(BaseModel, Embeddings):
|
||||
)
|
||||
"""
|
||||
|
||||
sambastudio_embeddings_base_url: str = ""
|
||||
sambastudio_embeddings_base_url: str = Field(
|
||||
default_factory=from_env("SAMBASTUDIO_EMBEDDINGS_BASE_URL", default="")
|
||||
)
|
||||
"""Base url to use"""
|
||||
|
||||
sambastudio_embeddings_base_uri: str = ""
|
||||
"""endpoint base uri"""
|
||||
|
||||
sambastudio_embeddings_project_id: str = ""
|
||||
sambastudio_embeddings_project_id: str = Field(
|
||||
default_factory=from_env("SAMBASTUDIO_EMBEDDINGS_PROJECT_ID", default="")
|
||||
)
|
||||
"""Project id on sambastudio for model"""
|
||||
|
||||
sambastudio_embeddings_endpoint_id: str = ""
|
||||
sambastudio_embeddings_endpoint_id: str = Field(
|
||||
default_factory=from_env("SAMBASTUDIO_EMBEDDINGS_ENDPOINT_ID", default="")
|
||||
)
|
||||
"""endpoint id on sambastudio for model"""
|
||||
|
||||
sambastudio_embeddings_api_key: str = ""
|
||||
sambastudio_embeddings_api_key: str = Field(
|
||||
default_factory=from_env("SAMBASTUDIO_EMBEDDINGS_API_KEY", default="")
|
||||
)
|
||||
"""sambastudio api key"""
|
||||
|
||||
model_kwargs: dict = {}
|
||||
@@ -67,28 +75,12 @@ class SambaStudioEmbeddings(BaseModel, Embeddings):
|
||||
@pre_init
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
values["sambastudio_embeddings_base_url"] = get_from_dict_or_env(
|
||||
values, "sambastudio_embeddings_base_url", "SAMBASTUDIO_EMBEDDINGS_BASE_URL"
|
||||
)
|
||||
values["sambastudio_embeddings_base_uri"] = get_from_dict_or_env(
|
||||
values,
|
||||
"sambastudio_embeddings_base_uri",
|
||||
"SAMBASTUDIO_EMBEDDINGS_BASE_URI",
|
||||
default="api/predict/generic",
|
||||
)
|
||||
values["sambastudio_embeddings_project_id"] = get_from_dict_or_env(
|
||||
values,
|
||||
"sambastudio_embeddings_project_id",
|
||||
"SAMBASTUDIO_EMBEDDINGS_PROJECT_ID",
|
||||
)
|
||||
values["sambastudio_embeddings_endpoint_id"] = get_from_dict_or_env(
|
||||
values,
|
||||
"sambastudio_embeddings_endpoint_id",
|
||||
"SAMBASTUDIO_EMBEDDINGS_ENDPOINT_ID",
|
||||
)
|
||||
values["sambastudio_embeddings_api_key"] = get_from_dict_or_env(
|
||||
values, "sambastudio_embeddings_api_key", "SAMBASTUDIO_EMBEDDINGS_API_KEY"
|
||||
)
|
||||
return values
|
||||
|
||||
def _get_tuning_params(self) -> str:
|
||||
|
||||
@@ -4,8 +4,8 @@ import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.utils import get_from_dict_or_env, pre_init
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
from langchain_core.utils import from_env, pre_init
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -13,11 +13,15 @@ logger = logging.getLogger(__name__)
|
||||
class VolcanoEmbeddings(BaseModel, Embeddings):
|
||||
"""`Volcengine Embeddings` embedding models."""
|
||||
|
||||
volcano_ak: Optional[str] = None
|
||||
volcano_ak: Optional[str] = Field(
|
||||
default_factory=from_env("VOLC_ACCESSKEY", default=None)
|
||||
)
|
||||
"""volcano access key
|
||||
learn more from: https://www.volcengine.com/docs/6459/76491#ak-sk"""
|
||||
|
||||
volcano_sk: Optional[str] = None
|
||||
volcano_sk: Optional[str] = Field(
|
||||
default_factory=from_env("VOLC_SECRETKEY", default=None)
|
||||
)
|
||||
"""volcano secret key
|
||||
learn more from: https://www.volcengine.com/docs/6459/76491#ak-sk"""
|
||||
|
||||
@@ -66,16 +70,6 @@ class VolcanoEmbeddings(BaseModel, Embeddings):
|
||||
ValueError: volcengine package not found, please install it with
|
||||
`pip install volcengine`
|
||||
"""
|
||||
values["volcano_ak"] = get_from_dict_or_env(
|
||||
values,
|
||||
"volcano_ak",
|
||||
"VOLC_ACCESSKEY",
|
||||
)
|
||||
values["volcano_sk"] = get_from_dict_or_env(
|
||||
values,
|
||||
"volcano_sk",
|
||||
"VOLC_SECRETKEY",
|
||||
)
|
||||
|
||||
try:
|
||||
from volcengine.maas import MaasService
|
||||
|
||||
@@ -23,6 +23,7 @@ from langchain_core.prompt_values import PromptValue
|
||||
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
|
||||
from langchain_core.utils import (
|
||||
check_package_version,
|
||||
from_env,
|
||||
get_from_dict_or_env,
|
||||
get_pydantic_field_names,
|
||||
pre_init,
|
||||
@@ -57,7 +58,11 @@ class _AnthropicCommon(BaseLanguageModel):
|
||||
max_retries: int = 2
|
||||
"""Number of retries allowed for requests sent to the Anthropic Completion API."""
|
||||
|
||||
anthropic_api_url: Optional[str] = None
|
||||
anthropic_api_url: Optional[str] = Field(
|
||||
default_factory=from_env(
|
||||
"ANTHROPIC_API_URL", default="https://api.anthropic.com"
|
||||
)
|
||||
)
|
||||
|
||||
anthropic_api_key: Optional[SecretStr] = None
|
||||
|
||||
@@ -82,12 +87,6 @@ class _AnthropicCommon(BaseLanguageModel):
|
||||
get_from_dict_or_env(values, "anthropic_api_key", "ANTHROPIC_API_KEY")
|
||||
)
|
||||
# Get custom api url from environment.
|
||||
values["anthropic_api_url"] = get_from_dict_or_env(
|
||||
values,
|
||||
"anthropic_api_url",
|
||||
"ANTHROPIC_API_URL",
|
||||
default="https://api.anthropic.com",
|
||||
)
|
||||
|
||||
try:
|
||||
import anthropic
|
||||
|
||||
@@ -15,7 +15,12 @@ from langchain_core.callbacks import (
|
||||
)
|
||||
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
|
||||
from langchain_core.pydantic_v1 import Field, SecretStr
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init
|
||||
from langchain_core.utils import (
|
||||
convert_to_secret_str,
|
||||
from_env,
|
||||
get_from_dict_or_env,
|
||||
pre_init,
|
||||
)
|
||||
|
||||
from langchain_community.llms.openai import (
|
||||
BaseOpenAI,
|
||||
@@ -83,9 +88,13 @@ class Anyscale(BaseOpenAI):
|
||||
"""
|
||||
|
||||
"""Key word arguments to pass to the model."""
|
||||
anyscale_api_base: str = Field(default=DEFAULT_BASE_URL)
|
||||
anyscale_api_base: str = Field(
|
||||
default_factory=from_env("ANYSCALE_API_BASE", default=DEFAULT_BASE_URL)
|
||||
)
|
||||
anyscale_api_key: SecretStr = Field(default=None)
|
||||
model_name: str = Field(default=DEFAULT_MODEL)
|
||||
model_name: str = Field(
|
||||
default_factory=from_env("MODEL_NAME", default=DEFAULT_MODEL)
|
||||
)
|
||||
|
||||
prefix_messages: List = Field(default_factory=list)
|
||||
|
||||
@@ -96,21 +105,9 @@ class Anyscale(BaseOpenAI):
|
||||
@pre_init
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
values["anyscale_api_base"] = get_from_dict_or_env(
|
||||
values,
|
||||
"anyscale_api_base",
|
||||
"ANYSCALE_API_BASE",
|
||||
default=DEFAULT_BASE_URL,
|
||||
)
|
||||
values["anyscale_api_key"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(values, "anyscale_api_key", "ANYSCALE_API_KEY")
|
||||
)
|
||||
values["model_name"] = get_from_dict_or_env(
|
||||
values,
|
||||
"model_name",
|
||||
"MODEL_NAME",
|
||||
default=DEFAULT_MODEL,
|
||||
)
|
||||
|
||||
try:
|
||||
import openai
|
||||
|
||||
@@ -8,7 +8,12 @@ import requests
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models.llms import LLM
|
||||
from langchain_core.pydantic_v1 import Field, SecretStr
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init
|
||||
from langchain_core.utils import (
|
||||
convert_to_secret_str,
|
||||
from_env,
|
||||
get_from_dict_or_env,
|
||||
pre_init,
|
||||
)
|
||||
|
||||
from langchain_community.llms.utils import enforce_stop_tokens
|
||||
|
||||
@@ -28,7 +33,12 @@ class BaichuanLLM(LLM):
|
||||
timeout: int = 60
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
baichuan_api_host: Optional[str] = None
|
||||
baichuan_api_host: Optional[str] = Field(
|
||||
default_factory=from_env(
|
||||
"BAICHUAN_API_HOST",
|
||||
default="https://api.baichuan-ai.com/v1/chat/completions",
|
||||
)
|
||||
)
|
||||
baichuan_api_key: Optional[SecretStr] = None
|
||||
|
||||
@pre_init
|
||||
@@ -36,12 +46,6 @@ class BaichuanLLM(LLM):
|
||||
values["baichuan_api_key"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(values, "baichuan_api_key", "BAICHUAN_API_KEY")
|
||||
)
|
||||
values["baichuan_api_host"] = get_from_dict_or_env(
|
||||
values,
|
||||
"baichuan_api_host",
|
||||
"BAICHUAN_API_HOST",
|
||||
default="https://api.baichuan-ai.com/v1/chat/completions",
|
||||
)
|
||||
return values
|
||||
|
||||
@property
|
||||
|
||||
@@ -10,7 +10,7 @@ from langchain_core.callbacks import (
|
||||
)
|
||||
from langchain_core.language_models.llms import LLM
|
||||
from langchain_core.pydantic_v1 import Field, root_validator
|
||||
from langchain_core.utils import get_from_dict_or_env, pre_init
|
||||
from langchain_core.utils import from_env, pre_init
|
||||
from langchain_core.utils.pydantic import get_fields
|
||||
|
||||
from langchain_community.llms.utils import enforce_stop_tokens
|
||||
@@ -34,7 +34,9 @@ class EdenAI(LLM):
|
||||
|
||||
base_url: str = "https://api.edenai.run/v2"
|
||||
|
||||
edenai_api_key: Optional[str] = None
|
||||
edenai_api_key: Optional[str] = Field(
|
||||
default_factory=from_env("EDENAI_API_KEY", default=None)
|
||||
)
|
||||
|
||||
feature: Literal["text", "image"] = "text"
|
||||
"""Which generative feature to use, use text by default"""
|
||||
@@ -75,9 +77,6 @@ class EdenAI(LLM):
|
||||
@pre_init
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key exists in environment."""
|
||||
values["edenai_api_key"] = get_from_dict_or_env(
|
||||
values, "edenai_api_key", "EDENAI_API_KEY"
|
||||
)
|
||||
return values
|
||||
|
||||
@root_validator(pre=True)
|
||||
|
||||
@@ -16,7 +16,12 @@ from langchain_core.callbacks import (
|
||||
)
|
||||
from langchain_core.language_models.llms import LLM
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init
|
||||
from langchain_core.utils import (
|
||||
convert_to_secret_str,
|
||||
from_env,
|
||||
get_from_dict_or_env,
|
||||
pre_init,
|
||||
)
|
||||
|
||||
from langchain_community.llms.utils import enforce_stop_tokens
|
||||
|
||||
@@ -68,8 +73,12 @@ class MinimaxCommon(BaseModel):
|
||||
"""Total probability mass of tokens to consider at each step."""
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Holds any model parameters valid for `create` call not explicitly specified."""
|
||||
minimax_api_host: Optional[str] = None
|
||||
minimax_group_id: Optional[str] = None
|
||||
minimax_api_host: Optional[str] = Field(
|
||||
default_factory=from_env("MINIMAX_API_HOST", default="https://api.minimax.chat")
|
||||
)
|
||||
minimax_group_id: Optional[str] = Field(
|
||||
default_factory=from_env("MINIMAX_GROUP_ID", default=None)
|
||||
)
|
||||
minimax_api_key: Optional[SecretStr] = None
|
||||
|
||||
@pre_init
|
||||
@@ -78,16 +87,7 @@ class MinimaxCommon(BaseModel):
|
||||
values["minimax_api_key"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(values, "minimax_api_key", "MINIMAX_API_KEY")
|
||||
)
|
||||
values["minimax_group_id"] = get_from_dict_or_env(
|
||||
values, "minimax_group_id", "MINIMAX_GROUP_ID"
|
||||
)
|
||||
# Get custom api url from environment.
|
||||
values["minimax_api_host"] = get_from_dict_or_env(
|
||||
values,
|
||||
"minimax_api_host",
|
||||
"MINIMAX_API_HOST",
|
||||
default="https://api.minimax.chat",
|
||||
)
|
||||
values["_client"] = _MinimaxEndpointClient( # type: ignore[call-arg]
|
||||
host=values["minimax_api_host"],
|
||||
api_key=values["minimax_api_key"],
|
||||
|
||||
@@ -5,7 +5,7 @@ import requests
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models.llms import LLM
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
from langchain_core.utils import get_from_dict_or_env, pre_init
|
||||
from langchain_core.utils import from_env, pre_init
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -35,7 +35,7 @@ class OCIModelDeploymentLLM(LLM):
|
||||
p: float = 0.75
|
||||
"""Total probability mass of tokens to consider at each step."""
|
||||
|
||||
endpoint: str = ""
|
||||
endpoint: str = Field(default_factory=from_env("OCI_LLM_ENDPOINT", default=""))
|
||||
"""The uri of the endpoint from the deployed Model Deployment model."""
|
||||
|
||||
best_of: int = 1
|
||||
@@ -62,11 +62,6 @@ class OCIModelDeploymentLLM(LLM):
|
||||
) from ex
|
||||
if not values.get("auth", None):
|
||||
values["auth"] = ads.common.auth.default_signer()
|
||||
values["endpoint"] = get_from_dict_or_env(
|
||||
values,
|
||||
"endpoint",
|
||||
"OCI_LLM_ENDPOINT",
|
||||
)
|
||||
return values
|
||||
|
||||
@property
|
||||
|
||||
@@ -1,7 +1,12 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
from langchain_core.pydantic_v1 import Field, SecretStr
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init
|
||||
from langchain_core.utils import (
|
||||
convert_to_secret_str,
|
||||
from_env,
|
||||
get_from_dict_or_env,
|
||||
pre_init,
|
||||
)
|
||||
|
||||
from langchain_community.llms.openai import BaseOpenAI
|
||||
from langchain_community.utils.openai import is_openai_v1
|
||||
@@ -35,9 +40,13 @@ class OctoAIEndpoint(BaseOpenAI):
|
||||
"""
|
||||
|
||||
"""Key word arguments to pass to the model."""
|
||||
octoai_api_base: str = Field(default=DEFAULT_BASE_URL)
|
||||
octoai_api_base: str = Field(
|
||||
default_factory=from_env("OCTOAI_API_BASE", default=DEFAULT_BASE_URL)
|
||||
)
|
||||
octoai_api_token: SecretStr = Field(default=None)
|
||||
model_name: str = Field(default=DEFAULT_MODEL)
|
||||
model_name: str = Field(
|
||||
default_factory=from_env("MODEL_NAME", default=DEFAULT_MODEL)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
@@ -69,21 +78,9 @@ class OctoAIEndpoint(BaseOpenAI):
|
||||
@pre_init
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
values["octoai_api_base"] = get_from_dict_or_env(
|
||||
values,
|
||||
"octoai_api_base",
|
||||
"OCTOAI_API_BASE",
|
||||
default=DEFAULT_BASE_URL,
|
||||
)
|
||||
values["octoai_api_token"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(values, "octoai_api_token", "OCTOAI_API_TOKEN")
|
||||
)
|
||||
values["model_name"] = get_from_dict_or_env(
|
||||
values,
|
||||
"model_name",
|
||||
"MODEL_NAME",
|
||||
default=DEFAULT_MODEL,
|
||||
)
|
||||
|
||||
try:
|
||||
import openai
|
||||
|
||||
@@ -30,6 +30,7 @@ from langchain_core.language_models.llms import BaseLLM, create_base_retry_decor
|
||||
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
|
||||
from langchain_core.pydantic_v1 import Field, root_validator
|
||||
from langchain_core.utils import (
|
||||
from_env,
|
||||
get_from_dict_or_env,
|
||||
get_pydantic_field_names,
|
||||
pre_init,
|
||||
@@ -202,7 +203,9 @@ class BaseOpenAI(BaseLLM):
|
||||
# When updating this to use a SecretStr
|
||||
# Check for classes that derive from this class (as some of them
|
||||
# may assume openai_api_key is a str)
|
||||
openai_api_key: Optional[str] = Field(default=None, alias="api_key")
|
||||
openai_api_key: Optional[str] = Field(
|
||||
default_factory=from_env("OPENAI_API_KEY", default=None), alias="api_key"
|
||||
)
|
||||
"""Automatically inferred from env var `OPENAI_API_KEY` if not provided."""
|
||||
openai_api_base: Optional[str] = Field(default=None, alias="base_url")
|
||||
"""Base URL path for API requests, leave blank if not using a proxy or service
|
||||
@@ -210,7 +213,9 @@ class BaseOpenAI(BaseLLM):
|
||||
openai_organization: Optional[str] = Field(default=None, alias="organization")
|
||||
"""Automatically inferred from env var `OPENAI_ORG_ID` if not provided."""
|
||||
# to support explicit proxy for OpenAI
|
||||
openai_proxy: Optional[str] = None
|
||||
openai_proxy: Optional[str] = Field(
|
||||
default_factory=from_env("OPENAI_PROXY", default="")
|
||||
)
|
||||
batch_size: int = 20
|
||||
"""Batch size to use when passing multiple documents to generate."""
|
||||
request_timeout: Union[float, Tuple[float, float], Any, None] = Field(
|
||||
@@ -282,18 +287,9 @@ class BaseOpenAI(BaseLLM):
|
||||
if values["streaming"] and values["best_of"] > 1:
|
||||
raise ValueError("Cannot stream results when best_of > 1.")
|
||||
|
||||
values["openai_api_key"] = get_from_dict_or_env(
|
||||
values, "openai_api_key", "OPENAI_API_KEY"
|
||||
)
|
||||
values["openai_api_base"] = values["openai_api_base"] or os.getenv(
|
||||
"OPENAI_API_BASE"
|
||||
)
|
||||
values["openai_proxy"] = get_from_dict_or_env(
|
||||
values,
|
||||
"openai_proxy",
|
||||
"OPENAI_PROXY",
|
||||
default="",
|
||||
)
|
||||
values["openai_organization"] = (
|
||||
values["openai_organization"]
|
||||
or os.getenv("OPENAI_ORG_ID")
|
||||
|
||||
@@ -6,7 +6,7 @@ import requests
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models.llms import LLM
|
||||
from langchain_core.outputs import GenerationChunk
|
||||
from langchain_core.utils import get_from_dict_or_env, pre_init
|
||||
from langchain_core.utils import pre_init
|
||||
|
||||
from langchain_community.llms.utils import enforce_stop_tokens
|
||||
|
||||
@@ -54,12 +54,6 @@ class PaiEasEndpoint(LLM):
|
||||
@pre_init
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
values["eas_service_url"] = get_from_dict_or_env(
|
||||
values, "eas_service_url", "EAS_SERVICE_URL"
|
||||
)
|
||||
values["eas_service_token"] = get_from_dict_or_env(
|
||||
values, "eas_service_token", "EAS_SERVICE_TOKEN"
|
||||
)
|
||||
|
||||
return values
|
||||
|
||||
|
||||
@@ -5,7 +5,8 @@ import requests
|
||||
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models.llms import LLM
|
||||
from langchain_core.outputs import GenerationChunk
|
||||
from langchain_core.utils import get_from_dict_or_env, pre_init
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
from langchain_core.utils import from_env, get_from_dict_or_env, pre_init
|
||||
|
||||
|
||||
class SVEndpointHandler:
|
||||
@@ -197,10 +198,14 @@ class Sambaverse(LLM):
|
||||
sambaverse_url: str = ""
|
||||
"""Sambaverse url to use"""
|
||||
|
||||
sambaverse_api_key: str = ""
|
||||
sambaverse_api_key: str = Field(
|
||||
default_factory=from_env("SAMBAVERSE_API_KEY", default="")
|
||||
)
|
||||
"""sambaverse api key"""
|
||||
|
||||
sambaverse_model_name: Optional[str] = None
|
||||
sambaverse_model_name: Optional[str] = Field(
|
||||
default_factory=from_env("SAMBAVERSE_MODEL_NAME", default=None)
|
||||
)
|
||||
"""sambaverse expert model to use"""
|
||||
|
||||
model_kwargs: Optional[dict] = None
|
||||
@@ -225,12 +230,6 @@ class Sambaverse(LLM):
|
||||
"SAMBAVERSE_URL",
|
||||
default="https://sambaverse.sambanova.ai",
|
||||
)
|
||||
values["sambaverse_api_key"] = get_from_dict_or_env(
|
||||
values, "sambaverse_api_key", "SAMBAVERSE_API_KEY"
|
||||
)
|
||||
values["sambaverse_model_name"] = get_from_dict_or_env(
|
||||
values, "sambaverse_model_name", "SAMBAVERSE_MODEL_NAME"
|
||||
)
|
||||
return values
|
||||
|
||||
@property
|
||||
@@ -691,19 +690,27 @@ class SambaStudio(LLM):
|
||||
)
|
||||
"""
|
||||
|
||||
sambastudio_base_url: str = ""
|
||||
sambastudio_base_url: str = Field(
|
||||
default_factory=from_env("SAMBASTUDIO_BASE_URL", default="")
|
||||
)
|
||||
"""Base url to use"""
|
||||
|
||||
sambastudio_base_uri: str = ""
|
||||
"""endpoint base uri"""
|
||||
|
||||
sambastudio_project_id: str = ""
|
||||
sambastudio_project_id: str = Field(
|
||||
default_factory=from_env("SAMBASTUDIO_PROJECT_ID", default="")
|
||||
)
|
||||
"""Project id on sambastudio for model"""
|
||||
|
||||
sambastudio_endpoint_id: str = ""
|
||||
sambastudio_endpoint_id: str = Field(
|
||||
default_factory=from_env("SAMBASTUDIO_ENDPOINT_ID", default="")
|
||||
)
|
||||
"""endpoint id on sambastudio for model"""
|
||||
|
||||
sambastudio_api_key: str = ""
|
||||
sambastudio_api_key: str = Field(
|
||||
default_factory=from_env("SAMBASTUDIO_API_KEY", default="")
|
||||
)
|
||||
"""sambastudio api key"""
|
||||
|
||||
model_kwargs: Optional[dict] = None
|
||||
@@ -732,24 +739,12 @@ class SambaStudio(LLM):
|
||||
@pre_init
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
values["sambastudio_base_url"] = get_from_dict_or_env(
|
||||
values, "sambastudio_base_url", "SAMBASTUDIO_BASE_URL"
|
||||
)
|
||||
values["sambastudio_base_uri"] = get_from_dict_or_env(
|
||||
values,
|
||||
"sambastudio_base_uri",
|
||||
"SAMBASTUDIO_BASE_URI",
|
||||
default="api/predict/generic",
|
||||
)
|
||||
values["sambastudio_project_id"] = get_from_dict_or_env(
|
||||
values, "sambastudio_project_id", "SAMBASTUDIO_PROJECT_ID"
|
||||
)
|
||||
values["sambastudio_endpoint_id"] = get_from_dict_or_env(
|
||||
values, "sambastudio_endpoint_id", "SAMBASTUDIO_ENDPOINT_ID"
|
||||
)
|
||||
values["sambastudio_api_key"] = get_from_dict_or_env(
|
||||
values, "sambastudio_api_key", "SAMBASTUDIO_API_KEY"
|
||||
)
|
||||
return values
|
||||
|
||||
def _get_tuning_params(self, stop: Optional[List[str]]) -> str:
|
||||
|
||||
@@ -18,7 +18,7 @@ from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models.llms import LLM
|
||||
from langchain_core.outputs import GenerationChunk
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
from langchain_core.utils import get_from_dict_or_env, pre_init
|
||||
from langchain_core.utils import from_env, pre_init
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -42,11 +42,23 @@ class SparkLLM(LLM):
|
||||
"""
|
||||
|
||||
client: Any = None #: :meta private:
|
||||
spark_app_id: Optional[str] = None
|
||||
spark_api_key: Optional[str] = None
|
||||
spark_api_secret: Optional[str] = None
|
||||
spark_api_url: Optional[str] = None
|
||||
spark_llm_domain: Optional[str] = None
|
||||
spark_app_id: Optional[str] = Field(
|
||||
default_factory=from_env("IFLYTEK_SPARK_APP_ID", default=None)
|
||||
)
|
||||
spark_api_key: Optional[str] = Field(
|
||||
default_factory=from_env("IFLYTEK_SPARK_API_KEY", default=None)
|
||||
)
|
||||
spark_api_secret: Optional[str] = Field(
|
||||
default_factory=from_env("IFLYTEK_SPARK_API_SECRET", default=None)
|
||||
)
|
||||
spark_api_url: Optional[str] = Field(
|
||||
default_factory=from_env(
|
||||
"IFLYTEK_SPARK_API_URL", default="wss://spark-api.xf-yun.com/v3.1/chat"
|
||||
)
|
||||
)
|
||||
spark_llm_domain: Optional[str] = Field(
|
||||
default_factory=from_env("IFLYTEK_SPARK_LLM_DOMAIN", default="generalv3")
|
||||
)
|
||||
spark_user_id: str = "lc_user"
|
||||
streaming: bool = False
|
||||
request_timeout: int = 30
|
||||
@@ -56,33 +68,6 @@ class SparkLLM(LLM):
|
||||
|
||||
@pre_init
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
values["spark_app_id"] = get_from_dict_or_env(
|
||||
values,
|
||||
"spark_app_id",
|
||||
"IFLYTEK_SPARK_APP_ID",
|
||||
)
|
||||
values["spark_api_key"] = get_from_dict_or_env(
|
||||
values,
|
||||
"spark_api_key",
|
||||
"IFLYTEK_SPARK_API_KEY",
|
||||
)
|
||||
values["spark_api_secret"] = get_from_dict_or_env(
|
||||
values,
|
||||
"spark_api_secret",
|
||||
"IFLYTEK_SPARK_API_SECRET",
|
||||
)
|
||||
values["spark_api_url"] = get_from_dict_or_env(
|
||||
values,
|
||||
"spark_api_url",
|
||||
"IFLYTEK_SPARK_API_URL",
|
||||
"wss://spark-api.xf-yun.com/v3.1/chat",
|
||||
)
|
||||
values["spark_llm_domain"] = get_from_dict_or_env(
|
||||
values,
|
||||
"spark_llm_domain",
|
||||
"IFLYTEK_SPARK_LLM_DOMAIN",
|
||||
"generalv3",
|
||||
)
|
||||
# put extra params into model_kwargs
|
||||
values["model_kwargs"]["temperature"] = values["temperature"] or cls.temperature
|
||||
values["model_kwargs"]["top_k"] = values["top_k"] or cls.top_k
|
||||
|
||||
@@ -2,9 +2,14 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.pydantic_v1 import SecretStr
|
||||
from langchain_core.pydantic_v1 import Field, SecretStr
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init
|
||||
from langchain_core.utils import (
|
||||
convert_to_secret_str,
|
||||
from_env,
|
||||
get_from_dict_or_env,
|
||||
pre_init,
|
||||
)
|
||||
|
||||
from langchain_community.utilities.arcee import ArceeWrapper, DALMFilter
|
||||
|
||||
@@ -37,13 +42,19 @@ class ArceeRetriever(BaseRetriever):
|
||||
model: str
|
||||
"""Arcee DALM name"""
|
||||
|
||||
arcee_api_url: str = "https://api.arcee.ai"
|
||||
arcee_api_url: str = Field(
|
||||
default_factory=from_env("ARCEE_API_URL", default="https://api.arcee.ai")
|
||||
)
|
||||
"""Arcee API URL"""
|
||||
|
||||
arcee_api_version: str = "v2"
|
||||
arcee_api_version: str = Field(
|
||||
default_factory=from_env("ARCEE_API_VERSION", default="v2")
|
||||
)
|
||||
"""Arcee API Version"""
|
||||
|
||||
arcee_app_url: str = "https://app.arcee.ai"
|
||||
arcee_app_url: str = Field(
|
||||
default_factory=from_env("ARCEE_APP_URL", default="https://app.arcee.ai")
|
||||
)
|
||||
"""Arcee App URL"""
|
||||
|
||||
model_kwargs: Optional[Dict[str, Any]] = None
|
||||
@@ -81,24 +92,6 @@ class ArceeRetriever(BaseRetriever):
|
||||
)
|
||||
)
|
||||
|
||||
values["arcee_api_url"] = get_from_dict_or_env(
|
||||
values,
|
||||
"arcee_api_url",
|
||||
"ARCEE_API_URL",
|
||||
)
|
||||
|
||||
values["arcee_app_url"] = get_from_dict_or_env(
|
||||
values,
|
||||
"arcee_app_url",
|
||||
"ARCEE_APP_URL",
|
||||
)
|
||||
|
||||
values["arcee_api_version"] = get_from_dict_or_env(
|
||||
values,
|
||||
"arcee_api_version",
|
||||
"ARCEE_API_VERSION",
|
||||
)
|
||||
|
||||
# validate model kwargs
|
||||
if values["model_kwargs"]:
|
||||
kw = values["model_kwargs"]
|
||||
|
||||
@@ -6,7 +6,7 @@ from langchain_core._api.deprecation import deprecated
|
||||
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from langchain_core.utils import get_from_dict_or_env, pre_init
|
||||
from langchain_core.utils import pre_init
|
||||
|
||||
from langchain_community.utilities.vertexai import get_client_info
|
||||
|
||||
@@ -58,9 +58,6 @@ class GoogleDocumentAIWarehouseRetriever(BaseRetriever):
|
||||
"Please install it with pip install google-cloud-contentwarehouse"
|
||||
) from exc
|
||||
|
||||
values["project_number"] = get_from_dict_or_env(
|
||||
values, "project_number", "PROJECT_NUMBER"
|
||||
)
|
||||
values["client"] = DocumentServiceClient(
|
||||
client_info=get_client_info(module="document-ai-warehouse")
|
||||
)
|
||||
|
||||
@@ -280,6 +280,10 @@ def from_env(key: str, /) -> Callable[[], str]: ...
|
||||
def from_env(key: str, /, *, default: str) -> Callable[[], str]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def from_env(key: str, /, *, default: None) -> Callable[[], str]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def from_env(key: str, /, *, error_message: str) -> Callable[[], str]: ...
|
||||
|
||||
|
||||
Reference in New Issue
Block a user