mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-26 05:48:40 +00:00
qxqxqx
This commit is contained in:
@@ -53,16 +53,6 @@ AnyMessage = Annotated[
|
||||
Union[
|
||||
AIMessage,
|
||||
HumanMessage,
|
||||
ChatMessage,
|
||||
SystemMessage,
|
||||
FunctionMessage,
|
||||
ToolMessage,
|
||||
AIMessageChunk,
|
||||
HumanMessageChunk,
|
||||
ChatMessageChunk,
|
||||
SystemMessageChunk,
|
||||
FunctionMessageChunk,
|
||||
ToolMessageChunk,
|
||||
],
|
||||
Field(discriminator=Discriminator("type")),
|
||||
]
|
||||
|
@@ -125,12 +125,24 @@ class ImagePromptValue(PromptValue):
|
||||
"""Return prompt (image URL) as messages."""
|
||||
return [HumanMessage(content=[cast(dict, self.image_url)])]
|
||||
|
||||
from typing import Annotated, Union
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from pydantic import Field, Discriminator
|
||||
|
||||
# AnyMessage = Annotated[
|
||||
# Union[
|
||||
# AIMessage,
|
||||
# HumanMessage,
|
||||
# ],
|
||||
# Field(discriminator=Discriminator("type")),
|
||||
# ]
|
||||
|
||||
|
||||
class ChatPromptValueConcrete(ChatPromptValue):
|
||||
"""Chat prompt value which explicitly lists out the message types it accepts.
|
||||
For use in external schemas."""
|
||||
|
||||
messages: Sequence[AnyMessage]
|
||||
messages: Sequence[Annotated[Union[AIMessage, HumanMessage], Field(discriminator="type")]]
|
||||
"""Sequence of messages."""
|
||||
|
||||
type: Literal["ChatPromptValueConcrete"] = "ChatPromptValueConcrete"
|
||||
@@ -142,3 +154,6 @@ class ChatPromptValueConcrete(ChatPromptValue):
|
||||
Defaults to ["langchain", "prompts", "chat"].
|
||||
"""
|
||||
return ["langchain", "prompts", "chat"]
|
||||
|
||||
|
||||
ChatPromptValueConcrete.model_rebuild()
|
||||
|
@@ -1349,23 +1349,6 @@
|
||||
'history': dict({
|
||||
'default': None,
|
||||
'items': dict({
|
||||
'discriminator': dict({
|
||||
'mapping': dict({
|
||||
'AIMessageChunk': '#/$defs/AIMessageChunk',
|
||||
'ChatMessageChunk': '#/$defs/ChatMessageChunk',
|
||||
'FunctionMessageChunk': '#/$defs/FunctionMessageChunk',
|
||||
'HumanMessageChunk': '#/$defs/HumanMessageChunk',
|
||||
'SystemMessageChunk': '#/$defs/SystemMessageChunk',
|
||||
'ToolMessageChunk': '#/$defs/ToolMessageChunk',
|
||||
'ai': '#/$defs/AIMessage',
|
||||
'chat': '#/$defs/ChatMessage',
|
||||
'function': '#/$defs/FunctionMessage',
|
||||
'human': '#/$defs/HumanMessage',
|
||||
'system': '#/$defs/SystemMessage',
|
||||
'tool': '#/$defs/ToolMessage',
|
||||
}),
|
||||
'propertyName': 'type',
|
||||
}),
|
||||
'oneOf': list([
|
||||
dict({
|
||||
'$ref': '#/$defs/AIMessage',
|
||||
@@ -2769,23 +2752,6 @@
|
||||
'properties': dict({
|
||||
'history': dict({
|
||||
'items': dict({
|
||||
'discriminator': dict({
|
||||
'mapping': dict({
|
||||
'AIMessageChunk': '#/$defs/AIMessageChunk',
|
||||
'ChatMessageChunk': '#/$defs/ChatMessageChunk',
|
||||
'FunctionMessageChunk': '#/$defs/FunctionMessageChunk',
|
||||
'HumanMessageChunk': '#/$defs/HumanMessageChunk',
|
||||
'SystemMessageChunk': '#/$defs/SystemMessageChunk',
|
||||
'ToolMessageChunk': '#/$defs/ToolMessageChunk',
|
||||
'ai': '#/$defs/AIMessage',
|
||||
'chat': '#/$defs/ChatMessage',
|
||||
'function': '#/$defs/FunctionMessage',
|
||||
'human': '#/$defs/HumanMessage',
|
||||
'system': '#/$defs/SystemMessage',
|
||||
'tool': '#/$defs/ToolMessage',
|
||||
}),
|
||||
'propertyName': 'type',
|
||||
}),
|
||||
'oneOf': list([
|
||||
dict({
|
||||
'$ref': '#/$defs/AIMessage',
|
||||
|
@@ -1710,23 +1710,6 @@
|
||||
'type': 'string',
|
||||
}),
|
||||
dict({
|
||||
'discriminator': dict({
|
||||
'mapping': dict({
|
||||
'AIMessageChunk': '#/$defs/AIMessageChunk',
|
||||
'ChatMessageChunk': '#/$defs/ChatMessageChunk',
|
||||
'FunctionMessageChunk': '#/$defs/FunctionMessageChunk',
|
||||
'HumanMessageChunk': '#/$defs/HumanMessageChunk',
|
||||
'SystemMessageChunk': '#/$defs/SystemMessageChunk',
|
||||
'ToolMessageChunk': '#/$defs/ToolMessageChunk',
|
||||
'ai': '#/$defs/AIMessage',
|
||||
'chat': '#/$defs/ChatMessage',
|
||||
'function': '#/$defs/FunctionMessage',
|
||||
'human': '#/$defs/HumanMessage',
|
||||
'system': '#/$defs/SystemMessage',
|
||||
'tool': '#/$defs/ToolMessage',
|
||||
}),
|
||||
'propertyName': 'type',
|
||||
}),
|
||||
'oneOf': list([
|
||||
dict({
|
||||
'$ref': '#/$defs/AIMessage',
|
||||
|
@@ -6855,23 +6855,6 @@
|
||||
'properties': dict({
|
||||
'history': dict({
|
||||
'items': dict({
|
||||
'discriminator': dict({
|
||||
'mapping': dict({
|
||||
'AIMessageChunk': '#/$defs/AIMessageChunk',
|
||||
'ChatMessageChunk': '#/$defs/ChatMessageChunk',
|
||||
'FunctionMessageChunk': '#/$defs/FunctionMessageChunk',
|
||||
'HumanMessageChunk': '#/$defs/HumanMessageChunk',
|
||||
'SystemMessageChunk': '#/$defs/SystemMessageChunk',
|
||||
'ToolMessageChunk': '#/$defs/ToolMessageChunk',
|
||||
'ai': '#/$defs/AIMessage',
|
||||
'chat': '#/$defs/ChatMessage',
|
||||
'function': '#/$defs/FunctionMessage',
|
||||
'human': '#/$defs/HumanMessage',
|
||||
'system': '#/$defs/SystemMessage',
|
||||
'tool': '#/$defs/ToolMessage',
|
||||
}),
|
||||
'propertyName': 'type',
|
||||
}),
|
||||
'oneOf': list([
|
||||
dict({
|
||||
'$ref': '#/$defs/AIMessage',
|
||||
@@ -7320,23 +7303,6 @@
|
||||
'properties': dict({
|
||||
'messages': dict({
|
||||
'items': dict({
|
||||
'discriminator': dict({
|
||||
'mapping': dict({
|
||||
'AIMessageChunk': '#/$defs/AIMessageChunk',
|
||||
'ChatMessageChunk': '#/$defs/ChatMessageChunk',
|
||||
'FunctionMessageChunk': '#/$defs/FunctionMessageChunk',
|
||||
'HumanMessageChunk': '#/$defs/HumanMessageChunk',
|
||||
'SystemMessageChunk': '#/$defs/SystemMessageChunk',
|
||||
'ToolMessageChunk': '#/$defs/ToolMessageChunk',
|
||||
'ai': '#/$defs/AIMessage',
|
||||
'chat': '#/$defs/ChatMessage',
|
||||
'function': '#/$defs/FunctionMessage',
|
||||
'human': '#/$defs/HumanMessage',
|
||||
'system': '#/$defs/SystemMessage',
|
||||
'tool': '#/$defs/ToolMessage',
|
||||
}),
|
||||
'propertyName': 'type',
|
||||
}),
|
||||
'oneOf': list([
|
||||
dict({
|
||||
'$ref': '#/$defs/AIMessage',
|
||||
@@ -8399,23 +8365,6 @@
|
||||
}),
|
||||
dict({
|
||||
'items': dict({
|
||||
'discriminator': dict({
|
||||
'mapping': dict({
|
||||
'AIMessageChunk': '#/definitions/AIMessageChunk',
|
||||
'ChatMessageChunk': '#/definitions/ChatMessageChunk',
|
||||
'FunctionMessageChunk': '#/definitions/FunctionMessageChunk',
|
||||
'HumanMessageChunk': '#/definitions/HumanMessageChunk',
|
||||
'SystemMessageChunk': '#/definitions/SystemMessageChunk',
|
||||
'ToolMessageChunk': '#/definitions/ToolMessageChunk',
|
||||
'ai': '#/definitions/AIMessage',
|
||||
'chat': '#/definitions/ChatMessage',
|
||||
'function': '#/definitions/FunctionMessage',
|
||||
'human': '#/definitions/HumanMessage',
|
||||
'system': '#/definitions/SystemMessage',
|
||||
'tool': '#/definitions/ToolMessage',
|
||||
}),
|
||||
'propertyName': 'type',
|
||||
}),
|
||||
'oneOf': list([
|
||||
dict({
|
||||
'$ref': '#/definitions/AIMessage',
|
||||
@@ -8854,23 +8803,6 @@
|
||||
'properties': dict({
|
||||
'messages': dict({
|
||||
'items': dict({
|
||||
'discriminator': dict({
|
||||
'mapping': dict({
|
||||
'AIMessageChunk': '#/definitions/AIMessageChunk',
|
||||
'ChatMessageChunk': '#/definitions/ChatMessageChunk',
|
||||
'FunctionMessageChunk': '#/definitions/FunctionMessageChunk',
|
||||
'HumanMessageChunk': '#/definitions/HumanMessageChunk',
|
||||
'SystemMessageChunk': '#/definitions/SystemMessageChunk',
|
||||
'ToolMessageChunk': '#/definitions/ToolMessageChunk',
|
||||
'ai': '#/definitions/AIMessage',
|
||||
'chat': '#/definitions/ChatMessage',
|
||||
'function': '#/definitions/FunctionMessage',
|
||||
'human': '#/definitions/HumanMessage',
|
||||
'system': '#/definitions/SystemMessage',
|
||||
'tool': '#/definitions/ToolMessage',
|
||||
}),
|
||||
'propertyName': 'type',
|
||||
}),
|
||||
'oneOf': list([
|
||||
dict({
|
||||
'$ref': '#/definitions/AIMessage',
|
||||
@@ -11253,23 +11185,6 @@
|
||||
'type': 'object',
|
||||
}),
|
||||
}),
|
||||
'discriminator': dict({
|
||||
'mapping': dict({
|
||||
'AIMessageChunk': '#/definitions/AIMessageChunk',
|
||||
'ChatMessageChunk': '#/definitions/ChatMessageChunk',
|
||||
'FunctionMessageChunk': '#/definitions/FunctionMessageChunk',
|
||||
'HumanMessageChunk': '#/definitions/HumanMessageChunk',
|
||||
'SystemMessageChunk': '#/definitions/SystemMessageChunk',
|
||||
'ToolMessageChunk': '#/definitions/ToolMessageChunk',
|
||||
'ai': '#/definitions/AIMessage',
|
||||
'chat': '#/definitions/ChatMessage',
|
||||
'function': '#/definitions/FunctionMessage',
|
||||
'human': '#/definitions/HumanMessage',
|
||||
'system': '#/definitions/SystemMessage',
|
||||
'tool': '#/definitions/ToolMessage',
|
||||
}),
|
||||
'propertyName': 'type',
|
||||
}),
|
||||
'oneOf': list([
|
||||
dict({
|
||||
'$ref': '#/definitions/AIMessage',
|
||||
@@ -11325,23 +11240,6 @@
|
||||
}),
|
||||
dict({
|
||||
'items': dict({
|
||||
'discriminator': dict({
|
||||
'mapping': dict({
|
||||
'AIMessageChunk': '#/definitions/AIMessageChunk',
|
||||
'ChatMessageChunk': '#/definitions/ChatMessageChunk',
|
||||
'FunctionMessageChunk': '#/definitions/FunctionMessageChunk',
|
||||
'HumanMessageChunk': '#/definitions/HumanMessageChunk',
|
||||
'SystemMessageChunk': '#/definitions/SystemMessageChunk',
|
||||
'ToolMessageChunk': '#/definitions/ToolMessageChunk',
|
||||
'ai': '#/definitions/AIMessage',
|
||||
'chat': '#/definitions/ChatMessage',
|
||||
'function': '#/definitions/FunctionMessage',
|
||||
'human': '#/definitions/HumanMessage',
|
||||
'system': '#/definitions/SystemMessage',
|
||||
'tool': '#/definitions/ToolMessage',
|
||||
}),
|
||||
'propertyName': 'type',
|
||||
}),
|
||||
'oneOf': list([
|
||||
dict({
|
||||
'$ref': '#/definitions/AIMessage',
|
||||
@@ -11780,23 +11678,6 @@
|
||||
'properties': dict({
|
||||
'messages': dict({
|
||||
'items': dict({
|
||||
'discriminator': dict({
|
||||
'mapping': dict({
|
||||
'AIMessageChunk': '#/definitions/AIMessageChunk',
|
||||
'ChatMessageChunk': '#/definitions/ChatMessageChunk',
|
||||
'FunctionMessageChunk': '#/definitions/FunctionMessageChunk',
|
||||
'HumanMessageChunk': '#/definitions/HumanMessageChunk',
|
||||
'SystemMessageChunk': '#/definitions/SystemMessageChunk',
|
||||
'ToolMessageChunk': '#/definitions/ToolMessageChunk',
|
||||
'ai': '#/definitions/AIMessage',
|
||||
'chat': '#/definitions/ChatMessage',
|
||||
'function': '#/definitions/FunctionMessage',
|
||||
'human': '#/definitions/HumanMessage',
|
||||
'system': '#/definitions/SystemMessage',
|
||||
'tool': '#/definitions/ToolMessage',
|
||||
}),
|
||||
'propertyName': 'type',
|
||||
}),
|
||||
'oneOf': list([
|
||||
dict({
|
||||
'$ref': '#/definitions/AIMessage',
|
||||
@@ -12842,23 +12723,6 @@
|
||||
'type': 'string',
|
||||
}),
|
||||
dict({
|
||||
'discriminator': dict({
|
||||
'mapping': dict({
|
||||
'AIMessageChunk': '#/definitions/AIMessageChunk',
|
||||
'ChatMessageChunk': '#/definitions/ChatMessageChunk',
|
||||
'FunctionMessageChunk': '#/definitions/FunctionMessageChunk',
|
||||
'HumanMessageChunk': '#/definitions/HumanMessageChunk',
|
||||
'SystemMessageChunk': '#/definitions/SystemMessageChunk',
|
||||
'ToolMessageChunk': '#/definitions/ToolMessageChunk',
|
||||
'ai': '#/definitions/AIMessage',
|
||||
'chat': '#/definitions/ChatMessage',
|
||||
'function': '#/definitions/FunctionMessage',
|
||||
'human': '#/definitions/HumanMessage',
|
||||
'system': '#/definitions/SystemMessage',
|
||||
'tool': '#/definitions/ToolMessage',
|
||||
}),
|
||||
'propertyName': 'type',
|
||||
}),
|
||||
'oneOf': list([
|
||||
dict({
|
||||
'$ref': '#/definitions/AIMessage',
|
||||
@@ -14642,23 +14506,6 @@
|
||||
'properties': dict({
|
||||
'messages': dict({
|
||||
'items': dict({
|
||||
'discriminator': dict({
|
||||
'mapping': dict({
|
||||
'AIMessageChunk': '#/definitions/AIMessageChunk',
|
||||
'ChatMessageChunk': '#/definitions/ChatMessageChunk',
|
||||
'FunctionMessageChunk': '#/definitions/FunctionMessageChunk',
|
||||
'HumanMessageChunk': '#/definitions/HumanMessageChunk',
|
||||
'SystemMessageChunk': '#/definitions/SystemMessageChunk',
|
||||
'ToolMessageChunk': '#/definitions/ToolMessageChunk',
|
||||
'ai': '#/definitions/AIMessage',
|
||||
'chat': '#/definitions/ChatMessage',
|
||||
'function': '#/definitions/FunctionMessage',
|
||||
'human': '#/definitions/HumanMessage',
|
||||
'system': '#/definitions/SystemMessage',
|
||||
'tool': '#/definitions/ToolMessage',
|
||||
}),
|
||||
'propertyName': 'type',
|
||||
}),
|
||||
'oneOf': list([
|
||||
dict({
|
||||
'$ref': '#/definitions/AIMessage',
|
||||
@@ -16118,23 +15965,6 @@
|
||||
'properties': dict({
|
||||
'messages': dict({
|
||||
'items': dict({
|
||||
'discriminator': dict({
|
||||
'mapping': dict({
|
||||
'AIMessageChunk': '#/definitions/AIMessageChunk',
|
||||
'ChatMessageChunk': '#/definitions/ChatMessageChunk',
|
||||
'FunctionMessageChunk': '#/definitions/FunctionMessageChunk',
|
||||
'HumanMessageChunk': '#/definitions/HumanMessageChunk',
|
||||
'SystemMessageChunk': '#/definitions/SystemMessageChunk',
|
||||
'ToolMessageChunk': '#/definitions/ToolMessageChunk',
|
||||
'ai': '#/definitions/AIMessage',
|
||||
'chat': '#/definitions/ChatMessage',
|
||||
'function': '#/definitions/FunctionMessage',
|
||||
'human': '#/definitions/HumanMessage',
|
||||
'system': '#/definitions/SystemMessage',
|
||||
'tool': '#/definitions/ToolMessage',
|
||||
}),
|
||||
'propertyName': 'type',
|
||||
}),
|
||||
'oneOf': list([
|
||||
dict({
|
||||
'$ref': '#/definitions/AIMessage',
|
||||
|
112
libs/core/tests/unit_tests/test_pydantic_serde.py
Normal file
112
libs/core/tests/unit_tests/test_pydantic_serde.py
Normal file
@@ -0,0 +1,112 @@
|
||||
"""A set of tests that verifies that Union discrimination works correctly with
|
||||
the various pydantic base models.
|
||||
|
||||
These tests can uncover issues that will also arise during regular instantiation
|
||||
of the models (i.e., not necessarily from loading or dumping JSON).
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from pydantic import RootModel, ValidationError
|
||||
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
AnyMessage,
|
||||
ChatMessage,
|
||||
ChatMessageChunk,
|
||||
FunctionMessage,
|
||||
FunctionMessageChunk,
|
||||
HumanMessage,
|
||||
HumanMessageChunk,
|
||||
SystemMessage,
|
||||
SystemMessageChunk,
|
||||
)
|
||||
from langchain_core.prompt_values import ChatPromptValueConcrete
|
||||
|
||||
|
||||
def test_serde_any_message() -> None:
|
||||
"""Test AnyMessage() serder."""
|
||||
|
||||
lc_objects = [
|
||||
HumanMessage(content="human"),
|
||||
HumanMessageChunk(content="human"),
|
||||
AIMessage(content="ai"),
|
||||
AIMessageChunk(content="ai"),
|
||||
SystemMessage(content="sys"),
|
||||
SystemMessageChunk(content="sys"),
|
||||
FunctionMessage(
|
||||
name="func",
|
||||
content="func",
|
||||
),
|
||||
FunctionMessageChunk(
|
||||
name="func",
|
||||
content="func",
|
||||
),
|
||||
ChatMessage(
|
||||
role="human",
|
||||
content="human",
|
||||
),
|
||||
ChatMessageChunk(
|
||||
role="human",
|
||||
content="human",
|
||||
),
|
||||
]
|
||||
|
||||
Model = RootModel[AnyMessage]
|
||||
|
||||
for lc_object in lc_objects:
|
||||
d = lc_object.model_dump()
|
||||
assert "type" in d, f"Missing key `type` for {type(lc_object)}"
|
||||
obj1 = Model.model_validate(d)
|
||||
assert type(obj1.root) is type(lc_object), f"failed for {type(lc_object)}"
|
||||
|
||||
with pytest.raises((TypeError, ValidationError)):
|
||||
# Make sure that specifically validation error is raised
|
||||
Model.model_validate({})
|
||||
|
||||
|
||||
def test_serde_chat_prompt_value():
|
||||
prompt = ChatPromptValueConcrete(
|
||||
messages=[
|
||||
AIMessage(
|
||||
content="Hello",
|
||||
),
|
||||
HumanMessage(
|
||||
content=" World",
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Derived = RootModel[Sequence[Any_]]
|
||||
|
||||
|
||||
def test_kookoo():
|
||||
import pydantic
|
||||
from pydantic import __version__
|
||||
|
||||
from typing import Annotated, Union, Literal, Sequence, Any
|
||||
from pydantic import BaseModel, Field, Tag, RootModel, Discriminator
|
||||
import pprint
|
||||
class Base(BaseModel):
|
||||
y: int = 'hello'
|
||||
type: Literal['base'] = 'base'
|
||||
|
||||
class Foo(Base):
|
||||
type: Literal['foo'] = 'foo'
|
||||
x: int
|
||||
|
||||
class Bar(Base):
|
||||
type: Literal['bar'] = 'bar'
|
||||
x: int
|
||||
|
||||
FooOrBar = Annotated[Union[Foo, Bar], Field(discriminator="type")]
|
||||
|
||||
|
||||
class BaseContainer(BaseModel):
|
||||
messages: Sequence[Base]
|
||||
|
||||
class Container(BaseModel):
|
||||
messages: Sequence[FooOrBar]
|
||||
|
||||
|
||||
Container(messages=[Foo(x=5), Bar(x=2), Foo(x=10)])
|
Reference in New Issue
Block a user