core[minor]: Relax constraints on type checking for tools and parsers (#24459)

This will allow tools and parsers to accept pydantic models from any of
the
following namespaces:

* pydantic.BaseModel with pydantic 1
* pydantic.BaseModel with pydantic 2
* pydantic.v1.BaseModel with pydantic 2
This commit is contained in:
Eugene Yurtsev 2024-07-19 21:47:34 -04:00 committed by GitHub
parent 838464de25
commit 5e48f35fba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 211 additions and 30 deletions

View File

@ -1,7 +1,7 @@
import copy
import json
from json import JSONDecodeError
from typing import Any, Dict, List, Optional, Type
from typing import Any, Dict, List, Optional
from langchain_core.exceptions import OutputParserException
from langchain_core.messages import AIMessage, InvalidToolCall
@ -13,8 +13,9 @@ from langchain_core.messages.tool import (
)
from langchain_core.output_parsers.transform import BaseCumulativeTransformOutputParser
from langchain_core.outputs import ChatGeneration, Generation
from langchain_core.pydantic_v1 import BaseModel, ValidationError
from langchain_core.pydantic_v1 import ValidationError
from langchain_core.utils.json import parse_partial_json
from langchain_core.utils.pydantic import TypeBaseModel
def parse_tool_call(
@ -255,7 +256,7 @@ class JsonOutputKeyToolsParser(JsonOutputToolsParser):
class PydanticToolsParser(JsonOutputToolsParser):
"""Parse tools from OpenAI response."""
tools: List[Type[BaseModel]]
tools: List[TypeBaseModel]
"""The tools to parse."""
# TODO: Support more granular streaming of objects. Currently only streams once all

View File

@ -1,23 +1,16 @@
import json
from typing import Generic, List, Type, TypeVar, Union
from typing import Generic, List, Type
import pydantic # pydantic: ignore
from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.outputs import Generation
from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION
if PYDANTIC_MAJOR_VERSION < 2:
PydanticBaseModel = pydantic.BaseModel
else:
from pydantic.v1 import BaseModel # pydantic: ignore
# Union type needs to be last assignment to PydanticBaseModel to make mypy happy.
PydanticBaseModel = Union[BaseModel, pydantic.BaseModel] # type: ignore
TBaseModel = TypeVar("TBaseModel", bound=PydanticBaseModel)
from langchain_core.utils.pydantic import (
PYDANTIC_MAJOR_VERSION,
PydanticBaseModel,
TBaseModel,
)
class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
@ -122,3 +115,10 @@ Here is the output schema:
```
{schema}
```""" # noqa: E501
# Re-exporting types for backwards compatibility
__all__ = [
"PydanticBaseModel",
"PydanticOutputParser",
"TBaseModel",
]

View File

@ -89,6 +89,7 @@ from langchain_core.runnables.config import (
)
from langchain_core.runnables.utils import accepts_context
from langchain_core.utils.pydantic import (
TypeBaseModel,
_create_subset_model,
is_basemodel_subclass,
)
@ -332,8 +333,15 @@ class ChildTool(BaseTool):
You can provide few-shot examples as a part of the description.
"""
args_schema: Optional[Type[BaseModel]] = None
"""Pydantic model class to validate and parse the tool's input arguments."""
args_schema: Optional[TypeBaseModel] = None
"""Pydantic model class to validate and parse the tool's input arguments.
Args schema should be either:
- A subclass of pydantic.BaseModel.
or
- A subclass of pydantic.v1.BaseModel if accessing v1 namespace in pydantic 2
"""
return_direct: bool = False
"""Whether to return the tool's output directly.
@ -891,7 +899,7 @@ class StructuredTool(BaseTool):
"""Tool that can operate on any number of inputs."""
description: str = ""
args_schema: Type[BaseModel] = Field(..., description="The tool schema.")
args_schema: TypeBaseModel = Field(..., description="The tool schema.")
"""The input arguments' schema."""
func: Optional[Callable[..., Any]]
"""The function to run when the tool is called."""

View File

@ -5,9 +5,14 @@ from __future__ import annotations
import inspect
import textwrap
from functools import wraps
from typing import Any, Callable, Dict, List, Optional, Type
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union
from langchain_core.pydantic_v1 import BaseModel, root_validator
import pydantic # pydantic: ignore
from langchain_core.pydantic_v1 import (
BaseModel,
root_validator,
)
def get_pydantic_major_version() -> int:
@ -23,6 +28,22 @@ def get_pydantic_major_version() -> int:
PYDANTIC_MAJOR_VERSION = get_pydantic_major_version()
if PYDANTIC_MAJOR_VERSION == 1:
PydanticBaseModel = pydantic.BaseModel
TypeBaseModel = Type[BaseModel]
elif PYDANTIC_MAJOR_VERSION == 2:
from pydantic.v1 import BaseModel # pydantic: ignore
# Union type needs to be last assignment to PydanticBaseModel to make mypy happy.
PydanticBaseModel = Union[BaseModel, pydantic.BaseModel] # type: ignore
TypeBaseModel = Union[Type[BaseModel], Type[pydantic.BaseModel]] # type: ignore
else:
raise ValueError(f"Unsupported Pydantic version: {PYDANTIC_MAJOR_VERSION}")
TBaseModel = TypeVar("TBaseModel", bound=PydanticBaseModel)
def is_basemodel_subclass(cls: Type) -> bool:
"""Check if the given class is a subclass of Pydantic BaseModel.
@ -37,13 +58,13 @@ def is_basemodel_subclass(cls: Type) -> bool:
return False
if PYDANTIC_MAJOR_VERSION == 1:
from pydantic import BaseModel as BaseModelV1Proper
from pydantic import BaseModel as BaseModelV1Proper # pydantic: ignore
if issubclass(cls, BaseModelV1Proper):
return True
elif PYDANTIC_MAJOR_VERSION == 2:
from pydantic import BaseModel as BaseModelV2
from pydantic.v1 import BaseModel as BaseModelV1
from pydantic import BaseModel as BaseModelV2 # pydantic: ignore
from pydantic.v1 import BaseModel as BaseModelV1 # pydantic: ignore
if issubclass(cls, BaseModelV2):
return True
@ -65,13 +86,13 @@ def is_basemodel_instance(obj: Any) -> bool:
* pydantic.v1.BaseModel in Pydantic 2.x
"""
if PYDANTIC_MAJOR_VERSION == 1:
from pydantic import BaseModel as BaseModelV1Proper
from pydantic import BaseModel as BaseModelV1Proper # pydantic: ignore
if isinstance(obj, BaseModelV1Proper):
return True
elif PYDANTIC_MAJOR_VERSION == 2:
from pydantic import BaseModel as BaseModelV2
from pydantic.v1 import BaseModel as BaseModelV1
from pydantic import BaseModel as BaseModelV2 # pydantic: ignore
from pydantic.v1 import BaseModel as BaseModelV1 # pydantic: ignore
if isinstance(obj, BaseModelV2):
return True

View File

@ -1,12 +1,21 @@
from typing import Any, AsyncIterator, Iterator, List
from langchain_core.messages import AIMessageChunk, BaseMessage, ToolCallChunk
import pytest
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
ToolCallChunk,
)
from langchain_core.output_parsers.openai_tools import (
JsonOutputKeyToolsParser,
JsonOutputToolsParser,
PydanticToolsParser,
)
from langchain_core.outputs import ChatGeneration
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION
STREAMED_MESSAGES: list = [
AIMessageChunk(content=""),
@ -518,3 +527,108 @@ async def test_partial_pydantic_output_parser_async() -> None:
actual = [p async for p in chain.astream(None)]
assert actual == EXPECTED_STREAMED_PYDANTIC
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="This test is for pydantic 2")
def test_parse_with_different_pydantic_2_v1() -> None:
"""Test with pydantic.v1.BaseModel from pydantic 2."""
import pydantic # pydantic: ignore
class Forecast(pydantic.v1.BaseModel):
temperature: int
forecast: str
# Can't get pydantic to work here due to the odd typing of tryig to support
# both v1 and v2 in the same codebase.
parser = PydanticToolsParser(tools=[Forecast]) # type: ignore[list-item]
message = AIMessage(
content="",
tool_calls=[
{
"id": "call_OwL7f5PE",
"name": "Forecast",
"args": {"temperature": 20, "forecast": "Sunny"},
}
],
)
generation = ChatGeneration(
message=message,
)
assert parser.parse_result([generation]) == [
Forecast(
temperature=20,
forecast="Sunny",
)
]
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="This test is for pydantic 2")
def test_parse_with_different_pydantic_2_proper() -> None:
"""Test with pydantic.BaseModel from pydantic 2."""
import pydantic # pydantic: ignore
class Forecast(pydantic.BaseModel):
temperature: int
forecast: str
# Can't get pydantic to work here due to the odd typing of tryig to support
# both v1 and v2 in the same codebase.
parser = PydanticToolsParser(tools=[Forecast]) # type: ignore[list-item]
message = AIMessage(
content="",
tool_calls=[
{
"id": "call_OwL7f5PE",
"name": "Forecast",
"args": {"temperature": 20, "forecast": "Sunny"},
}
],
)
generation = ChatGeneration(
message=message,
)
assert parser.parse_result([generation]) == [
Forecast(
temperature=20,
forecast="Sunny",
)
]
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 1, reason="This test is for pydantic 1")
def test_parse_with_different_pydantic_1_proper() -> None:
"""Test with pydantic.BaseModel from pydantic 1."""
import pydantic # pydantic: ignore
class Forecast(pydantic.BaseModel):
temperature: int
forecast: str
# Can't get pydantic to work here due to the odd typing of tryig to support
# both v1 and v2 in the same codebase.
parser = PydanticToolsParser(tools=[Forecast]) # type: ignore[list-item]
message = AIMessage(
content="",
tool_calls=[
{
"id": "call_OwL7f5PE",
"name": "Forecast",
"args": {"temperature": 20, "forecast": "Sunny"},
}
],
)
generation = ChatGeneration(
message=message,
)
assert parser.parse_result([generation]) == [
Forecast(
temperature=20,
forecast="Sunny",
)
]

View File

@ -6,9 +6,9 @@ import pytest
from langchain_core.exceptions import OutputParserException
from langchain_core.language_models import ParrotFakeChatModel
from langchain_core.output_parsers.json import JsonOutputParser
from langchain_core.output_parsers.pydantic import PydanticOutputParser, TBaseModel
from langchain_core.output_parsers.pydantic import PydanticOutputParser
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION
from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION, TBaseModel
V1BaseModel = pydantic.BaseModel
if PYDANTIC_MAJOR_VERSION == 2:

View File

@ -1526,3 +1526,40 @@ def test_args_schema_explicitly_typed() -> None:
"title": "some_tool",
"type": "object",
}
@pytest.mark.parametrize("pydantic_model", TEST_MODELS)
def test_structured_tool_with_different_pydantic_versions(pydantic_model: Any) -> None:
"""This should test that one can type the args schema as a pydantic model."""
from langchain_core.tools import StructuredTool
def foo(a: int, b: str) -> str:
"""Hahaha"""
return "foo"
foo_tool = StructuredTool.from_function(
func=foo,
args_schema=pydantic_model,
)
assert foo_tool.invoke({"a": 5, "b": "hello"}) == "foo"
assert foo_tool.args_schema.schema() == {
"properties": {
"a": {"title": "A", "type": "integer"},
"b": {"title": "B", "type": "string"},
},
"required": ["a", "b"],
"title": pydantic_model.__name__,
"type": "object",
}
assert foo_tool.get_input_schema().schema() == {
"properties": {
"a": {"title": "A", "type": "integer"},
"b": {"title": "B", "type": "string"},
},
"required": ["a", "b"],
"title": pydantic_model.__name__,
"type": "object",
}