From ce987cdc681fda25eed7527414f8b85a2d4e14c4 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Wed, 17 Jan 2024 12:17:45 -0500 Subject: [PATCH] q --- libs/core/langchain_core/tools.py | 12 +++--- libs/core/langchain_core/tracers/langchain.py | 2 +- .../runnables/test_runnable_events.py | 42 +++++++++++++++++-- 3 files changed, 47 insertions(+), 9 deletions(-) diff --git a/libs/core/langchain_core/tools.py b/libs/core/langchain_core/tools.py index 05af53ced9f..b83a5cc8a7b 100644 --- a/libs/core/langchain_core/tools.py +++ b/libs/core/langchain_core/tools.py @@ -333,7 +333,11 @@ class ChildTool(BaseTool): tool_input if isinstance(tool_input, str) else str(tool_input), color=start_color, name=run_name, - inputs=tool_input, + # Inputs by definition should always be dicts. + # For now, it's unclear whether this assumption is ever violated, + # but if it is we will send a `None` value to the callback instead + # And will need to address issue via a patch. + inputs=None if isinstance(tool_input, str) else tool_input, **kwargs, ) try: @@ -742,7 +746,7 @@ def tool( return_direct: bool = False, args_schema: Optional[Type[BaseModel]] = None, infer_schema: bool = True, -) -> Callable[[Union[str, Callable, Runnable]], BaseTool]: +) -> Callable: """Make tools out of functions, can be used with or without arguments. Args: @@ -772,9 +776,7 @@ def tool( return """ - def _make_with_name( - tool_name: str - ) -> Callable[[Union[Callable, Runnable]], BaseTool]: + def _make_with_name(tool_name: str) -> Callable: def _make_tool(dec_func: Union[Callable, Runnable]) -> BaseTool: if isinstance(dec_func, Runnable): runnable = dec_func diff --git a/libs/core/langchain_core/tracers/langchain.py b/libs/core/langchain_core/tracers/langchain.py index d5daf108085..bca8a3eb1fd 100644 --- a/libs/core/langchain_core/tracers/langchain.py +++ b/libs/core/langchain_core/tracers/langchain.py @@ -112,7 +112,7 @@ class LangChainTracer(BaseTracer): metadata: Optional[Dict[str, Any]] = None, name: Optional[str] = None, **kwargs: Any, - ) -> None: + ) -> Any: """Start a trace for an LLM run.""" parent_run_id_ = str(parent_run_id) if parent_run_id else None execution_order = self._get_execution_order(parent_run_id_) diff --git a/libs/core/tests/unit_tests/runnables/test_runnable_events.py b/libs/core/tests/unit_tests/runnables/test_runnable_events.py index cf29109b49f..610325ebf91 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable_events.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable_events.py @@ -17,7 +17,7 @@ from langchain_core.runnables import ( RunnableLambda, ) from langchain_core.runnables.utils import StreamEvent -from langchain_core.tools import tool +from langchain_core.tools import BaseTool, tool from tests.unit_tests.fake.chat_model import GenericFakeChatModel @@ -942,5 +942,41 @@ async def test_event_stream_with_retry() -> None: ] -def test_foo() -> None: - assert "Add tool test with single input (old style)" == 1 +async def test_with_inheritance_based_tool() -> None: + """Test with a tool created using inheritance.""" + + class CustomTool(BaseTool): + name = "is_cat_good_for_me" + description = "Find out what having a cat leads to!" + + def _run(self, inputs: str) -> str: + return "success" + + tool = CustomTool() + events = await _collect_events(tool.astream_events("q")) + assert events == [ + { + "data": {"input": "q"}, + "event": "on_tool_start", + "metadata": {}, + "name": "is_cat_good_for_me", + "run_id": "", + "tags": [], + }, + { + "data": {"chunk": "success"}, + "event": "on_tool_stream", + "metadata": {}, + "name": "is_cat_good_for_me", + "run_id": "", + "tags": [], + }, + { + "data": {"output": "success"}, + "event": "on_tool_end", + "metadata": {}, + "name": "is_cat_good_for_me", + "run_id": "", + "tags": [], + }, + ]