mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-13 16:36:06 +00:00
Adds transform and atransform support to runnable sequences (#9583)
Allow runnable sequences to support transform if each individual runnable inside supports transform/atransform. @nfcampos
This commit is contained in:
parent
c0e1a1d32c
commit
a50e62e44b
@ -1231,11 +1231,11 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
else:
|
||||
raise first_exception
|
||||
|
||||
def stream(
|
||||
def _transform(
|
||||
self,
|
||||
input: Input,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
input: Iterator[Input],
|
||||
run_manager: CallbackManagerForChainRun,
|
||||
config: RunnableConfig,
|
||||
) -> Iterator[Output]:
|
||||
# setup callbacks
|
||||
config = ensure_config(config)
|
||||
@ -1254,37 +1254,50 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
else:
|
||||
break
|
||||
|
||||
# invoke the first steps
|
||||
try:
|
||||
for step in steps[0:streaming_start_index]:
|
||||
input = step.invoke(
|
||||
input,
|
||||
# mark each step as a child run
|
||||
final_pipeline = None
|
||||
gathered_input = None
|
||||
if streaming_start_index == 0:
|
||||
final_pipeline = steps[streaming_start_index].transform(
|
||||
input,
|
||||
patch_config(config, callbacks=run_manager.get_child("seq:step:1")),
|
||||
)
|
||||
else:
|
||||
try:
|
||||
for input_chunk in input:
|
||||
if gathered_input is None:
|
||||
gathered_input = input_chunk
|
||||
else:
|
||||
gathered_input += input_chunk
|
||||
# invoke the first steps
|
||||
for step in steps[0:streaming_start_index]:
|
||||
gathered_input = step.invoke(
|
||||
gathered_input,
|
||||
# mark each step as a child run
|
||||
patch_config(
|
||||
config,
|
||||
callbacks=run_manager.get_child(
|
||||
f"seq:step:{steps.index(step)+1}"
|
||||
),
|
||||
),
|
||||
)
|
||||
# stream the first of the last steps with the final non-streaming input
|
||||
final_pipeline = steps[streaming_start_index].stream(
|
||||
gathered_input,
|
||||
patch_config(
|
||||
config,
|
||||
callbacks=run_manager.get_child(
|
||||
f"seq:step:{steps.index(step)+1}"
|
||||
f"seq:step:{streaming_start_index+1}"
|
||||
),
|
||||
),
|
||||
)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
run_manager.on_chain_error(e)
|
||||
raise
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
run_manager.on_chain_error(e)
|
||||
raise
|
||||
|
||||
# stream the last steps
|
||||
final: Union[Output, None] = None
|
||||
final_supported = True
|
||||
try:
|
||||
# stream the first of the last steps with non-streaming input
|
||||
final_pipeline = steps[streaming_start_index].stream(
|
||||
input,
|
||||
patch_config(
|
||||
config,
|
||||
callbacks=run_manager.get_child(
|
||||
f"seq:step:{streaming_start_index+1}"
|
||||
),
|
||||
),
|
||||
)
|
||||
# stream the rest of the last steps with streaming input
|
||||
for step in steps[streaming_start_index + 1 :]:
|
||||
final_pipeline = step.transform(
|
||||
@ -1296,6 +1309,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
for output in final_pipeline:
|
||||
yield output
|
||||
# Accumulate output if possible, otherwise disable accumulation
|
||||
@ -1316,11 +1330,11 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
else:
|
||||
run_manager.on_chain_end(final)
|
||||
|
||||
async def astream(
|
||||
async def _atransform(
|
||||
self,
|
||||
input: Input,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
input: AsyncIterator[Input],
|
||||
run_manager: AsyncCallbackManagerForChainRun,
|
||||
config: RunnableConfig,
|
||||
) -> AsyncIterator[Output]:
|
||||
# setup callbacks
|
||||
config = ensure_config(config)
|
||||
@ -1334,42 +1348,55 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
streaming_start_index = len(steps) - 1
|
||||
|
||||
for i in range(len(steps) - 1, 0, -1):
|
||||
if type(steps[i]).transform != Runnable.transform:
|
||||
if type(steps[i]).atransform != Runnable.atransform:
|
||||
streaming_start_index = i - 1
|
||||
else:
|
||||
break
|
||||
|
||||
# invoke the first steps
|
||||
try:
|
||||
for step in steps[0:streaming_start_index]:
|
||||
input = await step.ainvoke(
|
||||
input,
|
||||
# mark each step as a child run
|
||||
final_pipeline = None
|
||||
gathered_input = None
|
||||
if streaming_start_index == 0:
|
||||
final_pipeline = steps[0].atransform(
|
||||
input,
|
||||
patch_config(config, callbacks=run_manager.get_child("seq:step:1")),
|
||||
)
|
||||
else:
|
||||
try:
|
||||
async for input_chunk in input:
|
||||
if gathered_input is None:
|
||||
gathered_input = input_chunk
|
||||
else:
|
||||
gathered_input += input_chunk
|
||||
# invoke the first steps
|
||||
for step in steps[0:streaming_start_index]:
|
||||
gathered_input = await step.ainvoke(
|
||||
gathered_input,
|
||||
# mark each step as a child run
|
||||
patch_config(
|
||||
config,
|
||||
callbacks=run_manager.get_child(
|
||||
f"seq:step:{steps.index(step)+1}"
|
||||
),
|
||||
),
|
||||
)
|
||||
# stream the first of the last steps with the final non-streaming input
|
||||
final_pipeline = steps[streaming_start_index].astream(
|
||||
gathered_input,
|
||||
patch_config(
|
||||
config,
|
||||
callbacks=run_manager.get_child(
|
||||
f"seq:step:{steps.index(step)+1}"
|
||||
f"seq:step:{streaming_start_index+1}"
|
||||
),
|
||||
),
|
||||
)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
await run_manager.on_chain_error(e)
|
||||
raise
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
await run_manager.on_chain_error(e)
|
||||
raise
|
||||
|
||||
# stream the last steps
|
||||
final: Union[Output, None] = None
|
||||
final_supported = True
|
||||
try:
|
||||
# stream the first of the last steps with non-streaming input
|
||||
final_pipeline = steps[streaming_start_index].astream(
|
||||
input,
|
||||
patch_config(
|
||||
config,
|
||||
callbacks=run_manager.get_child(
|
||||
f"seq:step:{streaming_start_index+1}"
|
||||
),
|
||||
),
|
||||
)
|
||||
# stream the rest of the last steps with streaming input
|
||||
for step in steps[streaming_start_index + 1 :]:
|
||||
final_pipeline = step.atransform(
|
||||
@ -1401,6 +1428,47 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
else:
|
||||
await run_manager.on_chain_end(final)
|
||||
|
||||
def transform(
|
||||
self,
|
||||
input: Iterator[Input],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> Iterator[Output]:
|
||||
yield from self._transform_stream_with_config(
|
||||
input, self._transform, config, **kwargs
|
||||
)
|
||||
|
||||
def stream(
|
||||
self,
|
||||
input: Input,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> Iterator[Output]:
|
||||
yield from self.transform(iter([input]), config, **kwargs)
|
||||
|
||||
async def atransform(
|
||||
self,
|
||||
input: AsyncIterator[Input],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> AsyncIterator[Output]:
|
||||
async for chunk in self._atransform_stream_with_config(
|
||||
input, self._atransform, config, **kwargs
|
||||
):
|
||||
yield chunk
|
||||
|
||||
async def astream(
|
||||
self,
|
||||
input: Input,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> AsyncIterator[Output]:
|
||||
async def input_aiter() -> AsyncIterator[Input]:
|
||||
yield input
|
||||
|
||||
async for chunk in self.atransform(input_aiter(), config, **kwargs):
|
||||
yield chunk
|
||||
|
||||
|
||||
class RunnableMapChunk(Dict[str, Any]):
|
||||
"""
|
||||
|
@ -1315,6 +1315,37 @@ async def test_deep_astream() -> None:
|
||||
assert "".join(chunks) == "foo-lish"
|
||||
|
||||
|
||||
def test_runnable_sequence_transform() -> None:
|
||||
llm = FakeStreamingListLLM(responses=["foo-lish"])
|
||||
|
||||
chain = llm | StrOutputParser()
|
||||
|
||||
stream = chain.transform(llm.stream("Hi there!"))
|
||||
|
||||
chunks = []
|
||||
for chunk in stream:
|
||||
chunks.append(chunk)
|
||||
|
||||
assert len(chunks) == len("foo-lish")
|
||||
assert "".join(chunks) == "foo-lish"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runnable_sequence_atransform() -> None:
|
||||
llm = FakeStreamingListLLM(responses=["foo-lish"])
|
||||
|
||||
chain = llm | StrOutputParser()
|
||||
|
||||
stream = chain.atransform(llm.astream("Hi there!"))
|
||||
|
||||
chunks = []
|
||||
async for chunk in stream:
|
||||
chunks.append(chunk)
|
||||
|
||||
assert len(chunks) == len("foo-lish")
|
||||
assert "".join(chunks) == "foo-lish"
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def llm_with_fallbacks() -> RunnableWithFallbacks:
|
||||
error_llm = FakeListLLM(responses=["foo"], i=1)
|
||||
|
Loading…
Reference in New Issue
Block a user