Compare commits

...

46 Commits

Author SHA1 Message Date
Sydney Runkle
4cc3a15af8 more generics 2025-11-21 11:31:38 -05:00
Sydney Runkle
bacafb2fc6 improved types 2025-11-19 15:42:28 -05:00
Sydney Runkle
b7d1831f9d fix: deprecate setattr on ModelCallRequest (#34022)
* one alternative considered was setting `frozen=True` on the dataclass,
but this is breaking, so a deprecation is a nicer approach
2025-11-19 11:08:55 -05:00
ccurme
328ba36601 chore(openai): skip Azure text completions tests (#34021) 2025-11-19 09:29:12 -05:00
Sydney Runkle
6f677ef5c1 chore: temporarily skip openai integration tests (#34020)
getting around deprecated azure model issues blocking core release
2025-11-19 14:05:22 +00:00
Sydney Runkle
d47d41cbd3 release: langchain-core 1.0.6 (#34018) 2025-11-19 08:16:34 -05:00
William FH
32bbe99efc chore: Support tool runtime injection when custom args schema is prov… (#33999)
Support injection of injected args (like `InjectedToolCallId`,
`ToolRuntime`) when an `args_schema` is specified that doesn't contain
said args.

This allows for pydantic validation of other args while retaining the
ability to inject langchain specific arguments.

fixes https://github.com/langchain-ai/langchain/issues/33646
fixes https://github.com/langchain-ai/langchain/issues/31688

Taking a deep dive here reminded me that we definitely need to revisit
our internal tooling logic, but I don't think we should do that in this
PR.

---------

Co-authored-by: Sydney Runkle <54324534+sydney-runkle@users.noreply.github.com>
Co-authored-by: Sydney Runkle <sydneymarierunkle@gmail.com>
2025-11-18 17:09:59 +00:00
ccurme
990e346c46 release(anthropic): 1.1 (#33997) 2025-11-17 16:24:29 -05:00
ccurme
9b7792631d feat(anthropic): support native structured output feature and strict tool calling (#33980) 2025-11-17 16:14:20 -05:00
CKLogic
558a8fe25b feat(core): add proxy support for mermaid png rendering (#32400)
### Description

This PR adds support for configuring HTTP/HTTPS proxies when rendering
Mermaid diagrams as PNG images using the remote Mermaid.INK API. This
enhancement allows users in restricted network environments to access
the API via a proxy, making the remote rendering feature more robust and
accessible.

The changes include:
- Added optional `proxies` parameter to `draw_mermaid_png` and
`_render_mermaid_using_api` functions
- Updated `Graph.draw_mermaid_png` method to support and pass through
proxy configuration
- Enhanced docstrings with usage examples for the new parameter
- Maintained full backward compatibility with existing code

### Usage Example

```python
proxies = {
        "http": "http://127.0.0.1:7890",
        "https": "http://127.0.0.1:7890"
}

display(Image(chain.get_graph().draw_mermaid_png(proxies=proxies)))

```

### Dependencies

No new dependencies required. Uses existing `requests` library for HTTP
requests.

---------

Co-authored-by: Mason Daugherty <mason@langchain.dev>
Co-authored-by: Mason Daugherty <github@mdrxy.com>
2025-11-17 12:45:17 -06:00
Mason Daugherty
52b1516d44 style(langchain): fix some middleware ref syntax (#33988) 2025-11-16 00:33:17 -05:00
Mason Daugherty
8a3bb73c05 release(openai): 1.0.3 (#33981)
- Respect 300k token limit for embeddings API requests #33668
- fix create_agent / response_format for Responses API #33939
- fix response.incomplete event is not handled when using
stream_mode=['messages'] #33871
2025-11-14 19:18:50 -05:00
Mason Daugherty
099c042395 refactor(openai): embedding utils and calculations (#33982)
Now returns (`_iter`, `tokens`, `indices`, token_counts`). The
`token_counts` are calculated directly during tokenization, which is
more accurate and efficient than splitting strings later.
2025-11-14 19:18:37 -05:00
Kaparthy Reddy
2d4f00a451 fix(openai): Respect 300k token limit for embeddings API requests (#33668)
## Description

Fixes #31227 - Resolves the issue where `OpenAIEmbeddings` exceeds
OpenAI's 300,000 token per request limit, causing 400 BadRequest errors.

## Problem

When embedding large document sets, LangChain would send batches
containing more than 300,000 tokens in a single API request, causing
this error:
```
openai.BadRequestError: Error code: 400 - {'error': {'message': 'Requested 673477 tokens, max 300000 tokens per request'}}
```

The issue occurred because:
- The code chunks texts by `embedding_ctx_length` (8191 tokens per
chunk)
- Then batches chunks by `chunk_size` (default 1000 chunks per request)
- **But didn't check**: Total tokens per batch against OpenAI's 300k
limit
- Result: `1000 chunks × 8191 tokens = 8,191,000 tokens` → Exceeds
limit!

## Solution

This PR implements dynamic batching that respects the 300k token limit:

1. **Added constant**: `MAX_TOKENS_PER_REQUEST = 300000`
2. **Track token counts**: Calculate actual tokens for each chunk
3. **Dynamic batching**: Instead of fixed `chunk_size` batches,
accumulate chunks until approaching the 300k limit
4. **Applied to both sync and async**: Fixed both
`_get_len_safe_embeddings` and `_aget_len_safe_embeddings`

## Changes

- Modified `langchain_openai/embeddings/base.py`:
  - Added `MAX_TOKENS_PER_REQUEST` constant
  - Replaced fixed-size batching with token-aware dynamic batching
  - Applied to both sync (line ~478) and async (line ~527) methods
- Added test in `tests/unit_tests/embeddings/test_base.py`:
- `test_embeddings_respects_token_limit()` - Verifies large document
sets are properly batched

## Testing

All existing tests pass (280 passed, 4 xfailed, 1 xpassed).

New test verifies:
- Large document sets (500 texts × 1000 tokens = 500k tokens) are split
into multiple API calls
- Each API call respects the 300k token limit

## Usage

After this fix, users can embed large document sets without errors:
```python
from langchain_openai import OpenAIEmbeddings
from langchain_chroma import Chroma
from langchain_text_splitters import CharacterTextSplitter

# This will now work without exceeding token limits
embeddings = OpenAIEmbeddings()
documents = CharacterTextSplitter().split_documents(large_documents)
Chroma.from_documents(documents, embeddings)
```

Resolves #31227

---------

Co-authored-by: Kaparthy Reddy <kaparthyreddy@Kaparthys-MacBook-Air.local>
Co-authored-by: Chester Curme <chester.curme@gmail.com>
Co-authored-by: Mason Daugherty <mason@langchain.dev>
Co-authored-by: Mason Daugherty <github@mdrxy.com>
2025-11-14 18:12:07 -05:00
Sydney Runkle
9bd401a6d4 fix: resumable shell, works w/ interrupts (#33978)
fixes https://github.com/langchain-ai/langchain/issues/33684

Now able to run this minimal snippet successfully

```py
import os

from langchain.agents import create_agent
from langchain.agents.middleware import (
    HostExecutionPolicy,
    HumanInTheLoopMiddleware,
    ShellToolMiddleware,
)
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.types import Command


shell_middleware = ShellToolMiddleware(
    workspace_root=os.getcwd(),
    env=os.environ,  # danger
    execution_policy=HostExecutionPolicy()
)

hil_middleware = HumanInTheLoopMiddleware(interrupt_on={"shell": True})

checkpointer = InMemorySaver()

agent = create_agent(
    "openai:gpt-4.1-mini",
    middleware=[shell_middleware, hil_middleware],
    checkpointer=checkpointer,
)

input_message = {"role": "user", "content": "run `which python`"}

config = {"configurable": {"thread_id": "1"}}

result = agent.invoke(
    {"messages": [input_message]},
    config=config,
    durability="exit",
)
```
2025-11-14 15:32:25 -05:00
ccurme
6aa3794b74 feat(langchain): reference model profiles for provider strategy (#33974) 2025-11-14 19:24:18 +00:00
Sydney Runkle
189dcf7295 chore: increase coverage for shell, filesystem, and summarization middleware (#33928)
cc generated, just a start here but wanted to bump things up from 70%
ish
2025-11-14 13:30:36 -05:00
Sydney Runkle
1bc88028e6 fix(anthropic): execute bash + file tools via tool node (#33960)
* use `override` instead of directly patching things on `ModelRequest`
* rely on `ToolNode` for execution of tools related to said middleware,
using `wrap_model_call` to inject the relevant claude tool specs +
allowing tool node to forward them along to corresponding langchain tool
implementations
* making the same change for the native shell tool middleware
* allowing shell tool middleware to specify a name for the shell tool
(negative diff then for claude bash middleware)


long term I think the solution might be to attach metadata to a tool to
map the provider spec to a langchain implementation, which we could also
take some lessons from on the MCP front.
2025-11-14 13:17:01 -05:00
Mason Daugherty
d2942351ce release(core): 1.0.5 (#33973) 2025-11-14 11:51:27 -05:00
Sydney Runkle
83c078f363 fix: adding missing async hooks (#33957)
* filling in missing async gaps
* using recommended tool runtime injection instead of injected state
  * updating tests to use helper function as well
2025-11-14 09:13:39 -05:00
ZhangShenao
26d39ffc4a docs: Fix doc links (#33964) 2025-11-14 09:07:32 -05:00
Mason Daugherty
421e2ceeee fix(core): don't mask exceptions (#33959) 2025-11-14 09:05:29 -05:00
Mason Daugherty
275dcbf69f docs(core): add clarity to base token counting methods (#33958)
Wasn't immediately obvious that `get_num_tokens_from_messages` adds
additional prefixes to represent user roles in conversation, which adds
to the overall token count.

```python
from langchain_google_genai import GoogleGenerativeAI

llm = GoogleGenerativeAI(model="gemini-2.5-flash")
num_tokens = llm.get_num_tokens("Hello, world!")
print(f"Number of tokens: {num_tokens}")
# Number of tokens: 4
```

```python
from langchain.messages import HumanMessage

messages = [HumanMessage(content="Hello, world!")]

num_tokens = llm.get_num_tokens_from_messages(messages)
print(f"Number of tokens: {num_tokens}")
# Number of tokens: 6
```
2025-11-13 17:15:47 -05:00
Sydney Runkle
9f87b27a5b fix: add filesystem middleware in init (#33955) 2025-11-13 15:07:33 -05:00
Mason Daugherty
b2e1196e29 chore(core,infra): nits (#33954) 2025-11-13 14:50:54 -05:00
Sydney Runkle
2dc1396380 chore(langchain): update deps (#33951) 2025-11-13 14:21:25 -05:00
Mason Daugherty
77941ab3ce feat(infra): add automatic issue labeling (#33952) 2025-11-13 14:13:52 -05:00
Mason Daugherty
ee19a30dde fix(groq): bump min ver for core dep (#33949)
Due to issue with unit tests and docs URL for exceptions
2025-11-13 11:46:54 -05:00
Mason Daugherty
5d799b3174 release(nomic): 1.0.1 (#33948)
support Python 3.14 #33655
2025-11-13 11:25:39 -05:00
Mason Daugherty
8f33a985a2 release(groq): 1.0.1 (#33947)
- fix: handle tool calls with no args #33896
- add prompt caching token usage details #33708
2025-11-13 11:25:00 -05:00
Mason Daugherty
78eeccef0e release(deepseek): 1.0.1 (#33946)
- support strict beta structured output #32727
2025-11-13 11:24:39 -05:00
ccurme
3d415441e8 fix(langchain, openai): backward compat for response_format (#33945) 2025-11-13 11:11:35 -05:00
ccurme
74385e0ebd fix(langchain, openai): fix create_agent / response_format for Responses API (#33939) 2025-11-13 10:18:15 -05:00
Christophe Bornet
2bfbc29ccc chore(core): fix some ruff TC rules (#33929)
fix some ruff TC rules but still don't enforce them as Pydantic model
fields use type annotations at runtime.
2025-11-12 14:07:19 -05:00
Christophe Bornet
ef79c26f18 chore(cli,standard-tests,text-splitters): fix some ruff TC rules (#33934)
Co-authored-by: Mason Daugherty <mason@langchain.dev>
2025-11-12 14:06:31 -05:00
ccurme
fbe32c8e89 release(anthropic): 1.0.3 (#33935) 2025-11-12 10:55:28 -05:00
Mohammad Mohtashim
2511c28f92 feat(anthropic): support code_execution_20250825 (#33925) 2025-11-12 10:44:51 -05:00
Sydney Runkle
637bb1cbbc feat: refactor tests coverage (#33927)
middleware tests have gotten quite unwieldy, major restructuring, sets
the stage for coverage increase

this is super hard to review -- as a proof that we've retained important
tests, I ran coverage on `master` and this branch and confirmed
identical coverage.

* moving all middleware related tests to `agents/middleware` folder
* consolidating related test files
* adding coverage utility to makefile
2025-11-11 10:40:12 -05:00
Mason Daugherty
3dfea96ec1 chore: update README.md files (#33919) 2025-11-10 22:51:35 -05:00
ccurme
68643153e5 feat(langchain): support async summarization in SummarizationMiddleware (#33918) 2025-11-10 15:48:51 -05:00
Abbas Syed
462762f75b test(core): add comprehensive tests for groq block translator (#33906) 2025-11-10 15:45:36 -05:00
ccurme
4f3729c004 release(model-profiles): 0.0.4 (#33917) 2025-11-10 12:06:32 -05:00
Mason Daugherty
ba428cdf54 chore(infra): add note to pr linting workflow (#33916) 2025-11-10 11:49:31 -05:00
Mason Daugherty
69c7d1b01b test(groq,openai): add retries for flaky tests (#33914) 2025-11-10 10:36:11 -05:00
Mason Daugherty
733299ec13 revert(core): "applied secrets_map in load to plain string values" (#33913)
Reverts langchain-ai/langchain#33678

Breaking API change
2025-11-10 10:29:30 -05:00
ccurme
e1adf781c6 feat(langchain): (SummarizationMiddleware) support use of model context windows when triggering summarization (#33825) 2025-11-10 10:08:52 -05:00
175 changed files with 12744 additions and 8389 deletions

View File

@@ -8,16 +8,15 @@ body:
value: |
Thank you for taking the time to file a bug report.
Use this to report BUGS in LangChain. For usage questions, feature requests and general design questions, please use the [LangChain Forum](https://forum.langchain.com/).
For usage questions, feature requests and general design questions, please use the [LangChain Forum](https://forum.langchain.com/).
Relevant links to check before filing a bug report to see if your issue has already been reported, fixed or
if there's another way to solve your problem:
Check these before submitting to see if your issue has already been reported, fixed or if there's another way to solve your problem:
* [LangChain Forum](https://forum.langchain.com/),
* [LangChain documentation with the integrated search](https://docs.langchain.com/oss/python/langchain/overview),
* [API Reference](https://reference.langchain.com/python/),
* [Documentation](https://docs.langchain.com/oss/python/langchain/overview),
* [API Reference Documentation](https://reference.langchain.com/python/),
* [LangChain ChatBot](https://chat.langchain.com/)
* [GitHub search](https://github.com/langchain-ai/langchain),
* [LangChain Forum](https://forum.langchain.com/),
- type: checkboxes
id: checks
attributes:
@@ -36,16 +35,48 @@ body:
required: true
- label: This is not related to the langchain-community package.
required: true
- label: I read what a minimal reproducible example is (https://stackoverflow.com/help/minimal-reproducible-example).
required: true
- label: I posted a self-contained, minimal, reproducible example. A maintainer can copy it and run it AS IS.
required: true
- type: checkboxes
id: package
attributes:
label: Package (Required)
description: |
Which `langchain` package(s) is this bug related to? Select at least one.
Note that if the package you are reporting for is not listed here, it is not in this repository (e.g. `langchain-google-genai` is in [`langchain-ai/langchain-google`](https://github.com/langchain-ai/langchain-google/)).
Please report issues for other packages to their respective repositories.
options:
- label: langchain
- label: langchain-openai
- label: langchain-anthropic
- label: langchain-classic
- label: langchain-core
- label: langchain-cli
- label: langchain-model-profiles
- label: langchain-tests
- label: langchain-text-splitters
- label: langchain-chroma
- label: langchain-deepseek
- label: langchain-exa
- label: langchain-fireworks
- label: langchain-groq
- label: langchain-huggingface
- label: langchain-mistralai
- label: langchain-nomic
- label: langchain-ollama
- label: langchain-perplexity
- label: langchain-prompty
- label: langchain-qdrant
- label: langchain-xai
- label: Other / not sure / general
- type: textarea
id: reproduction
validations:
required: true
attributes:
label: Example Code
label: Example Code (Python)
description: |
Please add a self-contained, [minimal, reproducible, example](https://stackoverflow.com/help/minimal-reproducible-example) with your use case.
@@ -53,15 +84,12 @@ body:
**Important!**
* Avoid screenshots when possible, as they are hard to read and (more importantly) don't allow others to copy-and-paste your code.
* Reduce your code to the minimum required to reproduce the issue if possible. This makes it much easier for others to help you.
* Use code tags (e.g., ```python ... ```) to correctly [format your code](https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting).
* INCLUDE the language label (e.g. `python`) after the first three backticks to enable syntax highlighting. (e.g., ```python rather than ```).
* Avoid screenshots, as they are hard to read and (more importantly) don't allow others to copy-and-paste your code.
* Reduce your code to the minimum required to reproduce the issue if possible.
(This will be automatically formatted into code, so no need for backticks.)
render: python
placeholder: |
The following code:
```python
from langchain_core.runnables import RunnableLambda
def bad_code(inputs) -> int:
@@ -69,17 +97,14 @@ body:
chain = RunnableLambda(bad_code)
chain.invoke('Hello!')
```
- type: textarea
id: error
validations:
required: false
attributes:
label: Error Message and Stack Trace (if applicable)
description: |
If you are reporting an error, please include the full error message and stack trace.
placeholder: |
Exception + full stack trace
If you are reporting an error, please copy and paste the full error message and
stack trace.
(This will be automatically formatted into code, so no need for backticks.)
render: shell
- type: textarea
id: description
attributes:
@@ -99,9 +124,7 @@ body:
attributes:
label: System Info
description: |
Please share your system info with us. Do NOT skip this step and please don't trim
the output. Most users don't include enough information here and it makes it harder
for us to help you.
Please share your system info with us.
Run the following command in your terminal and paste the output here:
@@ -113,8 +136,6 @@ body:
from langchain_core import sys_info
sys_info.print_sys_info()
```
alternatively, put the entire output of `pip freeze` here.
placeholder: |
python -m langchain_core.sys_info
validations:

View File

@@ -1,9 +1,18 @@
blank_issues_enabled: false
version: 2.1
contact_links:
- name: 📚 Documentation
- name: 📚 Documentation issue
url: https://github.com/langchain-ai/docs/issues/new?template=01-langchain.yml
about: Report an issue related to the LangChain documentation
- name: 💬 LangChain Forum
url: https://forum.langchain.com/
about: General community discussions and support
- name: 📚 LangChain Documentation
url: https://docs.langchain.com/oss/python/langchain/overview
about: View the official LangChain documentation
- name: 📚 API Reference Documentation
url: https://reference.langchain.com/python/
about: View the official LangChain API reference documentation
- name: 💬 LangChain Forum
url: https://forum.langchain.com/
about: Ask questions and get help from the community

View File

@@ -13,11 +13,11 @@ body:
Relevant links to check before filing a feature request to see if your request has already been made or
if there's another way to achieve what you want:
* [LangChain Forum](https://forum.langchain.com/),
* [LangChain documentation with the integrated search](https://docs.langchain.com/oss/python/langchain/overview),
* [API Reference](https://reference.langchain.com/python/),
* [Documentation](https://docs.langchain.com/oss/python/langchain/overview),
* [API Reference Documentation](https://reference.langchain.com/python/),
* [LangChain ChatBot](https://chat.langchain.com/)
* [GitHub search](https://github.com/langchain-ai/langchain),
* [LangChain Forum](https://forum.langchain.com/),
- type: checkboxes
id: checks
attributes:
@@ -34,6 +34,40 @@ body:
required: true
- label: This is not related to the langchain-community package.
required: true
- type: checkboxes
id: package
attributes:
label: Package (Required)
description: |
Which `langchain` package(s) is this request related to? Select at least one.
Note that if the package you are requesting for is not listed here, it is not in this repository (e.g. `langchain-google-genai` is in `langchain-ai/langchain`).
Please submit feature requests for other packages to their respective repositories.
options:
- label: langchain
- label: langchain-openai
- label: langchain-anthropic
- label: langchain-classic
- label: langchain-core
- label: langchain-cli
- label: langchain-model-profiles
- label: langchain-tests
- label: langchain-text-splitters
- label: langchain-chroma
- label: langchain-deepseek
- label: langchain-exa
- label: langchain-fireworks
- label: langchain-groq
- label: langchain-huggingface
- label: langchain-mistralai
- label: langchain-nomic
- label: langchain-ollama
- label: langchain-perplexity
- label: langchain-prompty
- label: langchain-qdrant
- label: langchain-xai
- label: Other / not sure / general
- type: textarea
id: feature-description
validations:

View File

@@ -18,3 +18,33 @@ body:
attributes:
label: Issue Content
description: Add the content of the issue here.
- type: checkboxes
id: package
attributes:
label: Package (Required)
description: |
Please select package(s) that this issue is related to.
options:
- label: langchain
- label: langchain-openai
- label: langchain-anthropic
- label: langchain-classic
- label: langchain-core
- label: langchain-cli
- label: langchain-model-profiles
- label: langchain-tests
- label: langchain-text-splitters
- label: langchain-chroma
- label: langchain-deepseek
- label: langchain-exa
- label: langchain-fireworks
- label: langchain-groq
- label: langchain-huggingface
- label: langchain-mistralai
- label: langchain-nomic
- label: langchain-ollama
- label: langchain-perplexity
- label: langchain-prompty
- label: langchain-qdrant
- label: langchain-xai
- label: Other / not sure / general

View File

@@ -25,13 +25,13 @@ body:
label: Task Description
description: |
Provide a clear and detailed description of the task.
What needs to be done? Be specific about the scope and requirements.
placeholder: |
This task involves...
The goal is to...
Specific requirements:
- ...
- ...
@@ -43,7 +43,7 @@ body:
label: Acceptance Criteria
description: |
Define the criteria that must be met for this task to be considered complete.
What are the specific deliverables or outcomes expected?
placeholder: |
This task will be complete when:
@@ -58,15 +58,15 @@ body:
label: Context and Background
description: |
Provide any relevant context, background information, or links to related issues/PRs.
Why is this task needed? What problem does it solve?
placeholder: |
Background:
- ...
Related issues/PRs:
- #...
Additional context:
- ...
validations:
@@ -77,15 +77,45 @@ body:
label: Dependencies
description: |
List any dependencies or blockers for this task.
Are there other tasks, issues, or external factors that need to be completed first?
placeholder: |
This task depends on:
- [ ] Issue #...
- [ ] PR #...
- [ ] External dependency: ...
Blocked by:
- ...
validations:
required: false
- type: checkboxes
id: package
attributes:
label: Package (Required)
description: |
Please select package(s) that this task is related to.
options:
- label: langchain
- label: langchain-openai
- label: langchain-anthropic
- label: langchain-classic
- label: langchain-core
- label: langchain-cli
- label: langchain-model-profiles
- label: langchain-tests
- label: langchain-text-splitters
- label: langchain-chroma
- label: langchain-deepseek
- label: langchain-exa
- label: langchain-fireworks
- label: langchain-groq
- label: langchain-huggingface
- label: langchain-mistralai
- label: langchain-nomic
- label: langchain-ollama
- label: langchain-perplexity
- label: langchain-prompty
- label: langchain-qdrant
- label: langchain-xai
- label: Other / not sure / general

View File

@@ -396,7 +396,7 @@ jobs:
contents: read
strategy:
matrix:
partner: [openai, anthropic]
partner: [anthropic]
fail-fast: false # Continue testing other partners if one fails
env:
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}

View File

@@ -0,0 +1,107 @@
name: Auto Label Issues by Package
on:
issues:
types: [opened, edited]
jobs:
label-by-package:
permissions:
issues: write
runs-on: ubuntu-latest
steps:
- name: Sync package labels
uses: actions/github-script@v6
with:
script: |
const body = context.payload.issue.body || "";
// Extract text under "### Package"
const match = body.match(/### Package\s+([\s\S]*?)\n###/i);
if (!match) return;
const packageSection = match[1].trim();
// Mapping table for package names to labels
const mapping = {
"langchain": "langchain",
"langchain-openai": "openai",
"langchain-anthropic": "anthropic",
"langchain-classic": "langchain-classic",
"langchain-core": "core",
"langchain-cli": "cli",
"langchain-model-profiles": "model-profiles",
"langchain-tests": "standard-tests",
"langchain-text-splitters": "text-splitters",
"langchain-chroma": "chroma",
"langchain-deepseek": "deepseek",
"langchain-exa": "exa",
"langchain-fireworks": "fireworks",
"langchain-groq": "groq",
"langchain-huggingface": "huggingface",
"langchain-mistralai": "mistralai",
"langchain-nomic": "nomic",
"langchain-ollama": "ollama",
"langchain-perplexity": "perplexity",
"langchain-prompty": "prompty",
"langchain-qdrant": "qdrant",
"langchain-xai": "xai",
};
// All possible package labels we manage
const allPackageLabels = Object.values(mapping);
const selectedLabels = [];
// Check if this is checkbox format (multiple selection)
const checkboxMatches = packageSection.match(/- \[x\]\s+([^\n\r]+)/gi);
if (checkboxMatches) {
// Handle checkbox format
for (const match of checkboxMatches) {
const packageName = match.replace(/- \[x\]\s+/i, '').trim();
const label = mapping[packageName];
if (label && !selectedLabels.includes(label)) {
selectedLabels.push(label);
}
}
} else {
// Handle dropdown format (single selection)
const label = mapping[packageSection];
if (label) {
selectedLabels.push(label);
}
}
// Get current issue labels
const issue = await github.rest.issues.get({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: context.issue.number
});
const currentLabels = issue.data.labels.map(label => label.name);
const currentPackageLabels = currentLabels.filter(label => allPackageLabels.includes(label));
// Determine labels to add and remove
const labelsToAdd = selectedLabels.filter(label => !currentPackageLabels.includes(label));
const labelsToRemove = currentPackageLabels.filter(label => !selectedLabels.includes(label));
// Add new labels
if (labelsToAdd.length > 0) {
await github.rest.issues.addLabels({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: context.issue.number,
labels: labelsToAdd
});
}
// Remove old labels
for (const label of labelsToRemove) {
await github.rest.issues.removeLabel({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: context.issue.number,
name: label
});
}

View File

@@ -26,12 +26,14 @@
# * revert — reverts a previous commit
# * release — prepare a new release
#
# Allowed Scopes (optional):
# Allowed Scope(s) (optional):
# core, cli, langchain, langchain_v1, langchain-classic, standard-tests,
# text-splitters, docs, anthropic, chroma, deepseek, exa, fireworks, groq,
# huggingface, mistralai, nomic, ollama, openai, perplexity, prompty, qdrant,
# xai, infra, deps
#
# Multiple scopes can be used by separating them with a comma.
#
# Rules:
# 1. The 'Type' must start with a lowercase letter.
# 2. Breaking changes: append "!" after type/scope (e.g., feat!: drop x support)

3
.gitignore vendored
View File

@@ -163,3 +163,6 @@ node_modules
prof
virtualenv/
scratch/
.langgraph_api/

View File

@@ -6,9 +6,8 @@ import hashlib
import logging
import re
import shutil
from collections.abc import Sequence
from pathlib import Path
from typing import Any, TypedDict
from typing import TYPE_CHECKING, Any, TypedDict
from git import Repo
@@ -18,6 +17,9 @@ from langchain_cli.constants import (
DEFAULT_GIT_SUBDIRECTORY,
)
if TYPE_CHECKING:
from collections.abc import Sequence
logger = logging.getLogger(__name__)

View File

@@ -1,9 +1,11 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING
from .file import File
from .folder import Folder
if TYPE_CHECKING:
from .file import File
from .folder import Folder
@dataclass

View File

@@ -1,9 +1,12 @@
from __future__ import annotations
from pathlib import Path
from typing import TYPE_CHECKING
from .file import File
if TYPE_CHECKING:
from pathlib import Path
class Folder:
def __init__(self, name: str, *files: Folder | File) -> None:

View File

@@ -34,7 +34,7 @@ The LangChain ecosystem is built on top of `langchain-core`. Some of the benefit
## 📖 Documentation
For full documentation, see the [API reference](https://reference.langchain.com/python/langchain_core/).
For full documentation, see the [API reference](https://reference.langchain.com/python/langchain_core/). For conceptual guides, tutorials, and examples on using LangChain, see the [LangChain Docs](https://docs.langchain.com/oss/python/langchain/overview).
## 📕 Releases & Versioning

View File

@@ -5,13 +5,12 @@ from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any
from typing_extensions import Self
if TYPE_CHECKING:
from collections.abc import Sequence
from uuid import UUID
from tenacity import RetryCallState
from typing_extensions import Self
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.documents import Document

View File

@@ -39,7 +39,6 @@ from langchain_core.tracers.context import (
tracing_v2_callback_var,
)
from langchain_core.tracers.langchain import LangChainTracer
from langchain_core.tracers.schemas import Run
from langchain_core.tracers.stdout import ConsoleCallbackHandler
from langchain_core.utils.env import env_var_is_set
@@ -52,6 +51,7 @@ if TYPE_CHECKING:
from langchain_core.documents import Document
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk, LLMResult
from langchain_core.runnables.config import RunnableConfig
from langchain_core.tracers.schemas import Run
logger = logging.getLogger(__name__)

View File

@@ -6,16 +6,9 @@ import hashlib
import json
import uuid
import warnings
from collections.abc import (
AsyncIterable,
AsyncIterator,
Callable,
Iterable,
Iterator,
Sequence,
)
from itertools import islice
from typing import (
TYPE_CHECKING,
Any,
Literal,
TypedDict,
@@ -29,6 +22,16 @@ from langchain_core.exceptions import LangChainException
from langchain_core.indexing.base import DocumentIndex, RecordManager
from langchain_core.vectorstores import VectorStore
if TYPE_CHECKING:
from collections.abc import (
AsyncIterable,
AsyncIterator,
Callable,
Iterable,
Iterator,
Sequence,
)
# Magic UUID to use as a namespace for hashing.
# Used to try and generate a unique UUID for each document
# from hashing the document content and metadata.

View File

@@ -299,6 +299,9 @@ class BaseLanguageModel(
Useful for checking if an input fits in a model's context window.
This should be overridden by model-specific implementations to provide accurate
token counts via model-specific tokenizers.
Args:
text: The string input to tokenize.
@@ -317,9 +320,17 @@ class BaseLanguageModel(
Useful for checking if an input fits in a model's context window.
This should be overridden by model-specific implementations to provide accurate
token counts via model-specific tokenizers.
!!! note
The base implementation of `get_num_tokens_from_messages` ignores tool
schemas.
* The base implementation of `get_num_tokens_from_messages` ignores tool
schemas.
* The base implementation of `get_num_tokens_from_messages` adds additional
prefixes to messages in represent user roles, which will add to the
overall token count. Model-specific implementations may choose to
handle this differently.
Args:
messages: The message inputs to tokenize.

View File

@@ -91,7 +91,10 @@ def _generate_response_from_error(error: BaseException) -> list[ChatGeneration]:
try:
metadata["body"] = response.json()
except Exception:
metadata["body"] = getattr(response, "text", None)
try:
metadata["body"] = getattr(response, "text", None)
except Exception:
metadata["body"] = None
if hasattr(response, "headers"):
try:
metadata["headers"] = dict(response.headers)

View File

@@ -61,13 +61,15 @@ class Reviver:
"""Initialize the reviver.
Args:
secrets_map: A map of secrets to load. If a secret is not found in
the map, it will be loaded from the environment if `secrets_from_env`
is True.
secrets_map: A map of secrets to load.
If a secret is not found in the map, it will be loaded from the
environment if `secrets_from_env` is `True`.
valid_namespaces: A list of additional namespaces (modules)
to allow to be deserialized.
secrets_from_env: Whether to load secrets from the environment.
additional_import_mappings: A dictionary of additional namespace mappings
You can use this to override default mappings or add new mappings.
ignore_unserializable_fields: Whether to ignore unserializable fields.
"""
@@ -195,13 +197,15 @@ def loads(
Args:
text: The string to load.
secrets_map: A map of secrets to load. If a secret is not found in
the map, it will be loaded from the environment if `secrets_from_env`
is True.
secrets_map: A map of secrets to load.
If a secret is not found in the map, it will be loaded from the environment
if `secrets_from_env` is `True`.
valid_namespaces: A list of additional namespaces (modules)
to allow to be deserialized.
secrets_from_env: Whether to load secrets from the environment.
additional_import_mappings: A dictionary of additional namespace mappings
You can use this to override default mappings or add new mappings.
ignore_unserializable_fields: Whether to ignore unserializable fields.
@@ -237,13 +241,15 @@ def load(
Args:
obj: The object to load.
secrets_map: A map of secrets to load. If a secret is not found in
the map, it will be loaded from the environment if `secrets_from_env`
is True.
secrets_map: A map of secrets to load.
If a secret is not found in the map, it will be loaded from the environment
if `secrets_from_env` is `True`.
valid_namespaces: A list of additional namespaces (modules)
to allow to be deserialized.
secrets_from_env: Whether to load secrets from the environment.
additional_import_mappings: A dictionary of additional namespace mappings
You can use this to override default mappings or add new mappings.
ignore_unserializable_fields: Whether to ignore unserializable fields.
@@ -265,8 +271,6 @@ def load(
return reviver(loaded_obj)
if isinstance(obj, list):
return [_load(o) for o in obj]
if isinstance(obj, str) and obj in reviver.secrets_map:
return reviver.secrets_map[obj]
return obj
return _load(obj)

View File

@@ -5,11 +5,9 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Any, cast, overload
from pydantic import ConfigDict, Field
from typing_extensions import Self
from langchain_core._api.deprecation import warn_deprecated
from langchain_core.load.serializable import Serializable
from langchain_core.messages import content as types
from langchain_core.utils import get_bolded_text
from langchain_core.utils._merge import merge_dicts, merge_lists
from langchain_core.utils.interactive_env import is_interactive_env
@@ -17,6 +15,9 @@ from langchain_core.utils.interactive_env import is_interactive_env
if TYPE_CHECKING:
from collections.abc import Sequence
from typing_extensions import Self
from langchain_core.messages import content as types
from langchain_core.prompts.chat import ChatPromptTemplate

View File

@@ -12,10 +12,11 @@ the implementation in `BaseMessage`.
from __future__ import annotations
from collections.abc import Callable
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from collections.abc import Callable
from langchain_core.messages import AIMessage, AIMessageChunk
from langchain_core.messages import content as types

View File

@@ -4,7 +4,6 @@ from __future__ import annotations
import json
import warnings
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, Literal, cast
from langchain_core.language_models._utils import (
@@ -14,6 +13,8 @@ from langchain_core.language_models._utils import (
from langchain_core.messages import content as types
if TYPE_CHECKING:
from collections.abc import Iterable
from langchain_core.messages import AIMessage, AIMessageChunk

View File

@@ -2,15 +2,17 @@
from __future__ import annotations
from typing import Literal
from typing import TYPE_CHECKING, Literal
from pydantic import model_validator
from typing_extensions import Self
from langchain_core.messages import BaseMessage, BaseMessageChunk
from langchain_core.outputs.generation import Generation
from langchain_core.utils._merge import merge_dicts
if TYPE_CHECKING:
from typing_extensions import Self
class ChatGeneration(Generation):
"""A single chat generation output.

View File

@@ -6,7 +6,7 @@ import contextlib
import json
import typing
from abc import ABC, abstractmethod
from collections.abc import Callable, Mapping
from collections.abc import Mapping
from functools import cached_property
from pathlib import Path
from typing import (
@@ -33,6 +33,8 @@ from langchain_core.runnables.config import ensure_config
from langchain_core.utils.pydantic import create_model_v2
if TYPE_CHECKING:
from collections.abc import Callable
from langchain_core.documents import Document

View File

@@ -6,10 +6,10 @@ from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any
from langchain_core.load import Serializable
from langchain_core.messages import BaseMessage
from langchain_core.utils.interactive_env import is_interactive_env
if TYPE_CHECKING:
from langchain_core.messages import BaseMessage
from langchain_core.prompts.chat import ChatPromptTemplate

View File

@@ -4,9 +4,8 @@ from __future__ import annotations
import warnings
from abc import ABC
from collections.abc import Callable, Sequence
from string import Formatter
from typing import Any, Literal
from typing import TYPE_CHECKING, Any, Literal
from pydantic import BaseModel, create_model
@@ -16,6 +15,9 @@ from langchain_core.utils import get_colored_text, mustache
from langchain_core.utils.formatting import formatter
from langchain_core.utils.interactive_env import is_interactive_env
if TYPE_CHECKING:
from collections.abc import Callable, Sequence
try:
from jinja2 import Environment, meta
from jinja2.sandbox import SandboxedEnvironment

View File

@@ -4,7 +4,6 @@ from __future__ import annotations
import inspect
from collections import defaultdict
from collections.abc import Callable
from dataclasses import dataclass, field
from enum import Enum
from typing import (
@@ -22,7 +21,7 @@ from langchain_core.runnables.base import Runnable, RunnableSerializable
from langchain_core.utils.pydantic import _IgnoreUnserializable, is_basemodel_subclass
if TYPE_CHECKING:
from collections.abc import Sequence
from collections.abc import Callable, Sequence
from pydantic import BaseModel
@@ -642,6 +641,7 @@ class Graph:
retry_delay: float = 1.0,
frontmatter_config: dict[str, Any] | None = None,
base_url: str | None = None,
proxies: dict[str, str] | None = None,
) -> bytes:
"""Draw the graph as a PNG image using Mermaid.
@@ -674,11 +674,10 @@ class Graph:
}
```
base_url: The base URL of the Mermaid server for rendering via API.
proxies: HTTP/HTTPS proxies for requests (e.g. `{"http": "http://127.0.0.1:7890"}`).
Returns:
The PNG image as bytes.
"""
# Import locally to prevent circular import
from langchain_core.runnables.graph_mermaid import ( # noqa: PLC0415
@@ -699,6 +698,7 @@ class Graph:
padding=padding,
max_retries=max_retries,
retry_delay=retry_delay,
proxies=proxies,
base_url=base_url,
)

View File

@@ -7,7 +7,6 @@ from __future__ import annotations
import math
import os
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any
try:
@@ -20,6 +19,8 @@ except ImportError:
_HAS_GRANDALF = False
if TYPE_CHECKING:
from collections.abc import Mapping, Sequence
from langchain_core.runnables.graph import Edge as LangEdge

View File

@@ -281,6 +281,7 @@ def draw_mermaid_png(
max_retries: int = 1,
retry_delay: float = 1.0,
base_url: str | None = None,
proxies: dict[str, str] | None = None,
) -> bytes:
"""Draws a Mermaid graph as PNG using provided syntax.
@@ -293,6 +294,7 @@ def draw_mermaid_png(
max_retries: Maximum number of retries (MermaidDrawMethod.API).
retry_delay: Delay between retries (MermaidDrawMethod.API).
base_url: Base URL for the Mermaid.ink API.
proxies: HTTP/HTTPS proxies for requests (e.g. `{"http": "http://127.0.0.1:7890"}`).
Returns:
PNG image bytes.
@@ -314,6 +316,7 @@ def draw_mermaid_png(
max_retries=max_retries,
retry_delay=retry_delay,
base_url=base_url,
proxies=proxies,
)
else:
supported_methods = ", ".join([m.value for m in MermaidDrawMethod])
@@ -405,6 +408,7 @@ def _render_mermaid_using_api(
file_type: Literal["jpeg", "png", "webp"] | None = "png",
max_retries: int = 1,
retry_delay: float = 1.0,
proxies: dict[str, str] | None = None,
base_url: str | None = None,
) -> bytes:
"""Renders Mermaid graph using the Mermaid.INK API."""
@@ -445,7 +449,7 @@ def _render_mermaid_using_api(
for attempt in range(max_retries + 1):
try:
response = requests.get(image_url, timeout=10)
response = requests.get(image_url, timeout=10, proxies=proxies)
if response.status_code == requests.codes.ok:
img_bytes = response.content
if output_file_path is not None:

View File

@@ -7,8 +7,7 @@ import asyncio
import inspect
import sys
import textwrap
from collections.abc import Callable, Mapping, Sequence
from contextvars import Context
from collections.abc import Mapping, Sequence
from functools import lru_cache
from inspect import signature
from itertools import groupby
@@ -31,9 +30,11 @@ if TYPE_CHECKING:
AsyncIterable,
AsyncIterator,
Awaitable,
Callable,
Coroutine,
Iterable,
)
from contextvars import Context
from langchain_core.runnables.schema import StreamEvent

View File

@@ -386,6 +386,8 @@ class ToolException(Exception): # noqa: N818
ArgsSchema = TypeBaseModel | dict[str, Any]
_EMPTY_SET: frozenset[str] = frozenset()
class BaseTool(RunnableSerializable[str | dict | ToolCall, Any]):
"""Base class for all LangChain tools.
@@ -569,6 +571,11 @@ class ChildTool(BaseTool):
self.name, full_schema, fields, fn_description=self.description
)
@functools.cached_property
def _injected_args_keys(self) -> frozenset[str]:
# base implementation doesn't manage injected args
return _EMPTY_SET
# --- Runnable ---
@override
@@ -649,6 +656,7 @@ class ChildTool(BaseTool):
if isinstance(input_args, dict):
return tool_input
if issubclass(input_args, BaseModel):
# Check args_schema for InjectedToolCallId
for k, v in get_all_basemodel_annotations(input_args).items():
if _is_injected_arg_type(v, injected_type=InjectedToolCallId):
if tool_call_id is None:
@@ -664,6 +672,7 @@ class ChildTool(BaseTool):
result = input_args.model_validate(tool_input)
result_dict = result.model_dump()
elif issubclass(input_args, BaseModelV1):
# Check args_schema for InjectedToolCallId
for k, v in get_all_basemodel_annotations(input_args).items():
if _is_injected_arg_type(v, injected_type=InjectedToolCallId):
if tool_call_id is None:
@@ -683,9 +692,25 @@ class ChildTool(BaseTool):
f"args_schema must be a Pydantic BaseModel, got {self.args_schema}"
)
raise NotImplementedError(msg)
return {
k: getattr(result, k) for k, v in result_dict.items() if k in tool_input
validated_input = {
k: getattr(result, k) for k in result_dict if k in tool_input
}
for k in self._injected_args_keys:
if k == "tool_call_id":
if tool_call_id is None:
msg = (
"When tool includes an InjectedToolCallId "
"argument, tool must always be invoked with a full "
"model ToolCall of the form: {'args': {...}, "
"'name': '...', 'type': 'tool_call', "
"'tool_call_id': '...'}"
)
raise ValueError(msg)
validated_input[k] = tool_call_id
if k in tool_input:
injected_val = tool_input[k]
validated_input[k] = injected_val
return validated_input
return tool_input
@abstractmethod

View File

@@ -2,6 +2,7 @@
from __future__ import annotations
import functools
import textwrap
from collections.abc import Awaitable, Callable
from inspect import signature
@@ -21,10 +22,12 @@ from langchain_core.callbacks import (
)
from langchain_core.runnables import RunnableConfig, run_in_executor
from langchain_core.tools.base import (
_EMPTY_SET,
FILTERED_ARGS,
ArgsSchema,
BaseTool,
_get_runnable_config_param,
_is_injected_arg_type,
create_schema_from_function,
)
from langchain_core.utils.pydantic import is_basemodel_subclass
@@ -241,6 +244,17 @@ class StructuredTool(BaseTool):
**kwargs,
)
@functools.cached_property
def _injected_args_keys(self) -> frozenset[str]:
fn = self.func or self.coroutine
if fn is None:
return _EMPTY_SET
return frozenset(
k
for k, v in signature(fn).parameters.items()
if _is_injected_arg_type(v.annotation)
)
def _filter_schema_args(func: Callable) -> list[str]:
filter_args = list(FILTERED_ARGS)

View File

@@ -15,12 +15,6 @@ from typing import (
from langchain_core.exceptions import TracerException
from langchain_core.load import dumpd
from langchain_core.outputs import (
ChatGeneration,
ChatGenerationChunk,
GenerationChunk,
LLMResult,
)
from langchain_core.tracers.schemas import Run
if TYPE_CHECKING:
@@ -31,6 +25,12 @@ if TYPE_CHECKING:
from langchain_core.documents import Document
from langchain_core.messages import BaseMessage
from langchain_core.outputs import (
ChatGeneration,
ChatGenerationChunk,
GenerationChunk,
LLMResult,
)
logger = logging.getLogger(__name__)

View File

@@ -8,7 +8,6 @@ import logging
import types
import typing
import uuid
from collections.abc import Callable
from typing import (
TYPE_CHECKING,
Annotated,
@@ -33,6 +32,8 @@ from langchain_core.utils.json_schema import dereference_refs
from langchain_core.utils.pydantic import is_basemodel_subclass
if TYPE_CHECKING:
from collections.abc import Callable
from langchain_core.tools import BaseTool
logger = logging.getLogger(__name__)

View File

@@ -4,11 +4,13 @@ from __future__ import annotations
import json
import re
from collections.abc import Callable
from typing import Any
from typing import TYPE_CHECKING, Any
from langchain_core.exceptions import OutputParserException
if TYPE_CHECKING:
from collections.abc import Callable
def _replace_new_line(match: re.Match[str]) -> str:
value = match.group(2)

View File

@@ -5,7 +5,6 @@ from __future__ import annotations
import inspect
import textwrap
import warnings
from collections.abc import Callable
from contextlib import nullcontext
from functools import lru_cache, wraps
from types import GenericAlias
@@ -41,10 +40,12 @@ from pydantic.json_schema import (
)
from pydantic.v1 import BaseModel as BaseModelV1
from pydantic.v1 import create_model as create_model_v1
from pydantic.v1.fields import ModelField
from typing_extensions import deprecated, override
if TYPE_CHECKING:
from collections.abc import Callable
from pydantic.v1.fields import ModelField
from pydantic_core import core_schema
PYDANTIC_VERSION = version.parse(pydantic.__version__)

View File

@@ -11,7 +11,6 @@ import logging
import math
import warnings
from abc import ABC, abstractmethod
from collections.abc import Callable
from itertools import cycle
from typing import (
TYPE_CHECKING,
@@ -29,7 +28,7 @@ from langchain_core.retrievers import BaseRetriever, LangSmithRetrieverParams
from langchain_core.runnables.config import run_in_executor
if TYPE_CHECKING:
from collections.abc import Collection, Iterable, Iterator, Sequence
from collections.abc import Callable, Collection, Iterable, Iterator, Sequence
from langchain_core.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,

View File

@@ -4,7 +4,6 @@ from __future__ import annotations
import json
import uuid
from collections.abc import Callable
from pathlib import Path
from typing import (
TYPE_CHECKING,
@@ -20,7 +19,7 @@ from langchain_core.vectorstores.utils import _cosine_similarity as cosine_simil
from langchain_core.vectorstores.utils import maximal_marginal_relevance
if TYPE_CHECKING:
from collections.abc import Iterator, Sequence
from collections.abc import Callable, Iterator, Sequence
from langchain_core.embeddings import Embeddings

View File

@@ -1,3 +1,3 @@
"""langchain-core version information and utilities."""
VERSION = "1.0.4"
VERSION = "1.0.6"

View File

@@ -9,7 +9,7 @@ license = {text = "MIT"}
readme = "README.md"
authors = []
version = "1.0.4"
version = "1.0.6"
requires-python = ">=3.10.0,<4.0.0"
dependencies = [
"langsmith>=0.3.45,<1.0.0",

View File

@@ -18,6 +18,7 @@ from langchain_core.language_models import (
ParrotFakeChatModel,
)
from langchain_core.language_models._utils import _normalize_messages
from langchain_core.language_models.chat_models import _generate_response_from_error
from langchain_core.language_models.fake_chat_models import (
FakeListChatModelError,
GenericFakeChatModel,
@@ -1234,3 +1235,93 @@ def test_model_profiles() -> None:
model = MyModel(messages=iter([]))
profile = model.profile
assert profile
class MockResponse:
"""Mock response for testing _generate_response_from_error."""
def __init__(
self,
status_code: int = 400,
headers: dict[str, str] | None = None,
json_data: dict[str, Any] | None = None,
json_raises: type[Exception] | None = None,
text_raises: type[Exception] | None = None,
):
self.status_code = status_code
self.headers = headers or {}
self._json_data = json_data
self._json_raises = json_raises
self._text_raises = text_raises
def json(self) -> dict[str, Any]:
if self._json_raises:
msg = "JSON parsing failed"
raise self._json_raises(msg)
return self._json_data or {}
@property
def text(self) -> str:
if self._text_raises:
msg = "Text access failed"
raise self._text_raises(msg)
return ""
class MockAPIError(Exception):
"""Mock API error with response attribute."""
def __init__(self, message: str, response: MockResponse | None = None):
super().__init__(message)
self.message = message
if response is not None:
self.response = response
def test_generate_response_from_error_with_valid_json() -> None:
"""Test `_generate_response_from_error` with valid JSON response."""
response = MockResponse(
status_code=400,
headers={"content-type": "application/json"},
json_data={"error": {"message": "Bad request", "type": "invalid_request"}},
)
error = MockAPIError("API Error", response=response)
generations = _generate_response_from_error(error)
assert len(generations) == 1
generation = generations[0]
assert isinstance(generation, ChatGeneration)
assert isinstance(generation.message, AIMessage)
assert generation.message.content == ""
metadata = generation.message.response_metadata
assert metadata["body"] == {
"error": {"message": "Bad request", "type": "invalid_request"}
}
assert metadata["headers"] == {"content-type": "application/json"}
assert metadata["status_code"] == 400
def test_generate_response_from_error_handles_streaming_response_failure() -> None:
# Simulates scenario where accessing response.json() or response.text
# raises ResponseNotRead on streaming responses
response = MockResponse(
status_code=400,
headers={"content-type": "application/json"},
json_raises=Exception, # Simulates ResponseNotRead or similar
text_raises=Exception,
)
error = MockAPIError("API Error", response=response)
# This should NOT raise an exception, but should handle it gracefully
generations = _generate_response_from_error(error)
assert len(generations) == 1
generation = generations[0]
metadata = generation.message.response_metadata
# When both fail, body should be None instead of raising an exception
assert metadata["body"] is None
assert metadata["headers"] == {"content-type": "application/json"}
assert metadata["status_code"] == 400

View File

@@ -1,11 +0,0 @@
"""Test for Serializable base class."""
from langchain_core.load.load import load
def test_load_with_string_secrets() -> None:
obj = {"api_key": "__SECRET_API_KEY__"}
secrets_map = {"__SECRET_API_KEY__": "hello"}
result = load(obj, secrets_map=secrets_map)
assert result["api_key"] == "hello"

View File

@@ -0,0 +1,140 @@
"""Test groq block translator."""
from typing import cast
import pytest
from langchain_core.messages import AIMessage
from langchain_core.messages import content as types
from langchain_core.messages.base import _extract_reasoning_from_additional_kwargs
from langchain_core.messages.block_translators import PROVIDER_TRANSLATORS
from langchain_core.messages.block_translators.groq import (
_parse_code_json,
translate_content,
)
def test_groq_translator_registered() -> None:
"""Test that groq translator is properly registered."""
assert "groq" in PROVIDER_TRANSLATORS
assert "translate_content" in PROVIDER_TRANSLATORS["groq"]
assert "translate_content_chunk" in PROVIDER_TRANSLATORS["groq"]
def test_extract_reasoning_from_additional_kwargs_exists() -> None:
"""Test that _extract_reasoning_from_additional_kwargs can be imported."""
# Verify it's callable
assert callable(_extract_reasoning_from_additional_kwargs)
def test_groq_translate_content_basic() -> None:
"""Test basic groq content translation."""
# Test with simple text message
message = AIMessage(content="Hello world")
blocks = translate_content(message)
assert isinstance(blocks, list)
assert len(blocks) == 1
assert blocks[0]["type"] == "text"
assert blocks[0]["text"] == "Hello world"
def test_groq_translate_content_with_reasoning() -> None:
"""Test groq content translation with reasoning content."""
# Test with reasoning content in additional_kwargs
message = AIMessage(
content="Final answer",
additional_kwargs={"reasoning_content": "Let me think about this..."},
)
blocks = translate_content(message)
assert isinstance(blocks, list)
assert len(blocks) == 2
# First block should be reasoning
assert blocks[0]["type"] == "reasoning"
assert blocks[0]["reasoning"] == "Let me think about this..."
# Second block should be text
assert blocks[1]["type"] == "text"
assert blocks[1]["text"] == "Final answer"
def test_groq_translate_content_with_tool_calls() -> None:
"""Test groq content translation with tool calls."""
# Test with tool calls
message = AIMessage(
content="",
tool_calls=[
{
"name": "search",
"args": {"query": "test"},
"id": "call_123",
}
],
)
blocks = translate_content(message)
assert isinstance(blocks, list)
assert len(blocks) == 1
assert blocks[0]["type"] == "tool_call"
assert blocks[0]["name"] == "search"
assert blocks[0]["args"] == {"query": "test"}
assert blocks[0]["id"] == "call_123"
def test_groq_translate_content_with_executed_tools() -> None:
"""Test groq content translation with executed tools (built-in tools)."""
# Test with executed_tools in additional_kwargs (Groq built-in tools)
message = AIMessage(
content="",
additional_kwargs={
"executed_tools": [
{
"type": "python",
"arguments": '{"code": "print(\\"hello\\")"}',
"output": "hello\\n",
}
]
},
)
blocks = translate_content(message)
assert isinstance(blocks, list)
# Should have server_tool_call and server_tool_result
assert len(blocks) >= 2
# Check for server_tool_call
tool_call_blocks = [
cast("types.ServerToolCall", b)
for b in blocks
if b.get("type") == "server_tool_call"
]
assert len(tool_call_blocks) == 1
assert tool_call_blocks[0]["name"] == "code_interpreter"
assert "code" in tool_call_blocks[0]["args"]
# Check for server_tool_result
tool_result_blocks = [
cast("types.ServerToolResult", b)
for b in blocks
if b.get("type") == "server_tool_result"
]
assert len(tool_result_blocks) == 1
assert tool_result_blocks[0]["output"] == "hello\\n"
assert tool_result_blocks[0]["status"] == "success"
def test_parse_code_json() -> None:
"""Test the _parse_code_json helper function."""
# Test valid code JSON
result = _parse_code_json('{"code": "print(\'hello\')"}')
assert result == {"code": "print('hello')"}
# Test code with unescaped quotes (Groq format)
result = _parse_code_json('{"code": "print("hello")"}')
assert result == {"code": 'print("hello")'}
# Test invalid format raises ValueError
with pytest.raises(ValueError, match="Could not extract Python code"):
_parse_code_json('{"invalid": "format"}')

View File

@@ -3,12 +3,14 @@
import asyncio
import time
from threading import Lock
from typing import Any
from typing import TYPE_CHECKING, Any
import pytest
from langchain_core.runnables import RunnableConfig, RunnableLambda
from langchain_core.runnables.base import Runnable
if TYPE_CHECKING:
from langchain_core.runnables.base import Runnable
@pytest.mark.asyncio

View File

@@ -3,21 +3,25 @@ from __future__ import annotations
import json
import sys
import uuid
from collections.abc import AsyncGenerator, Callable, Coroutine, Generator
from inspect import isasyncgenfunction
from typing import Any, Literal
from typing import TYPE_CHECKING, Any, Literal
from unittest.mock import MagicMock, patch
import pytest
from langsmith import Client, get_current_run_tree, traceable
from langsmith.run_helpers import tracing_context
from langsmith.run_trees import RunTree
from langsmith.utils import get_env_var
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.runnables.base import RunnableLambda, RunnableParallel
from langchain_core.tracers.langchain import LangChainTracer
if TYPE_CHECKING:
from collections.abc import AsyncGenerator, Callable, Coroutine, Generator
from langsmith.run_trees import RunTree
from langchain_core.callbacks import BaseCallbackHandler
def _get_posts(client: Client) -> list:
mock_calls = client.session.request.mock_calls # type: ignore[attr-defined]

View File

@@ -6,6 +6,7 @@ import sys
import textwrap
import threading
from collections.abc import Callable
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from functools import partial
@@ -55,6 +56,7 @@ from langchain_core.tools.base import (
InjectedToolArg,
InjectedToolCallId,
SchemaAnnotationError,
_DirectlyInjectedToolArg,
_is_message_content_block,
_is_message_content_type,
get_all_basemodel_annotations,
@@ -2331,6 +2333,101 @@ def test_injected_arg_with_complex_type() -> None:
assert injected_tool.invoke({"x": 5, "foo": Foo()}) == "bar"
@pytest.mark.parametrize("schema_format", ["model", "json_schema"])
def test_tool_allows_extra_runtime_args_with_custom_schema(
schema_format: Literal["model", "json_schema"],
) -> None:
"""Ensure runtime args are preserved even if not in the args schema."""
class InputSchema(BaseModel):
query: str
captured: dict[str, Any] = {}
@dataclass
class MyRuntime(_DirectlyInjectedToolArg):
some_obj: object
args_schema = (
InputSchema if schema_format == "model" else InputSchema.model_json_schema()
)
@tool(args_schema=args_schema)
def runtime_tool(query: str, runtime: MyRuntime) -> str:
"""Echo the query and capture runtime value."""
captured["runtime"] = runtime
return query
runtime_obj = object()
runtime = MyRuntime(some_obj=runtime_obj)
assert runtime_tool.invoke({"query": "hello", "runtime": runtime}) == "hello"
assert captured["runtime"] is runtime
def test_tool_injected_tool_call_id_with_custom_schema() -> None:
"""Ensure InjectedToolCallId works with custom args schema."""
class InputSchema(BaseModel):
x: int
@tool(args_schema=InputSchema)
def injected_tool(
x: int, tool_call_id: Annotated[str, InjectedToolCallId]
) -> ToolMessage:
"""Tool with injected tool_call_id and custom schema."""
return ToolMessage(str(x), tool_call_id=tool_call_id)
# Test that tool_call_id is properly injected even though not in custom schema
result = injected_tool.invoke(
{
"type": "tool_call",
"args": {"x": 42},
"name": "injected_tool",
"id": "test_call_id",
}
)
assert result == ToolMessage("42", tool_call_id="test_call_id")
# Test that it still raises error when invoked without a ToolCall
with pytest.raises(
ValueError,
match="When tool includes an InjectedToolCallId argument, "
"tool must always be invoked with a full model ToolCall",
):
injected_tool.invoke({"x": 42})
def test_tool_injected_arg_with_custom_schema() -> None:
"""Ensure InjectedToolArg works with custom args schema."""
class InputSchema(BaseModel):
query: str
class CustomContext:
"""Custom context object to be injected."""
def __init__(self, value: str) -> None:
self.value = value
captured: dict[str, Any] = {}
@tool(args_schema=InputSchema)
def search_tool(
query: str, context: Annotated[CustomContext, InjectedToolArg]
) -> str:
"""Search with custom context."""
captured["context"] = context
return f"Results for {query} with context {context.value}"
# Test that context is properly injected even though not in custom schema
ctx = CustomContext("test_context")
result = search_tool.invoke({"query": "hello", "context": ctx})
assert result == "Results for hello with context test_context"
assert captured["context"] is ctx
assert captured["context"].value == "test_context"
def test_tool_injected_tool_call_id() -> None:
@tool
def foo(x: int, tool_call_id: Annotated[str, InjectedToolCallId]) -> ToolMessage:

View File

@@ -1,10 +1,9 @@
import os
import re
import sys
from collections.abc import Callable
from contextlib import AbstractContextManager, nullcontext
from copy import deepcopy
from typing import Any
from typing import TYPE_CHECKING, Any
from unittest.mock import patch
import pytest
@@ -23,6 +22,9 @@ from langchain_core.utils import (
from langchain_core.utils._merge import merge_dicts
from langchain_core.utils.utils import secret_from_env
if TYPE_CHECKING:
from collections.abc import Callable
@pytest.mark.parametrize(
("package", "check_kwargs", "actual_version", "expected"),

4
libs/core/uv.lock generated
View File

@@ -960,7 +960,7 @@ wheels = [
[[package]]
name = "langchain-core"
version = "1.0.4"
version = "1.0.6"
source = { editable = "." }
dependencies = [
{ name = "jsonpatch" },
@@ -1057,7 +1057,7 @@ typing = [
[[package]]
name = "langchain-model-profiles"
version = "0.0.3"
version = "0.0.4"
source = { directory = "../model-profiles" }
dependencies = [
{ name = "tomli", marker = "python_full_version < '3.11'" },

View File

@@ -24,7 +24,7 @@ In most cases, you should be using the main [`langchain`](https://pypi.org/proje
## 📖 Documentation
For full documentation, see the [API reference](https://reference.langchain.com/python/langchain_classic).
For full documentation, see the [API reference](https://reference.langchain.com/python/langchain_classic). For conceptual guides, tutorials, and examples on using LangChain, see the [LangChain Docs](https://docs.langchain.com/oss/python/langchain/overview).
## 📕 Releases & Versioning

View File

@@ -1,4 +1,4 @@
.PHONY: all start_services stop_services coverage test test_fast extended_tests test_watch test_watch_extended integration_tests check_imports lint format lint_diff format_diff lint_package lint_tests help
.PHONY: all start_services stop_services coverage coverage_agents test test_fast extended_tests test_watch test_watch_extended integration_tests check_imports lint format lint_diff format_diff lint_package lint_tests help
# Default target executed when no arguments are given to make.
all: help
@@ -27,8 +27,17 @@ coverage:
--cov-report term-missing:skip-covered \
$(TEST_FILE)
# Run middleware and agent tests with coverage report.
coverage_agents:
uv run --group test pytest \
tests/unit_tests/agents/middleware/ \
tests/unit_tests/agents/test_*.py \
--cov=langchain.agents \
--cov-report=term-missing \
--cov-report=html:htmlcov \
test:
make start_services && LANGGRAPH_TEST_FAST=0 uv run --no-sync --active --group test pytest -n auto --disable-socket --allow-unix-socket $(TEST_FILE) --cov-report term-missing:skip-covered; \
make start_services && LANGGRAPH_TEST_FAST=0 uv run --no-sync --active --group test pytest -n auto --disable-socket --allow-unix-socket $(TEST_FILE) --cov-report term-missing:skip-covered --snapshot-update; \
EXIT_CODE=$$?; \
make stop_services; \
exit $$EXIT_CODE
@@ -93,6 +102,7 @@ help:
@echo 'lint - run linters'
@echo '-- TESTS --'
@echo 'coverage - run unit tests and generate coverage report'
@echo 'coverage_agents - run middleware and agent tests with coverage report'
@echo 'test - run unit tests with all services'
@echo 'test_fast - run unit tests with in-memory services only'
@echo 'tests - run unit tests (alias for "make test")'

View File

@@ -26,7 +26,7 @@ LangChain [agents](https://docs.langchain.com/oss/python/langchain/agents) are b
## 📖 Documentation
For full documentation, see the [API reference](https://reference.langchain.com/python/langchain/langchain/).
For full documentation, see the [API reference](https://reference.langchain.com/python/langchain/langchain/). For conceptual guides, tutorials, and examples on using LangChain, see the [LangChain Docs](https://docs.langchain.com/oss/python/langchain/overview).
## 📕 Releases & Versioning

View File

@@ -22,17 +22,20 @@ from langgraph.graph.state import StateGraph
from langgraph.prebuilt.tool_node import ToolCallWithContext, ToolNode
from langgraph.runtime import Runtime # noqa: TC002
from langgraph.types import Command, Send
from langgraph.typing import ContextT # noqa: TC002
from langgraph.typing import ContextT
from typing_extensions import NotRequired, Required, TypedDict
from langchain.agents.middleware.types import (
AgentMiddleware,
AgentState,
AsyncModelCallHandler,
JumpTo,
ModelCallHandler,
ModelRequest,
ModelResponse,
OmitFromSchema,
ResponseT,
StateT,
StateT_co,
_InputAgentState,
_OutputAgentState,
@@ -63,6 +66,18 @@ if TYPE_CHECKING:
STRUCTURED_OUTPUT_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes."
FALLBACK_MODELS_WITH_STRUCTURED_OUTPUT = [
# if langchain-model-profiles is not installed, these models are assumed to support
# structured output
"grok",
"gpt-5",
"gpt-4.1",
"gpt-4o",
"gpt-oss",
"o3-pro",
"o3-mini",
]
def _normalize_to_model_response(result: ModelResponse | AIMessage) -> ModelResponse:
"""Normalize middleware return value to ModelResponse."""
@@ -74,13 +89,13 @@ def _normalize_to_model_response(result: ModelResponse | AIMessage) -> ModelResp
def _chain_model_call_handlers(
handlers: Sequence[
Callable[
[ModelRequest, Callable[[ModelRequest], ModelResponse]],
[ModelRequest[StateT, ContextT], ModelCallHandler[StateT, ContextT]],
ModelResponse | AIMessage,
]
],
) -> (
Callable[
[ModelRequest, Callable[[ModelRequest], ModelResponse]],
[ModelRequest[StateT, ContextT], ModelCallHandler[StateT, ContextT]],
ModelResponse,
]
| None
@@ -128,8 +143,8 @@ def _chain_model_call_handlers(
single_handler = handlers[0]
def normalized_single(
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
request: ModelRequest[StateT, ContextT],
handler: ModelCallHandler[StateT, ContextT],
) -> ModelResponse:
result = single_handler(request, handler)
return _normalize_to_model_response(result)
@@ -138,25 +153,25 @@ def _chain_model_call_handlers(
def compose_two(
outer: Callable[
[ModelRequest, Callable[[ModelRequest], ModelResponse]],
[ModelRequest[StateT, ContextT], ModelCallHandler[StateT, ContextT]],
ModelResponse | AIMessage,
],
inner: Callable[
[ModelRequest, Callable[[ModelRequest], ModelResponse]],
[ModelRequest[StateT, ContextT], ModelCallHandler[StateT, ContextT]],
ModelResponse | AIMessage,
],
) -> Callable[
[ModelRequest, Callable[[ModelRequest], ModelResponse]],
[ModelRequest[StateT, ContextT], ModelCallHandler[StateT, ContextT]],
ModelResponse,
]:
"""Compose two handlers where outer wraps inner."""
def composed(
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
request: ModelRequest[StateT, ContextT],
handler: ModelCallHandler[StateT, ContextT],
) -> ModelResponse:
# Create a wrapper that calls inner with the base handler and normalizes
def inner_handler(req: ModelRequest) -> ModelResponse:
def inner_handler(req: ModelRequest[StateT, ContextT]) -> ModelResponse:
inner_result = inner(req, handler)
return _normalize_to_model_response(inner_result)
@@ -173,8 +188,8 @@ def _chain_model_call_handlers(
# Wrap to ensure final return type is exactly ModelResponse
def final_normalized(
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
request: ModelRequest[StateT, ContextT],
handler: ModelCallHandler[StateT, ContextT],
) -> ModelResponse:
# result here is typed as returning ModelResponse | AIMessage but compose_two normalizes
final_result = result(request, handler)
@@ -186,13 +201,13 @@ def _chain_model_call_handlers(
def _chain_async_model_call_handlers(
handlers: Sequence[
Callable[
[ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]],
[ModelRequest[StateT, ContextT], AsyncModelCallHandler[StateT, ContextT]],
Awaitable[ModelResponse | AIMessage],
]
],
) -> (
Callable[
[ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]],
[ModelRequest[StateT, ContextT], AsyncModelCallHandler[StateT, ContextT]],
Awaitable[ModelResponse],
]
| None
@@ -213,8 +228,8 @@ def _chain_async_model_call_handlers(
single_handler = handlers[0]
async def normalized_single(
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
request: ModelRequest[StateT, ContextT],
handler: AsyncModelCallHandler[StateT, ContextT],
) -> ModelResponse:
result = await single_handler(request, handler)
return _normalize_to_model_response(result)
@@ -223,25 +238,25 @@ def _chain_async_model_call_handlers(
def compose_two(
outer: Callable[
[ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]],
[ModelRequest[StateT, ContextT], AsyncModelCallHandler[StateT, ContextT]],
Awaitable[ModelResponse | AIMessage],
],
inner: Callable[
[ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]],
[ModelRequest[StateT, ContextT], AsyncModelCallHandler[StateT, ContextT]],
Awaitable[ModelResponse | AIMessage],
],
) -> Callable[
[ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]],
[ModelRequest[StateT, ContextT], AsyncModelCallHandler[StateT, ContextT]],
Awaitable[ModelResponse],
]:
"""Compose two async handlers where outer wraps inner."""
async def composed(
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
request: ModelRequest[StateT, ContextT],
handler: AsyncModelCallHandler[StateT, ContextT],
) -> ModelResponse:
# Create a wrapper that calls inner with the base handler and normalizes
async def inner_handler(req: ModelRequest) -> ModelResponse:
async def inner_handler(req: ModelRequest[StateT, ContextT]) -> ModelResponse:
inner_result = await inner(req, handler)
return _normalize_to_model_response(inner_result)
@@ -258,8 +273,8 @@ def _chain_async_model_call_handlers(
# Wrap to ensure final return type is exactly ModelResponse
async def final_normalized(
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
request: ModelRequest[StateT, ContextT],
handler: AsyncModelCallHandler[StateT, ContextT],
) -> ModelResponse:
# result here is typed as returning ModelResponse | AIMessage but compose_two normalizes
final_result = await result(request, handler)
@@ -349,11 +364,13 @@ def _get_can_jump_to(middleware: AgentMiddleware[Any, Any], hook_name: str) -> l
return []
def _supports_provider_strategy(model: str | BaseChatModel) -> bool:
def _supports_provider_strategy(model: str | BaseChatModel, tools: list | None = None) -> bool:
"""Check if a model supports provider-specific structured output.
Args:
model: Model name string or `BaseChatModel` instance.
tools: Optional list of tools provided to the agent. Needed because some models
don't support structured output together with tool calling.
Returns:
`True` if the model supports provider-specific structured output, `False` otherwise.
@@ -362,11 +379,26 @@ def _supports_provider_strategy(model: str | BaseChatModel) -> bool:
if isinstance(model, str):
model_name = model
elif isinstance(model, BaseChatModel):
model_name = getattr(model, "model_name", None)
model_name = (
getattr(model, "model_name", None)
or getattr(model, "model", None)
or getattr(model, "model_id", "")
)
try:
model_profile = model.profile
except ImportError:
pass
else:
if (
model_profile.get("structured_output")
# We make an exception for Gemini models, which currently do not support
# simultaneous tool use with structured output
and not (tools and isinstance(model_name, str) and "gemini" in model_name.lower())
):
return True
return (
"grok" in model_name.lower()
or any(part in model_name for part in ["gpt-5", "gpt-4.1", "gpt-oss", "o3-pro", "o3-mini"])
any(part in model_name.lower() for part in FALLBACK_MODELS_WITH_STRUCTURED_OUTPUT)
if model_name
else False
)
@@ -517,9 +549,9 @@ def create_agent( # noqa: PLR0915
tools: Sequence[BaseTool | Callable | dict[str, Any]] | None = None,
*,
system_prompt: str | None = None,
middleware: Sequence[AgentMiddleware[StateT_co, ContextT]] = (),
response_format: ResponseFormat[ResponseT] | type[ResponseT] | None = None,
state_schema: type[AgentState[ResponseT]] | None = None,
middleware: Sequence[AgentMiddleware[StateT_co, ContextT]] = (),
context_schema: type[ContextT] | None = None,
checkpointer: Checkpointer | None = None,
store: BaseStore | None = None,
@@ -939,7 +971,9 @@ def create_agent( # noqa: PLR0915
return {"messages": [output]}
def _get_bound_model(request: ModelRequest) -> tuple[Runnable, ResponseFormat | None]:
def _get_bound_model(
request: ModelRequest[StateT, ContextT],
) -> tuple[Runnable, ResponseFormat | None]:
"""Get the model with appropriate tool bindings.
Performs auto-detection of strategy if needed based on model capabilities.
@@ -988,7 +1022,7 @@ def create_agent( # noqa: PLR0915
effective_response_format: ResponseFormat | None
if isinstance(request.response_format, AutoStrategy):
# User provided raw schema via AutoStrategy - auto-detect best strategy based on model
if _supports_provider_strategy(request.model):
if _supports_provider_strategy(request.model, tools=request.tools):
# Model supports provider strategy - use it
effective_response_format = ProviderStrategy(schema=request.response_format.schema)
else:
@@ -1009,7 +1043,7 @@ def create_agent( # noqa: PLR0915
# Bind model based on effective response format
if isinstance(effective_response_format, ProviderStrategy):
# Use provider-specific structured output
# (Backward compatibility) Use OpenAI format structured output
kwargs = effective_response_format.to_model_kwargs()
return (
request.model.bind_tools(
@@ -1053,7 +1087,7 @@ def create_agent( # noqa: PLR0915
)
return request.model.bind(**request.model_settings), None
def _execute_model_sync(request: ModelRequest) -> ModelResponse:
def _execute_model_sync(request: ModelRequest[StateT, ContextT]) -> ModelResponse:
"""Execute model and return response.
This is the core model execution logic wrapped by `wrap_model_call` handlers.
@@ -1077,9 +1111,9 @@ def create_agent( # noqa: PLR0915
structured_response=structured_response,
)
def model_node(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
def model_node(state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any]:
"""Sync model request handler with sequential middleware processing."""
request = ModelRequest(
request = ModelRequest[StateT, ContextT](
model=model,
tools=default_tools,
system_prompt=system_prompt,
@@ -1104,7 +1138,7 @@ def create_agent( # noqa: PLR0915
return state_updates
async def _execute_model_async(request: ModelRequest) -> ModelResponse:
async def _execute_model_async(request: ModelRequest[StateT, ContextT]) -> ModelResponse:
"""Execute model asynchronously and return response.
This is the core async model execution logic wrapped by `wrap_model_call`
@@ -1130,9 +1164,9 @@ def create_agent( # noqa: PLR0915
structured_response=structured_response,
)
async def amodel_node(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
async def amodel_node(state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any]:
"""Async model request handler with sequential middleware processing."""
request = ModelRequest(
request = ModelRequest[StateT, ContextT](
model=model,
tools=default_tools,
system_prompt=system_prompt,

View File

@@ -4,6 +4,7 @@ from .context_editing import (
ClearToolUsesEdit,
ContextEditingMiddleware,
)
from .file_search import FilesystemFileSearchMiddleware
from .human_in_the_loop import (
HumanInTheLoopMiddleware,
InterruptOnConfig,
@@ -46,6 +47,7 @@ __all__ = [
"CodexSandboxExecutionPolicy",
"ContextEditingMiddleware",
"DockerExecutionPolicy",
"FilesystemFileSearchMiddleware",
"HostExecutionPolicy",
"HumanInTheLoopMiddleware",
"InterruptOnConfig",

View File

@@ -10,6 +10,7 @@ chat model.
from __future__ import annotations
from collections.abc import Awaitable, Callable, Iterable, Sequence
from copy import deepcopy
from dataclasses import dataclass
from typing import Literal
@@ -238,10 +239,11 @@ class ContextEditingMiddleware(AgentMiddleware):
system_msg + list(messages), request.tools
)
edited_messages = deepcopy(list(request.messages))
for edit in self.edits:
edit.apply(request.messages, count_tokens=count_tokens)
edit.apply(edited_messages, count_tokens=count_tokens)
return handler(request)
return handler(request.override(messages=edited_messages))
async def awrap_model_call(
self,
@@ -266,10 +268,11 @@ class ContextEditingMiddleware(AgentMiddleware):
system_msg + list(messages), request.tools
)
edited_messages = deepcopy(list(request.messages))
for edit in self.edits:
edit.apply(request.messages, count_tokens=count_tokens)
edit.apply(edited_messages, count_tokens=count_tokens)
return await handler(request)
return await handler(request.override(messages=edited_messages))
__all__ = [

View File

@@ -120,9 +120,9 @@ class FilesystemFileSearchMiddleware(AgentMiddleware):
Args:
root_path: Root directory to search.
use_ripgrep: Whether to use ripgrep for search.
use_ripgrep: Whether to use `ripgrep` for search.
Falls back to Python if ripgrep unavailable.
Falls back to Python if `ripgrep` unavailable.
max_file_size_mb: Maximum file size to search in MB.
"""
self.root_path = Path(root_path).resolve()

View File

@@ -353,3 +353,7 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
last_ai_msg.tool_calls = revised_tool_calls
return {"messages": [last_ai_msg, *artificial_tool_messages]}
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
"""Async trigger interrupt flows for relevant tool calls after an `AIMessage`."""
return self.after_model(state, runtime)

View File

@@ -133,6 +133,7 @@ class ModelCallLimitMiddleware(AgentMiddleware[ModelCallLimitState, Any]):
`None` means no limit.
exit_behavior: What to do when limits are exceeded.
- `'end'`: Jump to the end of the agent execution and
inject an artificial AI message indicating that the limit was
exceeded.
@@ -198,6 +199,29 @@ class ModelCallLimitMiddleware(AgentMiddleware[ModelCallLimitState, Any]):
return None
@hook_config(can_jump_to=["end"])
async def abefore_model(
self,
state: ModelCallLimitState,
runtime: Runtime,
) -> dict[str, Any] | None:
"""Async check model call limits before making a model call.
Args:
state: The current agent state containing call counts.
runtime: The langgraph runtime.
Returns:
If limits are exceeded and exit_behavior is `'end'`, returns
a `Command` to jump to the end with a limit exceeded message. Otherwise
returns `None`.
Raises:
ModelCallLimitExceededError: If limits are exceeded and `exit_behavior`
is `'error'`.
"""
return self.before_model(state, runtime)
def after_model(self, state: ModelCallLimitState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
"""Increment model call counts after a model call.
@@ -212,3 +236,19 @@ class ModelCallLimitMiddleware(AgentMiddleware[ModelCallLimitState, Any]):
"thread_model_call_count": state.get("thread_model_call_count", 0) + 1,
"run_model_call_count": state.get("run_model_call_count", 0) + 1,
}
async def aafter_model(
self,
state: ModelCallLimitState,
runtime: Runtime,
) -> dict[str, Any] | None:
"""Async increment model call counts after a model call.
Args:
state: The current agent state.
runtime: The langgraph runtime.
Returns:
State updates with incremented call counts.
"""
return self.after_model(state, runtime)

View File

@@ -92,9 +92,8 @@ class ModelFallbackMiddleware(AgentMiddleware):
# Try fallback models
for fallback_model in self.models:
request.model = fallback_model
try:
return handler(request)
return handler(request.override(model=fallback_model))
except Exception as e: # noqa: BLE001
last_exception = e
continue
@@ -127,9 +126,8 @@ class ModelFallbackMiddleware(AgentMiddleware):
# Try fallback models
for fallback_model in self.models:
request.model = fallback_model
try:
return await handler(request)
return await handler(request.override(model=fallback_model))
except Exception as e: # noqa: BLE001
last_exception = e
continue

View File

@@ -252,6 +252,27 @@ class PIIMiddleware(AgentMiddleware):
return None
@hook_config(can_jump_to=["end"])
async def abefore_model(
self,
state: AgentState,
runtime: Runtime,
) -> dict[str, Any] | None:
"""Async check user messages and tool results for PII before model invocation.
Args:
state: The current agent state.
runtime: The langgraph runtime.
Returns:
Updated state with PII handled according to strategy, or `None` if no PII
detected.
Raises:
PIIDetectionError: If PII is detected and strategy is `'block'`.
"""
return self.before_model(state, runtime)
def after_model(
self,
state: AgentState,
@@ -311,6 +332,26 @@ class PIIMiddleware(AgentMiddleware):
return {"messages": new_messages}
async def aafter_model(
self,
state: AgentState,
runtime: Runtime,
) -> dict[str, Any] | None:
"""Async check AI messages for PII after model invocation.
Args:
state: The current agent state.
runtime: The langgraph runtime.
Returns:
Updated state with PII handled according to strategy, or None if no PII
detected.
Raises:
PIIDetectionError: If PII is detected and strategy is `'block'`.
"""
return self.after_model(state, runtime)
__all__ = [
"PIIDetectionError",

View File

@@ -11,17 +11,17 @@ import subprocess
import tempfile
import threading
import time
import typing
import uuid
import weakref
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Annotated, Any, Literal
from typing import TYPE_CHECKING, Annotated, Any, Literal, cast
from langchain_core.messages import ToolMessage
from langchain_core.tools.base import BaseTool, ToolException
from langchain_core.tools.base import ToolException
from langgraph.channels.untracked_value import UntrackedValue
from pydantic import BaseModel, model_validator
from pydantic.json_schema import SkipJsonSchema
from typing_extensions import NotRequired
from langchain.agents.middleware._execution import (
@@ -38,14 +38,13 @@ from langchain.agents.middleware._redaction import (
ResolvedRedactionRule,
)
from langchain.agents.middleware.types import AgentMiddleware, AgentState, PrivateStateAttr
from langchain.tools import ToolRuntime, tool
if TYPE_CHECKING:
from collections.abc import Mapping, Sequence
from langgraph.runtime import Runtime
from langgraph.types import Command
from langchain.agents.middleware.types import ToolCallRequest
LOGGER = logging.getLogger(__name__)
_DONE_MARKER_PREFIX = "__LC_SHELL_DONE__"
@@ -59,6 +58,7 @@ DEFAULT_TOOL_DESCRIPTION = (
"session remains stable. Outputs may be truncated when they become very large, and long "
"running commands will be terminated once their configured timeout elapses."
)
SHELL_TOOL_NAME = "shell"
def _cleanup_resources(
@@ -334,7 +334,17 @@ class _ShellToolInput(BaseModel):
"""Input schema for the persistent shell tool."""
command: str | None = None
"""The shell command to execute."""
restart: bool | None = None
"""Whether to restart the shell session."""
runtime: Annotated[Any, SkipJsonSchema()] = None
"""The runtime for the shell tool.
Included as a workaround at the moment bc args_schema doesn't work with
injected ToolRuntime.
"""
@model_validator(mode="after")
def validate_payload(self) -> _ShellToolInput:
@@ -347,24 +357,6 @@ class _ShellToolInput(BaseModel):
return self
class _PersistentShellTool(BaseTool):
"""Tool wrapper that relies on middleware interception for execution."""
name: str = "shell"
description: str = DEFAULT_TOOL_DESCRIPTION
args_schema: type[BaseModel] = _ShellToolInput
def __init__(self, middleware: ShellToolMiddleware, description: str | None = None) -> None:
super().__init__()
self._middleware = middleware
if description is not None:
self.description = description
def _run(self, **_: Any) -> Any: # pragma: no cover - executed via middleware wrapper
msg = "Persistent shell tool execution should be intercepted via middleware wrappers."
raise RuntimeError(msg)
class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
"""Middleware that registers a persistent shell tool for agents.
@@ -393,10 +385,11 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
execution_policy: BaseExecutionPolicy | None = None,
redaction_rules: tuple[RedactionRule, ...] | list[RedactionRule] | None = None,
tool_description: str | None = None,
tool_name: str = SHELL_TOOL_NAME,
shell_command: Sequence[str] | str | None = None,
env: Mapping[str, Any] | None = None,
) -> None:
"""Initialize the middleware.
"""Initialize an instance of `ShellToolMiddleware`.
Args:
workspace_root: Base directory for the shell session.
@@ -414,6 +407,9 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
returning it to the model.
tool_description: Optional override for the registered shell tool
description.
tool_name: Name for the registered shell tool.
Defaults to `"shell"`.
shell_command: Optional shell executable (string) or argument sequence used
to launch the persistent session.
@@ -425,6 +421,7 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
"""
super().__init__()
self._workspace_root = Path(workspace_root) if workspace_root else None
self._tool_name = tool_name
self._shell_command = self._normalize_shell_command(shell_command)
self._environment = self._normalize_env(env)
if execution_policy is not None:
@@ -438,9 +435,25 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
self._startup_commands = self._normalize_commands(startup_commands)
self._shutdown_commands = self._normalize_commands(shutdown_commands)
# Create a proper tool that executes directly (no interception needed)
description = tool_description or DEFAULT_TOOL_DESCRIPTION
self._tool = _PersistentShellTool(self, description=description)
self.tools = [self._tool]
@tool(self._tool_name, args_schema=_ShellToolInput, description=description)
def shell_tool(
*,
runtime: ToolRuntime[None, ShellToolState],
command: str | None = None,
restart: bool = False,
) -> ToolMessage | str:
resources = self._get_or_create_resources(runtime.state)
return self._run_shell_tool(
resources,
{"command": command, "restart": restart},
tool_call_id=runtime.tool_call_id,
)
self._shell_tool = shell_tool
self.tools = [self._shell_tool]
@staticmethod
def _normalize_commands(
@@ -478,36 +491,48 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
def before_agent(self, state: ShellToolState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
"""Start the shell session and run startup commands."""
resources = self._create_resources()
resources = self._get_or_create_resources(state)
return {"shell_session_resources": resources}
async def abefore_agent(self, state: ShellToolState, runtime: Runtime) -> dict[str, Any] | None:
"""Async counterpart to `before_agent`."""
"""Async start the shell session and run startup commands."""
return self.before_agent(state, runtime)
def after_agent(self, state: ShellToolState, runtime: Runtime) -> None: # noqa: ARG002
"""Run shutdown commands and release resources when an agent completes."""
resources = self._ensure_resources(state)
resources = state.get("shell_session_resources")
if not isinstance(resources, _SessionResources):
# Resources were never created, nothing to clean up
return
try:
self._run_shutdown_commands(resources.session)
finally:
resources._finalizer()
async def aafter_agent(self, state: ShellToolState, runtime: Runtime) -> None:
"""Async counterpart to `after_agent`."""
"""Async run shutdown commands and release resources when an agent completes."""
return self.after_agent(state, runtime)
def _ensure_resources(self, state: ShellToolState) -> _SessionResources:
def _get_or_create_resources(self, state: ShellToolState) -> _SessionResources:
"""Get existing resources from state or create new ones if they don't exist.
This method enables resumability by checking if resources already exist in the state
(e.g., after an interrupt), and only creating new resources if they're not present.
Args:
state: The agent state which may contain shell session resources.
Returns:
Session resources, either retrieved from state or newly created.
"""
resources = state.get("shell_session_resources")
if resources is not None and not isinstance(resources, _SessionResources):
resources = None
if resources is None:
msg = (
"Shell session resources are unavailable. Ensure `before_agent` ran successfully "
"before invoking the shell tool."
)
raise ToolException(msg)
return resources
if isinstance(resources, _SessionResources):
return resources
new_resources = self._create_resources()
# Cast needed to make state dict-like for mutation
cast("dict[str, Any]", state)["shell_session_resources"] = new_resources
return new_resources
def _create_resources(self) -> _SessionResources:
workspace = self._workspace_root
@@ -669,36 +694,6 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
artifact=artifact,
)
def wrap_tool_call(
self,
request: ToolCallRequest,
handler: typing.Callable[[ToolCallRequest], ToolMessage | Command],
) -> ToolMessage | Command:
"""Intercept local shell tool calls and execute them via the managed session."""
if isinstance(request.tool, _PersistentShellTool):
resources = self._ensure_resources(request.state)
return self._run_shell_tool(
resources,
request.tool_call["args"],
tool_call_id=request.tool_call.get("id"),
)
return handler(request)
async def awrap_tool_call(
self,
request: ToolCallRequest,
handler: typing.Callable[[ToolCallRequest], typing.Awaitable[ToolMessage | Command]],
) -> ToolMessage | Command:
"""Async interception mirroring the synchronous tool handler."""
if isinstance(request.tool, _PersistentShellTool):
resources = self._ensure_resources(request.state)
return self._run_shell_tool(
resources,
request.tool_call["args"],
tool_call_id=request.tool_call.get("id"),
)
return await handler(request)
def _format_tool_message(
self,
content: str,
@@ -713,7 +708,7 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
return ToolMessage(
content=content,
tool_call_id=tool_call_id,
name=self._tool.name,
name=self._tool_name,
status=status,
artifact=artifact,
)

View File

@@ -1,8 +1,9 @@
"""Summarization middleware."""
import uuid
from collections.abc import Callable, Iterable
from typing import Any, cast
import warnings
from collections.abc import Callable, Iterable, Mapping
from typing import Any, Literal, cast
from langchain_core.messages import (
AIMessage,
@@ -51,13 +52,17 @@ Messages to summarize:
{messages}
</messages>""" # noqa: E501
SUMMARY_PREFIX = "## Previous conversation summary:"
_DEFAULT_MESSAGES_TO_KEEP = 20
_DEFAULT_TRIM_TOKEN_LIMIT = 4000
_DEFAULT_FALLBACK_MESSAGE_COUNT = 15
_SEARCH_RANGE_FOR_TOOL_PAIRS = 5
ContextFraction = tuple[Literal["fraction"], float]
ContextTokens = tuple[Literal["tokens"], int]
ContextMessages = tuple[Literal["messages"], int]
ContextSize = ContextFraction | ContextTokens | ContextMessages
class SummarizationMiddleware(AgentMiddleware):
"""Summarizes conversation history when token limits are approached.
@@ -70,34 +75,95 @@ class SummarizationMiddleware(AgentMiddleware):
def __init__(
self,
model: str | BaseChatModel,
max_tokens_before_summary: int | None = None,
messages_to_keep: int = _DEFAULT_MESSAGES_TO_KEEP,
*,
trigger: ContextSize | list[ContextSize] | None = None,
keep: ContextSize = ("messages", _DEFAULT_MESSAGES_TO_KEEP),
token_counter: TokenCounter = count_tokens_approximately,
summary_prompt: str = DEFAULT_SUMMARY_PROMPT,
summary_prefix: str = SUMMARY_PREFIX,
trim_tokens_to_summarize: int | None = _DEFAULT_TRIM_TOKEN_LIMIT,
**deprecated_kwargs: Any,
) -> None:
"""Initialize summarization middleware.
Args:
model: The language model to use for generating summaries.
max_tokens_before_summary: Token threshold to trigger summarization.
If `None`, summarization is disabled.
messages_to_keep: Number of recent messages to preserve after summarization.
trigger: One or more thresholds that trigger summarization.
Provide a single `ContextSize` tuple or a list of tuples, in which case
summarization runs when any threshold is breached.
Examples: `("messages", 50)`, `("tokens", 3000)`, `[("fraction", 0.8),
("messages", 100)]`.
keep: Context retention policy applied after summarization.
Provide a `ContextSize` tuple to specify how much history to preserve.
Defaults to keeping the most recent 20 messages.
Examples: `("messages", 20)`, `("tokens", 3000)`, or
`("fraction", 0.3)`.
token_counter: Function to count tokens in messages.
summary_prompt: Prompt template for generating summaries.
summary_prefix: Prefix added to system message when including summary.
trim_tokens_to_summarize: Maximum tokens to keep when preparing messages for
the summarization call.
Pass `None` to skip trimming entirely.
"""
# Handle deprecated parameters
if "max_tokens_before_summary" in deprecated_kwargs:
value = deprecated_kwargs["max_tokens_before_summary"]
warnings.warn(
"max_tokens_before_summary is deprecated. Use trigger=('tokens', value) instead.",
DeprecationWarning,
stacklevel=2,
)
if trigger is None and value is not None:
trigger = ("tokens", value)
if "messages_to_keep" in deprecated_kwargs:
value = deprecated_kwargs["messages_to_keep"]
warnings.warn(
"messages_to_keep is deprecated. Use keep=('messages', value) instead.",
DeprecationWarning,
stacklevel=2,
)
if keep == ("messages", _DEFAULT_MESSAGES_TO_KEEP):
keep = ("messages", value)
super().__init__()
if isinstance(model, str):
model = init_chat_model(model)
self.model = model
self.max_tokens_before_summary = max_tokens_before_summary
self.messages_to_keep = messages_to_keep
if trigger is None:
self.trigger: ContextSize | list[ContextSize] | None = None
trigger_conditions: list[ContextSize] = []
elif isinstance(trigger, list):
validated_list = [self._validate_context_size(item, "trigger") for item in trigger]
self.trigger = validated_list
trigger_conditions = validated_list
else:
validated = self._validate_context_size(trigger, "trigger")
self.trigger = validated
trigger_conditions = [validated]
self._trigger_conditions = trigger_conditions
self.keep = self._validate_context_size(keep, "keep")
self.token_counter = token_counter
self.summary_prompt = summary_prompt
self.summary_prefix = summary_prefix
self.trim_tokens_to_summarize = trim_tokens_to_summarize
requires_profile = any(condition[0] == "fraction" for condition in self._trigger_conditions)
if self.keep[0] == "fraction":
requires_profile = True
if requires_profile and self._get_profile_limits() is None:
msg = (
"Model profile information is required to use fractional token limits. "
'pip install "langchain[model-profiles]" or use absolute token counts '
"instead."
)
raise ValueError(msg)
def before_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
"""Process messages before model invocation, potentially triggering summarization."""
@@ -105,13 +171,10 @@ class SummarizationMiddleware(AgentMiddleware):
self._ensure_message_ids(messages)
total_tokens = self.token_counter(messages)
if (
self.max_tokens_before_summary is not None
and total_tokens < self.max_tokens_before_summary
):
if not self._should_summarize(messages, total_tokens):
return None
cutoff_index = self._find_safe_cutoff(messages)
cutoff_index = self._determine_cutoff_index(messages)
if cutoff_index <= 0:
return None
@@ -129,6 +192,151 @@ class SummarizationMiddleware(AgentMiddleware):
]
}
async def abefore_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
"""Process messages before model invocation, potentially triggering summarization."""
messages = state["messages"]
self._ensure_message_ids(messages)
total_tokens = self.token_counter(messages)
if not self._should_summarize(messages, total_tokens):
return None
cutoff_index = self._determine_cutoff_index(messages)
if cutoff_index <= 0:
return None
messages_to_summarize, preserved_messages = self._partition_messages(messages, cutoff_index)
summary = await self._acreate_summary(messages_to_summarize)
new_messages = self._build_new_messages(summary)
return {
"messages": [
RemoveMessage(id=REMOVE_ALL_MESSAGES),
*new_messages,
*preserved_messages,
]
}
def _should_summarize(self, messages: list[AnyMessage], total_tokens: int) -> bool:
"""Determine whether summarization should run for the current token usage."""
if not self._trigger_conditions:
return False
for kind, value in self._trigger_conditions:
if kind == "messages" and len(messages) >= value:
return True
if kind == "tokens" and total_tokens >= value:
return True
if kind == "fraction":
max_input_tokens = self._get_profile_limits()
if max_input_tokens is None:
continue
threshold = int(max_input_tokens * value)
if threshold <= 0:
threshold = 1
if total_tokens >= threshold:
return True
return False
def _determine_cutoff_index(self, messages: list[AnyMessage]) -> int:
"""Choose cutoff index respecting retention configuration."""
kind, value = self.keep
if kind in {"tokens", "fraction"}:
token_based_cutoff = self._find_token_based_cutoff(messages)
if token_based_cutoff is not None:
return token_based_cutoff
# None cutoff -> model profile data not available (caught in __init__ but
# here for safety), fallback to message count
return self._find_safe_cutoff(messages, _DEFAULT_MESSAGES_TO_KEEP)
return self._find_safe_cutoff(messages, cast("int", value))
def _find_token_based_cutoff(self, messages: list[AnyMessage]) -> int | None:
"""Find cutoff index based on target token retention."""
if not messages:
return 0
kind, value = self.keep
if kind == "fraction":
max_input_tokens = self._get_profile_limits()
if max_input_tokens is None:
return None
target_token_count = int(max_input_tokens * value)
elif kind == "tokens":
target_token_count = int(value)
else:
return None
if target_token_count <= 0:
target_token_count = 1
if self.token_counter(messages) <= target_token_count:
return 0
# Use binary search to identify the earliest message index that keeps the
# suffix within the token budget.
left, right = 0, len(messages)
cutoff_candidate = len(messages)
max_iterations = len(messages).bit_length() + 1
for _ in range(max_iterations):
if left >= right:
break
mid = (left + right) // 2
if self.token_counter(messages[mid:]) <= target_token_count:
cutoff_candidate = mid
right = mid
else:
left = mid + 1
if cutoff_candidate == len(messages):
cutoff_candidate = left
if cutoff_candidate >= len(messages):
if len(messages) == 1:
return 0
cutoff_candidate = len(messages) - 1
for i in range(cutoff_candidate, -1, -1):
if self._is_safe_cutoff_point(messages, i):
return i
return 0
def _get_profile_limits(self) -> int | None:
"""Retrieve max input token limit from the model profile."""
try:
profile = self.model.profile
except (AttributeError, ImportError):
return None
if not isinstance(profile, Mapping):
return None
max_input_tokens = profile.get("max_input_tokens")
if not isinstance(max_input_tokens, int):
return None
return max_input_tokens
def _validate_context_size(self, context: ContextSize, parameter_name: str) -> ContextSize:
"""Validate context configuration tuples."""
kind, value = context
if kind == "fraction":
if not 0 < value <= 1:
msg = f"Fractional {parameter_name} values must be between 0 and 1, got {value}."
raise ValueError(msg)
elif kind in {"tokens", "messages"}:
if value <= 0:
msg = f"{parameter_name} thresholds must be greater than 0, got {value}."
raise ValueError(msg)
else:
msg = f"Unsupported context size type {kind} for {parameter_name}."
raise ValueError(msg)
return context
def _build_new_messages(self, summary: str) -> list[HumanMessage]:
return [
HumanMessage(content=f"Here is a summary of the conversation to date:\n\n{summary}")
@@ -151,16 +359,16 @@ class SummarizationMiddleware(AgentMiddleware):
return messages_to_summarize, preserved_messages
def _find_safe_cutoff(self, messages: list[AnyMessage]) -> int:
def _find_safe_cutoff(self, messages: list[AnyMessage], messages_to_keep: int) -> int:
"""Find safe cutoff point that preserves AI/Tool message pairs.
Returns the index where messages can be safely cut without separating
related AI and Tool messages. Returns 0 if no safe cutoff is found.
related AI and Tool messages. Returns `0` if no safe cutoff is found.
"""
if len(messages) <= self.messages_to_keep:
if len(messages) <= messages_to_keep:
return 0
target_cutoff = len(messages) - self.messages_to_keep
target_cutoff = len(messages) - messages_to_keep
for i in range(target_cutoff, -1, -1):
if self._is_safe_cutoff_point(messages, i):
@@ -229,16 +437,35 @@ class SummarizationMiddleware(AgentMiddleware):
try:
response = self.model.invoke(self.summary_prompt.format(messages=trimmed_messages))
return cast("str", response.content).strip()
return response.text.strip()
except Exception as e: # noqa: BLE001
return f"Error generating summary: {e!s}"
async def _acreate_summary(self, messages_to_summarize: list[AnyMessage]) -> str:
"""Generate summary for the given messages."""
if not messages_to_summarize:
return "No previous conversation history."
trimmed_messages = self._trim_messages_for_summary(messages_to_summarize)
if not trimmed_messages:
return "Previous conversation was too long to summarize."
try:
response = await self.model.ainvoke(
self.summary_prompt.format(messages=trimmed_messages)
)
return response.text.strip()
except Exception as e: # noqa: BLE001
return f"Error generating summary: {e!s}"
def _trim_messages_for_summary(self, messages: list[AnyMessage]) -> list[AnyMessage]:
"""Trim messages to fit within summary generation limits."""
try:
if self.trim_tokens_to_summarize is None:
return messages
return trim_messages(
messages,
max_tokens=_DEFAULT_TRIM_TOKEN_LIMIT,
max_tokens=self.trim_tokens_to_summarize,
token_counter=self.token_counter,
start_on="human",
strategy="last",

View File

@@ -150,10 +150,6 @@ class TodoListMiddleware(AgentMiddleware):
print(result["todos"]) # Array of todo items with status tracking
```
Args:
system_prompt: Custom system prompt to guide the agent on using the todo tool.
tool_description: Custom description for the write_todos tool.
"""
state_schema = PlanningState
@@ -198,12 +194,12 @@ class TodoListMiddleware(AgentMiddleware):
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
"""Update the system prompt to include the todo system prompt."""
request.system_prompt = (
new_system_prompt = (
request.system_prompt + "\n\n" + self.system_prompt
if request.system_prompt
else self.system_prompt
)
return handler(request)
return handler(request.override(system_prompt=new_system_prompt))
async def awrap_model_call(
self,
@@ -211,9 +207,9 @@ class TodoListMiddleware(AgentMiddleware):
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelCallResult:
"""Update the system prompt to include the todo system prompt (async version)."""
request.system_prompt = (
new_system_prompt = (
request.system_prompt + "\n\n" + self.system_prompt
if request.system_prompt
else self.system_prompt
)
return await handler(request)
return await handler(request.override(system_prompt=new_system_prompt))

View File

@@ -153,38 +153,46 @@ class ToolCallLimitMiddleware(
are other pending tool calls (due to parallel tool calling).
Examples:
```python title="Continue execution with blocked tools (default)"
from langchain.agents.middleware.tool_call_limit import ToolCallLimitMiddleware
from langchain.agents import create_agent
!!! example "Continue execution with blocked tools (default)"
# Block exceeded tools but let other tools and model continue
limiter = ToolCallLimitMiddleware(
thread_limit=20,
run_limit=10,
exit_behavior="continue", # default
)
```python
from langchain.agents.middleware.tool_call_limit import ToolCallLimitMiddleware
from langchain.agents import create_agent
agent = create_agent("openai:gpt-4o", middleware=[limiter])
```
# Block exceeded tools but let other tools and model continue
limiter = ToolCallLimitMiddleware(
thread_limit=20,
run_limit=10,
exit_behavior="continue", # default
)
```python title="Stop immediately when limit exceeded"
# End execution immediately with an AI message
limiter = ToolCallLimitMiddleware(run_limit=5, exit_behavior="end")
agent = create_agent("openai:gpt-4o", middleware=[limiter])
```
agent = create_agent("openai:gpt-4o", middleware=[limiter])
```
!!! example "Stop immediately when limit exceeded"
```python title="Raise exception on limit"
# Strict limit with exception handling
limiter = ToolCallLimitMiddleware(tool_name="search", thread_limit=5, exit_behavior="error")
```python
# End execution immediately with an AI message
limiter = ToolCallLimitMiddleware(run_limit=5, exit_behavior="end")
agent = create_agent("openai:gpt-4o", middleware=[limiter])
agent = create_agent("openai:gpt-4o", middleware=[limiter])
```
try:
result = await agent.invoke({"messages": [HumanMessage("Task")]})
except ToolCallLimitExceededError as e:
print(f"Search limit exceeded: {e}")
```
!!! example "Raise exception on limit"
```python
# Strict limit with exception handling
limiter = ToolCallLimitMiddleware(
tool_name="search", thread_limit=5, exit_behavior="error"
)
agent = create_agent("openai:gpt-4o", middleware=[limiter])
try:
result = await agent.invoke({"messages": [HumanMessage("Task")]})
except ToolCallLimitExceededError as e:
print(f"Search limit exceeded: {e}")
```
"""
@@ -208,6 +216,7 @@ class ToolCallLimitMiddleware(
run_limit: Maximum number of tool calls allowed per run.
`None` means no limit.
exit_behavior: How to handle when limits are exceeded.
- `'continue'`: Block exceeded tools with error messages, let other
tools continue. Model decides when to end.
- `'error'`: Raise a `ToolCallLimitExceededError` exception
@@ -218,7 +227,7 @@ class ToolCallLimitMiddleware(
Raises:
ValueError: If both limits are `None`, if `exit_behavior` is invalid,
or if `run_limit` exceeds thread_limit.
or if `run_limit` exceeds `thread_limit`.
"""
super().__init__()
@@ -451,3 +460,28 @@ class ToolCallLimitMiddleware(
"run_tool_call_count": run_counts,
"messages": artificial_messages,
}
@hook_config(can_jump_to=["end"])
async def aafter_model(
self,
state: ToolCallLimitState[ResponseT],
runtime: Runtime[ContextT],
) -> dict[str, Any] | None:
"""Async increment tool call counts after a model call and check limits.
Args:
state: The current agent state.
runtime: The langgraph runtime.
Returns:
State updates with incremented tool call counts. If limits are exceeded
and exit_behavior is `'end'`, also includes a jump to end with a
`ToolMessage` and AI message for the single exceeded tool call.
Raises:
ToolCallLimitExceededError: If limits are exceeded and `exit_behavior`
is `'error'`.
NotImplementedError: If limits are exceeded, `exit_behavior` is `'end'`,
and there are multiple tool calls.
"""
return self.after_model(state, runtime)

View File

@@ -25,34 +25,42 @@ class LLMToolEmulator(AgentMiddleware):
This middleware allows selective emulation of tools for testing purposes.
By default (when `tools=None`), all tools are emulated. You can specify which
tools to emulate by passing a list of tool names or BaseTool instances.
tools to emulate by passing a list of tool names or `BaseTool` instances.
Examples:
```python title="Emulate all tools (default behavior)"
from langchain.agents.middleware import LLMToolEmulator
!!! example "Emulate all tools (default behavior)"
middleware = LLMToolEmulator()
```python
from langchain.agents.middleware import LLMToolEmulator
agent = create_agent(
model="openai:gpt-4o",
tools=[get_weather, get_user_location, calculator],
middleware=[middleware],
)
```
middleware = LLMToolEmulator()
```python title="Emulate specific tools by name"
middleware = LLMToolEmulator(tools=["get_weather", "get_user_location"])
```
agent = create_agent(
model="openai:gpt-4o",
tools=[get_weather, get_user_location, calculator],
middleware=[middleware],
)
```
```python title="Use a custom model for emulation"
middleware = LLMToolEmulator(
tools=["get_weather"], model="anthropic:claude-sonnet-4-5-20250929"
)
```
!!! example "Emulate specific tools by name"
```python title="Emulate specific tools by passing tool instances"
middleware = LLMToolEmulator(tools=[get_weather, get_user_location])
```
```python
middleware = LLMToolEmulator(tools=["get_weather", "get_user_location"])
```
!!! example "Use a custom model for emulation"
```python
middleware = LLMToolEmulator(
tools=["get_weather"], model="anthropic:claude-sonnet-4-5-20250929"
)
```
!!! example "Emulate specific tools by passing tool instances"
```python
middleware = LLMToolEmulator(tools=[get_weather, get_user_location])
```
"""
def __init__(

View File

@@ -26,96 +26,96 @@ class ToolRetryMiddleware(AgentMiddleware):
Supports retrying on specific exceptions and exponential backoff.
Examples:
Basic usage with default settings (2 retries, exponential backoff):
!!! example "Basic usage with default settings (2 retries, exponential backoff)"
```python
from langchain.agents import create_agent
from langchain.agents.middleware import ToolRetryMiddleware
```python
from langchain.agents import create_agent
from langchain.agents.middleware import ToolRetryMiddleware
agent = create_agent(model, tools=[search_tool], middleware=[ToolRetryMiddleware()])
```
agent = create_agent(model, tools=[search_tool], middleware=[ToolRetryMiddleware()])
```
Retry specific exceptions only:
!!! example "Retry specific exceptions only"
```python
from requests.exceptions import RequestException, Timeout
```python
from requests.exceptions import RequestException, Timeout
retry = ToolRetryMiddleware(
max_retries=4,
retry_on=(RequestException, Timeout),
backoff_factor=1.5,
)
```
retry = ToolRetryMiddleware(
max_retries=4,
retry_on=(RequestException, Timeout),
backoff_factor=1.5,
)
```
Custom exception filtering:
!!! example "Custom exception filtering"
```python
from requests.exceptions import HTTPError
```python
from requests.exceptions import HTTPError
def should_retry(exc: Exception) -> bool:
# Only retry on 5xx errors
if isinstance(exc, HTTPError):
return 500 <= exc.status_code < 600
return False
def should_retry(exc: Exception) -> bool:
# Only retry on 5xx errors
if isinstance(exc, HTTPError):
return 500 <= exc.status_code < 600
return False
retry = ToolRetryMiddleware(
max_retries=3,
retry_on=should_retry,
)
```
retry = ToolRetryMiddleware(
max_retries=3,
retry_on=should_retry,
)
```
Apply to specific tools with custom error handling:
!!! example "Apply to specific tools with custom error handling"
```python
def format_error(exc: Exception) -> str:
return "Database temporarily unavailable. Please try again later."
```python
def format_error(exc: Exception) -> str:
return "Database temporarily unavailable. Please try again later."
retry = ToolRetryMiddleware(
max_retries=4,
tools=["search_database"],
on_failure=format_error,
)
```
retry = ToolRetryMiddleware(
max_retries=4,
tools=["search_database"],
on_failure=format_error,
)
```
Apply to specific tools using BaseTool instances:
!!! example "Apply to specific tools using `BaseTool` instances"
```python
from langchain_core.tools import tool
```python
from langchain_core.tools import tool
@tool
def search_database(query: str) -> str:
'''Search the database.'''
return results
@tool
def search_database(query: str) -> str:
'''Search the database.'''
return results
retry = ToolRetryMiddleware(
max_retries=4,
tools=[search_database], # Pass BaseTool instance
)
```
retry = ToolRetryMiddleware(
max_retries=4,
tools=[search_database], # Pass BaseTool instance
)
```
Constant backoff (no exponential growth):
!!! example "Constant backoff (no exponential growth)"
```python
retry = ToolRetryMiddleware(
max_retries=5,
backoff_factor=0.0, # No exponential growth
initial_delay=2.0, # Always wait 2 seconds
)
```
```python
retry = ToolRetryMiddleware(
max_retries=5,
backoff_factor=0.0, # No exponential growth
initial_delay=2.0, # Always wait 2 seconds
)
```
Raise exception on failure:
!!! example "Raise exception on failure"
```python
retry = ToolRetryMiddleware(
max_retries=2,
on_failure="raise", # Re-raise exception instead of returning message
)
```
```python
retry = ToolRetryMiddleware(
max_retries=2,
on_failure="raise", # Re-raise exception instead of returning message
)
```
"""
def __init__(
@@ -136,7 +136,10 @@ class ToolRetryMiddleware(AgentMiddleware):
Args:
max_retries: Maximum number of retry attempts after the initial call.
Default is `2` retries (`3` total attempts). Must be `>= 0`.
Default is `2` retries (`3` total attempts).
Must be `>= 0`.
tools: Optional list of tools or tool names to apply retry logic to.
Can be a list of `BaseTool` instances or tool name strings.
@@ -146,12 +149,14 @@ class ToolRetryMiddleware(AgentMiddleware):
that takes an exception and returns `True` if it should be retried.
Default is to retry on all exceptions.
on_failure: Behavior when all retries are exhausted. Options:
on_failure: Behavior when all retries are exhausted.
Options:
- `'return_message'`: Return a `ToolMessage` with error details,
allowing the LLM to handle the failure and potentially recover.
- `'raise'`: Re-raise the exception, stopping agent execution.
- Custom callable: Function that takes the exception and returns a
- **Custom callable:** Function that takes the exception and returns a
string for the `ToolMessage` content, allowing custom error
formatting.
backoff_factor: Multiplier for exponential backoff.

View File

@@ -93,21 +93,25 @@ class LLMToolSelectorMiddleware(AgentMiddleware):
and helps the main model focus on the right tools.
Examples:
```python title="Limit to 3 tools"
from langchain.agents.middleware import LLMToolSelectorMiddleware
!!! example "Limit to 3 tools"
middleware = LLMToolSelectorMiddleware(max_tools=3)
```python
from langchain.agents.middleware import LLMToolSelectorMiddleware
agent = create_agent(
model="openai:gpt-4o",
tools=[tool1, tool2, tool3, tool4, tool5],
middleware=[middleware],
)
```
middleware = LLMToolSelectorMiddleware(max_tools=3)
```python title="Use a smaller model for selection"
middleware = LLMToolSelectorMiddleware(model="openai:gpt-4o-mini", max_tools=2)
```
agent = create_agent(
model="openai:gpt-4o",
tools=[tool1, tool2, tool3, tool4, tool5],
middleware=[middleware],
)
```
!!! example "Use a smaller model for selection"
```python
middleware = LLMToolSelectorMiddleware(model="openai:gpt-4o-mini", max_tools=2)
```
"""
def __init__(
@@ -131,7 +135,7 @@ class LLMToolSelectorMiddleware(AgentMiddleware):
If the model selects more, only the first `max_tools` will be used.
No limit if not specified.
If not specified, there is no limit.
always_include: Tool names to always include regardless of selection.
These do not count against the `max_tools` limit.
@@ -251,8 +255,7 @@ class LLMToolSelectorMiddleware(AgentMiddleware):
# Also preserve any provider-specific tool dicts from the original request
provider_tools = [tool for tool in request.tools if isinstance(tool, dict)]
request.tools = [*selected_tools, *provider_tools]
return request
return request.override(tools=[*selected_tools, *provider_tools])
def wrap_model_call(
self,

File diff suppressed because it is too large Load Diff

View File

@@ -125,7 +125,7 @@ def init_chat_model(
- `ollama` -> [`langchain-ollama`](https://docs.langchain.com/oss/python/integrations/providers/ollama)
- `google_anthropic_vertex` -> [`langchain-google-vertexai`](https://docs.langchain.com/oss/python/integrations/providers/google)
- `deepseek` -> [`langchain-deepseek`](https://docs.langchain.com/oss/python/integrations/providers/deepseek)
- `ibm` -> [`langchain-ibm`](https://docs.langchain.com/oss/python/integrations/providers/deepseek)
- `ibm` -> [`langchain-ibm`](https://docs.langchain.com/oss/python/integrations/providers/ibm)
- `nvidia` -> [`langchain-nvidia-ai-endpoints`](https://docs.langchain.com/oss/python/integrations/providers/nvidia)
- `xai` -> [`langchain-xai`](https://docs.langchain.com/oss/python/integrations/providers/xai)
- `perplexity` -> [`langchain-perplexity`](https://docs.langchain.com/oss/python/integrations/providers/perplexity)

View File

@@ -57,6 +57,7 @@ test = [
"pytest-mock",
"syrupy>=4.0.2,<5.0.0",
"toml>=0.10.2,<1.0.0",
"langchain-model-profiles",
"langchain-tests",
"langchain-openai",
]
@@ -75,6 +76,7 @@ test_integration = [
"cassio>=0.1.0,<1.0.0",
"langchainhub>=0.1.16,<1.0.0",
"langchain-core",
"langchain-model-profiles",
"langchain-text-splitters",
]
@@ -83,6 +85,7 @@ prerelease = "allow"
[tool.uv.sources]
langchain-core = { path = "../core", editable = true }
langchain-model-profiles = { path = "../model-profiles", editable = true }
langchain-tests = { path = "../standard-tests", editable = true }
langchain-text-splitters = { path = "../text-splitters", editable = true }
langchain-openai = { path = "../partners/openai", editable = true }

View File

@@ -1,79 +0,0 @@
import pytest
from langchain_core.messages import HumanMessage
from pydantic import BaseModel, Field
from langchain.agents import create_agent
from langchain.agents.structured_output import ToolStrategy
class WeatherBaseModel(BaseModel):
"""Weather response."""
temperature: float = Field(description="The temperature in fahrenheit")
condition: str = Field(description="Weather condition")
def get_weather(city: str) -> str: # noqa: ARG001
"""Get the weather for a city."""
return "The weather is sunny and 75°F."
@pytest.mark.requires("langchain_openai")
def test_inference_to_native_output() -> None:
"""Test that native output is inferred when a model supports it."""
from langchain_openai import ChatOpenAI
model = ChatOpenAI(model="gpt-5")
agent = create_agent(
model,
system_prompt=(
"You are a helpful weather assistant. Please call the get_weather tool, "
"then use the WeatherReport tool to generate the final response."
),
tools=[get_weather],
response_format=WeatherBaseModel,
)
response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
assert isinstance(response["structured_response"], WeatherBaseModel)
assert response["structured_response"].temperature == 75.0
assert response["structured_response"].condition.lower() == "sunny"
assert len(response["messages"]) == 4
assert [m.type for m in response["messages"]] == [
"human", # "What's the weather?"
"ai", # "What's the weather?"
"tool", # "The weather is sunny and 75°F."
"ai", # structured response
]
@pytest.mark.requires("langchain_openai")
def test_inference_to_tool_output() -> None:
"""Test that tool output is inferred when a model supports it."""
from langchain_openai import ChatOpenAI
model = ChatOpenAI(model="gpt-4")
agent = create_agent(
model,
system_prompt=(
"You are a helpful weather assistant. Please call the get_weather tool, "
"then use the WeatherReport tool to generate the final response."
),
tools=[get_weather],
response_format=ToolStrategy(WeatherBaseModel),
)
response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
assert isinstance(response["structured_response"], WeatherBaseModel)
assert response["structured_response"].temperature == 75.0
assert response["structured_response"].condition.lower() == "sunny"
assert len(response["messages"]) == 5
assert [m.type for m in response["messages"]] == [
"human", # "What's the weather?"
"ai", # "What's the weather?"
"tool", # "The weather is sunny and 75°F."
"ai", # structured response
"tool", # artificial tool message
]

View File

@@ -0,0 +1,212 @@
# serializer version: 1
# name: test_agent_graph_with_jump_to_end_as_after_agent
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
tools(tools)
NoopZero\2ebefore_agent(NoopZero.before_agent)
NoopOne\2eafter_agent(NoopOne.after_agent)
NoopTwo\2eafter_agent(NoopTwo.after_agent)
__end__([<p>__end__</p>]):::last
NoopTwo\2eafter_agent --> NoopOne\2eafter_agent;
NoopZero\2ebefore_agent -.-> NoopTwo\2eafter_agent;
NoopZero\2ebefore_agent -.-> model;
__start__ --> NoopZero\2ebefore_agent;
model -.-> NoopTwo\2eafter_agent;
model -.-> tools;
tools -.-> model;
NoopOne\2eafter_agent --> __end__;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_create_agent_jump[memory]
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
tools(tools)
NoopSeven\2ebefore_model(NoopSeven.before_model)
NoopSeven\2eafter_model(NoopSeven.after_model)
NoopEight\2ebefore_model(NoopEight.before_model)
NoopEight\2eafter_model(NoopEight.after_model)
__end__([<p>__end__</p>]):::last
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
NoopEight\2ebefore_model -.-> __end__;
NoopEight\2ebefore_model -.-> model;
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
NoopSeven\2eafter_model -.-> __end__;
NoopSeven\2eafter_model -.-> tools;
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
__start__ --> NoopSeven\2ebefore_model;
model --> NoopEight\2eafter_model;
tools -.-> NoopSeven\2ebefore_model;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_create_agent_jump[postgres]
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
tools(tools)
NoopSeven\2ebefore_model(NoopSeven.before_model)
NoopSeven\2eafter_model(NoopSeven.after_model)
NoopEight\2ebefore_model(NoopEight.before_model)
NoopEight\2eafter_model(NoopEight.after_model)
__end__([<p>__end__</p>]):::last
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
NoopEight\2ebefore_model -.-> __end__;
NoopEight\2ebefore_model -.-> model;
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
NoopSeven\2eafter_model -.-> __end__;
NoopSeven\2eafter_model -.-> tools;
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
__start__ --> NoopSeven\2ebefore_model;
model --> NoopEight\2eafter_model;
tools -.-> NoopSeven\2ebefore_model;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_create_agent_jump[postgres_pipe]
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
tools(tools)
NoopSeven\2ebefore_model(NoopSeven.before_model)
NoopSeven\2eafter_model(NoopSeven.after_model)
NoopEight\2ebefore_model(NoopEight.before_model)
NoopEight\2eafter_model(NoopEight.after_model)
__end__([<p>__end__</p>]):::last
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
NoopEight\2ebefore_model -.-> __end__;
NoopEight\2ebefore_model -.-> model;
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
NoopSeven\2eafter_model -.-> __end__;
NoopSeven\2eafter_model -.-> tools;
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
__start__ --> NoopSeven\2ebefore_model;
model --> NoopEight\2eafter_model;
tools -.-> NoopSeven\2ebefore_model;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_create_agent_jump[postgres_pool]
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
tools(tools)
NoopSeven\2ebefore_model(NoopSeven.before_model)
NoopSeven\2eafter_model(NoopSeven.after_model)
NoopEight\2ebefore_model(NoopEight.before_model)
NoopEight\2eafter_model(NoopEight.after_model)
__end__([<p>__end__</p>]):::last
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
NoopEight\2ebefore_model -.-> __end__;
NoopEight\2ebefore_model -.-> model;
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
NoopSeven\2eafter_model -.-> __end__;
NoopSeven\2eafter_model -.-> tools;
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
__start__ --> NoopSeven\2ebefore_model;
model --> NoopEight\2eafter_model;
tools -.-> NoopSeven\2ebefore_model;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_create_agent_jump[sqlite]
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
tools(tools)
NoopSeven\2ebefore_model(NoopSeven.before_model)
NoopSeven\2eafter_model(NoopSeven.after_model)
NoopEight\2ebefore_model(NoopEight.before_model)
NoopEight\2eafter_model(NoopEight.after_model)
__end__([<p>__end__</p>]):::last
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
NoopEight\2ebefore_model -.-> __end__;
NoopEight\2ebefore_model -.-> model;
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
NoopSeven\2eafter_model -.-> __end__;
NoopSeven\2eafter_model -.-> tools;
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
__start__ --> NoopSeven\2ebefore_model;
model --> NoopEight\2eafter_model;
tools -.-> NoopSeven\2ebefore_model;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_simple_agent_graph
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
tools(tools)
__end__([<p>__end__</p>]):::last
__start__ --> model;
model -.-> __end__;
model -.-> tools;
tools -.-> model;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---

View File

@@ -0,0 +1,95 @@
# serializer version: 1
# name: test_async_middleware_with_can_jump_to_graph_snapshot
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
async_before_with_jump\2ebefore_model(async_before_with_jump.before_model)
__end__([<p>__end__</p>]):::last
__start__ --> async_before_with_jump\2ebefore_model;
async_before_with_jump\2ebefore_model -.-> __end__;
async_before_with_jump\2ebefore_model -.-> model;
model --> __end__;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_async_middleware_with_can_jump_to_graph_snapshot.1
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
async_after_with_jump\2eafter_model(async_after_with_jump.after_model)
__end__([<p>__end__</p>]):::last
__start__ --> model;
async_after_with_jump\2eafter_model -.-> __end__;
async_after_with_jump\2eafter_model -.-> model;
model --> async_after_with_jump\2eafter_model;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_async_middleware_with_can_jump_to_graph_snapshot.2
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
async_before_early_exit\2ebefore_model(async_before_early_exit.before_model)
async_after_retry\2eafter_model(async_after_retry.after_model)
__end__([<p>__end__</p>]):::last
__start__ --> async_before_early_exit\2ebefore_model;
async_after_retry\2eafter_model -.-> __end__;
async_after_retry\2eafter_model -.-> async_before_early_exit\2ebefore_model;
async_before_early_exit\2ebefore_model -.-> __end__;
async_before_early_exit\2ebefore_model -.-> model;
model --> async_after_retry\2eafter_model;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_async_middleware_with_can_jump_to_graph_snapshot.3
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
sync_before_with_jump\2ebefore_model(sync_before_with_jump.before_model)
async_after_with_jumps\2eafter_model(async_after_with_jumps.after_model)
__end__([<p>__end__</p>]):::last
__start__ --> sync_before_with_jump\2ebefore_model;
async_after_with_jumps\2eafter_model -.-> __end__;
async_after_with_jumps\2eafter_model -.-> sync_before_with_jump\2ebefore_model;
model --> async_after_with_jumps\2eafter_model;
sync_before_with_jump\2ebefore_model -.-> __end__;
sync_before_with_jump\2ebefore_model -.-> model;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---

View File

@@ -0,0 +1,289 @@
# serializer version: 1
# name: test_create_agent_diagram
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
__end__([<p>__end__</p>]):::last
__start__ --> model;
model --> __end__;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_create_agent_diagram.1
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
NoopOne\2ebefore_model(NoopOne.before_model)
__end__([<p>__end__</p>]):::last
NoopOne\2ebefore_model --> model;
__start__ --> NoopOne\2ebefore_model;
model --> __end__;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_create_agent_diagram.10
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
NoopTen\2ebefore_model(NoopTen.before_model)
NoopTen\2eafter_model(NoopTen.after_model)
__end__([<p>__end__</p>]):::last
NoopTen\2ebefore_model --> model;
__start__ --> NoopTen\2ebefore_model;
model --> NoopTen\2eafter_model;
NoopTen\2eafter_model --> __end__;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_create_agent_diagram.11
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
NoopTen\2ebefore_model(NoopTen.before_model)
NoopTen\2eafter_model(NoopTen.after_model)
NoopEleven\2ebefore_model(NoopEleven.before_model)
NoopEleven\2eafter_model(NoopEleven.after_model)
__end__([<p>__end__</p>]):::last
NoopEleven\2eafter_model --> NoopTen\2eafter_model;
NoopEleven\2ebefore_model --> model;
NoopTen\2ebefore_model --> NoopEleven\2ebefore_model;
__start__ --> NoopTen\2ebefore_model;
model --> NoopEleven\2eafter_model;
NoopTen\2eafter_model --> __end__;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_create_agent_diagram.2
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
NoopOne\2ebefore_model(NoopOne.before_model)
NoopTwo\2ebefore_model(NoopTwo.before_model)
__end__([<p>__end__</p>]):::last
NoopOne\2ebefore_model --> NoopTwo\2ebefore_model;
NoopTwo\2ebefore_model --> model;
__start__ --> NoopOne\2ebefore_model;
model --> __end__;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_create_agent_diagram.3
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
NoopOne\2ebefore_model(NoopOne.before_model)
NoopTwo\2ebefore_model(NoopTwo.before_model)
NoopThree\2ebefore_model(NoopThree.before_model)
__end__([<p>__end__</p>]):::last
NoopOne\2ebefore_model --> NoopTwo\2ebefore_model;
NoopThree\2ebefore_model --> model;
NoopTwo\2ebefore_model --> NoopThree\2ebefore_model;
__start__ --> NoopOne\2ebefore_model;
model --> __end__;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_create_agent_diagram.4
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
NoopFour\2eafter_model(NoopFour.after_model)
__end__([<p>__end__</p>]):::last
__start__ --> model;
model --> NoopFour\2eafter_model;
NoopFour\2eafter_model --> __end__;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_create_agent_diagram.5
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
NoopFour\2eafter_model(NoopFour.after_model)
NoopFive\2eafter_model(NoopFive.after_model)
__end__([<p>__end__</p>]):::last
NoopFive\2eafter_model --> NoopFour\2eafter_model;
__start__ --> model;
model --> NoopFive\2eafter_model;
NoopFour\2eafter_model --> __end__;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_create_agent_diagram.6
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
NoopFour\2eafter_model(NoopFour.after_model)
NoopFive\2eafter_model(NoopFive.after_model)
NoopSix\2eafter_model(NoopSix.after_model)
__end__([<p>__end__</p>]):::last
NoopFive\2eafter_model --> NoopFour\2eafter_model;
NoopSix\2eafter_model --> NoopFive\2eafter_model;
__start__ --> model;
model --> NoopSix\2eafter_model;
NoopFour\2eafter_model --> __end__;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_create_agent_diagram.7
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
NoopSeven\2ebefore_model(NoopSeven.before_model)
NoopSeven\2eafter_model(NoopSeven.after_model)
__end__([<p>__end__</p>]):::last
NoopSeven\2ebefore_model --> model;
__start__ --> NoopSeven\2ebefore_model;
model --> NoopSeven\2eafter_model;
NoopSeven\2eafter_model --> __end__;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_create_agent_diagram.8
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
NoopSeven\2ebefore_model(NoopSeven.before_model)
NoopSeven\2eafter_model(NoopSeven.after_model)
NoopEight\2ebefore_model(NoopEight.before_model)
NoopEight\2eafter_model(NoopEight.after_model)
__end__([<p>__end__</p>]):::last
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
NoopEight\2ebefore_model --> model;
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
__start__ --> NoopSeven\2ebefore_model;
model --> NoopEight\2eafter_model;
NoopSeven\2eafter_model --> __end__;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_create_agent_diagram.9
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
NoopSeven\2ebefore_model(NoopSeven.before_model)
NoopSeven\2eafter_model(NoopSeven.after_model)
NoopEight\2ebefore_model(NoopEight.before_model)
NoopEight\2eafter_model(NoopEight.after_model)
NoopNine\2ebefore_model(NoopNine.before_model)
NoopNine\2eafter_model(NoopNine.after_model)
__end__([<p>__end__</p>]):::last
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
NoopEight\2ebefore_model --> NoopNine\2ebefore_model;
NoopNine\2eafter_model --> NoopEight\2eafter_model;
NoopNine\2ebefore_model --> model;
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
__start__ --> NoopSeven\2ebefore_model;
model --> NoopNine\2eafter_model;
NoopSeven\2eafter_model --> __end__;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---

View File

@@ -0,0 +1,212 @@
# serializer version: 1
# name: test_agent_graph_with_jump_to_end_as_after_agent
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
tools(tools)
NoopZero\2ebefore_agent(NoopZero.before_agent)
NoopOne\2eafter_agent(NoopOne.after_agent)
NoopTwo\2eafter_agent(NoopTwo.after_agent)
__end__([<p>__end__</p>]):::last
NoopTwo\2eafter_agent --> NoopOne\2eafter_agent;
NoopZero\2ebefore_agent -.-> NoopTwo\2eafter_agent;
NoopZero\2ebefore_agent -.-> model;
__start__ --> NoopZero\2ebefore_agent;
model -.-> NoopTwo\2eafter_agent;
model -.-> tools;
tools -.-> model;
NoopOne\2eafter_agent --> __end__;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_create_agent_jump[memory]
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
tools(tools)
NoopSeven\2ebefore_model(NoopSeven.before_model)
NoopSeven\2eafter_model(NoopSeven.after_model)
NoopEight\2ebefore_model(NoopEight.before_model)
NoopEight\2eafter_model(NoopEight.after_model)
__end__([<p>__end__</p>]):::last
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
NoopEight\2ebefore_model -.-> __end__;
NoopEight\2ebefore_model -.-> model;
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
NoopSeven\2eafter_model -.-> __end__;
NoopSeven\2eafter_model -.-> tools;
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
__start__ --> NoopSeven\2ebefore_model;
model --> NoopEight\2eafter_model;
tools -.-> NoopSeven\2ebefore_model;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_create_agent_jump[postgres]
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
tools(tools)
NoopSeven\2ebefore_model(NoopSeven.before_model)
NoopSeven\2eafter_model(NoopSeven.after_model)
NoopEight\2ebefore_model(NoopEight.before_model)
NoopEight\2eafter_model(NoopEight.after_model)
__end__([<p>__end__</p>]):::last
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
NoopEight\2ebefore_model -.-> __end__;
NoopEight\2ebefore_model -.-> model;
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
NoopSeven\2eafter_model -.-> __end__;
NoopSeven\2eafter_model -.-> tools;
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
__start__ --> NoopSeven\2ebefore_model;
model --> NoopEight\2eafter_model;
tools -.-> NoopSeven\2ebefore_model;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_create_agent_jump[postgres_pipe]
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
tools(tools)
NoopSeven\2ebefore_model(NoopSeven.before_model)
NoopSeven\2eafter_model(NoopSeven.after_model)
NoopEight\2ebefore_model(NoopEight.before_model)
NoopEight\2eafter_model(NoopEight.after_model)
__end__([<p>__end__</p>]):::last
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
NoopEight\2ebefore_model -.-> __end__;
NoopEight\2ebefore_model -.-> model;
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
NoopSeven\2eafter_model -.-> __end__;
NoopSeven\2eafter_model -.-> tools;
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
__start__ --> NoopSeven\2ebefore_model;
model --> NoopEight\2eafter_model;
tools -.-> NoopSeven\2ebefore_model;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_create_agent_jump[postgres_pool]
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
tools(tools)
NoopSeven\2ebefore_model(NoopSeven.before_model)
NoopSeven\2eafter_model(NoopSeven.after_model)
NoopEight\2ebefore_model(NoopEight.before_model)
NoopEight\2eafter_model(NoopEight.after_model)
__end__([<p>__end__</p>]):::last
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
NoopEight\2ebefore_model -.-> __end__;
NoopEight\2ebefore_model -.-> model;
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
NoopSeven\2eafter_model -.-> __end__;
NoopSeven\2eafter_model -.-> tools;
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
__start__ --> NoopSeven\2ebefore_model;
model --> NoopEight\2eafter_model;
tools -.-> NoopSeven\2ebefore_model;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_create_agent_jump[sqlite]
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
tools(tools)
NoopSeven\2ebefore_model(NoopSeven.before_model)
NoopSeven\2eafter_model(NoopSeven.after_model)
NoopEight\2ebefore_model(NoopEight.before_model)
NoopEight\2eafter_model(NoopEight.after_model)
__end__([<p>__end__</p>]):::last
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
NoopEight\2ebefore_model -.-> __end__;
NoopEight\2ebefore_model -.-> model;
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
NoopSeven\2eafter_model -.-> __end__;
NoopSeven\2eafter_model -.-> tools;
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
__start__ --> NoopSeven\2ebefore_model;
model --> NoopEight\2eafter_model;
tools -.-> NoopSeven\2ebefore_model;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_simple_agent_graph
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
tools(tools)
__end__([<p>__end__</p>]):::last
__start__ --> model;
model -.-> __end__;
model -.-> tools;
tools -.-> model;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---

View File

@@ -0,0 +1,95 @@
# serializer version: 1
# name: test_async_middleware_with_can_jump_to_graph_snapshot
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
async_before_with_jump\2ebefore_model(async_before_with_jump.before_model)
__end__([<p>__end__</p>]):::last
__start__ --> async_before_with_jump\2ebefore_model;
async_before_with_jump\2ebefore_model -.-> __end__;
async_before_with_jump\2ebefore_model -.-> model;
model --> __end__;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_async_middleware_with_can_jump_to_graph_snapshot.1
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
async_after_with_jump\2eafter_model(async_after_with_jump.after_model)
__end__([<p>__end__</p>]):::last
__start__ --> model;
async_after_with_jump\2eafter_model -.-> __end__;
async_after_with_jump\2eafter_model -.-> model;
model --> async_after_with_jump\2eafter_model;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_async_middleware_with_can_jump_to_graph_snapshot.2
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
async_before_early_exit\2ebefore_model(async_before_early_exit.before_model)
async_after_retry\2eafter_model(async_after_retry.after_model)
__end__([<p>__end__</p>]):::last
__start__ --> async_before_early_exit\2ebefore_model;
async_after_retry\2eafter_model -.-> __end__;
async_after_retry\2eafter_model -.-> async_before_early_exit\2ebefore_model;
async_before_early_exit\2ebefore_model -.-> __end__;
async_before_early_exit\2ebefore_model -.-> model;
model --> async_after_retry\2eafter_model;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_async_middleware_with_can_jump_to_graph_snapshot.3
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
sync_before_with_jump\2ebefore_model(sync_before_with_jump.before_model)
async_after_with_jumps\2eafter_model(async_after_with_jumps.after_model)
__end__([<p>__end__</p>]):::last
__start__ --> sync_before_with_jump\2ebefore_model;
async_after_with_jumps\2eafter_model -.-> __end__;
async_after_with_jumps\2eafter_model -.-> sync_before_with_jump\2ebefore_model;
model --> async_after_with_jumps\2eafter_model;
sync_before_with_jump\2ebefore_model -.-> __end__;
sync_before_with_jump\2ebefore_model -.-> model;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---

View File

@@ -0,0 +1,289 @@
# serializer version: 1
# name: test_create_agent_diagram
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
__end__([<p>__end__</p>]):::last
__start__ --> model;
model --> __end__;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_create_agent_diagram.1
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
NoopOne\2ebefore_model(NoopOne.before_model)
__end__([<p>__end__</p>]):::last
NoopOne\2ebefore_model --> model;
__start__ --> NoopOne\2ebefore_model;
model --> __end__;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_create_agent_diagram.10
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
NoopTen\2ebefore_model(NoopTen.before_model)
NoopTen\2eafter_model(NoopTen.after_model)
__end__([<p>__end__</p>]):::last
NoopTen\2ebefore_model --> model;
__start__ --> NoopTen\2ebefore_model;
model --> NoopTen\2eafter_model;
NoopTen\2eafter_model --> __end__;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_create_agent_diagram.11
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
NoopTen\2ebefore_model(NoopTen.before_model)
NoopTen\2eafter_model(NoopTen.after_model)
NoopEleven\2ebefore_model(NoopEleven.before_model)
NoopEleven\2eafter_model(NoopEleven.after_model)
__end__([<p>__end__</p>]):::last
NoopEleven\2eafter_model --> NoopTen\2eafter_model;
NoopEleven\2ebefore_model --> model;
NoopTen\2ebefore_model --> NoopEleven\2ebefore_model;
__start__ --> NoopTen\2ebefore_model;
model --> NoopEleven\2eafter_model;
NoopTen\2eafter_model --> __end__;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_create_agent_diagram.2
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
NoopOne\2ebefore_model(NoopOne.before_model)
NoopTwo\2ebefore_model(NoopTwo.before_model)
__end__([<p>__end__</p>]):::last
NoopOne\2ebefore_model --> NoopTwo\2ebefore_model;
NoopTwo\2ebefore_model --> model;
__start__ --> NoopOne\2ebefore_model;
model --> __end__;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_create_agent_diagram.3
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
NoopOne\2ebefore_model(NoopOne.before_model)
NoopTwo\2ebefore_model(NoopTwo.before_model)
NoopThree\2ebefore_model(NoopThree.before_model)
__end__([<p>__end__</p>]):::last
NoopOne\2ebefore_model --> NoopTwo\2ebefore_model;
NoopThree\2ebefore_model --> model;
NoopTwo\2ebefore_model --> NoopThree\2ebefore_model;
__start__ --> NoopOne\2ebefore_model;
model --> __end__;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_create_agent_diagram.4
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
NoopFour\2eafter_model(NoopFour.after_model)
__end__([<p>__end__</p>]):::last
__start__ --> model;
model --> NoopFour\2eafter_model;
NoopFour\2eafter_model --> __end__;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_create_agent_diagram.5
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
NoopFour\2eafter_model(NoopFour.after_model)
NoopFive\2eafter_model(NoopFive.after_model)
__end__([<p>__end__</p>]):::last
NoopFive\2eafter_model --> NoopFour\2eafter_model;
__start__ --> model;
model --> NoopFive\2eafter_model;
NoopFour\2eafter_model --> __end__;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_create_agent_diagram.6
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
NoopFour\2eafter_model(NoopFour.after_model)
NoopFive\2eafter_model(NoopFive.after_model)
NoopSix\2eafter_model(NoopSix.after_model)
__end__([<p>__end__</p>]):::last
NoopFive\2eafter_model --> NoopFour\2eafter_model;
NoopSix\2eafter_model --> NoopFive\2eafter_model;
__start__ --> model;
model --> NoopSix\2eafter_model;
NoopFour\2eafter_model --> __end__;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_create_agent_diagram.7
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
NoopSeven\2ebefore_model(NoopSeven.before_model)
NoopSeven\2eafter_model(NoopSeven.after_model)
__end__([<p>__end__</p>]):::last
NoopSeven\2ebefore_model --> model;
__start__ --> NoopSeven\2ebefore_model;
model --> NoopSeven\2eafter_model;
NoopSeven\2eafter_model --> __end__;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_create_agent_diagram.8
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
NoopSeven\2ebefore_model(NoopSeven.before_model)
NoopSeven\2eafter_model(NoopSeven.after_model)
NoopEight\2ebefore_model(NoopEight.before_model)
NoopEight\2eafter_model(NoopEight.after_model)
__end__([<p>__end__</p>]):::last
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
NoopEight\2ebefore_model --> model;
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
__start__ --> NoopSeven\2ebefore_model;
model --> NoopEight\2eafter_model;
NoopSeven\2eafter_model --> __end__;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_create_agent_diagram.9
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
NoopSeven\2ebefore_model(NoopSeven.before_model)
NoopSeven\2eafter_model(NoopSeven.after_model)
NoopEight\2ebefore_model(NoopEight.before_model)
NoopEight\2eafter_model(NoopEight.after_model)
NoopNine\2ebefore_model(NoopNine.before_model)
NoopNine\2eafter_model(NoopNine.after_model)
__end__([<p>__end__</p>]):::last
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
NoopEight\2ebefore_model --> NoopNine\2ebefore_model;
NoopNine\2eafter_model --> NoopEight\2eafter_model;
NoopNine\2ebefore_model --> model;
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
__start__ --> NoopSeven\2ebefore_model;
model --> NoopNine\2eafter_model;
NoopSeven\2eafter_model --> __end__;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---

View File

@@ -0,0 +1,212 @@
# serializer version: 1
# name: test_agent_graph_with_jump_to_end_as_after_agent
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
tools(tools)
NoopZero\2ebefore_agent(NoopZero.before_agent)
NoopOne\2eafter_agent(NoopOne.after_agent)
NoopTwo\2eafter_agent(NoopTwo.after_agent)
__end__([<p>__end__</p>]):::last
NoopTwo\2eafter_agent --> NoopOne\2eafter_agent;
NoopZero\2ebefore_agent -.-> NoopTwo\2eafter_agent;
NoopZero\2ebefore_agent -.-> model;
__start__ --> NoopZero\2ebefore_agent;
model -.-> NoopTwo\2eafter_agent;
model -.-> tools;
tools -.-> model;
NoopOne\2eafter_agent --> __end__;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_create_agent_jump[memory]
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
tools(tools)
NoopSeven\2ebefore_model(NoopSeven.before_model)
NoopSeven\2eafter_model(NoopSeven.after_model)
NoopEight\2ebefore_model(NoopEight.before_model)
NoopEight\2eafter_model(NoopEight.after_model)
__end__([<p>__end__</p>]):::last
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
NoopEight\2ebefore_model -.-> __end__;
NoopEight\2ebefore_model -.-> model;
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
NoopSeven\2eafter_model -.-> __end__;
NoopSeven\2eafter_model -.-> tools;
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
__start__ --> NoopSeven\2ebefore_model;
model --> NoopEight\2eafter_model;
tools -.-> NoopSeven\2ebefore_model;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_create_agent_jump[postgres]
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
tools(tools)
NoopSeven\2ebefore_model(NoopSeven.before_model)
NoopSeven\2eafter_model(NoopSeven.after_model)
NoopEight\2ebefore_model(NoopEight.before_model)
NoopEight\2eafter_model(NoopEight.after_model)
__end__([<p>__end__</p>]):::last
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
NoopEight\2ebefore_model -.-> __end__;
NoopEight\2ebefore_model -.-> model;
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
NoopSeven\2eafter_model -.-> __end__;
NoopSeven\2eafter_model -.-> tools;
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
__start__ --> NoopSeven\2ebefore_model;
model --> NoopEight\2eafter_model;
tools -.-> NoopSeven\2ebefore_model;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_create_agent_jump[postgres_pipe]
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
tools(tools)
NoopSeven\2ebefore_model(NoopSeven.before_model)
NoopSeven\2eafter_model(NoopSeven.after_model)
NoopEight\2ebefore_model(NoopEight.before_model)
NoopEight\2eafter_model(NoopEight.after_model)
__end__([<p>__end__</p>]):::last
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
NoopEight\2ebefore_model -.-> __end__;
NoopEight\2ebefore_model -.-> model;
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
NoopSeven\2eafter_model -.-> __end__;
NoopSeven\2eafter_model -.-> tools;
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
__start__ --> NoopSeven\2ebefore_model;
model --> NoopEight\2eafter_model;
tools -.-> NoopSeven\2ebefore_model;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_create_agent_jump[postgres_pool]
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
tools(tools)
NoopSeven\2ebefore_model(NoopSeven.before_model)
NoopSeven\2eafter_model(NoopSeven.after_model)
NoopEight\2ebefore_model(NoopEight.before_model)
NoopEight\2eafter_model(NoopEight.after_model)
__end__([<p>__end__</p>]):::last
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
NoopEight\2ebefore_model -.-> __end__;
NoopEight\2ebefore_model -.-> model;
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
NoopSeven\2eafter_model -.-> __end__;
NoopSeven\2eafter_model -.-> tools;
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
__start__ --> NoopSeven\2ebefore_model;
model --> NoopEight\2eafter_model;
tools -.-> NoopSeven\2ebefore_model;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_create_agent_jump[sqlite]
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
tools(tools)
NoopSeven\2ebefore_model(NoopSeven.before_model)
NoopSeven\2eafter_model(NoopSeven.after_model)
NoopEight\2ebefore_model(NoopEight.before_model)
NoopEight\2eafter_model(NoopEight.after_model)
__end__([<p>__end__</p>]):::last
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
NoopEight\2ebefore_model -.-> __end__;
NoopEight\2ebefore_model -.-> model;
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
NoopSeven\2eafter_model -.-> __end__;
NoopSeven\2eafter_model -.-> tools;
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
__start__ --> NoopSeven\2ebefore_model;
model --> NoopEight\2eafter_model;
tools -.-> NoopSeven\2ebefore_model;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_simple_agent_graph
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
tools(tools)
__end__([<p>__end__</p>]):::last
__start__ --> model;
model -.-> __end__;
model -.-> tools;
tools -.-> model;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---

View File

@@ -230,9 +230,7 @@ class TestChainModelCallHandlers:
test_runtime = {"test": "runtime"}
# Create request with state and runtime
test_request = create_test_request()
test_request.state = test_state
test_request.runtime = test_runtime
test_request = create_test_request(state=test_state, runtime=test_runtime)
result = composed(test_request, create_mock_base_handler())
# Both handlers should see same state and runtime

View File

@@ -22,7 +22,7 @@ from langchain.agents.middleware.types import (
hook_config,
)
from langchain.agents.factory import create_agent, _get_can_jump_to
from .model import FakeToolCallingModel
from ...model import FakeToolCallingModel
class CustomState(AgentState):
@@ -90,8 +90,7 @@ def test_on_model_call_decorator() -> None:
@wrap_model_call(state_schema=CustomState, tools=[test_tool], name="CustomOnModelCall")
def custom_on_model_call(request, handler):
request.system_prompt = "Modified"
return handler(request)
return handler(request.override(system_prompt="Modified"))
# Verify all options were applied
assert isinstance(custom_on_model_call, AgentMiddleware)
@@ -277,8 +276,7 @@ def test_async_on_model_call_decorator() -> None:
@wrap_model_call(state_schema=CustomState, tools=[test_tool], name="AsyncOnModelCall")
async def async_on_model_call(request, handler):
request.system_prompt = "Modified async"
return await handler(request)
return await handler(request.override(system_prompt="Modified async"))
assert isinstance(async_on_model_call, AgentMiddleware)
assert async_on_model_call.state_schema == CustomState

View File

@@ -0,0 +1,193 @@
from collections.abc import Callable
from syrupy.assertion import SnapshotAssertion
from langchain.agents.factory import create_agent
from langchain.agents.middleware.types import AgentMiddleware, ModelRequest
from langchain_core.messages import AIMessage
from ...model import FakeToolCallingModel
def test_create_agent_diagram(
snapshot: SnapshotAssertion,
):
class NoopOne(AgentMiddleware):
def before_model(self, state, runtime):
pass
class NoopTwo(AgentMiddleware):
def before_model(self, state, runtime):
pass
class NoopThree(AgentMiddleware):
def before_model(self, state, runtime):
pass
class NoopFour(AgentMiddleware):
def after_model(self, state, runtime):
pass
class NoopFive(AgentMiddleware):
def after_model(self, state, runtime):
pass
class NoopSix(AgentMiddleware):
def after_model(self, state, runtime):
pass
class NoopSeven(AgentMiddleware):
def before_model(self, state, runtime):
pass
def after_model(self, state, runtime):
pass
class NoopEight(AgentMiddleware):
def before_model(self, state, runtime):
pass
def after_model(self, state, runtime):
pass
class NoopNine(AgentMiddleware):
def before_model(self, state, runtime):
pass
def after_model(self, state, runtime):
pass
class NoopTen(AgentMiddleware):
def before_model(self, state, runtime):
pass
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], AIMessage],
) -> AIMessage:
return handler(request)
def after_model(self, state, runtime):
pass
class NoopEleven(AgentMiddleware):
def before_model(self, state, runtime):
pass
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], AIMessage],
) -> AIMessage:
return handler(request)
def after_model(self, state, runtime):
pass
agent_zero = create_agent(
model=FakeToolCallingModel(),
tools=[],
system_prompt="You are a helpful assistant.",
)
assert agent_zero.get_graph().draw_mermaid() == snapshot
agent_one = create_agent(
model=FakeToolCallingModel(),
tools=[],
system_prompt="You are a helpful assistant.",
middleware=[NoopOne()],
)
assert agent_one.get_graph().draw_mermaid() == snapshot
agent_two = create_agent(
model=FakeToolCallingModel(),
tools=[],
system_prompt="You are a helpful assistant.",
middleware=[NoopOne(), NoopTwo()],
)
assert agent_two.get_graph().draw_mermaid() == snapshot
agent_three = create_agent(
model=FakeToolCallingModel(),
tools=[],
system_prompt="You are a helpful assistant.",
middleware=[NoopOne(), NoopTwo(), NoopThree()],
)
assert agent_three.get_graph().draw_mermaid() == snapshot
agent_four = create_agent(
model=FakeToolCallingModel(),
tools=[],
system_prompt="You are a helpful assistant.",
middleware=[NoopFour()],
)
assert agent_four.get_graph().draw_mermaid() == snapshot
agent_five = create_agent(
model=FakeToolCallingModel(),
tools=[],
system_prompt="You are a helpful assistant.",
middleware=[NoopFour(), NoopFive()],
)
assert agent_five.get_graph().draw_mermaid() == snapshot
agent_six = create_agent(
model=FakeToolCallingModel(),
tools=[],
system_prompt="You are a helpful assistant.",
middleware=[NoopFour(), NoopFive(), NoopSix()],
)
assert agent_six.get_graph().draw_mermaid() == snapshot
agent_seven = create_agent(
model=FakeToolCallingModel(),
tools=[],
system_prompt="You are a helpful assistant.",
middleware=[NoopSeven()],
)
assert agent_seven.get_graph().draw_mermaid() == snapshot
agent_eight = create_agent(
model=FakeToolCallingModel(),
tools=[],
system_prompt="You are a helpful assistant.",
middleware=[NoopSeven(), NoopEight()],
)
assert agent_eight.get_graph().draw_mermaid() == snapshot
agent_nine = create_agent(
model=FakeToolCallingModel(),
tools=[],
system_prompt="You are a helpful assistant.",
middleware=[NoopSeven(), NoopEight(), NoopNine()],
)
assert agent_nine.get_graph().draw_mermaid() == snapshot
agent_ten = create_agent(
model=FakeToolCallingModel(),
tools=[],
system_prompt="You are a helpful assistant.",
middleware=[NoopTen()],
)
assert agent_ten.get_graph().draw_mermaid() == snapshot
agent_eleven = create_agent(
model=FakeToolCallingModel(),
tools=[],
system_prompt="You are a helpful assistant.",
middleware=[NoopTen(), NoopEleven()],
)
assert agent_eleven.get_graph().draw_mermaid() == snapshot

File diff suppressed because it is too large Load Diff

View File

@@ -14,7 +14,7 @@ from langgraph.checkpoint.memory import InMemorySaver
from langchain.agents.factory import create_agent
from langchain.agents.middleware.types import AgentMiddleware, wrap_tool_call
from langchain.agents.middleware.types import ToolCallRequest
from tests.unit_tests.agents.test_middleware_agent import FakeToolCallingModel
from tests.unit_tests.agents.model import FakeToolCallingModel
@tool

View File

@@ -9,7 +9,7 @@ from langchain.agents.middleware.types import AgentMiddleware, AgentState, Model
from langgraph.prebuilt.tool_node import ToolNode
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from langchain_core.tools import tool
from .model import FakeToolCallingModel
from tests.unit_tests.agents.model import FakeToolCallingModel
def test_model_request_tools_are_base_tools() -> None:
@@ -79,8 +79,8 @@ def test_middleware_can_modify_tools() -> None:
handler: Callable[[ModelRequest], AIMessage],
) -> AIMessage:
# Only allow tool_a and tool_b
request.tools = [t for t in request.tools if t.name in ["tool_a", "tool_b"]]
return handler(request)
filtered_tools = [t for t in request.tools if t.name in ["tool_a", "tool_b"]]
return handler(request.override(tools=filtered_tools))
# Model will try to call tool_a
model = FakeToolCallingModel(
@@ -123,8 +123,7 @@ def test_unknown_tool_raises_error() -> None:
handler: Callable[[ModelRequest], AIMessage],
) -> AIMessage:
# Add an unknown tool
request.tools = request.tools + [unknown_tool]
return handler(request)
return handler(request.override(tools=request.tools + [unknown_tool]))
agent = create_agent(
model=FakeToolCallingModel(),
@@ -163,7 +162,8 @@ def test_middleware_can_add_and_remove_tools() -> None:
) -> AIMessage:
# Remove admin_tool if not admin
if not request.state.get("is_admin", False):
request.tools = [t for t in request.tools if t.name != "admin_tool"]
filtered_tools = [t for t in request.tools if t.name != "admin_tool"]
request = request.override(tools=filtered_tools)
return handler(request)
model = FakeToolCallingModel()
@@ -200,7 +200,7 @@ def test_empty_tools_list_is_valid() -> None:
handler: Callable[[ModelRequest], AIMessage],
) -> AIMessage:
# Remove all tools
request.tools = []
request = request.override(tools=[])
return handler(request)
model = FakeToolCallingModel()
@@ -244,7 +244,8 @@ def test_tools_preserved_across_multiple_middleware() -> None:
) -> AIMessage:
modification_order.append([t.name for t in request.tools])
# Remove tool_c
request.tools = [t for t in request.tools if t.name != "tool_c"]
filtered_tools = [t for t in request.tools if t.name != "tool_c"]
request = request.override(tools=filtered_tools)
return handler(request)
class SecondMiddleware(AgentMiddleware):
@@ -257,7 +258,8 @@ def test_tools_preserved_across_multiple_middleware() -> None:
# Should not see tool_c here
assert all(t.name != "tool_c" for t in request.tools)
# Remove tool_b
request.tools = [t for t in request.tools if t.name != "tool_b"]
filtered_tools = [t for t in request.tools if t.name != "tool_b"]
request = request.override(tools=filtered_tools)
return handler(request)
agent = create_agent(

View File

@@ -1,18 +1,30 @@
"""Unit tests for wrap_model_call middleware generator protocol."""
"""Unit tests for wrap_model_call hook and @wrap_model_call decorator.
This module tests the wrap_model_call functionality in three forms:
1. As a middleware method (AgentMiddleware.wrap_model_call)
2. As a decorator (@wrap_model_call)
3. Async variant (AgentMiddleware.awrap_model_call)
"""
from collections.abc import Awaitable, Callable
import pytest
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain.agents import create_agent
from langchain.agents.middleware.types import (
AgentMiddleware,
AgentState,
ModelRequest,
wrap_model_call,
)
from ...model import FakeToolCallingModel
class TestBasicOnModelCall:
class TestBasicWrapModelCall:
"""Test basic wrap_model_call functionality."""
def test_passthrough_middleware(self) -> None:
@@ -70,7 +82,7 @@ class TestBasicOnModelCall:
assert counter.call_count == 1
class TestRetryMiddleware:
class TestRetryLogic:
"""Test retry logic with wrap_model_call."""
def test_simple_retry_on_error(self) -> None:
@@ -91,12 +103,10 @@ class TestRetryMiddleware:
def wrap_model_call(self, request, handler):
try:
result = handler(request)
return result
return handler(request)
except Exception:
self.retry_count += 1
result = handler(request)
return result
return handler(request)
retry_middleware = RetryOnceMiddleware()
model = FailOnceThenSucceed(messages=iter([AIMessage(content="Success")]))
@@ -125,8 +135,7 @@ class TestRetryMiddleware:
for attempt in range(self.max_retries):
self.attempts.append(attempt + 1)
try:
result = handler(request)
return result
return handler(request)
except Exception as e:
last_exception = e
continue
@@ -143,6 +152,75 @@ class TestRetryMiddleware:
assert retry_middleware.attempts == [1, 2, 3]
def test_no_retry_propagates_error(self) -> None:
"""Test that error is propagated when middleware doesn't retry."""
class FailingModel(BaseChatModel):
"""Model that always fails."""
def _generate(self, messages, **kwargs):
raise ValueError("Model error")
@property
def _llm_type(self):
return "failing"
class NoRetryMiddleware(AgentMiddleware):
def wrap_model_call(self, request, handler):
return handler(request)
agent = create_agent(model=FailingModel(), middleware=[NoRetryMiddleware()])
with pytest.raises(ValueError, match="Model error"):
agent.invoke({"messages": [HumanMessage("Test")]})
def test_max_attempts_limit(self) -> None:
"""Test that middleware controls termination via retry limits."""
class AlwaysFailingModel(BaseChatModel):
"""Model that always fails."""
def _generate(self, messages, **kwargs):
raise ValueError("Always fails")
@property
def _llm_type(self):
return "always_failing"
class LimitedRetryMiddleware(AgentMiddleware):
"""Middleware that limits its own retries."""
def __init__(self, max_retries: int = 10):
super().__init__()
self.max_retries = max_retries
self.attempt_count = 0
def wrap_model_call(self, request, handler):
last_exception = None
for attempt in range(self.max_retries):
self.attempt_count += 1
try:
return handler(request)
except Exception as e:
last_exception = e
# Continue to retry
# All retries exhausted, re-raise the last error
if last_exception:
raise last_exception
model = AlwaysFailingModel()
middleware = LimitedRetryMiddleware(max_retries=10)
agent = create_agent(model=model, middleware=[middleware])
# Should fail with the model's error after middleware stops retrying
with pytest.raises(ValueError, match="Always fails"):
agent.invoke({"messages": [HumanMessage("Test")]})
# Should have attempted exactly 10 times as configured
assert middleware.attempt_count == 10
class TestResponseRewriting:
"""Test response content rewriting with wrap_model_call."""
@@ -185,6 +263,28 @@ class TestResponseRewriting:
assert result["messages"][1].content == "[BOT]: Response"
def test_multi_stage_transformation(self) -> None:
"""Test middleware applying multiple transformations."""
class MultiTransformMiddleware(AgentMiddleware):
def wrap_model_call(self, request, handler):
result = handler(request)
# result is ModelResponse, extract AIMessage from it
ai_message = result.result[0]
# First transformation: uppercase
content = ai_message.content.upper()
# Second transformation: add prefix and suffix
content = f"[START] {content} [END]"
return AIMessage(content=content)
model = GenericFakeChatModel(messages=iter([AIMessage(content="hello")]))
agent = create_agent(model=model, middleware=[MultiTransformMiddleware()])
result = agent.invoke({"messages": [HumanMessage("Test")]})
assert result["messages"][1].content == "[START] HELLO [END]"
class TestErrorHandling:
"""Test error handling with wrap_model_call."""
@@ -200,9 +300,8 @@ class TestErrorHandling:
def wrap_model_call(self, request, handler):
try:
return handler(request)
except Exception:
fallback = AIMessage(content="Error handled gracefully")
return fallback
except Exception as e:
return AIMessage(content=f"Error occurred: {e}. Using fallback response.")
model = AlwaysFailModel(messages=iter([]))
agent = create_agent(model=model, middleware=[ErrorToSuccessMiddleware()])
@@ -210,7 +309,8 @@ class TestErrorHandling:
# Should not raise, middleware converts error to response
result = agent.invoke({"messages": [HumanMessage("Test")]})
assert "Error handled gracefully" in result["messages"][1].content
assert "Error occurred" in result["messages"][1].content
assert "fallback response" in result["messages"][1].content
def test_selective_error_handling(self) -> None:
"""Test middleware that only handles specific errors."""
@@ -224,8 +324,7 @@ class TestErrorHandling:
try:
return handler(request)
except ConnectionError:
fallback = AIMessage(content="Network issue, try again later")
return fallback
return AIMessage(content="Network issue, try again later")
model = SpecificErrorModel(messages=iter([]))
agent = create_agent(model=model, middleware=[SelectiveErrorMiddleware()])
@@ -247,8 +346,7 @@ class TestErrorHandling:
return result
except Exception:
call_log.append("caught-error")
fallback = AIMessage(content="Recovered from error")
return fallback
return AIMessage(content="Recovered from error")
# Test 1: Success path
call_log.clear()
@@ -403,7 +501,6 @@ class TestStateAndRuntime:
for attempt in range(max_retries):
try:
return handler(request)
break # Success
except Exception:
if attempt == max_retries - 1:
raise
@@ -460,6 +557,49 @@ class TestMiddlewareComposition:
"outer-after",
]
def test_three_middleware_composition(self) -> None:
"""Test composition of three middleware."""
execution_order = []
class FirstMiddleware(AgentMiddleware):
def wrap_model_call(self, request, handler):
execution_order.append("first-before")
response = handler(request)
execution_order.append("first-after")
return response
class SecondMiddleware(AgentMiddleware):
def wrap_model_call(self, request, handler):
execution_order.append("second-before")
response = handler(request)
execution_order.append("second-after")
return response
class ThirdMiddleware(AgentMiddleware):
def wrap_model_call(self, request, handler):
execution_order.append("third-before")
response = handler(request)
execution_order.append("third-after")
return response
model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")]))
agent = create_agent(
model=model,
middleware=[FirstMiddleware(), SecondMiddleware(), ThirdMiddleware()],
)
agent.invoke({"messages": [HumanMessage("Test")]})
# First wraps Second wraps Third: 1-before, 2-before, 3-before, model, 3-after, 2-after, 1-after
assert execution_order == [
"first-before",
"second-before",
"third-before",
"third-after",
"second-after",
"first-after",
]
def test_retry_with_logging(self) -> None:
"""Test retry middleware composed with logging middleware."""
call_count = {"value": 0}
@@ -549,11 +689,9 @@ class TestMiddlewareComposition:
class RetryMiddleware(AgentMiddleware):
def wrap_model_call(self, request, handler):
try:
result = handler(request)
return result
return handler(request)
except Exception:
result = handler(request)
return result
return handler(request)
class UppercaseMiddleware(AgentMiddleware):
def wrap_model_call(self, request, handler):
@@ -571,49 +709,6 @@ class TestMiddlewareComposition:
# Should retry and uppercase the result
assert result["messages"][1].content == "SUCCESS"
def test_three_middleware_composition(self) -> None:
"""Test composition of three middleware."""
execution_order = []
class FirstMiddleware(AgentMiddleware):
def wrap_model_call(self, request, handler):
execution_order.append("first-before")
response = handler(request)
execution_order.append("first-after")
return response
class SecondMiddleware(AgentMiddleware):
def wrap_model_call(self, request, handler):
execution_order.append("second-before")
response = handler(request)
execution_order.append("second-after")
return response
class ThirdMiddleware(AgentMiddleware):
def wrap_model_call(self, request, handler):
execution_order.append("third-before")
response = handler(request)
execution_order.append("third-after")
return response
model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")]))
agent = create_agent(
model=model,
middleware=[FirstMiddleware(), SecondMiddleware(), ThirdMiddleware()],
)
agent.invoke({"messages": [HumanMessage("Test")]})
# First wraps Second wraps Third: 1-before, 2-before, 3-before, model, 3-after, 2-after, 1-after
assert execution_order == [
"first-before",
"second-before",
"third-before",
"third-after",
"second-after",
"first-after",
]
def test_middle_retry_middleware(self) -> None:
"""Test that middle middleware doing retry causes inner to execute twice."""
execution_order = []
@@ -674,7 +769,306 @@ class TestMiddlewareComposition:
assert len(model_calls) == 2
class TestAsyncOnModelCall:
class TestWrapModelCallDecorator:
"""Test the @wrap_model_call decorator for creating middleware."""
def test_basic_decorator_usage(self) -> None:
"""Test basic decorator usage without parameters."""
@wrap_model_call
def passthrough_middleware(request, handler):
return handler(request)
# Should return an AgentMiddleware instance
assert isinstance(passthrough_middleware, AgentMiddleware)
# Should work in agent
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
agent = create_agent(model=model, middleware=[passthrough_middleware])
result = agent.invoke({"messages": [HumanMessage("Hi")]})
assert len(result["messages"]) == 2
assert result["messages"][1].content == "Hello"
def test_decorator_with_custom_name(self) -> None:
"""Test decorator with custom middleware name."""
@wrap_model_call(name="CustomMiddleware")
def my_middleware(request, handler):
return handler(request)
assert isinstance(my_middleware, AgentMiddleware)
assert my_middleware.__class__.__name__ == "CustomMiddleware"
def test_decorator_retry_logic(self) -> None:
"""Test decorator for implementing retry logic."""
call_count = {"value": 0}
class FailOnceThenSucceed(GenericFakeChatModel):
def _generate(self, messages, **kwargs):
call_count["value"] += 1
if call_count["value"] == 1:
raise ValueError("First call fails")
return super()._generate(messages, **kwargs)
@wrap_model_call
def retry_once(request, handler):
try:
return handler(request)
except Exception:
# Retry once
return handler(request)
model = FailOnceThenSucceed(messages=iter([AIMessage(content="Success")]))
agent = create_agent(model=model, middleware=[retry_once])
result = agent.invoke({"messages": [HumanMessage("Test")]})
assert call_count["value"] == 2
assert result["messages"][1].content == "Success"
def test_decorator_response_rewriting(self) -> None:
"""Test decorator for rewriting responses."""
@wrap_model_call
def uppercase_responses(request, handler):
result = handler(request)
# result is ModelResponse, extract AIMessage from it
ai_message = result.result[0]
return AIMessage(content=ai_message.content.upper())
model = GenericFakeChatModel(messages=iter([AIMessage(content="hello world")]))
agent = create_agent(model=model, middleware=[uppercase_responses])
result = agent.invoke({"messages": [HumanMessage("Test")]})
assert result["messages"][1].content == "HELLO WORLD"
def test_decorator_error_handling(self) -> None:
"""Test decorator for error recovery."""
class AlwaysFailModel(GenericFakeChatModel):
def _generate(self, messages, **kwargs):
raise ValueError("Model error")
@wrap_model_call
def error_to_fallback(request, handler):
try:
return handler(request)
except Exception:
return AIMessage(content="Fallback response")
model = AlwaysFailModel(messages=iter([]))
agent = create_agent(model=model, middleware=[error_to_fallback])
result = agent.invoke({"messages": [HumanMessage("Test")]})
assert result["messages"][1].content == "Fallback response"
def test_decorator_with_state_access(self) -> None:
"""Test decorator accessing agent state."""
state_values = []
@wrap_model_call
def log_state(request, handler):
state_values.append(request.state.get("messages"))
return handler(request)
model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")]))
agent = create_agent(model=model, middleware=[log_state])
agent.invoke({"messages": [HumanMessage("Test")]})
# State should contain the user message
assert len(state_values) == 1
assert len(state_values[0]) == 1
assert state_values[0][0].content == "Test"
def test_multiple_decorated_middleware(self) -> None:
"""Test composition of multiple decorated middleware."""
execution_order = []
@wrap_model_call
def outer_middleware(request, handler):
execution_order.append("outer-before")
result = handler(request)
execution_order.append("outer-after")
return result
@wrap_model_call
def inner_middleware(request, handler):
execution_order.append("inner-before")
result = handler(request)
execution_order.append("inner-after")
return result
model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")]))
agent = create_agent(model=model, middleware=[outer_middleware, inner_middleware])
agent.invoke({"messages": [HumanMessage("Test")]})
assert execution_order == [
"outer-before",
"inner-before",
"inner-after",
"outer-after",
]
def test_decorator_with_custom_state_schema(self) -> None:
"""Test decorator with custom state schema."""
from typing_extensions import TypedDict
class CustomState(TypedDict):
messages: list
custom_field: str
@wrap_model_call(state_schema=CustomState)
def middleware_with_schema(request, handler):
return handler(request)
assert isinstance(middleware_with_schema, AgentMiddleware)
# Custom state schema should be set
assert middleware_with_schema.state_schema == CustomState
def test_decorator_with_tools_parameter(self) -> None:
"""Test decorator with tools parameter."""
from langchain_core.tools import tool
@tool
def test_tool(query: str) -> str:
"""A test tool."""
return f"Result: {query}"
@wrap_model_call(tools=[test_tool])
def middleware_with_tools(request, handler):
return handler(request)
assert isinstance(middleware_with_tools, AgentMiddleware)
assert len(middleware_with_tools.tools) == 1
assert middleware_with_tools.tools[0].name == "test_tool"
def test_decorator_parentheses_optional(self) -> None:
"""Test that decorator works both with and without parentheses."""
# Without parentheses
@wrap_model_call
def middleware_no_parens(request, handler):
return handler(request)
# With parentheses
@wrap_model_call()
def middleware_with_parens(request, handler):
return handler(request)
assert isinstance(middleware_no_parens, AgentMiddleware)
assert isinstance(middleware_with_parens, AgentMiddleware)
def test_decorator_preserves_function_name(self) -> None:
"""Test that decorator uses function name for class name."""
@wrap_model_call
def my_custom_middleware(request, handler):
return handler(request)
assert my_custom_middleware.__class__.__name__ == "my_custom_middleware"
def test_decorator_mixed_with_class_middleware(self) -> None:
"""Test decorated middleware mixed with class-based middleware."""
execution_order = []
@wrap_model_call
def decorated_middleware(request, handler):
execution_order.append("decorated-before")
result = handler(request)
execution_order.append("decorated-after")
return result
class ClassMiddleware(AgentMiddleware):
def wrap_model_call(self, request, handler):
execution_order.append("class-before")
result = handler(request)
execution_order.append("class-after")
return result
model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")]))
agent = create_agent(
model=model,
middleware=[decorated_middleware, ClassMiddleware()],
)
agent.invoke({"messages": [HumanMessage("Test")]})
# Decorated is outer, class-based is inner
assert execution_order == [
"decorated-before",
"class-before",
"class-after",
"decorated-after",
]
def test_decorator_complex_retry_logic(self) -> None:
"""Test decorator with complex retry logic and backoff."""
attempts = []
call_count = {"value": 0}
class UnreliableModel(GenericFakeChatModel):
def _generate(self, messages, **kwargs):
call_count["value"] += 1
if call_count["value"] <= 2:
raise ValueError(f"Attempt {call_count['value']} failed")
return super()._generate(messages, **kwargs)
@wrap_model_call
def retry_with_tracking(request, handler):
max_retries = 3
for attempt in range(max_retries):
attempts.append(attempt + 1)
try:
return handler(request)
except Exception:
# On error, continue to next attempt
if attempt < max_retries - 1:
continue # Retry
else:
raise # All retries failed
model = UnreliableModel(messages=iter([AIMessage(content="Finally worked")]))
agent = create_agent(model=model, middleware=[retry_with_tracking])
result = agent.invoke({"messages": [HumanMessage("Test")]})
assert attempts == [1, 2, 3]
assert result["messages"][1].content == "Finally worked"
def test_decorator_request_modification(self) -> None:
"""Test decorator modifying request before execution."""
modified_prompts = []
@wrap_model_call
def add_system_prompt(request, handler):
# Modify request to add system prompt
modified_request = ModelRequest(
messages=request.messages,
model=request.model,
system_prompt="You are a helpful assistant",
tool_choice=request.tool_choice,
tools=request.tools,
response_format=request.response_format,
state={},
runtime=None,
)
modified_prompts.append(modified_request.system_prompt)
return handler(modified_request)
model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")]))
agent = create_agent(model=model, middleware=[add_system_prompt])
agent.invoke({"messages": [HumanMessage("Test")]})
assert modified_prompts == ["You are a helpful assistant"]
class TestAsyncWrapModelCall:
"""Test async execution with wrap_model_call."""
async def test_async_model_with_middleware(self) -> None:
@@ -686,7 +1080,6 @@ class TestAsyncOnModelCall:
log.append("before")
result = await handler(request)
log.append("after")
return result
model = GenericFakeChatModel(messages=iter([AIMessage(content="Async response")]))
@@ -723,6 +1116,92 @@ class TestAsyncOnModelCall:
assert call_count["value"] == 2
assert result["messages"][1].content == "Async success"
async def test_decorator_with_async_agent(self) -> None:
"""Test that decorated middleware works with async agent invocation."""
call_log = []
@wrap_model_call
async def logging_middleware(request, handler):
call_log.append("before")
result = await handler(request)
call_log.append("after")
return result
model = GenericFakeChatModel(messages=iter([AIMessage(content="Async response")]))
agent = create_agent(model=model, middleware=[logging_middleware])
result = await agent.ainvoke({"messages": [HumanMessage("Test")]})
assert call_log == ["before", "after"]
assert result["messages"][1].content == "Async response"
class TestSyncAsyncInterop:
"""Test sync/async interoperability."""
def test_sync_invoke_with_only_async_middleware_raises_error(self) -> None:
"""Test that sync invoke with only async middleware raises error."""
class AsyncOnlyMiddleware(AgentMiddleware):
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[AIMessage]],
) -> AIMessage:
return await handler(request)
agent = create_agent(
model=FakeToolCallingModel(),
tools=[],
system_prompt="You are a helpful assistant.",
middleware=[AsyncOnlyMiddleware()],
)
with pytest.raises(NotImplementedError):
agent.invoke({"messages": [HumanMessage("hello")]})
def test_sync_invoke_with_mixed_middleware(self) -> None:
"""Test that sync invoke works with mixed sync/async middleware when sync versions exist."""
calls = []
class MixedMiddleware(AgentMiddleware):
def before_model(self, state, runtime) -> None:
calls.append("MixedMiddleware.before_model")
async def abefore_model(self, state, runtime) -> None:
calls.append("MixedMiddleware.abefore_model")
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], AIMessage],
) -> AIMessage:
calls.append("MixedMiddleware.wrap_model_call")
return handler(request)
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[AIMessage]],
) -> AIMessage:
calls.append("MixedMiddleware.awrap_model_call")
return await handler(request)
agent = create_agent(
model=FakeToolCallingModel(),
tools=[],
system_prompt="You are a helpful assistant.",
middleware=[MixedMiddleware()],
)
result = agent.invoke({"messages": [HumanMessage("hello")]})
# In sync mode, only sync methods should be called
assert calls == [
"MixedMiddleware.before_model",
"MixedMiddleware.wrap_model_call",
]
class TestEdgeCases:
"""Test edge cases and error conditions."""
@@ -753,12 +1232,10 @@ class TestEdgeCases:
def wrap_model_call(self, request, handler):
attempts.append("first-attempt")
try:
result = handler(request)
return result
return handler(request)
except Exception:
attempts.append("retry-attempt")
result = handler(request)
return result
return handler(request)
call_count = {"value": 0}

View File

@@ -14,7 +14,7 @@ from langgraph.types import Command
from langchain.agents.factory import create_agent
from langchain.agents.middleware.types import wrap_tool_call
from langchain.agents.middleware.types import ToolCallRequest
from tests.unit_tests.agents.test_middleware_agent import FakeToolCallingModel
from tests.unit_tests.agents.model import FakeToolCallingModel
@tool

View File

@@ -82,16 +82,23 @@ def test_no_edit_when_below_trigger() -> None:
edits=[ClearToolUsesEdit(trigger=50)],
)
modified_request = None
def mock_handler(req: ModelRequest) -> AIMessage:
nonlocal modified_request
modified_request = req
return AIMessage(content="mock response")
# Call wrap_model_call which modifies the request
# Call wrap_model_call which creates a new request
middleware.wrap_model_call(request, mock_handler)
# The request should have been modified in place
# The modified request passed to handler should be the same since no edits applied
assert modified_request is not None
assert modified_request.messages[0].content == ""
assert modified_request.messages[1].content == "12345"
# Original request should be unchanged
assert request.messages[0].content == ""
assert request.messages[1].content == "12345"
assert state["messages"] == request.messages
def test_clear_tool_outputs_and_inputs() -> None:
@@ -115,14 +122,19 @@ def test_clear_tool_outputs_and_inputs() -> None:
)
middleware = ContextEditingMiddleware(edits=[edit])
modified_request = None
def mock_handler(req: ModelRequest) -> AIMessage:
nonlocal modified_request
modified_request = req
return AIMessage(content="mock response")
# Call wrap_model_call which modifies the request
# Call wrap_model_call which creates a new request with edits
middleware.wrap_model_call(request, mock_handler)
cleared_ai = request.messages[0]
cleared_tool = request.messages[1]
assert modified_request is not None
cleared_ai = modified_request.messages[0]
cleared_tool = modified_request.messages[1]
assert isinstance(cleared_tool, ToolMessage)
assert cleared_tool.content == "[cleared output]"
@@ -134,7 +146,9 @@ def test_clear_tool_outputs_and_inputs() -> None:
assert context_meta is not None
assert context_meta["cleared_tool_inputs"] == [tool_call_id]
assert state["messages"] == request.messages
# Original request should be unchanged
assert request.messages[0].tool_calls[0]["args"] == {"query": "foo"}
assert request.messages[1].content == "x" * 200
def test_respects_keep_last_tool_results() -> None:
@@ -167,21 +181,26 @@ def test_respects_keep_last_tool_results() -> None:
token_count_method="model",
)
modified_request = None
def mock_handler(req: ModelRequest) -> AIMessage:
nonlocal modified_request
modified_request = req
return AIMessage(content="mock response")
# Call wrap_model_call which modifies the request
# Call wrap_model_call which creates a new request with edits
middleware.wrap_model_call(request, mock_handler)
assert modified_request is not None
cleared_messages = [
msg
for msg in request.messages
for msg in modified_request.messages
if isinstance(msg, ToolMessage) and msg.content == "[cleared]"
]
assert len(cleared_messages) == 2
assert isinstance(request.messages[-1], ToolMessage)
assert request.messages[-1].content != "[cleared]"
assert isinstance(modified_request.messages[-1], ToolMessage)
assert modified_request.messages[-1].content != "[cleared]"
def test_exclude_tools_prevents_clearing() -> None:
@@ -215,14 +234,19 @@ def test_exclude_tools_prevents_clearing() -> None:
],
)
modified_request = None
def mock_handler(req: ModelRequest) -> AIMessage:
nonlocal modified_request
modified_request = req
return AIMessage(content="mock response")
# Call wrap_model_call which modifies the request
# Call wrap_model_call which creates a new request with edits
middleware.wrap_model_call(request, mock_handler)
search_tool = request.messages[1]
calc_tool = request.messages[3]
assert modified_request is not None
search_tool = modified_request.messages[1]
calc_tool = modified_request.messages[3]
assert isinstance(search_tool, ToolMessage)
assert search_tool.content == "search-results" * 20
@@ -249,16 +273,23 @@ async def test_no_edit_when_below_trigger_async() -> None:
edits=[ClearToolUsesEdit(trigger=50)],
)
modified_request = None
async def mock_handler(req: ModelRequest) -> AIMessage:
nonlocal modified_request
modified_request = req
return AIMessage(content="mock response")
# Call awrap_model_call which modifies the request
# Call awrap_model_call which creates a new request
await middleware.awrap_model_call(request, mock_handler)
# The request should have been modified in place
# The modified request passed to handler should be the same since no edits applied
assert modified_request is not None
assert modified_request.messages[0].content == ""
assert modified_request.messages[1].content == "12345"
# Original request should be unchanged
assert request.messages[0].content == ""
assert request.messages[1].content == "12345"
assert state["messages"] == request.messages
async def test_clear_tool_outputs_and_inputs_async() -> None:
@@ -283,14 +314,19 @@ async def test_clear_tool_outputs_and_inputs_async() -> None:
)
middleware = ContextEditingMiddleware(edits=[edit])
modified_request = None
async def mock_handler(req: ModelRequest) -> AIMessage:
nonlocal modified_request
modified_request = req
return AIMessage(content="mock response")
# Call awrap_model_call which modifies the request
# Call awrap_model_call which creates a new request with edits
await middleware.awrap_model_call(request, mock_handler)
cleared_ai = request.messages[0]
cleared_tool = request.messages[1]
assert modified_request is not None
cleared_ai = modified_request.messages[0]
cleared_tool = modified_request.messages[1]
assert isinstance(cleared_tool, ToolMessage)
assert cleared_tool.content == "[cleared output]"
@@ -302,7 +338,9 @@ async def test_clear_tool_outputs_and_inputs_async() -> None:
assert context_meta is not None
assert context_meta["cleared_tool_inputs"] == [tool_call_id]
assert state["messages"] == request.messages
# Original request should be unchanged
assert request.messages[0].tool_calls[0]["args"] == {"query": "foo"}
assert request.messages[1].content == "x" * 200
async def test_respects_keep_last_tool_results_async() -> None:
@@ -336,21 +374,26 @@ async def test_respects_keep_last_tool_results_async() -> None:
token_count_method="model",
)
modified_request = None
async def mock_handler(req: ModelRequest) -> AIMessage:
nonlocal modified_request
modified_request = req
return AIMessage(content="mock response")
# Call awrap_model_call which modifies the request
# Call awrap_model_call which creates a new request with edits
await middleware.awrap_model_call(request, mock_handler)
assert modified_request is not None
cleared_messages = [
msg
for msg in request.messages
for msg in modified_request.messages
if isinstance(msg, ToolMessage) and msg.content == "[cleared]"
]
assert len(cleared_messages) == 2
assert isinstance(request.messages[-1], ToolMessage)
assert request.messages[-1].content != "[cleared]"
assert isinstance(modified_request.messages[-1], ToolMessage)
assert modified_request.messages[-1].content != "[cleared]"
async def test_exclude_tools_prevents_clearing_async() -> None:
@@ -385,14 +428,19 @@ async def test_exclude_tools_prevents_clearing_async() -> None:
],
)
modified_request = None
async def mock_handler(req: ModelRequest) -> AIMessage:
nonlocal modified_request
modified_request = req
return AIMessage(content="mock response")
# Call awrap_model_call which modifies the request
# Call awrap_model_call which creates a new request with edits
await middleware.awrap_model_call(request, mock_handler)
search_tool = request.messages[1]
calc_tool = request.messages[3]
assert modified_request is not None
search_tool = modified_request.messages[1]
calc_tool = modified_request.messages[3]
assert isinstance(search_tool, ToolMessage)
assert search_tool.content == "search-results" * 20

View File

@@ -7,6 +7,9 @@ import pytest
from langchain.agents.middleware.file_search import (
FilesystemFileSearchMiddleware,
_expand_include_patterns,
_is_valid_include_pattern,
_match_include_pattern,
)
@@ -259,3 +262,105 @@ class TestPathTraversalSecurity:
assert result == "No matches found"
assert "secret" not in result
class TestExpandIncludePatterns:
"""Tests for _expand_include_patterns helper function."""
def test_expand_patterns_basic_brace_expansion(self) -> None:
"""Test basic brace expansion with multiple options."""
result = _expand_include_patterns("*.{py,txt}")
assert result == ["*.py", "*.txt"]
def test_expand_patterns_nested_braces(self) -> None:
"""Test nested brace expansion."""
result = _expand_include_patterns("test.{a,b}.{c,d}")
assert result is not None
assert len(result) == 4
assert "test.a.c" in result
assert "test.b.d" in result
@pytest.mark.parametrize(
"pattern",
[
"*.py}", # closing brace without opening
"*.{}", # empty braces
"*.{py", # unclosed brace
],
)
def test_expand_patterns_invalid_braces(self, pattern: str) -> None:
"""Test patterns with invalid brace syntax return None."""
result = _expand_include_patterns(pattern)
assert result is None
class TestValidateIncludePattern:
"""Tests for _is_valid_include_pattern helper function."""
@pytest.mark.parametrize(
"pattern",
[
"", # empty pattern
"*.py\x00", # null byte
"*.py\n", # newline
],
)
def test_validate_invalid_patterns(self, pattern: str) -> None:
"""Test that invalid patterns are rejected."""
assert not _is_valid_include_pattern(pattern)
class TestMatchIncludePattern:
"""Tests for _match_include_pattern helper function."""
def test_match_pattern_with_braces(self) -> None:
"""Test matching with brace expansion."""
assert _match_include_pattern("test.py", "*.{py,txt}")
assert _match_include_pattern("test.txt", "*.{py,txt}")
assert not _match_include_pattern("test.md", "*.{py,txt}")
def test_match_pattern_invalid_expansion(self) -> None:
"""Test matching with pattern that cannot be expanded returns False."""
assert not _match_include_pattern("test.py", "*.{}")
class TestGrepEdgeCases:
"""Tests for edge cases in grep search."""
def test_grep_with_special_chars_in_pattern(self, tmp_path: Path) -> None:
"""Test grep with special characters in pattern."""
(tmp_path / "test.py").write_text("def test():\n pass\n", encoding="utf-8")
middleware = FilesystemFileSearchMiddleware(root_path=str(tmp_path), use_ripgrep=False)
result = middleware.grep_search.func(pattern="def.*:")
assert "/test.py" in result
def test_grep_case_insensitive(self, tmp_path: Path) -> None:
"""Test grep with case-insensitive search."""
(tmp_path / "test.py").write_text("HELLO world\n", encoding="utf-8")
middleware = FilesystemFileSearchMiddleware(root_path=str(tmp_path), use_ripgrep=False)
result = middleware.grep_search.func(pattern="(?i)hello")
assert "/test.py" in result
def test_grep_with_large_file_skipping(self, tmp_path: Path) -> None:
"""Test that grep skips files larger than max_file_size_mb."""
# Create a file larger than 1MB
large_content = "x" * (2 * 1024 * 1024) # 2MB
(tmp_path / "large.txt").write_text(large_content, encoding="utf-8")
(tmp_path / "small.txt").write_text("x", encoding="utf-8")
middleware = FilesystemFileSearchMiddleware(
root_path=str(tmp_path),
use_ripgrep=False,
max_file_size_mb=1, # 1MB limit
)
result = middleware.grep_search.func(pattern="x")
# Large file should be skipped
assert "/small.txt" in result

View File

@@ -0,0 +1,575 @@
from typing import Any
from unittest.mock import patch
import pytest
from langchain_core.messages import AIMessage, HumanMessage, ToolCall, ToolMessage
from langgraph.runtime import Runtime
from langchain.agents.middleware.human_in_the_loop import (
Action,
HumanInTheLoopMiddleware,
)
from langchain.agents.middleware.types import AgentState
def test_human_in_the_loop_middleware_initialization() -> None:
"""Test HumanInTheLoopMiddleware initialization."""
middleware = HumanInTheLoopMiddleware(
interrupt_on={"test_tool": {"allowed_decisions": ["approve", "edit", "reject"]}},
description_prefix="Custom prefix",
)
assert middleware.interrupt_on == {
"test_tool": {"allowed_decisions": ["approve", "edit", "reject"]}
}
assert middleware.description_prefix == "Custom prefix"
def test_human_in_the_loop_middleware_no_interrupts_needed() -> None:
"""Test HumanInTheLoopMiddleware when no interrupts are needed."""
middleware = HumanInTheLoopMiddleware(
interrupt_on={"test_tool": {"allowed_decisions": ["approve", "edit", "reject"]}}
)
# Test with no messages
state: dict[str, Any] = {"messages": []}
result = middleware.after_model(state, None)
assert result is None
# Test with message but no tool calls
state = {"messages": [HumanMessage(content="Hello"), AIMessage(content="Hi there")]}
result = middleware.after_model(state, None)
assert result is None
# Test with tool calls that don't require interrupts
ai_message = AIMessage(
content="I'll help you",
tool_calls=[{"name": "other_tool", "args": {"input": "test"}, "id": "1"}],
)
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
result = middleware.after_model(state, None)
assert result is None
def test_human_in_the_loop_middleware_single_tool_accept() -> None:
"""Test HumanInTheLoopMiddleware with single tool accept response."""
middleware = HumanInTheLoopMiddleware(
interrupt_on={"test_tool": {"allowed_decisions": ["approve", "edit", "reject"]}}
)
ai_message = AIMessage(
content="I'll help you",
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
)
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
def mock_accept(requests):
return {"decisions": [{"type": "approve"}]}
with patch("langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_accept):
result = middleware.after_model(state, None)
assert result is not None
assert "messages" in result
assert len(result["messages"]) == 1
assert result["messages"][0] == ai_message
assert result["messages"][0].tool_calls == ai_message.tool_calls
state["messages"].append(
ToolMessage(content="Tool message", name="test_tool", tool_call_id="1")
)
state["messages"].append(AIMessage(content="test_tool called with result: Tool message"))
result = middleware.after_model(state, None)
# No interrupts needed
assert result is None
def test_human_in_the_loop_middleware_single_tool_edit() -> None:
"""Test HumanInTheLoopMiddleware with single tool edit response."""
middleware = HumanInTheLoopMiddleware(
interrupt_on={"test_tool": {"allowed_decisions": ["approve", "edit", "reject"]}}
)
ai_message = AIMessage(
content="I'll help you",
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
)
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
def mock_edit(requests):
return {
"decisions": [
{
"type": "edit",
"edited_action": Action(
name="test_tool",
args={"input": "edited"},
),
}
]
}
with patch("langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_edit):
result = middleware.after_model(state, None)
assert result is not None
assert "messages" in result
assert len(result["messages"]) == 1
assert result["messages"][0].tool_calls[0]["args"] == {"input": "edited"}
assert result["messages"][0].tool_calls[0]["id"] == "1" # ID should be preserved
def test_human_in_the_loop_middleware_single_tool_response() -> None:
"""Test HumanInTheLoopMiddleware with single tool response with custom message."""
middleware = HumanInTheLoopMiddleware(
interrupt_on={"test_tool": {"allowed_decisions": ["approve", "edit", "reject"]}}
)
ai_message = AIMessage(
content="I'll help you",
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
)
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
def mock_response(requests):
return {"decisions": [{"type": "reject", "message": "Custom response message"}]}
with patch(
"langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_response
):
result = middleware.after_model(state, None)
assert result is not None
assert "messages" in result
assert len(result["messages"]) == 2
assert isinstance(result["messages"][0], AIMessage)
assert isinstance(result["messages"][1], ToolMessage)
assert result["messages"][1].content == "Custom response message"
assert result["messages"][1].name == "test_tool"
assert result["messages"][1].tool_call_id == "1"
def test_human_in_the_loop_middleware_multiple_tools_mixed_responses() -> None:
"""Test HumanInTheLoopMiddleware with multiple tools and mixed response types."""
middleware = HumanInTheLoopMiddleware(
interrupt_on={
"get_forecast": {"allowed_decisions": ["approve", "edit", "reject"]},
"get_temperature": {"allowed_decisions": ["approve", "edit", "reject"]},
}
)
ai_message = AIMessage(
content="I'll help you with weather",
tool_calls=[
{"name": "get_forecast", "args": {"location": "San Francisco"}, "id": "1"},
{"name": "get_temperature", "args": {"location": "San Francisco"}, "id": "2"},
],
)
state = {"messages": [HumanMessage(content="What's the weather?"), ai_message]}
def mock_mixed_responses(requests):
return {
"decisions": [
{"type": "approve"},
{"type": "reject", "message": "User rejected this tool call"},
]
}
with patch(
"langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_mixed_responses
):
result = middleware.after_model(state, None)
assert result is not None
assert "messages" in result
assert (
len(result["messages"]) == 2
) # AI message with accepted tool call + tool message for rejected
# First message should be the AI message with both tool calls
updated_ai_message = result["messages"][0]
assert len(updated_ai_message.tool_calls) == 2 # Both tool calls remain
assert updated_ai_message.tool_calls[0]["name"] == "get_forecast" # Accepted
assert updated_ai_message.tool_calls[1]["name"] == "get_temperature" # Got response
# Second message should be the tool message for the rejected tool call
tool_message = result["messages"][1]
assert isinstance(tool_message, ToolMessage)
assert tool_message.content == "User rejected this tool call"
assert tool_message.name == "get_temperature"
def test_human_in_the_loop_middleware_multiple_tools_edit_responses() -> None:
"""Test HumanInTheLoopMiddleware with multiple tools and edit responses."""
middleware = HumanInTheLoopMiddleware(
interrupt_on={
"get_forecast": {"allowed_decisions": ["approve", "edit", "reject"]},
"get_temperature": {"allowed_decisions": ["approve", "edit", "reject"]},
}
)
ai_message = AIMessage(
content="I'll help you with weather",
tool_calls=[
{"name": "get_forecast", "args": {"location": "San Francisco"}, "id": "1"},
{"name": "get_temperature", "args": {"location": "San Francisco"}, "id": "2"},
],
)
state = {"messages": [HumanMessage(content="What's the weather?"), ai_message]}
def mock_edit_responses(requests):
return {
"decisions": [
{
"type": "edit",
"edited_action": Action(
name="get_forecast",
args={"location": "New York"},
),
},
{
"type": "edit",
"edited_action": Action(
name="get_temperature",
args={"location": "New York"},
),
},
]
}
with patch(
"langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_edit_responses
):
result = middleware.after_model(state, None)
assert result is not None
assert "messages" in result
assert len(result["messages"]) == 1
updated_ai_message = result["messages"][0]
assert updated_ai_message.tool_calls[0]["args"] == {"location": "New York"}
assert updated_ai_message.tool_calls[0]["id"] == "1" # ID preserved
assert updated_ai_message.tool_calls[1]["args"] == {"location": "New York"}
assert updated_ai_message.tool_calls[1]["id"] == "2" # ID preserved
def test_human_in_the_loop_middleware_edit_with_modified_args() -> None:
"""Test HumanInTheLoopMiddleware with edit action that includes modified args."""
middleware = HumanInTheLoopMiddleware(
interrupt_on={"test_tool": {"allowed_decisions": ["approve", "edit", "reject"]}}
)
ai_message = AIMessage(
content="I'll help you",
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
)
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
def mock_edit_with_args(requests):
return {
"decisions": [
{
"type": "edit",
"edited_action": Action(
name="test_tool",
args={"input": "modified"},
),
}
]
}
with patch(
"langchain.agents.middleware.human_in_the_loop.interrupt",
side_effect=mock_edit_with_args,
):
result = middleware.after_model(state, None)
assert result is not None
assert "messages" in result
assert len(result["messages"]) == 1
# Should have modified args
updated_ai_message = result["messages"][0]
assert updated_ai_message.tool_calls[0]["args"] == {"input": "modified"}
assert updated_ai_message.tool_calls[0]["id"] == "1" # ID preserved
def test_human_in_the_loop_middleware_unknown_response_type() -> None:
"""Test HumanInTheLoopMiddleware with unknown response type."""
middleware = HumanInTheLoopMiddleware(
interrupt_on={"test_tool": {"allowed_decisions": ["approve", "edit", "reject"]}}
)
ai_message = AIMessage(
content="I'll help you",
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
)
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
def mock_unknown(requests):
return {"decisions": [{"type": "unknown"}]}
with patch("langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_unknown):
with pytest.raises(
ValueError,
match=r"Unexpected human decision: {'type': 'unknown'}. Decision type 'unknown' is not allowed for tool 'test_tool'. Expected one of \['approve', 'edit', 'reject'\] based on the tool's configuration.",
):
middleware.after_model(state, None)
def test_human_in_the_loop_middleware_disallowed_action() -> None:
"""Test HumanInTheLoopMiddleware with action not allowed by tool config."""
# edit is not allowed by tool config
middleware = HumanInTheLoopMiddleware(
interrupt_on={"test_tool": {"allowed_decisions": ["approve", "reject"]}}
)
ai_message = AIMessage(
content="I'll help you",
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
)
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
def mock_disallowed_action(requests):
return {
"decisions": [
{
"type": "edit",
"edited_action": Action(
name="test_tool",
args={"input": "modified"},
),
}
]
}
with patch(
"langchain.agents.middleware.human_in_the_loop.interrupt",
side_effect=mock_disallowed_action,
):
with pytest.raises(
ValueError,
match=r"Unexpected human decision: {'type': 'edit', 'edited_action': {'name': 'test_tool', 'args': {'input': 'modified'}}}. Decision type 'edit' is not allowed for tool 'test_tool'. Expected one of \['approve', 'reject'\] based on the tool's configuration.",
):
middleware.after_model(state, None)
def test_human_in_the_loop_middleware_mixed_auto_approved_and_interrupt() -> None:
"""Test HumanInTheLoopMiddleware with mix of auto-approved and interrupt tools."""
middleware = HumanInTheLoopMiddleware(
interrupt_on={"interrupt_tool": {"allowed_decisions": ["approve", "edit", "reject"]}}
)
ai_message = AIMessage(
content="I'll help you",
tool_calls=[
{"name": "auto_tool", "args": {"input": "auto"}, "id": "1"},
{"name": "interrupt_tool", "args": {"input": "interrupt"}, "id": "2"},
],
)
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
def mock_accept(requests):
return {"decisions": [{"type": "approve"}]}
with patch("langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_accept):
result = middleware.after_model(state, None)
assert result is not None
assert "messages" in result
assert len(result["messages"]) == 1
updated_ai_message = result["messages"][0]
# Should have both tools: auto-approved first, then interrupt tool
assert len(updated_ai_message.tool_calls) == 2
assert updated_ai_message.tool_calls[0]["name"] == "auto_tool"
assert updated_ai_message.tool_calls[1]["name"] == "interrupt_tool"
def test_human_in_the_loop_middleware_interrupt_request_structure() -> None:
"""Test that interrupt requests are structured correctly."""
middleware = HumanInTheLoopMiddleware(
interrupt_on={"test_tool": {"allowed_decisions": ["approve", "edit", "reject"]}},
description_prefix="Custom prefix",
)
ai_message = AIMessage(
content="I'll help you",
tool_calls=[{"name": "test_tool", "args": {"input": "test", "location": "SF"}, "id": "1"}],
)
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
captured_request = None
def mock_capture_requests(request):
nonlocal captured_request
captured_request = request
return {"decisions": [{"type": "approve"}]}
with patch(
"langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_capture_requests
):
middleware.after_model(state, None)
assert captured_request is not None
assert "action_requests" in captured_request
assert "review_configs" in captured_request
assert len(captured_request["action_requests"]) == 1
action_request = captured_request["action_requests"][0]
assert action_request["name"] == "test_tool"
assert action_request["args"] == {"input": "test", "location": "SF"}
assert "Custom prefix" in action_request["description"]
assert "Tool: test_tool" in action_request["description"]
assert "Args: {'input': 'test', 'location': 'SF'}" in action_request["description"]
assert len(captured_request["review_configs"]) == 1
review_config = captured_request["review_configs"][0]
assert review_config["action_name"] == "test_tool"
assert review_config["allowed_decisions"] == ["approve", "edit", "reject"]
def test_human_in_the_loop_middleware_boolean_configs() -> None:
"""Test HITL middleware with boolean tool configs."""
middleware = HumanInTheLoopMiddleware(interrupt_on={"test_tool": True})
ai_message = AIMessage(
content="I'll help you",
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
)
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
# Test accept
with patch(
"langchain.agents.middleware.human_in_the_loop.interrupt",
return_value={"decisions": [{"type": "approve"}]},
):
result = middleware.after_model(state, None)
assert result is not None
assert "messages" in result
assert len(result["messages"]) == 1
assert result["messages"][0].tool_calls == ai_message.tool_calls
# Test edit
with patch(
"langchain.agents.middleware.human_in_the_loop.interrupt",
return_value={
"decisions": [
{
"type": "edit",
"edited_action": Action(
name="test_tool",
args={"input": "edited"},
),
}
]
},
):
result = middleware.after_model(state, None)
assert result is not None
assert "messages" in result
assert len(result["messages"]) == 1
assert result["messages"][0].tool_calls[0]["args"] == {"input": "edited"}
middleware = HumanInTheLoopMiddleware(interrupt_on={"test_tool": False})
result = middleware.after_model(state, None)
# No interruption should occur
assert result is None
def test_human_in_the_loop_middleware_sequence_mismatch() -> None:
"""Test that sequence mismatch in resume raises an error."""
middleware = HumanInTheLoopMiddleware(interrupt_on={"test_tool": True})
ai_message = AIMessage(
content="I'll help you",
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
)
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
# Test with too few responses
with patch(
"langchain.agents.middleware.human_in_the_loop.interrupt",
return_value={"decisions": []}, # No responses for 1 tool call
):
with pytest.raises(
ValueError,
match=r"Number of human decisions \(0\) does not match number of hanging tool calls \(1\)\.",
):
middleware.after_model(state, None)
# Test with too many responses
with patch(
"langchain.agents.middleware.human_in_the_loop.interrupt",
return_value={
"decisions": [
{"type": "approve"},
{"type": "approve"},
]
}, # 2 responses for 1 tool call
):
with pytest.raises(
ValueError,
match=r"Number of human decisions \(2\) does not match number of hanging tool calls \(1\)\.",
):
middleware.after_model(state, None)
def test_human_in_the_loop_middleware_description_as_callable() -> None:
"""Test that description field accepts both string and callable."""
def custom_description(tool_call: ToolCall, state: AgentState, runtime: Runtime) -> str:
"""Generate a custom description."""
return f"Custom: {tool_call['name']} with args {tool_call['args']}"
middleware = HumanInTheLoopMiddleware(
interrupt_on={
"tool_with_callable": {
"allowed_decisions": ["approve"],
"description": custom_description,
},
"tool_with_string": {
"allowed_decisions": ["approve"],
"description": "Static description",
},
}
)
ai_message = AIMessage(
content="I'll help you",
tool_calls=[
{"name": "tool_with_callable", "args": {"x": 1}, "id": "1"},
{"name": "tool_with_string", "args": {"y": 2}, "id": "2"},
],
)
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
captured_request = None
def mock_capture_requests(request):
nonlocal captured_request
captured_request = request
return {"decisions": [{"type": "approve"}, {"type": "approve"}]}
with patch(
"langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_capture_requests
):
middleware.after_model(state, None)
assert captured_request is not None
assert "action_requests" in captured_request
assert len(captured_request["action_requests"]) == 2
# Check callable description
assert (
captured_request["action_requests"][0]["description"]
== "Custom: tool_with_callable with args {'x': 1}"
)
# Check string description
assert captured_request["action_requests"][1]["description"] == "Static description"

View File

@@ -0,0 +1,224 @@
import pytest
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from langchain_core.tools import tool
from langgraph.checkpoint.memory import InMemorySaver
from langchain.agents.factory import create_agent
from langchain.agents.middleware.model_call_limit import (
ModelCallLimitMiddleware,
ModelCallLimitExceededError,
)
from ...model import FakeToolCallingModel
@tool
def simple_tool(input: str) -> str:
"""A simple tool"""
return input
def test_middleware_unit_functionality():
"""Test that the middleware works as expected in isolation."""
# Test with end behavior
middleware = ModelCallLimitMiddleware(thread_limit=2, run_limit=1)
# Mock runtime (not used in current implementation)
runtime = None
# Test when limits are not exceeded
state = {"thread_model_call_count": 0, "run_model_call_count": 0}
result = middleware.before_model(state, runtime)
assert result is None
# Test when thread limit is exceeded
state = {"thread_model_call_count": 2, "run_model_call_count": 0}
result = middleware.before_model(state, runtime)
assert result is not None
assert result["jump_to"] == "end"
assert "messages" in result
assert len(result["messages"]) == 1
assert "thread limit (2/2)" in result["messages"][0].content
# Test when run limit is exceeded
state = {"thread_model_call_count": 1, "run_model_call_count": 1}
result = middleware.before_model(state, runtime)
assert result is not None
assert result["jump_to"] == "end"
assert "messages" in result
assert len(result["messages"]) == 1
assert "run limit (1/1)" in result["messages"][0].content
# Test with error behavior
middleware_exception = ModelCallLimitMiddleware(
thread_limit=2, run_limit=1, exit_behavior="error"
)
# Test exception when thread limit exceeded
state = {"thread_model_call_count": 2, "run_model_call_count": 0}
with pytest.raises(ModelCallLimitExceededError) as exc_info:
middleware_exception.before_model(state, runtime)
assert "thread limit (2/2)" in str(exc_info.value)
# Test exception when run limit exceeded
state = {"thread_model_call_count": 1, "run_model_call_count": 1}
with pytest.raises(ModelCallLimitExceededError) as exc_info:
middleware_exception.before_model(state, runtime)
assert "run limit (1/1)" in str(exc_info.value)
def test_thread_limit_with_create_agent():
"""Test that thread limits work correctly with create_agent."""
model = FakeToolCallingModel()
# Set thread limit to 1 (should be exceeded after 1 call)
agent = create_agent(
model=model,
tools=[simple_tool],
middleware=[ModelCallLimitMiddleware(thread_limit=1)],
checkpointer=InMemorySaver(),
)
# First invocation should work - 1 model call, within thread limit
result = agent.invoke(
{"messages": [HumanMessage("Hello")]}, {"configurable": {"thread_id": "thread1"}}
)
# Should complete successfully with 1 model call
assert "messages" in result
assert len(result["messages"]) == 2 # Human + AI messages
# Second invocation in same thread should hit thread limit
# The agent should jump to end after detecting the limit
result2 = agent.invoke(
{"messages": [HumanMessage("Hello again")]}, {"configurable": {"thread_id": "thread1"}}
)
assert "messages" in result2
# The agent should have detected the limit and jumped to end with a limit exceeded message
# So we should have: previous messages + new human message + limit exceeded AI message
assert len(result2["messages"]) == 4 # Previous Human + AI + New Human + Limit AI
assert isinstance(result2["messages"][0], HumanMessage) # First human
assert isinstance(result2["messages"][1], AIMessage) # First AI response
assert isinstance(result2["messages"][2], HumanMessage) # Second human
assert isinstance(result2["messages"][3], AIMessage) # Limit exceeded message
assert "thread limit" in result2["messages"][3].content
def test_run_limit_with_create_agent():
"""Test that run limits work correctly with create_agent."""
# Create a model that will make 2 calls
model = FakeToolCallingModel(
tool_calls=[
[{"name": "simple_tool", "args": {"input": "test"}, "id": "1"}],
[], # No tool calls on second call
]
)
# Set run limit to 1 (should be exceeded after 1 call)
agent = create_agent(
model=model,
tools=[simple_tool],
middleware=[ModelCallLimitMiddleware(run_limit=1)],
checkpointer=InMemorySaver(),
)
# This should hit the run limit after the first model call
result = agent.invoke(
{"messages": [HumanMessage("Hello")]}, {"configurable": {"thread_id": "thread1"}}
)
assert "messages" in result
# The agent should have made 1 model call then jumped to end with limit exceeded message
# So we should have: Human + AI + Tool + Limit exceeded AI message
assert len(result["messages"]) == 4 # Human + AI + Tool + Limit AI
assert isinstance(result["messages"][0], HumanMessage)
assert isinstance(result["messages"][1], AIMessage)
assert isinstance(result["messages"][2], ToolMessage)
assert isinstance(result["messages"][3], AIMessage) # Limit exceeded message
assert "run limit" in result["messages"][3].content
def test_middleware_initialization_validation():
"""Test that middleware initialization validates parameters correctly."""
# Test that at least one limit must be specified
with pytest.raises(ValueError, match="At least one limit must be specified"):
ModelCallLimitMiddleware()
# Test invalid exit behavior
with pytest.raises(ValueError, match="Invalid exit_behavior"):
ModelCallLimitMiddleware(thread_limit=5, exit_behavior="invalid")
# Test valid initialization
middleware = ModelCallLimitMiddleware(thread_limit=5, run_limit=3)
assert middleware.thread_limit == 5
assert middleware.run_limit == 3
assert middleware.exit_behavior == "end"
# Test with only thread limit
middleware = ModelCallLimitMiddleware(thread_limit=5)
assert middleware.thread_limit == 5
assert middleware.run_limit is None
# Test with only run limit
middleware = ModelCallLimitMiddleware(run_limit=3)
assert middleware.thread_limit is None
assert middleware.run_limit == 3
def test_exception_error_message():
"""Test that the exception provides clear error messages."""
middleware = ModelCallLimitMiddleware(thread_limit=2, run_limit=1, exit_behavior="error")
# Test thread limit exceeded
state = {"thread_model_call_count": 2, "run_model_call_count": 0}
with pytest.raises(ModelCallLimitExceededError) as exc_info:
middleware.before_model(state, None)
error_msg = str(exc_info.value)
assert "Model call limits exceeded" in error_msg
assert "thread limit (2/2)" in error_msg
# Test run limit exceeded
state = {"thread_model_call_count": 0, "run_model_call_count": 1}
with pytest.raises(ModelCallLimitExceededError) as exc_info:
middleware.before_model(state, None)
error_msg = str(exc_info.value)
assert "Model call limits exceeded" in error_msg
assert "run limit (1/1)" in error_msg
# Test both limits exceeded
state = {"thread_model_call_count": 2, "run_model_call_count": 1}
with pytest.raises(ModelCallLimitExceededError) as exc_info:
middleware.before_model(state, None)
error_msg = str(exc_info.value)
assert "Model call limits exceeded" in error_msg
assert "thread limit (2/2)" in error_msg
assert "run limit (1/1)" in error_msg
def test_run_limit_resets_between_invocations() -> None:
"""Test that run_model_call_count resets between invocations, but thread_model_call_count accumulates."""
# First: No tool calls per invocation, so model does not increment call counts internally
middleware = ModelCallLimitMiddleware(thread_limit=3, run_limit=1, exit_behavior="error")
model = FakeToolCallingModel(
tool_calls=[[], [], [], []]
) # No tool calls, so only model call per run
agent = create_agent(model=model, middleware=[middleware], checkpointer=InMemorySaver())
thread_config = {"configurable": {"thread_id": "test_thread"}}
agent.invoke({"messages": [HumanMessage("Hello")]}, thread_config)
agent.invoke({"messages": [HumanMessage("Hello again")]}, thread_config)
agent.invoke({"messages": [HumanMessage("Hello third")]}, thread_config)
# Fourth run: should raise, thread_model_call_count == 3 (limit)
with pytest.raises(ModelCallLimitExceededError) as exc_info:
agent.invoke({"messages": [HumanMessage("Hello fourth")]}, thread_config)
error_msg = str(exc_info.value)
assert "thread limit (3/3)" in error_msg

View File

@@ -2,16 +2,22 @@
from __future__ import annotations
import warnings
from typing import cast
import pytest
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
from langchain_core.messages import AIMessage
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain.agents.factory import create_agent
from langchain.agents.middleware.model_fallback import ModelFallbackMiddleware
from langchain.agents.middleware.types import ModelRequest, ModelResponse
from langgraph.runtime import Runtime
from ...model import FakeToolCallingModel
def _fake_runtime() -> Runtime:
return cast(Runtime, object())
@@ -40,7 +46,7 @@ def test_primary_model_succeeds() -> None:
middleware = ModelFallbackMiddleware(fallback_model)
request = _make_request()
request.model = primary_model
request = request.override(model=primary_model)
def mock_handler(req: ModelRequest) -> ModelResponse:
# Simulate successful model call
@@ -65,7 +71,7 @@ def test_fallback_on_primary_failure() -> None:
middleware = ModelFallbackMiddleware(fallback_model)
request = _make_request()
request.model = primary_model
request = request.override(model=primary_model)
def mock_handler(req: ModelRequest) -> ModelResponse:
result = req.model.invoke([])
@@ -90,7 +96,7 @@ def test_multiple_fallbacks() -> None:
middleware = ModelFallbackMiddleware(fallback1, fallback2)
request = _make_request()
request.model = primary_model
request = request.override(model=primary_model)
def mock_handler(req: ModelRequest) -> ModelResponse:
result = req.model.invoke([])
@@ -114,7 +120,7 @@ def test_all_models_fail() -> None:
middleware = ModelFallbackMiddleware(fallback_model)
request = _make_request()
request.model = primary_model
request = request.override(model=primary_model)
def mock_handler(req: ModelRequest) -> ModelResponse:
result = req.model.invoke([])
@@ -131,7 +137,7 @@ async def test_primary_model_succeeds_async() -> None:
middleware = ModelFallbackMiddleware(fallback_model)
request = _make_request()
request.model = primary_model
request = request.override(model=primary_model)
async def mock_handler(req: ModelRequest) -> ModelResponse:
# Simulate successful async model call
@@ -156,7 +162,7 @@ async def test_fallback_on_primary_failure_async() -> None:
middleware = ModelFallbackMiddleware(fallback_model)
request = _make_request()
request.model = primary_model
request = request.override(model=primary_model)
async def mock_handler(req: ModelRequest) -> ModelResponse:
result = await req.model.ainvoke([])
@@ -181,7 +187,7 @@ async def test_multiple_fallbacks_async() -> None:
middleware = ModelFallbackMiddleware(fallback1, fallback2)
request = _make_request()
request.model = primary_model
request = request.override(model=primary_model)
async def mock_handler(req: ModelRequest) -> ModelResponse:
result = await req.model.ainvoke([])
@@ -205,7 +211,7 @@ async def test_all_models_fail_async() -> None:
middleware = ModelFallbackMiddleware(fallback_model)
request = _make_request()
request.model = primary_model
request = request.override(model=primary_model)
async def mock_handler(req: ModelRequest) -> ModelResponse:
result = await req.model.ainvoke([])
@@ -213,3 +219,133 @@ async def test_all_models_fail_async() -> None:
with pytest.raises(ValueError, match="Model failed"):
await middleware.awrap_model_call(request, mock_handler)
def test_model_fallback_middleware_with_agent() -> None:
"""Test ModelFallbackMiddleware with agent.invoke and fallback models only."""
class FailingModel(BaseChatModel):
"""Model that always fails."""
def _generate(self, messages, **kwargs):
raise ValueError("Primary model failed")
@property
def _llm_type(self):
return "failing"
class SuccessModel(BaseChatModel):
"""Model that succeeds."""
def _generate(self, messages, **kwargs):
return ChatResult(
generations=[ChatGeneration(message=AIMessage(content="Fallback success"))]
)
@property
def _llm_type(self):
return "success"
primary = FailingModel()
fallback = SuccessModel()
# Only pass fallback models to middleware (not the primary)
fallback_middleware = ModelFallbackMiddleware(fallback)
agent = create_agent(model=primary, middleware=[fallback_middleware])
result = agent.invoke({"messages": [HumanMessage("Test")]})
# Should have succeeded with fallback model
assert len(result["messages"]) == 2
assert result["messages"][1].content == "Fallback success"
def test_model_fallback_middleware_exhausted_with_agent() -> None:
"""Test ModelFallbackMiddleware with agent.invoke when all models fail."""
class AlwaysFailingModel(BaseChatModel):
"""Model that always fails."""
def __init__(self, name: str):
super().__init__()
self.name = name
def _generate(self, messages, **kwargs):
raise ValueError(f"{self.name} failed")
@property
def _llm_type(self):
return self.name
primary = AlwaysFailingModel("primary")
fallback1 = AlwaysFailingModel("fallback1")
fallback2 = AlwaysFailingModel("fallback2")
# Primary fails (attempt 1), then fallback1 (attempt 2), then fallback2 (attempt 3)
fallback_middleware = ModelFallbackMiddleware(fallback1, fallback2)
agent = create_agent(model=primary, middleware=[fallback_middleware])
# Should fail with the last fallback's error
with pytest.raises(ValueError, match="fallback2 failed"):
agent.invoke({"messages": [HumanMessage("Test")]})
def test_model_fallback_middleware_initialization() -> None:
"""Test ModelFallbackMiddleware initialization."""
# Test with no models - now a TypeError (missing required argument)
with pytest.raises(TypeError):
ModelFallbackMiddleware() # type: ignore[call-arg]
# Test with one fallback model (valid)
middleware = ModelFallbackMiddleware(FakeToolCallingModel())
assert len(middleware.models) == 1
# Test with multiple fallback models
middleware = ModelFallbackMiddleware(FakeToolCallingModel(), FakeToolCallingModel())
assert len(middleware.models) == 2
def test_model_request_is_frozen() -> None:
"""Test that ModelRequest raises deprecation warning on direct attribute assignment."""
request = _make_request()
new_model = GenericFakeChatModel(messages=iter([AIMessage(content="new model")]))
# Direct attribute assignment should raise DeprecationWarning but still work
with pytest.warns(
DeprecationWarning, match="Direct attribute assignment to ModelRequest.model is deprecated"
):
request.model = new_model # type: ignore[misc]
# Verify the assignment actually worked
assert request.model == new_model
with pytest.warns(
DeprecationWarning,
match="Direct attribute assignment to ModelRequest.system_prompt is deprecated",
):
request.system_prompt = "new prompt" # type: ignore[misc]
assert request.system_prompt == "new prompt"
with pytest.warns(
DeprecationWarning,
match="Direct attribute assignment to ModelRequest.messages is deprecated",
):
request.messages = [] # type: ignore[misc]
assert request.messages == []
# Using override method should work without warnings
request2 = _make_request()
with warnings.catch_warnings():
warnings.simplefilter("error") # Turn warnings into errors
new_request = request2.override(model=new_model, system_prompt="override prompt")
assert new_request.model == new_model
assert new_request.system_prompt == "override prompt"
# Original request should be unchanged
assert request2.model != new_model
assert request2.system_prompt != "override prompt"

View File

@@ -14,7 +14,7 @@ from langchain.agents.middleware.pii import (
)
from langchain.agents.factory import create_agent
from .model import FakeToolCallingModel
from tests.unit_tests.agents.model import FakeToolCallingModel
# ============================================================================

View File

@@ -0,0 +1,556 @@
from __future__ import annotations
import asyncio
import gc
import tempfile
import time
from pathlib import Path
import pytest
from langchain_core.messages import AIMessage, ToolMessage
from langchain_core.tools.base import ToolException
from langchain.agents.middleware.shell_tool import (
HostExecutionPolicy,
RedactionRule,
ShellToolMiddleware,
_SessionResources,
_ShellToolInput,
)
from langchain.agents.middleware.types import AgentState
def _empty_state() -> AgentState:
return {"messages": []} # type: ignore[return-value]
def test_executes_command_and_persists_state(tmp_path: Path) -> None:
workspace = tmp_path / "workspace"
middleware = ShellToolMiddleware(workspace_root=workspace)
try:
state: AgentState = _empty_state()
updates = middleware.before_agent(state, None)
if updates:
state.update(updates)
resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
middleware._run_shell_tool(resources, {"command": "cd /"}, tool_call_id=None)
result = middleware._run_shell_tool(resources, {"command": "pwd"}, tool_call_id=None)
assert isinstance(result, str)
assert result.strip() == "/"
echo_result = middleware._run_shell_tool(
resources, {"command": "echo ready"}, tool_call_id=None
)
assert "ready" in echo_result
finally:
updates = middleware.after_agent(state, None)
if updates:
state.update(updates)
def test_restart_resets_session_environment(tmp_path: Path) -> None:
middleware = ShellToolMiddleware(workspace_root=tmp_path / "workspace")
try:
state: AgentState = _empty_state()
updates = middleware.before_agent(state, None)
if updates:
state.update(updates)
resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
middleware._run_shell_tool(resources, {"command": "export FOO=bar"}, tool_call_id=None)
restart_message = middleware._run_shell_tool(
resources, {"restart": True}, tool_call_id=None
)
assert "restarted" in restart_message.lower()
resources = middleware._get_or_create_resources(state) # reacquire after restart
result = middleware._run_shell_tool(
resources, {"command": "echo ${FOO:-unset}"}, tool_call_id=None
)
assert "unset" in result
finally:
updates = middleware.after_agent(state, None)
if updates:
state.update(updates)
def test_truncation_indicator_present(tmp_path: Path) -> None:
policy = HostExecutionPolicy(max_output_lines=5, command_timeout=5.0)
middleware = ShellToolMiddleware(workspace_root=tmp_path / "workspace", execution_policy=policy)
try:
state: AgentState = _empty_state()
updates = middleware.before_agent(state, None)
if updates:
state.update(updates)
resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
result = middleware._run_shell_tool(resources, {"command": "seq 1 20"}, tool_call_id=None)
assert "Output truncated" in result
finally:
updates = middleware.after_agent(state, None)
if updates:
state.update(updates)
def test_timeout_returns_error(tmp_path: Path) -> None:
policy = HostExecutionPolicy(command_timeout=0.5)
middleware = ShellToolMiddleware(workspace_root=tmp_path / "workspace", execution_policy=policy)
try:
state: AgentState = _empty_state()
updates = middleware.before_agent(state, None)
if updates:
state.update(updates)
resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
start = time.monotonic()
result = middleware._run_shell_tool(resources, {"command": "sleep 2"}, tool_call_id=None)
elapsed = time.monotonic() - start
assert elapsed < policy.command_timeout + 2.0
assert "timed out" in result.lower()
finally:
updates = middleware.after_agent(state, None)
if updates:
state.update(updates)
def test_redaction_policy_applies(tmp_path: Path) -> None:
middleware = ShellToolMiddleware(
workspace_root=tmp_path / "workspace",
redaction_rules=(RedactionRule(pii_type="email", strategy="redact"),),
)
try:
state: AgentState = _empty_state()
updates = middleware.before_agent(state, None)
if updates:
state.update(updates)
resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
message = middleware._run_shell_tool(
resources,
{"command": "printf 'Contact: user@example.com\\n'"},
tool_call_id=None,
)
assert "[REDACTED_EMAIL]" in message
assert "user@example.com" not in message
finally:
updates = middleware.after_agent(state, None)
if updates:
state.update(updates)
def test_startup_and_shutdown_commands(tmp_path: Path) -> None:
workspace = tmp_path / "workspace"
middleware = ShellToolMiddleware(
workspace_root=workspace,
startup_commands=("touch startup.txt",),
shutdown_commands=("touch shutdown.txt",),
)
try:
state: AgentState = _empty_state()
updates = middleware.before_agent(state, None)
if updates:
state.update(updates)
assert (workspace / "startup.txt").exists()
finally:
updates = middleware.after_agent(state, None)
if updates:
state.update(updates)
assert (workspace / "shutdown.txt").exists()
def test_session_resources_finalizer_cleans_up(tmp_path: Path) -> None:
policy = HostExecutionPolicy(termination_timeout=0.1)
class DummySession:
def __init__(self) -> None:
self.stopped: bool = False
def stop(self, timeout: float) -> None: # noqa: ARG002
self.stopped = True
session = DummySession()
tempdir = tempfile.TemporaryDirectory(dir=tmp_path)
tempdir_path = Path(tempdir.name)
resources = _SessionResources(session=session, tempdir=tempdir, policy=policy) # type: ignore[arg-type]
finalizer = resources._finalizer
# Drop our last strong reference and force collection.
del resources
gc.collect()
assert not finalizer.alive
assert session.stopped
assert not tempdir_path.exists()
def test_shell_tool_input_validation() -> None:
"""Test _ShellToolInput validation rules."""
# Both command and restart not allowed
with pytest.raises(ValueError, match="only one"):
_ShellToolInput(command="ls", restart=True)
# Neither command nor restart provided
with pytest.raises(ValueError, match="requires either"):
_ShellToolInput()
# Valid: command only
valid_cmd = _ShellToolInput(command="ls")
assert valid_cmd.command == "ls"
assert not valid_cmd.restart
# Valid: restart only
valid_restart = _ShellToolInput(restart=True)
assert valid_restart.restart is True
assert valid_restart.command is None
def test_normalize_shell_command_empty() -> None:
"""Test that empty shell command raises an error."""
with pytest.raises(ValueError, match="at least one argument"):
ShellToolMiddleware(shell_command=[])
def test_normalize_env_non_string_keys() -> None:
"""Test that non-string environment keys raise an error."""
with pytest.raises(TypeError, match="must be strings"):
ShellToolMiddleware(env={123: "value"}) # type: ignore[dict-item]
def test_normalize_env_coercion(tmp_path: Path) -> None:
"""Test that environment values are coerced to strings."""
middleware = ShellToolMiddleware(
workspace_root=tmp_path / "workspace", env={"NUM": 42, "BOOL": True}
)
try:
state: AgentState = _empty_state()
updates = middleware.before_agent(state, None)
if updates:
state.update(updates)
resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
result = middleware._run_shell_tool(
resources, {"command": "echo $NUM $BOOL"}, tool_call_id=None
)
assert "42" in result
assert "True" in result
finally:
updates = middleware.after_agent(state, None)
if updates:
state.update(updates)
def test_shell_tool_missing_command_string(tmp_path: Path) -> None:
"""Test that shell tool raises an error when command is not a string."""
middleware = ShellToolMiddleware(workspace_root=tmp_path / "workspace")
try:
state: AgentState = _empty_state()
updates = middleware.before_agent(state, None)
if updates:
state.update(updates)
resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
with pytest.raises(ToolException, match="expects a 'command' string"):
middleware._run_shell_tool(resources, {"command": None}, tool_call_id=None)
with pytest.raises(ToolException, match="expects a 'command' string"):
middleware._run_shell_tool(
resources,
{"command": 123}, # type: ignore[dict-item]
tool_call_id=None,
)
finally:
updates = middleware.after_agent(state, None)
if updates:
state.update(updates)
def test_tool_message_formatting_with_id(tmp_path: Path) -> None:
"""Test that tool messages are properly formatted with tool_call_id."""
middleware = ShellToolMiddleware(workspace_root=tmp_path / "workspace")
try:
state: AgentState = _empty_state()
updates = middleware.before_agent(state, None)
if updates:
state.update(updates)
resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
result = middleware._run_shell_tool(
resources, {"command": "echo test"}, tool_call_id="test-id-123"
)
assert isinstance(result, ToolMessage)
assert result.tool_call_id == "test-id-123"
assert result.name == "shell"
assert result.status == "success"
assert "test" in result.content
finally:
updates = middleware.after_agent(state, None)
if updates:
state.update(updates)
def test_nonzero_exit_code_returns_error(tmp_path: Path) -> None:
"""Test that non-zero exit codes are marked as errors."""
middleware = ShellToolMiddleware(workspace_root=tmp_path / "workspace")
try:
state: AgentState = _empty_state()
updates = middleware.before_agent(state, None)
if updates:
state.update(updates)
resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
result = middleware._run_shell_tool(
resources,
{"command": "false"}, # Command that exits with 1 but doesn't kill shell
tool_call_id="test-id",
)
assert isinstance(result, ToolMessage)
assert result.status == "error"
assert "Exit code: 1" in result.content
assert result.artifact["exit_code"] == 1 # type: ignore[index]
finally:
updates = middleware.after_agent(state, None)
if updates:
state.update(updates)
def test_truncation_by_bytes(tmp_path: Path) -> None:
"""Test that output is truncated by bytes when max_output_bytes is exceeded."""
policy = HostExecutionPolicy(max_output_bytes=50, command_timeout=5.0)
middleware = ShellToolMiddleware(workspace_root=tmp_path / "workspace", execution_policy=policy)
try:
state: AgentState = _empty_state()
updates = middleware.before_agent(state, None)
if updates:
state.update(updates)
resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
result = middleware._run_shell_tool(
resources, {"command": "python3 -c 'print(\"x\" * 100)'"}, tool_call_id=None
)
assert "truncated at 50 bytes" in result.lower()
finally:
updates = middleware.after_agent(state, None)
if updates:
state.update(updates)
def test_startup_command_failure(tmp_path: Path) -> None:
"""Test that startup command failure raises an error."""
policy = HostExecutionPolicy(startup_timeout=1.0)
middleware = ShellToolMiddleware(
workspace_root=tmp_path / "workspace", startup_commands=("exit 1",), execution_policy=policy
)
state: AgentState = _empty_state()
with pytest.raises(RuntimeError, match="Startup command.*failed"):
middleware.before_agent(state, None)
def test_shutdown_command_failure_logged(tmp_path: Path) -> None:
"""Test that shutdown command failures are logged but don't raise."""
policy = HostExecutionPolicy(command_timeout=1.0)
middleware = ShellToolMiddleware(
workspace_root=tmp_path / "workspace",
shutdown_commands=("exit 1",),
execution_policy=policy,
)
try:
state: AgentState = _empty_state()
updates = middleware.before_agent(state, None)
if updates:
state.update(updates)
finally:
# Should not raise despite shutdown command failing
middleware.after_agent(state, None)
def test_shutdown_command_timeout_logged(tmp_path: Path) -> None:
"""Test that shutdown command timeouts are logged but don't raise."""
policy = HostExecutionPolicy(command_timeout=0.1)
middleware = ShellToolMiddleware(
workspace_root=tmp_path / "workspace",
execution_policy=policy,
shutdown_commands=("sleep 2",),
)
try:
state: AgentState = _empty_state()
updates = middleware.before_agent(state, None)
if updates:
state.update(updates)
finally:
# Should not raise despite shutdown command timing out
middleware.after_agent(state, None)
def test_empty_output_replaced_with_no_output(tmp_path: Path) -> None:
"""Test that empty command output is replaced with '<no output>'."""
middleware = ShellToolMiddleware(workspace_root=tmp_path / "workspace")
try:
state: AgentState = _empty_state()
updates = middleware.before_agent(state, None)
if updates:
state.update(updates)
resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
result = middleware._run_shell_tool(
resources,
{"command": "true"}, # Command that produces no output
tool_call_id=None,
)
assert "<no output>" in result
finally:
updates = middleware.after_agent(state, None)
if updates:
state.update(updates)
def test_stderr_output_labeling(tmp_path: Path) -> None:
"""Test that stderr output is properly labeled."""
middleware = ShellToolMiddleware(workspace_root=tmp_path / "workspace")
try:
state: AgentState = _empty_state()
updates = middleware.before_agent(state, None)
if updates:
state.update(updates)
resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
result = middleware._run_shell_tool(
resources, {"command": "echo error >&2"}, tool_call_id=None
)
assert "[stderr] error" in result
finally:
updates = middleware.after_agent(state, None)
if updates:
state.update(updates)
@pytest.mark.parametrize(
("startup_commands", "expected"),
[
("echo test", ("echo test",)), # String
(["echo test", "pwd"], ("echo test", "pwd")), # List
(("echo test",), ("echo test",)), # Tuple
(None, ()), # None
],
)
def test_normalize_commands_string_tuple_list(
tmp_path: Path,
startup_commands: str | list[str] | tuple[str, ...] | None,
expected: tuple[str, ...],
) -> None:
"""Test various command normalization formats."""
middleware = ShellToolMiddleware(
workspace_root=tmp_path / "workspace", startup_commands=startup_commands
)
assert middleware._startup_commands == expected # type: ignore[attr-defined]
def test_async_methods_delegate_to_sync(tmp_path: Path) -> None:
"""Test that async methods properly delegate to sync methods."""
middleware = ShellToolMiddleware(workspace_root=tmp_path / "workspace")
try:
state: AgentState = _empty_state()
# Test abefore_agent
updates = asyncio.run(middleware.abefore_agent(state, None))
if updates:
state.update(updates)
# Test aafter_agent
asyncio.run(middleware.aafter_agent(state, None))
finally:
pass
def test_shell_middleware_resumable_after_interrupt(tmp_path: Path) -> None:
"""Test that shell middleware is resumable after an interrupt.
This test simulates a scenario where:
1. The middleware creates a shell session
2. A command is executed
3. The agent is interrupted (state is preserved)
4. The agent resumes with the same state
5. The shell session is reused (not recreated)
"""
workspace = tmp_path / "workspace"
middleware = ShellToolMiddleware(workspace_root=workspace)
# Simulate first execution (before interrupt)
state: AgentState = _empty_state()
updates = middleware.before_agent(state, None)
if updates:
state.update(updates)
# Get the resources and verify they exist
resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
initial_session = resources.session
initial_tempdir = resources.tempdir
# Execute a command to set state
middleware._run_shell_tool(resources, {"command": "export TEST_VAR=hello"}, tool_call_id=None)
# Simulate interrupt - state is preserved, but we don't call after_agent
# In a real scenario, the state would be checkpointed here
# Simulate resumption - call before_agent again with same state
# This should reuse existing resources, not create new ones
updates = middleware.before_agent(state, None)
if updates:
state.update(updates)
# Get resources again - should be the same session
resumed_resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
# Verify the session was reused (same object reference)
assert resumed_resources.session is initial_session
assert resumed_resources.tempdir is initial_tempdir
# Verify the session state persisted (environment variable still set)
result = middleware._run_shell_tool(
resumed_resources, {"command": "echo ${TEST_VAR:-unset}"}, tool_call_id=None
)
assert "hello" in result
assert "unset" not in result
# Clean up
middleware.after_agent(state, None)
def test_get_or_create_resources_creates_when_missing(tmp_path: Path) -> None:
"""Test that _get_or_create_resources creates resources when they don't exist."""
workspace = tmp_path / "workspace"
middleware = ShellToolMiddleware(workspace_root=workspace)
state: AgentState = _empty_state()
# State has no resources initially
assert "shell_session_resources" not in state
# Call _get_or_create_resources - should create new resources
resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
assert isinstance(resources, _SessionResources)
assert resources.session is not None
assert state.get("shell_session_resources") is resources
# Clean up
resources._finalizer()
def test_get_or_create_resources_reuses_existing(tmp_path: Path) -> None:
"""Test that _get_or_create_resources reuses existing resources."""
workspace = tmp_path / "workspace"
middleware = ShellToolMiddleware(workspace_root=workspace)
state: AgentState = _empty_state()
# Create resources first time
resources1 = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
# Call again - should return the same resources
resources2 = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
assert resources1 is resources2
assert resources1.session is resources2.session
# Clean up
resources1._finalizer()

Some files were not shown because too many files have changed in this diff Show More