mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-22 14:49:29 +00:00
core(fix): revert set_text
optimization (#31555)
Revert serialization regression introduced in https://github.com/langchain-ai/langchain/pull/31238 Fixes https://github.com/langchain-ai/langchain/issues/31486
This commit is contained in:
parent
e455fab5d3
commit
5b165effcd
@ -4,7 +4,8 @@ from __future__ import annotations
|
||||
|
||||
from typing import Literal, Union
|
||||
|
||||
from pydantic import computed_field
|
||||
from pydantic import model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from langchain_core.messages import BaseMessage, BaseMessageChunk
|
||||
from langchain_core.outputs.generation import Generation
|
||||
@ -25,30 +26,44 @@ class ChatGeneration(Generation):
|
||||
via callbacks).
|
||||
"""
|
||||
|
||||
text: str = ""
|
||||
"""*SHOULD NOT BE SET DIRECTLY* The text contents of the output message."""
|
||||
message: BaseMessage
|
||||
"""The message output by the chat model."""
|
||||
|
||||
# Override type to be ChatGeneration, ignore mypy error as this is intentional
|
||||
type: Literal["ChatGeneration"] = "ChatGeneration" # type: ignore[assignment]
|
||||
"""Type is used exclusively for serialization purposes."""
|
||||
|
||||
@computed_field # type: ignore[prop-decorator]
|
||||
@property
|
||||
def text(self) -> str:
|
||||
"""Set the text attribute to be the contents of the message."""
|
||||
text_ = ""
|
||||
@model_validator(mode="after")
|
||||
def set_text(self) -> Self:
|
||||
"""Set the text attribute to be the contents of the message.
|
||||
|
||||
Args:
|
||||
values: The values of the object.
|
||||
|
||||
Returns:
|
||||
The values of the object with the text attribute set.
|
||||
|
||||
Raises:
|
||||
ValueError: If the message is not a string or a list.
|
||||
"""
|
||||
text = ""
|
||||
if isinstance(self.message.content, str):
|
||||
text_ = self.message.content
|
||||
text = self.message.content
|
||||
# Assumes text in content blocks in OpenAI format.
|
||||
# Uses first text block.
|
||||
elif isinstance(self.message.content, list):
|
||||
for block in self.message.content:
|
||||
if isinstance(block, str):
|
||||
text_ = block
|
||||
text = block
|
||||
break
|
||||
if isinstance(block, dict) and "text" in block:
|
||||
text_ = block["text"]
|
||||
text = block["text"]
|
||||
break
|
||||
return text_
|
||||
else:
|
||||
pass
|
||||
self.text = text
|
||||
return self
|
||||
|
||||
|
||||
class ChatGenerationChunk(ChatGeneration):
|
||||
@ -59,7 +74,7 @@ class ChatGenerationChunk(ChatGeneration):
|
||||
|
||||
message: BaseMessageChunk
|
||||
"""The message chunk output by the chat model."""
|
||||
|
||||
# Override type to be ChatGeneration, ignore mypy error as this is intentional
|
||||
type: Literal["ChatGenerationChunk"] = "ChatGenerationChunk" # type: ignore[assignment]
|
||||
"""Type is used exclusively for serialization purposes."""
|
||||
|
||||
|
@ -4,8 +4,6 @@ from __future__ import annotations
|
||||
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from pydantic import computed_field
|
||||
|
||||
from langchain_core.load import Serializable
|
||||
from langchain_core.utils._merge import merge_dicts
|
||||
|
||||
@ -26,30 +24,14 @@ class Generation(Serializable):
|
||||
for more information.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text: str = "",
|
||||
generation_info: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Initialize a Generation."""
|
||||
super().__init__(generation_info=generation_info, **kwargs)
|
||||
self._text = text
|
||||
|
||||
# workaround for ChatGeneration so that we can use a computed field to populate
|
||||
# the text field from the message content (parent class needs to have a property)
|
||||
@computed_field # type: ignore[prop-decorator]
|
||||
@property
|
||||
def text(self) -> str:
|
||||
"""The text contents of the output."""
|
||||
return self._text
|
||||
text: str
|
||||
"""Generated text output."""
|
||||
|
||||
generation_info: Optional[dict[str, Any]] = None
|
||||
"""Raw response from the provider.
|
||||
|
||||
May include things like the reason for finishing or token log probabilities.
|
||||
"""
|
||||
|
||||
type: Literal["Generation"] = "Generation"
|
||||
"""Type is used exclusively for serialization purposes.
|
||||
Set to "Generation" for this class."""
|
||||
@ -71,16 +53,6 @@ class Generation(Serializable):
|
||||
class GenerationChunk(Generation):
|
||||
"""Generation chunk, which can be concatenated with other Generation chunks."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text: str = "",
|
||||
generation_info: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Initialize a GenerationChunk."""
|
||||
super().__init__(text=text, generation_info=generation_info, **kwargs)
|
||||
self._text = text
|
||||
|
||||
def __add__(self, other: GenerationChunk) -> GenerationChunk:
|
||||
"""Concatenate two GenerationChunks."""
|
||||
if isinstance(other, GenerationChunk):
|
||||
|
@ -3,7 +3,7 @@ from pydantic import BaseModel, ConfigDict, Field
|
||||
from langchain_core.load import Serializable, dumpd, load
|
||||
from langchain_core.load.serializable import _is_field_useful
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.outputs import ChatGeneration
|
||||
from langchain_core.outputs import ChatGeneration, Generation
|
||||
|
||||
|
||||
class NonBoolObj:
|
||||
@ -223,3 +223,8 @@ def test_serialization_with_pydantic() -> None:
|
||||
assert isinstance(deser, ChatGeneration)
|
||||
assert deser.message.content
|
||||
assert deser.message.additional_kwargs["parsed"] == my_model.model_dump()
|
||||
|
||||
|
||||
def test_serialization_with_generation() -> None:
|
||||
generation = Generation(text="hello-world")
|
||||
assert dumpd(generation)["kwargs"] == {"text": "hello-world", "type": "Generation"}
|
||||
|
Loading…
Reference in New Issue
Block a user