From bbd7015b5d23da30213b1b0bbd618b063df2dcf1 Mon Sep 17 00:00:00 2001 From: hmasdev <73353463+hmasdev@users.noreply.github.com> Date: Tue, 28 May 2024 06:34:34 +0900 Subject: [PATCH] core[patch]: Add `TypeError` handler into `get_graph` of `Runnable` (#19856) # Description ## Problem `Runnable.get_graph` fails when `InputType` or `OutputType` property raises `TypeError`. - https://github.com/langchain-ai/langchain/tree/003c98e5b440a420d8dbd5f9fdea08888a4a5a33/libs/core/langchain_core/runnables/base.py#L250-L274 - https://github.com/langchain-ai/langchain/tree/003c98e5b440a420d8dbd5f9fdea08888a4a5a33/libs/core/langchain_core/runnables/base.py#L394-L396 This problem prevents getting a graph of `Runnable` objects whose `InputType` or `OutputType` property raises `TypeError` but whose `invoke` works well, such as `langchain.output_parsers.RegexParser`, which I have already pointed out in #19792 that a `TypeError` would occur. ## Solution - Add `try-except` syntax to handle `TypeError` to the codes which get `input_node` and `output_node`. # Issue - #19801 # Twitter Handle - [hmdev3](https://twitter.com/hmdev3) --------- Co-authored-by: Bagatur --- libs/core/langchain_core/runnables/base.py | 10 +++- .../tests/unit_tests/runnables/test_graph.py | 48 ++++++++++++++++++- 2 files changed, 55 insertions(+), 3 deletions(-) diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index 9ee3c7707c5..eee42e876e6 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -391,9 +391,15 @@ class Runnable(Generic[Input, Output], ABC): from langchain_core.runnables.graph import Graph graph = Graph() - input_node = graph.add_node(self.get_input_schema(config)) + try: + input_node = graph.add_node(self.get_input_schema(config)) + except TypeError: + input_node = graph.add_node(create_model(self.get_name("Input"))) runnable_node = graph.add_node(self) - output_node = graph.add_node(self.get_output_schema(config)) + try: + output_node = graph.add_node(self.get_output_schema(config)) + except TypeError: + output_node = graph.add_node(create_model(self.get_name("Output"))) graph.add_edge(input_node, runnable_node) graph.add_edge(runnable_node, output_node) return graph diff --git a/libs/core/tests/unit_tests/runnables/test_graph.py b/libs/core/tests/unit_tests/runnables/test_graph.py index bc4ef621fc5..50291e83816 100644 --- a/libs/core/tests/unit_tests/runnables/test_graph.py +++ b/libs/core/tests/unit_tests/runnables/test_graph.py @@ -1,3 +1,5 @@ +from typing import Optional + from syrupy import SnapshotAssertion from langchain_core.language_models import FakeListLLM @@ -5,7 +7,7 @@ from langchain_core.output_parsers.list import CommaSeparatedListOutputParser from langchain_core.output_parsers.string import StrOutputParser from langchain_core.output_parsers.xml import XMLOutputParser from langchain_core.prompts.prompt import PromptTemplate -from langchain_core.runnables.base import Runnable +from langchain_core.runnables.base import Runnable, RunnableConfig def test_graph_single_runnable(snapshot: SnapshotAssertion) -> None: @@ -687,3 +689,47 @@ def test_graph_sequence_map(snapshot: SnapshotAssertion) -> None: assert graph.draw_ascii() == snapshot(name="ascii") assert graph.draw_mermaid() == snapshot(name="mermaid") assert graph.draw_mermaid(with_styles=False) == snapshot(name="mermaid-simple") + + +def test_runnable_get_graph_with_invalid_input_type() -> None: + """Test that error isn't raised when getting graph with invalid input type.""" + + class InvalidInputTypeRunnable(Runnable[int, int]): + @property + def InputType(self) -> type: + raise TypeError() + + def invoke( + self, + input: int, + config: Optional[RunnableConfig] = None, + ) -> int: + return input + + runnable = InvalidInputTypeRunnable() + # check whether runnable.invoke works + assert runnable.invoke(1) == 1 + # check whether runnable.get_graph works + runnable.get_graph() + + +def test_runnable_get_graph_with_invalid_output_type() -> None: + """Test that error is't raised when getting graph with invalid output type.""" + + class InvalidOutputTypeRunnable(Runnable[int, int]): + @property + def OutputType(self) -> type: + raise TypeError() + + def invoke( + self, + input: int, + config: Optional[RunnableConfig] = None, + ) -> int: + return input + + runnable = InvalidOutputTypeRunnable() + # check whether runnable.invoke works + assert runnable.invoke(1) == 1 + # check whether runnable.get_graph works + runnable.get_graph()