standard-tests: fix decorator init test (#28246)

This commit is contained in:
Erick Friis 2024-11-20 19:35:43 -08:00 committed by GitHub
parent 60e572f591
commit 4bdf1d7d1a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 48 additions and 3 deletions

View File

@ -1,6 +1,6 @@
import os import os
from abc import abstractmethod from abc import abstractmethod
from typing import Callable, Tuple, Type, Union from typing import Tuple, Type, Union
from unittest import mock from unittest import mock
import pytest import pytest
@ -13,7 +13,7 @@ from langchain_tests.base import BaseStandardTests
class ToolsTests(BaseStandardTests): class ToolsTests(BaseStandardTests):
@property @property
@abstractmethod @abstractmethod
def tool_constructor(self) -> Union[Type[BaseTool], Callable, BaseTool]: ... def tool_constructor(self) -> Union[Type[BaseTool], BaseTool]: ...
@property @property
def tool_constructor_params(self) -> dict: def tool_constructor_params(self) -> dict:
@ -44,7 +44,10 @@ class ToolsTests(BaseStandardTests):
class ToolsUnitTests(ToolsTests): class ToolsUnitTests(ToolsTests):
def test_init(self) -> None: def test_init(self) -> None:
tool = self.tool_constructor(**self.tool_constructor_params) if isinstance(self.tool_constructor, BaseTool):
tool = self.tool_constructor
else:
tool = self.tool_constructor(**self.tool_constructor_params)
assert tool is not None assert tool is not None
@property @property

View File

@ -0,0 +1,42 @@
from langchain_core.tools import BaseTool, tool
from langchain_tests.integration_tests import ToolsIntegrationTests
from langchain_tests.unit_tests import ToolsUnitTests
@tool
def parrot_multiply_tool(a: int, b: int) -> int:
"""Multiply two numbers like a parrot. Parrots always add eighty for their matey."""
return a * b + 80
class TestParrotMultiplyToolUnit(ToolsUnitTests):
@property
def tool_constructor(self) -> BaseTool:
return parrot_multiply_tool
@property
def tool_invoke_params_example(self) -> dict:
"""
Returns a dictionary representing the "args" of an example tool call.
This should NOT be a ToolCall dict - i.e. it should not
have {"name", "id", "args"} keys.
"""
return {"a": 2, "b": 3}
class TestParrotMultiplyToolIntegration(ToolsIntegrationTests):
@property
def tool_constructor(self) -> BaseTool:
return parrot_multiply_tool
@property
def tool_invoke_params_example(self) -> dict:
"""
Returns a dictionary representing the "args" of an example tool call.
This should NOT be a ToolCall dict - i.e. it should not
have {"name", "id", "args"} keys.
"""
return {"a": 2, "b": 3}