mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-09 21:08:59 +00:00
Add strict flag to the JSON parser (#9471)
This updates the default configuration since I think it's almost always what we want to happen. But we should evaluate whether there are any issues.
This commit is contained in:
parent
09a92bb9bf
commit
e51bccdb28
@ -37,17 +37,33 @@ class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
|
|||||||
class JsonOutputFunctionsParser(OutputFunctionsParser):
|
class JsonOutputFunctionsParser(OutputFunctionsParser):
|
||||||
"""Parse an output as the Json object."""
|
"""Parse an output as the Json object."""
|
||||||
|
|
||||||
|
strict: bool = False
|
||||||
|
"""Whether to allow non-JSON-compliant strings.
|
||||||
|
|
||||||
|
See: https://docs.python.org/3/library/json.html#encoders-and-decoders
|
||||||
|
|
||||||
|
Useful when the parsed output may include unicode characters or new lines.
|
||||||
|
"""
|
||||||
|
|
||||||
def parse_result(self, result: List[Generation]) -> Any:
|
def parse_result(self, result: List[Generation]) -> Any:
|
||||||
function_call_info = super().parse_result(result)
|
function_call_info = super().parse_result(result)
|
||||||
if self.args_only:
|
if self.args_only:
|
||||||
try:
|
try:
|
||||||
return json.loads(function_call_info)
|
return json.loads(function_call_info, strict=self.strict)
|
||||||
except (json.JSONDecodeError, TypeError) as exc:
|
except (json.JSONDecodeError, TypeError) as exc:
|
||||||
raise OutputParserException(
|
raise OutputParserException(
|
||||||
f"Could not parse function call data: {exc}"
|
f"Could not parse function call data: {exc}"
|
||||||
)
|
)
|
||||||
function_call_info["arguments"] = json.loads(function_call_info["arguments"])
|
else:
|
||||||
return function_call_info
|
try:
|
||||||
|
function_call_info["arguments"] = json.loads(
|
||||||
|
function_call_info["arguments"], strict=self.strict
|
||||||
|
)
|
||||||
|
except (json.JSONDecodeError, TypeError) as exc:
|
||||||
|
raise OutputParserException(
|
||||||
|
f"Could not parse function call data: {exc}"
|
||||||
|
)
|
||||||
|
return function_call_info
|
||||||
|
|
||||||
|
|
||||||
class JsonKeyOutputFunctionsParser(JsonOutputFunctionsParser):
|
class JsonKeyOutputFunctionsParser(JsonOutputFunctionsParser):
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
import json
|
from typing import Any, Dict
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -9,42 +9,101 @@ from langchain.schema import BaseMessage, ChatGeneration, OutputParserException
|
|||||||
from langchain.schema.messages import AIMessage, HumanMessage
|
from langchain.schema.messages import AIMessage, HumanMessage
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
def test_json_output_function_parser() -> None:
|
||||||
def ai_message() -> AIMessage:
|
"""Test the JSON output function parser is configured with robust defaults."""
|
||||||
"""Return a simple AIMessage."""
|
message = AIMessage(
|
||||||
content = "This is a test message"
|
content="This is a test message",
|
||||||
|
additional_kwargs={
|
||||||
args = json.dumps(
|
"function_call": {
|
||||||
{
|
"name": "function_name",
|
||||||
"arg1": "value1",
|
"arguments": '{"arg1": "code\ncode"}',
|
||||||
}
|
}
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
chat_generation = ChatGeneration(message=message)
|
||||||
function_call = {"name": "function_name", "arguments": args}
|
|
||||||
additional_kwargs = {"function_call": function_call}
|
|
||||||
return AIMessage(content=content, additional_kwargs=additional_kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def test_json_output_function_parser(ai_message: AIMessage) -> None:
|
|
||||||
"""Test that the JsonOutputFunctionsParser with full output."""
|
|
||||||
chat_generation = ChatGeneration(message=ai_message)
|
|
||||||
|
|
||||||
# Full output
|
# Full output
|
||||||
|
# Test that the parsers defaults are configured to parse in non-strict mode
|
||||||
parser = JsonOutputFunctionsParser(args_only=False)
|
parser = JsonOutputFunctionsParser(args_only=False)
|
||||||
result = parser.parse_result([chat_generation])
|
result = parser.parse_result([chat_generation])
|
||||||
assert result == {"arguments": {"arg1": "value1"}, "name": "function_name"}
|
assert result == {"arguments": {"arg1": "code\ncode"}, "name": "function_name"}
|
||||||
|
|
||||||
# Args only
|
# Args only
|
||||||
parser = JsonOutputFunctionsParser(args_only=True)
|
parser = JsonOutputFunctionsParser(args_only=True)
|
||||||
result = parser.parse_result([chat_generation])
|
result = parser.parse_result([chat_generation])
|
||||||
assert result == {"arg1": "value1"}
|
assert result == {"arg1": "code\ncode"}
|
||||||
|
|
||||||
# Verify that the original message is not modified
|
# Verify that the original message is not modified
|
||||||
assert ai_message.additional_kwargs == {
|
assert message.additional_kwargs == {
|
||||||
"function_call": {"name": "function_name", "arguments": '{"arg1": "value1"}'}
|
"function_call": {
|
||||||
|
"name": "function_name",
|
||||||
|
"arguments": '{"arg1": "code\ncode"}',
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"config",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"args_only": False,
|
||||||
|
"strict": False,
|
||||||
|
"args": '{"arg1": "value1"}',
|
||||||
|
"result": {"arguments": {"arg1": "value1"}, "name": "function_name"},
|
||||||
|
"exception": None,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"args_only": True,
|
||||||
|
"strict": False,
|
||||||
|
"args": '{"arg1": "value1"}',
|
||||||
|
"result": {"arg1": "value1"},
|
||||||
|
"exception": None,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"args_only": True,
|
||||||
|
"strict": False,
|
||||||
|
"args": '{"code": "print(2+\n2)"}',
|
||||||
|
"result": {"code": "print(2+\n2)"},
|
||||||
|
"exception": None,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"args_only": True,
|
||||||
|
"strict": False,
|
||||||
|
"args": '{"code": "你好)"}',
|
||||||
|
"result": {"code": "你好)"},
|
||||||
|
"exception": None,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"args_only": True,
|
||||||
|
"strict": True,
|
||||||
|
"args": '{"code": "print(2+\n2)"}',
|
||||||
|
"exception": OutputParserException,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_json_output_function_parser_strictness(config: Dict[str, Any]) -> None:
|
||||||
|
"""Test parsing with JSON strictness on and off."""
|
||||||
|
args = config["args"]
|
||||||
|
|
||||||
|
message = AIMessage(
|
||||||
|
content="This is a test message",
|
||||||
|
additional_kwargs={
|
||||||
|
"function_call": {"name": "function_name", "arguments": args}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
chat_generation = ChatGeneration(message=message)
|
||||||
|
|
||||||
|
# Full output
|
||||||
|
parser = JsonOutputFunctionsParser(
|
||||||
|
strict=config["strict"], args_only=config["args_only"]
|
||||||
|
)
|
||||||
|
if config["exception"] is not None:
|
||||||
|
with pytest.raises(config["exception"]):
|
||||||
|
parser.parse_result([chat_generation])
|
||||||
|
else:
|
||||||
|
assert parser.parse_result([chat_generation]) == config["result"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"bad_message",
|
"bad_message",
|
||||||
[
|
[
|
||||||
|
Loading…
Reference in New Issue
Block a user