mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-16 06:53:16 +00:00
core[minor]: Support all versions of pydantic base model in argsschema (#24418)
This adds support to any pydantic base model for tools. The only potential issue is that `get_input_schema()` will not always return a v1 base model.
This commit is contained in:
@@ -31,10 +31,10 @@ from langchain_core.tools import (
|
||||
StructuredTool,
|
||||
Tool,
|
||||
ToolException,
|
||||
_create_subset_model,
|
||||
tool,
|
||||
)
|
||||
from langchain_core.utils.function_calling import convert_to_openai_function
|
||||
from langchain_core.utils.pydantic import _create_subset_model
|
||||
from tests.unit_tests.fake.callbacks import FakeCallbackHandler
|
||||
|
||||
|
||||
@@ -1417,3 +1417,112 @@ def test_tool_injected_arg_with_schema(tool_: BaseTool) -> None:
|
||||
"required": ["x"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def generate_models() -> List[Any]:
|
||||
"""Generate a list of base models depending on the pydantic version."""
|
||||
from pydantic import BaseModel as BaseModelProper # pydantic: ignore
|
||||
|
||||
class FooProper(BaseModelProper):
|
||||
a: int
|
||||
b: str
|
||||
|
||||
return [FooProper]
|
||||
|
||||
|
||||
def generate_backwards_compatible_v1() -> List[Any]:
|
||||
"""Generate a model with pydantic 2 from the v1 namespace."""
|
||||
from pydantic.v1 import BaseModel as BaseModelV1 # pydantic: ignore
|
||||
|
||||
class FooV1Namespace(BaseModelV1):
|
||||
a: int
|
||||
b: str
|
||||
|
||||
return [FooV1Namespace]
|
||||
|
||||
|
||||
# This generates a list of models that can be used for testing that our APIs
|
||||
# behave well with either pydantic 1 proper,
|
||||
# pydantic v1 from pydantic 2,
|
||||
# or pydantic 2 proper.
|
||||
TEST_MODELS = generate_models() + generate_backwards_compatible_v1()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("pydantic_model", TEST_MODELS)
|
||||
def test_args_schema_as_pydantic(pydantic_model: Any) -> None:
|
||||
class SomeTool(BaseTool):
|
||||
args_schema: Type[pydantic_model] = pydantic_model
|
||||
|
||||
def _run(self, *args: Any, **kwargs: Any) -> str:
|
||||
return "foo"
|
||||
|
||||
tool = SomeTool(
|
||||
name="some_tool", description="some description", args_schema=pydantic_model
|
||||
)
|
||||
|
||||
assert tool.get_input_schema().schema() == {
|
||||
"properties": {
|
||||
"a": {"title": "A", "type": "integer"},
|
||||
"b": {"title": "B", "type": "string"},
|
||||
},
|
||||
"required": ["a", "b"],
|
||||
"title": pydantic_model.__name__,
|
||||
"type": "object",
|
||||
}
|
||||
|
||||
assert tool.tool_call_schema.schema() == {
|
||||
"description": "some description",
|
||||
"properties": {
|
||||
"a": {"title": "A", "type": "integer"},
|
||||
"b": {"title": "B", "type": "string"},
|
||||
},
|
||||
"required": ["a", "b"],
|
||||
"title": "some_tool",
|
||||
"type": "object",
|
||||
}
|
||||
|
||||
|
||||
def test_args_schema_explicitly_typed() -> None:
|
||||
"""This should test that one can type the args schema as a pydantic model.
|
||||
|
||||
Please note that this will test using pydantic 2 even though BaseTool
|
||||
is a pydantic 1 model!
|
||||
"""
|
||||
# Check with whatever pydantic model is passed in and not via v1 namespace
|
||||
from pydantic import BaseModel # pydantic: ignore
|
||||
|
||||
class Foo(BaseModel):
|
||||
a: int
|
||||
b: str
|
||||
|
||||
class SomeTool(BaseTool):
|
||||
# type ignoring here since we're allowing overriding a type
|
||||
# signature of pydantic.v1.BaseModel with pydantic.BaseModel
|
||||
# for pydantic 2!
|
||||
args_schema: Type[BaseModel] = Foo # type: ignore[assignment]
|
||||
|
||||
def _run(self, *args: Any, **kwargs: Any) -> str:
|
||||
return "foo"
|
||||
|
||||
tool = SomeTool(name="some_tool", description="some description")
|
||||
|
||||
assert tool.get_input_schema().schema() == {
|
||||
"properties": {
|
||||
"a": {"title": "A", "type": "integer"},
|
||||
"b": {"title": "B", "type": "string"},
|
||||
},
|
||||
"required": ["a", "b"],
|
||||
"title": "Foo",
|
||||
"type": "object",
|
||||
}
|
||||
|
||||
assert tool.tool_call_schema.schema() == {
|
||||
"description": "some description",
|
||||
"properties": {
|
||||
"a": {"title": "A", "type": "integer"},
|
||||
"b": {"title": "B", "type": "string"},
|
||||
},
|
||||
"required": ["a", "b"],
|
||||
"title": "some_tool",
|
||||
"type": "object",
|
||||
}
|
||||
|
@@ -3,7 +3,12 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
from langchain_core.utils.pydantic import pre_init
|
||||
from langchain_core.utils.pydantic import (
|
||||
PYDANTIC_MAJOR_VERSION,
|
||||
is_basemodel_instance,
|
||||
is_basemodel_subclass,
|
||||
pre_init,
|
||||
)
|
||||
|
||||
|
||||
def test_pre_init_decorator() -> None:
|
||||
@@ -73,3 +78,46 @@ def test_with_aliases() -> None:
|
||||
foo = Foo(y=2) # type: ignore
|
||||
assert foo.x == 2
|
||||
assert foo.z == 2
|
||||
|
||||
|
||||
def test_is_basemodel_subclass() -> None:
|
||||
"""Test pydantic."""
|
||||
if PYDANTIC_MAJOR_VERSION == 1:
|
||||
from pydantic import BaseModel as BaseModelV1Proper # pydantic: ignore
|
||||
|
||||
assert is_basemodel_subclass(BaseModelV1Proper)
|
||||
elif PYDANTIC_MAJOR_VERSION == 2:
|
||||
from pydantic import BaseModel as BaseModelV2 # pydantic: ignore
|
||||
from pydantic.v1 import BaseModel as BaseModelV1 # pydantic: ignore
|
||||
|
||||
assert is_basemodel_subclass(BaseModelV2)
|
||||
|
||||
assert is_basemodel_subclass(BaseModelV1)
|
||||
else:
|
||||
raise ValueError(f"Unsupported Pydantic version: {PYDANTIC_MAJOR_VERSION}")
|
||||
|
||||
|
||||
def test_is_basemodel_instance() -> None:
|
||||
"""Test pydantic."""
|
||||
if PYDANTIC_MAJOR_VERSION == 1:
|
||||
from pydantic import BaseModel as BaseModelV1Proper # pydantic: ignore
|
||||
|
||||
class FooV1(BaseModelV1Proper):
|
||||
x: int
|
||||
|
||||
assert is_basemodel_instance(FooV1(x=5))
|
||||
elif PYDANTIC_MAJOR_VERSION == 2:
|
||||
from pydantic import BaseModel as BaseModelV2 # pydantic: ignore
|
||||
from pydantic.v1 import BaseModel as BaseModelV1 # pydantic: ignore
|
||||
|
||||
class Foo(BaseModelV2):
|
||||
x: int
|
||||
|
||||
assert is_basemodel_instance(Foo(x=5))
|
||||
|
||||
class Bar(BaseModelV1):
|
||||
x: int
|
||||
|
||||
assert is_basemodel_instance(Bar(x=5))
|
||||
else:
|
||||
raise ValueError(f"Unsupported Pydantic version: {PYDANTIC_MAJOR_VERSION}")
|
||||
|
Reference in New Issue
Block a user