mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-27 17:08:47 +00:00
Adds streaming for runnable maps (#9283)
@nfcampos @baskaryan --------- Co-authored-by: Nuno Campos <nuno@boringbits.io>
This commit is contained in:
parent
0dd2c21089
commit
0689628489
@ -1,4 +1,6 @@
|
|||||||
"""Fake ChatModel for testing purposes."""
|
"""Fake ChatModel for testing purposes."""
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
|
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import (
|
||||||
@ -14,6 +16,7 @@ class FakeListChatModel(SimpleChatModel):
|
|||||||
"""Fake ChatModel for testing purposes."""
|
"""Fake ChatModel for testing purposes."""
|
||||||
|
|
||||||
responses: List
|
responses: List
|
||||||
|
sleep: Optional[float] = None
|
||||||
i: int = 0
|
i: int = 0
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -48,6 +51,8 @@ class FakeListChatModel(SimpleChatModel):
|
|||||||
else:
|
else:
|
||||||
self.i = 0
|
self.i = 0
|
||||||
for c in response:
|
for c in response:
|
||||||
|
if self.sleep is not None:
|
||||||
|
time.sleep(self.sleep)
|
||||||
yield ChatGenerationChunk(message=AIMessageChunk(content=c))
|
yield ChatGenerationChunk(message=AIMessageChunk(content=c))
|
||||||
|
|
||||||
async def _astream(
|
async def _astream(
|
||||||
@ -63,6 +68,8 @@ class FakeListChatModel(SimpleChatModel):
|
|||||||
else:
|
else:
|
||||||
self.i = 0
|
self.i = 0
|
||||||
for c in response:
|
for c in response:
|
||||||
|
if self.sleep is not None:
|
||||||
|
await asyncio.sleep(self.sleep)
|
||||||
yield ChatGenerationChunk(message=AIMessageChunk(content=c))
|
yield ChatGenerationChunk(message=AIMessageChunk(content=c))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
import asyncio
|
||||||
|
import time
|
||||||
from typing import Any, AsyncIterator, Iterator, List, Mapping, Optional
|
from typing import Any, AsyncIterator, Iterator, List, Mapping, Optional
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import (
|
||||||
@ -13,6 +15,7 @@ class FakeListLLM(LLM):
|
|||||||
"""Fake LLM for testing purposes."""
|
"""Fake LLM for testing purposes."""
|
||||||
|
|
||||||
responses: List
|
responses: List
|
||||||
|
sleep: Optional[float] = None
|
||||||
i: int = 0
|
i: int = 0
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -68,6 +71,8 @@ class FakeStreamingListLLM(FakeListLLM):
|
|||||||
) -> Iterator[str]:
|
) -> Iterator[str]:
|
||||||
result = self.invoke(input, config)
|
result = self.invoke(input, config)
|
||||||
for c in result:
|
for c in result:
|
||||||
|
if self.sleep is not None:
|
||||||
|
time.sleep(self.sleep)
|
||||||
yield c
|
yield c
|
||||||
|
|
||||||
async def astream(
|
async def astream(
|
||||||
@ -80,4 +85,6 @@ class FakeStreamingListLLM(FakeListLLM):
|
|||||||
) -> AsyncIterator[str]:
|
) -> AsyncIterator[str]:
|
||||||
result = await self.ainvoke(input, config)
|
result = await self.ainvoke(input, config)
|
||||||
for c in result:
|
for c in result:
|
||||||
|
if self.sleep is not None:
|
||||||
|
await asyncio.sleep(self.sleep)
|
||||||
yield c
|
yield c
|
||||||
|
@ -277,6 +277,7 @@ class BaseTransformOutputParser(BaseOutputParser[T]):
|
|||||||
self,
|
self,
|
||||||
input: Iterator[Union[str, BaseMessage]],
|
input: Iterator[Union[str, BaseMessage]],
|
||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> Iterator[T]:
|
) -> Iterator[T]:
|
||||||
yield from self._transform_stream_with_config(
|
yield from self._transform_stream_with_config(
|
||||||
input, self._transform, config, run_type="parser"
|
input, self._transform, config, run_type="parser"
|
||||||
@ -286,6 +287,7 @@ class BaseTransformOutputParser(BaseOutputParser[T]):
|
|||||||
self,
|
self,
|
||||||
input: AsyncIterator[Union[str, BaseMessage]],
|
input: AsyncIterator[Union[str, BaseMessage]],
|
||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> AsyncIterator[T]:
|
) -> AsyncIterator[T]:
|
||||||
async for chunk in self._atransform_stream_with_config(
|
async for chunk in self._atransform_stream_with_config(
|
||||||
input, self._atransform, config, run_type="parser"
|
input, self._atransform, config, run_type="parser"
|
||||||
|
@ -1,10 +1,13 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import copy
|
||||||
|
import threading
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait
|
||||||
from itertools import tee
|
from itertools import tee
|
||||||
from typing import (
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
AsyncIterator,
|
AsyncIterator,
|
||||||
Awaitable,
|
Awaitable,
|
||||||
@ -23,15 +26,25 @@ from typing import (
|
|||||||
cast,
|
cast,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from langchain.callbacks.manager import (
|
||||||
|
AsyncCallbackManagerForChainRun,
|
||||||
|
CallbackManagerForChainRun,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
from langchain.callbacks.base import BaseCallbackManager
|
from langchain.callbacks.base import BaseCallbackManager
|
||||||
from langchain.load.dump import dumpd
|
from langchain.load.dump import dumpd
|
||||||
from langchain.load.serializable import Serializable
|
from langchain.load.serializable import Serializable
|
||||||
from langchain.pydantic_v1 import Field
|
from langchain.pydantic_v1 import Field
|
||||||
from langchain.schema.runnable.config import RunnableConfig
|
from langchain.schema.runnable.config import RunnableConfig
|
||||||
from langchain.schema.runnable.utils import (
|
from langchain.schema.runnable.utils import (
|
||||||
|
accepts_run_manager,
|
||||||
|
accepts_run_manager_and_config,
|
||||||
gather_with_concurrency,
|
gather_with_concurrency,
|
||||||
)
|
)
|
||||||
from langchain.utils.aiter import atee, py_anext
|
from langchain.utils.aiter import atee, py_anext
|
||||||
|
from langchain.utils.iter import safetee
|
||||||
|
|
||||||
Input = TypeVar("Input")
|
Input = TypeVar("Input")
|
||||||
# Output type should implement __concat__, as eg str, list, dict do
|
# Output type should implement __concat__, as eg str, list, dict do
|
||||||
@ -48,7 +61,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
other: Union[
|
other: Union[
|
||||||
Runnable[Any, Other],
|
Runnable[Any, Other],
|
||||||
Callable[[Any], Other],
|
Callable[[Any], Other],
|
||||||
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other]]],
|
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]],
|
||||||
],
|
],
|
||||||
) -> RunnableSequence[Input, Other]:
|
) -> RunnableSequence[Input, Other]:
|
||||||
return RunnableSequence(first=self, last=coerce_to_runnable(other))
|
return RunnableSequence(first=self, last=coerce_to_runnable(other))
|
||||||
@ -58,7 +71,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
other: Union[
|
other: Union[
|
||||||
Runnable[Other, Any],
|
Runnable[Other, Any],
|
||||||
Callable[[Any], Other],
|
Callable[[Any], Other],
|
||||||
Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any]]],
|
Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any], Any]],
|
||||||
],
|
],
|
||||||
) -> RunnableSequence[Other, Output]:
|
) -> RunnableSequence[Other, Output]:
|
||||||
return RunnableSequence(first=coerce_to_runnable(other), last=self)
|
return RunnableSequence(first=coerce_to_runnable(other), last=self)
|
||||||
@ -135,7 +148,10 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
yield await self.ainvoke(input, config)
|
yield await self.ainvoke(input, config)
|
||||||
|
|
||||||
def transform(
|
def transform(
|
||||||
self, input: Iterator[Input], config: Optional[RunnableConfig] = None
|
self,
|
||||||
|
input: Iterator[Input],
|
||||||
|
config: Optional[RunnableConfig] = None,
|
||||||
|
**kwargs: Optional[Any],
|
||||||
) -> Iterator[Output]:
|
) -> Iterator[Output]:
|
||||||
"""
|
"""
|
||||||
Default implementation of transform, which buffers input and then calls stream.
|
Default implementation of transform, which buffers input and then calls stream.
|
||||||
@ -152,10 +168,13 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
# This method should throw an error if gathering fails.
|
# This method should throw an error if gathering fails.
|
||||||
final += chunk # type: ignore[operator]
|
final += chunk # type: ignore[operator]
|
||||||
if final:
|
if final:
|
||||||
yield from self.stream(final, config)
|
yield from self.stream(final, config, **kwargs)
|
||||||
|
|
||||||
async def atransform(
|
async def atransform(
|
||||||
self, input: AsyncIterator[Input], config: Optional[RunnableConfig] = None
|
self,
|
||||||
|
input: AsyncIterator[Input],
|
||||||
|
config: Optional[RunnableConfig] = None,
|
||||||
|
**kwargs: Optional[Any],
|
||||||
) -> AsyncIterator[Output]:
|
) -> AsyncIterator[Output]:
|
||||||
"""
|
"""
|
||||||
Default implementation of atransform, which buffers input and calls astream.
|
Default implementation of atransform, which buffers input and calls astream.
|
||||||
@ -173,7 +192,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
final += chunk # type: ignore[operator]
|
final += chunk # type: ignore[operator]
|
||||||
|
|
||||||
if final:
|
if final:
|
||||||
async for output in self.astream(final, config):
|
async for output in self.astream(final, config, **kwargs):
|
||||||
yield output
|
yield output
|
||||||
|
|
||||||
def bind(self, **kwargs: Any) -> Runnable[Input, Output]:
|
def bind(self, **kwargs: Any) -> Runnable[Input, Output]:
|
||||||
@ -217,7 +236,11 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
|
|
||||||
def _call_with_config(
|
def _call_with_config(
|
||||||
self,
|
self,
|
||||||
func: Callable[[Input], Output],
|
func: Union[
|
||||||
|
Callable[[Input], Output],
|
||||||
|
Callable[[Input, CallbackManagerForChainRun], Output],
|
||||||
|
Callable[[Input, CallbackManagerForChainRun, RunnableConfig], Output],
|
||||||
|
],
|
||||||
input: Input,
|
input: Input,
|
||||||
config: Optional[RunnableConfig],
|
config: Optional[RunnableConfig],
|
||||||
run_type: Optional[str] = None,
|
run_type: Optional[str] = None,
|
||||||
@ -238,7 +261,16 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
run_type=run_type,
|
run_type=run_type,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
output = func(input)
|
if accepts_run_manager_and_config(func):
|
||||||
|
output = func(
|
||||||
|
input,
|
||||||
|
run_manager=run_manager,
|
||||||
|
config=config,
|
||||||
|
) # type: ignore[call-arg]
|
||||||
|
elif accepts_run_manager(func):
|
||||||
|
output = func(input, run_manager=run_manager) # type: ignore[call-arg]
|
||||||
|
else:
|
||||||
|
output = func(input) # type: ignore[call-arg]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
run_manager.on_chain_error(e)
|
run_manager.on_chain_error(e)
|
||||||
raise
|
raise
|
||||||
@ -253,7 +285,14 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
|
|
||||||
async def _acall_with_config(
|
async def _acall_with_config(
|
||||||
self,
|
self,
|
||||||
func: Callable[[Input], Awaitable[Output]],
|
func: Union[
|
||||||
|
Callable[[Input], Awaitable[Output]],
|
||||||
|
Callable[[Input, AsyncCallbackManagerForChainRun], Awaitable[Output]],
|
||||||
|
Callable[
|
||||||
|
[Input, AsyncCallbackManagerForChainRun, RunnableConfig],
|
||||||
|
Awaitable[Output],
|
||||||
|
],
|
||||||
|
],
|
||||||
input: Input,
|
input: Input,
|
||||||
config: Optional[RunnableConfig],
|
config: Optional[RunnableConfig],
|
||||||
run_type: Optional[str] = None,
|
run_type: Optional[str] = None,
|
||||||
@ -274,7 +313,19 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
run_type=run_type,
|
run_type=run_type,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
output = await func(input)
|
if accepts_run_manager_and_config(func):
|
||||||
|
output = await func(
|
||||||
|
input,
|
||||||
|
run_manager=run_manager,
|
||||||
|
config=config,
|
||||||
|
) # type: ignore[call-arg]
|
||||||
|
elif accepts_run_manager(func):
|
||||||
|
output = await func(
|
||||||
|
input,
|
||||||
|
run_manager=run_manager,
|
||||||
|
) # type: ignore[call-arg]
|
||||||
|
else:
|
||||||
|
output = await func(input) # type: ignore[call-arg]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await run_manager.on_chain_error(e)
|
await run_manager.on_chain_error(e)
|
||||||
raise
|
raise
|
||||||
@ -290,7 +341,18 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
def _transform_stream_with_config(
|
def _transform_stream_with_config(
|
||||||
self,
|
self,
|
||||||
input: Iterator[Input],
|
input: Iterator[Input],
|
||||||
transformer: Callable[[Iterator[Input]], Iterator[Output]],
|
transformer: Union[
|
||||||
|
Callable[[Iterator[Input]], Iterator[Output]],
|
||||||
|
Callable[[Iterator[Input], CallbackManagerForChainRun], Iterator[Output]],
|
||||||
|
Callable[
|
||||||
|
[
|
||||||
|
Iterator[Input],
|
||||||
|
CallbackManagerForChainRun,
|
||||||
|
RunnableConfig,
|
||||||
|
],
|
||||||
|
Iterator[Output],
|
||||||
|
],
|
||||||
|
],
|
||||||
config: Optional[RunnableConfig],
|
config: Optional[RunnableConfig],
|
||||||
run_type: Optional[str] = None,
|
run_type: Optional[str] = None,
|
||||||
) -> Iterator[Output]:
|
) -> Iterator[Output]:
|
||||||
@ -319,7 +381,20 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
run_type=run_type,
|
run_type=run_type,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
for chunk in transformer(input_for_transform):
|
if accepts_run_manager_and_config(transformer):
|
||||||
|
iterator = transformer(
|
||||||
|
input_for_transform,
|
||||||
|
run_manager=run_manager,
|
||||||
|
config=config,
|
||||||
|
) # type: ignore[call-arg]
|
||||||
|
elif accepts_run_manager(transformer):
|
||||||
|
iterator = transformer(
|
||||||
|
input_for_transform,
|
||||||
|
run_manager=run_manager,
|
||||||
|
) # type: ignore[call-arg]
|
||||||
|
else:
|
||||||
|
iterator = transformer(input_for_transform) # type: ignore[call-arg]
|
||||||
|
for chunk in iterator:
|
||||||
yield chunk
|
yield chunk
|
||||||
if final_output_supported:
|
if final_output_supported:
|
||||||
if final_output is None:
|
if final_output is None:
|
||||||
@ -361,7 +436,21 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
async def _atransform_stream_with_config(
|
async def _atransform_stream_with_config(
|
||||||
self,
|
self,
|
||||||
input: AsyncIterator[Input],
|
input: AsyncIterator[Input],
|
||||||
transformer: Callable[[AsyncIterator[Input]], AsyncIterator[Output]],
|
transformer: Union[
|
||||||
|
Callable[[AsyncIterator[Input]], AsyncIterator[Output]],
|
||||||
|
Callable[
|
||||||
|
[AsyncIterator[Input], AsyncCallbackManagerForChainRun],
|
||||||
|
AsyncIterator[Output],
|
||||||
|
],
|
||||||
|
Callable[
|
||||||
|
[
|
||||||
|
AsyncIterator[Input],
|
||||||
|
AsyncCallbackManagerForChainRun,
|
||||||
|
RunnableConfig,
|
||||||
|
],
|
||||||
|
AsyncIterator[Output],
|
||||||
|
],
|
||||||
|
],
|
||||||
config: Optional[RunnableConfig],
|
config: Optional[RunnableConfig],
|
||||||
run_type: Optional[str] = None,
|
run_type: Optional[str] = None,
|
||||||
) -> AsyncIterator[Output]:
|
) -> AsyncIterator[Output]:
|
||||||
@ -390,7 +479,22 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
run_type=run_type,
|
run_type=run_type,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
async for chunk in transformer(input_for_transform):
|
# mypy can't quite work out thew type guard here, but this is safe,
|
||||||
|
# check implementations of the accepts_* functions
|
||||||
|
if accepts_run_manager_and_config(transformer):
|
||||||
|
iterator = transformer(
|
||||||
|
input_for_transform,
|
||||||
|
run_manager=run_manager,
|
||||||
|
config=config,
|
||||||
|
) # type: ignore[call-arg]
|
||||||
|
elif accepts_run_manager(transformer):
|
||||||
|
iterator = transformer(
|
||||||
|
input_for_transform,
|
||||||
|
run_manager=run_manager,
|
||||||
|
) # type: ignore[call-arg]
|
||||||
|
else:
|
||||||
|
iterator = transformer(input_for_transform) # type: ignore[call-arg]
|
||||||
|
async for chunk in iterator:
|
||||||
yield chunk
|
yield chunk
|
||||||
if final_output_supported:
|
if final_output_supported:
|
||||||
if final_output is None:
|
if final_output is None:
|
||||||
@ -700,7 +804,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
other: Union[
|
other: Union[
|
||||||
Runnable[Any, Other],
|
Runnable[Any, Other],
|
||||||
Callable[[Any], Other],
|
Callable[[Any], Other],
|
||||||
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other]]],
|
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]],
|
||||||
],
|
],
|
||||||
) -> RunnableSequence[Input, Other]:
|
) -> RunnableSequence[Input, Other]:
|
||||||
if isinstance(other, RunnableSequence):
|
if isinstance(other, RunnableSequence):
|
||||||
@ -721,7 +825,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
other: Union[
|
other: Union[
|
||||||
Runnable[Other, Any],
|
Runnable[Other, Any],
|
||||||
Callable[[Any], Other],
|
Callable[[Any], Other],
|
||||||
Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any]]],
|
Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any], Any]],
|
||||||
],
|
],
|
||||||
) -> RunnableSequence[Other, Output]:
|
) -> RunnableSequence[Other, Output]:
|
||||||
if isinstance(other, RunnableSequence):
|
if isinstance(other, RunnableSequence):
|
||||||
@ -875,7 +979,6 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
) -> List[Output]:
|
) -> List[Output]:
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import (
|
||||||
AsyncCallbackManager,
|
AsyncCallbackManager,
|
||||||
AsyncCallbackManagerForChainRun,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# setup callbacks
|
# setup callbacks
|
||||||
@ -1085,6 +1188,21 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RunnableMapChunk(Dict[str, Any]):
|
||||||
|
"""
|
||||||
|
Partial output from a RunnableMap
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __add__(self, other: RunnableMapChunk) -> RunnableMapChunk:
|
||||||
|
chunk = copy.deepcopy(self)
|
||||||
|
for key in other:
|
||||||
|
if key not in chunk or chunk[key] is None:
|
||||||
|
chunk[key] = other[key]
|
||||||
|
elif other[key] is not None:
|
||||||
|
chunk[key] += other[key]
|
||||||
|
return chunk
|
||||||
|
|
||||||
|
|
||||||
class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
||||||
"""
|
"""
|
||||||
A runnable that runs a mapping of runnables in parallel,
|
A runnable that runs a mapping of runnables in parallel,
|
||||||
@ -1134,7 +1252,9 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
|||||||
local_metadata=None,
|
local_metadata=None,
|
||||||
)
|
)
|
||||||
# start the root run
|
# start the root run
|
||||||
run_manager = callback_manager.on_chain_start(dumpd(self), {"input": input})
|
run_manager = callback_manager.on_chain_start(
|
||||||
|
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
||||||
|
)
|
||||||
|
|
||||||
# gather results from all steps
|
# gather results from all steps
|
||||||
try:
|
try:
|
||||||
@ -1177,7 +1297,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
|||||||
)
|
)
|
||||||
# start the root run
|
# start the root run
|
||||||
run_manager = await callback_manager.on_chain_start(
|
run_manager = await callback_manager.on_chain_start(
|
||||||
dumpd(self), {"input": input}
|
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
||||||
)
|
)
|
||||||
|
|
||||||
# gather results from all steps
|
# gather results from all steps
|
||||||
@ -1203,6 +1323,134 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
|||||||
await run_manager.on_chain_end(output)
|
await run_manager.on_chain_end(output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
def _transform(
|
||||||
|
self,
|
||||||
|
input: Iterator[Input],
|
||||||
|
run_manager: CallbackManagerForChainRun,
|
||||||
|
config: RunnableConfig,
|
||||||
|
) -> Iterator[RunnableMapChunk]:
|
||||||
|
# Shallow copy steps to ignore mutations while in progress
|
||||||
|
steps = dict(self.steps)
|
||||||
|
# Each step gets a copy of the input iterator,
|
||||||
|
# which is consumed in parallel in a separate thread.
|
||||||
|
input_copies = list(safetee(input, len(steps), lock=threading.Lock()))
|
||||||
|
with ThreadPoolExecutor() as executor:
|
||||||
|
# Create the transform() generator for each step
|
||||||
|
named_generators = [
|
||||||
|
(
|
||||||
|
name,
|
||||||
|
step.transform(
|
||||||
|
input_copies.pop(),
|
||||||
|
patch_config(config, run_manager.get_child()),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for name, step in steps.items()
|
||||||
|
]
|
||||||
|
# Start the first iteration of each generator
|
||||||
|
futures = {
|
||||||
|
executor.submit(next, generator): (step_name, generator)
|
||||||
|
for step_name, generator in named_generators
|
||||||
|
}
|
||||||
|
# Yield chunks from each as they become available,
|
||||||
|
# and start the next iteration of that generator that yielded it.
|
||||||
|
# When all generators are exhausted, stop.
|
||||||
|
while futures:
|
||||||
|
completed_futures, _ = wait(futures, return_when=FIRST_COMPLETED)
|
||||||
|
for future in completed_futures:
|
||||||
|
(step_name, generator) = futures.pop(future)
|
||||||
|
try:
|
||||||
|
chunk = RunnableMapChunk({step_name: future.result()})
|
||||||
|
yield chunk
|
||||||
|
futures[executor.submit(next, generator)] = (
|
||||||
|
step_name,
|
||||||
|
generator,
|
||||||
|
)
|
||||||
|
except StopIteration:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def transform(
|
||||||
|
self,
|
||||||
|
input: Iterator[Input],
|
||||||
|
config: Optional[RunnableConfig] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Iterator[Dict[str, Any]]:
|
||||||
|
yield from self._transform_stream_with_config(
|
||||||
|
input, self._transform, config, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
def stream(
|
||||||
|
self, input: Input, config: Optional[RunnableConfig] = None
|
||||||
|
) -> Iterator[Dict[str, Any]]:
|
||||||
|
yield from self.transform(iter([input]), config)
|
||||||
|
|
||||||
|
async def _atransform(
|
||||||
|
self,
|
||||||
|
input: AsyncIterator[Input],
|
||||||
|
run_manager: AsyncCallbackManagerForChainRun,
|
||||||
|
config: RunnableConfig,
|
||||||
|
) -> AsyncIterator[RunnableMapChunk]:
|
||||||
|
# Shallow copy steps to ignore mutations while in progress
|
||||||
|
steps = dict(self.steps)
|
||||||
|
# Each step gets a copy of the input iterator,
|
||||||
|
# which is consumed in parallel in a separate thread.
|
||||||
|
input_copies = list(atee(input, len(steps), lock=asyncio.Lock()))
|
||||||
|
# Create the transform() generator for each step
|
||||||
|
named_generators = [
|
||||||
|
(
|
||||||
|
name,
|
||||||
|
step.atransform(
|
||||||
|
input_copies.pop(), patch_config(config, run_manager.get_child())
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for name, step in steps.items()
|
||||||
|
]
|
||||||
|
|
||||||
|
# Wrap in a coroutine to satisfy linter
|
||||||
|
async def get_next_chunk(generator: AsyncIterator) -> Optional[Output]:
|
||||||
|
return await py_anext(generator)
|
||||||
|
|
||||||
|
# Start the first iteration of each generator
|
||||||
|
tasks = {
|
||||||
|
asyncio.create_task(get_next_chunk(generator)): (step_name, generator)
|
||||||
|
for step_name, generator in named_generators
|
||||||
|
}
|
||||||
|
# Yield chunks from each as they become available,
|
||||||
|
# and start the next iteration of the generator that yielded it.
|
||||||
|
# When all generators are exhausted, stop.
|
||||||
|
while tasks:
|
||||||
|
completed_tasks, _ = await asyncio.wait(
|
||||||
|
tasks, return_when=asyncio.FIRST_COMPLETED
|
||||||
|
)
|
||||||
|
for task in completed_tasks:
|
||||||
|
(step_name, generator) = tasks.pop(task)
|
||||||
|
try:
|
||||||
|
chunk = RunnableMapChunk({step_name: task.result()})
|
||||||
|
yield chunk
|
||||||
|
new_task = asyncio.create_task(get_next_chunk(generator))
|
||||||
|
tasks[new_task] = (step_name, generator)
|
||||||
|
except StopAsyncIteration:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def atransform(
|
||||||
|
self,
|
||||||
|
input: AsyncIterator[Input],
|
||||||
|
config: Optional[RunnableConfig] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> AsyncIterator[Dict[str, Any]]:
|
||||||
|
async for chunk in self._atransform_stream_with_config(
|
||||||
|
input, self._atransform, config, **kwargs
|
||||||
|
):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
async def astream(
|
||||||
|
self, input: Input, config: Optional[RunnableConfig] = None
|
||||||
|
) -> AsyncIterator[Dict[str, Any]]:
|
||||||
|
async def input_aiter() -> AsyncIterator[Input]:
|
||||||
|
yield input
|
||||||
|
|
||||||
|
async for chunk in self.atransform(input_aiter(), config):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
|
||||||
class RunnableLambda(Runnable[Input, Output]):
|
class RunnableLambda(Runnable[Input, Output]):
|
||||||
"""
|
"""
|
||||||
@ -1293,14 +1541,22 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
|||||||
yield item
|
yield item
|
||||||
|
|
||||||
def transform(
|
def transform(
|
||||||
self, input: Iterator[Input], config: Optional[RunnableConfig] = None
|
self,
|
||||||
|
input: Iterator[Input],
|
||||||
|
config: Optional[RunnableConfig] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> Iterator[Output]:
|
) -> Iterator[Output]:
|
||||||
yield from self.bound.transform(input, config, **self.kwargs)
|
yield from self.bound.transform(input, config, **{**self.kwargs, **kwargs})
|
||||||
|
|
||||||
async def atransform(
|
async def atransform(
|
||||||
self, input: AsyncIterator[Input], config: Optional[RunnableConfig] = None
|
self,
|
||||||
|
input: AsyncIterator[Input],
|
||||||
|
config: Optional[RunnableConfig] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> AsyncIterator[Output]:
|
) -> AsyncIterator[Output]:
|
||||||
async for item in self.bound.atransform(input, config, **self.kwargs):
|
async for item in self.bound.atransform(
|
||||||
|
input, config, **{**self.kwargs, **kwargs}
|
||||||
|
):
|
||||||
yield item
|
yield item
|
||||||
|
|
||||||
|
|
||||||
@ -1316,7 +1572,7 @@ def coerce_to_runnable(
|
|||||||
thing: Union[
|
thing: Union[
|
||||||
Runnable[Input, Output],
|
Runnable[Input, Output],
|
||||||
Callable[[Input], Output],
|
Callable[[Input], Output],
|
||||||
Mapping[str, Union[Runnable[Input, Output], Callable[[Input], Output]]],
|
Mapping[str, Any],
|
||||||
]
|
]
|
||||||
) -> Runnable[Input, Output]:
|
) -> Runnable[Input, Output]:
|
||||||
if isinstance(thing, Runnable):
|
if isinstance(thing, Runnable):
|
||||||
@ -1324,7 +1580,9 @@ def coerce_to_runnable(
|
|||||||
elif callable(thing):
|
elif callable(thing):
|
||||||
return RunnableLambda(thing)
|
return RunnableLambda(thing)
|
||||||
elif isinstance(thing, dict):
|
elif isinstance(thing, dict):
|
||||||
runnables = {key: coerce_to_runnable(r) for key, r in thing.items()}
|
runnables: Mapping[str, Runnable[Any, Any]] = {
|
||||||
|
key: coerce_to_runnable(r) for key, r in thing.items()
|
||||||
|
}
|
||||||
return cast(Runnable[Input, Output], RunnableMap(steps=runnables))
|
return cast(Runnable[Input, Output], RunnableMap(steps=runnables))
|
||||||
else:
|
else:
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import AsyncIterator, Iterator, List, Optional
|
from typing import Any, AsyncIterator, Iterator, List, Optional
|
||||||
|
|
||||||
from langchain.load.serializable import Serializable
|
from langchain.load.serializable import Serializable
|
||||||
from langchain.schema.runnable.base import Input, Runnable
|
from langchain.schema.runnable.base import Input, Runnable
|
||||||
@ -32,16 +32,22 @@ class RunnablePassthrough(Serializable, Runnable[Input, Input]):
|
|||||||
return self._call_with_config(identity, input, config)
|
return self._call_with_config(identity, input, config)
|
||||||
|
|
||||||
async def ainvoke(
|
async def ainvoke(
|
||||||
self, input: Input, config: RunnableConfig | None = None
|
self, input: Input, config: Optional[RunnableConfig] = None
|
||||||
) -> Input:
|
) -> Input:
|
||||||
return await self._acall_with_config(aidentity, input, config)
|
return await self._acall_with_config(aidentity, input, config)
|
||||||
|
|
||||||
def transform(
|
def transform(
|
||||||
self, input: Iterator[Input], config: RunnableConfig | None = None
|
self,
|
||||||
|
input: Iterator[Input],
|
||||||
|
config: Optional[RunnableConfig] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> Iterator[Input]:
|
) -> Iterator[Input]:
|
||||||
return self._transform_stream_with_config(input, identity, config)
|
return self._transform_stream_with_config(input, identity, config)
|
||||||
|
|
||||||
def atransform(
|
def atransform(
|
||||||
self, input: AsyncIterator[Input], config: RunnableConfig | None = None
|
self,
|
||||||
|
input: AsyncIterator[Input],
|
||||||
|
config: Optional[RunnableConfig] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> AsyncIterator[Input]:
|
) -> AsyncIterator[Input]:
|
||||||
return self._atransform_stream_with_config(input, identity, config)
|
return self._atransform_stream_with_config(input, identity, config)
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import Any, Coroutine, Union
|
from inspect import signature
|
||||||
|
from typing import Any, Callable, Coroutine, Union
|
||||||
|
|
||||||
|
|
||||||
async def gated_coro(semaphore: asyncio.Semaphore, coro: Coroutine) -> Any:
|
async def gated_coro(semaphore: asyncio.Semaphore, coro: Coroutine) -> Any:
|
||||||
@ -16,3 +17,17 @@ async def gather_with_concurrency(n: Union[int, None], *coros: Coroutine) -> lis
|
|||||||
semaphore = asyncio.Semaphore(n)
|
semaphore = asyncio.Semaphore(n)
|
||||||
|
|
||||||
return await asyncio.gather(*(gated_coro(semaphore, c) for c in coros))
|
return await asyncio.gather(*(gated_coro(semaphore, c) for c in coros))
|
||||||
|
|
||||||
|
|
||||||
|
def accepts_run_manager(callable: Callable[..., Any]) -> bool:
|
||||||
|
try:
|
||||||
|
return signature(callable).parameters.get("run_manager") is not None
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def accepts_run_manager_and_config(callable: Callable[..., Any]) -> bool:
|
||||||
|
return (
|
||||||
|
accepts_run_manager(callable)
|
||||||
|
and signature(callable).parameters.get("config") is not None
|
||||||
|
)
|
||||||
|
@ -7,6 +7,7 @@ MIT License
|
|||||||
from collections import deque
|
from collections import deque
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
|
AsyncContextManager,
|
||||||
AsyncGenerator,
|
AsyncGenerator,
|
||||||
AsyncIterator,
|
AsyncIterator,
|
||||||
Awaitable,
|
Awaitable,
|
||||||
@ -15,6 +16,7 @@ from typing import (
|
|||||||
Generic,
|
Generic,
|
||||||
Iterator,
|
Iterator,
|
||||||
List,
|
List,
|
||||||
|
Optional,
|
||||||
Tuple,
|
Tuple,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
@ -64,33 +66,45 @@ def py_anext(
|
|||||||
return anext_impl()
|
return anext_impl()
|
||||||
|
|
||||||
|
|
||||||
|
class NoLock:
|
||||||
|
"""Dummy lock that provides the proper interface but no protection"""
|
||||||
|
|
||||||
|
async def __aenter__(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
async def tee_peer(
|
async def tee_peer(
|
||||||
iterator: AsyncIterator[T],
|
iterator: AsyncIterator[T],
|
||||||
# the buffer specific to this peer
|
# the buffer specific to this peer
|
||||||
buffer: Deque[T],
|
buffer: Deque[T],
|
||||||
# the buffers of all peers, including our own
|
# the buffers of all peers, including our own
|
||||||
peers: List[Deque[T]],
|
peers: List[Deque[T]],
|
||||||
|
lock: AsyncContextManager[Any],
|
||||||
) -> AsyncGenerator[T, None]:
|
) -> AsyncGenerator[T, None]:
|
||||||
"""An individual iterator of a :py:func:`~.tee`"""
|
"""An individual iterator of a :py:func:`~.tee`"""
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
if not buffer:
|
if not buffer:
|
||||||
# Another peer produced an item while we were waiting for the lock.
|
async with lock:
|
||||||
# Proceed with the next loop iteration to yield the item.
|
# Another peer produced an item while we were waiting for the lock.
|
||||||
if buffer:
|
# Proceed with the next loop iteration to yield the item.
|
||||||
continue
|
if buffer:
|
||||||
try:
|
continue
|
||||||
item = await iterator.__anext__()
|
try:
|
||||||
except StopAsyncIteration:
|
item = await iterator.__anext__()
|
||||||
break
|
except StopAsyncIteration:
|
||||||
else:
|
break
|
||||||
# Append to all buffers, including our own. We'll fetch our
|
else:
|
||||||
# item from the buffer again, instead of yielding it directly.
|
# Append to all buffers, including our own. We'll fetch our
|
||||||
# This ensures the proper item ordering if any of our peers
|
# item from the buffer again, instead of yielding it directly.
|
||||||
# are fetching items concurrently. They may have buffered their
|
# This ensures the proper item ordering if any of our peers
|
||||||
# item already.
|
# are fetching items concurrently. They may have buffered their
|
||||||
for peer_buffer in peers:
|
# item already.
|
||||||
peer_buffer.append(item)
|
for peer_buffer in peers:
|
||||||
|
peer_buffer.append(item)
|
||||||
yield buffer.popleft()
|
yield buffer.popleft()
|
||||||
finally:
|
finally:
|
||||||
# this peer is done – remove its buffer
|
# this peer is done – remove its buffer
|
||||||
@ -145,6 +159,8 @@ class Tee(Generic[T]):
|
|||||||
self,
|
self,
|
||||||
iterable: AsyncIterator[T],
|
iterable: AsyncIterator[T],
|
||||||
n: int = 2,
|
n: int = 2,
|
||||||
|
*,
|
||||||
|
lock: Optional[AsyncContextManager[Any]] = None,
|
||||||
):
|
):
|
||||||
self._iterator = iterable.__aiter__() # before 3.10 aiter() doesn't exist
|
self._iterator = iterable.__aiter__() # before 3.10 aiter() doesn't exist
|
||||||
self._buffers: List[Deque[T]] = [deque() for _ in range(n)]
|
self._buffers: List[Deque[T]] = [deque() for _ in range(n)]
|
||||||
@ -153,6 +169,7 @@ class Tee(Generic[T]):
|
|||||||
iterator=self._iterator,
|
iterator=self._iterator,
|
||||||
buffer=buffer,
|
buffer=buffer,
|
||||||
peers=self._buffers,
|
peers=self._buffers,
|
||||||
|
lock=lock if lock is not None else NoLock(),
|
||||||
)
|
)
|
||||||
for buffer in self._buffers
|
for buffer in self._buffers
|
||||||
)
|
)
|
||||||
|
162
libs/langchain/langchain/utils/iter.py
Normal file
162
libs/langchain/langchain/utils/iter.py
Normal file
@ -0,0 +1,162 @@
|
|||||||
|
from collections import deque
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
ContextManager,
|
||||||
|
Deque,
|
||||||
|
Generator,
|
||||||
|
Generic,
|
||||||
|
Iterator,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
TypeVar,
|
||||||
|
Union,
|
||||||
|
overload,
|
||||||
|
)
|
||||||
|
|
||||||
|
from typing_extensions import Literal
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
class NoLock:
|
||||||
|
"""Dummy lock that provides the proper interface but no protection"""
|
||||||
|
|
||||||
|
def __enter__(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> Literal[False]:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def tee_peer(
|
||||||
|
iterator: Iterator[T],
|
||||||
|
# the buffer specific to this peer
|
||||||
|
buffer: Deque[T],
|
||||||
|
# the buffers of all peers, including our own
|
||||||
|
peers: List[Deque[T]],
|
||||||
|
lock: ContextManager[Any],
|
||||||
|
) -> Generator[T, None, None]:
|
||||||
|
"""An individual iterator of a :py:func:`~.tee`"""
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
if not buffer:
|
||||||
|
with lock:
|
||||||
|
# Another peer produced an item while we were waiting for the lock.
|
||||||
|
# Proceed with the next loop iteration to yield the item.
|
||||||
|
if buffer:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
item = next(iterator)
|
||||||
|
except StopIteration:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# Append to all buffers, including our own. We'll fetch our
|
||||||
|
# item from the buffer again, instead of yielding it directly.
|
||||||
|
# This ensures the proper item ordering if any of our peers
|
||||||
|
# are fetching items concurrently. They may have buffered their
|
||||||
|
# item already.
|
||||||
|
for peer_buffer in peers:
|
||||||
|
peer_buffer.append(item)
|
||||||
|
yield buffer.popleft()
|
||||||
|
finally:
|
||||||
|
# this peer is done – remove its buffer
|
||||||
|
for idx, peer_buffer in enumerate(peers): # pragma: no branch
|
||||||
|
if peer_buffer is buffer:
|
||||||
|
peers.pop(idx)
|
||||||
|
break
|
||||||
|
# if we are the last peer, try and close the iterator
|
||||||
|
if not peers and hasattr(iterator, "close"):
|
||||||
|
iterator.close()
|
||||||
|
|
||||||
|
|
||||||
|
class Tee(Generic[T]):
|
||||||
|
"""
|
||||||
|
Create ``n`` separate asynchronous iterators over ``iterable``
|
||||||
|
|
||||||
|
This splits a single ``iterable`` into multiple iterators, each providing
|
||||||
|
the same items in the same order.
|
||||||
|
All child iterators may advance separately but share the same items
|
||||||
|
from ``iterable`` -- when the most advanced iterator retrieves an item,
|
||||||
|
it is buffered until the least advanced iterator has yielded it as well.
|
||||||
|
A ``tee`` works lazily and can handle an infinite ``iterable``, provided
|
||||||
|
that all iterators advance.
|
||||||
|
|
||||||
|
.. code-block:: python3
|
||||||
|
|
||||||
|
async def derivative(sensor_data):
|
||||||
|
previous, current = a.tee(sensor_data, n=2)
|
||||||
|
await a.anext(previous) # advance one iterator
|
||||||
|
return a.map(operator.sub, previous, current)
|
||||||
|
|
||||||
|
Unlike :py:func:`itertools.tee`, :py:func:`~.tee` returns a custom type instead
|
||||||
|
of a :py:class:`tuple`. Like a tuple, it can be indexed, iterated and unpacked
|
||||||
|
to get the child iterators. In addition, its :py:meth:`~.tee.aclose` method
|
||||||
|
immediately closes all children, and it can be used in an ``async with`` context
|
||||||
|
for the same effect.
|
||||||
|
|
||||||
|
If ``iterable`` is an iterator and read elsewhere, ``tee`` will *not*
|
||||||
|
provide these items. Also, ``tee`` must internally buffer each item until the
|
||||||
|
last iterator has yielded it; if the most and least advanced iterator differ
|
||||||
|
by most data, using a :py:class:`list` is more efficient (but not lazy).
|
||||||
|
|
||||||
|
If the underlying iterable is concurrency safe (``anext`` may be awaited
|
||||||
|
concurrently) the resulting iterators are concurrency safe as well. Otherwise,
|
||||||
|
the iterators are safe if there is only ever one single "most advanced" iterator.
|
||||||
|
To enforce sequential use of ``anext``, provide a ``lock``
|
||||||
|
- e.g. an :py:class:`asyncio.Lock` instance in an :py:mod:`asyncio` application -
|
||||||
|
and access is automatically synchronised.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
iterable: Iterator[T],
|
||||||
|
n: int = 2,
|
||||||
|
*,
|
||||||
|
lock: Optional[ContextManager[Any]] = None,
|
||||||
|
):
|
||||||
|
self._iterator = iter(iterable)
|
||||||
|
self._buffers: List[Deque[T]] = [deque() for _ in range(n)]
|
||||||
|
self._children = tuple(
|
||||||
|
tee_peer(
|
||||||
|
iterator=self._iterator,
|
||||||
|
buffer=buffer,
|
||||||
|
peers=self._buffers,
|
||||||
|
lock=lock if lock is not None else NoLock(),
|
||||||
|
)
|
||||||
|
for buffer in self._buffers
|
||||||
|
)
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self._children)
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def __getitem__(self, item: int) -> Iterator[T]:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def __getitem__(self, item: slice) -> Tuple[Iterator[T], ...]:
|
||||||
|
...
|
||||||
|
|
||||||
|
def __getitem__(
|
||||||
|
self, item: Union[int, slice]
|
||||||
|
) -> Union[Iterator[T], Tuple[Iterator[T], ...]]:
|
||||||
|
return self._children[item]
|
||||||
|
|
||||||
|
def __iter__(self) -> Iterator[Iterator[T]]:
|
||||||
|
yield from self._children
|
||||||
|
|
||||||
|
def __enter__(self) -> "Tee[T]":
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> Literal[False]:
|
||||||
|
self.close()
|
||||||
|
return False
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
for child in self._children:
|
||||||
|
child.close()
|
||||||
|
|
||||||
|
|
||||||
|
# Why this is needed https://stackoverflow.com/a/44638570
|
||||||
|
safetee = Tee
|
@ -528,7 +528,7 @@ Question:
|
|||||||
|
|
||||||
parser = CommaSeparatedListOutputParser()
|
parser = CommaSeparatedListOutputParser()
|
||||||
|
|
||||||
chain = (
|
chain: Runnable = (
|
||||||
{
|
{
|
||||||
"question": RunnablePassthrough[str]() | passthrough,
|
"question": RunnablePassthrough[str]() | passthrough,
|
||||||
"documents": passthrough | retriever,
|
"documents": passthrough | retriever,
|
||||||
@ -760,6 +760,188 @@ def test_seq_prompt_map(mocker: MockerFixture, snapshot: SnapshotAssertion) -> N
|
|||||||
assert len(map_run.child_runs) == 3
|
assert len(map_run.child_runs) == 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_map_stream() -> None:
|
||||||
|
prompt = (
|
||||||
|
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
|
||||||
|
+ "{question}"
|
||||||
|
)
|
||||||
|
|
||||||
|
chat_res = "i'm a chatbot"
|
||||||
|
# sleep to better simulate a real stream
|
||||||
|
chat = FakeListChatModel(responses=[chat_res], sleep=0.01)
|
||||||
|
|
||||||
|
llm_res = "i'm a textbot"
|
||||||
|
# sleep to better simulate a real stream
|
||||||
|
llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)
|
||||||
|
|
||||||
|
chain: Runnable = prompt | {
|
||||||
|
"chat": chat.bind(stop=["Thought:"]),
|
||||||
|
"llm": llm,
|
||||||
|
"passthrough": RunnablePassthrough(),
|
||||||
|
}
|
||||||
|
|
||||||
|
stream = chain.stream({"question": "What is your name?"})
|
||||||
|
|
||||||
|
final_value = None
|
||||||
|
streamed_chunks = []
|
||||||
|
for chunk in stream:
|
||||||
|
streamed_chunks.append(chunk)
|
||||||
|
if final_value is None:
|
||||||
|
final_value = chunk
|
||||||
|
else:
|
||||||
|
final_value += chunk
|
||||||
|
|
||||||
|
assert streamed_chunks[0] in [
|
||||||
|
{"passthrough": prompt.invoke({"question": "What is your name?"})},
|
||||||
|
{"llm": "i"},
|
||||||
|
{"chat": "i"},
|
||||||
|
]
|
||||||
|
assert len(streamed_chunks) == len(chat_res) + len(llm_res) + 1
|
||||||
|
assert all(len(c.keys()) == 1 for c in streamed_chunks)
|
||||||
|
assert final_value is not None
|
||||||
|
assert final_value.get("chat").content == "i'm a chatbot"
|
||||||
|
assert final_value.get("llm") == "i'm a textbot"
|
||||||
|
assert final_value.get("passthrough") == prompt.invoke(
|
||||||
|
{"question": "What is your name?"}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_map_stream_iterator_input() -> None:
|
||||||
|
prompt = (
|
||||||
|
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
|
||||||
|
+ "{question}"
|
||||||
|
)
|
||||||
|
|
||||||
|
chat_res = "i'm a chatbot"
|
||||||
|
# sleep to better simulate a real stream
|
||||||
|
chat = FakeListChatModel(responses=[chat_res], sleep=0.01)
|
||||||
|
|
||||||
|
llm_res = "i'm a textbot"
|
||||||
|
# sleep to better simulate a real stream
|
||||||
|
llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)
|
||||||
|
|
||||||
|
chain: Runnable = (
|
||||||
|
prompt
|
||||||
|
| llm
|
||||||
|
| {
|
||||||
|
"chat": chat.bind(stop=["Thought:"]),
|
||||||
|
"llm": llm,
|
||||||
|
"passthrough": RunnablePassthrough(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
stream = chain.stream({"question": "What is your name?"})
|
||||||
|
|
||||||
|
final_value = None
|
||||||
|
streamed_chunks = []
|
||||||
|
for chunk in stream:
|
||||||
|
streamed_chunks.append(chunk)
|
||||||
|
if final_value is None:
|
||||||
|
final_value = chunk
|
||||||
|
else:
|
||||||
|
final_value += chunk
|
||||||
|
|
||||||
|
assert streamed_chunks[0] in [{"passthrough": "i"}, {"llm": "i"}, {"chat": "i"}]
|
||||||
|
assert len(streamed_chunks) == len(chat_res) + len(llm_res) + len(llm_res)
|
||||||
|
assert all(len(c.keys()) == 1 for c in streamed_chunks)
|
||||||
|
assert final_value is not None
|
||||||
|
assert final_value.get("chat").content == "i'm a chatbot"
|
||||||
|
assert final_value.get("llm") == "i'm a textbot"
|
||||||
|
assert final_value.get("passthrough") == "i'm a textbot"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_map_astream() -> None:
|
||||||
|
prompt = (
|
||||||
|
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
|
||||||
|
+ "{question}"
|
||||||
|
)
|
||||||
|
|
||||||
|
chat_res = "i'm a chatbot"
|
||||||
|
# sleep to better simulate a real stream
|
||||||
|
chat = FakeListChatModel(responses=[chat_res], sleep=0.01)
|
||||||
|
|
||||||
|
llm_res = "i'm a textbot"
|
||||||
|
# sleep to better simulate a real stream
|
||||||
|
llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)
|
||||||
|
|
||||||
|
chain: Runnable = prompt | {
|
||||||
|
"chat": chat.bind(stop=["Thought:"]),
|
||||||
|
"llm": llm,
|
||||||
|
"passthrough": RunnablePassthrough(),
|
||||||
|
}
|
||||||
|
|
||||||
|
stream = chain.astream({"question": "What is your name?"})
|
||||||
|
|
||||||
|
final_value = None
|
||||||
|
streamed_chunks = []
|
||||||
|
async for chunk in stream:
|
||||||
|
streamed_chunks.append(chunk)
|
||||||
|
if final_value is None:
|
||||||
|
final_value = chunk
|
||||||
|
else:
|
||||||
|
final_value += chunk
|
||||||
|
|
||||||
|
assert streamed_chunks[0] in [
|
||||||
|
{"passthrough": prompt.invoke({"question": "What is your name?"})},
|
||||||
|
{"llm": "i"},
|
||||||
|
{"chat": "i"},
|
||||||
|
]
|
||||||
|
assert len(streamed_chunks) == len(chat_res) + len(llm_res) + 1
|
||||||
|
assert all(len(c.keys()) == 1 for c in streamed_chunks)
|
||||||
|
assert final_value is not None
|
||||||
|
assert final_value.get("chat").content == "i'm a chatbot"
|
||||||
|
assert final_value.get("llm") == "i'm a textbot"
|
||||||
|
assert final_value.get("passthrough") == prompt.invoke(
|
||||||
|
{"question": "What is your name?"}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_map_astream_iterator_input() -> None:
|
||||||
|
prompt = (
|
||||||
|
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
|
||||||
|
+ "{question}"
|
||||||
|
)
|
||||||
|
|
||||||
|
chat_res = "i'm a chatbot"
|
||||||
|
# sleep to better simulate a real stream
|
||||||
|
chat = FakeListChatModel(responses=[chat_res], sleep=0.01)
|
||||||
|
|
||||||
|
llm_res = "i'm a textbot"
|
||||||
|
# sleep to better simulate a real stream
|
||||||
|
llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)
|
||||||
|
|
||||||
|
chain: Runnable = (
|
||||||
|
prompt
|
||||||
|
| llm
|
||||||
|
| {
|
||||||
|
"chat": chat.bind(stop=["Thought:"]),
|
||||||
|
"llm": llm,
|
||||||
|
"passthrough": RunnablePassthrough(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
stream = chain.astream({"question": "What is your name?"})
|
||||||
|
|
||||||
|
final_value = None
|
||||||
|
streamed_chunks = []
|
||||||
|
async for chunk in stream:
|
||||||
|
streamed_chunks.append(chunk)
|
||||||
|
if final_value is None:
|
||||||
|
final_value = chunk
|
||||||
|
else:
|
||||||
|
final_value += chunk
|
||||||
|
|
||||||
|
assert streamed_chunks[0] in [{"passthrough": "i"}, {"llm": "i"}, {"chat": "i"}]
|
||||||
|
assert len(streamed_chunks) == len(chat_res) + len(llm_res) + len(llm_res)
|
||||||
|
assert all(len(c.keys()) == 1 for c in streamed_chunks)
|
||||||
|
assert final_value is not None
|
||||||
|
assert final_value.get("chat").content == "i'm a chatbot"
|
||||||
|
assert final_value.get("llm") == "i'm a textbot"
|
||||||
|
assert final_value.get("passthrough") == llm_res
|
||||||
|
|
||||||
|
|
||||||
def test_bind_bind() -> None:
|
def test_bind_bind() -> None:
|
||||||
llm = FakeListLLM(responses=["i'm a textbot"])
|
llm = FakeListLLM(responses=["i'm a textbot"])
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user