mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-04 04:07:54 +00:00
Combine with existing json output parsers
This commit is contained in:
parent
4b8442896b
commit
c9d0f2b984
@ -7,9 +7,10 @@ from typing import Any, Callable, List, Optional
|
||||
|
||||
import jsonpatch
|
||||
|
||||
from langchain.schema import BaseOutputParser, OutputParserException
|
||||
from langchain.schema.output import ChatGeneration, Generation
|
||||
from langchain.schema.output_parser import BaseCumulativeTransformOutputParser
|
||||
from langchain.schema.output_parser import (
|
||||
BaseCumulativeTransformOutputParser,
|
||||
OutputParserException,
|
||||
)
|
||||
|
||||
|
||||
def _replace_new_line(match: re.Match[str]) -> str:
|
||||
@ -44,10 +45,10 @@ def _custom_parser(multiline_string: str) -> str:
|
||||
|
||||
# Adapted from https://github.com/KillianLucas/open-interpreter/blob/main/interpreter/utils/parse_partial_json.py
|
||||
# MIT License
|
||||
def parse_partial_json(s: str) -> Any:
|
||||
def parse_partial_json(s: str, *, strict: bool = False) -> Any:
|
||||
# Attempt to parse the string as-is.
|
||||
try:
|
||||
return json.loads(s)
|
||||
return json.loads(s, strict=strict)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
@ -97,7 +98,7 @@ def parse_partial_json(s: str) -> Any:
|
||||
|
||||
# Attempt to parse the modified string as JSON.
|
||||
try:
|
||||
return json.loads(new_s)
|
||||
return json.loads(new_s, strict=strict)
|
||||
except json.JSONDecodeError:
|
||||
# If we still can't parse the string as JSON, return None to indicate failure.
|
||||
return None
|
||||
@ -162,62 +163,26 @@ def parse_and_check_json_markdown(text: str, expected_keys: List[str]) -> dict:
|
||||
return json_obj
|
||||
|
||||
|
||||
class SimpleJsonOutputParser(BaseOutputParser[Any]):
|
||||
"""Parse the output of an LLM call to a JSON object."""
|
||||
class SimpleJsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
|
||||
"""Parse the output of an LLM call to a JSON object.
|
||||
|
||||
When used in streaming mode, it will yield partial JSON objects containing
|
||||
all the keys that have been returned so far.
|
||||
|
||||
In streaming, if `diff` is set to `True`, yields JSONPatch operations
|
||||
describing the difference between the previous and the current object.
|
||||
"""
|
||||
|
||||
def _diff(self, prev: Optional[Any], next: Any) -> Any:
|
||||
return jsonpatch.make_patch(prev, next).patch
|
||||
|
||||
def parse(self, text: str) -> Any:
|
||||
text = text.strip()
|
||||
try:
|
||||
return parse_partial_json(text)
|
||||
return parse_json_markdown(text.strip(), parse_partial_json)
|
||||
except JSONDecodeError as e:
|
||||
raise OutputParserException(f"Invalid json output: {text}") from e
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
return "simple_json_output_parser"
|
||||
|
||||
|
||||
class PartialFunctionsJsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
return "partial_functions_json"
|
||||
|
||||
def parse_result(self, result: List[Generation]) -> Any:
|
||||
if len(result) != 1:
|
||||
raise OutputParserException(
|
||||
f"Expected exactly one result, but got {len(result)}"
|
||||
)
|
||||
generation = result[0]
|
||||
if not isinstance(generation, ChatGeneration):
|
||||
raise OutputParserException(
|
||||
"This output parser can only be used with a chat generation."
|
||||
)
|
||||
message = generation.message
|
||||
try:
|
||||
function_call = message.additional_kwargs["function_call"]
|
||||
except KeyError:
|
||||
return None
|
||||
try:
|
||||
return parse_partial_json(function_call["arguments"])
|
||||
except KeyError:
|
||||
return None
|
||||
|
||||
def _diff(self, prev: Optional[Any], next: Any) -> Any:
|
||||
return jsonpatch.make_patch(prev, next).patch
|
||||
|
||||
# This method would be called by the default implementation of `parse_result`
|
||||
# but we're overriding that method so it's not needed.
|
||||
def parse(self, text: str) -> Any:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class PartialJsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
return "partial_functions_json"
|
||||
|
||||
def _diff(self, prev: Optional[Any], next: Any) -> Any:
|
||||
return jsonpatch.make_patch(prev, next).patch
|
||||
|
||||
def parse(self, text: str) -> Any:
|
||||
return parse_json_markdown(text, parse_partial_json)
|
||||
|
@ -1,14 +1,20 @@
|
||||
import copy
|
||||
import json
|
||||
from typing import Any, Dict, List, Type, Union
|
||||
from typing import Any, Dict, List, Optional, Type, Union
|
||||
|
||||
import jsonpatch
|
||||
|
||||
from langchain.output_parsers.json import parse_partial_json
|
||||
from langchain.pydantic_v1 import BaseModel, root_validator
|
||||
from langchain.schema import (
|
||||
ChatGeneration,
|
||||
Generation,
|
||||
OutputParserException,
|
||||
)
|
||||
from langchain.schema.output_parser import BaseGenerationOutputParser
|
||||
from langchain.schema.output_parser import (
|
||||
BaseCumulativeTransformOutputParser,
|
||||
BaseGenerationOutputParser,
|
||||
)
|
||||
|
||||
|
||||
class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
|
||||
@ -34,7 +40,7 @@ class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
|
||||
return func_call
|
||||
|
||||
|
||||
class JsonOutputFunctionsParser(OutputFunctionsParser):
|
||||
class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
|
||||
"""Parse an output as the Json object."""
|
||||
|
||||
strict: bool = False
|
||||
@ -45,25 +51,42 @@ class JsonOutputFunctionsParser(OutputFunctionsParser):
|
||||
Useful when the parsed output may include unicode characters or new lines.
|
||||
"""
|
||||
|
||||
args_only: bool = True
|
||||
"""Whether to only return the arguments to the function call."""
|
||||
|
||||
def _diff(self, prev: Optional[Any], next: Any) -> Any:
|
||||
return jsonpatch.make_patch(prev, next).patch
|
||||
|
||||
def parse_result(self, result: List[Generation]) -> Any:
|
||||
function_call_info = super().parse_result(result)
|
||||
if self.args_only:
|
||||
try:
|
||||
return json.loads(function_call_info, strict=self.strict)
|
||||
except (json.JSONDecodeError, TypeError) as exc:
|
||||
raise OutputParserException(
|
||||
f"Could not parse function call data: {exc}"
|
||||
)
|
||||
else:
|
||||
try:
|
||||
function_call_info["arguments"] = json.loads(
|
||||
function_call_info["arguments"], strict=self.strict
|
||||
)
|
||||
except (json.JSONDecodeError, TypeError) as exc:
|
||||
raise OutputParserException(
|
||||
f"Could not parse function call data: {exc}"
|
||||
)
|
||||
return function_call_info
|
||||
if len(result) != 1:
|
||||
raise OutputParserException(
|
||||
f"Expected exactly one result, but got {len(result)}"
|
||||
)
|
||||
generation = result[0]
|
||||
if not isinstance(generation, ChatGeneration):
|
||||
raise OutputParserException(
|
||||
"This output parser can only be used with a chat generation."
|
||||
)
|
||||
message = generation.message
|
||||
try:
|
||||
function_call = message.additional_kwargs["function_call"]
|
||||
except KeyError:
|
||||
return None
|
||||
try:
|
||||
if self.args_only:
|
||||
return parse_partial_json(function_call["arguments"])
|
||||
else:
|
||||
return {
|
||||
**function_call,
|
||||
"arguments": parse_partial_json(function_call["arguments"]),
|
||||
}
|
||||
except KeyError:
|
||||
return None
|
||||
|
||||
# This method would be called by the default implementation of `parse_result`
|
||||
# but we're overriding that method so it's not needed.
|
||||
def parse(self, text: str) -> Any:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class JsonKeyOutputFunctionsParser(JsonOutputFunctionsParser):
|
||||
|
@ -338,6 +338,9 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
|
||||
"""Base class for an output parser that can handle streaming input."""
|
||||
|
||||
diff: bool = False
|
||||
"""In streaming mode, whether to yield diffs between the previous and current
|
||||
parsed output, or just the current parsed output.
|
||||
"""
|
||||
|
||||
def _diff(self, prev: Optional[T], next: T) -> T:
|
||||
"""Convert parsed outputs into a diff format. The semantics of this are
|
||||
|
@ -4,12 +4,12 @@ from typing import Any, AsyncIterator, Iterator, Tuple
|
||||
import pytest
|
||||
|
||||
from langchain.output_parsers.json import (
|
||||
PartialFunctionsJsonOutputParser,
|
||||
PartialJsonOutputParser,
|
||||
SimpleJsonOutputParser,
|
||||
parse_json_markdown,
|
||||
parse_partial_json,
|
||||
)
|
||||
from langchain.schema.messages import AIMessageChunk
|
||||
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
|
||||
|
||||
GOOD_JSON = """```json
|
||||
{
|
||||
@ -455,7 +455,7 @@ def test_partial_text_json_output_parser() -> None:
|
||||
for token in STREAMED_TOKENS:
|
||||
yield token
|
||||
|
||||
chain = input_iter | PartialJsonOutputParser()
|
||||
chain = input_iter | SimpleJsonOutputParser()
|
||||
|
||||
assert list(chain.stream(None)) == EXPECTED_STREAMED_JSON
|
||||
|
||||
@ -467,7 +467,7 @@ def test_partial_functions_json_output_parser() -> None:
|
||||
content="", additional_kwargs={"function_call": {"arguments": token}}
|
||||
)
|
||||
|
||||
chain = input_iter | PartialFunctionsJsonOutputParser()
|
||||
chain = input_iter | JsonOutputFunctionsParser()
|
||||
|
||||
assert list(chain.stream(None)) == EXPECTED_STREAMED_JSON
|
||||
|
||||
@ -477,7 +477,7 @@ def test_partial_text_json_output_parser_diff() -> None:
|
||||
for token in STREAMED_TOKENS:
|
||||
yield token
|
||||
|
||||
chain = input_iter | PartialJsonOutputParser(diff=True)
|
||||
chain = input_iter | SimpleJsonOutputParser(diff=True)
|
||||
|
||||
assert list(chain.stream(None)) == EXPECTED_STREAMED_JSON_DIFF
|
||||
|
||||
@ -489,7 +489,7 @@ def test_partial_functions_json_output_parser_diff() -> None:
|
||||
content="", additional_kwargs={"function_call": {"arguments": token}}
|
||||
)
|
||||
|
||||
chain = input_iter | PartialFunctionsJsonOutputParser(diff=True)
|
||||
chain = input_iter | JsonOutputFunctionsParser(diff=True)
|
||||
|
||||
assert list(chain.stream(None)) == EXPECTED_STREAMED_JSON_DIFF
|
||||
|
||||
@ -500,7 +500,7 @@ async def test_partial_text_json_output_parser_async() -> None:
|
||||
for token in STREAMED_TOKENS:
|
||||
yield token
|
||||
|
||||
chain = input_iter | PartialJsonOutputParser()
|
||||
chain = input_iter | SimpleJsonOutputParser()
|
||||
|
||||
assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON
|
||||
|
||||
@ -513,7 +513,7 @@ async def test_partial_functions_json_output_parser_async() -> None:
|
||||
content="", additional_kwargs={"function_call": {"arguments": token}}
|
||||
)
|
||||
|
||||
chain = input_iter | PartialFunctionsJsonOutputParser()
|
||||
chain = input_iter | JsonOutputFunctionsParser()
|
||||
|
||||
assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON
|
||||
|
||||
@ -524,7 +524,7 @@ async def test_partial_text_json_output_parser_diff_async() -> None:
|
||||
for token in STREAMED_TOKENS:
|
||||
yield token
|
||||
|
||||
chain = input_iter | PartialJsonOutputParser(diff=True)
|
||||
chain = input_iter | SimpleJsonOutputParser(diff=True)
|
||||
|
||||
assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON_DIFF
|
||||
|
||||
@ -537,6 +537,6 @@ async def test_partial_functions_json_output_parser_diff_async() -> None:
|
||||
content="", additional_kwargs={"function_call": {"arguments": token}}
|
||||
)
|
||||
|
||||
chain = input_iter | PartialFunctionsJsonOutputParser(diff=True)
|
||||
chain = input_iter | JsonOutputFunctionsParser(diff=True)
|
||||
|
||||
assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON_DIFF
|
||||
|
Loading…
Reference in New Issue
Block a user