From c417bbc313b2a6123c108bb1410c5d4b7a3587a5 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Wed, 11 Sep 2024 11:12:52 -0400 Subject: [PATCH] core[patch]: Add default None to StructuredTool func (#26324) This PR was autogenerated using gritql, tests written manually ```shell grit apply 'class_definition(name=$C, $body, superclasses=$S) where { $C <: ! "Config", // Does not work in this scope, but works after class_definition $body <: block($statements), $statements <: some bubble assignment(left=$x, right=$y, type=$t) as $A where { or { $y <: `Field($z)`, $x <: "model_config" } }, // And has either Any or Optional fields without a default $statements <: some bubble assignment(left=$x, right=$y, type=$t) as $A where { $t <: or { r"Optional.*", r"Any", r"Union[None, .*]", r"Union[.*, None, .*]", r"Union[.*, None]", }, $y <: ., // Match empty node $t => `$t = None`, }, } ' --language python . ``` --- libs/core/langchain_core/tools/structured.py | 6 +++--- libs/core/tests/unit_tests/test_tools.py | 17 ++++++++++++++++- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/libs/core/langchain_core/tools/structured.py b/libs/core/langchain_core/tools/structured.py index 248678d3801..a30defdd77b 100644 --- a/libs/core/langchain_core/tools/structured.py +++ b/libs/core/langchain_core/tools/structured.py @@ -40,7 +40,7 @@ class StructuredTool(BaseTool): ..., description="The tool schema." ) """The input arguments' schema.""" - func: Optional[Callable[..., Any]] + func: Optional[Callable[..., Any]] = None """The function to run when the tool is called.""" coroutine: Optional[Callable[..., Awaitable[Any]]] = None """The asynchronous version of the function.""" @@ -98,8 +98,8 @@ class StructuredTool(BaseTool): kwargs[config_param] = config return await self.coroutine(*args, **kwargs) - # NOTE: this code is unreachable since _arun is only called if coroutine is not - # None. + # If self.coroutine is None, then this will delegate to the default + # implementation which is expected to delegate to _run on a separate thread. return await super()._arun( *args, config=config, run_manager=run_manager, **kwargs ) diff --git a/libs/core/tests/unit_tests/test_tools.py b/libs/core/tests/unit_tests/test_tools.py index c1d07762cf9..e2dabc1c125 100644 --- a/libs/core/tests/unit_tests/test_tools.py +++ b/libs/core/tests/unit_tests/test_tools.py @@ -54,7 +54,6 @@ from langchain_core.tools.base import ( ) from langchain_core.utils.function_calling import convert_to_openai_function from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION, _create_subset_model -from langchain_core.utils.pydantic import TypeBaseModel as TypeBaseModel from tests.unit_tests.fake.callbacks import FakeCallbackHandler from tests.unit_tests.pydantic_utils import _schema @@ -1978,3 +1977,19 @@ def test_imports() -> None: ] for module_name in expected_all: assert hasattr(tools, module_name) and getattr(tools, module_name) is not None + + +def test_structured_tool_direct_init() -> None: + def foo(bar: str) -> str: + return bar + + async def asyncFoo(bar: str) -> str: + return bar + + class fooSchema(BaseModel): + bar: str = Field(..., description="The bar") + + tool = StructuredTool(name="foo", args_schema=fooSchema, coroutine=asyncFoo) + + with pytest.raises(NotImplementedError): + assert tool.invoke("hello") == "hello"