From 48307e46a3136ea55fb65e5313c7dda8d8677169 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Thu, 18 Apr 2024 18:52:33 -0700 Subject: [PATCH] core[patch]: Fix runnable map ser/de (#20631) --- libs/core/langchain_core/runnables/base.py | 34 +++++++++---------- .../langchain_core/runnables/passthrough.py | 8 +++-- .../__snapshots__/test_fallbacks.ambr | 4 +-- .../__snapshots__/test_runnable.ambr | 14 ++++---- .../unit_tests/runnables/test_runnable.py | 6 +++- 5 files changed, 36 insertions(+), 30 deletions(-) diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index 7645bf52bca..a29a0677ee9 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -2136,7 +2136,7 @@ def _seq_input_schema( **{ k: (v.annotation, v.default) for k, v in next_input_schema.__fields__.items() - if k not in first.mapper.steps + if k not in first.mapper.steps__ }, ) elif isinstance(first, RunnablePick): @@ -2981,11 +2981,11 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]): print(output) # noqa: T201 """ - steps: Mapping[str, Runnable[Input, Any]] + steps__: Mapping[str, Runnable[Input, Any]] def __init__( self, - __steps: Optional[ + steps__: Optional[ Mapping[ str, Union[ @@ -3001,10 +3001,10 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]): Mapping[str, Union[Runnable[Input, Any], Callable[[Input], Any]]], ], ) -> None: - merged = {**__steps} if __steps is not None else {} + merged = {**steps__} if steps__ is not None else {} merged.update(kwargs) super().__init__( # type: ignore[call-arg] - steps={key: coerce_to_runnable(r) for key, r in merged.items()} + steps__={key: coerce_to_runnable(r) for key, r in merged.items()} ) @classmethod @@ -3022,12 +3022,12 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]): def get_name( self, suffix: Optional[str] = None, *, name: Optional[str] = None ) -> str: - name = name or self.name or f"RunnableParallel<{','.join(self.steps.keys())}>" + name = name or self.name or f"RunnableParallel<{','.join(self.steps__.keys())}>" return super().get_name(suffix, name=name) @property def InputType(self) -> Any: - for step in self.steps.values(): + for step in self.steps__.values(): if step.InputType: return step.InputType @@ -3038,14 +3038,14 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]): ) -> Type[BaseModel]: if all( s.get_input_schema(config).schema().get("type", "object") == "object" - for s in self.steps.values() + for s in self.steps__.values() ): # This is correct, but pydantic typings/mypy don't think so. return create_model( # type: ignore[call-overload] self.get_name("Input"), **{ k: (v.annotation, v.default) - for step in self.steps.values() + for step in self.steps__.values() for k, v in step.get_input_schema(config).__fields__.items() if k != "__root__" }, @@ -3059,13 +3059,13 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]): # This is correct, but pydantic typings/mypy don't think so. return create_model( # type: ignore[call-overload] self.get_name("Output"), - **{k: (v.OutputType, None) for k, v in self.steps.items()}, + **{k: (v.OutputType, None) for k, v in self.steps__.items()}, ) @property def config_specs(self) -> List[ConfigurableFieldSpec]: return get_unique_config_specs( - spec for step in self.steps.values() for spec in step.config_specs + spec for step in self.steps__.values() for spec in step.config_specs ) def get_graph(self, config: Optional[RunnableConfig] = None) -> Graph: @@ -3074,7 +3074,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]): graph = Graph() input_node = graph.add_node(self.get_input_schema(config)) output_node = graph.add_node(self.get_output_schema(config)) - for step in self.steps.values(): + for step in self.steps__.values(): step_graph = step.get_graph() step_graph.trim_first_node() step_graph.trim_last_node() @@ -3096,7 +3096,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]): def __repr__(self) -> str: map_for_repr = ",\n ".join( f"{k}: {indent_lines_after_first(repr(v), ' ' + k + ': ')}" - for k, v in self.steps.items() + for k, v in self.steps__.items() ) return "{\n " + map_for_repr + "\n}" @@ -3127,7 +3127,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]): # gather results from all steps try: # copy to avoid issues from the caller mutating the steps during invoke() - steps = dict(self.steps) + steps = dict(self.steps__) with get_executor_for_config(config) as executor: futures = [ executor.submit( @@ -3170,7 +3170,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]): # gather results from all steps try: # copy to avoid issues from the caller mutating the steps during invoke() - steps = dict(self.steps) + steps = dict(self.steps__) results = await asyncio.gather( *( step.ainvoke( @@ -3199,7 +3199,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]): config: RunnableConfig, ) -> Iterator[AddableDict]: # Shallow copy steps to ignore mutations while in progress - steps = dict(self.steps) + steps = dict(self.steps__) # Each step gets a copy of the input iterator, # which is consumed in parallel in a separate thread. input_copies = list(safetee(input, len(steps), lock=threading.Lock())) @@ -3264,7 +3264,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]): config: RunnableConfig, ) -> AsyncIterator[AddableDict]: # Shallow copy steps to ignore mutations while in progress - steps = dict(self.steps) + steps = dict(self.steps__) # Each step gets a copy of the input iterator, # which is consumed in parallel in a separate thread. input_copies = list(atee(input, len(steps), lock=asyncio.Lock())) diff --git a/libs/core/langchain_core/runnables/passthrough.py b/libs/core/langchain_core/runnables/passthrough.py index 55d61093ec7..d2fbf30e4b9 100644 --- a/libs/core/langchain_core/runnables/passthrough.py +++ b/libs/core/langchain_core/runnables/passthrough.py @@ -369,7 +369,9 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]): self, suffix: Optional[str] = None, *, name: Optional[str] = None ) -> str: name = ( - name or self.name or f"RunnableAssign<{','.join(self.mapper.steps.keys())}>" + name + or self.name + or f"RunnableAssign<{','.join(self.mapper.steps__.keys())}>" ) return super().get_name(suffix, name=name) @@ -488,7 +490,7 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]): **kwargs: Any, ) -> Iterator[Dict[str, Any]]: # collect mapper keys - mapper_keys = set(self.mapper.steps.keys()) + mapper_keys = set(self.mapper.steps__.keys()) # create two streams, one for the map and one for the passthrough for_passthrough, for_map = safetee(input, 2, lock=threading.Lock()) @@ -544,7 +546,7 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]): **kwargs: Any, ) -> AsyncIterator[Dict[str, Any]]: # collect mapper keys - mapper_keys = set(self.mapper.steps.keys()) + mapper_keys = set(self.mapper.steps__.keys()) # create two streams, one for the map and one for the passthrough for_passthrough, for_map = atee(input, 2, lock=asyncio.Lock()) # create map output stream diff --git a/libs/core/tests/unit_tests/runnables/__snapshots__/test_fallbacks.ambr b/libs/core/tests/unit_tests/runnables/__snapshots__/test_fallbacks.ambr index b802468d17e..22309fcd6d0 100644 --- a/libs/core/tests/unit_tests/runnables/__snapshots__/test_fallbacks.ambr +++ b/libs/core/tests/unit_tests/runnables/__snapshots__/test_fallbacks.ambr @@ -21,7 +21,7 @@ "RunnableParallel" ], "kwargs": { - "steps": { + "steps__": { "buz": { "lc": 1, "type": "not_implemented", @@ -569,7 +569,7 @@ "RunnableParallel" ], "kwargs": { - "steps": { + "steps__": { "text": { "lc": 1, "type": "constructor", diff --git a/libs/core/tests/unit_tests/runnables/__snapshots__/test_runnable.ambr b/libs/core/tests/unit_tests/runnables/__snapshots__/test_runnable.ambr index ee65b937915..aa6499bc290 100644 --- a/libs/core/tests/unit_tests/runnables/__snapshots__/test_runnable.ambr +++ b/libs/core/tests/unit_tests/runnables/__snapshots__/test_runnable.ambr @@ -2051,7 +2051,7 @@ "RunnableParallel" ], "kwargs": { - "steps": { + "steps__": { "key": { "lc": 1, "type": "not_implemented", @@ -2073,7 +2073,7 @@ "RunnableParallel" ], "kwargs": { - "steps": { + "steps__": { "question": { "lc": 1, "type": "not_implemented", @@ -4459,7 +4459,7 @@ "RunnableParallel" ], "kwargs": { - "steps": { + "steps__": { "key": { "lc": 1, "type": "not_implemented", @@ -4481,7 +4481,7 @@ "RunnableParallel" ], "kwargs": { - "steps": { + "steps__": { "question": { "lc": 1, "type": "not_implemented", @@ -8760,7 +8760,7 @@ "RunnableParallel" ], "kwargs": { - "steps": { + "steps__": { "question": { "lc": 1, "type": "constructor", @@ -9860,7 +9860,7 @@ "RunnableParallel" ], "kwargs": { - "steps": { + "steps__": { "chat": { "lc": 1, "type": "not_implemented", @@ -10352,7 +10352,7 @@ "RunnableParallel" ], "kwargs": { - "steps": { + "steps__": { "chat": { "lc": 1, "type": "constructor", diff --git a/libs/core/tests/unit_tests/runnables/test_runnable.py b/libs/core/tests/unit_tests/runnables/test_runnable.py index 7a0524a90b8..ca6d2a3adaf 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable.py @@ -35,6 +35,7 @@ from langchain_core.language_models import ( FakeStreamingListLLM, ) from langchain_core.load import dumpd, dumps +from langchain_core.load.load import loads from langchain_core.messages import ( AIMessage, AIMessageChunk, @@ -76,7 +77,7 @@ from langchain_core.runnables import ( add, chain, ) -from langchain_core.runnables.base import RunnableSerializable +from langchain_core.runnables.base import RunnableMap, RunnableSerializable from langchain_core.runnables.utils import Input, Output from langchain_core.tools import BaseTool, tool from langchain_core.tracers import ( @@ -3553,6 +3554,9 @@ async def test_map_astream_iterator_input() -> None: assert final_value.get("llm") == "i'm a textbot" assert final_value.get("passthrough") == llm_res + simple_map = RunnableMap(passthrough=RunnablePassthrough()) + assert loads(dumps(simple_map)) == simple_map + def test_with_config_with_config() -> None: llm = FakeListLLM(responses=["i'm a textbot"])