mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-06 19:48:26 +00:00
Fix: Recognize List
at from_function
(#7178)
- Description: pydantic's `ModelField.type_` only exposes the native data type but not complex type hints like `List`. Thus, generating a Tool with `from_function` through function signature produces incorrect argument schemas (e.g., `str` instead of `List[str]`) - Issue: N/A - Dependencies: N/A - Tag maintainer: @hinthornw - Twitter handle: `mapped` All the unittest (with an additional one in this PR) passed, though I didn't try integration tests...
This commit is contained in:
parent
ec10787bc7
commit
d642609a23
@ -71,7 +71,7 @@ def _create_subset_model(
|
||||
fields = {}
|
||||
for field_name in field_names:
|
||||
field = model.__fields__[field_name]
|
||||
fields[field_name] = (field.type_, field.field_info)
|
||||
fields[field_name] = (field.outer_type_, field.field_info)
|
||||
return create_model(name, **fields) # type: ignore
|
||||
|
||||
|
||||
|
@ -3,7 +3,7 @@ import json
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import Any, Optional, Type, Union
|
||||
from typing import Any, List, Optional, Type, Union
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
@ -349,6 +349,39 @@ def test_structured_tool_from_function_docstring() -> None:
|
||||
assert structured_tool.description == prefix + foo.__doc__.strip()
|
||||
|
||||
|
||||
def test_structured_tool_from_function_docstring_complex_args() -> None:
|
||||
"""Test that structured tools can be created from functions."""
|
||||
|
||||
def foo(bar: int, baz: List[str]) -> str:
|
||||
"""Docstring
|
||||
Args:
|
||||
bar: int
|
||||
baz: List[str]
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
structured_tool = StructuredTool.from_function(foo)
|
||||
assert structured_tool.name == "foo"
|
||||
assert structured_tool.args == {
|
||||
"bar": {"title": "Bar", "type": "integer"},
|
||||
"baz": {"title": "Baz", "type": "array", "items": {"type": "string"}},
|
||||
}
|
||||
|
||||
assert structured_tool.args_schema.schema() == {
|
||||
"properties": {
|
||||
"bar": {"title": "Bar", "type": "integer"},
|
||||
"baz": {"title": "Baz", "type": "array", "items": {"type": "string"}},
|
||||
},
|
||||
"title": "fooSchemaSchema",
|
||||
"type": "object",
|
||||
"required": ["bar", "baz"],
|
||||
}
|
||||
|
||||
prefix = "foo(bar: int, baz: List[str]) -> str - "
|
||||
assert foo.__doc__ is not None
|
||||
assert structured_tool.description == prefix + foo.__doc__.strip()
|
||||
|
||||
|
||||
def test_structured_tool_lambda_multi_args_schema() -> None:
|
||||
"""Test args schema inference when the tool argument is a lambda function."""
|
||||
tool = StructuredTool.from_function(
|
||||
|
Loading…
Reference in New Issue
Block a user