1
0
mirror of https://github.com/hwchase17/langchain.git synced 2025-09-23 03:19:38 +00:00

standard-tests: fix decorator init test ()

This commit is contained in:
Erick Friis
2024-11-20 19:35:43 -08:00
committed by GitHub
parent 60e572f591
commit 4bdf1d7d1a
2 changed files with 48 additions and 3 deletions
libs/standard-tests
langchain_tests
unit_tests
tests

@@ -1,6 +1,6 @@
import os import os
from abc import abstractmethod from abc import abstractmethod
from typing import Callable, Tuple, Type, Union from typing import Tuple, Type, Union
from unittest import mock from unittest import mock
import pytest import pytest
@@ -13,7 +13,7 @@ from langchain_tests.base import BaseStandardTests
class ToolsTests(BaseStandardTests): class ToolsTests(BaseStandardTests):
@property @property
@abstractmethod @abstractmethod
def tool_constructor(self) -> Union[Type[BaseTool], Callable, BaseTool]: ... def tool_constructor(self) -> Union[Type[BaseTool], BaseTool]: ...
@property @property
def tool_constructor_params(self) -> dict: def tool_constructor_params(self) -> dict:
@@ -44,6 +44,9 @@ class ToolsTests(BaseStandardTests):
class ToolsUnitTests(ToolsTests): class ToolsUnitTests(ToolsTests):
def test_init(self) -> None: def test_init(self) -> None:
if isinstance(self.tool_constructor, BaseTool):
tool = self.tool_constructor
else:
tool = self.tool_constructor(**self.tool_constructor_params) tool = self.tool_constructor(**self.tool_constructor_params)
assert tool is not None assert tool is not None

@@ -0,0 +1,42 @@
from langchain_core.tools import BaseTool, tool
from langchain_tests.integration_tests import ToolsIntegrationTests
from langchain_tests.unit_tests import ToolsUnitTests
@tool
def parrot_multiply_tool(a: int, b: int) -> int:
"""Multiply two numbers like a parrot. Parrots always add eighty for their matey."""
return a * b + 80
class TestParrotMultiplyToolUnit(ToolsUnitTests):
@property
def tool_constructor(self) -> BaseTool:
return parrot_multiply_tool
@property
def tool_invoke_params_example(self) -> dict:
"""
Returns a dictionary representing the "args" of an example tool call.
This should NOT be a ToolCall dict - i.e. it should not
have {"name", "id", "args"} keys.
"""
return {"a": 2, "b": 3}
class TestParrotMultiplyToolIntegration(ToolsIntegrationTests):
@property
def tool_constructor(self) -> BaseTool:
return parrot_multiply_tool
@property
def tool_invoke_params_example(self) -> dict:
"""
Returns a dictionary representing the "args" of an example tool call.
This should NOT be a ToolCall dict - i.e. it should not
have {"name", "id", "args"} keys.
"""
return {"a": 2, "b": 3}