diff --git a/libs/standard-tests/langchain_tests/unit_tests/tools.py b/libs/standard-tests/langchain_tests/unit_tests/tools.py index 46d0de7e8ea..93701437da3 100644 --- a/libs/standard-tests/langchain_tests/unit_tests/tools.py +++ b/libs/standard-tests/langchain_tests/unit_tests/tools.py @@ -1,6 +1,6 @@ import os from abc import abstractmethod -from typing import Callable, Tuple, Type, Union +from typing import Tuple, Type, Union from unittest import mock import pytest @@ -13,7 +13,7 @@ from langchain_tests.base import BaseStandardTests class ToolsTests(BaseStandardTests): @property @abstractmethod - def tool_constructor(self) -> Union[Type[BaseTool], Callable, BaseTool]: ... + def tool_constructor(self) -> Union[Type[BaseTool], BaseTool]: ... @property def tool_constructor_params(self) -> dict: @@ -44,7 +44,10 @@ class ToolsTests(BaseStandardTests): class ToolsUnitTests(ToolsTests): def test_init(self) -> None: - tool = self.tool_constructor(**self.tool_constructor_params) + if isinstance(self.tool_constructor, BaseTool): + tool = self.tool_constructor + else: + tool = self.tool_constructor(**self.tool_constructor_params) assert tool is not None @property diff --git a/libs/standard-tests/tests/unit_tests/test_decorated_tool.py b/libs/standard-tests/tests/unit_tests/test_decorated_tool.py new file mode 100644 index 00000000000..ecc8af3b00c --- /dev/null +++ b/libs/standard-tests/tests/unit_tests/test_decorated_tool.py @@ -0,0 +1,42 @@ +from langchain_core.tools import BaseTool, tool + +from langchain_tests.integration_tests import ToolsIntegrationTests +from langchain_tests.unit_tests import ToolsUnitTests + + +@tool +def parrot_multiply_tool(a: int, b: int) -> int: + """Multiply two numbers like a parrot. Parrots always add eighty for their matey.""" + return a * b + 80 + + +class TestParrotMultiplyToolUnit(ToolsUnitTests): + @property + def tool_constructor(self) -> BaseTool: + return parrot_multiply_tool + + @property + def tool_invoke_params_example(self) -> dict: + """ + Returns a dictionary representing the "args" of an example tool call. + + This should NOT be a ToolCall dict - i.e. it should not + have {"name", "id", "args"} keys. + """ + return {"a": 2, "b": 3} + + +class TestParrotMultiplyToolIntegration(ToolsIntegrationTests): + @property + def tool_constructor(self) -> BaseTool: + return parrot_multiply_tool + + @property + def tool_invoke_params_example(self) -> dict: + """ + Returns a dictionary representing the "args" of an example tool call. + + This should NOT be a ToolCall dict - i.e. it should not + have {"name", "id", "args"} keys. + """ + return {"a": 2, "b": 3}