From f672b39cc9a73f716cf3d118634966bda502067c Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Thu, 28 Sep 2023 20:14:17 +0100 Subject: [PATCH 01/15] Add a streaming json parser --- .../langchain/output_parsers/json.py | 98 ++++++++++++++++++- .../langchain/schema/output_parser.py | 57 ++++++++++- 2 files changed, 151 insertions(+), 4 deletions(-) diff --git a/libs/langchain/langchain/output_parsers/json.py b/libs/langchain/langchain/output_parsers/json.py index 7465aba2fe1..0168f1ae364 100644 --- a/libs/langchain/langchain/output_parsers/json.py +++ b/libs/langchain/langchain/output_parsers/json.py @@ -6,6 +6,8 @@ from json import JSONDecodeError from typing import Any, List from langchain.schema import BaseOutputParser, OutputParserException +from langchain.schema.output import ChatGeneration, Generation +from langchain.schema.output_parser import BaseCumulativeTransformOutputParser def _replace_new_line(match: re.Match[str]) -> str: @@ -38,6 +40,66 @@ def _custom_parser(multiline_string: str) -> str: return multiline_string +# Adapted from https://github.com/KillianLucas/open-interpreter/blob/main/interpreter/utils/parse_partial_json.py +# MIT License +def parse_partial_json(s): + # Attempt to parse the string as-is. + try: + return json.loads(s) + except json.JSONDecodeError: + pass + + # Initialize variables. + new_s = "" + stack = [] + is_inside_string = False + escaped = False + + # Process each character in the string one at a time. + for char in s: + if is_inside_string: + if char == '"' and not escaped: + is_inside_string = False + elif char == "\n" and not escaped: + char = "\\n" # Replace the newline character with the escape sequence. + elif char == "\\": + escaped = not escaped + else: + escaped = False + else: + if char == '"': + is_inside_string = True + escaped = False + elif char == "{": + stack.append("}") + elif char == "[": + stack.append("]") + elif char == "}" or char == "]": + if stack and stack[-1] == char: + stack.pop() + else: + # Mismatched closing character; the input is malformed. + return None + + # Append the processed character to the new string. + new_s += char + + # If we're still inside a string at the end of processing, we need to close the string. + if is_inside_string: + new_s += '"' + + # Close any remaining open structures in the reverse order that they were opened. + for closing_char in reversed(stack): + new_s += closing_char + + # Attempt to parse the modified string as JSON. + try: + return json.loads(new_s) + except json.JSONDecodeError: + # If we still can't parse the string as JSON, return None to indicate failure. + return None + + def parse_json_markdown(json_string: str) -> dict: """ Parse a JSON string from a Markdown string. @@ -65,7 +127,7 @@ def parse_json_markdown(json_string: str) -> dict: json_str = _custom_parser(json_str) # Parse the JSON string into a Python dictionary - parsed = json.loads(json_str) + parsed = parse_partial_json(json_str) return parsed @@ -101,10 +163,42 @@ class SimpleJsonOutputParser(BaseOutputParser[Any]): def parse(self, text: str) -> Any: text = text.strip() try: - return json.loads(text) + return parse_partial_json(text) except JSONDecodeError as e: raise OutputParserException(f"Invalid json output: {text}") from e @property def _type(self) -> str: return "simple_json_output_parser" + + +class PartialJsonOutputParser(BaseCumulativeTransformOutputParser[Any]): + from_function_args: bool = False + + @property + def _type(self) -> str: + return "partial_json" + + def parse_result(self, result: List[Generation]) -> Any: + if self.from_function_args: + if len(result) != 1: + raise OutputParserException( + f"Expected exactly one result, but got {len(result)}" + ) + generation = result[0] + if not isinstance(generation, ChatGeneration): + raise OutputParserException( + "This output parser can only be used with a chat generation." + ) + message = generation.message + try: + function_call = message.additional_kwargs["function_call"] + except KeyError: + return None + try: + return parse_partial_json(function_call["arguments"]) + except KeyError: + return None + + def parse(self, text: str) -> Any: + pass diff --git a/libs/langchain/langchain/schema/output_parser.py b/libs/langchain/langchain/schema/output_parser.py index 46a8d9def03..5f09ee48ae6 100644 --- a/libs/langchain/langchain/schema/output_parser.py +++ b/libs/langchain/langchain/schema/output_parser.py @@ -17,8 +17,13 @@ from typing import ( from typing_extensions import get_args from langchain.load.serializable import Serializable -from langchain.schema.messages import AnyMessage, BaseMessage -from langchain.schema.output import ChatGeneration, Generation +from langchain.schema.messages import AnyMessage, BaseMessage, BaseMessageChunk +from langchain.schema.output import ( + ChatGeneration, + ChatGenerationChunk, + Generation, + GenerationChunk, +) from langchain.schema.prompt import PromptValue from langchain.schema.runnable import Runnable, RunnableConfig @@ -329,6 +334,54 @@ class BaseTransformOutputParser(BaseOutputParser[T]): yield chunk +class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]): + """Base class for an output parser that can handle streaming input.""" + + def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[Any]: + acc_gen = None + for chunk in input: + if isinstance(chunk, BaseMessageChunk): + chunk_gen = ChatGenerationChunk(message=chunk) + elif isinstance(chunk, BaseMessage): + chunk_gen = ChatGenerationChunk( + message=BaseMessageChunk(**chunk.dict()) + ) + else: + chunk_gen = GenerationChunk(text=chunk) + + if acc_gen is None: + acc_gen = chunk_gen + else: + acc_gen += chunk_gen + + parsed = self.parse_result([acc_gen]) + if parsed is not None: + yield parsed + + async def _atransform( + self, input: AsyncIterator[Union[str, BaseMessage]] + ) -> AsyncIterator[T]: + acc_gen = None + for chunk in input: + if isinstance(chunk, BaseMessageChunk): + chunk_gen = ChatGenerationChunk(message=chunk) + elif isinstance(chunk, BaseMessage): + chunk_gen = ChatGenerationChunk( + message=BaseMessageChunk(**chunk.dict()) + ) + else: + chunk_gen = GenerationChunk(text=chunk) + + if acc_gen is None: + acc_gen = chunk_gen + else: + acc_gen += chunk_gen + + parsed = self.parse_result([acc_gen]) + if parsed is not None: + yield parsed + + class StrOutputParser(BaseTransformOutputParser[str]): """OutputParser that parses LLMResult into the top likely string.""" From 63f2ef8d1cb26e27371516785cedf2659e777227 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Thu, 28 Sep 2023 20:20:47 +0100 Subject: [PATCH 02/15] Implement str one --- .../langchain/output_parsers/json.py | 52 +++++++++++-------- 1 file changed, 29 insertions(+), 23 deletions(-) diff --git a/libs/langchain/langchain/output_parsers/json.py b/libs/langchain/langchain/output_parsers/json.py index 0168f1ae364..f512ed6e47a 100644 --- a/libs/langchain/langchain/output_parsers/json.py +++ b/libs/langchain/langchain/output_parsers/json.py @@ -172,33 +172,39 @@ class SimpleJsonOutputParser(BaseOutputParser[Any]): return "simple_json_output_parser" -class PartialJsonOutputParser(BaseCumulativeTransformOutputParser[Any]): - from_function_args: bool = False - +class PartialFunctionsJsonOutputParser(BaseCumulativeTransformOutputParser[Any]): @property def _type(self) -> str: - return "partial_json" + return "partial_functions_json" def parse_result(self, result: List[Generation]) -> Any: - if self.from_function_args: - if len(result) != 1: - raise OutputParserException( - f"Expected exactly one result, but got {len(result)}" - ) - generation = result[0] - if not isinstance(generation, ChatGeneration): - raise OutputParserException( - "This output parser can only be used with a chat generation." - ) - message = generation.message - try: - function_call = message.additional_kwargs["function_call"] - except KeyError: - return None - try: - return parse_partial_json(function_call["arguments"]) - except KeyError: - return None + if len(result) != 1: + raise OutputParserException( + f"Expected exactly one result, but got {len(result)}" + ) + generation = result[0] + if not isinstance(generation, ChatGeneration): + raise OutputParserException( + "This output parser can only be used with a chat generation." + ) + message = generation.message + try: + function_call = message.additional_kwargs["function_call"] + except KeyError: + return None + try: + return parse_partial_json(function_call["arguments"]) + except KeyError: + return None def parse(self, text: str) -> Any: pass + + +class PartialJsonOutputParser(BaseCumulativeTransformOutputParser[Any]): + @property + def _type(self) -> str: + return "partial_functions_json" + + def parse(self, text: str) -> Any: + return parse_json_markdown(text) From 6c0a6b70e082314871c20ed00349d4da376512e9 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Fri, 29 Sep 2023 11:43:31 +0100 Subject: [PATCH 03/15] =?UTF-8?q?WIP=20Add=20tests=C2=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../unit_tests/output_parsers/test_json.py | 136 +++++++++++++++++- 1 file changed, 135 insertions(+), 1 deletion(-) diff --git a/libs/langchain/tests/unit_tests/output_parsers/test_json.py b/libs/langchain/tests/unit_tests/output_parsers/test_json.py index 21bc600fd05..fea7b290d05 100644 --- a/libs/langchain/tests/unit_tests/output_parsers/test_json.py +++ b/libs/langchain/tests/unit_tests/output_parsers/test_json.py @@ -1,6 +1,8 @@ +import json +from typing import Iterator, Tuple import pytest -from langchain.output_parsers.json import parse_json_markdown +from langchain.output_parsers.json import parse_json_markdown, parse_partial_json GOOD_JSON = """```json { @@ -183,3 +185,135 @@ def test_parse_json_with_python_dict() -> None: "action": "Final Answer", "action_input": {"foo": "bar", "bar": "foo"}, } + + +TEST_CASES_PARTIAL = [ + ('{"foo": "bar", "bar": "foo"}', '{"foo": "bar", "bar": "foo"}'), + ('{"foo": "bar", "bar": "foo', '{"foo": "bar", "bar": "foo"}'), + ('{"foo": "bar", "bar": "foo}', '{"foo": "bar", "bar": "foo}"}'), + ('{"foo": "bar", "bar": "foo[', '{"foo": "bar", "bar": "foo["}'), + ('{"foo": "bar", "bar": "foo\\"', '{"foo": "bar", "bar": "foo\\""}'), +] + + +@pytest.mark.parametrize("json_strings", TEST_CASES_PARTIAL) +def test_parse_partial_json(json_strings: Tuple[str, str]) -> None: + case, expected = json_strings + parsed = parse_partial_json(case) + assert parsed == json.loads(expected) + + +STREAMED_TOKENS = """ +{ + + + " +setup +": + " +Why + did + the + bears + go + on + a + picnic +?", + + + " +p +unch +line +": + " +Because + they + wanted + to + have + a + bear +-y + good + time +!" + +} +""".splitlines() + +EXPECTED_STREAMED_JSON = [ + {}, + {}, + {"setup": ""}, + {"setup": "Why"}, + {"setup": "Why did"}, + {"setup": "Why did the"}, + {"setup": "Why did the bears"}, + {"setup": "Why did the bears start"}, + {"setup": "Why did the bears start a"}, + {"setup": "Why did the bears start a band"}, + {"setup": "Why did the bears start a band called"}, + {"setup": "Why did the bears start a band called Bears"}, + {"setup": "Why did the bears start a band called Bears Bears"}, + {"setup": "Why did the bears start a band called Bears Bears Bears"}, + { + "setup": "Why did the bears start a band called Bears Bears Bears?", + "punchline": "", + }, + { + "setup": "Why did the bears start a band called Bears Bears Bears?", + "punchline": "Because", + }, + { + "setup": "Why did the bears start a band called Bears Bears Bears?", + "punchline": "Because they", + }, + { + "setup": "Why did the bears start a band called Bears Bears Bears?", + "punchline": "Because they wanted", + }, + { + "setup": "Why did the bears start a band called Bears Bears Bears?", + "punchline": "Because they wanted to", + }, + { + "setup": "Why did the bears start a band called Bears Bears Bears?", + "punchline": "Because they wanted to play", + }, + { + "setup": "Why did the bears start a band called Bears Bears Bears?", + "punchline": "Because they wanted to play bear", + }, + { + "setup": "Why did the bears start a band called Bears Bears Bears?", + "punchline": "Because they wanted to play bear-y", + }, + { + "setup": "Why did the bears start a band called Bears Bears Bears?", + "punchline": "Because they wanted to play bear-y good", + }, + { + "setup": "Why did the bears start a band called Bears Bears Bears?", + "punchline": "Because they wanted to play bear-y good music", + }, + { + "setup": "Why did the bears start a band called Bears Bears Bears?", + "punchline": "Because they wanted to play bear-y good music!", + }, + { + "setup": "Why did the bears start a band called Bears Bears Bears?", + "punchline": "Because they wanted to play bear-y good music!", + }, + { + "setup": "Why did the bears start a band called Bears Bears Bears?", + "punchline": "Because they wanted to play bear-y good music!", + }, +] + + +def test_partial_text_json_output_parser() -> None: + def input_iter() -> Iterator[str]: + for token in STREAMED_TOKENS: + yield token From 5cbe2b7b6aef6299d0ccaadc76dcaab02bf9102f Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Fri, 29 Sep 2023 14:06:07 +0100 Subject: [PATCH 04/15] Implement diff --- .../langchain/output_parsers/json.py | 15 +- .../unit_tests/output_parsers/test_json.py | 251 +++++++++++++++--- 2 files changed, 225 insertions(+), 41 deletions(-) diff --git a/libs/langchain/langchain/output_parsers/json.py b/libs/langchain/langchain/output_parsers/json.py index f512ed6e47a..7a4660ef6b8 100644 --- a/libs/langchain/langchain/output_parsers/json.py +++ b/libs/langchain/langchain/output_parsers/json.py @@ -3,7 +3,9 @@ from __future__ import annotations import json import re from json import JSONDecodeError -from typing import Any, List +from typing import Any, List, Optional + +import jsonpatch from langchain.schema import BaseOutputParser, OutputParserException from langchain.schema.output import ChatGeneration, Generation @@ -42,7 +44,7 @@ def _custom_parser(multiline_string: str) -> str: # Adapted from https://github.com/KillianLucas/open-interpreter/blob/main/interpreter/utils/parse_partial_json.py # MIT License -def parse_partial_json(s): +def parse_partial_json(s: str) -> Any: # Attempt to parse the string as-is. try: return json.loads(s) @@ -84,7 +86,8 @@ def parse_partial_json(s): # Append the processed character to the new string. new_s += char - # If we're still inside a string at the end of processing, we need to close the string. + # If we're still inside a string at the end of processing, + # we need to close the string. if is_inside_string: new_s += '"' @@ -197,6 +200,9 @@ class PartialFunctionsJsonOutputParser(BaseCumulativeTransformOutputParser[Any]) except KeyError: return None + def _diff(self, prev: Optional[Any], next: Any) -> Any: + return jsonpatch.make_patch(prev, next).patch + def parse(self, text: str) -> Any: pass @@ -206,5 +212,8 @@ class PartialJsonOutputParser(BaseCumulativeTransformOutputParser[Any]): def _type(self) -> str: return "partial_functions_json" + def _diff(self, prev: Optional[Any], next: Any) -> Any: + return jsonpatch.make_patch(prev, next).patch + def parse(self, text: str) -> Any: return parse_json_markdown(text) diff --git a/libs/langchain/tests/unit_tests/output_parsers/test_json.py b/libs/langchain/tests/unit_tests/output_parsers/test_json.py index fea7b290d05..00dbd4d3b36 100644 --- a/libs/langchain/tests/unit_tests/output_parsers/test_json.py +++ b/libs/langchain/tests/unit_tests/output_parsers/test_json.py @@ -1,8 +1,15 @@ import json -from typing import Iterator, Tuple +from typing import Any, Iterator, Tuple + import pytest -from langchain.output_parsers.json import parse_json_markdown, parse_partial_json +from langchain.output_parsers.json import ( + PartialFunctionsJsonOutputParser, + PartialJsonOutputParser, + parse_json_markdown, + parse_partial_json, +) +from langchain.schema.messages import AIMessageChunk GOOD_JSON = """```json { @@ -206,7 +213,6 @@ def test_parse_partial_json(json_strings: Tuple[str, str]) -> None: STREAMED_TOKENS = """ { - " setup ": @@ -215,36 +221,50 @@ Why did the bears - go - on + start a - picnic -?", - - + band + called + Bears + Bears + Bears + ? +" +, " -p -unch -line +punchline ": " Because they wanted to - have - a + play bear --y + -y good - time -!" + music + ! +" +, + " +audience +": + [ +" +Haha +" +, + " +So + funny +" +] } """.splitlines() EXPECTED_STREAMED_JSON = [ - {}, {}, {"setup": ""}, {"setup": "Why"}, @@ -258,62 +278,217 @@ EXPECTED_STREAMED_JSON = [ {"setup": "Why did the bears start a band called Bears"}, {"setup": "Why did the bears start a band called Bears Bears"}, {"setup": "Why did the bears start a band called Bears Bears Bears"}, + {"setup": "Why did the bears start a band called Bears Bears Bears ?"}, { - "setup": "Why did the bears start a band called Bears Bears Bears?", + "setup": "Why did the bears start a band called Bears Bears Bears ?", "punchline": "", }, { - "setup": "Why did the bears start a band called Bears Bears Bears?", + "setup": "Why did the bears start a band called Bears Bears Bears ?", "punchline": "Because", }, { - "setup": "Why did the bears start a band called Bears Bears Bears?", + "setup": "Why did the bears start a band called Bears Bears Bears ?", "punchline": "Because they", }, { - "setup": "Why did the bears start a band called Bears Bears Bears?", + "setup": "Why did the bears start a band called Bears Bears Bears ?", "punchline": "Because they wanted", }, { - "setup": "Why did the bears start a band called Bears Bears Bears?", + "setup": "Why did the bears start a band called Bears Bears Bears ?", "punchline": "Because they wanted to", }, { - "setup": "Why did the bears start a band called Bears Bears Bears?", + "setup": "Why did the bears start a band called Bears Bears Bears ?", "punchline": "Because they wanted to play", }, { - "setup": "Why did the bears start a band called Bears Bears Bears?", + "setup": "Why did the bears start a band called Bears Bears Bears ?", "punchline": "Because they wanted to play bear", }, { - "setup": "Why did the bears start a band called Bears Bears Bears?", - "punchline": "Because they wanted to play bear-y", + "setup": "Why did the bears start a band called Bears Bears Bears ?", + "punchline": "Because they wanted to play bear -y", }, { - "setup": "Why did the bears start a band called Bears Bears Bears?", - "punchline": "Because they wanted to play bear-y good", + "setup": "Why did the bears start a band called Bears Bears Bears ?", + "punchline": "Because they wanted to play bear -y good", }, { - "setup": "Why did the bears start a band called Bears Bears Bears?", - "punchline": "Because they wanted to play bear-y good music", + "setup": "Why did the bears start a band called Bears Bears Bears ?", + "punchline": "Because they wanted to play bear -y good music", }, { - "setup": "Why did the bears start a band called Bears Bears Bears?", - "punchline": "Because they wanted to play bear-y good music!", + "setup": "Why did the bears start a band called Bears Bears Bears ?", + "punchline": "Because they wanted to play bear -y good music !", }, { - "setup": "Why did the bears start a band called Bears Bears Bears?", - "punchline": "Because they wanted to play bear-y good music!", + "punchline": "Because they wanted to play bear -y good music !", + "setup": "Why did the bears start a band called Bears Bears Bears ?", + "audience": [], }, { - "setup": "Why did the bears start a band called Bears Bears Bears?", - "punchline": "Because they wanted to play bear-y good music!", + "punchline": "Because they wanted to play bear -y good music !", + "setup": "Why did the bears start a band called Bears Bears Bears ?", + "audience": [""], }, + { + "punchline": "Because they wanted to play bear -y good music !", + "setup": "Why did the bears start a band called Bears Bears Bears ?", + "audience": ["Haha"], + }, + { + "punchline": "Because they wanted to play bear -y good music !", + "setup": "Why did the bears start a band called Bears Bears Bears ?", + "audience": ["Haha", ""], + }, + { + "punchline": "Because they wanted to play bear -y good music !", + "setup": "Why did the bears start a band called Bears Bears Bears ?", + "audience": ["Haha", "So"], + }, + { + "punchline": "Because they wanted to play bear -y good music !", + "setup": "Why did the bears start a band called Bears Bears Bears ?", + "audience": ["Haha", "So funny"], + }, +] + +EXPECTED_STREAMED_JSON_DIFF = [ + [{"op": "replace", "path": "", "value": {}}], + [{"op": "add", "path": "/setup", "value": ""}], + [{"op": "replace", "path": "/setup", "value": "Why"}], + [{"op": "replace", "path": "/setup", "value": "Why did"}], + [{"op": "replace", "path": "/setup", "value": "Why did the"}], + [{"op": "replace", "path": "/setup", "value": "Why did the bears"}], + [{"op": "replace", "path": "/setup", "value": "Why did the bears start"}], + [{"op": "replace", "path": "/setup", "value": "Why did the bears start a"}], + [{"op": "replace", "path": "/setup", "value": "Why did the bears start a band"}], + [ + { + "op": "replace", + "path": "/setup", + "value": "Why did the bears start a band called", + } + ], + [ + { + "op": "replace", + "path": "/setup", + "value": "Why did the bears start a band called Bears", + } + ], + [ + { + "op": "replace", + "path": "/setup", + "value": "Why did the bears start a band called Bears Bears", + } + ], + [ + { + "op": "replace", + "path": "/setup", + "value": "Why did the bears start a band called Bears Bears Bears", + } + ], + [ + { + "op": "replace", + "path": "/setup", + "value": "Why did the bears start a band called Bears Bears Bears ?", + } + ], + [{"op": "add", "path": "/punchline", "value": ""}], + [{"op": "replace", "path": "/punchline", "value": "Because"}], + [{"op": "replace", "path": "/punchline", "value": "Because they"}], + [{"op": "replace", "path": "/punchline", "value": "Because they wanted"}], + [{"op": "replace", "path": "/punchline", "value": "Because they wanted to"}], + [{"op": "replace", "path": "/punchline", "value": "Because they wanted to play"}], + [ + { + "op": "replace", + "path": "/punchline", + "value": "Because they wanted to play bear", + } + ], + [ + { + "op": "replace", + "path": "/punchline", + "value": "Because they wanted to play bear -y", + } + ], + [ + { + "op": "replace", + "path": "/punchline", + "value": "Because they wanted to play bear -y good", + } + ], + [ + { + "op": "replace", + "path": "/punchline", + "value": "Because they wanted to play bear -y good music", + } + ], + [ + { + "op": "replace", + "path": "/punchline", + "value": "Because they wanted to play bear -y good music !", + } + ], + [{"op": "add", "path": "/audience", "value": []}], + [{"op": "add", "path": "/audience/0", "value": ""}], + [{"op": "replace", "path": "/audience/0", "value": "Haha"}], + [{"op": "add", "path": "/audience/1", "value": ""}], + [{"op": "replace", "path": "/audience/1", "value": "So"}], + [{"op": "replace", "path": "/audience/1", "value": "So funny"}], ] def test_partial_text_json_output_parser() -> None: - def input_iter() -> Iterator[str]: + def input_iter(_: Any) -> Iterator[str]: for token in STREAMED_TOKENS: yield token + + chain = input_iter | PartialJsonOutputParser() + + assert list(chain.stream(None)) == EXPECTED_STREAMED_JSON + + +def test_partial_functions_json_output_parser() -> None: + def input_iter(_: Any) -> Iterator[AIMessageChunk]: + for token in STREAMED_TOKENS: + yield AIMessageChunk( + content="", additional_kwargs={"function_call": {"arguments": token}} + ) + + chain = input_iter | PartialFunctionsJsonOutputParser() + + assert list(chain.stream(None)) == EXPECTED_STREAMED_JSON + + +def test_partial_text_json_output_parser_diff() -> None: + def input_iter(_: Any) -> Iterator[str]: + for token in STREAMED_TOKENS: + yield token + + chain = input_iter | PartialJsonOutputParser(diff=True) + + assert list(chain.stream(None)) == EXPECTED_STREAMED_JSON_DIFF + + +def test_partial_functions_json_output_parser_diff() -> None: + def input_iter(_: Any) -> Iterator[AIMessageChunk]: + for token in STREAMED_TOKENS: + yield AIMessageChunk( + content="", additional_kwargs={"function_call": {"arguments": token}} + ) + + chain = input_iter | PartialFunctionsJsonOutputParser(diff=True) + + assert list(chain.stream(None)) == EXPECTED_STREAMED_JSON_DIFF From 4e28a7a5132e4e49bffe6bd261aa40d22214b813 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Fri, 29 Sep 2023 14:12:48 +0100 Subject: [PATCH 05/15] Implement diff --- .../langchain/schema/output_parser.py | 29 ++++++++++++++----- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/libs/langchain/langchain/schema/output_parser.py b/libs/langchain/langchain/schema/output_parser.py index 5f09ee48ae6..89a065e9ad6 100644 --- a/libs/langchain/langchain/schema/output_parser.py +++ b/libs/langchain/langchain/schema/output_parser.py @@ -337,11 +337,17 @@ class BaseTransformOutputParser(BaseOutputParser[T]): class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]): """Base class for an output parser that can handle streaming input.""" + diff: bool = False + + def _diff(self, prev: Optional[T], next: T) -> T: + raise NotImplementedError() + def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[Any]: + prev_parsed = None acc_gen = None for chunk in input: if isinstance(chunk, BaseMessageChunk): - chunk_gen = ChatGenerationChunk(message=chunk) + chunk_gen: Generation = ChatGenerationChunk(message=chunk) elif isinstance(chunk, BaseMessage): chunk_gen = ChatGenerationChunk( message=BaseMessageChunk(**chunk.dict()) @@ -355,16 +361,21 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]): acc_gen += chunk_gen parsed = self.parse_result([acc_gen]) - if parsed is not None: - yield parsed + if parsed is not None and parsed != prev_parsed: + if self.diff: + yield self._diff(prev_parsed, parsed) + else: + yield parsed + prev_parsed = parsed async def _atransform( self, input: AsyncIterator[Union[str, BaseMessage]] ) -> AsyncIterator[T]: + prev_parsed = None acc_gen = None - for chunk in input: + async for chunk in input: if isinstance(chunk, BaseMessageChunk): - chunk_gen = ChatGenerationChunk(message=chunk) + chunk_gen: Generation = ChatGenerationChunk(message=chunk) elif isinstance(chunk, BaseMessage): chunk_gen = ChatGenerationChunk( message=BaseMessageChunk(**chunk.dict()) @@ -378,8 +389,12 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]): acc_gen += chunk_gen parsed = self.parse_result([acc_gen]) - if parsed is not None: - yield parsed + if parsed is not None and parsed != prev_parsed: + if self.diff: + yield self._diff(prev_parsed, parsed) + else: + yield parsed + prev_parsed = parsed class StrOutputParser(BaseTransformOutputParser[str]): From 091d8845d5246875a29a67858f8148b28e287530 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Fri, 29 Sep 2023 14:18:38 +0100 Subject: [PATCH 06/15] Backwards compat --- libs/langchain/langchain/output_parsers/json.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/libs/langchain/langchain/output_parsers/json.py b/libs/langchain/langchain/output_parsers/json.py index 7a4660ef6b8..d4423a0d2f4 100644 --- a/libs/langchain/langchain/output_parsers/json.py +++ b/libs/langchain/langchain/output_parsers/json.py @@ -3,7 +3,7 @@ from __future__ import annotations import json import re from json import JSONDecodeError -from typing import Any, List, Optional +from typing import Any, Callable, List, Optional import jsonpatch @@ -103,7 +103,9 @@ def parse_partial_json(s: str) -> Any: return None -def parse_json_markdown(json_string: str) -> dict: +def parse_json_markdown( + json_string: str, parser: Callable[[str], Any] = json.loads +) -> dict: """ Parse a JSON string from a Markdown string. @@ -130,7 +132,7 @@ def parse_json_markdown(json_string: str) -> dict: json_str = _custom_parser(json_str) # Parse the JSON string into a Python dictionary - parsed = parse_partial_json(json_str) + parsed = parser(json_str) return parsed @@ -216,4 +218,4 @@ class PartialJsonOutputParser(BaseCumulativeTransformOutputParser[Any]): return jsonpatch.make_patch(prev, next).patch def parse(self, text: str) -> Any: - return parse_json_markdown(text) + return parse_json_markdown(text, parse_partial_json) From 3d8aa88e26b8f28c32f47ef9ca3c266b7ed15975 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Fri, 29 Sep 2023 15:28:46 +0100 Subject: [PATCH 07/15] Add async tests and comments --- .../langchain/output_parsers/json.py | 4 +- .../langchain/schema/output_parser.py | 2 + .../unit_tests/output_parsers/test_json.py | 50 ++++++++++++++++++- 3 files changed, 54 insertions(+), 2 deletions(-) diff --git a/libs/langchain/langchain/output_parsers/json.py b/libs/langchain/langchain/output_parsers/json.py index d4423a0d2f4..aafaedb67de 100644 --- a/libs/langchain/langchain/output_parsers/json.py +++ b/libs/langchain/langchain/output_parsers/json.py @@ -205,8 +205,10 @@ class PartialFunctionsJsonOutputParser(BaseCumulativeTransformOutputParser[Any]) def _diff(self, prev: Optional[Any], next: Any) -> Any: return jsonpatch.make_patch(prev, next).patch + # This method would be called by the default implementation of `parse_result` + # but we're overriding that method so it's not needed. def parse(self, text: str) -> Any: - pass + raise NotImplementedError() class PartialJsonOutputParser(BaseCumulativeTransformOutputParser[Any]): diff --git a/libs/langchain/langchain/schema/output_parser.py b/libs/langchain/langchain/schema/output_parser.py index 89a065e9ad6..6d2e3888933 100644 --- a/libs/langchain/langchain/schema/output_parser.py +++ b/libs/langchain/langchain/schema/output_parser.py @@ -340,6 +340,8 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]): diff: bool = False def _diff(self, prev: Optional[T], next: T) -> T: + """Convert parsed outputs into a diff format. The semantics of this are + up to the output parser.""" raise NotImplementedError() def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[Any]: diff --git a/libs/langchain/tests/unit_tests/output_parsers/test_json.py b/libs/langchain/tests/unit_tests/output_parsers/test_json.py index 00dbd4d3b36..b9daee1a51c 100644 --- a/libs/langchain/tests/unit_tests/output_parsers/test_json.py +++ b/libs/langchain/tests/unit_tests/output_parsers/test_json.py @@ -1,5 +1,5 @@ import json -from typing import Any, Iterator, Tuple +from typing import Any, AsyncIterator, Iterator, Tuple import pytest @@ -492,3 +492,51 @@ def test_partial_functions_json_output_parser_diff() -> None: chain = input_iter | PartialFunctionsJsonOutputParser(diff=True) assert list(chain.stream(None)) == EXPECTED_STREAMED_JSON_DIFF + + +@pytest.mark.asyncio +async def test_partial_text_json_output_parser_async() -> None: + async def input_iter(_: Any) -> AsyncIterator[str]: + for token in STREAMED_TOKENS: + yield token + + chain = input_iter | PartialJsonOutputParser() + + assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON + + +@pytest.mark.asyncio +async def test_partial_functions_json_output_parser_async() -> None: + async def input_iter(_: Any) -> AsyncIterator[AIMessageChunk]: + for token in STREAMED_TOKENS: + yield AIMessageChunk( + content="", additional_kwargs={"function_call": {"arguments": token}} + ) + + chain = input_iter | PartialFunctionsJsonOutputParser() + + assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON + + +@pytest.mark.asyncio +async def test_partial_text_json_output_parser_diff_async() -> None: + async def input_iter(_: Any) -> AsyncIterator[str]: + for token in STREAMED_TOKENS: + yield token + + chain = input_iter | PartialJsonOutputParser(diff=True) + + assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON_DIFF + + +@pytest.mark.asyncio +async def test_partial_functions_json_output_parser_diff_async() -> None: + async def input_iter(_: Any) -> AsyncIterator[AIMessageChunk]: + for token in STREAMED_TOKENS: + yield AIMessageChunk( + content="", additional_kwargs={"function_call": {"arguments": token}} + ) + + chain = input_iter | PartialFunctionsJsonOutputParser(diff=True) + + assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON_DIFF From 4b8442896b9adefa6536255df6c995ab57fd27de Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Fri, 29 Sep 2023 16:50:00 +0100 Subject: [PATCH 08/15] Make test deterministic --- .../tests/unit_tests/schema/runnable/test_runnable.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py index 4a63f92ff2a..9d1d131f0d3 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py @@ -582,7 +582,9 @@ async def test_with_config(mocker: MockerFixture) -> None: ) == [5, 7] assert len(spy.call_args_list) == 2 - for i, call in enumerate(spy.call_args_list): + for i, call in enumerate( + sorted(spy.call_args_list, key=lambda x: 0 if x.args[0] == "hello" else 1) + ): assert call.args[0] == ("hello" if i == 0 else "wooorld") if i == 0: assert call.args[1].get("recursion_limit") == 5 From c9d0f2b9843d0fbf7b2374a8204505941d824d76 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Fri, 29 Sep 2023 17:55:30 +0100 Subject: [PATCH 09/15] Combine with existing json output parsers --- .../langchain/output_parsers/json.py | 75 +++++-------------- .../output_parsers/openai_functions.py | 65 ++++++++++------ .../langchain/schema/output_parser.py | 3 + .../unit_tests/output_parsers/test_json.py | 20 ++--- 4 files changed, 77 insertions(+), 86 deletions(-) diff --git a/libs/langchain/langchain/output_parsers/json.py b/libs/langchain/langchain/output_parsers/json.py index aafaedb67de..946aeb6408b 100644 --- a/libs/langchain/langchain/output_parsers/json.py +++ b/libs/langchain/langchain/output_parsers/json.py @@ -7,9 +7,10 @@ from typing import Any, Callable, List, Optional import jsonpatch -from langchain.schema import BaseOutputParser, OutputParserException -from langchain.schema.output import ChatGeneration, Generation -from langchain.schema.output_parser import BaseCumulativeTransformOutputParser +from langchain.schema.output_parser import ( + BaseCumulativeTransformOutputParser, + OutputParserException, +) def _replace_new_line(match: re.Match[str]) -> str: @@ -44,10 +45,10 @@ def _custom_parser(multiline_string: str) -> str: # Adapted from https://github.com/KillianLucas/open-interpreter/blob/main/interpreter/utils/parse_partial_json.py # MIT License -def parse_partial_json(s: str) -> Any: +def parse_partial_json(s: str, *, strict: bool = False) -> Any: # Attempt to parse the string as-is. try: - return json.loads(s) + return json.loads(s, strict=strict) except json.JSONDecodeError: pass @@ -97,7 +98,7 @@ def parse_partial_json(s: str) -> Any: # Attempt to parse the modified string as JSON. try: - return json.loads(new_s) + return json.loads(new_s, strict=strict) except json.JSONDecodeError: # If we still can't parse the string as JSON, return None to indicate failure. return None @@ -162,62 +163,26 @@ def parse_and_check_json_markdown(text: str, expected_keys: List[str]) -> dict: return json_obj -class SimpleJsonOutputParser(BaseOutputParser[Any]): - """Parse the output of an LLM call to a JSON object.""" +class SimpleJsonOutputParser(BaseCumulativeTransformOutputParser[Any]): + """Parse the output of an LLM call to a JSON object. + + When used in streaming mode, it will yield partial JSON objects containing + all the keys that have been returned so far. + + In streaming, if `diff` is set to `True`, yields JSONPatch operations + describing the difference between the previous and the current object. + """ + + def _diff(self, prev: Optional[Any], next: Any) -> Any: + return jsonpatch.make_patch(prev, next).patch def parse(self, text: str) -> Any: text = text.strip() try: - return parse_partial_json(text) + return parse_json_markdown(text.strip(), parse_partial_json) except JSONDecodeError as e: raise OutputParserException(f"Invalid json output: {text}") from e @property def _type(self) -> str: return "simple_json_output_parser" - - -class PartialFunctionsJsonOutputParser(BaseCumulativeTransformOutputParser[Any]): - @property - def _type(self) -> str: - return "partial_functions_json" - - def parse_result(self, result: List[Generation]) -> Any: - if len(result) != 1: - raise OutputParserException( - f"Expected exactly one result, but got {len(result)}" - ) - generation = result[0] - if not isinstance(generation, ChatGeneration): - raise OutputParserException( - "This output parser can only be used with a chat generation." - ) - message = generation.message - try: - function_call = message.additional_kwargs["function_call"] - except KeyError: - return None - try: - return parse_partial_json(function_call["arguments"]) - except KeyError: - return None - - def _diff(self, prev: Optional[Any], next: Any) -> Any: - return jsonpatch.make_patch(prev, next).patch - - # This method would be called by the default implementation of `parse_result` - # but we're overriding that method so it's not needed. - def parse(self, text: str) -> Any: - raise NotImplementedError() - - -class PartialJsonOutputParser(BaseCumulativeTransformOutputParser[Any]): - @property - def _type(self) -> str: - return "partial_functions_json" - - def _diff(self, prev: Optional[Any], next: Any) -> Any: - return jsonpatch.make_patch(prev, next).patch - - def parse(self, text: str) -> Any: - return parse_json_markdown(text, parse_partial_json) diff --git a/libs/langchain/langchain/output_parsers/openai_functions.py b/libs/langchain/langchain/output_parsers/openai_functions.py index cabafd599de..f0016b3e337 100644 --- a/libs/langchain/langchain/output_parsers/openai_functions.py +++ b/libs/langchain/langchain/output_parsers/openai_functions.py @@ -1,14 +1,20 @@ import copy import json -from typing import Any, Dict, List, Type, Union +from typing import Any, Dict, List, Optional, Type, Union +import jsonpatch + +from langchain.output_parsers.json import parse_partial_json from langchain.pydantic_v1 import BaseModel, root_validator from langchain.schema import ( ChatGeneration, Generation, OutputParserException, ) -from langchain.schema.output_parser import BaseGenerationOutputParser +from langchain.schema.output_parser import ( + BaseCumulativeTransformOutputParser, + BaseGenerationOutputParser, +) class OutputFunctionsParser(BaseGenerationOutputParser[Any]): @@ -34,7 +40,7 @@ class OutputFunctionsParser(BaseGenerationOutputParser[Any]): return func_call -class JsonOutputFunctionsParser(OutputFunctionsParser): +class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]): """Parse an output as the Json object.""" strict: bool = False @@ -45,25 +51,42 @@ class JsonOutputFunctionsParser(OutputFunctionsParser): Useful when the parsed output may include unicode characters or new lines. """ + args_only: bool = True + """Whether to only return the arguments to the function call.""" + + def _diff(self, prev: Optional[Any], next: Any) -> Any: + return jsonpatch.make_patch(prev, next).patch + def parse_result(self, result: List[Generation]) -> Any: - function_call_info = super().parse_result(result) - if self.args_only: - try: - return json.loads(function_call_info, strict=self.strict) - except (json.JSONDecodeError, TypeError) as exc: - raise OutputParserException( - f"Could not parse function call data: {exc}" - ) - else: - try: - function_call_info["arguments"] = json.loads( - function_call_info["arguments"], strict=self.strict - ) - except (json.JSONDecodeError, TypeError) as exc: - raise OutputParserException( - f"Could not parse function call data: {exc}" - ) - return function_call_info + if len(result) != 1: + raise OutputParserException( + f"Expected exactly one result, but got {len(result)}" + ) + generation = result[0] + if not isinstance(generation, ChatGeneration): + raise OutputParserException( + "This output parser can only be used with a chat generation." + ) + message = generation.message + try: + function_call = message.additional_kwargs["function_call"] + except KeyError: + return None + try: + if self.args_only: + return parse_partial_json(function_call["arguments"]) + else: + return { + **function_call, + "arguments": parse_partial_json(function_call["arguments"]), + } + except KeyError: + return None + + # This method would be called by the default implementation of `parse_result` + # but we're overriding that method so it's not needed. + def parse(self, text: str) -> Any: + raise NotImplementedError() class JsonKeyOutputFunctionsParser(JsonOutputFunctionsParser): diff --git a/libs/langchain/langchain/schema/output_parser.py b/libs/langchain/langchain/schema/output_parser.py index 6d2e3888933..58587302562 100644 --- a/libs/langchain/langchain/schema/output_parser.py +++ b/libs/langchain/langchain/schema/output_parser.py @@ -338,6 +338,9 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]): """Base class for an output parser that can handle streaming input.""" diff: bool = False + """In streaming mode, whether to yield diffs between the previous and current + parsed output, or just the current parsed output. + """ def _diff(self, prev: Optional[T], next: T) -> T: """Convert parsed outputs into a diff format. The semantics of this are diff --git a/libs/langchain/tests/unit_tests/output_parsers/test_json.py b/libs/langchain/tests/unit_tests/output_parsers/test_json.py index b9daee1a51c..90b2d5a7da5 100644 --- a/libs/langchain/tests/unit_tests/output_parsers/test_json.py +++ b/libs/langchain/tests/unit_tests/output_parsers/test_json.py @@ -4,12 +4,12 @@ from typing import Any, AsyncIterator, Iterator, Tuple import pytest from langchain.output_parsers.json import ( - PartialFunctionsJsonOutputParser, - PartialJsonOutputParser, + SimpleJsonOutputParser, parse_json_markdown, parse_partial_json, ) from langchain.schema.messages import AIMessageChunk +from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser GOOD_JSON = """```json { @@ -455,7 +455,7 @@ def test_partial_text_json_output_parser() -> None: for token in STREAMED_TOKENS: yield token - chain = input_iter | PartialJsonOutputParser() + chain = input_iter | SimpleJsonOutputParser() assert list(chain.stream(None)) == EXPECTED_STREAMED_JSON @@ -467,7 +467,7 @@ def test_partial_functions_json_output_parser() -> None: content="", additional_kwargs={"function_call": {"arguments": token}} ) - chain = input_iter | PartialFunctionsJsonOutputParser() + chain = input_iter | JsonOutputFunctionsParser() assert list(chain.stream(None)) == EXPECTED_STREAMED_JSON @@ -477,7 +477,7 @@ def test_partial_text_json_output_parser_diff() -> None: for token in STREAMED_TOKENS: yield token - chain = input_iter | PartialJsonOutputParser(diff=True) + chain = input_iter | SimpleJsonOutputParser(diff=True) assert list(chain.stream(None)) == EXPECTED_STREAMED_JSON_DIFF @@ -489,7 +489,7 @@ def test_partial_functions_json_output_parser_diff() -> None: content="", additional_kwargs={"function_call": {"arguments": token}} ) - chain = input_iter | PartialFunctionsJsonOutputParser(diff=True) + chain = input_iter | JsonOutputFunctionsParser(diff=True) assert list(chain.stream(None)) == EXPECTED_STREAMED_JSON_DIFF @@ -500,7 +500,7 @@ async def test_partial_text_json_output_parser_async() -> None: for token in STREAMED_TOKENS: yield token - chain = input_iter | PartialJsonOutputParser() + chain = input_iter | SimpleJsonOutputParser() assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON @@ -513,7 +513,7 @@ async def test_partial_functions_json_output_parser_async() -> None: content="", additional_kwargs={"function_call": {"arguments": token}} ) - chain = input_iter | PartialFunctionsJsonOutputParser() + chain = input_iter | JsonOutputFunctionsParser() assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON @@ -524,7 +524,7 @@ async def test_partial_text_json_output_parser_diff_async() -> None: for token in STREAMED_TOKENS: yield token - chain = input_iter | PartialJsonOutputParser(diff=True) + chain = input_iter | SimpleJsonOutputParser(diff=True) assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON_DIFF @@ -537,6 +537,6 @@ async def test_partial_functions_json_output_parser_diff_async() -> None: content="", additional_kwargs={"function_call": {"arguments": token}} ) - chain = input_iter | PartialFunctionsJsonOutputParser(diff=True) + chain = input_iter | JsonOutputFunctionsParser(diff=True) assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON_DIFF From 1f30e25681332ab6e827a1b1fd59b1e79784871a Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Fri, 29 Sep 2023 18:03:41 +0100 Subject: [PATCH 10/15] Lint --- libs/langchain/langchain/output_parsers/openai_functions.py | 1 - libs/langchain/tests/unit_tests/output_parsers/test_json.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/libs/langchain/langchain/output_parsers/openai_functions.py b/libs/langchain/langchain/output_parsers/openai_functions.py index f0016b3e337..098f4ab3723 100644 --- a/libs/langchain/langchain/output_parsers/openai_functions.py +++ b/libs/langchain/langchain/output_parsers/openai_functions.py @@ -1,5 +1,4 @@ import copy -import json from typing import Any, Dict, List, Optional, Type, Union import jsonpatch diff --git a/libs/langchain/tests/unit_tests/output_parsers/test_json.py b/libs/langchain/tests/unit_tests/output_parsers/test_json.py index 90b2d5a7da5..6fe7cb27dd1 100644 --- a/libs/langchain/tests/unit_tests/output_parsers/test_json.py +++ b/libs/langchain/tests/unit_tests/output_parsers/test_json.py @@ -8,8 +8,8 @@ from langchain.output_parsers.json import ( parse_json_markdown, parse_partial_json, ) -from langchain.schema.messages import AIMessageChunk from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser +from langchain.schema.messages import AIMessageChunk GOOD_JSON = """```json { From aa8b4120a8e0ae447f3a6f450b5223fd4bf9fca2 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Fri, 29 Sep 2023 19:21:27 +0100 Subject: [PATCH 11/15] Keep exceptions when not in streaming mode --- .../agents/output_parsers/openai_functions.py | 4 +- .../output_parsers/openai_functions.py | 59 ++++++++++++++----- .../langchain/schema/output_parser.py | 16 +++-- 3 files changed, 58 insertions(+), 21 deletions(-) diff --git a/libs/langchain/langchain/agents/output_parsers/openai_functions.py b/libs/langchain/langchain/agents/output_parsers/openai_functions.py index 7a2a2702d6b..9ba41963a89 100644 --- a/libs/langchain/langchain/agents/output_parsers/openai_functions.py +++ b/libs/langchain/langchain/agents/output_parsers/openai_functions.py @@ -75,7 +75,9 @@ class OpenAIFunctionsAgentOutputParser(AgentOutputParser): return_values={"output": message.content}, log=message.content ) - def parse_result(self, result: List[Generation]) -> Union[AgentAction, AgentFinish]: + def parse_result( + self, result: List[Generation], *, partial: bool = False + ) -> Union[AgentAction, AgentFinish]: if not isinstance(result[0], ChatGeneration): raise ValueError("This output parser only works on ChatGeneration output") message = result[0].message diff --git a/libs/langchain/langchain/output_parsers/openai_functions.py b/libs/langchain/langchain/output_parsers/openai_functions.py index 098f4ab3723..8724035e4b0 100644 --- a/libs/langchain/langchain/output_parsers/openai_functions.py +++ b/libs/langchain/langchain/output_parsers/openai_functions.py @@ -1,4 +1,5 @@ import copy +import json from typing import Any, Dict, List, Optional, Type, Union import jsonpatch @@ -22,7 +23,7 @@ class OutputFunctionsParser(BaseGenerationOutputParser[Any]): args_only: bool = True """Whether to only return the arguments to the function call.""" - def parse_result(self, result: List[Generation]) -> Any: + def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: generation = result[0] if not isinstance(generation, ChatGeneration): raise OutputParserException( @@ -56,7 +57,7 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]): def _diff(self, prev: Optional[Any], next: Any) -> Any: return jsonpatch.make_patch(prev, next).patch - def parse_result(self, result: List[Generation]) -> Any: + def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: if len(result) != 1: raise OutputParserException( f"Expected exactly one result, but got {len(result)}" @@ -69,16 +70,46 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]): message = generation.message try: function_call = message.additional_kwargs["function_call"] - except KeyError: - return None - try: - if self.args_only: - return parse_partial_json(function_call["arguments"]) + except KeyError as exc: + if partial: + return None else: - return { - **function_call, - "arguments": parse_partial_json(function_call["arguments"]), - } + raise OutputParserException(f"Could not parse function call: {exc}") + try: + if partial: + if self.args_only: + return parse_partial_json( + function_call["arguments"], strict=self.strict + ) + else: + return { + **function_call, + "arguments": parse_partial_json( + function_call["arguments"], strict=self.strict + ), + } + else: + if self.args_only: + try: + return json.loads( + function_call["arguments"], strict=self.strict + ) + except (json.JSONDecodeError, TypeError) as exc: + raise OutputParserException( + f"Could not parse function call data: {exc}" + ) + else: + try: + return { + **function_call, + "arguments": json.loads( + function_call["arguments"], strict=self.strict + ), + } + except (json.JSONDecodeError, TypeError) as exc: + raise OutputParserException( + f"Could not parse function call data: {exc}" + ) except KeyError: return None @@ -94,7 +125,7 @@ class JsonKeyOutputFunctionsParser(JsonOutputFunctionsParser): key_name: str """The name of the key to return.""" - def parse_result(self, result: List[Generation]) -> Any: + def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: res = super().parse_result(result) return res[self.key_name] @@ -119,7 +150,7 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser): ) return values - def parse_result(self, result: List[Generation]) -> Any: + def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: _result = super().parse_result(result) if self.args_only: pydantic_args = self.pydantic_schema.parse_raw(_result) # type: ignore @@ -136,6 +167,6 @@ class PydanticAttrOutputFunctionsParser(PydanticOutputFunctionsParser): attr_name: str """The name of the attribute to return.""" - def parse_result(self, result: List[Generation]) -> Any: + def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: result = super().parse_result(result) return getattr(result, self.attr_name) diff --git a/libs/langchain/langchain/schema/output_parser.py b/libs/langchain/langchain/schema/output_parser.py index 58587302562..157c1cd5f0f 100644 --- a/libs/langchain/langchain/schema/output_parser.py +++ b/libs/langchain/langchain/schema/output_parser.py @@ -34,7 +34,7 @@ class BaseLLMOutputParser(Serializable, Generic[T], ABC): """Abstract base class for parsing the outputs of a model.""" @abstractmethod - def parse_result(self, result: List[Generation]) -> T: + def parse_result(self, result: List[Generation], *, partial: bool = False) -> T: """Parse a list of candidate model Generations into a specific format. Args: @@ -45,7 +45,9 @@ class BaseLLMOutputParser(Serializable, Generic[T], ABC): Structured output. """ - async def aparse_result(self, result: List[Generation]) -> T: + async def aparse_result( + self, result: List[Generation], *, partial: bool = False + ) -> T: """Parse a list of candidate model Generations into a specific format. Args: @@ -205,7 +207,7 @@ class BaseOutputParser(BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T] run_type="parser", ) - def parse_result(self, result: List[Generation]) -> T: + def parse_result(self, result: List[Generation], *, partial: bool = False) -> T: """Parse a list of candidate model Generations into a specific format. The return value is parsed from only the first Generation in the result, which @@ -231,7 +233,9 @@ class BaseOutputParser(BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T] Structured output. """ - async def aparse_result(self, result: List[Generation]) -> T: + async def aparse_result( + self, result: List[Generation], *, partial: bool = False + ) -> T: """Parse a list of candidate model Generations into a specific format. The return value is parsed from only the first Generation in the result, which @@ -365,7 +369,7 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]): else: acc_gen += chunk_gen - parsed = self.parse_result([acc_gen]) + parsed = self.parse_result([acc_gen], partial=True) if parsed is not None and parsed != prev_parsed: if self.diff: yield self._diff(prev_parsed, parsed) @@ -393,7 +397,7 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]): else: acc_gen += chunk_gen - parsed = self.parse_result([acc_gen]) + parsed = self.parse_result([acc_gen], partial=True) if parsed is not None and parsed != prev_parsed: if self.diff: yield self._diff(prev_parsed, parsed) From cbe18057b0c5db3508bee7c9606a2ff3e399f2bb Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Fri, 29 Sep 2023 19:34:27 +0100 Subject: [PATCH 12/15] Update json.py Co-authored-by: Eugene Yurtsev --- libs/langchain/langchain/output_parsers/json.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/langchain/langchain/output_parsers/json.py b/libs/langchain/langchain/output_parsers/json.py index 946aeb6408b..9efa829bd27 100644 --- a/libs/langchain/langchain/output_parsers/json.py +++ b/libs/langchain/langchain/output_parsers/json.py @@ -179,7 +179,7 @@ class SimpleJsonOutputParser(BaseCumulativeTransformOutputParser[Any]): def parse(self, text: str) -> Any: text = text.strip() try: - return parse_json_markdown(text.strip(), parse_partial_json) + return parse_json_markdown(text.strip(), parser=parse_partial_json) except JSONDecodeError as e: raise OutputParserException(f"Invalid json output: {text}") from e From f6b0b065d361f7735e10c7257d0b0d416a31be47 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Fri, 29 Sep 2023 19:34:35 +0100 Subject: [PATCH 13/15] Update json.py Co-authored-by: Eugene Yurtsev --- libs/langchain/langchain/output_parsers/json.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/langchain/langchain/output_parsers/json.py b/libs/langchain/langchain/output_parsers/json.py index 9efa829bd27..557151f61ab 100644 --- a/libs/langchain/langchain/output_parsers/json.py +++ b/libs/langchain/langchain/output_parsers/json.py @@ -105,7 +105,7 @@ def parse_partial_json(s: str, *, strict: bool = False) -> Any: def parse_json_markdown( - json_string: str, parser: Callable[[str], Any] = json.loads + json_string: str, *, parser: Callable[[str], Any] = json.loads ) -> dict: """ Parse a JSON string from a Markdown string. From f3f3f7181132f78f71f71bff7d6a20d76edc72d2 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Fri, 29 Sep 2023 19:57:34 +0100 Subject: [PATCH 14/15] Lint --- .../langchain/agents/output_parsers/openai_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/langchain/langchain/agents/output_parsers/openai_functions.py b/libs/langchain/langchain/agents/output_parsers/openai_functions.py index 9ba41963a89..3f82ceff962 100644 --- a/libs/langchain/langchain/agents/output_parsers/openai_functions.py +++ b/libs/langchain/langchain/agents/output_parsers/openai_functions.py @@ -84,7 +84,7 @@ class OpenAIFunctionsAgentOutputParser(AgentOutputParser): return self._parse_ai_message(message) async def aparse_result( - self, result: List[Generation] + self, result: List[Generation], *, partial: bool = False ) -> Union[AgentAction, AgentFinish]: return await asyncio.get_running_loop().run_in_executor( None, self.parse_result, result From ee56c616ffdec07ec12510c8a0d568d7ace2588b Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Fri, 29 Sep 2023 20:05:33 +0100 Subject: [PATCH 15/15] Remove flawed test - It is not possible to access properties on classes, only on instances, therefore this test is not something we can implement --- .../output_parsers/test_base_output_parser.py | 55 ------------------- 1 file changed, 55 deletions(-) delete mode 100644 libs/langchain/tests/unit_tests/output_parsers/test_base_output_parser.py diff --git a/libs/langchain/tests/unit_tests/output_parsers/test_base_output_parser.py b/libs/langchain/tests/unit_tests/output_parsers/test_base_output_parser.py deleted file mode 100644 index 3ce6fb976d5..00000000000 --- a/libs/langchain/tests/unit_tests/output_parsers/test_base_output_parser.py +++ /dev/null @@ -1,55 +0,0 @@ -"""Test the BaseOutputParser class and its sub-classes.""" -from abc import ABC -from collections import defaultdict -from typing import List, Optional, Set, Type - -import pytest - -from langchain.schema import BaseOutputParser - - -def non_abstract_subclasses( - cls: Type[ABC], to_skip: Optional[Set] = None -) -> List[Type]: - """Recursively find all non-abstract subclasses of a class.""" - _to_skip = to_skip or set() - subclasses = [] - for subclass in cls.__subclasses__(): - if not getattr(subclass, "__abstractmethods__", None): - if subclass.__name__ not in _to_skip: - subclasses.append(subclass) - subclasses.extend(non_abstract_subclasses(subclass, to_skip=_to_skip)) - return subclasses - - -# parsers defined not in the output_parsers module: -_PARSERS_TO_SKIP = { - "FakeOutputParser", - "BaseOutputParser", - "FinishedOutputParser", - "RouterOutputParser", - "TrajectoryRunEvalOutputParser", -} -_NON_ABSTRACT_PARSERS = non_abstract_subclasses( - BaseOutputParser, to_skip=_PARSERS_TO_SKIP -) - - -@pytest.mark.parametrize("cls", _NON_ABSTRACT_PARSERS) -def test_subclass_implements_type(cls: Type[BaseOutputParser]) -> None: - try: - cls._type - except NotImplementedError: - pytest.fail(f"_type property is not implemented in class {cls.__name__}") - - -def test_all_subclasses_implement_unique_type() -> None: - types = defaultdict(list) - for cls in _NON_ABSTRACT_PARSERS: - try: - types[cls._type].append(cls.__name__) - except NotImplementedError: - # This is handled in the previous test - pass - dups = {t: names for t, names in types.items() if len(names) > 1} - assert not dups, f"Duplicate types: {dups}"