perf(core): copy-on-write callbacks, cache inspect.signature, frozenset config keys

- Add copy-on-write (COW) to BaseCallbackManager.copy() — defer
  shallow copies of handlers/tags/metadata until first mutation
- Cache inspect.signature() results for BaseChatModel._generate and
  _agenerate to avoid repeated introspection per invoke
- Cache signature(self._run) and _get_runnable_config_param in
  ChildTool.__init__ to avoid per-invocation introspection
- Convert CONFIG_KEYS and COPIABLE_KEYS from lists to frozensets
  for O(1) membership checks
- Fast path in _format_for_tracing when no messages have list content

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
John Kennedy
2026-02-27 22:10:45 -08:00
parent e2da14595b
commit 1867503157
4 changed files with 101 additions and 34 deletions

View File

@@ -969,18 +969,36 @@ class BaseCallbackManager(CallbackManagerMixin):
self.inheritable_tags = inheritable_tags or []
self.metadata = metadata or {}
self.inheritable_metadata = inheritable_metadata or {}
self._cow = False
def _cow_copy(self) -> None:
"""Materialize copy-on-write shared state before mutation."""
if self._cow:
self.handlers = self.handlers.copy()
self.inheritable_handlers = self.inheritable_handlers.copy()
self.tags = self.tags.copy()
self.inheritable_tags = self.inheritable_tags.copy()
self.metadata = self.metadata.copy()
self.inheritable_metadata = self.inheritable_metadata.copy()
self._cow = False
def copy(self) -> Self:
"""Return a copy of the callback manager."""
return self.__class__(
handlers=self.handlers.copy(),
inheritable_handlers=self.inheritable_handlers.copy(),
parent_run_id=self.parent_run_id,
tags=self.tags.copy(),
inheritable_tags=self.inheritable_tags.copy(),
metadata=self.metadata.copy(),
inheritable_metadata=self.inheritable_metadata.copy(),
)
"""Return a copy of the callback manager.
Uses copy-on-write: the copy shares underlying lists/dicts until
either the original or the copy is mutated.
"""
self._cow = True
clone = self.__class__.__new__(self.__class__)
clone.handlers = self.handlers
clone.inheritable_handlers = self.inheritable_handlers
clone.parent_run_id = self.parent_run_id
clone.tags = self.tags
clone.inheritable_tags = self.inheritable_tags
clone.metadata = self.metadata
clone.inheritable_metadata = self.inheritable_metadata
clone._cow = True # noqa: SLF001
return clone
def merge(self, other: BaseCallbackManager) -> Self:
"""Merge the callback manager with another callback manager.
@@ -1053,6 +1071,7 @@ class BaseCallbackManager(CallbackManagerMixin):
handler: The handler to add.
inherit: Whether to inherit the handler.
"""
self._cow_copy()
if handler not in self.handlers:
self.handlers.append(handler)
if inherit and handler not in self.inheritable_handlers:
@@ -1064,6 +1083,7 @@ class BaseCallbackManager(CallbackManagerMixin):
Args:
handler: The handler to remove.
"""
self._cow_copy()
if handler in self.handlers:
self.handlers.remove(handler)
if handler in self.inheritable_handlers:
@@ -1080,6 +1100,7 @@ class BaseCallbackManager(CallbackManagerMixin):
handlers: The handlers to set.
inherit: Whether to inherit the handlers.
"""
self._cow = False
self.handlers = []
self.inheritable_handlers = []
for handler in handlers:
@@ -1109,6 +1130,7 @@ class BaseCallbackManager(CallbackManagerMixin):
tags: The tags to add.
inherit: Whether to inherit the tags.
"""
self._cow_copy()
if not self.tags:
self.tags.extend(tags)
if inherit:
@@ -1130,6 +1152,7 @@ class BaseCallbackManager(CallbackManagerMixin):
Args:
tags: The tags to remove.
"""
self._cow_copy()
for tag in tags:
if tag in self.tags:
self.tags.remove(tag)
@@ -1147,6 +1170,7 @@ class BaseCallbackManager(CallbackManagerMixin):
metadata: The metadata to add.
inherit: Whether to inherit the metadata.
"""
self._cow_copy()
self.metadata.update(metadata)
if inherit:
self.inheritable_metadata.update(metadata)
@@ -1157,6 +1181,7 @@ class BaseCallbackManager(CallbackManagerMixin):
Args:
keys: The keys to remove.
"""
self._cow_copy()
for key in keys:
self.metadata.pop(key, None)
self.inheritable_metadata.pop(key, None)

View File

@@ -125,6 +125,9 @@ def _format_for_tracing(messages: list[BaseMessage]) -> list[BaseMessage]:
List of messages formatted for tracing.
"""
# Fast path: if no messages have list content, no formatting is needed.
if not any(isinstance(m.content, list) for m in messages):
return messages
messages_to_trace = []
for message in messages:
message_to_trace = message
@@ -243,6 +246,30 @@ def _format_ls_structured_output(ls_structured_output_format: dict | None) -> di
return ls_structured_output_format_dict
_generate_accepts_run_manager: dict[type, bool] = {}
_agenerate_accepts_run_manager: dict[type, bool] = {}
def _check_generates_accept_run_manager(self: BaseChatModel) -> bool:
cls = type(self)
try:
return _generate_accepts_run_manager[cls]
except KeyError:
result = bool(inspect.signature(self._generate).parameters.get("run_manager"))
_generate_accepts_run_manager[cls] = result
return result
def _check_agenerates_accept_run_manager(self: BaseChatModel) -> bool:
cls = type(self)
try:
return _agenerate_accepts_run_manager[cls]
except KeyError:
result = bool(inspect.signature(self._agenerate).parameters.get("run_manager"))
_agenerate_accepts_run_manager[cls] = result
return result
class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
r"""Base class for chat models.
@@ -1231,7 +1258,7 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
run_manager.on_llm_new_token("", chunk=chunk)
chunks.append(chunk)
result = generate_from_stream(iter(chunks))
elif inspect.signature(self._generate).parameters.get("run_manager"):
elif _check_generates_accept_run_manager(self):
result = self._generate(
messages, stop=stop, run_manager=run_manager, **kwargs
)
@@ -1357,7 +1384,7 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
await run_manager.on_llm_new_token("", chunk=chunk)
chunks.append(chunk)
result = generate_from_stream(iter(chunks))
elif inspect.signature(self._agenerate).parameters.get("run_manager"):
elif _check_agenerates_accept_run_manager(self):
result = await self._agenerate(
messages, stop=stop, run_manager=run_manager, **kwargs
)

View File

@@ -120,23 +120,27 @@ class RunnableConfig(TypedDict, total=False):
"""
CONFIG_KEYS = [
"tags",
"metadata",
"callbacks",
"run_name",
"max_concurrency",
"recursion_limit",
"configurable",
"run_id",
]
CONFIG_KEYS = frozenset(
{
"tags",
"metadata",
"callbacks",
"run_name",
"max_concurrency",
"recursion_limit",
"configurable",
"run_id",
}
)
COPIABLE_KEYS = [
"tags",
"metadata",
"callbacks",
"configurable",
]
COPIABLE_KEYS = frozenset(
{
"tags",
"metadata",
"callbacks",
"configurable",
}
)
DEFAULT_RECURSION_LIMIT = 25

View File

@@ -549,6 +549,19 @@ class ChildTool(BaseTool):
)
raise TypeError(msg)
super().__init__(**kwargs)
# Cache per-invocation introspection results
try:
self._has_run_manager_param: bool = bool(
signature(self._run).parameters.get("run_manager")
)
except (ValueError, TypeError):
self._has_run_manager_param = False
try:
self._runnable_config_param: str | None = _get_runnable_config_param(
self._run
)
except (ValueError, TypeError):
self._runnable_config_param = None
model_config = ConfigDict(
arbitrary_types_allowed=True,
@@ -794,9 +807,7 @@ class ChildTool(BaseTool):
Returns:
The result of the tool execution.
"""
if kwargs.get("run_manager") and signature(self._run).parameters.get(
"run_manager"
):
if kwargs.get("run_manager") and self._has_run_manager_param:
kwargs["run_manager"] = kwargs["run_manager"].get_sync()
return await run_in_executor(None, self._run, *args, **kwargs)
@@ -960,10 +971,10 @@ class ChildTool(BaseTool):
tool_args, tool_kwargs = self._to_args_and_kwargs(
tool_input, tool_call_id
)
if signature(self._run).parameters.get("run_manager"):
if self._has_run_manager_param:
tool_kwargs |= {"run_manager": run_manager}
if config_param := _get_runnable_config_param(self._run):
tool_kwargs |= {config_param: config}
if self._runnable_config_param:
tool_kwargs |= {self._runnable_config_param: config}
response = context.run(self._run, *tool_args, **tool_kwargs)
if self.response_format == "content_and_artifact":
msg = (