core[patch]: support pydantic v1 annotations in tool arguments (#26336)

If all pydantic annotations in function signature are V1, use V1
`validate_arguments`.
This commit is contained in:
ccurme
2024-09-11 16:31:40 -04:00
committed by GitHub
parent d67e1dfe32
commit 284e1a7e9e
2 changed files with 103 additions and 10 deletions

View File

@@ -38,6 +38,7 @@ from pydantic import (
validate_arguments,
)
from pydantic.v1 import BaseModel as BaseModelV1
from pydantic.v1 import validate_arguments as validate_arguments_v1
from typing_extensions import Annotated
from langchain_core._api import deprecated
@@ -164,6 +165,35 @@ def _infer_arg_descriptions(
return description, arg_descriptions
def _is_pydantic_annotation(annotation: Any, pydantic_version: str = "v2") -> bool:
"""Determine if a type annotation is a Pydantic model."""
base_model_class = BaseModelV1 if pydantic_version == "v1" else BaseModel
try:
return issubclass(annotation, base_model_class)
except TypeError:
return False
def _function_annotations_are_pydantic_v1(
signature: inspect.Signature, func: Callable
) -> bool:
"""Determine if all Pydantic annotations in a function signature are from V1."""
any_v1_annotations = any(
_is_pydantic_annotation(parameter.annotation, pydantic_version="v1")
for parameter in signature.parameters.values()
)
any_v2_annotations = any(
_is_pydantic_annotation(parameter.annotation, pydantic_version="v2")
for parameter in signature.parameters.values()
)
if any_v1_annotations and any_v2_annotations:
raise NotImplementedError(
f"Function {func} contains a mix of Pydantic v1 and v2 annotations. "
"Only one version of Pydantic annotations per function is supported."
)
return any_v1_annotations and not any_v2_annotations
class _SchemaConfig:
"""Configuration for the pydantic model.
@@ -208,16 +238,19 @@ def create_schema_from_function(
Returns:
A pydantic model with the same arguments as the function.
"""
# https://docs.pydantic.dev/latest/usage/validation_decorator/
with warnings.catch_warnings():
# We are using deprecated functionality here.
# This code should be re-written to simply construct a pydantic model
# using inspect.signature and create_model.
warnings.simplefilter("ignore", category=PydanticDeprecationWarning)
validated = validate_arguments(func, config=_SchemaConfig) # type: ignore
sig = inspect.signature(func)
if _function_annotations_are_pydantic_v1(sig, func):
validated = validate_arguments_v1(func, config=_SchemaConfig) # type: ignore
else:
# https://docs.pydantic.dev/latest/usage/validation_decorator/
with warnings.catch_warnings():
# We are using deprecated functionality here.
# This code should be re-written to simply construct a pydantic model
# using inspect.signature and create_model.
warnings.simplefilter("ignore", category=PydanticDeprecationWarning)
validated = validate_arguments(func, config=_SchemaConfig) # type: ignore
# Let's ignore `self` and `cls` arguments for class and instance methods
if func.__qualname__ and "." in func.__qualname__:
# Then it likely belongs in a class namespace

View File

@@ -23,7 +23,7 @@ from typing import (
import pytest
from pydantic import BaseModel, Field, ValidationError
from pydantic import BaseModel as BaseModelProper
from pydantic.v1 import BaseModel as BaseModelV1
from typing_extensions import Annotated, TypedDict, TypeVar
from langchain_core import tools
@@ -80,6 +80,14 @@ class _MockSchema(BaseModel):
arg3: Optional[dict] = None
class _MockSchemaV1(BaseModelV1):
"""Return the arguments directly."""
arg1: int
arg2: bool
arg3: Optional[dict] = None
class _MockStructuredTool(BaseTool):
name: str = "structured_api"
args_schema: Type[BaseModel] = _MockSchema
@@ -169,6 +177,13 @@ def test_decorator_with_specified_schema() -> None:
assert isinstance(tool_func, BaseTool)
assert tool_func.args_schema == _MockSchema
@tool(args_schema=_MockSchemaV1)
def tool_func_v1(arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
return f"{arg1} {arg2} {arg3}"
assert isinstance(tool_func_v1, BaseTool)
assert tool_func_v1.args_schema == _MockSchemaV1
def test_decorated_function_schema_equivalent() -> None:
"""Test that a BaseTool without a schema meets expectations."""
@@ -302,6 +317,51 @@ def test_structured_tool_types_parsed() -> None:
assert result == expected
def test_structured_tool_types_parsed_pydantic_v1() -> None:
"""Test the non-primitive types are correctly passed to structured tools."""
class SomeBaseModel(BaseModelV1):
foo: str
class AnotherBaseModel(BaseModelV1):
bar: str
@tool
def structured_tool(some_base_model: SomeBaseModel) -> AnotherBaseModel:
"""Return the arguments directly."""
return AnotherBaseModel(bar=some_base_model.foo)
assert isinstance(structured_tool, StructuredTool)
expected = AnotherBaseModel(bar="baz")
for arg in [
SomeBaseModel(foo="baz"),
SomeBaseModel(foo="baz").dict(),
]:
args = {"some_base_model": arg}
result = structured_tool.run(args)
assert result == expected
def test_structured_tool_types_parsed_pydantic_mixed() -> None:
"""Test handling of tool with mixed Pydantic version arguments."""
class SomeBaseModel(BaseModelV1):
foo: str
class AnotherBaseModel(BaseModel):
bar: str
with pytest.raises(NotImplementedError):
@tool
def structured_tool(
some_base_model: SomeBaseModel, another_base_model: AnotherBaseModel
) -> None:
"""Return the arguments directly."""
pass
def test_base_tool_inheritance_base_schema() -> None:
"""Test schema is correctly inferred when inheriting from BaseTool."""
@@ -1562,7 +1622,7 @@ def test_fn_injected_arg_with_schema(tool_: Callable) -> None:
def generate_models() -> List[Any]:
"""Generate a list of base models depending on the pydantic version."""
class FooProper(BaseModelProper):
class FooProper(BaseModel):
a: int
b: str