mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-09 06:24:47 +00:00
Add tests
This commit is contained in:
parent
b67db8deaa
commit
0318cdd33c
@ -2100,7 +2100,7 @@ class RunnableGenerator(Runnable[Input, Output]):
|
|||||||
params = inspect.signature(func).parameters
|
params = inspect.signature(func).parameters
|
||||||
first_param = next(iter(params.values()), None)
|
first_param = next(iter(params.values()), None)
|
||||||
if first_param and first_param.annotation != inspect.Parameter.empty:
|
if first_param and first_param.annotation != inspect.Parameter.empty:
|
||||||
return first_param.annotation
|
return getattr(first_param.annotation, "__args__", (Any,))[0]
|
||||||
else:
|
else:
|
||||||
return Any
|
return Any
|
||||||
except ValueError:
|
except ValueError:
|
||||||
@ -2112,7 +2112,7 @@ class RunnableGenerator(Runnable[Input, Output]):
|
|||||||
try:
|
try:
|
||||||
sig = inspect.signature(func)
|
sig = inspect.signature(func)
|
||||||
return (
|
return (
|
||||||
sig.return_annotation
|
getattr(sig.return_annotation, "__args__", (Any,))[0]
|
||||||
if sig.return_annotation != inspect.Signature.empty
|
if sig.return_annotation != inspect.Signature.empty
|
||||||
else Any
|
else Any
|
||||||
)
|
)
|
||||||
@ -2162,7 +2162,7 @@ class RunnableGenerator(Runnable[Input, Output]):
|
|||||||
final += output
|
final += output
|
||||||
return final
|
return final
|
||||||
|
|
||||||
async def atransform(
|
def atransform(
|
||||||
self,
|
self,
|
||||||
input: AsyncIterator[Input],
|
input: AsyncIterator[Input],
|
||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
@ -2175,7 +2175,7 @@ class RunnableGenerator(Runnable[Input, Output]):
|
|||||||
input, self._atransform, config, **kwargs
|
input, self._atransform, config, **kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
async def astream(
|
def astream(
|
||||||
self,
|
self,
|
||||||
input: Input,
|
input: Input,
|
||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
|
@ -1,7 +1,18 @@
|
|||||||
import sys
|
import sys
|
||||||
from operator import itemgetter
|
from operator import itemgetter
|
||||||
from typing import Any, Dict, List, Optional, Sequence, Union, cast
|
from typing import (
|
||||||
|
Any,
|
||||||
|
AsyncIterator,
|
||||||
|
Dict,
|
||||||
|
Iterator,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Union,
|
||||||
|
cast,
|
||||||
|
)
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
from langchain.schema.runnable.base import RunnableGenerator
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from freezegun import freeze_time
|
from freezegun import freeze_time
|
||||||
@ -2809,3 +2820,81 @@ async def test_tool_from_runnable() -> None:
|
|||||||
"title": "PromptInput",
|
"title": "PromptInput",
|
||||||
"type": "object",
|
"type": "object",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runnable_gen() -> None:
|
||||||
|
"""Test that a generator can be used as a runnable."""
|
||||||
|
|
||||||
|
def gen(input: Iterator[Any]) -> Iterator[int]:
|
||||||
|
yield 1
|
||||||
|
yield 2
|
||||||
|
yield 3
|
||||||
|
|
||||||
|
runnable = RunnableGenerator(gen)
|
||||||
|
|
||||||
|
assert runnable.input_schema.schema() == {"title": "RunnableGeneratorInput"}
|
||||||
|
assert runnable.output_schema.schema() == {
|
||||||
|
"title": "RunnableGeneratorOutput",
|
||||||
|
"type": "integer",
|
||||||
|
}
|
||||||
|
|
||||||
|
assert runnable.invoke(None) == 6
|
||||||
|
assert list(runnable.stream(None)) == [1, 2, 3]
|
||||||
|
assert runnable.batch([None, None]) == [6, 6]
|
||||||
|
|
||||||
|
async def agen(input: AsyncIterator[Any]) -> AsyncIterator[int]:
|
||||||
|
yield 1
|
||||||
|
yield 2
|
||||||
|
yield 3
|
||||||
|
|
||||||
|
arunnable = RunnableGenerator(agen)
|
||||||
|
|
||||||
|
assert await arunnable.ainvoke(None) == 6
|
||||||
|
assert [p async for p in arunnable.astream(None)] == [1, 2, 3]
|
||||||
|
assert await arunnable.abatch([None, None]) == [6, 6]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runnable_gen_transform() -> None:
|
||||||
|
"""Test that a generator can be used as a runnable."""
|
||||||
|
|
||||||
|
def gen_indexes(length_iter: Iterator[int]) -> Iterator[int]:
|
||||||
|
for i in range(next(length_iter)):
|
||||||
|
yield i
|
||||||
|
|
||||||
|
async def agen_indexes(length_iter: AsyncIterator[int]) -> AsyncIterator[int]:
|
||||||
|
async for length in length_iter:
|
||||||
|
for i in range(length):
|
||||||
|
yield i
|
||||||
|
|
||||||
|
def plus_one(input: Iterator[int]) -> Iterator[int]:
|
||||||
|
for i in input:
|
||||||
|
yield i + 1
|
||||||
|
|
||||||
|
async def aplus_one(input: AsyncIterator[int]) -> AsyncIterator[int]:
|
||||||
|
async for i in input:
|
||||||
|
yield i + 1
|
||||||
|
|
||||||
|
chain = RunnableGenerator(gen_indexes, agen_indexes) | plus_one
|
||||||
|
achain = RunnableGenerator(gen_indexes, agen_indexes) | aplus_one
|
||||||
|
|
||||||
|
assert chain.input_schema.schema() == {
|
||||||
|
"title": "RunnableGeneratorInput",
|
||||||
|
"type": "integer",
|
||||||
|
}
|
||||||
|
assert chain.output_schema.schema() == {
|
||||||
|
"title": "RunnableGeneratorOutput",
|
||||||
|
"type": "integer",
|
||||||
|
}
|
||||||
|
assert achain.input_schema.schema() == {
|
||||||
|
"title": "RunnableGeneratorInput",
|
||||||
|
"type": "integer",
|
||||||
|
}
|
||||||
|
assert achain.output_schema.schema() == {
|
||||||
|
"title": "RunnableGeneratorOutput",
|
||||||
|
"type": "integer",
|
||||||
|
}
|
||||||
|
|
||||||
|
assert list(chain.stream(3)) == [1, 2, 3]
|
||||||
|
assert [p async for p in achain.astream(4)] == [1, 2, 3, 4]
|
||||||
|
Loading…
Reference in New Issue
Block a user