This commit is contained in:
Eugene Yurtsev
2024-01-17 12:17:45 -05:00
parent c56aa0a024
commit ce987cdc68
3 changed files with 47 additions and 9 deletions

View File

@@ -333,7 +333,11 @@ class ChildTool(BaseTool):
tool_input if isinstance(tool_input, str) else str(tool_input),
color=start_color,
name=run_name,
inputs=tool_input,
# Inputs by definition should always be dicts.
# For now, it's unclear whether this assumption is ever violated,
# but if it is we will send a `None` value to the callback instead
# And will need to address issue via a patch.
inputs=None if isinstance(tool_input, str) else tool_input,
**kwargs,
)
try:
@@ -742,7 +746,7 @@ def tool(
return_direct: bool = False,
args_schema: Optional[Type[BaseModel]] = None,
infer_schema: bool = True,
) -> Callable[[Union[str, Callable, Runnable]], BaseTool]:
) -> Callable:
"""Make tools out of functions, can be used with or without arguments.
Args:
@@ -772,9 +776,7 @@ def tool(
return
"""
def _make_with_name(
tool_name: str
) -> Callable[[Union[Callable, Runnable]], BaseTool]:
def _make_with_name(tool_name: str) -> Callable:
def _make_tool(dec_func: Union[Callable, Runnable]) -> BaseTool:
if isinstance(dec_func, Runnable):
runnable = dec_func

View File

@@ -112,7 +112,7 @@ class LangChainTracer(BaseTracer):
metadata: Optional[Dict[str, Any]] = None,
name: Optional[str] = None,
**kwargs: Any,
) -> None:
) -> Any:
"""Start a trace for an LLM run."""
parent_run_id_ = str(parent_run_id) if parent_run_id else None
execution_order = self._get_execution_order(parent_run_id_)

View File

@@ -17,7 +17,7 @@ from langchain_core.runnables import (
RunnableLambda,
)
from langchain_core.runnables.utils import StreamEvent
from langchain_core.tools import tool
from langchain_core.tools import BaseTool, tool
from tests.unit_tests.fake.chat_model import GenericFakeChatModel
@@ -942,5 +942,41 @@ async def test_event_stream_with_retry() -> None:
]
def test_foo() -> None:
assert "Add tool test with single input (old style)" == 1
async def test_with_inheritance_based_tool() -> None:
"""Test with a tool created using inheritance."""
class CustomTool(BaseTool):
name = "is_cat_good_for_me"
description = "Find out what having a cat leads to!"
def _run(self, inputs: str) -> str:
return "success"
tool = CustomTool()
events = await _collect_events(tool.astream_events("q"))
assert events == [
{
"data": {"input": "q"},
"event": "on_tool_start",
"metadata": {},
"name": "is_cat_good_for_me",
"run_id": "",
"tags": [],
},
{
"data": {"chunk": "success"},
"event": "on_tool_stream",
"metadata": {},
"name": "is_cat_good_for_me",
"run_id": "",
"tags": [],
},
{
"data": {"output": "success"},
"event": "on_tool_end",
"metadata": {},
"name": "is_cat_good_for_me",
"run_id": "",
"tags": [],
},
]