mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-31 12:09:58 +00:00
fix ChatMessageChunk concat error (#10174)
<!-- Thank you for contributing to LangChain! Replace this entire comment with: - Description: a description of the change, - Issue: the issue # it fixes (if applicable), - Dependencies: any dependencies required for this change, - Tag maintainer: for a quicker response, tag the relevant maintainer (see below), - Twitter handle: we announce bigger features on Twitter. If your PR gets announced and you'd like a mention, we'll gladly shout you out! Please make sure your PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` to check this locally. See contribution guidelines for more information on how to write/run tests, lint, etc: https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. These live is docs/extras directory. If no one reviews your PR within a few days, please @-mention one of @baskaryan, @eyurtsev, @hwchase17, @rlancemartin. --> - Description: fix `ChatMessageChunk` concat error - Issue: #10173 - Dependencies: None - Tag maintainer: @baskaryan, @eyurtsev, @rlancemartin - Twitter handle: None --------- Co-authored-by: wangshuai.scotty <wangshuai.scotty@bytedance.com> Co-authored-by: Nuno Campos <nuno@boringbits.io>
This commit is contained in:
parent
4322b246aa
commit
88a02076af
@ -117,6 +117,14 @@ class BaseMessageChunk(BaseMessage):
|
|||||||
# If both are (subclasses of) BaseMessageChunk,
|
# If both are (subclasses of) BaseMessageChunk,
|
||||||
# concat into a single BaseMessageChunk
|
# concat into a single BaseMessageChunk
|
||||||
|
|
||||||
|
if isinstance(self, ChatMessageChunk):
|
||||||
|
return self.__class__(
|
||||||
|
role=self.role,
|
||||||
|
content=self.content + other.content,
|
||||||
|
additional_kwargs=self._merge_kwargs_dict(
|
||||||
|
self.additional_kwargs, other.additional_kwargs
|
||||||
|
),
|
||||||
|
)
|
||||||
return self.__class__(
|
return self.__class__(
|
||||||
content=self.content + other.content,
|
content=self.content + other.content,
|
||||||
additional_kwargs=self._merge_kwargs_dict(
|
additional_kwargs=self._merge_kwargs_dict(
|
||||||
@ -168,7 +176,22 @@ class AIMessage(BaseMessage):
|
|||||||
class AIMessageChunk(AIMessage, BaseMessageChunk):
|
class AIMessageChunk(AIMessage, BaseMessageChunk):
|
||||||
"""A Message chunk from an AI."""
|
"""A Message chunk from an AI."""
|
||||||
|
|
||||||
pass
|
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
|
||||||
|
if isinstance(other, AIMessageChunk):
|
||||||
|
if self.example != other.example:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot concatenate AIMessageChunks with different example values."
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.__class__(
|
||||||
|
example=self.example,
|
||||||
|
content=self.content + other.content,
|
||||||
|
additional_kwargs=self._merge_kwargs_dict(
|
||||||
|
self.additional_kwargs, other.additional_kwargs
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return super().__add__(other)
|
||||||
|
|
||||||
|
|
||||||
class SystemMessage(BaseMessage):
|
class SystemMessage(BaseMessage):
|
||||||
@ -203,7 +226,22 @@ class FunctionMessage(BaseMessage):
|
|||||||
class FunctionMessageChunk(FunctionMessage, BaseMessageChunk):
|
class FunctionMessageChunk(FunctionMessage, BaseMessageChunk):
|
||||||
"""A Function Message chunk."""
|
"""A Function Message chunk."""
|
||||||
|
|
||||||
pass
|
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
|
||||||
|
if isinstance(other, FunctionMessageChunk):
|
||||||
|
if self.name != other.name:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot concatenate FunctionMessageChunks with different names."
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.__class__(
|
||||||
|
name=self.name,
|
||||||
|
content=self.content + other.content,
|
||||||
|
additional_kwargs=self._merge_kwargs_dict(
|
||||||
|
self.additional_kwargs, other.additional_kwargs
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return super().__add__(other)
|
||||||
|
|
||||||
|
|
||||||
class ChatMessage(BaseMessage):
|
class ChatMessage(BaseMessage):
|
||||||
@ -221,7 +259,22 @@ class ChatMessage(BaseMessage):
|
|||||||
class ChatMessageChunk(ChatMessage, BaseMessageChunk):
|
class ChatMessageChunk(ChatMessage, BaseMessageChunk):
|
||||||
"""A Chat Message chunk."""
|
"""A Chat Message chunk."""
|
||||||
|
|
||||||
pass
|
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
|
||||||
|
if isinstance(other, ChatMessageChunk):
|
||||||
|
if self.role != other.role:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot concatenate ChatMessageChunks with different roles."
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.__class__(
|
||||||
|
role=self.role,
|
||||||
|
content=self.content + other.content,
|
||||||
|
additional_kwargs=self._merge_kwargs_dict(
|
||||||
|
self.additional_kwargs, other.additional_kwargs
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return super().__add__(other)
|
||||||
|
|
||||||
|
|
||||||
def _message_to_dict(message: BaseMessage) -> dict:
|
def _message_to_dict(message: BaseMessage) -> dict:
|
||||||
|
@ -1,4 +1,11 @@
|
|||||||
from langchain.schema.messages import AIMessageChunk, HumanMessageChunk
|
import pytest
|
||||||
|
|
||||||
|
from langchain.schema.messages import (
|
||||||
|
AIMessageChunk,
|
||||||
|
ChatMessageChunk,
|
||||||
|
FunctionMessageChunk,
|
||||||
|
HumanMessageChunk,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_message_chunks() -> None:
|
def test_message_chunks() -> None:
|
||||||
@ -36,3 +43,54 @@ def test_message_chunks() -> None:
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
), "MessageChunk + MessageChunk should be a MessageChunk with merged additional_kwargs" # noqa: E501
|
), "MessageChunk + MessageChunk should be a MessageChunk with merged additional_kwargs" # noqa: E501
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_message_chunks() -> None:
|
||||||
|
assert ChatMessageChunk(role="User", content="I am") + ChatMessageChunk(
|
||||||
|
role="User", content=" indeed."
|
||||||
|
) == ChatMessageChunk(
|
||||||
|
role="User", content="I am indeed."
|
||||||
|
), "ChatMessageChunk + ChatMessageChunk should be a ChatMessageChunk"
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
ChatMessageChunk(role="User", content="I am") + ChatMessageChunk(
|
||||||
|
role="Assistant", content=" indeed."
|
||||||
|
)
|
||||||
|
|
||||||
|
assert ChatMessageChunk(role="User", content="I am") + AIMessageChunk(
|
||||||
|
content=" indeed."
|
||||||
|
) == ChatMessageChunk(
|
||||||
|
role="User", content="I am indeed."
|
||||||
|
), "ChatMessageChunk + other MessageChunk should be a ChatMessageChunk with the left side's role" # noqa: E501
|
||||||
|
|
||||||
|
assert AIMessageChunk(content="I am") + ChatMessageChunk(
|
||||||
|
role="User", content=" indeed."
|
||||||
|
) == AIMessageChunk(
|
||||||
|
content="I am indeed."
|
||||||
|
), "Other MessageChunk + ChatMessageChunk should be a MessageChunk as the left side" # noqa: E501
|
||||||
|
|
||||||
|
|
||||||
|
def test_function_message_chunks() -> None:
|
||||||
|
assert FunctionMessageChunk(name="hello", content="I am") + FunctionMessageChunk(
|
||||||
|
name="hello", content=" indeed."
|
||||||
|
) == FunctionMessageChunk(
|
||||||
|
name="hello", content="I am indeed."
|
||||||
|
), "FunctionMessageChunk + FunctionMessageChunk should be a FunctionMessageChunk"
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
FunctionMessageChunk(name="hello", content="I am") + FunctionMessageChunk(
|
||||||
|
name="bye", content=" indeed."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_ani_message_chunks() -> None:
|
||||||
|
assert AIMessageChunk(example=True, content="I am") + AIMessageChunk(
|
||||||
|
example=True, content=" indeed."
|
||||||
|
) == AIMessageChunk(
|
||||||
|
example=True, content="I am indeed."
|
||||||
|
), "AIMessageChunk + AIMessageChunk should be a AIMessageChunk"
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
AIMessageChunk(example=True, content="I am") + AIMessageChunk(
|
||||||
|
example=False, content=" indeed."
|
||||||
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user