Revert "Locals in config" (#9661)

Reverts langchain-ai/langchain#9007
This commit is contained in:
Bagatur 2023-08-23 10:24:59 -07:00 committed by GitHub
parent 1c64db575c
commit ef87affd4d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 140 additions and 353 deletions

View File

@ -1,4 +1,3 @@
from langchain.schema.runnable._locals import GetLocalVar, PutLocalVar
from langchain.schema.runnable.base import (
Runnable,
RunnableBinding,
@ -12,8 +11,6 @@ from langchain.schema.runnable.passthrough import RunnablePassthrough
from langchain.schema.runnable.router import RouterInput, RouterRunnable
__all__ = [
"GetLocalVar",
"PutLocalVar",
"RouterInput",
"RouterRunnable",
"Runnable",

View File

@ -1,156 +0,0 @@
from __future__ import annotations
from typing import Any, AsyncIterator, Dict, Iterator, Mapping, Optional, Union
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain.load.serializable import Serializable
from langchain.schema.runnable.base import Input, Output, Runnable
from langchain.schema.runnable.config import RunnableConfig
from langchain.schema.runnable.passthrough import RunnablePassthrough
class PutLocalVar(RunnablePassthrough):
key: Union[str, Mapping[str, str]]
"""The key(s) to use for storing the input variable(s) in local state.
If a string is provided then the entire input is stored under that key. If a
Mapping is provided, then the map values are gotten from the input and
stored in local state under the map keys.
"""
def __init__(self, key: Union[str, Mapping[str, str]], **kwargs: Any) -> None:
super().__init__(key=key, **kwargs)
def _concat_put(
self,
input: Input,
*,
config: Optional[RunnableConfig] = None,
replace: bool = False,
) -> None:
if config is None:
raise ValueError(
"PutLocalVar should only be used in a RunnableSequence, and should "
"therefore always receive a non-null config."
)
if isinstance(self.key, str):
if self.key not in config["_locals"] or replace:
config["_locals"][self.key] = input
else:
config["_locals"][self.key] += input
elif isinstance(self.key, Mapping):
if not isinstance(input, Mapping):
raise TypeError(
f"Received key of type Mapping but input of type {type(input)}. "
f"input is expected to be of type Mapping when key is Mapping."
)
for input_key, put_key in self.key.items():
if put_key not in config["_locals"] or replace:
config["_locals"][put_key] = input[input_key]
else:
config["_locals"][put_key] += input[input_key]
else:
raise TypeError(
f"`key` should be a string or Mapping[str, str], received type "
f"{(type(self.key))}."
)
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input:
self._concat_put(input, config=config, replace=True)
return super().invoke(input, config=config)
async def ainvoke(
self,
input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Input:
self._concat_put(input, config=config, replace=True)
return await super().ainvoke(input, config=config)
def transform(
self,
input: Iterator[Input],
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Iterator[Input]:
for chunk in super().transform(input, config=config):
self._concat_put(chunk, config=config)
yield chunk
async def atransform(
self,
input: AsyncIterator[Input],
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> AsyncIterator[Input]:
async for chunk in super().atransform(input, config=config):
self._concat_put(chunk, config=config)
yield chunk
class GetLocalVar(
Serializable, Runnable[Input, Union[Output, Dict[str, Union[Input, Output]]]]
):
key: str
"""The key to extract from the local state."""
passthrough_key: Optional[str] = None
"""The key to use for passing through the invocation input.
If None, then only the value retrieved from local state is returned. Otherwise a
dictionary ``{self.key: <<retrieved_value>>, self.passthrough_key: <<input>>}``
is returned.
"""
def __init__(self, key: str, **kwargs: Any) -> None:
super().__init__(key=key, **kwargs)
def _get(
self,
input: Input,
run_manager: Union[CallbackManagerForChainRun, Any],
config: RunnableConfig,
) -> Union[Output, Dict[str, Union[Input, Output]]]:
if self.passthrough_key:
return {
self.key: config["_locals"][self.key],
self.passthrough_key: input,
}
else:
return config["_locals"][self.key]
async def _aget(
self,
input: Input,
run_manager: AsyncCallbackManagerForChainRun,
config: RunnableConfig,
) -> Union[Output, Dict[str, Union[Input, Output]]]:
return self._get(input, run_manager, config)
def invoke(
self, input: Input, config: Optional[RunnableConfig] = None
) -> Union[Output, Dict[str, Union[Input, Output]]]:
if config is None:
raise ValueError(
"GetLocalVar should only be used in a RunnableSequence, and should "
"therefore always receive a non-null config."
)
return self._call_with_config(self._get, input, config)
async def ainvoke(
self,
input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Union[Output, Dict[str, Union[Input, Output]]]:
if config is None:
raise ValueError(
"GetLocalVar should only be used in a RunnableSequence, and should "
"therefore always receive a non-null config."
)
return await self._acall_with_config(self._aget, input, config)

View File

@ -5,7 +5,6 @@ import copy
import threading
from abc import ABC, abstractmethod
from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait
from copy import deepcopy
from functools import partial
from itertools import tee
from typing import (
@ -35,16 +34,11 @@ if TYPE_CHECKING:
)
from langchain.callbacks.base import BaseCallbackManager
from langchain.load.dump import dumpd
from langchain.load.serializable import Serializable
from langchain.pydantic_v1 import Field
from langchain.schema.runnable.config import (
RunnableConfig,
ensure_config,
get_async_callback_manager_for_config,
get_callback_manager_for_config,
patch_config,
)
from langchain.schema.runnable.config import RunnableConfig
from langchain.schema.runnable.utils import (
accepts_run_manager,
accepts_run_manager_and_config,
@ -244,9 +238,9 @@ class Runnable(Generic[Input, Output], ABC):
)
return (
list(map(ensure_config, config))
config
if isinstance(config, list)
else [deepcopy(ensure_config(config)) for _ in range(length)]
else [config.copy() if config is not None else {} for _ in range(length)]
)
def _call_with_config(
@ -262,8 +256,14 @@ class Runnable(Generic[Input, Output], ABC):
) -> Output:
"""Helper method to transform an Input value to an Output value,
with callbacks. Use this method to implement invoke() in subclasses."""
config = ensure_config(config)
callback_manager = get_callback_manager_for_config(config)
from langchain.callbacks.manager import CallbackManager
config = config or {}
callback_manager = CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
inheritable_tags=config.get("tags"),
inheritable_metadata=config.get("metadata"),
)
run_manager = callback_manager.on_chain_start(
dumpd(self),
input,
@ -303,8 +303,14 @@ class Runnable(Generic[Input, Output], ABC):
) -> Output:
"""Helper method to transform an Input value to an Output value,
with callbacks. Use this method to implement ainvoke() in subclasses."""
config = ensure_config(config)
callback_manager = get_async_callback_manager_for_config(config)
from langchain.callbacks.manager import AsyncCallbackManager
config = config or {}
callback_manager = AsyncCallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
inheritable_tags=config.get("tags"),
inheritable_metadata=config.get("metadata"),
)
run_manager = await callback_manager.on_chain_start(
dumpd(self),
input,
@ -352,6 +358,8 @@ class Runnable(Generic[Input, Output], ABC):
"""Helper method to transform an Iterator of Input values into an Iterator of
Output values, with callbacks.
Use this to implement `stream()` or `transform()` in Runnable subclasses."""
from langchain.callbacks.manager import CallbackManager
# tee the input so we can iterate over it twice
input_for_tracing, input_for_transform = tee(input, 2)
# Start the input iterator to ensure the input runnable starts before this one
@ -360,8 +368,12 @@ class Runnable(Generic[Input, Output], ABC):
final_output: Optional[Output] = None
final_output_supported = True
config = ensure_config(config)
callback_manager = get_callback_manager_for_config(config)
config = config or {}
callback_manager = CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
inheritable_tags=config.get("tags"),
inheritable_metadata=config.get("metadata"),
)
run_manager = callback_manager.on_chain_start(
dumpd(self),
{"input": ""},
@ -432,6 +444,8 @@ class Runnable(Generic[Input, Output], ABC):
"""Helper method to transform an Async Iterator of Input values into an Async
Iterator of Output values, with callbacks.
Use this to implement `astream()` or `atransform()` in Runnable subclasses."""
from langchain.callbacks.manager import AsyncCallbackManager
# tee the input so we can iterate over it twice
input_for_tracing, input_for_transform = atee(input, 2)
# Start the input iterator to ensure the input runnable starts before this one
@ -440,8 +454,12 @@ class Runnable(Generic[Input, Output], ABC):
final_output: Optional[Output] = None
final_output_supported = True
config = ensure_config(config)
callback_manager = get_async_callback_manager_for_config(config)
config = config or {}
callback_manager = AsyncCallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
inheritable_tags=config.get("tags"),
inheritable_metadata=config.get("metadata"),
)
run_manager = await callback_manager.on_chain_start(
dumpd(self),
{"input": ""},
@ -517,9 +535,19 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
yield from self.fallbacks
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
from langchain.callbacks.manager import CallbackManager
# setup callbacks
config = ensure_config(config)
callback_manager = get_callback_manager_for_config(config)
config = config or {}
callback_manager = CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
local_callbacks=None,
verbose=False,
inheritable_tags=config.get("tags"),
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
)
# start the root run
run_manager = callback_manager.on_chain_start(dumpd(self), input)
first_error = None
@ -549,9 +577,19 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Output:
from langchain.callbacks.manager import AsyncCallbackManager
# setup callbacks
config = ensure_config(config)
callback_manager = get_async_callback_manager_for_config(config)
config = config or {}
callback_manager = AsyncCallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
local_callbacks=None,
verbose=False,
inheritable_tags=config.get("tags"),
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
)
# start the root run
run_manager = await callback_manager.on_chain_start(dumpd(self), input)
@ -770,9 +808,19 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
)
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
from langchain.callbacks.manager import CallbackManager
# setup callbacks
config = ensure_config(config)
callback_manager = get_callback_manager_for_config(config)
config = config or {}
callback_manager = CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
local_callbacks=None,
verbose=False,
inheritable_tags=config.get("tags"),
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
)
# start the root run
run_manager = callback_manager.on_chain_start(dumpd(self), input)
@ -798,9 +846,19 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Output:
from langchain.callbacks.manager import AsyncCallbackManager
# setup callbacks
config = ensure_config(config)
callback_manager = get_async_callback_manager_for_config(config)
config = config or {}
callback_manager = AsyncCallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
local_callbacks=None,
verbose=False,
inheritable_tags=config.get("tags"),
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
)
# start the root run
run_manager = await callback_manager.on_chain_start(dumpd(self), input)
@ -935,9 +993,19 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Iterator[Output]:
from langchain.callbacks.manager import CallbackManager
# setup callbacks
config = ensure_config(config)
callback_manager = get_callback_manager_for_config(config)
config = config or {}
callback_manager = CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
local_callbacks=None,
verbose=False,
inheritable_tags=config.get("tags"),
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
)
# start the root run
run_manager = callback_manager.on_chain_start(dumpd(self), input)
@ -1001,9 +1069,19 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> AsyncIterator[Output]:
from langchain.callbacks.manager import AsyncCallbackManager
# setup callbacks
config = ensure_config(config)
callback_manager = get_async_callback_manager_for_config(config)
config = config or {}
callback_manager = AsyncCallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
local_callbacks=None,
verbose=False,
inheritable_tags=config.get("tags"),
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
)
# start the root run
run_manager = await callback_manager.on_chain_start(dumpd(self), input)
@ -1115,7 +1193,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
from langchain.callbacks.manager import CallbackManager
# setup callbacks
config = ensure_config(config)
config = config or {}
callback_manager = CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
local_callbacks=None,
@ -1138,7 +1216,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
step.invoke,
input,
# mark each step as a child run
patch_config(deepcopy(config), run_manager.get_child()),
patch_config(config, run_manager.get_child()),
)
for step in steps.values()
]
@ -1157,9 +1235,19 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Dict[str, Any]:
from langchain.callbacks.manager import AsyncCallbackManager
# setup callbacks
config = ensure_config(config)
callback_manager = get_async_callback_manager_for_config(config)
config = config or {}
callback_manager = AsyncCallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
local_callbacks=None,
verbose=False,
inheritable_tags=config.get("tags"),
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
)
# start the root run
run_manager = await callback_manager.on_chain_start(dumpd(self), input)
@ -1452,6 +1540,14 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
yield item
def patch_config(
config: RunnableConfig, callback_manager: BaseCallbackManager
) -> RunnableConfig:
config = config.copy()
config["callbacks"] = callback_manager
return config
def coerce_to_runnable(
thing: Union[
Runnable[Input, Output],

View File

@ -1,9 +1,8 @@
from __future__ import annotations
from typing import Any, Dict, List, Optional, TypedDict
from typing import Any, Dict, List, TypedDict
from langchain.callbacks.base import BaseCallbackManager, Callbacks
from langchain.callbacks.manager import AsyncCallbackManager, CallbackManager
from langchain.callbacks.base import Callbacks
class RunnableConfig(TypedDict, total=False):
@ -26,42 +25,3 @@ class RunnableConfig(TypedDict, total=False):
Callbacks for this call and any sub-calls (eg. a Chain calling an LLM).
Tags are passed to all callbacks, metadata is passed to handle*Start callbacks.
"""
_locals: Dict[str, Any]
"""
Local variables
"""
def ensure_config(config: Optional[RunnableConfig]) -> RunnableConfig:
empty = RunnableConfig(tags=[], metadata={}, callbacks=None, _locals={})
if config is not None:
empty.update(config)
return empty
def patch_config(
config: RunnableConfig,
callbacks: BaseCallbackManager,
) -> RunnableConfig:
config = config.copy()
config["callbacks"] = callbacks
return config
def get_callback_manager_for_config(config: RunnableConfig) -> CallbackManager:
return CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
inheritable_tags=config.get("tags"),
inheritable_metadata=config.get("metadata"),
)
def get_async_callback_manager_for_config(
config: RunnableConfig,
) -> AsyncCallbackManager:
return AsyncCallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
inheritable_tags=config.get("tags"),
inheritable_metadata=config.get("metadata"),
)

View File

@ -47,11 +47,10 @@ class RunnablePassthrough(Serializable, Runnable[Input, Input]):
) -> Iterator[Input]:
return self._transform_stream_with_config(input, identity, config)
async def atransform(
def atransform(
self,
input: AsyncIterator[Input],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> AsyncIterator[Input]:
async for chunk in self._atransform_stream_with_config(input, identity, config):
yield chunk
return self._atransform_stream_with_config(input, identity, config)

View File

@ -1352,7 +1352,6 @@
"lc": 1,
"type": "not_implemented",
"id": [
"runnable",
"test_runnable",
"FakeRetriever"
]

View File

@ -1,93 +0,0 @@
from typing import Any, Callable, Type
import pytest
from langchain import PromptTemplate
from langchain.llms import FakeListLLM
from langchain.schema.runnable import (
GetLocalVar,
PutLocalVar,
RunnablePassthrough,
RunnableSequence,
)
@pytest.mark.parametrize(
("method", "input", "output"),
[
(lambda r, x: r.invoke(x), "foo", "foo"),
(lambda r, x: r.batch(x), ["foo", "bar"], ["foo", "bar"]),
(lambda r, x: list(r.stream(x))[0], "foo", "foo"),
],
)
def test_put_get(method: Callable, input: Any, output: Any) -> None:
runnable = PutLocalVar("input") | GetLocalVar("input")
assert method(runnable, input) == output
@pytest.mark.asyncio
@pytest.mark.parametrize(
("method", "input", "output"),
[
(lambda r, x: r.ainvoke(x), "foo", "foo"),
(lambda r, x: r.abatch(x), ["foo", "bar"], ["foo", "bar"]),
],
)
async def test_put_get_async(method: Callable, input: Any, output: Any) -> None:
runnable = PutLocalVar("input") | GetLocalVar("input")
assert await method(runnable, input) == output
@pytest.mark.parametrize(
("runnable", "error"),
[
(PutLocalVar("input"), ValueError),
(GetLocalVar("input"), ValueError),
(PutLocalVar("input") | GetLocalVar("missing"), KeyError),
],
)
def test_incorrect_usage(runnable: RunnableSequence, error: Type[Exception]) -> None:
with pytest.raises(error):
runnable.invoke("foo")
def test_get_in_map() -> None:
runnable: RunnableSequence = PutLocalVar("input") | {"bar": GetLocalVar("input")}
assert runnable.invoke("foo") == {"bar": "foo"}
def test_put_in_map() -> None:
runnable: RunnableSequence = {"bar": PutLocalVar("input")} | GetLocalVar("input")
with pytest.raises(KeyError):
runnable.invoke("foo")
@pytest.mark.parametrize(
"runnable",
[
PutLocalVar("input") | GetLocalVar("input", passthrough_key="output"),
(
PutLocalVar("input")
| {"input": RunnablePassthrough()}
| PromptTemplate.from_template("say {input}")
| FakeListLLM(responses=["hello"])
| GetLocalVar("input", passthrough_key="output")
),
],
)
@pytest.mark.parametrize(
("method", "input", "output"),
[
(lambda r, x: r.invoke(x), "hello", {"input": "hello", "output": "hello"}),
(lambda r, x: r.batch(x), ["hello"], [{"input": "hello", "output": "hello"}]),
(
lambda r, x: list(r.stream(x))[0],
"hello",
{"input": "hello", "output": "hello"},
),
],
)
def test_put_get_sequence(
runnable: RunnableSequence, method: Callable, input: Any, output: Any
) -> None:
assert method(runnable, input) == output

View File

@ -132,24 +132,15 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
["hello", "wooorld"], [dict(tags=["a-tag"]), dict(metadata={"key": "value"})]
) == [5, 7]
assert spy.call_args_list == [
mocker.call(
"hello", dict(tags=["a-tag"], metadata={}, callbacks=None, _locals={})
),
mocker.call(
"wooorld",
dict(metadata={"key": "value"}, tags=[], callbacks=None, _locals={}),
),
mocker.call("hello", dict(tags=["a-tag"])),
mocker.call("wooorld", dict(metadata={"key": "value"})),
]
spy.reset_mock()
assert fake.batch(["hello", "wooorld"], dict(tags=["a-tag"])) == [5, 7]
assert spy.call_args_list == [
mocker.call(
"hello", dict(tags=["a-tag"], metadata={}, callbacks=None, _locals={})
),
mocker.call(
"wooorld", dict(tags=["a-tag"], metadata={}, callbacks=None, _locals={})
),
mocker.call("hello", dict(tags=["a-tag"])),
mocker.call("wooorld", dict(tags=["a-tag"])),
]
spy.reset_mock()
@ -170,14 +161,8 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
7,
]
assert spy.call_args_list == [
mocker.call(
"hello",
dict(metadata={"key": "value"}, tags=[], callbacks=None, _locals={}),
),
mocker.call(
"wooorld",
dict(metadata={"key": "value"}, tags=[], callbacks=None, _locals={}),
),
mocker.call("hello", dict(metadata={"key": "value"})),
mocker.call("wooorld", dict(metadata={"key": "value"})),
]