core[patch]: fix ChatPromptValueConcrete typing (#26106)

Thank you for contributing to LangChain!

- [ ] **PR title**: "package: description"
- Where "package" is whichever of langchain, community, core,
experimental, etc. is being modified. Use "docs: ..." for purely docs
changes, "templates: ..." for template changes, "infra: ..." for CI
changes.
  - Example: "community: add foobar LLM"


- [ ] **PR message**: ***Delete this entire checklist*** and replace
with
    - **Description:** a description of the change
    - **Issue:** the issue # it fixes, if applicable
    - **Dependencies:** any dependencies required for this change
- **Twitter handle:** if your PR gets announced, and you'd like a
mention, we'll gladly shout you out!


- [ ] **Add tests and docs**: If you're adding a new integration, please
include
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in
`docs/docs/integrations` directory.


- [ ] **Lint and test**: Run `make format`, `make lint` and `make test`
from the root of the package(s) you've modified. See contribution
guidelines for more: https://python.langchain.com/docs/contributing/

Additional guidelines:
- Make sure optional dependencies are imported within a function.
- Please do not add dependencies to pyproject.toml files (even optional
ones) unless they are required for unit tests.
- Most PRs should not touch more than one package.
- Changes should be backwards compatible.
- If you are adding something to community, do not re-import it in
langchain.

If no one reviews your PR within a few days, please @-mention one of
baskaryan, efriis, eyurtsev, ccurme, vbarda, hwchase17.
This commit is contained in:
Bagatur
2024-09-06 17:13:57 -04:00
committed by GitHub
parent 6df9360e32
commit b2c8f2de4c
6 changed files with 371 additions and 132 deletions

View File

@@ -29,6 +29,9 @@ from typing import (
overload, overload,
) )
from pydantic import Discriminator, Field
from typing_extensions import Annotated
from langchain_core.messages.ai import AIMessage, AIMessageChunk from langchain_core.messages.ai import AIMessage, AIMessageChunk
from langchain_core.messages.base import BaseMessage, BaseMessageChunk from langchain_core.messages.base import BaseMessage, BaseMessageChunk
from langchain_core.messages.chat import ChatMessage, ChatMessageChunk from langchain_core.messages.chat import ChatMessage, ChatMessageChunk
@@ -45,19 +48,23 @@ if TYPE_CHECKING:
from langchain_core.prompt_values import PromptValue from langchain_core.prompt_values import PromptValue
from langchain_core.runnables.base import Runnable from langchain_core.runnables.base import Runnable
AnyMessage = Union[
AIMessage, AnyMessage = Annotated[
HumanMessage, Union[
ChatMessage, AIMessage,
SystemMessage, HumanMessage,
FunctionMessage, ChatMessage,
ToolMessage, SystemMessage,
AIMessageChunk, FunctionMessage,
HumanMessageChunk, ToolMessage,
ChatMessageChunk, AIMessageChunk,
SystemMessageChunk, HumanMessageChunk,
FunctionMessageChunk, ChatMessageChunk,
ToolMessageChunk, SystemMessageChunk,
FunctionMessageChunk,
ToolMessageChunk,
],
Field(discriminator=Discriminator("type")),
] ]

View File

@@ -1349,7 +1349,24 @@
'history': dict({ 'history': dict({
'default': None, 'default': None,
'items': dict({ 'items': dict({
'anyOf': list([ 'discriminator': dict({
'mapping': dict({
'AIMessageChunk': '#/$defs/AIMessageChunk',
'ChatMessageChunk': '#/$defs/ChatMessageChunk',
'FunctionMessageChunk': '#/$defs/FunctionMessageChunk',
'HumanMessageChunk': '#/$defs/HumanMessageChunk',
'SystemMessageChunk': '#/$defs/SystemMessageChunk',
'ToolMessageChunk': '#/$defs/ToolMessageChunk',
'ai': '#/$defs/AIMessage',
'chat': '#/$defs/ChatMessage',
'function': '#/$defs/FunctionMessage',
'human': '#/$defs/HumanMessage',
'system': '#/$defs/SystemMessage',
'tool': '#/$defs/ToolMessage',
}),
'propertyName': 'type',
}),
'oneOf': list([
dict({ dict({
'$ref': '#/$defs/AIMessage', '$ref': '#/$defs/AIMessage',
}), }),
@@ -2752,7 +2769,24 @@
'properties': dict({ 'properties': dict({
'history': dict({ 'history': dict({
'items': dict({ 'items': dict({
'anyOf': list([ 'discriminator': dict({
'mapping': dict({
'AIMessageChunk': '#/$defs/AIMessageChunk',
'ChatMessageChunk': '#/$defs/ChatMessageChunk',
'FunctionMessageChunk': '#/$defs/FunctionMessageChunk',
'HumanMessageChunk': '#/$defs/HumanMessageChunk',
'SystemMessageChunk': '#/$defs/SystemMessageChunk',
'ToolMessageChunk': '#/$defs/ToolMessageChunk',
'ai': '#/$defs/AIMessage',
'chat': '#/$defs/ChatMessage',
'function': '#/$defs/FunctionMessage',
'human': '#/$defs/HumanMessage',
'system': '#/$defs/SystemMessage',
'tool': '#/$defs/ToolMessage',
}),
'propertyName': 'type',
}),
'oneOf': list([
dict({ dict({
'$ref': '#/$defs/AIMessage', '$ref': '#/$defs/AIMessage',
}), }),

View File

@@ -1710,40 +1710,61 @@
'type': 'string', 'type': 'string',
}), }),
dict({ dict({
'$ref': '#/$defs/AIMessage', 'discriminator': dict({
}), 'mapping': dict({
dict({ 'AIMessageChunk': '#/$defs/AIMessageChunk',
'$ref': '#/$defs/HumanMessage', 'ChatMessageChunk': '#/$defs/ChatMessageChunk',
}), 'FunctionMessageChunk': '#/$defs/FunctionMessageChunk',
dict({ 'HumanMessageChunk': '#/$defs/HumanMessageChunk',
'$ref': '#/$defs/ChatMessage', 'SystemMessageChunk': '#/$defs/SystemMessageChunk',
}), 'ToolMessageChunk': '#/$defs/ToolMessageChunk',
dict({ 'ai': '#/$defs/AIMessage',
'$ref': '#/$defs/SystemMessage', 'chat': '#/$defs/ChatMessage',
}), 'function': '#/$defs/FunctionMessage',
dict({ 'human': '#/$defs/HumanMessage',
'$ref': '#/$defs/FunctionMessage', 'system': '#/$defs/SystemMessage',
}), 'tool': '#/$defs/ToolMessage',
dict({ }),
'$ref': '#/$defs/ToolMessage', 'propertyName': 'type',
}), }),
dict({ 'oneOf': list([
'$ref': '#/$defs/AIMessageChunk', dict({
}), '$ref': '#/$defs/AIMessage',
dict({ }),
'$ref': '#/$defs/HumanMessageChunk', dict({
}), '$ref': '#/$defs/HumanMessage',
dict({ }),
'$ref': '#/$defs/ChatMessageChunk', dict({
}), '$ref': '#/$defs/ChatMessage',
dict({ }),
'$ref': '#/$defs/SystemMessageChunk', dict({
}), '$ref': '#/$defs/SystemMessage',
dict({ }),
'$ref': '#/$defs/FunctionMessageChunk', dict({
}), '$ref': '#/$defs/FunctionMessage',
dict({ }),
'$ref': '#/$defs/ToolMessageChunk', dict({
'$ref': '#/$defs/ToolMessage',
}),
dict({
'$ref': '#/$defs/AIMessageChunk',
}),
dict({
'$ref': '#/$defs/HumanMessageChunk',
}),
dict({
'$ref': '#/$defs/ChatMessageChunk',
}),
dict({
'$ref': '#/$defs/SystemMessageChunk',
}),
dict({
'$ref': '#/$defs/FunctionMessageChunk',
}),
dict({
'$ref': '#/$defs/ToolMessageChunk',
}),
]),
}), }),
]), ]),
'title': 'RunnableParallel<as_list,as_str>Input', 'title': 'RunnableParallel<as_list,as_str>Input',

View File

@@ -6525,7 +6525,24 @@
'properties': dict({ 'properties': dict({
'history': dict({ 'history': dict({
'items': dict({ 'items': dict({
'anyOf': list([ 'discriminator': dict({
'mapping': dict({
'AIMessageChunk': '#/$defs/AIMessageChunk',
'ChatMessageChunk': '#/$defs/ChatMessageChunk',
'FunctionMessageChunk': '#/$defs/FunctionMessageChunk',
'HumanMessageChunk': '#/$defs/HumanMessageChunk',
'SystemMessageChunk': '#/$defs/SystemMessageChunk',
'ToolMessageChunk': '#/$defs/ToolMessageChunk',
'ai': '#/$defs/AIMessage',
'chat': '#/$defs/ChatMessage',
'function': '#/$defs/FunctionMessage',
'human': '#/$defs/HumanMessage',
'system': '#/$defs/SystemMessage',
'tool': '#/$defs/ToolMessage',
}),
'propertyName': 'type',
}),
'oneOf': list([
dict({ dict({
'$ref': '#/$defs/AIMessage', '$ref': '#/$defs/AIMessage',
}), }),
@@ -6973,7 +6990,24 @@
'properties': dict({ 'properties': dict({
'messages': dict({ 'messages': dict({
'items': dict({ 'items': dict({
'anyOf': list([ 'discriminator': dict({
'mapping': dict({
'AIMessageChunk': '#/$defs/AIMessageChunk',
'ChatMessageChunk': '#/$defs/ChatMessageChunk',
'FunctionMessageChunk': '#/$defs/FunctionMessageChunk',
'HumanMessageChunk': '#/$defs/HumanMessageChunk',
'SystemMessageChunk': '#/$defs/SystemMessageChunk',
'ToolMessageChunk': '#/$defs/ToolMessageChunk',
'ai': '#/$defs/AIMessage',
'chat': '#/$defs/ChatMessage',
'function': '#/$defs/FunctionMessage',
'human': '#/$defs/HumanMessage',
'system': '#/$defs/SystemMessage',
'tool': '#/$defs/ToolMessage',
}),
'propertyName': 'type',
}),
'oneOf': list([
dict({ dict({
'$ref': '#/$defs/AIMessage', '$ref': '#/$defs/AIMessage',
}), }),
@@ -8035,7 +8069,24 @@
}), }),
dict({ dict({
'items': dict({ 'items': dict({
'anyOf': list([ 'discriminator': dict({
'mapping': dict({
'AIMessageChunk': '#/definitions/AIMessageChunk',
'ChatMessageChunk': '#/definitions/ChatMessageChunk',
'FunctionMessageChunk': '#/definitions/FunctionMessageChunk',
'HumanMessageChunk': '#/definitions/HumanMessageChunk',
'SystemMessageChunk': '#/definitions/SystemMessageChunk',
'ToolMessageChunk': '#/definitions/ToolMessageChunk',
'ai': '#/definitions/AIMessage',
'chat': '#/definitions/ChatMessage',
'function': '#/definitions/FunctionMessage',
'human': '#/definitions/HumanMessage',
'system': '#/definitions/SystemMessage',
'tool': '#/definitions/ToolMessage',
}),
'propertyName': 'type',
}),
'oneOf': list([
dict({ dict({
'$ref': '#/definitions/AIMessage', '$ref': '#/definitions/AIMessage',
}), }),
@@ -8473,7 +8524,24 @@
'properties': dict({ 'properties': dict({
'messages': dict({ 'messages': dict({
'items': dict({ 'items': dict({
'anyOf': list([ 'discriminator': dict({
'mapping': dict({
'AIMessageChunk': '#/definitions/AIMessageChunk',
'ChatMessageChunk': '#/definitions/ChatMessageChunk',
'FunctionMessageChunk': '#/definitions/FunctionMessageChunk',
'HumanMessageChunk': '#/definitions/HumanMessageChunk',
'SystemMessageChunk': '#/definitions/SystemMessageChunk',
'ToolMessageChunk': '#/definitions/ToolMessageChunk',
'ai': '#/definitions/AIMessage',
'chat': '#/definitions/ChatMessage',
'function': '#/definitions/FunctionMessage',
'human': '#/definitions/HumanMessage',
'system': '#/definitions/SystemMessage',
'tool': '#/definitions/ToolMessage',
}),
'propertyName': 'type',
}),
'oneOf': list([
dict({ dict({
'$ref': '#/definitions/AIMessage', '$ref': '#/definitions/AIMessage',
}), }),
@@ -9513,44 +9581,6 @@
# --- # ---
# name: test_schemas[fake_chat_output_schema] # name: test_schemas[fake_chat_output_schema]
dict({ dict({
'anyOf': list([
dict({
'$ref': '#/definitions/AIMessage',
}),
dict({
'$ref': '#/definitions/HumanMessage',
}),
dict({
'$ref': '#/definitions/ChatMessage',
}),
dict({
'$ref': '#/definitions/SystemMessage',
}),
dict({
'$ref': '#/definitions/FunctionMessage',
}),
dict({
'$ref': '#/definitions/ToolMessage',
}),
dict({
'$ref': '#/definitions/AIMessageChunk',
}),
dict({
'$ref': '#/definitions/HumanMessageChunk',
}),
dict({
'$ref': '#/definitions/ChatMessageChunk',
}),
dict({
'$ref': '#/definitions/SystemMessageChunk',
}),
dict({
'$ref': '#/definitions/FunctionMessageChunk',
}),
dict({
'$ref': '#/definitions/ToolMessageChunk',
}),
]),
'definitions': dict({ 'definitions': dict({
'AIMessage': dict({ 'AIMessage': dict({
'additionalProperties': True, 'additionalProperties': True,
@@ -10893,6 +10923,61 @@
'type': 'object', 'type': 'object',
}), }),
}), }),
'discriminator': dict({
'mapping': dict({
'AIMessageChunk': '#/definitions/AIMessageChunk',
'ChatMessageChunk': '#/definitions/ChatMessageChunk',
'FunctionMessageChunk': '#/definitions/FunctionMessageChunk',
'HumanMessageChunk': '#/definitions/HumanMessageChunk',
'SystemMessageChunk': '#/definitions/SystemMessageChunk',
'ToolMessageChunk': '#/definitions/ToolMessageChunk',
'ai': '#/definitions/AIMessage',
'chat': '#/definitions/ChatMessage',
'function': '#/definitions/FunctionMessage',
'human': '#/definitions/HumanMessage',
'system': '#/definitions/SystemMessage',
'tool': '#/definitions/ToolMessage',
}),
'propertyName': 'type',
}),
'oneOf': list([
dict({
'$ref': '#/definitions/AIMessage',
}),
dict({
'$ref': '#/definitions/HumanMessage',
}),
dict({
'$ref': '#/definitions/ChatMessage',
}),
dict({
'$ref': '#/definitions/SystemMessage',
}),
dict({
'$ref': '#/definitions/FunctionMessage',
}),
dict({
'$ref': '#/definitions/ToolMessage',
}),
dict({
'$ref': '#/definitions/AIMessageChunk',
}),
dict({
'$ref': '#/definitions/HumanMessageChunk',
}),
dict({
'$ref': '#/definitions/ChatMessageChunk',
}),
dict({
'$ref': '#/definitions/SystemMessageChunk',
}),
dict({
'$ref': '#/definitions/FunctionMessageChunk',
}),
dict({
'$ref': '#/definitions/ToolMessageChunk',
}),
]),
'title': 'FakeListChatModelOutput', 'title': 'FakeListChatModelOutput',
}) })
# --- # ---
@@ -10910,7 +10995,24 @@
}), }),
dict({ dict({
'items': dict({ 'items': dict({
'anyOf': list([ 'discriminator': dict({
'mapping': dict({
'AIMessageChunk': '#/definitions/AIMessageChunk',
'ChatMessageChunk': '#/definitions/ChatMessageChunk',
'FunctionMessageChunk': '#/definitions/FunctionMessageChunk',
'HumanMessageChunk': '#/definitions/HumanMessageChunk',
'SystemMessageChunk': '#/definitions/SystemMessageChunk',
'ToolMessageChunk': '#/definitions/ToolMessageChunk',
'ai': '#/definitions/AIMessage',
'chat': '#/definitions/ChatMessage',
'function': '#/definitions/FunctionMessage',
'human': '#/definitions/HumanMessage',
'system': '#/definitions/SystemMessage',
'tool': '#/definitions/ToolMessage',
}),
'propertyName': 'type',
}),
'oneOf': list([
dict({ dict({
'$ref': '#/definitions/AIMessage', '$ref': '#/definitions/AIMessage',
}), }),
@@ -11348,7 +11450,24 @@
'properties': dict({ 'properties': dict({
'messages': dict({ 'messages': dict({
'items': dict({ 'items': dict({
'anyOf': list([ 'discriminator': dict({
'mapping': dict({
'AIMessageChunk': '#/definitions/AIMessageChunk',
'ChatMessageChunk': '#/definitions/ChatMessageChunk',
'FunctionMessageChunk': '#/definitions/FunctionMessageChunk',
'HumanMessageChunk': '#/definitions/HumanMessageChunk',
'SystemMessageChunk': '#/definitions/SystemMessageChunk',
'ToolMessageChunk': '#/definitions/ToolMessageChunk',
'ai': '#/definitions/AIMessage',
'chat': '#/definitions/ChatMessage',
'function': '#/definitions/FunctionMessage',
'human': '#/definitions/HumanMessage',
'system': '#/definitions/SystemMessage',
'tool': '#/definitions/ToolMessage',
}),
'propertyName': 'type',
}),
'oneOf': list([
dict({ dict({
'$ref': '#/definitions/AIMessage', '$ref': '#/definitions/AIMessage',
}), }),
@@ -12393,40 +12512,61 @@
'type': 'string', 'type': 'string',
}), }),
dict({ dict({
'$ref': '#/definitions/AIMessage', 'discriminator': dict({
}), 'mapping': dict({
dict({ 'AIMessageChunk': '#/definitions/AIMessageChunk',
'$ref': '#/definitions/HumanMessage', 'ChatMessageChunk': '#/definitions/ChatMessageChunk',
}), 'FunctionMessageChunk': '#/definitions/FunctionMessageChunk',
dict({ 'HumanMessageChunk': '#/definitions/HumanMessageChunk',
'$ref': '#/definitions/ChatMessage', 'SystemMessageChunk': '#/definitions/SystemMessageChunk',
}), 'ToolMessageChunk': '#/definitions/ToolMessageChunk',
dict({ 'ai': '#/definitions/AIMessage',
'$ref': '#/definitions/SystemMessage', 'chat': '#/definitions/ChatMessage',
}), 'function': '#/definitions/FunctionMessage',
dict({ 'human': '#/definitions/HumanMessage',
'$ref': '#/definitions/FunctionMessage', 'system': '#/definitions/SystemMessage',
}), 'tool': '#/definitions/ToolMessage',
dict({ }),
'$ref': '#/definitions/ToolMessage', 'propertyName': 'type',
}), }),
dict({ 'oneOf': list([
'$ref': '#/definitions/AIMessageChunk', dict({
}), '$ref': '#/definitions/AIMessage',
dict({ }),
'$ref': '#/definitions/HumanMessageChunk', dict({
}), '$ref': '#/definitions/HumanMessage',
dict({ }),
'$ref': '#/definitions/ChatMessageChunk', dict({
}), '$ref': '#/definitions/ChatMessage',
dict({ }),
'$ref': '#/definitions/SystemMessageChunk', dict({
}), '$ref': '#/definitions/SystemMessage',
dict({ }),
'$ref': '#/definitions/FunctionMessageChunk', dict({
}), '$ref': '#/definitions/FunctionMessage',
dict({ }),
'$ref': '#/definitions/ToolMessageChunk', dict({
'$ref': '#/definitions/ToolMessage',
}),
dict({
'$ref': '#/definitions/AIMessageChunk',
}),
dict({
'$ref': '#/definitions/HumanMessageChunk',
}),
dict({
'$ref': '#/definitions/ChatMessageChunk',
}),
dict({
'$ref': '#/definitions/SystemMessageChunk',
}),
dict({
'$ref': '#/definitions/FunctionMessageChunk',
}),
dict({
'$ref': '#/definitions/ToolMessageChunk',
}),
]),
}), }),
]), ]),
'definitions': dict({ 'definitions': dict({
@@ -14172,7 +14312,24 @@
'properties': dict({ 'properties': dict({
'messages': dict({ 'messages': dict({
'items': dict({ 'items': dict({
'anyOf': list([ 'discriminator': dict({
'mapping': dict({
'AIMessageChunk': '#/definitions/AIMessageChunk',
'ChatMessageChunk': '#/definitions/ChatMessageChunk',
'FunctionMessageChunk': '#/definitions/FunctionMessageChunk',
'HumanMessageChunk': '#/definitions/HumanMessageChunk',
'SystemMessageChunk': '#/definitions/SystemMessageChunk',
'ToolMessageChunk': '#/definitions/ToolMessageChunk',
'ai': '#/definitions/AIMessage',
'chat': '#/definitions/ChatMessage',
'function': '#/definitions/FunctionMessage',
'human': '#/definitions/HumanMessage',
'system': '#/definitions/SystemMessage',
'tool': '#/definitions/ToolMessage',
}),
'propertyName': 'type',
}),
'oneOf': list([
dict({ dict({
'$ref': '#/definitions/AIMessage', '$ref': '#/definitions/AIMessage',
}), }),
@@ -15631,7 +15788,24 @@
'properties': dict({ 'properties': dict({
'messages': dict({ 'messages': dict({
'items': dict({ 'items': dict({
'anyOf': list([ 'discriminator': dict({
'mapping': dict({
'AIMessageChunk': '#/definitions/AIMessageChunk',
'ChatMessageChunk': '#/definitions/ChatMessageChunk',
'FunctionMessageChunk': '#/definitions/FunctionMessageChunk',
'HumanMessageChunk': '#/definitions/HumanMessageChunk',
'SystemMessageChunk': '#/definitions/SystemMessageChunk',
'ToolMessageChunk': '#/definitions/ToolMessageChunk',
'ai': '#/definitions/AIMessage',
'chat': '#/definitions/ChatMessage',
'function': '#/definitions/FunctionMessage',
'human': '#/definitions/HumanMessage',
'system': '#/definitions/SystemMessage',
'tool': '#/definitions/ToolMessage',
}),
'propertyName': 'type',
}),
'oneOf': list([
dict({ dict({
'$ref': '#/definitions/AIMessage', '$ref': '#/definitions/AIMessage',
}), }),

View File

@@ -1,5 +1,3 @@
import pytest
from langchain_core.messages import ( from langchain_core.messages import (
AIMessage, AIMessage,
AIMessageChunk, AIMessageChunk,
@@ -13,7 +11,6 @@ from langchain_core.messages import (
from langchain_core.prompt_values import ChatPromptValueConcrete from langchain_core.prompt_values import ChatPromptValueConcrete
@pytest.mark.xfail(reason="Broken union type.")
def test_chat_prompt_value_concrete() -> None: def test_chat_prompt_value_concrete() -> None:
messages: list = [ messages: list = [
AIMessage("foo"), AIMessage("foo"),

View File

@@ -16,6 +16,7 @@ from langchain_core.messages import (
HumanMessageChunk, HumanMessageChunk,
SystemMessage, SystemMessage,
SystemMessageChunk, SystemMessageChunk,
ToolMessage,
) )
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, Generation from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, Generation
from langchain_core.prompt_values import ChatPromptValueConcrete, StringPromptValue from langchain_core.prompt_values import ChatPromptValueConcrete, StringPromptValue
@@ -73,7 +74,12 @@ def test_serialization_of_wellknown_objects() -> None:
content="human", content="human",
), ),
StringPromptValue(text="hello"), StringPromptValue(text="hello"),
ChatPromptValueConcrete(messages=[AIMessage(content="foo")]),
ChatPromptValueConcrete(messages=[HumanMessage(content="human")]), ChatPromptValueConcrete(messages=[HumanMessage(content="human")]),
ChatPromptValueConcrete(
messages=[ToolMessage(content="foo", tool_call_id="bar")]
),
ChatPromptValueConcrete(messages=[SystemMessage(content="foo")]),
Document(page_content="hello"), Document(page_content="hello"),
AgentFinish(return_values={}, log=""), AgentFinish(return_values={}, log=""),
AgentAction(tool="tool", tool_input="input", log=""), AgentAction(tool="tool", tool_input="input", log=""),