mirror of
https://github.com/hwchase17/langchain.git
synced 2026-04-13 07:52:48 +00:00
Compare commits
3 Commits
sr/type-sa
...
jk/optimiz
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1867503157 | ||
|
|
e2da14595b | ||
|
|
67b777f655 |
@@ -969,18 +969,36 @@ class BaseCallbackManager(CallbackManagerMixin):
|
||||
self.inheritable_tags = inheritable_tags or []
|
||||
self.metadata = metadata or {}
|
||||
self.inheritable_metadata = inheritable_metadata or {}
|
||||
self._cow = False
|
||||
|
||||
def _cow_copy(self) -> None:
|
||||
"""Materialize copy-on-write shared state before mutation."""
|
||||
if self._cow:
|
||||
self.handlers = self.handlers.copy()
|
||||
self.inheritable_handlers = self.inheritable_handlers.copy()
|
||||
self.tags = self.tags.copy()
|
||||
self.inheritable_tags = self.inheritable_tags.copy()
|
||||
self.metadata = self.metadata.copy()
|
||||
self.inheritable_metadata = self.inheritable_metadata.copy()
|
||||
self._cow = False
|
||||
|
||||
def copy(self) -> Self:
|
||||
"""Return a copy of the callback manager."""
|
||||
return self.__class__(
|
||||
handlers=self.handlers.copy(),
|
||||
inheritable_handlers=self.inheritable_handlers.copy(),
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags.copy(),
|
||||
inheritable_tags=self.inheritable_tags.copy(),
|
||||
metadata=self.metadata.copy(),
|
||||
inheritable_metadata=self.inheritable_metadata.copy(),
|
||||
)
|
||||
"""Return a copy of the callback manager.
|
||||
|
||||
Uses copy-on-write: the copy shares underlying lists/dicts until
|
||||
either the original or the copy is mutated.
|
||||
"""
|
||||
self._cow = True
|
||||
clone = self.__class__.__new__(self.__class__)
|
||||
clone.handlers = self.handlers
|
||||
clone.inheritable_handlers = self.inheritable_handlers
|
||||
clone.parent_run_id = self.parent_run_id
|
||||
clone.tags = self.tags
|
||||
clone.inheritable_tags = self.inheritable_tags
|
||||
clone.metadata = self.metadata
|
||||
clone.inheritable_metadata = self.inheritable_metadata
|
||||
clone._cow = True # noqa: SLF001
|
||||
return clone
|
||||
|
||||
def merge(self, other: BaseCallbackManager) -> Self:
|
||||
"""Merge the callback manager with another callback manager.
|
||||
@@ -1053,6 +1071,7 @@ class BaseCallbackManager(CallbackManagerMixin):
|
||||
handler: The handler to add.
|
||||
inherit: Whether to inherit the handler.
|
||||
"""
|
||||
self._cow_copy()
|
||||
if handler not in self.handlers:
|
||||
self.handlers.append(handler)
|
||||
if inherit and handler not in self.inheritable_handlers:
|
||||
@@ -1064,6 +1083,7 @@ class BaseCallbackManager(CallbackManagerMixin):
|
||||
Args:
|
||||
handler: The handler to remove.
|
||||
"""
|
||||
self._cow_copy()
|
||||
if handler in self.handlers:
|
||||
self.handlers.remove(handler)
|
||||
if handler in self.inheritable_handlers:
|
||||
@@ -1080,6 +1100,7 @@ class BaseCallbackManager(CallbackManagerMixin):
|
||||
handlers: The handlers to set.
|
||||
inherit: Whether to inherit the handlers.
|
||||
"""
|
||||
self._cow = False
|
||||
self.handlers = []
|
||||
self.inheritable_handlers = []
|
||||
for handler in handlers:
|
||||
@@ -1109,12 +1130,21 @@ class BaseCallbackManager(CallbackManagerMixin):
|
||||
tags: The tags to add.
|
||||
inherit: Whether to inherit the tags.
|
||||
"""
|
||||
for tag in tags:
|
||||
if tag in self.tags:
|
||||
self.remove_tags([tag])
|
||||
self.tags.extend(tags)
|
||||
self._cow_copy()
|
||||
if not self.tags:
|
||||
self.tags.extend(tags)
|
||||
if inherit:
|
||||
self.inheritable_tags.extend(tags)
|
||||
return
|
||||
# Deduplicate: tag order is not meaningful across the codebase
|
||||
# (merge_configs sorts, tracers deduplicate via sets).
|
||||
existing = set(self.tags)
|
||||
new_tags = [t for t in tags if t not in existing]
|
||||
self.tags.extend(new_tags)
|
||||
if inherit:
|
||||
self.inheritable_tags.extend(tags)
|
||||
existing_inh = set(self.inheritable_tags)
|
||||
new_inh = [t for t in tags if t not in existing_inh]
|
||||
self.inheritable_tags.extend(new_inh)
|
||||
|
||||
def remove_tags(self, tags: list[str]) -> None:
|
||||
"""Remove tags from the callback manager.
|
||||
@@ -1122,6 +1152,7 @@ class BaseCallbackManager(CallbackManagerMixin):
|
||||
Args:
|
||||
tags: The tags to remove.
|
||||
"""
|
||||
self._cow_copy()
|
||||
for tag in tags:
|
||||
if tag in self.tags:
|
||||
self.tags.remove(tag)
|
||||
@@ -1139,6 +1170,7 @@ class BaseCallbackManager(CallbackManagerMixin):
|
||||
metadata: The metadata to add.
|
||||
inherit: Whether to inherit the metadata.
|
||||
"""
|
||||
self._cow_copy()
|
||||
self.metadata.update(metadata)
|
||||
if inherit:
|
||||
self.inheritable_metadata.update(metadata)
|
||||
@@ -1149,6 +1181,7 @@ class BaseCallbackManager(CallbackManagerMixin):
|
||||
Args:
|
||||
keys: The keys to remove.
|
||||
"""
|
||||
self._cow_copy()
|
||||
for key in keys:
|
||||
self.metadata.pop(key, None)
|
||||
self.inheritable_metadata.pop(key, None)
|
||||
|
||||
@@ -269,6 +269,9 @@ def handle_event(
|
||||
**kwargs: The keyword arguments to pass to the event handler
|
||||
|
||||
"""
|
||||
if not handlers:
|
||||
return
|
||||
|
||||
coros: list[Coroutine[Any, Any, Any]] = []
|
||||
|
||||
try:
|
||||
@@ -433,6 +436,9 @@ async def ahandle_event(
|
||||
**kwargs: The keyword arguments to pass to the event handler.
|
||||
|
||||
"""
|
||||
if not handlers:
|
||||
return
|
||||
|
||||
for handler in [h for h in handlers if h.run_inline]:
|
||||
await _ahandle_event_for_handler(
|
||||
handler, event_name, ignore_condition_name, *args, **kwargs
|
||||
@@ -574,13 +580,18 @@ class ParentRunManager(RunManager):
|
||||
The child callback manager.
|
||||
|
||||
"""
|
||||
manager = CallbackManager(handlers=[], parent_run_id=self.run_id)
|
||||
manager.set_handlers(self.inheritable_handlers)
|
||||
manager.add_tags(self.inheritable_tags)
|
||||
manager.add_metadata(self.inheritable_metadata)
|
||||
tags = list(self.inheritable_tags)
|
||||
if tag is not None:
|
||||
manager.add_tags([tag], inherit=False)
|
||||
return manager
|
||||
tags.append(tag)
|
||||
return CallbackManager(
|
||||
handlers=list(self.inheritable_handlers),
|
||||
inheritable_handlers=list(self.inheritable_handlers),
|
||||
parent_run_id=self.run_id,
|
||||
tags=tags,
|
||||
inheritable_tags=list(self.inheritable_tags),
|
||||
metadata=dict(self.inheritable_metadata),
|
||||
inheritable_metadata=dict(self.inheritable_metadata),
|
||||
)
|
||||
|
||||
|
||||
class AsyncRunManager(BaseRunManager, ABC):
|
||||
@@ -658,13 +669,18 @@ class AsyncParentRunManager(AsyncRunManager):
|
||||
The child callback manager.
|
||||
|
||||
"""
|
||||
manager = AsyncCallbackManager(handlers=[], parent_run_id=self.run_id)
|
||||
manager.set_handlers(self.inheritable_handlers)
|
||||
manager.add_tags(self.inheritable_tags)
|
||||
manager.add_metadata(self.inheritable_metadata)
|
||||
tags = list(self.inheritable_tags)
|
||||
if tag is not None:
|
||||
manager.add_tags([tag], inherit=False)
|
||||
return manager
|
||||
tags.append(tag)
|
||||
return AsyncCallbackManager(
|
||||
handlers=list(self.inheritable_handlers),
|
||||
inheritable_handlers=list(self.inheritable_handlers),
|
||||
parent_run_id=self.run_id,
|
||||
tags=tags,
|
||||
inheritable_tags=list(self.inheritable_tags),
|
||||
metadata=dict(self.inheritable_metadata),
|
||||
inheritable_metadata=dict(self.inheritable_metadata),
|
||||
)
|
||||
|
||||
|
||||
class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
|
||||
@@ -2340,10 +2356,6 @@ def _configure(
|
||||
tracing_tags = tracing_context["tags"]
|
||||
run_tree: Run | None = tracing_context["parent"]
|
||||
parent_run_id = None if run_tree is None else run_tree.id
|
||||
callback_manager = callback_manager_cls(
|
||||
handlers=[],
|
||||
parent_run_id=parent_run_id,
|
||||
)
|
||||
if inheritable_callbacks or local_callbacks:
|
||||
if isinstance(inheritable_callbacks, list) or inheritable_callbacks is None:
|
||||
inheritable_callbacks_ = inheritable_callbacks or []
|
||||
@@ -2381,6 +2393,11 @@ def _configure(
|
||||
)
|
||||
for handler in local_handlers_:
|
||||
callback_manager.add_handler(handler, inherit=False)
|
||||
else:
|
||||
callback_manager = callback_manager_cls(
|
||||
handlers=[],
|
||||
parent_run_id=parent_run_id,
|
||||
)
|
||||
if inheritable_tags or local_tags:
|
||||
callback_manager.add_tags(inheritable_tags or [])
|
||||
callback_manager.add_tags(local_tags or [], inherit=False)
|
||||
|
||||
@@ -125,6 +125,9 @@ def _format_for_tracing(messages: list[BaseMessage]) -> list[BaseMessage]:
|
||||
List of messages formatted for tracing.
|
||||
|
||||
"""
|
||||
# Fast path: if no messages have list content, no formatting is needed.
|
||||
if not any(isinstance(m.content, list) for m in messages):
|
||||
return messages
|
||||
messages_to_trace = []
|
||||
for message in messages:
|
||||
message_to_trace = message
|
||||
@@ -243,6 +246,30 @@ def _format_ls_structured_output(ls_structured_output_format: dict | None) -> di
|
||||
return ls_structured_output_format_dict
|
||||
|
||||
|
||||
_generate_accepts_run_manager: dict[type, bool] = {}
|
||||
_agenerate_accepts_run_manager: dict[type, bool] = {}
|
||||
|
||||
|
||||
def _check_generates_accept_run_manager(self: BaseChatModel) -> bool:
|
||||
cls = type(self)
|
||||
try:
|
||||
return _generate_accepts_run_manager[cls]
|
||||
except KeyError:
|
||||
result = bool(inspect.signature(self._generate).parameters.get("run_manager"))
|
||||
_generate_accepts_run_manager[cls] = result
|
||||
return result
|
||||
|
||||
|
||||
def _check_agenerates_accept_run_manager(self: BaseChatModel) -> bool:
|
||||
cls = type(self)
|
||||
try:
|
||||
return _agenerate_accepts_run_manager[cls]
|
||||
except KeyError:
|
||||
result = bool(inspect.signature(self._agenerate).parameters.get("run_manager"))
|
||||
_agenerate_accepts_run_manager[cls] = result
|
||||
return result
|
||||
|
||||
|
||||
class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
|
||||
r"""Base class for chat models.
|
||||
|
||||
@@ -1231,7 +1258,7 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
|
||||
run_manager.on_llm_new_token("", chunk=chunk)
|
||||
chunks.append(chunk)
|
||||
result = generate_from_stream(iter(chunks))
|
||||
elif inspect.signature(self._generate).parameters.get("run_manager"):
|
||||
elif _check_generates_accept_run_manager(self):
|
||||
result = self._generate(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
@@ -1357,7 +1384,7 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
|
||||
await run_manager.on_llm_new_token("", chunk=chunk)
|
||||
chunks.append(chunk)
|
||||
result = generate_from_stream(iter(chunks))
|
||||
elif inspect.signature(self._agenerate).parameters.get("run_manager"):
|
||||
elif _check_agenerates_accept_run_manager(self):
|
||||
result = await self._agenerate(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
|
||||
@@ -120,23 +120,27 @@ class RunnableConfig(TypedDict, total=False):
|
||||
"""
|
||||
|
||||
|
||||
CONFIG_KEYS = [
|
||||
"tags",
|
||||
"metadata",
|
||||
"callbacks",
|
||||
"run_name",
|
||||
"max_concurrency",
|
||||
"recursion_limit",
|
||||
"configurable",
|
||||
"run_id",
|
||||
]
|
||||
CONFIG_KEYS = frozenset(
|
||||
{
|
||||
"tags",
|
||||
"metadata",
|
||||
"callbacks",
|
||||
"run_name",
|
||||
"max_concurrency",
|
||||
"recursion_limit",
|
||||
"configurable",
|
||||
"run_id",
|
||||
}
|
||||
)
|
||||
|
||||
COPIABLE_KEYS = [
|
||||
"tags",
|
||||
"metadata",
|
||||
"callbacks",
|
||||
"configurable",
|
||||
]
|
||||
COPIABLE_KEYS = frozenset(
|
||||
{
|
||||
"tags",
|
||||
"metadata",
|
||||
"callbacks",
|
||||
"configurable",
|
||||
}
|
||||
)
|
||||
|
||||
DEFAULT_RECURSION_LIMIT = 25
|
||||
|
||||
|
||||
@@ -549,6 +549,19 @@ class ChildTool(BaseTool):
|
||||
)
|
||||
raise TypeError(msg)
|
||||
super().__init__(**kwargs)
|
||||
# Cache per-invocation introspection results
|
||||
try:
|
||||
self._has_run_manager_param: bool = bool(
|
||||
signature(self._run).parameters.get("run_manager")
|
||||
)
|
||||
except (ValueError, TypeError):
|
||||
self._has_run_manager_param = False
|
||||
try:
|
||||
self._runnable_config_param: str | None = _get_runnable_config_param(
|
||||
self._run
|
||||
)
|
||||
except (ValueError, TypeError):
|
||||
self._runnable_config_param = None
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
@@ -794,9 +807,7 @@ class ChildTool(BaseTool):
|
||||
Returns:
|
||||
The result of the tool execution.
|
||||
"""
|
||||
if kwargs.get("run_manager") and signature(self._run).parameters.get(
|
||||
"run_manager"
|
||||
):
|
||||
if kwargs.get("run_manager") and self._has_run_manager_param:
|
||||
kwargs["run_manager"] = kwargs["run_manager"].get_sync()
|
||||
return await run_in_executor(None, self._run, *args, **kwargs)
|
||||
|
||||
@@ -960,10 +971,10 @@ class ChildTool(BaseTool):
|
||||
tool_args, tool_kwargs = self._to_args_and_kwargs(
|
||||
tool_input, tool_call_id
|
||||
)
|
||||
if signature(self._run).parameters.get("run_manager"):
|
||||
if self._has_run_manager_param:
|
||||
tool_kwargs |= {"run_manager": run_manager}
|
||||
if config_param := _get_runnable_config_param(self._run):
|
||||
tool_kwargs |= {config_param: config}
|
||||
if self._runnable_config_param:
|
||||
tool_kwargs |= {self._runnable_config_param: config}
|
||||
response = context.run(self._run, *tool_args, **tool_kwargs)
|
||||
if self.response_format == "content_and_artifact":
|
||||
msg = (
|
||||
|
||||
@@ -6,6 +6,7 @@ import json
|
||||
from collections.abc import Callable, Iterator, Sequence
|
||||
from json import JSONDecodeError
|
||||
from typing import Any, Literal, TypeAlias, cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import openai
|
||||
from langchain_core.callbacks import (
|
||||
@@ -197,7 +198,8 @@ class ChatDeepSeek(BaseChatOpenAI):
|
||||
@property
|
||||
def _is_azure_endpoint(self) -> bool:
|
||||
"""Check if the configured endpoint is an Azure deployment."""
|
||||
return "azure.com" in (self.api_base or "").lower()
|
||||
hostname = urlparse(self.api_base or "").hostname or ""
|
||||
return hostname == "azure.com" or hostname.endswith(".azure.com")
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
|
||||
@@ -348,6 +348,9 @@ class TestChatDeepSeekAzureToolChoice:
|
||||
DEFAULT_API_BASE,
|
||||
"https://api.openai.com/v1",
|
||||
"https://custom-endpoint.com/api",
|
||||
"https://evil-azure.com/v1", # hostname bypass attempt
|
||||
"https://notazure.com.evil.com/", # subdomain bypass attempt
|
||||
"https://example.com/azure.com", # path bypass attempt
|
||||
]
|
||||
for endpoint in non_azure_endpoints:
|
||||
llm = ChatDeepSeek(
|
||||
|
||||
Reference in New Issue
Block a user