From f3aa26d6bfc24e83bf019cf5362a8045f624e95f Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Tue, 16 Apr 2024 13:10:29 -0700 Subject: [PATCH] Fix getattr in runnable binding for cases where config is passed in as arg too (#20528) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …s arg too Thank you for contributing to LangChain! - [ ] **PR title**: "package: description" - Where "package" is whichever of langchain, community, core, experimental, etc. is being modified. Use "docs: ..." for purely docs changes, "templates: ..." for template changes, "infra: ..." for CI changes. - Example: "community: add foobar LLM" - [ ] **PR message**: ***Delete this entire checklist*** and replace with - **Description:** a description of the change - **Issue:** the issue # it fixes, if applicable - **Dependencies:** any dependencies required for this change - **Twitter handle:** if your PR gets announced, and you'd like a mention, we'll gladly shout you out! - [ ] **Add tests and docs**: If you're adding a new integration, please include 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/docs/integrations` directory. - [ ] **Lint and test**: Run `make format`, `make lint` and `make test` from the root of the package(s) you've modified. See contribution guidelines for more: https://python.langchain.com/docs/contributing/ Additional guidelines: - Make sure optional dependencies are imported within a function. - Please do not add dependencies to pyproject.toml files (even optional ones) unless they are required for unit tests. - Most PRs should not touch more than one package. - Changes should be backwards compatible. - If you are adding something to community, do not re-import it in langchain. If no one reviews your PR within a few days, please @-mention one of baskaryan, efriis, eyurtsev, hwchase17. --- libs/core/langchain_core/runnables/base.py | 40 +++++++--- libs/core/langchain_core/runnables/graph.py | 2 + .../langchain_core/runnables/graph_mermaid.py | 1 + .../unit_tests/runnables/test_configurable.py | 77 ++++++++++++++++++- 4 files changed, 109 insertions(+), 11 deletions(-) diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index eb502b08997..b81a15d5d6a 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -4895,17 +4895,39 @@ class RunnableBinding(RunnableBindingBase[Input, Output]): def __getattr__(self, name: str) -> Any: attr = getattr(self.bound, name) - if callable(attr) and accepts_config(attr): + if callable(attr) and ( + config_param := inspect.signature(attr).parameters.get("config") + ): + if config_param.kind == inspect.Parameter.KEYWORD_ONLY: - @wraps(attr) - def wrapper(*args: Any, **kwargs: Any) -> Any: - return attr( - *args, - **kwargs, - config=merge_configs(self.config, kwargs.get("config")), - ) + @wraps(attr) + def wrapper(*args: Any, **kwargs: Any) -> Any: + return attr( + *args, + config=merge_configs(self.config, kwargs.pop("config", None)), + **kwargs, + ) - return wrapper + return wrapper + elif config_param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD: + idx = list(inspect.signature(attr).parameters).index("config") + + @wraps(attr) + def wrapper(*args: Any, **kwargs: Any) -> Any: + if len(args) >= idx + 1: + argsl = list(args) + argsl[idx] = merge_configs(self.config, argsl[idx]) + return attr(*argsl, **kwargs) + else: + return attr( + *args, + config=merge_configs( + self.config, kwargs.pop("config", None) + ), + **kwargs, + ) + + return wrapper return attr diff --git a/libs/core/langchain_core/runnables/graph.py b/libs/core/langchain_core/runnables/graph.py index b9690b9edb9..97fbf00ed9a 100644 --- a/libs/core/langchain_core/runnables/graph.py +++ b/libs/core/langchain_core/runnables/graph.py @@ -344,6 +344,7 @@ class Graph: def draw_mermaid( self, + *, curve_style: CurveStyle = CurveStyle.LINEAR, node_colors: NodeColors = NodeColors( start="#ffdfba", end="#baffc9", other="#fad7de" @@ -372,6 +373,7 @@ class Graph: def draw_mermaid_png( self, + *, curve_style: CurveStyle = CurveStyle.LINEAR, node_colors: NodeColors = NodeColors( start="#ffdfba", end="#baffc9", other="#fad7de" diff --git a/libs/core/langchain_core/runnables/graph_mermaid.py b/libs/core/langchain_core/runnables/graph_mermaid.py index 210a3b5be5a..122020a47d6 100644 --- a/libs/core/langchain_core/runnables/graph_mermaid.py +++ b/libs/core/langchain_core/runnables/graph_mermaid.py @@ -14,6 +14,7 @@ from langchain_core.runnables.graph import ( def draw_mermaid( nodes: Dict[str, str], edges: List[Edge], + *, first_node_label: Optional[str] = None, last_node_label: Optional[str] = None, curve_style: CurveStyle = CurveStyle.LINEAR, diff --git a/libs/core/tests/unit_tests/runnables/test_configurable.py b/libs/core/tests/unit_tests/runnables/test_configurable.py index 6fa048a3d87..c5d74df5ee9 100644 --- a/libs/core/tests/unit_tests/runnables/test_configurable.py +++ b/libs/core/tests/unit_tests/runnables/test_configurable.py @@ -34,7 +34,14 @@ class MyRunnable(RunnableSerializable[str, str]): def my_custom_function(self) -> str: return self.my_property - def my_custom_function_w_config(self, config: RunnableConfig) -> str: + def my_custom_function_w_config( + self, config: Optional[RunnableConfig] = None + ) -> str: + return self.my_property + + def my_custom_function_w_kw_config( + self, *, config: Optional[RunnableConfig] = None + ) -> str: return self.my_property @@ -175,7 +182,73 @@ def test_config_passthrough_nested() -> None: configurable={"my_property": "b"} ).my_custom_function() # type: ignore[attr-defined] == "b" - ) + ), "function without config can be called w bound config" + assert ( + configurable_runnable.with_config( + configurable={"my_property": "b"} + ).my_custom_function_w_config( # type: ignore[attr-defined] + ) + == "b" + ), "func with config arg can be called w bound config without config" + assert ( + configurable_runnable.with_config( + configurable={"my_property": "b"} + ).my_custom_function_w_config( # type: ignore[attr-defined] + config={"configurable": {"my_property": "c"}} + ) + == "c" + ), "func with config arg can be called w bound config with config as kwarg" + assert ( + configurable_runnable.with_config( + configurable={"my_property": "b"} + ).my_custom_function_w_kw_config( # type: ignore[attr-defined] + ) + == "b" + ), "function with config kwarg can be called w bound config w/out config" + assert ( + configurable_runnable.with_config( + configurable={"my_property": "b"} + ).my_custom_function_w_kw_config( # type: ignore[attr-defined] + config={"configurable": {"my_property": "c"}} + ) + == "c" + ), "function with config kwarg can be called w bound config with config" + assert ( + configurable_runnable.with_config(configurable={"my_property": "b"}) + .with_types() + .my_custom_function() # type: ignore[attr-defined] + == "b" + ), "function without config can be called w bound config" + assert ( + configurable_runnable.with_config(configurable={"my_property": "b"}) + .with_types() + .my_custom_function_w_config( # type: ignore[attr-defined] + ) + == "b" + ), "func with config arg can be called w bound config without config" + assert ( + configurable_runnable.with_config(configurable={"my_property": "b"}) + .with_types() + .my_custom_function_w_config( # type: ignore[attr-defined] + config={"configurable": {"my_property": "c"}} + ) + == "c" + ), "func with config arg can be called w bound config with config as kwarg" + assert ( + configurable_runnable.with_config(configurable={"my_property": "b"}) + .with_types() + .my_custom_function_w_kw_config( # type: ignore[attr-defined] + ) + == "b" + ), "function with config kwarg can be called w bound config w/out config" + assert ( + configurable_runnable.with_config(configurable={"my_property": "b"}) + .with_types() + .my_custom_function_w_kw_config( # type: ignore[attr-defined] + config={"configurable": {"my_property": "c"}} + ) + == "c" + ), "function with config kwarg can be called w bound config with config" # second one with pytest.raises(AttributeError): configurable_runnable.my_other_custom_function() # type: ignore[attr-defined]