From e022bfaa7dd54ea962fc11094b37b3fa324c95ea Mon Sep 17 00:00:00 2001 From: T Cramer Date: Mon, 5 Feb 2024 22:18:30 +0000 Subject: [PATCH] langchain: add partial parsing support to JsonOutputToolsParser (#17035) - **Description:** Add partial parsing support to JsonOutputToolsParser - **Issue:** [16736](https://github.com/langchain-ai/langchain/issues/16736) --------- Co-authored-by: Harrison Chase --- .../langchain/output_parsers/openai_tools.py | 20 +++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/libs/langchain/langchain/output_parsers/openai_tools.py b/libs/langchain/langchain/output_parsers/openai_tools.py index 39fd05974a9..045e32686ef 100644 --- a/libs/langchain/langchain/output_parsers/openai_tools.py +++ b/libs/langchain/langchain/output_parsers/openai_tools.py @@ -4,9 +4,8 @@ from json import JSONDecodeError from typing import Any, List, Type from langchain_core.exceptions import OutputParserException -from langchain_core.output_parsers import ( - BaseGenerationOutputParser, -) +from langchain_core.output_parsers import BaseGenerationOutputParser +from langchain_core.output_parsers.json import parse_partial_json from langchain_core.outputs import ChatGeneration, Generation from langchain_core.pydantic_v1 import BaseModel @@ -42,9 +41,14 @@ class JsonOutputToolsParser(BaseGenerationOutputParser[Any]): if "function" not in tool_call: continue try: - function_args = json.loads( - tool_call["function"]["arguments"], strict=self.strict - ) + if partial: + function_args = parse_partial_json( + tool_call["function"]["arguments"], strict=self.strict + ) + else: + function_args = json.loads( + tool_call["function"]["arguments"], strict=self.strict + ) except JSONDecodeError as e: exceptions.append( f"Function {tool_call['function']['name']} arguments:\n\n" @@ -77,7 +81,7 @@ class JsonOutputKeyToolsParser(JsonOutputToolsParser): super().__init__(key_name=key_name, **kwargs) def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: - results = super().parse_result(result) + results = super().parse_result(result, partial=partial) results = [res for res in results if res["type"] == self.key_name] if not self.return_id: results = [res["args"] for res in results] @@ -92,6 +96,6 @@ class PydanticToolsParser(JsonOutputToolsParser): tools: List[Type[BaseModel]] def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: - results = super().parse_result(result) + results = super().parse_result(result, partial=partial) name_dict = {tool.__name__: tool for tool in self.tools} return [name_dict[res["type"]](**res["args"]) for res in results]