mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-17 16:39:52 +00:00
core[patch]: Add pydantic metadata to subset model (#25032)
- **Description:** This includes Pydantic field metadata in `_create_subset_model_v2` so that it gets included in the final serialized form that get sent out. - **Issue:** #25031 - **Dependencies:** n/a - **Twitter handle:** @gramliu --------- Co-authored-by: Bagatur <baskaryan@gmail.com> Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
This commit is contained in:
parent
8f33fce871
commit
88a9a6a758
@ -1055,10 +1055,12 @@ class StructuredTool(BaseTool):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Type args_schema as TypeBaseModel if we can get mypy to correctly recognize
|
||||||
|
# pydantic v2 BaseModel classes.
|
||||||
def tool(
|
def tool(
|
||||||
*args: Union[str, Callable, Runnable],
|
*args: Union[str, Callable, Runnable],
|
||||||
return_direct: bool = False,
|
return_direct: bool = False,
|
||||||
args_schema: Optional[Type[BaseModel]] = None,
|
args_schema: Optional[Type] = None,
|
||||||
infer_schema: bool = True,
|
infer_schema: bool = True,
|
||||||
response_format: Literal["content", "content_and_artifact"] = "content",
|
response_format: Literal["content", "content_and_artifact"] = "content",
|
||||||
parse_docstring: bool = False,
|
parse_docstring: bool = False,
|
||||||
|
@ -29,8 +29,6 @@ if PYDANTIC_MAJOR_VERSION == 1:
|
|||||||
PydanticBaseModel = pydantic.BaseModel
|
PydanticBaseModel = pydantic.BaseModel
|
||||||
TypeBaseModel = Type[BaseModel]
|
TypeBaseModel = Type[BaseModel]
|
||||||
elif PYDANTIC_MAJOR_VERSION == 2:
|
elif PYDANTIC_MAJOR_VERSION == 2:
|
||||||
from pydantic.v1 import BaseModel # pydantic: ignore
|
|
||||||
|
|
||||||
# Union type needs to be last assignment to PydanticBaseModel to make mypy happy.
|
# Union type needs to be last assignment to PydanticBaseModel to make mypy happy.
|
||||||
PydanticBaseModel = Union[BaseModel, pydantic.BaseModel] # type: ignore
|
PydanticBaseModel = Union[BaseModel, pydantic.BaseModel] # type: ignore
|
||||||
TypeBaseModel = Union[Type[BaseModel], Type[pydantic.BaseModel]] # type: ignore
|
TypeBaseModel = Union[Type[BaseModel], Type[pydantic.BaseModel]] # type: ignore
|
||||||
@ -199,12 +197,12 @@ def _create_subset_model_v1(
|
|||||||
|
|
||||||
def _create_subset_model_v2(
|
def _create_subset_model_v2(
|
||||||
name: str,
|
name: str,
|
||||||
model: Type[BaseModel],
|
model: Type[pydantic.BaseModel],
|
||||||
field_names: List[str],
|
field_names: List[str],
|
||||||
*,
|
*,
|
||||||
descriptions: Optional[dict] = None,
|
descriptions: Optional[dict] = None,
|
||||||
fn_description: Optional[str] = None,
|
fn_description: Optional[str] = None,
|
||||||
) -> Type[BaseModel]:
|
) -> Type[pydantic.BaseModel]:
|
||||||
"""Create a pydantic model with a subset of the model fields."""
|
"""Create a pydantic model with a subset of the model fields."""
|
||||||
from pydantic import create_model # pydantic: ignore
|
from pydantic import create_model # pydantic: ignore
|
||||||
from pydantic.fields import FieldInfo # pydantic: ignore
|
from pydantic.fields import FieldInfo # pydantic: ignore
|
||||||
@ -214,10 +212,10 @@ def _create_subset_model_v2(
|
|||||||
for field_name in field_names:
|
for field_name in field_names:
|
||||||
field = model.model_fields[field_name] # type: ignore
|
field = model.model_fields[field_name] # type: ignore
|
||||||
description = descriptions_.get(field_name, field.description)
|
description = descriptions_.get(field_name, field.description)
|
||||||
fields[field_name] = (
|
field_info = FieldInfo(description=description, default=field.default)
|
||||||
field.annotation,
|
if field.metadata:
|
||||||
FieldInfo(description=description, default=field.default),
|
field_info.metadata = field.metadata
|
||||||
)
|
fields[field_name] = (field.annotation, field_info)
|
||||||
rtn = create_model(name, **fields) # type: ignore
|
rtn = create_model(name, **fields) # type: ignore
|
||||||
|
|
||||||
rtn.__doc__ = textwrap.dedent(fn_description or model.__doc__ or "")
|
rtn.__doc__ = textwrap.dedent(fn_description or model.__doc__ or "")
|
||||||
@ -230,7 +228,7 @@ def _create_subset_model_v2(
|
|||||||
# However, can't find a way to type hint this.
|
# However, can't find a way to type hint this.
|
||||||
def _create_subset_model(
|
def _create_subset_model(
|
||||||
name: str,
|
name: str,
|
||||||
model: Type[BaseModel],
|
model: TypeBaseModel,
|
||||||
field_names: List[str],
|
field_names: List[str],
|
||||||
*,
|
*,
|
||||||
descriptions: Optional[dict] = None,
|
descriptions: Optional[dict] = None,
|
||||||
|
@ -1863,3 +1863,41 @@ def test__get_all_basemodel_annotations_v1() -> None:
|
|||||||
}
|
}
|
||||||
actual = _get_all_basemodel_annotations(ModelD[int])
|
actual = _get_all_basemodel_annotations(ModelD[int])
|
||||||
assert actual == expected
|
assert actual == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Testing pydantic v2.")
|
||||||
|
def test_tool_args_schema_pydantic_v2_with_metadata() -> None:
|
||||||
|
from pydantic import BaseModel as BaseModelV2 # pydantic: ignore
|
||||||
|
from pydantic import Field as FieldV2 # pydantic: ignore
|
||||||
|
from pydantic import ValidationError as ValidationErrorV2 # pydantic: ignore
|
||||||
|
|
||||||
|
class Foo(BaseModelV2):
|
||||||
|
x: List[int] = FieldV2(
|
||||||
|
description="List of integers", min_length=10, max_length=15
|
||||||
|
)
|
||||||
|
|
||||||
|
@tool(args_schema=Foo)
|
||||||
|
def foo(x): # type: ignore[no-untyped-def]
|
||||||
|
"""foo"""
|
||||||
|
return x
|
||||||
|
|
||||||
|
assert foo.tool_call_schema.schema() == {
|
||||||
|
"description": "foo",
|
||||||
|
"properties": {
|
||||||
|
"x": {
|
||||||
|
"description": "List of integers",
|
||||||
|
"items": {"type": "integer"},
|
||||||
|
"maxItems": 15,
|
||||||
|
"minItems": 10,
|
||||||
|
"title": "X",
|
||||||
|
"type": "array",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["x"],
|
||||||
|
"title": "foo",
|
||||||
|
"type": "object",
|
||||||
|
}
|
||||||
|
|
||||||
|
assert foo.invoke({"x": [0] * 10})
|
||||||
|
with pytest.raises(ValidationErrorV2):
|
||||||
|
foo.invoke({"x": [0] * 9})
|
||||||
|
@ -1,10 +1,13 @@
|
|||||||
"""Test for some custom pydantic decorators."""
|
"""Test for some custom pydantic decorators."""
|
||||||
|
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||||
from langchain_core.utils.pydantic import (
|
from langchain_core.utils.pydantic import (
|
||||||
PYDANTIC_MAJOR_VERSION,
|
PYDANTIC_MAJOR_VERSION,
|
||||||
|
_create_subset_model_v2,
|
||||||
is_basemodel_instance,
|
is_basemodel_instance,
|
||||||
is_basemodel_subclass,
|
is_basemodel_subclass,
|
||||||
pre_init,
|
pre_init,
|
||||||
@ -121,3 +124,32 @@ def test_is_basemodel_instance() -> None:
|
|||||||
assert is_basemodel_instance(Bar(x=5))
|
assert is_basemodel_instance(Bar(x=5))
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported Pydantic version: {PYDANTIC_MAJOR_VERSION}")
|
raise ValueError(f"Unsupported Pydantic version: {PYDANTIC_MAJOR_VERSION}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Only tests Pydantic v2")
|
||||||
|
def test_with_field_metadata() -> None:
|
||||||
|
"""Test pydantic with field metadata"""
|
||||||
|
from pydantic import BaseModel as BaseModelV2 # pydantic: ignore
|
||||||
|
from pydantic import Field as FieldV2 # pydantic: ignore
|
||||||
|
|
||||||
|
class Foo(BaseModelV2):
|
||||||
|
x: List[int] = FieldV2(
|
||||||
|
description="List of integers", min_length=10, max_length=15
|
||||||
|
)
|
||||||
|
|
||||||
|
subset_model = _create_subset_model_v2("Foo", Foo, ["x"])
|
||||||
|
assert subset_model.model_json_schema() == {
|
||||||
|
"properties": {
|
||||||
|
"x": {
|
||||||
|
"description": "List of integers",
|
||||||
|
"items": {"type": "integer"},
|
||||||
|
"maxItems": 15,
|
||||||
|
"minItems": 10,
|
||||||
|
"title": "X",
|
||||||
|
"type": "array",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["x"],
|
||||||
|
"title": "Foo",
|
||||||
|
"type": "object",
|
||||||
|
}
|
||||||
|
@ -18,6 +18,8 @@ from langchain_core.output_parsers import StrOutputParser
|
|||||||
from langchain_core.prompts import ChatPromptTemplate
|
from langchain_core.prompts import ChatPromptTemplate
|
||||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
|
from pydantic import BaseModel as RawBaseModel
|
||||||
|
from pydantic import Field as RawField
|
||||||
|
|
||||||
from langchain_standard_tests.unit_tests.chat_models import (
|
from langchain_standard_tests.unit_tests.chat_models import (
|
||||||
ChatModelTests,
|
ChatModelTests,
|
||||||
@ -26,7 +28,11 @@ from langchain_standard_tests.unit_tests.chat_models import (
|
|||||||
from langchain_standard_tests.utils.pydantic import PYDANTIC_MAJOR_VERSION
|
from langchain_standard_tests.utils.pydantic import PYDANTIC_MAJOR_VERSION
|
||||||
|
|
||||||
|
|
||||||
@tool
|
class MagicFunctionSchema(RawBaseModel):
|
||||||
|
input: int = RawField(..., gt=-1000, lt=1000)
|
||||||
|
|
||||||
|
|
||||||
|
@tool(args_schema=MagicFunctionSchema)
|
||||||
def magic_function(input: int) -> int:
|
def magic_function(input: int) -> int:
|
||||||
"""Applies a magic function to an input."""
|
"""Applies a magic function to an input."""
|
||||||
return input + 2
|
return input + 2
|
||||||
|
Loading…
Reference in New Issue
Block a user