fireworks[major]: switch to pydantic v2 (#26004)

This commit is contained in:
Bagatur
2024-09-04 09:18:10 -07:00
committed by GitHub
8 changed files with 125 additions and 70 deletions

View File

@@ -68,12 +68,6 @@ from langchain_core.output_parsers.openai_tools import (
parse_tool_call, parse_tool_call,
) )
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import (
BaseModel,
Field,
SecretStr,
root_validator,
)
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
from langchain_core.utils import ( from langchain_core.utils import (
@@ -85,6 +79,14 @@ from langchain_core.utils.function_calling import (
) )
from langchain_core.utils.pydantic import is_basemodel_subclass from langchain_core.utils.pydantic import is_basemodel_subclass
from langchain_core.utils.utils import build_extra_kwargs, from_env, secret_from_env from langchain_core.utils.utils import build_extra_kwargs, from_env, secret_from_env
from pydantic import (
BaseModel,
ConfigDict,
Field,
SecretStr,
model_validator,
)
from typing_extensions import Self
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -354,13 +356,13 @@ class ChatFireworks(BaseChatModel):
max_retries: Optional[int] = None max_retries: Optional[int] = None
"""Maximum number of retries to make when generating.""" """Maximum number of retries to make when generating."""
class Config: model_config = ConfigDict(
"""Configuration for this pydantic object.""" populate_by_name=True,
)
allow_population_by_field_name = True @model_validator(mode="before")
@classmethod
@root_validator(pre=True) def build_extra(cls, values: Dict[str, Any]) -> Any:
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra kwargs from additional params that were passed in.""" """Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls) all_required_field_names = get_pydantic_field_names(cls)
extra = values.get("model_kwargs", {}) extra = values.get("model_kwargs", {})
@@ -369,32 +371,32 @@ class ChatFireworks(BaseChatModel):
) )
return values return values
@root_validator(pre=False, skip_on_failure=True) @model_validator(mode="after")
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(self) -> Self:
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""
if values["n"] < 1: if self.n < 1:
raise ValueError("n must be at least 1.") raise ValueError("n must be at least 1.")
if values["n"] > 1 and values["streaming"]: if self.n > 1 and self.streaming:
raise ValueError("n must be 1 when streaming.") raise ValueError("n must be 1 when streaming.")
client_params = { client_params = {
"api_key": ( "api_key": (
values["fireworks_api_key"].get_secret_value() self.fireworks_api_key.get_secret_value()
if values["fireworks_api_key"] if self.fireworks_api_key
else None else None
), ),
"base_url": values["fireworks_api_base"], "base_url": self.fireworks_api_base,
"timeout": values["request_timeout"], "timeout": self.request_timeout,
} }
if not values.get("client"): if not self.client:
values["client"] = Fireworks(**client_params).chat.completions self.client = Fireworks(**client_params).chat.completions
if not values.get("async_client"): if not self.async_client:
values["async_client"] = AsyncFireworks(**client_params).chat.completions self.async_client = AsyncFireworks(**client_params).chat.completions
if values["max_retries"]: if self.max_retries:
values["client"]._max_retries = values["max_retries"] self.client._max_retries = self.max_retries
values["async_client"]._max_retries = values["max_retries"] self.async_client._max_retries = self.max_retries
return values return self
@property @property
def _default_params(self) -> Dict[str, Any]: def _default_params(self) -> Dict[str, Any]:
@@ -803,7 +805,7 @@ class ChatFireworks(BaseChatModel):
from typing import Optional from typing import Optional
from langchain_fireworks import ChatFireworks from langchain_fireworks import ChatFireworks
from langchain_core.pydantic_v1 import BaseModel, Field from pydantic import BaseModel, Field
class AnswerWithJustification(BaseModel): class AnswerWithJustification(BaseModel):
@@ -834,7 +836,7 @@ class ChatFireworks(BaseChatModel):
.. code-block:: python .. code-block:: python
from langchain_fireworks import ChatFireworks from langchain_fireworks import ChatFireworks
from langchain_core.pydantic_v1 import BaseModel from pydantic import BaseModel
class AnswerWithJustification(BaseModel): class AnswerWithJustification(BaseModel):
@@ -921,7 +923,7 @@ class ChatFireworks(BaseChatModel):
.. code-block:: .. code-block::
from langchain_fireworks import ChatFireworks from langchain_fireworks import ChatFireworks
from langchain_core.pydantic_v1 import BaseModel from pydantic import BaseModel
class AnswerWithJustification(BaseModel): class AnswerWithJustification(BaseModel):
answer: str answer: str

View File

@@ -1,9 +1,12 @@
from typing import Any, Dict, List from typing import List
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
from langchain_core.utils import secret_from_env from langchain_core.utils import secret_from_env
from openai import OpenAI # type: ignore from openai import OpenAI
from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
from typing_extensions import Self
# type: ignore
class FireworksEmbeddings(BaseModel, Embeddings): class FireworksEmbeddings(BaseModel, Embeddings):
@@ -65,7 +68,7 @@ class FireworksEmbeddings(BaseModel, Embeddings):
[-0.024603435769677162, -0.007543657906353474, 0.0039630369283258915] [-0.024603435769677162, -0.007543657906353474, 0.0039630369283258915]
""" """
_client: OpenAI = Field(default=None) client: OpenAI = Field(default=None, exclude=True) #: :meta private:
fireworks_api_key: SecretStr = Field( fireworks_api_key: SecretStr = Field(
alias="api_key", alias="api_key",
default_factory=secret_from_env( default_factory=secret_from_env(
@@ -79,20 +82,25 @@ class FireworksEmbeddings(BaseModel, Embeddings):
""" """
model: str = "nomic-ai/nomic-embed-text-v1.5" model: str = "nomic-ai/nomic-embed-text-v1.5"
@root_validator(pre=False, skip_on_failure=True) model_config = ConfigDict(
def validate_environment(cls, values: Dict[str, Any]) -> Dict[str, Any]: populate_by_name=True,
arbitrary_types_allowed=True,
)
@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validate environment variables.""" """Validate environment variables."""
values["_client"] = OpenAI( self.client = OpenAI(
api_key=values["fireworks_api_key"].get_secret_value(), api_key=self.fireworks_api_key.get_secret_value(),
base_url="https://api.fireworks.ai/inference/v1", base_url="https://api.fireworks.ai/inference/v1",
) )
return values return self
def embed_documents(self, texts: List[str]) -> List[List[float]]: def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed search docs.""" """Embed search docs."""
return [ return [
i.embedding i.embedding
for i in self._client.embeddings.create(input=texts, model=self.model).data for i in self.client.embeddings.create(input=texts, model=self.model).data
] ]
def embed_query(self, text: str) -> List[float]: def embed_query(self, text: str) -> List[float]:

View File

@@ -10,13 +10,9 @@ from langchain_core.callbacks import (
CallbackManagerForLLMRun, CallbackManagerForLLMRun,
) )
from langchain_core.language_models.llms import LLM from langchain_core.language_models.llms import LLM
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator from langchain_core.utils import get_pydantic_field_names
from langchain_core.utils import ( from langchain_core.utils.utils import build_extra_kwargs, secret_from_env
convert_to_secret_str, from pydantic import ConfigDict, Field, SecretStr, model_validator
get_from_dict_or_env,
get_pydantic_field_names,
)
from langchain_core.utils.utils import build_extra_kwargs
from langchain_fireworks.version import __version__ from langchain_fireworks.version import __version__
@@ -39,8 +35,21 @@ class Fireworks(LLM):
base_url: str = "https://api.fireworks.ai/inference/v1/completions" base_url: str = "https://api.fireworks.ai/inference/v1/completions"
"""Base inference API URL.""" """Base inference API URL."""
fireworks_api_key: SecretStr = Field(default=None, alias="api_key") fireworks_api_key: SecretStr = Field(
"""Fireworks AI API key. Get it here: https://fireworks.ai""" alias="api_key",
default_factory=secret_from_env(
"FIREWORKS_API_KEY",
error_message=(
"You must specify an api key. "
"You can pass it an argument as `api_key=...` or "
"set the environment variable `FIREWORKS_API_KEY`."
),
),
)
"""Fireworks API key.
Automatically read from env variable `FIREWORKS_API_KEY` if not provided.
"""
model: str model: str
"""Model name. Available models listed here: """Model name. Available models listed here:
https://readme.fireworks.ai/ https://readme.fireworks.ai/
@@ -74,14 +83,14 @@ class Fireworks(LLM):
the response for each token generation step. the response for each token generation step.
""" """
class Config: model_config = ConfigDict(
"""Configuration for this pydantic object.""" extra="forbid",
populate_by_name=True,
)
extra = "forbid" @model_validator(mode="before")
allow_population_by_field_name = True @classmethod
def build_extra(cls, values: Dict[str, Any]) -> Any:
@root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra kwargs from additional params that were passed in.""" """Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls) all_required_field_names = get_pydantic_field_names(cls)
extra = values.get("model_kwargs", {}) extra = values.get("model_kwargs", {})
@@ -90,14 +99,6 @@ class Fireworks(LLM):
) )
return values return values
@root_validator(pre=False, skip_on_failure=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key exists in environment."""
values["fireworks_api_key"] = convert_to_secret_str(
get_from_dict_or_env(values, "fireworks_api_key", "FIREWORKS_API_KEY")
)
return values
@property @property
def _llm_type(self) -> str: def _llm_type(self) -> str:
"""Return type of model.""" """Return type of model."""

View File

@@ -20,8 +20,8 @@ result=$(git -C "$repository_path" grep -E '^import pydantic|^from pydantic')
if [ -n "$result" ]; then if [ -n "$result" ]; then
echo "ERROR: The following lines need to be updated:" echo "ERROR: The following lines need to be updated:"
echo "$result" echo "$result"
echo "Please replace the code with an import from langchain_core.pydantic_v1." echo "Please replace the code with an import from pydantic."
echo "For example, replace 'from pydantic import BaseModel'" echo "For example, replace 'from pydantic import BaseModel'"
echo "with 'from langchain_core.pydantic_v1 import BaseModel'" echo "with 'from pydantic import BaseModel'"
exit 1 exit 1
fi fi

View File

@@ -7,7 +7,7 @@ import json
from typing import Optional from typing import Optional
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessageChunk from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessageChunk
from langchain_core.pydantic_v1 import BaseModel from pydantic import BaseModel
from langchain_fireworks import ChatFireworks from langchain_fireworks import ChatFireworks

View File

@@ -0,0 +1,30 @@
"""Standard LangChain interface tests"""
from typing import Tuple, Type
from langchain_core.embeddings import Embeddings
from langchain_standard_tests.unit_tests.embeddings import EmbeddingsUnitTests
from langchain_fireworks import FireworksEmbeddings
class TestFireworksStandard(EmbeddingsUnitTests):
@property
def embeddings_class(self) -> Type[Embeddings]:
return FireworksEmbeddings
@property
def embeddings_params(self) -> dict:
return {"api_key": "test_api_key"}
@property
def init_from_env_params(self) -> Tuple[dict, dict, dict]:
return (
{
"FIREWORKS_API_KEY": "api_key",
},
{},
{
"fireworks_api_key": "api_key",
},
)

View File

@@ -2,7 +2,7 @@
from typing import cast from typing import cast
from langchain_core.pydantic_v1 import SecretStr from pydantic import SecretStr
from pytest import CaptureFixture, MonkeyPatch from pytest import CaptureFixture, MonkeyPatch
from langchain_fireworks import Fireworks from langchain_fireworks import Fireworks

View File

@@ -1,6 +1,6 @@
"""Standard LangChain interface tests""" """Standard LangChain interface tests"""
from typing import Type from typing import Tuple, Type
from langchain_core.language_models import BaseChatModel from langchain_core.language_models import BaseChatModel
from langchain_standard_tests.unit_tests import ( # type: ignore[import-not-found] from langchain_standard_tests.unit_tests import ( # type: ignore[import-not-found]
@@ -18,3 +18,17 @@ class TestFireworksStandard(ChatModelUnitTests):
@property @property
def chat_model_params(self) -> dict: def chat_model_params(self) -> dict:
return {"api_key": "test_api_key"} return {"api_key": "test_api_key"}
@property
def init_from_env_params(self) -> Tuple[dict, dict, dict]:
return (
{
"FIREWORKS_API_KEY": "api_key",
"FIREWORKS_API_BASE": "https://base.com",
},
{},
{
"fireworks_api_key": "api_key",
"fireworks_api_base": "https://base.com",
},
)