mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-25 16:13:25 +00:00
Add .pick and .assign methods to Runnable (#15229)
<!-- Thank you for contributing to LangChain! Please title your PR "<package>: <description>", where <package> is whichever of langchain, community, core, experimental, etc. is being modified. Replace this entire comment with: - **Description:** a description of the change, - **Issue:** the issue # it fixes if applicable, - **Dependencies:** any dependencies required for this change, - **Twitter handle:** we announce bigger features on Twitter. If your PR gets announced, and you'd like a mention, we'll gladly shout you out! Please make sure your PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` from the root of the package you've modified to check this locally. See contribution guidelines for more information on how to write/run tests, lint, etc: https://python.langchain.com/docs/contributing/ If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/docs/integrations` directory. If no one reviews your PR within a few days, please @-mention one of @baskaryan, @eyurtsev, @hwchase17. -->
This commit is contained in:
parent
0252a24471
commit
6a5a2fb9c8
@ -152,8 +152,7 @@
|
|||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"full_chain = (\n",
|
"full_chain = (\n",
|
||||||
" RunnablePassthrough.assign(query=sql_response)\n",
|
" RunnablePassthrough.assign(query=sql_response).assign(\n",
|
||||||
" | RunnablePassthrough.assign(\n",
|
|
||||||
" schema=get_schema,\n",
|
" schema=get_schema,\n",
|
||||||
" response=lambda x: db.run(x[\"query\"]),\n",
|
" response=lambda x: db.run(x[\"query\"]),\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
|
@ -31,7 +31,11 @@ from langchain_core.runnables.config import (
|
|||||||
patch_config,
|
patch_config,
|
||||||
)
|
)
|
||||||
from langchain_core.runnables.fallbacks import RunnableWithFallbacks
|
from langchain_core.runnables.fallbacks import RunnableWithFallbacks
|
||||||
from langchain_core.runnables.passthrough import RunnablePassthrough
|
from langchain_core.runnables.passthrough import (
|
||||||
|
RunnableAssign,
|
||||||
|
RunnablePassthrough,
|
||||||
|
RunnablePick,
|
||||||
|
)
|
||||||
from langchain_core.runnables.router import RouterInput, RouterRunnable
|
from langchain_core.runnables.router import RouterInput, RouterRunnable
|
||||||
from langchain_core.runnables.utils import (
|
from langchain_core.runnables.utils import (
|
||||||
AddableDict,
|
AddableDict,
|
||||||
@ -60,6 +64,8 @@ __all__ = [
|
|||||||
"RunnableMap",
|
"RunnableMap",
|
||||||
"RunnableParallel",
|
"RunnableParallel",
|
||||||
"RunnablePassthrough",
|
"RunnablePassthrough",
|
||||||
|
"RunnableAssign",
|
||||||
|
"RunnablePick",
|
||||||
"RunnableSequence",
|
"RunnableSequence",
|
||||||
"RunnableWithFallbacks",
|
"RunnableWithFallbacks",
|
||||||
"get_config_list",
|
"get_config_list",
|
||||||
|
@ -220,9 +220,11 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
name: Optional[str] = None
|
name: Optional[str] = None
|
||||||
"""The name of the runnable. Used for debugging and tracing."""
|
"""The name of the runnable. Used for debugging and tracing."""
|
||||||
|
|
||||||
def get_name(self, suffix: Optional[str] = None) -> str:
|
def get_name(
|
||||||
|
self, suffix: Optional[str] = None, *, name: Optional[str] = None
|
||||||
|
) -> str:
|
||||||
"""Get the name of the runnable."""
|
"""Get the name of the runnable."""
|
||||||
name = self.name or self.__class__.__name__
|
name = name or self.name or self.__class__.__name__
|
||||||
if suffix:
|
if suffix:
|
||||||
if name[0].isupper():
|
if name[0].isupper():
|
||||||
return name + suffix.title()
|
return name + suffix.title()
|
||||||
@ -410,6 +412,38 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
"""Compose this runnable with another object to create a RunnableSequence."""
|
"""Compose this runnable with another object to create a RunnableSequence."""
|
||||||
return RunnableSequence(coerce_to_runnable(other), self)
|
return RunnableSequence(coerce_to_runnable(other), self)
|
||||||
|
|
||||||
|
def pipe(
|
||||||
|
self,
|
||||||
|
*others: Union[Runnable[Any, Other], Callable[[Any], Other]],
|
||||||
|
name: Optional[str] = None,
|
||||||
|
) -> RunnableSerializable[Input, Other]:
|
||||||
|
"""Compose this runnable with another object to create a RunnableSequence."""
|
||||||
|
return RunnableSequence(self, *others, name=name)
|
||||||
|
|
||||||
|
def pick(self, keys: Union[str, List[str]]) -> RunnableSerializable[Any, Any]:
|
||||||
|
"""Pick keys from the dict output of this runnable.
|
||||||
|
Returns a new runnable."""
|
||||||
|
from langchain_core.runnables.passthrough import RunnablePick
|
||||||
|
|
||||||
|
return self | RunnablePick(keys)
|
||||||
|
|
||||||
|
def assign(
|
||||||
|
self,
|
||||||
|
**kwargs: Union[
|
||||||
|
Runnable[Dict[str, Any], Any],
|
||||||
|
Callable[[Dict[str, Any]], Any],
|
||||||
|
Mapping[
|
||||||
|
str,
|
||||||
|
Union[Runnable[Dict[str, Any], Any], Callable[[Dict[str, Any]], Any]],
|
||||||
|
],
|
||||||
|
],
|
||||||
|
) -> RunnableSerializable[Any, Any]:
|
||||||
|
"""Assigns new fields to the dict output of this runnable.
|
||||||
|
Returns a new runnable."""
|
||||||
|
from langchain_core.runnables.passthrough import RunnableAssign
|
||||||
|
|
||||||
|
return self | RunnableAssign(RunnableParallel(kwargs))
|
||||||
|
|
||||||
""" --- Public API --- """
|
""" --- Public API --- """
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -1669,7 +1703,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
|||||||
callback_manager = get_callback_manager_for_config(config)
|
callback_manager = get_callback_manager_for_config(config)
|
||||||
# start the root run
|
# start the root run
|
||||||
run_manager = callback_manager.on_chain_start(
|
run_manager = callback_manager.on_chain_start(
|
||||||
dumpd(self), input, name=config.get("run_name") or self.name
|
dumpd(self), input, name=config.get("run_name") or self.get_name()
|
||||||
)
|
)
|
||||||
|
|
||||||
# invoke all steps in sequence
|
# invoke all steps in sequence
|
||||||
@ -1703,7 +1737,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
|||||||
callback_manager = get_async_callback_manager_for_config(config)
|
callback_manager = get_async_callback_manager_for_config(config)
|
||||||
# start the root run
|
# start the root run
|
||||||
run_manager = await callback_manager.on_chain_start(
|
run_manager = await callback_manager.on_chain_start(
|
||||||
dumpd(self), input, name=config.get("run_name") or self.name
|
dumpd(self), input, name=config.get("run_name") or self.get_name()
|
||||||
)
|
)
|
||||||
|
|
||||||
# invoke all steps in sequence
|
# invoke all steps in sequence
|
||||||
@ -1760,7 +1794,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
|||||||
cm.on_chain_start(
|
cm.on_chain_start(
|
||||||
dumpd(self),
|
dumpd(self),
|
||||||
input,
|
input,
|
||||||
name=config.get("run_name") or self.name,
|
name=config.get("run_name") or self.get_name(),
|
||||||
)
|
)
|
||||||
for cm, input, config in zip(callback_managers, inputs, configs)
|
for cm, input, config in zip(callback_managers, inputs, configs)
|
||||||
]
|
]
|
||||||
@ -1884,7 +1918,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
|||||||
cm.on_chain_start(
|
cm.on_chain_start(
|
||||||
dumpd(self),
|
dumpd(self),
|
||||||
input,
|
input,
|
||||||
name=config.get("run_name") or self.name,
|
name=config.get("run_name") or self.get_name(),
|
||||||
)
|
)
|
||||||
for cm, input, config in zip(callback_managers, inputs, configs)
|
for cm, input, config in zip(callback_managers, inputs, configs)
|
||||||
)
|
)
|
||||||
@ -2119,6 +2153,12 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
|
|||||||
class Config:
|
class Config:
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
def get_name(
|
||||||
|
self, suffix: Optional[str] = None, *, name: Optional[str] = None
|
||||||
|
) -> str:
|
||||||
|
name = name or self.name or f"RunnableParallel<{','.join(self.steps.keys())}>"
|
||||||
|
return super().get_name(suffix, name=name)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def InputType(self) -> Any:
|
def InputType(self) -> Any:
|
||||||
for step in self.steps.values():
|
for step in self.steps.values():
|
||||||
@ -2214,7 +2254,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
|
|||||||
)
|
)
|
||||||
# start the root run
|
# start the root run
|
||||||
run_manager = callback_manager.on_chain_start(
|
run_manager = callback_manager.on_chain_start(
|
||||||
dumpd(self), input, name=config.get("run_name")
|
dumpd(self), input, name=config.get("run_name") or self.get_name()
|
||||||
)
|
)
|
||||||
|
|
||||||
# gather results from all steps
|
# gather results from all steps
|
||||||
@ -2254,7 +2294,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
|
|||||||
callback_manager = get_async_callback_manager_for_config(config)
|
callback_manager = get_async_callback_manager_for_config(config)
|
||||||
# start the root run
|
# start the root run
|
||||||
run_manager = await callback_manager.on_chain_start(
|
run_manager = await callback_manager.on_chain_start(
|
||||||
dumpd(self), input, name=config.get("run_name")
|
dumpd(self), input, name=config.get("run_name") or self.get_name()
|
||||||
)
|
)
|
||||||
|
|
||||||
# gather results from all steps
|
# gather results from all steps
|
||||||
@ -3174,6 +3214,12 @@ class RunnableEach(RunnableEachBase[Input, Output]):
|
|||||||
"""Get the namespace of the langchain object."""
|
"""Get the namespace of the langchain object."""
|
||||||
return ["langchain", "schema", "runnable"]
|
return ["langchain", "schema", "runnable"]
|
||||||
|
|
||||||
|
def get_name(
|
||||||
|
self, suffix: Optional[str] = None, *, name: Optional[str] = None
|
||||||
|
) -> str:
|
||||||
|
name = name or self.name or f"RunnableEach<{self.bound.get_name()}>"
|
||||||
|
return super().get_name(suffix, name=name)
|
||||||
|
|
||||||
def bind(self, **kwargs: Any) -> RunnableEach[Input, Output]:
|
def bind(self, **kwargs: Any) -> RunnableEach[Input, Output]:
|
||||||
return RunnableEach(bound=self.bound.bind(**kwargs))
|
return RunnableEach(bound=self.bound.bind(**kwargs))
|
||||||
|
|
||||||
@ -3298,8 +3344,10 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
|
|||||||
**other_kwargs,
|
**other_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_name(self, suffix: Optional[str] = None) -> str:
|
def get_name(
|
||||||
return self.bound.get_name(suffix)
|
self, suffix: Optional[str] = None, *, name: Optional[str] = None
|
||||||
|
) -> str:
|
||||||
|
return self.bound.get_name(suffix, name=name)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def InputType(self) -> Type[Input]:
|
def InputType(self) -> Type[Input]:
|
||||||
|
@ -202,21 +202,6 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
|
|||||||
"""
|
"""
|
||||||
return RunnableAssign(RunnableParallel(kwargs))
|
return RunnableAssign(RunnableParallel(kwargs))
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def pick(
|
|
||||||
cls,
|
|
||||||
keys: Union[str, List[str]],
|
|
||||||
) -> "RunnablePick":
|
|
||||||
"""Pick keys from the Dict input.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
keys: A string or list of strings representing the keys to pick.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A runnable that picks keys from the Dict input.
|
|
||||||
"""
|
|
||||||
return RunnablePick(keys)
|
|
||||||
|
|
||||||
def invoke(
|
def invoke(
|
||||||
self, input: Other, config: Optional[RunnableConfig] = None, **kwargs: Any
|
self, input: Other, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||||
) -> Other:
|
) -> Other:
|
||||||
@ -335,6 +320,14 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
|||||||
"""Get the namespace of the langchain object."""
|
"""Get the namespace of the langchain object."""
|
||||||
return ["langchain", "schema", "runnable"]
|
return ["langchain", "schema", "runnable"]
|
||||||
|
|
||||||
|
def get_name(
|
||||||
|
self, suffix: Optional[str] = None, *, name: Optional[str] = None
|
||||||
|
) -> str:
|
||||||
|
name = (
|
||||||
|
name or self.name or f"RunnableAssign<{','.join(self.mapper.steps.keys())}>"
|
||||||
|
)
|
||||||
|
return super().get_name(suffix, name=name)
|
||||||
|
|
||||||
def get_input_schema(
|
def get_input_schema(
|
||||||
self, config: Optional[RunnableConfig] = None
|
self, config: Optional[RunnableConfig] = None
|
||||||
) -> Type[BaseModel]:
|
) -> Type[BaseModel]:
|
||||||
@ -589,6 +582,16 @@ class RunnablePick(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
|||||||
"""Get the namespace of the langchain object."""
|
"""Get the namespace of the langchain object."""
|
||||||
return ["langchain", "schema", "runnable"]
|
return ["langchain", "schema", "runnable"]
|
||||||
|
|
||||||
|
def get_name(
|
||||||
|
self, suffix: Optional[str] = None, *, name: Optional[str] = None
|
||||||
|
) -> str:
|
||||||
|
name = (
|
||||||
|
name
|
||||||
|
or self.name
|
||||||
|
or f"RunnablePick<{','.join([self.keys] if isinstance(self.keys, str) else self.keys)}>" # noqa: E501
|
||||||
|
)
|
||||||
|
return super().get_name(suffix, name=name)
|
||||||
|
|
||||||
def _pick(self, input: Dict[str, Any]) -> Any:
|
def _pick(self, input: Dict[str, Any]) -> Any:
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
input, dict
|
input, dict
|
||||||
|
@ -32,51 +32,51 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_graph_sequence_map
|
# name: test_graph_sequence_map
|
||||||
'''
|
'''
|
||||||
+-------------+
|
+-------------+
|
||||||
| PromptInput |
|
| PromptInput |
|
||||||
+-------------+
|
+-------------+
|
||||||
*
|
*
|
||||||
*
|
*
|
||||||
*
|
*
|
||||||
+----------------+
|
+----------------+
|
||||||
| PromptTemplate |
|
| PromptTemplate |
|
||||||
+----------------+
|
+----------------+
|
||||||
*
|
*
|
||||||
*
|
*
|
||||||
*
|
*
|
||||||
+-------------+
|
+-------------+
|
||||||
| FakeListLLM |
|
| FakeListLLM |
|
||||||
+-------------+
|
+-------------+
|
||||||
*
|
*
|
||||||
*
|
*
|
||||||
*
|
*
|
||||||
+---------------+
|
+-------------------------------+
|
||||||
| ParallelInput |
|
| Parallel<as_list,as_str>Input |
|
||||||
+---------------+******
|
+-------------------------------+
|
||||||
***** ******
|
***** ******
|
||||||
*** ******
|
*** ******
|
||||||
*** ******
|
*** ******
|
||||||
+------------------------------+ ***
|
+------------------------------+ ****
|
||||||
| conditional_str_parser_input | *
|
| conditional_str_parser_input | *
|
||||||
+------------------------------+ *
|
+------------------------------+ *
|
||||||
*** *** *
|
*** *** *
|
||||||
*** *** *
|
*** *** *
|
||||||
** ** *
|
** ** *
|
||||||
+-----------------+ +-----------------+ *
|
+-----------------+ +-----------------+ *
|
||||||
| StrOutputParser | | XMLOutputParser | *
|
| StrOutputParser | | XMLOutputParser | *
|
||||||
+-----------------+ +-----------------+ *
|
+-----------------+ +-----------------+ *
|
||||||
*** *** *
|
*** *** *
|
||||||
*** *** *
|
*** *** *
|
||||||
** ** *
|
** ** *
|
||||||
+-------------------------------+ +--------------------------------+
|
+-------------------------------+ +--------------------------------+
|
||||||
| conditional_str_parser_output | | CommaSeparatedListOutputParser |
|
| conditional_str_parser_output | | CommaSeparatedListOutputParser |
|
||||||
+-------------------------------+ +--------------------------------+
|
+-------------------------------+ +--------------------------------+
|
||||||
***** ******
|
***** ******
|
||||||
*** ******
|
*** ******
|
||||||
*** ***
|
*** ****
|
||||||
+----------------+
|
+--------------------------------+
|
||||||
| ParallelOutput |
|
| Parallel<as_list,as_str>Output |
|
||||||
+----------------+
|
+--------------------------------+
|
||||||
'''
|
'''
|
||||||
# ---
|
# ---
|
||||||
# name: test_graph_single_runnable
|
# name: test_graph_single_runnable
|
||||||
|
@ -4012,7 +4012,7 @@
|
|||||||
'items': dict({
|
'items': dict({
|
||||||
'$ref': '#/definitions/PromptTemplateOutput',
|
'$ref': '#/definitions/PromptTemplateOutput',
|
||||||
}),
|
}),
|
||||||
'title': 'RunnableEachOutput',
|
'title': 'RunnableEach<PromptTemplate>Output',
|
||||||
'type': 'array',
|
'type': 'array',
|
||||||
})
|
})
|
||||||
# ---
|
# ---
|
||||||
|
@ -18,6 +18,8 @@ EXPECTED_ALL = [
|
|||||||
"RunnableMap",
|
"RunnableMap",
|
||||||
"RunnableParallel",
|
"RunnableParallel",
|
||||||
"RunnablePassthrough",
|
"RunnablePassthrough",
|
||||||
|
"RunnableAssign",
|
||||||
|
"RunnablePick",
|
||||||
"RunnableSequence",
|
"RunnableSequence",
|
||||||
"RunnableWithFallbacks",
|
"RunnableWithFallbacks",
|
||||||
"get_config_list",
|
"get_config_list",
|
||||||
|
@ -64,6 +64,7 @@ from langchain_core.runnables import (
|
|||||||
RunnableLambda,
|
RunnableLambda,
|
||||||
RunnableParallel,
|
RunnableParallel,
|
||||||
RunnablePassthrough,
|
RunnablePassthrough,
|
||||||
|
RunnablePick,
|
||||||
RunnableSequence,
|
RunnableSequence,
|
||||||
RunnableWithFallbacks,
|
RunnableWithFallbacks,
|
||||||
add,
|
add,
|
||||||
@ -510,7 +511,7 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
|
|||||||
},
|
},
|
||||||
"items": {"$ref": "#/definitions/PromptInput"},
|
"items": {"$ref": "#/definitions/PromptInput"},
|
||||||
"type": "array",
|
"type": "array",
|
||||||
"title": "RunnableEachInput",
|
"title": "RunnableEach<PromptTemplate>Input",
|
||||||
}
|
}
|
||||||
assert prompt_mapper.output_schema.schema() == snapshot
|
assert prompt_mapper.output_schema.schema() == snapshot
|
||||||
|
|
||||||
@ -571,7 +572,7 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
|
|||||||
"properties": {"name": {"title": "Name", "type": "string"}},
|
"properties": {"name": {"title": "Name", "type": "string"}},
|
||||||
}
|
}
|
||||||
assert seq_w_map.output_schema.schema() == {
|
assert seq_w_map.output_schema.schema() == {
|
||||||
"title": "RunnableParallelOutput",
|
"title": "RunnableParallel<original,as_list,length>Output",
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"original": {"title": "Original", "type": "string"},
|
"original": {"title": "Original", "type": "string"},
|
||||||
@ -615,7 +616,7 @@ def test_passthrough_assign_schema() -> None:
|
|||||||
# expected dict input_schema
|
# expected dict input_schema
|
||||||
assert invalid_seq_w_assign.input_schema.schema() == {
|
assert invalid_seq_w_assign.input_schema.schema() == {
|
||||||
"properties": {"question": {"title": "Question"}},
|
"properties": {"question": {"title": "Question"}},
|
||||||
"title": "RunnableParallelInput",
|
"title": "RunnableParallel<context>Input",
|
||||||
"type": "object",
|
"type": "object",
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -774,7 +775,7 @@ def test_schema_complex_seq() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert chain2.input_schema.schema() == {
|
assert chain2.input_schema.schema() == {
|
||||||
"title": "RunnableParallelInput",
|
"title": "RunnableParallel<city,language>Input",
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"person": {"title": "Person", "type": "string"},
|
"person": {"title": "Person", "type": "string"},
|
||||||
@ -2160,8 +2161,8 @@ async def test_stream_log_retriever() -> None:
|
|||||||
"FakeListLLM:2",
|
"FakeListLLM:2",
|
||||||
"Retriever",
|
"Retriever",
|
||||||
"RunnableLambda",
|
"RunnableLambda",
|
||||||
"RunnableParallel",
|
"RunnableParallel<documents,question>",
|
||||||
"RunnableParallel:2",
|
"RunnableParallel<one,two>",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -2444,7 +2445,7 @@ What is your name?"""
|
|||||||
parent_run = next(r for r in tracer.runs if r.parent_run_id is None)
|
parent_run = next(r for r in tracer.runs if r.parent_run_id is None)
|
||||||
assert len(parent_run.child_runs) == 4
|
assert len(parent_run.child_runs) == 4
|
||||||
map_run = parent_run.child_runs[0]
|
map_run = parent_run.child_runs[0]
|
||||||
assert map_run.name == "RunnableParallel"
|
assert map_run.name == "RunnableParallel<question,documents,just_to_test_lambda>"
|
||||||
assert len(map_run.child_runs) == 3
|
assert len(map_run.child_runs) == 3
|
||||||
|
|
||||||
|
|
||||||
@ -2505,7 +2506,7 @@ def test_seq_prompt_dict(mocker: MockerFixture, snapshot: SnapshotAssertion) ->
|
|||||||
parent_run = next(r for r in tracer.runs if r.parent_run_id is None)
|
parent_run = next(r for r in tracer.runs if r.parent_run_id is None)
|
||||||
assert len(parent_run.child_runs) == 3
|
assert len(parent_run.child_runs) == 3
|
||||||
map_run = parent_run.child_runs[2]
|
map_run = parent_run.child_runs[2]
|
||||||
assert map_run.name == "RunnableParallel"
|
assert map_run.name == "RunnableParallel<chat,llm>"
|
||||||
assert len(map_run.child_runs) == 2
|
assert len(map_run.child_runs) == 2
|
||||||
|
|
||||||
|
|
||||||
@ -2721,7 +2722,7 @@ def test_seq_prompt_map(mocker: MockerFixture, snapshot: SnapshotAssertion) -> N
|
|||||||
parent_run = next(r for r in tracer.runs if r.parent_run_id is None)
|
parent_run = next(r for r in tracer.runs if r.parent_run_id is None)
|
||||||
assert len(parent_run.child_runs) == 3
|
assert len(parent_run.child_runs) == 3
|
||||||
map_run = parent_run.child_runs[2]
|
map_run = parent_run.child_runs[2]
|
||||||
assert map_run.name == "RunnableParallel"
|
assert map_run.name == "RunnableParallel<chat,llm,passthrough>"
|
||||||
assert len(map_run.child_runs) == 3
|
assert len(map_run.child_runs) == 3
|
||||||
|
|
||||||
|
|
||||||
@ -2770,7 +2771,7 @@ def test_map_stream() -> None:
|
|||||||
{"question": "What is your name?"}
|
{"question": "What is your name?"}
|
||||||
)
|
)
|
||||||
|
|
||||||
chain_pick_one = chain | RunnablePassthrough.pick("llm")
|
chain_pick_one = chain.pick("llm")
|
||||||
|
|
||||||
assert chain_pick_one.output_schema.schema() == {
|
assert chain_pick_one.output_schema.schema() == {
|
||||||
"title": "RunnableSequenceOutput",
|
"title": "RunnableSequenceOutput",
|
||||||
@ -2791,10 +2792,8 @@ def test_map_stream() -> None:
|
|||||||
assert streamed_chunks[0] == "i"
|
assert streamed_chunks[0] == "i"
|
||||||
assert len(streamed_chunks) == len(llm_res)
|
assert len(streamed_chunks) == len(llm_res)
|
||||||
|
|
||||||
chain_pick_two = (
|
chain_pick_two = chain.assign(hello=RunnablePick("llm").pipe(llm)).pick(
|
||||||
chain
|
["llm", "hello"]
|
||||||
| RunnablePassthrough.assign(hello=RunnablePassthrough.pick("llm") | llm)
|
|
||||||
| RunnablePassthrough.pick(["llm", "hello"])
|
|
||||||
)
|
)
|
||||||
|
|
||||||
assert chain_pick_two.output_schema.schema() == {
|
assert chain_pick_two.output_schema.schema() == {
|
||||||
@ -2940,12 +2939,15 @@ async def test_map_astream() -> None:
|
|||||||
assert final_state.state["logs"]["ChatPromptTemplate"][
|
assert final_state.state["logs"]["ChatPromptTemplate"][
|
||||||
"final_output"
|
"final_output"
|
||||||
] == prompt.invoke({"question": "What is your name?"})
|
] == prompt.invoke({"question": "What is your name?"})
|
||||||
assert final_state.state["logs"]["RunnableParallel"]["name"] == "RunnableParallel"
|
assert (
|
||||||
|
final_state.state["logs"]["RunnableParallel<chat,llm,passthrough>"]["name"]
|
||||||
|
== "RunnableParallel<chat,llm,passthrough>"
|
||||||
|
)
|
||||||
assert sorted(final_state.state["logs"]) == [
|
assert sorted(final_state.state["logs"]) == [
|
||||||
"ChatPromptTemplate",
|
"ChatPromptTemplate",
|
||||||
"FakeListChatModel",
|
"FakeListChatModel",
|
||||||
"FakeStreamingListLLM",
|
"FakeStreamingListLLM",
|
||||||
"RunnableParallel",
|
"RunnableParallel<chat,llm,passthrough>",
|
||||||
"RunnablePassthrough",
|
"RunnablePassthrough",
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -2985,11 +2987,14 @@ async def test_map_astream() -> None:
|
|||||||
assert final_state.state["logs"]["ChatPromptTemplate"]["final_output"] == (
|
assert final_state.state["logs"]["ChatPromptTemplate"]["final_output"] == (
|
||||||
prompt.invoke({"question": "What is your name?"})
|
prompt.invoke({"question": "What is your name?"})
|
||||||
)
|
)
|
||||||
assert final_state.state["logs"]["RunnableParallel"]["name"] == "RunnableParallel"
|
assert (
|
||||||
|
final_state.state["logs"]["RunnableParallel<chat,llm,passthrough>"]["name"]
|
||||||
|
== "RunnableParallel<chat,llm,passthrough>"
|
||||||
|
)
|
||||||
assert sorted(final_state.state["logs"]) == [
|
assert sorted(final_state.state["logs"]) == [
|
||||||
"ChatPromptTemplate",
|
"ChatPromptTemplate",
|
||||||
"FakeStreamingListLLM",
|
"FakeStreamingListLLM",
|
||||||
"RunnableParallel",
|
"RunnableParallel<chat,llm,passthrough>",
|
||||||
"RunnablePassthrough",
|
"RunnablePassthrough",
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -3130,9 +3135,7 @@ def test_deep_stream_assign() -> None:
|
|||||||
assert len(chunks) == len("foo-lish")
|
assert len(chunks) == len("foo-lish")
|
||||||
assert add(chunks) == {"str": "foo-lish"}
|
assert add(chunks) == {"str": "foo-lish"}
|
||||||
|
|
||||||
chain_with_assign = chain | RunnablePassthrough.assign(
|
chain_with_assign = chain.assign(hello=itemgetter("str") | llm)
|
||||||
hello=itemgetter("str") | llm
|
|
||||||
)
|
|
||||||
|
|
||||||
assert chain_with_assign.input_schema.schema() == {
|
assert chain_with_assign.input_schema.schema() == {
|
||||||
"title": "PromptInput",
|
"title": "PromptInput",
|
||||||
@ -3179,7 +3182,7 @@ def test_deep_stream_assign() -> None:
|
|||||||
"hello": "foo-lish",
|
"hello": "foo-lish",
|
||||||
}
|
}
|
||||||
|
|
||||||
chain_with_assign_shadow = chain | RunnablePassthrough.assign(
|
chain_with_assign_shadow = chain.assign(
|
||||||
str=lambda _: "shadow",
|
str=lambda _: "shadow",
|
||||||
hello=itemgetter("str") | llm,
|
hello=itemgetter("str") | llm,
|
||||||
)
|
)
|
||||||
@ -3254,7 +3257,7 @@ async def test_deep_astream_assign() -> None:
|
|||||||
assert len(chunks) == len("foo-lish")
|
assert len(chunks) == len("foo-lish")
|
||||||
assert add(chunks) == {"str": "foo-lish"}
|
assert add(chunks) == {"str": "foo-lish"}
|
||||||
|
|
||||||
chain_with_assign = chain | RunnablePassthrough.assign(
|
chain_with_assign = chain.assign(
|
||||||
hello=itemgetter("str") | llm,
|
hello=itemgetter("str") | llm,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -4473,15 +4476,15 @@ def test_invoke_stream_passthrough_assign_trace() -> None:
|
|||||||
tracer = FakeTracer()
|
tracer = FakeTracer()
|
||||||
chain.invoke({"example": [1, 2, 3]}, dict(callbacks=[tracer]))
|
chain.invoke({"example": [1, 2, 3]}, dict(callbacks=[tracer]))
|
||||||
|
|
||||||
assert tracer.runs[0].name == "RunnableAssign"
|
assert tracer.runs[0].name == "RunnableAssign<urls>"
|
||||||
assert tracer.runs[0].child_runs[0].name == "RunnableParallel"
|
assert tracer.runs[0].child_runs[0].name == "RunnableParallel<urls>"
|
||||||
|
|
||||||
tracer = FakeTracer()
|
tracer = FakeTracer()
|
||||||
for item in chain.stream({"example": [1, 2, 3]}, dict(callbacks=[tracer])):
|
for item in chain.stream({"example": [1, 2, 3]}, dict(callbacks=[tracer])):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
assert tracer.runs[0].name == "RunnableAssign"
|
assert tracer.runs[0].name == "RunnableAssign<urls>"
|
||||||
assert tracer.runs[0].child_runs[0].name == "RunnableParallel"
|
assert tracer.runs[0].child_runs[0].name == "RunnableParallel<urls>"
|
||||||
|
|
||||||
|
|
||||||
async def test_ainvoke_astream_passthrough_assign_trace() -> None:
|
async def test_ainvoke_astream_passthrough_assign_trace() -> None:
|
||||||
@ -4493,15 +4496,15 @@ async def test_ainvoke_astream_passthrough_assign_trace() -> None:
|
|||||||
tracer = FakeTracer()
|
tracer = FakeTracer()
|
||||||
await chain.ainvoke({"example": [1, 2, 3]}, dict(callbacks=[tracer]))
|
await chain.ainvoke({"example": [1, 2, 3]}, dict(callbacks=[tracer]))
|
||||||
|
|
||||||
assert tracer.runs[0].name == "RunnableAssign"
|
assert tracer.runs[0].name == "RunnableAssign<urls>"
|
||||||
assert tracer.runs[0].child_runs[0].name == "RunnableParallel"
|
assert tracer.runs[0].child_runs[0].name == "RunnableParallel<urls>"
|
||||||
|
|
||||||
tracer = FakeTracer()
|
tracer = FakeTracer()
|
||||||
async for item in chain.astream({"example": [1, 2, 3]}, dict(callbacks=[tracer])):
|
async for item in chain.astream({"example": [1, 2, 3]}, dict(callbacks=[tracer])):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
assert tracer.runs[0].name == "RunnableAssign"
|
assert tracer.runs[0].name == "RunnableAssign<urls>"
|
||||||
assert tracer.runs[0].child_runs[0].name == "RunnableParallel"
|
assert tracer.runs[0].child_runs[0].name == "RunnableParallel<urls>"
|
||||||
|
|
||||||
|
|
||||||
async def test_astream_log_deep_copies() -> None:
|
async def test_astream_log_deep_copies() -> None:
|
||||||
|
@ -35,13 +35,13 @@ def create_history_aware_retriever(
|
|||||||
# pip install -U langchain langchain-community
|
# pip install -U langchain langchain-community
|
||||||
|
|
||||||
from langchain_community.chat_models import ChatOpenAI
|
from langchain_community.chat_models import ChatOpenAI
|
||||||
from langchain.chains import create_chat_history_retriever
|
from langchain.chains import create_history_aware_retriever
|
||||||
from langchain import hub
|
from langchain import hub
|
||||||
|
|
||||||
rephrase_prompt = hub.pull("langchain-ai/chat-langchain-rephrase")
|
rephrase_prompt = hub.pull("langchain-ai/chat-langchain-rephrase")
|
||||||
llm = ChatOpenAI()
|
llm = ChatOpenAI()
|
||||||
retriever = ...
|
retriever = ...
|
||||||
chat_retriever_chain = create_chat_retriever_chain(
|
chat_retriever_chain = create_history_aware_retriever(
|
||||||
llm, retriever, rephrase_prompt
|
llm, retriever, rephrase_prompt
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -64,8 +64,7 @@ def create_retrieval_chain(
|
|||||||
RunnablePassthrough.assign(
|
RunnablePassthrough.assign(
|
||||||
context=retrieval_docs.with_config(run_name="retrieve_documents"),
|
context=retrieval_docs.with_config(run_name="retrieve_documents"),
|
||||||
chat_history=lambda x: x.get("chat_history", []),
|
chat_history=lambda x: x.get("chat_history", []),
|
||||||
)
|
).assign(answer=combine_docs_chain)
|
||||||
| RunnablePassthrough.assign(answer=combine_docs_chain)
|
|
||||||
).with_config(run_name="retrieval_chain")
|
).with_config(run_name="retrieval_chain")
|
||||||
|
|
||||||
return retrieval_chain
|
return retrieval_chain
|
||||||
|
Loading…
Reference in New Issue
Block a user