langchain: add partial parsing support to JsonOutputToolsParser (#17035)

- **Description:** Add partial parsing support to JsonOutputToolsParser
- **Issue:**
[16736](https://github.com/langchain-ai/langchain/issues/16736)

---------

Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
T Cramer 2024-02-05 22:18:30 +00:00 committed by GitHub
parent dcf973c22c
commit e022bfaa7d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -4,9 +4,8 @@ from json import JSONDecodeError
from typing import Any, List, Type from typing import Any, List, Type
from langchain_core.exceptions import OutputParserException from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers import ( from langchain_core.output_parsers import BaseGenerationOutputParser
BaseGenerationOutputParser, from langchain_core.output_parsers.json import parse_partial_json
)
from langchain_core.outputs import ChatGeneration, Generation from langchain_core.outputs import ChatGeneration, Generation
from langchain_core.pydantic_v1 import BaseModel from langchain_core.pydantic_v1 import BaseModel
@ -42,9 +41,14 @@ class JsonOutputToolsParser(BaseGenerationOutputParser[Any]):
if "function" not in tool_call: if "function" not in tool_call:
continue continue
try: try:
function_args = json.loads( if partial:
tool_call["function"]["arguments"], strict=self.strict function_args = parse_partial_json(
) tool_call["function"]["arguments"], strict=self.strict
)
else:
function_args = json.loads(
tool_call["function"]["arguments"], strict=self.strict
)
except JSONDecodeError as e: except JSONDecodeError as e:
exceptions.append( exceptions.append(
f"Function {tool_call['function']['name']} arguments:\n\n" f"Function {tool_call['function']['name']} arguments:\n\n"
@ -77,7 +81,7 @@ class JsonOutputKeyToolsParser(JsonOutputToolsParser):
super().__init__(key_name=key_name, **kwargs) super().__init__(key_name=key_name, **kwargs)
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
results = super().parse_result(result) results = super().parse_result(result, partial=partial)
results = [res for res in results if res["type"] == self.key_name] results = [res for res in results if res["type"] == self.key_name]
if not self.return_id: if not self.return_id:
results = [res["args"] for res in results] results = [res["args"] for res in results]
@ -92,6 +96,6 @@ class PydanticToolsParser(JsonOutputToolsParser):
tools: List[Type[BaseModel]] tools: List[Type[BaseModel]]
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
results = super().parse_result(result) results = super().parse_result(result, partial=partial)
name_dict = {tool.__name__: tool for tool in self.tools} name_dict = {tool.__name__: tool for tool in self.tools}
return [name_dict[res["type"]](**res["args"]) for res in results] return [name_dict[res["type"]](**res["args"]) for res in results]