mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-23 15:19:33 +00:00
Fix combining runnable sequences (#8557)
Combining runnable sequences was dropping a step in the middle. @nfcampos @baskaryan
This commit is contained in:
parent
3fbb737bb3
commit
2a26cc6d2b
@ -214,7 +214,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
if isinstance(other, RunnableSequence):
|
||||
return RunnableSequence(
|
||||
first=self.first,
|
||||
middle=self.middle + [self.last] + other.middle,
|
||||
middle=self.middle + [self.last] + [other.first] + other.middle,
|
||||
last=other.last,
|
||||
)
|
||||
else:
|
||||
@ -235,7 +235,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
if isinstance(other, RunnableSequence):
|
||||
return RunnableSequence(
|
||||
first=other.first,
|
||||
middle=other.middle + [other.last] + self.middle,
|
||||
middle=other.middle + [other.last] + [self.first] + self.middle,
|
||||
last=self.last,
|
||||
)
|
||||
else:
|
||||
|
File diff suppressed because one or more lines are too long
@ -440,6 +440,64 @@ def test_prompt_with_chat_model_and_parser(
|
||||
assert tracer.runs == snapshot
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_combining_sequences(
|
||||
mocker: MockerFixture, snapshot: SnapshotAssertion
|
||||
) -> None:
|
||||
prompt = (
|
||||
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
|
||||
+ "{question}"
|
||||
)
|
||||
chat = FakeListChatModel(responses=["foo, bar"])
|
||||
parser = CommaSeparatedListOutputParser()
|
||||
|
||||
chain = prompt | chat | parser
|
||||
|
||||
assert isinstance(chain, RunnableSequence)
|
||||
assert chain.first == prompt
|
||||
assert chain.middle == [chat]
|
||||
assert chain.last == parser
|
||||
assert dumps(chain, pretty=True) == snapshot
|
||||
|
||||
prompt2 = (
|
||||
SystemMessagePromptTemplate.from_template("You are a nicer assistant.")
|
||||
+ "{question}"
|
||||
)
|
||||
chat2 = FakeListChatModel(responses=["baz, qux"])
|
||||
parser2 = CommaSeparatedListOutputParser()
|
||||
input_formatter: RunnableLambda[List[str], Dict[str, Any]] = RunnableLambda(
|
||||
lambda x: {"question": x[0] + x[1]}
|
||||
)
|
||||
|
||||
chain2 = input_formatter | prompt2 | chat2 | parser2
|
||||
|
||||
assert isinstance(chain, RunnableSequence)
|
||||
assert chain2.first == input_formatter
|
||||
assert chain2.middle == [prompt2, chat2]
|
||||
assert chain2.last == parser2
|
||||
assert dumps(chain2, pretty=True) == snapshot
|
||||
|
||||
combined_chain = chain | chain2
|
||||
|
||||
assert combined_chain.first == prompt
|
||||
assert combined_chain.middle == [
|
||||
chat,
|
||||
parser,
|
||||
input_formatter,
|
||||
prompt2,
|
||||
chat2,
|
||||
]
|
||||
assert combined_chain.last == parser2
|
||||
assert dumps(combined_chain, pretty=True) == snapshot
|
||||
|
||||
# Test invoke
|
||||
tracer = FakeTracer()
|
||||
assert combined_chain.invoke(
|
||||
{"question": "What is your name?"}, dict(callbacks=[tracer])
|
||||
) == ["baz", "qux"]
|
||||
assert tracer.runs == snapshot
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_seq_dict_prompt_llm(
|
||||
mocker: MockerFixture, snapshot: SnapshotAssertion
|
||||
|
Loading…
Reference in New Issue
Block a user