core[patch]: Tool accept RunnableConfig (#24143)

Relies on #24038

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Bagatur 2024-07-11 15:13:17 -07:00 committed by GitHub
parent 5fd1e67808
commit 8d100c58de
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 117 additions and 9 deletions

View File

@ -20,6 +20,7 @@ tool for the job.
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import functools
import inspect import inspect
import json import json
import textwrap import textwrap
@ -548,6 +549,9 @@ class ChildTool(BaseTool):
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input) tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input)
if signature(self._run).parameters.get("run_manager"): if signature(self._run).parameters.get("run_manager"):
tool_kwargs["run_manager"] = run_manager tool_kwargs["run_manager"] = run_manager
if config_param := _get_runnable_config_param(self._run):
tool_kwargs[config_param] = config
response = context.run(self._run, *tool_args, **tool_kwargs) response = context.run(self._run, *tool_args, **tool_kwargs)
if self.response_format == "content_and_raw_output": if self.response_format == "content_and_raw_output":
if not isinstance(response, tuple) or len(response) != 2: if not isinstance(response, tuple) or len(response) != 2:
@ -627,10 +631,14 @@ class ChildTool(BaseTool):
child_config = patch_config(config, callbacks=run_manager.get_child()) child_config = patch_config(config, callbacks=run_manager.get_child())
context = copy_context() context = copy_context()
context.run(_set_config_context, child_config) context.run(_set_config_context, child_config)
if self.__class__._arun is BaseTool._arun or signature( func_to_check = (
self._arun self._run if self.__class__._arun is BaseTool._arun else self._arun
).parameters.get("run_manager"): )
if signature(func_to_check).parameters.get("run_manager"):
tool_kwargs["run_manager"] = run_manager tool_kwargs["run_manager"] = run_manager
if config_param := _get_runnable_config_param(func_to_check):
tool_kwargs[config_param] = config
coro = context.run(self._arun, *tool_args, **tool_kwargs) coro = context.run(self._arun, *tool_args, **tool_kwargs)
if accepts_context(asyncio.create_task): if accepts_context(asyncio.create_task):
response = await asyncio.create_task(coro, context=context) # type: ignore response = await asyncio.create_task(coro, context=context) # type: ignore
@ -724,6 +732,7 @@ class Tool(BaseTool):
def _run( def _run(
self, self,
*args: Any, *args: Any,
config: RunnableConfig,
run_manager: Optional[CallbackManagerForToolRun] = None, run_manager: Optional[CallbackManagerForToolRun] = None,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
@ -731,12 +740,15 @@ class Tool(BaseTool):
if self.func: if self.func:
if run_manager and signature(self.func).parameters.get("callbacks"): if run_manager and signature(self.func).parameters.get("callbacks"):
kwargs["callbacks"] = run_manager.get_child() kwargs["callbacks"] = run_manager.get_child()
if config_param := _get_runnable_config_param(self.func):
kwargs[config_param] = config
return self.func(*args, **kwargs) return self.func(*args, **kwargs)
raise NotImplementedError("Tool does not support sync invocation.") raise NotImplementedError("Tool does not support sync invocation.")
async def _arun( async def _arun(
self, self,
*args: Any, *args: Any,
config: RunnableConfig,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None, run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
@ -744,11 +756,15 @@ class Tool(BaseTool):
if self.coroutine: if self.coroutine:
if run_manager and signature(self.coroutine).parameters.get("callbacks"): if run_manager and signature(self.coroutine).parameters.get("callbacks"):
kwargs["callbacks"] = run_manager.get_child() kwargs["callbacks"] = run_manager.get_child()
if config_param := _get_runnable_config_param(self.coroutine):
kwargs[config_param] = config
return await self.coroutine(*args, **kwargs) return await self.coroutine(*args, **kwargs)
# NOTE: this code is unreachable since _arun is only called if coroutine is not # NOTE: this code is unreachable since _arun is only called if coroutine is not
# None. # None.
return await super()._arun(*args, run_manager=run_manager, **kwargs) return await super()._arun(
*args, config=config, run_manager=run_manager, **kwargs
)
# TODO: this is for backwards compatibility, remove in future # TODO: this is for backwards compatibility, remove in future
def __init__( def __init__(
@ -822,6 +838,7 @@ class StructuredTool(BaseTool):
def _run( def _run(
self, self,
*args: Any, *args: Any,
config: RunnableConfig,
run_manager: Optional[CallbackManagerForToolRun] = None, run_manager: Optional[CallbackManagerForToolRun] = None,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
@ -829,12 +846,15 @@ class StructuredTool(BaseTool):
if self.func: if self.func:
if run_manager and signature(self.func).parameters.get("callbacks"): if run_manager and signature(self.func).parameters.get("callbacks"):
kwargs["callbacks"] = run_manager.get_child() kwargs["callbacks"] = run_manager.get_child()
if config_param := _get_runnable_config_param(self.func):
kwargs[config_param] = config
return self.func(*args, **kwargs) return self.func(*args, **kwargs)
raise NotImplementedError("StructuredTool does not support sync invocation.") raise NotImplementedError("StructuredTool does not support sync invocation.")
async def _arun( async def _arun(
self, self,
*args: Any, *args: Any,
config: RunnableConfig,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None, run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
@ -842,11 +862,15 @@ class StructuredTool(BaseTool):
if self.coroutine: if self.coroutine:
if run_manager and signature(self.coroutine).parameters.get("callbacks"): if run_manager and signature(self.coroutine).parameters.get("callbacks"):
kwargs["callbacks"] = run_manager.get_child() kwargs["callbacks"] = run_manager.get_child()
if config_param := _get_runnable_config_param(self.coroutine):
kwargs[config_param] = config
return await self.coroutine(*args, **kwargs) return await self.coroutine(*args, **kwargs)
# NOTE: this code is unreachable since _arun is only called if coroutine is not # NOTE: this code is unreachable since _arun is only called if coroutine is not
# None. # None.
return await super()._arun(*args, run_manager=run_manager, **kwargs) return await super()._arun(
*args, config=config, run_manager=run_manager, **kwargs
)
@classmethod @classmethod
def from_function( def from_function(
@ -923,12 +947,21 @@ class StructuredTool(BaseTool):
description_ = f"{description_.strip()}" description_ = f"{description_.strip()}"
_args_schema = args_schema _args_schema = args_schema
if _args_schema is None and infer_schema: if _args_schema is None and infer_schema:
if config_param := _get_runnable_config_param(source_function):
filter_args: Tuple[str, ...] = (
config_param,
"run_manager",
"callbacks",
)
else:
filter_args = ("run_manager", "callbacks")
# schema name is appended within function # schema name is appended within function
_args_schema = create_schema_from_function( _args_schema = create_schema_from_function(
name, name,
source_function, source_function,
parse_docstring=parse_docstring, parse_docstring=parse_docstring,
error_on_invalid_docstring=error_on_invalid_docstring, error_on_invalid_docstring=error_on_invalid_docstring,
filter_args=filter_args,
) )
return cls( return cls(
name=name, name=name,
@ -1112,7 +1145,7 @@ def tool(
) )
# If someone doesn't want a schema applied, we must treat it as # If someone doesn't want a schema applied, we must treat it as
# a simple string->string function # a simple string->string function
if func.__doc__ is None: if dec_func.__doc__ is None:
raise ValueError( raise ValueError(
"Function must have a docstring if " "Function must have a docstring if "
"description not provided and infer_schema is False." "description not provided and infer_schema is False."
@ -1447,3 +1480,17 @@ def convert_runnable_to_tool(
description=description, description=description,
args_schema=args_schema, args_schema=args_schema,
) )
def _get_runnable_config_param(func: Callable) -> Optional[str]:
if isinstance(func, functools.partial):
func = func.func
try:
type_hints = get_type_hints(func)
except Exception:
return None
else:
for name, type_ in type_hints.items():
if type_ is RunnableConfig:
return name
return None

View File

@ -1,6 +1,5 @@
"""Test the base tool implementation.""" """Test the base tool implementation."""
import asyncio
import inspect import inspect
import json import json
import sys import sys
@ -19,7 +18,12 @@ from langchain_core.callbacks import (
) )
from langchain_core.messages import ToolMessage from langchain_core.messages import ToolMessage
from langchain_core.pydantic_v1 import BaseModel, ValidationError from langchain_core.pydantic_v1 import BaseModel, ValidationError
from langchain_core.runnables import Runnable, RunnableLambda, ensure_config from langchain_core.runnables import (
Runnable,
RunnableConfig,
RunnableLambda,
ensure_config,
)
from langchain_core.tools import ( from langchain_core.tools import (
BaseTool, BaseTool,
SchemaAnnotationError, SchemaAnnotationError,
@ -914,7 +918,6 @@ async def test_async_tool_pass_context() -> None:
@tool @tool
async def foo(bar: str) -> str: async def foo(bar: str) -> str:
"""The foo.""" """The foo."""
await asyncio.sleep(0.0001)
config = ensure_config() config = ensure_config()
assert config["configurable"]["foo"] == "not-bar" assert config["configurable"]["foo"] == "not-bar"
assert bar == "baz" assert bar == "baz"
@ -925,6 +928,64 @@ async def test_async_tool_pass_context() -> None:
) )
def assert_bar(bar: Any, bar_config: RunnableConfig) -> Any:
assert bar_config["configurable"]["foo"] == "not-bar"
assert bar == "baz"
return bar
@tool
def foo(bar: Any, bar_config: RunnableConfig) -> Any:
"""The foo."""
return assert_bar(bar, bar_config)
@tool
async def afoo(bar: Any, bar_config: RunnableConfig) -> Any:
"""The foo."""
return assert_bar(bar, bar_config)
@tool(infer_schema=False)
def simple_foo(bar: Any, bar_config: RunnableConfig) -> Any:
"""The foo."""
return assert_bar(bar, bar_config)
@tool(infer_schema=False)
async def asimple_foo(bar: Any, bar_config: RunnableConfig) -> Any:
"""The foo."""
return assert_bar(bar, bar_config)
class FooBase(BaseTool):
name: str = "Foo"
description: str = "Foo"
def _run(self, bar: Any, bar_config: RunnableConfig, **kwargs: Any) -> Any:
return assert_bar(bar, bar_config)
class AFooBase(FooBase):
async def _arun(self, bar: Any, bar_config: RunnableConfig, **kwargs: Any) -> Any:
return assert_bar(bar, bar_config)
@pytest.mark.parametrize("tool", [foo, simple_foo, FooBase(), AFooBase()])
def test_tool_pass_config(tool: BaseTool) -> None:
assert tool.invoke({"bar": "baz"}, {"configurable": {"foo": "not-bar"}}) == "baz"
@pytest.mark.parametrize(
"tool", [foo, afoo, simple_foo, asimple_foo, FooBase(), AFooBase()]
)
async def test_async_tool_pass_config(tool: BaseTool) -> None:
assert (
await tool.ainvoke({"bar": "baz"}, {"configurable": {"foo": "not-bar"}})
== "baz"
)
def test_tool_description() -> None: def test_tool_description() -> None:
def foo(bar: str) -> str: def foo(bar: str) -> str:
"""The foo.""" """The foo."""