From 18de77cc8cdf37385e0ab2791d877be673296380 Mon Sep 17 00:00:00 2001 From: Bagatur <22008038+baskaryan@users.noreply.github.com> Date: Mon, 11 Mar 2024 21:53:56 -0700 Subject: [PATCH] core[minor]: add streaming support to OAI tool parsers (#18940) Co-authored-by: Erick Friis --- .../output_parsers/openai_tools.py | 64 ++- .../output_parsers/test_openai_tools.py | 483 ++++++++++++++++++ 2 files changed, 527 insertions(+), 20 deletions(-) create mode 100644 libs/core/tests/unit_tests/output_parsers/test_openai_tools.py diff --git a/libs/core/langchain_core/output_parsers/openai_tools.py b/libs/core/langchain_core/output_parsers/openai_tools.py index 3e405e76276..fb1f88aca24 100644 --- a/libs/core/langchain_core/output_parsers/openai_tools.py +++ b/libs/core/langchain_core/output_parsers/openai_tools.py @@ -4,13 +4,13 @@ from json import JSONDecodeError from typing import Any, List, Type from langchain_core.exceptions import OutputParserException -from langchain_core.output_parsers import BaseGenerationOutputParser +from langchain_core.output_parsers import BaseCumulativeTransformOutputParser from langchain_core.output_parsers.json import parse_partial_json from langchain_core.outputs import ChatGeneration, Generation -from langchain_core.pydantic_v1 import BaseModel +from langchain_core.pydantic_v1 import BaseModel, ValidationError -class JsonOutputToolsParser(BaseGenerationOutputParser[Any]): +class JsonOutputToolsParser(BaseCumulativeTransformOutputParser[Any]): """Parse tools from OpenAI response.""" strict: bool = False @@ -50,22 +50,25 @@ class JsonOutputToolsParser(BaseGenerationOutputParser[Any]): for tool_call in tool_calls: if "function" not in tool_call: continue - try: - if partial: + if partial: + try: function_args = parse_partial_json( tool_call["function"]["arguments"], strict=self.strict ) - else: + except JSONDecodeError: + continue + else: + try: function_args = json.loads( tool_call["function"]["arguments"], strict=self.strict ) - except JSONDecodeError as e: - exceptions.append( - f"Function {tool_call['function']['name']} arguments:\n\n" - f"{tool_call['function']['arguments']}\n\nare not valid JSON. " - f"Received JSONDecodeError {e}" - ) - continue + except JSONDecodeError as e: + exceptions.append( + f"Function {tool_call['function']['name']} arguments:\n\n" + f"{tool_call['function']['arguments']}\n\nare not valid JSON. " + f"Received JSONDecodeError {e}" + ) + continue parsed = { "type": tool_call["function"]["name"], "args": function_args, @@ -79,6 +82,9 @@ class JsonOutputToolsParser(BaseGenerationOutputParser[Any]): return final_tools[0] if final_tools else None return final_tools + def parse(self, text: str) -> Any: + raise NotImplementedError() + class JsonOutputKeyToolsParser(JsonOutputToolsParser): """Parse tools from OpenAI response.""" @@ -88,6 +94,7 @@ class JsonOutputKeyToolsParser(JsonOutputToolsParser): def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: parsed_result = super().parse_result(result, partial=partial) + if self.first_tool_only: single_result = ( parsed_result @@ -111,13 +118,30 @@ class PydanticToolsParser(JsonOutputToolsParser): tools: List[Type[BaseModel]] + # TODO: Support more granular streaming of objects. Currently only streams once all + # Pydantic object fields are present. def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: - parsed_result = super().parse_result(result, partial=partial) + json_results = super().parse_result(result, partial=partial) + if not json_results: + return None if self.first_tool_only else [] + + json_results = [json_results] if self.first_tool_only else json_results name_dict = {tool.__name__: tool for tool in self.tools} + pydantic_objects = [] + for res in json_results: + try: + if not isinstance(res["args"], dict): + raise ValueError( + f"Tool arguments must be specified as a dict, received: " + f"{res['args']}" + ) + pydantic_objects.append(name_dict[res["type"]](**res["args"])) + except (ValidationError, ValueError) as e: + if partial: + continue + else: + raise e if self.first_tool_only: - return ( - name_dict[parsed_result["type"]](**parsed_result["args"]) - if parsed_result - else None - ) - return [name_dict[res["type"]](**res["args"]) for res in parsed_result] + return pydantic_objects[0] if pydantic_objects else None + else: + return pydantic_objects diff --git a/libs/core/tests/unit_tests/output_parsers/test_openai_tools.py b/libs/core/tests/unit_tests/output_parsers/test_openai_tools.py new file mode 100644 index 00000000000..0ba52d4ff01 --- /dev/null +++ b/libs/core/tests/unit_tests/output_parsers/test_openai_tools.py @@ -0,0 +1,483 @@ +from typing import Any, AsyncIterator, Iterator, List + +from langchain_core.messages import AIMessageChunk, BaseMessage +from langchain_core.output_parsers.openai_tools import ( + JsonOutputKeyToolsParser, + JsonOutputToolsParser, + PydanticToolsParser, +) +from langchain_core.pydantic_v1 import BaseModel, Field + +STREAMED_MESSAGES: list = [ + AIMessageChunk(content=""), + AIMessageChunk( + content="", + additional_kwargs={ + "tool_calls": [ + { + "index": 0, + "id": "call_OwL7f5PEPJTYzw9sQlNJtCZl", + "function": {"arguments": "", "name": "NameCollector"}, + "type": "function", + } + ] + }, + ), + AIMessageChunk( + content="", + additional_kwargs={ + "tool_calls": [ + { + "index": 0, + "id": None, + "function": {"arguments": '{"na', "name": None}, + "type": None, + } + ] + }, + ), + AIMessageChunk( + content="", + additional_kwargs={ + "tool_calls": [ + { + "index": 0, + "id": None, + "function": {"arguments": 'mes":', "name": None}, + "type": None, + } + ] + }, + ), + AIMessageChunk( + content="", + additional_kwargs={ + "tool_calls": [ + { + "index": 0, + "id": None, + "function": {"arguments": ' ["suz', "name": None}, + "type": None, + } + ] + }, + ), + AIMessageChunk( + content="", + additional_kwargs={ + "tool_calls": [ + { + "index": 0, + "id": None, + "function": {"arguments": 'y", ', "name": None}, + "type": None, + } + ] + }, + ), + AIMessageChunk( + content="", + additional_kwargs={ + "tool_calls": [ + { + "index": 0, + "id": None, + "function": {"arguments": '"jerm', "name": None}, + "type": None, + } + ] + }, + ), + AIMessageChunk( + content="", + additional_kwargs={ + "tool_calls": [ + { + "index": 0, + "id": None, + "function": {"arguments": 'aine",', "name": None}, + "type": None, + } + ] + }, + ), + AIMessageChunk( + content="", + additional_kwargs={ + "tool_calls": [ + { + "index": 0, + "id": None, + "function": {"arguments": ' "al', "name": None}, + "type": None, + } + ] + }, + ), + AIMessageChunk( + content="", + additional_kwargs={ + "tool_calls": [ + { + "index": 0, + "id": None, + "function": {"arguments": 'ex"],', "name": None}, + "type": None, + } + ] + }, + ), + AIMessageChunk( + content="", + additional_kwargs={ + "tool_calls": [ + { + "index": 0, + "id": None, + "function": {"arguments": ' "pers', "name": None}, + "type": None, + } + ] + }, + ), + AIMessageChunk( + content="", + additional_kwargs={ + "tool_calls": [ + { + "index": 0, + "id": None, + "function": {"arguments": 'on":', "name": None}, + "type": None, + } + ] + }, + ), + AIMessageChunk( + content="", + additional_kwargs={ + "tool_calls": [ + { + "index": 0, + "id": None, + "function": {"arguments": ' {"ag', "name": None}, + "type": None, + } + ] + }, + ), + AIMessageChunk( + content="", + additional_kwargs={ + "tool_calls": [ + { + "index": 0, + "id": None, + "function": {"arguments": 'e": 39', "name": None}, + "type": None, + } + ] + }, + ), + AIMessageChunk( + content="", + additional_kwargs={ + "tool_calls": [ + { + "index": 0, + "id": None, + "function": {"arguments": ', "h', "name": None}, + "type": None, + } + ] + }, + ), + AIMessageChunk( + content="", + additional_kwargs={ + "tool_calls": [ + { + "index": 0, + "id": None, + "function": {"arguments": "air_c", "name": None}, + "type": None, + } + ] + }, + ), + AIMessageChunk( + content="", + additional_kwargs={ + "tool_calls": [ + { + "index": 0, + "id": None, + "function": {"arguments": 'olor":', "name": None}, + "type": None, + } + ] + }, + ), + AIMessageChunk( + content="", + additional_kwargs={ + "tool_calls": [ + { + "index": 0, + "id": None, + "function": {"arguments": ' "br', "name": None}, + "type": None, + } + ] + }, + ), + AIMessageChunk( + content="", + additional_kwargs={ + "tool_calls": [ + { + "index": 0, + "id": None, + "function": {"arguments": 'own",', "name": None}, + "type": None, + } + ] + }, + ), + AIMessageChunk( + content="", + additional_kwargs={ + "tool_calls": [ + { + "index": 0, + "id": None, + "function": {"arguments": ' "job"', "name": None}, + "type": None, + } + ] + }, + ), + AIMessageChunk( + content="", + additional_kwargs={ + "tool_calls": [ + { + "index": 0, + "id": None, + "function": {"arguments": ': "c', "name": None}, + "type": None, + } + ] + }, + ), + AIMessageChunk( + content="", + additional_kwargs={ + "tool_calls": [ + { + "index": 0, + "id": None, + "function": {"arguments": "oncie", "name": None}, + "type": None, + } + ] + }, + ), + AIMessageChunk( + content="", + additional_kwargs={ + "tool_calls": [ + { + "index": 0, + "id": None, + "function": {"arguments": 'rge"}}', "name": None}, + "type": None, + } + ] + }, + ), + AIMessageChunk(content=""), +] + + +EXPECTED_STREAMED_JSON = [ + {}, + {"names": ["suz"]}, + {"names": ["suzy"]}, + {"names": ["suzy", "jerm"]}, + {"names": ["suzy", "jermaine"]}, + {"names": ["suzy", "jermaine", "al"]}, + {"names": ["suzy", "jermaine", "alex"]}, + {"names": ["suzy", "jermaine", "alex"], "person": {}}, + {"names": ["suzy", "jermaine", "alex"], "person": {"age": 39}}, + {"names": ["suzy", "jermaine", "alex"], "person": {"age": 39, "hair_color": "br"}}, + { + "names": ["suzy", "jermaine", "alex"], + "person": {"age": 39, "hair_color": "brown"}, + }, + { + "names": ["suzy", "jermaine", "alex"], + "person": {"age": 39, "hair_color": "brown", "job": "c"}, + }, + { + "names": ["suzy", "jermaine", "alex"], + "person": {"age": 39, "hair_color": "brown", "job": "concie"}, + }, + { + "names": ["suzy", "jermaine", "alex"], + "person": {"age": 39, "hair_color": "brown", "job": "concierge"}, + }, +] + + +def test_partial_json_output_parser() -> None: + def input_iter(_: Any) -> Iterator[BaseMessage]: + for msg in STREAMED_MESSAGES: + yield msg + + chain = input_iter | JsonOutputToolsParser() + + actual = list(chain.stream(None)) + expected: list = [[]] + [ + [{"type": "NameCollector", "args": chunk}] for chunk in EXPECTED_STREAMED_JSON + ] + assert actual == expected + + +async def test_partial_json_output_parser_async() -> None: + async def input_iter(_: Any) -> AsyncIterator[BaseMessage]: + for token in STREAMED_MESSAGES: + yield token + + chain = input_iter | JsonOutputToolsParser() + + actual = [p async for p in chain.astream(None)] + expected: list = [[]] + [ + [{"type": "NameCollector", "args": chunk}] for chunk in EXPECTED_STREAMED_JSON + ] + assert actual == expected + + +def test_partial_json_output_parser_return_id() -> None: + def input_iter(_: Any) -> Iterator[BaseMessage]: + for msg in STREAMED_MESSAGES: + yield msg + + chain = input_iter | JsonOutputToolsParser(return_id=True) + + actual = list(chain.stream(None)) + expected: list = [[]] + [ + [ + { + "type": "NameCollector", + "args": chunk, + "id": "call_OwL7f5PEPJTYzw9sQlNJtCZl", + } + ] + for chunk in EXPECTED_STREAMED_JSON + ] + assert actual == expected + + +def test_partial_json_output_key_parser() -> None: + def input_iter(_: Any) -> Iterator[BaseMessage]: + for msg in STREAMED_MESSAGES: + yield msg + + chain = input_iter | JsonOutputKeyToolsParser(key_name="NameCollector") + + actual = list(chain.stream(None)) + expected: list = [[]] + [[chunk] for chunk in EXPECTED_STREAMED_JSON] + assert actual == expected + + +async def test_partial_json_output_parser_key_async() -> None: + async def input_iter(_: Any) -> AsyncIterator[BaseMessage]: + for token in STREAMED_MESSAGES: + yield token + + chain = input_iter | JsonOutputKeyToolsParser(key_name="NameCollector") + + actual = [p async for p in chain.astream(None)] + expected: list = [[]] + [[chunk] for chunk in EXPECTED_STREAMED_JSON] + assert actual == expected + + +def test_partial_json_output_key_parser_first_only() -> None: + def input_iter(_: Any) -> Iterator[BaseMessage]: + for msg in STREAMED_MESSAGES: + yield msg + + chain = input_iter | JsonOutputKeyToolsParser( + key_name="NameCollector", first_tool_only=True + ) + + assert list(chain.stream(None)) == EXPECTED_STREAMED_JSON + + +async def test_partial_json_output_parser_key_async_first_only() -> None: + async def input_iter(_: Any) -> AsyncIterator[BaseMessage]: + for token in STREAMED_MESSAGES: + yield token + + chain = input_iter | JsonOutputKeyToolsParser( + key_name="NameCollector", first_tool_only=True + ) + + assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON + + +class Person(BaseModel): + age: int + hair_color: str + job: str + + +class NameCollector(BaseModel): + """record names of all people mentioned""" + + names: List[str] = Field(..., description="all names mentioned") + person: Person = Field(..., description="info about the main subject") + + +# Expected to change when we support more granular pydantic streaming. +EXPECTED_STREAMED_PYDANTIC = [ + NameCollector( + names=["suzy", "jermaine", "alex"], + person=Person(age=39, hair_color="brown", job="c"), + ), + NameCollector( + names=["suzy", "jermaine", "alex"], + person=Person(age=39, hair_color="brown", job="concie"), + ), + NameCollector( + names=["suzy", "jermaine", "alex"], + person=Person(age=39, hair_color="brown", job="concierge"), + ), +] + + +def test_partial_pydantic_output_parser() -> None: + def input_iter(_: Any) -> Iterator[BaseMessage]: + for msg in STREAMED_MESSAGES: + yield msg + + chain = input_iter | PydanticToolsParser( + tools=[NameCollector], first_tool_only=True + ) + + actual = list(chain.stream(None)) + assert actual == EXPECTED_STREAMED_PYDANTIC + + +async def test_partial_pydantic_output_parser_async() -> None: + async def input_iter(_: Any) -> AsyncIterator[BaseMessage]: + for token in STREAMED_MESSAGES: + yield token + + chain = input_iter | PydanticToolsParser( + tools=[NameCollector], first_tool_only=True + ) + + actual = [p async for p in chain.astream(None)] + assert actual == EXPECTED_STREAMED_PYDANTIC