mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-07 13:40:46 +00:00
wip
This commit is contained in:
parent
6f69b19ff5
commit
ab21af71be
@ -253,7 +253,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
) -> Output:
|
) -> Output:
|
||||||
"""Helper method to transform an Input value to an Output value,
|
"""Helper method to transform an Input value to an Output value,
|
||||||
with callbacks. Use this method to implement invoke() in subclasses."""
|
with callbacks. Use this method to implement invoke() in subclasses."""
|
||||||
config = config or {}
|
config = config or _empty_config()
|
||||||
callback_manager = _get_callback_manager(config)
|
callback_manager = _get_callback_manager(config)
|
||||||
run_manager = callback_manager.on_chain_start(
|
run_manager = callback_manager.on_chain_start(
|
||||||
dumpd(self),
|
dumpd(self),
|
||||||
@ -283,7 +283,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
) -> Output:
|
) -> Output:
|
||||||
"""Helper method to transform an Input value to an Output value,
|
"""Helper method to transform an Input value to an Output value,
|
||||||
with callbacks. Use this method to implement ainvoke() in subclasses."""
|
with callbacks. Use this method to implement ainvoke() in subclasses."""
|
||||||
config = config or {}
|
config = config or _empty_config()
|
||||||
callback_manager = _get_async_callback_manager(config)
|
callback_manager = _get_async_callback_manager(config)
|
||||||
run_manager = await callback_manager.on_chain_start(
|
run_manager = await callback_manager.on_chain_start(
|
||||||
dumpd(self),
|
dumpd(self),
|
||||||
@ -322,7 +322,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
final_output: Optional[Output] = None
|
final_output: Optional[Output] = None
|
||||||
final_output_supported = True
|
final_output_supported = True
|
||||||
|
|
||||||
config = config or {}
|
config = config or _empty_config()
|
||||||
callback_manager = _get_callback_manager(config)
|
callback_manager = _get_callback_manager(config)
|
||||||
run_manager = callback_manager.on_chain_start(
|
run_manager = callback_manager.on_chain_start(
|
||||||
dumpd(self),
|
dumpd(self),
|
||||||
@ -387,7 +387,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
final_output: Optional[Output] = None
|
final_output: Optional[Output] = None
|
||||||
final_output_supported = True
|
final_output_supported = True
|
||||||
|
|
||||||
config = config or {}
|
config = config or _empty_config()
|
||||||
callback_manager = _get_async_callback_manager(config)
|
callback_manager = _get_async_callback_manager(config)
|
||||||
run_manager = await callback_manager.on_chain_start(
|
run_manager = await callback_manager.on_chain_start(
|
||||||
dumpd(self),
|
dumpd(self),
|
||||||
@ -462,7 +462,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
|||||||
|
|
||||||
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
|
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
|
||||||
# setup callbacks
|
# setup callbacks
|
||||||
config = config or {}
|
config = config or _empty_config()
|
||||||
callback_manager = _get_callback_manager(config)
|
callback_manager = _get_callback_manager(config)
|
||||||
# start the root run
|
# start the root run
|
||||||
run_manager = callback_manager.on_chain_start(
|
run_manager = callback_manager.on_chain_start(
|
||||||
@ -495,7 +495,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
|||||||
self, input: Input, config: Optional[RunnableConfig] = None
|
self, input: Input, config: Optional[RunnableConfig] = None
|
||||||
) -> Output:
|
) -> Output:
|
||||||
# setup callbacks
|
# setup callbacks
|
||||||
config = config or {}
|
config = config or _empty_config()
|
||||||
callback_manager = _get_async_callback_manager(config)
|
callback_manager = _get_async_callback_manager(config)
|
||||||
# start the root run
|
# start the root run
|
||||||
run_manager = await callback_manager.on_chain_start(
|
run_manager = await callback_manager.on_chain_start(
|
||||||
@ -1068,7 +1068,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
|||||||
from langchain.callbacks.manager import CallbackManager
|
from langchain.callbacks.manager import CallbackManager
|
||||||
|
|
||||||
# setup callbacks
|
# setup callbacks
|
||||||
config = config or {}
|
config = config or _empty_config()
|
||||||
callback_manager = CallbackManager.configure(
|
callback_manager = CallbackManager.configure(
|
||||||
inheritable_callbacks=config.get("callbacks"),
|
inheritable_callbacks=config.get("callbacks"),
|
||||||
local_callbacks=None,
|
local_callbacks=None,
|
||||||
@ -1091,7 +1091,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
|||||||
step.invoke,
|
step.invoke,
|
||||||
input,
|
input,
|
||||||
# mark each step as a child run
|
# mark each step as a child run
|
||||||
patch_config(config, run_manager.get_child()),
|
patch_config(deepcopy(config), run_manager.get_child()),
|
||||||
)
|
)
|
||||||
for step in steps.values()
|
for step in steps.values()
|
||||||
]
|
]
|
||||||
@ -1108,7 +1108,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
|||||||
self, input: Input, config: Optional[RunnableConfig] = None
|
self, input: Input, config: Optional[RunnableConfig] = None
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
# setup callbacks
|
# setup callbacks
|
||||||
config = config or {}
|
config = config or _empty_config()
|
||||||
callback_manager = _get_async_callback_manager(config)
|
callback_manager = _get_async_callback_manager(config)
|
||||||
# start the root run
|
# start the root run
|
||||||
run_manager = await callback_manager.on_chain_start(
|
run_manager = await callback_manager.on_chain_start(
|
||||||
|
@ -1,6 +1,13 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from langchain.schema.runnable import GetLocalVar, PutLocalVar
|
from langchain import PromptTemplate
|
||||||
|
from langchain.llms import FakeListLLM
|
||||||
|
from langchain.schema.runnable import (
|
||||||
|
GetLocalVar,
|
||||||
|
PutLocalVar,
|
||||||
|
RunnablePassthrough,
|
||||||
|
RunnableSequence,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@ -29,3 +36,30 @@ def test_get_missing_var_invoke() -> None:
|
|||||||
with pytest.raises(KeyError):
|
with pytest.raises(KeyError):
|
||||||
runnable.invoke("foo")
|
runnable.invoke("foo")
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_in_map() -> None:
|
||||||
|
runnable: RunnableSequence = PutLocalVar("input") | {"bar": GetLocalVar("input")}
|
||||||
|
assert runnable.invoke("foo") == {"bar": "foo"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_cant_put_in_map() -> None:
|
||||||
|
runnable: RunnableSequence = {"bar": PutLocalVar("input")} | GetLocalVar("input")
|
||||||
|
with pytest.raises(KeyError):
|
||||||
|
runnable.invoke("foo")
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_passthrough_key() -> None:
|
||||||
|
runnable = PutLocalVar("input") | GetLocalVar("input", passthrough_key="output")
|
||||||
|
assert runnable.invoke("foo") == {"input": "foo", "output": "foo"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_multi_step_sequence() -> None:
|
||||||
|
prompt = PromptTemplate.from_template("say {foo}")
|
||||||
|
runnable = (
|
||||||
|
PutLocalVar("foo")
|
||||||
|
| {"foo": RunnablePassthrough()}
|
||||||
|
| prompt
|
||||||
|
| FakeListLLM(responses=["bar"])
|
||||||
|
| GetLocalVar("foo", passthrough_key="output")
|
||||||
|
)
|
||||||
|
assert runnable.invoke("hello") == {"foo": "hello", "output": "bar"}
|
||||||
|
Loading…
Reference in New Issue
Block a user