mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-27 08:58:48 +00:00
core[patch]: Add TypeError
handler into get_graph
of Runnable
(#19856)
# Description ## Problem `Runnable.get_graph` fails when `InputType` or `OutputType` property raises `TypeError`. -003c98e5b4/libs/core/langchain_core/runnables/base.py (L250-L274)
-003c98e5b4/libs/core/langchain_core/runnables/base.py (L394-L396)
This problem prevents getting a graph of `Runnable` objects whose `InputType` or `OutputType` property raises `TypeError` but whose `invoke` works well, such as `langchain.output_parsers.RegexParser`, which I have already pointed out in #19792 that a `TypeError` would occur. ## Solution - Add `try-except` syntax to handle `TypeError` to the codes which get `input_node` and `output_node`. # Issue - #19801 # Twitter Handle - [hmdev3](https://twitter.com/hmdev3) --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
753353411f
commit
bbd7015b5d
@ -391,9 +391,15 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
from langchain_core.runnables.graph import Graph
|
||||
|
||||
graph = Graph()
|
||||
input_node = graph.add_node(self.get_input_schema(config))
|
||||
try:
|
||||
input_node = graph.add_node(self.get_input_schema(config))
|
||||
except TypeError:
|
||||
input_node = graph.add_node(create_model(self.get_name("Input")))
|
||||
runnable_node = graph.add_node(self)
|
||||
output_node = graph.add_node(self.get_output_schema(config))
|
||||
try:
|
||||
output_node = graph.add_node(self.get_output_schema(config))
|
||||
except TypeError:
|
||||
output_node = graph.add_node(create_model(self.get_name("Output")))
|
||||
graph.add_edge(input_node, runnable_node)
|
||||
graph.add_edge(runnable_node, output_node)
|
||||
return graph
|
||||
|
@ -1,3 +1,5 @@
|
||||
from typing import Optional
|
||||
|
||||
from syrupy import SnapshotAssertion
|
||||
|
||||
from langchain_core.language_models import FakeListLLM
|
||||
@ -5,7 +7,7 @@ from langchain_core.output_parsers.list import CommaSeparatedListOutputParser
|
||||
from langchain_core.output_parsers.string import StrOutputParser
|
||||
from langchain_core.output_parsers.xml import XMLOutputParser
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.runnables.base import Runnable
|
||||
from langchain_core.runnables.base import Runnable, RunnableConfig
|
||||
|
||||
|
||||
def test_graph_single_runnable(snapshot: SnapshotAssertion) -> None:
|
||||
@ -687,3 +689,47 @@ def test_graph_sequence_map(snapshot: SnapshotAssertion) -> None:
|
||||
assert graph.draw_ascii() == snapshot(name="ascii")
|
||||
assert graph.draw_mermaid() == snapshot(name="mermaid")
|
||||
assert graph.draw_mermaid(with_styles=False) == snapshot(name="mermaid-simple")
|
||||
|
||||
|
||||
def test_runnable_get_graph_with_invalid_input_type() -> None:
|
||||
"""Test that error isn't raised when getting graph with invalid input type."""
|
||||
|
||||
class InvalidInputTypeRunnable(Runnable[int, int]):
|
||||
@property
|
||||
def InputType(self) -> type:
|
||||
raise TypeError()
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
input: int,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
) -> int:
|
||||
return input
|
||||
|
||||
runnable = InvalidInputTypeRunnable()
|
||||
# check whether runnable.invoke works
|
||||
assert runnable.invoke(1) == 1
|
||||
# check whether runnable.get_graph works
|
||||
runnable.get_graph()
|
||||
|
||||
|
||||
def test_runnable_get_graph_with_invalid_output_type() -> None:
|
||||
"""Test that error is't raised when getting graph with invalid output type."""
|
||||
|
||||
class InvalidOutputTypeRunnable(Runnable[int, int]):
|
||||
@property
|
||||
def OutputType(self) -> type:
|
||||
raise TypeError()
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
input: int,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
) -> int:
|
||||
return input
|
||||
|
||||
runnable = InvalidOutputTypeRunnable()
|
||||
# check whether runnable.invoke works
|
||||
assert runnable.invoke(1) == 1
|
||||
# check whether runnable.get_graph works
|
||||
runnable.get_graph()
|
||||
|
Loading…
Reference in New Issue
Block a user