Add a streaming json parser

This commit is contained in:
Nuno Campos 2023-09-28 20:14:17 +01:00
parent 2387647d30
commit f672b39cc9
2 changed files with 151 additions and 4 deletions

View File

@ -6,6 +6,8 @@ from json import JSONDecodeError
from typing import Any, List from typing import Any, List
from langchain.schema import BaseOutputParser, OutputParserException from langchain.schema import BaseOutputParser, OutputParserException
from langchain.schema.output import ChatGeneration, Generation
from langchain.schema.output_parser import BaseCumulativeTransformOutputParser
def _replace_new_line(match: re.Match[str]) -> str: def _replace_new_line(match: re.Match[str]) -> str:
@ -38,6 +40,66 @@ def _custom_parser(multiline_string: str) -> str:
return multiline_string return multiline_string
# Adapted from https://github.com/KillianLucas/open-interpreter/blob/main/interpreter/utils/parse_partial_json.py
# MIT License
def parse_partial_json(s):
# Attempt to parse the string as-is.
try:
return json.loads(s)
except json.JSONDecodeError:
pass
# Initialize variables.
new_s = ""
stack = []
is_inside_string = False
escaped = False
# Process each character in the string one at a time.
for char in s:
if is_inside_string:
if char == '"' and not escaped:
is_inside_string = False
elif char == "\n" and not escaped:
char = "\\n" # Replace the newline character with the escape sequence.
elif char == "\\":
escaped = not escaped
else:
escaped = False
else:
if char == '"':
is_inside_string = True
escaped = False
elif char == "{":
stack.append("}")
elif char == "[":
stack.append("]")
elif char == "}" or char == "]":
if stack and stack[-1] == char:
stack.pop()
else:
# Mismatched closing character; the input is malformed.
return None
# Append the processed character to the new string.
new_s += char
# If we're still inside a string at the end of processing, we need to close the string.
if is_inside_string:
new_s += '"'
# Close any remaining open structures in the reverse order that they were opened.
for closing_char in reversed(stack):
new_s += closing_char
# Attempt to parse the modified string as JSON.
try:
return json.loads(new_s)
except json.JSONDecodeError:
# If we still can't parse the string as JSON, return None to indicate failure.
return None
def parse_json_markdown(json_string: str) -> dict: def parse_json_markdown(json_string: str) -> dict:
""" """
Parse a JSON string from a Markdown string. Parse a JSON string from a Markdown string.
@ -65,7 +127,7 @@ def parse_json_markdown(json_string: str) -> dict:
json_str = _custom_parser(json_str) json_str = _custom_parser(json_str)
# Parse the JSON string into a Python dictionary # Parse the JSON string into a Python dictionary
parsed = json.loads(json_str) parsed = parse_partial_json(json_str)
return parsed return parsed
@ -101,10 +163,42 @@ class SimpleJsonOutputParser(BaseOutputParser[Any]):
def parse(self, text: str) -> Any: def parse(self, text: str) -> Any:
text = text.strip() text = text.strip()
try: try:
return json.loads(text) return parse_partial_json(text)
except JSONDecodeError as e: except JSONDecodeError as e:
raise OutputParserException(f"Invalid json output: {text}") from e raise OutputParserException(f"Invalid json output: {text}") from e
@property @property
def _type(self) -> str: def _type(self) -> str:
return "simple_json_output_parser" return "simple_json_output_parser"
class PartialJsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
from_function_args: bool = False
@property
def _type(self) -> str:
return "partial_json"
def parse_result(self, result: List[Generation]) -> Any:
if self.from_function_args:
if len(result) != 1:
raise OutputParserException(
f"Expected exactly one result, but got {len(result)}"
)
generation = result[0]
if not isinstance(generation, ChatGeneration):
raise OutputParserException(
"This output parser can only be used with a chat generation."
)
message = generation.message
try:
function_call = message.additional_kwargs["function_call"]
except KeyError:
return None
try:
return parse_partial_json(function_call["arguments"])
except KeyError:
return None
def parse(self, text: str) -> Any:
pass

View File

@ -17,8 +17,13 @@ from typing import (
from typing_extensions import get_args from typing_extensions import get_args
from langchain.load.serializable import Serializable from langchain.load.serializable import Serializable
from langchain.schema.messages import AnyMessage, BaseMessage from langchain.schema.messages import AnyMessage, BaseMessage, BaseMessageChunk
from langchain.schema.output import ChatGeneration, Generation from langchain.schema.output import (
ChatGeneration,
ChatGenerationChunk,
Generation,
GenerationChunk,
)
from langchain.schema.prompt import PromptValue from langchain.schema.prompt import PromptValue
from langchain.schema.runnable import Runnable, RunnableConfig from langchain.schema.runnable import Runnable, RunnableConfig
@ -329,6 +334,54 @@ class BaseTransformOutputParser(BaseOutputParser[T]):
yield chunk yield chunk
class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
"""Base class for an output parser that can handle streaming input."""
def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[Any]:
acc_gen = None
for chunk in input:
if isinstance(chunk, BaseMessageChunk):
chunk_gen = ChatGenerationChunk(message=chunk)
elif isinstance(chunk, BaseMessage):
chunk_gen = ChatGenerationChunk(
message=BaseMessageChunk(**chunk.dict())
)
else:
chunk_gen = GenerationChunk(text=chunk)
if acc_gen is None:
acc_gen = chunk_gen
else:
acc_gen += chunk_gen
parsed = self.parse_result([acc_gen])
if parsed is not None:
yield parsed
async def _atransform(
self, input: AsyncIterator[Union[str, BaseMessage]]
) -> AsyncIterator[T]:
acc_gen = None
for chunk in input:
if isinstance(chunk, BaseMessageChunk):
chunk_gen = ChatGenerationChunk(message=chunk)
elif isinstance(chunk, BaseMessage):
chunk_gen = ChatGenerationChunk(
message=BaseMessageChunk(**chunk.dict())
)
else:
chunk_gen = GenerationChunk(text=chunk)
if acc_gen is None:
acc_gen = chunk_gen
else:
acc_gen += chunk_gen
parsed = self.parse_result([acc_gen])
if parsed is not None:
yield parsed
class StrOutputParser(BaseTransformOutputParser[str]): class StrOutputParser(BaseTransformOutputParser[str]):
"""OutputParser that parses LLMResult into the top likely string.""" """OutputParser that parses LLMResult into the top likely string."""