core[patch]: Add TypeError handler into get_graph of Runnable (#19856)

# Description

## Problem

`Runnable.get_graph` fails when `InputType` or `OutputType` property
raises `TypeError`.

-
003c98e5b4/libs/core/langchain_core/runnables/base.py (L250-L274)
-
003c98e5b4/libs/core/langchain_core/runnables/base.py (L394-L396)

This problem prevents getting a graph of `Runnable` objects whose
`InputType` or `OutputType` property raises `TypeError` but whose
`invoke` works well, such as `langchain.output_parsers.RegexParser`,
which I have already pointed out in #19792 that a `TypeError` would
occur.

## Solution

- Add `try-except` syntax to handle `TypeError` to the codes which get
`input_node` and `output_node`.

# Issue
- #19801 

# Twitter Handle
- [hmdev3](https://twitter.com/hmdev3)

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
hmasdev 2024-05-28 06:34:34 +09:00 committed by GitHub
parent 753353411f
commit bbd7015b5d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 55 additions and 3 deletions

View File

@ -391,9 +391,15 @@ class Runnable(Generic[Input, Output], ABC):
from langchain_core.runnables.graph import Graph
graph = Graph()
input_node = graph.add_node(self.get_input_schema(config))
try:
input_node = graph.add_node(self.get_input_schema(config))
except TypeError:
input_node = graph.add_node(create_model(self.get_name("Input")))
runnable_node = graph.add_node(self)
output_node = graph.add_node(self.get_output_schema(config))
try:
output_node = graph.add_node(self.get_output_schema(config))
except TypeError:
output_node = graph.add_node(create_model(self.get_name("Output")))
graph.add_edge(input_node, runnable_node)
graph.add_edge(runnable_node, output_node)
return graph

View File

@ -1,3 +1,5 @@
from typing import Optional
from syrupy import SnapshotAssertion
from langchain_core.language_models import FakeListLLM
@ -5,7 +7,7 @@ from langchain_core.output_parsers.list import CommaSeparatedListOutputParser
from langchain_core.output_parsers.string import StrOutputParser
from langchain_core.output_parsers.xml import XMLOutputParser
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.runnables.base import Runnable
from langchain_core.runnables.base import Runnable, RunnableConfig
def test_graph_single_runnable(snapshot: SnapshotAssertion) -> None:
@ -687,3 +689,47 @@ def test_graph_sequence_map(snapshot: SnapshotAssertion) -> None:
assert graph.draw_ascii() == snapshot(name="ascii")
assert graph.draw_mermaid() == snapshot(name="mermaid")
assert graph.draw_mermaid(with_styles=False) == snapshot(name="mermaid-simple")
def test_runnable_get_graph_with_invalid_input_type() -> None:
"""Test that error isn't raised when getting graph with invalid input type."""
class InvalidInputTypeRunnable(Runnable[int, int]):
@property
def InputType(self) -> type:
raise TypeError()
def invoke(
self,
input: int,
config: Optional[RunnableConfig] = None,
) -> int:
return input
runnable = InvalidInputTypeRunnable()
# check whether runnable.invoke works
assert runnable.invoke(1) == 1
# check whether runnable.get_graph works
runnable.get_graph()
def test_runnable_get_graph_with_invalid_output_type() -> None:
"""Test that error is't raised when getting graph with invalid output type."""
class InvalidOutputTypeRunnable(Runnable[int, int]):
@property
def OutputType(self) -> type:
raise TypeError()
def invoke(
self,
input: int,
config: Optional[RunnableConfig] = None,
) -> int:
return input
runnable = InvalidOutputTypeRunnable()
# check whether runnable.invoke works
assert runnable.invoke(1) == 1
# check whether runnable.get_graph works
runnable.get_graph()