mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-01 12:38:45 +00:00
Nc/runnable lambda recurse (#9390)
<!-- Thank you for contributing to LangChain! Replace this entire comment with: - Description: a description of the change, - Issue: the issue # it fixes (if applicable), - Dependencies: any dependencies required for this change, - Tag maintainer: for a quicker response, tag the relevant maintainer (see below), - Twitter handle: we announce bigger features on Twitter. If your PR gets announced and you'd like a mention, we'll gladly shout you out! Please make sure your PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` to check this locally. See contribution guidelines for more information on how to write/run tests, lint, etc: https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. These live is docs/extras directory. If no one reviews your PR within a few days, please @-mention one of @baskaryan, @eyurtsev, @hwchase17, @rlancemartin. -->
This commit is contained in:
commit
fa05e18278
@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import threading
|
||||
from abc import ABC, abstractmethod
|
||||
from concurrent.futures import FIRST_COMPLETED, wait
|
||||
@ -1343,9 +1344,18 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
A runnable that runs a callable.
|
||||
"""
|
||||
|
||||
def __init__(self, func: Callable[[Input], Output]) -> None:
|
||||
if callable(func):
|
||||
self.func = func
|
||||
def __init__(
|
||||
self,
|
||||
func: Union[Callable[[Input], Output], Callable[[Input], Awaitable[Output]]],
|
||||
afunc: Optional[Callable[[Input], Awaitable[Output]]] = None,
|
||||
) -> None:
|
||||
if afunc is not None:
|
||||
self.afunc = afunc
|
||||
|
||||
if inspect.iscoroutinefunction(func):
|
||||
self.afunc = func
|
||||
elif callable(func):
|
||||
self.func = cast(Callable[[Input], Output], func)
|
||||
else:
|
||||
raise TypeError(
|
||||
"Expected a callable type for `func`."
|
||||
@ -1354,17 +1364,89 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if isinstance(other, RunnableLambda):
|
||||
return self.func == other.func
|
||||
if hasattr(self, "func") and hasattr(other, "func"):
|
||||
return self.func == other.func
|
||||
elif hasattr(self, "afunc") and hasattr(other, "afunc"):
|
||||
return self.afunc == other.afunc
|
||||
else:
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
input: Input,
|
||||
run_manager: CallbackManagerForChainRun,
|
||||
config: RunnableConfig,
|
||||
) -> Output:
|
||||
output = self.func(input)
|
||||
# If the output is a runnable, invoke it
|
||||
if isinstance(output, Runnable):
|
||||
recursion_limit = config["recursion_limit"]
|
||||
if recursion_limit <= 0:
|
||||
raise RecursionError(
|
||||
f"Recursion limit reached when invoking {self} with input {input}."
|
||||
)
|
||||
output = output.invoke(
|
||||
input,
|
||||
patch_config(
|
||||
config,
|
||||
callbacks=run_manager.get_child(),
|
||||
recursion_limit=recursion_limit - 1,
|
||||
),
|
||||
)
|
||||
return output
|
||||
|
||||
async def _ainvoke(
|
||||
self,
|
||||
input: Input,
|
||||
run_manager: AsyncCallbackManagerForChainRun,
|
||||
config: RunnableConfig,
|
||||
) -> Output:
|
||||
output = await self.afunc(input)
|
||||
# If the output is a runnable, invoke it
|
||||
if isinstance(output, Runnable):
|
||||
recursion_limit = config["recursion_limit"]
|
||||
if recursion_limit <= 0:
|
||||
raise RecursionError(
|
||||
f"Recursion limit reached when invoking {self} with input {input}."
|
||||
)
|
||||
output = await output.ainvoke(
|
||||
input,
|
||||
patch_config(
|
||||
config,
|
||||
callbacks=run_manager.get_child(),
|
||||
recursion_limit=recursion_limit - 1,
|
||||
),
|
||||
)
|
||||
return output
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
input: Input,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> Output:
|
||||
return self._call_with_config(self.func, input, config)
|
||||
if hasattr(self, "func"):
|
||||
return self._call_with_config(self._invoke, input, config)
|
||||
else:
|
||||
raise TypeError(
|
||||
"Cannot invoke a coroutine function synchronously."
|
||||
"Use `ainvoke` instead."
|
||||
)
|
||||
|
||||
async def ainvoke(
|
||||
self,
|
||||
input: Input,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> Output:
|
||||
if hasattr(self, "afunc"):
|
||||
return await self._acall_with_config(self._ainvoke, input, config)
|
||||
else:
|
||||
# Delegating to super implementation of ainvoke.
|
||||
# Uses asyncio executor to run the sync version (invoke)
|
||||
return await super().ainvoke(input, config)
|
||||
|
||||
|
||||
class RunnableEach(Serializable, Runnable[List[Input], List[Output]]):
|
||||
|
@ -47,6 +47,11 @@ class RunnableConfig(TypedDict, total=False):
|
||||
ThreadPoolExecutor will be created.
|
||||
"""
|
||||
|
||||
recursion_limit: int
|
||||
"""
|
||||
Maximum number of times a call can recurse. If not provided, defaults to 10.
|
||||
"""
|
||||
|
||||
|
||||
def ensure_config(config: Optional[RunnableConfig]) -> RunnableConfig:
|
||||
empty = RunnableConfig(
|
||||
@ -54,6 +59,7 @@ def ensure_config(config: Optional[RunnableConfig]) -> RunnableConfig:
|
||||
metadata={},
|
||||
callbacks=None,
|
||||
_locals={},
|
||||
recursion_limit=10,
|
||||
)
|
||||
if config is not None:
|
||||
empty.update(config)
|
||||
@ -66,6 +72,7 @@ def patch_config(
|
||||
deep_copy_locals: bool = False,
|
||||
callbacks: Optional[BaseCallbackManager] = None,
|
||||
executor: Optional[Executor] = None,
|
||||
recursion_limit: Optional[int] = None,
|
||||
) -> RunnableConfig:
|
||||
config = ensure_config(config)
|
||||
if deep_copy_locals:
|
||||
@ -74,6 +81,8 @@ def patch_config(
|
||||
config["callbacks"] = callbacks
|
||||
if executor is not None:
|
||||
config["executor"] = executor
|
||||
if recursion_limit is not None:
|
||||
config["recursion_limit"] = recursion_limit
|
||||
return config
|
||||
|
||||
|
||||
|
File diff suppressed because one or more lines are too long
@ -1,4 +1,5 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
from operator import itemgetter
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
@ -176,6 +177,7 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
|
||||
tags=[],
|
||||
callbacks=None,
|
||||
_locals={},
|
||||
recursion_limit=10,
|
||||
),
|
||||
),
|
||||
mocker.call(
|
||||
@ -185,6 +187,7 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
|
||||
tags=[],
|
||||
callbacks=None,
|
||||
_locals={},
|
||||
recursion_limit=10,
|
||||
),
|
||||
),
|
||||
]
|
||||
@ -438,6 +441,50 @@ async def test_prompt_with_llm(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@freeze_time("2023-01-01")
|
||||
async def test_prompt_with_llm_and_async_lambda(
|
||||
mocker: MockerFixture, snapshot: SnapshotAssertion
|
||||
) -> None:
|
||||
prompt = (
|
||||
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
|
||||
+ "{question}"
|
||||
)
|
||||
llm = FakeListLLM(responses=["foo", "bar"])
|
||||
|
||||
async def passthrough(input: Any) -> Any:
|
||||
return input
|
||||
|
||||
chain = prompt | llm | passthrough
|
||||
|
||||
assert isinstance(chain, RunnableSequence)
|
||||
assert chain.first == prompt
|
||||
assert chain.middle == [llm]
|
||||
assert chain.last == RunnableLambda(func=passthrough)
|
||||
assert dumps(chain, pretty=True) == snapshot
|
||||
|
||||
# Test invoke
|
||||
prompt_spy = mocker.spy(prompt.__class__, "ainvoke")
|
||||
llm_spy = mocker.spy(llm.__class__, "ainvoke")
|
||||
tracer = FakeTracer()
|
||||
assert (
|
||||
await chain.ainvoke(
|
||||
{"question": "What is your name?"}, dict(callbacks=[tracer])
|
||||
)
|
||||
== "foo"
|
||||
)
|
||||
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
|
||||
assert llm_spy.call_args.args[1] == ChatPromptValue(
|
||||
messages=[
|
||||
SystemMessage(content="You are a nice assistant."),
|
||||
HumanMessage(content="What is your name?"),
|
||||
]
|
||||
)
|
||||
assert tracer.runs == snapshot
|
||||
mocker.stop(prompt_spy)
|
||||
mocker.stop(llm_spy)
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_prompt_with_chat_model_and_parser(
|
||||
mocker: MockerFixture, snapshot: SnapshotAssertion
|
||||
@ -722,6 +769,105 @@ async def test_router_runnable(
|
||||
assert len(router_run.child_runs) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@freeze_time("2023-01-01")
|
||||
async def test_higher_order_lambda_runnable(
|
||||
mocker: MockerFixture, snapshot: SnapshotAssertion
|
||||
) -> None:
|
||||
math_chain = ChatPromptTemplate.from_template(
|
||||
"You are a math genius. Answer the question: {question}"
|
||||
) | FakeListLLM(responses=["4"])
|
||||
english_chain = ChatPromptTemplate.from_template(
|
||||
"You are an english major. Answer the question: {question}"
|
||||
) | FakeListLLM(responses=["2"])
|
||||
input_map: Runnable = RunnableMap(
|
||||
{ # type: ignore[arg-type]
|
||||
"key": lambda x: x["key"],
|
||||
"input": {"question": lambda x: x["question"]},
|
||||
}
|
||||
)
|
||||
|
||||
def router(input: Dict[str, Any]) -> Runnable:
|
||||
if input["key"] == "math":
|
||||
return itemgetter("input") | math_chain
|
||||
elif input["key"] == "english":
|
||||
return itemgetter("input") | english_chain
|
||||
else:
|
||||
raise ValueError(f"Unknown key: {input['key']}")
|
||||
|
||||
chain: Runnable = input_map | router
|
||||
assert dumps(chain, pretty=True) == snapshot
|
||||
|
||||
result = chain.invoke({"key": "math", "question": "2 + 2"})
|
||||
assert result == "4"
|
||||
|
||||
result2 = chain.batch(
|
||||
[{"key": "math", "question": "2 + 2"}, {"key": "english", "question": "2 + 2"}]
|
||||
)
|
||||
assert result2 == ["4", "2"]
|
||||
|
||||
result = await chain.ainvoke({"key": "math", "question": "2 + 2"})
|
||||
assert result == "4"
|
||||
|
||||
result2 = await chain.abatch(
|
||||
[{"key": "math", "question": "2 + 2"}, {"key": "english", "question": "2 + 2"}]
|
||||
)
|
||||
assert result2 == ["4", "2"]
|
||||
|
||||
# Test invoke
|
||||
math_spy = mocker.spy(math_chain.__class__, "invoke")
|
||||
tracer = FakeTracer()
|
||||
assert (
|
||||
chain.invoke({"key": "math", "question": "2 + 2"}, dict(callbacks=[tracer]))
|
||||
== "4"
|
||||
)
|
||||
assert math_spy.call_args.args[1] == {
|
||||
"key": "math",
|
||||
"input": {"question": "2 + 2"},
|
||||
}
|
||||
assert len([r for r in tracer.runs if r.parent_run_id is None]) == 1
|
||||
parent_run = next(r for r in tracer.runs if r.parent_run_id is None)
|
||||
assert len(parent_run.child_runs) == 2
|
||||
router_run = parent_run.child_runs[1]
|
||||
assert router_run.name == "RunnableLambda"
|
||||
assert len(router_run.child_runs) == 1
|
||||
math_run = router_run.child_runs[0]
|
||||
assert math_run.name == "RunnableSequence"
|
||||
assert len(math_run.child_runs) == 3
|
||||
|
||||
# Test ainvoke
|
||||
async def arouter(input: Dict[str, Any]) -> Runnable:
|
||||
if input["key"] == "math":
|
||||
return itemgetter("input") | math_chain
|
||||
elif input["key"] == "english":
|
||||
return itemgetter("input") | english_chain
|
||||
else:
|
||||
raise ValueError(f"Unknown key: {input['key']}")
|
||||
|
||||
achain: Runnable = input_map | arouter
|
||||
math_spy = mocker.spy(math_chain.__class__, "ainvoke")
|
||||
tracer = FakeTracer()
|
||||
assert (
|
||||
await achain.ainvoke(
|
||||
{"key": "math", "question": "2 + 2"}, dict(callbacks=[tracer])
|
||||
)
|
||||
== "4"
|
||||
)
|
||||
assert math_spy.call_args.args[1] == {
|
||||
"key": "math",
|
||||
"input": {"question": "2 + 2"},
|
||||
}
|
||||
assert len([r for r in tracer.runs if r.parent_run_id is None]) == 1
|
||||
parent_run = next(r for r in tracer.runs if r.parent_run_id is None)
|
||||
assert len(parent_run.child_runs) == 2
|
||||
router_run = parent_run.child_runs[1]
|
||||
assert router_run.name == "RunnableLambda"
|
||||
assert len(router_run.child_runs) == 1
|
||||
math_run = router_run.child_runs[0]
|
||||
assert math_run.name == "RunnableSequence"
|
||||
assert len(math_run.child_runs) == 3
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_seq_prompt_map(mocker: MockerFixture, snapshot: SnapshotAssertion) -> None:
|
||||
passthrough = mocker.Mock(side_effect=lambda x: x)
|
||||
@ -1136,3 +1282,17 @@ def test_each(snapshot: SnapshotAssertion) -> None:
|
||||
"test",
|
||||
"this",
|
||||
]
|
||||
|
||||
|
||||
def test_recursive_lambda() -> None:
|
||||
def _simple_recursion(x: int) -> Union[int, Runnable]:
|
||||
if x < 10:
|
||||
return RunnableLambda(lambda *args: _simple_recursion(x + 1))
|
||||
else:
|
||||
return x
|
||||
|
||||
runnable = RunnableLambda(_simple_recursion)
|
||||
assert runnable.invoke(5) == 10
|
||||
|
||||
with pytest.raises(RecursionError):
|
||||
runnable.invoke(0, {"recursion_limit": 9})
|
||||
|
Loading…
Reference in New Issue
Block a user