diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index c7dd38865fa..a130dc62b8c 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import inspect import threading from abc import ABC, abstractmethod from concurrent.futures import FIRST_COMPLETED, wait @@ -1343,9 +1344,18 @@ class RunnableLambda(Runnable[Input, Output]): A runnable that runs a callable. """ - def __init__(self, func: Callable[[Input], Output]) -> None: - if callable(func): - self.func = func + def __init__( + self, + func: Union[Callable[[Input], Output], Callable[[Input], Awaitable[Output]]], + afunc: Optional[Callable[[Input], Awaitable[Output]]] = None, + ) -> None: + if afunc is not None: + self.afunc = afunc + + if inspect.iscoroutinefunction(func): + self.afunc = func + elif callable(func): + self.func = cast(Callable[[Input], Output], func) else: raise TypeError( "Expected a callable type for `func`." @@ -1354,17 +1364,89 @@ class RunnableLambda(Runnable[Input, Output]): def __eq__(self, other: Any) -> bool: if isinstance(other, RunnableLambda): - return self.func == other.func + if hasattr(self, "func") and hasattr(other, "func"): + return self.func == other.func + elif hasattr(self, "afunc") and hasattr(other, "afunc"): + return self.afunc == other.afunc + else: + return False else: return False + def _invoke( + self, + input: Input, + run_manager: CallbackManagerForChainRun, + config: RunnableConfig, + ) -> Output: + output = self.func(input) + # If the output is a runnable, invoke it + if isinstance(output, Runnable): + recursion_limit = config["recursion_limit"] + if recursion_limit <= 0: + raise RecursionError( + f"Recursion limit reached when invoking {self} with input {input}." + ) + output = output.invoke( + input, + patch_config( + config, + callbacks=run_manager.get_child(), + recursion_limit=recursion_limit - 1, + ), + ) + return output + + async def _ainvoke( + self, + input: Input, + run_manager: AsyncCallbackManagerForChainRun, + config: RunnableConfig, + ) -> Output: + output = await self.afunc(input) + # If the output is a runnable, invoke it + if isinstance(output, Runnable): + recursion_limit = config["recursion_limit"] + if recursion_limit <= 0: + raise RecursionError( + f"Recursion limit reached when invoking {self} with input {input}." + ) + output = await output.ainvoke( + input, + patch_config( + config, + callbacks=run_manager.get_child(), + recursion_limit=recursion_limit - 1, + ), + ) + return output + def invoke( self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Optional[Any], ) -> Output: - return self._call_with_config(self.func, input, config) + if hasattr(self, "func"): + return self._call_with_config(self._invoke, input, config) + else: + raise TypeError( + "Cannot invoke a coroutine function synchronously." + "Use `ainvoke` instead." + ) + + async def ainvoke( + self, + input: Input, + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], + ) -> Output: + if hasattr(self, "afunc"): + return await self._acall_with_config(self._ainvoke, input, config) + else: + # Delegating to super implementation of ainvoke. + # Uses asyncio executor to run the sync version (invoke) + return await super().ainvoke(input, config) class RunnableEach(Serializable, Runnable[List[Input], List[Output]]): diff --git a/libs/langchain/langchain/schema/runnable/config.py b/libs/langchain/langchain/schema/runnable/config.py index 4ad3b4fb8d4..ce4e11861e6 100644 --- a/libs/langchain/langchain/schema/runnable/config.py +++ b/libs/langchain/langchain/schema/runnable/config.py @@ -47,6 +47,11 @@ class RunnableConfig(TypedDict, total=False): ThreadPoolExecutor will be created. """ + recursion_limit: int + """ + Maximum number of times a call can recurse. If not provided, defaults to 10. + """ + def ensure_config(config: Optional[RunnableConfig]) -> RunnableConfig: empty = RunnableConfig( @@ -54,6 +59,7 @@ def ensure_config(config: Optional[RunnableConfig]) -> RunnableConfig: metadata={}, callbacks=None, _locals={}, + recursion_limit=10, ) if config is not None: empty.update(config) @@ -66,6 +72,7 @@ def patch_config( deep_copy_locals: bool = False, callbacks: Optional[BaseCallbackManager] = None, executor: Optional[Executor] = None, + recursion_limit: Optional[int] = None, ) -> RunnableConfig: config = ensure_config(config) if deep_copy_locals: @@ -74,6 +81,8 @@ def patch_config( config["callbacks"] = callbacks if executor is not None: config["executor"] = executor + if recursion_limit is not None: + config["recursion_limit"] = recursion_limit return config diff --git a/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr b/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr index 828227fd308..c48d4edbd41 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr +++ b/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr @@ -599,6 +599,83 @@ } ''' # --- +# name: test_higher_order_lambda_runnable + ''' + { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "schema", + "runnable", + "RunnableSequence" + ], + "kwargs": { + "first": { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "schema", + "runnable", + "RunnableMap" + ], + "kwargs": { + "steps": { + "key": { + "lc": 1, + "type": "not_implemented", + "id": [ + "langchain", + "schema", + "runnable", + "base", + "RunnableLambda" + ] + }, + "input": { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "schema", + "runnable", + "RunnableMap" + ], + "kwargs": { + "steps": { + "question": { + "lc": 1, + "type": "not_implemented", + "id": [ + "langchain", + "schema", + "runnable", + "base", + "RunnableLambda" + ] + } + } + } + } + } + } + }, + "last": { + "lc": 1, + "type": "not_implemented", + "id": [ + "langchain", + "schema", + "runnable", + "base", + "RunnableLambda" + ] + } + } + } + ''' +# --- # name: test_llm_with_fallbacks[llm_chain_with_fallbacks] ''' { @@ -1187,6 +1264,125 @@ Run(id=UUID('00000000-0000-4000-8000-000000000003'), name='RunnableSequence', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='chain', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'runnable', 'RunnableSequence'], 'kwargs': {'first': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string', 'partial_variables': {}}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string', 'partial_variables': {}}}}}], 'input_variables': ['question']}}, 'last': {'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'llms', 'fake', 'FakeListLLM']}}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your favorite color?'}, outputs={'output': 'foo'}, reference_example_id=None, parent_run_id=None, tags=[], execution_order=None, child_execution_order=None, child_runs=[Run(id=UUID('00000000-0000-4000-8000-000000000004'), name='ChatPromptTemplate', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='prompt', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string', 'partial_variables': {}}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string', 'partial_variables': {}}}}}], 'input_variables': ['question']}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your favorite color?'}, outputs={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptValue'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'messages', 'SystemMessage'], 'kwargs': {'content': 'You are a nice assistant.', 'additional_kwargs': {}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'messages', 'HumanMessage'], 'kwargs': {'content': 'What is your favorite color?', 'additional_kwargs': {}}}]}}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000003'), tags=[], execution_order=None, child_execution_order=None, child_runs=[]), Run(id=UUID('00000000-0000-4000-8000-000000000005'), name='FakeListLLM', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='llm', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={'invocation_params': {'responses': ['foo', 'bar'], '_type': 'fake-list', 'stop': None}, 'options': {'stop': None}}, error=None, serialized={'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'llms', 'fake', 'FakeListLLM']}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'prompts': ['System: You are a nice assistant.\nHuman: What is your favorite color?']}, outputs={'generations': [[{'text': 'foo', 'generation_info': None}]], 'llm_output': None, 'run': None}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000003'), tags=[], execution_order=None, child_execution_order=None, child_runs=[])]), ]) # --- +# name: test_prompt_with_llm_and_async_lambda + ''' + { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "schema", + "runnable", + "RunnableSequence" + ], + "kwargs": { + "first": { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "prompts", + "chat", + "ChatPromptTemplate" + ], + "kwargs": { + "messages": [ + { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "prompts", + "chat", + "SystemMessagePromptTemplate" + ], + "kwargs": { + "prompt": { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "prompts", + "prompt", + "PromptTemplate" + ], + "kwargs": { + "input_variables": [], + "template": "You are a nice assistant.", + "template_format": "f-string", + "partial_variables": {} + } + } + } + }, + { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "prompts", + "chat", + "HumanMessagePromptTemplate" + ], + "kwargs": { + "prompt": { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "prompts", + "prompt", + "PromptTemplate" + ], + "kwargs": { + "input_variables": [ + "question" + ], + "template": "{question}", + "template_format": "f-string", + "partial_variables": {} + } + } + } + } + ], + "input_variables": [ + "question" + ] + } + }, + "middle": [ + { + "lc": 1, + "type": "not_implemented", + "id": [ + "langchain", + "llms", + "fake", + "FakeListLLM" + ] + } + ], + "last": { + "lc": 1, + "type": "not_implemented", + "id": [ + "langchain", + "schema", + "runnable", + "base", + "RunnableLambda" + ] + } + } + } + ''' +# --- +# name: test_prompt_with_llm_and_async_lambda.1 + list([ + Run(id=UUID('00000000-0000-4000-8000-000000000000'), name='RunnableSequence', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='chain', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'runnable', 'RunnableSequence'], 'kwargs': {'first': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string', 'partial_variables': {}}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string', 'partial_variables': {}}}}}], 'input_variables': ['question']}}, 'middle': [{'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'llms', 'fake', 'FakeListLLM']}], 'last': {'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'schema', 'runnable', 'base', 'RunnableLambda']}}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your name?'}, outputs={'output': 'foo'}, reference_example_id=None, parent_run_id=None, tags=[], execution_order=None, child_execution_order=None, child_runs=[Run(id=UUID('00000000-0000-4000-8000-000000000001'), name='ChatPromptTemplate', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='prompt', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string', 'partial_variables': {}}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string', 'partial_variables': {}}}}}], 'input_variables': ['question']}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your name?'}, outputs={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptValue'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'messages', 'SystemMessage'], 'kwargs': {'content': 'You are a nice assistant.', 'additional_kwargs': {}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'messages', 'HumanMessage'], 'kwargs': {'content': 'What is your name?', 'additional_kwargs': {}}}]}}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=[], execution_order=None, child_execution_order=None, child_runs=[]), Run(id=UUID('00000000-0000-4000-8000-000000000002'), name='FakeListLLM', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='llm', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={'invocation_params': {'responses': ['foo', 'bar'], '_type': 'fake-list', 'stop': None}, 'options': {'stop': None}}, error=None, serialized={'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'llms', 'fake', 'FakeListLLM']}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'prompts': ['System: You are a nice assistant.\nHuman: What is your name?']}, outputs={'generations': [[{'text': 'foo', 'generation_info': None}]], 'llm_output': None, 'run': None}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=[], execution_order=None, child_execution_order=None, child_runs=[]), Run(id=UUID('00000000-0000-4000-8000-000000000003'), name='RunnableLambda', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='chain', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'schema', 'runnable', 'base', 'RunnableLambda']}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'input': 'foo'}, outputs={'output': 'foo'}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=[], execution_order=None, child_execution_order=None, child_runs=[])]), + ]) +# --- # name: test_router_runnable ''' { diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py index 544b42da693..f2447533107 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py @@ -1,4 +1,5 @@ -from typing import Any, Dict, List, Optional +from operator import itemgetter +from typing import Any, Dict, List, Optional, Union from uuid import UUID import pytest @@ -176,6 +177,7 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None: tags=[], callbacks=None, _locals={}, + recursion_limit=10, ), ), mocker.call( @@ -185,6 +187,7 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None: tags=[], callbacks=None, _locals={}, + recursion_limit=10, ), ), ] @@ -438,6 +441,50 @@ async def test_prompt_with_llm( ) +@pytest.mark.asyncio +@freeze_time("2023-01-01") +async def test_prompt_with_llm_and_async_lambda( + mocker: MockerFixture, snapshot: SnapshotAssertion +) -> None: + prompt = ( + SystemMessagePromptTemplate.from_template("You are a nice assistant.") + + "{question}" + ) + llm = FakeListLLM(responses=["foo", "bar"]) + + async def passthrough(input: Any) -> Any: + return input + + chain = prompt | llm | passthrough + + assert isinstance(chain, RunnableSequence) + assert chain.first == prompt + assert chain.middle == [llm] + assert chain.last == RunnableLambda(func=passthrough) + assert dumps(chain, pretty=True) == snapshot + + # Test invoke + prompt_spy = mocker.spy(prompt.__class__, "ainvoke") + llm_spy = mocker.spy(llm.__class__, "ainvoke") + tracer = FakeTracer() + assert ( + await chain.ainvoke( + {"question": "What is your name?"}, dict(callbacks=[tracer]) + ) + == "foo" + ) + assert prompt_spy.call_args.args[1] == {"question": "What is your name?"} + assert llm_spy.call_args.args[1] == ChatPromptValue( + messages=[ + SystemMessage(content="You are a nice assistant."), + HumanMessage(content="What is your name?"), + ] + ) + assert tracer.runs == snapshot + mocker.stop(prompt_spy) + mocker.stop(llm_spy) + + @freeze_time("2023-01-01") def test_prompt_with_chat_model_and_parser( mocker: MockerFixture, snapshot: SnapshotAssertion @@ -722,6 +769,105 @@ async def test_router_runnable( assert len(router_run.child_runs) == 2 +@pytest.mark.asyncio +@freeze_time("2023-01-01") +async def test_higher_order_lambda_runnable( + mocker: MockerFixture, snapshot: SnapshotAssertion +) -> None: + math_chain = ChatPromptTemplate.from_template( + "You are a math genius. Answer the question: {question}" + ) | FakeListLLM(responses=["4"]) + english_chain = ChatPromptTemplate.from_template( + "You are an english major. Answer the question: {question}" + ) | FakeListLLM(responses=["2"]) + input_map: Runnable = RunnableMap( + { # type: ignore[arg-type] + "key": lambda x: x["key"], + "input": {"question": lambda x: x["question"]}, + } + ) + + def router(input: Dict[str, Any]) -> Runnable: + if input["key"] == "math": + return itemgetter("input") | math_chain + elif input["key"] == "english": + return itemgetter("input") | english_chain + else: + raise ValueError(f"Unknown key: {input['key']}") + + chain: Runnable = input_map | router + assert dumps(chain, pretty=True) == snapshot + + result = chain.invoke({"key": "math", "question": "2 + 2"}) + assert result == "4" + + result2 = chain.batch( + [{"key": "math", "question": "2 + 2"}, {"key": "english", "question": "2 + 2"}] + ) + assert result2 == ["4", "2"] + + result = await chain.ainvoke({"key": "math", "question": "2 + 2"}) + assert result == "4" + + result2 = await chain.abatch( + [{"key": "math", "question": "2 + 2"}, {"key": "english", "question": "2 + 2"}] + ) + assert result2 == ["4", "2"] + + # Test invoke + math_spy = mocker.spy(math_chain.__class__, "invoke") + tracer = FakeTracer() + assert ( + chain.invoke({"key": "math", "question": "2 + 2"}, dict(callbacks=[tracer])) + == "4" + ) + assert math_spy.call_args.args[1] == { + "key": "math", + "input": {"question": "2 + 2"}, + } + assert len([r for r in tracer.runs if r.parent_run_id is None]) == 1 + parent_run = next(r for r in tracer.runs if r.parent_run_id is None) + assert len(parent_run.child_runs) == 2 + router_run = parent_run.child_runs[1] + assert router_run.name == "RunnableLambda" + assert len(router_run.child_runs) == 1 + math_run = router_run.child_runs[0] + assert math_run.name == "RunnableSequence" + assert len(math_run.child_runs) == 3 + + # Test ainvoke + async def arouter(input: Dict[str, Any]) -> Runnable: + if input["key"] == "math": + return itemgetter("input") | math_chain + elif input["key"] == "english": + return itemgetter("input") | english_chain + else: + raise ValueError(f"Unknown key: {input['key']}") + + achain: Runnable = input_map | arouter + math_spy = mocker.spy(math_chain.__class__, "ainvoke") + tracer = FakeTracer() + assert ( + await achain.ainvoke( + {"key": "math", "question": "2 + 2"}, dict(callbacks=[tracer]) + ) + == "4" + ) + assert math_spy.call_args.args[1] == { + "key": "math", + "input": {"question": "2 + 2"}, + } + assert len([r for r in tracer.runs if r.parent_run_id is None]) == 1 + parent_run = next(r for r in tracer.runs if r.parent_run_id is None) + assert len(parent_run.child_runs) == 2 + router_run = parent_run.child_runs[1] + assert router_run.name == "RunnableLambda" + assert len(router_run.child_runs) == 1 + math_run = router_run.child_runs[0] + assert math_run.name == "RunnableSequence" + assert len(math_run.child_runs) == 3 + + @freeze_time("2023-01-01") def test_seq_prompt_map(mocker: MockerFixture, snapshot: SnapshotAssertion) -> None: passthrough = mocker.Mock(side_effect=lambda x: x) @@ -1136,3 +1282,17 @@ def test_each(snapshot: SnapshotAssertion) -> None: "test", "this", ] + + +def test_recursive_lambda() -> None: + def _simple_recursion(x: int) -> Union[int, Runnable]: + if x < 10: + return RunnableLambda(lambda *args: _simple_recursion(x + 1)) + else: + return x + + runnable = RunnableLambda(_simple_recursion) + assert runnable.invoke(5) == 10 + + with pytest.raises(RecursionError): + runnable.invoke(0, {"recursion_limit": 9})