mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-29 19:18:53 +00:00
Add async tests and comments
This commit is contained in:
parent
091d8845d5
commit
3d8aa88e26
@ -205,8 +205,10 @@ class PartialFunctionsJsonOutputParser(BaseCumulativeTransformOutputParser[Any])
|
|||||||
def _diff(self, prev: Optional[Any], next: Any) -> Any:
|
def _diff(self, prev: Optional[Any], next: Any) -> Any:
|
||||||
return jsonpatch.make_patch(prev, next).patch
|
return jsonpatch.make_patch(prev, next).patch
|
||||||
|
|
||||||
|
# This method would be called by the default implementation of `parse_result`
|
||||||
|
# but we're overriding that method so it's not needed.
|
||||||
def parse(self, text: str) -> Any:
|
def parse(self, text: str) -> Any:
|
||||||
pass
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
class PartialJsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
|
class PartialJsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
|
||||||
|
@ -340,6 +340,8 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
|
|||||||
diff: bool = False
|
diff: bool = False
|
||||||
|
|
||||||
def _diff(self, prev: Optional[T], next: T) -> T:
|
def _diff(self, prev: Optional[T], next: T) -> T:
|
||||||
|
"""Convert parsed outputs into a diff format. The semantics of this are
|
||||||
|
up to the output parser."""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[Any]:
|
def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[Any]:
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
from typing import Any, Iterator, Tuple
|
from typing import Any, AsyncIterator, Iterator, Tuple
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -492,3 +492,51 @@ def test_partial_functions_json_output_parser_diff() -> None:
|
|||||||
chain = input_iter | PartialFunctionsJsonOutputParser(diff=True)
|
chain = input_iter | PartialFunctionsJsonOutputParser(diff=True)
|
||||||
|
|
||||||
assert list(chain.stream(None)) == EXPECTED_STREAMED_JSON_DIFF
|
assert list(chain.stream(None)) == EXPECTED_STREAMED_JSON_DIFF
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_partial_text_json_output_parser_async() -> None:
|
||||||
|
async def input_iter(_: Any) -> AsyncIterator[str]:
|
||||||
|
for token in STREAMED_TOKENS:
|
||||||
|
yield token
|
||||||
|
|
||||||
|
chain = input_iter | PartialJsonOutputParser()
|
||||||
|
|
||||||
|
assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_partial_functions_json_output_parser_async() -> None:
|
||||||
|
async def input_iter(_: Any) -> AsyncIterator[AIMessageChunk]:
|
||||||
|
for token in STREAMED_TOKENS:
|
||||||
|
yield AIMessageChunk(
|
||||||
|
content="", additional_kwargs={"function_call": {"arguments": token}}
|
||||||
|
)
|
||||||
|
|
||||||
|
chain = input_iter | PartialFunctionsJsonOutputParser()
|
||||||
|
|
||||||
|
assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_partial_text_json_output_parser_diff_async() -> None:
|
||||||
|
async def input_iter(_: Any) -> AsyncIterator[str]:
|
||||||
|
for token in STREAMED_TOKENS:
|
||||||
|
yield token
|
||||||
|
|
||||||
|
chain = input_iter | PartialJsonOutputParser(diff=True)
|
||||||
|
|
||||||
|
assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON_DIFF
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_partial_functions_json_output_parser_diff_async() -> None:
|
||||||
|
async def input_iter(_: Any) -> AsyncIterator[AIMessageChunk]:
|
||||||
|
for token in STREAMED_TOKENS:
|
||||||
|
yield AIMessageChunk(
|
||||||
|
content="", additional_kwargs={"function_call": {"arguments": token}}
|
||||||
|
)
|
||||||
|
|
||||||
|
chain = input_iter | PartialFunctionsJsonOutputParser(diff=True)
|
||||||
|
|
||||||
|
assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON_DIFF
|
||||||
|
Loading…
Reference in New Issue
Block a user