mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-28 10:39:23 +00:00
Add async tests and comments
This commit is contained in:
parent
091d8845d5
commit
3d8aa88e26
@ -205,8 +205,10 @@ class PartialFunctionsJsonOutputParser(BaseCumulativeTransformOutputParser[Any])
|
||||
def _diff(self, prev: Optional[Any], next: Any) -> Any:
|
||||
return jsonpatch.make_patch(prev, next).patch
|
||||
|
||||
# This method would be called by the default implementation of `parse_result`
|
||||
# but we're overriding that method so it's not needed.
|
||||
def parse(self, text: str) -> Any:
|
||||
pass
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class PartialJsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
|
||||
|
@ -340,6 +340,8 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
|
||||
diff: bool = False
|
||||
|
||||
def _diff(self, prev: Optional[T], next: T) -> T:
|
||||
"""Convert parsed outputs into a diff format. The semantics of this are
|
||||
up to the output parser."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[Any]:
|
||||
|
@ -1,5 +1,5 @@
|
||||
import json
|
||||
from typing import Any, Iterator, Tuple
|
||||
from typing import Any, AsyncIterator, Iterator, Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
@ -492,3 +492,51 @@ def test_partial_functions_json_output_parser_diff() -> None:
|
||||
chain = input_iter | PartialFunctionsJsonOutputParser(diff=True)
|
||||
|
||||
assert list(chain.stream(None)) == EXPECTED_STREAMED_JSON_DIFF
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_partial_text_json_output_parser_async() -> None:
|
||||
async def input_iter(_: Any) -> AsyncIterator[str]:
|
||||
for token in STREAMED_TOKENS:
|
||||
yield token
|
||||
|
||||
chain = input_iter | PartialJsonOutputParser()
|
||||
|
||||
assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_partial_functions_json_output_parser_async() -> None:
|
||||
async def input_iter(_: Any) -> AsyncIterator[AIMessageChunk]:
|
||||
for token in STREAMED_TOKENS:
|
||||
yield AIMessageChunk(
|
||||
content="", additional_kwargs={"function_call": {"arguments": token}}
|
||||
)
|
||||
|
||||
chain = input_iter | PartialFunctionsJsonOutputParser()
|
||||
|
||||
assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_partial_text_json_output_parser_diff_async() -> None:
|
||||
async def input_iter(_: Any) -> AsyncIterator[str]:
|
||||
for token in STREAMED_TOKENS:
|
||||
yield token
|
||||
|
||||
chain = input_iter | PartialJsonOutputParser(diff=True)
|
||||
|
||||
assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON_DIFF
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_partial_functions_json_output_parser_diff_async() -> None:
|
||||
async def input_iter(_: Any) -> AsyncIterator[AIMessageChunk]:
|
||||
for token in STREAMED_TOKENS:
|
||||
yield AIMessageChunk(
|
||||
content="", additional_kwargs={"function_call": {"arguments": token}}
|
||||
)
|
||||
|
||||
chain = input_iter | PartialFunctionsJsonOutputParser(diff=True)
|
||||
|
||||
assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON_DIFF
|
||||
|
Loading…
Reference in New Issue
Block a user