mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-27 06:18:05 +00:00
core[patch]: support pydantic v1 annotations in tool arguments (#26336)
If all pydantic annotations in function signature are V1, use V1 `validate_arguments`.
This commit is contained in:
@@ -38,6 +38,7 @@ from pydantic import (
|
||||
validate_arguments,
|
||||
)
|
||||
from pydantic.v1 import BaseModel as BaseModelV1
|
||||
from pydantic.v1 import validate_arguments as validate_arguments_v1
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
@@ -164,6 +165,35 @@ def _infer_arg_descriptions(
|
||||
return description, arg_descriptions
|
||||
|
||||
|
||||
def _is_pydantic_annotation(annotation: Any, pydantic_version: str = "v2") -> bool:
|
||||
"""Determine if a type annotation is a Pydantic model."""
|
||||
base_model_class = BaseModelV1 if pydantic_version == "v1" else BaseModel
|
||||
try:
|
||||
return issubclass(annotation, base_model_class)
|
||||
except TypeError:
|
||||
return False
|
||||
|
||||
|
||||
def _function_annotations_are_pydantic_v1(
|
||||
signature: inspect.Signature, func: Callable
|
||||
) -> bool:
|
||||
"""Determine if all Pydantic annotations in a function signature are from V1."""
|
||||
any_v1_annotations = any(
|
||||
_is_pydantic_annotation(parameter.annotation, pydantic_version="v1")
|
||||
for parameter in signature.parameters.values()
|
||||
)
|
||||
any_v2_annotations = any(
|
||||
_is_pydantic_annotation(parameter.annotation, pydantic_version="v2")
|
||||
for parameter in signature.parameters.values()
|
||||
)
|
||||
if any_v1_annotations and any_v2_annotations:
|
||||
raise NotImplementedError(
|
||||
f"Function {func} contains a mix of Pydantic v1 and v2 annotations. "
|
||||
"Only one version of Pydantic annotations per function is supported."
|
||||
)
|
||||
return any_v1_annotations and not any_v2_annotations
|
||||
|
||||
|
||||
class _SchemaConfig:
|
||||
"""Configuration for the pydantic model.
|
||||
|
||||
@@ -208,16 +238,19 @@ def create_schema_from_function(
|
||||
Returns:
|
||||
A pydantic model with the same arguments as the function.
|
||||
"""
|
||||
# https://docs.pydantic.dev/latest/usage/validation_decorator/
|
||||
with warnings.catch_warnings():
|
||||
# We are using deprecated functionality here.
|
||||
# This code should be re-written to simply construct a pydantic model
|
||||
# using inspect.signature and create_model.
|
||||
warnings.simplefilter("ignore", category=PydanticDeprecationWarning)
|
||||
validated = validate_arguments(func, config=_SchemaConfig) # type: ignore
|
||||
|
||||
sig = inspect.signature(func)
|
||||
|
||||
if _function_annotations_are_pydantic_v1(sig, func):
|
||||
validated = validate_arguments_v1(func, config=_SchemaConfig) # type: ignore
|
||||
else:
|
||||
# https://docs.pydantic.dev/latest/usage/validation_decorator/
|
||||
with warnings.catch_warnings():
|
||||
# We are using deprecated functionality here.
|
||||
# This code should be re-written to simply construct a pydantic model
|
||||
# using inspect.signature and create_model.
|
||||
warnings.simplefilter("ignore", category=PydanticDeprecationWarning)
|
||||
validated = validate_arguments(func, config=_SchemaConfig) # type: ignore
|
||||
|
||||
# Let's ignore `self` and `cls` arguments for class and instance methods
|
||||
if func.__qualname__ and "." in func.__qualname__:
|
||||
# Then it likely belongs in a class namespace
|
||||
|
@@ -23,7 +23,7 @@ from typing import (
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
from pydantic import BaseModel as BaseModelProper
|
||||
from pydantic.v1 import BaseModel as BaseModelV1
|
||||
from typing_extensions import Annotated, TypedDict, TypeVar
|
||||
|
||||
from langchain_core import tools
|
||||
@@ -80,6 +80,14 @@ class _MockSchema(BaseModel):
|
||||
arg3: Optional[dict] = None
|
||||
|
||||
|
||||
class _MockSchemaV1(BaseModelV1):
|
||||
"""Return the arguments directly."""
|
||||
|
||||
arg1: int
|
||||
arg2: bool
|
||||
arg3: Optional[dict] = None
|
||||
|
||||
|
||||
class _MockStructuredTool(BaseTool):
|
||||
name: str = "structured_api"
|
||||
args_schema: Type[BaseModel] = _MockSchema
|
||||
@@ -169,6 +177,13 @@ def test_decorator_with_specified_schema() -> None:
|
||||
assert isinstance(tool_func, BaseTool)
|
||||
assert tool_func.args_schema == _MockSchema
|
||||
|
||||
@tool(args_schema=_MockSchemaV1)
|
||||
def tool_func_v1(arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
|
||||
return f"{arg1} {arg2} {arg3}"
|
||||
|
||||
assert isinstance(tool_func_v1, BaseTool)
|
||||
assert tool_func_v1.args_schema == _MockSchemaV1
|
||||
|
||||
|
||||
def test_decorated_function_schema_equivalent() -> None:
|
||||
"""Test that a BaseTool without a schema meets expectations."""
|
||||
@@ -302,6 +317,51 @@ def test_structured_tool_types_parsed() -> None:
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_structured_tool_types_parsed_pydantic_v1() -> None:
|
||||
"""Test the non-primitive types are correctly passed to structured tools."""
|
||||
|
||||
class SomeBaseModel(BaseModelV1):
|
||||
foo: str
|
||||
|
||||
class AnotherBaseModel(BaseModelV1):
|
||||
bar: str
|
||||
|
||||
@tool
|
||||
def structured_tool(some_base_model: SomeBaseModel) -> AnotherBaseModel:
|
||||
"""Return the arguments directly."""
|
||||
return AnotherBaseModel(bar=some_base_model.foo)
|
||||
|
||||
assert isinstance(structured_tool, StructuredTool)
|
||||
|
||||
expected = AnotherBaseModel(bar="baz")
|
||||
for arg in [
|
||||
SomeBaseModel(foo="baz"),
|
||||
SomeBaseModel(foo="baz").dict(),
|
||||
]:
|
||||
args = {"some_base_model": arg}
|
||||
result = structured_tool.run(args)
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_structured_tool_types_parsed_pydantic_mixed() -> None:
|
||||
"""Test handling of tool with mixed Pydantic version arguments."""
|
||||
|
||||
class SomeBaseModel(BaseModelV1):
|
||||
foo: str
|
||||
|
||||
class AnotherBaseModel(BaseModel):
|
||||
bar: str
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
|
||||
@tool
|
||||
def structured_tool(
|
||||
some_base_model: SomeBaseModel, another_base_model: AnotherBaseModel
|
||||
) -> None:
|
||||
"""Return the arguments directly."""
|
||||
pass
|
||||
|
||||
|
||||
def test_base_tool_inheritance_base_schema() -> None:
|
||||
"""Test schema is correctly inferred when inheriting from BaseTool."""
|
||||
|
||||
@@ -1562,7 +1622,7 @@ def test_fn_injected_arg_with_schema(tool_: Callable) -> None:
|
||||
def generate_models() -> List[Any]:
|
||||
"""Generate a list of base models depending on the pydantic version."""
|
||||
|
||||
class FooProper(BaseModelProper):
|
||||
class FooProper(BaseModel):
|
||||
a: int
|
||||
b: str
|
||||
|
||||
|
Reference in New Issue
Block a user