mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-17 16:39:52 +00:00
wip
This commit is contained in:
parent
539672a7fd
commit
50b13ab938
@ -62,7 +62,9 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
|
|||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
return self(input, **(config or {}), **kwargs)
|
_config: Dict[str, Any] = dict(config) if config else {}
|
||||||
|
_config.pop("_locals", None)
|
||||||
|
return self(input, **_config, **kwargs)
|
||||||
|
|
||||||
async def ainvoke(
|
async def ainvoke(
|
||||||
self,
|
self,
|
||||||
@ -76,7 +78,9 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
|
|||||||
None, partial(self.invoke, input, config, **kwargs)
|
None, partial(self.invoke, input, config, **kwargs)
|
||||||
)
|
)
|
||||||
|
|
||||||
return await self.acall(input, **(config or {}), **kwargs)
|
_config: Dict[str, Any] = dict(config) if config else {}
|
||||||
|
_config.pop("_locals", None)
|
||||||
|
return await self.acall(input, **_config, **kwargs)
|
||||||
|
|
||||||
memory: Optional[BaseMemory] = None
|
memory: Optional[BaseMemory] = None
|
||||||
"""Optional memory object. Defaults to None.
|
"""Optional memory object. Defaults to None.
|
||||||
|
@ -3,6 +3,7 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from copy import deepcopy
|
||||||
from itertools import tee
|
from itertools import tee
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
@ -66,6 +67,35 @@ class RunnableConfig(TypedDict, total=False):
|
|||||||
Tags are passed to all callbacks, metadata is passed to handle*Start callbacks.
|
Tags are passed to all callbacks, metadata is passed to handle*Start callbacks.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
_locals: Dict[str, Any]
|
||||||
|
"""
|
||||||
|
Local variables
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def _empty_config() -> RunnableConfig:
|
||||||
|
return RunnableConfig(tags=[], metadata={}, callbacks=None, _locals={})
|
||||||
|
|
||||||
|
|
||||||
|
def _get_callback_manager(config: Mapping) -> Any:
|
||||||
|
from langchain.callbacks.manager import CallbackManager
|
||||||
|
|
||||||
|
return CallbackManager.configure(
|
||||||
|
inheritable_callbacks=config.get("callbacks"),
|
||||||
|
inheritable_tags=config.get("tags"),
|
||||||
|
inheritable_metadata=config.get("metadata"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_async_callback_manager(config: Mapping) -> Any:
|
||||||
|
from langchain.callbacks.manager import AsyncCallbackManager
|
||||||
|
|
||||||
|
return AsyncCallbackManager.configure(
|
||||||
|
inheritable_callbacks=config.get("callbacks"),
|
||||||
|
inheritable_tags=config.get("tags"),
|
||||||
|
inheritable_metadata=config.get("metadata"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
Input = TypeVar("Input")
|
Input = TypeVar("Input")
|
||||||
# Output type should implement __concat__, as eg str, list, dict do
|
# Output type should implement __concat__, as eg str, list, dict do
|
||||||
@ -243,7 +273,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
return (
|
return (
|
||||||
config
|
config
|
||||||
if isinstance(config, list)
|
if isinstance(config, list)
|
||||||
else [config.copy() if config is not None else {} for _ in range(length)]
|
else [deepcopy(config) if config is not None else {} for _ in range(length)]
|
||||||
)
|
)
|
||||||
|
|
||||||
def _call_with_config(
|
def _call_with_config(
|
||||||
@ -255,14 +285,8 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
) -> Output:
|
) -> Output:
|
||||||
"""Helper method to transform an Input value to an Output value,
|
"""Helper method to transform an Input value to an Output value,
|
||||||
with callbacks. Use this method to implement invoke() in subclasses."""
|
with callbacks. Use this method to implement invoke() in subclasses."""
|
||||||
from langchain.callbacks.manager import CallbackManager
|
|
||||||
|
|
||||||
config = config or {}
|
config = config or {}
|
||||||
callback_manager = CallbackManager.configure(
|
callback_manager = _get_callback_manager(config)
|
||||||
inheritable_callbacks=config.get("callbacks"),
|
|
||||||
inheritable_tags=config.get("tags"),
|
|
||||||
inheritable_metadata=config.get("metadata"),
|
|
||||||
)
|
|
||||||
run_manager = callback_manager.on_chain_start(
|
run_manager = callback_manager.on_chain_start(
|
||||||
dumpd(self),
|
dumpd(self),
|
||||||
input if isinstance(input, dict) else {"input": input},
|
input if isinstance(input, dict) else {"input": input},
|
||||||
@ -288,14 +312,8 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
) -> Output:
|
) -> Output:
|
||||||
"""Helper method to transform an Input value to an Output value,
|
"""Helper method to transform an Input value to an Output value,
|
||||||
with callbacks. Use this method to implement ainvoke() in subclasses."""
|
with callbacks. Use this method to implement ainvoke() in subclasses."""
|
||||||
from langchain.callbacks.manager import AsyncCallbackManager
|
|
||||||
|
|
||||||
config = config or {}
|
config = config or {}
|
||||||
callback_manager = AsyncCallbackManager.configure(
|
callback_manager = _get_async_callback_manager(config)
|
||||||
inheritable_callbacks=config.get("callbacks"),
|
|
||||||
inheritable_tags=config.get("tags"),
|
|
||||||
inheritable_metadata=config.get("metadata"),
|
|
||||||
)
|
|
||||||
run_manager = await callback_manager.on_chain_start(
|
run_manager = await callback_manager.on_chain_start(
|
||||||
dumpd(self),
|
dumpd(self),
|
||||||
input if isinstance(input, dict) else {"input": input},
|
input if isinstance(input, dict) else {"input": input},
|
||||||
@ -322,8 +340,6 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
"""Helper method to transform an Iterator of Input values into an Iterator of
|
"""Helper method to transform an Iterator of Input values into an Iterator of
|
||||||
Output values, with callbacks.
|
Output values, with callbacks.
|
||||||
Use this to implement `stream()` or `transform()` in Runnable subclasses."""
|
Use this to implement `stream()` or `transform()` in Runnable subclasses."""
|
||||||
from langchain.callbacks.manager import CallbackManager
|
|
||||||
|
|
||||||
# tee the input so we can iterate over it twice
|
# tee the input so we can iterate over it twice
|
||||||
input_for_tracing, input_for_transform = tee(input, 2)
|
input_for_tracing, input_for_transform = tee(input, 2)
|
||||||
# Start the input iterator to ensure the input runnable starts before this one
|
# Start the input iterator to ensure the input runnable starts before this one
|
||||||
@ -333,11 +349,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
final_output_supported = True
|
final_output_supported = True
|
||||||
|
|
||||||
config = config or {}
|
config = config or {}
|
||||||
callback_manager = CallbackManager.configure(
|
callback_manager = _get_callback_manager(config)
|
||||||
inheritable_callbacks=config.get("callbacks"),
|
|
||||||
inheritable_tags=config.get("tags"),
|
|
||||||
inheritable_metadata=config.get("metadata"),
|
|
||||||
)
|
|
||||||
run_manager = callback_manager.on_chain_start(
|
run_manager = callback_manager.on_chain_start(
|
||||||
dumpd(self),
|
dumpd(self),
|
||||||
{"input": ""},
|
{"input": ""},
|
||||||
@ -393,8 +405,6 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
"""Helper method to transform an Async Iterator of Input values into an Async
|
"""Helper method to transform an Async Iterator of Input values into an Async
|
||||||
Iterator of Output values, with callbacks.
|
Iterator of Output values, with callbacks.
|
||||||
Use this to implement `astream()` or `atransform()` in Runnable subclasses."""
|
Use this to implement `astream()` or `atransform()` in Runnable subclasses."""
|
||||||
from langchain.callbacks.manager import AsyncCallbackManager
|
|
||||||
|
|
||||||
# tee the input so we can iterate over it twice
|
# tee the input so we can iterate over it twice
|
||||||
input_for_tracing, input_for_transform = atee(input, 2)
|
input_for_tracing, input_for_transform = atee(input, 2)
|
||||||
# Start the input iterator to ensure the input runnable starts before this one
|
# Start the input iterator to ensure the input runnable starts before this one
|
||||||
@ -404,11 +414,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
final_output_supported = True
|
final_output_supported = True
|
||||||
|
|
||||||
config = config or {}
|
config = config or {}
|
||||||
callback_manager = AsyncCallbackManager.configure(
|
callback_manager = _get_async_callback_manager(config)
|
||||||
inheritable_callbacks=config.get("callbacks"),
|
|
||||||
inheritable_tags=config.get("tags"),
|
|
||||||
inheritable_metadata=config.get("metadata"),
|
|
||||||
)
|
|
||||||
run_manager = await callback_manager.on_chain_start(
|
run_manager = await callback_manager.on_chain_start(
|
||||||
dumpd(self),
|
dumpd(self),
|
||||||
{"input": ""},
|
{"input": ""},
|
||||||
@ -473,19 +479,9 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
|||||||
yield from self.fallbacks
|
yield from self.fallbacks
|
||||||
|
|
||||||
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
|
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
|
||||||
from langchain.callbacks.manager import CallbackManager
|
|
||||||
|
|
||||||
# setup callbacks
|
# setup callbacks
|
||||||
config = config or {}
|
config = config or {}
|
||||||
callback_manager = CallbackManager.configure(
|
callback_manager = _get_callback_manager(config)
|
||||||
inheritable_callbacks=config.get("callbacks"),
|
|
||||||
local_callbacks=None,
|
|
||||||
verbose=False,
|
|
||||||
inheritable_tags=config.get("tags"),
|
|
||||||
local_tags=None,
|
|
||||||
inheritable_metadata=config.get("metadata"),
|
|
||||||
local_metadata=None,
|
|
||||||
)
|
|
||||||
# start the root run
|
# start the root run
|
||||||
run_manager = callback_manager.on_chain_start(
|
run_manager = callback_manager.on_chain_start(
|
||||||
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
||||||
@ -516,19 +512,9 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
|||||||
async def ainvoke(
|
async def ainvoke(
|
||||||
self, input: Input, config: Optional[RunnableConfig] = None
|
self, input: Input, config: Optional[RunnableConfig] = None
|
||||||
) -> Output:
|
) -> Output:
|
||||||
from langchain.callbacks.manager import AsyncCallbackManager
|
|
||||||
|
|
||||||
# setup callbacks
|
# setup callbacks
|
||||||
config = config or {}
|
config = config or {}
|
||||||
callback_manager = AsyncCallbackManager.configure(
|
callback_manager = _get_async_callback_manager(config)
|
||||||
inheritable_callbacks=config.get("callbacks"),
|
|
||||||
local_callbacks=None,
|
|
||||||
verbose=False,
|
|
||||||
inheritable_tags=config.get("tags"),
|
|
||||||
local_tags=None,
|
|
||||||
inheritable_metadata=config.get("metadata"),
|
|
||||||
local_metadata=None,
|
|
||||||
)
|
|
||||||
# start the root run
|
# start the root run
|
||||||
run_manager = await callback_manager.on_chain_start(
|
run_manager = await callback_manager.on_chain_start(
|
||||||
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
||||||
@ -751,19 +737,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
|
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
|
||||||
from langchain.callbacks.manager import CallbackManager
|
|
||||||
|
|
||||||
# setup callbacks
|
# setup callbacks
|
||||||
config = config or {}
|
config = config or _empty_config()
|
||||||
callback_manager = CallbackManager.configure(
|
callback_manager = _get_callback_manager(config)
|
||||||
inheritable_callbacks=config.get("callbacks"),
|
|
||||||
local_callbacks=None,
|
|
||||||
verbose=False,
|
|
||||||
inheritable_tags=config.get("tags"),
|
|
||||||
local_tags=None,
|
|
||||||
inheritable_metadata=config.get("metadata"),
|
|
||||||
local_metadata=None,
|
|
||||||
)
|
|
||||||
# start the root run
|
# start the root run
|
||||||
run_manager = callback_manager.on_chain_start(
|
run_manager = callback_manager.on_chain_start(
|
||||||
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
||||||
@ -771,11 +747,12 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
|
|
||||||
# invoke all steps in sequence
|
# invoke all steps in sequence
|
||||||
try:
|
try:
|
||||||
|
callbacks = run_manager.get_child()
|
||||||
for step in self.steps:
|
for step in self.steps:
|
||||||
input = step.invoke(
|
input = step.invoke(
|
||||||
input,
|
input,
|
||||||
# mark each step as a child run
|
# mark each step as a child run
|
||||||
_patch_config(config, run_manager.get_child()),
|
_patch_config(config, callbacks),
|
||||||
)
|
)
|
||||||
# finish the root run
|
# finish the root run
|
||||||
except (KeyboardInterrupt, Exception) as e:
|
except (KeyboardInterrupt, Exception) as e:
|
||||||
@ -790,19 +767,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
async def ainvoke(
|
async def ainvoke(
|
||||||
self, input: Input, config: Optional[RunnableConfig] = None
|
self, input: Input, config: Optional[RunnableConfig] = None
|
||||||
) -> Output:
|
) -> Output:
|
||||||
from langchain.callbacks.manager import AsyncCallbackManager
|
|
||||||
|
|
||||||
# setup callbacks
|
# setup callbacks
|
||||||
config = config or {}
|
config = config or {}
|
||||||
callback_manager = AsyncCallbackManager.configure(
|
callback_manager = _get_async_callback_manager(config)
|
||||||
inheritable_callbacks=config.get("callbacks"),
|
|
||||||
local_callbacks=None,
|
|
||||||
verbose=False,
|
|
||||||
inheritable_tags=config.get("tags"),
|
|
||||||
local_tags=None,
|
|
||||||
inheritable_metadata=config.get("metadata"),
|
|
||||||
local_metadata=None,
|
|
||||||
)
|
|
||||||
# start the root run
|
# start the root run
|
||||||
run_manager = await callback_manager.on_chain_start(
|
run_manager = await callback_manager.on_chain_start(
|
||||||
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
||||||
@ -946,19 +913,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
def stream(
|
def stream(
|
||||||
self, input: Input, config: Optional[RunnableConfig] = None
|
self, input: Input, config: Optional[RunnableConfig] = None
|
||||||
) -> Iterator[Output]:
|
) -> Iterator[Output]:
|
||||||
from langchain.callbacks.manager import CallbackManager
|
|
||||||
|
|
||||||
# setup callbacks
|
# setup callbacks
|
||||||
config = config or {}
|
config = config or {}
|
||||||
callback_manager = CallbackManager.configure(
|
callback_manager = _get_callback_manager(config)
|
||||||
inheritable_callbacks=config.get("callbacks"),
|
|
||||||
local_callbacks=None,
|
|
||||||
verbose=False,
|
|
||||||
inheritable_tags=config.get("tags"),
|
|
||||||
local_tags=None,
|
|
||||||
inheritable_metadata=config.get("metadata"),
|
|
||||||
local_metadata=None,
|
|
||||||
)
|
|
||||||
# start the root run
|
# start the root run
|
||||||
run_manager = callback_manager.on_chain_start(
|
run_manager = callback_manager.on_chain_start(
|
||||||
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
||||||
@ -1023,19 +980,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
async def astream(
|
async def astream(
|
||||||
self, input: Input, config: Optional[RunnableConfig] = None
|
self, input: Input, config: Optional[RunnableConfig] = None
|
||||||
) -> AsyncIterator[Output]:
|
) -> AsyncIterator[Output]:
|
||||||
from langchain.callbacks.manager import AsyncCallbackManager
|
|
||||||
|
|
||||||
# setup callbacks
|
# setup callbacks
|
||||||
config = config or {}
|
config = config or {}
|
||||||
callback_manager = AsyncCallbackManager.configure(
|
callback_manager = _get_async_callback_manager(config)
|
||||||
inheritable_callbacks=config.get("callbacks"),
|
|
||||||
local_callbacks=None,
|
|
||||||
verbose=False,
|
|
||||||
inheritable_tags=config.get("tags"),
|
|
||||||
local_tags=None,
|
|
||||||
inheritable_metadata=config.get("metadata"),
|
|
||||||
local_metadata=None,
|
|
||||||
)
|
|
||||||
# start the root run
|
# start the root run
|
||||||
run_manager = await callback_manager.on_chain_start(
|
run_manager = await callback_manager.on_chain_start(
|
||||||
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
||||||
@ -1173,19 +1120,9 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
|||||||
async def ainvoke(
|
async def ainvoke(
|
||||||
self, input: Input, config: Optional[RunnableConfig] = None
|
self, input: Input, config: Optional[RunnableConfig] = None
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
from langchain.callbacks.manager import AsyncCallbackManager
|
|
||||||
|
|
||||||
# setup callbacks
|
# setup callbacks
|
||||||
config = config or {}
|
config = config or {}
|
||||||
callback_manager = AsyncCallbackManager.configure(
|
callback_manager = _get_async_callback_manager(config)
|
||||||
inheritable_callbacks=config.get("callbacks"),
|
|
||||||
local_callbacks=None,
|
|
||||||
verbose=False,
|
|
||||||
inheritable_tags=config.get("tags"),
|
|
||||||
local_tags=None,
|
|
||||||
inheritable_metadata=config.get("metadata"),
|
|
||||||
local_metadata=None,
|
|
||||||
)
|
|
||||||
# start the root run
|
# start the root run
|
||||||
run_manager = await callback_manager.on_chain_start(
|
run_manager = await callback_manager.on_chain_start(
|
||||||
dumpd(self), {"input": input}
|
dumpd(self), {"input": input}
|
||||||
@ -1464,10 +1401,11 @@ class RouterRunnable(
|
|||||||
|
|
||||||
|
|
||||||
def _patch_config(
|
def _patch_config(
|
||||||
config: RunnableConfig, callback_manager: BaseCallbackManager
|
config: RunnableConfig, callback_manager: BaseCallbackManager, _locals: Optional[Dict[str, Any]] = None
|
||||||
) -> RunnableConfig:
|
) -> RunnableConfig:
|
||||||
config = config.copy()
|
config = deepcopy(config)
|
||||||
config["callbacks"] = callback_manager
|
config["callbacks"] = callback_manager
|
||||||
|
config["_locals"] = _locals or {}
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
@ -636,7 +636,9 @@ async def _arun_chain(
|
|||||||
else:
|
else:
|
||||||
output = await chain.acall(inputs_, callbacks=callbacks, tags=tags)
|
output = await chain.acall(inputs_, callbacks=callbacks, tags=tags)
|
||||||
else:
|
else:
|
||||||
runnable_config = RunnableConfig(tags=tags or [], callbacks=callbacks)
|
runnable_config = RunnableConfig(
|
||||||
|
tags=tags or [], callbacks=callbacks, _locals={}
|
||||||
|
)
|
||||||
output = await chain.ainvoke(inputs_, config=runnable_config)
|
output = await chain.ainvoke(inputs_, config=runnable_config)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -957,7 +959,9 @@ def _run_chain(
|
|||||||
else:
|
else:
|
||||||
output = chain(inputs_, callbacks=callbacks, tags=tags)
|
output = chain(inputs_, callbacks=callbacks, tags=tags)
|
||||||
else:
|
else:
|
||||||
runnable_config = RunnableConfig(tags=tags or [], callbacks=callbacks)
|
runnable_config = RunnableConfig(
|
||||||
|
tags=tags or [], callbacks=callbacks, _locals={}
|
||||||
|
)
|
||||||
output = chain.invoke(inputs_, config=runnable_config)
|
output = chain.invoke(inputs_, config=runnable_config)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user