mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-26 13:59:49 +00:00
fireworks[major]: switch to pydantic v2
This commit is contained in:
@@ -68,12 +68,6 @@ from langchain_core.output_parsers.openai_tools import (
|
||||
parse_tool_call,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.pydantic_v1 import (
|
||||
BaseModel,
|
||||
Field,
|
||||
SecretStr,
|
||||
root_validator,
|
||||
)
|
||||
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils import (
|
||||
@@ -85,6 +79,14 @@ from langchain_core.utils.function_calling import (
|
||||
)
|
||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
from langchain_core.utils.utils import build_extra_kwargs, from_env, secret_from_env
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
SecretStr,
|
||||
model_validator,
|
||||
)
|
||||
from typing_extensions import Self
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -354,13 +356,13 @@ class ChatFireworks(BaseChatModel):
|
||||
max_retries: Optional[int] = None
|
||||
"""Maximum number of retries to make when generating."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
model_config = ConfigDict(
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
||||
allow_population_by_field_name = True
|
||||
|
||||
@root_validator(pre=True)
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Any:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = get_pydantic_field_names(cls)
|
||||
extra = values.get("model_kwargs", {})
|
||||
@@ -369,32 +371,32 @@ class ChatFireworks(BaseChatModel):
|
||||
)
|
||||
return values
|
||||
|
||||
@root_validator(pre=False, skip_on_failure=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
@model_validator(mode="after")
|
||||
def validate_environment(self) -> Self:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
if values["n"] < 1:
|
||||
if self.n < 1:
|
||||
raise ValueError("n must be at least 1.")
|
||||
if values["n"] > 1 and values["streaming"]:
|
||||
if self.n > 1 and self.streaming:
|
||||
raise ValueError("n must be 1 when streaming.")
|
||||
|
||||
client_params = {
|
||||
"api_key": (
|
||||
values["fireworks_api_key"].get_secret_value()
|
||||
if values["fireworks_api_key"]
|
||||
self.fireworks_api_key.get_secret_value()
|
||||
if self.fireworks_api_key
|
||||
else None
|
||||
),
|
||||
"base_url": values["fireworks_api_base"],
|
||||
"timeout": values["request_timeout"],
|
||||
"base_url": self.fireworks_api_base,
|
||||
"timeout": self.request_timeout,
|
||||
}
|
||||
|
||||
if not values.get("client"):
|
||||
values["client"] = Fireworks(**client_params).chat.completions
|
||||
if not values.get("async_client"):
|
||||
values["async_client"] = AsyncFireworks(**client_params).chat.completions
|
||||
if values["max_retries"]:
|
||||
values["client"]._max_retries = values["max_retries"]
|
||||
values["async_client"]._max_retries = values["max_retries"]
|
||||
return values
|
||||
if not (self.client or None):
|
||||
self.client = Fireworks(**client_params).chat.completions
|
||||
if not (self.async_client or None):
|
||||
self.async_client = AsyncFireworks(**client_params).chat.completions
|
||||
if self.max_retries:
|
||||
self.client._max_retries = self.max_retries
|
||||
self.async_client._max_retries = self.max_retries
|
||||
return self
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
|
@@ -1,9 +1,12 @@
|
||||
from typing import Any, Dict, List
|
||||
from typing import List
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
|
||||
from langchain_core.utils import secret_from_env
|
||||
from openai import OpenAI # type: ignore
|
||||
from openai import OpenAI
|
||||
from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
# type: ignore
|
||||
|
||||
|
||||
class FireworksEmbeddings(BaseModel, Embeddings):
|
||||
@@ -65,7 +68,7 @@ class FireworksEmbeddings(BaseModel, Embeddings):
|
||||
[-0.024603435769677162, -0.007543657906353474, 0.0039630369283258915]
|
||||
"""
|
||||
|
||||
_client: OpenAI = Field(default=None)
|
||||
client: OpenAI = Field(default=None, exclude=True) #: :meta private:
|
||||
fireworks_api_key: SecretStr = Field(
|
||||
alias="api_key",
|
||||
default_factory=secret_from_env(
|
||||
@@ -79,20 +82,25 @@ class FireworksEmbeddings(BaseModel, Embeddings):
|
||||
"""
|
||||
model: str = "nomic-ai/nomic-embed-text-v1.5"
|
||||
|
||||
@root_validator(pre=False, skip_on_failure=True)
|
||||
def validate_environment(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
model_config = ConfigDict(
|
||||
populate_by_name=True,
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_environment(self) -> Self:
|
||||
"""Validate environment variables."""
|
||||
values["_client"] = OpenAI(
|
||||
api_key=values["fireworks_api_key"].get_secret_value(),
|
||||
self.client = OpenAI(
|
||||
api_key=self.fireworks_api_key.get_secret_value(),
|
||||
base_url="https://api.fireworks.ai/inference/v1",
|
||||
)
|
||||
return values
|
||||
return self
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Embed search docs."""
|
||||
return [
|
||||
i.embedding
|
||||
for i in self._client.embeddings.create(input=texts, model=self.model).data
|
||||
for i in self.client.embeddings.create(input=texts, model=self.model).data
|
||||
]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
|
@@ -10,13 +10,9 @@ from langchain_core.callbacks import (
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models.llms import LLM
|
||||
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
|
||||
from langchain_core.utils import (
|
||||
convert_to_secret_str,
|
||||
get_from_dict_or_env,
|
||||
get_pydantic_field_names,
|
||||
)
|
||||
from langchain_core.utils.utils import build_extra_kwargs
|
||||
from langchain_core.utils import get_pydantic_field_names
|
||||
from langchain_core.utils.utils import build_extra_kwargs, secret_from_env
|
||||
from pydantic import ConfigDict, Field, SecretStr, model_validator
|
||||
|
||||
from langchain_fireworks.version import __version__
|
||||
|
||||
@@ -39,8 +35,21 @@ class Fireworks(LLM):
|
||||
|
||||
base_url: str = "https://api.fireworks.ai/inference/v1/completions"
|
||||
"""Base inference API URL."""
|
||||
fireworks_api_key: SecretStr = Field(default=None, alias="api_key")
|
||||
"""Fireworks AI API key. Get it here: https://fireworks.ai"""
|
||||
fireworks_api_key: SecretStr = Field(
|
||||
alias="api_key",
|
||||
default_factory=secret_from_env(
|
||||
"FIREWORKS_API_KEY",
|
||||
error_message=(
|
||||
"You must specify an api key. "
|
||||
"You can pass it an argument as `api_key=...` or "
|
||||
"set the environment variable `FIREWORKS_API_KEY`."
|
||||
),
|
||||
),
|
||||
)
|
||||
"""Fireworks API key.
|
||||
|
||||
Automatically read from env variable `FIREWORKS_API_KEY` if not provided.
|
||||
"""
|
||||
model: str
|
||||
"""Model name. Available models listed here:
|
||||
https://readme.fireworks.ai/
|
||||
@@ -74,14 +83,14 @@ class Fireworks(LLM):
|
||||
the response for each token generation step.
|
||||
"""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
model_config = ConfigDict(
|
||||
extra="forbid",
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
||||
extra = "forbid"
|
||||
allow_population_by_field_name = True
|
||||
|
||||
@root_validator(pre=True)
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Any:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = get_pydantic_field_names(cls)
|
||||
extra = values.get("model_kwargs", {})
|
||||
@@ -90,14 +99,6 @@ class Fireworks(LLM):
|
||||
)
|
||||
return values
|
||||
|
||||
@root_validator(pre=False, skip_on_failure=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key exists in environment."""
|
||||
values["fireworks_api_key"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(values, "fireworks_api_key", "FIREWORKS_API_KEY")
|
||||
)
|
||||
return values
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of model."""
|
||||
|
@@ -7,7 +7,7 @@ import json
|
||||
from typing import Optional
|
||||
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessageChunk
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain_fireworks import ChatFireworks
|
||||
|
||||
|
@@ -2,7 +2,7 @@
|
||||
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.pydantic_v1 import SecretStr
|
||||
from pydantic import SecretStr
|
||||
from pytest import CaptureFixture, MonkeyPatch
|
||||
|
||||
from langchain_fireworks import Fireworks
|
||||
|
Reference in New Issue
Block a user