mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-24 07:35:18 +00:00
core[patch]: Fix runnable map ser/de (#20631)
This commit is contained in:
parent
1cbab0ebda
commit
48307e46a3
@ -2136,7 +2136,7 @@ def _seq_input_schema(
|
|||||||
**{
|
**{
|
||||||
k: (v.annotation, v.default)
|
k: (v.annotation, v.default)
|
||||||
for k, v in next_input_schema.__fields__.items()
|
for k, v in next_input_schema.__fields__.items()
|
||||||
if k not in first.mapper.steps
|
if k not in first.mapper.steps__
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
elif isinstance(first, RunnablePick):
|
elif isinstance(first, RunnablePick):
|
||||||
@ -2981,11 +2981,11 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
|
|||||||
print(output) # noqa: T201
|
print(output) # noqa: T201
|
||||||
"""
|
"""
|
||||||
|
|
||||||
steps: Mapping[str, Runnable[Input, Any]]
|
steps__: Mapping[str, Runnable[Input, Any]]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
__steps: Optional[
|
steps__: Optional[
|
||||||
Mapping[
|
Mapping[
|
||||||
str,
|
str,
|
||||||
Union[
|
Union[
|
||||||
@ -3001,10 +3001,10 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
|
|||||||
Mapping[str, Union[Runnable[Input, Any], Callable[[Input], Any]]],
|
Mapping[str, Union[Runnable[Input, Any], Callable[[Input], Any]]],
|
||||||
],
|
],
|
||||||
) -> None:
|
) -> None:
|
||||||
merged = {**__steps} if __steps is not None else {}
|
merged = {**steps__} if steps__ is not None else {}
|
||||||
merged.update(kwargs)
|
merged.update(kwargs)
|
||||||
super().__init__( # type: ignore[call-arg]
|
super().__init__( # type: ignore[call-arg]
|
||||||
steps={key: coerce_to_runnable(r) for key, r in merged.items()}
|
steps__={key: coerce_to_runnable(r) for key, r in merged.items()}
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -3022,12 +3022,12 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
|
|||||||
def get_name(
|
def get_name(
|
||||||
self, suffix: Optional[str] = None, *, name: Optional[str] = None
|
self, suffix: Optional[str] = None, *, name: Optional[str] = None
|
||||||
) -> str:
|
) -> str:
|
||||||
name = name or self.name or f"RunnableParallel<{','.join(self.steps.keys())}>"
|
name = name or self.name or f"RunnableParallel<{','.join(self.steps__.keys())}>"
|
||||||
return super().get_name(suffix, name=name)
|
return super().get_name(suffix, name=name)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def InputType(self) -> Any:
|
def InputType(self) -> Any:
|
||||||
for step in self.steps.values():
|
for step in self.steps__.values():
|
||||||
if step.InputType:
|
if step.InputType:
|
||||||
return step.InputType
|
return step.InputType
|
||||||
|
|
||||||
@ -3038,14 +3038,14 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
|
|||||||
) -> Type[BaseModel]:
|
) -> Type[BaseModel]:
|
||||||
if all(
|
if all(
|
||||||
s.get_input_schema(config).schema().get("type", "object") == "object"
|
s.get_input_schema(config).schema().get("type", "object") == "object"
|
||||||
for s in self.steps.values()
|
for s in self.steps__.values()
|
||||||
):
|
):
|
||||||
# This is correct, but pydantic typings/mypy don't think so.
|
# This is correct, but pydantic typings/mypy don't think so.
|
||||||
return create_model( # type: ignore[call-overload]
|
return create_model( # type: ignore[call-overload]
|
||||||
self.get_name("Input"),
|
self.get_name("Input"),
|
||||||
**{
|
**{
|
||||||
k: (v.annotation, v.default)
|
k: (v.annotation, v.default)
|
||||||
for step in self.steps.values()
|
for step in self.steps__.values()
|
||||||
for k, v in step.get_input_schema(config).__fields__.items()
|
for k, v in step.get_input_schema(config).__fields__.items()
|
||||||
if k != "__root__"
|
if k != "__root__"
|
||||||
},
|
},
|
||||||
@ -3059,13 +3059,13 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
|
|||||||
# This is correct, but pydantic typings/mypy don't think so.
|
# This is correct, but pydantic typings/mypy don't think so.
|
||||||
return create_model( # type: ignore[call-overload]
|
return create_model( # type: ignore[call-overload]
|
||||||
self.get_name("Output"),
|
self.get_name("Output"),
|
||||||
**{k: (v.OutputType, None) for k, v in self.steps.items()},
|
**{k: (v.OutputType, None) for k, v in self.steps__.items()},
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||||
return get_unique_config_specs(
|
return get_unique_config_specs(
|
||||||
spec for step in self.steps.values() for spec in step.config_specs
|
spec for step in self.steps__.values() for spec in step.config_specs
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_graph(self, config: Optional[RunnableConfig] = None) -> Graph:
|
def get_graph(self, config: Optional[RunnableConfig] = None) -> Graph:
|
||||||
@ -3074,7 +3074,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
|
|||||||
graph = Graph()
|
graph = Graph()
|
||||||
input_node = graph.add_node(self.get_input_schema(config))
|
input_node = graph.add_node(self.get_input_schema(config))
|
||||||
output_node = graph.add_node(self.get_output_schema(config))
|
output_node = graph.add_node(self.get_output_schema(config))
|
||||||
for step in self.steps.values():
|
for step in self.steps__.values():
|
||||||
step_graph = step.get_graph()
|
step_graph = step.get_graph()
|
||||||
step_graph.trim_first_node()
|
step_graph.trim_first_node()
|
||||||
step_graph.trim_last_node()
|
step_graph.trim_last_node()
|
||||||
@ -3096,7 +3096,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
|
|||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
map_for_repr = ",\n ".join(
|
map_for_repr = ",\n ".join(
|
||||||
f"{k}: {indent_lines_after_first(repr(v), ' ' + k + ': ')}"
|
f"{k}: {indent_lines_after_first(repr(v), ' ' + k + ': ')}"
|
||||||
for k, v in self.steps.items()
|
for k, v in self.steps__.items()
|
||||||
)
|
)
|
||||||
return "{\n " + map_for_repr + "\n}"
|
return "{\n " + map_for_repr + "\n}"
|
||||||
|
|
||||||
@ -3127,7 +3127,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
|
|||||||
# gather results from all steps
|
# gather results from all steps
|
||||||
try:
|
try:
|
||||||
# copy to avoid issues from the caller mutating the steps during invoke()
|
# copy to avoid issues from the caller mutating the steps during invoke()
|
||||||
steps = dict(self.steps)
|
steps = dict(self.steps__)
|
||||||
with get_executor_for_config(config) as executor:
|
with get_executor_for_config(config) as executor:
|
||||||
futures = [
|
futures = [
|
||||||
executor.submit(
|
executor.submit(
|
||||||
@ -3170,7 +3170,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
|
|||||||
# gather results from all steps
|
# gather results from all steps
|
||||||
try:
|
try:
|
||||||
# copy to avoid issues from the caller mutating the steps during invoke()
|
# copy to avoid issues from the caller mutating the steps during invoke()
|
||||||
steps = dict(self.steps)
|
steps = dict(self.steps__)
|
||||||
results = await asyncio.gather(
|
results = await asyncio.gather(
|
||||||
*(
|
*(
|
||||||
step.ainvoke(
|
step.ainvoke(
|
||||||
@ -3199,7 +3199,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
|
|||||||
config: RunnableConfig,
|
config: RunnableConfig,
|
||||||
) -> Iterator[AddableDict]:
|
) -> Iterator[AddableDict]:
|
||||||
# Shallow copy steps to ignore mutations while in progress
|
# Shallow copy steps to ignore mutations while in progress
|
||||||
steps = dict(self.steps)
|
steps = dict(self.steps__)
|
||||||
# Each step gets a copy of the input iterator,
|
# Each step gets a copy of the input iterator,
|
||||||
# which is consumed in parallel in a separate thread.
|
# which is consumed in parallel in a separate thread.
|
||||||
input_copies = list(safetee(input, len(steps), lock=threading.Lock()))
|
input_copies = list(safetee(input, len(steps), lock=threading.Lock()))
|
||||||
@ -3264,7 +3264,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
|
|||||||
config: RunnableConfig,
|
config: RunnableConfig,
|
||||||
) -> AsyncIterator[AddableDict]:
|
) -> AsyncIterator[AddableDict]:
|
||||||
# Shallow copy steps to ignore mutations while in progress
|
# Shallow copy steps to ignore mutations while in progress
|
||||||
steps = dict(self.steps)
|
steps = dict(self.steps__)
|
||||||
# Each step gets a copy of the input iterator,
|
# Each step gets a copy of the input iterator,
|
||||||
# which is consumed in parallel in a separate thread.
|
# which is consumed in parallel in a separate thread.
|
||||||
input_copies = list(atee(input, len(steps), lock=asyncio.Lock()))
|
input_copies = list(atee(input, len(steps), lock=asyncio.Lock()))
|
||||||
|
@ -369,7 +369,9 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
|||||||
self, suffix: Optional[str] = None, *, name: Optional[str] = None
|
self, suffix: Optional[str] = None, *, name: Optional[str] = None
|
||||||
) -> str:
|
) -> str:
|
||||||
name = (
|
name = (
|
||||||
name or self.name or f"RunnableAssign<{','.join(self.mapper.steps.keys())}>"
|
name
|
||||||
|
or self.name
|
||||||
|
or f"RunnableAssign<{','.join(self.mapper.steps__.keys())}>"
|
||||||
)
|
)
|
||||||
return super().get_name(suffix, name=name)
|
return super().get_name(suffix, name=name)
|
||||||
|
|
||||||
@ -488,7 +490,7 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Iterator[Dict[str, Any]]:
|
) -> Iterator[Dict[str, Any]]:
|
||||||
# collect mapper keys
|
# collect mapper keys
|
||||||
mapper_keys = set(self.mapper.steps.keys())
|
mapper_keys = set(self.mapper.steps__.keys())
|
||||||
# create two streams, one for the map and one for the passthrough
|
# create two streams, one for the map and one for the passthrough
|
||||||
for_passthrough, for_map = safetee(input, 2, lock=threading.Lock())
|
for_passthrough, for_map = safetee(input, 2, lock=threading.Lock())
|
||||||
|
|
||||||
@ -544,7 +546,7 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> AsyncIterator[Dict[str, Any]]:
|
) -> AsyncIterator[Dict[str, Any]]:
|
||||||
# collect mapper keys
|
# collect mapper keys
|
||||||
mapper_keys = set(self.mapper.steps.keys())
|
mapper_keys = set(self.mapper.steps__.keys())
|
||||||
# create two streams, one for the map and one for the passthrough
|
# create two streams, one for the map and one for the passthrough
|
||||||
for_passthrough, for_map = atee(input, 2, lock=asyncio.Lock())
|
for_passthrough, for_map = atee(input, 2, lock=asyncio.Lock())
|
||||||
# create map output stream
|
# create map output stream
|
||||||
|
@ -21,7 +21,7 @@
|
|||||||
"RunnableParallel"
|
"RunnableParallel"
|
||||||
],
|
],
|
||||||
"kwargs": {
|
"kwargs": {
|
||||||
"steps": {
|
"steps__": {
|
||||||
"buz": {
|
"buz": {
|
||||||
"lc": 1,
|
"lc": 1,
|
||||||
"type": "not_implemented",
|
"type": "not_implemented",
|
||||||
@ -569,7 +569,7 @@
|
|||||||
"RunnableParallel"
|
"RunnableParallel"
|
||||||
],
|
],
|
||||||
"kwargs": {
|
"kwargs": {
|
||||||
"steps": {
|
"steps__": {
|
||||||
"text": {
|
"text": {
|
||||||
"lc": 1,
|
"lc": 1,
|
||||||
"type": "constructor",
|
"type": "constructor",
|
||||||
|
@ -2051,7 +2051,7 @@
|
|||||||
"RunnableParallel"
|
"RunnableParallel"
|
||||||
],
|
],
|
||||||
"kwargs": {
|
"kwargs": {
|
||||||
"steps": {
|
"steps__": {
|
||||||
"key": {
|
"key": {
|
||||||
"lc": 1,
|
"lc": 1,
|
||||||
"type": "not_implemented",
|
"type": "not_implemented",
|
||||||
@ -2073,7 +2073,7 @@
|
|||||||
"RunnableParallel"
|
"RunnableParallel"
|
||||||
],
|
],
|
||||||
"kwargs": {
|
"kwargs": {
|
||||||
"steps": {
|
"steps__": {
|
||||||
"question": {
|
"question": {
|
||||||
"lc": 1,
|
"lc": 1,
|
||||||
"type": "not_implemented",
|
"type": "not_implemented",
|
||||||
@ -4459,7 +4459,7 @@
|
|||||||
"RunnableParallel"
|
"RunnableParallel"
|
||||||
],
|
],
|
||||||
"kwargs": {
|
"kwargs": {
|
||||||
"steps": {
|
"steps__": {
|
||||||
"key": {
|
"key": {
|
||||||
"lc": 1,
|
"lc": 1,
|
||||||
"type": "not_implemented",
|
"type": "not_implemented",
|
||||||
@ -4481,7 +4481,7 @@
|
|||||||
"RunnableParallel"
|
"RunnableParallel"
|
||||||
],
|
],
|
||||||
"kwargs": {
|
"kwargs": {
|
||||||
"steps": {
|
"steps__": {
|
||||||
"question": {
|
"question": {
|
||||||
"lc": 1,
|
"lc": 1,
|
||||||
"type": "not_implemented",
|
"type": "not_implemented",
|
||||||
@ -8760,7 +8760,7 @@
|
|||||||
"RunnableParallel"
|
"RunnableParallel"
|
||||||
],
|
],
|
||||||
"kwargs": {
|
"kwargs": {
|
||||||
"steps": {
|
"steps__": {
|
||||||
"question": {
|
"question": {
|
||||||
"lc": 1,
|
"lc": 1,
|
||||||
"type": "constructor",
|
"type": "constructor",
|
||||||
@ -9860,7 +9860,7 @@
|
|||||||
"RunnableParallel"
|
"RunnableParallel"
|
||||||
],
|
],
|
||||||
"kwargs": {
|
"kwargs": {
|
||||||
"steps": {
|
"steps__": {
|
||||||
"chat": {
|
"chat": {
|
||||||
"lc": 1,
|
"lc": 1,
|
||||||
"type": "not_implemented",
|
"type": "not_implemented",
|
||||||
@ -10352,7 +10352,7 @@
|
|||||||
"RunnableParallel"
|
"RunnableParallel"
|
||||||
],
|
],
|
||||||
"kwargs": {
|
"kwargs": {
|
||||||
"steps": {
|
"steps__": {
|
||||||
"chat": {
|
"chat": {
|
||||||
"lc": 1,
|
"lc": 1,
|
||||||
"type": "constructor",
|
"type": "constructor",
|
||||||
|
@ -35,6 +35,7 @@ from langchain_core.language_models import (
|
|||||||
FakeStreamingListLLM,
|
FakeStreamingListLLM,
|
||||||
)
|
)
|
||||||
from langchain_core.load import dumpd, dumps
|
from langchain_core.load import dumpd, dumps
|
||||||
|
from langchain_core.load.load import loads
|
||||||
from langchain_core.messages import (
|
from langchain_core.messages import (
|
||||||
AIMessage,
|
AIMessage,
|
||||||
AIMessageChunk,
|
AIMessageChunk,
|
||||||
@ -76,7 +77,7 @@ from langchain_core.runnables import (
|
|||||||
add,
|
add,
|
||||||
chain,
|
chain,
|
||||||
)
|
)
|
||||||
from langchain_core.runnables.base import RunnableSerializable
|
from langchain_core.runnables.base import RunnableMap, RunnableSerializable
|
||||||
from langchain_core.runnables.utils import Input, Output
|
from langchain_core.runnables.utils import Input, Output
|
||||||
from langchain_core.tools import BaseTool, tool
|
from langchain_core.tools import BaseTool, tool
|
||||||
from langchain_core.tracers import (
|
from langchain_core.tracers import (
|
||||||
@ -3553,6 +3554,9 @@ async def test_map_astream_iterator_input() -> None:
|
|||||||
assert final_value.get("llm") == "i'm a textbot"
|
assert final_value.get("llm") == "i'm a textbot"
|
||||||
assert final_value.get("passthrough") == llm_res
|
assert final_value.get("passthrough") == llm_res
|
||||||
|
|
||||||
|
simple_map = RunnableMap(passthrough=RunnablePassthrough())
|
||||||
|
assert loads(dumps(simple_map)) == simple_map
|
||||||
|
|
||||||
|
|
||||||
def test_with_config_with_config() -> None:
|
def test_with_config_with_config() -> None:
|
||||||
llm = FakeListLLM(responses=["i'm a textbot"])
|
llm = FakeListLLM(responses=["i'm a textbot"])
|
||||||
|
Loading…
Reference in New Issue
Block a user