From 112208baa5de65887cee452426f04d98439fbde3 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Thu, 11 Jan 2024 18:47:55 -0800 Subject: [PATCH] Passthrough configurable primitive values as tracer metadata (#15915) --- libs/core/langchain_core/runnables/config.py | 3 ++ .../unit_tests/runnables/test_runnable.py | 36 +++++++++++++++++++ 2 files changed, 39 insertions(+) diff --git a/libs/core/langchain_core/runnables/config.py b/libs/core/langchain_core/runnables/config.py index bd9330a1fcc..015cb8f4664 100644 --- a/libs/core/langchain_core/runnables/config.py +++ b/libs/core/langchain_core/runnables/config.py @@ -125,6 +125,9 @@ def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig: empty.update( cast(RunnableConfig, {k: v for k, v in config.items() if v is not None}) ) + for key, value in empty.get("configurable", {}).items(): + if isinstance(value, (str, int, float, bool)) and key not in empty["metadata"]: + empty["metadata"][key] = value return empty diff --git a/libs/core/tests/unit_tests/runnables/test_runnable.py b/libs/core/tests/unit_tests/runnables/test_runnable.py index 7a8585e4018..2a94cf2469b 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable.py @@ -70,6 +70,7 @@ from langchain_core.runnables import ( add, chain, ) +from langchain_core.runnables.base import RunnableSerializable from langchain_core.tools import BaseTool, tool from langchain_core.tracers import ( BaseTracer, @@ -142,6 +143,17 @@ class FakeRunnable(Runnable[str, int]): return len(input) +class FakeRunnableSerializable(RunnableSerializable[str, int]): + hello: str = "" + + def invoke( + self, + input: str, + config: Optional[RunnableConfig] = None, + ) -> int: + return len(input) + + class FakeRetriever(BaseRetriever): def _get_relevant_documents( self, @@ -1302,6 +1314,30 @@ async def test_passthrough_tap_async(mocker: MockerFixture) -> None: mock.reset_mock() +async def test_with_config_metadata_passthrough(mocker: MockerFixture) -> None: + fake = FakeRunnableSerializable() + spy = mocker.spy(fake.__class__, "invoke") + fakew = fake.configurable_fields(hello=ConfigurableField(id="hello", name="Hello")) + + assert ( + fakew.with_config(tags=["a-tag"]).invoke( + "hello", {"configurable": {"hello": "there"}, "metadata": {"bye": "now"}} + ) + == 5 + ) + assert spy.call_args_list[0].args[1:] == ( + "hello", + dict( + tags=["a-tag"], + callbacks=None, + recursion_limit=25, + configurable={"hello": "there"}, + metadata={"hello": "there", "bye": "now"}, + ), + ) + spy.reset_mock() + + async def test_with_config(mocker: MockerFixture) -> None: fake = FakeRunnable() spy = mocker.spy(fake, "invoke")