mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-18 17:11:25 +00:00
perf[core]: remove unnecessary model validators (#31238)
* Remove unnecessary cast of id -> str (can do with a field setting) * Remove unnecessary `set_text` model validator (can be done with a computed field - though we had to make some changes to the `Generation` class to make this possible Before: ~2.4s Blue circles represent time spent in custom validators :( <img width="1337" alt="Screenshot 2025-05-14 at 10 10 12 AM" src="https://github.com/user-attachments/assets/bb4f477f-4ee3-4870-ae93-14ca7f197d55" /> After: ~2.2s <img width="1344" alt="Screenshot 2025-05-14 at 10 11 03 AM" src="https://github.com/user-attachments/assets/99f97d80-49de-462f-856f-9e7e8662adbc" /> We still want to optimize the backwards compatible tool calls model validator, though I think this might involve breaking changes, so wanted to separate that into a different PR. This is circled in green.
This commit is contained in:
parent
1523602196
commit
7263011b24
@ -8,7 +8,7 @@ from io import BufferedReader, BytesIO
|
|||||||
from pathlib import Path, PurePath
|
from pathlib import Path, PurePath
|
||||||
from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast
|
from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast
|
||||||
|
|
||||||
from pydantic import ConfigDict, Field, field_validator, model_validator
|
from pydantic import ConfigDict, Field, model_validator
|
||||||
|
|
||||||
from langchain_core.load.serializable import Serializable
|
from langchain_core.load.serializable import Serializable
|
||||||
|
|
||||||
@ -33,7 +33,7 @@ class BaseMedia(Serializable):
|
|||||||
# The ID field is optional at the moment.
|
# The ID field is optional at the moment.
|
||||||
# It will likely become required in a future major release after
|
# It will likely become required in a future major release after
|
||||||
# it has been adopted by enough vectorstore implementations.
|
# it has been adopted by enough vectorstore implementations.
|
||||||
id: Optional[str] = None
|
id: Optional[str] = Field(default=None, coerce_numbers_to_str=True)
|
||||||
"""An optional identifier for the document.
|
"""An optional identifier for the document.
|
||||||
|
|
||||||
Ideally this should be unique across the document collection and formatted
|
Ideally this should be unique across the document collection and formatted
|
||||||
@ -45,17 +45,6 @@ class BaseMedia(Serializable):
|
|||||||
metadata: dict = Field(default_factory=dict)
|
metadata: dict = Field(default_factory=dict)
|
||||||
"""Arbitrary metadata associated with the content."""
|
"""Arbitrary metadata associated with the content."""
|
||||||
|
|
||||||
@field_validator("id", mode="before")
|
|
||||||
def cast_id_to_str(cls, id_value: Any) -> Optional[str]:
|
|
||||||
"""Coerce the id field to a string.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
id_value: The id value to coerce.
|
|
||||||
"""
|
|
||||||
if id_value is not None:
|
|
||||||
return str(id_value)
|
|
||||||
return id_value
|
|
||||||
|
|
||||||
|
|
||||||
class Blob(BaseMedia):
|
class Blob(BaseMedia):
|
||||||
"""Blob represents raw data by either reference or value.
|
"""Blob represents raw data by either reference or value.
|
||||||
|
@ -194,6 +194,7 @@ class AIMessage(BaseMessage):
|
|||||||
"invalid_tool_calls": self.invalid_tool_calls,
|
"invalid_tool_calls": self.invalid_tool_calls,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# TODO: remove this logic if possible, reducing breaking nature of changes
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def _backwards_compat_tool_calls(cls, values: dict) -> Any:
|
def _backwards_compat_tool_calls(cls, values: dict) -> Any:
|
||||||
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
||||||
|
|
||||||
from pydantic import ConfigDict, Field, field_validator
|
from pydantic import ConfigDict, Field
|
||||||
|
|
||||||
from langchain_core.load.serializable import Serializable
|
from langchain_core.load.serializable import Serializable
|
||||||
from langchain_core.utils import get_bolded_text
|
from langchain_core.utils import get_bolded_text
|
||||||
@ -52,7 +52,7 @@ class BaseMessage(Serializable):
|
|||||||
model implementation.
|
model implementation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
id: Optional[str] = None
|
id: Optional[str] = Field(default=None, coerce_numbers_to_str=True)
|
||||||
"""An optional unique identifier for the message. This should ideally be
|
"""An optional unique identifier for the message. This should ideally be
|
||||||
provided by the provider/model which created the message."""
|
provided by the provider/model which created the message."""
|
||||||
|
|
||||||
@ -60,13 +60,6 @@ class BaseMessage(Serializable):
|
|||||||
extra="allow",
|
extra="allow",
|
||||||
)
|
)
|
||||||
|
|
||||||
@field_validator("id", mode="before")
|
|
||||||
def cast_id_to_str(cls, id_value: Any) -> Optional[str]:
|
|
||||||
"""Coerce the id field to a string."""
|
|
||||||
if id_value is not None:
|
|
||||||
return str(id_value)
|
|
||||||
return id_value
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, content: Union[str, list[Union[str, dict]]], **kwargs: Any
|
self, content: Union[str, list[Union[str, dict]]], **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -4,8 +4,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from typing import Literal, Union
|
from typing import Literal, Union
|
||||||
|
|
||||||
from pydantic import model_validator
|
from pydantic import computed_field
|
||||||
from typing_extensions import Self
|
|
||||||
|
|
||||||
from langchain_core.messages import BaseMessage, BaseMessageChunk
|
from langchain_core.messages import BaseMessage, BaseMessageChunk
|
||||||
from langchain_core.outputs.generation import Generation
|
from langchain_core.outputs.generation import Generation
|
||||||
@ -26,48 +25,30 @@ class ChatGeneration(Generation):
|
|||||||
via callbacks).
|
via callbacks).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
text: str = ""
|
|
||||||
"""*SHOULD NOT BE SET DIRECTLY* The text contents of the output message."""
|
|
||||||
message: BaseMessage
|
message: BaseMessage
|
||||||
"""The message output by the chat model."""
|
"""The message output by the chat model."""
|
||||||
# Override type to be ChatGeneration, ignore mypy error as this is intentional
|
|
||||||
type: Literal["ChatGeneration"] = "ChatGeneration" # type: ignore[assignment]
|
type: Literal["ChatGeneration"] = "ChatGeneration" # type: ignore[assignment]
|
||||||
"""Type is used exclusively for serialization purposes."""
|
"""Type is used exclusively for serialization purposes."""
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@computed_field # type: ignore[prop-decorator]
|
||||||
def set_text(self) -> Self:
|
@property
|
||||||
"""Set the text attribute to be the contents of the message.
|
def text(self) -> str:
|
||||||
|
"""Set the text attribute to be the contents of the message."""
|
||||||
Args:
|
text_ = ""
|
||||||
values: The values of the object.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The values of the object with the text attribute set.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the message is not a string or a list.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
text = ""
|
|
||||||
if isinstance(self.message.content, str):
|
if isinstance(self.message.content, str):
|
||||||
text = self.message.content
|
text_ = self.message.content
|
||||||
# Assumes text in content blocks in OpenAI format.
|
# Assumes text in content blocks in OpenAI format.
|
||||||
# Uses first text block.
|
# Uses first text block.
|
||||||
elif isinstance(self.message.content, list):
|
elif isinstance(self.message.content, list):
|
||||||
for block in self.message.content:
|
for block in self.message.content:
|
||||||
if isinstance(block, str):
|
if isinstance(block, str):
|
||||||
text = block
|
text_ = block
|
||||||
break
|
break
|
||||||
if isinstance(block, dict) and "text" in block:
|
if isinstance(block, dict) and "text" in block:
|
||||||
text = block["text"]
|
text_ = block["text"]
|
||||||
break
|
break
|
||||||
else:
|
return text_
|
||||||
pass
|
|
||||||
self.text = text
|
|
||||||
except (KeyError, AttributeError) as e:
|
|
||||||
msg = "Error while initializing ChatGeneration"
|
|
||||||
raise ValueError(msg) from e
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
class ChatGenerationChunk(ChatGeneration):
|
class ChatGenerationChunk(ChatGeneration):
|
||||||
@ -78,7 +59,7 @@ class ChatGenerationChunk(ChatGeneration):
|
|||||||
|
|
||||||
message: BaseMessageChunk
|
message: BaseMessageChunk
|
||||||
"""The message chunk output by the chat model."""
|
"""The message chunk output by the chat model."""
|
||||||
# Override type to be ChatGeneration, ignore mypy error as this is intentional
|
|
||||||
type: Literal["ChatGenerationChunk"] = "ChatGenerationChunk" # type: ignore[assignment]
|
type: Literal["ChatGenerationChunk"] = "ChatGenerationChunk" # type: ignore[assignment]
|
||||||
"""Type is used exclusively for serialization purposes."""
|
"""Type is used exclusively for serialization purposes."""
|
||||||
|
|
||||||
|
@ -4,6 +4,8 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from typing import Any, Literal, Optional
|
from typing import Any, Literal, Optional
|
||||||
|
|
||||||
|
from pydantic import computed_field
|
||||||
|
|
||||||
from langchain_core.load import Serializable
|
from langchain_core.load import Serializable
|
||||||
from langchain_core.utils._merge import merge_dicts
|
from langchain_core.utils._merge import merge_dicts
|
||||||
|
|
||||||
@ -24,14 +26,30 @@ class Generation(Serializable):
|
|||||||
for more information.
|
for more information.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
text: str
|
def __init__(
|
||||||
"""Generated text output."""
|
self,
|
||||||
|
text: str = "",
|
||||||
|
generation_info: Optional[dict[str, Any]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
):
|
||||||
|
"""Initialize a Generation."""
|
||||||
|
super().__init__(generation_info=generation_info, **kwargs)
|
||||||
|
self._text = text
|
||||||
|
|
||||||
|
# workaround for ChatGeneration so that we can use a computed field to populate
|
||||||
|
# the text field from the message content (parent class needs to have a property)
|
||||||
|
@computed_field # type: ignore[prop-decorator]
|
||||||
|
@property
|
||||||
|
def text(self) -> str:
|
||||||
|
"""The text contents of the output."""
|
||||||
|
return self._text
|
||||||
|
|
||||||
generation_info: Optional[dict[str, Any]] = None
|
generation_info: Optional[dict[str, Any]] = None
|
||||||
"""Raw response from the provider.
|
"""Raw response from the provider.
|
||||||
|
|
||||||
May include things like the reason for finishing or token log probabilities.
|
May include things like the reason for finishing or token log probabilities.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: Literal["Generation"] = "Generation"
|
type: Literal["Generation"] = "Generation"
|
||||||
"""Type is used exclusively for serialization purposes.
|
"""Type is used exclusively for serialization purposes.
|
||||||
Set to "Generation" for this class."""
|
Set to "Generation" for this class."""
|
||||||
@ -53,6 +71,16 @@ class Generation(Serializable):
|
|||||||
class GenerationChunk(Generation):
|
class GenerationChunk(Generation):
|
||||||
"""Generation chunk, which can be concatenated with other Generation chunks."""
|
"""Generation chunk, which can be concatenated with other Generation chunks."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
text: str = "",
|
||||||
|
generation_info: Optional[dict[str, Any]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
):
|
||||||
|
"""Initialize a GenerationChunk."""
|
||||||
|
super().__init__(text=text, generation_info=generation_info, **kwargs)
|
||||||
|
self._text = text
|
||||||
|
|
||||||
def __add__(self, other: GenerationChunk) -> GenerationChunk:
|
def __add__(self, other: GenerationChunk) -> GenerationChunk:
|
||||||
"""Concatenate two GenerationChunks."""
|
"""Concatenate two GenerationChunks."""
|
||||||
if isinstance(other, GenerationChunk):
|
if isinstance(other, GenerationChunk):
|
||||||
|
Loading…
Reference in New Issue
Block a user