mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-24 15:43:54 +00:00
core(fix): revert set_text
optimization (#31555)
Revert serialization regression introduced in https://github.com/langchain-ai/langchain/pull/31238 Fixes https://github.com/langchain-ai/langchain/issues/31486
This commit is contained in:
parent
e455fab5d3
commit
5b165effcd
@ -4,7 +4,8 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from typing import Literal, Union
|
from typing import Literal, Union
|
||||||
|
|
||||||
from pydantic import computed_field
|
from pydantic import model_validator
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
from langchain_core.messages import BaseMessage, BaseMessageChunk
|
from langchain_core.messages import BaseMessage, BaseMessageChunk
|
||||||
from langchain_core.outputs.generation import Generation
|
from langchain_core.outputs.generation import Generation
|
||||||
@ -25,30 +26,44 @@ class ChatGeneration(Generation):
|
|||||||
via callbacks).
|
via callbacks).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
text: str = ""
|
||||||
|
"""*SHOULD NOT BE SET DIRECTLY* The text contents of the output message."""
|
||||||
message: BaseMessage
|
message: BaseMessage
|
||||||
"""The message output by the chat model."""
|
"""The message output by the chat model."""
|
||||||
|
# Override type to be ChatGeneration, ignore mypy error as this is intentional
|
||||||
type: Literal["ChatGeneration"] = "ChatGeneration" # type: ignore[assignment]
|
type: Literal["ChatGeneration"] = "ChatGeneration" # type: ignore[assignment]
|
||||||
"""Type is used exclusively for serialization purposes."""
|
"""Type is used exclusively for serialization purposes."""
|
||||||
|
|
||||||
@computed_field # type: ignore[prop-decorator]
|
@model_validator(mode="after")
|
||||||
@property
|
def set_text(self) -> Self:
|
||||||
def text(self) -> str:
|
"""Set the text attribute to be the contents of the message.
|
||||||
"""Set the text attribute to be the contents of the message."""
|
|
||||||
text_ = ""
|
Args:
|
||||||
|
values: The values of the object.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The values of the object with the text attribute set.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the message is not a string or a list.
|
||||||
|
"""
|
||||||
|
text = ""
|
||||||
if isinstance(self.message.content, str):
|
if isinstance(self.message.content, str):
|
||||||
text_ = self.message.content
|
text = self.message.content
|
||||||
# Assumes text in content blocks in OpenAI format.
|
# Assumes text in content blocks in OpenAI format.
|
||||||
# Uses first text block.
|
# Uses first text block.
|
||||||
elif isinstance(self.message.content, list):
|
elif isinstance(self.message.content, list):
|
||||||
for block in self.message.content:
|
for block in self.message.content:
|
||||||
if isinstance(block, str):
|
if isinstance(block, str):
|
||||||
text_ = block
|
text = block
|
||||||
break
|
break
|
||||||
if isinstance(block, dict) and "text" in block:
|
if isinstance(block, dict) and "text" in block:
|
||||||
text_ = block["text"]
|
text = block["text"]
|
||||||
break
|
break
|
||||||
return text_
|
else:
|
||||||
|
pass
|
||||||
|
self.text = text
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
class ChatGenerationChunk(ChatGeneration):
|
class ChatGenerationChunk(ChatGeneration):
|
||||||
@ -59,7 +74,7 @@ class ChatGenerationChunk(ChatGeneration):
|
|||||||
|
|
||||||
message: BaseMessageChunk
|
message: BaseMessageChunk
|
||||||
"""The message chunk output by the chat model."""
|
"""The message chunk output by the chat model."""
|
||||||
|
# Override type to be ChatGeneration, ignore mypy error as this is intentional
|
||||||
type: Literal["ChatGenerationChunk"] = "ChatGenerationChunk" # type: ignore[assignment]
|
type: Literal["ChatGenerationChunk"] = "ChatGenerationChunk" # type: ignore[assignment]
|
||||||
"""Type is used exclusively for serialization purposes."""
|
"""Type is used exclusively for serialization purposes."""
|
||||||
|
|
||||||
|
@ -4,8 +4,6 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from typing import Any, Literal, Optional
|
from typing import Any, Literal, Optional
|
||||||
|
|
||||||
from pydantic import computed_field
|
|
||||||
|
|
||||||
from langchain_core.load import Serializable
|
from langchain_core.load import Serializable
|
||||||
from langchain_core.utils._merge import merge_dicts
|
from langchain_core.utils._merge import merge_dicts
|
||||||
|
|
||||||
@ -26,30 +24,14 @@ class Generation(Serializable):
|
|||||||
for more information.
|
for more information.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
text: str
|
||||||
self,
|
"""Generated text output."""
|
||||||
text: str = "",
|
|
||||||
generation_info: Optional[dict[str, Any]] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
):
|
|
||||||
"""Initialize a Generation."""
|
|
||||||
super().__init__(generation_info=generation_info, **kwargs)
|
|
||||||
self._text = text
|
|
||||||
|
|
||||||
# workaround for ChatGeneration so that we can use a computed field to populate
|
|
||||||
# the text field from the message content (parent class needs to have a property)
|
|
||||||
@computed_field # type: ignore[prop-decorator]
|
|
||||||
@property
|
|
||||||
def text(self) -> str:
|
|
||||||
"""The text contents of the output."""
|
|
||||||
return self._text
|
|
||||||
|
|
||||||
generation_info: Optional[dict[str, Any]] = None
|
generation_info: Optional[dict[str, Any]] = None
|
||||||
"""Raw response from the provider.
|
"""Raw response from the provider.
|
||||||
|
|
||||||
May include things like the reason for finishing or token log probabilities.
|
May include things like the reason for finishing or token log probabilities.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: Literal["Generation"] = "Generation"
|
type: Literal["Generation"] = "Generation"
|
||||||
"""Type is used exclusively for serialization purposes.
|
"""Type is used exclusively for serialization purposes.
|
||||||
Set to "Generation" for this class."""
|
Set to "Generation" for this class."""
|
||||||
@ -71,16 +53,6 @@ class Generation(Serializable):
|
|||||||
class GenerationChunk(Generation):
|
class GenerationChunk(Generation):
|
||||||
"""Generation chunk, which can be concatenated with other Generation chunks."""
|
"""Generation chunk, which can be concatenated with other Generation chunks."""
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
text: str = "",
|
|
||||||
generation_info: Optional[dict[str, Any]] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
):
|
|
||||||
"""Initialize a GenerationChunk."""
|
|
||||||
super().__init__(text=text, generation_info=generation_info, **kwargs)
|
|
||||||
self._text = text
|
|
||||||
|
|
||||||
def __add__(self, other: GenerationChunk) -> GenerationChunk:
|
def __add__(self, other: GenerationChunk) -> GenerationChunk:
|
||||||
"""Concatenate two GenerationChunks."""
|
"""Concatenate two GenerationChunks."""
|
||||||
if isinstance(other, GenerationChunk):
|
if isinstance(other, GenerationChunk):
|
||||||
|
@ -3,7 +3,7 @@ from pydantic import BaseModel, ConfigDict, Field
|
|||||||
from langchain_core.load import Serializable, dumpd, load
|
from langchain_core.load import Serializable, dumpd, load
|
||||||
from langchain_core.load.serializable import _is_field_useful
|
from langchain_core.load.serializable import _is_field_useful
|
||||||
from langchain_core.messages import AIMessage
|
from langchain_core.messages import AIMessage
|
||||||
from langchain_core.outputs import ChatGeneration
|
from langchain_core.outputs import ChatGeneration, Generation
|
||||||
|
|
||||||
|
|
||||||
class NonBoolObj:
|
class NonBoolObj:
|
||||||
@ -223,3 +223,8 @@ def test_serialization_with_pydantic() -> None:
|
|||||||
assert isinstance(deser, ChatGeneration)
|
assert isinstance(deser, ChatGeneration)
|
||||||
assert deser.message.content
|
assert deser.message.content
|
||||||
assert deser.message.additional_kwargs["parsed"] == my_model.model_dump()
|
assert deser.message.additional_kwargs["parsed"] == my_model.model_dump()
|
||||||
|
|
||||||
|
|
||||||
|
def test_serialization_with_generation() -> None:
|
||||||
|
generation = Generation(text="hello-world")
|
||||||
|
assert dumpd(generation)["kwargs"] == {"text": "hello-world", "type": "Generation"}
|
||||||
|
Loading…
Reference in New Issue
Block a user