mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-02 17:54:23 +00:00
test type narrowing option
This commit is contained in:
parent
bc5e8e0c17
commit
81a4a051ab
@ -8,6 +8,7 @@ from typing import Any, Literal, Optional, Union, cast
|
|||||||
from pydantic import model_validator
|
from pydantic import model_validator
|
||||||
from typing_extensions import NotRequired, Self, TypedDict, override
|
from typing_extensions import NotRequired, Self, TypedDict, override
|
||||||
|
|
||||||
|
from langchain_core.messages import ContentBlock
|
||||||
from langchain_core.messages.base import (
|
from langchain_core.messages.base import (
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
BaseMessageChunk,
|
BaseMessageChunk,
|
||||||
@ -178,7 +179,7 @@ class AIMessage(BaseMessage):
|
|||||||
"""The type of the message (used for deserialization). Defaults to "ai"."""
|
"""The type of the message (used for deserialization). Defaults to "ai"."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, content: Union[str, list[Union[str, dict]]], **kwargs: Any
|
self, content: Union[str, list[Union[str, ContentBlock, dict]]], **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Pass in content as positional arg.
|
"""Pass in content as positional arg.
|
||||||
|
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from typing import Any, Literal, Union
|
from typing import Any, Literal, Union
|
||||||
|
|
||||||
|
from langchain_core.messages import ContentBlock
|
||||||
from langchain_core.messages.base import BaseMessage, BaseMessageChunk
|
from langchain_core.messages.base import BaseMessage, BaseMessageChunk
|
||||||
|
|
||||||
|
|
||||||
@ -41,7 +42,7 @@ class HumanMessage(BaseMessage):
|
|||||||
"""The type of the message (used for serialization). Defaults to "human"."""
|
"""The type of the message (used for serialization). Defaults to "human"."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, content: Union[str, list[Union[str, dict]]], **kwargs: Any
|
self, content: Union[str, list[Union[str, ContentBlock, dict]]], **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Pass in content as positional arg.
|
"""Pass in content as positional arg.
|
||||||
|
|
||||||
|
@ -31,7 +31,10 @@ from typing import (
|
|||||||
from pydantic import Discriminator, Field, Tag
|
from pydantic import Discriminator, Field, Tag
|
||||||
|
|
||||||
from langchain_core.exceptions import ErrorCode, create_message
|
from langchain_core.exceptions import ErrorCode, create_message
|
||||||
from langchain_core.messages import convert_to_openai_data_block, is_data_content_block
|
from langchain_core.messages import (
|
||||||
|
convert_to_openai_data_block,
|
||||||
|
is_data_content_block,
|
||||||
|
)
|
||||||
from langchain_core.messages.ai import AIMessage, AIMessageChunk
|
from langchain_core.messages.ai import AIMessage, AIMessageChunk
|
||||||
from langchain_core.messages.base import BaseMessage, BaseMessageChunk
|
from langchain_core.messages.base import BaseMessage, BaseMessageChunk
|
||||||
from langchain_core.messages.chat import ChatMessage, ChatMessageChunk
|
from langchain_core.messages.chat import ChatMessage, ChatMessageChunk
|
||||||
@ -1011,8 +1014,6 @@ def convert_to_openai_messages(
|
|||||||
|
|
||||||
for i, message in enumerate(messages):
|
for i, message in enumerate(messages):
|
||||||
oai_msg: dict = {"role": _get_message_openai_role(message)}
|
oai_msg: dict = {"role": _get_message_openai_role(message)}
|
||||||
tool_messages: list = []
|
|
||||||
content: Union[str, list[dict]]
|
|
||||||
|
|
||||||
if message.name:
|
if message.name:
|
||||||
oai_msg["name"] = message.name
|
oai_msg["name"] = message.name
|
||||||
@ -1023,14 +1024,37 @@ def convert_to_openai_messages(
|
|||||||
if isinstance(message, ToolMessage):
|
if isinstance(message, ToolMessage):
|
||||||
oai_msg["tool_call_id"] = message.tool_call_id
|
oai_msg["tool_call_id"] = message.tool_call_id
|
||||||
|
|
||||||
|
content, tool_messages = _extract_content(i, message, oai_msg, text_format)
|
||||||
|
oai_msg["content"] = content
|
||||||
|
if message.content and not oai_msg["content"] and tool_messages:
|
||||||
|
oai_messages.extend(tool_messages)
|
||||||
|
else:
|
||||||
|
oai_messages.extend([oai_msg, *tool_messages])
|
||||||
|
|
||||||
|
if is_single:
|
||||||
|
return oai_messages[0]
|
||||||
|
return oai_messages
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_content(
|
||||||
|
idx: int,
|
||||||
|
message: BaseMessage,
|
||||||
|
oai_msg: dict,
|
||||||
|
text_format: Literal["string", "block"],
|
||||||
|
) -> tuple[Union[str, list[dict]], list]:
|
||||||
|
"""Extract content from a message and format it according to OpenAI standards."""
|
||||||
|
content: Union[str, list[dict]]
|
||||||
|
tool_messages: list = []
|
||||||
if not message.content:
|
if not message.content:
|
||||||
content = "" if text_format == "string" else []
|
content = "" if text_format == "string" else []
|
||||||
elif isinstance(message.content, str):
|
return content, tool_messages
|
||||||
|
if isinstance(message.content, str):
|
||||||
if text_format == "string":
|
if text_format == "string":
|
||||||
content = message.content
|
content = message.content
|
||||||
else:
|
else:
|
||||||
content = [{"type": "text", "text": message.content}]
|
content = [{"type": "text", "text": message.content}]
|
||||||
elif text_format == "string" and all(
|
return content, tool_messages
|
||||||
|
if text_format == "string" and all(
|
||||||
isinstance(block, str) or block.get("type") == "text"
|
isinstance(block, str) or block.get("type") == "text"
|
||||||
for block in message.content
|
for block in message.content
|
||||||
):
|
):
|
||||||
@ -1038,17 +1062,22 @@ def convert_to_openai_messages(
|
|||||||
block if isinstance(block, str) else block["text"]
|
block if isinstance(block, str) else block["text"]
|
||||||
for block in message.content
|
for block in message.content
|
||||||
)
|
)
|
||||||
else:
|
return content, tool_messages
|
||||||
|
|
||||||
content = []
|
content = []
|
||||||
for j, block in enumerate(message.content):
|
for block_idx, block in enumerate(message.content):
|
||||||
# OpenAI format
|
# OpenAI format
|
||||||
if isinstance(block, str):
|
if isinstance(block, str):
|
||||||
content.append({"type": "text", "text": block})
|
content.append({"type": "text", "text": block})
|
||||||
elif block.get("type") == "text":
|
continue
|
||||||
|
|
||||||
|
block = cast("dict", block)
|
||||||
|
|
||||||
|
if block.get("type") == "text":
|
||||||
if missing := [k for k in ("text",) if k not in block]:
|
if missing := [k for k in ("text",) if k not in block]:
|
||||||
err = (
|
err = (
|
||||||
f"Unrecognized content block at "
|
f"Unrecognized content block at "
|
||||||
f"messages[{i}].content[{j}] has 'type': 'text' "
|
f"messages[{idx}].content[{block_idx}] has 'type': 'text' "
|
||||||
f"but is missing expected key(s) "
|
f"but is missing expected key(s) "
|
||||||
f"{missing}. Full content block:\n\n{block}"
|
f"{missing}. Full content block:\n\n{block}"
|
||||||
)
|
)
|
||||||
@ -1058,7 +1087,7 @@ def convert_to_openai_messages(
|
|||||||
if missing := [k for k in ("image_url",) if k not in block]:
|
if missing := [k for k in ("image_url",) if k not in block]:
|
||||||
err = (
|
err = (
|
||||||
f"Unrecognized content block at "
|
f"Unrecognized content block at "
|
||||||
f"messages[{i}].content[{j}] has 'type': 'image_url' "
|
f"messages[{idx}].content[{block_idx}] has 'type': 'image_url' "
|
||||||
f"but is missing expected key(s) "
|
f"but is missing expected key(s) "
|
||||||
f"{missing}. Full content block:\n\n{block}"
|
f"{missing}. Full content block:\n\n{block}"
|
||||||
)
|
)
|
||||||
@ -1089,7 +1118,7 @@ def convert_to_openai_messages(
|
|||||||
]:
|
]:
|
||||||
err = (
|
err = (
|
||||||
f"Unrecognized content block at "
|
f"Unrecognized content block at "
|
||||||
f"messages[{i}].content[{j}] has 'type': 'image' "
|
f"messages[{idx}].content[{block_idx}] has 'type': 'image' "
|
||||||
f"but 'source' is missing expected key(s) "
|
f"but 'source' is missing expected key(s) "
|
||||||
f"{missing}. Full content block:\n\n{block}"
|
f"{missing}. Full content block:\n\n{block}"
|
||||||
)
|
)
|
||||||
@ -1107,12 +1136,10 @@ def convert_to_openai_messages(
|
|||||||
)
|
)
|
||||||
# Bedrock converse
|
# Bedrock converse
|
||||||
elif image := block.get("image"):
|
elif image := block.get("image"):
|
||||||
if missing := [
|
if missing := [k for k in ("source", "format") if k not in image]:
|
||||||
k for k in ("source", "format") if k not in image
|
|
||||||
]:
|
|
||||||
err = (
|
err = (
|
||||||
f"Unrecognized content block at "
|
f"Unrecognized content block at "
|
||||||
f"messages[{i}].content[{j}] has key 'image', "
|
f"messages[{idx}].content[{block_idx}] has key 'image', "
|
||||||
f"but 'image' is missing expected key(s) "
|
f"but 'image' is missing expected key(s) "
|
||||||
f"{missing}. Full content block:\n\n{block}"
|
f"{missing}. Full content block:\n\n{block}"
|
||||||
)
|
)
|
||||||
@ -1122,16 +1149,14 @@ def convert_to_openai_messages(
|
|||||||
{
|
{
|
||||||
"type": "image_url",
|
"type": "image_url",
|
||||||
"image_url": {
|
"image_url": {
|
||||||
"url": (
|
"url": (f"data:image/{image['format']};base64,{b64_image}")
|
||||||
f"data:image/{image['format']};base64,{b64_image}"
|
|
||||||
)
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
err = (
|
err = (
|
||||||
f"Unrecognized content block at "
|
f"Unrecognized content block at "
|
||||||
f"messages[{i}].content[{j}] has 'type': 'image' "
|
f"messages[{idx}].content[{block_idx}] has 'type': 'image' "
|
||||||
f"but does not have a 'source' or 'image' key. Full "
|
f"but does not have a 'source' or 'image' key. Full "
|
||||||
f"content block:\n\n{block}"
|
f"content block:\n\n{block}"
|
||||||
)
|
)
|
||||||
@ -1155,12 +1180,10 @@ def convert_to_openai_messages(
|
|||||||
):
|
):
|
||||||
content.append(block)
|
content.append(block)
|
||||||
elif block.get("type") == "tool_use":
|
elif block.get("type") == "tool_use":
|
||||||
if missing := [
|
if missing := [k for k in ("id", "name", "input") if k not in block]:
|
||||||
k for k in ("id", "name", "input") if k not in block
|
|
||||||
]:
|
|
||||||
err = (
|
err = (
|
||||||
f"Unrecognized content block at "
|
f"Unrecognized content block at "
|
||||||
f"messages[{i}].content[{j}] has 'type': "
|
f"messages[{idx}].content[{block_idx}] has 'type': "
|
||||||
f"'tool_use', but is missing expected key(s) "
|
f"'tool_use', but is missing expected key(s) "
|
||||||
f"{missing}. Full content block:\n\n{block}"
|
f"{missing}. Full content block:\n\n{block}"
|
||||||
)
|
)
|
||||||
@ -1181,12 +1204,10 @@ def convert_to_openai_messages(
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
elif block.get("type") == "tool_result":
|
elif block.get("type") == "tool_result":
|
||||||
if missing := [
|
if missing := [k for k in ("content", "tool_use_id") if k not in block]:
|
||||||
k for k in ("content", "tool_use_id") if k not in block
|
|
||||||
]:
|
|
||||||
msg = (
|
msg = (
|
||||||
f"Unrecognized content block at "
|
f"Unrecognized content block at "
|
||||||
f"messages[{i}].content[{j}] has 'type': "
|
f"messages[{idx}].content[{block_idx}] has 'type': "
|
||||||
f"'tool_result', but is missing expected key(s) "
|
f"'tool_result', but is missing expected key(s) "
|
||||||
f"{missing}. Full content block:\n\n{block}"
|
f"{missing}. Full content block:\n\n{block}"
|
||||||
)
|
)
|
||||||
@ -1198,15 +1219,13 @@ def convert_to_openai_messages(
|
|||||||
)
|
)
|
||||||
# Recurse to make sure tool message contents are OpenAI format.
|
# Recurse to make sure tool message contents are OpenAI format.
|
||||||
tool_messages.extend(
|
tool_messages.extend(
|
||||||
convert_to_openai_messages(
|
convert_to_openai_messages([tool_message], text_format=text_format)
|
||||||
[tool_message], text_format=text_format
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
elif (block.get("type") == "json") or "json" in block:
|
elif (block.get("type") == "json") or "json" in block:
|
||||||
if "json" not in block:
|
if "json" not in block:
|
||||||
msg = (
|
msg = (
|
||||||
f"Unrecognized content block at "
|
f"Unrecognized content block at "
|
||||||
f"messages[{i}].content[{j}] has 'type': 'json' "
|
f"messages[{idx}].content[{block_idx}] has 'type': 'json' "
|
||||||
f"but does not have a 'json' key. Full "
|
f"but does not have a 'json' key. Full "
|
||||||
f"content block:\n\n{block}"
|
f"content block:\n\n{block}"
|
||||||
)
|
)
|
||||||
@ -1218,15 +1237,12 @@ def convert_to_openai_messages(
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
elif (block.get("type") == "guard_content") or "guard_content" in block:
|
elif (block.get("type") == "guard_content") or "guard_content" in block:
|
||||||
if (
|
if "guard_content" not in block or "text" not in block["guard_content"]:
|
||||||
"guard_content" not in block
|
|
||||||
or "text" not in block["guard_content"]
|
|
||||||
):
|
|
||||||
msg = (
|
msg = (
|
||||||
f"Unrecognized content block at "
|
f"Unrecognized content block at "
|
||||||
f"messages[{i}].content[{j}] has 'type': "
|
f"messages[{idx}].content[{block_idx}] has 'type': "
|
||||||
f"'guard_content' but does not have a "
|
f"'guard_content' but does not have a "
|
||||||
f"messages[{i}].content[{j}]['guard_content']['text'] "
|
f"messages[{idx}].content[{block_idx}]['guard_content']['text'] "
|
||||||
f"key. Full content block:\n\n{block}"
|
f"key. Full content block:\n\n{block}"
|
||||||
)
|
)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
@ -1239,7 +1255,7 @@ def convert_to_openai_messages(
|
|||||||
if missing := [k for k in ("mime_type", "data") if k not in block]:
|
if missing := [k for k in ("mime_type", "data") if k not in block]:
|
||||||
err = (
|
err = (
|
||||||
f"Unrecognized content block at "
|
f"Unrecognized content block at "
|
||||||
f"messages[{i}].content[{j}] has 'type': "
|
f"messages[{idx}].content[{block_idx}] has 'type': "
|
||||||
f"'media' but does not have key(s) {missing}. Full "
|
f"'media' but does not have key(s) {missing}. Full "
|
||||||
f"content block:\n\n{block}"
|
f"content block:\n\n{block}"
|
||||||
)
|
)
|
||||||
@ -1265,7 +1281,7 @@ def convert_to_openai_messages(
|
|||||||
else:
|
else:
|
||||||
err = (
|
err = (
|
||||||
f"Unrecognized content block at "
|
f"Unrecognized content block at "
|
||||||
f"messages[{i}].content[{j}] does not match OpenAI, "
|
f"messages[{idx}].content[{block_idx}] does not match OpenAI, "
|
||||||
f"Anthropic, Bedrock Converse, or VertexAI format. Full "
|
f"Anthropic, Bedrock Converse, or VertexAI format. Full "
|
||||||
f"content block:\n\n{block}"
|
f"content block:\n\n{block}"
|
||||||
)
|
)
|
||||||
@ -1274,15 +1290,7 @@ def convert_to_openai_messages(
|
|||||||
block["type"] != "text" for block in content
|
block["type"] != "text" for block in content
|
||||||
):
|
):
|
||||||
content = "\n".join(block["text"] for block in content)
|
content = "\n".join(block["text"] for block in content)
|
||||||
oai_msg["content"] = content
|
return content, tool_messages
|
||||||
if message.content and not oai_msg["content"] and tool_messages:
|
|
||||||
oai_messages.extend(tool_messages)
|
|
||||||
else:
|
|
||||||
oai_messages.extend([oai_msg, *tool_messages])
|
|
||||||
|
|
||||||
if is_single:
|
|
||||||
return oai_messages[0]
|
|
||||||
return oai_messages
|
|
||||||
|
|
||||||
|
|
||||||
def _first_max_tokens(
|
def _first_max_tokens(
|
||||||
|
@ -1,18 +1,20 @@
|
|||||||
import base64
|
import base64
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
from collections.abc import Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
from typing import Any, Callable, Optional, Union
|
from typing import Any, Callable, Optional, Union, cast
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from typing_extensions import override
|
from typing_extensions import TypeGuard, override
|
||||||
|
|
||||||
from langchain_core.language_models.fake_chat_models import FakeChatModel
|
from langchain_core.language_models.fake_chat_models import FakeChatModel
|
||||||
from langchain_core.messages import (
|
from langchain_core.messages import (
|
||||||
AIMessage,
|
AIMessage,
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
HumanMessage,
|
HumanMessage,
|
||||||
|
ReasoningContentBlock,
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
|
TextContentBlock,
|
||||||
ToolCall,
|
ToolCall,
|
||||||
ToolMessage,
|
ToolMessage,
|
||||||
)
|
)
|
||||||
@ -1457,3 +1459,32 @@ def test_get_buffer_string_with_empty_content() -> None:
|
|||||||
expected = "Human: \nAI: \nSystem: "
|
expected = "Human: \nAI: \nSystem: "
|
||||||
actual = get_buffer_string(messages)
|
actual = get_buffer_string(messages)
|
||||||
assert actual == expected
|
assert actual == expected
|
||||||
|
|
||||||
|
|
||||||
|
def is_reasoning_block(block: Mapping[str, Any]) -> TypeGuard[ReasoningContentBlock]:
|
||||||
|
"""Check if a block is a ReasoningContentBlock."""
|
||||||
|
return block.get("type") == "reasoning"
|
||||||
|
|
||||||
|
|
||||||
|
def is_text_block(block: Mapping[str, Any]) -> TypeGuard[TextContentBlock]:
|
||||||
|
"""Check if a block is a TextContentBlock."""
|
||||||
|
return block.get("type") == "text"
|
||||||
|
|
||||||
|
|
||||||
|
def test_typing() -> None:
|
||||||
|
"""Test typing on things"""
|
||||||
|
message = AIMessage(
|
||||||
|
content="Hello",
|
||||||
|
)
|
||||||
|
if isinstance(message.content, str):
|
||||||
|
# This should not raise an error
|
||||||
|
message.content = message.content + " world"
|
||||||
|
elif isinstance(message.content, list):
|
||||||
|
all_contents = []
|
||||||
|
for block in message.content:
|
||||||
|
if isinstance(block, dict):
|
||||||
|
block = cast("dict", block)
|
||||||
|
if is_text_block(block):
|
||||||
|
all_contents.append(block["text"])
|
||||||
|
if is_reasoning_block(block):
|
||||||
|
all_contents.append(block.get("reasoning", "foo"))
|
||||||
|
Loading…
Reference in New Issue
Block a user