mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-21 22:29:51 +00:00
Adds transform support for runnables (#8762)
<!-- Thank you for contributing to LangChain! Replace this comment with: - Description: a description of the change, - Issue: the issue # it fixes (if applicable), - Dependencies: any dependencies required for this change, - Tag maintainer: for a quicker response, tag the relevant maintainer (see below), - Twitter handle: we announce bigger features on Twitter. If your PR gets announced and you'd like a mention, we'll gladly shout you out! Please make sure you're PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` to check this locally. If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. Maintainer responsibilities: - General / Misc / if you don't know who to tag: @baskaryan - DataLoaders / VectorStores / Retrievers: @rlancemartin, @eyurtsev - Models / Prompts: @hwchase17, @baskaryan - Memory: @hwchase17 - Agents / Tools / Toolkits: @hinthornw - Tracing / Callbacks: @agola11 - Async: @agola11 If no one reviews your PR within a few days, feel free to @-mention the same people again. See contribution guidelines for more information on how to write/run tests, lint, etc: https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md --> --------- Co-authored-by: jacoblee93 <jacoblee93@gmail.com> Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com> Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
4d72288487
commit
b8df15cd64
@ -47,6 +47,9 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
parent_run = self.run_map[str(run.parent_run_id)]
|
||||
if parent_run:
|
||||
self._add_child_run(parent_run, run)
|
||||
parent_run.child_execution_order = max(
|
||||
parent_run.child_execution_order, run.child_execution_order
|
||||
)
|
||||
else:
|
||||
logger.debug(f"Parent run with UUID {run.parent_run_id} not found.")
|
||||
self.run_map[str(run.id)] = run
|
||||
@ -254,7 +257,12 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
self._on_chain_start(chain_run)
|
||||
|
||||
def on_chain_end(
|
||||
self, outputs: Dict[str, Any], *, run_id: UUID, **kwargs: Any
|
||||
self,
|
||||
outputs: Dict[str, Any],
|
||||
*,
|
||||
run_id: UUID,
|
||||
inputs: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""End a trace for a chain run."""
|
||||
if not run_id:
|
||||
@ -266,6 +274,8 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
chain_run.outputs = outputs
|
||||
chain_run.end_time = datetime.utcnow()
|
||||
chain_run.events.append({"name": "end", "time": chain_run.end_time})
|
||||
if inputs is not None:
|
||||
chain_run.inputs = inputs
|
||||
self._end_trace(chain_run)
|
||||
self._on_chain_end(chain_run)
|
||||
|
||||
@ -273,6 +283,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
*,
|
||||
inputs: Optional[Dict[str, Any]] = None,
|
||||
run_id: UUID,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
@ -286,6 +297,8 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
chain_run.error = repr(error)
|
||||
chain_run.end_time = datetime.utcnow()
|
||||
chain_run.events.append({"name": "error", "time": chain_run.end_time})
|
||||
if inputs is not None:
|
||||
chain_run.inputs = inputs
|
||||
self._end_trace(chain_run)
|
||||
self._on_chain_error(chain_run)
|
||||
|
||||
|
@ -1,9 +1,11 @@
|
||||
"""Base interface that all chains should implement."""
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
@ -55,18 +57,26 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
"""
|
||||
|
||||
def invoke(
|
||||
self, input: Dict[str, Any], config: Optional[RunnableConfig] = None
|
||||
self,
|
||||
input: Dict[str, Any],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
return self(input, **(config or {}))
|
||||
return self(input, **(config or {}), **kwargs)
|
||||
|
||||
async def ainvoke(
|
||||
self, input: Dict[str, Any], config: Optional[RunnableConfig] = None
|
||||
self,
|
||||
input: Dict[str, Any],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
if type(self)._acall == Chain._acall:
|
||||
# If the chain does not implement async, fall back to default implementation
|
||||
return await super().ainvoke(input, config)
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, partial(self.invoke, input, config, **kwargs)
|
||||
)
|
||||
|
||||
return await self.acall(input, **(config or {}))
|
||||
return await self.acall(input, **(config or {}), **kwargs)
|
||||
|
||||
memory: Optional[BaseMemory] = None
|
||||
"""Optional memory object. Defaults to None.
|
||||
|
@ -3,6 +3,8 @@ import functools
|
||||
import logging
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
@ -27,9 +29,11 @@ class TransformChain(Chain):
|
||||
"""The keys expected by the transform's input dictionary."""
|
||||
output_variables: List[str]
|
||||
"""The keys returned by the transform's output dictionary."""
|
||||
transform: Callable[[Dict[str, str]], Dict[str, str]]
|
||||
transform_cb: Callable[[Dict[str, str]], Dict[str, str]] = Field(alias="transform")
|
||||
"""The transform function."""
|
||||
atransform: Optional[Callable[[Dict[str, Any]], Awaitable[Dict[str, Any]]]] = None
|
||||
atransform_cb: Optional[
|
||||
Callable[[Dict[str, Any]], Awaitable[Dict[str, Any]]]
|
||||
] = Field(None, alias="atransform")
|
||||
"""The async coroutine transform function."""
|
||||
|
||||
@staticmethod
|
||||
@ -62,18 +66,18 @@ class TransformChain(Chain):
|
||||
inputs: Dict[str, str],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
return self.transform(inputs)
|
||||
return self.transform_cb(inputs)
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
if self.atransform is not None:
|
||||
return await self.atransform(inputs)
|
||||
if self.atransform_cb is not None:
|
||||
return await self.atransform_cb(inputs)
|
||||
else:
|
||||
self._log_once(
|
||||
"TransformChain's atransform is not provided, falling"
|
||||
" back to synchronous transform"
|
||||
)
|
||||
return self.transform(inputs)
|
||||
return self.transform_cb(inputs)
|
||||
|
@ -1,9 +1,13 @@
|
||||
"""Fake ChatModel for testing purposes."""
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.chat_models.base import SimpleChatModel
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain.schema.messages import AIMessageChunk, BaseMessage
|
||||
from langchain.schema.output import ChatGenerationChunk
|
||||
|
||||
|
||||
class FakeListChatModel(SimpleChatModel):
|
||||
@ -31,6 +35,36 @@ class FakeListChatModel(SimpleChatModel):
|
||||
self.i = 0
|
||||
return response
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Union[List[str], None] = None,
|
||||
run_manager: Union[CallbackManagerForLLMRun, None] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
response = self.responses[self.i]
|
||||
if self.i < len(self.responses) - 1:
|
||||
self.i += 1
|
||||
else:
|
||||
self.i = 0
|
||||
for c in response:
|
||||
yield ChatGenerationChunk(message=AIMessageChunk(content=c))
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Union[List[str], None] = None,
|
||||
run_manager: Union[AsyncCallbackManagerForLLMRun, None] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
response = self.responses[self.i]
|
||||
if self.i < len(self.responses) - 1:
|
||||
self.i += 1
|
||||
else:
|
||||
self.i = 0
|
||||
for c in response:
|
||||
yield ChatGenerationChunk(message=AIMessageChunk(content=c))
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
return {"responses": self.responses}
|
||||
|
@ -1,10 +1,12 @@
|
||||
from typing import Any, List, Mapping, Optional
|
||||
from typing import Any, AsyncIterator, Iterator, List, Mapping, Optional
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.schema.language_model import LanguageModelInput
|
||||
from langchain.schema.runnable import RunnableConfig
|
||||
|
||||
|
||||
class FakeListLLM(LLM):
|
||||
@ -51,3 +53,29 @@ class FakeListLLM(LLM):
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
return {"responses": self.responses}
|
||||
|
||||
|
||||
class FakeStreamingListLLM(FakeListLLM):
|
||||
def stream(
|
||||
self,
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[str]:
|
||||
result = self.invoke(input, config)
|
||||
for c in result:
|
||||
yield c
|
||||
|
||||
async def astream(
|
||||
self,
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[str]:
|
||||
result = await self.ainvoke(input, config)
|
||||
for c in result:
|
||||
yield c
|
||||
|
@ -2,7 +2,17 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Generic, List, Optional, TypeVar, Union
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Dict,
|
||||
Generic,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
from langchain.load.serializable import Serializable
|
||||
from langchain.schema.messages import BaseMessage
|
||||
@ -47,7 +57,7 @@ class BaseGenerationOutputParser(
|
||||
BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T]
|
||||
):
|
||||
def invoke(
|
||||
self, input: str | BaseMessage, config: RunnableConfig | None = None
|
||||
self, input: Union[str, BaseMessage], config: Optional[RunnableConfig] = None
|
||||
) -> T:
|
||||
if isinstance(input, BaseMessage):
|
||||
return self._call_with_config(
|
||||
@ -115,7 +125,7 @@ class BaseOutputParser(BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T]
|
||||
""" # noqa: E501
|
||||
|
||||
def invoke(
|
||||
self, input: str | BaseMessage, config: RunnableConfig | None = None
|
||||
self, input: Union[str, BaseMessage], config: Optional[RunnableConfig] = None
|
||||
) -> T:
|
||||
if isinstance(input, BaseMessage):
|
||||
return self._call_with_config(
|
||||
@ -242,8 +252,47 @@ class BaseOutputParser(BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T]
|
||||
return output_parser_dict
|
||||
|
||||
|
||||
class StrOutputParser(BaseOutputParser[str]):
|
||||
"""OutputParser that parses LLMResult into the top likely string.."""
|
||||
class BaseTransformOutputParser(BaseOutputParser[T]):
|
||||
"""Base class for an output parser that can handle streaming input."""
|
||||
|
||||
def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[T]:
|
||||
for chunk in input:
|
||||
if isinstance(chunk, BaseMessage):
|
||||
yield self.parse_result([ChatGeneration(message=chunk)])
|
||||
else:
|
||||
yield self.parse_result([Generation(text=chunk)])
|
||||
|
||||
async def _atransform(
|
||||
self, input: AsyncIterator[Union[str, BaseMessage]]
|
||||
) -> AsyncIterator[T]:
|
||||
async for chunk in input:
|
||||
if isinstance(chunk, BaseMessage):
|
||||
yield self.parse_result([ChatGeneration(message=chunk)])
|
||||
else:
|
||||
yield self.parse_result([Generation(text=chunk)])
|
||||
|
||||
def transform(
|
||||
self,
|
||||
input: Iterator[Union[str, BaseMessage]],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
) -> Iterator[T]:
|
||||
yield from self._transform_stream_with_config(
|
||||
input, self._transform, config, run_type="parser"
|
||||
)
|
||||
|
||||
async def atransform(
|
||||
self,
|
||||
input: AsyncIterator[Union[str, BaseMessage]],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
) -> AsyncIterator[T]:
|
||||
async for chunk in self._atransform_stream_with_config(
|
||||
input, self._atransform, config, run_type="parser"
|
||||
):
|
||||
yield chunk
|
||||
|
||||
|
||||
class StrOutputParser(BaseTransformOutputParser[str]):
|
||||
"""OutputParser that parses LLMResult into the top likely string."""
|
||||
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
|
@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from itertools import tee
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
@ -29,6 +30,7 @@ from pydantic import Field
|
||||
from langchain.callbacks.base import BaseCallbackManager, Callbacks
|
||||
from langchain.load.dump import dumpd
|
||||
from langchain.load.serializable import Serializable
|
||||
from langchain.utils.aiter import atee, py_anext
|
||||
|
||||
|
||||
async def _gated_coro(semaphore: asyncio.Semaphore, coro: Coroutine) -> Any:
|
||||
@ -92,6 +94,8 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
) -> RunnableSequence[Other, Output]:
|
||||
return RunnableSequence(first=_coerce_to_runnable(other), last=self)
|
||||
|
||||
""" --- Public API --- """
|
||||
|
||||
@abstractmethod
|
||||
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
|
||||
...
|
||||
@ -99,6 +103,10 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
async def ainvoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None
|
||||
) -> Output:
|
||||
"""
|
||||
Default implementation of ainvoke, which calls invoke in a thread pool.
|
||||
Subclasses should override this method if they can run asynchronously.
|
||||
"""
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, self.invoke, input, config
|
||||
)
|
||||
@ -110,6 +118,10 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
*,
|
||||
max_concurrency: Optional[int] = None,
|
||||
) -> List[Output]:
|
||||
"""
|
||||
Default implementation of batch, which calls invoke N times.
|
||||
Subclasses should override this method if they can batch more efficiently.
|
||||
"""
|
||||
configs = self._get_config_list(config, len(inputs))
|
||||
|
||||
# If there's only one input, don't bother with the executor
|
||||
@ -126,6 +138,10 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
*,
|
||||
max_concurrency: Optional[int] = None,
|
||||
) -> List[Output]:
|
||||
"""
|
||||
Default implementation of abatch, which calls ainvoke N times.
|
||||
Subclasses should override this method if they can batch more efficiently.
|
||||
"""
|
||||
configs = self._get_config_list(config, len(inputs))
|
||||
coros = map(self.ainvoke, inputs, configs)
|
||||
|
||||
@ -134,22 +150,90 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
def stream(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None
|
||||
) -> Iterator[Output]:
|
||||
"""
|
||||
Default implementation of stream, which calls invoke.
|
||||
Subclasses should override this method if they support streaming output.
|
||||
"""
|
||||
yield self.invoke(input, config)
|
||||
|
||||
async def astream(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None
|
||||
) -> AsyncIterator[Output]:
|
||||
"""
|
||||
Default implementation of astream, which calls ainvoke.
|
||||
Subclasses should override this method if they support streaming output.
|
||||
"""
|
||||
yield await self.ainvoke(input, config)
|
||||
|
||||
def transform(
|
||||
self, input: Iterator[Input], config: Optional[RunnableConfig] = None
|
||||
) -> Iterator[Output]:
|
||||
"""
|
||||
Default implementation of transform, which buffers input and then calls stream.
|
||||
Subclasses should override this method if they can start producing output while
|
||||
input is still being generated.
|
||||
"""
|
||||
final: Union[Input, None] = None
|
||||
|
||||
for chunk in input:
|
||||
if final is None:
|
||||
final = chunk
|
||||
else:
|
||||
# Make a best effort to gather, for any type that supports `+`
|
||||
# This method should throw an error if gathering fails.
|
||||
final += chunk # type: ignore[operator]
|
||||
if final:
|
||||
yield from self.stream(final, config)
|
||||
|
||||
async def atransform(
|
||||
self, input: AsyncIterator[Input], config: Optional[RunnableConfig] = None
|
||||
) -> AsyncIterator[Output]:
|
||||
"""
|
||||
Default implementation of atransform, which buffers input and calls astream.
|
||||
Subclasses should override this method if they can start producing output while
|
||||
input is still being generated.
|
||||
"""
|
||||
final: Union[Input, None] = None
|
||||
|
||||
async for chunk in input:
|
||||
if final is None:
|
||||
final = chunk
|
||||
else:
|
||||
# Make a best effort to gather, for any type that supports `+`
|
||||
# This method should throw an error if gathering fails.
|
||||
final += chunk # type: ignore[operator]
|
||||
|
||||
if final:
|
||||
async for output in self.astream(final, config):
|
||||
yield output
|
||||
|
||||
def bind(self, **kwargs: Any) -> Runnable[Input, Output]:
|
||||
"""
|
||||
Bind arguments to a Runnable, returning a new Runnable.
|
||||
"""
|
||||
return RunnableBinding(bound=self, kwargs=kwargs)
|
||||
|
||||
def with_fallbacks(
|
||||
self,
|
||||
fallbacks: Sequence[Runnable[Input, Output]],
|
||||
*,
|
||||
exceptions_to_handle: Tuple[Type[BaseException]] = (Exception,),
|
||||
) -> RunnableWithFallbacks[Input, Output]:
|
||||
return RunnableWithFallbacks(
|
||||
runnable=self,
|
||||
fallbacks=fallbacks,
|
||||
exceptions_to_handle=exceptions_to_handle,
|
||||
)
|
||||
|
||||
""" --- Helper methods for Subclasses --- """
|
||||
|
||||
def _get_config_list(
|
||||
self, config: Optional[Union[RunnableConfig, List[RunnableConfig]]], length: int
|
||||
) -> List[RunnableConfig]:
|
||||
"""
|
||||
Helper method to get a list of configs from a single config or a list of
|
||||
configs, useful for subclasses overriding batch() or abatch().
|
||||
"""
|
||||
if isinstance(config, list) and len(config) != length:
|
||||
raise ValueError(
|
||||
f"config must be a list of the same length as inputs, "
|
||||
@ -169,6 +253,8 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
config: Optional[RunnableConfig],
|
||||
run_type: Optional[str] = None,
|
||||
) -> Output:
|
||||
"""Helper method to transform an Input value to an Output value,
|
||||
with callbacks. Use this method to implement invoke() in subclasses."""
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
|
||||
config = config or {}
|
||||
@ -200,6 +286,8 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
config: Optional[RunnableConfig],
|
||||
run_type: Optional[str] = None,
|
||||
) -> Output:
|
||||
"""Helper method to transform an Input value to an Output value,
|
||||
with callbacks. Use this method to implement ainvoke() in subclasses."""
|
||||
from langchain.callbacks.manager import AsyncCallbackManager
|
||||
|
||||
config = config or {}
|
||||
@ -224,20 +312,154 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
)
|
||||
return output
|
||||
|
||||
def with_fallbacks(
|
||||
def _transform_stream_with_config(
|
||||
self,
|
||||
fallbacks: Sequence[Runnable[Input, Output]],
|
||||
*,
|
||||
exceptions_to_handle: Tuple[Type[BaseException]] = (Exception,),
|
||||
) -> RunnableWithFallbacks[Input, Output]:
|
||||
return RunnableWithFallbacks(
|
||||
runnable=self,
|
||||
fallbacks=fallbacks,
|
||||
exceptions_to_handle=exceptions_to_handle,
|
||||
input: Iterator[Input],
|
||||
transformer: Callable[[Iterator[Input]], Iterator[Output]],
|
||||
config: Optional[RunnableConfig],
|
||||
run_type: Optional[str] = None,
|
||||
) -> Iterator[Output]:
|
||||
"""Helper method to transform an Iterator of Input values into an Iterator of
|
||||
Output values, with callbacks.
|
||||
Use this to implement `stream()` or `transform()` in Runnable subclasses."""
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
|
||||
# tee the input so we can iterate over it twice
|
||||
input_for_tracing, input_for_transform = tee(input, 2)
|
||||
# Start the input iterator to ensure the input runnable starts before this one
|
||||
final_input: Optional[Input] = next(input_for_tracing, None)
|
||||
final_input_supported = True
|
||||
final_output: Optional[Output] = None
|
||||
final_output_supported = True
|
||||
|
||||
config = config or {}
|
||||
callback_manager = CallbackManager.configure(
|
||||
inheritable_callbacks=config.get("callbacks"),
|
||||
inheritable_tags=config.get("tags"),
|
||||
inheritable_metadata=config.get("metadata"),
|
||||
)
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self),
|
||||
{"input": ""},
|
||||
run_type=run_type,
|
||||
)
|
||||
try:
|
||||
for chunk in transformer(input_for_transform):
|
||||
yield chunk
|
||||
if final_output_supported:
|
||||
if final_output is None:
|
||||
final_output = chunk
|
||||
else:
|
||||
try:
|
||||
final_output += chunk # type: ignore[operator]
|
||||
except TypeError:
|
||||
final_output = None
|
||||
final_output_supported = False
|
||||
for ichunk in input_for_tracing:
|
||||
if final_input_supported:
|
||||
if final_input is None:
|
||||
final_input = ichunk
|
||||
else:
|
||||
try:
|
||||
final_input += ichunk # type: ignore[operator]
|
||||
except TypeError:
|
||||
final_input = None
|
||||
final_input_supported = False
|
||||
except Exception as e:
|
||||
run_manager.on_chain_error(
|
||||
e,
|
||||
inputs=final_input
|
||||
if isinstance(final_input, dict)
|
||||
else {"input": final_input},
|
||||
)
|
||||
raise
|
||||
else:
|
||||
run_manager.on_chain_end(
|
||||
final_output
|
||||
if isinstance(final_output, dict)
|
||||
else {"output": final_output},
|
||||
inputs=final_input
|
||||
if isinstance(final_input, dict)
|
||||
else {"input": final_input},
|
||||
)
|
||||
|
||||
async def _atransform_stream_with_config(
|
||||
self,
|
||||
input: AsyncIterator[Input],
|
||||
transformer: Callable[[AsyncIterator[Input]], AsyncIterator[Output]],
|
||||
config: Optional[RunnableConfig],
|
||||
run_type: Optional[str] = None,
|
||||
) -> AsyncIterator[Output]:
|
||||
"""Helper method to transform an Async Iterator of Input values into an Async
|
||||
Iterator of Output values, with callbacks.
|
||||
Use this to implement `astream()` or `atransform()` in Runnable subclasses."""
|
||||
from langchain.callbacks.manager import AsyncCallbackManager
|
||||
|
||||
# tee the input so we can iterate over it twice
|
||||
input_for_tracing, input_for_transform = atee(input, 2)
|
||||
# Start the input iterator to ensure the input runnable starts before this one
|
||||
final_input: Optional[Input] = await py_anext(input_for_tracing, None)
|
||||
final_input_supported = True
|
||||
final_output: Optional[Output] = None
|
||||
final_output_supported = True
|
||||
|
||||
config = config or {}
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
inheritable_callbacks=config.get("callbacks"),
|
||||
inheritable_tags=config.get("tags"),
|
||||
inheritable_metadata=config.get("metadata"),
|
||||
)
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
dumpd(self),
|
||||
{"input": ""},
|
||||
run_type=run_type,
|
||||
)
|
||||
try:
|
||||
async for chunk in transformer(input_for_transform):
|
||||
yield chunk
|
||||
if final_output_supported:
|
||||
if final_output is None:
|
||||
final_output = chunk
|
||||
else:
|
||||
try:
|
||||
final_output += chunk # type: ignore[operator]
|
||||
except TypeError:
|
||||
final_output = None
|
||||
final_output_supported = False
|
||||
async for ichunk in input_for_tracing:
|
||||
if final_input_supported:
|
||||
if final_input is None:
|
||||
final_input = ichunk
|
||||
else:
|
||||
try:
|
||||
final_input += ichunk # type: ignore[operator]
|
||||
except TypeError:
|
||||
final_input = None
|
||||
final_input_supported = False
|
||||
except Exception as e:
|
||||
await run_manager.on_chain_error(
|
||||
e,
|
||||
inputs=final_input
|
||||
if isinstance(final_input, dict)
|
||||
else {"input": final_input},
|
||||
)
|
||||
raise
|
||||
else:
|
||||
await run_manager.on_chain_end(
|
||||
final_output
|
||||
if isinstance(final_output, dict)
|
||||
else {"output": final_output},
|
||||
inputs=final_input
|
||||
if isinstance(final_input, dict)
|
||||
else {"input": final_input},
|
||||
)
|
||||
|
||||
|
||||
class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
||||
"""
|
||||
A Runnable that can fallback to other Runnables if it fails.
|
||||
"""
|
||||
|
||||
runnable: Runnable[Input, Output]
|
||||
fallbacks: Sequence[Runnable[Input, Output]]
|
||||
exceptions_to_handle: Tuple[Type[BaseException]] = (Exception,)
|
||||
@ -467,6 +689,10 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
||||
|
||||
|
||||
class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
"""
|
||||
A sequence of runnables, where the output of each is the input of the next.
|
||||
"""
|
||||
|
||||
first: Runnable[Input, Any]
|
||||
middle: List[Runnable[Any, Any]] = Field(default_factory=list)
|
||||
last: Runnable[Any, Output]
|
||||
@ -738,9 +964,18 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
||||
)
|
||||
|
||||
steps = [self.first] + self.middle + [self.last]
|
||||
streaming_start_index = 0
|
||||
|
||||
for i in range(len(steps) - 1, 0, -1):
|
||||
if type(steps[i]).transform != Runnable.transform:
|
||||
streaming_start_index = i - 1
|
||||
else:
|
||||
break
|
||||
|
||||
# invoke the first steps
|
||||
try:
|
||||
for step in [self.first] + self.middle:
|
||||
for step in steps[0:streaming_start_index]:
|
||||
input = step.invoke(
|
||||
input,
|
||||
# mark each step as a child run
|
||||
@ -750,15 +985,20 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
run_manager.on_chain_error(e)
|
||||
raise
|
||||
|
||||
# stream the last step
|
||||
# stream the last steps
|
||||
final: Union[Output, None] = None
|
||||
final_supported = True
|
||||
try:
|
||||
for output in self.last.stream(
|
||||
input,
|
||||
# mark the last step as a child run
|
||||
_patch_config(config, run_manager.get_child()),
|
||||
):
|
||||
# stream the first of the last steps with non-streaming input
|
||||
final_pipeline = steps[streaming_start_index].stream(
|
||||
input, _patch_config(config, run_manager.get_child())
|
||||
)
|
||||
# stream the rest of the last steps with streaming input
|
||||
for step in steps[streaming_start_index + 1 :]:
|
||||
final_pipeline = step.transform(
|
||||
final_pipeline, _patch_config(config, run_manager.get_child())
|
||||
)
|
||||
for output in final_pipeline:
|
||||
yield output
|
||||
# Accumulate output if possible, otherwise disable accumulation
|
||||
if final_supported:
|
||||
@ -801,9 +1041,18 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
||||
)
|
||||
|
||||
steps = [self.first] + self.middle + [self.last]
|
||||
streaming_start_index = len(steps) - 1
|
||||
|
||||
for i in range(len(steps) - 1, 0, -1):
|
||||
if type(steps[i]).transform != Runnable.transform:
|
||||
streaming_start_index = i - 1
|
||||
else:
|
||||
break
|
||||
|
||||
# invoke the first steps
|
||||
try:
|
||||
for step in [self.first] + self.middle:
|
||||
for step in steps[0:streaming_start_index]:
|
||||
input = await step.ainvoke(
|
||||
input,
|
||||
# mark each step as a child run
|
||||
@ -813,15 +1062,20 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
await run_manager.on_chain_error(e)
|
||||
raise
|
||||
|
||||
# stream the last step
|
||||
# stream the last steps
|
||||
final: Union[Output, None] = None
|
||||
final_supported = True
|
||||
try:
|
||||
async for output in self.last.astream(
|
||||
input,
|
||||
# mark the last step as a child run
|
||||
_patch_config(config, run_manager.get_child()),
|
||||
):
|
||||
# stream the first of the last steps with non-streaming input
|
||||
final_pipeline = steps[streaming_start_index].astream(
|
||||
input, _patch_config(config, run_manager.get_child())
|
||||
)
|
||||
# stream the rest of the last steps with streaming input
|
||||
for step in steps[streaming_start_index + 1 :]:
|
||||
final_pipeline = step.atransform(
|
||||
final_pipeline, _patch_config(config, run_manager.get_child())
|
||||
)
|
||||
async for output in final_pipeline:
|
||||
yield output
|
||||
# Accumulate output if possible, otherwise disable accumulation
|
||||
if final_supported:
|
||||
@ -845,6 +1099,11 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
|
||||
|
||||
class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
||||
"""
|
||||
A runnable that runs a mapping of runnables in parallel,
|
||||
and returns a mapping of their outputs.
|
||||
"""
|
||||
|
||||
steps: Mapping[str, Runnable[Input, Any]]
|
||||
|
||||
def __init__(
|
||||
@ -957,6 +1216,10 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
||||
|
||||
|
||||
class RunnableLambda(Runnable[Input, Output]):
|
||||
"""
|
||||
A runnable that runs a callable.
|
||||
"""
|
||||
|
||||
def __init__(self, func: Callable[[Input], Output]) -> None:
|
||||
if callable(func):
|
||||
self.func = func
|
||||
@ -977,6 +1240,10 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
|
||||
|
||||
class RunnablePassthrough(Serializable, Runnable[Input, Input]):
|
||||
"""
|
||||
A runnable that passes through the input.
|
||||
"""
|
||||
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
return True
|
||||
@ -986,6 +1253,10 @@ class RunnablePassthrough(Serializable, Runnable[Input, Input]):
|
||||
|
||||
|
||||
class RunnableBinding(Serializable, Runnable[Input, Output]):
|
||||
"""
|
||||
A runnable that binds a runnable to a set of kwargs.
|
||||
"""
|
||||
|
||||
bound: Runnable[Input, Output]
|
||||
|
||||
kwargs: Mapping[str, Any]
|
||||
@ -1041,6 +1312,17 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
||||
async for item in self.bound.astream(input, config, **self.kwargs):
|
||||
yield item
|
||||
|
||||
def transform(
|
||||
self, input: Iterator[Input], config: Optional[RunnableConfig] = None
|
||||
) -> Iterator[Output]:
|
||||
yield from self.bound.transform(input, config, **self.kwargs)
|
||||
|
||||
async def atransform(
|
||||
self, input: AsyncIterator[Input], config: Optional[RunnableConfig] = None
|
||||
) -> AsyncIterator[Output]:
|
||||
async for item in self.bound.atransform(input, config, **self.kwargs):
|
||||
yield item
|
||||
|
||||
|
||||
class RouterInput(TypedDict):
|
||||
key: str
|
||||
@ -1050,6 +1332,11 @@ class RouterInput(TypedDict):
|
||||
class RouterRunnable(
|
||||
Serializable, Generic[Input, Output], Runnable[RouterInput, Output]
|
||||
):
|
||||
"""
|
||||
A runnable that routes to a set of runnables based on Input['key'].
|
||||
Returns the output of the selected runnable.
|
||||
"""
|
||||
|
||||
runnables: Mapping[str, Runnable[Input, Output]]
|
||||
|
||||
def __init__(self, runnables: Mapping[str, Runnable[Input, Output]]) -> None:
|
||||
|
191
libs/langchain/langchain/utils/aiter.py
Normal file
191
libs/langchain/langchain/utils/aiter.py
Normal file
@ -0,0 +1,191 @@
|
||||
"""
|
||||
Adapted from
|
||||
https://github.com/maxfischer2781/asyncstdlib/blob/master/asyncstdlib/itertools.py
|
||||
MIT License
|
||||
"""
|
||||
|
||||
from collections import deque
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
AsyncIterator,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Deque,
|
||||
Generic,
|
||||
Iterator,
|
||||
List,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
overload,
|
||||
)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
_no_default = object()
|
||||
|
||||
|
||||
# https://github.com/python/cpython/blob/main/Lib/test/test_asyncgen.py#L54
|
||||
# before 3.10, the builtin anext() was not available
|
||||
def py_anext(
|
||||
iterator: AsyncIterator[T], default: Union[T, Any] = _no_default
|
||||
) -> Awaitable[Union[T, None, Any]]:
|
||||
"""Pure-Python implementation of anext() for testing purposes.
|
||||
|
||||
Closely matches the builtin anext() C implementation.
|
||||
Can be used to compare the built-in implementation of the inner
|
||||
coroutines machinery to C-implementation of __anext__() and send()
|
||||
or throw() on the returned generator.
|
||||
"""
|
||||
|
||||
try:
|
||||
__anext__ = cast(
|
||||
Callable[[AsyncIterator[T]], Awaitable[T]], type(iterator).__anext__
|
||||
)
|
||||
except AttributeError:
|
||||
raise TypeError(f"{iterator!r} is not an async iterator")
|
||||
|
||||
if default is _no_default:
|
||||
return __anext__(iterator)
|
||||
|
||||
async def anext_impl() -> Union[T, Any]:
|
||||
try:
|
||||
# The C code is way more low-level than this, as it implements
|
||||
# all methods of the iterator protocol. In this implementation
|
||||
# we're relying on higher-level coroutine concepts, but that's
|
||||
# exactly what we want -- crosstest pure-Python high-level
|
||||
# implementation and low-level C anext() iterators.
|
||||
return await __anext__(iterator)
|
||||
except StopAsyncIteration:
|
||||
return default
|
||||
|
||||
return anext_impl()
|
||||
|
||||
|
||||
async def tee_peer(
|
||||
iterator: AsyncIterator[T],
|
||||
# the buffer specific to this peer
|
||||
buffer: Deque[T],
|
||||
# the buffers of all peers, including our own
|
||||
peers: List[Deque[T]],
|
||||
) -> AsyncGenerator[T, None]:
|
||||
"""An individual iterator of a :py:func:`~.tee`"""
|
||||
try:
|
||||
while True:
|
||||
if not buffer:
|
||||
# Another peer produced an item while we were waiting for the lock.
|
||||
# Proceed with the next loop iteration to yield the item.
|
||||
if buffer:
|
||||
continue
|
||||
try:
|
||||
item = await iterator.__anext__()
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
else:
|
||||
# Append to all buffers, including our own. We'll fetch our
|
||||
# item from the buffer again, instead of yielding it directly.
|
||||
# This ensures the proper item ordering if any of our peers
|
||||
# are fetching items concurrently. They may have buffered their
|
||||
# item already.
|
||||
for peer_buffer in peers:
|
||||
peer_buffer.append(item)
|
||||
yield buffer.popleft()
|
||||
finally:
|
||||
# this peer is done – remove its buffer
|
||||
for idx, peer_buffer in enumerate(peers): # pragma: no branch
|
||||
if peer_buffer is buffer:
|
||||
peers.pop(idx)
|
||||
break
|
||||
# if we are the last peer, try and close the iterator
|
||||
if not peers and hasattr(iterator, "aclose"):
|
||||
await iterator.aclose()
|
||||
|
||||
|
||||
class Tee(Generic[T]):
|
||||
"""
|
||||
Create ``n`` separate asynchronous iterators over ``iterable``
|
||||
|
||||
This splits a single ``iterable`` into multiple iterators, each providing
|
||||
the same items in the same order.
|
||||
All child iterators may advance separately but share the same items
|
||||
from ``iterable`` -- when the most advanced iterator retrieves an item,
|
||||
it is buffered until the least advanced iterator has yielded it as well.
|
||||
A ``tee`` works lazily and can handle an infinite ``iterable``, provided
|
||||
that all iterators advance.
|
||||
|
||||
.. code-block:: python3
|
||||
|
||||
async def derivative(sensor_data):
|
||||
previous, current = a.tee(sensor_data, n=2)
|
||||
await a.anext(previous) # advance one iterator
|
||||
return a.map(operator.sub, previous, current)
|
||||
|
||||
Unlike :py:func:`itertools.tee`, :py:func:`~.tee` returns a custom type instead
|
||||
of a :py:class:`tuple`. Like a tuple, it can be indexed, iterated and unpacked
|
||||
to get the child iterators. In addition, its :py:meth:`~.tee.aclose` method
|
||||
immediately closes all children, and it can be used in an ``async with`` context
|
||||
for the same effect.
|
||||
|
||||
If ``iterable`` is an iterator and read elsewhere, ``tee`` will *not*
|
||||
provide these items. Also, ``tee`` must internally buffer each item until the
|
||||
last iterator has yielded it; if the most and least advanced iterator differ
|
||||
by most data, using a :py:class:`list` is more efficient (but not lazy).
|
||||
|
||||
If the underlying iterable is concurrency safe (``anext`` may be awaited
|
||||
concurrently) the resulting iterators are concurrency safe as well. Otherwise,
|
||||
the iterators are safe if there is only ever one single "most advanced" iterator.
|
||||
To enforce sequential use of ``anext``, provide a ``lock``
|
||||
- e.g. an :py:class:`asyncio.Lock` instance in an :py:mod:`asyncio` application -
|
||||
and access is automatically synchronised.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
iterable: AsyncIterator[T],
|
||||
n: int = 2,
|
||||
):
|
||||
self._iterator = iterable.__aiter__() # before 3.10 aiter() doesn't exist
|
||||
self._buffers: List[Deque[T]] = [deque() for _ in range(n)]
|
||||
self._children = tuple(
|
||||
tee_peer(
|
||||
iterator=self._iterator,
|
||||
buffer=buffer,
|
||||
peers=self._buffers,
|
||||
)
|
||||
for buffer in self._buffers
|
||||
)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._children)
|
||||
|
||||
@overload
|
||||
def __getitem__(self, item: int) -> AsyncIterator[T]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __getitem__(self, item: slice) -> Tuple[AsyncIterator[T], ...]:
|
||||
...
|
||||
|
||||
def __getitem__(
|
||||
self, item: Union[int, slice]
|
||||
) -> Union[AsyncIterator[T], Tuple[AsyncIterator[T], ...]]:
|
||||
return self._children[item]
|
||||
|
||||
def __iter__(self) -> Iterator[AsyncIterator[T]]:
|
||||
yield from self._children
|
||||
|
||||
async def __aenter__(self) -> "Tee[T]":
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool:
|
||||
await self.aclose()
|
||||
return False
|
||||
|
||||
async def aclose(self) -> None:
|
||||
for child in self._children:
|
||||
await child.aclose()
|
||||
|
||||
|
||||
atee = Tee
|
@ -15,7 +15,7 @@ def dummy_transform(inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
return outputs
|
||||
|
||||
|
||||
def test_tranform_chain() -> None:
|
||||
def test_transform_chain() -> None:
|
||||
"""Test basic transform chain."""
|
||||
transform_chain = TransformChain(
|
||||
input_variables=["first_name", "last_name"],
|
||||
|
File diff suppressed because one or more lines are too long
@ -11,7 +11,7 @@ from langchain.callbacks.manager import Callbacks
|
||||
from langchain.callbacks.tracers.base import BaseTracer
|
||||
from langchain.callbacks.tracers.schemas import Run
|
||||
from langchain.chat_models.fake import FakeListChatModel
|
||||
from langchain.llms.fake import FakeListLLM
|
||||
from langchain.llms.fake import FakeListLLM, FakeStreamingListLLM
|
||||
from langchain.load.dump import dumpd, dumps
|
||||
from langchain.output_parsers.list import CommaSeparatedListOutputParser
|
||||
from langchain.prompts.chat import (
|
||||
@ -22,6 +22,7 @@ from langchain.prompts.chat import (
|
||||
)
|
||||
from langchain.schema.document import Document
|
||||
from langchain.schema.messages import AIMessage, HumanMessage, SystemMessage
|
||||
from langchain.schema.output_parser import StrOutputParser
|
||||
from langchain.schema.retriever import BaseRetriever
|
||||
from langchain.schema.runnable import (
|
||||
RouterRunnable,
|
||||
@ -61,6 +62,8 @@ class FakeTracer(BaseTracer):
|
||||
if run.parent_run_id
|
||||
else None,
|
||||
"child_runs": [self._copy_run(child) for child in run.child_runs],
|
||||
"execution_order": None,
|
||||
"child_execution_order": None,
|
||||
}
|
||||
)
|
||||
|
||||
@ -302,7 +305,7 @@ async def test_prompt_with_chat_model(
|
||||
tracer = FakeTracer()
|
||||
assert [
|
||||
*chain.stream({"question": "What is your name?"}, dict(callbacks=[tracer]))
|
||||
] == [AIMessage(content="foo")]
|
||||
] == [AIMessage(content="f"), AIMessage(content="o"), AIMessage(content="o")]
|
||||
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
|
||||
assert chat_spy.call_args.args[1] == ChatPromptValue(
|
||||
messages=[
|
||||
@ -678,7 +681,12 @@ async def test_router_runnable(
|
||||
"key": "math",
|
||||
"input": {"question": "2 + 2"},
|
||||
}
|
||||
assert tracer.runs == snapshot
|
||||
assert len([r for r in tracer.runs if r.parent_run_id is None]) == 1
|
||||
parent_run = next(r for r in tracer.runs if r.parent_run_id is None)
|
||||
assert len(parent_run.child_runs) == 2
|
||||
router_run = parent_run.child_runs[1]
|
||||
assert router_run.name == "RunnableSequence" # TODO: should be RunnableRouter
|
||||
assert len(router_run.child_runs) == 2
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
@ -758,6 +766,45 @@ def test_bind_bind() -> None:
|
||||
) == dumpd(llm.bind(stop=["Observation:"], one="two", hello="world"))
|
||||
|
||||
|
||||
def test_deep_stream() -> None:
|
||||
prompt = (
|
||||
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
|
||||
+ "{question}"
|
||||
)
|
||||
llm = FakeStreamingListLLM(responses=["foo-lish"])
|
||||
|
||||
chain = prompt | llm | StrOutputParser()
|
||||
|
||||
stream = chain.stream({"question": "What up"})
|
||||
|
||||
chunks = []
|
||||
for chunk in stream:
|
||||
chunks.append(chunk)
|
||||
|
||||
assert len(chunks) == len("foo-lish")
|
||||
assert "".join(chunks) == "foo-lish"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deep_astream() -> None:
|
||||
prompt = (
|
||||
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
|
||||
+ "{question}"
|
||||
)
|
||||
llm = FakeStreamingListLLM(responses=["foo-lish"])
|
||||
|
||||
chain = prompt | llm | StrOutputParser()
|
||||
|
||||
stream = chain.astream({"question": "What up"})
|
||||
|
||||
chunks = []
|
||||
async for chunk in stream:
|
||||
chunks.append(chunk)
|
||||
|
||||
assert len(chunks) == len("foo-lish")
|
||||
assert "".join(chunks) == "foo-lish"
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def llm_with_fallbacks() -> RunnableWithFallbacks:
|
||||
error_llm = FakeListLLM(responses=["foo"], i=1)
|
||||
|
Loading…
Reference in New Issue
Block a user