core[minor]: Relax constraints on type checking for tools and parsers (#24459)

This will allow tools and parsers to accept pydantic models from any of
the
following namespaces:

* pydantic.BaseModel with pydantic 1
* pydantic.BaseModel with pydantic 2
* pydantic.v1.BaseModel with pydantic 2
This commit is contained in:
Eugene Yurtsev 2024-07-19 21:47:34 -04:00 committed by GitHub
parent 838464de25
commit 5e48f35fba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 211 additions and 30 deletions

View File

@ -1,7 +1,7 @@
import copy import copy
import json import json
from json import JSONDecodeError from json import JSONDecodeError
from typing import Any, Dict, List, Optional, Type from typing import Any, Dict, List, Optional
from langchain_core.exceptions import OutputParserException from langchain_core.exceptions import OutputParserException
from langchain_core.messages import AIMessage, InvalidToolCall from langchain_core.messages import AIMessage, InvalidToolCall
@ -13,8 +13,9 @@ from langchain_core.messages.tool import (
) )
from langchain_core.output_parsers.transform import BaseCumulativeTransformOutputParser from langchain_core.output_parsers.transform import BaseCumulativeTransformOutputParser
from langchain_core.outputs import ChatGeneration, Generation from langchain_core.outputs import ChatGeneration, Generation
from langchain_core.pydantic_v1 import BaseModel, ValidationError from langchain_core.pydantic_v1 import ValidationError
from langchain_core.utils.json import parse_partial_json from langchain_core.utils.json import parse_partial_json
from langchain_core.utils.pydantic import TypeBaseModel
def parse_tool_call( def parse_tool_call(
@ -255,7 +256,7 @@ class JsonOutputKeyToolsParser(JsonOutputToolsParser):
class PydanticToolsParser(JsonOutputToolsParser): class PydanticToolsParser(JsonOutputToolsParser):
"""Parse tools from OpenAI response.""" """Parse tools from OpenAI response."""
tools: List[Type[BaseModel]] tools: List[TypeBaseModel]
"""The tools to parse.""" """The tools to parse."""
# TODO: Support more granular streaming of objects. Currently only streams once all # TODO: Support more granular streaming of objects. Currently only streams once all

View File

@ -1,23 +1,16 @@
import json import json
from typing import Generic, List, Type, TypeVar, Union from typing import Generic, List, Type
import pydantic # pydantic: ignore import pydantic # pydantic: ignore
from langchain_core.exceptions import OutputParserException from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers import JsonOutputParser from langchain_core.output_parsers import JsonOutputParser
from langchain_core.outputs import Generation from langchain_core.outputs import Generation
from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION from langchain_core.utils.pydantic import (
PYDANTIC_MAJOR_VERSION,
if PYDANTIC_MAJOR_VERSION < 2: PydanticBaseModel,
PydanticBaseModel = pydantic.BaseModel TBaseModel,
)
else:
from pydantic.v1 import BaseModel # pydantic: ignore
# Union type needs to be last assignment to PydanticBaseModel to make mypy happy.
PydanticBaseModel = Union[BaseModel, pydantic.BaseModel] # type: ignore
TBaseModel = TypeVar("TBaseModel", bound=PydanticBaseModel)
class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]): class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
@ -122,3 +115,10 @@ Here is the output schema:
``` ```
{schema} {schema}
```""" # noqa: E501 ```""" # noqa: E501
# Re-exporting types for backwards compatibility
__all__ = [
"PydanticBaseModel",
"PydanticOutputParser",
"TBaseModel",
]

View File

@ -89,6 +89,7 @@ from langchain_core.runnables.config import (
) )
from langchain_core.runnables.utils import accepts_context from langchain_core.runnables.utils import accepts_context
from langchain_core.utils.pydantic import ( from langchain_core.utils.pydantic import (
TypeBaseModel,
_create_subset_model, _create_subset_model,
is_basemodel_subclass, is_basemodel_subclass,
) )
@ -332,8 +333,15 @@ class ChildTool(BaseTool):
You can provide few-shot examples as a part of the description. You can provide few-shot examples as a part of the description.
""" """
args_schema: Optional[Type[BaseModel]] = None args_schema: Optional[TypeBaseModel] = None
"""Pydantic model class to validate and parse the tool's input arguments.""" """Pydantic model class to validate and parse the tool's input arguments.
Args schema should be either:
- A subclass of pydantic.BaseModel.
or
- A subclass of pydantic.v1.BaseModel if accessing v1 namespace in pydantic 2
"""
return_direct: bool = False return_direct: bool = False
"""Whether to return the tool's output directly. """Whether to return the tool's output directly.
@ -891,7 +899,7 @@ class StructuredTool(BaseTool):
"""Tool that can operate on any number of inputs.""" """Tool that can operate on any number of inputs."""
description: str = "" description: str = ""
args_schema: Type[BaseModel] = Field(..., description="The tool schema.") args_schema: TypeBaseModel = Field(..., description="The tool schema.")
"""The input arguments' schema.""" """The input arguments' schema."""
func: Optional[Callable[..., Any]] func: Optional[Callable[..., Any]]
"""The function to run when the tool is called.""" """The function to run when the tool is called."""

View File

@ -5,9 +5,14 @@ from __future__ import annotations
import inspect import inspect
import textwrap import textwrap
from functools import wraps from functools import wraps
from typing import Any, Callable, Dict, List, Optional, Type from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union
from langchain_core.pydantic_v1 import BaseModel, root_validator import pydantic # pydantic: ignore
from langchain_core.pydantic_v1 import (
BaseModel,
root_validator,
)
def get_pydantic_major_version() -> int: def get_pydantic_major_version() -> int:
@ -23,6 +28,22 @@ def get_pydantic_major_version() -> int:
PYDANTIC_MAJOR_VERSION = get_pydantic_major_version() PYDANTIC_MAJOR_VERSION = get_pydantic_major_version()
if PYDANTIC_MAJOR_VERSION == 1:
PydanticBaseModel = pydantic.BaseModel
TypeBaseModel = Type[BaseModel]
elif PYDANTIC_MAJOR_VERSION == 2:
from pydantic.v1 import BaseModel # pydantic: ignore
# Union type needs to be last assignment to PydanticBaseModel to make mypy happy.
PydanticBaseModel = Union[BaseModel, pydantic.BaseModel] # type: ignore
TypeBaseModel = Union[Type[BaseModel], Type[pydantic.BaseModel]] # type: ignore
else:
raise ValueError(f"Unsupported Pydantic version: {PYDANTIC_MAJOR_VERSION}")
TBaseModel = TypeVar("TBaseModel", bound=PydanticBaseModel)
def is_basemodel_subclass(cls: Type) -> bool: def is_basemodel_subclass(cls: Type) -> bool:
"""Check if the given class is a subclass of Pydantic BaseModel. """Check if the given class is a subclass of Pydantic BaseModel.
@ -37,13 +58,13 @@ def is_basemodel_subclass(cls: Type) -> bool:
return False return False
if PYDANTIC_MAJOR_VERSION == 1: if PYDANTIC_MAJOR_VERSION == 1:
from pydantic import BaseModel as BaseModelV1Proper from pydantic import BaseModel as BaseModelV1Proper # pydantic: ignore
if issubclass(cls, BaseModelV1Proper): if issubclass(cls, BaseModelV1Proper):
return True return True
elif PYDANTIC_MAJOR_VERSION == 2: elif PYDANTIC_MAJOR_VERSION == 2:
from pydantic import BaseModel as BaseModelV2 from pydantic import BaseModel as BaseModelV2 # pydantic: ignore
from pydantic.v1 import BaseModel as BaseModelV1 from pydantic.v1 import BaseModel as BaseModelV1 # pydantic: ignore
if issubclass(cls, BaseModelV2): if issubclass(cls, BaseModelV2):
return True return True
@ -65,13 +86,13 @@ def is_basemodel_instance(obj: Any) -> bool:
* pydantic.v1.BaseModel in Pydantic 2.x * pydantic.v1.BaseModel in Pydantic 2.x
""" """
if PYDANTIC_MAJOR_VERSION == 1: if PYDANTIC_MAJOR_VERSION == 1:
from pydantic import BaseModel as BaseModelV1Proper from pydantic import BaseModel as BaseModelV1Proper # pydantic: ignore
if isinstance(obj, BaseModelV1Proper): if isinstance(obj, BaseModelV1Proper):
return True return True
elif PYDANTIC_MAJOR_VERSION == 2: elif PYDANTIC_MAJOR_VERSION == 2:
from pydantic import BaseModel as BaseModelV2 from pydantic import BaseModel as BaseModelV2 # pydantic: ignore
from pydantic.v1 import BaseModel as BaseModelV1 from pydantic.v1 import BaseModel as BaseModelV1 # pydantic: ignore
if isinstance(obj, BaseModelV2): if isinstance(obj, BaseModelV2):
return True return True

View File

@ -1,12 +1,21 @@
from typing import Any, AsyncIterator, Iterator, List from typing import Any, AsyncIterator, Iterator, List
from langchain_core.messages import AIMessageChunk, BaseMessage, ToolCallChunk import pytest
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
ToolCallChunk,
)
from langchain_core.output_parsers.openai_tools import ( from langchain_core.output_parsers.openai_tools import (
JsonOutputKeyToolsParser, JsonOutputKeyToolsParser,
JsonOutputToolsParser, JsonOutputToolsParser,
PydanticToolsParser, PydanticToolsParser,
) )
from langchain_core.outputs import ChatGeneration
from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION
STREAMED_MESSAGES: list = [ STREAMED_MESSAGES: list = [
AIMessageChunk(content=""), AIMessageChunk(content=""),
@ -518,3 +527,108 @@ async def test_partial_pydantic_output_parser_async() -> None:
actual = [p async for p in chain.astream(None)] actual = [p async for p in chain.astream(None)]
assert actual == EXPECTED_STREAMED_PYDANTIC assert actual == EXPECTED_STREAMED_PYDANTIC
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="This test is for pydantic 2")
def test_parse_with_different_pydantic_2_v1() -> None:
"""Test with pydantic.v1.BaseModel from pydantic 2."""
import pydantic # pydantic: ignore
class Forecast(pydantic.v1.BaseModel):
temperature: int
forecast: str
# Can't get pydantic to work here due to the odd typing of tryig to support
# both v1 and v2 in the same codebase.
parser = PydanticToolsParser(tools=[Forecast]) # type: ignore[list-item]
message = AIMessage(
content="",
tool_calls=[
{
"id": "call_OwL7f5PE",
"name": "Forecast",
"args": {"temperature": 20, "forecast": "Sunny"},
}
],
)
generation = ChatGeneration(
message=message,
)
assert parser.parse_result([generation]) == [
Forecast(
temperature=20,
forecast="Sunny",
)
]
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="This test is for pydantic 2")
def test_parse_with_different_pydantic_2_proper() -> None:
"""Test with pydantic.BaseModel from pydantic 2."""
import pydantic # pydantic: ignore
class Forecast(pydantic.BaseModel):
temperature: int
forecast: str
# Can't get pydantic to work here due to the odd typing of tryig to support
# both v1 and v2 in the same codebase.
parser = PydanticToolsParser(tools=[Forecast]) # type: ignore[list-item]
message = AIMessage(
content="",
tool_calls=[
{
"id": "call_OwL7f5PE",
"name": "Forecast",
"args": {"temperature": 20, "forecast": "Sunny"},
}
],
)
generation = ChatGeneration(
message=message,
)
assert parser.parse_result([generation]) == [
Forecast(
temperature=20,
forecast="Sunny",
)
]
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 1, reason="This test is for pydantic 1")
def test_parse_with_different_pydantic_1_proper() -> None:
"""Test with pydantic.BaseModel from pydantic 1."""
import pydantic # pydantic: ignore
class Forecast(pydantic.BaseModel):
temperature: int
forecast: str
# Can't get pydantic to work here due to the odd typing of tryig to support
# both v1 and v2 in the same codebase.
parser = PydanticToolsParser(tools=[Forecast]) # type: ignore[list-item]
message = AIMessage(
content="",
tool_calls=[
{
"id": "call_OwL7f5PE",
"name": "Forecast",
"args": {"temperature": 20, "forecast": "Sunny"},
}
],
)
generation = ChatGeneration(
message=message,
)
assert parser.parse_result([generation]) == [
Forecast(
temperature=20,
forecast="Sunny",
)
]

View File

@ -6,9 +6,9 @@ import pytest
from langchain_core.exceptions import OutputParserException from langchain_core.exceptions import OutputParserException
from langchain_core.language_models import ParrotFakeChatModel from langchain_core.language_models import ParrotFakeChatModel
from langchain_core.output_parsers.json import JsonOutputParser from langchain_core.output_parsers.json import JsonOutputParser
from langchain_core.output_parsers.pydantic import PydanticOutputParser, TBaseModel from langchain_core.output_parsers.pydantic import PydanticOutputParser
from langchain_core.prompts.prompt import PromptTemplate from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION, TBaseModel
V1BaseModel = pydantic.BaseModel V1BaseModel = pydantic.BaseModel
if PYDANTIC_MAJOR_VERSION == 2: if PYDANTIC_MAJOR_VERSION == 2:

View File

@ -1526,3 +1526,40 @@ def test_args_schema_explicitly_typed() -> None:
"title": "some_tool", "title": "some_tool",
"type": "object", "type": "object",
} }
@pytest.mark.parametrize("pydantic_model", TEST_MODELS)
def test_structured_tool_with_different_pydantic_versions(pydantic_model: Any) -> None:
"""This should test that one can type the args schema as a pydantic model."""
from langchain_core.tools import StructuredTool
def foo(a: int, b: str) -> str:
"""Hahaha"""
return "foo"
foo_tool = StructuredTool.from_function(
func=foo,
args_schema=pydantic_model,
)
assert foo_tool.invoke({"a": 5, "b": "hello"}) == "foo"
assert foo_tool.args_schema.schema() == {
"properties": {
"a": {"title": "A", "type": "integer"},
"b": {"title": "B", "type": "string"},
},
"required": ["a", "b"],
"title": pydantic_model.__name__,
"type": "object",
}
assert foo_tool.get_input_schema().schema() == {
"properties": {
"a": {"title": "A", "type": "integer"},
"b": {"title": "B", "type": "string"},
},
"required": ["a", "b"],
"title": pydantic_model.__name__,
"type": "object",
}