Compare commits

...

16 Commits

Author SHA1 Message Date
Nick Huang
526cbb36e6 Prevent dangling tool calls from breaking react agents 2025-10-14 19:59:47 -04:00
Mason Daugherty
68ceeb64f6 chore(core): delete function_calling.py utils marked for removal (#33376) 2025-10-14 16:13:19 -04:00
Mason Daugherty
edae976b81 chore(core): delete pydantic_v1/ (#33374) 2025-10-14 16:08:24 -04:00
ccurme
9f4366bc9d feat(mistralai): support reasoning feature and v1 content (#33485)
Not yet supported: server-side tool calls
2025-10-14 15:19:44 -04:00
Eugene Yurtsev
99e0a60aab chore(langchain_v1): remove invocation request (#33482)
Remove ToolNode primitives from langchain
2025-10-14 15:07:30 -04:00
Eugene Yurtsev
d38729fbac feat(langchain_v1): add async implementations to wrap_model_call (#33467)
Add async implementations to wrap_model_call for prebuilt middleware
2025-10-14 17:39:38 +00:00
gsmini
ff0d21cfd5 fix(langchain_v1): can not import "wrap_tool_call" from agents.… (#33472)
fix can not import `wrap_tool_call` from ` langchain.agents.middleware
import `
```python

from langchain.agents import create_agent
from langchain.agents.middleware import wrap_tool_call # here !
from langchain_core.messages import ToolMessage

@wrap_tool_call
def handle_tool_errors(request, handler):
    """Handle tool execution errors with custom messages."""
    try:
        return handler(request)
    except Exception as e:
        # Return a custom error message to the model
        return ToolMessage(
            content=f"Tool error: Please check your input and try again. ({str(e)})",
            tool_call_id=request.tool_call["id"]
        )

agent = create_agent(
    model="openai:gpt-4o",
    tools=[search, calculate],
    middleware=[handle_tool_errors]
)
```
> example code from:
https://docs.langchain.com/oss/python/langchain/agents#tool-error-handling
2025-10-14 13:39:25 -04:00
Eugene Yurtsev
9140a7cb86 feat(langchain_v1): add override to model request and tool call request (#33465)
Add override to model request and tool call request
2025-10-14 10:31:46 -04:00
ccurme
41fe18bc80 chore(groq): fix integration tests (#33478)
- add missing cassette
- update streaming metadata test for v1
2025-10-14 14:16:34 +00:00
Mason Daugherty
9105573cb3 docs: create_agent style and clarify system_prompt (#33470) 2025-10-14 09:56:54 -04:00
Sydney Runkle
fff87e95d1 fix(langchain): rename PlanningMiddleware to TodoListMiddleware (#33476) 2025-10-14 09:06:06 -04:00
ccurme
9beb29a34c chore(mistralai): delete redundant tests (#33468) 2025-10-13 21:28:51 +00:00
ChoYongHo | 조용호
ca00f5aed9 fix(langchain_v1): export ModelResponse from agents.middleware (#33453) (#33454)
## Description

  Fixes #33453

`ModelResponse` was defined in `types.py` and included in its `__all__`
list, but was not exported from the middleware package's `__init__.py`.
This caused `ImportError` when attempting to import it directly
from `langchain.agents.middleware`, despite being documented as a public
export.

  ## Changes

- Added `ModelResponse` to the import statement in
`langchain/agents/middleware/__init__.py`
- Added `ModelResponse` to the `__all__` list in
`langchain/agents/middleware/__init__.py`
- Added comprehensive unit tests in `test_imports.py` to verify the
import works correctly

  ## Issue

  The original issue reported that the following import failed:

  ```python
  from langchain.agents.middleware import ModelResponse
# ImportError: cannot import name 'ModelResponse' from
'langchain.agents.middleware'

  The workaround was to import from the submodule:

from langchain.agents.middleware.types import ModelResponse # Workaround

  Solution

  After this fix, ModelResponse can be imported directly as documented:

  from langchain.agents.middleware import ModelResponse  # Now works!

  Testing

-  Added 3 unit tests in
tests/unit_tests/agents/middleware/test_imports.py
  -  All tests pass locally: make format, make lint, make test
  -  Verified ModelResponse is properly exported and importable
  -  Verified ModelResponse appears in __all__ list

  Dependencies

  None. This is a simple export fix with no new dependencies.

---------

Co-authored-by: Eugene Yurtsev <eugene@langchain.dev>
Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
2025-10-13 16:02:30 -04:00
dependabot[bot]
637777b8e7 chore(infra): bump astral-sh/setup-uv from 6 to 7 (#33457)
Bumps [astral-sh/setup-uv](https://github.com/astral-sh/setup-uv) from 6
to 7.
<details>
<summary>Release notes</summary>
<p><em>Sourced from <a
href="https://github.com/astral-sh/setup-uv/releases">astral-sh/setup-uv's
releases</a>.</em></p>
<blockquote>
<h2>v7.0.0 🌈 node24 and a lot of bugfixes</h2>
<h2>Changes</h2>
<p>This release comes with a load of bug fixes and a speed up. Because
of switching from node20 to node24 it is also a breaking change. If you
are running on GitHub hosted runners this will just work, if you are
using self-hosted runners make sure, that your runners are up to date.
If you followed the normal installation instructions your self-hosted
runner will keep itself updated.</p>
<p>This release also removes the deprecated input
<code>server-url</code> which was used to download uv releases from a
different server.
The <a
href="https://github.com/astral-sh/setup-uv?tab=readme-ov-file#manifest-file">manifest-file</a>
input supersedes that functionality by adding a flexible way to define
available versions and where they should be downloaded from.</p>
<h3>Fixes</h3>
<ul>
<li>The action now respects when the environment variable
<code>UV_CACHE_DIR</code> is already set and does not overwrite it. It
now also finds <a
href="https://docs.astral.sh/uv/reference/settings/#cache-dir">cache-dir</a>
settings in config files if you set them.</li>
<li>Some users encountered problems that <a
href="https://github.com/astral-sh/setup-uv?tab=readme-ov-file#disable-cache-pruning">cache
pruning</a> took forever because they had some <code>uv</code> processes
running in the background. Starting with uv version <code>0.8.24</code>
this action uses <code>uv cache prune --ci --force</code> to ignore the
running processes</li>
<li>If you just want to install uv but not have it available in path,
this action now respects <code>UV_NO_MODIFY_PATH</code></li>
<li>Some other actions also set the env var <code>UV_CACHE_DIR</code>.
This action can now deal with that but as this could lead to unwanted
behavior in some edgecases a warning is now displayed.</li>
</ul>
<h3>Improvements</h3>
<p>If you are using minimum version specifiers for the version of uv to
install for example</p>
<pre lang="toml"><code>[tool.uv]
required-version = &quot;&gt;=0.8.17&quot;
</code></pre>
<p>This action now detects that and directly uses the latest version.
Previously it would download all available releases from the uv repo
to determine the highest matching candidate for the version specifier,
which took much more time.</p>
<p>If you are using other specifiers like <code>0.8.x</code> this action
still needs to download all available releases because the specifier
defines an upper bound (not 0.9.0 or later) and &quot;latest&quot; would
possibly not satisfy that.</p>
<h2>🚨 Breaking changes</h2>
<ul>
<li>Use node24 instead of node20 <a
href="https://github.com/eifinger"><code>@​eifinger</code></a> (<a
href="https://redirect.github.com/astral-sh/setup-uv/issues/608">#608</a>)</li>
<li>Remove deprecated input server-url <a
href="https://github.com/eifinger"><code>@​eifinger</code></a> (<a
href="https://redirect.github.com/astral-sh/setup-uv/issues/607">#607</a>)</li>
</ul>
<h2>🐛 Bug fixes</h2>
<ul>
<li>Respect UV_CACHE_DIR and cache-dir <a
href="https://github.com/eifinger"><code>@​eifinger</code></a> (<a
href="https://redirect.github.com/astral-sh/setup-uv/issues/612">#612</a>)</li>
<li>Use --force when pruning cache <a
href="https://github.com/eifinger"><code>@​eifinger</code></a> (<a
href="https://redirect.github.com/astral-sh/setup-uv/issues/611">#611</a>)</li>
<li>Respect UV_NO_MODIFY_PATH <a
href="https://github.com/eifinger"><code>@​eifinger</code></a> (<a
href="https://redirect.github.com/astral-sh/setup-uv/issues/603">#603</a>)</li>
<li>Warn when <code>UV_CACHE_DIR</code> has changed <a
href="https://github.com/jamesbraza"><code>@​jamesbraza</code></a> (<a
href="https://redirect.github.com/astral-sh/setup-uv/issues/601">#601</a>)</li>
</ul>
<h2>🚀 Enhancements</h2>
<ul>
<li>Shortcut to latest version for minimum version specifier <a
href="https://github.com/eifinger"><code>@​eifinger</code></a> (<a
href="https://redirect.github.com/astral-sh/setup-uv/issues/598">#598</a>)</li>
</ul>
<h2>🧰 Maintenance</h2>
<ul>
<li>Bump dependencies <a
href="https://github.com/eifinger"><code>@​eifinger</code></a> (<a
href="https://redirect.github.com/astral-sh/setup-uv/issues/613">#613</a>)</li>
<li>Fix test-uv-no-modify-path <a
href="https://github.com/eifinger"><code>@​eifinger</code></a> (<a
href="https://redirect.github.com/astral-sh/setup-uv/issues/604">#604</a>)</li>
</ul>
<!-- raw HTML omitted -->
</blockquote>
<p>... (truncated)</p>
</details>
<details>
<summary>Commits</summary>
<ul>
<li><a
href="3259c6206f"><code>3259c62</code></a>
Bump deps (<a
href="https://redirect.github.com/astral-sh/setup-uv/issues/633">#633</a>)</li>
<li><a
href="bf8e8ed895"><code>bf8e8ed</code></a>
Split up documentation (<a
href="https://redirect.github.com/astral-sh/setup-uv/issues/632">#632</a>)</li>
<li><a
href="9c6b5e9fb5"><code>9c6b5e9</code></a>
Add resolution-strategy input to support oldest compatible version
selection ...</li>
<li><a
href="a5129e99f4"><code>a5129e9</code></a>
Add copilot-instructions.md (<a
href="https://redirect.github.com/astral-sh/setup-uv/issues/630">#630</a>)</li>
<li><a
href="d18bcc753a"><code>d18bcc7</code></a>
Add value of UV_PYTHON_INSTALL_DIR to path (<a
href="https://redirect.github.com/astral-sh/setup-uv/issues/628">#628</a>)</li>
<li><a
href="bd1f875aba"><code>bd1f875</code></a>
Set output venv when activate-environment is used (<a
href="https://redirect.github.com/astral-sh/setup-uv/issues/627">#627</a>)</li>
<li><a
href="1a91c3851d"><code>1a91c38</code></a>
chore: update known checksums for 0.9.2 (<a
href="https://redirect.github.com/astral-sh/setup-uv/issues/626">#626</a>)</li>
<li><a
href="c79f606987"><code>c79f606</code></a>
chore: update known checksums for 0.9.1 (<a
href="https://redirect.github.com/astral-sh/setup-uv/issues/625">#625</a>)</li>
<li><a
href="e0249f1599"><code>e0249f1</code></a>
Fall back to PR for updating known versions (<a
href="https://redirect.github.com/astral-sh/setup-uv/issues/623">#623</a>)</li>
<li><a
href="6d2eb15b49"><code>6d2eb15</code></a>
Cache python installs (<a
href="https://redirect.github.com/astral-sh/setup-uv/issues/621">#621</a>)</li>
<li>Additional commits viewable in <a
href="https://github.com/astral-sh/setup-uv/compare/v6...v7">compare
view</a></li>
</ul>
</details>
<br />


[![Dependabot compatibility
score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=astral-sh/setup-uv&package-manager=github_actions&previous-version=6&new-version=7)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores)

Dependabot will resolve any conflicts with this PR as long as you don't
alter it yourself. You can also trigger a rebase manually by commenting
`@dependabot rebase`.

[//]: # (dependabot-automerge-start)
[//]: # (dependabot-automerge-end)

---

<details>
<summary>Dependabot commands and options</summary>
<br />

You can trigger Dependabot actions by commenting on this PR:
- `@dependabot rebase` will rebase this PR
- `@dependabot recreate` will recreate this PR, overwriting any edits
that have been made to it
- `@dependabot merge` will merge this PR after your CI passes on it
- `@dependabot squash and merge` will squash and merge this PR after
your CI passes on it
- `@dependabot cancel merge` will cancel a previously requested merge
and block automerging
- `@dependabot reopen` will reopen this PR if it is closed
- `@dependabot close` will close this PR and stop Dependabot recreating
it. You can achieve the same result by closing it manually
- `@dependabot show <dependency name> ignore conditions` will show all
of the ignore conditions of the specified dependency
- `@dependabot ignore this major version` will close this PR and stop
Dependabot creating any more for this major version (unless you reopen
the PR or upgrade to it yourself)
- `@dependabot ignore this minor version` will close this PR and stop
Dependabot creating any more for this minor version (unless you reopen
the PR or upgrade to it yourself)
- `@dependabot ignore this dependency` will close this PR and stop
Dependabot creating any more for this dependency (unless you reopen the
PR or upgrade to it yourself)


</details>

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-10-13 15:21:12 -04:00
Eugene Yurtsev
1cf851e054 chore(langchain_v1,anthropic): migrate anthropic middleware to langchain_anthropic (#33463)
Migrate prompt caching implementation into langchain_anthropic.middleware
2025-10-13 15:12:54 -04:00
ccurme
961f965f0c feat(groq): support built-in tools in message content (#33459) 2025-10-13 15:06:01 -04:00
49 changed files with 3312 additions and 1222 deletions

View File

@@ -186,7 +186,7 @@ jobs:
# We have to use 3.12 as 3.13 is not yet supported
- name: "📦 Install UV Package Manager"
uses: astral-sh/setup-uv@v6
uses: astral-sh/setup-uv@v7
with:
python-version: "3.12"

View File

@@ -1,39 +1,135 @@
"""Derivations of standard content blocks from Groq content."""
import warnings
import json
import re
from typing import Any
from langchain_core.messages import AIMessage, AIMessageChunk
from langchain_core.messages import content as types
WARNED = False
from langchain_core.messages.base import _extract_reasoning_from_additional_kwargs
def translate_content(message: AIMessage) -> list[types.ContentBlock]: # noqa: ARG001
"""Derive standard content blocks from a message with Groq content."""
global WARNED # noqa: PLW0603
if not WARNED:
warning_message = (
"Content block standardization is not yet fully supported for Groq."
def _populate_extras(
standard_block: types.ContentBlock, block: dict[str, Any], known_fields: set[str]
) -> types.ContentBlock:
"""Mutate a block, populating extras."""
if standard_block.get("type") == "non_standard":
return standard_block
for key, value in block.items():
if key not in known_fields:
if "extras" not in standard_block:
# Below type-ignores are because mypy thinks a non-standard block can
# get here, although we exclude them above.
standard_block["extras"] = {} # type: ignore[typeddict-unknown-key]
standard_block["extras"][key] = value # type: ignore[typeddict-item]
return standard_block
def _parse_code_json(s: str) -> dict:
"""Extract Python code from Groq built-in tool content.
Extracts the value of the 'code' field from a string of the form:
{"code": some_arbitrary_text_with_unescaped_quotes}
As Groq may not escape quotes in the executed tools, e.g.:
```
'{"code": "import math; print("The square root of 101 is: "); print(math.sqrt(101))"}'
```
""" # noqa: E501
m = re.fullmatch(r'\s*\{\s*"code"\s*:\s*"(.*)"\s*\}\s*', s, flags=re.DOTALL)
if not m:
msg = (
"Could not extract Python code from Groq tool arguments. "
"Expected a JSON object with a 'code' field."
)
warnings.warn(warning_message, stacklevel=2)
WARNED = True
raise NotImplementedError
raise ValueError(msg)
return {"code": m.group(1)}
def translate_content_chunk(message: AIMessageChunk) -> list[types.ContentBlock]: # noqa: ARG001
"""Derive standard content blocks from a message chunk with Groq content."""
global WARNED # noqa: PLW0603
if not WARNED:
warning_message = (
"Content block standardization is not yet fully supported for Groq."
def _convert_to_v1_from_groq(message: AIMessage) -> list[types.ContentBlock]:
"""Convert groq message content to v1 format."""
content_blocks: list[types.ContentBlock] = []
if reasoning_block := _extract_reasoning_from_additional_kwargs(message):
content_blocks.append(reasoning_block)
if executed_tools := message.additional_kwargs.get("executed_tools"):
for idx, executed_tool in enumerate(executed_tools):
args: dict[str, Any] | None = None
if arguments := executed_tool.get("arguments"):
try:
args = json.loads(arguments)
except json.JSONDecodeError:
if executed_tool.get("type") == "python":
try:
args = _parse_code_json(arguments)
except ValueError:
continue
elif (
executed_tool.get("type") == "function"
and executed_tool.get("name") == "python"
):
# GPT-OSS
args = {"code": arguments}
else:
continue
if isinstance(args, dict):
name = ""
if executed_tool.get("type") == "search":
name = "web_search"
elif executed_tool.get("type") == "python" or (
executed_tool.get("type") == "function"
and executed_tool.get("name") == "python"
):
name = "code_interpreter"
server_tool_call: types.ServerToolCall = {
"type": "server_tool_call",
"name": name,
"id": str(idx),
"args": args,
}
content_blocks.append(server_tool_call)
if tool_output := executed_tool.get("output"):
tool_result: types.ServerToolResult = {
"type": "server_tool_result",
"tool_call_id": str(idx),
"output": tool_output,
"status": "success",
}
known_fields = {"type", "arguments", "index", "output"}
_populate_extras(tool_result, executed_tool, known_fields)
content_blocks.append(tool_result)
if isinstance(message.content, str) and message.content:
content_blocks.append({"type": "text", "text": message.content})
for tool_call in message.tool_calls:
content_blocks.append( # noqa: PERF401
{
"type": "tool_call",
"name": tool_call["name"],
"args": tool_call["args"],
"id": tool_call.get("id"),
}
)
warnings.warn(warning_message, stacklevel=2)
WARNED = True
raise NotImplementedError
return content_blocks
def translate_content(message: AIMessage) -> list[types.ContentBlock]:
"""Derive standard content blocks from a message with groq content."""
return _convert_to_v1_from_groq(message)
def translate_content_chunk(message: AIMessageChunk) -> list[types.ContentBlock]:
"""Derive standard content blocks from a message chunk with groq content."""
return _convert_to_v1_from_groq(message)
def _register_groq_translator() -> None:
"""Register the Groq translator with the central registry.
"""Register the groq translator with the central registry.
Run automatically when the module is imported.
"""

View File

@@ -1,30 +0,0 @@
"""Pydantic v1 compatibility shim."""
from importlib import metadata
from pydantic.v1 import * # noqa: F403
from langchain_core._api.deprecation import warn_deprecated
try:
_PYDANTIC_MAJOR_VERSION: int = int(metadata.version("pydantic").split(".")[0])
except metadata.PackageNotFoundError:
_PYDANTIC_MAJOR_VERSION = 0
warn_deprecated(
"0.3.0",
removal="1.0.0",
alternative="pydantic.v1 or pydantic",
message=(
"As of langchain-core 0.3.0, LangChain uses pydantic v2 internally. "
"The langchain_core.pydantic_v1 module was a "
"compatibility shim for pydantic v1, and should no longer be used. "
"Please update the code to import from Pydantic directly.\n\n"
"For example, replace imports like: "
"`from langchain_core.pydantic_v1 import BaseModel`\n"
"with: `from pydantic import BaseModel`\n"
"or the v1 compatibility namespace if you are working in a code base "
"that has not been fully upgraded to pydantic 2 yet. "
"\tfrom pydantic.v1 import BaseModel\n"
),
)

View File

@@ -1,23 +0,0 @@
"""Pydantic v1 compatibility shim."""
from pydantic.v1.dataclasses import * # noqa: F403
from langchain_core._api import warn_deprecated
warn_deprecated(
"0.3.0",
removal="1.0.0",
alternative="pydantic.v1 or pydantic",
message=(
"As of langchain-core 0.3.0, LangChain uses pydantic v2 internally. "
"The langchain_core.pydantic_v1 module was a "
"compatibility shim for pydantic v1, and should no longer be used. "
"Please update the code to import from Pydantic directly.\n\n"
"For example, replace imports like: "
"`from langchain_core.pydantic_v1 import BaseModel`\n"
"with: `from pydantic import BaseModel`\n"
"or the v1 compatibility namespace if you are working in a code base "
"that has not been fully upgraded to pydantic 2 yet. "
"\tfrom pydantic.v1 import BaseModel\n"
),
)

View File

@@ -1,23 +0,0 @@
"""Pydantic v1 compatibility shim."""
from pydantic.v1.main import * # noqa: F403
from langchain_core._api import warn_deprecated
warn_deprecated(
"0.3.0",
removal="1.0.0",
alternative="pydantic.v1 or pydantic",
message=(
"As of langchain-core 0.3.0, LangChain uses pydantic v2 internally. "
"The langchain_core.pydantic_v1 module was a "
"compatibility shim for pydantic v1, and should no longer be used. "
"Please update the code to import from Pydantic directly.\n\n"
"For example, replace imports like: "
"`from langchain_core.pydantic_v1 import BaseModel`\n"
"with: `from pydantic import BaseModel`\n"
"or the v1 compatibility namespace if you are working in a code base "
"that has not been fully upgraded to pydantic 2 yet. "
"\tfrom pydantic.v1 import BaseModel\n"
),
)

View File

@@ -27,7 +27,7 @@ from pydantic.v1 import create_model as create_model_v1
from typing_extensions import TypedDict, is_typeddict
import langchain_core
from langchain_core._api import beta, deprecated
from langchain_core._api import beta
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage
from langchain_core.utils.json_schema import dereference_refs
from langchain_core.utils.pydantic import is_basemodel_subclass
@@ -168,42 +168,6 @@ def _convert_pydantic_to_openai_function(
)
convert_pydantic_to_openai_function = deprecated(
"0.1.16",
alternative="langchain_core.utils.function_calling.convert_to_openai_function()",
removal="1.0",
)(_convert_pydantic_to_openai_function)
@deprecated(
"0.1.16",
alternative="langchain_core.utils.function_calling.convert_to_openai_tool()",
removal="1.0",
)
def convert_pydantic_to_openai_tool(
model: type[BaseModel],
*,
name: str | None = None,
description: str | None = None,
) -> ToolDescription:
"""Converts a Pydantic model to a function description for the OpenAI API.
Args:
model: The Pydantic model to convert.
name: The name of the function. If not provided, the title of the schema will be
used.
description: The description of the function. If not provided, the description
of the schema will be used.
Returns:
The tool description.
"""
function = _convert_pydantic_to_openai_function(
model, name=name, description=description
)
return {"type": "function", "function": function}
def _get_python_function_name(function: Callable) -> str:
"""Get the name of a Python function."""
return function.__name__
@@ -240,13 +204,6 @@ def _convert_python_function_to_openai_function(
)
convert_python_function_to_openai_function = deprecated(
"0.1.16",
alternative="langchain_core.utils.function_calling.convert_to_openai_function()",
removal="1.0",
)(_convert_python_function_to_openai_function)
def _convert_typed_dict_to_openai_function(typed_dict: type) -> FunctionDescription:
visited: dict = {}
@@ -368,31 +325,6 @@ def _format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription:
}
format_tool_to_openai_function = deprecated(
"0.1.16",
alternative="langchain_core.utils.function_calling.convert_to_openai_function()",
removal="1.0",
)(_format_tool_to_openai_function)
@deprecated(
"0.1.16",
alternative="langchain_core.utils.function_calling.convert_to_openai_tool()",
removal="1.0",
)
def format_tool_to_openai_tool(tool: BaseTool) -> ToolDescription:
"""Format tool into the OpenAI function API.
Args:
tool: The tool to format.
Returns:
The tool description.
"""
function = _format_tool_to_openai_function(tool)
return {"type": "function", "function": function}
def convert_to_openai_function(
function: dict[str, Any] | type | Callable | BaseTool,
*,

View File

@@ -486,8 +486,10 @@ def test_provider_warns() -> None:
# implemented.
# This test should be removed when all major providers support content block
# standardization.
message = AIMessage("Hello.", response_metadata={"model_provider": "groq"})
with pytest.warns(match="not yet fully supported for Groq"):
message = AIMessage(
"Hello.", response_metadata={"model_provider": "google_vertexai"}
)
with pytest.warns(match="not yet fully supported for Google VertexAI"):
content_blocks = message.content_blocks
assert content_blocks == [{"type": "text", "text": "Hello."}]

View File

@@ -3,7 +3,9 @@ from langchain_core.language_models import BaseLanguageModel
from langchain_core.output_parsers.openai_tools import PydanticToolsParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import Runnable
from langchain_core.utils.function_calling import convert_pydantic_to_openai_function
from langchain_core.utils.function_calling import (
convert_to_openai_function as convert_pydantic_to_openai_function,
)
from pydantic import BaseModel
_EXTRACTION_TEMPLATE = """Extract and save the relevant entities mentioned \

View File

@@ -1,38 +0,0 @@
from importlib import metadata
from langchain_core._api import warn_deprecated
## Create namespaces for pydantic v1 and v2.
# This code must stay at the top of the file before other modules may
# attempt to import pydantic since it adds pydantic_v1 and pydantic_v2 to sys.modules.
#
# This hack is done for the following reasons:
# * LangChain will attempt to remain compatible with both pydantic v1 and v2 since
# both dependencies and dependents may be stuck on either version of v1 or v2.
# * Creating namespaces for pydantic v1 and v2 should allow us to write code that
# unambiguously uses either v1 or v2 API.
# * This change is easier to roll out and roll back.
from pydantic.v1 import * # noqa: F403
try:
_PYDANTIC_MAJOR_VERSION: int = int(metadata.version("pydantic").split(".")[0])
except metadata.PackageNotFoundError:
_PYDANTIC_MAJOR_VERSION = 0
warn_deprecated(
"0.3.0",
removal="1.0.0",
alternative="pydantic.v1 or pydantic",
message=(
"As of langchain-core 0.3.0, LangChain uses pydantic v2 internally. "
"The langchain.pydantic_v1 module was a "
"compatibility shim for pydantic v1, and should no longer be used. "
"Please update the code to import from Pydantic directly.\n\n"
"For example, replace imports like: "
"`from langchain_classic.pydantic_v1 import BaseModel`\n"
"with: `from pydantic import BaseModel`\n"
"or the v1 compatibility namespace if you are working in a code base "
"that has not been fully upgraded to pydantic 2 yet. "
"\tfrom pydantic.v1 import BaseModel\n"
),
)

View File

@@ -1,20 +0,0 @@
from langchain_core._api import warn_deprecated
from pydantic.v1.dataclasses import * # noqa: F403
warn_deprecated(
"0.3.0",
removal="1.0.0",
alternative="pydantic.v1 or pydantic",
message=(
"As of langchain-core 0.3.0, LangChain uses pydantic v2 internally. "
"The langchain.pydantic_v1 module was a "
"compatibility shim for pydantic v1, and should no longer be used. "
"Please update the code to import from Pydantic directly.\n\n"
"For example, replace imports like: "
"`from langchain_classic.pydantic_v1 import BaseModel`\n"
"with: `from pydantic import BaseModel`\n"
"or the v1 compatibility namespace if you are working in a code base "
"that has not been fully upgraded to pydantic 2 yet. "
"\tfrom pydantic.v1 import BaseModel\n"
),
)

View File

@@ -1,20 +0,0 @@
from langchain_core._api import warn_deprecated
from pydantic.v1.main import * # noqa: F403
warn_deprecated(
"0.3.0",
removal="1.0.0",
alternative="pydantic.v1 or pydantic",
message=(
"As of langchain-core 0.3.0, LangChain uses pydantic v2 internally. "
"The langchain.pydantic_v1 module was a "
"compatibility shim for pydantic v1, and should no longer be used. "
"Please update the code to import from Pydantic directly.\n\n"
"For example, replace imports like: "
"`from langchain_classic.pydantic_v1 import BaseModel`\n"
"with: `from pydantic import BaseModel`\n"
"or the v1 compatibility namespace if you are working in a code base "
"that has not been fully upgraded to pydantic 2 yet. "
"\tfrom pydantic.v1 import BaseModel\n"
),
)

View File

@@ -1,4 +1,6 @@
from langchain_core.utils.function_calling import format_tool_to_openai_function
from langchain_core.utils.function_calling import (
convert_to_openai_function as format_tool_to_openai_function,
)
# For backwards compatibility
__all__ = ["format_tool_to_openai_function"]

View File

@@ -11,8 +11,10 @@ from langchain_core.tools import (
render_text_description_and_args,
)
from langchain_core.utils.function_calling import (
format_tool_to_openai_function,
format_tool_to_openai_tool,
convert_to_openai_function as format_tool_to_openai_function,
)
from langchain_core.utils.function_calling import (
convert_to_openai_tool as format_tool_to_openai_tool,
)
__all__ = [

View File

@@ -1,8 +1,9 @@
from langchain_core.utils.function_calling import FunctionDescription, ToolDescription
from langchain_core.utils.function_calling import (
FunctionDescription,
ToolDescription,
convert_pydantic_to_openai_function,
convert_pydantic_to_openai_tool,
convert_to_openai_function as convert_pydantic_to_openai_function,
)
from langchain_core.utils.function_calling import (
convert_to_openai_tool as convert_pydantic_to_openai_tool,
)
__all__ = [

View File

@@ -1,4 +1,4 @@
from langchain_core.utils.function_calling import convert_pydantic_to_openai_function
from langchain_core.utils.function_calling import convert_to_openai_function
from pydantic import BaseModel, Field
@@ -9,7 +9,7 @@ def test_convert_pydantic_to_openai_function() -> None:
key: str = Field(..., description="API key")
days: int = Field(default=0, description="Number of days to forecast")
actual = convert_pydantic_to_openai_function(Data)
actual = convert_to_openai_function(Data)
expected = {
"name": "Data",
"description": "The data to return.",
@@ -41,7 +41,7 @@ def test_convert_pydantic_to_openai_function_nested() -> None:
data: Data
actual = convert_pydantic_to_openai_function(Model)
actual = convert_to_openai_function(Model)
expected = {
"name": "Model",
"description": "The model to return.",

View File

@@ -537,13 +537,12 @@ def create_agent( # noqa: PLR0915
(e.g., `"openai:gpt-4"`), a chat model instance (e.g., `ChatOpenAI()`).
tools: A list of tools, dicts, or callables. If `None` or an empty list,
the agent will consist of a model node without a tool calling loop.
system_prompt: An optional system prompt for the LLM. If provided as a string,
it will be converted to a SystemMessage and added to the beginning
of the message list.
system_prompt: An optional system prompt for the LLM. Prompts are converted to a
`SystemMessage` and added to the beginning of the message list.
middleware: A sequence of middleware instances to apply to the agent.
Middleware can intercept and modify agent behavior at various stages.
response_format: An optional configuration for structured responses.
Can be a ToolStrategy, ProviderStrategy, or a Pydantic model class.
Can be a `ToolStrategy`, `ProviderStrategy`, or a Pydantic model class.
If provided, the agent will handle structured output during the
conversation flow. Raw schemas will be wrapped in an appropriate strategy
based on model capabilities.
@@ -560,14 +559,14 @@ def create_agent( # noqa: PLR0915
This is useful if you want to return directly or run additional processing
on an output.
debug: A flag indicating whether to enable debug mode.
name: An optional name for the CompiledStateGraph.
name: An optional name for the `CompiledStateGraph`.
This name will be automatically used when adding the agent graph to
another graph as a subgraph node - particularly useful for building
multi-agent systems.
cache: An optional BaseCache instance to enable caching of graph execution.
cache: An optional `BaseCache` instance to enable caching of graph execution.
Returns:
A compiled StateGraph that can be used for chat interactions.
A compiled `StateGraph` that can be used for chat interactions.
The agent node calls the language model with the messages list (after applying
the system prompt). If the resulting AIMessage contains `tool_calls`, the graph will

View File

@@ -11,9 +11,8 @@ from .human_in_the_loop import (
from .model_call_limit import ModelCallLimitMiddleware
from .model_fallback import ModelFallbackMiddleware
from .pii import PIIDetectionError, PIIMiddleware
from .planning import PlanningMiddleware
from .prompt_caching import AnthropicPromptCachingMiddleware
from .summarization import SummarizationMiddleware
from .todo import TodoListMiddleware
from .tool_call_limit import ToolCallLimitMiddleware
from .tool_emulator import LLMToolEmulator
from .tool_selection import LLMToolSelectorMiddleware
@@ -21,6 +20,7 @@ from .types import (
AgentMiddleware,
AgentState,
ModelRequest,
ModelResponse,
after_agent,
after_model,
before_agent,
@@ -28,13 +28,12 @@ from .types import (
dynamic_prompt,
hook_config,
wrap_model_call,
wrap_tool_call,
)
__all__ = [
"AgentMiddleware",
"AgentState",
# should move to langchain-anthropic if we decide to keep it
"AnthropicPromptCachingMiddleware",
"ClearToolUsesEdit",
"ContextEditingMiddleware",
"HumanInTheLoopMiddleware",
@@ -44,10 +43,11 @@ __all__ = [
"ModelCallLimitMiddleware",
"ModelFallbackMiddleware",
"ModelRequest",
"ModelResponse",
"PIIDetectionError",
"PIIMiddleware",
"PlanningMiddleware",
"SummarizationMiddleware",
"TodoListMiddleware",
"ToolCallLimitMiddleware",
"after_agent",
"after_model",
@@ -56,4 +56,5 @@ __all__ = [
"dynamic_prompt",
"hook_config",
"wrap_model_call",
"wrap_tool_call",
]

View File

@@ -8,7 +8,7 @@ with any LangChain chat model.
from __future__ import annotations
from collections.abc import Callable, Iterable, Sequence
from collections.abc import Awaitable, Callable, Iterable, Sequence
from dataclasses import dataclass
from typing import Literal
@@ -239,6 +239,34 @@ class ContextEditingMiddleware(AgentMiddleware):
return handler(request)
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelCallResult:
"""Apply context edits before invoking the model via handler (async version)."""
if not request.messages:
return await handler(request)
if self.token_count_method == "approximate": # noqa: S105
def count_tokens(messages: Sequence[BaseMessage]) -> int:
return count_tokens_approximately(messages)
else:
system_msg = (
[SystemMessage(content=request.system_prompt)] if request.system_prompt else []
)
def count_tokens(messages: Sequence[BaseMessage]) -> int:
return request.model.get_num_tokens_from_messages(
system_msg + list(messages), request.tools
)
for edit in self.edits:
edit.apply(request.messages, count_tokens=count_tokens)
return await handler(request)
__all__ = [
"ClearToolUsesEdit",

View File

@@ -2,7 +2,8 @@
from typing import Any, Literal, Protocol
from langchain_core.messages import AIMessage, ToolCall, ToolMessage
from langchain_core.messages import AIMessage, RemoveMessage, ToolCall, ToolMessage
from langgraph.graph.message import REMOVE_ALL_MESSAGES
from langgraph.runtime import Runtime
from langgraph.types import interrupt
from typing_extensions import NotRequired, TypedDict
@@ -269,6 +270,42 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
)
raise ValueError(msg)
def before_agent(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
"""Before the agent runs, handle dangling tool calls from the most recent AIMessage."""
messages = state["messages"]
if not messages or len(messages) == 0:
return None
patched_messages = []
# Iterate over the messages and add any dangling tool calls
for i, msg in enumerate(messages):
patched_messages.append(msg)
if msg.type == "ai" and msg.tool_calls:
for tool_call in msg.tool_calls:
corresponding_tool_msg = next(
(
msg
for msg in messages[i:]
if msg.type == "tool" and msg.tool_call_id == tool_call["id"]
),
None,
)
if corresponding_tool_msg is None:
# We have a dangling tool call which needs a ToolMessage
tool_msg = (
f"Tool call {tool_call['name']} with id {tool_call['id']} was "
"cancelled - another message came in before it could be completed."
)
patched_messages.append(
ToolMessage(
content=tool_msg,
name=tool_call["name"],
tool_call_id=tool_call["id"],
)
)
return {"messages": [RemoveMessage(id=REMOVE_ALL_MESSAGES), *patched_messages]}
def after_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
"""Trigger interrupt flows for relevant tool calls after an AIMessage."""
messages = state["messages"]

View File

@@ -13,7 +13,7 @@ from langchain.agents.middleware.types import (
from langchain.chat_models import init_chat_model
if TYPE_CHECKING:
from collections.abc import Callable
from collections.abc import Awaitable, Callable
from langchain_core.language_models.chat_models import BaseChatModel
@@ -102,3 +102,38 @@ class ModelFallbackMiddleware(AgentMiddleware):
continue
raise last_exception
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelCallResult:
"""Try fallback models in sequence on errors (async version).
Args:
request: Initial model request.
handler: Async callback to execute the model.
Returns:
AIMessage from successful model call.
Raises:
Exception: If all models fail, re-raises last exception.
"""
# Try primary model first
last_exception: Exception
try:
return await handler(request)
except Exception as e: # noqa: BLE001
last_exception = e
# Try fallback models
for fallback_model in self.models:
request.model = fallback_model
try:
return await handler(request)
except Exception as e: # noqa: BLE001
last_exception = e
continue
raise last_exception

View File

@@ -1,89 +0,0 @@
"""Anthropic prompt caching middleware."""
from collections.abc import Callable
from typing import Literal
from warnings import warn
from langchain.agents.middleware.types import (
AgentMiddleware,
ModelCallResult,
ModelRequest,
ModelResponse,
)
class AnthropicPromptCachingMiddleware(AgentMiddleware):
"""Prompt Caching Middleware.
Optimizes API usage by caching conversation prefixes for Anthropic models.
Learn more about Anthropic prompt caching
[here](https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching).
"""
def __init__(
self,
type: Literal["ephemeral"] = "ephemeral",
ttl: Literal["5m", "1h"] = "5m",
min_messages_to_cache: int = 0,
unsupported_model_behavior: Literal["ignore", "warn", "raise"] = "warn",
) -> None:
"""Initialize the middleware with cache control settings.
Args:
type: The type of cache to use, only "ephemeral" is supported.
ttl: The time to live for the cache, only "5m" and "1h" are supported.
min_messages_to_cache: The minimum number of messages until the cache is used,
default is 0.
unsupported_model_behavior: The behavior to take when an unsupported model is used.
"ignore" will ignore the unsupported model and continue without caching.
"warn" will warn the user and continue without caching.
"raise" will raise an error and stop the agent.
"""
self.type = type
self.ttl = ttl
self.min_messages_to_cache = min_messages_to_cache
self.unsupported_model_behavior = unsupported_model_behavior
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
"""Modify the model request to add cache control blocks."""
try:
from langchain_anthropic import ChatAnthropic
except ImportError:
ChatAnthropic = None # noqa: N806
msg: str | None = None
if ChatAnthropic is None:
msg = (
"AnthropicPromptCachingMiddleware caching middleware only supports "
"Anthropic models. "
"Please install langchain-anthropic."
)
elif not isinstance(request.model, ChatAnthropic):
msg = (
"AnthropicPromptCachingMiddleware caching middleware only supports "
f"Anthropic models, not instances of {type(request.model)}"
)
if msg is not None:
if self.unsupported_model_behavior == "raise":
raise ValueError(msg)
if self.unsupported_model_behavior == "warn":
warn(msg, stacklevel=3)
else:
return handler(request)
messages_count = (
len(request.messages) + 1 if request.system_prompt else len(request.messages)
)
if messages_count < self.min_messages_to_cache:
return handler(request)
request.model_settings["cache_control"] = {"type": self.type, "ttl": self.ttl}
return handler(request)

View File

@@ -6,7 +6,7 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Annotated, Literal
if TYPE_CHECKING:
from collections.abc import Callable
from collections.abc import Awaitable, Callable
from langchain_core.messages import ToolMessage
from langchain_core.tools import tool
@@ -126,7 +126,7 @@ def write_todos(todos: list[Todo], tool_call_id: Annotated[str, InjectedToolCall
)
class PlanningMiddleware(AgentMiddleware):
class TodoListMiddleware(AgentMiddleware):
"""Middleware that provides todo list management capabilities to agents.
This middleware adds a `write_todos` tool that allows agents to create and manage
@@ -139,10 +139,10 @@ class PlanningMiddleware(AgentMiddleware):
Example:
```python
from langchain.agents.middleware.planning import PlanningMiddleware
from langchain.agents.middleware.todo import TodoListMiddleware
from langchain.agents import create_agent
agent = create_agent("openai:gpt-4o", middleware=[PlanningMiddleware()])
agent = create_agent("openai:gpt-4o", middleware=[TodoListMiddleware()])
# Agent now has access to write_todos tool and todo state tracking
result = await agent.invoke({"messages": [HumanMessage("Help me refactor my codebase")]})
@@ -165,7 +165,7 @@ class PlanningMiddleware(AgentMiddleware):
system_prompt: str = WRITE_TODOS_SYSTEM_PROMPT,
tool_description: str = WRITE_TODOS_TOOL_DESCRIPTION,
) -> None:
"""Initialize the PlanningMiddleware with optional custom prompts.
"""Initialize the TodoListMiddleware with optional custom prompts.
Args:
system_prompt: Custom system prompt to guide the agent on using the todo tool.
@@ -204,3 +204,16 @@ class PlanningMiddleware(AgentMiddleware):
else self.system_prompt
)
return handler(request)
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelCallResult:
"""Update the system prompt to include the todo system prompt (async version)."""
request.system_prompt = (
request.system_prompt + "\n\n" + self.system_prompt
if request.system_prompt
else self.system_prompt
)
return await handler(request)

View File

@@ -3,7 +3,7 @@
from __future__ import annotations
from collections.abc import Awaitable, Callable
from dataclasses import dataclass, field
from dataclasses import dataclass, field, replace
from inspect import iscoroutinefunction
from typing import (
TYPE_CHECKING,
@@ -30,7 +30,7 @@ from langgraph.channels.untracked_value import UntrackedValue
from langgraph.graph.message import add_messages
from langgraph.types import Command # noqa: TC002
from langgraph.typing import ContextT
from typing_extensions import NotRequired, Required, TypedDict, TypeVar
from typing_extensions import NotRequired, Required, TypedDict, TypeVar, Unpack
if TYPE_CHECKING:
from langchain_core.language_models.chat_models import BaseChatModel
@@ -62,6 +62,18 @@ JumpTo = Literal["tools", "model", "end"]
ResponseT = TypeVar("ResponseT")
class _ModelRequestOverrides(TypedDict, total=False):
"""Possible overrides for ModelRequest.override() method."""
model: BaseChatModel
system_prompt: str | None
messages: list[AnyMessage]
tool_choice: Any | None
tools: list[BaseTool | dict]
response_format: ResponseFormat | None
model_settings: dict[str, Any]
@dataclass
class ModelRequest:
"""Model request information for the agent."""
@@ -76,6 +88,36 @@ class ModelRequest:
runtime: Runtime[ContextT] # type: ignore[valid-type]
model_settings: dict[str, Any] = field(default_factory=dict)
def override(self, **overrides: Unpack[_ModelRequestOverrides]) -> ModelRequest:
"""Replace the request with a new request with the given overrides.
Returns a new `ModelRequest` instance with the specified attributes replaced.
This follows an immutable pattern, leaving the original request unchanged.
Args:
**overrides: Keyword arguments for attributes to override. Supported keys:
- model: BaseChatModel instance
- system_prompt: Optional system prompt string
- messages: List of messages
- tool_choice: Tool choice configuration
- tools: List of available tools
- response_format: Response format specification
- model_settings: Additional model settings
Returns:
New ModelRequest instance with specified overrides applied.
Examples:
```python
# Create a new request with different model
new_request = request.override(model=different_model)
# Override multiple attributes
new_request = request.override(system_prompt="New instructions", tool_choice="auto")
```
"""
return replace(self, **overrides)
@dataclass
class ModelResponse:

View File

@@ -8,7 +8,7 @@ from langchain_core.tools import (
tool,
)
from langchain.tools.tool_node import InjectedState, InjectedStore, ToolInvocationError
from langchain.tools.tool_node import InjectedState, InjectedStore
__all__ = [
"BaseTool",
@@ -17,6 +17,5 @@ __all__ = [
"InjectedToolArg",
"InjectedToolCallId",
"ToolException",
"ToolInvocationError",
"tool",
]

View File

@@ -81,6 +81,7 @@ from langgraph.graph.message import REMOVE_ALL_MESSAGES
from langgraph.runtime import get_runtime
from langgraph.types import Command, Send
from pydantic import BaseModel, ValidationError
from typing_extensions import Unpack
if TYPE_CHECKING:
from collections.abc import Sequence
@@ -104,6 +105,12 @@ TOOL_INVOCATION_ERROR_TEMPLATE = (
)
class _ToolCallRequestOverrides(TypedDict, total=False):
"""Possible overrides for ToolCallRequest.override() method."""
tool_call: ToolCall
@dataclass()
class ToolCallRequest:
"""Tool execution request passed to tool call interceptors.
@@ -120,6 +127,31 @@ class ToolCallRequest:
state: Any
runtime: Any
def override(self, **overrides: Unpack[_ToolCallRequestOverrides]) -> ToolCallRequest:
"""Replace the request with a new request with the given overrides.
Returns a new `ToolCallRequest` instance with the specified attributes replaced.
This follows an immutable pattern, leaving the original request unchanged.
Args:
**overrides: Keyword arguments for attributes to override. Supported keys:
- tool_call: Tool call dict with name, args, and id
Returns:
New ToolCallRequest instance with specified overrides applied.
Examples:
```python
# Modify tool call arguments without mutating original
modified_call = {**request.tool_call, "args": {"value": 10}}
new_request = request.override(tool_call=modified_call)
# Override multiple attributes
new_request = request.override(tool_call=modified_call, state=new_state)
```
"""
return replace(self, **overrides)
ToolCallWrapper = Callable[
[ToolCallRequest, Callable[[ToolCallRequest], ToolMessage | Command]],

View File

@@ -0,0 +1,142 @@
from langchain.agents.middleware.human_in_the_loop import HumanInTheLoopMiddleware
from langchain_core.messages import (
SystemMessage,
HumanMessage,
AIMessage,
ToolMessage,
ToolCall,
RemoveMessage,
)
from langgraph.graph.message import add_messages
class TestHumanInTheLoopMiddlewareBeforeModel:
"""Test HumanInTheLoopMiddleware before_model behavior."""
def test_first_message(self) -> None:
input_messages = [
SystemMessage(content="You are a helpful assistant.", id="1"),
HumanMessage(content="Hello, how are you?", id="2"),
]
middleware = HumanInTheLoopMiddleware(interrupt_on={})
state_update = middleware.before_agent({"messages": input_messages}, None)
assert state_update is not None
assert len(state_update["messages"]) == 3
assert state_update["messages"][0].type == "remove"
assert state_update["messages"][1].type == "system"
assert state_update["messages"][1].content == "You are a helpful assistant."
assert state_update["messages"][2].type == "human"
assert state_update["messages"][2].content == "Hello, how are you?"
assert state_update["messages"][2].id == "2"
def test_missing_tool_call(self) -> None:
input_messages = [
SystemMessage(content="You are a helpful assistant.", id="1"),
HumanMessage(content="Hello, how are you?", id="2"),
AIMessage(
content="I'm doing well, thank you!",
tool_calls=[
ToolCall(id="123", name="get_events_for_days", args={"date_str": "2025-01-01"})
],
id="3",
),
HumanMessage(content="What is the weather in Tokyo?", id="4"),
]
middleware = HumanInTheLoopMiddleware(interrupt_on={})
state_update = middleware.before_agent({"messages": input_messages}, None)
assert state_update is not None
assert len(state_update["messages"]) == 6
assert state_update["messages"][0].type == "remove"
assert state_update["messages"][1] == input_messages[0]
assert state_update["messages"][2] == input_messages[1]
assert state_update["messages"][3] == input_messages[2]
assert state_update["messages"][4].type == "tool"
assert state_update["messages"][4].tool_call_id == "123"
assert state_update["messages"][4].name == "get_events_for_days"
assert state_update["messages"][5] == input_messages[3]
updated_messages = add_messages(input_messages, state_update["messages"])
assert len(updated_messages) == 5
assert updated_messages[0] == input_messages[0]
assert updated_messages[1] == input_messages[1]
assert updated_messages[2] == input_messages[2]
assert updated_messages[3].type == "tool"
assert updated_messages[3].tool_call_id == "123"
assert updated_messages[3].name == "get_events_for_days"
assert updated_messages[4] == input_messages[3]
def test_no_missing_tool_calls(self) -> None:
input_messages = [
SystemMessage(content="You are a helpful assistant.", id="1"),
HumanMessage(content="Hello, how are you?", id="2"),
AIMessage(
content="I'm doing well, thank you!",
tool_calls=[
ToolCall(id="123", name="get_events_for_days", args={"date_str": "2025-01-01"})
],
id="3",
),
ToolMessage(content="I have no events for that date.", tool_call_id="123", id="4"),
HumanMessage(content="What is the weather in Tokyo?", id="5"),
]
middleware = HumanInTheLoopMiddleware(interrupt_on={})
state_update = middleware.before_agent({"messages": input_messages}, None)
assert state_update is not None
assert len(state_update["messages"]) == 6
assert state_update["messages"][0].type == "remove"
assert state_update["messages"][1:] == input_messages
updated_messages = add_messages(input_messages, state_update["messages"])
assert len(updated_messages) == 5
assert updated_messages == input_messages
def test_two_missing_tool_calls(self) -> None:
input_messages = [
SystemMessage(content="You are a helpful assistant.", id="1"),
HumanMessage(content="Hello, how are you?", id="2"),
AIMessage(
content="I'm doing well, thank you!",
tool_calls=[
ToolCall(id="123", name="get_events_for_days", args={"date_str": "2025-01-01"})
],
id="3",
),
HumanMessage(content="What is the weather in Tokyo?", id="4"),
AIMessage(
content="I'm doing well, thank you!",
tool_calls=[
ToolCall(id="456", name="get_events_for_days", args={"date_str": "2025-01-01"})
],
id="5",
),
HumanMessage(content="What is the weather in Tokyo?", id="6"),
]
middleware = HumanInTheLoopMiddleware(interrupt_on={})
state_update = middleware.before_agent({"messages": input_messages}, None)
assert state_update is not None
assert len(state_update["messages"]) == 9
assert state_update["messages"][0].type == "remove"
assert state_update["messages"][1] == input_messages[0]
assert state_update["messages"][2] == input_messages[1]
assert state_update["messages"][3] == input_messages[2]
assert state_update["messages"][4].type == "tool"
assert state_update["messages"][4].tool_call_id == "123"
assert state_update["messages"][4].name == "get_events_for_days"
assert state_update["messages"][5] == input_messages[3]
assert state_update["messages"][6] == input_messages[4]
assert state_update["messages"][7].type == "tool"
assert state_update["messages"][7].tool_call_id == "456"
assert state_update["messages"][7].name == "get_events_for_days"
assert state_update["messages"][8] == input_messages[5]
updated_messages = add_messages(input_messages, state_update["messages"])
assert len(updated_messages) == 8
assert updated_messages[0] == input_messages[0]
assert updated_messages[1] == input_messages[1]
assert updated_messages[2] == input_messages[2]
assert updated_messages[3].type == "tool"
assert updated_messages[3].tool_call_id == "123"
assert updated_messages[3].name == "get_events_for_days"
assert updated_messages[4] == input_messages[3]
assert updated_messages[5] == input_messages[4]
assert updated_messages[6].type == "tool"
assert updated_messages[6].tool_call_id == "456"
assert updated_messages[6].name == "get_events_for_days"
assert updated_messages[7] == input_messages[5]

View File

@@ -0,0 +1,381 @@
"""Unit tests for override() methods on ModelRequest and ToolCallRequest."""
import pytest
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.tools import BaseTool
from langchain.agents.middleware.types import ModelRequest
from langchain.tools.tool_node import ToolCallRequest
class TestModelRequestOverride:
"""Test the ModelRequest.override() method."""
def test_override_single_attribute(self) -> None:
"""Test overriding a single attribute."""
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
original_request = ModelRequest(
model=model,
system_prompt="Original prompt",
messages=[HumanMessage("Hi")],
tool_choice=None,
tools=[],
response_format=None,
state={},
runtime=None,
)
new_request = original_request.override(system_prompt="New prompt")
# New request should have the overridden value
assert new_request.system_prompt == "New prompt"
# Original request should be unchanged (immutability)
assert original_request.system_prompt == "Original prompt"
# Other attributes should be the same
assert new_request.model == original_request.model
assert new_request.messages == original_request.messages
def test_override_multiple_attributes(self) -> None:
"""Test overriding multiple attributes at once."""
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
original_request = ModelRequest(
model=model,
system_prompt="Original prompt",
messages=[HumanMessage("Hi")],
tool_choice=None,
tools=[],
response_format=None,
state={"count": 1},
runtime=None,
)
new_request = original_request.override(
system_prompt="New prompt",
tool_choice="auto",
state={"count": 2},
)
# Overridden values should be changed
assert new_request.system_prompt == "New prompt"
assert new_request.tool_choice == "auto"
assert new_request.state == {"count": 2}
# Original should be unchanged
assert original_request.system_prompt == "Original prompt"
assert original_request.tool_choice is None
assert original_request.state == {"count": 1}
def test_override_messages(self) -> None:
"""Test overriding messages list."""
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
original_messages = [HumanMessage("Hi")]
new_messages = [HumanMessage("Hello"), AIMessage("Hi there")]
original_request = ModelRequest(
model=model,
system_prompt=None,
messages=original_messages,
tool_choice=None,
tools=[],
response_format=None,
state={},
runtime=None,
)
new_request = original_request.override(messages=new_messages)
assert new_request.messages == new_messages
assert original_request.messages == original_messages
assert len(new_request.messages) == 2
assert len(original_request.messages) == 1
def test_override_model_settings(self) -> None:
"""Test overriding model_settings dict."""
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
original_request = ModelRequest(
model=model,
system_prompt=None,
messages=[HumanMessage("Hi")],
tool_choice=None,
tools=[],
response_format=None,
state={},
runtime=None,
model_settings={"temperature": 0.5},
)
new_request = original_request.override(
model_settings={"temperature": 0.9, "max_tokens": 100}
)
assert new_request.model_settings == {"temperature": 0.9, "max_tokens": 100}
assert original_request.model_settings == {"temperature": 0.5}
def test_override_with_none_value(self) -> None:
"""Test overriding with None value."""
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
original_request = ModelRequest(
model=model,
system_prompt="Original prompt",
messages=[HumanMessage("Hi")],
tool_choice="auto",
tools=[],
response_format=None,
state={},
runtime=None,
)
new_request = original_request.override(
system_prompt=None,
tool_choice=None,
)
assert new_request.system_prompt is None
assert new_request.tool_choice is None
assert original_request.system_prompt == "Original prompt"
assert original_request.tool_choice == "auto"
def test_override_preserves_identity_of_unchanged_objects(self) -> None:
"""Test that unchanged attributes maintain object identity."""
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
messages = [HumanMessage("Hi")]
state = {"key": "value"}
original_request = ModelRequest(
model=model,
system_prompt="Original prompt",
messages=messages,
tool_choice=None,
tools=[],
response_format=None,
state=state,
runtime=None,
)
new_request = original_request.override(system_prompt="New prompt")
# Unchanged objects should be the same instance
assert new_request.messages is messages
assert new_request.state is state
assert new_request.model is model
def test_override_chaining(self) -> None:
"""Test chaining multiple override calls."""
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
original_request = ModelRequest(
model=model,
system_prompt="Prompt 1",
messages=[HumanMessage("Hi")],
tool_choice=None,
tools=[],
response_format=None,
state={"count": 1},
runtime=None,
)
final_request = (
original_request.override(system_prompt="Prompt 2")
.override(state={"count": 2})
.override(tool_choice="auto")
)
assert final_request.system_prompt == "Prompt 2"
assert final_request.state == {"count": 2}
assert final_request.tool_choice == "auto"
# Original should be unchanged
assert original_request.system_prompt == "Prompt 1"
assert original_request.state == {"count": 1}
assert original_request.tool_choice is None
class TestToolCallRequestOverride:
"""Test the ToolCallRequest.override() method."""
def test_override_tool_call(self) -> None:
"""Test overriding tool_call dict."""
from langchain_core.tools import tool
@tool
def test_tool(x: int) -> str:
"""A test tool."""
return f"Result: {x}"
original_call = {"name": "test_tool", "args": {"x": 5}, "id": "1", "type": "tool_call"}
modified_call = {"name": "test_tool", "args": {"x": 10}, "id": "1", "type": "tool_call"}
original_request = ToolCallRequest(
tool_call=original_call,
tool=test_tool,
state={"messages": []},
runtime=None,
)
new_request = original_request.override(tool_call=modified_call)
# New request should have modified tool_call
assert new_request.tool_call["args"]["x"] == 10
# Original should be unchanged
assert original_request.tool_call["args"]["x"] == 5
# Other attributes should be the same
assert new_request.tool is original_request.tool
assert new_request.state is original_request.state
def test_override_state(self) -> None:
"""Test overriding state."""
from langchain_core.tools import tool
@tool
def test_tool(x: int) -> str:
"""A test tool."""
return f"Result: {x}"
tool_call = {"name": "test_tool", "args": {"x": 5}, "id": "1", "type": "tool_call"}
original_state = {"messages": [HumanMessage("Hi")]}
new_state = {"messages": [HumanMessage("Hi"), AIMessage("Hello")]}
original_request = ToolCallRequest(
tool_call=tool_call,
tool=test_tool,
state=original_state,
runtime=None,
)
new_request = original_request.override(state=new_state)
assert len(new_request.state["messages"]) == 2
assert len(original_request.state["messages"]) == 1
def test_override_multiple_attributes(self) -> None:
"""Test overriding multiple attributes at once."""
from langchain_core.tools import tool
@tool
def test_tool(x: int) -> str:
"""A test tool."""
return f"Result: {x}"
@tool
def another_tool(y: str) -> str:
"""Another test tool."""
return f"Output: {y}"
original_call = {"name": "test_tool", "args": {"x": 5}, "id": "1", "type": "tool_call"}
modified_call = {
"name": "another_tool",
"args": {"y": "hello"},
"id": "2",
"type": "tool_call",
}
original_request = ToolCallRequest(
tool_call=original_call,
tool=test_tool,
state={"count": 1},
runtime=None,
)
new_request = original_request.override(
tool_call=modified_call,
tool=another_tool,
state={"count": 2},
)
assert new_request.tool_call["name"] == "another_tool"
assert new_request.tool.name == "another_tool"
assert new_request.state == {"count": 2}
# Original unchanged
assert original_request.tool_call["name"] == "test_tool"
assert original_request.tool.name == "test_tool"
assert original_request.state == {"count": 1}
def test_override_with_copy_pattern(self) -> None:
"""Test common pattern of copying and modifying tool_call."""
from langchain_core.tools import tool
@tool
def test_tool(value: int) -> str:
"""A test tool."""
return f"Result: {value}"
original_call = {
"name": "test_tool",
"args": {"value": 5},
"id": "call_123",
"type": "tool_call",
}
original_request = ToolCallRequest(
tool_call=original_call,
tool=test_tool,
state={},
runtime=None,
)
# Common pattern: copy tool_call and modify args
modified_call = {**original_request.tool_call, "args": {"value": 10}}
new_request = original_request.override(tool_call=modified_call)
assert new_request.tool_call["args"]["value"] == 10
assert new_request.tool_call["id"] == "call_123"
assert new_request.tool_call["name"] == "test_tool"
# Original unchanged
assert original_request.tool_call["args"]["value"] == 5
def test_override_preserves_identity(self) -> None:
"""Test that unchanged attributes maintain object identity."""
from langchain_core.tools import tool
@tool
def test_tool(x: int) -> str:
"""A test tool."""
return f"Result: {x}"
tool_call = {"name": "test_tool", "args": {"x": 5}, "id": "1", "type": "tool_call"}
state = {"messages": []}
original_request = ToolCallRequest(
tool_call=tool_call,
tool=test_tool,
state=state,
runtime=None,
)
new_call = {"name": "test_tool", "args": {"x": 10}, "id": "1", "type": "tool_call"}
new_request = original_request.override(tool_call=new_call)
# Unchanged objects should be the same instance
assert new_request.tool is test_tool
assert new_request.state is state
def test_override_chaining(self) -> None:
"""Test chaining multiple override calls."""
from langchain_core.tools import tool
@tool
def test_tool(x: int) -> str:
"""A test tool."""
return f"Result: {x}"
tool_call = {"name": "test_tool", "args": {"x": 5}, "id": "1", "type": "tool_call"}
original_request = ToolCallRequest(
tool_call=tool_call,
tool=test_tool,
state={"count": 1},
runtime=None,
)
call_2 = {"name": "test_tool", "args": {"x": 10}, "id": "1", "type": "tool_call"}
call_3 = {"name": "test_tool", "args": {"x": 15}, "id": "1", "type": "tool_call"}
final_request = (
original_request.override(tool_call=call_2)
.override(state={"count": 2})
.override(tool_call=call_3)
)
assert final_request.tool_call["args"]["x"] == 15
assert final_request.state == {"count": 2}
# Original unchanged
assert original_request.tool_call["args"]["x"] == 5
assert original_request.state == {"count": 1}

View File

@@ -233,3 +233,169 @@ def test_exclude_tools_prevents_clearing() -> None:
def _fake_runtime() -> Runtime:
return cast(Runtime, object())
async def test_no_edit_when_below_trigger_async() -> None:
"""Test async version of context editing with no edit when below trigger."""
tool_call_id = "call-1"
ai_message = AIMessage(
content="",
tool_calls=[{"id": tool_call_id, "name": "search", "args": {}}],
)
tool_message = ToolMessage(content="12345", tool_call_id=tool_call_id)
state, request = _make_state_and_request([ai_message, tool_message])
middleware = ContextEditingMiddleware(
edits=[ClearToolUsesEdit(trigger=50)],
)
async def mock_handler(req: ModelRequest) -> AIMessage:
return AIMessage(content="mock response")
# Call awrap_model_call which modifies the request
await middleware.awrap_model_call(request, mock_handler)
# The request should have been modified in place
assert request.messages[0].content == ""
assert request.messages[1].content == "12345"
assert state["messages"] == request.messages
async def test_clear_tool_outputs_and_inputs_async() -> None:
"""Test async version of clearing tool outputs and inputs."""
tool_call_id = "call-2"
ai_message = AIMessage(
content=[
{"type": "tool_call", "id": tool_call_id, "name": "search", "args": {"query": "foo"}}
],
tool_calls=[{"id": tool_call_id, "name": "search", "args": {"query": "foo"}}],
)
tool_message = ToolMessage(content="x" * 200, tool_call_id=tool_call_id)
state, request = _make_state_and_request([ai_message, tool_message])
edit = ClearToolUsesEdit(
trigger=50,
clear_at_least=10,
clear_tool_inputs=True,
keep=0,
placeholder="[cleared output]",
)
middleware = ContextEditingMiddleware(edits=[edit])
async def mock_handler(req: ModelRequest) -> AIMessage:
return AIMessage(content="mock response")
# Call awrap_model_call which modifies the request
await middleware.awrap_model_call(request, mock_handler)
cleared_ai = request.messages[0]
cleared_tool = request.messages[1]
assert isinstance(cleared_tool, ToolMessage)
assert cleared_tool.content == "[cleared output]"
assert cleared_tool.response_metadata["context_editing"]["cleared"] is True
assert isinstance(cleared_ai, AIMessage)
assert cleared_ai.tool_calls[0]["args"] == {}
context_meta = cleared_ai.response_metadata.get("context_editing")
assert context_meta is not None
assert context_meta["cleared_tool_inputs"] == [tool_call_id]
assert state["messages"] == request.messages
async def test_respects_keep_last_tool_results_async() -> None:
"""Test async version respects keep parameter for last tool results."""
conversation: list[AIMessage | ToolMessage] = []
edits = [
("call-a", "tool-output-a" * 5),
("call-b", "tool-output-b" * 5),
("call-c", "tool-output-c" * 5),
]
for call_id, text in edits:
conversation.append(
AIMessage(
content="",
tool_calls=[{"id": call_id, "name": "tool", "args": {"input": call_id}}],
)
)
conversation.append(ToolMessage(content=text, tool_call_id=call_id))
state, request = _make_state_and_request(conversation)
middleware = ContextEditingMiddleware(
edits=[
ClearToolUsesEdit(
trigger=50,
keep=1,
placeholder="[cleared]",
)
],
token_count_method="model",
)
async def mock_handler(req: ModelRequest) -> AIMessage:
return AIMessage(content="mock response")
# Call awrap_model_call which modifies the request
await middleware.awrap_model_call(request, mock_handler)
cleared_messages = [
msg
for msg in request.messages
if isinstance(msg, ToolMessage) and msg.content == "[cleared]"
]
assert len(cleared_messages) == 2
assert isinstance(request.messages[-1], ToolMessage)
assert request.messages[-1].content != "[cleared]"
async def test_exclude_tools_prevents_clearing_async() -> None:
"""Test async version of excluding tools from clearing."""
search_call = "call-search"
calc_call = "call-calc"
state, request = _make_state_and_request(
[
AIMessage(
content="",
tool_calls=[{"id": search_call, "name": "search", "args": {"query": "foo"}}],
),
ToolMessage(content="search-results" * 20, tool_call_id=search_call),
AIMessage(
content="",
tool_calls=[{"id": calc_call, "name": "calculator", "args": {"a": 1, "b": 2}}],
),
ToolMessage(content="42", tool_call_id=calc_call),
]
)
middleware = ContextEditingMiddleware(
edits=[
ClearToolUsesEdit(
trigger=50,
clear_at_least=10,
keep=0,
exclude_tools=("search",),
placeholder="[cleared]",
)
],
)
async def mock_handler(req: ModelRequest) -> AIMessage:
return AIMessage(content="mock response")
# Call awrap_model_call which modifies the request
await middleware.awrap_model_call(request, mock_handler)
search_tool = request.messages[1]
calc_tool = request.messages[3]
assert isinstance(search_tool, ToolMessage)
assert search_tool.content == "search-results" * 20
assert isinstance(calc_tool, ToolMessage)
assert calc_tool.content == "[cleared]"

View File

@@ -32,8 +32,8 @@ from langchain.agents.middleware.human_in_the_loop import (
Action,
HumanInTheLoopMiddleware,
)
from langchain.agents.middleware.planning import (
PlanningMiddleware,
from langchain.agents.middleware.todo import (
TodoListMiddleware,
PlanningState,
WRITE_TODOS_SYSTEM_PROMPT,
write_todos,
@@ -44,7 +44,6 @@ from langchain.agents.middleware.model_call_limit import (
ModelCallLimitExceededError,
)
from langchain.agents.middleware.model_fallback import ModelFallbackMiddleware
from langchain.agents.middleware.prompt_caching import AnthropicPromptCachingMiddleware
from langchain.agents.middleware.summarization import SummarizationMiddleware
from langchain.agents.middleware.types import (
AgentMiddleware,
@@ -1024,115 +1023,6 @@ def test_human_in_the_loop_middleware_description_as_callable() -> None:
assert captured_request["action_requests"][1]["description"] == "Static description"
# Tests for AnthropicPromptCachingMiddleware
def test_anthropic_prompt_caching_middleware_initialization() -> None:
"""Test AnthropicPromptCachingMiddleware initialization."""
# Test with custom values
middleware = AnthropicPromptCachingMiddleware(
type="ephemeral", ttl="1h", min_messages_to_cache=5
)
assert middleware.type == "ephemeral"
assert middleware.ttl == "1h"
assert middleware.min_messages_to_cache == 5
# Test with default values
middleware = AnthropicPromptCachingMiddleware()
assert middleware.type == "ephemeral"
assert middleware.ttl == "5m"
assert middleware.min_messages_to_cache == 0
fake_request = ModelRequest(
model=FakeToolCallingModel(),
messages=[HumanMessage("Hello")],
system_prompt=None,
tool_choice=None,
tools=[],
response_format=None,
state={"messages": [HumanMessage("Hello")]},
runtime=cast(Runtime, object()),
model_settings={},
)
def mock_handler(req: ModelRequest) -> AIMessage:
return AIMessage(content="mock response", **req.model_settings)
result = middleware.wrap_model_call(fake_request, mock_handler)
# Check that model_settings were passed through via the request
assert fake_request.model_settings == {"cache_control": {"type": "ephemeral", "ttl": "5m"}}
def test_anthropic_prompt_caching_middleware_unsupported_model() -> None:
"""Test AnthropicPromptCachingMiddleware with unsupported model."""
from typing import cast
fake_request = ModelRequest(
model=FakeToolCallingModel(),
messages=[HumanMessage("Hello")],
system_prompt=None,
tool_choice=None,
tools=[],
response_format=None,
state={"messages": [HumanMessage("Hello")]},
runtime=cast(Runtime, object()),
model_settings={},
)
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="raise")
def mock_handler(req: ModelRequest) -> AIMessage:
return AIMessage(content="mock response")
with pytest.raises(
ValueError,
match="AnthropicPromptCachingMiddleware caching middleware only supports Anthropic models. Please install langchain-anthropic.",
):
middleware.wrap_model_call(fake_request, mock_handler)
langchain_anthropic = ModuleType("langchain_anthropic")
class MockChatAnthropic:
pass
langchain_anthropic.ChatAnthropic = MockChatAnthropic
with patch.dict("sys.modules", {"langchain_anthropic": langchain_anthropic}):
with pytest.raises(
ValueError,
match="AnthropicPromptCachingMiddleware caching middleware only supports Anthropic models, not instances of",
):
middleware.wrap_model_call(fake_request, mock_handler)
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="warn")
with warnings.catch_warnings(record=True) as w:
result = middleware.wrap_model_call(fake_request, mock_handler)
assert len(w) == 1
assert (
"AnthropicPromptCachingMiddleware caching middleware only supports Anthropic models. Please install langchain-anthropic."
in str(w[-1].message)
)
assert isinstance(result, AIMessage)
with warnings.catch_warnings(record=True) as w:
with patch.dict("sys.modules", {"langchain_anthropic": langchain_anthropic}):
result = middleware.wrap_model_call(fake_request, mock_handler)
assert isinstance(result, AIMessage)
assert len(w) == 1
assert (
"AnthropicPromptCachingMiddleware caching middleware only supports Anthropic models, not instances of"
in str(w[-1].message)
)
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore")
result = middleware.wrap_model_call(fake_request, mock_handler)
assert isinstance(result, AIMessage)
with patch.dict("sys.modules", {"langchain_anthropic": {"ChatAnthropic": object()}}):
result = middleware.wrap_model_call(fake_request, mock_handler)
assert isinstance(result, AIMessage)
# Tests for SummarizationMiddleware
def test_summarization_middleware_initialization() -> None:
"""Test SummarizationMiddleware initialization."""
@@ -1567,10 +1457,10 @@ def test_jump_to_is_ephemeral() -> None:
assert "jump_to" not in result
# Tests for PlanningMiddleware
def test_planning_middleware_initialization() -> None:
"""Test that PlanningMiddleware initializes correctly."""
middleware = PlanningMiddleware()
# Tests for TodoListMiddleware
def test_todo_middleware_initialization() -> None:
"""Test that TodoListMiddleware initializes correctly."""
middleware = TodoListMiddleware()
assert middleware.state_schema == PlanningState
assert len(middleware.tools) == 1
assert middleware.tools[0].name == "write_todos"
@@ -1583,9 +1473,9 @@ def test_planning_middleware_initialization() -> None:
(None, "## `write_todos`"),
],
)
def test_planning_middleware_on_model_call(original_prompt, expected_prompt_prefix) -> None:
def test_todo_middleware_on_model_call(original_prompt, expected_prompt_prefix) -> None:
"""Test that wrap_model_call handles system prompts correctly."""
middleware = PlanningMiddleware()
middleware = TodoListMiddleware()
model = FakeToolCallingModel()
state: PlanningState = {"messages": [HumanMessage(content="Hello")]}
@@ -1636,7 +1526,7 @@ def test_planning_middleware_on_model_call(original_prompt, expected_prompt_pref
),
],
)
def test_planning_middleware_write_todos_tool_execution(todos, expected_message) -> None:
def test_todo_middleware_write_todos_tool_execution(todos, expected_message) -> None:
"""Test that the write_todos tool executes correctly."""
tool_call = {
"args": {"todos": todos},
@@ -1656,7 +1546,7 @@ def test_planning_middleware_write_todos_tool_execution(todos, expected_message)
[{"status": "pending"}],
],
)
def test_planning_middleware_write_todos_tool_validation_errors(invalid_todos) -> None:
def test_todo_middleware_write_todos_tool_validation_errors(invalid_todos) -> None:
"""Test that the write_todos tool rejects invalid input."""
tool_call = {
"args": {"todos": invalid_todos},
@@ -1668,7 +1558,7 @@ def test_planning_middleware_write_todos_tool_validation_errors(invalid_todos) -
write_todos.invoke(tool_call)
def test_planning_middleware_agent_creation_with_middleware() -> None:
def test_todo_middleware_agent_creation_with_middleware() -> None:
"""Test that an agent can be created with the planning middleware."""
model = FakeToolCallingModel(
tool_calls=[
@@ -1699,7 +1589,7 @@ def test_planning_middleware_agent_creation_with_middleware() -> None:
[],
]
)
middleware = PlanningMiddleware()
middleware = TodoListMiddleware()
agent = create_agent(model=model, middleware=[middleware])
result = agent.invoke({"messages": [HumanMessage("Hello")]})
@@ -1716,12 +1606,14 @@ def test_planning_middleware_agent_creation_with_middleware() -> None:
assert len(result["messages"]) == 8
def test_planning_middleware_custom_system_prompt() -> None:
"""Test that PlanningMiddleware can be initialized with custom system prompt."""
def test_todo_middleware_custom_system_prompt() -> None:
"""Test that TodoListMiddleware can be initialized with custom system prompt."""
custom_system_prompt = "Custom todo system prompt for testing"
middleware = PlanningMiddleware(system_prompt=custom_system_prompt)
middleware = TodoListMiddleware(system_prompt=custom_system_prompt)
model = FakeToolCallingModel()
state: PlanningState = {"messages": [HumanMessage(content="Hello")]}
request = ModelRequest(
model=model,
system_prompt="Original prompt",
@@ -1730,10 +1622,10 @@ def test_planning_middleware_custom_system_prompt() -> None:
tools=[],
response_format=None,
model_settings={},
state=state,
runtime=cast(Runtime, object()),
)
state: PlanningState = {"messages": [HumanMessage(content="Hello")]}
def mock_handler(req: ModelRequest) -> AIMessage:
return AIMessage(content="mock response")
@@ -1743,21 +1635,21 @@ def test_planning_middleware_custom_system_prompt() -> None:
assert request.system_prompt == f"Original prompt\n\n{custom_system_prompt}"
def test_planning_middleware_custom_tool_description() -> None:
"""Test that PlanningMiddleware can be initialized with custom tool description."""
def test_todo_middleware_custom_tool_description() -> None:
"""Test that TodoListMiddleware can be initialized with custom tool description."""
custom_tool_description = "Custom tool description for testing"
middleware = PlanningMiddleware(tool_description=custom_tool_description)
middleware = TodoListMiddleware(tool_description=custom_tool_description)
assert len(middleware.tools) == 1
tool = middleware.tools[0]
assert tool.description == custom_tool_description
def test_planning_middleware_custom_system_prompt_and_tool_description() -> None:
"""Test that PlanningMiddleware can be initialized with both custom prompts."""
def test_todo_middleware_custom_system_prompt_and_tool_description() -> None:
"""Test that TodoListMiddleware can be initialized with both custom prompts."""
custom_system_prompt = "Custom system prompt"
custom_tool_description = "Custom tool description"
middleware = PlanningMiddleware(
middleware = TodoListMiddleware(
system_prompt=custom_system_prompt,
tool_description=custom_tool_description,
)
@@ -1792,9 +1684,9 @@ def test_planning_middleware_custom_system_prompt_and_tool_description() -> None
assert tool.description == custom_tool_description
def test_planning_middleware_default_prompts() -> None:
"""Test that PlanningMiddleware uses default prompts when none provided."""
middleware = PlanningMiddleware()
def test_todo_middleware_default_prompts() -> None:
"""Test that TodoListMiddleware uses default prompts when none provided."""
middleware = TodoListMiddleware()
# Verify default system prompt
assert middleware.system_prompt == WRITE_TODOS_SYSTEM_PROMPT
@@ -1806,9 +1698,9 @@ def test_planning_middleware_default_prompts() -> None:
assert tool.description == WRITE_TODOS_TOOL_DESCRIPTION
def test_planning_middleware_custom_system_prompt() -> None:
def test_todo_middleware_custom_system_prompt_in_agent() -> None:
"""Test that custom tool executes correctly in an agent."""
middleware = PlanningMiddleware(system_prompt="call the write_todos tool")
middleware = TodoListMiddleware(system_prompt="call the write_todos tool")
model = FakeToolCallingModel(
tool_calls=[

View File

@@ -0,0 +1,215 @@
"""Unit tests for ModelFallbackMiddleware."""
from __future__ import annotations
from typing import cast
import pytest
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
from langchain_core.messages import AIMessage
from langchain.agents.middleware.model_fallback import ModelFallbackMiddleware
from langchain.agents.middleware.types import ModelRequest, ModelResponse
from langgraph.runtime import Runtime
def _fake_runtime() -> Runtime:
return cast(Runtime, object())
def _make_request() -> ModelRequest:
"""Create a minimal ModelRequest for testing."""
model = GenericFakeChatModel(messages=iter([AIMessage(content="primary")]))
return ModelRequest(
model=model,
system_prompt=None,
messages=[],
tool_choice=None,
tools=[],
response_format=None,
state=cast("AgentState", {}), # type: ignore[name-defined]
runtime=_fake_runtime(),
model_settings={},
)
def test_primary_model_succeeds() -> None:
"""Test that primary model is used when it succeeds."""
primary_model = GenericFakeChatModel(messages=iter([AIMessage(content="primary response")]))
fallback_model = GenericFakeChatModel(messages=iter([AIMessage(content="fallback response")]))
middleware = ModelFallbackMiddleware(fallback_model)
request = _make_request()
request.model = primary_model
def mock_handler(req: ModelRequest) -> ModelResponse:
# Simulate successful model call
result = req.model.invoke([])
return ModelResponse(result=[result])
response = middleware.wrap_model_call(request, mock_handler)
assert isinstance(response, ModelResponse)
assert response.result[0].content == "primary response"
def test_fallback_on_primary_failure() -> None:
"""Test that fallback model is used when primary fails."""
class FailingPrimaryModel(GenericFakeChatModel):
def _generate(self, messages, **kwargs):
raise ValueError("Primary model failed")
primary_model = FailingPrimaryModel(messages=iter([AIMessage(content="should not see")]))
fallback_model = GenericFakeChatModel(messages=iter([AIMessage(content="fallback response")]))
middleware = ModelFallbackMiddleware(fallback_model)
request = _make_request()
request.model = primary_model
def mock_handler(req: ModelRequest) -> ModelResponse:
result = req.model.invoke([])
return ModelResponse(result=[result])
response = middleware.wrap_model_call(request, mock_handler)
assert isinstance(response, ModelResponse)
assert response.result[0].content == "fallback response"
def test_multiple_fallbacks() -> None:
"""Test that multiple fallback models are tried in sequence."""
class FailingModel(GenericFakeChatModel):
def _generate(self, messages, **kwargs):
raise ValueError("Model failed")
primary_model = FailingModel(messages=iter([AIMessage(content="should not see")]))
fallback1 = FailingModel(messages=iter([AIMessage(content="fallback1")]))
fallback2 = GenericFakeChatModel(messages=iter([AIMessage(content="fallback2")]))
middleware = ModelFallbackMiddleware(fallback1, fallback2)
request = _make_request()
request.model = primary_model
def mock_handler(req: ModelRequest) -> ModelResponse:
result = req.model.invoke([])
return ModelResponse(result=[result])
response = middleware.wrap_model_call(request, mock_handler)
assert isinstance(response, ModelResponse)
assert response.result[0].content == "fallback2"
def test_all_models_fail() -> None:
"""Test that exception is raised when all models fail."""
class AlwaysFailingModel(GenericFakeChatModel):
def _generate(self, messages, **kwargs):
raise ValueError("Model failed")
primary_model = AlwaysFailingModel(messages=iter([]))
fallback_model = AlwaysFailingModel(messages=iter([]))
middleware = ModelFallbackMiddleware(fallback_model)
request = _make_request()
request.model = primary_model
def mock_handler(req: ModelRequest) -> ModelResponse:
result = req.model.invoke([])
return ModelResponse(result=[result])
with pytest.raises(ValueError, match="Model failed"):
middleware.wrap_model_call(request, mock_handler)
async def test_primary_model_succeeds_async() -> None:
"""Test async version - primary model is used when it succeeds."""
primary_model = GenericFakeChatModel(messages=iter([AIMessage(content="primary response")]))
fallback_model = GenericFakeChatModel(messages=iter([AIMessage(content="fallback response")]))
middleware = ModelFallbackMiddleware(fallback_model)
request = _make_request()
request.model = primary_model
async def mock_handler(req: ModelRequest) -> ModelResponse:
# Simulate successful async model call
result = await req.model.ainvoke([])
return ModelResponse(result=[result])
response = await middleware.awrap_model_call(request, mock_handler)
assert isinstance(response, ModelResponse)
assert response.result[0].content == "primary response"
async def test_fallback_on_primary_failure_async() -> None:
"""Test async version - fallback model is used when primary fails."""
class AsyncFailingPrimaryModel(GenericFakeChatModel):
async def _agenerate(self, messages, **kwargs):
raise ValueError("Primary model failed")
primary_model = AsyncFailingPrimaryModel(messages=iter([AIMessage(content="should not see")]))
fallback_model = GenericFakeChatModel(messages=iter([AIMessage(content="fallback response")]))
middleware = ModelFallbackMiddleware(fallback_model)
request = _make_request()
request.model = primary_model
async def mock_handler(req: ModelRequest) -> ModelResponse:
result = await req.model.ainvoke([])
return ModelResponse(result=[result])
response = await middleware.awrap_model_call(request, mock_handler)
assert isinstance(response, ModelResponse)
assert response.result[0].content == "fallback response"
async def test_multiple_fallbacks_async() -> None:
"""Test async version - multiple fallback models are tried in sequence."""
class AsyncFailingModel(GenericFakeChatModel):
async def _agenerate(self, messages, **kwargs):
raise ValueError("Model failed")
primary_model = AsyncFailingModel(messages=iter([AIMessage(content="should not see")]))
fallback1 = AsyncFailingModel(messages=iter([AIMessage(content="fallback1")]))
fallback2 = GenericFakeChatModel(messages=iter([AIMessage(content="fallback2")]))
middleware = ModelFallbackMiddleware(fallback1, fallback2)
request = _make_request()
request.model = primary_model
async def mock_handler(req: ModelRequest) -> ModelResponse:
result = await req.model.ainvoke([])
return ModelResponse(result=[result])
response = await middleware.awrap_model_call(request, mock_handler)
assert isinstance(response, ModelResponse)
assert response.result[0].content == "fallback2"
async def test_all_models_fail_async() -> None:
"""Test async version - exception is raised when all models fail."""
class AsyncAlwaysFailingModel(GenericFakeChatModel):
async def _agenerate(self, messages, **kwargs):
raise ValueError("Model failed")
primary_model = AsyncAlwaysFailingModel(messages=iter([]))
fallback_model = AsyncAlwaysFailingModel(messages=iter([]))
middleware = ModelFallbackMiddleware(fallback_model)
request = _make_request()
request.model = primary_model
async def mock_handler(req: ModelRequest) -> ModelResponse:
result = await req.model.ainvoke([])
return ModelResponse(result=[result])
with pytest.raises(ValueError, match="Model failed"):
await middleware.awrap_model_call(request, mock_handler)

View File

@@ -0,0 +1,172 @@
"""Unit tests for TodoListMiddleware."""
from __future__ import annotations
from typing import cast
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
from langchain_core.messages import AIMessage
from langchain.agents.middleware.todo import TodoListMiddleware
from langchain.agents.middleware.types import ModelRequest, ModelResponse
from langgraph.runtime import Runtime
def _fake_runtime() -> Runtime:
return cast(Runtime, object())
def _make_request(system_prompt: str | None = None) -> ModelRequest:
"""Create a minimal ModelRequest for testing."""
model = GenericFakeChatModel(messages=iter([AIMessage(content="response")]))
return ModelRequest(
model=model,
system_prompt=system_prompt,
messages=[],
tool_choice=None,
tools=[],
response_format=None,
state=cast("AgentState", {}), # type: ignore[name-defined]
runtime=_fake_runtime(),
model_settings={},
)
def test_adds_system_prompt_when_none_exists() -> None:
"""Test that middleware adds system prompt when request has none."""
middleware = TodoListMiddleware()
request = _make_request(system_prompt=None)
def mock_handler(req: ModelRequest) -> ModelResponse:
return ModelResponse(result=[AIMessage(content="response")])
middleware.wrap_model_call(request, mock_handler)
# System prompt should be set
assert request.system_prompt is not None
assert "write_todos" in request.system_prompt
def test_appends_to_existing_system_prompt() -> None:
"""Test that middleware appends to existing system prompt."""
existing_prompt = "You are a helpful assistant."
middleware = TodoListMiddleware()
request = _make_request(system_prompt=existing_prompt)
def mock_handler(req: ModelRequest) -> ModelResponse:
return ModelResponse(result=[AIMessage(content="response")])
middleware.wrap_model_call(request, mock_handler)
# System prompt should contain both
assert request.system_prompt is not None
assert existing_prompt in request.system_prompt
assert "write_todos" in request.system_prompt
assert request.system_prompt.startswith(existing_prompt)
def test_custom_system_prompt() -> None:
"""Test that middleware uses custom system prompt."""
custom_prompt = "Custom planning instructions"
middleware = TodoListMiddleware(system_prompt=custom_prompt)
request = _make_request(system_prompt=None)
def mock_handler(req: ModelRequest) -> ModelResponse:
return ModelResponse(result=[AIMessage(content="response")])
middleware.wrap_model_call(request, mock_handler)
# Should use custom prompt
assert request.system_prompt == custom_prompt
def test_has_write_todos_tool() -> None:
"""Test that middleware registers the write_todos tool."""
middleware = TodoListMiddleware()
# Should have one tool registered
assert len(middleware.tools) == 1
assert middleware.tools[0].name == "write_todos"
def test_custom_tool_description() -> None:
"""Test that middleware uses custom tool description."""
custom_description = "Custom todo tool description"
middleware = TodoListMiddleware(tool_description=custom_description)
# Tool should use custom description
assert len(middleware.tools) == 1
assert middleware.tools[0].description == custom_description
# ==============================================================================
# Async Tests
# ==============================================================================
async def test_adds_system_prompt_when_none_exists_async() -> None:
"""Test async version - middleware adds system prompt when request has none."""
middleware = TodoListMiddleware()
request = _make_request(system_prompt=None)
async def mock_handler(req: ModelRequest) -> ModelResponse:
return ModelResponse(result=[AIMessage(content="response")])
await middleware.awrap_model_call(request, mock_handler)
# System prompt should be set
assert request.system_prompt is not None
assert "write_todos" in request.system_prompt
async def test_appends_to_existing_system_prompt_async() -> None:
"""Test async version - middleware appends to existing system prompt."""
existing_prompt = "You are a helpful assistant."
middleware = TodoListMiddleware()
request = _make_request(system_prompt=existing_prompt)
async def mock_handler(req: ModelRequest) -> ModelResponse:
return ModelResponse(result=[AIMessage(content="response")])
await middleware.awrap_model_call(request, mock_handler)
# System prompt should contain both
assert request.system_prompt is not None
assert existing_prompt in request.system_prompt
assert "write_todos" in request.system_prompt
assert request.system_prompt.startswith(existing_prompt)
async def test_custom_system_prompt_async() -> None:
"""Test async version - middleware uses custom system prompt."""
custom_prompt = "Custom planning instructions"
middleware = TodoListMiddleware(system_prompt=custom_prompt)
request = _make_request(system_prompt=None)
async def mock_handler(req: ModelRequest) -> ModelResponse:
return ModelResponse(result=[AIMessage(content="response")])
await middleware.awrap_model_call(request, mock_handler)
# Should use custom prompt
assert request.system_prompt == custom_prompt
async def test_handler_called_with_modified_request_async() -> None:
"""Test async version - handler receives the modified request."""
middleware = TodoListMiddleware()
request = _make_request(system_prompt="Original")
handler_called = {"value": False}
received_prompt = {"value": None}
async def mock_handler(req: ModelRequest) -> ModelResponse:
handler_called["value"] = True
received_prompt["value"] = req.system_prompt
return ModelResponse(result=[AIMessage(content="response")])
await middleware.awrap_model_call(request, mock_handler)
assert handler_called["value"]
assert received_prompt["value"] is not None
assert "Original" in received_prompt["value"]
assert "write_todos" in received_prompt["value"]

View File

@@ -7,7 +7,6 @@ EXPECTED_ALL = {
"InjectedToolArg",
"InjectedToolCallId",
"ToolException",
"ToolInvocationError",
"tool",
}

View File

@@ -0,0 +1,9 @@
"""Middleware for Anthropic models."""
from langchain_anthropic.middleware.prompt_caching import (
AnthropicPromptCachingMiddleware,
)
__all__ = [
"AnthropicPromptCachingMiddleware",
]

View File

@@ -0,0 +1,123 @@
"""Anthropic prompt caching middleware.
Requires:
- langchain: For agent middleware framework
- langchain-anthropic: For ChatAnthropic model (already a dependency)
"""
from collections.abc import Awaitable, Callable
from typing import Literal
from warnings import warn
from langchain_anthropic.chat_models import ChatAnthropic
try:
from langchain.agents.middleware.types import (
AgentMiddleware,
ModelCallResult,
ModelRequest,
ModelResponse,
)
except ImportError as e:
msg = (
"AnthropicPromptCachingMiddleware requires 'langchain' to be installed. "
"This middleware is designed for use with LangChain agents. "
"Install it with: pip install langchain"
)
raise ImportError(msg) from e
class AnthropicPromptCachingMiddleware(AgentMiddleware):
"""Prompt Caching Middleware.
Optimizes API usage by caching conversation prefixes for Anthropic models.
Requires both 'langchain' and 'langchain-anthropic' packages to be installed.
Learn more about Anthropic prompt caching
[here](https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching).
"""
def __init__(
self,
type: Literal["ephemeral"] = "ephemeral", # noqa: A002
ttl: Literal["5m", "1h"] = "5m",
min_messages_to_cache: int = 0,
unsupported_model_behavior: Literal["ignore", "warn", "raise"] = "warn",
) -> None:
"""Initialize the middleware with cache control settings.
Args:
type: The type of cache to use, only "ephemeral" is supported.
ttl: The time to live for the cache, only "5m" and "1h" are
supported.
min_messages_to_cache: The minimum number of messages until the
cache is used, default is 0.
unsupported_model_behavior: The behavior to take when an
unsupported model is used. "ignore" will ignore the unsupported
model and continue without caching. "warn" will warn the user
and continue without caching. "raise" will raise an error and
stop the agent.
"""
self.type = type
self.ttl = ttl
self.min_messages_to_cache = min_messages_to_cache
self.unsupported_model_behavior = unsupported_model_behavior
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
"""Modify the model request to add cache control blocks."""
if not isinstance(request.model, ChatAnthropic):
msg = (
"AnthropicPromptCachingMiddleware caching middleware only supports "
f"Anthropic models, not instances of {type(request.model)}"
)
if self.unsupported_model_behavior == "raise":
raise ValueError(msg)
if self.unsupported_model_behavior == "warn":
warn(msg, stacklevel=3)
return handler(request)
messages_count = (
len(request.messages) + 1
if request.system_prompt
else len(request.messages)
)
if messages_count < self.min_messages_to_cache:
return handler(request)
request.model_settings["cache_control"] = {"type": self.type, "ttl": self.ttl}
return handler(request)
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelCallResult:
"""Modify the model request to add cache control blocks (async version)."""
if not isinstance(request.model, ChatAnthropic):
msg = (
"AnthropicPromptCachingMiddleware caching middleware only supports "
f"Anthropic models, not instances of {type(request.model)}"
)
if self.unsupported_model_behavior == "raise":
raise ValueError(msg)
if self.unsupported_model_behavior == "warn":
warn(msg, stacklevel=3)
return await handler(request)
messages_count = (
len(request.messages) + 1
if request.system_prompt
else len(request.messages)
)
if messages_count < self.min_messages_to_cache:
return await handler(request)
request.model_settings["cache_control"] = {"type": self.type, "ttl": self.ttl}
return await handler(request)

View File

@@ -41,6 +41,7 @@ test = [
"vcrpy>=7.0.0,<8.0.0",
"langchain-core",
"langchain-tests",
"langchain",
]
lint = ["ruff>=0.13.1,<0.14.0"]
dev = ["langchain-core"]
@@ -55,6 +56,7 @@ typing = [
[tool.uv.sources]
langchain-core = { path = "../../core", editable = true }
langchain-tests = { path = "../../standard-tests", editable = true }
langchain = { path = "../../langchain_v1", editable = true }
[tool.mypy]
disallow_untyped_defs = "True"

View File

@@ -0,0 +1 @@
"""Tests for Anthropic middleware."""

View File

@@ -0,0 +1,246 @@
"""Tests for Anthropic prompt caching middleware."""
import warnings
from typing import Any, cast
from unittest.mock import MagicMock
import pytest
from langchain.agents.middleware.types import ModelRequest, ModelResponse
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.outputs import ChatGeneration, ChatResult
from langgraph.runtime import Runtime
from langchain_anthropic.chat_models import ChatAnthropic
from langchain_anthropic.middleware import AnthropicPromptCachingMiddleware
class FakeToolCallingModel(BaseChatModel):
"""Fake model for testing middleware."""
def _generate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
"""Top Level call"""
messages_string = "-".join([str(m.content) for m in messages])
message = AIMessage(content=messages_string, id="0")
return ChatResult(generations=[ChatGeneration(message=message)])
async def _agenerate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: AsyncCallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
"""Async top level call"""
messages_string = "-".join([str(m.content) for m in messages])
message = AIMessage(content=messages_string, id="0")
return ChatResult(generations=[ChatGeneration(message=message)])
@property
def _llm_type(self) -> str:
return "fake-tool-call-model"
def test_anthropic_prompt_caching_middleware_initialization() -> None:
"""Test AnthropicPromptCachingMiddleware initialization."""
# Test with custom values
middleware = AnthropicPromptCachingMiddleware(
type="ephemeral", ttl="1h", min_messages_to_cache=5
)
assert middleware.type == "ephemeral"
assert middleware.ttl == "1h"
assert middleware.min_messages_to_cache == 5
# Test with default values
middleware = AnthropicPromptCachingMiddleware()
assert middleware.type == "ephemeral"
assert middleware.ttl == "5m"
assert middleware.min_messages_to_cache == 0
# Create a mock ChatAnthropic instance
mock_chat_anthropic = MagicMock(spec=ChatAnthropic)
fake_request = ModelRequest(
model=mock_chat_anthropic,
messages=[HumanMessage("Hello")],
system_prompt=None,
tool_choice=None,
tools=[],
response_format=None,
state={"messages": [HumanMessage("Hello")]},
runtime=cast(Runtime, object()),
model_settings={},
)
def mock_handler(req: ModelRequest) -> ModelResponse:
return ModelResponse(result=[AIMessage(content="mock response")])
middleware.wrap_model_call(fake_request, mock_handler)
# Check that model_settings were passed through via the request
assert fake_request.model_settings == {
"cache_control": {"type": "ephemeral", "ttl": "5m"}
}
def test_anthropic_prompt_caching_middleware_unsupported_model() -> None:
"""Test AnthropicPromptCachingMiddleware with unsupported model."""
fake_request = ModelRequest(
model=FakeToolCallingModel(),
messages=[HumanMessage("Hello")],
system_prompt=None,
tool_choice=None,
tools=[],
response_format=None,
state={"messages": [HumanMessage("Hello")]},
runtime=cast(Runtime, object()),
model_settings={},
)
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="raise")
def mock_handler(req: ModelRequest) -> ModelResponse:
return ModelResponse(result=[AIMessage(content="mock response")])
# Since we're in the langchain-anthropic package, ChatAnthropic is always
# available. Test that it raises an error for unsupported model instances
with pytest.raises(
ValueError,
match=(
"AnthropicPromptCachingMiddleware caching middleware only supports "
"Anthropic models, not instances of"
),
):
middleware.wrap_model_call(fake_request, mock_handler)
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="warn")
# Test warn behavior for unsupported model instances
with warnings.catch_warnings(record=True) as w:
result = middleware.wrap_model_call(fake_request, mock_handler)
assert isinstance(result, ModelResponse)
assert len(w) == 1
assert (
"AnthropicPromptCachingMiddleware caching middleware only supports "
"Anthropic models, not instances of"
) in str(w[-1].message)
# Test ignore behavior
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore")
result = middleware.wrap_model_call(fake_request, mock_handler)
assert isinstance(result, ModelResponse)
async def test_anthropic_prompt_caching_middleware_async() -> None:
"""Test AnthropicPromptCachingMiddleware async path."""
# Test with custom values
middleware = AnthropicPromptCachingMiddleware(
type="ephemeral", ttl="1h", min_messages_to_cache=5
)
# Create a mock ChatAnthropic instance
mock_chat_anthropic = MagicMock(spec=ChatAnthropic)
fake_request = ModelRequest(
model=mock_chat_anthropic,
messages=[HumanMessage("Hello")] * 6,
system_prompt=None,
tool_choice=None,
tools=[],
response_format=None,
state={"messages": [HumanMessage("Hello")] * 6},
runtime=cast(Runtime, object()),
model_settings={},
)
async def mock_handler(req: ModelRequest) -> ModelResponse:
return ModelResponse(result=[AIMessage(content="mock response")])
result = await middleware.awrap_model_call(fake_request, mock_handler)
assert isinstance(result, ModelResponse)
# Check that model_settings were passed through via the request
assert fake_request.model_settings == {
"cache_control": {"type": "ephemeral", "ttl": "1h"}
}
async def test_anthropic_prompt_caching_middleware_async_unsupported_model() -> None:
"""Test AnthropicPromptCachingMiddleware async path with unsupported model."""
fake_request = ModelRequest(
model=FakeToolCallingModel(),
messages=[HumanMessage("Hello")],
system_prompt=None,
tool_choice=None,
tools=[],
response_format=None,
state={"messages": [HumanMessage("Hello")]},
runtime=cast(Runtime, object()),
model_settings={},
)
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="raise")
async def mock_handler(req: ModelRequest) -> ModelResponse:
return ModelResponse(result=[AIMessage(content="mock response")])
# Test that it raises an error for unsupported model instances
with pytest.raises(
ValueError,
match=(
"AnthropicPromptCachingMiddleware caching middleware only supports "
"Anthropic models, not instances of"
),
):
await middleware.awrap_model_call(fake_request, mock_handler)
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="warn")
# Test warn behavior for unsupported model instances
with warnings.catch_warnings(record=True) as w:
result = await middleware.awrap_model_call(fake_request, mock_handler)
assert isinstance(result, ModelResponse)
assert len(w) == 1
assert (
"AnthropicPromptCachingMiddleware caching middleware only supports "
"Anthropic models, not instances of"
) in str(w[-1].message)
# Test ignore behavior
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore")
result = await middleware.awrap_model_call(fake_request, mock_handler)
assert isinstance(result, ModelResponse)
async def test_anthropic_prompt_caching_middleware_async_min_messages() -> None:
"""Test async path respects min_messages_to_cache."""
middleware = AnthropicPromptCachingMiddleware(min_messages_to_cache=5)
# Test with fewer messages than minimum
fake_request = ModelRequest(
model=FakeToolCallingModel(),
messages=[HumanMessage("Hello")] * 3,
system_prompt=None,
tool_choice=None,
tools=[],
response_format=None,
state={"messages": [HumanMessage("Hello")] * 3},
runtime=cast(Runtime, object()),
model_settings={},
)
async def mock_handler(req: ModelRequest) -> ModelResponse:
return ModelResponse(result=[AIMessage(content="mock response")])
result = await middleware.awrap_model_call(fake_request, mock_handler)
assert isinstance(result, ModelResponse)
# Cache control should NOT be added when message count is below minimum
assert fake_request.model_settings == {}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,71 @@
from __future__ import annotations
import json
from typing import Any, cast
from langchain_core.messages import content as types
def _convert_from_v1_to_groq(
content: list[types.ContentBlock],
model_provider: str | None,
) -> tuple[list[dict[str, Any] | str], dict]:
new_content: list = []
new_additional_kwargs: dict = {}
for i, block in enumerate(content):
if block["type"] == "text":
new_content.append({"text": block.get("text", ""), "type": "text"})
elif (
block["type"] == "reasoning"
and (reasoning := block.get("reasoning"))
and model_provider == "groq"
):
new_additional_kwargs["reasoning_content"] = reasoning
elif block["type"] == "server_tool_call" and model_provider == "groq":
new_block = {}
if "args" in block:
new_block["arguments"] = json.dumps(block["args"])
if idx := block.get("extras", {}).get("index"):
new_block["index"] = idx
if block.get("name") == "web_search":
new_block["type"] = "search"
elif block.get("name") == "code_interpreter":
new_block["type"] = "python"
else:
new_block["type"] = ""
if i < len(content) - 1 and content[i + 1]["type"] == "server_tool_result":
result = cast("types.ServerToolResult", content[i + 1])
for k, v in result.get("extras", {}).items():
new_block[k] = v # noqa: PERF403
if "output" in result:
new_block["output"] = result["output"]
if "executed_tools" not in new_additional_kwargs:
new_additional_kwargs["executed_tools"] = []
new_additional_kwargs["executed_tools"].append(new_block)
elif block["type"] == "server_tool_result":
continue
elif (
block["type"] == "non_standard"
and "value" in block
and model_provider == "groq"
):
new_content.append(block["value"])
else:
new_content.append(block)
# For consistency with v0 payloads, we cast single text blocks to str
if (
len(new_content) == 1
and isinstance(new_content[0], dict)
and new_content[0].get("type") == "text"
and (text_content := new_content[0].get("text"))
and isinstance(text_content, str)
):
return text_content, new_additional_kwargs
return new_content, new_additional_kwargs

View File

@@ -57,6 +57,7 @@ from langchain_core.utils.pydantic import is_basemodel_subclass
from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
from typing_extensions import Self
from langchain_groq._compat import _convert_from_v1_to_groq
from langchain_groq.version import __version__
@@ -1187,6 +1188,17 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
elif isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
# Translate v1 content
if message.response_metadata.get("output_version") == "v1":
new_content, new_additional_kwargs = _convert_from_v1_to_groq(
message.content_blocks, message.response_metadata.get("model_provider")
)
message = message.model_copy(
update={
"content": new_content,
"additional_kwargs": new_additional_kwargs,
}
)
message_dict = {"role": "assistant", "content": message.content}
# If content is a list of content blocks, filter out tool_call blocks
@@ -1268,6 +1280,22 @@ def _convert_chunk_to_message_chunk(
if role == "assistant" or default_class == AIMessageChunk:
if reasoning := _dict.get("reasoning"):
additional_kwargs["reasoning_content"] = reasoning
if executed_tools := _dict.get("executed_tools"):
additional_kwargs["executed_tools"] = []
for executed_tool in executed_tools:
if executed_tool.get("output"):
# Tool output duplicates query and other server tool call data
additional_kwargs["executed_tools"].append(
{
k: executed_tool[k]
for k in ("index", "output")
if k in executed_tool
}
)
else:
additional_kwargs["executed_tools"].append(
{k: executed_tool[k] for k in executed_tool if k != "output"}
)
if usage := (chunk.get("x_groq") or {}).get("usage"):
input_tokens = usage.get("prompt_tokens", 0)
output_tokens = usage.get("completion_tokens", 0)
@@ -1282,6 +1310,7 @@ def _convert_chunk_to_message_chunk(
content=content,
additional_kwargs=additional_kwargs,
usage_metadata=usage_metadata, # type: ignore[arg-type]
response_metadata={"model_provider": "groq"},
)
if role == "system" or default_class == SystemMessageChunk:
return SystemMessageChunk(content=content)
@@ -1313,6 +1342,8 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
additional_kwargs: dict = {}
if reasoning := _dict.get("reasoning"):
additional_kwargs["reasoning_content"] = reasoning
if executed_tools := _dict.get("executed_tools"):
additional_kwargs["executed_tools"] = executed_tools
if function_call := _dict.get("function_call"):
additional_kwargs["function_call"] = dict(function_call)
tool_calls = []
@@ -1332,6 +1363,7 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
additional_kwargs=additional_kwargs,
tool_calls=tool_calls,
invalid_tool_calls=invalid_tool_calls,
response_metadata={"model_provider": "groq"},
)
if role == "system":
return SystemMessage(content=_dict.get("content", ""))

View File

@@ -0,0 +1,37 @@
from typing import Any
import pytest
from langchain_tests.conftest import CustomPersister, CustomSerializer
from langchain_tests.conftest import (
_base_vcr_config as _base_vcr_config, # noqa: PLC0414
)
from vcr import VCR # type: ignore[import-untyped]
def remove_request_headers(request: Any) -> Any:
for k in request.headers:
request.headers[k] = "**REDACTED**"
return request
def remove_response_headers(response: dict) -> dict:
for k in response["headers"]:
response["headers"][k] = "**REDACTED**"
return response
@pytest.fixture(scope="session")
def vcr_config(_base_vcr_config: dict) -> dict: # noqa: F811
"""Extend the default configuration coming from langchain_tests."""
config = _base_vcr_config.copy()
config["before_record_request"] = remove_request_headers
config["before_record_response"] = remove_response_headers
config["serializer"] = "yaml.gz"
config["path_transformer"] = VCR.ensure_suffix(".yaml.gz")
return config
def pytest_recording_configure(config: dict, vcr: VCR) -> None:
vcr.register_persister(CustomPersister())
vcr.register_serializer("yaml.gz", CustomSerializer())

View File

@@ -111,7 +111,9 @@ async def test_astream() -> None:
full = token if full is None else full + token
if token.usage_metadata is not None:
chunks_with_token_counts += 1
if token.response_metadata:
if token.response_metadata and not set(token.response_metadata.keys()).issubset(
{"model_provider", "output_version"}
):
chunks_with_response_metadata += 1
if chunks_with_token_counts != 1 or chunks_with_response_metadata != 1:
msg = (
@@ -665,6 +667,146 @@ async def test_setting_service_tier_request_async() -> None:
assert response.response_metadata.get("service_tier") == "on_demand"
@pytest.mark.vcr
def test_web_search() -> None:
llm = ChatGroq(model="groq/compound")
input_message = {
"role": "user",
"content": "Search for the weather in Boston today.",
}
full: AIMessageChunk | None = None
for chunk in llm.stream([input_message]):
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessageChunk)
assert full.additional_kwargs["reasoning_content"]
assert full.additional_kwargs["executed_tools"]
assert [block["type"] for block in full.content_blocks] == [
"reasoning",
"server_tool_call",
"server_tool_result",
"text",
]
next_message = {
"role": "user",
"content": "Now search for the weather in San Francisco.",
}
response = llm.invoke([input_message, full, next_message])
assert [block["type"] for block in response.content_blocks] == [
"reasoning",
"server_tool_call",
"server_tool_result",
"text",
]
@pytest.mark.default_cassette("test_web_search.yaml.gz")
@pytest.mark.vcr
def test_web_search_v1() -> None:
llm = ChatGroq(model="groq/compound", output_version="v1")
input_message = {
"role": "user",
"content": "Search for the weather in Boston today.",
}
full: AIMessageChunk | None = None
for chunk in llm.stream([input_message]):
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessageChunk)
assert full.additional_kwargs["reasoning_content"]
assert full.additional_kwargs["executed_tools"]
assert [block["type"] for block in full.content_blocks] == [
"reasoning",
"server_tool_call",
"server_tool_result",
"reasoning",
"text",
]
next_message = {
"role": "user",
"content": "Now search for the weather in San Francisco.",
}
response = llm.invoke([input_message, full, next_message])
assert [block["type"] for block in response.content_blocks] == [
"reasoning",
"server_tool_call",
"server_tool_result",
"text",
]
@pytest.mark.vcr
def test_code_interpreter() -> None:
llm = ChatGroq(model="groq/compound-mini")
input_message = {
"role": "user",
"content": (
"Calculate the square root of 101 and show me the Python code you used."
),
}
full: AIMessageChunk | None = None
for chunk in llm.stream([input_message]):
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessageChunk)
assert full.additional_kwargs["reasoning_content"]
assert full.additional_kwargs["executed_tools"]
assert [block["type"] for block in full.content_blocks] == [
"reasoning",
"server_tool_call",
"server_tool_result",
"text",
]
next_message = {
"role": "user",
"content": "Now do the same for 102.",
}
response = llm.invoke([input_message, full, next_message])
assert [block["type"] for block in response.content_blocks] == [
"reasoning",
"server_tool_call",
"server_tool_result",
"text",
]
@pytest.mark.default_cassette("test_code_interpreter.yaml.gz")
@pytest.mark.vcr
def test_code_interpreter_v1() -> None:
llm = ChatGroq(model="groq/compound-mini", output_version="v1")
input_message = {
"role": "user",
"content": (
"Calculate the square root of 101 and show me the Python code you used."
),
}
full: AIMessageChunk | None = None
for chunk in llm.stream([input_message]):
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessageChunk)
assert full.additional_kwargs["reasoning_content"]
assert full.additional_kwargs["executed_tools"]
assert [block["type"] for block in full.content_blocks] == [
"reasoning",
"server_tool_call",
"server_tool_result",
"reasoning",
"text",
]
next_message = {
"role": "user",
"content": "Now do the same for 102.",
}
response = llm.invoke([input_message, full, next_message])
assert [block["type"] for block in response.content_blocks] == [
"reasoning",
"server_tool_call",
"server_tool_result",
"text",
]
# Groq does not currently support N > 1
# @pytest.mark.scheduled
# def test_chat_multiple_completions() -> None:

View File

@@ -54,7 +54,9 @@ def test__convert_dict_to_message_human() -> None:
def test__convert_dict_to_message_ai() -> None:
message = {"role": "assistant", "content": "foo"}
result = _convert_dict_to_message(message)
expected_output = AIMessage(content="foo")
expected_output = AIMessage(
content="foo", response_metadata={"model_provider": "groq"}
)
assert result == expected_output
@@ -80,6 +82,7 @@ def test__convert_dict_to_message_tool_call() -> None:
type="tool_call",
)
],
response_metadata={"model_provider": "groq"},
)
assert result == expected_output
@@ -124,6 +127,7 @@ def test__convert_dict_to_message_tool_call() -> None:
type="tool_call",
),
],
response_metadata={"model_provider": "groq"},
)
assert result == expected_output

View File

@@ -0,0 +1,125 @@
"""Derivations of standard content blocks from mistral content."""
from __future__ import annotations
from langchain_core.messages import AIMessage, AIMessageChunk
from langchain_core.messages import content as types
from langchain_core.messages.block_translators import register_translator
def _convert_from_v1_to_mistral(
content: list[types.ContentBlock],
model_provider: str | None,
) -> str | list[str | dict]:
new_content: list = []
for block in content:
if block["type"] == "text":
new_content.append({"text": block.get("text", ""), "type": "text"})
elif (
block["type"] == "reasoning"
and (reasoning := block.get("reasoning"))
and isinstance(reasoning, str)
and model_provider == "mistralai"
):
new_content.append(
{
"type": "thinking",
"thinking": [{"type": "text", "text": reasoning}],
}
)
elif (
block["type"] == "non_standard"
and "value" in block
and model_provider == "mistralai"
):
new_content.append(block["value"])
elif block["type"] == "tool_call":
continue
else:
new_content.append(block)
return new_content
def _convert_to_v1_from_mistral(message: AIMessage) -> list[types.ContentBlock]:
"""Convert mistral message content to v1 format."""
if isinstance(message.content, str):
content_blocks: list[types.ContentBlock] = [
{"type": "text", "text": message.content}
]
else:
content_blocks = []
for block in message.content:
if isinstance(block, str):
content_blocks.append({"type": "text", "text": block})
elif isinstance(block, dict):
if block.get("type") == "text" and isinstance(block.get("text"), str):
text_block: types.TextContentBlock = {
"type": "text",
"text": block["text"],
}
if "index" in block:
text_block["index"] = block["index"]
content_blocks.append(text_block)
elif block.get("type") == "thinking" and isinstance(
block.get("thinking"), list
):
for sub_block in block["thinking"]:
if (
isinstance(sub_block, dict)
and sub_block.get("type") == "text"
):
reasoning_block: types.ReasoningContentBlock = {
"type": "reasoning",
"reasoning": sub_block.get("text", ""),
}
if "index" in block:
reasoning_block["index"] = block["index"]
content_blocks.append(reasoning_block)
else:
non_standard_block: types.NonStandardContentBlock = {
"type": "non_standard",
"value": block,
}
content_blocks.append(non_standard_block)
else:
continue
if (
len(content_blocks) == 1
and content_blocks[0].get("type") == "text"
and content_blocks[0].get("text") == ""
and message.tool_calls
):
content_blocks = []
for tool_call in message.tool_calls:
content_blocks.append(
{
"type": "tool_call",
"name": tool_call["name"],
"args": tool_call["args"],
"id": tool_call.get("id"),
}
)
return content_blocks
def translate_content(message: AIMessage) -> list[types.ContentBlock]:
"""Derive standard content blocks from a message with mistral content."""
return _convert_to_v1_from_mistral(message)
def translate_content_chunk(message: AIMessageChunk) -> list[types.ContentBlock]:
"""Derive standard content blocks from a message chunk with mistral content."""
return _convert_to_v1_from_mistral(message)
register_translator("mistralai", translate_content, translate_content_chunk)

View File

@@ -24,12 +24,7 @@ from langchain_core.callbacks import (
CallbackManagerForLLMRun,
)
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import (
BaseChatModel,
LangSmithParams,
agenerate_from_stream,
generate_from_stream,
)
from langchain_core.language_models.chat_models import BaseChatModel, LangSmithParams
from langchain_core.language_models.llms import create_base_retry_decorator
from langchain_core.messages import (
AIMessage,
@@ -74,6 +69,8 @@ from pydantic import (
)
from typing_extensions import Self
from langchain_mistralai._compat import _convert_from_v1_to_mistral
if TYPE_CHECKING:
from collections.abc import AsyncIterator, Iterator, Sequence
from contextlib import AbstractAsyncContextManager
@@ -160,6 +157,7 @@ def _convert_mistral_chat_message_to_message(
additional_kwargs=additional_kwargs,
tool_calls=tool_calls,
invalid_tool_calls=invalid_tool_calls,
response_metadata={"model_provider": "mistralai"},
)
@@ -231,14 +229,34 @@ async def acompletion_with_retry(
def _convert_chunk_to_message_chunk(
chunk: dict, default_class: type[BaseMessageChunk]
) -> BaseMessageChunk:
chunk: dict,
default_class: type[BaseMessageChunk],
index: int,
index_type: str,
output_version: str | None,
) -> tuple[BaseMessageChunk, int, str]:
_choice = chunk["choices"][0]
_delta = _choice["delta"]
role = _delta.get("role")
content = _delta.get("content") or ""
if output_version == "v1" and isinstance(content, str):
content = [{"type": "text", "text": content}]
if isinstance(content, list):
for block in content:
if isinstance(block, dict):
if "type" in block and block["type"] != index_type:
index_type = block["type"]
index = index + 1
if "index" not in block:
block["index"] = index
if block.get("type") == "thinking" and isinstance(
block.get("thinking"), list
):
for sub_block in block["thinking"]:
if isinstance(sub_block, dict) and "index" not in sub_block:
sub_block["index"] = 0
if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content)
return HumanMessageChunk(content=content), index, index_type
if role == "assistant" or default_class == AIMessageChunk:
additional_kwargs: dict = {}
response_metadata = {}
@@ -276,18 +294,22 @@ def _convert_chunk_to_message_chunk(
):
response_metadata["model_name"] = chunk["model"]
response_metadata["finish_reason"] = _choice["finish_reason"]
return AIMessageChunk(
content=content,
additional_kwargs=additional_kwargs,
tool_call_chunks=tool_call_chunks, # type: ignore[arg-type]
usage_metadata=usage_metadata, # type: ignore[arg-type]
response_metadata=response_metadata,
return (
AIMessageChunk(
content=content,
additional_kwargs=additional_kwargs,
tool_call_chunks=tool_call_chunks, # type: ignore[arg-type]
usage_metadata=usage_metadata, # type: ignore[arg-type]
response_metadata={"model_provider": "mistralai", **response_metadata},
),
index,
index_type,
)
if role == "system" or default_class == SystemMessageChunk:
return SystemMessageChunk(content=content)
return SystemMessageChunk(content=content), index, index_type
if role or default_class == ChatMessageChunk:
return ChatMessageChunk(content=content, role=role)
return default_class(content=content) # type: ignore[call-arg]
return ChatMessageChunk(content=content, role=role), index, index_type
return default_class(content=content), index, index_type # type: ignore[call-arg]
def _format_tool_call_for_mistral(tool_call: ToolCall) -> dict:
@@ -318,6 +340,21 @@ def _format_invalid_tool_call_for_mistral(invalid_tool_call: InvalidToolCall) ->
return result
def _clean_block(block: dict) -> dict:
# Remove "index" key added for message aggregation in langchain-core
new_block = {k: v for k, v in block.items() if k != "index"}
if block.get("type") == "thinking" and isinstance(block.get("thinking"), list):
new_block["thinking"] = [
(
{k: v for k, v in sb.items() if k != "index"}
if isinstance(sb, dict) and "index" in sb
else sb
)
for sb in block["thinking"]
]
return new_block
def _convert_message_to_mistral_chat_message(
message: BaseMessage,
) -> dict:
@@ -356,13 +393,40 @@ def _convert_message_to_mistral_chat_message(
pass
if tool_calls: # do not populate empty list tool_calls
message_dict["tool_calls"] = tool_calls
if tool_calls and message.content:
# Message content
# Translate v1 content
if message.response_metadata.get("output_version") == "v1":
content = _convert_from_v1_to_mistral(
message.content_blocks, message.response_metadata.get("model_provider")
)
else:
content = message.content
if tool_calls and content:
# Assistant message must have either content or tool_calls, but not both.
# Some providers may not support tool_calls in the same message as content.
# This is done to ensure compatibility with messages from other providers.
message_dict["content"] = ""
content = ""
elif isinstance(content, list):
content = [
_clean_block(block)
if isinstance(block, dict) and "index" in block
else block
for block in content
]
else:
message_dict["content"] = message.content
content = message.content
# if any blocks are dicts, cast strings to text blocks
if any(isinstance(block, dict) for block in content):
content = [
block if isinstance(block, dict) else {"type": "text", "text": block}
for block in content
]
message_dict["content"] = content
if "prefix" in message.additional_kwargs:
message_dict["prefix"] = message.additional_kwargs["prefix"]
return message_dict
@@ -564,13 +628,6 @@ class ChatMistralAI(BaseChatModel):
stream: bool | None = None, # noqa: FBT001
**kwargs: Any,
) -> ChatResult:
should_stream = stream if stream is not None else self.streaming
if should_stream:
stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
response = self.completion_with_retry(
@@ -627,12 +684,16 @@ class ChatMistralAI(BaseChatModel):
params = {**params, **kwargs, "stream": True}
default_chunk_class: type[BaseMessageChunk] = AIMessageChunk
index = -1
index_type = ""
for chunk in self.completion_with_retry(
messages=message_dicts, run_manager=run_manager, **params
):
if len(chunk.get("choices", [])) == 0:
continue
new_chunk = _convert_chunk_to_message_chunk(chunk, default_chunk_class)
new_chunk, index, index_type = _convert_chunk_to_message_chunk(
chunk, default_chunk_class, index, index_type, self.output_version
)
# make future chunks same type as first chunk
default_chunk_class = new_chunk.__class__
gen_chunk = ChatGenerationChunk(message=new_chunk)
@@ -653,12 +714,16 @@ class ChatMistralAI(BaseChatModel):
params = {**params, **kwargs, "stream": True}
default_chunk_class: type[BaseMessageChunk] = AIMessageChunk
index = -1
index_type = ""
async for chunk in await acompletion_with_retry(
self, messages=message_dicts, run_manager=run_manager, **params
):
if len(chunk.get("choices", [])) == 0:
continue
new_chunk = _convert_chunk_to_message_chunk(chunk, default_chunk_class)
new_chunk, index, index_type = _convert_chunk_to_message_chunk(
chunk, default_chunk_class, index, index_type, self.output_version
)
# make future chunks same type as first chunk
default_chunk_class = new_chunk.__class__
gen_chunk = ChatGenerationChunk(message=new_chunk)
@@ -676,13 +741,6 @@ class ChatMistralAI(BaseChatModel):
stream: bool | None = None, # noqa: FBT001
**kwargs: Any,
) -> ChatResult:
should_stream = stream if stream is not None else self.streaming
if should_stream:
stream_iter = self._astream(
messages=messages, stop=stop, run_manager=run_manager, **kwargs
)
return await agenerate_from_stream(stream_iter)
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
response = await acompletion_with_retry(

View File

@@ -2,33 +2,19 @@
from __future__ import annotations
import json
import logging
import time
from typing import Any
import pytest
from httpx import ReadTimeout
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessageChunk,
HumanMessage,
)
from langchain_core.messages import AIMessageChunk, BaseMessageChunk
from pydantic import BaseModel
from typing_extensions import TypedDict
from langchain_mistralai.chat_models import ChatMistralAI
def test_stream() -> None:
"""Test streaming tokens from ChatMistralAI."""
llm = ChatMistralAI()
for token in llm.stream("Hello"):
assert isinstance(token.content, str)
async def test_astream() -> None:
"""Test streaming tokens from ChatMistralAI."""
llm = ChatMistralAI()
@@ -42,7 +28,9 @@ async def test_astream() -> None:
full = token if full is None else full + token
if token.usage_metadata is not None:
chunks_with_token_counts += 1
if token.response_metadata:
if token.response_metadata and not set(token.response_metadata.keys()).issubset(
{"model_provider", "output_version"}
):
chunks_with_response_metadata += 1
if chunks_with_token_counts != 1 or chunks_with_response_metadata != 1:
msg = (
@@ -63,131 +51,6 @@ async def test_astream() -> None:
assert full.response_metadata["model_name"]
async def test_abatch() -> None:
"""Test streaming tokens from ChatMistralAI."""
llm = ChatMistralAI()
result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"])
for token in result:
assert isinstance(token.content, str)
async def test_abatch_tags() -> None:
"""Test batch tokens from ChatMistralAI."""
llm = ChatMistralAI()
result = await llm.abatch(
["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]}
)
for token in result:
assert isinstance(token.content, str)
def test_batch() -> None:
"""Test batch tokens from ChatMistralAI."""
llm = ChatMistralAI()
result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"])
for token in result:
assert isinstance(token.content, str)
async def test_ainvoke() -> None:
"""Test invoke tokens from ChatMistralAI."""
llm = ChatMistralAI()
result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]})
assert isinstance(result.content, str)
assert "model_name" in result.response_metadata
def test_invoke() -> None:
"""Test invoke tokens from ChatMistralAI."""
llm = ChatMistralAI()
result = llm.invoke("I'm Pickle Rick", config={"tags": ["foo"]})
assert isinstance(result.content, str)
def test_chat_mistralai_llm_output_contains_model_name() -> None:
"""Test llm_output contains model_name."""
chat = ChatMistralAI(max_tokens=10)
message = HumanMessage(content="Hello")
llm_result = chat.generate([[message]])
assert llm_result.llm_output is not None
assert llm_result.llm_output["model_name"] == chat.model
def test_chat_mistralai_streaming_llm_output_contains_model_name() -> None:
"""Test llm_output contains model_name."""
chat = ChatMistralAI(max_tokens=10, streaming=True)
message = HumanMessage(content="Hello")
llm_result = chat.generate([[message]])
assert llm_result.llm_output is not None
assert llm_result.llm_output["model_name"] == chat.model
def test_chat_mistralai_llm_output_contains_token_usage() -> None:
"""Test llm_output contains model_name."""
chat = ChatMistralAI(max_tokens=10)
message = HumanMessage(content="Hello")
llm_result = chat.generate([[message]])
assert llm_result.llm_output is not None
assert "token_usage" in llm_result.llm_output
token_usage = llm_result.llm_output["token_usage"]
assert "prompt_tokens" in token_usage
assert "completion_tokens" in token_usage
assert "total_tokens" in token_usage
def test_chat_mistralai_streaming_llm_output_not_contain_token_usage() -> None:
"""Mistral currently doesn't return token usage when streaming."""
chat = ChatMistralAI(max_tokens=10, streaming=True)
message = HumanMessage(content="Hello")
llm_result = chat.generate([[message]])
assert llm_result.llm_output is not None
assert "token_usage" in llm_result.llm_output
token_usage = llm_result.llm_output["token_usage"]
assert not token_usage
def test_structured_output() -> None:
llm = ChatMistralAI(model="mistral-large-latest", temperature=0) # type: ignore[call-arg]
schema = {
"title": "AnswerWithJustification",
"description": (
"An answer to the user question along with justification for the answer."
),
"type": "object",
"properties": {
"answer": {"title": "Answer", "type": "string"},
"justification": {"title": "Justification", "type": "string"},
},
"required": ["answer", "justification"],
}
structured_llm = llm.with_structured_output(schema)
result = structured_llm.invoke(
"What weighs more a pound of bricks or a pound of feathers"
)
assert isinstance(result, dict)
def test_streaming_structured_output() -> None:
llm = ChatMistralAI(model="mistral-large-latest", temperature=0) # type: ignore[call-arg]
class Person(BaseModel):
name: str
age: int
structured_llm = llm.with_structured_output(Person)
strm = structured_llm.stream("Erick, 27 years old")
for chunk_num, chunk in enumerate(strm):
assert chunk_num == 0, "should only have one chunk with model"
assert isinstance(chunk, Person)
assert chunk.name == "Erick"
assert chunk.age == 27
class Book(BaseModel):
name: str
authors: list[str]
@@ -247,66 +110,6 @@ async def test_structured_output_json_schema_async(schema: Any) -> None:
_check_parsed_result(chunk, schema)
def test_tool_call() -> None:
llm = ChatMistralAI(model="mistral-large-latest", temperature=0) # type: ignore[call-arg]
class Person(BaseModel):
name: str
age: int
tool_llm = llm.bind_tools([Person])
result = tool_llm.invoke("Erick, 27 years old")
assert isinstance(result, AIMessage)
assert len(result.tool_calls) == 1
tool_call = result.tool_calls[0]
assert tool_call["name"] == "Person"
assert tool_call["args"] == {"name": "Erick", "age": 27}
def test_streaming_tool_call() -> None:
llm = ChatMistralAI(model="mistral-large-latest", temperature=0) # type: ignore[call-arg]
class Person(BaseModel):
name: str
age: int
tool_llm = llm.bind_tools([Person])
# where it calls the tool
strm = tool_llm.stream("Erick, 27 years old")
additional_kwargs = None
for chunk in strm:
assert isinstance(chunk, AIMessageChunk)
assert chunk.content == ""
additional_kwargs = chunk.additional_kwargs
assert additional_kwargs is not None
assert "tool_calls" in additional_kwargs
assert len(additional_kwargs["tool_calls"]) == 1
assert additional_kwargs["tool_calls"][0]["function"]["name"] == "Person"
assert json.loads(additional_kwargs["tool_calls"][0]["function"]["arguments"]) == {
"name": "Erick",
"age": 27,
}
assert isinstance(chunk, AIMessageChunk)
assert len(chunk.tool_call_chunks) == 1
tool_call_chunk = chunk.tool_call_chunks[0]
assert tool_call_chunk["name"] == "Person"
assert tool_call_chunk["args"] == '{"name": "Erick", "age": 27}'
# where it doesn't call the tool
strm = tool_llm.stream("What is 2+2?")
acc: Any = None
for chunk in strm:
assert isinstance(chunk, AIMessageChunk)
acc = chunk if acc is None else acc + chunk
assert acc.content != ""
assert "tool_calls" not in acc.additional_kwargs
def test_retry_parameters(caplog: pytest.LogCaptureFixture) -> None:
"""Test that retry parameters are honored in ChatMistralAI."""
# Create a model with intentionally short timeout and multiple retries
@@ -342,3 +145,51 @@ def test_retry_parameters(caplog: pytest.LogCaptureFixture) -> None:
except Exception:
logger.exception("Unexpected exception")
raise
def test_reasoning() -> None:
model = ChatMistralAI(model="magistral-medium-latest") # type: ignore[call-arg]
input_message = {
"role": "user",
"content": "Hello, my name is Bob.",
}
full: AIMessageChunk | None = None
for chunk in model.stream([input_message]):
assert isinstance(chunk, AIMessageChunk)
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessageChunk)
thinking_blocks = 0
for i, block in enumerate(full.content):
if isinstance(block, dict) and block.get("type") == "thinking":
thinking_blocks += 1
reasoning_block = full.content_blocks[i]
assert reasoning_block["type"] == "reasoning"
assert isinstance(reasoning_block.get("reasoning"), str)
assert thinking_blocks > 0
next_message = {"role": "user", "content": "What is my name?"}
_ = model.invoke([input_message, full, next_message])
def test_reasoning_v1() -> None:
model = ChatMistralAI(model="magistral-medium-latest", output_version="v1") # type: ignore[call-arg]
input_message = {
"role": "user",
"content": "Hello, my name is Bob.",
}
full: AIMessageChunk | None = None
chunks = []
for chunk in model.stream([input_message]):
assert isinstance(chunk, AIMessageChunk)
full = chunk if full is None else full + chunk
chunks.append(chunk)
assert isinstance(full, AIMessageChunk)
reasoning_blocks = 0
for block in full.content:
if isinstance(block, dict) and block.get("type") == "reasoning":
reasoning_blocks += 1
assert isinstance(block.get("reasoning"), str)
assert reasoning_blocks > 0
next_message = {"role": "user", "content": "What is my name?"}
_ = model.invoke([input_message, full, next_message])

View File

@@ -188,6 +188,7 @@ def test__convert_dict_to_message_tool_call() -> None:
type="tool_call",
)
],
response_metadata={"model_provider": "mistralai"},
)
assert result == expected_output
assert _convert_message_to_mistral_chat_message(expected_output) == message
@@ -231,6 +232,7 @@ def test__convert_dict_to_message_tool_call() -> None:
type="tool_call",
),
],
response_metadata={"model_provider": "mistralai"},
)
assert result == expected_output
assert _convert_message_to_mistral_chat_message(expected_output) == message