From 6a5a2fb9c82a07ed91cc155a9c19a3330d8b6bda Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Wed, 27 Dec 2023 13:35:34 -0800 Subject: [PATCH] Add .pick and .assign methods to Runnable (#15229) --- .../expression_language/cookbook/sql_db.ipynb | 3 +- .../core/langchain_core/runnables/__init__.py | 8 +- libs/core/langchain_core/runnables/base.py | 68 +++++++++++--- .../langchain_core/runnables/passthrough.py | 33 +++---- .../runnables/__snapshots__/test_graph.ambr | 90 +++++++++---------- .../__snapshots__/test_runnable.ambr | 2 +- .../unit_tests/runnables/test_imports.py | 2 + .../unit_tests/runnables/test_runnable.py | 65 +++++++------- .../chains/history_aware_retriever.py | 4 +- libs/langchain/langchain/chains/retrieval.py | 3 +- 10 files changed, 169 insertions(+), 109 deletions(-) diff --git a/docs/docs/expression_language/cookbook/sql_db.ipynb b/docs/docs/expression_language/cookbook/sql_db.ipynb index 8e872655ecd..71fe2dcc341 100644 --- a/docs/docs/expression_language/cookbook/sql_db.ipynb +++ b/docs/docs/expression_language/cookbook/sql_db.ipynb @@ -152,8 +152,7 @@ "outputs": [], "source": [ "full_chain = (\n", - " RunnablePassthrough.assign(query=sql_response)\n", - " | RunnablePassthrough.assign(\n", + " RunnablePassthrough.assign(query=sql_response).assign(\n", " schema=get_schema,\n", " response=lambda x: db.run(x[\"query\"]),\n", " )\n", diff --git a/libs/core/langchain_core/runnables/__init__.py b/libs/core/langchain_core/runnables/__init__.py index b51a94eea3f..e1c9a995cb3 100644 --- a/libs/core/langchain_core/runnables/__init__.py +++ b/libs/core/langchain_core/runnables/__init__.py @@ -31,7 +31,11 @@ from langchain_core.runnables.config import ( patch_config, ) from langchain_core.runnables.fallbacks import RunnableWithFallbacks -from langchain_core.runnables.passthrough import RunnablePassthrough +from langchain_core.runnables.passthrough import ( + RunnableAssign, + RunnablePassthrough, + RunnablePick, +) from langchain_core.runnables.router import RouterInput, RouterRunnable from langchain_core.runnables.utils import ( AddableDict, @@ -60,6 +64,8 @@ __all__ = [ "RunnableMap", "RunnableParallel", "RunnablePassthrough", + "RunnableAssign", + "RunnablePick", "RunnableSequence", "RunnableWithFallbacks", "get_config_list", diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index bd73cc81ec9..d6cf846413e 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -220,9 +220,11 @@ class Runnable(Generic[Input, Output], ABC): name: Optional[str] = None """The name of the runnable. Used for debugging and tracing.""" - def get_name(self, suffix: Optional[str] = None) -> str: + def get_name( + self, suffix: Optional[str] = None, *, name: Optional[str] = None + ) -> str: """Get the name of the runnable.""" - name = self.name or self.__class__.__name__ + name = name or self.name or self.__class__.__name__ if suffix: if name[0].isupper(): return name + suffix.title() @@ -410,6 +412,38 @@ class Runnable(Generic[Input, Output], ABC): """Compose this runnable with another object to create a RunnableSequence.""" return RunnableSequence(coerce_to_runnable(other), self) + def pipe( + self, + *others: Union[Runnable[Any, Other], Callable[[Any], Other]], + name: Optional[str] = None, + ) -> RunnableSerializable[Input, Other]: + """Compose this runnable with another object to create a RunnableSequence.""" + return RunnableSequence(self, *others, name=name) + + def pick(self, keys: Union[str, List[str]]) -> RunnableSerializable[Any, Any]: + """Pick keys from the dict output of this runnable. + Returns a new runnable.""" + from langchain_core.runnables.passthrough import RunnablePick + + return self | RunnablePick(keys) + + def assign( + self, + **kwargs: Union[ + Runnable[Dict[str, Any], Any], + Callable[[Dict[str, Any]], Any], + Mapping[ + str, + Union[Runnable[Dict[str, Any], Any], Callable[[Dict[str, Any]], Any]], + ], + ], + ) -> RunnableSerializable[Any, Any]: + """Assigns new fields to the dict output of this runnable. + Returns a new runnable.""" + from langchain_core.runnables.passthrough import RunnableAssign + + return self | RunnableAssign(RunnableParallel(kwargs)) + """ --- Public API --- """ @abstractmethod @@ -1669,7 +1703,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]): callback_manager = get_callback_manager_for_config(config) # start the root run run_manager = callback_manager.on_chain_start( - dumpd(self), input, name=config.get("run_name") or self.name + dumpd(self), input, name=config.get("run_name") or self.get_name() ) # invoke all steps in sequence @@ -1703,7 +1737,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]): callback_manager = get_async_callback_manager_for_config(config) # start the root run run_manager = await callback_manager.on_chain_start( - dumpd(self), input, name=config.get("run_name") or self.name + dumpd(self), input, name=config.get("run_name") or self.get_name() ) # invoke all steps in sequence @@ -1760,7 +1794,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]): cm.on_chain_start( dumpd(self), input, - name=config.get("run_name") or self.name, + name=config.get("run_name") or self.get_name(), ) for cm, input, config in zip(callback_managers, inputs, configs) ] @@ -1884,7 +1918,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]): cm.on_chain_start( dumpd(self), input, - name=config.get("run_name") or self.name, + name=config.get("run_name") or self.get_name(), ) for cm, input, config in zip(callback_managers, inputs, configs) ) @@ -2119,6 +2153,12 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]): class Config: arbitrary_types_allowed = True + def get_name( + self, suffix: Optional[str] = None, *, name: Optional[str] = None + ) -> str: + name = name or self.name or f"RunnableParallel<{','.join(self.steps.keys())}>" + return super().get_name(suffix, name=name) + @property def InputType(self) -> Any: for step in self.steps.values(): @@ -2214,7 +2254,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]): ) # start the root run run_manager = callback_manager.on_chain_start( - dumpd(self), input, name=config.get("run_name") + dumpd(self), input, name=config.get("run_name") or self.get_name() ) # gather results from all steps @@ -2254,7 +2294,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]): callback_manager = get_async_callback_manager_for_config(config) # start the root run run_manager = await callback_manager.on_chain_start( - dumpd(self), input, name=config.get("run_name") + dumpd(self), input, name=config.get("run_name") or self.get_name() ) # gather results from all steps @@ -3174,6 +3214,12 @@ class RunnableEach(RunnableEachBase[Input, Output]): """Get the namespace of the langchain object.""" return ["langchain", "schema", "runnable"] + def get_name( + self, suffix: Optional[str] = None, *, name: Optional[str] = None + ) -> str: + name = name or self.name or f"RunnableEach<{self.bound.get_name()}>" + return super().get_name(suffix, name=name) + def bind(self, **kwargs: Any) -> RunnableEach[Input, Output]: return RunnableEach(bound=self.bound.bind(**kwargs)) @@ -3298,8 +3344,10 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]): **other_kwargs, ) - def get_name(self, suffix: Optional[str] = None) -> str: - return self.bound.get_name(suffix) + def get_name( + self, suffix: Optional[str] = None, *, name: Optional[str] = None + ) -> str: + return self.bound.get_name(suffix, name=name) @property def InputType(self) -> Type[Input]: diff --git a/libs/core/langchain_core/runnables/passthrough.py b/libs/core/langchain_core/runnables/passthrough.py index 21f75b94428..66c9b420c41 100644 --- a/libs/core/langchain_core/runnables/passthrough.py +++ b/libs/core/langchain_core/runnables/passthrough.py @@ -202,21 +202,6 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]): """ return RunnableAssign(RunnableParallel(kwargs)) - @classmethod - def pick( - cls, - keys: Union[str, List[str]], - ) -> "RunnablePick": - """Pick keys from the Dict input. - - Args: - keys: A string or list of strings representing the keys to pick. - - Returns: - A runnable that picks keys from the Dict input. - """ - return RunnablePick(keys) - def invoke( self, input: Other, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> Other: @@ -335,6 +320,14 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]): """Get the namespace of the langchain object.""" return ["langchain", "schema", "runnable"] + def get_name( + self, suffix: Optional[str] = None, *, name: Optional[str] = None + ) -> str: + name = ( + name or self.name or f"RunnableAssign<{','.join(self.mapper.steps.keys())}>" + ) + return super().get_name(suffix, name=name) + def get_input_schema( self, config: Optional[RunnableConfig] = None ) -> Type[BaseModel]: @@ -589,6 +582,16 @@ class RunnablePick(RunnableSerializable[Dict[str, Any], Dict[str, Any]]): """Get the namespace of the langchain object.""" return ["langchain", "schema", "runnable"] + def get_name( + self, suffix: Optional[str] = None, *, name: Optional[str] = None + ) -> str: + name = ( + name + or self.name + or f"RunnablePick<{','.join([self.keys] if isinstance(self.keys, str) else self.keys)}>" # noqa: E501 + ) + return super().get_name(suffix, name=name) + def _pick(self, input: Dict[str, Any]) -> Any: assert isinstance( input, dict diff --git a/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr b/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr index fd3fbcf5b5d..a76baa67871 100644 --- a/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr +++ b/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr @@ -32,51 +32,51 @@ # --- # name: test_graph_sequence_map ''' - +-------------+ - | PromptInput | - +-------------+ - * - * - * - +----------------+ - | PromptTemplate | - +----------------+ - * - * - * - +-------------+ - | FakeListLLM | - +-------------+ - * - * - * - +---------------+ - | ParallelInput | - +---------------+****** - ***** ****** - *** ****** - *** ****** - +------------------------------+ *** - | conditional_str_parser_input | * - +------------------------------+ * - *** *** * - *** *** * - ** ** * - +-----------------+ +-----------------+ * - | StrOutputParser | | XMLOutputParser | * - +-----------------+ +-----------------+ * - *** *** * - *** *** * - ** ** * - +-------------------------------+ +--------------------------------+ - | conditional_str_parser_output | | CommaSeparatedListOutputParser | - +-------------------------------+ +--------------------------------+ - ***** ****** - *** ****** - *** *** - +----------------+ - | ParallelOutput | - +----------------+ + +-------------+ + | PromptInput | + +-------------+ + * + * + * + +----------------+ + | PromptTemplate | + +----------------+ + * + * + * + +-------------+ + | FakeListLLM | + +-------------+ + * + * + * + +-------------------------------+ + | ParallelInput | + +-------------------------------+ + ***** ****** + *** ****** + *** ****** + +------------------------------+ **** + | conditional_str_parser_input | * + +------------------------------+ * + *** *** * + *** *** * + ** ** * + +-----------------+ +-----------------+ * + | StrOutputParser | | XMLOutputParser | * + +-----------------+ +-----------------+ * + *** *** * + *** *** * + ** ** * + +-------------------------------+ +--------------------------------+ + | conditional_str_parser_output | | CommaSeparatedListOutputParser | + +-------------------------------+ +--------------------------------+ + ***** ****** + *** ****** + *** **** + +--------------------------------+ + | ParallelOutput | + +--------------------------------+ ''' # --- # name: test_graph_single_runnable diff --git a/libs/core/tests/unit_tests/runnables/__snapshots__/test_runnable.ambr b/libs/core/tests/unit_tests/runnables/__snapshots__/test_runnable.ambr index 5b3b4e4cbb5..f712cce6b75 100644 --- a/libs/core/tests/unit_tests/runnables/__snapshots__/test_runnable.ambr +++ b/libs/core/tests/unit_tests/runnables/__snapshots__/test_runnable.ambr @@ -4012,7 +4012,7 @@ 'items': dict({ '$ref': '#/definitions/PromptTemplateOutput', }), - 'title': 'RunnableEachOutput', + 'title': 'RunnableEachOutput', 'type': 'array', }) # --- diff --git a/libs/core/tests/unit_tests/runnables/test_imports.py b/libs/core/tests/unit_tests/runnables/test_imports.py index 935571ed12a..c0bd73cd3ed 100644 --- a/libs/core/tests/unit_tests/runnables/test_imports.py +++ b/libs/core/tests/unit_tests/runnables/test_imports.py @@ -18,6 +18,8 @@ EXPECTED_ALL = [ "RunnableMap", "RunnableParallel", "RunnablePassthrough", + "RunnableAssign", + "RunnablePick", "RunnableSequence", "RunnableWithFallbacks", "get_config_list", diff --git a/libs/core/tests/unit_tests/runnables/test_runnable.py b/libs/core/tests/unit_tests/runnables/test_runnable.py index 6ded9d7cd22..4f4d1749e40 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable.py @@ -64,6 +64,7 @@ from langchain_core.runnables import ( RunnableLambda, RunnableParallel, RunnablePassthrough, + RunnablePick, RunnableSequence, RunnableWithFallbacks, add, @@ -510,7 +511,7 @@ def test_schemas(snapshot: SnapshotAssertion) -> None: }, "items": {"$ref": "#/definitions/PromptInput"}, "type": "array", - "title": "RunnableEachInput", + "title": "RunnableEachInput", } assert prompt_mapper.output_schema.schema() == snapshot @@ -571,7 +572,7 @@ def test_schemas(snapshot: SnapshotAssertion) -> None: "properties": {"name": {"title": "Name", "type": "string"}}, } assert seq_w_map.output_schema.schema() == { - "title": "RunnableParallelOutput", + "title": "RunnableParallelOutput", "type": "object", "properties": { "original": {"title": "Original", "type": "string"}, @@ -615,7 +616,7 @@ def test_passthrough_assign_schema() -> None: # expected dict input_schema assert invalid_seq_w_assign.input_schema.schema() == { "properties": {"question": {"title": "Question"}}, - "title": "RunnableParallelInput", + "title": "RunnableParallelInput", "type": "object", } @@ -774,7 +775,7 @@ def test_schema_complex_seq() -> None: ) assert chain2.input_schema.schema() == { - "title": "RunnableParallelInput", + "title": "RunnableParallelInput", "type": "object", "properties": { "person": {"title": "Person", "type": "string"}, @@ -2160,8 +2161,8 @@ async def test_stream_log_retriever() -> None: "FakeListLLM:2", "Retriever", "RunnableLambda", - "RunnableParallel", - "RunnableParallel:2", + "RunnableParallel", + "RunnableParallel", ] @@ -2444,7 +2445,7 @@ What is your name?""" parent_run = next(r for r in tracer.runs if r.parent_run_id is None) assert len(parent_run.child_runs) == 4 map_run = parent_run.child_runs[0] - assert map_run.name == "RunnableParallel" + assert map_run.name == "RunnableParallel" assert len(map_run.child_runs) == 3 @@ -2505,7 +2506,7 @@ def test_seq_prompt_dict(mocker: MockerFixture, snapshot: SnapshotAssertion) -> parent_run = next(r for r in tracer.runs if r.parent_run_id is None) assert len(parent_run.child_runs) == 3 map_run = parent_run.child_runs[2] - assert map_run.name == "RunnableParallel" + assert map_run.name == "RunnableParallel" assert len(map_run.child_runs) == 2 @@ -2721,7 +2722,7 @@ def test_seq_prompt_map(mocker: MockerFixture, snapshot: SnapshotAssertion) -> N parent_run = next(r for r in tracer.runs if r.parent_run_id is None) assert len(parent_run.child_runs) == 3 map_run = parent_run.child_runs[2] - assert map_run.name == "RunnableParallel" + assert map_run.name == "RunnableParallel" assert len(map_run.child_runs) == 3 @@ -2770,7 +2771,7 @@ def test_map_stream() -> None: {"question": "What is your name?"} ) - chain_pick_one = chain | RunnablePassthrough.pick("llm") + chain_pick_one = chain.pick("llm") assert chain_pick_one.output_schema.schema() == { "title": "RunnableSequenceOutput", @@ -2791,10 +2792,8 @@ def test_map_stream() -> None: assert streamed_chunks[0] == "i" assert len(streamed_chunks) == len(llm_res) - chain_pick_two = ( - chain - | RunnablePassthrough.assign(hello=RunnablePassthrough.pick("llm") | llm) - | RunnablePassthrough.pick(["llm", "hello"]) + chain_pick_two = chain.assign(hello=RunnablePick("llm").pipe(llm)).pick( + ["llm", "hello"] ) assert chain_pick_two.output_schema.schema() == { @@ -2940,12 +2939,15 @@ async def test_map_astream() -> None: assert final_state.state["logs"]["ChatPromptTemplate"][ "final_output" ] == prompt.invoke({"question": "What is your name?"}) - assert final_state.state["logs"]["RunnableParallel"]["name"] == "RunnableParallel" + assert ( + final_state.state["logs"]["RunnableParallel"]["name"] + == "RunnableParallel" + ) assert sorted(final_state.state["logs"]) == [ "ChatPromptTemplate", "FakeListChatModel", "FakeStreamingListLLM", - "RunnableParallel", + "RunnableParallel", "RunnablePassthrough", ] @@ -2985,11 +2987,14 @@ async def test_map_astream() -> None: assert final_state.state["logs"]["ChatPromptTemplate"]["final_output"] == ( prompt.invoke({"question": "What is your name?"}) ) - assert final_state.state["logs"]["RunnableParallel"]["name"] == "RunnableParallel" + assert ( + final_state.state["logs"]["RunnableParallel"]["name"] + == "RunnableParallel" + ) assert sorted(final_state.state["logs"]) == [ "ChatPromptTemplate", "FakeStreamingListLLM", - "RunnableParallel", + "RunnableParallel", "RunnablePassthrough", ] @@ -3130,9 +3135,7 @@ def test_deep_stream_assign() -> None: assert len(chunks) == len("foo-lish") assert add(chunks) == {"str": "foo-lish"} - chain_with_assign = chain | RunnablePassthrough.assign( - hello=itemgetter("str") | llm - ) + chain_with_assign = chain.assign(hello=itemgetter("str") | llm) assert chain_with_assign.input_schema.schema() == { "title": "PromptInput", @@ -3179,7 +3182,7 @@ def test_deep_stream_assign() -> None: "hello": "foo-lish", } - chain_with_assign_shadow = chain | RunnablePassthrough.assign( + chain_with_assign_shadow = chain.assign( str=lambda _: "shadow", hello=itemgetter("str") | llm, ) @@ -3254,7 +3257,7 @@ async def test_deep_astream_assign() -> None: assert len(chunks) == len("foo-lish") assert add(chunks) == {"str": "foo-lish"} - chain_with_assign = chain | RunnablePassthrough.assign( + chain_with_assign = chain.assign( hello=itemgetter("str") | llm, ) @@ -4473,15 +4476,15 @@ def test_invoke_stream_passthrough_assign_trace() -> None: tracer = FakeTracer() chain.invoke({"example": [1, 2, 3]}, dict(callbacks=[tracer])) - assert tracer.runs[0].name == "RunnableAssign" - assert tracer.runs[0].child_runs[0].name == "RunnableParallel" + assert tracer.runs[0].name == "RunnableAssign" + assert tracer.runs[0].child_runs[0].name == "RunnableParallel" tracer = FakeTracer() for item in chain.stream({"example": [1, 2, 3]}, dict(callbacks=[tracer])): pass - assert tracer.runs[0].name == "RunnableAssign" - assert tracer.runs[0].child_runs[0].name == "RunnableParallel" + assert tracer.runs[0].name == "RunnableAssign" + assert tracer.runs[0].child_runs[0].name == "RunnableParallel" async def test_ainvoke_astream_passthrough_assign_trace() -> None: @@ -4493,15 +4496,15 @@ async def test_ainvoke_astream_passthrough_assign_trace() -> None: tracer = FakeTracer() await chain.ainvoke({"example": [1, 2, 3]}, dict(callbacks=[tracer])) - assert tracer.runs[0].name == "RunnableAssign" - assert tracer.runs[0].child_runs[0].name == "RunnableParallel" + assert tracer.runs[0].name == "RunnableAssign" + assert tracer.runs[0].child_runs[0].name == "RunnableParallel" tracer = FakeTracer() async for item in chain.astream({"example": [1, 2, 3]}, dict(callbacks=[tracer])): pass - assert tracer.runs[0].name == "RunnableAssign" - assert tracer.runs[0].child_runs[0].name == "RunnableParallel" + assert tracer.runs[0].name == "RunnableAssign" + assert tracer.runs[0].child_runs[0].name == "RunnableParallel" async def test_astream_log_deep_copies() -> None: diff --git a/libs/langchain/langchain/chains/history_aware_retriever.py b/libs/langchain/langchain/chains/history_aware_retriever.py index 6e4704e7a89..a29d31a8563 100644 --- a/libs/langchain/langchain/chains/history_aware_retriever.py +++ b/libs/langchain/langchain/chains/history_aware_retriever.py @@ -35,13 +35,13 @@ def create_history_aware_retriever( # pip install -U langchain langchain-community from langchain_community.chat_models import ChatOpenAI - from langchain.chains import create_chat_history_retriever + from langchain.chains import create_history_aware_retriever from langchain import hub rephrase_prompt = hub.pull("langchain-ai/chat-langchain-rephrase") llm = ChatOpenAI() retriever = ... - chat_retriever_chain = create_chat_retriever_chain( + chat_retriever_chain = create_history_aware_retriever( llm, retriever, rephrase_prompt ) diff --git a/libs/langchain/langchain/chains/retrieval.py b/libs/langchain/langchain/chains/retrieval.py index ea53ff99ece..7781bb5a922 100644 --- a/libs/langchain/langchain/chains/retrieval.py +++ b/libs/langchain/langchain/chains/retrieval.py @@ -64,8 +64,7 @@ def create_retrieval_chain( RunnablePassthrough.assign( context=retrieval_docs.with_config(run_name="retrieve_documents"), chat_history=lambda x: x.get("chat_history", []), - ) - | RunnablePassthrough.assign(answer=combine_docs_chain) + ).assign(answer=combine_docs_chain) ).with_config(run_name="retrieval_chain") return retrieval_chain