mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-12 21:11:43 +00:00
core[minor]: add streaming support to OAI tool parsers (#18940)
Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
@@ -4,13 +4,13 @@ from json import JSONDecodeError
|
|||||||
from typing import Any, List, Type
|
from typing import Any, List, Type
|
||||||
|
|
||||||
from langchain_core.exceptions import OutputParserException
|
from langchain_core.exceptions import OutputParserException
|
||||||
from langchain_core.output_parsers import BaseGenerationOutputParser
|
from langchain_core.output_parsers import BaseCumulativeTransformOutputParser
|
||||||
from langchain_core.output_parsers.json import parse_partial_json
|
from langchain_core.output_parsers.json import parse_partial_json
|
||||||
from langchain_core.outputs import ChatGeneration, Generation
|
from langchain_core.outputs import ChatGeneration, Generation
|
||||||
from langchain_core.pydantic_v1 import BaseModel
|
from langchain_core.pydantic_v1 import BaseModel, ValidationError
|
||||||
|
|
||||||
|
|
||||||
class JsonOutputToolsParser(BaseGenerationOutputParser[Any]):
|
class JsonOutputToolsParser(BaseCumulativeTransformOutputParser[Any]):
|
||||||
"""Parse tools from OpenAI response."""
|
"""Parse tools from OpenAI response."""
|
||||||
|
|
||||||
strict: bool = False
|
strict: bool = False
|
||||||
@@ -50,22 +50,25 @@ class JsonOutputToolsParser(BaseGenerationOutputParser[Any]):
|
|||||||
for tool_call in tool_calls:
|
for tool_call in tool_calls:
|
||||||
if "function" not in tool_call:
|
if "function" not in tool_call:
|
||||||
continue
|
continue
|
||||||
try:
|
if partial:
|
||||||
if partial:
|
try:
|
||||||
function_args = parse_partial_json(
|
function_args = parse_partial_json(
|
||||||
tool_call["function"]["arguments"], strict=self.strict
|
tool_call["function"]["arguments"], strict=self.strict
|
||||||
)
|
)
|
||||||
else:
|
except JSONDecodeError:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
try:
|
||||||
function_args = json.loads(
|
function_args = json.loads(
|
||||||
tool_call["function"]["arguments"], strict=self.strict
|
tool_call["function"]["arguments"], strict=self.strict
|
||||||
)
|
)
|
||||||
except JSONDecodeError as e:
|
except JSONDecodeError as e:
|
||||||
exceptions.append(
|
exceptions.append(
|
||||||
f"Function {tool_call['function']['name']} arguments:\n\n"
|
f"Function {tool_call['function']['name']} arguments:\n\n"
|
||||||
f"{tool_call['function']['arguments']}\n\nare not valid JSON. "
|
f"{tool_call['function']['arguments']}\n\nare not valid JSON. "
|
||||||
f"Received JSONDecodeError {e}"
|
f"Received JSONDecodeError {e}"
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
parsed = {
|
parsed = {
|
||||||
"type": tool_call["function"]["name"],
|
"type": tool_call["function"]["name"],
|
||||||
"args": function_args,
|
"args": function_args,
|
||||||
@@ -79,6 +82,9 @@ class JsonOutputToolsParser(BaseGenerationOutputParser[Any]):
|
|||||||
return final_tools[0] if final_tools else None
|
return final_tools[0] if final_tools else None
|
||||||
return final_tools
|
return final_tools
|
||||||
|
|
||||||
|
def parse(self, text: str) -> Any:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
class JsonOutputKeyToolsParser(JsonOutputToolsParser):
|
class JsonOutputKeyToolsParser(JsonOutputToolsParser):
|
||||||
"""Parse tools from OpenAI response."""
|
"""Parse tools from OpenAI response."""
|
||||||
@@ -88,6 +94,7 @@ class JsonOutputKeyToolsParser(JsonOutputToolsParser):
|
|||||||
|
|
||||||
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
|
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
|
||||||
parsed_result = super().parse_result(result, partial=partial)
|
parsed_result = super().parse_result(result, partial=partial)
|
||||||
|
|
||||||
if self.first_tool_only:
|
if self.first_tool_only:
|
||||||
single_result = (
|
single_result = (
|
||||||
parsed_result
|
parsed_result
|
||||||
@@ -111,13 +118,30 @@ class PydanticToolsParser(JsonOutputToolsParser):
|
|||||||
|
|
||||||
tools: List[Type[BaseModel]]
|
tools: List[Type[BaseModel]]
|
||||||
|
|
||||||
|
# TODO: Support more granular streaming of objects. Currently only streams once all
|
||||||
|
# Pydantic object fields are present.
|
||||||
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
|
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
|
||||||
parsed_result = super().parse_result(result, partial=partial)
|
json_results = super().parse_result(result, partial=partial)
|
||||||
|
if not json_results:
|
||||||
|
return None if self.first_tool_only else []
|
||||||
|
|
||||||
|
json_results = [json_results] if self.first_tool_only else json_results
|
||||||
name_dict = {tool.__name__: tool for tool in self.tools}
|
name_dict = {tool.__name__: tool for tool in self.tools}
|
||||||
|
pydantic_objects = []
|
||||||
|
for res in json_results:
|
||||||
|
try:
|
||||||
|
if not isinstance(res["args"], dict):
|
||||||
|
raise ValueError(
|
||||||
|
f"Tool arguments must be specified as a dict, received: "
|
||||||
|
f"{res['args']}"
|
||||||
|
)
|
||||||
|
pydantic_objects.append(name_dict[res["type"]](**res["args"]))
|
||||||
|
except (ValidationError, ValueError) as e:
|
||||||
|
if partial:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
if self.first_tool_only:
|
if self.first_tool_only:
|
||||||
return (
|
return pydantic_objects[0] if pydantic_objects else None
|
||||||
name_dict[parsed_result["type"]](**parsed_result["args"])
|
else:
|
||||||
if parsed_result
|
return pydantic_objects
|
||||||
else None
|
|
||||||
)
|
|
||||||
return [name_dict[res["type"]](**res["args"]) for res in parsed_result]
|
|
||||||
|
483
libs/core/tests/unit_tests/output_parsers/test_openai_tools.py
Normal file
483
libs/core/tests/unit_tests/output_parsers/test_openai_tools.py
Normal file
@@ -0,0 +1,483 @@
|
|||||||
|
from typing import Any, AsyncIterator, Iterator, List
|
||||||
|
|
||||||
|
from langchain_core.messages import AIMessageChunk, BaseMessage
|
||||||
|
from langchain_core.output_parsers.openai_tools import (
|
||||||
|
JsonOutputKeyToolsParser,
|
||||||
|
JsonOutputToolsParser,
|
||||||
|
PydanticToolsParser,
|
||||||
|
)
|
||||||
|
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||||
|
|
||||||
|
STREAMED_MESSAGES: list = [
|
||||||
|
AIMessageChunk(content=""),
|
||||||
|
AIMessageChunk(
|
||||||
|
content="",
|
||||||
|
additional_kwargs={
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"id": "call_OwL7f5PEPJTYzw9sQlNJtCZl",
|
||||||
|
"function": {"arguments": "", "name": "NameCollector"},
|
||||||
|
"type": "function",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
),
|
||||||
|
AIMessageChunk(
|
||||||
|
content="",
|
||||||
|
additional_kwargs={
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"id": None,
|
||||||
|
"function": {"arguments": '{"na', "name": None},
|
||||||
|
"type": None,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
),
|
||||||
|
AIMessageChunk(
|
||||||
|
content="",
|
||||||
|
additional_kwargs={
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"id": None,
|
||||||
|
"function": {"arguments": 'mes":', "name": None},
|
||||||
|
"type": None,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
),
|
||||||
|
AIMessageChunk(
|
||||||
|
content="",
|
||||||
|
additional_kwargs={
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"id": None,
|
||||||
|
"function": {"arguments": ' ["suz', "name": None},
|
||||||
|
"type": None,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
),
|
||||||
|
AIMessageChunk(
|
||||||
|
content="",
|
||||||
|
additional_kwargs={
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"id": None,
|
||||||
|
"function": {"arguments": 'y", ', "name": None},
|
||||||
|
"type": None,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
),
|
||||||
|
AIMessageChunk(
|
||||||
|
content="",
|
||||||
|
additional_kwargs={
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"id": None,
|
||||||
|
"function": {"arguments": '"jerm', "name": None},
|
||||||
|
"type": None,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
),
|
||||||
|
AIMessageChunk(
|
||||||
|
content="",
|
||||||
|
additional_kwargs={
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"id": None,
|
||||||
|
"function": {"arguments": 'aine",', "name": None},
|
||||||
|
"type": None,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
),
|
||||||
|
AIMessageChunk(
|
||||||
|
content="",
|
||||||
|
additional_kwargs={
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"id": None,
|
||||||
|
"function": {"arguments": ' "al', "name": None},
|
||||||
|
"type": None,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
),
|
||||||
|
AIMessageChunk(
|
||||||
|
content="",
|
||||||
|
additional_kwargs={
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"id": None,
|
||||||
|
"function": {"arguments": 'ex"],', "name": None},
|
||||||
|
"type": None,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
),
|
||||||
|
AIMessageChunk(
|
||||||
|
content="",
|
||||||
|
additional_kwargs={
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"id": None,
|
||||||
|
"function": {"arguments": ' "pers', "name": None},
|
||||||
|
"type": None,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
),
|
||||||
|
AIMessageChunk(
|
||||||
|
content="",
|
||||||
|
additional_kwargs={
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"id": None,
|
||||||
|
"function": {"arguments": 'on":', "name": None},
|
||||||
|
"type": None,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
),
|
||||||
|
AIMessageChunk(
|
||||||
|
content="",
|
||||||
|
additional_kwargs={
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"id": None,
|
||||||
|
"function": {"arguments": ' {"ag', "name": None},
|
||||||
|
"type": None,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
),
|
||||||
|
AIMessageChunk(
|
||||||
|
content="",
|
||||||
|
additional_kwargs={
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"id": None,
|
||||||
|
"function": {"arguments": 'e": 39', "name": None},
|
||||||
|
"type": None,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
),
|
||||||
|
AIMessageChunk(
|
||||||
|
content="",
|
||||||
|
additional_kwargs={
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"id": None,
|
||||||
|
"function": {"arguments": ', "h', "name": None},
|
||||||
|
"type": None,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
),
|
||||||
|
AIMessageChunk(
|
||||||
|
content="",
|
||||||
|
additional_kwargs={
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"id": None,
|
||||||
|
"function": {"arguments": "air_c", "name": None},
|
||||||
|
"type": None,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
),
|
||||||
|
AIMessageChunk(
|
||||||
|
content="",
|
||||||
|
additional_kwargs={
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"id": None,
|
||||||
|
"function": {"arguments": 'olor":', "name": None},
|
||||||
|
"type": None,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
),
|
||||||
|
AIMessageChunk(
|
||||||
|
content="",
|
||||||
|
additional_kwargs={
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"id": None,
|
||||||
|
"function": {"arguments": ' "br', "name": None},
|
||||||
|
"type": None,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
),
|
||||||
|
AIMessageChunk(
|
||||||
|
content="",
|
||||||
|
additional_kwargs={
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"id": None,
|
||||||
|
"function": {"arguments": 'own",', "name": None},
|
||||||
|
"type": None,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
),
|
||||||
|
AIMessageChunk(
|
||||||
|
content="",
|
||||||
|
additional_kwargs={
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"id": None,
|
||||||
|
"function": {"arguments": ' "job"', "name": None},
|
||||||
|
"type": None,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
),
|
||||||
|
AIMessageChunk(
|
||||||
|
content="",
|
||||||
|
additional_kwargs={
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"id": None,
|
||||||
|
"function": {"arguments": ': "c', "name": None},
|
||||||
|
"type": None,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
),
|
||||||
|
AIMessageChunk(
|
||||||
|
content="",
|
||||||
|
additional_kwargs={
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"id": None,
|
||||||
|
"function": {"arguments": "oncie", "name": None},
|
||||||
|
"type": None,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
),
|
||||||
|
AIMessageChunk(
|
||||||
|
content="",
|
||||||
|
additional_kwargs={
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"id": None,
|
||||||
|
"function": {"arguments": 'rge"}}', "name": None},
|
||||||
|
"type": None,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
),
|
||||||
|
AIMessageChunk(content=""),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
EXPECTED_STREAMED_JSON = [
|
||||||
|
{},
|
||||||
|
{"names": ["suz"]},
|
||||||
|
{"names": ["suzy"]},
|
||||||
|
{"names": ["suzy", "jerm"]},
|
||||||
|
{"names": ["suzy", "jermaine"]},
|
||||||
|
{"names": ["suzy", "jermaine", "al"]},
|
||||||
|
{"names": ["suzy", "jermaine", "alex"]},
|
||||||
|
{"names": ["suzy", "jermaine", "alex"], "person": {}},
|
||||||
|
{"names": ["suzy", "jermaine", "alex"], "person": {"age": 39}},
|
||||||
|
{"names": ["suzy", "jermaine", "alex"], "person": {"age": 39, "hair_color": "br"}},
|
||||||
|
{
|
||||||
|
"names": ["suzy", "jermaine", "alex"],
|
||||||
|
"person": {"age": 39, "hair_color": "brown"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"names": ["suzy", "jermaine", "alex"],
|
||||||
|
"person": {"age": 39, "hair_color": "brown", "job": "c"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"names": ["suzy", "jermaine", "alex"],
|
||||||
|
"person": {"age": 39, "hair_color": "brown", "job": "concie"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"names": ["suzy", "jermaine", "alex"],
|
||||||
|
"person": {"age": 39, "hair_color": "brown", "job": "concierge"},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_partial_json_output_parser() -> None:
|
||||||
|
def input_iter(_: Any) -> Iterator[BaseMessage]:
|
||||||
|
for msg in STREAMED_MESSAGES:
|
||||||
|
yield msg
|
||||||
|
|
||||||
|
chain = input_iter | JsonOutputToolsParser()
|
||||||
|
|
||||||
|
actual = list(chain.stream(None))
|
||||||
|
expected: list = [[]] + [
|
||||||
|
[{"type": "NameCollector", "args": chunk}] for chunk in EXPECTED_STREAMED_JSON
|
||||||
|
]
|
||||||
|
assert actual == expected
|
||||||
|
|
||||||
|
|
||||||
|
async def test_partial_json_output_parser_async() -> None:
|
||||||
|
async def input_iter(_: Any) -> AsyncIterator[BaseMessage]:
|
||||||
|
for token in STREAMED_MESSAGES:
|
||||||
|
yield token
|
||||||
|
|
||||||
|
chain = input_iter | JsonOutputToolsParser()
|
||||||
|
|
||||||
|
actual = [p async for p in chain.astream(None)]
|
||||||
|
expected: list = [[]] + [
|
||||||
|
[{"type": "NameCollector", "args": chunk}] for chunk in EXPECTED_STREAMED_JSON
|
||||||
|
]
|
||||||
|
assert actual == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_partial_json_output_parser_return_id() -> None:
|
||||||
|
def input_iter(_: Any) -> Iterator[BaseMessage]:
|
||||||
|
for msg in STREAMED_MESSAGES:
|
||||||
|
yield msg
|
||||||
|
|
||||||
|
chain = input_iter | JsonOutputToolsParser(return_id=True)
|
||||||
|
|
||||||
|
actual = list(chain.stream(None))
|
||||||
|
expected: list = [[]] + [
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"type": "NameCollector",
|
||||||
|
"args": chunk,
|
||||||
|
"id": "call_OwL7f5PEPJTYzw9sQlNJtCZl",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
for chunk in EXPECTED_STREAMED_JSON
|
||||||
|
]
|
||||||
|
assert actual == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_partial_json_output_key_parser() -> None:
|
||||||
|
def input_iter(_: Any) -> Iterator[BaseMessage]:
|
||||||
|
for msg in STREAMED_MESSAGES:
|
||||||
|
yield msg
|
||||||
|
|
||||||
|
chain = input_iter | JsonOutputKeyToolsParser(key_name="NameCollector")
|
||||||
|
|
||||||
|
actual = list(chain.stream(None))
|
||||||
|
expected: list = [[]] + [[chunk] for chunk in EXPECTED_STREAMED_JSON]
|
||||||
|
assert actual == expected
|
||||||
|
|
||||||
|
|
||||||
|
async def test_partial_json_output_parser_key_async() -> None:
|
||||||
|
async def input_iter(_: Any) -> AsyncIterator[BaseMessage]:
|
||||||
|
for token in STREAMED_MESSAGES:
|
||||||
|
yield token
|
||||||
|
|
||||||
|
chain = input_iter | JsonOutputKeyToolsParser(key_name="NameCollector")
|
||||||
|
|
||||||
|
actual = [p async for p in chain.astream(None)]
|
||||||
|
expected: list = [[]] + [[chunk] for chunk in EXPECTED_STREAMED_JSON]
|
||||||
|
assert actual == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_partial_json_output_key_parser_first_only() -> None:
|
||||||
|
def input_iter(_: Any) -> Iterator[BaseMessage]:
|
||||||
|
for msg in STREAMED_MESSAGES:
|
||||||
|
yield msg
|
||||||
|
|
||||||
|
chain = input_iter | JsonOutputKeyToolsParser(
|
||||||
|
key_name="NameCollector", first_tool_only=True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert list(chain.stream(None)) == EXPECTED_STREAMED_JSON
|
||||||
|
|
||||||
|
|
||||||
|
async def test_partial_json_output_parser_key_async_first_only() -> None:
|
||||||
|
async def input_iter(_: Any) -> AsyncIterator[BaseMessage]:
|
||||||
|
for token in STREAMED_MESSAGES:
|
||||||
|
yield token
|
||||||
|
|
||||||
|
chain = input_iter | JsonOutputKeyToolsParser(
|
||||||
|
key_name="NameCollector", first_tool_only=True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON
|
||||||
|
|
||||||
|
|
||||||
|
class Person(BaseModel):
|
||||||
|
age: int
|
||||||
|
hair_color: str
|
||||||
|
job: str
|
||||||
|
|
||||||
|
|
||||||
|
class NameCollector(BaseModel):
|
||||||
|
"""record names of all people mentioned"""
|
||||||
|
|
||||||
|
names: List[str] = Field(..., description="all names mentioned")
|
||||||
|
person: Person = Field(..., description="info about the main subject")
|
||||||
|
|
||||||
|
|
||||||
|
# Expected to change when we support more granular pydantic streaming.
|
||||||
|
EXPECTED_STREAMED_PYDANTIC = [
|
||||||
|
NameCollector(
|
||||||
|
names=["suzy", "jermaine", "alex"],
|
||||||
|
person=Person(age=39, hair_color="brown", job="c"),
|
||||||
|
),
|
||||||
|
NameCollector(
|
||||||
|
names=["suzy", "jermaine", "alex"],
|
||||||
|
person=Person(age=39, hair_color="brown", job="concie"),
|
||||||
|
),
|
||||||
|
NameCollector(
|
||||||
|
names=["suzy", "jermaine", "alex"],
|
||||||
|
person=Person(age=39, hair_color="brown", job="concierge"),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_partial_pydantic_output_parser() -> None:
|
||||||
|
def input_iter(_: Any) -> Iterator[BaseMessage]:
|
||||||
|
for msg in STREAMED_MESSAGES:
|
||||||
|
yield msg
|
||||||
|
|
||||||
|
chain = input_iter | PydanticToolsParser(
|
||||||
|
tools=[NameCollector], first_tool_only=True
|
||||||
|
)
|
||||||
|
|
||||||
|
actual = list(chain.stream(None))
|
||||||
|
assert actual == EXPECTED_STREAMED_PYDANTIC
|
||||||
|
|
||||||
|
|
||||||
|
async def test_partial_pydantic_output_parser_async() -> None:
|
||||||
|
async def input_iter(_: Any) -> AsyncIterator[BaseMessage]:
|
||||||
|
for token in STREAMED_MESSAGES:
|
||||||
|
yield token
|
||||||
|
|
||||||
|
chain = input_iter | PydanticToolsParser(
|
||||||
|
tools=[NameCollector], first_tool_only=True
|
||||||
|
)
|
||||||
|
|
||||||
|
actual = [p async for p in chain.astream(None)]
|
||||||
|
assert actual == EXPECTED_STREAMED_PYDANTIC
|
Reference in New Issue
Block a user