Compare commits

...

1 Commits

Author SHA1 Message Date
William Fu-Hinthorn
0508176982 Update lambda typing 2023-09-08 15:49:00 -07:00
3 changed files with 33 additions and 2 deletions

View File

@@ -1692,8 +1692,18 @@ class RunnableLambda(Runnable[Input, Output]):
def __init__(
self,
func: Union[Callable[[Input], Output], Callable[[Input], Awaitable[Output]]],
afunc: Optional[Callable[[Input], Awaitable[Output]]] = None,
func: Union[
Callable[[Input], Output],
Callable[[Input, RunnableConfig], Output],
Callable[[Input], Awaitable[Output]],
Callable[[Input, RunnableConfig], Awaitable[Output]],
],
afunc: Optional[
Union[
Callable[[Input], Awaitable[Output]],
Callable[[Input, RunnableConfig], Awaitable[Output]],
]
] = None,
) -> None:
if afunc is not None:
self.afunc = afunc

View File

@@ -143,6 +143,7 @@ def call_func_with_variable_args(
func: Union[
Callable[[Input], Output],
Callable[[Input, CallbackManagerForChainRun], Output],
Callable[[Input, RunnableConfig], Output],
Callable[[Input, CallbackManagerForChainRun, RunnableConfig], Output],
],
input: Input,
@@ -166,6 +167,10 @@ async def acall_func_with_variable_args(
[Input, AsyncCallbackManagerForChainRun, RunnableConfig],
Awaitable[Output],
],
Callable[
[Input, RunnableConfig],
Awaitable[Output],
],
],
input: Input,
run_manager: AsyncCallbackManagerForChainRun,

View File

@@ -1,3 +1,4 @@
import asyncio
from operator import itemgetter
from typing import Any, Dict, List, Optional, Union
from uuid import UUID
@@ -1785,3 +1786,18 @@ async def test_seq_abatch_return_exceptions(mocker: MockerFixture) -> None:
assert parent_run_qux.outputs["output"] == "quxaaaa"
assert len(parent_run_qux.child_runs) == 4
assert [r.error for r in parent_run_qux.child_runs] == [None, None, None, None]
@pytest.mark.asyncio
async def test_lambda_accept_config() -> None:
def sync_with_config(x: str, config: RunnableConfig) -> str:
return x
RunnableLambda(sync_with_config).invoke("foo")
async def async_with_config(x: str, config: RunnableConfig) -> str:
asyncio.sleep(0.001)
return x
await RunnableLambda(async_with_config).ainvoke("foo")
await RunnableLambda(sync_with_config, async_with_config).abatch(["foo"])