mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-28 01:19:31 +00:00
core[minor]: Support all versions of pydantic base model in argsschema (#24418)
This adds support to any pydantic base model for tools. The only potential issue is that `get_input_schema()` will not always return a v1 base model.
This commit is contained in:
parent
b2bc15e640
commit
f62b323108
@ -42,11 +42,10 @@ from typing import (
|
|||||||
Tuple,
|
Tuple,
|
||||||
Type,
|
Type,
|
||||||
Union,
|
Union,
|
||||||
cast,
|
|
||||||
get_type_hints,
|
get_type_hints,
|
||||||
)
|
)
|
||||||
|
|
||||||
from typing_extensions import Annotated, get_args, get_origin
|
from typing_extensions import Annotated, cast, get_args, get_origin
|
||||||
|
|
||||||
from langchain_core._api import deprecated
|
from langchain_core._api import deprecated
|
||||||
from langchain_core.callbacks import (
|
from langchain_core.callbacks import (
|
||||||
@ -89,6 +88,10 @@ from langchain_core.runnables.config import (
|
|||||||
run_in_executor,
|
run_in_executor,
|
||||||
)
|
)
|
||||||
from langchain_core.runnables.utils import accepts_context
|
from langchain_core.runnables.utils import accepts_context
|
||||||
|
from langchain_core.utils.pydantic import (
|
||||||
|
_create_subset_model,
|
||||||
|
is_basemodel_subclass,
|
||||||
|
)
|
||||||
|
|
||||||
FILTERED_ARGS = ("run_manager", "callbacks")
|
FILTERED_ARGS = ("run_manager", "callbacks")
|
||||||
|
|
||||||
@ -110,34 +113,6 @@ def _get_annotation_description(arg_type: Type) -> str | None:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _create_subset_model(
|
|
||||||
name: str,
|
|
||||||
model: Type[BaseModel],
|
|
||||||
field_names: list,
|
|
||||||
*,
|
|
||||||
descriptions: Optional[dict] = None,
|
|
||||||
fn_description: Optional[str] = None,
|
|
||||||
) -> Type[BaseModel]:
|
|
||||||
"""Create a pydantic model with only a subset of model's fields."""
|
|
||||||
fields = {}
|
|
||||||
|
|
||||||
for field_name in field_names:
|
|
||||||
field = model.__fields__[field_name]
|
|
||||||
t = (
|
|
||||||
# this isn't perfect but should work for most functions
|
|
||||||
field.outer_type_
|
|
||||||
if field.required and not field.allow_none
|
|
||||||
else Optional[field.outer_type_]
|
|
||||||
)
|
|
||||||
if descriptions and field_name in descriptions:
|
|
||||||
field.field_info.description = descriptions[field_name]
|
|
||||||
fields[field_name] = (t, field.field_info)
|
|
||||||
|
|
||||||
rtn = create_model(name, **fields) # type: ignore
|
|
||||||
rtn.__doc__ = textwrap.dedent(fn_description or model.__doc__ or "")
|
|
||||||
return rtn
|
|
||||||
|
|
||||||
|
|
||||||
def _get_filtered_args(
|
def _get_filtered_args(
|
||||||
inferred_model: Type[BaseModel],
|
inferred_model: Type[BaseModel],
|
||||||
func: Callable,
|
func: Callable,
|
||||||
@ -403,6 +378,16 @@ class ChildTool(BaseTool):
|
|||||||
two-tuple corresponding to the (content, artifact) of a ToolMessage.
|
two-tuple corresponding to the (content, artifact) of a ToolMessage.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs: Any) -> None:
|
||||||
|
"""Initialize the tool."""
|
||||||
|
if "args_schema" in kwargs and kwargs["args_schema"] is not None:
|
||||||
|
if not is_basemodel_subclass(kwargs["args_schema"]):
|
||||||
|
raise TypeError(
|
||||||
|
f"args_schema must be a subclass of pydantic BaseModel. "
|
||||||
|
f"Got: {kwargs['args_schema']}."
|
||||||
|
)
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
class Config(Serializable.Config):
|
class Config(Serializable.Config):
|
||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
@ -1,7 +1,11 @@
|
|||||||
"""Utilities for tests."""
|
"""Utilities for tests."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
import textwrap
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Any, Callable, Dict, Type
|
from typing import Any, Callable, Dict, List, Optional, Type
|
||||||
|
|
||||||
from langchain_core.pydantic_v1 import BaseModel, root_validator
|
from langchain_core.pydantic_v1 import BaseModel, root_validator
|
||||||
|
|
||||||
@ -19,6 +23,66 @@ def get_pydantic_major_version() -> int:
|
|||||||
PYDANTIC_MAJOR_VERSION = get_pydantic_major_version()
|
PYDANTIC_MAJOR_VERSION = get_pydantic_major_version()
|
||||||
|
|
||||||
|
|
||||||
|
def is_basemodel_subclass(cls: Type) -> bool:
|
||||||
|
"""Check if the given class is a subclass of Pydantic BaseModel.
|
||||||
|
|
||||||
|
Check if the given class is a subclass of any of the following:
|
||||||
|
|
||||||
|
* pydantic.BaseModel in Pydantic 1.x
|
||||||
|
* pydantic.BaseModel in Pydantic 2.x
|
||||||
|
* pydantic.v1.BaseModel in Pydantic 2.x
|
||||||
|
"""
|
||||||
|
# Before we can use issubclass on the cls we need to check if it is a class
|
||||||
|
if not inspect.isclass(cls):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if PYDANTIC_MAJOR_VERSION == 1:
|
||||||
|
from pydantic import BaseModel as BaseModelV1Proper
|
||||||
|
|
||||||
|
if issubclass(cls, BaseModelV1Proper):
|
||||||
|
return True
|
||||||
|
elif PYDANTIC_MAJOR_VERSION == 2:
|
||||||
|
from pydantic import BaseModel as BaseModelV2
|
||||||
|
from pydantic.v1 import BaseModel as BaseModelV1
|
||||||
|
|
||||||
|
if issubclass(cls, BaseModelV2):
|
||||||
|
return True
|
||||||
|
|
||||||
|
if issubclass(cls, BaseModelV1):
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported Pydantic version: {PYDANTIC_MAJOR_VERSION}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def is_basemodel_instance(obj: Any) -> bool:
|
||||||
|
"""Check if the given class is an instance of Pydantic BaseModel.
|
||||||
|
|
||||||
|
Check if the given class is an instance of any of the following:
|
||||||
|
|
||||||
|
* pydantic.BaseModel in Pydantic 1.x
|
||||||
|
* pydantic.BaseModel in Pydantic 2.x
|
||||||
|
* pydantic.v1.BaseModel in Pydantic 2.x
|
||||||
|
"""
|
||||||
|
if PYDANTIC_MAJOR_VERSION == 1:
|
||||||
|
from pydantic import BaseModel as BaseModelV1Proper
|
||||||
|
|
||||||
|
if isinstance(obj, BaseModelV1Proper):
|
||||||
|
return True
|
||||||
|
elif PYDANTIC_MAJOR_VERSION == 2:
|
||||||
|
from pydantic import BaseModel as BaseModelV2
|
||||||
|
from pydantic.v1 import BaseModel as BaseModelV1
|
||||||
|
|
||||||
|
if isinstance(obj, BaseModelV2):
|
||||||
|
return True
|
||||||
|
|
||||||
|
if isinstance(obj, BaseModelV1):
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported Pydantic version: {PYDANTIC_MAJOR_VERSION}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
# How to type hint this?
|
# How to type hint this?
|
||||||
def pre_init(func: Callable) -> Any:
|
def pre_init(func: Callable) -> Any:
|
||||||
"""Decorator to run a function before model initialization.
|
"""Decorator to run a function before model initialization.
|
||||||
@ -64,3 +128,106 @@ def pre_init(func: Callable) -> Any:
|
|||||||
return func(cls, values)
|
return func(cls, values)
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def _create_subset_model_v1(
|
||||||
|
name: str,
|
||||||
|
model: Type[BaseModel],
|
||||||
|
field_names: list,
|
||||||
|
*,
|
||||||
|
descriptions: Optional[dict] = None,
|
||||||
|
fn_description: Optional[str] = None,
|
||||||
|
) -> Type[BaseModel]:
|
||||||
|
"""Create a pydantic model with only a subset of model's fields."""
|
||||||
|
from langchain_core.pydantic_v1 import create_model
|
||||||
|
|
||||||
|
fields = {}
|
||||||
|
|
||||||
|
for field_name in field_names:
|
||||||
|
field = model.__fields__[field_name]
|
||||||
|
t = (
|
||||||
|
# this isn't perfect but should work for most functions
|
||||||
|
field.outer_type_
|
||||||
|
if field.required and not field.allow_none
|
||||||
|
else Optional[field.outer_type_]
|
||||||
|
)
|
||||||
|
if descriptions and field_name in descriptions:
|
||||||
|
field.field_info.description = descriptions[field_name]
|
||||||
|
fields[field_name] = (t, field.field_info)
|
||||||
|
|
||||||
|
rtn = create_model(name, **fields) # type: ignore
|
||||||
|
rtn.__doc__ = textwrap.dedent(fn_description or model.__doc__ or "")
|
||||||
|
return rtn
|
||||||
|
|
||||||
|
|
||||||
|
def _create_subset_model_v2(
|
||||||
|
name: str,
|
||||||
|
model: Type[BaseModel],
|
||||||
|
field_names: List[str],
|
||||||
|
*,
|
||||||
|
descriptions: Optional[dict] = None,
|
||||||
|
fn_description: Optional[str] = None,
|
||||||
|
) -> Type[BaseModel]:
|
||||||
|
"""Create a pydantic model with a subset of the model fields."""
|
||||||
|
from pydantic import create_model # pydantic: ignore
|
||||||
|
from pydantic.fields import FieldInfo # pydantic: ignore
|
||||||
|
|
||||||
|
descriptions_ = descriptions or {}
|
||||||
|
fields = {}
|
||||||
|
for field_name in field_names:
|
||||||
|
field = model.model_fields[field_name] # type: ignore
|
||||||
|
description = descriptions_.get(field_name, field.description)
|
||||||
|
fields[field_name] = (
|
||||||
|
field.annotation,
|
||||||
|
FieldInfo(description=description, default=field.default),
|
||||||
|
)
|
||||||
|
rtn = create_model(name, **fields) # type: ignore
|
||||||
|
|
||||||
|
rtn.__doc__ = textwrap.dedent(fn_description or model.__doc__ or "")
|
||||||
|
return rtn
|
||||||
|
|
||||||
|
|
||||||
|
# Private functionality to create a subset model that's compatible across
|
||||||
|
# different versions of pydantic.
|
||||||
|
# Handles pydantic versions 1.x and 2.x. including v1 of pydantic in 2.x.
|
||||||
|
# However, can't find a way to type hint this.
|
||||||
|
def _create_subset_model(
|
||||||
|
name: str,
|
||||||
|
model: Type[BaseModel],
|
||||||
|
field_names: List[str],
|
||||||
|
*,
|
||||||
|
descriptions: Optional[dict] = None,
|
||||||
|
fn_description: Optional[str] = None,
|
||||||
|
) -> Type[BaseModel]:
|
||||||
|
"""Create subset model using the same pydantic version as the input model."""
|
||||||
|
if PYDANTIC_MAJOR_VERSION == 1:
|
||||||
|
return _create_subset_model_v1(
|
||||||
|
name,
|
||||||
|
model,
|
||||||
|
field_names,
|
||||||
|
descriptions=descriptions,
|
||||||
|
fn_description=fn_description,
|
||||||
|
)
|
||||||
|
elif PYDANTIC_MAJOR_VERSION == 2:
|
||||||
|
from pydantic.v1 import BaseModel as BaseModelV1 # pydantic: ignore
|
||||||
|
|
||||||
|
if issubclass(model, BaseModelV1):
|
||||||
|
return _create_subset_model_v1(
|
||||||
|
name,
|
||||||
|
model,
|
||||||
|
field_names,
|
||||||
|
descriptions=descriptions,
|
||||||
|
fn_description=fn_description,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return _create_subset_model_v2(
|
||||||
|
name,
|
||||||
|
model,
|
||||||
|
field_names,
|
||||||
|
descriptions=descriptions,
|
||||||
|
fn_description=fn_description,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"Unsupported pydantic version: {PYDANTIC_MAJOR_VERSION}"
|
||||||
|
)
|
||||||
|
@ -31,10 +31,10 @@ from langchain_core.tools import (
|
|||||||
StructuredTool,
|
StructuredTool,
|
||||||
Tool,
|
Tool,
|
||||||
ToolException,
|
ToolException,
|
||||||
_create_subset_model,
|
|
||||||
tool,
|
tool,
|
||||||
)
|
)
|
||||||
from langchain_core.utils.function_calling import convert_to_openai_function
|
from langchain_core.utils.function_calling import convert_to_openai_function
|
||||||
|
from langchain_core.utils.pydantic import _create_subset_model
|
||||||
from tests.unit_tests.fake.callbacks import FakeCallbackHandler
|
from tests.unit_tests.fake.callbacks import FakeCallbackHandler
|
||||||
|
|
||||||
|
|
||||||
@ -1417,3 +1417,112 @@ def test_tool_injected_arg_with_schema(tool_: BaseTool) -> None:
|
|||||||
"required": ["x"],
|
"required": ["x"],
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def generate_models() -> List[Any]:
|
||||||
|
"""Generate a list of base models depending on the pydantic version."""
|
||||||
|
from pydantic import BaseModel as BaseModelProper # pydantic: ignore
|
||||||
|
|
||||||
|
class FooProper(BaseModelProper):
|
||||||
|
a: int
|
||||||
|
b: str
|
||||||
|
|
||||||
|
return [FooProper]
|
||||||
|
|
||||||
|
|
||||||
|
def generate_backwards_compatible_v1() -> List[Any]:
|
||||||
|
"""Generate a model with pydantic 2 from the v1 namespace."""
|
||||||
|
from pydantic.v1 import BaseModel as BaseModelV1 # pydantic: ignore
|
||||||
|
|
||||||
|
class FooV1Namespace(BaseModelV1):
|
||||||
|
a: int
|
||||||
|
b: str
|
||||||
|
|
||||||
|
return [FooV1Namespace]
|
||||||
|
|
||||||
|
|
||||||
|
# This generates a list of models that can be used for testing that our APIs
|
||||||
|
# behave well with either pydantic 1 proper,
|
||||||
|
# pydantic v1 from pydantic 2,
|
||||||
|
# or pydantic 2 proper.
|
||||||
|
TEST_MODELS = generate_models() + generate_backwards_compatible_v1()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("pydantic_model", TEST_MODELS)
|
||||||
|
def test_args_schema_as_pydantic(pydantic_model: Any) -> None:
|
||||||
|
class SomeTool(BaseTool):
|
||||||
|
args_schema: Type[pydantic_model] = pydantic_model
|
||||||
|
|
||||||
|
def _run(self, *args: Any, **kwargs: Any) -> str:
|
||||||
|
return "foo"
|
||||||
|
|
||||||
|
tool = SomeTool(
|
||||||
|
name="some_tool", description="some description", args_schema=pydantic_model
|
||||||
|
)
|
||||||
|
|
||||||
|
assert tool.get_input_schema().schema() == {
|
||||||
|
"properties": {
|
||||||
|
"a": {"title": "A", "type": "integer"},
|
||||||
|
"b": {"title": "B", "type": "string"},
|
||||||
|
},
|
||||||
|
"required": ["a", "b"],
|
||||||
|
"title": pydantic_model.__name__,
|
||||||
|
"type": "object",
|
||||||
|
}
|
||||||
|
|
||||||
|
assert tool.tool_call_schema.schema() == {
|
||||||
|
"description": "some description",
|
||||||
|
"properties": {
|
||||||
|
"a": {"title": "A", "type": "integer"},
|
||||||
|
"b": {"title": "B", "type": "string"},
|
||||||
|
},
|
||||||
|
"required": ["a", "b"],
|
||||||
|
"title": "some_tool",
|
||||||
|
"type": "object",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_args_schema_explicitly_typed() -> None:
|
||||||
|
"""This should test that one can type the args schema as a pydantic model.
|
||||||
|
|
||||||
|
Please note that this will test using pydantic 2 even though BaseTool
|
||||||
|
is a pydantic 1 model!
|
||||||
|
"""
|
||||||
|
# Check with whatever pydantic model is passed in and not via v1 namespace
|
||||||
|
from pydantic import BaseModel # pydantic: ignore
|
||||||
|
|
||||||
|
class Foo(BaseModel):
|
||||||
|
a: int
|
||||||
|
b: str
|
||||||
|
|
||||||
|
class SomeTool(BaseTool):
|
||||||
|
# type ignoring here since we're allowing overriding a type
|
||||||
|
# signature of pydantic.v1.BaseModel with pydantic.BaseModel
|
||||||
|
# for pydantic 2!
|
||||||
|
args_schema: Type[BaseModel] = Foo # type: ignore[assignment]
|
||||||
|
|
||||||
|
def _run(self, *args: Any, **kwargs: Any) -> str:
|
||||||
|
return "foo"
|
||||||
|
|
||||||
|
tool = SomeTool(name="some_tool", description="some description")
|
||||||
|
|
||||||
|
assert tool.get_input_schema().schema() == {
|
||||||
|
"properties": {
|
||||||
|
"a": {"title": "A", "type": "integer"},
|
||||||
|
"b": {"title": "B", "type": "string"},
|
||||||
|
},
|
||||||
|
"required": ["a", "b"],
|
||||||
|
"title": "Foo",
|
||||||
|
"type": "object",
|
||||||
|
}
|
||||||
|
|
||||||
|
assert tool.tool_call_schema.schema() == {
|
||||||
|
"description": "some description",
|
||||||
|
"properties": {
|
||||||
|
"a": {"title": "A", "type": "integer"},
|
||||||
|
"b": {"title": "B", "type": "string"},
|
||||||
|
},
|
||||||
|
"required": ["a", "b"],
|
||||||
|
"title": "some_tool",
|
||||||
|
"type": "object",
|
||||||
|
}
|
||||||
|
@ -3,7 +3,12 @@
|
|||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||||
from langchain_core.utils.pydantic import pre_init
|
from langchain_core.utils.pydantic import (
|
||||||
|
PYDANTIC_MAJOR_VERSION,
|
||||||
|
is_basemodel_instance,
|
||||||
|
is_basemodel_subclass,
|
||||||
|
pre_init,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_pre_init_decorator() -> None:
|
def test_pre_init_decorator() -> None:
|
||||||
@ -73,3 +78,46 @@ def test_with_aliases() -> None:
|
|||||||
foo = Foo(y=2) # type: ignore
|
foo = Foo(y=2) # type: ignore
|
||||||
assert foo.x == 2
|
assert foo.x == 2
|
||||||
assert foo.z == 2
|
assert foo.z == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_basemodel_subclass() -> None:
|
||||||
|
"""Test pydantic."""
|
||||||
|
if PYDANTIC_MAJOR_VERSION == 1:
|
||||||
|
from pydantic import BaseModel as BaseModelV1Proper # pydantic: ignore
|
||||||
|
|
||||||
|
assert is_basemodel_subclass(BaseModelV1Proper)
|
||||||
|
elif PYDANTIC_MAJOR_VERSION == 2:
|
||||||
|
from pydantic import BaseModel as BaseModelV2 # pydantic: ignore
|
||||||
|
from pydantic.v1 import BaseModel as BaseModelV1 # pydantic: ignore
|
||||||
|
|
||||||
|
assert is_basemodel_subclass(BaseModelV2)
|
||||||
|
|
||||||
|
assert is_basemodel_subclass(BaseModelV1)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported Pydantic version: {PYDANTIC_MAJOR_VERSION}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_basemodel_instance() -> None:
|
||||||
|
"""Test pydantic."""
|
||||||
|
if PYDANTIC_MAJOR_VERSION == 1:
|
||||||
|
from pydantic import BaseModel as BaseModelV1Proper # pydantic: ignore
|
||||||
|
|
||||||
|
class FooV1(BaseModelV1Proper):
|
||||||
|
x: int
|
||||||
|
|
||||||
|
assert is_basemodel_instance(FooV1(x=5))
|
||||||
|
elif PYDANTIC_MAJOR_VERSION == 2:
|
||||||
|
from pydantic import BaseModel as BaseModelV2 # pydantic: ignore
|
||||||
|
from pydantic.v1 import BaseModel as BaseModelV1 # pydantic: ignore
|
||||||
|
|
||||||
|
class Foo(BaseModelV2):
|
||||||
|
x: int
|
||||||
|
|
||||||
|
assert is_basemodel_instance(Foo(x=5))
|
||||||
|
|
||||||
|
class Bar(BaseModelV1):
|
||||||
|
x: int
|
||||||
|
|
||||||
|
assert is_basemodel_instance(Bar(x=5))
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported Pydantic version: {PYDANTIC_MAJOR_VERSION}")
|
||||||
|
Loading…
Reference in New Issue
Block a user