Use reference counting for storing inherited run trees to support garbage collection

This commit is contained in:
jacoblee93
2026-04-10 14:26:46 -07:00
parent 9f232caa7a
commit fda666b4d2
5 changed files with 78 additions and 1 deletions

View File

@@ -2480,7 +2480,12 @@ def _configure(
run_tree.trace_id,
run_tree.dotted_order,
)
handler.run_map[str(run_tree.id)] = run_tree
run_id_str = str(run_tree.id)
if run_id_str not in handler.run_map:
handler.run_map[run_id_str] = run_tree
handler._external_run_ids.setdefault( # noqa: SLF001
run_id_str, 0
)
for var, inheritable, handler_class, env_var in _configure_hooks:
create_one = (
env_var is not None

View File

@@ -47,6 +47,15 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
if not run.parent_run_id:
self._persist_run(run)
self.run_map.pop(str(run.id))
# If this run's parent was injected from an external tracing context
# (e.g. a langsmith @traceable), decrement its child refcount and
# remove it from run_map once the last child is done.
parent_id = str(run.parent_run_id) if run.parent_run_id else None
if parent_id and parent_id in self._external_run_ids:
self._external_run_ids[parent_id] -= 1
if self._external_run_ids[parent_id] <= 0:
self.run_map.pop(parent_id, None)
del self._external_run_ids[parent_id]
self._on_run_update(run)
def on_chat_model_start(
@@ -568,6 +577,15 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
if not run.parent_run_id:
await self._persist_run(run)
self.run_map.pop(str(run.id))
# If this run's parent was injected from an external tracing context
# (e.g. a langsmith @traceable), decrement its child refcount and
# remove it from run_map once the last child is done.
parent_id = str(run.parent_run_id) if run.parent_run_id else None
if parent_id and parent_id in self._external_run_ids:
self._external_run_ids[parent_id] -= 1
if self._external_run_ids[parent_id] <= 0:
self.run_map.pop(parent_id, None)
del self._external_run_ids[parent_id]
await self._on_run_update(run)
@override

View File

@@ -53,6 +53,7 @@ class _TracerCore(ABC):
] = "original",
run_map: dict[str, Run] | None = None,
order_map: dict[UUID, tuple[UUID, str]] | None = None,
_external_run_ids: dict[str, int] | None = None,
**kwargs: Any,
) -> None:
"""Initialize the tracer.
@@ -74,6 +75,7 @@ class _TracerCore(ABC):
it does NOT raise an attribute error `on_chat_model_start`
run_map: Optional shared map of run ID to run.
order_map: Optional shared map of run ID to trace ordering data.
_external_run_ids: Optional shared set of externally injected run IDs.
**kwargs: Additional keyword arguments that will be passed to the
superclass.
"""
@@ -87,6 +89,16 @@ class _TracerCore(ABC):
self.order_map = order_map if order_map is not None else {}
"""Map of run ID to (trace_id, dotted_order). Cleared when tracer GCed."""
self._external_run_ids: dict[str, int] = (
_external_run_ids if _external_run_ids is not None else {}
)
"""Refcount of active children per externally-injected run ID.
These runs are added to `run_map` so child runs can find their parent,
but they are not managed by the tracer's callback lifecycle. When
the last child finishes the entry is evicted to avoid memory leaks.
"""
@abstractmethod
def _persist_run(self, run: Run) -> Coroutine[Any, Any, None] | None:
"""Persist a run."""
@@ -117,6 +129,9 @@ class _TracerCore(ABC):
run.dotted_order += "." + current_dotted_order
if parent_run := self.run_map.get(str(run.parent_run_id)):
self._add_child_run(parent_run, run)
parent_key = str(run.parent_run_id)
if parent_key in self._external_run_ids:
self._external_run_ids[parent_key] += 1
else:
if self.log_missing_parent:
logger.debug(

View File

@@ -189,6 +189,7 @@ class LangChainTracer(BaseTracer):
metadata=merged_metadata,
run_map=self.run_map,
order_map=self.order_map,
_external_run_ids=self._external_run_ids,
)
def _start_trace(self, run: Run) -> None:

View File

@@ -555,6 +555,44 @@ def test_tree_is_constructed(parent_type: Literal["ls", "lc"]) -> None:
assert kitten_run.dotted_order.startswith(grandchild_run.dotted_order)
def test_traceable_parent_run_map_cleanup() -> None:
"""External RunTree injected into run_map is cleaned up when its child ends.
When a `@traceable` function invokes a LangChain `Runnable`, the
`RunTree` is added to the tracer's `run_map` so child runs can
reference it. Previously the entry was never removed, causing a
memory leak that grew with every call.
"""
mock_session = MagicMock()
mock_client_ = Client(
session=mock_session, api_key="test", auto_batch_tracing=False
)
@RunnableLambda
def child(x: str) -> str:
return x
with tracing_context(client=mock_client_, enabled=True):
@traceable
def parent(x: str) -> str:
return child.invoke(x)
parent("hello")
# All LangChainTracer instances created during the call should have an
# empty run_map after the call completes.
import gc # noqa: PLC0415
gc.collect()
tracers = [o for o in gc.get_objects() if isinstance(o, LangChainTracer)]
for tracer in tracers:
assert tracer.run_map == {}, (
f"run_map should be empty but contains: "
f"{[getattr(v, 'name', k) for k, v in tracer.run_map.items()]}"
)
class TestTracerMetadataThroughInvoke:
"""Tests for tracer metadata merging through invoke calls."""