From 06660bc78cdd2a110af4e0e9ecd0eff65871a63c Mon Sep 17 00:00:00 2001 From: Erick Friis Date: Fri, 2 Feb 2024 15:05:54 -0800 Subject: [PATCH] core[patch]: handle some optional cases in tools (#16954) primary problem in pydantic still exists, where `Optional[str]` gets turned to `string` in the jsonschema `.schema()` Also fixes the `SchemaSchema` naming issue --------- Co-authored-by: William Fu-Hinthorn <13333726+hinthornw@users.noreply.github.com> --- libs/core/langchain_core/tools.py | 16 +++-- libs/core/tests/unit_tests/test_tools.py | 58 +++++++++++++++++-- .../unit_tests/utils/test_function_calling.py | 26 +++++++-- 3 files changed, 87 insertions(+), 13 deletions(-) diff --git a/libs/core/langchain_core/tools.py b/libs/core/langchain_core/tools.py index b1d9e9ead3c..7e1d82e226e 100644 --- a/libs/core/langchain_core/tools.py +++ b/libs/core/langchain_core/tools.py @@ -39,14 +39,21 @@ class SchemaAnnotationError(TypeError): def _create_subset_model( - name: str, model: BaseModel, field_names: list + name: str, model: Type[BaseModel], field_names: list ) -> Type[BaseModel]: """Create a pydantic model with only a subset of model's fields.""" fields = {} for field_name in field_names: field = model.__fields__[field_name] - fields[field_name] = (field.outer_type_, field.field_info) - return create_model(name, **fields) # type: ignore + t = ( + # this isn't perfect but should work for most functions + field.outer_type_ + if field.required and not field.allow_none + else Optional[field.outer_type_] + ) + fields[field_name] = (t, field.field_info) + rtn = create_model(name, **fields) # type: ignore + return rtn def _get_filtered_args( @@ -764,7 +771,8 @@ class StructuredTool(BaseTool): description = f"{name}{sig} - {description.strip()}" _args_schema = args_schema if _args_schema is None and infer_schema: - _args_schema = create_schema_from_function(f"{name}Schema", source_function) + # schema name is appended within function + _args_schema = create_schema_from_function(name, source_function) return cls( name=name, func=func, diff --git a/libs/core/tests/unit_tests/test_tools.py b/libs/core/tests/unit_tests/test_tools.py index 33b5d0ff9b5..d89f8b2657d 100644 --- a/libs/core/tests/unit_tests/test_tools.py +++ b/libs/core/tests/unit_tests/test_tools.py @@ -1,4 +1,5 @@ """Test the base tool implementation.""" + import json from datetime import datetime from enum import Enum @@ -18,6 +19,7 @@ from langchain_core.tools import ( StructuredTool, Tool, ToolException, + _create_subset_model, tool, ) from tests.unit_tests.fake.callbacks import FakeCallbackHandler @@ -321,7 +323,7 @@ def test_structured_tool_from_function_docstring() -> None: "bar": {"title": "Bar", "type": "integer"}, "baz": {"title": "Baz", "type": "string"}, }, - "title": "fooSchemaSchema", + "title": "fooSchema", "type": "object", "required": ["bar", "baz"], } @@ -354,7 +356,7 @@ def test_structured_tool_from_function_docstring_complex_args() -> None: "bar": {"title": "Bar", "type": "integer"}, "baz": {"title": "Baz", "type": "array", "items": {"type": "string"}}, }, - "title": "fooSchemaSchema", + "title": "fooSchema", "type": "object", "required": ["bar", "baz"], } @@ -454,7 +456,7 @@ def test_structured_tool_from_function_with_run_manager() -> None: "bar": {"title": "Bar", "type": "integer"}, "baz": {"title": "Baz", "type": "string"}, }, - "title": "fooSchemaSchema", + "title": "fooSchema", "type": "object", "required": ["bar", "baz"], } @@ -685,7 +687,7 @@ def test_structured_tool_from_function() -> None: } assert structured_tool.args_schema.schema() == { - "title": "fooSchemaSchema", + "title": "fooSchema", "type": "object", "properties": { "bar": {"title": "Bar", "type": "integer"}, @@ -821,3 +823,51 @@ async def test_async_validation_error_handling_non_validation_error( _tool = _RaiseNonValidationErrorTool(handle_validation_error=handler) with pytest.raises(NotImplementedError): await _tool.arun({}) + + +def test_optional_subset_model_rewrite() -> None: + class MyModel(BaseModel): + a: Optional[str] + b: str + c: Optional[List[Optional[str]]] + + model2 = _create_subset_model("model2", MyModel, ["a", "b", "c"]) + + assert "a" not in model2.schema()["required"] # should be optional + assert "b" in model2.schema()["required"] # should be required + assert "c" not in model2.schema()["required"] # should be optional + + +@pytest.mark.parametrize( + "inputs, expected", + [ + # Check not required + ({"bar": "bar"}, {"bar": "bar", "baz": 3, "buzz": "buzz"}), + # Check overwritten + ( + {"bar": "bar", "baz": 4, "buzz": "not-buzz"}, + {"bar": "bar", "baz": 4, "buzz": "not-buzz"}, + ), + # Check validation error when missing + ({}, None), + # Check validation error when wrong type + ({"bar": "bar", "baz": "not-an-int"}, None), + # Check OK when None explicitly passed + ({"bar": "bar", "baz": None}, {"bar": "bar", "baz": None, "buzz": "buzz"}), + ], +) +def test_tool_invoke_optional_args(inputs: dict, expected: Optional[dict]) -> None: + @tool + def foo(bar: str, baz: Optional[int] = 3, buzz: Optional[str] = "buzz") -> dict: + """The foo.""" + return { + "bar": bar, + "baz": baz, + "buzz": buzz, + } + + if expected is not None: + assert foo.invoke(inputs) == expected # type: ignore + else: + with pytest.raises(ValidationError): + foo.invoke(inputs) # type: ignore diff --git a/libs/core/tests/unit_tests/utils/test_function_calling.py b/libs/core/tests/unit_tests/utils/test_function_calling.py index 026338614ce..bd03abe2757 100644 --- a/libs/core/tests/unit_tests/utils/test_function_calling.py +++ b/libs/core/tests/unit_tests/utils/test_function_calling.py @@ -1,9 +1,9 @@ -from typing import Any, Callable, Literal, Type +from typing import Any, Callable, List, Literal, Optional, Type import pytest from langchain_core.pydantic_v1 import BaseModel, Field -from langchain_core.tools import BaseTool +from langchain_core.tools import BaseTool, tool from langchain_core.utils.function_calling import convert_to_openai_function @@ -33,7 +33,7 @@ def function() -> Callable: @pytest.fixture() -def tool() -> BaseTool: +def dummy_tool() -> BaseTool: class Schema(BaseModel): arg1: int = Field(..., description="foo") arg2: Literal["bar", "baz"] = Field(..., description="one of 'bar', 'baz'") @@ -50,7 +50,7 @@ def tool() -> BaseTool: def test_convert_to_openai_function( - pydantic: Type[BaseModel], function: Callable, tool: BaseTool + pydantic: Type[BaseModel], function: Callable, dummy_tool: BaseTool ) -> None: expected = { "name": "dummy_function", @@ -69,6 +69,22 @@ def test_convert_to_openai_function( }, } - for fn in (pydantic, function, tool, expected): + for fn in (pydantic, function, dummy_tool, expected): actual = convert_to_openai_function(fn) # type: ignore assert actual == expected + + +@pytest.mark.xfail(reason="Pydantic converts Optional[str] to str in .schema()") +def test_function_optional_param() -> None: + @tool + def func5( + a: Optional[str], + b: str, + c: Optional[List[Optional[str]]], + ) -> None: + """A test function""" + pass + + func = convert_to_openai_function(func5) + req = func["parameters"]["required"] + assert set(req) == {"b"}