Compare commits

...

53 Commits

Author SHA1 Message Date
Sydney Runkle
6ba1177f4f Merge branch 'master' into sr/looser-api-for-summarization 2025-11-17 08:55:42 -05:00
Mason Daugherty
52b1516d44 style(langchain): fix some middleware ref syntax (#33988) 2025-11-16 00:33:17 -05:00
Mason Daugherty
8a3bb73c05 release(openai): 1.0.3 (#33981)
- Respect 300k token limit for embeddings API requests #33668
- fix create_agent / response_format for Responses API #33939
- fix response.incomplete event is not handled when using
stream_mode=['messages'] #33871
2025-11-14 19:18:50 -05:00
Mason Daugherty
099c042395 refactor(openai): embedding utils and calculations (#33982)
Now returns (`_iter`, `tokens`, `indices`, token_counts`). The
`token_counts` are calculated directly during tokenization, which is
more accurate and efficient than splitting strings later.
2025-11-14 19:18:37 -05:00
Kaparthy Reddy
2d4f00a451 fix(openai): Respect 300k token limit for embeddings API requests (#33668)
## Description

Fixes #31227 - Resolves the issue where `OpenAIEmbeddings` exceeds
OpenAI's 300,000 token per request limit, causing 400 BadRequest errors.

## Problem

When embedding large document sets, LangChain would send batches
containing more than 300,000 tokens in a single API request, causing
this error:
```
openai.BadRequestError: Error code: 400 - {'error': {'message': 'Requested 673477 tokens, max 300000 tokens per request'}}
```

The issue occurred because:
- The code chunks texts by `embedding_ctx_length` (8191 tokens per
chunk)
- Then batches chunks by `chunk_size` (default 1000 chunks per request)
- **But didn't check**: Total tokens per batch against OpenAI's 300k
limit
- Result: `1000 chunks × 8191 tokens = 8,191,000 tokens` → Exceeds
limit!

## Solution

This PR implements dynamic batching that respects the 300k token limit:

1. **Added constant**: `MAX_TOKENS_PER_REQUEST = 300000`
2. **Track token counts**: Calculate actual tokens for each chunk
3. **Dynamic batching**: Instead of fixed `chunk_size` batches,
accumulate chunks until approaching the 300k limit
4. **Applied to both sync and async**: Fixed both
`_get_len_safe_embeddings` and `_aget_len_safe_embeddings`

## Changes

- Modified `langchain_openai/embeddings/base.py`:
  - Added `MAX_TOKENS_PER_REQUEST` constant
  - Replaced fixed-size batching with token-aware dynamic batching
  - Applied to both sync (line ~478) and async (line ~527) methods
- Added test in `tests/unit_tests/embeddings/test_base.py`:
- `test_embeddings_respects_token_limit()` - Verifies large document
sets are properly batched

## Testing

All existing tests pass (280 passed, 4 xfailed, 1 xpassed).

New test verifies:
- Large document sets (500 texts × 1000 tokens = 500k tokens) are split
into multiple API calls
- Each API call respects the 300k token limit

## Usage

After this fix, users can embed large document sets without errors:
```python
from langchain_openai import OpenAIEmbeddings
from langchain_chroma import Chroma
from langchain_text_splitters import CharacterTextSplitter

# This will now work without exceeding token limits
embeddings = OpenAIEmbeddings()
documents = CharacterTextSplitter().split_documents(large_documents)
Chroma.from_documents(documents, embeddings)
```

Resolves #31227

---------

Co-authored-by: Kaparthy Reddy <kaparthyreddy@Kaparthys-MacBook-Air.local>
Co-authored-by: Chester Curme <chester.curme@gmail.com>
Co-authored-by: Mason Daugherty <mason@langchain.dev>
Co-authored-by: Mason Daugherty <github@mdrxy.com>
2025-11-14 18:12:07 -05:00
Sydney Runkle
9bd401a6d4 fix: resumable shell, works w/ interrupts (#33978)
fixes https://github.com/langchain-ai/langchain/issues/33684

Now able to run this minimal snippet successfully

```py
import os

from langchain.agents import create_agent
from langchain.agents.middleware import (
    HostExecutionPolicy,
    HumanInTheLoopMiddleware,
    ShellToolMiddleware,
)
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.types import Command


shell_middleware = ShellToolMiddleware(
    workspace_root=os.getcwd(),
    env=os.environ,  # danger
    execution_policy=HostExecutionPolicy()
)

hil_middleware = HumanInTheLoopMiddleware(interrupt_on={"shell": True})

checkpointer = InMemorySaver()

agent = create_agent(
    "openai:gpt-4.1-mini",
    middleware=[shell_middleware, hil_middleware],
    checkpointer=checkpointer,
)

input_message = {"role": "user", "content": "run `which python`"}

config = {"configurable": {"thread_id": "1"}}

result = agent.invoke(
    {"messages": [input_message]},
    config=config,
    durability="exit",
)
```
2025-11-14 15:32:25 -05:00
Sydney Runkle
62c05e09c1 readable tests 2025-11-14 14:34:24 -05:00
Sydney Runkle
83b9d9f810 parametrize 2025-11-14 14:31:28 -05:00
ccurme
6aa3794b74 feat(langchain): reference model profiles for provider strategy (#33974) 2025-11-14 19:24:18 +00:00
Sydney Runkle
6deee23d8d Revert "cleanup"
This reverts commit 690aabe8d4.
2025-11-14 14:14:32 -05:00
Sydney Runkle
690aabe8d4 cleanup 2025-11-14 14:12:14 -05:00
Sydney Runkle
80554df1e6 docs 2025-11-14 14:00:15 -05:00
Sydney Runkle
72c45e65e8 summarization edits 2025-11-14 13:56:26 -05:00
Sydney Runkle
189dcf7295 chore: increase coverage for shell, filesystem, and summarization middleware (#33928)
cc generated, just a start here but wanted to bump things up from 70%
ish
2025-11-14 13:30:36 -05:00
Sydney Runkle
1bc88028e6 fix(anthropic): execute bash + file tools via tool node (#33960)
* use `override` instead of directly patching things on `ModelRequest`
* rely on `ToolNode` for execution of tools related to said middleware,
using `wrap_model_call` to inject the relevant claude tool specs +
allowing tool node to forward them along to corresponding langchain tool
implementations
* making the same change for the native shell tool middleware
* allowing shell tool middleware to specify a name for the shell tool
(negative diff then for claude bash middleware)


long term I think the solution might be to attach metadata to a tool to
map the provider spec to a langchain implementation, which we could also
take some lessons from on the MCP front.
2025-11-14 13:17:01 -05:00
Mason Daugherty
d2942351ce release(core): 1.0.5 (#33973) 2025-11-14 11:51:27 -05:00
Sydney Runkle
83c078f363 fix: adding missing async hooks (#33957)
* filling in missing async gaps
* using recommended tool runtime injection instead of injected state
  * updating tests to use helper function as well
2025-11-14 09:13:39 -05:00
ZhangShenao
26d39ffc4a docs: Fix doc links (#33964) 2025-11-14 09:07:32 -05:00
Mason Daugherty
421e2ceeee fix(core): don't mask exceptions (#33959) 2025-11-14 09:05:29 -05:00
Mason Daugherty
275dcbf69f docs(core): add clarity to base token counting methods (#33958)
Wasn't immediately obvious that `get_num_tokens_from_messages` adds
additional prefixes to represent user roles in conversation, which adds
to the overall token count.

```python
from langchain_google_genai import GoogleGenerativeAI

llm = GoogleGenerativeAI(model="gemini-2.5-flash")
num_tokens = llm.get_num_tokens("Hello, world!")
print(f"Number of tokens: {num_tokens}")
# Number of tokens: 4
```

```python
from langchain.messages import HumanMessage

messages = [HumanMessage(content="Hello, world!")]

num_tokens = llm.get_num_tokens_from_messages(messages)
print(f"Number of tokens: {num_tokens}")
# Number of tokens: 6
```
2025-11-13 17:15:47 -05:00
Sydney Runkle
9f87b27a5b fix: add filesystem middleware in init (#33955) 2025-11-13 15:07:33 -05:00
Mason Daugherty
b2e1196e29 chore(core,infra): nits (#33954) 2025-11-13 14:50:54 -05:00
Sydney Runkle
2dc1396380 chore(langchain): update deps (#33951) 2025-11-13 14:21:25 -05:00
Mason Daugherty
77941ab3ce feat(infra): add automatic issue labeling (#33952) 2025-11-13 14:13:52 -05:00
Mason Daugherty
ee19a30dde fix(groq): bump min ver for core dep (#33949)
Due to issue with unit tests and docs URL for exceptions
2025-11-13 11:46:54 -05:00
Mason Daugherty
5d799b3174 release(nomic): 1.0.1 (#33948)
support Python 3.14 #33655
2025-11-13 11:25:39 -05:00
Mason Daugherty
8f33a985a2 release(groq): 1.0.1 (#33947)
- fix: handle tool calls with no args #33896
- add prompt caching token usage details #33708
2025-11-13 11:25:00 -05:00
Mason Daugherty
78eeccef0e release(deepseek): 1.0.1 (#33946)
- support strict beta structured output #32727
2025-11-13 11:24:39 -05:00
ccurme
3d415441e8 fix(langchain, openai): backward compat for response_format (#33945) 2025-11-13 11:11:35 -05:00
ccurme
74385e0ebd fix(langchain, openai): fix create_agent / response_format for Responses API (#33939) 2025-11-13 10:18:15 -05:00
Christophe Bornet
2bfbc29ccc chore(core): fix some ruff TC rules (#33929)
fix some ruff TC rules but still don't enforce them as Pydantic model
fields use type annotations at runtime.
2025-11-12 14:07:19 -05:00
Christophe Bornet
ef79c26f18 chore(cli,standard-tests,text-splitters): fix some ruff TC rules (#33934)
Co-authored-by: Mason Daugherty <mason@langchain.dev>
2025-11-12 14:06:31 -05:00
ccurme
fbe32c8e89 release(anthropic): 1.0.3 (#33935) 2025-11-12 10:55:28 -05:00
Mohammad Mohtashim
2511c28f92 feat(anthropic): support code_execution_20250825 (#33925) 2025-11-12 10:44:51 -05:00
Sydney Runkle
637bb1cbbc feat: refactor tests coverage (#33927)
middleware tests have gotten quite unwieldy, major restructuring, sets
the stage for coverage increase

this is super hard to review -- as a proof that we've retained important
tests, I ran coverage on `master` and this branch and confirmed
identical coverage.

* moving all middleware related tests to `agents/middleware` folder
* consolidating related test files
* adding coverage utility to makefile
2025-11-11 10:40:12 -05:00
Mason Daugherty
3dfea96ec1 chore: update README.md files (#33919) 2025-11-10 22:51:35 -05:00
ccurme
68643153e5 feat(langchain): support async summarization in SummarizationMiddleware (#33918) 2025-11-10 15:48:51 -05:00
Abbas Syed
462762f75b test(core): add comprehensive tests for groq block translator (#33906) 2025-11-10 15:45:36 -05:00
ccurme
4f3729c004 release(model-profiles): 0.0.4 (#33917) 2025-11-10 12:06:32 -05:00
Mason Daugherty
ba428cdf54 chore(infra): add note to pr linting workflow (#33916) 2025-11-10 11:49:31 -05:00
Mason Daugherty
69c7d1b01b test(groq,openai): add retries for flaky tests (#33914) 2025-11-10 10:36:11 -05:00
Mason Daugherty
733299ec13 revert(core): "applied secrets_map in load to plain string values" (#33913)
Reverts langchain-ai/langchain#33678

Breaking API change
2025-11-10 10:29:30 -05:00
ccurme
e1adf781c6 feat(langchain): (SummarizationMiddleware) support use of model context windows when triggering summarization (#33825) 2025-11-10 10:08:52 -05:00
Shahroz Ahmad
31b5e4810c feat(deepseek): support strict beta structured output (#32727)
**Description:** This PR adds support for DeepSeek's beta strict mode
feature for structured
outputs and tool calling. It overrides `bind_tools()` and
`with_structured_output()` to automatically use
DeepSeek's beta endpoint (https://api.deepseek.com/beta) when
`strict=True`. Both methods need overriding because they're independent
entry points and user can call either directly. When DeepSeek's strict
mode graduates from beta, we can just remove both overriden methods. You
can read more about the beta feature here:
https://api-docs.deepseek.com/guides/function_calling#strict-mode-beta
  
**Issue:** Implements #32670 


**Dependencies:** None


**Sample Code**

```python
from langchain_deepseek import ChatDeepSeek
from pydantic import BaseModel, Field
from typing import Optional
import os


# Enter your DeepSeek API Key here
API_KEY = "YOUR_API_KEY"


# location, temperature, condition are required fields
# humidity is optional field with default value
class WeatherInfo(BaseModel):
    location: str = Field(description="City name")
    temperature: int = Field(description="Temperature in Celsius")
    condition: str = Field(description="Weather condition (sunny, cloudy, rainy)")
    humidity: Optional[int] = Field(default=None, description="Humidity percentage")


llm = ChatDeepSeek(
    model="deepseek-chat",
    api_key=API_KEY,
)

# just to confirm that a new instance will use the default base url (instead of beta)
print(f"Default API base: {llm.api_base}")



# Test 1: bind_tools with strict=True shoud list all the tools calls
print("\nTest 1: bind_tools with strict=True")
llm_with_tools = llm.bind_tools([WeatherInfo], strict=True)
response = llm_with_tools.invoke("Tell me the weather in New York. It's 22 degrees, sunny.")
print(response.tool_calls)



# Test 2: with_structured_output with strict=True
print("\nTest 2: with_structured_output with strict=True")
structured_llm = llm.with_structured_output(WeatherInfo, strict=True)
result = structured_llm.invoke("Tell me the weather in New York.")
print(f"  Result: {result}")
assert isinstance(result, WeatherInfo), "Result should be a WeatherInfo instance"
```

---------

Co-authored-by: Mason Daugherty <mason@langchain.dev>
Co-authored-by: Mason Daugherty <github@mdrxy.com>
2025-11-09 22:24:33 -05:00
Mason Daugherty
c6801fe159 chore: fix URL underlining in README.md (#33905) 2025-11-09 22:22:56 -05:00
AmazingcatAndrew
1b563067f8 fix(chroma): resolve OpenCLIP + Chroma image embedding test regression (#33899)
**Description:**  
Fixes the OpenCLIP × Chroma regression that caused nested embedding
errors when adding or searching image data.
The test case `test_openclip_chroma_embed_no_nesting_error` has been
restored and verified to work correctly with the current LangChain core
dependencies.
Functional validation confirms that `similarity_search_by_image` now
returns correct, metadata‑preserving results.

**Issue:**  
Fixes #33851

**Dependencies:**  
No new dependencies introduced.  

**Testing:**  
All tests under  
```bash
uv run --group test pytest tests/unit_tests
```  
result:
```
30 passed in 91.26s (0:01:31)
```
have passed successfully using Python 3.13.9 and uv‑managed environment.
This confirms that the regression has been fixed.  

Running  
```bash
make test
```  
still produces cleanup‑time `AttributeError: 'ProactorEventLoop' object
has no attribute '_ssock'` on Windows (Python 3.13+).
This is a benign asyncio teardown message rather than a functional
failure.
`uv run pytest` closes event loops immediately after tests, while `make
test` invokes pytest through a secondary process layer that leaves a
background loop alive at interpreter shutdown.
This difference in teardown behavior explains the extra messages seen
only when using `make test`.

**Summary:**  
- Verified the OpenCLIP + Chroma image pipeline works correctly.  
- `uv run --group test pytest` fully passes; the fix is complete.  
- The residual `_ssock` warnings occur only during
Windows asyncio cleanup and are not related to this code change.

This is my first time contributing code, please contact me with any
questions

---

---------

Co-authored-by: Mason Daugherty <mason@langchain.dev>
Co-authored-by: Mason Daugherty <github@mdrxy.com>
2025-11-09 21:24:33 -05:00
Mason Daugherty
1996d81d72 chore(langchain): pass on reference docstrings (middleware) (#33904) 2025-11-09 21:18:28 -05:00
Mason Daugherty
ab0677c6f1 fix(groq): handle tool calls with no args (#33896)
When Groq returns tool calls with no arguments, it sends arguments:
`'null'` (JSON null), but LangChain's core parsing expects either a dict
or converts null to Python None, which fails the `isinstance(args_,
dict)` check and incorrectly marks the tool call as invalid.

Related to #32017
2025-11-08 22:30:44 -05:00
artreimus
bdb53c93cc docs(langchain): correct IBM provider link in chat_models docstring (#33897)
**PR title**

```
docs(langchain): correct IBM provider link in chat_models docstring
```

**PR message**

**Description**
Fix broken link in the `chat_models` docstring. The **ibm** bullet
incorrectly linked to the DeepSeek provider page; update it to the
canonical IBM provider docs.

This only affects generated API reference content on
`reference.langchain.com`. No runtime behavior changes.

**Issue**
N/A (documentation-only).

**Dependencies**
None.

**Testing & quality**

* Ran `make format`, `make lint`, and `make test` in the package (no
code changes expected to affect tests).
2025-11-08 07:02:33 -06:00
Alazar Genene
94d5271cb5 fix(standard-tests): fix semantic typo in if statement (#33890) 2025-11-07 18:01:59 -05:00
ccurme
e499db4266 release(langchain): 1.0.5 (#33893) 2025-11-07 17:54:43 -05:00
npage902
cc3af82b47 fix(core): applied secrets_map in load to plain string values (#33678)
Replaces #33618 

**Description:** Fixes the bug in the `load()` function where secret
placeholders in plain dicts were not replaced, even if they match a key
in `secrets_map`, and adds a test case.

Example:
```py
obj = {"api_key": "__SECRET_API_KEY__"}
secret_key = "secret_key_1234"
secrets_map = {"__SECRET_API_KEY__": secret_key}
result = load(obj, secrets_map=secrets_map)
```
Before this change, printing `api_key` in `result` would output
`"__SECRET_API_KEY__"`. Now, it will properly output
`"secret_key_1234"`.

**Issue:** Fixes #31804 

**Dependencies:** None

`make format`, `make lint`, and `make test` have all passed on my
machine.

---------

Co-authored-by: Mason Daugherty <mason@langchain.dev>
Co-authored-by: Mason Daugherty <github@mdrxy.com>
2025-11-07 17:14:13 -05:00
Mshari
9383b78be1 feat(groq): add prompt caching token usage details (#33708)
**Description:** 
Adds support for prompt caching usage metadata in ChatGroq. The
integration now captures cached token information from the Groq API
response and includes it in the `input_token_details` field of the
`usage_metadata`.

Changes:
- Created new `_create_usage_metadata()` helper function to centralize
usage metadata creation logic
- Extracts `cached_tokens` from `prompt_tokens_details` in API responses
and maps to `input_token_details.cache_read`
- Integrated the helper function in both streaming
(`_convert_chunk_to_message_chunk`) and non-streaming
(`_create_chat_result`) code paths
- Added comprehensive unit tests to verify caching metadata handling and
backward compatibility

This enables users to monitor prompt caching effectiveness when using
Groq models with prompt caching enabled.

**Issue:** N/A

**Dependencies:** None

---------

Co-authored-by: Mason Daugherty <github@mdrxy.com>
Co-authored-by: Mason Daugherty <mason@langchain.dev>
2025-11-07 17:05:22 -05:00
177 changed files with 12983 additions and 8147 deletions

View File

@@ -8,16 +8,15 @@ body:
value: |
Thank you for taking the time to file a bug report.
Use this to report BUGS in LangChain. For usage questions, feature requests and general design questions, please use the [LangChain Forum](https://forum.langchain.com/).
For usage questions, feature requests and general design questions, please use the [LangChain Forum](https://forum.langchain.com/).
Relevant links to check before filing a bug report to see if your issue has already been reported, fixed or
if there's another way to solve your problem:
Check these before submitting to see if your issue has already been reported, fixed or if there's another way to solve your problem:
* [LangChain Forum](https://forum.langchain.com/),
* [LangChain documentation with the integrated search](https://docs.langchain.com/oss/python/langchain/overview),
* [API Reference](https://reference.langchain.com/python/),
* [Documentation](https://docs.langchain.com/oss/python/langchain/overview),
* [API Reference Documentation](https://reference.langchain.com/python/),
* [LangChain ChatBot](https://chat.langchain.com/)
* [GitHub search](https://github.com/langchain-ai/langchain),
* [LangChain Forum](https://forum.langchain.com/),
- type: checkboxes
id: checks
attributes:
@@ -36,16 +35,48 @@ body:
required: true
- label: This is not related to the langchain-community package.
required: true
- label: I read what a minimal reproducible example is (https://stackoverflow.com/help/minimal-reproducible-example).
required: true
- label: I posted a self-contained, minimal, reproducible example. A maintainer can copy it and run it AS IS.
required: true
- type: checkboxes
id: package
attributes:
label: Package (Required)
description: |
Which `langchain` package(s) is this bug related to? Select at least one.
Note that if the package you are reporting for is not listed here, it is not in this repository (e.g. `langchain-google-genai` is in [`langchain-ai/langchain-google`](https://github.com/langchain-ai/langchain-google/)).
Please report issues for other packages to their respective repositories.
options:
- label: langchain
- label: langchain-openai
- label: langchain-anthropic
- label: langchain-classic
- label: langchain-core
- label: langchain-cli
- label: langchain-model-profiles
- label: langchain-tests
- label: langchain-text-splitters
- label: langchain-chroma
- label: langchain-deepseek
- label: langchain-exa
- label: langchain-fireworks
- label: langchain-groq
- label: langchain-huggingface
- label: langchain-mistralai
- label: langchain-nomic
- label: langchain-ollama
- label: langchain-perplexity
- label: langchain-prompty
- label: langchain-qdrant
- label: langchain-xai
- label: Other / not sure / general
- type: textarea
id: reproduction
validations:
required: true
attributes:
label: Example Code
label: Example Code (Python)
description: |
Please add a self-contained, [minimal, reproducible, example](https://stackoverflow.com/help/minimal-reproducible-example) with your use case.
@@ -53,15 +84,12 @@ body:
**Important!**
* Avoid screenshots when possible, as they are hard to read and (more importantly) don't allow others to copy-and-paste your code.
* Reduce your code to the minimum required to reproduce the issue if possible. This makes it much easier for others to help you.
* Use code tags (e.g., ```python ... ```) to correctly [format your code](https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting).
* INCLUDE the language label (e.g. `python`) after the first three backticks to enable syntax highlighting. (e.g., ```python rather than ```).
* Avoid screenshots, as they are hard to read and (more importantly) don't allow others to copy-and-paste your code.
* Reduce your code to the minimum required to reproduce the issue if possible.
(This will be automatically formatted into code, so no need for backticks.)
render: python
placeholder: |
The following code:
```python
from langchain_core.runnables import RunnableLambda
def bad_code(inputs) -> int:
@@ -69,17 +97,14 @@ body:
chain = RunnableLambda(bad_code)
chain.invoke('Hello!')
```
- type: textarea
id: error
validations:
required: false
attributes:
label: Error Message and Stack Trace (if applicable)
description: |
If you are reporting an error, please include the full error message and stack trace.
placeholder: |
Exception + full stack trace
If you are reporting an error, please copy and paste the full error message and
stack trace.
(This will be automatically formatted into code, so no need for backticks.)
render: shell
- type: textarea
id: description
attributes:
@@ -99,9 +124,7 @@ body:
attributes:
label: System Info
description: |
Please share your system info with us. Do NOT skip this step and please don't trim
the output. Most users don't include enough information here and it makes it harder
for us to help you.
Please share your system info with us.
Run the following command in your terminal and paste the output here:
@@ -113,8 +136,6 @@ body:
from langchain_core import sys_info
sys_info.print_sys_info()
```
alternatively, put the entire output of `pip freeze` here.
placeholder: |
python -m langchain_core.sys_info
validations:

View File

@@ -1,9 +1,18 @@
blank_issues_enabled: false
version: 2.1
contact_links:
- name: 📚 Documentation
- name: 📚 Documentation issue
url: https://github.com/langchain-ai/docs/issues/new?template=01-langchain.yml
about: Report an issue related to the LangChain documentation
- name: 💬 LangChain Forum
url: https://forum.langchain.com/
about: General community discussions and support
- name: 📚 LangChain Documentation
url: https://docs.langchain.com/oss/python/langchain/overview
about: View the official LangChain documentation
- name: 📚 API Reference Documentation
url: https://reference.langchain.com/python/
about: View the official LangChain API reference documentation
- name: 💬 LangChain Forum
url: https://forum.langchain.com/
about: Ask questions and get help from the community

View File

@@ -13,11 +13,11 @@ body:
Relevant links to check before filing a feature request to see if your request has already been made or
if there's another way to achieve what you want:
* [LangChain Forum](https://forum.langchain.com/),
* [LangChain documentation with the integrated search](https://docs.langchain.com/oss/python/langchain/overview),
* [API Reference](https://reference.langchain.com/python/),
* [Documentation](https://docs.langchain.com/oss/python/langchain/overview),
* [API Reference Documentation](https://reference.langchain.com/python/),
* [LangChain ChatBot](https://chat.langchain.com/)
* [GitHub search](https://github.com/langchain-ai/langchain),
* [LangChain Forum](https://forum.langchain.com/),
- type: checkboxes
id: checks
attributes:
@@ -34,6 +34,40 @@ body:
required: true
- label: This is not related to the langchain-community package.
required: true
- type: checkboxes
id: package
attributes:
label: Package (Required)
description: |
Which `langchain` package(s) is this request related to? Select at least one.
Note that if the package you are requesting for is not listed here, it is not in this repository (e.g. `langchain-google-genai` is in `langchain-ai/langchain`).
Please submit feature requests for other packages to their respective repositories.
options:
- label: langchain
- label: langchain-openai
- label: langchain-anthropic
- label: langchain-classic
- label: langchain-core
- label: langchain-cli
- label: langchain-model-profiles
- label: langchain-tests
- label: langchain-text-splitters
- label: langchain-chroma
- label: langchain-deepseek
- label: langchain-exa
- label: langchain-fireworks
- label: langchain-groq
- label: langchain-huggingface
- label: langchain-mistralai
- label: langchain-nomic
- label: langchain-ollama
- label: langchain-perplexity
- label: langchain-prompty
- label: langchain-qdrant
- label: langchain-xai
- label: Other / not sure / general
- type: textarea
id: feature-description
validations:

View File

@@ -18,3 +18,33 @@ body:
attributes:
label: Issue Content
description: Add the content of the issue here.
- type: checkboxes
id: package
attributes:
label: Package (Required)
description: |
Please select package(s) that this issue is related to.
options:
- label: langchain
- label: langchain-openai
- label: langchain-anthropic
- label: langchain-classic
- label: langchain-core
- label: langchain-cli
- label: langchain-model-profiles
- label: langchain-tests
- label: langchain-text-splitters
- label: langchain-chroma
- label: langchain-deepseek
- label: langchain-exa
- label: langchain-fireworks
- label: langchain-groq
- label: langchain-huggingface
- label: langchain-mistralai
- label: langchain-nomic
- label: langchain-ollama
- label: langchain-perplexity
- label: langchain-prompty
- label: langchain-qdrant
- label: langchain-xai
- label: Other / not sure / general

View File

@@ -25,13 +25,13 @@ body:
label: Task Description
description: |
Provide a clear and detailed description of the task.
What needs to be done? Be specific about the scope and requirements.
placeholder: |
This task involves...
The goal is to...
Specific requirements:
- ...
- ...
@@ -43,7 +43,7 @@ body:
label: Acceptance Criteria
description: |
Define the criteria that must be met for this task to be considered complete.
What are the specific deliverables or outcomes expected?
placeholder: |
This task will be complete when:
@@ -58,15 +58,15 @@ body:
label: Context and Background
description: |
Provide any relevant context, background information, or links to related issues/PRs.
Why is this task needed? What problem does it solve?
placeholder: |
Background:
- ...
Related issues/PRs:
- #...
Additional context:
- ...
validations:
@@ -77,15 +77,45 @@ body:
label: Dependencies
description: |
List any dependencies or blockers for this task.
Are there other tasks, issues, or external factors that need to be completed first?
placeholder: |
This task depends on:
- [ ] Issue #...
- [ ] PR #...
- [ ] External dependency: ...
Blocked by:
- ...
validations:
required: false
- type: checkboxes
id: package
attributes:
label: Package (Required)
description: |
Please select package(s) that this task is related to.
options:
- label: langchain
- label: langchain-openai
- label: langchain-anthropic
- label: langchain-classic
- label: langchain-core
- label: langchain-cli
- label: langchain-model-profiles
- label: langchain-tests
- label: langchain-text-splitters
- label: langchain-chroma
- label: langchain-deepseek
- label: langchain-exa
- label: langchain-fireworks
- label: langchain-groq
- label: langchain-huggingface
- label: langchain-mistralai
- label: langchain-nomic
- label: langchain-ollama
- label: langchain-perplexity
- label: langchain-prompty
- label: langchain-qdrant
- label: langchain-xai
- label: Other / not sure / general

View File

@@ -98,7 +98,7 @@ def _check_python_version_from_requirement(
return True
else:
marker_str = str(requirement.marker)
if "python_version" or "python_full_version" in marker_str:
if "python_version" in marker_str or "python_full_version" in marker_str:
python_version_str = "".join(
char
for char in marker_str

View File

@@ -0,0 +1,107 @@
name: Auto Label Issues by Package
on:
issues:
types: [opened, edited]
jobs:
label-by-package:
permissions:
issues: write
runs-on: ubuntu-latest
steps:
- name: Sync package labels
uses: actions/github-script@v6
with:
script: |
const body = context.payload.issue.body || "";
// Extract text under "### Package"
const match = body.match(/### Package\s+([\s\S]*?)\n###/i);
if (!match) return;
const packageSection = match[1].trim();
// Mapping table for package names to labels
const mapping = {
"langchain": "langchain",
"langchain-openai": "openai",
"langchain-anthropic": "anthropic",
"langchain-classic": "langchain-classic",
"langchain-core": "core",
"langchain-cli": "cli",
"langchain-model-profiles": "model-profiles",
"langchain-tests": "standard-tests",
"langchain-text-splitters": "text-splitters",
"langchain-chroma": "chroma",
"langchain-deepseek": "deepseek",
"langchain-exa": "exa",
"langchain-fireworks": "fireworks",
"langchain-groq": "groq",
"langchain-huggingface": "huggingface",
"langchain-mistralai": "mistralai",
"langchain-nomic": "nomic",
"langchain-ollama": "ollama",
"langchain-perplexity": "perplexity",
"langchain-prompty": "prompty",
"langchain-qdrant": "qdrant",
"langchain-xai": "xai",
};
// All possible package labels we manage
const allPackageLabels = Object.values(mapping);
const selectedLabels = [];
// Check if this is checkbox format (multiple selection)
const checkboxMatches = packageSection.match(/- \[x\]\s+([^\n\r]+)/gi);
if (checkboxMatches) {
// Handle checkbox format
for (const match of checkboxMatches) {
const packageName = match.replace(/- \[x\]\s+/i, '').trim();
const label = mapping[packageName];
if (label && !selectedLabels.includes(label)) {
selectedLabels.push(label);
}
}
} else {
// Handle dropdown format (single selection)
const label = mapping[packageSection];
if (label) {
selectedLabels.push(label);
}
}
// Get current issue labels
const issue = await github.rest.issues.get({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: context.issue.number
});
const currentLabels = issue.data.labels.map(label => label.name);
const currentPackageLabels = currentLabels.filter(label => allPackageLabels.includes(label));
// Determine labels to add and remove
const labelsToAdd = selectedLabels.filter(label => !currentPackageLabels.includes(label));
const labelsToRemove = currentPackageLabels.filter(label => !selectedLabels.includes(label));
// Add new labels
if (labelsToAdd.length > 0) {
await github.rest.issues.addLabels({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: context.issue.number,
labels: labelsToAdd
});
}
// Remove old labels
for (const label of labelsToRemove) {
await github.rest.issues.removeLabel({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: context.issue.number,
name: label
});
}

View File

@@ -26,12 +26,14 @@
# * revert — reverts a previous commit
# * release — prepare a new release
#
# Allowed Scopes (optional):
# Allowed Scope(s) (optional):
# core, cli, langchain, langchain_v1, langchain-classic, standard-tests,
# text-splitters, docs, anthropic, chroma, deepseek, exa, fireworks, groq,
# huggingface, mistralai, nomic, ollama, openai, perplexity, prompty, qdrant,
# xai, infra, deps
#
# Multiple scopes can be used by separating them with a comma.
#
# Rules:
# 1. The 'Type' must start with a lowercase letter.
# 2. Breaking changes: append "!" after type/scope (e.g., feat!: drop x support)

View File

@@ -1,38 +1,26 @@
<p align="center">
<picture>
<source media="(prefers-color-scheme: light)" srcset=".github/images/logo-dark.svg">
<source media="(prefers-color-scheme: dark)" srcset=".github/images/logo-light.svg">
<img alt="LangChain Logo" src=".github/images/logo-dark.svg" width="80%">
</picture>
</p>
<div align="center">
<a href="https://www.langchain.com/">
<picture>
<source media="(prefers-color-scheme: light)" srcset=".github/images/logo-dark.svg">
<source media="(prefers-color-scheme: dark)" srcset=".github/images/logo-light.svg">
<img alt="LangChain Logo" src=".github/images/logo-dark.svg" width="80%">
</picture>
</a>
</div>
<p align="center">
The platform for reliable agents.
</p>
<div align="center">
<h3>The platform for reliable agents.</h3>
</div>
<p align="center">
<a href="https://opensource.org/licenses/MIT" target="_blank">
<img src="https://img.shields.io/pypi/l/langchain" alt="PyPI - License">
</a>
<a href="https://pypistats.org/packages/langchain" target="_blank">
<img src="https://img.shields.io/pepy/dt/langchain" alt="PyPI - Downloads">
</a>
<a href="https://pypi.org/project/langchain/#history" target="_blank">
<img src="https://img.shields.io/pypi/v/langchain?label=%20" alt="Version">
</a>
<a href="https://vscode.dev/redirect?url=vscode://ms-vscode-remote.remote-containers/cloneInVolume?url=https://github.com/langchain-ai/langchain" target="_blank">
<img src="https://img.shields.io/static/v1?label=Dev%20Containers&message=Open&color=blue&logo=visualstudiocode" alt="Open in Dev Containers">
</a>
<a href="https://codespaces.new/langchain-ai/langchain" target="_blank">
<img src="https://github.com/codespaces/badge.svg" alt="Open in Github Codespace" title="Open in Github Codespace" width="150" height="20">
</a>
<a href="https://codspeed.io/langchain-ai/langchain" target="_blank">
<img src="https://img.shields.io/endpoint?url=https://codspeed.io/badge.json" alt="CodSpeed Badge">
</a>
<a href="https://twitter.com/langchainai" target="_blank">
<img src="https://img.shields.io/twitter/url/https/twitter.com/langchainai.svg?style=social&label=Follow%20%40LangChainAI" alt="Twitter / X">
</a>
</p>
<div align="center">
<a href="https://opensource.org/licenses/MIT" target="_blank"><img src="https://img.shields.io/pypi/l/langchain" alt="PyPI - License"></a>
<a href="https://pypistats.org/packages/langchain" target="_blank"><img src="https://img.shields.io/pepy/dt/langchain" alt="PyPI - Downloads"></a>
<a href="https://pypi.org/project/langchain/#history" target="_blank"><img src="https://img.shields.io/pypi/v/langchain?label=%20" alt="Version"></a>
<a href="https://vscode.dev/redirect?url=vscode://ms-vscode-remote.remote-containers/cloneInVolume?url=https://github.com/langchain-ai/langchain" target="_blank"><img src="https://img.shields.io/static/v1?label=Dev%20Containers&message=Open&color=blue&logo=visualstudiocode" alt="Open in Dev Containers"></a>
<a href="https://codespaces.new/langchain-ai/langchain" target="_blank"><img src="https://github.com/codespaces/badge.svg" alt="Open in Github Codespace" title="Open in Github Codespace" width="150" height="20"></a>
<a href="https://codspeed.io/langchain-ai/langchain" target="_blank"><img src="https://img.shields.io/endpoint?url=https://codspeed.io/badge.json" alt="CodSpeed Badge"></a>
<a href="https://twitter.com/langchainai" target="_blank"><img src="https://img.shields.io/twitter/url/https/twitter.com/langchainai.svg?style=social&label=Follow%20%40LangChainAI" alt="Twitter / X"></a>
</div>
LangChain is a framework for building agents and LLM-powered applications. It helps you chain together interoperable components and third-party integrations to simplify AI application development all while future-proofing decisions as the underlying technology evolves.

View File

@@ -6,9 +6,8 @@ import hashlib
import logging
import re
import shutil
from collections.abc import Sequence
from pathlib import Path
from typing import Any, TypedDict
from typing import TYPE_CHECKING, Any, TypedDict
from git import Repo
@@ -18,6 +17,9 @@ from langchain_cli.constants import (
DEFAULT_GIT_SUBDIRECTORY,
)
if TYPE_CHECKING:
from collections.abc import Sequence
logger = logging.getLogger(__name__)

View File

@@ -1,9 +1,11 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING
from .file import File
from .folder import Folder
if TYPE_CHECKING:
from .file import File
from .folder import Folder
@dataclass

View File

@@ -1,9 +1,12 @@
from __future__ import annotations
from pathlib import Path
from typing import TYPE_CHECKING
from .file import File
if TYPE_CHECKING:
from pathlib import Path
class Folder:
def __init__(self, name: str, *files: Folder | File) -> None:

View File

@@ -34,7 +34,7 @@ The LangChain ecosystem is built on top of `langchain-core`. Some of the benefit
## 📖 Documentation
For full documentation, see the [API reference](https://reference.langchain.com/python/langchain_core/).
For full documentation, see the [API reference](https://reference.langchain.com/python/langchain_core/). For conceptual guides, tutorials, and examples on using LangChain, see the [LangChain Docs](https://docs.langchain.com/oss/python/langchain/overview).
## 📕 Releases & Versioning

View File

@@ -5,13 +5,12 @@ from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any
from typing_extensions import Self
if TYPE_CHECKING:
from collections.abc import Sequence
from uuid import UUID
from tenacity import RetryCallState
from typing_extensions import Self
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.documents import Document

View File

@@ -39,7 +39,6 @@ from langchain_core.tracers.context import (
tracing_v2_callback_var,
)
from langchain_core.tracers.langchain import LangChainTracer
from langchain_core.tracers.schemas import Run
from langchain_core.tracers.stdout import ConsoleCallbackHandler
from langchain_core.utils.env import env_var_is_set
@@ -52,6 +51,7 @@ if TYPE_CHECKING:
from langchain_core.documents import Document
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk, LLMResult
from langchain_core.runnables.config import RunnableConfig
from langchain_core.tracers.schemas import Run
logger = logging.getLogger(__name__)

View File

@@ -6,16 +6,9 @@ import hashlib
import json
import uuid
import warnings
from collections.abc import (
AsyncIterable,
AsyncIterator,
Callable,
Iterable,
Iterator,
Sequence,
)
from itertools import islice
from typing import (
TYPE_CHECKING,
Any,
Literal,
TypedDict,
@@ -29,6 +22,16 @@ from langchain_core.exceptions import LangChainException
from langchain_core.indexing.base import DocumentIndex, RecordManager
from langchain_core.vectorstores import VectorStore
if TYPE_CHECKING:
from collections.abc import (
AsyncIterable,
AsyncIterator,
Callable,
Iterable,
Iterator,
Sequence,
)
# Magic UUID to use as a namespace for hashing.
# Used to try and generate a unique UUID for each document
# from hashing the document content and metadata.

View File

@@ -299,6 +299,9 @@ class BaseLanguageModel(
Useful for checking if an input fits in a model's context window.
This should be overridden by model-specific implementations to provide accurate
token counts via model-specific tokenizers.
Args:
text: The string input to tokenize.
@@ -317,9 +320,17 @@ class BaseLanguageModel(
Useful for checking if an input fits in a model's context window.
This should be overridden by model-specific implementations to provide accurate
token counts via model-specific tokenizers.
!!! note
The base implementation of `get_num_tokens_from_messages` ignores tool
schemas.
* The base implementation of `get_num_tokens_from_messages` ignores tool
schemas.
* The base implementation of `get_num_tokens_from_messages` adds additional
prefixes to messages in represent user roles, which will add to the
overall token count. Model-specific implementations may choose to
handle this differently.
Args:
messages: The message inputs to tokenize.

View File

@@ -91,7 +91,10 @@ def _generate_response_from_error(error: BaseException) -> list[ChatGeneration]:
try:
metadata["body"] = response.json()
except Exception:
metadata["body"] = getattr(response, "text", None)
try:
metadata["body"] = getattr(response, "text", None)
except Exception:
metadata["body"] = None
if hasattr(response, "headers"):
try:
metadata["headers"] = dict(response.headers)

View File

@@ -61,13 +61,15 @@ class Reviver:
"""Initialize the reviver.
Args:
secrets_map: A map of secrets to load. If a secret is not found in
the map, it will be loaded from the environment if `secrets_from_env`
is True.
secrets_map: A map of secrets to load.
If a secret is not found in the map, it will be loaded from the
environment if `secrets_from_env` is `True`.
valid_namespaces: A list of additional namespaces (modules)
to allow to be deserialized.
secrets_from_env: Whether to load secrets from the environment.
additional_import_mappings: A dictionary of additional namespace mappings
You can use this to override default mappings or add new mappings.
ignore_unserializable_fields: Whether to ignore unserializable fields.
"""
@@ -195,13 +197,15 @@ def loads(
Args:
text: The string to load.
secrets_map: A map of secrets to load. If a secret is not found in
the map, it will be loaded from the environment if `secrets_from_env`
is True.
secrets_map: A map of secrets to load.
If a secret is not found in the map, it will be loaded from the environment
if `secrets_from_env` is `True`.
valid_namespaces: A list of additional namespaces (modules)
to allow to be deserialized.
secrets_from_env: Whether to load secrets from the environment.
additional_import_mappings: A dictionary of additional namespace mappings
You can use this to override default mappings or add new mappings.
ignore_unserializable_fields: Whether to ignore unserializable fields.
@@ -237,13 +241,15 @@ def load(
Args:
obj: The object to load.
secrets_map: A map of secrets to load. If a secret is not found in
the map, it will be loaded from the environment if `secrets_from_env`
is True.
secrets_map: A map of secrets to load.
If a secret is not found in the map, it will be loaded from the environment
if `secrets_from_env` is `True`.
valid_namespaces: A list of additional namespaces (modules)
to allow to be deserialized.
secrets_from_env: Whether to load secrets from the environment.
additional_import_mappings: A dictionary of additional namespace mappings
You can use this to override default mappings or add new mappings.
ignore_unserializable_fields: Whether to ignore unserializable fields.

View File

@@ -5,11 +5,9 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Any, cast, overload
from pydantic import ConfigDict, Field
from typing_extensions import Self
from langchain_core._api.deprecation import warn_deprecated
from langchain_core.load.serializable import Serializable
from langchain_core.messages import content as types
from langchain_core.utils import get_bolded_text
from langchain_core.utils._merge import merge_dicts, merge_lists
from langchain_core.utils.interactive_env import is_interactive_env
@@ -17,6 +15,9 @@ from langchain_core.utils.interactive_env import is_interactive_env
if TYPE_CHECKING:
from collections.abc import Sequence
from typing_extensions import Self
from langchain_core.messages import content as types
from langchain_core.prompts.chat import ChatPromptTemplate

View File

@@ -12,10 +12,11 @@ the implementation in `BaseMessage`.
from __future__ import annotations
from collections.abc import Callable
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from collections.abc import Callable
from langchain_core.messages import AIMessage, AIMessageChunk
from langchain_core.messages import content as types

View File

@@ -4,7 +4,6 @@ from __future__ import annotations
import json
import warnings
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, Literal, cast
from langchain_core.language_models._utils import (
@@ -14,6 +13,8 @@ from langchain_core.language_models._utils import (
from langchain_core.messages import content as types
if TYPE_CHECKING:
from collections.abc import Iterable
from langchain_core.messages import AIMessage, AIMessageChunk

View File

@@ -2,15 +2,17 @@
from __future__ import annotations
from typing import Literal
from typing import TYPE_CHECKING, Literal
from pydantic import model_validator
from typing_extensions import Self
from langchain_core.messages import BaseMessage, BaseMessageChunk
from langchain_core.outputs.generation import Generation
from langchain_core.utils._merge import merge_dicts
if TYPE_CHECKING:
from typing_extensions import Self
class ChatGeneration(Generation):
"""A single chat generation output.

View File

@@ -6,7 +6,7 @@ import contextlib
import json
import typing
from abc import ABC, abstractmethod
from collections.abc import Callable, Mapping
from collections.abc import Mapping
from functools import cached_property
from pathlib import Path
from typing import (
@@ -33,6 +33,8 @@ from langchain_core.runnables.config import ensure_config
from langchain_core.utils.pydantic import create_model_v2
if TYPE_CHECKING:
from collections.abc import Callable
from langchain_core.documents import Document

View File

@@ -6,10 +6,10 @@ from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any
from langchain_core.load import Serializable
from langchain_core.messages import BaseMessage
from langchain_core.utils.interactive_env import is_interactive_env
if TYPE_CHECKING:
from langchain_core.messages import BaseMessage
from langchain_core.prompts.chat import ChatPromptTemplate

View File

@@ -4,9 +4,8 @@ from __future__ import annotations
import warnings
from abc import ABC
from collections.abc import Callable, Sequence
from string import Formatter
from typing import Any, Literal
from typing import TYPE_CHECKING, Any, Literal
from pydantic import BaseModel, create_model
@@ -16,6 +15,9 @@ from langchain_core.utils import get_colored_text, mustache
from langchain_core.utils.formatting import formatter
from langchain_core.utils.interactive_env import is_interactive_env
if TYPE_CHECKING:
from collections.abc import Callable, Sequence
try:
from jinja2 import Environment, meta
from jinja2.sandbox import SandboxedEnvironment

View File

@@ -4,7 +4,6 @@ from __future__ import annotations
import inspect
from collections import defaultdict
from collections.abc import Callable
from dataclasses import dataclass, field
from enum import Enum
from typing import (
@@ -22,7 +21,7 @@ from langchain_core.runnables.base import Runnable, RunnableSerializable
from langchain_core.utils.pydantic import _IgnoreUnserializable, is_basemodel_subclass
if TYPE_CHECKING:
from collections.abc import Sequence
from collections.abc import Callable, Sequence
from pydantic import BaseModel

View File

@@ -7,7 +7,6 @@ from __future__ import annotations
import math
import os
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any
try:
@@ -20,6 +19,8 @@ except ImportError:
_HAS_GRANDALF = False
if TYPE_CHECKING:
from collections.abc import Mapping, Sequence
from langchain_core.runnables.graph import Edge as LangEdge

View File

@@ -7,8 +7,7 @@ import asyncio
import inspect
import sys
import textwrap
from collections.abc import Callable, Mapping, Sequence
from contextvars import Context
from collections.abc import Mapping, Sequence
from functools import lru_cache
from inspect import signature
from itertools import groupby
@@ -31,9 +30,11 @@ if TYPE_CHECKING:
AsyncIterable,
AsyncIterator,
Awaitable,
Callable,
Coroutine,
Iterable,
)
from contextvars import Context
from langchain_core.runnables.schema import StreamEvent

View File

@@ -15,12 +15,6 @@ from typing import (
from langchain_core.exceptions import TracerException
from langchain_core.load import dumpd
from langchain_core.outputs import (
ChatGeneration,
ChatGenerationChunk,
GenerationChunk,
LLMResult,
)
from langchain_core.tracers.schemas import Run
if TYPE_CHECKING:
@@ -31,6 +25,12 @@ if TYPE_CHECKING:
from langchain_core.documents import Document
from langchain_core.messages import BaseMessage
from langchain_core.outputs import (
ChatGeneration,
ChatGenerationChunk,
GenerationChunk,
LLMResult,
)
logger = logging.getLogger(__name__)

View File

@@ -8,7 +8,6 @@ import logging
import types
import typing
import uuid
from collections.abc import Callable
from typing import (
TYPE_CHECKING,
Annotated,
@@ -33,6 +32,8 @@ from langchain_core.utils.json_schema import dereference_refs
from langchain_core.utils.pydantic import is_basemodel_subclass
if TYPE_CHECKING:
from collections.abc import Callable
from langchain_core.tools import BaseTool
logger = logging.getLogger(__name__)

View File

@@ -4,11 +4,13 @@ from __future__ import annotations
import json
import re
from collections.abc import Callable
from typing import Any
from typing import TYPE_CHECKING, Any
from langchain_core.exceptions import OutputParserException
if TYPE_CHECKING:
from collections.abc import Callable
def _replace_new_line(match: re.Match[str]) -> str:
value = match.group(2)

View File

@@ -5,7 +5,6 @@ from __future__ import annotations
import inspect
import textwrap
import warnings
from collections.abc import Callable
from contextlib import nullcontext
from functools import lru_cache, wraps
from types import GenericAlias
@@ -41,10 +40,12 @@ from pydantic.json_schema import (
)
from pydantic.v1 import BaseModel as BaseModelV1
from pydantic.v1 import create_model as create_model_v1
from pydantic.v1.fields import ModelField
from typing_extensions import deprecated, override
if TYPE_CHECKING:
from collections.abc import Callable
from pydantic.v1.fields import ModelField
from pydantic_core import core_schema
PYDANTIC_VERSION = version.parse(pydantic.__version__)

View File

@@ -11,7 +11,6 @@ import logging
import math
import warnings
from abc import ABC, abstractmethod
from collections.abc import Callable
from itertools import cycle
from typing import (
TYPE_CHECKING,
@@ -29,7 +28,7 @@ from langchain_core.retrievers import BaseRetriever, LangSmithRetrieverParams
from langchain_core.runnables.config import run_in_executor
if TYPE_CHECKING:
from collections.abc import Collection, Iterable, Iterator, Sequence
from collections.abc import Callable, Collection, Iterable, Iterator, Sequence
from langchain_core.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,

View File

@@ -4,7 +4,6 @@ from __future__ import annotations
import json
import uuid
from collections.abc import Callable
from pathlib import Path
from typing import (
TYPE_CHECKING,
@@ -20,7 +19,7 @@ from langchain_core.vectorstores.utils import _cosine_similarity as cosine_simil
from langchain_core.vectorstores.utils import maximal_marginal_relevance
if TYPE_CHECKING:
from collections.abc import Iterator, Sequence
from collections.abc import Callable, Iterator, Sequence
from langchain_core.embeddings import Embeddings

View File

@@ -1,3 +1,3 @@
"""langchain-core version information and utilities."""
VERSION = "1.0.4"
VERSION = "1.0.5"

View File

@@ -9,7 +9,7 @@ license = {text = "MIT"}
readme = "README.md"
authors = []
version = "1.0.4"
version = "1.0.5"
requires-python = ">=3.10.0,<4.0.0"
dependencies = [
"langsmith>=0.3.45,<1.0.0",

View File

@@ -18,6 +18,7 @@ from langchain_core.language_models import (
ParrotFakeChatModel,
)
from langchain_core.language_models._utils import _normalize_messages
from langchain_core.language_models.chat_models import _generate_response_from_error
from langchain_core.language_models.fake_chat_models import (
FakeListChatModelError,
GenericFakeChatModel,
@@ -1234,3 +1235,93 @@ def test_model_profiles() -> None:
model = MyModel(messages=iter([]))
profile = model.profile
assert profile
class MockResponse:
"""Mock response for testing _generate_response_from_error."""
def __init__(
self,
status_code: int = 400,
headers: dict[str, str] | None = None,
json_data: dict[str, Any] | None = None,
json_raises: type[Exception] | None = None,
text_raises: type[Exception] | None = None,
):
self.status_code = status_code
self.headers = headers or {}
self._json_data = json_data
self._json_raises = json_raises
self._text_raises = text_raises
def json(self) -> dict[str, Any]:
if self._json_raises:
msg = "JSON parsing failed"
raise self._json_raises(msg)
return self._json_data or {}
@property
def text(self) -> str:
if self._text_raises:
msg = "Text access failed"
raise self._text_raises(msg)
return ""
class MockAPIError(Exception):
"""Mock API error with response attribute."""
def __init__(self, message: str, response: MockResponse | None = None):
super().__init__(message)
self.message = message
if response is not None:
self.response = response
def test_generate_response_from_error_with_valid_json() -> None:
"""Test `_generate_response_from_error` with valid JSON response."""
response = MockResponse(
status_code=400,
headers={"content-type": "application/json"},
json_data={"error": {"message": "Bad request", "type": "invalid_request"}},
)
error = MockAPIError("API Error", response=response)
generations = _generate_response_from_error(error)
assert len(generations) == 1
generation = generations[0]
assert isinstance(generation, ChatGeneration)
assert isinstance(generation.message, AIMessage)
assert generation.message.content == ""
metadata = generation.message.response_metadata
assert metadata["body"] == {
"error": {"message": "Bad request", "type": "invalid_request"}
}
assert metadata["headers"] == {"content-type": "application/json"}
assert metadata["status_code"] == 400
def test_generate_response_from_error_handles_streaming_response_failure() -> None:
# Simulates scenario where accessing response.json() or response.text
# raises ResponseNotRead on streaming responses
response = MockResponse(
status_code=400,
headers={"content-type": "application/json"},
json_raises=Exception, # Simulates ResponseNotRead or similar
text_raises=Exception,
)
error = MockAPIError("API Error", response=response)
# This should NOT raise an exception, but should handle it gracefully
generations = _generate_response_from_error(error)
assert len(generations) == 1
generation = generations[0]
metadata = generation.message.response_metadata
# When both fail, body should be None instead of raising an exception
assert metadata["body"] is None
assert metadata["headers"] == {"content-type": "application/json"}
assert metadata["status_code"] == 400

View File

@@ -0,0 +1,140 @@
"""Test groq block translator."""
from typing import cast
import pytest
from langchain_core.messages import AIMessage
from langchain_core.messages import content as types
from langchain_core.messages.base import _extract_reasoning_from_additional_kwargs
from langchain_core.messages.block_translators import PROVIDER_TRANSLATORS
from langchain_core.messages.block_translators.groq import (
_parse_code_json,
translate_content,
)
def test_groq_translator_registered() -> None:
"""Test that groq translator is properly registered."""
assert "groq" in PROVIDER_TRANSLATORS
assert "translate_content" in PROVIDER_TRANSLATORS["groq"]
assert "translate_content_chunk" in PROVIDER_TRANSLATORS["groq"]
def test_extract_reasoning_from_additional_kwargs_exists() -> None:
"""Test that _extract_reasoning_from_additional_kwargs can be imported."""
# Verify it's callable
assert callable(_extract_reasoning_from_additional_kwargs)
def test_groq_translate_content_basic() -> None:
"""Test basic groq content translation."""
# Test with simple text message
message = AIMessage(content="Hello world")
blocks = translate_content(message)
assert isinstance(blocks, list)
assert len(blocks) == 1
assert blocks[0]["type"] == "text"
assert blocks[0]["text"] == "Hello world"
def test_groq_translate_content_with_reasoning() -> None:
"""Test groq content translation with reasoning content."""
# Test with reasoning content in additional_kwargs
message = AIMessage(
content="Final answer",
additional_kwargs={"reasoning_content": "Let me think about this..."},
)
blocks = translate_content(message)
assert isinstance(blocks, list)
assert len(blocks) == 2
# First block should be reasoning
assert blocks[0]["type"] == "reasoning"
assert blocks[0]["reasoning"] == "Let me think about this..."
# Second block should be text
assert blocks[1]["type"] == "text"
assert blocks[1]["text"] == "Final answer"
def test_groq_translate_content_with_tool_calls() -> None:
"""Test groq content translation with tool calls."""
# Test with tool calls
message = AIMessage(
content="",
tool_calls=[
{
"name": "search",
"args": {"query": "test"},
"id": "call_123",
}
],
)
blocks = translate_content(message)
assert isinstance(blocks, list)
assert len(blocks) == 1
assert blocks[0]["type"] == "tool_call"
assert blocks[0]["name"] == "search"
assert blocks[0]["args"] == {"query": "test"}
assert blocks[0]["id"] == "call_123"
def test_groq_translate_content_with_executed_tools() -> None:
"""Test groq content translation with executed tools (built-in tools)."""
# Test with executed_tools in additional_kwargs (Groq built-in tools)
message = AIMessage(
content="",
additional_kwargs={
"executed_tools": [
{
"type": "python",
"arguments": '{"code": "print(\\"hello\\")"}',
"output": "hello\\n",
}
]
},
)
blocks = translate_content(message)
assert isinstance(blocks, list)
# Should have server_tool_call and server_tool_result
assert len(blocks) >= 2
# Check for server_tool_call
tool_call_blocks = [
cast("types.ServerToolCall", b)
for b in blocks
if b.get("type") == "server_tool_call"
]
assert len(tool_call_blocks) == 1
assert tool_call_blocks[0]["name"] == "code_interpreter"
assert "code" in tool_call_blocks[0]["args"]
# Check for server_tool_result
tool_result_blocks = [
cast("types.ServerToolResult", b)
for b in blocks
if b.get("type") == "server_tool_result"
]
assert len(tool_result_blocks) == 1
assert tool_result_blocks[0]["output"] == "hello\\n"
assert tool_result_blocks[0]["status"] == "success"
def test_parse_code_json() -> None:
"""Test the _parse_code_json helper function."""
# Test valid code JSON
result = _parse_code_json('{"code": "print(\'hello\')"}')
assert result == {"code": "print('hello')"}
# Test code with unescaped quotes (Groq format)
result = _parse_code_json('{"code": "print("hello")"}')
assert result == {"code": 'print("hello")'}
# Test invalid format raises ValueError
with pytest.raises(ValueError, match="Could not extract Python code"):
_parse_code_json('{"invalid": "format"}')

View File

@@ -3,12 +3,14 @@
import asyncio
import time
from threading import Lock
from typing import Any
from typing import TYPE_CHECKING, Any
import pytest
from langchain_core.runnables import RunnableConfig, RunnableLambda
from langchain_core.runnables.base import Runnable
if TYPE_CHECKING:
from langchain_core.runnables.base import Runnable
@pytest.mark.asyncio

View File

@@ -3,21 +3,25 @@ from __future__ import annotations
import json
import sys
import uuid
from collections.abc import AsyncGenerator, Callable, Coroutine, Generator
from inspect import isasyncgenfunction
from typing import Any, Literal
from typing import TYPE_CHECKING, Any, Literal
from unittest.mock import MagicMock, patch
import pytest
from langsmith import Client, get_current_run_tree, traceable
from langsmith.run_helpers import tracing_context
from langsmith.run_trees import RunTree
from langsmith.utils import get_env_var
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.runnables.base import RunnableLambda, RunnableParallel
from langchain_core.tracers.langchain import LangChainTracer
if TYPE_CHECKING:
from collections.abc import AsyncGenerator, Callable, Coroutine, Generator
from langsmith.run_trees import RunTree
from langchain_core.callbacks import BaseCallbackHandler
def _get_posts(client: Client) -> list:
mock_calls = client.session.request.mock_calls # type: ignore[attr-defined]

View File

@@ -1,10 +1,9 @@
import os
import re
import sys
from collections.abc import Callable
from contextlib import AbstractContextManager, nullcontext
from copy import deepcopy
from typing import Any
from typing import TYPE_CHECKING, Any
from unittest.mock import patch
import pytest
@@ -23,6 +22,9 @@ from langchain_core.utils import (
from langchain_core.utils._merge import merge_dicts
from langchain_core.utils.utils import secret_from_env
if TYPE_CHECKING:
from collections.abc import Callable
@pytest.mark.parametrize(
("package", "check_kwargs", "actual_version", "expected"),

6
libs/core/uv.lock generated
View File

@@ -1,5 +1,5 @@
version = 1
revision = 2
revision = 3
requires-python = ">=3.10.0, <4.0.0"
resolution-markers = [
"python_full_version >= '3.14' and platform_python_implementation == 'PyPy'",
@@ -960,7 +960,7 @@ wheels = [
[[package]]
name = "langchain-core"
version = "1.0.4"
version = "1.0.5"
source = { editable = "." }
dependencies = [
{ name = "jsonpatch" },
@@ -1057,7 +1057,7 @@ typing = [
[[package]]
name = "langchain-model-profiles"
version = "0.0.3"
version = "0.0.4"
source = { directory = "../model-profiles" }
dependencies = [
{ name = "tomli", marker = "python_full_version < '3.11'" },

View File

@@ -24,7 +24,7 @@ In most cases, you should be using the main [`langchain`](https://pypi.org/proje
## 📖 Documentation
For full documentation, see the [API reference](https://reference.langchain.com/python/langchain_classic).
For full documentation, see the [API reference](https://reference.langchain.com/python/langchain_classic). For conceptual guides, tutorials, and examples on using LangChain, see the [LangChain Docs](https://docs.langchain.com/oss/python/langchain/overview).
## 📕 Releases & Versioning

View File

@@ -100,6 +100,21 @@ def init_chat_model(
You can also specify model and model provider in a single argument using
`'{model_provider}:{model}'` format, e.g. `'openai:o1'`.
Will attempt to infer `model_provider` from model if not specified.
The following providers will be inferred based on these model prefixes:
- `gpt-...` | `o1...` | `o3...` -> `openai`
- `claude...` -> `anthropic`
- `amazon...` -> `bedrock`
- `gemini...` -> `google_vertexai`
- `command...` -> `cohere`
- `accounts/fireworks...` -> `fireworks`
- `mistral...` -> `mistralai`
- `deepseek...` -> `deepseek`
- `grok...` -> `xai`
- `sonar...` -> `perplexity`
model_provider: The model provider if not specified as part of the model arg
(see above).
@@ -123,24 +138,10 @@ def init_chat_model(
- `ollama` -> [`langchain-ollama`](https://docs.langchain.com/oss/python/integrations/providers/ollama)
- `google_anthropic_vertex` -> [`langchain-google-vertexai`](https://docs.langchain.com/oss/python/integrations/providers/google)
- `deepseek` -> [`langchain-deepseek`](https://docs.langchain.com/oss/python/integrations/providers/deepseek)
- `ibm` -> [`langchain-ibm`](https://docs.langchain.com/oss/python/integrations/providers/deepseek)
- `ibm` -> [`langchain-ibm`](https://docs.langchain.com/oss/python/integrations/providers/ibm)
- `nvidia` -> [`langchain-nvidia-ai-endpoints`](https://docs.langchain.com/oss/python/integrations/providers/nvidia)
- `xai` -> [`langchain-xai`](https://docs.langchain.com/oss/python/integrations/providers/xai)
- `perplexity` -> [`langchain-perplexity`](https://docs.langchain.com/oss/python/integrations/providers/perplexity)
Will attempt to infer `model_provider` from model if not specified. The
following providers will be inferred based on these model prefixes:
- `gpt-...` | `o1...` | `o3...` -> `openai`
- `claude...` -> `anthropic`
- `amazon...` -> `bedrock`
- `gemini...` -> `google_vertexai`
- `command...` -> `cohere`
- `accounts/fireworks...` -> `fireworks`
- `mistral...` -> `mistralai`
- `deepseek...` -> `deepseek`
- `grok...` -> `xai`
- `sonar...` -> `perplexity`
configurable_fields: Which model parameters are configurable at runtime:
- `None`: No configurable fields (i.e., a fixed model).
@@ -155,6 +156,7 @@ def init_chat_model(
If `model` is not specified, then defaults to `("model", "model_provider")`.
!!! warning "Security note"
Setting `configurable_fields="any"` means fields like `api_key`,
`base_url`, etc., can be altered at runtime, potentially redirecting
model requests to a different service/user.

View File

@@ -1,4 +1,4 @@
.PHONY: all start_services stop_services coverage test test_fast extended_tests test_watch test_watch_extended integration_tests check_imports lint format lint_diff format_diff lint_package lint_tests help
.PHONY: all start_services stop_services coverage coverage_agents test test_fast extended_tests test_watch test_watch_extended integration_tests check_imports lint format lint_diff format_diff lint_package lint_tests help
# Default target executed when no arguments are given to make.
all: help
@@ -27,8 +27,17 @@ coverage:
--cov-report term-missing:skip-covered \
$(TEST_FILE)
# Run middleware and agent tests with coverage report.
coverage_agents:
uv run --group test pytest \
tests/unit_tests/agents/middleware/ \
tests/unit_tests/agents/test_*.py \
--cov=langchain.agents \
--cov-report=term-missing \
--cov-report=html:htmlcov \
test:
make start_services && LANGGRAPH_TEST_FAST=0 uv run --no-sync --active --group test pytest -n auto --disable-socket --allow-unix-socket $(TEST_FILE) --cov-report term-missing:skip-covered; \
make start_services && LANGGRAPH_TEST_FAST=0 uv run --no-sync --active --group test pytest -n auto --disable-socket --allow-unix-socket $(TEST_FILE) --cov-report term-missing:skip-covered --snapshot-update; \
EXIT_CODE=$$?; \
make stop_services; \
exit $$EXIT_CODE
@@ -93,6 +102,7 @@ help:
@echo 'lint - run linters'
@echo '-- TESTS --'
@echo 'coverage - run unit tests and generate coverage report'
@echo 'coverage_agents - run middleware and agent tests with coverage report'
@echo 'test - run unit tests with all services'
@echo 'test_fast - run unit tests with in-memory services only'
@echo 'tests - run unit tests (alias for "make test")'

View File

@@ -26,7 +26,7 @@ LangChain [agents](https://docs.langchain.com/oss/python/langchain/agents) are b
## 📖 Documentation
For full documentation, see the [API reference](https://reference.langchain.com/python/langchain/langchain/).
For full documentation, see the [API reference](https://reference.langchain.com/python/langchain/langchain/). For conceptual guides, tutorials, and examples on using LangChain, see the [LangChain Docs](https://docs.langchain.com/oss/python/langchain/overview).
## 📕 Releases & Versioning

View File

@@ -1,3 +1,3 @@
"""Main entrypoint into LangChain."""
__version__ = "1.0.4"
__version__ = "1.0.5"

View File

@@ -1,10 +1,4 @@
"""Entrypoint to building [Agents](https://docs.langchain.com/oss/python/langchain/agents) with LangChain.
!!! warning "Reference docs"
This page contains **reference documentation** for Agents. See
[the docs](https://docs.langchain.com/oss/python/langchain/agents) for conceptual
guides, tutorials, and examples on using Agents.
""" # noqa: E501
"""Entrypoint to building [Agents](https://docs.langchain.com/oss/python/langchain/agents) with LangChain.""" # noqa: E501
from langchain.agents.factory import create_agent
from langchain.agents.middleware.types import AgentState

View File

@@ -63,6 +63,18 @@ if TYPE_CHECKING:
STRUCTURED_OUTPUT_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes."
FALLBACK_MODELS_WITH_STRUCTURED_OUTPUT = [
# if langchain-model-profiles is not installed, these models are assumed to support
# structured output
"grok",
"gpt-5",
"gpt-4.1",
"gpt-4o",
"gpt-oss",
"o3-pro",
"o3-mini",
]
def _normalize_to_model_response(result: ModelResponse | AIMessage) -> ModelResponse:
"""Normalize middleware return value to ModelResponse."""
@@ -349,11 +361,13 @@ def _get_can_jump_to(middleware: AgentMiddleware[Any, Any], hook_name: str) -> l
return []
def _supports_provider_strategy(model: str | BaseChatModel) -> bool:
def _supports_provider_strategy(model: str | BaseChatModel, tools: list | None = None) -> bool:
"""Check if a model supports provider-specific structured output.
Args:
model: Model name string or `BaseChatModel` instance.
tools: Optional list of tools provided to the agent. Needed because some models
don't support structured output together with tool calling.
Returns:
`True` if the model supports provider-specific structured output, `False` otherwise.
@@ -362,11 +376,26 @@ def _supports_provider_strategy(model: str | BaseChatModel) -> bool:
if isinstance(model, str):
model_name = model
elif isinstance(model, BaseChatModel):
model_name = getattr(model, "model_name", None)
model_name = (
getattr(model, "model_name", None)
or getattr(model, "model", None)
or getattr(model, "model_id", "")
)
try:
model_profile = model.profile
except ImportError:
pass
else:
if (
model_profile.get("structured_output")
# We make an exception for Gemini models, which currently do not support
# simultaneous tool use with structured output
and not (tools and isinstance(model_name, str) and "gemini" in model_name.lower())
):
return True
return (
"grok" in model_name.lower()
or any(part in model_name for part in ["gpt-5", "gpt-4.1", "gpt-oss", "o3-pro", "o3-mini"])
any(part in model_name.lower() for part in FALLBACK_MODELS_WITH_STRUCTURED_OUTPUT)
if model_name
else False
)
@@ -537,17 +566,29 @@ def create_agent( # noqa: PLR0915
visit the [Agents](https://docs.langchain.com/oss/python/langchain/agents) docs.
Args:
model: The language model for the agent. Can be a string identifier
(e.g., `"openai:gpt-4"`) or a direct chat model instance (e.g.,
[`ChatOpenAI`][langchain_openai.ChatOpenAI] or other another
[chat model](https://docs.langchain.com/oss/python/integrations/chat)).
model: The language model for the agent.
Can be a string identifier (e.g., `"openai:gpt-4"`) or a direct chat model
instance (e.g., [`ChatOpenAI`][langchain_openai.ChatOpenAI] or other another
[LangChain chat model](https://docs.langchain.com/oss/python/integrations/chat)).
For a full list of supported model strings, see
[`init_chat_model`][langchain.chat_models.init_chat_model(model_provider)].
tools: A list of tools, `dicts`, or `Callable`.
!!! tip ""
See the [Models](https://docs.langchain.com/oss/python/langchain/models)
docs for more information.
tools: A list of tools, `dict`, or `Callable`.
If `None` or an empty list, the agent will consist of a model node without a
tool calling loop.
!!! tip ""
See the [Tools](https://docs.langchain.com/oss/python/langchain/tools)
docs for more information.
system_prompt: An optional system prompt for the LLM.
Prompts are converted to a
@@ -555,24 +596,34 @@ def create_agent( # noqa: PLR0915
beginning of the message list.
middleware: A sequence of middleware instances to apply to the agent.
Middleware can intercept and modify agent behavior at various stages. See
the [full guide](https://docs.langchain.com/oss/python/langchain/middleware).
Middleware can intercept and modify agent behavior at various stages.
!!! tip ""
See the [Middleware](https://docs.langchain.com/oss/python/langchain/middleware)
docs for more information.
response_format: An optional configuration for structured responses.
Can be a `ToolStrategy`, `ProviderStrategy`, or a Pydantic model class.
If provided, the agent will handle structured output during the
conversation flow. Raw schemas will be wrapped in an appropriate strategy
based on model capabilities.
conversation flow.
Raw schemas will be wrapped in an appropriate strategy based on model
capabilities.
!!! tip ""
See the [Structured output](https://docs.langchain.com/oss/python/langchain/structured-output)
docs for more information.
state_schema: An optional `TypedDict` schema that extends `AgentState`.
When provided, this schema is used instead of `AgentState` as the base
schema for merging with middleware state schemas. This allows users to
add custom state fields without needing to create custom middleware.
Generally, it's recommended to use `state_schema` extensions via middleware
to keep relevant extensions scoped to corresponding hooks / tools.
The schema must be a subclass of `AgentState[ResponseT]`.
context_schema: An optional schema for runtime context.
checkpointer: An optional checkpoint saver object.
@@ -966,7 +1017,7 @@ def create_agent( # noqa: PLR0915
effective_response_format: ResponseFormat | None
if isinstance(request.response_format, AutoStrategy):
# User provided raw schema via AutoStrategy - auto-detect best strategy based on model
if _supports_provider_strategy(request.model):
if _supports_provider_strategy(request.model, tools=request.tools):
# Model supports provider strategy - use it
effective_response_format = ProviderStrategy(schema=request.response_format.schema)
else:
@@ -987,7 +1038,7 @@ def create_agent( # noqa: PLR0915
# Bind model based on effective response format
if isinstance(effective_response_format, ProviderStrategy):
# Use provider-specific structured output
# (Backward compatibility) Use OpenAI format structured output
kwargs = effective_response_format.to_model_kwargs()
return (
request.model.bind_tools(

View File

@@ -1,15 +1,10 @@
"""Entrypoint to using [Middleware](https://docs.langchain.com/oss/python/langchain/middleware) plugins with [Agents](https://docs.langchain.com/oss/python/langchain/agents).
!!! warning "Reference docs"
This page contains **reference documentation** for Middleware. See
[the docs](https://docs.langchain.com/oss/python/langchain/middleware) for conceptual
guides, tutorials, and examples on using Middleware.
""" # noqa: E501
"""Entrypoint to using [middleware](https://docs.langchain.com/oss/python/langchain/middleware) plugins with [Agents](https://docs.langchain.com/oss/python/langchain/agents).""" # noqa: E501
from .context_editing import (
ClearToolUsesEdit,
ContextEditingMiddleware,
)
from .file_search import FilesystemFileSearchMiddleware
from .human_in_the_loop import (
HumanInTheLoopMiddleware,
InterruptOnConfig,
@@ -52,6 +47,7 @@ __all__ = [
"CodexSandboxExecutionPolicy",
"ContextEditingMiddleware",
"DockerExecutionPolicy",
"FilesystemFileSearchMiddleware",
"HostExecutionPolicy",
"HumanInTheLoopMiddleware",
"InterruptOnConfig",

View File

@@ -56,11 +56,12 @@ class BaseExecutionPolicy(abc.ABC):
"""Configuration contract for persistent shell sessions.
Concrete subclasses encapsulate how a shell process is launched and constrained.
Each policy documents its security guarantees and the operating environments in
which it is appropriate. Use :class:`HostExecutionPolicy` for trusted, same-host
execution; :class:`CodexSandboxExecutionPolicy` when the Codex CLI sandbox is
available and you want additional syscall restrictions; and
:class:`DockerExecutionPolicy` for container-level isolation using Docker.
which it is appropriate. Use `HostExecutionPolicy` for trusted, same-host execution;
`CodexSandboxExecutionPolicy` when the Codex CLI sandbox is available and you want
additional syscall restrictions; and `DockerExecutionPolicy` for container-level
isolation using Docker.
"""
command_timeout: float = 30.0
@@ -91,13 +92,13 @@ class HostExecutionPolicy(BaseExecutionPolicy):
This policy is best suited for trusted or single-tenant environments (CI jobs,
developer workstations, pre-sandboxed containers) where the agent must access the
host filesystem and tooling without additional isolation. It enforces optional CPU
and memory limits to prevent runaway commands but offers **no** filesystem or network
host filesystem and tooling without additional isolation. Enforces optional CPU and
memory limits to prevent runaway commands but offers **no** filesystem or network
sandboxing; commands can modify anything the process user can reach.
On Linux platforms resource limits are applied with ``resource.prlimit`` after the
shell starts. On macOS, where ``prlimit`` is unavailable, limits are set in a
``preexec_fn`` before ``exec``. In both cases the shell runs in its own process group
On Linux platforms resource limits are applied with `resource.prlimit` after the
shell starts. On macOS, where `prlimit` is unavailable, limits are set in a
`preexec_fn` before `exec`. In both cases the shell runs in its own process group
so timeouts can terminate the full subtree.
"""
@@ -199,9 +200,9 @@ class CodexSandboxExecutionPolicy(BaseExecutionPolicy):
(Linux) profiles. Commands still run on the host, but within the sandbox requested by
the CLI. If the Codex binary is unavailable or the runtime lacks the required
kernel features (e.g., Landlock inside some containers), process startup fails with a
:class:`RuntimeError`.
`RuntimeError`.
Configure sandbox behaviour via ``config_overrides`` to align with your Codex CLI
Configure sandbox behavior via `config_overrides` to align with your Codex CLI
profile. This policy does not add its own resource limits; combine it with
host-level guards (cgroups, container resource limits) as needed.
"""
@@ -271,17 +272,17 @@ class DockerExecutionPolicy(BaseExecutionPolicy):
"""Run the shell inside a dedicated Docker container.
Choose this policy when commands originate from untrusted users or you require
strong isolation between sessions. By default the workspace is bind-mounted only when
it refers to an existing non-temporary directory; ephemeral sessions run without a
mount to minimise host exposure. The container's network namespace is disabled by
default (``--network none``) and you can enable further hardening via
``read_only_rootfs`` and ``user``.
strong isolation between sessions. By default the workspace is bind-mounted only
when it refers to an existing non-temporary directory; ephemeral sessions run
without a mount to minimise host exposure. The container's network namespace is
disabled by default (`--network none`) and you can enable further hardening via
`read_only_rootfs` and `user`.
The security guarantees depend on your Docker daemon configuration. Run the agent on
a host where Docker is locked down (rootless mode, AppArmor/SELinux, etc.) and review
any additional volumes or capabilities passed through ``extra_run_args``. The default
image is ``python:3.12-alpine3.19``; supply a custom image if you need preinstalled
tooling.
a host where Docker is locked down (rootless mode, AppArmor/SELinux, etc.) and
review any additional volumes or capabilities passed through ``extra_run_args``. The
default image is `python:3.12-alpine3.19`; supply a custom image if you need
preinstalled tooling.
"""
binary: str = "docker"

View File

@@ -1,9 +1,10 @@
"""Context editing middleware.
This middleware mirrors Anthropic's context editing capabilities by clearing
older tool results once the conversation grows beyond a configurable token
threshold. The implementation is intentionally model-agnostic so it can be used
with any LangChain chat model.
Mirrors Anthropic's context editing capabilities by clearing older tool results once the
conversation grows beyond a configurable token threshold.
The implementation is intentionally model-agnostic so it can be used with any LangChain
chat model.
"""
from __future__ import annotations
@@ -182,11 +183,13 @@ class ClearToolUsesEdit(ContextEdit):
class ContextEditingMiddleware(AgentMiddleware):
"""Automatically prunes tool results to manage context size.
"""Automatically prune tool results to manage context size.
The middleware applies a sequence of edits when the total input token count
exceeds configured thresholds. Currently the `ClearToolUsesEdit` strategy is
supported, aligning with Anthropic's `clear_tool_uses_20250919` behaviour.
The middleware applies a sequence of edits when the total input token count exceeds
configured thresholds.
Currently the `ClearToolUsesEdit` strategy is supported, aligning with Anthropic's
`clear_tool_uses_20250919` behavior [(read more)](https://docs.claude.com/en/docs/agents-and-tools/tool-use/memory-tool).
"""
edits: list[ContextEdit]
@@ -198,11 +201,12 @@ class ContextEditingMiddleware(AgentMiddleware):
edits: Iterable[ContextEdit] | None = None,
token_count_method: Literal["approximate", "model"] = "approximate", # noqa: S107
) -> None:
"""Initializes a context editing middleware instance.
"""Initialize an instance of context editing middleware.
Args:
edits: Sequence of edit strategies to apply. Defaults to a single
`ClearToolUsesEdit` mirroring Anthropic defaults.
edits: Sequence of edit strategies to apply.
Defaults to a single `ClearToolUsesEdit` mirroring Anthropic defaults.
token_count_method: Whether to use approximate token counting
(faster, less accurate) or exact counting implemented by the
chat model (potentially slower, more accurate).

View File

@@ -21,7 +21,7 @@ from langchain.agents.middleware.types import AgentMiddleware
def _expand_include_patterns(pattern: str) -> list[str] | None:
"""Expand brace patterns like ``*.{py,pyi}`` into a list of globs."""
"""Expand brace patterns like `*.{py,pyi}` into a list of globs."""
if "}" in pattern and "{" not in pattern:
return None
@@ -88,6 +88,7 @@ class FilesystemFileSearchMiddleware(AgentMiddleware):
"""Provides Glob and Grep search over filesystem files.
This middleware adds two tools that search through local filesystem:
- Glob: Fast file pattern matching by file path
- Grep: Fast content search using ripgrep or Python fallback
@@ -100,7 +101,7 @@ class FilesystemFileSearchMiddleware(AgentMiddleware):
agent = create_agent(
model=model,
tools=[],
tools=[], # Add tools as needed
middleware=[
FilesystemFileSearchMiddleware(root_path="/workspace"),
],
@@ -119,9 +120,10 @@ class FilesystemFileSearchMiddleware(AgentMiddleware):
Args:
root_path: Root directory to search.
use_ripgrep: Whether to use ripgrep for search (default: True).
Falls back to Python if ripgrep unavailable.
max_file_size_mb: Maximum file size to search in MB (default: 10).
use_ripgrep: Whether to use `ripgrep` for search.
Falls back to Python if `ripgrep` unavailable.
max_file_size_mb: Maximum file size to search in MB.
"""
self.root_path = Path(root_path).resolve()
self.use_ripgrep = use_ripgrep
@@ -132,8 +134,10 @@ class FilesystemFileSearchMiddleware(AgentMiddleware):
def glob_search(pattern: str, path: str = "/") -> str:
"""Fast file pattern matching tool that works with any codebase size.
Supports glob patterns like **/*.js or src/**/*.ts.
Supports glob patterns like `**/*.js` or `src/**/*.ts`.
Returns matching file paths sorted by modification time.
Use this tool when you need to find files by name patterns.
Args:
@@ -142,7 +146,7 @@ class FilesystemFileSearchMiddleware(AgentMiddleware):
Returns:
Newline-separated list of matching file paths, sorted by modification
time (most recently modified first). Returns "No files found" if no
time (most recently modified first). Returns `'No files found'` if no
matches.
"""
try:
@@ -184,15 +188,16 @@ class FilesystemFileSearchMiddleware(AgentMiddleware):
Args:
pattern: The regular expression pattern to search for in file contents.
path: The directory to search in. If not specified, searches from root.
include: File pattern to filter (e.g., "*.js", "*.{ts,tsx}").
include: File pattern to filter (e.g., `'*.js'`, `'*.{ts,tsx}'`).
output_mode: Output format:
- "files_with_matches": Only file paths containing matches (default)
- "content": Matching lines with file:line:content format
- "count": Count of matches per file
- `'files_with_matches'`: Only file paths containing matches
- `'content'`: Matching lines with `file:line:content` format
- `'count'`: Count of matches per file
Returns:
Search results formatted according to output_mode. Returns "No matches
found" if no results.
Search results formatted according to `output_mode`.
Returns `'No matches found'` if no results.
"""
# Compile regex pattern (for validation)
try:

View File

@@ -14,10 +14,10 @@ class Action(TypedDict):
"""Represents an action with a name and args."""
name: str
"""The type or name of action being requested (e.g., "add_numbers")."""
"""The type or name of action being requested (e.g., `'add_numbers'`)."""
args: dict[str, Any]
"""Key-value pairs of args needed for the action (e.g., {"a": 1, "b": 2})."""
"""Key-value pairs of args needed for the action (e.g., `{"a": 1, "b": 2}`)."""
class ActionRequest(TypedDict):
@@ -27,7 +27,7 @@ class ActionRequest(TypedDict):
"""The name of the action being requested."""
args: dict[str, Any]
"""Key-value pairs of args needed for the action (e.g., {"a": 1, "b": 2})."""
"""Key-value pairs of args needed for the action (e.g., `{"a": 1, "b": 2}`)."""
description: NotRequired[str]
"""The description of the action to be reviewed."""
@@ -169,18 +169,22 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
Args:
interrupt_on: Mapping of tool name to allowed actions.
If a tool doesn't have an entry, it's auto-approved by default.
* `True` indicates all decisions are allowed: approve, edit, and reject.
* `False` indicates that the tool is auto-approved.
* `InterruptOnConfig` indicates the specific decisions allowed for this
tool.
The InterruptOnConfig can include a `description` field (`str` or
The `InterruptOnConfig` can include a `description` field (`str` or
`Callable`) for custom formatting of the interrupt description.
description_prefix: The prefix to use when constructing action requests.
This is used to provide context about the tool call and the action being
requested. Not used if a tool has a `description` in its
`InterruptOnConfig`.
requested.
Not used if a tool has a `description` in its `InterruptOnConfig`.
"""
super().__init__()
resolved_configs: dict[str, InterruptOnConfig] = {}
@@ -349,3 +353,7 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
last_ai_msg.tool_calls = revised_tool_calls
return {"messages": [last_ai_msg, *artificial_tool_messages]}
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
"""Async trigger interrupt flows for relevant tool calls after an `AIMessage`."""
return self.after_model(state, runtime)

View File

@@ -20,9 +20,9 @@ if TYPE_CHECKING:
class ModelCallLimitState(AgentState):
"""State schema for ModelCallLimitMiddleware.
"""State schema for `ModelCallLimitMiddleware`.
Extends AgentState with model call tracking fields.
Extends `AgentState` with model call tracking fields.
"""
thread_model_call_count: NotRequired[Annotated[int, PrivateStateAttr]]
@@ -58,8 +58,8 @@ def _build_limit_exceeded_message(
class ModelCallLimitExceededError(Exception):
"""Exception raised when model call limits are exceeded.
This exception is raised when the configured exit behavior is 'error'
and either the thread or run model call limit has been exceeded.
This exception is raised when the configured exit behavior is `'error'` and either
the thread or run model call limit has been exceeded.
"""
def __init__(
@@ -127,13 +127,17 @@ class ModelCallLimitMiddleware(AgentMiddleware[ModelCallLimitState, Any]):
Args:
thread_limit: Maximum number of model calls allowed per thread.
None means no limit.
`None` means no limit.
run_limit: Maximum number of model calls allowed per run.
None means no limit.
`None` means no limit.
exit_behavior: What to do when limits are exceeded.
- "end": Jump to the end of the agent execution and
inject an artificial AI message indicating that the limit was exceeded.
- "error": Raise a `ModelCallLimitExceededError`
- `'end'`: Jump to the end of the agent execution and
inject an artificial AI message indicating that the limit was
exceeded.
- `'error'`: Raise a `ModelCallLimitExceededError`
Raises:
ValueError: If both limits are `None` or if `exit_behavior` is invalid.
@@ -161,12 +165,13 @@ class ModelCallLimitMiddleware(AgentMiddleware[ModelCallLimitState, Any]):
runtime: The langgraph runtime.
Returns:
If limits are exceeded and exit_behavior is "end", returns
a Command to jump to the end with a limit exceeded message. Otherwise returns None.
If limits are exceeded and exit_behavior is `'end'`, returns
a `Command` to jump to the end with a limit exceeded message. Otherwise
returns `None`.
Raises:
ModelCallLimitExceededError: If limits are exceeded and exit_behavior
is "error".
ModelCallLimitExceededError: If limits are exceeded and `exit_behavior`
is `'error'`.
"""
thread_count = state.get("thread_model_call_count", 0)
run_count = state.get("run_model_call_count", 0)
@@ -194,6 +199,29 @@ class ModelCallLimitMiddleware(AgentMiddleware[ModelCallLimitState, Any]):
return None
@hook_config(can_jump_to=["end"])
async def abefore_model(
self,
state: ModelCallLimitState,
runtime: Runtime,
) -> dict[str, Any] | None:
"""Async check model call limits before making a model call.
Args:
state: The current agent state containing call counts.
runtime: The langgraph runtime.
Returns:
If limits are exceeded and exit_behavior is `'end'`, returns
a `Command` to jump to the end with a limit exceeded message. Otherwise
returns `None`.
Raises:
ModelCallLimitExceededError: If limits are exceeded and `exit_behavior`
is `'error'`.
"""
return self.before_model(state, runtime)
def after_model(self, state: ModelCallLimitState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
"""Increment model call counts after a model call.
@@ -208,3 +236,19 @@ class ModelCallLimitMiddleware(AgentMiddleware[ModelCallLimitState, Any]):
"thread_model_call_count": state.get("thread_model_call_count", 0) + 1,
"run_model_call_count": state.get("run_model_call_count", 0) + 1,
}
async def aafter_model(
self,
state: ModelCallLimitState,
runtime: Runtime,
) -> dict[str, Any] | None:
"""Async increment model call counts after a model call.
Args:
state: The current agent state.
runtime: The langgraph runtime.
Returns:
State updates with incremented call counts.
"""
return self.after_model(state, runtime)

View File

@@ -22,7 +22,7 @@ class ModelFallbackMiddleware(AgentMiddleware):
"""Automatic fallback to alternative models on errors.
Retries failed model calls with alternative models in sequence until
success or all models exhausted. Primary model specified in create_agent().
success or all models exhausted. Primary model specified in `create_agent`.
Example:
```python

View File

@@ -27,24 +27,26 @@ if TYPE_CHECKING:
class PIIMiddleware(AgentMiddleware):
"""Detect and handle Personally Identifiable Information (PII) in agent conversations.
"""Detect and handle Personally Identifiable Information (PII) in conversations.
This middleware detects common PII types and applies configurable strategies
to handle them. It can detect emails, credit cards, IP addresses,
MAC addresses, and URLs in both user input and agent output.
to handle them. It can detect emails, credit cards, IP addresses, MAC addresses, and
URLs in both user input and agent output.
Built-in PII types:
- `email`: Email addresses
- `credit_card`: Credit card numbers (validated with Luhn algorithm)
- `ip`: IP addresses (validated with stdlib)
- `mac_address`: MAC addresses
- `url`: URLs (both `http`/`https` and bare URLs)
- `email`: Email addresses
- `credit_card`: Credit card numbers (validated with Luhn algorithm)
- `ip`: IP addresses (validated with stdlib)
- `mac_address`: MAC addresses
- `url`: URLs (both `http`/`https` and bare URLs)
Strategies:
- `block`: Raise an exception when PII is detected
- `redact`: Replace PII with `[REDACTED_TYPE]` placeholders
- `mask`: Partially mask PII (e.g., `****-****-****-1234` for credit card)
- `hash`: Replace PII with deterministic hash (e.g., `<email_hash:a1b2c3d4>`)
- `block`: Raise an exception when PII is detected
- `redact`: Replace PII with `[REDACTED_TYPE]` placeholders
- `mask`: Partially mask PII (e.g., `****-****-****-1234` for credit card)
- `hash`: Replace PII with deterministic hash (e.g., `<email_hash:a1b2c3d4>`)
Strategy Selection Guide:
@@ -101,12 +103,15 @@ class PIIMiddleware(AgentMiddleware):
"""Initialize the PII detection middleware.
Args:
pii_type: Type of PII to detect. Can be a built-in type
(`email`, `credit_card`, `ip`, `mac_address`, `url`)
or a custom type name.
strategy: How to handle detected PII:
pii_type: Type of PII to detect.
* `block`: Raise PIIDetectionError when PII is detected
Can be a built-in type (`email`, `credit_card`, `ip`, `mac_address`,
`url`) or a custom type name.
strategy: How to handle detected PII.
Options:
* `block`: Raise `PIIDetectionError` when PII is detected
* `redact`: Replace with `[REDACTED_TYPE]` placeholders
* `mask`: Partially mask PII (show last few characters)
* `hash`: Replace with deterministic hash (format: `<type_hash:digest>`)
@@ -114,16 +119,15 @@ class PIIMiddleware(AgentMiddleware):
detector: Custom detector function or regex pattern.
* If `Callable`: Function that takes content string and returns
list of PIIMatch objects
list of `PIIMatch` objects
* If `str`: Regex pattern to match PII
* If `None`: Uses built-in detector for the pii_type
* If `None`: Uses built-in detector for the `pii_type`
apply_to_input: Whether to check user messages before model call.
apply_to_output: Whether to check AI messages after model call.
apply_to_tool_results: Whether to check tool result messages after tool execution.
Raises:
ValueError: If pii_type is not built-in and no detector is provided.
ValueError: If `pii_type` is not built-in and no detector is provided.
"""
super().__init__()
@@ -166,10 +170,11 @@ class PIIMiddleware(AgentMiddleware):
runtime: The langgraph runtime.
Returns:
Updated state with PII handled according to strategy, or None if no PII detected.
Updated state with PII handled according to strategy, or `None` if no PII
detected.
Raises:
PIIDetectionError: If PII is detected and strategy is "block".
PIIDetectionError: If PII is detected and strategy is `'block'`.
"""
if not self.apply_to_input and not self.apply_to_tool_results:
return None
@@ -247,6 +252,27 @@ class PIIMiddleware(AgentMiddleware):
return None
@hook_config(can_jump_to=["end"])
async def abefore_model(
self,
state: AgentState,
runtime: Runtime,
) -> dict[str, Any] | None:
"""Async check user messages and tool results for PII before model invocation.
Args:
state: The current agent state.
runtime: The langgraph runtime.
Returns:
Updated state with PII handled according to strategy, or `None` if no PII
detected.
Raises:
PIIDetectionError: If PII is detected and strategy is `'block'`.
"""
return self.before_model(state, runtime)
def after_model(
self,
state: AgentState,
@@ -259,10 +285,11 @@ class PIIMiddleware(AgentMiddleware):
runtime: The langgraph runtime.
Returns:
Updated state with PII handled according to strategy, or None if no PII detected.
Updated state with PII handled according to strategy, or None if no PII
detected.
Raises:
PIIDetectionError: If PII is detected and strategy is "block".
PIIDetectionError: If PII is detected and strategy is `'block'`.
"""
if not self.apply_to_output:
return None
@@ -305,6 +332,26 @@ class PIIMiddleware(AgentMiddleware):
return {"messages": new_messages}
async def aafter_model(
self,
state: AgentState,
runtime: Runtime,
) -> dict[str, Any] | None:
"""Async check AI messages for PII after model invocation.
Args:
state: The current agent state.
runtime: The langgraph runtime.
Returns:
Updated state with PII handled according to strategy, or None if no PII
detected.
Raises:
PIIDetectionError: If PII is detected and strategy is `'block'`.
"""
return self.after_model(state, runtime)
__all__ = [
"PIIDetectionError",

View File

@@ -11,17 +11,17 @@ import subprocess
import tempfile
import threading
import time
import typing
import uuid
import weakref
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Annotated, Any, Literal
from typing import TYPE_CHECKING, Annotated, Any, Literal, cast
from langchain_core.messages import ToolMessage
from langchain_core.tools.base import BaseTool, ToolException
from langchain_core.tools.base import ToolException
from langgraph.channels.untracked_value import UntrackedValue
from pydantic import BaseModel, model_validator
from pydantic.json_schema import SkipJsonSchema
from typing_extensions import NotRequired
from langchain.agents.middleware._execution import (
@@ -38,14 +38,13 @@ from langchain.agents.middleware._redaction import (
ResolvedRedactionRule,
)
from langchain.agents.middleware.types import AgentMiddleware, AgentState, PrivateStateAttr
from langchain.tools import ToolRuntime, tool
if TYPE_CHECKING:
from collections.abc import Mapping, Sequence
from langgraph.runtime import Runtime
from langgraph.types import Command
from langchain.agents.middleware.types import ToolCallRequest
LOGGER = logging.getLogger(__name__)
_DONE_MARKER_PREFIX = "__LC_SHELL_DONE__"
@@ -59,6 +58,7 @@ DEFAULT_TOOL_DESCRIPTION = (
"session remains stable. Outputs may be truncated when they become very large, and long "
"running commands will be terminated once their configured timeout elapses."
)
SHELL_TOOL_NAME = "shell"
def _cleanup_resources(
@@ -334,7 +334,17 @@ class _ShellToolInput(BaseModel):
"""Input schema for the persistent shell tool."""
command: str | None = None
"""The shell command to execute."""
restart: bool | None = None
"""Whether to restart the shell session."""
runtime: Annotated[Any, SkipJsonSchema()] = None
"""The runtime for the shell tool.
Included as a workaround at the moment bc args_schema doesn't work with
injected ToolRuntime.
"""
@model_validator(mode="after")
def validate_payload(self) -> _ShellToolInput:
@@ -347,38 +357,21 @@ class _ShellToolInput(BaseModel):
return self
class _PersistentShellTool(BaseTool):
"""Tool wrapper that relies on middleware interception for execution."""
name: str = "shell"
description: str = DEFAULT_TOOL_DESCRIPTION
args_schema: type[BaseModel] = _ShellToolInput
def __init__(self, middleware: ShellToolMiddleware, description: str | None = None) -> None:
super().__init__()
self._middleware = middleware
if description is not None:
self.description = description
def _run(self, **_: Any) -> Any: # pragma: no cover - executed via middleware wrapper
msg = "Persistent shell tool execution should be intercepted via middleware wrappers."
raise RuntimeError(msg)
class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
"""Middleware that registers a persistent shell tool for agents.
The middleware exposes a single long-lived shell session. Use the execution policy to
match your deployment's security posture:
The middleware exposes a single long-lived shell session. Use the execution policy
to match your deployment's security posture:
* ``HostExecutionPolicy`` - full host access; best for trusted environments where the
agent already runs inside a container or VM that provides isolation.
* ``CodexSandboxExecutionPolicy`` - reuses the Codex CLI sandbox for additional
syscall/filesystem restrictions when the CLI is available.
* ``DockerExecutionPolicy`` - launches a separate Docker container for each agent run,
providing harder isolation, optional read-only root filesystems, and user remapping.
* `HostExecutionPolicy` full host access; best for trusted environments where the
agent already runs inside a container or VM that provides isolation.
* `CodexSandboxExecutionPolicy` reuses the Codex CLI sandbox for additional
syscall/filesystem restrictions when the CLI is available.
* `DockerExecutionPolicy` launches a separate Docker container for each agent run,
providing harder isolation, optional read-only root filesystems, and user
remapping.
When no policy is provided the middleware defaults to ``HostExecutionPolicy``.
When no policy is provided the middleware defaults to `HostExecutionPolicy`.
"""
state_schema = ShellToolState
@@ -392,29 +385,43 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
execution_policy: BaseExecutionPolicy | None = None,
redaction_rules: tuple[RedactionRule, ...] | list[RedactionRule] | None = None,
tool_description: str | None = None,
tool_name: str = SHELL_TOOL_NAME,
shell_command: Sequence[str] | str | None = None,
env: Mapping[str, Any] | None = None,
) -> None:
"""Initialize the middleware.
"""Initialize an instance of `ShellToolMiddleware`.
Args:
workspace_root: Base directory for the shell session. If omitted, a temporary
directory is created when the agent starts and removed when it ends.
startup_commands: Optional commands executed sequentially after the session starts.
workspace_root: Base directory for the shell session.
If omitted, a temporary directory is created when the agent starts and
removed when it ends.
startup_commands: Optional commands executed sequentially after the session
starts.
shutdown_commands: Optional commands executed before the session shuts down.
execution_policy: Execution policy controlling timeouts, output limits, and resource
configuration. Defaults to :class:`HostExecutionPolicy` for native execution.
execution_policy: Execution policy controlling timeouts, output limits, and
resource configuration.
Defaults to `HostExecutionPolicy` for native execution.
redaction_rules: Optional redaction rules to sanitize command output before
returning it to the model.
tool_description: Optional override for the registered shell tool description.
shell_command: Optional shell executable (string) or argument sequence used to
launch the persistent session. Defaults to an implementation-defined bash command.
env: Optional environment variables to supply to the shell session. Values are
coerced to strings before command execution. If omitted, the session inherits the
parent process environment.
tool_description: Optional override for the registered shell tool
description.
tool_name: Name for the registered shell tool.
Defaults to `"shell"`.
shell_command: Optional shell executable (string) or argument sequence used
to launch the persistent session.
Defaults to an implementation-defined bash command.
env: Optional environment variables to supply to the shell session.
Values are coerced to strings before command execution. If omitted, the
session inherits the parent process environment.
"""
super().__init__()
self._workspace_root = Path(workspace_root) if workspace_root else None
self._tool_name = tool_name
self._shell_command = self._normalize_shell_command(shell_command)
self._environment = self._normalize_env(env)
if execution_policy is not None:
@@ -428,9 +435,25 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
self._startup_commands = self._normalize_commands(startup_commands)
self._shutdown_commands = self._normalize_commands(shutdown_commands)
# Create a proper tool that executes directly (no interception needed)
description = tool_description or DEFAULT_TOOL_DESCRIPTION
self._tool = _PersistentShellTool(self, description=description)
self.tools = [self._tool]
@tool(self._tool_name, args_schema=_ShellToolInput, description=description)
def shell_tool(
*,
runtime: ToolRuntime[None, ShellToolState],
command: str | None = None,
restart: bool = False,
) -> ToolMessage | str:
resources = self._get_or_create_resources(runtime.state)
return self._run_shell_tool(
resources,
{"command": command, "restart": restart},
tool_call_id=runtime.tool_call_id,
)
self._shell_tool = shell_tool
self.tools = [self._shell_tool]
@staticmethod
def _normalize_commands(
@@ -468,36 +491,48 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
def before_agent(self, state: ShellToolState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
"""Start the shell session and run startup commands."""
resources = self._create_resources()
resources = self._get_or_create_resources(state)
return {"shell_session_resources": resources}
async def abefore_agent(self, state: ShellToolState, runtime: Runtime) -> dict[str, Any] | None:
"""Async counterpart to `before_agent`."""
"""Async start the shell session and run startup commands."""
return self.before_agent(state, runtime)
def after_agent(self, state: ShellToolState, runtime: Runtime) -> None: # noqa: ARG002
"""Run shutdown commands and release resources when an agent completes."""
resources = self._ensure_resources(state)
resources = state.get("shell_session_resources")
if not isinstance(resources, _SessionResources):
# Resources were never created, nothing to clean up
return
try:
self._run_shutdown_commands(resources.session)
finally:
resources._finalizer()
async def aafter_agent(self, state: ShellToolState, runtime: Runtime) -> None:
"""Async counterpart to `after_agent`."""
"""Async run shutdown commands and release resources when an agent completes."""
return self.after_agent(state, runtime)
def _ensure_resources(self, state: ShellToolState) -> _SessionResources:
def _get_or_create_resources(self, state: ShellToolState) -> _SessionResources:
"""Get existing resources from state or create new ones if they don't exist.
This method enables resumability by checking if resources already exist in the state
(e.g., after an interrupt), and only creating new resources if they're not present.
Args:
state: The agent state which may contain shell session resources.
Returns:
Session resources, either retrieved from state or newly created.
"""
resources = state.get("shell_session_resources")
if resources is not None and not isinstance(resources, _SessionResources):
resources = None
if resources is None:
msg = (
"Shell session resources are unavailable. Ensure `before_agent` ran successfully "
"before invoking the shell tool."
)
raise ToolException(msg)
return resources
if isinstance(resources, _SessionResources):
return resources
new_resources = self._create_resources()
# Cast needed to make state dict-like for mutation
cast("dict[str, Any]", state)["shell_session_resources"] = new_resources
return new_resources
def _create_resources(self) -> _SessionResources:
workspace = self._workspace_root
@@ -659,36 +694,6 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
artifact=artifact,
)
def wrap_tool_call(
self,
request: ToolCallRequest,
handler: typing.Callable[[ToolCallRequest], ToolMessage | Command],
) -> ToolMessage | Command:
"""Intercept local shell tool calls and execute them via the managed session."""
if isinstance(request.tool, _PersistentShellTool):
resources = self._ensure_resources(request.state)
return self._run_shell_tool(
resources,
request.tool_call["args"],
tool_call_id=request.tool_call.get("id"),
)
return handler(request)
async def awrap_tool_call(
self,
request: ToolCallRequest,
handler: typing.Callable[[ToolCallRequest], typing.Awaitable[ToolMessage | Command]],
) -> ToolMessage | Command:
"""Async interception mirroring the synchronous tool handler."""
if isinstance(request.tool, _PersistentShellTool):
resources = self._ensure_resources(request.state)
return self._run_shell_tool(
resources,
request.tool_call["args"],
tool_call_id=request.tool_call.get("id"),
)
return await handler(request)
def _format_tool_message(
self,
content: str,
@@ -703,7 +708,7 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
return ToolMessage(
content=content,
tool_call_id=tool_call_id,
name=self._tool.name,
name=self._tool_name,
status=status,
artifact=artifact,
)

View File

@@ -1,8 +1,9 @@
"""Summarization middleware."""
import uuid
from collections.abc import Callable, Iterable
from typing import Any, cast
import warnings
from collections.abc import Callable, Iterable, Mapping
from typing import Any, Literal, cast
from langchain_core.messages import (
AIMessage,
@@ -51,13 +52,26 @@ Messages to summarize:
{messages}
</messages>""" # noqa: E501
SUMMARY_PREFIX = "## Previous conversation summary:"
_DEFAULT_MESSAGES_TO_KEEP = 20
_DEFAULT_TRIM_TOKEN_LIMIT = 4000
_DEFAULT_FALLBACK_MESSAGE_COUNT = 15
_SEARCH_RANGE_FOR_TOOL_PAIRS = 5
ContextFraction = tuple[Literal["fraction"], float]
"""Tuple specifying context size as a fraction of the model's context window."""
ContextTokens = tuple[Literal["tokens"], int]
"""Tuple specifying context size as a number of tokens."""
ContextMessages = tuple[Literal["messages"], int]
"""Tuple specifying context size as a number of messages."""
ContextSize = ContextFraction | ContextTokens | ContextMessages
"""Context size tuple to specify how much history to preserve."""
ContextCondition = ContextSize | list[ContextSize | list[ContextSize]]
"""Recursive type to support nested AND/OR conditions
Top-level list = OR logic, nested list = AND logic."""
class SummarizationMiddleware(AgentMiddleware):
"""Summarizes conversation history when token limits are approached.
@@ -70,34 +84,100 @@ class SummarizationMiddleware(AgentMiddleware):
def __init__(
self,
model: str | BaseChatModel,
max_tokens_before_summary: int | None = None,
messages_to_keep: int = _DEFAULT_MESSAGES_TO_KEEP,
*,
trigger: ContextCondition | None = None,
keep: ContextSize = ("messages", _DEFAULT_MESSAGES_TO_KEEP),
token_counter: TokenCounter = count_tokens_approximately,
summary_prompt: str = DEFAULT_SUMMARY_PROMPT,
summary_prefix: str = SUMMARY_PREFIX,
trim_tokens_to_summarize: int | None = _DEFAULT_TRIM_TOKEN_LIMIT,
**deprecated_kwargs: Any,
) -> None:
"""Initialize the summarization middleware.
"""Initialize summarization middleware.
Args:
model: The language model to use for generating summaries.
max_tokens_before_summary: Token threshold to trigger summarization.
If `None`, summarization is disabled.
messages_to_keep: Number of recent messages to preserve after summarization.
trigger: One or more thresholds that trigger summarization. Supports flexible
AND/OR logic via nested lists. Top-level list items are combined with OR,
nested lists are combined with AND. Examples:
- Single condition: `("messages", 50)`
- OR conditions: `[("tokens", 3000), ("messages", 100)]` (triggers when
tokens >= 3000 OR messages >= 100)
- AND conditions: `[("tokens", 500), ("fraction", 0.8)]` as a nested list
within the top-level list
- Mixed AND/OR: `[("messages", 10), [("tokens", 500), ("fraction", 0.8)]]`
(triggers when messages >= 10 OR (tokens >= 500 AND fraction >= 0.8))
keep: Context retention policy applied after summarization.
Provide a `ContextSize` tuple to specify how much history to preserve.
Defaults to keeping the most recent 20 messages.
Examples: `("messages", 20)`, `("tokens", 3000)`, or
`("fraction", 0.3)`.
token_counter: Function to count tokens in messages.
summary_prompt: Prompt template for generating summaries.
summary_prefix: Prefix added to system message when including summary.
trim_tokens_to_summarize: Maximum tokens to keep when preparing messages for
the summarization call.
Pass `None` to skip trimming entirely.
"""
# Handle deprecated parameters
if "max_tokens_before_summary" in deprecated_kwargs:
value = deprecated_kwargs["max_tokens_before_summary"]
warnings.warn(
"max_tokens_before_summary is deprecated. Use trigger=('tokens', value) instead.",
DeprecationWarning,
stacklevel=2,
)
if trigger is None and value is not None:
trigger = ("tokens", value)
if "messages_to_keep" in deprecated_kwargs:
value = deprecated_kwargs["messages_to_keep"]
warnings.warn(
"messages_to_keep is deprecated. Use keep=('messages', value) instead.",
DeprecationWarning,
stacklevel=2,
)
if keep == ("messages", _DEFAULT_MESSAGES_TO_KEEP):
keep = ("messages", value)
super().__init__()
if isinstance(model, str):
model = init_chat_model(model)
self.model = model
self.max_tokens_before_summary = max_tokens_before_summary
self.messages_to_keep = messages_to_keep
if trigger is None:
self.trigger: ContextCondition | None = None
trigger_conditions: list[ContextSize | list[ContextSize]] = []
elif isinstance(trigger, list):
# Validate and normalize nested structure
validated_list = self._validate_trigger_conditions(trigger)
self.trigger = validated_list
trigger_conditions = validated_list
else:
# Single ContextSize tuple
validated = self._validate_context_size(trigger, "trigger")
self.trigger = validated
trigger_conditions = [validated]
self._trigger_conditions = trigger_conditions
self.keep = self._validate_context_size(keep, "keep")
self.token_counter = token_counter
self.summary_prompt = summary_prompt
self.summary_prefix = summary_prefix
self.trim_tokens_to_summarize = trim_tokens_to_summarize
requires_profile = self._requires_profile(self._trigger_conditions)
if self.keep[0] == "fraction":
requires_profile = True
if requires_profile and self._get_profile_limits() is None:
msg = (
"Model profile information is required to use fractional token limits. "
'pip install "langchain[model-profiles]" or use absolute token counts '
"instead."
)
raise ValueError(msg)
def before_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
"""Process messages before model invocation, potentially triggering summarization."""
@@ -105,13 +185,10 @@ class SummarizationMiddleware(AgentMiddleware):
self._ensure_message_ids(messages)
total_tokens = self.token_counter(messages)
if (
self.max_tokens_before_summary is not None
and total_tokens < self.max_tokens_before_summary
):
if not self._should_summarize(messages, total_tokens):
return None
cutoff_index = self._find_safe_cutoff(messages)
cutoff_index = self._determine_cutoff_index(messages)
if cutoff_index <= 0:
return None
@@ -129,6 +206,218 @@ class SummarizationMiddleware(AgentMiddleware):
]
}
async def abefore_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
"""Process messages before model invocation, potentially triggering summarization."""
messages = state["messages"]
self._ensure_message_ids(messages)
total_tokens = self.token_counter(messages)
if not self._should_summarize(messages, total_tokens):
return None
cutoff_index = self._determine_cutoff_index(messages)
if cutoff_index <= 0:
return None
messages_to_summarize, preserved_messages = self._partition_messages(messages, cutoff_index)
summary = await self._acreate_summary(messages_to_summarize)
new_messages = self._build_new_messages(summary)
return {
"messages": [
RemoveMessage(id=REMOVE_ALL_MESSAGES),
*new_messages,
*preserved_messages,
]
}
def _should_summarize(self, messages: list[AnyMessage], total_tokens: int) -> bool:
"""Determine whether summarization should run for the current token usage.
Evaluates trigger conditions with AND/OR logic:
- Top-level items are OR'd together
- Nested lists are AND'd together
"""
if not self._trigger_conditions:
return False
# OR logic across top-level conditions
for condition in self._trigger_conditions:
if isinstance(condition, list):
# AND group - all must be satisfied
if self._check_and_group(condition, messages, total_tokens):
return True
elif self._check_single_condition(condition, messages, total_tokens):
# Single condition
return True
return False
def _check_and_group(
self, and_group: list[ContextSize], messages: list[AnyMessage], total_tokens: int
) -> bool:
"""Check if all conditions in an AND group are satisfied."""
for condition in and_group:
if not self._check_single_condition(condition, messages, total_tokens):
return False
return True
def _check_single_condition(
self, condition: ContextSize, messages: list[AnyMessage], total_tokens: int
) -> bool:
"""Check if a single condition is satisfied."""
kind, value = condition
if kind == "messages":
return len(messages) >= value
if kind == "tokens":
return total_tokens >= value
if kind == "fraction":
max_input_tokens = self._get_profile_limits()
if max_input_tokens is None:
return False
threshold = int(max_input_tokens * value)
if threshold <= 0:
threshold = 1
return total_tokens >= threshold
return False
def _determine_cutoff_index(self, messages: list[AnyMessage]) -> int:
"""Choose cutoff index respecting retention configuration."""
kind, value = self.keep
if kind in {"tokens", "fraction"}:
token_based_cutoff = self._find_token_based_cutoff(messages)
if token_based_cutoff is not None:
return token_based_cutoff
# None cutoff -> model profile data not available (caught in __init__ but
# here for safety), fallback to message count
return self._find_safe_cutoff(messages, _DEFAULT_MESSAGES_TO_KEEP)
return self._find_safe_cutoff(messages, cast("int", value))
def _find_token_based_cutoff(self, messages: list[AnyMessage]) -> int | None:
"""Find cutoff index based on target token retention."""
if not messages:
return 0
kind, value = self.keep
if kind == "fraction":
max_input_tokens = self._get_profile_limits()
if max_input_tokens is None:
return None
target_token_count = int(max_input_tokens * value)
elif kind == "tokens":
target_token_count = int(value)
else:
return None
if target_token_count <= 0:
target_token_count = 1
if self.token_counter(messages) <= target_token_count:
return 0
# Use binary search to identify the earliest message index that keeps the
# suffix within the token budget.
left, right = 0, len(messages)
cutoff_candidate = len(messages)
max_iterations = len(messages).bit_length() + 1
for _ in range(max_iterations):
if left >= right:
break
mid = (left + right) // 2
if self.token_counter(messages[mid:]) <= target_token_count:
cutoff_candidate = mid
right = mid
else:
left = mid + 1
if cutoff_candidate == len(messages):
cutoff_candidate = left
if cutoff_candidate >= len(messages):
if len(messages) == 1:
return 0
cutoff_candidate = len(messages) - 1
for i in range(cutoff_candidate, -1, -1):
if self._is_safe_cutoff_point(messages, i):
return i
return 0
def _get_profile_limits(self) -> int | None:
"""Retrieve max input token limit from the model profile."""
try:
profile = self.model.profile
except (AttributeError, ImportError):
return None
if not isinstance(profile, Mapping):
return None
max_input_tokens = profile.get("max_input_tokens")
if not isinstance(max_input_tokens, int):
return None
return max_input_tokens
def _validate_context_size(self, context: ContextSize, parameter_name: str) -> ContextSize:
"""Validate context configuration tuples."""
kind, value = context
if kind == "fraction":
if not 0 < value <= 1:
msg = f"Fractional {parameter_name} values must be between 0 and 1, got {value}."
raise ValueError(msg)
elif kind in {"tokens", "messages"}:
if value <= 0:
msg = f"{parameter_name} thresholds must be greater than 0, got {value}."
raise ValueError(msg)
else:
msg = f"Unsupported context size type {kind} for {parameter_name}."
raise ValueError(msg)
return context
def _validate_trigger_conditions(
self, conditions: list[Any]
) -> list[ContextSize | list[ContextSize]]:
"""Validate and normalize trigger conditions with nested AND/OR logic.
Args:
conditions: List of ContextSize tuples or nested lists of ContextSize tuples.
Returns:
Validated list where top-level items are OR'd and nested lists are AND'd.
"""
validated: list[ContextSize | list[ContextSize]] = []
for item in conditions:
if isinstance(item, tuple):
# Single condition (tuple)
validated.append(self._validate_context_size(item, "trigger"))
elif isinstance(item, list):
# AND group (nested list)
if not item:
msg = "Empty AND groups are not allowed in trigger conditions."
raise ValueError(msg)
and_group = [self._validate_context_size(cond, "trigger") for cond in item]
validated.append(and_group)
else:
msg = f"Trigger conditions must be tuples or lists, got {type(item).__name__}."
raise ValueError(msg)
return validated
def _requires_profile(self, conditions: list[ContextSize | list[ContextSize]]) -> bool:
"""Check if any condition requires model profile information."""
for condition in conditions:
if isinstance(condition, list):
# AND group
if any(c[0] == "fraction" for c in condition):
return True
elif condition[0] == "fraction":
return True
return False
def _build_new_messages(self, summary: str) -> list[HumanMessage]:
return [
HumanMessage(content=f"Here is a summary of the conversation to date:\n\n{summary}")
@@ -151,16 +440,16 @@ class SummarizationMiddleware(AgentMiddleware):
return messages_to_summarize, preserved_messages
def _find_safe_cutoff(self, messages: list[AnyMessage]) -> int:
def _find_safe_cutoff(self, messages: list[AnyMessage], messages_to_keep: int) -> int:
"""Find safe cutoff point that preserves AI/Tool message pairs.
Returns the index where messages can be safely cut without separating
related AI and Tool messages. Returns 0 if no safe cutoff is found.
related AI and Tool messages. Returns `0` if no safe cutoff is found.
"""
if len(messages) <= self.messages_to_keep:
if len(messages) <= messages_to_keep:
return 0
target_cutoff = len(messages) - self.messages_to_keep
target_cutoff = len(messages) - messages_to_keep
for i in range(target_cutoff, -1, -1):
if self._is_safe_cutoff_point(messages, i):
@@ -229,16 +518,35 @@ class SummarizationMiddleware(AgentMiddleware):
try:
response = self.model.invoke(self.summary_prompt.format(messages=trimmed_messages))
return cast("str", response.content).strip()
return response.text.strip()
except Exception as e: # noqa: BLE001
return f"Error generating summary: {e!s}"
async def _acreate_summary(self, messages_to_summarize: list[AnyMessage]) -> str:
"""Generate summary for the given messages."""
if not messages_to_summarize:
return "No previous conversation history."
trimmed_messages = self._trim_messages_for_summary(messages_to_summarize)
if not trimmed_messages:
return "Previous conversation was too long to summarize."
try:
response = await self.model.ainvoke(
self.summary_prompt.format(messages=trimmed_messages)
)
return response.text.strip()
except Exception as e: # noqa: BLE001
return f"Error generating summary: {e!s}"
def _trim_messages_for_summary(self, messages: list[AnyMessage]) -> list[AnyMessage]:
"""Trim messages to fit within summary generation limits."""
try:
if self.trim_tokens_to_summarize is None:
return messages
return trim_messages(
messages,
max_tokens=_DEFAULT_TRIM_TOKEN_LIMIT,
max_tokens=self.trim_tokens_to_summarize,
token_counter=self.token_counter,
start_on="human",
strategy="last",

View File

@@ -150,12 +150,6 @@ class TodoListMiddleware(AgentMiddleware):
print(result["todos"]) # Array of todo items with status tracking
```
Args:
system_prompt: Custom system prompt to guide the agent on using the todo tool.
If not provided, uses the default `WRITE_TODOS_SYSTEM_PROMPT`.
tool_description: Custom description for the write_todos tool.
If not provided, uses the default `WRITE_TODOS_TOOL_DESCRIPTION`.
"""
state_schema = PlanningState
@@ -166,11 +160,12 @@ class TodoListMiddleware(AgentMiddleware):
system_prompt: str = WRITE_TODOS_SYSTEM_PROMPT,
tool_description: str = WRITE_TODOS_TOOL_DESCRIPTION,
) -> None:
"""Initialize the TodoListMiddleware with optional custom prompts.
"""Initialize the `TodoListMiddleware` with optional custom prompts.
Args:
system_prompt: Custom system prompt to guide the agent on using the todo tool.
tool_description: Custom description for the write_todos tool.
system_prompt: Custom system prompt to guide the agent on using the todo
tool.
tool_description: Custom description for the `write_todos` tool.
"""
super().__init__()
self.system_prompt = system_prompt

View File

@@ -23,22 +23,23 @@ if TYPE_CHECKING:
ExitBehavior = Literal["continue", "error", "end"]
"""How to handle execution when tool call limits are exceeded.
- `"continue"`: Block exceeded tools with error messages, let other tools continue (default)
- `"error"`: Raise a `ToolCallLimitExceededError` exception
- `"end"`: Stop execution immediately, injecting a ToolMessage and an AI message
for the single tool call that exceeded the limit. Raises `NotImplementedError`
if there are other pending tool calls (due to parallel tool calling).
- `'continue'`: Block exceeded tools with error messages, let other tools continue
(default)
- `'error'`: Raise a `ToolCallLimitExceededError` exception
- `'end'`: Stop execution immediately, injecting a `ToolMessage` and an `AIMessage` for
the single tool call that exceeded the limit. Raises `NotImplementedError` if there
are other pending tool calls (due to parallel tool calling).
"""
class ToolCallLimitState(AgentState[ResponseT], Generic[ResponseT]):
"""State schema for ToolCallLimitMiddleware.
"""State schema for `ToolCallLimitMiddleware`.
Extends AgentState with tool call tracking fields.
Extends `AgentState` with tool call tracking fields.
The count fields are dictionaries mapping tool names to execution counts.
This allows multiple middleware instances to track different tools independently.
The special key "__all__" is used for tracking all tool calls globally.
The count fields are dictionaries mapping tool names to execution counts. This
allows multiple middleware instances to track different tools independently. The
special key `'__all__'` is used for tracking all tool calls globally.
"""
thread_tool_call_count: NotRequired[Annotated[dict[str, int], PrivateStateAttr]]
@@ -46,13 +47,13 @@ class ToolCallLimitState(AgentState[ResponseT], Generic[ResponseT]):
def _build_tool_message_content(tool_name: str | None) -> str:
"""Build the error message content for ToolMessage when limit is exceeded.
"""Build the error message content for `ToolMessage` when limit is exceeded.
This message is sent to the model, so it should not reference thread/run concepts
that the model has no notion of.
Args:
tool_name: Tool name being limited (if specific tool), or None for all tools.
tool_name: Tool name being limited (if specific tool), or `None` for all tools.
Returns:
A concise message instructing the model not to call the tool again.
@@ -70,7 +71,7 @@ def _build_final_ai_message_content(
run_limit: int | None,
tool_name: str | None,
) -> str:
"""Build the final AI message content for 'end' behavior.
"""Build the final AI message content for `'end'` behavior.
This message is displayed to the user, so it should include detailed information
about which limits were exceeded.
@@ -80,7 +81,7 @@ def _build_final_ai_message_content(
run_count: Current run tool call count.
thread_limit: Thread tool call limit (if set).
run_limit: Run tool call limit (if set).
tool_name: Tool name being limited (if specific tool), or None for all tools.
tool_name: Tool name being limited (if specific tool), or `None` for all tools.
Returns:
A formatted message describing which limits were exceeded.
@@ -100,8 +101,8 @@ def _build_final_ai_message_content(
class ToolCallLimitExceededError(Exception):
"""Exception raised when tool call limits are exceeded.
This exception is raised when the configured exit behavior is 'error'
and either the thread or run tool call limit has been exceeded.
This exception is raised when the configured exit behavior is `'error'` and either
the thread or run tool call limit has been exceeded.
"""
def __init__(
@@ -145,48 +146,53 @@ class ToolCallLimitMiddleware(
Configuration:
- `exit_behavior`: How to handle when limits are exceeded
- `"continue"`: Block exceeded tools, let execution continue (default)
- `"error"`: Raise an exception
- `"end"`: Stop immediately with a ToolMessage + AI message for the single
tool call that exceeded the limit (raises `NotImplementedError` if there
are other pending tool calls (due to parallel tool calling).
- `'continue'`: Block exceeded tools, let execution continue (default)
- `'error'`: Raise an exception
- `'end'`: Stop immediately with a `ToolMessage` + AI message for the single
tool call that exceeded the limit (raises `NotImplementedError` if there
are other pending tool calls (due to parallel tool calling).
Examples:
Continue execution with blocked tools (default):
```python
from langchain.agents.middleware.tool_call_limit import ToolCallLimitMiddleware
from langchain.agents import create_agent
!!! example "Continue execution with blocked tools (default)"
# Block exceeded tools but let other tools and model continue
limiter = ToolCallLimitMiddleware(
thread_limit=20,
run_limit=10,
exit_behavior="continue", # default
)
```python
from langchain.agents.middleware.tool_call_limit import ToolCallLimitMiddleware
from langchain.agents import create_agent
agent = create_agent("openai:gpt-4o", middleware=[limiter])
```
# Block exceeded tools but let other tools and model continue
limiter = ToolCallLimitMiddleware(
thread_limit=20,
run_limit=10,
exit_behavior="continue", # default
)
Stop immediately when limit exceeded:
```python
# End execution immediately with an AI message
limiter = ToolCallLimitMiddleware(run_limit=5, exit_behavior="end")
agent = create_agent("openai:gpt-4o", middleware=[limiter])
```
agent = create_agent("openai:gpt-4o", middleware=[limiter])
```
!!! example "Stop immediately when limit exceeded"
Raise exception on limit:
```python
# Strict limit with exception handling
limiter = ToolCallLimitMiddleware(tool_name="search", thread_limit=5, exit_behavior="error")
```python
# End execution immediately with an AI message
limiter = ToolCallLimitMiddleware(run_limit=5, exit_behavior="end")
agent = create_agent("openai:gpt-4o", middleware=[limiter])
agent = create_agent("openai:gpt-4o", middleware=[limiter])
```
try:
result = await agent.invoke({"messages": [HumanMessage("Task")]})
except ToolCallLimitExceededError as e:
print(f"Search limit exceeded: {e}")
```
!!! example "Raise exception on limit"
```python
# Strict limit with exception handling
limiter = ToolCallLimitMiddleware(
tool_name="search", thread_limit=5, exit_behavior="error"
)
agent = create_agent("openai:gpt-4o", middleware=[limiter])
try:
result = await agent.invoke({"messages": [HumanMessage("Task")]})
except ToolCallLimitExceededError as e:
print(f"Search limit exceeded: {e}")
```
"""
@@ -204,23 +210,24 @@ class ToolCallLimitMiddleware(
Args:
tool_name: Name of the specific tool to limit. If `None`, limits apply
to all tools. Defaults to `None`.
to all tools.
thread_limit: Maximum number of tool calls allowed per thread.
`None` means no limit. Defaults to `None`.
`None` means no limit.
run_limit: Maximum number of tool calls allowed per run.
`None` means no limit. Defaults to `None`.
`None` means no limit.
exit_behavior: How to handle when limits are exceeded.
- `"continue"`: Block exceeded tools with error messages, let other
tools continue. Model decides when to end. (default)
- `"error"`: Raise a `ToolCallLimitExceededError` exception
- `"end"`: Stop execution immediately with a ToolMessage + AI message
for the single tool call that exceeded the limit. Raises
`NotImplementedError` if there are multiple parallel tool
calls to other tools or multiple pending tool calls.
- `'continue'`: Block exceeded tools with error messages, let other
tools continue. Model decides when to end.
- `'error'`: Raise a `ToolCallLimitExceededError` exception
- `'end'`: Stop execution immediately with a `ToolMessage` + AI message
for the single tool call that exceeded the limit. Raises
`NotImplementedError` if there are multiple parallel tool
calls to other tools or multiple pending tool calls.
Raises:
ValueError: If both limits are `None`, if exit_behavior is invalid,
or if run_limit exceeds thread_limit.
ValueError: If both limits are `None`, if `exit_behavior` is invalid,
or if `run_limit` exceeds `thread_limit`.
"""
super().__init__()
@@ -293,7 +300,8 @@ class ToolCallLimitMiddleware(
run_count: Current run call count.
Returns:
Tuple of (allowed_calls, blocked_calls, final_thread_count, final_run_count).
Tuple of `(allowed_calls, blocked_calls, final_thread_count,
final_run_count)`.
"""
allowed_calls: list[ToolCall] = []
blocked_calls: list[ToolCall] = []
@@ -327,13 +335,13 @@ class ToolCallLimitMiddleware(
Returns:
State updates with incremented tool call counts. If limits are exceeded
and exit_behavior is "end", also includes a jump to end with a ToolMessage
and AI message for the single exceeded tool call.
and exit_behavior is `'end'`, also includes a jump to end with a
`ToolMessage` and AI message for the single exceeded tool call.
Raises:
ToolCallLimitExceededError: If limits are exceeded and exit_behavior
is "error".
NotImplementedError: If limits are exceeded, exit_behavior is "end",
ToolCallLimitExceededError: If limits are exceeded and `exit_behavior`
is `'error'`.
NotImplementedError: If limits are exceeded, `exit_behavior` is `'end'`,
and there are multiple tool calls.
"""
# Get the last AIMessage to check for tool calls
@@ -452,3 +460,28 @@ class ToolCallLimitMiddleware(
"run_tool_call_count": run_counts,
"messages": artificial_messages,
}
@hook_config(can_jump_to=["end"])
async def aafter_model(
self,
state: ToolCallLimitState[ResponseT],
runtime: Runtime[ContextT],
) -> dict[str, Any] | None:
"""Async increment tool call counts after a model call and check limits.
Args:
state: The current agent state.
runtime: The langgraph runtime.
Returns:
State updates with incremented tool call counts. If limits are exceeded
and exit_behavior is `'end'`, also includes a jump to end with a
`ToolMessage` and AI message for the single exceeded tool call.
Raises:
ToolCallLimitExceededError: If limits are exceeded and `exit_behavior`
is `'error'`.
NotImplementedError: If limits are exceeded, `exit_behavior` is `'end'`,
and there are multiple tool calls.
"""
return self.after_model(state, runtime)

View File

@@ -23,39 +23,44 @@ class LLMToolEmulator(AgentMiddleware):
"""Emulates specified tools using an LLM instead of executing them.
This middleware allows selective emulation of tools for testing purposes.
By default (when tools=None), all tools are emulated. You can specify which
tools to emulate by passing a list of tool names or BaseTool instances.
By default (when `tools=None`), all tools are emulated. You can specify which
tools to emulate by passing a list of tool names or `BaseTool` instances.
Examples:
Emulate all tools (default behavior):
```python
from langchain.agents.middleware import LLMToolEmulator
!!! example "Emulate all tools (default behavior)"
middleware = LLMToolEmulator()
```python
from langchain.agents.middleware import LLMToolEmulator
agent = create_agent(
model="openai:gpt-4o",
tools=[get_weather, get_user_location, calculator],
middleware=[middleware],
)
```
middleware = LLMToolEmulator()
Emulate specific tools by name:
```python
middleware = LLMToolEmulator(tools=["get_weather", "get_user_location"])
```
agent = create_agent(
model="openai:gpt-4o",
tools=[get_weather, get_user_location, calculator],
middleware=[middleware],
)
```
Use a custom model for emulation:
```python
middleware = LLMToolEmulator(
tools=["get_weather"], model="anthropic:claude-sonnet-4-5-20250929"
)
```
!!! example "Emulate specific tools by name"
Emulate specific tools by passing tool instances:
```python
middleware = LLMToolEmulator(tools=[get_weather, get_user_location])
```
```python
middleware = LLMToolEmulator(tools=["get_weather", "get_user_location"])
```
!!! example "Use a custom model for emulation"
```python
middleware = LLMToolEmulator(
tools=["get_weather"], model="anthropic:claude-sonnet-4-5-20250929"
)
```
!!! example "Emulate specific tools by passing tool instances"
```python
middleware = LLMToolEmulator(tools=[get_weather, get_user_location])
```
"""
def __init__(
@@ -67,12 +72,16 @@ class LLMToolEmulator(AgentMiddleware):
"""Initialize the tool emulator.
Args:
tools: List of tool names (str) or BaseTool instances to emulate.
If None (default), ALL tools will be emulated.
tools: List of tool names (`str`) or `BaseTool` instances to emulate.
If `None`, ALL tools will be emulated.
If empty list, no tools will be emulated.
model: Model to use for emulation.
Defaults to "anthropic:claude-sonnet-4-5-20250929".
Can be a model identifier string or BaseChatModel instance.
Defaults to `'anthropic:claude-sonnet-4-5-20250929'`.
Can be a model identifier string or `BaseChatModel` instance.
"""
super().__init__()
@@ -110,7 +119,7 @@ class LLMToolEmulator(AgentMiddleware):
Returns:
ToolMessage with emulated response if tool should be emulated,
otherwise calls handler for normal execution.
otherwise calls handler for normal execution.
"""
tool_name = request.tool_call["name"]
@@ -152,7 +161,7 @@ class LLMToolEmulator(AgentMiddleware):
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
) -> ToolMessage | Command:
"""Async version of wrap_tool_call.
"""Async version of `wrap_tool_call`.
Emulate tool execution using LLM if tool should be emulated.
@@ -162,7 +171,7 @@ class LLMToolEmulator(AgentMiddleware):
Returns:
ToolMessage with emulated response if tool should be emulated,
otherwise calls handler for normal execution.
otherwise calls handler for normal execution.
"""
tool_name = request.tool_call["name"]

View File

@@ -26,89 +26,96 @@ class ToolRetryMiddleware(AgentMiddleware):
Supports retrying on specific exceptions and exponential backoff.
Examples:
Basic usage with default settings (2 retries, exponential backoff):
```python
from langchain.agents import create_agent
from langchain.agents.middleware import ToolRetryMiddleware
!!! example "Basic usage with default settings (2 retries, exponential backoff)"
agent = create_agent(model, tools=[search_tool], middleware=[ToolRetryMiddleware()])
```
```python
from langchain.agents import create_agent
from langchain.agents.middleware import ToolRetryMiddleware
Retry specific exceptions only:
```python
from requests.exceptions import RequestException, Timeout
agent = create_agent(model, tools=[search_tool], middleware=[ToolRetryMiddleware()])
```
retry = ToolRetryMiddleware(
max_retries=4,
retry_on=(RequestException, Timeout),
backoff_factor=1.5,
)
```
!!! example "Retry specific exceptions only"
Custom exception filtering:
```python
from requests.exceptions import HTTPError
```python
from requests.exceptions import RequestException, Timeout
retry = ToolRetryMiddleware(
max_retries=4,
retry_on=(RequestException, Timeout),
backoff_factor=1.5,
)
```
!!! example "Custom exception filtering"
```python
from requests.exceptions import HTTPError
def should_retry(exc: Exception) -> bool:
# Only retry on 5xx errors
if isinstance(exc, HTTPError):
return 500 <= exc.status_code < 600
return False
def should_retry(exc: Exception) -> bool:
# Only retry on 5xx errors
if isinstance(exc, HTTPError):
return 500 <= exc.status_code < 600
return False
retry = ToolRetryMiddleware(
max_retries=3,
retry_on=should_retry,
)
```
retry = ToolRetryMiddleware(
max_retries=3,
retry_on=should_retry,
)
```
Apply to specific tools with custom error handling:
```python
def format_error(exc: Exception) -> str:
return "Database temporarily unavailable. Please try again later."
!!! example "Apply to specific tools with custom error handling"
```python
def format_error(exc: Exception) -> str:
return "Database temporarily unavailable. Please try again later."
retry = ToolRetryMiddleware(
max_retries=4,
tools=["search_database"],
on_failure=format_error,
)
```
retry = ToolRetryMiddleware(
max_retries=4,
tools=["search_database"],
on_failure=format_error,
)
```
Apply to specific tools using BaseTool instances:
```python
from langchain_core.tools import tool
!!! example "Apply to specific tools using `BaseTool` instances"
```python
from langchain_core.tools import tool
@tool
def search_database(query: str) -> str:
'''Search the database.'''
return results
@tool
def search_database(query: str) -> str:
'''Search the database.'''
return results
retry = ToolRetryMiddleware(
max_retries=4,
tools=[search_database], # Pass BaseTool instance
)
```
retry = ToolRetryMiddleware(
max_retries=4,
tools=[search_database], # Pass BaseTool instance
)
```
Constant backoff (no exponential growth):
```python
retry = ToolRetryMiddleware(
max_retries=5,
backoff_factor=0.0, # No exponential growth
initial_delay=2.0, # Always wait 2 seconds
)
```
!!! example "Constant backoff (no exponential growth)"
Raise exception on failure:
```python
retry = ToolRetryMiddleware(
max_retries=2,
on_failure="raise", # Re-raise exception instead of returning message
)
```
```python
retry = ToolRetryMiddleware(
max_retries=5,
backoff_factor=0.0, # No exponential growth
initial_delay=2.0, # Always wait 2 seconds
)
```
!!! example "Raise exception on failure"
```python
retry = ToolRetryMiddleware(
max_retries=2,
on_failure="raise", # Re-raise exception instead of returning message
)
```
"""
def __init__(
@@ -125,34 +132,47 @@ class ToolRetryMiddleware(AgentMiddleware):
max_delay: float = 60.0,
jitter: bool = True,
) -> None:
"""Initialize ToolRetryMiddleware.
"""Initialize `ToolRetryMiddleware`.
Args:
max_retries: Maximum number of retry attempts after the initial call.
Default is 2 retries (3 total attempts). Must be >= 0.
Default is `2` retries (`3` total attempts).
Must be `>= 0`.
tools: Optional list of tools or tool names to apply retry logic to.
Can be a list of `BaseTool` instances or tool name strings.
If `None`, applies to all tools. Default is `None`.
If `None`, applies to all tools.
retry_on: Either a tuple of exception types to retry on, or a callable
that takes an exception and returns `True` if it should be retried.
Default is to retry on all exceptions.
on_failure: Behavior when all retries are exhausted. Options:
- `"return_message"` (default): Return a ToolMessage with error details,
allowing the LLM to handle the failure and potentially recover.
- `"raise"`: Re-raise the exception, stopping agent execution.
- Custom callable: Function that takes the exception and returns a string
for the ToolMessage content, allowing custom error formatting.
backoff_factor: Multiplier for exponential backoff. Each retry waits
`initial_delay * (backoff_factor ** retry_number)` seconds.
Set to 0.0 for constant delay. Default is 2.0.
initial_delay: Initial delay in seconds before first retry. Default is 1.0.
max_delay: Maximum delay in seconds between retries. Caps exponential
backoff growth. Default is 60.0.
jitter: Whether to add random jitter (±25%) to delay to avoid thundering herd.
Default is `True`.
on_failure: Behavior when all retries are exhausted.
Options:
- `'return_message'`: Return a `ToolMessage` with error details,
allowing the LLM to handle the failure and potentially recover.
- `'raise'`: Re-raise the exception, stopping agent execution.
- **Custom callable:** Function that takes the exception and returns a
string for the `ToolMessage` content, allowing custom error
formatting.
backoff_factor: Multiplier for exponential backoff.
Each retry waits `initial_delay * (backoff_factor ** retry_number)`
seconds.
Set to `0.0` for constant delay.
initial_delay: Initial delay in seconds before first retry.
max_delay: Maximum delay in seconds between retries.
Caps exponential backoff growth.
jitter: Whether to add random jitter (`±25%`) to delay to avoid thundering herd.
Raises:
ValueError: If max_retries < 0 or delays are negative.
ValueError: If `max_retries < 0` or delays are negative.
"""
super().__init__()
@@ -260,15 +280,15 @@ class ToolRetryMiddleware(AgentMiddleware):
Args:
tool_name: Name of the tool that failed.
tool_call_id: ID of the tool call (may be None).
tool_call_id: ID of the tool call (may be `None`).
exc: The exception that caused the failure.
attempts_made: Number of attempts actually made.
Returns:
ToolMessage with error details.
`ToolMessage` with error details.
Raises:
Exception: If on_failure is "raise", re-raises the exception.
Exception: If `on_failure` is `'raise'`, re-raises the exception.
"""
if self.on_failure == "raise":
raise exc
@@ -293,11 +313,11 @@ class ToolRetryMiddleware(AgentMiddleware):
"""Intercept tool execution and retry on failure.
Args:
request: Tool call request with call dict, BaseTool, state, and runtime.
request: Tool call request with call dict, `BaseTool`, state, and runtime.
handler: Callable to execute the tool (can be called multiple times).
Returns:
ToolMessage or Command (the final result).
`ToolMessage` or `Command` (the final result).
"""
tool_name = request.tool.name if request.tool else request.tool_call["name"]
@@ -342,11 +362,12 @@ class ToolRetryMiddleware(AgentMiddleware):
"""Intercept and control async tool execution with retry logic.
Args:
request: Tool call request with call dict, BaseTool, state, and runtime.
handler: Async callable to execute the tool and returns ToolMessage or Command.
request: Tool call request with call `dict`, `BaseTool`, state, and runtime.
handler: Async callable to execute the tool and returns `ToolMessage` or
`Command`.
Returns:
ToolMessage or Command (the final result).
`ToolMessage` or `Command` (the final result).
"""
tool_name = request.tool.name if request.tool else request.tool_call["name"]

View File

@@ -49,7 +49,8 @@ def _create_tool_selection_response(tools: list[BaseTool]) -> TypeAdapter:
tools: Available tools to include in the schema.
Returns:
TypeAdapter for a schema where each tool name is a Literal with its description.
`TypeAdapter` for a schema where each tool name is a `Literal` with its
description.
"""
if not tools:
msg = "Invalid usage: tools must be non-empty"
@@ -92,23 +93,25 @@ class LLMToolSelectorMiddleware(AgentMiddleware):
and helps the main model focus on the right tools.
Examples:
Limit to 3 tools:
```python
from langchain.agents.middleware import LLMToolSelectorMiddleware
!!! example "Limit to 3 tools"
middleware = LLMToolSelectorMiddleware(max_tools=3)
```python
from langchain.agents.middleware import LLMToolSelectorMiddleware
agent = create_agent(
model="openai:gpt-4o",
tools=[tool1, tool2, tool3, tool4, tool5],
middleware=[middleware],
)
```
middleware = LLMToolSelectorMiddleware(max_tools=3)
Use a smaller model for selection:
```python
middleware = LLMToolSelectorMiddleware(model="openai:gpt-4o-mini", max_tools=2)
```
agent = create_agent(
model="openai:gpt-4o",
tools=[tool1, tool2, tool3, tool4, tool5],
middleware=[middleware],
)
```
!!! example "Use a smaller model for selection"
```python
middleware = LLMToolSelectorMiddleware(model="openai:gpt-4o-mini", max_tools=2)
```
"""
def __init__(
@@ -122,13 +125,20 @@ class LLMToolSelectorMiddleware(AgentMiddleware):
"""Initialize the tool selector.
Args:
model: Model to use for selection. If not provided, uses the agent's main model.
Can be a model identifier string or BaseChatModel instance.
model: Model to use for selection.
If not provided, uses the agent's main model.
Can be a model identifier string or `BaseChatModel` instance.
system_prompt: Instructions for the selection model.
max_tools: Maximum number of tools to select. If the model selects more,
only the first max_tools will be used. No limit if not specified.
max_tools: Maximum number of tools to select.
If the model selects more, only the first `max_tools` will be used.
If not specified, there is no limit.
always_include: Tool names to always include regardless of selection.
These do not count against the max_tools limit.
These do not count against the `max_tools` limit.
"""
super().__init__()
self.system_prompt = system_prompt
@@ -144,7 +154,8 @@ class LLMToolSelectorMiddleware(AgentMiddleware):
"""Prepare inputs for tool selection.
Returns:
SelectionRequest with prepared inputs, or None if no selection is needed.
`SelectionRequest` with prepared inputs, or `None` if no selection is
needed.
"""
# If no tools available, return None
if not request.tools or len(request.tools) == 0:
@@ -211,7 +222,7 @@ class LLMToolSelectorMiddleware(AgentMiddleware):
valid_tool_names: list[str],
request: ModelRequest,
) -> ModelRequest:
"""Process the selection response and return filtered ModelRequest."""
"""Process the selection response and return filtered `ModelRequest`."""
selected_tool_names: list[str] = []
invalid_tool_selections = []

File diff suppressed because it is too large Load Diff

View File

@@ -1,10 +1,4 @@
"""Entrypoint to using [chat models](https://docs.langchain.com/oss/python/langchain/models) in LangChain.
!!! warning "Reference docs"
This page contains **reference documentation** for chat models. See
[the docs](https://docs.langchain.com/oss/python/langchain/models) for conceptual
guides, tutorials, and examples on using chat models.
""" # noqa: E501
"""Entrypoint to using [chat models](https://docs.langchain.com/oss/python/langchain/models) in LangChain.""" # noqa: E501
from langchain_core.language_models import BaseChatModel

View File

@@ -87,6 +87,21 @@ def init_chat_model(
You can also specify model and model provider in a single argument using
`'{model_provider}:{model}'` format, e.g. `'openai:o1'`.
Will attempt to infer `model_provider` from model if not specified.
The following providers will be inferred based on these model prefixes:
- `gpt-...` | `o1...` | `o3...` -> `openai`
- `claude...` -> `anthropic`
- `amazon...` -> `bedrock`
- `gemini...` -> `google_vertexai`
- `command...` -> `cohere`
- `accounts/fireworks...` -> `fireworks`
- `mistral...` -> `mistralai`
- `deepseek...` -> `deepseek`
- `grok...` -> `xai`
- `sonar...` -> `perplexity`
model_provider: The model provider if not specified as part of the model arg
(see above).
@@ -110,24 +125,11 @@ def init_chat_model(
- `ollama` -> [`langchain-ollama`](https://docs.langchain.com/oss/python/integrations/providers/ollama)
- `google_anthropic_vertex` -> [`langchain-google-vertexai`](https://docs.langchain.com/oss/python/integrations/providers/google)
- `deepseek` -> [`langchain-deepseek`](https://docs.langchain.com/oss/python/integrations/providers/deepseek)
- `ibm` -> [`langchain-ibm`](https://docs.langchain.com/oss/python/integrations/providers/deepseek)
- `ibm` -> [`langchain-ibm`](https://docs.langchain.com/oss/python/integrations/providers/ibm)
- `nvidia` -> [`langchain-nvidia-ai-endpoints`](https://docs.langchain.com/oss/python/integrations/providers/nvidia)
- `xai` -> [`langchain-xai`](https://docs.langchain.com/oss/python/integrations/providers/xai)
- `perplexity` -> [`langchain-perplexity`](https://docs.langchain.com/oss/python/integrations/providers/perplexity)
Will attempt to infer `model_provider` from model if not specified. The
following providers will be inferred based on these model prefixes:
- `gpt-...` | `o1...` | `o3...` -> `openai`
- `claude...` -> `anthropic`
- `amazon...` -> `bedrock`
- `gemini...` -> `google_vertexai`
- `command...` -> `cohere`
- `accounts/fireworks...` -> `fireworks`
- `mistral...` -> `mistralai`
- `deepseek...` -> `deepseek`
- `grok...` -> `xai`
- `sonar...` -> `perplexity`
configurable_fields: Which model parameters are configurable at runtime:
- `None`: No configurable fields (i.e., a fixed model).
@@ -142,6 +144,7 @@ def init_chat_model(
If `model` is not specified, then defaults to `("model", "model_provider")`.
!!! warning "Security note"
Setting `configurable_fields="any"` means fields like `api_key`,
`base_url`, etc., can be altered at runtime, potentially redirecting
model requests to a different service/user.

View File

@@ -1,10 +1,5 @@
"""Embeddings models.
!!! warning "Reference docs"
This page contains **reference documentation** for Embeddings. See
[the docs](https://docs.langchain.com/oss/python/langchain/retrieval#embedding-models)
for conceptual guides, tutorials, and examples on using Embeddings.
!!! warning "Modules moved"
With the release of `langchain 1.0.0`, several embeddings modules were moved to
`langchain-classic`, such as `CacheBackedEmbeddings` and all community

View File

@@ -2,11 +2,6 @@
Includes message types for different roles (e.g., human, AI, system), as well as types
for message content blocks (e.g., text, image, audio) and tool calls.
!!! warning "Reference docs"
This page contains **reference documentation** for Messages. See
[the docs](https://docs.langchain.com/oss/python/langchain/messages) for conceptual
guides, tutorials, and examples on using Messages.
"""
from langchain_core.messages import (

View File

@@ -1,10 +1,4 @@
"""Tools.
!!! warning "Reference docs"
This page contains **reference documentation** for Tools. See
[the docs](https://docs.langchain.com/oss/python/langchain/tools) for conceptual
guides, tutorials, and examples on using Tools.
"""
"""Tools."""
from langchain_core.tools import (
BaseTool,

View File

@@ -9,10 +9,10 @@ license = { text = "MIT" }
readme = "README.md"
authors = []
version = "1.0.4"
version = "1.0.5"
requires-python = ">=3.10.0,<4.0.0"
dependencies = [
"langchain-core>=1.0.2,<2.0.0",
"langchain-core>=1.0.4,<2.0.0",
"langgraph>=1.0.2,<1.1.0",
"pydantic>=2.7.4,<3.0.0",
]
@@ -57,6 +57,7 @@ test = [
"pytest-mock",
"syrupy>=4.0.2,<5.0.0",
"toml>=0.10.2,<1.0.0",
"langchain-model-profiles",
"langchain-tests",
"langchain-openai",
]
@@ -75,6 +76,7 @@ test_integration = [
"cassio>=0.1.0,<1.0.0",
"langchainhub>=0.1.16,<1.0.0",
"langchain-core",
"langchain-model-profiles",
"langchain-text-splitters",
]
@@ -83,6 +85,7 @@ prerelease = "allow"
[tool.uv.sources]
langchain-core = { path = "../core", editable = true }
langchain-model-profiles = { path = "../model-profiles", editable = true }
langchain-tests = { path = "../standard-tests", editable = true }
langchain-text-splitters = { path = "../text-splitters", editable = true }
langchain-openai = { path = "../partners/openai", editable = true }

View File

@@ -1,79 +0,0 @@
import pytest
from langchain_core.messages import HumanMessage
from pydantic import BaseModel, Field
from langchain.agents import create_agent
from langchain.agents.structured_output import ToolStrategy
class WeatherBaseModel(BaseModel):
"""Weather response."""
temperature: float = Field(description="The temperature in fahrenheit")
condition: str = Field(description="Weather condition")
def get_weather(city: str) -> str: # noqa: ARG001
"""Get the weather for a city."""
return "The weather is sunny and 75°F."
@pytest.mark.requires("langchain_openai")
def test_inference_to_native_output() -> None:
"""Test that native output is inferred when a model supports it."""
from langchain_openai import ChatOpenAI
model = ChatOpenAI(model="gpt-5")
agent = create_agent(
model,
system_prompt=(
"You are a helpful weather assistant. Please call the get_weather tool, "
"then use the WeatherReport tool to generate the final response."
),
tools=[get_weather],
response_format=WeatherBaseModel,
)
response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
assert isinstance(response["structured_response"], WeatherBaseModel)
assert response["structured_response"].temperature == 75.0
assert response["structured_response"].condition.lower() == "sunny"
assert len(response["messages"]) == 4
assert [m.type for m in response["messages"]] == [
"human", # "What's the weather?"
"ai", # "What's the weather?"
"tool", # "The weather is sunny and 75°F."
"ai", # structured response
]
@pytest.mark.requires("langchain_openai")
def test_inference_to_tool_output() -> None:
"""Test that tool output is inferred when a model supports it."""
from langchain_openai import ChatOpenAI
model = ChatOpenAI(model="gpt-4")
agent = create_agent(
model,
system_prompt=(
"You are a helpful weather assistant. Please call the get_weather tool, "
"then use the WeatherReport tool to generate the final response."
),
tools=[get_weather],
response_format=ToolStrategy(WeatherBaseModel),
)
response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
assert isinstance(response["structured_response"], WeatherBaseModel)
assert response["structured_response"].temperature == 75.0
assert response["structured_response"].condition.lower() == "sunny"
assert len(response["messages"]) == 5
assert [m.type for m in response["messages"]] == [
"human", # "What's the weather?"
"ai", # "What's the weather?"
"tool", # "The weather is sunny and 75°F."
"ai", # structured response
"tool", # artificial tool message
]

View File

@@ -0,0 +1,212 @@
# serializer version: 1
# name: test_agent_graph_with_jump_to_end_as_after_agent
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
tools(tools)
NoopZero\2ebefore_agent(NoopZero.before_agent)
NoopOne\2eafter_agent(NoopOne.after_agent)
NoopTwo\2eafter_agent(NoopTwo.after_agent)
__end__([<p>__end__</p>]):::last
NoopTwo\2eafter_agent --> NoopOne\2eafter_agent;
NoopZero\2ebefore_agent -.-> NoopTwo\2eafter_agent;
NoopZero\2ebefore_agent -.-> model;
__start__ --> NoopZero\2ebefore_agent;
model -.-> NoopTwo\2eafter_agent;
model -.-> tools;
tools -.-> model;
NoopOne\2eafter_agent --> __end__;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_create_agent_jump[memory]
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
tools(tools)
NoopSeven\2ebefore_model(NoopSeven.before_model)
NoopSeven\2eafter_model(NoopSeven.after_model)
NoopEight\2ebefore_model(NoopEight.before_model)
NoopEight\2eafter_model(NoopEight.after_model)
__end__([<p>__end__</p>]):::last
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
NoopEight\2ebefore_model -.-> __end__;
NoopEight\2ebefore_model -.-> model;
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
NoopSeven\2eafter_model -.-> __end__;
NoopSeven\2eafter_model -.-> tools;
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
__start__ --> NoopSeven\2ebefore_model;
model --> NoopEight\2eafter_model;
tools -.-> NoopSeven\2ebefore_model;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_create_agent_jump[postgres]
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
tools(tools)
NoopSeven\2ebefore_model(NoopSeven.before_model)
NoopSeven\2eafter_model(NoopSeven.after_model)
NoopEight\2ebefore_model(NoopEight.before_model)
NoopEight\2eafter_model(NoopEight.after_model)
__end__([<p>__end__</p>]):::last
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
NoopEight\2ebefore_model -.-> __end__;
NoopEight\2ebefore_model -.-> model;
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
NoopSeven\2eafter_model -.-> __end__;
NoopSeven\2eafter_model -.-> tools;
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
__start__ --> NoopSeven\2ebefore_model;
model --> NoopEight\2eafter_model;
tools -.-> NoopSeven\2ebefore_model;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_create_agent_jump[postgres_pipe]
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
tools(tools)
NoopSeven\2ebefore_model(NoopSeven.before_model)
NoopSeven\2eafter_model(NoopSeven.after_model)
NoopEight\2ebefore_model(NoopEight.before_model)
NoopEight\2eafter_model(NoopEight.after_model)
__end__([<p>__end__</p>]):::last
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
NoopEight\2ebefore_model -.-> __end__;
NoopEight\2ebefore_model -.-> model;
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
NoopSeven\2eafter_model -.-> __end__;
NoopSeven\2eafter_model -.-> tools;
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
__start__ --> NoopSeven\2ebefore_model;
model --> NoopEight\2eafter_model;
tools -.-> NoopSeven\2ebefore_model;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_create_agent_jump[postgres_pool]
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
tools(tools)
NoopSeven\2ebefore_model(NoopSeven.before_model)
NoopSeven\2eafter_model(NoopSeven.after_model)
NoopEight\2ebefore_model(NoopEight.before_model)
NoopEight\2eafter_model(NoopEight.after_model)
__end__([<p>__end__</p>]):::last
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
NoopEight\2ebefore_model -.-> __end__;
NoopEight\2ebefore_model -.-> model;
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
NoopSeven\2eafter_model -.-> __end__;
NoopSeven\2eafter_model -.-> tools;
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
__start__ --> NoopSeven\2ebefore_model;
model --> NoopEight\2eafter_model;
tools -.-> NoopSeven\2ebefore_model;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_create_agent_jump[sqlite]
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
tools(tools)
NoopSeven\2ebefore_model(NoopSeven.before_model)
NoopSeven\2eafter_model(NoopSeven.after_model)
NoopEight\2ebefore_model(NoopEight.before_model)
NoopEight\2eafter_model(NoopEight.after_model)
__end__([<p>__end__</p>]):::last
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
NoopEight\2ebefore_model -.-> __end__;
NoopEight\2ebefore_model -.-> model;
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
NoopSeven\2eafter_model -.-> __end__;
NoopSeven\2eafter_model -.-> tools;
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
__start__ --> NoopSeven\2ebefore_model;
model --> NoopEight\2eafter_model;
tools -.-> NoopSeven\2ebefore_model;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_simple_agent_graph
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
tools(tools)
__end__([<p>__end__</p>]):::last
__start__ --> model;
model -.-> __end__;
model -.-> tools;
tools -.-> model;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---

View File

@@ -0,0 +1,95 @@
# serializer version: 1
# name: test_async_middleware_with_can_jump_to_graph_snapshot
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
async_before_with_jump\2ebefore_model(async_before_with_jump.before_model)
__end__([<p>__end__</p>]):::last
__start__ --> async_before_with_jump\2ebefore_model;
async_before_with_jump\2ebefore_model -.-> __end__;
async_before_with_jump\2ebefore_model -.-> model;
model --> __end__;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_async_middleware_with_can_jump_to_graph_snapshot.1
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
async_after_with_jump\2eafter_model(async_after_with_jump.after_model)
__end__([<p>__end__</p>]):::last
__start__ --> model;
async_after_with_jump\2eafter_model -.-> __end__;
async_after_with_jump\2eafter_model -.-> model;
model --> async_after_with_jump\2eafter_model;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_async_middleware_with_can_jump_to_graph_snapshot.2
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
async_before_early_exit\2ebefore_model(async_before_early_exit.before_model)
async_after_retry\2eafter_model(async_after_retry.after_model)
__end__([<p>__end__</p>]):::last
__start__ --> async_before_early_exit\2ebefore_model;
async_after_retry\2eafter_model -.-> __end__;
async_after_retry\2eafter_model -.-> async_before_early_exit\2ebefore_model;
async_before_early_exit\2ebefore_model -.-> __end__;
async_before_early_exit\2ebefore_model -.-> model;
model --> async_after_retry\2eafter_model;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_async_middleware_with_can_jump_to_graph_snapshot.3
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
sync_before_with_jump\2ebefore_model(sync_before_with_jump.before_model)
async_after_with_jumps\2eafter_model(async_after_with_jumps.after_model)
__end__([<p>__end__</p>]):::last
__start__ --> sync_before_with_jump\2ebefore_model;
async_after_with_jumps\2eafter_model -.-> __end__;
async_after_with_jumps\2eafter_model -.-> sync_before_with_jump\2ebefore_model;
model --> async_after_with_jumps\2eafter_model;
sync_before_with_jump\2ebefore_model -.-> __end__;
sync_before_with_jump\2ebefore_model -.-> model;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---

View File

@@ -0,0 +1,289 @@
# serializer version: 1
# name: test_create_agent_diagram
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
__end__([<p>__end__</p>]):::last
__start__ --> model;
model --> __end__;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_create_agent_diagram.1
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
NoopOne\2ebefore_model(NoopOne.before_model)
__end__([<p>__end__</p>]):::last
NoopOne\2ebefore_model --> model;
__start__ --> NoopOne\2ebefore_model;
model --> __end__;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_create_agent_diagram.10
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
NoopTen\2ebefore_model(NoopTen.before_model)
NoopTen\2eafter_model(NoopTen.after_model)
__end__([<p>__end__</p>]):::last
NoopTen\2ebefore_model --> model;
__start__ --> NoopTen\2ebefore_model;
model --> NoopTen\2eafter_model;
NoopTen\2eafter_model --> __end__;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_create_agent_diagram.11
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
NoopTen\2ebefore_model(NoopTen.before_model)
NoopTen\2eafter_model(NoopTen.after_model)
NoopEleven\2ebefore_model(NoopEleven.before_model)
NoopEleven\2eafter_model(NoopEleven.after_model)
__end__([<p>__end__</p>]):::last
NoopEleven\2eafter_model --> NoopTen\2eafter_model;
NoopEleven\2ebefore_model --> model;
NoopTen\2ebefore_model --> NoopEleven\2ebefore_model;
__start__ --> NoopTen\2ebefore_model;
model --> NoopEleven\2eafter_model;
NoopTen\2eafter_model --> __end__;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_create_agent_diagram.2
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
NoopOne\2ebefore_model(NoopOne.before_model)
NoopTwo\2ebefore_model(NoopTwo.before_model)
__end__([<p>__end__</p>]):::last
NoopOne\2ebefore_model --> NoopTwo\2ebefore_model;
NoopTwo\2ebefore_model --> model;
__start__ --> NoopOne\2ebefore_model;
model --> __end__;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_create_agent_diagram.3
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
NoopOne\2ebefore_model(NoopOne.before_model)
NoopTwo\2ebefore_model(NoopTwo.before_model)
NoopThree\2ebefore_model(NoopThree.before_model)
__end__([<p>__end__</p>]):::last
NoopOne\2ebefore_model --> NoopTwo\2ebefore_model;
NoopThree\2ebefore_model --> model;
NoopTwo\2ebefore_model --> NoopThree\2ebefore_model;
__start__ --> NoopOne\2ebefore_model;
model --> __end__;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_create_agent_diagram.4
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
NoopFour\2eafter_model(NoopFour.after_model)
__end__([<p>__end__</p>]):::last
__start__ --> model;
model --> NoopFour\2eafter_model;
NoopFour\2eafter_model --> __end__;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_create_agent_diagram.5
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
NoopFour\2eafter_model(NoopFour.after_model)
NoopFive\2eafter_model(NoopFive.after_model)
__end__([<p>__end__</p>]):::last
NoopFive\2eafter_model --> NoopFour\2eafter_model;
__start__ --> model;
model --> NoopFive\2eafter_model;
NoopFour\2eafter_model --> __end__;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_create_agent_diagram.6
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
NoopFour\2eafter_model(NoopFour.after_model)
NoopFive\2eafter_model(NoopFive.after_model)
NoopSix\2eafter_model(NoopSix.after_model)
__end__([<p>__end__</p>]):::last
NoopFive\2eafter_model --> NoopFour\2eafter_model;
NoopSix\2eafter_model --> NoopFive\2eafter_model;
__start__ --> model;
model --> NoopSix\2eafter_model;
NoopFour\2eafter_model --> __end__;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_create_agent_diagram.7
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
NoopSeven\2ebefore_model(NoopSeven.before_model)
NoopSeven\2eafter_model(NoopSeven.after_model)
__end__([<p>__end__</p>]):::last
NoopSeven\2ebefore_model --> model;
__start__ --> NoopSeven\2ebefore_model;
model --> NoopSeven\2eafter_model;
NoopSeven\2eafter_model --> __end__;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_create_agent_diagram.8
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
NoopSeven\2ebefore_model(NoopSeven.before_model)
NoopSeven\2eafter_model(NoopSeven.after_model)
NoopEight\2ebefore_model(NoopEight.before_model)
NoopEight\2eafter_model(NoopEight.after_model)
__end__([<p>__end__</p>]):::last
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
NoopEight\2ebefore_model --> model;
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
__start__ --> NoopSeven\2ebefore_model;
model --> NoopEight\2eafter_model;
NoopSeven\2eafter_model --> __end__;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_create_agent_diagram.9
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
NoopSeven\2ebefore_model(NoopSeven.before_model)
NoopSeven\2eafter_model(NoopSeven.after_model)
NoopEight\2ebefore_model(NoopEight.before_model)
NoopEight\2eafter_model(NoopEight.after_model)
NoopNine\2ebefore_model(NoopNine.before_model)
NoopNine\2eafter_model(NoopNine.after_model)
__end__([<p>__end__</p>]):::last
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
NoopEight\2ebefore_model --> NoopNine\2ebefore_model;
NoopNine\2eafter_model --> NoopEight\2eafter_model;
NoopNine\2ebefore_model --> model;
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
__start__ --> NoopSeven\2ebefore_model;
model --> NoopNine\2eafter_model;
NoopSeven\2eafter_model --> __end__;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---

View File

@@ -0,0 +1,212 @@
# serializer version: 1
# name: test_agent_graph_with_jump_to_end_as_after_agent
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
tools(tools)
NoopZero\2ebefore_agent(NoopZero.before_agent)
NoopOne\2eafter_agent(NoopOne.after_agent)
NoopTwo\2eafter_agent(NoopTwo.after_agent)
__end__([<p>__end__</p>]):::last
NoopTwo\2eafter_agent --> NoopOne\2eafter_agent;
NoopZero\2ebefore_agent -.-> NoopTwo\2eafter_agent;
NoopZero\2ebefore_agent -.-> model;
__start__ --> NoopZero\2ebefore_agent;
model -.-> NoopTwo\2eafter_agent;
model -.-> tools;
tools -.-> model;
NoopOne\2eafter_agent --> __end__;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_create_agent_jump[memory]
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
tools(tools)
NoopSeven\2ebefore_model(NoopSeven.before_model)
NoopSeven\2eafter_model(NoopSeven.after_model)
NoopEight\2ebefore_model(NoopEight.before_model)
NoopEight\2eafter_model(NoopEight.after_model)
__end__([<p>__end__</p>]):::last
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
NoopEight\2ebefore_model -.-> __end__;
NoopEight\2ebefore_model -.-> model;
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
NoopSeven\2eafter_model -.-> __end__;
NoopSeven\2eafter_model -.-> tools;
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
__start__ --> NoopSeven\2ebefore_model;
model --> NoopEight\2eafter_model;
tools -.-> NoopSeven\2ebefore_model;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_create_agent_jump[postgres]
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
tools(tools)
NoopSeven\2ebefore_model(NoopSeven.before_model)
NoopSeven\2eafter_model(NoopSeven.after_model)
NoopEight\2ebefore_model(NoopEight.before_model)
NoopEight\2eafter_model(NoopEight.after_model)
__end__([<p>__end__</p>]):::last
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
NoopEight\2ebefore_model -.-> __end__;
NoopEight\2ebefore_model -.-> model;
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
NoopSeven\2eafter_model -.-> __end__;
NoopSeven\2eafter_model -.-> tools;
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
__start__ --> NoopSeven\2ebefore_model;
model --> NoopEight\2eafter_model;
tools -.-> NoopSeven\2ebefore_model;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_create_agent_jump[postgres_pipe]
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
tools(tools)
NoopSeven\2ebefore_model(NoopSeven.before_model)
NoopSeven\2eafter_model(NoopSeven.after_model)
NoopEight\2ebefore_model(NoopEight.before_model)
NoopEight\2eafter_model(NoopEight.after_model)
__end__([<p>__end__</p>]):::last
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
NoopEight\2ebefore_model -.-> __end__;
NoopEight\2ebefore_model -.-> model;
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
NoopSeven\2eafter_model -.-> __end__;
NoopSeven\2eafter_model -.-> tools;
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
__start__ --> NoopSeven\2ebefore_model;
model --> NoopEight\2eafter_model;
tools -.-> NoopSeven\2ebefore_model;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_create_agent_jump[postgres_pool]
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
tools(tools)
NoopSeven\2ebefore_model(NoopSeven.before_model)
NoopSeven\2eafter_model(NoopSeven.after_model)
NoopEight\2ebefore_model(NoopEight.before_model)
NoopEight\2eafter_model(NoopEight.after_model)
__end__([<p>__end__</p>]):::last
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
NoopEight\2ebefore_model -.-> __end__;
NoopEight\2ebefore_model -.-> model;
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
NoopSeven\2eafter_model -.-> __end__;
NoopSeven\2eafter_model -.-> tools;
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
__start__ --> NoopSeven\2ebefore_model;
model --> NoopEight\2eafter_model;
tools -.-> NoopSeven\2ebefore_model;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_create_agent_jump[sqlite]
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
tools(tools)
NoopSeven\2ebefore_model(NoopSeven.before_model)
NoopSeven\2eafter_model(NoopSeven.after_model)
NoopEight\2ebefore_model(NoopEight.before_model)
NoopEight\2eafter_model(NoopEight.after_model)
__end__([<p>__end__</p>]):::last
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
NoopEight\2ebefore_model -.-> __end__;
NoopEight\2ebefore_model -.-> model;
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
NoopSeven\2eafter_model -.-> __end__;
NoopSeven\2eafter_model -.-> tools;
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
__start__ --> NoopSeven\2ebefore_model;
model --> NoopEight\2eafter_model;
tools -.-> NoopSeven\2ebefore_model;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_simple_agent_graph
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
tools(tools)
__end__([<p>__end__</p>]):::last
__start__ --> model;
model -.-> __end__;
model -.-> tools;
tools -.-> model;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---

View File

@@ -0,0 +1,95 @@
# serializer version: 1
# name: test_async_middleware_with_can_jump_to_graph_snapshot
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
async_before_with_jump\2ebefore_model(async_before_with_jump.before_model)
__end__([<p>__end__</p>]):::last
__start__ --> async_before_with_jump\2ebefore_model;
async_before_with_jump\2ebefore_model -.-> __end__;
async_before_with_jump\2ebefore_model -.-> model;
model --> __end__;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_async_middleware_with_can_jump_to_graph_snapshot.1
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
async_after_with_jump\2eafter_model(async_after_with_jump.after_model)
__end__([<p>__end__</p>]):::last
__start__ --> model;
async_after_with_jump\2eafter_model -.-> __end__;
async_after_with_jump\2eafter_model -.-> model;
model --> async_after_with_jump\2eafter_model;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_async_middleware_with_can_jump_to_graph_snapshot.2
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
async_before_early_exit\2ebefore_model(async_before_early_exit.before_model)
async_after_retry\2eafter_model(async_after_retry.after_model)
__end__([<p>__end__</p>]):::last
__start__ --> async_before_early_exit\2ebefore_model;
async_after_retry\2eafter_model -.-> __end__;
async_after_retry\2eafter_model -.-> async_before_early_exit\2ebefore_model;
async_before_early_exit\2ebefore_model -.-> __end__;
async_before_early_exit\2ebefore_model -.-> model;
model --> async_after_retry\2eafter_model;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_async_middleware_with_can_jump_to_graph_snapshot.3
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
sync_before_with_jump\2ebefore_model(sync_before_with_jump.before_model)
async_after_with_jumps\2eafter_model(async_after_with_jumps.after_model)
__end__([<p>__end__</p>]):::last
__start__ --> sync_before_with_jump\2ebefore_model;
async_after_with_jumps\2eafter_model -.-> __end__;
async_after_with_jumps\2eafter_model -.-> sync_before_with_jump\2ebefore_model;
model --> async_after_with_jumps\2eafter_model;
sync_before_with_jump\2ebefore_model -.-> __end__;
sync_before_with_jump\2ebefore_model -.-> model;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---

View File

@@ -0,0 +1,289 @@
# serializer version: 1
# name: test_create_agent_diagram
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
__end__([<p>__end__</p>]):::last
__start__ --> model;
model --> __end__;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_create_agent_diagram.1
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
NoopOne\2ebefore_model(NoopOne.before_model)
__end__([<p>__end__</p>]):::last
NoopOne\2ebefore_model --> model;
__start__ --> NoopOne\2ebefore_model;
model --> __end__;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_create_agent_diagram.10
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
NoopTen\2ebefore_model(NoopTen.before_model)
NoopTen\2eafter_model(NoopTen.after_model)
__end__([<p>__end__</p>]):::last
NoopTen\2ebefore_model --> model;
__start__ --> NoopTen\2ebefore_model;
model --> NoopTen\2eafter_model;
NoopTen\2eafter_model --> __end__;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_create_agent_diagram.11
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
NoopTen\2ebefore_model(NoopTen.before_model)
NoopTen\2eafter_model(NoopTen.after_model)
NoopEleven\2ebefore_model(NoopEleven.before_model)
NoopEleven\2eafter_model(NoopEleven.after_model)
__end__([<p>__end__</p>]):::last
NoopEleven\2eafter_model --> NoopTen\2eafter_model;
NoopEleven\2ebefore_model --> model;
NoopTen\2ebefore_model --> NoopEleven\2ebefore_model;
__start__ --> NoopTen\2ebefore_model;
model --> NoopEleven\2eafter_model;
NoopTen\2eafter_model --> __end__;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_create_agent_diagram.2
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
NoopOne\2ebefore_model(NoopOne.before_model)
NoopTwo\2ebefore_model(NoopTwo.before_model)
__end__([<p>__end__</p>]):::last
NoopOne\2ebefore_model --> NoopTwo\2ebefore_model;
NoopTwo\2ebefore_model --> model;
__start__ --> NoopOne\2ebefore_model;
model --> __end__;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_create_agent_diagram.3
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
NoopOne\2ebefore_model(NoopOne.before_model)
NoopTwo\2ebefore_model(NoopTwo.before_model)
NoopThree\2ebefore_model(NoopThree.before_model)
__end__([<p>__end__</p>]):::last
NoopOne\2ebefore_model --> NoopTwo\2ebefore_model;
NoopThree\2ebefore_model --> model;
NoopTwo\2ebefore_model --> NoopThree\2ebefore_model;
__start__ --> NoopOne\2ebefore_model;
model --> __end__;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_create_agent_diagram.4
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
NoopFour\2eafter_model(NoopFour.after_model)
__end__([<p>__end__</p>]):::last
__start__ --> model;
model --> NoopFour\2eafter_model;
NoopFour\2eafter_model --> __end__;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_create_agent_diagram.5
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
NoopFour\2eafter_model(NoopFour.after_model)
NoopFive\2eafter_model(NoopFive.after_model)
__end__([<p>__end__</p>]):::last
NoopFive\2eafter_model --> NoopFour\2eafter_model;
__start__ --> model;
model --> NoopFive\2eafter_model;
NoopFour\2eafter_model --> __end__;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_create_agent_diagram.6
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
NoopFour\2eafter_model(NoopFour.after_model)
NoopFive\2eafter_model(NoopFive.after_model)
NoopSix\2eafter_model(NoopSix.after_model)
__end__([<p>__end__</p>]):::last
NoopFive\2eafter_model --> NoopFour\2eafter_model;
NoopSix\2eafter_model --> NoopFive\2eafter_model;
__start__ --> model;
model --> NoopSix\2eafter_model;
NoopFour\2eafter_model --> __end__;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_create_agent_diagram.7
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
NoopSeven\2ebefore_model(NoopSeven.before_model)
NoopSeven\2eafter_model(NoopSeven.after_model)
__end__([<p>__end__</p>]):::last
NoopSeven\2ebefore_model --> model;
__start__ --> NoopSeven\2ebefore_model;
model --> NoopSeven\2eafter_model;
NoopSeven\2eafter_model --> __end__;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_create_agent_diagram.8
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
NoopSeven\2ebefore_model(NoopSeven.before_model)
NoopSeven\2eafter_model(NoopSeven.after_model)
NoopEight\2ebefore_model(NoopEight.before_model)
NoopEight\2eafter_model(NoopEight.after_model)
__end__([<p>__end__</p>]):::last
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
NoopEight\2ebefore_model --> model;
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
__start__ --> NoopSeven\2ebefore_model;
model --> NoopEight\2eafter_model;
NoopSeven\2eafter_model --> __end__;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_create_agent_diagram.9
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
NoopSeven\2ebefore_model(NoopSeven.before_model)
NoopSeven\2eafter_model(NoopSeven.after_model)
NoopEight\2ebefore_model(NoopEight.before_model)
NoopEight\2eafter_model(NoopEight.after_model)
NoopNine\2ebefore_model(NoopNine.before_model)
NoopNine\2eafter_model(NoopNine.after_model)
__end__([<p>__end__</p>]):::last
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
NoopEight\2ebefore_model --> NoopNine\2ebefore_model;
NoopNine\2eafter_model --> NoopEight\2eafter_model;
NoopNine\2ebefore_model --> model;
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
__start__ --> NoopSeven\2ebefore_model;
model --> NoopNine\2eafter_model;
NoopSeven\2eafter_model --> __end__;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---

View File

@@ -0,0 +1,212 @@
# serializer version: 1
# name: test_agent_graph_with_jump_to_end_as_after_agent
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
tools(tools)
NoopZero\2ebefore_agent(NoopZero.before_agent)
NoopOne\2eafter_agent(NoopOne.after_agent)
NoopTwo\2eafter_agent(NoopTwo.after_agent)
__end__([<p>__end__</p>]):::last
NoopTwo\2eafter_agent --> NoopOne\2eafter_agent;
NoopZero\2ebefore_agent -.-> NoopTwo\2eafter_agent;
NoopZero\2ebefore_agent -.-> model;
__start__ --> NoopZero\2ebefore_agent;
model -.-> NoopTwo\2eafter_agent;
model -.-> tools;
tools -.-> model;
NoopOne\2eafter_agent --> __end__;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_create_agent_jump[memory]
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
tools(tools)
NoopSeven\2ebefore_model(NoopSeven.before_model)
NoopSeven\2eafter_model(NoopSeven.after_model)
NoopEight\2ebefore_model(NoopEight.before_model)
NoopEight\2eafter_model(NoopEight.after_model)
__end__([<p>__end__</p>]):::last
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
NoopEight\2ebefore_model -.-> __end__;
NoopEight\2ebefore_model -.-> model;
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
NoopSeven\2eafter_model -.-> __end__;
NoopSeven\2eafter_model -.-> tools;
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
__start__ --> NoopSeven\2ebefore_model;
model --> NoopEight\2eafter_model;
tools -.-> NoopSeven\2ebefore_model;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_create_agent_jump[postgres]
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
tools(tools)
NoopSeven\2ebefore_model(NoopSeven.before_model)
NoopSeven\2eafter_model(NoopSeven.after_model)
NoopEight\2ebefore_model(NoopEight.before_model)
NoopEight\2eafter_model(NoopEight.after_model)
__end__([<p>__end__</p>]):::last
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
NoopEight\2ebefore_model -.-> __end__;
NoopEight\2ebefore_model -.-> model;
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
NoopSeven\2eafter_model -.-> __end__;
NoopSeven\2eafter_model -.-> tools;
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
__start__ --> NoopSeven\2ebefore_model;
model --> NoopEight\2eafter_model;
tools -.-> NoopSeven\2ebefore_model;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_create_agent_jump[postgres_pipe]
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
tools(tools)
NoopSeven\2ebefore_model(NoopSeven.before_model)
NoopSeven\2eafter_model(NoopSeven.after_model)
NoopEight\2ebefore_model(NoopEight.before_model)
NoopEight\2eafter_model(NoopEight.after_model)
__end__([<p>__end__</p>]):::last
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
NoopEight\2ebefore_model -.-> __end__;
NoopEight\2ebefore_model -.-> model;
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
NoopSeven\2eafter_model -.-> __end__;
NoopSeven\2eafter_model -.-> tools;
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
__start__ --> NoopSeven\2ebefore_model;
model --> NoopEight\2eafter_model;
tools -.-> NoopSeven\2ebefore_model;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_create_agent_jump[postgres_pool]
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
tools(tools)
NoopSeven\2ebefore_model(NoopSeven.before_model)
NoopSeven\2eafter_model(NoopSeven.after_model)
NoopEight\2ebefore_model(NoopEight.before_model)
NoopEight\2eafter_model(NoopEight.after_model)
__end__([<p>__end__</p>]):::last
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
NoopEight\2ebefore_model -.-> __end__;
NoopEight\2ebefore_model -.-> model;
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
NoopSeven\2eafter_model -.-> __end__;
NoopSeven\2eafter_model -.-> tools;
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
__start__ --> NoopSeven\2ebefore_model;
model --> NoopEight\2eafter_model;
tools -.-> NoopSeven\2ebefore_model;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_create_agent_jump[sqlite]
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
tools(tools)
NoopSeven\2ebefore_model(NoopSeven.before_model)
NoopSeven\2eafter_model(NoopSeven.after_model)
NoopEight\2ebefore_model(NoopEight.before_model)
NoopEight\2eafter_model(NoopEight.after_model)
__end__([<p>__end__</p>]):::last
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
NoopEight\2ebefore_model -.-> __end__;
NoopEight\2ebefore_model -.-> model;
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
NoopSeven\2eafter_model -.-> __end__;
NoopSeven\2eafter_model -.-> tools;
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
__start__ --> NoopSeven\2ebefore_model;
model --> NoopEight\2eafter_model;
tools -.-> NoopSeven\2ebefore_model;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_simple_agent_graph
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
tools(tools)
__end__([<p>__end__</p>]):::last
__start__ --> model;
model -.-> __end__;
model -.-> tools;
tools -.-> model;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---

View File

@@ -22,7 +22,7 @@ from langchain.agents.middleware.types import (
hook_config,
)
from langchain.agents.factory import create_agent, _get_can_jump_to
from .model import FakeToolCallingModel
from ...model import FakeToolCallingModel
class CustomState(AgentState):

View File

@@ -0,0 +1,193 @@
from collections.abc import Callable
from syrupy.assertion import SnapshotAssertion
from langchain.agents.factory import create_agent
from langchain.agents.middleware.types import AgentMiddleware, ModelRequest
from langchain_core.messages import AIMessage
from ...model import FakeToolCallingModel
def test_create_agent_diagram(
snapshot: SnapshotAssertion,
):
class NoopOne(AgentMiddleware):
def before_model(self, state, runtime):
pass
class NoopTwo(AgentMiddleware):
def before_model(self, state, runtime):
pass
class NoopThree(AgentMiddleware):
def before_model(self, state, runtime):
pass
class NoopFour(AgentMiddleware):
def after_model(self, state, runtime):
pass
class NoopFive(AgentMiddleware):
def after_model(self, state, runtime):
pass
class NoopSix(AgentMiddleware):
def after_model(self, state, runtime):
pass
class NoopSeven(AgentMiddleware):
def before_model(self, state, runtime):
pass
def after_model(self, state, runtime):
pass
class NoopEight(AgentMiddleware):
def before_model(self, state, runtime):
pass
def after_model(self, state, runtime):
pass
class NoopNine(AgentMiddleware):
def before_model(self, state, runtime):
pass
def after_model(self, state, runtime):
pass
class NoopTen(AgentMiddleware):
def before_model(self, state, runtime):
pass
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], AIMessage],
) -> AIMessage:
return handler(request)
def after_model(self, state, runtime):
pass
class NoopEleven(AgentMiddleware):
def before_model(self, state, runtime):
pass
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], AIMessage],
) -> AIMessage:
return handler(request)
def after_model(self, state, runtime):
pass
agent_zero = create_agent(
model=FakeToolCallingModel(),
tools=[],
system_prompt="You are a helpful assistant.",
)
assert agent_zero.get_graph().draw_mermaid() == snapshot
agent_one = create_agent(
model=FakeToolCallingModel(),
tools=[],
system_prompt="You are a helpful assistant.",
middleware=[NoopOne()],
)
assert agent_one.get_graph().draw_mermaid() == snapshot
agent_two = create_agent(
model=FakeToolCallingModel(),
tools=[],
system_prompt="You are a helpful assistant.",
middleware=[NoopOne(), NoopTwo()],
)
assert agent_two.get_graph().draw_mermaid() == snapshot
agent_three = create_agent(
model=FakeToolCallingModel(),
tools=[],
system_prompt="You are a helpful assistant.",
middleware=[NoopOne(), NoopTwo(), NoopThree()],
)
assert agent_three.get_graph().draw_mermaid() == snapshot
agent_four = create_agent(
model=FakeToolCallingModel(),
tools=[],
system_prompt="You are a helpful assistant.",
middleware=[NoopFour()],
)
assert agent_four.get_graph().draw_mermaid() == snapshot
agent_five = create_agent(
model=FakeToolCallingModel(),
tools=[],
system_prompt="You are a helpful assistant.",
middleware=[NoopFour(), NoopFive()],
)
assert agent_five.get_graph().draw_mermaid() == snapshot
agent_six = create_agent(
model=FakeToolCallingModel(),
tools=[],
system_prompt="You are a helpful assistant.",
middleware=[NoopFour(), NoopFive(), NoopSix()],
)
assert agent_six.get_graph().draw_mermaid() == snapshot
agent_seven = create_agent(
model=FakeToolCallingModel(),
tools=[],
system_prompt="You are a helpful assistant.",
middleware=[NoopSeven()],
)
assert agent_seven.get_graph().draw_mermaid() == snapshot
agent_eight = create_agent(
model=FakeToolCallingModel(),
tools=[],
system_prompt="You are a helpful assistant.",
middleware=[NoopSeven(), NoopEight()],
)
assert agent_eight.get_graph().draw_mermaid() == snapshot
agent_nine = create_agent(
model=FakeToolCallingModel(),
tools=[],
system_prompt="You are a helpful assistant.",
middleware=[NoopSeven(), NoopEight(), NoopNine()],
)
assert agent_nine.get_graph().draw_mermaid() == snapshot
agent_ten = create_agent(
model=FakeToolCallingModel(),
tools=[],
system_prompt="You are a helpful assistant.",
middleware=[NoopTen()],
)
assert agent_ten.get_graph().draw_mermaid() == snapshot
agent_eleven = create_agent(
model=FakeToolCallingModel(),
tools=[],
system_prompt="You are a helpful assistant.",
middleware=[NoopTen(), NoopEleven()],
)
assert agent_eleven.get_graph().draw_mermaid() == snapshot

File diff suppressed because it is too large Load Diff

View File

@@ -14,7 +14,7 @@ from langgraph.checkpoint.memory import InMemorySaver
from langchain.agents.factory import create_agent
from langchain.agents.middleware.types import AgentMiddleware, wrap_tool_call
from langchain.agents.middleware.types import ToolCallRequest
from tests.unit_tests.agents.test_middleware_agent import FakeToolCallingModel
from tests.unit_tests.agents.model import FakeToolCallingModel
@tool

View File

@@ -9,7 +9,7 @@ from langchain.agents.middleware.types import AgentMiddleware, AgentState, Model
from langgraph.prebuilt.tool_node import ToolNode
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from langchain_core.tools import tool
from .model import FakeToolCallingModel
from tests.unit_tests.agents.model import FakeToolCallingModel
def test_model_request_tools_are_base_tools() -> None:

View File

@@ -1,18 +1,30 @@
"""Unit tests for wrap_model_call middleware generator protocol."""
"""Unit tests for wrap_model_call hook and @wrap_model_call decorator.
This module tests the wrap_model_call functionality in three forms:
1. As a middleware method (AgentMiddleware.wrap_model_call)
2. As a decorator (@wrap_model_call)
3. Async variant (AgentMiddleware.awrap_model_call)
"""
from collections.abc import Awaitable, Callable
import pytest
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain.agents import create_agent
from langchain.agents.middleware.types import (
AgentMiddleware,
AgentState,
ModelRequest,
wrap_model_call,
)
from ...model import FakeToolCallingModel
class TestBasicOnModelCall:
class TestBasicWrapModelCall:
"""Test basic wrap_model_call functionality."""
def test_passthrough_middleware(self) -> None:
@@ -70,7 +82,7 @@ class TestBasicOnModelCall:
assert counter.call_count == 1
class TestRetryMiddleware:
class TestRetryLogic:
"""Test retry logic with wrap_model_call."""
def test_simple_retry_on_error(self) -> None:
@@ -91,12 +103,10 @@ class TestRetryMiddleware:
def wrap_model_call(self, request, handler):
try:
result = handler(request)
return result
return handler(request)
except Exception:
self.retry_count += 1
result = handler(request)
return result
return handler(request)
retry_middleware = RetryOnceMiddleware()
model = FailOnceThenSucceed(messages=iter([AIMessage(content="Success")]))
@@ -125,8 +135,7 @@ class TestRetryMiddleware:
for attempt in range(self.max_retries):
self.attempts.append(attempt + 1)
try:
result = handler(request)
return result
return handler(request)
except Exception as e:
last_exception = e
continue
@@ -143,6 +152,75 @@ class TestRetryMiddleware:
assert retry_middleware.attempts == [1, 2, 3]
def test_no_retry_propagates_error(self) -> None:
"""Test that error is propagated when middleware doesn't retry."""
class FailingModel(BaseChatModel):
"""Model that always fails."""
def _generate(self, messages, **kwargs):
raise ValueError("Model error")
@property
def _llm_type(self):
return "failing"
class NoRetryMiddleware(AgentMiddleware):
def wrap_model_call(self, request, handler):
return handler(request)
agent = create_agent(model=FailingModel(), middleware=[NoRetryMiddleware()])
with pytest.raises(ValueError, match="Model error"):
agent.invoke({"messages": [HumanMessage("Test")]})
def test_max_attempts_limit(self) -> None:
"""Test that middleware controls termination via retry limits."""
class AlwaysFailingModel(BaseChatModel):
"""Model that always fails."""
def _generate(self, messages, **kwargs):
raise ValueError("Always fails")
@property
def _llm_type(self):
return "always_failing"
class LimitedRetryMiddleware(AgentMiddleware):
"""Middleware that limits its own retries."""
def __init__(self, max_retries: int = 10):
super().__init__()
self.max_retries = max_retries
self.attempt_count = 0
def wrap_model_call(self, request, handler):
last_exception = None
for attempt in range(self.max_retries):
self.attempt_count += 1
try:
return handler(request)
except Exception as e:
last_exception = e
# Continue to retry
# All retries exhausted, re-raise the last error
if last_exception:
raise last_exception
model = AlwaysFailingModel()
middleware = LimitedRetryMiddleware(max_retries=10)
agent = create_agent(model=model, middleware=[middleware])
# Should fail with the model's error after middleware stops retrying
with pytest.raises(ValueError, match="Always fails"):
agent.invoke({"messages": [HumanMessage("Test")]})
# Should have attempted exactly 10 times as configured
assert middleware.attempt_count == 10
class TestResponseRewriting:
"""Test response content rewriting with wrap_model_call."""
@@ -185,6 +263,28 @@ class TestResponseRewriting:
assert result["messages"][1].content == "[BOT]: Response"
def test_multi_stage_transformation(self) -> None:
"""Test middleware applying multiple transformations."""
class MultiTransformMiddleware(AgentMiddleware):
def wrap_model_call(self, request, handler):
result = handler(request)
# result is ModelResponse, extract AIMessage from it
ai_message = result.result[0]
# First transformation: uppercase
content = ai_message.content.upper()
# Second transformation: add prefix and suffix
content = f"[START] {content} [END]"
return AIMessage(content=content)
model = GenericFakeChatModel(messages=iter([AIMessage(content="hello")]))
agent = create_agent(model=model, middleware=[MultiTransformMiddleware()])
result = agent.invoke({"messages": [HumanMessage("Test")]})
assert result["messages"][1].content == "[START] HELLO [END]"
class TestErrorHandling:
"""Test error handling with wrap_model_call."""
@@ -200,9 +300,8 @@ class TestErrorHandling:
def wrap_model_call(self, request, handler):
try:
return handler(request)
except Exception:
fallback = AIMessage(content="Error handled gracefully")
return fallback
except Exception as e:
return AIMessage(content=f"Error occurred: {e}. Using fallback response.")
model = AlwaysFailModel(messages=iter([]))
agent = create_agent(model=model, middleware=[ErrorToSuccessMiddleware()])
@@ -210,7 +309,8 @@ class TestErrorHandling:
# Should not raise, middleware converts error to response
result = agent.invoke({"messages": [HumanMessage("Test")]})
assert "Error handled gracefully" in result["messages"][1].content
assert "Error occurred" in result["messages"][1].content
assert "fallback response" in result["messages"][1].content
def test_selective_error_handling(self) -> None:
"""Test middleware that only handles specific errors."""
@@ -224,8 +324,7 @@ class TestErrorHandling:
try:
return handler(request)
except ConnectionError:
fallback = AIMessage(content="Network issue, try again later")
return fallback
return AIMessage(content="Network issue, try again later")
model = SpecificErrorModel(messages=iter([]))
agent = create_agent(model=model, middleware=[SelectiveErrorMiddleware()])
@@ -247,8 +346,7 @@ class TestErrorHandling:
return result
except Exception:
call_log.append("caught-error")
fallback = AIMessage(content="Recovered from error")
return fallback
return AIMessage(content="Recovered from error")
# Test 1: Success path
call_log.clear()
@@ -403,7 +501,6 @@ class TestStateAndRuntime:
for attempt in range(max_retries):
try:
return handler(request)
break # Success
except Exception:
if attempt == max_retries - 1:
raise
@@ -460,6 +557,49 @@ class TestMiddlewareComposition:
"outer-after",
]
def test_three_middleware_composition(self) -> None:
"""Test composition of three middleware."""
execution_order = []
class FirstMiddleware(AgentMiddleware):
def wrap_model_call(self, request, handler):
execution_order.append("first-before")
response = handler(request)
execution_order.append("first-after")
return response
class SecondMiddleware(AgentMiddleware):
def wrap_model_call(self, request, handler):
execution_order.append("second-before")
response = handler(request)
execution_order.append("second-after")
return response
class ThirdMiddleware(AgentMiddleware):
def wrap_model_call(self, request, handler):
execution_order.append("third-before")
response = handler(request)
execution_order.append("third-after")
return response
model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")]))
agent = create_agent(
model=model,
middleware=[FirstMiddleware(), SecondMiddleware(), ThirdMiddleware()],
)
agent.invoke({"messages": [HumanMessage("Test")]})
# First wraps Second wraps Third: 1-before, 2-before, 3-before, model, 3-after, 2-after, 1-after
assert execution_order == [
"first-before",
"second-before",
"third-before",
"third-after",
"second-after",
"first-after",
]
def test_retry_with_logging(self) -> None:
"""Test retry middleware composed with logging middleware."""
call_count = {"value": 0}
@@ -549,11 +689,9 @@ class TestMiddlewareComposition:
class RetryMiddleware(AgentMiddleware):
def wrap_model_call(self, request, handler):
try:
result = handler(request)
return result
return handler(request)
except Exception:
result = handler(request)
return result
return handler(request)
class UppercaseMiddleware(AgentMiddleware):
def wrap_model_call(self, request, handler):
@@ -571,49 +709,6 @@ class TestMiddlewareComposition:
# Should retry and uppercase the result
assert result["messages"][1].content == "SUCCESS"
def test_three_middleware_composition(self) -> None:
"""Test composition of three middleware."""
execution_order = []
class FirstMiddleware(AgentMiddleware):
def wrap_model_call(self, request, handler):
execution_order.append("first-before")
response = handler(request)
execution_order.append("first-after")
return response
class SecondMiddleware(AgentMiddleware):
def wrap_model_call(self, request, handler):
execution_order.append("second-before")
response = handler(request)
execution_order.append("second-after")
return response
class ThirdMiddleware(AgentMiddleware):
def wrap_model_call(self, request, handler):
execution_order.append("third-before")
response = handler(request)
execution_order.append("third-after")
return response
model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")]))
agent = create_agent(
model=model,
middleware=[FirstMiddleware(), SecondMiddleware(), ThirdMiddleware()],
)
agent.invoke({"messages": [HumanMessage("Test")]})
# First wraps Second wraps Third: 1-before, 2-before, 3-before, model, 3-after, 2-after, 1-after
assert execution_order == [
"first-before",
"second-before",
"third-before",
"third-after",
"second-after",
"first-after",
]
def test_middle_retry_middleware(self) -> None:
"""Test that middle middleware doing retry causes inner to execute twice."""
execution_order = []
@@ -674,7 +769,306 @@ class TestMiddlewareComposition:
assert len(model_calls) == 2
class TestAsyncOnModelCall:
class TestWrapModelCallDecorator:
"""Test the @wrap_model_call decorator for creating middleware."""
def test_basic_decorator_usage(self) -> None:
"""Test basic decorator usage without parameters."""
@wrap_model_call
def passthrough_middleware(request, handler):
return handler(request)
# Should return an AgentMiddleware instance
assert isinstance(passthrough_middleware, AgentMiddleware)
# Should work in agent
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
agent = create_agent(model=model, middleware=[passthrough_middleware])
result = agent.invoke({"messages": [HumanMessage("Hi")]})
assert len(result["messages"]) == 2
assert result["messages"][1].content == "Hello"
def test_decorator_with_custom_name(self) -> None:
"""Test decorator with custom middleware name."""
@wrap_model_call(name="CustomMiddleware")
def my_middleware(request, handler):
return handler(request)
assert isinstance(my_middleware, AgentMiddleware)
assert my_middleware.__class__.__name__ == "CustomMiddleware"
def test_decorator_retry_logic(self) -> None:
"""Test decorator for implementing retry logic."""
call_count = {"value": 0}
class FailOnceThenSucceed(GenericFakeChatModel):
def _generate(self, messages, **kwargs):
call_count["value"] += 1
if call_count["value"] == 1:
raise ValueError("First call fails")
return super()._generate(messages, **kwargs)
@wrap_model_call
def retry_once(request, handler):
try:
return handler(request)
except Exception:
# Retry once
return handler(request)
model = FailOnceThenSucceed(messages=iter([AIMessage(content="Success")]))
agent = create_agent(model=model, middleware=[retry_once])
result = agent.invoke({"messages": [HumanMessage("Test")]})
assert call_count["value"] == 2
assert result["messages"][1].content == "Success"
def test_decorator_response_rewriting(self) -> None:
"""Test decorator for rewriting responses."""
@wrap_model_call
def uppercase_responses(request, handler):
result = handler(request)
# result is ModelResponse, extract AIMessage from it
ai_message = result.result[0]
return AIMessage(content=ai_message.content.upper())
model = GenericFakeChatModel(messages=iter([AIMessage(content="hello world")]))
agent = create_agent(model=model, middleware=[uppercase_responses])
result = agent.invoke({"messages": [HumanMessage("Test")]})
assert result["messages"][1].content == "HELLO WORLD"
def test_decorator_error_handling(self) -> None:
"""Test decorator for error recovery."""
class AlwaysFailModel(GenericFakeChatModel):
def _generate(self, messages, **kwargs):
raise ValueError("Model error")
@wrap_model_call
def error_to_fallback(request, handler):
try:
return handler(request)
except Exception:
return AIMessage(content="Fallback response")
model = AlwaysFailModel(messages=iter([]))
agent = create_agent(model=model, middleware=[error_to_fallback])
result = agent.invoke({"messages": [HumanMessage("Test")]})
assert result["messages"][1].content == "Fallback response"
def test_decorator_with_state_access(self) -> None:
"""Test decorator accessing agent state."""
state_values = []
@wrap_model_call
def log_state(request, handler):
state_values.append(request.state.get("messages"))
return handler(request)
model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")]))
agent = create_agent(model=model, middleware=[log_state])
agent.invoke({"messages": [HumanMessage("Test")]})
# State should contain the user message
assert len(state_values) == 1
assert len(state_values[0]) == 1
assert state_values[0][0].content == "Test"
def test_multiple_decorated_middleware(self) -> None:
"""Test composition of multiple decorated middleware."""
execution_order = []
@wrap_model_call
def outer_middleware(request, handler):
execution_order.append("outer-before")
result = handler(request)
execution_order.append("outer-after")
return result
@wrap_model_call
def inner_middleware(request, handler):
execution_order.append("inner-before")
result = handler(request)
execution_order.append("inner-after")
return result
model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")]))
agent = create_agent(model=model, middleware=[outer_middleware, inner_middleware])
agent.invoke({"messages": [HumanMessage("Test")]})
assert execution_order == [
"outer-before",
"inner-before",
"inner-after",
"outer-after",
]
def test_decorator_with_custom_state_schema(self) -> None:
"""Test decorator with custom state schema."""
from typing_extensions import TypedDict
class CustomState(TypedDict):
messages: list
custom_field: str
@wrap_model_call(state_schema=CustomState)
def middleware_with_schema(request, handler):
return handler(request)
assert isinstance(middleware_with_schema, AgentMiddleware)
# Custom state schema should be set
assert middleware_with_schema.state_schema == CustomState
def test_decorator_with_tools_parameter(self) -> None:
"""Test decorator with tools parameter."""
from langchain_core.tools import tool
@tool
def test_tool(query: str) -> str:
"""A test tool."""
return f"Result: {query}"
@wrap_model_call(tools=[test_tool])
def middleware_with_tools(request, handler):
return handler(request)
assert isinstance(middleware_with_tools, AgentMiddleware)
assert len(middleware_with_tools.tools) == 1
assert middleware_with_tools.tools[0].name == "test_tool"
def test_decorator_parentheses_optional(self) -> None:
"""Test that decorator works both with and without parentheses."""
# Without parentheses
@wrap_model_call
def middleware_no_parens(request, handler):
return handler(request)
# With parentheses
@wrap_model_call()
def middleware_with_parens(request, handler):
return handler(request)
assert isinstance(middleware_no_parens, AgentMiddleware)
assert isinstance(middleware_with_parens, AgentMiddleware)
def test_decorator_preserves_function_name(self) -> None:
"""Test that decorator uses function name for class name."""
@wrap_model_call
def my_custom_middleware(request, handler):
return handler(request)
assert my_custom_middleware.__class__.__name__ == "my_custom_middleware"
def test_decorator_mixed_with_class_middleware(self) -> None:
"""Test decorated middleware mixed with class-based middleware."""
execution_order = []
@wrap_model_call
def decorated_middleware(request, handler):
execution_order.append("decorated-before")
result = handler(request)
execution_order.append("decorated-after")
return result
class ClassMiddleware(AgentMiddleware):
def wrap_model_call(self, request, handler):
execution_order.append("class-before")
result = handler(request)
execution_order.append("class-after")
return result
model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")]))
agent = create_agent(
model=model,
middleware=[decorated_middleware, ClassMiddleware()],
)
agent.invoke({"messages": [HumanMessage("Test")]})
# Decorated is outer, class-based is inner
assert execution_order == [
"decorated-before",
"class-before",
"class-after",
"decorated-after",
]
def test_decorator_complex_retry_logic(self) -> None:
"""Test decorator with complex retry logic and backoff."""
attempts = []
call_count = {"value": 0}
class UnreliableModel(GenericFakeChatModel):
def _generate(self, messages, **kwargs):
call_count["value"] += 1
if call_count["value"] <= 2:
raise ValueError(f"Attempt {call_count['value']} failed")
return super()._generate(messages, **kwargs)
@wrap_model_call
def retry_with_tracking(request, handler):
max_retries = 3
for attempt in range(max_retries):
attempts.append(attempt + 1)
try:
return handler(request)
except Exception:
# On error, continue to next attempt
if attempt < max_retries - 1:
continue # Retry
else:
raise # All retries failed
model = UnreliableModel(messages=iter([AIMessage(content="Finally worked")]))
agent = create_agent(model=model, middleware=[retry_with_tracking])
result = agent.invoke({"messages": [HumanMessage("Test")]})
assert attempts == [1, 2, 3]
assert result["messages"][1].content == "Finally worked"
def test_decorator_request_modification(self) -> None:
"""Test decorator modifying request before execution."""
modified_prompts = []
@wrap_model_call
def add_system_prompt(request, handler):
# Modify request to add system prompt
modified_request = ModelRequest(
messages=request.messages,
model=request.model,
system_prompt="You are a helpful assistant",
tool_choice=request.tool_choice,
tools=request.tools,
response_format=request.response_format,
state={},
runtime=None,
)
modified_prompts.append(modified_request.system_prompt)
return handler(modified_request)
model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")]))
agent = create_agent(model=model, middleware=[add_system_prompt])
agent.invoke({"messages": [HumanMessage("Test")]})
assert modified_prompts == ["You are a helpful assistant"]
class TestAsyncWrapModelCall:
"""Test async execution with wrap_model_call."""
async def test_async_model_with_middleware(self) -> None:
@@ -686,7 +1080,6 @@ class TestAsyncOnModelCall:
log.append("before")
result = await handler(request)
log.append("after")
return result
model = GenericFakeChatModel(messages=iter([AIMessage(content="Async response")]))
@@ -723,6 +1116,92 @@ class TestAsyncOnModelCall:
assert call_count["value"] == 2
assert result["messages"][1].content == "Async success"
async def test_decorator_with_async_agent(self) -> None:
"""Test that decorated middleware works with async agent invocation."""
call_log = []
@wrap_model_call
async def logging_middleware(request, handler):
call_log.append("before")
result = await handler(request)
call_log.append("after")
return result
model = GenericFakeChatModel(messages=iter([AIMessage(content="Async response")]))
agent = create_agent(model=model, middleware=[logging_middleware])
result = await agent.ainvoke({"messages": [HumanMessage("Test")]})
assert call_log == ["before", "after"]
assert result["messages"][1].content == "Async response"
class TestSyncAsyncInterop:
"""Test sync/async interoperability."""
def test_sync_invoke_with_only_async_middleware_raises_error(self) -> None:
"""Test that sync invoke with only async middleware raises error."""
class AsyncOnlyMiddleware(AgentMiddleware):
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[AIMessage]],
) -> AIMessage:
return await handler(request)
agent = create_agent(
model=FakeToolCallingModel(),
tools=[],
system_prompt="You are a helpful assistant.",
middleware=[AsyncOnlyMiddleware()],
)
with pytest.raises(NotImplementedError):
agent.invoke({"messages": [HumanMessage("hello")]})
def test_sync_invoke_with_mixed_middleware(self) -> None:
"""Test that sync invoke works with mixed sync/async middleware when sync versions exist."""
calls = []
class MixedMiddleware(AgentMiddleware):
def before_model(self, state, runtime) -> None:
calls.append("MixedMiddleware.before_model")
async def abefore_model(self, state, runtime) -> None:
calls.append("MixedMiddleware.abefore_model")
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], AIMessage],
) -> AIMessage:
calls.append("MixedMiddleware.wrap_model_call")
return handler(request)
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[AIMessage]],
) -> AIMessage:
calls.append("MixedMiddleware.awrap_model_call")
return await handler(request)
agent = create_agent(
model=FakeToolCallingModel(),
tools=[],
system_prompt="You are a helpful assistant.",
middleware=[MixedMiddleware()],
)
result = agent.invoke({"messages": [HumanMessage("hello")]})
# In sync mode, only sync methods should be called
assert calls == [
"MixedMiddleware.before_model",
"MixedMiddleware.wrap_model_call",
]
class TestEdgeCases:
"""Test edge cases and error conditions."""
@@ -753,12 +1232,10 @@ class TestEdgeCases:
def wrap_model_call(self, request, handler):
attempts.append("first-attempt")
try:
result = handler(request)
return result
return handler(request)
except Exception:
attempts.append("retry-attempt")
result = handler(request)
return result
return handler(request)
call_count = {"value": 0}

View File

@@ -14,7 +14,7 @@ from langgraph.types import Command
from langchain.agents.factory import create_agent
from langchain.agents.middleware.types import wrap_tool_call
from langchain.agents.middleware.types import ToolCallRequest
from tests.unit_tests.agents.test_middleware_agent import FakeToolCallingModel
from tests.unit_tests.agents.model import FakeToolCallingModel
@tool

View File

@@ -7,6 +7,9 @@ import pytest
from langchain.agents.middleware.file_search import (
FilesystemFileSearchMiddleware,
_expand_include_patterns,
_is_valid_include_pattern,
_match_include_pattern,
)
@@ -259,3 +262,105 @@ class TestPathTraversalSecurity:
assert result == "No matches found"
assert "secret" not in result
class TestExpandIncludePatterns:
"""Tests for _expand_include_patterns helper function."""
def test_expand_patterns_basic_brace_expansion(self) -> None:
"""Test basic brace expansion with multiple options."""
result = _expand_include_patterns("*.{py,txt}")
assert result == ["*.py", "*.txt"]
def test_expand_patterns_nested_braces(self) -> None:
"""Test nested brace expansion."""
result = _expand_include_patterns("test.{a,b}.{c,d}")
assert result is not None
assert len(result) == 4
assert "test.a.c" in result
assert "test.b.d" in result
@pytest.mark.parametrize(
"pattern",
[
"*.py}", # closing brace without opening
"*.{}", # empty braces
"*.{py", # unclosed brace
],
)
def test_expand_patterns_invalid_braces(self, pattern: str) -> None:
"""Test patterns with invalid brace syntax return None."""
result = _expand_include_patterns(pattern)
assert result is None
class TestValidateIncludePattern:
"""Tests for _is_valid_include_pattern helper function."""
@pytest.mark.parametrize(
"pattern",
[
"", # empty pattern
"*.py\x00", # null byte
"*.py\n", # newline
],
)
def test_validate_invalid_patterns(self, pattern: str) -> None:
"""Test that invalid patterns are rejected."""
assert not _is_valid_include_pattern(pattern)
class TestMatchIncludePattern:
"""Tests for _match_include_pattern helper function."""
def test_match_pattern_with_braces(self) -> None:
"""Test matching with brace expansion."""
assert _match_include_pattern("test.py", "*.{py,txt}")
assert _match_include_pattern("test.txt", "*.{py,txt}")
assert not _match_include_pattern("test.md", "*.{py,txt}")
def test_match_pattern_invalid_expansion(self) -> None:
"""Test matching with pattern that cannot be expanded returns False."""
assert not _match_include_pattern("test.py", "*.{}")
class TestGrepEdgeCases:
"""Tests for edge cases in grep search."""
def test_grep_with_special_chars_in_pattern(self, tmp_path: Path) -> None:
"""Test grep with special characters in pattern."""
(tmp_path / "test.py").write_text("def test():\n pass\n", encoding="utf-8")
middleware = FilesystemFileSearchMiddleware(root_path=str(tmp_path), use_ripgrep=False)
result = middleware.grep_search.func(pattern="def.*:")
assert "/test.py" in result
def test_grep_case_insensitive(self, tmp_path: Path) -> None:
"""Test grep with case-insensitive search."""
(tmp_path / "test.py").write_text("HELLO world\n", encoding="utf-8")
middleware = FilesystemFileSearchMiddleware(root_path=str(tmp_path), use_ripgrep=False)
result = middleware.grep_search.func(pattern="(?i)hello")
assert "/test.py" in result
def test_grep_with_large_file_skipping(self, tmp_path: Path) -> None:
"""Test that grep skips files larger than max_file_size_mb."""
# Create a file larger than 1MB
large_content = "x" * (2 * 1024 * 1024) # 2MB
(tmp_path / "large.txt").write_text(large_content, encoding="utf-8")
(tmp_path / "small.txt").write_text("x", encoding="utf-8")
middleware = FilesystemFileSearchMiddleware(
root_path=str(tmp_path),
use_ripgrep=False,
max_file_size_mb=1, # 1MB limit
)
result = middleware.grep_search.func(pattern="x")
# Large file should be skipped
assert "/small.txt" in result

View File

@@ -0,0 +1,575 @@
from typing import Any
from unittest.mock import patch
import pytest
from langchain_core.messages import AIMessage, HumanMessage, ToolCall, ToolMessage
from langgraph.runtime import Runtime
from langchain.agents.middleware.human_in_the_loop import (
Action,
HumanInTheLoopMiddleware,
)
from langchain.agents.middleware.types import AgentState
def test_human_in_the_loop_middleware_initialization() -> None:
"""Test HumanInTheLoopMiddleware initialization."""
middleware = HumanInTheLoopMiddleware(
interrupt_on={"test_tool": {"allowed_decisions": ["approve", "edit", "reject"]}},
description_prefix="Custom prefix",
)
assert middleware.interrupt_on == {
"test_tool": {"allowed_decisions": ["approve", "edit", "reject"]}
}
assert middleware.description_prefix == "Custom prefix"
def test_human_in_the_loop_middleware_no_interrupts_needed() -> None:
"""Test HumanInTheLoopMiddleware when no interrupts are needed."""
middleware = HumanInTheLoopMiddleware(
interrupt_on={"test_tool": {"allowed_decisions": ["approve", "edit", "reject"]}}
)
# Test with no messages
state: dict[str, Any] = {"messages": []}
result = middleware.after_model(state, None)
assert result is None
# Test with message but no tool calls
state = {"messages": [HumanMessage(content="Hello"), AIMessage(content="Hi there")]}
result = middleware.after_model(state, None)
assert result is None
# Test with tool calls that don't require interrupts
ai_message = AIMessage(
content="I'll help you",
tool_calls=[{"name": "other_tool", "args": {"input": "test"}, "id": "1"}],
)
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
result = middleware.after_model(state, None)
assert result is None
def test_human_in_the_loop_middleware_single_tool_accept() -> None:
"""Test HumanInTheLoopMiddleware with single tool accept response."""
middleware = HumanInTheLoopMiddleware(
interrupt_on={"test_tool": {"allowed_decisions": ["approve", "edit", "reject"]}}
)
ai_message = AIMessage(
content="I'll help you",
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
)
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
def mock_accept(requests):
return {"decisions": [{"type": "approve"}]}
with patch("langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_accept):
result = middleware.after_model(state, None)
assert result is not None
assert "messages" in result
assert len(result["messages"]) == 1
assert result["messages"][0] == ai_message
assert result["messages"][0].tool_calls == ai_message.tool_calls
state["messages"].append(
ToolMessage(content="Tool message", name="test_tool", tool_call_id="1")
)
state["messages"].append(AIMessage(content="test_tool called with result: Tool message"))
result = middleware.after_model(state, None)
# No interrupts needed
assert result is None
def test_human_in_the_loop_middleware_single_tool_edit() -> None:
"""Test HumanInTheLoopMiddleware with single tool edit response."""
middleware = HumanInTheLoopMiddleware(
interrupt_on={"test_tool": {"allowed_decisions": ["approve", "edit", "reject"]}}
)
ai_message = AIMessage(
content="I'll help you",
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
)
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
def mock_edit(requests):
return {
"decisions": [
{
"type": "edit",
"edited_action": Action(
name="test_tool",
args={"input": "edited"},
),
}
]
}
with patch("langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_edit):
result = middleware.after_model(state, None)
assert result is not None
assert "messages" in result
assert len(result["messages"]) == 1
assert result["messages"][0].tool_calls[0]["args"] == {"input": "edited"}
assert result["messages"][0].tool_calls[0]["id"] == "1" # ID should be preserved
def test_human_in_the_loop_middleware_single_tool_response() -> None:
"""Test HumanInTheLoopMiddleware with single tool response with custom message."""
middleware = HumanInTheLoopMiddleware(
interrupt_on={"test_tool": {"allowed_decisions": ["approve", "edit", "reject"]}}
)
ai_message = AIMessage(
content="I'll help you",
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
)
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
def mock_response(requests):
return {"decisions": [{"type": "reject", "message": "Custom response message"}]}
with patch(
"langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_response
):
result = middleware.after_model(state, None)
assert result is not None
assert "messages" in result
assert len(result["messages"]) == 2
assert isinstance(result["messages"][0], AIMessage)
assert isinstance(result["messages"][1], ToolMessage)
assert result["messages"][1].content == "Custom response message"
assert result["messages"][1].name == "test_tool"
assert result["messages"][1].tool_call_id == "1"
def test_human_in_the_loop_middleware_multiple_tools_mixed_responses() -> None:
"""Test HumanInTheLoopMiddleware with multiple tools and mixed response types."""
middleware = HumanInTheLoopMiddleware(
interrupt_on={
"get_forecast": {"allowed_decisions": ["approve", "edit", "reject"]},
"get_temperature": {"allowed_decisions": ["approve", "edit", "reject"]},
}
)
ai_message = AIMessage(
content="I'll help you with weather",
tool_calls=[
{"name": "get_forecast", "args": {"location": "San Francisco"}, "id": "1"},
{"name": "get_temperature", "args": {"location": "San Francisco"}, "id": "2"},
],
)
state = {"messages": [HumanMessage(content="What's the weather?"), ai_message]}
def mock_mixed_responses(requests):
return {
"decisions": [
{"type": "approve"},
{"type": "reject", "message": "User rejected this tool call"},
]
}
with patch(
"langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_mixed_responses
):
result = middleware.after_model(state, None)
assert result is not None
assert "messages" in result
assert (
len(result["messages"]) == 2
) # AI message with accepted tool call + tool message for rejected
# First message should be the AI message with both tool calls
updated_ai_message = result["messages"][0]
assert len(updated_ai_message.tool_calls) == 2 # Both tool calls remain
assert updated_ai_message.tool_calls[0]["name"] == "get_forecast" # Accepted
assert updated_ai_message.tool_calls[1]["name"] == "get_temperature" # Got response
# Second message should be the tool message for the rejected tool call
tool_message = result["messages"][1]
assert isinstance(tool_message, ToolMessage)
assert tool_message.content == "User rejected this tool call"
assert tool_message.name == "get_temperature"
def test_human_in_the_loop_middleware_multiple_tools_edit_responses() -> None:
"""Test HumanInTheLoopMiddleware with multiple tools and edit responses."""
middleware = HumanInTheLoopMiddleware(
interrupt_on={
"get_forecast": {"allowed_decisions": ["approve", "edit", "reject"]},
"get_temperature": {"allowed_decisions": ["approve", "edit", "reject"]},
}
)
ai_message = AIMessage(
content="I'll help you with weather",
tool_calls=[
{"name": "get_forecast", "args": {"location": "San Francisco"}, "id": "1"},
{"name": "get_temperature", "args": {"location": "San Francisco"}, "id": "2"},
],
)
state = {"messages": [HumanMessage(content="What's the weather?"), ai_message]}
def mock_edit_responses(requests):
return {
"decisions": [
{
"type": "edit",
"edited_action": Action(
name="get_forecast",
args={"location": "New York"},
),
},
{
"type": "edit",
"edited_action": Action(
name="get_temperature",
args={"location": "New York"},
),
},
]
}
with patch(
"langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_edit_responses
):
result = middleware.after_model(state, None)
assert result is not None
assert "messages" in result
assert len(result["messages"]) == 1
updated_ai_message = result["messages"][0]
assert updated_ai_message.tool_calls[0]["args"] == {"location": "New York"}
assert updated_ai_message.tool_calls[0]["id"] == "1" # ID preserved
assert updated_ai_message.tool_calls[1]["args"] == {"location": "New York"}
assert updated_ai_message.tool_calls[1]["id"] == "2" # ID preserved
def test_human_in_the_loop_middleware_edit_with_modified_args() -> None:
"""Test HumanInTheLoopMiddleware with edit action that includes modified args."""
middleware = HumanInTheLoopMiddleware(
interrupt_on={"test_tool": {"allowed_decisions": ["approve", "edit", "reject"]}}
)
ai_message = AIMessage(
content="I'll help you",
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
)
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
def mock_edit_with_args(requests):
return {
"decisions": [
{
"type": "edit",
"edited_action": Action(
name="test_tool",
args={"input": "modified"},
),
}
]
}
with patch(
"langchain.agents.middleware.human_in_the_loop.interrupt",
side_effect=mock_edit_with_args,
):
result = middleware.after_model(state, None)
assert result is not None
assert "messages" in result
assert len(result["messages"]) == 1
# Should have modified args
updated_ai_message = result["messages"][0]
assert updated_ai_message.tool_calls[0]["args"] == {"input": "modified"}
assert updated_ai_message.tool_calls[0]["id"] == "1" # ID preserved
def test_human_in_the_loop_middleware_unknown_response_type() -> None:
"""Test HumanInTheLoopMiddleware with unknown response type."""
middleware = HumanInTheLoopMiddleware(
interrupt_on={"test_tool": {"allowed_decisions": ["approve", "edit", "reject"]}}
)
ai_message = AIMessage(
content="I'll help you",
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
)
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
def mock_unknown(requests):
return {"decisions": [{"type": "unknown"}]}
with patch("langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_unknown):
with pytest.raises(
ValueError,
match=r"Unexpected human decision: {'type': 'unknown'}. Decision type 'unknown' is not allowed for tool 'test_tool'. Expected one of \['approve', 'edit', 'reject'\] based on the tool's configuration.",
):
middleware.after_model(state, None)
def test_human_in_the_loop_middleware_disallowed_action() -> None:
"""Test HumanInTheLoopMiddleware with action not allowed by tool config."""
# edit is not allowed by tool config
middleware = HumanInTheLoopMiddleware(
interrupt_on={"test_tool": {"allowed_decisions": ["approve", "reject"]}}
)
ai_message = AIMessage(
content="I'll help you",
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
)
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
def mock_disallowed_action(requests):
return {
"decisions": [
{
"type": "edit",
"edited_action": Action(
name="test_tool",
args={"input": "modified"},
),
}
]
}
with patch(
"langchain.agents.middleware.human_in_the_loop.interrupt",
side_effect=mock_disallowed_action,
):
with pytest.raises(
ValueError,
match=r"Unexpected human decision: {'type': 'edit', 'edited_action': {'name': 'test_tool', 'args': {'input': 'modified'}}}. Decision type 'edit' is not allowed for tool 'test_tool'. Expected one of \['approve', 'reject'\] based on the tool's configuration.",
):
middleware.after_model(state, None)
def test_human_in_the_loop_middleware_mixed_auto_approved_and_interrupt() -> None:
"""Test HumanInTheLoopMiddleware with mix of auto-approved and interrupt tools."""
middleware = HumanInTheLoopMiddleware(
interrupt_on={"interrupt_tool": {"allowed_decisions": ["approve", "edit", "reject"]}}
)
ai_message = AIMessage(
content="I'll help you",
tool_calls=[
{"name": "auto_tool", "args": {"input": "auto"}, "id": "1"},
{"name": "interrupt_tool", "args": {"input": "interrupt"}, "id": "2"},
],
)
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
def mock_accept(requests):
return {"decisions": [{"type": "approve"}]}
with patch("langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_accept):
result = middleware.after_model(state, None)
assert result is not None
assert "messages" in result
assert len(result["messages"]) == 1
updated_ai_message = result["messages"][0]
# Should have both tools: auto-approved first, then interrupt tool
assert len(updated_ai_message.tool_calls) == 2
assert updated_ai_message.tool_calls[0]["name"] == "auto_tool"
assert updated_ai_message.tool_calls[1]["name"] == "interrupt_tool"
def test_human_in_the_loop_middleware_interrupt_request_structure() -> None:
"""Test that interrupt requests are structured correctly."""
middleware = HumanInTheLoopMiddleware(
interrupt_on={"test_tool": {"allowed_decisions": ["approve", "edit", "reject"]}},
description_prefix="Custom prefix",
)
ai_message = AIMessage(
content="I'll help you",
tool_calls=[{"name": "test_tool", "args": {"input": "test", "location": "SF"}, "id": "1"}],
)
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
captured_request = None
def mock_capture_requests(request):
nonlocal captured_request
captured_request = request
return {"decisions": [{"type": "approve"}]}
with patch(
"langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_capture_requests
):
middleware.after_model(state, None)
assert captured_request is not None
assert "action_requests" in captured_request
assert "review_configs" in captured_request
assert len(captured_request["action_requests"]) == 1
action_request = captured_request["action_requests"][0]
assert action_request["name"] == "test_tool"
assert action_request["args"] == {"input": "test", "location": "SF"}
assert "Custom prefix" in action_request["description"]
assert "Tool: test_tool" in action_request["description"]
assert "Args: {'input': 'test', 'location': 'SF'}" in action_request["description"]
assert len(captured_request["review_configs"]) == 1
review_config = captured_request["review_configs"][0]
assert review_config["action_name"] == "test_tool"
assert review_config["allowed_decisions"] == ["approve", "edit", "reject"]
def test_human_in_the_loop_middleware_boolean_configs() -> None:
"""Test HITL middleware with boolean tool configs."""
middleware = HumanInTheLoopMiddleware(interrupt_on={"test_tool": True})
ai_message = AIMessage(
content="I'll help you",
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
)
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
# Test accept
with patch(
"langchain.agents.middleware.human_in_the_loop.interrupt",
return_value={"decisions": [{"type": "approve"}]},
):
result = middleware.after_model(state, None)
assert result is not None
assert "messages" in result
assert len(result["messages"]) == 1
assert result["messages"][0].tool_calls == ai_message.tool_calls
# Test edit
with patch(
"langchain.agents.middleware.human_in_the_loop.interrupt",
return_value={
"decisions": [
{
"type": "edit",
"edited_action": Action(
name="test_tool",
args={"input": "edited"},
),
}
]
},
):
result = middleware.after_model(state, None)
assert result is not None
assert "messages" in result
assert len(result["messages"]) == 1
assert result["messages"][0].tool_calls[0]["args"] == {"input": "edited"}
middleware = HumanInTheLoopMiddleware(interrupt_on={"test_tool": False})
result = middleware.after_model(state, None)
# No interruption should occur
assert result is None
def test_human_in_the_loop_middleware_sequence_mismatch() -> None:
"""Test that sequence mismatch in resume raises an error."""
middleware = HumanInTheLoopMiddleware(interrupt_on={"test_tool": True})
ai_message = AIMessage(
content="I'll help you",
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
)
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
# Test with too few responses
with patch(
"langchain.agents.middleware.human_in_the_loop.interrupt",
return_value={"decisions": []}, # No responses for 1 tool call
):
with pytest.raises(
ValueError,
match=r"Number of human decisions \(0\) does not match number of hanging tool calls \(1\)\.",
):
middleware.after_model(state, None)
# Test with too many responses
with patch(
"langchain.agents.middleware.human_in_the_loop.interrupt",
return_value={
"decisions": [
{"type": "approve"},
{"type": "approve"},
]
}, # 2 responses for 1 tool call
):
with pytest.raises(
ValueError,
match=r"Number of human decisions \(2\) does not match number of hanging tool calls \(1\)\.",
):
middleware.after_model(state, None)
def test_human_in_the_loop_middleware_description_as_callable() -> None:
"""Test that description field accepts both string and callable."""
def custom_description(tool_call: ToolCall, state: AgentState, runtime: Runtime) -> str:
"""Generate a custom description."""
return f"Custom: {tool_call['name']} with args {tool_call['args']}"
middleware = HumanInTheLoopMiddleware(
interrupt_on={
"tool_with_callable": {
"allowed_decisions": ["approve"],
"description": custom_description,
},
"tool_with_string": {
"allowed_decisions": ["approve"],
"description": "Static description",
},
}
)
ai_message = AIMessage(
content="I'll help you",
tool_calls=[
{"name": "tool_with_callable", "args": {"x": 1}, "id": "1"},
{"name": "tool_with_string", "args": {"y": 2}, "id": "2"},
],
)
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
captured_request = None
def mock_capture_requests(request):
nonlocal captured_request
captured_request = request
return {"decisions": [{"type": "approve"}, {"type": "approve"}]}
with patch(
"langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_capture_requests
):
middleware.after_model(state, None)
assert captured_request is not None
assert "action_requests" in captured_request
assert len(captured_request["action_requests"]) == 2
# Check callable description
assert (
captured_request["action_requests"][0]["description"]
== "Custom: tool_with_callable with args {'x': 1}"
)
# Check string description
assert captured_request["action_requests"][1]["description"] == "Static description"

View File

@@ -0,0 +1,224 @@
import pytest
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from langchain_core.tools import tool
from langgraph.checkpoint.memory import InMemorySaver
from langchain.agents.factory import create_agent
from langchain.agents.middleware.model_call_limit import (
ModelCallLimitMiddleware,
ModelCallLimitExceededError,
)
from ...model import FakeToolCallingModel
@tool
def simple_tool(input: str) -> str:
"""A simple tool"""
return input
def test_middleware_unit_functionality():
"""Test that the middleware works as expected in isolation."""
# Test with end behavior
middleware = ModelCallLimitMiddleware(thread_limit=2, run_limit=1)
# Mock runtime (not used in current implementation)
runtime = None
# Test when limits are not exceeded
state = {"thread_model_call_count": 0, "run_model_call_count": 0}
result = middleware.before_model(state, runtime)
assert result is None
# Test when thread limit is exceeded
state = {"thread_model_call_count": 2, "run_model_call_count": 0}
result = middleware.before_model(state, runtime)
assert result is not None
assert result["jump_to"] == "end"
assert "messages" in result
assert len(result["messages"]) == 1
assert "thread limit (2/2)" in result["messages"][0].content
# Test when run limit is exceeded
state = {"thread_model_call_count": 1, "run_model_call_count": 1}
result = middleware.before_model(state, runtime)
assert result is not None
assert result["jump_to"] == "end"
assert "messages" in result
assert len(result["messages"]) == 1
assert "run limit (1/1)" in result["messages"][0].content
# Test with error behavior
middleware_exception = ModelCallLimitMiddleware(
thread_limit=2, run_limit=1, exit_behavior="error"
)
# Test exception when thread limit exceeded
state = {"thread_model_call_count": 2, "run_model_call_count": 0}
with pytest.raises(ModelCallLimitExceededError) as exc_info:
middleware_exception.before_model(state, runtime)
assert "thread limit (2/2)" in str(exc_info.value)
# Test exception when run limit exceeded
state = {"thread_model_call_count": 1, "run_model_call_count": 1}
with pytest.raises(ModelCallLimitExceededError) as exc_info:
middleware_exception.before_model(state, runtime)
assert "run limit (1/1)" in str(exc_info.value)
def test_thread_limit_with_create_agent():
"""Test that thread limits work correctly with create_agent."""
model = FakeToolCallingModel()
# Set thread limit to 1 (should be exceeded after 1 call)
agent = create_agent(
model=model,
tools=[simple_tool],
middleware=[ModelCallLimitMiddleware(thread_limit=1)],
checkpointer=InMemorySaver(),
)
# First invocation should work - 1 model call, within thread limit
result = agent.invoke(
{"messages": [HumanMessage("Hello")]}, {"configurable": {"thread_id": "thread1"}}
)
# Should complete successfully with 1 model call
assert "messages" in result
assert len(result["messages"]) == 2 # Human + AI messages
# Second invocation in same thread should hit thread limit
# The agent should jump to end after detecting the limit
result2 = agent.invoke(
{"messages": [HumanMessage("Hello again")]}, {"configurable": {"thread_id": "thread1"}}
)
assert "messages" in result2
# The agent should have detected the limit and jumped to end with a limit exceeded message
# So we should have: previous messages + new human message + limit exceeded AI message
assert len(result2["messages"]) == 4 # Previous Human + AI + New Human + Limit AI
assert isinstance(result2["messages"][0], HumanMessage) # First human
assert isinstance(result2["messages"][1], AIMessage) # First AI response
assert isinstance(result2["messages"][2], HumanMessage) # Second human
assert isinstance(result2["messages"][3], AIMessage) # Limit exceeded message
assert "thread limit" in result2["messages"][3].content
def test_run_limit_with_create_agent():
"""Test that run limits work correctly with create_agent."""
# Create a model that will make 2 calls
model = FakeToolCallingModel(
tool_calls=[
[{"name": "simple_tool", "args": {"input": "test"}, "id": "1"}],
[], # No tool calls on second call
]
)
# Set run limit to 1 (should be exceeded after 1 call)
agent = create_agent(
model=model,
tools=[simple_tool],
middleware=[ModelCallLimitMiddleware(run_limit=1)],
checkpointer=InMemorySaver(),
)
# This should hit the run limit after the first model call
result = agent.invoke(
{"messages": [HumanMessage("Hello")]}, {"configurable": {"thread_id": "thread1"}}
)
assert "messages" in result
# The agent should have made 1 model call then jumped to end with limit exceeded message
# So we should have: Human + AI + Tool + Limit exceeded AI message
assert len(result["messages"]) == 4 # Human + AI + Tool + Limit AI
assert isinstance(result["messages"][0], HumanMessage)
assert isinstance(result["messages"][1], AIMessage)
assert isinstance(result["messages"][2], ToolMessage)
assert isinstance(result["messages"][3], AIMessage) # Limit exceeded message
assert "run limit" in result["messages"][3].content
def test_middleware_initialization_validation():
"""Test that middleware initialization validates parameters correctly."""
# Test that at least one limit must be specified
with pytest.raises(ValueError, match="At least one limit must be specified"):
ModelCallLimitMiddleware()
# Test invalid exit behavior
with pytest.raises(ValueError, match="Invalid exit_behavior"):
ModelCallLimitMiddleware(thread_limit=5, exit_behavior="invalid")
# Test valid initialization
middleware = ModelCallLimitMiddleware(thread_limit=5, run_limit=3)
assert middleware.thread_limit == 5
assert middleware.run_limit == 3
assert middleware.exit_behavior == "end"
# Test with only thread limit
middleware = ModelCallLimitMiddleware(thread_limit=5)
assert middleware.thread_limit == 5
assert middleware.run_limit is None
# Test with only run limit
middleware = ModelCallLimitMiddleware(run_limit=3)
assert middleware.thread_limit is None
assert middleware.run_limit == 3
def test_exception_error_message():
"""Test that the exception provides clear error messages."""
middleware = ModelCallLimitMiddleware(thread_limit=2, run_limit=1, exit_behavior="error")
# Test thread limit exceeded
state = {"thread_model_call_count": 2, "run_model_call_count": 0}
with pytest.raises(ModelCallLimitExceededError) as exc_info:
middleware.before_model(state, None)
error_msg = str(exc_info.value)
assert "Model call limits exceeded" in error_msg
assert "thread limit (2/2)" in error_msg
# Test run limit exceeded
state = {"thread_model_call_count": 0, "run_model_call_count": 1}
with pytest.raises(ModelCallLimitExceededError) as exc_info:
middleware.before_model(state, None)
error_msg = str(exc_info.value)
assert "Model call limits exceeded" in error_msg
assert "run limit (1/1)" in error_msg
# Test both limits exceeded
state = {"thread_model_call_count": 2, "run_model_call_count": 1}
with pytest.raises(ModelCallLimitExceededError) as exc_info:
middleware.before_model(state, None)
error_msg = str(exc_info.value)
assert "Model call limits exceeded" in error_msg
assert "thread limit (2/2)" in error_msg
assert "run limit (1/1)" in error_msg
def test_run_limit_resets_between_invocations() -> None:
"""Test that run_model_call_count resets between invocations, but thread_model_call_count accumulates."""
# First: No tool calls per invocation, so model does not increment call counts internally
middleware = ModelCallLimitMiddleware(thread_limit=3, run_limit=1, exit_behavior="error")
model = FakeToolCallingModel(
tool_calls=[[], [], [], []]
) # No tool calls, so only model call per run
agent = create_agent(model=model, middleware=[middleware], checkpointer=InMemorySaver())
thread_config = {"configurable": {"thread_id": "test_thread"}}
agent.invoke({"messages": [HumanMessage("Hello")]}, thread_config)
agent.invoke({"messages": [HumanMessage("Hello again")]}, thread_config)
agent.invoke({"messages": [HumanMessage("Hello third")]}, thread_config)
# Fourth run: should raise, thread_model_call_count == 3 (limit)
with pytest.raises(ModelCallLimitExceededError) as exc_info:
agent.invoke({"messages": [HumanMessage("Hello fourth")]}, thread_config)
error_msg = str(exc_info.value)
assert "thread limit (3/3)" in error_msg

View File

@@ -5,13 +5,18 @@ from __future__ import annotations
from typing import cast
import pytest
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
from langchain_core.messages import AIMessage
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain.agents.factory import create_agent
from langchain.agents.middleware.model_fallback import ModelFallbackMiddleware
from langchain.agents.middleware.types import ModelRequest, ModelResponse
from langgraph.runtime import Runtime
from ...model import FakeToolCallingModel
def _fake_runtime() -> Runtime:
return cast(Runtime, object())
@@ -213,3 +218,90 @@ async def test_all_models_fail_async() -> None:
with pytest.raises(ValueError, match="Model failed"):
await middleware.awrap_model_call(request, mock_handler)
def test_model_fallback_middleware_with_agent() -> None:
"""Test ModelFallbackMiddleware with agent.invoke and fallback models only."""
class FailingModel(BaseChatModel):
"""Model that always fails."""
def _generate(self, messages, **kwargs):
raise ValueError("Primary model failed")
@property
def _llm_type(self):
return "failing"
class SuccessModel(BaseChatModel):
"""Model that succeeds."""
def _generate(self, messages, **kwargs):
return ChatResult(
generations=[ChatGeneration(message=AIMessage(content="Fallback success"))]
)
@property
def _llm_type(self):
return "success"
primary = FailingModel()
fallback = SuccessModel()
# Only pass fallback models to middleware (not the primary)
fallback_middleware = ModelFallbackMiddleware(fallback)
agent = create_agent(model=primary, middleware=[fallback_middleware])
result = agent.invoke({"messages": [HumanMessage("Test")]})
# Should have succeeded with fallback model
assert len(result["messages"]) == 2
assert result["messages"][1].content == "Fallback success"
def test_model_fallback_middleware_exhausted_with_agent() -> None:
"""Test ModelFallbackMiddleware with agent.invoke when all models fail."""
class AlwaysFailingModel(BaseChatModel):
"""Model that always fails."""
def __init__(self, name: str):
super().__init__()
self.name = name
def _generate(self, messages, **kwargs):
raise ValueError(f"{self.name} failed")
@property
def _llm_type(self):
return self.name
primary = AlwaysFailingModel("primary")
fallback1 = AlwaysFailingModel("fallback1")
fallback2 = AlwaysFailingModel("fallback2")
# Primary fails (attempt 1), then fallback1 (attempt 2), then fallback2 (attempt 3)
fallback_middleware = ModelFallbackMiddleware(fallback1, fallback2)
agent = create_agent(model=primary, middleware=[fallback_middleware])
# Should fail with the last fallback's error
with pytest.raises(ValueError, match="fallback2 failed"):
agent.invoke({"messages": [HumanMessage("Test")]})
def test_model_fallback_middleware_initialization() -> None:
"""Test ModelFallbackMiddleware initialization."""
# Test with no models - now a TypeError (missing required argument)
with pytest.raises(TypeError):
ModelFallbackMiddleware() # type: ignore[call-arg]
# Test with one fallback model (valid)
middleware = ModelFallbackMiddleware(FakeToolCallingModel())
assert len(middleware.models) == 1
# Test with multiple fallback models
middleware = ModelFallbackMiddleware(FakeToolCallingModel(), FakeToolCallingModel())
assert len(middleware.models) == 2

Some files were not shown because too many files have changed in this diff Show More