Nc/runnable lambda recurse (#9390)

<!-- Thank you for contributing to LangChain!

Replace this entire comment with:
  - Description: a description of the change, 
  - Issue: the issue # it fixes (if applicable),
  - Dependencies: any dependencies required for this change,
- Tag maintainer: for a quicker response, tag the relevant maintainer
(see below),
- Twitter handle: we announce bigger features on Twitter. If your PR
gets announced and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc:

https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. These live is docs/extras
directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17, @rlancemartin.
 -->
This commit is contained in:
Nuno Campos 2023-08-23 20:07:08 +01:00 committed by GitHub
commit fa05e18278
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 453 additions and 6 deletions

View File

@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import inspect
import threading import threading
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from concurrent.futures import FIRST_COMPLETED, wait from concurrent.futures import FIRST_COMPLETED, wait
@ -1343,9 +1344,18 @@ class RunnableLambda(Runnable[Input, Output]):
A runnable that runs a callable. A runnable that runs a callable.
""" """
def __init__(self, func: Callable[[Input], Output]) -> None: def __init__(
if callable(func): self,
self.func = func func: Union[Callable[[Input], Output], Callable[[Input], Awaitable[Output]]],
afunc: Optional[Callable[[Input], Awaitable[Output]]] = None,
) -> None:
if afunc is not None:
self.afunc = afunc
if inspect.iscoroutinefunction(func):
self.afunc = func
elif callable(func):
self.func = cast(Callable[[Input], Output], func)
else: else:
raise TypeError( raise TypeError(
"Expected a callable type for `func`." "Expected a callable type for `func`."
@ -1354,9 +1364,62 @@ class RunnableLambda(Runnable[Input, Output]):
def __eq__(self, other: Any) -> bool: def __eq__(self, other: Any) -> bool:
if isinstance(other, RunnableLambda): if isinstance(other, RunnableLambda):
if hasattr(self, "func") and hasattr(other, "func"):
return self.func == other.func return self.func == other.func
elif hasattr(self, "afunc") and hasattr(other, "afunc"):
return self.afunc == other.afunc
else: else:
return False return False
else:
return False
def _invoke(
self,
input: Input,
run_manager: CallbackManagerForChainRun,
config: RunnableConfig,
) -> Output:
output = self.func(input)
# If the output is a runnable, invoke it
if isinstance(output, Runnable):
recursion_limit = config["recursion_limit"]
if recursion_limit <= 0:
raise RecursionError(
f"Recursion limit reached when invoking {self} with input {input}."
)
output = output.invoke(
input,
patch_config(
config,
callbacks=run_manager.get_child(),
recursion_limit=recursion_limit - 1,
),
)
return output
async def _ainvoke(
self,
input: Input,
run_manager: AsyncCallbackManagerForChainRun,
config: RunnableConfig,
) -> Output:
output = await self.afunc(input)
# If the output is a runnable, invoke it
if isinstance(output, Runnable):
recursion_limit = config["recursion_limit"]
if recursion_limit <= 0:
raise RecursionError(
f"Recursion limit reached when invoking {self} with input {input}."
)
output = await output.ainvoke(
input,
patch_config(
config,
callbacks=run_manager.get_child(),
recursion_limit=recursion_limit - 1,
),
)
return output
def invoke( def invoke(
self, self,
@ -1364,7 +1427,26 @@ class RunnableLambda(Runnable[Input, Output]):
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> Output: ) -> Output:
return self._call_with_config(self.func, input, config) if hasattr(self, "func"):
return self._call_with_config(self._invoke, input, config)
else:
raise TypeError(
"Cannot invoke a coroutine function synchronously."
"Use `ainvoke` instead."
)
async def ainvoke(
self,
input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Output:
if hasattr(self, "afunc"):
return await self._acall_with_config(self._ainvoke, input, config)
else:
# Delegating to super implementation of ainvoke.
# Uses asyncio executor to run the sync version (invoke)
return await super().ainvoke(input, config)
class RunnableEach(Serializable, Runnable[List[Input], List[Output]]): class RunnableEach(Serializable, Runnable[List[Input], List[Output]]):

View File

@ -47,6 +47,11 @@ class RunnableConfig(TypedDict, total=False):
ThreadPoolExecutor will be created. ThreadPoolExecutor will be created.
""" """
recursion_limit: int
"""
Maximum number of times a call can recurse. If not provided, defaults to 10.
"""
def ensure_config(config: Optional[RunnableConfig]) -> RunnableConfig: def ensure_config(config: Optional[RunnableConfig]) -> RunnableConfig:
empty = RunnableConfig( empty = RunnableConfig(
@ -54,6 +59,7 @@ def ensure_config(config: Optional[RunnableConfig]) -> RunnableConfig:
metadata={}, metadata={},
callbacks=None, callbacks=None,
_locals={}, _locals={},
recursion_limit=10,
) )
if config is not None: if config is not None:
empty.update(config) empty.update(config)
@ -66,6 +72,7 @@ def patch_config(
deep_copy_locals: bool = False, deep_copy_locals: bool = False,
callbacks: Optional[BaseCallbackManager] = None, callbacks: Optional[BaseCallbackManager] = None,
executor: Optional[Executor] = None, executor: Optional[Executor] = None,
recursion_limit: Optional[int] = None,
) -> RunnableConfig: ) -> RunnableConfig:
config = ensure_config(config) config = ensure_config(config)
if deep_copy_locals: if deep_copy_locals:
@ -74,6 +81,8 @@ def patch_config(
config["callbacks"] = callbacks config["callbacks"] = callbacks
if executor is not None: if executor is not None:
config["executor"] = executor config["executor"] = executor
if recursion_limit is not None:
config["recursion_limit"] = recursion_limit
return config return config

File diff suppressed because one or more lines are too long

View File

@ -1,4 +1,5 @@
from typing import Any, Dict, List, Optional from operator import itemgetter
from typing import Any, Dict, List, Optional, Union
from uuid import UUID from uuid import UUID
import pytest import pytest
@ -176,6 +177,7 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
tags=[], tags=[],
callbacks=None, callbacks=None,
_locals={}, _locals={},
recursion_limit=10,
), ),
), ),
mocker.call( mocker.call(
@ -185,6 +187,7 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
tags=[], tags=[],
callbacks=None, callbacks=None,
_locals={}, _locals={},
recursion_limit=10,
), ),
), ),
] ]
@ -438,6 +441,50 @@ async def test_prompt_with_llm(
) )
@pytest.mark.asyncio
@freeze_time("2023-01-01")
async def test_prompt_with_llm_and_async_lambda(
mocker: MockerFixture, snapshot: SnapshotAssertion
) -> None:
prompt = (
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
+ "{question}"
)
llm = FakeListLLM(responses=["foo", "bar"])
async def passthrough(input: Any) -> Any:
return input
chain = prompt | llm | passthrough
assert isinstance(chain, RunnableSequence)
assert chain.first == prompt
assert chain.middle == [llm]
assert chain.last == RunnableLambda(func=passthrough)
assert dumps(chain, pretty=True) == snapshot
# Test invoke
prompt_spy = mocker.spy(prompt.__class__, "ainvoke")
llm_spy = mocker.spy(llm.__class__, "ainvoke")
tracer = FakeTracer()
assert (
await chain.ainvoke(
{"question": "What is your name?"}, dict(callbacks=[tracer])
)
== "foo"
)
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
assert llm_spy.call_args.args[1] == ChatPromptValue(
messages=[
SystemMessage(content="You are a nice assistant."),
HumanMessage(content="What is your name?"),
]
)
assert tracer.runs == snapshot
mocker.stop(prompt_spy)
mocker.stop(llm_spy)
@freeze_time("2023-01-01") @freeze_time("2023-01-01")
def test_prompt_with_chat_model_and_parser( def test_prompt_with_chat_model_and_parser(
mocker: MockerFixture, snapshot: SnapshotAssertion mocker: MockerFixture, snapshot: SnapshotAssertion
@ -722,6 +769,105 @@ async def test_router_runnable(
assert len(router_run.child_runs) == 2 assert len(router_run.child_runs) == 2
@pytest.mark.asyncio
@freeze_time("2023-01-01")
async def test_higher_order_lambda_runnable(
mocker: MockerFixture, snapshot: SnapshotAssertion
) -> None:
math_chain = ChatPromptTemplate.from_template(
"You are a math genius. Answer the question: {question}"
) | FakeListLLM(responses=["4"])
english_chain = ChatPromptTemplate.from_template(
"You are an english major. Answer the question: {question}"
) | FakeListLLM(responses=["2"])
input_map: Runnable = RunnableMap(
{ # type: ignore[arg-type]
"key": lambda x: x["key"],
"input": {"question": lambda x: x["question"]},
}
)
def router(input: Dict[str, Any]) -> Runnable:
if input["key"] == "math":
return itemgetter("input") | math_chain
elif input["key"] == "english":
return itemgetter("input") | english_chain
else:
raise ValueError(f"Unknown key: {input['key']}")
chain: Runnable = input_map | router
assert dumps(chain, pretty=True) == snapshot
result = chain.invoke({"key": "math", "question": "2 + 2"})
assert result == "4"
result2 = chain.batch(
[{"key": "math", "question": "2 + 2"}, {"key": "english", "question": "2 + 2"}]
)
assert result2 == ["4", "2"]
result = await chain.ainvoke({"key": "math", "question": "2 + 2"})
assert result == "4"
result2 = await chain.abatch(
[{"key": "math", "question": "2 + 2"}, {"key": "english", "question": "2 + 2"}]
)
assert result2 == ["4", "2"]
# Test invoke
math_spy = mocker.spy(math_chain.__class__, "invoke")
tracer = FakeTracer()
assert (
chain.invoke({"key": "math", "question": "2 + 2"}, dict(callbacks=[tracer]))
== "4"
)
assert math_spy.call_args.args[1] == {
"key": "math",
"input": {"question": "2 + 2"},
}
assert len([r for r in tracer.runs if r.parent_run_id is None]) == 1
parent_run = next(r for r in tracer.runs if r.parent_run_id is None)
assert len(parent_run.child_runs) == 2
router_run = parent_run.child_runs[1]
assert router_run.name == "RunnableLambda"
assert len(router_run.child_runs) == 1
math_run = router_run.child_runs[0]
assert math_run.name == "RunnableSequence"
assert len(math_run.child_runs) == 3
# Test ainvoke
async def arouter(input: Dict[str, Any]) -> Runnable:
if input["key"] == "math":
return itemgetter("input") | math_chain
elif input["key"] == "english":
return itemgetter("input") | english_chain
else:
raise ValueError(f"Unknown key: {input['key']}")
achain: Runnable = input_map | arouter
math_spy = mocker.spy(math_chain.__class__, "ainvoke")
tracer = FakeTracer()
assert (
await achain.ainvoke(
{"key": "math", "question": "2 + 2"}, dict(callbacks=[tracer])
)
== "4"
)
assert math_spy.call_args.args[1] == {
"key": "math",
"input": {"question": "2 + 2"},
}
assert len([r for r in tracer.runs if r.parent_run_id is None]) == 1
parent_run = next(r for r in tracer.runs if r.parent_run_id is None)
assert len(parent_run.child_runs) == 2
router_run = parent_run.child_runs[1]
assert router_run.name == "RunnableLambda"
assert len(router_run.child_runs) == 1
math_run = router_run.child_runs[0]
assert math_run.name == "RunnableSequence"
assert len(math_run.child_runs) == 3
@freeze_time("2023-01-01") @freeze_time("2023-01-01")
def test_seq_prompt_map(mocker: MockerFixture, snapshot: SnapshotAssertion) -> None: def test_seq_prompt_map(mocker: MockerFixture, snapshot: SnapshotAssertion) -> None:
passthrough = mocker.Mock(side_effect=lambda x: x) passthrough = mocker.Mock(side_effect=lambda x: x)
@ -1136,3 +1282,17 @@ def test_each(snapshot: SnapshotAssertion) -> None:
"test", "test",
"this", "this",
] ]
def test_recursive_lambda() -> None:
def _simple_recursion(x: int) -> Union[int, Runnable]:
if x < 10:
return RunnableLambda(lambda *args: _simple_recursion(x + 1))
else:
return x
runnable = RunnableLambda(_simple_recursion)
assert runnable.invoke(5) == 10
with pytest.raises(RecursionError):
runnable.invoke(0, {"recursion_limit": 9})