mirror of
https://github.com/hwchase17/langchain.git
synced 2026-03-18 11:07:36 +00:00
perf(core): copy-on-write callbacks, cache inspect.signature, frozenset config keys
- Add copy-on-write (COW) to BaseCallbackManager.copy() — defer shallow copies of handlers/tags/metadata until first mutation - Cache inspect.signature() results for BaseChatModel._generate and _agenerate to avoid repeated introspection per invoke - Cache signature(self._run) and _get_runnable_config_param in ChildTool.__init__ to avoid per-invocation introspection - Convert CONFIG_KEYS and COPIABLE_KEYS from lists to frozensets for O(1) membership checks - Fast path in _format_for_tracing when no messages have list content Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -969,18 +969,36 @@ class BaseCallbackManager(CallbackManagerMixin):
|
||||
self.inheritable_tags = inheritable_tags or []
|
||||
self.metadata = metadata or {}
|
||||
self.inheritable_metadata = inheritable_metadata or {}
|
||||
self._cow = False
|
||||
|
||||
def _cow_copy(self) -> None:
|
||||
"""Materialize copy-on-write shared state before mutation."""
|
||||
if self._cow:
|
||||
self.handlers = self.handlers.copy()
|
||||
self.inheritable_handlers = self.inheritable_handlers.copy()
|
||||
self.tags = self.tags.copy()
|
||||
self.inheritable_tags = self.inheritable_tags.copy()
|
||||
self.metadata = self.metadata.copy()
|
||||
self.inheritable_metadata = self.inheritable_metadata.copy()
|
||||
self._cow = False
|
||||
|
||||
def copy(self) -> Self:
|
||||
"""Return a copy of the callback manager."""
|
||||
return self.__class__(
|
||||
handlers=self.handlers.copy(),
|
||||
inheritable_handlers=self.inheritable_handlers.copy(),
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags.copy(),
|
||||
inheritable_tags=self.inheritable_tags.copy(),
|
||||
metadata=self.metadata.copy(),
|
||||
inheritable_metadata=self.inheritable_metadata.copy(),
|
||||
)
|
||||
"""Return a copy of the callback manager.
|
||||
|
||||
Uses copy-on-write: the copy shares underlying lists/dicts until
|
||||
either the original or the copy is mutated.
|
||||
"""
|
||||
self._cow = True
|
||||
clone = self.__class__.__new__(self.__class__)
|
||||
clone.handlers = self.handlers
|
||||
clone.inheritable_handlers = self.inheritable_handlers
|
||||
clone.parent_run_id = self.parent_run_id
|
||||
clone.tags = self.tags
|
||||
clone.inheritable_tags = self.inheritable_tags
|
||||
clone.metadata = self.metadata
|
||||
clone.inheritable_metadata = self.inheritable_metadata
|
||||
clone._cow = True # noqa: SLF001
|
||||
return clone
|
||||
|
||||
def merge(self, other: BaseCallbackManager) -> Self:
|
||||
"""Merge the callback manager with another callback manager.
|
||||
@@ -1053,6 +1071,7 @@ class BaseCallbackManager(CallbackManagerMixin):
|
||||
handler: The handler to add.
|
||||
inherit: Whether to inherit the handler.
|
||||
"""
|
||||
self._cow_copy()
|
||||
if handler not in self.handlers:
|
||||
self.handlers.append(handler)
|
||||
if inherit and handler not in self.inheritable_handlers:
|
||||
@@ -1064,6 +1083,7 @@ class BaseCallbackManager(CallbackManagerMixin):
|
||||
Args:
|
||||
handler: The handler to remove.
|
||||
"""
|
||||
self._cow_copy()
|
||||
if handler in self.handlers:
|
||||
self.handlers.remove(handler)
|
||||
if handler in self.inheritable_handlers:
|
||||
@@ -1080,6 +1100,7 @@ class BaseCallbackManager(CallbackManagerMixin):
|
||||
handlers: The handlers to set.
|
||||
inherit: Whether to inherit the handlers.
|
||||
"""
|
||||
self._cow = False
|
||||
self.handlers = []
|
||||
self.inheritable_handlers = []
|
||||
for handler in handlers:
|
||||
@@ -1109,6 +1130,7 @@ class BaseCallbackManager(CallbackManagerMixin):
|
||||
tags: The tags to add.
|
||||
inherit: Whether to inherit the tags.
|
||||
"""
|
||||
self._cow_copy()
|
||||
if not self.tags:
|
||||
self.tags.extend(tags)
|
||||
if inherit:
|
||||
@@ -1130,6 +1152,7 @@ class BaseCallbackManager(CallbackManagerMixin):
|
||||
Args:
|
||||
tags: The tags to remove.
|
||||
"""
|
||||
self._cow_copy()
|
||||
for tag in tags:
|
||||
if tag in self.tags:
|
||||
self.tags.remove(tag)
|
||||
@@ -1147,6 +1170,7 @@ class BaseCallbackManager(CallbackManagerMixin):
|
||||
metadata: The metadata to add.
|
||||
inherit: Whether to inherit the metadata.
|
||||
"""
|
||||
self._cow_copy()
|
||||
self.metadata.update(metadata)
|
||||
if inherit:
|
||||
self.inheritable_metadata.update(metadata)
|
||||
@@ -1157,6 +1181,7 @@ class BaseCallbackManager(CallbackManagerMixin):
|
||||
Args:
|
||||
keys: The keys to remove.
|
||||
"""
|
||||
self._cow_copy()
|
||||
for key in keys:
|
||||
self.metadata.pop(key, None)
|
||||
self.inheritable_metadata.pop(key, None)
|
||||
|
||||
@@ -125,6 +125,9 @@ def _format_for_tracing(messages: list[BaseMessage]) -> list[BaseMessage]:
|
||||
List of messages formatted for tracing.
|
||||
|
||||
"""
|
||||
# Fast path: if no messages have list content, no formatting is needed.
|
||||
if not any(isinstance(m.content, list) for m in messages):
|
||||
return messages
|
||||
messages_to_trace = []
|
||||
for message in messages:
|
||||
message_to_trace = message
|
||||
@@ -243,6 +246,30 @@ def _format_ls_structured_output(ls_structured_output_format: dict | None) -> di
|
||||
return ls_structured_output_format_dict
|
||||
|
||||
|
||||
_generate_accepts_run_manager: dict[type, bool] = {}
|
||||
_agenerate_accepts_run_manager: dict[type, bool] = {}
|
||||
|
||||
|
||||
def _check_generates_accept_run_manager(self: BaseChatModel) -> bool:
|
||||
cls = type(self)
|
||||
try:
|
||||
return _generate_accepts_run_manager[cls]
|
||||
except KeyError:
|
||||
result = bool(inspect.signature(self._generate).parameters.get("run_manager"))
|
||||
_generate_accepts_run_manager[cls] = result
|
||||
return result
|
||||
|
||||
|
||||
def _check_agenerates_accept_run_manager(self: BaseChatModel) -> bool:
|
||||
cls = type(self)
|
||||
try:
|
||||
return _agenerate_accepts_run_manager[cls]
|
||||
except KeyError:
|
||||
result = bool(inspect.signature(self._agenerate).parameters.get("run_manager"))
|
||||
_agenerate_accepts_run_manager[cls] = result
|
||||
return result
|
||||
|
||||
|
||||
class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
|
||||
r"""Base class for chat models.
|
||||
|
||||
@@ -1231,7 +1258,7 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
|
||||
run_manager.on_llm_new_token("", chunk=chunk)
|
||||
chunks.append(chunk)
|
||||
result = generate_from_stream(iter(chunks))
|
||||
elif inspect.signature(self._generate).parameters.get("run_manager"):
|
||||
elif _check_generates_accept_run_manager(self):
|
||||
result = self._generate(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
@@ -1357,7 +1384,7 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
|
||||
await run_manager.on_llm_new_token("", chunk=chunk)
|
||||
chunks.append(chunk)
|
||||
result = generate_from_stream(iter(chunks))
|
||||
elif inspect.signature(self._agenerate).parameters.get("run_manager"):
|
||||
elif _check_agenerates_accept_run_manager(self):
|
||||
result = await self._agenerate(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
|
||||
@@ -120,23 +120,27 @@ class RunnableConfig(TypedDict, total=False):
|
||||
"""
|
||||
|
||||
|
||||
CONFIG_KEYS = [
|
||||
"tags",
|
||||
"metadata",
|
||||
"callbacks",
|
||||
"run_name",
|
||||
"max_concurrency",
|
||||
"recursion_limit",
|
||||
"configurable",
|
||||
"run_id",
|
||||
]
|
||||
CONFIG_KEYS = frozenset(
|
||||
{
|
||||
"tags",
|
||||
"metadata",
|
||||
"callbacks",
|
||||
"run_name",
|
||||
"max_concurrency",
|
||||
"recursion_limit",
|
||||
"configurable",
|
||||
"run_id",
|
||||
}
|
||||
)
|
||||
|
||||
COPIABLE_KEYS = [
|
||||
"tags",
|
||||
"metadata",
|
||||
"callbacks",
|
||||
"configurable",
|
||||
]
|
||||
COPIABLE_KEYS = frozenset(
|
||||
{
|
||||
"tags",
|
||||
"metadata",
|
||||
"callbacks",
|
||||
"configurable",
|
||||
}
|
||||
)
|
||||
|
||||
DEFAULT_RECURSION_LIMIT = 25
|
||||
|
||||
|
||||
@@ -549,6 +549,19 @@ class ChildTool(BaseTool):
|
||||
)
|
||||
raise TypeError(msg)
|
||||
super().__init__(**kwargs)
|
||||
# Cache per-invocation introspection results
|
||||
try:
|
||||
self._has_run_manager_param: bool = bool(
|
||||
signature(self._run).parameters.get("run_manager")
|
||||
)
|
||||
except (ValueError, TypeError):
|
||||
self._has_run_manager_param = False
|
||||
try:
|
||||
self._runnable_config_param: str | None = _get_runnable_config_param(
|
||||
self._run
|
||||
)
|
||||
except (ValueError, TypeError):
|
||||
self._runnable_config_param = None
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
@@ -794,9 +807,7 @@ class ChildTool(BaseTool):
|
||||
Returns:
|
||||
The result of the tool execution.
|
||||
"""
|
||||
if kwargs.get("run_manager") and signature(self._run).parameters.get(
|
||||
"run_manager"
|
||||
):
|
||||
if kwargs.get("run_manager") and self._has_run_manager_param:
|
||||
kwargs["run_manager"] = kwargs["run_manager"].get_sync()
|
||||
return await run_in_executor(None, self._run, *args, **kwargs)
|
||||
|
||||
@@ -960,10 +971,10 @@ class ChildTool(BaseTool):
|
||||
tool_args, tool_kwargs = self._to_args_and_kwargs(
|
||||
tool_input, tool_call_id
|
||||
)
|
||||
if signature(self._run).parameters.get("run_manager"):
|
||||
if self._has_run_manager_param:
|
||||
tool_kwargs |= {"run_manager": run_manager}
|
||||
if config_param := _get_runnable_config_param(self._run):
|
||||
tool_kwargs |= {config_param: config}
|
||||
if self._runnable_config_param:
|
||||
tool_kwargs |= {self._runnable_config_param: config}
|
||||
response = context.run(self._run, *tool_args, **tool_kwargs)
|
||||
if self.response_format == "content_and_artifact":
|
||||
msg = (
|
||||
|
||||
Reference in New Issue
Block a user