From 760ce596016516031ea0e489b35d78c82906115b Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Mon, 9 Sep 2024 15:24:04 -0400 Subject: [PATCH] qxqxqx --- libs/core/langchain_core/messages/utils.py | 10 -- libs/core/langchain_core/prompt_values.py | 17 +- .../prompts/__snapshots__/test_chat.ambr | 34 ---- .../runnables/__snapshots__/test_graph.ambr | 17 -- .../__snapshots__/test_runnable.ambr | 170 ------------------ .../tests/unit_tests/test_pydantic_serde.py | 112 ++++++++++++ 6 files changed, 128 insertions(+), 232 deletions(-) create mode 100644 libs/core/tests/unit_tests/test_pydantic_serde.py diff --git a/libs/core/langchain_core/messages/utils.py b/libs/core/langchain_core/messages/utils.py index f78d877e595..f7ad2436a70 100644 --- a/libs/core/langchain_core/messages/utils.py +++ b/libs/core/langchain_core/messages/utils.py @@ -53,16 +53,6 @@ AnyMessage = Annotated[ Union[ AIMessage, HumanMessage, - ChatMessage, - SystemMessage, - FunctionMessage, - ToolMessage, - AIMessageChunk, - HumanMessageChunk, - ChatMessageChunk, - SystemMessageChunk, - FunctionMessageChunk, - ToolMessageChunk, ], Field(discriminator=Discriminator("type")), ] diff --git a/libs/core/langchain_core/prompt_values.py b/libs/core/langchain_core/prompt_values.py index 27bf3e5df7f..b180f62c766 100644 --- a/libs/core/langchain_core/prompt_values.py +++ b/libs/core/langchain_core/prompt_values.py @@ -125,12 +125,24 @@ class ImagePromptValue(PromptValue): """Return prompt (image URL) as messages.""" return [HumanMessage(content=[cast(dict, self.image_url)])] +from typing import Annotated, Union +from langchain_core.messages import AIMessage, HumanMessage +from pydantic import Field, Discriminator + +# AnyMessage = Annotated[ +# Union[ +# AIMessage, +# HumanMessage, +# ], +# Field(discriminator=Discriminator("type")), +# ] + class ChatPromptValueConcrete(ChatPromptValue): """Chat prompt value which explicitly lists out the message types it accepts. For use in external schemas.""" - messages: Sequence[AnyMessage] + messages: Sequence[Annotated[Union[AIMessage, HumanMessage], Field(discriminator="type")]] """Sequence of messages.""" type: Literal["ChatPromptValueConcrete"] = "ChatPromptValueConcrete" @@ -142,3 +154,6 @@ class ChatPromptValueConcrete(ChatPromptValue): Defaults to ["langchain", "prompts", "chat"]. """ return ["langchain", "prompts", "chat"] + + +ChatPromptValueConcrete.model_rebuild() diff --git a/libs/core/tests/unit_tests/prompts/__snapshots__/test_chat.ambr b/libs/core/tests/unit_tests/prompts/__snapshots__/test_chat.ambr index 7c3c2ffb456..11cec3e3b2e 100644 --- a/libs/core/tests/unit_tests/prompts/__snapshots__/test_chat.ambr +++ b/libs/core/tests/unit_tests/prompts/__snapshots__/test_chat.ambr @@ -1349,23 +1349,6 @@ 'history': dict({ 'default': None, 'items': dict({ - 'discriminator': dict({ - 'mapping': dict({ - 'AIMessageChunk': '#/$defs/AIMessageChunk', - 'ChatMessageChunk': '#/$defs/ChatMessageChunk', - 'FunctionMessageChunk': '#/$defs/FunctionMessageChunk', - 'HumanMessageChunk': '#/$defs/HumanMessageChunk', - 'SystemMessageChunk': '#/$defs/SystemMessageChunk', - 'ToolMessageChunk': '#/$defs/ToolMessageChunk', - 'ai': '#/$defs/AIMessage', - 'chat': '#/$defs/ChatMessage', - 'function': '#/$defs/FunctionMessage', - 'human': '#/$defs/HumanMessage', - 'system': '#/$defs/SystemMessage', - 'tool': '#/$defs/ToolMessage', - }), - 'propertyName': 'type', - }), 'oneOf': list([ dict({ '$ref': '#/$defs/AIMessage', @@ -2769,23 +2752,6 @@ 'properties': dict({ 'history': dict({ 'items': dict({ - 'discriminator': dict({ - 'mapping': dict({ - 'AIMessageChunk': '#/$defs/AIMessageChunk', - 'ChatMessageChunk': '#/$defs/ChatMessageChunk', - 'FunctionMessageChunk': '#/$defs/FunctionMessageChunk', - 'HumanMessageChunk': '#/$defs/HumanMessageChunk', - 'SystemMessageChunk': '#/$defs/SystemMessageChunk', - 'ToolMessageChunk': '#/$defs/ToolMessageChunk', - 'ai': '#/$defs/AIMessage', - 'chat': '#/$defs/ChatMessage', - 'function': '#/$defs/FunctionMessage', - 'human': '#/$defs/HumanMessage', - 'system': '#/$defs/SystemMessage', - 'tool': '#/$defs/ToolMessage', - }), - 'propertyName': 'type', - }), 'oneOf': list([ dict({ '$ref': '#/$defs/AIMessage', diff --git a/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr b/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr index 34929ad3abd..bf14546547a 100644 --- a/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr +++ b/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr @@ -1710,23 +1710,6 @@ 'type': 'string', }), dict({ - 'discriminator': dict({ - 'mapping': dict({ - 'AIMessageChunk': '#/$defs/AIMessageChunk', - 'ChatMessageChunk': '#/$defs/ChatMessageChunk', - 'FunctionMessageChunk': '#/$defs/FunctionMessageChunk', - 'HumanMessageChunk': '#/$defs/HumanMessageChunk', - 'SystemMessageChunk': '#/$defs/SystemMessageChunk', - 'ToolMessageChunk': '#/$defs/ToolMessageChunk', - 'ai': '#/$defs/AIMessage', - 'chat': '#/$defs/ChatMessage', - 'function': '#/$defs/FunctionMessage', - 'human': '#/$defs/HumanMessage', - 'system': '#/$defs/SystemMessage', - 'tool': '#/$defs/ToolMessage', - }), - 'propertyName': 'type', - }), 'oneOf': list([ dict({ '$ref': '#/$defs/AIMessage', diff --git a/libs/core/tests/unit_tests/runnables/__snapshots__/test_runnable.ambr b/libs/core/tests/unit_tests/runnables/__snapshots__/test_runnable.ambr index bee55716e86..23bc73ab3a8 100644 --- a/libs/core/tests/unit_tests/runnables/__snapshots__/test_runnable.ambr +++ b/libs/core/tests/unit_tests/runnables/__snapshots__/test_runnable.ambr @@ -6855,23 +6855,6 @@ 'properties': dict({ 'history': dict({ 'items': dict({ - 'discriminator': dict({ - 'mapping': dict({ - 'AIMessageChunk': '#/$defs/AIMessageChunk', - 'ChatMessageChunk': '#/$defs/ChatMessageChunk', - 'FunctionMessageChunk': '#/$defs/FunctionMessageChunk', - 'HumanMessageChunk': '#/$defs/HumanMessageChunk', - 'SystemMessageChunk': '#/$defs/SystemMessageChunk', - 'ToolMessageChunk': '#/$defs/ToolMessageChunk', - 'ai': '#/$defs/AIMessage', - 'chat': '#/$defs/ChatMessage', - 'function': '#/$defs/FunctionMessage', - 'human': '#/$defs/HumanMessage', - 'system': '#/$defs/SystemMessage', - 'tool': '#/$defs/ToolMessage', - }), - 'propertyName': 'type', - }), 'oneOf': list([ dict({ '$ref': '#/$defs/AIMessage', @@ -7320,23 +7303,6 @@ 'properties': dict({ 'messages': dict({ 'items': dict({ - 'discriminator': dict({ - 'mapping': dict({ - 'AIMessageChunk': '#/$defs/AIMessageChunk', - 'ChatMessageChunk': '#/$defs/ChatMessageChunk', - 'FunctionMessageChunk': '#/$defs/FunctionMessageChunk', - 'HumanMessageChunk': '#/$defs/HumanMessageChunk', - 'SystemMessageChunk': '#/$defs/SystemMessageChunk', - 'ToolMessageChunk': '#/$defs/ToolMessageChunk', - 'ai': '#/$defs/AIMessage', - 'chat': '#/$defs/ChatMessage', - 'function': '#/$defs/FunctionMessage', - 'human': '#/$defs/HumanMessage', - 'system': '#/$defs/SystemMessage', - 'tool': '#/$defs/ToolMessage', - }), - 'propertyName': 'type', - }), 'oneOf': list([ dict({ '$ref': '#/$defs/AIMessage', @@ -8399,23 +8365,6 @@ }), dict({ 'items': dict({ - 'discriminator': dict({ - 'mapping': dict({ - 'AIMessageChunk': '#/definitions/AIMessageChunk', - 'ChatMessageChunk': '#/definitions/ChatMessageChunk', - 'FunctionMessageChunk': '#/definitions/FunctionMessageChunk', - 'HumanMessageChunk': '#/definitions/HumanMessageChunk', - 'SystemMessageChunk': '#/definitions/SystemMessageChunk', - 'ToolMessageChunk': '#/definitions/ToolMessageChunk', - 'ai': '#/definitions/AIMessage', - 'chat': '#/definitions/ChatMessage', - 'function': '#/definitions/FunctionMessage', - 'human': '#/definitions/HumanMessage', - 'system': '#/definitions/SystemMessage', - 'tool': '#/definitions/ToolMessage', - }), - 'propertyName': 'type', - }), 'oneOf': list([ dict({ '$ref': '#/definitions/AIMessage', @@ -8854,23 +8803,6 @@ 'properties': dict({ 'messages': dict({ 'items': dict({ - 'discriminator': dict({ - 'mapping': dict({ - 'AIMessageChunk': '#/definitions/AIMessageChunk', - 'ChatMessageChunk': '#/definitions/ChatMessageChunk', - 'FunctionMessageChunk': '#/definitions/FunctionMessageChunk', - 'HumanMessageChunk': '#/definitions/HumanMessageChunk', - 'SystemMessageChunk': '#/definitions/SystemMessageChunk', - 'ToolMessageChunk': '#/definitions/ToolMessageChunk', - 'ai': '#/definitions/AIMessage', - 'chat': '#/definitions/ChatMessage', - 'function': '#/definitions/FunctionMessage', - 'human': '#/definitions/HumanMessage', - 'system': '#/definitions/SystemMessage', - 'tool': '#/definitions/ToolMessage', - }), - 'propertyName': 'type', - }), 'oneOf': list([ dict({ '$ref': '#/definitions/AIMessage', @@ -11253,23 +11185,6 @@ 'type': 'object', }), }), - 'discriminator': dict({ - 'mapping': dict({ - 'AIMessageChunk': '#/definitions/AIMessageChunk', - 'ChatMessageChunk': '#/definitions/ChatMessageChunk', - 'FunctionMessageChunk': '#/definitions/FunctionMessageChunk', - 'HumanMessageChunk': '#/definitions/HumanMessageChunk', - 'SystemMessageChunk': '#/definitions/SystemMessageChunk', - 'ToolMessageChunk': '#/definitions/ToolMessageChunk', - 'ai': '#/definitions/AIMessage', - 'chat': '#/definitions/ChatMessage', - 'function': '#/definitions/FunctionMessage', - 'human': '#/definitions/HumanMessage', - 'system': '#/definitions/SystemMessage', - 'tool': '#/definitions/ToolMessage', - }), - 'propertyName': 'type', - }), 'oneOf': list([ dict({ '$ref': '#/definitions/AIMessage', @@ -11325,23 +11240,6 @@ }), dict({ 'items': dict({ - 'discriminator': dict({ - 'mapping': dict({ - 'AIMessageChunk': '#/definitions/AIMessageChunk', - 'ChatMessageChunk': '#/definitions/ChatMessageChunk', - 'FunctionMessageChunk': '#/definitions/FunctionMessageChunk', - 'HumanMessageChunk': '#/definitions/HumanMessageChunk', - 'SystemMessageChunk': '#/definitions/SystemMessageChunk', - 'ToolMessageChunk': '#/definitions/ToolMessageChunk', - 'ai': '#/definitions/AIMessage', - 'chat': '#/definitions/ChatMessage', - 'function': '#/definitions/FunctionMessage', - 'human': '#/definitions/HumanMessage', - 'system': '#/definitions/SystemMessage', - 'tool': '#/definitions/ToolMessage', - }), - 'propertyName': 'type', - }), 'oneOf': list([ dict({ '$ref': '#/definitions/AIMessage', @@ -11780,23 +11678,6 @@ 'properties': dict({ 'messages': dict({ 'items': dict({ - 'discriminator': dict({ - 'mapping': dict({ - 'AIMessageChunk': '#/definitions/AIMessageChunk', - 'ChatMessageChunk': '#/definitions/ChatMessageChunk', - 'FunctionMessageChunk': '#/definitions/FunctionMessageChunk', - 'HumanMessageChunk': '#/definitions/HumanMessageChunk', - 'SystemMessageChunk': '#/definitions/SystemMessageChunk', - 'ToolMessageChunk': '#/definitions/ToolMessageChunk', - 'ai': '#/definitions/AIMessage', - 'chat': '#/definitions/ChatMessage', - 'function': '#/definitions/FunctionMessage', - 'human': '#/definitions/HumanMessage', - 'system': '#/definitions/SystemMessage', - 'tool': '#/definitions/ToolMessage', - }), - 'propertyName': 'type', - }), 'oneOf': list([ dict({ '$ref': '#/definitions/AIMessage', @@ -12842,23 +12723,6 @@ 'type': 'string', }), dict({ - 'discriminator': dict({ - 'mapping': dict({ - 'AIMessageChunk': '#/definitions/AIMessageChunk', - 'ChatMessageChunk': '#/definitions/ChatMessageChunk', - 'FunctionMessageChunk': '#/definitions/FunctionMessageChunk', - 'HumanMessageChunk': '#/definitions/HumanMessageChunk', - 'SystemMessageChunk': '#/definitions/SystemMessageChunk', - 'ToolMessageChunk': '#/definitions/ToolMessageChunk', - 'ai': '#/definitions/AIMessage', - 'chat': '#/definitions/ChatMessage', - 'function': '#/definitions/FunctionMessage', - 'human': '#/definitions/HumanMessage', - 'system': '#/definitions/SystemMessage', - 'tool': '#/definitions/ToolMessage', - }), - 'propertyName': 'type', - }), 'oneOf': list([ dict({ '$ref': '#/definitions/AIMessage', @@ -14642,23 +14506,6 @@ 'properties': dict({ 'messages': dict({ 'items': dict({ - 'discriminator': dict({ - 'mapping': dict({ - 'AIMessageChunk': '#/definitions/AIMessageChunk', - 'ChatMessageChunk': '#/definitions/ChatMessageChunk', - 'FunctionMessageChunk': '#/definitions/FunctionMessageChunk', - 'HumanMessageChunk': '#/definitions/HumanMessageChunk', - 'SystemMessageChunk': '#/definitions/SystemMessageChunk', - 'ToolMessageChunk': '#/definitions/ToolMessageChunk', - 'ai': '#/definitions/AIMessage', - 'chat': '#/definitions/ChatMessage', - 'function': '#/definitions/FunctionMessage', - 'human': '#/definitions/HumanMessage', - 'system': '#/definitions/SystemMessage', - 'tool': '#/definitions/ToolMessage', - }), - 'propertyName': 'type', - }), 'oneOf': list([ dict({ '$ref': '#/definitions/AIMessage', @@ -16118,23 +15965,6 @@ 'properties': dict({ 'messages': dict({ 'items': dict({ - 'discriminator': dict({ - 'mapping': dict({ - 'AIMessageChunk': '#/definitions/AIMessageChunk', - 'ChatMessageChunk': '#/definitions/ChatMessageChunk', - 'FunctionMessageChunk': '#/definitions/FunctionMessageChunk', - 'HumanMessageChunk': '#/definitions/HumanMessageChunk', - 'SystemMessageChunk': '#/definitions/SystemMessageChunk', - 'ToolMessageChunk': '#/definitions/ToolMessageChunk', - 'ai': '#/definitions/AIMessage', - 'chat': '#/definitions/ChatMessage', - 'function': '#/definitions/FunctionMessage', - 'human': '#/definitions/HumanMessage', - 'system': '#/definitions/SystemMessage', - 'tool': '#/definitions/ToolMessage', - }), - 'propertyName': 'type', - }), 'oneOf': list([ dict({ '$ref': '#/definitions/AIMessage', diff --git a/libs/core/tests/unit_tests/test_pydantic_serde.py b/libs/core/tests/unit_tests/test_pydantic_serde.py new file mode 100644 index 00000000000..bd24a1a4950 --- /dev/null +++ b/libs/core/tests/unit_tests/test_pydantic_serde.py @@ -0,0 +1,112 @@ +"""A set of tests that verifies that Union discrimination works correctly with +the various pydantic base models. + +These tests can uncover issues that will also arise during regular instantiation +of the models (i.e., not necessarily from loading or dumping JSON). +""" + +import pytest +from pydantic import RootModel, ValidationError + +from langchain_core.messages import ( + AIMessage, + AIMessageChunk, + AnyMessage, + ChatMessage, + ChatMessageChunk, + FunctionMessage, + FunctionMessageChunk, + HumanMessage, + HumanMessageChunk, + SystemMessage, + SystemMessageChunk, +) +from langchain_core.prompt_values import ChatPromptValueConcrete + + +def test_serde_any_message() -> None: + """Test AnyMessage() serder.""" + + lc_objects = [ + HumanMessage(content="human"), + HumanMessageChunk(content="human"), + AIMessage(content="ai"), + AIMessageChunk(content="ai"), + SystemMessage(content="sys"), + SystemMessageChunk(content="sys"), + FunctionMessage( + name="func", + content="func", + ), + FunctionMessageChunk( + name="func", + content="func", + ), + ChatMessage( + role="human", + content="human", + ), + ChatMessageChunk( + role="human", + content="human", + ), + ] + + Model = RootModel[AnyMessage] + + for lc_object in lc_objects: + d = lc_object.model_dump() + assert "type" in d, f"Missing key `type` for {type(lc_object)}" + obj1 = Model.model_validate(d) + assert type(obj1.root) is type(lc_object), f"failed for {type(lc_object)}" + + with pytest.raises((TypeError, ValidationError)): + # Make sure that specifically validation error is raised + Model.model_validate({}) + + +def test_serde_chat_prompt_value(): + prompt = ChatPromptValueConcrete( + messages=[ + AIMessage( + content="Hello", + ), + HumanMessage( + content=" World", + ) + ] + ) + + # Derived = RootModel[Sequence[Any_]] + + +def test_kookoo(): + import pydantic + from pydantic import __version__ + + from typing import Annotated, Union, Literal, Sequence, Any + from pydantic import BaseModel, Field, Tag, RootModel, Discriminator + import pprint + class Base(BaseModel): + y: int = 'hello' + type: Literal['base'] = 'base' + + class Foo(Base): + type: Literal['foo'] = 'foo' + x: int + + class Bar(Base): + type: Literal['bar'] = 'bar' + x: int + + FooOrBar = Annotated[Union[Foo, Bar], Field(discriminator="type")] + + + class BaseContainer(BaseModel): + messages: Sequence[Base] + + class Container(BaseModel): + messages: Sequence[FooOrBar] + + + Container(messages=[Foo(x=5), Bar(x=2), Foo(x=10)])