standard-tests[patch]: add test for async tool calling (#28133)

This commit is contained in:
ccurme
2024-11-15 11:09:50 -05:00
committed by GitHub
parent ed4952e475
commit 776e3271e3
2 changed files with 25 additions and 0 deletions

View File

@@ -53,6 +53,10 @@ class TestHuggingFaceEndpoint(ChatModelIntegrationTests):
def test_tool_calling(self, model: BaseChatModel) -> None:
super().test_tool_calling(model)
@pytest.mark.xfail(reason=("Not implemented"))
async def test_tool_calling_async(self, model: BaseChatModel) -> None:
await super().test_tool_calling_async(model)
@pytest.mark.xfail(reason=("Not implemented"))
def test_tool_calling_with_no_arguments(self, model: BaseChatModel) -> None:
super().test_tool_calling_with_no_arguments(model)

View File

@@ -286,6 +286,27 @@ class ChatModelIntegrationTests(ChatModelTests):
assert isinstance(full, AIMessage)
_validate_tool_call_message(full)
async def test_tool_calling_async(self, model: BaseChatModel) -> None:
if not self.has_tool_calling:
pytest.skip("Test requires tool calling.")
if self.tool_choice_value == "tool_name":
tool_choice: Optional[str] = "magic_function"
else:
tool_choice = self.tool_choice_value
model_with_tools = model.bind_tools([magic_function], tool_choice=tool_choice)
# Test ainvoke
query = "What is the value of magic_function(3)? Use the tool."
result = await model_with_tools.ainvoke(query)
_validate_tool_call_message(result)
# Test astream
full: Optional[BaseMessageChunk] = None
async for chunk in model_with_tools.astream(query):
full = chunk if full is None else full + chunk # type: ignore
assert isinstance(full, AIMessage)
_validate_tool_call_message(full)
def test_tool_calling_with_no_arguments(self, model: BaseChatModel) -> None:
if not self.has_tool_calling:
pytest.skip("Test requires tool calling.")