mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-21 14:18:52 +00:00
core[patch]: handle some optional cases in tools (#16954)
primary problem in pydantic still exists, where `Optional[str]` gets turned to `string` in the jsonschema `.schema()` Also fixes the `SchemaSchema` naming issue --------- Co-authored-by: William Fu-Hinthorn <13333726+hinthornw@users.noreply.github.com>
This commit is contained in:
parent
f8943e8739
commit
06660bc78c
@ -39,14 +39,21 @@ class SchemaAnnotationError(TypeError):
|
|||||||
|
|
||||||
|
|
||||||
def _create_subset_model(
|
def _create_subset_model(
|
||||||
name: str, model: BaseModel, field_names: list
|
name: str, model: Type[BaseModel], field_names: list
|
||||||
) -> Type[BaseModel]:
|
) -> Type[BaseModel]:
|
||||||
"""Create a pydantic model with only a subset of model's fields."""
|
"""Create a pydantic model with only a subset of model's fields."""
|
||||||
fields = {}
|
fields = {}
|
||||||
for field_name in field_names:
|
for field_name in field_names:
|
||||||
field = model.__fields__[field_name]
|
field = model.__fields__[field_name]
|
||||||
fields[field_name] = (field.outer_type_, field.field_info)
|
t = (
|
||||||
return create_model(name, **fields) # type: ignore
|
# this isn't perfect but should work for most functions
|
||||||
|
field.outer_type_
|
||||||
|
if field.required and not field.allow_none
|
||||||
|
else Optional[field.outer_type_]
|
||||||
|
)
|
||||||
|
fields[field_name] = (t, field.field_info)
|
||||||
|
rtn = create_model(name, **fields) # type: ignore
|
||||||
|
return rtn
|
||||||
|
|
||||||
|
|
||||||
def _get_filtered_args(
|
def _get_filtered_args(
|
||||||
@ -764,7 +771,8 @@ class StructuredTool(BaseTool):
|
|||||||
description = f"{name}{sig} - {description.strip()}"
|
description = f"{name}{sig} - {description.strip()}"
|
||||||
_args_schema = args_schema
|
_args_schema = args_schema
|
||||||
if _args_schema is None and infer_schema:
|
if _args_schema is None and infer_schema:
|
||||||
_args_schema = create_schema_from_function(f"{name}Schema", source_function)
|
# schema name is appended within function
|
||||||
|
_args_schema = create_schema_from_function(name, source_function)
|
||||||
return cls(
|
return cls(
|
||||||
name=name,
|
name=name,
|
||||||
func=func,
|
func=func,
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
"""Test the base tool implementation."""
|
"""Test the base tool implementation."""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
@ -18,6 +19,7 @@ from langchain_core.tools import (
|
|||||||
StructuredTool,
|
StructuredTool,
|
||||||
Tool,
|
Tool,
|
||||||
ToolException,
|
ToolException,
|
||||||
|
_create_subset_model,
|
||||||
tool,
|
tool,
|
||||||
)
|
)
|
||||||
from tests.unit_tests.fake.callbacks import FakeCallbackHandler
|
from tests.unit_tests.fake.callbacks import FakeCallbackHandler
|
||||||
@ -321,7 +323,7 @@ def test_structured_tool_from_function_docstring() -> None:
|
|||||||
"bar": {"title": "Bar", "type": "integer"},
|
"bar": {"title": "Bar", "type": "integer"},
|
||||||
"baz": {"title": "Baz", "type": "string"},
|
"baz": {"title": "Baz", "type": "string"},
|
||||||
},
|
},
|
||||||
"title": "fooSchemaSchema",
|
"title": "fooSchema",
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": ["bar", "baz"],
|
"required": ["bar", "baz"],
|
||||||
}
|
}
|
||||||
@ -354,7 +356,7 @@ def test_structured_tool_from_function_docstring_complex_args() -> None:
|
|||||||
"bar": {"title": "Bar", "type": "integer"},
|
"bar": {"title": "Bar", "type": "integer"},
|
||||||
"baz": {"title": "Baz", "type": "array", "items": {"type": "string"}},
|
"baz": {"title": "Baz", "type": "array", "items": {"type": "string"}},
|
||||||
},
|
},
|
||||||
"title": "fooSchemaSchema",
|
"title": "fooSchema",
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": ["bar", "baz"],
|
"required": ["bar", "baz"],
|
||||||
}
|
}
|
||||||
@ -454,7 +456,7 @@ def test_structured_tool_from_function_with_run_manager() -> None:
|
|||||||
"bar": {"title": "Bar", "type": "integer"},
|
"bar": {"title": "Bar", "type": "integer"},
|
||||||
"baz": {"title": "Baz", "type": "string"},
|
"baz": {"title": "Baz", "type": "string"},
|
||||||
},
|
},
|
||||||
"title": "fooSchemaSchema",
|
"title": "fooSchema",
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": ["bar", "baz"],
|
"required": ["bar", "baz"],
|
||||||
}
|
}
|
||||||
@ -685,7 +687,7 @@ def test_structured_tool_from_function() -> None:
|
|||||||
}
|
}
|
||||||
|
|
||||||
assert structured_tool.args_schema.schema() == {
|
assert structured_tool.args_schema.schema() == {
|
||||||
"title": "fooSchemaSchema",
|
"title": "fooSchema",
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"bar": {"title": "Bar", "type": "integer"},
|
"bar": {"title": "Bar", "type": "integer"},
|
||||||
@ -821,3 +823,51 @@ async def test_async_validation_error_handling_non_validation_error(
|
|||||||
_tool = _RaiseNonValidationErrorTool(handle_validation_error=handler)
|
_tool = _RaiseNonValidationErrorTool(handle_validation_error=handler)
|
||||||
with pytest.raises(NotImplementedError):
|
with pytest.raises(NotImplementedError):
|
||||||
await _tool.arun({})
|
await _tool.arun({})
|
||||||
|
|
||||||
|
|
||||||
|
def test_optional_subset_model_rewrite() -> None:
|
||||||
|
class MyModel(BaseModel):
|
||||||
|
a: Optional[str]
|
||||||
|
b: str
|
||||||
|
c: Optional[List[Optional[str]]]
|
||||||
|
|
||||||
|
model2 = _create_subset_model("model2", MyModel, ["a", "b", "c"])
|
||||||
|
|
||||||
|
assert "a" not in model2.schema()["required"] # should be optional
|
||||||
|
assert "b" in model2.schema()["required"] # should be required
|
||||||
|
assert "c" not in model2.schema()["required"] # should be optional
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"inputs, expected",
|
||||||
|
[
|
||||||
|
# Check not required
|
||||||
|
({"bar": "bar"}, {"bar": "bar", "baz": 3, "buzz": "buzz"}),
|
||||||
|
# Check overwritten
|
||||||
|
(
|
||||||
|
{"bar": "bar", "baz": 4, "buzz": "not-buzz"},
|
||||||
|
{"bar": "bar", "baz": 4, "buzz": "not-buzz"},
|
||||||
|
),
|
||||||
|
# Check validation error when missing
|
||||||
|
({}, None),
|
||||||
|
# Check validation error when wrong type
|
||||||
|
({"bar": "bar", "baz": "not-an-int"}, None),
|
||||||
|
# Check OK when None explicitly passed
|
||||||
|
({"bar": "bar", "baz": None}, {"bar": "bar", "baz": None, "buzz": "buzz"}),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_tool_invoke_optional_args(inputs: dict, expected: Optional[dict]) -> None:
|
||||||
|
@tool
|
||||||
|
def foo(bar: str, baz: Optional[int] = 3, buzz: Optional[str] = "buzz") -> dict:
|
||||||
|
"""The foo."""
|
||||||
|
return {
|
||||||
|
"bar": bar,
|
||||||
|
"baz": baz,
|
||||||
|
"buzz": buzz,
|
||||||
|
}
|
||||||
|
|
||||||
|
if expected is not None:
|
||||||
|
assert foo.invoke(inputs) == expected # type: ignore
|
||||||
|
else:
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
foo.invoke(inputs) # type: ignore
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
from typing import Any, Callable, Literal, Type
|
from typing import Any, Callable, List, Literal, Optional, Type
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool, tool
|
||||||
from langchain_core.utils.function_calling import convert_to_openai_function
|
from langchain_core.utils.function_calling import convert_to_openai_function
|
||||||
|
|
||||||
|
|
||||||
@ -33,7 +33,7 @@ def function() -> Callable:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def tool() -> BaseTool:
|
def dummy_tool() -> BaseTool:
|
||||||
class Schema(BaseModel):
|
class Schema(BaseModel):
|
||||||
arg1: int = Field(..., description="foo")
|
arg1: int = Field(..., description="foo")
|
||||||
arg2: Literal["bar", "baz"] = Field(..., description="one of 'bar', 'baz'")
|
arg2: Literal["bar", "baz"] = Field(..., description="one of 'bar', 'baz'")
|
||||||
@ -50,7 +50,7 @@ def tool() -> BaseTool:
|
|||||||
|
|
||||||
|
|
||||||
def test_convert_to_openai_function(
|
def test_convert_to_openai_function(
|
||||||
pydantic: Type[BaseModel], function: Callable, tool: BaseTool
|
pydantic: Type[BaseModel], function: Callable, dummy_tool: BaseTool
|
||||||
) -> None:
|
) -> None:
|
||||||
expected = {
|
expected = {
|
||||||
"name": "dummy_function",
|
"name": "dummy_function",
|
||||||
@ -69,6 +69,22 @@ def test_convert_to_openai_function(
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for fn in (pydantic, function, tool, expected):
|
for fn in (pydantic, function, dummy_tool, expected):
|
||||||
actual = convert_to_openai_function(fn) # type: ignore
|
actual = convert_to_openai_function(fn) # type: ignore
|
||||||
assert actual == expected
|
assert actual == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.xfail(reason="Pydantic converts Optional[str] to str in .schema()")
|
||||||
|
def test_function_optional_param() -> None:
|
||||||
|
@tool
|
||||||
|
def func5(
|
||||||
|
a: Optional[str],
|
||||||
|
b: str,
|
||||||
|
c: Optional[List[Optional[str]]],
|
||||||
|
) -> None:
|
||||||
|
"""A test function"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
func = convert_to_openai_function(func5)
|
||||||
|
req = func["parameters"]["required"]
|
||||||
|
assert set(req) == {"b"}
|
||||||
|
Loading…
Reference in New Issue
Block a user