mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-16 15:04:13 +00:00
core[patch]: Tool accept RunnableConfig (#24143)
Relies on #24038 --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
@@ -1,6 +1,5 @@
|
||||
"""Test the base tool implementation."""
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import sys
|
||||
@@ -19,7 +18,12 @@ from langchain_core.callbacks import (
|
||||
)
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langchain_core.pydantic_v1 import BaseModel, ValidationError
|
||||
from langchain_core.runnables import Runnable, RunnableLambda, ensure_config
|
||||
from langchain_core.runnables import (
|
||||
Runnable,
|
||||
RunnableConfig,
|
||||
RunnableLambda,
|
||||
ensure_config,
|
||||
)
|
||||
from langchain_core.tools import (
|
||||
BaseTool,
|
||||
SchemaAnnotationError,
|
||||
@@ -914,7 +918,6 @@ async def test_async_tool_pass_context() -> None:
|
||||
@tool
|
||||
async def foo(bar: str) -> str:
|
||||
"""The foo."""
|
||||
await asyncio.sleep(0.0001)
|
||||
config = ensure_config()
|
||||
assert config["configurable"]["foo"] == "not-bar"
|
||||
assert bar == "baz"
|
||||
@@ -925,6 +928,64 @@ async def test_async_tool_pass_context() -> None:
|
||||
)
|
||||
|
||||
|
||||
def assert_bar(bar: Any, bar_config: RunnableConfig) -> Any:
|
||||
assert bar_config["configurable"]["foo"] == "not-bar"
|
||||
assert bar == "baz"
|
||||
return bar
|
||||
|
||||
|
||||
@tool
|
||||
def foo(bar: Any, bar_config: RunnableConfig) -> Any:
|
||||
"""The foo."""
|
||||
return assert_bar(bar, bar_config)
|
||||
|
||||
|
||||
@tool
|
||||
async def afoo(bar: Any, bar_config: RunnableConfig) -> Any:
|
||||
"""The foo."""
|
||||
return assert_bar(bar, bar_config)
|
||||
|
||||
|
||||
@tool(infer_schema=False)
|
||||
def simple_foo(bar: Any, bar_config: RunnableConfig) -> Any:
|
||||
"""The foo."""
|
||||
return assert_bar(bar, bar_config)
|
||||
|
||||
|
||||
@tool(infer_schema=False)
|
||||
async def asimple_foo(bar: Any, bar_config: RunnableConfig) -> Any:
|
||||
"""The foo."""
|
||||
return assert_bar(bar, bar_config)
|
||||
|
||||
|
||||
class FooBase(BaseTool):
|
||||
name: str = "Foo"
|
||||
description: str = "Foo"
|
||||
|
||||
def _run(self, bar: Any, bar_config: RunnableConfig, **kwargs: Any) -> Any:
|
||||
return assert_bar(bar, bar_config)
|
||||
|
||||
|
||||
class AFooBase(FooBase):
|
||||
async def _arun(self, bar: Any, bar_config: RunnableConfig, **kwargs: Any) -> Any:
|
||||
return assert_bar(bar, bar_config)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tool", [foo, simple_foo, FooBase(), AFooBase()])
|
||||
def test_tool_pass_config(tool: BaseTool) -> None:
|
||||
assert tool.invoke({"bar": "baz"}, {"configurable": {"foo": "not-bar"}}) == "baz"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"tool", [foo, afoo, simple_foo, asimple_foo, FooBase(), AFooBase()]
|
||||
)
|
||||
async def test_async_tool_pass_config(tool: BaseTool) -> None:
|
||||
assert (
|
||||
await tool.ainvoke({"bar": "baz"}, {"configurable": {"foo": "not-bar"}})
|
||||
== "baz"
|
||||
)
|
||||
|
||||
|
||||
def test_tool_description() -> None:
|
||||
def foo(bar: str) -> str:
|
||||
"""The foo."""
|
||||
|
Reference in New Issue
Block a user