mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-04 20:28:10 +00:00
Improve graph repr for runnable passthrough and itemgetter (#15083)
<!-- Thank you for contributing to LangChain! Please title your PR "<package>: <description>", where <package> is whichever of langchain, community, core, experimental, etc. is being modified. Replace this entire comment with: - **Description:** a description of the change, - **Issue:** the issue # it fixes if applicable, - **Dependencies:** any dependencies required for this change, - **Twitter handle:** we announce bigger features on Twitter. If your PR gets announced, and you'd like a mention, we'll gladly shout you out! Please make sure your PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` from the root of the package you've modified to check this locally. See contribution guidelines for more information on how to write/run tests, lint, etc: https://python.langchain.com/docs/contributing/ If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/docs/integrations` directory. If no one reviews your PR within a few days, please @-mention one of @baskaryan, @eyurtsev, @hwchase17. -->
This commit is contained in:
parent
0d0901ea18
commit
a2d3042823
@ -2007,7 +2007,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
|
|||||||
):
|
):
|
||||||
# This is correct, but pydantic typings/mypy don't think so.
|
# This is correct, but pydantic typings/mypy don't think so.
|
||||||
return create_model( # type: ignore[call-overload]
|
return create_model( # type: ignore[call-overload]
|
||||||
"RunnableParallelInput",
|
"RunnableMapInput",
|
||||||
**{
|
**{
|
||||||
k: (v.annotation, v.default)
|
k: (v.annotation, v.default)
|
||||||
for step in self.steps.values()
|
for step in self.steps.values()
|
||||||
@ -2024,7 +2024,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
|
|||||||
) -> Type[BaseModel]:
|
) -> Type[BaseModel]:
|
||||||
# This is correct, but pydantic typings/mypy don't think so.
|
# This is correct, but pydantic typings/mypy don't think so.
|
||||||
return create_model( # type: ignore[call-overload]
|
return create_model( # type: ignore[call-overload]
|
||||||
"RunnableParallelOutput",
|
"RunnableMapOutput",
|
||||||
**{k: (v.OutputType, None) for k, v in self.steps.items()},
|
**{k: (v.OutputType, None) for k, v in self.steps.items()},
|
||||||
__config__=_SchemaConfig,
|
__config__=_SchemaConfig,
|
||||||
)
|
)
|
||||||
@ -2650,7 +2650,9 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
"""A string representation of this runnable."""
|
"""A string representation of this runnable."""
|
||||||
if hasattr(self, "func"):
|
if hasattr(self, "func") and isinstance(self.func, itemgetter):
|
||||||
|
return f"RunnableLambda({str(self.func)[len('operator.'):]})"
|
||||||
|
elif hasattr(self, "func"):
|
||||||
return f"RunnableLambda({get_lambda_source(self.func) or '...'})"
|
return f"RunnableLambda({get_lambda_source(self.func) or '...'})"
|
||||||
elif hasattr(self, "afunc"):
|
elif hasattr(self, "afunc"):
|
||||||
return f"RunnableLambda(afunc={get_lambda_source(self.afunc) or '...'})"
|
return f"RunnableLambda(afunc={get_lambda_source(self.afunc) or '...'})"
|
||||||
|
@ -123,13 +123,13 @@ class Graph:
|
|||||||
or len(data.splitlines()) > 1
|
or len(data.splitlines()) > 1
|
||||||
):
|
):
|
||||||
data = node.data.__class__.__name__
|
data = node.data.__class__.__name__
|
||||||
elif len(data) > 36:
|
elif len(data) > 42:
|
||||||
data = data[:36] + "..."
|
data = data[:42] + "..."
|
||||||
except Exception:
|
except Exception:
|
||||||
data = node.data.__class__.__name__
|
data = node.data.__class__.__name__
|
||||||
else:
|
else:
|
||||||
data = node.data.__name__
|
data = node.data.__name__
|
||||||
return data
|
return data if not data.startswith("Runnable") else data[8:]
|
||||||
|
|
||||||
return draw(
|
return draw(
|
||||||
{node.id: node_data(node) for node in self.nodes.values()},
|
{node.id: node_data(node) for node in self.nodes.values()},
|
||||||
|
@ -34,6 +34,7 @@ from langchain_core.runnables.config import (
|
|||||||
get_executor_for_config,
|
get_executor_for_config,
|
||||||
patch_config,
|
patch_config,
|
||||||
)
|
)
|
||||||
|
from langchain_core.runnables.graph import Graph
|
||||||
from langchain_core.runnables.utils import AddableDict, ConfigurableFieldSpec
|
from langchain_core.runnables.utils import AddableDict, ConfigurableFieldSpec
|
||||||
from langchain_core.utils.aiter import atee, py_anext
|
from langchain_core.utils.aiter import atee, py_anext
|
||||||
from langchain_core.utils.iter import safetee
|
from langchain_core.utils.iter import safetee
|
||||||
@ -297,6 +298,9 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
|
|||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
|
||||||
|
_graph_passthrough: RunnablePassthrough = RunnablePassthrough()
|
||||||
|
|
||||||
|
|
||||||
class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
||||||
"""
|
"""
|
||||||
A runnable that assigns key-value pairs to Dict[str, Any] inputs.
|
A runnable that assigns key-value pairs to Dict[str, Any] inputs.
|
||||||
@ -355,6 +359,18 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
|||||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||||
return self.mapper.config_specs
|
return self.mapper.config_specs
|
||||||
|
|
||||||
|
def get_graph(self, config: RunnableConfig | None = None) -> Graph:
|
||||||
|
# get graph from mapper
|
||||||
|
graph = self.mapper.get_graph(config)
|
||||||
|
# add passthrough node and edges
|
||||||
|
input_node = graph.first_node()
|
||||||
|
output_node = graph.last_node()
|
||||||
|
if input_node is not None and output_node is not None:
|
||||||
|
passthrough_node = graph.add_node(_graph_passthrough)
|
||||||
|
graph.add_edge(input_node, passthrough_node)
|
||||||
|
graph.add_edge(passthrough_node, output_node)
|
||||||
|
return graph
|
||||||
|
|
||||||
def _invoke(
|
def _invoke(
|
||||||
self,
|
self,
|
||||||
input: Dict[str, Any],
|
input: Dict[str, Any],
|
||||||
|
@ -32,51 +32,51 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_graph_sequence_map
|
# name: test_graph_sequence_map
|
||||||
'''
|
'''
|
||||||
+-------------+
|
+-------------+
|
||||||
| PromptInput |
|
| PromptInput |
|
||||||
+-------------+
|
+-------------+
|
||||||
*
|
*
|
||||||
*
|
*
|
||||||
*
|
*
|
||||||
+----------------+
|
+----------------+
|
||||||
| PromptTemplate |
|
| PromptTemplate |
|
||||||
+----------------+
|
+----------------+
|
||||||
*
|
*
|
||||||
*
|
*
|
||||||
*
|
*
|
||||||
+-------------+
|
+-------------+
|
||||||
| FakeListLLM |
|
| FakeListLLM |
|
||||||
+-------------+
|
+-------------+
|
||||||
*
|
*
|
||||||
*
|
*
|
||||||
*
|
*
|
||||||
+-----------------------+
|
+---------------+
|
||||||
| RunnableParallelInput |
|
| ParallelInput |
|
||||||
+-----------------------+**
|
+---------------+*****
|
||||||
**** *******
|
*** ******
|
||||||
**** *****
|
*** *****
|
||||||
** *******
|
** *****
|
||||||
+---------------------+ ***
|
+-------------+ ***
|
||||||
| RunnableLambdaInput | *
|
| LambdaInput | *
|
||||||
+---------------------+ *
|
+-------------+ *
|
||||||
*** *** *
|
** ** *
|
||||||
*** *** *
|
*** *** *
|
||||||
** ** *
|
** ** *
|
||||||
+-----------------+ +-----------------+ *
|
+-----------------+ +-----------------+ *
|
||||||
| StrOutputParser | | XMLOutputParser | *
|
| StrOutputParser | | XMLOutputParser | *
|
||||||
+-----------------+ +-----------------+ *
|
+-----------------+ +-----------------+ *
|
||||||
*** *** *
|
** ** *
|
||||||
*** *** *
|
*** *** *
|
||||||
** ** *
|
** ** *
|
||||||
+----------------------+ +--------------------------------+
|
+--------------+ +--------------------------------+
|
||||||
| RunnableLambdaOutput | | CommaSeparatedListOutputParser |
|
| LambdaOutput | | CommaSeparatedListOutputParser |
|
||||||
+----------------------+ +--------------------------------+
|
+--------------+ +--------------------------------+
|
||||||
**** *******
|
*** ******
|
||||||
**** *****
|
*** *****
|
||||||
** ****
|
** ***
|
||||||
+------------------------+
|
+-----------+
|
||||||
| RunnableParallelOutput |
|
| MapOutput |
|
||||||
+------------------------+
|
+-----------+
|
||||||
'''
|
'''
|
||||||
# ---
|
# ---
|
||||||
# name: test_graph_single_runnable
|
# name: test_graph_single_runnable
|
||||||
|
@ -569,7 +569,7 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
|
|||||||
"properties": {"name": {"title": "Name", "type": "string"}},
|
"properties": {"name": {"title": "Name", "type": "string"}},
|
||||||
}
|
}
|
||||||
assert seq_w_map.output_schema.schema() == {
|
assert seq_w_map.output_schema.schema() == {
|
||||||
"title": "RunnableParallelOutput",
|
"title": "RunnableMapOutput",
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"original": {"title": "Original", "type": "string"},
|
"original": {"title": "Original", "type": "string"},
|
||||||
@ -613,7 +613,7 @@ def test_passthrough_assign_schema() -> None:
|
|||||||
# expected dict input_schema
|
# expected dict input_schema
|
||||||
assert invalid_seq_w_assign.input_schema.schema() == {
|
assert invalid_seq_w_assign.input_schema.schema() == {
|
||||||
"properties": {"question": {"title": "Question"}},
|
"properties": {"question": {"title": "Question"}},
|
||||||
"title": "RunnableParallelInput",
|
"title": "RunnableMapInput",
|
||||||
"type": "object",
|
"type": "object",
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -768,7 +768,7 @@ def test_schema_complex_seq() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert chain2.input_schema.schema() == {
|
assert chain2.input_schema.schema() == {
|
||||||
"title": "RunnableParallelInput",
|
"title": "RunnableMapInput",
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"person": {"title": "Person", "type": "string"},
|
"person": {"title": "Person", "type": "string"},
|
||||||
@ -2221,7 +2221,6 @@ async def test_stream_log_lists() -> None:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@freeze_time("2023-01-01")
|
@freeze_time("2023-01-01")
|
||||||
async def test_prompt_with_llm_and_async_lambda(
|
async def test_prompt_with_llm_and_async_lambda(
|
||||||
mocker: MockerFixture, snapshot: SnapshotAssertion
|
mocker: MockerFixture, snapshot: SnapshotAssertion
|
||||||
@ -4262,7 +4261,6 @@ def test_with_config_callbacks() -> None:
|
|||||||
assert isinstance(result, RunnableBinding)
|
assert isinstance(result, RunnableBinding)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_ainvoke_on_returned_runnable() -> None:
|
async def test_ainvoke_on_returned_runnable() -> None:
|
||||||
"""Verify that a runnable returned by a sync runnable in the async path will
|
"""Verify that a runnable returned by a sync runnable in the async path will
|
||||||
be runthroughaasync path (issue #13407)"""
|
be runthroughaasync path (issue #13407)"""
|
||||||
@ -4301,7 +4299,6 @@ def test_invoke_stream_passthrough_assign_trace() -> None:
|
|||||||
assert tracer.runs[0].child_runs[0].name == "RunnableParallel"
|
assert tracer.runs[0].child_runs[0].name == "RunnableParallel"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_ainvoke_astream_passthrough_assign_trace() -> None:
|
async def test_ainvoke_astream_passthrough_assign_trace() -> None:
|
||||||
def idchain_sync(__input: dict) -> bool:
|
def idchain_sync(__input: dict) -> bool:
|
||||||
return False
|
return False
|
||||||
|
Loading…
Reference in New Issue
Block a user