mirror of
https://github.com/hwchase17/langchain.git
synced 2025-10-08 13:50:00 +00:00
core[patch]: support pydantic v1 annotations in tool arguments (#26336)
If all pydantic annotations in function signature are V1, use V1 `validate_arguments`.
This commit is contained in:
@@ -38,6 +38,7 @@ from pydantic import (
|
|||||||
validate_arguments,
|
validate_arguments,
|
||||||
)
|
)
|
||||||
from pydantic.v1 import BaseModel as BaseModelV1
|
from pydantic.v1 import BaseModel as BaseModelV1
|
||||||
|
from pydantic.v1 import validate_arguments as validate_arguments_v1
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from langchain_core._api import deprecated
|
from langchain_core._api import deprecated
|
||||||
@@ -164,6 +165,35 @@ def _infer_arg_descriptions(
|
|||||||
return description, arg_descriptions
|
return description, arg_descriptions
|
||||||
|
|
||||||
|
|
||||||
|
def _is_pydantic_annotation(annotation: Any, pydantic_version: str = "v2") -> bool:
|
||||||
|
"""Determine if a type annotation is a Pydantic model."""
|
||||||
|
base_model_class = BaseModelV1 if pydantic_version == "v1" else BaseModel
|
||||||
|
try:
|
||||||
|
return issubclass(annotation, base_model_class)
|
||||||
|
except TypeError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _function_annotations_are_pydantic_v1(
|
||||||
|
signature: inspect.Signature, func: Callable
|
||||||
|
) -> bool:
|
||||||
|
"""Determine if all Pydantic annotations in a function signature are from V1."""
|
||||||
|
any_v1_annotations = any(
|
||||||
|
_is_pydantic_annotation(parameter.annotation, pydantic_version="v1")
|
||||||
|
for parameter in signature.parameters.values()
|
||||||
|
)
|
||||||
|
any_v2_annotations = any(
|
||||||
|
_is_pydantic_annotation(parameter.annotation, pydantic_version="v2")
|
||||||
|
for parameter in signature.parameters.values()
|
||||||
|
)
|
||||||
|
if any_v1_annotations and any_v2_annotations:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"Function {func} contains a mix of Pydantic v1 and v2 annotations. "
|
||||||
|
"Only one version of Pydantic annotations per function is supported."
|
||||||
|
)
|
||||||
|
return any_v1_annotations and not any_v2_annotations
|
||||||
|
|
||||||
|
|
||||||
class _SchemaConfig:
|
class _SchemaConfig:
|
||||||
"""Configuration for the pydantic model.
|
"""Configuration for the pydantic model.
|
||||||
|
|
||||||
@@ -208,6 +238,11 @@ def create_schema_from_function(
|
|||||||
Returns:
|
Returns:
|
||||||
A pydantic model with the same arguments as the function.
|
A pydantic model with the same arguments as the function.
|
||||||
"""
|
"""
|
||||||
|
sig = inspect.signature(func)
|
||||||
|
|
||||||
|
if _function_annotations_are_pydantic_v1(sig, func):
|
||||||
|
validated = validate_arguments_v1(func, config=_SchemaConfig) # type: ignore
|
||||||
|
else:
|
||||||
# https://docs.pydantic.dev/latest/usage/validation_decorator/
|
# https://docs.pydantic.dev/latest/usage/validation_decorator/
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
# We are using deprecated functionality here.
|
# We are using deprecated functionality here.
|
||||||
@@ -216,8 +251,6 @@ def create_schema_from_function(
|
|||||||
warnings.simplefilter("ignore", category=PydanticDeprecationWarning)
|
warnings.simplefilter("ignore", category=PydanticDeprecationWarning)
|
||||||
validated = validate_arguments(func, config=_SchemaConfig) # type: ignore
|
validated = validate_arguments(func, config=_SchemaConfig) # type: ignore
|
||||||
|
|
||||||
sig = inspect.signature(func)
|
|
||||||
|
|
||||||
# Let's ignore `self` and `cls` arguments for class and instance methods
|
# Let's ignore `self` and `cls` arguments for class and instance methods
|
||||||
if func.__qualname__ and "." in func.__qualname__:
|
if func.__qualname__ and "." in func.__qualname__:
|
||||||
# Then it likely belongs in a class namespace
|
# Then it likely belongs in a class namespace
|
||||||
|
@@ -23,7 +23,7 @@ from typing import (
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import BaseModel, Field, ValidationError
|
from pydantic import BaseModel, Field, ValidationError
|
||||||
from pydantic import BaseModel as BaseModelProper
|
from pydantic.v1 import BaseModel as BaseModelV1
|
||||||
from typing_extensions import Annotated, TypedDict, TypeVar
|
from typing_extensions import Annotated, TypedDict, TypeVar
|
||||||
|
|
||||||
from langchain_core import tools
|
from langchain_core import tools
|
||||||
@@ -80,6 +80,14 @@ class _MockSchema(BaseModel):
|
|||||||
arg3: Optional[dict] = None
|
arg3: Optional[dict] = None
|
||||||
|
|
||||||
|
|
||||||
|
class _MockSchemaV1(BaseModelV1):
|
||||||
|
"""Return the arguments directly."""
|
||||||
|
|
||||||
|
arg1: int
|
||||||
|
arg2: bool
|
||||||
|
arg3: Optional[dict] = None
|
||||||
|
|
||||||
|
|
||||||
class _MockStructuredTool(BaseTool):
|
class _MockStructuredTool(BaseTool):
|
||||||
name: str = "structured_api"
|
name: str = "structured_api"
|
||||||
args_schema: Type[BaseModel] = _MockSchema
|
args_schema: Type[BaseModel] = _MockSchema
|
||||||
@@ -169,6 +177,13 @@ def test_decorator_with_specified_schema() -> None:
|
|||||||
assert isinstance(tool_func, BaseTool)
|
assert isinstance(tool_func, BaseTool)
|
||||||
assert tool_func.args_schema == _MockSchema
|
assert tool_func.args_schema == _MockSchema
|
||||||
|
|
||||||
|
@tool(args_schema=_MockSchemaV1)
|
||||||
|
def tool_func_v1(arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
|
||||||
|
return f"{arg1} {arg2} {arg3}"
|
||||||
|
|
||||||
|
assert isinstance(tool_func_v1, BaseTool)
|
||||||
|
assert tool_func_v1.args_schema == _MockSchemaV1
|
||||||
|
|
||||||
|
|
||||||
def test_decorated_function_schema_equivalent() -> None:
|
def test_decorated_function_schema_equivalent() -> None:
|
||||||
"""Test that a BaseTool without a schema meets expectations."""
|
"""Test that a BaseTool without a schema meets expectations."""
|
||||||
@@ -302,6 +317,51 @@ def test_structured_tool_types_parsed() -> None:
|
|||||||
assert result == expected
|
assert result == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_structured_tool_types_parsed_pydantic_v1() -> None:
|
||||||
|
"""Test the non-primitive types are correctly passed to structured tools."""
|
||||||
|
|
||||||
|
class SomeBaseModel(BaseModelV1):
|
||||||
|
foo: str
|
||||||
|
|
||||||
|
class AnotherBaseModel(BaseModelV1):
|
||||||
|
bar: str
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def structured_tool(some_base_model: SomeBaseModel) -> AnotherBaseModel:
|
||||||
|
"""Return the arguments directly."""
|
||||||
|
return AnotherBaseModel(bar=some_base_model.foo)
|
||||||
|
|
||||||
|
assert isinstance(structured_tool, StructuredTool)
|
||||||
|
|
||||||
|
expected = AnotherBaseModel(bar="baz")
|
||||||
|
for arg in [
|
||||||
|
SomeBaseModel(foo="baz"),
|
||||||
|
SomeBaseModel(foo="baz").dict(),
|
||||||
|
]:
|
||||||
|
args = {"some_base_model": arg}
|
||||||
|
result = structured_tool.run(args)
|
||||||
|
assert result == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_structured_tool_types_parsed_pydantic_mixed() -> None:
|
||||||
|
"""Test handling of tool with mixed Pydantic version arguments."""
|
||||||
|
|
||||||
|
class SomeBaseModel(BaseModelV1):
|
||||||
|
foo: str
|
||||||
|
|
||||||
|
class AnotherBaseModel(BaseModel):
|
||||||
|
bar: str
|
||||||
|
|
||||||
|
with pytest.raises(NotImplementedError):
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def structured_tool(
|
||||||
|
some_base_model: SomeBaseModel, another_base_model: AnotherBaseModel
|
||||||
|
) -> None:
|
||||||
|
"""Return the arguments directly."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def test_base_tool_inheritance_base_schema() -> None:
|
def test_base_tool_inheritance_base_schema() -> None:
|
||||||
"""Test schema is correctly inferred when inheriting from BaseTool."""
|
"""Test schema is correctly inferred when inheriting from BaseTool."""
|
||||||
|
|
||||||
@@ -1562,7 +1622,7 @@ def test_fn_injected_arg_with_schema(tool_: Callable) -> None:
|
|||||||
def generate_models() -> List[Any]:
|
def generate_models() -> List[Any]:
|
||||||
"""Generate a list of base models depending on the pydantic version."""
|
"""Generate a list of base models depending on the pydantic version."""
|
||||||
|
|
||||||
class FooProper(BaseModelProper):
|
class FooProper(BaseModel):
|
||||||
a: int
|
a: int
|
||||||
b: str
|
b: str
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user