Improve graph repr for runnable passthrough and itemgetter (#15083)

<!-- Thank you for contributing to LangChain!

Please title your PR "<package>: <description>", where <package> is
whichever of langchain, community, core, experimental, etc. is being
modified.

Replace this entire comment with:
  - **Description:** a description of the change, 
  - **Issue:** the issue # it fixes if applicable,
  - **Dependencies:** any dependencies required for this change,
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` from the root
of the package you've modified to check this locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc: https://python.langchain.com/docs/contributing/

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in
`docs/docs/integrations` directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->
This commit is contained in:
Nuno Campos 2023-12-22 16:05:48 -08:00 committed by GitHub
parent 0d0901ea18
commit a2d3042823
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 72 additions and 57 deletions

View File

@ -2007,7 +2007,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
): ):
# This is correct, but pydantic typings/mypy don't think so. # This is correct, but pydantic typings/mypy don't think so.
return create_model( # type: ignore[call-overload] return create_model( # type: ignore[call-overload]
"RunnableParallelInput", "RunnableMapInput",
**{ **{
k: (v.annotation, v.default) k: (v.annotation, v.default)
for step in self.steps.values() for step in self.steps.values()
@ -2024,7 +2024,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
) -> Type[BaseModel]: ) -> Type[BaseModel]:
# This is correct, but pydantic typings/mypy don't think so. # This is correct, but pydantic typings/mypy don't think so.
return create_model( # type: ignore[call-overload] return create_model( # type: ignore[call-overload]
"RunnableParallelOutput", "RunnableMapOutput",
**{k: (v.OutputType, None) for k, v in self.steps.items()}, **{k: (v.OutputType, None) for k, v in self.steps.items()},
__config__=_SchemaConfig, __config__=_SchemaConfig,
) )
@ -2650,7 +2650,9 @@ class RunnableLambda(Runnable[Input, Output]):
def __repr__(self) -> str: def __repr__(self) -> str:
"""A string representation of this runnable.""" """A string representation of this runnable."""
if hasattr(self, "func"): if hasattr(self, "func") and isinstance(self.func, itemgetter):
return f"RunnableLambda({str(self.func)[len('operator.'):]})"
elif hasattr(self, "func"):
return f"RunnableLambda({get_lambda_source(self.func) or '...'})" return f"RunnableLambda({get_lambda_source(self.func) or '...'})"
elif hasattr(self, "afunc"): elif hasattr(self, "afunc"):
return f"RunnableLambda(afunc={get_lambda_source(self.afunc) or '...'})" return f"RunnableLambda(afunc={get_lambda_source(self.afunc) or '...'})"

View File

@ -123,13 +123,13 @@ class Graph:
or len(data.splitlines()) > 1 or len(data.splitlines()) > 1
): ):
data = node.data.__class__.__name__ data = node.data.__class__.__name__
elif len(data) > 36: elif len(data) > 42:
data = data[:36] + "..." data = data[:42] + "..."
except Exception: except Exception:
data = node.data.__class__.__name__ data = node.data.__class__.__name__
else: else:
data = node.data.__name__ data = node.data.__name__
return data return data if not data.startswith("Runnable") else data[8:]
return draw( return draw(
{node.id: node_data(node) for node in self.nodes.values()}, {node.id: node_data(node) for node in self.nodes.values()},

View File

@ -34,6 +34,7 @@ from langchain_core.runnables.config import (
get_executor_for_config, get_executor_for_config,
patch_config, patch_config,
) )
from langchain_core.runnables.graph import Graph
from langchain_core.runnables.utils import AddableDict, ConfigurableFieldSpec from langchain_core.runnables.utils import AddableDict, ConfigurableFieldSpec
from langchain_core.utils.aiter import atee, py_anext from langchain_core.utils.aiter import atee, py_anext
from langchain_core.utils.iter import safetee from langchain_core.utils.iter import safetee
@ -297,6 +298,9 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
yield chunk yield chunk
_graph_passthrough: RunnablePassthrough = RunnablePassthrough()
class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]): class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
""" """
A runnable that assigns key-value pairs to Dict[str, Any] inputs. A runnable that assigns key-value pairs to Dict[str, Any] inputs.
@ -355,6 +359,18 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
def config_specs(self) -> List[ConfigurableFieldSpec]: def config_specs(self) -> List[ConfigurableFieldSpec]:
return self.mapper.config_specs return self.mapper.config_specs
def get_graph(self, config: RunnableConfig | None = None) -> Graph:
# get graph from mapper
graph = self.mapper.get_graph(config)
# add passthrough node and edges
input_node = graph.first_node()
output_node = graph.last_node()
if input_node is not None and output_node is not None:
passthrough_node = graph.add_node(_graph_passthrough)
graph.add_edge(input_node, passthrough_node)
graph.add_edge(passthrough_node, output_node)
return graph
def _invoke( def _invoke(
self, self,
input: Dict[str, Any], input: Dict[str, Any],

View File

@ -32,51 +32,51 @@
# --- # ---
# name: test_graph_sequence_map # name: test_graph_sequence_map
''' '''
+-------------+ +-------------+
| PromptInput | | PromptInput |
+-------------+ +-------------+
* *
* *
* *
+----------------+ +----------------+
| PromptTemplate | | PromptTemplate |
+----------------+ +----------------+
* *
* *
* *
+-------------+ +-------------+
| FakeListLLM | | FakeListLLM |
+-------------+ +-------------+
* *
* *
* *
+-----------------------+ +---------------+
| RunnableParallelInput | | ParallelInput |
+-----------------------+** +---------------+*****
**** ******* *** ******
**** ***** *** *****
** ******* ** *****
+---------------------+ *** +-------------+ ***
| RunnableLambdaInput | * | LambdaInput | *
+---------------------+ * +-------------+ *
*** *** * ** ** *
*** *** * *** *** *
** ** * ** ** *
+-----------------+ +-----------------+ * +-----------------+ +-----------------+ *
| StrOutputParser | | XMLOutputParser | * | StrOutputParser | | XMLOutputParser | *
+-----------------+ +-----------------+ * +-----------------+ +-----------------+ *
*** *** * ** ** *
*** *** * *** *** *
** ** * ** ** *
+----------------------+ +--------------------------------+ +--------------+ +--------------------------------+
| RunnableLambdaOutput | | CommaSeparatedListOutputParser | | LambdaOutput | | CommaSeparatedListOutputParser |
+----------------------+ +--------------------------------+ +--------------+ +--------------------------------+
**** ******* *** ******
**** ***** *** *****
** **** ** ***
+------------------------+ +-----------+
| RunnableParallelOutput | | MapOutput |
+------------------------+ +-----------+
''' '''
# --- # ---
# name: test_graph_single_runnable # name: test_graph_single_runnable

View File

@ -569,7 +569,7 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
"properties": {"name": {"title": "Name", "type": "string"}}, "properties": {"name": {"title": "Name", "type": "string"}},
} }
assert seq_w_map.output_schema.schema() == { assert seq_w_map.output_schema.schema() == {
"title": "RunnableParallelOutput", "title": "RunnableMapOutput",
"type": "object", "type": "object",
"properties": { "properties": {
"original": {"title": "Original", "type": "string"}, "original": {"title": "Original", "type": "string"},
@ -613,7 +613,7 @@ def test_passthrough_assign_schema() -> None:
# expected dict input_schema # expected dict input_schema
assert invalid_seq_w_assign.input_schema.schema() == { assert invalid_seq_w_assign.input_schema.schema() == {
"properties": {"question": {"title": "Question"}}, "properties": {"question": {"title": "Question"}},
"title": "RunnableParallelInput", "title": "RunnableMapInput",
"type": "object", "type": "object",
} }
@ -768,7 +768,7 @@ def test_schema_complex_seq() -> None:
) )
assert chain2.input_schema.schema() == { assert chain2.input_schema.schema() == {
"title": "RunnableParallelInput", "title": "RunnableMapInput",
"type": "object", "type": "object",
"properties": { "properties": {
"person": {"title": "Person", "type": "string"}, "person": {"title": "Person", "type": "string"},
@ -2221,7 +2221,6 @@ async def test_stream_log_lists() -> None:
} }
@pytest.mark.asyncio
@freeze_time("2023-01-01") @freeze_time("2023-01-01")
async def test_prompt_with_llm_and_async_lambda( async def test_prompt_with_llm_and_async_lambda(
mocker: MockerFixture, snapshot: SnapshotAssertion mocker: MockerFixture, snapshot: SnapshotAssertion
@ -4262,7 +4261,6 @@ def test_with_config_callbacks() -> None:
assert isinstance(result, RunnableBinding) assert isinstance(result, RunnableBinding)
@pytest.mark.asyncio
async def test_ainvoke_on_returned_runnable() -> None: async def test_ainvoke_on_returned_runnable() -> None:
"""Verify that a runnable returned by a sync runnable in the async path will """Verify that a runnable returned by a sync runnable in the async path will
be runthroughaasync path (issue #13407)""" be runthroughaasync path (issue #13407)"""
@ -4301,7 +4299,6 @@ def test_invoke_stream_passthrough_assign_trace() -> None:
assert tracer.runs[0].child_runs[0].name == "RunnableParallel" assert tracer.runs[0].child_runs[0].name == "RunnableParallel"
@pytest.mark.asyncio
async def test_ainvoke_astream_passthrough_assign_trace() -> None: async def test_ainvoke_astream_passthrough_assign_trace() -> None:
def idchain_sync(__input: dict) -> bool: def idchain_sync(__input: dict) -> bool:
return False return False