mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-23 15:19:33 +00:00
Merge branch 'master' into bagatur/locals_in_config
This commit is contained in:
commit
24a197f96a
@ -1,4 +1,6 @@
|
||||
"""Fake ChatModel for testing purposes."""
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
@ -14,6 +16,7 @@ class FakeListChatModel(SimpleChatModel):
|
||||
"""Fake ChatModel for testing purposes."""
|
||||
|
||||
responses: List
|
||||
sleep: Optional[float] = None
|
||||
i: int = 0
|
||||
|
||||
@property
|
||||
@ -48,6 +51,8 @@ class FakeListChatModel(SimpleChatModel):
|
||||
else:
|
||||
self.i = 0
|
||||
for c in response:
|
||||
if self.sleep is not None:
|
||||
time.sleep(self.sleep)
|
||||
yield ChatGenerationChunk(message=AIMessageChunk(content=c))
|
||||
|
||||
async def _astream(
|
||||
@ -63,6 +68,8 @@ class FakeListChatModel(SimpleChatModel):
|
||||
else:
|
||||
self.i = 0
|
||||
for c in response:
|
||||
if self.sleep is not None:
|
||||
await asyncio.sleep(self.sleep)
|
||||
yield ChatGenerationChunk(message=AIMessageChunk(content=c))
|
||||
|
||||
@property
|
||||
|
@ -1,3 +1,5 @@
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Any, AsyncIterator, Iterator, List, Mapping, Optional
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
@ -13,6 +15,7 @@ class FakeListLLM(LLM):
|
||||
"""Fake LLM for testing purposes."""
|
||||
|
||||
responses: List
|
||||
sleep: Optional[float] = None
|
||||
i: int = 0
|
||||
|
||||
@property
|
||||
@ -68,6 +71,8 @@ class FakeStreamingListLLM(FakeListLLM):
|
||||
) -> Iterator[str]:
|
||||
result = self.invoke(input, config)
|
||||
for c in result:
|
||||
if self.sleep is not None:
|
||||
time.sleep(self.sleep)
|
||||
yield c
|
||||
|
||||
async def astream(
|
||||
@ -80,4 +85,6 @@ class FakeStreamingListLLM(FakeListLLM):
|
||||
) -> AsyncIterator[str]:
|
||||
result = await self.ainvoke(input, config)
|
||||
for c in result:
|
||||
if self.sleep is not None:
|
||||
await asyncio.sleep(self.sleep)
|
||||
yield c
|
||||
|
@ -277,6 +277,7 @@ class BaseTransformOutputParser(BaseOutputParser[T]):
|
||||
self,
|
||||
input: Iterator[Union[str, BaseMessage]],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[T]:
|
||||
yield from self._transform_stream_with_config(
|
||||
input, self._transform, config, run_type="parser"
|
||||
@ -286,6 +287,7 @@ class BaseTransformOutputParser(BaseOutputParser[T]):
|
||||
self,
|
||||
input: AsyncIterator[Union[str, BaseMessage]],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[T]:
|
||||
async for chunk in self._atransform_stream_with_config(
|
||||
input, self._atransform, config, run_type="parser"
|
||||
|
@ -1,11 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
import threading
|
||||
from abc import ABC, abstractmethod
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait
|
||||
from copy import deepcopy
|
||||
from itertools import tee
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Awaitable,
|
||||
@ -24,6 +27,13 @@ from typing import (
|
||||
cast,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.load.dump import dumpd
|
||||
from langchain.load.serializable import Serializable
|
||||
@ -35,9 +45,12 @@ from langchain.schema.runnable.config import (
|
||||
get_callback_manager_for_config,
|
||||
)
|
||||
from langchain.schema.runnable.utils import (
|
||||
accepts_run_manager,
|
||||
accepts_run_manager_and_config,
|
||||
gather_with_concurrency,
|
||||
)
|
||||
from langchain.utils.aiter import atee, py_anext
|
||||
from langchain.utils.iter import safetee
|
||||
|
||||
|
||||
Input = TypeVar("Input")
|
||||
@ -55,7 +68,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
other: Union[
|
||||
Runnable[Any, Other],
|
||||
Callable[[Any], Other],
|
||||
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other]]],
|
||||
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]],
|
||||
],
|
||||
) -> RunnableSequence[Input, Other]:
|
||||
return RunnableSequence(first=self, last=coerce_to_runnable(other))
|
||||
@ -65,7 +78,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
other: Union[
|
||||
Runnable[Other, Any],
|
||||
Callable[[Any], Other],
|
||||
Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any]]],
|
||||
Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any], Any]],
|
||||
],
|
||||
) -> RunnableSequence[Other, Output]:
|
||||
return RunnableSequence(first=coerce_to_runnable(other), last=self)
|
||||
@ -142,7 +155,10 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
yield await self.ainvoke(input, config)
|
||||
|
||||
def transform(
|
||||
self, input: Iterator[Input], config: Optional[RunnableConfig] = None
|
||||
self,
|
||||
input: Iterator[Input],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> Iterator[Output]:
|
||||
"""
|
||||
Default implementation of transform, which buffers input and then calls stream.
|
||||
@ -159,10 +175,13 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
# This method should throw an error if gathering fails.
|
||||
final += chunk # type: ignore[operator]
|
||||
if final:
|
||||
yield from self.stream(final, config)
|
||||
yield from self.stream(final, config, **kwargs)
|
||||
|
||||
async def atransform(
|
||||
self, input: AsyncIterator[Input], config: Optional[RunnableConfig] = None
|
||||
self,
|
||||
input: AsyncIterator[Input],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> AsyncIterator[Output]:
|
||||
"""
|
||||
Default implementation of atransform, which buffers input and calls astream.
|
||||
@ -180,7 +199,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
final += chunk # type: ignore[operator]
|
||||
|
||||
if final:
|
||||
async for output in self.astream(final, config):
|
||||
async for output in self.astream(final, config, **kwargs):
|
||||
yield output
|
||||
|
||||
def bind(self, **kwargs: Any) -> Runnable[Input, Output]:
|
||||
@ -224,8 +243,12 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
|
||||
def _call_with_config(
|
||||
self,
|
||||
func: Callable[[Any], Output],
|
||||
input: Any,
|
||||
func: Union[
|
||||
Callable[[Input], Output],
|
||||
Callable[[Input, CallbackManagerForChainRun], Output],
|
||||
Callable[[Input, CallbackManagerForChainRun, RunnableConfig], Output],
|
||||
],
|
||||
input: Input,
|
||||
config: Optional[RunnableConfig],
|
||||
run_type: Optional[str] = None,
|
||||
) -> Output:
|
||||
@ -239,7 +262,16 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
run_type=run_type,
|
||||
)
|
||||
try:
|
||||
output = func(input)
|
||||
if accepts_run_manager_and_config(func):
|
||||
output = func(
|
||||
input,
|
||||
run_manager=run_manager,
|
||||
config=config,
|
||||
) # type: ignore[call-arg]
|
||||
elif accepts_run_manager(func):
|
||||
output = func(input, run_manager=run_manager) # type: ignore[call-arg]
|
||||
else:
|
||||
output = func(input) # type: ignore[call-arg]
|
||||
except Exception as e:
|
||||
run_manager.on_chain_error(e)
|
||||
raise
|
||||
@ -254,8 +286,15 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
|
||||
async def _acall_with_config(
|
||||
self,
|
||||
func: Callable[[Any], Awaitable[Output]],
|
||||
input: Any,
|
||||
func: Union[
|
||||
Callable[[Input], Awaitable[Output]],
|
||||
Callable[[Input, AsyncCallbackManagerForChainRun], Awaitable[Output]],
|
||||
Callable[
|
||||
[Input, AsyncCallbackManagerForChainRun, RunnableConfig],
|
||||
Awaitable[Output],
|
||||
],
|
||||
],
|
||||
input: Input,
|
||||
config: Optional[RunnableConfig],
|
||||
run_type: Optional[str] = None,
|
||||
) -> Output:
|
||||
@ -269,7 +308,19 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
run_type=run_type,
|
||||
)
|
||||
try:
|
||||
output = await func(input)
|
||||
if accepts_run_manager_and_config(func):
|
||||
output = await func(
|
||||
input,
|
||||
run_manager=run_manager,
|
||||
config=config,
|
||||
) # type: ignore[call-arg]
|
||||
elif accepts_run_manager(func):
|
||||
output = await func(
|
||||
input,
|
||||
run_manager=run_manager,
|
||||
) # type: ignore[call-arg]
|
||||
else:
|
||||
output = await func(input) # type: ignore[call-arg]
|
||||
except Exception as e:
|
||||
await run_manager.on_chain_error(e)
|
||||
raise
|
||||
@ -285,7 +336,18 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
def _transform_stream_with_config(
|
||||
self,
|
||||
input: Iterator[Input],
|
||||
transformer: Callable[[Iterator[Input]], Iterator[Output]],
|
||||
transformer: Union[
|
||||
Callable[[Iterator[Input]], Iterator[Output]],
|
||||
Callable[[Iterator[Input], CallbackManagerForChainRun], Iterator[Output]],
|
||||
Callable[
|
||||
[
|
||||
Iterator[Input],
|
||||
CallbackManagerForChainRun,
|
||||
RunnableConfig,
|
||||
],
|
||||
Iterator[Output],
|
||||
],
|
||||
],
|
||||
config: Optional[RunnableConfig],
|
||||
run_type: Optional[str] = None,
|
||||
) -> Iterator[Output]:
|
||||
@ -308,7 +370,20 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
run_type=run_type,
|
||||
)
|
||||
try:
|
||||
for chunk in transformer(input_for_transform):
|
||||
if accepts_run_manager_and_config(transformer):
|
||||
iterator = transformer(
|
||||
input_for_transform,
|
||||
run_manager=run_manager,
|
||||
config=config,
|
||||
) # type: ignore[call-arg]
|
||||
elif accepts_run_manager(transformer):
|
||||
iterator = transformer(
|
||||
input_for_transform,
|
||||
run_manager=run_manager,
|
||||
) # type: ignore[call-arg]
|
||||
else:
|
||||
iterator = transformer(input_for_transform) # type: ignore[call-arg]
|
||||
for chunk in iterator:
|
||||
yield chunk
|
||||
if final_output_supported:
|
||||
if final_output is None:
|
||||
@ -350,7 +425,21 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
async def _atransform_stream_with_config(
|
||||
self,
|
||||
input: AsyncIterator[Input],
|
||||
transformer: Callable[[AsyncIterator[Input]], AsyncIterator[Output]],
|
||||
transformer: Union[
|
||||
Callable[[AsyncIterator[Input]], AsyncIterator[Output]],
|
||||
Callable[
|
||||
[AsyncIterator[Input], AsyncCallbackManagerForChainRun],
|
||||
AsyncIterator[Output],
|
||||
],
|
||||
Callable[
|
||||
[
|
||||
AsyncIterator[Input],
|
||||
AsyncCallbackManagerForChainRun,
|
||||
RunnableConfig,
|
||||
],
|
||||
AsyncIterator[Output],
|
||||
],
|
||||
],
|
||||
config: Optional[RunnableConfig],
|
||||
run_type: Optional[str] = None,
|
||||
) -> AsyncIterator[Output]:
|
||||
@ -373,7 +462,22 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
run_type=run_type,
|
||||
)
|
||||
try:
|
||||
async for chunk in transformer(input_for_transform):
|
||||
# mypy can't quite work out thew type guard here, but this is safe,
|
||||
# check implementations of the accepts_* functions
|
||||
if accepts_run_manager_and_config(transformer):
|
||||
iterator = transformer(
|
||||
input_for_transform,
|
||||
run_manager=run_manager,
|
||||
config=config,
|
||||
) # type: ignore[call-arg]
|
||||
elif accepts_run_manager(transformer):
|
||||
iterator = transformer(
|
||||
input_for_transform,
|
||||
run_manager=run_manager,
|
||||
) # type: ignore[call-arg]
|
||||
else:
|
||||
iterator = transformer(input_for_transform) # type: ignore[call-arg]
|
||||
async for chunk in iterator:
|
||||
yield chunk
|
||||
if final_output_supported:
|
||||
if final_output is None:
|
||||
@ -663,7 +767,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
other: Union[
|
||||
Runnable[Any, Other],
|
||||
Callable[[Any], Other],
|
||||
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other]]],
|
||||
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]],
|
||||
],
|
||||
) -> RunnableSequence[Input, Other]:
|
||||
if isinstance(other, RunnableSequence):
|
||||
@ -684,7 +788,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
other: Union[
|
||||
Runnable[Other, Any],
|
||||
Callable[[Any], Other],
|
||||
Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any]]],
|
||||
Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any], Any]],
|
||||
],
|
||||
) -> RunnableSequence[Other, Output]:
|
||||
if isinstance(other, RunnableSequence):
|
||||
@ -818,7 +922,6 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
) -> List[Output]:
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManager,
|
||||
AsyncCallbackManagerForChainRun,
|
||||
)
|
||||
|
||||
# setup callbacks
|
||||
@ -1008,6 +1111,21 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
)
|
||||
|
||||
|
||||
class RunnableMapChunk(Dict[str, Any]):
|
||||
"""
|
||||
Partial output from a RunnableMap
|
||||
"""
|
||||
|
||||
def __add__(self, other: RunnableMapChunk) -> RunnableMapChunk:
|
||||
chunk = copy.deepcopy(self)
|
||||
for key in other:
|
||||
if key not in chunk or chunk[key] is None:
|
||||
chunk[key] = other[key]
|
||||
elif other[key] is not None:
|
||||
chunk[key] += other[key]
|
||||
return chunk
|
||||
|
||||
|
||||
class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
||||
"""
|
||||
A runnable that runs a mapping of runnables in parallel,
|
||||
@ -1057,7 +1175,9 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
||||
local_metadata=None,
|
||||
)
|
||||
# start the root run
|
||||
run_manager = callback_manager.on_chain_start(dumpd(self), {"input": input})
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
||||
)
|
||||
|
||||
# gather results from all steps
|
||||
try:
|
||||
@ -1090,7 +1210,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
||||
callback_manager = get_async_callback_manager_for_config(config)
|
||||
# start the root run
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
dumpd(self), {"input": input}
|
||||
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
||||
)
|
||||
|
||||
# gather results from all steps
|
||||
@ -1116,6 +1236,134 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
||||
await run_manager.on_chain_end(output)
|
||||
return output
|
||||
|
||||
def _transform(
|
||||
self,
|
||||
input: Iterator[Input],
|
||||
run_manager: CallbackManagerForChainRun,
|
||||
config: RunnableConfig,
|
||||
) -> Iterator[RunnableMapChunk]:
|
||||
# Shallow copy steps to ignore mutations while in progress
|
||||
steps = dict(self.steps)
|
||||
# Each step gets a copy of the input iterator,
|
||||
# which is consumed in parallel in a separate thread.
|
||||
input_copies = list(safetee(input, len(steps), lock=threading.Lock()))
|
||||
with ThreadPoolExecutor() as executor:
|
||||
# Create the transform() generator for each step
|
||||
named_generators = [
|
||||
(
|
||||
name,
|
||||
step.transform(
|
||||
input_copies.pop(),
|
||||
patch_config(config, run_manager.get_child()),
|
||||
),
|
||||
)
|
||||
for name, step in steps.items()
|
||||
]
|
||||
# Start the first iteration of each generator
|
||||
futures = {
|
||||
executor.submit(next, generator): (step_name, generator)
|
||||
for step_name, generator in named_generators
|
||||
}
|
||||
# Yield chunks from each as they become available,
|
||||
# and start the next iteration of that generator that yielded it.
|
||||
# When all generators are exhausted, stop.
|
||||
while futures:
|
||||
completed_futures, _ = wait(futures, return_when=FIRST_COMPLETED)
|
||||
for future in completed_futures:
|
||||
(step_name, generator) = futures.pop(future)
|
||||
try:
|
||||
chunk = RunnableMapChunk({step_name: future.result()})
|
||||
yield chunk
|
||||
futures[executor.submit(next, generator)] = (
|
||||
step_name,
|
||||
generator,
|
||||
)
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
def transform(
|
||||
self,
|
||||
input: Iterator[Input],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[Dict[str, Any]]:
|
||||
yield from self._transform_stream_with_config(
|
||||
input, self._transform, config, **kwargs
|
||||
)
|
||||
|
||||
def stream(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None
|
||||
) -> Iterator[Dict[str, Any]]:
|
||||
yield from self.transform(iter([input]), config)
|
||||
|
||||
async def _atransform(
|
||||
self,
|
||||
input: AsyncIterator[Input],
|
||||
run_manager: AsyncCallbackManagerForChainRun,
|
||||
config: RunnableConfig,
|
||||
) -> AsyncIterator[RunnableMapChunk]:
|
||||
# Shallow copy steps to ignore mutations while in progress
|
||||
steps = dict(self.steps)
|
||||
# Each step gets a copy of the input iterator,
|
||||
# which is consumed in parallel in a separate thread.
|
||||
input_copies = list(atee(input, len(steps), lock=asyncio.Lock()))
|
||||
# Create the transform() generator for each step
|
||||
named_generators = [
|
||||
(
|
||||
name,
|
||||
step.atransform(
|
||||
input_copies.pop(), patch_config(config, run_manager.get_child())
|
||||
),
|
||||
)
|
||||
for name, step in steps.items()
|
||||
]
|
||||
|
||||
# Wrap in a coroutine to satisfy linter
|
||||
async def get_next_chunk(generator: AsyncIterator) -> Optional[Output]:
|
||||
return await py_anext(generator)
|
||||
|
||||
# Start the first iteration of each generator
|
||||
tasks = {
|
||||
asyncio.create_task(get_next_chunk(generator)): (step_name, generator)
|
||||
for step_name, generator in named_generators
|
||||
}
|
||||
# Yield chunks from each as they become available,
|
||||
# and start the next iteration of the generator that yielded it.
|
||||
# When all generators are exhausted, stop.
|
||||
while tasks:
|
||||
completed_tasks, _ = await asyncio.wait(
|
||||
tasks, return_when=asyncio.FIRST_COMPLETED
|
||||
)
|
||||
for task in completed_tasks:
|
||||
(step_name, generator) = tasks.pop(task)
|
||||
try:
|
||||
chunk = RunnableMapChunk({step_name: task.result()})
|
||||
yield chunk
|
||||
new_task = asyncio.create_task(get_next_chunk(generator))
|
||||
tasks[new_task] = (step_name, generator)
|
||||
except StopAsyncIteration:
|
||||
pass
|
||||
|
||||
async def atransform(
|
||||
self,
|
||||
input: AsyncIterator[Input],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[Dict[str, Any]]:
|
||||
async for chunk in self._atransform_stream_with_config(
|
||||
input, self._atransform, config, **kwargs
|
||||
):
|
||||
yield chunk
|
||||
|
||||
async def astream(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None
|
||||
) -> AsyncIterator[Dict[str, Any]]:
|
||||
async def input_aiter() -> AsyncIterator[Input]:
|
||||
yield input
|
||||
|
||||
async for chunk in self.atransform(input_aiter(), config):
|
||||
yield chunk
|
||||
|
||||
|
||||
class RunnableLambda(Runnable[Input, Output]):
|
||||
"""
|
||||
@ -1206,14 +1454,22 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
||||
yield item
|
||||
|
||||
def transform(
|
||||
self, input: Iterator[Input], config: Optional[RunnableConfig] = None
|
||||
self,
|
||||
input: Iterator[Input],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[Output]:
|
||||
yield from self.bound.transform(input, config, **self.kwargs)
|
||||
yield from self.bound.transform(input, config, **{**self.kwargs, **kwargs})
|
||||
|
||||
async def atransform(
|
||||
self, input: AsyncIterator[Input], config: Optional[RunnableConfig] = None
|
||||
self,
|
||||
input: AsyncIterator[Input],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[Output]:
|
||||
async for item in self.bound.atransform(input, config, **self.kwargs):
|
||||
async for item in self.bound.atransform(
|
||||
input, config, **{**self.kwargs, **kwargs}
|
||||
):
|
||||
yield item
|
||||
|
||||
|
||||
@ -1233,7 +1489,7 @@ def coerce_to_runnable(
|
||||
thing: Union[
|
||||
Runnable[Input, Output],
|
||||
Callable[[Input], Output],
|
||||
Mapping[str, Union[Runnable[Input, Output], Callable[[Input], Output]]],
|
||||
Mapping[str, Any],
|
||||
]
|
||||
) -> Runnable[Input, Output]:
|
||||
if isinstance(thing, Runnable):
|
||||
@ -1241,7 +1497,9 @@ def coerce_to_runnable(
|
||||
elif callable(thing):
|
||||
return RunnableLambda(thing)
|
||||
elif isinstance(thing, dict):
|
||||
runnables = {key: coerce_to_runnable(r) for key, r in thing.items()}
|
||||
runnables: Mapping[str, Runnable[Any, Any]] = {
|
||||
key: coerce_to_runnable(r) for key, r in thing.items()
|
||||
}
|
||||
return cast(Runnable[Input, Output], RunnableMap(steps=runnables))
|
||||
else:
|
||||
raise TypeError(
|
||||
|
@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import AsyncIterator, Iterator, List, Optional
|
||||
from typing import Any, AsyncIterator, Iterator, List, Optional
|
||||
|
||||
from langchain.load.serializable import Serializable
|
||||
from langchain.schema.runnable.base import Input, Runnable
|
||||
@ -32,17 +32,23 @@ class RunnablePassthrough(Serializable, Runnable[Input, Input]):
|
||||
return self._call_with_config(identity, input, config)
|
||||
|
||||
async def ainvoke(
|
||||
self, input: Input, config: RunnableConfig | None = None
|
||||
self, input: Input, config: Optional[RunnableConfig] = None
|
||||
) -> Input:
|
||||
return await self._acall_with_config(aidentity, input, config)
|
||||
|
||||
def transform(
|
||||
self, input: Iterator[Input], config: RunnableConfig | None = None
|
||||
self,
|
||||
input: Iterator[Input],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[Input]:
|
||||
return self._transform_stream_with_config(input, identity, config)
|
||||
|
||||
async def atransform(
|
||||
self, input: AsyncIterator[Input], config: RunnableConfig | None = None
|
||||
self,
|
||||
input: AsyncIterator[Input],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[Input]:
|
||||
async for chunk in self._atransform_stream_with_config(input, identity, config):
|
||||
yield chunk
|
||||
|
@ -1,7 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Coroutine, Union
|
||||
from inspect import signature
|
||||
from typing import Any, Callable, Coroutine, Union
|
||||
|
||||
|
||||
async def gated_coro(semaphore: asyncio.Semaphore, coro: Coroutine) -> Any:
|
||||
@ -16,3 +17,17 @@ async def gather_with_concurrency(n: Union[int, None], *coros: Coroutine) -> lis
|
||||
semaphore = asyncio.Semaphore(n)
|
||||
|
||||
return await asyncio.gather(*(gated_coro(semaphore, c) for c in coros))
|
||||
|
||||
|
||||
def accepts_run_manager(callable: Callable[..., Any]) -> bool:
|
||||
try:
|
||||
return signature(callable).parameters.get("run_manager") is not None
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def accepts_run_manager_and_config(callable: Callable[..., Any]) -> bool:
|
||||
return (
|
||||
accepts_run_manager(callable)
|
||||
and signature(callable).parameters.get("config") is not None
|
||||
)
|
||||
|
@ -7,6 +7,7 @@ MIT License
|
||||
from collections import deque
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncContextManager,
|
||||
AsyncGenerator,
|
||||
AsyncIterator,
|
||||
Awaitable,
|
||||
@ -15,6 +16,7 @@ from typing import (
|
||||
Generic,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
@ -64,33 +66,45 @@ def py_anext(
|
||||
return anext_impl()
|
||||
|
||||
|
||||
class NoLock:
|
||||
"""Dummy lock that provides the proper interface but no protection"""
|
||||
|
||||
async def __aenter__(self) -> None:
|
||||
pass
|
||||
|
||||
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
async def tee_peer(
|
||||
iterator: AsyncIterator[T],
|
||||
# the buffer specific to this peer
|
||||
buffer: Deque[T],
|
||||
# the buffers of all peers, including our own
|
||||
peers: List[Deque[T]],
|
||||
lock: AsyncContextManager[Any],
|
||||
) -> AsyncGenerator[T, None]:
|
||||
"""An individual iterator of a :py:func:`~.tee`"""
|
||||
try:
|
||||
while True:
|
||||
if not buffer:
|
||||
# Another peer produced an item while we were waiting for the lock.
|
||||
# Proceed with the next loop iteration to yield the item.
|
||||
if buffer:
|
||||
continue
|
||||
try:
|
||||
item = await iterator.__anext__()
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
else:
|
||||
# Append to all buffers, including our own. We'll fetch our
|
||||
# item from the buffer again, instead of yielding it directly.
|
||||
# This ensures the proper item ordering if any of our peers
|
||||
# are fetching items concurrently. They may have buffered their
|
||||
# item already.
|
||||
for peer_buffer in peers:
|
||||
peer_buffer.append(item)
|
||||
async with lock:
|
||||
# Another peer produced an item while we were waiting for the lock.
|
||||
# Proceed with the next loop iteration to yield the item.
|
||||
if buffer:
|
||||
continue
|
||||
try:
|
||||
item = await iterator.__anext__()
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
else:
|
||||
# Append to all buffers, including our own. We'll fetch our
|
||||
# item from the buffer again, instead of yielding it directly.
|
||||
# This ensures the proper item ordering if any of our peers
|
||||
# are fetching items concurrently. They may have buffered their
|
||||
# item already.
|
||||
for peer_buffer in peers:
|
||||
peer_buffer.append(item)
|
||||
yield buffer.popleft()
|
||||
finally:
|
||||
# this peer is done – remove its buffer
|
||||
@ -145,6 +159,8 @@ class Tee(Generic[T]):
|
||||
self,
|
||||
iterable: AsyncIterator[T],
|
||||
n: int = 2,
|
||||
*,
|
||||
lock: Optional[AsyncContextManager[Any]] = None,
|
||||
):
|
||||
self._iterator = iterable.__aiter__() # before 3.10 aiter() doesn't exist
|
||||
self._buffers: List[Deque[T]] = [deque() for _ in range(n)]
|
||||
@ -153,6 +169,7 @@ class Tee(Generic[T]):
|
||||
iterator=self._iterator,
|
||||
buffer=buffer,
|
||||
peers=self._buffers,
|
||||
lock=lock if lock is not None else NoLock(),
|
||||
)
|
||||
for buffer in self._buffers
|
||||
)
|
||||
|
162
libs/langchain/langchain/utils/iter.py
Normal file
162
libs/langchain/langchain/utils/iter.py
Normal file
@ -0,0 +1,162 @@
|
||||
from collections import deque
|
||||
from typing import (
|
||||
Any,
|
||||
ContextManager,
|
||||
Deque,
|
||||
Generator,
|
||||
Generic,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
overload,
|
||||
)
|
||||
|
||||
from typing_extensions import Literal
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class NoLock:
|
||||
"""Dummy lock that provides the proper interface but no protection"""
|
||||
|
||||
def __enter__(self) -> None:
|
||||
pass
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> Literal[False]:
|
||||
return False
|
||||
|
||||
|
||||
def tee_peer(
|
||||
iterator: Iterator[T],
|
||||
# the buffer specific to this peer
|
||||
buffer: Deque[T],
|
||||
# the buffers of all peers, including our own
|
||||
peers: List[Deque[T]],
|
||||
lock: ContextManager[Any],
|
||||
) -> Generator[T, None, None]:
|
||||
"""An individual iterator of a :py:func:`~.tee`"""
|
||||
try:
|
||||
while True:
|
||||
if not buffer:
|
||||
with lock:
|
||||
# Another peer produced an item while we were waiting for the lock.
|
||||
# Proceed with the next loop iteration to yield the item.
|
||||
if buffer:
|
||||
continue
|
||||
try:
|
||||
item = next(iterator)
|
||||
except StopIteration:
|
||||
break
|
||||
else:
|
||||
# Append to all buffers, including our own. We'll fetch our
|
||||
# item from the buffer again, instead of yielding it directly.
|
||||
# This ensures the proper item ordering if any of our peers
|
||||
# are fetching items concurrently. They may have buffered their
|
||||
# item already.
|
||||
for peer_buffer in peers:
|
||||
peer_buffer.append(item)
|
||||
yield buffer.popleft()
|
||||
finally:
|
||||
# this peer is done – remove its buffer
|
||||
for idx, peer_buffer in enumerate(peers): # pragma: no branch
|
||||
if peer_buffer is buffer:
|
||||
peers.pop(idx)
|
||||
break
|
||||
# if we are the last peer, try and close the iterator
|
||||
if not peers and hasattr(iterator, "close"):
|
||||
iterator.close()
|
||||
|
||||
|
||||
class Tee(Generic[T]):
|
||||
"""
|
||||
Create ``n`` separate asynchronous iterators over ``iterable``
|
||||
|
||||
This splits a single ``iterable`` into multiple iterators, each providing
|
||||
the same items in the same order.
|
||||
All child iterators may advance separately but share the same items
|
||||
from ``iterable`` -- when the most advanced iterator retrieves an item,
|
||||
it is buffered until the least advanced iterator has yielded it as well.
|
||||
A ``tee`` works lazily and can handle an infinite ``iterable``, provided
|
||||
that all iterators advance.
|
||||
|
||||
.. code-block:: python3
|
||||
|
||||
async def derivative(sensor_data):
|
||||
previous, current = a.tee(sensor_data, n=2)
|
||||
await a.anext(previous) # advance one iterator
|
||||
return a.map(operator.sub, previous, current)
|
||||
|
||||
Unlike :py:func:`itertools.tee`, :py:func:`~.tee` returns a custom type instead
|
||||
of a :py:class:`tuple`. Like a tuple, it can be indexed, iterated and unpacked
|
||||
to get the child iterators. In addition, its :py:meth:`~.tee.aclose` method
|
||||
immediately closes all children, and it can be used in an ``async with`` context
|
||||
for the same effect.
|
||||
|
||||
If ``iterable`` is an iterator and read elsewhere, ``tee`` will *not*
|
||||
provide these items. Also, ``tee`` must internally buffer each item until the
|
||||
last iterator has yielded it; if the most and least advanced iterator differ
|
||||
by most data, using a :py:class:`list` is more efficient (but not lazy).
|
||||
|
||||
If the underlying iterable is concurrency safe (``anext`` may be awaited
|
||||
concurrently) the resulting iterators are concurrency safe as well. Otherwise,
|
||||
the iterators are safe if there is only ever one single "most advanced" iterator.
|
||||
To enforce sequential use of ``anext``, provide a ``lock``
|
||||
- e.g. an :py:class:`asyncio.Lock` instance in an :py:mod:`asyncio` application -
|
||||
and access is automatically synchronised.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
iterable: Iterator[T],
|
||||
n: int = 2,
|
||||
*,
|
||||
lock: Optional[ContextManager[Any]] = None,
|
||||
):
|
||||
self._iterator = iter(iterable)
|
||||
self._buffers: List[Deque[T]] = [deque() for _ in range(n)]
|
||||
self._children = tuple(
|
||||
tee_peer(
|
||||
iterator=self._iterator,
|
||||
buffer=buffer,
|
||||
peers=self._buffers,
|
||||
lock=lock if lock is not None else NoLock(),
|
||||
)
|
||||
for buffer in self._buffers
|
||||
)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._children)
|
||||
|
||||
@overload
|
||||
def __getitem__(self, item: int) -> Iterator[T]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __getitem__(self, item: slice) -> Tuple[Iterator[T], ...]:
|
||||
...
|
||||
|
||||
def __getitem__(
|
||||
self, item: Union[int, slice]
|
||||
) -> Union[Iterator[T], Tuple[Iterator[T], ...]]:
|
||||
return self._children[item]
|
||||
|
||||
def __iter__(self) -> Iterator[Iterator[T]]:
|
||||
yield from self._children
|
||||
|
||||
def __enter__(self) -> "Tee[T]":
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> Literal[False]:
|
||||
self.close()
|
||||
return False
|
||||
|
||||
def close(self) -> None:
|
||||
for child in self._children:
|
||||
child.close()
|
||||
|
||||
|
||||
# Why this is needed https://stackoverflow.com/a/44638570
|
||||
safetee = Tee
|
@ -538,7 +538,7 @@ Question:
|
||||
|
||||
parser = CommaSeparatedListOutputParser()
|
||||
|
||||
chain = (
|
||||
chain: Runnable = (
|
||||
{
|
||||
"question": RunnablePassthrough[str]() | passthrough,
|
||||
"documents": passthrough | retriever,
|
||||
@ -770,6 +770,188 @@ def test_seq_prompt_map(mocker: MockerFixture, snapshot: SnapshotAssertion) -> N
|
||||
assert len(map_run.child_runs) == 3
|
||||
|
||||
|
||||
def test_map_stream() -> None:
|
||||
prompt = (
|
||||
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
|
||||
+ "{question}"
|
||||
)
|
||||
|
||||
chat_res = "i'm a chatbot"
|
||||
# sleep to better simulate a real stream
|
||||
chat = FakeListChatModel(responses=[chat_res], sleep=0.01)
|
||||
|
||||
llm_res = "i'm a textbot"
|
||||
# sleep to better simulate a real stream
|
||||
llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)
|
||||
|
||||
chain: Runnable = prompt | {
|
||||
"chat": chat.bind(stop=["Thought:"]),
|
||||
"llm": llm,
|
||||
"passthrough": RunnablePassthrough(),
|
||||
}
|
||||
|
||||
stream = chain.stream({"question": "What is your name?"})
|
||||
|
||||
final_value = None
|
||||
streamed_chunks = []
|
||||
for chunk in stream:
|
||||
streamed_chunks.append(chunk)
|
||||
if final_value is None:
|
||||
final_value = chunk
|
||||
else:
|
||||
final_value += chunk
|
||||
|
||||
assert streamed_chunks[0] in [
|
||||
{"passthrough": prompt.invoke({"question": "What is your name?"})},
|
||||
{"llm": "i"},
|
||||
{"chat": "i"},
|
||||
]
|
||||
assert len(streamed_chunks) == len(chat_res) + len(llm_res) + 1
|
||||
assert all(len(c.keys()) == 1 for c in streamed_chunks)
|
||||
assert final_value is not None
|
||||
assert final_value.get("chat").content == "i'm a chatbot"
|
||||
assert final_value.get("llm") == "i'm a textbot"
|
||||
assert final_value.get("passthrough") == prompt.invoke(
|
||||
{"question": "What is your name?"}
|
||||
)
|
||||
|
||||
|
||||
def test_map_stream_iterator_input() -> None:
|
||||
prompt = (
|
||||
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
|
||||
+ "{question}"
|
||||
)
|
||||
|
||||
chat_res = "i'm a chatbot"
|
||||
# sleep to better simulate a real stream
|
||||
chat = FakeListChatModel(responses=[chat_res], sleep=0.01)
|
||||
|
||||
llm_res = "i'm a textbot"
|
||||
# sleep to better simulate a real stream
|
||||
llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)
|
||||
|
||||
chain: Runnable = (
|
||||
prompt
|
||||
| llm
|
||||
| {
|
||||
"chat": chat.bind(stop=["Thought:"]),
|
||||
"llm": llm,
|
||||
"passthrough": RunnablePassthrough(),
|
||||
}
|
||||
)
|
||||
|
||||
stream = chain.stream({"question": "What is your name?"})
|
||||
|
||||
final_value = None
|
||||
streamed_chunks = []
|
||||
for chunk in stream:
|
||||
streamed_chunks.append(chunk)
|
||||
if final_value is None:
|
||||
final_value = chunk
|
||||
else:
|
||||
final_value += chunk
|
||||
|
||||
assert streamed_chunks[0] in [{"passthrough": "i"}, {"llm": "i"}, {"chat": "i"}]
|
||||
assert len(streamed_chunks) == len(chat_res) + len(llm_res) + len(llm_res)
|
||||
assert all(len(c.keys()) == 1 for c in streamed_chunks)
|
||||
assert final_value is not None
|
||||
assert final_value.get("chat").content == "i'm a chatbot"
|
||||
assert final_value.get("llm") == "i'm a textbot"
|
||||
assert final_value.get("passthrough") == "i'm a textbot"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_map_astream() -> None:
|
||||
prompt = (
|
||||
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
|
||||
+ "{question}"
|
||||
)
|
||||
|
||||
chat_res = "i'm a chatbot"
|
||||
# sleep to better simulate a real stream
|
||||
chat = FakeListChatModel(responses=[chat_res], sleep=0.01)
|
||||
|
||||
llm_res = "i'm a textbot"
|
||||
# sleep to better simulate a real stream
|
||||
llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)
|
||||
|
||||
chain: Runnable = prompt | {
|
||||
"chat": chat.bind(stop=["Thought:"]),
|
||||
"llm": llm,
|
||||
"passthrough": RunnablePassthrough(),
|
||||
}
|
||||
|
||||
stream = chain.astream({"question": "What is your name?"})
|
||||
|
||||
final_value = None
|
||||
streamed_chunks = []
|
||||
async for chunk in stream:
|
||||
streamed_chunks.append(chunk)
|
||||
if final_value is None:
|
||||
final_value = chunk
|
||||
else:
|
||||
final_value += chunk
|
||||
|
||||
assert streamed_chunks[0] in [
|
||||
{"passthrough": prompt.invoke({"question": "What is your name?"})},
|
||||
{"llm": "i"},
|
||||
{"chat": "i"},
|
||||
]
|
||||
assert len(streamed_chunks) == len(chat_res) + len(llm_res) + 1
|
||||
assert all(len(c.keys()) == 1 for c in streamed_chunks)
|
||||
assert final_value is not None
|
||||
assert final_value.get("chat").content == "i'm a chatbot"
|
||||
assert final_value.get("llm") == "i'm a textbot"
|
||||
assert final_value.get("passthrough") == prompt.invoke(
|
||||
{"question": "What is your name?"}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_map_astream_iterator_input() -> None:
|
||||
prompt = (
|
||||
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
|
||||
+ "{question}"
|
||||
)
|
||||
|
||||
chat_res = "i'm a chatbot"
|
||||
# sleep to better simulate a real stream
|
||||
chat = FakeListChatModel(responses=[chat_res], sleep=0.01)
|
||||
|
||||
llm_res = "i'm a textbot"
|
||||
# sleep to better simulate a real stream
|
||||
llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)
|
||||
|
||||
chain: Runnable = (
|
||||
prompt
|
||||
| llm
|
||||
| {
|
||||
"chat": chat.bind(stop=["Thought:"]),
|
||||
"llm": llm,
|
||||
"passthrough": RunnablePassthrough(),
|
||||
}
|
||||
)
|
||||
|
||||
stream = chain.astream({"question": "What is your name?"})
|
||||
|
||||
final_value = None
|
||||
streamed_chunks = []
|
||||
async for chunk in stream:
|
||||
streamed_chunks.append(chunk)
|
||||
if final_value is None:
|
||||
final_value = chunk
|
||||
else:
|
||||
final_value += chunk
|
||||
|
||||
assert streamed_chunks[0] in [{"passthrough": "i"}, {"llm": "i"}, {"chat": "i"}]
|
||||
assert len(streamed_chunks) == len(chat_res) + len(llm_res) + len(llm_res)
|
||||
assert all(len(c.keys()) == 1 for c in streamed_chunks)
|
||||
assert final_value is not None
|
||||
assert final_value.get("chat").content == "i'm a chatbot"
|
||||
assert final_value.get("llm") == "i'm a textbot"
|
||||
assert final_value.get("passthrough") == llm_res
|
||||
|
||||
|
||||
def test_bind_bind() -> None:
|
||||
llm = FakeListLLM(responses=["i'm a textbot"])
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user