mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-17 18:23:59 +00:00
Add support for async funcs in RunnableSequence
This commit is contained in:
parent
64a958c85d
commit
677da6a0fd
@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import threading
|
||||
from abc import ABC, abstractmethod
|
||||
from concurrent.futures import FIRST_COMPLETED, wait
|
||||
@ -12,6 +13,7 @@ from typing import (
|
||||
AsyncIterator,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Coroutine,
|
||||
Dict,
|
||||
Generic,
|
||||
Iterator,
|
||||
@ -1343,8 +1345,17 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
A runnable that runs a callable.
|
||||
"""
|
||||
|
||||
def __init__(self, func: Callable[[Input], Output]) -> None:
|
||||
if callable(func):
|
||||
def __init__(
|
||||
self,
|
||||
func: Union[Callable[[Input], Output], Coroutine[Input, Any, Output]],
|
||||
afunc: Optional[Coroutine[Input, Any, Output]] = None,
|
||||
) -> None:
|
||||
if afunc is not None:
|
||||
self.afunc = afunc
|
||||
|
||||
if inspect.iscoroutinefunction(func):
|
||||
self.afunc = func
|
||||
elif callable(func):
|
||||
self.func = func
|
||||
else:
|
||||
raise TypeError(
|
||||
@ -1354,7 +1365,12 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if isinstance(other, RunnableLambda):
|
||||
if hasattr(self, "func") and hasattr(other, "func"):
|
||||
return self.func == other.func
|
||||
elif hasattr(self, "afunc") and hasattr(other, "afunc"):
|
||||
return self.afunc == other.afunc
|
||||
else:
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
|
||||
@ -1364,7 +1380,24 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> Output:
|
||||
if hasattr(self, "func"):
|
||||
return self._call_with_config(self.func, input, config)
|
||||
else:
|
||||
raise TypeError(
|
||||
"Cannot invoke a coroutine function synchronously."
|
||||
"Use `ainvoke` instead."
|
||||
)
|
||||
|
||||
async def ainvoke(
|
||||
self,
|
||||
input: Input,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> Output:
|
||||
if hasattr(self, "afunc"):
|
||||
return await self._acall_with_config(self.afunc, input, config)
|
||||
else:
|
||||
return await super().ainvoke(input, config)
|
||||
|
||||
|
||||
class RunnableEach(Serializable, Runnable[List[Input], List[Output]]):
|
||||
|
File diff suppressed because one or more lines are too long
@ -1,3 +1,4 @@
|
||||
from ast import Not
|
||||
from typing import Any, Dict, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
@ -38,6 +39,7 @@ from langchain.schema.runnable import (
|
||||
RunnablePassthrough,
|
||||
RunnableSequence,
|
||||
RunnableWithFallbacks,
|
||||
passthrough,
|
||||
)
|
||||
|
||||
|
||||
@ -438,6 +440,50 @@ async def test_prompt_with_llm(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@freeze_time("2023-01-01")
|
||||
async def test_prompt_with_llm_and_async_lambda(
|
||||
mocker: MockerFixture, snapshot: SnapshotAssertion
|
||||
) -> None:
|
||||
prompt = (
|
||||
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
|
||||
+ "{question}"
|
||||
)
|
||||
llm = FakeListLLM(responses=["foo", "bar"])
|
||||
|
||||
async def passthrough(input: Any) -> Any:
|
||||
return input
|
||||
|
||||
chain = prompt | llm | passthrough
|
||||
|
||||
assert isinstance(chain, RunnableSequence)
|
||||
assert chain.first == prompt
|
||||
assert chain.middle == [llm]
|
||||
assert chain.last == RunnableLambda(func=passthrough)
|
||||
assert dumps(chain, pretty=True) == snapshot
|
||||
|
||||
# Test invoke
|
||||
prompt_spy = mocker.spy(prompt.__class__, "ainvoke")
|
||||
llm_spy = mocker.spy(llm.__class__, "ainvoke")
|
||||
tracer = FakeTracer()
|
||||
assert (
|
||||
await chain.ainvoke(
|
||||
{"question": "What is your name?"}, dict(callbacks=[tracer])
|
||||
)
|
||||
== "foo"
|
||||
)
|
||||
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
|
||||
assert llm_spy.call_args.args[1] == ChatPromptValue(
|
||||
messages=[
|
||||
SystemMessage(content="You are a nice assistant."),
|
||||
HumanMessage(content="What is your name?"),
|
||||
]
|
||||
)
|
||||
assert tracer.runs == snapshot
|
||||
mocker.stop(prompt_spy)
|
||||
mocker.stop(llm_spy)
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_prompt_with_chat_model_and_parser(
|
||||
mocker: MockerFixture, snapshot: SnapshotAssertion
|
||||
|
Loading…
Reference in New Issue
Block a user