mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-08 14:05:16 +00:00
Implement RunnablePassthrough.assign(...) (#11222)
Passes through dict input and assigns additional keys <!-- Thank you for contributing to LangChain! Replace this entire comment with: - **Description:** a description of the change, - **Issue:** the issue # it fixes (if applicable), - **Dependencies:** any dependencies required for this change, - **Tag maintainer:** for a quicker response, tag the relevant maintainer (see below), - **Twitter handle:** we announce bigger features on Twitter. If your PR gets announced, and you'd like a mention, we'll gladly shout you out! Please make sure your PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` to check this locally. See contribution guidelines for more information on how to write/run tests, lint, etc: https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/extras` directory. If no one reviews your PR within a few days, please @-mention one of @baskaryan, @eyurtsev, @hwchase17. -->
This commit is contained in:
parent
1ddf9f74b2
commit
fb66b392c6
@ -53,6 +53,7 @@ from langchain.schema.runnable.config import (
|
||||
patch_config,
|
||||
)
|
||||
from langchain.schema.runnable.utils import (
|
||||
AddableDict,
|
||||
Input,
|
||||
Output,
|
||||
accepts_config,
|
||||
@ -1748,30 +1749,6 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
yield chunk
|
||||
|
||||
|
||||
class RunnableMapChunk(Dict[str, Any]):
|
||||
"""
|
||||
Partial output from a RunnableMap
|
||||
"""
|
||||
|
||||
def __add__(self, other: RunnableMapChunk) -> RunnableMapChunk:
|
||||
chunk = RunnableMapChunk(self)
|
||||
for key in other:
|
||||
if key not in chunk or chunk[key] is None:
|
||||
chunk[key] = other[key]
|
||||
elif other[key] is not None:
|
||||
chunk[key] = chunk[key] + other[key]
|
||||
return chunk
|
||||
|
||||
def __radd__(self, other: RunnableMapChunk) -> RunnableMapChunk:
|
||||
chunk = RunnableMapChunk(other)
|
||||
for key in self:
|
||||
if key not in chunk or chunk[key] is None:
|
||||
chunk[key] = self[key]
|
||||
elif self[key] is not None:
|
||||
chunk[key] = chunk[key] + self[key]
|
||||
return chunk
|
||||
|
||||
|
||||
class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
||||
"""
|
||||
A runnable that runs a mapping of runnables in parallel,
|
||||
@ -1814,7 +1791,10 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
||||
|
||||
@property
|
||||
def input_schema(self) -> type[BaseModel]:
|
||||
if all(not s.input_schema.__custom_root_type__ for s in self.steps.values()):
|
||||
if all(
|
||||
s.input_schema.schema().get("type", "object") == "object"
|
||||
for s in self.steps.values()
|
||||
):
|
||||
# This is correct, but pydantic typings/mypy don't think so.
|
||||
return create_model( # type: ignore[call-overload]
|
||||
"RunnableMapInput",
|
||||
@ -1822,6 +1802,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
||||
k: (v.type_, v.default)
|
||||
for step in self.steps.values()
|
||||
for k, v in step.input_schema.__fields__.items()
|
||||
if k != "__root__"
|
||||
},
|
||||
)
|
||||
|
||||
@ -1934,7 +1915,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
||||
input: Iterator[Input],
|
||||
run_manager: CallbackManagerForChainRun,
|
||||
config: RunnableConfig,
|
||||
) -> Iterator[RunnableMapChunk]:
|
||||
) -> Iterator[AddableDict]:
|
||||
# Shallow copy steps to ignore mutations while in progress
|
||||
steps = dict(self.steps)
|
||||
# Each step gets a copy of the input iterator,
|
||||
@ -1967,7 +1948,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
||||
for future in completed_futures:
|
||||
(step_name, generator) = futures.pop(future)
|
||||
try:
|
||||
chunk = RunnableMapChunk({step_name: future.result()})
|
||||
chunk = AddableDict({step_name: future.result()})
|
||||
yield chunk
|
||||
futures[executor.submit(next, generator)] = (
|
||||
step_name,
|
||||
@ -1999,7 +1980,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
||||
input: AsyncIterator[Input],
|
||||
run_manager: AsyncCallbackManagerForChainRun,
|
||||
config: RunnableConfig,
|
||||
) -> AsyncIterator[RunnableMapChunk]:
|
||||
) -> AsyncIterator[AddableDict]:
|
||||
# Shallow copy steps to ignore mutations while in progress
|
||||
steps = dict(self.steps)
|
||||
# Each step gets a copy of the input iterator,
|
||||
@ -2038,7 +2019,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
||||
for task in completed_tasks:
|
||||
(step_name, generator) = tasks.pop(task)
|
||||
try:
|
||||
chunk = RunnableMapChunk({step_name: task.result()})
|
||||
chunk = AddableDict({step_name: task.result()})
|
||||
yield chunk
|
||||
new_task = asyncio.create_task(get_next_chunk(generator))
|
||||
tasks[new_task] = (step_name, generator)
|
||||
|
@ -1,10 +1,28 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, AsyncIterator, Iterator, List, Optional, Type
|
||||
import asyncio
|
||||
import threading
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from langchain.load.serializable import Serializable
|
||||
from langchain.schema.runnable.base import Input, Runnable
|
||||
from langchain.schema.runnable.config import RunnableConfig
|
||||
from langchain.pydantic_v1 import BaseModel, create_model
|
||||
from langchain.schema.runnable.base import Input, Runnable, RunnableMap
|
||||
from langchain.schema.runnable.config import RunnableConfig, get_executor_for_config
|
||||
from langchain.schema.runnable.utils import AddableDict
|
||||
from langchain.utils.aiter import atee, py_anext
|
||||
from langchain.utils.iter import safetee
|
||||
|
||||
|
||||
def identity(x: Input) -> Input:
|
||||
@ -38,6 +56,30 @@ class RunnablePassthrough(Serializable, Runnable[Input, Input]):
|
||||
def OutputType(self) -> Any:
|
||||
return self.input_type or Any
|
||||
|
||||
@classmethod
|
||||
def assign(
|
||||
cls,
|
||||
**kwargs: Union[
|
||||
Runnable[Dict[str, Any], Any],
|
||||
Callable[[Dict[str, Any]], Any],
|
||||
Mapping[
|
||||
str,
|
||||
Union[Runnable[Dict[str, Any], Any], Callable[[Dict[str, Any]], Any]],
|
||||
],
|
||||
],
|
||||
) -> RunnableAssign:
|
||||
"""
|
||||
Merge the Dict input with the output produced by the mapping argument.
|
||||
|
||||
Args:
|
||||
mapping: A mapping from keys to runnables or callables.
|
||||
|
||||
Returns:
|
||||
A runnable that merges the Dict input with the output produced by the
|
||||
mapping argument.
|
||||
"""
|
||||
return RunnableAssign(RunnableMap(kwargs))
|
||||
|
||||
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input:
|
||||
return self._call_with_config(identity, input, config)
|
||||
|
||||
@ -65,3 +107,155 @@ class RunnablePassthrough(Serializable, Runnable[Input, Input]):
|
||||
) -> AsyncIterator[Input]:
|
||||
async for chunk in self._atransform_stream_with_config(input, identity, config):
|
||||
yield chunk
|
||||
|
||||
|
||||
class RunnableAssign(Serializable, Runnable[Dict[str, Any], Dict[str, Any]]):
|
||||
"""
|
||||
A runnable that assigns key-value pairs to Dict[str, Any] inputs.
|
||||
"""
|
||||
|
||||
mapper: RunnableMap[Dict[str, Any]]
|
||||
|
||||
def __init__(self, mapper: RunnableMap[Dict[str, Any]], **kwargs: Any) -> None:
|
||||
super().__init__(mapper=mapper, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
return cls.__module__.split(".")[:-1]
|
||||
|
||||
@property
|
||||
def input_schema(self) -> type[BaseModel]:
|
||||
map_input_schema = self.mapper.input_schema
|
||||
if not map_input_schema.__custom_root_type__:
|
||||
# ie. it's a dict
|
||||
return map_input_schema
|
||||
|
||||
return super().input_schema
|
||||
|
||||
@property
|
||||
def output_schema(self) -> type[BaseModel]:
|
||||
map_input_schema = self.mapper.input_schema
|
||||
map_output_schema = self.mapper.output_schema
|
||||
if (
|
||||
not map_input_schema.__custom_root_type__
|
||||
and not map_output_schema.__custom_root_type__
|
||||
):
|
||||
# ie. both are dicts
|
||||
return create_model( # type: ignore[call-overload]
|
||||
"RunnableAssignOutput",
|
||||
**{
|
||||
k: (v.type_, v.default)
|
||||
for s in (map_input_schema, map_output_schema)
|
||||
for k, v in s.__fields__.items()
|
||||
},
|
||||
)
|
||||
|
||||
return super().output_schema
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
input: Dict[str, Any],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
assert isinstance(input, dict)
|
||||
return {
|
||||
**input,
|
||||
**self.mapper.invoke(input, config, **kwargs),
|
||||
}
|
||||
|
||||
async def ainvoke(
|
||||
self,
|
||||
input: Dict[str, Any],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
assert isinstance(input, dict)
|
||||
return {
|
||||
**input,
|
||||
**await self.mapper.ainvoke(input, config, **kwargs),
|
||||
}
|
||||
|
||||
def transform(
|
||||
self,
|
||||
input: Iterator[Dict[str, Any]],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[Dict[str, Any]]:
|
||||
# collect mapper keys
|
||||
mapper_keys = set(self.mapper.steps.keys())
|
||||
# create two streams, one for the map and one for the passthrough
|
||||
for_passthrough, for_map = safetee(input, 2, lock=threading.Lock())
|
||||
# create map output stream
|
||||
map_output = self.mapper.transform(for_map, config, **kwargs)
|
||||
# get executor to start map output stream in background
|
||||
with get_executor_for_config(config or {}) as executor:
|
||||
# start map output stream
|
||||
first_map_chunk_future = executor.submit(next, map_output) # type: ignore
|
||||
# consume passthrough stream
|
||||
for chunk in for_passthrough:
|
||||
assert isinstance(chunk, dict)
|
||||
# remove mapper keys from passthrough chunk, to be overwritten by map
|
||||
filtered = AddableDict(
|
||||
{k: v for k, v in chunk.items() if k not in mapper_keys}
|
||||
)
|
||||
if filtered:
|
||||
yield filtered
|
||||
# yield map output
|
||||
yield cast(Dict[str, Any], first_map_chunk_future.result())
|
||||
for chunk in map_output:
|
||||
yield chunk
|
||||
|
||||
async def atransform(
|
||||
self,
|
||||
input: AsyncIterator[Dict[str, Any]],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[Dict[str, Any]]:
|
||||
# collect mapper keys
|
||||
mapper_keys = set(self.mapper.steps.keys())
|
||||
# create two streams, one for the map and one for the passthrough
|
||||
for_passthrough, for_map = atee(input, 2, lock=asyncio.Lock())
|
||||
# create map output stream
|
||||
map_output = self.mapper.atransform(for_map, config, **kwargs)
|
||||
# start map output stream
|
||||
first_map_chunk_task: asyncio.Task = asyncio.create_task(
|
||||
py_anext(map_output), # type: ignore[arg-type]
|
||||
)
|
||||
# consume passthrough stream
|
||||
async for chunk in for_passthrough:
|
||||
assert isinstance(chunk, dict)
|
||||
# remove mapper keys from passthrough chunk, to be overwritten by map output
|
||||
filtered = AddableDict(
|
||||
{k: v for k, v in chunk.items() if k not in mapper_keys}
|
||||
)
|
||||
if filtered:
|
||||
yield filtered
|
||||
# yield map output
|
||||
yield await first_map_chunk_task
|
||||
async for chunk in map_output:
|
||||
yield chunk
|
||||
|
||||
def stream(
|
||||
self,
|
||||
input: Dict[str, Any],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[Dict[str, Any]]:
|
||||
return self.transform(iter([input]), config, **kwargs)
|
||||
|
||||
async def astream(
|
||||
self,
|
||||
input: Dict[str, Any],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[Dict[str, Any]]:
|
||||
async def input_aiter() -> AsyncIterator[Dict[str, Any]]:
|
||||
yield input
|
||||
|
||||
async for chunk in self.atransform(input_aiter(), config, **kwargs):
|
||||
yield chunk
|
||||
|
@ -5,7 +5,20 @@ import asyncio
|
||||
import inspect
|
||||
import textwrap
|
||||
from inspect import signature
|
||||
from typing import Any, Callable, Coroutine, List, Optional, Set, TypeVar, Union
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterable,
|
||||
Callable,
|
||||
Coroutine,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Protocol,
|
||||
Set,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
Input = TypeVar("Input")
|
||||
# Output type should implement __concat__, as eg str, list, dict do
|
||||
@ -142,3 +155,59 @@ def indent_lines_after_first(text: str, prefix: str) -> str:
|
||||
spaces = " " * n_spaces
|
||||
lines = text.splitlines()
|
||||
return "\n".join([lines[0]] + [spaces + line for line in lines[1:]])
|
||||
|
||||
|
||||
class AddableDict(Dict[str, Any]):
|
||||
"""
|
||||
Dictionary that can be added to another dictionary.
|
||||
"""
|
||||
|
||||
def __add__(self, other: AddableDict) -> AddableDict:
|
||||
chunk = AddableDict(self)
|
||||
for key in other:
|
||||
if key not in chunk or chunk[key] is None:
|
||||
chunk[key] = other[key]
|
||||
elif other[key] is not None:
|
||||
chunk[key] = chunk[key] + other[key]
|
||||
return chunk
|
||||
|
||||
def __radd__(self, other: AddableDict) -> AddableDict:
|
||||
chunk = AddableDict(other)
|
||||
for key in self:
|
||||
if key not in chunk or chunk[key] is None:
|
||||
chunk[key] = self[key]
|
||||
elif self[key] is not None:
|
||||
chunk[key] = chunk[key] + self[key]
|
||||
return chunk
|
||||
|
||||
|
||||
_T_co = TypeVar("_T_co", covariant=True)
|
||||
_T_contra = TypeVar("_T_contra", contravariant=True)
|
||||
|
||||
|
||||
class SupportsAdd(Protocol[_T_contra, _T_co]):
|
||||
def __add__(self, __x: _T_contra) -> _T_co:
|
||||
...
|
||||
|
||||
|
||||
Addable = TypeVar("Addable", bound=SupportsAdd[Any, Any])
|
||||
|
||||
|
||||
def add(addables: Iterable[Addable]) -> Optional[Addable]:
|
||||
final = None
|
||||
for chunk in addables:
|
||||
if final is None:
|
||||
final = chunk
|
||||
else:
|
||||
final = final + chunk
|
||||
return final
|
||||
|
||||
|
||||
async def aadd(addables: AsyncIterable[Addable]) -> Optional[Addable]:
|
||||
final = None
|
||||
async for chunk in addables:
|
||||
if final is None:
|
||||
final = chunk
|
||||
else:
|
||||
final = final + chunk
|
||||
return final
|
||||
|
@ -57,6 +57,7 @@ from langchain.schema.runnable import (
|
||||
RunnableWithFallbacks,
|
||||
)
|
||||
from langchain.schema.runnable.base import RunnableGenerator
|
||||
from langchain.schema.runnable.utils import add
|
||||
from langchain.tools.base import BaseTool, tool
|
||||
from langchain.tools.json.tool import JsonListKeysTool, JsonSpec
|
||||
|
||||
@ -2018,6 +2019,104 @@ def test_deep_stream() -> None:
|
||||
assert "".join(chunks) == "foo-lish"
|
||||
|
||||
|
||||
def test_deep_stream_assign() -> None:
|
||||
prompt = (
|
||||
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
|
||||
+ "{question}"
|
||||
)
|
||||
llm = FakeStreamingListLLM(responses=["foo-lish"])
|
||||
|
||||
chain: Runnable = prompt | llm | {"str": StrOutputParser()}
|
||||
|
||||
stream = chain.stream({"question": "What up"})
|
||||
|
||||
chunks = []
|
||||
for chunk in stream:
|
||||
chunks.append(chunk)
|
||||
|
||||
assert len(chunks) == len("foo-lish")
|
||||
assert add(chunks) == {"str": "foo-lish"}
|
||||
|
||||
chain_with_assign = chain | RunnablePassthrough.assign(
|
||||
hello=itemgetter("str") | llm
|
||||
)
|
||||
|
||||
assert chain_with_assign.input_schema.schema() == {
|
||||
"title": "PromptInput",
|
||||
"type": "object",
|
||||
"properties": {"question": {"title": "Question"}},
|
||||
}
|
||||
assert chain_with_assign.output_schema.schema() == {
|
||||
"title": "RunnableAssignOutput",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"str": {"title": "Str"},
|
||||
"hello": {"title": "Hello", "type": "string"},
|
||||
},
|
||||
}
|
||||
|
||||
chunks = []
|
||||
for chunk in chain_with_assign.stream({"question": "What up"}):
|
||||
chunks.append(chunk)
|
||||
|
||||
assert len(chunks) == len("foo-lish") * 2
|
||||
assert chunks == [
|
||||
# first stream passthrough input chunks
|
||||
{"str": "f"},
|
||||
{"str": "o"},
|
||||
{"str": "o"},
|
||||
{"str": "-"},
|
||||
{"str": "l"},
|
||||
{"str": "i"},
|
||||
{"str": "s"},
|
||||
{"str": "h"},
|
||||
# then stream assign output chunks
|
||||
{"hello": "f"},
|
||||
{"hello": "o"},
|
||||
{"hello": "o"},
|
||||
{"hello": "-"},
|
||||
{"hello": "l"},
|
||||
{"hello": "i"},
|
||||
{"hello": "s"},
|
||||
{"hello": "h"},
|
||||
]
|
||||
assert add(chunks) == {"str": "foo-lish", "hello": "foo-lish"}
|
||||
assert chain_with_assign.invoke({"question": "What up"}) == {
|
||||
"str": "foo-lish",
|
||||
"hello": "foo-lish",
|
||||
}
|
||||
|
||||
chain_with_assign_shadow = chain | RunnablePassthrough.assign(
|
||||
str=lambda _: "shadow",
|
||||
hello=itemgetter("str") | llm,
|
||||
)
|
||||
|
||||
assert chain_with_assign_shadow.input_schema.schema() == {
|
||||
"title": "PromptInput",
|
||||
"type": "object",
|
||||
"properties": {"question": {"title": "Question"}},
|
||||
}
|
||||
assert chain_with_assign_shadow.output_schema.schema() == {
|
||||
"title": "RunnableAssignOutput",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"str": {"title": "Str"},
|
||||
"hello": {"title": "Hello", "type": "string"},
|
||||
},
|
||||
}
|
||||
|
||||
chunks = []
|
||||
for chunk in chain_with_assign_shadow.stream({"question": "What up"}):
|
||||
chunks.append(chunk)
|
||||
|
||||
assert len(chunks) == len("foo-lish") + 1
|
||||
assert add(chunks) == {"str": "shadow", "hello": "foo-lish"}
|
||||
assert chain_with_assign_shadow.invoke({"question": "What up"}) == {
|
||||
"str": "shadow",
|
||||
"hello": "foo-lish",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deep_astream() -> None:
|
||||
prompt = (
|
||||
@ -2045,6 +2144,105 @@ async def test_deep_astream() -> None:
|
||||
assert "".join(chunks) == "foo-lish"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deep_astream_assign() -> None:
|
||||
prompt = (
|
||||
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
|
||||
+ "{question}"
|
||||
)
|
||||
llm = FakeStreamingListLLM(responses=["foo-lish"])
|
||||
|
||||
chain: Runnable = prompt | llm | {"str": StrOutputParser()}
|
||||
|
||||
stream = chain.astream({"question": "What up"})
|
||||
|
||||
chunks = []
|
||||
async for chunk in stream:
|
||||
chunks.append(chunk)
|
||||
|
||||
assert len(chunks) == len("foo-lish")
|
||||
assert add(chunks) == {"str": "foo-lish"}
|
||||
|
||||
chain_with_assign = chain | RunnablePassthrough.assign(
|
||||
hello=itemgetter("str") | llm,
|
||||
)
|
||||
|
||||
assert chain_with_assign.input_schema.schema() == {
|
||||
"title": "PromptInput",
|
||||
"type": "object",
|
||||
"properties": {"question": {"title": "Question"}},
|
||||
}
|
||||
assert chain_with_assign.output_schema.schema() == {
|
||||
"title": "RunnableAssignOutput",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"str": {"title": "Str"},
|
||||
"hello": {"title": "Hello", "type": "string"},
|
||||
},
|
||||
}
|
||||
|
||||
chunks = []
|
||||
async for chunk in chain_with_assign.astream({"question": "What up"}):
|
||||
chunks.append(chunk)
|
||||
|
||||
assert len(chunks) == len("foo-lish") * 2
|
||||
assert chunks == [
|
||||
# first stream passthrough input chunks
|
||||
{"str": "f"},
|
||||
{"str": "o"},
|
||||
{"str": "o"},
|
||||
{"str": "-"},
|
||||
{"str": "l"},
|
||||
{"str": "i"},
|
||||
{"str": "s"},
|
||||
{"str": "h"},
|
||||
# then stream assign output chunks
|
||||
{"hello": "f"},
|
||||
{"hello": "o"},
|
||||
{"hello": "o"},
|
||||
{"hello": "-"},
|
||||
{"hello": "l"},
|
||||
{"hello": "i"},
|
||||
{"hello": "s"},
|
||||
{"hello": "h"},
|
||||
]
|
||||
assert add(chunks) == {"str": "foo-lish", "hello": "foo-lish"}
|
||||
assert await chain_with_assign.ainvoke({"question": "What up"}) == {
|
||||
"str": "foo-lish",
|
||||
"hello": "foo-lish",
|
||||
}
|
||||
|
||||
chain_with_assign_shadow = chain | RunnablePassthrough.assign(
|
||||
str=lambda _: "shadow",
|
||||
hello=itemgetter("str") | llm,
|
||||
)
|
||||
|
||||
assert chain_with_assign_shadow.input_schema.schema() == {
|
||||
"title": "PromptInput",
|
||||
"type": "object",
|
||||
"properties": {"question": {"title": "Question"}},
|
||||
}
|
||||
assert chain_with_assign_shadow.output_schema.schema() == {
|
||||
"title": "RunnableAssignOutput",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"str": {"title": "Str"},
|
||||
"hello": {"title": "Hello", "type": "string"},
|
||||
},
|
||||
}
|
||||
|
||||
chunks = []
|
||||
async for chunk in chain_with_assign_shadow.astream({"question": "What up"}):
|
||||
chunks.append(chunk)
|
||||
|
||||
assert len(chunks) == len("foo-lish") + 1
|
||||
assert add(chunks) == {"str": "shadow", "hello": "foo-lish"}
|
||||
assert await chain_with_assign_shadow.ainvoke({"question": "What up"}) == {
|
||||
"str": "shadow",
|
||||
"hello": "foo-lish",
|
||||
}
|
||||
|
||||
|
||||
def test_runnable_sequence_transform() -> None:
|
||||
llm = FakeStreamingListLLM(responses=["foo-lish"])
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user