From 6096c80b71fd294822cad996d9ffb4f2e0d0df68 Mon Sep 17 00:00:00 2001 From: Erick Friis Date: Thu, 22 Aug 2024 18:00:09 -0700 Subject: [PATCH] core: pydantic output parser streaming fix (#24415) Co-authored-by: Bagatur --- .../langchain_core/output_parsers/pydantic.py | 13 +- .../output_parsers/test_pydantic_parser.py | 123 +++++++++++++++++- .../output_parsers/test_pydantic_parser.py | 123 ------------------ 3 files changed, 130 insertions(+), 129 deletions(-) delete mode 100644 libs/langchain/tests/unit_tests/output_parsers/test_pydantic_parser.py diff --git a/libs/core/langchain_core/output_parsers/pydantic.py b/libs/core/langchain_core/output_parsers/pydantic.py index b8f782ffa2f..e3ee62e1b0c 100644 --- a/libs/core/langchain_core/output_parsers/pydantic.py +++ b/libs/core/langchain_core/output_parsers/pydantic.py @@ -1,5 +1,5 @@ import json -from typing import Generic, List, Type +from typing import Generic, List, Optional, Type import pydantic # pydantic: ignore @@ -49,7 +49,7 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]): def parse_result( self, result: List[Generation], *, partial: bool = False - ) -> TBaseModel: + ) -> Optional[TBaseModel]: """Parse the result of an LLM call to a pydantic object. Args: @@ -62,8 +62,13 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]): Returns: The parsed pydantic object. """ - json_object = super().parse_result(result) - return self._parse_obj(json_object) + try: + json_object = super().parse_result(result) + return self._parse_obj(json_object) + except OutputParserException as e: + if partial: + return None + raise e def parse(self, text: str) -> TBaseModel: """Parse the output of an LLM call to a pydantic object. diff --git a/libs/core/tests/unit_tests/output_parsers/test_pydantic_parser.py b/libs/core/tests/unit_tests/output_parsers/test_pydantic_parser.py index eb45a65e6ed..11e11b06e6d 100644 --- a/libs/core/tests/unit_tests/output_parsers/test_pydantic_parser.py +++ b/libs/core/tests/unit_tests/output_parsers/test_pydantic_parser.py @@ -1,13 +1,17 @@ -from typing import Literal +"""Test PydanticOutputParser""" + +from enum import Enum +from typing import Literal, Optional import pydantic # pydantic: ignore import pytest from langchain_core.exceptions import OutputParserException from langchain_core.language_models import ParrotFakeChatModel +from langchain_core.output_parsers import PydanticOutputParser from langchain_core.output_parsers.json import JsonOutputParser -from langchain_core.output_parsers.pydantic import PydanticOutputParser from langchain_core.prompts.prompt import PromptTemplate +from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION, TBaseModel V1BaseModel = pydantic.BaseModel @@ -96,3 +100,118 @@ def test_json_parser_chaining( assert res["f_or_c"] == "C" assert res["temperature"] == 20 assert res["forecast"] == "Sunny" + + +class Actions(Enum): + SEARCH = "Search" + CREATE = "Create" + UPDATE = "Update" + DELETE = "Delete" + + +class TestModel(BaseModel): + action: Actions = Field(description="Action to be performed") + action_input: str = Field(description="Input to be used in the action") + additional_fields: Optional[str] = Field( + description="Additional fields", default=None + ) + for_new_lines: str = Field(description="To be used to test newlines") + + +# Prevent pytest from trying to run tests on TestModel +TestModel.__test__ = False # type: ignore[attr-defined] + + +DEF_RESULT = """{ + "action": "Update", + "action_input": "The PydanticOutputParser class is powerful", + "additional_fields": null, + "for_new_lines": "not_escape_newline:\n escape_newline: \\n" +}""" + +# action 'update' with a lowercase 'u' to test schema validation failure. +DEF_RESULT_FAIL = """{ + "action": "update", + "action_input": "The PydanticOutputParser class is powerful", + "additional_fields": null +}""" + +DEF_EXPECTED_RESULT = TestModel( + action=Actions.UPDATE, + action_input="The PydanticOutputParser class is powerful", + additional_fields=None, + for_new_lines="not_escape_newline:\n escape_newline: \n", +) + + +def test_pydantic_output_parser() -> None: + """Test PydanticOutputParser.""" + + pydantic_parser: PydanticOutputParser = PydanticOutputParser( + pydantic_object=TestModel + ) + + result = pydantic_parser.parse(DEF_RESULT) + print("parse_result:", result) # noqa: T201 + assert DEF_EXPECTED_RESULT == result + assert pydantic_parser.OutputType is TestModel + + +def test_pydantic_output_parser_fail() -> None: + """Test PydanticOutputParser where completion result fails schema validation.""" + + pydantic_parser: PydanticOutputParser = PydanticOutputParser( + pydantic_object=TestModel + ) + + try: + pydantic_parser.parse(DEF_RESULT_FAIL) + except OutputParserException as e: + print("parse_result:", e) # noqa: T201 + assert "Failed to parse TestModel from completion" in str(e) + else: + assert False, "Expected OutputParserException" + + +def test_pydantic_output_parser_type_inference() -> None: + """Test pydantic output parser type inference.""" + + class SampleModel(BaseModel): + foo: int + bar: str + + # Ignoring mypy error that appears in python 3.8, but not 3.11. + # This seems to be functionally correct, so we'll ignore the error. + pydantic_parser = PydanticOutputParser(pydantic_object=SampleModel) # type: ignore + schema = pydantic_parser.get_output_schema().schema() + + assert schema == { + "properties": { + "bar": {"title": "Bar", "type": "string"}, + "foo": {"title": "Foo", "type": "integer"}, + }, + "required": ["foo", "bar"], + "title": "SampleModel", + "type": "object", + } + + +def test_format_instructions_preserves_language() -> None: + """Test format instructions does not attempt to encode into ascii.""" + from langchain_core.pydantic_v1 import BaseModel, Field + + description = ( + "你好, こんにちは, नमस्ते, Bonjour, Hola, " + "Olá, 안녕하세요, Jambo, Merhaba, Γειά σου" + ) + + class Foo(BaseModel): + hello: str = Field( + description=( + "你好, こんにちは, नमस्ते, Bonjour, Hola, " + "Olá, 안녕하세요, Jambo, Merhaba, Γειά σου" + ) + ) + + parser = PydanticOutputParser(pydantic_object=Foo) # type: ignore + assert description in parser.get_format_instructions() diff --git a/libs/langchain/tests/unit_tests/output_parsers/test_pydantic_parser.py b/libs/langchain/tests/unit_tests/output_parsers/test_pydantic_parser.py deleted file mode 100644 index d390ee541b6..00000000000 --- a/libs/langchain/tests/unit_tests/output_parsers/test_pydantic_parser.py +++ /dev/null @@ -1,123 +0,0 @@ -"""Test PydanticOutputParser""" - -from enum import Enum -from typing import Optional - -from langchain_core.exceptions import OutputParserException -from langchain_core.output_parsers import PydanticOutputParser -from langchain_core.pydantic_v1 import BaseModel, Field - - -class Actions(Enum): - SEARCH = "Search" - CREATE = "Create" - UPDATE = "Update" - DELETE = "Delete" - - -class TestModel(BaseModel): - action: Actions = Field(description="Action to be performed") - action_input: str = Field(description="Input to be used in the action") - additional_fields: Optional[str] = Field( - description="Additional fields", default=None - ) - for_new_lines: str = Field(description="To be used to test newlines") - - -# Prevent pytest from trying to run tests on TestModel -TestModel.__test__ = False # type: ignore[attr-defined] - - -DEF_RESULT = """{ - "action": "Update", - "action_input": "The PydanticOutputParser class is powerful", - "additional_fields": null, - "for_new_lines": "not_escape_newline:\n escape_newline: \\n" -}""" - -# action 'update' with a lowercase 'u' to test schema validation failure. -DEF_RESULT_FAIL = """{ - "action": "update", - "action_input": "The PydanticOutputParser class is powerful", - "additional_fields": null -}""" - -DEF_EXPECTED_RESULT = TestModel( - action=Actions.UPDATE, - action_input="The PydanticOutputParser class is powerful", - additional_fields=None, - for_new_lines="not_escape_newline:\n escape_newline: \n", -) - - -def test_pydantic_output_parser() -> None: - """Test PydanticOutputParser.""" - - pydantic_parser: PydanticOutputParser = PydanticOutputParser( - pydantic_object=TestModel - ) - - result = pydantic_parser.parse(DEF_RESULT) - print("parse_result:", result) # noqa: T201 - assert DEF_EXPECTED_RESULT == result - assert pydantic_parser.OutputType is TestModel - - -def test_pydantic_output_parser_fail() -> None: - """Test PydanticOutputParser where completion result fails schema validation.""" - - pydantic_parser: PydanticOutputParser = PydanticOutputParser( - pydantic_object=TestModel - ) - - try: - pydantic_parser.parse(DEF_RESULT_FAIL) - except OutputParserException as e: - print("parse_result:", e) # noqa: T201 - assert "Failed to parse TestModel from completion" in str(e) - else: - assert False, "Expected OutputParserException" - - -def test_pydantic_output_parser_type_inference() -> None: - """Test pydantic output parser type inference.""" - - class SampleModel(BaseModel): - foo: int - bar: str - - # Ignoring mypy error that appears in python 3.8, but not 3.11. - # This seems to be functionally correct, so we'll ignore the error. - pydantic_parser = PydanticOutputParser(pydantic_object=SampleModel) # type: ignore - schema = pydantic_parser.get_output_schema().schema() - - assert schema == { - "properties": { - "bar": {"title": "Bar", "type": "string"}, - "foo": {"title": "Foo", "type": "integer"}, - }, - "required": ["foo", "bar"], - "title": "SampleModel", - "type": "object", - } - - -def test_format_instructions_preserves_language() -> None: - """Test format instructions does not attempt to encode into ascii.""" - from langchain_core.pydantic_v1 import BaseModel, Field - - description = ( - "你好, こんにちは, नमस्ते, Bonjour, Hola, " - "Olá, 안녕하세요, Jambo, Merhaba, Γειά σου" - ) - - class Foo(BaseModel): - hello: str = Field( - description=( - "你好, こんにちは, नमस्ते, Bonjour, Hola, " - "Olá, 안녕하세요, Jambo, Merhaba, Γειά σου" - ) - ) - - parser = PydanticOutputParser(pydantic_object=Foo) # type: ignore - assert description in parser.get_format_instructions()