core[minor]: Support all versions of pydantic base model in argsschema (#24418)

This adds support to any pydantic base model for tools.

The only potential issue is that `get_input_schema()` will not always
return a v1 base model.
This commit is contained in:
Eugene Yurtsev 2024-07-18 17:14:23 -04:00 committed by GitHub
parent b2bc15e640
commit f62b323108
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 342 additions and 33 deletions

View File

@ -42,11 +42,10 @@ from typing import (
Tuple, Tuple,
Type, Type,
Union, Union,
cast,
get_type_hints, get_type_hints,
) )
from typing_extensions import Annotated, get_args, get_origin from typing_extensions import Annotated, cast, get_args, get_origin
from langchain_core._api import deprecated from langchain_core._api import deprecated
from langchain_core.callbacks import ( from langchain_core.callbacks import (
@ -89,6 +88,10 @@ from langchain_core.runnables.config import (
run_in_executor, run_in_executor,
) )
from langchain_core.runnables.utils import accepts_context from langchain_core.runnables.utils import accepts_context
from langchain_core.utils.pydantic import (
_create_subset_model,
is_basemodel_subclass,
)
FILTERED_ARGS = ("run_manager", "callbacks") FILTERED_ARGS = ("run_manager", "callbacks")
@ -110,34 +113,6 @@ def _get_annotation_description(arg_type: Type) -> str | None:
return None return None
def _create_subset_model(
name: str,
model: Type[BaseModel],
field_names: list,
*,
descriptions: Optional[dict] = None,
fn_description: Optional[str] = None,
) -> Type[BaseModel]:
"""Create a pydantic model with only a subset of model's fields."""
fields = {}
for field_name in field_names:
field = model.__fields__[field_name]
t = (
# this isn't perfect but should work for most functions
field.outer_type_
if field.required and not field.allow_none
else Optional[field.outer_type_]
)
if descriptions and field_name in descriptions:
field.field_info.description = descriptions[field_name]
fields[field_name] = (t, field.field_info)
rtn = create_model(name, **fields) # type: ignore
rtn.__doc__ = textwrap.dedent(fn_description or model.__doc__ or "")
return rtn
def _get_filtered_args( def _get_filtered_args(
inferred_model: Type[BaseModel], inferred_model: Type[BaseModel],
func: Callable, func: Callable,
@ -403,6 +378,16 @@ class ChildTool(BaseTool):
two-tuple corresponding to the (content, artifact) of a ToolMessage. two-tuple corresponding to the (content, artifact) of a ToolMessage.
""" """
def __init__(self, **kwargs: Any) -> None:
"""Initialize the tool."""
if "args_schema" in kwargs and kwargs["args_schema"] is not None:
if not is_basemodel_subclass(kwargs["args_schema"]):
raise TypeError(
f"args_schema must be a subclass of pydantic BaseModel. "
f"Got: {kwargs['args_schema']}."
)
super().__init__(**kwargs)
class Config(Serializable.Config): class Config(Serializable.Config):
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""

View File

@ -1,7 +1,11 @@
"""Utilities for tests.""" """Utilities for tests."""
from __future__ import annotations
import inspect
import textwrap
from functools import wraps from functools import wraps
from typing import Any, Callable, Dict, Type from typing import Any, Callable, Dict, List, Optional, Type
from langchain_core.pydantic_v1 import BaseModel, root_validator from langchain_core.pydantic_v1 import BaseModel, root_validator
@ -19,6 +23,66 @@ def get_pydantic_major_version() -> int:
PYDANTIC_MAJOR_VERSION = get_pydantic_major_version() PYDANTIC_MAJOR_VERSION = get_pydantic_major_version()
def is_basemodel_subclass(cls: Type) -> bool:
"""Check if the given class is a subclass of Pydantic BaseModel.
Check if the given class is a subclass of any of the following:
* pydantic.BaseModel in Pydantic 1.x
* pydantic.BaseModel in Pydantic 2.x
* pydantic.v1.BaseModel in Pydantic 2.x
"""
# Before we can use issubclass on the cls we need to check if it is a class
if not inspect.isclass(cls):
return False
if PYDANTIC_MAJOR_VERSION == 1:
from pydantic import BaseModel as BaseModelV1Proper
if issubclass(cls, BaseModelV1Proper):
return True
elif PYDANTIC_MAJOR_VERSION == 2:
from pydantic import BaseModel as BaseModelV2
from pydantic.v1 import BaseModel as BaseModelV1
if issubclass(cls, BaseModelV2):
return True
if issubclass(cls, BaseModelV1):
return True
else:
raise ValueError(f"Unsupported Pydantic version: {PYDANTIC_MAJOR_VERSION}")
return False
def is_basemodel_instance(obj: Any) -> bool:
"""Check if the given class is an instance of Pydantic BaseModel.
Check if the given class is an instance of any of the following:
* pydantic.BaseModel in Pydantic 1.x
* pydantic.BaseModel in Pydantic 2.x
* pydantic.v1.BaseModel in Pydantic 2.x
"""
if PYDANTIC_MAJOR_VERSION == 1:
from pydantic import BaseModel as BaseModelV1Proper
if isinstance(obj, BaseModelV1Proper):
return True
elif PYDANTIC_MAJOR_VERSION == 2:
from pydantic import BaseModel as BaseModelV2
from pydantic.v1 import BaseModel as BaseModelV1
if isinstance(obj, BaseModelV2):
return True
if isinstance(obj, BaseModelV1):
return True
else:
raise ValueError(f"Unsupported Pydantic version: {PYDANTIC_MAJOR_VERSION}")
return False
# How to type hint this? # How to type hint this?
def pre_init(func: Callable) -> Any: def pre_init(func: Callable) -> Any:
"""Decorator to run a function before model initialization. """Decorator to run a function before model initialization.
@ -64,3 +128,106 @@ def pre_init(func: Callable) -> Any:
return func(cls, values) return func(cls, values)
return wrapper return wrapper
def _create_subset_model_v1(
name: str,
model: Type[BaseModel],
field_names: list,
*,
descriptions: Optional[dict] = None,
fn_description: Optional[str] = None,
) -> Type[BaseModel]:
"""Create a pydantic model with only a subset of model's fields."""
from langchain_core.pydantic_v1 import create_model
fields = {}
for field_name in field_names:
field = model.__fields__[field_name]
t = (
# this isn't perfect but should work for most functions
field.outer_type_
if field.required and not field.allow_none
else Optional[field.outer_type_]
)
if descriptions and field_name in descriptions:
field.field_info.description = descriptions[field_name]
fields[field_name] = (t, field.field_info)
rtn = create_model(name, **fields) # type: ignore
rtn.__doc__ = textwrap.dedent(fn_description or model.__doc__ or "")
return rtn
def _create_subset_model_v2(
name: str,
model: Type[BaseModel],
field_names: List[str],
*,
descriptions: Optional[dict] = None,
fn_description: Optional[str] = None,
) -> Type[BaseModel]:
"""Create a pydantic model with a subset of the model fields."""
from pydantic import create_model # pydantic: ignore
from pydantic.fields import FieldInfo # pydantic: ignore
descriptions_ = descriptions or {}
fields = {}
for field_name in field_names:
field = model.model_fields[field_name] # type: ignore
description = descriptions_.get(field_name, field.description)
fields[field_name] = (
field.annotation,
FieldInfo(description=description, default=field.default),
)
rtn = create_model(name, **fields) # type: ignore
rtn.__doc__ = textwrap.dedent(fn_description or model.__doc__ or "")
return rtn
# Private functionality to create a subset model that's compatible across
# different versions of pydantic.
# Handles pydantic versions 1.x and 2.x. including v1 of pydantic in 2.x.
# However, can't find a way to type hint this.
def _create_subset_model(
name: str,
model: Type[BaseModel],
field_names: List[str],
*,
descriptions: Optional[dict] = None,
fn_description: Optional[str] = None,
) -> Type[BaseModel]:
"""Create subset model using the same pydantic version as the input model."""
if PYDANTIC_MAJOR_VERSION == 1:
return _create_subset_model_v1(
name,
model,
field_names,
descriptions=descriptions,
fn_description=fn_description,
)
elif PYDANTIC_MAJOR_VERSION == 2:
from pydantic.v1 import BaseModel as BaseModelV1 # pydantic: ignore
if issubclass(model, BaseModelV1):
return _create_subset_model_v1(
name,
model,
field_names,
descriptions=descriptions,
fn_description=fn_description,
)
else:
return _create_subset_model_v2(
name,
model,
field_names,
descriptions=descriptions,
fn_description=fn_description,
)
else:
raise NotImplementedError(
f"Unsupported pydantic version: {PYDANTIC_MAJOR_VERSION}"
)

View File

@ -31,10 +31,10 @@ from langchain_core.tools import (
StructuredTool, StructuredTool,
Tool, Tool,
ToolException, ToolException,
_create_subset_model,
tool, tool,
) )
from langchain_core.utils.function_calling import convert_to_openai_function from langchain_core.utils.function_calling import convert_to_openai_function
from langchain_core.utils.pydantic import _create_subset_model
from tests.unit_tests.fake.callbacks import FakeCallbackHandler from tests.unit_tests.fake.callbacks import FakeCallbackHandler
@ -1417,3 +1417,112 @@ def test_tool_injected_arg_with_schema(tool_: BaseTool) -> None:
"required": ["x"], "required": ["x"],
}, },
} }
def generate_models() -> List[Any]:
"""Generate a list of base models depending on the pydantic version."""
from pydantic import BaseModel as BaseModelProper # pydantic: ignore
class FooProper(BaseModelProper):
a: int
b: str
return [FooProper]
def generate_backwards_compatible_v1() -> List[Any]:
"""Generate a model with pydantic 2 from the v1 namespace."""
from pydantic.v1 import BaseModel as BaseModelV1 # pydantic: ignore
class FooV1Namespace(BaseModelV1):
a: int
b: str
return [FooV1Namespace]
# This generates a list of models that can be used for testing that our APIs
# behave well with either pydantic 1 proper,
# pydantic v1 from pydantic 2,
# or pydantic 2 proper.
TEST_MODELS = generate_models() + generate_backwards_compatible_v1()
@pytest.mark.parametrize("pydantic_model", TEST_MODELS)
def test_args_schema_as_pydantic(pydantic_model: Any) -> None:
class SomeTool(BaseTool):
args_schema: Type[pydantic_model] = pydantic_model
def _run(self, *args: Any, **kwargs: Any) -> str:
return "foo"
tool = SomeTool(
name="some_tool", description="some description", args_schema=pydantic_model
)
assert tool.get_input_schema().schema() == {
"properties": {
"a": {"title": "A", "type": "integer"},
"b": {"title": "B", "type": "string"},
},
"required": ["a", "b"],
"title": pydantic_model.__name__,
"type": "object",
}
assert tool.tool_call_schema.schema() == {
"description": "some description",
"properties": {
"a": {"title": "A", "type": "integer"},
"b": {"title": "B", "type": "string"},
},
"required": ["a", "b"],
"title": "some_tool",
"type": "object",
}
def test_args_schema_explicitly_typed() -> None:
"""This should test that one can type the args schema as a pydantic model.
Please note that this will test using pydantic 2 even though BaseTool
is a pydantic 1 model!
"""
# Check with whatever pydantic model is passed in and not via v1 namespace
from pydantic import BaseModel # pydantic: ignore
class Foo(BaseModel):
a: int
b: str
class SomeTool(BaseTool):
# type ignoring here since we're allowing overriding a type
# signature of pydantic.v1.BaseModel with pydantic.BaseModel
# for pydantic 2!
args_schema: Type[BaseModel] = Foo # type: ignore[assignment]
def _run(self, *args: Any, **kwargs: Any) -> str:
return "foo"
tool = SomeTool(name="some_tool", description="some description")
assert tool.get_input_schema().schema() == {
"properties": {
"a": {"title": "A", "type": "integer"},
"b": {"title": "B", "type": "string"},
},
"required": ["a", "b"],
"title": "Foo",
"type": "object",
}
assert tool.tool_call_schema.schema() == {
"description": "some description",
"properties": {
"a": {"title": "A", "type": "integer"},
"b": {"title": "B", "type": "string"},
},
"required": ["a", "b"],
"title": "some_tool",
"type": "object",
}

View File

@ -3,7 +3,12 @@
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.utils.pydantic import pre_init from langchain_core.utils.pydantic import (
PYDANTIC_MAJOR_VERSION,
is_basemodel_instance,
is_basemodel_subclass,
pre_init,
)
def test_pre_init_decorator() -> None: def test_pre_init_decorator() -> None:
@ -73,3 +78,46 @@ def test_with_aliases() -> None:
foo = Foo(y=2) # type: ignore foo = Foo(y=2) # type: ignore
assert foo.x == 2 assert foo.x == 2
assert foo.z == 2 assert foo.z == 2
def test_is_basemodel_subclass() -> None:
"""Test pydantic."""
if PYDANTIC_MAJOR_VERSION == 1:
from pydantic import BaseModel as BaseModelV1Proper # pydantic: ignore
assert is_basemodel_subclass(BaseModelV1Proper)
elif PYDANTIC_MAJOR_VERSION == 2:
from pydantic import BaseModel as BaseModelV2 # pydantic: ignore
from pydantic.v1 import BaseModel as BaseModelV1 # pydantic: ignore
assert is_basemodel_subclass(BaseModelV2)
assert is_basemodel_subclass(BaseModelV1)
else:
raise ValueError(f"Unsupported Pydantic version: {PYDANTIC_MAJOR_VERSION}")
def test_is_basemodel_instance() -> None:
"""Test pydantic."""
if PYDANTIC_MAJOR_VERSION == 1:
from pydantic import BaseModel as BaseModelV1Proper # pydantic: ignore
class FooV1(BaseModelV1Proper):
x: int
assert is_basemodel_instance(FooV1(x=5))
elif PYDANTIC_MAJOR_VERSION == 2:
from pydantic import BaseModel as BaseModelV2 # pydantic: ignore
from pydantic.v1 import BaseModel as BaseModelV1 # pydantic: ignore
class Foo(BaseModelV2):
x: int
assert is_basemodel_instance(Foo(x=5))
class Bar(BaseModelV1):
x: int
assert is_basemodel_instance(Bar(x=5))
else:
raise ValueError(f"Unsupported Pydantic version: {PYDANTIC_MAJOR_VERSION}")