core[minor]: Relax constraints on type checking for tools and parsers (#24459)

This will allow tools and parsers to accept pydantic models from any of
the
following namespaces:

* pydantic.BaseModel with pydantic 1
* pydantic.BaseModel with pydantic 2
* pydantic.v1.BaseModel with pydantic 2
This commit is contained in:
Eugene Yurtsev
2024-07-19 21:47:34 -04:00
committed by GitHub
parent 838464de25
commit 5e48f35fba
7 changed files with 211 additions and 30 deletions

View File

@@ -1,12 +1,21 @@
from typing import Any, AsyncIterator, Iterator, List
from langchain_core.messages import AIMessageChunk, BaseMessage, ToolCallChunk
import pytest
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
ToolCallChunk,
)
from langchain_core.output_parsers.openai_tools import (
JsonOutputKeyToolsParser,
JsonOutputToolsParser,
PydanticToolsParser,
)
from langchain_core.outputs import ChatGeneration
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION
STREAMED_MESSAGES: list = [
AIMessageChunk(content=""),
@@ -518,3 +527,108 @@ async def test_partial_pydantic_output_parser_async() -> None:
actual = [p async for p in chain.astream(None)]
assert actual == EXPECTED_STREAMED_PYDANTIC
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="This test is for pydantic 2")
def test_parse_with_different_pydantic_2_v1() -> None:
"""Test with pydantic.v1.BaseModel from pydantic 2."""
import pydantic # pydantic: ignore
class Forecast(pydantic.v1.BaseModel):
temperature: int
forecast: str
# Can't get pydantic to work here due to the odd typing of tryig to support
# both v1 and v2 in the same codebase.
parser = PydanticToolsParser(tools=[Forecast]) # type: ignore[list-item]
message = AIMessage(
content="",
tool_calls=[
{
"id": "call_OwL7f5PE",
"name": "Forecast",
"args": {"temperature": 20, "forecast": "Sunny"},
}
],
)
generation = ChatGeneration(
message=message,
)
assert parser.parse_result([generation]) == [
Forecast(
temperature=20,
forecast="Sunny",
)
]
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="This test is for pydantic 2")
def test_parse_with_different_pydantic_2_proper() -> None:
"""Test with pydantic.BaseModel from pydantic 2."""
import pydantic # pydantic: ignore
class Forecast(pydantic.BaseModel):
temperature: int
forecast: str
# Can't get pydantic to work here due to the odd typing of tryig to support
# both v1 and v2 in the same codebase.
parser = PydanticToolsParser(tools=[Forecast]) # type: ignore[list-item]
message = AIMessage(
content="",
tool_calls=[
{
"id": "call_OwL7f5PE",
"name": "Forecast",
"args": {"temperature": 20, "forecast": "Sunny"},
}
],
)
generation = ChatGeneration(
message=message,
)
assert parser.parse_result([generation]) == [
Forecast(
temperature=20,
forecast="Sunny",
)
]
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 1, reason="This test is for pydantic 1")
def test_parse_with_different_pydantic_1_proper() -> None:
"""Test with pydantic.BaseModel from pydantic 1."""
import pydantic # pydantic: ignore
class Forecast(pydantic.BaseModel):
temperature: int
forecast: str
# Can't get pydantic to work here due to the odd typing of tryig to support
# both v1 and v2 in the same codebase.
parser = PydanticToolsParser(tools=[Forecast]) # type: ignore[list-item]
message = AIMessage(
content="",
tool_calls=[
{
"id": "call_OwL7f5PE",
"name": "Forecast",
"args": {"temperature": 20, "forecast": "Sunny"},
}
],
)
generation = ChatGeneration(
message=message,
)
assert parser.parse_result([generation]) == [
Forecast(
temperature=20,
forecast="Sunny",
)
]

View File

@@ -6,9 +6,9 @@ import pytest
from langchain_core.exceptions import OutputParserException
from langchain_core.language_models import ParrotFakeChatModel
from langchain_core.output_parsers.json import JsonOutputParser
from langchain_core.output_parsers.pydantic import PydanticOutputParser, TBaseModel
from langchain_core.output_parsers.pydantic import PydanticOutputParser
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION
from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION, TBaseModel
V1BaseModel = pydantic.BaseModel
if PYDANTIC_MAJOR_VERSION == 2:

View File

@@ -1526,3 +1526,40 @@ def test_args_schema_explicitly_typed() -> None:
"title": "some_tool",
"type": "object",
}
@pytest.mark.parametrize("pydantic_model", TEST_MODELS)
def test_structured_tool_with_different_pydantic_versions(pydantic_model: Any) -> None:
"""This should test that one can type the args schema as a pydantic model."""
from langchain_core.tools import StructuredTool
def foo(a: int, b: str) -> str:
"""Hahaha"""
return "foo"
foo_tool = StructuredTool.from_function(
func=foo,
args_schema=pydantic_model,
)
assert foo_tool.invoke({"a": 5, "b": "hello"}) == "foo"
assert foo_tool.args_schema.schema() == {
"properties": {
"a": {"title": "A", "type": "integer"},
"b": {"title": "B", "type": "string"},
},
"required": ["a", "b"],
"title": pydantic_model.__name__,
"type": "object",
}
assert foo_tool.get_input_schema().schema() == {
"properties": {
"a": {"title": "A", "type": "integer"},
"b": {"title": "B", "type": "string"},
},
"required": ["a", "b"],
"title": pydantic_model.__name__,
"type": "object",
}