fireworks[major]: switch to pydantic v2

This commit is contained in:
Bagatur
2024-09-03 17:41:28 -07:00
parent 9a9ab65030
commit 559d8a4d13
5 changed files with 75 additions and 64 deletions

View File

@@ -68,12 +68,6 @@ from langchain_core.output_parsers.openai_tools import (
parse_tool_call,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import (
BaseModel,
Field,
SecretStr,
root_validator,
)
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.tools import BaseTool
from langchain_core.utils import (
@@ -85,6 +79,14 @@ from langchain_core.utils.function_calling import (
)
from langchain_core.utils.pydantic import is_basemodel_subclass
from langchain_core.utils.utils import build_extra_kwargs, from_env, secret_from_env
from pydantic import (
BaseModel,
ConfigDict,
Field,
SecretStr,
model_validator,
)
from typing_extensions import Self
logger = logging.getLogger(__name__)
@@ -354,13 +356,13 @@ class ChatFireworks(BaseChatModel):
max_retries: Optional[int] = None
"""Maximum number of retries to make when generating."""
class Config:
"""Configuration for this pydantic object."""
model_config = ConfigDict(
populate_by_name=True,
)
allow_population_by_field_name = True
@root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
@model_validator(mode="before")
@classmethod
def build_extra(cls, values: Dict[str, Any]) -> Any:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls)
extra = values.get("model_kwargs", {})
@@ -369,32 +371,32 @@ class ChatFireworks(BaseChatModel):
)
return values
@root_validator(pre=False, skip_on_failure=True)
def validate_environment(cls, values: Dict) -> Dict:
@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validate that api key and python package exists in environment."""
if values["n"] < 1:
if self.n < 1:
raise ValueError("n must be at least 1.")
if values["n"] > 1 and values["streaming"]:
if self.n > 1 and self.streaming:
raise ValueError("n must be 1 when streaming.")
client_params = {
"api_key": (
values["fireworks_api_key"].get_secret_value()
if values["fireworks_api_key"]
self.fireworks_api_key.get_secret_value()
if self.fireworks_api_key
else None
),
"base_url": values["fireworks_api_base"],
"timeout": values["request_timeout"],
"base_url": self.fireworks_api_base,
"timeout": self.request_timeout,
}
if not values.get("client"):
values["client"] = Fireworks(**client_params).chat.completions
if not values.get("async_client"):
values["async_client"] = AsyncFireworks(**client_params).chat.completions
if values["max_retries"]:
values["client"]._max_retries = values["max_retries"]
values["async_client"]._max_retries = values["max_retries"]
return values
if not (self.client or None):
self.client = Fireworks(**client_params).chat.completions
if not (self.async_client or None):
self.async_client = AsyncFireworks(**client_params).chat.completions
if self.max_retries:
self.client._max_retries = self.max_retries
self.async_client._max_retries = self.max_retries
return self
@property
def _default_params(self) -> Dict[str, Any]:

View File

@@ -1,9 +1,12 @@
from typing import Any, Dict, List
from typing import List
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
from langchain_core.utils import secret_from_env
from openai import OpenAI # type: ignore
from openai import OpenAI
from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
from typing_extensions import Self
# type: ignore
class FireworksEmbeddings(BaseModel, Embeddings):
@@ -65,7 +68,7 @@ class FireworksEmbeddings(BaseModel, Embeddings):
[-0.024603435769677162, -0.007543657906353474, 0.0039630369283258915]
"""
_client: OpenAI = Field(default=None)
client: OpenAI = Field(default=None, exclude=True) #: :meta private:
fireworks_api_key: SecretStr = Field(
alias="api_key",
default_factory=secret_from_env(
@@ -79,20 +82,25 @@ class FireworksEmbeddings(BaseModel, Embeddings):
"""
model: str = "nomic-ai/nomic-embed-text-v1.5"
@root_validator(pre=False, skip_on_failure=True)
def validate_environment(cls, values: Dict[str, Any]) -> Dict[str, Any]:
model_config = ConfigDict(
populate_by_name=True,
arbitrary_types_allowed=True,
)
@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validate environment variables."""
values["_client"] = OpenAI(
api_key=values["fireworks_api_key"].get_secret_value(),
self.client = OpenAI(
api_key=self.fireworks_api_key.get_secret_value(),
base_url="https://api.fireworks.ai/inference/v1",
)
return values
return self
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed search docs."""
return [
i.embedding
for i in self._client.embeddings.create(input=texts, model=self.model).data
for i in self.client.embeddings.create(input=texts, model=self.model).data
]
def embed_query(self, text: str) -> List[float]:

View File

@@ -10,13 +10,9 @@ from langchain_core.callbacks import (
CallbackManagerForLLMRun,
)
from langchain_core.language_models.llms import LLM
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.utils import (
convert_to_secret_str,
get_from_dict_or_env,
get_pydantic_field_names,
)
from langchain_core.utils.utils import build_extra_kwargs
from langchain_core.utils import get_pydantic_field_names
from langchain_core.utils.utils import build_extra_kwargs, secret_from_env
from pydantic import ConfigDict, Field, SecretStr, model_validator
from langchain_fireworks.version import __version__
@@ -39,8 +35,21 @@ class Fireworks(LLM):
base_url: str = "https://api.fireworks.ai/inference/v1/completions"
"""Base inference API URL."""
fireworks_api_key: SecretStr = Field(default=None, alias="api_key")
"""Fireworks AI API key. Get it here: https://fireworks.ai"""
fireworks_api_key: SecretStr = Field(
alias="api_key",
default_factory=secret_from_env(
"FIREWORKS_API_KEY",
error_message=(
"You must specify an api key. "
"You can pass it an argument as `api_key=...` or "
"set the environment variable `FIREWORKS_API_KEY`."
),
),
)
"""Fireworks API key.
Automatically read from env variable `FIREWORKS_API_KEY` if not provided.
"""
model: str
"""Model name. Available models listed here:
https://readme.fireworks.ai/
@@ -74,14 +83,14 @@ class Fireworks(LLM):
the response for each token generation step.
"""
class Config:
"""Configuration for this pydantic object."""
model_config = ConfigDict(
extra="forbid",
populate_by_name=True,
)
extra = "forbid"
allow_population_by_field_name = True
@root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
@model_validator(mode="before")
@classmethod
def build_extra(cls, values: Dict[str, Any]) -> Any:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls)
extra = values.get("model_kwargs", {})
@@ -90,14 +99,6 @@ class Fireworks(LLM):
)
return values
@root_validator(pre=False, skip_on_failure=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key exists in environment."""
values["fireworks_api_key"] = convert_to_secret_str(
get_from_dict_or_env(values, "fireworks_api_key", "FIREWORKS_API_KEY")
)
return values
@property
def _llm_type(self) -> str:
"""Return type of model."""

View File

@@ -7,7 +7,7 @@ import json
from typing import Optional
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessageChunk
from langchain_core.pydantic_v1 import BaseModel
from pydantic import BaseModel
from langchain_fireworks import ChatFireworks

View File

@@ -2,7 +2,7 @@
from typing import cast
from langchain_core.pydantic_v1 import SecretStr
from pydantic import SecretStr
from pytest import CaptureFixture, MonkeyPatch
from langchain_fireworks import Fireworks