diff --git a/libs/langchain/langchain/output_parsers/json.py b/libs/langchain/langchain/output_parsers/json.py index 7465aba2fe1..0168f1ae364 100644 --- a/libs/langchain/langchain/output_parsers/json.py +++ b/libs/langchain/langchain/output_parsers/json.py @@ -6,6 +6,8 @@ from json import JSONDecodeError from typing import Any, List from langchain.schema import BaseOutputParser, OutputParserException +from langchain.schema.output import ChatGeneration, Generation +from langchain.schema.output_parser import BaseCumulativeTransformOutputParser def _replace_new_line(match: re.Match[str]) -> str: @@ -38,6 +40,66 @@ def _custom_parser(multiline_string: str) -> str: return multiline_string +# Adapted from https://github.com/KillianLucas/open-interpreter/blob/main/interpreter/utils/parse_partial_json.py +# MIT License +def parse_partial_json(s): + # Attempt to parse the string as-is. + try: + return json.loads(s) + except json.JSONDecodeError: + pass + + # Initialize variables. + new_s = "" + stack = [] + is_inside_string = False + escaped = False + + # Process each character in the string one at a time. + for char in s: + if is_inside_string: + if char == '"' and not escaped: + is_inside_string = False + elif char == "\n" and not escaped: + char = "\\n" # Replace the newline character with the escape sequence. + elif char == "\\": + escaped = not escaped + else: + escaped = False + else: + if char == '"': + is_inside_string = True + escaped = False + elif char == "{": + stack.append("}") + elif char == "[": + stack.append("]") + elif char == "}" or char == "]": + if stack and stack[-1] == char: + stack.pop() + else: + # Mismatched closing character; the input is malformed. + return None + + # Append the processed character to the new string. + new_s += char + + # If we're still inside a string at the end of processing, we need to close the string. + if is_inside_string: + new_s += '"' + + # Close any remaining open structures in the reverse order that they were opened. + for closing_char in reversed(stack): + new_s += closing_char + + # Attempt to parse the modified string as JSON. + try: + return json.loads(new_s) + except json.JSONDecodeError: + # If we still can't parse the string as JSON, return None to indicate failure. + return None + + def parse_json_markdown(json_string: str) -> dict: """ Parse a JSON string from a Markdown string. @@ -65,7 +127,7 @@ def parse_json_markdown(json_string: str) -> dict: json_str = _custom_parser(json_str) # Parse the JSON string into a Python dictionary - parsed = json.loads(json_str) + parsed = parse_partial_json(json_str) return parsed @@ -101,10 +163,42 @@ class SimpleJsonOutputParser(BaseOutputParser[Any]): def parse(self, text: str) -> Any: text = text.strip() try: - return json.loads(text) + return parse_partial_json(text) except JSONDecodeError as e: raise OutputParserException(f"Invalid json output: {text}") from e @property def _type(self) -> str: return "simple_json_output_parser" + + +class PartialJsonOutputParser(BaseCumulativeTransformOutputParser[Any]): + from_function_args: bool = False + + @property + def _type(self) -> str: + return "partial_json" + + def parse_result(self, result: List[Generation]) -> Any: + if self.from_function_args: + if len(result) != 1: + raise OutputParserException( + f"Expected exactly one result, but got {len(result)}" + ) + generation = result[0] + if not isinstance(generation, ChatGeneration): + raise OutputParserException( + "This output parser can only be used with a chat generation." + ) + message = generation.message + try: + function_call = message.additional_kwargs["function_call"] + except KeyError: + return None + try: + return parse_partial_json(function_call["arguments"]) + except KeyError: + return None + + def parse(self, text: str) -> Any: + pass diff --git a/libs/langchain/langchain/schema/output_parser.py b/libs/langchain/langchain/schema/output_parser.py index 46a8d9def03..5f09ee48ae6 100644 --- a/libs/langchain/langchain/schema/output_parser.py +++ b/libs/langchain/langchain/schema/output_parser.py @@ -17,8 +17,13 @@ from typing import ( from typing_extensions import get_args from langchain.load.serializable import Serializable -from langchain.schema.messages import AnyMessage, BaseMessage -from langchain.schema.output import ChatGeneration, Generation +from langchain.schema.messages import AnyMessage, BaseMessage, BaseMessageChunk +from langchain.schema.output import ( + ChatGeneration, + ChatGenerationChunk, + Generation, + GenerationChunk, +) from langchain.schema.prompt import PromptValue from langchain.schema.runnable import Runnable, RunnableConfig @@ -329,6 +334,54 @@ class BaseTransformOutputParser(BaseOutputParser[T]): yield chunk +class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]): + """Base class for an output parser that can handle streaming input.""" + + def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[Any]: + acc_gen = None + for chunk in input: + if isinstance(chunk, BaseMessageChunk): + chunk_gen = ChatGenerationChunk(message=chunk) + elif isinstance(chunk, BaseMessage): + chunk_gen = ChatGenerationChunk( + message=BaseMessageChunk(**chunk.dict()) + ) + else: + chunk_gen = GenerationChunk(text=chunk) + + if acc_gen is None: + acc_gen = chunk_gen + else: + acc_gen += chunk_gen + + parsed = self.parse_result([acc_gen]) + if parsed is not None: + yield parsed + + async def _atransform( + self, input: AsyncIterator[Union[str, BaseMessage]] + ) -> AsyncIterator[T]: + acc_gen = None + for chunk in input: + if isinstance(chunk, BaseMessageChunk): + chunk_gen = ChatGenerationChunk(message=chunk) + elif isinstance(chunk, BaseMessage): + chunk_gen = ChatGenerationChunk( + message=BaseMessageChunk(**chunk.dict()) + ) + else: + chunk_gen = GenerationChunk(text=chunk) + + if acc_gen is None: + acc_gen = chunk_gen + else: + acc_gen += chunk_gen + + parsed = self.parse_result([acc_gen]) + if parsed is not None: + yield parsed + + class StrOutputParser(BaseTransformOutputParser[str]): """OutputParser that parses LLMResult into the top likely string."""