diff --git a/libs/core/langchain_core/messages/utils.py b/libs/core/langchain_core/messages/utils.py index 9e9c5228cb2..f78d877e595 100644 --- a/libs/core/langchain_core/messages/utils.py +++ b/libs/core/langchain_core/messages/utils.py @@ -29,6 +29,9 @@ from typing import ( overload, ) +from pydantic import Discriminator, Field +from typing_extensions import Annotated + from langchain_core.messages.ai import AIMessage, AIMessageChunk from langchain_core.messages.base import BaseMessage, BaseMessageChunk from langchain_core.messages.chat import ChatMessage, ChatMessageChunk @@ -45,19 +48,23 @@ if TYPE_CHECKING: from langchain_core.prompt_values import PromptValue from langchain_core.runnables.base import Runnable -AnyMessage = Union[ - AIMessage, - HumanMessage, - ChatMessage, - SystemMessage, - FunctionMessage, - ToolMessage, - AIMessageChunk, - HumanMessageChunk, - ChatMessageChunk, - SystemMessageChunk, - FunctionMessageChunk, - ToolMessageChunk, + +AnyMessage = Annotated[ + Union[ + AIMessage, + HumanMessage, + ChatMessage, + SystemMessage, + FunctionMessage, + ToolMessage, + AIMessageChunk, + HumanMessageChunk, + ChatMessageChunk, + SystemMessageChunk, + FunctionMessageChunk, + ToolMessageChunk, + ], + Field(discriminator=Discriminator("type")), ] diff --git a/libs/core/tests/unit_tests/prompts/__snapshots__/test_chat.ambr b/libs/core/tests/unit_tests/prompts/__snapshots__/test_chat.ambr index 29ab354d126..7c3c2ffb456 100644 --- a/libs/core/tests/unit_tests/prompts/__snapshots__/test_chat.ambr +++ b/libs/core/tests/unit_tests/prompts/__snapshots__/test_chat.ambr @@ -1349,7 +1349,24 @@ 'history': dict({ 'default': None, 'items': dict({ - 'anyOf': list([ + 'discriminator': dict({ + 'mapping': dict({ + 'AIMessageChunk': '#/$defs/AIMessageChunk', + 'ChatMessageChunk': '#/$defs/ChatMessageChunk', + 'FunctionMessageChunk': '#/$defs/FunctionMessageChunk', + 'HumanMessageChunk': '#/$defs/HumanMessageChunk', + 'SystemMessageChunk': '#/$defs/SystemMessageChunk', + 'ToolMessageChunk': '#/$defs/ToolMessageChunk', + 'ai': '#/$defs/AIMessage', + 'chat': '#/$defs/ChatMessage', + 'function': '#/$defs/FunctionMessage', + 'human': '#/$defs/HumanMessage', + 'system': '#/$defs/SystemMessage', + 'tool': '#/$defs/ToolMessage', + }), + 'propertyName': 'type', + }), + 'oneOf': list([ dict({ '$ref': '#/$defs/AIMessage', }), @@ -2752,7 +2769,24 @@ 'properties': dict({ 'history': dict({ 'items': dict({ - 'anyOf': list([ + 'discriminator': dict({ + 'mapping': dict({ + 'AIMessageChunk': '#/$defs/AIMessageChunk', + 'ChatMessageChunk': '#/$defs/ChatMessageChunk', + 'FunctionMessageChunk': '#/$defs/FunctionMessageChunk', + 'HumanMessageChunk': '#/$defs/HumanMessageChunk', + 'SystemMessageChunk': '#/$defs/SystemMessageChunk', + 'ToolMessageChunk': '#/$defs/ToolMessageChunk', + 'ai': '#/$defs/AIMessage', + 'chat': '#/$defs/ChatMessage', + 'function': '#/$defs/FunctionMessage', + 'human': '#/$defs/HumanMessage', + 'system': '#/$defs/SystemMessage', + 'tool': '#/$defs/ToolMessage', + }), + 'propertyName': 'type', + }), + 'oneOf': list([ dict({ '$ref': '#/$defs/AIMessage', }), diff --git a/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr b/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr index 2a6ff6ad79c..34929ad3abd 100644 --- a/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr +++ b/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr @@ -1710,40 +1710,61 @@ 'type': 'string', }), dict({ - '$ref': '#/$defs/AIMessage', - }), - dict({ - '$ref': '#/$defs/HumanMessage', - }), - dict({ - '$ref': '#/$defs/ChatMessage', - }), - dict({ - '$ref': '#/$defs/SystemMessage', - }), - dict({ - '$ref': '#/$defs/FunctionMessage', - }), - dict({ - '$ref': '#/$defs/ToolMessage', - }), - dict({ - '$ref': '#/$defs/AIMessageChunk', - }), - dict({ - '$ref': '#/$defs/HumanMessageChunk', - }), - dict({ - '$ref': '#/$defs/ChatMessageChunk', - }), - dict({ - '$ref': '#/$defs/SystemMessageChunk', - }), - dict({ - '$ref': '#/$defs/FunctionMessageChunk', - }), - dict({ - '$ref': '#/$defs/ToolMessageChunk', + 'discriminator': dict({ + 'mapping': dict({ + 'AIMessageChunk': '#/$defs/AIMessageChunk', + 'ChatMessageChunk': '#/$defs/ChatMessageChunk', + 'FunctionMessageChunk': '#/$defs/FunctionMessageChunk', + 'HumanMessageChunk': '#/$defs/HumanMessageChunk', + 'SystemMessageChunk': '#/$defs/SystemMessageChunk', + 'ToolMessageChunk': '#/$defs/ToolMessageChunk', + 'ai': '#/$defs/AIMessage', + 'chat': '#/$defs/ChatMessage', + 'function': '#/$defs/FunctionMessage', + 'human': '#/$defs/HumanMessage', + 'system': '#/$defs/SystemMessage', + 'tool': '#/$defs/ToolMessage', + }), + 'propertyName': 'type', + }), + 'oneOf': list([ + dict({ + '$ref': '#/$defs/AIMessage', + }), + dict({ + '$ref': '#/$defs/HumanMessage', + }), + dict({ + '$ref': '#/$defs/ChatMessage', + }), + dict({ + '$ref': '#/$defs/SystemMessage', + }), + dict({ + '$ref': '#/$defs/FunctionMessage', + }), + dict({ + '$ref': '#/$defs/ToolMessage', + }), + dict({ + '$ref': '#/$defs/AIMessageChunk', + }), + dict({ + '$ref': '#/$defs/HumanMessageChunk', + }), + dict({ + '$ref': '#/$defs/ChatMessageChunk', + }), + dict({ + '$ref': '#/$defs/SystemMessageChunk', + }), + dict({ + '$ref': '#/$defs/FunctionMessageChunk', + }), + dict({ + '$ref': '#/$defs/ToolMessageChunk', + }), + ]), }), ]), 'title': 'RunnableParallelInput', diff --git a/libs/core/tests/unit_tests/runnables/__snapshots__/test_runnable.ambr b/libs/core/tests/unit_tests/runnables/__snapshots__/test_runnable.ambr index b9d75f3c74d..07be8ecbd5d 100644 --- a/libs/core/tests/unit_tests/runnables/__snapshots__/test_runnable.ambr +++ b/libs/core/tests/unit_tests/runnables/__snapshots__/test_runnable.ambr @@ -6525,7 +6525,24 @@ 'properties': dict({ 'history': dict({ 'items': dict({ - 'anyOf': list([ + 'discriminator': dict({ + 'mapping': dict({ + 'AIMessageChunk': '#/$defs/AIMessageChunk', + 'ChatMessageChunk': '#/$defs/ChatMessageChunk', + 'FunctionMessageChunk': '#/$defs/FunctionMessageChunk', + 'HumanMessageChunk': '#/$defs/HumanMessageChunk', + 'SystemMessageChunk': '#/$defs/SystemMessageChunk', + 'ToolMessageChunk': '#/$defs/ToolMessageChunk', + 'ai': '#/$defs/AIMessage', + 'chat': '#/$defs/ChatMessage', + 'function': '#/$defs/FunctionMessage', + 'human': '#/$defs/HumanMessage', + 'system': '#/$defs/SystemMessage', + 'tool': '#/$defs/ToolMessage', + }), + 'propertyName': 'type', + }), + 'oneOf': list([ dict({ '$ref': '#/$defs/AIMessage', }), @@ -6973,7 +6990,24 @@ 'properties': dict({ 'messages': dict({ 'items': dict({ - 'anyOf': list([ + 'discriminator': dict({ + 'mapping': dict({ + 'AIMessageChunk': '#/$defs/AIMessageChunk', + 'ChatMessageChunk': '#/$defs/ChatMessageChunk', + 'FunctionMessageChunk': '#/$defs/FunctionMessageChunk', + 'HumanMessageChunk': '#/$defs/HumanMessageChunk', + 'SystemMessageChunk': '#/$defs/SystemMessageChunk', + 'ToolMessageChunk': '#/$defs/ToolMessageChunk', + 'ai': '#/$defs/AIMessage', + 'chat': '#/$defs/ChatMessage', + 'function': '#/$defs/FunctionMessage', + 'human': '#/$defs/HumanMessage', + 'system': '#/$defs/SystemMessage', + 'tool': '#/$defs/ToolMessage', + }), + 'propertyName': 'type', + }), + 'oneOf': list([ dict({ '$ref': '#/$defs/AIMessage', }), @@ -8035,7 +8069,24 @@ }), dict({ 'items': dict({ - 'anyOf': list([ + 'discriminator': dict({ + 'mapping': dict({ + 'AIMessageChunk': '#/definitions/AIMessageChunk', + 'ChatMessageChunk': '#/definitions/ChatMessageChunk', + 'FunctionMessageChunk': '#/definitions/FunctionMessageChunk', + 'HumanMessageChunk': '#/definitions/HumanMessageChunk', + 'SystemMessageChunk': '#/definitions/SystemMessageChunk', + 'ToolMessageChunk': '#/definitions/ToolMessageChunk', + 'ai': '#/definitions/AIMessage', + 'chat': '#/definitions/ChatMessage', + 'function': '#/definitions/FunctionMessage', + 'human': '#/definitions/HumanMessage', + 'system': '#/definitions/SystemMessage', + 'tool': '#/definitions/ToolMessage', + }), + 'propertyName': 'type', + }), + 'oneOf': list([ dict({ '$ref': '#/definitions/AIMessage', }), @@ -8473,7 +8524,24 @@ 'properties': dict({ 'messages': dict({ 'items': dict({ - 'anyOf': list([ + 'discriminator': dict({ + 'mapping': dict({ + 'AIMessageChunk': '#/definitions/AIMessageChunk', + 'ChatMessageChunk': '#/definitions/ChatMessageChunk', + 'FunctionMessageChunk': '#/definitions/FunctionMessageChunk', + 'HumanMessageChunk': '#/definitions/HumanMessageChunk', + 'SystemMessageChunk': '#/definitions/SystemMessageChunk', + 'ToolMessageChunk': '#/definitions/ToolMessageChunk', + 'ai': '#/definitions/AIMessage', + 'chat': '#/definitions/ChatMessage', + 'function': '#/definitions/FunctionMessage', + 'human': '#/definitions/HumanMessage', + 'system': '#/definitions/SystemMessage', + 'tool': '#/definitions/ToolMessage', + }), + 'propertyName': 'type', + }), + 'oneOf': list([ dict({ '$ref': '#/definitions/AIMessage', }), @@ -9513,44 +9581,6 @@ # --- # name: test_schemas[fake_chat_output_schema] dict({ - 'anyOf': list([ - dict({ - '$ref': '#/definitions/AIMessage', - }), - dict({ - '$ref': '#/definitions/HumanMessage', - }), - dict({ - '$ref': '#/definitions/ChatMessage', - }), - dict({ - '$ref': '#/definitions/SystemMessage', - }), - dict({ - '$ref': '#/definitions/FunctionMessage', - }), - dict({ - '$ref': '#/definitions/ToolMessage', - }), - dict({ - '$ref': '#/definitions/AIMessageChunk', - }), - dict({ - '$ref': '#/definitions/HumanMessageChunk', - }), - dict({ - '$ref': '#/definitions/ChatMessageChunk', - }), - dict({ - '$ref': '#/definitions/SystemMessageChunk', - }), - dict({ - '$ref': '#/definitions/FunctionMessageChunk', - }), - dict({ - '$ref': '#/definitions/ToolMessageChunk', - }), - ]), 'definitions': dict({ 'AIMessage': dict({ 'additionalProperties': True, @@ -10893,6 +10923,61 @@ 'type': 'object', }), }), + 'discriminator': dict({ + 'mapping': dict({ + 'AIMessageChunk': '#/definitions/AIMessageChunk', + 'ChatMessageChunk': '#/definitions/ChatMessageChunk', + 'FunctionMessageChunk': '#/definitions/FunctionMessageChunk', + 'HumanMessageChunk': '#/definitions/HumanMessageChunk', + 'SystemMessageChunk': '#/definitions/SystemMessageChunk', + 'ToolMessageChunk': '#/definitions/ToolMessageChunk', + 'ai': '#/definitions/AIMessage', + 'chat': '#/definitions/ChatMessage', + 'function': '#/definitions/FunctionMessage', + 'human': '#/definitions/HumanMessage', + 'system': '#/definitions/SystemMessage', + 'tool': '#/definitions/ToolMessage', + }), + 'propertyName': 'type', + }), + 'oneOf': list([ + dict({ + '$ref': '#/definitions/AIMessage', + }), + dict({ + '$ref': '#/definitions/HumanMessage', + }), + dict({ + '$ref': '#/definitions/ChatMessage', + }), + dict({ + '$ref': '#/definitions/SystemMessage', + }), + dict({ + '$ref': '#/definitions/FunctionMessage', + }), + dict({ + '$ref': '#/definitions/ToolMessage', + }), + dict({ + '$ref': '#/definitions/AIMessageChunk', + }), + dict({ + '$ref': '#/definitions/HumanMessageChunk', + }), + dict({ + '$ref': '#/definitions/ChatMessageChunk', + }), + dict({ + '$ref': '#/definitions/SystemMessageChunk', + }), + dict({ + '$ref': '#/definitions/FunctionMessageChunk', + }), + dict({ + '$ref': '#/definitions/ToolMessageChunk', + }), + ]), 'title': 'FakeListChatModelOutput', }) # --- @@ -10910,7 +10995,24 @@ }), dict({ 'items': dict({ - 'anyOf': list([ + 'discriminator': dict({ + 'mapping': dict({ + 'AIMessageChunk': '#/definitions/AIMessageChunk', + 'ChatMessageChunk': '#/definitions/ChatMessageChunk', + 'FunctionMessageChunk': '#/definitions/FunctionMessageChunk', + 'HumanMessageChunk': '#/definitions/HumanMessageChunk', + 'SystemMessageChunk': '#/definitions/SystemMessageChunk', + 'ToolMessageChunk': '#/definitions/ToolMessageChunk', + 'ai': '#/definitions/AIMessage', + 'chat': '#/definitions/ChatMessage', + 'function': '#/definitions/FunctionMessage', + 'human': '#/definitions/HumanMessage', + 'system': '#/definitions/SystemMessage', + 'tool': '#/definitions/ToolMessage', + }), + 'propertyName': 'type', + }), + 'oneOf': list([ dict({ '$ref': '#/definitions/AIMessage', }), @@ -11348,7 +11450,24 @@ 'properties': dict({ 'messages': dict({ 'items': dict({ - 'anyOf': list([ + 'discriminator': dict({ + 'mapping': dict({ + 'AIMessageChunk': '#/definitions/AIMessageChunk', + 'ChatMessageChunk': '#/definitions/ChatMessageChunk', + 'FunctionMessageChunk': '#/definitions/FunctionMessageChunk', + 'HumanMessageChunk': '#/definitions/HumanMessageChunk', + 'SystemMessageChunk': '#/definitions/SystemMessageChunk', + 'ToolMessageChunk': '#/definitions/ToolMessageChunk', + 'ai': '#/definitions/AIMessage', + 'chat': '#/definitions/ChatMessage', + 'function': '#/definitions/FunctionMessage', + 'human': '#/definitions/HumanMessage', + 'system': '#/definitions/SystemMessage', + 'tool': '#/definitions/ToolMessage', + }), + 'propertyName': 'type', + }), + 'oneOf': list([ dict({ '$ref': '#/definitions/AIMessage', }), @@ -12393,40 +12512,61 @@ 'type': 'string', }), dict({ - '$ref': '#/definitions/AIMessage', - }), - dict({ - '$ref': '#/definitions/HumanMessage', - }), - dict({ - '$ref': '#/definitions/ChatMessage', - }), - dict({ - '$ref': '#/definitions/SystemMessage', - }), - dict({ - '$ref': '#/definitions/FunctionMessage', - }), - dict({ - '$ref': '#/definitions/ToolMessage', - }), - dict({ - '$ref': '#/definitions/AIMessageChunk', - }), - dict({ - '$ref': '#/definitions/HumanMessageChunk', - }), - dict({ - '$ref': '#/definitions/ChatMessageChunk', - }), - dict({ - '$ref': '#/definitions/SystemMessageChunk', - }), - dict({ - '$ref': '#/definitions/FunctionMessageChunk', - }), - dict({ - '$ref': '#/definitions/ToolMessageChunk', + 'discriminator': dict({ + 'mapping': dict({ + 'AIMessageChunk': '#/definitions/AIMessageChunk', + 'ChatMessageChunk': '#/definitions/ChatMessageChunk', + 'FunctionMessageChunk': '#/definitions/FunctionMessageChunk', + 'HumanMessageChunk': '#/definitions/HumanMessageChunk', + 'SystemMessageChunk': '#/definitions/SystemMessageChunk', + 'ToolMessageChunk': '#/definitions/ToolMessageChunk', + 'ai': '#/definitions/AIMessage', + 'chat': '#/definitions/ChatMessage', + 'function': '#/definitions/FunctionMessage', + 'human': '#/definitions/HumanMessage', + 'system': '#/definitions/SystemMessage', + 'tool': '#/definitions/ToolMessage', + }), + 'propertyName': 'type', + }), + 'oneOf': list([ + dict({ + '$ref': '#/definitions/AIMessage', + }), + dict({ + '$ref': '#/definitions/HumanMessage', + }), + dict({ + '$ref': '#/definitions/ChatMessage', + }), + dict({ + '$ref': '#/definitions/SystemMessage', + }), + dict({ + '$ref': '#/definitions/FunctionMessage', + }), + dict({ + '$ref': '#/definitions/ToolMessage', + }), + dict({ + '$ref': '#/definitions/AIMessageChunk', + }), + dict({ + '$ref': '#/definitions/HumanMessageChunk', + }), + dict({ + '$ref': '#/definitions/ChatMessageChunk', + }), + dict({ + '$ref': '#/definitions/SystemMessageChunk', + }), + dict({ + '$ref': '#/definitions/FunctionMessageChunk', + }), + dict({ + '$ref': '#/definitions/ToolMessageChunk', + }), + ]), }), ]), 'definitions': dict({ @@ -14172,7 +14312,24 @@ 'properties': dict({ 'messages': dict({ 'items': dict({ - 'anyOf': list([ + 'discriminator': dict({ + 'mapping': dict({ + 'AIMessageChunk': '#/definitions/AIMessageChunk', + 'ChatMessageChunk': '#/definitions/ChatMessageChunk', + 'FunctionMessageChunk': '#/definitions/FunctionMessageChunk', + 'HumanMessageChunk': '#/definitions/HumanMessageChunk', + 'SystemMessageChunk': '#/definitions/SystemMessageChunk', + 'ToolMessageChunk': '#/definitions/ToolMessageChunk', + 'ai': '#/definitions/AIMessage', + 'chat': '#/definitions/ChatMessage', + 'function': '#/definitions/FunctionMessage', + 'human': '#/definitions/HumanMessage', + 'system': '#/definitions/SystemMessage', + 'tool': '#/definitions/ToolMessage', + }), + 'propertyName': 'type', + }), + 'oneOf': list([ dict({ '$ref': '#/definitions/AIMessage', }), @@ -15631,7 +15788,24 @@ 'properties': dict({ 'messages': dict({ 'items': dict({ - 'anyOf': list([ + 'discriminator': dict({ + 'mapping': dict({ + 'AIMessageChunk': '#/definitions/AIMessageChunk', + 'ChatMessageChunk': '#/definitions/ChatMessageChunk', + 'FunctionMessageChunk': '#/definitions/FunctionMessageChunk', + 'HumanMessageChunk': '#/definitions/HumanMessageChunk', + 'SystemMessageChunk': '#/definitions/SystemMessageChunk', + 'ToolMessageChunk': '#/definitions/ToolMessageChunk', + 'ai': '#/definitions/AIMessage', + 'chat': '#/definitions/ChatMessage', + 'function': '#/definitions/FunctionMessage', + 'human': '#/definitions/HumanMessage', + 'system': '#/definitions/SystemMessage', + 'tool': '#/definitions/ToolMessage', + }), + 'propertyName': 'type', + }), + 'oneOf': list([ dict({ '$ref': '#/definitions/AIMessage', }), diff --git a/libs/core/tests/unit_tests/test_prompt_values.py b/libs/core/tests/unit_tests/test_prompt_values.py index 625cd51202b..6a08a4270ac 100644 --- a/libs/core/tests/unit_tests/test_prompt_values.py +++ b/libs/core/tests/unit_tests/test_prompt_values.py @@ -1,5 +1,3 @@ -import pytest - from langchain_core.messages import ( AIMessage, AIMessageChunk, @@ -13,7 +11,6 @@ from langchain_core.messages import ( from langchain_core.prompt_values import ChatPromptValueConcrete -@pytest.mark.xfail(reason="Broken union type.") def test_chat_prompt_value_concrete() -> None: messages: list = [ AIMessage("foo"), diff --git a/libs/langchain/tests/unit_tests/test_schema.py b/libs/langchain/tests/unit_tests/test_schema.py index 1c4136a60e9..aa0cc26dd9a 100644 --- a/libs/langchain/tests/unit_tests/test_schema.py +++ b/libs/langchain/tests/unit_tests/test_schema.py @@ -16,6 +16,7 @@ from langchain_core.messages import ( HumanMessageChunk, SystemMessage, SystemMessageChunk, + ToolMessage, ) from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, Generation from langchain_core.prompt_values import ChatPromptValueConcrete, StringPromptValue @@ -73,7 +74,12 @@ def test_serialization_of_wellknown_objects() -> None: content="human", ), StringPromptValue(text="hello"), + ChatPromptValueConcrete(messages=[AIMessage(content="foo")]), ChatPromptValueConcrete(messages=[HumanMessage(content="human")]), + ChatPromptValueConcrete( + messages=[ToolMessage(content="foo", tool_call_id="bar")] + ), + ChatPromptValueConcrete(messages=[SystemMessage(content="foo")]), Document(page_content="hello"), AgentFinish(return_values={}, log=""), AgentAction(tool="tool", tool_input="input", log=""),