From 6d19709b65d98e1ff7f37a459357e5470bd1c4bf Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Thu, 17 Aug 2023 15:04:36 +0100 Subject: [PATCH] RunnableLambda, if func returns a Runnable, run it --- .../langchain/schema/runnable/base.py | 59 +++++++++- .../langchain/schema/runnable/config.py | 9 ++ .../runnable/__snapshots__/test_runnable.ambr | 82 +++++++++++++- .../schema/runnable/test_runnable.py | 104 +++++++++++++++++- 4 files changed, 245 insertions(+), 9 deletions(-) diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index 722a7c376b2..3f6148e9494 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -13,7 +13,6 @@ from typing import ( AsyncIterator, Awaitable, Callable, - Coroutine, Dict, Generic, Iterator, @@ -1347,8 +1346,8 @@ class RunnableLambda(Runnable[Input, Output]): def __init__( self, - func: Union[Callable[[Input], Output], Coroutine[Input, Any, Output]], - afunc: Optional[Coroutine[Input, Any, Output]] = None, + func: Union[Callable[[Input], Output], Callable[[Input], Awaitable[Output]]], + afunc: Optional[Callable[[Input], Awaitable[Output]]] = None, ) -> None: if afunc is not None: self.afunc = afunc @@ -1356,7 +1355,7 @@ class RunnableLambda(Runnable[Input, Output]): if inspect.iscoroutinefunction(func): self.afunc = func elif callable(func): - self.func = func + self.func = cast(Callable[[Input], Output], func) else: raise TypeError( "Expected a callable type for `func`." @@ -1374,6 +1373,54 @@ class RunnableLambda(Runnable[Input, Output]): else: return False + def _invoke( + self, + input: Input, + run_manager: CallbackManagerForChainRun, + config: RunnableConfig, + ) -> Output: + output = self.func(input) + # If the output is a runnable, invoke it + if isinstance(output, Runnable): + recursion_limit = config["recursion_limit"] + if recursion_limit <= 0: + raise RecursionError( + f"Recursion limit reached when invoking {self} with input {input}." + ) + output = output.invoke( + input, + patch_config( + config, + callbacks=run_manager.get_child(), + recursion_limit=recursion_limit - 1, + ), + ) + return output + + async def _ainvoke( + self, + input: Input, + run_manager: AsyncCallbackManagerForChainRun, + config: RunnableConfig, + ) -> Output: + output = await self.afunc(input) + # If the output is a runnable, invoke it + if isinstance(output, Runnable): + recursion_limit = config["recursion_limit"] + if recursion_limit <= 0: + raise RecursionError( + f"Recursion limit reached when invoking {self} with input {input}." + ) + output = await output.ainvoke( + input, + patch_config( + config, + callbacks=run_manager.get_child(), + recursion_limit=recursion_limit - 1, + ), + ) + return output + def invoke( self, input: Input, @@ -1381,7 +1428,7 @@ class RunnableLambda(Runnable[Input, Output]): **kwargs: Optional[Any], ) -> Output: if hasattr(self, "func"): - return self._call_with_config(self.func, input, config) + return self._call_with_config(self._invoke, input, config) else: raise TypeError( "Cannot invoke a coroutine function synchronously." @@ -1395,7 +1442,7 @@ class RunnableLambda(Runnable[Input, Output]): **kwargs: Optional[Any], ) -> Output: if hasattr(self, "afunc"): - return await self._acall_with_config(self.afunc, input, config) + return await self._acall_with_config(self._ainvoke, input, config) else: return await super().ainvoke(input, config) diff --git a/libs/langchain/langchain/schema/runnable/config.py b/libs/langchain/langchain/schema/runnable/config.py index 4ad3b4fb8d4..ce4e11861e6 100644 --- a/libs/langchain/langchain/schema/runnable/config.py +++ b/libs/langchain/langchain/schema/runnable/config.py @@ -47,6 +47,11 @@ class RunnableConfig(TypedDict, total=False): ThreadPoolExecutor will be created. """ + recursion_limit: int + """ + Maximum number of times a call can recurse. If not provided, defaults to 10. + """ + def ensure_config(config: Optional[RunnableConfig]) -> RunnableConfig: empty = RunnableConfig( @@ -54,6 +59,7 @@ def ensure_config(config: Optional[RunnableConfig]) -> RunnableConfig: metadata={}, callbacks=None, _locals={}, + recursion_limit=10, ) if config is not None: empty.update(config) @@ -66,6 +72,7 @@ def patch_config( deep_copy_locals: bool = False, callbacks: Optional[BaseCallbackManager] = None, executor: Optional[Executor] = None, + recursion_limit: Optional[int] = None, ) -> RunnableConfig: config = ensure_config(config) if deep_copy_locals: @@ -74,6 +81,8 @@ def patch_config( config["callbacks"] = callbacks if executor is not None: config["executor"] = executor + if recursion_limit is not None: + config["recursion_limit"] = recursion_limit return config diff --git a/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr b/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr index ac8deb3de25..c48d4edbd41 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr +++ b/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr @@ -599,6 +599,83 @@ } ''' # --- +# name: test_higher_order_lambda_runnable + ''' + { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "schema", + "runnable", + "RunnableSequence" + ], + "kwargs": { + "first": { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "schema", + "runnable", + "RunnableMap" + ], + "kwargs": { + "steps": { + "key": { + "lc": 1, + "type": "not_implemented", + "id": [ + "langchain", + "schema", + "runnable", + "base", + "RunnableLambda" + ] + }, + "input": { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "schema", + "runnable", + "RunnableMap" + ], + "kwargs": { + "steps": { + "question": { + "lc": 1, + "type": "not_implemented", + "id": [ + "langchain", + "schema", + "runnable", + "base", + "RunnableLambda" + ] + } + } + } + } + } + } + }, + "last": { + "lc": 1, + "type": "not_implemented", + "id": [ + "langchain", + "schema", + "runnable", + "base", + "RunnableLambda" + ] + } + } + } + ''' +# --- # name: test_llm_with_fallbacks[llm_chain_with_fallbacks] ''' { @@ -1268,6 +1345,9 @@ } } } + ], + "input_variables": [ + "question" ] } }, @@ -1300,7 +1380,7 @@ # --- # name: test_prompt_with_llm_and_async_lambda.1 list([ - Run(id=UUID('00000000-0000-4000-8000-000000000000'), name='RunnableSequence', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='chain', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'runnable', 'RunnableSequence'], 'kwargs': {'first': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string', 'partial_variables': {}}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string', 'partial_variables': {}}}}}]}}, 'middle': [{'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'llms', 'fake', 'FakeListLLM']}], 'last': {'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'schema', 'runnable', 'base', 'RunnableLambda']}}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your name?'}, outputs={'output': 'foo'}, reference_example_id=None, parent_run_id=None, tags=[], execution_order=None, child_execution_order=None, child_runs=[Run(id=UUID('00000000-0000-4000-8000-000000000001'), name='ChatPromptTemplate', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='prompt', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string', 'partial_variables': {}}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string', 'partial_variables': {}}}}}]}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your name?'}, outputs={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptValue'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'messages', 'SystemMessage'], 'kwargs': {'content': 'You are a nice assistant.', 'additional_kwargs': {}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'messages', 'HumanMessage'], 'kwargs': {'content': 'What is your name?', 'additional_kwargs': {}}}]}}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=[], execution_order=None, child_execution_order=None, child_runs=[]), Run(id=UUID('00000000-0000-4000-8000-000000000002'), name='FakeListLLM', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='llm', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={'invocation_params': {'responses': ['foo', 'bar'], '_type': 'fake-list', 'stop': None}, 'options': {'stop': None}}, error=None, serialized={'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'llms', 'fake', 'FakeListLLM']}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'prompts': ['System: You are a nice assistant.\nHuman: What is your name?']}, outputs={'generations': [[{'text': 'foo', 'generation_info': None}]], 'llm_output': None, 'run': None}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=[], execution_order=None, child_execution_order=None, child_runs=[]), Run(id=UUID('00000000-0000-4000-8000-000000000003'), name='RunnableLambda', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='chain', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'schema', 'runnable', 'base', 'RunnableLambda']}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'input': 'foo'}, outputs={'output': 'foo'}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=[], execution_order=None, child_execution_order=None, child_runs=[])]), + Run(id=UUID('00000000-0000-4000-8000-000000000000'), name='RunnableSequence', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='chain', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'runnable', 'RunnableSequence'], 'kwargs': {'first': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string', 'partial_variables': {}}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string', 'partial_variables': {}}}}}], 'input_variables': ['question']}}, 'middle': [{'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'llms', 'fake', 'FakeListLLM']}], 'last': {'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'schema', 'runnable', 'base', 'RunnableLambda']}}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your name?'}, outputs={'output': 'foo'}, reference_example_id=None, parent_run_id=None, tags=[], execution_order=None, child_execution_order=None, child_runs=[Run(id=UUID('00000000-0000-4000-8000-000000000001'), name='ChatPromptTemplate', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='prompt', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string', 'partial_variables': {}}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string', 'partial_variables': {}}}}}], 'input_variables': ['question']}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your name?'}, outputs={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptValue'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'messages', 'SystemMessage'], 'kwargs': {'content': 'You are a nice assistant.', 'additional_kwargs': {}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'messages', 'HumanMessage'], 'kwargs': {'content': 'What is your name?', 'additional_kwargs': {}}}]}}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=[], execution_order=None, child_execution_order=None, child_runs=[]), Run(id=UUID('00000000-0000-4000-8000-000000000002'), name='FakeListLLM', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='llm', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={'invocation_params': {'responses': ['foo', 'bar'], '_type': 'fake-list', 'stop': None}, 'options': {'stop': None}}, error=None, serialized={'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'llms', 'fake', 'FakeListLLM']}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'prompts': ['System: You are a nice assistant.\nHuman: What is your name?']}, outputs={'generations': [[{'text': 'foo', 'generation_info': None}]], 'llm_output': None, 'run': None}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=[], execution_order=None, child_execution_order=None, child_runs=[]), Run(id=UUID('00000000-0000-4000-8000-000000000003'), name='RunnableLambda', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='chain', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'schema', 'runnable', 'base', 'RunnableLambda']}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'input': 'foo'}, outputs={'output': 'foo'}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=[], execution_order=None, child_execution_order=None, child_runs=[])]), ]) # --- # name: test_router_runnable diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py index 8c616770efe..0a36408d2c4 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py @@ -1,4 +1,4 @@ -from ast import Not +from operator import itemgetter from typing import Any, Dict, List, Optional from uuid import UUID @@ -39,7 +39,6 @@ from langchain.schema.runnable import ( RunnablePassthrough, RunnableSequence, RunnableWithFallbacks, - passthrough, ) @@ -178,6 +177,7 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None: tags=[], callbacks=None, _locals={}, + recursion_limit=10, ), ), mocker.call( @@ -187,6 +187,7 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None: tags=[], callbacks=None, _locals={}, + recursion_limit=10, ), ), ] @@ -768,6 +769,105 @@ async def test_router_runnable( assert len(router_run.child_runs) == 2 +@pytest.mark.asyncio +@freeze_time("2023-01-01") +async def test_higher_order_lambda_runnable( + mocker: MockerFixture, snapshot: SnapshotAssertion +) -> None: + math_chain = ChatPromptTemplate.from_template( + "You are a math genius. Answer the question: {question}" + ) | FakeListLLM(responses=["4"]) + english_chain = ChatPromptTemplate.from_template( + "You are an english major. Answer the question: {question}" + ) | FakeListLLM(responses=["2"]) + input_map = RunnableMap( + { # type: ignore[arg-type] + "key": lambda x: x["key"], + "input": {"question": lambda x: x["question"]}, + } + ) + + def router(input: Dict[str, Any]) -> Runnable: + if input["key"] == "math": + return itemgetter("input") | math_chain + elif input["key"] == "english": + return itemgetter("input") | english_chain + else: + raise ValueError(f"Unknown key: {input['key']}") + + chain: Runnable = input_map | router + assert dumps(chain, pretty=True) == snapshot + + result = chain.invoke({"key": "math", "question": "2 + 2"}) + assert result == "4" + + result2 = chain.batch( + [{"key": "math", "question": "2 + 2"}, {"key": "english", "question": "2 + 2"}] + ) + assert result2 == ["4", "2"] + + result = await chain.ainvoke({"key": "math", "question": "2 + 2"}) + assert result == "4" + + result2 = await chain.abatch( + [{"key": "math", "question": "2 + 2"}, {"key": "english", "question": "2 + 2"}] + ) + assert result2 == ["4", "2"] + + # Test invoke + math_spy = mocker.spy(math_chain.__class__, "invoke") + tracer = FakeTracer() + assert ( + chain.invoke({"key": "math", "question": "2 + 2"}, dict(callbacks=[tracer])) + == "4" + ) + assert math_spy.call_args.args[1] == { + "key": "math", + "input": {"question": "2 + 2"}, + } + assert len([r for r in tracer.runs if r.parent_run_id is None]) == 1 + parent_run = next(r for r in tracer.runs if r.parent_run_id is None) + assert len(parent_run.child_runs) == 2 + router_run = parent_run.child_runs[1] + assert router_run.name == "RunnableLambda" + assert len(router_run.child_runs) == 1 + math_run = router_run.child_runs[0] + assert math_run.name == "RunnableSequence" + assert len(math_run.child_runs) == 3 + + # Test ainvoke + async def arouter(input: Dict[str, Any]) -> Runnable: + if input["key"] == "math": + return itemgetter("input") | math_chain + elif input["key"] == "english": + return itemgetter("input") | english_chain + else: + raise ValueError(f"Unknown key: {input['key']}") + + achain: Runnable = input_map | arouter + math_spy = mocker.spy(math_chain.__class__, "ainvoke") + tracer = FakeTracer() + assert ( + await achain.ainvoke( + {"key": "math", "question": "2 + 2"}, dict(callbacks=[tracer]) + ) + == "4" + ) + assert math_spy.call_args.args[1] == { + "key": "math", + "input": {"question": "2 + 2"}, + } + assert len([r for r in tracer.runs if r.parent_run_id is None]) == 1 + parent_run = next(r for r in tracer.runs if r.parent_run_id is None) + assert len(parent_run.child_runs) == 2 + router_run = parent_run.child_runs[1] + assert router_run.name == "RunnableLambda" + assert len(router_run.child_runs) == 1 + math_run = router_run.child_runs[0] + assert math_run.name == "RunnableSequence" + assert len(math_run.child_runs) == 3 + + @freeze_time("2023-01-01") def test_seq_prompt_map(mocker: MockerFixture, snapshot: SnapshotAssertion) -> None: passthrough = mocker.Mock(side_effect=lambda x: x)