Implement diff

This commit is contained in:
Nuno Campos 2023-09-29 14:06:07 +01:00
parent 6c0a6b70e0
commit 5cbe2b7b6a
2 changed files with 225 additions and 41 deletions

View File

@ -3,7 +3,9 @@ from __future__ import annotations
import json import json
import re import re
from json import JSONDecodeError from json import JSONDecodeError
from typing import Any, List from typing import Any, List, Optional
import jsonpatch
from langchain.schema import BaseOutputParser, OutputParserException from langchain.schema import BaseOutputParser, OutputParserException
from langchain.schema.output import ChatGeneration, Generation from langchain.schema.output import ChatGeneration, Generation
@ -42,7 +44,7 @@ def _custom_parser(multiline_string: str) -> str:
# Adapted from https://github.com/KillianLucas/open-interpreter/blob/main/interpreter/utils/parse_partial_json.py # Adapted from https://github.com/KillianLucas/open-interpreter/blob/main/interpreter/utils/parse_partial_json.py
# MIT License # MIT License
def parse_partial_json(s): def parse_partial_json(s: str) -> Any:
# Attempt to parse the string as-is. # Attempt to parse the string as-is.
try: try:
return json.loads(s) return json.loads(s)
@ -84,7 +86,8 @@ def parse_partial_json(s):
# Append the processed character to the new string. # Append the processed character to the new string.
new_s += char new_s += char
# If we're still inside a string at the end of processing, we need to close the string. # If we're still inside a string at the end of processing,
# we need to close the string.
if is_inside_string: if is_inside_string:
new_s += '"' new_s += '"'
@ -197,6 +200,9 @@ class PartialFunctionsJsonOutputParser(BaseCumulativeTransformOutputParser[Any])
except KeyError: except KeyError:
return None return None
def _diff(self, prev: Optional[Any], next: Any) -> Any:
return jsonpatch.make_patch(prev, next).patch
def parse(self, text: str) -> Any: def parse(self, text: str) -> Any:
pass pass
@ -206,5 +212,8 @@ class PartialJsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
def _type(self) -> str: def _type(self) -> str:
return "partial_functions_json" return "partial_functions_json"
def _diff(self, prev: Optional[Any], next: Any) -> Any:
return jsonpatch.make_patch(prev, next).patch
def parse(self, text: str) -> Any: def parse(self, text: str) -> Any:
return parse_json_markdown(text) return parse_json_markdown(text)

View File

@ -1,8 +1,15 @@
import json import json
from typing import Iterator, Tuple from typing import Any, Iterator, Tuple
import pytest import pytest
from langchain.output_parsers.json import parse_json_markdown, parse_partial_json from langchain.output_parsers.json import (
PartialFunctionsJsonOutputParser,
PartialJsonOutputParser,
parse_json_markdown,
parse_partial_json,
)
from langchain.schema.messages import AIMessageChunk
GOOD_JSON = """```json GOOD_JSON = """```json
{ {
@ -206,7 +213,6 @@ def test_parse_partial_json(json_strings: Tuple[str, str]) -> None:
STREAMED_TOKENS = """ STREAMED_TOKENS = """
{ {
" "
setup setup
": ":
@ -215,36 +221,50 @@ Why
did did
the the
bears bears
go start
on
a a
picnic band
?", called
Bears
Bears
Bears
?
" "
p ,
unch "
line punchline
": ":
" "
Because Because
they they
wanted wanted
to to
have play
a
bear bear
-y -y
good good
time music
!" !
"
,
"
audience
":
[
"
Haha
"
,
"
So
funny
"
]
} }
""".splitlines() """.splitlines()
EXPECTED_STREAMED_JSON = [ EXPECTED_STREAMED_JSON = [
{},
{}, {},
{"setup": ""}, {"setup": ""},
{"setup": "Why"}, {"setup": "Why"},
@ -258,6 +278,7 @@ EXPECTED_STREAMED_JSON = [
{"setup": "Why did the bears start a band called Bears"}, {"setup": "Why did the bears start a band called Bears"},
{"setup": "Why did the bears start a band called Bears Bears"}, {"setup": "Why did the bears start a band called Bears Bears"},
{"setup": "Why did the bears start a band called Bears Bears Bears"}, {"setup": "Why did the bears start a band called Bears Bears Bears"},
{"setup": "Why did the bears start a band called Bears Bears Bears ?"},
{ {
"setup": "Why did the bears start a band called Bears Bears Bears ?", "setup": "Why did the bears start a band called Bears Bears Bears ?",
"punchline": "", "punchline": "",
@ -303,17 +324,171 @@ EXPECTED_STREAMED_JSON = [
"punchline": "Because they wanted to play bear -y good music !", "punchline": "Because they wanted to play bear -y good music !",
}, },
{ {
"setup": "Why did the bears start a band called Bears Bears Bears?",
"punchline": "Because they wanted to play bear -y good music !", "punchline": "Because they wanted to play bear -y good music !",
"setup": "Why did the bears start a band called Bears Bears Bears ?",
"audience": [],
}, },
{ {
"setup": "Why did the bears start a band called Bears Bears Bears?",
"punchline": "Because they wanted to play bear -y good music !", "punchline": "Because they wanted to play bear -y good music !",
"setup": "Why did the bears start a band called Bears Bears Bears ?",
"audience": [""],
}, },
{
"punchline": "Because they wanted to play bear -y good music !",
"setup": "Why did the bears start a band called Bears Bears Bears ?",
"audience": ["Haha"],
},
{
"punchline": "Because they wanted to play bear -y good music !",
"setup": "Why did the bears start a band called Bears Bears Bears ?",
"audience": ["Haha", ""],
},
{
"punchline": "Because they wanted to play bear -y good music !",
"setup": "Why did the bears start a band called Bears Bears Bears ?",
"audience": ["Haha", "So"],
},
{
"punchline": "Because they wanted to play bear -y good music !",
"setup": "Why did the bears start a band called Bears Bears Bears ?",
"audience": ["Haha", "So funny"],
},
]
EXPECTED_STREAMED_JSON_DIFF = [
[{"op": "replace", "path": "", "value": {}}],
[{"op": "add", "path": "/setup", "value": ""}],
[{"op": "replace", "path": "/setup", "value": "Why"}],
[{"op": "replace", "path": "/setup", "value": "Why did"}],
[{"op": "replace", "path": "/setup", "value": "Why did the"}],
[{"op": "replace", "path": "/setup", "value": "Why did the bears"}],
[{"op": "replace", "path": "/setup", "value": "Why did the bears start"}],
[{"op": "replace", "path": "/setup", "value": "Why did the bears start a"}],
[{"op": "replace", "path": "/setup", "value": "Why did the bears start a band"}],
[
{
"op": "replace",
"path": "/setup",
"value": "Why did the bears start a band called",
}
],
[
{
"op": "replace",
"path": "/setup",
"value": "Why did the bears start a band called Bears",
}
],
[
{
"op": "replace",
"path": "/setup",
"value": "Why did the bears start a band called Bears Bears",
}
],
[
{
"op": "replace",
"path": "/setup",
"value": "Why did the bears start a band called Bears Bears Bears",
}
],
[
{
"op": "replace",
"path": "/setup",
"value": "Why did the bears start a band called Bears Bears Bears ?",
}
],
[{"op": "add", "path": "/punchline", "value": ""}],
[{"op": "replace", "path": "/punchline", "value": "Because"}],
[{"op": "replace", "path": "/punchline", "value": "Because they"}],
[{"op": "replace", "path": "/punchline", "value": "Because they wanted"}],
[{"op": "replace", "path": "/punchline", "value": "Because they wanted to"}],
[{"op": "replace", "path": "/punchline", "value": "Because they wanted to play"}],
[
{
"op": "replace",
"path": "/punchline",
"value": "Because they wanted to play bear",
}
],
[
{
"op": "replace",
"path": "/punchline",
"value": "Because they wanted to play bear -y",
}
],
[
{
"op": "replace",
"path": "/punchline",
"value": "Because they wanted to play bear -y good",
}
],
[
{
"op": "replace",
"path": "/punchline",
"value": "Because they wanted to play bear -y good music",
}
],
[
{
"op": "replace",
"path": "/punchline",
"value": "Because they wanted to play bear -y good music !",
}
],
[{"op": "add", "path": "/audience", "value": []}],
[{"op": "add", "path": "/audience/0", "value": ""}],
[{"op": "replace", "path": "/audience/0", "value": "Haha"}],
[{"op": "add", "path": "/audience/1", "value": ""}],
[{"op": "replace", "path": "/audience/1", "value": "So"}],
[{"op": "replace", "path": "/audience/1", "value": "So funny"}],
] ]
def test_partial_text_json_output_parser() -> None: def test_partial_text_json_output_parser() -> None:
def input_iter() -> Iterator[str]: def input_iter(_: Any) -> Iterator[str]:
for token in STREAMED_TOKENS: for token in STREAMED_TOKENS:
yield token yield token
chain = input_iter | PartialJsonOutputParser()
assert list(chain.stream(None)) == EXPECTED_STREAMED_JSON
def test_partial_functions_json_output_parser() -> None:
def input_iter(_: Any) -> Iterator[AIMessageChunk]:
for token in STREAMED_TOKENS:
yield AIMessageChunk(
content="", additional_kwargs={"function_call": {"arguments": token}}
)
chain = input_iter | PartialFunctionsJsonOutputParser()
assert list(chain.stream(None)) == EXPECTED_STREAMED_JSON
def test_partial_text_json_output_parser_diff() -> None:
def input_iter(_: Any) -> Iterator[str]:
for token in STREAMED_TOKENS:
yield token
chain = input_iter | PartialJsonOutputParser(diff=True)
assert list(chain.stream(None)) == EXPECTED_STREAMED_JSON_DIFF
def test_partial_functions_json_output_parser_diff() -> None:
def input_iter(_: Any) -> Iterator[AIMessageChunk]:
for token in STREAMED_TOKENS:
yield AIMessageChunk(
content="", additional_kwargs={"function_call": {"arguments": token}}
)
chain = input_iter | PartialFunctionsJsonOutputParser(diff=True)
assert list(chain.stream(None)) == EXPECTED_STREAMED_JSON_DIFF