mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-07 13:40:46 +00:00
wip
This commit is contained in:
parent
6f69b19ff5
commit
ab21af71be
@ -253,7 +253,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
) -> Output:
|
||||
"""Helper method to transform an Input value to an Output value,
|
||||
with callbacks. Use this method to implement invoke() in subclasses."""
|
||||
config = config or {}
|
||||
config = config or _empty_config()
|
||||
callback_manager = _get_callback_manager(config)
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self),
|
||||
@ -283,7 +283,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
) -> Output:
|
||||
"""Helper method to transform an Input value to an Output value,
|
||||
with callbacks. Use this method to implement ainvoke() in subclasses."""
|
||||
config = config or {}
|
||||
config = config or _empty_config()
|
||||
callback_manager = _get_async_callback_manager(config)
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
dumpd(self),
|
||||
@ -322,7 +322,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
final_output: Optional[Output] = None
|
||||
final_output_supported = True
|
||||
|
||||
config = config or {}
|
||||
config = config or _empty_config()
|
||||
callback_manager = _get_callback_manager(config)
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self),
|
||||
@ -387,7 +387,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
final_output: Optional[Output] = None
|
||||
final_output_supported = True
|
||||
|
||||
config = config or {}
|
||||
config = config or _empty_config()
|
||||
callback_manager = _get_async_callback_manager(config)
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
dumpd(self),
|
||||
@ -462,7 +462,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
||||
|
||||
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
|
||||
# setup callbacks
|
||||
config = config or {}
|
||||
config = config or _empty_config()
|
||||
callback_manager = _get_callback_manager(config)
|
||||
# start the root run
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
@ -495,7 +495,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
||||
self, input: Input, config: Optional[RunnableConfig] = None
|
||||
) -> Output:
|
||||
# setup callbacks
|
||||
config = config or {}
|
||||
config = config or _empty_config()
|
||||
callback_manager = _get_async_callback_manager(config)
|
||||
# start the root run
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
@ -1068,7 +1068,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
|
||||
# setup callbacks
|
||||
config = config or {}
|
||||
config = config or _empty_config()
|
||||
callback_manager = CallbackManager.configure(
|
||||
inheritable_callbacks=config.get("callbacks"),
|
||||
local_callbacks=None,
|
||||
@ -1091,7 +1091,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
||||
step.invoke,
|
||||
input,
|
||||
# mark each step as a child run
|
||||
patch_config(config, run_manager.get_child()),
|
||||
patch_config(deepcopy(config), run_manager.get_child()),
|
||||
)
|
||||
for step in steps.values()
|
||||
]
|
||||
@ -1108,7 +1108,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
||||
self, input: Input, config: Optional[RunnableConfig] = None
|
||||
) -> Dict[str, Any]:
|
||||
# setup callbacks
|
||||
config = config or {}
|
||||
config = config or _empty_config()
|
||||
callback_manager = _get_async_callback_manager(config)
|
||||
# start the root run
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
|
@ -1,6 +1,13 @@
|
||||
import pytest
|
||||
|
||||
from langchain.schema.runnable import GetLocalVar, PutLocalVar
|
||||
from langchain import PromptTemplate
|
||||
from langchain.llms import FakeListLLM
|
||||
from langchain.schema.runnable import (
|
||||
GetLocalVar,
|
||||
PutLocalVar,
|
||||
RunnablePassthrough,
|
||||
RunnableSequence,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -29,3 +36,30 @@ def test_get_missing_var_invoke() -> None:
|
||||
with pytest.raises(KeyError):
|
||||
runnable.invoke("foo")
|
||||
|
||||
|
||||
def test_get_in_map() -> None:
|
||||
runnable: RunnableSequence = PutLocalVar("input") | {"bar": GetLocalVar("input")}
|
||||
assert runnable.invoke("foo") == {"bar": "foo"}
|
||||
|
||||
|
||||
def test_cant_put_in_map() -> None:
|
||||
runnable: RunnableSequence = {"bar": PutLocalVar("input")} | GetLocalVar("input")
|
||||
with pytest.raises(KeyError):
|
||||
runnable.invoke("foo")
|
||||
|
||||
|
||||
def test_get_passthrough_key() -> None:
|
||||
runnable = PutLocalVar("input") | GetLocalVar("input", passthrough_key="output")
|
||||
assert runnable.invoke("foo") == {"input": "foo", "output": "foo"}
|
||||
|
||||
|
||||
def test_multi_step_sequence() -> None:
|
||||
prompt = PromptTemplate.from_template("say {foo}")
|
||||
runnable = (
|
||||
PutLocalVar("foo")
|
||||
| {"foo": RunnablePassthrough()}
|
||||
| prompt
|
||||
| FakeListLLM(responses=["bar"])
|
||||
| GetLocalVar("foo", passthrough_key="output")
|
||||
)
|
||||
assert runnable.invoke("hello") == {"foo": "hello", "output": "bar"}
|
||||
|
Loading…
Reference in New Issue
Block a user