mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-13 06:40:04 +00:00
fireworks[patch]: Upgrade @root_validators to be pydantic 2 compliant (#25443)
Update @root_validators to be pydantic 2 compliant
This commit is contained in:
parent
75ae585deb
commit
eb3870e9d8
@ -296,6 +296,10 @@ def from_env(
|
|||||||
) -> Callable[[], Optional[str]]: ...
|
) -> Callable[[], Optional[str]]: ...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def from_env(key: str, /, *, default: None) -> Callable[[], Optional[str]]: ...
|
||||||
|
|
||||||
|
|
||||||
def from_env(
|
def from_env(
|
||||||
key: str,
|
key: str,
|
||||||
/,
|
/,
|
||||||
|
@ -4,7 +4,6 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
from operator import itemgetter
|
from operator import itemgetter
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
@ -78,8 +77,6 @@ from langchain_core.pydantic_v1 import (
|
|||||||
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
|
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
|
||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
from langchain_core.utils import (
|
from langchain_core.utils import (
|
||||||
convert_to_secret_str,
|
|
||||||
get_from_dict_or_env,
|
|
||||||
get_pydantic_field_names,
|
get_pydantic_field_names,
|
||||||
)
|
)
|
||||||
from langchain_core.utils.function_calling import (
|
from langchain_core.utils.function_calling import (
|
||||||
@ -87,7 +84,7 @@ from langchain_core.utils.function_calling import (
|
|||||||
convert_to_openai_tool,
|
convert_to_openai_tool,
|
||||||
)
|
)
|
||||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||||
from langchain_core.utils.utils import build_extra_kwargs
|
from langchain_core.utils.utils import build_extra_kwargs, from_env, secret_from_env
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -322,9 +319,25 @@ class ChatFireworks(BaseChatModel):
|
|||||||
"""Default stop sequences."""
|
"""Default stop sequences."""
|
||||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||||
"""Holds any model parameters valid for `create` call not explicitly specified."""
|
"""Holds any model parameters valid for `create` call not explicitly specified."""
|
||||||
fireworks_api_key: SecretStr = Field(default=None, alias="api_key")
|
fireworks_api_key: SecretStr = Field(
|
||||||
"""Automatically inferred from env var `FIREWORKS_API_KEY` if not provided."""
|
alias="api_key",
|
||||||
fireworks_api_base: Optional[str] = Field(default=None, alias="base_url")
|
default_factory=secret_from_env(
|
||||||
|
"FIREWORKS_API_KEY",
|
||||||
|
error_message=(
|
||||||
|
"You must specify an api key. "
|
||||||
|
"You can pass it an argument as `api_key=...` or "
|
||||||
|
"set the environment variable `FIREWORKS_API_KEY`."
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
"""Fireworks API key.
|
||||||
|
|
||||||
|
Automatically read from env variable `FIREWORKS_API_KEY` if not provided.
|
||||||
|
"""
|
||||||
|
|
||||||
|
fireworks_api_base: Optional[str] = Field(
|
||||||
|
alias="base_url", default_factory=from_env("FIREWORKS_API_BASE", default=None)
|
||||||
|
)
|
||||||
"""Base URL path for API requests, leave blank if not using a proxy or service
|
"""Base URL path for API requests, leave blank if not using a proxy or service
|
||||||
emulator."""
|
emulator."""
|
||||||
request_timeout: Union[float, Tuple[float, float], Any, None] = Field(
|
request_timeout: Union[float, Tuple[float, float], Any, None] = Field(
|
||||||
@ -356,7 +369,7 @@ class ChatFireworks(BaseChatModel):
|
|||||||
)
|
)
|
||||||
return values
|
return values
|
||||||
|
|
||||||
@root_validator()
|
@root_validator(pre=False, skip_on_failure=True)
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
"""Validate that api key and python package exists in environment."""
|
"""Validate that api key and python package exists in environment."""
|
||||||
if values["n"] < 1:
|
if values["n"] < 1:
|
||||||
@ -364,12 +377,6 @@ class ChatFireworks(BaseChatModel):
|
|||||||
if values["n"] > 1 and values["streaming"]:
|
if values["n"] > 1 and values["streaming"]:
|
||||||
raise ValueError("n must be 1 when streaming.")
|
raise ValueError("n must be 1 when streaming.")
|
||||||
|
|
||||||
values["fireworks_api_key"] = convert_to_secret_str(
|
|
||||||
get_from_dict_or_env(values, "fireworks_api_key", "FIREWORKS_API_KEY")
|
|
||||||
)
|
|
||||||
values["fireworks_api_base"] = values["fireworks_api_base"] or os.getenv(
|
|
||||||
"FIREWORKS_API_BASE"
|
|
||||||
)
|
|
||||||
client_params = {
|
client_params = {
|
||||||
"api_key": (
|
"api_key": (
|
||||||
values["fireworks_api_key"].get_secret_value()
|
values["fireworks_api_key"].get_secret_value()
|
||||||
|
@ -1,9 +1,8 @@
|
|||||||
import os
|
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
|
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
|
||||||
from langchain_core.utils import convert_to_secret_str
|
from langchain_core.utils import secret_from_env
|
||||||
from openai import OpenAI # type: ignore
|
from openai import OpenAI # type: ignore
|
||||||
|
|
||||||
|
|
||||||
@ -67,22 +66,25 @@ class FireworksEmbeddings(BaseModel, Embeddings):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
_client: OpenAI = Field(default=None)
|
_client: OpenAI = Field(default=None)
|
||||||
fireworks_api_key: SecretStr = convert_to_secret_str("")
|
fireworks_api_key: SecretStr = Field(
|
||||||
|
alias="api_key",
|
||||||
|
default_factory=secret_from_env(
|
||||||
|
"FIREWORKS_API_KEY",
|
||||||
|
default="",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
"""Fireworks API key.
|
||||||
|
|
||||||
|
Automatically read from env variable `FIREWORKS_API_KEY` if not provided.
|
||||||
|
"""
|
||||||
model: str = "nomic-ai/nomic-embed-text-v1.5"
|
model: str = "nomic-ai/nomic-embed-text-v1.5"
|
||||||
|
|
||||||
@root_validator()
|
@root_validator(pre=False, skip_on_failure=True)
|
||||||
def validate_environment(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
def validate_environment(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""Validate environment variables."""
|
"""Validate environment variables."""
|
||||||
fireworks_api_key = convert_to_secret_str(
|
|
||||||
values.get("fireworks_api_key") or os.getenv("FIREWORKS_API_KEY") or ""
|
|
||||||
)
|
|
||||||
values["fireworks_api_key"] = fireworks_api_key
|
|
||||||
|
|
||||||
# note this sets it globally for module
|
|
||||||
# there isn't currently a way to pass it into client
|
|
||||||
api_key = fireworks_api_key.get_secret_value()
|
|
||||||
values["_client"] = OpenAI(
|
values["_client"] = OpenAI(
|
||||||
api_key=api_key, base_url="https://api.fireworks.ai/inference/v1"
|
api_key=values["fireworks_api_key"].get_secret_value(),
|
||||||
|
base_url="https://api.fireworks.ai/inference/v1",
|
||||||
)
|
)
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user