diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index eb502b08997..b81a15d5d6a 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -4895,17 +4895,39 @@ class RunnableBinding(RunnableBindingBase[Input, Output]): def __getattr__(self, name: str) -> Any: attr = getattr(self.bound, name) - if callable(attr) and accepts_config(attr): + if callable(attr) and ( + config_param := inspect.signature(attr).parameters.get("config") + ): + if config_param.kind == inspect.Parameter.KEYWORD_ONLY: - @wraps(attr) - def wrapper(*args: Any, **kwargs: Any) -> Any: - return attr( - *args, - **kwargs, - config=merge_configs(self.config, kwargs.get("config")), - ) + @wraps(attr) + def wrapper(*args: Any, **kwargs: Any) -> Any: + return attr( + *args, + config=merge_configs(self.config, kwargs.pop("config", None)), + **kwargs, + ) - return wrapper + return wrapper + elif config_param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD: + idx = list(inspect.signature(attr).parameters).index("config") + + @wraps(attr) + def wrapper(*args: Any, **kwargs: Any) -> Any: + if len(args) >= idx + 1: + argsl = list(args) + argsl[idx] = merge_configs(self.config, argsl[idx]) + return attr(*argsl, **kwargs) + else: + return attr( + *args, + config=merge_configs( + self.config, kwargs.pop("config", None) + ), + **kwargs, + ) + + return wrapper return attr diff --git a/libs/core/langchain_core/runnables/graph.py b/libs/core/langchain_core/runnables/graph.py index b9690b9edb9..97fbf00ed9a 100644 --- a/libs/core/langchain_core/runnables/graph.py +++ b/libs/core/langchain_core/runnables/graph.py @@ -344,6 +344,7 @@ class Graph: def draw_mermaid( self, + *, curve_style: CurveStyle = CurveStyle.LINEAR, node_colors: NodeColors = NodeColors( start="#ffdfba", end="#baffc9", other="#fad7de" @@ -372,6 +373,7 @@ class Graph: def draw_mermaid_png( self, + *, curve_style: CurveStyle = CurveStyle.LINEAR, node_colors: NodeColors = NodeColors( start="#ffdfba", end="#baffc9", other="#fad7de" diff --git a/libs/core/langchain_core/runnables/graph_mermaid.py b/libs/core/langchain_core/runnables/graph_mermaid.py index 210a3b5be5a..122020a47d6 100644 --- a/libs/core/langchain_core/runnables/graph_mermaid.py +++ b/libs/core/langchain_core/runnables/graph_mermaid.py @@ -14,6 +14,7 @@ from langchain_core.runnables.graph import ( def draw_mermaid( nodes: Dict[str, str], edges: List[Edge], + *, first_node_label: Optional[str] = None, last_node_label: Optional[str] = None, curve_style: CurveStyle = CurveStyle.LINEAR, diff --git a/libs/core/tests/unit_tests/runnables/test_configurable.py b/libs/core/tests/unit_tests/runnables/test_configurable.py index 6fa048a3d87..c5d74df5ee9 100644 --- a/libs/core/tests/unit_tests/runnables/test_configurable.py +++ b/libs/core/tests/unit_tests/runnables/test_configurable.py @@ -34,7 +34,14 @@ class MyRunnable(RunnableSerializable[str, str]): def my_custom_function(self) -> str: return self.my_property - def my_custom_function_w_config(self, config: RunnableConfig) -> str: + def my_custom_function_w_config( + self, config: Optional[RunnableConfig] = None + ) -> str: + return self.my_property + + def my_custom_function_w_kw_config( + self, *, config: Optional[RunnableConfig] = None + ) -> str: return self.my_property @@ -175,7 +182,73 @@ def test_config_passthrough_nested() -> None: configurable={"my_property": "b"} ).my_custom_function() # type: ignore[attr-defined] == "b" - ) + ), "function without config can be called w bound config" + assert ( + configurable_runnable.with_config( + configurable={"my_property": "b"} + ).my_custom_function_w_config( # type: ignore[attr-defined] + ) + == "b" + ), "func with config arg can be called w bound config without config" + assert ( + configurable_runnable.with_config( + configurable={"my_property": "b"} + ).my_custom_function_w_config( # type: ignore[attr-defined] + config={"configurable": {"my_property": "c"}} + ) + == "c" + ), "func with config arg can be called w bound config with config as kwarg" + assert ( + configurable_runnable.with_config( + configurable={"my_property": "b"} + ).my_custom_function_w_kw_config( # type: ignore[attr-defined] + ) + == "b" + ), "function with config kwarg can be called w bound config w/out config" + assert ( + configurable_runnable.with_config( + configurable={"my_property": "b"} + ).my_custom_function_w_kw_config( # type: ignore[attr-defined] + config={"configurable": {"my_property": "c"}} + ) + == "c" + ), "function with config kwarg can be called w bound config with config" + assert ( + configurable_runnable.with_config(configurable={"my_property": "b"}) + .with_types() + .my_custom_function() # type: ignore[attr-defined] + == "b" + ), "function without config can be called w bound config" + assert ( + configurable_runnable.with_config(configurable={"my_property": "b"}) + .with_types() + .my_custom_function_w_config( # type: ignore[attr-defined] + ) + == "b" + ), "func with config arg can be called w bound config without config" + assert ( + configurable_runnable.with_config(configurable={"my_property": "b"}) + .with_types() + .my_custom_function_w_config( # type: ignore[attr-defined] + config={"configurable": {"my_property": "c"}} + ) + == "c" + ), "func with config arg can be called w bound config with config as kwarg" + assert ( + configurable_runnable.with_config(configurable={"my_property": "b"}) + .with_types() + .my_custom_function_w_kw_config( # type: ignore[attr-defined] + ) + == "b" + ), "function with config kwarg can be called w bound config w/out config" + assert ( + configurable_runnable.with_config(configurable={"my_property": "b"}) + .with_types() + .my_custom_function_w_kw_config( # type: ignore[attr-defined] + config={"configurable": {"my_property": "c"}} + ) + == "c" + ), "function with config kwarg can be called w bound config with config" # second one with pytest.raises(AttributeError): configurable_runnable.my_other_custom_function() # type: ignore[attr-defined]