mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-12 12:59:07 +00:00
core[minor]: add streaming support to OAI tool parsers (#18940)
Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
@@ -4,13 +4,13 @@ from json import JSONDecodeError
|
||||
from typing import Any, List, Type
|
||||
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.output_parsers import BaseGenerationOutputParser
|
||||
from langchain_core.output_parsers import BaseCumulativeTransformOutputParser
|
||||
from langchain_core.output_parsers.json import parse_partial_json
|
||||
from langchain_core.outputs import ChatGeneration, Generation
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.pydantic_v1 import BaseModel, ValidationError
|
||||
|
||||
|
||||
class JsonOutputToolsParser(BaseGenerationOutputParser[Any]):
|
||||
class JsonOutputToolsParser(BaseCumulativeTransformOutputParser[Any]):
|
||||
"""Parse tools from OpenAI response."""
|
||||
|
||||
strict: bool = False
|
||||
@@ -50,22 +50,25 @@ class JsonOutputToolsParser(BaseGenerationOutputParser[Any]):
|
||||
for tool_call in tool_calls:
|
||||
if "function" not in tool_call:
|
||||
continue
|
||||
try:
|
||||
if partial:
|
||||
if partial:
|
||||
try:
|
||||
function_args = parse_partial_json(
|
||||
tool_call["function"]["arguments"], strict=self.strict
|
||||
)
|
||||
else:
|
||||
except JSONDecodeError:
|
||||
continue
|
||||
else:
|
||||
try:
|
||||
function_args = json.loads(
|
||||
tool_call["function"]["arguments"], strict=self.strict
|
||||
)
|
||||
except JSONDecodeError as e:
|
||||
exceptions.append(
|
||||
f"Function {tool_call['function']['name']} arguments:\n\n"
|
||||
f"{tool_call['function']['arguments']}\n\nare not valid JSON. "
|
||||
f"Received JSONDecodeError {e}"
|
||||
)
|
||||
continue
|
||||
except JSONDecodeError as e:
|
||||
exceptions.append(
|
||||
f"Function {tool_call['function']['name']} arguments:\n\n"
|
||||
f"{tool_call['function']['arguments']}\n\nare not valid JSON. "
|
||||
f"Received JSONDecodeError {e}"
|
||||
)
|
||||
continue
|
||||
parsed = {
|
||||
"type": tool_call["function"]["name"],
|
||||
"args": function_args,
|
||||
@@ -79,6 +82,9 @@ class JsonOutputToolsParser(BaseGenerationOutputParser[Any]):
|
||||
return final_tools[0] if final_tools else None
|
||||
return final_tools
|
||||
|
||||
def parse(self, text: str) -> Any:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class JsonOutputKeyToolsParser(JsonOutputToolsParser):
|
||||
"""Parse tools from OpenAI response."""
|
||||
@@ -88,6 +94,7 @@ class JsonOutputKeyToolsParser(JsonOutputToolsParser):
|
||||
|
||||
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
|
||||
parsed_result = super().parse_result(result, partial=partial)
|
||||
|
||||
if self.first_tool_only:
|
||||
single_result = (
|
||||
parsed_result
|
||||
@@ -111,13 +118,30 @@ class PydanticToolsParser(JsonOutputToolsParser):
|
||||
|
||||
tools: List[Type[BaseModel]]
|
||||
|
||||
# TODO: Support more granular streaming of objects. Currently only streams once all
|
||||
# Pydantic object fields are present.
|
||||
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
|
||||
parsed_result = super().parse_result(result, partial=partial)
|
||||
json_results = super().parse_result(result, partial=partial)
|
||||
if not json_results:
|
||||
return None if self.first_tool_only else []
|
||||
|
||||
json_results = [json_results] if self.first_tool_only else json_results
|
||||
name_dict = {tool.__name__: tool for tool in self.tools}
|
||||
pydantic_objects = []
|
||||
for res in json_results:
|
||||
try:
|
||||
if not isinstance(res["args"], dict):
|
||||
raise ValueError(
|
||||
f"Tool arguments must be specified as a dict, received: "
|
||||
f"{res['args']}"
|
||||
)
|
||||
pydantic_objects.append(name_dict[res["type"]](**res["args"]))
|
||||
except (ValidationError, ValueError) as e:
|
||||
if partial:
|
||||
continue
|
||||
else:
|
||||
raise e
|
||||
if self.first_tool_only:
|
||||
return (
|
||||
name_dict[parsed_result["type"]](**parsed_result["args"])
|
||||
if parsed_result
|
||||
else None
|
||||
)
|
||||
return [name_dict[res["type"]](**res["args"]) for res in parsed_result]
|
||||
return pydantic_objects[0] if pydantic_objects else None
|
||||
else:
|
||||
return pydantic_objects
|
||||
|
483
libs/core/tests/unit_tests/output_parsers/test_openai_tools.py
Normal file
483
libs/core/tests/unit_tests/output_parsers/test_openai_tools.py
Normal file
@@ -0,0 +1,483 @@
|
||||
from typing import Any, AsyncIterator, Iterator, List
|
||||
|
||||
from langchain_core.messages import AIMessageChunk, BaseMessage
|
||||
from langchain_core.output_parsers.openai_tools import (
|
||||
JsonOutputKeyToolsParser,
|
||||
JsonOutputToolsParser,
|
||||
PydanticToolsParser,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
|
||||
STREAMED_MESSAGES: list = [
|
||||
AIMessageChunk(content=""),
|
||||
AIMessageChunk(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"tool_calls": [
|
||||
{
|
||||
"index": 0,
|
||||
"id": "call_OwL7f5PEPJTYzw9sQlNJtCZl",
|
||||
"function": {"arguments": "", "name": "NameCollector"},
|
||||
"type": "function",
|
||||
}
|
||||
]
|
||||
},
|
||||
),
|
||||
AIMessageChunk(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"tool_calls": [
|
||||
{
|
||||
"index": 0,
|
||||
"id": None,
|
||||
"function": {"arguments": '{"na', "name": None},
|
||||
"type": None,
|
||||
}
|
||||
]
|
||||
},
|
||||
),
|
||||
AIMessageChunk(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"tool_calls": [
|
||||
{
|
||||
"index": 0,
|
||||
"id": None,
|
||||
"function": {"arguments": 'mes":', "name": None},
|
||||
"type": None,
|
||||
}
|
||||
]
|
||||
},
|
||||
),
|
||||
AIMessageChunk(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"tool_calls": [
|
||||
{
|
||||
"index": 0,
|
||||
"id": None,
|
||||
"function": {"arguments": ' ["suz', "name": None},
|
||||
"type": None,
|
||||
}
|
||||
]
|
||||
},
|
||||
),
|
||||
AIMessageChunk(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"tool_calls": [
|
||||
{
|
||||
"index": 0,
|
||||
"id": None,
|
||||
"function": {"arguments": 'y", ', "name": None},
|
||||
"type": None,
|
||||
}
|
||||
]
|
||||
},
|
||||
),
|
||||
AIMessageChunk(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"tool_calls": [
|
||||
{
|
||||
"index": 0,
|
||||
"id": None,
|
||||
"function": {"arguments": '"jerm', "name": None},
|
||||
"type": None,
|
||||
}
|
||||
]
|
||||
},
|
||||
),
|
||||
AIMessageChunk(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"tool_calls": [
|
||||
{
|
||||
"index": 0,
|
||||
"id": None,
|
||||
"function": {"arguments": 'aine",', "name": None},
|
||||
"type": None,
|
||||
}
|
||||
]
|
||||
},
|
||||
),
|
||||
AIMessageChunk(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"tool_calls": [
|
||||
{
|
||||
"index": 0,
|
||||
"id": None,
|
||||
"function": {"arguments": ' "al', "name": None},
|
||||
"type": None,
|
||||
}
|
||||
]
|
||||
},
|
||||
),
|
||||
AIMessageChunk(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"tool_calls": [
|
||||
{
|
||||
"index": 0,
|
||||
"id": None,
|
||||
"function": {"arguments": 'ex"],', "name": None},
|
||||
"type": None,
|
||||
}
|
||||
]
|
||||
},
|
||||
),
|
||||
AIMessageChunk(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"tool_calls": [
|
||||
{
|
||||
"index": 0,
|
||||
"id": None,
|
||||
"function": {"arguments": ' "pers', "name": None},
|
||||
"type": None,
|
||||
}
|
||||
]
|
||||
},
|
||||
),
|
||||
AIMessageChunk(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"tool_calls": [
|
||||
{
|
||||
"index": 0,
|
||||
"id": None,
|
||||
"function": {"arguments": 'on":', "name": None},
|
||||
"type": None,
|
||||
}
|
||||
]
|
||||
},
|
||||
),
|
||||
AIMessageChunk(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"tool_calls": [
|
||||
{
|
||||
"index": 0,
|
||||
"id": None,
|
||||
"function": {"arguments": ' {"ag', "name": None},
|
||||
"type": None,
|
||||
}
|
||||
]
|
||||
},
|
||||
),
|
||||
AIMessageChunk(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"tool_calls": [
|
||||
{
|
||||
"index": 0,
|
||||
"id": None,
|
||||
"function": {"arguments": 'e": 39', "name": None},
|
||||
"type": None,
|
||||
}
|
||||
]
|
||||
},
|
||||
),
|
||||
AIMessageChunk(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"tool_calls": [
|
||||
{
|
||||
"index": 0,
|
||||
"id": None,
|
||||
"function": {"arguments": ', "h', "name": None},
|
||||
"type": None,
|
||||
}
|
||||
]
|
||||
},
|
||||
),
|
||||
AIMessageChunk(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"tool_calls": [
|
||||
{
|
||||
"index": 0,
|
||||
"id": None,
|
||||
"function": {"arguments": "air_c", "name": None},
|
||||
"type": None,
|
||||
}
|
||||
]
|
||||
},
|
||||
),
|
||||
AIMessageChunk(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"tool_calls": [
|
||||
{
|
||||
"index": 0,
|
||||
"id": None,
|
||||
"function": {"arguments": 'olor":', "name": None},
|
||||
"type": None,
|
||||
}
|
||||
]
|
||||
},
|
||||
),
|
||||
AIMessageChunk(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"tool_calls": [
|
||||
{
|
||||
"index": 0,
|
||||
"id": None,
|
||||
"function": {"arguments": ' "br', "name": None},
|
||||
"type": None,
|
||||
}
|
||||
]
|
||||
},
|
||||
),
|
||||
AIMessageChunk(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"tool_calls": [
|
||||
{
|
||||
"index": 0,
|
||||
"id": None,
|
||||
"function": {"arguments": 'own",', "name": None},
|
||||
"type": None,
|
||||
}
|
||||
]
|
||||
},
|
||||
),
|
||||
AIMessageChunk(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"tool_calls": [
|
||||
{
|
||||
"index": 0,
|
||||
"id": None,
|
||||
"function": {"arguments": ' "job"', "name": None},
|
||||
"type": None,
|
||||
}
|
||||
]
|
||||
},
|
||||
),
|
||||
AIMessageChunk(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"tool_calls": [
|
||||
{
|
||||
"index": 0,
|
||||
"id": None,
|
||||
"function": {"arguments": ': "c', "name": None},
|
||||
"type": None,
|
||||
}
|
||||
]
|
||||
},
|
||||
),
|
||||
AIMessageChunk(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"tool_calls": [
|
||||
{
|
||||
"index": 0,
|
||||
"id": None,
|
||||
"function": {"arguments": "oncie", "name": None},
|
||||
"type": None,
|
||||
}
|
||||
]
|
||||
},
|
||||
),
|
||||
AIMessageChunk(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"tool_calls": [
|
||||
{
|
||||
"index": 0,
|
||||
"id": None,
|
||||
"function": {"arguments": 'rge"}}', "name": None},
|
||||
"type": None,
|
||||
}
|
||||
]
|
||||
},
|
||||
),
|
||||
AIMessageChunk(content=""),
|
||||
]
|
||||
|
||||
|
||||
EXPECTED_STREAMED_JSON = [
|
||||
{},
|
||||
{"names": ["suz"]},
|
||||
{"names": ["suzy"]},
|
||||
{"names": ["suzy", "jerm"]},
|
||||
{"names": ["suzy", "jermaine"]},
|
||||
{"names": ["suzy", "jermaine", "al"]},
|
||||
{"names": ["suzy", "jermaine", "alex"]},
|
||||
{"names": ["suzy", "jermaine", "alex"], "person": {}},
|
||||
{"names": ["suzy", "jermaine", "alex"], "person": {"age": 39}},
|
||||
{"names": ["suzy", "jermaine", "alex"], "person": {"age": 39, "hair_color": "br"}},
|
||||
{
|
||||
"names": ["suzy", "jermaine", "alex"],
|
||||
"person": {"age": 39, "hair_color": "brown"},
|
||||
},
|
||||
{
|
||||
"names": ["suzy", "jermaine", "alex"],
|
||||
"person": {"age": 39, "hair_color": "brown", "job": "c"},
|
||||
},
|
||||
{
|
||||
"names": ["suzy", "jermaine", "alex"],
|
||||
"person": {"age": 39, "hair_color": "brown", "job": "concie"},
|
||||
},
|
||||
{
|
||||
"names": ["suzy", "jermaine", "alex"],
|
||||
"person": {"age": 39, "hair_color": "brown", "job": "concierge"},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def test_partial_json_output_parser() -> None:
|
||||
def input_iter(_: Any) -> Iterator[BaseMessage]:
|
||||
for msg in STREAMED_MESSAGES:
|
||||
yield msg
|
||||
|
||||
chain = input_iter | JsonOutputToolsParser()
|
||||
|
||||
actual = list(chain.stream(None))
|
||||
expected: list = [[]] + [
|
||||
[{"type": "NameCollector", "args": chunk}] for chunk in EXPECTED_STREAMED_JSON
|
||||
]
|
||||
assert actual == expected
|
||||
|
||||
|
||||
async def test_partial_json_output_parser_async() -> None:
|
||||
async def input_iter(_: Any) -> AsyncIterator[BaseMessage]:
|
||||
for token in STREAMED_MESSAGES:
|
||||
yield token
|
||||
|
||||
chain = input_iter | JsonOutputToolsParser()
|
||||
|
||||
actual = [p async for p in chain.astream(None)]
|
||||
expected: list = [[]] + [
|
||||
[{"type": "NameCollector", "args": chunk}] for chunk in EXPECTED_STREAMED_JSON
|
||||
]
|
||||
assert actual == expected
|
||||
|
||||
|
||||
def test_partial_json_output_parser_return_id() -> None:
|
||||
def input_iter(_: Any) -> Iterator[BaseMessage]:
|
||||
for msg in STREAMED_MESSAGES:
|
||||
yield msg
|
||||
|
||||
chain = input_iter | JsonOutputToolsParser(return_id=True)
|
||||
|
||||
actual = list(chain.stream(None))
|
||||
expected: list = [[]] + [
|
||||
[
|
||||
{
|
||||
"type": "NameCollector",
|
||||
"args": chunk,
|
||||
"id": "call_OwL7f5PEPJTYzw9sQlNJtCZl",
|
||||
}
|
||||
]
|
||||
for chunk in EXPECTED_STREAMED_JSON
|
||||
]
|
||||
assert actual == expected
|
||||
|
||||
|
||||
def test_partial_json_output_key_parser() -> None:
|
||||
def input_iter(_: Any) -> Iterator[BaseMessage]:
|
||||
for msg in STREAMED_MESSAGES:
|
||||
yield msg
|
||||
|
||||
chain = input_iter | JsonOutputKeyToolsParser(key_name="NameCollector")
|
||||
|
||||
actual = list(chain.stream(None))
|
||||
expected: list = [[]] + [[chunk] for chunk in EXPECTED_STREAMED_JSON]
|
||||
assert actual == expected
|
||||
|
||||
|
||||
async def test_partial_json_output_parser_key_async() -> None:
|
||||
async def input_iter(_: Any) -> AsyncIterator[BaseMessage]:
|
||||
for token in STREAMED_MESSAGES:
|
||||
yield token
|
||||
|
||||
chain = input_iter | JsonOutputKeyToolsParser(key_name="NameCollector")
|
||||
|
||||
actual = [p async for p in chain.astream(None)]
|
||||
expected: list = [[]] + [[chunk] for chunk in EXPECTED_STREAMED_JSON]
|
||||
assert actual == expected
|
||||
|
||||
|
||||
def test_partial_json_output_key_parser_first_only() -> None:
|
||||
def input_iter(_: Any) -> Iterator[BaseMessage]:
|
||||
for msg in STREAMED_MESSAGES:
|
||||
yield msg
|
||||
|
||||
chain = input_iter | JsonOutputKeyToolsParser(
|
||||
key_name="NameCollector", first_tool_only=True
|
||||
)
|
||||
|
||||
assert list(chain.stream(None)) == EXPECTED_STREAMED_JSON
|
||||
|
||||
|
||||
async def test_partial_json_output_parser_key_async_first_only() -> None:
|
||||
async def input_iter(_: Any) -> AsyncIterator[BaseMessage]:
|
||||
for token in STREAMED_MESSAGES:
|
||||
yield token
|
||||
|
||||
chain = input_iter | JsonOutputKeyToolsParser(
|
||||
key_name="NameCollector", first_tool_only=True
|
||||
)
|
||||
|
||||
assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON
|
||||
|
||||
|
||||
class Person(BaseModel):
|
||||
age: int
|
||||
hair_color: str
|
||||
job: str
|
||||
|
||||
|
||||
class NameCollector(BaseModel):
|
||||
"""record names of all people mentioned"""
|
||||
|
||||
names: List[str] = Field(..., description="all names mentioned")
|
||||
person: Person = Field(..., description="info about the main subject")
|
||||
|
||||
|
||||
# Expected to change when we support more granular pydantic streaming.
|
||||
EXPECTED_STREAMED_PYDANTIC = [
|
||||
NameCollector(
|
||||
names=["suzy", "jermaine", "alex"],
|
||||
person=Person(age=39, hair_color="brown", job="c"),
|
||||
),
|
||||
NameCollector(
|
||||
names=["suzy", "jermaine", "alex"],
|
||||
person=Person(age=39, hair_color="brown", job="concie"),
|
||||
),
|
||||
NameCollector(
|
||||
names=["suzy", "jermaine", "alex"],
|
||||
person=Person(age=39, hair_color="brown", job="concierge"),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def test_partial_pydantic_output_parser() -> None:
|
||||
def input_iter(_: Any) -> Iterator[BaseMessage]:
|
||||
for msg in STREAMED_MESSAGES:
|
||||
yield msg
|
||||
|
||||
chain = input_iter | PydanticToolsParser(
|
||||
tools=[NameCollector], first_tool_only=True
|
||||
)
|
||||
|
||||
actual = list(chain.stream(None))
|
||||
assert actual == EXPECTED_STREAMED_PYDANTIC
|
||||
|
||||
|
||||
async def test_partial_pydantic_output_parser_async() -> None:
|
||||
async def input_iter(_: Any) -> AsyncIterator[BaseMessage]:
|
||||
for token in STREAMED_MESSAGES:
|
||||
yield token
|
||||
|
||||
chain = input_iter | PydanticToolsParser(
|
||||
tools=[NameCollector], first_tool_only=True
|
||||
)
|
||||
|
||||
actual = [p async for p in chain.astream(None)]
|
||||
assert actual == EXPECTED_STREAMED_PYDANTIC
|
Reference in New Issue
Block a user