diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index 9ee3c7707c5..eee42e876e6 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -391,9 +391,15 @@ class Runnable(Generic[Input, Output], ABC): from langchain_core.runnables.graph import Graph graph = Graph() - input_node = graph.add_node(self.get_input_schema(config)) + try: + input_node = graph.add_node(self.get_input_schema(config)) + except TypeError: + input_node = graph.add_node(create_model(self.get_name("Input"))) runnable_node = graph.add_node(self) - output_node = graph.add_node(self.get_output_schema(config)) + try: + output_node = graph.add_node(self.get_output_schema(config)) + except TypeError: + output_node = graph.add_node(create_model(self.get_name("Output"))) graph.add_edge(input_node, runnable_node) graph.add_edge(runnable_node, output_node) return graph diff --git a/libs/core/tests/unit_tests/runnables/test_graph.py b/libs/core/tests/unit_tests/runnables/test_graph.py index bc4ef621fc5..50291e83816 100644 --- a/libs/core/tests/unit_tests/runnables/test_graph.py +++ b/libs/core/tests/unit_tests/runnables/test_graph.py @@ -1,3 +1,5 @@ +from typing import Optional + from syrupy import SnapshotAssertion from langchain_core.language_models import FakeListLLM @@ -5,7 +7,7 @@ from langchain_core.output_parsers.list import CommaSeparatedListOutputParser from langchain_core.output_parsers.string import StrOutputParser from langchain_core.output_parsers.xml import XMLOutputParser from langchain_core.prompts.prompt import PromptTemplate -from langchain_core.runnables.base import Runnable +from langchain_core.runnables.base import Runnable, RunnableConfig def test_graph_single_runnable(snapshot: SnapshotAssertion) -> None: @@ -687,3 +689,47 @@ def test_graph_sequence_map(snapshot: SnapshotAssertion) -> None: assert graph.draw_ascii() == snapshot(name="ascii") assert graph.draw_mermaid() == snapshot(name="mermaid") assert graph.draw_mermaid(with_styles=False) == snapshot(name="mermaid-simple") + + +def test_runnable_get_graph_with_invalid_input_type() -> None: + """Test that error isn't raised when getting graph with invalid input type.""" + + class InvalidInputTypeRunnable(Runnable[int, int]): + @property + def InputType(self) -> type: + raise TypeError() + + def invoke( + self, + input: int, + config: Optional[RunnableConfig] = None, + ) -> int: + return input + + runnable = InvalidInputTypeRunnable() + # check whether runnable.invoke works + assert runnable.invoke(1) == 1 + # check whether runnable.get_graph works + runnable.get_graph() + + +def test_runnable_get_graph_with_invalid_output_type() -> None: + """Test that error is't raised when getting graph with invalid output type.""" + + class InvalidOutputTypeRunnable(Runnable[int, int]): + @property + def OutputType(self) -> type: + raise TypeError() + + def invoke( + self, + input: int, + config: Optional[RunnableConfig] = None, + ) -> int: + return input + + runnable = InvalidOutputTypeRunnable() + # check whether runnable.invoke works + assert runnable.invoke(1) == 1 + # check whether runnable.get_graph works + runnable.get_graph()