perf[core]: remove unnecessary model validators (#31238)

* Remove unnecessary cast of id -> str (can do with a field setting)
* Remove unnecessary `set_text` model validator (can be done with a
computed field - though we had to make some changes to the `Generation`
class to make this possible

Before: ~2.4s

Blue circles represent time spent in custom validators :(

<img width="1337" alt="Screenshot 2025-05-14 at 10 10 12 AM"
src="https://github.com/user-attachments/assets/bb4f477f-4ee3-4870-ae93-14ca7f197d55"
/>


After: ~2.2s

<img width="1344" alt="Screenshot 2025-05-14 at 10 11 03 AM"
src="https://github.com/user-attachments/assets/99f97d80-49de-462f-856f-9e7e8662adbc"
/>

We still want to optimize the backwards compatible tool calls model
validator, though I think this might involve breaking changes, so wanted
to separate that into a different PR. This is circled in green.
This commit is contained in:
Sydney Runkle 2025-05-14 10:20:22 -07:00 committed by GitHub
parent 1523602196
commit 7263011b24
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 56 additions and 64 deletions

View File

@ -8,7 +8,7 @@ from io import BufferedReader, BytesIO
from pathlib import Path, PurePath
from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast
from pydantic import ConfigDict, Field, field_validator, model_validator
from pydantic import ConfigDict, Field, model_validator
from langchain_core.load.serializable import Serializable
@ -33,7 +33,7 @@ class BaseMedia(Serializable):
# The ID field is optional at the moment.
# It will likely become required in a future major release after
# it has been adopted by enough vectorstore implementations.
id: Optional[str] = None
id: Optional[str] = Field(default=None, coerce_numbers_to_str=True)
"""An optional identifier for the document.
Ideally this should be unique across the document collection and formatted
@ -45,17 +45,6 @@ class BaseMedia(Serializable):
metadata: dict = Field(default_factory=dict)
"""Arbitrary metadata associated with the content."""
@field_validator("id", mode="before")
def cast_id_to_str(cls, id_value: Any) -> Optional[str]:
"""Coerce the id field to a string.
Args:
id_value: The id value to coerce.
"""
if id_value is not None:
return str(id_value)
return id_value
class Blob(BaseMedia):
"""Blob represents raw data by either reference or value.

View File

@ -194,6 +194,7 @@ class AIMessage(BaseMessage):
"invalid_tool_calls": self.invalid_tool_calls,
}
# TODO: remove this logic if possible, reducing breaking nature of changes
@model_validator(mode="before")
@classmethod
def _backwards_compat_tool_calls(cls, values: dict) -> Any:

View File

@ -4,7 +4,7 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Any, Optional, Union, cast
from pydantic import ConfigDict, Field, field_validator
from pydantic import ConfigDict, Field
from langchain_core.load.serializable import Serializable
from langchain_core.utils import get_bolded_text
@ -52,7 +52,7 @@ class BaseMessage(Serializable):
model implementation.
"""
id: Optional[str] = None
id: Optional[str] = Field(default=None, coerce_numbers_to_str=True)
"""An optional unique identifier for the message. This should ideally be
provided by the provider/model which created the message."""
@ -60,13 +60,6 @@ class BaseMessage(Serializable):
extra="allow",
)
@field_validator("id", mode="before")
def cast_id_to_str(cls, id_value: Any) -> Optional[str]:
"""Coerce the id field to a string."""
if id_value is not None:
return str(id_value)
return id_value
def __init__(
self, content: Union[str, list[Union[str, dict]]], **kwargs: Any
) -> None:

View File

@ -4,8 +4,7 @@ from __future__ import annotations
from typing import Literal, Union
from pydantic import model_validator
from typing_extensions import Self
from pydantic import computed_field
from langchain_core.messages import BaseMessage, BaseMessageChunk
from langchain_core.outputs.generation import Generation
@ -26,48 +25,30 @@ class ChatGeneration(Generation):
via callbacks).
"""
text: str = ""
"""*SHOULD NOT BE SET DIRECTLY* The text contents of the output message."""
message: BaseMessage
"""The message output by the chat model."""
# Override type to be ChatGeneration, ignore mypy error as this is intentional
type: Literal["ChatGeneration"] = "ChatGeneration" # type: ignore[assignment]
"""Type is used exclusively for serialization purposes."""
@model_validator(mode="after")
def set_text(self) -> Self:
"""Set the text attribute to be the contents of the message.
Args:
values: The values of the object.
Returns:
The values of the object with the text attribute set.
Raises:
ValueError: If the message is not a string or a list.
"""
try:
text = ""
@computed_field # type: ignore[prop-decorator]
@property
def text(self) -> str:
"""Set the text attribute to be the contents of the message."""
text_ = ""
if isinstance(self.message.content, str):
text = self.message.content
text_ = self.message.content
# Assumes text in content blocks in OpenAI format.
# Uses first text block.
elif isinstance(self.message.content, list):
for block in self.message.content:
if isinstance(block, str):
text = block
text_ = block
break
if isinstance(block, dict) and "text" in block:
text = block["text"]
text_ = block["text"]
break
else:
pass
self.text = text
except (KeyError, AttributeError) as e:
msg = "Error while initializing ChatGeneration"
raise ValueError(msg) from e
return self
return text_
class ChatGenerationChunk(ChatGeneration):
@ -78,7 +59,7 @@ class ChatGenerationChunk(ChatGeneration):
message: BaseMessageChunk
"""The message chunk output by the chat model."""
# Override type to be ChatGeneration, ignore mypy error as this is intentional
type: Literal["ChatGenerationChunk"] = "ChatGenerationChunk" # type: ignore[assignment]
"""Type is used exclusively for serialization purposes."""

View File

@ -4,6 +4,8 @@ from __future__ import annotations
from typing import Any, Literal, Optional
from pydantic import computed_field
from langchain_core.load import Serializable
from langchain_core.utils._merge import merge_dicts
@ -24,14 +26,30 @@ class Generation(Serializable):
for more information.
"""
text: str
"""Generated text output."""
def __init__(
self,
text: str = "",
generation_info: Optional[dict[str, Any]] = None,
**kwargs: Any,
):
"""Initialize a Generation."""
super().__init__(generation_info=generation_info, **kwargs)
self._text = text
# workaround for ChatGeneration so that we can use a computed field to populate
# the text field from the message content (parent class needs to have a property)
@computed_field # type: ignore[prop-decorator]
@property
def text(self) -> str:
"""The text contents of the output."""
return self._text
generation_info: Optional[dict[str, Any]] = None
"""Raw response from the provider.
May include things like the reason for finishing or token log probabilities.
"""
type: Literal["Generation"] = "Generation"
"""Type is used exclusively for serialization purposes.
Set to "Generation" for this class."""
@ -53,6 +71,16 @@ class Generation(Serializable):
class GenerationChunk(Generation):
"""Generation chunk, which can be concatenated with other Generation chunks."""
def __init__(
self,
text: str = "",
generation_info: Optional[dict[str, Any]] = None,
**kwargs: Any,
):
"""Initialize a GenerationChunk."""
super().__init__(text=text, generation_info=generation_info, **kwargs)
self._text = text
def __add__(self, other: GenerationChunk) -> GenerationChunk:
"""Concatenate two GenerationChunks."""
if isinstance(other, GenerationChunk):