Compare commits

...

17 Commits

Author SHA1 Message Date
Bagatur
40f4676c64 merge 2024-10-09 16:44:05 -07:00
Bagatur
eb9170a121 fmt 2024-08-30 13:31:47 -07:00
Bagatur
9bf6215571 fmt 2024-08-30 12:54:55 -07:00
Bagatur
eb7e485886 fmt 2024-08-30 12:52:06 -07:00
Bagatur
b31faf6572 Merge branch 'master' into bagatur/format_content_as 2024-08-30 12:06:50 -07:00
Bagatur
2f5a3e5181 Merge branch 'master' into bagatur/format_content_as 2024-08-29 18:08:43 -07:00
Bagatur
bdea415fe1 Merge branch 'master' into bagatur/format_content_as 2024-08-29 17:46:01 -07:00
Bagatur
b9835f8308 fmt 2024-08-29 17:45:50 -07:00
Bagatur
3de115e46d fmt 2024-08-29 17:40:25 -07:00
Bagatur
cc70f2afe9 fmt 2024-08-29 17:35:19 -07:00
Bagatur
992884b410 fmt 2024-08-29 17:18:41 -07:00
Bagatur
fb002faba4 Merge branch 'master' into bagatur/format_content_as 2024-08-29 13:37:08 -07:00
Bagatur
fc46822bb7 wip transform 2024-08-29 11:38:18 -07:00
Bagatur
ad009f9c69 Merge branch 'master' into bagatur/format_content_as 2024-08-29 09:26:32 -07:00
Bagatur
ca5773dc0a fmt 2024-08-28 17:01:23 -07:00
Bagatur
bd48abe54a fmt 2024-08-28 16:54:15 -07:00
Bagatur
49f7c8cdd8 core[minor]: Add msg content formatting util 2024-08-28 16:31:30 -07:00
6 changed files with 1766 additions and 21 deletions

View File

@@ -44,6 +44,7 @@ from langchain_core.messages.utils import (
_message_from_dict,
convert_to_messages,
filter_messages,
format_messages,
get_buffer_string,
merge_message_runs,
message_chunk_to_message,
@@ -83,4 +84,5 @@ __all__ = [
"filter_messages",
"merge_message_runs",
"trim_messages",
"format_messages",
]

File diff suppressed because it is too large Load Diff

View File

@@ -32,6 +32,7 @@ EXPECTED_ALL = [
"filter_messages",
"merge_message_runs",
"trim_messages",
"format_messages",
]

View File

@@ -1,10 +1,13 @@
import json
from collections.abc import Iterator
from typing import Any, Callable
import pytest
from langchain_core.language_models.fake_chat_models import FakeChatModel
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
HumanMessage,
SystemMessage,
@@ -12,11 +15,14 @@ from langchain_core.messages import (
ToolMessage,
)
from langchain_core.messages.utils import (
_bytes_to_b64_str,
convert_to_messages,
filter_messages,
format_messages,
merge_message_runs,
trim_messages,
)
from langchain_core.runnables import RunnableLambda
@pytest.mark.parametrize("msg_cls", [HumanMessage, AIMessage, SystemMessage])
@@ -568,3 +574,842 @@ def test_convert_to_messages() -> None:
@pytest.mark.xfail(reason="AI message does not support refusal key yet.")
def test_convert_to_messages_openai_refusal() -> None:
convert_to_messages([{"role": "assistant", "refusal": "9.1"}])
def create_base64_image(format: str = "jpeg") -> str:
return f"" # noqa: E501
def test_format_messages_single_message() -> None:
message = HumanMessage(content="Hello")
result = format_messages(message, format="langchain-openai", text_format="string")
assert isinstance(result, BaseMessage)
assert result.content == "Hello"
def test_format_messages_multiple_messages() -> None:
messages = [
SystemMessage(content="System message"),
HumanMessage(content="Human message"),
AIMessage(content="AI message"),
]
result = format_messages(messages, format="langchain-openai", text_format="string")
assert isinstance(result, list)
assert len(result) == 3
assert all(isinstance(msg, BaseMessage) for msg in result)
assert [msg.content for msg in result] == [
"System message",
"Human message",
"AI message",
]
def test_format_messages_openai_string() -> None:
messages = [
HumanMessage(
content=[
{"type": "text", "text": "Hello"},
{"type": "text", "text": "World"},
]
),
AIMessage(
content=[{"type": "text", "text": "Hi"}, {"type": "text", "text": "there"}]
),
]
result = format_messages(messages, format="langchain-openai", text_format="string")
assert [msg.content for msg in result] == ["Hello\nWorld", "Hi\nthere"]
def test_format_messages_openai_block() -> None:
messages = [
HumanMessage(content="Hello"),
AIMessage(content="Hi there"),
]
result = format_messages(messages, format="langchain-openai", text_format="block")
assert [msg.content for msg in result] == [
[{"type": "text", "text": "Hello"}],
[{"type": "text", "text": "Hi there"}],
]
def test_format_messages_anthropic_string() -> None:
messages = [
HumanMessage(
content=[
{"type": "text", "text": "Hello"},
{"type": "text", "text": "World"},
]
),
AIMessage(
content=[{"type": "text", "text": "Hi"}, {"type": "text", "text": "there"}]
),
]
result = format_messages(
messages, format="langchain-anthropic", text_format="string"
)
assert [msg.content for msg in result] == ["Hello\nWorld", "Hi\nthere"]
def test_format_messages_anthropic_block() -> None:
messages = [
HumanMessage(content="Hello"),
AIMessage(content="Hi there"),
]
result = format_messages(
messages, format="langchain-anthropic", text_format="block"
)
assert [msg.content for msg in result] == [
[{"type": "text", "text": "Hello"}],
[{"type": "text", "text": "Hi there"}],
]
def test_format_messages_invalid_format() -> None:
with pytest.raises(ValueError, match="Unrecognized format="):
format_messages(
[HumanMessage(content="Hello")], format="invalid", text_format="string"
)
def test_format_messages_openai_image() -> None:
base64_image = create_base64_image()
messages = [
HumanMessage(
content=[
{"type": "text", "text": "Here's an image:"},
{"type": "image_url", "image_url": {"url": base64_image}},
]
)
]
result = format_messages(messages, format="langchain-openai", text_format="block")
assert result[0].content[1]["type"] == "image_url"
assert result[0].content[1]["image_url"]["url"] == base64_image
def test_format_messages_anthropic_image() -> None:
base64_image = create_base64_image()
messages = [
HumanMessage(
content=[
{"type": "text", "text": "Here's an image:"},
{"type": "image_url", "image_url": {"url": base64_image}},
]
)
]
result = format_messages(
messages, format="langchain-anthropic", text_format="block"
)
assert result[0].content[1]["type"] == "image"
assert result[0].content[1]["source"]["type"] == "base64"
assert result[0].content[1]["source"]["media_type"] == "image/jpeg"
def test_format_messages_tool_message() -> None:
tool_message = ToolMessage(content="Tool result", tool_call_id="123")
result = format_messages(
[tool_message], format="langchain-openai", text_format="block"
)
assert isinstance(result[0], ToolMessage)
assert result[0].content == [{"type": "text", "text": "Tool result"}]
assert result[0].tool_call_id == "123"
def test_format_messages_tool_use() -> None:
messages = [
AIMessage(
content=[
{"type": "tool_use", "id": "123", "name": "calculator", "input": "2+2"}
]
)
]
result = format_messages(messages, format="langchain-openai", text_format="block")
assert result[0].tool_calls[0]["id"] == "123"
assert result[0].tool_calls[0]["name"] == "calculator"
assert result[0].tool_calls[0]["args"] == "2+2"
def test_format_messages_json() -> None:
json_data = {"key": "value"}
messages = [HumanMessage(content=[{"type": "json", "json": json_data}])]
result = format_messages(messages, format="langchain-openai", text_format="block")
assert result[0].content[0]["type"] == "text"
assert json.loads(result[0].content[0]["text"]) == json_data
def test_format_messages_guard_content() -> None:
messages = [
HumanMessage(
content=[
{
"type": "guard_content",
"guard_content": {"text": "Protected content"},
}
]
)
]
result = format_messages(messages, format="langchain-openai", text_format="block")
assert result[0].content[0]["type"] == "text"
assert result[0].content[0]["text"] == "Protected content"
def test_format_messages_vertexai_image() -> None:
messages = [
HumanMessage(
content=[
{"type": "media", "mime_type": "image/jpeg", "data": b"image_bytes"}
]
)
]
result = format_messages(messages, format="langchain-openai", text_format="block")
assert result[0].content[0]["type"] == "image_url"
assert (
result[0].content[0]["image_url"]["url"]
== f"data:image/jpeg;base64,{_bytes_to_b64_str(b'image_bytes')}"
)
def test_format_messages_invalid_block() -> None:
messages = [HumanMessage(content=[{"type": "invalid", "foo": "bar"}])]
with pytest.raises(ValueError, match="Unrecognized content block"):
format_messages(messages, format="langchain-openai", text_format="block")
with pytest.raises(ValueError, match="Unrecognized content block"):
format_messages(messages, format="langchain-anthropic", text_format="block")
def test_format_messages_empty_message() -> None:
result = format_messages(
HumanMessage(content=""), format="langchain-openai", text_format="string"
)
assert result.content == ""
def test_format_messages_empty_list() -> None:
result = format_messages([], format="langchain-openai", text_format="string")
assert result == []
def test_format_messages_mixed_content_types() -> None:
messages = [
HumanMessage(
content=[
"Text message",
{"type": "text", "text": "Structured text"},
{"type": "image_url", "image_url": create_base64_image()},
]
)
]
result = format_messages(messages, format="langchain-openai", text_format="block")
assert len(result[0].content) == 3
assert isinstance(result[0].content[0], dict)
assert isinstance(result[0].content[1], dict)
assert isinstance(result[0].content[2], dict)
def test_format_messages_anthropic_tool_calls() -> None:
message = AIMessage(
"blah",
tool_calls=[
{"type": "tool_call", "name": "foo", "id": "1", "args": {"bar": "baz"}}
],
)
result = format_messages(
message, format="langchain-anthropic", text_format="string"
)
assert result.content == [
{"type": "text", "text": "blah"},
{"type": "tool_use", "id": "1", "name": "foo", "input": {"bar": "baz"}},
]
assert result.tool_calls == message.tool_calls
def test_format_messages_declarative() -> None:
formatter = format_messages(format="langchain-openai", text_format="block")
base64_image = create_base64_image()
messages = [
HumanMessage(
content=[
{"type": "text", "text": "Here's an image:"},
{"type": "image_url", "image_url": {"url": base64_image}},
]
)
]
result = formatter.invoke(messages)
assert result[0].content[1]["type"] == "image_url"
assert result[0].content[1]["image_url"]["url"] == base64_image
def _stream_oai(input_: Any) -> Iterator:
chunks = [
AIMessageChunk(content=""),
AIMessageChunk(content="Certainly"),
AIMessageChunk(content="!"),
AIMessageChunk(
content="",
tool_calls=[
{
"name": "multiply",
"args": {},
"id": "call_hr4yN4ZN7zv9Vmc5Cuahp4K8",
"type": "tool_call",
}
],
tool_call_chunks=[
{
"name": "multiply",
"args": "",
"id": "call_hr4yN4ZN7zv9Vmc5Cuahp4K8",
"index": 0,
"type": "tool_call_chunk",
}
],
),
AIMessageChunk(
content="",
tool_calls=[{"name": "", "args": {}, "id": None, "type": "tool_call"}],
tool_call_chunks=[
{
"name": None,
"args": '{"',
"id": None,
"index": 0,
"type": "tool_call_chunk",
}
],
),
AIMessageChunk(
content="",
invalid_tool_calls=[
{
"name": None,
"args": 'a": 5, "b": 2',
"id": None,
"error": None,
"type": "invalid_tool_call",
}
],
tool_call_chunks=[
{
"name": None,
"args": 'a": 5, "b": 2',
"id": None,
"index": 0,
"type": "tool_call_chunk",
}
],
),
AIMessageChunk(
content="",
invalid_tool_calls=[
{
"name": None,
"args": "}",
"id": None,
"error": None,
"type": "invalid_tool_call",
}
],
tool_call_chunks=[
{
"name": None,
"args": "}",
"id": None,
"index": 0,
"type": "tool_call_chunk",
}
],
),
AIMessageChunk(
content="",
),
]
yield from chunks
def _stream_anthropic(input_: Any) -> Iterator:
chunks = [
AIMessageChunk(content=[]),
AIMessageChunk(content=[{"text": "Certainly", "type": "text", "index": 0}]),
AIMessageChunk(content=[{"text": "!", "type": "text", "index": 0}]),
AIMessageChunk(
content=[
{
"id": "call_hr4yN4ZN7zv9Vmc5Cuahp4K8",
"input": {},
"name": "multiply",
"type": "tool_use",
"index": 1,
}
],
tool_calls=[
{
"name": "multiply",
"args": {},
"id": "call_hr4yN4ZN7zv9Vmc5Cuahp4K8",
"type": "tool_call",
}
],
tool_call_chunks=[
{
"name": "multiply",
"args": "",
"id": "call_hr4yN4ZN7zv9Vmc5Cuahp4K8",
"index": 1,
"type": "tool_call_chunk",
}
],
),
AIMessageChunk(
content=[{"partial_json": '{"', "type": "tool_use", "index": 1}],
tool_calls=[{"name": "", "args": {}, "id": None, "type": "tool_call"}],
tool_call_chunks=[
{
"name": None,
"args": '{"',
"id": None,
"index": 1,
"type": "tool_call_chunk",
}
],
),
AIMessageChunk(
content=[{"partial_json": 'a": 5, "b": 2', "type": "tool_use", "index": 1}],
invalid_tool_calls=[
{
"name": None,
"args": 'a": 5, "b": 2',
"id": None,
"error": None,
"type": "invalid_tool_call",
}
],
tool_call_chunks=[
{
"name": None,
"args": 'a": 5, "b": 2',
"id": None,
"index": 1,
"type": "tool_call_chunk",
}
],
),
AIMessageChunk(
content=[{"partial_json": "}", "type": "tool_use", "index": 1}],
invalid_tool_calls=[
{
"name": None,
"args": "}",
"id": None,
"error": None,
"type": "invalid_tool_call",
}
],
tool_call_chunks=[
{
"name": None,
"args": "}",
"id": None,
"index": 1,
"type": "tool_call_chunk",
}
],
),
AIMessageChunk(content=""),
]
yield from chunks
@pytest.mark.parametrize("stream", [_stream_oai, _stream_anthropic])
def test_format_messages_openai_string_stream(stream: Callable) -> None:
formatter = format_messages(format="langchain-openai", text_format="string")
chain = RunnableLambda(stream) | formatter
tool_call_idx = 1 if stream == _stream_anthropic else 0
expected = [
AIMessageChunk(content=""),
AIMessageChunk(content="Certainly"),
AIMessageChunk(content="!"),
AIMessageChunk(
content="",
tool_calls=[
{
"name": "multiply",
"args": {},
"id": "call_hr4yN4ZN7zv9Vmc5Cuahp4K8",
"type": "tool_call",
}
],
tool_call_chunks=[
{
"name": "multiply",
"args": "",
"id": "call_hr4yN4ZN7zv9Vmc5Cuahp4K8",
"index": tool_call_idx,
"type": "tool_call_chunk",
}
],
),
AIMessageChunk(
content="",
tool_calls=[{"name": "", "args": {}, "id": None, "type": "tool_call"}],
tool_call_chunks=[
{
"name": None,
"args": '{"',
"id": None,
"index": tool_call_idx,
"type": "tool_call_chunk",
}
],
),
AIMessageChunk(
content="",
invalid_tool_calls=[
{
"name": None,
"args": 'a": 5, "b": 2',
"id": None,
"error": None,
"type": "invalid_tool_call",
}
],
tool_call_chunks=[
{
"name": None,
"args": 'a": 5, "b": 2',
"id": None,
"index": tool_call_idx,
"type": "tool_call_chunk",
}
],
),
AIMessageChunk(
content="",
invalid_tool_calls=[
{
"name": None,
"args": "}",
"id": None,
"error": None,
"type": "invalid_tool_call",
}
],
tool_call_chunks=[
{
"name": None,
"args": "}",
"id": None,
"index": tool_call_idx,
"type": "tool_call_chunk",
}
],
),
AIMessageChunk(
content="",
),
]
actual = list(chain.stream({}))
assert expected == actual
@pytest.mark.parametrize("stream", [_stream_oai, _stream_anthropic])
def test_format_messages_openai_block_stream(stream: Callable) -> None:
formatter = format_messages(format="langchain-openai", text_format="block")
chain = RunnableLambda(stream) | formatter
tool_call_idx = 1 if stream == _stream_anthropic else 0
expected = [
AIMessageChunk(content=[]),
AIMessageChunk(content=[{"type": "text", "text": "Certainly"}]),
AIMessageChunk(content=[{"type": "text", "text": "!"}]),
AIMessageChunk(
content=[],
tool_calls=[
{
"name": "multiply",
"args": {},
"id": "call_hr4yN4ZN7zv9Vmc5Cuahp4K8",
"type": "tool_call",
}
],
tool_call_chunks=[
{
"name": "multiply",
"args": "",
"id": "call_hr4yN4ZN7zv9Vmc5Cuahp4K8",
"index": tool_call_idx,
"type": "tool_call_chunk",
}
],
),
AIMessageChunk(
content=[],
tool_calls=[{"name": "", "args": {}, "id": None, "type": "tool_call"}],
tool_call_chunks=[
{
"name": None,
"args": '{"',
"id": None,
"index": tool_call_idx,
"type": "tool_call_chunk",
}
],
),
AIMessageChunk(
content=[],
invalid_tool_calls=[
{
"name": None,
"args": 'a": 5, "b": 2',
"id": None,
"error": None,
"type": "invalid_tool_call",
}
],
tool_call_chunks=[
{
"name": None,
"args": 'a": 5, "b": 2',
"id": None,
"index": tool_call_idx,
"type": "tool_call_chunk",
}
],
),
AIMessageChunk(
content=[],
invalid_tool_calls=[
{
"name": None,
"args": "}",
"id": None,
"error": None,
"type": "invalid_tool_call",
}
],
tool_call_chunks=[
{
"name": None,
"args": "}",
"id": None,
"index": tool_call_idx,
"type": "tool_call_chunk",
}
],
),
AIMessageChunk(
content=[],
),
]
actual = list(chain.stream({}))
assert expected == actual
@pytest.mark.parametrize("stream", [_stream_oai, _stream_anthropic])
def test_format_messages_anthropic_block_stream(stream: Callable) -> None:
formatter = format_messages(format="langchain-anthropic", text_format="block")
chain = RunnableLambda(stream) | formatter
expected = [
AIMessageChunk(content=[]),
AIMessageChunk(content=[{"text": "Certainly", "type": "text", "index": 0}]),
AIMessageChunk(content=[{"text": "!", "type": "text", "index": 0}]),
AIMessageChunk(
content=[
{
"id": "call_hr4yN4ZN7zv9Vmc5Cuahp4K8",
"name": "multiply",
"type": "tool_use",
"index": 1,
**(
{"input": {}}
if stream == _stream_anthropic
else {"partial_json": ""}
),
}
],
tool_calls=[
{
"name": "multiply",
"args": {},
"id": "call_hr4yN4ZN7zv9Vmc5Cuahp4K8",
"type": "tool_call",
}
],
tool_call_chunks=[
{
"name": "multiply",
"args": "",
"id": "call_hr4yN4ZN7zv9Vmc5Cuahp4K8",
"index": 1,
"type": "tool_call_chunk",
}
],
),
AIMessageChunk(
content=[{"partial_json": '{"', "type": "tool_use", "index": 1}],
tool_calls=[{"name": "", "args": {}, "id": None, "type": "tool_call"}],
tool_call_chunks=[
{
"name": None,
"args": '{"',
"id": None,
"index": 1,
"type": "tool_call_chunk",
}
],
),
AIMessageChunk(
content=[{"partial_json": 'a": 5, "b": 2', "type": "tool_use", "index": 1}],
invalid_tool_calls=[
{
"name": None,
"args": 'a": 5, "b": 2',
"id": None,
"error": None,
"type": "invalid_tool_call",
}
],
tool_call_chunks=[
{
"name": None,
"args": 'a": 5, "b": 2',
"id": None,
"index": 1,
"type": "tool_call_chunk",
}
],
),
AIMessageChunk(
content=[{"partial_json": "}", "type": "tool_use", "index": 1}],
invalid_tool_calls=[
{
"name": None,
"args": "}",
"id": None,
"error": None,
"type": "invalid_tool_call",
}
],
tool_call_chunks=[
{
"name": None,
"args": "}",
"id": None,
"index": 1,
"type": "tool_call_chunk",
}
],
),
AIMessageChunk(content=[]),
]
actual = list(chain.stream({}))
assert expected == actual
@pytest.mark.parametrize("stream", [_stream_oai, _stream_anthropic])
def test_format_messages_anthropic_string_stream(stream: Callable) -> None:
formatter = format_messages(format="langchain-anthropic", text_format="string")
chain = RunnableLambda(stream) | formatter
expected = [
AIMessageChunk(content=""),
AIMessageChunk(content="Certainly"),
AIMessageChunk(content="!"),
AIMessageChunk(
content=[
{
"type": "tool_use",
"id": "call_hr4yN4ZN7zv9Vmc5Cuahp4K8",
"name": "multiply",
"index": 1,
**(
{"input": {}}
if stream == _stream_anthropic
else {"partial_json": ""}
),
},
],
tool_calls=[
{
"name": "multiply",
"args": {},
"id": "call_hr4yN4ZN7zv9Vmc5Cuahp4K8",
"type": "tool_call",
}
],
tool_call_chunks=[
{
"name": "multiply",
"args": "",
"id": "call_hr4yN4ZN7zv9Vmc5Cuahp4K8",
"index": 1,
"type": "tool_call_chunk",
}
],
),
AIMessageChunk(
content=[
{"type": "tool_use", "partial_json": '{"', "index": 1},
],
tool_calls=[{"name": "", "args": {}, "id": None, "type": "tool_call"}],
tool_call_chunks=[
{
"name": None,
"args": '{"',
"id": None,
"index": 1,
"type": "tool_call_chunk",
}
],
),
AIMessageChunk(
content=[
{"type": "tool_use", "partial_json": 'a": 5, "b": 2', "index": 1},
],
invalid_tool_calls=[
{
"name": None,
"args": 'a": 5, "b": 2',
"id": None,
"error": None,
"type": "invalid_tool_call",
}
],
tool_call_chunks=[
{
"name": None,
"args": 'a": 5, "b": 2',
"id": None,
"index": 1,
"type": "tool_call_chunk",
}
],
),
AIMessageChunk(
content=[
{"type": "tool_use", "partial_json": "}", "index": 1},
],
invalid_tool_calls=[
{
"name": None,
"args": "}",
"id": None,
"error": None,
"type": "invalid_tool_call",
}
],
tool_call_chunks=[
{
"name": None,
"args": "}",
"id": None,
"index": 1,
"type": "tool_call_chunk",
}
],
),
AIMessageChunk(content=""),
]
actual = list(chain.stream({}))
assert expected == actual

View File

@@ -30,7 +30,7 @@ from langchain_core.language_models.chat_models import (
generate_from_stream,
)
from langchain_core.messages import AnyMessage, BaseMessage
from langchain_core.runnables import Runnable, RunnableConfig
from langchain_core.runnables import Runnable, RunnableConfig, ensure_config
from langchain_core.runnables.schema import StreamEvent
from langchain_core.tools import BaseTool
from langchain_core.tracers import RunLog, RunLogPatch
@@ -530,7 +530,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
return model
def _model_params(self, config: Optional[RunnableConfig]) -> dict:
config = config or {}
config = ensure_config(config)
model_params = {
_remove_prefix(k, self._config_prefix): v
for k, v in config.get("configurable", {}).items()

View File

@@ -1216,12 +1216,12 @@ def _make_message_chunk_from_anthropic_event(
content_block = event.delta.model_dump()
content_block["index"] = event.index
content_block["type"] = "tool_use"
tool_call_chunk = {
"index": event.index,
"id": None,
"name": None,
"args": event.delta.partial_json,
}
tool_call_chunk = create_tool_call_chunk(
index=event.index,
id=None,
name=None,
args=event.delta.partial_json,
)
message_chunk = AIMessageChunk(
content=[content_block],
tool_call_chunks=[tool_call_chunk], # type: ignore