mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-24 16:37:46 +00:00
Add type to message chunks (#11232)
This commit is contained in:
parent
fb66b392c6
commit
8b4cb4eb60
@ -149,6 +149,7 @@ class HumanMessage(BaseMessage):
|
||||
"""
|
||||
|
||||
type: Literal["human"] = "human"
|
||||
is_chunk: Literal[False] = False
|
||||
|
||||
|
||||
HumanMessage.update_forward_refs()
|
||||
@ -157,7 +158,10 @@ HumanMessage.update_forward_refs()
|
||||
class HumanMessageChunk(HumanMessage, BaseMessageChunk):
|
||||
"""A Human Message chunk."""
|
||||
|
||||
pass
|
||||
# Ignoring mypy re-assignment here since we're overriding the value
|
||||
# to make sure that the chunk variant can be discriminated from the
|
||||
# non-chunk variant.
|
||||
is_chunk: Literal[True] = True # type: ignore[assignment]
|
||||
|
||||
|
||||
class AIMessage(BaseMessage):
|
||||
@ -169,6 +173,7 @@ class AIMessage(BaseMessage):
|
||||
"""
|
||||
|
||||
type: Literal["ai"] = "ai"
|
||||
is_chunk: Literal[False] = False
|
||||
|
||||
|
||||
AIMessage.update_forward_refs()
|
||||
@ -177,6 +182,11 @@ AIMessage.update_forward_refs()
|
||||
class AIMessageChunk(AIMessage, BaseMessageChunk):
|
||||
"""A Message chunk from an AI."""
|
||||
|
||||
# Ignoring mypy re-assignment here since we're overriding the value
|
||||
# to make sure that the chunk variant can be discriminated from the
|
||||
# non-chunk variant.
|
||||
is_chunk: Literal[True] = True # type: ignore[assignment]
|
||||
|
||||
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
|
||||
if isinstance(other, AIMessageChunk):
|
||||
if self.example != other.example:
|
||||
@ -201,6 +211,7 @@ class SystemMessage(BaseMessage):
|
||||
"""
|
||||
|
||||
type: Literal["system"] = "system"
|
||||
is_chunk: Literal[False] = False
|
||||
|
||||
|
||||
SystemMessage.update_forward_refs()
|
||||
@ -209,7 +220,10 @@ SystemMessage.update_forward_refs()
|
||||
class SystemMessageChunk(SystemMessage, BaseMessageChunk):
|
||||
"""A System Message chunk."""
|
||||
|
||||
pass
|
||||
# Ignoring mypy re-assignment here since we're overriding the value
|
||||
# to make sure that the chunk variant can be discriminated from the
|
||||
# non-chunk variant.
|
||||
is_chunk: Literal[True] = True # type: ignore[assignment]
|
||||
|
||||
|
||||
class FunctionMessage(BaseMessage):
|
||||
@ -219,6 +233,7 @@ class FunctionMessage(BaseMessage):
|
||||
"""The name of the function that was executed."""
|
||||
|
||||
type: Literal["function"] = "function"
|
||||
is_chunk: Literal[False] = False
|
||||
|
||||
|
||||
FunctionMessage.update_forward_refs()
|
||||
@ -227,6 +242,11 @@ FunctionMessage.update_forward_refs()
|
||||
class FunctionMessageChunk(FunctionMessage, BaseMessageChunk):
|
||||
"""A Function Message chunk."""
|
||||
|
||||
# Ignoring mypy re-assignment here since we're overriding the value
|
||||
# to make sure that the chunk variant can be discriminated from the
|
||||
# non-chunk variant.
|
||||
is_chunk: Literal[True] = True # type: ignore[assignment]
|
||||
|
||||
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
|
||||
if isinstance(other, FunctionMessageChunk):
|
||||
if self.name != other.name:
|
||||
@ -252,6 +272,7 @@ class ChatMessage(BaseMessage):
|
||||
"""The speaker / role of the Message."""
|
||||
|
||||
type: Literal["chat"] = "chat"
|
||||
is_chunk: Literal[False] = False
|
||||
|
||||
|
||||
ChatMessage.update_forward_refs()
|
||||
@ -260,6 +281,11 @@ ChatMessage.update_forward_refs()
|
||||
class ChatMessageChunk(ChatMessage, BaseMessageChunk):
|
||||
"""A Chat Message chunk."""
|
||||
|
||||
# Ignoring mypy re-assignment here since we're overriding the value
|
||||
# to make sure that the chunk variant can be discriminated from the
|
||||
# non-chunk variant.
|
||||
is_chunk: Literal[True] = True # type: ignore[assignment]
|
||||
|
||||
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
|
||||
if isinstance(other, ChatMessageChunk):
|
||||
if self.role != other.role:
|
||||
|
@ -1693,6 +1693,14 @@
|
||||
'title': 'Example',
|
||||
'type': 'boolean',
|
||||
}),
|
||||
'is_chunk': dict({
|
||||
'default': False,
|
||||
'enum': list([
|
||||
False,
|
||||
]),
|
||||
'title': 'Is Chunk',
|
||||
'type': 'boolean',
|
||||
}),
|
||||
'type': dict({
|
||||
'default': 'ai',
|
||||
'enum': list([
|
||||
@ -1719,6 +1727,14 @@
|
||||
'title': 'Content',
|
||||
'type': 'string',
|
||||
}),
|
||||
'is_chunk': dict({
|
||||
'default': False,
|
||||
'enum': list([
|
||||
False,
|
||||
]),
|
||||
'title': 'Is Chunk',
|
||||
'type': 'boolean',
|
||||
}),
|
||||
'role': dict({
|
||||
'title': 'Role',
|
||||
'type': 'string',
|
||||
@ -1786,6 +1802,14 @@
|
||||
'title': 'Content',
|
||||
'type': 'string',
|
||||
}),
|
||||
'is_chunk': dict({
|
||||
'default': False,
|
||||
'enum': list([
|
||||
False,
|
||||
]),
|
||||
'title': 'Is Chunk',
|
||||
'type': 'boolean',
|
||||
}),
|
||||
'name': dict({
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
@ -1822,6 +1846,14 @@
|
||||
'title': 'Example',
|
||||
'type': 'boolean',
|
||||
}),
|
||||
'is_chunk': dict({
|
||||
'default': False,
|
||||
'enum': list([
|
||||
False,
|
||||
]),
|
||||
'title': 'Is Chunk',
|
||||
'type': 'boolean',
|
||||
}),
|
||||
'type': dict({
|
||||
'default': 'human',
|
||||
'enum': list([
|
||||
@ -1865,6 +1897,14 @@
|
||||
'title': 'Content',
|
||||
'type': 'string',
|
||||
}),
|
||||
'is_chunk': dict({
|
||||
'default': False,
|
||||
'enum': list([
|
||||
False,
|
||||
]),
|
||||
'title': 'Is Chunk',
|
||||
'type': 'boolean',
|
||||
}),
|
||||
'type': dict({
|
||||
'default': 'system',
|
||||
'enum': list([
|
||||
@ -1936,6 +1976,14 @@
|
||||
'title': 'Example',
|
||||
'type': 'boolean',
|
||||
}),
|
||||
'is_chunk': dict({
|
||||
'default': False,
|
||||
'enum': list([
|
||||
False,
|
||||
]),
|
||||
'title': 'Is Chunk',
|
||||
'type': 'boolean',
|
||||
}),
|
||||
'type': dict({
|
||||
'default': 'ai',
|
||||
'enum': list([
|
||||
@ -1962,6 +2010,14 @@
|
||||
'title': 'Content',
|
||||
'type': 'string',
|
||||
}),
|
||||
'is_chunk': dict({
|
||||
'default': False,
|
||||
'enum': list([
|
||||
False,
|
||||
]),
|
||||
'title': 'Is Chunk',
|
||||
'type': 'boolean',
|
||||
}),
|
||||
'role': dict({
|
||||
'title': 'Role',
|
||||
'type': 'string',
|
||||
@ -2029,6 +2085,14 @@
|
||||
'title': 'Content',
|
||||
'type': 'string',
|
||||
}),
|
||||
'is_chunk': dict({
|
||||
'default': False,
|
||||
'enum': list([
|
||||
False,
|
||||
]),
|
||||
'title': 'Is Chunk',
|
||||
'type': 'boolean',
|
||||
}),
|
||||
'name': dict({
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
@ -2065,6 +2129,14 @@
|
||||
'title': 'Example',
|
||||
'type': 'boolean',
|
||||
}),
|
||||
'is_chunk': dict({
|
||||
'default': False,
|
||||
'enum': list([
|
||||
False,
|
||||
]),
|
||||
'title': 'Is Chunk',
|
||||
'type': 'boolean',
|
||||
}),
|
||||
'type': dict({
|
||||
'default': 'human',
|
||||
'enum': list([
|
||||
@ -2108,6 +2180,14 @@
|
||||
'title': 'Content',
|
||||
'type': 'string',
|
||||
}),
|
||||
'is_chunk': dict({
|
||||
'default': False,
|
||||
'enum': list([
|
||||
False,
|
||||
]),
|
||||
'title': 'Is Chunk',
|
||||
'type': 'boolean',
|
||||
}),
|
||||
'type': dict({
|
||||
'default': 'system',
|
||||
'enum': list([
|
||||
@ -2163,6 +2243,14 @@
|
||||
'title': 'Example',
|
||||
'type': 'boolean',
|
||||
}),
|
||||
'is_chunk': dict({
|
||||
'default': True,
|
||||
'enum': list([
|
||||
True,
|
||||
]),
|
||||
'title': 'Is Chunk',
|
||||
'type': 'boolean',
|
||||
}),
|
||||
'type': dict({
|
||||
'default': 'ai',
|
||||
'enum': list([
|
||||
@ -2189,6 +2277,14 @@
|
||||
'title': 'Content',
|
||||
'type': 'string',
|
||||
}),
|
||||
'is_chunk': dict({
|
||||
'default': True,
|
||||
'enum': list([
|
||||
True,
|
||||
]),
|
||||
'title': 'Is Chunk',
|
||||
'type': 'boolean',
|
||||
}),
|
||||
'role': dict({
|
||||
'title': 'Role',
|
||||
'type': 'string',
|
||||
@ -2220,6 +2316,14 @@
|
||||
'title': 'Content',
|
||||
'type': 'string',
|
||||
}),
|
||||
'is_chunk': dict({
|
||||
'default': True,
|
||||
'enum': list([
|
||||
True,
|
||||
]),
|
||||
'title': 'Is Chunk',
|
||||
'type': 'boolean',
|
||||
}),
|
||||
'name': dict({
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
@ -2256,6 +2360,14 @@
|
||||
'title': 'Example',
|
||||
'type': 'boolean',
|
||||
}),
|
||||
'is_chunk': dict({
|
||||
'default': True,
|
||||
'enum': list([
|
||||
True,
|
||||
]),
|
||||
'title': 'Is Chunk',
|
||||
'type': 'boolean',
|
||||
}),
|
||||
'type': dict({
|
||||
'default': 'human',
|
||||
'enum': list([
|
||||
@ -2282,6 +2394,14 @@
|
||||
'title': 'Content',
|
||||
'type': 'string',
|
||||
}),
|
||||
'is_chunk': dict({
|
||||
'default': True,
|
||||
'enum': list([
|
||||
True,
|
||||
]),
|
||||
'title': 'Is Chunk',
|
||||
'type': 'boolean',
|
||||
}),
|
||||
'type': dict({
|
||||
'default': 'system',
|
||||
'enum': list([
|
||||
@ -2328,6 +2448,14 @@
|
||||
'title': 'Example',
|
||||
'type': 'boolean',
|
||||
}),
|
||||
'is_chunk': dict({
|
||||
'default': False,
|
||||
'enum': list([
|
||||
False,
|
||||
]),
|
||||
'title': 'Is Chunk',
|
||||
'type': 'boolean',
|
||||
}),
|
||||
'type': dict({
|
||||
'default': 'ai',
|
||||
'enum': list([
|
||||
@ -2354,6 +2482,14 @@
|
||||
'title': 'Content',
|
||||
'type': 'string',
|
||||
}),
|
||||
'is_chunk': dict({
|
||||
'default': False,
|
||||
'enum': list([
|
||||
False,
|
||||
]),
|
||||
'title': 'Is Chunk',
|
||||
'type': 'boolean',
|
||||
}),
|
||||
'role': dict({
|
||||
'title': 'Role',
|
||||
'type': 'string',
|
||||
@ -2421,6 +2557,14 @@
|
||||
'title': 'Content',
|
||||
'type': 'string',
|
||||
}),
|
||||
'is_chunk': dict({
|
||||
'default': False,
|
||||
'enum': list([
|
||||
False,
|
||||
]),
|
||||
'title': 'Is Chunk',
|
||||
'type': 'boolean',
|
||||
}),
|
||||
'name': dict({
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
@ -2457,6 +2601,14 @@
|
||||
'title': 'Example',
|
||||
'type': 'boolean',
|
||||
}),
|
||||
'is_chunk': dict({
|
||||
'default': False,
|
||||
'enum': list([
|
||||
False,
|
||||
]),
|
||||
'title': 'Is Chunk',
|
||||
'type': 'boolean',
|
||||
}),
|
||||
'type': dict({
|
||||
'default': 'human',
|
||||
'enum': list([
|
||||
@ -2500,6 +2652,14 @@
|
||||
'title': 'Content',
|
||||
'type': 'string',
|
||||
}),
|
||||
'is_chunk': dict({
|
||||
'default': False,
|
||||
'enum': list([
|
||||
False,
|
||||
]),
|
||||
'title': 'Is Chunk',
|
||||
'type': 'boolean',
|
||||
}),
|
||||
'type': dict({
|
||||
'default': 'system',
|
||||
'enum': list([
|
||||
@ -2538,6 +2698,14 @@
|
||||
'title': 'Example',
|
||||
'type': 'boolean',
|
||||
}),
|
||||
'is_chunk': dict({
|
||||
'default': False,
|
||||
'enum': list([
|
||||
False,
|
||||
]),
|
||||
'title': 'Is Chunk',
|
||||
'type': 'boolean',
|
||||
}),
|
||||
'type': dict({
|
||||
'default': 'ai',
|
||||
'enum': list([
|
||||
@ -2564,6 +2732,14 @@
|
||||
'title': 'Content',
|
||||
'type': 'string',
|
||||
}),
|
||||
'is_chunk': dict({
|
||||
'default': False,
|
||||
'enum': list([
|
||||
False,
|
||||
]),
|
||||
'title': 'Is Chunk',
|
||||
'type': 'boolean',
|
||||
}),
|
||||
'role': dict({
|
||||
'title': 'Role',
|
||||
'type': 'string',
|
||||
@ -2631,6 +2807,14 @@
|
||||
'title': 'Content',
|
||||
'type': 'string',
|
||||
}),
|
||||
'is_chunk': dict({
|
||||
'default': False,
|
||||
'enum': list([
|
||||
False,
|
||||
]),
|
||||
'title': 'Is Chunk',
|
||||
'type': 'boolean',
|
||||
}),
|
||||
'name': dict({
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
@ -2667,6 +2851,14 @@
|
||||
'title': 'Example',
|
||||
'type': 'boolean',
|
||||
}),
|
||||
'is_chunk': dict({
|
||||
'default': False,
|
||||
'enum': list([
|
||||
False,
|
||||
]),
|
||||
'title': 'Is Chunk',
|
||||
'type': 'boolean',
|
||||
}),
|
||||
'type': dict({
|
||||
'default': 'human',
|
||||
'enum': list([
|
||||
@ -2721,6 +2913,14 @@
|
||||
'title': 'Content',
|
||||
'type': 'string',
|
||||
}),
|
||||
'is_chunk': dict({
|
||||
'default': False,
|
||||
'enum': list([
|
||||
False,
|
||||
]),
|
||||
'title': 'Is Chunk',
|
||||
'type': 'boolean',
|
||||
}),
|
||||
'type': dict({
|
||||
'default': 'system',
|
||||
'enum': list([
|
||||
@ -2783,6 +2983,14 @@
|
||||
'title': 'Example',
|
||||
'type': 'boolean',
|
||||
}),
|
||||
'is_chunk': dict({
|
||||
'default': False,
|
||||
'enum': list([
|
||||
False,
|
||||
]),
|
||||
'title': 'Is Chunk',
|
||||
'type': 'boolean',
|
||||
}),
|
||||
'type': dict({
|
||||
'default': 'ai',
|
||||
'enum': list([
|
||||
@ -2809,6 +3017,14 @@
|
||||
'title': 'Content',
|
||||
'type': 'string',
|
||||
}),
|
||||
'is_chunk': dict({
|
||||
'default': False,
|
||||
'enum': list([
|
||||
False,
|
||||
]),
|
||||
'title': 'Is Chunk',
|
||||
'type': 'boolean',
|
||||
}),
|
||||
'role': dict({
|
||||
'title': 'Role',
|
||||
'type': 'string',
|
||||
@ -2840,6 +3056,14 @@
|
||||
'title': 'Content',
|
||||
'type': 'string',
|
||||
}),
|
||||
'is_chunk': dict({
|
||||
'default': False,
|
||||
'enum': list([
|
||||
False,
|
||||
]),
|
||||
'title': 'Is Chunk',
|
||||
'type': 'boolean',
|
||||
}),
|
||||
'name': dict({
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
@ -2876,6 +3100,14 @@
|
||||
'title': 'Example',
|
||||
'type': 'boolean',
|
||||
}),
|
||||
'is_chunk': dict({
|
||||
'default': False,
|
||||
'enum': list([
|
||||
False,
|
||||
]),
|
||||
'title': 'Is Chunk',
|
||||
'type': 'boolean',
|
||||
}),
|
||||
'type': dict({
|
||||
'default': 'human',
|
||||
'enum': list([
|
||||
@ -2905,6 +3137,14 @@
|
||||
'title': 'Content',
|
||||
'type': 'string',
|
||||
}),
|
||||
'is_chunk': dict({
|
||||
'default': False,
|
||||
'enum': list([
|
||||
False,
|
||||
]),
|
||||
'title': 'Is Chunk',
|
||||
'type': 'boolean',
|
||||
}),
|
||||
'type': dict({
|
||||
'default': 'system',
|
||||
'enum': list([
|
||||
|
@ -960,7 +960,11 @@ async def test_prompt_with_chat_model(
|
||||
tracer = FakeTracer()
|
||||
assert [
|
||||
*chain.stream({"question": "What is your name?"}, dict(callbacks=[tracer]))
|
||||
] == [AIMessage(content="f"), AIMessage(content="o"), AIMessage(content="o")]
|
||||
] == [
|
||||
AIMessageChunk(content="f"),
|
||||
AIMessageChunk(content="o"),
|
||||
AIMessageChunk(content="o"),
|
||||
]
|
||||
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
|
||||
assert chat_spy.call_args.args[1] == ChatPromptValue(
|
||||
messages=[
|
||||
|
@ -1,11 +1,20 @@
|
||||
"""Test formatting functionality."""
|
||||
|
||||
import unittest
|
||||
from typing import Union
|
||||
|
||||
from langchain.pydantic_v1 import BaseModel
|
||||
from langchain.schema.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
ChatMessage,
|
||||
ChatMessageChunk,
|
||||
FunctionMessage,
|
||||
FunctionMessageChunk,
|
||||
HumanMessage,
|
||||
HumanMessageChunk,
|
||||
SystemMessage,
|
||||
SystemMessageChunk,
|
||||
get_buffer_string,
|
||||
messages_from_dict,
|
||||
messages_to_dict,
|
||||
@ -70,3 +79,50 @@ def test_multiple_msg() -> None:
|
||||
sys_msg,
|
||||
]
|
||||
assert messages_from_dict(messages_to_dict(msgs)) == msgs
|
||||
|
||||
|
||||
def test_distinguish_messages() -> None:
|
||||
"""Test that pydantic is able to discriminate between similar looking messages."""
|
||||
|
||||
class WellKnownTypes(BaseModel):
|
||||
__root__: Union[
|
||||
HumanMessage,
|
||||
AIMessage,
|
||||
SystemMessage,
|
||||
FunctionMessage,
|
||||
HumanMessageChunk,
|
||||
AIMessageChunk,
|
||||
SystemMessageChunk,
|
||||
FunctionMessageChunk,
|
||||
ChatMessageChunk,
|
||||
ChatMessage,
|
||||
]
|
||||
|
||||
messages = [
|
||||
HumanMessage(content="human"),
|
||||
HumanMessageChunk(content="human"),
|
||||
AIMessage(content="ai"),
|
||||
AIMessageChunk(content="ai"),
|
||||
SystemMessage(content="sys"),
|
||||
SystemMessageChunk(content="sys"),
|
||||
FunctionMessage(
|
||||
name="func",
|
||||
content="func",
|
||||
),
|
||||
FunctionMessageChunk(
|
||||
name="func",
|
||||
content="func",
|
||||
),
|
||||
ChatMessage(
|
||||
role="human",
|
||||
content="human",
|
||||
),
|
||||
ChatMessageChunk(
|
||||
role="human",
|
||||
content="human",
|
||||
),
|
||||
]
|
||||
|
||||
for msg in messages:
|
||||
obj1 = WellKnownTypes.parse_obj(msg.dict())
|
||||
assert type(obj1.__root__) == type(msg), f"failed for {type(msg)}"
|
||||
|
Loading…
Reference in New Issue
Block a user