mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-01 09:04:03 +00:00
core[minor]: Relax constraints on type checking for tools and parsers (#24459)
This will allow tools and parsers to accept pydantic models from any of the following namespaces: * pydantic.BaseModel with pydantic 1 * pydantic.BaseModel with pydantic 2 * pydantic.v1.BaseModel with pydantic 2
This commit is contained in:
parent
838464de25
commit
5e48f35fba
@ -1,7 +1,7 @@
|
||||
import copy
|
||||
import json
|
||||
from json import JSONDecodeError
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.messages import AIMessage, InvalidToolCall
|
||||
@ -13,8 +13,9 @@ from langchain_core.messages.tool import (
|
||||
)
|
||||
from langchain_core.output_parsers.transform import BaseCumulativeTransformOutputParser
|
||||
from langchain_core.outputs import ChatGeneration, Generation
|
||||
from langchain_core.pydantic_v1 import BaseModel, ValidationError
|
||||
from langchain_core.pydantic_v1 import ValidationError
|
||||
from langchain_core.utils.json import parse_partial_json
|
||||
from langchain_core.utils.pydantic import TypeBaseModel
|
||||
|
||||
|
||||
def parse_tool_call(
|
||||
@ -255,7 +256,7 @@ class JsonOutputKeyToolsParser(JsonOutputToolsParser):
|
||||
class PydanticToolsParser(JsonOutputToolsParser):
|
||||
"""Parse tools from OpenAI response."""
|
||||
|
||||
tools: List[Type[BaseModel]]
|
||||
tools: List[TypeBaseModel]
|
||||
"""The tools to parse."""
|
||||
|
||||
# TODO: Support more granular streaming of objects. Currently only streams once all
|
||||
|
@ -1,23 +1,16 @@
|
||||
import json
|
||||
from typing import Generic, List, Type, TypeVar, Union
|
||||
from typing import Generic, List, Type
|
||||
|
||||
import pydantic # pydantic: ignore
|
||||
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.output_parsers import JsonOutputParser
|
||||
from langchain_core.outputs import Generation
|
||||
from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION
|
||||
|
||||
if PYDANTIC_MAJOR_VERSION < 2:
|
||||
PydanticBaseModel = pydantic.BaseModel
|
||||
|
||||
else:
|
||||
from pydantic.v1 import BaseModel # pydantic: ignore
|
||||
|
||||
# Union type needs to be last assignment to PydanticBaseModel to make mypy happy.
|
||||
PydanticBaseModel = Union[BaseModel, pydantic.BaseModel] # type: ignore
|
||||
|
||||
TBaseModel = TypeVar("TBaseModel", bound=PydanticBaseModel)
|
||||
from langchain_core.utils.pydantic import (
|
||||
PYDANTIC_MAJOR_VERSION,
|
||||
PydanticBaseModel,
|
||||
TBaseModel,
|
||||
)
|
||||
|
||||
|
||||
class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
|
||||
@ -122,3 +115,10 @@ Here is the output schema:
|
||||
```
|
||||
{schema}
|
||||
```""" # noqa: E501
|
||||
|
||||
# Re-exporting types for backwards compatibility
|
||||
__all__ = [
|
||||
"PydanticBaseModel",
|
||||
"PydanticOutputParser",
|
||||
"TBaseModel",
|
||||
]
|
||||
|
@ -89,6 +89,7 @@ from langchain_core.runnables.config import (
|
||||
)
|
||||
from langchain_core.runnables.utils import accepts_context
|
||||
from langchain_core.utils.pydantic import (
|
||||
TypeBaseModel,
|
||||
_create_subset_model,
|
||||
is_basemodel_subclass,
|
||||
)
|
||||
@ -332,8 +333,15 @@ class ChildTool(BaseTool):
|
||||
|
||||
You can provide few-shot examples as a part of the description.
|
||||
"""
|
||||
args_schema: Optional[Type[BaseModel]] = None
|
||||
"""Pydantic model class to validate and parse the tool's input arguments."""
|
||||
args_schema: Optional[TypeBaseModel] = None
|
||||
"""Pydantic model class to validate and parse the tool's input arguments.
|
||||
|
||||
Args schema should be either:
|
||||
|
||||
- A subclass of pydantic.BaseModel.
|
||||
or
|
||||
- A subclass of pydantic.v1.BaseModel if accessing v1 namespace in pydantic 2
|
||||
"""
|
||||
return_direct: bool = False
|
||||
"""Whether to return the tool's output directly.
|
||||
|
||||
@ -891,7 +899,7 @@ class StructuredTool(BaseTool):
|
||||
"""Tool that can operate on any number of inputs."""
|
||||
|
||||
description: str = ""
|
||||
args_schema: Type[BaseModel] = Field(..., description="The tool schema.")
|
||||
args_schema: TypeBaseModel = Field(..., description="The tool schema.")
|
||||
"""The input arguments' schema."""
|
||||
func: Optional[Callable[..., Any]]
|
||||
"""The function to run when the tool is called."""
|
||||
|
@ -5,9 +5,14 @@ from __future__ import annotations
|
||||
import inspect
|
||||
import textwrap
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Dict, List, Optional, Type
|
||||
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union
|
||||
|
||||
from langchain_core.pydantic_v1 import BaseModel, root_validator
|
||||
import pydantic # pydantic: ignore
|
||||
|
||||
from langchain_core.pydantic_v1 import (
|
||||
BaseModel,
|
||||
root_validator,
|
||||
)
|
||||
|
||||
|
||||
def get_pydantic_major_version() -> int:
|
||||
@ -23,6 +28,22 @@ def get_pydantic_major_version() -> int:
|
||||
PYDANTIC_MAJOR_VERSION = get_pydantic_major_version()
|
||||
|
||||
|
||||
if PYDANTIC_MAJOR_VERSION == 1:
|
||||
PydanticBaseModel = pydantic.BaseModel
|
||||
TypeBaseModel = Type[BaseModel]
|
||||
elif PYDANTIC_MAJOR_VERSION == 2:
|
||||
from pydantic.v1 import BaseModel # pydantic: ignore
|
||||
|
||||
# Union type needs to be last assignment to PydanticBaseModel to make mypy happy.
|
||||
PydanticBaseModel = Union[BaseModel, pydantic.BaseModel] # type: ignore
|
||||
TypeBaseModel = Union[Type[BaseModel], Type[pydantic.BaseModel]] # type: ignore
|
||||
else:
|
||||
raise ValueError(f"Unsupported Pydantic version: {PYDANTIC_MAJOR_VERSION}")
|
||||
|
||||
|
||||
TBaseModel = TypeVar("TBaseModel", bound=PydanticBaseModel)
|
||||
|
||||
|
||||
def is_basemodel_subclass(cls: Type) -> bool:
|
||||
"""Check if the given class is a subclass of Pydantic BaseModel.
|
||||
|
||||
@ -37,13 +58,13 @@ def is_basemodel_subclass(cls: Type) -> bool:
|
||||
return False
|
||||
|
||||
if PYDANTIC_MAJOR_VERSION == 1:
|
||||
from pydantic import BaseModel as BaseModelV1Proper
|
||||
from pydantic import BaseModel as BaseModelV1Proper # pydantic: ignore
|
||||
|
||||
if issubclass(cls, BaseModelV1Proper):
|
||||
return True
|
||||
elif PYDANTIC_MAJOR_VERSION == 2:
|
||||
from pydantic import BaseModel as BaseModelV2
|
||||
from pydantic.v1 import BaseModel as BaseModelV1
|
||||
from pydantic import BaseModel as BaseModelV2 # pydantic: ignore
|
||||
from pydantic.v1 import BaseModel as BaseModelV1 # pydantic: ignore
|
||||
|
||||
if issubclass(cls, BaseModelV2):
|
||||
return True
|
||||
@ -65,13 +86,13 @@ def is_basemodel_instance(obj: Any) -> bool:
|
||||
* pydantic.v1.BaseModel in Pydantic 2.x
|
||||
"""
|
||||
if PYDANTIC_MAJOR_VERSION == 1:
|
||||
from pydantic import BaseModel as BaseModelV1Proper
|
||||
from pydantic import BaseModel as BaseModelV1Proper # pydantic: ignore
|
||||
|
||||
if isinstance(obj, BaseModelV1Proper):
|
||||
return True
|
||||
elif PYDANTIC_MAJOR_VERSION == 2:
|
||||
from pydantic import BaseModel as BaseModelV2
|
||||
from pydantic.v1 import BaseModel as BaseModelV1
|
||||
from pydantic import BaseModel as BaseModelV2 # pydantic: ignore
|
||||
from pydantic.v1 import BaseModel as BaseModelV1 # pydantic: ignore
|
||||
|
||||
if isinstance(obj, BaseModelV2):
|
||||
return True
|
||||
|
@ -1,12 +1,21 @@
|
||||
from typing import Any, AsyncIterator, Iterator, List
|
||||
|
||||
from langchain_core.messages import AIMessageChunk, BaseMessage, ToolCallChunk
|
||||
import pytest
|
||||
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
ToolCallChunk,
|
||||
)
|
||||
from langchain_core.output_parsers.openai_tools import (
|
||||
JsonOutputKeyToolsParser,
|
||||
JsonOutputToolsParser,
|
||||
PydanticToolsParser,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION
|
||||
|
||||
STREAMED_MESSAGES: list = [
|
||||
AIMessageChunk(content=""),
|
||||
@ -518,3 +527,108 @@ async def test_partial_pydantic_output_parser_async() -> None:
|
||||
|
||||
actual = [p async for p in chain.astream(None)]
|
||||
assert actual == EXPECTED_STREAMED_PYDANTIC
|
||||
|
||||
|
||||
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="This test is for pydantic 2")
|
||||
def test_parse_with_different_pydantic_2_v1() -> None:
|
||||
"""Test with pydantic.v1.BaseModel from pydantic 2."""
|
||||
import pydantic # pydantic: ignore
|
||||
|
||||
class Forecast(pydantic.v1.BaseModel):
|
||||
temperature: int
|
||||
forecast: str
|
||||
|
||||
# Can't get pydantic to work here due to the odd typing of tryig to support
|
||||
# both v1 and v2 in the same codebase.
|
||||
parser = PydanticToolsParser(tools=[Forecast]) # type: ignore[list-item]
|
||||
message = AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "call_OwL7f5PE",
|
||||
"name": "Forecast",
|
||||
"args": {"temperature": 20, "forecast": "Sunny"},
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
generation = ChatGeneration(
|
||||
message=message,
|
||||
)
|
||||
|
||||
assert parser.parse_result([generation]) == [
|
||||
Forecast(
|
||||
temperature=20,
|
||||
forecast="Sunny",
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="This test is for pydantic 2")
|
||||
def test_parse_with_different_pydantic_2_proper() -> None:
|
||||
"""Test with pydantic.BaseModel from pydantic 2."""
|
||||
import pydantic # pydantic: ignore
|
||||
|
||||
class Forecast(pydantic.BaseModel):
|
||||
temperature: int
|
||||
forecast: str
|
||||
|
||||
# Can't get pydantic to work here due to the odd typing of tryig to support
|
||||
# both v1 and v2 in the same codebase.
|
||||
parser = PydanticToolsParser(tools=[Forecast]) # type: ignore[list-item]
|
||||
message = AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "call_OwL7f5PE",
|
||||
"name": "Forecast",
|
||||
"args": {"temperature": 20, "forecast": "Sunny"},
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
generation = ChatGeneration(
|
||||
message=message,
|
||||
)
|
||||
|
||||
assert parser.parse_result([generation]) == [
|
||||
Forecast(
|
||||
temperature=20,
|
||||
forecast="Sunny",
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 1, reason="This test is for pydantic 1")
|
||||
def test_parse_with_different_pydantic_1_proper() -> None:
|
||||
"""Test with pydantic.BaseModel from pydantic 1."""
|
||||
import pydantic # pydantic: ignore
|
||||
|
||||
class Forecast(pydantic.BaseModel):
|
||||
temperature: int
|
||||
forecast: str
|
||||
|
||||
# Can't get pydantic to work here due to the odd typing of tryig to support
|
||||
# both v1 and v2 in the same codebase.
|
||||
parser = PydanticToolsParser(tools=[Forecast]) # type: ignore[list-item]
|
||||
message = AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "call_OwL7f5PE",
|
||||
"name": "Forecast",
|
||||
"args": {"temperature": 20, "forecast": "Sunny"},
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
generation = ChatGeneration(
|
||||
message=message,
|
||||
)
|
||||
|
||||
assert parser.parse_result([generation]) == [
|
||||
Forecast(
|
||||
temperature=20,
|
||||
forecast="Sunny",
|
||||
)
|
||||
]
|
||||
|
@ -6,9 +6,9 @@ import pytest
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.language_models import ParrotFakeChatModel
|
||||
from langchain_core.output_parsers.json import JsonOutputParser
|
||||
from langchain_core.output_parsers.pydantic import PydanticOutputParser, TBaseModel
|
||||
from langchain_core.output_parsers.pydantic import PydanticOutputParser
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION
|
||||
from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION, TBaseModel
|
||||
|
||||
V1BaseModel = pydantic.BaseModel
|
||||
if PYDANTIC_MAJOR_VERSION == 2:
|
||||
|
@ -1526,3 +1526,40 @@ def test_args_schema_explicitly_typed() -> None:
|
||||
"title": "some_tool",
|
||||
"type": "object",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("pydantic_model", TEST_MODELS)
|
||||
def test_structured_tool_with_different_pydantic_versions(pydantic_model: Any) -> None:
|
||||
"""This should test that one can type the args schema as a pydantic model."""
|
||||
from langchain_core.tools import StructuredTool
|
||||
|
||||
def foo(a: int, b: str) -> str:
|
||||
"""Hahaha"""
|
||||
return "foo"
|
||||
|
||||
foo_tool = StructuredTool.from_function(
|
||||
func=foo,
|
||||
args_schema=pydantic_model,
|
||||
)
|
||||
|
||||
assert foo_tool.invoke({"a": 5, "b": "hello"}) == "foo"
|
||||
|
||||
assert foo_tool.args_schema.schema() == {
|
||||
"properties": {
|
||||
"a": {"title": "A", "type": "integer"},
|
||||
"b": {"title": "B", "type": "string"},
|
||||
},
|
||||
"required": ["a", "b"],
|
||||
"title": pydantic_model.__name__,
|
||||
"type": "object",
|
||||
}
|
||||
|
||||
assert foo_tool.get_input_schema().schema() == {
|
||||
"properties": {
|
||||
"a": {"title": "A", "type": "integer"},
|
||||
"b": {"title": "B", "type": "string"},
|
||||
},
|
||||
"required": ["a", "b"],
|
||||
"title": pydantic_model.__name__,
|
||||
"type": "object",
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user