From 7263011b24b02e71cb4c6f8798d16815adf65e40 Mon Sep 17 00:00:00 2001
From: Sydney Runkle <54324534+sydney-runkle@users.noreply.github.com>
Date: Wed, 14 May 2025 10:20:22 -0700
Subject: [PATCH] perf[core]: remove unnecessary model validators (#31238)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
* Remove unnecessary cast of id -> str (can do with a field setting)
* Remove unnecessary `set_text` model validator (can be done with a
computed field - though we had to make some changes to the `Generation`
class to make this possible
Before: ~2.4s
Blue circles represent time spent in custom validators :(
After: ~2.2s
We still want to optimize the backwards compatible tool calls model
validator, though I think this might involve breaking changes, so wanted
to separate that into a different PR. This is circled in green.
---
libs/core/langchain_core/documents/base.py | 15 +----
libs/core/langchain_core/messages/ai.py | 1 +
libs/core/langchain_core/messages/base.py | 11 +---
.../langchain_core/outputs/chat_generation.py | 61 +++++++------------
.../core/langchain_core/outputs/generation.py | 32 +++++++++-
5 files changed, 56 insertions(+), 64 deletions(-)
diff --git a/libs/core/langchain_core/documents/base.py b/libs/core/langchain_core/documents/base.py
index fba997f4959..0323bb1e886 100644
--- a/libs/core/langchain_core/documents/base.py
+++ b/libs/core/langchain_core/documents/base.py
@@ -8,7 +8,7 @@ from io import BufferedReader, BytesIO
from pathlib import Path, PurePath
from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast
-from pydantic import ConfigDict, Field, field_validator, model_validator
+from pydantic import ConfigDict, Field, model_validator
from langchain_core.load.serializable import Serializable
@@ -33,7 +33,7 @@ class BaseMedia(Serializable):
# The ID field is optional at the moment.
# It will likely become required in a future major release after
# it has been adopted by enough vectorstore implementations.
- id: Optional[str] = None
+ id: Optional[str] = Field(default=None, coerce_numbers_to_str=True)
"""An optional identifier for the document.
Ideally this should be unique across the document collection and formatted
@@ -45,17 +45,6 @@ class BaseMedia(Serializable):
metadata: dict = Field(default_factory=dict)
"""Arbitrary metadata associated with the content."""
- @field_validator("id", mode="before")
- def cast_id_to_str(cls, id_value: Any) -> Optional[str]:
- """Coerce the id field to a string.
-
- Args:
- id_value: The id value to coerce.
- """
- if id_value is not None:
- return str(id_value)
- return id_value
-
class Blob(BaseMedia):
"""Blob represents raw data by either reference or value.
diff --git a/libs/core/langchain_core/messages/ai.py b/libs/core/langchain_core/messages/ai.py
index 6b067864186..34e5c217270 100644
--- a/libs/core/langchain_core/messages/ai.py
+++ b/libs/core/langchain_core/messages/ai.py
@@ -194,6 +194,7 @@ class AIMessage(BaseMessage):
"invalid_tool_calls": self.invalid_tool_calls,
}
+ # TODO: remove this logic if possible, reducing breaking nature of changes
@model_validator(mode="before")
@classmethod
def _backwards_compat_tool_calls(cls, values: dict) -> Any:
diff --git a/libs/core/langchain_core/messages/base.py b/libs/core/langchain_core/messages/base.py
index 58049d28e9e..1ce0122166b 100644
--- a/libs/core/langchain_core/messages/base.py
+++ b/libs/core/langchain_core/messages/base.py
@@ -4,7 +4,7 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Any, Optional, Union, cast
-from pydantic import ConfigDict, Field, field_validator
+from pydantic import ConfigDict, Field
from langchain_core.load.serializable import Serializable
from langchain_core.utils import get_bolded_text
@@ -52,7 +52,7 @@ class BaseMessage(Serializable):
model implementation.
"""
- id: Optional[str] = None
+ id: Optional[str] = Field(default=None, coerce_numbers_to_str=True)
"""An optional unique identifier for the message. This should ideally be
provided by the provider/model which created the message."""
@@ -60,13 +60,6 @@ class BaseMessage(Serializable):
extra="allow",
)
- @field_validator("id", mode="before")
- def cast_id_to_str(cls, id_value: Any) -> Optional[str]:
- """Coerce the id field to a string."""
- if id_value is not None:
- return str(id_value)
- return id_value
-
def __init__(
self, content: Union[str, list[Union[str, dict]]], **kwargs: Any
) -> None:
diff --git a/libs/core/langchain_core/outputs/chat_generation.py b/libs/core/langchain_core/outputs/chat_generation.py
index 25ea5684e67..ed4f1066341 100644
--- a/libs/core/langchain_core/outputs/chat_generation.py
+++ b/libs/core/langchain_core/outputs/chat_generation.py
@@ -4,8 +4,7 @@ from __future__ import annotations
from typing import Literal, Union
-from pydantic import model_validator
-from typing_extensions import Self
+from pydantic import computed_field
from langchain_core.messages import BaseMessage, BaseMessageChunk
from langchain_core.outputs.generation import Generation
@@ -26,48 +25,30 @@ class ChatGeneration(Generation):
via callbacks).
"""
- text: str = ""
- """*SHOULD NOT BE SET DIRECTLY* The text contents of the output message."""
message: BaseMessage
"""The message output by the chat model."""
- # Override type to be ChatGeneration, ignore mypy error as this is intentional
+
type: Literal["ChatGeneration"] = "ChatGeneration" # type: ignore[assignment]
"""Type is used exclusively for serialization purposes."""
- @model_validator(mode="after")
- def set_text(self) -> Self:
- """Set the text attribute to be the contents of the message.
-
- Args:
- values: The values of the object.
-
- Returns:
- The values of the object with the text attribute set.
-
- Raises:
- ValueError: If the message is not a string or a list.
- """
- try:
- text = ""
- if isinstance(self.message.content, str):
- text = self.message.content
- # Assumes text in content blocks in OpenAI format.
- # Uses first text block.
- elif isinstance(self.message.content, list):
- for block in self.message.content:
- if isinstance(block, str):
- text = block
- break
- if isinstance(block, dict) and "text" in block:
- text = block["text"]
- break
- else:
- pass
- self.text = text
- except (KeyError, AttributeError) as e:
- msg = "Error while initializing ChatGeneration"
- raise ValueError(msg) from e
- return self
+ @computed_field # type: ignore[prop-decorator]
+ @property
+ def text(self) -> str:
+ """Set the text attribute to be the contents of the message."""
+ text_ = ""
+ if isinstance(self.message.content, str):
+ text_ = self.message.content
+ # Assumes text in content blocks in OpenAI format.
+ # Uses first text block.
+ elif isinstance(self.message.content, list):
+ for block in self.message.content:
+ if isinstance(block, str):
+ text_ = block
+ break
+ if isinstance(block, dict) and "text" in block:
+ text_ = block["text"]
+ break
+ return text_
class ChatGenerationChunk(ChatGeneration):
@@ -78,7 +59,7 @@ class ChatGenerationChunk(ChatGeneration):
message: BaseMessageChunk
"""The message chunk output by the chat model."""
- # Override type to be ChatGeneration, ignore mypy error as this is intentional
+
type: Literal["ChatGenerationChunk"] = "ChatGenerationChunk" # type: ignore[assignment]
"""Type is used exclusively for serialization purposes."""
diff --git a/libs/core/langchain_core/outputs/generation.py b/libs/core/langchain_core/outputs/generation.py
index 8f3bbe5a77c..75ac5b81db7 100644
--- a/libs/core/langchain_core/outputs/generation.py
+++ b/libs/core/langchain_core/outputs/generation.py
@@ -4,6 +4,8 @@ from __future__ import annotations
from typing import Any, Literal, Optional
+from pydantic import computed_field
+
from langchain_core.load import Serializable
from langchain_core.utils._merge import merge_dicts
@@ -24,14 +26,30 @@ class Generation(Serializable):
for more information.
"""
- text: str
- """Generated text output."""
+ def __init__(
+ self,
+ text: str = "",
+ generation_info: Optional[dict[str, Any]] = None,
+ **kwargs: Any,
+ ):
+ """Initialize a Generation."""
+ super().__init__(generation_info=generation_info, **kwargs)
+ self._text = text
+
+ # workaround for ChatGeneration so that we can use a computed field to populate
+ # the text field from the message content (parent class needs to have a property)
+ @computed_field # type: ignore[prop-decorator]
+ @property
+ def text(self) -> str:
+ """The text contents of the output."""
+ return self._text
generation_info: Optional[dict[str, Any]] = None
"""Raw response from the provider.
May include things like the reason for finishing or token log probabilities.
"""
+
type: Literal["Generation"] = "Generation"
"""Type is used exclusively for serialization purposes.
Set to "Generation" for this class."""
@@ -53,6 +71,16 @@ class Generation(Serializable):
class GenerationChunk(Generation):
"""Generation chunk, which can be concatenated with other Generation chunks."""
+ def __init__(
+ self,
+ text: str = "",
+ generation_info: Optional[dict[str, Any]] = None,
+ **kwargs: Any,
+ ):
+ """Initialize a GenerationChunk."""
+ super().__init__(text=text, generation_info=generation_info, **kwargs)
+ self._text = text
+
def __add__(self, other: GenerationChunk) -> GenerationChunk:
"""Concatenate two GenerationChunks."""
if isinstance(other, GenerationChunk):