standard-tests: fix decorator init test (#28246)

This commit is contained in:
Erick Friis 2024-11-20 19:35:43 -08:00 committed by GitHub
parent 60e572f591
commit 4bdf1d7d1a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 48 additions and 3 deletions

View File

@ -1,6 +1,6 @@
import os
from abc import abstractmethod
from typing import Callable, Tuple, Type, Union
from typing import Tuple, Type, Union
from unittest import mock
import pytest
@ -13,7 +13,7 @@ from langchain_tests.base import BaseStandardTests
class ToolsTests(BaseStandardTests):
@property
@abstractmethod
def tool_constructor(self) -> Union[Type[BaseTool], Callable, BaseTool]: ...
def tool_constructor(self) -> Union[Type[BaseTool], BaseTool]: ...
@property
def tool_constructor_params(self) -> dict:
@ -44,6 +44,9 @@ class ToolsTests(BaseStandardTests):
class ToolsUnitTests(ToolsTests):
def test_init(self) -> None:
if isinstance(self.tool_constructor, BaseTool):
tool = self.tool_constructor
else:
tool = self.tool_constructor(**self.tool_constructor_params)
assert tool is not None

View File

@ -0,0 +1,42 @@
from langchain_core.tools import BaseTool, tool
from langchain_tests.integration_tests import ToolsIntegrationTests
from langchain_tests.unit_tests import ToolsUnitTests
@tool
def parrot_multiply_tool(a: int, b: int) -> int:
"""Multiply two numbers like a parrot. Parrots always add eighty for their matey."""
return a * b + 80
class TestParrotMultiplyToolUnit(ToolsUnitTests):
@property
def tool_constructor(self) -> BaseTool:
return parrot_multiply_tool
@property
def tool_invoke_params_example(self) -> dict:
"""
Returns a dictionary representing the "args" of an example tool call.
This should NOT be a ToolCall dict - i.e. it should not
have {"name", "id", "args"} keys.
"""
return {"a": 2, "b": 3}
class TestParrotMultiplyToolIntegration(ToolsIntegrationTests):
@property
def tool_constructor(self) -> BaseTool:
return parrot_multiply_tool
@property
def tool_invoke_params_example(self) -> dict:
"""
Returns a dictionary representing the "args" of an example tool call.
This should NOT be a ToolCall dict - i.e. it should not
have {"name", "id", "args"} keys.
"""
return {"a": 2, "b": 3}