core[patch]: Tool accept RunnableConfig (#24143)

Relies on #24038

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Bagatur
2024-07-11 15:13:17 -07:00
committed by GitHub
parent 5fd1e67808
commit 8d100c58de
2 changed files with 117 additions and 9 deletions

View File

@@ -1,6 +1,5 @@
"""Test the base tool implementation."""
import asyncio
import inspect
import json
import sys
@@ -19,7 +18,12 @@ from langchain_core.callbacks import (
)
from langchain_core.messages import ToolMessage
from langchain_core.pydantic_v1 import BaseModel, ValidationError
from langchain_core.runnables import Runnable, RunnableLambda, ensure_config
from langchain_core.runnables import (
Runnable,
RunnableConfig,
RunnableLambda,
ensure_config,
)
from langchain_core.tools import (
BaseTool,
SchemaAnnotationError,
@@ -914,7 +918,6 @@ async def test_async_tool_pass_context() -> None:
@tool
async def foo(bar: str) -> str:
"""The foo."""
await asyncio.sleep(0.0001)
config = ensure_config()
assert config["configurable"]["foo"] == "not-bar"
assert bar == "baz"
@@ -925,6 +928,64 @@ async def test_async_tool_pass_context() -> None:
)
def assert_bar(bar: Any, bar_config: RunnableConfig) -> Any:
assert bar_config["configurable"]["foo"] == "not-bar"
assert bar == "baz"
return bar
@tool
def foo(bar: Any, bar_config: RunnableConfig) -> Any:
"""The foo."""
return assert_bar(bar, bar_config)
@tool
async def afoo(bar: Any, bar_config: RunnableConfig) -> Any:
"""The foo."""
return assert_bar(bar, bar_config)
@tool(infer_schema=False)
def simple_foo(bar: Any, bar_config: RunnableConfig) -> Any:
"""The foo."""
return assert_bar(bar, bar_config)
@tool(infer_schema=False)
async def asimple_foo(bar: Any, bar_config: RunnableConfig) -> Any:
"""The foo."""
return assert_bar(bar, bar_config)
class FooBase(BaseTool):
name: str = "Foo"
description: str = "Foo"
def _run(self, bar: Any, bar_config: RunnableConfig, **kwargs: Any) -> Any:
return assert_bar(bar, bar_config)
class AFooBase(FooBase):
async def _arun(self, bar: Any, bar_config: RunnableConfig, **kwargs: Any) -> Any:
return assert_bar(bar, bar_config)
@pytest.mark.parametrize("tool", [foo, simple_foo, FooBase(), AFooBase()])
def test_tool_pass_config(tool: BaseTool) -> None:
assert tool.invoke({"bar": "baz"}, {"configurable": {"foo": "not-bar"}}) == "baz"
@pytest.mark.parametrize(
"tool", [foo, afoo, simple_foo, asimple_foo, FooBase(), AFooBase()]
)
async def test_async_tool_pass_config(tool: BaseTool) -> None:
assert (
await tool.ainvoke({"bar": "baz"}, {"configurable": {"foo": "not-bar"}})
== "baz"
)
def test_tool_description() -> None:
def foo(bar: str) -> str:
"""The foo."""