mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-09 06:24:47 +00:00
Add a streaming json parser
This commit is contained in:
parent
2387647d30
commit
f672b39cc9
@ -6,6 +6,8 @@ from json import JSONDecodeError
|
|||||||
from typing import Any, List
|
from typing import Any, List
|
||||||
|
|
||||||
from langchain.schema import BaseOutputParser, OutputParserException
|
from langchain.schema import BaseOutputParser, OutputParserException
|
||||||
|
from langchain.schema.output import ChatGeneration, Generation
|
||||||
|
from langchain.schema.output_parser import BaseCumulativeTransformOutputParser
|
||||||
|
|
||||||
|
|
||||||
def _replace_new_line(match: re.Match[str]) -> str:
|
def _replace_new_line(match: re.Match[str]) -> str:
|
||||||
@ -38,6 +40,66 @@ def _custom_parser(multiline_string: str) -> str:
|
|||||||
return multiline_string
|
return multiline_string
|
||||||
|
|
||||||
|
|
||||||
|
# Adapted from https://github.com/KillianLucas/open-interpreter/blob/main/interpreter/utils/parse_partial_json.py
|
||||||
|
# MIT License
|
||||||
|
def parse_partial_json(s):
|
||||||
|
# Attempt to parse the string as-is.
|
||||||
|
try:
|
||||||
|
return json.loads(s)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Initialize variables.
|
||||||
|
new_s = ""
|
||||||
|
stack = []
|
||||||
|
is_inside_string = False
|
||||||
|
escaped = False
|
||||||
|
|
||||||
|
# Process each character in the string one at a time.
|
||||||
|
for char in s:
|
||||||
|
if is_inside_string:
|
||||||
|
if char == '"' and not escaped:
|
||||||
|
is_inside_string = False
|
||||||
|
elif char == "\n" and not escaped:
|
||||||
|
char = "\\n" # Replace the newline character with the escape sequence.
|
||||||
|
elif char == "\\":
|
||||||
|
escaped = not escaped
|
||||||
|
else:
|
||||||
|
escaped = False
|
||||||
|
else:
|
||||||
|
if char == '"':
|
||||||
|
is_inside_string = True
|
||||||
|
escaped = False
|
||||||
|
elif char == "{":
|
||||||
|
stack.append("}")
|
||||||
|
elif char == "[":
|
||||||
|
stack.append("]")
|
||||||
|
elif char == "}" or char == "]":
|
||||||
|
if stack and stack[-1] == char:
|
||||||
|
stack.pop()
|
||||||
|
else:
|
||||||
|
# Mismatched closing character; the input is malformed.
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Append the processed character to the new string.
|
||||||
|
new_s += char
|
||||||
|
|
||||||
|
# If we're still inside a string at the end of processing, we need to close the string.
|
||||||
|
if is_inside_string:
|
||||||
|
new_s += '"'
|
||||||
|
|
||||||
|
# Close any remaining open structures in the reverse order that they were opened.
|
||||||
|
for closing_char in reversed(stack):
|
||||||
|
new_s += closing_char
|
||||||
|
|
||||||
|
# Attempt to parse the modified string as JSON.
|
||||||
|
try:
|
||||||
|
return json.loads(new_s)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
# If we still can't parse the string as JSON, return None to indicate failure.
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def parse_json_markdown(json_string: str) -> dict:
|
def parse_json_markdown(json_string: str) -> dict:
|
||||||
"""
|
"""
|
||||||
Parse a JSON string from a Markdown string.
|
Parse a JSON string from a Markdown string.
|
||||||
@ -65,7 +127,7 @@ def parse_json_markdown(json_string: str) -> dict:
|
|||||||
json_str = _custom_parser(json_str)
|
json_str = _custom_parser(json_str)
|
||||||
|
|
||||||
# Parse the JSON string into a Python dictionary
|
# Parse the JSON string into a Python dictionary
|
||||||
parsed = json.loads(json_str)
|
parsed = parse_partial_json(json_str)
|
||||||
|
|
||||||
return parsed
|
return parsed
|
||||||
|
|
||||||
@ -101,10 +163,42 @@ class SimpleJsonOutputParser(BaseOutputParser[Any]):
|
|||||||
def parse(self, text: str) -> Any:
|
def parse(self, text: str) -> Any:
|
||||||
text = text.strip()
|
text = text.strip()
|
||||||
try:
|
try:
|
||||||
return json.loads(text)
|
return parse_partial_json(text)
|
||||||
except JSONDecodeError as e:
|
except JSONDecodeError as e:
|
||||||
raise OutputParserException(f"Invalid json output: {text}") from e
|
raise OutputParserException(f"Invalid json output: {text}") from e
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _type(self) -> str:
|
def _type(self) -> str:
|
||||||
return "simple_json_output_parser"
|
return "simple_json_output_parser"
|
||||||
|
|
||||||
|
|
||||||
|
class PartialJsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
|
||||||
|
from_function_args: bool = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _type(self) -> str:
|
||||||
|
return "partial_json"
|
||||||
|
|
||||||
|
def parse_result(self, result: List[Generation]) -> Any:
|
||||||
|
if self.from_function_args:
|
||||||
|
if len(result) != 1:
|
||||||
|
raise OutputParserException(
|
||||||
|
f"Expected exactly one result, but got {len(result)}"
|
||||||
|
)
|
||||||
|
generation = result[0]
|
||||||
|
if not isinstance(generation, ChatGeneration):
|
||||||
|
raise OutputParserException(
|
||||||
|
"This output parser can only be used with a chat generation."
|
||||||
|
)
|
||||||
|
message = generation.message
|
||||||
|
try:
|
||||||
|
function_call = message.additional_kwargs["function_call"]
|
||||||
|
except KeyError:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return parse_partial_json(function_call["arguments"])
|
||||||
|
except KeyError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def parse(self, text: str) -> Any:
|
||||||
|
pass
|
||||||
|
@ -17,8 +17,13 @@ from typing import (
|
|||||||
from typing_extensions import get_args
|
from typing_extensions import get_args
|
||||||
|
|
||||||
from langchain.load.serializable import Serializable
|
from langchain.load.serializable import Serializable
|
||||||
from langchain.schema.messages import AnyMessage, BaseMessage
|
from langchain.schema.messages import AnyMessage, BaseMessage, BaseMessageChunk
|
||||||
from langchain.schema.output import ChatGeneration, Generation
|
from langchain.schema.output import (
|
||||||
|
ChatGeneration,
|
||||||
|
ChatGenerationChunk,
|
||||||
|
Generation,
|
||||||
|
GenerationChunk,
|
||||||
|
)
|
||||||
from langchain.schema.prompt import PromptValue
|
from langchain.schema.prompt import PromptValue
|
||||||
from langchain.schema.runnable import Runnable, RunnableConfig
|
from langchain.schema.runnable import Runnable, RunnableConfig
|
||||||
|
|
||||||
@ -329,6 +334,54 @@ class BaseTransformOutputParser(BaseOutputParser[T]):
|
|||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
|
||||||
|
class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
|
||||||
|
"""Base class for an output parser that can handle streaming input."""
|
||||||
|
|
||||||
|
def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[Any]:
|
||||||
|
acc_gen = None
|
||||||
|
for chunk in input:
|
||||||
|
if isinstance(chunk, BaseMessageChunk):
|
||||||
|
chunk_gen = ChatGenerationChunk(message=chunk)
|
||||||
|
elif isinstance(chunk, BaseMessage):
|
||||||
|
chunk_gen = ChatGenerationChunk(
|
||||||
|
message=BaseMessageChunk(**chunk.dict())
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
chunk_gen = GenerationChunk(text=chunk)
|
||||||
|
|
||||||
|
if acc_gen is None:
|
||||||
|
acc_gen = chunk_gen
|
||||||
|
else:
|
||||||
|
acc_gen += chunk_gen
|
||||||
|
|
||||||
|
parsed = self.parse_result([acc_gen])
|
||||||
|
if parsed is not None:
|
||||||
|
yield parsed
|
||||||
|
|
||||||
|
async def _atransform(
|
||||||
|
self, input: AsyncIterator[Union[str, BaseMessage]]
|
||||||
|
) -> AsyncIterator[T]:
|
||||||
|
acc_gen = None
|
||||||
|
for chunk in input:
|
||||||
|
if isinstance(chunk, BaseMessageChunk):
|
||||||
|
chunk_gen = ChatGenerationChunk(message=chunk)
|
||||||
|
elif isinstance(chunk, BaseMessage):
|
||||||
|
chunk_gen = ChatGenerationChunk(
|
||||||
|
message=BaseMessageChunk(**chunk.dict())
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
chunk_gen = GenerationChunk(text=chunk)
|
||||||
|
|
||||||
|
if acc_gen is None:
|
||||||
|
acc_gen = chunk_gen
|
||||||
|
else:
|
||||||
|
acc_gen += chunk_gen
|
||||||
|
|
||||||
|
parsed = self.parse_result([acc_gen])
|
||||||
|
if parsed is not None:
|
||||||
|
yield parsed
|
||||||
|
|
||||||
|
|
||||||
class StrOutputParser(BaseTransformOutputParser[str]):
|
class StrOutputParser(BaseTransformOutputParser[str]):
|
||||||
"""OutputParser that parses LLMResult into the top likely string."""
|
"""OutputParser that parses LLMResult into the top likely string."""
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user