This commit is contained in:
Bagatur 2024-08-29 17:35:19 -07:00
parent 992884b410
commit cc70f2afe9
2 changed files with 41 additions and 39 deletions

View File

@ -938,7 +938,7 @@ def format_messages_as(
messages: Union[MessageLikeRepresentation, Sequence[MessageLikeRepresentation]], messages: Union[MessageLikeRepresentation, Sequence[MessageLikeRepresentation]],
*, *,
format: Literal["openai", "anthropic"], format: Literal["openai", "anthropic"],
text: Literal["string", "block"], text_format: Literal["string", "block"],
) -> Union[BaseMessage, List[BaseMessage]]: ) -> Union[BaseMessage, List[BaseMessage]]:
"""Convert message contents into a standard format. """Convert message contents into a standard format.
@ -948,7 +948,7 @@ def format_messages_as(
messages: Message-like object or iterable of objects whose contents are already messages: Message-like object or iterable of objects whose contents are already
in OpenAI, Anthropic, Bedrock Converse, or VertexAI formats. in OpenAI, Anthropic, Bedrock Converse, or VertexAI formats.
format: Format to convert message contents to. format: Format to convert message contents to.
text: How to format text contents. If ``text='string'`` then any string text_format: How to format text contents. If ``text='string'`` then any string
contents are left as strings. If a message has content blocks that are all contents are left as strings. If a message has content blocks that are all
of type 'text', these are joined with a newline to make a single string. If of type 'text', these are joined with a newline to make a single string. If
a message has content blocks and at least one isn't of type 'text', then a message has content blocks and at least one isn't of type 'text', then
@ -1059,9 +1059,9 @@ def format_messages_as(
messages = [messages] messages = [messages]
messages = convert_to_messages(messages, copy=True) messages = convert_to_messages(messages, copy=True)
if format.lower() == "openai": if format.lower() == "openai":
formatted = _format_messages_as_openai(messages, text=text) formatted = _format_messages_as_openai(messages, text_format=text_format)
elif format.lower() == "anthropic": elif format.lower() == "anthropic":
formatted = _format_messages_as_anthropic(messages, text=text) formatted = _format_messages_as_anthropic(messages, text_format=text_format)
else: else:
raise ValueError( raise ValueError(
f"Unrecognized {format=}. Expected one of ('openai', 'anthropic')." f"Unrecognized {format=}. Expected one of ('openai', 'anthropic')."
@ -1073,21 +1073,21 @@ def format_messages_as(
def _format_messages_as_openai( def _format_messages_as_openai(
messages: Sequence[BaseMessage], *, text: Literal["string", "block"] messages: Sequence[BaseMessage], *, text_format: Literal["string", "block"]
) -> List[BaseMessage]: ) -> List[BaseMessage]:
"""Mutates messages so their contents match OpenAI messages API.""" """Mutates messages so their contents match OpenAI messages API."""
updated_messages: list = [] updated_messages: list = []
for i, message in enumerate(messages): for i, message in enumerate(messages):
tool_messages: list = [] tool_messages: list = []
if not message.content: if not message.content:
message.content = "" if text == "string" else [] message.content = "" if text_format == "string" else []
elif isinstance(message.content, str): elif isinstance(message.content, str):
if text == "string": if text_format == "string":
pass pass
else: else:
message.content = [{"type": "text", "text": message.content}] message.content = [{"type": "text", "text": message.content}]
else: else:
if text == "string" and all( if text_format == "string" and all(
isinstance(block, str) or block.get("type") == "text" isinstance(block, str) or block.get("type") == "text"
for block in message.content for block in message.content
): ):
@ -1226,7 +1226,9 @@ def _format_messages_as_openai(
) )
# Recurse to make sure tool message contents are OpenAI format. # Recurse to make sure tool message contents are OpenAI format.
tool_messages.extend( tool_messages.extend(
_format_messages_as_openai([tool_message], text=text) _format_messages_as_openai(
[tool_message], text_format=text_format
)
) )
elif (block.get("type") == "json") or "json" in block: elif (block.get("type") == "json") or "json" in block:
if "json" not in block: if "json" not in block:
@ -1292,7 +1294,7 @@ def _format_messages_as_openai(
f"Anthropic, Bedrock Converse, or VertexAI format. Full " f"Anthropic, Bedrock Converse, or VertexAI format. Full "
f"content block:\n\n{block}" f"content block:\n\n{block}"
) )
if text == "string" and not any( if text_format == "string" and not any(
block["type"] != "text" for block in content block["type"] != "text" for block in content
): ):
message.content = "\n".join(block["text"] for block in content) message.content = "\n".join(block["text"] for block in content)
@ -1306,7 +1308,7 @@ _OPTIONAL_ANTHROPIC_KEYS = ("cache_control", "is_error", "index")
def _format_messages_as_anthropic( def _format_messages_as_anthropic(
messages: Sequence[BaseMessage], *, text: Literal["string", "block"] messages: Sequence[BaseMessage], *, text_format: Literal["string", "block"]
) -> List[BaseMessage]: ) -> List[BaseMessage]:
"""Mutates messages so their contents match Anthropic messages API.""" """Mutates messages so their contents match Anthropic messages API."""
updated_messages: List = [] updated_messages: List = []
@ -1328,9 +1330,9 @@ def _format_messages_as_anthropic(
updated_messages.append(HumanMessage([tool_result_block])) updated_messages.append(HumanMessage([tool_result_block]))
continue continue
elif not message.content: elif not message.content:
message.content = "" if text == "string" else [] message.content = "" if text_format == "string" else []
elif isinstance(message.content, str): elif isinstance(message.content, str):
if text == "string": if text_format == "string":
pass pass
else: else:
text_block: dict = {"type": "text", "text": message.content} text_block: dict = {"type": "text", "text": message.content}
@ -1338,7 +1340,7 @@ def _format_messages_as_anthropic(
text_block["index"] = 0 text_block["index"] = 0
message.content = [text_block] message.content = [text_block]
else: else:
if text == "string" and all( if text_format == "string" and all(
isinstance(block, str) isinstance(block, str)
or (block.get("type") == "text" and "cache_control" not in block) or (block.get("type") == "text" and "cache_control" not in block)
for block in message.content for block in message.content

View File

@ -568,7 +568,7 @@ def create_base64_image(format: str = "jpeg") -> str:
def test_format_messages_as_single_message() -> None: def test_format_messages_as_single_message() -> None:
message = HumanMessage(content="Hello") message = HumanMessage(content="Hello")
result = format_messages_as(message, format="openai", text="string") result = format_messages_as(message, format="openai", text_format="string")
assert isinstance(result, BaseMessage) assert isinstance(result, BaseMessage)
assert result.content == "Hello" assert result.content == "Hello"
@ -579,7 +579,7 @@ def test_format_messages_as_multiple_messages() -> None:
HumanMessage(content="Human message"), HumanMessage(content="Human message"),
AIMessage(content="AI message"), AIMessage(content="AI message"),
] ]
result = format_messages_as(messages, format="openai", text="string") result = format_messages_as(messages, format="openai", text_format="string")
assert isinstance(result, list) assert isinstance(result, list)
assert len(result) == 3 assert len(result) == 3
assert all(isinstance(msg, BaseMessage) for msg in result) assert all(isinstance(msg, BaseMessage) for msg in result)
@ -602,7 +602,7 @@ def test_format_messages_as_openai_string() -> None:
content=[{"type": "text", "text": "Hi"}, {"type": "text", "text": "there"}] content=[{"type": "text", "text": "Hi"}, {"type": "text", "text": "there"}]
), ),
] ]
result = format_messages_as(messages, format="openai", text="string") result = format_messages_as(messages, format="openai", text_format="string")
assert [msg.content for msg in result] == ["Hello\nWorld", "Hi\nthere"] assert [msg.content for msg in result] == ["Hello\nWorld", "Hi\nthere"]
@ -611,7 +611,7 @@ def test_format_messages_as_openai_block() -> None:
HumanMessage(content="Hello"), HumanMessage(content="Hello"),
AIMessage(content="Hi there"), AIMessage(content="Hi there"),
] ]
result = format_messages_as(messages, format="openai", text="block") result = format_messages_as(messages, format="openai", text_format="block")
assert [msg.content for msg in result] == [ assert [msg.content for msg in result] == [
[{"type": "text", "text": "Hello"}], [{"type": "text", "text": "Hello"}],
[{"type": "text", "text": "Hi there"}], [{"type": "text", "text": "Hi there"}],
@ -630,7 +630,7 @@ def test_format_messages_as_anthropic_string() -> None:
content=[{"type": "text", "text": "Hi"}, {"type": "text", "text": "there"}] content=[{"type": "text", "text": "Hi"}, {"type": "text", "text": "there"}]
), ),
] ]
result = format_messages_as(messages, format="anthropic", text="string") result = format_messages_as(messages, format="anthropic", text_format="string")
assert [msg.content for msg in result] == ["Hello\nWorld", "Hi\nthere"] assert [msg.content for msg in result] == ["Hello\nWorld", "Hi\nthere"]
@ -639,7 +639,7 @@ def test_format_messages_as_anthropic_block() -> None:
HumanMessage(content="Hello"), HumanMessage(content="Hello"),
AIMessage(content="Hi there"), AIMessage(content="Hi there"),
] ]
result = format_messages_as(messages, format="anthropic", text="block") result = format_messages_as(messages, format="anthropic", text_format="block")
assert [msg.content for msg in result] == [ assert [msg.content for msg in result] == [
[{"type": "text", "text": "Hello"}], [{"type": "text", "text": "Hello"}],
[{"type": "text", "text": "Hi there"}], [{"type": "text", "text": "Hi there"}],
@ -649,7 +649,7 @@ def test_format_messages_as_anthropic_block() -> None:
def test_format_messages_as_invalid_format() -> None: def test_format_messages_as_invalid_format() -> None:
with pytest.raises(ValueError, match="Unrecognized format="): with pytest.raises(ValueError, match="Unrecognized format="):
format_messages_as( format_messages_as(
[HumanMessage(content="Hello")], format="invalid", text="string" [HumanMessage(content="Hello")], format="invalid", text_format="string"
) )
@ -663,7 +663,7 @@ def test_format_messages_as_openai_image() -> None:
] ]
) )
] ]
result = format_messages_as(messages, format="openai", text="block") result = format_messages_as(messages, format="openai", text_format="block")
assert result[0].content[1]["type"] == "image_url" assert result[0].content[1]["type"] == "image_url"
assert result[0].content[1]["image_url"]["url"] == base64_image assert result[0].content[1]["image_url"]["url"] == base64_image
@ -678,7 +678,7 @@ def test_format_messages_as_anthropic_image() -> None:
] ]
) )
] ]
result = format_messages_as(messages, format="anthropic", text="block") result = format_messages_as(messages, format="anthropic", text_format="block")
assert result[0].content[1]["type"] == "image" assert result[0].content[1]["type"] == "image"
assert result[0].content[1]["source"]["type"] == "base64" assert result[0].content[1]["source"]["type"] == "base64"
assert result[0].content[1]["source"]["media_type"] == "image/jpeg" assert result[0].content[1]["source"]["media_type"] == "image/jpeg"
@ -686,7 +686,7 @@ def test_format_messages_as_anthropic_image() -> None:
def test_format_messages_as_tool_message() -> None: def test_format_messages_as_tool_message() -> None:
tool_message = ToolMessage(content="Tool result", tool_call_id="123") tool_message = ToolMessage(content="Tool result", tool_call_id="123")
result = format_messages_as([tool_message], format="openai", text="block") result = format_messages_as([tool_message], format="openai", text_format="block")
assert isinstance(result[0], ToolMessage) assert isinstance(result[0], ToolMessage)
assert result[0].content == [{"type": "text", "text": "Tool result"}] assert result[0].content == [{"type": "text", "text": "Tool result"}]
assert result[0].tool_call_id == "123" assert result[0].tool_call_id == "123"
@ -700,7 +700,7 @@ def test_format_messages_as_tool_use() -> None:
] ]
) )
] ]
result = format_messages_as(messages, format="openai", text="block") result = format_messages_as(messages, format="openai", text_format="block")
assert result[0].tool_calls[0]["id"] == "123" assert result[0].tool_calls[0]["id"] == "123"
assert result[0].tool_calls[0]["name"] == "calculator" assert result[0].tool_calls[0]["name"] == "calculator"
assert result[0].tool_calls[0]["args"] == "2+2" assert result[0].tool_calls[0]["args"] == "2+2"
@ -709,7 +709,7 @@ def test_format_messages_as_tool_use() -> None:
def test_format_messages_as_json() -> None: def test_format_messages_as_json() -> None:
json_data = {"key": "value"} json_data = {"key": "value"}
messages = [HumanMessage(content=[{"type": "json", "json": json_data}])] messages = [HumanMessage(content=[{"type": "json", "json": json_data}])]
result = format_messages_as(messages, format="openai", text="block") result = format_messages_as(messages, format="openai", text_format="block")
assert result[0].content[0]["type"] == "text" assert result[0].content[0]["type"] == "text"
assert json.loads(result[0].content[0]["text"]) == json_data assert json.loads(result[0].content[0]["text"]) == json_data
@ -725,7 +725,7 @@ def test_format_messages_as_guard_content() -> None:
] ]
) )
] ]
result = format_messages_as(messages, format="openai", text="block") result = format_messages_as(messages, format="openai", text_format="block")
assert result[0].content[0]["type"] == "text" assert result[0].content[0]["type"] == "text"
assert result[0].content[0]["text"] == "Protected content" assert result[0].content[0]["text"] == "Protected content"
@ -738,7 +738,7 @@ def test_format_messages_as_vertexai_image() -> None:
] ]
) )
] ]
result = format_messages_as(messages, format="openai", text="block") result = format_messages_as(messages, format="openai", text_format="block")
assert result[0].content[0]["type"] == "image_url" assert result[0].content[0]["type"] == "image_url"
assert ( assert (
result[0].content[0]["image_url"]["url"] result[0].content[0]["image_url"]["url"]
@ -749,20 +749,20 @@ def test_format_messages_as_vertexai_image() -> None:
def test_format_messages_as_invalid_block() -> None: def test_format_messages_as_invalid_block() -> None:
messages = [HumanMessage(content=[{"type": "invalid", "foo": "bar"}])] messages = [HumanMessage(content=[{"type": "invalid", "foo": "bar"}])]
with pytest.raises(ValueError, match="Unrecognized content block"): with pytest.raises(ValueError, match="Unrecognized content block"):
format_messages_as(messages, format="openai", text="block") format_messages_as(messages, format="openai", text_format="block")
with pytest.raises(ValueError, match="Unrecognized content block"): with pytest.raises(ValueError, match="Unrecognized content block"):
format_messages_as(messages, format="anthropic", text="block") format_messages_as(messages, format="anthropic", text_format="block")
def test_format_messages_as_empty_message() -> None: def test_format_messages_as_empty_message() -> None:
result = format_messages_as( result = format_messages_as(
HumanMessage(content=""), format="openai", text="string" HumanMessage(content=""), format="openai", text_format="string"
) )
assert result.content == "" assert result.content == ""
def test_format_messages_as_empty_list() -> None: def test_format_messages_as_empty_list() -> None:
result = format_messages_as([], format="openai", text="string") result = format_messages_as([], format="openai", text_format="string")
assert result == [] assert result == []
@ -776,7 +776,7 @@ def test_format_messages_as_mixed_content_types() -> None:
] ]
) )
] ]
result = format_messages_as(messages, format="openai", text="block") result = format_messages_as(messages, format="openai", text_format="block")
assert len(result[0].content) == 3 assert len(result[0].content) == 3
assert isinstance(result[0].content[0], dict) assert isinstance(result[0].content[0], dict)
assert isinstance(result[0].content[1], dict) assert isinstance(result[0].content[1], dict)
@ -790,7 +790,7 @@ def test_format_messages_as_anthropic_tool_calls() -> None:
{"type": "tool_call", "name": "foo", "id": "1", "args": {"bar": "baz"}} {"type": "tool_call", "name": "foo", "id": "1", "args": {"bar": "baz"}}
], ],
) )
result = format_messages_as(message, format="anthropic", text="string") result = format_messages_as(message, format="anthropic", text_format="string")
assert result.content == [ assert result.content == [
{"type": "text", "text": "blah"}, {"type": "text", "text": "blah"},
{"type": "tool_use", "id": "1", "name": "foo", "input": {"bar": "baz"}}, {"type": "tool_use", "id": "1", "name": "foo", "input": {"bar": "baz"}},
@ -799,7 +799,7 @@ def test_format_messages_as_anthropic_tool_calls() -> None:
def test_format_messages_as_declarative() -> None: def test_format_messages_as_declarative() -> None:
formatter = format_messages_as(format="openai", text="block") formatter = format_messages_as(format="openai", text_format="block")
base64_image = create_base64_image() base64_image = create_base64_image()
messages = [ messages = [
HumanMessage( HumanMessage(
@ -996,7 +996,7 @@ def _stream_anthropic(input_: Any) -> Iterator:
@pytest.mark.parametrize("stream", [_stream_oai, _stream_anthropic]) @pytest.mark.parametrize("stream", [_stream_oai, _stream_anthropic])
def test_format_messages_openai_string_stream(stream: Callable) -> None: def test_format_messages_openai_string_stream(stream: Callable) -> None:
formatter = format_messages_as(format="openai", text="string") formatter = format_messages_as(format="openai", text_format="string")
chain = RunnableLambda(stream) | formatter chain = RunnableLambda(stream) | formatter
tool_call_idx = 1 if stream == _stream_anthropic else 0 tool_call_idx = 1 if stream == _stream_anthropic else 0
@ -1090,7 +1090,7 @@ def test_format_messages_openai_string_stream(stream: Callable) -> None:
@pytest.mark.parametrize("stream", [_stream_oai, _stream_anthropic]) @pytest.mark.parametrize("stream", [_stream_oai, _stream_anthropic])
def test_format_messages_openai_block_stream(stream: Callable) -> None: def test_format_messages_openai_block_stream(stream: Callable) -> None:
formatter = format_messages_as(format="openai", text="block") formatter = format_messages_as(format="openai", text_format="block")
chain = RunnableLambda(stream) | formatter chain = RunnableLambda(stream) | formatter
tool_call_idx = 1 if stream == _stream_anthropic else 0 tool_call_idx = 1 if stream == _stream_anthropic else 0
@ -1183,7 +1183,7 @@ def test_format_messages_openai_block_stream(stream: Callable) -> None:
@pytest.mark.parametrize("stream", [_stream_oai, _stream_anthropic]) @pytest.mark.parametrize("stream", [_stream_oai, _stream_anthropic])
def test_format_messages_anthropic_block_stream(stream: Callable) -> None: def test_format_messages_anthropic_block_stream(stream: Callable) -> None:
formatter = format_messages_as(format="anthropic", text="block") formatter = format_messages_as(format="anthropic", text_format="block")
chain = RunnableLambda(stream) | formatter chain = RunnableLambda(stream) | formatter
expected = [ expected = [
@ -1285,7 +1285,7 @@ def test_format_messages_anthropic_block_stream(stream: Callable) -> None:
@pytest.mark.parametrize("stream", [_stream_oai, _stream_anthropic]) @pytest.mark.parametrize("stream", [_stream_oai, _stream_anthropic])
def test_format_messages_anthropic_string_stream(stream: Callable) -> None: def test_format_messages_anthropic_string_stream(stream: Callable) -> None:
formatter = format_messages_as(format="anthropic", text="string") formatter = format_messages_as(format="anthropic", text_format="string")
chain = RunnableLambda(stream) | formatter chain = RunnableLambda(stream) | formatter
expected = [ expected = [