From 7263011b24b02e71cb4c6f8798d16815adf65e40 Mon Sep 17 00:00:00 2001 From: Sydney Runkle <54324534+sydney-runkle@users.noreply.github.com> Date: Wed, 14 May 2025 10:20:22 -0700 Subject: [PATCH] perf[core]: remove unnecessary model validators (#31238) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Remove unnecessary cast of id -> str (can do with a field setting) * Remove unnecessary `set_text` model validator (can be done with a computed field - though we had to make some changes to the `Generation` class to make this possible Before: ~2.4s Blue circles represent time spent in custom validators :( Screenshot 2025-05-14 at 10 10 12 AM After: ~2.2s Screenshot 2025-05-14 at 10 11 03 AM We still want to optimize the backwards compatible tool calls model validator, though I think this might involve breaking changes, so wanted to separate that into a different PR. This is circled in green. --- libs/core/langchain_core/documents/base.py | 15 +---- libs/core/langchain_core/messages/ai.py | 1 + libs/core/langchain_core/messages/base.py | 11 +--- .../langchain_core/outputs/chat_generation.py | 61 +++++++------------ .../core/langchain_core/outputs/generation.py | 32 +++++++++- 5 files changed, 56 insertions(+), 64 deletions(-) diff --git a/libs/core/langchain_core/documents/base.py b/libs/core/langchain_core/documents/base.py index fba997f4959..0323bb1e886 100644 --- a/libs/core/langchain_core/documents/base.py +++ b/libs/core/langchain_core/documents/base.py @@ -8,7 +8,7 @@ from io import BufferedReader, BytesIO from pathlib import Path, PurePath from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast -from pydantic import ConfigDict, Field, field_validator, model_validator +from pydantic import ConfigDict, Field, model_validator from langchain_core.load.serializable import Serializable @@ -33,7 +33,7 @@ class BaseMedia(Serializable): # The ID field is optional at the moment. # It will likely become required in a future major release after # it has been adopted by enough vectorstore implementations. - id: Optional[str] = None + id: Optional[str] = Field(default=None, coerce_numbers_to_str=True) """An optional identifier for the document. Ideally this should be unique across the document collection and formatted @@ -45,17 +45,6 @@ class BaseMedia(Serializable): metadata: dict = Field(default_factory=dict) """Arbitrary metadata associated with the content.""" - @field_validator("id", mode="before") - def cast_id_to_str(cls, id_value: Any) -> Optional[str]: - """Coerce the id field to a string. - - Args: - id_value: The id value to coerce. - """ - if id_value is not None: - return str(id_value) - return id_value - class Blob(BaseMedia): """Blob represents raw data by either reference or value. diff --git a/libs/core/langchain_core/messages/ai.py b/libs/core/langchain_core/messages/ai.py index 6b067864186..34e5c217270 100644 --- a/libs/core/langchain_core/messages/ai.py +++ b/libs/core/langchain_core/messages/ai.py @@ -194,6 +194,7 @@ class AIMessage(BaseMessage): "invalid_tool_calls": self.invalid_tool_calls, } + # TODO: remove this logic if possible, reducing breaking nature of changes @model_validator(mode="before") @classmethod def _backwards_compat_tool_calls(cls, values: dict) -> Any: diff --git a/libs/core/langchain_core/messages/base.py b/libs/core/langchain_core/messages/base.py index 58049d28e9e..1ce0122166b 100644 --- a/libs/core/langchain_core/messages/base.py +++ b/libs/core/langchain_core/messages/base.py @@ -4,7 +4,7 @@ from __future__ import annotations from typing import TYPE_CHECKING, Any, Optional, Union, cast -from pydantic import ConfigDict, Field, field_validator +from pydantic import ConfigDict, Field from langchain_core.load.serializable import Serializable from langchain_core.utils import get_bolded_text @@ -52,7 +52,7 @@ class BaseMessage(Serializable): model implementation. """ - id: Optional[str] = None + id: Optional[str] = Field(default=None, coerce_numbers_to_str=True) """An optional unique identifier for the message. This should ideally be provided by the provider/model which created the message.""" @@ -60,13 +60,6 @@ class BaseMessage(Serializable): extra="allow", ) - @field_validator("id", mode="before") - def cast_id_to_str(cls, id_value: Any) -> Optional[str]: - """Coerce the id field to a string.""" - if id_value is not None: - return str(id_value) - return id_value - def __init__( self, content: Union[str, list[Union[str, dict]]], **kwargs: Any ) -> None: diff --git a/libs/core/langchain_core/outputs/chat_generation.py b/libs/core/langchain_core/outputs/chat_generation.py index 25ea5684e67..ed4f1066341 100644 --- a/libs/core/langchain_core/outputs/chat_generation.py +++ b/libs/core/langchain_core/outputs/chat_generation.py @@ -4,8 +4,7 @@ from __future__ import annotations from typing import Literal, Union -from pydantic import model_validator -from typing_extensions import Self +from pydantic import computed_field from langchain_core.messages import BaseMessage, BaseMessageChunk from langchain_core.outputs.generation import Generation @@ -26,48 +25,30 @@ class ChatGeneration(Generation): via callbacks). """ - text: str = "" - """*SHOULD NOT BE SET DIRECTLY* The text contents of the output message.""" message: BaseMessage """The message output by the chat model.""" - # Override type to be ChatGeneration, ignore mypy error as this is intentional + type: Literal["ChatGeneration"] = "ChatGeneration" # type: ignore[assignment] """Type is used exclusively for serialization purposes.""" - @model_validator(mode="after") - def set_text(self) -> Self: - """Set the text attribute to be the contents of the message. - - Args: - values: The values of the object. - - Returns: - The values of the object with the text attribute set. - - Raises: - ValueError: If the message is not a string or a list. - """ - try: - text = "" - if isinstance(self.message.content, str): - text = self.message.content - # Assumes text in content blocks in OpenAI format. - # Uses first text block. - elif isinstance(self.message.content, list): - for block in self.message.content: - if isinstance(block, str): - text = block - break - if isinstance(block, dict) and "text" in block: - text = block["text"] - break - else: - pass - self.text = text - except (KeyError, AttributeError) as e: - msg = "Error while initializing ChatGeneration" - raise ValueError(msg) from e - return self + @computed_field # type: ignore[prop-decorator] + @property + def text(self) -> str: + """Set the text attribute to be the contents of the message.""" + text_ = "" + if isinstance(self.message.content, str): + text_ = self.message.content + # Assumes text in content blocks in OpenAI format. + # Uses first text block. + elif isinstance(self.message.content, list): + for block in self.message.content: + if isinstance(block, str): + text_ = block + break + if isinstance(block, dict) and "text" in block: + text_ = block["text"] + break + return text_ class ChatGenerationChunk(ChatGeneration): @@ -78,7 +59,7 @@ class ChatGenerationChunk(ChatGeneration): message: BaseMessageChunk """The message chunk output by the chat model.""" - # Override type to be ChatGeneration, ignore mypy error as this is intentional + type: Literal["ChatGenerationChunk"] = "ChatGenerationChunk" # type: ignore[assignment] """Type is used exclusively for serialization purposes.""" diff --git a/libs/core/langchain_core/outputs/generation.py b/libs/core/langchain_core/outputs/generation.py index 8f3bbe5a77c..75ac5b81db7 100644 --- a/libs/core/langchain_core/outputs/generation.py +++ b/libs/core/langchain_core/outputs/generation.py @@ -4,6 +4,8 @@ from __future__ import annotations from typing import Any, Literal, Optional +from pydantic import computed_field + from langchain_core.load import Serializable from langchain_core.utils._merge import merge_dicts @@ -24,14 +26,30 @@ class Generation(Serializable): for more information. """ - text: str - """Generated text output.""" + def __init__( + self, + text: str = "", + generation_info: Optional[dict[str, Any]] = None, + **kwargs: Any, + ): + """Initialize a Generation.""" + super().__init__(generation_info=generation_info, **kwargs) + self._text = text + + # workaround for ChatGeneration so that we can use a computed field to populate + # the text field from the message content (parent class needs to have a property) + @computed_field # type: ignore[prop-decorator] + @property + def text(self) -> str: + """The text contents of the output.""" + return self._text generation_info: Optional[dict[str, Any]] = None """Raw response from the provider. May include things like the reason for finishing or token log probabilities. """ + type: Literal["Generation"] = "Generation" """Type is used exclusively for serialization purposes. Set to "Generation" for this class.""" @@ -53,6 +71,16 @@ class Generation(Serializable): class GenerationChunk(Generation): """Generation chunk, which can be concatenated with other Generation chunks.""" + def __init__( + self, + text: str = "", + generation_info: Optional[dict[str, Any]] = None, + **kwargs: Any, + ): + """Initialize a GenerationChunk.""" + super().__init__(text=text, generation_info=generation_info, **kwargs) + self._text = text + def __add__(self, other: GenerationChunk) -> GenerationChunk: """Concatenate two GenerationChunks.""" if isinstance(other, GenerationChunk):