mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-29 02:58:06 +00:00
fix ChatMessageChunk concat error (#10174)
<!-- Thank you for contributing to LangChain! Replace this entire comment with: - Description: a description of the change, - Issue: the issue # it fixes (if applicable), - Dependencies: any dependencies required for this change, - Tag maintainer: for a quicker response, tag the relevant maintainer (see below), - Twitter handle: we announce bigger features on Twitter. If your PR gets announced and you'd like a mention, we'll gladly shout you out! Please make sure your PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` to check this locally. See contribution guidelines for more information on how to write/run tests, lint, etc: https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. These live is docs/extras directory. If no one reviews your PR within a few days, please @-mention one of @baskaryan, @eyurtsev, @hwchase17, @rlancemartin. --> - Description: fix `ChatMessageChunk` concat error - Issue: #10173 - Dependencies: None - Tag maintainer: @baskaryan, @eyurtsev, @rlancemartin - Twitter handle: None --------- Co-authored-by: wangshuai.scotty <wangshuai.scotty@bytedance.com> Co-authored-by: Nuno Campos <nuno@boringbits.io>
This commit is contained in:
parent
4322b246aa
commit
88a02076af
@ -117,6 +117,14 @@ class BaseMessageChunk(BaseMessage):
|
||||
# If both are (subclasses of) BaseMessageChunk,
|
||||
# concat into a single BaseMessageChunk
|
||||
|
||||
if isinstance(self, ChatMessageChunk):
|
||||
return self.__class__(
|
||||
role=self.role,
|
||||
content=self.content + other.content,
|
||||
additional_kwargs=self._merge_kwargs_dict(
|
||||
self.additional_kwargs, other.additional_kwargs
|
||||
),
|
||||
)
|
||||
return self.__class__(
|
||||
content=self.content + other.content,
|
||||
additional_kwargs=self._merge_kwargs_dict(
|
||||
@ -168,7 +176,22 @@ class AIMessage(BaseMessage):
|
||||
class AIMessageChunk(AIMessage, BaseMessageChunk):
|
||||
"""A Message chunk from an AI."""
|
||||
|
||||
pass
|
||||
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
|
||||
if isinstance(other, AIMessageChunk):
|
||||
if self.example != other.example:
|
||||
raise ValueError(
|
||||
"Cannot concatenate AIMessageChunks with different example values."
|
||||
)
|
||||
|
||||
return self.__class__(
|
||||
example=self.example,
|
||||
content=self.content + other.content,
|
||||
additional_kwargs=self._merge_kwargs_dict(
|
||||
self.additional_kwargs, other.additional_kwargs
|
||||
),
|
||||
)
|
||||
|
||||
return super().__add__(other)
|
||||
|
||||
|
||||
class SystemMessage(BaseMessage):
|
||||
@ -203,7 +226,22 @@ class FunctionMessage(BaseMessage):
|
||||
class FunctionMessageChunk(FunctionMessage, BaseMessageChunk):
|
||||
"""A Function Message chunk."""
|
||||
|
||||
pass
|
||||
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
|
||||
if isinstance(other, FunctionMessageChunk):
|
||||
if self.name != other.name:
|
||||
raise ValueError(
|
||||
"Cannot concatenate FunctionMessageChunks with different names."
|
||||
)
|
||||
|
||||
return self.__class__(
|
||||
name=self.name,
|
||||
content=self.content + other.content,
|
||||
additional_kwargs=self._merge_kwargs_dict(
|
||||
self.additional_kwargs, other.additional_kwargs
|
||||
),
|
||||
)
|
||||
|
||||
return super().__add__(other)
|
||||
|
||||
|
||||
class ChatMessage(BaseMessage):
|
||||
@ -221,7 +259,22 @@ class ChatMessage(BaseMessage):
|
||||
class ChatMessageChunk(ChatMessage, BaseMessageChunk):
|
||||
"""A Chat Message chunk."""
|
||||
|
||||
pass
|
||||
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
|
||||
if isinstance(other, ChatMessageChunk):
|
||||
if self.role != other.role:
|
||||
raise ValueError(
|
||||
"Cannot concatenate ChatMessageChunks with different roles."
|
||||
)
|
||||
|
||||
return self.__class__(
|
||||
role=self.role,
|
||||
content=self.content + other.content,
|
||||
additional_kwargs=self._merge_kwargs_dict(
|
||||
self.additional_kwargs, other.additional_kwargs
|
||||
),
|
||||
)
|
||||
|
||||
return super().__add__(other)
|
||||
|
||||
|
||||
def _message_to_dict(message: BaseMessage) -> dict:
|
||||
|
@ -1,4 +1,11 @@
|
||||
from langchain.schema.messages import AIMessageChunk, HumanMessageChunk
|
||||
import pytest
|
||||
|
||||
from langchain.schema.messages import (
|
||||
AIMessageChunk,
|
||||
ChatMessageChunk,
|
||||
FunctionMessageChunk,
|
||||
HumanMessageChunk,
|
||||
)
|
||||
|
||||
|
||||
def test_message_chunks() -> None:
|
||||
@ -36,3 +43,54 @@ def test_message_chunks() -> None:
|
||||
}
|
||||
},
|
||||
), "MessageChunk + MessageChunk should be a MessageChunk with merged additional_kwargs" # noqa: E501
|
||||
|
||||
|
||||
def test_chat_message_chunks() -> None:
|
||||
assert ChatMessageChunk(role="User", content="I am") + ChatMessageChunk(
|
||||
role="User", content=" indeed."
|
||||
) == ChatMessageChunk(
|
||||
role="User", content="I am indeed."
|
||||
), "ChatMessageChunk + ChatMessageChunk should be a ChatMessageChunk"
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
ChatMessageChunk(role="User", content="I am") + ChatMessageChunk(
|
||||
role="Assistant", content=" indeed."
|
||||
)
|
||||
|
||||
assert ChatMessageChunk(role="User", content="I am") + AIMessageChunk(
|
||||
content=" indeed."
|
||||
) == ChatMessageChunk(
|
||||
role="User", content="I am indeed."
|
||||
), "ChatMessageChunk + other MessageChunk should be a ChatMessageChunk with the left side's role" # noqa: E501
|
||||
|
||||
assert AIMessageChunk(content="I am") + ChatMessageChunk(
|
||||
role="User", content=" indeed."
|
||||
) == AIMessageChunk(
|
||||
content="I am indeed."
|
||||
), "Other MessageChunk + ChatMessageChunk should be a MessageChunk as the left side" # noqa: E501
|
||||
|
||||
|
||||
def test_function_message_chunks() -> None:
|
||||
assert FunctionMessageChunk(name="hello", content="I am") + FunctionMessageChunk(
|
||||
name="hello", content=" indeed."
|
||||
) == FunctionMessageChunk(
|
||||
name="hello", content="I am indeed."
|
||||
), "FunctionMessageChunk + FunctionMessageChunk should be a FunctionMessageChunk"
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
FunctionMessageChunk(name="hello", content="I am") + FunctionMessageChunk(
|
||||
name="bye", content=" indeed."
|
||||
)
|
||||
|
||||
|
||||
def test_ani_message_chunks() -> None:
|
||||
assert AIMessageChunk(example=True, content="I am") + AIMessageChunk(
|
||||
example=True, content=" indeed."
|
||||
) == AIMessageChunk(
|
||||
example=True, content="I am indeed."
|
||||
), "AIMessageChunk + AIMessageChunk should be a AIMessageChunk"
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
AIMessageChunk(example=True, content="I am") + AIMessageChunk(
|
||||
example=False, content=" indeed."
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user