Add .pick and .assign methods to Runnable (#15229)

<!-- Thank you for contributing to LangChain!

Please title your PR "<package>: <description>", where <package> is
whichever of langchain, community, core, experimental, etc. is being
modified.

Replace this entire comment with:
  - **Description:** a description of the change, 
  - **Issue:** the issue # it fixes if applicable,
  - **Dependencies:** any dependencies required for this change,
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` from the root
of the package you've modified to check this locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc: https://python.langchain.com/docs/contributing/

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in
`docs/docs/integrations` directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->
This commit is contained in:
Nuno Campos 2023-12-27 13:35:34 -08:00 committed by GitHub
parent 0252a24471
commit 6a5a2fb9c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 169 additions and 109 deletions

View File

@ -152,8 +152,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"full_chain = (\n", "full_chain = (\n",
" RunnablePassthrough.assign(query=sql_response)\n", " RunnablePassthrough.assign(query=sql_response).assign(\n",
" | RunnablePassthrough.assign(\n",
" schema=get_schema,\n", " schema=get_schema,\n",
" response=lambda x: db.run(x[\"query\"]),\n", " response=lambda x: db.run(x[\"query\"]),\n",
" )\n", " )\n",

View File

@ -31,7 +31,11 @@ from langchain_core.runnables.config import (
patch_config, patch_config,
) )
from langchain_core.runnables.fallbacks import RunnableWithFallbacks from langchain_core.runnables.fallbacks import RunnableWithFallbacks
from langchain_core.runnables.passthrough import RunnablePassthrough from langchain_core.runnables.passthrough import (
RunnableAssign,
RunnablePassthrough,
RunnablePick,
)
from langchain_core.runnables.router import RouterInput, RouterRunnable from langchain_core.runnables.router import RouterInput, RouterRunnable
from langchain_core.runnables.utils import ( from langchain_core.runnables.utils import (
AddableDict, AddableDict,
@ -60,6 +64,8 @@ __all__ = [
"RunnableMap", "RunnableMap",
"RunnableParallel", "RunnableParallel",
"RunnablePassthrough", "RunnablePassthrough",
"RunnableAssign",
"RunnablePick",
"RunnableSequence", "RunnableSequence",
"RunnableWithFallbacks", "RunnableWithFallbacks",
"get_config_list", "get_config_list",

View File

@ -220,9 +220,11 @@ class Runnable(Generic[Input, Output], ABC):
name: Optional[str] = None name: Optional[str] = None
"""The name of the runnable. Used for debugging and tracing.""" """The name of the runnable. Used for debugging and tracing."""
def get_name(self, suffix: Optional[str] = None) -> str: def get_name(
self, suffix: Optional[str] = None, *, name: Optional[str] = None
) -> str:
"""Get the name of the runnable.""" """Get the name of the runnable."""
name = self.name or self.__class__.__name__ name = name or self.name or self.__class__.__name__
if suffix: if suffix:
if name[0].isupper(): if name[0].isupper():
return name + suffix.title() return name + suffix.title()
@ -410,6 +412,38 @@ class Runnable(Generic[Input, Output], ABC):
"""Compose this runnable with another object to create a RunnableSequence.""" """Compose this runnable with another object to create a RunnableSequence."""
return RunnableSequence(coerce_to_runnable(other), self) return RunnableSequence(coerce_to_runnable(other), self)
def pipe(
self,
*others: Union[Runnable[Any, Other], Callable[[Any], Other]],
name: Optional[str] = None,
) -> RunnableSerializable[Input, Other]:
"""Compose this runnable with another object to create a RunnableSequence."""
return RunnableSequence(self, *others, name=name)
def pick(self, keys: Union[str, List[str]]) -> RunnableSerializable[Any, Any]:
"""Pick keys from the dict output of this runnable.
Returns a new runnable."""
from langchain_core.runnables.passthrough import RunnablePick
return self | RunnablePick(keys)
def assign(
self,
**kwargs: Union[
Runnable[Dict[str, Any], Any],
Callable[[Dict[str, Any]], Any],
Mapping[
str,
Union[Runnable[Dict[str, Any], Any], Callable[[Dict[str, Any]], Any]],
],
],
) -> RunnableSerializable[Any, Any]:
"""Assigns new fields to the dict output of this runnable.
Returns a new runnable."""
from langchain_core.runnables.passthrough import RunnableAssign
return self | RunnableAssign(RunnableParallel(kwargs))
""" --- Public API --- """ """ --- Public API --- """
@abstractmethod @abstractmethod
@ -1669,7 +1703,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
callback_manager = get_callback_manager_for_config(config) callback_manager = get_callback_manager_for_config(config)
# start the root run # start the root run
run_manager = callback_manager.on_chain_start( run_manager = callback_manager.on_chain_start(
dumpd(self), input, name=config.get("run_name") or self.name dumpd(self), input, name=config.get("run_name") or self.get_name()
) )
# invoke all steps in sequence # invoke all steps in sequence
@ -1703,7 +1737,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
callback_manager = get_async_callback_manager_for_config(config) callback_manager = get_async_callback_manager_for_config(config)
# start the root run # start the root run
run_manager = await callback_manager.on_chain_start( run_manager = await callback_manager.on_chain_start(
dumpd(self), input, name=config.get("run_name") or self.name dumpd(self), input, name=config.get("run_name") or self.get_name()
) )
# invoke all steps in sequence # invoke all steps in sequence
@ -1760,7 +1794,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
cm.on_chain_start( cm.on_chain_start(
dumpd(self), dumpd(self),
input, input,
name=config.get("run_name") or self.name, name=config.get("run_name") or self.get_name(),
) )
for cm, input, config in zip(callback_managers, inputs, configs) for cm, input, config in zip(callback_managers, inputs, configs)
] ]
@ -1884,7 +1918,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
cm.on_chain_start( cm.on_chain_start(
dumpd(self), dumpd(self),
input, input,
name=config.get("run_name") or self.name, name=config.get("run_name") or self.get_name(),
) )
for cm, input, config in zip(callback_managers, inputs, configs) for cm, input, config in zip(callback_managers, inputs, configs)
) )
@ -2119,6 +2153,12 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
class Config: class Config:
arbitrary_types_allowed = True arbitrary_types_allowed = True
def get_name(
self, suffix: Optional[str] = None, *, name: Optional[str] = None
) -> str:
name = name or self.name or f"RunnableParallel<{','.join(self.steps.keys())}>"
return super().get_name(suffix, name=name)
@property @property
def InputType(self) -> Any: def InputType(self) -> Any:
for step in self.steps.values(): for step in self.steps.values():
@ -2214,7 +2254,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
) )
# start the root run # start the root run
run_manager = callback_manager.on_chain_start( run_manager = callback_manager.on_chain_start(
dumpd(self), input, name=config.get("run_name") dumpd(self), input, name=config.get("run_name") or self.get_name()
) )
# gather results from all steps # gather results from all steps
@ -2254,7 +2294,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
callback_manager = get_async_callback_manager_for_config(config) callback_manager = get_async_callback_manager_for_config(config)
# start the root run # start the root run
run_manager = await callback_manager.on_chain_start( run_manager = await callback_manager.on_chain_start(
dumpd(self), input, name=config.get("run_name") dumpd(self), input, name=config.get("run_name") or self.get_name()
) )
# gather results from all steps # gather results from all steps
@ -3174,6 +3214,12 @@ class RunnableEach(RunnableEachBase[Input, Output]):
"""Get the namespace of the langchain object.""" """Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"] return ["langchain", "schema", "runnable"]
def get_name(
self, suffix: Optional[str] = None, *, name: Optional[str] = None
) -> str:
name = name or self.name or f"RunnableEach<{self.bound.get_name()}>"
return super().get_name(suffix, name=name)
def bind(self, **kwargs: Any) -> RunnableEach[Input, Output]: def bind(self, **kwargs: Any) -> RunnableEach[Input, Output]:
return RunnableEach(bound=self.bound.bind(**kwargs)) return RunnableEach(bound=self.bound.bind(**kwargs))
@ -3298,8 +3344,10 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
**other_kwargs, **other_kwargs,
) )
def get_name(self, suffix: Optional[str] = None) -> str: def get_name(
return self.bound.get_name(suffix) self, suffix: Optional[str] = None, *, name: Optional[str] = None
) -> str:
return self.bound.get_name(suffix, name=name)
@property @property
def InputType(self) -> Type[Input]: def InputType(self) -> Type[Input]:

View File

@ -202,21 +202,6 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
""" """
return RunnableAssign(RunnableParallel(kwargs)) return RunnableAssign(RunnableParallel(kwargs))
@classmethod
def pick(
cls,
keys: Union[str, List[str]],
) -> "RunnablePick":
"""Pick keys from the Dict input.
Args:
keys: A string or list of strings representing the keys to pick.
Returns:
A runnable that picks keys from the Dict input.
"""
return RunnablePick(keys)
def invoke( def invoke(
self, input: Other, config: Optional[RunnableConfig] = None, **kwargs: Any self, input: Other, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Other: ) -> Other:
@ -335,6 +320,14 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
"""Get the namespace of the langchain object.""" """Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"] return ["langchain", "schema", "runnable"]
def get_name(
self, suffix: Optional[str] = None, *, name: Optional[str] = None
) -> str:
name = (
name or self.name or f"RunnableAssign<{','.join(self.mapper.steps.keys())}>"
)
return super().get_name(suffix, name=name)
def get_input_schema( def get_input_schema(
self, config: Optional[RunnableConfig] = None self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]: ) -> Type[BaseModel]:
@ -589,6 +582,16 @@ class RunnablePick(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
"""Get the namespace of the langchain object.""" """Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"] return ["langchain", "schema", "runnable"]
def get_name(
self, suffix: Optional[str] = None, *, name: Optional[str] = None
) -> str:
name = (
name
or self.name
or f"RunnablePick<{','.join([self.keys] if isinstance(self.keys, str) else self.keys)}>" # noqa: E501
)
return super().get_name(suffix, name=name)
def _pick(self, input: Dict[str, Any]) -> Any: def _pick(self, input: Dict[str, Any]) -> Any:
assert isinstance( assert isinstance(
input, dict input, dict

View File

@ -32,51 +32,51 @@
# --- # ---
# name: test_graph_sequence_map # name: test_graph_sequence_map
''' '''
+-------------+ +-------------+
| PromptInput | | PromptInput |
+-------------+ +-------------+
* *
* *
* *
+----------------+ +----------------+
| PromptTemplate | | PromptTemplate |
+----------------+ +----------------+
* *
* *
* *
+-------------+ +-------------+
| FakeListLLM | | FakeListLLM |
+-------------+ +-------------+
* *
* *
* *
+---------------+ +-------------------------------+
| ParallelInput | | Parallel<as_list,as_str>Input |
+---------------+****** +-------------------------------+
***** ****** ***** ******
*** ****** *** ******
*** ****** *** ******
+------------------------------+ *** +------------------------------+ ****
| conditional_str_parser_input | * | conditional_str_parser_input | *
+------------------------------+ * +------------------------------+ *
*** *** * *** *** *
*** *** * *** *** *
** ** * ** ** *
+-----------------+ +-----------------+ * +-----------------+ +-----------------+ *
| StrOutputParser | | XMLOutputParser | * | StrOutputParser | | XMLOutputParser | *
+-----------------+ +-----------------+ * +-----------------+ +-----------------+ *
*** *** * *** *** *
*** *** * *** *** *
** ** * ** ** *
+-------------------------------+ +--------------------------------+ +-------------------------------+ +--------------------------------+
| conditional_str_parser_output | | CommaSeparatedListOutputParser | | conditional_str_parser_output | | CommaSeparatedListOutputParser |
+-------------------------------+ +--------------------------------+ +-------------------------------+ +--------------------------------+
***** ****** ***** ******
*** ****** *** ******
*** *** *** ****
+----------------+ +--------------------------------+
| ParallelOutput | | Parallel<as_list,as_str>Output |
+----------------+ +--------------------------------+
''' '''
# --- # ---
# name: test_graph_single_runnable # name: test_graph_single_runnable

View File

@ -4012,7 +4012,7 @@
'items': dict({ 'items': dict({
'$ref': '#/definitions/PromptTemplateOutput', '$ref': '#/definitions/PromptTemplateOutput',
}), }),
'title': 'RunnableEachOutput', 'title': 'RunnableEach<PromptTemplate>Output',
'type': 'array', 'type': 'array',
}) })
# --- # ---

View File

@ -18,6 +18,8 @@ EXPECTED_ALL = [
"RunnableMap", "RunnableMap",
"RunnableParallel", "RunnableParallel",
"RunnablePassthrough", "RunnablePassthrough",
"RunnableAssign",
"RunnablePick",
"RunnableSequence", "RunnableSequence",
"RunnableWithFallbacks", "RunnableWithFallbacks",
"get_config_list", "get_config_list",

View File

@ -64,6 +64,7 @@ from langchain_core.runnables import (
RunnableLambda, RunnableLambda,
RunnableParallel, RunnableParallel,
RunnablePassthrough, RunnablePassthrough,
RunnablePick,
RunnableSequence, RunnableSequence,
RunnableWithFallbacks, RunnableWithFallbacks,
add, add,
@ -510,7 +511,7 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
}, },
"items": {"$ref": "#/definitions/PromptInput"}, "items": {"$ref": "#/definitions/PromptInput"},
"type": "array", "type": "array",
"title": "RunnableEachInput", "title": "RunnableEach<PromptTemplate>Input",
} }
assert prompt_mapper.output_schema.schema() == snapshot assert prompt_mapper.output_schema.schema() == snapshot
@ -571,7 +572,7 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
"properties": {"name": {"title": "Name", "type": "string"}}, "properties": {"name": {"title": "Name", "type": "string"}},
} }
assert seq_w_map.output_schema.schema() == { assert seq_w_map.output_schema.schema() == {
"title": "RunnableParallelOutput", "title": "RunnableParallel<original,as_list,length>Output",
"type": "object", "type": "object",
"properties": { "properties": {
"original": {"title": "Original", "type": "string"}, "original": {"title": "Original", "type": "string"},
@ -615,7 +616,7 @@ def test_passthrough_assign_schema() -> None:
# expected dict input_schema # expected dict input_schema
assert invalid_seq_w_assign.input_schema.schema() == { assert invalid_seq_w_assign.input_schema.schema() == {
"properties": {"question": {"title": "Question"}}, "properties": {"question": {"title": "Question"}},
"title": "RunnableParallelInput", "title": "RunnableParallel<context>Input",
"type": "object", "type": "object",
} }
@ -774,7 +775,7 @@ def test_schema_complex_seq() -> None:
) )
assert chain2.input_schema.schema() == { assert chain2.input_schema.schema() == {
"title": "RunnableParallelInput", "title": "RunnableParallel<city,language>Input",
"type": "object", "type": "object",
"properties": { "properties": {
"person": {"title": "Person", "type": "string"}, "person": {"title": "Person", "type": "string"},
@ -2160,8 +2161,8 @@ async def test_stream_log_retriever() -> None:
"FakeListLLM:2", "FakeListLLM:2",
"Retriever", "Retriever",
"RunnableLambda", "RunnableLambda",
"RunnableParallel", "RunnableParallel<documents,question>",
"RunnableParallel:2", "RunnableParallel<one,two>",
] ]
@ -2444,7 +2445,7 @@ What is your name?"""
parent_run = next(r for r in tracer.runs if r.parent_run_id is None) parent_run = next(r for r in tracer.runs if r.parent_run_id is None)
assert len(parent_run.child_runs) == 4 assert len(parent_run.child_runs) == 4
map_run = parent_run.child_runs[0] map_run = parent_run.child_runs[0]
assert map_run.name == "RunnableParallel" assert map_run.name == "RunnableParallel<question,documents,just_to_test_lambda>"
assert len(map_run.child_runs) == 3 assert len(map_run.child_runs) == 3
@ -2505,7 +2506,7 @@ def test_seq_prompt_dict(mocker: MockerFixture, snapshot: SnapshotAssertion) ->
parent_run = next(r for r in tracer.runs if r.parent_run_id is None) parent_run = next(r for r in tracer.runs if r.parent_run_id is None)
assert len(parent_run.child_runs) == 3 assert len(parent_run.child_runs) == 3
map_run = parent_run.child_runs[2] map_run = parent_run.child_runs[2]
assert map_run.name == "RunnableParallel" assert map_run.name == "RunnableParallel<chat,llm>"
assert len(map_run.child_runs) == 2 assert len(map_run.child_runs) == 2
@ -2721,7 +2722,7 @@ def test_seq_prompt_map(mocker: MockerFixture, snapshot: SnapshotAssertion) -> N
parent_run = next(r for r in tracer.runs if r.parent_run_id is None) parent_run = next(r for r in tracer.runs if r.parent_run_id is None)
assert len(parent_run.child_runs) == 3 assert len(parent_run.child_runs) == 3
map_run = parent_run.child_runs[2] map_run = parent_run.child_runs[2]
assert map_run.name == "RunnableParallel" assert map_run.name == "RunnableParallel<chat,llm,passthrough>"
assert len(map_run.child_runs) == 3 assert len(map_run.child_runs) == 3
@ -2770,7 +2771,7 @@ def test_map_stream() -> None:
{"question": "What is your name?"} {"question": "What is your name?"}
) )
chain_pick_one = chain | RunnablePassthrough.pick("llm") chain_pick_one = chain.pick("llm")
assert chain_pick_one.output_schema.schema() == { assert chain_pick_one.output_schema.schema() == {
"title": "RunnableSequenceOutput", "title": "RunnableSequenceOutput",
@ -2791,10 +2792,8 @@ def test_map_stream() -> None:
assert streamed_chunks[0] == "i" assert streamed_chunks[0] == "i"
assert len(streamed_chunks) == len(llm_res) assert len(streamed_chunks) == len(llm_res)
chain_pick_two = ( chain_pick_two = chain.assign(hello=RunnablePick("llm").pipe(llm)).pick(
chain ["llm", "hello"]
| RunnablePassthrough.assign(hello=RunnablePassthrough.pick("llm") | llm)
| RunnablePassthrough.pick(["llm", "hello"])
) )
assert chain_pick_two.output_schema.schema() == { assert chain_pick_two.output_schema.schema() == {
@ -2940,12 +2939,15 @@ async def test_map_astream() -> None:
assert final_state.state["logs"]["ChatPromptTemplate"][ assert final_state.state["logs"]["ChatPromptTemplate"][
"final_output" "final_output"
] == prompt.invoke({"question": "What is your name?"}) ] == prompt.invoke({"question": "What is your name?"})
assert final_state.state["logs"]["RunnableParallel"]["name"] == "RunnableParallel" assert (
final_state.state["logs"]["RunnableParallel<chat,llm,passthrough>"]["name"]
== "RunnableParallel<chat,llm,passthrough>"
)
assert sorted(final_state.state["logs"]) == [ assert sorted(final_state.state["logs"]) == [
"ChatPromptTemplate", "ChatPromptTemplate",
"FakeListChatModel", "FakeListChatModel",
"FakeStreamingListLLM", "FakeStreamingListLLM",
"RunnableParallel", "RunnableParallel<chat,llm,passthrough>",
"RunnablePassthrough", "RunnablePassthrough",
] ]
@ -2985,11 +2987,14 @@ async def test_map_astream() -> None:
assert final_state.state["logs"]["ChatPromptTemplate"]["final_output"] == ( assert final_state.state["logs"]["ChatPromptTemplate"]["final_output"] == (
prompt.invoke({"question": "What is your name?"}) prompt.invoke({"question": "What is your name?"})
) )
assert final_state.state["logs"]["RunnableParallel"]["name"] == "RunnableParallel" assert (
final_state.state["logs"]["RunnableParallel<chat,llm,passthrough>"]["name"]
== "RunnableParallel<chat,llm,passthrough>"
)
assert sorted(final_state.state["logs"]) == [ assert sorted(final_state.state["logs"]) == [
"ChatPromptTemplate", "ChatPromptTemplate",
"FakeStreamingListLLM", "FakeStreamingListLLM",
"RunnableParallel", "RunnableParallel<chat,llm,passthrough>",
"RunnablePassthrough", "RunnablePassthrough",
] ]
@ -3130,9 +3135,7 @@ def test_deep_stream_assign() -> None:
assert len(chunks) == len("foo-lish") assert len(chunks) == len("foo-lish")
assert add(chunks) == {"str": "foo-lish"} assert add(chunks) == {"str": "foo-lish"}
chain_with_assign = chain | RunnablePassthrough.assign( chain_with_assign = chain.assign(hello=itemgetter("str") | llm)
hello=itemgetter("str") | llm
)
assert chain_with_assign.input_schema.schema() == { assert chain_with_assign.input_schema.schema() == {
"title": "PromptInput", "title": "PromptInput",
@ -3179,7 +3182,7 @@ def test_deep_stream_assign() -> None:
"hello": "foo-lish", "hello": "foo-lish",
} }
chain_with_assign_shadow = chain | RunnablePassthrough.assign( chain_with_assign_shadow = chain.assign(
str=lambda _: "shadow", str=lambda _: "shadow",
hello=itemgetter("str") | llm, hello=itemgetter("str") | llm,
) )
@ -3254,7 +3257,7 @@ async def test_deep_astream_assign() -> None:
assert len(chunks) == len("foo-lish") assert len(chunks) == len("foo-lish")
assert add(chunks) == {"str": "foo-lish"} assert add(chunks) == {"str": "foo-lish"}
chain_with_assign = chain | RunnablePassthrough.assign( chain_with_assign = chain.assign(
hello=itemgetter("str") | llm, hello=itemgetter("str") | llm,
) )
@ -4473,15 +4476,15 @@ def test_invoke_stream_passthrough_assign_trace() -> None:
tracer = FakeTracer() tracer = FakeTracer()
chain.invoke({"example": [1, 2, 3]}, dict(callbacks=[tracer])) chain.invoke({"example": [1, 2, 3]}, dict(callbacks=[tracer]))
assert tracer.runs[0].name == "RunnableAssign" assert tracer.runs[0].name == "RunnableAssign<urls>"
assert tracer.runs[0].child_runs[0].name == "RunnableParallel" assert tracer.runs[0].child_runs[0].name == "RunnableParallel<urls>"
tracer = FakeTracer() tracer = FakeTracer()
for item in chain.stream({"example": [1, 2, 3]}, dict(callbacks=[tracer])): for item in chain.stream({"example": [1, 2, 3]}, dict(callbacks=[tracer])):
pass pass
assert tracer.runs[0].name == "RunnableAssign" assert tracer.runs[0].name == "RunnableAssign<urls>"
assert tracer.runs[0].child_runs[0].name == "RunnableParallel" assert tracer.runs[0].child_runs[0].name == "RunnableParallel<urls>"
async def test_ainvoke_astream_passthrough_assign_trace() -> None: async def test_ainvoke_astream_passthrough_assign_trace() -> None:
@ -4493,15 +4496,15 @@ async def test_ainvoke_astream_passthrough_assign_trace() -> None:
tracer = FakeTracer() tracer = FakeTracer()
await chain.ainvoke({"example": [1, 2, 3]}, dict(callbacks=[tracer])) await chain.ainvoke({"example": [1, 2, 3]}, dict(callbacks=[tracer]))
assert tracer.runs[0].name == "RunnableAssign" assert tracer.runs[0].name == "RunnableAssign<urls>"
assert tracer.runs[0].child_runs[0].name == "RunnableParallel" assert tracer.runs[0].child_runs[0].name == "RunnableParallel<urls>"
tracer = FakeTracer() tracer = FakeTracer()
async for item in chain.astream({"example": [1, 2, 3]}, dict(callbacks=[tracer])): async for item in chain.astream({"example": [1, 2, 3]}, dict(callbacks=[tracer])):
pass pass
assert tracer.runs[0].name == "RunnableAssign" assert tracer.runs[0].name == "RunnableAssign<urls>"
assert tracer.runs[0].child_runs[0].name == "RunnableParallel" assert tracer.runs[0].child_runs[0].name == "RunnableParallel<urls>"
async def test_astream_log_deep_copies() -> None: async def test_astream_log_deep_copies() -> None:

View File

@ -35,13 +35,13 @@ def create_history_aware_retriever(
# pip install -U langchain langchain-community # pip install -U langchain langchain-community
from langchain_community.chat_models import ChatOpenAI from langchain_community.chat_models import ChatOpenAI
from langchain.chains import create_chat_history_retriever from langchain.chains import create_history_aware_retriever
from langchain import hub from langchain import hub
rephrase_prompt = hub.pull("langchain-ai/chat-langchain-rephrase") rephrase_prompt = hub.pull("langchain-ai/chat-langchain-rephrase")
llm = ChatOpenAI() llm = ChatOpenAI()
retriever = ... retriever = ...
chat_retriever_chain = create_chat_retriever_chain( chat_retriever_chain = create_history_aware_retriever(
llm, retriever, rephrase_prompt llm, retriever, rephrase_prompt
) )

View File

@ -64,8 +64,7 @@ def create_retrieval_chain(
RunnablePassthrough.assign( RunnablePassthrough.assign(
context=retrieval_docs.with_config(run_name="retrieve_documents"), context=retrieval_docs.with_config(run_name="retrieve_documents"),
chat_history=lambda x: x.get("chat_history", []), chat_history=lambda x: x.get("chat_history", []),
) ).assign(answer=combine_docs_chain)
| RunnablePassthrough.assign(answer=combine_docs_chain)
).with_config(run_name="retrieval_chain") ).with_config(run_name="retrieval_chain")
return retrieval_chain return retrieval_chain