From 888fbc07b54043d85c697c1b6356372d2230aa3b Mon Sep 17 00:00:00 2001 From: ccurme Date: Mon, 15 Jul 2024 10:51:05 -0400 Subject: [PATCH] core[patch]: support passing `args_schema` through `as_tool` (#24269) Note: this allows the schema to be passed in positionally. ```python from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.runnables import RunnableLambda class Add(BaseModel): """Add two integers together.""" a: int = Field(..., description="First integer") b: int = Field(..., description="Second integer") def add(input: dict) -> int: return input["a"] + input["b"] runnable = RunnableLambda(add) as_tool = runnable.as_tool(Add) as_tool.args_schema.schema() ``` ``` {'title': 'Add', 'description': 'Add two integers together.', 'type': 'object', 'properties': {'a': {'title': 'A', 'description': 'First integer', 'type': 'integer'}, 'b': {'title': 'B', 'description': 'Second integer', 'type': 'integer'}}, 'required': ['a', 'b']} ``` --- .../how_to/convert_runnable_to_tool.ipynb | 18 +++++++--- libs/core/langchain_core/runnables/base.py | 34 +++++++++++++++++-- libs/core/langchain_core/tools.py | 4 +++ libs/core/tests/unit_tests/test_tools.py | 16 +++++++-- 4 files changed, 62 insertions(+), 10 deletions(-) diff --git a/docs/docs/how_to/convert_runnable_to_tool.ipynb b/docs/docs/how_to/convert_runnable_to_tool.ipynb index ed4b51e0972..467bafe736d 100644 --- a/docs/docs/how_to/convert_runnable_to_tool.ipynb +++ b/docs/docs/how_to/convert_runnable_to_tool.ipynb @@ -180,7 +180,7 @@ "id": "32b1a992-8997-4c98-8eb2-c9fe9431b799", "metadata": {}, "source": [ - "Alternatively, we can add typing information via [Runnable.with_types](https://api.python.langchain.com/en/latest/runnables/langchain_core.runnables.base.Runnable.html#langchain_core.runnables.base.Runnable.with_types):" + "Alternatively, the schema can be fully specified by directly passing the desired [args_schema](https://api.python.langchain.com/en/latest/tools/langchain_core.tools.BaseTool.html#langchain_core.tools.BaseTool.args_schema) for the tool:" ] }, { @@ -190,10 +190,18 @@ "metadata": {}, "outputs": [], "source": [ - "as_tool = runnable.with_types(input_type=Args).as_tool(\n", - " name=\"My tool\",\n", - " description=\"Explanation of when to use tool.\",\n", - ")" + "from langchain_core.pydantic_v1 import BaseModel, Field\n", + "\n", + "\n", + "class GSchema(BaseModel):\n", + " \"\"\"Apply a function to an integer and list of integers.\"\"\"\n", + "\n", + " a: int = Field(..., description=\"Integer\")\n", + " b: List[int] = Field(..., description=\"List of ints\")\n", + "\n", + "\n", + "runnable = RunnableLambda(g)\n", + "as_tool = runnable.as_tool(GSchema)" ] }, { diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index 3b874ca4ff6..1658f61b85f 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -2150,6 +2150,7 @@ class Runnable(Generic[Input, Output], ABC): @beta_decorator.beta(message="This API is in beta and may change in the future.") def as_tool( self, + args_schema: Optional[Type[BaseModel]] = None, *, name: Optional[str] = None, description: Optional[str] = None, @@ -2161,9 +2162,11 @@ class Runnable(Generic[Input, Output], ABC): ``args_schema`` from a Runnable. Where possible, schemas are inferred from ``runnable.get_input_schema``. Alternatively (e.g., if the Runnable takes a dict as input and the specific dict keys are not typed), - pass ``arg_types`` to specify the required arguments. + the schema can be specified directly with ``args_schema``. You can also + pass ``arg_types`` to just specify the required arguments and their types. Args: + args_schema: The schema for the tool. Defaults to None. name: The name of the tool. Defaults to None. description: The description of the tool. Defaults to None. arg_types: A dictionary of argument names to types. Defaults to None. @@ -2190,7 +2193,28 @@ class Runnable(Generic[Input, Output], ABC): as_tool = runnable.as_tool() as_tool.invoke({"a": 3, "b": [1, 2]}) - ``dict`` input, specifying schema: + ``dict`` input, specifying schema via ``args_schema``: + + .. code-block:: python + + from typing import Any, Dict, List + from langchain_core.pydantic_v1 import BaseModel, Field + from langchain_core.runnables import RunnableLambda + + def f(x: Dict[str, Any]) -> str: + return str(x["a"] * max(x["b"])) + + class FSchema(BaseModel): + \"\"\"Apply a function to an integer and list of integers.\"\"\" + + a: int = Field(..., description="Integer") + b: List[int] = Field(..., description="List of ints") + + runnable = RunnableLambda(f) + as_tool = runnable.as_tool(FSchema) + as_tool.invoke({"a": 3, "b": [1, 2]}) + + ``dict`` input, specifying schema via ``arg_types``: .. code-block:: python @@ -2226,7 +2250,11 @@ class Runnable(Generic[Input, Output], ABC): from langchain_core.tools import convert_runnable_to_tool return convert_runnable_to_tool( - self, name=name, description=description, arg_types=arg_types + self, + args_schema=args_schema, + name=name, + description=description, + arg_types=arg_types, ) diff --git a/libs/core/langchain_core/tools.py b/libs/core/langchain_core/tools.py index 0794613e3ae..c5b889d76a0 100644 --- a/libs/core/langchain_core/tools.py +++ b/libs/core/langchain_core/tools.py @@ -1438,11 +1438,15 @@ def _get_schema_from_runnable_and_arg_types( def convert_runnable_to_tool( runnable: Runnable, + args_schema: Optional[Type[BaseModel]] = None, + *, name: Optional[str] = None, description: Optional[str] = None, arg_types: Optional[Dict[str, Type]] = None, ) -> BaseTool: """Convert a Runnable into a BaseTool.""" + if args_schema: + runnable = runnable.with_types(input_type=args_schema) description = description or _get_description_from_runnable(runnable) name = name or runnable.get_name() diff --git a/libs/core/tests/unit_tests/test_tools.py b/libs/core/tests/unit_tests/test_tools.py index 9bc15d72d38..13fceb11d94 100644 --- a/libs/core/tests/unit_tests/test_tools.py +++ b/libs/core/tests/unit_tests/test_tools.py @@ -17,7 +17,7 @@ from langchain_core.callbacks import ( CallbackManagerForToolRun, ) from langchain_core.messages import ToolMessage -from langchain_core.pydantic_v1 import BaseModel, ValidationError +from langchain_core.pydantic_v1 import BaseModel, Field, ValidationError from langchain_core.runnables import ( Runnable, RunnableConfig, @@ -1222,10 +1222,22 @@ def test_convert_from_runnable_dict() -> None: assert as_tool.name == "my tool" assert as_tool.description == "test description" - # Dict without typed input-- must supply arg types + # Dict without typed input-- must supply schema def g(x: Dict[str, Any]) -> str: return str(x["a"] * max(x["b"])) + # Specify via args_schema: + class GSchema(BaseModel): + """Apply a function to an integer and list of integers.""" + + a: int = Field(..., description="Integer") + b: List[int] = Field(..., description="List of ints") + + runnable = RunnableLambda(g) + as_tool = runnable.as_tool(GSchema) + as_tool.invoke({"a": 3, "b": [1, 2]}) + + # Specify via arg_types: runnable = RunnableLambda(g) as_tool = runnable.as_tool(arg_types={"a": int, "b": List[int]}) result = as_tool.invoke({"a": 3, "b": [1, 2]})