mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-22 14:49:29 +00:00
Adds transform support for runnables (#8762)
<!-- Thank you for contributing to LangChain! Replace this comment with: - Description: a description of the change, - Issue: the issue # it fixes (if applicable), - Dependencies: any dependencies required for this change, - Tag maintainer: for a quicker response, tag the relevant maintainer (see below), - Twitter handle: we announce bigger features on Twitter. If your PR gets announced and you'd like a mention, we'll gladly shout you out! Please make sure you're PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` to check this locally. If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. Maintainer responsibilities: - General / Misc / if you don't know who to tag: @baskaryan - DataLoaders / VectorStores / Retrievers: @rlancemartin, @eyurtsev - Models / Prompts: @hwchase17, @baskaryan - Memory: @hwchase17 - Agents / Tools / Toolkits: @hinthornw - Tracing / Callbacks: @agola11 - Async: @agola11 If no one reviews your PR within a few days, feel free to @-mention the same people again. See contribution guidelines for more information on how to write/run tests, lint, etc: https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md --> --------- Co-authored-by: jacoblee93 <jacoblee93@gmail.com> Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com> Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
4d72288487
commit
b8df15cd64
@ -47,6 +47,9 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
parent_run = self.run_map[str(run.parent_run_id)]
|
parent_run = self.run_map[str(run.parent_run_id)]
|
||||||
if parent_run:
|
if parent_run:
|
||||||
self._add_child_run(parent_run, run)
|
self._add_child_run(parent_run, run)
|
||||||
|
parent_run.child_execution_order = max(
|
||||||
|
parent_run.child_execution_order, run.child_execution_order
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.debug(f"Parent run with UUID {run.parent_run_id} not found.")
|
logger.debug(f"Parent run with UUID {run.parent_run_id} not found.")
|
||||||
self.run_map[str(run.id)] = run
|
self.run_map[str(run.id)] = run
|
||||||
@ -254,7 +257,12 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
self._on_chain_start(chain_run)
|
self._on_chain_start(chain_run)
|
||||||
|
|
||||||
def on_chain_end(
|
def on_chain_end(
|
||||||
self, outputs: Dict[str, Any], *, run_id: UUID, **kwargs: Any
|
self,
|
||||||
|
outputs: Dict[str, Any],
|
||||||
|
*,
|
||||||
|
run_id: UUID,
|
||||||
|
inputs: Optional[Dict[str, Any]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""End a trace for a chain run."""
|
"""End a trace for a chain run."""
|
||||||
if not run_id:
|
if not run_id:
|
||||||
@ -266,6 +274,8 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
chain_run.outputs = outputs
|
chain_run.outputs = outputs
|
||||||
chain_run.end_time = datetime.utcnow()
|
chain_run.end_time = datetime.utcnow()
|
||||||
chain_run.events.append({"name": "end", "time": chain_run.end_time})
|
chain_run.events.append({"name": "end", "time": chain_run.end_time})
|
||||||
|
if inputs is not None:
|
||||||
|
chain_run.inputs = inputs
|
||||||
self._end_trace(chain_run)
|
self._end_trace(chain_run)
|
||||||
self._on_chain_end(chain_run)
|
self._on_chain_end(chain_run)
|
||||||
|
|
||||||
@ -273,6 +283,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
self,
|
self,
|
||||||
error: Union[Exception, KeyboardInterrupt],
|
error: Union[Exception, KeyboardInterrupt],
|
||||||
*,
|
*,
|
||||||
|
inputs: Optional[Dict[str, Any]] = None,
|
||||||
run_id: UUID,
|
run_id: UUID,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -286,6 +297,8 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
chain_run.error = repr(error)
|
chain_run.error = repr(error)
|
||||||
chain_run.end_time = datetime.utcnow()
|
chain_run.end_time = datetime.utcnow()
|
||||||
chain_run.events.append({"name": "error", "time": chain_run.end_time})
|
chain_run.events.append({"name": "error", "time": chain_run.end_time})
|
||||||
|
if inputs is not None:
|
||||||
|
chain_run.inputs = inputs
|
||||||
self._end_trace(chain_run)
|
self._end_trace(chain_run)
|
||||||
self._on_chain_error(chain_run)
|
self._on_chain_error(chain_run)
|
||||||
|
|
||||||
|
@ -1,9 +1,11 @@
|
|||||||
"""Base interface that all chains should implement."""
|
"""Base interface that all chains should implement."""
|
||||||
|
import asyncio
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import warnings
|
import warnings
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
@ -55,18 +57,26 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def invoke(
|
def invoke(
|
||||||
self, input: Dict[str, Any], config: Optional[RunnableConfig] = None
|
self,
|
||||||
|
input: Dict[str, Any],
|
||||||
|
config: Optional[RunnableConfig] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
return self(input, **(config or {}))
|
return self(input, **(config or {}), **kwargs)
|
||||||
|
|
||||||
async def ainvoke(
|
async def ainvoke(
|
||||||
self, input: Dict[str, Any], config: Optional[RunnableConfig] = None
|
self,
|
||||||
|
input: Dict[str, Any],
|
||||||
|
config: Optional[RunnableConfig] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
if type(self)._acall == Chain._acall:
|
if type(self)._acall == Chain._acall:
|
||||||
# If the chain does not implement async, fall back to default implementation
|
# If the chain does not implement async, fall back to default implementation
|
||||||
return await super().ainvoke(input, config)
|
return await asyncio.get_running_loop().run_in_executor(
|
||||||
|
None, partial(self.invoke, input, config, **kwargs)
|
||||||
|
)
|
||||||
|
|
||||||
return await self.acall(input, **(config or {}))
|
return await self.acall(input, **(config or {}), **kwargs)
|
||||||
|
|
||||||
memory: Optional[BaseMemory] = None
|
memory: Optional[BaseMemory] = None
|
||||||
"""Optional memory object. Defaults to None.
|
"""Optional memory object. Defaults to None.
|
||||||
|
@ -3,6 +3,8 @@ import functools
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any, Awaitable, Callable, Dict, List, Optional
|
from typing import Any, Awaitable, Callable, Dict, List, Optional
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import (
|
||||||
AsyncCallbackManagerForChainRun,
|
AsyncCallbackManagerForChainRun,
|
||||||
CallbackManagerForChainRun,
|
CallbackManagerForChainRun,
|
||||||
@ -27,9 +29,11 @@ class TransformChain(Chain):
|
|||||||
"""The keys expected by the transform's input dictionary."""
|
"""The keys expected by the transform's input dictionary."""
|
||||||
output_variables: List[str]
|
output_variables: List[str]
|
||||||
"""The keys returned by the transform's output dictionary."""
|
"""The keys returned by the transform's output dictionary."""
|
||||||
transform: Callable[[Dict[str, str]], Dict[str, str]]
|
transform_cb: Callable[[Dict[str, str]], Dict[str, str]] = Field(alias="transform")
|
||||||
"""The transform function."""
|
"""The transform function."""
|
||||||
atransform: Optional[Callable[[Dict[str, Any]], Awaitable[Dict[str, Any]]]] = None
|
atransform_cb: Optional[
|
||||||
|
Callable[[Dict[str, Any]], Awaitable[Dict[str, Any]]]
|
||||||
|
] = Field(None, alias="atransform")
|
||||||
"""The async coroutine transform function."""
|
"""The async coroutine transform function."""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -62,18 +66,18 @@ class TransformChain(Chain):
|
|||||||
inputs: Dict[str, str],
|
inputs: Dict[str, str],
|
||||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
) -> Dict[str, str]:
|
) -> Dict[str, str]:
|
||||||
return self.transform(inputs)
|
return self.transform_cb(inputs)
|
||||||
|
|
||||||
async def _acall(
|
async def _acall(
|
||||||
self,
|
self,
|
||||||
inputs: Dict[str, Any],
|
inputs: Dict[str, Any],
|
||||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
if self.atransform is not None:
|
if self.atransform_cb is not None:
|
||||||
return await self.atransform(inputs)
|
return await self.atransform_cb(inputs)
|
||||||
else:
|
else:
|
||||||
self._log_once(
|
self._log_once(
|
||||||
"TransformChain's atransform is not provided, falling"
|
"TransformChain's atransform is not provided, falling"
|
||||||
" back to synchronous transform"
|
" back to synchronous transform"
|
||||||
)
|
)
|
||||||
return self.transform(inputs)
|
return self.transform_cb(inputs)
|
||||||
|
@ -1,9 +1,13 @@
|
|||||||
"""Fake ChatModel for testing purposes."""
|
"""Fake ChatModel for testing purposes."""
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
|
||||||
|
|
||||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
from langchain.callbacks.manager import (
|
||||||
|
AsyncCallbackManagerForLLMRun,
|
||||||
|
CallbackManagerForLLMRun,
|
||||||
|
)
|
||||||
from langchain.chat_models.base import SimpleChatModel
|
from langchain.chat_models.base import SimpleChatModel
|
||||||
from langchain.schema.messages import BaseMessage
|
from langchain.schema.messages import AIMessageChunk, BaseMessage
|
||||||
|
from langchain.schema.output import ChatGenerationChunk
|
||||||
|
|
||||||
|
|
||||||
class FakeListChatModel(SimpleChatModel):
|
class FakeListChatModel(SimpleChatModel):
|
||||||
@ -31,6 +35,36 @@ class FakeListChatModel(SimpleChatModel):
|
|||||||
self.i = 0
|
self.i = 0
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
def _stream(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Union[List[str], None] = None,
|
||||||
|
run_manager: Union[CallbackManagerForLLMRun, None] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Iterator[ChatGenerationChunk]:
|
||||||
|
response = self.responses[self.i]
|
||||||
|
if self.i < len(self.responses) - 1:
|
||||||
|
self.i += 1
|
||||||
|
else:
|
||||||
|
self.i = 0
|
||||||
|
for c in response:
|
||||||
|
yield ChatGenerationChunk(message=AIMessageChunk(content=c))
|
||||||
|
|
||||||
|
async def _astream(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Union[List[str], None] = None,
|
||||||
|
run_manager: Union[AsyncCallbackManagerForLLMRun, None] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> AsyncIterator[ChatGenerationChunk]:
|
||||||
|
response = self.responses[self.i]
|
||||||
|
if self.i < len(self.responses) - 1:
|
||||||
|
self.i += 1
|
||||||
|
else:
|
||||||
|
self.i = 0
|
||||||
|
for c in response:
|
||||||
|
yield ChatGenerationChunk(message=AIMessageChunk(content=c))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _identifying_params(self) -> Dict[str, Any]:
|
def _identifying_params(self) -> Dict[str, Any]:
|
||||||
return {"responses": self.responses}
|
return {"responses": self.responses}
|
||||||
|
@ -1,10 +1,12 @@
|
|||||||
from typing import Any, List, Mapping, Optional
|
from typing import Any, AsyncIterator, Iterator, List, Mapping, Optional
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import (
|
||||||
AsyncCallbackManagerForLLMRun,
|
AsyncCallbackManagerForLLMRun,
|
||||||
CallbackManagerForLLMRun,
|
CallbackManagerForLLMRun,
|
||||||
)
|
)
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import LLM
|
||||||
|
from langchain.schema.language_model import LanguageModelInput
|
||||||
|
from langchain.schema.runnable import RunnableConfig
|
||||||
|
|
||||||
|
|
||||||
class FakeListLLM(LLM):
|
class FakeListLLM(LLM):
|
||||||
@ -51,3 +53,29 @@ class FakeListLLM(LLM):
|
|||||||
@property
|
@property
|
||||||
def _identifying_params(self) -> Mapping[str, Any]:
|
def _identifying_params(self) -> Mapping[str, Any]:
|
||||||
return {"responses": self.responses}
|
return {"responses": self.responses}
|
||||||
|
|
||||||
|
|
||||||
|
class FakeStreamingListLLM(FakeListLLM):
|
||||||
|
def stream(
|
||||||
|
self,
|
||||||
|
input: LanguageModelInput,
|
||||||
|
config: Optional[RunnableConfig] = None,
|
||||||
|
*,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Iterator[str]:
|
||||||
|
result = self.invoke(input, config)
|
||||||
|
for c in result:
|
||||||
|
yield c
|
||||||
|
|
||||||
|
async def astream(
|
||||||
|
self,
|
||||||
|
input: LanguageModelInput,
|
||||||
|
config: Optional[RunnableConfig] = None,
|
||||||
|
*,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> AsyncIterator[str]:
|
||||||
|
result = await self.ainvoke(input, config)
|
||||||
|
for c in result:
|
||||||
|
yield c
|
||||||
|
@ -2,7 +2,17 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Dict, Generic, List, Optional, TypeVar, Union
|
from typing import (
|
||||||
|
Any,
|
||||||
|
AsyncIterator,
|
||||||
|
Dict,
|
||||||
|
Generic,
|
||||||
|
Iterator,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
TypeVar,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
from langchain.load.serializable import Serializable
|
from langchain.load.serializable import Serializable
|
||||||
from langchain.schema.messages import BaseMessage
|
from langchain.schema.messages import BaseMessage
|
||||||
@ -47,7 +57,7 @@ class BaseGenerationOutputParser(
|
|||||||
BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T]
|
BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T]
|
||||||
):
|
):
|
||||||
def invoke(
|
def invoke(
|
||||||
self, input: str | BaseMessage, config: RunnableConfig | None = None
|
self, input: Union[str, BaseMessage], config: Optional[RunnableConfig] = None
|
||||||
) -> T:
|
) -> T:
|
||||||
if isinstance(input, BaseMessage):
|
if isinstance(input, BaseMessage):
|
||||||
return self._call_with_config(
|
return self._call_with_config(
|
||||||
@ -115,7 +125,7 @@ class BaseOutputParser(BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T]
|
|||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
|
|
||||||
def invoke(
|
def invoke(
|
||||||
self, input: str | BaseMessage, config: RunnableConfig | None = None
|
self, input: Union[str, BaseMessage], config: Optional[RunnableConfig] = None
|
||||||
) -> T:
|
) -> T:
|
||||||
if isinstance(input, BaseMessage):
|
if isinstance(input, BaseMessage):
|
||||||
return self._call_with_config(
|
return self._call_with_config(
|
||||||
@ -242,8 +252,47 @@ class BaseOutputParser(BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T]
|
|||||||
return output_parser_dict
|
return output_parser_dict
|
||||||
|
|
||||||
|
|
||||||
class StrOutputParser(BaseOutputParser[str]):
|
class BaseTransformOutputParser(BaseOutputParser[T]):
|
||||||
"""OutputParser that parses LLMResult into the top likely string.."""
|
"""Base class for an output parser that can handle streaming input."""
|
||||||
|
|
||||||
|
def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[T]:
|
||||||
|
for chunk in input:
|
||||||
|
if isinstance(chunk, BaseMessage):
|
||||||
|
yield self.parse_result([ChatGeneration(message=chunk)])
|
||||||
|
else:
|
||||||
|
yield self.parse_result([Generation(text=chunk)])
|
||||||
|
|
||||||
|
async def _atransform(
|
||||||
|
self, input: AsyncIterator[Union[str, BaseMessage]]
|
||||||
|
) -> AsyncIterator[T]:
|
||||||
|
async for chunk in input:
|
||||||
|
if isinstance(chunk, BaseMessage):
|
||||||
|
yield self.parse_result([ChatGeneration(message=chunk)])
|
||||||
|
else:
|
||||||
|
yield self.parse_result([Generation(text=chunk)])
|
||||||
|
|
||||||
|
def transform(
|
||||||
|
self,
|
||||||
|
input: Iterator[Union[str, BaseMessage]],
|
||||||
|
config: Optional[RunnableConfig] = None,
|
||||||
|
) -> Iterator[T]:
|
||||||
|
yield from self._transform_stream_with_config(
|
||||||
|
input, self._transform, config, run_type="parser"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def atransform(
|
||||||
|
self,
|
||||||
|
input: AsyncIterator[Union[str, BaseMessage]],
|
||||||
|
config: Optional[RunnableConfig] = None,
|
||||||
|
) -> AsyncIterator[T]:
|
||||||
|
async for chunk in self._atransform_stream_with_config(
|
||||||
|
input, self._atransform, config, run_type="parser"
|
||||||
|
):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
|
||||||
|
class StrOutputParser(BaseTransformOutputParser[str]):
|
||||||
|
"""OutputParser that parses LLMResult into the top likely string."""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def lc_serializable(self) -> bool:
|
def lc_serializable(self) -> bool:
|
||||||
|
@ -3,6 +3,7 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from itertools import tee
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
AsyncIterator,
|
AsyncIterator,
|
||||||
@ -29,6 +30,7 @@ from pydantic import Field
|
|||||||
from langchain.callbacks.base import BaseCallbackManager, Callbacks
|
from langchain.callbacks.base import BaseCallbackManager, Callbacks
|
||||||
from langchain.load.dump import dumpd
|
from langchain.load.dump import dumpd
|
||||||
from langchain.load.serializable import Serializable
|
from langchain.load.serializable import Serializable
|
||||||
|
from langchain.utils.aiter import atee, py_anext
|
||||||
|
|
||||||
|
|
||||||
async def _gated_coro(semaphore: asyncio.Semaphore, coro: Coroutine) -> Any:
|
async def _gated_coro(semaphore: asyncio.Semaphore, coro: Coroutine) -> Any:
|
||||||
@ -92,6 +94,8 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
) -> RunnableSequence[Other, Output]:
|
) -> RunnableSequence[Other, Output]:
|
||||||
return RunnableSequence(first=_coerce_to_runnable(other), last=self)
|
return RunnableSequence(first=_coerce_to_runnable(other), last=self)
|
||||||
|
|
||||||
|
""" --- Public API --- """
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
|
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
|
||||||
...
|
...
|
||||||
@ -99,6 +103,10 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
async def ainvoke(
|
async def ainvoke(
|
||||||
self, input: Input, config: Optional[RunnableConfig] = None
|
self, input: Input, config: Optional[RunnableConfig] = None
|
||||||
) -> Output:
|
) -> Output:
|
||||||
|
"""
|
||||||
|
Default implementation of ainvoke, which calls invoke in a thread pool.
|
||||||
|
Subclasses should override this method if they can run asynchronously.
|
||||||
|
"""
|
||||||
return await asyncio.get_running_loop().run_in_executor(
|
return await asyncio.get_running_loop().run_in_executor(
|
||||||
None, self.invoke, input, config
|
None, self.invoke, input, config
|
||||||
)
|
)
|
||||||
@ -110,6 +118,10 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
*,
|
*,
|
||||||
max_concurrency: Optional[int] = None,
|
max_concurrency: Optional[int] = None,
|
||||||
) -> List[Output]:
|
) -> List[Output]:
|
||||||
|
"""
|
||||||
|
Default implementation of batch, which calls invoke N times.
|
||||||
|
Subclasses should override this method if they can batch more efficiently.
|
||||||
|
"""
|
||||||
configs = self._get_config_list(config, len(inputs))
|
configs = self._get_config_list(config, len(inputs))
|
||||||
|
|
||||||
# If there's only one input, don't bother with the executor
|
# If there's only one input, don't bother with the executor
|
||||||
@ -126,6 +138,10 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
*,
|
*,
|
||||||
max_concurrency: Optional[int] = None,
|
max_concurrency: Optional[int] = None,
|
||||||
) -> List[Output]:
|
) -> List[Output]:
|
||||||
|
"""
|
||||||
|
Default implementation of abatch, which calls ainvoke N times.
|
||||||
|
Subclasses should override this method if they can batch more efficiently.
|
||||||
|
"""
|
||||||
configs = self._get_config_list(config, len(inputs))
|
configs = self._get_config_list(config, len(inputs))
|
||||||
coros = map(self.ainvoke, inputs, configs)
|
coros = map(self.ainvoke, inputs, configs)
|
||||||
|
|
||||||
@ -134,22 +150,90 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
def stream(
|
def stream(
|
||||||
self, input: Input, config: Optional[RunnableConfig] = None
|
self, input: Input, config: Optional[RunnableConfig] = None
|
||||||
) -> Iterator[Output]:
|
) -> Iterator[Output]:
|
||||||
|
"""
|
||||||
|
Default implementation of stream, which calls invoke.
|
||||||
|
Subclasses should override this method if they support streaming output.
|
||||||
|
"""
|
||||||
yield self.invoke(input, config)
|
yield self.invoke(input, config)
|
||||||
|
|
||||||
async def astream(
|
async def astream(
|
||||||
self, input: Input, config: Optional[RunnableConfig] = None
|
self, input: Input, config: Optional[RunnableConfig] = None
|
||||||
) -> AsyncIterator[Output]:
|
) -> AsyncIterator[Output]:
|
||||||
|
"""
|
||||||
|
Default implementation of astream, which calls ainvoke.
|
||||||
|
Subclasses should override this method if they support streaming output.
|
||||||
|
"""
|
||||||
yield await self.ainvoke(input, config)
|
yield await self.ainvoke(input, config)
|
||||||
|
|
||||||
|
def transform(
|
||||||
|
self, input: Iterator[Input], config: Optional[RunnableConfig] = None
|
||||||
|
) -> Iterator[Output]:
|
||||||
|
"""
|
||||||
|
Default implementation of transform, which buffers input and then calls stream.
|
||||||
|
Subclasses should override this method if they can start producing output while
|
||||||
|
input is still being generated.
|
||||||
|
"""
|
||||||
|
final: Union[Input, None] = None
|
||||||
|
|
||||||
|
for chunk in input:
|
||||||
|
if final is None:
|
||||||
|
final = chunk
|
||||||
|
else:
|
||||||
|
# Make a best effort to gather, for any type that supports `+`
|
||||||
|
# This method should throw an error if gathering fails.
|
||||||
|
final += chunk # type: ignore[operator]
|
||||||
|
if final:
|
||||||
|
yield from self.stream(final, config)
|
||||||
|
|
||||||
|
async def atransform(
|
||||||
|
self, input: AsyncIterator[Input], config: Optional[RunnableConfig] = None
|
||||||
|
) -> AsyncIterator[Output]:
|
||||||
|
"""
|
||||||
|
Default implementation of atransform, which buffers input and calls astream.
|
||||||
|
Subclasses should override this method if they can start producing output while
|
||||||
|
input is still being generated.
|
||||||
|
"""
|
||||||
|
final: Union[Input, None] = None
|
||||||
|
|
||||||
|
async for chunk in input:
|
||||||
|
if final is None:
|
||||||
|
final = chunk
|
||||||
|
else:
|
||||||
|
# Make a best effort to gather, for any type that supports `+`
|
||||||
|
# This method should throw an error if gathering fails.
|
||||||
|
final += chunk # type: ignore[operator]
|
||||||
|
|
||||||
|
if final:
|
||||||
|
async for output in self.astream(final, config):
|
||||||
|
yield output
|
||||||
|
|
||||||
def bind(self, **kwargs: Any) -> Runnable[Input, Output]:
|
def bind(self, **kwargs: Any) -> Runnable[Input, Output]:
|
||||||
"""
|
"""
|
||||||
Bind arguments to a Runnable, returning a new Runnable.
|
Bind arguments to a Runnable, returning a new Runnable.
|
||||||
"""
|
"""
|
||||||
return RunnableBinding(bound=self, kwargs=kwargs)
|
return RunnableBinding(bound=self, kwargs=kwargs)
|
||||||
|
|
||||||
|
def with_fallbacks(
|
||||||
|
self,
|
||||||
|
fallbacks: Sequence[Runnable[Input, Output]],
|
||||||
|
*,
|
||||||
|
exceptions_to_handle: Tuple[Type[BaseException]] = (Exception,),
|
||||||
|
) -> RunnableWithFallbacks[Input, Output]:
|
||||||
|
return RunnableWithFallbacks(
|
||||||
|
runnable=self,
|
||||||
|
fallbacks=fallbacks,
|
||||||
|
exceptions_to_handle=exceptions_to_handle,
|
||||||
|
)
|
||||||
|
|
||||||
|
""" --- Helper methods for Subclasses --- """
|
||||||
|
|
||||||
def _get_config_list(
|
def _get_config_list(
|
||||||
self, config: Optional[Union[RunnableConfig, List[RunnableConfig]]], length: int
|
self, config: Optional[Union[RunnableConfig, List[RunnableConfig]]], length: int
|
||||||
) -> List[RunnableConfig]:
|
) -> List[RunnableConfig]:
|
||||||
|
"""
|
||||||
|
Helper method to get a list of configs from a single config or a list of
|
||||||
|
configs, useful for subclasses overriding batch() or abatch().
|
||||||
|
"""
|
||||||
if isinstance(config, list) and len(config) != length:
|
if isinstance(config, list) and len(config) != length:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"config must be a list of the same length as inputs, "
|
f"config must be a list of the same length as inputs, "
|
||||||
@ -169,6 +253,8 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
config: Optional[RunnableConfig],
|
config: Optional[RunnableConfig],
|
||||||
run_type: Optional[str] = None,
|
run_type: Optional[str] = None,
|
||||||
) -> Output:
|
) -> Output:
|
||||||
|
"""Helper method to transform an Input value to an Output value,
|
||||||
|
with callbacks. Use this method to implement invoke() in subclasses."""
|
||||||
from langchain.callbacks.manager import CallbackManager
|
from langchain.callbacks.manager import CallbackManager
|
||||||
|
|
||||||
config = config or {}
|
config = config or {}
|
||||||
@ -200,6 +286,8 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
config: Optional[RunnableConfig],
|
config: Optional[RunnableConfig],
|
||||||
run_type: Optional[str] = None,
|
run_type: Optional[str] = None,
|
||||||
) -> Output:
|
) -> Output:
|
||||||
|
"""Helper method to transform an Input value to an Output value,
|
||||||
|
with callbacks. Use this method to implement ainvoke() in subclasses."""
|
||||||
from langchain.callbacks.manager import AsyncCallbackManager
|
from langchain.callbacks.manager import AsyncCallbackManager
|
||||||
|
|
||||||
config = config or {}
|
config = config or {}
|
||||||
@ -224,20 +312,154 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
)
|
)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def with_fallbacks(
|
def _transform_stream_with_config(
|
||||||
self,
|
self,
|
||||||
fallbacks: Sequence[Runnable[Input, Output]],
|
input: Iterator[Input],
|
||||||
*,
|
transformer: Callable[[Iterator[Input]], Iterator[Output]],
|
||||||
exceptions_to_handle: Tuple[Type[BaseException]] = (Exception,),
|
config: Optional[RunnableConfig],
|
||||||
) -> RunnableWithFallbacks[Input, Output]:
|
run_type: Optional[str] = None,
|
||||||
return RunnableWithFallbacks(
|
) -> Iterator[Output]:
|
||||||
runnable=self,
|
"""Helper method to transform an Iterator of Input values into an Iterator of
|
||||||
fallbacks=fallbacks,
|
Output values, with callbacks.
|
||||||
exceptions_to_handle=exceptions_to_handle,
|
Use this to implement `stream()` or `transform()` in Runnable subclasses."""
|
||||||
|
from langchain.callbacks.manager import CallbackManager
|
||||||
|
|
||||||
|
# tee the input so we can iterate over it twice
|
||||||
|
input_for_tracing, input_for_transform = tee(input, 2)
|
||||||
|
# Start the input iterator to ensure the input runnable starts before this one
|
||||||
|
final_input: Optional[Input] = next(input_for_tracing, None)
|
||||||
|
final_input_supported = True
|
||||||
|
final_output: Optional[Output] = None
|
||||||
|
final_output_supported = True
|
||||||
|
|
||||||
|
config = config or {}
|
||||||
|
callback_manager = CallbackManager.configure(
|
||||||
|
inheritable_callbacks=config.get("callbacks"),
|
||||||
|
inheritable_tags=config.get("tags"),
|
||||||
|
inheritable_metadata=config.get("metadata"),
|
||||||
)
|
)
|
||||||
|
run_manager = callback_manager.on_chain_start(
|
||||||
|
dumpd(self),
|
||||||
|
{"input": ""},
|
||||||
|
run_type=run_type,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
for chunk in transformer(input_for_transform):
|
||||||
|
yield chunk
|
||||||
|
if final_output_supported:
|
||||||
|
if final_output is None:
|
||||||
|
final_output = chunk
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
final_output += chunk # type: ignore[operator]
|
||||||
|
except TypeError:
|
||||||
|
final_output = None
|
||||||
|
final_output_supported = False
|
||||||
|
for ichunk in input_for_tracing:
|
||||||
|
if final_input_supported:
|
||||||
|
if final_input is None:
|
||||||
|
final_input = ichunk
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
final_input += ichunk # type: ignore[operator]
|
||||||
|
except TypeError:
|
||||||
|
final_input = None
|
||||||
|
final_input_supported = False
|
||||||
|
except Exception as e:
|
||||||
|
run_manager.on_chain_error(
|
||||||
|
e,
|
||||||
|
inputs=final_input
|
||||||
|
if isinstance(final_input, dict)
|
||||||
|
else {"input": final_input},
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
run_manager.on_chain_end(
|
||||||
|
final_output
|
||||||
|
if isinstance(final_output, dict)
|
||||||
|
else {"output": final_output},
|
||||||
|
inputs=final_input
|
||||||
|
if isinstance(final_input, dict)
|
||||||
|
else {"input": final_input},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _atransform_stream_with_config(
|
||||||
|
self,
|
||||||
|
input: AsyncIterator[Input],
|
||||||
|
transformer: Callable[[AsyncIterator[Input]], AsyncIterator[Output]],
|
||||||
|
config: Optional[RunnableConfig],
|
||||||
|
run_type: Optional[str] = None,
|
||||||
|
) -> AsyncIterator[Output]:
|
||||||
|
"""Helper method to transform an Async Iterator of Input values into an Async
|
||||||
|
Iterator of Output values, with callbacks.
|
||||||
|
Use this to implement `astream()` or `atransform()` in Runnable subclasses."""
|
||||||
|
from langchain.callbacks.manager import AsyncCallbackManager
|
||||||
|
|
||||||
|
# tee the input so we can iterate over it twice
|
||||||
|
input_for_tracing, input_for_transform = atee(input, 2)
|
||||||
|
# Start the input iterator to ensure the input runnable starts before this one
|
||||||
|
final_input: Optional[Input] = await py_anext(input_for_tracing, None)
|
||||||
|
final_input_supported = True
|
||||||
|
final_output: Optional[Output] = None
|
||||||
|
final_output_supported = True
|
||||||
|
|
||||||
|
config = config or {}
|
||||||
|
callback_manager = AsyncCallbackManager.configure(
|
||||||
|
inheritable_callbacks=config.get("callbacks"),
|
||||||
|
inheritable_tags=config.get("tags"),
|
||||||
|
inheritable_metadata=config.get("metadata"),
|
||||||
|
)
|
||||||
|
run_manager = await callback_manager.on_chain_start(
|
||||||
|
dumpd(self),
|
||||||
|
{"input": ""},
|
||||||
|
run_type=run_type,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
async for chunk in transformer(input_for_transform):
|
||||||
|
yield chunk
|
||||||
|
if final_output_supported:
|
||||||
|
if final_output is None:
|
||||||
|
final_output = chunk
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
final_output += chunk # type: ignore[operator]
|
||||||
|
except TypeError:
|
||||||
|
final_output = None
|
||||||
|
final_output_supported = False
|
||||||
|
async for ichunk in input_for_tracing:
|
||||||
|
if final_input_supported:
|
||||||
|
if final_input is None:
|
||||||
|
final_input = ichunk
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
final_input += ichunk # type: ignore[operator]
|
||||||
|
except TypeError:
|
||||||
|
final_input = None
|
||||||
|
final_input_supported = False
|
||||||
|
except Exception as e:
|
||||||
|
await run_manager.on_chain_error(
|
||||||
|
e,
|
||||||
|
inputs=final_input
|
||||||
|
if isinstance(final_input, dict)
|
||||||
|
else {"input": final_input},
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
await run_manager.on_chain_end(
|
||||||
|
final_output
|
||||||
|
if isinstance(final_output, dict)
|
||||||
|
else {"output": final_output},
|
||||||
|
inputs=final_input
|
||||||
|
if isinstance(final_input, dict)
|
||||||
|
else {"input": final_input},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
||||||
|
"""
|
||||||
|
A Runnable that can fallback to other Runnables if it fails.
|
||||||
|
"""
|
||||||
|
|
||||||
runnable: Runnable[Input, Output]
|
runnable: Runnable[Input, Output]
|
||||||
fallbacks: Sequence[Runnable[Input, Output]]
|
fallbacks: Sequence[Runnable[Input, Output]]
|
||||||
exceptions_to_handle: Tuple[Type[BaseException]] = (Exception,)
|
exceptions_to_handle: Tuple[Type[BaseException]] = (Exception,)
|
||||||
@ -467,6 +689,10 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
|||||||
|
|
||||||
|
|
||||||
class RunnableSequence(Serializable, Runnable[Input, Output]):
|
class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||||
|
"""
|
||||||
|
A sequence of runnables, where the output of each is the input of the next.
|
||||||
|
"""
|
||||||
|
|
||||||
first: Runnable[Input, Any]
|
first: Runnable[Input, Any]
|
||||||
middle: List[Runnable[Any, Any]] = Field(default_factory=list)
|
middle: List[Runnable[Any, Any]] = Field(default_factory=list)
|
||||||
last: Runnable[Any, Output]
|
last: Runnable[Any, Output]
|
||||||
@ -738,9 +964,18 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
steps = [self.first] + self.middle + [self.last]
|
||||||
|
streaming_start_index = 0
|
||||||
|
|
||||||
|
for i in range(len(steps) - 1, 0, -1):
|
||||||
|
if type(steps[i]).transform != Runnable.transform:
|
||||||
|
streaming_start_index = i - 1
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
# invoke the first steps
|
# invoke the first steps
|
||||||
try:
|
try:
|
||||||
for step in [self.first] + self.middle:
|
for step in steps[0:streaming_start_index]:
|
||||||
input = step.invoke(
|
input = step.invoke(
|
||||||
input,
|
input,
|
||||||
# mark each step as a child run
|
# mark each step as a child run
|
||||||
@ -750,15 +985,20 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
run_manager.on_chain_error(e)
|
run_manager.on_chain_error(e)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
# stream the last step
|
# stream the last steps
|
||||||
final: Union[Output, None] = None
|
final: Union[Output, None] = None
|
||||||
final_supported = True
|
final_supported = True
|
||||||
try:
|
try:
|
||||||
for output in self.last.stream(
|
# stream the first of the last steps with non-streaming input
|
||||||
input,
|
final_pipeline = steps[streaming_start_index].stream(
|
||||||
# mark the last step as a child run
|
input, _patch_config(config, run_manager.get_child())
|
||||||
_patch_config(config, run_manager.get_child()),
|
)
|
||||||
):
|
# stream the rest of the last steps with streaming input
|
||||||
|
for step in steps[streaming_start_index + 1 :]:
|
||||||
|
final_pipeline = step.transform(
|
||||||
|
final_pipeline, _patch_config(config, run_manager.get_child())
|
||||||
|
)
|
||||||
|
for output in final_pipeline:
|
||||||
yield output
|
yield output
|
||||||
# Accumulate output if possible, otherwise disable accumulation
|
# Accumulate output if possible, otherwise disable accumulation
|
||||||
if final_supported:
|
if final_supported:
|
||||||
@ -801,9 +1041,18 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
steps = [self.first] + self.middle + [self.last]
|
||||||
|
streaming_start_index = len(steps) - 1
|
||||||
|
|
||||||
|
for i in range(len(steps) - 1, 0, -1):
|
||||||
|
if type(steps[i]).transform != Runnable.transform:
|
||||||
|
streaming_start_index = i - 1
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
# invoke the first steps
|
# invoke the first steps
|
||||||
try:
|
try:
|
||||||
for step in [self.first] + self.middle:
|
for step in steps[0:streaming_start_index]:
|
||||||
input = await step.ainvoke(
|
input = await step.ainvoke(
|
||||||
input,
|
input,
|
||||||
# mark each step as a child run
|
# mark each step as a child run
|
||||||
@ -813,15 +1062,20 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
await run_manager.on_chain_error(e)
|
await run_manager.on_chain_error(e)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
# stream the last step
|
# stream the last steps
|
||||||
final: Union[Output, None] = None
|
final: Union[Output, None] = None
|
||||||
final_supported = True
|
final_supported = True
|
||||||
try:
|
try:
|
||||||
async for output in self.last.astream(
|
# stream the first of the last steps with non-streaming input
|
||||||
input,
|
final_pipeline = steps[streaming_start_index].astream(
|
||||||
# mark the last step as a child run
|
input, _patch_config(config, run_manager.get_child())
|
||||||
_patch_config(config, run_manager.get_child()),
|
)
|
||||||
):
|
# stream the rest of the last steps with streaming input
|
||||||
|
for step in steps[streaming_start_index + 1 :]:
|
||||||
|
final_pipeline = step.atransform(
|
||||||
|
final_pipeline, _patch_config(config, run_manager.get_child())
|
||||||
|
)
|
||||||
|
async for output in final_pipeline:
|
||||||
yield output
|
yield output
|
||||||
# Accumulate output if possible, otherwise disable accumulation
|
# Accumulate output if possible, otherwise disable accumulation
|
||||||
if final_supported:
|
if final_supported:
|
||||||
@ -845,6 +1099,11 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
|
|
||||||
|
|
||||||
class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
||||||
|
"""
|
||||||
|
A runnable that runs a mapping of runnables in parallel,
|
||||||
|
and returns a mapping of their outputs.
|
||||||
|
"""
|
||||||
|
|
||||||
steps: Mapping[str, Runnable[Input, Any]]
|
steps: Mapping[str, Runnable[Input, Any]]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -957,6 +1216,10 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
|||||||
|
|
||||||
|
|
||||||
class RunnableLambda(Runnable[Input, Output]):
|
class RunnableLambda(Runnable[Input, Output]):
|
||||||
|
"""
|
||||||
|
A runnable that runs a callable.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, func: Callable[[Input], Output]) -> None:
|
def __init__(self, func: Callable[[Input], Output]) -> None:
|
||||||
if callable(func):
|
if callable(func):
|
||||||
self.func = func
|
self.func = func
|
||||||
@ -977,6 +1240,10 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
|
|
||||||
|
|
||||||
class RunnablePassthrough(Serializable, Runnable[Input, Input]):
|
class RunnablePassthrough(Serializable, Runnable[Input, Input]):
|
||||||
|
"""
|
||||||
|
A runnable that passes through the input.
|
||||||
|
"""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def lc_serializable(self) -> bool:
|
def lc_serializable(self) -> bool:
|
||||||
return True
|
return True
|
||||||
@ -986,6 +1253,10 @@ class RunnablePassthrough(Serializable, Runnable[Input, Input]):
|
|||||||
|
|
||||||
|
|
||||||
class RunnableBinding(Serializable, Runnable[Input, Output]):
|
class RunnableBinding(Serializable, Runnable[Input, Output]):
|
||||||
|
"""
|
||||||
|
A runnable that binds a runnable to a set of kwargs.
|
||||||
|
"""
|
||||||
|
|
||||||
bound: Runnable[Input, Output]
|
bound: Runnable[Input, Output]
|
||||||
|
|
||||||
kwargs: Mapping[str, Any]
|
kwargs: Mapping[str, Any]
|
||||||
@ -1041,6 +1312,17 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
|||||||
async for item in self.bound.astream(input, config, **self.kwargs):
|
async for item in self.bound.astream(input, config, **self.kwargs):
|
||||||
yield item
|
yield item
|
||||||
|
|
||||||
|
def transform(
|
||||||
|
self, input: Iterator[Input], config: Optional[RunnableConfig] = None
|
||||||
|
) -> Iterator[Output]:
|
||||||
|
yield from self.bound.transform(input, config, **self.kwargs)
|
||||||
|
|
||||||
|
async def atransform(
|
||||||
|
self, input: AsyncIterator[Input], config: Optional[RunnableConfig] = None
|
||||||
|
) -> AsyncIterator[Output]:
|
||||||
|
async for item in self.bound.atransform(input, config, **self.kwargs):
|
||||||
|
yield item
|
||||||
|
|
||||||
|
|
||||||
class RouterInput(TypedDict):
|
class RouterInput(TypedDict):
|
||||||
key: str
|
key: str
|
||||||
@ -1050,6 +1332,11 @@ class RouterInput(TypedDict):
|
|||||||
class RouterRunnable(
|
class RouterRunnable(
|
||||||
Serializable, Generic[Input, Output], Runnable[RouterInput, Output]
|
Serializable, Generic[Input, Output], Runnable[RouterInput, Output]
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
A runnable that routes to a set of runnables based on Input['key'].
|
||||||
|
Returns the output of the selected runnable.
|
||||||
|
"""
|
||||||
|
|
||||||
runnables: Mapping[str, Runnable[Input, Output]]
|
runnables: Mapping[str, Runnable[Input, Output]]
|
||||||
|
|
||||||
def __init__(self, runnables: Mapping[str, Runnable[Input, Output]]) -> None:
|
def __init__(self, runnables: Mapping[str, Runnable[Input, Output]]) -> None:
|
||||||
|
191
libs/langchain/langchain/utils/aiter.py
Normal file
191
libs/langchain/langchain/utils/aiter.py
Normal file
@ -0,0 +1,191 @@
|
|||||||
|
"""
|
||||||
|
Adapted from
|
||||||
|
https://github.com/maxfischer2781/asyncstdlib/blob/master/asyncstdlib/itertools.py
|
||||||
|
MIT License
|
||||||
|
"""
|
||||||
|
|
||||||
|
from collections import deque
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
AsyncGenerator,
|
||||||
|
AsyncIterator,
|
||||||
|
Awaitable,
|
||||||
|
Callable,
|
||||||
|
Deque,
|
||||||
|
Generic,
|
||||||
|
Iterator,
|
||||||
|
List,
|
||||||
|
Tuple,
|
||||||
|
TypeVar,
|
||||||
|
Union,
|
||||||
|
cast,
|
||||||
|
overload,
|
||||||
|
)
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
_no_default = object()
|
||||||
|
|
||||||
|
|
||||||
|
# https://github.com/python/cpython/blob/main/Lib/test/test_asyncgen.py#L54
|
||||||
|
# before 3.10, the builtin anext() was not available
|
||||||
|
def py_anext(
|
||||||
|
iterator: AsyncIterator[T], default: Union[T, Any] = _no_default
|
||||||
|
) -> Awaitable[Union[T, None, Any]]:
|
||||||
|
"""Pure-Python implementation of anext() for testing purposes.
|
||||||
|
|
||||||
|
Closely matches the builtin anext() C implementation.
|
||||||
|
Can be used to compare the built-in implementation of the inner
|
||||||
|
coroutines machinery to C-implementation of __anext__() and send()
|
||||||
|
or throw() on the returned generator.
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
__anext__ = cast(
|
||||||
|
Callable[[AsyncIterator[T]], Awaitable[T]], type(iterator).__anext__
|
||||||
|
)
|
||||||
|
except AttributeError:
|
||||||
|
raise TypeError(f"{iterator!r} is not an async iterator")
|
||||||
|
|
||||||
|
if default is _no_default:
|
||||||
|
return __anext__(iterator)
|
||||||
|
|
||||||
|
async def anext_impl() -> Union[T, Any]:
|
||||||
|
try:
|
||||||
|
# The C code is way more low-level than this, as it implements
|
||||||
|
# all methods of the iterator protocol. In this implementation
|
||||||
|
# we're relying on higher-level coroutine concepts, but that's
|
||||||
|
# exactly what we want -- crosstest pure-Python high-level
|
||||||
|
# implementation and low-level C anext() iterators.
|
||||||
|
return await __anext__(iterator)
|
||||||
|
except StopAsyncIteration:
|
||||||
|
return default
|
||||||
|
|
||||||
|
return anext_impl()
|
||||||
|
|
||||||
|
|
||||||
|
async def tee_peer(
|
||||||
|
iterator: AsyncIterator[T],
|
||||||
|
# the buffer specific to this peer
|
||||||
|
buffer: Deque[T],
|
||||||
|
# the buffers of all peers, including our own
|
||||||
|
peers: List[Deque[T]],
|
||||||
|
) -> AsyncGenerator[T, None]:
|
||||||
|
"""An individual iterator of a :py:func:`~.tee`"""
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
if not buffer:
|
||||||
|
# Another peer produced an item while we were waiting for the lock.
|
||||||
|
# Proceed with the next loop iteration to yield the item.
|
||||||
|
if buffer:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
item = await iterator.__anext__()
|
||||||
|
except StopAsyncIteration:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# Append to all buffers, including our own. We'll fetch our
|
||||||
|
# item from the buffer again, instead of yielding it directly.
|
||||||
|
# This ensures the proper item ordering if any of our peers
|
||||||
|
# are fetching items concurrently. They may have buffered their
|
||||||
|
# item already.
|
||||||
|
for peer_buffer in peers:
|
||||||
|
peer_buffer.append(item)
|
||||||
|
yield buffer.popleft()
|
||||||
|
finally:
|
||||||
|
# this peer is done – remove its buffer
|
||||||
|
for idx, peer_buffer in enumerate(peers): # pragma: no branch
|
||||||
|
if peer_buffer is buffer:
|
||||||
|
peers.pop(idx)
|
||||||
|
break
|
||||||
|
# if we are the last peer, try and close the iterator
|
||||||
|
if not peers and hasattr(iterator, "aclose"):
|
||||||
|
await iterator.aclose()
|
||||||
|
|
||||||
|
|
||||||
|
class Tee(Generic[T]):
|
||||||
|
"""
|
||||||
|
Create ``n`` separate asynchronous iterators over ``iterable``
|
||||||
|
|
||||||
|
This splits a single ``iterable`` into multiple iterators, each providing
|
||||||
|
the same items in the same order.
|
||||||
|
All child iterators may advance separately but share the same items
|
||||||
|
from ``iterable`` -- when the most advanced iterator retrieves an item,
|
||||||
|
it is buffered until the least advanced iterator has yielded it as well.
|
||||||
|
A ``tee`` works lazily and can handle an infinite ``iterable``, provided
|
||||||
|
that all iterators advance.
|
||||||
|
|
||||||
|
.. code-block:: python3
|
||||||
|
|
||||||
|
async def derivative(sensor_data):
|
||||||
|
previous, current = a.tee(sensor_data, n=2)
|
||||||
|
await a.anext(previous) # advance one iterator
|
||||||
|
return a.map(operator.sub, previous, current)
|
||||||
|
|
||||||
|
Unlike :py:func:`itertools.tee`, :py:func:`~.tee` returns a custom type instead
|
||||||
|
of a :py:class:`tuple`. Like a tuple, it can be indexed, iterated and unpacked
|
||||||
|
to get the child iterators. In addition, its :py:meth:`~.tee.aclose` method
|
||||||
|
immediately closes all children, and it can be used in an ``async with`` context
|
||||||
|
for the same effect.
|
||||||
|
|
||||||
|
If ``iterable`` is an iterator and read elsewhere, ``tee`` will *not*
|
||||||
|
provide these items. Also, ``tee`` must internally buffer each item until the
|
||||||
|
last iterator has yielded it; if the most and least advanced iterator differ
|
||||||
|
by most data, using a :py:class:`list` is more efficient (but not lazy).
|
||||||
|
|
||||||
|
If the underlying iterable is concurrency safe (``anext`` may be awaited
|
||||||
|
concurrently) the resulting iterators are concurrency safe as well. Otherwise,
|
||||||
|
the iterators are safe if there is only ever one single "most advanced" iterator.
|
||||||
|
To enforce sequential use of ``anext``, provide a ``lock``
|
||||||
|
- e.g. an :py:class:`asyncio.Lock` instance in an :py:mod:`asyncio` application -
|
||||||
|
and access is automatically synchronised.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
iterable: AsyncIterator[T],
|
||||||
|
n: int = 2,
|
||||||
|
):
|
||||||
|
self._iterator = iterable.__aiter__() # before 3.10 aiter() doesn't exist
|
||||||
|
self._buffers: List[Deque[T]] = [deque() for _ in range(n)]
|
||||||
|
self._children = tuple(
|
||||||
|
tee_peer(
|
||||||
|
iterator=self._iterator,
|
||||||
|
buffer=buffer,
|
||||||
|
peers=self._buffers,
|
||||||
|
)
|
||||||
|
for buffer in self._buffers
|
||||||
|
)
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self._children)
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def __getitem__(self, item: int) -> AsyncIterator[T]:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def __getitem__(self, item: slice) -> Tuple[AsyncIterator[T], ...]:
|
||||||
|
...
|
||||||
|
|
||||||
|
def __getitem__(
|
||||||
|
self, item: Union[int, slice]
|
||||||
|
) -> Union[AsyncIterator[T], Tuple[AsyncIterator[T], ...]]:
|
||||||
|
return self._children[item]
|
||||||
|
|
||||||
|
def __iter__(self) -> Iterator[AsyncIterator[T]]:
|
||||||
|
yield from self._children
|
||||||
|
|
||||||
|
async def __aenter__(self) -> "Tee[T]":
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool:
|
||||||
|
await self.aclose()
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def aclose(self) -> None:
|
||||||
|
for child in self._children:
|
||||||
|
await child.aclose()
|
||||||
|
|
||||||
|
|
||||||
|
atee = Tee
|
@ -15,7 +15,7 @@ def dummy_transform(inputs: Dict[str, str]) -> Dict[str, str]:
|
|||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
def test_tranform_chain() -> None:
|
def test_transform_chain() -> None:
|
||||||
"""Test basic transform chain."""
|
"""Test basic transform chain."""
|
||||||
transform_chain = TransformChain(
|
transform_chain = TransformChain(
|
||||||
input_variables=["first_name", "last_name"],
|
input_variables=["first_name", "last_name"],
|
||||||
|
File diff suppressed because one or more lines are too long
@ -11,7 +11,7 @@ from langchain.callbacks.manager import Callbacks
|
|||||||
from langchain.callbacks.tracers.base import BaseTracer
|
from langchain.callbacks.tracers.base import BaseTracer
|
||||||
from langchain.callbacks.tracers.schemas import Run
|
from langchain.callbacks.tracers.schemas import Run
|
||||||
from langchain.chat_models.fake import FakeListChatModel
|
from langchain.chat_models.fake import FakeListChatModel
|
||||||
from langchain.llms.fake import FakeListLLM
|
from langchain.llms.fake import FakeListLLM, FakeStreamingListLLM
|
||||||
from langchain.load.dump import dumpd, dumps
|
from langchain.load.dump import dumpd, dumps
|
||||||
from langchain.output_parsers.list import CommaSeparatedListOutputParser
|
from langchain.output_parsers.list import CommaSeparatedListOutputParser
|
||||||
from langchain.prompts.chat import (
|
from langchain.prompts.chat import (
|
||||||
@ -22,6 +22,7 @@ from langchain.prompts.chat import (
|
|||||||
)
|
)
|
||||||
from langchain.schema.document import Document
|
from langchain.schema.document import Document
|
||||||
from langchain.schema.messages import AIMessage, HumanMessage, SystemMessage
|
from langchain.schema.messages import AIMessage, HumanMessage, SystemMessage
|
||||||
|
from langchain.schema.output_parser import StrOutputParser
|
||||||
from langchain.schema.retriever import BaseRetriever
|
from langchain.schema.retriever import BaseRetriever
|
||||||
from langchain.schema.runnable import (
|
from langchain.schema.runnable import (
|
||||||
RouterRunnable,
|
RouterRunnable,
|
||||||
@ -61,6 +62,8 @@ class FakeTracer(BaseTracer):
|
|||||||
if run.parent_run_id
|
if run.parent_run_id
|
||||||
else None,
|
else None,
|
||||||
"child_runs": [self._copy_run(child) for child in run.child_runs],
|
"child_runs": [self._copy_run(child) for child in run.child_runs],
|
||||||
|
"execution_order": None,
|
||||||
|
"child_execution_order": None,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -302,7 +305,7 @@ async def test_prompt_with_chat_model(
|
|||||||
tracer = FakeTracer()
|
tracer = FakeTracer()
|
||||||
assert [
|
assert [
|
||||||
*chain.stream({"question": "What is your name?"}, dict(callbacks=[tracer]))
|
*chain.stream({"question": "What is your name?"}, dict(callbacks=[tracer]))
|
||||||
] == [AIMessage(content="foo")]
|
] == [AIMessage(content="f"), AIMessage(content="o"), AIMessage(content="o")]
|
||||||
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
|
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
|
||||||
assert chat_spy.call_args.args[1] == ChatPromptValue(
|
assert chat_spy.call_args.args[1] == ChatPromptValue(
|
||||||
messages=[
|
messages=[
|
||||||
@ -678,7 +681,12 @@ async def test_router_runnable(
|
|||||||
"key": "math",
|
"key": "math",
|
||||||
"input": {"question": "2 + 2"},
|
"input": {"question": "2 + 2"},
|
||||||
}
|
}
|
||||||
assert tracer.runs == snapshot
|
assert len([r for r in tracer.runs if r.parent_run_id is None]) == 1
|
||||||
|
parent_run = next(r for r in tracer.runs if r.parent_run_id is None)
|
||||||
|
assert len(parent_run.child_runs) == 2
|
||||||
|
router_run = parent_run.child_runs[1]
|
||||||
|
assert router_run.name == "RunnableSequence" # TODO: should be RunnableRouter
|
||||||
|
assert len(router_run.child_runs) == 2
|
||||||
|
|
||||||
|
|
||||||
@freeze_time("2023-01-01")
|
@freeze_time("2023-01-01")
|
||||||
@ -758,6 +766,45 @@ def test_bind_bind() -> None:
|
|||||||
) == dumpd(llm.bind(stop=["Observation:"], one="two", hello="world"))
|
) == dumpd(llm.bind(stop=["Observation:"], one="two", hello="world"))
|
||||||
|
|
||||||
|
|
||||||
|
def test_deep_stream() -> None:
|
||||||
|
prompt = (
|
||||||
|
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
|
||||||
|
+ "{question}"
|
||||||
|
)
|
||||||
|
llm = FakeStreamingListLLM(responses=["foo-lish"])
|
||||||
|
|
||||||
|
chain = prompt | llm | StrOutputParser()
|
||||||
|
|
||||||
|
stream = chain.stream({"question": "What up"})
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
for chunk in stream:
|
||||||
|
chunks.append(chunk)
|
||||||
|
|
||||||
|
assert len(chunks) == len("foo-lish")
|
||||||
|
assert "".join(chunks) == "foo-lish"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_deep_astream() -> None:
|
||||||
|
prompt = (
|
||||||
|
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
|
||||||
|
+ "{question}"
|
||||||
|
)
|
||||||
|
llm = FakeStreamingListLLM(responses=["foo-lish"])
|
||||||
|
|
||||||
|
chain = prompt | llm | StrOutputParser()
|
||||||
|
|
||||||
|
stream = chain.astream({"question": "What up"})
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
async for chunk in stream:
|
||||||
|
chunks.append(chunk)
|
||||||
|
|
||||||
|
assert len(chunks) == len("foo-lish")
|
||||||
|
assert "".join(chunks) == "foo-lish"
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def llm_with_fallbacks() -> RunnableWithFallbacks:
|
def llm_with_fallbacks() -> RunnableWithFallbacks:
|
||||||
error_llm = FakeListLLM(responses=["foo"], i=1)
|
error_llm = FakeListLLM(responses=["foo"], i=1)
|
||||||
|
Loading…
Reference in New Issue
Block a user