Add support for async funcs in RunnableSequence

This commit is contained in:
Nuno Campos 2023-08-17 14:25:43 +01:00
parent 64a958c85d
commit 677da6a0fd
3 changed files with 199 additions and 4 deletions

View File

@ -1,6 +1,7 @@
from __future__ import annotations
import asyncio
import inspect
import threading
from abc import ABC, abstractmethod
from concurrent.futures import FIRST_COMPLETED, wait
@ -12,6 +13,7 @@ from typing import (
AsyncIterator,
Awaitable,
Callable,
Coroutine,
Dict,
Generic,
Iterator,
@ -1343,8 +1345,17 @@ class RunnableLambda(Runnable[Input, Output]):
A runnable that runs a callable.
"""
def __init__(self, func: Callable[[Input], Output]) -> None:
if callable(func):
def __init__(
self,
func: Union[Callable[[Input], Output], Coroutine[Input, Any, Output]],
afunc: Optional[Coroutine[Input, Any, Output]] = None,
) -> None:
if afunc is not None:
self.afunc = afunc
if inspect.iscoroutinefunction(func):
self.afunc = func
elif callable(func):
self.func = func
else:
raise TypeError(
@ -1354,7 +1365,12 @@ class RunnableLambda(Runnable[Input, Output]):
def __eq__(self, other: Any) -> bool:
if isinstance(other, RunnableLambda):
if hasattr(self, "func") and hasattr(other, "func"):
return self.func == other.func
elif hasattr(self, "afunc") and hasattr(other, "afunc"):
return self.afunc == other.afunc
else:
return False
else:
return False
@ -1364,7 +1380,24 @@ class RunnableLambda(Runnable[Input, Output]):
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Output:
if hasattr(self, "func"):
return self._call_with_config(self.func, input, config)
else:
raise TypeError(
"Cannot invoke a coroutine function synchronously."
"Use `ainvoke` instead."
)
async def ainvoke(
self,
input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Output:
if hasattr(self, "afunc"):
return await self._acall_with_config(self.afunc, input, config)
else:
return await super().ainvoke(input, config)
class RunnableEach(Serializable, Runnable[List[Input], List[Output]]):

File diff suppressed because one or more lines are too long

View File

@ -1,3 +1,4 @@
from ast import Not
from typing import Any, Dict, List, Optional
from uuid import UUID
@ -38,6 +39,7 @@ from langchain.schema.runnable import (
RunnablePassthrough,
RunnableSequence,
RunnableWithFallbacks,
passthrough,
)
@ -438,6 +440,50 @@ async def test_prompt_with_llm(
)
@pytest.mark.asyncio
@freeze_time("2023-01-01")
async def test_prompt_with_llm_and_async_lambda(
mocker: MockerFixture, snapshot: SnapshotAssertion
) -> None:
prompt = (
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
+ "{question}"
)
llm = FakeListLLM(responses=["foo", "bar"])
async def passthrough(input: Any) -> Any:
return input
chain = prompt | llm | passthrough
assert isinstance(chain, RunnableSequence)
assert chain.first == prompt
assert chain.middle == [llm]
assert chain.last == RunnableLambda(func=passthrough)
assert dumps(chain, pretty=True) == snapshot
# Test invoke
prompt_spy = mocker.spy(prompt.__class__, "ainvoke")
llm_spy = mocker.spy(llm.__class__, "ainvoke")
tracer = FakeTracer()
assert (
await chain.ainvoke(
{"question": "What is your name?"}, dict(callbacks=[tracer])
)
== "foo"
)
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
assert llm_spy.call_args.args[1] == ChatPromptValue(
messages=[
SystemMessage(content="You are a nice assistant."),
HumanMessage(content="What is your name?"),
]
)
assert tracer.runs == snapshot
mocker.stop(prompt_spy)
mocker.stop(llm_spy)
@freeze_time("2023-01-01")
def test_prompt_with_chat_model_and_parser(
mocker: MockerFixture, snapshot: SnapshotAssertion