mirror of
https://github.com/hwchase17/langchain.git
synced 2026-01-23 05:09:12 +00:00
Compare commits
53 Commits
langchain-
...
sr/looser-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6ba1177f4f | ||
|
|
52b1516d44 | ||
|
|
8a3bb73c05 | ||
|
|
099c042395 | ||
|
|
2d4f00a451 | ||
|
|
9bd401a6d4 | ||
|
|
62c05e09c1 | ||
|
|
83b9d9f810 | ||
|
|
6aa3794b74 | ||
|
|
6deee23d8d | ||
|
|
690aabe8d4 | ||
|
|
80554df1e6 | ||
|
|
72c45e65e8 | ||
|
|
189dcf7295 | ||
|
|
1bc88028e6 | ||
|
|
d2942351ce | ||
|
|
83c078f363 | ||
|
|
26d39ffc4a | ||
|
|
421e2ceeee | ||
|
|
275dcbf69f | ||
|
|
9f87b27a5b | ||
|
|
b2e1196e29 | ||
|
|
2dc1396380 | ||
|
|
77941ab3ce | ||
|
|
ee19a30dde | ||
|
|
5d799b3174 | ||
|
|
8f33a985a2 | ||
|
|
78eeccef0e | ||
|
|
3d415441e8 | ||
|
|
74385e0ebd | ||
|
|
2bfbc29ccc | ||
|
|
ef79c26f18 | ||
|
|
fbe32c8e89 | ||
|
|
2511c28f92 | ||
|
|
637bb1cbbc | ||
|
|
3dfea96ec1 | ||
|
|
68643153e5 | ||
|
|
462762f75b | ||
|
|
4f3729c004 | ||
|
|
ba428cdf54 | ||
|
|
69c7d1b01b | ||
|
|
733299ec13 | ||
|
|
e1adf781c6 | ||
|
|
31b5e4810c | ||
|
|
c6801fe159 | ||
|
|
1b563067f8 | ||
|
|
1996d81d72 | ||
|
|
ab0677c6f1 | ||
|
|
bdb53c93cc | ||
|
|
94d5271cb5 | ||
|
|
e499db4266 | ||
|
|
cc3af82b47 | ||
|
|
9383b78be1 |
77
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
77
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
@@ -8,16 +8,15 @@ body:
|
||||
value: |
|
||||
Thank you for taking the time to file a bug report.
|
||||
|
||||
Use this to report BUGS in LangChain. For usage questions, feature requests and general design questions, please use the [LangChain Forum](https://forum.langchain.com/).
|
||||
For usage questions, feature requests and general design questions, please use the [LangChain Forum](https://forum.langchain.com/).
|
||||
|
||||
Relevant links to check before filing a bug report to see if your issue has already been reported, fixed or
|
||||
if there's another way to solve your problem:
|
||||
Check these before submitting to see if your issue has already been reported, fixed or if there's another way to solve your problem:
|
||||
|
||||
* [LangChain Forum](https://forum.langchain.com/),
|
||||
* [LangChain documentation with the integrated search](https://docs.langchain.com/oss/python/langchain/overview),
|
||||
* [API Reference](https://reference.langchain.com/python/),
|
||||
* [Documentation](https://docs.langchain.com/oss/python/langchain/overview),
|
||||
* [API Reference Documentation](https://reference.langchain.com/python/),
|
||||
* [LangChain ChatBot](https://chat.langchain.com/)
|
||||
* [GitHub search](https://github.com/langchain-ai/langchain),
|
||||
* [LangChain Forum](https://forum.langchain.com/),
|
||||
- type: checkboxes
|
||||
id: checks
|
||||
attributes:
|
||||
@@ -36,16 +35,48 @@ body:
|
||||
required: true
|
||||
- label: This is not related to the langchain-community package.
|
||||
required: true
|
||||
- label: I read what a minimal reproducible example is (https://stackoverflow.com/help/minimal-reproducible-example).
|
||||
required: true
|
||||
- label: I posted a self-contained, minimal, reproducible example. A maintainer can copy it and run it AS IS.
|
||||
required: true
|
||||
- type: checkboxes
|
||||
id: package
|
||||
attributes:
|
||||
label: Package (Required)
|
||||
description: |
|
||||
Which `langchain` package(s) is this bug related to? Select at least one.
|
||||
|
||||
Note that if the package you are reporting for is not listed here, it is not in this repository (e.g. `langchain-google-genai` is in [`langchain-ai/langchain-google`](https://github.com/langchain-ai/langchain-google/)).
|
||||
|
||||
Please report issues for other packages to their respective repositories.
|
||||
options:
|
||||
- label: langchain
|
||||
- label: langchain-openai
|
||||
- label: langchain-anthropic
|
||||
- label: langchain-classic
|
||||
- label: langchain-core
|
||||
- label: langchain-cli
|
||||
- label: langchain-model-profiles
|
||||
- label: langchain-tests
|
||||
- label: langchain-text-splitters
|
||||
- label: langchain-chroma
|
||||
- label: langchain-deepseek
|
||||
- label: langchain-exa
|
||||
- label: langchain-fireworks
|
||||
- label: langchain-groq
|
||||
- label: langchain-huggingface
|
||||
- label: langchain-mistralai
|
||||
- label: langchain-nomic
|
||||
- label: langchain-ollama
|
||||
- label: langchain-perplexity
|
||||
- label: langchain-prompty
|
||||
- label: langchain-qdrant
|
||||
- label: langchain-xai
|
||||
- label: Other / not sure / general
|
||||
- type: textarea
|
||||
id: reproduction
|
||||
validations:
|
||||
required: true
|
||||
attributes:
|
||||
label: Example Code
|
||||
label: Example Code (Python)
|
||||
description: |
|
||||
Please add a self-contained, [minimal, reproducible, example](https://stackoverflow.com/help/minimal-reproducible-example) with your use case.
|
||||
|
||||
@@ -53,15 +84,12 @@ body:
|
||||
|
||||
**Important!**
|
||||
|
||||
* Avoid screenshots when possible, as they are hard to read and (more importantly) don't allow others to copy-and-paste your code.
|
||||
* Reduce your code to the minimum required to reproduce the issue if possible. This makes it much easier for others to help you.
|
||||
* Use code tags (e.g., ```python ... ```) to correctly [format your code](https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting).
|
||||
* INCLUDE the language label (e.g. `python`) after the first three backticks to enable syntax highlighting. (e.g., ```python rather than ```).
|
||||
* Avoid screenshots, as they are hard to read and (more importantly) don't allow others to copy-and-paste your code.
|
||||
* Reduce your code to the minimum required to reproduce the issue if possible.
|
||||
|
||||
(This will be automatically formatted into code, so no need for backticks.)
|
||||
render: python
|
||||
placeholder: |
|
||||
The following code:
|
||||
|
||||
```python
|
||||
from langchain_core.runnables import RunnableLambda
|
||||
|
||||
def bad_code(inputs) -> int:
|
||||
@@ -69,17 +97,14 @@ body:
|
||||
|
||||
chain = RunnableLambda(bad_code)
|
||||
chain.invoke('Hello!')
|
||||
```
|
||||
- type: textarea
|
||||
id: error
|
||||
validations:
|
||||
required: false
|
||||
attributes:
|
||||
label: Error Message and Stack Trace (if applicable)
|
||||
description: |
|
||||
If you are reporting an error, please include the full error message and stack trace.
|
||||
placeholder: |
|
||||
Exception + full stack trace
|
||||
If you are reporting an error, please copy and paste the full error message and
|
||||
stack trace.
|
||||
(This will be automatically formatted into code, so no need for backticks.)
|
||||
render: shell
|
||||
- type: textarea
|
||||
id: description
|
||||
attributes:
|
||||
@@ -99,9 +124,7 @@ body:
|
||||
attributes:
|
||||
label: System Info
|
||||
description: |
|
||||
Please share your system info with us. Do NOT skip this step and please don't trim
|
||||
the output. Most users don't include enough information here and it makes it harder
|
||||
for us to help you.
|
||||
Please share your system info with us.
|
||||
|
||||
Run the following command in your terminal and paste the output here:
|
||||
|
||||
@@ -113,8 +136,6 @@ body:
|
||||
from langchain_core import sys_info
|
||||
sys_info.print_sys_info()
|
||||
```
|
||||
|
||||
alternatively, put the entire output of `pip freeze` here.
|
||||
placeholder: |
|
||||
python -m langchain_core.sys_info
|
||||
validations:
|
||||
|
||||
11
.github/ISSUE_TEMPLATE/config.yml
vendored
11
.github/ISSUE_TEMPLATE/config.yml
vendored
@@ -1,9 +1,18 @@
|
||||
blank_issues_enabled: false
|
||||
version: 2.1
|
||||
contact_links:
|
||||
- name: 📚 Documentation
|
||||
- name: 📚 Documentation issue
|
||||
url: https://github.com/langchain-ai/docs/issues/new?template=01-langchain.yml
|
||||
about: Report an issue related to the LangChain documentation
|
||||
- name: 💬 LangChain Forum
|
||||
url: https://forum.langchain.com/
|
||||
about: General community discussions and support
|
||||
- name: 📚 LangChain Documentation
|
||||
url: https://docs.langchain.com/oss/python/langchain/overview
|
||||
about: View the official LangChain documentation
|
||||
- name: 📚 API Reference Documentation
|
||||
url: https://reference.langchain.com/python/
|
||||
about: View the official LangChain API reference documentation
|
||||
- name: 💬 LangChain Forum
|
||||
url: https://forum.langchain.com/
|
||||
about: Ask questions and get help from the community
|
||||
|
||||
40
.github/ISSUE_TEMPLATE/feature-request.yml
vendored
40
.github/ISSUE_TEMPLATE/feature-request.yml
vendored
@@ -13,11 +13,11 @@ body:
|
||||
Relevant links to check before filing a feature request to see if your request has already been made or
|
||||
if there's another way to achieve what you want:
|
||||
|
||||
* [LangChain Forum](https://forum.langchain.com/),
|
||||
* [LangChain documentation with the integrated search](https://docs.langchain.com/oss/python/langchain/overview),
|
||||
* [API Reference](https://reference.langchain.com/python/),
|
||||
* [Documentation](https://docs.langchain.com/oss/python/langchain/overview),
|
||||
* [API Reference Documentation](https://reference.langchain.com/python/),
|
||||
* [LangChain ChatBot](https://chat.langchain.com/)
|
||||
* [GitHub search](https://github.com/langchain-ai/langchain),
|
||||
* [LangChain Forum](https://forum.langchain.com/),
|
||||
- type: checkboxes
|
||||
id: checks
|
||||
attributes:
|
||||
@@ -34,6 +34,40 @@ body:
|
||||
required: true
|
||||
- label: This is not related to the langchain-community package.
|
||||
required: true
|
||||
- type: checkboxes
|
||||
id: package
|
||||
attributes:
|
||||
label: Package (Required)
|
||||
description: |
|
||||
Which `langchain` package(s) is this request related to? Select at least one.
|
||||
|
||||
Note that if the package you are requesting for is not listed here, it is not in this repository (e.g. `langchain-google-genai` is in `langchain-ai/langchain`).
|
||||
|
||||
Please submit feature requests for other packages to their respective repositories.
|
||||
options:
|
||||
- label: langchain
|
||||
- label: langchain-openai
|
||||
- label: langchain-anthropic
|
||||
- label: langchain-classic
|
||||
- label: langchain-core
|
||||
- label: langchain-cli
|
||||
- label: langchain-model-profiles
|
||||
- label: langchain-tests
|
||||
- label: langchain-text-splitters
|
||||
- label: langchain-chroma
|
||||
- label: langchain-deepseek
|
||||
- label: langchain-exa
|
||||
- label: langchain-fireworks
|
||||
- label: langchain-groq
|
||||
- label: langchain-huggingface
|
||||
- label: langchain-mistralai
|
||||
- label: langchain-nomic
|
||||
- label: langchain-ollama
|
||||
- label: langchain-perplexity
|
||||
- label: langchain-prompty
|
||||
- label: langchain-qdrant
|
||||
- label: langchain-xai
|
||||
- label: Other / not sure / general
|
||||
- type: textarea
|
||||
id: feature-description
|
||||
validations:
|
||||
|
||||
30
.github/ISSUE_TEMPLATE/privileged.yml
vendored
30
.github/ISSUE_TEMPLATE/privileged.yml
vendored
@@ -18,3 +18,33 @@ body:
|
||||
attributes:
|
||||
label: Issue Content
|
||||
description: Add the content of the issue here.
|
||||
- type: checkboxes
|
||||
id: package
|
||||
attributes:
|
||||
label: Package (Required)
|
||||
description: |
|
||||
Please select package(s) that this issue is related to.
|
||||
options:
|
||||
- label: langchain
|
||||
- label: langchain-openai
|
||||
- label: langchain-anthropic
|
||||
- label: langchain-classic
|
||||
- label: langchain-core
|
||||
- label: langchain-cli
|
||||
- label: langchain-model-profiles
|
||||
- label: langchain-tests
|
||||
- label: langchain-text-splitters
|
||||
- label: langchain-chroma
|
||||
- label: langchain-deepseek
|
||||
- label: langchain-exa
|
||||
- label: langchain-fireworks
|
||||
- label: langchain-groq
|
||||
- label: langchain-huggingface
|
||||
- label: langchain-mistralai
|
||||
- label: langchain-nomic
|
||||
- label: langchain-ollama
|
||||
- label: langchain-perplexity
|
||||
- label: langchain-prompty
|
||||
- label: langchain-qdrant
|
||||
- label: langchain-xai
|
||||
- label: Other / not sure / general
|
||||
|
||||
48
.github/ISSUE_TEMPLATE/task.yml
vendored
48
.github/ISSUE_TEMPLATE/task.yml
vendored
@@ -25,13 +25,13 @@ body:
|
||||
label: Task Description
|
||||
description: |
|
||||
Provide a clear and detailed description of the task.
|
||||
|
||||
|
||||
What needs to be done? Be specific about the scope and requirements.
|
||||
placeholder: |
|
||||
This task involves...
|
||||
|
||||
|
||||
The goal is to...
|
||||
|
||||
|
||||
Specific requirements:
|
||||
- ...
|
||||
- ...
|
||||
@@ -43,7 +43,7 @@ body:
|
||||
label: Acceptance Criteria
|
||||
description: |
|
||||
Define the criteria that must be met for this task to be considered complete.
|
||||
|
||||
|
||||
What are the specific deliverables or outcomes expected?
|
||||
placeholder: |
|
||||
This task will be complete when:
|
||||
@@ -58,15 +58,15 @@ body:
|
||||
label: Context and Background
|
||||
description: |
|
||||
Provide any relevant context, background information, or links to related issues/PRs.
|
||||
|
||||
|
||||
Why is this task needed? What problem does it solve?
|
||||
placeholder: |
|
||||
Background:
|
||||
- ...
|
||||
|
||||
|
||||
Related issues/PRs:
|
||||
- #...
|
||||
|
||||
|
||||
Additional context:
|
||||
- ...
|
||||
validations:
|
||||
@@ -77,15 +77,45 @@ body:
|
||||
label: Dependencies
|
||||
description: |
|
||||
List any dependencies or blockers for this task.
|
||||
|
||||
|
||||
Are there other tasks, issues, or external factors that need to be completed first?
|
||||
placeholder: |
|
||||
This task depends on:
|
||||
- [ ] Issue #...
|
||||
- [ ] PR #...
|
||||
- [ ] External dependency: ...
|
||||
|
||||
|
||||
Blocked by:
|
||||
- ...
|
||||
validations:
|
||||
required: false
|
||||
- type: checkboxes
|
||||
id: package
|
||||
attributes:
|
||||
label: Package (Required)
|
||||
description: |
|
||||
Please select package(s) that this task is related to.
|
||||
options:
|
||||
- label: langchain
|
||||
- label: langchain-openai
|
||||
- label: langchain-anthropic
|
||||
- label: langchain-classic
|
||||
- label: langchain-core
|
||||
- label: langchain-cli
|
||||
- label: langchain-model-profiles
|
||||
- label: langchain-tests
|
||||
- label: langchain-text-splitters
|
||||
- label: langchain-chroma
|
||||
- label: langchain-deepseek
|
||||
- label: langchain-exa
|
||||
- label: langchain-fireworks
|
||||
- label: langchain-groq
|
||||
- label: langchain-huggingface
|
||||
- label: langchain-mistralai
|
||||
- label: langchain-nomic
|
||||
- label: langchain-ollama
|
||||
- label: langchain-perplexity
|
||||
- label: langchain-prompty
|
||||
- label: langchain-qdrant
|
||||
- label: langchain-xai
|
||||
- label: Other / not sure / general
|
||||
|
||||
2
.github/scripts/get_min_versions.py
vendored
2
.github/scripts/get_min_versions.py
vendored
@@ -98,7 +98,7 @@ def _check_python_version_from_requirement(
|
||||
return True
|
||||
else:
|
||||
marker_str = str(requirement.marker)
|
||||
if "python_version" or "python_full_version" in marker_str:
|
||||
if "python_version" in marker_str or "python_full_version" in marker_str:
|
||||
python_version_str = "".join(
|
||||
char
|
||||
for char in marker_str
|
||||
|
||||
107
.github/workflows/auto-label-by-package.yml
vendored
Normal file
107
.github/workflows/auto-label-by-package.yml
vendored
Normal file
@@ -0,0 +1,107 @@
|
||||
name: Auto Label Issues by Package
|
||||
|
||||
on:
|
||||
issues:
|
||||
types: [opened, edited]
|
||||
|
||||
jobs:
|
||||
label-by-package:
|
||||
permissions:
|
||||
issues: write
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Sync package labels
|
||||
uses: actions/github-script@v6
|
||||
with:
|
||||
script: |
|
||||
const body = context.payload.issue.body || "";
|
||||
|
||||
// Extract text under "### Package"
|
||||
const match = body.match(/### Package\s+([\s\S]*?)\n###/i);
|
||||
if (!match) return;
|
||||
|
||||
const packageSection = match[1].trim();
|
||||
|
||||
// Mapping table for package names to labels
|
||||
const mapping = {
|
||||
"langchain": "langchain",
|
||||
"langchain-openai": "openai",
|
||||
"langchain-anthropic": "anthropic",
|
||||
"langchain-classic": "langchain-classic",
|
||||
"langchain-core": "core",
|
||||
"langchain-cli": "cli",
|
||||
"langchain-model-profiles": "model-profiles",
|
||||
"langchain-tests": "standard-tests",
|
||||
"langchain-text-splitters": "text-splitters",
|
||||
"langchain-chroma": "chroma",
|
||||
"langchain-deepseek": "deepseek",
|
||||
"langchain-exa": "exa",
|
||||
"langchain-fireworks": "fireworks",
|
||||
"langchain-groq": "groq",
|
||||
"langchain-huggingface": "huggingface",
|
||||
"langchain-mistralai": "mistralai",
|
||||
"langchain-nomic": "nomic",
|
||||
"langchain-ollama": "ollama",
|
||||
"langchain-perplexity": "perplexity",
|
||||
"langchain-prompty": "prompty",
|
||||
"langchain-qdrant": "qdrant",
|
||||
"langchain-xai": "xai",
|
||||
};
|
||||
|
||||
// All possible package labels we manage
|
||||
const allPackageLabels = Object.values(mapping);
|
||||
const selectedLabels = [];
|
||||
|
||||
// Check if this is checkbox format (multiple selection)
|
||||
const checkboxMatches = packageSection.match(/- \[x\]\s+([^\n\r]+)/gi);
|
||||
if (checkboxMatches) {
|
||||
// Handle checkbox format
|
||||
for (const match of checkboxMatches) {
|
||||
const packageName = match.replace(/- \[x\]\s+/i, '').trim();
|
||||
const label = mapping[packageName];
|
||||
if (label && !selectedLabels.includes(label)) {
|
||||
selectedLabels.push(label);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Handle dropdown format (single selection)
|
||||
const label = mapping[packageSection];
|
||||
if (label) {
|
||||
selectedLabels.push(label);
|
||||
}
|
||||
}
|
||||
|
||||
// Get current issue labels
|
||||
const issue = await github.rest.issues.get({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: context.issue.number
|
||||
});
|
||||
|
||||
const currentLabels = issue.data.labels.map(label => label.name);
|
||||
const currentPackageLabels = currentLabels.filter(label => allPackageLabels.includes(label));
|
||||
|
||||
// Determine labels to add and remove
|
||||
const labelsToAdd = selectedLabels.filter(label => !currentPackageLabels.includes(label));
|
||||
const labelsToRemove = currentPackageLabels.filter(label => !selectedLabels.includes(label));
|
||||
|
||||
// Add new labels
|
||||
if (labelsToAdd.length > 0) {
|
||||
await github.rest.issues.addLabels({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: context.issue.number,
|
||||
labels: labelsToAdd
|
||||
});
|
||||
}
|
||||
|
||||
// Remove old labels
|
||||
for (const label of labelsToRemove) {
|
||||
await github.rest.issues.removeLabel({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: context.issue.number,
|
||||
name: label
|
||||
});
|
||||
}
|
||||
4
.github/workflows/pr_lint.yml
vendored
4
.github/workflows/pr_lint.yml
vendored
@@ -26,12 +26,14 @@
|
||||
# * revert — reverts a previous commit
|
||||
# * release — prepare a new release
|
||||
#
|
||||
# Allowed Scopes (optional):
|
||||
# Allowed Scope(s) (optional):
|
||||
# core, cli, langchain, langchain_v1, langchain-classic, standard-tests,
|
||||
# text-splitters, docs, anthropic, chroma, deepseek, exa, fireworks, groq,
|
||||
# huggingface, mistralai, nomic, ollama, openai, perplexity, prompty, qdrant,
|
||||
# xai, infra, deps
|
||||
#
|
||||
# Multiple scopes can be used by separating them with a comma.
|
||||
#
|
||||
# Rules:
|
||||
# 1. The 'Type' must start with a lowercase letter.
|
||||
# 2. Breaking changes: append "!" after type/scope (e.g., feat!: drop x support)
|
||||
|
||||
54
README.md
54
README.md
@@ -1,38 +1,26 @@
|
||||
<p align="center">
|
||||
<picture>
|
||||
<source media="(prefers-color-scheme: light)" srcset=".github/images/logo-dark.svg">
|
||||
<source media="(prefers-color-scheme: dark)" srcset=".github/images/logo-light.svg">
|
||||
<img alt="LangChain Logo" src=".github/images/logo-dark.svg" width="80%">
|
||||
</picture>
|
||||
</p>
|
||||
<div align="center">
|
||||
<a href="https://www.langchain.com/">
|
||||
<picture>
|
||||
<source media="(prefers-color-scheme: light)" srcset=".github/images/logo-dark.svg">
|
||||
<source media="(prefers-color-scheme: dark)" srcset=".github/images/logo-light.svg">
|
||||
<img alt="LangChain Logo" src=".github/images/logo-dark.svg" width="80%">
|
||||
</picture>
|
||||
</a>
|
||||
</div>
|
||||
|
||||
<p align="center">
|
||||
The platform for reliable agents.
|
||||
</p>
|
||||
<div align="center">
|
||||
<h3>The platform for reliable agents.</h3>
|
||||
</div>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://opensource.org/licenses/MIT" target="_blank">
|
||||
<img src="https://img.shields.io/pypi/l/langchain" alt="PyPI - License">
|
||||
</a>
|
||||
<a href="https://pypistats.org/packages/langchain" target="_blank">
|
||||
<img src="https://img.shields.io/pepy/dt/langchain" alt="PyPI - Downloads">
|
||||
</a>
|
||||
<a href="https://pypi.org/project/langchain/#history" target="_blank">
|
||||
<img src="https://img.shields.io/pypi/v/langchain?label=%20" alt="Version">
|
||||
</a>
|
||||
<a href="https://vscode.dev/redirect?url=vscode://ms-vscode-remote.remote-containers/cloneInVolume?url=https://github.com/langchain-ai/langchain" target="_blank">
|
||||
<img src="https://img.shields.io/static/v1?label=Dev%20Containers&message=Open&color=blue&logo=visualstudiocode" alt="Open in Dev Containers">
|
||||
</a>
|
||||
<a href="https://codespaces.new/langchain-ai/langchain" target="_blank">
|
||||
<img src="https://github.com/codespaces/badge.svg" alt="Open in Github Codespace" title="Open in Github Codespace" width="150" height="20">
|
||||
</a>
|
||||
<a href="https://codspeed.io/langchain-ai/langchain" target="_blank">
|
||||
<img src="https://img.shields.io/endpoint?url=https://codspeed.io/badge.json" alt="CodSpeed Badge">
|
||||
</a>
|
||||
<a href="https://twitter.com/langchainai" target="_blank">
|
||||
<img src="https://img.shields.io/twitter/url/https/twitter.com/langchainai.svg?style=social&label=Follow%20%40LangChainAI" alt="Twitter / X">
|
||||
</a>
|
||||
</p>
|
||||
<div align="center">
|
||||
<a href="https://opensource.org/licenses/MIT" target="_blank"><img src="https://img.shields.io/pypi/l/langchain" alt="PyPI - License"></a>
|
||||
<a href="https://pypistats.org/packages/langchain" target="_blank"><img src="https://img.shields.io/pepy/dt/langchain" alt="PyPI - Downloads"></a>
|
||||
<a href="https://pypi.org/project/langchain/#history" target="_blank"><img src="https://img.shields.io/pypi/v/langchain?label=%20" alt="Version"></a>
|
||||
<a href="https://vscode.dev/redirect?url=vscode://ms-vscode-remote.remote-containers/cloneInVolume?url=https://github.com/langchain-ai/langchain" target="_blank"><img src="https://img.shields.io/static/v1?label=Dev%20Containers&message=Open&color=blue&logo=visualstudiocode" alt="Open in Dev Containers"></a>
|
||||
<a href="https://codespaces.new/langchain-ai/langchain" target="_blank"><img src="https://github.com/codespaces/badge.svg" alt="Open in Github Codespace" title="Open in Github Codespace" width="150" height="20"></a>
|
||||
<a href="https://codspeed.io/langchain-ai/langchain" target="_blank"><img src="https://img.shields.io/endpoint?url=https://codspeed.io/badge.json" alt="CodSpeed Badge"></a>
|
||||
<a href="https://twitter.com/langchainai" target="_blank"><img src="https://img.shields.io/twitter/url/https/twitter.com/langchainai.svg?style=social&label=Follow%20%40LangChainAI" alt="Twitter / X"></a>
|
||||
</div>
|
||||
|
||||
LangChain is a framework for building agents and LLM-powered applications. It helps you chain together interoperable components and third-party integrations to simplify AI application development – all while future-proofing decisions as the underlying technology evolves.
|
||||
|
||||
|
||||
@@ -6,9 +6,8 @@ import hashlib
|
||||
import logging
|
||||
import re
|
||||
import shutil
|
||||
from collections.abc import Sequence
|
||||
from pathlib import Path
|
||||
from typing import Any, TypedDict
|
||||
from typing import TYPE_CHECKING, Any, TypedDict
|
||||
|
||||
from git import Repo
|
||||
|
||||
@@ -18,6 +17,9 @@ from langchain_cli.constants import (
|
||||
DEFAULT_GIT_SUBDIRECTORY,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .file import File
|
||||
from .folder import Folder
|
||||
if TYPE_CHECKING:
|
||||
from .file import File
|
||||
from .folder import Folder
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .file import File
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class Folder:
|
||||
def __init__(self, name: str, *files: Folder | File) -> None:
|
||||
|
||||
@@ -34,7 +34,7 @@ The LangChain ecosystem is built on top of `langchain-core`. Some of the benefit
|
||||
|
||||
## 📖 Documentation
|
||||
|
||||
For full documentation, see the [API reference](https://reference.langchain.com/python/langchain_core/).
|
||||
For full documentation, see the [API reference](https://reference.langchain.com/python/langchain_core/). For conceptual guides, tutorials, and examples on using LangChain, see the [LangChain Docs](https://docs.langchain.com/oss/python/langchain/overview).
|
||||
|
||||
## 📕 Releases & Versioning
|
||||
|
||||
|
||||
@@ -5,13 +5,12 @@ from __future__ import annotations
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
from uuid import UUID
|
||||
|
||||
from tenacity import RetryCallState
|
||||
from typing_extensions import Self
|
||||
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.documents import Document
|
||||
|
||||
@@ -39,7 +39,6 @@ from langchain_core.tracers.context import (
|
||||
tracing_v2_callback_var,
|
||||
)
|
||||
from langchain_core.tracers.langchain import LangChainTracer
|
||||
from langchain_core.tracers.schemas import Run
|
||||
from langchain_core.tracers.stdout import ConsoleCallbackHandler
|
||||
from langchain_core.utils.env import env_var_is_set
|
||||
|
||||
@@ -52,6 +51,7 @@ if TYPE_CHECKING:
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk, LLMResult
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
from langchain_core.tracers.schemas import Run
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -6,16 +6,9 @@ import hashlib
|
||||
import json
|
||||
import uuid
|
||||
import warnings
|
||||
from collections.abc import (
|
||||
AsyncIterable,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Iterable,
|
||||
Iterator,
|
||||
Sequence,
|
||||
)
|
||||
from itertools import islice
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Literal,
|
||||
TypedDict,
|
||||
@@ -29,6 +22,16 @@ from langchain_core.exceptions import LangChainException
|
||||
from langchain_core.indexing.base import DocumentIndex, RecordManager
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import (
|
||||
AsyncIterable,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Iterable,
|
||||
Iterator,
|
||||
Sequence,
|
||||
)
|
||||
|
||||
# Magic UUID to use as a namespace for hashing.
|
||||
# Used to try and generate a unique UUID for each document
|
||||
# from hashing the document content and metadata.
|
||||
|
||||
@@ -299,6 +299,9 @@ class BaseLanguageModel(
|
||||
|
||||
Useful for checking if an input fits in a model's context window.
|
||||
|
||||
This should be overridden by model-specific implementations to provide accurate
|
||||
token counts via model-specific tokenizers.
|
||||
|
||||
Args:
|
||||
text: The string input to tokenize.
|
||||
|
||||
@@ -317,9 +320,17 @@ class BaseLanguageModel(
|
||||
|
||||
Useful for checking if an input fits in a model's context window.
|
||||
|
||||
This should be overridden by model-specific implementations to provide accurate
|
||||
token counts via model-specific tokenizers.
|
||||
|
||||
!!! note
|
||||
The base implementation of `get_num_tokens_from_messages` ignores tool
|
||||
schemas.
|
||||
|
||||
* The base implementation of `get_num_tokens_from_messages` ignores tool
|
||||
schemas.
|
||||
* The base implementation of `get_num_tokens_from_messages` adds additional
|
||||
prefixes to messages in represent user roles, which will add to the
|
||||
overall token count. Model-specific implementations may choose to
|
||||
handle this differently.
|
||||
|
||||
Args:
|
||||
messages: The message inputs to tokenize.
|
||||
|
||||
@@ -91,7 +91,10 @@ def _generate_response_from_error(error: BaseException) -> list[ChatGeneration]:
|
||||
try:
|
||||
metadata["body"] = response.json()
|
||||
except Exception:
|
||||
metadata["body"] = getattr(response, "text", None)
|
||||
try:
|
||||
metadata["body"] = getattr(response, "text", None)
|
||||
except Exception:
|
||||
metadata["body"] = None
|
||||
if hasattr(response, "headers"):
|
||||
try:
|
||||
metadata["headers"] = dict(response.headers)
|
||||
|
||||
@@ -61,13 +61,15 @@ class Reviver:
|
||||
"""Initialize the reviver.
|
||||
|
||||
Args:
|
||||
secrets_map: A map of secrets to load. If a secret is not found in
|
||||
the map, it will be loaded from the environment if `secrets_from_env`
|
||||
is True.
|
||||
secrets_map: A map of secrets to load.
|
||||
|
||||
If a secret is not found in the map, it will be loaded from the
|
||||
environment if `secrets_from_env` is `True`.
|
||||
valid_namespaces: A list of additional namespaces (modules)
|
||||
to allow to be deserialized.
|
||||
secrets_from_env: Whether to load secrets from the environment.
|
||||
additional_import_mappings: A dictionary of additional namespace mappings
|
||||
|
||||
You can use this to override default mappings or add new mappings.
|
||||
ignore_unserializable_fields: Whether to ignore unserializable fields.
|
||||
"""
|
||||
@@ -195,13 +197,15 @@ def loads(
|
||||
|
||||
Args:
|
||||
text: The string to load.
|
||||
secrets_map: A map of secrets to load. If a secret is not found in
|
||||
the map, it will be loaded from the environment if `secrets_from_env`
|
||||
is True.
|
||||
secrets_map: A map of secrets to load.
|
||||
|
||||
If a secret is not found in the map, it will be loaded from the environment
|
||||
if `secrets_from_env` is `True`.
|
||||
valid_namespaces: A list of additional namespaces (modules)
|
||||
to allow to be deserialized.
|
||||
secrets_from_env: Whether to load secrets from the environment.
|
||||
additional_import_mappings: A dictionary of additional namespace mappings
|
||||
|
||||
You can use this to override default mappings or add new mappings.
|
||||
ignore_unserializable_fields: Whether to ignore unserializable fields.
|
||||
|
||||
@@ -237,13 +241,15 @@ def load(
|
||||
|
||||
Args:
|
||||
obj: The object to load.
|
||||
secrets_map: A map of secrets to load. If a secret is not found in
|
||||
the map, it will be loaded from the environment if `secrets_from_env`
|
||||
is True.
|
||||
secrets_map: A map of secrets to load.
|
||||
|
||||
If a secret is not found in the map, it will be loaded from the environment
|
||||
if `secrets_from_env` is `True`.
|
||||
valid_namespaces: A list of additional namespaces (modules)
|
||||
to allow to be deserialized.
|
||||
secrets_from_env: Whether to load secrets from the environment.
|
||||
additional_import_mappings: A dictionary of additional namespace mappings
|
||||
|
||||
You can use this to override default mappings or add new mappings.
|
||||
ignore_unserializable_fields: Whether to ignore unserializable fields.
|
||||
|
||||
|
||||
@@ -5,11 +5,9 @@ from __future__ import annotations
|
||||
from typing import TYPE_CHECKING, Any, cast, overload
|
||||
|
||||
from pydantic import ConfigDict, Field
|
||||
from typing_extensions import Self
|
||||
|
||||
from langchain_core._api.deprecation import warn_deprecated
|
||||
from langchain_core.load.serializable import Serializable
|
||||
from langchain_core.messages import content as types
|
||||
from langchain_core.utils import get_bolded_text
|
||||
from langchain_core.utils._merge import merge_dicts, merge_lists
|
||||
from langchain_core.utils.interactive_env import is_interactive_env
|
||||
@@ -17,6 +15,9 @@ from langchain_core.utils.interactive_env import is_interactive_env
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
from langchain_core.messages import content as types
|
||||
from langchain_core.prompts.chat import ChatPromptTemplate
|
||||
|
||||
|
||||
|
||||
@@ -12,10 +12,11 @@ the implementation in `BaseMessage`.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk
|
||||
from langchain_core.messages import content as types
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
import warnings
|
||||
from collections.abc import Iterable
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
|
||||
from langchain_core.language_models._utils import (
|
||||
@@ -14,6 +13,8 @@ from langchain_core.language_models._utils import (
|
||||
from langchain_core.messages import content as types
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable
|
||||
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk
|
||||
|
||||
|
||||
|
||||
@@ -2,15 +2,17 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
|
||||
from pydantic import model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from langchain_core.messages import BaseMessage, BaseMessageChunk
|
||||
from langchain_core.outputs.generation import Generation
|
||||
from langchain_core.utils._merge import merge_dicts
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
class ChatGeneration(Generation):
|
||||
"""A single chat generation output.
|
||||
|
||||
@@ -6,7 +6,7 @@ import contextlib
|
||||
import json
|
||||
import typing
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Mapping
|
||||
from collections.abc import Mapping
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
@@ -33,6 +33,8 @@ from langchain_core.runnables.config import ensure_config
|
||||
from langchain_core.utils.pydantic import create_model_v2
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
from langchain_core.documents import Document
|
||||
|
||||
|
||||
|
||||
@@ -6,10 +6,10 @@ from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_core.load import Serializable
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.utils.interactive_env import is_interactive_env
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.prompts.chat import ChatPromptTemplate
|
||||
|
||||
|
||||
|
||||
@@ -4,9 +4,8 @@ from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from abc import ABC
|
||||
from collections.abc import Callable, Sequence
|
||||
from string import Formatter
|
||||
from typing import Any, Literal
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
from pydantic import BaseModel, create_model
|
||||
|
||||
@@ -16,6 +15,9 @@ from langchain_core.utils import get_colored_text, mustache
|
||||
from langchain_core.utils.formatting import formatter
|
||||
from langchain_core.utils.interactive_env import is_interactive_env
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable, Sequence
|
||||
|
||||
try:
|
||||
from jinja2 import Environment, meta
|
||||
from jinja2.sandbox import SandboxedEnvironment
|
||||
|
||||
@@ -4,7 +4,6 @@ from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
@@ -22,7 +21,7 @@ from langchain_core.runnables.base import Runnable, RunnableSerializable
|
||||
from langchain_core.utils.pydantic import _IgnoreUnserializable, is_basemodel_subclass
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Sequence
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ from __future__ import annotations
|
||||
|
||||
import math
|
||||
import os
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
try:
|
||||
@@ -20,6 +19,8 @@ except ImportError:
|
||||
_HAS_GRANDALF = False
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Mapping, Sequence
|
||||
|
||||
from langchain_core.runnables.graph import Edge as LangEdge
|
||||
|
||||
|
||||
|
||||
@@ -7,8 +7,7 @@ import asyncio
|
||||
import inspect
|
||||
import sys
|
||||
import textwrap
|
||||
from collections.abc import Callable, Mapping, Sequence
|
||||
from contextvars import Context
|
||||
from collections.abc import Mapping, Sequence
|
||||
from functools import lru_cache
|
||||
from inspect import signature
|
||||
from itertools import groupby
|
||||
@@ -31,9 +30,11 @@ if TYPE_CHECKING:
|
||||
AsyncIterable,
|
||||
AsyncIterator,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Coroutine,
|
||||
Iterable,
|
||||
)
|
||||
from contextvars import Context
|
||||
|
||||
from langchain_core.runnables.schema import StreamEvent
|
||||
|
||||
|
||||
@@ -15,12 +15,6 @@ from typing import (
|
||||
|
||||
from langchain_core.exceptions import TracerException
|
||||
from langchain_core.load import dumpd
|
||||
from langchain_core.outputs import (
|
||||
ChatGeneration,
|
||||
ChatGenerationChunk,
|
||||
GenerationChunk,
|
||||
LLMResult,
|
||||
)
|
||||
from langchain_core.tracers.schemas import Run
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -31,6 +25,12 @@ if TYPE_CHECKING:
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.outputs import (
|
||||
ChatGeneration,
|
||||
ChatGenerationChunk,
|
||||
GenerationChunk,
|
||||
LLMResult,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -8,7 +8,6 @@ import logging
|
||||
import types
|
||||
import typing
|
||||
import uuid
|
||||
from collections.abc import Callable
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Annotated,
|
||||
@@ -33,6 +32,8 @@ from langchain_core.utils.json_schema import dereference_refs
|
||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -4,11 +4,13 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
|
||||
def _replace_new_line(match: re.Match[str]) -> str:
|
||||
value = match.group(2)
|
||||
|
||||
@@ -5,7 +5,6 @@ from __future__ import annotations
|
||||
import inspect
|
||||
import textwrap
|
||||
import warnings
|
||||
from collections.abc import Callable
|
||||
from contextlib import nullcontext
|
||||
from functools import lru_cache, wraps
|
||||
from types import GenericAlias
|
||||
@@ -41,10 +40,12 @@ from pydantic.json_schema import (
|
||||
)
|
||||
from pydantic.v1 import BaseModel as BaseModelV1
|
||||
from pydantic.v1 import create_model as create_model_v1
|
||||
from pydantic.v1.fields import ModelField
|
||||
from typing_extensions import deprecated, override
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
from pydantic.v1.fields import ModelField
|
||||
from pydantic_core import core_schema
|
||||
|
||||
PYDANTIC_VERSION = version.parse(pydantic.__version__)
|
||||
|
||||
@@ -11,7 +11,6 @@ import logging
|
||||
import math
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from itertools import cycle
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
@@ -29,7 +28,7 @@ from langchain_core.retrievers import BaseRetriever, LangSmithRetrieverParams
|
||||
from langchain_core.runnables.config import run_in_executor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Collection, Iterable, Iterator, Sequence
|
||||
from collections.abc import Callable, Collection, Iterable, Iterator, Sequence
|
||||
|
||||
from langchain_core.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
|
||||
@@ -4,7 +4,6 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
@@ -20,7 +19,7 @@ from langchain_core.vectorstores.utils import _cosine_similarity as cosine_simil
|
||||
from langchain_core.vectorstores.utils import maximal_marginal_relevance
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterator, Sequence
|
||||
from collections.abc import Callable, Iterator, Sequence
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
"""langchain-core version information and utilities."""
|
||||
|
||||
VERSION = "1.0.4"
|
||||
VERSION = "1.0.5"
|
||||
|
||||
@@ -9,7 +9,7 @@ license = {text = "MIT"}
|
||||
readme = "README.md"
|
||||
authors = []
|
||||
|
||||
version = "1.0.4"
|
||||
version = "1.0.5"
|
||||
requires-python = ">=3.10.0,<4.0.0"
|
||||
dependencies = [
|
||||
"langsmith>=0.3.45,<1.0.0",
|
||||
|
||||
@@ -18,6 +18,7 @@ from langchain_core.language_models import (
|
||||
ParrotFakeChatModel,
|
||||
)
|
||||
from langchain_core.language_models._utils import _normalize_messages
|
||||
from langchain_core.language_models.chat_models import _generate_response_from_error
|
||||
from langchain_core.language_models.fake_chat_models import (
|
||||
FakeListChatModelError,
|
||||
GenericFakeChatModel,
|
||||
@@ -1234,3 +1235,93 @@ def test_model_profiles() -> None:
|
||||
model = MyModel(messages=iter([]))
|
||||
profile = model.profile
|
||||
assert profile
|
||||
|
||||
|
||||
class MockResponse:
|
||||
"""Mock response for testing _generate_response_from_error."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
status_code: int = 400,
|
||||
headers: dict[str, str] | None = None,
|
||||
json_data: dict[str, Any] | None = None,
|
||||
json_raises: type[Exception] | None = None,
|
||||
text_raises: type[Exception] | None = None,
|
||||
):
|
||||
self.status_code = status_code
|
||||
self.headers = headers or {}
|
||||
self._json_data = json_data
|
||||
self._json_raises = json_raises
|
||||
self._text_raises = text_raises
|
||||
|
||||
def json(self) -> dict[str, Any]:
|
||||
if self._json_raises:
|
||||
msg = "JSON parsing failed"
|
||||
raise self._json_raises(msg)
|
||||
return self._json_data or {}
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
if self._text_raises:
|
||||
msg = "Text access failed"
|
||||
raise self._text_raises(msg)
|
||||
return ""
|
||||
|
||||
|
||||
class MockAPIError(Exception):
|
||||
"""Mock API error with response attribute."""
|
||||
|
||||
def __init__(self, message: str, response: MockResponse | None = None):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
if response is not None:
|
||||
self.response = response
|
||||
|
||||
|
||||
def test_generate_response_from_error_with_valid_json() -> None:
|
||||
"""Test `_generate_response_from_error` with valid JSON response."""
|
||||
response = MockResponse(
|
||||
status_code=400,
|
||||
headers={"content-type": "application/json"},
|
||||
json_data={"error": {"message": "Bad request", "type": "invalid_request"}},
|
||||
)
|
||||
error = MockAPIError("API Error", response=response)
|
||||
|
||||
generations = _generate_response_from_error(error)
|
||||
|
||||
assert len(generations) == 1
|
||||
generation = generations[0]
|
||||
assert isinstance(generation, ChatGeneration)
|
||||
assert isinstance(generation.message, AIMessage)
|
||||
assert generation.message.content == ""
|
||||
|
||||
metadata = generation.message.response_metadata
|
||||
assert metadata["body"] == {
|
||||
"error": {"message": "Bad request", "type": "invalid_request"}
|
||||
}
|
||||
assert metadata["headers"] == {"content-type": "application/json"}
|
||||
assert metadata["status_code"] == 400
|
||||
|
||||
|
||||
def test_generate_response_from_error_handles_streaming_response_failure() -> None:
|
||||
# Simulates scenario where accessing response.json() or response.text
|
||||
# raises ResponseNotRead on streaming responses
|
||||
response = MockResponse(
|
||||
status_code=400,
|
||||
headers={"content-type": "application/json"},
|
||||
json_raises=Exception, # Simulates ResponseNotRead or similar
|
||||
text_raises=Exception,
|
||||
)
|
||||
error = MockAPIError("API Error", response=response)
|
||||
|
||||
# This should NOT raise an exception, but should handle it gracefully
|
||||
generations = _generate_response_from_error(error)
|
||||
|
||||
assert len(generations) == 1
|
||||
generation = generations[0]
|
||||
metadata = generation.message.response_metadata
|
||||
|
||||
# When both fail, body should be None instead of raising an exception
|
||||
assert metadata["body"] is None
|
||||
assert metadata["headers"] == {"content-type": "application/json"}
|
||||
assert metadata["status_code"] == 400
|
||||
|
||||
@@ -0,0 +1,140 @@
|
||||
"""Test groq block translator."""
|
||||
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.messages import content as types
|
||||
from langchain_core.messages.base import _extract_reasoning_from_additional_kwargs
|
||||
from langchain_core.messages.block_translators import PROVIDER_TRANSLATORS
|
||||
from langchain_core.messages.block_translators.groq import (
|
||||
_parse_code_json,
|
||||
translate_content,
|
||||
)
|
||||
|
||||
|
||||
def test_groq_translator_registered() -> None:
|
||||
"""Test that groq translator is properly registered."""
|
||||
assert "groq" in PROVIDER_TRANSLATORS
|
||||
assert "translate_content" in PROVIDER_TRANSLATORS["groq"]
|
||||
assert "translate_content_chunk" in PROVIDER_TRANSLATORS["groq"]
|
||||
|
||||
|
||||
def test_extract_reasoning_from_additional_kwargs_exists() -> None:
|
||||
"""Test that _extract_reasoning_from_additional_kwargs can be imported."""
|
||||
# Verify it's callable
|
||||
assert callable(_extract_reasoning_from_additional_kwargs)
|
||||
|
||||
|
||||
def test_groq_translate_content_basic() -> None:
|
||||
"""Test basic groq content translation."""
|
||||
# Test with simple text message
|
||||
message = AIMessage(content="Hello world")
|
||||
blocks = translate_content(message)
|
||||
|
||||
assert isinstance(blocks, list)
|
||||
assert len(blocks) == 1
|
||||
assert blocks[0]["type"] == "text"
|
||||
assert blocks[0]["text"] == "Hello world"
|
||||
|
||||
|
||||
def test_groq_translate_content_with_reasoning() -> None:
|
||||
"""Test groq content translation with reasoning content."""
|
||||
# Test with reasoning content in additional_kwargs
|
||||
message = AIMessage(
|
||||
content="Final answer",
|
||||
additional_kwargs={"reasoning_content": "Let me think about this..."},
|
||||
)
|
||||
blocks = translate_content(message)
|
||||
|
||||
assert isinstance(blocks, list)
|
||||
assert len(blocks) == 2
|
||||
|
||||
# First block should be reasoning
|
||||
assert blocks[0]["type"] == "reasoning"
|
||||
assert blocks[0]["reasoning"] == "Let me think about this..."
|
||||
|
||||
# Second block should be text
|
||||
assert blocks[1]["type"] == "text"
|
||||
assert blocks[1]["text"] == "Final answer"
|
||||
|
||||
|
||||
def test_groq_translate_content_with_tool_calls() -> None:
|
||||
"""Test groq content translation with tool calls."""
|
||||
# Test with tool calls
|
||||
message = AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "search",
|
||||
"args": {"query": "test"},
|
||||
"id": "call_123",
|
||||
}
|
||||
],
|
||||
)
|
||||
blocks = translate_content(message)
|
||||
|
||||
assert isinstance(blocks, list)
|
||||
assert len(blocks) == 1
|
||||
assert blocks[0]["type"] == "tool_call"
|
||||
assert blocks[0]["name"] == "search"
|
||||
assert blocks[0]["args"] == {"query": "test"}
|
||||
assert blocks[0]["id"] == "call_123"
|
||||
|
||||
|
||||
def test_groq_translate_content_with_executed_tools() -> None:
|
||||
"""Test groq content translation with executed tools (built-in tools)."""
|
||||
# Test with executed_tools in additional_kwargs (Groq built-in tools)
|
||||
message = AIMessage(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"executed_tools": [
|
||||
{
|
||||
"type": "python",
|
||||
"arguments": '{"code": "print(\\"hello\\")"}',
|
||||
"output": "hello\\n",
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
blocks = translate_content(message)
|
||||
|
||||
assert isinstance(blocks, list)
|
||||
# Should have server_tool_call and server_tool_result
|
||||
assert len(blocks) >= 2
|
||||
|
||||
# Check for server_tool_call
|
||||
tool_call_blocks = [
|
||||
cast("types.ServerToolCall", b)
|
||||
for b in blocks
|
||||
if b.get("type") == "server_tool_call"
|
||||
]
|
||||
assert len(tool_call_blocks) == 1
|
||||
assert tool_call_blocks[0]["name"] == "code_interpreter"
|
||||
assert "code" in tool_call_blocks[0]["args"]
|
||||
|
||||
# Check for server_tool_result
|
||||
tool_result_blocks = [
|
||||
cast("types.ServerToolResult", b)
|
||||
for b in blocks
|
||||
if b.get("type") == "server_tool_result"
|
||||
]
|
||||
assert len(tool_result_blocks) == 1
|
||||
assert tool_result_blocks[0]["output"] == "hello\\n"
|
||||
assert tool_result_blocks[0]["status"] == "success"
|
||||
|
||||
|
||||
def test_parse_code_json() -> None:
|
||||
"""Test the _parse_code_json helper function."""
|
||||
# Test valid code JSON
|
||||
result = _parse_code_json('{"code": "print(\'hello\')"}')
|
||||
assert result == {"code": "print('hello')"}
|
||||
|
||||
# Test code with unescaped quotes (Groq format)
|
||||
result = _parse_code_json('{"code": "print("hello")"}')
|
||||
assert result == {"code": 'print("hello")'}
|
||||
|
||||
# Test invalid format raises ValueError
|
||||
with pytest.raises(ValueError, match="Could not extract Python code"):
|
||||
_parse_code_json('{"invalid": "format"}')
|
||||
@@ -3,12 +3,14 @@
|
||||
import asyncio
|
||||
import time
|
||||
from threading import Lock
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_core.runnables import RunnableConfig, RunnableLambda
|
||||
from langchain_core.runnables.base import Runnable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.runnables.base import Runnable
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -3,21 +3,25 @@ from __future__ import annotations
|
||||
import json
|
||||
import sys
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator, Callable, Coroutine, Generator
|
||||
from inspect import isasyncgenfunction
|
||||
from typing import Any, Literal
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from langsmith import Client, get_current_run_tree, traceable
|
||||
from langsmith.run_helpers import tracing_context
|
||||
from langsmith.run_trees import RunTree
|
||||
from langsmith.utils import get_env_var
|
||||
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.runnables.base import RunnableLambda, RunnableParallel
|
||||
from langchain_core.tracers.langchain import LangChainTracer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncGenerator, Callable, Coroutine, Generator
|
||||
|
||||
from langsmith.run_trees import RunTree
|
||||
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
|
||||
|
||||
def _get_posts(client: Client) -> list:
|
||||
mock_calls = client.session.request.mock_calls # type: ignore[attr-defined]
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from collections.abc import Callable
|
||||
from contextlib import AbstractContextManager, nullcontext
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
@@ -23,6 +22,9 @@ from langchain_core.utils import (
|
||||
from langchain_core.utils._merge import merge_dicts
|
||||
from langchain_core.utils.utils import secret_from_env
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("package", "check_kwargs", "actual_version", "expected"),
|
||||
|
||||
6
libs/core/uv.lock
generated
6
libs/core/uv.lock
generated
@@ -1,5 +1,5 @@
|
||||
version = 1
|
||||
revision = 2
|
||||
revision = 3
|
||||
requires-python = ">=3.10.0, <4.0.0"
|
||||
resolution-markers = [
|
||||
"python_full_version >= '3.14' and platform_python_implementation == 'PyPy'",
|
||||
@@ -960,7 +960,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "langchain-core"
|
||||
version = "1.0.4"
|
||||
version = "1.0.5"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "jsonpatch" },
|
||||
@@ -1057,7 +1057,7 @@ typing = [
|
||||
|
||||
[[package]]
|
||||
name = "langchain-model-profiles"
|
||||
version = "0.0.3"
|
||||
version = "0.0.4"
|
||||
source = { directory = "../model-profiles" }
|
||||
dependencies = [
|
||||
{ name = "tomli", marker = "python_full_version < '3.11'" },
|
||||
|
||||
@@ -24,7 +24,7 @@ In most cases, you should be using the main [`langchain`](https://pypi.org/proje
|
||||
|
||||
## 📖 Documentation
|
||||
|
||||
For full documentation, see the [API reference](https://reference.langchain.com/python/langchain_classic).
|
||||
For full documentation, see the [API reference](https://reference.langchain.com/python/langchain_classic). For conceptual guides, tutorials, and examples on using LangChain, see the [LangChain Docs](https://docs.langchain.com/oss/python/langchain/overview).
|
||||
|
||||
## 📕 Releases & Versioning
|
||||
|
||||
|
||||
@@ -100,6 +100,21 @@ def init_chat_model(
|
||||
|
||||
You can also specify model and model provider in a single argument using
|
||||
`'{model_provider}:{model}'` format, e.g. `'openai:o1'`.
|
||||
|
||||
Will attempt to infer `model_provider` from model if not specified.
|
||||
|
||||
The following providers will be inferred based on these model prefixes:
|
||||
|
||||
- `gpt-...` | `o1...` | `o3...` -> `openai`
|
||||
- `claude...` -> `anthropic`
|
||||
- `amazon...` -> `bedrock`
|
||||
- `gemini...` -> `google_vertexai`
|
||||
- `command...` -> `cohere`
|
||||
- `accounts/fireworks...` -> `fireworks`
|
||||
- `mistral...` -> `mistralai`
|
||||
- `deepseek...` -> `deepseek`
|
||||
- `grok...` -> `xai`
|
||||
- `sonar...` -> `perplexity`
|
||||
model_provider: The model provider if not specified as part of the model arg
|
||||
(see above).
|
||||
|
||||
@@ -123,24 +138,10 @@ def init_chat_model(
|
||||
- `ollama` -> [`langchain-ollama`](https://docs.langchain.com/oss/python/integrations/providers/ollama)
|
||||
- `google_anthropic_vertex` -> [`langchain-google-vertexai`](https://docs.langchain.com/oss/python/integrations/providers/google)
|
||||
- `deepseek` -> [`langchain-deepseek`](https://docs.langchain.com/oss/python/integrations/providers/deepseek)
|
||||
- `ibm` -> [`langchain-ibm`](https://docs.langchain.com/oss/python/integrations/providers/deepseek)
|
||||
- `ibm` -> [`langchain-ibm`](https://docs.langchain.com/oss/python/integrations/providers/ibm)
|
||||
- `nvidia` -> [`langchain-nvidia-ai-endpoints`](https://docs.langchain.com/oss/python/integrations/providers/nvidia)
|
||||
- `xai` -> [`langchain-xai`](https://docs.langchain.com/oss/python/integrations/providers/xai)
|
||||
- `perplexity` -> [`langchain-perplexity`](https://docs.langchain.com/oss/python/integrations/providers/perplexity)
|
||||
|
||||
Will attempt to infer `model_provider` from model if not specified. The
|
||||
following providers will be inferred based on these model prefixes:
|
||||
|
||||
- `gpt-...` | `o1...` | `o3...` -> `openai`
|
||||
- `claude...` -> `anthropic`
|
||||
- `amazon...` -> `bedrock`
|
||||
- `gemini...` -> `google_vertexai`
|
||||
- `command...` -> `cohere`
|
||||
- `accounts/fireworks...` -> `fireworks`
|
||||
- `mistral...` -> `mistralai`
|
||||
- `deepseek...` -> `deepseek`
|
||||
- `grok...` -> `xai`
|
||||
- `sonar...` -> `perplexity`
|
||||
configurable_fields: Which model parameters are configurable at runtime:
|
||||
|
||||
- `None`: No configurable fields (i.e., a fixed model).
|
||||
@@ -155,6 +156,7 @@ def init_chat_model(
|
||||
If `model` is not specified, then defaults to `("model", "model_provider")`.
|
||||
|
||||
!!! warning "Security note"
|
||||
|
||||
Setting `configurable_fields="any"` means fields like `api_key`,
|
||||
`base_url`, etc., can be altered at runtime, potentially redirecting
|
||||
model requests to a different service/user.
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
.PHONY: all start_services stop_services coverage test test_fast extended_tests test_watch test_watch_extended integration_tests check_imports lint format lint_diff format_diff lint_package lint_tests help
|
||||
.PHONY: all start_services stop_services coverage coverage_agents test test_fast extended_tests test_watch test_watch_extended integration_tests check_imports lint format lint_diff format_diff lint_package lint_tests help
|
||||
|
||||
# Default target executed when no arguments are given to make.
|
||||
all: help
|
||||
@@ -27,8 +27,17 @@ coverage:
|
||||
--cov-report term-missing:skip-covered \
|
||||
$(TEST_FILE)
|
||||
|
||||
# Run middleware and agent tests with coverage report.
|
||||
coverage_agents:
|
||||
uv run --group test pytest \
|
||||
tests/unit_tests/agents/middleware/ \
|
||||
tests/unit_tests/agents/test_*.py \
|
||||
--cov=langchain.agents \
|
||||
--cov-report=term-missing \
|
||||
--cov-report=html:htmlcov \
|
||||
|
||||
test:
|
||||
make start_services && LANGGRAPH_TEST_FAST=0 uv run --no-sync --active --group test pytest -n auto --disable-socket --allow-unix-socket $(TEST_FILE) --cov-report term-missing:skip-covered; \
|
||||
make start_services && LANGGRAPH_TEST_FAST=0 uv run --no-sync --active --group test pytest -n auto --disable-socket --allow-unix-socket $(TEST_FILE) --cov-report term-missing:skip-covered --snapshot-update; \
|
||||
EXIT_CODE=$$?; \
|
||||
make stop_services; \
|
||||
exit $$EXIT_CODE
|
||||
@@ -93,6 +102,7 @@ help:
|
||||
@echo 'lint - run linters'
|
||||
@echo '-- TESTS --'
|
||||
@echo 'coverage - run unit tests and generate coverage report'
|
||||
@echo 'coverage_agents - run middleware and agent tests with coverage report'
|
||||
@echo 'test - run unit tests with all services'
|
||||
@echo 'test_fast - run unit tests with in-memory services only'
|
||||
@echo 'tests - run unit tests (alias for "make test")'
|
||||
|
||||
@@ -26,7 +26,7 @@ LangChain [agents](https://docs.langchain.com/oss/python/langchain/agents) are b
|
||||
|
||||
## 📖 Documentation
|
||||
|
||||
For full documentation, see the [API reference](https://reference.langchain.com/python/langchain/langchain/).
|
||||
For full documentation, see the [API reference](https://reference.langchain.com/python/langchain/langchain/). For conceptual guides, tutorials, and examples on using LangChain, see the [LangChain Docs](https://docs.langchain.com/oss/python/langchain/overview).
|
||||
|
||||
## 📕 Releases & Versioning
|
||||
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
"""Main entrypoint into LangChain."""
|
||||
|
||||
__version__ = "1.0.4"
|
||||
__version__ = "1.0.5"
|
||||
|
||||
@@ -1,10 +1,4 @@
|
||||
"""Entrypoint to building [Agents](https://docs.langchain.com/oss/python/langchain/agents) with LangChain.
|
||||
|
||||
!!! warning "Reference docs"
|
||||
This page contains **reference documentation** for Agents. See
|
||||
[the docs](https://docs.langchain.com/oss/python/langchain/agents) for conceptual
|
||||
guides, tutorials, and examples on using Agents.
|
||||
""" # noqa: E501
|
||||
"""Entrypoint to building [Agents](https://docs.langchain.com/oss/python/langchain/agents) with LangChain.""" # noqa: E501
|
||||
|
||||
from langchain.agents.factory import create_agent
|
||||
from langchain.agents.middleware.types import AgentState
|
||||
|
||||
@@ -63,6 +63,18 @@ if TYPE_CHECKING:
|
||||
|
||||
STRUCTURED_OUTPUT_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes."
|
||||
|
||||
FALLBACK_MODELS_WITH_STRUCTURED_OUTPUT = [
|
||||
# if langchain-model-profiles is not installed, these models are assumed to support
|
||||
# structured output
|
||||
"grok",
|
||||
"gpt-5",
|
||||
"gpt-4.1",
|
||||
"gpt-4o",
|
||||
"gpt-oss",
|
||||
"o3-pro",
|
||||
"o3-mini",
|
||||
]
|
||||
|
||||
|
||||
def _normalize_to_model_response(result: ModelResponse | AIMessage) -> ModelResponse:
|
||||
"""Normalize middleware return value to ModelResponse."""
|
||||
@@ -349,11 +361,13 @@ def _get_can_jump_to(middleware: AgentMiddleware[Any, Any], hook_name: str) -> l
|
||||
return []
|
||||
|
||||
|
||||
def _supports_provider_strategy(model: str | BaseChatModel) -> bool:
|
||||
def _supports_provider_strategy(model: str | BaseChatModel, tools: list | None = None) -> bool:
|
||||
"""Check if a model supports provider-specific structured output.
|
||||
|
||||
Args:
|
||||
model: Model name string or `BaseChatModel` instance.
|
||||
tools: Optional list of tools provided to the agent. Needed because some models
|
||||
don't support structured output together with tool calling.
|
||||
|
||||
Returns:
|
||||
`True` if the model supports provider-specific structured output, `False` otherwise.
|
||||
@@ -362,11 +376,26 @@ def _supports_provider_strategy(model: str | BaseChatModel) -> bool:
|
||||
if isinstance(model, str):
|
||||
model_name = model
|
||||
elif isinstance(model, BaseChatModel):
|
||||
model_name = getattr(model, "model_name", None)
|
||||
model_name = (
|
||||
getattr(model, "model_name", None)
|
||||
or getattr(model, "model", None)
|
||||
or getattr(model, "model_id", "")
|
||||
)
|
||||
try:
|
||||
model_profile = model.profile
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
if (
|
||||
model_profile.get("structured_output")
|
||||
# We make an exception for Gemini models, which currently do not support
|
||||
# simultaneous tool use with structured output
|
||||
and not (tools and isinstance(model_name, str) and "gemini" in model_name.lower())
|
||||
):
|
||||
return True
|
||||
|
||||
return (
|
||||
"grok" in model_name.lower()
|
||||
or any(part in model_name for part in ["gpt-5", "gpt-4.1", "gpt-oss", "o3-pro", "o3-mini"])
|
||||
any(part in model_name.lower() for part in FALLBACK_MODELS_WITH_STRUCTURED_OUTPUT)
|
||||
if model_name
|
||||
else False
|
||||
)
|
||||
@@ -537,17 +566,29 @@ def create_agent( # noqa: PLR0915
|
||||
visit the [Agents](https://docs.langchain.com/oss/python/langchain/agents) docs.
|
||||
|
||||
Args:
|
||||
model: The language model for the agent. Can be a string identifier
|
||||
(e.g., `"openai:gpt-4"`) or a direct chat model instance (e.g.,
|
||||
[`ChatOpenAI`][langchain_openai.ChatOpenAI] or other another
|
||||
[chat model](https://docs.langchain.com/oss/python/integrations/chat)).
|
||||
model: The language model for the agent.
|
||||
|
||||
Can be a string identifier (e.g., `"openai:gpt-4"`) or a direct chat model
|
||||
instance (e.g., [`ChatOpenAI`][langchain_openai.ChatOpenAI] or other another
|
||||
[LangChain chat model](https://docs.langchain.com/oss/python/integrations/chat)).
|
||||
|
||||
For a full list of supported model strings, see
|
||||
[`init_chat_model`][langchain.chat_models.init_chat_model(model_provider)].
|
||||
tools: A list of tools, `dicts`, or `Callable`.
|
||||
|
||||
!!! tip ""
|
||||
|
||||
See the [Models](https://docs.langchain.com/oss/python/langchain/models)
|
||||
docs for more information.
|
||||
tools: A list of tools, `dict`, or `Callable`.
|
||||
|
||||
If `None` or an empty list, the agent will consist of a model node without a
|
||||
tool calling loop.
|
||||
|
||||
|
||||
!!! tip ""
|
||||
|
||||
See the [Tools](https://docs.langchain.com/oss/python/langchain/tools)
|
||||
docs for more information.
|
||||
system_prompt: An optional system prompt for the LLM.
|
||||
|
||||
Prompts are converted to a
|
||||
@@ -555,24 +596,34 @@ def create_agent( # noqa: PLR0915
|
||||
beginning of the message list.
|
||||
middleware: A sequence of middleware instances to apply to the agent.
|
||||
|
||||
Middleware can intercept and modify agent behavior at various stages. See
|
||||
the [full guide](https://docs.langchain.com/oss/python/langchain/middleware).
|
||||
Middleware can intercept and modify agent behavior at various stages.
|
||||
|
||||
!!! tip ""
|
||||
|
||||
See the [Middleware](https://docs.langchain.com/oss/python/langchain/middleware)
|
||||
docs for more information.
|
||||
response_format: An optional configuration for structured responses.
|
||||
|
||||
Can be a `ToolStrategy`, `ProviderStrategy`, or a Pydantic model class.
|
||||
|
||||
If provided, the agent will handle structured output during the
|
||||
conversation flow. Raw schemas will be wrapped in an appropriate strategy
|
||||
based on model capabilities.
|
||||
conversation flow.
|
||||
|
||||
Raw schemas will be wrapped in an appropriate strategy based on model
|
||||
capabilities.
|
||||
|
||||
!!! tip ""
|
||||
|
||||
See the [Structured output](https://docs.langchain.com/oss/python/langchain/structured-output)
|
||||
docs for more information.
|
||||
state_schema: An optional `TypedDict` schema that extends `AgentState`.
|
||||
|
||||
When provided, this schema is used instead of `AgentState` as the base
|
||||
schema for merging with middleware state schemas. This allows users to
|
||||
add custom state fields without needing to create custom middleware.
|
||||
|
||||
Generally, it's recommended to use `state_schema` extensions via middleware
|
||||
to keep relevant extensions scoped to corresponding hooks / tools.
|
||||
|
||||
The schema must be a subclass of `AgentState[ResponseT]`.
|
||||
context_schema: An optional schema for runtime context.
|
||||
checkpointer: An optional checkpoint saver object.
|
||||
|
||||
@@ -966,7 +1017,7 @@ def create_agent( # noqa: PLR0915
|
||||
effective_response_format: ResponseFormat | None
|
||||
if isinstance(request.response_format, AutoStrategy):
|
||||
# User provided raw schema via AutoStrategy - auto-detect best strategy based on model
|
||||
if _supports_provider_strategy(request.model):
|
||||
if _supports_provider_strategy(request.model, tools=request.tools):
|
||||
# Model supports provider strategy - use it
|
||||
effective_response_format = ProviderStrategy(schema=request.response_format.schema)
|
||||
else:
|
||||
@@ -987,7 +1038,7 @@ def create_agent( # noqa: PLR0915
|
||||
|
||||
# Bind model based on effective response format
|
||||
if isinstance(effective_response_format, ProviderStrategy):
|
||||
# Use provider-specific structured output
|
||||
# (Backward compatibility) Use OpenAI format structured output
|
||||
kwargs = effective_response_format.to_model_kwargs()
|
||||
return (
|
||||
request.model.bind_tools(
|
||||
|
||||
@@ -1,15 +1,10 @@
|
||||
"""Entrypoint to using [Middleware](https://docs.langchain.com/oss/python/langchain/middleware) plugins with [Agents](https://docs.langchain.com/oss/python/langchain/agents).
|
||||
|
||||
!!! warning "Reference docs"
|
||||
This page contains **reference documentation** for Middleware. See
|
||||
[the docs](https://docs.langchain.com/oss/python/langchain/middleware) for conceptual
|
||||
guides, tutorials, and examples on using Middleware.
|
||||
""" # noqa: E501
|
||||
"""Entrypoint to using [middleware](https://docs.langchain.com/oss/python/langchain/middleware) plugins with [Agents](https://docs.langchain.com/oss/python/langchain/agents).""" # noqa: E501
|
||||
|
||||
from .context_editing import (
|
||||
ClearToolUsesEdit,
|
||||
ContextEditingMiddleware,
|
||||
)
|
||||
from .file_search import FilesystemFileSearchMiddleware
|
||||
from .human_in_the_loop import (
|
||||
HumanInTheLoopMiddleware,
|
||||
InterruptOnConfig,
|
||||
@@ -52,6 +47,7 @@ __all__ = [
|
||||
"CodexSandboxExecutionPolicy",
|
||||
"ContextEditingMiddleware",
|
||||
"DockerExecutionPolicy",
|
||||
"FilesystemFileSearchMiddleware",
|
||||
"HostExecutionPolicy",
|
||||
"HumanInTheLoopMiddleware",
|
||||
"InterruptOnConfig",
|
||||
|
||||
@@ -56,11 +56,12 @@ class BaseExecutionPolicy(abc.ABC):
|
||||
"""Configuration contract for persistent shell sessions.
|
||||
|
||||
Concrete subclasses encapsulate how a shell process is launched and constrained.
|
||||
|
||||
Each policy documents its security guarantees and the operating environments in
|
||||
which it is appropriate. Use :class:`HostExecutionPolicy` for trusted, same-host
|
||||
execution; :class:`CodexSandboxExecutionPolicy` when the Codex CLI sandbox is
|
||||
available and you want additional syscall restrictions; and
|
||||
:class:`DockerExecutionPolicy` for container-level isolation using Docker.
|
||||
which it is appropriate. Use `HostExecutionPolicy` for trusted, same-host execution;
|
||||
`CodexSandboxExecutionPolicy` when the Codex CLI sandbox is available and you want
|
||||
additional syscall restrictions; and `DockerExecutionPolicy` for container-level
|
||||
isolation using Docker.
|
||||
"""
|
||||
|
||||
command_timeout: float = 30.0
|
||||
@@ -91,13 +92,13 @@ class HostExecutionPolicy(BaseExecutionPolicy):
|
||||
|
||||
This policy is best suited for trusted or single-tenant environments (CI jobs,
|
||||
developer workstations, pre-sandboxed containers) where the agent must access the
|
||||
host filesystem and tooling without additional isolation. It enforces optional CPU
|
||||
and memory limits to prevent runaway commands but offers **no** filesystem or network
|
||||
host filesystem and tooling without additional isolation. Enforces optional CPU and
|
||||
memory limits to prevent runaway commands but offers **no** filesystem or network
|
||||
sandboxing; commands can modify anything the process user can reach.
|
||||
|
||||
On Linux platforms resource limits are applied with ``resource.prlimit`` after the
|
||||
shell starts. On macOS, where ``prlimit`` is unavailable, limits are set in a
|
||||
``preexec_fn`` before ``exec``. In both cases the shell runs in its own process group
|
||||
On Linux platforms resource limits are applied with `resource.prlimit` after the
|
||||
shell starts. On macOS, where `prlimit` is unavailable, limits are set in a
|
||||
`preexec_fn` before `exec`. In both cases the shell runs in its own process group
|
||||
so timeouts can terminate the full subtree.
|
||||
"""
|
||||
|
||||
@@ -199,9 +200,9 @@ class CodexSandboxExecutionPolicy(BaseExecutionPolicy):
|
||||
(Linux) profiles. Commands still run on the host, but within the sandbox requested by
|
||||
the CLI. If the Codex binary is unavailable or the runtime lacks the required
|
||||
kernel features (e.g., Landlock inside some containers), process startup fails with a
|
||||
:class:`RuntimeError`.
|
||||
`RuntimeError`.
|
||||
|
||||
Configure sandbox behaviour via ``config_overrides`` to align with your Codex CLI
|
||||
Configure sandbox behavior via `config_overrides` to align with your Codex CLI
|
||||
profile. This policy does not add its own resource limits; combine it with
|
||||
host-level guards (cgroups, container resource limits) as needed.
|
||||
"""
|
||||
@@ -271,17 +272,17 @@ class DockerExecutionPolicy(BaseExecutionPolicy):
|
||||
"""Run the shell inside a dedicated Docker container.
|
||||
|
||||
Choose this policy when commands originate from untrusted users or you require
|
||||
strong isolation between sessions. By default the workspace is bind-mounted only when
|
||||
it refers to an existing non-temporary directory; ephemeral sessions run without a
|
||||
mount to minimise host exposure. The container's network namespace is disabled by
|
||||
default (``--network none``) and you can enable further hardening via
|
||||
``read_only_rootfs`` and ``user``.
|
||||
strong isolation between sessions. By default the workspace is bind-mounted only
|
||||
when it refers to an existing non-temporary directory; ephemeral sessions run
|
||||
without a mount to minimise host exposure. The container's network namespace is
|
||||
disabled by default (`--network none`) and you can enable further hardening via
|
||||
`read_only_rootfs` and `user`.
|
||||
|
||||
The security guarantees depend on your Docker daemon configuration. Run the agent on
|
||||
a host where Docker is locked down (rootless mode, AppArmor/SELinux, etc.) and review
|
||||
any additional volumes or capabilities passed through ``extra_run_args``. The default
|
||||
image is ``python:3.12-alpine3.19``; supply a custom image if you need preinstalled
|
||||
tooling.
|
||||
a host where Docker is locked down (rootless mode, AppArmor/SELinux, etc.) and
|
||||
review any additional volumes or capabilities passed through ``extra_run_args``. The
|
||||
default image is `python:3.12-alpine3.19`; supply a custom image if you need
|
||||
preinstalled tooling.
|
||||
"""
|
||||
|
||||
binary: str = "docker"
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
"""Context editing middleware.
|
||||
|
||||
This middleware mirrors Anthropic's context editing capabilities by clearing
|
||||
older tool results once the conversation grows beyond a configurable token
|
||||
threshold. The implementation is intentionally model-agnostic so it can be used
|
||||
with any LangChain chat model.
|
||||
Mirrors Anthropic's context editing capabilities by clearing older tool results once the
|
||||
conversation grows beyond a configurable token threshold.
|
||||
|
||||
The implementation is intentionally model-agnostic so it can be used with any LangChain
|
||||
chat model.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -182,11 +183,13 @@ class ClearToolUsesEdit(ContextEdit):
|
||||
|
||||
|
||||
class ContextEditingMiddleware(AgentMiddleware):
|
||||
"""Automatically prunes tool results to manage context size.
|
||||
"""Automatically prune tool results to manage context size.
|
||||
|
||||
The middleware applies a sequence of edits when the total input token count
|
||||
exceeds configured thresholds. Currently the `ClearToolUsesEdit` strategy is
|
||||
supported, aligning with Anthropic's `clear_tool_uses_20250919` behaviour.
|
||||
The middleware applies a sequence of edits when the total input token count exceeds
|
||||
configured thresholds.
|
||||
|
||||
Currently the `ClearToolUsesEdit` strategy is supported, aligning with Anthropic's
|
||||
`clear_tool_uses_20250919` behavior [(read more)](https://docs.claude.com/en/docs/agents-and-tools/tool-use/memory-tool).
|
||||
"""
|
||||
|
||||
edits: list[ContextEdit]
|
||||
@@ -198,11 +201,12 @@ class ContextEditingMiddleware(AgentMiddleware):
|
||||
edits: Iterable[ContextEdit] | None = None,
|
||||
token_count_method: Literal["approximate", "model"] = "approximate", # noqa: S107
|
||||
) -> None:
|
||||
"""Initializes a context editing middleware instance.
|
||||
"""Initialize an instance of context editing middleware.
|
||||
|
||||
Args:
|
||||
edits: Sequence of edit strategies to apply. Defaults to a single
|
||||
`ClearToolUsesEdit` mirroring Anthropic defaults.
|
||||
edits: Sequence of edit strategies to apply.
|
||||
|
||||
Defaults to a single `ClearToolUsesEdit` mirroring Anthropic defaults.
|
||||
token_count_method: Whether to use approximate token counting
|
||||
(faster, less accurate) or exact counting implemented by the
|
||||
chat model (potentially slower, more accurate).
|
||||
|
||||
@@ -21,7 +21,7 @@ from langchain.agents.middleware.types import AgentMiddleware
|
||||
|
||||
|
||||
def _expand_include_patterns(pattern: str) -> list[str] | None:
|
||||
"""Expand brace patterns like ``*.{py,pyi}`` into a list of globs."""
|
||||
"""Expand brace patterns like `*.{py,pyi}` into a list of globs."""
|
||||
if "}" in pattern and "{" not in pattern:
|
||||
return None
|
||||
|
||||
@@ -88,6 +88,7 @@ class FilesystemFileSearchMiddleware(AgentMiddleware):
|
||||
"""Provides Glob and Grep search over filesystem files.
|
||||
|
||||
This middleware adds two tools that search through local filesystem:
|
||||
|
||||
- Glob: Fast file pattern matching by file path
|
||||
- Grep: Fast content search using ripgrep or Python fallback
|
||||
|
||||
@@ -100,7 +101,7 @@ class FilesystemFileSearchMiddleware(AgentMiddleware):
|
||||
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
tools=[],
|
||||
tools=[], # Add tools as needed
|
||||
middleware=[
|
||||
FilesystemFileSearchMiddleware(root_path="/workspace"),
|
||||
],
|
||||
@@ -119,9 +120,10 @@ class FilesystemFileSearchMiddleware(AgentMiddleware):
|
||||
|
||||
Args:
|
||||
root_path: Root directory to search.
|
||||
use_ripgrep: Whether to use ripgrep for search (default: True).
|
||||
Falls back to Python if ripgrep unavailable.
|
||||
max_file_size_mb: Maximum file size to search in MB (default: 10).
|
||||
use_ripgrep: Whether to use `ripgrep` for search.
|
||||
|
||||
Falls back to Python if `ripgrep` unavailable.
|
||||
max_file_size_mb: Maximum file size to search in MB.
|
||||
"""
|
||||
self.root_path = Path(root_path).resolve()
|
||||
self.use_ripgrep = use_ripgrep
|
||||
@@ -132,8 +134,10 @@ class FilesystemFileSearchMiddleware(AgentMiddleware):
|
||||
def glob_search(pattern: str, path: str = "/") -> str:
|
||||
"""Fast file pattern matching tool that works with any codebase size.
|
||||
|
||||
Supports glob patterns like **/*.js or src/**/*.ts.
|
||||
Supports glob patterns like `**/*.js` or `src/**/*.ts`.
|
||||
|
||||
Returns matching file paths sorted by modification time.
|
||||
|
||||
Use this tool when you need to find files by name patterns.
|
||||
|
||||
Args:
|
||||
@@ -142,7 +146,7 @@ class FilesystemFileSearchMiddleware(AgentMiddleware):
|
||||
|
||||
Returns:
|
||||
Newline-separated list of matching file paths, sorted by modification
|
||||
time (most recently modified first). Returns "No files found" if no
|
||||
time (most recently modified first). Returns `'No files found'` if no
|
||||
matches.
|
||||
"""
|
||||
try:
|
||||
@@ -184,15 +188,16 @@ class FilesystemFileSearchMiddleware(AgentMiddleware):
|
||||
Args:
|
||||
pattern: The regular expression pattern to search for in file contents.
|
||||
path: The directory to search in. If not specified, searches from root.
|
||||
include: File pattern to filter (e.g., "*.js", "*.{ts,tsx}").
|
||||
include: File pattern to filter (e.g., `'*.js'`, `'*.{ts,tsx}'`).
|
||||
output_mode: Output format:
|
||||
- "files_with_matches": Only file paths containing matches (default)
|
||||
- "content": Matching lines with file:line:content format
|
||||
- "count": Count of matches per file
|
||||
|
||||
- `'files_with_matches'`: Only file paths containing matches
|
||||
- `'content'`: Matching lines with `file:line:content` format
|
||||
- `'count'`: Count of matches per file
|
||||
|
||||
Returns:
|
||||
Search results formatted according to output_mode. Returns "No matches
|
||||
found" if no results.
|
||||
Search results formatted according to `output_mode`.
|
||||
Returns `'No matches found'` if no results.
|
||||
"""
|
||||
# Compile regex pattern (for validation)
|
||||
try:
|
||||
|
||||
@@ -14,10 +14,10 @@ class Action(TypedDict):
|
||||
"""Represents an action with a name and args."""
|
||||
|
||||
name: str
|
||||
"""The type or name of action being requested (e.g., "add_numbers")."""
|
||||
"""The type or name of action being requested (e.g., `'add_numbers'`)."""
|
||||
|
||||
args: dict[str, Any]
|
||||
"""Key-value pairs of args needed for the action (e.g., {"a": 1, "b": 2})."""
|
||||
"""Key-value pairs of args needed for the action (e.g., `{"a": 1, "b": 2}`)."""
|
||||
|
||||
|
||||
class ActionRequest(TypedDict):
|
||||
@@ -27,7 +27,7 @@ class ActionRequest(TypedDict):
|
||||
"""The name of the action being requested."""
|
||||
|
||||
args: dict[str, Any]
|
||||
"""Key-value pairs of args needed for the action (e.g., {"a": 1, "b": 2})."""
|
||||
"""Key-value pairs of args needed for the action (e.g., `{"a": 1, "b": 2}`)."""
|
||||
|
||||
description: NotRequired[str]
|
||||
"""The description of the action to be reviewed."""
|
||||
@@ -169,18 +169,22 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
|
||||
|
||||
Args:
|
||||
interrupt_on: Mapping of tool name to allowed actions.
|
||||
|
||||
If a tool doesn't have an entry, it's auto-approved by default.
|
||||
|
||||
* `True` indicates all decisions are allowed: approve, edit, and reject.
|
||||
* `False` indicates that the tool is auto-approved.
|
||||
* `InterruptOnConfig` indicates the specific decisions allowed for this
|
||||
tool.
|
||||
The InterruptOnConfig can include a `description` field (`str` or
|
||||
|
||||
The `InterruptOnConfig` can include a `description` field (`str` or
|
||||
`Callable`) for custom formatting of the interrupt description.
|
||||
description_prefix: The prefix to use when constructing action requests.
|
||||
|
||||
This is used to provide context about the tool call and the action being
|
||||
requested. Not used if a tool has a `description` in its
|
||||
`InterruptOnConfig`.
|
||||
requested.
|
||||
|
||||
Not used if a tool has a `description` in its `InterruptOnConfig`.
|
||||
"""
|
||||
super().__init__()
|
||||
resolved_configs: dict[str, InterruptOnConfig] = {}
|
||||
@@ -349,3 +353,7 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
|
||||
last_ai_msg.tool_calls = revised_tool_calls
|
||||
|
||||
return {"messages": [last_ai_msg, *artificial_tool_messages]}
|
||||
|
||||
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
|
||||
"""Async trigger interrupt flows for relevant tool calls after an `AIMessage`."""
|
||||
return self.after_model(state, runtime)
|
||||
|
||||
@@ -20,9 +20,9 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class ModelCallLimitState(AgentState):
|
||||
"""State schema for ModelCallLimitMiddleware.
|
||||
"""State schema for `ModelCallLimitMiddleware`.
|
||||
|
||||
Extends AgentState with model call tracking fields.
|
||||
Extends `AgentState` with model call tracking fields.
|
||||
"""
|
||||
|
||||
thread_model_call_count: NotRequired[Annotated[int, PrivateStateAttr]]
|
||||
@@ -58,8 +58,8 @@ def _build_limit_exceeded_message(
|
||||
class ModelCallLimitExceededError(Exception):
|
||||
"""Exception raised when model call limits are exceeded.
|
||||
|
||||
This exception is raised when the configured exit behavior is 'error'
|
||||
and either the thread or run model call limit has been exceeded.
|
||||
This exception is raised when the configured exit behavior is `'error'` and either
|
||||
the thread or run model call limit has been exceeded.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -127,13 +127,17 @@ class ModelCallLimitMiddleware(AgentMiddleware[ModelCallLimitState, Any]):
|
||||
|
||||
Args:
|
||||
thread_limit: Maximum number of model calls allowed per thread.
|
||||
None means no limit.
|
||||
|
||||
`None` means no limit.
|
||||
run_limit: Maximum number of model calls allowed per run.
|
||||
None means no limit.
|
||||
|
||||
`None` means no limit.
|
||||
exit_behavior: What to do when limits are exceeded.
|
||||
- "end": Jump to the end of the agent execution and
|
||||
inject an artificial AI message indicating that the limit was exceeded.
|
||||
- "error": Raise a `ModelCallLimitExceededError`
|
||||
|
||||
- `'end'`: Jump to the end of the agent execution and
|
||||
inject an artificial AI message indicating that the limit was
|
||||
exceeded.
|
||||
- `'error'`: Raise a `ModelCallLimitExceededError`
|
||||
|
||||
Raises:
|
||||
ValueError: If both limits are `None` or if `exit_behavior` is invalid.
|
||||
@@ -161,12 +165,13 @@ class ModelCallLimitMiddleware(AgentMiddleware[ModelCallLimitState, Any]):
|
||||
runtime: The langgraph runtime.
|
||||
|
||||
Returns:
|
||||
If limits are exceeded and exit_behavior is "end", returns
|
||||
a Command to jump to the end with a limit exceeded message. Otherwise returns None.
|
||||
If limits are exceeded and exit_behavior is `'end'`, returns
|
||||
a `Command` to jump to the end with a limit exceeded message. Otherwise
|
||||
returns `None`.
|
||||
|
||||
Raises:
|
||||
ModelCallLimitExceededError: If limits are exceeded and exit_behavior
|
||||
is "error".
|
||||
ModelCallLimitExceededError: If limits are exceeded and `exit_behavior`
|
||||
is `'error'`.
|
||||
"""
|
||||
thread_count = state.get("thread_model_call_count", 0)
|
||||
run_count = state.get("run_model_call_count", 0)
|
||||
@@ -194,6 +199,29 @@ class ModelCallLimitMiddleware(AgentMiddleware[ModelCallLimitState, Any]):
|
||||
|
||||
return None
|
||||
|
||||
@hook_config(can_jump_to=["end"])
|
||||
async def abefore_model(
|
||||
self,
|
||||
state: ModelCallLimitState,
|
||||
runtime: Runtime,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Async check model call limits before making a model call.
|
||||
|
||||
Args:
|
||||
state: The current agent state containing call counts.
|
||||
runtime: The langgraph runtime.
|
||||
|
||||
Returns:
|
||||
If limits are exceeded and exit_behavior is `'end'`, returns
|
||||
a `Command` to jump to the end with a limit exceeded message. Otherwise
|
||||
returns `None`.
|
||||
|
||||
Raises:
|
||||
ModelCallLimitExceededError: If limits are exceeded and `exit_behavior`
|
||||
is `'error'`.
|
||||
"""
|
||||
return self.before_model(state, runtime)
|
||||
|
||||
def after_model(self, state: ModelCallLimitState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
|
||||
"""Increment model call counts after a model call.
|
||||
|
||||
@@ -208,3 +236,19 @@ class ModelCallLimitMiddleware(AgentMiddleware[ModelCallLimitState, Any]):
|
||||
"thread_model_call_count": state.get("thread_model_call_count", 0) + 1,
|
||||
"run_model_call_count": state.get("run_model_call_count", 0) + 1,
|
||||
}
|
||||
|
||||
async def aafter_model(
|
||||
self,
|
||||
state: ModelCallLimitState,
|
||||
runtime: Runtime,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Async increment model call counts after a model call.
|
||||
|
||||
Args:
|
||||
state: The current agent state.
|
||||
runtime: The langgraph runtime.
|
||||
|
||||
Returns:
|
||||
State updates with incremented call counts.
|
||||
"""
|
||||
return self.after_model(state, runtime)
|
||||
|
||||
@@ -22,7 +22,7 @@ class ModelFallbackMiddleware(AgentMiddleware):
|
||||
"""Automatic fallback to alternative models on errors.
|
||||
|
||||
Retries failed model calls with alternative models in sequence until
|
||||
success or all models exhausted. Primary model specified in create_agent().
|
||||
success or all models exhausted. Primary model specified in `create_agent`.
|
||||
|
||||
Example:
|
||||
```python
|
||||
|
||||
@@ -27,24 +27,26 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class PIIMiddleware(AgentMiddleware):
|
||||
"""Detect and handle Personally Identifiable Information (PII) in agent conversations.
|
||||
"""Detect and handle Personally Identifiable Information (PII) in conversations.
|
||||
|
||||
This middleware detects common PII types and applies configurable strategies
|
||||
to handle them. It can detect emails, credit cards, IP addresses,
|
||||
MAC addresses, and URLs in both user input and agent output.
|
||||
to handle them. It can detect emails, credit cards, IP addresses, MAC addresses, and
|
||||
URLs in both user input and agent output.
|
||||
|
||||
Built-in PII types:
|
||||
- `email`: Email addresses
|
||||
- `credit_card`: Credit card numbers (validated with Luhn algorithm)
|
||||
- `ip`: IP addresses (validated with stdlib)
|
||||
- `mac_address`: MAC addresses
|
||||
- `url`: URLs (both `http`/`https` and bare URLs)
|
||||
|
||||
- `email`: Email addresses
|
||||
- `credit_card`: Credit card numbers (validated with Luhn algorithm)
|
||||
- `ip`: IP addresses (validated with stdlib)
|
||||
- `mac_address`: MAC addresses
|
||||
- `url`: URLs (both `http`/`https` and bare URLs)
|
||||
|
||||
Strategies:
|
||||
- `block`: Raise an exception when PII is detected
|
||||
- `redact`: Replace PII with `[REDACTED_TYPE]` placeholders
|
||||
- `mask`: Partially mask PII (e.g., `****-****-****-1234` for credit card)
|
||||
- `hash`: Replace PII with deterministic hash (e.g., `<email_hash:a1b2c3d4>`)
|
||||
|
||||
- `block`: Raise an exception when PII is detected
|
||||
- `redact`: Replace PII with `[REDACTED_TYPE]` placeholders
|
||||
- `mask`: Partially mask PII (e.g., `****-****-****-1234` for credit card)
|
||||
- `hash`: Replace PII with deterministic hash (e.g., `<email_hash:a1b2c3d4>`)
|
||||
|
||||
Strategy Selection Guide:
|
||||
|
||||
@@ -101,12 +103,15 @@ class PIIMiddleware(AgentMiddleware):
|
||||
"""Initialize the PII detection middleware.
|
||||
|
||||
Args:
|
||||
pii_type: Type of PII to detect. Can be a built-in type
|
||||
(`email`, `credit_card`, `ip`, `mac_address`, `url`)
|
||||
or a custom type name.
|
||||
strategy: How to handle detected PII:
|
||||
pii_type: Type of PII to detect.
|
||||
|
||||
* `block`: Raise PIIDetectionError when PII is detected
|
||||
Can be a built-in type (`email`, `credit_card`, `ip`, `mac_address`,
|
||||
`url`) or a custom type name.
|
||||
strategy: How to handle detected PII.
|
||||
|
||||
Options:
|
||||
|
||||
* `block`: Raise `PIIDetectionError` when PII is detected
|
||||
* `redact`: Replace with `[REDACTED_TYPE]` placeholders
|
||||
* `mask`: Partially mask PII (show last few characters)
|
||||
* `hash`: Replace with deterministic hash (format: `<type_hash:digest>`)
|
||||
@@ -114,16 +119,15 @@ class PIIMiddleware(AgentMiddleware):
|
||||
detector: Custom detector function or regex pattern.
|
||||
|
||||
* If `Callable`: Function that takes content string and returns
|
||||
list of PIIMatch objects
|
||||
list of `PIIMatch` objects
|
||||
* If `str`: Regex pattern to match PII
|
||||
* If `None`: Uses built-in detector for the pii_type
|
||||
|
||||
* If `None`: Uses built-in detector for the `pii_type`
|
||||
apply_to_input: Whether to check user messages before model call.
|
||||
apply_to_output: Whether to check AI messages after model call.
|
||||
apply_to_tool_results: Whether to check tool result messages after tool execution.
|
||||
|
||||
Raises:
|
||||
ValueError: If pii_type is not built-in and no detector is provided.
|
||||
ValueError: If `pii_type` is not built-in and no detector is provided.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
@@ -166,10 +170,11 @@ class PIIMiddleware(AgentMiddleware):
|
||||
runtime: The langgraph runtime.
|
||||
|
||||
Returns:
|
||||
Updated state with PII handled according to strategy, or None if no PII detected.
|
||||
Updated state with PII handled according to strategy, or `None` if no PII
|
||||
detected.
|
||||
|
||||
Raises:
|
||||
PIIDetectionError: If PII is detected and strategy is "block".
|
||||
PIIDetectionError: If PII is detected and strategy is `'block'`.
|
||||
"""
|
||||
if not self.apply_to_input and not self.apply_to_tool_results:
|
||||
return None
|
||||
@@ -247,6 +252,27 @@ class PIIMiddleware(AgentMiddleware):
|
||||
|
||||
return None
|
||||
|
||||
@hook_config(can_jump_to=["end"])
|
||||
async def abefore_model(
|
||||
self,
|
||||
state: AgentState,
|
||||
runtime: Runtime,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Async check user messages and tool results for PII before model invocation.
|
||||
|
||||
Args:
|
||||
state: The current agent state.
|
||||
runtime: The langgraph runtime.
|
||||
|
||||
Returns:
|
||||
Updated state with PII handled according to strategy, or `None` if no PII
|
||||
detected.
|
||||
|
||||
Raises:
|
||||
PIIDetectionError: If PII is detected and strategy is `'block'`.
|
||||
"""
|
||||
return self.before_model(state, runtime)
|
||||
|
||||
def after_model(
|
||||
self,
|
||||
state: AgentState,
|
||||
@@ -259,10 +285,11 @@ class PIIMiddleware(AgentMiddleware):
|
||||
runtime: The langgraph runtime.
|
||||
|
||||
Returns:
|
||||
Updated state with PII handled according to strategy, or None if no PII detected.
|
||||
Updated state with PII handled according to strategy, or None if no PII
|
||||
detected.
|
||||
|
||||
Raises:
|
||||
PIIDetectionError: If PII is detected and strategy is "block".
|
||||
PIIDetectionError: If PII is detected and strategy is `'block'`.
|
||||
"""
|
||||
if not self.apply_to_output:
|
||||
return None
|
||||
@@ -305,6 +332,26 @@ class PIIMiddleware(AgentMiddleware):
|
||||
|
||||
return {"messages": new_messages}
|
||||
|
||||
async def aafter_model(
|
||||
self,
|
||||
state: AgentState,
|
||||
runtime: Runtime,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Async check AI messages for PII after model invocation.
|
||||
|
||||
Args:
|
||||
state: The current agent state.
|
||||
runtime: The langgraph runtime.
|
||||
|
||||
Returns:
|
||||
Updated state with PII handled according to strategy, or None if no PII
|
||||
detected.
|
||||
|
||||
Raises:
|
||||
PIIDetectionError: If PII is detected and strategy is `'block'`.
|
||||
"""
|
||||
return self.after_model(state, runtime)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"PIIDetectionError",
|
||||
|
||||
@@ -11,17 +11,17 @@ import subprocess
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
import typing
|
||||
import uuid
|
||||
import weakref
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Annotated, Any, Literal
|
||||
from typing import TYPE_CHECKING, Annotated, Any, Literal, cast
|
||||
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langchain_core.tools.base import BaseTool, ToolException
|
||||
from langchain_core.tools.base import ToolException
|
||||
from langgraph.channels.untracked_value import UntrackedValue
|
||||
from pydantic import BaseModel, model_validator
|
||||
from pydantic.json_schema import SkipJsonSchema
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
from langchain.agents.middleware._execution import (
|
||||
@@ -38,14 +38,13 @@ from langchain.agents.middleware._redaction import (
|
||||
ResolvedRedactionRule,
|
||||
)
|
||||
from langchain.agents.middleware.types import AgentMiddleware, AgentState, PrivateStateAttr
|
||||
from langchain.tools import ToolRuntime, tool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Mapping, Sequence
|
||||
|
||||
from langgraph.runtime import Runtime
|
||||
from langgraph.types import Command
|
||||
|
||||
from langchain.agents.middleware.types import ToolCallRequest
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
_DONE_MARKER_PREFIX = "__LC_SHELL_DONE__"
|
||||
@@ -59,6 +58,7 @@ DEFAULT_TOOL_DESCRIPTION = (
|
||||
"session remains stable. Outputs may be truncated when they become very large, and long "
|
||||
"running commands will be terminated once their configured timeout elapses."
|
||||
)
|
||||
SHELL_TOOL_NAME = "shell"
|
||||
|
||||
|
||||
def _cleanup_resources(
|
||||
@@ -334,7 +334,17 @@ class _ShellToolInput(BaseModel):
|
||||
"""Input schema for the persistent shell tool."""
|
||||
|
||||
command: str | None = None
|
||||
"""The shell command to execute."""
|
||||
|
||||
restart: bool | None = None
|
||||
"""Whether to restart the shell session."""
|
||||
|
||||
runtime: Annotated[Any, SkipJsonSchema()] = None
|
||||
"""The runtime for the shell tool.
|
||||
|
||||
Included as a workaround at the moment bc args_schema doesn't work with
|
||||
injected ToolRuntime.
|
||||
"""
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_payload(self) -> _ShellToolInput:
|
||||
@@ -347,38 +357,21 @@ class _ShellToolInput(BaseModel):
|
||||
return self
|
||||
|
||||
|
||||
class _PersistentShellTool(BaseTool):
|
||||
"""Tool wrapper that relies on middleware interception for execution."""
|
||||
|
||||
name: str = "shell"
|
||||
description: str = DEFAULT_TOOL_DESCRIPTION
|
||||
args_schema: type[BaseModel] = _ShellToolInput
|
||||
|
||||
def __init__(self, middleware: ShellToolMiddleware, description: str | None = None) -> None:
|
||||
super().__init__()
|
||||
self._middleware = middleware
|
||||
if description is not None:
|
||||
self.description = description
|
||||
|
||||
def _run(self, **_: Any) -> Any: # pragma: no cover - executed via middleware wrapper
|
||||
msg = "Persistent shell tool execution should be intercepted via middleware wrappers."
|
||||
raise RuntimeError(msg)
|
||||
|
||||
|
||||
class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
|
||||
"""Middleware that registers a persistent shell tool for agents.
|
||||
|
||||
The middleware exposes a single long-lived shell session. Use the execution policy to
|
||||
match your deployment's security posture:
|
||||
The middleware exposes a single long-lived shell session. Use the execution policy
|
||||
to match your deployment's security posture:
|
||||
|
||||
* ``HostExecutionPolicy`` - full host access; best for trusted environments where the
|
||||
agent already runs inside a container or VM that provides isolation.
|
||||
* ``CodexSandboxExecutionPolicy`` - reuses the Codex CLI sandbox for additional
|
||||
syscall/filesystem restrictions when the CLI is available.
|
||||
* ``DockerExecutionPolicy`` - launches a separate Docker container for each agent run,
|
||||
providing harder isolation, optional read-only root filesystems, and user remapping.
|
||||
* `HostExecutionPolicy` – full host access; best for trusted environments where the
|
||||
agent already runs inside a container or VM that provides isolation.
|
||||
* `CodexSandboxExecutionPolicy` – reuses the Codex CLI sandbox for additional
|
||||
syscall/filesystem restrictions when the CLI is available.
|
||||
* `DockerExecutionPolicy` – launches a separate Docker container for each agent run,
|
||||
providing harder isolation, optional read-only root filesystems, and user
|
||||
remapping.
|
||||
|
||||
When no policy is provided the middleware defaults to ``HostExecutionPolicy``.
|
||||
When no policy is provided the middleware defaults to `HostExecutionPolicy`.
|
||||
"""
|
||||
|
||||
state_schema = ShellToolState
|
||||
@@ -392,29 +385,43 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
|
||||
execution_policy: BaseExecutionPolicy | None = None,
|
||||
redaction_rules: tuple[RedactionRule, ...] | list[RedactionRule] | None = None,
|
||||
tool_description: str | None = None,
|
||||
tool_name: str = SHELL_TOOL_NAME,
|
||||
shell_command: Sequence[str] | str | None = None,
|
||||
env: Mapping[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Initialize the middleware.
|
||||
"""Initialize an instance of `ShellToolMiddleware`.
|
||||
|
||||
Args:
|
||||
workspace_root: Base directory for the shell session. If omitted, a temporary
|
||||
directory is created when the agent starts and removed when it ends.
|
||||
startup_commands: Optional commands executed sequentially after the session starts.
|
||||
workspace_root: Base directory for the shell session.
|
||||
|
||||
If omitted, a temporary directory is created when the agent starts and
|
||||
removed when it ends.
|
||||
startup_commands: Optional commands executed sequentially after the session
|
||||
starts.
|
||||
shutdown_commands: Optional commands executed before the session shuts down.
|
||||
execution_policy: Execution policy controlling timeouts, output limits, and resource
|
||||
configuration. Defaults to :class:`HostExecutionPolicy` for native execution.
|
||||
execution_policy: Execution policy controlling timeouts, output limits, and
|
||||
resource configuration.
|
||||
|
||||
Defaults to `HostExecutionPolicy` for native execution.
|
||||
redaction_rules: Optional redaction rules to sanitize command output before
|
||||
returning it to the model.
|
||||
tool_description: Optional override for the registered shell tool description.
|
||||
shell_command: Optional shell executable (string) or argument sequence used to
|
||||
launch the persistent session. Defaults to an implementation-defined bash command.
|
||||
env: Optional environment variables to supply to the shell session. Values are
|
||||
coerced to strings before command execution. If omitted, the session inherits the
|
||||
parent process environment.
|
||||
tool_description: Optional override for the registered shell tool
|
||||
description.
|
||||
tool_name: Name for the registered shell tool.
|
||||
|
||||
Defaults to `"shell"`.
|
||||
shell_command: Optional shell executable (string) or argument sequence used
|
||||
to launch the persistent session.
|
||||
|
||||
Defaults to an implementation-defined bash command.
|
||||
env: Optional environment variables to supply to the shell session.
|
||||
|
||||
Values are coerced to strings before command execution. If omitted, the
|
||||
session inherits the parent process environment.
|
||||
"""
|
||||
super().__init__()
|
||||
self._workspace_root = Path(workspace_root) if workspace_root else None
|
||||
self._tool_name = tool_name
|
||||
self._shell_command = self._normalize_shell_command(shell_command)
|
||||
self._environment = self._normalize_env(env)
|
||||
if execution_policy is not None:
|
||||
@@ -428,9 +435,25 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
|
||||
self._startup_commands = self._normalize_commands(startup_commands)
|
||||
self._shutdown_commands = self._normalize_commands(shutdown_commands)
|
||||
|
||||
# Create a proper tool that executes directly (no interception needed)
|
||||
description = tool_description or DEFAULT_TOOL_DESCRIPTION
|
||||
self._tool = _PersistentShellTool(self, description=description)
|
||||
self.tools = [self._tool]
|
||||
|
||||
@tool(self._tool_name, args_schema=_ShellToolInput, description=description)
|
||||
def shell_tool(
|
||||
*,
|
||||
runtime: ToolRuntime[None, ShellToolState],
|
||||
command: str | None = None,
|
||||
restart: bool = False,
|
||||
) -> ToolMessage | str:
|
||||
resources = self._get_or_create_resources(runtime.state)
|
||||
return self._run_shell_tool(
|
||||
resources,
|
||||
{"command": command, "restart": restart},
|
||||
tool_call_id=runtime.tool_call_id,
|
||||
)
|
||||
|
||||
self._shell_tool = shell_tool
|
||||
self.tools = [self._shell_tool]
|
||||
|
||||
@staticmethod
|
||||
def _normalize_commands(
|
||||
@@ -468,36 +491,48 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
|
||||
|
||||
def before_agent(self, state: ShellToolState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
|
||||
"""Start the shell session and run startup commands."""
|
||||
resources = self._create_resources()
|
||||
resources = self._get_or_create_resources(state)
|
||||
return {"shell_session_resources": resources}
|
||||
|
||||
async def abefore_agent(self, state: ShellToolState, runtime: Runtime) -> dict[str, Any] | None:
|
||||
"""Async counterpart to `before_agent`."""
|
||||
"""Async start the shell session and run startup commands."""
|
||||
return self.before_agent(state, runtime)
|
||||
|
||||
def after_agent(self, state: ShellToolState, runtime: Runtime) -> None: # noqa: ARG002
|
||||
"""Run shutdown commands and release resources when an agent completes."""
|
||||
resources = self._ensure_resources(state)
|
||||
resources = state.get("shell_session_resources")
|
||||
if not isinstance(resources, _SessionResources):
|
||||
# Resources were never created, nothing to clean up
|
||||
return
|
||||
try:
|
||||
self._run_shutdown_commands(resources.session)
|
||||
finally:
|
||||
resources._finalizer()
|
||||
|
||||
async def aafter_agent(self, state: ShellToolState, runtime: Runtime) -> None:
|
||||
"""Async counterpart to `after_agent`."""
|
||||
"""Async run shutdown commands and release resources when an agent completes."""
|
||||
return self.after_agent(state, runtime)
|
||||
|
||||
def _ensure_resources(self, state: ShellToolState) -> _SessionResources:
|
||||
def _get_or_create_resources(self, state: ShellToolState) -> _SessionResources:
|
||||
"""Get existing resources from state or create new ones if they don't exist.
|
||||
|
||||
This method enables resumability by checking if resources already exist in the state
|
||||
(e.g., after an interrupt), and only creating new resources if they're not present.
|
||||
|
||||
Args:
|
||||
state: The agent state which may contain shell session resources.
|
||||
|
||||
Returns:
|
||||
Session resources, either retrieved from state or newly created.
|
||||
"""
|
||||
resources = state.get("shell_session_resources")
|
||||
if resources is not None and not isinstance(resources, _SessionResources):
|
||||
resources = None
|
||||
if resources is None:
|
||||
msg = (
|
||||
"Shell session resources are unavailable. Ensure `before_agent` ran successfully "
|
||||
"before invoking the shell tool."
|
||||
)
|
||||
raise ToolException(msg)
|
||||
return resources
|
||||
if isinstance(resources, _SessionResources):
|
||||
return resources
|
||||
|
||||
new_resources = self._create_resources()
|
||||
# Cast needed to make state dict-like for mutation
|
||||
cast("dict[str, Any]", state)["shell_session_resources"] = new_resources
|
||||
return new_resources
|
||||
|
||||
def _create_resources(self) -> _SessionResources:
|
||||
workspace = self._workspace_root
|
||||
@@ -659,36 +694,6 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
|
||||
artifact=artifact,
|
||||
)
|
||||
|
||||
def wrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: typing.Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
"""Intercept local shell tool calls and execute them via the managed session."""
|
||||
if isinstance(request.tool, _PersistentShellTool):
|
||||
resources = self._ensure_resources(request.state)
|
||||
return self._run_shell_tool(
|
||||
resources,
|
||||
request.tool_call["args"],
|
||||
tool_call_id=request.tool_call.get("id"),
|
||||
)
|
||||
return handler(request)
|
||||
|
||||
async def awrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: typing.Callable[[ToolCallRequest], typing.Awaitable[ToolMessage | Command]],
|
||||
) -> ToolMessage | Command:
|
||||
"""Async interception mirroring the synchronous tool handler."""
|
||||
if isinstance(request.tool, _PersistentShellTool):
|
||||
resources = self._ensure_resources(request.state)
|
||||
return self._run_shell_tool(
|
||||
resources,
|
||||
request.tool_call["args"],
|
||||
tool_call_id=request.tool_call.get("id"),
|
||||
)
|
||||
return await handler(request)
|
||||
|
||||
def _format_tool_message(
|
||||
self,
|
||||
content: str,
|
||||
@@ -703,7 +708,7 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
|
||||
return ToolMessage(
|
||||
content=content,
|
||||
tool_call_id=tool_call_id,
|
||||
name=self._tool.name,
|
||||
name=self._tool_name,
|
||||
status=status,
|
||||
artifact=artifact,
|
||||
)
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
"""Summarization middleware."""
|
||||
|
||||
import uuid
|
||||
from collections.abc import Callable, Iterable
|
||||
from typing import Any, cast
|
||||
import warnings
|
||||
from collections.abc import Callable, Iterable, Mapping
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
@@ -51,13 +52,26 @@ Messages to summarize:
|
||||
{messages}
|
||||
</messages>""" # noqa: E501
|
||||
|
||||
SUMMARY_PREFIX = "## Previous conversation summary:"
|
||||
|
||||
_DEFAULT_MESSAGES_TO_KEEP = 20
|
||||
_DEFAULT_TRIM_TOKEN_LIMIT = 4000
|
||||
_DEFAULT_FALLBACK_MESSAGE_COUNT = 15
|
||||
_SEARCH_RANGE_FOR_TOOL_PAIRS = 5
|
||||
|
||||
ContextFraction = tuple[Literal["fraction"], float]
|
||||
"""Tuple specifying context size as a fraction of the model's context window."""
|
||||
ContextTokens = tuple[Literal["tokens"], int]
|
||||
"""Tuple specifying context size as a number of tokens."""
|
||||
ContextMessages = tuple[Literal["messages"], int]
|
||||
"""Tuple specifying context size as a number of messages."""
|
||||
|
||||
ContextSize = ContextFraction | ContextTokens | ContextMessages
|
||||
"""Context size tuple to specify how much history to preserve."""
|
||||
|
||||
ContextCondition = ContextSize | list[ContextSize | list[ContextSize]]
|
||||
"""Recursive type to support nested AND/OR conditions
|
||||
|
||||
Top-level list = OR logic, nested list = AND logic."""
|
||||
|
||||
|
||||
class SummarizationMiddleware(AgentMiddleware):
|
||||
"""Summarizes conversation history when token limits are approached.
|
||||
@@ -70,34 +84,100 @@ class SummarizationMiddleware(AgentMiddleware):
|
||||
def __init__(
|
||||
self,
|
||||
model: str | BaseChatModel,
|
||||
max_tokens_before_summary: int | None = None,
|
||||
messages_to_keep: int = _DEFAULT_MESSAGES_TO_KEEP,
|
||||
*,
|
||||
trigger: ContextCondition | None = None,
|
||||
keep: ContextSize = ("messages", _DEFAULT_MESSAGES_TO_KEEP),
|
||||
token_counter: TokenCounter = count_tokens_approximately,
|
||||
summary_prompt: str = DEFAULT_SUMMARY_PROMPT,
|
||||
summary_prefix: str = SUMMARY_PREFIX,
|
||||
trim_tokens_to_summarize: int | None = _DEFAULT_TRIM_TOKEN_LIMIT,
|
||||
**deprecated_kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the summarization middleware.
|
||||
"""Initialize summarization middleware.
|
||||
|
||||
Args:
|
||||
model: The language model to use for generating summaries.
|
||||
max_tokens_before_summary: Token threshold to trigger summarization.
|
||||
If `None`, summarization is disabled.
|
||||
messages_to_keep: Number of recent messages to preserve after summarization.
|
||||
trigger: One or more thresholds that trigger summarization. Supports flexible
|
||||
AND/OR logic via nested lists. Top-level list items are combined with OR,
|
||||
nested lists are combined with AND. Examples:
|
||||
- Single condition: `("messages", 50)`
|
||||
- OR conditions: `[("tokens", 3000), ("messages", 100)]` (triggers when
|
||||
tokens >= 3000 OR messages >= 100)
|
||||
- AND conditions: `[("tokens", 500), ("fraction", 0.8)]` as a nested list
|
||||
within the top-level list
|
||||
- Mixed AND/OR: `[("messages", 10), [("tokens", 500), ("fraction", 0.8)]]`
|
||||
(triggers when messages >= 10 OR (tokens >= 500 AND fraction >= 0.8))
|
||||
keep: Context retention policy applied after summarization.
|
||||
|
||||
Provide a `ContextSize` tuple to specify how much history to preserve.
|
||||
|
||||
Defaults to keeping the most recent 20 messages.
|
||||
|
||||
Examples: `("messages", 20)`, `("tokens", 3000)`, or
|
||||
`("fraction", 0.3)`.
|
||||
token_counter: Function to count tokens in messages.
|
||||
summary_prompt: Prompt template for generating summaries.
|
||||
summary_prefix: Prefix added to system message when including summary.
|
||||
trim_tokens_to_summarize: Maximum tokens to keep when preparing messages for
|
||||
the summarization call.
|
||||
|
||||
Pass `None` to skip trimming entirely.
|
||||
"""
|
||||
# Handle deprecated parameters
|
||||
if "max_tokens_before_summary" in deprecated_kwargs:
|
||||
value = deprecated_kwargs["max_tokens_before_summary"]
|
||||
warnings.warn(
|
||||
"max_tokens_before_summary is deprecated. Use trigger=('tokens', value) instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
if trigger is None and value is not None:
|
||||
trigger = ("tokens", value)
|
||||
|
||||
if "messages_to_keep" in deprecated_kwargs:
|
||||
value = deprecated_kwargs["messages_to_keep"]
|
||||
warnings.warn(
|
||||
"messages_to_keep is deprecated. Use keep=('messages', value) instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
if keep == ("messages", _DEFAULT_MESSAGES_TO_KEEP):
|
||||
keep = ("messages", value)
|
||||
|
||||
super().__init__()
|
||||
|
||||
if isinstance(model, str):
|
||||
model = init_chat_model(model)
|
||||
|
||||
self.model = model
|
||||
self.max_tokens_before_summary = max_tokens_before_summary
|
||||
self.messages_to_keep = messages_to_keep
|
||||
if trigger is None:
|
||||
self.trigger: ContextCondition | None = None
|
||||
trigger_conditions: list[ContextSize | list[ContextSize]] = []
|
||||
elif isinstance(trigger, list):
|
||||
# Validate and normalize nested structure
|
||||
validated_list = self._validate_trigger_conditions(trigger)
|
||||
self.trigger = validated_list
|
||||
trigger_conditions = validated_list
|
||||
else:
|
||||
# Single ContextSize tuple
|
||||
validated = self._validate_context_size(trigger, "trigger")
|
||||
self.trigger = validated
|
||||
trigger_conditions = [validated]
|
||||
self._trigger_conditions = trigger_conditions
|
||||
|
||||
self.keep = self._validate_context_size(keep, "keep")
|
||||
self.token_counter = token_counter
|
||||
self.summary_prompt = summary_prompt
|
||||
self.summary_prefix = summary_prefix
|
||||
self.trim_tokens_to_summarize = trim_tokens_to_summarize
|
||||
|
||||
requires_profile = self._requires_profile(self._trigger_conditions)
|
||||
if self.keep[0] == "fraction":
|
||||
requires_profile = True
|
||||
if requires_profile and self._get_profile_limits() is None:
|
||||
msg = (
|
||||
"Model profile information is required to use fractional token limits. "
|
||||
'pip install "langchain[model-profiles]" or use absolute token counts '
|
||||
"instead."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
def before_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
|
||||
"""Process messages before model invocation, potentially triggering summarization."""
|
||||
@@ -105,13 +185,10 @@ class SummarizationMiddleware(AgentMiddleware):
|
||||
self._ensure_message_ids(messages)
|
||||
|
||||
total_tokens = self.token_counter(messages)
|
||||
if (
|
||||
self.max_tokens_before_summary is not None
|
||||
and total_tokens < self.max_tokens_before_summary
|
||||
):
|
||||
if not self._should_summarize(messages, total_tokens):
|
||||
return None
|
||||
|
||||
cutoff_index = self._find_safe_cutoff(messages)
|
||||
cutoff_index = self._determine_cutoff_index(messages)
|
||||
|
||||
if cutoff_index <= 0:
|
||||
return None
|
||||
@@ -129,6 +206,218 @@ class SummarizationMiddleware(AgentMiddleware):
|
||||
]
|
||||
}
|
||||
|
||||
async def abefore_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
|
||||
"""Process messages before model invocation, potentially triggering summarization."""
|
||||
messages = state["messages"]
|
||||
self._ensure_message_ids(messages)
|
||||
|
||||
total_tokens = self.token_counter(messages)
|
||||
if not self._should_summarize(messages, total_tokens):
|
||||
return None
|
||||
|
||||
cutoff_index = self._determine_cutoff_index(messages)
|
||||
|
||||
if cutoff_index <= 0:
|
||||
return None
|
||||
|
||||
messages_to_summarize, preserved_messages = self._partition_messages(messages, cutoff_index)
|
||||
|
||||
summary = await self._acreate_summary(messages_to_summarize)
|
||||
new_messages = self._build_new_messages(summary)
|
||||
|
||||
return {
|
||||
"messages": [
|
||||
RemoveMessage(id=REMOVE_ALL_MESSAGES),
|
||||
*new_messages,
|
||||
*preserved_messages,
|
||||
]
|
||||
}
|
||||
|
||||
def _should_summarize(self, messages: list[AnyMessage], total_tokens: int) -> bool:
|
||||
"""Determine whether summarization should run for the current token usage.
|
||||
|
||||
Evaluates trigger conditions with AND/OR logic:
|
||||
- Top-level items are OR'd together
|
||||
- Nested lists are AND'd together
|
||||
"""
|
||||
if not self._trigger_conditions:
|
||||
return False
|
||||
|
||||
# OR logic across top-level conditions
|
||||
for condition in self._trigger_conditions:
|
||||
if isinstance(condition, list):
|
||||
# AND group - all must be satisfied
|
||||
if self._check_and_group(condition, messages, total_tokens):
|
||||
return True
|
||||
elif self._check_single_condition(condition, messages, total_tokens):
|
||||
# Single condition
|
||||
return True
|
||||
return False
|
||||
|
||||
def _check_and_group(
|
||||
self, and_group: list[ContextSize], messages: list[AnyMessage], total_tokens: int
|
||||
) -> bool:
|
||||
"""Check if all conditions in an AND group are satisfied."""
|
||||
for condition in and_group:
|
||||
if not self._check_single_condition(condition, messages, total_tokens):
|
||||
return False
|
||||
return True
|
||||
|
||||
def _check_single_condition(
|
||||
self, condition: ContextSize, messages: list[AnyMessage], total_tokens: int
|
||||
) -> bool:
|
||||
"""Check if a single condition is satisfied."""
|
||||
kind, value = condition
|
||||
if kind == "messages":
|
||||
return len(messages) >= value
|
||||
if kind == "tokens":
|
||||
return total_tokens >= value
|
||||
if kind == "fraction":
|
||||
max_input_tokens = self._get_profile_limits()
|
||||
if max_input_tokens is None:
|
||||
return False
|
||||
threshold = int(max_input_tokens * value)
|
||||
if threshold <= 0:
|
||||
threshold = 1
|
||||
return total_tokens >= threshold
|
||||
return False
|
||||
|
||||
def _determine_cutoff_index(self, messages: list[AnyMessage]) -> int:
|
||||
"""Choose cutoff index respecting retention configuration."""
|
||||
kind, value = self.keep
|
||||
if kind in {"tokens", "fraction"}:
|
||||
token_based_cutoff = self._find_token_based_cutoff(messages)
|
||||
if token_based_cutoff is not None:
|
||||
return token_based_cutoff
|
||||
# None cutoff -> model profile data not available (caught in __init__ but
|
||||
# here for safety), fallback to message count
|
||||
return self._find_safe_cutoff(messages, _DEFAULT_MESSAGES_TO_KEEP)
|
||||
return self._find_safe_cutoff(messages, cast("int", value))
|
||||
|
||||
def _find_token_based_cutoff(self, messages: list[AnyMessage]) -> int | None:
|
||||
"""Find cutoff index based on target token retention."""
|
||||
if not messages:
|
||||
return 0
|
||||
|
||||
kind, value = self.keep
|
||||
if kind == "fraction":
|
||||
max_input_tokens = self._get_profile_limits()
|
||||
if max_input_tokens is None:
|
||||
return None
|
||||
target_token_count = int(max_input_tokens * value)
|
||||
elif kind == "tokens":
|
||||
target_token_count = int(value)
|
||||
else:
|
||||
return None
|
||||
|
||||
if target_token_count <= 0:
|
||||
target_token_count = 1
|
||||
|
||||
if self.token_counter(messages) <= target_token_count:
|
||||
return 0
|
||||
|
||||
# Use binary search to identify the earliest message index that keeps the
|
||||
# suffix within the token budget.
|
||||
left, right = 0, len(messages)
|
||||
cutoff_candidate = len(messages)
|
||||
max_iterations = len(messages).bit_length() + 1
|
||||
for _ in range(max_iterations):
|
||||
if left >= right:
|
||||
break
|
||||
|
||||
mid = (left + right) // 2
|
||||
if self.token_counter(messages[mid:]) <= target_token_count:
|
||||
cutoff_candidate = mid
|
||||
right = mid
|
||||
else:
|
||||
left = mid + 1
|
||||
|
||||
if cutoff_candidate == len(messages):
|
||||
cutoff_candidate = left
|
||||
|
||||
if cutoff_candidate >= len(messages):
|
||||
if len(messages) == 1:
|
||||
return 0
|
||||
cutoff_candidate = len(messages) - 1
|
||||
|
||||
for i in range(cutoff_candidate, -1, -1):
|
||||
if self._is_safe_cutoff_point(messages, i):
|
||||
return i
|
||||
|
||||
return 0
|
||||
|
||||
def _get_profile_limits(self) -> int | None:
|
||||
"""Retrieve max input token limit from the model profile."""
|
||||
try:
|
||||
profile = self.model.profile
|
||||
except (AttributeError, ImportError):
|
||||
return None
|
||||
|
||||
if not isinstance(profile, Mapping):
|
||||
return None
|
||||
|
||||
max_input_tokens = profile.get("max_input_tokens")
|
||||
|
||||
if not isinstance(max_input_tokens, int):
|
||||
return None
|
||||
|
||||
return max_input_tokens
|
||||
|
||||
def _validate_context_size(self, context: ContextSize, parameter_name: str) -> ContextSize:
|
||||
"""Validate context configuration tuples."""
|
||||
kind, value = context
|
||||
if kind == "fraction":
|
||||
if not 0 < value <= 1:
|
||||
msg = f"Fractional {parameter_name} values must be between 0 and 1, got {value}."
|
||||
raise ValueError(msg)
|
||||
elif kind in {"tokens", "messages"}:
|
||||
if value <= 0:
|
||||
msg = f"{parameter_name} thresholds must be greater than 0, got {value}."
|
||||
raise ValueError(msg)
|
||||
else:
|
||||
msg = f"Unsupported context size type {kind} for {parameter_name}."
|
||||
raise ValueError(msg)
|
||||
return context
|
||||
|
||||
def _validate_trigger_conditions(
|
||||
self, conditions: list[Any]
|
||||
) -> list[ContextSize | list[ContextSize]]:
|
||||
"""Validate and normalize trigger conditions with nested AND/OR logic.
|
||||
|
||||
Args:
|
||||
conditions: List of ContextSize tuples or nested lists of ContextSize tuples.
|
||||
|
||||
Returns:
|
||||
Validated list where top-level items are OR'd and nested lists are AND'd.
|
||||
"""
|
||||
validated: list[ContextSize | list[ContextSize]] = []
|
||||
for item in conditions:
|
||||
if isinstance(item, tuple):
|
||||
# Single condition (tuple)
|
||||
validated.append(self._validate_context_size(item, "trigger"))
|
||||
elif isinstance(item, list):
|
||||
# AND group (nested list)
|
||||
if not item:
|
||||
msg = "Empty AND groups are not allowed in trigger conditions."
|
||||
raise ValueError(msg)
|
||||
and_group = [self._validate_context_size(cond, "trigger") for cond in item]
|
||||
validated.append(and_group)
|
||||
else:
|
||||
msg = f"Trigger conditions must be tuples or lists, got {type(item).__name__}."
|
||||
raise ValueError(msg)
|
||||
return validated
|
||||
|
||||
def _requires_profile(self, conditions: list[ContextSize | list[ContextSize]]) -> bool:
|
||||
"""Check if any condition requires model profile information."""
|
||||
for condition in conditions:
|
||||
if isinstance(condition, list):
|
||||
# AND group
|
||||
if any(c[0] == "fraction" for c in condition):
|
||||
return True
|
||||
elif condition[0] == "fraction":
|
||||
return True
|
||||
return False
|
||||
|
||||
def _build_new_messages(self, summary: str) -> list[HumanMessage]:
|
||||
return [
|
||||
HumanMessage(content=f"Here is a summary of the conversation to date:\n\n{summary}")
|
||||
@@ -151,16 +440,16 @@ class SummarizationMiddleware(AgentMiddleware):
|
||||
|
||||
return messages_to_summarize, preserved_messages
|
||||
|
||||
def _find_safe_cutoff(self, messages: list[AnyMessage]) -> int:
|
||||
def _find_safe_cutoff(self, messages: list[AnyMessage], messages_to_keep: int) -> int:
|
||||
"""Find safe cutoff point that preserves AI/Tool message pairs.
|
||||
|
||||
Returns the index where messages can be safely cut without separating
|
||||
related AI and Tool messages. Returns 0 if no safe cutoff is found.
|
||||
related AI and Tool messages. Returns `0` if no safe cutoff is found.
|
||||
"""
|
||||
if len(messages) <= self.messages_to_keep:
|
||||
if len(messages) <= messages_to_keep:
|
||||
return 0
|
||||
|
||||
target_cutoff = len(messages) - self.messages_to_keep
|
||||
target_cutoff = len(messages) - messages_to_keep
|
||||
|
||||
for i in range(target_cutoff, -1, -1):
|
||||
if self._is_safe_cutoff_point(messages, i):
|
||||
@@ -229,16 +518,35 @@ class SummarizationMiddleware(AgentMiddleware):
|
||||
|
||||
try:
|
||||
response = self.model.invoke(self.summary_prompt.format(messages=trimmed_messages))
|
||||
return cast("str", response.content).strip()
|
||||
return response.text.strip()
|
||||
except Exception as e: # noqa: BLE001
|
||||
return f"Error generating summary: {e!s}"
|
||||
|
||||
async def _acreate_summary(self, messages_to_summarize: list[AnyMessage]) -> str:
|
||||
"""Generate summary for the given messages."""
|
||||
if not messages_to_summarize:
|
||||
return "No previous conversation history."
|
||||
|
||||
trimmed_messages = self._trim_messages_for_summary(messages_to_summarize)
|
||||
if not trimmed_messages:
|
||||
return "Previous conversation was too long to summarize."
|
||||
|
||||
try:
|
||||
response = await self.model.ainvoke(
|
||||
self.summary_prompt.format(messages=trimmed_messages)
|
||||
)
|
||||
return response.text.strip()
|
||||
except Exception as e: # noqa: BLE001
|
||||
return f"Error generating summary: {e!s}"
|
||||
|
||||
def _trim_messages_for_summary(self, messages: list[AnyMessage]) -> list[AnyMessage]:
|
||||
"""Trim messages to fit within summary generation limits."""
|
||||
try:
|
||||
if self.trim_tokens_to_summarize is None:
|
||||
return messages
|
||||
return trim_messages(
|
||||
messages,
|
||||
max_tokens=_DEFAULT_TRIM_TOKEN_LIMIT,
|
||||
max_tokens=self.trim_tokens_to_summarize,
|
||||
token_counter=self.token_counter,
|
||||
start_on="human",
|
||||
strategy="last",
|
||||
|
||||
@@ -150,12 +150,6 @@ class TodoListMiddleware(AgentMiddleware):
|
||||
|
||||
print(result["todos"]) # Array of todo items with status tracking
|
||||
```
|
||||
|
||||
Args:
|
||||
system_prompt: Custom system prompt to guide the agent on using the todo tool.
|
||||
If not provided, uses the default `WRITE_TODOS_SYSTEM_PROMPT`.
|
||||
tool_description: Custom description for the write_todos tool.
|
||||
If not provided, uses the default `WRITE_TODOS_TOOL_DESCRIPTION`.
|
||||
"""
|
||||
|
||||
state_schema = PlanningState
|
||||
@@ -166,11 +160,12 @@ class TodoListMiddleware(AgentMiddleware):
|
||||
system_prompt: str = WRITE_TODOS_SYSTEM_PROMPT,
|
||||
tool_description: str = WRITE_TODOS_TOOL_DESCRIPTION,
|
||||
) -> None:
|
||||
"""Initialize the TodoListMiddleware with optional custom prompts.
|
||||
"""Initialize the `TodoListMiddleware` with optional custom prompts.
|
||||
|
||||
Args:
|
||||
system_prompt: Custom system prompt to guide the agent on using the todo tool.
|
||||
tool_description: Custom description for the write_todos tool.
|
||||
system_prompt: Custom system prompt to guide the agent on using the todo
|
||||
tool.
|
||||
tool_description: Custom description for the `write_todos` tool.
|
||||
"""
|
||||
super().__init__()
|
||||
self.system_prompt = system_prompt
|
||||
|
||||
@@ -23,22 +23,23 @@ if TYPE_CHECKING:
|
||||
ExitBehavior = Literal["continue", "error", "end"]
|
||||
"""How to handle execution when tool call limits are exceeded.
|
||||
|
||||
- `"continue"`: Block exceeded tools with error messages, let other tools continue (default)
|
||||
- `"error"`: Raise a `ToolCallLimitExceededError` exception
|
||||
- `"end"`: Stop execution immediately, injecting a ToolMessage and an AI message
|
||||
for the single tool call that exceeded the limit. Raises `NotImplementedError`
|
||||
if there are other pending tool calls (due to parallel tool calling).
|
||||
- `'continue'`: Block exceeded tools with error messages, let other tools continue
|
||||
(default)
|
||||
- `'error'`: Raise a `ToolCallLimitExceededError` exception
|
||||
- `'end'`: Stop execution immediately, injecting a `ToolMessage` and an `AIMessage` for
|
||||
the single tool call that exceeded the limit. Raises `NotImplementedError` if there
|
||||
are other pending tool calls (due to parallel tool calling).
|
||||
"""
|
||||
|
||||
|
||||
class ToolCallLimitState(AgentState[ResponseT], Generic[ResponseT]):
|
||||
"""State schema for ToolCallLimitMiddleware.
|
||||
"""State schema for `ToolCallLimitMiddleware`.
|
||||
|
||||
Extends AgentState with tool call tracking fields.
|
||||
Extends `AgentState` with tool call tracking fields.
|
||||
|
||||
The count fields are dictionaries mapping tool names to execution counts.
|
||||
This allows multiple middleware instances to track different tools independently.
|
||||
The special key "__all__" is used for tracking all tool calls globally.
|
||||
The count fields are dictionaries mapping tool names to execution counts. This
|
||||
allows multiple middleware instances to track different tools independently. The
|
||||
special key `'__all__'` is used for tracking all tool calls globally.
|
||||
"""
|
||||
|
||||
thread_tool_call_count: NotRequired[Annotated[dict[str, int], PrivateStateAttr]]
|
||||
@@ -46,13 +47,13 @@ class ToolCallLimitState(AgentState[ResponseT], Generic[ResponseT]):
|
||||
|
||||
|
||||
def _build_tool_message_content(tool_name: str | None) -> str:
|
||||
"""Build the error message content for ToolMessage when limit is exceeded.
|
||||
"""Build the error message content for `ToolMessage` when limit is exceeded.
|
||||
|
||||
This message is sent to the model, so it should not reference thread/run concepts
|
||||
that the model has no notion of.
|
||||
|
||||
Args:
|
||||
tool_name: Tool name being limited (if specific tool), or None for all tools.
|
||||
tool_name: Tool name being limited (if specific tool), or `None` for all tools.
|
||||
|
||||
Returns:
|
||||
A concise message instructing the model not to call the tool again.
|
||||
@@ -70,7 +71,7 @@ def _build_final_ai_message_content(
|
||||
run_limit: int | None,
|
||||
tool_name: str | None,
|
||||
) -> str:
|
||||
"""Build the final AI message content for 'end' behavior.
|
||||
"""Build the final AI message content for `'end'` behavior.
|
||||
|
||||
This message is displayed to the user, so it should include detailed information
|
||||
about which limits were exceeded.
|
||||
@@ -80,7 +81,7 @@ def _build_final_ai_message_content(
|
||||
run_count: Current run tool call count.
|
||||
thread_limit: Thread tool call limit (if set).
|
||||
run_limit: Run tool call limit (if set).
|
||||
tool_name: Tool name being limited (if specific tool), or None for all tools.
|
||||
tool_name: Tool name being limited (if specific tool), or `None` for all tools.
|
||||
|
||||
Returns:
|
||||
A formatted message describing which limits were exceeded.
|
||||
@@ -100,8 +101,8 @@ def _build_final_ai_message_content(
|
||||
class ToolCallLimitExceededError(Exception):
|
||||
"""Exception raised when tool call limits are exceeded.
|
||||
|
||||
This exception is raised when the configured exit behavior is 'error'
|
||||
and either the thread or run tool call limit has been exceeded.
|
||||
This exception is raised when the configured exit behavior is `'error'` and either
|
||||
the thread or run tool call limit has been exceeded.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -145,48 +146,53 @@ class ToolCallLimitMiddleware(
|
||||
|
||||
Configuration:
|
||||
- `exit_behavior`: How to handle when limits are exceeded
|
||||
- `"continue"`: Block exceeded tools, let execution continue (default)
|
||||
- `"error"`: Raise an exception
|
||||
- `"end"`: Stop immediately with a ToolMessage + AI message for the single
|
||||
tool call that exceeded the limit (raises `NotImplementedError` if there
|
||||
are other pending tool calls (due to parallel tool calling).
|
||||
- `'continue'`: Block exceeded tools, let execution continue (default)
|
||||
- `'error'`: Raise an exception
|
||||
- `'end'`: Stop immediately with a `ToolMessage` + AI message for the single
|
||||
tool call that exceeded the limit (raises `NotImplementedError` if there
|
||||
are other pending tool calls (due to parallel tool calling).
|
||||
|
||||
Examples:
|
||||
Continue execution with blocked tools (default):
|
||||
```python
|
||||
from langchain.agents.middleware.tool_call_limit import ToolCallLimitMiddleware
|
||||
from langchain.agents import create_agent
|
||||
!!! example "Continue execution with blocked tools (default)"
|
||||
|
||||
# Block exceeded tools but let other tools and model continue
|
||||
limiter = ToolCallLimitMiddleware(
|
||||
thread_limit=20,
|
||||
run_limit=10,
|
||||
exit_behavior="continue", # default
|
||||
)
|
||||
```python
|
||||
from langchain.agents.middleware.tool_call_limit import ToolCallLimitMiddleware
|
||||
from langchain.agents import create_agent
|
||||
|
||||
agent = create_agent("openai:gpt-4o", middleware=[limiter])
|
||||
```
|
||||
# Block exceeded tools but let other tools and model continue
|
||||
limiter = ToolCallLimitMiddleware(
|
||||
thread_limit=20,
|
||||
run_limit=10,
|
||||
exit_behavior="continue", # default
|
||||
)
|
||||
|
||||
Stop immediately when limit exceeded:
|
||||
```python
|
||||
# End execution immediately with an AI message
|
||||
limiter = ToolCallLimitMiddleware(run_limit=5, exit_behavior="end")
|
||||
agent = create_agent("openai:gpt-4o", middleware=[limiter])
|
||||
```
|
||||
|
||||
agent = create_agent("openai:gpt-4o", middleware=[limiter])
|
||||
```
|
||||
!!! example "Stop immediately when limit exceeded"
|
||||
|
||||
Raise exception on limit:
|
||||
```python
|
||||
# Strict limit with exception handling
|
||||
limiter = ToolCallLimitMiddleware(tool_name="search", thread_limit=5, exit_behavior="error")
|
||||
```python
|
||||
# End execution immediately with an AI message
|
||||
limiter = ToolCallLimitMiddleware(run_limit=5, exit_behavior="end")
|
||||
|
||||
agent = create_agent("openai:gpt-4o", middleware=[limiter])
|
||||
agent = create_agent("openai:gpt-4o", middleware=[limiter])
|
||||
```
|
||||
|
||||
try:
|
||||
result = await agent.invoke({"messages": [HumanMessage("Task")]})
|
||||
except ToolCallLimitExceededError as e:
|
||||
print(f"Search limit exceeded: {e}")
|
||||
```
|
||||
!!! example "Raise exception on limit"
|
||||
|
||||
```python
|
||||
# Strict limit with exception handling
|
||||
limiter = ToolCallLimitMiddleware(
|
||||
tool_name="search", thread_limit=5, exit_behavior="error"
|
||||
)
|
||||
|
||||
agent = create_agent("openai:gpt-4o", middleware=[limiter])
|
||||
|
||||
try:
|
||||
result = await agent.invoke({"messages": [HumanMessage("Task")]})
|
||||
except ToolCallLimitExceededError as e:
|
||||
print(f"Search limit exceeded: {e}")
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
@@ -204,23 +210,24 @@ class ToolCallLimitMiddleware(
|
||||
|
||||
Args:
|
||||
tool_name: Name of the specific tool to limit. If `None`, limits apply
|
||||
to all tools. Defaults to `None`.
|
||||
to all tools.
|
||||
thread_limit: Maximum number of tool calls allowed per thread.
|
||||
`None` means no limit. Defaults to `None`.
|
||||
`None` means no limit.
|
||||
run_limit: Maximum number of tool calls allowed per run.
|
||||
`None` means no limit. Defaults to `None`.
|
||||
`None` means no limit.
|
||||
exit_behavior: How to handle when limits are exceeded.
|
||||
- `"continue"`: Block exceeded tools with error messages, let other
|
||||
tools continue. Model decides when to end. (default)
|
||||
- `"error"`: Raise a `ToolCallLimitExceededError` exception
|
||||
- `"end"`: Stop execution immediately with a ToolMessage + AI message
|
||||
for the single tool call that exceeded the limit. Raises
|
||||
`NotImplementedError` if there are multiple parallel tool
|
||||
calls to other tools or multiple pending tool calls.
|
||||
|
||||
- `'continue'`: Block exceeded tools with error messages, let other
|
||||
tools continue. Model decides when to end.
|
||||
- `'error'`: Raise a `ToolCallLimitExceededError` exception
|
||||
- `'end'`: Stop execution immediately with a `ToolMessage` + AI message
|
||||
for the single tool call that exceeded the limit. Raises
|
||||
`NotImplementedError` if there are multiple parallel tool
|
||||
calls to other tools or multiple pending tool calls.
|
||||
|
||||
Raises:
|
||||
ValueError: If both limits are `None`, if exit_behavior is invalid,
|
||||
or if run_limit exceeds thread_limit.
|
||||
ValueError: If both limits are `None`, if `exit_behavior` is invalid,
|
||||
or if `run_limit` exceeds `thread_limit`.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
@@ -293,7 +300,8 @@ class ToolCallLimitMiddleware(
|
||||
run_count: Current run call count.
|
||||
|
||||
Returns:
|
||||
Tuple of (allowed_calls, blocked_calls, final_thread_count, final_run_count).
|
||||
Tuple of `(allowed_calls, blocked_calls, final_thread_count,
|
||||
final_run_count)`.
|
||||
"""
|
||||
allowed_calls: list[ToolCall] = []
|
||||
blocked_calls: list[ToolCall] = []
|
||||
@@ -327,13 +335,13 @@ class ToolCallLimitMiddleware(
|
||||
|
||||
Returns:
|
||||
State updates with incremented tool call counts. If limits are exceeded
|
||||
and exit_behavior is "end", also includes a jump to end with a ToolMessage
|
||||
and AI message for the single exceeded tool call.
|
||||
and exit_behavior is `'end'`, also includes a jump to end with a
|
||||
`ToolMessage` and AI message for the single exceeded tool call.
|
||||
|
||||
Raises:
|
||||
ToolCallLimitExceededError: If limits are exceeded and exit_behavior
|
||||
is "error".
|
||||
NotImplementedError: If limits are exceeded, exit_behavior is "end",
|
||||
ToolCallLimitExceededError: If limits are exceeded and `exit_behavior`
|
||||
is `'error'`.
|
||||
NotImplementedError: If limits are exceeded, `exit_behavior` is `'end'`,
|
||||
and there are multiple tool calls.
|
||||
"""
|
||||
# Get the last AIMessage to check for tool calls
|
||||
@@ -452,3 +460,28 @@ class ToolCallLimitMiddleware(
|
||||
"run_tool_call_count": run_counts,
|
||||
"messages": artificial_messages,
|
||||
}
|
||||
|
||||
@hook_config(can_jump_to=["end"])
|
||||
async def aafter_model(
|
||||
self,
|
||||
state: ToolCallLimitState[ResponseT],
|
||||
runtime: Runtime[ContextT],
|
||||
) -> dict[str, Any] | None:
|
||||
"""Async increment tool call counts after a model call and check limits.
|
||||
|
||||
Args:
|
||||
state: The current agent state.
|
||||
runtime: The langgraph runtime.
|
||||
|
||||
Returns:
|
||||
State updates with incremented tool call counts. If limits are exceeded
|
||||
and exit_behavior is `'end'`, also includes a jump to end with a
|
||||
`ToolMessage` and AI message for the single exceeded tool call.
|
||||
|
||||
Raises:
|
||||
ToolCallLimitExceededError: If limits are exceeded and `exit_behavior`
|
||||
is `'error'`.
|
||||
NotImplementedError: If limits are exceeded, `exit_behavior` is `'end'`,
|
||||
and there are multiple tool calls.
|
||||
"""
|
||||
return self.after_model(state, runtime)
|
||||
|
||||
@@ -23,39 +23,44 @@ class LLMToolEmulator(AgentMiddleware):
|
||||
"""Emulates specified tools using an LLM instead of executing them.
|
||||
|
||||
This middleware allows selective emulation of tools for testing purposes.
|
||||
By default (when tools=None), all tools are emulated. You can specify which
|
||||
tools to emulate by passing a list of tool names or BaseTool instances.
|
||||
|
||||
By default (when `tools=None`), all tools are emulated. You can specify which
|
||||
tools to emulate by passing a list of tool names or `BaseTool` instances.
|
||||
|
||||
Examples:
|
||||
Emulate all tools (default behavior):
|
||||
```python
|
||||
from langchain.agents.middleware import LLMToolEmulator
|
||||
!!! example "Emulate all tools (default behavior)"
|
||||
|
||||
middleware = LLMToolEmulator()
|
||||
```python
|
||||
from langchain.agents.middleware import LLMToolEmulator
|
||||
|
||||
agent = create_agent(
|
||||
model="openai:gpt-4o",
|
||||
tools=[get_weather, get_user_location, calculator],
|
||||
middleware=[middleware],
|
||||
)
|
||||
```
|
||||
middleware = LLMToolEmulator()
|
||||
|
||||
Emulate specific tools by name:
|
||||
```python
|
||||
middleware = LLMToolEmulator(tools=["get_weather", "get_user_location"])
|
||||
```
|
||||
agent = create_agent(
|
||||
model="openai:gpt-4o",
|
||||
tools=[get_weather, get_user_location, calculator],
|
||||
middleware=[middleware],
|
||||
)
|
||||
```
|
||||
|
||||
Use a custom model for emulation:
|
||||
```python
|
||||
middleware = LLMToolEmulator(
|
||||
tools=["get_weather"], model="anthropic:claude-sonnet-4-5-20250929"
|
||||
)
|
||||
```
|
||||
!!! example "Emulate specific tools by name"
|
||||
|
||||
Emulate specific tools by passing tool instances:
|
||||
```python
|
||||
middleware = LLMToolEmulator(tools=[get_weather, get_user_location])
|
||||
```
|
||||
```python
|
||||
middleware = LLMToolEmulator(tools=["get_weather", "get_user_location"])
|
||||
```
|
||||
|
||||
!!! example "Use a custom model for emulation"
|
||||
|
||||
```python
|
||||
middleware = LLMToolEmulator(
|
||||
tools=["get_weather"], model="anthropic:claude-sonnet-4-5-20250929"
|
||||
)
|
||||
```
|
||||
|
||||
!!! example "Emulate specific tools by passing tool instances"
|
||||
|
||||
```python
|
||||
middleware = LLMToolEmulator(tools=[get_weather, get_user_location])
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -67,12 +72,16 @@ class LLMToolEmulator(AgentMiddleware):
|
||||
"""Initialize the tool emulator.
|
||||
|
||||
Args:
|
||||
tools: List of tool names (str) or BaseTool instances to emulate.
|
||||
If None (default), ALL tools will be emulated.
|
||||
tools: List of tool names (`str`) or `BaseTool` instances to emulate.
|
||||
|
||||
If `None`, ALL tools will be emulated.
|
||||
|
||||
If empty list, no tools will be emulated.
|
||||
model: Model to use for emulation.
|
||||
Defaults to "anthropic:claude-sonnet-4-5-20250929".
|
||||
Can be a model identifier string or BaseChatModel instance.
|
||||
|
||||
Defaults to `'anthropic:claude-sonnet-4-5-20250929'`.
|
||||
|
||||
Can be a model identifier string or `BaseChatModel` instance.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
@@ -110,7 +119,7 @@ class LLMToolEmulator(AgentMiddleware):
|
||||
|
||||
Returns:
|
||||
ToolMessage with emulated response if tool should be emulated,
|
||||
otherwise calls handler for normal execution.
|
||||
otherwise calls handler for normal execution.
|
||||
"""
|
||||
tool_name = request.tool_call["name"]
|
||||
|
||||
@@ -152,7 +161,7 @@ class LLMToolEmulator(AgentMiddleware):
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
|
||||
) -> ToolMessage | Command:
|
||||
"""Async version of wrap_tool_call.
|
||||
"""Async version of `wrap_tool_call`.
|
||||
|
||||
Emulate tool execution using LLM if tool should be emulated.
|
||||
|
||||
@@ -162,7 +171,7 @@ class LLMToolEmulator(AgentMiddleware):
|
||||
|
||||
Returns:
|
||||
ToolMessage with emulated response if tool should be emulated,
|
||||
otherwise calls handler for normal execution.
|
||||
otherwise calls handler for normal execution.
|
||||
"""
|
||||
tool_name = request.tool_call["name"]
|
||||
|
||||
|
||||
@@ -26,89 +26,96 @@ class ToolRetryMiddleware(AgentMiddleware):
|
||||
Supports retrying on specific exceptions and exponential backoff.
|
||||
|
||||
Examples:
|
||||
Basic usage with default settings (2 retries, exponential backoff):
|
||||
```python
|
||||
from langchain.agents import create_agent
|
||||
from langchain.agents.middleware import ToolRetryMiddleware
|
||||
!!! example "Basic usage with default settings (2 retries, exponential backoff)"
|
||||
|
||||
agent = create_agent(model, tools=[search_tool], middleware=[ToolRetryMiddleware()])
|
||||
```
|
||||
```python
|
||||
from langchain.agents import create_agent
|
||||
from langchain.agents.middleware import ToolRetryMiddleware
|
||||
|
||||
Retry specific exceptions only:
|
||||
```python
|
||||
from requests.exceptions import RequestException, Timeout
|
||||
agent = create_agent(model, tools=[search_tool], middleware=[ToolRetryMiddleware()])
|
||||
```
|
||||
|
||||
retry = ToolRetryMiddleware(
|
||||
max_retries=4,
|
||||
retry_on=(RequestException, Timeout),
|
||||
backoff_factor=1.5,
|
||||
)
|
||||
```
|
||||
!!! example "Retry specific exceptions only"
|
||||
|
||||
Custom exception filtering:
|
||||
```python
|
||||
from requests.exceptions import HTTPError
|
||||
```python
|
||||
from requests.exceptions import RequestException, Timeout
|
||||
|
||||
retry = ToolRetryMiddleware(
|
||||
max_retries=4,
|
||||
retry_on=(RequestException, Timeout),
|
||||
backoff_factor=1.5,
|
||||
)
|
||||
```
|
||||
|
||||
!!! example "Custom exception filtering"
|
||||
|
||||
```python
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
|
||||
def should_retry(exc: Exception) -> bool:
|
||||
# Only retry on 5xx errors
|
||||
if isinstance(exc, HTTPError):
|
||||
return 500 <= exc.status_code < 600
|
||||
return False
|
||||
def should_retry(exc: Exception) -> bool:
|
||||
# Only retry on 5xx errors
|
||||
if isinstance(exc, HTTPError):
|
||||
return 500 <= exc.status_code < 600
|
||||
return False
|
||||
|
||||
|
||||
retry = ToolRetryMiddleware(
|
||||
max_retries=3,
|
||||
retry_on=should_retry,
|
||||
)
|
||||
```
|
||||
retry = ToolRetryMiddleware(
|
||||
max_retries=3,
|
||||
retry_on=should_retry,
|
||||
)
|
||||
```
|
||||
|
||||
Apply to specific tools with custom error handling:
|
||||
```python
|
||||
def format_error(exc: Exception) -> str:
|
||||
return "Database temporarily unavailable. Please try again later."
|
||||
!!! example "Apply to specific tools with custom error handling"
|
||||
|
||||
```python
|
||||
def format_error(exc: Exception) -> str:
|
||||
return "Database temporarily unavailable. Please try again later."
|
||||
|
||||
|
||||
retry = ToolRetryMiddleware(
|
||||
max_retries=4,
|
||||
tools=["search_database"],
|
||||
on_failure=format_error,
|
||||
)
|
||||
```
|
||||
retry = ToolRetryMiddleware(
|
||||
max_retries=4,
|
||||
tools=["search_database"],
|
||||
on_failure=format_error,
|
||||
)
|
||||
```
|
||||
|
||||
Apply to specific tools using BaseTool instances:
|
||||
```python
|
||||
from langchain_core.tools import tool
|
||||
!!! example "Apply to specific tools using `BaseTool` instances"
|
||||
|
||||
```python
|
||||
from langchain_core.tools import tool
|
||||
|
||||
|
||||
@tool
|
||||
def search_database(query: str) -> str:
|
||||
'''Search the database.'''
|
||||
return results
|
||||
@tool
|
||||
def search_database(query: str) -> str:
|
||||
'''Search the database.'''
|
||||
return results
|
||||
|
||||
|
||||
retry = ToolRetryMiddleware(
|
||||
max_retries=4,
|
||||
tools=[search_database], # Pass BaseTool instance
|
||||
)
|
||||
```
|
||||
retry = ToolRetryMiddleware(
|
||||
max_retries=4,
|
||||
tools=[search_database], # Pass BaseTool instance
|
||||
)
|
||||
```
|
||||
|
||||
Constant backoff (no exponential growth):
|
||||
```python
|
||||
retry = ToolRetryMiddleware(
|
||||
max_retries=5,
|
||||
backoff_factor=0.0, # No exponential growth
|
||||
initial_delay=2.0, # Always wait 2 seconds
|
||||
)
|
||||
```
|
||||
!!! example "Constant backoff (no exponential growth)"
|
||||
|
||||
Raise exception on failure:
|
||||
```python
|
||||
retry = ToolRetryMiddleware(
|
||||
max_retries=2,
|
||||
on_failure="raise", # Re-raise exception instead of returning message
|
||||
)
|
||||
```
|
||||
```python
|
||||
retry = ToolRetryMiddleware(
|
||||
max_retries=5,
|
||||
backoff_factor=0.0, # No exponential growth
|
||||
initial_delay=2.0, # Always wait 2 seconds
|
||||
)
|
||||
```
|
||||
|
||||
!!! example "Raise exception on failure"
|
||||
|
||||
```python
|
||||
retry = ToolRetryMiddleware(
|
||||
max_retries=2,
|
||||
on_failure="raise", # Re-raise exception instead of returning message
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -125,34 +132,47 @@ class ToolRetryMiddleware(AgentMiddleware):
|
||||
max_delay: float = 60.0,
|
||||
jitter: bool = True,
|
||||
) -> None:
|
||||
"""Initialize ToolRetryMiddleware.
|
||||
"""Initialize `ToolRetryMiddleware`.
|
||||
|
||||
Args:
|
||||
max_retries: Maximum number of retry attempts after the initial call.
|
||||
Default is 2 retries (3 total attempts). Must be >= 0.
|
||||
|
||||
Default is `2` retries (`3` total attempts).
|
||||
|
||||
Must be `>= 0`.
|
||||
tools: Optional list of tools or tool names to apply retry logic to.
|
||||
|
||||
Can be a list of `BaseTool` instances or tool name strings.
|
||||
If `None`, applies to all tools. Default is `None`.
|
||||
|
||||
If `None`, applies to all tools.
|
||||
retry_on: Either a tuple of exception types to retry on, or a callable
|
||||
that takes an exception and returns `True` if it should be retried.
|
||||
|
||||
Default is to retry on all exceptions.
|
||||
on_failure: Behavior when all retries are exhausted. Options:
|
||||
- `"return_message"` (default): Return a ToolMessage with error details,
|
||||
allowing the LLM to handle the failure and potentially recover.
|
||||
- `"raise"`: Re-raise the exception, stopping agent execution.
|
||||
- Custom callable: Function that takes the exception and returns a string
|
||||
for the ToolMessage content, allowing custom error formatting.
|
||||
backoff_factor: Multiplier for exponential backoff. Each retry waits
|
||||
`initial_delay * (backoff_factor ** retry_number)` seconds.
|
||||
Set to 0.0 for constant delay. Default is 2.0.
|
||||
initial_delay: Initial delay in seconds before first retry. Default is 1.0.
|
||||
max_delay: Maximum delay in seconds between retries. Caps exponential
|
||||
backoff growth. Default is 60.0.
|
||||
jitter: Whether to add random jitter (±25%) to delay to avoid thundering herd.
|
||||
Default is `True`.
|
||||
on_failure: Behavior when all retries are exhausted.
|
||||
|
||||
Options:
|
||||
|
||||
- `'return_message'`: Return a `ToolMessage` with error details,
|
||||
allowing the LLM to handle the failure and potentially recover.
|
||||
- `'raise'`: Re-raise the exception, stopping agent execution.
|
||||
- **Custom callable:** Function that takes the exception and returns a
|
||||
string for the `ToolMessage` content, allowing custom error
|
||||
formatting.
|
||||
backoff_factor: Multiplier for exponential backoff.
|
||||
|
||||
Each retry waits `initial_delay * (backoff_factor ** retry_number)`
|
||||
seconds.
|
||||
|
||||
Set to `0.0` for constant delay.
|
||||
initial_delay: Initial delay in seconds before first retry.
|
||||
max_delay: Maximum delay in seconds between retries.
|
||||
|
||||
Caps exponential backoff growth.
|
||||
jitter: Whether to add random jitter (`±25%`) to delay to avoid thundering herd.
|
||||
|
||||
Raises:
|
||||
ValueError: If max_retries < 0 or delays are negative.
|
||||
ValueError: If `max_retries < 0` or delays are negative.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
@@ -260,15 +280,15 @@ class ToolRetryMiddleware(AgentMiddleware):
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool that failed.
|
||||
tool_call_id: ID of the tool call (may be None).
|
||||
tool_call_id: ID of the tool call (may be `None`).
|
||||
exc: The exception that caused the failure.
|
||||
attempts_made: Number of attempts actually made.
|
||||
|
||||
Returns:
|
||||
ToolMessage with error details.
|
||||
`ToolMessage` with error details.
|
||||
|
||||
Raises:
|
||||
Exception: If on_failure is "raise", re-raises the exception.
|
||||
Exception: If `on_failure` is `'raise'`, re-raises the exception.
|
||||
"""
|
||||
if self.on_failure == "raise":
|
||||
raise exc
|
||||
@@ -293,11 +313,11 @@ class ToolRetryMiddleware(AgentMiddleware):
|
||||
"""Intercept tool execution and retry on failure.
|
||||
|
||||
Args:
|
||||
request: Tool call request with call dict, BaseTool, state, and runtime.
|
||||
request: Tool call request with call dict, `BaseTool`, state, and runtime.
|
||||
handler: Callable to execute the tool (can be called multiple times).
|
||||
|
||||
Returns:
|
||||
ToolMessage or Command (the final result).
|
||||
`ToolMessage` or `Command` (the final result).
|
||||
"""
|
||||
tool_name = request.tool.name if request.tool else request.tool_call["name"]
|
||||
|
||||
@@ -342,11 +362,12 @@ class ToolRetryMiddleware(AgentMiddleware):
|
||||
"""Intercept and control async tool execution with retry logic.
|
||||
|
||||
Args:
|
||||
request: Tool call request with call dict, BaseTool, state, and runtime.
|
||||
handler: Async callable to execute the tool and returns ToolMessage or Command.
|
||||
request: Tool call request with call `dict`, `BaseTool`, state, and runtime.
|
||||
handler: Async callable to execute the tool and returns `ToolMessage` or
|
||||
`Command`.
|
||||
|
||||
Returns:
|
||||
ToolMessage or Command (the final result).
|
||||
`ToolMessage` or `Command` (the final result).
|
||||
"""
|
||||
tool_name = request.tool.name if request.tool else request.tool_call["name"]
|
||||
|
||||
|
||||
@@ -49,7 +49,8 @@ def _create_tool_selection_response(tools: list[BaseTool]) -> TypeAdapter:
|
||||
tools: Available tools to include in the schema.
|
||||
|
||||
Returns:
|
||||
TypeAdapter for a schema where each tool name is a Literal with its description.
|
||||
`TypeAdapter` for a schema where each tool name is a `Literal` with its
|
||||
description.
|
||||
"""
|
||||
if not tools:
|
||||
msg = "Invalid usage: tools must be non-empty"
|
||||
@@ -92,23 +93,25 @@ class LLMToolSelectorMiddleware(AgentMiddleware):
|
||||
and helps the main model focus on the right tools.
|
||||
|
||||
Examples:
|
||||
Limit to 3 tools:
|
||||
```python
|
||||
from langchain.agents.middleware import LLMToolSelectorMiddleware
|
||||
!!! example "Limit to 3 tools"
|
||||
|
||||
middleware = LLMToolSelectorMiddleware(max_tools=3)
|
||||
```python
|
||||
from langchain.agents.middleware import LLMToolSelectorMiddleware
|
||||
|
||||
agent = create_agent(
|
||||
model="openai:gpt-4o",
|
||||
tools=[tool1, tool2, tool3, tool4, tool5],
|
||||
middleware=[middleware],
|
||||
)
|
||||
```
|
||||
middleware = LLMToolSelectorMiddleware(max_tools=3)
|
||||
|
||||
Use a smaller model for selection:
|
||||
```python
|
||||
middleware = LLMToolSelectorMiddleware(model="openai:gpt-4o-mini", max_tools=2)
|
||||
```
|
||||
agent = create_agent(
|
||||
model="openai:gpt-4o",
|
||||
tools=[tool1, tool2, tool3, tool4, tool5],
|
||||
middleware=[middleware],
|
||||
)
|
||||
```
|
||||
|
||||
!!! example "Use a smaller model for selection"
|
||||
|
||||
```python
|
||||
middleware = LLMToolSelectorMiddleware(model="openai:gpt-4o-mini", max_tools=2)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -122,13 +125,20 @@ class LLMToolSelectorMiddleware(AgentMiddleware):
|
||||
"""Initialize the tool selector.
|
||||
|
||||
Args:
|
||||
model: Model to use for selection. If not provided, uses the agent's main model.
|
||||
Can be a model identifier string or BaseChatModel instance.
|
||||
model: Model to use for selection.
|
||||
|
||||
If not provided, uses the agent's main model.
|
||||
|
||||
Can be a model identifier string or `BaseChatModel` instance.
|
||||
system_prompt: Instructions for the selection model.
|
||||
max_tools: Maximum number of tools to select. If the model selects more,
|
||||
only the first max_tools will be used. No limit if not specified.
|
||||
max_tools: Maximum number of tools to select.
|
||||
|
||||
If the model selects more, only the first `max_tools` will be used.
|
||||
|
||||
If not specified, there is no limit.
|
||||
always_include: Tool names to always include regardless of selection.
|
||||
These do not count against the max_tools limit.
|
||||
|
||||
These do not count against the `max_tools` limit.
|
||||
"""
|
||||
super().__init__()
|
||||
self.system_prompt = system_prompt
|
||||
@@ -144,7 +154,8 @@ class LLMToolSelectorMiddleware(AgentMiddleware):
|
||||
"""Prepare inputs for tool selection.
|
||||
|
||||
Returns:
|
||||
SelectionRequest with prepared inputs, or None if no selection is needed.
|
||||
`SelectionRequest` with prepared inputs, or `None` if no selection is
|
||||
needed.
|
||||
"""
|
||||
# If no tools available, return None
|
||||
if not request.tools or len(request.tools) == 0:
|
||||
@@ -211,7 +222,7 @@ class LLMToolSelectorMiddleware(AgentMiddleware):
|
||||
valid_tool_names: list[str],
|
||||
request: ModelRequest,
|
||||
) -> ModelRequest:
|
||||
"""Process the selection response and return filtered ModelRequest."""
|
||||
"""Process the selection response and return filtered `ModelRequest`."""
|
||||
selected_tool_names: list[str] = []
|
||||
invalid_tool_selections = []
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,10 +1,4 @@
|
||||
"""Entrypoint to using [chat models](https://docs.langchain.com/oss/python/langchain/models) in LangChain.
|
||||
|
||||
!!! warning "Reference docs"
|
||||
This page contains **reference documentation** for chat models. See
|
||||
[the docs](https://docs.langchain.com/oss/python/langchain/models) for conceptual
|
||||
guides, tutorials, and examples on using chat models.
|
||||
""" # noqa: E501
|
||||
"""Entrypoint to using [chat models](https://docs.langchain.com/oss/python/langchain/models) in LangChain.""" # noqa: E501
|
||||
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
|
||||
|
||||
@@ -87,6 +87,21 @@ def init_chat_model(
|
||||
|
||||
You can also specify model and model provider in a single argument using
|
||||
`'{model_provider}:{model}'` format, e.g. `'openai:o1'`.
|
||||
|
||||
Will attempt to infer `model_provider` from model if not specified.
|
||||
|
||||
The following providers will be inferred based on these model prefixes:
|
||||
|
||||
- `gpt-...` | `o1...` | `o3...` -> `openai`
|
||||
- `claude...` -> `anthropic`
|
||||
- `amazon...` -> `bedrock`
|
||||
- `gemini...` -> `google_vertexai`
|
||||
- `command...` -> `cohere`
|
||||
- `accounts/fireworks...` -> `fireworks`
|
||||
- `mistral...` -> `mistralai`
|
||||
- `deepseek...` -> `deepseek`
|
||||
- `grok...` -> `xai`
|
||||
- `sonar...` -> `perplexity`
|
||||
model_provider: The model provider if not specified as part of the model arg
|
||||
(see above).
|
||||
|
||||
@@ -110,24 +125,11 @@ def init_chat_model(
|
||||
- `ollama` -> [`langchain-ollama`](https://docs.langchain.com/oss/python/integrations/providers/ollama)
|
||||
- `google_anthropic_vertex` -> [`langchain-google-vertexai`](https://docs.langchain.com/oss/python/integrations/providers/google)
|
||||
- `deepseek` -> [`langchain-deepseek`](https://docs.langchain.com/oss/python/integrations/providers/deepseek)
|
||||
- `ibm` -> [`langchain-ibm`](https://docs.langchain.com/oss/python/integrations/providers/deepseek)
|
||||
- `ibm` -> [`langchain-ibm`](https://docs.langchain.com/oss/python/integrations/providers/ibm)
|
||||
- `nvidia` -> [`langchain-nvidia-ai-endpoints`](https://docs.langchain.com/oss/python/integrations/providers/nvidia)
|
||||
- `xai` -> [`langchain-xai`](https://docs.langchain.com/oss/python/integrations/providers/xai)
|
||||
- `perplexity` -> [`langchain-perplexity`](https://docs.langchain.com/oss/python/integrations/providers/perplexity)
|
||||
|
||||
Will attempt to infer `model_provider` from model if not specified. The
|
||||
following providers will be inferred based on these model prefixes:
|
||||
|
||||
- `gpt-...` | `o1...` | `o3...` -> `openai`
|
||||
- `claude...` -> `anthropic`
|
||||
- `amazon...` -> `bedrock`
|
||||
- `gemini...` -> `google_vertexai`
|
||||
- `command...` -> `cohere`
|
||||
- `accounts/fireworks...` -> `fireworks`
|
||||
- `mistral...` -> `mistralai`
|
||||
- `deepseek...` -> `deepseek`
|
||||
- `grok...` -> `xai`
|
||||
- `sonar...` -> `perplexity`
|
||||
configurable_fields: Which model parameters are configurable at runtime:
|
||||
|
||||
- `None`: No configurable fields (i.e., a fixed model).
|
||||
@@ -142,6 +144,7 @@ def init_chat_model(
|
||||
If `model` is not specified, then defaults to `("model", "model_provider")`.
|
||||
|
||||
!!! warning "Security note"
|
||||
|
||||
Setting `configurable_fields="any"` means fields like `api_key`,
|
||||
`base_url`, etc., can be altered at runtime, potentially redirecting
|
||||
model requests to a different service/user.
|
||||
|
||||
@@ -1,10 +1,5 @@
|
||||
"""Embeddings models.
|
||||
|
||||
!!! warning "Reference docs"
|
||||
This page contains **reference documentation** for Embeddings. See
|
||||
[the docs](https://docs.langchain.com/oss/python/langchain/retrieval#embedding-models)
|
||||
for conceptual guides, tutorials, and examples on using Embeddings.
|
||||
|
||||
!!! warning "Modules moved"
|
||||
With the release of `langchain 1.0.0`, several embeddings modules were moved to
|
||||
`langchain-classic`, such as `CacheBackedEmbeddings` and all community
|
||||
|
||||
@@ -2,11 +2,6 @@
|
||||
|
||||
Includes message types for different roles (e.g., human, AI, system), as well as types
|
||||
for message content blocks (e.g., text, image, audio) and tool calls.
|
||||
|
||||
!!! warning "Reference docs"
|
||||
This page contains **reference documentation** for Messages. See
|
||||
[the docs](https://docs.langchain.com/oss/python/langchain/messages) for conceptual
|
||||
guides, tutorials, and examples on using Messages.
|
||||
"""
|
||||
|
||||
from langchain_core.messages import (
|
||||
|
||||
@@ -1,10 +1,4 @@
|
||||
"""Tools.
|
||||
|
||||
!!! warning "Reference docs"
|
||||
This page contains **reference documentation** for Tools. See
|
||||
[the docs](https://docs.langchain.com/oss/python/langchain/tools) for conceptual
|
||||
guides, tutorials, and examples on using Tools.
|
||||
"""
|
||||
"""Tools."""
|
||||
|
||||
from langchain_core.tools import (
|
||||
BaseTool,
|
||||
|
||||
@@ -9,10 +9,10 @@ license = { text = "MIT" }
|
||||
readme = "README.md"
|
||||
authors = []
|
||||
|
||||
version = "1.0.4"
|
||||
version = "1.0.5"
|
||||
requires-python = ">=3.10.0,<4.0.0"
|
||||
dependencies = [
|
||||
"langchain-core>=1.0.2,<2.0.0",
|
||||
"langchain-core>=1.0.4,<2.0.0",
|
||||
"langgraph>=1.0.2,<1.1.0",
|
||||
"pydantic>=2.7.4,<3.0.0",
|
||||
]
|
||||
@@ -57,6 +57,7 @@ test = [
|
||||
"pytest-mock",
|
||||
"syrupy>=4.0.2,<5.0.0",
|
||||
"toml>=0.10.2,<1.0.0",
|
||||
"langchain-model-profiles",
|
||||
"langchain-tests",
|
||||
"langchain-openai",
|
||||
]
|
||||
@@ -75,6 +76,7 @@ test_integration = [
|
||||
"cassio>=0.1.0,<1.0.0",
|
||||
"langchainhub>=0.1.16,<1.0.0",
|
||||
"langchain-core",
|
||||
"langchain-model-profiles",
|
||||
"langchain-text-splitters",
|
||||
]
|
||||
|
||||
@@ -83,6 +85,7 @@ prerelease = "allow"
|
||||
|
||||
[tool.uv.sources]
|
||||
langchain-core = { path = "../core", editable = true }
|
||||
langchain-model-profiles = { path = "../model-profiles", editable = true }
|
||||
langchain-tests = { path = "../standard-tests", editable = true }
|
||||
langchain-text-splitters = { path = "../text-splitters", editable = true }
|
||||
langchain-openai = { path = "../partners/openai", editable = true }
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,79 +0,0 @@
|
||||
import pytest
|
||||
from langchain_core.messages import HumanMessage
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from langchain.agents import create_agent
|
||||
from langchain.agents.structured_output import ToolStrategy
|
||||
|
||||
|
||||
class WeatherBaseModel(BaseModel):
|
||||
"""Weather response."""
|
||||
|
||||
temperature: float = Field(description="The temperature in fahrenheit")
|
||||
condition: str = Field(description="Weather condition")
|
||||
|
||||
|
||||
def get_weather(city: str) -> str: # noqa: ARG001
|
||||
"""Get the weather for a city."""
|
||||
return "The weather is sunny and 75°F."
|
||||
|
||||
|
||||
@pytest.mark.requires("langchain_openai")
|
||||
def test_inference_to_native_output() -> None:
|
||||
"""Test that native output is inferred when a model supports it."""
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
model = ChatOpenAI(model="gpt-5")
|
||||
agent = create_agent(
|
||||
model,
|
||||
system_prompt=(
|
||||
"You are a helpful weather assistant. Please call the get_weather tool, "
|
||||
"then use the WeatherReport tool to generate the final response."
|
||||
),
|
||||
tools=[get_weather],
|
||||
response_format=WeatherBaseModel,
|
||||
)
|
||||
response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
|
||||
|
||||
assert isinstance(response["structured_response"], WeatherBaseModel)
|
||||
assert response["structured_response"].temperature == 75.0
|
||||
assert response["structured_response"].condition.lower() == "sunny"
|
||||
assert len(response["messages"]) == 4
|
||||
|
||||
assert [m.type for m in response["messages"]] == [
|
||||
"human", # "What's the weather?"
|
||||
"ai", # "What's the weather?"
|
||||
"tool", # "The weather is sunny and 75°F."
|
||||
"ai", # structured response
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.requires("langchain_openai")
|
||||
def test_inference_to_tool_output() -> None:
|
||||
"""Test that tool output is inferred when a model supports it."""
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
model = ChatOpenAI(model="gpt-4")
|
||||
agent = create_agent(
|
||||
model,
|
||||
system_prompt=(
|
||||
"You are a helpful weather assistant. Please call the get_weather tool, "
|
||||
"then use the WeatherReport tool to generate the final response."
|
||||
),
|
||||
tools=[get_weather],
|
||||
response_format=ToolStrategy(WeatherBaseModel),
|
||||
)
|
||||
response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
|
||||
|
||||
assert isinstance(response["structured_response"], WeatherBaseModel)
|
||||
assert response["structured_response"].temperature == 75.0
|
||||
assert response["structured_response"].condition.lower() == "sunny"
|
||||
assert len(response["messages"]) == 5
|
||||
|
||||
assert [m.type for m in response["messages"]] == [
|
||||
"human", # "What's the weather?"
|
||||
"ai", # "What's the weather?"
|
||||
"tool", # "The weather is sunny and 75°F."
|
||||
"ai", # structured response
|
||||
"tool", # artificial tool message
|
||||
]
|
||||
@@ -0,0 +1,212 @@
|
||||
# serializer version: 1
|
||||
# name: test_agent_graph_with_jump_to_end_as_after_agent
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
tools(tools)
|
||||
NoopZero\2ebefore_agent(NoopZero.before_agent)
|
||||
NoopOne\2eafter_agent(NoopOne.after_agent)
|
||||
NoopTwo\2eafter_agent(NoopTwo.after_agent)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopTwo\2eafter_agent --> NoopOne\2eafter_agent;
|
||||
NoopZero\2ebefore_agent -.-> NoopTwo\2eafter_agent;
|
||||
NoopZero\2ebefore_agent -.-> model;
|
||||
__start__ --> NoopZero\2ebefore_agent;
|
||||
model -.-> NoopTwo\2eafter_agent;
|
||||
model -.-> tools;
|
||||
tools -.-> model;
|
||||
NoopOne\2eafter_agent --> __end__;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_jump[memory]
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
tools(tools)
|
||||
NoopSeven\2ebefore_model(NoopSeven.before_model)
|
||||
NoopSeven\2eafter_model(NoopSeven.after_model)
|
||||
NoopEight\2ebefore_model(NoopEight.before_model)
|
||||
NoopEight\2eafter_model(NoopEight.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
|
||||
NoopEight\2ebefore_model -.-> __end__;
|
||||
NoopEight\2ebefore_model -.-> model;
|
||||
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
|
||||
NoopSeven\2eafter_model -.-> __end__;
|
||||
NoopSeven\2eafter_model -.-> tools;
|
||||
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
|
||||
__start__ --> NoopSeven\2ebefore_model;
|
||||
model --> NoopEight\2eafter_model;
|
||||
tools -.-> NoopSeven\2ebefore_model;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_jump[postgres]
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
tools(tools)
|
||||
NoopSeven\2ebefore_model(NoopSeven.before_model)
|
||||
NoopSeven\2eafter_model(NoopSeven.after_model)
|
||||
NoopEight\2ebefore_model(NoopEight.before_model)
|
||||
NoopEight\2eafter_model(NoopEight.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
|
||||
NoopEight\2ebefore_model -.-> __end__;
|
||||
NoopEight\2ebefore_model -.-> model;
|
||||
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
|
||||
NoopSeven\2eafter_model -.-> __end__;
|
||||
NoopSeven\2eafter_model -.-> tools;
|
||||
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
|
||||
__start__ --> NoopSeven\2ebefore_model;
|
||||
model --> NoopEight\2eafter_model;
|
||||
tools -.-> NoopSeven\2ebefore_model;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_jump[postgres_pipe]
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
tools(tools)
|
||||
NoopSeven\2ebefore_model(NoopSeven.before_model)
|
||||
NoopSeven\2eafter_model(NoopSeven.after_model)
|
||||
NoopEight\2ebefore_model(NoopEight.before_model)
|
||||
NoopEight\2eafter_model(NoopEight.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
|
||||
NoopEight\2ebefore_model -.-> __end__;
|
||||
NoopEight\2ebefore_model -.-> model;
|
||||
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
|
||||
NoopSeven\2eafter_model -.-> __end__;
|
||||
NoopSeven\2eafter_model -.-> tools;
|
||||
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
|
||||
__start__ --> NoopSeven\2ebefore_model;
|
||||
model --> NoopEight\2eafter_model;
|
||||
tools -.-> NoopSeven\2ebefore_model;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_jump[postgres_pool]
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
tools(tools)
|
||||
NoopSeven\2ebefore_model(NoopSeven.before_model)
|
||||
NoopSeven\2eafter_model(NoopSeven.after_model)
|
||||
NoopEight\2ebefore_model(NoopEight.before_model)
|
||||
NoopEight\2eafter_model(NoopEight.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
|
||||
NoopEight\2ebefore_model -.-> __end__;
|
||||
NoopEight\2ebefore_model -.-> model;
|
||||
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
|
||||
NoopSeven\2eafter_model -.-> __end__;
|
||||
NoopSeven\2eafter_model -.-> tools;
|
||||
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
|
||||
__start__ --> NoopSeven\2ebefore_model;
|
||||
model --> NoopEight\2eafter_model;
|
||||
tools -.-> NoopSeven\2ebefore_model;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_jump[sqlite]
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
tools(tools)
|
||||
NoopSeven\2ebefore_model(NoopSeven.before_model)
|
||||
NoopSeven\2eafter_model(NoopSeven.after_model)
|
||||
NoopEight\2ebefore_model(NoopEight.before_model)
|
||||
NoopEight\2eafter_model(NoopEight.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
|
||||
NoopEight\2ebefore_model -.-> __end__;
|
||||
NoopEight\2ebefore_model -.-> model;
|
||||
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
|
||||
NoopSeven\2eafter_model -.-> __end__;
|
||||
NoopSeven\2eafter_model -.-> tools;
|
||||
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
|
||||
__start__ --> NoopSeven\2ebefore_model;
|
||||
model --> NoopEight\2eafter_model;
|
||||
tools -.-> NoopSeven\2ebefore_model;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_simple_agent_graph
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
tools(tools)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
__start__ --> model;
|
||||
model -.-> __end__;
|
||||
model -.-> tools;
|
||||
tools -.-> model;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
@@ -0,0 +1,95 @@
|
||||
# serializer version: 1
|
||||
# name: test_async_middleware_with_can_jump_to_graph_snapshot
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
async_before_with_jump\2ebefore_model(async_before_with_jump.before_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
__start__ --> async_before_with_jump\2ebefore_model;
|
||||
async_before_with_jump\2ebefore_model -.-> __end__;
|
||||
async_before_with_jump\2ebefore_model -.-> model;
|
||||
model --> __end__;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_async_middleware_with_can_jump_to_graph_snapshot.1
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
async_after_with_jump\2eafter_model(async_after_with_jump.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
__start__ --> model;
|
||||
async_after_with_jump\2eafter_model -.-> __end__;
|
||||
async_after_with_jump\2eafter_model -.-> model;
|
||||
model --> async_after_with_jump\2eafter_model;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_async_middleware_with_can_jump_to_graph_snapshot.2
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
async_before_early_exit\2ebefore_model(async_before_early_exit.before_model)
|
||||
async_after_retry\2eafter_model(async_after_retry.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
__start__ --> async_before_early_exit\2ebefore_model;
|
||||
async_after_retry\2eafter_model -.-> __end__;
|
||||
async_after_retry\2eafter_model -.-> async_before_early_exit\2ebefore_model;
|
||||
async_before_early_exit\2ebefore_model -.-> __end__;
|
||||
async_before_early_exit\2ebefore_model -.-> model;
|
||||
model --> async_after_retry\2eafter_model;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_async_middleware_with_can_jump_to_graph_snapshot.3
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
sync_before_with_jump\2ebefore_model(sync_before_with_jump.before_model)
|
||||
async_after_with_jumps\2eafter_model(async_after_with_jumps.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
__start__ --> sync_before_with_jump\2ebefore_model;
|
||||
async_after_with_jumps\2eafter_model -.-> __end__;
|
||||
async_after_with_jumps\2eafter_model -.-> sync_before_with_jump\2ebefore_model;
|
||||
model --> async_after_with_jumps\2eafter_model;
|
||||
sync_before_with_jump\2ebefore_model -.-> __end__;
|
||||
sync_before_with_jump\2ebefore_model -.-> model;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
@@ -0,0 +1,289 @@
|
||||
# serializer version: 1
|
||||
# name: test_create_agent_diagram
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
__start__ --> model;
|
||||
model --> __end__;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_diagram.1
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
NoopOne\2ebefore_model(NoopOne.before_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopOne\2ebefore_model --> model;
|
||||
__start__ --> NoopOne\2ebefore_model;
|
||||
model --> __end__;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_diagram.10
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
NoopTen\2ebefore_model(NoopTen.before_model)
|
||||
NoopTen\2eafter_model(NoopTen.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopTen\2ebefore_model --> model;
|
||||
__start__ --> NoopTen\2ebefore_model;
|
||||
model --> NoopTen\2eafter_model;
|
||||
NoopTen\2eafter_model --> __end__;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_diagram.11
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
NoopTen\2ebefore_model(NoopTen.before_model)
|
||||
NoopTen\2eafter_model(NoopTen.after_model)
|
||||
NoopEleven\2ebefore_model(NoopEleven.before_model)
|
||||
NoopEleven\2eafter_model(NoopEleven.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopEleven\2eafter_model --> NoopTen\2eafter_model;
|
||||
NoopEleven\2ebefore_model --> model;
|
||||
NoopTen\2ebefore_model --> NoopEleven\2ebefore_model;
|
||||
__start__ --> NoopTen\2ebefore_model;
|
||||
model --> NoopEleven\2eafter_model;
|
||||
NoopTen\2eafter_model --> __end__;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_diagram.2
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
NoopOne\2ebefore_model(NoopOne.before_model)
|
||||
NoopTwo\2ebefore_model(NoopTwo.before_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopOne\2ebefore_model --> NoopTwo\2ebefore_model;
|
||||
NoopTwo\2ebefore_model --> model;
|
||||
__start__ --> NoopOne\2ebefore_model;
|
||||
model --> __end__;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_diagram.3
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
NoopOne\2ebefore_model(NoopOne.before_model)
|
||||
NoopTwo\2ebefore_model(NoopTwo.before_model)
|
||||
NoopThree\2ebefore_model(NoopThree.before_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopOne\2ebefore_model --> NoopTwo\2ebefore_model;
|
||||
NoopThree\2ebefore_model --> model;
|
||||
NoopTwo\2ebefore_model --> NoopThree\2ebefore_model;
|
||||
__start__ --> NoopOne\2ebefore_model;
|
||||
model --> __end__;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_diagram.4
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
NoopFour\2eafter_model(NoopFour.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
__start__ --> model;
|
||||
model --> NoopFour\2eafter_model;
|
||||
NoopFour\2eafter_model --> __end__;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_diagram.5
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
NoopFour\2eafter_model(NoopFour.after_model)
|
||||
NoopFive\2eafter_model(NoopFive.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopFive\2eafter_model --> NoopFour\2eafter_model;
|
||||
__start__ --> model;
|
||||
model --> NoopFive\2eafter_model;
|
||||
NoopFour\2eafter_model --> __end__;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_diagram.6
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
NoopFour\2eafter_model(NoopFour.after_model)
|
||||
NoopFive\2eafter_model(NoopFive.after_model)
|
||||
NoopSix\2eafter_model(NoopSix.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopFive\2eafter_model --> NoopFour\2eafter_model;
|
||||
NoopSix\2eafter_model --> NoopFive\2eafter_model;
|
||||
__start__ --> model;
|
||||
model --> NoopSix\2eafter_model;
|
||||
NoopFour\2eafter_model --> __end__;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_diagram.7
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
NoopSeven\2ebefore_model(NoopSeven.before_model)
|
||||
NoopSeven\2eafter_model(NoopSeven.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopSeven\2ebefore_model --> model;
|
||||
__start__ --> NoopSeven\2ebefore_model;
|
||||
model --> NoopSeven\2eafter_model;
|
||||
NoopSeven\2eafter_model --> __end__;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_diagram.8
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
NoopSeven\2ebefore_model(NoopSeven.before_model)
|
||||
NoopSeven\2eafter_model(NoopSeven.after_model)
|
||||
NoopEight\2ebefore_model(NoopEight.before_model)
|
||||
NoopEight\2eafter_model(NoopEight.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
|
||||
NoopEight\2ebefore_model --> model;
|
||||
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
|
||||
__start__ --> NoopSeven\2ebefore_model;
|
||||
model --> NoopEight\2eafter_model;
|
||||
NoopSeven\2eafter_model --> __end__;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_diagram.9
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
NoopSeven\2ebefore_model(NoopSeven.before_model)
|
||||
NoopSeven\2eafter_model(NoopSeven.after_model)
|
||||
NoopEight\2ebefore_model(NoopEight.before_model)
|
||||
NoopEight\2eafter_model(NoopEight.after_model)
|
||||
NoopNine\2ebefore_model(NoopNine.before_model)
|
||||
NoopNine\2eafter_model(NoopNine.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
|
||||
NoopEight\2ebefore_model --> NoopNine\2ebefore_model;
|
||||
NoopNine\2eafter_model --> NoopEight\2eafter_model;
|
||||
NoopNine\2ebefore_model --> model;
|
||||
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
|
||||
__start__ --> NoopSeven\2ebefore_model;
|
||||
model --> NoopNine\2eafter_model;
|
||||
NoopSeven\2eafter_model --> __end__;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
@@ -0,0 +1,212 @@
|
||||
# serializer version: 1
|
||||
# name: test_agent_graph_with_jump_to_end_as_after_agent
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
tools(tools)
|
||||
NoopZero\2ebefore_agent(NoopZero.before_agent)
|
||||
NoopOne\2eafter_agent(NoopOne.after_agent)
|
||||
NoopTwo\2eafter_agent(NoopTwo.after_agent)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopTwo\2eafter_agent --> NoopOne\2eafter_agent;
|
||||
NoopZero\2ebefore_agent -.-> NoopTwo\2eafter_agent;
|
||||
NoopZero\2ebefore_agent -.-> model;
|
||||
__start__ --> NoopZero\2ebefore_agent;
|
||||
model -.-> NoopTwo\2eafter_agent;
|
||||
model -.-> tools;
|
||||
tools -.-> model;
|
||||
NoopOne\2eafter_agent --> __end__;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_jump[memory]
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
tools(tools)
|
||||
NoopSeven\2ebefore_model(NoopSeven.before_model)
|
||||
NoopSeven\2eafter_model(NoopSeven.after_model)
|
||||
NoopEight\2ebefore_model(NoopEight.before_model)
|
||||
NoopEight\2eafter_model(NoopEight.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
|
||||
NoopEight\2ebefore_model -.-> __end__;
|
||||
NoopEight\2ebefore_model -.-> model;
|
||||
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
|
||||
NoopSeven\2eafter_model -.-> __end__;
|
||||
NoopSeven\2eafter_model -.-> tools;
|
||||
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
|
||||
__start__ --> NoopSeven\2ebefore_model;
|
||||
model --> NoopEight\2eafter_model;
|
||||
tools -.-> NoopSeven\2ebefore_model;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_jump[postgres]
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
tools(tools)
|
||||
NoopSeven\2ebefore_model(NoopSeven.before_model)
|
||||
NoopSeven\2eafter_model(NoopSeven.after_model)
|
||||
NoopEight\2ebefore_model(NoopEight.before_model)
|
||||
NoopEight\2eafter_model(NoopEight.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
|
||||
NoopEight\2ebefore_model -.-> __end__;
|
||||
NoopEight\2ebefore_model -.-> model;
|
||||
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
|
||||
NoopSeven\2eafter_model -.-> __end__;
|
||||
NoopSeven\2eafter_model -.-> tools;
|
||||
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
|
||||
__start__ --> NoopSeven\2ebefore_model;
|
||||
model --> NoopEight\2eafter_model;
|
||||
tools -.-> NoopSeven\2ebefore_model;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_jump[postgres_pipe]
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
tools(tools)
|
||||
NoopSeven\2ebefore_model(NoopSeven.before_model)
|
||||
NoopSeven\2eafter_model(NoopSeven.after_model)
|
||||
NoopEight\2ebefore_model(NoopEight.before_model)
|
||||
NoopEight\2eafter_model(NoopEight.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
|
||||
NoopEight\2ebefore_model -.-> __end__;
|
||||
NoopEight\2ebefore_model -.-> model;
|
||||
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
|
||||
NoopSeven\2eafter_model -.-> __end__;
|
||||
NoopSeven\2eafter_model -.-> tools;
|
||||
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
|
||||
__start__ --> NoopSeven\2ebefore_model;
|
||||
model --> NoopEight\2eafter_model;
|
||||
tools -.-> NoopSeven\2ebefore_model;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_jump[postgres_pool]
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
tools(tools)
|
||||
NoopSeven\2ebefore_model(NoopSeven.before_model)
|
||||
NoopSeven\2eafter_model(NoopSeven.after_model)
|
||||
NoopEight\2ebefore_model(NoopEight.before_model)
|
||||
NoopEight\2eafter_model(NoopEight.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
|
||||
NoopEight\2ebefore_model -.-> __end__;
|
||||
NoopEight\2ebefore_model -.-> model;
|
||||
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
|
||||
NoopSeven\2eafter_model -.-> __end__;
|
||||
NoopSeven\2eafter_model -.-> tools;
|
||||
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
|
||||
__start__ --> NoopSeven\2ebefore_model;
|
||||
model --> NoopEight\2eafter_model;
|
||||
tools -.-> NoopSeven\2ebefore_model;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_jump[sqlite]
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
tools(tools)
|
||||
NoopSeven\2ebefore_model(NoopSeven.before_model)
|
||||
NoopSeven\2eafter_model(NoopSeven.after_model)
|
||||
NoopEight\2ebefore_model(NoopEight.before_model)
|
||||
NoopEight\2eafter_model(NoopEight.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
|
||||
NoopEight\2ebefore_model -.-> __end__;
|
||||
NoopEight\2ebefore_model -.-> model;
|
||||
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
|
||||
NoopSeven\2eafter_model -.-> __end__;
|
||||
NoopSeven\2eafter_model -.-> tools;
|
||||
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
|
||||
__start__ --> NoopSeven\2ebefore_model;
|
||||
model --> NoopEight\2eafter_model;
|
||||
tools -.-> NoopSeven\2ebefore_model;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_simple_agent_graph
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
tools(tools)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
__start__ --> model;
|
||||
model -.-> __end__;
|
||||
model -.-> tools;
|
||||
tools -.-> model;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
@@ -0,0 +1,95 @@
|
||||
# serializer version: 1
|
||||
# name: test_async_middleware_with_can_jump_to_graph_snapshot
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
async_before_with_jump\2ebefore_model(async_before_with_jump.before_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
__start__ --> async_before_with_jump\2ebefore_model;
|
||||
async_before_with_jump\2ebefore_model -.-> __end__;
|
||||
async_before_with_jump\2ebefore_model -.-> model;
|
||||
model --> __end__;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_async_middleware_with_can_jump_to_graph_snapshot.1
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
async_after_with_jump\2eafter_model(async_after_with_jump.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
__start__ --> model;
|
||||
async_after_with_jump\2eafter_model -.-> __end__;
|
||||
async_after_with_jump\2eafter_model -.-> model;
|
||||
model --> async_after_with_jump\2eafter_model;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_async_middleware_with_can_jump_to_graph_snapshot.2
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
async_before_early_exit\2ebefore_model(async_before_early_exit.before_model)
|
||||
async_after_retry\2eafter_model(async_after_retry.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
__start__ --> async_before_early_exit\2ebefore_model;
|
||||
async_after_retry\2eafter_model -.-> __end__;
|
||||
async_after_retry\2eafter_model -.-> async_before_early_exit\2ebefore_model;
|
||||
async_before_early_exit\2ebefore_model -.-> __end__;
|
||||
async_before_early_exit\2ebefore_model -.-> model;
|
||||
model --> async_after_retry\2eafter_model;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_async_middleware_with_can_jump_to_graph_snapshot.3
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
sync_before_with_jump\2ebefore_model(sync_before_with_jump.before_model)
|
||||
async_after_with_jumps\2eafter_model(async_after_with_jumps.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
__start__ --> sync_before_with_jump\2ebefore_model;
|
||||
async_after_with_jumps\2eafter_model -.-> __end__;
|
||||
async_after_with_jumps\2eafter_model -.-> sync_before_with_jump\2ebefore_model;
|
||||
model --> async_after_with_jumps\2eafter_model;
|
||||
sync_before_with_jump\2ebefore_model -.-> __end__;
|
||||
sync_before_with_jump\2ebefore_model -.-> model;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
@@ -0,0 +1,289 @@
|
||||
# serializer version: 1
|
||||
# name: test_create_agent_diagram
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
__start__ --> model;
|
||||
model --> __end__;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_diagram.1
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
NoopOne\2ebefore_model(NoopOne.before_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopOne\2ebefore_model --> model;
|
||||
__start__ --> NoopOne\2ebefore_model;
|
||||
model --> __end__;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_diagram.10
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
NoopTen\2ebefore_model(NoopTen.before_model)
|
||||
NoopTen\2eafter_model(NoopTen.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopTen\2ebefore_model --> model;
|
||||
__start__ --> NoopTen\2ebefore_model;
|
||||
model --> NoopTen\2eafter_model;
|
||||
NoopTen\2eafter_model --> __end__;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_diagram.11
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
NoopTen\2ebefore_model(NoopTen.before_model)
|
||||
NoopTen\2eafter_model(NoopTen.after_model)
|
||||
NoopEleven\2ebefore_model(NoopEleven.before_model)
|
||||
NoopEleven\2eafter_model(NoopEleven.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopEleven\2eafter_model --> NoopTen\2eafter_model;
|
||||
NoopEleven\2ebefore_model --> model;
|
||||
NoopTen\2ebefore_model --> NoopEleven\2ebefore_model;
|
||||
__start__ --> NoopTen\2ebefore_model;
|
||||
model --> NoopEleven\2eafter_model;
|
||||
NoopTen\2eafter_model --> __end__;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_diagram.2
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
NoopOne\2ebefore_model(NoopOne.before_model)
|
||||
NoopTwo\2ebefore_model(NoopTwo.before_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopOne\2ebefore_model --> NoopTwo\2ebefore_model;
|
||||
NoopTwo\2ebefore_model --> model;
|
||||
__start__ --> NoopOne\2ebefore_model;
|
||||
model --> __end__;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_diagram.3
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
NoopOne\2ebefore_model(NoopOne.before_model)
|
||||
NoopTwo\2ebefore_model(NoopTwo.before_model)
|
||||
NoopThree\2ebefore_model(NoopThree.before_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopOne\2ebefore_model --> NoopTwo\2ebefore_model;
|
||||
NoopThree\2ebefore_model --> model;
|
||||
NoopTwo\2ebefore_model --> NoopThree\2ebefore_model;
|
||||
__start__ --> NoopOne\2ebefore_model;
|
||||
model --> __end__;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_diagram.4
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
NoopFour\2eafter_model(NoopFour.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
__start__ --> model;
|
||||
model --> NoopFour\2eafter_model;
|
||||
NoopFour\2eafter_model --> __end__;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_diagram.5
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
NoopFour\2eafter_model(NoopFour.after_model)
|
||||
NoopFive\2eafter_model(NoopFive.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopFive\2eafter_model --> NoopFour\2eafter_model;
|
||||
__start__ --> model;
|
||||
model --> NoopFive\2eafter_model;
|
||||
NoopFour\2eafter_model --> __end__;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_diagram.6
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
NoopFour\2eafter_model(NoopFour.after_model)
|
||||
NoopFive\2eafter_model(NoopFive.after_model)
|
||||
NoopSix\2eafter_model(NoopSix.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopFive\2eafter_model --> NoopFour\2eafter_model;
|
||||
NoopSix\2eafter_model --> NoopFive\2eafter_model;
|
||||
__start__ --> model;
|
||||
model --> NoopSix\2eafter_model;
|
||||
NoopFour\2eafter_model --> __end__;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_diagram.7
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
NoopSeven\2ebefore_model(NoopSeven.before_model)
|
||||
NoopSeven\2eafter_model(NoopSeven.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopSeven\2ebefore_model --> model;
|
||||
__start__ --> NoopSeven\2ebefore_model;
|
||||
model --> NoopSeven\2eafter_model;
|
||||
NoopSeven\2eafter_model --> __end__;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_diagram.8
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
NoopSeven\2ebefore_model(NoopSeven.before_model)
|
||||
NoopSeven\2eafter_model(NoopSeven.after_model)
|
||||
NoopEight\2ebefore_model(NoopEight.before_model)
|
||||
NoopEight\2eafter_model(NoopEight.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
|
||||
NoopEight\2ebefore_model --> model;
|
||||
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
|
||||
__start__ --> NoopSeven\2ebefore_model;
|
||||
model --> NoopEight\2eafter_model;
|
||||
NoopSeven\2eafter_model --> __end__;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_diagram.9
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
NoopSeven\2ebefore_model(NoopSeven.before_model)
|
||||
NoopSeven\2eafter_model(NoopSeven.after_model)
|
||||
NoopEight\2ebefore_model(NoopEight.before_model)
|
||||
NoopEight\2eafter_model(NoopEight.after_model)
|
||||
NoopNine\2ebefore_model(NoopNine.before_model)
|
||||
NoopNine\2eafter_model(NoopNine.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
|
||||
NoopEight\2ebefore_model --> NoopNine\2ebefore_model;
|
||||
NoopNine\2eafter_model --> NoopEight\2eafter_model;
|
||||
NoopNine\2ebefore_model --> model;
|
||||
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
|
||||
__start__ --> NoopSeven\2ebefore_model;
|
||||
model --> NoopNine\2eafter_model;
|
||||
NoopSeven\2eafter_model --> __end__;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
@@ -0,0 +1,212 @@
|
||||
# serializer version: 1
|
||||
# name: test_agent_graph_with_jump_to_end_as_after_agent
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
tools(tools)
|
||||
NoopZero\2ebefore_agent(NoopZero.before_agent)
|
||||
NoopOne\2eafter_agent(NoopOne.after_agent)
|
||||
NoopTwo\2eafter_agent(NoopTwo.after_agent)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopTwo\2eafter_agent --> NoopOne\2eafter_agent;
|
||||
NoopZero\2ebefore_agent -.-> NoopTwo\2eafter_agent;
|
||||
NoopZero\2ebefore_agent -.-> model;
|
||||
__start__ --> NoopZero\2ebefore_agent;
|
||||
model -.-> NoopTwo\2eafter_agent;
|
||||
model -.-> tools;
|
||||
tools -.-> model;
|
||||
NoopOne\2eafter_agent --> __end__;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_jump[memory]
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
tools(tools)
|
||||
NoopSeven\2ebefore_model(NoopSeven.before_model)
|
||||
NoopSeven\2eafter_model(NoopSeven.after_model)
|
||||
NoopEight\2ebefore_model(NoopEight.before_model)
|
||||
NoopEight\2eafter_model(NoopEight.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
|
||||
NoopEight\2ebefore_model -.-> __end__;
|
||||
NoopEight\2ebefore_model -.-> model;
|
||||
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
|
||||
NoopSeven\2eafter_model -.-> __end__;
|
||||
NoopSeven\2eafter_model -.-> tools;
|
||||
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
|
||||
__start__ --> NoopSeven\2ebefore_model;
|
||||
model --> NoopEight\2eafter_model;
|
||||
tools -.-> NoopSeven\2ebefore_model;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_jump[postgres]
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
tools(tools)
|
||||
NoopSeven\2ebefore_model(NoopSeven.before_model)
|
||||
NoopSeven\2eafter_model(NoopSeven.after_model)
|
||||
NoopEight\2ebefore_model(NoopEight.before_model)
|
||||
NoopEight\2eafter_model(NoopEight.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
|
||||
NoopEight\2ebefore_model -.-> __end__;
|
||||
NoopEight\2ebefore_model -.-> model;
|
||||
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
|
||||
NoopSeven\2eafter_model -.-> __end__;
|
||||
NoopSeven\2eafter_model -.-> tools;
|
||||
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
|
||||
__start__ --> NoopSeven\2ebefore_model;
|
||||
model --> NoopEight\2eafter_model;
|
||||
tools -.-> NoopSeven\2ebefore_model;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_jump[postgres_pipe]
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
tools(tools)
|
||||
NoopSeven\2ebefore_model(NoopSeven.before_model)
|
||||
NoopSeven\2eafter_model(NoopSeven.after_model)
|
||||
NoopEight\2ebefore_model(NoopEight.before_model)
|
||||
NoopEight\2eafter_model(NoopEight.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
|
||||
NoopEight\2ebefore_model -.-> __end__;
|
||||
NoopEight\2ebefore_model -.-> model;
|
||||
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
|
||||
NoopSeven\2eafter_model -.-> __end__;
|
||||
NoopSeven\2eafter_model -.-> tools;
|
||||
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
|
||||
__start__ --> NoopSeven\2ebefore_model;
|
||||
model --> NoopEight\2eafter_model;
|
||||
tools -.-> NoopSeven\2ebefore_model;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_jump[postgres_pool]
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
tools(tools)
|
||||
NoopSeven\2ebefore_model(NoopSeven.before_model)
|
||||
NoopSeven\2eafter_model(NoopSeven.after_model)
|
||||
NoopEight\2ebefore_model(NoopEight.before_model)
|
||||
NoopEight\2eafter_model(NoopEight.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
|
||||
NoopEight\2ebefore_model -.-> __end__;
|
||||
NoopEight\2ebefore_model -.-> model;
|
||||
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
|
||||
NoopSeven\2eafter_model -.-> __end__;
|
||||
NoopSeven\2eafter_model -.-> tools;
|
||||
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
|
||||
__start__ --> NoopSeven\2ebefore_model;
|
||||
model --> NoopEight\2eafter_model;
|
||||
tools -.-> NoopSeven\2ebefore_model;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_jump[sqlite]
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
tools(tools)
|
||||
NoopSeven\2ebefore_model(NoopSeven.before_model)
|
||||
NoopSeven\2eafter_model(NoopSeven.after_model)
|
||||
NoopEight\2ebefore_model(NoopEight.before_model)
|
||||
NoopEight\2eafter_model(NoopEight.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
|
||||
NoopEight\2ebefore_model -.-> __end__;
|
||||
NoopEight\2ebefore_model -.-> model;
|
||||
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
|
||||
NoopSeven\2eafter_model -.-> __end__;
|
||||
NoopSeven\2eafter_model -.-> tools;
|
||||
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
|
||||
__start__ --> NoopSeven\2ebefore_model;
|
||||
model --> NoopEight\2eafter_model;
|
||||
tools -.-> NoopSeven\2ebefore_model;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_simple_agent_graph
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
tools(tools)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
__start__ --> model;
|
||||
model -.-> __end__;
|
||||
model -.-> tools;
|
||||
tools -.-> model;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
@@ -22,7 +22,7 @@ from langchain.agents.middleware.types import (
|
||||
hook_config,
|
||||
)
|
||||
from langchain.agents.factory import create_agent, _get_can_jump_to
|
||||
from .model import FakeToolCallingModel
|
||||
from ...model import FakeToolCallingModel
|
||||
|
||||
|
||||
class CustomState(AgentState):
|
||||
@@ -0,0 +1,193 @@
|
||||
from collections.abc import Callable
|
||||
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
|
||||
from langchain.agents.factory import create_agent
|
||||
from langchain.agents.middleware.types import AgentMiddleware, ModelRequest
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
from ...model import FakeToolCallingModel
|
||||
|
||||
|
||||
def test_create_agent_diagram(
|
||||
snapshot: SnapshotAssertion,
|
||||
):
|
||||
class NoopOne(AgentMiddleware):
|
||||
def before_model(self, state, runtime):
|
||||
pass
|
||||
|
||||
class NoopTwo(AgentMiddleware):
|
||||
def before_model(self, state, runtime):
|
||||
pass
|
||||
|
||||
class NoopThree(AgentMiddleware):
|
||||
def before_model(self, state, runtime):
|
||||
pass
|
||||
|
||||
class NoopFour(AgentMiddleware):
|
||||
def after_model(self, state, runtime):
|
||||
pass
|
||||
|
||||
class NoopFive(AgentMiddleware):
|
||||
def after_model(self, state, runtime):
|
||||
pass
|
||||
|
||||
class NoopSix(AgentMiddleware):
|
||||
def after_model(self, state, runtime):
|
||||
pass
|
||||
|
||||
class NoopSeven(AgentMiddleware):
|
||||
def before_model(self, state, runtime):
|
||||
pass
|
||||
|
||||
def after_model(self, state, runtime):
|
||||
pass
|
||||
|
||||
class NoopEight(AgentMiddleware):
|
||||
def before_model(self, state, runtime):
|
||||
pass
|
||||
|
||||
def after_model(self, state, runtime):
|
||||
pass
|
||||
|
||||
class NoopNine(AgentMiddleware):
|
||||
def before_model(self, state, runtime):
|
||||
pass
|
||||
|
||||
def after_model(self, state, runtime):
|
||||
pass
|
||||
|
||||
class NoopTen(AgentMiddleware):
|
||||
def before_model(self, state, runtime):
|
||||
pass
|
||||
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], AIMessage],
|
||||
) -> AIMessage:
|
||||
return handler(request)
|
||||
|
||||
def after_model(self, state, runtime):
|
||||
pass
|
||||
|
||||
class NoopEleven(AgentMiddleware):
|
||||
def before_model(self, state, runtime):
|
||||
pass
|
||||
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], AIMessage],
|
||||
) -> AIMessage:
|
||||
return handler(request)
|
||||
|
||||
def after_model(self, state, runtime):
|
||||
pass
|
||||
|
||||
agent_zero = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
tools=[],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
)
|
||||
|
||||
assert agent_zero.get_graph().draw_mermaid() == snapshot
|
||||
|
||||
agent_one = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
tools=[],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[NoopOne()],
|
||||
)
|
||||
|
||||
assert agent_one.get_graph().draw_mermaid() == snapshot
|
||||
|
||||
agent_two = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
tools=[],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[NoopOne(), NoopTwo()],
|
||||
)
|
||||
|
||||
assert agent_two.get_graph().draw_mermaid() == snapshot
|
||||
|
||||
agent_three = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
tools=[],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[NoopOne(), NoopTwo(), NoopThree()],
|
||||
)
|
||||
|
||||
assert agent_three.get_graph().draw_mermaid() == snapshot
|
||||
|
||||
agent_four = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
tools=[],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[NoopFour()],
|
||||
)
|
||||
|
||||
assert agent_four.get_graph().draw_mermaid() == snapshot
|
||||
|
||||
agent_five = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
tools=[],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[NoopFour(), NoopFive()],
|
||||
)
|
||||
|
||||
assert agent_five.get_graph().draw_mermaid() == snapshot
|
||||
|
||||
agent_six = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
tools=[],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[NoopFour(), NoopFive(), NoopSix()],
|
||||
)
|
||||
|
||||
assert agent_six.get_graph().draw_mermaid() == snapshot
|
||||
|
||||
agent_seven = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
tools=[],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[NoopSeven()],
|
||||
)
|
||||
|
||||
assert agent_seven.get_graph().draw_mermaid() == snapshot
|
||||
|
||||
agent_eight = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
tools=[],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[NoopSeven(), NoopEight()],
|
||||
)
|
||||
|
||||
assert agent_eight.get_graph().draw_mermaid() == snapshot
|
||||
|
||||
agent_nine = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
tools=[],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[NoopSeven(), NoopEight(), NoopNine()],
|
||||
)
|
||||
|
||||
assert agent_nine.get_graph().draw_mermaid() == snapshot
|
||||
|
||||
agent_ten = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
tools=[],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[NoopTen()],
|
||||
)
|
||||
|
||||
assert agent_ten.get_graph().draw_mermaid() == snapshot
|
||||
|
||||
agent_eleven = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
tools=[],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[NoopTen(), NoopEleven()],
|
||||
)
|
||||
|
||||
assert agent_eleven.get_graph().draw_mermaid() == snapshot
|
||||
File diff suppressed because it is too large
Load Diff
@@ -14,7 +14,7 @@ from langgraph.checkpoint.memory import InMemorySaver
|
||||
from langchain.agents.factory import create_agent
|
||||
from langchain.agents.middleware.types import AgentMiddleware, wrap_tool_call
|
||||
from langchain.agents.middleware.types import ToolCallRequest
|
||||
from tests.unit_tests.agents.test_middleware_agent import FakeToolCallingModel
|
||||
from tests.unit_tests.agents.model import FakeToolCallingModel
|
||||
|
||||
|
||||
@tool
|
||||
@@ -9,7 +9,7 @@ from langchain.agents.middleware.types import AgentMiddleware, AgentState, Model
|
||||
from langgraph.prebuilt.tool_node import ToolNode
|
||||
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
|
||||
from langchain_core.tools import tool
|
||||
from .model import FakeToolCallingModel
|
||||
from tests.unit_tests.agents.model import FakeToolCallingModel
|
||||
|
||||
|
||||
def test_model_request_tools_are_base_tools() -> None:
|
||||
@@ -1,18 +1,30 @@
|
||||
"""Unit tests for wrap_model_call middleware generator protocol."""
|
||||
"""Unit tests for wrap_model_call hook and @wrap_model_call decorator.
|
||||
|
||||
This module tests the wrap_model_call functionality in three forms:
|
||||
1. As a middleware method (AgentMiddleware.wrap_model_call)
|
||||
2. As a decorator (@wrap_model_call)
|
||||
3. Async variant (AgentMiddleware.awrap_model_call)
|
||||
"""
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
|
||||
from langchain.agents import create_agent
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ModelRequest,
|
||||
wrap_model_call,
|
||||
)
|
||||
|
||||
from ...model import FakeToolCallingModel
|
||||
|
||||
class TestBasicOnModelCall:
|
||||
|
||||
class TestBasicWrapModelCall:
|
||||
"""Test basic wrap_model_call functionality."""
|
||||
|
||||
def test_passthrough_middleware(self) -> None:
|
||||
@@ -70,7 +82,7 @@ class TestBasicOnModelCall:
|
||||
assert counter.call_count == 1
|
||||
|
||||
|
||||
class TestRetryMiddleware:
|
||||
class TestRetryLogic:
|
||||
"""Test retry logic with wrap_model_call."""
|
||||
|
||||
def test_simple_retry_on_error(self) -> None:
|
||||
@@ -91,12 +103,10 @@ class TestRetryMiddleware:
|
||||
|
||||
def wrap_model_call(self, request, handler):
|
||||
try:
|
||||
result = handler(request)
|
||||
return result
|
||||
return handler(request)
|
||||
except Exception:
|
||||
self.retry_count += 1
|
||||
result = handler(request)
|
||||
return result
|
||||
return handler(request)
|
||||
|
||||
retry_middleware = RetryOnceMiddleware()
|
||||
model = FailOnceThenSucceed(messages=iter([AIMessage(content="Success")]))
|
||||
@@ -125,8 +135,7 @@ class TestRetryMiddleware:
|
||||
for attempt in range(self.max_retries):
|
||||
self.attempts.append(attempt + 1)
|
||||
try:
|
||||
result = handler(request)
|
||||
return result
|
||||
return handler(request)
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
continue
|
||||
@@ -143,6 +152,75 @@ class TestRetryMiddleware:
|
||||
|
||||
assert retry_middleware.attempts == [1, 2, 3]
|
||||
|
||||
def test_no_retry_propagates_error(self) -> None:
|
||||
"""Test that error is propagated when middleware doesn't retry."""
|
||||
|
||||
class FailingModel(BaseChatModel):
|
||||
"""Model that always fails."""
|
||||
|
||||
def _generate(self, messages, **kwargs):
|
||||
raise ValueError("Model error")
|
||||
|
||||
@property
|
||||
def _llm_type(self):
|
||||
return "failing"
|
||||
|
||||
class NoRetryMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(self, request, handler):
|
||||
return handler(request)
|
||||
|
||||
agent = create_agent(model=FailingModel(), middleware=[NoRetryMiddleware()])
|
||||
|
||||
with pytest.raises(ValueError, match="Model error"):
|
||||
agent.invoke({"messages": [HumanMessage("Test")]})
|
||||
|
||||
def test_max_attempts_limit(self) -> None:
|
||||
"""Test that middleware controls termination via retry limits."""
|
||||
|
||||
class AlwaysFailingModel(BaseChatModel):
|
||||
"""Model that always fails."""
|
||||
|
||||
def _generate(self, messages, **kwargs):
|
||||
raise ValueError("Always fails")
|
||||
|
||||
@property
|
||||
def _llm_type(self):
|
||||
return "always_failing"
|
||||
|
||||
class LimitedRetryMiddleware(AgentMiddleware):
|
||||
"""Middleware that limits its own retries."""
|
||||
|
||||
def __init__(self, max_retries: int = 10):
|
||||
super().__init__()
|
||||
self.max_retries = max_retries
|
||||
self.attempt_count = 0
|
||||
|
||||
def wrap_model_call(self, request, handler):
|
||||
last_exception = None
|
||||
for attempt in range(self.max_retries):
|
||||
self.attempt_count += 1
|
||||
try:
|
||||
return handler(request)
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
# Continue to retry
|
||||
|
||||
# All retries exhausted, re-raise the last error
|
||||
if last_exception:
|
||||
raise last_exception
|
||||
|
||||
model = AlwaysFailingModel()
|
||||
middleware = LimitedRetryMiddleware(max_retries=10)
|
||||
|
||||
agent = create_agent(model=model, middleware=[middleware])
|
||||
|
||||
# Should fail with the model's error after middleware stops retrying
|
||||
with pytest.raises(ValueError, match="Always fails"):
|
||||
agent.invoke({"messages": [HumanMessage("Test")]})
|
||||
|
||||
# Should have attempted exactly 10 times as configured
|
||||
assert middleware.attempt_count == 10
|
||||
|
||||
|
||||
class TestResponseRewriting:
|
||||
"""Test response content rewriting with wrap_model_call."""
|
||||
@@ -185,6 +263,28 @@ class TestResponseRewriting:
|
||||
|
||||
assert result["messages"][1].content == "[BOT]: Response"
|
||||
|
||||
def test_multi_stage_transformation(self) -> None:
|
||||
"""Test middleware applying multiple transformations."""
|
||||
|
||||
class MultiTransformMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(self, request, handler):
|
||||
result = handler(request)
|
||||
# result is ModelResponse, extract AIMessage from it
|
||||
ai_message = result.result[0]
|
||||
|
||||
# First transformation: uppercase
|
||||
content = ai_message.content.upper()
|
||||
# Second transformation: add prefix and suffix
|
||||
content = f"[START] {content} [END]"
|
||||
return AIMessage(content=content)
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="hello")]))
|
||||
agent = create_agent(model=model, middleware=[MultiTransformMiddleware()])
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Test")]})
|
||||
|
||||
assert result["messages"][1].content == "[START] HELLO [END]"
|
||||
|
||||
|
||||
class TestErrorHandling:
|
||||
"""Test error handling with wrap_model_call."""
|
||||
@@ -200,9 +300,8 @@ class TestErrorHandling:
|
||||
def wrap_model_call(self, request, handler):
|
||||
try:
|
||||
return handler(request)
|
||||
except Exception:
|
||||
fallback = AIMessage(content="Error handled gracefully")
|
||||
return fallback
|
||||
except Exception as e:
|
||||
return AIMessage(content=f"Error occurred: {e}. Using fallback response.")
|
||||
|
||||
model = AlwaysFailModel(messages=iter([]))
|
||||
agent = create_agent(model=model, middleware=[ErrorToSuccessMiddleware()])
|
||||
@@ -210,7 +309,8 @@ class TestErrorHandling:
|
||||
# Should not raise, middleware converts error to response
|
||||
result = agent.invoke({"messages": [HumanMessage("Test")]})
|
||||
|
||||
assert "Error handled gracefully" in result["messages"][1].content
|
||||
assert "Error occurred" in result["messages"][1].content
|
||||
assert "fallback response" in result["messages"][1].content
|
||||
|
||||
def test_selective_error_handling(self) -> None:
|
||||
"""Test middleware that only handles specific errors."""
|
||||
@@ -224,8 +324,7 @@ class TestErrorHandling:
|
||||
try:
|
||||
return handler(request)
|
||||
except ConnectionError:
|
||||
fallback = AIMessage(content="Network issue, try again later")
|
||||
return fallback
|
||||
return AIMessage(content="Network issue, try again later")
|
||||
|
||||
model = SpecificErrorModel(messages=iter([]))
|
||||
agent = create_agent(model=model, middleware=[SelectiveErrorMiddleware()])
|
||||
@@ -247,8 +346,7 @@ class TestErrorHandling:
|
||||
return result
|
||||
except Exception:
|
||||
call_log.append("caught-error")
|
||||
fallback = AIMessage(content="Recovered from error")
|
||||
return fallback
|
||||
return AIMessage(content="Recovered from error")
|
||||
|
||||
# Test 1: Success path
|
||||
call_log.clear()
|
||||
@@ -403,7 +501,6 @@ class TestStateAndRuntime:
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
return handler(request)
|
||||
break # Success
|
||||
except Exception:
|
||||
if attempt == max_retries - 1:
|
||||
raise
|
||||
@@ -460,6 +557,49 @@ class TestMiddlewareComposition:
|
||||
"outer-after",
|
||||
]
|
||||
|
||||
def test_three_middleware_composition(self) -> None:
|
||||
"""Test composition of three middleware."""
|
||||
execution_order = []
|
||||
|
||||
class FirstMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(self, request, handler):
|
||||
execution_order.append("first-before")
|
||||
response = handler(request)
|
||||
execution_order.append("first-after")
|
||||
return response
|
||||
|
||||
class SecondMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(self, request, handler):
|
||||
execution_order.append("second-before")
|
||||
response = handler(request)
|
||||
execution_order.append("second-after")
|
||||
return response
|
||||
|
||||
class ThirdMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(self, request, handler):
|
||||
execution_order.append("third-before")
|
||||
response = handler(request)
|
||||
execution_order.append("third-after")
|
||||
return response
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")]))
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
middleware=[FirstMiddleware(), SecondMiddleware(), ThirdMiddleware()],
|
||||
)
|
||||
|
||||
agent.invoke({"messages": [HumanMessage("Test")]})
|
||||
|
||||
# First wraps Second wraps Third: 1-before, 2-before, 3-before, model, 3-after, 2-after, 1-after
|
||||
assert execution_order == [
|
||||
"first-before",
|
||||
"second-before",
|
||||
"third-before",
|
||||
"third-after",
|
||||
"second-after",
|
||||
"first-after",
|
||||
]
|
||||
|
||||
def test_retry_with_logging(self) -> None:
|
||||
"""Test retry middleware composed with logging middleware."""
|
||||
call_count = {"value": 0}
|
||||
@@ -549,11 +689,9 @@ class TestMiddlewareComposition:
|
||||
class RetryMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(self, request, handler):
|
||||
try:
|
||||
result = handler(request)
|
||||
return result
|
||||
return handler(request)
|
||||
except Exception:
|
||||
result = handler(request)
|
||||
return result
|
||||
return handler(request)
|
||||
|
||||
class UppercaseMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(self, request, handler):
|
||||
@@ -571,49 +709,6 @@ class TestMiddlewareComposition:
|
||||
# Should retry and uppercase the result
|
||||
assert result["messages"][1].content == "SUCCESS"
|
||||
|
||||
def test_three_middleware_composition(self) -> None:
|
||||
"""Test composition of three middleware."""
|
||||
execution_order = []
|
||||
|
||||
class FirstMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(self, request, handler):
|
||||
execution_order.append("first-before")
|
||||
response = handler(request)
|
||||
execution_order.append("first-after")
|
||||
return response
|
||||
|
||||
class SecondMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(self, request, handler):
|
||||
execution_order.append("second-before")
|
||||
response = handler(request)
|
||||
execution_order.append("second-after")
|
||||
return response
|
||||
|
||||
class ThirdMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(self, request, handler):
|
||||
execution_order.append("third-before")
|
||||
response = handler(request)
|
||||
execution_order.append("third-after")
|
||||
return response
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")]))
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
middleware=[FirstMiddleware(), SecondMiddleware(), ThirdMiddleware()],
|
||||
)
|
||||
|
||||
agent.invoke({"messages": [HumanMessage("Test")]})
|
||||
|
||||
# First wraps Second wraps Third: 1-before, 2-before, 3-before, model, 3-after, 2-after, 1-after
|
||||
assert execution_order == [
|
||||
"first-before",
|
||||
"second-before",
|
||||
"third-before",
|
||||
"third-after",
|
||||
"second-after",
|
||||
"first-after",
|
||||
]
|
||||
|
||||
def test_middle_retry_middleware(self) -> None:
|
||||
"""Test that middle middleware doing retry causes inner to execute twice."""
|
||||
execution_order = []
|
||||
@@ -674,7 +769,306 @@ class TestMiddlewareComposition:
|
||||
assert len(model_calls) == 2
|
||||
|
||||
|
||||
class TestAsyncOnModelCall:
|
||||
class TestWrapModelCallDecorator:
|
||||
"""Test the @wrap_model_call decorator for creating middleware."""
|
||||
|
||||
def test_basic_decorator_usage(self) -> None:
|
||||
"""Test basic decorator usage without parameters."""
|
||||
|
||||
@wrap_model_call
|
||||
def passthrough_middleware(request, handler):
|
||||
return handler(request)
|
||||
|
||||
# Should return an AgentMiddleware instance
|
||||
assert isinstance(passthrough_middleware, AgentMiddleware)
|
||||
|
||||
# Should work in agent
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
|
||||
agent = create_agent(model=model, middleware=[passthrough_middleware])
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Hi")]})
|
||||
assert len(result["messages"]) == 2
|
||||
assert result["messages"][1].content == "Hello"
|
||||
|
||||
def test_decorator_with_custom_name(self) -> None:
|
||||
"""Test decorator with custom middleware name."""
|
||||
|
||||
@wrap_model_call(name="CustomMiddleware")
|
||||
def my_middleware(request, handler):
|
||||
return handler(request)
|
||||
|
||||
assert isinstance(my_middleware, AgentMiddleware)
|
||||
assert my_middleware.__class__.__name__ == "CustomMiddleware"
|
||||
|
||||
def test_decorator_retry_logic(self) -> None:
|
||||
"""Test decorator for implementing retry logic."""
|
||||
call_count = {"value": 0}
|
||||
|
||||
class FailOnceThenSucceed(GenericFakeChatModel):
|
||||
def _generate(self, messages, **kwargs):
|
||||
call_count["value"] += 1
|
||||
if call_count["value"] == 1:
|
||||
raise ValueError("First call fails")
|
||||
return super()._generate(messages, **kwargs)
|
||||
|
||||
@wrap_model_call
|
||||
def retry_once(request, handler):
|
||||
try:
|
||||
return handler(request)
|
||||
except Exception:
|
||||
# Retry once
|
||||
return handler(request)
|
||||
|
||||
model = FailOnceThenSucceed(messages=iter([AIMessage(content="Success")]))
|
||||
agent = create_agent(model=model, middleware=[retry_once])
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Test")]})
|
||||
|
||||
assert call_count["value"] == 2
|
||||
assert result["messages"][1].content == "Success"
|
||||
|
||||
def test_decorator_response_rewriting(self) -> None:
|
||||
"""Test decorator for rewriting responses."""
|
||||
|
||||
@wrap_model_call
|
||||
def uppercase_responses(request, handler):
|
||||
result = handler(request)
|
||||
# result is ModelResponse, extract AIMessage from it
|
||||
ai_message = result.result[0]
|
||||
return AIMessage(content=ai_message.content.upper())
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="hello world")]))
|
||||
agent = create_agent(model=model, middleware=[uppercase_responses])
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Test")]})
|
||||
|
||||
assert result["messages"][1].content == "HELLO WORLD"
|
||||
|
||||
def test_decorator_error_handling(self) -> None:
|
||||
"""Test decorator for error recovery."""
|
||||
|
||||
class AlwaysFailModel(GenericFakeChatModel):
|
||||
def _generate(self, messages, **kwargs):
|
||||
raise ValueError("Model error")
|
||||
|
||||
@wrap_model_call
|
||||
def error_to_fallback(request, handler):
|
||||
try:
|
||||
return handler(request)
|
||||
except Exception:
|
||||
return AIMessage(content="Fallback response")
|
||||
|
||||
model = AlwaysFailModel(messages=iter([]))
|
||||
agent = create_agent(model=model, middleware=[error_to_fallback])
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Test")]})
|
||||
|
||||
assert result["messages"][1].content == "Fallback response"
|
||||
|
||||
def test_decorator_with_state_access(self) -> None:
|
||||
"""Test decorator accessing agent state."""
|
||||
state_values = []
|
||||
|
||||
@wrap_model_call
|
||||
def log_state(request, handler):
|
||||
state_values.append(request.state.get("messages"))
|
||||
return handler(request)
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")]))
|
||||
agent = create_agent(model=model, middleware=[log_state])
|
||||
|
||||
agent.invoke({"messages": [HumanMessage("Test")]})
|
||||
|
||||
# State should contain the user message
|
||||
assert len(state_values) == 1
|
||||
assert len(state_values[0]) == 1
|
||||
assert state_values[0][0].content == "Test"
|
||||
|
||||
def test_multiple_decorated_middleware(self) -> None:
|
||||
"""Test composition of multiple decorated middleware."""
|
||||
execution_order = []
|
||||
|
||||
@wrap_model_call
|
||||
def outer_middleware(request, handler):
|
||||
execution_order.append("outer-before")
|
||||
result = handler(request)
|
||||
execution_order.append("outer-after")
|
||||
return result
|
||||
|
||||
@wrap_model_call
|
||||
def inner_middleware(request, handler):
|
||||
execution_order.append("inner-before")
|
||||
result = handler(request)
|
||||
execution_order.append("inner-after")
|
||||
return result
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")]))
|
||||
agent = create_agent(model=model, middleware=[outer_middleware, inner_middleware])
|
||||
|
||||
agent.invoke({"messages": [HumanMessage("Test")]})
|
||||
|
||||
assert execution_order == [
|
||||
"outer-before",
|
||||
"inner-before",
|
||||
"inner-after",
|
||||
"outer-after",
|
||||
]
|
||||
|
||||
def test_decorator_with_custom_state_schema(self) -> None:
|
||||
"""Test decorator with custom state schema."""
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
class CustomState(TypedDict):
|
||||
messages: list
|
||||
custom_field: str
|
||||
|
||||
@wrap_model_call(state_schema=CustomState)
|
||||
def middleware_with_schema(request, handler):
|
||||
return handler(request)
|
||||
|
||||
assert isinstance(middleware_with_schema, AgentMiddleware)
|
||||
# Custom state schema should be set
|
||||
assert middleware_with_schema.state_schema == CustomState
|
||||
|
||||
def test_decorator_with_tools_parameter(self) -> None:
|
||||
"""Test decorator with tools parameter."""
|
||||
from langchain_core.tools import tool
|
||||
|
||||
@tool
|
||||
def test_tool(query: str) -> str:
|
||||
"""A test tool."""
|
||||
return f"Result: {query}"
|
||||
|
||||
@wrap_model_call(tools=[test_tool])
|
||||
def middleware_with_tools(request, handler):
|
||||
return handler(request)
|
||||
|
||||
assert isinstance(middleware_with_tools, AgentMiddleware)
|
||||
assert len(middleware_with_tools.tools) == 1
|
||||
assert middleware_with_tools.tools[0].name == "test_tool"
|
||||
|
||||
def test_decorator_parentheses_optional(self) -> None:
|
||||
"""Test that decorator works both with and without parentheses."""
|
||||
|
||||
# Without parentheses
|
||||
@wrap_model_call
|
||||
def middleware_no_parens(request, handler):
|
||||
return handler(request)
|
||||
|
||||
# With parentheses
|
||||
@wrap_model_call()
|
||||
def middleware_with_parens(request, handler):
|
||||
return handler(request)
|
||||
|
||||
assert isinstance(middleware_no_parens, AgentMiddleware)
|
||||
assert isinstance(middleware_with_parens, AgentMiddleware)
|
||||
|
||||
def test_decorator_preserves_function_name(self) -> None:
|
||||
"""Test that decorator uses function name for class name."""
|
||||
|
||||
@wrap_model_call
|
||||
def my_custom_middleware(request, handler):
|
||||
return handler(request)
|
||||
|
||||
assert my_custom_middleware.__class__.__name__ == "my_custom_middleware"
|
||||
|
||||
def test_decorator_mixed_with_class_middleware(self) -> None:
|
||||
"""Test decorated middleware mixed with class-based middleware."""
|
||||
execution_order = []
|
||||
|
||||
@wrap_model_call
|
||||
def decorated_middleware(request, handler):
|
||||
execution_order.append("decorated-before")
|
||||
result = handler(request)
|
||||
execution_order.append("decorated-after")
|
||||
return result
|
||||
|
||||
class ClassMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(self, request, handler):
|
||||
execution_order.append("class-before")
|
||||
result = handler(request)
|
||||
execution_order.append("class-after")
|
||||
return result
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")]))
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
middleware=[decorated_middleware, ClassMiddleware()],
|
||||
)
|
||||
|
||||
agent.invoke({"messages": [HumanMessage("Test")]})
|
||||
|
||||
# Decorated is outer, class-based is inner
|
||||
assert execution_order == [
|
||||
"decorated-before",
|
||||
"class-before",
|
||||
"class-after",
|
||||
"decorated-after",
|
||||
]
|
||||
|
||||
def test_decorator_complex_retry_logic(self) -> None:
|
||||
"""Test decorator with complex retry logic and backoff."""
|
||||
attempts = []
|
||||
call_count = {"value": 0}
|
||||
|
||||
class UnreliableModel(GenericFakeChatModel):
|
||||
def _generate(self, messages, **kwargs):
|
||||
call_count["value"] += 1
|
||||
if call_count["value"] <= 2:
|
||||
raise ValueError(f"Attempt {call_count['value']} failed")
|
||||
return super()._generate(messages, **kwargs)
|
||||
|
||||
@wrap_model_call
|
||||
def retry_with_tracking(request, handler):
|
||||
max_retries = 3
|
||||
for attempt in range(max_retries):
|
||||
attempts.append(attempt + 1)
|
||||
try:
|
||||
return handler(request)
|
||||
except Exception:
|
||||
# On error, continue to next attempt
|
||||
if attempt < max_retries - 1:
|
||||
continue # Retry
|
||||
else:
|
||||
raise # All retries failed
|
||||
|
||||
model = UnreliableModel(messages=iter([AIMessage(content="Finally worked")]))
|
||||
agent = create_agent(model=model, middleware=[retry_with_tracking])
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Test")]})
|
||||
|
||||
assert attempts == [1, 2, 3]
|
||||
assert result["messages"][1].content == "Finally worked"
|
||||
|
||||
def test_decorator_request_modification(self) -> None:
|
||||
"""Test decorator modifying request before execution."""
|
||||
modified_prompts = []
|
||||
|
||||
@wrap_model_call
|
||||
def add_system_prompt(request, handler):
|
||||
# Modify request to add system prompt
|
||||
modified_request = ModelRequest(
|
||||
messages=request.messages,
|
||||
model=request.model,
|
||||
system_prompt="You are a helpful assistant",
|
||||
tool_choice=request.tool_choice,
|
||||
tools=request.tools,
|
||||
response_format=request.response_format,
|
||||
state={},
|
||||
runtime=None,
|
||||
)
|
||||
modified_prompts.append(modified_request.system_prompt)
|
||||
return handler(modified_request)
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")]))
|
||||
agent = create_agent(model=model, middleware=[add_system_prompt])
|
||||
|
||||
agent.invoke({"messages": [HumanMessage("Test")]})
|
||||
|
||||
assert modified_prompts == ["You are a helpful assistant"]
|
||||
|
||||
|
||||
class TestAsyncWrapModelCall:
|
||||
"""Test async execution with wrap_model_call."""
|
||||
|
||||
async def test_async_model_with_middleware(self) -> None:
|
||||
@@ -686,7 +1080,6 @@ class TestAsyncOnModelCall:
|
||||
log.append("before")
|
||||
result = await handler(request)
|
||||
log.append("after")
|
||||
|
||||
return result
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Async response")]))
|
||||
@@ -723,6 +1116,92 @@ class TestAsyncOnModelCall:
|
||||
assert call_count["value"] == 2
|
||||
assert result["messages"][1].content == "Async success"
|
||||
|
||||
async def test_decorator_with_async_agent(self) -> None:
|
||||
"""Test that decorated middleware works with async agent invocation."""
|
||||
call_log = []
|
||||
|
||||
@wrap_model_call
|
||||
async def logging_middleware(request, handler):
|
||||
call_log.append("before")
|
||||
result = await handler(request)
|
||||
call_log.append("after")
|
||||
return result
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Async response")]))
|
||||
agent = create_agent(model=model, middleware=[logging_middleware])
|
||||
|
||||
result = await agent.ainvoke({"messages": [HumanMessage("Test")]})
|
||||
|
||||
assert call_log == ["before", "after"]
|
||||
assert result["messages"][1].content == "Async response"
|
||||
|
||||
|
||||
class TestSyncAsyncInterop:
|
||||
"""Test sync/async interoperability."""
|
||||
|
||||
def test_sync_invoke_with_only_async_middleware_raises_error(self) -> None:
|
||||
"""Test that sync invoke with only async middleware raises error."""
|
||||
|
||||
class AsyncOnlyMiddleware(AgentMiddleware):
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[AIMessage]],
|
||||
) -> AIMessage:
|
||||
return await handler(request)
|
||||
|
||||
agent = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
tools=[],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[AsyncOnlyMiddleware()],
|
||||
)
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
agent.invoke({"messages": [HumanMessage("hello")]})
|
||||
|
||||
def test_sync_invoke_with_mixed_middleware(self) -> None:
|
||||
"""Test that sync invoke works with mixed sync/async middleware when sync versions exist."""
|
||||
calls = []
|
||||
|
||||
class MixedMiddleware(AgentMiddleware):
|
||||
def before_model(self, state, runtime) -> None:
|
||||
calls.append("MixedMiddleware.before_model")
|
||||
|
||||
async def abefore_model(self, state, runtime) -> None:
|
||||
calls.append("MixedMiddleware.abefore_model")
|
||||
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], AIMessage],
|
||||
) -> AIMessage:
|
||||
calls.append("MixedMiddleware.wrap_model_call")
|
||||
return handler(request)
|
||||
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[AIMessage]],
|
||||
) -> AIMessage:
|
||||
calls.append("MixedMiddleware.awrap_model_call")
|
||||
return await handler(request)
|
||||
|
||||
agent = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
tools=[],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[MixedMiddleware()],
|
||||
)
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("hello")]})
|
||||
|
||||
# In sync mode, only sync methods should be called
|
||||
assert calls == [
|
||||
"MixedMiddleware.before_model",
|
||||
"MixedMiddleware.wrap_model_call",
|
||||
]
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
"""Test edge cases and error conditions."""
|
||||
@@ -753,12 +1232,10 @@ class TestEdgeCases:
|
||||
def wrap_model_call(self, request, handler):
|
||||
attempts.append("first-attempt")
|
||||
try:
|
||||
result = handler(request)
|
||||
return result
|
||||
return handler(request)
|
||||
except Exception:
|
||||
attempts.append("retry-attempt")
|
||||
result = handler(request)
|
||||
return result
|
||||
return handler(request)
|
||||
|
||||
call_count = {"value": 0}
|
||||
|
||||
@@ -14,7 +14,7 @@ from langgraph.types import Command
|
||||
from langchain.agents.factory import create_agent
|
||||
from langchain.agents.middleware.types import wrap_tool_call
|
||||
from langchain.agents.middleware.types import ToolCallRequest
|
||||
from tests.unit_tests.agents.test_middleware_agent import FakeToolCallingModel
|
||||
from tests.unit_tests.agents.model import FakeToolCallingModel
|
||||
|
||||
|
||||
@tool
|
||||
@@ -7,6 +7,9 @@ import pytest
|
||||
|
||||
from langchain.agents.middleware.file_search import (
|
||||
FilesystemFileSearchMiddleware,
|
||||
_expand_include_patterns,
|
||||
_is_valid_include_pattern,
|
||||
_match_include_pattern,
|
||||
)
|
||||
|
||||
|
||||
@@ -259,3 +262,105 @@ class TestPathTraversalSecurity:
|
||||
|
||||
assert result == "No matches found"
|
||||
assert "secret" not in result
|
||||
|
||||
|
||||
class TestExpandIncludePatterns:
|
||||
"""Tests for _expand_include_patterns helper function."""
|
||||
|
||||
def test_expand_patterns_basic_brace_expansion(self) -> None:
|
||||
"""Test basic brace expansion with multiple options."""
|
||||
result = _expand_include_patterns("*.{py,txt}")
|
||||
assert result == ["*.py", "*.txt"]
|
||||
|
||||
def test_expand_patterns_nested_braces(self) -> None:
|
||||
"""Test nested brace expansion."""
|
||||
result = _expand_include_patterns("test.{a,b}.{c,d}")
|
||||
assert result is not None
|
||||
assert len(result) == 4
|
||||
assert "test.a.c" in result
|
||||
assert "test.b.d" in result
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"pattern",
|
||||
[
|
||||
"*.py}", # closing brace without opening
|
||||
"*.{}", # empty braces
|
||||
"*.{py", # unclosed brace
|
||||
],
|
||||
)
|
||||
def test_expand_patterns_invalid_braces(self, pattern: str) -> None:
|
||||
"""Test patterns with invalid brace syntax return None."""
|
||||
result = _expand_include_patterns(pattern)
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestValidateIncludePattern:
|
||||
"""Tests for _is_valid_include_pattern helper function."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"pattern",
|
||||
[
|
||||
"", # empty pattern
|
||||
"*.py\x00", # null byte
|
||||
"*.py\n", # newline
|
||||
],
|
||||
)
|
||||
def test_validate_invalid_patterns(self, pattern: str) -> None:
|
||||
"""Test that invalid patterns are rejected."""
|
||||
assert not _is_valid_include_pattern(pattern)
|
||||
|
||||
|
||||
class TestMatchIncludePattern:
|
||||
"""Tests for _match_include_pattern helper function."""
|
||||
|
||||
def test_match_pattern_with_braces(self) -> None:
|
||||
"""Test matching with brace expansion."""
|
||||
assert _match_include_pattern("test.py", "*.{py,txt}")
|
||||
assert _match_include_pattern("test.txt", "*.{py,txt}")
|
||||
assert not _match_include_pattern("test.md", "*.{py,txt}")
|
||||
|
||||
def test_match_pattern_invalid_expansion(self) -> None:
|
||||
"""Test matching with pattern that cannot be expanded returns False."""
|
||||
assert not _match_include_pattern("test.py", "*.{}")
|
||||
|
||||
|
||||
class TestGrepEdgeCases:
|
||||
"""Tests for edge cases in grep search."""
|
||||
|
||||
def test_grep_with_special_chars_in_pattern(self, tmp_path: Path) -> None:
|
||||
"""Test grep with special characters in pattern."""
|
||||
(tmp_path / "test.py").write_text("def test():\n pass\n", encoding="utf-8")
|
||||
|
||||
middleware = FilesystemFileSearchMiddleware(root_path=str(tmp_path), use_ripgrep=False)
|
||||
|
||||
result = middleware.grep_search.func(pattern="def.*:")
|
||||
|
||||
assert "/test.py" in result
|
||||
|
||||
def test_grep_case_insensitive(self, tmp_path: Path) -> None:
|
||||
"""Test grep with case-insensitive search."""
|
||||
(tmp_path / "test.py").write_text("HELLO world\n", encoding="utf-8")
|
||||
|
||||
middleware = FilesystemFileSearchMiddleware(root_path=str(tmp_path), use_ripgrep=False)
|
||||
|
||||
result = middleware.grep_search.func(pattern="(?i)hello")
|
||||
|
||||
assert "/test.py" in result
|
||||
|
||||
def test_grep_with_large_file_skipping(self, tmp_path: Path) -> None:
|
||||
"""Test that grep skips files larger than max_file_size_mb."""
|
||||
# Create a file larger than 1MB
|
||||
large_content = "x" * (2 * 1024 * 1024) # 2MB
|
||||
(tmp_path / "large.txt").write_text(large_content, encoding="utf-8")
|
||||
(tmp_path / "small.txt").write_text("x", encoding="utf-8")
|
||||
|
||||
middleware = FilesystemFileSearchMiddleware(
|
||||
root_path=str(tmp_path),
|
||||
use_ripgrep=False,
|
||||
max_file_size_mb=1, # 1MB limit
|
||||
)
|
||||
|
||||
result = middleware.grep_search.func(pattern="x")
|
||||
|
||||
# Large file should be skipped
|
||||
assert "/small.txt" in result
|
||||
@@ -0,0 +1,575 @@
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, HumanMessage, ToolCall, ToolMessage
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from langchain.agents.middleware.human_in_the_loop import (
|
||||
Action,
|
||||
HumanInTheLoopMiddleware,
|
||||
)
|
||||
from langchain.agents.middleware.types import AgentState
|
||||
|
||||
|
||||
def test_human_in_the_loop_middleware_initialization() -> None:
|
||||
"""Test HumanInTheLoopMiddleware initialization."""
|
||||
|
||||
middleware = HumanInTheLoopMiddleware(
|
||||
interrupt_on={"test_tool": {"allowed_decisions": ["approve", "edit", "reject"]}},
|
||||
description_prefix="Custom prefix",
|
||||
)
|
||||
|
||||
assert middleware.interrupt_on == {
|
||||
"test_tool": {"allowed_decisions": ["approve", "edit", "reject"]}
|
||||
}
|
||||
assert middleware.description_prefix == "Custom prefix"
|
||||
|
||||
|
||||
def test_human_in_the_loop_middleware_no_interrupts_needed() -> None:
|
||||
"""Test HumanInTheLoopMiddleware when no interrupts are needed."""
|
||||
|
||||
middleware = HumanInTheLoopMiddleware(
|
||||
interrupt_on={"test_tool": {"allowed_decisions": ["approve", "edit", "reject"]}}
|
||||
)
|
||||
|
||||
# Test with no messages
|
||||
state: dict[str, Any] = {"messages": []}
|
||||
result = middleware.after_model(state, None)
|
||||
assert result is None
|
||||
|
||||
# Test with message but no tool calls
|
||||
state = {"messages": [HumanMessage(content="Hello"), AIMessage(content="Hi there")]}
|
||||
result = middleware.after_model(state, None)
|
||||
assert result is None
|
||||
|
||||
# Test with tool calls that don't require interrupts
|
||||
ai_message = AIMessage(
|
||||
content="I'll help you",
|
||||
tool_calls=[{"name": "other_tool", "args": {"input": "test"}, "id": "1"}],
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
|
||||
result = middleware.after_model(state, None)
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_human_in_the_loop_middleware_single_tool_accept() -> None:
|
||||
"""Test HumanInTheLoopMiddleware with single tool accept response."""
|
||||
|
||||
middleware = HumanInTheLoopMiddleware(
|
||||
interrupt_on={"test_tool": {"allowed_decisions": ["approve", "edit", "reject"]}}
|
||||
)
|
||||
|
||||
ai_message = AIMessage(
|
||||
content="I'll help you",
|
||||
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
|
||||
|
||||
def mock_accept(requests):
|
||||
return {"decisions": [{"type": "approve"}]}
|
||||
|
||||
with patch("langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_accept):
|
||||
result = middleware.after_model(state, None)
|
||||
assert result is not None
|
||||
assert "messages" in result
|
||||
assert len(result["messages"]) == 1
|
||||
assert result["messages"][0] == ai_message
|
||||
assert result["messages"][0].tool_calls == ai_message.tool_calls
|
||||
|
||||
state["messages"].append(
|
||||
ToolMessage(content="Tool message", name="test_tool", tool_call_id="1")
|
||||
)
|
||||
state["messages"].append(AIMessage(content="test_tool called with result: Tool message"))
|
||||
|
||||
result = middleware.after_model(state, None)
|
||||
# No interrupts needed
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_human_in_the_loop_middleware_single_tool_edit() -> None:
|
||||
"""Test HumanInTheLoopMiddleware with single tool edit response."""
|
||||
middleware = HumanInTheLoopMiddleware(
|
||||
interrupt_on={"test_tool": {"allowed_decisions": ["approve", "edit", "reject"]}}
|
||||
)
|
||||
|
||||
ai_message = AIMessage(
|
||||
content="I'll help you",
|
||||
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
|
||||
|
||||
def mock_edit(requests):
|
||||
return {
|
||||
"decisions": [
|
||||
{
|
||||
"type": "edit",
|
||||
"edited_action": Action(
|
||||
name="test_tool",
|
||||
args={"input": "edited"},
|
||||
),
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with patch("langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_edit):
|
||||
result = middleware.after_model(state, None)
|
||||
assert result is not None
|
||||
assert "messages" in result
|
||||
assert len(result["messages"]) == 1
|
||||
assert result["messages"][0].tool_calls[0]["args"] == {"input": "edited"}
|
||||
assert result["messages"][0].tool_calls[0]["id"] == "1" # ID should be preserved
|
||||
|
||||
|
||||
def test_human_in_the_loop_middleware_single_tool_response() -> None:
|
||||
"""Test HumanInTheLoopMiddleware with single tool response with custom message."""
|
||||
|
||||
middleware = HumanInTheLoopMiddleware(
|
||||
interrupt_on={"test_tool": {"allowed_decisions": ["approve", "edit", "reject"]}}
|
||||
)
|
||||
|
||||
ai_message = AIMessage(
|
||||
content="I'll help you",
|
||||
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
|
||||
|
||||
def mock_response(requests):
|
||||
return {"decisions": [{"type": "reject", "message": "Custom response message"}]}
|
||||
|
||||
with patch(
|
||||
"langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_response
|
||||
):
|
||||
result = middleware.after_model(state, None)
|
||||
assert result is not None
|
||||
assert "messages" in result
|
||||
assert len(result["messages"]) == 2
|
||||
assert isinstance(result["messages"][0], AIMessage)
|
||||
assert isinstance(result["messages"][1], ToolMessage)
|
||||
assert result["messages"][1].content == "Custom response message"
|
||||
assert result["messages"][1].name == "test_tool"
|
||||
assert result["messages"][1].tool_call_id == "1"
|
||||
|
||||
|
||||
def test_human_in_the_loop_middleware_multiple_tools_mixed_responses() -> None:
|
||||
"""Test HumanInTheLoopMiddleware with multiple tools and mixed response types."""
|
||||
|
||||
middleware = HumanInTheLoopMiddleware(
|
||||
interrupt_on={
|
||||
"get_forecast": {"allowed_decisions": ["approve", "edit", "reject"]},
|
||||
"get_temperature": {"allowed_decisions": ["approve", "edit", "reject"]},
|
||||
}
|
||||
)
|
||||
|
||||
ai_message = AIMessage(
|
||||
content="I'll help you with weather",
|
||||
tool_calls=[
|
||||
{"name": "get_forecast", "args": {"location": "San Francisco"}, "id": "1"},
|
||||
{"name": "get_temperature", "args": {"location": "San Francisco"}, "id": "2"},
|
||||
],
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="What's the weather?"), ai_message]}
|
||||
|
||||
def mock_mixed_responses(requests):
|
||||
return {
|
||||
"decisions": [
|
||||
{"type": "approve"},
|
||||
{"type": "reject", "message": "User rejected this tool call"},
|
||||
]
|
||||
}
|
||||
|
||||
with patch(
|
||||
"langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_mixed_responses
|
||||
):
|
||||
result = middleware.after_model(state, None)
|
||||
assert result is not None
|
||||
assert "messages" in result
|
||||
assert (
|
||||
len(result["messages"]) == 2
|
||||
) # AI message with accepted tool call + tool message for rejected
|
||||
|
||||
# First message should be the AI message with both tool calls
|
||||
updated_ai_message = result["messages"][0]
|
||||
assert len(updated_ai_message.tool_calls) == 2 # Both tool calls remain
|
||||
assert updated_ai_message.tool_calls[0]["name"] == "get_forecast" # Accepted
|
||||
assert updated_ai_message.tool_calls[1]["name"] == "get_temperature" # Got response
|
||||
|
||||
# Second message should be the tool message for the rejected tool call
|
||||
tool_message = result["messages"][1]
|
||||
assert isinstance(tool_message, ToolMessage)
|
||||
assert tool_message.content == "User rejected this tool call"
|
||||
assert tool_message.name == "get_temperature"
|
||||
|
||||
|
||||
def test_human_in_the_loop_middleware_multiple_tools_edit_responses() -> None:
|
||||
"""Test HumanInTheLoopMiddleware with multiple tools and edit responses."""
|
||||
|
||||
middleware = HumanInTheLoopMiddleware(
|
||||
interrupt_on={
|
||||
"get_forecast": {"allowed_decisions": ["approve", "edit", "reject"]},
|
||||
"get_temperature": {"allowed_decisions": ["approve", "edit", "reject"]},
|
||||
}
|
||||
)
|
||||
|
||||
ai_message = AIMessage(
|
||||
content="I'll help you with weather",
|
||||
tool_calls=[
|
||||
{"name": "get_forecast", "args": {"location": "San Francisco"}, "id": "1"},
|
||||
{"name": "get_temperature", "args": {"location": "San Francisco"}, "id": "2"},
|
||||
],
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="What's the weather?"), ai_message]}
|
||||
|
||||
def mock_edit_responses(requests):
|
||||
return {
|
||||
"decisions": [
|
||||
{
|
||||
"type": "edit",
|
||||
"edited_action": Action(
|
||||
name="get_forecast",
|
||||
args={"location": "New York"},
|
||||
),
|
||||
},
|
||||
{
|
||||
"type": "edit",
|
||||
"edited_action": Action(
|
||||
name="get_temperature",
|
||||
args={"location": "New York"},
|
||||
),
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
with patch(
|
||||
"langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_edit_responses
|
||||
):
|
||||
result = middleware.after_model(state, None)
|
||||
assert result is not None
|
||||
assert "messages" in result
|
||||
assert len(result["messages"]) == 1
|
||||
|
||||
updated_ai_message = result["messages"][0]
|
||||
assert updated_ai_message.tool_calls[0]["args"] == {"location": "New York"}
|
||||
assert updated_ai_message.tool_calls[0]["id"] == "1" # ID preserved
|
||||
assert updated_ai_message.tool_calls[1]["args"] == {"location": "New York"}
|
||||
assert updated_ai_message.tool_calls[1]["id"] == "2" # ID preserved
|
||||
|
||||
|
||||
def test_human_in_the_loop_middleware_edit_with_modified_args() -> None:
|
||||
"""Test HumanInTheLoopMiddleware with edit action that includes modified args."""
|
||||
|
||||
middleware = HumanInTheLoopMiddleware(
|
||||
interrupt_on={"test_tool": {"allowed_decisions": ["approve", "edit", "reject"]}}
|
||||
)
|
||||
|
||||
ai_message = AIMessage(
|
||||
content="I'll help you",
|
||||
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
|
||||
|
||||
def mock_edit_with_args(requests):
|
||||
return {
|
||||
"decisions": [
|
||||
{
|
||||
"type": "edit",
|
||||
"edited_action": Action(
|
||||
name="test_tool",
|
||||
args={"input": "modified"},
|
||||
),
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with patch(
|
||||
"langchain.agents.middleware.human_in_the_loop.interrupt",
|
||||
side_effect=mock_edit_with_args,
|
||||
):
|
||||
result = middleware.after_model(state, None)
|
||||
assert result is not None
|
||||
assert "messages" in result
|
||||
assert len(result["messages"]) == 1
|
||||
|
||||
# Should have modified args
|
||||
updated_ai_message = result["messages"][0]
|
||||
assert updated_ai_message.tool_calls[0]["args"] == {"input": "modified"}
|
||||
assert updated_ai_message.tool_calls[0]["id"] == "1" # ID preserved
|
||||
|
||||
|
||||
def test_human_in_the_loop_middleware_unknown_response_type() -> None:
|
||||
"""Test HumanInTheLoopMiddleware with unknown response type."""
|
||||
middleware = HumanInTheLoopMiddleware(
|
||||
interrupt_on={"test_tool": {"allowed_decisions": ["approve", "edit", "reject"]}}
|
||||
)
|
||||
|
||||
ai_message = AIMessage(
|
||||
content="I'll help you",
|
||||
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
|
||||
|
||||
def mock_unknown(requests):
|
||||
return {"decisions": [{"type": "unknown"}]}
|
||||
|
||||
with patch("langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_unknown):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r"Unexpected human decision: {'type': 'unknown'}. Decision type 'unknown' is not allowed for tool 'test_tool'. Expected one of \['approve', 'edit', 'reject'\] based on the tool's configuration.",
|
||||
):
|
||||
middleware.after_model(state, None)
|
||||
|
||||
|
||||
def test_human_in_the_loop_middleware_disallowed_action() -> None:
|
||||
"""Test HumanInTheLoopMiddleware with action not allowed by tool config."""
|
||||
|
||||
# edit is not allowed by tool config
|
||||
middleware = HumanInTheLoopMiddleware(
|
||||
interrupt_on={"test_tool": {"allowed_decisions": ["approve", "reject"]}}
|
||||
)
|
||||
|
||||
ai_message = AIMessage(
|
||||
content="I'll help you",
|
||||
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
|
||||
|
||||
def mock_disallowed_action(requests):
|
||||
return {
|
||||
"decisions": [
|
||||
{
|
||||
"type": "edit",
|
||||
"edited_action": Action(
|
||||
name="test_tool",
|
||||
args={"input": "modified"},
|
||||
),
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with patch(
|
||||
"langchain.agents.middleware.human_in_the_loop.interrupt",
|
||||
side_effect=mock_disallowed_action,
|
||||
):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r"Unexpected human decision: {'type': 'edit', 'edited_action': {'name': 'test_tool', 'args': {'input': 'modified'}}}. Decision type 'edit' is not allowed for tool 'test_tool'. Expected one of \['approve', 'reject'\] based on the tool's configuration.",
|
||||
):
|
||||
middleware.after_model(state, None)
|
||||
|
||||
|
||||
def test_human_in_the_loop_middleware_mixed_auto_approved_and_interrupt() -> None:
|
||||
"""Test HumanInTheLoopMiddleware with mix of auto-approved and interrupt tools."""
|
||||
|
||||
middleware = HumanInTheLoopMiddleware(
|
||||
interrupt_on={"interrupt_tool": {"allowed_decisions": ["approve", "edit", "reject"]}}
|
||||
)
|
||||
|
||||
ai_message = AIMessage(
|
||||
content="I'll help you",
|
||||
tool_calls=[
|
||||
{"name": "auto_tool", "args": {"input": "auto"}, "id": "1"},
|
||||
{"name": "interrupt_tool", "args": {"input": "interrupt"}, "id": "2"},
|
||||
],
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
|
||||
|
||||
def mock_accept(requests):
|
||||
return {"decisions": [{"type": "approve"}]}
|
||||
|
||||
with patch("langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_accept):
|
||||
result = middleware.after_model(state, None)
|
||||
assert result is not None
|
||||
assert "messages" in result
|
||||
assert len(result["messages"]) == 1
|
||||
|
||||
updated_ai_message = result["messages"][0]
|
||||
# Should have both tools: auto-approved first, then interrupt tool
|
||||
assert len(updated_ai_message.tool_calls) == 2
|
||||
assert updated_ai_message.tool_calls[0]["name"] == "auto_tool"
|
||||
assert updated_ai_message.tool_calls[1]["name"] == "interrupt_tool"
|
||||
|
||||
|
||||
def test_human_in_the_loop_middleware_interrupt_request_structure() -> None:
|
||||
"""Test that interrupt requests are structured correctly."""
|
||||
|
||||
middleware = HumanInTheLoopMiddleware(
|
||||
interrupt_on={"test_tool": {"allowed_decisions": ["approve", "edit", "reject"]}},
|
||||
description_prefix="Custom prefix",
|
||||
)
|
||||
|
||||
ai_message = AIMessage(
|
||||
content="I'll help you",
|
||||
tool_calls=[{"name": "test_tool", "args": {"input": "test", "location": "SF"}, "id": "1"}],
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
|
||||
|
||||
captured_request = None
|
||||
|
||||
def mock_capture_requests(request):
|
||||
nonlocal captured_request
|
||||
captured_request = request
|
||||
return {"decisions": [{"type": "approve"}]}
|
||||
|
||||
with patch(
|
||||
"langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_capture_requests
|
||||
):
|
||||
middleware.after_model(state, None)
|
||||
|
||||
assert captured_request is not None
|
||||
assert "action_requests" in captured_request
|
||||
assert "review_configs" in captured_request
|
||||
|
||||
assert len(captured_request["action_requests"]) == 1
|
||||
action_request = captured_request["action_requests"][0]
|
||||
assert action_request["name"] == "test_tool"
|
||||
assert action_request["args"] == {"input": "test", "location": "SF"}
|
||||
assert "Custom prefix" in action_request["description"]
|
||||
assert "Tool: test_tool" in action_request["description"]
|
||||
assert "Args: {'input': 'test', 'location': 'SF'}" in action_request["description"]
|
||||
|
||||
assert len(captured_request["review_configs"]) == 1
|
||||
review_config = captured_request["review_configs"][0]
|
||||
assert review_config["action_name"] == "test_tool"
|
||||
assert review_config["allowed_decisions"] == ["approve", "edit", "reject"]
|
||||
|
||||
|
||||
def test_human_in_the_loop_middleware_boolean_configs() -> None:
|
||||
"""Test HITL middleware with boolean tool configs."""
|
||||
middleware = HumanInTheLoopMiddleware(interrupt_on={"test_tool": True})
|
||||
|
||||
ai_message = AIMessage(
|
||||
content="I'll help you",
|
||||
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
|
||||
|
||||
# Test accept
|
||||
with patch(
|
||||
"langchain.agents.middleware.human_in_the_loop.interrupt",
|
||||
return_value={"decisions": [{"type": "approve"}]},
|
||||
):
|
||||
result = middleware.after_model(state, None)
|
||||
assert result is not None
|
||||
assert "messages" in result
|
||||
assert len(result["messages"]) == 1
|
||||
assert result["messages"][0].tool_calls == ai_message.tool_calls
|
||||
|
||||
# Test edit
|
||||
with patch(
|
||||
"langchain.agents.middleware.human_in_the_loop.interrupt",
|
||||
return_value={
|
||||
"decisions": [
|
||||
{
|
||||
"type": "edit",
|
||||
"edited_action": Action(
|
||||
name="test_tool",
|
||||
args={"input": "edited"},
|
||||
),
|
||||
}
|
||||
]
|
||||
},
|
||||
):
|
||||
result = middleware.after_model(state, None)
|
||||
assert result is not None
|
||||
assert "messages" in result
|
||||
assert len(result["messages"]) == 1
|
||||
assert result["messages"][0].tool_calls[0]["args"] == {"input": "edited"}
|
||||
|
||||
middleware = HumanInTheLoopMiddleware(interrupt_on={"test_tool": False})
|
||||
|
||||
result = middleware.after_model(state, None)
|
||||
# No interruption should occur
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_human_in_the_loop_middleware_sequence_mismatch() -> None:
|
||||
"""Test that sequence mismatch in resume raises an error."""
|
||||
middleware = HumanInTheLoopMiddleware(interrupt_on={"test_tool": True})
|
||||
|
||||
ai_message = AIMessage(
|
||||
content="I'll help you",
|
||||
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
|
||||
|
||||
# Test with too few responses
|
||||
with patch(
|
||||
"langchain.agents.middleware.human_in_the_loop.interrupt",
|
||||
return_value={"decisions": []}, # No responses for 1 tool call
|
||||
):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r"Number of human decisions \(0\) does not match number of hanging tool calls \(1\)\.",
|
||||
):
|
||||
middleware.after_model(state, None)
|
||||
|
||||
# Test with too many responses
|
||||
with patch(
|
||||
"langchain.agents.middleware.human_in_the_loop.interrupt",
|
||||
return_value={
|
||||
"decisions": [
|
||||
{"type": "approve"},
|
||||
{"type": "approve"},
|
||||
]
|
||||
}, # 2 responses for 1 tool call
|
||||
):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r"Number of human decisions \(2\) does not match number of hanging tool calls \(1\)\.",
|
||||
):
|
||||
middleware.after_model(state, None)
|
||||
|
||||
|
||||
def test_human_in_the_loop_middleware_description_as_callable() -> None:
|
||||
"""Test that description field accepts both string and callable."""
|
||||
|
||||
def custom_description(tool_call: ToolCall, state: AgentState, runtime: Runtime) -> str:
|
||||
"""Generate a custom description."""
|
||||
return f"Custom: {tool_call['name']} with args {tool_call['args']}"
|
||||
|
||||
middleware = HumanInTheLoopMiddleware(
|
||||
interrupt_on={
|
||||
"tool_with_callable": {
|
||||
"allowed_decisions": ["approve"],
|
||||
"description": custom_description,
|
||||
},
|
||||
"tool_with_string": {
|
||||
"allowed_decisions": ["approve"],
|
||||
"description": "Static description",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
ai_message = AIMessage(
|
||||
content="I'll help you",
|
||||
tool_calls=[
|
||||
{"name": "tool_with_callable", "args": {"x": 1}, "id": "1"},
|
||||
{"name": "tool_with_string", "args": {"y": 2}, "id": "2"},
|
||||
],
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
|
||||
|
||||
captured_request = None
|
||||
|
||||
def mock_capture_requests(request):
|
||||
nonlocal captured_request
|
||||
captured_request = request
|
||||
return {"decisions": [{"type": "approve"}, {"type": "approve"}]}
|
||||
|
||||
with patch(
|
||||
"langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_capture_requests
|
||||
):
|
||||
middleware.after_model(state, None)
|
||||
|
||||
assert captured_request is not None
|
||||
assert "action_requests" in captured_request
|
||||
assert len(captured_request["action_requests"]) == 2
|
||||
|
||||
# Check callable description
|
||||
assert (
|
||||
captured_request["action_requests"][0]["description"]
|
||||
== "Custom: tool_with_callable with args {'x': 1}"
|
||||
)
|
||||
|
||||
# Check string description
|
||||
assert captured_request["action_requests"][1]["description"] == "Static description"
|
||||
@@ -0,0 +1,224 @@
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
from langchain.agents.factory import create_agent
|
||||
from langchain.agents.middleware.model_call_limit import (
|
||||
ModelCallLimitMiddleware,
|
||||
ModelCallLimitExceededError,
|
||||
)
|
||||
|
||||
from ...model import FakeToolCallingModel
|
||||
|
||||
|
||||
@tool
|
||||
def simple_tool(input: str) -> str:
|
||||
"""A simple tool"""
|
||||
return input
|
||||
|
||||
|
||||
def test_middleware_unit_functionality():
|
||||
"""Test that the middleware works as expected in isolation."""
|
||||
# Test with end behavior
|
||||
middleware = ModelCallLimitMiddleware(thread_limit=2, run_limit=1)
|
||||
|
||||
# Mock runtime (not used in current implementation)
|
||||
runtime = None
|
||||
|
||||
# Test when limits are not exceeded
|
||||
state = {"thread_model_call_count": 0, "run_model_call_count": 0}
|
||||
result = middleware.before_model(state, runtime)
|
||||
assert result is None
|
||||
|
||||
# Test when thread limit is exceeded
|
||||
state = {"thread_model_call_count": 2, "run_model_call_count": 0}
|
||||
result = middleware.before_model(state, runtime)
|
||||
assert result is not None
|
||||
assert result["jump_to"] == "end"
|
||||
assert "messages" in result
|
||||
assert len(result["messages"]) == 1
|
||||
assert "thread limit (2/2)" in result["messages"][0].content
|
||||
|
||||
# Test when run limit is exceeded
|
||||
state = {"thread_model_call_count": 1, "run_model_call_count": 1}
|
||||
result = middleware.before_model(state, runtime)
|
||||
assert result is not None
|
||||
assert result["jump_to"] == "end"
|
||||
assert "messages" in result
|
||||
assert len(result["messages"]) == 1
|
||||
assert "run limit (1/1)" in result["messages"][0].content
|
||||
|
||||
# Test with error behavior
|
||||
middleware_exception = ModelCallLimitMiddleware(
|
||||
thread_limit=2, run_limit=1, exit_behavior="error"
|
||||
)
|
||||
|
||||
# Test exception when thread limit exceeded
|
||||
state = {"thread_model_call_count": 2, "run_model_call_count": 0}
|
||||
with pytest.raises(ModelCallLimitExceededError) as exc_info:
|
||||
middleware_exception.before_model(state, runtime)
|
||||
|
||||
assert "thread limit (2/2)" in str(exc_info.value)
|
||||
|
||||
# Test exception when run limit exceeded
|
||||
state = {"thread_model_call_count": 1, "run_model_call_count": 1}
|
||||
with pytest.raises(ModelCallLimitExceededError) as exc_info:
|
||||
middleware_exception.before_model(state, runtime)
|
||||
|
||||
assert "run limit (1/1)" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_thread_limit_with_create_agent():
|
||||
"""Test that thread limits work correctly with create_agent."""
|
||||
model = FakeToolCallingModel()
|
||||
|
||||
# Set thread limit to 1 (should be exceeded after 1 call)
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
tools=[simple_tool],
|
||||
middleware=[ModelCallLimitMiddleware(thread_limit=1)],
|
||||
checkpointer=InMemorySaver(),
|
||||
)
|
||||
|
||||
# First invocation should work - 1 model call, within thread limit
|
||||
result = agent.invoke(
|
||||
{"messages": [HumanMessage("Hello")]}, {"configurable": {"thread_id": "thread1"}}
|
||||
)
|
||||
|
||||
# Should complete successfully with 1 model call
|
||||
assert "messages" in result
|
||||
assert len(result["messages"]) == 2 # Human + AI messages
|
||||
|
||||
# Second invocation in same thread should hit thread limit
|
||||
# The agent should jump to end after detecting the limit
|
||||
result2 = agent.invoke(
|
||||
{"messages": [HumanMessage("Hello again")]}, {"configurable": {"thread_id": "thread1"}}
|
||||
)
|
||||
|
||||
assert "messages" in result2
|
||||
# The agent should have detected the limit and jumped to end with a limit exceeded message
|
||||
# So we should have: previous messages + new human message + limit exceeded AI message
|
||||
assert len(result2["messages"]) == 4 # Previous Human + AI + New Human + Limit AI
|
||||
assert isinstance(result2["messages"][0], HumanMessage) # First human
|
||||
assert isinstance(result2["messages"][1], AIMessage) # First AI response
|
||||
assert isinstance(result2["messages"][2], HumanMessage) # Second human
|
||||
assert isinstance(result2["messages"][3], AIMessage) # Limit exceeded message
|
||||
assert "thread limit" in result2["messages"][3].content
|
||||
|
||||
|
||||
def test_run_limit_with_create_agent():
|
||||
"""Test that run limits work correctly with create_agent."""
|
||||
# Create a model that will make 2 calls
|
||||
model = FakeToolCallingModel(
|
||||
tool_calls=[
|
||||
[{"name": "simple_tool", "args": {"input": "test"}, "id": "1"}],
|
||||
[], # No tool calls on second call
|
||||
]
|
||||
)
|
||||
|
||||
# Set run limit to 1 (should be exceeded after 1 call)
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
tools=[simple_tool],
|
||||
middleware=[ModelCallLimitMiddleware(run_limit=1)],
|
||||
checkpointer=InMemorySaver(),
|
||||
)
|
||||
|
||||
# This should hit the run limit after the first model call
|
||||
result = agent.invoke(
|
||||
{"messages": [HumanMessage("Hello")]}, {"configurable": {"thread_id": "thread1"}}
|
||||
)
|
||||
|
||||
assert "messages" in result
|
||||
# The agent should have made 1 model call then jumped to end with limit exceeded message
|
||||
# So we should have: Human + AI + Tool + Limit exceeded AI message
|
||||
assert len(result["messages"]) == 4 # Human + AI + Tool + Limit AI
|
||||
assert isinstance(result["messages"][0], HumanMessage)
|
||||
assert isinstance(result["messages"][1], AIMessage)
|
||||
assert isinstance(result["messages"][2], ToolMessage)
|
||||
assert isinstance(result["messages"][3], AIMessage) # Limit exceeded message
|
||||
assert "run limit" in result["messages"][3].content
|
||||
|
||||
|
||||
def test_middleware_initialization_validation():
|
||||
"""Test that middleware initialization validates parameters correctly."""
|
||||
# Test that at least one limit must be specified
|
||||
with pytest.raises(ValueError, match="At least one limit must be specified"):
|
||||
ModelCallLimitMiddleware()
|
||||
|
||||
# Test invalid exit behavior
|
||||
with pytest.raises(ValueError, match="Invalid exit_behavior"):
|
||||
ModelCallLimitMiddleware(thread_limit=5, exit_behavior="invalid")
|
||||
|
||||
# Test valid initialization
|
||||
middleware = ModelCallLimitMiddleware(thread_limit=5, run_limit=3)
|
||||
assert middleware.thread_limit == 5
|
||||
assert middleware.run_limit == 3
|
||||
assert middleware.exit_behavior == "end"
|
||||
|
||||
# Test with only thread limit
|
||||
middleware = ModelCallLimitMiddleware(thread_limit=5)
|
||||
assert middleware.thread_limit == 5
|
||||
assert middleware.run_limit is None
|
||||
|
||||
# Test with only run limit
|
||||
middleware = ModelCallLimitMiddleware(run_limit=3)
|
||||
assert middleware.thread_limit is None
|
||||
assert middleware.run_limit == 3
|
||||
|
||||
|
||||
def test_exception_error_message():
|
||||
"""Test that the exception provides clear error messages."""
|
||||
middleware = ModelCallLimitMiddleware(thread_limit=2, run_limit=1, exit_behavior="error")
|
||||
|
||||
# Test thread limit exceeded
|
||||
state = {"thread_model_call_count": 2, "run_model_call_count": 0}
|
||||
with pytest.raises(ModelCallLimitExceededError) as exc_info:
|
||||
middleware.before_model(state, None)
|
||||
|
||||
error_msg = str(exc_info.value)
|
||||
assert "Model call limits exceeded" in error_msg
|
||||
assert "thread limit (2/2)" in error_msg
|
||||
|
||||
# Test run limit exceeded
|
||||
state = {"thread_model_call_count": 0, "run_model_call_count": 1}
|
||||
with pytest.raises(ModelCallLimitExceededError) as exc_info:
|
||||
middleware.before_model(state, None)
|
||||
|
||||
error_msg = str(exc_info.value)
|
||||
assert "Model call limits exceeded" in error_msg
|
||||
assert "run limit (1/1)" in error_msg
|
||||
|
||||
# Test both limits exceeded
|
||||
state = {"thread_model_call_count": 2, "run_model_call_count": 1}
|
||||
with pytest.raises(ModelCallLimitExceededError) as exc_info:
|
||||
middleware.before_model(state, None)
|
||||
|
||||
error_msg = str(exc_info.value)
|
||||
assert "Model call limits exceeded" in error_msg
|
||||
assert "thread limit (2/2)" in error_msg
|
||||
assert "run limit (1/1)" in error_msg
|
||||
|
||||
|
||||
def test_run_limit_resets_between_invocations() -> None:
|
||||
"""Test that run_model_call_count resets between invocations, but thread_model_call_count accumulates."""
|
||||
|
||||
# First: No tool calls per invocation, so model does not increment call counts internally
|
||||
middleware = ModelCallLimitMiddleware(thread_limit=3, run_limit=1, exit_behavior="error")
|
||||
model = FakeToolCallingModel(
|
||||
tool_calls=[[], [], [], []]
|
||||
) # No tool calls, so only model call per run
|
||||
|
||||
agent = create_agent(model=model, middleware=[middleware], checkpointer=InMemorySaver())
|
||||
|
||||
thread_config = {"configurable": {"thread_id": "test_thread"}}
|
||||
agent.invoke({"messages": [HumanMessage("Hello")]}, thread_config)
|
||||
agent.invoke({"messages": [HumanMessage("Hello again")]}, thread_config)
|
||||
agent.invoke({"messages": [HumanMessage("Hello third")]}, thread_config)
|
||||
|
||||
# Fourth run: should raise, thread_model_call_count == 3 (limit)
|
||||
with pytest.raises(ModelCallLimitExceededError) as exc_info:
|
||||
agent.invoke({"messages": [HumanMessage("Hello fourth")]}, thread_config)
|
||||
error_msg = str(exc_info.value)
|
||||
assert "thread limit (3/3)" in error_msg
|
||||
@@ -5,13 +5,18 @@ from __future__ import annotations
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
|
||||
from langchain.agents.factory import create_agent
|
||||
from langchain.agents.middleware.model_fallback import ModelFallbackMiddleware
|
||||
from langchain.agents.middleware.types import ModelRequest, ModelResponse
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from ...model import FakeToolCallingModel
|
||||
|
||||
|
||||
def _fake_runtime() -> Runtime:
|
||||
return cast(Runtime, object())
|
||||
@@ -213,3 +218,90 @@ async def test_all_models_fail_async() -> None:
|
||||
|
||||
with pytest.raises(ValueError, match="Model failed"):
|
||||
await middleware.awrap_model_call(request, mock_handler)
|
||||
|
||||
|
||||
def test_model_fallback_middleware_with_agent() -> None:
|
||||
"""Test ModelFallbackMiddleware with agent.invoke and fallback models only."""
|
||||
|
||||
class FailingModel(BaseChatModel):
|
||||
"""Model that always fails."""
|
||||
|
||||
def _generate(self, messages, **kwargs):
|
||||
raise ValueError("Primary model failed")
|
||||
|
||||
@property
|
||||
def _llm_type(self):
|
||||
return "failing"
|
||||
|
||||
class SuccessModel(BaseChatModel):
|
||||
"""Model that succeeds."""
|
||||
|
||||
def _generate(self, messages, **kwargs):
|
||||
return ChatResult(
|
||||
generations=[ChatGeneration(message=AIMessage(content="Fallback success"))]
|
||||
)
|
||||
|
||||
@property
|
||||
def _llm_type(self):
|
||||
return "success"
|
||||
|
||||
primary = FailingModel()
|
||||
fallback = SuccessModel()
|
||||
|
||||
# Only pass fallback models to middleware (not the primary)
|
||||
fallback_middleware = ModelFallbackMiddleware(fallback)
|
||||
|
||||
agent = create_agent(model=primary, middleware=[fallback_middleware])
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Test")]})
|
||||
|
||||
# Should have succeeded with fallback model
|
||||
assert len(result["messages"]) == 2
|
||||
assert result["messages"][1].content == "Fallback success"
|
||||
|
||||
|
||||
def test_model_fallback_middleware_exhausted_with_agent() -> None:
|
||||
"""Test ModelFallbackMiddleware with agent.invoke when all models fail."""
|
||||
|
||||
class AlwaysFailingModel(BaseChatModel):
|
||||
"""Model that always fails."""
|
||||
|
||||
def __init__(self, name: str):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
|
||||
def _generate(self, messages, **kwargs):
|
||||
raise ValueError(f"{self.name} failed")
|
||||
|
||||
@property
|
||||
def _llm_type(self):
|
||||
return self.name
|
||||
|
||||
primary = AlwaysFailingModel("primary")
|
||||
fallback1 = AlwaysFailingModel("fallback1")
|
||||
fallback2 = AlwaysFailingModel("fallback2")
|
||||
|
||||
# Primary fails (attempt 1), then fallback1 (attempt 2), then fallback2 (attempt 3)
|
||||
fallback_middleware = ModelFallbackMiddleware(fallback1, fallback2)
|
||||
|
||||
agent = create_agent(model=primary, middleware=[fallback_middleware])
|
||||
|
||||
# Should fail with the last fallback's error
|
||||
with pytest.raises(ValueError, match="fallback2 failed"):
|
||||
agent.invoke({"messages": [HumanMessage("Test")]})
|
||||
|
||||
|
||||
def test_model_fallback_middleware_initialization() -> None:
|
||||
"""Test ModelFallbackMiddleware initialization."""
|
||||
|
||||
# Test with no models - now a TypeError (missing required argument)
|
||||
with pytest.raises(TypeError):
|
||||
ModelFallbackMiddleware() # type: ignore[call-arg]
|
||||
|
||||
# Test with one fallback model (valid)
|
||||
middleware = ModelFallbackMiddleware(FakeToolCallingModel())
|
||||
assert len(middleware.models) == 1
|
||||
|
||||
# Test with multiple fallback models
|
||||
middleware = ModelFallbackMiddleware(FakeToolCallingModel(), FakeToolCallingModel())
|
||||
assert len(middleware.models) == 2
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user