mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-24 23:54:14 +00:00
Fix combining runnable sequences (#8557)
Combining runnable sequences was dropping a step in the middle. @nfcampos @baskaryan
This commit is contained in:
parent
3fbb737bb3
commit
2a26cc6d2b
@ -214,7 +214,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
if isinstance(other, RunnableSequence):
|
if isinstance(other, RunnableSequence):
|
||||||
return RunnableSequence(
|
return RunnableSequence(
|
||||||
first=self.first,
|
first=self.first,
|
||||||
middle=self.middle + [self.last] + other.middle,
|
middle=self.middle + [self.last] + [other.first] + other.middle,
|
||||||
last=other.last,
|
last=other.last,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -235,7 +235,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
if isinstance(other, RunnableSequence):
|
if isinstance(other, RunnableSequence):
|
||||||
return RunnableSequence(
|
return RunnableSequence(
|
||||||
first=other.first,
|
first=other.first,
|
||||||
middle=other.middle + [other.last] + self.middle,
|
middle=other.middle + [other.last] + [self.first] + self.middle,
|
||||||
last=self.last,
|
last=self.last,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
File diff suppressed because one or more lines are too long
@ -440,6 +440,64 @@ def test_prompt_with_chat_model_and_parser(
|
|||||||
assert tracer.runs == snapshot
|
assert tracer.runs == snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@freeze_time("2023-01-01")
|
||||||
|
def test_combining_sequences(
|
||||||
|
mocker: MockerFixture, snapshot: SnapshotAssertion
|
||||||
|
) -> None:
|
||||||
|
prompt = (
|
||||||
|
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
|
||||||
|
+ "{question}"
|
||||||
|
)
|
||||||
|
chat = FakeListChatModel(responses=["foo, bar"])
|
||||||
|
parser = CommaSeparatedListOutputParser()
|
||||||
|
|
||||||
|
chain = prompt | chat | parser
|
||||||
|
|
||||||
|
assert isinstance(chain, RunnableSequence)
|
||||||
|
assert chain.first == prompt
|
||||||
|
assert chain.middle == [chat]
|
||||||
|
assert chain.last == parser
|
||||||
|
assert dumps(chain, pretty=True) == snapshot
|
||||||
|
|
||||||
|
prompt2 = (
|
||||||
|
SystemMessagePromptTemplate.from_template("You are a nicer assistant.")
|
||||||
|
+ "{question}"
|
||||||
|
)
|
||||||
|
chat2 = FakeListChatModel(responses=["baz, qux"])
|
||||||
|
parser2 = CommaSeparatedListOutputParser()
|
||||||
|
input_formatter: RunnableLambda[List[str], Dict[str, Any]] = RunnableLambda(
|
||||||
|
lambda x: {"question": x[0] + x[1]}
|
||||||
|
)
|
||||||
|
|
||||||
|
chain2 = input_formatter | prompt2 | chat2 | parser2
|
||||||
|
|
||||||
|
assert isinstance(chain, RunnableSequence)
|
||||||
|
assert chain2.first == input_formatter
|
||||||
|
assert chain2.middle == [prompt2, chat2]
|
||||||
|
assert chain2.last == parser2
|
||||||
|
assert dumps(chain2, pretty=True) == snapshot
|
||||||
|
|
||||||
|
combined_chain = chain | chain2
|
||||||
|
|
||||||
|
assert combined_chain.first == prompt
|
||||||
|
assert combined_chain.middle == [
|
||||||
|
chat,
|
||||||
|
parser,
|
||||||
|
input_formatter,
|
||||||
|
prompt2,
|
||||||
|
chat2,
|
||||||
|
]
|
||||||
|
assert combined_chain.last == parser2
|
||||||
|
assert dumps(combined_chain, pretty=True) == snapshot
|
||||||
|
|
||||||
|
# Test invoke
|
||||||
|
tracer = FakeTracer()
|
||||||
|
assert combined_chain.invoke(
|
||||||
|
{"question": "What is your name?"}, dict(callbacks=[tracer])
|
||||||
|
) == ["baz", "qux"]
|
||||||
|
assert tracer.runs == snapshot
|
||||||
|
|
||||||
|
|
||||||
@freeze_time("2023-01-01")
|
@freeze_time("2023-01-01")
|
||||||
def test_seq_dict_prompt_llm(
|
def test_seq_dict_prompt_llm(
|
||||||
mocker: MockerFixture, snapshot: SnapshotAssertion
|
mocker: MockerFixture, snapshot: SnapshotAssertion
|
||||||
|
Loading…
Reference in New Issue
Block a user