fireworks[patch]: Upgrade @root_validators to be pydantic 2 compliant (#25443)

Update @root_validators to be pydantic 2 compliant
This commit is contained in:
Eugene Yurtsev 2024-08-15 12:56:48 -04:00 committed by GitHub
parent 75ae585deb
commit eb3870e9d8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 40 additions and 27 deletions

View File

@ -296,6 +296,10 @@ def from_env(
) -> Callable[[], Optional[str]]: ... ) -> Callable[[], Optional[str]]: ...
@overload
def from_env(key: str, /, *, default: None) -> Callable[[], Optional[str]]: ...
def from_env( def from_env(
key: str, key: str,
/, /,

View File

@ -4,7 +4,6 @@ from __future__ import annotations
import json import json
import logging import logging
import os
from operator import itemgetter from operator import itemgetter
from typing import ( from typing import (
Any, Any,
@ -78,8 +77,6 @@ from langchain_core.pydantic_v1 import (
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
from langchain_core.utils import ( from langchain_core.utils import (
convert_to_secret_str,
get_from_dict_or_env,
get_pydantic_field_names, get_pydantic_field_names,
) )
from langchain_core.utils.function_calling import ( from langchain_core.utils.function_calling import (
@ -87,7 +84,7 @@ from langchain_core.utils.function_calling import (
convert_to_openai_tool, convert_to_openai_tool,
) )
from langchain_core.utils.pydantic import is_basemodel_subclass from langchain_core.utils.pydantic import is_basemodel_subclass
from langchain_core.utils.utils import build_extra_kwargs from langchain_core.utils.utils import build_extra_kwargs, from_env, secret_from_env
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -322,9 +319,25 @@ class ChatFireworks(BaseChatModel):
"""Default stop sequences.""" """Default stop sequences."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict) model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified.""" """Holds any model parameters valid for `create` call not explicitly specified."""
fireworks_api_key: SecretStr = Field(default=None, alias="api_key") fireworks_api_key: SecretStr = Field(
"""Automatically inferred from env var `FIREWORKS_API_KEY` if not provided.""" alias="api_key",
fireworks_api_base: Optional[str] = Field(default=None, alias="base_url") default_factory=secret_from_env(
"FIREWORKS_API_KEY",
error_message=(
"You must specify an api key. "
"You can pass it an argument as `api_key=...` or "
"set the environment variable `FIREWORKS_API_KEY`."
),
),
)
"""Fireworks API key.
Automatically read from env variable `FIREWORKS_API_KEY` if not provided.
"""
fireworks_api_base: Optional[str] = Field(
alias="base_url", default_factory=from_env("FIREWORKS_API_BASE", default=None)
)
"""Base URL path for API requests, leave blank if not using a proxy or service """Base URL path for API requests, leave blank if not using a proxy or service
emulator.""" emulator."""
request_timeout: Union[float, Tuple[float, float], Any, None] = Field( request_timeout: Union[float, Tuple[float, float], Any, None] = Field(
@ -356,7 +369,7 @@ class ChatFireworks(BaseChatModel):
) )
return values return values
@root_validator() @root_validator(pre=False, skip_on_failure=True)
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""
if values["n"] < 1: if values["n"] < 1:
@ -364,12 +377,6 @@ class ChatFireworks(BaseChatModel):
if values["n"] > 1 and values["streaming"]: if values["n"] > 1 and values["streaming"]:
raise ValueError("n must be 1 when streaming.") raise ValueError("n must be 1 when streaming.")
values["fireworks_api_key"] = convert_to_secret_str(
get_from_dict_or_env(values, "fireworks_api_key", "FIREWORKS_API_KEY")
)
values["fireworks_api_base"] = values["fireworks_api_base"] or os.getenv(
"FIREWORKS_API_BASE"
)
client_params = { client_params = {
"api_key": ( "api_key": (
values["fireworks_api_key"].get_secret_value() values["fireworks_api_key"].get_secret_value()

View File

@ -1,9 +1,8 @@
import os
from typing import Any, Dict, List from typing import Any, Dict, List
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str from langchain_core.utils import secret_from_env
from openai import OpenAI # type: ignore from openai import OpenAI # type: ignore
@ -67,22 +66,25 @@ class FireworksEmbeddings(BaseModel, Embeddings):
""" """
_client: OpenAI = Field(default=None) _client: OpenAI = Field(default=None)
fireworks_api_key: SecretStr = convert_to_secret_str("") fireworks_api_key: SecretStr = Field(
alias="api_key",
default_factory=secret_from_env(
"FIREWORKS_API_KEY",
default="",
),
)
"""Fireworks API key.
Automatically read from env variable `FIREWORKS_API_KEY` if not provided.
"""
model: str = "nomic-ai/nomic-embed-text-v1.5" model: str = "nomic-ai/nomic-embed-text-v1.5"
@root_validator() @root_validator(pre=False, skip_on_failure=True)
def validate_environment(cls, values: Dict[str, Any]) -> Dict[str, Any]: def validate_environment(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Validate environment variables.""" """Validate environment variables."""
fireworks_api_key = convert_to_secret_str(
values.get("fireworks_api_key") or os.getenv("FIREWORKS_API_KEY") or ""
)
values["fireworks_api_key"] = fireworks_api_key
# note this sets it globally for module
# there isn't currently a way to pass it into client
api_key = fireworks_api_key.get_secret_value()
values["_client"] = OpenAI( values["_client"] = OpenAI(
api_key=api_key, base_url="https://api.fireworks.ai/inference/v1" api_key=values["fireworks_api_key"].get_secret_value(),
base_url="https://api.fireworks.ai/inference/v1",
) )
return values return values