diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index d660132767c..7631c54dd30 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -217,6 +217,12 @@ class Runnable(Generic[Input, Output], ABC): """ return RunnableBinding(bound=self, kwargs=kwargs) + def each(self) -> Runnable[List[Input], List[Output]]: + """ + Wrap a Runnable to run it on each element of the input sequence. + """ + return RunnableEach(bound=self) + def with_fallbacks( self, fallbacks: Sequence[Runnable[Input, Output]], @@ -1360,6 +1366,41 @@ class RunnableLambda(Runnable[Input, Output]): return self._call_with_config(self.func, input, config) +class RunnableEach(Serializable, Runnable[List[Input], List[Output]]): + """ + A runnable that delegates calls to another runnable with each element of the input sequence. + """ + + bound: Runnable[Input, Output] + + class Config: + arbitrary_types_allowed = True + + @property + def lc_serializable(self) -> bool: + return True + + @property + def lc_namespace(self) -> List[str]: + return self.__class__.__module__.split(".")[:-1] + + def each(self) -> RunnableEach[Input, Output]: # type: ignore[override] + return self + + def bind(self, **kwargs: Any) -> RunnableEach[Input, Output]: + return RunnableEach(bound=self.bound.bind(**kwargs)) + + def invoke( + self, input: List[Input], config: Optional[RunnableConfig] = None + ) -> List[Output]: + return self.bound.batch(input, config) + + async def ainvoke( + self, input: List[Input], config: Optional[RunnableConfig] = None + ) -> List[Output]: + return await self.bound.abatch(input, config) + + class RunnableBinding(Serializable, Runnable[Input, Output]): """ A runnable that delegates calls to another runnable with a set of kwargs. diff --git a/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr b/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr index 11e554a9cc4..44793ce3a03 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr +++ b/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr @@ -464,6 +464,139 @@ Run(id=UUID('00000000-0000-4000-8000-000000000000'), name='RunnableSequence', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='chain', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'runnable', 'RunnableSequence'], 'kwargs': {'first': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string', 'partial_variables': {}}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string', 'partial_variables': {}}}}}], 'input_variables': ['question']}}, 'middle': [{'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'chat_models', 'fake', 'FakeListChatModel']}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'output_parsers', 'list', 'CommaSeparatedListOutputParser'], 'kwargs': {}}, {'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'schema', 'runnable', 'base', 'RunnableLambda']}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nicer assistant.', 'template_format': 'f-string', 'partial_variables': {}}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string', 'partial_variables': {}}}}}], 'input_variables': ['question']}}, {'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'chat_models', 'fake', 'FakeListChatModel']}], 'last': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'output_parsers', 'list', 'CommaSeparatedListOutputParser'], 'kwargs': {}}}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your name?'}, outputs={'output': ['baz', 'qux']}, reference_example_id=None, parent_run_id=None, tags=[], execution_order=None, child_execution_order=None, child_runs=[Run(id=UUID('00000000-0000-4000-8000-000000000001'), name='ChatPromptTemplate', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='prompt', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string', 'partial_variables': {}}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string', 'partial_variables': {}}}}}], 'input_variables': ['question']}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your name?'}, outputs={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptValue'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'messages', 'SystemMessage'], 'kwargs': {'content': 'You are a nice assistant.', 'additional_kwargs': {}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'messages', 'HumanMessage'], 'kwargs': {'content': 'What is your name?', 'additional_kwargs': {}}}]}}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=[], execution_order=None, child_execution_order=None, child_runs=[]), Run(id=UUID('00000000-0000-4000-8000-000000000002'), name='FakeListChatModel', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='llm', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={'invocation_params': {'responses': ['foo, bar'], '_type': 'fake-list-chat-model', 'stop': None}, 'options': {'stop': None}}, error=None, serialized={'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'chat_models', 'fake', 'FakeListChatModel']}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'prompts': ['System: You are a nice assistant.\nHuman: What is your name?']}, outputs={'generations': [[{'text': 'foo, bar', 'generation_info': None, 'message': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'messages', 'AIMessage'], 'kwargs': {'content': 'foo, bar'}}}]], 'llm_output': None, 'run': None}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=[], execution_order=None, child_execution_order=None, child_runs=[]), Run(id=UUID('00000000-0000-4000-8000-000000000003'), name='CommaSeparatedListOutputParser', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='parser', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'output_parsers', 'list', 'CommaSeparatedListOutputParser'], 'kwargs': {}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'input': AIMessage(content='foo, bar', additional_kwargs={}, example=False)}, outputs={'output': ['foo', 'bar']}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=[], execution_order=None, child_execution_order=None, child_runs=[]), Run(id=UUID('00000000-0000-4000-8000-000000000004'), name='RunnableLambda', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='chain', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'schema', 'runnable', 'base', 'RunnableLambda']}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'input': ['foo', 'bar']}, outputs={'question': 'foobar'}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=[], execution_order=None, child_execution_order=None, child_runs=[]), Run(id=UUID('00000000-0000-4000-8000-000000000005'), name='ChatPromptTemplate', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='prompt', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nicer assistant.', 'template_format': 'f-string', 'partial_variables': {}}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string', 'partial_variables': {}}}}}], 'input_variables': ['question']}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'foobar'}, outputs={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptValue'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'messages', 'SystemMessage'], 'kwargs': {'content': 'You are a nicer assistant.', 'additional_kwargs': {}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'messages', 'HumanMessage'], 'kwargs': {'content': 'foobar', 'additional_kwargs': {}}}]}}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=[], execution_order=None, child_execution_order=None, child_runs=[]), Run(id=UUID('00000000-0000-4000-8000-000000000006'), name='FakeListChatModel', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='llm', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={'invocation_params': {'responses': ['baz, qux'], '_type': 'fake-list-chat-model', 'stop': None}, 'options': {'stop': None}}, error=None, serialized={'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'chat_models', 'fake', 'FakeListChatModel']}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'prompts': ['System: You are a nicer assistant.\nHuman: foobar']}, outputs={'generations': [[{'text': 'baz, qux', 'generation_info': None, 'message': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'messages', 'AIMessage'], 'kwargs': {'content': 'baz, qux'}}}]], 'llm_output': None, 'run': None}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=[], execution_order=None, child_execution_order=None, child_runs=[]), Run(id=UUID('00000000-0000-4000-8000-000000000007'), name='CommaSeparatedListOutputParser', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='parser', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'output_parsers', 'list', 'CommaSeparatedListOutputParser'], 'kwargs': {}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'input': AIMessage(content='baz, qux', additional_kwargs={}, example=False)}, outputs={'output': ['baz', 'qux']}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=[], execution_order=None, child_execution_order=None, child_runs=[])]), ]) # --- +# name: test_each + ''' + { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "schema", + "runnable", + "RunnableSequence" + ], + "kwargs": { + "first": { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "prompts", + "chat", + "ChatPromptTemplate" + ], + "kwargs": { + "messages": [ + { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "prompts", + "chat", + "SystemMessagePromptTemplate" + ], + "kwargs": { + "prompt": { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "prompts", + "prompt", + "PromptTemplate" + ], + "kwargs": { + "input_variables": [], + "template": "You are a nice assistant.", + "template_format": "f-string", + "partial_variables": {} + } + } + } + }, + { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "prompts", + "chat", + "HumanMessagePromptTemplate" + ], + "kwargs": { + "prompt": { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "prompts", + "prompt", + "PromptTemplate" + ], + "kwargs": { + "input_variables": [ + "question" + ], + "template": "{question}", + "template_format": "f-string", + "partial_variables": {} + } + } + } + } + ] + } + }, + "middle": [ + { + "lc": 1, + "type": "not_implemented", + "id": [ + "langchain", + "llms", + "fake", + "FakeStreamingListLLM" + ] + }, + { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "output_parsers", + "list", + "CommaSeparatedListOutputParser" + ], + "kwargs": {} + } + ], + "last": { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "schema", + "runnable", + "RunnableEach" + ], + "kwargs": { + "bound": { + "lc": 1, + "type": "not_implemented", + "id": [ + "langchain", + "llms", + "fake", + "FakeStreamingListLLM" + ] + } + } + } + } + } + ''' +# --- # name: test_llm_with_fallbacks[llm_chain_with_fallbacks] ''' { diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py index 98cab172a56..cb5c0147250 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py @@ -1,5 +1,6 @@ from typing import Any, Dict, List, Optional from uuid import UUID +from xml.dom import ValidationErr import pytest from freezegun import freeze_time @@ -20,6 +21,7 @@ from langchain.prompts.chat import ( HumanMessagePromptTemplate, SystemMessagePromptTemplate, ) +from langchain.pydantic_v1 import ValidationError from langchain.schema.document import Document from langchain.schema.messages import ( AIMessage, @@ -1086,3 +1088,18 @@ async def test_llm_with_fallbacks( assert await runnable.abatch(["hi", "hey", "bye"]) == ["bar"] * 3 assert list(await runnable.ainvoke("hello")) == list("bar") assert dumps(runnable, pretty=True) == snapshot + + +def test_each(snapshot: SnapshotAssertion) -> None: + prompt = ( + SystemMessagePromptTemplate.from_template("You are a nice assistant.") + + "{question}" + ) + first_llm = FakeStreamingListLLM(responses=["first item, second item, third item"]) + second_llm = FakeStreamingListLLM(responses=["this", "is", "a", "test"]) + + chain = prompt | first_llm | CommaSeparatedListOutputParser() | second_llm.each() + + assert dumps(chain, pretty=True) == snapshot + output = chain.invoke({"question": "What up"}) + assert output == ["this", "is", "a"]