This commit is contained in:
Eugene Yurtsev
2024-09-09 15:24:04 -04:00
parent 0319ccd273
commit 760ce59601
6 changed files with 128 additions and 232 deletions

View File

@@ -53,16 +53,6 @@ AnyMessage = Annotated[
Union[ Union[
AIMessage, AIMessage,
HumanMessage, HumanMessage,
ChatMessage,
SystemMessage,
FunctionMessage,
ToolMessage,
AIMessageChunk,
HumanMessageChunk,
ChatMessageChunk,
SystemMessageChunk,
FunctionMessageChunk,
ToolMessageChunk,
], ],
Field(discriminator=Discriminator("type")), Field(discriminator=Discriminator("type")),
] ]

View File

@@ -125,12 +125,24 @@ class ImagePromptValue(PromptValue):
"""Return prompt (image URL) as messages.""" """Return prompt (image URL) as messages."""
return [HumanMessage(content=[cast(dict, self.image_url)])] return [HumanMessage(content=[cast(dict, self.image_url)])]
from typing import Annotated, Union
from langchain_core.messages import AIMessage, HumanMessage
from pydantic import Field, Discriminator
# AnyMessage = Annotated[
# Union[
# AIMessage,
# HumanMessage,
# ],
# Field(discriminator=Discriminator("type")),
# ]
class ChatPromptValueConcrete(ChatPromptValue): class ChatPromptValueConcrete(ChatPromptValue):
"""Chat prompt value which explicitly lists out the message types it accepts. """Chat prompt value which explicitly lists out the message types it accepts.
For use in external schemas.""" For use in external schemas."""
messages: Sequence[AnyMessage] messages: Sequence[Annotated[Union[AIMessage, HumanMessage], Field(discriminator="type")]]
"""Sequence of messages.""" """Sequence of messages."""
type: Literal["ChatPromptValueConcrete"] = "ChatPromptValueConcrete" type: Literal["ChatPromptValueConcrete"] = "ChatPromptValueConcrete"
@@ -142,3 +154,6 @@ class ChatPromptValueConcrete(ChatPromptValue):
Defaults to ["langchain", "prompts", "chat"]. Defaults to ["langchain", "prompts", "chat"].
""" """
return ["langchain", "prompts", "chat"] return ["langchain", "prompts", "chat"]
ChatPromptValueConcrete.model_rebuild()

View File

@@ -1349,23 +1349,6 @@
'history': dict({ 'history': dict({
'default': None, 'default': None,
'items': dict({ 'items': dict({
'discriminator': dict({
'mapping': dict({
'AIMessageChunk': '#/$defs/AIMessageChunk',
'ChatMessageChunk': '#/$defs/ChatMessageChunk',
'FunctionMessageChunk': '#/$defs/FunctionMessageChunk',
'HumanMessageChunk': '#/$defs/HumanMessageChunk',
'SystemMessageChunk': '#/$defs/SystemMessageChunk',
'ToolMessageChunk': '#/$defs/ToolMessageChunk',
'ai': '#/$defs/AIMessage',
'chat': '#/$defs/ChatMessage',
'function': '#/$defs/FunctionMessage',
'human': '#/$defs/HumanMessage',
'system': '#/$defs/SystemMessage',
'tool': '#/$defs/ToolMessage',
}),
'propertyName': 'type',
}),
'oneOf': list([ 'oneOf': list([
dict({ dict({
'$ref': '#/$defs/AIMessage', '$ref': '#/$defs/AIMessage',
@@ -2769,23 +2752,6 @@
'properties': dict({ 'properties': dict({
'history': dict({ 'history': dict({
'items': dict({ 'items': dict({
'discriminator': dict({
'mapping': dict({
'AIMessageChunk': '#/$defs/AIMessageChunk',
'ChatMessageChunk': '#/$defs/ChatMessageChunk',
'FunctionMessageChunk': '#/$defs/FunctionMessageChunk',
'HumanMessageChunk': '#/$defs/HumanMessageChunk',
'SystemMessageChunk': '#/$defs/SystemMessageChunk',
'ToolMessageChunk': '#/$defs/ToolMessageChunk',
'ai': '#/$defs/AIMessage',
'chat': '#/$defs/ChatMessage',
'function': '#/$defs/FunctionMessage',
'human': '#/$defs/HumanMessage',
'system': '#/$defs/SystemMessage',
'tool': '#/$defs/ToolMessage',
}),
'propertyName': 'type',
}),
'oneOf': list([ 'oneOf': list([
dict({ dict({
'$ref': '#/$defs/AIMessage', '$ref': '#/$defs/AIMessage',

View File

@@ -1710,23 +1710,6 @@
'type': 'string', 'type': 'string',
}), }),
dict({ dict({
'discriminator': dict({
'mapping': dict({
'AIMessageChunk': '#/$defs/AIMessageChunk',
'ChatMessageChunk': '#/$defs/ChatMessageChunk',
'FunctionMessageChunk': '#/$defs/FunctionMessageChunk',
'HumanMessageChunk': '#/$defs/HumanMessageChunk',
'SystemMessageChunk': '#/$defs/SystemMessageChunk',
'ToolMessageChunk': '#/$defs/ToolMessageChunk',
'ai': '#/$defs/AIMessage',
'chat': '#/$defs/ChatMessage',
'function': '#/$defs/FunctionMessage',
'human': '#/$defs/HumanMessage',
'system': '#/$defs/SystemMessage',
'tool': '#/$defs/ToolMessage',
}),
'propertyName': 'type',
}),
'oneOf': list([ 'oneOf': list([
dict({ dict({
'$ref': '#/$defs/AIMessage', '$ref': '#/$defs/AIMessage',

View File

@@ -6855,23 +6855,6 @@
'properties': dict({ 'properties': dict({
'history': dict({ 'history': dict({
'items': dict({ 'items': dict({
'discriminator': dict({
'mapping': dict({
'AIMessageChunk': '#/$defs/AIMessageChunk',
'ChatMessageChunk': '#/$defs/ChatMessageChunk',
'FunctionMessageChunk': '#/$defs/FunctionMessageChunk',
'HumanMessageChunk': '#/$defs/HumanMessageChunk',
'SystemMessageChunk': '#/$defs/SystemMessageChunk',
'ToolMessageChunk': '#/$defs/ToolMessageChunk',
'ai': '#/$defs/AIMessage',
'chat': '#/$defs/ChatMessage',
'function': '#/$defs/FunctionMessage',
'human': '#/$defs/HumanMessage',
'system': '#/$defs/SystemMessage',
'tool': '#/$defs/ToolMessage',
}),
'propertyName': 'type',
}),
'oneOf': list([ 'oneOf': list([
dict({ dict({
'$ref': '#/$defs/AIMessage', '$ref': '#/$defs/AIMessage',
@@ -7320,23 +7303,6 @@
'properties': dict({ 'properties': dict({
'messages': dict({ 'messages': dict({
'items': dict({ 'items': dict({
'discriminator': dict({
'mapping': dict({
'AIMessageChunk': '#/$defs/AIMessageChunk',
'ChatMessageChunk': '#/$defs/ChatMessageChunk',
'FunctionMessageChunk': '#/$defs/FunctionMessageChunk',
'HumanMessageChunk': '#/$defs/HumanMessageChunk',
'SystemMessageChunk': '#/$defs/SystemMessageChunk',
'ToolMessageChunk': '#/$defs/ToolMessageChunk',
'ai': '#/$defs/AIMessage',
'chat': '#/$defs/ChatMessage',
'function': '#/$defs/FunctionMessage',
'human': '#/$defs/HumanMessage',
'system': '#/$defs/SystemMessage',
'tool': '#/$defs/ToolMessage',
}),
'propertyName': 'type',
}),
'oneOf': list([ 'oneOf': list([
dict({ dict({
'$ref': '#/$defs/AIMessage', '$ref': '#/$defs/AIMessage',
@@ -8399,23 +8365,6 @@
}), }),
dict({ dict({
'items': dict({ 'items': dict({
'discriminator': dict({
'mapping': dict({
'AIMessageChunk': '#/definitions/AIMessageChunk',
'ChatMessageChunk': '#/definitions/ChatMessageChunk',
'FunctionMessageChunk': '#/definitions/FunctionMessageChunk',
'HumanMessageChunk': '#/definitions/HumanMessageChunk',
'SystemMessageChunk': '#/definitions/SystemMessageChunk',
'ToolMessageChunk': '#/definitions/ToolMessageChunk',
'ai': '#/definitions/AIMessage',
'chat': '#/definitions/ChatMessage',
'function': '#/definitions/FunctionMessage',
'human': '#/definitions/HumanMessage',
'system': '#/definitions/SystemMessage',
'tool': '#/definitions/ToolMessage',
}),
'propertyName': 'type',
}),
'oneOf': list([ 'oneOf': list([
dict({ dict({
'$ref': '#/definitions/AIMessage', '$ref': '#/definitions/AIMessage',
@@ -8854,23 +8803,6 @@
'properties': dict({ 'properties': dict({
'messages': dict({ 'messages': dict({
'items': dict({ 'items': dict({
'discriminator': dict({
'mapping': dict({
'AIMessageChunk': '#/definitions/AIMessageChunk',
'ChatMessageChunk': '#/definitions/ChatMessageChunk',
'FunctionMessageChunk': '#/definitions/FunctionMessageChunk',
'HumanMessageChunk': '#/definitions/HumanMessageChunk',
'SystemMessageChunk': '#/definitions/SystemMessageChunk',
'ToolMessageChunk': '#/definitions/ToolMessageChunk',
'ai': '#/definitions/AIMessage',
'chat': '#/definitions/ChatMessage',
'function': '#/definitions/FunctionMessage',
'human': '#/definitions/HumanMessage',
'system': '#/definitions/SystemMessage',
'tool': '#/definitions/ToolMessage',
}),
'propertyName': 'type',
}),
'oneOf': list([ 'oneOf': list([
dict({ dict({
'$ref': '#/definitions/AIMessage', '$ref': '#/definitions/AIMessage',
@@ -11253,23 +11185,6 @@
'type': 'object', 'type': 'object',
}), }),
}), }),
'discriminator': dict({
'mapping': dict({
'AIMessageChunk': '#/definitions/AIMessageChunk',
'ChatMessageChunk': '#/definitions/ChatMessageChunk',
'FunctionMessageChunk': '#/definitions/FunctionMessageChunk',
'HumanMessageChunk': '#/definitions/HumanMessageChunk',
'SystemMessageChunk': '#/definitions/SystemMessageChunk',
'ToolMessageChunk': '#/definitions/ToolMessageChunk',
'ai': '#/definitions/AIMessage',
'chat': '#/definitions/ChatMessage',
'function': '#/definitions/FunctionMessage',
'human': '#/definitions/HumanMessage',
'system': '#/definitions/SystemMessage',
'tool': '#/definitions/ToolMessage',
}),
'propertyName': 'type',
}),
'oneOf': list([ 'oneOf': list([
dict({ dict({
'$ref': '#/definitions/AIMessage', '$ref': '#/definitions/AIMessage',
@@ -11325,23 +11240,6 @@
}), }),
dict({ dict({
'items': dict({ 'items': dict({
'discriminator': dict({
'mapping': dict({
'AIMessageChunk': '#/definitions/AIMessageChunk',
'ChatMessageChunk': '#/definitions/ChatMessageChunk',
'FunctionMessageChunk': '#/definitions/FunctionMessageChunk',
'HumanMessageChunk': '#/definitions/HumanMessageChunk',
'SystemMessageChunk': '#/definitions/SystemMessageChunk',
'ToolMessageChunk': '#/definitions/ToolMessageChunk',
'ai': '#/definitions/AIMessage',
'chat': '#/definitions/ChatMessage',
'function': '#/definitions/FunctionMessage',
'human': '#/definitions/HumanMessage',
'system': '#/definitions/SystemMessage',
'tool': '#/definitions/ToolMessage',
}),
'propertyName': 'type',
}),
'oneOf': list([ 'oneOf': list([
dict({ dict({
'$ref': '#/definitions/AIMessage', '$ref': '#/definitions/AIMessage',
@@ -11780,23 +11678,6 @@
'properties': dict({ 'properties': dict({
'messages': dict({ 'messages': dict({
'items': dict({ 'items': dict({
'discriminator': dict({
'mapping': dict({
'AIMessageChunk': '#/definitions/AIMessageChunk',
'ChatMessageChunk': '#/definitions/ChatMessageChunk',
'FunctionMessageChunk': '#/definitions/FunctionMessageChunk',
'HumanMessageChunk': '#/definitions/HumanMessageChunk',
'SystemMessageChunk': '#/definitions/SystemMessageChunk',
'ToolMessageChunk': '#/definitions/ToolMessageChunk',
'ai': '#/definitions/AIMessage',
'chat': '#/definitions/ChatMessage',
'function': '#/definitions/FunctionMessage',
'human': '#/definitions/HumanMessage',
'system': '#/definitions/SystemMessage',
'tool': '#/definitions/ToolMessage',
}),
'propertyName': 'type',
}),
'oneOf': list([ 'oneOf': list([
dict({ dict({
'$ref': '#/definitions/AIMessage', '$ref': '#/definitions/AIMessage',
@@ -12842,23 +12723,6 @@
'type': 'string', 'type': 'string',
}), }),
dict({ dict({
'discriminator': dict({
'mapping': dict({
'AIMessageChunk': '#/definitions/AIMessageChunk',
'ChatMessageChunk': '#/definitions/ChatMessageChunk',
'FunctionMessageChunk': '#/definitions/FunctionMessageChunk',
'HumanMessageChunk': '#/definitions/HumanMessageChunk',
'SystemMessageChunk': '#/definitions/SystemMessageChunk',
'ToolMessageChunk': '#/definitions/ToolMessageChunk',
'ai': '#/definitions/AIMessage',
'chat': '#/definitions/ChatMessage',
'function': '#/definitions/FunctionMessage',
'human': '#/definitions/HumanMessage',
'system': '#/definitions/SystemMessage',
'tool': '#/definitions/ToolMessage',
}),
'propertyName': 'type',
}),
'oneOf': list([ 'oneOf': list([
dict({ dict({
'$ref': '#/definitions/AIMessage', '$ref': '#/definitions/AIMessage',
@@ -14642,23 +14506,6 @@
'properties': dict({ 'properties': dict({
'messages': dict({ 'messages': dict({
'items': dict({ 'items': dict({
'discriminator': dict({
'mapping': dict({
'AIMessageChunk': '#/definitions/AIMessageChunk',
'ChatMessageChunk': '#/definitions/ChatMessageChunk',
'FunctionMessageChunk': '#/definitions/FunctionMessageChunk',
'HumanMessageChunk': '#/definitions/HumanMessageChunk',
'SystemMessageChunk': '#/definitions/SystemMessageChunk',
'ToolMessageChunk': '#/definitions/ToolMessageChunk',
'ai': '#/definitions/AIMessage',
'chat': '#/definitions/ChatMessage',
'function': '#/definitions/FunctionMessage',
'human': '#/definitions/HumanMessage',
'system': '#/definitions/SystemMessage',
'tool': '#/definitions/ToolMessage',
}),
'propertyName': 'type',
}),
'oneOf': list([ 'oneOf': list([
dict({ dict({
'$ref': '#/definitions/AIMessage', '$ref': '#/definitions/AIMessage',
@@ -16118,23 +15965,6 @@
'properties': dict({ 'properties': dict({
'messages': dict({ 'messages': dict({
'items': dict({ 'items': dict({
'discriminator': dict({
'mapping': dict({
'AIMessageChunk': '#/definitions/AIMessageChunk',
'ChatMessageChunk': '#/definitions/ChatMessageChunk',
'FunctionMessageChunk': '#/definitions/FunctionMessageChunk',
'HumanMessageChunk': '#/definitions/HumanMessageChunk',
'SystemMessageChunk': '#/definitions/SystemMessageChunk',
'ToolMessageChunk': '#/definitions/ToolMessageChunk',
'ai': '#/definitions/AIMessage',
'chat': '#/definitions/ChatMessage',
'function': '#/definitions/FunctionMessage',
'human': '#/definitions/HumanMessage',
'system': '#/definitions/SystemMessage',
'tool': '#/definitions/ToolMessage',
}),
'propertyName': 'type',
}),
'oneOf': list([ 'oneOf': list([
dict({ dict({
'$ref': '#/definitions/AIMessage', '$ref': '#/definitions/AIMessage',

View File

@@ -0,0 +1,112 @@
"""A set of tests that verifies that Union discrimination works correctly with
the various pydantic base models.
These tests can uncover issues that will also arise during regular instantiation
of the models (i.e., not necessarily from loading or dumping JSON).
"""
import pytest
from pydantic import RootModel, ValidationError
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
AnyMessage,
ChatMessage,
ChatMessageChunk,
FunctionMessage,
FunctionMessageChunk,
HumanMessage,
HumanMessageChunk,
SystemMessage,
SystemMessageChunk,
)
from langchain_core.prompt_values import ChatPromptValueConcrete
def test_serde_any_message() -> None:
"""Test AnyMessage() serder."""
lc_objects = [
HumanMessage(content="human"),
HumanMessageChunk(content="human"),
AIMessage(content="ai"),
AIMessageChunk(content="ai"),
SystemMessage(content="sys"),
SystemMessageChunk(content="sys"),
FunctionMessage(
name="func",
content="func",
),
FunctionMessageChunk(
name="func",
content="func",
),
ChatMessage(
role="human",
content="human",
),
ChatMessageChunk(
role="human",
content="human",
),
]
Model = RootModel[AnyMessage]
for lc_object in lc_objects:
d = lc_object.model_dump()
assert "type" in d, f"Missing key `type` for {type(lc_object)}"
obj1 = Model.model_validate(d)
assert type(obj1.root) is type(lc_object), f"failed for {type(lc_object)}"
with pytest.raises((TypeError, ValidationError)):
# Make sure that specifically validation error is raised
Model.model_validate({})
def test_serde_chat_prompt_value():
prompt = ChatPromptValueConcrete(
messages=[
AIMessage(
content="Hello",
),
HumanMessage(
content=" World",
)
]
)
# Derived = RootModel[Sequence[Any_]]
def test_kookoo():
import pydantic
from pydantic import __version__
from typing import Annotated, Union, Literal, Sequence, Any
from pydantic import BaseModel, Field, Tag, RootModel, Discriminator
import pprint
class Base(BaseModel):
y: int = 'hello'
type: Literal['base'] = 'base'
class Foo(Base):
type: Literal['foo'] = 'foo'
x: int
class Bar(Base):
type: Literal['bar'] = 'bar'
x: int
FooOrBar = Annotated[Union[Foo, Bar], Field(discriminator="type")]
class BaseContainer(BaseModel):
messages: Sequence[Base]
class Container(BaseModel):
messages: Sequence[FooOrBar]
Container(messages=[Foo(x=5), Bar(x=2), Foo(x=10)])