mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-22 14:49:29 +00:00
Nc/runnable lambda recurse (#9390)
<!-- Thank you for contributing to LangChain! Replace this entire comment with: - Description: a description of the change, - Issue: the issue # it fixes (if applicable), - Dependencies: any dependencies required for this change, - Tag maintainer: for a quicker response, tag the relevant maintainer (see below), - Twitter handle: we announce bigger features on Twitter. If your PR gets announced and you'd like a mention, we'll gladly shout you out! Please make sure your PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` to check this locally. See contribution guidelines for more information on how to write/run tests, lint, etc: https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. These live is docs/extras directory. If no one reviews your PR within a few days, please @-mention one of @baskaryan, @eyurtsev, @hwchase17, @rlancemartin. -->
This commit is contained in:
commit
fa05e18278
@ -1,6 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import inspect
|
||||||
import threading
|
import threading
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from concurrent.futures import FIRST_COMPLETED, wait
|
from concurrent.futures import FIRST_COMPLETED, wait
|
||||||
@ -1343,9 +1344,18 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
A runnable that runs a callable.
|
A runnable that runs a callable.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, func: Callable[[Input], Output]) -> None:
|
def __init__(
|
||||||
if callable(func):
|
self,
|
||||||
self.func = func
|
func: Union[Callable[[Input], Output], Callable[[Input], Awaitable[Output]]],
|
||||||
|
afunc: Optional[Callable[[Input], Awaitable[Output]]] = None,
|
||||||
|
) -> None:
|
||||||
|
if afunc is not None:
|
||||||
|
self.afunc = afunc
|
||||||
|
|
||||||
|
if inspect.iscoroutinefunction(func):
|
||||||
|
self.afunc = func
|
||||||
|
elif callable(func):
|
||||||
|
self.func = cast(Callable[[Input], Output], func)
|
||||||
else:
|
else:
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
"Expected a callable type for `func`."
|
"Expected a callable type for `func`."
|
||||||
@ -1354,17 +1364,89 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
|
|
||||||
def __eq__(self, other: Any) -> bool:
|
def __eq__(self, other: Any) -> bool:
|
||||||
if isinstance(other, RunnableLambda):
|
if isinstance(other, RunnableLambda):
|
||||||
return self.func == other.func
|
if hasattr(self, "func") and hasattr(other, "func"):
|
||||||
|
return self.func == other.func
|
||||||
|
elif hasattr(self, "afunc") and hasattr(other, "afunc"):
|
||||||
|
return self.afunc == other.afunc
|
||||||
|
else:
|
||||||
|
return False
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def _invoke(
|
||||||
|
self,
|
||||||
|
input: Input,
|
||||||
|
run_manager: CallbackManagerForChainRun,
|
||||||
|
config: RunnableConfig,
|
||||||
|
) -> Output:
|
||||||
|
output = self.func(input)
|
||||||
|
# If the output is a runnable, invoke it
|
||||||
|
if isinstance(output, Runnable):
|
||||||
|
recursion_limit = config["recursion_limit"]
|
||||||
|
if recursion_limit <= 0:
|
||||||
|
raise RecursionError(
|
||||||
|
f"Recursion limit reached when invoking {self} with input {input}."
|
||||||
|
)
|
||||||
|
output = output.invoke(
|
||||||
|
input,
|
||||||
|
patch_config(
|
||||||
|
config,
|
||||||
|
callbacks=run_manager.get_child(),
|
||||||
|
recursion_limit=recursion_limit - 1,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return output
|
||||||
|
|
||||||
|
async def _ainvoke(
|
||||||
|
self,
|
||||||
|
input: Input,
|
||||||
|
run_manager: AsyncCallbackManagerForChainRun,
|
||||||
|
config: RunnableConfig,
|
||||||
|
) -> Output:
|
||||||
|
output = await self.afunc(input)
|
||||||
|
# If the output is a runnable, invoke it
|
||||||
|
if isinstance(output, Runnable):
|
||||||
|
recursion_limit = config["recursion_limit"]
|
||||||
|
if recursion_limit <= 0:
|
||||||
|
raise RecursionError(
|
||||||
|
f"Recursion limit reached when invoking {self} with input {input}."
|
||||||
|
)
|
||||||
|
output = await output.ainvoke(
|
||||||
|
input,
|
||||||
|
patch_config(
|
||||||
|
config,
|
||||||
|
callbacks=run_manager.get_child(),
|
||||||
|
recursion_limit=recursion_limit - 1,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return output
|
||||||
|
|
||||||
def invoke(
|
def invoke(
|
||||||
self,
|
self,
|
||||||
input: Input,
|
input: Input,
|
||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
**kwargs: Optional[Any],
|
**kwargs: Optional[Any],
|
||||||
) -> Output:
|
) -> Output:
|
||||||
return self._call_with_config(self.func, input, config)
|
if hasattr(self, "func"):
|
||||||
|
return self._call_with_config(self._invoke, input, config)
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
"Cannot invoke a coroutine function synchronously."
|
||||||
|
"Use `ainvoke` instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
async def ainvoke(
|
||||||
|
self,
|
||||||
|
input: Input,
|
||||||
|
config: Optional[RunnableConfig] = None,
|
||||||
|
**kwargs: Optional[Any],
|
||||||
|
) -> Output:
|
||||||
|
if hasattr(self, "afunc"):
|
||||||
|
return await self._acall_with_config(self._ainvoke, input, config)
|
||||||
|
else:
|
||||||
|
# Delegating to super implementation of ainvoke.
|
||||||
|
# Uses asyncio executor to run the sync version (invoke)
|
||||||
|
return await super().ainvoke(input, config)
|
||||||
|
|
||||||
|
|
||||||
class RunnableEach(Serializable, Runnable[List[Input], List[Output]]):
|
class RunnableEach(Serializable, Runnable[List[Input], List[Output]]):
|
||||||
|
@ -47,6 +47,11 @@ class RunnableConfig(TypedDict, total=False):
|
|||||||
ThreadPoolExecutor will be created.
|
ThreadPoolExecutor will be created.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
recursion_limit: int
|
||||||
|
"""
|
||||||
|
Maximum number of times a call can recurse. If not provided, defaults to 10.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
def ensure_config(config: Optional[RunnableConfig]) -> RunnableConfig:
|
def ensure_config(config: Optional[RunnableConfig]) -> RunnableConfig:
|
||||||
empty = RunnableConfig(
|
empty = RunnableConfig(
|
||||||
@ -54,6 +59,7 @@ def ensure_config(config: Optional[RunnableConfig]) -> RunnableConfig:
|
|||||||
metadata={},
|
metadata={},
|
||||||
callbacks=None,
|
callbacks=None,
|
||||||
_locals={},
|
_locals={},
|
||||||
|
recursion_limit=10,
|
||||||
)
|
)
|
||||||
if config is not None:
|
if config is not None:
|
||||||
empty.update(config)
|
empty.update(config)
|
||||||
@ -66,6 +72,7 @@ def patch_config(
|
|||||||
deep_copy_locals: bool = False,
|
deep_copy_locals: bool = False,
|
||||||
callbacks: Optional[BaseCallbackManager] = None,
|
callbacks: Optional[BaseCallbackManager] = None,
|
||||||
executor: Optional[Executor] = None,
|
executor: Optional[Executor] = None,
|
||||||
|
recursion_limit: Optional[int] = None,
|
||||||
) -> RunnableConfig:
|
) -> RunnableConfig:
|
||||||
config = ensure_config(config)
|
config = ensure_config(config)
|
||||||
if deep_copy_locals:
|
if deep_copy_locals:
|
||||||
@ -74,6 +81,8 @@ def patch_config(
|
|||||||
config["callbacks"] = callbacks
|
config["callbacks"] = callbacks
|
||||||
if executor is not None:
|
if executor is not None:
|
||||||
config["executor"] = executor
|
config["executor"] = executor
|
||||||
|
if recursion_limit is not None:
|
||||||
|
config["recursion_limit"] = recursion_limit
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
File diff suppressed because one or more lines are too long
@ -1,4 +1,5 @@
|
|||||||
from typing import Any, Dict, List, Optional
|
from operator import itemgetter
|
||||||
|
from typing import Any, Dict, List, Optional, Union
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -176,6 +177,7 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
|
|||||||
tags=[],
|
tags=[],
|
||||||
callbacks=None,
|
callbacks=None,
|
||||||
_locals={},
|
_locals={},
|
||||||
|
recursion_limit=10,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
mocker.call(
|
mocker.call(
|
||||||
@ -185,6 +187,7 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
|
|||||||
tags=[],
|
tags=[],
|
||||||
callbacks=None,
|
callbacks=None,
|
||||||
_locals={},
|
_locals={},
|
||||||
|
recursion_limit=10,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
@ -438,6 +441,50 @@ async def test_prompt_with_llm(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@freeze_time("2023-01-01")
|
||||||
|
async def test_prompt_with_llm_and_async_lambda(
|
||||||
|
mocker: MockerFixture, snapshot: SnapshotAssertion
|
||||||
|
) -> None:
|
||||||
|
prompt = (
|
||||||
|
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
|
||||||
|
+ "{question}"
|
||||||
|
)
|
||||||
|
llm = FakeListLLM(responses=["foo", "bar"])
|
||||||
|
|
||||||
|
async def passthrough(input: Any) -> Any:
|
||||||
|
return input
|
||||||
|
|
||||||
|
chain = prompt | llm | passthrough
|
||||||
|
|
||||||
|
assert isinstance(chain, RunnableSequence)
|
||||||
|
assert chain.first == prompt
|
||||||
|
assert chain.middle == [llm]
|
||||||
|
assert chain.last == RunnableLambda(func=passthrough)
|
||||||
|
assert dumps(chain, pretty=True) == snapshot
|
||||||
|
|
||||||
|
# Test invoke
|
||||||
|
prompt_spy = mocker.spy(prompt.__class__, "ainvoke")
|
||||||
|
llm_spy = mocker.spy(llm.__class__, "ainvoke")
|
||||||
|
tracer = FakeTracer()
|
||||||
|
assert (
|
||||||
|
await chain.ainvoke(
|
||||||
|
{"question": "What is your name?"}, dict(callbacks=[tracer])
|
||||||
|
)
|
||||||
|
== "foo"
|
||||||
|
)
|
||||||
|
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
|
||||||
|
assert llm_spy.call_args.args[1] == ChatPromptValue(
|
||||||
|
messages=[
|
||||||
|
SystemMessage(content="You are a nice assistant."),
|
||||||
|
HumanMessage(content="What is your name?"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert tracer.runs == snapshot
|
||||||
|
mocker.stop(prompt_spy)
|
||||||
|
mocker.stop(llm_spy)
|
||||||
|
|
||||||
|
|
||||||
@freeze_time("2023-01-01")
|
@freeze_time("2023-01-01")
|
||||||
def test_prompt_with_chat_model_and_parser(
|
def test_prompt_with_chat_model_and_parser(
|
||||||
mocker: MockerFixture, snapshot: SnapshotAssertion
|
mocker: MockerFixture, snapshot: SnapshotAssertion
|
||||||
@ -722,6 +769,105 @@ async def test_router_runnable(
|
|||||||
assert len(router_run.child_runs) == 2
|
assert len(router_run.child_runs) == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@freeze_time("2023-01-01")
|
||||||
|
async def test_higher_order_lambda_runnable(
|
||||||
|
mocker: MockerFixture, snapshot: SnapshotAssertion
|
||||||
|
) -> None:
|
||||||
|
math_chain = ChatPromptTemplate.from_template(
|
||||||
|
"You are a math genius. Answer the question: {question}"
|
||||||
|
) | FakeListLLM(responses=["4"])
|
||||||
|
english_chain = ChatPromptTemplate.from_template(
|
||||||
|
"You are an english major. Answer the question: {question}"
|
||||||
|
) | FakeListLLM(responses=["2"])
|
||||||
|
input_map: Runnable = RunnableMap(
|
||||||
|
{ # type: ignore[arg-type]
|
||||||
|
"key": lambda x: x["key"],
|
||||||
|
"input": {"question": lambda x: x["question"]},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def router(input: Dict[str, Any]) -> Runnable:
|
||||||
|
if input["key"] == "math":
|
||||||
|
return itemgetter("input") | math_chain
|
||||||
|
elif input["key"] == "english":
|
||||||
|
return itemgetter("input") | english_chain
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown key: {input['key']}")
|
||||||
|
|
||||||
|
chain: Runnable = input_map | router
|
||||||
|
assert dumps(chain, pretty=True) == snapshot
|
||||||
|
|
||||||
|
result = chain.invoke({"key": "math", "question": "2 + 2"})
|
||||||
|
assert result == "4"
|
||||||
|
|
||||||
|
result2 = chain.batch(
|
||||||
|
[{"key": "math", "question": "2 + 2"}, {"key": "english", "question": "2 + 2"}]
|
||||||
|
)
|
||||||
|
assert result2 == ["4", "2"]
|
||||||
|
|
||||||
|
result = await chain.ainvoke({"key": "math", "question": "2 + 2"})
|
||||||
|
assert result == "4"
|
||||||
|
|
||||||
|
result2 = await chain.abatch(
|
||||||
|
[{"key": "math", "question": "2 + 2"}, {"key": "english", "question": "2 + 2"}]
|
||||||
|
)
|
||||||
|
assert result2 == ["4", "2"]
|
||||||
|
|
||||||
|
# Test invoke
|
||||||
|
math_spy = mocker.spy(math_chain.__class__, "invoke")
|
||||||
|
tracer = FakeTracer()
|
||||||
|
assert (
|
||||||
|
chain.invoke({"key": "math", "question": "2 + 2"}, dict(callbacks=[tracer]))
|
||||||
|
== "4"
|
||||||
|
)
|
||||||
|
assert math_spy.call_args.args[1] == {
|
||||||
|
"key": "math",
|
||||||
|
"input": {"question": "2 + 2"},
|
||||||
|
}
|
||||||
|
assert len([r for r in tracer.runs if r.parent_run_id is None]) == 1
|
||||||
|
parent_run = next(r for r in tracer.runs if r.parent_run_id is None)
|
||||||
|
assert len(parent_run.child_runs) == 2
|
||||||
|
router_run = parent_run.child_runs[1]
|
||||||
|
assert router_run.name == "RunnableLambda"
|
||||||
|
assert len(router_run.child_runs) == 1
|
||||||
|
math_run = router_run.child_runs[0]
|
||||||
|
assert math_run.name == "RunnableSequence"
|
||||||
|
assert len(math_run.child_runs) == 3
|
||||||
|
|
||||||
|
# Test ainvoke
|
||||||
|
async def arouter(input: Dict[str, Any]) -> Runnable:
|
||||||
|
if input["key"] == "math":
|
||||||
|
return itemgetter("input") | math_chain
|
||||||
|
elif input["key"] == "english":
|
||||||
|
return itemgetter("input") | english_chain
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown key: {input['key']}")
|
||||||
|
|
||||||
|
achain: Runnable = input_map | arouter
|
||||||
|
math_spy = mocker.spy(math_chain.__class__, "ainvoke")
|
||||||
|
tracer = FakeTracer()
|
||||||
|
assert (
|
||||||
|
await achain.ainvoke(
|
||||||
|
{"key": "math", "question": "2 + 2"}, dict(callbacks=[tracer])
|
||||||
|
)
|
||||||
|
== "4"
|
||||||
|
)
|
||||||
|
assert math_spy.call_args.args[1] == {
|
||||||
|
"key": "math",
|
||||||
|
"input": {"question": "2 + 2"},
|
||||||
|
}
|
||||||
|
assert len([r for r in tracer.runs if r.parent_run_id is None]) == 1
|
||||||
|
parent_run = next(r for r in tracer.runs if r.parent_run_id is None)
|
||||||
|
assert len(parent_run.child_runs) == 2
|
||||||
|
router_run = parent_run.child_runs[1]
|
||||||
|
assert router_run.name == "RunnableLambda"
|
||||||
|
assert len(router_run.child_runs) == 1
|
||||||
|
math_run = router_run.child_runs[0]
|
||||||
|
assert math_run.name == "RunnableSequence"
|
||||||
|
assert len(math_run.child_runs) == 3
|
||||||
|
|
||||||
|
|
||||||
@freeze_time("2023-01-01")
|
@freeze_time("2023-01-01")
|
||||||
def test_seq_prompt_map(mocker: MockerFixture, snapshot: SnapshotAssertion) -> None:
|
def test_seq_prompt_map(mocker: MockerFixture, snapshot: SnapshotAssertion) -> None:
|
||||||
passthrough = mocker.Mock(side_effect=lambda x: x)
|
passthrough = mocker.Mock(side_effect=lambda x: x)
|
||||||
@ -1136,3 +1282,17 @@ def test_each(snapshot: SnapshotAssertion) -> None:
|
|||||||
"test",
|
"test",
|
||||||
"this",
|
"this",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_recursive_lambda() -> None:
|
||||||
|
def _simple_recursion(x: int) -> Union[int, Runnable]:
|
||||||
|
if x < 10:
|
||||||
|
return RunnableLambda(lambda *args: _simple_recursion(x + 1))
|
||||||
|
else:
|
||||||
|
return x
|
||||||
|
|
||||||
|
runnable = RunnableLambda(_simple_recursion)
|
||||||
|
assert runnable.invoke(5) == 10
|
||||||
|
|
||||||
|
with pytest.raises(RecursionError):
|
||||||
|
runnable.invoke(0, {"recursion_limit": 9})
|
||||||
|
Loading…
Reference in New Issue
Block a user