diff --git a/libs/langchain/langchain/agents/output_parsers/openai_functions.py b/libs/langchain/langchain/agents/output_parsers/openai_functions.py index 7a2a2702d6b..3f82ceff962 100644 --- a/libs/langchain/langchain/agents/output_parsers/openai_functions.py +++ b/libs/langchain/langchain/agents/output_parsers/openai_functions.py @@ -75,14 +75,16 @@ class OpenAIFunctionsAgentOutputParser(AgentOutputParser): return_values={"output": message.content}, log=message.content ) - def parse_result(self, result: List[Generation]) -> Union[AgentAction, AgentFinish]: + def parse_result( + self, result: List[Generation], *, partial: bool = False + ) -> Union[AgentAction, AgentFinish]: if not isinstance(result[0], ChatGeneration): raise ValueError("This output parser only works on ChatGeneration output") message = result[0].message return self._parse_ai_message(message) async def aparse_result( - self, result: List[Generation] + self, result: List[Generation], *, partial: bool = False ) -> Union[AgentAction, AgentFinish]: return await asyncio.get_running_loop().run_in_executor( None, self.parse_result, result diff --git a/libs/langchain/langchain/output_parsers/json.py b/libs/langchain/langchain/output_parsers/json.py index 7465aba2fe1..557151f61ab 100644 --- a/libs/langchain/langchain/output_parsers/json.py +++ b/libs/langchain/langchain/output_parsers/json.py @@ -3,9 +3,14 @@ from __future__ import annotations import json import re from json import JSONDecodeError -from typing import Any, List +from typing import Any, Callable, List, Optional -from langchain.schema import BaseOutputParser, OutputParserException +import jsonpatch + +from langchain.schema.output_parser import ( + BaseCumulativeTransformOutputParser, + OutputParserException, +) def _replace_new_line(match: re.Match[str]) -> str: @@ -38,7 +43,70 @@ def _custom_parser(multiline_string: str) -> str: return multiline_string -def parse_json_markdown(json_string: str) -> dict: +# Adapted from https://github.com/KillianLucas/open-interpreter/blob/main/interpreter/utils/parse_partial_json.py +# MIT License +def parse_partial_json(s: str, *, strict: bool = False) -> Any: + # Attempt to parse the string as-is. + try: + return json.loads(s, strict=strict) + except json.JSONDecodeError: + pass + + # Initialize variables. + new_s = "" + stack = [] + is_inside_string = False + escaped = False + + # Process each character in the string one at a time. + for char in s: + if is_inside_string: + if char == '"' and not escaped: + is_inside_string = False + elif char == "\n" and not escaped: + char = "\\n" # Replace the newline character with the escape sequence. + elif char == "\\": + escaped = not escaped + else: + escaped = False + else: + if char == '"': + is_inside_string = True + escaped = False + elif char == "{": + stack.append("}") + elif char == "[": + stack.append("]") + elif char == "}" or char == "]": + if stack and stack[-1] == char: + stack.pop() + else: + # Mismatched closing character; the input is malformed. + return None + + # Append the processed character to the new string. + new_s += char + + # If we're still inside a string at the end of processing, + # we need to close the string. + if is_inside_string: + new_s += '"' + + # Close any remaining open structures in the reverse order that they were opened. + for closing_char in reversed(stack): + new_s += closing_char + + # Attempt to parse the modified string as JSON. + try: + return json.loads(new_s, strict=strict) + except json.JSONDecodeError: + # If we still can't parse the string as JSON, return None to indicate failure. + return None + + +def parse_json_markdown( + json_string: str, *, parser: Callable[[str], Any] = json.loads +) -> dict: """ Parse a JSON string from a Markdown string. @@ -65,7 +133,7 @@ def parse_json_markdown(json_string: str) -> dict: json_str = _custom_parser(json_str) # Parse the JSON string into a Python dictionary - parsed = json.loads(json_str) + parsed = parser(json_str) return parsed @@ -95,13 +163,23 @@ def parse_and_check_json_markdown(text: str, expected_keys: List[str]) -> dict: return json_obj -class SimpleJsonOutputParser(BaseOutputParser[Any]): - """Parse the output of an LLM call to a JSON object.""" +class SimpleJsonOutputParser(BaseCumulativeTransformOutputParser[Any]): + """Parse the output of an LLM call to a JSON object. + + When used in streaming mode, it will yield partial JSON objects containing + all the keys that have been returned so far. + + In streaming, if `diff` is set to `True`, yields JSONPatch operations + describing the difference between the previous and the current object. + """ + + def _diff(self, prev: Optional[Any], next: Any) -> Any: + return jsonpatch.make_patch(prev, next).patch def parse(self, text: str) -> Any: text = text.strip() try: - return json.loads(text) + return parse_json_markdown(text.strip(), parser=parse_partial_json) except JSONDecodeError as e: raise OutputParserException(f"Invalid json output: {text}") from e diff --git a/libs/langchain/langchain/output_parsers/openai_functions.py b/libs/langchain/langchain/output_parsers/openai_functions.py index cabafd599de..8724035e4b0 100644 --- a/libs/langchain/langchain/output_parsers/openai_functions.py +++ b/libs/langchain/langchain/output_parsers/openai_functions.py @@ -1,14 +1,20 @@ import copy import json -from typing import Any, Dict, List, Type, Union +from typing import Any, Dict, List, Optional, Type, Union +import jsonpatch + +from langchain.output_parsers.json import parse_partial_json from langchain.pydantic_v1 import BaseModel, root_validator from langchain.schema import ( ChatGeneration, Generation, OutputParserException, ) -from langchain.schema.output_parser import BaseGenerationOutputParser +from langchain.schema.output_parser import ( + BaseCumulativeTransformOutputParser, + BaseGenerationOutputParser, +) class OutputFunctionsParser(BaseGenerationOutputParser[Any]): @@ -17,7 +23,7 @@ class OutputFunctionsParser(BaseGenerationOutputParser[Any]): args_only: bool = True """Whether to only return the arguments to the function call.""" - def parse_result(self, result: List[Generation]) -> Any: + def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: generation = result[0] if not isinstance(generation, ChatGeneration): raise OutputParserException( @@ -34,7 +40,7 @@ class OutputFunctionsParser(BaseGenerationOutputParser[Any]): return func_call -class JsonOutputFunctionsParser(OutputFunctionsParser): +class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]): """Parse an output as the Json object.""" strict: bool = False @@ -45,25 +51,72 @@ class JsonOutputFunctionsParser(OutputFunctionsParser): Useful when the parsed output may include unicode characters or new lines. """ - def parse_result(self, result: List[Generation]) -> Any: - function_call_info = super().parse_result(result) - if self.args_only: - try: - return json.loads(function_call_info, strict=self.strict) - except (json.JSONDecodeError, TypeError) as exc: - raise OutputParserException( - f"Could not parse function call data: {exc}" - ) - else: - try: - function_call_info["arguments"] = json.loads( - function_call_info["arguments"], strict=self.strict - ) - except (json.JSONDecodeError, TypeError) as exc: - raise OutputParserException( - f"Could not parse function call data: {exc}" - ) - return function_call_info + args_only: bool = True + """Whether to only return the arguments to the function call.""" + + def _diff(self, prev: Optional[Any], next: Any) -> Any: + return jsonpatch.make_patch(prev, next).patch + + def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: + if len(result) != 1: + raise OutputParserException( + f"Expected exactly one result, but got {len(result)}" + ) + generation = result[0] + if not isinstance(generation, ChatGeneration): + raise OutputParserException( + "This output parser can only be used with a chat generation." + ) + message = generation.message + try: + function_call = message.additional_kwargs["function_call"] + except KeyError as exc: + if partial: + return None + else: + raise OutputParserException(f"Could not parse function call: {exc}") + try: + if partial: + if self.args_only: + return parse_partial_json( + function_call["arguments"], strict=self.strict + ) + else: + return { + **function_call, + "arguments": parse_partial_json( + function_call["arguments"], strict=self.strict + ), + } + else: + if self.args_only: + try: + return json.loads( + function_call["arguments"], strict=self.strict + ) + except (json.JSONDecodeError, TypeError) as exc: + raise OutputParserException( + f"Could not parse function call data: {exc}" + ) + else: + try: + return { + **function_call, + "arguments": json.loads( + function_call["arguments"], strict=self.strict + ), + } + except (json.JSONDecodeError, TypeError) as exc: + raise OutputParserException( + f"Could not parse function call data: {exc}" + ) + except KeyError: + return None + + # This method would be called by the default implementation of `parse_result` + # but we're overriding that method so it's not needed. + def parse(self, text: str) -> Any: + raise NotImplementedError() class JsonKeyOutputFunctionsParser(JsonOutputFunctionsParser): @@ -72,7 +125,7 @@ class JsonKeyOutputFunctionsParser(JsonOutputFunctionsParser): key_name: str """The name of the key to return.""" - def parse_result(self, result: List[Generation]) -> Any: + def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: res = super().parse_result(result) return res[self.key_name] @@ -97,7 +150,7 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser): ) return values - def parse_result(self, result: List[Generation]) -> Any: + def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: _result = super().parse_result(result) if self.args_only: pydantic_args = self.pydantic_schema.parse_raw(_result) # type: ignore @@ -114,6 +167,6 @@ class PydanticAttrOutputFunctionsParser(PydanticOutputFunctionsParser): attr_name: str """The name of the attribute to return.""" - def parse_result(self, result: List[Generation]) -> Any: + def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: result = super().parse_result(result) return getattr(result, self.attr_name) diff --git a/libs/langchain/langchain/schema/output_parser.py b/libs/langchain/langchain/schema/output_parser.py index 46a8d9def03..157c1cd5f0f 100644 --- a/libs/langchain/langchain/schema/output_parser.py +++ b/libs/langchain/langchain/schema/output_parser.py @@ -17,8 +17,13 @@ from typing import ( from typing_extensions import get_args from langchain.load.serializable import Serializable -from langchain.schema.messages import AnyMessage, BaseMessage -from langchain.schema.output import ChatGeneration, Generation +from langchain.schema.messages import AnyMessage, BaseMessage, BaseMessageChunk +from langchain.schema.output import ( + ChatGeneration, + ChatGenerationChunk, + Generation, + GenerationChunk, +) from langchain.schema.prompt import PromptValue from langchain.schema.runnable import Runnable, RunnableConfig @@ -29,7 +34,7 @@ class BaseLLMOutputParser(Serializable, Generic[T], ABC): """Abstract base class for parsing the outputs of a model.""" @abstractmethod - def parse_result(self, result: List[Generation]) -> T: + def parse_result(self, result: List[Generation], *, partial: bool = False) -> T: """Parse a list of candidate model Generations into a specific format. Args: @@ -40,7 +45,9 @@ class BaseLLMOutputParser(Serializable, Generic[T], ABC): Structured output. """ - async def aparse_result(self, result: List[Generation]) -> T: + async def aparse_result( + self, result: List[Generation], *, partial: bool = False + ) -> T: """Parse a list of candidate model Generations into a specific format. Args: @@ -200,7 +207,7 @@ class BaseOutputParser(BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T] run_type="parser", ) - def parse_result(self, result: List[Generation]) -> T: + def parse_result(self, result: List[Generation], *, partial: bool = False) -> T: """Parse a list of candidate model Generations into a specific format. The return value is parsed from only the first Generation in the result, which @@ -226,7 +233,9 @@ class BaseOutputParser(BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T] Structured output. """ - async def aparse_result(self, result: List[Generation]) -> T: + async def aparse_result( + self, result: List[Generation], *, partial: bool = False + ) -> T: """Parse a list of candidate model Generations into a specific format. The return value is parsed from only the first Generation in the result, which @@ -329,6 +338,74 @@ class BaseTransformOutputParser(BaseOutputParser[T]): yield chunk +class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]): + """Base class for an output parser that can handle streaming input.""" + + diff: bool = False + """In streaming mode, whether to yield diffs between the previous and current + parsed output, or just the current parsed output. + """ + + def _diff(self, prev: Optional[T], next: T) -> T: + """Convert parsed outputs into a diff format. The semantics of this are + up to the output parser.""" + raise NotImplementedError() + + def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[Any]: + prev_parsed = None + acc_gen = None + for chunk in input: + if isinstance(chunk, BaseMessageChunk): + chunk_gen: Generation = ChatGenerationChunk(message=chunk) + elif isinstance(chunk, BaseMessage): + chunk_gen = ChatGenerationChunk( + message=BaseMessageChunk(**chunk.dict()) + ) + else: + chunk_gen = GenerationChunk(text=chunk) + + if acc_gen is None: + acc_gen = chunk_gen + else: + acc_gen += chunk_gen + + parsed = self.parse_result([acc_gen], partial=True) + if parsed is not None and parsed != prev_parsed: + if self.diff: + yield self._diff(prev_parsed, parsed) + else: + yield parsed + prev_parsed = parsed + + async def _atransform( + self, input: AsyncIterator[Union[str, BaseMessage]] + ) -> AsyncIterator[T]: + prev_parsed = None + acc_gen = None + async for chunk in input: + if isinstance(chunk, BaseMessageChunk): + chunk_gen: Generation = ChatGenerationChunk(message=chunk) + elif isinstance(chunk, BaseMessage): + chunk_gen = ChatGenerationChunk( + message=BaseMessageChunk(**chunk.dict()) + ) + else: + chunk_gen = GenerationChunk(text=chunk) + + if acc_gen is None: + acc_gen = chunk_gen + else: + acc_gen += chunk_gen + + parsed = self.parse_result([acc_gen], partial=True) + if parsed is not None and parsed != prev_parsed: + if self.diff: + yield self._diff(prev_parsed, parsed) + else: + yield parsed + prev_parsed = parsed + + class StrOutputParser(BaseTransformOutputParser[str]): """OutputParser that parses LLMResult into the top likely string.""" diff --git a/libs/langchain/tests/unit_tests/output_parsers/test_base_output_parser.py b/libs/langchain/tests/unit_tests/output_parsers/test_base_output_parser.py deleted file mode 100644 index 3ce6fb976d5..00000000000 --- a/libs/langchain/tests/unit_tests/output_parsers/test_base_output_parser.py +++ /dev/null @@ -1,55 +0,0 @@ -"""Test the BaseOutputParser class and its sub-classes.""" -from abc import ABC -from collections import defaultdict -from typing import List, Optional, Set, Type - -import pytest - -from langchain.schema import BaseOutputParser - - -def non_abstract_subclasses( - cls: Type[ABC], to_skip: Optional[Set] = None -) -> List[Type]: - """Recursively find all non-abstract subclasses of a class.""" - _to_skip = to_skip or set() - subclasses = [] - for subclass in cls.__subclasses__(): - if not getattr(subclass, "__abstractmethods__", None): - if subclass.__name__ not in _to_skip: - subclasses.append(subclass) - subclasses.extend(non_abstract_subclasses(subclass, to_skip=_to_skip)) - return subclasses - - -# parsers defined not in the output_parsers module: -_PARSERS_TO_SKIP = { - "FakeOutputParser", - "BaseOutputParser", - "FinishedOutputParser", - "RouterOutputParser", - "TrajectoryRunEvalOutputParser", -} -_NON_ABSTRACT_PARSERS = non_abstract_subclasses( - BaseOutputParser, to_skip=_PARSERS_TO_SKIP -) - - -@pytest.mark.parametrize("cls", _NON_ABSTRACT_PARSERS) -def test_subclass_implements_type(cls: Type[BaseOutputParser]) -> None: - try: - cls._type - except NotImplementedError: - pytest.fail(f"_type property is not implemented in class {cls.__name__}") - - -def test_all_subclasses_implement_unique_type() -> None: - types = defaultdict(list) - for cls in _NON_ABSTRACT_PARSERS: - try: - types[cls._type].append(cls.__name__) - except NotImplementedError: - # This is handled in the previous test - pass - dups = {t: names for t, names in types.items() if len(names) > 1} - assert not dups, f"Duplicate types: {dups}" diff --git a/libs/langchain/tests/unit_tests/output_parsers/test_json.py b/libs/langchain/tests/unit_tests/output_parsers/test_json.py index 21bc600fd05..6fe7cb27dd1 100644 --- a/libs/langchain/tests/unit_tests/output_parsers/test_json.py +++ b/libs/langchain/tests/unit_tests/output_parsers/test_json.py @@ -1,6 +1,15 @@ +import json +from typing import Any, AsyncIterator, Iterator, Tuple + import pytest -from langchain.output_parsers.json import parse_json_markdown +from langchain.output_parsers.json import ( + SimpleJsonOutputParser, + parse_json_markdown, + parse_partial_json, +) +from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser +from langchain.schema.messages import AIMessageChunk GOOD_JSON = """```json { @@ -183,3 +192,351 @@ def test_parse_json_with_python_dict() -> None: "action": "Final Answer", "action_input": {"foo": "bar", "bar": "foo"}, } + + +TEST_CASES_PARTIAL = [ + ('{"foo": "bar", "bar": "foo"}', '{"foo": "bar", "bar": "foo"}'), + ('{"foo": "bar", "bar": "foo', '{"foo": "bar", "bar": "foo"}'), + ('{"foo": "bar", "bar": "foo}', '{"foo": "bar", "bar": "foo}"}'), + ('{"foo": "bar", "bar": "foo[', '{"foo": "bar", "bar": "foo["}'), + ('{"foo": "bar", "bar": "foo\\"', '{"foo": "bar", "bar": "foo\\""}'), +] + + +@pytest.mark.parametrize("json_strings", TEST_CASES_PARTIAL) +def test_parse_partial_json(json_strings: Tuple[str, str]) -> None: + case, expected = json_strings + parsed = parse_partial_json(case) + assert parsed == json.loads(expected) + + +STREAMED_TOKENS = """ +{ + + " +setup +": + " +Why + did + the + bears + start + a + band + called + Bears + Bears + Bears + ? +" +, + " +punchline +": + " +Because + they + wanted + to + play + bear + -y + good + music + ! +" +, + " +audience +": + [ +" +Haha +" +, + " +So + funny +" +] + +} +""".splitlines() + +EXPECTED_STREAMED_JSON = [ + {}, + {"setup": ""}, + {"setup": "Why"}, + {"setup": "Why did"}, + {"setup": "Why did the"}, + {"setup": "Why did the bears"}, + {"setup": "Why did the bears start"}, + {"setup": "Why did the bears start a"}, + {"setup": "Why did the bears start a band"}, + {"setup": "Why did the bears start a band called"}, + {"setup": "Why did the bears start a band called Bears"}, + {"setup": "Why did the bears start a band called Bears Bears"}, + {"setup": "Why did the bears start a band called Bears Bears Bears"}, + {"setup": "Why did the bears start a band called Bears Bears Bears ?"}, + { + "setup": "Why did the bears start a band called Bears Bears Bears ?", + "punchline": "", + }, + { + "setup": "Why did the bears start a band called Bears Bears Bears ?", + "punchline": "Because", + }, + { + "setup": "Why did the bears start a band called Bears Bears Bears ?", + "punchline": "Because they", + }, + { + "setup": "Why did the bears start a band called Bears Bears Bears ?", + "punchline": "Because they wanted", + }, + { + "setup": "Why did the bears start a band called Bears Bears Bears ?", + "punchline": "Because they wanted to", + }, + { + "setup": "Why did the bears start a band called Bears Bears Bears ?", + "punchline": "Because they wanted to play", + }, + { + "setup": "Why did the bears start a band called Bears Bears Bears ?", + "punchline": "Because they wanted to play bear", + }, + { + "setup": "Why did the bears start a band called Bears Bears Bears ?", + "punchline": "Because they wanted to play bear -y", + }, + { + "setup": "Why did the bears start a band called Bears Bears Bears ?", + "punchline": "Because they wanted to play bear -y good", + }, + { + "setup": "Why did the bears start a band called Bears Bears Bears ?", + "punchline": "Because they wanted to play bear -y good music", + }, + { + "setup": "Why did the bears start a band called Bears Bears Bears ?", + "punchline": "Because they wanted to play bear -y good music !", + }, + { + "punchline": "Because they wanted to play bear -y good music !", + "setup": "Why did the bears start a band called Bears Bears Bears ?", + "audience": [], + }, + { + "punchline": "Because they wanted to play bear -y good music !", + "setup": "Why did the bears start a band called Bears Bears Bears ?", + "audience": [""], + }, + { + "punchline": "Because they wanted to play bear -y good music !", + "setup": "Why did the bears start a band called Bears Bears Bears ?", + "audience": ["Haha"], + }, + { + "punchline": "Because they wanted to play bear -y good music !", + "setup": "Why did the bears start a band called Bears Bears Bears ?", + "audience": ["Haha", ""], + }, + { + "punchline": "Because they wanted to play bear -y good music !", + "setup": "Why did the bears start a band called Bears Bears Bears ?", + "audience": ["Haha", "So"], + }, + { + "punchline": "Because they wanted to play bear -y good music !", + "setup": "Why did the bears start a band called Bears Bears Bears ?", + "audience": ["Haha", "So funny"], + }, +] + +EXPECTED_STREAMED_JSON_DIFF = [ + [{"op": "replace", "path": "", "value": {}}], + [{"op": "add", "path": "/setup", "value": ""}], + [{"op": "replace", "path": "/setup", "value": "Why"}], + [{"op": "replace", "path": "/setup", "value": "Why did"}], + [{"op": "replace", "path": "/setup", "value": "Why did the"}], + [{"op": "replace", "path": "/setup", "value": "Why did the bears"}], + [{"op": "replace", "path": "/setup", "value": "Why did the bears start"}], + [{"op": "replace", "path": "/setup", "value": "Why did the bears start a"}], + [{"op": "replace", "path": "/setup", "value": "Why did the bears start a band"}], + [ + { + "op": "replace", + "path": "/setup", + "value": "Why did the bears start a band called", + } + ], + [ + { + "op": "replace", + "path": "/setup", + "value": "Why did the bears start a band called Bears", + } + ], + [ + { + "op": "replace", + "path": "/setup", + "value": "Why did the bears start a band called Bears Bears", + } + ], + [ + { + "op": "replace", + "path": "/setup", + "value": "Why did the bears start a band called Bears Bears Bears", + } + ], + [ + { + "op": "replace", + "path": "/setup", + "value": "Why did the bears start a band called Bears Bears Bears ?", + } + ], + [{"op": "add", "path": "/punchline", "value": ""}], + [{"op": "replace", "path": "/punchline", "value": "Because"}], + [{"op": "replace", "path": "/punchline", "value": "Because they"}], + [{"op": "replace", "path": "/punchline", "value": "Because they wanted"}], + [{"op": "replace", "path": "/punchline", "value": "Because they wanted to"}], + [{"op": "replace", "path": "/punchline", "value": "Because they wanted to play"}], + [ + { + "op": "replace", + "path": "/punchline", + "value": "Because they wanted to play bear", + } + ], + [ + { + "op": "replace", + "path": "/punchline", + "value": "Because they wanted to play bear -y", + } + ], + [ + { + "op": "replace", + "path": "/punchline", + "value": "Because they wanted to play bear -y good", + } + ], + [ + { + "op": "replace", + "path": "/punchline", + "value": "Because they wanted to play bear -y good music", + } + ], + [ + { + "op": "replace", + "path": "/punchline", + "value": "Because they wanted to play bear -y good music !", + } + ], + [{"op": "add", "path": "/audience", "value": []}], + [{"op": "add", "path": "/audience/0", "value": ""}], + [{"op": "replace", "path": "/audience/0", "value": "Haha"}], + [{"op": "add", "path": "/audience/1", "value": ""}], + [{"op": "replace", "path": "/audience/1", "value": "So"}], + [{"op": "replace", "path": "/audience/1", "value": "So funny"}], +] + + +def test_partial_text_json_output_parser() -> None: + def input_iter(_: Any) -> Iterator[str]: + for token in STREAMED_TOKENS: + yield token + + chain = input_iter | SimpleJsonOutputParser() + + assert list(chain.stream(None)) == EXPECTED_STREAMED_JSON + + +def test_partial_functions_json_output_parser() -> None: + def input_iter(_: Any) -> Iterator[AIMessageChunk]: + for token in STREAMED_TOKENS: + yield AIMessageChunk( + content="", additional_kwargs={"function_call": {"arguments": token}} + ) + + chain = input_iter | JsonOutputFunctionsParser() + + assert list(chain.stream(None)) == EXPECTED_STREAMED_JSON + + +def test_partial_text_json_output_parser_diff() -> None: + def input_iter(_: Any) -> Iterator[str]: + for token in STREAMED_TOKENS: + yield token + + chain = input_iter | SimpleJsonOutputParser(diff=True) + + assert list(chain.stream(None)) == EXPECTED_STREAMED_JSON_DIFF + + +def test_partial_functions_json_output_parser_diff() -> None: + def input_iter(_: Any) -> Iterator[AIMessageChunk]: + for token in STREAMED_TOKENS: + yield AIMessageChunk( + content="", additional_kwargs={"function_call": {"arguments": token}} + ) + + chain = input_iter | JsonOutputFunctionsParser(diff=True) + + assert list(chain.stream(None)) == EXPECTED_STREAMED_JSON_DIFF + + +@pytest.mark.asyncio +async def test_partial_text_json_output_parser_async() -> None: + async def input_iter(_: Any) -> AsyncIterator[str]: + for token in STREAMED_TOKENS: + yield token + + chain = input_iter | SimpleJsonOutputParser() + + assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON + + +@pytest.mark.asyncio +async def test_partial_functions_json_output_parser_async() -> None: + async def input_iter(_: Any) -> AsyncIterator[AIMessageChunk]: + for token in STREAMED_TOKENS: + yield AIMessageChunk( + content="", additional_kwargs={"function_call": {"arguments": token}} + ) + + chain = input_iter | JsonOutputFunctionsParser() + + assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON + + +@pytest.mark.asyncio +async def test_partial_text_json_output_parser_diff_async() -> None: + async def input_iter(_: Any) -> AsyncIterator[str]: + for token in STREAMED_TOKENS: + yield token + + chain = input_iter | SimpleJsonOutputParser(diff=True) + + assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON_DIFF + + +@pytest.mark.asyncio +async def test_partial_functions_json_output_parser_diff_async() -> None: + async def input_iter(_: Any) -> AsyncIterator[AIMessageChunk]: + for token in STREAMED_TOKENS: + yield AIMessageChunk( + content="", additional_kwargs={"function_call": {"arguments": token}} + ) + + chain = input_iter | JsonOutputFunctionsParser(diff=True) + + assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON_DIFF diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py index f632a23e9dd..29df3726f9a 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py @@ -582,7 +582,9 @@ async def test_with_config(mocker: MockerFixture) -> None: ) == [5, 7] assert len(spy.call_args_list) == 2 - for i, call in enumerate(spy.call_args_list): + for i, call in enumerate( + sorted(spy.call_args_list, key=lambda x: 0 if x.args[0] == "hello" else 1) + ): assert call.args[0] == ("hello" if i == 0 else "wooorld") if i == 0: assert call.args[1].get("recursion_limit") == 5