Support using RunnableMap directly (#8317)

<!-- Thank you for contributing to LangChain!

Replace this comment with:
  - Description: a description of the change, 
  - Issue: the issue # it fixes (if applicable),
  - Dependencies: any dependencies required for this change,
- Tag maintainer: for a quicker response, tag the relevant maintainer
(see below),
- Twitter handle: we announce bigger features on Twitter. If your PR
gets announced and you'd like a mention, we'll gladly shout you out!

Please make sure you're PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
  2. an example notebook showing its use.

Maintainer responsibilities:
  - General / Misc / if you don't know who to tag: @baskaryan
  - DataLoaders / VectorStores / Retrievers: @rlancemartin, @eyurtsev
  - Models / Prompts: @hwchase17, @baskaryan
  - Memory: @hwchase17
  - Agents / Tools / Toolkits: @hinthornw
  - Tracing / Callbacks: @agola11
  - Async: @agola11

If no one reviews your PR within a few days, feel free to @-mention the
same people again.

See contribution guidelines for more information on how to write/run
tests, lint, etc:
https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md
 -->
This commit is contained in:
Nuno Campos 2023-07-27 17:24:29 +01:00 committed by GitHub
parent 944321c6ab
commit 1bbadde77b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 236 additions and 9 deletions

View File

@ -12,6 +12,7 @@ from typing import (
Generic,
Iterator,
List,
Mapping,
Optional,
TypedDict,
TypeVar,
@ -71,7 +72,7 @@ class Runnable(Generic[Input, Output], ABC):
self,
other: Union[
Runnable[Any, Other],
Dict[str, Union[Runnable[Any, Other], Callable[[Any], Other]]],
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other]]],
],
) -> RunnableSequence[Input, Other]:
return RunnableSequence(first=self, last=_coerce_to_runnable(other))
@ -80,7 +81,7 @@ class Runnable(Generic[Input, Output], ABC):
self,
other: Union[
Runnable[Other, Any],
Dict[str, Union[Runnable[Other, Any], Callable[[Other], Any]]],
Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any]]],
],
) -> RunnableSequence[Other, Output]:
return RunnableSequence(first=_coerce_to_runnable(other), last=self)
@ -194,7 +195,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
self,
other: Union[
Runnable[Any, Other],
Dict[str, Union[Runnable[Any, Other], Callable[[Any], Other]]],
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other]]],
],
) -> RunnableSequence[Input, Other]:
if isinstance(other, RunnableSequence):
@ -214,7 +215,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
self,
other: Union[
Runnable[Other, Any],
Dict[str, Union[Runnable[Other, Any], Callable[[Other], Any]]],
Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any]]],
],
) -> RunnableSequence[Other, Output]:
if isinstance(other, RunnableSequence):
@ -551,7 +552,22 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
steps: Dict[str, Runnable[Input, Any]]
steps: Mapping[str, Runnable[Input, Any]]
def __init__(
self,
steps: Mapping[
str,
Union[
Runnable[Input, Any],
Callable[[Input], Any],
Mapping[str, Union[Runnable[Input, Any], Callable[[Input], Any]]],
],
],
) -> None:
super().__init__(
steps={key: _coerce_to_runnable(r) for key, r in steps.items()}
)
@property
def lc_serializable(self) -> bool:
@ -582,7 +598,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
# gather results from all steps
try:
# copy to avoid issues from the caller mutating the steps during invoke()
steps = self.steps.copy()
steps = dict(self.steps)
with ThreadPoolExecutor() as executor:
futures = [
executor.submit(
@ -626,7 +642,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
# gather results from all steps
try:
# copy to avoid issues from the caller mutating the steps during invoke()
steps = self.steps.copy()
steps = dict(self.steps)
results = await asyncio.gather(
*(
step.ainvoke(
@ -688,7 +704,7 @@ def _coerce_to_runnable(
thing: Union[
Runnable[Input, Output],
Callable[[Input], Output],
Dict[str, Union[Runnable[Input, Output], Callable[[Input], Output]]],
Mapping[str, Union[Runnable[Input, Output], Callable[[Input], Output]]],
]
) -> Runnable[Input, Output]:
if isinstance(thing, Runnable):

File diff suppressed because one or more lines are too long

View File

@ -508,7 +508,7 @@ def test_seq_prompt_dict(
chain = (
prompt
| passthrough
| { # type: ignore
| {
"chat": chat,
"llm": llm,
}
@ -545,3 +545,67 @@ def test_seq_prompt_dict(
]
)
assert tracer.runs == snapshot
@freeze_time("2023-01-01")
def test_seq_prompt_map(
mocker: MockerFixture, snapshot: SnapshotAssertion, fixed_uuids: None
) -> None:
passthrough = mocker.Mock(side_effect=lambda x: x)
prompt = (
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
+ "{question}"
)
chat = FakeListChatModel(responses=["i'm a chatbot"])
llm = FakeListLLM(responses=["i'm a textbot"])
chain = (
prompt
| passthrough
| {
"chat": chat,
"llm": llm,
"passthrough": passthrough,
}
)
assert isinstance(chain, RunnableSequence)
assert chain.first == prompt
assert chain.middle == [RunnableLambda(passthrough)]
assert isinstance(chain.last, RunnableMap)
assert dumps(chain, pretty=True) == snapshot
# Test invoke
prompt_spy = mocker.spy(prompt.__class__, "invoke")
chat_spy = mocker.spy(chat.__class__, "invoke")
llm_spy = mocker.spy(llm.__class__, "invoke")
tracer = FakeTracer()
assert chain.invoke(
{"question": "What is your name?"}, dict(callbacks=[tracer])
) == {
"chat": AIMessage(content="i'm a chatbot"),
"llm": "i'm a textbot",
"passthrough": ChatPromptValue(
messages=[
SystemMessage(content="You are a nice assistant."),
HumanMessage(content="What is your name?"),
]
),
}
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
assert chat_spy.call_args.args[1] == ChatPromptValue(
messages=[
SystemMessage(content="You are a nice assistant."),
HumanMessage(content="What is your name?"),
]
)
assert llm_spy.call_args.args[1] == ChatPromptValue(
messages=[
SystemMessage(content="You are a nice assistant."),
HumanMessage(content="What is your name?"),
]
)
assert tracer.runs == snapshot