standard-tests: split tool calling test (#20803)

just making it a bit easier to grok
This commit is contained in:
Erick Friis 2024-04-23 13:59:45 -07:00 committed by GitHub
parent 6622829c67
commit ddc2274aea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -117,13 +117,16 @@ class ChatModelIntegrationTests(ABC):
assert isinstance(result.content, str) assert isinstance(result.content, str)
assert len(result.content) > 0 assert len(result.content) > 0
def test_tool_message_histories( def test_tool_message_histories_string_content(
self, self,
chat_model_class: Type[BaseChatModel], chat_model_class: Type[BaseChatModel],
chat_model_params: dict, chat_model_params: dict,
chat_model_has_tool_calling: bool, chat_model_has_tool_calling: bool,
) -> None: ) -> None:
"""Test that message histories are compatible across providers.""" """
Test that message histories are compatible with string tool contents
(e.g. OpenAI).
"""
if not chat_model_has_tool_calling: if not chat_model_has_tool_calling:
pytest.skip("Test requires tool calling.") pytest.skip("Test requires tool calling.")
model = chat_model_class(**chat_model_params) model = chat_model_class(**chat_model_params)
@ -131,15 +134,10 @@ class ChatModelIntegrationTests(ABC):
function_name = "my_adder_tool" function_name = "my_adder_tool"
function_args = {"a": "1", "b": "2"} function_args = {"a": "1", "b": "2"}
human_message = HumanMessage(content="What is 1 + 2") messages_string_content = [
tool_message = ToolMessage( HumanMessage(content="What is 1 + 2"),
name=function_name, # string content (e.g. OpenAI)
content=json.dumps({"result": 3}), AIMessage(
tool_call_id="abc123",
)
# String content (e.g., OpenAI)
string_content_msg = AIMessage(
content="", content="",
tool_calls=[ tool_calls=[
{ {
@ -148,17 +146,37 @@ class ChatModelIntegrationTests(ABC):
"id": "abc123", "id": "abc123",
}, },
], ],
) ),
messages = [ ToolMessage(
human_message, name=function_name,
string_content_msg, content=json.dumps({"result": 3}),
tool_message, tool_call_id="abc123",
),
] ]
result = model_with_tools.invoke(messages) result_string_content = model_with_tools.invoke(messages_string_content)
assert isinstance(result, AIMessage) assert isinstance(result_string_content, AIMessage)
def test_tool_message_histories_list_content(
self,
chat_model_class: Type[BaseChatModel],
chat_model_params: dict,
chat_model_has_tool_calling: bool,
) -> None:
"""
Test that message histories are compatible with list tool contents
(e.g. Anthropic).
"""
if not chat_model_has_tool_calling:
pytest.skip("Test requires tool calling.")
model = chat_model_class(**chat_model_params)
model_with_tools = model.bind_tools([my_adder_tool])
function_name = "my_adder_tool"
function_args = {"a": 1, "b": 2}
messages_list_content = [
HumanMessage(content="What is 1 + 2"),
# List content (e.g., Anthropic) # List content (e.g., Anthropic)
list_content_msg = AIMessage( AIMessage(
content=[ content=[
{"type": "text", "text": "some text"}, {"type": "text", "text": "some text"},
{ {
@ -175,11 +193,12 @@ class ChatModelIntegrationTests(ABC):
"id": "abc123", "id": "abc123",
}, },
], ],
) ),
messages = [ ToolMessage(
human_message, name=function_name,
list_content_msg, content=json.dumps({"result": 3}),
tool_message, tool_call_id="abc123",
),
] ]
result = model_with_tools.invoke(messages) result_list_content = model_with_tools.invoke(messages_list_content)
assert isinstance(result, AIMessage) assert isinstance(result_list_content, AIMessage)