From eb3870e9d83ec6efc2c1486fc742f309cb923619 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Thu, 15 Aug 2024 12:56:48 -0400 Subject: [PATCH] fireworks[patch]: Upgrade @root_validators to be pydantic 2 compliant (#25443) Update @root_validators to be pydantic 2 compliant --- libs/core/langchain_core/utils/utils.py | 4 +++ .../langchain_fireworks/chat_models.py | 35 +++++++++++-------- .../langchain_fireworks/embeddings.py | 28 ++++++++------- 3 files changed, 40 insertions(+), 27 deletions(-) diff --git a/libs/core/langchain_core/utils/utils.py b/libs/core/langchain_core/utils/utils.py index 960646ecb70..0956ea1dd90 100644 --- a/libs/core/langchain_core/utils/utils.py +++ b/libs/core/langchain_core/utils/utils.py @@ -296,6 +296,10 @@ def from_env( ) -> Callable[[], Optional[str]]: ... +@overload +def from_env(key: str, /, *, default: None) -> Callable[[], Optional[str]]: ... + + def from_env( key: str, /, diff --git a/libs/partners/fireworks/langchain_fireworks/chat_models.py b/libs/partners/fireworks/langchain_fireworks/chat_models.py index 9fc9402e8ee..7b489cd5d00 100644 --- a/libs/partners/fireworks/langchain_fireworks/chat_models.py +++ b/libs/partners/fireworks/langchain_fireworks/chat_models.py @@ -4,7 +4,6 @@ from __future__ import annotations import json import logging -import os from operator import itemgetter from typing import ( Any, @@ -78,8 +77,6 @@ from langchain_core.pydantic_v1 import ( from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough from langchain_core.tools import BaseTool from langchain_core.utils import ( - convert_to_secret_str, - get_from_dict_or_env, get_pydantic_field_names, ) from langchain_core.utils.function_calling import ( @@ -87,7 +84,7 @@ from langchain_core.utils.function_calling import ( convert_to_openai_tool, ) from langchain_core.utils.pydantic import is_basemodel_subclass -from langchain_core.utils.utils import build_extra_kwargs +from langchain_core.utils.utils import build_extra_kwargs, from_env, secret_from_env logger = logging.getLogger(__name__) @@ -322,9 +319,25 @@ class ChatFireworks(BaseChatModel): """Default stop sequences.""" model_kwargs: Dict[str, Any] = Field(default_factory=dict) """Holds any model parameters valid for `create` call not explicitly specified.""" - fireworks_api_key: SecretStr = Field(default=None, alias="api_key") - """Automatically inferred from env var `FIREWORKS_API_KEY` if not provided.""" - fireworks_api_base: Optional[str] = Field(default=None, alias="base_url") + fireworks_api_key: SecretStr = Field( + alias="api_key", + default_factory=secret_from_env( + "FIREWORKS_API_KEY", + error_message=( + "You must specify an api key. " + "You can pass it an argument as `api_key=...` or " + "set the environment variable `FIREWORKS_API_KEY`." + ), + ), + ) + """Fireworks API key. + + Automatically read from env variable `FIREWORKS_API_KEY` if not provided. + """ + + fireworks_api_base: Optional[str] = Field( + alias="base_url", default_factory=from_env("FIREWORKS_API_BASE", default=None) + ) """Base URL path for API requests, leave blank if not using a proxy or service emulator.""" request_timeout: Union[float, Tuple[float, float], Any, None] = Field( @@ -356,7 +369,7 @@ class ChatFireworks(BaseChatModel): ) return values - @root_validator() + @root_validator(pre=False, skip_on_failure=True) def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" if values["n"] < 1: @@ -364,12 +377,6 @@ class ChatFireworks(BaseChatModel): if values["n"] > 1 and values["streaming"]: raise ValueError("n must be 1 when streaming.") - values["fireworks_api_key"] = convert_to_secret_str( - get_from_dict_or_env(values, "fireworks_api_key", "FIREWORKS_API_KEY") - ) - values["fireworks_api_base"] = values["fireworks_api_base"] or os.getenv( - "FIREWORKS_API_BASE" - ) client_params = { "api_key": ( values["fireworks_api_key"].get_secret_value() diff --git a/libs/partners/fireworks/langchain_fireworks/embeddings.py b/libs/partners/fireworks/langchain_fireworks/embeddings.py index bb5693675ea..719bb34a935 100644 --- a/libs/partners/fireworks/langchain_fireworks/embeddings.py +++ b/libs/partners/fireworks/langchain_fireworks/embeddings.py @@ -1,9 +1,8 @@ -import os from typing import Any, Dict, List from langchain_core.embeddings import Embeddings from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator -from langchain_core.utils import convert_to_secret_str +from langchain_core.utils import secret_from_env from openai import OpenAI # type: ignore @@ -67,22 +66,25 @@ class FireworksEmbeddings(BaseModel, Embeddings): """ _client: OpenAI = Field(default=None) - fireworks_api_key: SecretStr = convert_to_secret_str("") + fireworks_api_key: SecretStr = Field( + alias="api_key", + default_factory=secret_from_env( + "FIREWORKS_API_KEY", + default="", + ), + ) + """Fireworks API key. + + Automatically read from env variable `FIREWORKS_API_KEY` if not provided. + """ model: str = "nomic-ai/nomic-embed-text-v1.5" - @root_validator() + @root_validator(pre=False, skip_on_failure=True) def validate_environment(cls, values: Dict[str, Any]) -> Dict[str, Any]: """Validate environment variables.""" - fireworks_api_key = convert_to_secret_str( - values.get("fireworks_api_key") or os.getenv("FIREWORKS_API_KEY") or "" - ) - values["fireworks_api_key"] = fireworks_api_key - - # note this sets it globally for module - # there isn't currently a way to pass it into client - api_key = fireworks_api_key.get_secret_value() values["_client"] = OpenAI( - api_key=api_key, base_url="https://api.fireworks.ai/inference/v1" + api_key=values["fireworks_api_key"].get_secret_value(), + base_url="https://api.fireworks.ai/inference/v1", ) return values