From 677da6a0fd7ac8e6c3127487b09b6bd3e1d620d8 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Thu, 17 Aug 2023 14:25:43 +0100 Subject: [PATCH] Add support for async funcs in RunnableSequence --- .../langchain/schema/runnable/base.py | 41 ++++++- .../runnable/__snapshots__/test_runnable.ambr | 116 ++++++++++++++++++ .../schema/runnable/test_runnable.py | 46 +++++++ 3 files changed, 199 insertions(+), 4 deletions(-) diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index c7dd38865fa..722a7c376b2 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import inspect import threading from abc import ABC, abstractmethod from concurrent.futures import FIRST_COMPLETED, wait @@ -12,6 +13,7 @@ from typing import ( AsyncIterator, Awaitable, Callable, + Coroutine, Dict, Generic, Iterator, @@ -1343,8 +1345,17 @@ class RunnableLambda(Runnable[Input, Output]): A runnable that runs a callable. """ - def __init__(self, func: Callable[[Input], Output]) -> None: - if callable(func): + def __init__( + self, + func: Union[Callable[[Input], Output], Coroutine[Input, Any, Output]], + afunc: Optional[Coroutine[Input, Any, Output]] = None, + ) -> None: + if afunc is not None: + self.afunc = afunc + + if inspect.iscoroutinefunction(func): + self.afunc = func + elif callable(func): self.func = func else: raise TypeError( @@ -1354,7 +1365,12 @@ class RunnableLambda(Runnable[Input, Output]): def __eq__(self, other: Any) -> bool: if isinstance(other, RunnableLambda): - return self.func == other.func + if hasattr(self, "func") and hasattr(other, "func"): + return self.func == other.func + elif hasattr(self, "afunc") and hasattr(other, "afunc"): + return self.afunc == other.afunc + else: + return False else: return False @@ -1364,7 +1380,24 @@ class RunnableLambda(Runnable[Input, Output]): config: Optional[RunnableConfig] = None, **kwargs: Optional[Any], ) -> Output: - return self._call_with_config(self.func, input, config) + if hasattr(self, "func"): + return self._call_with_config(self.func, input, config) + else: + raise TypeError( + "Cannot invoke a coroutine function synchronously." + "Use `ainvoke` instead." + ) + + async def ainvoke( + self, + input: Input, + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], + ) -> Output: + if hasattr(self, "afunc"): + return await self._acall_with_config(self.afunc, input, config) + else: + return await super().ainvoke(input, config) class RunnableEach(Serializable, Runnable[List[Input], List[Output]]): diff --git a/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr b/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr index 828227fd308..ac8deb3de25 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr +++ b/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr @@ -1187,6 +1187,122 @@ Run(id=UUID('00000000-0000-4000-8000-000000000003'), name='RunnableSequence', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='chain', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'runnable', 'RunnableSequence'], 'kwargs': {'first': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string', 'partial_variables': {}}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string', 'partial_variables': {}}}}}], 'input_variables': ['question']}}, 'last': {'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'llms', 'fake', 'FakeListLLM']}}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your favorite color?'}, outputs={'output': 'foo'}, reference_example_id=None, parent_run_id=None, tags=[], execution_order=None, child_execution_order=None, child_runs=[Run(id=UUID('00000000-0000-4000-8000-000000000004'), name='ChatPromptTemplate', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='prompt', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string', 'partial_variables': {}}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string', 'partial_variables': {}}}}}], 'input_variables': ['question']}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your favorite color?'}, outputs={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptValue'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'messages', 'SystemMessage'], 'kwargs': {'content': 'You are a nice assistant.', 'additional_kwargs': {}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'messages', 'HumanMessage'], 'kwargs': {'content': 'What is your favorite color?', 'additional_kwargs': {}}}]}}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000003'), tags=[], execution_order=None, child_execution_order=None, child_runs=[]), Run(id=UUID('00000000-0000-4000-8000-000000000005'), name='FakeListLLM', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='llm', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={'invocation_params': {'responses': ['foo', 'bar'], '_type': 'fake-list', 'stop': None}, 'options': {'stop': None}}, error=None, serialized={'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'llms', 'fake', 'FakeListLLM']}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'prompts': ['System: You are a nice assistant.\nHuman: What is your favorite color?']}, outputs={'generations': [[{'text': 'foo', 'generation_info': None}]], 'llm_output': None, 'run': None}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000003'), tags=[], execution_order=None, child_execution_order=None, child_runs=[])]), ]) # --- +# name: test_prompt_with_llm_and_async_lambda + ''' + { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "schema", + "runnable", + "RunnableSequence" + ], + "kwargs": { + "first": { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "prompts", + "chat", + "ChatPromptTemplate" + ], + "kwargs": { + "messages": [ + { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "prompts", + "chat", + "SystemMessagePromptTemplate" + ], + "kwargs": { + "prompt": { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "prompts", + "prompt", + "PromptTemplate" + ], + "kwargs": { + "input_variables": [], + "template": "You are a nice assistant.", + "template_format": "f-string", + "partial_variables": {} + } + } + } + }, + { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "prompts", + "chat", + "HumanMessagePromptTemplate" + ], + "kwargs": { + "prompt": { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "prompts", + "prompt", + "PromptTemplate" + ], + "kwargs": { + "input_variables": [ + "question" + ], + "template": "{question}", + "template_format": "f-string", + "partial_variables": {} + } + } + } + } + ] + } + }, + "middle": [ + { + "lc": 1, + "type": "not_implemented", + "id": [ + "langchain", + "llms", + "fake", + "FakeListLLM" + ] + } + ], + "last": { + "lc": 1, + "type": "not_implemented", + "id": [ + "langchain", + "schema", + "runnable", + "base", + "RunnableLambda" + ] + } + } + } + ''' +# --- +# name: test_prompt_with_llm_and_async_lambda.1 + list([ + Run(id=UUID('00000000-0000-4000-8000-000000000000'), name='RunnableSequence', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='chain', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'runnable', 'RunnableSequence'], 'kwargs': {'first': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string', 'partial_variables': {}}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string', 'partial_variables': {}}}}}]}}, 'middle': [{'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'llms', 'fake', 'FakeListLLM']}], 'last': {'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'schema', 'runnable', 'base', 'RunnableLambda']}}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your name?'}, outputs={'output': 'foo'}, reference_example_id=None, parent_run_id=None, tags=[], execution_order=None, child_execution_order=None, child_runs=[Run(id=UUID('00000000-0000-4000-8000-000000000001'), name='ChatPromptTemplate', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='prompt', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string', 'partial_variables': {}}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string', 'partial_variables': {}}}}}]}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your name?'}, outputs={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptValue'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'messages', 'SystemMessage'], 'kwargs': {'content': 'You are a nice assistant.', 'additional_kwargs': {}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'messages', 'HumanMessage'], 'kwargs': {'content': 'What is your name?', 'additional_kwargs': {}}}]}}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=[], execution_order=None, child_execution_order=None, child_runs=[]), Run(id=UUID('00000000-0000-4000-8000-000000000002'), name='FakeListLLM', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='llm', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={'invocation_params': {'responses': ['foo', 'bar'], '_type': 'fake-list', 'stop': None}, 'options': {'stop': None}}, error=None, serialized={'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'llms', 'fake', 'FakeListLLM']}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'prompts': ['System: You are a nice assistant.\nHuman: What is your name?']}, outputs={'generations': [[{'text': 'foo', 'generation_info': None}]], 'llm_output': None, 'run': None}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=[], execution_order=None, child_execution_order=None, child_runs=[]), Run(id=UUID('00000000-0000-4000-8000-000000000003'), name='RunnableLambda', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='chain', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'schema', 'runnable', 'base', 'RunnableLambda']}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'input': 'foo'}, outputs={'output': 'foo'}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=[], execution_order=None, child_execution_order=None, child_runs=[])]), + ]) +# --- # name: test_router_runnable ''' { diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py index 544b42da693..8c616770efe 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py @@ -1,3 +1,4 @@ +from ast import Not from typing import Any, Dict, List, Optional from uuid import UUID @@ -38,6 +39,7 @@ from langchain.schema.runnable import ( RunnablePassthrough, RunnableSequence, RunnableWithFallbacks, + passthrough, ) @@ -438,6 +440,50 @@ async def test_prompt_with_llm( ) +@pytest.mark.asyncio +@freeze_time("2023-01-01") +async def test_prompt_with_llm_and_async_lambda( + mocker: MockerFixture, snapshot: SnapshotAssertion +) -> None: + prompt = ( + SystemMessagePromptTemplate.from_template("You are a nice assistant.") + + "{question}" + ) + llm = FakeListLLM(responses=["foo", "bar"]) + + async def passthrough(input: Any) -> Any: + return input + + chain = prompt | llm | passthrough + + assert isinstance(chain, RunnableSequence) + assert chain.first == prompt + assert chain.middle == [llm] + assert chain.last == RunnableLambda(func=passthrough) + assert dumps(chain, pretty=True) == snapshot + + # Test invoke + prompt_spy = mocker.spy(prompt.__class__, "ainvoke") + llm_spy = mocker.spy(llm.__class__, "ainvoke") + tracer = FakeTracer() + assert ( + await chain.ainvoke( + {"question": "What is your name?"}, dict(callbacks=[tracer]) + ) + == "foo" + ) + assert prompt_spy.call_args.args[1] == {"question": "What is your name?"} + assert llm_spy.call_args.args[1] == ChatPromptValue( + messages=[ + SystemMessage(content="You are a nice assistant."), + HumanMessage(content="What is your name?"), + ] + ) + assert tracer.runs == snapshot + mocker.stop(prompt_spy) + mocker.stop(llm_spy) + + @freeze_time("2023-01-01") def test_prompt_with_chat_model_and_parser( mocker: MockerFixture, snapshot: SnapshotAssertion