mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-19 13:23:35 +00:00
Support using RunnableMap directly (#8317)
<!-- Thank you for contributing to LangChain! Replace this comment with: - Description: a description of the change, - Issue: the issue # it fixes (if applicable), - Dependencies: any dependencies required for this change, - Tag maintainer: for a quicker response, tag the relevant maintainer (see below), - Twitter handle: we announce bigger features on Twitter. If your PR gets announced and you'd like a mention, we'll gladly shout you out! Please make sure you're PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` to check this locally. If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. Maintainer responsibilities: - General / Misc / if you don't know who to tag: @baskaryan - DataLoaders / VectorStores / Retrievers: @rlancemartin, @eyurtsev - Models / Prompts: @hwchase17, @baskaryan - Memory: @hwchase17 - Agents / Tools / Toolkits: @hinthornw - Tracing / Callbacks: @agola11 - Async: @agola11 If no one reviews your PR within a few days, feel free to @-mention the same people again. See contribution guidelines for more information on how to write/run tests, lint, etc: https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md -->
This commit is contained in:
parent
944321c6ab
commit
1bbadde77b
@ -12,6 +12,7 @@ from typing import (
|
||||
Generic,
|
||||
Iterator,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
TypedDict,
|
||||
TypeVar,
|
||||
@ -71,7 +72,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
self,
|
||||
other: Union[
|
||||
Runnable[Any, Other],
|
||||
Dict[str, Union[Runnable[Any, Other], Callable[[Any], Other]]],
|
||||
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other]]],
|
||||
],
|
||||
) -> RunnableSequence[Input, Other]:
|
||||
return RunnableSequence(first=self, last=_coerce_to_runnable(other))
|
||||
@ -80,7 +81,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
self,
|
||||
other: Union[
|
||||
Runnable[Other, Any],
|
||||
Dict[str, Union[Runnable[Other, Any], Callable[[Other], Any]]],
|
||||
Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any]]],
|
||||
],
|
||||
) -> RunnableSequence[Other, Output]:
|
||||
return RunnableSequence(first=_coerce_to_runnable(other), last=self)
|
||||
@ -194,7 +195,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
self,
|
||||
other: Union[
|
||||
Runnable[Any, Other],
|
||||
Dict[str, Union[Runnable[Any, Other], Callable[[Any], Other]]],
|
||||
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other]]],
|
||||
],
|
||||
) -> RunnableSequence[Input, Other]:
|
||||
if isinstance(other, RunnableSequence):
|
||||
@ -214,7 +215,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
self,
|
||||
other: Union[
|
||||
Runnable[Other, Any],
|
||||
Dict[str, Union[Runnable[Other, Any], Callable[[Other], Any]]],
|
||||
Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any]]],
|
||||
],
|
||||
) -> RunnableSequence[Other, Output]:
|
||||
if isinstance(other, RunnableSequence):
|
||||
@ -551,7 +552,22 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
|
||||
|
||||
class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
||||
steps: Dict[str, Runnable[Input, Any]]
|
||||
steps: Mapping[str, Runnable[Input, Any]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
steps: Mapping[
|
||||
str,
|
||||
Union[
|
||||
Runnable[Input, Any],
|
||||
Callable[[Input], Any],
|
||||
Mapping[str, Union[Runnable[Input, Any], Callable[[Input], Any]]],
|
||||
],
|
||||
],
|
||||
) -> None:
|
||||
super().__init__(
|
||||
steps={key: _coerce_to_runnable(r) for key, r in steps.items()}
|
||||
)
|
||||
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
@ -582,7 +598,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
||||
# gather results from all steps
|
||||
try:
|
||||
# copy to avoid issues from the caller mutating the steps during invoke()
|
||||
steps = self.steps.copy()
|
||||
steps = dict(self.steps)
|
||||
with ThreadPoolExecutor() as executor:
|
||||
futures = [
|
||||
executor.submit(
|
||||
@ -626,7 +642,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
||||
# gather results from all steps
|
||||
try:
|
||||
# copy to avoid issues from the caller mutating the steps during invoke()
|
||||
steps = self.steps.copy()
|
||||
steps = dict(self.steps)
|
||||
results = await asyncio.gather(
|
||||
*(
|
||||
step.ainvoke(
|
||||
@ -688,7 +704,7 @@ def _coerce_to_runnable(
|
||||
thing: Union[
|
||||
Runnable[Input, Output],
|
||||
Callable[[Input], Output],
|
||||
Dict[str, Union[Runnable[Input, Output], Callable[[Input], Output]]],
|
||||
Mapping[str, Union[Runnable[Input, Output], Callable[[Input], Output]]],
|
||||
]
|
||||
) -> Runnable[Input, Output]:
|
||||
if isinstance(thing, Runnable):
|
||||
|
File diff suppressed because one or more lines are too long
@ -508,7 +508,7 @@ def test_seq_prompt_dict(
|
||||
chain = (
|
||||
prompt
|
||||
| passthrough
|
||||
| { # type: ignore
|
||||
| {
|
||||
"chat": chat,
|
||||
"llm": llm,
|
||||
}
|
||||
@ -545,3 +545,67 @@ def test_seq_prompt_dict(
|
||||
]
|
||||
)
|
||||
assert tracer.runs == snapshot
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_seq_prompt_map(
|
||||
mocker: MockerFixture, snapshot: SnapshotAssertion, fixed_uuids: None
|
||||
) -> None:
|
||||
passthrough = mocker.Mock(side_effect=lambda x: x)
|
||||
|
||||
prompt = (
|
||||
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
|
||||
+ "{question}"
|
||||
)
|
||||
|
||||
chat = FakeListChatModel(responses=["i'm a chatbot"])
|
||||
|
||||
llm = FakeListLLM(responses=["i'm a textbot"])
|
||||
|
||||
chain = (
|
||||
prompt
|
||||
| passthrough
|
||||
| {
|
||||
"chat": chat,
|
||||
"llm": llm,
|
||||
"passthrough": passthrough,
|
||||
}
|
||||
)
|
||||
|
||||
assert isinstance(chain, RunnableSequence)
|
||||
assert chain.first == prompt
|
||||
assert chain.middle == [RunnableLambda(passthrough)]
|
||||
assert isinstance(chain.last, RunnableMap)
|
||||
assert dumps(chain, pretty=True) == snapshot
|
||||
|
||||
# Test invoke
|
||||
prompt_spy = mocker.spy(prompt.__class__, "invoke")
|
||||
chat_spy = mocker.spy(chat.__class__, "invoke")
|
||||
llm_spy = mocker.spy(llm.__class__, "invoke")
|
||||
tracer = FakeTracer()
|
||||
assert chain.invoke(
|
||||
{"question": "What is your name?"}, dict(callbacks=[tracer])
|
||||
) == {
|
||||
"chat": AIMessage(content="i'm a chatbot"),
|
||||
"llm": "i'm a textbot",
|
||||
"passthrough": ChatPromptValue(
|
||||
messages=[
|
||||
SystemMessage(content="You are a nice assistant."),
|
||||
HumanMessage(content="What is your name?"),
|
||||
]
|
||||
),
|
||||
}
|
||||
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
|
||||
assert chat_spy.call_args.args[1] == ChatPromptValue(
|
||||
messages=[
|
||||
SystemMessage(content="You are a nice assistant."),
|
||||
HumanMessage(content="What is your name?"),
|
||||
]
|
||||
)
|
||||
assert llm_spy.call_args.args[1] == ChatPromptValue(
|
||||
messages=[
|
||||
SystemMessage(content="You are a nice assistant."),
|
||||
HumanMessage(content="What is your name?"),
|
||||
]
|
||||
)
|
||||
assert tracer.runs == snapshot
|
||||
|
Loading…
Reference in New Issue
Block a user