mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-08 14:05:16 +00:00
Add a streaming json parser (#11193)
<img width="1728" alt="Screenshot 2023-09-28 at 20 15 01" src="https://github.com/langchain-ai/langchain/assets/56902/ed0644c3-6db7-41b9-9543-e34fce46d3e5"> <!-- Thank you for contributing to LangChain! Replace this entire comment with: - **Description:** a description of the change, - **Issue:** the issue # it fixes (if applicable), - **Dependencies:** any dependencies required for this change, - **Tag maintainer:** for a quicker response, tag the relevant maintainer (see below), - **Twitter handle:** we announce bigger features on Twitter. If your PR gets announced, and you'd like a mention, we'll gladly shout you out! Please make sure your PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` to check this locally. See contribution guidelines for more information on how to write/run tests, lint, etc: https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/extras` directory. If no one reviews your PR within a few days, please @-mention one of @baskaryan, @eyurtsev, @hwchase17. -->
This commit is contained in:
commit
1ddf9f74b2
@ -75,14 +75,16 @@ class OpenAIFunctionsAgentOutputParser(AgentOutputParser):
|
||||
return_values={"output": message.content}, log=message.content
|
||||
)
|
||||
|
||||
def parse_result(self, result: List[Generation]) -> Union[AgentAction, AgentFinish]:
|
||||
def parse_result(
|
||||
self, result: List[Generation], *, partial: bool = False
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
if not isinstance(result[0], ChatGeneration):
|
||||
raise ValueError("This output parser only works on ChatGeneration output")
|
||||
message = result[0].message
|
||||
return self._parse_ai_message(message)
|
||||
|
||||
async def aparse_result(
|
||||
self, result: List[Generation]
|
||||
self, result: List[Generation], *, partial: bool = False
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, self.parse_result, result
|
||||
|
@ -3,9 +3,14 @@ from __future__ import annotations
|
||||
import json
|
||||
import re
|
||||
from json import JSONDecodeError
|
||||
from typing import Any, List
|
||||
from typing import Any, Callable, List, Optional
|
||||
|
||||
from langchain.schema import BaseOutputParser, OutputParserException
|
||||
import jsonpatch
|
||||
|
||||
from langchain.schema.output_parser import (
|
||||
BaseCumulativeTransformOutputParser,
|
||||
OutputParserException,
|
||||
)
|
||||
|
||||
|
||||
def _replace_new_line(match: re.Match[str]) -> str:
|
||||
@ -38,7 +43,70 @@ def _custom_parser(multiline_string: str) -> str:
|
||||
return multiline_string
|
||||
|
||||
|
||||
def parse_json_markdown(json_string: str) -> dict:
|
||||
# Adapted from https://github.com/KillianLucas/open-interpreter/blob/main/interpreter/utils/parse_partial_json.py
|
||||
# MIT License
|
||||
def parse_partial_json(s: str, *, strict: bool = False) -> Any:
|
||||
# Attempt to parse the string as-is.
|
||||
try:
|
||||
return json.loads(s, strict=strict)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Initialize variables.
|
||||
new_s = ""
|
||||
stack = []
|
||||
is_inside_string = False
|
||||
escaped = False
|
||||
|
||||
# Process each character in the string one at a time.
|
||||
for char in s:
|
||||
if is_inside_string:
|
||||
if char == '"' and not escaped:
|
||||
is_inside_string = False
|
||||
elif char == "\n" and not escaped:
|
||||
char = "\\n" # Replace the newline character with the escape sequence.
|
||||
elif char == "\\":
|
||||
escaped = not escaped
|
||||
else:
|
||||
escaped = False
|
||||
else:
|
||||
if char == '"':
|
||||
is_inside_string = True
|
||||
escaped = False
|
||||
elif char == "{":
|
||||
stack.append("}")
|
||||
elif char == "[":
|
||||
stack.append("]")
|
||||
elif char == "}" or char == "]":
|
||||
if stack and stack[-1] == char:
|
||||
stack.pop()
|
||||
else:
|
||||
# Mismatched closing character; the input is malformed.
|
||||
return None
|
||||
|
||||
# Append the processed character to the new string.
|
||||
new_s += char
|
||||
|
||||
# If we're still inside a string at the end of processing,
|
||||
# we need to close the string.
|
||||
if is_inside_string:
|
||||
new_s += '"'
|
||||
|
||||
# Close any remaining open structures in the reverse order that they were opened.
|
||||
for closing_char in reversed(stack):
|
||||
new_s += closing_char
|
||||
|
||||
# Attempt to parse the modified string as JSON.
|
||||
try:
|
||||
return json.loads(new_s, strict=strict)
|
||||
except json.JSONDecodeError:
|
||||
# If we still can't parse the string as JSON, return None to indicate failure.
|
||||
return None
|
||||
|
||||
|
||||
def parse_json_markdown(
|
||||
json_string: str, *, parser: Callable[[str], Any] = json.loads
|
||||
) -> dict:
|
||||
"""
|
||||
Parse a JSON string from a Markdown string.
|
||||
|
||||
@ -65,7 +133,7 @@ def parse_json_markdown(json_string: str) -> dict:
|
||||
json_str = _custom_parser(json_str)
|
||||
|
||||
# Parse the JSON string into a Python dictionary
|
||||
parsed = json.loads(json_str)
|
||||
parsed = parser(json_str)
|
||||
|
||||
return parsed
|
||||
|
||||
@ -95,13 +163,23 @@ def parse_and_check_json_markdown(text: str, expected_keys: List[str]) -> dict:
|
||||
return json_obj
|
||||
|
||||
|
||||
class SimpleJsonOutputParser(BaseOutputParser[Any]):
|
||||
"""Parse the output of an LLM call to a JSON object."""
|
||||
class SimpleJsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
|
||||
"""Parse the output of an LLM call to a JSON object.
|
||||
|
||||
When used in streaming mode, it will yield partial JSON objects containing
|
||||
all the keys that have been returned so far.
|
||||
|
||||
In streaming, if `diff` is set to `True`, yields JSONPatch operations
|
||||
describing the difference between the previous and the current object.
|
||||
"""
|
||||
|
||||
def _diff(self, prev: Optional[Any], next: Any) -> Any:
|
||||
return jsonpatch.make_patch(prev, next).patch
|
||||
|
||||
def parse(self, text: str) -> Any:
|
||||
text = text.strip()
|
||||
try:
|
||||
return json.loads(text)
|
||||
return parse_json_markdown(text.strip(), parser=parse_partial_json)
|
||||
except JSONDecodeError as e:
|
||||
raise OutputParserException(f"Invalid json output: {text}") from e
|
||||
|
||||
|
@ -1,14 +1,20 @@
|
||||
import copy
|
||||
import json
|
||||
from typing import Any, Dict, List, Type, Union
|
||||
from typing import Any, Dict, List, Optional, Type, Union
|
||||
|
||||
import jsonpatch
|
||||
|
||||
from langchain.output_parsers.json import parse_partial_json
|
||||
from langchain.pydantic_v1 import BaseModel, root_validator
|
||||
from langchain.schema import (
|
||||
ChatGeneration,
|
||||
Generation,
|
||||
OutputParserException,
|
||||
)
|
||||
from langchain.schema.output_parser import BaseGenerationOutputParser
|
||||
from langchain.schema.output_parser import (
|
||||
BaseCumulativeTransformOutputParser,
|
||||
BaseGenerationOutputParser,
|
||||
)
|
||||
|
||||
|
||||
class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
|
||||
@ -17,7 +23,7 @@ class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
|
||||
args_only: bool = True
|
||||
"""Whether to only return the arguments to the function call."""
|
||||
|
||||
def parse_result(self, result: List[Generation]) -> Any:
|
||||
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
|
||||
generation = result[0]
|
||||
if not isinstance(generation, ChatGeneration):
|
||||
raise OutputParserException(
|
||||
@ -34,7 +40,7 @@ class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
|
||||
return func_call
|
||||
|
||||
|
||||
class JsonOutputFunctionsParser(OutputFunctionsParser):
|
||||
class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
|
||||
"""Parse an output as the Json object."""
|
||||
|
||||
strict: bool = False
|
||||
@ -45,25 +51,72 @@ class JsonOutputFunctionsParser(OutputFunctionsParser):
|
||||
Useful when the parsed output may include unicode characters or new lines.
|
||||
"""
|
||||
|
||||
def parse_result(self, result: List[Generation]) -> Any:
|
||||
function_call_info = super().parse_result(result)
|
||||
if self.args_only:
|
||||
try:
|
||||
return json.loads(function_call_info, strict=self.strict)
|
||||
except (json.JSONDecodeError, TypeError) as exc:
|
||||
raise OutputParserException(
|
||||
f"Could not parse function call data: {exc}"
|
||||
)
|
||||
else:
|
||||
try:
|
||||
function_call_info["arguments"] = json.loads(
|
||||
function_call_info["arguments"], strict=self.strict
|
||||
)
|
||||
except (json.JSONDecodeError, TypeError) as exc:
|
||||
raise OutputParserException(
|
||||
f"Could not parse function call data: {exc}"
|
||||
)
|
||||
return function_call_info
|
||||
args_only: bool = True
|
||||
"""Whether to only return the arguments to the function call."""
|
||||
|
||||
def _diff(self, prev: Optional[Any], next: Any) -> Any:
|
||||
return jsonpatch.make_patch(prev, next).patch
|
||||
|
||||
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
|
||||
if len(result) != 1:
|
||||
raise OutputParserException(
|
||||
f"Expected exactly one result, but got {len(result)}"
|
||||
)
|
||||
generation = result[0]
|
||||
if not isinstance(generation, ChatGeneration):
|
||||
raise OutputParserException(
|
||||
"This output parser can only be used with a chat generation."
|
||||
)
|
||||
message = generation.message
|
||||
try:
|
||||
function_call = message.additional_kwargs["function_call"]
|
||||
except KeyError as exc:
|
||||
if partial:
|
||||
return None
|
||||
else:
|
||||
raise OutputParserException(f"Could not parse function call: {exc}")
|
||||
try:
|
||||
if partial:
|
||||
if self.args_only:
|
||||
return parse_partial_json(
|
||||
function_call["arguments"], strict=self.strict
|
||||
)
|
||||
else:
|
||||
return {
|
||||
**function_call,
|
||||
"arguments": parse_partial_json(
|
||||
function_call["arguments"], strict=self.strict
|
||||
),
|
||||
}
|
||||
else:
|
||||
if self.args_only:
|
||||
try:
|
||||
return json.loads(
|
||||
function_call["arguments"], strict=self.strict
|
||||
)
|
||||
except (json.JSONDecodeError, TypeError) as exc:
|
||||
raise OutputParserException(
|
||||
f"Could not parse function call data: {exc}"
|
||||
)
|
||||
else:
|
||||
try:
|
||||
return {
|
||||
**function_call,
|
||||
"arguments": json.loads(
|
||||
function_call["arguments"], strict=self.strict
|
||||
),
|
||||
}
|
||||
except (json.JSONDecodeError, TypeError) as exc:
|
||||
raise OutputParserException(
|
||||
f"Could not parse function call data: {exc}"
|
||||
)
|
||||
except KeyError:
|
||||
return None
|
||||
|
||||
# This method would be called by the default implementation of `parse_result`
|
||||
# but we're overriding that method so it's not needed.
|
||||
def parse(self, text: str) -> Any:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class JsonKeyOutputFunctionsParser(JsonOutputFunctionsParser):
|
||||
@ -72,7 +125,7 @@ class JsonKeyOutputFunctionsParser(JsonOutputFunctionsParser):
|
||||
key_name: str
|
||||
"""The name of the key to return."""
|
||||
|
||||
def parse_result(self, result: List[Generation]) -> Any:
|
||||
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
|
||||
res = super().parse_result(result)
|
||||
return res[self.key_name]
|
||||
|
||||
@ -97,7 +150,7 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser):
|
||||
)
|
||||
return values
|
||||
|
||||
def parse_result(self, result: List[Generation]) -> Any:
|
||||
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
|
||||
_result = super().parse_result(result)
|
||||
if self.args_only:
|
||||
pydantic_args = self.pydantic_schema.parse_raw(_result) # type: ignore
|
||||
@ -114,6 +167,6 @@ class PydanticAttrOutputFunctionsParser(PydanticOutputFunctionsParser):
|
||||
attr_name: str
|
||||
"""The name of the attribute to return."""
|
||||
|
||||
def parse_result(self, result: List[Generation]) -> Any:
|
||||
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
|
||||
result = super().parse_result(result)
|
||||
return getattr(result, self.attr_name)
|
||||
|
@ -17,8 +17,13 @@ from typing import (
|
||||
from typing_extensions import get_args
|
||||
|
||||
from langchain.load.serializable import Serializable
|
||||
from langchain.schema.messages import AnyMessage, BaseMessage
|
||||
from langchain.schema.output import ChatGeneration, Generation
|
||||
from langchain.schema.messages import AnyMessage, BaseMessage, BaseMessageChunk
|
||||
from langchain.schema.output import (
|
||||
ChatGeneration,
|
||||
ChatGenerationChunk,
|
||||
Generation,
|
||||
GenerationChunk,
|
||||
)
|
||||
from langchain.schema.prompt import PromptValue
|
||||
from langchain.schema.runnable import Runnable, RunnableConfig
|
||||
|
||||
@ -29,7 +34,7 @@ class BaseLLMOutputParser(Serializable, Generic[T], ABC):
|
||||
"""Abstract base class for parsing the outputs of a model."""
|
||||
|
||||
@abstractmethod
|
||||
def parse_result(self, result: List[Generation]) -> T:
|
||||
def parse_result(self, result: List[Generation], *, partial: bool = False) -> T:
|
||||
"""Parse a list of candidate model Generations into a specific format.
|
||||
|
||||
Args:
|
||||
@ -40,7 +45,9 @@ class BaseLLMOutputParser(Serializable, Generic[T], ABC):
|
||||
Structured output.
|
||||
"""
|
||||
|
||||
async def aparse_result(self, result: List[Generation]) -> T:
|
||||
async def aparse_result(
|
||||
self, result: List[Generation], *, partial: bool = False
|
||||
) -> T:
|
||||
"""Parse a list of candidate model Generations into a specific format.
|
||||
|
||||
Args:
|
||||
@ -200,7 +207,7 @@ class BaseOutputParser(BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T]
|
||||
run_type="parser",
|
||||
)
|
||||
|
||||
def parse_result(self, result: List[Generation]) -> T:
|
||||
def parse_result(self, result: List[Generation], *, partial: bool = False) -> T:
|
||||
"""Parse a list of candidate model Generations into a specific format.
|
||||
|
||||
The return value is parsed from only the first Generation in the result, which
|
||||
@ -226,7 +233,9 @@ class BaseOutputParser(BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T]
|
||||
Structured output.
|
||||
"""
|
||||
|
||||
async def aparse_result(self, result: List[Generation]) -> T:
|
||||
async def aparse_result(
|
||||
self, result: List[Generation], *, partial: bool = False
|
||||
) -> T:
|
||||
"""Parse a list of candidate model Generations into a specific format.
|
||||
|
||||
The return value is parsed from only the first Generation in the result, which
|
||||
@ -329,6 +338,74 @@ class BaseTransformOutputParser(BaseOutputParser[T]):
|
||||
yield chunk
|
||||
|
||||
|
||||
class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
|
||||
"""Base class for an output parser that can handle streaming input."""
|
||||
|
||||
diff: bool = False
|
||||
"""In streaming mode, whether to yield diffs between the previous and current
|
||||
parsed output, or just the current parsed output.
|
||||
"""
|
||||
|
||||
def _diff(self, prev: Optional[T], next: T) -> T:
|
||||
"""Convert parsed outputs into a diff format. The semantics of this are
|
||||
up to the output parser."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[Any]:
|
||||
prev_parsed = None
|
||||
acc_gen = None
|
||||
for chunk in input:
|
||||
if isinstance(chunk, BaseMessageChunk):
|
||||
chunk_gen: Generation = ChatGenerationChunk(message=chunk)
|
||||
elif isinstance(chunk, BaseMessage):
|
||||
chunk_gen = ChatGenerationChunk(
|
||||
message=BaseMessageChunk(**chunk.dict())
|
||||
)
|
||||
else:
|
||||
chunk_gen = GenerationChunk(text=chunk)
|
||||
|
||||
if acc_gen is None:
|
||||
acc_gen = chunk_gen
|
||||
else:
|
||||
acc_gen += chunk_gen
|
||||
|
||||
parsed = self.parse_result([acc_gen], partial=True)
|
||||
if parsed is not None and parsed != prev_parsed:
|
||||
if self.diff:
|
||||
yield self._diff(prev_parsed, parsed)
|
||||
else:
|
||||
yield parsed
|
||||
prev_parsed = parsed
|
||||
|
||||
async def _atransform(
|
||||
self, input: AsyncIterator[Union[str, BaseMessage]]
|
||||
) -> AsyncIterator[T]:
|
||||
prev_parsed = None
|
||||
acc_gen = None
|
||||
async for chunk in input:
|
||||
if isinstance(chunk, BaseMessageChunk):
|
||||
chunk_gen: Generation = ChatGenerationChunk(message=chunk)
|
||||
elif isinstance(chunk, BaseMessage):
|
||||
chunk_gen = ChatGenerationChunk(
|
||||
message=BaseMessageChunk(**chunk.dict())
|
||||
)
|
||||
else:
|
||||
chunk_gen = GenerationChunk(text=chunk)
|
||||
|
||||
if acc_gen is None:
|
||||
acc_gen = chunk_gen
|
||||
else:
|
||||
acc_gen += chunk_gen
|
||||
|
||||
parsed = self.parse_result([acc_gen], partial=True)
|
||||
if parsed is not None and parsed != prev_parsed:
|
||||
if self.diff:
|
||||
yield self._diff(prev_parsed, parsed)
|
||||
else:
|
||||
yield parsed
|
||||
prev_parsed = parsed
|
||||
|
||||
|
||||
class StrOutputParser(BaseTransformOutputParser[str]):
|
||||
"""OutputParser that parses LLMResult into the top likely string."""
|
||||
|
||||
|
@ -1,55 +0,0 @@
|
||||
"""Test the BaseOutputParser class and its sub-classes."""
|
||||
from abc import ABC
|
||||
from collections import defaultdict
|
||||
from typing import List, Optional, Set, Type
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.schema import BaseOutputParser
|
||||
|
||||
|
||||
def non_abstract_subclasses(
|
||||
cls: Type[ABC], to_skip: Optional[Set] = None
|
||||
) -> List[Type]:
|
||||
"""Recursively find all non-abstract subclasses of a class."""
|
||||
_to_skip = to_skip or set()
|
||||
subclasses = []
|
||||
for subclass in cls.__subclasses__():
|
||||
if not getattr(subclass, "__abstractmethods__", None):
|
||||
if subclass.__name__ not in _to_skip:
|
||||
subclasses.append(subclass)
|
||||
subclasses.extend(non_abstract_subclasses(subclass, to_skip=_to_skip))
|
||||
return subclasses
|
||||
|
||||
|
||||
# parsers defined not in the output_parsers module:
|
||||
_PARSERS_TO_SKIP = {
|
||||
"FakeOutputParser",
|
||||
"BaseOutputParser",
|
||||
"FinishedOutputParser",
|
||||
"RouterOutputParser",
|
||||
"TrajectoryRunEvalOutputParser",
|
||||
}
|
||||
_NON_ABSTRACT_PARSERS = non_abstract_subclasses(
|
||||
BaseOutputParser, to_skip=_PARSERS_TO_SKIP
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("cls", _NON_ABSTRACT_PARSERS)
|
||||
def test_subclass_implements_type(cls: Type[BaseOutputParser]) -> None:
|
||||
try:
|
||||
cls._type
|
||||
except NotImplementedError:
|
||||
pytest.fail(f"_type property is not implemented in class {cls.__name__}")
|
||||
|
||||
|
||||
def test_all_subclasses_implement_unique_type() -> None:
|
||||
types = defaultdict(list)
|
||||
for cls in _NON_ABSTRACT_PARSERS:
|
||||
try:
|
||||
types[cls._type].append(cls.__name__)
|
||||
except NotImplementedError:
|
||||
# This is handled in the previous test
|
||||
pass
|
||||
dups = {t: names for t, names in types.items() if len(names) > 1}
|
||||
assert not dups, f"Duplicate types: {dups}"
|
@ -1,6 +1,15 @@
|
||||
import json
|
||||
from typing import Any, AsyncIterator, Iterator, Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.output_parsers.json import parse_json_markdown
|
||||
from langchain.output_parsers.json import (
|
||||
SimpleJsonOutputParser,
|
||||
parse_json_markdown,
|
||||
parse_partial_json,
|
||||
)
|
||||
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
|
||||
from langchain.schema.messages import AIMessageChunk
|
||||
|
||||
GOOD_JSON = """```json
|
||||
{
|
||||
@ -183,3 +192,351 @@ def test_parse_json_with_python_dict() -> None:
|
||||
"action": "Final Answer",
|
||||
"action_input": {"foo": "bar", "bar": "foo"},
|
||||
}
|
||||
|
||||
|
||||
TEST_CASES_PARTIAL = [
|
||||
('{"foo": "bar", "bar": "foo"}', '{"foo": "bar", "bar": "foo"}'),
|
||||
('{"foo": "bar", "bar": "foo', '{"foo": "bar", "bar": "foo"}'),
|
||||
('{"foo": "bar", "bar": "foo}', '{"foo": "bar", "bar": "foo}"}'),
|
||||
('{"foo": "bar", "bar": "foo[', '{"foo": "bar", "bar": "foo["}'),
|
||||
('{"foo": "bar", "bar": "foo\\"', '{"foo": "bar", "bar": "foo\\""}'),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("json_strings", TEST_CASES_PARTIAL)
|
||||
def test_parse_partial_json(json_strings: Tuple[str, str]) -> None:
|
||||
case, expected = json_strings
|
||||
parsed = parse_partial_json(case)
|
||||
assert parsed == json.loads(expected)
|
||||
|
||||
|
||||
STREAMED_TOKENS = """
|
||||
{
|
||||
|
||||
"
|
||||
setup
|
||||
":
|
||||
"
|
||||
Why
|
||||
did
|
||||
the
|
||||
bears
|
||||
start
|
||||
a
|
||||
band
|
||||
called
|
||||
Bears
|
||||
Bears
|
||||
Bears
|
||||
?
|
||||
"
|
||||
,
|
||||
"
|
||||
punchline
|
||||
":
|
||||
"
|
||||
Because
|
||||
they
|
||||
wanted
|
||||
to
|
||||
play
|
||||
bear
|
||||
-y
|
||||
good
|
||||
music
|
||||
!
|
||||
"
|
||||
,
|
||||
"
|
||||
audience
|
||||
":
|
||||
[
|
||||
"
|
||||
Haha
|
||||
"
|
||||
,
|
||||
"
|
||||
So
|
||||
funny
|
||||
"
|
||||
]
|
||||
|
||||
}
|
||||
""".splitlines()
|
||||
|
||||
EXPECTED_STREAMED_JSON = [
|
||||
{},
|
||||
{"setup": ""},
|
||||
{"setup": "Why"},
|
||||
{"setup": "Why did"},
|
||||
{"setup": "Why did the"},
|
||||
{"setup": "Why did the bears"},
|
||||
{"setup": "Why did the bears start"},
|
||||
{"setup": "Why did the bears start a"},
|
||||
{"setup": "Why did the bears start a band"},
|
||||
{"setup": "Why did the bears start a band called"},
|
||||
{"setup": "Why did the bears start a band called Bears"},
|
||||
{"setup": "Why did the bears start a band called Bears Bears"},
|
||||
{"setup": "Why did the bears start a band called Bears Bears Bears"},
|
||||
{"setup": "Why did the bears start a band called Bears Bears Bears ?"},
|
||||
{
|
||||
"setup": "Why did the bears start a band called Bears Bears Bears ?",
|
||||
"punchline": "",
|
||||
},
|
||||
{
|
||||
"setup": "Why did the bears start a band called Bears Bears Bears ?",
|
||||
"punchline": "Because",
|
||||
},
|
||||
{
|
||||
"setup": "Why did the bears start a band called Bears Bears Bears ?",
|
||||
"punchline": "Because they",
|
||||
},
|
||||
{
|
||||
"setup": "Why did the bears start a band called Bears Bears Bears ?",
|
||||
"punchline": "Because they wanted",
|
||||
},
|
||||
{
|
||||
"setup": "Why did the bears start a band called Bears Bears Bears ?",
|
||||
"punchline": "Because they wanted to",
|
||||
},
|
||||
{
|
||||
"setup": "Why did the bears start a band called Bears Bears Bears ?",
|
||||
"punchline": "Because they wanted to play",
|
||||
},
|
||||
{
|
||||
"setup": "Why did the bears start a band called Bears Bears Bears ?",
|
||||
"punchline": "Because they wanted to play bear",
|
||||
},
|
||||
{
|
||||
"setup": "Why did the bears start a band called Bears Bears Bears ?",
|
||||
"punchline": "Because they wanted to play bear -y",
|
||||
},
|
||||
{
|
||||
"setup": "Why did the bears start a band called Bears Bears Bears ?",
|
||||
"punchline": "Because they wanted to play bear -y good",
|
||||
},
|
||||
{
|
||||
"setup": "Why did the bears start a band called Bears Bears Bears ?",
|
||||
"punchline": "Because they wanted to play bear -y good music",
|
||||
},
|
||||
{
|
||||
"setup": "Why did the bears start a band called Bears Bears Bears ?",
|
||||
"punchline": "Because they wanted to play bear -y good music !",
|
||||
},
|
||||
{
|
||||
"punchline": "Because they wanted to play bear -y good music !",
|
||||
"setup": "Why did the bears start a band called Bears Bears Bears ?",
|
||||
"audience": [],
|
||||
},
|
||||
{
|
||||
"punchline": "Because they wanted to play bear -y good music !",
|
||||
"setup": "Why did the bears start a band called Bears Bears Bears ?",
|
||||
"audience": [""],
|
||||
},
|
||||
{
|
||||
"punchline": "Because they wanted to play bear -y good music !",
|
||||
"setup": "Why did the bears start a band called Bears Bears Bears ?",
|
||||
"audience": ["Haha"],
|
||||
},
|
||||
{
|
||||
"punchline": "Because they wanted to play bear -y good music !",
|
||||
"setup": "Why did the bears start a band called Bears Bears Bears ?",
|
||||
"audience": ["Haha", ""],
|
||||
},
|
||||
{
|
||||
"punchline": "Because they wanted to play bear -y good music !",
|
||||
"setup": "Why did the bears start a band called Bears Bears Bears ?",
|
||||
"audience": ["Haha", "So"],
|
||||
},
|
||||
{
|
||||
"punchline": "Because they wanted to play bear -y good music !",
|
||||
"setup": "Why did the bears start a band called Bears Bears Bears ?",
|
||||
"audience": ["Haha", "So funny"],
|
||||
},
|
||||
]
|
||||
|
||||
EXPECTED_STREAMED_JSON_DIFF = [
|
||||
[{"op": "replace", "path": "", "value": {}}],
|
||||
[{"op": "add", "path": "/setup", "value": ""}],
|
||||
[{"op": "replace", "path": "/setup", "value": "Why"}],
|
||||
[{"op": "replace", "path": "/setup", "value": "Why did"}],
|
||||
[{"op": "replace", "path": "/setup", "value": "Why did the"}],
|
||||
[{"op": "replace", "path": "/setup", "value": "Why did the bears"}],
|
||||
[{"op": "replace", "path": "/setup", "value": "Why did the bears start"}],
|
||||
[{"op": "replace", "path": "/setup", "value": "Why did the bears start a"}],
|
||||
[{"op": "replace", "path": "/setup", "value": "Why did the bears start a band"}],
|
||||
[
|
||||
{
|
||||
"op": "replace",
|
||||
"path": "/setup",
|
||||
"value": "Why did the bears start a band called",
|
||||
}
|
||||
],
|
||||
[
|
||||
{
|
||||
"op": "replace",
|
||||
"path": "/setup",
|
||||
"value": "Why did the bears start a band called Bears",
|
||||
}
|
||||
],
|
||||
[
|
||||
{
|
||||
"op": "replace",
|
||||
"path": "/setup",
|
||||
"value": "Why did the bears start a band called Bears Bears",
|
||||
}
|
||||
],
|
||||
[
|
||||
{
|
||||
"op": "replace",
|
||||
"path": "/setup",
|
||||
"value": "Why did the bears start a band called Bears Bears Bears",
|
||||
}
|
||||
],
|
||||
[
|
||||
{
|
||||
"op": "replace",
|
||||
"path": "/setup",
|
||||
"value": "Why did the bears start a band called Bears Bears Bears ?",
|
||||
}
|
||||
],
|
||||
[{"op": "add", "path": "/punchline", "value": ""}],
|
||||
[{"op": "replace", "path": "/punchline", "value": "Because"}],
|
||||
[{"op": "replace", "path": "/punchline", "value": "Because they"}],
|
||||
[{"op": "replace", "path": "/punchline", "value": "Because they wanted"}],
|
||||
[{"op": "replace", "path": "/punchline", "value": "Because they wanted to"}],
|
||||
[{"op": "replace", "path": "/punchline", "value": "Because they wanted to play"}],
|
||||
[
|
||||
{
|
||||
"op": "replace",
|
||||
"path": "/punchline",
|
||||
"value": "Because they wanted to play bear",
|
||||
}
|
||||
],
|
||||
[
|
||||
{
|
||||
"op": "replace",
|
||||
"path": "/punchline",
|
||||
"value": "Because they wanted to play bear -y",
|
||||
}
|
||||
],
|
||||
[
|
||||
{
|
||||
"op": "replace",
|
||||
"path": "/punchline",
|
||||
"value": "Because they wanted to play bear -y good",
|
||||
}
|
||||
],
|
||||
[
|
||||
{
|
||||
"op": "replace",
|
||||
"path": "/punchline",
|
||||
"value": "Because they wanted to play bear -y good music",
|
||||
}
|
||||
],
|
||||
[
|
||||
{
|
||||
"op": "replace",
|
||||
"path": "/punchline",
|
||||
"value": "Because they wanted to play bear -y good music !",
|
||||
}
|
||||
],
|
||||
[{"op": "add", "path": "/audience", "value": []}],
|
||||
[{"op": "add", "path": "/audience/0", "value": ""}],
|
||||
[{"op": "replace", "path": "/audience/0", "value": "Haha"}],
|
||||
[{"op": "add", "path": "/audience/1", "value": ""}],
|
||||
[{"op": "replace", "path": "/audience/1", "value": "So"}],
|
||||
[{"op": "replace", "path": "/audience/1", "value": "So funny"}],
|
||||
]
|
||||
|
||||
|
||||
def test_partial_text_json_output_parser() -> None:
|
||||
def input_iter(_: Any) -> Iterator[str]:
|
||||
for token in STREAMED_TOKENS:
|
||||
yield token
|
||||
|
||||
chain = input_iter | SimpleJsonOutputParser()
|
||||
|
||||
assert list(chain.stream(None)) == EXPECTED_STREAMED_JSON
|
||||
|
||||
|
||||
def test_partial_functions_json_output_parser() -> None:
|
||||
def input_iter(_: Any) -> Iterator[AIMessageChunk]:
|
||||
for token in STREAMED_TOKENS:
|
||||
yield AIMessageChunk(
|
||||
content="", additional_kwargs={"function_call": {"arguments": token}}
|
||||
)
|
||||
|
||||
chain = input_iter | JsonOutputFunctionsParser()
|
||||
|
||||
assert list(chain.stream(None)) == EXPECTED_STREAMED_JSON
|
||||
|
||||
|
||||
def test_partial_text_json_output_parser_diff() -> None:
|
||||
def input_iter(_: Any) -> Iterator[str]:
|
||||
for token in STREAMED_TOKENS:
|
||||
yield token
|
||||
|
||||
chain = input_iter | SimpleJsonOutputParser(diff=True)
|
||||
|
||||
assert list(chain.stream(None)) == EXPECTED_STREAMED_JSON_DIFF
|
||||
|
||||
|
||||
def test_partial_functions_json_output_parser_diff() -> None:
|
||||
def input_iter(_: Any) -> Iterator[AIMessageChunk]:
|
||||
for token in STREAMED_TOKENS:
|
||||
yield AIMessageChunk(
|
||||
content="", additional_kwargs={"function_call": {"arguments": token}}
|
||||
)
|
||||
|
||||
chain = input_iter | JsonOutputFunctionsParser(diff=True)
|
||||
|
||||
assert list(chain.stream(None)) == EXPECTED_STREAMED_JSON_DIFF
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_partial_text_json_output_parser_async() -> None:
|
||||
async def input_iter(_: Any) -> AsyncIterator[str]:
|
||||
for token in STREAMED_TOKENS:
|
||||
yield token
|
||||
|
||||
chain = input_iter | SimpleJsonOutputParser()
|
||||
|
||||
assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_partial_functions_json_output_parser_async() -> None:
|
||||
async def input_iter(_: Any) -> AsyncIterator[AIMessageChunk]:
|
||||
for token in STREAMED_TOKENS:
|
||||
yield AIMessageChunk(
|
||||
content="", additional_kwargs={"function_call": {"arguments": token}}
|
||||
)
|
||||
|
||||
chain = input_iter | JsonOutputFunctionsParser()
|
||||
|
||||
assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_partial_text_json_output_parser_diff_async() -> None:
|
||||
async def input_iter(_: Any) -> AsyncIterator[str]:
|
||||
for token in STREAMED_TOKENS:
|
||||
yield token
|
||||
|
||||
chain = input_iter | SimpleJsonOutputParser(diff=True)
|
||||
|
||||
assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON_DIFF
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_partial_functions_json_output_parser_diff_async() -> None:
|
||||
async def input_iter(_: Any) -> AsyncIterator[AIMessageChunk]:
|
||||
for token in STREAMED_TOKENS:
|
||||
yield AIMessageChunk(
|
||||
content="", additional_kwargs={"function_call": {"arguments": token}}
|
||||
)
|
||||
|
||||
chain = input_iter | JsonOutputFunctionsParser(diff=True)
|
||||
|
||||
assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON_DIFF
|
||||
|
@ -582,7 +582,9 @@ async def test_with_config(mocker: MockerFixture) -> None:
|
||||
) == [5, 7]
|
||||
|
||||
assert len(spy.call_args_list) == 2
|
||||
for i, call in enumerate(spy.call_args_list):
|
||||
for i, call in enumerate(
|
||||
sorted(spy.call_args_list, key=lambda x: 0 if x.args[0] == "hello" else 1)
|
||||
):
|
||||
assert call.args[0] == ("hello" if i == 0 else "wooorld")
|
||||
if i == 0:
|
||||
assert call.args[1].get("recursion_limit") == 5
|
||||
|
Loading…
Reference in New Issue
Block a user