diff --git a/libs/partners/fireworks/langchain_fireworks/chat_models.py b/libs/partners/fireworks/langchain_fireworks/chat_models.py index 7b489cd5d00..c5e6d9c42fd 100644 --- a/libs/partners/fireworks/langchain_fireworks/chat_models.py +++ b/libs/partners/fireworks/langchain_fireworks/chat_models.py @@ -68,12 +68,6 @@ from langchain_core.output_parsers.openai_tools import ( parse_tool_call, ) from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult -from langchain_core.pydantic_v1 import ( - BaseModel, - Field, - SecretStr, - root_validator, -) from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough from langchain_core.tools import BaseTool from langchain_core.utils import ( @@ -85,6 +79,14 @@ from langchain_core.utils.function_calling import ( ) from langchain_core.utils.pydantic import is_basemodel_subclass from langchain_core.utils.utils import build_extra_kwargs, from_env, secret_from_env +from pydantic import ( + BaseModel, + ConfigDict, + Field, + SecretStr, + model_validator, +) +from typing_extensions import Self logger = logging.getLogger(__name__) @@ -354,13 +356,13 @@ class ChatFireworks(BaseChatModel): max_retries: Optional[int] = None """Maximum number of retries to make when generating.""" - class Config: - """Configuration for this pydantic object.""" + model_config = ConfigDict( + populate_by_name=True, + ) - allow_population_by_field_name = True - - @root_validator(pre=True) - def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: + @model_validator(mode="before") + @classmethod + def build_extra(cls, values: Dict[str, Any]) -> Any: """Build extra kwargs from additional params that were passed in.""" all_required_field_names = get_pydantic_field_names(cls) extra = values.get("model_kwargs", {}) @@ -369,32 +371,32 @@ class ChatFireworks(BaseChatModel): ) return values - @root_validator(pre=False, skip_on_failure=True) - def validate_environment(cls, values: Dict) -> Dict: + @model_validator(mode="after") + def validate_environment(self) -> Self: """Validate that api key and python package exists in environment.""" - if values["n"] < 1: + if self.n < 1: raise ValueError("n must be at least 1.") - if values["n"] > 1 and values["streaming"]: + if self.n > 1 and self.streaming: raise ValueError("n must be 1 when streaming.") client_params = { "api_key": ( - values["fireworks_api_key"].get_secret_value() - if values["fireworks_api_key"] + self.fireworks_api_key.get_secret_value() + if self.fireworks_api_key else None ), - "base_url": values["fireworks_api_base"], - "timeout": values["request_timeout"], + "base_url": self.fireworks_api_base, + "timeout": self.request_timeout, } - if not values.get("client"): - values["client"] = Fireworks(**client_params).chat.completions - if not values.get("async_client"): - values["async_client"] = AsyncFireworks(**client_params).chat.completions - if values["max_retries"]: - values["client"]._max_retries = values["max_retries"] - values["async_client"]._max_retries = values["max_retries"] - return values + if not (self.client or None): + self.client = Fireworks(**client_params).chat.completions + if not (self.async_client or None): + self.async_client = AsyncFireworks(**client_params).chat.completions + if self.max_retries: + self.client._max_retries = self.max_retries + self.async_client._max_retries = self.max_retries + return self @property def _default_params(self) -> Dict[str, Any]: diff --git a/libs/partners/fireworks/langchain_fireworks/embeddings.py b/libs/partners/fireworks/langchain_fireworks/embeddings.py index 719bb34a935..8fd67f116ca 100644 --- a/libs/partners/fireworks/langchain_fireworks/embeddings.py +++ b/libs/partners/fireworks/langchain_fireworks/embeddings.py @@ -1,9 +1,12 @@ -from typing import Any, Dict, List +from typing import List from langchain_core.embeddings import Embeddings -from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator from langchain_core.utils import secret_from_env -from openai import OpenAI # type: ignore +from openai import OpenAI +from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator +from typing_extensions import Self + +# type: ignore class FireworksEmbeddings(BaseModel, Embeddings): @@ -65,7 +68,7 @@ class FireworksEmbeddings(BaseModel, Embeddings): [-0.024603435769677162, -0.007543657906353474, 0.0039630369283258915] """ - _client: OpenAI = Field(default=None) + client: OpenAI = Field(default=None, exclude=True) #: :meta private: fireworks_api_key: SecretStr = Field( alias="api_key", default_factory=secret_from_env( @@ -79,20 +82,25 @@ class FireworksEmbeddings(BaseModel, Embeddings): """ model: str = "nomic-ai/nomic-embed-text-v1.5" - @root_validator(pre=False, skip_on_failure=True) - def validate_environment(cls, values: Dict[str, Any]) -> Dict[str, Any]: + model_config = ConfigDict( + populate_by_name=True, + arbitrary_types_allowed=True, + ) + + @model_validator(mode="after") + def validate_environment(self) -> Self: """Validate environment variables.""" - values["_client"] = OpenAI( - api_key=values["fireworks_api_key"].get_secret_value(), + self.client = OpenAI( + api_key=self.fireworks_api_key.get_secret_value(), base_url="https://api.fireworks.ai/inference/v1", ) - return values + return self def embed_documents(self, texts: List[str]) -> List[List[float]]: """Embed search docs.""" return [ i.embedding - for i in self._client.embeddings.create(input=texts, model=self.model).data + for i in self.client.embeddings.create(input=texts, model=self.model).data ] def embed_query(self, text: str) -> List[float]: diff --git a/libs/partners/fireworks/langchain_fireworks/llms.py b/libs/partners/fireworks/langchain_fireworks/llms.py index 747c59ecadd..3189483c914 100644 --- a/libs/partners/fireworks/langchain_fireworks/llms.py +++ b/libs/partners/fireworks/langchain_fireworks/llms.py @@ -10,13 +10,9 @@ from langchain_core.callbacks import ( CallbackManagerForLLMRun, ) from langchain_core.language_models.llms import LLM -from langchain_core.pydantic_v1 import Field, SecretStr, root_validator -from langchain_core.utils import ( - convert_to_secret_str, - get_from_dict_or_env, - get_pydantic_field_names, -) -from langchain_core.utils.utils import build_extra_kwargs +from langchain_core.utils import get_pydantic_field_names +from langchain_core.utils.utils import build_extra_kwargs, secret_from_env +from pydantic import ConfigDict, Field, SecretStr, model_validator from langchain_fireworks.version import __version__ @@ -39,8 +35,21 @@ class Fireworks(LLM): base_url: str = "https://api.fireworks.ai/inference/v1/completions" """Base inference API URL.""" - fireworks_api_key: SecretStr = Field(default=None, alias="api_key") - """Fireworks AI API key. Get it here: https://fireworks.ai""" + fireworks_api_key: SecretStr = Field( + alias="api_key", + default_factory=secret_from_env( + "FIREWORKS_API_KEY", + error_message=( + "You must specify an api key. " + "You can pass it an argument as `api_key=...` or " + "set the environment variable `FIREWORKS_API_KEY`." + ), + ), + ) + """Fireworks API key. + + Automatically read from env variable `FIREWORKS_API_KEY` if not provided. + """ model: str """Model name. Available models listed here: https://readme.fireworks.ai/ @@ -74,14 +83,14 @@ class Fireworks(LLM): the response for each token generation step. """ - class Config: - """Configuration for this pydantic object.""" + model_config = ConfigDict( + extra="forbid", + populate_by_name=True, + ) - extra = "forbid" - allow_population_by_field_name = True - - @root_validator(pre=True) - def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: + @model_validator(mode="before") + @classmethod + def build_extra(cls, values: Dict[str, Any]) -> Any: """Build extra kwargs from additional params that were passed in.""" all_required_field_names = get_pydantic_field_names(cls) extra = values.get("model_kwargs", {}) @@ -90,14 +99,6 @@ class Fireworks(LLM): ) return values - @root_validator(pre=False, skip_on_failure=True) - def validate_environment(cls, values: Dict) -> Dict: - """Validate that api key exists in environment.""" - values["fireworks_api_key"] = convert_to_secret_str( - get_from_dict_or_env(values, "fireworks_api_key", "FIREWORKS_API_KEY") - ) - return values - @property def _llm_type(self) -> str: """Return type of model.""" diff --git a/libs/partners/fireworks/tests/integration_tests/test_chat_models.py b/libs/partners/fireworks/tests/integration_tests/test_chat_models.py index 443f9e5f47b..88a1cd46cfc 100644 --- a/libs/partners/fireworks/tests/integration_tests/test_chat_models.py +++ b/libs/partners/fireworks/tests/integration_tests/test_chat_models.py @@ -7,7 +7,7 @@ import json from typing import Optional from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessageChunk -from langchain_core.pydantic_v1 import BaseModel +from pydantic import BaseModel from langchain_fireworks import ChatFireworks diff --git a/libs/partners/fireworks/tests/unit_tests/test_llms.py b/libs/partners/fireworks/tests/unit_tests/test_llms.py index e2fb8a131e4..265df7ede83 100644 --- a/libs/partners/fireworks/tests/unit_tests/test_llms.py +++ b/libs/partners/fireworks/tests/unit_tests/test_llms.py @@ -2,7 +2,7 @@ from typing import cast -from langchain_core.pydantic_v1 import SecretStr +from pydantic import SecretStr from pytest import CaptureFixture, MonkeyPatch from langchain_fireworks import Fireworks