mirror of
https://github.com/hwchase17/langchain.git
synced 2026-04-24 04:36:46 +00:00
chore: increase coverage for shell, filesystem, and summarization middleware (#33928)
cc generated, just a start here but wanted to bump things up from 70% ish
This commit is contained in:
@@ -7,6 +7,9 @@ import pytest
|
||||
|
||||
from langchain.agents.middleware.file_search import (
|
||||
FilesystemFileSearchMiddleware,
|
||||
_expand_include_patterns,
|
||||
_is_valid_include_pattern,
|
||||
_match_include_pattern,
|
||||
)
|
||||
|
||||
|
||||
@@ -259,3 +262,105 @@ class TestPathTraversalSecurity:
|
||||
|
||||
assert result == "No matches found"
|
||||
assert "secret" not in result
|
||||
|
||||
|
||||
class TestExpandIncludePatterns:
|
||||
"""Tests for _expand_include_patterns helper function."""
|
||||
|
||||
def test_expand_patterns_basic_brace_expansion(self) -> None:
|
||||
"""Test basic brace expansion with multiple options."""
|
||||
result = _expand_include_patterns("*.{py,txt}")
|
||||
assert result == ["*.py", "*.txt"]
|
||||
|
||||
def test_expand_patterns_nested_braces(self) -> None:
|
||||
"""Test nested brace expansion."""
|
||||
result = _expand_include_patterns("test.{a,b}.{c,d}")
|
||||
assert result is not None
|
||||
assert len(result) == 4
|
||||
assert "test.a.c" in result
|
||||
assert "test.b.d" in result
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"pattern",
|
||||
[
|
||||
"*.py}", # closing brace without opening
|
||||
"*.{}", # empty braces
|
||||
"*.{py", # unclosed brace
|
||||
],
|
||||
)
|
||||
def test_expand_patterns_invalid_braces(self, pattern: str) -> None:
|
||||
"""Test patterns with invalid brace syntax return None."""
|
||||
result = _expand_include_patterns(pattern)
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestValidateIncludePattern:
|
||||
"""Tests for _is_valid_include_pattern helper function."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"pattern",
|
||||
[
|
||||
"", # empty pattern
|
||||
"*.py\x00", # null byte
|
||||
"*.py\n", # newline
|
||||
],
|
||||
)
|
||||
def test_validate_invalid_patterns(self, pattern: str) -> None:
|
||||
"""Test that invalid patterns are rejected."""
|
||||
assert not _is_valid_include_pattern(pattern)
|
||||
|
||||
|
||||
class TestMatchIncludePattern:
|
||||
"""Tests for _match_include_pattern helper function."""
|
||||
|
||||
def test_match_pattern_with_braces(self) -> None:
|
||||
"""Test matching with brace expansion."""
|
||||
assert _match_include_pattern("test.py", "*.{py,txt}")
|
||||
assert _match_include_pattern("test.txt", "*.{py,txt}")
|
||||
assert not _match_include_pattern("test.md", "*.{py,txt}")
|
||||
|
||||
def test_match_pattern_invalid_expansion(self) -> None:
|
||||
"""Test matching with pattern that cannot be expanded returns False."""
|
||||
assert not _match_include_pattern("test.py", "*.{}")
|
||||
|
||||
|
||||
class TestGrepEdgeCases:
|
||||
"""Tests for edge cases in grep search."""
|
||||
|
||||
def test_grep_with_special_chars_in_pattern(self, tmp_path: Path) -> None:
|
||||
"""Test grep with special characters in pattern."""
|
||||
(tmp_path / "test.py").write_text("def test():\n pass\n", encoding="utf-8")
|
||||
|
||||
middleware = FilesystemFileSearchMiddleware(root_path=str(tmp_path), use_ripgrep=False)
|
||||
|
||||
result = middleware.grep_search.func(pattern="def.*:")
|
||||
|
||||
assert "/test.py" in result
|
||||
|
||||
def test_grep_case_insensitive(self, tmp_path: Path) -> None:
|
||||
"""Test grep with case-insensitive search."""
|
||||
(tmp_path / "test.py").write_text("HELLO world\n", encoding="utf-8")
|
||||
|
||||
middleware = FilesystemFileSearchMiddleware(root_path=str(tmp_path), use_ripgrep=False)
|
||||
|
||||
result = middleware.grep_search.func(pattern="(?i)hello")
|
||||
|
||||
assert "/test.py" in result
|
||||
|
||||
def test_grep_with_large_file_skipping(self, tmp_path: Path) -> None:
|
||||
"""Test that grep skips files larger than max_file_size_mb."""
|
||||
# Create a file larger than 1MB
|
||||
large_content = "x" * (2 * 1024 * 1024) # 2MB
|
||||
(tmp_path / "large.txt").write_text(large_content, encoding="utf-8")
|
||||
(tmp_path / "small.txt").write_text("x", encoding="utf-8")
|
||||
|
||||
middleware = FilesystemFileSearchMiddleware(
|
||||
root_path=str(tmp_path),
|
||||
use_ripgrep=False,
|
||||
max_file_size_mb=1, # 1MB limit
|
||||
)
|
||||
|
||||
result = middleware.grep_search.func(pattern="x")
|
||||
|
||||
# Large file should be skipped
|
||||
assert "/small.txt" in result
|
||||
|
||||
@@ -1,17 +1,21 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
import asyncio
|
||||
import gc
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, ToolMessage
|
||||
from langchain_core.tools.base import ToolException
|
||||
|
||||
from langchain.agents.middleware.shell_tool import (
|
||||
HostExecutionPolicy,
|
||||
RedactionRule,
|
||||
ShellToolMiddleware,
|
||||
_SessionResources,
|
||||
RedactionRule,
|
||||
_ShellToolInput,
|
||||
)
|
||||
from langchain.agents.middleware.types import AgentState
|
||||
|
||||
@@ -173,3 +177,294 @@ def test_session_resources_finalizer_cleans_up(tmp_path: Path) -> None:
|
||||
assert not finalizer.alive
|
||||
assert session.stopped
|
||||
assert not tempdir_path.exists()
|
||||
|
||||
|
||||
def test_shell_tool_input_validation() -> None:
|
||||
"""Test _ShellToolInput validation rules."""
|
||||
# Both command and restart not allowed
|
||||
with pytest.raises(ValueError, match="only one"):
|
||||
_ShellToolInput(command="ls", restart=True)
|
||||
|
||||
# Neither command nor restart provided
|
||||
with pytest.raises(ValueError, match="requires either"):
|
||||
_ShellToolInput()
|
||||
|
||||
# Valid: command only
|
||||
valid_cmd = _ShellToolInput(command="ls")
|
||||
assert valid_cmd.command == "ls"
|
||||
assert not valid_cmd.restart
|
||||
|
||||
# Valid: restart only
|
||||
valid_restart = _ShellToolInput(restart=True)
|
||||
assert valid_restart.restart is True
|
||||
assert valid_restart.command is None
|
||||
|
||||
|
||||
def test_normalize_shell_command_empty() -> None:
|
||||
"""Test that empty shell command raises an error."""
|
||||
with pytest.raises(ValueError, match="at least one argument"):
|
||||
ShellToolMiddleware(shell_command=[])
|
||||
|
||||
|
||||
def test_normalize_env_non_string_keys() -> None:
|
||||
"""Test that non-string environment keys raise an error."""
|
||||
with pytest.raises(TypeError, match="must be strings"):
|
||||
ShellToolMiddleware(env={123: "value"}) # type: ignore[dict-item]
|
||||
|
||||
|
||||
def test_normalize_env_coercion(tmp_path: Path) -> None:
|
||||
"""Test that environment values are coerced to strings."""
|
||||
middleware = ShellToolMiddleware(
|
||||
workspace_root=tmp_path / "workspace", env={"NUM": 42, "BOOL": True}
|
||||
)
|
||||
try:
|
||||
state: AgentState = _empty_state()
|
||||
updates = middleware.before_agent(state, None)
|
||||
if updates:
|
||||
state.update(updates)
|
||||
resources = middleware._ensure_resources(state) # type: ignore[attr-defined]
|
||||
result = middleware._run_shell_tool(
|
||||
resources, {"command": "echo $NUM $BOOL"}, tool_call_id=None
|
||||
)
|
||||
assert "42" in result
|
||||
assert "True" in result
|
||||
finally:
|
||||
updates = middleware.after_agent(state, None)
|
||||
if updates:
|
||||
state.update(updates)
|
||||
|
||||
|
||||
def test_shell_tool_missing_command_string(tmp_path: Path) -> None:
|
||||
"""Test that shell tool raises an error when command is not a string."""
|
||||
middleware = ShellToolMiddleware(workspace_root=tmp_path / "workspace")
|
||||
try:
|
||||
state: AgentState = _empty_state()
|
||||
updates = middleware.before_agent(state, None)
|
||||
if updates:
|
||||
state.update(updates)
|
||||
resources = middleware._ensure_resources(state) # type: ignore[attr-defined]
|
||||
|
||||
with pytest.raises(ToolException, match="expects a 'command' string"):
|
||||
middleware._run_shell_tool(resources, {"command": None}, tool_call_id=None)
|
||||
|
||||
with pytest.raises(ToolException, match="expects a 'command' string"):
|
||||
middleware._run_shell_tool(
|
||||
resources,
|
||||
{"command": 123}, # type: ignore[dict-item]
|
||||
tool_call_id=None,
|
||||
)
|
||||
finally:
|
||||
updates = middleware.after_agent(state, None)
|
||||
if updates:
|
||||
state.update(updates)
|
||||
|
||||
|
||||
def test_tool_message_formatting_with_id(tmp_path: Path) -> None:
|
||||
"""Test that tool messages are properly formatted with tool_call_id."""
|
||||
middleware = ShellToolMiddleware(workspace_root=tmp_path / "workspace")
|
||||
try:
|
||||
state: AgentState = _empty_state()
|
||||
updates = middleware.before_agent(state, None)
|
||||
if updates:
|
||||
state.update(updates)
|
||||
resources = middleware._ensure_resources(state) # type: ignore[attr-defined]
|
||||
|
||||
result = middleware._run_shell_tool(
|
||||
resources, {"command": "echo test"}, tool_call_id="test-id-123"
|
||||
)
|
||||
|
||||
assert isinstance(result, ToolMessage)
|
||||
assert result.tool_call_id == "test-id-123"
|
||||
assert result.name == "shell"
|
||||
assert result.status == "success"
|
||||
assert "test" in result.content
|
||||
finally:
|
||||
updates = middleware.after_agent(state, None)
|
||||
if updates:
|
||||
state.update(updates)
|
||||
|
||||
|
||||
def test_nonzero_exit_code_returns_error(tmp_path: Path) -> None:
|
||||
"""Test that non-zero exit codes are marked as errors."""
|
||||
middleware = ShellToolMiddleware(workspace_root=tmp_path / "workspace")
|
||||
try:
|
||||
state: AgentState = _empty_state()
|
||||
updates = middleware.before_agent(state, None)
|
||||
if updates:
|
||||
state.update(updates)
|
||||
resources = middleware._ensure_resources(state) # type: ignore[attr-defined]
|
||||
|
||||
result = middleware._run_shell_tool(
|
||||
resources,
|
||||
{"command": "false"}, # Command that exits with 1 but doesn't kill shell
|
||||
tool_call_id="test-id",
|
||||
)
|
||||
|
||||
assert isinstance(result, ToolMessage)
|
||||
assert result.status == "error"
|
||||
assert "Exit code: 1" in result.content
|
||||
assert result.artifact["exit_code"] == 1 # type: ignore[index]
|
||||
finally:
|
||||
updates = middleware.after_agent(state, None)
|
||||
if updates:
|
||||
state.update(updates)
|
||||
|
||||
|
||||
def test_truncation_by_bytes(tmp_path: Path) -> None:
|
||||
"""Test that output is truncated by bytes when max_output_bytes is exceeded."""
|
||||
policy = HostExecutionPolicy(max_output_bytes=50, command_timeout=5.0)
|
||||
middleware = ShellToolMiddleware(workspace_root=tmp_path / "workspace", execution_policy=policy)
|
||||
try:
|
||||
state: AgentState = _empty_state()
|
||||
updates = middleware.before_agent(state, None)
|
||||
if updates:
|
||||
state.update(updates)
|
||||
resources = middleware._ensure_resources(state) # type: ignore[attr-defined]
|
||||
|
||||
result = middleware._run_shell_tool(
|
||||
resources, {"command": "python3 -c 'print(\"x\" * 100)'"}, tool_call_id=None
|
||||
)
|
||||
|
||||
assert "truncated at 50 bytes" in result.lower()
|
||||
finally:
|
||||
updates = middleware.after_agent(state, None)
|
||||
if updates:
|
||||
state.update(updates)
|
||||
|
||||
|
||||
def test_startup_command_failure(tmp_path: Path) -> None:
|
||||
"""Test that startup command failure raises an error."""
|
||||
policy = HostExecutionPolicy(startup_timeout=1.0)
|
||||
middleware = ShellToolMiddleware(
|
||||
workspace_root=tmp_path / "workspace", startup_commands=("exit 1",), execution_policy=policy
|
||||
)
|
||||
state: AgentState = _empty_state()
|
||||
with pytest.raises(RuntimeError, match="Startup command.*failed"):
|
||||
middleware.before_agent(state, None)
|
||||
|
||||
|
||||
def test_shutdown_command_failure_logged(tmp_path: Path) -> None:
|
||||
"""Test that shutdown command failures are logged but don't raise."""
|
||||
policy = HostExecutionPolicy(command_timeout=1.0)
|
||||
middleware = ShellToolMiddleware(
|
||||
workspace_root=tmp_path / "workspace",
|
||||
shutdown_commands=("exit 1",),
|
||||
execution_policy=policy,
|
||||
)
|
||||
try:
|
||||
state: AgentState = _empty_state()
|
||||
updates = middleware.before_agent(state, None)
|
||||
if updates:
|
||||
state.update(updates)
|
||||
finally:
|
||||
# Should not raise despite shutdown command failing
|
||||
middleware.after_agent(state, None)
|
||||
|
||||
|
||||
def test_shutdown_command_timeout_logged(tmp_path: Path) -> None:
|
||||
"""Test that shutdown command timeouts are logged but don't raise."""
|
||||
policy = HostExecutionPolicy(command_timeout=0.1)
|
||||
middleware = ShellToolMiddleware(
|
||||
workspace_root=tmp_path / "workspace",
|
||||
execution_policy=policy,
|
||||
shutdown_commands=("sleep 2",),
|
||||
)
|
||||
try:
|
||||
state: AgentState = _empty_state()
|
||||
updates = middleware.before_agent(state, None)
|
||||
if updates:
|
||||
state.update(updates)
|
||||
finally:
|
||||
# Should not raise despite shutdown command timing out
|
||||
middleware.after_agent(state, None)
|
||||
|
||||
|
||||
def test_ensure_resources_missing_state(tmp_path: Path) -> None:
|
||||
"""Test that _ensure_resources raises when resources are missing."""
|
||||
middleware = ShellToolMiddleware(workspace_root=tmp_path / "workspace")
|
||||
state: AgentState = _empty_state()
|
||||
|
||||
with pytest.raises(ToolException, match="Shell session resources are unavailable"):
|
||||
middleware._ensure_resources(state) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def test_empty_output_replaced_with_no_output(tmp_path: Path) -> None:
|
||||
"""Test that empty command output is replaced with '<no output>'."""
|
||||
middleware = ShellToolMiddleware(workspace_root=tmp_path / "workspace")
|
||||
try:
|
||||
state: AgentState = _empty_state()
|
||||
updates = middleware.before_agent(state, None)
|
||||
if updates:
|
||||
state.update(updates)
|
||||
resources = middleware._ensure_resources(state) # type: ignore[attr-defined]
|
||||
|
||||
result = middleware._run_shell_tool(
|
||||
resources,
|
||||
{"command": "true"}, # Command that produces no output
|
||||
tool_call_id=None,
|
||||
)
|
||||
|
||||
assert "<no output>" in result
|
||||
finally:
|
||||
updates = middleware.after_agent(state, None)
|
||||
if updates:
|
||||
state.update(updates)
|
||||
|
||||
|
||||
def test_stderr_output_labeling(tmp_path: Path) -> None:
|
||||
"""Test that stderr output is properly labeled."""
|
||||
middleware = ShellToolMiddleware(workspace_root=tmp_path / "workspace")
|
||||
try:
|
||||
state: AgentState = _empty_state()
|
||||
updates = middleware.before_agent(state, None)
|
||||
if updates:
|
||||
state.update(updates)
|
||||
resources = middleware._ensure_resources(state) # type: ignore[attr-defined]
|
||||
|
||||
result = middleware._run_shell_tool(
|
||||
resources, {"command": "echo error >&2"}, tool_call_id=None
|
||||
)
|
||||
|
||||
assert "[stderr] error" in result
|
||||
finally:
|
||||
updates = middleware.after_agent(state, None)
|
||||
if updates:
|
||||
state.update(updates)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("startup_commands", "expected"),
|
||||
[
|
||||
("echo test", ("echo test",)), # String
|
||||
(["echo test", "pwd"], ("echo test", "pwd")), # List
|
||||
(("echo test",), ("echo test",)), # Tuple
|
||||
(None, ()), # None
|
||||
],
|
||||
)
|
||||
def test_normalize_commands_string_tuple_list(
|
||||
tmp_path: Path,
|
||||
startup_commands: str | list[str] | tuple[str, ...] | None,
|
||||
expected: tuple[str, ...],
|
||||
) -> None:
|
||||
"""Test various command normalization formats."""
|
||||
middleware = ShellToolMiddleware(
|
||||
workspace_root=tmp_path / "workspace", startup_commands=startup_commands
|
||||
)
|
||||
assert middleware._startup_commands == expected # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def test_async_methods_delegate_to_sync(tmp_path: Path) -> None:
|
||||
"""Test that async methods properly delegate to sync methods."""
|
||||
middleware = ShellToolMiddleware(workspace_root=tmp_path / "workspace")
|
||||
try:
|
||||
state: AgentState = _empty_state()
|
||||
|
||||
# Test abefore_agent
|
||||
updates = asyncio.run(middleware.abefore_agent(state, None))
|
||||
if updates:
|
||||
state.update(updates)
|
||||
|
||||
# Test aafter_agent
|
||||
asyncio.run(middleware.aafter_agent(state, None))
|
||||
finally:
|
||||
pass
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
from typing import TYPE_CHECKING
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage, AnyMessage, HumanMessage, RemoveMessage, ToolMessage
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
from langgraph.graph.message import REMOVE_ALL_MESSAGES
|
||||
from unittest.mock import patch
|
||||
|
||||
from langchain.agents.middleware.summarization import SummarizationMiddleware
|
||||
|
||||
@@ -15,6 +15,35 @@ if TYPE_CHECKING:
|
||||
from langchain_model_profiles import ModelProfile
|
||||
|
||||
|
||||
class MockChatModel(BaseChatModel):
|
||||
"""Mock chat model for testing."""
|
||||
|
||||
def invoke(self, prompt): # type: ignore[no-untyped-def]
|
||||
return AIMessage(content="Generated summary")
|
||||
|
||||
def _generate(self, messages, **kwargs): # type: ignore[no-untyped-def]
|
||||
return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))])
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "mock"
|
||||
|
||||
|
||||
class ProfileChatModel(BaseChatModel):
|
||||
"""Mock chat model with profile for testing."""
|
||||
|
||||
def _generate(self, messages, **kwargs): # type: ignore[no-untyped-def]
|
||||
return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))])
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "mock"
|
||||
|
||||
@property
|
||||
def profile(self) -> "ModelProfile":
|
||||
return {"max_input_tokens": 1000}
|
||||
|
||||
|
||||
def test_summarization_middleware_initialization() -> None:
|
||||
"""Test SummarizationMiddleware initialization."""
|
||||
model = FakeToolCallingModel()
|
||||
@@ -139,19 +168,7 @@ def test_summarization_middleware_tool_call_safety() -> None:
|
||||
|
||||
def test_summarization_middleware_summary_creation() -> None:
|
||||
"""Test SummarizationMiddleware summary creation."""
|
||||
|
||||
class MockModel(BaseChatModel):
|
||||
def invoke(self, prompt):
|
||||
return AIMessage(content="Generated summary")
|
||||
|
||||
def _generate(self, messages, **kwargs):
|
||||
return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))])
|
||||
|
||||
@property
|
||||
def _llm_type(self):
|
||||
return "mock"
|
||||
|
||||
middleware = SummarizationMiddleware(model=MockModel(), trigger=("tokens", 1000))
|
||||
middleware = SummarizationMiddleware(model=MockChatModel(), trigger=("tokens", 1000))
|
||||
|
||||
# Test normal summary creation
|
||||
messages = [HumanMessage(content="Hello"), AIMessage(content="Hi")]
|
||||
@@ -180,25 +197,16 @@ def test_summarization_middleware_summary_creation() -> None:
|
||||
|
||||
# Test we raise warning if max_tokens_before_summary or messages_to_keep is specified
|
||||
with pytest.warns(DeprecationWarning, match="max_tokens_before_summary is deprecated"):
|
||||
SummarizationMiddleware(model=MockModel(), max_tokens_before_summary=500)
|
||||
SummarizationMiddleware(model=MockChatModel(), max_tokens_before_summary=500)
|
||||
with pytest.warns(DeprecationWarning, match="messages_to_keep is deprecated"):
|
||||
SummarizationMiddleware(model=MockModel(), messages_to_keep=5)
|
||||
SummarizationMiddleware(model=MockChatModel(), messages_to_keep=5)
|
||||
|
||||
|
||||
def test_summarization_middleware_trim_limit_none_keeps_all_messages() -> None:
|
||||
"""Verify disabling trim limit preserves full message sequence."""
|
||||
|
||||
class MockModel(BaseChatModel):
|
||||
def _generate(self, messages, **kwargs):
|
||||
return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))])
|
||||
|
||||
@property
|
||||
def _llm_type(self):
|
||||
return "mock"
|
||||
|
||||
messages = [HumanMessage(content=str(i)) for i in range(10)]
|
||||
middleware = SummarizationMiddleware(
|
||||
model=MockModel(),
|
||||
model=MockChatModel(),
|
||||
trim_tokens_to_summarize=None,
|
||||
)
|
||||
middleware.token_counter = lambda msgs: len(msgs)
|
||||
@@ -209,23 +217,10 @@ def test_summarization_middleware_trim_limit_none_keeps_all_messages() -> None:
|
||||
|
||||
def test_summarization_middleware_profile_inference_triggers_summary() -> None:
|
||||
"""Ensure automatic profile inference triggers summarization when limits are exceeded."""
|
||||
|
||||
class ProfileModel(BaseChatModel):
|
||||
def _generate(self, messages, **kwargs):
|
||||
return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))])
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "mock"
|
||||
|
||||
@property
|
||||
def profile(self) -> "ModelProfile":
|
||||
return {"max_input_tokens": 1000}
|
||||
|
||||
token_counter = lambda messages: len(messages) * 200
|
||||
|
||||
middleware = SummarizationMiddleware(
|
||||
model=ProfileModel(),
|
||||
model=ProfileChatModel(),
|
||||
trigger=("fraction", 0.81),
|
||||
keep=("fraction", 0.5),
|
||||
token_counter=token_counter,
|
||||
@@ -250,7 +245,7 @@ def test_summarization_middleware_profile_inference_triggers_summary() -> None:
|
||||
# Engage summarization
|
||||
# 0.80 * 1000 == 800 <= 800
|
||||
middleware = SummarizationMiddleware(
|
||||
model=ProfileModel(),
|
||||
model=ProfileChatModel(),
|
||||
trigger=("fraction", 0.80),
|
||||
keep=("fraction", 0.5),
|
||||
token_counter=token_counter,
|
||||
@@ -270,7 +265,7 @@ def test_summarization_middleware_profile_inference_triggers_summary() -> None:
|
||||
# With keep=("fraction", 0.6) the target token allowance becomes 600,
|
||||
# so the cutoff shifts to keep the last three messages instead of two.
|
||||
middleware = SummarizationMiddleware(
|
||||
model=ProfileModel(),
|
||||
model=ProfileChatModel(),
|
||||
trigger=("fraction", 0.80),
|
||||
keep=("fraction", 0.6),
|
||||
token_counter=token_counter,
|
||||
@@ -287,7 +282,7 @@ def test_summarization_middleware_profile_inference_triggers_summary() -> None:
|
||||
# context (target tokens = 800), so token-based retention keeps everything
|
||||
# and summarization is skipped entirely.
|
||||
middleware = SummarizationMiddleware(
|
||||
model=ProfileModel(),
|
||||
model=ProfileChatModel(),
|
||||
trigger=("fraction", 0.80),
|
||||
keep=("fraction", 0.8),
|
||||
token_counter=token_counter,
|
||||
@@ -296,7 +291,7 @@ def test_summarization_middleware_profile_inference_triggers_summary() -> None:
|
||||
|
||||
# Test with tokens_to_keep as absolute int value
|
||||
middleware_int = SummarizationMiddleware(
|
||||
model=ProfileModel(),
|
||||
model=ProfileChatModel(),
|
||||
trigger=("fraction", 0.80),
|
||||
keep=("tokens", 400), # Keep exactly 400 tokens (2 messages)
|
||||
token_counter=token_counter,
|
||||
@@ -310,7 +305,7 @@ def test_summarization_middleware_profile_inference_triggers_summary() -> None:
|
||||
|
||||
# Test with tokens_to_keep as larger int value
|
||||
middleware_int_large = SummarizationMiddleware(
|
||||
model=ProfileModel(),
|
||||
model=ProfileChatModel(),
|
||||
trigger=("fraction", 0.80),
|
||||
keep=("tokens", 600), # Keep 600 tokens (3 messages)
|
||||
token_counter=token_counter,
|
||||
@@ -327,23 +322,11 @@ def test_summarization_middleware_profile_inference_triggers_summary() -> None:
|
||||
def test_summarization_middleware_token_retention_pct_respects_tool_pairs() -> None:
|
||||
"""Ensure token retention keeps pairs together even if exceeding target tokens."""
|
||||
|
||||
class ProfileModel(BaseChatModel):
|
||||
def _generate(self, messages, **kwargs):
|
||||
return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))])
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "mock"
|
||||
|
||||
@property
|
||||
def profile(self) -> "ModelProfile":
|
||||
return {"max_input_tokens": 1000}
|
||||
|
||||
def token_counter(messages):
|
||||
def token_counter(messages: list[AnyMessage]) -> int:
|
||||
return sum(len(getattr(message, "content", "")) for message in messages)
|
||||
|
||||
middleware = SummarizationMiddleware(
|
||||
model=ProfileModel(),
|
||||
model=ProfileChatModel(),
|
||||
trigger=("fraction", 0.1),
|
||||
keep=("fraction", 0.5),
|
||||
)
|
||||
@@ -400,22 +383,10 @@ def test_summarization_middleware_missing_profile() -> None:
|
||||
|
||||
def test_summarization_middleware_full_workflow() -> None:
|
||||
"""Test SummarizationMiddleware complete summarization workflow."""
|
||||
|
||||
class MockModel(BaseChatModel):
|
||||
def invoke(self, prompt):
|
||||
return AIMessage(content="Generated summary")
|
||||
|
||||
def _generate(self, messages, **kwargs):
|
||||
return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))])
|
||||
|
||||
@property
|
||||
def _llm_type(self):
|
||||
return "mock"
|
||||
|
||||
with pytest.warns(DeprecationWarning):
|
||||
# keep test for functionality
|
||||
middleware = SummarizationMiddleware(
|
||||
model=MockModel(), max_tokens_before_summary=1000, messages_to_keep=2
|
||||
model=MockChatModel(), max_tokens_before_summary=1000, messages_to_keep=2
|
||||
)
|
||||
|
||||
# Mock high token count to trigger summarization
|
||||
@@ -504,21 +475,9 @@ async def test_summarization_middleware_full_workflow_async() -> None:
|
||||
|
||||
def test_summarization_middleware_keep_messages() -> None:
|
||||
"""Test SummarizationMiddleware with keep parameter specifying messages."""
|
||||
|
||||
class MockModel(BaseChatModel):
|
||||
def invoke(self, prompt):
|
||||
return AIMessage(content="Generated summary")
|
||||
|
||||
def _generate(self, messages, **kwargs):
|
||||
return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))])
|
||||
|
||||
@property
|
||||
def _llm_type(self):
|
||||
return "mock"
|
||||
|
||||
# Test that summarization is triggered when message count reaches threshold
|
||||
middleware = SummarizationMiddleware(
|
||||
model=MockModel(), trigger=("messages", 5), keep=("messages", 2)
|
||||
model=MockChatModel(), trigger=("messages", 5), keep=("messages", 2)
|
||||
)
|
||||
|
||||
# Below threshold - no summarization
|
||||
@@ -561,6 +520,379 @@ def test_summarization_middleware_keep_messages() -> None:
|
||||
assert [message.content for message in result["messages"][2:]] == ["5", "6"]
|
||||
|
||||
# Test with both parameters disabled
|
||||
middleware_disabled = SummarizationMiddleware(model=MockModel(), trigger=None)
|
||||
middleware_disabled = SummarizationMiddleware(model=MockChatModel(), trigger=None)
|
||||
result = middleware_disabled.before_model(state_above, None)
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("param_name", "param_value", "expected_error"),
|
||||
[
|
||||
("trigger", ("fraction", 0.0), "Fractional trigger values must be between 0 and 1"),
|
||||
("trigger", ("fraction", 1.5), "Fractional trigger values must be between 0 and 1"),
|
||||
("keep", ("fraction", -0.1), "Fractional keep values must be between 0 and 1"),
|
||||
("trigger", ("tokens", 0), "trigger thresholds must be greater than 0"),
|
||||
("trigger", ("messages", -5), "trigger thresholds must be greater than 0"),
|
||||
("keep", ("tokens", 0), "keep thresholds must be greater than 0"),
|
||||
("trigger", ("invalid", 100), "Unsupported context size type"),
|
||||
("keep", ("invalid", 100), "Unsupported context size type"),
|
||||
],
|
||||
)
|
||||
def test_summarization_middleware_validation_edge_cases(
|
||||
param_name: str, param_value: tuple[str, float | int], expected_error: str
|
||||
) -> None:
|
||||
"""Test validation of context size parameters with edge cases."""
|
||||
model = FakeToolCallingModel()
|
||||
with pytest.raises(ValueError, match=expected_error):
|
||||
SummarizationMiddleware(model=model, **{param_name: param_value})
|
||||
|
||||
|
||||
def test_summarization_middleware_multiple_triggers() -> None:
|
||||
"""Test middleware with multiple trigger conditions."""
|
||||
# Test with multiple triggers - should activate when ANY condition is met
|
||||
middleware = SummarizationMiddleware(
|
||||
model=MockChatModel(),
|
||||
trigger=[("messages", 10), ("tokens", 500)],
|
||||
keep=("messages", 2),
|
||||
)
|
||||
|
||||
# Mock token counter to return low count
|
||||
def mock_low_tokens(messages):
|
||||
return 100
|
||||
|
||||
middleware.token_counter = mock_low_tokens
|
||||
|
||||
# Should not trigger - neither condition met
|
||||
messages = [HumanMessage(content=str(i)) for i in range(5)]
|
||||
state = {"messages": messages}
|
||||
result = middleware.before_model(state, None)
|
||||
assert result is None
|
||||
|
||||
# Should trigger - message count threshold met
|
||||
messages = [HumanMessage(content=str(i)) for i in range(10)]
|
||||
state = {"messages": messages}
|
||||
result = middleware.before_model(state, None)
|
||||
assert result is not None
|
||||
|
||||
# Test token trigger
|
||||
def mock_high_tokens(messages):
|
||||
return 600
|
||||
|
||||
middleware.token_counter = mock_high_tokens
|
||||
messages = [HumanMessage(content=str(i)) for i in range(5)]
|
||||
state = {"messages": messages}
|
||||
result = middleware.before_model(state, None)
|
||||
assert result is not None
|
||||
|
||||
|
||||
def test_summarization_middleware_profile_edge_cases() -> None:
|
||||
"""Test profile retrieval with various edge cases."""
|
||||
|
||||
class NoProfileModel(BaseChatModel):
|
||||
def _generate(self, messages, **kwargs):
|
||||
return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))])
|
||||
|
||||
@property
|
||||
def _llm_type(self):
|
||||
return "mock"
|
||||
|
||||
# Model without profile attribute
|
||||
middleware = SummarizationMiddleware(model=NoProfileModel(), trigger=("messages", 5))
|
||||
assert middleware._get_profile_limits() is None
|
||||
|
||||
class InvalidProfileModel(BaseChatModel):
|
||||
def _generate(self, messages, **kwargs):
|
||||
return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))])
|
||||
|
||||
@property
|
||||
def _llm_type(self):
|
||||
return "mock"
|
||||
|
||||
@property
|
||||
def profile(self):
|
||||
return "invalid_profile_type"
|
||||
|
||||
# Model with non-dict profile
|
||||
middleware = SummarizationMiddleware(model=InvalidProfileModel(), trigger=("messages", 5))
|
||||
assert middleware._get_profile_limits() is None
|
||||
|
||||
class MissingTokensModel(BaseChatModel):
|
||||
def _generate(self, messages, **kwargs):
|
||||
return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))])
|
||||
|
||||
@property
|
||||
def _llm_type(self):
|
||||
return "mock"
|
||||
|
||||
@property
|
||||
def profile(self):
|
||||
return {"other_field": 100}
|
||||
|
||||
# Model with profile but no max_input_tokens
|
||||
middleware = SummarizationMiddleware(model=MissingTokensModel(), trigger=("messages", 5))
|
||||
assert middleware._get_profile_limits() is None
|
||||
|
||||
class InvalidTokenTypeModel(BaseChatModel):
|
||||
def _generate(self, messages, **kwargs):
|
||||
return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))])
|
||||
|
||||
@property
|
||||
def _llm_type(self):
|
||||
return "mock"
|
||||
|
||||
@property
|
||||
def profile(self):
|
||||
return {"max_input_tokens": "not_an_int"}
|
||||
|
||||
# Model with non-int max_input_tokens
|
||||
middleware = SummarizationMiddleware(model=InvalidTokenTypeModel(), trigger=("messages", 5))
|
||||
assert middleware._get_profile_limits() is None
|
||||
|
||||
|
||||
def test_summarization_middleware_trim_messages_error_fallback() -> None:
|
||||
"""Test that trim_messages_for_summary falls back gracefully on errors."""
|
||||
middleware = SummarizationMiddleware(model=MockChatModel(), trigger=("messages", 5))
|
||||
|
||||
# Create a mock token counter that raises an exception
|
||||
def failing_token_counter(messages):
|
||||
raise Exception("Token counting failed")
|
||||
|
||||
middleware.token_counter = failing_token_counter
|
||||
|
||||
# Should fall back to last 15 messages
|
||||
messages = [HumanMessage(content=str(i)) for i in range(20)]
|
||||
trimmed = middleware._trim_messages_for_summary(messages)
|
||||
assert len(trimmed) == 15
|
||||
assert trimmed == messages[-15:]
|
||||
|
||||
|
||||
def test_summarization_middleware_binary_search_edge_cases() -> None:
|
||||
"""Test binary search in _find_token_based_cutoff with edge cases."""
|
||||
middleware = SummarizationMiddleware(
|
||||
model=MockChatModel(), trigger=("messages", 5), keep=("tokens", 100)
|
||||
)
|
||||
|
||||
# Test with single message that's too large
|
||||
def token_counter_single_large(messages):
|
||||
return len(messages) * 200
|
||||
|
||||
middleware.token_counter = token_counter_single_large
|
||||
|
||||
single_message = [HumanMessage(content="x" * 200)]
|
||||
cutoff = middleware._find_token_based_cutoff(single_message)
|
||||
assert cutoff == 0
|
||||
|
||||
# Test with empty messages
|
||||
cutoff = middleware._find_token_based_cutoff([])
|
||||
assert cutoff == 0
|
||||
|
||||
# Test when all messages fit within token budget
|
||||
def token_counter_small(messages):
|
||||
return len(messages) * 10
|
||||
|
||||
middleware.token_counter = token_counter_small
|
||||
messages = [HumanMessage(content=str(i)) for i in range(5)]
|
||||
cutoff = middleware._find_token_based_cutoff(messages)
|
||||
assert cutoff == 0
|
||||
|
||||
|
||||
def test_summarization_middleware_tool_call_extraction_edge_cases() -> None:
|
||||
"""Test tool call ID extraction with various message formats."""
|
||||
model = FakeToolCallingModel()
|
||||
middleware = SummarizationMiddleware(model=model, trigger=("messages", 5))
|
||||
|
||||
# Test with dict-style tool calls
|
||||
ai_message_dict = AIMessage(
|
||||
content="test", tool_calls=[{"name": "tool1", "args": {}, "id": "id1"}]
|
||||
)
|
||||
ids = middleware._extract_tool_call_ids(ai_message_dict)
|
||||
assert ids == {"id1"}
|
||||
|
||||
# Test with multiple tool calls
|
||||
ai_message_multiple = AIMessage(
|
||||
content="test",
|
||||
tool_calls=[
|
||||
{"name": "tool1", "args": {}, "id": "id1"},
|
||||
{"name": "tool2", "args": {}, "id": "id2"},
|
||||
],
|
||||
)
|
||||
ids = middleware._extract_tool_call_ids(ai_message_multiple)
|
||||
assert ids == {"id1", "id2"}
|
||||
|
||||
# Test with empty tool calls list
|
||||
ai_message_empty = AIMessage(content="test", tool_calls=[])
|
||||
ids = middleware._extract_tool_call_ids(ai_message_empty)
|
||||
assert len(ids) == 0
|
||||
|
||||
|
||||
def test_summarization_middleware_complex_tool_pair_scenarios() -> None:
|
||||
"""Test complex tool call pairing scenarios."""
|
||||
model = FakeToolCallingModel()
|
||||
middleware = SummarizationMiddleware(model=model, trigger=("messages", 5), keep=("messages", 3))
|
||||
|
||||
# Test with multiple AI messages with tool calls
|
||||
messages = [
|
||||
HumanMessage(content="msg1"),
|
||||
AIMessage(content="ai1", tool_calls=[{"name": "tool1", "args": {}, "id": "call1"}]),
|
||||
ToolMessage(content="result1", tool_call_id="call1"),
|
||||
HumanMessage(content="msg2"),
|
||||
AIMessage(content="ai2", tool_calls=[{"name": "tool2", "args": {}, "id": "call2"}]),
|
||||
ToolMessage(content="result2", tool_call_id="call2"),
|
||||
HumanMessage(content="msg3"),
|
||||
]
|
||||
|
||||
# Test cutoff at index 1 - unsafe (separates first AI/Tool pair)
|
||||
assert not middleware._is_safe_cutoff_point(messages, 2)
|
||||
|
||||
# Test cutoff at index 3 - safe (keeps first pair together)
|
||||
assert middleware._is_safe_cutoff_point(messages, 3)
|
||||
|
||||
# Test cutoff at index 5 - unsafe (separates second AI/Tool pair)
|
||||
assert not middleware._is_safe_cutoff_point(messages, 5)
|
||||
|
||||
# Test _cutoff_separates_tool_pair directly
|
||||
assert middleware._cutoff_separates_tool_pair(messages, 1, 2, {"call1"})
|
||||
assert not middleware._cutoff_separates_tool_pair(messages, 1, 0, {"call1"})
|
||||
assert not middleware._cutoff_separates_tool_pair(messages, 1, 3, {"call1"})
|
||||
|
||||
|
||||
def test_summarization_middleware_tool_call_in_search_range() -> None:
|
||||
"""Test tool call safety with messages at edge of search range."""
|
||||
model = FakeToolCallingModel()
|
||||
middleware = SummarizationMiddleware(
|
||||
model=model, trigger=("messages", 10), keep=("messages", 2)
|
||||
)
|
||||
|
||||
# Create messages with tool pair separated by some distance
|
||||
# Search range is 5, so messages within 5 positions of cutoff are checked
|
||||
messages = [
|
||||
HumanMessage(content="msg1"),
|
||||
HumanMessage(content="msg2"),
|
||||
AIMessage(content="ai", tool_calls=[{"name": "tool", "args": {}, "id": "call1"}]),
|
||||
HumanMessage(content="msg3"),
|
||||
HumanMessage(content="msg4"),
|
||||
ToolMessage(content="result", tool_call_id="call1"),
|
||||
HumanMessage(content="msg6"),
|
||||
]
|
||||
|
||||
# Cutoff at index 3 would separate: [0,1,2] from [3,4,5,6]
|
||||
# AI at index 2 is before cutoff, Tool at index 5 is after cutoff - unsafe
|
||||
assert not middleware._is_safe_cutoff_point(messages, 3)
|
||||
|
||||
# Cutoff at index 6 keeps AI and Tool both in summarized section
|
||||
assert middleware._is_safe_cutoff_point(messages, 6)
|
||||
|
||||
# Cutoff at index 0 or 1 also safe - both AI and Tool in preserved section
|
||||
assert middleware._is_safe_cutoff_point(messages, 0)
|
||||
assert middleware._is_safe_cutoff_point(messages, 1)
|
||||
|
||||
|
||||
def test_summarization_middleware_zero_and_negative_target_tokens() -> None:
|
||||
"""Test handling of edge cases with target token calculations."""
|
||||
# Test with very small fraction that rounds to zero
|
||||
middleware = SummarizationMiddleware(
|
||||
model=ProfileChatModel(), trigger=("fraction", 0.0001), keep=("fraction", 0.0001)
|
||||
)
|
||||
|
||||
# Should set threshold to 1 when calculated value is <= 0
|
||||
messages = [HumanMessage(content="test")]
|
||||
state = {"messages": messages}
|
||||
|
||||
# The trigger fraction calculation: int(1000 * 0.0001) = 0, but should be set to 1
|
||||
# Token count of 1 message should exceed threshold of 1
|
||||
def token_counter(msgs):
|
||||
return 2
|
||||
|
||||
middleware.token_counter = token_counter
|
||||
assert middleware._should_summarize(messages, 2)
|
||||
|
||||
|
||||
async def test_summarization_middleware_async_error_handling() -> None:
|
||||
"""Test async summary creation with errors."""
|
||||
|
||||
class ErrorAsyncModel(BaseChatModel):
|
||||
def _generate(self, messages, **kwargs):
|
||||
return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))])
|
||||
|
||||
async def _agenerate(self, messages, **kwargs):
|
||||
raise Exception("Async model error")
|
||||
|
||||
@property
|
||||
def _llm_type(self):
|
||||
return "mock"
|
||||
|
||||
middleware = SummarizationMiddleware(model=ErrorAsyncModel(), trigger=("messages", 5))
|
||||
messages = [HumanMessage(content="test")]
|
||||
summary = await middleware._acreate_summary(messages)
|
||||
assert "Error generating summary: Async model error" in summary
|
||||
|
||||
|
||||
def test_summarization_middleware_cutoff_at_boundary() -> None:
|
||||
"""Test cutoff index determination at exact message boundaries."""
|
||||
middleware = SummarizationMiddleware(
|
||||
model=MockChatModel(), trigger=("messages", 5), keep=("messages", 5)
|
||||
)
|
||||
|
||||
# When we want to keep exactly as many messages as we have
|
||||
messages = [HumanMessage(content=str(i)) for i in range(5)]
|
||||
cutoff = middleware._find_safe_cutoff(messages, 5)
|
||||
assert cutoff == 0 # Should not cut anything
|
||||
|
||||
# When we want to keep more messages than we have
|
||||
cutoff = middleware._find_safe_cutoff(messages, 10)
|
||||
assert cutoff == 0
|
||||
|
||||
|
||||
def test_summarization_middleware_deprecated_parameters_with_defaults() -> None:
|
||||
"""Test that deprecated parameters work correctly with default values."""
|
||||
# Test that deprecated max_tokens_before_summary is ignored when trigger is set
|
||||
with pytest.warns(DeprecationWarning):
|
||||
middleware = SummarizationMiddleware(
|
||||
model=MockChatModel(), trigger=("tokens", 2000), max_tokens_before_summary=1000
|
||||
)
|
||||
assert middleware.trigger == ("tokens", 2000)
|
||||
|
||||
# Test that messages_to_keep is ignored when keep is not default
|
||||
with pytest.warns(DeprecationWarning):
|
||||
middleware = SummarizationMiddleware(
|
||||
model=MockChatModel(), keep=("messages", 5), messages_to_keep=10
|
||||
)
|
||||
assert middleware.keep == ("messages", 5)
|
||||
|
||||
|
||||
def test_summarization_middleware_fraction_trigger_with_no_profile() -> None:
|
||||
"""Test fractional trigger condition when profile data becomes unavailable."""
|
||||
middleware = SummarizationMiddleware(
|
||||
model=ProfileChatModel(),
|
||||
trigger=[("fraction", 0.5), ("messages", 100)],
|
||||
keep=("messages", 5),
|
||||
)
|
||||
|
||||
# Test that when fractional condition can't be evaluated, other triggers still work
|
||||
messages = [HumanMessage(content=str(i)) for i in range(100)]
|
||||
|
||||
# Mock _get_profile_limits to return None
|
||||
original_method = middleware._get_profile_limits
|
||||
middleware._get_profile_limits = lambda: None
|
||||
|
||||
# Should still trigger based on message count
|
||||
state = {"messages": messages}
|
||||
result = middleware.before_model(state, None)
|
||||
assert result is not None
|
||||
|
||||
# Restore original method
|
||||
middleware._get_profile_limits = original_method
|
||||
|
||||
|
||||
def test_summarization_middleware_is_safe_cutoff_at_end() -> None:
|
||||
"""Test _is_safe_cutoff_point when cutoff is at or past the end."""
|
||||
model = FakeToolCallingModel()
|
||||
middleware = SummarizationMiddleware(model=model, trigger=("messages", 5))
|
||||
|
||||
messages = [HumanMessage(content=str(i)) for i in range(5)]
|
||||
|
||||
# Cutoff at exactly the length should be safe
|
||||
assert middleware._is_safe_cutoff_point(messages, len(messages))
|
||||
|
||||
# Cutoff past the length should also be safe
|
||||
assert middleware._is_safe_cutoff_point(messages, len(messages) + 5)
|
||||
|
||||
@@ -1,22 +1,22 @@
|
||||
"""Unit tests for LLM tool selection middleware."""
|
||||
|
||||
import typing
|
||||
from typing import Union, Any, Literal
|
||||
|
||||
from itertools import cycle
|
||||
from typing import Any, Literal, Union
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.agents import create_agent
|
||||
from langchain.agents.middleware import AgentState, ModelRequest, wrap_model_call
|
||||
from langchain.agents.middleware import LLMToolSelectorMiddleware
|
||||
from langchain.agents.middleware import LLMToolSelectorMiddleware, ModelRequest, wrap_model_call
|
||||
from langchain.agents.middleware.tool_selection import _create_tool_selection_response
|
||||
from langchain.agents.middleware.types import AgentState
|
||||
from langchain.messages import AIMessage
|
||||
from langchain_core.language_models import LanguageModelInput
|
||||
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import BaseMessage, HumanMessage
|
||||
from langchain_core.runnables import Runnable
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.tools import tool
|
||||
from langchain_core.tools import BaseTool, tool
|
||||
|
||||
|
||||
@tool
|
||||
@@ -596,3 +596,12 @@ class TestDuplicateAndInvalidTools:
|
||||
assert len(tool_names) == 2
|
||||
assert "get_weather" in tool_names
|
||||
assert "search_web" in tool_names
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
"""Test edge cases and error handling."""
|
||||
|
||||
def test_empty_tools_list_raises_error(self) -> None:
|
||||
"""Test that empty tools list raises an error in schema creation."""
|
||||
with pytest.raises(AssertionError, match="tools must be non-empty"):
|
||||
_create_tool_selection_response([])
|
||||
|
||||
Reference in New Issue
Block a user