Add strict flag to the JSON parser (#9471)

This updates the default configuration since I think it's almost always
what we want to happen. But we should evaluate whether there are any issues.
This commit is contained in:
Eugene Yurtsev 2023-08-19 22:02:12 -04:00 committed by GitHub
parent 09a92bb9bf
commit e51bccdb28
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 101 additions and 26 deletions

View File

@ -37,17 +37,33 @@ class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
class JsonOutputFunctionsParser(OutputFunctionsParser): class JsonOutputFunctionsParser(OutputFunctionsParser):
"""Parse an output as the Json object.""" """Parse an output as the Json object."""
strict: bool = False
"""Whether to allow non-JSON-compliant strings.
See: https://docs.python.org/3/library/json.html#encoders-and-decoders
Useful when the parsed output may include unicode characters or new lines.
"""
def parse_result(self, result: List[Generation]) -> Any: def parse_result(self, result: List[Generation]) -> Any:
function_call_info = super().parse_result(result) function_call_info = super().parse_result(result)
if self.args_only: if self.args_only:
try: try:
return json.loads(function_call_info) return json.loads(function_call_info, strict=self.strict)
except (json.JSONDecodeError, TypeError) as exc: except (json.JSONDecodeError, TypeError) as exc:
raise OutputParserException( raise OutputParserException(
f"Could not parse function call data: {exc}" f"Could not parse function call data: {exc}"
) )
function_call_info["arguments"] = json.loads(function_call_info["arguments"]) else:
return function_call_info try:
function_call_info["arguments"] = json.loads(
function_call_info["arguments"], strict=self.strict
)
except (json.JSONDecodeError, TypeError) as exc:
raise OutputParserException(
f"Could not parse function call data: {exc}"
)
return function_call_info
class JsonKeyOutputFunctionsParser(JsonOutputFunctionsParser): class JsonKeyOutputFunctionsParser(JsonOutputFunctionsParser):

View File

@ -1,4 +1,4 @@
import json from typing import Any, Dict
import pytest import pytest
@ -9,42 +9,101 @@ from langchain.schema import BaseMessage, ChatGeneration, OutputParserException
from langchain.schema.messages import AIMessage, HumanMessage from langchain.schema.messages import AIMessage, HumanMessage
@pytest.fixture def test_json_output_function_parser() -> None:
def ai_message() -> AIMessage: """Test the JSON output function parser is configured with robust defaults."""
"""Return a simple AIMessage.""" message = AIMessage(
content = "This is a test message" content="This is a test message",
additional_kwargs={
args = json.dumps( "function_call": {
{ "name": "function_name",
"arg1": "value1", "arguments": '{"arg1": "code\ncode"}',
} }
},
) )
chat_generation = ChatGeneration(message=message)
function_call = {"name": "function_name", "arguments": args}
additional_kwargs = {"function_call": function_call}
return AIMessage(content=content, additional_kwargs=additional_kwargs)
def test_json_output_function_parser(ai_message: AIMessage) -> None:
"""Test that the JsonOutputFunctionsParser with full output."""
chat_generation = ChatGeneration(message=ai_message)
# Full output # Full output
# Test that the parsers defaults are configured to parse in non-strict mode
parser = JsonOutputFunctionsParser(args_only=False) parser = JsonOutputFunctionsParser(args_only=False)
result = parser.parse_result([chat_generation]) result = parser.parse_result([chat_generation])
assert result == {"arguments": {"arg1": "value1"}, "name": "function_name"} assert result == {"arguments": {"arg1": "code\ncode"}, "name": "function_name"}
# Args only # Args only
parser = JsonOutputFunctionsParser(args_only=True) parser = JsonOutputFunctionsParser(args_only=True)
result = parser.parse_result([chat_generation]) result = parser.parse_result([chat_generation])
assert result == {"arg1": "value1"} assert result == {"arg1": "code\ncode"}
# Verify that the original message is not modified # Verify that the original message is not modified
assert ai_message.additional_kwargs == { assert message.additional_kwargs == {
"function_call": {"name": "function_name", "arguments": '{"arg1": "value1"}'} "function_call": {
"name": "function_name",
"arguments": '{"arg1": "code\ncode"}',
}
} }
@pytest.mark.parametrize(
"config",
[
{
"args_only": False,
"strict": False,
"args": '{"arg1": "value1"}',
"result": {"arguments": {"arg1": "value1"}, "name": "function_name"},
"exception": None,
},
{
"args_only": True,
"strict": False,
"args": '{"arg1": "value1"}',
"result": {"arg1": "value1"},
"exception": None,
},
{
"args_only": True,
"strict": False,
"args": '{"code": "print(2+\n2)"}',
"result": {"code": "print(2+\n2)"},
"exception": None,
},
{
"args_only": True,
"strict": False,
"args": '{"code": "你好)"}',
"result": {"code": "你好)"},
"exception": None,
},
{
"args_only": True,
"strict": True,
"args": '{"code": "print(2+\n2)"}',
"exception": OutputParserException,
},
],
)
def test_json_output_function_parser_strictness(config: Dict[str, Any]) -> None:
"""Test parsing with JSON strictness on and off."""
args = config["args"]
message = AIMessage(
content="This is a test message",
additional_kwargs={
"function_call": {"name": "function_name", "arguments": args}
},
)
chat_generation = ChatGeneration(message=message)
# Full output
parser = JsonOutputFunctionsParser(
strict=config["strict"], args_only=config["args_only"]
)
if config["exception"] is not None:
with pytest.raises(config["exception"]):
parser.parse_result([chat_generation])
else:
assert parser.parse_result([chat_generation]) == config["result"]
@pytest.mark.parametrize( @pytest.mark.parametrize(
"bad_message", "bad_message",
[ [