mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-18 00:51:18 +00:00
perf[core]: remove unnecessary model validators (#31238)
* Remove unnecessary cast of id -> str (can do with a field setting) * Remove unnecessary `set_text` model validator (can be done with a computed field - though we had to make some changes to the `Generation` class to make this possible Before: ~2.4s Blue circles represent time spent in custom validators :( <img width="1337" alt="Screenshot 2025-05-14 at 10 10 12 AM" src="https://github.com/user-attachments/assets/bb4f477f-4ee3-4870-ae93-14ca7f197d55" /> After: ~2.2s <img width="1344" alt="Screenshot 2025-05-14 at 10 11 03 AM" src="https://github.com/user-attachments/assets/99f97d80-49de-462f-856f-9e7e8662adbc" /> We still want to optimize the backwards compatible tool calls model validator, though I think this might involve breaking changes, so wanted to separate that into a different PR. This is circled in green.
This commit is contained in:
parent
1523602196
commit
7263011b24
@ -8,7 +8,7 @@ from io import BufferedReader, BytesIO
|
||||
from pathlib import Path, PurePath
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast
|
||||
|
||||
from pydantic import ConfigDict, Field, field_validator, model_validator
|
||||
from pydantic import ConfigDict, Field, model_validator
|
||||
|
||||
from langchain_core.load.serializable import Serializable
|
||||
|
||||
@ -33,7 +33,7 @@ class BaseMedia(Serializable):
|
||||
# The ID field is optional at the moment.
|
||||
# It will likely become required in a future major release after
|
||||
# it has been adopted by enough vectorstore implementations.
|
||||
id: Optional[str] = None
|
||||
id: Optional[str] = Field(default=None, coerce_numbers_to_str=True)
|
||||
"""An optional identifier for the document.
|
||||
|
||||
Ideally this should be unique across the document collection and formatted
|
||||
@ -45,17 +45,6 @@ class BaseMedia(Serializable):
|
||||
metadata: dict = Field(default_factory=dict)
|
||||
"""Arbitrary metadata associated with the content."""
|
||||
|
||||
@field_validator("id", mode="before")
|
||||
def cast_id_to_str(cls, id_value: Any) -> Optional[str]:
|
||||
"""Coerce the id field to a string.
|
||||
|
||||
Args:
|
||||
id_value: The id value to coerce.
|
||||
"""
|
||||
if id_value is not None:
|
||||
return str(id_value)
|
||||
return id_value
|
||||
|
||||
|
||||
class Blob(BaseMedia):
|
||||
"""Blob represents raw data by either reference or value.
|
||||
|
@ -194,6 +194,7 @@ class AIMessage(BaseMessage):
|
||||
"invalid_tool_calls": self.invalid_tool_calls,
|
||||
}
|
||||
|
||||
# TODO: remove this logic if possible, reducing breaking nature of changes
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def _backwards_compat_tool_calls(cls, values: dict) -> Any:
|
||||
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
||||
|
||||
from pydantic import ConfigDict, Field, field_validator
|
||||
from pydantic import ConfigDict, Field
|
||||
|
||||
from langchain_core.load.serializable import Serializable
|
||||
from langchain_core.utils import get_bolded_text
|
||||
@ -52,7 +52,7 @@ class BaseMessage(Serializable):
|
||||
model implementation.
|
||||
"""
|
||||
|
||||
id: Optional[str] = None
|
||||
id: Optional[str] = Field(default=None, coerce_numbers_to_str=True)
|
||||
"""An optional unique identifier for the message. This should ideally be
|
||||
provided by the provider/model which created the message."""
|
||||
|
||||
@ -60,13 +60,6 @@ class BaseMessage(Serializable):
|
||||
extra="allow",
|
||||
)
|
||||
|
||||
@field_validator("id", mode="before")
|
||||
def cast_id_to_str(cls, id_value: Any) -> Optional[str]:
|
||||
"""Coerce the id field to a string."""
|
||||
if id_value is not None:
|
||||
return str(id_value)
|
||||
return id_value
|
||||
|
||||
def __init__(
|
||||
self, content: Union[str, list[Union[str, dict]]], **kwargs: Any
|
||||
) -> None:
|
||||
|
@ -4,8 +4,7 @@ from __future__ import annotations
|
||||
|
||||
from typing import Literal, Union
|
||||
|
||||
from pydantic import model_validator
|
||||
from typing_extensions import Self
|
||||
from pydantic import computed_field
|
||||
|
||||
from langchain_core.messages import BaseMessage, BaseMessageChunk
|
||||
from langchain_core.outputs.generation import Generation
|
||||
@ -26,48 +25,30 @@ class ChatGeneration(Generation):
|
||||
via callbacks).
|
||||
"""
|
||||
|
||||
text: str = ""
|
||||
"""*SHOULD NOT BE SET DIRECTLY* The text contents of the output message."""
|
||||
message: BaseMessage
|
||||
"""The message output by the chat model."""
|
||||
# Override type to be ChatGeneration, ignore mypy error as this is intentional
|
||||
|
||||
type: Literal["ChatGeneration"] = "ChatGeneration" # type: ignore[assignment]
|
||||
"""Type is used exclusively for serialization purposes."""
|
||||
|
||||
@model_validator(mode="after")
|
||||
def set_text(self) -> Self:
|
||||
"""Set the text attribute to be the contents of the message.
|
||||
|
||||
Args:
|
||||
values: The values of the object.
|
||||
|
||||
Returns:
|
||||
The values of the object with the text attribute set.
|
||||
|
||||
Raises:
|
||||
ValueError: If the message is not a string or a list.
|
||||
"""
|
||||
try:
|
||||
text = ""
|
||||
if isinstance(self.message.content, str):
|
||||
text = self.message.content
|
||||
# Assumes text in content blocks in OpenAI format.
|
||||
# Uses first text block.
|
||||
elif isinstance(self.message.content, list):
|
||||
for block in self.message.content:
|
||||
if isinstance(block, str):
|
||||
text = block
|
||||
break
|
||||
if isinstance(block, dict) and "text" in block:
|
||||
text = block["text"]
|
||||
break
|
||||
else:
|
||||
pass
|
||||
self.text = text
|
||||
except (KeyError, AttributeError) as e:
|
||||
msg = "Error while initializing ChatGeneration"
|
||||
raise ValueError(msg) from e
|
||||
return self
|
||||
@computed_field # type: ignore[prop-decorator]
|
||||
@property
|
||||
def text(self) -> str:
|
||||
"""Set the text attribute to be the contents of the message."""
|
||||
text_ = ""
|
||||
if isinstance(self.message.content, str):
|
||||
text_ = self.message.content
|
||||
# Assumes text in content blocks in OpenAI format.
|
||||
# Uses first text block.
|
||||
elif isinstance(self.message.content, list):
|
||||
for block in self.message.content:
|
||||
if isinstance(block, str):
|
||||
text_ = block
|
||||
break
|
||||
if isinstance(block, dict) and "text" in block:
|
||||
text_ = block["text"]
|
||||
break
|
||||
return text_
|
||||
|
||||
|
||||
class ChatGenerationChunk(ChatGeneration):
|
||||
@ -78,7 +59,7 @@ class ChatGenerationChunk(ChatGeneration):
|
||||
|
||||
message: BaseMessageChunk
|
||||
"""The message chunk output by the chat model."""
|
||||
# Override type to be ChatGeneration, ignore mypy error as this is intentional
|
||||
|
||||
type: Literal["ChatGenerationChunk"] = "ChatGenerationChunk" # type: ignore[assignment]
|
||||
"""Type is used exclusively for serialization purposes."""
|
||||
|
||||
|
@ -4,6 +4,8 @@ from __future__ import annotations
|
||||
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from pydantic import computed_field
|
||||
|
||||
from langchain_core.load import Serializable
|
||||
from langchain_core.utils._merge import merge_dicts
|
||||
|
||||
@ -24,14 +26,30 @@ class Generation(Serializable):
|
||||
for more information.
|
||||
"""
|
||||
|
||||
text: str
|
||||
"""Generated text output."""
|
||||
def __init__(
|
||||
self,
|
||||
text: str = "",
|
||||
generation_info: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Initialize a Generation."""
|
||||
super().__init__(generation_info=generation_info, **kwargs)
|
||||
self._text = text
|
||||
|
||||
# workaround for ChatGeneration so that we can use a computed field to populate
|
||||
# the text field from the message content (parent class needs to have a property)
|
||||
@computed_field # type: ignore[prop-decorator]
|
||||
@property
|
||||
def text(self) -> str:
|
||||
"""The text contents of the output."""
|
||||
return self._text
|
||||
|
||||
generation_info: Optional[dict[str, Any]] = None
|
||||
"""Raw response from the provider.
|
||||
|
||||
May include things like the reason for finishing or token log probabilities.
|
||||
"""
|
||||
|
||||
type: Literal["Generation"] = "Generation"
|
||||
"""Type is used exclusively for serialization purposes.
|
||||
Set to "Generation" for this class."""
|
||||
@ -53,6 +71,16 @@ class Generation(Serializable):
|
||||
class GenerationChunk(Generation):
|
||||
"""Generation chunk, which can be concatenated with other Generation chunks."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text: str = "",
|
||||
generation_info: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Initialize a GenerationChunk."""
|
||||
super().__init__(text=text, generation_info=generation_info, **kwargs)
|
||||
self._text = text
|
||||
|
||||
def __add__(self, other: GenerationChunk) -> GenerationChunk:
|
||||
"""Concatenate two GenerationChunks."""
|
||||
if isinstance(other, GenerationChunk):
|
||||
|
Loading…
Reference in New Issue
Block a user