mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-27 14:26:48 +00:00
core[patch]: Add default None to StructuredTool func (#26324)
This PR was autogenerated using gritql, tests written manually ```shell grit apply 'class_definition(name=$C, $body, superclasses=$S) where { $C <: ! "Config", // Does not work in this scope, but works after class_definition $body <: block($statements), $statements <: some bubble assignment(left=$x, right=$y, type=$t) as $A where { or { $y <: `Field($z)`, $x <: "model_config" } }, // And has either Any or Optional fields without a default $statements <: some bubble assignment(left=$x, right=$y, type=$t) as $A where { $t <: or { r"Optional.*", r"Any", r"Union[None, .*]", r"Union[.*, None, .*]", r"Union[.*, None]", }, $y <: ., // Match empty node $t => `$t = None`, }, } ' --language python . ```
This commit is contained in:
@@ -40,7 +40,7 @@ class StructuredTool(BaseTool):
|
|||||||
..., description="The tool schema."
|
..., description="The tool schema."
|
||||||
)
|
)
|
||||||
"""The input arguments' schema."""
|
"""The input arguments' schema."""
|
||||||
func: Optional[Callable[..., Any]]
|
func: Optional[Callable[..., Any]] = None
|
||||||
"""The function to run when the tool is called."""
|
"""The function to run when the tool is called."""
|
||||||
coroutine: Optional[Callable[..., Awaitable[Any]]] = None
|
coroutine: Optional[Callable[..., Awaitable[Any]]] = None
|
||||||
"""The asynchronous version of the function."""
|
"""The asynchronous version of the function."""
|
||||||
@@ -98,8 +98,8 @@ class StructuredTool(BaseTool):
|
|||||||
kwargs[config_param] = config
|
kwargs[config_param] = config
|
||||||
return await self.coroutine(*args, **kwargs)
|
return await self.coroutine(*args, **kwargs)
|
||||||
|
|
||||||
# NOTE: this code is unreachable since _arun is only called if coroutine is not
|
# If self.coroutine is None, then this will delegate to the default
|
||||||
# None.
|
# implementation which is expected to delegate to _run on a separate thread.
|
||||||
return await super()._arun(
|
return await super()._arun(
|
||||||
*args, config=config, run_manager=run_manager, **kwargs
|
*args, config=config, run_manager=run_manager, **kwargs
|
||||||
)
|
)
|
||||||
|
@@ -54,7 +54,6 @@ from langchain_core.tools.base import (
|
|||||||
)
|
)
|
||||||
from langchain_core.utils.function_calling import convert_to_openai_function
|
from langchain_core.utils.function_calling import convert_to_openai_function
|
||||||
from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION, _create_subset_model
|
from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION, _create_subset_model
|
||||||
from langchain_core.utils.pydantic import TypeBaseModel as TypeBaseModel
|
|
||||||
from tests.unit_tests.fake.callbacks import FakeCallbackHandler
|
from tests.unit_tests.fake.callbacks import FakeCallbackHandler
|
||||||
from tests.unit_tests.pydantic_utils import _schema
|
from tests.unit_tests.pydantic_utils import _schema
|
||||||
|
|
||||||
@@ -1978,3 +1977,19 @@ def test_imports() -> None:
|
|||||||
]
|
]
|
||||||
for module_name in expected_all:
|
for module_name in expected_all:
|
||||||
assert hasattr(tools, module_name) and getattr(tools, module_name) is not None
|
assert hasattr(tools, module_name) and getattr(tools, module_name) is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_structured_tool_direct_init() -> None:
|
||||||
|
def foo(bar: str) -> str:
|
||||||
|
return bar
|
||||||
|
|
||||||
|
async def asyncFoo(bar: str) -> str:
|
||||||
|
return bar
|
||||||
|
|
||||||
|
class fooSchema(BaseModel):
|
||||||
|
bar: str = Field(..., description="The bar")
|
||||||
|
|
||||||
|
tool = StructuredTool(name="foo", args_schema=fooSchema, coroutine=asyncFoo)
|
||||||
|
|
||||||
|
with pytest.raises(NotImplementedError):
|
||||||
|
assert tool.invoke("hello") == "hello"
|
||||||
|
Reference in New Issue
Block a user