mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-02 11:39:18 +00:00
core[minor]: Relax constraints on type checking for tools and parsers (#24459)
This will allow tools and parsers to accept pydantic models from any of the following namespaces: * pydantic.BaseModel with pydantic 1 * pydantic.BaseModel with pydantic 2 * pydantic.v1.BaseModel with pydantic 2
This commit is contained in:
@@ -1,12 +1,21 @@
|
||||
from typing import Any, AsyncIterator, Iterator, List
|
||||
|
||||
from langchain_core.messages import AIMessageChunk, BaseMessage, ToolCallChunk
|
||||
import pytest
|
||||
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
ToolCallChunk,
|
||||
)
|
||||
from langchain_core.output_parsers.openai_tools import (
|
||||
JsonOutputKeyToolsParser,
|
||||
JsonOutputToolsParser,
|
||||
PydanticToolsParser,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION
|
||||
|
||||
STREAMED_MESSAGES: list = [
|
||||
AIMessageChunk(content=""),
|
||||
@@ -518,3 +527,108 @@ async def test_partial_pydantic_output_parser_async() -> None:
|
||||
|
||||
actual = [p async for p in chain.astream(None)]
|
||||
assert actual == EXPECTED_STREAMED_PYDANTIC
|
||||
|
||||
|
||||
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="This test is for pydantic 2")
|
||||
def test_parse_with_different_pydantic_2_v1() -> None:
|
||||
"""Test with pydantic.v1.BaseModel from pydantic 2."""
|
||||
import pydantic # pydantic: ignore
|
||||
|
||||
class Forecast(pydantic.v1.BaseModel):
|
||||
temperature: int
|
||||
forecast: str
|
||||
|
||||
# Can't get pydantic to work here due to the odd typing of tryig to support
|
||||
# both v1 and v2 in the same codebase.
|
||||
parser = PydanticToolsParser(tools=[Forecast]) # type: ignore[list-item]
|
||||
message = AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "call_OwL7f5PE",
|
||||
"name": "Forecast",
|
||||
"args": {"temperature": 20, "forecast": "Sunny"},
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
generation = ChatGeneration(
|
||||
message=message,
|
||||
)
|
||||
|
||||
assert parser.parse_result([generation]) == [
|
||||
Forecast(
|
||||
temperature=20,
|
||||
forecast="Sunny",
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="This test is for pydantic 2")
|
||||
def test_parse_with_different_pydantic_2_proper() -> None:
|
||||
"""Test with pydantic.BaseModel from pydantic 2."""
|
||||
import pydantic # pydantic: ignore
|
||||
|
||||
class Forecast(pydantic.BaseModel):
|
||||
temperature: int
|
||||
forecast: str
|
||||
|
||||
# Can't get pydantic to work here due to the odd typing of tryig to support
|
||||
# both v1 and v2 in the same codebase.
|
||||
parser = PydanticToolsParser(tools=[Forecast]) # type: ignore[list-item]
|
||||
message = AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "call_OwL7f5PE",
|
||||
"name": "Forecast",
|
||||
"args": {"temperature": 20, "forecast": "Sunny"},
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
generation = ChatGeneration(
|
||||
message=message,
|
||||
)
|
||||
|
||||
assert parser.parse_result([generation]) == [
|
||||
Forecast(
|
||||
temperature=20,
|
||||
forecast="Sunny",
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 1, reason="This test is for pydantic 1")
|
||||
def test_parse_with_different_pydantic_1_proper() -> None:
|
||||
"""Test with pydantic.BaseModel from pydantic 1."""
|
||||
import pydantic # pydantic: ignore
|
||||
|
||||
class Forecast(pydantic.BaseModel):
|
||||
temperature: int
|
||||
forecast: str
|
||||
|
||||
# Can't get pydantic to work here due to the odd typing of tryig to support
|
||||
# both v1 and v2 in the same codebase.
|
||||
parser = PydanticToolsParser(tools=[Forecast]) # type: ignore[list-item]
|
||||
message = AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "call_OwL7f5PE",
|
||||
"name": "Forecast",
|
||||
"args": {"temperature": 20, "forecast": "Sunny"},
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
generation = ChatGeneration(
|
||||
message=message,
|
||||
)
|
||||
|
||||
assert parser.parse_result([generation]) == [
|
||||
Forecast(
|
||||
temperature=20,
|
||||
forecast="Sunny",
|
||||
)
|
||||
]
|
||||
|
@@ -6,9 +6,9 @@ import pytest
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.language_models import ParrotFakeChatModel
|
||||
from langchain_core.output_parsers.json import JsonOutputParser
|
||||
from langchain_core.output_parsers.pydantic import PydanticOutputParser, TBaseModel
|
||||
from langchain_core.output_parsers.pydantic import PydanticOutputParser
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION
|
||||
from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION, TBaseModel
|
||||
|
||||
V1BaseModel = pydantic.BaseModel
|
||||
if PYDANTIC_MAJOR_VERSION == 2:
|
||||
|
@@ -1526,3 +1526,40 @@ def test_args_schema_explicitly_typed() -> None:
|
||||
"title": "some_tool",
|
||||
"type": "object",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("pydantic_model", TEST_MODELS)
|
||||
def test_structured_tool_with_different_pydantic_versions(pydantic_model: Any) -> None:
|
||||
"""This should test that one can type the args schema as a pydantic model."""
|
||||
from langchain_core.tools import StructuredTool
|
||||
|
||||
def foo(a: int, b: str) -> str:
|
||||
"""Hahaha"""
|
||||
return "foo"
|
||||
|
||||
foo_tool = StructuredTool.from_function(
|
||||
func=foo,
|
||||
args_schema=pydantic_model,
|
||||
)
|
||||
|
||||
assert foo_tool.invoke({"a": 5, "b": "hello"}) == "foo"
|
||||
|
||||
assert foo_tool.args_schema.schema() == {
|
||||
"properties": {
|
||||
"a": {"title": "A", "type": "integer"},
|
||||
"b": {"title": "B", "type": "string"},
|
||||
},
|
||||
"required": ["a", "b"],
|
||||
"title": pydantic_model.__name__,
|
||||
"type": "object",
|
||||
}
|
||||
|
||||
assert foo_tool.get_input_schema().schema() == {
|
||||
"properties": {
|
||||
"a": {"title": "A", "type": "integer"},
|
||||
"b": {"title": "B", "type": "string"},
|
||||
},
|
||||
"required": ["a", "b"],
|
||||
"title": pydantic_model.__name__,
|
||||
"type": "object",
|
||||
}
|
||||
|
Reference in New Issue
Block a user