mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-07 03:56:39 +00:00
parent
8c150ad7f6
commit
34ffb94770
@ -14,7 +14,6 @@ creating more responsive UX.
|
|||||||
|
|
||||||
This module contains schema and implementation of LangChain Runnables primitives.
|
This module contains schema and implementation of LangChain Runnables primitives.
|
||||||
"""
|
"""
|
||||||
from langchain.schema.runnable._locals import GetLocalVar, PutLocalVar
|
|
||||||
from langchain.schema.runnable.base import (
|
from langchain.schema.runnable.base import (
|
||||||
Runnable,
|
Runnable,
|
||||||
RunnableBinding,
|
RunnableBinding,
|
||||||
@ -40,9 +39,7 @@ __all__ = [
|
|||||||
"ConfigurableField",
|
"ConfigurableField",
|
||||||
"ConfigurableFieldSingleOption",
|
"ConfigurableFieldSingleOption",
|
||||||
"ConfigurableFieldMultiOption",
|
"ConfigurableFieldMultiOption",
|
||||||
"GetLocalVar",
|
|
||||||
"patch_config",
|
"patch_config",
|
||||||
"PutLocalVar",
|
|
||||||
"RouterInput",
|
"RouterInput",
|
||||||
"RouterRunnable",
|
"RouterRunnable",
|
||||||
"Runnable",
|
"Runnable",
|
||||||
|
@ -1,168 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import (
|
|
||||||
TYPE_CHECKING,
|
|
||||||
Any,
|
|
||||||
AsyncIterator,
|
|
||||||
Dict,
|
|
||||||
Iterator,
|
|
||||||
Mapping,
|
|
||||||
Optional,
|
|
||||||
Union,
|
|
||||||
)
|
|
||||||
|
|
||||||
from langchain.schema.runnable.base import Input, Other, Output, RunnableSerializable
|
|
||||||
from langchain.schema.runnable.config import RunnableConfig
|
|
||||||
from langchain.schema.runnable.passthrough import RunnablePassthrough
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from langchain.callbacks.manager import (
|
|
||||||
AsyncCallbackManagerForChainRun,
|
|
||||||
CallbackManagerForChainRun,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class PutLocalVar(RunnablePassthrough):
|
|
||||||
key: Union[str, Mapping[str, str]]
|
|
||||||
"""The key(s) to use for storing the input variable(s) in local state.
|
|
||||||
|
|
||||||
If a string is provided then the entire input is stored under that key. If a
|
|
||||||
Mapping is provided, then the map values are gotten from the input and
|
|
||||||
stored in local state under the map keys.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, key: Union[str, Mapping[str, str]], **kwargs: Any) -> None:
|
|
||||||
super().__init__(key=key, **kwargs)
|
|
||||||
|
|
||||||
def _concat_put(
|
|
||||||
self,
|
|
||||||
input: Other,
|
|
||||||
*,
|
|
||||||
config: Optional[RunnableConfig] = None,
|
|
||||||
replace: bool = False,
|
|
||||||
) -> None:
|
|
||||||
if config is None:
|
|
||||||
raise ValueError(
|
|
||||||
"PutLocalVar should only be used in a RunnableSequence, and should "
|
|
||||||
"therefore always receive a non-null config."
|
|
||||||
)
|
|
||||||
if isinstance(self.key, str):
|
|
||||||
if self.key not in config["locals"] or replace:
|
|
||||||
config["locals"][self.key] = input
|
|
||||||
else:
|
|
||||||
config["locals"][self.key] += input
|
|
||||||
elif isinstance(self.key, Mapping):
|
|
||||||
if not isinstance(input, Mapping):
|
|
||||||
raise TypeError(
|
|
||||||
f"Received key of type Mapping but input of type {type(input)}. "
|
|
||||||
f"input is expected to be of type Mapping when key is Mapping."
|
|
||||||
)
|
|
||||||
for input_key, put_key in self.key.items():
|
|
||||||
if put_key not in config["locals"] or replace:
|
|
||||||
config["locals"][put_key] = input[input_key]
|
|
||||||
else:
|
|
||||||
config["locals"][put_key] += input[input_key]
|
|
||||||
else:
|
|
||||||
raise TypeError(
|
|
||||||
f"`key` should be a string or Mapping[str, str], received type "
|
|
||||||
f"{(type(self.key))}."
|
|
||||||
)
|
|
||||||
|
|
||||||
def invoke(
|
|
||||||
self, input: Other, config: Optional[RunnableConfig] = None, **kwargs: Any
|
|
||||||
) -> Other:
|
|
||||||
self._concat_put(input, config=config, replace=True)
|
|
||||||
return super().invoke(input, config=config, **kwargs)
|
|
||||||
|
|
||||||
async def ainvoke(
|
|
||||||
self,
|
|
||||||
input: Other,
|
|
||||||
config: Optional[RunnableConfig] = None,
|
|
||||||
**kwargs: Optional[Any],
|
|
||||||
) -> Other:
|
|
||||||
self._concat_put(input, config=config, replace=True)
|
|
||||||
return await super().ainvoke(input, config=config, **kwargs)
|
|
||||||
|
|
||||||
def transform(
|
|
||||||
self,
|
|
||||||
input: Iterator[Other],
|
|
||||||
config: Optional[RunnableConfig] = None,
|
|
||||||
**kwargs: Optional[Any],
|
|
||||||
) -> Iterator[Other]:
|
|
||||||
for chunk in super().transform(input, config=config, **kwargs):
|
|
||||||
self._concat_put(chunk, config=config)
|
|
||||||
yield chunk
|
|
||||||
|
|
||||||
async def atransform(
|
|
||||||
self,
|
|
||||||
input: AsyncIterator[Other],
|
|
||||||
config: Optional[RunnableConfig] = None,
|
|
||||||
**kwargs: Optional[Any],
|
|
||||||
) -> AsyncIterator[Other]:
|
|
||||||
async for chunk in super().atransform(input, config=config, **kwargs):
|
|
||||||
self._concat_put(chunk, config=config)
|
|
||||||
yield chunk
|
|
||||||
|
|
||||||
|
|
||||||
class GetLocalVar(
|
|
||||||
RunnableSerializable[Input, Union[Output, Dict[str, Union[Input, Output]]]]
|
|
||||||
):
|
|
||||||
key: str
|
|
||||||
"""The key to extract from the local state."""
|
|
||||||
passthrough_key: Optional[str] = None
|
|
||||||
"""The key to use for passing through the invocation input.
|
|
||||||
|
|
||||||
If None, then only the value retrieved from local state is returned. Otherwise a
|
|
||||||
dictionary ``{self.key: <<retrieved_value>>, self.passthrough_key: <<input>>}``
|
|
||||||
is returned.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, key: str, **kwargs: Any) -> None:
|
|
||||||
super().__init__(key=key, **kwargs)
|
|
||||||
|
|
||||||
def _get(
|
|
||||||
self,
|
|
||||||
input: Input,
|
|
||||||
run_manager: Union[CallbackManagerForChainRun, Any],
|
|
||||||
config: RunnableConfig,
|
|
||||||
) -> Union[Output, Dict[str, Union[Input, Output]]]:
|
|
||||||
if self.passthrough_key:
|
|
||||||
return {
|
|
||||||
self.key: config["locals"][self.key],
|
|
||||||
self.passthrough_key: input,
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
return config["locals"][self.key]
|
|
||||||
|
|
||||||
async def _aget(
|
|
||||||
self,
|
|
||||||
input: Input,
|
|
||||||
run_manager: AsyncCallbackManagerForChainRun,
|
|
||||||
config: RunnableConfig,
|
|
||||||
) -> Union[Output, Dict[str, Union[Input, Output]]]:
|
|
||||||
return self._get(input, run_manager, config)
|
|
||||||
|
|
||||||
def invoke(
|
|
||||||
self, input: Input, config: Optional[RunnableConfig] = None
|
|
||||||
) -> Union[Output, Dict[str, Union[Input, Output]]]:
|
|
||||||
if config is None:
|
|
||||||
raise ValueError(
|
|
||||||
"GetLocalVar should only be used in a RunnableSequence, and should "
|
|
||||||
"therefore always receive a non-null config."
|
|
||||||
)
|
|
||||||
|
|
||||||
return self._call_with_config(self._get, input, config)
|
|
||||||
|
|
||||||
async def ainvoke(
|
|
||||||
self,
|
|
||||||
input: Input,
|
|
||||||
config: Optional[RunnableConfig] = None,
|
|
||||||
**kwargs: Optional[Any],
|
|
||||||
) -> Union[Output, Dict[str, Union[Input, Output]]]:
|
|
||||||
if config is None:
|
|
||||||
raise ValueError(
|
|
||||||
"GetLocalVar should only be used in a RunnableSequence, and should "
|
|
||||||
"therefore always receive a non-null config."
|
|
||||||
)
|
|
||||||
|
|
||||||
return await self._acall_with_config(self._aget, input, config)
|
|
@ -1656,7 +1656,6 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
|
|||||||
# mark each step as a child run
|
# mark each step as a child run
|
||||||
patch_config(
|
patch_config(
|
||||||
config,
|
config,
|
||||||
copy_locals=True,
|
|
||||||
callbacks=run_manager.get_child(f"map:key:{key}"),
|
callbacks=run_manager.get_child(f"map:key:{key}"),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@ -2534,10 +2533,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
|
|||||||
[merge_configs(self.config, conf) for conf in config],
|
[merge_configs(self.config, conf) for conf in config],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
configs = [
|
configs = [merge_configs(self.config, config) for _ in range(len(inputs))]
|
||||||
patch_config(merge_configs(self.config, config), copy_locals=True)
|
|
||||||
for _ in range(len(inputs))
|
|
||||||
]
|
|
||||||
return self.bound.batch(
|
return self.bound.batch(
|
||||||
inputs,
|
inputs,
|
||||||
configs,
|
configs,
|
||||||
@ -2559,10 +2555,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
|
|||||||
[merge_configs(self.config, conf) for conf in config],
|
[merge_configs(self.config, conf) for conf in config],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
configs = [
|
configs = [merge_configs(self.config, config) for _ in range(len(inputs))]
|
||||||
patch_config(merge_configs(self.config, config), copy_locals=True)
|
|
||||||
for _ in range(len(inputs))
|
|
||||||
]
|
|
||||||
return await self.bound.abatch(
|
return await self.bound.abatch(
|
||||||
inputs,
|
inputs,
|
||||||
configs,
|
configs,
|
||||||
|
@ -64,13 +64,6 @@ class RunnableConfig(TypedDict, total=False):
|
|||||||
Name for the tracer run for this call. Defaults to the name of the class.
|
Name for the tracer run for this call. Defaults to the name of the class.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
locals: Dict[str, Any]
|
|
||||||
"""
|
|
||||||
Variables scoped to this call and any sub-calls. Usually used with
|
|
||||||
GetLocalVar() and PutLocalVar(). Care should be taken when placing mutable
|
|
||||||
objects in locals, as they will be shared between parallel sub-calls.
|
|
||||||
"""
|
|
||||||
|
|
||||||
max_concurrency: Optional[int]
|
max_concurrency: Optional[int]
|
||||||
"""
|
"""
|
||||||
Maximum number of parallel calls to make. If not provided, defaults to
|
Maximum number of parallel calls to make. If not provided, defaults to
|
||||||
@ -96,7 +89,6 @@ def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig:
|
|||||||
tags=[],
|
tags=[],
|
||||||
metadata={},
|
metadata={},
|
||||||
callbacks=None,
|
callbacks=None,
|
||||||
locals={},
|
|
||||||
recursion_limit=25,
|
recursion_limit=25,
|
||||||
)
|
)
|
||||||
if config is not None:
|
if config is not None:
|
||||||
@ -124,14 +116,13 @@ def get_config_list(
|
|||||||
return (
|
return (
|
||||||
list(map(ensure_config, config))
|
list(map(ensure_config, config))
|
||||||
if isinstance(config, list)
|
if isinstance(config, list)
|
||||||
else [patch_config(config, copy_locals=True) for _ in range(length)]
|
else [ensure_config(config) for _ in range(length)]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def patch_config(
|
def patch_config(
|
||||||
config: Optional[RunnableConfig],
|
config: Optional[RunnableConfig],
|
||||||
*,
|
*,
|
||||||
copy_locals: bool = False,
|
|
||||||
callbacks: Optional[BaseCallbackManager] = None,
|
callbacks: Optional[BaseCallbackManager] = None,
|
||||||
recursion_limit: Optional[int] = None,
|
recursion_limit: Optional[int] = None,
|
||||||
max_concurrency: Optional[int] = None,
|
max_concurrency: Optional[int] = None,
|
||||||
@ -139,8 +130,6 @@ def patch_config(
|
|||||||
configurable: Optional[Dict[str, Any]] = None,
|
configurable: Optional[Dict[str, Any]] = None,
|
||||||
) -> RunnableConfig:
|
) -> RunnableConfig:
|
||||||
config = ensure_config(config)
|
config = ensure_config(config)
|
||||||
if copy_locals:
|
|
||||||
config["locals"] = config["locals"].copy()
|
|
||||||
if callbacks is not None:
|
if callbacks is not None:
|
||||||
# If we're replacing callbacks we need to unset run_name
|
# If we're replacing callbacks we need to unset run_name
|
||||||
# As that should apply only to the same run as the original callbacks
|
# As that should apply only to the same run as the original callbacks
|
||||||
|
@ -1,94 +0,0 @@
|
|||||||
from typing import Any, Callable, Type
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from langchain.llms import FakeListLLM
|
|
||||||
from langchain.prompts import PromptTemplate
|
|
||||||
from langchain.schema.runnable import (
|
|
||||||
GetLocalVar,
|
|
||||||
PutLocalVar,
|
|
||||||
Runnable,
|
|
||||||
RunnablePassthrough,
|
|
||||||
RunnableSequence,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
("method", "input", "output"),
|
|
||||||
[
|
|
||||||
(lambda r, x: r.invoke(x), "foo", "foo"),
|
|
||||||
(lambda r, x: r.batch(x), ["foo", "bar"], ["foo", "bar"]),
|
|
||||||
(lambda r, x: list(r.stream(x))[0], "foo", "foo"),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_put_get(method: Callable, input: Any, output: Any) -> None:
|
|
||||||
runnable = PutLocalVar("input") | GetLocalVar("input")
|
|
||||||
assert method(runnable, input) == output
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
("method", "input", "output"),
|
|
||||||
[
|
|
||||||
(lambda r, x: r.ainvoke(x), "foo", "foo"),
|
|
||||||
(lambda r, x: r.abatch(x), ["foo", "bar"], ["foo", "bar"]),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
async def test_put_get_async(method: Callable, input: Any, output: Any) -> None:
|
|
||||||
runnable = PutLocalVar("input") | GetLocalVar("input")
|
|
||||||
assert await method(runnable, input) == output
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
("runnable", "error"),
|
|
||||||
[
|
|
||||||
(PutLocalVar("input"), ValueError),
|
|
||||||
(GetLocalVar("input"), ValueError),
|
|
||||||
(PutLocalVar("input") | GetLocalVar("missing"), KeyError),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_incorrect_usage(runnable: RunnableSequence, error: Type[Exception]) -> None:
|
|
||||||
with pytest.raises(error):
|
|
||||||
runnable.invoke("foo")
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_in_map() -> None:
|
|
||||||
runnable: Runnable = PutLocalVar("input") | {"bar": GetLocalVar("input")}
|
|
||||||
assert runnable.invoke("foo") == {"bar": "foo"}
|
|
||||||
|
|
||||||
|
|
||||||
def test_put_in_map() -> None:
|
|
||||||
runnable: Runnable = {"bar": PutLocalVar("input")} | GetLocalVar("input")
|
|
||||||
with pytest.raises(KeyError):
|
|
||||||
runnable.invoke("foo")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"runnable",
|
|
||||||
[
|
|
||||||
PutLocalVar("input") | GetLocalVar("input", passthrough_key="output"),
|
|
||||||
(
|
|
||||||
PutLocalVar("input")
|
|
||||||
| {"input": RunnablePassthrough()}
|
|
||||||
| PromptTemplate.from_template("say {input}")
|
|
||||||
| FakeListLLM(responses=["hello"])
|
|
||||||
| GetLocalVar("input", passthrough_key="output")
|
|
||||||
),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
("method", "input", "output"),
|
|
||||||
[
|
|
||||||
(lambda r, x: r.invoke(x), "hello", {"input": "hello", "output": "hello"}),
|
|
||||||
(lambda r, x: r.batch(x), ["hello"], [{"input": "hello", "output": "hello"}]),
|
|
||||||
(
|
|
||||||
lambda r, x: list(r.stream(x))[0],
|
|
||||||
"hello",
|
|
||||||
{"input": "hello", "output": "hello"},
|
|
||||||
),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_put_get_sequence(
|
|
||||||
runnable: RunnableSequence, method: Callable, input: Any, output: Any
|
|
||||||
) -> None:
|
|
||||||
assert method(runnable, input) == output
|
|
@ -1209,7 +1209,6 @@ async def test_with_config(mocker: MockerFixture) -> None:
|
|||||||
metadata={"key": "value"},
|
metadata={"key": "value"},
|
||||||
tags=["c"],
|
tags=["c"],
|
||||||
callbacks=None,
|
callbacks=None,
|
||||||
locals={},
|
|
||||||
recursion_limit=5,
|
recursion_limit=5,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
@ -1219,7 +1218,6 @@ async def test_with_config(mocker: MockerFixture) -> None:
|
|||||||
metadata={"key": "value"},
|
metadata={"key": "value"},
|
||||||
tags=["c"],
|
tags=["c"],
|
||||||
callbacks=None,
|
callbacks=None,
|
||||||
locals={},
|
|
||||||
recursion_limit=5,
|
recursion_limit=5,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
@ -1290,7 +1288,6 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
|
|||||||
metadata={"key": "value"},
|
metadata={"key": "value"},
|
||||||
tags=[],
|
tags=[],
|
||||||
callbacks=None,
|
callbacks=None,
|
||||||
locals={},
|
|
||||||
recursion_limit=25,
|
recursion_limit=25,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
@ -1300,7 +1297,6 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
|
|||||||
metadata={"key": "value"},
|
metadata={"key": "value"},
|
||||||
tags=[],
|
tags=[],
|
||||||
callbacks=None,
|
callbacks=None,
|
||||||
locals={},
|
|
||||||
recursion_limit=25,
|
recursion_limit=25,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
Loading…
Reference in New Issue
Block a user