mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-06 11:37:12 +00:00
core[patch]: Tool accept RunnableConfig (#24143)
Relies on #24038 --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
5fd1e67808
commit
8d100c58de
@ -20,6 +20,7 @@ tool for the job.
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import functools
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
import textwrap
|
import textwrap
|
||||||
@ -548,6 +549,9 @@ class ChildTool(BaseTool):
|
|||||||
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input)
|
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input)
|
||||||
if signature(self._run).parameters.get("run_manager"):
|
if signature(self._run).parameters.get("run_manager"):
|
||||||
tool_kwargs["run_manager"] = run_manager
|
tool_kwargs["run_manager"] = run_manager
|
||||||
|
|
||||||
|
if config_param := _get_runnable_config_param(self._run):
|
||||||
|
tool_kwargs[config_param] = config
|
||||||
response = context.run(self._run, *tool_args, **tool_kwargs)
|
response = context.run(self._run, *tool_args, **tool_kwargs)
|
||||||
if self.response_format == "content_and_raw_output":
|
if self.response_format == "content_and_raw_output":
|
||||||
if not isinstance(response, tuple) or len(response) != 2:
|
if not isinstance(response, tuple) or len(response) != 2:
|
||||||
@ -627,10 +631,14 @@ class ChildTool(BaseTool):
|
|||||||
child_config = patch_config(config, callbacks=run_manager.get_child())
|
child_config = patch_config(config, callbacks=run_manager.get_child())
|
||||||
context = copy_context()
|
context = copy_context()
|
||||||
context.run(_set_config_context, child_config)
|
context.run(_set_config_context, child_config)
|
||||||
if self.__class__._arun is BaseTool._arun or signature(
|
func_to_check = (
|
||||||
self._arun
|
self._run if self.__class__._arun is BaseTool._arun else self._arun
|
||||||
).parameters.get("run_manager"):
|
)
|
||||||
|
if signature(func_to_check).parameters.get("run_manager"):
|
||||||
tool_kwargs["run_manager"] = run_manager
|
tool_kwargs["run_manager"] = run_manager
|
||||||
|
if config_param := _get_runnable_config_param(func_to_check):
|
||||||
|
tool_kwargs[config_param] = config
|
||||||
|
|
||||||
coro = context.run(self._arun, *tool_args, **tool_kwargs)
|
coro = context.run(self._arun, *tool_args, **tool_kwargs)
|
||||||
if accepts_context(asyncio.create_task):
|
if accepts_context(asyncio.create_task):
|
||||||
response = await asyncio.create_task(coro, context=context) # type: ignore
|
response = await asyncio.create_task(coro, context=context) # type: ignore
|
||||||
@ -724,6 +732,7 @@ class Tool(BaseTool):
|
|||||||
def _run(
|
def _run(
|
||||||
self,
|
self,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
|
config: RunnableConfig,
|
||||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
@ -731,12 +740,15 @@ class Tool(BaseTool):
|
|||||||
if self.func:
|
if self.func:
|
||||||
if run_manager and signature(self.func).parameters.get("callbacks"):
|
if run_manager and signature(self.func).parameters.get("callbacks"):
|
||||||
kwargs["callbacks"] = run_manager.get_child()
|
kwargs["callbacks"] = run_manager.get_child()
|
||||||
|
if config_param := _get_runnable_config_param(self.func):
|
||||||
|
kwargs[config_param] = config
|
||||||
return self.func(*args, **kwargs)
|
return self.func(*args, **kwargs)
|
||||||
raise NotImplementedError("Tool does not support sync invocation.")
|
raise NotImplementedError("Tool does not support sync invocation.")
|
||||||
|
|
||||||
async def _arun(
|
async def _arun(
|
||||||
self,
|
self,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
|
config: RunnableConfig,
|
||||||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
@ -744,11 +756,15 @@ class Tool(BaseTool):
|
|||||||
if self.coroutine:
|
if self.coroutine:
|
||||||
if run_manager and signature(self.coroutine).parameters.get("callbacks"):
|
if run_manager and signature(self.coroutine).parameters.get("callbacks"):
|
||||||
kwargs["callbacks"] = run_manager.get_child()
|
kwargs["callbacks"] = run_manager.get_child()
|
||||||
|
if config_param := _get_runnable_config_param(self.coroutine):
|
||||||
|
kwargs[config_param] = config
|
||||||
return await self.coroutine(*args, **kwargs)
|
return await self.coroutine(*args, **kwargs)
|
||||||
|
|
||||||
# NOTE: this code is unreachable since _arun is only called if coroutine is not
|
# NOTE: this code is unreachable since _arun is only called if coroutine is not
|
||||||
# None.
|
# None.
|
||||||
return await super()._arun(*args, run_manager=run_manager, **kwargs)
|
return await super()._arun(
|
||||||
|
*args, config=config, run_manager=run_manager, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
# TODO: this is for backwards compatibility, remove in future
|
# TODO: this is for backwards compatibility, remove in future
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -822,6 +838,7 @@ class StructuredTool(BaseTool):
|
|||||||
def _run(
|
def _run(
|
||||||
self,
|
self,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
|
config: RunnableConfig,
|
||||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
@ -829,12 +846,15 @@ class StructuredTool(BaseTool):
|
|||||||
if self.func:
|
if self.func:
|
||||||
if run_manager and signature(self.func).parameters.get("callbacks"):
|
if run_manager and signature(self.func).parameters.get("callbacks"):
|
||||||
kwargs["callbacks"] = run_manager.get_child()
|
kwargs["callbacks"] = run_manager.get_child()
|
||||||
|
if config_param := _get_runnable_config_param(self.func):
|
||||||
|
kwargs[config_param] = config
|
||||||
return self.func(*args, **kwargs)
|
return self.func(*args, **kwargs)
|
||||||
raise NotImplementedError("StructuredTool does not support sync invocation.")
|
raise NotImplementedError("StructuredTool does not support sync invocation.")
|
||||||
|
|
||||||
async def _arun(
|
async def _arun(
|
||||||
self,
|
self,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
|
config: RunnableConfig,
|
||||||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
@ -842,11 +862,15 @@ class StructuredTool(BaseTool):
|
|||||||
if self.coroutine:
|
if self.coroutine:
|
||||||
if run_manager and signature(self.coroutine).parameters.get("callbacks"):
|
if run_manager and signature(self.coroutine).parameters.get("callbacks"):
|
||||||
kwargs["callbacks"] = run_manager.get_child()
|
kwargs["callbacks"] = run_manager.get_child()
|
||||||
|
if config_param := _get_runnable_config_param(self.coroutine):
|
||||||
|
kwargs[config_param] = config
|
||||||
return await self.coroutine(*args, **kwargs)
|
return await self.coroutine(*args, **kwargs)
|
||||||
|
|
||||||
# NOTE: this code is unreachable since _arun is only called if coroutine is not
|
# NOTE: this code is unreachable since _arun is only called if coroutine is not
|
||||||
# None.
|
# None.
|
||||||
return await super()._arun(*args, run_manager=run_manager, **kwargs)
|
return await super()._arun(
|
||||||
|
*args, config=config, run_manager=run_manager, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_function(
|
def from_function(
|
||||||
@ -923,12 +947,21 @@ class StructuredTool(BaseTool):
|
|||||||
description_ = f"{description_.strip()}"
|
description_ = f"{description_.strip()}"
|
||||||
_args_schema = args_schema
|
_args_schema = args_schema
|
||||||
if _args_schema is None and infer_schema:
|
if _args_schema is None and infer_schema:
|
||||||
|
if config_param := _get_runnable_config_param(source_function):
|
||||||
|
filter_args: Tuple[str, ...] = (
|
||||||
|
config_param,
|
||||||
|
"run_manager",
|
||||||
|
"callbacks",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
filter_args = ("run_manager", "callbacks")
|
||||||
# schema name is appended within function
|
# schema name is appended within function
|
||||||
_args_schema = create_schema_from_function(
|
_args_schema = create_schema_from_function(
|
||||||
name,
|
name,
|
||||||
source_function,
|
source_function,
|
||||||
parse_docstring=parse_docstring,
|
parse_docstring=parse_docstring,
|
||||||
error_on_invalid_docstring=error_on_invalid_docstring,
|
error_on_invalid_docstring=error_on_invalid_docstring,
|
||||||
|
filter_args=filter_args,
|
||||||
)
|
)
|
||||||
return cls(
|
return cls(
|
||||||
name=name,
|
name=name,
|
||||||
@ -1112,7 +1145,7 @@ def tool(
|
|||||||
)
|
)
|
||||||
# If someone doesn't want a schema applied, we must treat it as
|
# If someone doesn't want a schema applied, we must treat it as
|
||||||
# a simple string->string function
|
# a simple string->string function
|
||||||
if func.__doc__ is None:
|
if dec_func.__doc__ is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Function must have a docstring if "
|
"Function must have a docstring if "
|
||||||
"description not provided and infer_schema is False."
|
"description not provided and infer_schema is False."
|
||||||
@ -1447,3 +1480,17 @@ def convert_runnable_to_tool(
|
|||||||
description=description,
|
description=description,
|
||||||
args_schema=args_schema,
|
args_schema=args_schema,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_runnable_config_param(func: Callable) -> Optional[str]:
|
||||||
|
if isinstance(func, functools.partial):
|
||||||
|
func = func.func
|
||||||
|
try:
|
||||||
|
type_hints = get_type_hints(func)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
for name, type_ in type_hints.items():
|
||||||
|
if type_ is RunnableConfig:
|
||||||
|
return name
|
||||||
|
return None
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
"""Test the base tool implementation."""
|
"""Test the base tool implementation."""
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
import sys
|
import sys
|
||||||
@ -19,7 +18,12 @@ from langchain_core.callbacks import (
|
|||||||
)
|
)
|
||||||
from langchain_core.messages import ToolMessage
|
from langchain_core.messages import ToolMessage
|
||||||
from langchain_core.pydantic_v1 import BaseModel, ValidationError
|
from langchain_core.pydantic_v1 import BaseModel, ValidationError
|
||||||
from langchain_core.runnables import Runnable, RunnableLambda, ensure_config
|
from langchain_core.runnables import (
|
||||||
|
Runnable,
|
||||||
|
RunnableConfig,
|
||||||
|
RunnableLambda,
|
||||||
|
ensure_config,
|
||||||
|
)
|
||||||
from langchain_core.tools import (
|
from langchain_core.tools import (
|
||||||
BaseTool,
|
BaseTool,
|
||||||
SchemaAnnotationError,
|
SchemaAnnotationError,
|
||||||
@ -914,7 +918,6 @@ async def test_async_tool_pass_context() -> None:
|
|||||||
@tool
|
@tool
|
||||||
async def foo(bar: str) -> str:
|
async def foo(bar: str) -> str:
|
||||||
"""The foo."""
|
"""The foo."""
|
||||||
await asyncio.sleep(0.0001)
|
|
||||||
config = ensure_config()
|
config = ensure_config()
|
||||||
assert config["configurable"]["foo"] == "not-bar"
|
assert config["configurable"]["foo"] == "not-bar"
|
||||||
assert bar == "baz"
|
assert bar == "baz"
|
||||||
@ -925,6 +928,64 @@ async def test_async_tool_pass_context() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def assert_bar(bar: Any, bar_config: RunnableConfig) -> Any:
|
||||||
|
assert bar_config["configurable"]["foo"] == "not-bar"
|
||||||
|
assert bar == "baz"
|
||||||
|
return bar
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def foo(bar: Any, bar_config: RunnableConfig) -> Any:
|
||||||
|
"""The foo."""
|
||||||
|
return assert_bar(bar, bar_config)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def afoo(bar: Any, bar_config: RunnableConfig) -> Any:
|
||||||
|
"""The foo."""
|
||||||
|
return assert_bar(bar, bar_config)
|
||||||
|
|
||||||
|
|
||||||
|
@tool(infer_schema=False)
|
||||||
|
def simple_foo(bar: Any, bar_config: RunnableConfig) -> Any:
|
||||||
|
"""The foo."""
|
||||||
|
return assert_bar(bar, bar_config)
|
||||||
|
|
||||||
|
|
||||||
|
@tool(infer_schema=False)
|
||||||
|
async def asimple_foo(bar: Any, bar_config: RunnableConfig) -> Any:
|
||||||
|
"""The foo."""
|
||||||
|
return assert_bar(bar, bar_config)
|
||||||
|
|
||||||
|
|
||||||
|
class FooBase(BaseTool):
|
||||||
|
name: str = "Foo"
|
||||||
|
description: str = "Foo"
|
||||||
|
|
||||||
|
def _run(self, bar: Any, bar_config: RunnableConfig, **kwargs: Any) -> Any:
|
||||||
|
return assert_bar(bar, bar_config)
|
||||||
|
|
||||||
|
|
||||||
|
class AFooBase(FooBase):
|
||||||
|
async def _arun(self, bar: Any, bar_config: RunnableConfig, **kwargs: Any) -> Any:
|
||||||
|
return assert_bar(bar, bar_config)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("tool", [foo, simple_foo, FooBase(), AFooBase()])
|
||||||
|
def test_tool_pass_config(tool: BaseTool) -> None:
|
||||||
|
assert tool.invoke({"bar": "baz"}, {"configurable": {"foo": "not-bar"}}) == "baz"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"tool", [foo, afoo, simple_foo, asimple_foo, FooBase(), AFooBase()]
|
||||||
|
)
|
||||||
|
async def test_async_tool_pass_config(tool: BaseTool) -> None:
|
||||||
|
assert (
|
||||||
|
await tool.ainvoke({"bar": "baz"}, {"configurable": {"foo": "not-bar"}})
|
||||||
|
== "baz"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_tool_description() -> None:
|
def test_tool_description() -> None:
|
||||||
def foo(bar: str) -> str:
|
def foo(bar: str) -> str:
|
||||||
"""The foo."""
|
"""The foo."""
|
||||||
|
Loading…
Reference in New Issue
Block a user