From e51bccdb2890fa193ce7eb5bf7e13c28afef4dc4 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Sat, 19 Aug 2023 22:02:12 -0400 Subject: [PATCH] Add strict flag to the JSON parser (#9471) This updates the default configuration since I think it's almost always what we want to happen. But we should evaluate whether there are any issues. --- .../output_parsers/openai_functions.py | 22 +++- .../output_parsers/test_openai_functions.py | 105 ++++++++++++++---- 2 files changed, 101 insertions(+), 26 deletions(-) diff --git a/libs/langchain/langchain/output_parsers/openai_functions.py b/libs/langchain/langchain/output_parsers/openai_functions.py index f916c4836ba..cabafd599de 100644 --- a/libs/langchain/langchain/output_parsers/openai_functions.py +++ b/libs/langchain/langchain/output_parsers/openai_functions.py @@ -37,17 +37,33 @@ class OutputFunctionsParser(BaseGenerationOutputParser[Any]): class JsonOutputFunctionsParser(OutputFunctionsParser): """Parse an output as the Json object.""" + strict: bool = False + """Whether to allow non-JSON-compliant strings. + + See: https://docs.python.org/3/library/json.html#encoders-and-decoders + + Useful when the parsed output may include unicode characters or new lines. + """ + def parse_result(self, result: List[Generation]) -> Any: function_call_info = super().parse_result(result) if self.args_only: try: - return json.loads(function_call_info) + return json.loads(function_call_info, strict=self.strict) except (json.JSONDecodeError, TypeError) as exc: raise OutputParserException( f"Could not parse function call data: {exc}" ) - function_call_info["arguments"] = json.loads(function_call_info["arguments"]) - return function_call_info + else: + try: + function_call_info["arguments"] = json.loads( + function_call_info["arguments"], strict=self.strict + ) + except (json.JSONDecodeError, TypeError) as exc: + raise OutputParserException( + f"Could not parse function call data: {exc}" + ) + return function_call_info class JsonKeyOutputFunctionsParser(JsonOutputFunctionsParser): diff --git a/libs/langchain/tests/unit_tests/output_parsers/test_openai_functions.py b/libs/langchain/tests/unit_tests/output_parsers/test_openai_functions.py index 7c6411e61f0..364f5a5a22e 100644 --- a/libs/langchain/tests/unit_tests/output_parsers/test_openai_functions.py +++ b/libs/langchain/tests/unit_tests/output_parsers/test_openai_functions.py @@ -1,4 +1,4 @@ -import json +from typing import Any, Dict import pytest @@ -9,42 +9,101 @@ from langchain.schema import BaseMessage, ChatGeneration, OutputParserException from langchain.schema.messages import AIMessage, HumanMessage -@pytest.fixture -def ai_message() -> AIMessage: - """Return a simple AIMessage.""" - content = "This is a test message" - - args = json.dumps( - { - "arg1": "value1", - } +def test_json_output_function_parser() -> None: + """Test the JSON output function parser is configured with robust defaults.""" + message = AIMessage( + content="This is a test message", + additional_kwargs={ + "function_call": { + "name": "function_name", + "arguments": '{"arg1": "code\ncode"}', + } + }, ) - - function_call = {"name": "function_name", "arguments": args} - additional_kwargs = {"function_call": function_call} - return AIMessage(content=content, additional_kwargs=additional_kwargs) - - -def test_json_output_function_parser(ai_message: AIMessage) -> None: - """Test that the JsonOutputFunctionsParser with full output.""" - chat_generation = ChatGeneration(message=ai_message) + chat_generation = ChatGeneration(message=message) # Full output + # Test that the parsers defaults are configured to parse in non-strict mode parser = JsonOutputFunctionsParser(args_only=False) result = parser.parse_result([chat_generation]) - assert result == {"arguments": {"arg1": "value1"}, "name": "function_name"} + assert result == {"arguments": {"arg1": "code\ncode"}, "name": "function_name"} # Args only parser = JsonOutputFunctionsParser(args_only=True) result = parser.parse_result([chat_generation]) - assert result == {"arg1": "value1"} + assert result == {"arg1": "code\ncode"} # Verify that the original message is not modified - assert ai_message.additional_kwargs == { - "function_call": {"name": "function_name", "arguments": '{"arg1": "value1"}'} + assert message.additional_kwargs == { + "function_call": { + "name": "function_name", + "arguments": '{"arg1": "code\ncode"}', + } } +@pytest.mark.parametrize( + "config", + [ + { + "args_only": False, + "strict": False, + "args": '{"arg1": "value1"}', + "result": {"arguments": {"arg1": "value1"}, "name": "function_name"}, + "exception": None, + }, + { + "args_only": True, + "strict": False, + "args": '{"arg1": "value1"}', + "result": {"arg1": "value1"}, + "exception": None, + }, + { + "args_only": True, + "strict": False, + "args": '{"code": "print(2+\n2)"}', + "result": {"code": "print(2+\n2)"}, + "exception": None, + }, + { + "args_only": True, + "strict": False, + "args": '{"code": "你好)"}', + "result": {"code": "你好)"}, + "exception": None, + }, + { + "args_only": True, + "strict": True, + "args": '{"code": "print(2+\n2)"}', + "exception": OutputParserException, + }, + ], +) +def test_json_output_function_parser_strictness(config: Dict[str, Any]) -> None: + """Test parsing with JSON strictness on and off.""" + args = config["args"] + + message = AIMessage( + content="This is a test message", + additional_kwargs={ + "function_call": {"name": "function_name", "arguments": args} + }, + ) + chat_generation = ChatGeneration(message=message) + + # Full output + parser = JsonOutputFunctionsParser( + strict=config["strict"], args_only=config["args_only"] + ) + if config["exception"] is not None: + with pytest.raises(config["exception"]): + parser.parse_result([chat_generation]) + else: + assert parser.parse_result([chat_generation]) == config["result"] + + @pytest.mark.parametrize( "bad_message", [