Fix: Recognize List at from_function (#7178)

- Description: pydantic's `ModelField.type_` only exposes the native
data type but not complex type hints like `List`. Thus, generating a
Tool with `from_function` through function signature produces incorrect
argument schemas (e.g., `str` instead of `List[str]`)
  - Issue: N/A
  - Dependencies: N/A
  - Tag maintainer: @hinthornw
  - Twitter handle: `mapped`

All the unittest (with an additional one in this PR) passed, though I
didn't try integration tests...
This commit is contained in:
Jason B. Koh
2023-07-06 14:22:09 -07:00
committed by GitHub
parent ec10787bc7
commit d642609a23
2 changed files with 35 additions and 2 deletions

View File

@@ -3,7 +3,7 @@ import json
from datetime import datetime
from enum import Enum
from functools import partial
from typing import Any, Optional, Type, Union
from typing import Any, List, Optional, Type, Union
import pytest
from pydantic import BaseModel
@@ -349,6 +349,39 @@ def test_structured_tool_from_function_docstring() -> None:
assert structured_tool.description == prefix + foo.__doc__.strip()
def test_structured_tool_from_function_docstring_complex_args() -> None:
"""Test that structured tools can be created from functions."""
def foo(bar: int, baz: List[str]) -> str:
"""Docstring
Args:
bar: int
baz: List[str]
"""
raise NotImplementedError()
structured_tool = StructuredTool.from_function(foo)
assert structured_tool.name == "foo"
assert structured_tool.args == {
"bar": {"title": "Bar", "type": "integer"},
"baz": {"title": "Baz", "type": "array", "items": {"type": "string"}},
}
assert structured_tool.args_schema.schema() == {
"properties": {
"bar": {"title": "Bar", "type": "integer"},
"baz": {"title": "Baz", "type": "array", "items": {"type": "string"}},
},
"title": "fooSchemaSchema",
"type": "object",
"required": ["bar", "baz"],
}
prefix = "foo(bar: int, baz: List[str]) -> str - "
assert foo.__doc__ is not None
assert structured_tool.description == prefix + foo.__doc__.strip()
def test_structured_tool_lambda_multi_args_schema() -> None:
"""Test args schema inference when the tool argument is a lambda function."""
tool = StructuredTool.from_function(