core[patch]: handle some optional cases in tools (#16954)

primary problem in pydantic still exists, where `Optional[str]` gets
turned to `string` in the jsonschema `.schema()`

Also fixes the `SchemaSchema` naming issue

---------

Co-authored-by: William Fu-Hinthorn <13333726+hinthornw@users.noreply.github.com>
This commit is contained in:
Erick Friis 2024-02-02 15:05:54 -08:00 committed by GitHub
parent f8943e8739
commit 06660bc78c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 87 additions and 13 deletions

View File

@ -39,14 +39,21 @@ class SchemaAnnotationError(TypeError):
def _create_subset_model( def _create_subset_model(
name: str, model: BaseModel, field_names: list name: str, model: Type[BaseModel], field_names: list
) -> Type[BaseModel]: ) -> Type[BaseModel]:
"""Create a pydantic model with only a subset of model's fields.""" """Create a pydantic model with only a subset of model's fields."""
fields = {} fields = {}
for field_name in field_names: for field_name in field_names:
field = model.__fields__[field_name] field = model.__fields__[field_name]
fields[field_name] = (field.outer_type_, field.field_info) t = (
return create_model(name, **fields) # type: ignore # this isn't perfect but should work for most functions
field.outer_type_
if field.required and not field.allow_none
else Optional[field.outer_type_]
)
fields[field_name] = (t, field.field_info)
rtn = create_model(name, **fields) # type: ignore
return rtn
def _get_filtered_args( def _get_filtered_args(
@ -764,7 +771,8 @@ class StructuredTool(BaseTool):
description = f"{name}{sig} - {description.strip()}" description = f"{name}{sig} - {description.strip()}"
_args_schema = args_schema _args_schema = args_schema
if _args_schema is None and infer_schema: if _args_schema is None and infer_schema:
_args_schema = create_schema_from_function(f"{name}Schema", source_function) # schema name is appended within function
_args_schema = create_schema_from_function(name, source_function)
return cls( return cls(
name=name, name=name,
func=func, func=func,

View File

@ -1,4 +1,5 @@
"""Test the base tool implementation.""" """Test the base tool implementation."""
import json import json
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
@ -18,6 +19,7 @@ from langchain_core.tools import (
StructuredTool, StructuredTool,
Tool, Tool,
ToolException, ToolException,
_create_subset_model,
tool, tool,
) )
from tests.unit_tests.fake.callbacks import FakeCallbackHandler from tests.unit_tests.fake.callbacks import FakeCallbackHandler
@ -321,7 +323,7 @@ def test_structured_tool_from_function_docstring() -> None:
"bar": {"title": "Bar", "type": "integer"}, "bar": {"title": "Bar", "type": "integer"},
"baz": {"title": "Baz", "type": "string"}, "baz": {"title": "Baz", "type": "string"},
}, },
"title": "fooSchemaSchema", "title": "fooSchema",
"type": "object", "type": "object",
"required": ["bar", "baz"], "required": ["bar", "baz"],
} }
@ -354,7 +356,7 @@ def test_structured_tool_from_function_docstring_complex_args() -> None:
"bar": {"title": "Bar", "type": "integer"}, "bar": {"title": "Bar", "type": "integer"},
"baz": {"title": "Baz", "type": "array", "items": {"type": "string"}}, "baz": {"title": "Baz", "type": "array", "items": {"type": "string"}},
}, },
"title": "fooSchemaSchema", "title": "fooSchema",
"type": "object", "type": "object",
"required": ["bar", "baz"], "required": ["bar", "baz"],
} }
@ -454,7 +456,7 @@ def test_structured_tool_from_function_with_run_manager() -> None:
"bar": {"title": "Bar", "type": "integer"}, "bar": {"title": "Bar", "type": "integer"},
"baz": {"title": "Baz", "type": "string"}, "baz": {"title": "Baz", "type": "string"},
}, },
"title": "fooSchemaSchema", "title": "fooSchema",
"type": "object", "type": "object",
"required": ["bar", "baz"], "required": ["bar", "baz"],
} }
@ -685,7 +687,7 @@ def test_structured_tool_from_function() -> None:
} }
assert structured_tool.args_schema.schema() == { assert structured_tool.args_schema.schema() == {
"title": "fooSchemaSchema", "title": "fooSchema",
"type": "object", "type": "object",
"properties": { "properties": {
"bar": {"title": "Bar", "type": "integer"}, "bar": {"title": "Bar", "type": "integer"},
@ -821,3 +823,51 @@ async def test_async_validation_error_handling_non_validation_error(
_tool = _RaiseNonValidationErrorTool(handle_validation_error=handler) _tool = _RaiseNonValidationErrorTool(handle_validation_error=handler)
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
await _tool.arun({}) await _tool.arun({})
def test_optional_subset_model_rewrite() -> None:
class MyModel(BaseModel):
a: Optional[str]
b: str
c: Optional[List[Optional[str]]]
model2 = _create_subset_model("model2", MyModel, ["a", "b", "c"])
assert "a" not in model2.schema()["required"] # should be optional
assert "b" in model2.schema()["required"] # should be required
assert "c" not in model2.schema()["required"] # should be optional
@pytest.mark.parametrize(
"inputs, expected",
[
# Check not required
({"bar": "bar"}, {"bar": "bar", "baz": 3, "buzz": "buzz"}),
# Check overwritten
(
{"bar": "bar", "baz": 4, "buzz": "not-buzz"},
{"bar": "bar", "baz": 4, "buzz": "not-buzz"},
),
# Check validation error when missing
({}, None),
# Check validation error when wrong type
({"bar": "bar", "baz": "not-an-int"}, None),
# Check OK when None explicitly passed
({"bar": "bar", "baz": None}, {"bar": "bar", "baz": None, "buzz": "buzz"}),
],
)
def test_tool_invoke_optional_args(inputs: dict, expected: Optional[dict]) -> None:
@tool
def foo(bar: str, baz: Optional[int] = 3, buzz: Optional[str] = "buzz") -> dict:
"""The foo."""
return {
"bar": bar,
"baz": baz,
"buzz": buzz,
}
if expected is not None:
assert foo.invoke(inputs) == expected # type: ignore
else:
with pytest.raises(ValidationError):
foo.invoke(inputs) # type: ignore

View File

@ -1,9 +1,9 @@
from typing import Any, Callable, Literal, Type from typing import Any, Callable, List, Literal, Optional, Type
import pytest import pytest
from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool, tool
from langchain_core.utils.function_calling import convert_to_openai_function from langchain_core.utils.function_calling import convert_to_openai_function
@ -33,7 +33,7 @@ def function() -> Callable:
@pytest.fixture() @pytest.fixture()
def tool() -> BaseTool: def dummy_tool() -> BaseTool:
class Schema(BaseModel): class Schema(BaseModel):
arg1: int = Field(..., description="foo") arg1: int = Field(..., description="foo")
arg2: Literal["bar", "baz"] = Field(..., description="one of 'bar', 'baz'") arg2: Literal["bar", "baz"] = Field(..., description="one of 'bar', 'baz'")
@ -50,7 +50,7 @@ def tool() -> BaseTool:
def test_convert_to_openai_function( def test_convert_to_openai_function(
pydantic: Type[BaseModel], function: Callable, tool: BaseTool pydantic: Type[BaseModel], function: Callable, dummy_tool: BaseTool
) -> None: ) -> None:
expected = { expected = {
"name": "dummy_function", "name": "dummy_function",
@ -69,6 +69,22 @@ def test_convert_to_openai_function(
}, },
} }
for fn in (pydantic, function, tool, expected): for fn in (pydantic, function, dummy_tool, expected):
actual = convert_to_openai_function(fn) # type: ignore actual = convert_to_openai_function(fn) # type: ignore
assert actual == expected assert actual == expected
@pytest.mark.xfail(reason="Pydantic converts Optional[str] to str in .schema()")
def test_function_optional_param() -> None:
@tool
def func5(
a: Optional[str],
b: str,
c: Optional[List[Optional[str]]],
) -> None:
"""A test function"""
pass
func = convert_to_openai_function(func5)
req = func["parameters"]["required"]
assert set(req) == {"b"}