mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-21 06:14:37 +00:00
core[patch]: handle some optional cases in tools (#16954)
primary problem in pydantic still exists, where `Optional[str]` gets turned to `string` in the jsonschema `.schema()` Also fixes the `SchemaSchema` naming issue --------- Co-authored-by: William Fu-Hinthorn <13333726+hinthornw@users.noreply.github.com>
This commit is contained in:
parent
f8943e8739
commit
06660bc78c
@ -39,14 +39,21 @@ class SchemaAnnotationError(TypeError):
|
||||
|
||||
|
||||
def _create_subset_model(
|
||||
name: str, model: BaseModel, field_names: list
|
||||
name: str, model: Type[BaseModel], field_names: list
|
||||
) -> Type[BaseModel]:
|
||||
"""Create a pydantic model with only a subset of model's fields."""
|
||||
fields = {}
|
||||
for field_name in field_names:
|
||||
field = model.__fields__[field_name]
|
||||
fields[field_name] = (field.outer_type_, field.field_info)
|
||||
return create_model(name, **fields) # type: ignore
|
||||
t = (
|
||||
# this isn't perfect but should work for most functions
|
||||
field.outer_type_
|
||||
if field.required and not field.allow_none
|
||||
else Optional[field.outer_type_]
|
||||
)
|
||||
fields[field_name] = (t, field.field_info)
|
||||
rtn = create_model(name, **fields) # type: ignore
|
||||
return rtn
|
||||
|
||||
|
||||
def _get_filtered_args(
|
||||
@ -764,7 +771,8 @@ class StructuredTool(BaseTool):
|
||||
description = f"{name}{sig} - {description.strip()}"
|
||||
_args_schema = args_schema
|
||||
if _args_schema is None and infer_schema:
|
||||
_args_schema = create_schema_from_function(f"{name}Schema", source_function)
|
||||
# schema name is appended within function
|
||||
_args_schema = create_schema_from_function(name, source_function)
|
||||
return cls(
|
||||
name=name,
|
||||
func=func,
|
||||
|
@ -1,4 +1,5 @@
|
||||
"""Test the base tool implementation."""
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
@ -18,6 +19,7 @@ from langchain_core.tools import (
|
||||
StructuredTool,
|
||||
Tool,
|
||||
ToolException,
|
||||
_create_subset_model,
|
||||
tool,
|
||||
)
|
||||
from tests.unit_tests.fake.callbacks import FakeCallbackHandler
|
||||
@ -321,7 +323,7 @@ def test_structured_tool_from_function_docstring() -> None:
|
||||
"bar": {"title": "Bar", "type": "integer"},
|
||||
"baz": {"title": "Baz", "type": "string"},
|
||||
},
|
||||
"title": "fooSchemaSchema",
|
||||
"title": "fooSchema",
|
||||
"type": "object",
|
||||
"required": ["bar", "baz"],
|
||||
}
|
||||
@ -354,7 +356,7 @@ def test_structured_tool_from_function_docstring_complex_args() -> None:
|
||||
"bar": {"title": "Bar", "type": "integer"},
|
||||
"baz": {"title": "Baz", "type": "array", "items": {"type": "string"}},
|
||||
},
|
||||
"title": "fooSchemaSchema",
|
||||
"title": "fooSchema",
|
||||
"type": "object",
|
||||
"required": ["bar", "baz"],
|
||||
}
|
||||
@ -454,7 +456,7 @@ def test_structured_tool_from_function_with_run_manager() -> None:
|
||||
"bar": {"title": "Bar", "type": "integer"},
|
||||
"baz": {"title": "Baz", "type": "string"},
|
||||
},
|
||||
"title": "fooSchemaSchema",
|
||||
"title": "fooSchema",
|
||||
"type": "object",
|
||||
"required": ["bar", "baz"],
|
||||
}
|
||||
@ -685,7 +687,7 @@ def test_structured_tool_from_function() -> None:
|
||||
}
|
||||
|
||||
assert structured_tool.args_schema.schema() == {
|
||||
"title": "fooSchemaSchema",
|
||||
"title": "fooSchema",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"bar": {"title": "Bar", "type": "integer"},
|
||||
@ -821,3 +823,51 @@ async def test_async_validation_error_handling_non_validation_error(
|
||||
_tool = _RaiseNonValidationErrorTool(handle_validation_error=handler)
|
||||
with pytest.raises(NotImplementedError):
|
||||
await _tool.arun({})
|
||||
|
||||
|
||||
def test_optional_subset_model_rewrite() -> None:
|
||||
class MyModel(BaseModel):
|
||||
a: Optional[str]
|
||||
b: str
|
||||
c: Optional[List[Optional[str]]]
|
||||
|
||||
model2 = _create_subset_model("model2", MyModel, ["a", "b", "c"])
|
||||
|
||||
assert "a" not in model2.schema()["required"] # should be optional
|
||||
assert "b" in model2.schema()["required"] # should be required
|
||||
assert "c" not in model2.schema()["required"] # should be optional
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"inputs, expected",
|
||||
[
|
||||
# Check not required
|
||||
({"bar": "bar"}, {"bar": "bar", "baz": 3, "buzz": "buzz"}),
|
||||
# Check overwritten
|
||||
(
|
||||
{"bar": "bar", "baz": 4, "buzz": "not-buzz"},
|
||||
{"bar": "bar", "baz": 4, "buzz": "not-buzz"},
|
||||
),
|
||||
# Check validation error when missing
|
||||
({}, None),
|
||||
# Check validation error when wrong type
|
||||
({"bar": "bar", "baz": "not-an-int"}, None),
|
||||
# Check OK when None explicitly passed
|
||||
({"bar": "bar", "baz": None}, {"bar": "bar", "baz": None, "buzz": "buzz"}),
|
||||
],
|
||||
)
|
||||
def test_tool_invoke_optional_args(inputs: dict, expected: Optional[dict]) -> None:
|
||||
@tool
|
||||
def foo(bar: str, baz: Optional[int] = 3, buzz: Optional[str] = "buzz") -> dict:
|
||||
"""The foo."""
|
||||
return {
|
||||
"bar": bar,
|
||||
"baz": baz,
|
||||
"buzz": buzz,
|
||||
}
|
||||
|
||||
if expected is not None:
|
||||
assert foo.invoke(inputs) == expected # type: ignore
|
||||
else:
|
||||
with pytest.raises(ValidationError):
|
||||
foo.invoke(inputs) # type: ignore
|
||||
|
@ -1,9 +1,9 @@
|
||||
from typing import Any, Callable, Literal, Type
|
||||
from typing import Any, Callable, List, Literal, Optional, Type
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.tools import BaseTool, tool
|
||||
from langchain_core.utils.function_calling import convert_to_openai_function
|
||||
|
||||
|
||||
@ -33,7 +33,7 @@ def function() -> Callable:
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def tool() -> BaseTool:
|
||||
def dummy_tool() -> BaseTool:
|
||||
class Schema(BaseModel):
|
||||
arg1: int = Field(..., description="foo")
|
||||
arg2: Literal["bar", "baz"] = Field(..., description="one of 'bar', 'baz'")
|
||||
@ -50,7 +50,7 @@ def tool() -> BaseTool:
|
||||
|
||||
|
||||
def test_convert_to_openai_function(
|
||||
pydantic: Type[BaseModel], function: Callable, tool: BaseTool
|
||||
pydantic: Type[BaseModel], function: Callable, dummy_tool: BaseTool
|
||||
) -> None:
|
||||
expected = {
|
||||
"name": "dummy_function",
|
||||
@ -69,6 +69,22 @@ def test_convert_to_openai_function(
|
||||
},
|
||||
}
|
||||
|
||||
for fn in (pydantic, function, tool, expected):
|
||||
for fn in (pydantic, function, dummy_tool, expected):
|
||||
actual = convert_to_openai_function(fn) # type: ignore
|
||||
assert actual == expected
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="Pydantic converts Optional[str] to str in .schema()")
|
||||
def test_function_optional_param() -> None:
|
||||
@tool
|
||||
def func5(
|
||||
a: Optional[str],
|
||||
b: str,
|
||||
c: Optional[List[Optional[str]]],
|
||||
) -> None:
|
||||
"""A test function"""
|
||||
pass
|
||||
|
||||
func = convert_to_openai_function(func5)
|
||||
req = func["parameters"]["required"]
|
||||
assert set(req) == {"b"}
|
||||
|
Loading…
Reference in New Issue
Block a user