mirror of
https://github.com/hwchase17/langchain.git
synced 2026-04-18 08:03:52 +00:00
Compare commits
16 Commits
replace_ap
...
nh/danglin
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
526cbb36e6 | ||
|
|
68ceeb64f6 | ||
|
|
edae976b81 | ||
|
|
9f4366bc9d | ||
|
|
99e0a60aab | ||
|
|
d38729fbac | ||
|
|
ff0d21cfd5 | ||
|
|
9140a7cb86 | ||
|
|
41fe18bc80 | ||
|
|
9105573cb3 | ||
|
|
fff87e95d1 | ||
|
|
9beb29a34c | ||
|
|
ca00f5aed9 | ||
|
|
637777b8e7 | ||
|
|
1cf851e054 | ||
|
|
961f965f0c |
2
.github/workflows/check_diffs.yml
vendored
2
.github/workflows/check_diffs.yml
vendored
@@ -186,7 +186,7 @@ jobs:
|
||||
|
||||
# We have to use 3.12 as 3.13 is not yet supported
|
||||
- name: "📦 Install UV Package Manager"
|
||||
uses: astral-sh/setup-uv@v6
|
||||
uses: astral-sh/setup-uv@v7
|
||||
with:
|
||||
python-version: "3.12"
|
||||
|
||||
|
||||
@@ -1,39 +1,135 @@
|
||||
"""Derivations of standard content blocks from Groq content."""
|
||||
|
||||
import warnings
|
||||
import json
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk
|
||||
from langchain_core.messages import content as types
|
||||
|
||||
WARNED = False
|
||||
from langchain_core.messages.base import _extract_reasoning_from_additional_kwargs
|
||||
|
||||
|
||||
def translate_content(message: AIMessage) -> list[types.ContentBlock]: # noqa: ARG001
|
||||
"""Derive standard content blocks from a message with Groq content."""
|
||||
global WARNED # noqa: PLW0603
|
||||
if not WARNED:
|
||||
warning_message = (
|
||||
"Content block standardization is not yet fully supported for Groq."
|
||||
def _populate_extras(
|
||||
standard_block: types.ContentBlock, block: dict[str, Any], known_fields: set[str]
|
||||
) -> types.ContentBlock:
|
||||
"""Mutate a block, populating extras."""
|
||||
if standard_block.get("type") == "non_standard":
|
||||
return standard_block
|
||||
|
||||
for key, value in block.items():
|
||||
if key not in known_fields:
|
||||
if "extras" not in standard_block:
|
||||
# Below type-ignores are because mypy thinks a non-standard block can
|
||||
# get here, although we exclude them above.
|
||||
standard_block["extras"] = {} # type: ignore[typeddict-unknown-key]
|
||||
standard_block["extras"][key] = value # type: ignore[typeddict-item]
|
||||
|
||||
return standard_block
|
||||
|
||||
|
||||
def _parse_code_json(s: str) -> dict:
|
||||
"""Extract Python code from Groq built-in tool content.
|
||||
|
||||
Extracts the value of the 'code' field from a string of the form:
|
||||
{"code": some_arbitrary_text_with_unescaped_quotes}
|
||||
|
||||
As Groq may not escape quotes in the executed tools, e.g.:
|
||||
```
|
||||
'{"code": "import math; print("The square root of 101 is: "); print(math.sqrt(101))"}'
|
||||
```
|
||||
""" # noqa: E501
|
||||
m = re.fullmatch(r'\s*\{\s*"code"\s*:\s*"(.*)"\s*\}\s*', s, flags=re.DOTALL)
|
||||
if not m:
|
||||
msg = (
|
||||
"Could not extract Python code from Groq tool arguments. "
|
||||
"Expected a JSON object with a 'code' field."
|
||||
)
|
||||
warnings.warn(warning_message, stacklevel=2)
|
||||
WARNED = True
|
||||
raise NotImplementedError
|
||||
raise ValueError(msg)
|
||||
return {"code": m.group(1)}
|
||||
|
||||
|
||||
def translate_content_chunk(message: AIMessageChunk) -> list[types.ContentBlock]: # noqa: ARG001
|
||||
"""Derive standard content blocks from a message chunk with Groq content."""
|
||||
global WARNED # noqa: PLW0603
|
||||
if not WARNED:
|
||||
warning_message = (
|
||||
"Content block standardization is not yet fully supported for Groq."
|
||||
def _convert_to_v1_from_groq(message: AIMessage) -> list[types.ContentBlock]:
|
||||
"""Convert groq message content to v1 format."""
|
||||
content_blocks: list[types.ContentBlock] = []
|
||||
|
||||
if reasoning_block := _extract_reasoning_from_additional_kwargs(message):
|
||||
content_blocks.append(reasoning_block)
|
||||
|
||||
if executed_tools := message.additional_kwargs.get("executed_tools"):
|
||||
for idx, executed_tool in enumerate(executed_tools):
|
||||
args: dict[str, Any] | None = None
|
||||
if arguments := executed_tool.get("arguments"):
|
||||
try:
|
||||
args = json.loads(arguments)
|
||||
except json.JSONDecodeError:
|
||||
if executed_tool.get("type") == "python":
|
||||
try:
|
||||
args = _parse_code_json(arguments)
|
||||
except ValueError:
|
||||
continue
|
||||
elif (
|
||||
executed_tool.get("type") == "function"
|
||||
and executed_tool.get("name") == "python"
|
||||
):
|
||||
# GPT-OSS
|
||||
args = {"code": arguments}
|
||||
else:
|
||||
continue
|
||||
if isinstance(args, dict):
|
||||
name = ""
|
||||
if executed_tool.get("type") == "search":
|
||||
name = "web_search"
|
||||
elif executed_tool.get("type") == "python" or (
|
||||
executed_tool.get("type") == "function"
|
||||
and executed_tool.get("name") == "python"
|
||||
):
|
||||
name = "code_interpreter"
|
||||
server_tool_call: types.ServerToolCall = {
|
||||
"type": "server_tool_call",
|
||||
"name": name,
|
||||
"id": str(idx),
|
||||
"args": args,
|
||||
}
|
||||
content_blocks.append(server_tool_call)
|
||||
if tool_output := executed_tool.get("output"):
|
||||
tool_result: types.ServerToolResult = {
|
||||
"type": "server_tool_result",
|
||||
"tool_call_id": str(idx),
|
||||
"output": tool_output,
|
||||
"status": "success",
|
||||
}
|
||||
known_fields = {"type", "arguments", "index", "output"}
|
||||
_populate_extras(tool_result, executed_tool, known_fields)
|
||||
content_blocks.append(tool_result)
|
||||
|
||||
if isinstance(message.content, str) and message.content:
|
||||
content_blocks.append({"type": "text", "text": message.content})
|
||||
|
||||
for tool_call in message.tool_calls:
|
||||
content_blocks.append( # noqa: PERF401
|
||||
{
|
||||
"type": "tool_call",
|
||||
"name": tool_call["name"],
|
||||
"args": tool_call["args"],
|
||||
"id": tool_call.get("id"),
|
||||
}
|
||||
)
|
||||
warnings.warn(warning_message, stacklevel=2)
|
||||
WARNED = True
|
||||
raise NotImplementedError
|
||||
|
||||
return content_blocks
|
||||
|
||||
|
||||
def translate_content(message: AIMessage) -> list[types.ContentBlock]:
|
||||
"""Derive standard content blocks from a message with groq content."""
|
||||
return _convert_to_v1_from_groq(message)
|
||||
|
||||
|
||||
def translate_content_chunk(message: AIMessageChunk) -> list[types.ContentBlock]:
|
||||
"""Derive standard content blocks from a message chunk with groq content."""
|
||||
return _convert_to_v1_from_groq(message)
|
||||
|
||||
|
||||
def _register_groq_translator() -> None:
|
||||
"""Register the Groq translator with the central registry.
|
||||
"""Register the groq translator with the central registry.
|
||||
|
||||
Run automatically when the module is imported.
|
||||
"""
|
||||
|
||||
@@ -1,30 +0,0 @@
|
||||
"""Pydantic v1 compatibility shim."""
|
||||
|
||||
from importlib import metadata
|
||||
|
||||
from pydantic.v1 import * # noqa: F403
|
||||
|
||||
from langchain_core._api.deprecation import warn_deprecated
|
||||
|
||||
try:
|
||||
_PYDANTIC_MAJOR_VERSION: int = int(metadata.version("pydantic").split(".")[0])
|
||||
except metadata.PackageNotFoundError:
|
||||
_PYDANTIC_MAJOR_VERSION = 0
|
||||
|
||||
warn_deprecated(
|
||||
"0.3.0",
|
||||
removal="1.0.0",
|
||||
alternative="pydantic.v1 or pydantic",
|
||||
message=(
|
||||
"As of langchain-core 0.3.0, LangChain uses pydantic v2 internally. "
|
||||
"The langchain_core.pydantic_v1 module was a "
|
||||
"compatibility shim for pydantic v1, and should no longer be used. "
|
||||
"Please update the code to import from Pydantic directly.\n\n"
|
||||
"For example, replace imports like: "
|
||||
"`from langchain_core.pydantic_v1 import BaseModel`\n"
|
||||
"with: `from pydantic import BaseModel`\n"
|
||||
"or the v1 compatibility namespace if you are working in a code base "
|
||||
"that has not been fully upgraded to pydantic 2 yet. "
|
||||
"\tfrom pydantic.v1 import BaseModel\n"
|
||||
),
|
||||
)
|
||||
@@ -1,23 +0,0 @@
|
||||
"""Pydantic v1 compatibility shim."""
|
||||
|
||||
from pydantic.v1.dataclasses import * # noqa: F403
|
||||
|
||||
from langchain_core._api import warn_deprecated
|
||||
|
||||
warn_deprecated(
|
||||
"0.3.0",
|
||||
removal="1.0.0",
|
||||
alternative="pydantic.v1 or pydantic",
|
||||
message=(
|
||||
"As of langchain-core 0.3.0, LangChain uses pydantic v2 internally. "
|
||||
"The langchain_core.pydantic_v1 module was a "
|
||||
"compatibility shim for pydantic v1, and should no longer be used. "
|
||||
"Please update the code to import from Pydantic directly.\n\n"
|
||||
"For example, replace imports like: "
|
||||
"`from langchain_core.pydantic_v1 import BaseModel`\n"
|
||||
"with: `from pydantic import BaseModel`\n"
|
||||
"or the v1 compatibility namespace if you are working in a code base "
|
||||
"that has not been fully upgraded to pydantic 2 yet. "
|
||||
"\tfrom pydantic.v1 import BaseModel\n"
|
||||
),
|
||||
)
|
||||
@@ -1,23 +0,0 @@
|
||||
"""Pydantic v1 compatibility shim."""
|
||||
|
||||
from pydantic.v1.main import * # noqa: F403
|
||||
|
||||
from langchain_core._api import warn_deprecated
|
||||
|
||||
warn_deprecated(
|
||||
"0.3.0",
|
||||
removal="1.0.0",
|
||||
alternative="pydantic.v1 or pydantic",
|
||||
message=(
|
||||
"As of langchain-core 0.3.0, LangChain uses pydantic v2 internally. "
|
||||
"The langchain_core.pydantic_v1 module was a "
|
||||
"compatibility shim for pydantic v1, and should no longer be used. "
|
||||
"Please update the code to import from Pydantic directly.\n\n"
|
||||
"For example, replace imports like: "
|
||||
"`from langchain_core.pydantic_v1 import BaseModel`\n"
|
||||
"with: `from pydantic import BaseModel`\n"
|
||||
"or the v1 compatibility namespace if you are working in a code base "
|
||||
"that has not been fully upgraded to pydantic 2 yet. "
|
||||
"\tfrom pydantic.v1 import BaseModel\n"
|
||||
),
|
||||
)
|
||||
@@ -27,7 +27,7 @@ from pydantic.v1 import create_model as create_model_v1
|
||||
from typing_extensions import TypedDict, is_typeddict
|
||||
|
||||
import langchain_core
|
||||
from langchain_core._api import beta, deprecated
|
||||
from langchain_core._api import beta
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage
|
||||
from langchain_core.utils.json_schema import dereference_refs
|
||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
@@ -168,42 +168,6 @@ def _convert_pydantic_to_openai_function(
|
||||
)
|
||||
|
||||
|
||||
convert_pydantic_to_openai_function = deprecated(
|
||||
"0.1.16",
|
||||
alternative="langchain_core.utils.function_calling.convert_to_openai_function()",
|
||||
removal="1.0",
|
||||
)(_convert_pydantic_to_openai_function)
|
||||
|
||||
|
||||
@deprecated(
|
||||
"0.1.16",
|
||||
alternative="langchain_core.utils.function_calling.convert_to_openai_tool()",
|
||||
removal="1.0",
|
||||
)
|
||||
def convert_pydantic_to_openai_tool(
|
||||
model: type[BaseModel],
|
||||
*,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
) -> ToolDescription:
|
||||
"""Converts a Pydantic model to a function description for the OpenAI API.
|
||||
|
||||
Args:
|
||||
model: The Pydantic model to convert.
|
||||
name: The name of the function. If not provided, the title of the schema will be
|
||||
used.
|
||||
description: The description of the function. If not provided, the description
|
||||
of the schema will be used.
|
||||
|
||||
Returns:
|
||||
The tool description.
|
||||
"""
|
||||
function = _convert_pydantic_to_openai_function(
|
||||
model, name=name, description=description
|
||||
)
|
||||
return {"type": "function", "function": function}
|
||||
|
||||
|
||||
def _get_python_function_name(function: Callable) -> str:
|
||||
"""Get the name of a Python function."""
|
||||
return function.__name__
|
||||
@@ -240,13 +204,6 @@ def _convert_python_function_to_openai_function(
|
||||
)
|
||||
|
||||
|
||||
convert_python_function_to_openai_function = deprecated(
|
||||
"0.1.16",
|
||||
alternative="langchain_core.utils.function_calling.convert_to_openai_function()",
|
||||
removal="1.0",
|
||||
)(_convert_python_function_to_openai_function)
|
||||
|
||||
|
||||
def _convert_typed_dict_to_openai_function(typed_dict: type) -> FunctionDescription:
|
||||
visited: dict = {}
|
||||
|
||||
@@ -368,31 +325,6 @@ def _format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription:
|
||||
}
|
||||
|
||||
|
||||
format_tool_to_openai_function = deprecated(
|
||||
"0.1.16",
|
||||
alternative="langchain_core.utils.function_calling.convert_to_openai_function()",
|
||||
removal="1.0",
|
||||
)(_format_tool_to_openai_function)
|
||||
|
||||
|
||||
@deprecated(
|
||||
"0.1.16",
|
||||
alternative="langchain_core.utils.function_calling.convert_to_openai_tool()",
|
||||
removal="1.0",
|
||||
)
|
||||
def format_tool_to_openai_tool(tool: BaseTool) -> ToolDescription:
|
||||
"""Format tool into the OpenAI function API.
|
||||
|
||||
Args:
|
||||
tool: The tool to format.
|
||||
|
||||
Returns:
|
||||
The tool description.
|
||||
"""
|
||||
function = _format_tool_to_openai_function(tool)
|
||||
return {"type": "function", "function": function}
|
||||
|
||||
|
||||
def convert_to_openai_function(
|
||||
function: dict[str, Any] | type | Callable | BaseTool,
|
||||
*,
|
||||
|
||||
@@ -486,8 +486,10 @@ def test_provider_warns() -> None:
|
||||
# implemented.
|
||||
# This test should be removed when all major providers support content block
|
||||
# standardization.
|
||||
message = AIMessage("Hello.", response_metadata={"model_provider": "groq"})
|
||||
with pytest.warns(match="not yet fully supported for Groq"):
|
||||
message = AIMessage(
|
||||
"Hello.", response_metadata={"model_provider": "google_vertexai"}
|
||||
)
|
||||
with pytest.warns(match="not yet fully supported for Google VertexAI"):
|
||||
content_blocks = message.content_blocks
|
||||
|
||||
assert content_blocks == [{"type": "text", "text": "Hello."}]
|
||||
|
||||
@@ -3,7 +3,9 @@ from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.output_parsers.openai_tools import PydanticToolsParser
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.runnables import Runnable
|
||||
from langchain_core.utils.function_calling import convert_pydantic_to_openai_function
|
||||
from langchain_core.utils.function_calling import (
|
||||
convert_to_openai_function as convert_pydantic_to_openai_function,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
|
||||
_EXTRACTION_TEMPLATE = """Extract and save the relevant entities mentioned \
|
||||
|
||||
@@ -1,38 +0,0 @@
|
||||
from importlib import metadata
|
||||
|
||||
from langchain_core._api import warn_deprecated
|
||||
|
||||
## Create namespaces for pydantic v1 and v2.
|
||||
# This code must stay at the top of the file before other modules may
|
||||
# attempt to import pydantic since it adds pydantic_v1 and pydantic_v2 to sys.modules.
|
||||
#
|
||||
# This hack is done for the following reasons:
|
||||
# * LangChain will attempt to remain compatible with both pydantic v1 and v2 since
|
||||
# both dependencies and dependents may be stuck on either version of v1 or v2.
|
||||
# * Creating namespaces for pydantic v1 and v2 should allow us to write code that
|
||||
# unambiguously uses either v1 or v2 API.
|
||||
# * This change is easier to roll out and roll back.
|
||||
from pydantic.v1 import * # noqa: F403
|
||||
|
||||
try:
|
||||
_PYDANTIC_MAJOR_VERSION: int = int(metadata.version("pydantic").split(".")[0])
|
||||
except metadata.PackageNotFoundError:
|
||||
_PYDANTIC_MAJOR_VERSION = 0
|
||||
|
||||
warn_deprecated(
|
||||
"0.3.0",
|
||||
removal="1.0.0",
|
||||
alternative="pydantic.v1 or pydantic",
|
||||
message=(
|
||||
"As of langchain-core 0.3.0, LangChain uses pydantic v2 internally. "
|
||||
"The langchain.pydantic_v1 module was a "
|
||||
"compatibility shim for pydantic v1, and should no longer be used. "
|
||||
"Please update the code to import from Pydantic directly.\n\n"
|
||||
"For example, replace imports like: "
|
||||
"`from langchain_classic.pydantic_v1 import BaseModel`\n"
|
||||
"with: `from pydantic import BaseModel`\n"
|
||||
"or the v1 compatibility namespace if you are working in a code base "
|
||||
"that has not been fully upgraded to pydantic 2 yet. "
|
||||
"\tfrom pydantic.v1 import BaseModel\n"
|
||||
),
|
||||
)
|
||||
@@ -1,20 +0,0 @@
|
||||
from langchain_core._api import warn_deprecated
|
||||
from pydantic.v1.dataclasses import * # noqa: F403
|
||||
|
||||
warn_deprecated(
|
||||
"0.3.0",
|
||||
removal="1.0.0",
|
||||
alternative="pydantic.v1 or pydantic",
|
||||
message=(
|
||||
"As of langchain-core 0.3.0, LangChain uses pydantic v2 internally. "
|
||||
"The langchain.pydantic_v1 module was a "
|
||||
"compatibility shim for pydantic v1, and should no longer be used. "
|
||||
"Please update the code to import from Pydantic directly.\n\n"
|
||||
"For example, replace imports like: "
|
||||
"`from langchain_classic.pydantic_v1 import BaseModel`\n"
|
||||
"with: `from pydantic import BaseModel`\n"
|
||||
"or the v1 compatibility namespace if you are working in a code base "
|
||||
"that has not been fully upgraded to pydantic 2 yet. "
|
||||
"\tfrom pydantic.v1 import BaseModel\n"
|
||||
),
|
||||
)
|
||||
@@ -1,20 +0,0 @@
|
||||
from langchain_core._api import warn_deprecated
|
||||
from pydantic.v1.main import * # noqa: F403
|
||||
|
||||
warn_deprecated(
|
||||
"0.3.0",
|
||||
removal="1.0.0",
|
||||
alternative="pydantic.v1 or pydantic",
|
||||
message=(
|
||||
"As of langchain-core 0.3.0, LangChain uses pydantic v2 internally. "
|
||||
"The langchain.pydantic_v1 module was a "
|
||||
"compatibility shim for pydantic v1, and should no longer be used. "
|
||||
"Please update the code to import from Pydantic directly.\n\n"
|
||||
"For example, replace imports like: "
|
||||
"`from langchain_classic.pydantic_v1 import BaseModel`\n"
|
||||
"with: `from pydantic import BaseModel`\n"
|
||||
"or the v1 compatibility namespace if you are working in a code base "
|
||||
"that has not been fully upgraded to pydantic 2 yet. "
|
||||
"\tfrom pydantic.v1 import BaseModel\n"
|
||||
),
|
||||
)
|
||||
@@ -1,4 +1,6 @@
|
||||
from langchain_core.utils.function_calling import format_tool_to_openai_function
|
||||
from langchain_core.utils.function_calling import (
|
||||
convert_to_openai_function as format_tool_to_openai_function,
|
||||
)
|
||||
|
||||
# For backwards compatibility
|
||||
__all__ = ["format_tool_to_openai_function"]
|
||||
|
||||
@@ -11,8 +11,10 @@ from langchain_core.tools import (
|
||||
render_text_description_and_args,
|
||||
)
|
||||
from langchain_core.utils.function_calling import (
|
||||
format_tool_to_openai_function,
|
||||
format_tool_to_openai_tool,
|
||||
convert_to_openai_function as format_tool_to_openai_function,
|
||||
)
|
||||
from langchain_core.utils.function_calling import (
|
||||
convert_to_openai_tool as format_tool_to_openai_tool,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from langchain_core.utils.function_calling import FunctionDescription, ToolDescription
|
||||
from langchain_core.utils.function_calling import (
|
||||
FunctionDescription,
|
||||
ToolDescription,
|
||||
convert_pydantic_to_openai_function,
|
||||
convert_pydantic_to_openai_tool,
|
||||
convert_to_openai_function as convert_pydantic_to_openai_function,
|
||||
)
|
||||
from langchain_core.utils.function_calling import (
|
||||
convert_to_openai_tool as convert_pydantic_to_openai_tool,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from langchain_core.utils.function_calling import convert_pydantic_to_openai_function
|
||||
from langchain_core.utils.function_calling import convert_to_openai_function
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ def test_convert_pydantic_to_openai_function() -> None:
|
||||
key: str = Field(..., description="API key")
|
||||
days: int = Field(default=0, description="Number of days to forecast")
|
||||
|
||||
actual = convert_pydantic_to_openai_function(Data)
|
||||
actual = convert_to_openai_function(Data)
|
||||
expected = {
|
||||
"name": "Data",
|
||||
"description": "The data to return.",
|
||||
@@ -41,7 +41,7 @@ def test_convert_pydantic_to_openai_function_nested() -> None:
|
||||
|
||||
data: Data
|
||||
|
||||
actual = convert_pydantic_to_openai_function(Model)
|
||||
actual = convert_to_openai_function(Model)
|
||||
expected = {
|
||||
"name": "Model",
|
||||
"description": "The model to return.",
|
||||
|
||||
@@ -537,13 +537,12 @@ def create_agent( # noqa: PLR0915
|
||||
(e.g., `"openai:gpt-4"`), a chat model instance (e.g., `ChatOpenAI()`).
|
||||
tools: A list of tools, dicts, or callables. If `None` or an empty list,
|
||||
the agent will consist of a model node without a tool calling loop.
|
||||
system_prompt: An optional system prompt for the LLM. If provided as a string,
|
||||
it will be converted to a SystemMessage and added to the beginning
|
||||
of the message list.
|
||||
system_prompt: An optional system prompt for the LLM. Prompts are converted to a
|
||||
`SystemMessage` and added to the beginning of the message list.
|
||||
middleware: A sequence of middleware instances to apply to the agent.
|
||||
Middleware can intercept and modify agent behavior at various stages.
|
||||
response_format: An optional configuration for structured responses.
|
||||
Can be a ToolStrategy, ProviderStrategy, or a Pydantic model class.
|
||||
Can be a `ToolStrategy`, `ProviderStrategy`, or a Pydantic model class.
|
||||
If provided, the agent will handle structured output during the
|
||||
conversation flow. Raw schemas will be wrapped in an appropriate strategy
|
||||
based on model capabilities.
|
||||
@@ -560,14 +559,14 @@ def create_agent( # noqa: PLR0915
|
||||
This is useful if you want to return directly or run additional processing
|
||||
on an output.
|
||||
debug: A flag indicating whether to enable debug mode.
|
||||
name: An optional name for the CompiledStateGraph.
|
||||
name: An optional name for the `CompiledStateGraph`.
|
||||
This name will be automatically used when adding the agent graph to
|
||||
another graph as a subgraph node - particularly useful for building
|
||||
multi-agent systems.
|
||||
cache: An optional BaseCache instance to enable caching of graph execution.
|
||||
cache: An optional `BaseCache` instance to enable caching of graph execution.
|
||||
|
||||
Returns:
|
||||
A compiled StateGraph that can be used for chat interactions.
|
||||
A compiled `StateGraph` that can be used for chat interactions.
|
||||
|
||||
The agent node calls the language model with the messages list (after applying
|
||||
the system prompt). If the resulting AIMessage contains `tool_calls`, the graph will
|
||||
|
||||
@@ -11,9 +11,8 @@ from .human_in_the_loop import (
|
||||
from .model_call_limit import ModelCallLimitMiddleware
|
||||
from .model_fallback import ModelFallbackMiddleware
|
||||
from .pii import PIIDetectionError, PIIMiddleware
|
||||
from .planning import PlanningMiddleware
|
||||
from .prompt_caching import AnthropicPromptCachingMiddleware
|
||||
from .summarization import SummarizationMiddleware
|
||||
from .todo import TodoListMiddleware
|
||||
from .tool_call_limit import ToolCallLimitMiddleware
|
||||
from .tool_emulator import LLMToolEmulator
|
||||
from .tool_selection import LLMToolSelectorMiddleware
|
||||
@@ -21,6 +20,7 @@ from .types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
after_agent,
|
||||
after_model,
|
||||
before_agent,
|
||||
@@ -28,13 +28,12 @@ from .types import (
|
||||
dynamic_prompt,
|
||||
hook_config,
|
||||
wrap_model_call,
|
||||
wrap_tool_call,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AgentMiddleware",
|
||||
"AgentState",
|
||||
# should move to langchain-anthropic if we decide to keep it
|
||||
"AnthropicPromptCachingMiddleware",
|
||||
"ClearToolUsesEdit",
|
||||
"ContextEditingMiddleware",
|
||||
"HumanInTheLoopMiddleware",
|
||||
@@ -44,10 +43,11 @@ __all__ = [
|
||||
"ModelCallLimitMiddleware",
|
||||
"ModelFallbackMiddleware",
|
||||
"ModelRequest",
|
||||
"ModelResponse",
|
||||
"PIIDetectionError",
|
||||
"PIIMiddleware",
|
||||
"PlanningMiddleware",
|
||||
"SummarizationMiddleware",
|
||||
"TodoListMiddleware",
|
||||
"ToolCallLimitMiddleware",
|
||||
"after_agent",
|
||||
"after_model",
|
||||
@@ -56,4 +56,5 @@ __all__ = [
|
||||
"dynamic_prompt",
|
||||
"hook_config",
|
||||
"wrap_model_call",
|
||||
"wrap_tool_call",
|
||||
]
|
||||
|
||||
@@ -8,7 +8,7 @@ with any LangChain chat model.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
from collections.abc import Awaitable, Callable, Iterable, Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
@@ -239,6 +239,34 @@ class ContextEditingMiddleware(AgentMiddleware):
|
||||
|
||||
return handler(request)
|
||||
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> ModelCallResult:
|
||||
"""Apply context edits before invoking the model via handler (async version)."""
|
||||
if not request.messages:
|
||||
return await handler(request)
|
||||
|
||||
if self.token_count_method == "approximate": # noqa: S105
|
||||
|
||||
def count_tokens(messages: Sequence[BaseMessage]) -> int:
|
||||
return count_tokens_approximately(messages)
|
||||
else:
|
||||
system_msg = (
|
||||
[SystemMessage(content=request.system_prompt)] if request.system_prompt else []
|
||||
)
|
||||
|
||||
def count_tokens(messages: Sequence[BaseMessage]) -> int:
|
||||
return request.model.get_num_tokens_from_messages(
|
||||
system_msg + list(messages), request.tools
|
||||
)
|
||||
|
||||
for edit in self.edits:
|
||||
edit.apply(request.messages, count_tokens=count_tokens)
|
||||
|
||||
return await handler(request)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ClearToolUsesEdit",
|
||||
|
||||
@@ -2,7 +2,8 @@
|
||||
|
||||
from typing import Any, Literal, Protocol
|
||||
|
||||
from langchain_core.messages import AIMessage, ToolCall, ToolMessage
|
||||
from langchain_core.messages import AIMessage, RemoveMessage, ToolCall, ToolMessage
|
||||
from langgraph.graph.message import REMOVE_ALL_MESSAGES
|
||||
from langgraph.runtime import Runtime
|
||||
from langgraph.types import interrupt
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
@@ -269,6 +270,42 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
def before_agent(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
|
||||
"""Before the agent runs, handle dangling tool calls from the most recent AIMessage."""
|
||||
messages = state["messages"]
|
||||
if not messages or len(messages) == 0:
|
||||
return None
|
||||
|
||||
patched_messages = []
|
||||
# Iterate over the messages and add any dangling tool calls
|
||||
for i, msg in enumerate(messages):
|
||||
patched_messages.append(msg)
|
||||
if msg.type == "ai" and msg.tool_calls:
|
||||
for tool_call in msg.tool_calls:
|
||||
corresponding_tool_msg = next(
|
||||
(
|
||||
msg
|
||||
for msg in messages[i:]
|
||||
if msg.type == "tool" and msg.tool_call_id == tool_call["id"]
|
||||
),
|
||||
None,
|
||||
)
|
||||
if corresponding_tool_msg is None:
|
||||
# We have a dangling tool call which needs a ToolMessage
|
||||
tool_msg = (
|
||||
f"Tool call {tool_call['name']} with id {tool_call['id']} was "
|
||||
"cancelled - another message came in before it could be completed."
|
||||
)
|
||||
patched_messages.append(
|
||||
ToolMessage(
|
||||
content=tool_msg,
|
||||
name=tool_call["name"],
|
||||
tool_call_id=tool_call["id"],
|
||||
)
|
||||
)
|
||||
|
||||
return {"messages": [RemoveMessage(id=REMOVE_ALL_MESSAGES), *patched_messages]}
|
||||
|
||||
def after_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
|
||||
"""Trigger interrupt flows for relevant tool calls after an AIMessage."""
|
||||
messages = state["messages"]
|
||||
|
||||
@@ -13,7 +13,7 @@ from langchain.agents.middleware.types import (
|
||||
from langchain.chat_models import init_chat_model
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
|
||||
@@ -102,3 +102,38 @@ class ModelFallbackMiddleware(AgentMiddleware):
|
||||
continue
|
||||
|
||||
raise last_exception
|
||||
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> ModelCallResult:
|
||||
"""Try fallback models in sequence on errors (async version).
|
||||
|
||||
Args:
|
||||
request: Initial model request.
|
||||
handler: Async callback to execute the model.
|
||||
|
||||
Returns:
|
||||
AIMessage from successful model call.
|
||||
|
||||
Raises:
|
||||
Exception: If all models fail, re-raises last exception.
|
||||
"""
|
||||
# Try primary model first
|
||||
last_exception: Exception
|
||||
try:
|
||||
return await handler(request)
|
||||
except Exception as e: # noqa: BLE001
|
||||
last_exception = e
|
||||
|
||||
# Try fallback models
|
||||
for fallback_model in self.models:
|
||||
request.model = fallback_model
|
||||
try:
|
||||
return await handler(request)
|
||||
except Exception as e: # noqa: BLE001
|
||||
last_exception = e
|
||||
continue
|
||||
|
||||
raise last_exception
|
||||
|
||||
@@ -1,89 +0,0 @@
|
||||
"""Anthropic prompt caching middleware."""
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Literal
|
||||
from warnings import warn
|
||||
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
ModelCallResult,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
)
|
||||
|
||||
|
||||
class AnthropicPromptCachingMiddleware(AgentMiddleware):
|
||||
"""Prompt Caching Middleware.
|
||||
|
||||
Optimizes API usage by caching conversation prefixes for Anthropic models.
|
||||
|
||||
Learn more about Anthropic prompt caching
|
||||
[here](https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
type: Literal["ephemeral"] = "ephemeral",
|
||||
ttl: Literal["5m", "1h"] = "5m",
|
||||
min_messages_to_cache: int = 0,
|
||||
unsupported_model_behavior: Literal["ignore", "warn", "raise"] = "warn",
|
||||
) -> None:
|
||||
"""Initialize the middleware with cache control settings.
|
||||
|
||||
Args:
|
||||
type: The type of cache to use, only "ephemeral" is supported.
|
||||
ttl: The time to live for the cache, only "5m" and "1h" are supported.
|
||||
min_messages_to_cache: The minimum number of messages until the cache is used,
|
||||
default is 0.
|
||||
unsupported_model_behavior: The behavior to take when an unsupported model is used.
|
||||
"ignore" will ignore the unsupported model and continue without caching.
|
||||
"warn" will warn the user and continue without caching.
|
||||
"raise" will raise an error and stop the agent.
|
||||
"""
|
||||
self.type = type
|
||||
self.ttl = ttl
|
||||
self.min_messages_to_cache = min_messages_to_cache
|
||||
self.unsupported_model_behavior = unsupported_model_behavior
|
||||
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
"""Modify the model request to add cache control blocks."""
|
||||
try:
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
except ImportError:
|
||||
ChatAnthropic = None # noqa: N806
|
||||
|
||||
msg: str | None = None
|
||||
|
||||
if ChatAnthropic is None:
|
||||
msg = (
|
||||
"AnthropicPromptCachingMiddleware caching middleware only supports "
|
||||
"Anthropic models. "
|
||||
"Please install langchain-anthropic."
|
||||
)
|
||||
elif not isinstance(request.model, ChatAnthropic):
|
||||
msg = (
|
||||
"AnthropicPromptCachingMiddleware caching middleware only supports "
|
||||
f"Anthropic models, not instances of {type(request.model)}"
|
||||
)
|
||||
|
||||
if msg is not None:
|
||||
if self.unsupported_model_behavior == "raise":
|
||||
raise ValueError(msg)
|
||||
if self.unsupported_model_behavior == "warn":
|
||||
warn(msg, stacklevel=3)
|
||||
else:
|
||||
return handler(request)
|
||||
|
||||
messages_count = (
|
||||
len(request.messages) + 1 if request.system_prompt else len(request.messages)
|
||||
)
|
||||
if messages_count < self.min_messages_to_cache:
|
||||
return handler(request)
|
||||
|
||||
request.model_settings["cache_control"] = {"type": self.type, "ttl": self.ttl}
|
||||
|
||||
return handler(request)
|
||||
@@ -6,7 +6,7 @@ from __future__ import annotations
|
||||
from typing import TYPE_CHECKING, Annotated, Literal
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langchain_core.tools import tool
|
||||
@@ -126,7 +126,7 @@ def write_todos(todos: list[Todo], tool_call_id: Annotated[str, InjectedToolCall
|
||||
)
|
||||
|
||||
|
||||
class PlanningMiddleware(AgentMiddleware):
|
||||
class TodoListMiddleware(AgentMiddleware):
|
||||
"""Middleware that provides todo list management capabilities to agents.
|
||||
|
||||
This middleware adds a `write_todos` tool that allows agents to create and manage
|
||||
@@ -139,10 +139,10 @@ class PlanningMiddleware(AgentMiddleware):
|
||||
|
||||
Example:
|
||||
```python
|
||||
from langchain.agents.middleware.planning import PlanningMiddleware
|
||||
from langchain.agents.middleware.todo import TodoListMiddleware
|
||||
from langchain.agents import create_agent
|
||||
|
||||
agent = create_agent("openai:gpt-4o", middleware=[PlanningMiddleware()])
|
||||
agent = create_agent("openai:gpt-4o", middleware=[TodoListMiddleware()])
|
||||
|
||||
# Agent now has access to write_todos tool and todo state tracking
|
||||
result = await agent.invoke({"messages": [HumanMessage("Help me refactor my codebase")]})
|
||||
@@ -165,7 +165,7 @@ class PlanningMiddleware(AgentMiddleware):
|
||||
system_prompt: str = WRITE_TODOS_SYSTEM_PROMPT,
|
||||
tool_description: str = WRITE_TODOS_TOOL_DESCRIPTION,
|
||||
) -> None:
|
||||
"""Initialize the PlanningMiddleware with optional custom prompts.
|
||||
"""Initialize the TodoListMiddleware with optional custom prompts.
|
||||
|
||||
Args:
|
||||
system_prompt: Custom system prompt to guide the agent on using the todo tool.
|
||||
@@ -204,3 +204,16 @@ class PlanningMiddleware(AgentMiddleware):
|
||||
else self.system_prompt
|
||||
)
|
||||
return handler(request)
|
||||
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> ModelCallResult:
|
||||
"""Update the system prompt to include the todo system prompt (async version)."""
|
||||
request.system_prompt = (
|
||||
request.system_prompt + "\n\n" + self.system_prompt
|
||||
if request.system_prompt
|
||||
else self.system_prompt
|
||||
)
|
||||
return await handler(request)
|
||||
@@ -3,7 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass, field, replace
|
||||
from inspect import iscoroutinefunction
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
@@ -30,7 +30,7 @@ from langgraph.channels.untracked_value import UntrackedValue
|
||||
from langgraph.graph.message import add_messages
|
||||
from langgraph.types import Command # noqa: TC002
|
||||
from langgraph.typing import ContextT
|
||||
from typing_extensions import NotRequired, Required, TypedDict, TypeVar
|
||||
from typing_extensions import NotRequired, Required, TypedDict, TypeVar, Unpack
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
@@ -62,6 +62,18 @@ JumpTo = Literal["tools", "model", "end"]
|
||||
ResponseT = TypeVar("ResponseT")
|
||||
|
||||
|
||||
class _ModelRequestOverrides(TypedDict, total=False):
|
||||
"""Possible overrides for ModelRequest.override() method."""
|
||||
|
||||
model: BaseChatModel
|
||||
system_prompt: str | None
|
||||
messages: list[AnyMessage]
|
||||
tool_choice: Any | None
|
||||
tools: list[BaseTool | dict]
|
||||
response_format: ResponseFormat | None
|
||||
model_settings: dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelRequest:
|
||||
"""Model request information for the agent."""
|
||||
@@ -76,6 +88,36 @@ class ModelRequest:
|
||||
runtime: Runtime[ContextT] # type: ignore[valid-type]
|
||||
model_settings: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def override(self, **overrides: Unpack[_ModelRequestOverrides]) -> ModelRequest:
|
||||
"""Replace the request with a new request with the given overrides.
|
||||
|
||||
Returns a new `ModelRequest` instance with the specified attributes replaced.
|
||||
This follows an immutable pattern, leaving the original request unchanged.
|
||||
|
||||
Args:
|
||||
**overrides: Keyword arguments for attributes to override. Supported keys:
|
||||
- model: BaseChatModel instance
|
||||
- system_prompt: Optional system prompt string
|
||||
- messages: List of messages
|
||||
- tool_choice: Tool choice configuration
|
||||
- tools: List of available tools
|
||||
- response_format: Response format specification
|
||||
- model_settings: Additional model settings
|
||||
|
||||
Returns:
|
||||
New ModelRequest instance with specified overrides applied.
|
||||
|
||||
Examples:
|
||||
```python
|
||||
# Create a new request with different model
|
||||
new_request = request.override(model=different_model)
|
||||
|
||||
# Override multiple attributes
|
||||
new_request = request.override(system_prompt="New instructions", tool_choice="auto")
|
||||
```
|
||||
"""
|
||||
return replace(self, **overrides)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelResponse:
|
||||
|
||||
@@ -8,7 +8,7 @@ from langchain_core.tools import (
|
||||
tool,
|
||||
)
|
||||
|
||||
from langchain.tools.tool_node import InjectedState, InjectedStore, ToolInvocationError
|
||||
from langchain.tools.tool_node import InjectedState, InjectedStore
|
||||
|
||||
__all__ = [
|
||||
"BaseTool",
|
||||
@@ -17,6 +17,5 @@ __all__ = [
|
||||
"InjectedToolArg",
|
||||
"InjectedToolCallId",
|
||||
"ToolException",
|
||||
"ToolInvocationError",
|
||||
"tool",
|
||||
]
|
||||
|
||||
@@ -81,6 +81,7 @@ from langgraph.graph.message import REMOVE_ALL_MESSAGES
|
||||
from langgraph.runtime import get_runtime
|
||||
from langgraph.types import Command, Send
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from typing_extensions import Unpack
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
@@ -104,6 +105,12 @@ TOOL_INVOCATION_ERROR_TEMPLATE = (
|
||||
)
|
||||
|
||||
|
||||
class _ToolCallRequestOverrides(TypedDict, total=False):
|
||||
"""Possible overrides for ToolCallRequest.override() method."""
|
||||
|
||||
tool_call: ToolCall
|
||||
|
||||
|
||||
@dataclass()
|
||||
class ToolCallRequest:
|
||||
"""Tool execution request passed to tool call interceptors.
|
||||
@@ -120,6 +127,31 @@ class ToolCallRequest:
|
||||
state: Any
|
||||
runtime: Any
|
||||
|
||||
def override(self, **overrides: Unpack[_ToolCallRequestOverrides]) -> ToolCallRequest:
|
||||
"""Replace the request with a new request with the given overrides.
|
||||
|
||||
Returns a new `ToolCallRequest` instance with the specified attributes replaced.
|
||||
This follows an immutable pattern, leaving the original request unchanged.
|
||||
|
||||
Args:
|
||||
**overrides: Keyword arguments for attributes to override. Supported keys:
|
||||
- tool_call: Tool call dict with name, args, and id
|
||||
|
||||
Returns:
|
||||
New ToolCallRequest instance with specified overrides applied.
|
||||
|
||||
Examples:
|
||||
```python
|
||||
# Modify tool call arguments without mutating original
|
||||
modified_call = {**request.tool_call, "args": {"value": 10}}
|
||||
new_request = request.override(tool_call=modified_call)
|
||||
|
||||
# Override multiple attributes
|
||||
new_request = request.override(tool_call=modified_call, state=new_state)
|
||||
```
|
||||
"""
|
||||
return replace(self, **overrides)
|
||||
|
||||
|
||||
ToolCallWrapper = Callable[
|
||||
[ToolCallRequest, Callable[[ToolCallRequest], ToolMessage | Command]],
|
||||
|
||||
@@ -0,0 +1,142 @@
|
||||
from langchain.agents.middleware.human_in_the_loop import HumanInTheLoopMiddleware
|
||||
from langchain_core.messages import (
|
||||
SystemMessage,
|
||||
HumanMessage,
|
||||
AIMessage,
|
||||
ToolMessage,
|
||||
ToolCall,
|
||||
RemoveMessage,
|
||||
)
|
||||
from langgraph.graph.message import add_messages
|
||||
|
||||
|
||||
class TestHumanInTheLoopMiddlewareBeforeModel:
|
||||
"""Test HumanInTheLoopMiddleware before_model behavior."""
|
||||
|
||||
def test_first_message(self) -> None:
|
||||
input_messages = [
|
||||
SystemMessage(content="You are a helpful assistant.", id="1"),
|
||||
HumanMessage(content="Hello, how are you?", id="2"),
|
||||
]
|
||||
middleware = HumanInTheLoopMiddleware(interrupt_on={})
|
||||
state_update = middleware.before_agent({"messages": input_messages}, None)
|
||||
assert state_update is not None
|
||||
assert len(state_update["messages"]) == 3
|
||||
assert state_update["messages"][0].type == "remove"
|
||||
assert state_update["messages"][1].type == "system"
|
||||
assert state_update["messages"][1].content == "You are a helpful assistant."
|
||||
assert state_update["messages"][2].type == "human"
|
||||
assert state_update["messages"][2].content == "Hello, how are you?"
|
||||
assert state_update["messages"][2].id == "2"
|
||||
|
||||
def test_missing_tool_call(self) -> None:
|
||||
input_messages = [
|
||||
SystemMessage(content="You are a helpful assistant.", id="1"),
|
||||
HumanMessage(content="Hello, how are you?", id="2"),
|
||||
AIMessage(
|
||||
content="I'm doing well, thank you!",
|
||||
tool_calls=[
|
||||
ToolCall(id="123", name="get_events_for_days", args={"date_str": "2025-01-01"})
|
||||
],
|
||||
id="3",
|
||||
),
|
||||
HumanMessage(content="What is the weather in Tokyo?", id="4"),
|
||||
]
|
||||
middleware = HumanInTheLoopMiddleware(interrupt_on={})
|
||||
state_update = middleware.before_agent({"messages": input_messages}, None)
|
||||
assert state_update is not None
|
||||
assert len(state_update["messages"]) == 6
|
||||
assert state_update["messages"][0].type == "remove"
|
||||
assert state_update["messages"][1] == input_messages[0]
|
||||
assert state_update["messages"][2] == input_messages[1]
|
||||
assert state_update["messages"][3] == input_messages[2]
|
||||
assert state_update["messages"][4].type == "tool"
|
||||
assert state_update["messages"][4].tool_call_id == "123"
|
||||
assert state_update["messages"][4].name == "get_events_for_days"
|
||||
assert state_update["messages"][5] == input_messages[3]
|
||||
updated_messages = add_messages(input_messages, state_update["messages"])
|
||||
assert len(updated_messages) == 5
|
||||
assert updated_messages[0] == input_messages[0]
|
||||
assert updated_messages[1] == input_messages[1]
|
||||
assert updated_messages[2] == input_messages[2]
|
||||
assert updated_messages[3].type == "tool"
|
||||
assert updated_messages[3].tool_call_id == "123"
|
||||
assert updated_messages[3].name == "get_events_for_days"
|
||||
assert updated_messages[4] == input_messages[3]
|
||||
|
||||
def test_no_missing_tool_calls(self) -> None:
|
||||
input_messages = [
|
||||
SystemMessage(content="You are a helpful assistant.", id="1"),
|
||||
HumanMessage(content="Hello, how are you?", id="2"),
|
||||
AIMessage(
|
||||
content="I'm doing well, thank you!",
|
||||
tool_calls=[
|
||||
ToolCall(id="123", name="get_events_for_days", args={"date_str": "2025-01-01"})
|
||||
],
|
||||
id="3",
|
||||
),
|
||||
ToolMessage(content="I have no events for that date.", tool_call_id="123", id="4"),
|
||||
HumanMessage(content="What is the weather in Tokyo?", id="5"),
|
||||
]
|
||||
middleware = HumanInTheLoopMiddleware(interrupt_on={})
|
||||
state_update = middleware.before_agent({"messages": input_messages}, None)
|
||||
assert state_update is not None
|
||||
assert len(state_update["messages"]) == 6
|
||||
assert state_update["messages"][0].type == "remove"
|
||||
assert state_update["messages"][1:] == input_messages
|
||||
updated_messages = add_messages(input_messages, state_update["messages"])
|
||||
assert len(updated_messages) == 5
|
||||
assert updated_messages == input_messages
|
||||
|
||||
def test_two_missing_tool_calls(self) -> None:
|
||||
input_messages = [
|
||||
SystemMessage(content="You are a helpful assistant.", id="1"),
|
||||
HumanMessage(content="Hello, how are you?", id="2"),
|
||||
AIMessage(
|
||||
content="I'm doing well, thank you!",
|
||||
tool_calls=[
|
||||
ToolCall(id="123", name="get_events_for_days", args={"date_str": "2025-01-01"})
|
||||
],
|
||||
id="3",
|
||||
),
|
||||
HumanMessage(content="What is the weather in Tokyo?", id="4"),
|
||||
AIMessage(
|
||||
content="I'm doing well, thank you!",
|
||||
tool_calls=[
|
||||
ToolCall(id="456", name="get_events_for_days", args={"date_str": "2025-01-01"})
|
||||
],
|
||||
id="5",
|
||||
),
|
||||
HumanMessage(content="What is the weather in Tokyo?", id="6"),
|
||||
]
|
||||
middleware = HumanInTheLoopMiddleware(interrupt_on={})
|
||||
state_update = middleware.before_agent({"messages": input_messages}, None)
|
||||
assert state_update is not None
|
||||
assert len(state_update["messages"]) == 9
|
||||
assert state_update["messages"][0].type == "remove"
|
||||
assert state_update["messages"][1] == input_messages[0]
|
||||
assert state_update["messages"][2] == input_messages[1]
|
||||
assert state_update["messages"][3] == input_messages[2]
|
||||
assert state_update["messages"][4].type == "tool"
|
||||
assert state_update["messages"][4].tool_call_id == "123"
|
||||
assert state_update["messages"][4].name == "get_events_for_days"
|
||||
assert state_update["messages"][5] == input_messages[3]
|
||||
assert state_update["messages"][6] == input_messages[4]
|
||||
assert state_update["messages"][7].type == "tool"
|
||||
assert state_update["messages"][7].tool_call_id == "456"
|
||||
assert state_update["messages"][7].name == "get_events_for_days"
|
||||
assert state_update["messages"][8] == input_messages[5]
|
||||
updated_messages = add_messages(input_messages, state_update["messages"])
|
||||
assert len(updated_messages) == 8
|
||||
assert updated_messages[0] == input_messages[0]
|
||||
assert updated_messages[1] == input_messages[1]
|
||||
assert updated_messages[2] == input_messages[2]
|
||||
assert updated_messages[3].type == "tool"
|
||||
assert updated_messages[3].tool_call_id == "123"
|
||||
assert updated_messages[3].name == "get_events_for_days"
|
||||
assert updated_messages[4] == input_messages[3]
|
||||
assert updated_messages[5] == input_messages[4]
|
||||
assert updated_messages[6].type == "tool"
|
||||
assert updated_messages[6].tool_call_id == "456"
|
||||
assert updated_messages[6].name == "get_events_for_days"
|
||||
assert updated_messages[7] == input_messages[5]
|
||||
@@ -0,0 +1,381 @@
|
||||
"""Unit tests for override() methods on ModelRequest and ToolCallRequest."""
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from langchain.agents.middleware.types import ModelRequest
|
||||
from langchain.tools.tool_node import ToolCallRequest
|
||||
|
||||
|
||||
class TestModelRequestOverride:
|
||||
"""Test the ModelRequest.override() method."""
|
||||
|
||||
def test_override_single_attribute(self) -> None:
|
||||
"""Test overriding a single attribute."""
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
|
||||
original_request = ModelRequest(
|
||||
model=model,
|
||||
system_prompt="Original prompt",
|
||||
messages=[HumanMessage("Hi")],
|
||||
tool_choice=None,
|
||||
tools=[],
|
||||
response_format=None,
|
||||
state={},
|
||||
runtime=None,
|
||||
)
|
||||
|
||||
new_request = original_request.override(system_prompt="New prompt")
|
||||
|
||||
# New request should have the overridden value
|
||||
assert new_request.system_prompt == "New prompt"
|
||||
# Original request should be unchanged (immutability)
|
||||
assert original_request.system_prompt == "Original prompt"
|
||||
# Other attributes should be the same
|
||||
assert new_request.model == original_request.model
|
||||
assert new_request.messages == original_request.messages
|
||||
|
||||
def test_override_multiple_attributes(self) -> None:
|
||||
"""Test overriding multiple attributes at once."""
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
|
||||
original_request = ModelRequest(
|
||||
model=model,
|
||||
system_prompt="Original prompt",
|
||||
messages=[HumanMessage("Hi")],
|
||||
tool_choice=None,
|
||||
tools=[],
|
||||
response_format=None,
|
||||
state={"count": 1},
|
||||
runtime=None,
|
||||
)
|
||||
|
||||
new_request = original_request.override(
|
||||
system_prompt="New prompt",
|
||||
tool_choice="auto",
|
||||
state={"count": 2},
|
||||
)
|
||||
|
||||
# Overridden values should be changed
|
||||
assert new_request.system_prompt == "New prompt"
|
||||
assert new_request.tool_choice == "auto"
|
||||
assert new_request.state == {"count": 2}
|
||||
# Original should be unchanged
|
||||
assert original_request.system_prompt == "Original prompt"
|
||||
assert original_request.tool_choice is None
|
||||
assert original_request.state == {"count": 1}
|
||||
|
||||
def test_override_messages(self) -> None:
|
||||
"""Test overriding messages list."""
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
|
||||
original_messages = [HumanMessage("Hi")]
|
||||
new_messages = [HumanMessage("Hello"), AIMessage("Hi there")]
|
||||
|
||||
original_request = ModelRequest(
|
||||
model=model,
|
||||
system_prompt=None,
|
||||
messages=original_messages,
|
||||
tool_choice=None,
|
||||
tools=[],
|
||||
response_format=None,
|
||||
state={},
|
||||
runtime=None,
|
||||
)
|
||||
|
||||
new_request = original_request.override(messages=new_messages)
|
||||
|
||||
assert new_request.messages == new_messages
|
||||
assert original_request.messages == original_messages
|
||||
assert len(new_request.messages) == 2
|
||||
assert len(original_request.messages) == 1
|
||||
|
||||
def test_override_model_settings(self) -> None:
|
||||
"""Test overriding model_settings dict."""
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
|
||||
original_request = ModelRequest(
|
||||
model=model,
|
||||
system_prompt=None,
|
||||
messages=[HumanMessage("Hi")],
|
||||
tool_choice=None,
|
||||
tools=[],
|
||||
response_format=None,
|
||||
state={},
|
||||
runtime=None,
|
||||
model_settings={"temperature": 0.5},
|
||||
)
|
||||
|
||||
new_request = original_request.override(
|
||||
model_settings={"temperature": 0.9, "max_tokens": 100}
|
||||
)
|
||||
|
||||
assert new_request.model_settings == {"temperature": 0.9, "max_tokens": 100}
|
||||
assert original_request.model_settings == {"temperature": 0.5}
|
||||
|
||||
def test_override_with_none_value(self) -> None:
|
||||
"""Test overriding with None value."""
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
|
||||
original_request = ModelRequest(
|
||||
model=model,
|
||||
system_prompt="Original prompt",
|
||||
messages=[HumanMessage("Hi")],
|
||||
tool_choice="auto",
|
||||
tools=[],
|
||||
response_format=None,
|
||||
state={},
|
||||
runtime=None,
|
||||
)
|
||||
|
||||
new_request = original_request.override(
|
||||
system_prompt=None,
|
||||
tool_choice=None,
|
||||
)
|
||||
|
||||
assert new_request.system_prompt is None
|
||||
assert new_request.tool_choice is None
|
||||
assert original_request.system_prompt == "Original prompt"
|
||||
assert original_request.tool_choice == "auto"
|
||||
|
||||
def test_override_preserves_identity_of_unchanged_objects(self) -> None:
|
||||
"""Test that unchanged attributes maintain object identity."""
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
|
||||
messages = [HumanMessage("Hi")]
|
||||
state = {"key": "value"}
|
||||
|
||||
original_request = ModelRequest(
|
||||
model=model,
|
||||
system_prompt="Original prompt",
|
||||
messages=messages,
|
||||
tool_choice=None,
|
||||
tools=[],
|
||||
response_format=None,
|
||||
state=state,
|
||||
runtime=None,
|
||||
)
|
||||
|
||||
new_request = original_request.override(system_prompt="New prompt")
|
||||
|
||||
# Unchanged objects should be the same instance
|
||||
assert new_request.messages is messages
|
||||
assert new_request.state is state
|
||||
assert new_request.model is model
|
||||
|
||||
def test_override_chaining(self) -> None:
|
||||
"""Test chaining multiple override calls."""
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
|
||||
original_request = ModelRequest(
|
||||
model=model,
|
||||
system_prompt="Prompt 1",
|
||||
messages=[HumanMessage("Hi")],
|
||||
tool_choice=None,
|
||||
tools=[],
|
||||
response_format=None,
|
||||
state={"count": 1},
|
||||
runtime=None,
|
||||
)
|
||||
|
||||
final_request = (
|
||||
original_request.override(system_prompt="Prompt 2")
|
||||
.override(state={"count": 2})
|
||||
.override(tool_choice="auto")
|
||||
)
|
||||
|
||||
assert final_request.system_prompt == "Prompt 2"
|
||||
assert final_request.state == {"count": 2}
|
||||
assert final_request.tool_choice == "auto"
|
||||
# Original should be unchanged
|
||||
assert original_request.system_prompt == "Prompt 1"
|
||||
assert original_request.state == {"count": 1}
|
||||
assert original_request.tool_choice is None
|
||||
|
||||
|
||||
class TestToolCallRequestOverride:
|
||||
"""Test the ToolCallRequest.override() method."""
|
||||
|
||||
def test_override_tool_call(self) -> None:
|
||||
"""Test overriding tool_call dict."""
|
||||
from langchain_core.tools import tool
|
||||
|
||||
@tool
|
||||
def test_tool(x: int) -> str:
|
||||
"""A test tool."""
|
||||
return f"Result: {x}"
|
||||
|
||||
original_call = {"name": "test_tool", "args": {"x": 5}, "id": "1", "type": "tool_call"}
|
||||
modified_call = {"name": "test_tool", "args": {"x": 10}, "id": "1", "type": "tool_call"}
|
||||
|
||||
original_request = ToolCallRequest(
|
||||
tool_call=original_call,
|
||||
tool=test_tool,
|
||||
state={"messages": []},
|
||||
runtime=None,
|
||||
)
|
||||
|
||||
new_request = original_request.override(tool_call=modified_call)
|
||||
|
||||
# New request should have modified tool_call
|
||||
assert new_request.tool_call["args"]["x"] == 10
|
||||
# Original should be unchanged
|
||||
assert original_request.tool_call["args"]["x"] == 5
|
||||
# Other attributes should be the same
|
||||
assert new_request.tool is original_request.tool
|
||||
assert new_request.state is original_request.state
|
||||
|
||||
def test_override_state(self) -> None:
|
||||
"""Test overriding state."""
|
||||
from langchain_core.tools import tool
|
||||
|
||||
@tool
|
||||
def test_tool(x: int) -> str:
|
||||
"""A test tool."""
|
||||
return f"Result: {x}"
|
||||
|
||||
tool_call = {"name": "test_tool", "args": {"x": 5}, "id": "1", "type": "tool_call"}
|
||||
original_state = {"messages": [HumanMessage("Hi")]}
|
||||
new_state = {"messages": [HumanMessage("Hi"), AIMessage("Hello")]}
|
||||
|
||||
original_request = ToolCallRequest(
|
||||
tool_call=tool_call,
|
||||
tool=test_tool,
|
||||
state=original_state,
|
||||
runtime=None,
|
||||
)
|
||||
|
||||
new_request = original_request.override(state=new_state)
|
||||
|
||||
assert len(new_request.state["messages"]) == 2
|
||||
assert len(original_request.state["messages"]) == 1
|
||||
|
||||
def test_override_multiple_attributes(self) -> None:
|
||||
"""Test overriding multiple attributes at once."""
|
||||
from langchain_core.tools import tool
|
||||
|
||||
@tool
|
||||
def test_tool(x: int) -> str:
|
||||
"""A test tool."""
|
||||
return f"Result: {x}"
|
||||
|
||||
@tool
|
||||
def another_tool(y: str) -> str:
|
||||
"""Another test tool."""
|
||||
return f"Output: {y}"
|
||||
|
||||
original_call = {"name": "test_tool", "args": {"x": 5}, "id": "1", "type": "tool_call"}
|
||||
modified_call = {
|
||||
"name": "another_tool",
|
||||
"args": {"y": "hello"},
|
||||
"id": "2",
|
||||
"type": "tool_call",
|
||||
}
|
||||
|
||||
original_request = ToolCallRequest(
|
||||
tool_call=original_call,
|
||||
tool=test_tool,
|
||||
state={"count": 1},
|
||||
runtime=None,
|
||||
)
|
||||
|
||||
new_request = original_request.override(
|
||||
tool_call=modified_call,
|
||||
tool=another_tool,
|
||||
state={"count": 2},
|
||||
)
|
||||
|
||||
assert new_request.tool_call["name"] == "another_tool"
|
||||
assert new_request.tool.name == "another_tool"
|
||||
assert new_request.state == {"count": 2}
|
||||
# Original unchanged
|
||||
assert original_request.tool_call["name"] == "test_tool"
|
||||
assert original_request.tool.name == "test_tool"
|
||||
assert original_request.state == {"count": 1}
|
||||
|
||||
def test_override_with_copy_pattern(self) -> None:
|
||||
"""Test common pattern of copying and modifying tool_call."""
|
||||
from langchain_core.tools import tool
|
||||
|
||||
@tool
|
||||
def test_tool(value: int) -> str:
|
||||
"""A test tool."""
|
||||
return f"Result: {value}"
|
||||
|
||||
original_call = {
|
||||
"name": "test_tool",
|
||||
"args": {"value": 5},
|
||||
"id": "call_123",
|
||||
"type": "tool_call",
|
||||
}
|
||||
|
||||
original_request = ToolCallRequest(
|
||||
tool_call=original_call,
|
||||
tool=test_tool,
|
||||
state={},
|
||||
runtime=None,
|
||||
)
|
||||
|
||||
# Common pattern: copy tool_call and modify args
|
||||
modified_call = {**original_request.tool_call, "args": {"value": 10}}
|
||||
new_request = original_request.override(tool_call=modified_call)
|
||||
|
||||
assert new_request.tool_call["args"]["value"] == 10
|
||||
assert new_request.tool_call["id"] == "call_123"
|
||||
assert new_request.tool_call["name"] == "test_tool"
|
||||
# Original unchanged
|
||||
assert original_request.tool_call["args"]["value"] == 5
|
||||
|
||||
def test_override_preserves_identity(self) -> None:
|
||||
"""Test that unchanged attributes maintain object identity."""
|
||||
from langchain_core.tools import tool
|
||||
|
||||
@tool
|
||||
def test_tool(x: int) -> str:
|
||||
"""A test tool."""
|
||||
return f"Result: {x}"
|
||||
|
||||
tool_call = {"name": "test_tool", "args": {"x": 5}, "id": "1", "type": "tool_call"}
|
||||
state = {"messages": []}
|
||||
|
||||
original_request = ToolCallRequest(
|
||||
tool_call=tool_call,
|
||||
tool=test_tool,
|
||||
state=state,
|
||||
runtime=None,
|
||||
)
|
||||
|
||||
new_call = {"name": "test_tool", "args": {"x": 10}, "id": "1", "type": "tool_call"}
|
||||
new_request = original_request.override(tool_call=new_call)
|
||||
|
||||
# Unchanged objects should be the same instance
|
||||
assert new_request.tool is test_tool
|
||||
assert new_request.state is state
|
||||
|
||||
def test_override_chaining(self) -> None:
|
||||
"""Test chaining multiple override calls."""
|
||||
from langchain_core.tools import tool
|
||||
|
||||
@tool
|
||||
def test_tool(x: int) -> str:
|
||||
"""A test tool."""
|
||||
return f"Result: {x}"
|
||||
|
||||
tool_call = {"name": "test_tool", "args": {"x": 5}, "id": "1", "type": "tool_call"}
|
||||
|
||||
original_request = ToolCallRequest(
|
||||
tool_call=tool_call,
|
||||
tool=test_tool,
|
||||
state={"count": 1},
|
||||
runtime=None,
|
||||
)
|
||||
|
||||
call_2 = {"name": "test_tool", "args": {"x": 10}, "id": "1", "type": "tool_call"}
|
||||
call_3 = {"name": "test_tool", "args": {"x": 15}, "id": "1", "type": "tool_call"}
|
||||
|
||||
final_request = (
|
||||
original_request.override(tool_call=call_2)
|
||||
.override(state={"count": 2})
|
||||
.override(tool_call=call_3)
|
||||
)
|
||||
|
||||
assert final_request.tool_call["args"]["x"] == 15
|
||||
assert final_request.state == {"count": 2}
|
||||
# Original unchanged
|
||||
assert original_request.tool_call["args"]["x"] == 5
|
||||
assert original_request.state == {"count": 1}
|
||||
@@ -233,3 +233,169 @@ def test_exclude_tools_prevents_clearing() -> None:
|
||||
|
||||
def _fake_runtime() -> Runtime:
|
||||
return cast(Runtime, object())
|
||||
|
||||
|
||||
async def test_no_edit_when_below_trigger_async() -> None:
|
||||
"""Test async version of context editing with no edit when below trigger."""
|
||||
tool_call_id = "call-1"
|
||||
ai_message = AIMessage(
|
||||
content="",
|
||||
tool_calls=[{"id": tool_call_id, "name": "search", "args": {}}],
|
||||
)
|
||||
tool_message = ToolMessage(content="12345", tool_call_id=tool_call_id)
|
||||
|
||||
state, request = _make_state_and_request([ai_message, tool_message])
|
||||
middleware = ContextEditingMiddleware(
|
||||
edits=[ClearToolUsesEdit(trigger=50)],
|
||||
)
|
||||
|
||||
async def mock_handler(req: ModelRequest) -> AIMessage:
|
||||
return AIMessage(content="mock response")
|
||||
|
||||
# Call awrap_model_call which modifies the request
|
||||
await middleware.awrap_model_call(request, mock_handler)
|
||||
|
||||
# The request should have been modified in place
|
||||
assert request.messages[0].content == ""
|
||||
assert request.messages[1].content == "12345"
|
||||
assert state["messages"] == request.messages
|
||||
|
||||
|
||||
async def test_clear_tool_outputs_and_inputs_async() -> None:
|
||||
"""Test async version of clearing tool outputs and inputs."""
|
||||
tool_call_id = "call-2"
|
||||
ai_message = AIMessage(
|
||||
content=[
|
||||
{"type": "tool_call", "id": tool_call_id, "name": "search", "args": {"query": "foo"}}
|
||||
],
|
||||
tool_calls=[{"id": tool_call_id, "name": "search", "args": {"query": "foo"}}],
|
||||
)
|
||||
tool_message = ToolMessage(content="x" * 200, tool_call_id=tool_call_id)
|
||||
|
||||
state, request = _make_state_and_request([ai_message, tool_message])
|
||||
|
||||
edit = ClearToolUsesEdit(
|
||||
trigger=50,
|
||||
clear_at_least=10,
|
||||
clear_tool_inputs=True,
|
||||
keep=0,
|
||||
placeholder="[cleared output]",
|
||||
)
|
||||
middleware = ContextEditingMiddleware(edits=[edit])
|
||||
|
||||
async def mock_handler(req: ModelRequest) -> AIMessage:
|
||||
return AIMessage(content="mock response")
|
||||
|
||||
# Call awrap_model_call which modifies the request
|
||||
await middleware.awrap_model_call(request, mock_handler)
|
||||
|
||||
cleared_ai = request.messages[0]
|
||||
cleared_tool = request.messages[1]
|
||||
|
||||
assert isinstance(cleared_tool, ToolMessage)
|
||||
assert cleared_tool.content == "[cleared output]"
|
||||
assert cleared_tool.response_metadata["context_editing"]["cleared"] is True
|
||||
|
||||
assert isinstance(cleared_ai, AIMessage)
|
||||
assert cleared_ai.tool_calls[0]["args"] == {}
|
||||
context_meta = cleared_ai.response_metadata.get("context_editing")
|
||||
assert context_meta is not None
|
||||
assert context_meta["cleared_tool_inputs"] == [tool_call_id]
|
||||
|
||||
assert state["messages"] == request.messages
|
||||
|
||||
|
||||
async def test_respects_keep_last_tool_results_async() -> None:
|
||||
"""Test async version respects keep parameter for last tool results."""
|
||||
conversation: list[AIMessage | ToolMessage] = []
|
||||
edits = [
|
||||
("call-a", "tool-output-a" * 5),
|
||||
("call-b", "tool-output-b" * 5),
|
||||
("call-c", "tool-output-c" * 5),
|
||||
]
|
||||
|
||||
for call_id, text in edits:
|
||||
conversation.append(
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[{"id": call_id, "name": "tool", "args": {"input": call_id}}],
|
||||
)
|
||||
)
|
||||
conversation.append(ToolMessage(content=text, tool_call_id=call_id))
|
||||
|
||||
state, request = _make_state_and_request(conversation)
|
||||
|
||||
middleware = ContextEditingMiddleware(
|
||||
edits=[
|
||||
ClearToolUsesEdit(
|
||||
trigger=50,
|
||||
keep=1,
|
||||
placeholder="[cleared]",
|
||||
)
|
||||
],
|
||||
token_count_method="model",
|
||||
)
|
||||
|
||||
async def mock_handler(req: ModelRequest) -> AIMessage:
|
||||
return AIMessage(content="mock response")
|
||||
|
||||
# Call awrap_model_call which modifies the request
|
||||
await middleware.awrap_model_call(request, mock_handler)
|
||||
|
||||
cleared_messages = [
|
||||
msg
|
||||
for msg in request.messages
|
||||
if isinstance(msg, ToolMessage) and msg.content == "[cleared]"
|
||||
]
|
||||
|
||||
assert len(cleared_messages) == 2
|
||||
assert isinstance(request.messages[-1], ToolMessage)
|
||||
assert request.messages[-1].content != "[cleared]"
|
||||
|
||||
|
||||
async def test_exclude_tools_prevents_clearing_async() -> None:
|
||||
"""Test async version of excluding tools from clearing."""
|
||||
search_call = "call-search"
|
||||
calc_call = "call-calc"
|
||||
|
||||
state, request = _make_state_and_request(
|
||||
[
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[{"id": search_call, "name": "search", "args": {"query": "foo"}}],
|
||||
),
|
||||
ToolMessage(content="search-results" * 20, tool_call_id=search_call),
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[{"id": calc_call, "name": "calculator", "args": {"a": 1, "b": 2}}],
|
||||
),
|
||||
ToolMessage(content="42", tool_call_id=calc_call),
|
||||
]
|
||||
)
|
||||
|
||||
middleware = ContextEditingMiddleware(
|
||||
edits=[
|
||||
ClearToolUsesEdit(
|
||||
trigger=50,
|
||||
clear_at_least=10,
|
||||
keep=0,
|
||||
exclude_tools=("search",),
|
||||
placeholder="[cleared]",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
async def mock_handler(req: ModelRequest) -> AIMessage:
|
||||
return AIMessage(content="mock response")
|
||||
|
||||
# Call awrap_model_call which modifies the request
|
||||
await middleware.awrap_model_call(request, mock_handler)
|
||||
|
||||
search_tool = request.messages[1]
|
||||
calc_tool = request.messages[3]
|
||||
|
||||
assert isinstance(search_tool, ToolMessage)
|
||||
assert search_tool.content == "search-results" * 20
|
||||
|
||||
assert isinstance(calc_tool, ToolMessage)
|
||||
assert calc_tool.content == "[cleared]"
|
||||
|
||||
@@ -32,8 +32,8 @@ from langchain.agents.middleware.human_in_the_loop import (
|
||||
Action,
|
||||
HumanInTheLoopMiddleware,
|
||||
)
|
||||
from langchain.agents.middleware.planning import (
|
||||
PlanningMiddleware,
|
||||
from langchain.agents.middleware.todo import (
|
||||
TodoListMiddleware,
|
||||
PlanningState,
|
||||
WRITE_TODOS_SYSTEM_PROMPT,
|
||||
write_todos,
|
||||
@@ -44,7 +44,6 @@ from langchain.agents.middleware.model_call_limit import (
|
||||
ModelCallLimitExceededError,
|
||||
)
|
||||
from langchain.agents.middleware.model_fallback import ModelFallbackMiddleware
|
||||
from langchain.agents.middleware.prompt_caching import AnthropicPromptCachingMiddleware
|
||||
from langchain.agents.middleware.summarization import SummarizationMiddleware
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
@@ -1024,115 +1023,6 @@ def test_human_in_the_loop_middleware_description_as_callable() -> None:
|
||||
assert captured_request["action_requests"][1]["description"] == "Static description"
|
||||
|
||||
|
||||
# Tests for AnthropicPromptCachingMiddleware
|
||||
def test_anthropic_prompt_caching_middleware_initialization() -> None:
|
||||
"""Test AnthropicPromptCachingMiddleware initialization."""
|
||||
# Test with custom values
|
||||
middleware = AnthropicPromptCachingMiddleware(
|
||||
type="ephemeral", ttl="1h", min_messages_to_cache=5
|
||||
)
|
||||
assert middleware.type == "ephemeral"
|
||||
assert middleware.ttl == "1h"
|
||||
assert middleware.min_messages_to_cache == 5
|
||||
|
||||
# Test with default values
|
||||
middleware = AnthropicPromptCachingMiddleware()
|
||||
assert middleware.type == "ephemeral"
|
||||
assert middleware.ttl == "5m"
|
||||
assert middleware.min_messages_to_cache == 0
|
||||
|
||||
fake_request = ModelRequest(
|
||||
model=FakeToolCallingModel(),
|
||||
messages=[HumanMessage("Hello")],
|
||||
system_prompt=None,
|
||||
tool_choice=None,
|
||||
tools=[],
|
||||
response_format=None,
|
||||
state={"messages": [HumanMessage("Hello")]},
|
||||
runtime=cast(Runtime, object()),
|
||||
model_settings={},
|
||||
)
|
||||
|
||||
def mock_handler(req: ModelRequest) -> AIMessage:
|
||||
return AIMessage(content="mock response", **req.model_settings)
|
||||
|
||||
result = middleware.wrap_model_call(fake_request, mock_handler)
|
||||
# Check that model_settings were passed through via the request
|
||||
assert fake_request.model_settings == {"cache_control": {"type": "ephemeral", "ttl": "5m"}}
|
||||
|
||||
|
||||
def test_anthropic_prompt_caching_middleware_unsupported_model() -> None:
|
||||
"""Test AnthropicPromptCachingMiddleware with unsupported model."""
|
||||
from typing import cast
|
||||
|
||||
fake_request = ModelRequest(
|
||||
model=FakeToolCallingModel(),
|
||||
messages=[HumanMessage("Hello")],
|
||||
system_prompt=None,
|
||||
tool_choice=None,
|
||||
tools=[],
|
||||
response_format=None,
|
||||
state={"messages": [HumanMessage("Hello")]},
|
||||
runtime=cast(Runtime, object()),
|
||||
model_settings={},
|
||||
)
|
||||
|
||||
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="raise")
|
||||
|
||||
def mock_handler(req: ModelRequest) -> AIMessage:
|
||||
return AIMessage(content="mock response")
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="AnthropicPromptCachingMiddleware caching middleware only supports Anthropic models. Please install langchain-anthropic.",
|
||||
):
|
||||
middleware.wrap_model_call(fake_request, mock_handler)
|
||||
|
||||
langchain_anthropic = ModuleType("langchain_anthropic")
|
||||
|
||||
class MockChatAnthropic:
|
||||
pass
|
||||
|
||||
langchain_anthropic.ChatAnthropic = MockChatAnthropic
|
||||
|
||||
with patch.dict("sys.modules", {"langchain_anthropic": langchain_anthropic}):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="AnthropicPromptCachingMiddleware caching middleware only supports Anthropic models, not instances of",
|
||||
):
|
||||
middleware.wrap_model_call(fake_request, mock_handler)
|
||||
|
||||
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="warn")
|
||||
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
result = middleware.wrap_model_call(fake_request, mock_handler)
|
||||
assert len(w) == 1
|
||||
assert (
|
||||
"AnthropicPromptCachingMiddleware caching middleware only supports Anthropic models. Please install langchain-anthropic."
|
||||
in str(w[-1].message)
|
||||
)
|
||||
assert isinstance(result, AIMessage)
|
||||
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
with patch.dict("sys.modules", {"langchain_anthropic": langchain_anthropic}):
|
||||
result = middleware.wrap_model_call(fake_request, mock_handler)
|
||||
assert isinstance(result, AIMessage)
|
||||
assert len(w) == 1
|
||||
assert (
|
||||
"AnthropicPromptCachingMiddleware caching middleware only supports Anthropic models, not instances of"
|
||||
in str(w[-1].message)
|
||||
)
|
||||
|
||||
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore")
|
||||
|
||||
result = middleware.wrap_model_call(fake_request, mock_handler)
|
||||
assert isinstance(result, AIMessage)
|
||||
|
||||
with patch.dict("sys.modules", {"langchain_anthropic": {"ChatAnthropic": object()}}):
|
||||
result = middleware.wrap_model_call(fake_request, mock_handler)
|
||||
assert isinstance(result, AIMessage)
|
||||
|
||||
|
||||
# Tests for SummarizationMiddleware
|
||||
def test_summarization_middleware_initialization() -> None:
|
||||
"""Test SummarizationMiddleware initialization."""
|
||||
@@ -1567,10 +1457,10 @@ def test_jump_to_is_ephemeral() -> None:
|
||||
assert "jump_to" not in result
|
||||
|
||||
|
||||
# Tests for PlanningMiddleware
|
||||
def test_planning_middleware_initialization() -> None:
|
||||
"""Test that PlanningMiddleware initializes correctly."""
|
||||
middleware = PlanningMiddleware()
|
||||
# Tests for TodoListMiddleware
|
||||
def test_todo_middleware_initialization() -> None:
|
||||
"""Test that TodoListMiddleware initializes correctly."""
|
||||
middleware = TodoListMiddleware()
|
||||
assert middleware.state_schema == PlanningState
|
||||
assert len(middleware.tools) == 1
|
||||
assert middleware.tools[0].name == "write_todos"
|
||||
@@ -1583,9 +1473,9 @@ def test_planning_middleware_initialization() -> None:
|
||||
(None, "## `write_todos`"),
|
||||
],
|
||||
)
|
||||
def test_planning_middleware_on_model_call(original_prompt, expected_prompt_prefix) -> None:
|
||||
def test_todo_middleware_on_model_call(original_prompt, expected_prompt_prefix) -> None:
|
||||
"""Test that wrap_model_call handles system prompts correctly."""
|
||||
middleware = PlanningMiddleware()
|
||||
middleware = TodoListMiddleware()
|
||||
model = FakeToolCallingModel()
|
||||
|
||||
state: PlanningState = {"messages": [HumanMessage(content="Hello")]}
|
||||
@@ -1636,7 +1526,7 @@ def test_planning_middleware_on_model_call(original_prompt, expected_prompt_pref
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_planning_middleware_write_todos_tool_execution(todos, expected_message) -> None:
|
||||
def test_todo_middleware_write_todos_tool_execution(todos, expected_message) -> None:
|
||||
"""Test that the write_todos tool executes correctly."""
|
||||
tool_call = {
|
||||
"args": {"todos": todos},
|
||||
@@ -1656,7 +1546,7 @@ def test_planning_middleware_write_todos_tool_execution(todos, expected_message)
|
||||
[{"status": "pending"}],
|
||||
],
|
||||
)
|
||||
def test_planning_middleware_write_todos_tool_validation_errors(invalid_todos) -> None:
|
||||
def test_todo_middleware_write_todos_tool_validation_errors(invalid_todos) -> None:
|
||||
"""Test that the write_todos tool rejects invalid input."""
|
||||
tool_call = {
|
||||
"args": {"todos": invalid_todos},
|
||||
@@ -1668,7 +1558,7 @@ def test_planning_middleware_write_todos_tool_validation_errors(invalid_todos) -
|
||||
write_todos.invoke(tool_call)
|
||||
|
||||
|
||||
def test_planning_middleware_agent_creation_with_middleware() -> None:
|
||||
def test_todo_middleware_agent_creation_with_middleware() -> None:
|
||||
"""Test that an agent can be created with the planning middleware."""
|
||||
model = FakeToolCallingModel(
|
||||
tool_calls=[
|
||||
@@ -1699,7 +1589,7 @@ def test_planning_middleware_agent_creation_with_middleware() -> None:
|
||||
[],
|
||||
]
|
||||
)
|
||||
middleware = PlanningMiddleware()
|
||||
middleware = TodoListMiddleware()
|
||||
agent = create_agent(model=model, middleware=[middleware])
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Hello")]})
|
||||
@@ -1716,12 +1606,14 @@ def test_planning_middleware_agent_creation_with_middleware() -> None:
|
||||
assert len(result["messages"]) == 8
|
||||
|
||||
|
||||
def test_planning_middleware_custom_system_prompt() -> None:
|
||||
"""Test that PlanningMiddleware can be initialized with custom system prompt."""
|
||||
def test_todo_middleware_custom_system_prompt() -> None:
|
||||
"""Test that TodoListMiddleware can be initialized with custom system prompt."""
|
||||
custom_system_prompt = "Custom todo system prompt for testing"
|
||||
middleware = PlanningMiddleware(system_prompt=custom_system_prompt)
|
||||
middleware = TodoListMiddleware(system_prompt=custom_system_prompt)
|
||||
model = FakeToolCallingModel()
|
||||
|
||||
state: PlanningState = {"messages": [HumanMessage(content="Hello")]}
|
||||
|
||||
request = ModelRequest(
|
||||
model=model,
|
||||
system_prompt="Original prompt",
|
||||
@@ -1730,10 +1622,10 @@ def test_planning_middleware_custom_system_prompt() -> None:
|
||||
tools=[],
|
||||
response_format=None,
|
||||
model_settings={},
|
||||
state=state,
|
||||
runtime=cast(Runtime, object()),
|
||||
)
|
||||
|
||||
state: PlanningState = {"messages": [HumanMessage(content="Hello")]}
|
||||
|
||||
def mock_handler(req: ModelRequest) -> AIMessage:
|
||||
return AIMessage(content="mock response")
|
||||
|
||||
@@ -1743,21 +1635,21 @@ def test_planning_middleware_custom_system_prompt() -> None:
|
||||
assert request.system_prompt == f"Original prompt\n\n{custom_system_prompt}"
|
||||
|
||||
|
||||
def test_planning_middleware_custom_tool_description() -> None:
|
||||
"""Test that PlanningMiddleware can be initialized with custom tool description."""
|
||||
def test_todo_middleware_custom_tool_description() -> None:
|
||||
"""Test that TodoListMiddleware can be initialized with custom tool description."""
|
||||
custom_tool_description = "Custom tool description for testing"
|
||||
middleware = PlanningMiddleware(tool_description=custom_tool_description)
|
||||
middleware = TodoListMiddleware(tool_description=custom_tool_description)
|
||||
|
||||
assert len(middleware.tools) == 1
|
||||
tool = middleware.tools[0]
|
||||
assert tool.description == custom_tool_description
|
||||
|
||||
|
||||
def test_planning_middleware_custom_system_prompt_and_tool_description() -> None:
|
||||
"""Test that PlanningMiddleware can be initialized with both custom prompts."""
|
||||
def test_todo_middleware_custom_system_prompt_and_tool_description() -> None:
|
||||
"""Test that TodoListMiddleware can be initialized with both custom prompts."""
|
||||
custom_system_prompt = "Custom system prompt"
|
||||
custom_tool_description = "Custom tool description"
|
||||
middleware = PlanningMiddleware(
|
||||
middleware = TodoListMiddleware(
|
||||
system_prompt=custom_system_prompt,
|
||||
tool_description=custom_tool_description,
|
||||
)
|
||||
@@ -1792,9 +1684,9 @@ def test_planning_middleware_custom_system_prompt_and_tool_description() -> None
|
||||
assert tool.description == custom_tool_description
|
||||
|
||||
|
||||
def test_planning_middleware_default_prompts() -> None:
|
||||
"""Test that PlanningMiddleware uses default prompts when none provided."""
|
||||
middleware = PlanningMiddleware()
|
||||
def test_todo_middleware_default_prompts() -> None:
|
||||
"""Test that TodoListMiddleware uses default prompts when none provided."""
|
||||
middleware = TodoListMiddleware()
|
||||
|
||||
# Verify default system prompt
|
||||
assert middleware.system_prompt == WRITE_TODOS_SYSTEM_PROMPT
|
||||
@@ -1806,9 +1698,9 @@ def test_planning_middleware_default_prompts() -> None:
|
||||
assert tool.description == WRITE_TODOS_TOOL_DESCRIPTION
|
||||
|
||||
|
||||
def test_planning_middleware_custom_system_prompt() -> None:
|
||||
def test_todo_middleware_custom_system_prompt_in_agent() -> None:
|
||||
"""Test that custom tool executes correctly in an agent."""
|
||||
middleware = PlanningMiddleware(system_prompt="call the write_todos tool")
|
||||
middleware = TodoListMiddleware(system_prompt="call the write_todos tool")
|
||||
|
||||
model = FakeToolCallingModel(
|
||||
tool_calls=[
|
||||
|
||||
@@ -0,0 +1,215 @@
|
||||
"""Unit tests for ModelFallbackMiddleware."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
from langchain.agents.middleware.model_fallback import ModelFallbackMiddleware
|
||||
from langchain.agents.middleware.types import ModelRequest, ModelResponse
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
|
||||
def _fake_runtime() -> Runtime:
|
||||
return cast(Runtime, object())
|
||||
|
||||
|
||||
def _make_request() -> ModelRequest:
|
||||
"""Create a minimal ModelRequest for testing."""
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="primary")]))
|
||||
return ModelRequest(
|
||||
model=model,
|
||||
system_prompt=None,
|
||||
messages=[],
|
||||
tool_choice=None,
|
||||
tools=[],
|
||||
response_format=None,
|
||||
state=cast("AgentState", {}), # type: ignore[name-defined]
|
||||
runtime=_fake_runtime(),
|
||||
model_settings={},
|
||||
)
|
||||
|
||||
|
||||
def test_primary_model_succeeds() -> None:
|
||||
"""Test that primary model is used when it succeeds."""
|
||||
primary_model = GenericFakeChatModel(messages=iter([AIMessage(content="primary response")]))
|
||||
fallback_model = GenericFakeChatModel(messages=iter([AIMessage(content="fallback response")]))
|
||||
|
||||
middleware = ModelFallbackMiddleware(fallback_model)
|
||||
request = _make_request()
|
||||
request.model = primary_model
|
||||
|
||||
def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
# Simulate successful model call
|
||||
result = req.model.invoke([])
|
||||
return ModelResponse(result=[result])
|
||||
|
||||
response = middleware.wrap_model_call(request, mock_handler)
|
||||
|
||||
assert isinstance(response, ModelResponse)
|
||||
assert response.result[0].content == "primary response"
|
||||
|
||||
|
||||
def test_fallback_on_primary_failure() -> None:
|
||||
"""Test that fallback model is used when primary fails."""
|
||||
|
||||
class FailingPrimaryModel(GenericFakeChatModel):
|
||||
def _generate(self, messages, **kwargs):
|
||||
raise ValueError("Primary model failed")
|
||||
|
||||
primary_model = FailingPrimaryModel(messages=iter([AIMessage(content="should not see")]))
|
||||
fallback_model = GenericFakeChatModel(messages=iter([AIMessage(content="fallback response")]))
|
||||
|
||||
middleware = ModelFallbackMiddleware(fallback_model)
|
||||
request = _make_request()
|
||||
request.model = primary_model
|
||||
|
||||
def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
result = req.model.invoke([])
|
||||
return ModelResponse(result=[result])
|
||||
|
||||
response = middleware.wrap_model_call(request, mock_handler)
|
||||
|
||||
assert isinstance(response, ModelResponse)
|
||||
assert response.result[0].content == "fallback response"
|
||||
|
||||
|
||||
def test_multiple_fallbacks() -> None:
|
||||
"""Test that multiple fallback models are tried in sequence."""
|
||||
|
||||
class FailingModel(GenericFakeChatModel):
|
||||
def _generate(self, messages, **kwargs):
|
||||
raise ValueError("Model failed")
|
||||
|
||||
primary_model = FailingModel(messages=iter([AIMessage(content="should not see")]))
|
||||
fallback1 = FailingModel(messages=iter([AIMessage(content="fallback1")]))
|
||||
fallback2 = GenericFakeChatModel(messages=iter([AIMessage(content="fallback2")]))
|
||||
|
||||
middleware = ModelFallbackMiddleware(fallback1, fallback2)
|
||||
request = _make_request()
|
||||
request.model = primary_model
|
||||
|
||||
def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
result = req.model.invoke([])
|
||||
return ModelResponse(result=[result])
|
||||
|
||||
response = middleware.wrap_model_call(request, mock_handler)
|
||||
|
||||
assert isinstance(response, ModelResponse)
|
||||
assert response.result[0].content == "fallback2"
|
||||
|
||||
|
||||
def test_all_models_fail() -> None:
|
||||
"""Test that exception is raised when all models fail."""
|
||||
|
||||
class AlwaysFailingModel(GenericFakeChatModel):
|
||||
def _generate(self, messages, **kwargs):
|
||||
raise ValueError("Model failed")
|
||||
|
||||
primary_model = AlwaysFailingModel(messages=iter([]))
|
||||
fallback_model = AlwaysFailingModel(messages=iter([]))
|
||||
|
||||
middleware = ModelFallbackMiddleware(fallback_model)
|
||||
request = _make_request()
|
||||
request.model = primary_model
|
||||
|
||||
def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
result = req.model.invoke([])
|
||||
return ModelResponse(result=[result])
|
||||
|
||||
with pytest.raises(ValueError, match="Model failed"):
|
||||
middleware.wrap_model_call(request, mock_handler)
|
||||
|
||||
|
||||
async def test_primary_model_succeeds_async() -> None:
|
||||
"""Test async version - primary model is used when it succeeds."""
|
||||
primary_model = GenericFakeChatModel(messages=iter([AIMessage(content="primary response")]))
|
||||
fallback_model = GenericFakeChatModel(messages=iter([AIMessage(content="fallback response")]))
|
||||
|
||||
middleware = ModelFallbackMiddleware(fallback_model)
|
||||
request = _make_request()
|
||||
request.model = primary_model
|
||||
|
||||
async def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
# Simulate successful async model call
|
||||
result = await req.model.ainvoke([])
|
||||
return ModelResponse(result=[result])
|
||||
|
||||
response = await middleware.awrap_model_call(request, mock_handler)
|
||||
|
||||
assert isinstance(response, ModelResponse)
|
||||
assert response.result[0].content == "primary response"
|
||||
|
||||
|
||||
async def test_fallback_on_primary_failure_async() -> None:
|
||||
"""Test async version - fallback model is used when primary fails."""
|
||||
|
||||
class AsyncFailingPrimaryModel(GenericFakeChatModel):
|
||||
async def _agenerate(self, messages, **kwargs):
|
||||
raise ValueError("Primary model failed")
|
||||
|
||||
primary_model = AsyncFailingPrimaryModel(messages=iter([AIMessage(content="should not see")]))
|
||||
fallback_model = GenericFakeChatModel(messages=iter([AIMessage(content="fallback response")]))
|
||||
|
||||
middleware = ModelFallbackMiddleware(fallback_model)
|
||||
request = _make_request()
|
||||
request.model = primary_model
|
||||
|
||||
async def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
result = await req.model.ainvoke([])
|
||||
return ModelResponse(result=[result])
|
||||
|
||||
response = await middleware.awrap_model_call(request, mock_handler)
|
||||
|
||||
assert isinstance(response, ModelResponse)
|
||||
assert response.result[0].content == "fallback response"
|
||||
|
||||
|
||||
async def test_multiple_fallbacks_async() -> None:
|
||||
"""Test async version - multiple fallback models are tried in sequence."""
|
||||
|
||||
class AsyncFailingModel(GenericFakeChatModel):
|
||||
async def _agenerate(self, messages, **kwargs):
|
||||
raise ValueError("Model failed")
|
||||
|
||||
primary_model = AsyncFailingModel(messages=iter([AIMessage(content="should not see")]))
|
||||
fallback1 = AsyncFailingModel(messages=iter([AIMessage(content="fallback1")]))
|
||||
fallback2 = GenericFakeChatModel(messages=iter([AIMessage(content="fallback2")]))
|
||||
|
||||
middleware = ModelFallbackMiddleware(fallback1, fallback2)
|
||||
request = _make_request()
|
||||
request.model = primary_model
|
||||
|
||||
async def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
result = await req.model.ainvoke([])
|
||||
return ModelResponse(result=[result])
|
||||
|
||||
response = await middleware.awrap_model_call(request, mock_handler)
|
||||
|
||||
assert isinstance(response, ModelResponse)
|
||||
assert response.result[0].content == "fallback2"
|
||||
|
||||
|
||||
async def test_all_models_fail_async() -> None:
|
||||
"""Test async version - exception is raised when all models fail."""
|
||||
|
||||
class AsyncAlwaysFailingModel(GenericFakeChatModel):
|
||||
async def _agenerate(self, messages, **kwargs):
|
||||
raise ValueError("Model failed")
|
||||
|
||||
primary_model = AsyncAlwaysFailingModel(messages=iter([]))
|
||||
fallback_model = AsyncAlwaysFailingModel(messages=iter([]))
|
||||
|
||||
middleware = ModelFallbackMiddleware(fallback_model)
|
||||
request = _make_request()
|
||||
request.model = primary_model
|
||||
|
||||
async def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
result = await req.model.ainvoke([])
|
||||
return ModelResponse(result=[result])
|
||||
|
||||
with pytest.raises(ValueError, match="Model failed"):
|
||||
await middleware.awrap_model_call(request, mock_handler)
|
||||
@@ -0,0 +1,172 @@
|
||||
"""Unit tests for TodoListMiddleware."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
from langchain.agents.middleware.todo import TodoListMiddleware
|
||||
from langchain.agents.middleware.types import ModelRequest, ModelResponse
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
|
||||
def _fake_runtime() -> Runtime:
|
||||
return cast(Runtime, object())
|
||||
|
||||
|
||||
def _make_request(system_prompt: str | None = None) -> ModelRequest:
|
||||
"""Create a minimal ModelRequest for testing."""
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="response")]))
|
||||
return ModelRequest(
|
||||
model=model,
|
||||
system_prompt=system_prompt,
|
||||
messages=[],
|
||||
tool_choice=None,
|
||||
tools=[],
|
||||
response_format=None,
|
||||
state=cast("AgentState", {}), # type: ignore[name-defined]
|
||||
runtime=_fake_runtime(),
|
||||
model_settings={},
|
||||
)
|
||||
|
||||
|
||||
def test_adds_system_prompt_when_none_exists() -> None:
|
||||
"""Test that middleware adds system prompt when request has none."""
|
||||
middleware = TodoListMiddleware()
|
||||
request = _make_request(system_prompt=None)
|
||||
|
||||
def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
return ModelResponse(result=[AIMessage(content="response")])
|
||||
|
||||
middleware.wrap_model_call(request, mock_handler)
|
||||
|
||||
# System prompt should be set
|
||||
assert request.system_prompt is not None
|
||||
assert "write_todos" in request.system_prompt
|
||||
|
||||
|
||||
def test_appends_to_existing_system_prompt() -> None:
|
||||
"""Test that middleware appends to existing system prompt."""
|
||||
existing_prompt = "You are a helpful assistant."
|
||||
middleware = TodoListMiddleware()
|
||||
request = _make_request(system_prompt=existing_prompt)
|
||||
|
||||
def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
return ModelResponse(result=[AIMessage(content="response")])
|
||||
|
||||
middleware.wrap_model_call(request, mock_handler)
|
||||
|
||||
# System prompt should contain both
|
||||
assert request.system_prompt is not None
|
||||
assert existing_prompt in request.system_prompt
|
||||
assert "write_todos" in request.system_prompt
|
||||
assert request.system_prompt.startswith(existing_prompt)
|
||||
|
||||
|
||||
def test_custom_system_prompt() -> None:
|
||||
"""Test that middleware uses custom system prompt."""
|
||||
custom_prompt = "Custom planning instructions"
|
||||
middleware = TodoListMiddleware(system_prompt=custom_prompt)
|
||||
request = _make_request(system_prompt=None)
|
||||
|
||||
def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
return ModelResponse(result=[AIMessage(content="response")])
|
||||
|
||||
middleware.wrap_model_call(request, mock_handler)
|
||||
|
||||
# Should use custom prompt
|
||||
assert request.system_prompt == custom_prompt
|
||||
|
||||
|
||||
def test_has_write_todos_tool() -> None:
|
||||
"""Test that middleware registers the write_todos tool."""
|
||||
middleware = TodoListMiddleware()
|
||||
|
||||
# Should have one tool registered
|
||||
assert len(middleware.tools) == 1
|
||||
assert middleware.tools[0].name == "write_todos"
|
||||
|
||||
|
||||
def test_custom_tool_description() -> None:
|
||||
"""Test that middleware uses custom tool description."""
|
||||
custom_description = "Custom todo tool description"
|
||||
middleware = TodoListMiddleware(tool_description=custom_description)
|
||||
|
||||
# Tool should use custom description
|
||||
assert len(middleware.tools) == 1
|
||||
assert middleware.tools[0].description == custom_description
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Async Tests
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
async def test_adds_system_prompt_when_none_exists_async() -> None:
|
||||
"""Test async version - middleware adds system prompt when request has none."""
|
||||
middleware = TodoListMiddleware()
|
||||
request = _make_request(system_prompt=None)
|
||||
|
||||
async def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
return ModelResponse(result=[AIMessage(content="response")])
|
||||
|
||||
await middleware.awrap_model_call(request, mock_handler)
|
||||
|
||||
# System prompt should be set
|
||||
assert request.system_prompt is not None
|
||||
assert "write_todos" in request.system_prompt
|
||||
|
||||
|
||||
async def test_appends_to_existing_system_prompt_async() -> None:
|
||||
"""Test async version - middleware appends to existing system prompt."""
|
||||
existing_prompt = "You are a helpful assistant."
|
||||
middleware = TodoListMiddleware()
|
||||
request = _make_request(system_prompt=existing_prompt)
|
||||
|
||||
async def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
return ModelResponse(result=[AIMessage(content="response")])
|
||||
|
||||
await middleware.awrap_model_call(request, mock_handler)
|
||||
|
||||
# System prompt should contain both
|
||||
assert request.system_prompt is not None
|
||||
assert existing_prompt in request.system_prompt
|
||||
assert "write_todos" in request.system_prompt
|
||||
assert request.system_prompt.startswith(existing_prompt)
|
||||
|
||||
|
||||
async def test_custom_system_prompt_async() -> None:
|
||||
"""Test async version - middleware uses custom system prompt."""
|
||||
custom_prompt = "Custom planning instructions"
|
||||
middleware = TodoListMiddleware(system_prompt=custom_prompt)
|
||||
request = _make_request(system_prompt=None)
|
||||
|
||||
async def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
return ModelResponse(result=[AIMessage(content="response")])
|
||||
|
||||
await middleware.awrap_model_call(request, mock_handler)
|
||||
|
||||
# Should use custom prompt
|
||||
assert request.system_prompt == custom_prompt
|
||||
|
||||
|
||||
async def test_handler_called_with_modified_request_async() -> None:
|
||||
"""Test async version - handler receives the modified request."""
|
||||
middleware = TodoListMiddleware()
|
||||
request = _make_request(system_prompt="Original")
|
||||
handler_called = {"value": False}
|
||||
received_prompt = {"value": None}
|
||||
|
||||
async def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
handler_called["value"] = True
|
||||
received_prompt["value"] = req.system_prompt
|
||||
return ModelResponse(result=[AIMessage(content="response")])
|
||||
|
||||
await middleware.awrap_model_call(request, mock_handler)
|
||||
|
||||
assert handler_called["value"]
|
||||
assert received_prompt["value"] is not None
|
||||
assert "Original" in received_prompt["value"]
|
||||
assert "write_todos" in received_prompt["value"]
|
||||
@@ -7,7 +7,6 @@ EXPECTED_ALL = {
|
||||
"InjectedToolArg",
|
||||
"InjectedToolCallId",
|
||||
"ToolException",
|
||||
"ToolInvocationError",
|
||||
"tool",
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,9 @@
|
||||
"""Middleware for Anthropic models."""
|
||||
|
||||
from langchain_anthropic.middleware.prompt_caching import (
|
||||
AnthropicPromptCachingMiddleware,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AnthropicPromptCachingMiddleware",
|
||||
]
|
||||
@@ -0,0 +1,123 @@
|
||||
"""Anthropic prompt caching middleware.
|
||||
|
||||
Requires:
|
||||
- langchain: For agent middleware framework
|
||||
- langchain-anthropic: For ChatAnthropic model (already a dependency)
|
||||
"""
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Literal
|
||||
from warnings import warn
|
||||
|
||||
from langchain_anthropic.chat_models import ChatAnthropic
|
||||
|
||||
try:
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
ModelCallResult,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
)
|
||||
except ImportError as e:
|
||||
msg = (
|
||||
"AnthropicPromptCachingMiddleware requires 'langchain' to be installed. "
|
||||
"This middleware is designed for use with LangChain agents. "
|
||||
"Install it with: pip install langchain"
|
||||
)
|
||||
raise ImportError(msg) from e
|
||||
|
||||
|
||||
class AnthropicPromptCachingMiddleware(AgentMiddleware):
|
||||
"""Prompt Caching Middleware.
|
||||
|
||||
Optimizes API usage by caching conversation prefixes for Anthropic models.
|
||||
|
||||
Requires both 'langchain' and 'langchain-anthropic' packages to be installed.
|
||||
|
||||
Learn more about Anthropic prompt caching
|
||||
[here](https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
type: Literal["ephemeral"] = "ephemeral", # noqa: A002
|
||||
ttl: Literal["5m", "1h"] = "5m",
|
||||
min_messages_to_cache: int = 0,
|
||||
unsupported_model_behavior: Literal["ignore", "warn", "raise"] = "warn",
|
||||
) -> None:
|
||||
"""Initialize the middleware with cache control settings.
|
||||
|
||||
Args:
|
||||
type: The type of cache to use, only "ephemeral" is supported.
|
||||
ttl: The time to live for the cache, only "5m" and "1h" are
|
||||
supported.
|
||||
min_messages_to_cache: The minimum number of messages until the
|
||||
cache is used, default is 0.
|
||||
unsupported_model_behavior: The behavior to take when an
|
||||
unsupported model is used. "ignore" will ignore the unsupported
|
||||
model and continue without caching. "warn" will warn the user
|
||||
and continue without caching. "raise" will raise an error and
|
||||
stop the agent.
|
||||
"""
|
||||
self.type = type
|
||||
self.ttl = ttl
|
||||
self.min_messages_to_cache = min_messages_to_cache
|
||||
self.unsupported_model_behavior = unsupported_model_behavior
|
||||
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
"""Modify the model request to add cache control blocks."""
|
||||
if not isinstance(request.model, ChatAnthropic):
|
||||
msg = (
|
||||
"AnthropicPromptCachingMiddleware caching middleware only supports "
|
||||
f"Anthropic models, not instances of {type(request.model)}"
|
||||
)
|
||||
if self.unsupported_model_behavior == "raise":
|
||||
raise ValueError(msg)
|
||||
if self.unsupported_model_behavior == "warn":
|
||||
warn(msg, stacklevel=3)
|
||||
return handler(request)
|
||||
|
||||
messages_count = (
|
||||
len(request.messages) + 1
|
||||
if request.system_prompt
|
||||
else len(request.messages)
|
||||
)
|
||||
if messages_count < self.min_messages_to_cache:
|
||||
return handler(request)
|
||||
|
||||
request.model_settings["cache_control"] = {"type": self.type, "ttl": self.ttl}
|
||||
|
||||
return handler(request)
|
||||
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> ModelCallResult:
|
||||
"""Modify the model request to add cache control blocks (async version)."""
|
||||
if not isinstance(request.model, ChatAnthropic):
|
||||
msg = (
|
||||
"AnthropicPromptCachingMiddleware caching middleware only supports "
|
||||
f"Anthropic models, not instances of {type(request.model)}"
|
||||
)
|
||||
if self.unsupported_model_behavior == "raise":
|
||||
raise ValueError(msg)
|
||||
if self.unsupported_model_behavior == "warn":
|
||||
warn(msg, stacklevel=3)
|
||||
return await handler(request)
|
||||
|
||||
messages_count = (
|
||||
len(request.messages) + 1
|
||||
if request.system_prompt
|
||||
else len(request.messages)
|
||||
)
|
||||
if messages_count < self.min_messages_to_cache:
|
||||
return await handler(request)
|
||||
|
||||
request.model_settings["cache_control"] = {"type": self.type, "ttl": self.ttl}
|
||||
|
||||
return await handler(request)
|
||||
@@ -41,6 +41,7 @@ test = [
|
||||
"vcrpy>=7.0.0,<8.0.0",
|
||||
"langchain-core",
|
||||
"langchain-tests",
|
||||
"langchain",
|
||||
]
|
||||
lint = ["ruff>=0.13.1,<0.14.0"]
|
||||
dev = ["langchain-core"]
|
||||
@@ -55,6 +56,7 @@ typing = [
|
||||
[tool.uv.sources]
|
||||
langchain-core = { path = "../../core", editable = true }
|
||||
langchain-tests = { path = "../../standard-tests", editable = true }
|
||||
langchain = { path = "../../langchain_v1", editable = true }
|
||||
|
||||
[tool.mypy]
|
||||
disallow_untyped_defs = "True"
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
"""Tests for Anthropic middleware."""
|
||||
@@ -0,0 +1,246 @@
|
||||
"""Tests for Anthropic prompt caching middleware."""
|
||||
|
||||
import warnings
|
||||
from typing import Any, cast
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from langchain.agents.middleware.types import ModelRequest, ModelResponse
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from langchain_anthropic.chat_models import ChatAnthropic
|
||||
from langchain_anthropic.middleware import AnthropicPromptCachingMiddleware
|
||||
|
||||
|
||||
class FakeToolCallingModel(BaseChatModel):
|
||||
"""Fake model for testing middleware."""
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""Top Level call"""
|
||||
messages_string = "-".join([str(m.content) for m in messages])
|
||||
message = AIMessage(content=messages_string, id="0")
|
||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: AsyncCallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""Async top level call"""
|
||||
messages_string = "-".join([str(m.content) for m in messages])
|
||||
message = AIMessage(content=messages_string, id="0")
|
||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "fake-tool-call-model"
|
||||
|
||||
|
||||
def test_anthropic_prompt_caching_middleware_initialization() -> None:
|
||||
"""Test AnthropicPromptCachingMiddleware initialization."""
|
||||
# Test with custom values
|
||||
middleware = AnthropicPromptCachingMiddleware(
|
||||
type="ephemeral", ttl="1h", min_messages_to_cache=5
|
||||
)
|
||||
assert middleware.type == "ephemeral"
|
||||
assert middleware.ttl == "1h"
|
||||
assert middleware.min_messages_to_cache == 5
|
||||
|
||||
# Test with default values
|
||||
middleware = AnthropicPromptCachingMiddleware()
|
||||
assert middleware.type == "ephemeral"
|
||||
assert middleware.ttl == "5m"
|
||||
assert middleware.min_messages_to_cache == 0
|
||||
|
||||
# Create a mock ChatAnthropic instance
|
||||
mock_chat_anthropic = MagicMock(spec=ChatAnthropic)
|
||||
|
||||
fake_request = ModelRequest(
|
||||
model=mock_chat_anthropic,
|
||||
messages=[HumanMessage("Hello")],
|
||||
system_prompt=None,
|
||||
tool_choice=None,
|
||||
tools=[],
|
||||
response_format=None,
|
||||
state={"messages": [HumanMessage("Hello")]},
|
||||
runtime=cast(Runtime, object()),
|
||||
model_settings={},
|
||||
)
|
||||
|
||||
def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
return ModelResponse(result=[AIMessage(content="mock response")])
|
||||
|
||||
middleware.wrap_model_call(fake_request, mock_handler)
|
||||
# Check that model_settings were passed through via the request
|
||||
assert fake_request.model_settings == {
|
||||
"cache_control": {"type": "ephemeral", "ttl": "5m"}
|
||||
}
|
||||
|
||||
|
||||
def test_anthropic_prompt_caching_middleware_unsupported_model() -> None:
|
||||
"""Test AnthropicPromptCachingMiddleware with unsupported model."""
|
||||
fake_request = ModelRequest(
|
||||
model=FakeToolCallingModel(),
|
||||
messages=[HumanMessage("Hello")],
|
||||
system_prompt=None,
|
||||
tool_choice=None,
|
||||
tools=[],
|
||||
response_format=None,
|
||||
state={"messages": [HumanMessage("Hello")]},
|
||||
runtime=cast(Runtime, object()),
|
||||
model_settings={},
|
||||
)
|
||||
|
||||
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="raise")
|
||||
|
||||
def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
return ModelResponse(result=[AIMessage(content="mock response")])
|
||||
|
||||
# Since we're in the langchain-anthropic package, ChatAnthropic is always
|
||||
# available. Test that it raises an error for unsupported model instances
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=(
|
||||
"AnthropicPromptCachingMiddleware caching middleware only supports "
|
||||
"Anthropic models, not instances of"
|
||||
),
|
||||
):
|
||||
middleware.wrap_model_call(fake_request, mock_handler)
|
||||
|
||||
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="warn")
|
||||
|
||||
# Test warn behavior for unsupported model instances
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
result = middleware.wrap_model_call(fake_request, mock_handler)
|
||||
assert isinstance(result, ModelResponse)
|
||||
assert len(w) == 1
|
||||
assert (
|
||||
"AnthropicPromptCachingMiddleware caching middleware only supports "
|
||||
"Anthropic models, not instances of"
|
||||
) in str(w[-1].message)
|
||||
|
||||
# Test ignore behavior
|
||||
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore")
|
||||
result = middleware.wrap_model_call(fake_request, mock_handler)
|
||||
assert isinstance(result, ModelResponse)
|
||||
|
||||
|
||||
async def test_anthropic_prompt_caching_middleware_async() -> None:
|
||||
"""Test AnthropicPromptCachingMiddleware async path."""
|
||||
# Test with custom values
|
||||
middleware = AnthropicPromptCachingMiddleware(
|
||||
type="ephemeral", ttl="1h", min_messages_to_cache=5
|
||||
)
|
||||
|
||||
# Create a mock ChatAnthropic instance
|
||||
mock_chat_anthropic = MagicMock(spec=ChatAnthropic)
|
||||
|
||||
fake_request = ModelRequest(
|
||||
model=mock_chat_anthropic,
|
||||
messages=[HumanMessage("Hello")] * 6,
|
||||
system_prompt=None,
|
||||
tool_choice=None,
|
||||
tools=[],
|
||||
response_format=None,
|
||||
state={"messages": [HumanMessage("Hello")] * 6},
|
||||
runtime=cast(Runtime, object()),
|
||||
model_settings={},
|
||||
)
|
||||
|
||||
async def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
return ModelResponse(result=[AIMessage(content="mock response")])
|
||||
|
||||
result = await middleware.awrap_model_call(fake_request, mock_handler)
|
||||
assert isinstance(result, ModelResponse)
|
||||
# Check that model_settings were passed through via the request
|
||||
assert fake_request.model_settings == {
|
||||
"cache_control": {"type": "ephemeral", "ttl": "1h"}
|
||||
}
|
||||
|
||||
|
||||
async def test_anthropic_prompt_caching_middleware_async_unsupported_model() -> None:
|
||||
"""Test AnthropicPromptCachingMiddleware async path with unsupported model."""
|
||||
fake_request = ModelRequest(
|
||||
model=FakeToolCallingModel(),
|
||||
messages=[HumanMessage("Hello")],
|
||||
system_prompt=None,
|
||||
tool_choice=None,
|
||||
tools=[],
|
||||
response_format=None,
|
||||
state={"messages": [HumanMessage("Hello")]},
|
||||
runtime=cast(Runtime, object()),
|
||||
model_settings={},
|
||||
)
|
||||
|
||||
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="raise")
|
||||
|
||||
async def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
return ModelResponse(result=[AIMessage(content="mock response")])
|
||||
|
||||
# Test that it raises an error for unsupported model instances
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=(
|
||||
"AnthropicPromptCachingMiddleware caching middleware only supports "
|
||||
"Anthropic models, not instances of"
|
||||
),
|
||||
):
|
||||
await middleware.awrap_model_call(fake_request, mock_handler)
|
||||
|
||||
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="warn")
|
||||
|
||||
# Test warn behavior for unsupported model instances
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
result = await middleware.awrap_model_call(fake_request, mock_handler)
|
||||
assert isinstance(result, ModelResponse)
|
||||
assert len(w) == 1
|
||||
assert (
|
||||
"AnthropicPromptCachingMiddleware caching middleware only supports "
|
||||
"Anthropic models, not instances of"
|
||||
) in str(w[-1].message)
|
||||
|
||||
# Test ignore behavior
|
||||
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore")
|
||||
result = await middleware.awrap_model_call(fake_request, mock_handler)
|
||||
assert isinstance(result, ModelResponse)
|
||||
|
||||
|
||||
async def test_anthropic_prompt_caching_middleware_async_min_messages() -> None:
|
||||
"""Test async path respects min_messages_to_cache."""
|
||||
middleware = AnthropicPromptCachingMiddleware(min_messages_to_cache=5)
|
||||
|
||||
# Test with fewer messages than minimum
|
||||
fake_request = ModelRequest(
|
||||
model=FakeToolCallingModel(),
|
||||
messages=[HumanMessage("Hello")] * 3,
|
||||
system_prompt=None,
|
||||
tool_choice=None,
|
||||
tools=[],
|
||||
response_format=None,
|
||||
state={"messages": [HumanMessage("Hello")] * 3},
|
||||
runtime=cast(Runtime, object()),
|
||||
model_settings={},
|
||||
)
|
||||
|
||||
async def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
return ModelResponse(result=[AIMessage(content="mock response")])
|
||||
|
||||
result = await middleware.awrap_model_call(fake_request, mock_handler)
|
||||
assert isinstance(result, ModelResponse)
|
||||
# Cache control should NOT be added when message count is below minimum
|
||||
assert fake_request.model_settings == {}
|
||||
1382
libs/partners/anthropic/uv.lock
generated
1382
libs/partners/anthropic/uv.lock
generated
File diff suppressed because it is too large
Load Diff
71
libs/partners/groq/langchain_groq/_compat.py
Normal file
71
libs/partners/groq/langchain_groq/_compat.py
Normal file
@@ -0,0 +1,71 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any, cast
|
||||
|
||||
from langchain_core.messages import content as types
|
||||
|
||||
|
||||
def _convert_from_v1_to_groq(
|
||||
content: list[types.ContentBlock],
|
||||
model_provider: str | None,
|
||||
) -> tuple[list[dict[str, Any] | str], dict]:
|
||||
new_content: list = []
|
||||
new_additional_kwargs: dict = {}
|
||||
for i, block in enumerate(content):
|
||||
if block["type"] == "text":
|
||||
new_content.append({"text": block.get("text", ""), "type": "text"})
|
||||
|
||||
elif (
|
||||
block["type"] == "reasoning"
|
||||
and (reasoning := block.get("reasoning"))
|
||||
and model_provider == "groq"
|
||||
):
|
||||
new_additional_kwargs["reasoning_content"] = reasoning
|
||||
|
||||
elif block["type"] == "server_tool_call" and model_provider == "groq":
|
||||
new_block = {}
|
||||
if "args" in block:
|
||||
new_block["arguments"] = json.dumps(block["args"])
|
||||
if idx := block.get("extras", {}).get("index"):
|
||||
new_block["index"] = idx
|
||||
if block.get("name") == "web_search":
|
||||
new_block["type"] = "search"
|
||||
elif block.get("name") == "code_interpreter":
|
||||
new_block["type"] = "python"
|
||||
else:
|
||||
new_block["type"] = ""
|
||||
|
||||
if i < len(content) - 1 and content[i + 1]["type"] == "server_tool_result":
|
||||
result = cast("types.ServerToolResult", content[i + 1])
|
||||
for k, v in result.get("extras", {}).items():
|
||||
new_block[k] = v # noqa: PERF403
|
||||
if "output" in result:
|
||||
new_block["output"] = result["output"]
|
||||
|
||||
if "executed_tools" not in new_additional_kwargs:
|
||||
new_additional_kwargs["executed_tools"] = []
|
||||
new_additional_kwargs["executed_tools"].append(new_block)
|
||||
elif block["type"] == "server_tool_result":
|
||||
continue
|
||||
|
||||
elif (
|
||||
block["type"] == "non_standard"
|
||||
and "value" in block
|
||||
and model_provider == "groq"
|
||||
):
|
||||
new_content.append(block["value"])
|
||||
else:
|
||||
new_content.append(block)
|
||||
|
||||
# For consistency with v0 payloads, we cast single text blocks to str
|
||||
if (
|
||||
len(new_content) == 1
|
||||
and isinstance(new_content[0], dict)
|
||||
and new_content[0].get("type") == "text"
|
||||
and (text_content := new_content[0].get("text"))
|
||||
and isinstance(text_content, str)
|
||||
):
|
||||
return text_content, new_additional_kwargs
|
||||
|
||||
return new_content, new_additional_kwargs
|
||||
@@ -57,6 +57,7 @@ from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from langchain_groq._compat import _convert_from_v1_to_groq
|
||||
from langchain_groq.version import __version__
|
||||
|
||||
|
||||
@@ -1187,6 +1188,17 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
elif isinstance(message, HumanMessage):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
elif isinstance(message, AIMessage):
|
||||
# Translate v1 content
|
||||
if message.response_metadata.get("output_version") == "v1":
|
||||
new_content, new_additional_kwargs = _convert_from_v1_to_groq(
|
||||
message.content_blocks, message.response_metadata.get("model_provider")
|
||||
)
|
||||
message = message.model_copy(
|
||||
update={
|
||||
"content": new_content,
|
||||
"additional_kwargs": new_additional_kwargs,
|
||||
}
|
||||
)
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
|
||||
# If content is a list of content blocks, filter out tool_call blocks
|
||||
@@ -1268,6 +1280,22 @@ def _convert_chunk_to_message_chunk(
|
||||
if role == "assistant" or default_class == AIMessageChunk:
|
||||
if reasoning := _dict.get("reasoning"):
|
||||
additional_kwargs["reasoning_content"] = reasoning
|
||||
if executed_tools := _dict.get("executed_tools"):
|
||||
additional_kwargs["executed_tools"] = []
|
||||
for executed_tool in executed_tools:
|
||||
if executed_tool.get("output"):
|
||||
# Tool output duplicates query and other server tool call data
|
||||
additional_kwargs["executed_tools"].append(
|
||||
{
|
||||
k: executed_tool[k]
|
||||
for k in ("index", "output")
|
||||
if k in executed_tool
|
||||
}
|
||||
)
|
||||
else:
|
||||
additional_kwargs["executed_tools"].append(
|
||||
{k: executed_tool[k] for k in executed_tool if k != "output"}
|
||||
)
|
||||
if usage := (chunk.get("x_groq") or {}).get("usage"):
|
||||
input_tokens = usage.get("prompt_tokens", 0)
|
||||
output_tokens = usage.get("completion_tokens", 0)
|
||||
@@ -1282,6 +1310,7 @@ def _convert_chunk_to_message_chunk(
|
||||
content=content,
|
||||
additional_kwargs=additional_kwargs,
|
||||
usage_metadata=usage_metadata, # type: ignore[arg-type]
|
||||
response_metadata={"model_provider": "groq"},
|
||||
)
|
||||
if role == "system" or default_class == SystemMessageChunk:
|
||||
return SystemMessageChunk(content=content)
|
||||
@@ -1313,6 +1342,8 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||
additional_kwargs: dict = {}
|
||||
if reasoning := _dict.get("reasoning"):
|
||||
additional_kwargs["reasoning_content"] = reasoning
|
||||
if executed_tools := _dict.get("executed_tools"):
|
||||
additional_kwargs["executed_tools"] = executed_tools
|
||||
if function_call := _dict.get("function_call"):
|
||||
additional_kwargs["function_call"] = dict(function_call)
|
||||
tool_calls = []
|
||||
@@ -1332,6 +1363,7 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||
additional_kwargs=additional_kwargs,
|
||||
tool_calls=tool_calls,
|
||||
invalid_tool_calls=invalid_tool_calls,
|
||||
response_metadata={"model_provider": "groq"},
|
||||
)
|
||||
if role == "system":
|
||||
return SystemMessage(content=_dict.get("content", ""))
|
||||
|
||||
BIN
libs/partners/groq/tests/cassettes/test_code_interpreter.yaml.gz
Normal file
BIN
libs/partners/groq/tests/cassettes/test_code_interpreter.yaml.gz
Normal file
Binary file not shown.
BIN
libs/partners/groq/tests/cassettes/test_web_search.yaml.gz
Normal file
BIN
libs/partners/groq/tests/cassettes/test_web_search.yaml.gz
Normal file
Binary file not shown.
37
libs/partners/groq/tests/conftest.py
Normal file
37
libs/partners/groq/tests/conftest.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from langchain_tests.conftest import CustomPersister, CustomSerializer
|
||||
from langchain_tests.conftest import (
|
||||
_base_vcr_config as _base_vcr_config, # noqa: PLC0414
|
||||
)
|
||||
from vcr import VCR # type: ignore[import-untyped]
|
||||
|
||||
|
||||
def remove_request_headers(request: Any) -> Any:
|
||||
for k in request.headers:
|
||||
request.headers[k] = "**REDACTED**"
|
||||
return request
|
||||
|
||||
|
||||
def remove_response_headers(response: dict) -> dict:
|
||||
for k in response["headers"]:
|
||||
response["headers"][k] = "**REDACTED**"
|
||||
return response
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def vcr_config(_base_vcr_config: dict) -> dict: # noqa: F811
|
||||
"""Extend the default configuration coming from langchain_tests."""
|
||||
config = _base_vcr_config.copy()
|
||||
config["before_record_request"] = remove_request_headers
|
||||
config["before_record_response"] = remove_response_headers
|
||||
config["serializer"] = "yaml.gz"
|
||||
config["path_transformer"] = VCR.ensure_suffix(".yaml.gz")
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def pytest_recording_configure(config: dict, vcr: VCR) -> None:
|
||||
vcr.register_persister(CustomPersister())
|
||||
vcr.register_serializer("yaml.gz", CustomSerializer())
|
||||
@@ -111,7 +111,9 @@ async def test_astream() -> None:
|
||||
full = token if full is None else full + token
|
||||
if token.usage_metadata is not None:
|
||||
chunks_with_token_counts += 1
|
||||
if token.response_metadata:
|
||||
if token.response_metadata and not set(token.response_metadata.keys()).issubset(
|
||||
{"model_provider", "output_version"}
|
||||
):
|
||||
chunks_with_response_metadata += 1
|
||||
if chunks_with_token_counts != 1 or chunks_with_response_metadata != 1:
|
||||
msg = (
|
||||
@@ -665,6 +667,146 @@ async def test_setting_service_tier_request_async() -> None:
|
||||
assert response.response_metadata.get("service_tier") == "on_demand"
|
||||
|
||||
|
||||
@pytest.mark.vcr
|
||||
def test_web_search() -> None:
|
||||
llm = ChatGroq(model="groq/compound")
|
||||
input_message = {
|
||||
"role": "user",
|
||||
"content": "Search for the weather in Boston today.",
|
||||
}
|
||||
full: AIMessageChunk | None = None
|
||||
for chunk in llm.stream([input_message]):
|
||||
full = chunk if full is None else full + chunk
|
||||
assert isinstance(full, AIMessageChunk)
|
||||
assert full.additional_kwargs["reasoning_content"]
|
||||
assert full.additional_kwargs["executed_tools"]
|
||||
assert [block["type"] for block in full.content_blocks] == [
|
||||
"reasoning",
|
||||
"server_tool_call",
|
||||
"server_tool_result",
|
||||
"text",
|
||||
]
|
||||
|
||||
next_message = {
|
||||
"role": "user",
|
||||
"content": "Now search for the weather in San Francisco.",
|
||||
}
|
||||
response = llm.invoke([input_message, full, next_message])
|
||||
assert [block["type"] for block in response.content_blocks] == [
|
||||
"reasoning",
|
||||
"server_tool_call",
|
||||
"server_tool_result",
|
||||
"text",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.default_cassette("test_web_search.yaml.gz")
|
||||
@pytest.mark.vcr
|
||||
def test_web_search_v1() -> None:
|
||||
llm = ChatGroq(model="groq/compound", output_version="v1")
|
||||
input_message = {
|
||||
"role": "user",
|
||||
"content": "Search for the weather in Boston today.",
|
||||
}
|
||||
full: AIMessageChunk | None = None
|
||||
for chunk in llm.stream([input_message]):
|
||||
full = chunk if full is None else full + chunk
|
||||
assert isinstance(full, AIMessageChunk)
|
||||
assert full.additional_kwargs["reasoning_content"]
|
||||
assert full.additional_kwargs["executed_tools"]
|
||||
assert [block["type"] for block in full.content_blocks] == [
|
||||
"reasoning",
|
||||
"server_tool_call",
|
||||
"server_tool_result",
|
||||
"reasoning",
|
||||
"text",
|
||||
]
|
||||
|
||||
next_message = {
|
||||
"role": "user",
|
||||
"content": "Now search for the weather in San Francisco.",
|
||||
}
|
||||
response = llm.invoke([input_message, full, next_message])
|
||||
assert [block["type"] for block in response.content_blocks] == [
|
||||
"reasoning",
|
||||
"server_tool_call",
|
||||
"server_tool_result",
|
||||
"text",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.vcr
|
||||
def test_code_interpreter() -> None:
|
||||
llm = ChatGroq(model="groq/compound-mini")
|
||||
input_message = {
|
||||
"role": "user",
|
||||
"content": (
|
||||
"Calculate the square root of 101 and show me the Python code you used."
|
||||
),
|
||||
}
|
||||
full: AIMessageChunk | None = None
|
||||
for chunk in llm.stream([input_message]):
|
||||
full = chunk if full is None else full + chunk
|
||||
assert isinstance(full, AIMessageChunk)
|
||||
assert full.additional_kwargs["reasoning_content"]
|
||||
assert full.additional_kwargs["executed_tools"]
|
||||
assert [block["type"] for block in full.content_blocks] == [
|
||||
"reasoning",
|
||||
"server_tool_call",
|
||||
"server_tool_result",
|
||||
"text",
|
||||
]
|
||||
|
||||
next_message = {
|
||||
"role": "user",
|
||||
"content": "Now do the same for 102.",
|
||||
}
|
||||
response = llm.invoke([input_message, full, next_message])
|
||||
assert [block["type"] for block in response.content_blocks] == [
|
||||
"reasoning",
|
||||
"server_tool_call",
|
||||
"server_tool_result",
|
||||
"text",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.default_cassette("test_code_interpreter.yaml.gz")
|
||||
@pytest.mark.vcr
|
||||
def test_code_interpreter_v1() -> None:
|
||||
llm = ChatGroq(model="groq/compound-mini", output_version="v1")
|
||||
input_message = {
|
||||
"role": "user",
|
||||
"content": (
|
||||
"Calculate the square root of 101 and show me the Python code you used."
|
||||
),
|
||||
}
|
||||
full: AIMessageChunk | None = None
|
||||
for chunk in llm.stream([input_message]):
|
||||
full = chunk if full is None else full + chunk
|
||||
assert isinstance(full, AIMessageChunk)
|
||||
assert full.additional_kwargs["reasoning_content"]
|
||||
assert full.additional_kwargs["executed_tools"]
|
||||
assert [block["type"] for block in full.content_blocks] == [
|
||||
"reasoning",
|
||||
"server_tool_call",
|
||||
"server_tool_result",
|
||||
"reasoning",
|
||||
"text",
|
||||
]
|
||||
|
||||
next_message = {
|
||||
"role": "user",
|
||||
"content": "Now do the same for 102.",
|
||||
}
|
||||
response = llm.invoke([input_message, full, next_message])
|
||||
assert [block["type"] for block in response.content_blocks] == [
|
||||
"reasoning",
|
||||
"server_tool_call",
|
||||
"server_tool_result",
|
||||
"text",
|
||||
]
|
||||
|
||||
|
||||
# Groq does not currently support N > 1
|
||||
# @pytest.mark.scheduled
|
||||
# def test_chat_multiple_completions() -> None:
|
||||
|
||||
@@ -54,7 +54,9 @@ def test__convert_dict_to_message_human() -> None:
|
||||
def test__convert_dict_to_message_ai() -> None:
|
||||
message = {"role": "assistant", "content": "foo"}
|
||||
result = _convert_dict_to_message(message)
|
||||
expected_output = AIMessage(content="foo")
|
||||
expected_output = AIMessage(
|
||||
content="foo", response_metadata={"model_provider": "groq"}
|
||||
)
|
||||
assert result == expected_output
|
||||
|
||||
|
||||
@@ -80,6 +82,7 @@ def test__convert_dict_to_message_tool_call() -> None:
|
||||
type="tool_call",
|
||||
)
|
||||
],
|
||||
response_metadata={"model_provider": "groq"},
|
||||
)
|
||||
assert result == expected_output
|
||||
|
||||
@@ -124,6 +127,7 @@ def test__convert_dict_to_message_tool_call() -> None:
|
||||
type="tool_call",
|
||||
),
|
||||
],
|
||||
response_metadata={"model_provider": "groq"},
|
||||
)
|
||||
assert result == expected_output
|
||||
|
||||
|
||||
125
libs/partners/mistralai/langchain_mistralai/_compat.py
Normal file
125
libs/partners/mistralai/langchain_mistralai/_compat.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""Derivations of standard content blocks from mistral content."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk
|
||||
from langchain_core.messages import content as types
|
||||
from langchain_core.messages.block_translators import register_translator
|
||||
|
||||
|
||||
def _convert_from_v1_to_mistral(
|
||||
content: list[types.ContentBlock],
|
||||
model_provider: str | None,
|
||||
) -> str | list[str | dict]:
|
||||
new_content: list = []
|
||||
for block in content:
|
||||
if block["type"] == "text":
|
||||
new_content.append({"text": block.get("text", ""), "type": "text"})
|
||||
|
||||
elif (
|
||||
block["type"] == "reasoning"
|
||||
and (reasoning := block.get("reasoning"))
|
||||
and isinstance(reasoning, str)
|
||||
and model_provider == "mistralai"
|
||||
):
|
||||
new_content.append(
|
||||
{
|
||||
"type": "thinking",
|
||||
"thinking": [{"type": "text", "text": reasoning}],
|
||||
}
|
||||
)
|
||||
|
||||
elif (
|
||||
block["type"] == "non_standard"
|
||||
and "value" in block
|
||||
and model_provider == "mistralai"
|
||||
):
|
||||
new_content.append(block["value"])
|
||||
elif block["type"] == "tool_call":
|
||||
continue
|
||||
else:
|
||||
new_content.append(block)
|
||||
|
||||
return new_content
|
||||
|
||||
|
||||
def _convert_to_v1_from_mistral(message: AIMessage) -> list[types.ContentBlock]:
|
||||
"""Convert mistral message content to v1 format."""
|
||||
if isinstance(message.content, str):
|
||||
content_blocks: list[types.ContentBlock] = [
|
||||
{"type": "text", "text": message.content}
|
||||
]
|
||||
|
||||
else:
|
||||
content_blocks = []
|
||||
for block in message.content:
|
||||
if isinstance(block, str):
|
||||
content_blocks.append({"type": "text", "text": block})
|
||||
|
||||
elif isinstance(block, dict):
|
||||
if block.get("type") == "text" and isinstance(block.get("text"), str):
|
||||
text_block: types.TextContentBlock = {
|
||||
"type": "text",
|
||||
"text": block["text"],
|
||||
}
|
||||
if "index" in block:
|
||||
text_block["index"] = block["index"]
|
||||
content_blocks.append(text_block)
|
||||
|
||||
elif block.get("type") == "thinking" and isinstance(
|
||||
block.get("thinking"), list
|
||||
):
|
||||
for sub_block in block["thinking"]:
|
||||
if (
|
||||
isinstance(sub_block, dict)
|
||||
and sub_block.get("type") == "text"
|
||||
):
|
||||
reasoning_block: types.ReasoningContentBlock = {
|
||||
"type": "reasoning",
|
||||
"reasoning": sub_block.get("text", ""),
|
||||
}
|
||||
if "index" in block:
|
||||
reasoning_block["index"] = block["index"]
|
||||
content_blocks.append(reasoning_block)
|
||||
|
||||
else:
|
||||
non_standard_block: types.NonStandardContentBlock = {
|
||||
"type": "non_standard",
|
||||
"value": block,
|
||||
}
|
||||
content_blocks.append(non_standard_block)
|
||||
else:
|
||||
continue
|
||||
|
||||
if (
|
||||
len(content_blocks) == 1
|
||||
and content_blocks[0].get("type") == "text"
|
||||
and content_blocks[0].get("text") == ""
|
||||
and message.tool_calls
|
||||
):
|
||||
content_blocks = []
|
||||
|
||||
for tool_call in message.tool_calls:
|
||||
content_blocks.append(
|
||||
{
|
||||
"type": "tool_call",
|
||||
"name": tool_call["name"],
|
||||
"args": tool_call["args"],
|
||||
"id": tool_call.get("id"),
|
||||
}
|
||||
)
|
||||
|
||||
return content_blocks
|
||||
|
||||
|
||||
def translate_content(message: AIMessage) -> list[types.ContentBlock]:
|
||||
"""Derive standard content blocks from a message with mistral content."""
|
||||
return _convert_to_v1_from_mistral(message)
|
||||
|
||||
|
||||
def translate_content_chunk(message: AIMessageChunk) -> list[types.ContentBlock]:
|
||||
"""Derive standard content blocks from a message chunk with mistral content."""
|
||||
return _convert_to_v1_from_mistral(message)
|
||||
|
||||
|
||||
register_translator("mistralai", translate_content, translate_content_chunk)
|
||||
@@ -24,12 +24,7 @@ from langchain_core.callbacks import (
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models import LanguageModelInput
|
||||
from langchain_core.language_models.chat_models import (
|
||||
BaseChatModel,
|
||||
LangSmithParams,
|
||||
agenerate_from_stream,
|
||||
generate_from_stream,
|
||||
)
|
||||
from langchain_core.language_models.chat_models import BaseChatModel, LangSmithParams
|
||||
from langchain_core.language_models.llms import create_base_retry_decorator
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
@@ -74,6 +69,8 @@ from pydantic import (
|
||||
)
|
||||
from typing_extensions import Self
|
||||
|
||||
from langchain_mistralai._compat import _convert_from_v1_to_mistral
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncIterator, Iterator, Sequence
|
||||
from contextlib import AbstractAsyncContextManager
|
||||
@@ -160,6 +157,7 @@ def _convert_mistral_chat_message_to_message(
|
||||
additional_kwargs=additional_kwargs,
|
||||
tool_calls=tool_calls,
|
||||
invalid_tool_calls=invalid_tool_calls,
|
||||
response_metadata={"model_provider": "mistralai"},
|
||||
)
|
||||
|
||||
|
||||
@@ -231,14 +229,34 @@ async def acompletion_with_retry(
|
||||
|
||||
|
||||
def _convert_chunk_to_message_chunk(
|
||||
chunk: dict, default_class: type[BaseMessageChunk]
|
||||
) -> BaseMessageChunk:
|
||||
chunk: dict,
|
||||
default_class: type[BaseMessageChunk],
|
||||
index: int,
|
||||
index_type: str,
|
||||
output_version: str | None,
|
||||
) -> tuple[BaseMessageChunk, int, str]:
|
||||
_choice = chunk["choices"][0]
|
||||
_delta = _choice["delta"]
|
||||
role = _delta.get("role")
|
||||
content = _delta.get("content") or ""
|
||||
if output_version == "v1" and isinstance(content, str):
|
||||
content = [{"type": "text", "text": content}]
|
||||
if isinstance(content, list):
|
||||
for block in content:
|
||||
if isinstance(block, dict):
|
||||
if "type" in block and block["type"] != index_type:
|
||||
index_type = block["type"]
|
||||
index = index + 1
|
||||
if "index" not in block:
|
||||
block["index"] = index
|
||||
if block.get("type") == "thinking" and isinstance(
|
||||
block.get("thinking"), list
|
||||
):
|
||||
for sub_block in block["thinking"]:
|
||||
if isinstance(sub_block, dict) and "index" not in sub_block:
|
||||
sub_block["index"] = 0
|
||||
if role == "user" or default_class == HumanMessageChunk:
|
||||
return HumanMessageChunk(content=content)
|
||||
return HumanMessageChunk(content=content), index, index_type
|
||||
if role == "assistant" or default_class == AIMessageChunk:
|
||||
additional_kwargs: dict = {}
|
||||
response_metadata = {}
|
||||
@@ -276,18 +294,22 @@ def _convert_chunk_to_message_chunk(
|
||||
):
|
||||
response_metadata["model_name"] = chunk["model"]
|
||||
response_metadata["finish_reason"] = _choice["finish_reason"]
|
||||
return AIMessageChunk(
|
||||
content=content,
|
||||
additional_kwargs=additional_kwargs,
|
||||
tool_call_chunks=tool_call_chunks, # type: ignore[arg-type]
|
||||
usage_metadata=usage_metadata, # type: ignore[arg-type]
|
||||
response_metadata=response_metadata,
|
||||
return (
|
||||
AIMessageChunk(
|
||||
content=content,
|
||||
additional_kwargs=additional_kwargs,
|
||||
tool_call_chunks=tool_call_chunks, # type: ignore[arg-type]
|
||||
usage_metadata=usage_metadata, # type: ignore[arg-type]
|
||||
response_metadata={"model_provider": "mistralai", **response_metadata},
|
||||
),
|
||||
index,
|
||||
index_type,
|
||||
)
|
||||
if role == "system" or default_class == SystemMessageChunk:
|
||||
return SystemMessageChunk(content=content)
|
||||
return SystemMessageChunk(content=content), index, index_type
|
||||
if role or default_class == ChatMessageChunk:
|
||||
return ChatMessageChunk(content=content, role=role)
|
||||
return default_class(content=content) # type: ignore[call-arg]
|
||||
return ChatMessageChunk(content=content, role=role), index, index_type
|
||||
return default_class(content=content), index, index_type # type: ignore[call-arg]
|
||||
|
||||
|
||||
def _format_tool_call_for_mistral(tool_call: ToolCall) -> dict:
|
||||
@@ -318,6 +340,21 @@ def _format_invalid_tool_call_for_mistral(invalid_tool_call: InvalidToolCall) ->
|
||||
return result
|
||||
|
||||
|
||||
def _clean_block(block: dict) -> dict:
|
||||
# Remove "index" key added for message aggregation in langchain-core
|
||||
new_block = {k: v for k, v in block.items() if k != "index"}
|
||||
if block.get("type") == "thinking" and isinstance(block.get("thinking"), list):
|
||||
new_block["thinking"] = [
|
||||
(
|
||||
{k: v for k, v in sb.items() if k != "index"}
|
||||
if isinstance(sb, dict) and "index" in sb
|
||||
else sb
|
||||
)
|
||||
for sb in block["thinking"]
|
||||
]
|
||||
return new_block
|
||||
|
||||
|
||||
def _convert_message_to_mistral_chat_message(
|
||||
message: BaseMessage,
|
||||
) -> dict:
|
||||
@@ -356,13 +393,40 @@ def _convert_message_to_mistral_chat_message(
|
||||
pass
|
||||
if tool_calls: # do not populate empty list tool_calls
|
||||
message_dict["tool_calls"] = tool_calls
|
||||
if tool_calls and message.content:
|
||||
|
||||
# Message content
|
||||
# Translate v1 content
|
||||
if message.response_metadata.get("output_version") == "v1":
|
||||
content = _convert_from_v1_to_mistral(
|
||||
message.content_blocks, message.response_metadata.get("model_provider")
|
||||
)
|
||||
else:
|
||||
content = message.content
|
||||
|
||||
if tool_calls and content:
|
||||
# Assistant message must have either content or tool_calls, but not both.
|
||||
# Some providers may not support tool_calls in the same message as content.
|
||||
# This is done to ensure compatibility with messages from other providers.
|
||||
message_dict["content"] = ""
|
||||
content = ""
|
||||
|
||||
elif isinstance(content, list):
|
||||
content = [
|
||||
_clean_block(block)
|
||||
if isinstance(block, dict) and "index" in block
|
||||
else block
|
||||
for block in content
|
||||
]
|
||||
else:
|
||||
message_dict["content"] = message.content
|
||||
content = message.content
|
||||
|
||||
# if any blocks are dicts, cast strings to text blocks
|
||||
if any(isinstance(block, dict) for block in content):
|
||||
content = [
|
||||
block if isinstance(block, dict) else {"type": "text", "text": block}
|
||||
for block in content
|
||||
]
|
||||
message_dict["content"] = content
|
||||
|
||||
if "prefix" in message.additional_kwargs:
|
||||
message_dict["prefix"] = message.additional_kwargs["prefix"]
|
||||
return message_dict
|
||||
@@ -564,13 +628,6 @@ class ChatMistralAI(BaseChatModel):
|
||||
stream: bool | None = None, # noqa: FBT001
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
should_stream = stream if stream is not None else self.streaming
|
||||
if should_stream:
|
||||
stream_iter = self._stream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return generate_from_stream(stream_iter)
|
||||
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs}
|
||||
response = self.completion_with_retry(
|
||||
@@ -627,12 +684,16 @@ class ChatMistralAI(BaseChatModel):
|
||||
params = {**params, **kwargs, "stream": True}
|
||||
|
||||
default_chunk_class: type[BaseMessageChunk] = AIMessageChunk
|
||||
index = -1
|
||||
index_type = ""
|
||||
for chunk in self.completion_with_retry(
|
||||
messages=message_dicts, run_manager=run_manager, **params
|
||||
):
|
||||
if len(chunk.get("choices", [])) == 0:
|
||||
continue
|
||||
new_chunk = _convert_chunk_to_message_chunk(chunk, default_chunk_class)
|
||||
new_chunk, index, index_type = _convert_chunk_to_message_chunk(
|
||||
chunk, default_chunk_class, index, index_type, self.output_version
|
||||
)
|
||||
# make future chunks same type as first chunk
|
||||
default_chunk_class = new_chunk.__class__
|
||||
gen_chunk = ChatGenerationChunk(message=new_chunk)
|
||||
@@ -653,12 +714,16 @@ class ChatMistralAI(BaseChatModel):
|
||||
params = {**params, **kwargs, "stream": True}
|
||||
|
||||
default_chunk_class: type[BaseMessageChunk] = AIMessageChunk
|
||||
index = -1
|
||||
index_type = ""
|
||||
async for chunk in await acompletion_with_retry(
|
||||
self, messages=message_dicts, run_manager=run_manager, **params
|
||||
):
|
||||
if len(chunk.get("choices", [])) == 0:
|
||||
continue
|
||||
new_chunk = _convert_chunk_to_message_chunk(chunk, default_chunk_class)
|
||||
new_chunk, index, index_type = _convert_chunk_to_message_chunk(
|
||||
chunk, default_chunk_class, index, index_type, self.output_version
|
||||
)
|
||||
# make future chunks same type as first chunk
|
||||
default_chunk_class = new_chunk.__class__
|
||||
gen_chunk = ChatGenerationChunk(message=new_chunk)
|
||||
@@ -676,13 +741,6 @@ class ChatMistralAI(BaseChatModel):
|
||||
stream: bool | None = None, # noqa: FBT001
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
should_stream = stream if stream is not None else self.streaming
|
||||
if should_stream:
|
||||
stream_iter = self._astream(
|
||||
messages=messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return await agenerate_from_stream(stream_iter)
|
||||
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs}
|
||||
response = await acompletion_with_retry(
|
||||
|
||||
@@ -2,33 +2,19 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from httpx import ReadTimeout
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessageChunk,
|
||||
HumanMessage,
|
||||
)
|
||||
from langchain_core.messages import AIMessageChunk, BaseMessageChunk
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from langchain_mistralai.chat_models import ChatMistralAI
|
||||
|
||||
|
||||
def test_stream() -> None:
|
||||
"""Test streaming tokens from ChatMistralAI."""
|
||||
llm = ChatMistralAI()
|
||||
|
||||
for token in llm.stream("Hello"):
|
||||
assert isinstance(token.content, str)
|
||||
|
||||
|
||||
async def test_astream() -> None:
|
||||
"""Test streaming tokens from ChatMistralAI."""
|
||||
llm = ChatMistralAI()
|
||||
@@ -42,7 +28,9 @@ async def test_astream() -> None:
|
||||
full = token if full is None else full + token
|
||||
if token.usage_metadata is not None:
|
||||
chunks_with_token_counts += 1
|
||||
if token.response_metadata:
|
||||
if token.response_metadata and not set(token.response_metadata.keys()).issubset(
|
||||
{"model_provider", "output_version"}
|
||||
):
|
||||
chunks_with_response_metadata += 1
|
||||
if chunks_with_token_counts != 1 or chunks_with_response_metadata != 1:
|
||||
msg = (
|
||||
@@ -63,131 +51,6 @@ async def test_astream() -> None:
|
||||
assert full.response_metadata["model_name"]
|
||||
|
||||
|
||||
async def test_abatch() -> None:
|
||||
"""Test streaming tokens from ChatMistralAI."""
|
||||
llm = ChatMistralAI()
|
||||
|
||||
result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"])
|
||||
for token in result:
|
||||
assert isinstance(token.content, str)
|
||||
|
||||
|
||||
async def test_abatch_tags() -> None:
|
||||
"""Test batch tokens from ChatMistralAI."""
|
||||
llm = ChatMistralAI()
|
||||
|
||||
result = await llm.abatch(
|
||||
["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]}
|
||||
)
|
||||
for token in result:
|
||||
assert isinstance(token.content, str)
|
||||
|
||||
|
||||
def test_batch() -> None:
|
||||
"""Test batch tokens from ChatMistralAI."""
|
||||
llm = ChatMistralAI()
|
||||
|
||||
result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"])
|
||||
for token in result:
|
||||
assert isinstance(token.content, str)
|
||||
|
||||
|
||||
async def test_ainvoke() -> None:
|
||||
"""Test invoke tokens from ChatMistralAI."""
|
||||
llm = ChatMistralAI()
|
||||
|
||||
result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]})
|
||||
assert isinstance(result.content, str)
|
||||
assert "model_name" in result.response_metadata
|
||||
|
||||
|
||||
def test_invoke() -> None:
|
||||
"""Test invoke tokens from ChatMistralAI."""
|
||||
llm = ChatMistralAI()
|
||||
|
||||
result = llm.invoke("I'm Pickle Rick", config={"tags": ["foo"]})
|
||||
assert isinstance(result.content, str)
|
||||
|
||||
|
||||
def test_chat_mistralai_llm_output_contains_model_name() -> None:
|
||||
"""Test llm_output contains model_name."""
|
||||
chat = ChatMistralAI(max_tokens=10)
|
||||
message = HumanMessage(content="Hello")
|
||||
llm_result = chat.generate([[message]])
|
||||
assert llm_result.llm_output is not None
|
||||
assert llm_result.llm_output["model_name"] == chat.model
|
||||
|
||||
|
||||
def test_chat_mistralai_streaming_llm_output_contains_model_name() -> None:
|
||||
"""Test llm_output contains model_name."""
|
||||
chat = ChatMistralAI(max_tokens=10, streaming=True)
|
||||
message = HumanMessage(content="Hello")
|
||||
llm_result = chat.generate([[message]])
|
||||
assert llm_result.llm_output is not None
|
||||
assert llm_result.llm_output["model_name"] == chat.model
|
||||
|
||||
|
||||
def test_chat_mistralai_llm_output_contains_token_usage() -> None:
|
||||
"""Test llm_output contains model_name."""
|
||||
chat = ChatMistralAI(max_tokens=10)
|
||||
message = HumanMessage(content="Hello")
|
||||
llm_result = chat.generate([[message]])
|
||||
assert llm_result.llm_output is not None
|
||||
assert "token_usage" in llm_result.llm_output
|
||||
token_usage = llm_result.llm_output["token_usage"]
|
||||
assert "prompt_tokens" in token_usage
|
||||
assert "completion_tokens" in token_usage
|
||||
assert "total_tokens" in token_usage
|
||||
|
||||
|
||||
def test_chat_mistralai_streaming_llm_output_not_contain_token_usage() -> None:
|
||||
"""Mistral currently doesn't return token usage when streaming."""
|
||||
chat = ChatMistralAI(max_tokens=10, streaming=True)
|
||||
message = HumanMessage(content="Hello")
|
||||
llm_result = chat.generate([[message]])
|
||||
assert llm_result.llm_output is not None
|
||||
assert "token_usage" in llm_result.llm_output
|
||||
token_usage = llm_result.llm_output["token_usage"]
|
||||
assert not token_usage
|
||||
|
||||
|
||||
def test_structured_output() -> None:
|
||||
llm = ChatMistralAI(model="mistral-large-latest", temperature=0) # type: ignore[call-arg]
|
||||
schema = {
|
||||
"title": "AnswerWithJustification",
|
||||
"description": (
|
||||
"An answer to the user question along with justification for the answer."
|
||||
),
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"answer": {"title": "Answer", "type": "string"},
|
||||
"justification": {"title": "Justification", "type": "string"},
|
||||
},
|
||||
"required": ["answer", "justification"],
|
||||
}
|
||||
structured_llm = llm.with_structured_output(schema)
|
||||
result = structured_llm.invoke(
|
||||
"What weighs more a pound of bricks or a pound of feathers"
|
||||
)
|
||||
assert isinstance(result, dict)
|
||||
|
||||
|
||||
def test_streaming_structured_output() -> None:
|
||||
llm = ChatMistralAI(model="mistral-large-latest", temperature=0) # type: ignore[call-arg]
|
||||
|
||||
class Person(BaseModel):
|
||||
name: str
|
||||
age: int
|
||||
|
||||
structured_llm = llm.with_structured_output(Person)
|
||||
strm = structured_llm.stream("Erick, 27 years old")
|
||||
for chunk_num, chunk in enumerate(strm):
|
||||
assert chunk_num == 0, "should only have one chunk with model"
|
||||
assert isinstance(chunk, Person)
|
||||
assert chunk.name == "Erick"
|
||||
assert chunk.age == 27
|
||||
|
||||
|
||||
class Book(BaseModel):
|
||||
name: str
|
||||
authors: list[str]
|
||||
@@ -247,66 +110,6 @@ async def test_structured_output_json_schema_async(schema: Any) -> None:
|
||||
_check_parsed_result(chunk, schema)
|
||||
|
||||
|
||||
def test_tool_call() -> None:
|
||||
llm = ChatMistralAI(model="mistral-large-latest", temperature=0) # type: ignore[call-arg]
|
||||
|
||||
class Person(BaseModel):
|
||||
name: str
|
||||
age: int
|
||||
|
||||
tool_llm = llm.bind_tools([Person])
|
||||
|
||||
result = tool_llm.invoke("Erick, 27 years old")
|
||||
assert isinstance(result, AIMessage)
|
||||
assert len(result.tool_calls) == 1
|
||||
tool_call = result.tool_calls[0]
|
||||
assert tool_call["name"] == "Person"
|
||||
assert tool_call["args"] == {"name": "Erick", "age": 27}
|
||||
|
||||
|
||||
def test_streaming_tool_call() -> None:
|
||||
llm = ChatMistralAI(model="mistral-large-latest", temperature=0) # type: ignore[call-arg]
|
||||
|
||||
class Person(BaseModel):
|
||||
name: str
|
||||
age: int
|
||||
|
||||
tool_llm = llm.bind_tools([Person])
|
||||
|
||||
# where it calls the tool
|
||||
strm = tool_llm.stream("Erick, 27 years old")
|
||||
|
||||
additional_kwargs = None
|
||||
for chunk in strm:
|
||||
assert isinstance(chunk, AIMessageChunk)
|
||||
assert chunk.content == ""
|
||||
additional_kwargs = chunk.additional_kwargs
|
||||
|
||||
assert additional_kwargs is not None
|
||||
assert "tool_calls" in additional_kwargs
|
||||
assert len(additional_kwargs["tool_calls"]) == 1
|
||||
assert additional_kwargs["tool_calls"][0]["function"]["name"] == "Person"
|
||||
assert json.loads(additional_kwargs["tool_calls"][0]["function"]["arguments"]) == {
|
||||
"name": "Erick",
|
||||
"age": 27,
|
||||
}
|
||||
|
||||
assert isinstance(chunk, AIMessageChunk)
|
||||
assert len(chunk.tool_call_chunks) == 1
|
||||
tool_call_chunk = chunk.tool_call_chunks[0]
|
||||
assert tool_call_chunk["name"] == "Person"
|
||||
assert tool_call_chunk["args"] == '{"name": "Erick", "age": 27}'
|
||||
|
||||
# where it doesn't call the tool
|
||||
strm = tool_llm.stream("What is 2+2?")
|
||||
acc: Any = None
|
||||
for chunk in strm:
|
||||
assert isinstance(chunk, AIMessageChunk)
|
||||
acc = chunk if acc is None else acc + chunk
|
||||
assert acc.content != ""
|
||||
assert "tool_calls" not in acc.additional_kwargs
|
||||
|
||||
|
||||
def test_retry_parameters(caplog: pytest.LogCaptureFixture) -> None:
|
||||
"""Test that retry parameters are honored in ChatMistralAI."""
|
||||
# Create a model with intentionally short timeout and multiple retries
|
||||
@@ -342,3 +145,51 @@ def test_retry_parameters(caplog: pytest.LogCaptureFixture) -> None:
|
||||
except Exception:
|
||||
logger.exception("Unexpected exception")
|
||||
raise
|
||||
|
||||
|
||||
def test_reasoning() -> None:
|
||||
model = ChatMistralAI(model="magistral-medium-latest") # type: ignore[call-arg]
|
||||
input_message = {
|
||||
"role": "user",
|
||||
"content": "Hello, my name is Bob.",
|
||||
}
|
||||
full: AIMessageChunk | None = None
|
||||
for chunk in model.stream([input_message]):
|
||||
assert isinstance(chunk, AIMessageChunk)
|
||||
full = chunk if full is None else full + chunk
|
||||
assert isinstance(full, AIMessageChunk)
|
||||
thinking_blocks = 0
|
||||
for i, block in enumerate(full.content):
|
||||
if isinstance(block, dict) and block.get("type") == "thinking":
|
||||
thinking_blocks += 1
|
||||
reasoning_block = full.content_blocks[i]
|
||||
assert reasoning_block["type"] == "reasoning"
|
||||
assert isinstance(reasoning_block.get("reasoning"), str)
|
||||
assert thinking_blocks > 0
|
||||
|
||||
next_message = {"role": "user", "content": "What is my name?"}
|
||||
_ = model.invoke([input_message, full, next_message])
|
||||
|
||||
|
||||
def test_reasoning_v1() -> None:
|
||||
model = ChatMistralAI(model="magistral-medium-latest", output_version="v1") # type: ignore[call-arg]
|
||||
input_message = {
|
||||
"role": "user",
|
||||
"content": "Hello, my name is Bob.",
|
||||
}
|
||||
full: AIMessageChunk | None = None
|
||||
chunks = []
|
||||
for chunk in model.stream([input_message]):
|
||||
assert isinstance(chunk, AIMessageChunk)
|
||||
full = chunk if full is None else full + chunk
|
||||
chunks.append(chunk)
|
||||
assert isinstance(full, AIMessageChunk)
|
||||
reasoning_blocks = 0
|
||||
for block in full.content:
|
||||
if isinstance(block, dict) and block.get("type") == "reasoning":
|
||||
reasoning_blocks += 1
|
||||
assert isinstance(block.get("reasoning"), str)
|
||||
assert reasoning_blocks > 0
|
||||
|
||||
next_message = {"role": "user", "content": "What is my name?"}
|
||||
_ = model.invoke([input_message, full, next_message])
|
||||
|
||||
@@ -188,6 +188,7 @@ def test__convert_dict_to_message_tool_call() -> None:
|
||||
type="tool_call",
|
||||
)
|
||||
],
|
||||
response_metadata={"model_provider": "mistralai"},
|
||||
)
|
||||
assert result == expected_output
|
||||
assert _convert_message_to_mistral_chat_message(expected_output) == message
|
||||
@@ -231,6 +232,7 @@ def test__convert_dict_to_message_tool_call() -> None:
|
||||
type="tool_call",
|
||||
),
|
||||
],
|
||||
response_metadata={"model_provider": "mistralai"},
|
||||
)
|
||||
assert result == expected_output
|
||||
assert _convert_message_to_mistral_chat_message(expected_output) == message
|
||||
|
||||
Reference in New Issue
Block a user