mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-04 04:07:54 +00:00
Add async support for transform chain (#8205)
This commit is contained in:
parent
8f158b72fc
commit
3662aca7d4
@ -1,9 +1,16 @@
|
||||
"""Chain that runs an arbitrary python function."""
|
||||
from typing import Callable, Dict, List, Optional
|
||||
import functools
|
||||
import logging
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain.chains.base import Chain
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TransformChain(Chain):
|
||||
"""Chain transform chain output.
|
||||
@ -17,8 +24,22 @@ class TransformChain(Chain):
|
||||
"""
|
||||
|
||||
input_variables: List[str]
|
||||
"""The keys expected by the transform's input dictionary."""
|
||||
output_variables: List[str]
|
||||
"""The keys returned by the transform's output dictionary."""
|
||||
transform: Callable[[Dict[str, str]], Dict[str, str]]
|
||||
"""The transform function."""
|
||||
atransform: Optional[Callable[[Dict[str, Any]], Awaitable[Dict[str, Any]]]] = None
|
||||
"""The async coroutine transform function."""
|
||||
|
||||
@staticmethod
|
||||
@functools.lru_cache
|
||||
def _log_once(msg: str) -> None:
|
||||
"""Log a message once.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
logger.warning(msg)
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
@ -42,3 +63,17 @@ class TransformChain(Chain):
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
return self.transform(inputs)
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
if self.atransform is not None:
|
||||
return await self.atransform(inputs)
|
||||
else:
|
||||
self._log_once(
|
||||
"TransformChain's atransform is not provided, falling"
|
||||
" back to synchronous transform"
|
||||
)
|
||||
return self.transform(inputs)
|
||||
|
Loading…
Reference in New Issue
Block a user