mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-08 18:19:21 +00:00
Compare commits
46 Commits
langchain=
...
sr/typing-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4cc3a15af8 | ||
|
|
bacafb2fc6 | ||
|
|
b7d1831f9d | ||
|
|
328ba36601 | ||
|
|
6f677ef5c1 | ||
|
|
d47d41cbd3 | ||
|
|
32bbe99efc | ||
|
|
990e346c46 | ||
|
|
9b7792631d | ||
|
|
558a8fe25b | ||
|
|
52b1516d44 | ||
|
|
8a3bb73c05 | ||
|
|
099c042395 | ||
|
|
2d4f00a451 | ||
|
|
9bd401a6d4 | ||
|
|
6aa3794b74 | ||
|
|
189dcf7295 | ||
|
|
1bc88028e6 | ||
|
|
d2942351ce | ||
|
|
83c078f363 | ||
|
|
26d39ffc4a | ||
|
|
421e2ceeee | ||
|
|
275dcbf69f | ||
|
|
9f87b27a5b | ||
|
|
b2e1196e29 | ||
|
|
2dc1396380 | ||
|
|
77941ab3ce | ||
|
|
ee19a30dde | ||
|
|
5d799b3174 | ||
|
|
8f33a985a2 | ||
|
|
78eeccef0e | ||
|
|
3d415441e8 | ||
|
|
74385e0ebd | ||
|
|
2bfbc29ccc | ||
|
|
ef79c26f18 | ||
|
|
fbe32c8e89 | ||
|
|
2511c28f92 | ||
|
|
637bb1cbbc | ||
|
|
3dfea96ec1 | ||
|
|
68643153e5 | ||
|
|
462762f75b | ||
|
|
4f3729c004 | ||
|
|
ba428cdf54 | ||
|
|
69c7d1b01b | ||
|
|
733299ec13 | ||
|
|
e1adf781c6 |
77
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
77
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
@@ -8,16 +8,15 @@ body:
|
||||
value: |
|
||||
Thank you for taking the time to file a bug report.
|
||||
|
||||
Use this to report BUGS in LangChain. For usage questions, feature requests and general design questions, please use the [LangChain Forum](https://forum.langchain.com/).
|
||||
For usage questions, feature requests and general design questions, please use the [LangChain Forum](https://forum.langchain.com/).
|
||||
|
||||
Relevant links to check before filing a bug report to see if your issue has already been reported, fixed or
|
||||
if there's another way to solve your problem:
|
||||
Check these before submitting to see if your issue has already been reported, fixed or if there's another way to solve your problem:
|
||||
|
||||
* [LangChain Forum](https://forum.langchain.com/),
|
||||
* [LangChain documentation with the integrated search](https://docs.langchain.com/oss/python/langchain/overview),
|
||||
* [API Reference](https://reference.langchain.com/python/),
|
||||
* [Documentation](https://docs.langchain.com/oss/python/langchain/overview),
|
||||
* [API Reference Documentation](https://reference.langchain.com/python/),
|
||||
* [LangChain ChatBot](https://chat.langchain.com/)
|
||||
* [GitHub search](https://github.com/langchain-ai/langchain),
|
||||
* [LangChain Forum](https://forum.langchain.com/),
|
||||
- type: checkboxes
|
||||
id: checks
|
||||
attributes:
|
||||
@@ -36,16 +35,48 @@ body:
|
||||
required: true
|
||||
- label: This is not related to the langchain-community package.
|
||||
required: true
|
||||
- label: I read what a minimal reproducible example is (https://stackoverflow.com/help/minimal-reproducible-example).
|
||||
required: true
|
||||
- label: I posted a self-contained, minimal, reproducible example. A maintainer can copy it and run it AS IS.
|
||||
required: true
|
||||
- type: checkboxes
|
||||
id: package
|
||||
attributes:
|
||||
label: Package (Required)
|
||||
description: |
|
||||
Which `langchain` package(s) is this bug related to? Select at least one.
|
||||
|
||||
Note that if the package you are reporting for is not listed here, it is not in this repository (e.g. `langchain-google-genai` is in [`langchain-ai/langchain-google`](https://github.com/langchain-ai/langchain-google/)).
|
||||
|
||||
Please report issues for other packages to their respective repositories.
|
||||
options:
|
||||
- label: langchain
|
||||
- label: langchain-openai
|
||||
- label: langchain-anthropic
|
||||
- label: langchain-classic
|
||||
- label: langchain-core
|
||||
- label: langchain-cli
|
||||
- label: langchain-model-profiles
|
||||
- label: langchain-tests
|
||||
- label: langchain-text-splitters
|
||||
- label: langchain-chroma
|
||||
- label: langchain-deepseek
|
||||
- label: langchain-exa
|
||||
- label: langchain-fireworks
|
||||
- label: langchain-groq
|
||||
- label: langchain-huggingface
|
||||
- label: langchain-mistralai
|
||||
- label: langchain-nomic
|
||||
- label: langchain-ollama
|
||||
- label: langchain-perplexity
|
||||
- label: langchain-prompty
|
||||
- label: langchain-qdrant
|
||||
- label: langchain-xai
|
||||
- label: Other / not sure / general
|
||||
- type: textarea
|
||||
id: reproduction
|
||||
validations:
|
||||
required: true
|
||||
attributes:
|
||||
label: Example Code
|
||||
label: Example Code (Python)
|
||||
description: |
|
||||
Please add a self-contained, [minimal, reproducible, example](https://stackoverflow.com/help/minimal-reproducible-example) with your use case.
|
||||
|
||||
@@ -53,15 +84,12 @@ body:
|
||||
|
||||
**Important!**
|
||||
|
||||
* Avoid screenshots when possible, as they are hard to read and (more importantly) don't allow others to copy-and-paste your code.
|
||||
* Reduce your code to the minimum required to reproduce the issue if possible. This makes it much easier for others to help you.
|
||||
* Use code tags (e.g., ```python ... ```) to correctly [format your code](https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting).
|
||||
* INCLUDE the language label (e.g. `python`) after the first three backticks to enable syntax highlighting. (e.g., ```python rather than ```).
|
||||
* Avoid screenshots, as they are hard to read and (more importantly) don't allow others to copy-and-paste your code.
|
||||
* Reduce your code to the minimum required to reproduce the issue if possible.
|
||||
|
||||
(This will be automatically formatted into code, so no need for backticks.)
|
||||
render: python
|
||||
placeholder: |
|
||||
The following code:
|
||||
|
||||
```python
|
||||
from langchain_core.runnables import RunnableLambda
|
||||
|
||||
def bad_code(inputs) -> int:
|
||||
@@ -69,17 +97,14 @@ body:
|
||||
|
||||
chain = RunnableLambda(bad_code)
|
||||
chain.invoke('Hello!')
|
||||
```
|
||||
- type: textarea
|
||||
id: error
|
||||
validations:
|
||||
required: false
|
||||
attributes:
|
||||
label: Error Message and Stack Trace (if applicable)
|
||||
description: |
|
||||
If you are reporting an error, please include the full error message and stack trace.
|
||||
placeholder: |
|
||||
Exception + full stack trace
|
||||
If you are reporting an error, please copy and paste the full error message and
|
||||
stack trace.
|
||||
(This will be automatically formatted into code, so no need for backticks.)
|
||||
render: shell
|
||||
- type: textarea
|
||||
id: description
|
||||
attributes:
|
||||
@@ -99,9 +124,7 @@ body:
|
||||
attributes:
|
||||
label: System Info
|
||||
description: |
|
||||
Please share your system info with us. Do NOT skip this step and please don't trim
|
||||
the output. Most users don't include enough information here and it makes it harder
|
||||
for us to help you.
|
||||
Please share your system info with us.
|
||||
|
||||
Run the following command in your terminal and paste the output here:
|
||||
|
||||
@@ -113,8 +136,6 @@ body:
|
||||
from langchain_core import sys_info
|
||||
sys_info.print_sys_info()
|
||||
```
|
||||
|
||||
alternatively, put the entire output of `pip freeze` here.
|
||||
placeholder: |
|
||||
python -m langchain_core.sys_info
|
||||
validations:
|
||||
|
||||
11
.github/ISSUE_TEMPLATE/config.yml
vendored
11
.github/ISSUE_TEMPLATE/config.yml
vendored
@@ -1,9 +1,18 @@
|
||||
blank_issues_enabled: false
|
||||
version: 2.1
|
||||
contact_links:
|
||||
- name: 📚 Documentation
|
||||
- name: 📚 Documentation issue
|
||||
url: https://github.com/langchain-ai/docs/issues/new?template=01-langchain.yml
|
||||
about: Report an issue related to the LangChain documentation
|
||||
- name: 💬 LangChain Forum
|
||||
url: https://forum.langchain.com/
|
||||
about: General community discussions and support
|
||||
- name: 📚 LangChain Documentation
|
||||
url: https://docs.langchain.com/oss/python/langchain/overview
|
||||
about: View the official LangChain documentation
|
||||
- name: 📚 API Reference Documentation
|
||||
url: https://reference.langchain.com/python/
|
||||
about: View the official LangChain API reference documentation
|
||||
- name: 💬 LangChain Forum
|
||||
url: https://forum.langchain.com/
|
||||
about: Ask questions and get help from the community
|
||||
|
||||
40
.github/ISSUE_TEMPLATE/feature-request.yml
vendored
40
.github/ISSUE_TEMPLATE/feature-request.yml
vendored
@@ -13,11 +13,11 @@ body:
|
||||
Relevant links to check before filing a feature request to see if your request has already been made or
|
||||
if there's another way to achieve what you want:
|
||||
|
||||
* [LangChain Forum](https://forum.langchain.com/),
|
||||
* [LangChain documentation with the integrated search](https://docs.langchain.com/oss/python/langchain/overview),
|
||||
* [API Reference](https://reference.langchain.com/python/),
|
||||
* [Documentation](https://docs.langchain.com/oss/python/langchain/overview),
|
||||
* [API Reference Documentation](https://reference.langchain.com/python/),
|
||||
* [LangChain ChatBot](https://chat.langchain.com/)
|
||||
* [GitHub search](https://github.com/langchain-ai/langchain),
|
||||
* [LangChain Forum](https://forum.langchain.com/),
|
||||
- type: checkboxes
|
||||
id: checks
|
||||
attributes:
|
||||
@@ -34,6 +34,40 @@ body:
|
||||
required: true
|
||||
- label: This is not related to the langchain-community package.
|
||||
required: true
|
||||
- type: checkboxes
|
||||
id: package
|
||||
attributes:
|
||||
label: Package (Required)
|
||||
description: |
|
||||
Which `langchain` package(s) is this request related to? Select at least one.
|
||||
|
||||
Note that if the package you are requesting for is not listed here, it is not in this repository (e.g. `langchain-google-genai` is in `langchain-ai/langchain`).
|
||||
|
||||
Please submit feature requests for other packages to their respective repositories.
|
||||
options:
|
||||
- label: langchain
|
||||
- label: langchain-openai
|
||||
- label: langchain-anthropic
|
||||
- label: langchain-classic
|
||||
- label: langchain-core
|
||||
- label: langchain-cli
|
||||
- label: langchain-model-profiles
|
||||
- label: langchain-tests
|
||||
- label: langchain-text-splitters
|
||||
- label: langchain-chroma
|
||||
- label: langchain-deepseek
|
||||
- label: langchain-exa
|
||||
- label: langchain-fireworks
|
||||
- label: langchain-groq
|
||||
- label: langchain-huggingface
|
||||
- label: langchain-mistralai
|
||||
- label: langchain-nomic
|
||||
- label: langchain-ollama
|
||||
- label: langchain-perplexity
|
||||
- label: langchain-prompty
|
||||
- label: langchain-qdrant
|
||||
- label: langchain-xai
|
||||
- label: Other / not sure / general
|
||||
- type: textarea
|
||||
id: feature-description
|
||||
validations:
|
||||
|
||||
30
.github/ISSUE_TEMPLATE/privileged.yml
vendored
30
.github/ISSUE_TEMPLATE/privileged.yml
vendored
@@ -18,3 +18,33 @@ body:
|
||||
attributes:
|
||||
label: Issue Content
|
||||
description: Add the content of the issue here.
|
||||
- type: checkboxes
|
||||
id: package
|
||||
attributes:
|
||||
label: Package (Required)
|
||||
description: |
|
||||
Please select package(s) that this issue is related to.
|
||||
options:
|
||||
- label: langchain
|
||||
- label: langchain-openai
|
||||
- label: langchain-anthropic
|
||||
- label: langchain-classic
|
||||
- label: langchain-core
|
||||
- label: langchain-cli
|
||||
- label: langchain-model-profiles
|
||||
- label: langchain-tests
|
||||
- label: langchain-text-splitters
|
||||
- label: langchain-chroma
|
||||
- label: langchain-deepseek
|
||||
- label: langchain-exa
|
||||
- label: langchain-fireworks
|
||||
- label: langchain-groq
|
||||
- label: langchain-huggingface
|
||||
- label: langchain-mistralai
|
||||
- label: langchain-nomic
|
||||
- label: langchain-ollama
|
||||
- label: langchain-perplexity
|
||||
- label: langchain-prompty
|
||||
- label: langchain-qdrant
|
||||
- label: langchain-xai
|
||||
- label: Other / not sure / general
|
||||
|
||||
48
.github/ISSUE_TEMPLATE/task.yml
vendored
48
.github/ISSUE_TEMPLATE/task.yml
vendored
@@ -25,13 +25,13 @@ body:
|
||||
label: Task Description
|
||||
description: |
|
||||
Provide a clear and detailed description of the task.
|
||||
|
||||
|
||||
What needs to be done? Be specific about the scope and requirements.
|
||||
placeholder: |
|
||||
This task involves...
|
||||
|
||||
|
||||
The goal is to...
|
||||
|
||||
|
||||
Specific requirements:
|
||||
- ...
|
||||
- ...
|
||||
@@ -43,7 +43,7 @@ body:
|
||||
label: Acceptance Criteria
|
||||
description: |
|
||||
Define the criteria that must be met for this task to be considered complete.
|
||||
|
||||
|
||||
What are the specific deliverables or outcomes expected?
|
||||
placeholder: |
|
||||
This task will be complete when:
|
||||
@@ -58,15 +58,15 @@ body:
|
||||
label: Context and Background
|
||||
description: |
|
||||
Provide any relevant context, background information, or links to related issues/PRs.
|
||||
|
||||
|
||||
Why is this task needed? What problem does it solve?
|
||||
placeholder: |
|
||||
Background:
|
||||
- ...
|
||||
|
||||
|
||||
Related issues/PRs:
|
||||
- #...
|
||||
|
||||
|
||||
Additional context:
|
||||
- ...
|
||||
validations:
|
||||
@@ -77,15 +77,45 @@ body:
|
||||
label: Dependencies
|
||||
description: |
|
||||
List any dependencies or blockers for this task.
|
||||
|
||||
|
||||
Are there other tasks, issues, or external factors that need to be completed first?
|
||||
placeholder: |
|
||||
This task depends on:
|
||||
- [ ] Issue #...
|
||||
- [ ] PR #...
|
||||
- [ ] External dependency: ...
|
||||
|
||||
|
||||
Blocked by:
|
||||
- ...
|
||||
validations:
|
||||
required: false
|
||||
- type: checkboxes
|
||||
id: package
|
||||
attributes:
|
||||
label: Package (Required)
|
||||
description: |
|
||||
Please select package(s) that this task is related to.
|
||||
options:
|
||||
- label: langchain
|
||||
- label: langchain-openai
|
||||
- label: langchain-anthropic
|
||||
- label: langchain-classic
|
||||
- label: langchain-core
|
||||
- label: langchain-cli
|
||||
- label: langchain-model-profiles
|
||||
- label: langchain-tests
|
||||
- label: langchain-text-splitters
|
||||
- label: langchain-chroma
|
||||
- label: langchain-deepseek
|
||||
- label: langchain-exa
|
||||
- label: langchain-fireworks
|
||||
- label: langchain-groq
|
||||
- label: langchain-huggingface
|
||||
- label: langchain-mistralai
|
||||
- label: langchain-nomic
|
||||
- label: langchain-ollama
|
||||
- label: langchain-perplexity
|
||||
- label: langchain-prompty
|
||||
- label: langchain-qdrant
|
||||
- label: langchain-xai
|
||||
- label: Other / not sure / general
|
||||
|
||||
2
.github/workflows/_release.yml
vendored
2
.github/workflows/_release.yml
vendored
@@ -396,7 +396,7 @@ jobs:
|
||||
contents: read
|
||||
strategy:
|
||||
matrix:
|
||||
partner: [openai, anthropic]
|
||||
partner: [anthropic]
|
||||
fail-fast: false # Continue testing other partners if one fails
|
||||
env:
|
||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
|
||||
107
.github/workflows/auto-label-by-package.yml
vendored
Normal file
107
.github/workflows/auto-label-by-package.yml
vendored
Normal file
@@ -0,0 +1,107 @@
|
||||
name: Auto Label Issues by Package
|
||||
|
||||
on:
|
||||
issues:
|
||||
types: [opened, edited]
|
||||
|
||||
jobs:
|
||||
label-by-package:
|
||||
permissions:
|
||||
issues: write
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Sync package labels
|
||||
uses: actions/github-script@v6
|
||||
with:
|
||||
script: |
|
||||
const body = context.payload.issue.body || "";
|
||||
|
||||
// Extract text under "### Package"
|
||||
const match = body.match(/### Package\s+([\s\S]*?)\n###/i);
|
||||
if (!match) return;
|
||||
|
||||
const packageSection = match[1].trim();
|
||||
|
||||
// Mapping table for package names to labels
|
||||
const mapping = {
|
||||
"langchain": "langchain",
|
||||
"langchain-openai": "openai",
|
||||
"langchain-anthropic": "anthropic",
|
||||
"langchain-classic": "langchain-classic",
|
||||
"langchain-core": "core",
|
||||
"langchain-cli": "cli",
|
||||
"langchain-model-profiles": "model-profiles",
|
||||
"langchain-tests": "standard-tests",
|
||||
"langchain-text-splitters": "text-splitters",
|
||||
"langchain-chroma": "chroma",
|
||||
"langchain-deepseek": "deepseek",
|
||||
"langchain-exa": "exa",
|
||||
"langchain-fireworks": "fireworks",
|
||||
"langchain-groq": "groq",
|
||||
"langchain-huggingface": "huggingface",
|
||||
"langchain-mistralai": "mistralai",
|
||||
"langchain-nomic": "nomic",
|
||||
"langchain-ollama": "ollama",
|
||||
"langchain-perplexity": "perplexity",
|
||||
"langchain-prompty": "prompty",
|
||||
"langchain-qdrant": "qdrant",
|
||||
"langchain-xai": "xai",
|
||||
};
|
||||
|
||||
// All possible package labels we manage
|
||||
const allPackageLabels = Object.values(mapping);
|
||||
const selectedLabels = [];
|
||||
|
||||
// Check if this is checkbox format (multiple selection)
|
||||
const checkboxMatches = packageSection.match(/- \[x\]\s+([^\n\r]+)/gi);
|
||||
if (checkboxMatches) {
|
||||
// Handle checkbox format
|
||||
for (const match of checkboxMatches) {
|
||||
const packageName = match.replace(/- \[x\]\s+/i, '').trim();
|
||||
const label = mapping[packageName];
|
||||
if (label && !selectedLabels.includes(label)) {
|
||||
selectedLabels.push(label);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Handle dropdown format (single selection)
|
||||
const label = mapping[packageSection];
|
||||
if (label) {
|
||||
selectedLabels.push(label);
|
||||
}
|
||||
}
|
||||
|
||||
// Get current issue labels
|
||||
const issue = await github.rest.issues.get({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: context.issue.number
|
||||
});
|
||||
|
||||
const currentLabels = issue.data.labels.map(label => label.name);
|
||||
const currentPackageLabels = currentLabels.filter(label => allPackageLabels.includes(label));
|
||||
|
||||
// Determine labels to add and remove
|
||||
const labelsToAdd = selectedLabels.filter(label => !currentPackageLabels.includes(label));
|
||||
const labelsToRemove = currentPackageLabels.filter(label => !selectedLabels.includes(label));
|
||||
|
||||
// Add new labels
|
||||
if (labelsToAdd.length > 0) {
|
||||
await github.rest.issues.addLabels({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: context.issue.number,
|
||||
labels: labelsToAdd
|
||||
});
|
||||
}
|
||||
|
||||
// Remove old labels
|
||||
for (const label of labelsToRemove) {
|
||||
await github.rest.issues.removeLabel({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: context.issue.number,
|
||||
name: label
|
||||
});
|
||||
}
|
||||
4
.github/workflows/pr_lint.yml
vendored
4
.github/workflows/pr_lint.yml
vendored
@@ -26,12 +26,14 @@
|
||||
# * revert — reverts a previous commit
|
||||
# * release — prepare a new release
|
||||
#
|
||||
# Allowed Scopes (optional):
|
||||
# Allowed Scope(s) (optional):
|
||||
# core, cli, langchain, langchain_v1, langchain-classic, standard-tests,
|
||||
# text-splitters, docs, anthropic, chroma, deepseek, exa, fireworks, groq,
|
||||
# huggingface, mistralai, nomic, ollama, openai, perplexity, prompty, qdrant,
|
||||
# xai, infra, deps
|
||||
#
|
||||
# Multiple scopes can be used by separating them with a comma.
|
||||
#
|
||||
# Rules:
|
||||
# 1. The 'Type' must start with a lowercase letter.
|
||||
# 2. Breaking changes: append "!" after type/scope (e.g., feat!: drop x support)
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -163,3 +163,6 @@ node_modules
|
||||
|
||||
prof
|
||||
virtualenv/
|
||||
scratch/
|
||||
|
||||
.langgraph_api/
|
||||
|
||||
@@ -6,9 +6,8 @@ import hashlib
|
||||
import logging
|
||||
import re
|
||||
import shutil
|
||||
from collections.abc import Sequence
|
||||
from pathlib import Path
|
||||
from typing import Any, TypedDict
|
||||
from typing import TYPE_CHECKING, Any, TypedDict
|
||||
|
||||
from git import Repo
|
||||
|
||||
@@ -18,6 +17,9 @@ from langchain_cli.constants import (
|
||||
DEFAULT_GIT_SUBDIRECTORY,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .file import File
|
||||
from .folder import Folder
|
||||
if TYPE_CHECKING:
|
||||
from .file import File
|
||||
from .folder import Folder
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .file import File
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class Folder:
|
||||
def __init__(self, name: str, *files: Folder | File) -> None:
|
||||
|
||||
@@ -34,7 +34,7 @@ The LangChain ecosystem is built on top of `langchain-core`. Some of the benefit
|
||||
|
||||
## 📖 Documentation
|
||||
|
||||
For full documentation, see the [API reference](https://reference.langchain.com/python/langchain_core/).
|
||||
For full documentation, see the [API reference](https://reference.langchain.com/python/langchain_core/). For conceptual guides, tutorials, and examples on using LangChain, see the [LangChain Docs](https://docs.langchain.com/oss/python/langchain/overview).
|
||||
|
||||
## 📕 Releases & Versioning
|
||||
|
||||
|
||||
@@ -5,13 +5,12 @@ from __future__ import annotations
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
from uuid import UUID
|
||||
|
||||
from tenacity import RetryCallState
|
||||
from typing_extensions import Self
|
||||
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.documents import Document
|
||||
|
||||
@@ -39,7 +39,6 @@ from langchain_core.tracers.context import (
|
||||
tracing_v2_callback_var,
|
||||
)
|
||||
from langchain_core.tracers.langchain import LangChainTracer
|
||||
from langchain_core.tracers.schemas import Run
|
||||
from langchain_core.tracers.stdout import ConsoleCallbackHandler
|
||||
from langchain_core.utils.env import env_var_is_set
|
||||
|
||||
@@ -52,6 +51,7 @@ if TYPE_CHECKING:
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk, LLMResult
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
from langchain_core.tracers.schemas import Run
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -6,16 +6,9 @@ import hashlib
|
||||
import json
|
||||
import uuid
|
||||
import warnings
|
||||
from collections.abc import (
|
||||
AsyncIterable,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Iterable,
|
||||
Iterator,
|
||||
Sequence,
|
||||
)
|
||||
from itertools import islice
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Literal,
|
||||
TypedDict,
|
||||
@@ -29,6 +22,16 @@ from langchain_core.exceptions import LangChainException
|
||||
from langchain_core.indexing.base import DocumentIndex, RecordManager
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import (
|
||||
AsyncIterable,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Iterable,
|
||||
Iterator,
|
||||
Sequence,
|
||||
)
|
||||
|
||||
# Magic UUID to use as a namespace for hashing.
|
||||
# Used to try and generate a unique UUID for each document
|
||||
# from hashing the document content and metadata.
|
||||
|
||||
@@ -299,6 +299,9 @@ class BaseLanguageModel(
|
||||
|
||||
Useful for checking if an input fits in a model's context window.
|
||||
|
||||
This should be overridden by model-specific implementations to provide accurate
|
||||
token counts via model-specific tokenizers.
|
||||
|
||||
Args:
|
||||
text: The string input to tokenize.
|
||||
|
||||
@@ -317,9 +320,17 @@ class BaseLanguageModel(
|
||||
|
||||
Useful for checking if an input fits in a model's context window.
|
||||
|
||||
This should be overridden by model-specific implementations to provide accurate
|
||||
token counts via model-specific tokenizers.
|
||||
|
||||
!!! note
|
||||
The base implementation of `get_num_tokens_from_messages` ignores tool
|
||||
schemas.
|
||||
|
||||
* The base implementation of `get_num_tokens_from_messages` ignores tool
|
||||
schemas.
|
||||
* The base implementation of `get_num_tokens_from_messages` adds additional
|
||||
prefixes to messages in represent user roles, which will add to the
|
||||
overall token count. Model-specific implementations may choose to
|
||||
handle this differently.
|
||||
|
||||
Args:
|
||||
messages: The message inputs to tokenize.
|
||||
|
||||
@@ -91,7 +91,10 @@ def _generate_response_from_error(error: BaseException) -> list[ChatGeneration]:
|
||||
try:
|
||||
metadata["body"] = response.json()
|
||||
except Exception:
|
||||
metadata["body"] = getattr(response, "text", None)
|
||||
try:
|
||||
metadata["body"] = getattr(response, "text", None)
|
||||
except Exception:
|
||||
metadata["body"] = None
|
||||
if hasattr(response, "headers"):
|
||||
try:
|
||||
metadata["headers"] = dict(response.headers)
|
||||
|
||||
@@ -61,13 +61,15 @@ class Reviver:
|
||||
"""Initialize the reviver.
|
||||
|
||||
Args:
|
||||
secrets_map: A map of secrets to load. If a secret is not found in
|
||||
the map, it will be loaded from the environment if `secrets_from_env`
|
||||
is True.
|
||||
secrets_map: A map of secrets to load.
|
||||
|
||||
If a secret is not found in the map, it will be loaded from the
|
||||
environment if `secrets_from_env` is `True`.
|
||||
valid_namespaces: A list of additional namespaces (modules)
|
||||
to allow to be deserialized.
|
||||
secrets_from_env: Whether to load secrets from the environment.
|
||||
additional_import_mappings: A dictionary of additional namespace mappings
|
||||
|
||||
You can use this to override default mappings or add new mappings.
|
||||
ignore_unserializable_fields: Whether to ignore unserializable fields.
|
||||
"""
|
||||
@@ -195,13 +197,15 @@ def loads(
|
||||
|
||||
Args:
|
||||
text: The string to load.
|
||||
secrets_map: A map of secrets to load. If a secret is not found in
|
||||
the map, it will be loaded from the environment if `secrets_from_env`
|
||||
is True.
|
||||
secrets_map: A map of secrets to load.
|
||||
|
||||
If a secret is not found in the map, it will be loaded from the environment
|
||||
if `secrets_from_env` is `True`.
|
||||
valid_namespaces: A list of additional namespaces (modules)
|
||||
to allow to be deserialized.
|
||||
secrets_from_env: Whether to load secrets from the environment.
|
||||
additional_import_mappings: A dictionary of additional namespace mappings
|
||||
|
||||
You can use this to override default mappings or add new mappings.
|
||||
ignore_unserializable_fields: Whether to ignore unserializable fields.
|
||||
|
||||
@@ -237,13 +241,15 @@ def load(
|
||||
|
||||
Args:
|
||||
obj: The object to load.
|
||||
secrets_map: A map of secrets to load. If a secret is not found in
|
||||
the map, it will be loaded from the environment if `secrets_from_env`
|
||||
is True.
|
||||
secrets_map: A map of secrets to load.
|
||||
|
||||
If a secret is not found in the map, it will be loaded from the environment
|
||||
if `secrets_from_env` is `True`.
|
||||
valid_namespaces: A list of additional namespaces (modules)
|
||||
to allow to be deserialized.
|
||||
secrets_from_env: Whether to load secrets from the environment.
|
||||
additional_import_mappings: A dictionary of additional namespace mappings
|
||||
|
||||
You can use this to override default mappings or add new mappings.
|
||||
ignore_unserializable_fields: Whether to ignore unserializable fields.
|
||||
|
||||
@@ -265,8 +271,6 @@ def load(
|
||||
return reviver(loaded_obj)
|
||||
if isinstance(obj, list):
|
||||
return [_load(o) for o in obj]
|
||||
if isinstance(obj, str) and obj in reviver.secrets_map:
|
||||
return reviver.secrets_map[obj]
|
||||
return obj
|
||||
|
||||
return _load(obj)
|
||||
|
||||
@@ -5,11 +5,9 @@ from __future__ import annotations
|
||||
from typing import TYPE_CHECKING, Any, cast, overload
|
||||
|
||||
from pydantic import ConfigDict, Field
|
||||
from typing_extensions import Self
|
||||
|
||||
from langchain_core._api.deprecation import warn_deprecated
|
||||
from langchain_core.load.serializable import Serializable
|
||||
from langchain_core.messages import content as types
|
||||
from langchain_core.utils import get_bolded_text
|
||||
from langchain_core.utils._merge import merge_dicts, merge_lists
|
||||
from langchain_core.utils.interactive_env import is_interactive_env
|
||||
@@ -17,6 +15,9 @@ from langchain_core.utils.interactive_env import is_interactive_env
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
from langchain_core.messages import content as types
|
||||
from langchain_core.prompts.chat import ChatPromptTemplate
|
||||
|
||||
|
||||
|
||||
@@ -12,10 +12,11 @@ the implementation in `BaseMessage`.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk
|
||||
from langchain_core.messages import content as types
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
import warnings
|
||||
from collections.abc import Iterable
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
|
||||
from langchain_core.language_models._utils import (
|
||||
@@ -14,6 +13,8 @@ from langchain_core.language_models._utils import (
|
||||
from langchain_core.messages import content as types
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable
|
||||
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk
|
||||
|
||||
|
||||
|
||||
@@ -2,15 +2,17 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
|
||||
from pydantic import model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from langchain_core.messages import BaseMessage, BaseMessageChunk
|
||||
from langchain_core.outputs.generation import Generation
|
||||
from langchain_core.utils._merge import merge_dicts
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
class ChatGeneration(Generation):
|
||||
"""A single chat generation output.
|
||||
|
||||
@@ -6,7 +6,7 @@ import contextlib
|
||||
import json
|
||||
import typing
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Mapping
|
||||
from collections.abc import Mapping
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
@@ -33,6 +33,8 @@ from langchain_core.runnables.config import ensure_config
|
||||
from langchain_core.utils.pydantic import create_model_v2
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
from langchain_core.documents import Document
|
||||
|
||||
|
||||
|
||||
@@ -6,10 +6,10 @@ from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_core.load import Serializable
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.utils.interactive_env import is_interactive_env
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.prompts.chat import ChatPromptTemplate
|
||||
|
||||
|
||||
|
||||
@@ -4,9 +4,8 @@ from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from abc import ABC
|
||||
from collections.abc import Callable, Sequence
|
||||
from string import Formatter
|
||||
from typing import Any, Literal
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
from pydantic import BaseModel, create_model
|
||||
|
||||
@@ -16,6 +15,9 @@ from langchain_core.utils import get_colored_text, mustache
|
||||
from langchain_core.utils.formatting import formatter
|
||||
from langchain_core.utils.interactive_env import is_interactive_env
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable, Sequence
|
||||
|
||||
try:
|
||||
from jinja2 import Environment, meta
|
||||
from jinja2.sandbox import SandboxedEnvironment
|
||||
|
||||
@@ -4,7 +4,6 @@ from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
@@ -22,7 +21,7 @@ from langchain_core.runnables.base import Runnable, RunnableSerializable
|
||||
from langchain_core.utils.pydantic import _IgnoreUnserializable, is_basemodel_subclass
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Sequence
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -642,6 +641,7 @@ class Graph:
|
||||
retry_delay: float = 1.0,
|
||||
frontmatter_config: dict[str, Any] | None = None,
|
||||
base_url: str | None = None,
|
||||
proxies: dict[str, str] | None = None,
|
||||
) -> bytes:
|
||||
"""Draw the graph as a PNG image using Mermaid.
|
||||
|
||||
@@ -674,11 +674,10 @@ class Graph:
|
||||
}
|
||||
```
|
||||
base_url: The base URL of the Mermaid server for rendering via API.
|
||||
|
||||
proxies: HTTP/HTTPS proxies for requests (e.g. `{"http": "http://127.0.0.1:7890"}`).
|
||||
|
||||
Returns:
|
||||
The PNG image as bytes.
|
||||
|
||||
"""
|
||||
# Import locally to prevent circular import
|
||||
from langchain_core.runnables.graph_mermaid import ( # noqa: PLC0415
|
||||
@@ -699,6 +698,7 @@ class Graph:
|
||||
padding=padding,
|
||||
max_retries=max_retries,
|
||||
retry_delay=retry_delay,
|
||||
proxies=proxies,
|
||||
base_url=base_url,
|
||||
)
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ from __future__ import annotations
|
||||
|
||||
import math
|
||||
import os
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
try:
|
||||
@@ -20,6 +19,8 @@ except ImportError:
|
||||
_HAS_GRANDALF = False
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Mapping, Sequence
|
||||
|
||||
from langchain_core.runnables.graph import Edge as LangEdge
|
||||
|
||||
|
||||
|
||||
@@ -281,6 +281,7 @@ def draw_mermaid_png(
|
||||
max_retries: int = 1,
|
||||
retry_delay: float = 1.0,
|
||||
base_url: str | None = None,
|
||||
proxies: dict[str, str] | None = None,
|
||||
) -> bytes:
|
||||
"""Draws a Mermaid graph as PNG using provided syntax.
|
||||
|
||||
@@ -293,6 +294,7 @@ def draw_mermaid_png(
|
||||
max_retries: Maximum number of retries (MermaidDrawMethod.API).
|
||||
retry_delay: Delay between retries (MermaidDrawMethod.API).
|
||||
base_url: Base URL for the Mermaid.ink API.
|
||||
proxies: HTTP/HTTPS proxies for requests (e.g. `{"http": "http://127.0.0.1:7890"}`).
|
||||
|
||||
Returns:
|
||||
PNG image bytes.
|
||||
@@ -314,6 +316,7 @@ def draw_mermaid_png(
|
||||
max_retries=max_retries,
|
||||
retry_delay=retry_delay,
|
||||
base_url=base_url,
|
||||
proxies=proxies,
|
||||
)
|
||||
else:
|
||||
supported_methods = ", ".join([m.value for m in MermaidDrawMethod])
|
||||
@@ -405,6 +408,7 @@ def _render_mermaid_using_api(
|
||||
file_type: Literal["jpeg", "png", "webp"] | None = "png",
|
||||
max_retries: int = 1,
|
||||
retry_delay: float = 1.0,
|
||||
proxies: dict[str, str] | None = None,
|
||||
base_url: str | None = None,
|
||||
) -> bytes:
|
||||
"""Renders Mermaid graph using the Mermaid.INK API."""
|
||||
@@ -445,7 +449,7 @@ def _render_mermaid_using_api(
|
||||
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
response = requests.get(image_url, timeout=10)
|
||||
response = requests.get(image_url, timeout=10, proxies=proxies)
|
||||
if response.status_code == requests.codes.ok:
|
||||
img_bytes = response.content
|
||||
if output_file_path is not None:
|
||||
|
||||
@@ -7,8 +7,7 @@ import asyncio
|
||||
import inspect
|
||||
import sys
|
||||
import textwrap
|
||||
from collections.abc import Callable, Mapping, Sequence
|
||||
from contextvars import Context
|
||||
from collections.abc import Mapping, Sequence
|
||||
from functools import lru_cache
|
||||
from inspect import signature
|
||||
from itertools import groupby
|
||||
@@ -31,9 +30,11 @@ if TYPE_CHECKING:
|
||||
AsyncIterable,
|
||||
AsyncIterator,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Coroutine,
|
||||
Iterable,
|
||||
)
|
||||
from contextvars import Context
|
||||
|
||||
from langchain_core.runnables.schema import StreamEvent
|
||||
|
||||
|
||||
@@ -386,6 +386,8 @@ class ToolException(Exception): # noqa: N818
|
||||
|
||||
ArgsSchema = TypeBaseModel | dict[str, Any]
|
||||
|
||||
_EMPTY_SET: frozenset[str] = frozenset()
|
||||
|
||||
|
||||
class BaseTool(RunnableSerializable[str | dict | ToolCall, Any]):
|
||||
"""Base class for all LangChain tools.
|
||||
@@ -569,6 +571,11 @@ class ChildTool(BaseTool):
|
||||
self.name, full_schema, fields, fn_description=self.description
|
||||
)
|
||||
|
||||
@functools.cached_property
|
||||
def _injected_args_keys(self) -> frozenset[str]:
|
||||
# base implementation doesn't manage injected args
|
||||
return _EMPTY_SET
|
||||
|
||||
# --- Runnable ---
|
||||
|
||||
@override
|
||||
@@ -649,6 +656,7 @@ class ChildTool(BaseTool):
|
||||
if isinstance(input_args, dict):
|
||||
return tool_input
|
||||
if issubclass(input_args, BaseModel):
|
||||
# Check args_schema for InjectedToolCallId
|
||||
for k, v in get_all_basemodel_annotations(input_args).items():
|
||||
if _is_injected_arg_type(v, injected_type=InjectedToolCallId):
|
||||
if tool_call_id is None:
|
||||
@@ -664,6 +672,7 @@ class ChildTool(BaseTool):
|
||||
result = input_args.model_validate(tool_input)
|
||||
result_dict = result.model_dump()
|
||||
elif issubclass(input_args, BaseModelV1):
|
||||
# Check args_schema for InjectedToolCallId
|
||||
for k, v in get_all_basemodel_annotations(input_args).items():
|
||||
if _is_injected_arg_type(v, injected_type=InjectedToolCallId):
|
||||
if tool_call_id is None:
|
||||
@@ -683,9 +692,25 @@ class ChildTool(BaseTool):
|
||||
f"args_schema must be a Pydantic BaseModel, got {self.args_schema}"
|
||||
)
|
||||
raise NotImplementedError(msg)
|
||||
return {
|
||||
k: getattr(result, k) for k, v in result_dict.items() if k in tool_input
|
||||
validated_input = {
|
||||
k: getattr(result, k) for k in result_dict if k in tool_input
|
||||
}
|
||||
for k in self._injected_args_keys:
|
||||
if k == "tool_call_id":
|
||||
if tool_call_id is None:
|
||||
msg = (
|
||||
"When tool includes an InjectedToolCallId "
|
||||
"argument, tool must always be invoked with a full "
|
||||
"model ToolCall of the form: {'args': {...}, "
|
||||
"'name': '...', 'type': 'tool_call', "
|
||||
"'tool_call_id': '...'}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
validated_input[k] = tool_call_id
|
||||
if k in tool_input:
|
||||
injected_val = tool_input[k]
|
||||
validated_input[k] = injected_val
|
||||
return validated_input
|
||||
return tool_input
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import textwrap
|
||||
from collections.abc import Awaitable, Callable
|
||||
from inspect import signature
|
||||
@@ -21,10 +22,12 @@ from langchain_core.callbacks import (
|
||||
)
|
||||
from langchain_core.runnables import RunnableConfig, run_in_executor
|
||||
from langchain_core.tools.base import (
|
||||
_EMPTY_SET,
|
||||
FILTERED_ARGS,
|
||||
ArgsSchema,
|
||||
BaseTool,
|
||||
_get_runnable_config_param,
|
||||
_is_injected_arg_type,
|
||||
create_schema_from_function,
|
||||
)
|
||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
@@ -241,6 +244,17 @@ class StructuredTool(BaseTool):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@functools.cached_property
|
||||
def _injected_args_keys(self) -> frozenset[str]:
|
||||
fn = self.func or self.coroutine
|
||||
if fn is None:
|
||||
return _EMPTY_SET
|
||||
return frozenset(
|
||||
k
|
||||
for k, v in signature(fn).parameters.items()
|
||||
if _is_injected_arg_type(v.annotation)
|
||||
)
|
||||
|
||||
|
||||
def _filter_schema_args(func: Callable) -> list[str]:
|
||||
filter_args = list(FILTERED_ARGS)
|
||||
|
||||
@@ -15,12 +15,6 @@ from typing import (
|
||||
|
||||
from langchain_core.exceptions import TracerException
|
||||
from langchain_core.load import dumpd
|
||||
from langchain_core.outputs import (
|
||||
ChatGeneration,
|
||||
ChatGenerationChunk,
|
||||
GenerationChunk,
|
||||
LLMResult,
|
||||
)
|
||||
from langchain_core.tracers.schemas import Run
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -31,6 +25,12 @@ if TYPE_CHECKING:
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.outputs import (
|
||||
ChatGeneration,
|
||||
ChatGenerationChunk,
|
||||
GenerationChunk,
|
||||
LLMResult,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -8,7 +8,6 @@ import logging
|
||||
import types
|
||||
import typing
|
||||
import uuid
|
||||
from collections.abc import Callable
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Annotated,
|
||||
@@ -33,6 +32,8 @@ from langchain_core.utils.json_schema import dereference_refs
|
||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -4,11 +4,13 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
|
||||
def _replace_new_line(match: re.Match[str]) -> str:
|
||||
value = match.group(2)
|
||||
|
||||
@@ -5,7 +5,6 @@ from __future__ import annotations
|
||||
import inspect
|
||||
import textwrap
|
||||
import warnings
|
||||
from collections.abc import Callable
|
||||
from contextlib import nullcontext
|
||||
from functools import lru_cache, wraps
|
||||
from types import GenericAlias
|
||||
@@ -41,10 +40,12 @@ from pydantic.json_schema import (
|
||||
)
|
||||
from pydantic.v1 import BaseModel as BaseModelV1
|
||||
from pydantic.v1 import create_model as create_model_v1
|
||||
from pydantic.v1.fields import ModelField
|
||||
from typing_extensions import deprecated, override
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
from pydantic.v1.fields import ModelField
|
||||
from pydantic_core import core_schema
|
||||
|
||||
PYDANTIC_VERSION = version.parse(pydantic.__version__)
|
||||
|
||||
@@ -11,7 +11,6 @@ import logging
|
||||
import math
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from itertools import cycle
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
@@ -29,7 +28,7 @@ from langchain_core.retrievers import BaseRetriever, LangSmithRetrieverParams
|
||||
from langchain_core.runnables.config import run_in_executor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Collection, Iterable, Iterator, Sequence
|
||||
from collections.abc import Callable, Collection, Iterable, Iterator, Sequence
|
||||
|
||||
from langchain_core.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
|
||||
@@ -4,7 +4,6 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
@@ -20,7 +19,7 @@ from langchain_core.vectorstores.utils import _cosine_similarity as cosine_simil
|
||||
from langchain_core.vectorstores.utils import maximal_marginal_relevance
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterator, Sequence
|
||||
from collections.abc import Callable, Iterator, Sequence
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
"""langchain-core version information and utilities."""
|
||||
|
||||
VERSION = "1.0.4"
|
||||
VERSION = "1.0.6"
|
||||
|
||||
@@ -9,7 +9,7 @@ license = {text = "MIT"}
|
||||
readme = "README.md"
|
||||
authors = []
|
||||
|
||||
version = "1.0.4"
|
||||
version = "1.0.6"
|
||||
requires-python = ">=3.10.0,<4.0.0"
|
||||
dependencies = [
|
||||
"langsmith>=0.3.45,<1.0.0",
|
||||
|
||||
@@ -18,6 +18,7 @@ from langchain_core.language_models import (
|
||||
ParrotFakeChatModel,
|
||||
)
|
||||
from langchain_core.language_models._utils import _normalize_messages
|
||||
from langchain_core.language_models.chat_models import _generate_response_from_error
|
||||
from langchain_core.language_models.fake_chat_models import (
|
||||
FakeListChatModelError,
|
||||
GenericFakeChatModel,
|
||||
@@ -1234,3 +1235,93 @@ def test_model_profiles() -> None:
|
||||
model = MyModel(messages=iter([]))
|
||||
profile = model.profile
|
||||
assert profile
|
||||
|
||||
|
||||
class MockResponse:
|
||||
"""Mock response for testing _generate_response_from_error."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
status_code: int = 400,
|
||||
headers: dict[str, str] | None = None,
|
||||
json_data: dict[str, Any] | None = None,
|
||||
json_raises: type[Exception] | None = None,
|
||||
text_raises: type[Exception] | None = None,
|
||||
):
|
||||
self.status_code = status_code
|
||||
self.headers = headers or {}
|
||||
self._json_data = json_data
|
||||
self._json_raises = json_raises
|
||||
self._text_raises = text_raises
|
||||
|
||||
def json(self) -> dict[str, Any]:
|
||||
if self._json_raises:
|
||||
msg = "JSON parsing failed"
|
||||
raise self._json_raises(msg)
|
||||
return self._json_data or {}
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
if self._text_raises:
|
||||
msg = "Text access failed"
|
||||
raise self._text_raises(msg)
|
||||
return ""
|
||||
|
||||
|
||||
class MockAPIError(Exception):
|
||||
"""Mock API error with response attribute."""
|
||||
|
||||
def __init__(self, message: str, response: MockResponse | None = None):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
if response is not None:
|
||||
self.response = response
|
||||
|
||||
|
||||
def test_generate_response_from_error_with_valid_json() -> None:
|
||||
"""Test `_generate_response_from_error` with valid JSON response."""
|
||||
response = MockResponse(
|
||||
status_code=400,
|
||||
headers={"content-type": "application/json"},
|
||||
json_data={"error": {"message": "Bad request", "type": "invalid_request"}},
|
||||
)
|
||||
error = MockAPIError("API Error", response=response)
|
||||
|
||||
generations = _generate_response_from_error(error)
|
||||
|
||||
assert len(generations) == 1
|
||||
generation = generations[0]
|
||||
assert isinstance(generation, ChatGeneration)
|
||||
assert isinstance(generation.message, AIMessage)
|
||||
assert generation.message.content == ""
|
||||
|
||||
metadata = generation.message.response_metadata
|
||||
assert metadata["body"] == {
|
||||
"error": {"message": "Bad request", "type": "invalid_request"}
|
||||
}
|
||||
assert metadata["headers"] == {"content-type": "application/json"}
|
||||
assert metadata["status_code"] == 400
|
||||
|
||||
|
||||
def test_generate_response_from_error_handles_streaming_response_failure() -> None:
|
||||
# Simulates scenario where accessing response.json() or response.text
|
||||
# raises ResponseNotRead on streaming responses
|
||||
response = MockResponse(
|
||||
status_code=400,
|
||||
headers={"content-type": "application/json"},
|
||||
json_raises=Exception, # Simulates ResponseNotRead or similar
|
||||
text_raises=Exception,
|
||||
)
|
||||
error = MockAPIError("API Error", response=response)
|
||||
|
||||
# This should NOT raise an exception, but should handle it gracefully
|
||||
generations = _generate_response_from_error(error)
|
||||
|
||||
assert len(generations) == 1
|
||||
generation = generations[0]
|
||||
metadata = generation.message.response_metadata
|
||||
|
||||
# When both fail, body should be None instead of raising an exception
|
||||
assert metadata["body"] is None
|
||||
assert metadata["headers"] == {"content-type": "application/json"}
|
||||
assert metadata["status_code"] == 400
|
||||
|
||||
@@ -1,11 +0,0 @@
|
||||
"""Test for Serializable base class."""
|
||||
|
||||
from langchain_core.load.load import load
|
||||
|
||||
|
||||
def test_load_with_string_secrets() -> None:
|
||||
obj = {"api_key": "__SECRET_API_KEY__"}
|
||||
secrets_map = {"__SECRET_API_KEY__": "hello"}
|
||||
result = load(obj, secrets_map=secrets_map)
|
||||
|
||||
assert result["api_key"] == "hello"
|
||||
@@ -0,0 +1,140 @@
|
||||
"""Test groq block translator."""
|
||||
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.messages import content as types
|
||||
from langchain_core.messages.base import _extract_reasoning_from_additional_kwargs
|
||||
from langchain_core.messages.block_translators import PROVIDER_TRANSLATORS
|
||||
from langchain_core.messages.block_translators.groq import (
|
||||
_parse_code_json,
|
||||
translate_content,
|
||||
)
|
||||
|
||||
|
||||
def test_groq_translator_registered() -> None:
|
||||
"""Test that groq translator is properly registered."""
|
||||
assert "groq" in PROVIDER_TRANSLATORS
|
||||
assert "translate_content" in PROVIDER_TRANSLATORS["groq"]
|
||||
assert "translate_content_chunk" in PROVIDER_TRANSLATORS["groq"]
|
||||
|
||||
|
||||
def test_extract_reasoning_from_additional_kwargs_exists() -> None:
|
||||
"""Test that _extract_reasoning_from_additional_kwargs can be imported."""
|
||||
# Verify it's callable
|
||||
assert callable(_extract_reasoning_from_additional_kwargs)
|
||||
|
||||
|
||||
def test_groq_translate_content_basic() -> None:
|
||||
"""Test basic groq content translation."""
|
||||
# Test with simple text message
|
||||
message = AIMessage(content="Hello world")
|
||||
blocks = translate_content(message)
|
||||
|
||||
assert isinstance(blocks, list)
|
||||
assert len(blocks) == 1
|
||||
assert blocks[0]["type"] == "text"
|
||||
assert blocks[0]["text"] == "Hello world"
|
||||
|
||||
|
||||
def test_groq_translate_content_with_reasoning() -> None:
|
||||
"""Test groq content translation with reasoning content."""
|
||||
# Test with reasoning content in additional_kwargs
|
||||
message = AIMessage(
|
||||
content="Final answer",
|
||||
additional_kwargs={"reasoning_content": "Let me think about this..."},
|
||||
)
|
||||
blocks = translate_content(message)
|
||||
|
||||
assert isinstance(blocks, list)
|
||||
assert len(blocks) == 2
|
||||
|
||||
# First block should be reasoning
|
||||
assert blocks[0]["type"] == "reasoning"
|
||||
assert blocks[0]["reasoning"] == "Let me think about this..."
|
||||
|
||||
# Second block should be text
|
||||
assert blocks[1]["type"] == "text"
|
||||
assert blocks[1]["text"] == "Final answer"
|
||||
|
||||
|
||||
def test_groq_translate_content_with_tool_calls() -> None:
|
||||
"""Test groq content translation with tool calls."""
|
||||
# Test with tool calls
|
||||
message = AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "search",
|
||||
"args": {"query": "test"},
|
||||
"id": "call_123",
|
||||
}
|
||||
],
|
||||
)
|
||||
blocks = translate_content(message)
|
||||
|
||||
assert isinstance(blocks, list)
|
||||
assert len(blocks) == 1
|
||||
assert blocks[0]["type"] == "tool_call"
|
||||
assert blocks[0]["name"] == "search"
|
||||
assert blocks[0]["args"] == {"query": "test"}
|
||||
assert blocks[0]["id"] == "call_123"
|
||||
|
||||
|
||||
def test_groq_translate_content_with_executed_tools() -> None:
|
||||
"""Test groq content translation with executed tools (built-in tools)."""
|
||||
# Test with executed_tools in additional_kwargs (Groq built-in tools)
|
||||
message = AIMessage(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"executed_tools": [
|
||||
{
|
||||
"type": "python",
|
||||
"arguments": '{"code": "print(\\"hello\\")"}',
|
||||
"output": "hello\\n",
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
blocks = translate_content(message)
|
||||
|
||||
assert isinstance(blocks, list)
|
||||
# Should have server_tool_call and server_tool_result
|
||||
assert len(blocks) >= 2
|
||||
|
||||
# Check for server_tool_call
|
||||
tool_call_blocks = [
|
||||
cast("types.ServerToolCall", b)
|
||||
for b in blocks
|
||||
if b.get("type") == "server_tool_call"
|
||||
]
|
||||
assert len(tool_call_blocks) == 1
|
||||
assert tool_call_blocks[0]["name"] == "code_interpreter"
|
||||
assert "code" in tool_call_blocks[0]["args"]
|
||||
|
||||
# Check for server_tool_result
|
||||
tool_result_blocks = [
|
||||
cast("types.ServerToolResult", b)
|
||||
for b in blocks
|
||||
if b.get("type") == "server_tool_result"
|
||||
]
|
||||
assert len(tool_result_blocks) == 1
|
||||
assert tool_result_blocks[0]["output"] == "hello\\n"
|
||||
assert tool_result_blocks[0]["status"] == "success"
|
||||
|
||||
|
||||
def test_parse_code_json() -> None:
|
||||
"""Test the _parse_code_json helper function."""
|
||||
# Test valid code JSON
|
||||
result = _parse_code_json('{"code": "print(\'hello\')"}')
|
||||
assert result == {"code": "print('hello')"}
|
||||
|
||||
# Test code with unescaped quotes (Groq format)
|
||||
result = _parse_code_json('{"code": "print("hello")"}')
|
||||
assert result == {"code": 'print("hello")'}
|
||||
|
||||
# Test invalid format raises ValueError
|
||||
with pytest.raises(ValueError, match="Could not extract Python code"):
|
||||
_parse_code_json('{"invalid": "format"}')
|
||||
@@ -3,12 +3,14 @@
|
||||
import asyncio
|
||||
import time
|
||||
from threading import Lock
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_core.runnables import RunnableConfig, RunnableLambda
|
||||
from langchain_core.runnables.base import Runnable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.runnables.base import Runnable
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -3,21 +3,25 @@ from __future__ import annotations
|
||||
import json
|
||||
import sys
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator, Callable, Coroutine, Generator
|
||||
from inspect import isasyncgenfunction
|
||||
from typing import Any, Literal
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from langsmith import Client, get_current_run_tree, traceable
|
||||
from langsmith.run_helpers import tracing_context
|
||||
from langsmith.run_trees import RunTree
|
||||
from langsmith.utils import get_env_var
|
||||
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.runnables.base import RunnableLambda, RunnableParallel
|
||||
from langchain_core.tracers.langchain import LangChainTracer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncGenerator, Callable, Coroutine, Generator
|
||||
|
||||
from langsmith.run_trees import RunTree
|
||||
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
|
||||
|
||||
def _get_posts(client: Client) -> list:
|
||||
mock_calls = client.session.request.mock_calls # type: ignore[attr-defined]
|
||||
|
||||
@@ -6,6 +6,7 @@ import sys
|
||||
import textwrap
|
||||
import threading
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
@@ -55,6 +56,7 @@ from langchain_core.tools.base import (
|
||||
InjectedToolArg,
|
||||
InjectedToolCallId,
|
||||
SchemaAnnotationError,
|
||||
_DirectlyInjectedToolArg,
|
||||
_is_message_content_block,
|
||||
_is_message_content_type,
|
||||
get_all_basemodel_annotations,
|
||||
@@ -2331,6 +2333,101 @@ def test_injected_arg_with_complex_type() -> None:
|
||||
assert injected_tool.invoke({"x": 5, "foo": Foo()}) == "bar"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("schema_format", ["model", "json_schema"])
|
||||
def test_tool_allows_extra_runtime_args_with_custom_schema(
|
||||
schema_format: Literal["model", "json_schema"],
|
||||
) -> None:
|
||||
"""Ensure runtime args are preserved even if not in the args schema."""
|
||||
|
||||
class InputSchema(BaseModel):
|
||||
query: str
|
||||
|
||||
captured: dict[str, Any] = {}
|
||||
|
||||
@dataclass
|
||||
class MyRuntime(_DirectlyInjectedToolArg):
|
||||
some_obj: object
|
||||
|
||||
args_schema = (
|
||||
InputSchema if schema_format == "model" else InputSchema.model_json_schema()
|
||||
)
|
||||
|
||||
@tool(args_schema=args_schema)
|
||||
def runtime_tool(query: str, runtime: MyRuntime) -> str:
|
||||
"""Echo the query and capture runtime value."""
|
||||
captured["runtime"] = runtime
|
||||
return query
|
||||
|
||||
runtime_obj = object()
|
||||
runtime = MyRuntime(some_obj=runtime_obj)
|
||||
assert runtime_tool.invoke({"query": "hello", "runtime": runtime}) == "hello"
|
||||
assert captured["runtime"] is runtime
|
||||
|
||||
|
||||
def test_tool_injected_tool_call_id_with_custom_schema() -> None:
|
||||
"""Ensure InjectedToolCallId works with custom args schema."""
|
||||
|
||||
class InputSchema(BaseModel):
|
||||
x: int
|
||||
|
||||
@tool(args_schema=InputSchema)
|
||||
def injected_tool(
|
||||
x: int, tool_call_id: Annotated[str, InjectedToolCallId]
|
||||
) -> ToolMessage:
|
||||
"""Tool with injected tool_call_id and custom schema."""
|
||||
return ToolMessage(str(x), tool_call_id=tool_call_id)
|
||||
|
||||
# Test that tool_call_id is properly injected even though not in custom schema
|
||||
result = injected_tool.invoke(
|
||||
{
|
||||
"type": "tool_call",
|
||||
"args": {"x": 42},
|
||||
"name": "injected_tool",
|
||||
"id": "test_call_id",
|
||||
}
|
||||
)
|
||||
assert result == ToolMessage("42", tool_call_id="test_call_id")
|
||||
|
||||
# Test that it still raises error when invoked without a ToolCall
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="When tool includes an InjectedToolCallId argument, "
|
||||
"tool must always be invoked with a full model ToolCall",
|
||||
):
|
||||
injected_tool.invoke({"x": 42})
|
||||
|
||||
|
||||
def test_tool_injected_arg_with_custom_schema() -> None:
|
||||
"""Ensure InjectedToolArg works with custom args schema."""
|
||||
|
||||
class InputSchema(BaseModel):
|
||||
query: str
|
||||
|
||||
class CustomContext:
|
||||
"""Custom context object to be injected."""
|
||||
|
||||
def __init__(self, value: str) -> None:
|
||||
self.value = value
|
||||
|
||||
captured: dict[str, Any] = {}
|
||||
|
||||
@tool(args_schema=InputSchema)
|
||||
def search_tool(
|
||||
query: str, context: Annotated[CustomContext, InjectedToolArg]
|
||||
) -> str:
|
||||
"""Search with custom context."""
|
||||
captured["context"] = context
|
||||
return f"Results for {query} with context {context.value}"
|
||||
|
||||
# Test that context is properly injected even though not in custom schema
|
||||
ctx = CustomContext("test_context")
|
||||
result = search_tool.invoke({"query": "hello", "context": ctx})
|
||||
|
||||
assert result == "Results for hello with context test_context"
|
||||
assert captured["context"] is ctx
|
||||
assert captured["context"].value == "test_context"
|
||||
|
||||
|
||||
def test_tool_injected_tool_call_id() -> None:
|
||||
@tool
|
||||
def foo(x: int, tool_call_id: Annotated[str, InjectedToolCallId]) -> ToolMessage:
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from collections.abc import Callable
|
||||
from contextlib import AbstractContextManager, nullcontext
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
@@ -23,6 +22,9 @@ from langchain_core.utils import (
|
||||
from langchain_core.utils._merge import merge_dicts
|
||||
from langchain_core.utils.utils import secret_from_env
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("package", "check_kwargs", "actual_version", "expected"),
|
||||
|
||||
4
libs/core/uv.lock
generated
4
libs/core/uv.lock
generated
@@ -960,7 +960,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "langchain-core"
|
||||
version = "1.0.4"
|
||||
version = "1.0.6"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "jsonpatch" },
|
||||
@@ -1057,7 +1057,7 @@ typing = [
|
||||
|
||||
[[package]]
|
||||
name = "langchain-model-profiles"
|
||||
version = "0.0.3"
|
||||
version = "0.0.4"
|
||||
source = { directory = "../model-profiles" }
|
||||
dependencies = [
|
||||
{ name = "tomli", marker = "python_full_version < '3.11'" },
|
||||
|
||||
@@ -24,7 +24,7 @@ In most cases, you should be using the main [`langchain`](https://pypi.org/proje
|
||||
|
||||
## 📖 Documentation
|
||||
|
||||
For full documentation, see the [API reference](https://reference.langchain.com/python/langchain_classic).
|
||||
For full documentation, see the [API reference](https://reference.langchain.com/python/langchain_classic). For conceptual guides, tutorials, and examples on using LangChain, see the [LangChain Docs](https://docs.langchain.com/oss/python/langchain/overview).
|
||||
|
||||
## 📕 Releases & Versioning
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
.PHONY: all start_services stop_services coverage test test_fast extended_tests test_watch test_watch_extended integration_tests check_imports lint format lint_diff format_diff lint_package lint_tests help
|
||||
.PHONY: all start_services stop_services coverage coverage_agents test test_fast extended_tests test_watch test_watch_extended integration_tests check_imports lint format lint_diff format_diff lint_package lint_tests help
|
||||
|
||||
# Default target executed when no arguments are given to make.
|
||||
all: help
|
||||
@@ -27,8 +27,17 @@ coverage:
|
||||
--cov-report term-missing:skip-covered \
|
||||
$(TEST_FILE)
|
||||
|
||||
# Run middleware and agent tests with coverage report.
|
||||
coverage_agents:
|
||||
uv run --group test pytest \
|
||||
tests/unit_tests/agents/middleware/ \
|
||||
tests/unit_tests/agents/test_*.py \
|
||||
--cov=langchain.agents \
|
||||
--cov-report=term-missing \
|
||||
--cov-report=html:htmlcov \
|
||||
|
||||
test:
|
||||
make start_services && LANGGRAPH_TEST_FAST=0 uv run --no-sync --active --group test pytest -n auto --disable-socket --allow-unix-socket $(TEST_FILE) --cov-report term-missing:skip-covered; \
|
||||
make start_services && LANGGRAPH_TEST_FAST=0 uv run --no-sync --active --group test pytest -n auto --disable-socket --allow-unix-socket $(TEST_FILE) --cov-report term-missing:skip-covered --snapshot-update; \
|
||||
EXIT_CODE=$$?; \
|
||||
make stop_services; \
|
||||
exit $$EXIT_CODE
|
||||
@@ -93,6 +102,7 @@ help:
|
||||
@echo 'lint - run linters'
|
||||
@echo '-- TESTS --'
|
||||
@echo 'coverage - run unit tests and generate coverage report'
|
||||
@echo 'coverage_agents - run middleware and agent tests with coverage report'
|
||||
@echo 'test - run unit tests with all services'
|
||||
@echo 'test_fast - run unit tests with in-memory services only'
|
||||
@echo 'tests - run unit tests (alias for "make test")'
|
||||
|
||||
@@ -26,7 +26,7 @@ LangChain [agents](https://docs.langchain.com/oss/python/langchain/agents) are b
|
||||
|
||||
## 📖 Documentation
|
||||
|
||||
For full documentation, see the [API reference](https://reference.langchain.com/python/langchain/langchain/).
|
||||
For full documentation, see the [API reference](https://reference.langchain.com/python/langchain/langchain/). For conceptual guides, tutorials, and examples on using LangChain, see the [LangChain Docs](https://docs.langchain.com/oss/python/langchain/overview).
|
||||
|
||||
## 📕 Releases & Versioning
|
||||
|
||||
|
||||
@@ -22,17 +22,20 @@ from langgraph.graph.state import StateGraph
|
||||
from langgraph.prebuilt.tool_node import ToolCallWithContext, ToolNode
|
||||
from langgraph.runtime import Runtime # noqa: TC002
|
||||
from langgraph.types import Command, Send
|
||||
from langgraph.typing import ContextT # noqa: TC002
|
||||
from langgraph.typing import ContextT
|
||||
from typing_extensions import NotRequired, Required, TypedDict
|
||||
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
AsyncModelCallHandler,
|
||||
JumpTo,
|
||||
ModelCallHandler,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
OmitFromSchema,
|
||||
ResponseT,
|
||||
StateT,
|
||||
StateT_co,
|
||||
_InputAgentState,
|
||||
_OutputAgentState,
|
||||
@@ -63,6 +66,18 @@ if TYPE_CHECKING:
|
||||
|
||||
STRUCTURED_OUTPUT_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes."
|
||||
|
||||
FALLBACK_MODELS_WITH_STRUCTURED_OUTPUT = [
|
||||
# if langchain-model-profiles is not installed, these models are assumed to support
|
||||
# structured output
|
||||
"grok",
|
||||
"gpt-5",
|
||||
"gpt-4.1",
|
||||
"gpt-4o",
|
||||
"gpt-oss",
|
||||
"o3-pro",
|
||||
"o3-mini",
|
||||
]
|
||||
|
||||
|
||||
def _normalize_to_model_response(result: ModelResponse | AIMessage) -> ModelResponse:
|
||||
"""Normalize middleware return value to ModelResponse."""
|
||||
@@ -74,13 +89,13 @@ def _normalize_to_model_response(result: ModelResponse | AIMessage) -> ModelResp
|
||||
def _chain_model_call_handlers(
|
||||
handlers: Sequence[
|
||||
Callable[
|
||||
[ModelRequest, Callable[[ModelRequest], ModelResponse]],
|
||||
[ModelRequest[StateT, ContextT], ModelCallHandler[StateT, ContextT]],
|
||||
ModelResponse | AIMessage,
|
||||
]
|
||||
],
|
||||
) -> (
|
||||
Callable[
|
||||
[ModelRequest, Callable[[ModelRequest], ModelResponse]],
|
||||
[ModelRequest[StateT, ContextT], ModelCallHandler[StateT, ContextT]],
|
||||
ModelResponse,
|
||||
]
|
||||
| None
|
||||
@@ -128,8 +143,8 @@ def _chain_model_call_handlers(
|
||||
single_handler = handlers[0]
|
||||
|
||||
def normalized_single(
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
request: ModelRequest[StateT, ContextT],
|
||||
handler: ModelCallHandler[StateT, ContextT],
|
||||
) -> ModelResponse:
|
||||
result = single_handler(request, handler)
|
||||
return _normalize_to_model_response(result)
|
||||
@@ -138,25 +153,25 @@ def _chain_model_call_handlers(
|
||||
|
||||
def compose_two(
|
||||
outer: Callable[
|
||||
[ModelRequest, Callable[[ModelRequest], ModelResponse]],
|
||||
[ModelRequest[StateT, ContextT], ModelCallHandler[StateT, ContextT]],
|
||||
ModelResponse | AIMessage,
|
||||
],
|
||||
inner: Callable[
|
||||
[ModelRequest, Callable[[ModelRequest], ModelResponse]],
|
||||
[ModelRequest[StateT, ContextT], ModelCallHandler[StateT, ContextT]],
|
||||
ModelResponse | AIMessage,
|
||||
],
|
||||
) -> Callable[
|
||||
[ModelRequest, Callable[[ModelRequest], ModelResponse]],
|
||||
[ModelRequest[StateT, ContextT], ModelCallHandler[StateT, ContextT]],
|
||||
ModelResponse,
|
||||
]:
|
||||
"""Compose two handlers where outer wraps inner."""
|
||||
|
||||
def composed(
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
request: ModelRequest[StateT, ContextT],
|
||||
handler: ModelCallHandler[StateT, ContextT],
|
||||
) -> ModelResponse:
|
||||
# Create a wrapper that calls inner with the base handler and normalizes
|
||||
def inner_handler(req: ModelRequest) -> ModelResponse:
|
||||
def inner_handler(req: ModelRequest[StateT, ContextT]) -> ModelResponse:
|
||||
inner_result = inner(req, handler)
|
||||
return _normalize_to_model_response(inner_result)
|
||||
|
||||
@@ -173,8 +188,8 @@ def _chain_model_call_handlers(
|
||||
|
||||
# Wrap to ensure final return type is exactly ModelResponse
|
||||
def final_normalized(
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
request: ModelRequest[StateT, ContextT],
|
||||
handler: ModelCallHandler[StateT, ContextT],
|
||||
) -> ModelResponse:
|
||||
# result here is typed as returning ModelResponse | AIMessage but compose_two normalizes
|
||||
final_result = result(request, handler)
|
||||
@@ -186,13 +201,13 @@ def _chain_model_call_handlers(
|
||||
def _chain_async_model_call_handlers(
|
||||
handlers: Sequence[
|
||||
Callable[
|
||||
[ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]],
|
||||
[ModelRequest[StateT, ContextT], AsyncModelCallHandler[StateT, ContextT]],
|
||||
Awaitable[ModelResponse | AIMessage],
|
||||
]
|
||||
],
|
||||
) -> (
|
||||
Callable[
|
||||
[ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]],
|
||||
[ModelRequest[StateT, ContextT], AsyncModelCallHandler[StateT, ContextT]],
|
||||
Awaitable[ModelResponse],
|
||||
]
|
||||
| None
|
||||
@@ -213,8 +228,8 @@ def _chain_async_model_call_handlers(
|
||||
single_handler = handlers[0]
|
||||
|
||||
async def normalized_single(
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
request: ModelRequest[StateT, ContextT],
|
||||
handler: AsyncModelCallHandler[StateT, ContextT],
|
||||
) -> ModelResponse:
|
||||
result = await single_handler(request, handler)
|
||||
return _normalize_to_model_response(result)
|
||||
@@ -223,25 +238,25 @@ def _chain_async_model_call_handlers(
|
||||
|
||||
def compose_two(
|
||||
outer: Callable[
|
||||
[ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]],
|
||||
[ModelRequest[StateT, ContextT], AsyncModelCallHandler[StateT, ContextT]],
|
||||
Awaitable[ModelResponse | AIMessage],
|
||||
],
|
||||
inner: Callable[
|
||||
[ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]],
|
||||
[ModelRequest[StateT, ContextT], AsyncModelCallHandler[StateT, ContextT]],
|
||||
Awaitable[ModelResponse | AIMessage],
|
||||
],
|
||||
) -> Callable[
|
||||
[ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]],
|
||||
[ModelRequest[StateT, ContextT], AsyncModelCallHandler[StateT, ContextT]],
|
||||
Awaitable[ModelResponse],
|
||||
]:
|
||||
"""Compose two async handlers where outer wraps inner."""
|
||||
|
||||
async def composed(
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
request: ModelRequest[StateT, ContextT],
|
||||
handler: AsyncModelCallHandler[StateT, ContextT],
|
||||
) -> ModelResponse:
|
||||
# Create a wrapper that calls inner with the base handler and normalizes
|
||||
async def inner_handler(req: ModelRequest) -> ModelResponse:
|
||||
async def inner_handler(req: ModelRequest[StateT, ContextT]) -> ModelResponse:
|
||||
inner_result = await inner(req, handler)
|
||||
return _normalize_to_model_response(inner_result)
|
||||
|
||||
@@ -258,8 +273,8 @@ def _chain_async_model_call_handlers(
|
||||
|
||||
# Wrap to ensure final return type is exactly ModelResponse
|
||||
async def final_normalized(
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
request: ModelRequest[StateT, ContextT],
|
||||
handler: AsyncModelCallHandler[StateT, ContextT],
|
||||
) -> ModelResponse:
|
||||
# result here is typed as returning ModelResponse | AIMessage but compose_two normalizes
|
||||
final_result = await result(request, handler)
|
||||
@@ -349,11 +364,13 @@ def _get_can_jump_to(middleware: AgentMiddleware[Any, Any], hook_name: str) -> l
|
||||
return []
|
||||
|
||||
|
||||
def _supports_provider_strategy(model: str | BaseChatModel) -> bool:
|
||||
def _supports_provider_strategy(model: str | BaseChatModel, tools: list | None = None) -> bool:
|
||||
"""Check if a model supports provider-specific structured output.
|
||||
|
||||
Args:
|
||||
model: Model name string or `BaseChatModel` instance.
|
||||
tools: Optional list of tools provided to the agent. Needed because some models
|
||||
don't support structured output together with tool calling.
|
||||
|
||||
Returns:
|
||||
`True` if the model supports provider-specific structured output, `False` otherwise.
|
||||
@@ -362,11 +379,26 @@ def _supports_provider_strategy(model: str | BaseChatModel) -> bool:
|
||||
if isinstance(model, str):
|
||||
model_name = model
|
||||
elif isinstance(model, BaseChatModel):
|
||||
model_name = getattr(model, "model_name", None)
|
||||
model_name = (
|
||||
getattr(model, "model_name", None)
|
||||
or getattr(model, "model", None)
|
||||
or getattr(model, "model_id", "")
|
||||
)
|
||||
try:
|
||||
model_profile = model.profile
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
if (
|
||||
model_profile.get("structured_output")
|
||||
# We make an exception for Gemini models, which currently do not support
|
||||
# simultaneous tool use with structured output
|
||||
and not (tools and isinstance(model_name, str) and "gemini" in model_name.lower())
|
||||
):
|
||||
return True
|
||||
|
||||
return (
|
||||
"grok" in model_name.lower()
|
||||
or any(part in model_name for part in ["gpt-5", "gpt-4.1", "gpt-oss", "o3-pro", "o3-mini"])
|
||||
any(part in model_name.lower() for part in FALLBACK_MODELS_WITH_STRUCTURED_OUTPUT)
|
||||
if model_name
|
||||
else False
|
||||
)
|
||||
@@ -517,9 +549,9 @@ def create_agent( # noqa: PLR0915
|
||||
tools: Sequence[BaseTool | Callable | dict[str, Any]] | None = None,
|
||||
*,
|
||||
system_prompt: str | None = None,
|
||||
middleware: Sequence[AgentMiddleware[StateT_co, ContextT]] = (),
|
||||
response_format: ResponseFormat[ResponseT] | type[ResponseT] | None = None,
|
||||
state_schema: type[AgentState[ResponseT]] | None = None,
|
||||
middleware: Sequence[AgentMiddleware[StateT_co, ContextT]] = (),
|
||||
context_schema: type[ContextT] | None = None,
|
||||
checkpointer: Checkpointer | None = None,
|
||||
store: BaseStore | None = None,
|
||||
@@ -939,7 +971,9 @@ def create_agent( # noqa: PLR0915
|
||||
|
||||
return {"messages": [output]}
|
||||
|
||||
def _get_bound_model(request: ModelRequest) -> tuple[Runnable, ResponseFormat | None]:
|
||||
def _get_bound_model(
|
||||
request: ModelRequest[StateT, ContextT],
|
||||
) -> tuple[Runnable, ResponseFormat | None]:
|
||||
"""Get the model with appropriate tool bindings.
|
||||
|
||||
Performs auto-detection of strategy if needed based on model capabilities.
|
||||
@@ -988,7 +1022,7 @@ def create_agent( # noqa: PLR0915
|
||||
effective_response_format: ResponseFormat | None
|
||||
if isinstance(request.response_format, AutoStrategy):
|
||||
# User provided raw schema via AutoStrategy - auto-detect best strategy based on model
|
||||
if _supports_provider_strategy(request.model):
|
||||
if _supports_provider_strategy(request.model, tools=request.tools):
|
||||
# Model supports provider strategy - use it
|
||||
effective_response_format = ProviderStrategy(schema=request.response_format.schema)
|
||||
else:
|
||||
@@ -1009,7 +1043,7 @@ def create_agent( # noqa: PLR0915
|
||||
|
||||
# Bind model based on effective response format
|
||||
if isinstance(effective_response_format, ProviderStrategy):
|
||||
# Use provider-specific structured output
|
||||
# (Backward compatibility) Use OpenAI format structured output
|
||||
kwargs = effective_response_format.to_model_kwargs()
|
||||
return (
|
||||
request.model.bind_tools(
|
||||
@@ -1053,7 +1087,7 @@ def create_agent( # noqa: PLR0915
|
||||
)
|
||||
return request.model.bind(**request.model_settings), None
|
||||
|
||||
def _execute_model_sync(request: ModelRequest) -> ModelResponse:
|
||||
def _execute_model_sync(request: ModelRequest[StateT, ContextT]) -> ModelResponse:
|
||||
"""Execute model and return response.
|
||||
|
||||
This is the core model execution logic wrapped by `wrap_model_call` handlers.
|
||||
@@ -1077,9 +1111,9 @@ def create_agent( # noqa: PLR0915
|
||||
structured_response=structured_response,
|
||||
)
|
||||
|
||||
def model_node(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
|
||||
def model_node(state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any]:
|
||||
"""Sync model request handler with sequential middleware processing."""
|
||||
request = ModelRequest(
|
||||
request = ModelRequest[StateT, ContextT](
|
||||
model=model,
|
||||
tools=default_tools,
|
||||
system_prompt=system_prompt,
|
||||
@@ -1104,7 +1138,7 @@ def create_agent( # noqa: PLR0915
|
||||
|
||||
return state_updates
|
||||
|
||||
async def _execute_model_async(request: ModelRequest) -> ModelResponse:
|
||||
async def _execute_model_async(request: ModelRequest[StateT, ContextT]) -> ModelResponse:
|
||||
"""Execute model asynchronously and return response.
|
||||
|
||||
This is the core async model execution logic wrapped by `wrap_model_call`
|
||||
@@ -1130,9 +1164,9 @@ def create_agent( # noqa: PLR0915
|
||||
structured_response=structured_response,
|
||||
)
|
||||
|
||||
async def amodel_node(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
|
||||
async def amodel_node(state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any]:
|
||||
"""Async model request handler with sequential middleware processing."""
|
||||
request = ModelRequest(
|
||||
request = ModelRequest[StateT, ContextT](
|
||||
model=model,
|
||||
tools=default_tools,
|
||||
system_prompt=system_prompt,
|
||||
|
||||
@@ -4,6 +4,7 @@ from .context_editing import (
|
||||
ClearToolUsesEdit,
|
||||
ContextEditingMiddleware,
|
||||
)
|
||||
from .file_search import FilesystemFileSearchMiddleware
|
||||
from .human_in_the_loop import (
|
||||
HumanInTheLoopMiddleware,
|
||||
InterruptOnConfig,
|
||||
@@ -46,6 +47,7 @@ __all__ = [
|
||||
"CodexSandboxExecutionPolicy",
|
||||
"ContextEditingMiddleware",
|
||||
"DockerExecutionPolicy",
|
||||
"FilesystemFileSearchMiddleware",
|
||||
"HostExecutionPolicy",
|
||||
"HumanInTheLoopMiddleware",
|
||||
"InterruptOnConfig",
|
||||
|
||||
@@ -10,6 +10,7 @@ chat model.
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Awaitable, Callable, Iterable, Sequence
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
@@ -238,10 +239,11 @@ class ContextEditingMiddleware(AgentMiddleware):
|
||||
system_msg + list(messages), request.tools
|
||||
)
|
||||
|
||||
edited_messages = deepcopy(list(request.messages))
|
||||
for edit in self.edits:
|
||||
edit.apply(request.messages, count_tokens=count_tokens)
|
||||
edit.apply(edited_messages, count_tokens=count_tokens)
|
||||
|
||||
return handler(request)
|
||||
return handler(request.override(messages=edited_messages))
|
||||
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
@@ -266,10 +268,11 @@ class ContextEditingMiddleware(AgentMiddleware):
|
||||
system_msg + list(messages), request.tools
|
||||
)
|
||||
|
||||
edited_messages = deepcopy(list(request.messages))
|
||||
for edit in self.edits:
|
||||
edit.apply(request.messages, count_tokens=count_tokens)
|
||||
edit.apply(edited_messages, count_tokens=count_tokens)
|
||||
|
||||
return await handler(request)
|
||||
return await handler(request.override(messages=edited_messages))
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
||||
@@ -120,9 +120,9 @@ class FilesystemFileSearchMiddleware(AgentMiddleware):
|
||||
|
||||
Args:
|
||||
root_path: Root directory to search.
|
||||
use_ripgrep: Whether to use ripgrep for search.
|
||||
use_ripgrep: Whether to use `ripgrep` for search.
|
||||
|
||||
Falls back to Python if ripgrep unavailable.
|
||||
Falls back to Python if `ripgrep` unavailable.
|
||||
max_file_size_mb: Maximum file size to search in MB.
|
||||
"""
|
||||
self.root_path = Path(root_path).resolve()
|
||||
|
||||
@@ -353,3 +353,7 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
|
||||
last_ai_msg.tool_calls = revised_tool_calls
|
||||
|
||||
return {"messages": [last_ai_msg, *artificial_tool_messages]}
|
||||
|
||||
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
|
||||
"""Async trigger interrupt flows for relevant tool calls after an `AIMessage`."""
|
||||
return self.after_model(state, runtime)
|
||||
|
||||
@@ -133,6 +133,7 @@ class ModelCallLimitMiddleware(AgentMiddleware[ModelCallLimitState, Any]):
|
||||
|
||||
`None` means no limit.
|
||||
exit_behavior: What to do when limits are exceeded.
|
||||
|
||||
- `'end'`: Jump to the end of the agent execution and
|
||||
inject an artificial AI message indicating that the limit was
|
||||
exceeded.
|
||||
@@ -198,6 +199,29 @@ class ModelCallLimitMiddleware(AgentMiddleware[ModelCallLimitState, Any]):
|
||||
|
||||
return None
|
||||
|
||||
@hook_config(can_jump_to=["end"])
|
||||
async def abefore_model(
|
||||
self,
|
||||
state: ModelCallLimitState,
|
||||
runtime: Runtime,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Async check model call limits before making a model call.
|
||||
|
||||
Args:
|
||||
state: The current agent state containing call counts.
|
||||
runtime: The langgraph runtime.
|
||||
|
||||
Returns:
|
||||
If limits are exceeded and exit_behavior is `'end'`, returns
|
||||
a `Command` to jump to the end with a limit exceeded message. Otherwise
|
||||
returns `None`.
|
||||
|
||||
Raises:
|
||||
ModelCallLimitExceededError: If limits are exceeded and `exit_behavior`
|
||||
is `'error'`.
|
||||
"""
|
||||
return self.before_model(state, runtime)
|
||||
|
||||
def after_model(self, state: ModelCallLimitState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
|
||||
"""Increment model call counts after a model call.
|
||||
|
||||
@@ -212,3 +236,19 @@ class ModelCallLimitMiddleware(AgentMiddleware[ModelCallLimitState, Any]):
|
||||
"thread_model_call_count": state.get("thread_model_call_count", 0) + 1,
|
||||
"run_model_call_count": state.get("run_model_call_count", 0) + 1,
|
||||
}
|
||||
|
||||
async def aafter_model(
|
||||
self,
|
||||
state: ModelCallLimitState,
|
||||
runtime: Runtime,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Async increment model call counts after a model call.
|
||||
|
||||
Args:
|
||||
state: The current agent state.
|
||||
runtime: The langgraph runtime.
|
||||
|
||||
Returns:
|
||||
State updates with incremented call counts.
|
||||
"""
|
||||
return self.after_model(state, runtime)
|
||||
|
||||
@@ -92,9 +92,8 @@ class ModelFallbackMiddleware(AgentMiddleware):
|
||||
|
||||
# Try fallback models
|
||||
for fallback_model in self.models:
|
||||
request.model = fallback_model
|
||||
try:
|
||||
return handler(request)
|
||||
return handler(request.override(model=fallback_model))
|
||||
except Exception as e: # noqa: BLE001
|
||||
last_exception = e
|
||||
continue
|
||||
@@ -127,9 +126,8 @@ class ModelFallbackMiddleware(AgentMiddleware):
|
||||
|
||||
# Try fallback models
|
||||
for fallback_model in self.models:
|
||||
request.model = fallback_model
|
||||
try:
|
||||
return await handler(request)
|
||||
return await handler(request.override(model=fallback_model))
|
||||
except Exception as e: # noqa: BLE001
|
||||
last_exception = e
|
||||
continue
|
||||
|
||||
@@ -252,6 +252,27 @@ class PIIMiddleware(AgentMiddleware):
|
||||
|
||||
return None
|
||||
|
||||
@hook_config(can_jump_to=["end"])
|
||||
async def abefore_model(
|
||||
self,
|
||||
state: AgentState,
|
||||
runtime: Runtime,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Async check user messages and tool results for PII before model invocation.
|
||||
|
||||
Args:
|
||||
state: The current agent state.
|
||||
runtime: The langgraph runtime.
|
||||
|
||||
Returns:
|
||||
Updated state with PII handled according to strategy, or `None` if no PII
|
||||
detected.
|
||||
|
||||
Raises:
|
||||
PIIDetectionError: If PII is detected and strategy is `'block'`.
|
||||
"""
|
||||
return self.before_model(state, runtime)
|
||||
|
||||
def after_model(
|
||||
self,
|
||||
state: AgentState,
|
||||
@@ -311,6 +332,26 @@ class PIIMiddleware(AgentMiddleware):
|
||||
|
||||
return {"messages": new_messages}
|
||||
|
||||
async def aafter_model(
|
||||
self,
|
||||
state: AgentState,
|
||||
runtime: Runtime,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Async check AI messages for PII after model invocation.
|
||||
|
||||
Args:
|
||||
state: The current agent state.
|
||||
runtime: The langgraph runtime.
|
||||
|
||||
Returns:
|
||||
Updated state with PII handled according to strategy, or None if no PII
|
||||
detected.
|
||||
|
||||
Raises:
|
||||
PIIDetectionError: If PII is detected and strategy is `'block'`.
|
||||
"""
|
||||
return self.after_model(state, runtime)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"PIIDetectionError",
|
||||
|
||||
@@ -11,17 +11,17 @@ import subprocess
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
import typing
|
||||
import uuid
|
||||
import weakref
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Annotated, Any, Literal
|
||||
from typing import TYPE_CHECKING, Annotated, Any, Literal, cast
|
||||
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langchain_core.tools.base import BaseTool, ToolException
|
||||
from langchain_core.tools.base import ToolException
|
||||
from langgraph.channels.untracked_value import UntrackedValue
|
||||
from pydantic import BaseModel, model_validator
|
||||
from pydantic.json_schema import SkipJsonSchema
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
from langchain.agents.middleware._execution import (
|
||||
@@ -38,14 +38,13 @@ from langchain.agents.middleware._redaction import (
|
||||
ResolvedRedactionRule,
|
||||
)
|
||||
from langchain.agents.middleware.types import AgentMiddleware, AgentState, PrivateStateAttr
|
||||
from langchain.tools import ToolRuntime, tool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Mapping, Sequence
|
||||
|
||||
from langgraph.runtime import Runtime
|
||||
from langgraph.types import Command
|
||||
|
||||
from langchain.agents.middleware.types import ToolCallRequest
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
_DONE_MARKER_PREFIX = "__LC_SHELL_DONE__"
|
||||
@@ -59,6 +58,7 @@ DEFAULT_TOOL_DESCRIPTION = (
|
||||
"session remains stable. Outputs may be truncated when they become very large, and long "
|
||||
"running commands will be terminated once their configured timeout elapses."
|
||||
)
|
||||
SHELL_TOOL_NAME = "shell"
|
||||
|
||||
|
||||
def _cleanup_resources(
|
||||
@@ -334,7 +334,17 @@ class _ShellToolInput(BaseModel):
|
||||
"""Input schema for the persistent shell tool."""
|
||||
|
||||
command: str | None = None
|
||||
"""The shell command to execute."""
|
||||
|
||||
restart: bool | None = None
|
||||
"""Whether to restart the shell session."""
|
||||
|
||||
runtime: Annotated[Any, SkipJsonSchema()] = None
|
||||
"""The runtime for the shell tool.
|
||||
|
||||
Included as a workaround at the moment bc args_schema doesn't work with
|
||||
injected ToolRuntime.
|
||||
"""
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_payload(self) -> _ShellToolInput:
|
||||
@@ -347,24 +357,6 @@ class _ShellToolInput(BaseModel):
|
||||
return self
|
||||
|
||||
|
||||
class _PersistentShellTool(BaseTool):
|
||||
"""Tool wrapper that relies on middleware interception for execution."""
|
||||
|
||||
name: str = "shell"
|
||||
description: str = DEFAULT_TOOL_DESCRIPTION
|
||||
args_schema: type[BaseModel] = _ShellToolInput
|
||||
|
||||
def __init__(self, middleware: ShellToolMiddleware, description: str | None = None) -> None:
|
||||
super().__init__()
|
||||
self._middleware = middleware
|
||||
if description is not None:
|
||||
self.description = description
|
||||
|
||||
def _run(self, **_: Any) -> Any: # pragma: no cover - executed via middleware wrapper
|
||||
msg = "Persistent shell tool execution should be intercepted via middleware wrappers."
|
||||
raise RuntimeError(msg)
|
||||
|
||||
|
||||
class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
|
||||
"""Middleware that registers a persistent shell tool for agents.
|
||||
|
||||
@@ -393,10 +385,11 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
|
||||
execution_policy: BaseExecutionPolicy | None = None,
|
||||
redaction_rules: tuple[RedactionRule, ...] | list[RedactionRule] | None = None,
|
||||
tool_description: str | None = None,
|
||||
tool_name: str = SHELL_TOOL_NAME,
|
||||
shell_command: Sequence[str] | str | None = None,
|
||||
env: Mapping[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Initialize the middleware.
|
||||
"""Initialize an instance of `ShellToolMiddleware`.
|
||||
|
||||
Args:
|
||||
workspace_root: Base directory for the shell session.
|
||||
@@ -414,6 +407,9 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
|
||||
returning it to the model.
|
||||
tool_description: Optional override for the registered shell tool
|
||||
description.
|
||||
tool_name: Name for the registered shell tool.
|
||||
|
||||
Defaults to `"shell"`.
|
||||
shell_command: Optional shell executable (string) or argument sequence used
|
||||
to launch the persistent session.
|
||||
|
||||
@@ -425,6 +421,7 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
|
||||
"""
|
||||
super().__init__()
|
||||
self._workspace_root = Path(workspace_root) if workspace_root else None
|
||||
self._tool_name = tool_name
|
||||
self._shell_command = self._normalize_shell_command(shell_command)
|
||||
self._environment = self._normalize_env(env)
|
||||
if execution_policy is not None:
|
||||
@@ -438,9 +435,25 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
|
||||
self._startup_commands = self._normalize_commands(startup_commands)
|
||||
self._shutdown_commands = self._normalize_commands(shutdown_commands)
|
||||
|
||||
# Create a proper tool that executes directly (no interception needed)
|
||||
description = tool_description or DEFAULT_TOOL_DESCRIPTION
|
||||
self._tool = _PersistentShellTool(self, description=description)
|
||||
self.tools = [self._tool]
|
||||
|
||||
@tool(self._tool_name, args_schema=_ShellToolInput, description=description)
|
||||
def shell_tool(
|
||||
*,
|
||||
runtime: ToolRuntime[None, ShellToolState],
|
||||
command: str | None = None,
|
||||
restart: bool = False,
|
||||
) -> ToolMessage | str:
|
||||
resources = self._get_or_create_resources(runtime.state)
|
||||
return self._run_shell_tool(
|
||||
resources,
|
||||
{"command": command, "restart": restart},
|
||||
tool_call_id=runtime.tool_call_id,
|
||||
)
|
||||
|
||||
self._shell_tool = shell_tool
|
||||
self.tools = [self._shell_tool]
|
||||
|
||||
@staticmethod
|
||||
def _normalize_commands(
|
||||
@@ -478,36 +491,48 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
|
||||
|
||||
def before_agent(self, state: ShellToolState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
|
||||
"""Start the shell session and run startup commands."""
|
||||
resources = self._create_resources()
|
||||
resources = self._get_or_create_resources(state)
|
||||
return {"shell_session_resources": resources}
|
||||
|
||||
async def abefore_agent(self, state: ShellToolState, runtime: Runtime) -> dict[str, Any] | None:
|
||||
"""Async counterpart to `before_agent`."""
|
||||
"""Async start the shell session and run startup commands."""
|
||||
return self.before_agent(state, runtime)
|
||||
|
||||
def after_agent(self, state: ShellToolState, runtime: Runtime) -> None: # noqa: ARG002
|
||||
"""Run shutdown commands and release resources when an agent completes."""
|
||||
resources = self._ensure_resources(state)
|
||||
resources = state.get("shell_session_resources")
|
||||
if not isinstance(resources, _SessionResources):
|
||||
# Resources were never created, nothing to clean up
|
||||
return
|
||||
try:
|
||||
self._run_shutdown_commands(resources.session)
|
||||
finally:
|
||||
resources._finalizer()
|
||||
|
||||
async def aafter_agent(self, state: ShellToolState, runtime: Runtime) -> None:
|
||||
"""Async counterpart to `after_agent`."""
|
||||
"""Async run shutdown commands and release resources when an agent completes."""
|
||||
return self.after_agent(state, runtime)
|
||||
|
||||
def _ensure_resources(self, state: ShellToolState) -> _SessionResources:
|
||||
def _get_or_create_resources(self, state: ShellToolState) -> _SessionResources:
|
||||
"""Get existing resources from state or create new ones if they don't exist.
|
||||
|
||||
This method enables resumability by checking if resources already exist in the state
|
||||
(e.g., after an interrupt), and only creating new resources if they're not present.
|
||||
|
||||
Args:
|
||||
state: The agent state which may contain shell session resources.
|
||||
|
||||
Returns:
|
||||
Session resources, either retrieved from state or newly created.
|
||||
"""
|
||||
resources = state.get("shell_session_resources")
|
||||
if resources is not None and not isinstance(resources, _SessionResources):
|
||||
resources = None
|
||||
if resources is None:
|
||||
msg = (
|
||||
"Shell session resources are unavailable. Ensure `before_agent` ran successfully "
|
||||
"before invoking the shell tool."
|
||||
)
|
||||
raise ToolException(msg)
|
||||
return resources
|
||||
if isinstance(resources, _SessionResources):
|
||||
return resources
|
||||
|
||||
new_resources = self._create_resources()
|
||||
# Cast needed to make state dict-like for mutation
|
||||
cast("dict[str, Any]", state)["shell_session_resources"] = new_resources
|
||||
return new_resources
|
||||
|
||||
def _create_resources(self) -> _SessionResources:
|
||||
workspace = self._workspace_root
|
||||
@@ -669,36 +694,6 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
|
||||
artifact=artifact,
|
||||
)
|
||||
|
||||
def wrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: typing.Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
"""Intercept local shell tool calls and execute them via the managed session."""
|
||||
if isinstance(request.tool, _PersistentShellTool):
|
||||
resources = self._ensure_resources(request.state)
|
||||
return self._run_shell_tool(
|
||||
resources,
|
||||
request.tool_call["args"],
|
||||
tool_call_id=request.tool_call.get("id"),
|
||||
)
|
||||
return handler(request)
|
||||
|
||||
async def awrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: typing.Callable[[ToolCallRequest], typing.Awaitable[ToolMessage | Command]],
|
||||
) -> ToolMessage | Command:
|
||||
"""Async interception mirroring the synchronous tool handler."""
|
||||
if isinstance(request.tool, _PersistentShellTool):
|
||||
resources = self._ensure_resources(request.state)
|
||||
return self._run_shell_tool(
|
||||
resources,
|
||||
request.tool_call["args"],
|
||||
tool_call_id=request.tool_call.get("id"),
|
||||
)
|
||||
return await handler(request)
|
||||
|
||||
def _format_tool_message(
|
||||
self,
|
||||
content: str,
|
||||
@@ -713,7 +708,7 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
|
||||
return ToolMessage(
|
||||
content=content,
|
||||
tool_call_id=tool_call_id,
|
||||
name=self._tool.name,
|
||||
name=self._tool_name,
|
||||
status=status,
|
||||
artifact=artifact,
|
||||
)
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
"""Summarization middleware."""
|
||||
|
||||
import uuid
|
||||
from collections.abc import Callable, Iterable
|
||||
from typing import Any, cast
|
||||
import warnings
|
||||
from collections.abc import Callable, Iterable, Mapping
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
@@ -51,13 +52,17 @@ Messages to summarize:
|
||||
{messages}
|
||||
</messages>""" # noqa: E501
|
||||
|
||||
SUMMARY_PREFIX = "## Previous conversation summary:"
|
||||
|
||||
_DEFAULT_MESSAGES_TO_KEEP = 20
|
||||
_DEFAULT_TRIM_TOKEN_LIMIT = 4000
|
||||
_DEFAULT_FALLBACK_MESSAGE_COUNT = 15
|
||||
_SEARCH_RANGE_FOR_TOOL_PAIRS = 5
|
||||
|
||||
ContextFraction = tuple[Literal["fraction"], float]
|
||||
ContextTokens = tuple[Literal["tokens"], int]
|
||||
ContextMessages = tuple[Literal["messages"], int]
|
||||
|
||||
ContextSize = ContextFraction | ContextTokens | ContextMessages
|
||||
|
||||
|
||||
class SummarizationMiddleware(AgentMiddleware):
|
||||
"""Summarizes conversation history when token limits are approached.
|
||||
@@ -70,34 +75,95 @@ class SummarizationMiddleware(AgentMiddleware):
|
||||
def __init__(
|
||||
self,
|
||||
model: str | BaseChatModel,
|
||||
max_tokens_before_summary: int | None = None,
|
||||
messages_to_keep: int = _DEFAULT_MESSAGES_TO_KEEP,
|
||||
*,
|
||||
trigger: ContextSize | list[ContextSize] | None = None,
|
||||
keep: ContextSize = ("messages", _DEFAULT_MESSAGES_TO_KEEP),
|
||||
token_counter: TokenCounter = count_tokens_approximately,
|
||||
summary_prompt: str = DEFAULT_SUMMARY_PROMPT,
|
||||
summary_prefix: str = SUMMARY_PREFIX,
|
||||
trim_tokens_to_summarize: int | None = _DEFAULT_TRIM_TOKEN_LIMIT,
|
||||
**deprecated_kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize summarization middleware.
|
||||
|
||||
Args:
|
||||
model: The language model to use for generating summaries.
|
||||
max_tokens_before_summary: Token threshold to trigger summarization.
|
||||
If `None`, summarization is disabled.
|
||||
messages_to_keep: Number of recent messages to preserve after summarization.
|
||||
trigger: One or more thresholds that trigger summarization.
|
||||
|
||||
Provide a single `ContextSize` tuple or a list of tuples, in which case
|
||||
summarization runs when any threshold is breached.
|
||||
|
||||
Examples: `("messages", 50)`, `("tokens", 3000)`, `[("fraction", 0.8),
|
||||
("messages", 100)]`.
|
||||
keep: Context retention policy applied after summarization.
|
||||
|
||||
Provide a `ContextSize` tuple to specify how much history to preserve.
|
||||
|
||||
Defaults to keeping the most recent 20 messages.
|
||||
|
||||
Examples: `("messages", 20)`, `("tokens", 3000)`, or
|
||||
`("fraction", 0.3)`.
|
||||
token_counter: Function to count tokens in messages.
|
||||
summary_prompt: Prompt template for generating summaries.
|
||||
summary_prefix: Prefix added to system message when including summary.
|
||||
trim_tokens_to_summarize: Maximum tokens to keep when preparing messages for
|
||||
the summarization call.
|
||||
|
||||
Pass `None` to skip trimming entirely.
|
||||
"""
|
||||
# Handle deprecated parameters
|
||||
if "max_tokens_before_summary" in deprecated_kwargs:
|
||||
value = deprecated_kwargs["max_tokens_before_summary"]
|
||||
warnings.warn(
|
||||
"max_tokens_before_summary is deprecated. Use trigger=('tokens', value) instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
if trigger is None and value is not None:
|
||||
trigger = ("tokens", value)
|
||||
|
||||
if "messages_to_keep" in deprecated_kwargs:
|
||||
value = deprecated_kwargs["messages_to_keep"]
|
||||
warnings.warn(
|
||||
"messages_to_keep is deprecated. Use keep=('messages', value) instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
if keep == ("messages", _DEFAULT_MESSAGES_TO_KEEP):
|
||||
keep = ("messages", value)
|
||||
|
||||
super().__init__()
|
||||
|
||||
if isinstance(model, str):
|
||||
model = init_chat_model(model)
|
||||
|
||||
self.model = model
|
||||
self.max_tokens_before_summary = max_tokens_before_summary
|
||||
self.messages_to_keep = messages_to_keep
|
||||
if trigger is None:
|
||||
self.trigger: ContextSize | list[ContextSize] | None = None
|
||||
trigger_conditions: list[ContextSize] = []
|
||||
elif isinstance(trigger, list):
|
||||
validated_list = [self._validate_context_size(item, "trigger") for item in trigger]
|
||||
self.trigger = validated_list
|
||||
trigger_conditions = validated_list
|
||||
else:
|
||||
validated = self._validate_context_size(trigger, "trigger")
|
||||
self.trigger = validated
|
||||
trigger_conditions = [validated]
|
||||
self._trigger_conditions = trigger_conditions
|
||||
|
||||
self.keep = self._validate_context_size(keep, "keep")
|
||||
self.token_counter = token_counter
|
||||
self.summary_prompt = summary_prompt
|
||||
self.summary_prefix = summary_prefix
|
||||
self.trim_tokens_to_summarize = trim_tokens_to_summarize
|
||||
|
||||
requires_profile = any(condition[0] == "fraction" for condition in self._trigger_conditions)
|
||||
if self.keep[0] == "fraction":
|
||||
requires_profile = True
|
||||
if requires_profile and self._get_profile_limits() is None:
|
||||
msg = (
|
||||
"Model profile information is required to use fractional token limits. "
|
||||
'pip install "langchain[model-profiles]" or use absolute token counts '
|
||||
"instead."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
def before_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
|
||||
"""Process messages before model invocation, potentially triggering summarization."""
|
||||
@@ -105,13 +171,10 @@ class SummarizationMiddleware(AgentMiddleware):
|
||||
self._ensure_message_ids(messages)
|
||||
|
||||
total_tokens = self.token_counter(messages)
|
||||
if (
|
||||
self.max_tokens_before_summary is not None
|
||||
and total_tokens < self.max_tokens_before_summary
|
||||
):
|
||||
if not self._should_summarize(messages, total_tokens):
|
||||
return None
|
||||
|
||||
cutoff_index = self._find_safe_cutoff(messages)
|
||||
cutoff_index = self._determine_cutoff_index(messages)
|
||||
|
||||
if cutoff_index <= 0:
|
||||
return None
|
||||
@@ -129,6 +192,151 @@ class SummarizationMiddleware(AgentMiddleware):
|
||||
]
|
||||
}
|
||||
|
||||
async def abefore_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
|
||||
"""Process messages before model invocation, potentially triggering summarization."""
|
||||
messages = state["messages"]
|
||||
self._ensure_message_ids(messages)
|
||||
|
||||
total_tokens = self.token_counter(messages)
|
||||
if not self._should_summarize(messages, total_tokens):
|
||||
return None
|
||||
|
||||
cutoff_index = self._determine_cutoff_index(messages)
|
||||
|
||||
if cutoff_index <= 0:
|
||||
return None
|
||||
|
||||
messages_to_summarize, preserved_messages = self._partition_messages(messages, cutoff_index)
|
||||
|
||||
summary = await self._acreate_summary(messages_to_summarize)
|
||||
new_messages = self._build_new_messages(summary)
|
||||
|
||||
return {
|
||||
"messages": [
|
||||
RemoveMessage(id=REMOVE_ALL_MESSAGES),
|
||||
*new_messages,
|
||||
*preserved_messages,
|
||||
]
|
||||
}
|
||||
|
||||
def _should_summarize(self, messages: list[AnyMessage], total_tokens: int) -> bool:
|
||||
"""Determine whether summarization should run for the current token usage."""
|
||||
if not self._trigger_conditions:
|
||||
return False
|
||||
|
||||
for kind, value in self._trigger_conditions:
|
||||
if kind == "messages" and len(messages) >= value:
|
||||
return True
|
||||
if kind == "tokens" and total_tokens >= value:
|
||||
return True
|
||||
if kind == "fraction":
|
||||
max_input_tokens = self._get_profile_limits()
|
||||
if max_input_tokens is None:
|
||||
continue
|
||||
threshold = int(max_input_tokens * value)
|
||||
if threshold <= 0:
|
||||
threshold = 1
|
||||
if total_tokens >= threshold:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _determine_cutoff_index(self, messages: list[AnyMessage]) -> int:
|
||||
"""Choose cutoff index respecting retention configuration."""
|
||||
kind, value = self.keep
|
||||
if kind in {"tokens", "fraction"}:
|
||||
token_based_cutoff = self._find_token_based_cutoff(messages)
|
||||
if token_based_cutoff is not None:
|
||||
return token_based_cutoff
|
||||
# None cutoff -> model profile data not available (caught in __init__ but
|
||||
# here for safety), fallback to message count
|
||||
return self._find_safe_cutoff(messages, _DEFAULT_MESSAGES_TO_KEEP)
|
||||
return self._find_safe_cutoff(messages, cast("int", value))
|
||||
|
||||
def _find_token_based_cutoff(self, messages: list[AnyMessage]) -> int | None:
|
||||
"""Find cutoff index based on target token retention."""
|
||||
if not messages:
|
||||
return 0
|
||||
|
||||
kind, value = self.keep
|
||||
if kind == "fraction":
|
||||
max_input_tokens = self._get_profile_limits()
|
||||
if max_input_tokens is None:
|
||||
return None
|
||||
target_token_count = int(max_input_tokens * value)
|
||||
elif kind == "tokens":
|
||||
target_token_count = int(value)
|
||||
else:
|
||||
return None
|
||||
|
||||
if target_token_count <= 0:
|
||||
target_token_count = 1
|
||||
|
||||
if self.token_counter(messages) <= target_token_count:
|
||||
return 0
|
||||
|
||||
# Use binary search to identify the earliest message index that keeps the
|
||||
# suffix within the token budget.
|
||||
left, right = 0, len(messages)
|
||||
cutoff_candidate = len(messages)
|
||||
max_iterations = len(messages).bit_length() + 1
|
||||
for _ in range(max_iterations):
|
||||
if left >= right:
|
||||
break
|
||||
|
||||
mid = (left + right) // 2
|
||||
if self.token_counter(messages[mid:]) <= target_token_count:
|
||||
cutoff_candidate = mid
|
||||
right = mid
|
||||
else:
|
||||
left = mid + 1
|
||||
|
||||
if cutoff_candidate == len(messages):
|
||||
cutoff_candidate = left
|
||||
|
||||
if cutoff_candidate >= len(messages):
|
||||
if len(messages) == 1:
|
||||
return 0
|
||||
cutoff_candidate = len(messages) - 1
|
||||
|
||||
for i in range(cutoff_candidate, -1, -1):
|
||||
if self._is_safe_cutoff_point(messages, i):
|
||||
return i
|
||||
|
||||
return 0
|
||||
|
||||
def _get_profile_limits(self) -> int | None:
|
||||
"""Retrieve max input token limit from the model profile."""
|
||||
try:
|
||||
profile = self.model.profile
|
||||
except (AttributeError, ImportError):
|
||||
return None
|
||||
|
||||
if not isinstance(profile, Mapping):
|
||||
return None
|
||||
|
||||
max_input_tokens = profile.get("max_input_tokens")
|
||||
|
||||
if not isinstance(max_input_tokens, int):
|
||||
return None
|
||||
|
||||
return max_input_tokens
|
||||
|
||||
def _validate_context_size(self, context: ContextSize, parameter_name: str) -> ContextSize:
|
||||
"""Validate context configuration tuples."""
|
||||
kind, value = context
|
||||
if kind == "fraction":
|
||||
if not 0 < value <= 1:
|
||||
msg = f"Fractional {parameter_name} values must be between 0 and 1, got {value}."
|
||||
raise ValueError(msg)
|
||||
elif kind in {"tokens", "messages"}:
|
||||
if value <= 0:
|
||||
msg = f"{parameter_name} thresholds must be greater than 0, got {value}."
|
||||
raise ValueError(msg)
|
||||
else:
|
||||
msg = f"Unsupported context size type {kind} for {parameter_name}."
|
||||
raise ValueError(msg)
|
||||
return context
|
||||
|
||||
def _build_new_messages(self, summary: str) -> list[HumanMessage]:
|
||||
return [
|
||||
HumanMessage(content=f"Here is a summary of the conversation to date:\n\n{summary}")
|
||||
@@ -151,16 +359,16 @@ class SummarizationMiddleware(AgentMiddleware):
|
||||
|
||||
return messages_to_summarize, preserved_messages
|
||||
|
||||
def _find_safe_cutoff(self, messages: list[AnyMessage]) -> int:
|
||||
def _find_safe_cutoff(self, messages: list[AnyMessage], messages_to_keep: int) -> int:
|
||||
"""Find safe cutoff point that preserves AI/Tool message pairs.
|
||||
|
||||
Returns the index where messages can be safely cut without separating
|
||||
related AI and Tool messages. Returns 0 if no safe cutoff is found.
|
||||
related AI and Tool messages. Returns `0` if no safe cutoff is found.
|
||||
"""
|
||||
if len(messages) <= self.messages_to_keep:
|
||||
if len(messages) <= messages_to_keep:
|
||||
return 0
|
||||
|
||||
target_cutoff = len(messages) - self.messages_to_keep
|
||||
target_cutoff = len(messages) - messages_to_keep
|
||||
|
||||
for i in range(target_cutoff, -1, -1):
|
||||
if self._is_safe_cutoff_point(messages, i):
|
||||
@@ -229,16 +437,35 @@ class SummarizationMiddleware(AgentMiddleware):
|
||||
|
||||
try:
|
||||
response = self.model.invoke(self.summary_prompt.format(messages=trimmed_messages))
|
||||
return cast("str", response.content).strip()
|
||||
return response.text.strip()
|
||||
except Exception as e: # noqa: BLE001
|
||||
return f"Error generating summary: {e!s}"
|
||||
|
||||
async def _acreate_summary(self, messages_to_summarize: list[AnyMessage]) -> str:
|
||||
"""Generate summary for the given messages."""
|
||||
if not messages_to_summarize:
|
||||
return "No previous conversation history."
|
||||
|
||||
trimmed_messages = self._trim_messages_for_summary(messages_to_summarize)
|
||||
if not trimmed_messages:
|
||||
return "Previous conversation was too long to summarize."
|
||||
|
||||
try:
|
||||
response = await self.model.ainvoke(
|
||||
self.summary_prompt.format(messages=trimmed_messages)
|
||||
)
|
||||
return response.text.strip()
|
||||
except Exception as e: # noqa: BLE001
|
||||
return f"Error generating summary: {e!s}"
|
||||
|
||||
def _trim_messages_for_summary(self, messages: list[AnyMessage]) -> list[AnyMessage]:
|
||||
"""Trim messages to fit within summary generation limits."""
|
||||
try:
|
||||
if self.trim_tokens_to_summarize is None:
|
||||
return messages
|
||||
return trim_messages(
|
||||
messages,
|
||||
max_tokens=_DEFAULT_TRIM_TOKEN_LIMIT,
|
||||
max_tokens=self.trim_tokens_to_summarize,
|
||||
token_counter=self.token_counter,
|
||||
start_on="human",
|
||||
strategy="last",
|
||||
|
||||
@@ -150,10 +150,6 @@ class TodoListMiddleware(AgentMiddleware):
|
||||
|
||||
print(result["todos"]) # Array of todo items with status tracking
|
||||
```
|
||||
|
||||
Args:
|
||||
system_prompt: Custom system prompt to guide the agent on using the todo tool.
|
||||
tool_description: Custom description for the write_todos tool.
|
||||
"""
|
||||
|
||||
state_schema = PlanningState
|
||||
@@ -198,12 +194,12 @@ class TodoListMiddleware(AgentMiddleware):
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
"""Update the system prompt to include the todo system prompt."""
|
||||
request.system_prompt = (
|
||||
new_system_prompt = (
|
||||
request.system_prompt + "\n\n" + self.system_prompt
|
||||
if request.system_prompt
|
||||
else self.system_prompt
|
||||
)
|
||||
return handler(request)
|
||||
return handler(request.override(system_prompt=new_system_prompt))
|
||||
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
@@ -211,9 +207,9 @@ class TodoListMiddleware(AgentMiddleware):
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> ModelCallResult:
|
||||
"""Update the system prompt to include the todo system prompt (async version)."""
|
||||
request.system_prompt = (
|
||||
new_system_prompt = (
|
||||
request.system_prompt + "\n\n" + self.system_prompt
|
||||
if request.system_prompt
|
||||
else self.system_prompt
|
||||
)
|
||||
return await handler(request)
|
||||
return await handler(request.override(system_prompt=new_system_prompt))
|
||||
|
||||
@@ -153,38 +153,46 @@ class ToolCallLimitMiddleware(
|
||||
are other pending tool calls (due to parallel tool calling).
|
||||
|
||||
Examples:
|
||||
```python title="Continue execution with blocked tools (default)"
|
||||
from langchain.agents.middleware.tool_call_limit import ToolCallLimitMiddleware
|
||||
from langchain.agents import create_agent
|
||||
!!! example "Continue execution with blocked tools (default)"
|
||||
|
||||
# Block exceeded tools but let other tools and model continue
|
||||
limiter = ToolCallLimitMiddleware(
|
||||
thread_limit=20,
|
||||
run_limit=10,
|
||||
exit_behavior="continue", # default
|
||||
)
|
||||
```python
|
||||
from langchain.agents.middleware.tool_call_limit import ToolCallLimitMiddleware
|
||||
from langchain.agents import create_agent
|
||||
|
||||
agent = create_agent("openai:gpt-4o", middleware=[limiter])
|
||||
```
|
||||
# Block exceeded tools but let other tools and model continue
|
||||
limiter = ToolCallLimitMiddleware(
|
||||
thread_limit=20,
|
||||
run_limit=10,
|
||||
exit_behavior="continue", # default
|
||||
)
|
||||
|
||||
```python title="Stop immediately when limit exceeded"
|
||||
# End execution immediately with an AI message
|
||||
limiter = ToolCallLimitMiddleware(run_limit=5, exit_behavior="end")
|
||||
agent = create_agent("openai:gpt-4o", middleware=[limiter])
|
||||
```
|
||||
|
||||
agent = create_agent("openai:gpt-4o", middleware=[limiter])
|
||||
```
|
||||
!!! example "Stop immediately when limit exceeded"
|
||||
|
||||
```python title="Raise exception on limit"
|
||||
# Strict limit with exception handling
|
||||
limiter = ToolCallLimitMiddleware(tool_name="search", thread_limit=5, exit_behavior="error")
|
||||
```python
|
||||
# End execution immediately with an AI message
|
||||
limiter = ToolCallLimitMiddleware(run_limit=5, exit_behavior="end")
|
||||
|
||||
agent = create_agent("openai:gpt-4o", middleware=[limiter])
|
||||
agent = create_agent("openai:gpt-4o", middleware=[limiter])
|
||||
```
|
||||
|
||||
try:
|
||||
result = await agent.invoke({"messages": [HumanMessage("Task")]})
|
||||
except ToolCallLimitExceededError as e:
|
||||
print(f"Search limit exceeded: {e}")
|
||||
```
|
||||
!!! example "Raise exception on limit"
|
||||
|
||||
```python
|
||||
# Strict limit with exception handling
|
||||
limiter = ToolCallLimitMiddleware(
|
||||
tool_name="search", thread_limit=5, exit_behavior="error"
|
||||
)
|
||||
|
||||
agent = create_agent("openai:gpt-4o", middleware=[limiter])
|
||||
|
||||
try:
|
||||
result = await agent.invoke({"messages": [HumanMessage("Task")]})
|
||||
except ToolCallLimitExceededError as e:
|
||||
print(f"Search limit exceeded: {e}")
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
@@ -208,6 +216,7 @@ class ToolCallLimitMiddleware(
|
||||
run_limit: Maximum number of tool calls allowed per run.
|
||||
`None` means no limit.
|
||||
exit_behavior: How to handle when limits are exceeded.
|
||||
|
||||
- `'continue'`: Block exceeded tools with error messages, let other
|
||||
tools continue. Model decides when to end.
|
||||
- `'error'`: Raise a `ToolCallLimitExceededError` exception
|
||||
@@ -218,7 +227,7 @@ class ToolCallLimitMiddleware(
|
||||
|
||||
Raises:
|
||||
ValueError: If both limits are `None`, if `exit_behavior` is invalid,
|
||||
or if `run_limit` exceeds thread_limit.
|
||||
or if `run_limit` exceeds `thread_limit`.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
@@ -451,3 +460,28 @@ class ToolCallLimitMiddleware(
|
||||
"run_tool_call_count": run_counts,
|
||||
"messages": artificial_messages,
|
||||
}
|
||||
|
||||
@hook_config(can_jump_to=["end"])
|
||||
async def aafter_model(
|
||||
self,
|
||||
state: ToolCallLimitState[ResponseT],
|
||||
runtime: Runtime[ContextT],
|
||||
) -> dict[str, Any] | None:
|
||||
"""Async increment tool call counts after a model call and check limits.
|
||||
|
||||
Args:
|
||||
state: The current agent state.
|
||||
runtime: The langgraph runtime.
|
||||
|
||||
Returns:
|
||||
State updates with incremented tool call counts. If limits are exceeded
|
||||
and exit_behavior is `'end'`, also includes a jump to end with a
|
||||
`ToolMessage` and AI message for the single exceeded tool call.
|
||||
|
||||
Raises:
|
||||
ToolCallLimitExceededError: If limits are exceeded and `exit_behavior`
|
||||
is `'error'`.
|
||||
NotImplementedError: If limits are exceeded, `exit_behavior` is `'end'`,
|
||||
and there are multiple tool calls.
|
||||
"""
|
||||
return self.after_model(state, runtime)
|
||||
|
||||
@@ -25,34 +25,42 @@ class LLMToolEmulator(AgentMiddleware):
|
||||
This middleware allows selective emulation of tools for testing purposes.
|
||||
|
||||
By default (when `tools=None`), all tools are emulated. You can specify which
|
||||
tools to emulate by passing a list of tool names or BaseTool instances.
|
||||
tools to emulate by passing a list of tool names or `BaseTool` instances.
|
||||
|
||||
Examples:
|
||||
```python title="Emulate all tools (default behavior)"
|
||||
from langchain.agents.middleware import LLMToolEmulator
|
||||
!!! example "Emulate all tools (default behavior)"
|
||||
|
||||
middleware = LLMToolEmulator()
|
||||
```python
|
||||
from langchain.agents.middleware import LLMToolEmulator
|
||||
|
||||
agent = create_agent(
|
||||
model="openai:gpt-4o",
|
||||
tools=[get_weather, get_user_location, calculator],
|
||||
middleware=[middleware],
|
||||
)
|
||||
```
|
||||
middleware = LLMToolEmulator()
|
||||
|
||||
```python title="Emulate specific tools by name"
|
||||
middleware = LLMToolEmulator(tools=["get_weather", "get_user_location"])
|
||||
```
|
||||
agent = create_agent(
|
||||
model="openai:gpt-4o",
|
||||
tools=[get_weather, get_user_location, calculator],
|
||||
middleware=[middleware],
|
||||
)
|
||||
```
|
||||
|
||||
```python title="Use a custom model for emulation"
|
||||
middleware = LLMToolEmulator(
|
||||
tools=["get_weather"], model="anthropic:claude-sonnet-4-5-20250929"
|
||||
)
|
||||
```
|
||||
!!! example "Emulate specific tools by name"
|
||||
|
||||
```python title="Emulate specific tools by passing tool instances"
|
||||
middleware = LLMToolEmulator(tools=[get_weather, get_user_location])
|
||||
```
|
||||
```python
|
||||
middleware = LLMToolEmulator(tools=["get_weather", "get_user_location"])
|
||||
```
|
||||
|
||||
!!! example "Use a custom model for emulation"
|
||||
|
||||
```python
|
||||
middleware = LLMToolEmulator(
|
||||
tools=["get_weather"], model="anthropic:claude-sonnet-4-5-20250929"
|
||||
)
|
||||
```
|
||||
|
||||
!!! example "Emulate specific tools by passing tool instances"
|
||||
|
||||
```python
|
||||
middleware = LLMToolEmulator(tools=[get_weather, get_user_location])
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -26,96 +26,96 @@ class ToolRetryMiddleware(AgentMiddleware):
|
||||
Supports retrying on specific exceptions and exponential backoff.
|
||||
|
||||
Examples:
|
||||
Basic usage with default settings (2 retries, exponential backoff):
|
||||
!!! example "Basic usage with default settings (2 retries, exponential backoff)"
|
||||
|
||||
```python
|
||||
from langchain.agents import create_agent
|
||||
from langchain.agents.middleware import ToolRetryMiddleware
|
||||
```python
|
||||
from langchain.agents import create_agent
|
||||
from langchain.agents.middleware import ToolRetryMiddleware
|
||||
|
||||
agent = create_agent(model, tools=[search_tool], middleware=[ToolRetryMiddleware()])
|
||||
```
|
||||
agent = create_agent(model, tools=[search_tool], middleware=[ToolRetryMiddleware()])
|
||||
```
|
||||
|
||||
Retry specific exceptions only:
|
||||
!!! example "Retry specific exceptions only"
|
||||
|
||||
```python
|
||||
from requests.exceptions import RequestException, Timeout
|
||||
```python
|
||||
from requests.exceptions import RequestException, Timeout
|
||||
|
||||
retry = ToolRetryMiddleware(
|
||||
max_retries=4,
|
||||
retry_on=(RequestException, Timeout),
|
||||
backoff_factor=1.5,
|
||||
)
|
||||
```
|
||||
retry = ToolRetryMiddleware(
|
||||
max_retries=4,
|
||||
retry_on=(RequestException, Timeout),
|
||||
backoff_factor=1.5,
|
||||
)
|
||||
```
|
||||
|
||||
Custom exception filtering:
|
||||
!!! example "Custom exception filtering"
|
||||
|
||||
```python
|
||||
from requests.exceptions import HTTPError
|
||||
```python
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
|
||||
def should_retry(exc: Exception) -> bool:
|
||||
# Only retry on 5xx errors
|
||||
if isinstance(exc, HTTPError):
|
||||
return 500 <= exc.status_code < 600
|
||||
return False
|
||||
def should_retry(exc: Exception) -> bool:
|
||||
# Only retry on 5xx errors
|
||||
if isinstance(exc, HTTPError):
|
||||
return 500 <= exc.status_code < 600
|
||||
return False
|
||||
|
||||
|
||||
retry = ToolRetryMiddleware(
|
||||
max_retries=3,
|
||||
retry_on=should_retry,
|
||||
)
|
||||
```
|
||||
retry = ToolRetryMiddleware(
|
||||
max_retries=3,
|
||||
retry_on=should_retry,
|
||||
)
|
||||
```
|
||||
|
||||
Apply to specific tools with custom error handling:
|
||||
!!! example "Apply to specific tools with custom error handling"
|
||||
|
||||
```python
|
||||
def format_error(exc: Exception) -> str:
|
||||
return "Database temporarily unavailable. Please try again later."
|
||||
```python
|
||||
def format_error(exc: Exception) -> str:
|
||||
return "Database temporarily unavailable. Please try again later."
|
||||
|
||||
|
||||
retry = ToolRetryMiddleware(
|
||||
max_retries=4,
|
||||
tools=["search_database"],
|
||||
on_failure=format_error,
|
||||
)
|
||||
```
|
||||
retry = ToolRetryMiddleware(
|
||||
max_retries=4,
|
||||
tools=["search_database"],
|
||||
on_failure=format_error,
|
||||
)
|
||||
```
|
||||
|
||||
Apply to specific tools using BaseTool instances:
|
||||
!!! example "Apply to specific tools using `BaseTool` instances"
|
||||
|
||||
```python
|
||||
from langchain_core.tools import tool
|
||||
```python
|
||||
from langchain_core.tools import tool
|
||||
|
||||
|
||||
@tool
|
||||
def search_database(query: str) -> str:
|
||||
'''Search the database.'''
|
||||
return results
|
||||
@tool
|
||||
def search_database(query: str) -> str:
|
||||
'''Search the database.'''
|
||||
return results
|
||||
|
||||
|
||||
retry = ToolRetryMiddleware(
|
||||
max_retries=4,
|
||||
tools=[search_database], # Pass BaseTool instance
|
||||
)
|
||||
```
|
||||
retry = ToolRetryMiddleware(
|
||||
max_retries=4,
|
||||
tools=[search_database], # Pass BaseTool instance
|
||||
)
|
||||
```
|
||||
|
||||
Constant backoff (no exponential growth):
|
||||
!!! example "Constant backoff (no exponential growth)"
|
||||
|
||||
```python
|
||||
retry = ToolRetryMiddleware(
|
||||
max_retries=5,
|
||||
backoff_factor=0.0, # No exponential growth
|
||||
initial_delay=2.0, # Always wait 2 seconds
|
||||
)
|
||||
```
|
||||
```python
|
||||
retry = ToolRetryMiddleware(
|
||||
max_retries=5,
|
||||
backoff_factor=0.0, # No exponential growth
|
||||
initial_delay=2.0, # Always wait 2 seconds
|
||||
)
|
||||
```
|
||||
|
||||
Raise exception on failure:
|
||||
!!! example "Raise exception on failure"
|
||||
|
||||
```python
|
||||
retry = ToolRetryMiddleware(
|
||||
max_retries=2,
|
||||
on_failure="raise", # Re-raise exception instead of returning message
|
||||
)
|
||||
```
|
||||
```python
|
||||
retry = ToolRetryMiddleware(
|
||||
max_retries=2,
|
||||
on_failure="raise", # Re-raise exception instead of returning message
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -136,7 +136,10 @@ class ToolRetryMiddleware(AgentMiddleware):
|
||||
|
||||
Args:
|
||||
max_retries: Maximum number of retry attempts after the initial call.
|
||||
Default is `2` retries (`3` total attempts). Must be `>= 0`.
|
||||
|
||||
Default is `2` retries (`3` total attempts).
|
||||
|
||||
Must be `>= 0`.
|
||||
tools: Optional list of tools or tool names to apply retry logic to.
|
||||
|
||||
Can be a list of `BaseTool` instances or tool name strings.
|
||||
@@ -146,12 +149,14 @@ class ToolRetryMiddleware(AgentMiddleware):
|
||||
that takes an exception and returns `True` if it should be retried.
|
||||
|
||||
Default is to retry on all exceptions.
|
||||
on_failure: Behavior when all retries are exhausted. Options:
|
||||
on_failure: Behavior when all retries are exhausted.
|
||||
|
||||
Options:
|
||||
|
||||
- `'return_message'`: Return a `ToolMessage` with error details,
|
||||
allowing the LLM to handle the failure and potentially recover.
|
||||
- `'raise'`: Re-raise the exception, stopping agent execution.
|
||||
- Custom callable: Function that takes the exception and returns a
|
||||
- **Custom callable:** Function that takes the exception and returns a
|
||||
string for the `ToolMessage` content, allowing custom error
|
||||
formatting.
|
||||
backoff_factor: Multiplier for exponential backoff.
|
||||
|
||||
@@ -93,21 +93,25 @@ class LLMToolSelectorMiddleware(AgentMiddleware):
|
||||
and helps the main model focus on the right tools.
|
||||
|
||||
Examples:
|
||||
```python title="Limit to 3 tools"
|
||||
from langchain.agents.middleware import LLMToolSelectorMiddleware
|
||||
!!! example "Limit to 3 tools"
|
||||
|
||||
middleware = LLMToolSelectorMiddleware(max_tools=3)
|
||||
```python
|
||||
from langchain.agents.middleware import LLMToolSelectorMiddleware
|
||||
|
||||
agent = create_agent(
|
||||
model="openai:gpt-4o",
|
||||
tools=[tool1, tool2, tool3, tool4, tool5],
|
||||
middleware=[middleware],
|
||||
)
|
||||
```
|
||||
middleware = LLMToolSelectorMiddleware(max_tools=3)
|
||||
|
||||
```python title="Use a smaller model for selection"
|
||||
middleware = LLMToolSelectorMiddleware(model="openai:gpt-4o-mini", max_tools=2)
|
||||
```
|
||||
agent = create_agent(
|
||||
model="openai:gpt-4o",
|
||||
tools=[tool1, tool2, tool3, tool4, tool5],
|
||||
middleware=[middleware],
|
||||
)
|
||||
```
|
||||
|
||||
!!! example "Use a smaller model for selection"
|
||||
|
||||
```python
|
||||
middleware = LLMToolSelectorMiddleware(model="openai:gpt-4o-mini", max_tools=2)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -131,7 +135,7 @@ class LLMToolSelectorMiddleware(AgentMiddleware):
|
||||
|
||||
If the model selects more, only the first `max_tools` will be used.
|
||||
|
||||
No limit if not specified.
|
||||
If not specified, there is no limit.
|
||||
always_include: Tool names to always include regardless of selection.
|
||||
|
||||
These do not count against the `max_tools` limit.
|
||||
@@ -251,8 +255,7 @@ class LLMToolSelectorMiddleware(AgentMiddleware):
|
||||
# Also preserve any provider-specific tool dicts from the original request
|
||||
provider_tools = [tool for tool in request.tools if isinstance(tool, dict)]
|
||||
|
||||
request.tools = [*selected_tools, *provider_tools]
|
||||
return request
|
||||
return request.override(tools=[*selected_tools, *provider_tools])
|
||||
|
||||
def wrap_model_call(
|
||||
self,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -125,7 +125,7 @@ def init_chat_model(
|
||||
- `ollama` -> [`langchain-ollama`](https://docs.langchain.com/oss/python/integrations/providers/ollama)
|
||||
- `google_anthropic_vertex` -> [`langchain-google-vertexai`](https://docs.langchain.com/oss/python/integrations/providers/google)
|
||||
- `deepseek` -> [`langchain-deepseek`](https://docs.langchain.com/oss/python/integrations/providers/deepseek)
|
||||
- `ibm` -> [`langchain-ibm`](https://docs.langchain.com/oss/python/integrations/providers/deepseek)
|
||||
- `ibm` -> [`langchain-ibm`](https://docs.langchain.com/oss/python/integrations/providers/ibm)
|
||||
- `nvidia` -> [`langchain-nvidia-ai-endpoints`](https://docs.langchain.com/oss/python/integrations/providers/nvidia)
|
||||
- `xai` -> [`langchain-xai`](https://docs.langchain.com/oss/python/integrations/providers/xai)
|
||||
- `perplexity` -> [`langchain-perplexity`](https://docs.langchain.com/oss/python/integrations/providers/perplexity)
|
||||
|
||||
@@ -57,6 +57,7 @@ test = [
|
||||
"pytest-mock",
|
||||
"syrupy>=4.0.2,<5.0.0",
|
||||
"toml>=0.10.2,<1.0.0",
|
||||
"langchain-model-profiles",
|
||||
"langchain-tests",
|
||||
"langchain-openai",
|
||||
]
|
||||
@@ -75,6 +76,7 @@ test_integration = [
|
||||
"cassio>=0.1.0,<1.0.0",
|
||||
"langchainhub>=0.1.16,<1.0.0",
|
||||
"langchain-core",
|
||||
"langchain-model-profiles",
|
||||
"langchain-text-splitters",
|
||||
]
|
||||
|
||||
@@ -83,6 +85,7 @@ prerelease = "allow"
|
||||
|
||||
[tool.uv.sources]
|
||||
langchain-core = { path = "../core", editable = true }
|
||||
langchain-model-profiles = { path = "../model-profiles", editable = true }
|
||||
langchain-tests = { path = "../standard-tests", editable = true }
|
||||
langchain-text-splitters = { path = "../text-splitters", editable = true }
|
||||
langchain-openai = { path = "../partners/openai", editable = true }
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,79 +0,0 @@
|
||||
import pytest
|
||||
from langchain_core.messages import HumanMessage
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from langchain.agents import create_agent
|
||||
from langchain.agents.structured_output import ToolStrategy
|
||||
|
||||
|
||||
class WeatherBaseModel(BaseModel):
|
||||
"""Weather response."""
|
||||
|
||||
temperature: float = Field(description="The temperature in fahrenheit")
|
||||
condition: str = Field(description="Weather condition")
|
||||
|
||||
|
||||
def get_weather(city: str) -> str: # noqa: ARG001
|
||||
"""Get the weather for a city."""
|
||||
return "The weather is sunny and 75°F."
|
||||
|
||||
|
||||
@pytest.mark.requires("langchain_openai")
|
||||
def test_inference_to_native_output() -> None:
|
||||
"""Test that native output is inferred when a model supports it."""
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
model = ChatOpenAI(model="gpt-5")
|
||||
agent = create_agent(
|
||||
model,
|
||||
system_prompt=(
|
||||
"You are a helpful weather assistant. Please call the get_weather tool, "
|
||||
"then use the WeatherReport tool to generate the final response."
|
||||
),
|
||||
tools=[get_weather],
|
||||
response_format=WeatherBaseModel,
|
||||
)
|
||||
response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
|
||||
|
||||
assert isinstance(response["structured_response"], WeatherBaseModel)
|
||||
assert response["structured_response"].temperature == 75.0
|
||||
assert response["structured_response"].condition.lower() == "sunny"
|
||||
assert len(response["messages"]) == 4
|
||||
|
||||
assert [m.type for m in response["messages"]] == [
|
||||
"human", # "What's the weather?"
|
||||
"ai", # "What's the weather?"
|
||||
"tool", # "The weather is sunny and 75°F."
|
||||
"ai", # structured response
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.requires("langchain_openai")
|
||||
def test_inference_to_tool_output() -> None:
|
||||
"""Test that tool output is inferred when a model supports it."""
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
model = ChatOpenAI(model="gpt-4")
|
||||
agent = create_agent(
|
||||
model,
|
||||
system_prompt=(
|
||||
"You are a helpful weather assistant. Please call the get_weather tool, "
|
||||
"then use the WeatherReport tool to generate the final response."
|
||||
),
|
||||
tools=[get_weather],
|
||||
response_format=ToolStrategy(WeatherBaseModel),
|
||||
)
|
||||
response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
|
||||
|
||||
assert isinstance(response["structured_response"], WeatherBaseModel)
|
||||
assert response["structured_response"].temperature == 75.0
|
||||
assert response["structured_response"].condition.lower() == "sunny"
|
||||
assert len(response["messages"]) == 5
|
||||
|
||||
assert [m.type for m in response["messages"]] == [
|
||||
"human", # "What's the weather?"
|
||||
"ai", # "What's the weather?"
|
||||
"tool", # "The weather is sunny and 75°F."
|
||||
"ai", # structured response
|
||||
"tool", # artificial tool message
|
||||
]
|
||||
@@ -0,0 +1,212 @@
|
||||
# serializer version: 1
|
||||
# name: test_agent_graph_with_jump_to_end_as_after_agent
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
tools(tools)
|
||||
NoopZero\2ebefore_agent(NoopZero.before_agent)
|
||||
NoopOne\2eafter_agent(NoopOne.after_agent)
|
||||
NoopTwo\2eafter_agent(NoopTwo.after_agent)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopTwo\2eafter_agent --> NoopOne\2eafter_agent;
|
||||
NoopZero\2ebefore_agent -.-> NoopTwo\2eafter_agent;
|
||||
NoopZero\2ebefore_agent -.-> model;
|
||||
__start__ --> NoopZero\2ebefore_agent;
|
||||
model -.-> NoopTwo\2eafter_agent;
|
||||
model -.-> tools;
|
||||
tools -.-> model;
|
||||
NoopOne\2eafter_agent --> __end__;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_jump[memory]
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
tools(tools)
|
||||
NoopSeven\2ebefore_model(NoopSeven.before_model)
|
||||
NoopSeven\2eafter_model(NoopSeven.after_model)
|
||||
NoopEight\2ebefore_model(NoopEight.before_model)
|
||||
NoopEight\2eafter_model(NoopEight.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
|
||||
NoopEight\2ebefore_model -.-> __end__;
|
||||
NoopEight\2ebefore_model -.-> model;
|
||||
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
|
||||
NoopSeven\2eafter_model -.-> __end__;
|
||||
NoopSeven\2eafter_model -.-> tools;
|
||||
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
|
||||
__start__ --> NoopSeven\2ebefore_model;
|
||||
model --> NoopEight\2eafter_model;
|
||||
tools -.-> NoopSeven\2ebefore_model;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_jump[postgres]
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
tools(tools)
|
||||
NoopSeven\2ebefore_model(NoopSeven.before_model)
|
||||
NoopSeven\2eafter_model(NoopSeven.after_model)
|
||||
NoopEight\2ebefore_model(NoopEight.before_model)
|
||||
NoopEight\2eafter_model(NoopEight.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
|
||||
NoopEight\2ebefore_model -.-> __end__;
|
||||
NoopEight\2ebefore_model -.-> model;
|
||||
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
|
||||
NoopSeven\2eafter_model -.-> __end__;
|
||||
NoopSeven\2eafter_model -.-> tools;
|
||||
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
|
||||
__start__ --> NoopSeven\2ebefore_model;
|
||||
model --> NoopEight\2eafter_model;
|
||||
tools -.-> NoopSeven\2ebefore_model;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_jump[postgres_pipe]
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
tools(tools)
|
||||
NoopSeven\2ebefore_model(NoopSeven.before_model)
|
||||
NoopSeven\2eafter_model(NoopSeven.after_model)
|
||||
NoopEight\2ebefore_model(NoopEight.before_model)
|
||||
NoopEight\2eafter_model(NoopEight.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
|
||||
NoopEight\2ebefore_model -.-> __end__;
|
||||
NoopEight\2ebefore_model -.-> model;
|
||||
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
|
||||
NoopSeven\2eafter_model -.-> __end__;
|
||||
NoopSeven\2eafter_model -.-> tools;
|
||||
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
|
||||
__start__ --> NoopSeven\2ebefore_model;
|
||||
model --> NoopEight\2eafter_model;
|
||||
tools -.-> NoopSeven\2ebefore_model;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_jump[postgres_pool]
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
tools(tools)
|
||||
NoopSeven\2ebefore_model(NoopSeven.before_model)
|
||||
NoopSeven\2eafter_model(NoopSeven.after_model)
|
||||
NoopEight\2ebefore_model(NoopEight.before_model)
|
||||
NoopEight\2eafter_model(NoopEight.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
|
||||
NoopEight\2ebefore_model -.-> __end__;
|
||||
NoopEight\2ebefore_model -.-> model;
|
||||
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
|
||||
NoopSeven\2eafter_model -.-> __end__;
|
||||
NoopSeven\2eafter_model -.-> tools;
|
||||
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
|
||||
__start__ --> NoopSeven\2ebefore_model;
|
||||
model --> NoopEight\2eafter_model;
|
||||
tools -.-> NoopSeven\2ebefore_model;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_jump[sqlite]
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
tools(tools)
|
||||
NoopSeven\2ebefore_model(NoopSeven.before_model)
|
||||
NoopSeven\2eafter_model(NoopSeven.after_model)
|
||||
NoopEight\2ebefore_model(NoopEight.before_model)
|
||||
NoopEight\2eafter_model(NoopEight.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
|
||||
NoopEight\2ebefore_model -.-> __end__;
|
||||
NoopEight\2ebefore_model -.-> model;
|
||||
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
|
||||
NoopSeven\2eafter_model -.-> __end__;
|
||||
NoopSeven\2eafter_model -.-> tools;
|
||||
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
|
||||
__start__ --> NoopSeven\2ebefore_model;
|
||||
model --> NoopEight\2eafter_model;
|
||||
tools -.-> NoopSeven\2ebefore_model;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_simple_agent_graph
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
tools(tools)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
__start__ --> model;
|
||||
model -.-> __end__;
|
||||
model -.-> tools;
|
||||
tools -.-> model;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
@@ -0,0 +1,95 @@
|
||||
# serializer version: 1
|
||||
# name: test_async_middleware_with_can_jump_to_graph_snapshot
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
async_before_with_jump\2ebefore_model(async_before_with_jump.before_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
__start__ --> async_before_with_jump\2ebefore_model;
|
||||
async_before_with_jump\2ebefore_model -.-> __end__;
|
||||
async_before_with_jump\2ebefore_model -.-> model;
|
||||
model --> __end__;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_async_middleware_with_can_jump_to_graph_snapshot.1
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
async_after_with_jump\2eafter_model(async_after_with_jump.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
__start__ --> model;
|
||||
async_after_with_jump\2eafter_model -.-> __end__;
|
||||
async_after_with_jump\2eafter_model -.-> model;
|
||||
model --> async_after_with_jump\2eafter_model;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_async_middleware_with_can_jump_to_graph_snapshot.2
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
async_before_early_exit\2ebefore_model(async_before_early_exit.before_model)
|
||||
async_after_retry\2eafter_model(async_after_retry.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
__start__ --> async_before_early_exit\2ebefore_model;
|
||||
async_after_retry\2eafter_model -.-> __end__;
|
||||
async_after_retry\2eafter_model -.-> async_before_early_exit\2ebefore_model;
|
||||
async_before_early_exit\2ebefore_model -.-> __end__;
|
||||
async_before_early_exit\2ebefore_model -.-> model;
|
||||
model --> async_after_retry\2eafter_model;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_async_middleware_with_can_jump_to_graph_snapshot.3
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
sync_before_with_jump\2ebefore_model(sync_before_with_jump.before_model)
|
||||
async_after_with_jumps\2eafter_model(async_after_with_jumps.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
__start__ --> sync_before_with_jump\2ebefore_model;
|
||||
async_after_with_jumps\2eafter_model -.-> __end__;
|
||||
async_after_with_jumps\2eafter_model -.-> sync_before_with_jump\2ebefore_model;
|
||||
model --> async_after_with_jumps\2eafter_model;
|
||||
sync_before_with_jump\2ebefore_model -.-> __end__;
|
||||
sync_before_with_jump\2ebefore_model -.-> model;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
@@ -0,0 +1,289 @@
|
||||
# serializer version: 1
|
||||
# name: test_create_agent_diagram
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
__start__ --> model;
|
||||
model --> __end__;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_diagram.1
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
NoopOne\2ebefore_model(NoopOne.before_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopOne\2ebefore_model --> model;
|
||||
__start__ --> NoopOne\2ebefore_model;
|
||||
model --> __end__;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_diagram.10
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
NoopTen\2ebefore_model(NoopTen.before_model)
|
||||
NoopTen\2eafter_model(NoopTen.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopTen\2ebefore_model --> model;
|
||||
__start__ --> NoopTen\2ebefore_model;
|
||||
model --> NoopTen\2eafter_model;
|
||||
NoopTen\2eafter_model --> __end__;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_diagram.11
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
NoopTen\2ebefore_model(NoopTen.before_model)
|
||||
NoopTen\2eafter_model(NoopTen.after_model)
|
||||
NoopEleven\2ebefore_model(NoopEleven.before_model)
|
||||
NoopEleven\2eafter_model(NoopEleven.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopEleven\2eafter_model --> NoopTen\2eafter_model;
|
||||
NoopEleven\2ebefore_model --> model;
|
||||
NoopTen\2ebefore_model --> NoopEleven\2ebefore_model;
|
||||
__start__ --> NoopTen\2ebefore_model;
|
||||
model --> NoopEleven\2eafter_model;
|
||||
NoopTen\2eafter_model --> __end__;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_diagram.2
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
NoopOne\2ebefore_model(NoopOne.before_model)
|
||||
NoopTwo\2ebefore_model(NoopTwo.before_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopOne\2ebefore_model --> NoopTwo\2ebefore_model;
|
||||
NoopTwo\2ebefore_model --> model;
|
||||
__start__ --> NoopOne\2ebefore_model;
|
||||
model --> __end__;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_diagram.3
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
NoopOne\2ebefore_model(NoopOne.before_model)
|
||||
NoopTwo\2ebefore_model(NoopTwo.before_model)
|
||||
NoopThree\2ebefore_model(NoopThree.before_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopOne\2ebefore_model --> NoopTwo\2ebefore_model;
|
||||
NoopThree\2ebefore_model --> model;
|
||||
NoopTwo\2ebefore_model --> NoopThree\2ebefore_model;
|
||||
__start__ --> NoopOne\2ebefore_model;
|
||||
model --> __end__;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_diagram.4
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
NoopFour\2eafter_model(NoopFour.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
__start__ --> model;
|
||||
model --> NoopFour\2eafter_model;
|
||||
NoopFour\2eafter_model --> __end__;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_diagram.5
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
NoopFour\2eafter_model(NoopFour.after_model)
|
||||
NoopFive\2eafter_model(NoopFive.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopFive\2eafter_model --> NoopFour\2eafter_model;
|
||||
__start__ --> model;
|
||||
model --> NoopFive\2eafter_model;
|
||||
NoopFour\2eafter_model --> __end__;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_diagram.6
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
NoopFour\2eafter_model(NoopFour.after_model)
|
||||
NoopFive\2eafter_model(NoopFive.after_model)
|
||||
NoopSix\2eafter_model(NoopSix.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopFive\2eafter_model --> NoopFour\2eafter_model;
|
||||
NoopSix\2eafter_model --> NoopFive\2eafter_model;
|
||||
__start__ --> model;
|
||||
model --> NoopSix\2eafter_model;
|
||||
NoopFour\2eafter_model --> __end__;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_diagram.7
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
NoopSeven\2ebefore_model(NoopSeven.before_model)
|
||||
NoopSeven\2eafter_model(NoopSeven.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopSeven\2ebefore_model --> model;
|
||||
__start__ --> NoopSeven\2ebefore_model;
|
||||
model --> NoopSeven\2eafter_model;
|
||||
NoopSeven\2eafter_model --> __end__;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_diagram.8
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
NoopSeven\2ebefore_model(NoopSeven.before_model)
|
||||
NoopSeven\2eafter_model(NoopSeven.after_model)
|
||||
NoopEight\2ebefore_model(NoopEight.before_model)
|
||||
NoopEight\2eafter_model(NoopEight.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
|
||||
NoopEight\2ebefore_model --> model;
|
||||
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
|
||||
__start__ --> NoopSeven\2ebefore_model;
|
||||
model --> NoopEight\2eafter_model;
|
||||
NoopSeven\2eafter_model --> __end__;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_diagram.9
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
NoopSeven\2ebefore_model(NoopSeven.before_model)
|
||||
NoopSeven\2eafter_model(NoopSeven.after_model)
|
||||
NoopEight\2ebefore_model(NoopEight.before_model)
|
||||
NoopEight\2eafter_model(NoopEight.after_model)
|
||||
NoopNine\2ebefore_model(NoopNine.before_model)
|
||||
NoopNine\2eafter_model(NoopNine.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
|
||||
NoopEight\2ebefore_model --> NoopNine\2ebefore_model;
|
||||
NoopNine\2eafter_model --> NoopEight\2eafter_model;
|
||||
NoopNine\2ebefore_model --> model;
|
||||
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
|
||||
__start__ --> NoopSeven\2ebefore_model;
|
||||
model --> NoopNine\2eafter_model;
|
||||
NoopSeven\2eafter_model --> __end__;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
@@ -0,0 +1,212 @@
|
||||
# serializer version: 1
|
||||
# name: test_agent_graph_with_jump_to_end_as_after_agent
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
tools(tools)
|
||||
NoopZero\2ebefore_agent(NoopZero.before_agent)
|
||||
NoopOne\2eafter_agent(NoopOne.after_agent)
|
||||
NoopTwo\2eafter_agent(NoopTwo.after_agent)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopTwo\2eafter_agent --> NoopOne\2eafter_agent;
|
||||
NoopZero\2ebefore_agent -.-> NoopTwo\2eafter_agent;
|
||||
NoopZero\2ebefore_agent -.-> model;
|
||||
__start__ --> NoopZero\2ebefore_agent;
|
||||
model -.-> NoopTwo\2eafter_agent;
|
||||
model -.-> tools;
|
||||
tools -.-> model;
|
||||
NoopOne\2eafter_agent --> __end__;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_jump[memory]
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
tools(tools)
|
||||
NoopSeven\2ebefore_model(NoopSeven.before_model)
|
||||
NoopSeven\2eafter_model(NoopSeven.after_model)
|
||||
NoopEight\2ebefore_model(NoopEight.before_model)
|
||||
NoopEight\2eafter_model(NoopEight.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
|
||||
NoopEight\2ebefore_model -.-> __end__;
|
||||
NoopEight\2ebefore_model -.-> model;
|
||||
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
|
||||
NoopSeven\2eafter_model -.-> __end__;
|
||||
NoopSeven\2eafter_model -.-> tools;
|
||||
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
|
||||
__start__ --> NoopSeven\2ebefore_model;
|
||||
model --> NoopEight\2eafter_model;
|
||||
tools -.-> NoopSeven\2ebefore_model;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_jump[postgres]
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
tools(tools)
|
||||
NoopSeven\2ebefore_model(NoopSeven.before_model)
|
||||
NoopSeven\2eafter_model(NoopSeven.after_model)
|
||||
NoopEight\2ebefore_model(NoopEight.before_model)
|
||||
NoopEight\2eafter_model(NoopEight.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
|
||||
NoopEight\2ebefore_model -.-> __end__;
|
||||
NoopEight\2ebefore_model -.-> model;
|
||||
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
|
||||
NoopSeven\2eafter_model -.-> __end__;
|
||||
NoopSeven\2eafter_model -.-> tools;
|
||||
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
|
||||
__start__ --> NoopSeven\2ebefore_model;
|
||||
model --> NoopEight\2eafter_model;
|
||||
tools -.-> NoopSeven\2ebefore_model;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_jump[postgres_pipe]
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
tools(tools)
|
||||
NoopSeven\2ebefore_model(NoopSeven.before_model)
|
||||
NoopSeven\2eafter_model(NoopSeven.after_model)
|
||||
NoopEight\2ebefore_model(NoopEight.before_model)
|
||||
NoopEight\2eafter_model(NoopEight.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
|
||||
NoopEight\2ebefore_model -.-> __end__;
|
||||
NoopEight\2ebefore_model -.-> model;
|
||||
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
|
||||
NoopSeven\2eafter_model -.-> __end__;
|
||||
NoopSeven\2eafter_model -.-> tools;
|
||||
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
|
||||
__start__ --> NoopSeven\2ebefore_model;
|
||||
model --> NoopEight\2eafter_model;
|
||||
tools -.-> NoopSeven\2ebefore_model;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_jump[postgres_pool]
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
tools(tools)
|
||||
NoopSeven\2ebefore_model(NoopSeven.before_model)
|
||||
NoopSeven\2eafter_model(NoopSeven.after_model)
|
||||
NoopEight\2ebefore_model(NoopEight.before_model)
|
||||
NoopEight\2eafter_model(NoopEight.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
|
||||
NoopEight\2ebefore_model -.-> __end__;
|
||||
NoopEight\2ebefore_model -.-> model;
|
||||
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
|
||||
NoopSeven\2eafter_model -.-> __end__;
|
||||
NoopSeven\2eafter_model -.-> tools;
|
||||
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
|
||||
__start__ --> NoopSeven\2ebefore_model;
|
||||
model --> NoopEight\2eafter_model;
|
||||
tools -.-> NoopSeven\2ebefore_model;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_jump[sqlite]
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
tools(tools)
|
||||
NoopSeven\2ebefore_model(NoopSeven.before_model)
|
||||
NoopSeven\2eafter_model(NoopSeven.after_model)
|
||||
NoopEight\2ebefore_model(NoopEight.before_model)
|
||||
NoopEight\2eafter_model(NoopEight.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
|
||||
NoopEight\2ebefore_model -.-> __end__;
|
||||
NoopEight\2ebefore_model -.-> model;
|
||||
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
|
||||
NoopSeven\2eafter_model -.-> __end__;
|
||||
NoopSeven\2eafter_model -.-> tools;
|
||||
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
|
||||
__start__ --> NoopSeven\2ebefore_model;
|
||||
model --> NoopEight\2eafter_model;
|
||||
tools -.-> NoopSeven\2ebefore_model;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_simple_agent_graph
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
tools(tools)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
__start__ --> model;
|
||||
model -.-> __end__;
|
||||
model -.-> tools;
|
||||
tools -.-> model;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
@@ -0,0 +1,95 @@
|
||||
# serializer version: 1
|
||||
# name: test_async_middleware_with_can_jump_to_graph_snapshot
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
async_before_with_jump\2ebefore_model(async_before_with_jump.before_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
__start__ --> async_before_with_jump\2ebefore_model;
|
||||
async_before_with_jump\2ebefore_model -.-> __end__;
|
||||
async_before_with_jump\2ebefore_model -.-> model;
|
||||
model --> __end__;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_async_middleware_with_can_jump_to_graph_snapshot.1
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
async_after_with_jump\2eafter_model(async_after_with_jump.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
__start__ --> model;
|
||||
async_after_with_jump\2eafter_model -.-> __end__;
|
||||
async_after_with_jump\2eafter_model -.-> model;
|
||||
model --> async_after_with_jump\2eafter_model;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_async_middleware_with_can_jump_to_graph_snapshot.2
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
async_before_early_exit\2ebefore_model(async_before_early_exit.before_model)
|
||||
async_after_retry\2eafter_model(async_after_retry.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
__start__ --> async_before_early_exit\2ebefore_model;
|
||||
async_after_retry\2eafter_model -.-> __end__;
|
||||
async_after_retry\2eafter_model -.-> async_before_early_exit\2ebefore_model;
|
||||
async_before_early_exit\2ebefore_model -.-> __end__;
|
||||
async_before_early_exit\2ebefore_model -.-> model;
|
||||
model --> async_after_retry\2eafter_model;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_async_middleware_with_can_jump_to_graph_snapshot.3
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
sync_before_with_jump\2ebefore_model(sync_before_with_jump.before_model)
|
||||
async_after_with_jumps\2eafter_model(async_after_with_jumps.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
__start__ --> sync_before_with_jump\2ebefore_model;
|
||||
async_after_with_jumps\2eafter_model -.-> __end__;
|
||||
async_after_with_jumps\2eafter_model -.-> sync_before_with_jump\2ebefore_model;
|
||||
model --> async_after_with_jumps\2eafter_model;
|
||||
sync_before_with_jump\2ebefore_model -.-> __end__;
|
||||
sync_before_with_jump\2ebefore_model -.-> model;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
@@ -0,0 +1,289 @@
|
||||
# serializer version: 1
|
||||
# name: test_create_agent_diagram
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
__start__ --> model;
|
||||
model --> __end__;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_diagram.1
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
NoopOne\2ebefore_model(NoopOne.before_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopOne\2ebefore_model --> model;
|
||||
__start__ --> NoopOne\2ebefore_model;
|
||||
model --> __end__;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_diagram.10
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
NoopTen\2ebefore_model(NoopTen.before_model)
|
||||
NoopTen\2eafter_model(NoopTen.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopTen\2ebefore_model --> model;
|
||||
__start__ --> NoopTen\2ebefore_model;
|
||||
model --> NoopTen\2eafter_model;
|
||||
NoopTen\2eafter_model --> __end__;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_diagram.11
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
NoopTen\2ebefore_model(NoopTen.before_model)
|
||||
NoopTen\2eafter_model(NoopTen.after_model)
|
||||
NoopEleven\2ebefore_model(NoopEleven.before_model)
|
||||
NoopEleven\2eafter_model(NoopEleven.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopEleven\2eafter_model --> NoopTen\2eafter_model;
|
||||
NoopEleven\2ebefore_model --> model;
|
||||
NoopTen\2ebefore_model --> NoopEleven\2ebefore_model;
|
||||
__start__ --> NoopTen\2ebefore_model;
|
||||
model --> NoopEleven\2eafter_model;
|
||||
NoopTen\2eafter_model --> __end__;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_diagram.2
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
NoopOne\2ebefore_model(NoopOne.before_model)
|
||||
NoopTwo\2ebefore_model(NoopTwo.before_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopOne\2ebefore_model --> NoopTwo\2ebefore_model;
|
||||
NoopTwo\2ebefore_model --> model;
|
||||
__start__ --> NoopOne\2ebefore_model;
|
||||
model --> __end__;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_diagram.3
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
NoopOne\2ebefore_model(NoopOne.before_model)
|
||||
NoopTwo\2ebefore_model(NoopTwo.before_model)
|
||||
NoopThree\2ebefore_model(NoopThree.before_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopOne\2ebefore_model --> NoopTwo\2ebefore_model;
|
||||
NoopThree\2ebefore_model --> model;
|
||||
NoopTwo\2ebefore_model --> NoopThree\2ebefore_model;
|
||||
__start__ --> NoopOne\2ebefore_model;
|
||||
model --> __end__;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_diagram.4
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
NoopFour\2eafter_model(NoopFour.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
__start__ --> model;
|
||||
model --> NoopFour\2eafter_model;
|
||||
NoopFour\2eafter_model --> __end__;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_diagram.5
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
NoopFour\2eafter_model(NoopFour.after_model)
|
||||
NoopFive\2eafter_model(NoopFive.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopFive\2eafter_model --> NoopFour\2eafter_model;
|
||||
__start__ --> model;
|
||||
model --> NoopFive\2eafter_model;
|
||||
NoopFour\2eafter_model --> __end__;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_diagram.6
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
NoopFour\2eafter_model(NoopFour.after_model)
|
||||
NoopFive\2eafter_model(NoopFive.after_model)
|
||||
NoopSix\2eafter_model(NoopSix.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopFive\2eafter_model --> NoopFour\2eafter_model;
|
||||
NoopSix\2eafter_model --> NoopFive\2eafter_model;
|
||||
__start__ --> model;
|
||||
model --> NoopSix\2eafter_model;
|
||||
NoopFour\2eafter_model --> __end__;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_diagram.7
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
NoopSeven\2ebefore_model(NoopSeven.before_model)
|
||||
NoopSeven\2eafter_model(NoopSeven.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopSeven\2ebefore_model --> model;
|
||||
__start__ --> NoopSeven\2ebefore_model;
|
||||
model --> NoopSeven\2eafter_model;
|
||||
NoopSeven\2eafter_model --> __end__;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_diagram.8
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
NoopSeven\2ebefore_model(NoopSeven.before_model)
|
||||
NoopSeven\2eafter_model(NoopSeven.after_model)
|
||||
NoopEight\2ebefore_model(NoopEight.before_model)
|
||||
NoopEight\2eafter_model(NoopEight.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
|
||||
NoopEight\2ebefore_model --> model;
|
||||
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
|
||||
__start__ --> NoopSeven\2ebefore_model;
|
||||
model --> NoopEight\2eafter_model;
|
||||
NoopSeven\2eafter_model --> __end__;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_diagram.9
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
NoopSeven\2ebefore_model(NoopSeven.before_model)
|
||||
NoopSeven\2eafter_model(NoopSeven.after_model)
|
||||
NoopEight\2ebefore_model(NoopEight.before_model)
|
||||
NoopEight\2eafter_model(NoopEight.after_model)
|
||||
NoopNine\2ebefore_model(NoopNine.before_model)
|
||||
NoopNine\2eafter_model(NoopNine.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
|
||||
NoopEight\2ebefore_model --> NoopNine\2ebefore_model;
|
||||
NoopNine\2eafter_model --> NoopEight\2eafter_model;
|
||||
NoopNine\2ebefore_model --> model;
|
||||
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
|
||||
__start__ --> NoopSeven\2ebefore_model;
|
||||
model --> NoopNine\2eafter_model;
|
||||
NoopSeven\2eafter_model --> __end__;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
@@ -0,0 +1,212 @@
|
||||
# serializer version: 1
|
||||
# name: test_agent_graph_with_jump_to_end_as_after_agent
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
tools(tools)
|
||||
NoopZero\2ebefore_agent(NoopZero.before_agent)
|
||||
NoopOne\2eafter_agent(NoopOne.after_agent)
|
||||
NoopTwo\2eafter_agent(NoopTwo.after_agent)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopTwo\2eafter_agent --> NoopOne\2eafter_agent;
|
||||
NoopZero\2ebefore_agent -.-> NoopTwo\2eafter_agent;
|
||||
NoopZero\2ebefore_agent -.-> model;
|
||||
__start__ --> NoopZero\2ebefore_agent;
|
||||
model -.-> NoopTwo\2eafter_agent;
|
||||
model -.-> tools;
|
||||
tools -.-> model;
|
||||
NoopOne\2eafter_agent --> __end__;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_jump[memory]
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
tools(tools)
|
||||
NoopSeven\2ebefore_model(NoopSeven.before_model)
|
||||
NoopSeven\2eafter_model(NoopSeven.after_model)
|
||||
NoopEight\2ebefore_model(NoopEight.before_model)
|
||||
NoopEight\2eafter_model(NoopEight.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
|
||||
NoopEight\2ebefore_model -.-> __end__;
|
||||
NoopEight\2ebefore_model -.-> model;
|
||||
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
|
||||
NoopSeven\2eafter_model -.-> __end__;
|
||||
NoopSeven\2eafter_model -.-> tools;
|
||||
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
|
||||
__start__ --> NoopSeven\2ebefore_model;
|
||||
model --> NoopEight\2eafter_model;
|
||||
tools -.-> NoopSeven\2ebefore_model;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_jump[postgres]
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
tools(tools)
|
||||
NoopSeven\2ebefore_model(NoopSeven.before_model)
|
||||
NoopSeven\2eafter_model(NoopSeven.after_model)
|
||||
NoopEight\2ebefore_model(NoopEight.before_model)
|
||||
NoopEight\2eafter_model(NoopEight.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
|
||||
NoopEight\2ebefore_model -.-> __end__;
|
||||
NoopEight\2ebefore_model -.-> model;
|
||||
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
|
||||
NoopSeven\2eafter_model -.-> __end__;
|
||||
NoopSeven\2eafter_model -.-> tools;
|
||||
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
|
||||
__start__ --> NoopSeven\2ebefore_model;
|
||||
model --> NoopEight\2eafter_model;
|
||||
tools -.-> NoopSeven\2ebefore_model;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_jump[postgres_pipe]
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
tools(tools)
|
||||
NoopSeven\2ebefore_model(NoopSeven.before_model)
|
||||
NoopSeven\2eafter_model(NoopSeven.after_model)
|
||||
NoopEight\2ebefore_model(NoopEight.before_model)
|
||||
NoopEight\2eafter_model(NoopEight.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
|
||||
NoopEight\2ebefore_model -.-> __end__;
|
||||
NoopEight\2ebefore_model -.-> model;
|
||||
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
|
||||
NoopSeven\2eafter_model -.-> __end__;
|
||||
NoopSeven\2eafter_model -.-> tools;
|
||||
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
|
||||
__start__ --> NoopSeven\2ebefore_model;
|
||||
model --> NoopEight\2eafter_model;
|
||||
tools -.-> NoopSeven\2ebefore_model;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_jump[postgres_pool]
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
tools(tools)
|
||||
NoopSeven\2ebefore_model(NoopSeven.before_model)
|
||||
NoopSeven\2eafter_model(NoopSeven.after_model)
|
||||
NoopEight\2ebefore_model(NoopEight.before_model)
|
||||
NoopEight\2eafter_model(NoopEight.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
|
||||
NoopEight\2ebefore_model -.-> __end__;
|
||||
NoopEight\2ebefore_model -.-> model;
|
||||
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
|
||||
NoopSeven\2eafter_model -.-> __end__;
|
||||
NoopSeven\2eafter_model -.-> tools;
|
||||
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
|
||||
__start__ --> NoopSeven\2ebefore_model;
|
||||
model --> NoopEight\2eafter_model;
|
||||
tools -.-> NoopSeven\2ebefore_model;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_jump[sqlite]
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
tools(tools)
|
||||
NoopSeven\2ebefore_model(NoopSeven.before_model)
|
||||
NoopSeven\2eafter_model(NoopSeven.after_model)
|
||||
NoopEight\2ebefore_model(NoopEight.before_model)
|
||||
NoopEight\2eafter_model(NoopEight.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
|
||||
NoopEight\2ebefore_model -.-> __end__;
|
||||
NoopEight\2ebefore_model -.-> model;
|
||||
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
|
||||
NoopSeven\2eafter_model -.-> __end__;
|
||||
NoopSeven\2eafter_model -.-> tools;
|
||||
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
|
||||
__start__ --> NoopSeven\2ebefore_model;
|
||||
model --> NoopEight\2eafter_model;
|
||||
tools -.-> NoopSeven\2ebefore_model;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_simple_agent_graph
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
tools(tools)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
__start__ --> model;
|
||||
model -.-> __end__;
|
||||
model -.-> tools;
|
||||
tools -.-> model;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
@@ -230,9 +230,7 @@ class TestChainModelCallHandlers:
|
||||
test_runtime = {"test": "runtime"}
|
||||
|
||||
# Create request with state and runtime
|
||||
test_request = create_test_request()
|
||||
test_request.state = test_state
|
||||
test_request.runtime = test_runtime
|
||||
test_request = create_test_request(state=test_state, runtime=test_runtime)
|
||||
result = composed(test_request, create_mock_base_handler())
|
||||
|
||||
# Both handlers should see same state and runtime
|
||||
@@ -22,7 +22,7 @@ from langchain.agents.middleware.types import (
|
||||
hook_config,
|
||||
)
|
||||
from langchain.agents.factory import create_agent, _get_can_jump_to
|
||||
from .model import FakeToolCallingModel
|
||||
from ...model import FakeToolCallingModel
|
||||
|
||||
|
||||
class CustomState(AgentState):
|
||||
@@ -90,8 +90,7 @@ def test_on_model_call_decorator() -> None:
|
||||
|
||||
@wrap_model_call(state_schema=CustomState, tools=[test_tool], name="CustomOnModelCall")
|
||||
def custom_on_model_call(request, handler):
|
||||
request.system_prompt = "Modified"
|
||||
return handler(request)
|
||||
return handler(request.override(system_prompt="Modified"))
|
||||
|
||||
# Verify all options were applied
|
||||
assert isinstance(custom_on_model_call, AgentMiddleware)
|
||||
@@ -277,8 +276,7 @@ def test_async_on_model_call_decorator() -> None:
|
||||
|
||||
@wrap_model_call(state_schema=CustomState, tools=[test_tool], name="AsyncOnModelCall")
|
||||
async def async_on_model_call(request, handler):
|
||||
request.system_prompt = "Modified async"
|
||||
return await handler(request)
|
||||
return await handler(request.override(system_prompt="Modified async"))
|
||||
|
||||
assert isinstance(async_on_model_call, AgentMiddleware)
|
||||
assert async_on_model_call.state_schema == CustomState
|
||||
@@ -0,0 +1,193 @@
|
||||
from collections.abc import Callable
|
||||
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
|
||||
from langchain.agents.factory import create_agent
|
||||
from langchain.agents.middleware.types import AgentMiddleware, ModelRequest
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
from ...model import FakeToolCallingModel
|
||||
|
||||
|
||||
def test_create_agent_diagram(
|
||||
snapshot: SnapshotAssertion,
|
||||
):
|
||||
class NoopOne(AgentMiddleware):
|
||||
def before_model(self, state, runtime):
|
||||
pass
|
||||
|
||||
class NoopTwo(AgentMiddleware):
|
||||
def before_model(self, state, runtime):
|
||||
pass
|
||||
|
||||
class NoopThree(AgentMiddleware):
|
||||
def before_model(self, state, runtime):
|
||||
pass
|
||||
|
||||
class NoopFour(AgentMiddleware):
|
||||
def after_model(self, state, runtime):
|
||||
pass
|
||||
|
||||
class NoopFive(AgentMiddleware):
|
||||
def after_model(self, state, runtime):
|
||||
pass
|
||||
|
||||
class NoopSix(AgentMiddleware):
|
||||
def after_model(self, state, runtime):
|
||||
pass
|
||||
|
||||
class NoopSeven(AgentMiddleware):
|
||||
def before_model(self, state, runtime):
|
||||
pass
|
||||
|
||||
def after_model(self, state, runtime):
|
||||
pass
|
||||
|
||||
class NoopEight(AgentMiddleware):
|
||||
def before_model(self, state, runtime):
|
||||
pass
|
||||
|
||||
def after_model(self, state, runtime):
|
||||
pass
|
||||
|
||||
class NoopNine(AgentMiddleware):
|
||||
def before_model(self, state, runtime):
|
||||
pass
|
||||
|
||||
def after_model(self, state, runtime):
|
||||
pass
|
||||
|
||||
class NoopTen(AgentMiddleware):
|
||||
def before_model(self, state, runtime):
|
||||
pass
|
||||
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], AIMessage],
|
||||
) -> AIMessage:
|
||||
return handler(request)
|
||||
|
||||
def after_model(self, state, runtime):
|
||||
pass
|
||||
|
||||
class NoopEleven(AgentMiddleware):
|
||||
def before_model(self, state, runtime):
|
||||
pass
|
||||
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], AIMessage],
|
||||
) -> AIMessage:
|
||||
return handler(request)
|
||||
|
||||
def after_model(self, state, runtime):
|
||||
pass
|
||||
|
||||
agent_zero = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
tools=[],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
)
|
||||
|
||||
assert agent_zero.get_graph().draw_mermaid() == snapshot
|
||||
|
||||
agent_one = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
tools=[],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[NoopOne()],
|
||||
)
|
||||
|
||||
assert agent_one.get_graph().draw_mermaid() == snapshot
|
||||
|
||||
agent_two = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
tools=[],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[NoopOne(), NoopTwo()],
|
||||
)
|
||||
|
||||
assert agent_two.get_graph().draw_mermaid() == snapshot
|
||||
|
||||
agent_three = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
tools=[],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[NoopOne(), NoopTwo(), NoopThree()],
|
||||
)
|
||||
|
||||
assert agent_three.get_graph().draw_mermaid() == snapshot
|
||||
|
||||
agent_four = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
tools=[],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[NoopFour()],
|
||||
)
|
||||
|
||||
assert agent_four.get_graph().draw_mermaid() == snapshot
|
||||
|
||||
agent_five = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
tools=[],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[NoopFour(), NoopFive()],
|
||||
)
|
||||
|
||||
assert agent_five.get_graph().draw_mermaid() == snapshot
|
||||
|
||||
agent_six = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
tools=[],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[NoopFour(), NoopFive(), NoopSix()],
|
||||
)
|
||||
|
||||
assert agent_six.get_graph().draw_mermaid() == snapshot
|
||||
|
||||
agent_seven = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
tools=[],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[NoopSeven()],
|
||||
)
|
||||
|
||||
assert agent_seven.get_graph().draw_mermaid() == snapshot
|
||||
|
||||
agent_eight = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
tools=[],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[NoopSeven(), NoopEight()],
|
||||
)
|
||||
|
||||
assert agent_eight.get_graph().draw_mermaid() == snapshot
|
||||
|
||||
agent_nine = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
tools=[],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[NoopSeven(), NoopEight(), NoopNine()],
|
||||
)
|
||||
|
||||
assert agent_nine.get_graph().draw_mermaid() == snapshot
|
||||
|
||||
agent_ten = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
tools=[],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[NoopTen()],
|
||||
)
|
||||
|
||||
assert agent_ten.get_graph().draw_mermaid() == snapshot
|
||||
|
||||
agent_eleven = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
tools=[],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[NoopTen(), NoopEleven()],
|
||||
)
|
||||
|
||||
assert agent_eleven.get_graph().draw_mermaid() == snapshot
|
||||
File diff suppressed because it is too large
Load Diff
@@ -14,7 +14,7 @@ from langgraph.checkpoint.memory import InMemorySaver
|
||||
from langchain.agents.factory import create_agent
|
||||
from langchain.agents.middleware.types import AgentMiddleware, wrap_tool_call
|
||||
from langchain.agents.middleware.types import ToolCallRequest
|
||||
from tests.unit_tests.agents.test_middleware_agent import FakeToolCallingModel
|
||||
from tests.unit_tests.agents.model import FakeToolCallingModel
|
||||
|
||||
|
||||
@tool
|
||||
@@ -9,7 +9,7 @@ from langchain.agents.middleware.types import AgentMiddleware, AgentState, Model
|
||||
from langgraph.prebuilt.tool_node import ToolNode
|
||||
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
|
||||
from langchain_core.tools import tool
|
||||
from .model import FakeToolCallingModel
|
||||
from tests.unit_tests.agents.model import FakeToolCallingModel
|
||||
|
||||
|
||||
def test_model_request_tools_are_base_tools() -> None:
|
||||
@@ -79,8 +79,8 @@ def test_middleware_can_modify_tools() -> None:
|
||||
handler: Callable[[ModelRequest], AIMessage],
|
||||
) -> AIMessage:
|
||||
# Only allow tool_a and tool_b
|
||||
request.tools = [t for t in request.tools if t.name in ["tool_a", "tool_b"]]
|
||||
return handler(request)
|
||||
filtered_tools = [t for t in request.tools if t.name in ["tool_a", "tool_b"]]
|
||||
return handler(request.override(tools=filtered_tools))
|
||||
|
||||
# Model will try to call tool_a
|
||||
model = FakeToolCallingModel(
|
||||
@@ -123,8 +123,7 @@ def test_unknown_tool_raises_error() -> None:
|
||||
handler: Callable[[ModelRequest], AIMessage],
|
||||
) -> AIMessage:
|
||||
# Add an unknown tool
|
||||
request.tools = request.tools + [unknown_tool]
|
||||
return handler(request)
|
||||
return handler(request.override(tools=request.tools + [unknown_tool]))
|
||||
|
||||
agent = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
@@ -163,7 +162,8 @@ def test_middleware_can_add_and_remove_tools() -> None:
|
||||
) -> AIMessage:
|
||||
# Remove admin_tool if not admin
|
||||
if not request.state.get("is_admin", False):
|
||||
request.tools = [t for t in request.tools if t.name != "admin_tool"]
|
||||
filtered_tools = [t for t in request.tools if t.name != "admin_tool"]
|
||||
request = request.override(tools=filtered_tools)
|
||||
return handler(request)
|
||||
|
||||
model = FakeToolCallingModel()
|
||||
@@ -200,7 +200,7 @@ def test_empty_tools_list_is_valid() -> None:
|
||||
handler: Callable[[ModelRequest], AIMessage],
|
||||
) -> AIMessage:
|
||||
# Remove all tools
|
||||
request.tools = []
|
||||
request = request.override(tools=[])
|
||||
return handler(request)
|
||||
|
||||
model = FakeToolCallingModel()
|
||||
@@ -244,7 +244,8 @@ def test_tools_preserved_across_multiple_middleware() -> None:
|
||||
) -> AIMessage:
|
||||
modification_order.append([t.name for t in request.tools])
|
||||
# Remove tool_c
|
||||
request.tools = [t for t in request.tools if t.name != "tool_c"]
|
||||
filtered_tools = [t for t in request.tools if t.name != "tool_c"]
|
||||
request = request.override(tools=filtered_tools)
|
||||
return handler(request)
|
||||
|
||||
class SecondMiddleware(AgentMiddleware):
|
||||
@@ -257,7 +258,8 @@ def test_tools_preserved_across_multiple_middleware() -> None:
|
||||
# Should not see tool_c here
|
||||
assert all(t.name != "tool_c" for t in request.tools)
|
||||
# Remove tool_b
|
||||
request.tools = [t for t in request.tools if t.name != "tool_b"]
|
||||
filtered_tools = [t for t in request.tools if t.name != "tool_b"]
|
||||
request = request.override(tools=filtered_tools)
|
||||
return handler(request)
|
||||
|
||||
agent = create_agent(
|
||||
@@ -1,18 +1,30 @@
|
||||
"""Unit tests for wrap_model_call middleware generator protocol."""
|
||||
"""Unit tests for wrap_model_call hook and @wrap_model_call decorator.
|
||||
|
||||
This module tests the wrap_model_call functionality in three forms:
|
||||
1. As a middleware method (AgentMiddleware.wrap_model_call)
|
||||
2. As a decorator (@wrap_model_call)
|
||||
3. Async variant (AgentMiddleware.awrap_model_call)
|
||||
"""
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
|
||||
from langchain.agents import create_agent
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ModelRequest,
|
||||
wrap_model_call,
|
||||
)
|
||||
|
||||
from ...model import FakeToolCallingModel
|
||||
|
||||
class TestBasicOnModelCall:
|
||||
|
||||
class TestBasicWrapModelCall:
|
||||
"""Test basic wrap_model_call functionality."""
|
||||
|
||||
def test_passthrough_middleware(self) -> None:
|
||||
@@ -70,7 +82,7 @@ class TestBasicOnModelCall:
|
||||
assert counter.call_count == 1
|
||||
|
||||
|
||||
class TestRetryMiddleware:
|
||||
class TestRetryLogic:
|
||||
"""Test retry logic with wrap_model_call."""
|
||||
|
||||
def test_simple_retry_on_error(self) -> None:
|
||||
@@ -91,12 +103,10 @@ class TestRetryMiddleware:
|
||||
|
||||
def wrap_model_call(self, request, handler):
|
||||
try:
|
||||
result = handler(request)
|
||||
return result
|
||||
return handler(request)
|
||||
except Exception:
|
||||
self.retry_count += 1
|
||||
result = handler(request)
|
||||
return result
|
||||
return handler(request)
|
||||
|
||||
retry_middleware = RetryOnceMiddleware()
|
||||
model = FailOnceThenSucceed(messages=iter([AIMessage(content="Success")]))
|
||||
@@ -125,8 +135,7 @@ class TestRetryMiddleware:
|
||||
for attempt in range(self.max_retries):
|
||||
self.attempts.append(attempt + 1)
|
||||
try:
|
||||
result = handler(request)
|
||||
return result
|
||||
return handler(request)
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
continue
|
||||
@@ -143,6 +152,75 @@ class TestRetryMiddleware:
|
||||
|
||||
assert retry_middleware.attempts == [1, 2, 3]
|
||||
|
||||
def test_no_retry_propagates_error(self) -> None:
|
||||
"""Test that error is propagated when middleware doesn't retry."""
|
||||
|
||||
class FailingModel(BaseChatModel):
|
||||
"""Model that always fails."""
|
||||
|
||||
def _generate(self, messages, **kwargs):
|
||||
raise ValueError("Model error")
|
||||
|
||||
@property
|
||||
def _llm_type(self):
|
||||
return "failing"
|
||||
|
||||
class NoRetryMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(self, request, handler):
|
||||
return handler(request)
|
||||
|
||||
agent = create_agent(model=FailingModel(), middleware=[NoRetryMiddleware()])
|
||||
|
||||
with pytest.raises(ValueError, match="Model error"):
|
||||
agent.invoke({"messages": [HumanMessage("Test")]})
|
||||
|
||||
def test_max_attempts_limit(self) -> None:
|
||||
"""Test that middleware controls termination via retry limits."""
|
||||
|
||||
class AlwaysFailingModel(BaseChatModel):
|
||||
"""Model that always fails."""
|
||||
|
||||
def _generate(self, messages, **kwargs):
|
||||
raise ValueError("Always fails")
|
||||
|
||||
@property
|
||||
def _llm_type(self):
|
||||
return "always_failing"
|
||||
|
||||
class LimitedRetryMiddleware(AgentMiddleware):
|
||||
"""Middleware that limits its own retries."""
|
||||
|
||||
def __init__(self, max_retries: int = 10):
|
||||
super().__init__()
|
||||
self.max_retries = max_retries
|
||||
self.attempt_count = 0
|
||||
|
||||
def wrap_model_call(self, request, handler):
|
||||
last_exception = None
|
||||
for attempt in range(self.max_retries):
|
||||
self.attempt_count += 1
|
||||
try:
|
||||
return handler(request)
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
# Continue to retry
|
||||
|
||||
# All retries exhausted, re-raise the last error
|
||||
if last_exception:
|
||||
raise last_exception
|
||||
|
||||
model = AlwaysFailingModel()
|
||||
middleware = LimitedRetryMiddleware(max_retries=10)
|
||||
|
||||
agent = create_agent(model=model, middleware=[middleware])
|
||||
|
||||
# Should fail with the model's error after middleware stops retrying
|
||||
with pytest.raises(ValueError, match="Always fails"):
|
||||
agent.invoke({"messages": [HumanMessage("Test")]})
|
||||
|
||||
# Should have attempted exactly 10 times as configured
|
||||
assert middleware.attempt_count == 10
|
||||
|
||||
|
||||
class TestResponseRewriting:
|
||||
"""Test response content rewriting with wrap_model_call."""
|
||||
@@ -185,6 +263,28 @@ class TestResponseRewriting:
|
||||
|
||||
assert result["messages"][1].content == "[BOT]: Response"
|
||||
|
||||
def test_multi_stage_transformation(self) -> None:
|
||||
"""Test middleware applying multiple transformations."""
|
||||
|
||||
class MultiTransformMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(self, request, handler):
|
||||
result = handler(request)
|
||||
# result is ModelResponse, extract AIMessage from it
|
||||
ai_message = result.result[0]
|
||||
|
||||
# First transformation: uppercase
|
||||
content = ai_message.content.upper()
|
||||
# Second transformation: add prefix and suffix
|
||||
content = f"[START] {content} [END]"
|
||||
return AIMessage(content=content)
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="hello")]))
|
||||
agent = create_agent(model=model, middleware=[MultiTransformMiddleware()])
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Test")]})
|
||||
|
||||
assert result["messages"][1].content == "[START] HELLO [END]"
|
||||
|
||||
|
||||
class TestErrorHandling:
|
||||
"""Test error handling with wrap_model_call."""
|
||||
@@ -200,9 +300,8 @@ class TestErrorHandling:
|
||||
def wrap_model_call(self, request, handler):
|
||||
try:
|
||||
return handler(request)
|
||||
except Exception:
|
||||
fallback = AIMessage(content="Error handled gracefully")
|
||||
return fallback
|
||||
except Exception as e:
|
||||
return AIMessage(content=f"Error occurred: {e}. Using fallback response.")
|
||||
|
||||
model = AlwaysFailModel(messages=iter([]))
|
||||
agent = create_agent(model=model, middleware=[ErrorToSuccessMiddleware()])
|
||||
@@ -210,7 +309,8 @@ class TestErrorHandling:
|
||||
# Should not raise, middleware converts error to response
|
||||
result = agent.invoke({"messages": [HumanMessage("Test")]})
|
||||
|
||||
assert "Error handled gracefully" in result["messages"][1].content
|
||||
assert "Error occurred" in result["messages"][1].content
|
||||
assert "fallback response" in result["messages"][1].content
|
||||
|
||||
def test_selective_error_handling(self) -> None:
|
||||
"""Test middleware that only handles specific errors."""
|
||||
@@ -224,8 +324,7 @@ class TestErrorHandling:
|
||||
try:
|
||||
return handler(request)
|
||||
except ConnectionError:
|
||||
fallback = AIMessage(content="Network issue, try again later")
|
||||
return fallback
|
||||
return AIMessage(content="Network issue, try again later")
|
||||
|
||||
model = SpecificErrorModel(messages=iter([]))
|
||||
agent = create_agent(model=model, middleware=[SelectiveErrorMiddleware()])
|
||||
@@ -247,8 +346,7 @@ class TestErrorHandling:
|
||||
return result
|
||||
except Exception:
|
||||
call_log.append("caught-error")
|
||||
fallback = AIMessage(content="Recovered from error")
|
||||
return fallback
|
||||
return AIMessage(content="Recovered from error")
|
||||
|
||||
# Test 1: Success path
|
||||
call_log.clear()
|
||||
@@ -403,7 +501,6 @@ class TestStateAndRuntime:
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
return handler(request)
|
||||
break # Success
|
||||
except Exception:
|
||||
if attempt == max_retries - 1:
|
||||
raise
|
||||
@@ -460,6 +557,49 @@ class TestMiddlewareComposition:
|
||||
"outer-after",
|
||||
]
|
||||
|
||||
def test_three_middleware_composition(self) -> None:
|
||||
"""Test composition of three middleware."""
|
||||
execution_order = []
|
||||
|
||||
class FirstMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(self, request, handler):
|
||||
execution_order.append("first-before")
|
||||
response = handler(request)
|
||||
execution_order.append("first-after")
|
||||
return response
|
||||
|
||||
class SecondMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(self, request, handler):
|
||||
execution_order.append("second-before")
|
||||
response = handler(request)
|
||||
execution_order.append("second-after")
|
||||
return response
|
||||
|
||||
class ThirdMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(self, request, handler):
|
||||
execution_order.append("third-before")
|
||||
response = handler(request)
|
||||
execution_order.append("third-after")
|
||||
return response
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")]))
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
middleware=[FirstMiddleware(), SecondMiddleware(), ThirdMiddleware()],
|
||||
)
|
||||
|
||||
agent.invoke({"messages": [HumanMessage("Test")]})
|
||||
|
||||
# First wraps Second wraps Third: 1-before, 2-before, 3-before, model, 3-after, 2-after, 1-after
|
||||
assert execution_order == [
|
||||
"first-before",
|
||||
"second-before",
|
||||
"third-before",
|
||||
"third-after",
|
||||
"second-after",
|
||||
"first-after",
|
||||
]
|
||||
|
||||
def test_retry_with_logging(self) -> None:
|
||||
"""Test retry middleware composed with logging middleware."""
|
||||
call_count = {"value": 0}
|
||||
@@ -549,11 +689,9 @@ class TestMiddlewareComposition:
|
||||
class RetryMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(self, request, handler):
|
||||
try:
|
||||
result = handler(request)
|
||||
return result
|
||||
return handler(request)
|
||||
except Exception:
|
||||
result = handler(request)
|
||||
return result
|
||||
return handler(request)
|
||||
|
||||
class UppercaseMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(self, request, handler):
|
||||
@@ -571,49 +709,6 @@ class TestMiddlewareComposition:
|
||||
# Should retry and uppercase the result
|
||||
assert result["messages"][1].content == "SUCCESS"
|
||||
|
||||
def test_three_middleware_composition(self) -> None:
|
||||
"""Test composition of three middleware."""
|
||||
execution_order = []
|
||||
|
||||
class FirstMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(self, request, handler):
|
||||
execution_order.append("first-before")
|
||||
response = handler(request)
|
||||
execution_order.append("first-after")
|
||||
return response
|
||||
|
||||
class SecondMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(self, request, handler):
|
||||
execution_order.append("second-before")
|
||||
response = handler(request)
|
||||
execution_order.append("second-after")
|
||||
return response
|
||||
|
||||
class ThirdMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(self, request, handler):
|
||||
execution_order.append("third-before")
|
||||
response = handler(request)
|
||||
execution_order.append("third-after")
|
||||
return response
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")]))
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
middleware=[FirstMiddleware(), SecondMiddleware(), ThirdMiddleware()],
|
||||
)
|
||||
|
||||
agent.invoke({"messages": [HumanMessage("Test")]})
|
||||
|
||||
# First wraps Second wraps Third: 1-before, 2-before, 3-before, model, 3-after, 2-after, 1-after
|
||||
assert execution_order == [
|
||||
"first-before",
|
||||
"second-before",
|
||||
"third-before",
|
||||
"third-after",
|
||||
"second-after",
|
||||
"first-after",
|
||||
]
|
||||
|
||||
def test_middle_retry_middleware(self) -> None:
|
||||
"""Test that middle middleware doing retry causes inner to execute twice."""
|
||||
execution_order = []
|
||||
@@ -674,7 +769,306 @@ class TestMiddlewareComposition:
|
||||
assert len(model_calls) == 2
|
||||
|
||||
|
||||
class TestAsyncOnModelCall:
|
||||
class TestWrapModelCallDecorator:
|
||||
"""Test the @wrap_model_call decorator for creating middleware."""
|
||||
|
||||
def test_basic_decorator_usage(self) -> None:
|
||||
"""Test basic decorator usage without parameters."""
|
||||
|
||||
@wrap_model_call
|
||||
def passthrough_middleware(request, handler):
|
||||
return handler(request)
|
||||
|
||||
# Should return an AgentMiddleware instance
|
||||
assert isinstance(passthrough_middleware, AgentMiddleware)
|
||||
|
||||
# Should work in agent
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
|
||||
agent = create_agent(model=model, middleware=[passthrough_middleware])
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Hi")]})
|
||||
assert len(result["messages"]) == 2
|
||||
assert result["messages"][1].content == "Hello"
|
||||
|
||||
def test_decorator_with_custom_name(self) -> None:
|
||||
"""Test decorator with custom middleware name."""
|
||||
|
||||
@wrap_model_call(name="CustomMiddleware")
|
||||
def my_middleware(request, handler):
|
||||
return handler(request)
|
||||
|
||||
assert isinstance(my_middleware, AgentMiddleware)
|
||||
assert my_middleware.__class__.__name__ == "CustomMiddleware"
|
||||
|
||||
def test_decorator_retry_logic(self) -> None:
|
||||
"""Test decorator for implementing retry logic."""
|
||||
call_count = {"value": 0}
|
||||
|
||||
class FailOnceThenSucceed(GenericFakeChatModel):
|
||||
def _generate(self, messages, **kwargs):
|
||||
call_count["value"] += 1
|
||||
if call_count["value"] == 1:
|
||||
raise ValueError("First call fails")
|
||||
return super()._generate(messages, **kwargs)
|
||||
|
||||
@wrap_model_call
|
||||
def retry_once(request, handler):
|
||||
try:
|
||||
return handler(request)
|
||||
except Exception:
|
||||
# Retry once
|
||||
return handler(request)
|
||||
|
||||
model = FailOnceThenSucceed(messages=iter([AIMessage(content="Success")]))
|
||||
agent = create_agent(model=model, middleware=[retry_once])
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Test")]})
|
||||
|
||||
assert call_count["value"] == 2
|
||||
assert result["messages"][1].content == "Success"
|
||||
|
||||
def test_decorator_response_rewriting(self) -> None:
|
||||
"""Test decorator for rewriting responses."""
|
||||
|
||||
@wrap_model_call
|
||||
def uppercase_responses(request, handler):
|
||||
result = handler(request)
|
||||
# result is ModelResponse, extract AIMessage from it
|
||||
ai_message = result.result[0]
|
||||
return AIMessage(content=ai_message.content.upper())
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="hello world")]))
|
||||
agent = create_agent(model=model, middleware=[uppercase_responses])
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Test")]})
|
||||
|
||||
assert result["messages"][1].content == "HELLO WORLD"
|
||||
|
||||
def test_decorator_error_handling(self) -> None:
|
||||
"""Test decorator for error recovery."""
|
||||
|
||||
class AlwaysFailModel(GenericFakeChatModel):
|
||||
def _generate(self, messages, **kwargs):
|
||||
raise ValueError("Model error")
|
||||
|
||||
@wrap_model_call
|
||||
def error_to_fallback(request, handler):
|
||||
try:
|
||||
return handler(request)
|
||||
except Exception:
|
||||
return AIMessage(content="Fallback response")
|
||||
|
||||
model = AlwaysFailModel(messages=iter([]))
|
||||
agent = create_agent(model=model, middleware=[error_to_fallback])
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Test")]})
|
||||
|
||||
assert result["messages"][1].content == "Fallback response"
|
||||
|
||||
def test_decorator_with_state_access(self) -> None:
|
||||
"""Test decorator accessing agent state."""
|
||||
state_values = []
|
||||
|
||||
@wrap_model_call
|
||||
def log_state(request, handler):
|
||||
state_values.append(request.state.get("messages"))
|
||||
return handler(request)
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")]))
|
||||
agent = create_agent(model=model, middleware=[log_state])
|
||||
|
||||
agent.invoke({"messages": [HumanMessage("Test")]})
|
||||
|
||||
# State should contain the user message
|
||||
assert len(state_values) == 1
|
||||
assert len(state_values[0]) == 1
|
||||
assert state_values[0][0].content == "Test"
|
||||
|
||||
def test_multiple_decorated_middleware(self) -> None:
|
||||
"""Test composition of multiple decorated middleware."""
|
||||
execution_order = []
|
||||
|
||||
@wrap_model_call
|
||||
def outer_middleware(request, handler):
|
||||
execution_order.append("outer-before")
|
||||
result = handler(request)
|
||||
execution_order.append("outer-after")
|
||||
return result
|
||||
|
||||
@wrap_model_call
|
||||
def inner_middleware(request, handler):
|
||||
execution_order.append("inner-before")
|
||||
result = handler(request)
|
||||
execution_order.append("inner-after")
|
||||
return result
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")]))
|
||||
agent = create_agent(model=model, middleware=[outer_middleware, inner_middleware])
|
||||
|
||||
agent.invoke({"messages": [HumanMessage("Test")]})
|
||||
|
||||
assert execution_order == [
|
||||
"outer-before",
|
||||
"inner-before",
|
||||
"inner-after",
|
||||
"outer-after",
|
||||
]
|
||||
|
||||
def test_decorator_with_custom_state_schema(self) -> None:
|
||||
"""Test decorator with custom state schema."""
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
class CustomState(TypedDict):
|
||||
messages: list
|
||||
custom_field: str
|
||||
|
||||
@wrap_model_call(state_schema=CustomState)
|
||||
def middleware_with_schema(request, handler):
|
||||
return handler(request)
|
||||
|
||||
assert isinstance(middleware_with_schema, AgentMiddleware)
|
||||
# Custom state schema should be set
|
||||
assert middleware_with_schema.state_schema == CustomState
|
||||
|
||||
def test_decorator_with_tools_parameter(self) -> None:
|
||||
"""Test decorator with tools parameter."""
|
||||
from langchain_core.tools import tool
|
||||
|
||||
@tool
|
||||
def test_tool(query: str) -> str:
|
||||
"""A test tool."""
|
||||
return f"Result: {query}"
|
||||
|
||||
@wrap_model_call(tools=[test_tool])
|
||||
def middleware_with_tools(request, handler):
|
||||
return handler(request)
|
||||
|
||||
assert isinstance(middleware_with_tools, AgentMiddleware)
|
||||
assert len(middleware_with_tools.tools) == 1
|
||||
assert middleware_with_tools.tools[0].name == "test_tool"
|
||||
|
||||
def test_decorator_parentheses_optional(self) -> None:
|
||||
"""Test that decorator works both with and without parentheses."""
|
||||
|
||||
# Without parentheses
|
||||
@wrap_model_call
|
||||
def middleware_no_parens(request, handler):
|
||||
return handler(request)
|
||||
|
||||
# With parentheses
|
||||
@wrap_model_call()
|
||||
def middleware_with_parens(request, handler):
|
||||
return handler(request)
|
||||
|
||||
assert isinstance(middleware_no_parens, AgentMiddleware)
|
||||
assert isinstance(middleware_with_parens, AgentMiddleware)
|
||||
|
||||
def test_decorator_preserves_function_name(self) -> None:
|
||||
"""Test that decorator uses function name for class name."""
|
||||
|
||||
@wrap_model_call
|
||||
def my_custom_middleware(request, handler):
|
||||
return handler(request)
|
||||
|
||||
assert my_custom_middleware.__class__.__name__ == "my_custom_middleware"
|
||||
|
||||
def test_decorator_mixed_with_class_middleware(self) -> None:
|
||||
"""Test decorated middleware mixed with class-based middleware."""
|
||||
execution_order = []
|
||||
|
||||
@wrap_model_call
|
||||
def decorated_middleware(request, handler):
|
||||
execution_order.append("decorated-before")
|
||||
result = handler(request)
|
||||
execution_order.append("decorated-after")
|
||||
return result
|
||||
|
||||
class ClassMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(self, request, handler):
|
||||
execution_order.append("class-before")
|
||||
result = handler(request)
|
||||
execution_order.append("class-after")
|
||||
return result
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")]))
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
middleware=[decorated_middleware, ClassMiddleware()],
|
||||
)
|
||||
|
||||
agent.invoke({"messages": [HumanMessage("Test")]})
|
||||
|
||||
# Decorated is outer, class-based is inner
|
||||
assert execution_order == [
|
||||
"decorated-before",
|
||||
"class-before",
|
||||
"class-after",
|
||||
"decorated-after",
|
||||
]
|
||||
|
||||
def test_decorator_complex_retry_logic(self) -> None:
|
||||
"""Test decorator with complex retry logic and backoff."""
|
||||
attempts = []
|
||||
call_count = {"value": 0}
|
||||
|
||||
class UnreliableModel(GenericFakeChatModel):
|
||||
def _generate(self, messages, **kwargs):
|
||||
call_count["value"] += 1
|
||||
if call_count["value"] <= 2:
|
||||
raise ValueError(f"Attempt {call_count['value']} failed")
|
||||
return super()._generate(messages, **kwargs)
|
||||
|
||||
@wrap_model_call
|
||||
def retry_with_tracking(request, handler):
|
||||
max_retries = 3
|
||||
for attempt in range(max_retries):
|
||||
attempts.append(attempt + 1)
|
||||
try:
|
||||
return handler(request)
|
||||
except Exception:
|
||||
# On error, continue to next attempt
|
||||
if attempt < max_retries - 1:
|
||||
continue # Retry
|
||||
else:
|
||||
raise # All retries failed
|
||||
|
||||
model = UnreliableModel(messages=iter([AIMessage(content="Finally worked")]))
|
||||
agent = create_agent(model=model, middleware=[retry_with_tracking])
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Test")]})
|
||||
|
||||
assert attempts == [1, 2, 3]
|
||||
assert result["messages"][1].content == "Finally worked"
|
||||
|
||||
def test_decorator_request_modification(self) -> None:
|
||||
"""Test decorator modifying request before execution."""
|
||||
modified_prompts = []
|
||||
|
||||
@wrap_model_call
|
||||
def add_system_prompt(request, handler):
|
||||
# Modify request to add system prompt
|
||||
modified_request = ModelRequest(
|
||||
messages=request.messages,
|
||||
model=request.model,
|
||||
system_prompt="You are a helpful assistant",
|
||||
tool_choice=request.tool_choice,
|
||||
tools=request.tools,
|
||||
response_format=request.response_format,
|
||||
state={},
|
||||
runtime=None,
|
||||
)
|
||||
modified_prompts.append(modified_request.system_prompt)
|
||||
return handler(modified_request)
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")]))
|
||||
agent = create_agent(model=model, middleware=[add_system_prompt])
|
||||
|
||||
agent.invoke({"messages": [HumanMessage("Test")]})
|
||||
|
||||
assert modified_prompts == ["You are a helpful assistant"]
|
||||
|
||||
|
||||
class TestAsyncWrapModelCall:
|
||||
"""Test async execution with wrap_model_call."""
|
||||
|
||||
async def test_async_model_with_middleware(self) -> None:
|
||||
@@ -686,7 +1080,6 @@ class TestAsyncOnModelCall:
|
||||
log.append("before")
|
||||
result = await handler(request)
|
||||
log.append("after")
|
||||
|
||||
return result
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Async response")]))
|
||||
@@ -723,6 +1116,92 @@ class TestAsyncOnModelCall:
|
||||
assert call_count["value"] == 2
|
||||
assert result["messages"][1].content == "Async success"
|
||||
|
||||
async def test_decorator_with_async_agent(self) -> None:
|
||||
"""Test that decorated middleware works with async agent invocation."""
|
||||
call_log = []
|
||||
|
||||
@wrap_model_call
|
||||
async def logging_middleware(request, handler):
|
||||
call_log.append("before")
|
||||
result = await handler(request)
|
||||
call_log.append("after")
|
||||
return result
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Async response")]))
|
||||
agent = create_agent(model=model, middleware=[logging_middleware])
|
||||
|
||||
result = await agent.ainvoke({"messages": [HumanMessage("Test")]})
|
||||
|
||||
assert call_log == ["before", "after"]
|
||||
assert result["messages"][1].content == "Async response"
|
||||
|
||||
|
||||
class TestSyncAsyncInterop:
|
||||
"""Test sync/async interoperability."""
|
||||
|
||||
def test_sync_invoke_with_only_async_middleware_raises_error(self) -> None:
|
||||
"""Test that sync invoke with only async middleware raises error."""
|
||||
|
||||
class AsyncOnlyMiddleware(AgentMiddleware):
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[AIMessage]],
|
||||
) -> AIMessage:
|
||||
return await handler(request)
|
||||
|
||||
agent = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
tools=[],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[AsyncOnlyMiddleware()],
|
||||
)
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
agent.invoke({"messages": [HumanMessage("hello")]})
|
||||
|
||||
def test_sync_invoke_with_mixed_middleware(self) -> None:
|
||||
"""Test that sync invoke works with mixed sync/async middleware when sync versions exist."""
|
||||
calls = []
|
||||
|
||||
class MixedMiddleware(AgentMiddleware):
|
||||
def before_model(self, state, runtime) -> None:
|
||||
calls.append("MixedMiddleware.before_model")
|
||||
|
||||
async def abefore_model(self, state, runtime) -> None:
|
||||
calls.append("MixedMiddleware.abefore_model")
|
||||
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], AIMessage],
|
||||
) -> AIMessage:
|
||||
calls.append("MixedMiddleware.wrap_model_call")
|
||||
return handler(request)
|
||||
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[AIMessage]],
|
||||
) -> AIMessage:
|
||||
calls.append("MixedMiddleware.awrap_model_call")
|
||||
return await handler(request)
|
||||
|
||||
agent = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
tools=[],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[MixedMiddleware()],
|
||||
)
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("hello")]})
|
||||
|
||||
# In sync mode, only sync methods should be called
|
||||
assert calls == [
|
||||
"MixedMiddleware.before_model",
|
||||
"MixedMiddleware.wrap_model_call",
|
||||
]
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
"""Test edge cases and error conditions."""
|
||||
@@ -753,12 +1232,10 @@ class TestEdgeCases:
|
||||
def wrap_model_call(self, request, handler):
|
||||
attempts.append("first-attempt")
|
||||
try:
|
||||
result = handler(request)
|
||||
return result
|
||||
return handler(request)
|
||||
except Exception:
|
||||
attempts.append("retry-attempt")
|
||||
result = handler(request)
|
||||
return result
|
||||
return handler(request)
|
||||
|
||||
call_count = {"value": 0}
|
||||
|
||||
@@ -14,7 +14,7 @@ from langgraph.types import Command
|
||||
from langchain.agents.factory import create_agent
|
||||
from langchain.agents.middleware.types import wrap_tool_call
|
||||
from langchain.agents.middleware.types import ToolCallRequest
|
||||
from tests.unit_tests.agents.test_middleware_agent import FakeToolCallingModel
|
||||
from tests.unit_tests.agents.model import FakeToolCallingModel
|
||||
|
||||
|
||||
@tool
|
||||
@@ -82,16 +82,23 @@ def test_no_edit_when_below_trigger() -> None:
|
||||
edits=[ClearToolUsesEdit(trigger=50)],
|
||||
)
|
||||
|
||||
modified_request = None
|
||||
|
||||
def mock_handler(req: ModelRequest) -> AIMessage:
|
||||
nonlocal modified_request
|
||||
modified_request = req
|
||||
return AIMessage(content="mock response")
|
||||
|
||||
# Call wrap_model_call which modifies the request
|
||||
# Call wrap_model_call which creates a new request
|
||||
middleware.wrap_model_call(request, mock_handler)
|
||||
|
||||
# The request should have been modified in place
|
||||
# The modified request passed to handler should be the same since no edits applied
|
||||
assert modified_request is not None
|
||||
assert modified_request.messages[0].content == ""
|
||||
assert modified_request.messages[1].content == "12345"
|
||||
# Original request should be unchanged
|
||||
assert request.messages[0].content == ""
|
||||
assert request.messages[1].content == "12345"
|
||||
assert state["messages"] == request.messages
|
||||
|
||||
|
||||
def test_clear_tool_outputs_and_inputs() -> None:
|
||||
@@ -115,14 +122,19 @@ def test_clear_tool_outputs_and_inputs() -> None:
|
||||
)
|
||||
middleware = ContextEditingMiddleware(edits=[edit])
|
||||
|
||||
modified_request = None
|
||||
|
||||
def mock_handler(req: ModelRequest) -> AIMessage:
|
||||
nonlocal modified_request
|
||||
modified_request = req
|
||||
return AIMessage(content="mock response")
|
||||
|
||||
# Call wrap_model_call which modifies the request
|
||||
# Call wrap_model_call which creates a new request with edits
|
||||
middleware.wrap_model_call(request, mock_handler)
|
||||
|
||||
cleared_ai = request.messages[0]
|
||||
cleared_tool = request.messages[1]
|
||||
assert modified_request is not None
|
||||
cleared_ai = modified_request.messages[0]
|
||||
cleared_tool = modified_request.messages[1]
|
||||
|
||||
assert isinstance(cleared_tool, ToolMessage)
|
||||
assert cleared_tool.content == "[cleared output]"
|
||||
@@ -134,7 +146,9 @@ def test_clear_tool_outputs_and_inputs() -> None:
|
||||
assert context_meta is not None
|
||||
assert context_meta["cleared_tool_inputs"] == [tool_call_id]
|
||||
|
||||
assert state["messages"] == request.messages
|
||||
# Original request should be unchanged
|
||||
assert request.messages[0].tool_calls[0]["args"] == {"query": "foo"}
|
||||
assert request.messages[1].content == "x" * 200
|
||||
|
||||
|
||||
def test_respects_keep_last_tool_results() -> None:
|
||||
@@ -167,21 +181,26 @@ def test_respects_keep_last_tool_results() -> None:
|
||||
token_count_method="model",
|
||||
)
|
||||
|
||||
modified_request = None
|
||||
|
||||
def mock_handler(req: ModelRequest) -> AIMessage:
|
||||
nonlocal modified_request
|
||||
modified_request = req
|
||||
return AIMessage(content="mock response")
|
||||
|
||||
# Call wrap_model_call which modifies the request
|
||||
# Call wrap_model_call which creates a new request with edits
|
||||
middleware.wrap_model_call(request, mock_handler)
|
||||
|
||||
assert modified_request is not None
|
||||
cleared_messages = [
|
||||
msg
|
||||
for msg in request.messages
|
||||
for msg in modified_request.messages
|
||||
if isinstance(msg, ToolMessage) and msg.content == "[cleared]"
|
||||
]
|
||||
|
||||
assert len(cleared_messages) == 2
|
||||
assert isinstance(request.messages[-1], ToolMessage)
|
||||
assert request.messages[-1].content != "[cleared]"
|
||||
assert isinstance(modified_request.messages[-1], ToolMessage)
|
||||
assert modified_request.messages[-1].content != "[cleared]"
|
||||
|
||||
|
||||
def test_exclude_tools_prevents_clearing() -> None:
|
||||
@@ -215,14 +234,19 @@ def test_exclude_tools_prevents_clearing() -> None:
|
||||
],
|
||||
)
|
||||
|
||||
modified_request = None
|
||||
|
||||
def mock_handler(req: ModelRequest) -> AIMessage:
|
||||
nonlocal modified_request
|
||||
modified_request = req
|
||||
return AIMessage(content="mock response")
|
||||
|
||||
# Call wrap_model_call which modifies the request
|
||||
# Call wrap_model_call which creates a new request with edits
|
||||
middleware.wrap_model_call(request, mock_handler)
|
||||
|
||||
search_tool = request.messages[1]
|
||||
calc_tool = request.messages[3]
|
||||
assert modified_request is not None
|
||||
search_tool = modified_request.messages[1]
|
||||
calc_tool = modified_request.messages[3]
|
||||
|
||||
assert isinstance(search_tool, ToolMessage)
|
||||
assert search_tool.content == "search-results" * 20
|
||||
@@ -249,16 +273,23 @@ async def test_no_edit_when_below_trigger_async() -> None:
|
||||
edits=[ClearToolUsesEdit(trigger=50)],
|
||||
)
|
||||
|
||||
modified_request = None
|
||||
|
||||
async def mock_handler(req: ModelRequest) -> AIMessage:
|
||||
nonlocal modified_request
|
||||
modified_request = req
|
||||
return AIMessage(content="mock response")
|
||||
|
||||
# Call awrap_model_call which modifies the request
|
||||
# Call awrap_model_call which creates a new request
|
||||
await middleware.awrap_model_call(request, mock_handler)
|
||||
|
||||
# The request should have been modified in place
|
||||
# The modified request passed to handler should be the same since no edits applied
|
||||
assert modified_request is not None
|
||||
assert modified_request.messages[0].content == ""
|
||||
assert modified_request.messages[1].content == "12345"
|
||||
# Original request should be unchanged
|
||||
assert request.messages[0].content == ""
|
||||
assert request.messages[1].content == "12345"
|
||||
assert state["messages"] == request.messages
|
||||
|
||||
|
||||
async def test_clear_tool_outputs_and_inputs_async() -> None:
|
||||
@@ -283,14 +314,19 @@ async def test_clear_tool_outputs_and_inputs_async() -> None:
|
||||
)
|
||||
middleware = ContextEditingMiddleware(edits=[edit])
|
||||
|
||||
modified_request = None
|
||||
|
||||
async def mock_handler(req: ModelRequest) -> AIMessage:
|
||||
nonlocal modified_request
|
||||
modified_request = req
|
||||
return AIMessage(content="mock response")
|
||||
|
||||
# Call awrap_model_call which modifies the request
|
||||
# Call awrap_model_call which creates a new request with edits
|
||||
await middleware.awrap_model_call(request, mock_handler)
|
||||
|
||||
cleared_ai = request.messages[0]
|
||||
cleared_tool = request.messages[1]
|
||||
assert modified_request is not None
|
||||
cleared_ai = modified_request.messages[0]
|
||||
cleared_tool = modified_request.messages[1]
|
||||
|
||||
assert isinstance(cleared_tool, ToolMessage)
|
||||
assert cleared_tool.content == "[cleared output]"
|
||||
@@ -302,7 +338,9 @@ async def test_clear_tool_outputs_and_inputs_async() -> None:
|
||||
assert context_meta is not None
|
||||
assert context_meta["cleared_tool_inputs"] == [tool_call_id]
|
||||
|
||||
assert state["messages"] == request.messages
|
||||
# Original request should be unchanged
|
||||
assert request.messages[0].tool_calls[0]["args"] == {"query": "foo"}
|
||||
assert request.messages[1].content == "x" * 200
|
||||
|
||||
|
||||
async def test_respects_keep_last_tool_results_async() -> None:
|
||||
@@ -336,21 +374,26 @@ async def test_respects_keep_last_tool_results_async() -> None:
|
||||
token_count_method="model",
|
||||
)
|
||||
|
||||
modified_request = None
|
||||
|
||||
async def mock_handler(req: ModelRequest) -> AIMessage:
|
||||
nonlocal modified_request
|
||||
modified_request = req
|
||||
return AIMessage(content="mock response")
|
||||
|
||||
# Call awrap_model_call which modifies the request
|
||||
# Call awrap_model_call which creates a new request with edits
|
||||
await middleware.awrap_model_call(request, mock_handler)
|
||||
|
||||
assert modified_request is not None
|
||||
cleared_messages = [
|
||||
msg
|
||||
for msg in request.messages
|
||||
for msg in modified_request.messages
|
||||
if isinstance(msg, ToolMessage) and msg.content == "[cleared]"
|
||||
]
|
||||
|
||||
assert len(cleared_messages) == 2
|
||||
assert isinstance(request.messages[-1], ToolMessage)
|
||||
assert request.messages[-1].content != "[cleared]"
|
||||
assert isinstance(modified_request.messages[-1], ToolMessage)
|
||||
assert modified_request.messages[-1].content != "[cleared]"
|
||||
|
||||
|
||||
async def test_exclude_tools_prevents_clearing_async() -> None:
|
||||
@@ -385,14 +428,19 @@ async def test_exclude_tools_prevents_clearing_async() -> None:
|
||||
],
|
||||
)
|
||||
|
||||
modified_request = None
|
||||
|
||||
async def mock_handler(req: ModelRequest) -> AIMessage:
|
||||
nonlocal modified_request
|
||||
modified_request = req
|
||||
return AIMessage(content="mock response")
|
||||
|
||||
# Call awrap_model_call which modifies the request
|
||||
# Call awrap_model_call which creates a new request with edits
|
||||
await middleware.awrap_model_call(request, mock_handler)
|
||||
|
||||
search_tool = request.messages[1]
|
||||
calc_tool = request.messages[3]
|
||||
assert modified_request is not None
|
||||
search_tool = modified_request.messages[1]
|
||||
calc_tool = modified_request.messages[3]
|
||||
|
||||
assert isinstance(search_tool, ToolMessage)
|
||||
assert search_tool.content == "search-results" * 20
|
||||
@@ -7,6 +7,9 @@ import pytest
|
||||
|
||||
from langchain.agents.middleware.file_search import (
|
||||
FilesystemFileSearchMiddleware,
|
||||
_expand_include_patterns,
|
||||
_is_valid_include_pattern,
|
||||
_match_include_pattern,
|
||||
)
|
||||
|
||||
|
||||
@@ -259,3 +262,105 @@ class TestPathTraversalSecurity:
|
||||
|
||||
assert result == "No matches found"
|
||||
assert "secret" not in result
|
||||
|
||||
|
||||
class TestExpandIncludePatterns:
|
||||
"""Tests for _expand_include_patterns helper function."""
|
||||
|
||||
def test_expand_patterns_basic_brace_expansion(self) -> None:
|
||||
"""Test basic brace expansion with multiple options."""
|
||||
result = _expand_include_patterns("*.{py,txt}")
|
||||
assert result == ["*.py", "*.txt"]
|
||||
|
||||
def test_expand_patterns_nested_braces(self) -> None:
|
||||
"""Test nested brace expansion."""
|
||||
result = _expand_include_patterns("test.{a,b}.{c,d}")
|
||||
assert result is not None
|
||||
assert len(result) == 4
|
||||
assert "test.a.c" in result
|
||||
assert "test.b.d" in result
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"pattern",
|
||||
[
|
||||
"*.py}", # closing brace without opening
|
||||
"*.{}", # empty braces
|
||||
"*.{py", # unclosed brace
|
||||
],
|
||||
)
|
||||
def test_expand_patterns_invalid_braces(self, pattern: str) -> None:
|
||||
"""Test patterns with invalid brace syntax return None."""
|
||||
result = _expand_include_patterns(pattern)
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestValidateIncludePattern:
|
||||
"""Tests for _is_valid_include_pattern helper function."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"pattern",
|
||||
[
|
||||
"", # empty pattern
|
||||
"*.py\x00", # null byte
|
||||
"*.py\n", # newline
|
||||
],
|
||||
)
|
||||
def test_validate_invalid_patterns(self, pattern: str) -> None:
|
||||
"""Test that invalid patterns are rejected."""
|
||||
assert not _is_valid_include_pattern(pattern)
|
||||
|
||||
|
||||
class TestMatchIncludePattern:
|
||||
"""Tests for _match_include_pattern helper function."""
|
||||
|
||||
def test_match_pattern_with_braces(self) -> None:
|
||||
"""Test matching with brace expansion."""
|
||||
assert _match_include_pattern("test.py", "*.{py,txt}")
|
||||
assert _match_include_pattern("test.txt", "*.{py,txt}")
|
||||
assert not _match_include_pattern("test.md", "*.{py,txt}")
|
||||
|
||||
def test_match_pattern_invalid_expansion(self) -> None:
|
||||
"""Test matching with pattern that cannot be expanded returns False."""
|
||||
assert not _match_include_pattern("test.py", "*.{}")
|
||||
|
||||
|
||||
class TestGrepEdgeCases:
|
||||
"""Tests for edge cases in grep search."""
|
||||
|
||||
def test_grep_with_special_chars_in_pattern(self, tmp_path: Path) -> None:
|
||||
"""Test grep with special characters in pattern."""
|
||||
(tmp_path / "test.py").write_text("def test():\n pass\n", encoding="utf-8")
|
||||
|
||||
middleware = FilesystemFileSearchMiddleware(root_path=str(tmp_path), use_ripgrep=False)
|
||||
|
||||
result = middleware.grep_search.func(pattern="def.*:")
|
||||
|
||||
assert "/test.py" in result
|
||||
|
||||
def test_grep_case_insensitive(self, tmp_path: Path) -> None:
|
||||
"""Test grep with case-insensitive search."""
|
||||
(tmp_path / "test.py").write_text("HELLO world\n", encoding="utf-8")
|
||||
|
||||
middleware = FilesystemFileSearchMiddleware(root_path=str(tmp_path), use_ripgrep=False)
|
||||
|
||||
result = middleware.grep_search.func(pattern="(?i)hello")
|
||||
|
||||
assert "/test.py" in result
|
||||
|
||||
def test_grep_with_large_file_skipping(self, tmp_path: Path) -> None:
|
||||
"""Test that grep skips files larger than max_file_size_mb."""
|
||||
# Create a file larger than 1MB
|
||||
large_content = "x" * (2 * 1024 * 1024) # 2MB
|
||||
(tmp_path / "large.txt").write_text(large_content, encoding="utf-8")
|
||||
(tmp_path / "small.txt").write_text("x", encoding="utf-8")
|
||||
|
||||
middleware = FilesystemFileSearchMiddleware(
|
||||
root_path=str(tmp_path),
|
||||
use_ripgrep=False,
|
||||
max_file_size_mb=1, # 1MB limit
|
||||
)
|
||||
|
||||
result = middleware.grep_search.func(pattern="x")
|
||||
|
||||
# Large file should be skipped
|
||||
assert "/small.txt" in result
|
||||
@@ -0,0 +1,575 @@
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, HumanMessage, ToolCall, ToolMessage
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from langchain.agents.middleware.human_in_the_loop import (
|
||||
Action,
|
||||
HumanInTheLoopMiddleware,
|
||||
)
|
||||
from langchain.agents.middleware.types import AgentState
|
||||
|
||||
|
||||
def test_human_in_the_loop_middleware_initialization() -> None:
|
||||
"""Test HumanInTheLoopMiddleware initialization."""
|
||||
|
||||
middleware = HumanInTheLoopMiddleware(
|
||||
interrupt_on={"test_tool": {"allowed_decisions": ["approve", "edit", "reject"]}},
|
||||
description_prefix="Custom prefix",
|
||||
)
|
||||
|
||||
assert middleware.interrupt_on == {
|
||||
"test_tool": {"allowed_decisions": ["approve", "edit", "reject"]}
|
||||
}
|
||||
assert middleware.description_prefix == "Custom prefix"
|
||||
|
||||
|
||||
def test_human_in_the_loop_middleware_no_interrupts_needed() -> None:
|
||||
"""Test HumanInTheLoopMiddleware when no interrupts are needed."""
|
||||
|
||||
middleware = HumanInTheLoopMiddleware(
|
||||
interrupt_on={"test_tool": {"allowed_decisions": ["approve", "edit", "reject"]}}
|
||||
)
|
||||
|
||||
# Test with no messages
|
||||
state: dict[str, Any] = {"messages": []}
|
||||
result = middleware.after_model(state, None)
|
||||
assert result is None
|
||||
|
||||
# Test with message but no tool calls
|
||||
state = {"messages": [HumanMessage(content="Hello"), AIMessage(content="Hi there")]}
|
||||
result = middleware.after_model(state, None)
|
||||
assert result is None
|
||||
|
||||
# Test with tool calls that don't require interrupts
|
||||
ai_message = AIMessage(
|
||||
content="I'll help you",
|
||||
tool_calls=[{"name": "other_tool", "args": {"input": "test"}, "id": "1"}],
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
|
||||
result = middleware.after_model(state, None)
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_human_in_the_loop_middleware_single_tool_accept() -> None:
|
||||
"""Test HumanInTheLoopMiddleware with single tool accept response."""
|
||||
|
||||
middleware = HumanInTheLoopMiddleware(
|
||||
interrupt_on={"test_tool": {"allowed_decisions": ["approve", "edit", "reject"]}}
|
||||
)
|
||||
|
||||
ai_message = AIMessage(
|
||||
content="I'll help you",
|
||||
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
|
||||
|
||||
def mock_accept(requests):
|
||||
return {"decisions": [{"type": "approve"}]}
|
||||
|
||||
with patch("langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_accept):
|
||||
result = middleware.after_model(state, None)
|
||||
assert result is not None
|
||||
assert "messages" in result
|
||||
assert len(result["messages"]) == 1
|
||||
assert result["messages"][0] == ai_message
|
||||
assert result["messages"][0].tool_calls == ai_message.tool_calls
|
||||
|
||||
state["messages"].append(
|
||||
ToolMessage(content="Tool message", name="test_tool", tool_call_id="1")
|
||||
)
|
||||
state["messages"].append(AIMessage(content="test_tool called with result: Tool message"))
|
||||
|
||||
result = middleware.after_model(state, None)
|
||||
# No interrupts needed
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_human_in_the_loop_middleware_single_tool_edit() -> None:
|
||||
"""Test HumanInTheLoopMiddleware with single tool edit response."""
|
||||
middleware = HumanInTheLoopMiddleware(
|
||||
interrupt_on={"test_tool": {"allowed_decisions": ["approve", "edit", "reject"]}}
|
||||
)
|
||||
|
||||
ai_message = AIMessage(
|
||||
content="I'll help you",
|
||||
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
|
||||
|
||||
def mock_edit(requests):
|
||||
return {
|
||||
"decisions": [
|
||||
{
|
||||
"type": "edit",
|
||||
"edited_action": Action(
|
||||
name="test_tool",
|
||||
args={"input": "edited"},
|
||||
),
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with patch("langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_edit):
|
||||
result = middleware.after_model(state, None)
|
||||
assert result is not None
|
||||
assert "messages" in result
|
||||
assert len(result["messages"]) == 1
|
||||
assert result["messages"][0].tool_calls[0]["args"] == {"input": "edited"}
|
||||
assert result["messages"][0].tool_calls[0]["id"] == "1" # ID should be preserved
|
||||
|
||||
|
||||
def test_human_in_the_loop_middleware_single_tool_response() -> None:
|
||||
"""Test HumanInTheLoopMiddleware with single tool response with custom message."""
|
||||
|
||||
middleware = HumanInTheLoopMiddleware(
|
||||
interrupt_on={"test_tool": {"allowed_decisions": ["approve", "edit", "reject"]}}
|
||||
)
|
||||
|
||||
ai_message = AIMessage(
|
||||
content="I'll help you",
|
||||
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
|
||||
|
||||
def mock_response(requests):
|
||||
return {"decisions": [{"type": "reject", "message": "Custom response message"}]}
|
||||
|
||||
with patch(
|
||||
"langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_response
|
||||
):
|
||||
result = middleware.after_model(state, None)
|
||||
assert result is not None
|
||||
assert "messages" in result
|
||||
assert len(result["messages"]) == 2
|
||||
assert isinstance(result["messages"][0], AIMessage)
|
||||
assert isinstance(result["messages"][1], ToolMessage)
|
||||
assert result["messages"][1].content == "Custom response message"
|
||||
assert result["messages"][1].name == "test_tool"
|
||||
assert result["messages"][1].tool_call_id == "1"
|
||||
|
||||
|
||||
def test_human_in_the_loop_middleware_multiple_tools_mixed_responses() -> None:
|
||||
"""Test HumanInTheLoopMiddleware with multiple tools and mixed response types."""
|
||||
|
||||
middleware = HumanInTheLoopMiddleware(
|
||||
interrupt_on={
|
||||
"get_forecast": {"allowed_decisions": ["approve", "edit", "reject"]},
|
||||
"get_temperature": {"allowed_decisions": ["approve", "edit", "reject"]},
|
||||
}
|
||||
)
|
||||
|
||||
ai_message = AIMessage(
|
||||
content="I'll help you with weather",
|
||||
tool_calls=[
|
||||
{"name": "get_forecast", "args": {"location": "San Francisco"}, "id": "1"},
|
||||
{"name": "get_temperature", "args": {"location": "San Francisco"}, "id": "2"},
|
||||
],
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="What's the weather?"), ai_message]}
|
||||
|
||||
def mock_mixed_responses(requests):
|
||||
return {
|
||||
"decisions": [
|
||||
{"type": "approve"},
|
||||
{"type": "reject", "message": "User rejected this tool call"},
|
||||
]
|
||||
}
|
||||
|
||||
with patch(
|
||||
"langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_mixed_responses
|
||||
):
|
||||
result = middleware.after_model(state, None)
|
||||
assert result is not None
|
||||
assert "messages" in result
|
||||
assert (
|
||||
len(result["messages"]) == 2
|
||||
) # AI message with accepted tool call + tool message for rejected
|
||||
|
||||
# First message should be the AI message with both tool calls
|
||||
updated_ai_message = result["messages"][0]
|
||||
assert len(updated_ai_message.tool_calls) == 2 # Both tool calls remain
|
||||
assert updated_ai_message.tool_calls[0]["name"] == "get_forecast" # Accepted
|
||||
assert updated_ai_message.tool_calls[1]["name"] == "get_temperature" # Got response
|
||||
|
||||
# Second message should be the tool message for the rejected tool call
|
||||
tool_message = result["messages"][1]
|
||||
assert isinstance(tool_message, ToolMessage)
|
||||
assert tool_message.content == "User rejected this tool call"
|
||||
assert tool_message.name == "get_temperature"
|
||||
|
||||
|
||||
def test_human_in_the_loop_middleware_multiple_tools_edit_responses() -> None:
|
||||
"""Test HumanInTheLoopMiddleware with multiple tools and edit responses."""
|
||||
|
||||
middleware = HumanInTheLoopMiddleware(
|
||||
interrupt_on={
|
||||
"get_forecast": {"allowed_decisions": ["approve", "edit", "reject"]},
|
||||
"get_temperature": {"allowed_decisions": ["approve", "edit", "reject"]},
|
||||
}
|
||||
)
|
||||
|
||||
ai_message = AIMessage(
|
||||
content="I'll help you with weather",
|
||||
tool_calls=[
|
||||
{"name": "get_forecast", "args": {"location": "San Francisco"}, "id": "1"},
|
||||
{"name": "get_temperature", "args": {"location": "San Francisco"}, "id": "2"},
|
||||
],
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="What's the weather?"), ai_message]}
|
||||
|
||||
def mock_edit_responses(requests):
|
||||
return {
|
||||
"decisions": [
|
||||
{
|
||||
"type": "edit",
|
||||
"edited_action": Action(
|
||||
name="get_forecast",
|
||||
args={"location": "New York"},
|
||||
),
|
||||
},
|
||||
{
|
||||
"type": "edit",
|
||||
"edited_action": Action(
|
||||
name="get_temperature",
|
||||
args={"location": "New York"},
|
||||
),
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
with patch(
|
||||
"langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_edit_responses
|
||||
):
|
||||
result = middleware.after_model(state, None)
|
||||
assert result is not None
|
||||
assert "messages" in result
|
||||
assert len(result["messages"]) == 1
|
||||
|
||||
updated_ai_message = result["messages"][0]
|
||||
assert updated_ai_message.tool_calls[0]["args"] == {"location": "New York"}
|
||||
assert updated_ai_message.tool_calls[0]["id"] == "1" # ID preserved
|
||||
assert updated_ai_message.tool_calls[1]["args"] == {"location": "New York"}
|
||||
assert updated_ai_message.tool_calls[1]["id"] == "2" # ID preserved
|
||||
|
||||
|
||||
def test_human_in_the_loop_middleware_edit_with_modified_args() -> None:
|
||||
"""Test HumanInTheLoopMiddleware with edit action that includes modified args."""
|
||||
|
||||
middleware = HumanInTheLoopMiddleware(
|
||||
interrupt_on={"test_tool": {"allowed_decisions": ["approve", "edit", "reject"]}}
|
||||
)
|
||||
|
||||
ai_message = AIMessage(
|
||||
content="I'll help you",
|
||||
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
|
||||
|
||||
def mock_edit_with_args(requests):
|
||||
return {
|
||||
"decisions": [
|
||||
{
|
||||
"type": "edit",
|
||||
"edited_action": Action(
|
||||
name="test_tool",
|
||||
args={"input": "modified"},
|
||||
),
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with patch(
|
||||
"langchain.agents.middleware.human_in_the_loop.interrupt",
|
||||
side_effect=mock_edit_with_args,
|
||||
):
|
||||
result = middleware.after_model(state, None)
|
||||
assert result is not None
|
||||
assert "messages" in result
|
||||
assert len(result["messages"]) == 1
|
||||
|
||||
# Should have modified args
|
||||
updated_ai_message = result["messages"][0]
|
||||
assert updated_ai_message.tool_calls[0]["args"] == {"input": "modified"}
|
||||
assert updated_ai_message.tool_calls[0]["id"] == "1" # ID preserved
|
||||
|
||||
|
||||
def test_human_in_the_loop_middleware_unknown_response_type() -> None:
|
||||
"""Test HumanInTheLoopMiddleware with unknown response type."""
|
||||
middleware = HumanInTheLoopMiddleware(
|
||||
interrupt_on={"test_tool": {"allowed_decisions": ["approve", "edit", "reject"]}}
|
||||
)
|
||||
|
||||
ai_message = AIMessage(
|
||||
content="I'll help you",
|
||||
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
|
||||
|
||||
def mock_unknown(requests):
|
||||
return {"decisions": [{"type": "unknown"}]}
|
||||
|
||||
with patch("langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_unknown):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r"Unexpected human decision: {'type': 'unknown'}. Decision type 'unknown' is not allowed for tool 'test_tool'. Expected one of \['approve', 'edit', 'reject'\] based on the tool's configuration.",
|
||||
):
|
||||
middleware.after_model(state, None)
|
||||
|
||||
|
||||
def test_human_in_the_loop_middleware_disallowed_action() -> None:
|
||||
"""Test HumanInTheLoopMiddleware with action not allowed by tool config."""
|
||||
|
||||
# edit is not allowed by tool config
|
||||
middleware = HumanInTheLoopMiddleware(
|
||||
interrupt_on={"test_tool": {"allowed_decisions": ["approve", "reject"]}}
|
||||
)
|
||||
|
||||
ai_message = AIMessage(
|
||||
content="I'll help you",
|
||||
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
|
||||
|
||||
def mock_disallowed_action(requests):
|
||||
return {
|
||||
"decisions": [
|
||||
{
|
||||
"type": "edit",
|
||||
"edited_action": Action(
|
||||
name="test_tool",
|
||||
args={"input": "modified"},
|
||||
),
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with patch(
|
||||
"langchain.agents.middleware.human_in_the_loop.interrupt",
|
||||
side_effect=mock_disallowed_action,
|
||||
):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r"Unexpected human decision: {'type': 'edit', 'edited_action': {'name': 'test_tool', 'args': {'input': 'modified'}}}. Decision type 'edit' is not allowed for tool 'test_tool'. Expected one of \['approve', 'reject'\] based on the tool's configuration.",
|
||||
):
|
||||
middleware.after_model(state, None)
|
||||
|
||||
|
||||
def test_human_in_the_loop_middleware_mixed_auto_approved_and_interrupt() -> None:
|
||||
"""Test HumanInTheLoopMiddleware with mix of auto-approved and interrupt tools."""
|
||||
|
||||
middleware = HumanInTheLoopMiddleware(
|
||||
interrupt_on={"interrupt_tool": {"allowed_decisions": ["approve", "edit", "reject"]}}
|
||||
)
|
||||
|
||||
ai_message = AIMessage(
|
||||
content="I'll help you",
|
||||
tool_calls=[
|
||||
{"name": "auto_tool", "args": {"input": "auto"}, "id": "1"},
|
||||
{"name": "interrupt_tool", "args": {"input": "interrupt"}, "id": "2"},
|
||||
],
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
|
||||
|
||||
def mock_accept(requests):
|
||||
return {"decisions": [{"type": "approve"}]}
|
||||
|
||||
with patch("langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_accept):
|
||||
result = middleware.after_model(state, None)
|
||||
assert result is not None
|
||||
assert "messages" in result
|
||||
assert len(result["messages"]) == 1
|
||||
|
||||
updated_ai_message = result["messages"][0]
|
||||
# Should have both tools: auto-approved first, then interrupt tool
|
||||
assert len(updated_ai_message.tool_calls) == 2
|
||||
assert updated_ai_message.tool_calls[0]["name"] == "auto_tool"
|
||||
assert updated_ai_message.tool_calls[1]["name"] == "interrupt_tool"
|
||||
|
||||
|
||||
def test_human_in_the_loop_middleware_interrupt_request_structure() -> None:
|
||||
"""Test that interrupt requests are structured correctly."""
|
||||
|
||||
middleware = HumanInTheLoopMiddleware(
|
||||
interrupt_on={"test_tool": {"allowed_decisions": ["approve", "edit", "reject"]}},
|
||||
description_prefix="Custom prefix",
|
||||
)
|
||||
|
||||
ai_message = AIMessage(
|
||||
content="I'll help you",
|
||||
tool_calls=[{"name": "test_tool", "args": {"input": "test", "location": "SF"}, "id": "1"}],
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
|
||||
|
||||
captured_request = None
|
||||
|
||||
def mock_capture_requests(request):
|
||||
nonlocal captured_request
|
||||
captured_request = request
|
||||
return {"decisions": [{"type": "approve"}]}
|
||||
|
||||
with patch(
|
||||
"langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_capture_requests
|
||||
):
|
||||
middleware.after_model(state, None)
|
||||
|
||||
assert captured_request is not None
|
||||
assert "action_requests" in captured_request
|
||||
assert "review_configs" in captured_request
|
||||
|
||||
assert len(captured_request["action_requests"]) == 1
|
||||
action_request = captured_request["action_requests"][0]
|
||||
assert action_request["name"] == "test_tool"
|
||||
assert action_request["args"] == {"input": "test", "location": "SF"}
|
||||
assert "Custom prefix" in action_request["description"]
|
||||
assert "Tool: test_tool" in action_request["description"]
|
||||
assert "Args: {'input': 'test', 'location': 'SF'}" in action_request["description"]
|
||||
|
||||
assert len(captured_request["review_configs"]) == 1
|
||||
review_config = captured_request["review_configs"][0]
|
||||
assert review_config["action_name"] == "test_tool"
|
||||
assert review_config["allowed_decisions"] == ["approve", "edit", "reject"]
|
||||
|
||||
|
||||
def test_human_in_the_loop_middleware_boolean_configs() -> None:
|
||||
"""Test HITL middleware with boolean tool configs."""
|
||||
middleware = HumanInTheLoopMiddleware(interrupt_on={"test_tool": True})
|
||||
|
||||
ai_message = AIMessage(
|
||||
content="I'll help you",
|
||||
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
|
||||
|
||||
# Test accept
|
||||
with patch(
|
||||
"langchain.agents.middleware.human_in_the_loop.interrupt",
|
||||
return_value={"decisions": [{"type": "approve"}]},
|
||||
):
|
||||
result = middleware.after_model(state, None)
|
||||
assert result is not None
|
||||
assert "messages" in result
|
||||
assert len(result["messages"]) == 1
|
||||
assert result["messages"][0].tool_calls == ai_message.tool_calls
|
||||
|
||||
# Test edit
|
||||
with patch(
|
||||
"langchain.agents.middleware.human_in_the_loop.interrupt",
|
||||
return_value={
|
||||
"decisions": [
|
||||
{
|
||||
"type": "edit",
|
||||
"edited_action": Action(
|
||||
name="test_tool",
|
||||
args={"input": "edited"},
|
||||
),
|
||||
}
|
||||
]
|
||||
},
|
||||
):
|
||||
result = middleware.after_model(state, None)
|
||||
assert result is not None
|
||||
assert "messages" in result
|
||||
assert len(result["messages"]) == 1
|
||||
assert result["messages"][0].tool_calls[0]["args"] == {"input": "edited"}
|
||||
|
||||
middleware = HumanInTheLoopMiddleware(interrupt_on={"test_tool": False})
|
||||
|
||||
result = middleware.after_model(state, None)
|
||||
# No interruption should occur
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_human_in_the_loop_middleware_sequence_mismatch() -> None:
|
||||
"""Test that sequence mismatch in resume raises an error."""
|
||||
middleware = HumanInTheLoopMiddleware(interrupt_on={"test_tool": True})
|
||||
|
||||
ai_message = AIMessage(
|
||||
content="I'll help you",
|
||||
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
|
||||
|
||||
# Test with too few responses
|
||||
with patch(
|
||||
"langchain.agents.middleware.human_in_the_loop.interrupt",
|
||||
return_value={"decisions": []}, # No responses for 1 tool call
|
||||
):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r"Number of human decisions \(0\) does not match number of hanging tool calls \(1\)\.",
|
||||
):
|
||||
middleware.after_model(state, None)
|
||||
|
||||
# Test with too many responses
|
||||
with patch(
|
||||
"langchain.agents.middleware.human_in_the_loop.interrupt",
|
||||
return_value={
|
||||
"decisions": [
|
||||
{"type": "approve"},
|
||||
{"type": "approve"},
|
||||
]
|
||||
}, # 2 responses for 1 tool call
|
||||
):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r"Number of human decisions \(2\) does not match number of hanging tool calls \(1\)\.",
|
||||
):
|
||||
middleware.after_model(state, None)
|
||||
|
||||
|
||||
def test_human_in_the_loop_middleware_description_as_callable() -> None:
|
||||
"""Test that description field accepts both string and callable."""
|
||||
|
||||
def custom_description(tool_call: ToolCall, state: AgentState, runtime: Runtime) -> str:
|
||||
"""Generate a custom description."""
|
||||
return f"Custom: {tool_call['name']} with args {tool_call['args']}"
|
||||
|
||||
middleware = HumanInTheLoopMiddleware(
|
||||
interrupt_on={
|
||||
"tool_with_callable": {
|
||||
"allowed_decisions": ["approve"],
|
||||
"description": custom_description,
|
||||
},
|
||||
"tool_with_string": {
|
||||
"allowed_decisions": ["approve"],
|
||||
"description": "Static description",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
ai_message = AIMessage(
|
||||
content="I'll help you",
|
||||
tool_calls=[
|
||||
{"name": "tool_with_callable", "args": {"x": 1}, "id": "1"},
|
||||
{"name": "tool_with_string", "args": {"y": 2}, "id": "2"},
|
||||
],
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
|
||||
|
||||
captured_request = None
|
||||
|
||||
def mock_capture_requests(request):
|
||||
nonlocal captured_request
|
||||
captured_request = request
|
||||
return {"decisions": [{"type": "approve"}, {"type": "approve"}]}
|
||||
|
||||
with patch(
|
||||
"langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_capture_requests
|
||||
):
|
||||
middleware.after_model(state, None)
|
||||
|
||||
assert captured_request is not None
|
||||
assert "action_requests" in captured_request
|
||||
assert len(captured_request["action_requests"]) == 2
|
||||
|
||||
# Check callable description
|
||||
assert (
|
||||
captured_request["action_requests"][0]["description"]
|
||||
== "Custom: tool_with_callable with args {'x': 1}"
|
||||
)
|
||||
|
||||
# Check string description
|
||||
assert captured_request["action_requests"][1]["description"] == "Static description"
|
||||
@@ -0,0 +1,224 @@
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
from langchain.agents.factory import create_agent
|
||||
from langchain.agents.middleware.model_call_limit import (
|
||||
ModelCallLimitMiddleware,
|
||||
ModelCallLimitExceededError,
|
||||
)
|
||||
|
||||
from ...model import FakeToolCallingModel
|
||||
|
||||
|
||||
@tool
|
||||
def simple_tool(input: str) -> str:
|
||||
"""A simple tool"""
|
||||
return input
|
||||
|
||||
|
||||
def test_middleware_unit_functionality():
|
||||
"""Test that the middleware works as expected in isolation."""
|
||||
# Test with end behavior
|
||||
middleware = ModelCallLimitMiddleware(thread_limit=2, run_limit=1)
|
||||
|
||||
# Mock runtime (not used in current implementation)
|
||||
runtime = None
|
||||
|
||||
# Test when limits are not exceeded
|
||||
state = {"thread_model_call_count": 0, "run_model_call_count": 0}
|
||||
result = middleware.before_model(state, runtime)
|
||||
assert result is None
|
||||
|
||||
# Test when thread limit is exceeded
|
||||
state = {"thread_model_call_count": 2, "run_model_call_count": 0}
|
||||
result = middleware.before_model(state, runtime)
|
||||
assert result is not None
|
||||
assert result["jump_to"] == "end"
|
||||
assert "messages" in result
|
||||
assert len(result["messages"]) == 1
|
||||
assert "thread limit (2/2)" in result["messages"][0].content
|
||||
|
||||
# Test when run limit is exceeded
|
||||
state = {"thread_model_call_count": 1, "run_model_call_count": 1}
|
||||
result = middleware.before_model(state, runtime)
|
||||
assert result is not None
|
||||
assert result["jump_to"] == "end"
|
||||
assert "messages" in result
|
||||
assert len(result["messages"]) == 1
|
||||
assert "run limit (1/1)" in result["messages"][0].content
|
||||
|
||||
# Test with error behavior
|
||||
middleware_exception = ModelCallLimitMiddleware(
|
||||
thread_limit=2, run_limit=1, exit_behavior="error"
|
||||
)
|
||||
|
||||
# Test exception when thread limit exceeded
|
||||
state = {"thread_model_call_count": 2, "run_model_call_count": 0}
|
||||
with pytest.raises(ModelCallLimitExceededError) as exc_info:
|
||||
middleware_exception.before_model(state, runtime)
|
||||
|
||||
assert "thread limit (2/2)" in str(exc_info.value)
|
||||
|
||||
# Test exception when run limit exceeded
|
||||
state = {"thread_model_call_count": 1, "run_model_call_count": 1}
|
||||
with pytest.raises(ModelCallLimitExceededError) as exc_info:
|
||||
middleware_exception.before_model(state, runtime)
|
||||
|
||||
assert "run limit (1/1)" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_thread_limit_with_create_agent():
|
||||
"""Test that thread limits work correctly with create_agent."""
|
||||
model = FakeToolCallingModel()
|
||||
|
||||
# Set thread limit to 1 (should be exceeded after 1 call)
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
tools=[simple_tool],
|
||||
middleware=[ModelCallLimitMiddleware(thread_limit=1)],
|
||||
checkpointer=InMemorySaver(),
|
||||
)
|
||||
|
||||
# First invocation should work - 1 model call, within thread limit
|
||||
result = agent.invoke(
|
||||
{"messages": [HumanMessage("Hello")]}, {"configurable": {"thread_id": "thread1"}}
|
||||
)
|
||||
|
||||
# Should complete successfully with 1 model call
|
||||
assert "messages" in result
|
||||
assert len(result["messages"]) == 2 # Human + AI messages
|
||||
|
||||
# Second invocation in same thread should hit thread limit
|
||||
# The agent should jump to end after detecting the limit
|
||||
result2 = agent.invoke(
|
||||
{"messages": [HumanMessage("Hello again")]}, {"configurable": {"thread_id": "thread1"}}
|
||||
)
|
||||
|
||||
assert "messages" in result2
|
||||
# The agent should have detected the limit and jumped to end with a limit exceeded message
|
||||
# So we should have: previous messages + new human message + limit exceeded AI message
|
||||
assert len(result2["messages"]) == 4 # Previous Human + AI + New Human + Limit AI
|
||||
assert isinstance(result2["messages"][0], HumanMessage) # First human
|
||||
assert isinstance(result2["messages"][1], AIMessage) # First AI response
|
||||
assert isinstance(result2["messages"][2], HumanMessage) # Second human
|
||||
assert isinstance(result2["messages"][3], AIMessage) # Limit exceeded message
|
||||
assert "thread limit" in result2["messages"][3].content
|
||||
|
||||
|
||||
def test_run_limit_with_create_agent():
|
||||
"""Test that run limits work correctly with create_agent."""
|
||||
# Create a model that will make 2 calls
|
||||
model = FakeToolCallingModel(
|
||||
tool_calls=[
|
||||
[{"name": "simple_tool", "args": {"input": "test"}, "id": "1"}],
|
||||
[], # No tool calls on second call
|
||||
]
|
||||
)
|
||||
|
||||
# Set run limit to 1 (should be exceeded after 1 call)
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
tools=[simple_tool],
|
||||
middleware=[ModelCallLimitMiddleware(run_limit=1)],
|
||||
checkpointer=InMemorySaver(),
|
||||
)
|
||||
|
||||
# This should hit the run limit after the first model call
|
||||
result = agent.invoke(
|
||||
{"messages": [HumanMessage("Hello")]}, {"configurable": {"thread_id": "thread1"}}
|
||||
)
|
||||
|
||||
assert "messages" in result
|
||||
# The agent should have made 1 model call then jumped to end with limit exceeded message
|
||||
# So we should have: Human + AI + Tool + Limit exceeded AI message
|
||||
assert len(result["messages"]) == 4 # Human + AI + Tool + Limit AI
|
||||
assert isinstance(result["messages"][0], HumanMessage)
|
||||
assert isinstance(result["messages"][1], AIMessage)
|
||||
assert isinstance(result["messages"][2], ToolMessage)
|
||||
assert isinstance(result["messages"][3], AIMessage) # Limit exceeded message
|
||||
assert "run limit" in result["messages"][3].content
|
||||
|
||||
|
||||
def test_middleware_initialization_validation():
|
||||
"""Test that middleware initialization validates parameters correctly."""
|
||||
# Test that at least one limit must be specified
|
||||
with pytest.raises(ValueError, match="At least one limit must be specified"):
|
||||
ModelCallLimitMiddleware()
|
||||
|
||||
# Test invalid exit behavior
|
||||
with pytest.raises(ValueError, match="Invalid exit_behavior"):
|
||||
ModelCallLimitMiddleware(thread_limit=5, exit_behavior="invalid")
|
||||
|
||||
# Test valid initialization
|
||||
middleware = ModelCallLimitMiddleware(thread_limit=5, run_limit=3)
|
||||
assert middleware.thread_limit == 5
|
||||
assert middleware.run_limit == 3
|
||||
assert middleware.exit_behavior == "end"
|
||||
|
||||
# Test with only thread limit
|
||||
middleware = ModelCallLimitMiddleware(thread_limit=5)
|
||||
assert middleware.thread_limit == 5
|
||||
assert middleware.run_limit is None
|
||||
|
||||
# Test with only run limit
|
||||
middleware = ModelCallLimitMiddleware(run_limit=3)
|
||||
assert middleware.thread_limit is None
|
||||
assert middleware.run_limit == 3
|
||||
|
||||
|
||||
def test_exception_error_message():
|
||||
"""Test that the exception provides clear error messages."""
|
||||
middleware = ModelCallLimitMiddleware(thread_limit=2, run_limit=1, exit_behavior="error")
|
||||
|
||||
# Test thread limit exceeded
|
||||
state = {"thread_model_call_count": 2, "run_model_call_count": 0}
|
||||
with pytest.raises(ModelCallLimitExceededError) as exc_info:
|
||||
middleware.before_model(state, None)
|
||||
|
||||
error_msg = str(exc_info.value)
|
||||
assert "Model call limits exceeded" in error_msg
|
||||
assert "thread limit (2/2)" in error_msg
|
||||
|
||||
# Test run limit exceeded
|
||||
state = {"thread_model_call_count": 0, "run_model_call_count": 1}
|
||||
with pytest.raises(ModelCallLimitExceededError) as exc_info:
|
||||
middleware.before_model(state, None)
|
||||
|
||||
error_msg = str(exc_info.value)
|
||||
assert "Model call limits exceeded" in error_msg
|
||||
assert "run limit (1/1)" in error_msg
|
||||
|
||||
# Test both limits exceeded
|
||||
state = {"thread_model_call_count": 2, "run_model_call_count": 1}
|
||||
with pytest.raises(ModelCallLimitExceededError) as exc_info:
|
||||
middleware.before_model(state, None)
|
||||
|
||||
error_msg = str(exc_info.value)
|
||||
assert "Model call limits exceeded" in error_msg
|
||||
assert "thread limit (2/2)" in error_msg
|
||||
assert "run limit (1/1)" in error_msg
|
||||
|
||||
|
||||
def test_run_limit_resets_between_invocations() -> None:
|
||||
"""Test that run_model_call_count resets between invocations, but thread_model_call_count accumulates."""
|
||||
|
||||
# First: No tool calls per invocation, so model does not increment call counts internally
|
||||
middleware = ModelCallLimitMiddleware(thread_limit=3, run_limit=1, exit_behavior="error")
|
||||
model = FakeToolCallingModel(
|
||||
tool_calls=[[], [], [], []]
|
||||
) # No tool calls, so only model call per run
|
||||
|
||||
agent = create_agent(model=model, middleware=[middleware], checkpointer=InMemorySaver())
|
||||
|
||||
thread_config = {"configurable": {"thread_id": "test_thread"}}
|
||||
agent.invoke({"messages": [HumanMessage("Hello")]}, thread_config)
|
||||
agent.invoke({"messages": [HumanMessage("Hello again")]}, thread_config)
|
||||
agent.invoke({"messages": [HumanMessage("Hello third")]}, thread_config)
|
||||
|
||||
# Fourth run: should raise, thread_model_call_count == 3 (limit)
|
||||
with pytest.raises(ModelCallLimitExceededError) as exc_info:
|
||||
agent.invoke({"messages": [HumanMessage("Hello fourth")]}, thread_config)
|
||||
error_msg = str(exc_info.value)
|
||||
assert "thread limit (3/3)" in error_msg
|
||||
@@ -2,16 +2,22 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
|
||||
from langchain.agents.factory import create_agent
|
||||
from langchain.agents.middleware.model_fallback import ModelFallbackMiddleware
|
||||
from langchain.agents.middleware.types import ModelRequest, ModelResponse
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from ...model import FakeToolCallingModel
|
||||
|
||||
|
||||
def _fake_runtime() -> Runtime:
|
||||
return cast(Runtime, object())
|
||||
@@ -40,7 +46,7 @@ def test_primary_model_succeeds() -> None:
|
||||
|
||||
middleware = ModelFallbackMiddleware(fallback_model)
|
||||
request = _make_request()
|
||||
request.model = primary_model
|
||||
request = request.override(model=primary_model)
|
||||
|
||||
def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
# Simulate successful model call
|
||||
@@ -65,7 +71,7 @@ def test_fallback_on_primary_failure() -> None:
|
||||
|
||||
middleware = ModelFallbackMiddleware(fallback_model)
|
||||
request = _make_request()
|
||||
request.model = primary_model
|
||||
request = request.override(model=primary_model)
|
||||
|
||||
def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
result = req.model.invoke([])
|
||||
@@ -90,7 +96,7 @@ def test_multiple_fallbacks() -> None:
|
||||
|
||||
middleware = ModelFallbackMiddleware(fallback1, fallback2)
|
||||
request = _make_request()
|
||||
request.model = primary_model
|
||||
request = request.override(model=primary_model)
|
||||
|
||||
def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
result = req.model.invoke([])
|
||||
@@ -114,7 +120,7 @@ def test_all_models_fail() -> None:
|
||||
|
||||
middleware = ModelFallbackMiddleware(fallback_model)
|
||||
request = _make_request()
|
||||
request.model = primary_model
|
||||
request = request.override(model=primary_model)
|
||||
|
||||
def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
result = req.model.invoke([])
|
||||
@@ -131,7 +137,7 @@ async def test_primary_model_succeeds_async() -> None:
|
||||
|
||||
middleware = ModelFallbackMiddleware(fallback_model)
|
||||
request = _make_request()
|
||||
request.model = primary_model
|
||||
request = request.override(model=primary_model)
|
||||
|
||||
async def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
# Simulate successful async model call
|
||||
@@ -156,7 +162,7 @@ async def test_fallback_on_primary_failure_async() -> None:
|
||||
|
||||
middleware = ModelFallbackMiddleware(fallback_model)
|
||||
request = _make_request()
|
||||
request.model = primary_model
|
||||
request = request.override(model=primary_model)
|
||||
|
||||
async def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
result = await req.model.ainvoke([])
|
||||
@@ -181,7 +187,7 @@ async def test_multiple_fallbacks_async() -> None:
|
||||
|
||||
middleware = ModelFallbackMiddleware(fallback1, fallback2)
|
||||
request = _make_request()
|
||||
request.model = primary_model
|
||||
request = request.override(model=primary_model)
|
||||
|
||||
async def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
result = await req.model.ainvoke([])
|
||||
@@ -205,7 +211,7 @@ async def test_all_models_fail_async() -> None:
|
||||
|
||||
middleware = ModelFallbackMiddleware(fallback_model)
|
||||
request = _make_request()
|
||||
request.model = primary_model
|
||||
request = request.override(model=primary_model)
|
||||
|
||||
async def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
result = await req.model.ainvoke([])
|
||||
@@ -213,3 +219,133 @@ async def test_all_models_fail_async() -> None:
|
||||
|
||||
with pytest.raises(ValueError, match="Model failed"):
|
||||
await middleware.awrap_model_call(request, mock_handler)
|
||||
|
||||
|
||||
def test_model_fallback_middleware_with_agent() -> None:
|
||||
"""Test ModelFallbackMiddleware with agent.invoke and fallback models only."""
|
||||
|
||||
class FailingModel(BaseChatModel):
|
||||
"""Model that always fails."""
|
||||
|
||||
def _generate(self, messages, **kwargs):
|
||||
raise ValueError("Primary model failed")
|
||||
|
||||
@property
|
||||
def _llm_type(self):
|
||||
return "failing"
|
||||
|
||||
class SuccessModel(BaseChatModel):
|
||||
"""Model that succeeds."""
|
||||
|
||||
def _generate(self, messages, **kwargs):
|
||||
return ChatResult(
|
||||
generations=[ChatGeneration(message=AIMessage(content="Fallback success"))]
|
||||
)
|
||||
|
||||
@property
|
||||
def _llm_type(self):
|
||||
return "success"
|
||||
|
||||
primary = FailingModel()
|
||||
fallback = SuccessModel()
|
||||
|
||||
# Only pass fallback models to middleware (not the primary)
|
||||
fallback_middleware = ModelFallbackMiddleware(fallback)
|
||||
|
||||
agent = create_agent(model=primary, middleware=[fallback_middleware])
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Test")]})
|
||||
|
||||
# Should have succeeded with fallback model
|
||||
assert len(result["messages"]) == 2
|
||||
assert result["messages"][1].content == "Fallback success"
|
||||
|
||||
|
||||
def test_model_fallback_middleware_exhausted_with_agent() -> None:
|
||||
"""Test ModelFallbackMiddleware with agent.invoke when all models fail."""
|
||||
|
||||
class AlwaysFailingModel(BaseChatModel):
|
||||
"""Model that always fails."""
|
||||
|
||||
def __init__(self, name: str):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
|
||||
def _generate(self, messages, **kwargs):
|
||||
raise ValueError(f"{self.name} failed")
|
||||
|
||||
@property
|
||||
def _llm_type(self):
|
||||
return self.name
|
||||
|
||||
primary = AlwaysFailingModel("primary")
|
||||
fallback1 = AlwaysFailingModel("fallback1")
|
||||
fallback2 = AlwaysFailingModel("fallback2")
|
||||
|
||||
# Primary fails (attempt 1), then fallback1 (attempt 2), then fallback2 (attempt 3)
|
||||
fallback_middleware = ModelFallbackMiddleware(fallback1, fallback2)
|
||||
|
||||
agent = create_agent(model=primary, middleware=[fallback_middleware])
|
||||
|
||||
# Should fail with the last fallback's error
|
||||
with pytest.raises(ValueError, match="fallback2 failed"):
|
||||
agent.invoke({"messages": [HumanMessage("Test")]})
|
||||
|
||||
|
||||
def test_model_fallback_middleware_initialization() -> None:
|
||||
"""Test ModelFallbackMiddleware initialization."""
|
||||
|
||||
# Test with no models - now a TypeError (missing required argument)
|
||||
with pytest.raises(TypeError):
|
||||
ModelFallbackMiddleware() # type: ignore[call-arg]
|
||||
|
||||
# Test with one fallback model (valid)
|
||||
middleware = ModelFallbackMiddleware(FakeToolCallingModel())
|
||||
assert len(middleware.models) == 1
|
||||
|
||||
# Test with multiple fallback models
|
||||
middleware = ModelFallbackMiddleware(FakeToolCallingModel(), FakeToolCallingModel())
|
||||
assert len(middleware.models) == 2
|
||||
|
||||
|
||||
def test_model_request_is_frozen() -> None:
|
||||
"""Test that ModelRequest raises deprecation warning on direct attribute assignment."""
|
||||
request = _make_request()
|
||||
new_model = GenericFakeChatModel(messages=iter([AIMessage(content="new model")]))
|
||||
|
||||
# Direct attribute assignment should raise DeprecationWarning but still work
|
||||
with pytest.warns(
|
||||
DeprecationWarning, match="Direct attribute assignment to ModelRequest.model is deprecated"
|
||||
):
|
||||
request.model = new_model # type: ignore[misc]
|
||||
|
||||
# Verify the assignment actually worked
|
||||
assert request.model == new_model
|
||||
|
||||
with pytest.warns(
|
||||
DeprecationWarning,
|
||||
match="Direct attribute assignment to ModelRequest.system_prompt is deprecated",
|
||||
):
|
||||
request.system_prompt = "new prompt" # type: ignore[misc]
|
||||
|
||||
assert request.system_prompt == "new prompt"
|
||||
|
||||
with pytest.warns(
|
||||
DeprecationWarning,
|
||||
match="Direct attribute assignment to ModelRequest.messages is deprecated",
|
||||
):
|
||||
request.messages = [] # type: ignore[misc]
|
||||
|
||||
assert request.messages == []
|
||||
|
||||
# Using override method should work without warnings
|
||||
request2 = _make_request()
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("error") # Turn warnings into errors
|
||||
new_request = request2.override(model=new_model, system_prompt="override prompt")
|
||||
|
||||
assert new_request.model == new_model
|
||||
assert new_request.system_prompt == "override prompt"
|
||||
# Original request should be unchanged
|
||||
assert request2.model != new_model
|
||||
assert request2.system_prompt != "override prompt"
|
||||
@@ -14,7 +14,7 @@ from langchain.agents.middleware.pii import (
|
||||
)
|
||||
from langchain.agents.factory import create_agent
|
||||
|
||||
from .model import FakeToolCallingModel
|
||||
from tests.unit_tests.agents.model import FakeToolCallingModel
|
||||
|
||||
|
||||
# ============================================================================
|
||||
@@ -0,0 +1,556 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import gc
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, ToolMessage
|
||||
from langchain_core.tools.base import ToolException
|
||||
|
||||
from langchain.agents.middleware.shell_tool import (
|
||||
HostExecutionPolicy,
|
||||
RedactionRule,
|
||||
ShellToolMiddleware,
|
||||
_SessionResources,
|
||||
_ShellToolInput,
|
||||
)
|
||||
from langchain.agents.middleware.types import AgentState
|
||||
|
||||
|
||||
def _empty_state() -> AgentState:
|
||||
return {"messages": []} # type: ignore[return-value]
|
||||
|
||||
|
||||
def test_executes_command_and_persists_state(tmp_path: Path) -> None:
|
||||
workspace = tmp_path / "workspace"
|
||||
middleware = ShellToolMiddleware(workspace_root=workspace)
|
||||
try:
|
||||
state: AgentState = _empty_state()
|
||||
updates = middleware.before_agent(state, None)
|
||||
if updates:
|
||||
state.update(updates)
|
||||
resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
|
||||
|
||||
middleware._run_shell_tool(resources, {"command": "cd /"}, tool_call_id=None)
|
||||
result = middleware._run_shell_tool(resources, {"command": "pwd"}, tool_call_id=None)
|
||||
assert isinstance(result, str)
|
||||
assert result.strip() == "/"
|
||||
echo_result = middleware._run_shell_tool(
|
||||
resources, {"command": "echo ready"}, tool_call_id=None
|
||||
)
|
||||
assert "ready" in echo_result
|
||||
finally:
|
||||
updates = middleware.after_agent(state, None)
|
||||
if updates:
|
||||
state.update(updates)
|
||||
|
||||
|
||||
def test_restart_resets_session_environment(tmp_path: Path) -> None:
|
||||
middleware = ShellToolMiddleware(workspace_root=tmp_path / "workspace")
|
||||
try:
|
||||
state: AgentState = _empty_state()
|
||||
updates = middleware.before_agent(state, None)
|
||||
if updates:
|
||||
state.update(updates)
|
||||
resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
|
||||
|
||||
middleware._run_shell_tool(resources, {"command": "export FOO=bar"}, tool_call_id=None)
|
||||
restart_message = middleware._run_shell_tool(
|
||||
resources, {"restart": True}, tool_call_id=None
|
||||
)
|
||||
assert "restarted" in restart_message.lower()
|
||||
resources = middleware._get_or_create_resources(state) # reacquire after restart
|
||||
result = middleware._run_shell_tool(
|
||||
resources, {"command": "echo ${FOO:-unset}"}, tool_call_id=None
|
||||
)
|
||||
assert "unset" in result
|
||||
finally:
|
||||
updates = middleware.after_agent(state, None)
|
||||
if updates:
|
||||
state.update(updates)
|
||||
|
||||
|
||||
def test_truncation_indicator_present(tmp_path: Path) -> None:
|
||||
policy = HostExecutionPolicy(max_output_lines=5, command_timeout=5.0)
|
||||
middleware = ShellToolMiddleware(workspace_root=tmp_path / "workspace", execution_policy=policy)
|
||||
try:
|
||||
state: AgentState = _empty_state()
|
||||
updates = middleware.before_agent(state, None)
|
||||
if updates:
|
||||
state.update(updates)
|
||||
resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
|
||||
result = middleware._run_shell_tool(resources, {"command": "seq 1 20"}, tool_call_id=None)
|
||||
assert "Output truncated" in result
|
||||
finally:
|
||||
updates = middleware.after_agent(state, None)
|
||||
if updates:
|
||||
state.update(updates)
|
||||
|
||||
|
||||
def test_timeout_returns_error(tmp_path: Path) -> None:
|
||||
policy = HostExecutionPolicy(command_timeout=0.5)
|
||||
middleware = ShellToolMiddleware(workspace_root=tmp_path / "workspace", execution_policy=policy)
|
||||
try:
|
||||
state: AgentState = _empty_state()
|
||||
updates = middleware.before_agent(state, None)
|
||||
if updates:
|
||||
state.update(updates)
|
||||
resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
|
||||
start = time.monotonic()
|
||||
result = middleware._run_shell_tool(resources, {"command": "sleep 2"}, tool_call_id=None)
|
||||
elapsed = time.monotonic() - start
|
||||
assert elapsed < policy.command_timeout + 2.0
|
||||
assert "timed out" in result.lower()
|
||||
finally:
|
||||
updates = middleware.after_agent(state, None)
|
||||
if updates:
|
||||
state.update(updates)
|
||||
|
||||
|
||||
def test_redaction_policy_applies(tmp_path: Path) -> None:
|
||||
middleware = ShellToolMiddleware(
|
||||
workspace_root=tmp_path / "workspace",
|
||||
redaction_rules=(RedactionRule(pii_type="email", strategy="redact"),),
|
||||
)
|
||||
try:
|
||||
state: AgentState = _empty_state()
|
||||
updates = middleware.before_agent(state, None)
|
||||
if updates:
|
||||
state.update(updates)
|
||||
resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
|
||||
message = middleware._run_shell_tool(
|
||||
resources,
|
||||
{"command": "printf 'Contact: user@example.com\\n'"},
|
||||
tool_call_id=None,
|
||||
)
|
||||
assert "[REDACTED_EMAIL]" in message
|
||||
assert "user@example.com" not in message
|
||||
finally:
|
||||
updates = middleware.after_agent(state, None)
|
||||
if updates:
|
||||
state.update(updates)
|
||||
|
||||
|
||||
def test_startup_and_shutdown_commands(tmp_path: Path) -> None:
|
||||
workspace = tmp_path / "workspace"
|
||||
middleware = ShellToolMiddleware(
|
||||
workspace_root=workspace,
|
||||
startup_commands=("touch startup.txt",),
|
||||
shutdown_commands=("touch shutdown.txt",),
|
||||
)
|
||||
try:
|
||||
state: AgentState = _empty_state()
|
||||
updates = middleware.before_agent(state, None)
|
||||
if updates:
|
||||
state.update(updates)
|
||||
assert (workspace / "startup.txt").exists()
|
||||
finally:
|
||||
updates = middleware.after_agent(state, None)
|
||||
if updates:
|
||||
state.update(updates)
|
||||
assert (workspace / "shutdown.txt").exists()
|
||||
|
||||
|
||||
def test_session_resources_finalizer_cleans_up(tmp_path: Path) -> None:
|
||||
policy = HostExecutionPolicy(termination_timeout=0.1)
|
||||
|
||||
class DummySession:
|
||||
def __init__(self) -> None:
|
||||
self.stopped: bool = False
|
||||
|
||||
def stop(self, timeout: float) -> None: # noqa: ARG002
|
||||
self.stopped = True
|
||||
|
||||
session = DummySession()
|
||||
tempdir = tempfile.TemporaryDirectory(dir=tmp_path)
|
||||
tempdir_path = Path(tempdir.name)
|
||||
resources = _SessionResources(session=session, tempdir=tempdir, policy=policy) # type: ignore[arg-type]
|
||||
finalizer = resources._finalizer
|
||||
|
||||
# Drop our last strong reference and force collection.
|
||||
del resources
|
||||
gc.collect()
|
||||
|
||||
assert not finalizer.alive
|
||||
assert session.stopped
|
||||
assert not tempdir_path.exists()
|
||||
|
||||
|
||||
def test_shell_tool_input_validation() -> None:
|
||||
"""Test _ShellToolInput validation rules."""
|
||||
# Both command and restart not allowed
|
||||
with pytest.raises(ValueError, match="only one"):
|
||||
_ShellToolInput(command="ls", restart=True)
|
||||
|
||||
# Neither command nor restart provided
|
||||
with pytest.raises(ValueError, match="requires either"):
|
||||
_ShellToolInput()
|
||||
|
||||
# Valid: command only
|
||||
valid_cmd = _ShellToolInput(command="ls")
|
||||
assert valid_cmd.command == "ls"
|
||||
assert not valid_cmd.restart
|
||||
|
||||
# Valid: restart only
|
||||
valid_restart = _ShellToolInput(restart=True)
|
||||
assert valid_restart.restart is True
|
||||
assert valid_restart.command is None
|
||||
|
||||
|
||||
def test_normalize_shell_command_empty() -> None:
|
||||
"""Test that empty shell command raises an error."""
|
||||
with pytest.raises(ValueError, match="at least one argument"):
|
||||
ShellToolMiddleware(shell_command=[])
|
||||
|
||||
|
||||
def test_normalize_env_non_string_keys() -> None:
|
||||
"""Test that non-string environment keys raise an error."""
|
||||
with pytest.raises(TypeError, match="must be strings"):
|
||||
ShellToolMiddleware(env={123: "value"}) # type: ignore[dict-item]
|
||||
|
||||
|
||||
def test_normalize_env_coercion(tmp_path: Path) -> None:
|
||||
"""Test that environment values are coerced to strings."""
|
||||
middleware = ShellToolMiddleware(
|
||||
workspace_root=tmp_path / "workspace", env={"NUM": 42, "BOOL": True}
|
||||
)
|
||||
try:
|
||||
state: AgentState = _empty_state()
|
||||
updates = middleware.before_agent(state, None)
|
||||
if updates:
|
||||
state.update(updates)
|
||||
resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
|
||||
result = middleware._run_shell_tool(
|
||||
resources, {"command": "echo $NUM $BOOL"}, tool_call_id=None
|
||||
)
|
||||
assert "42" in result
|
||||
assert "True" in result
|
||||
finally:
|
||||
updates = middleware.after_agent(state, None)
|
||||
if updates:
|
||||
state.update(updates)
|
||||
|
||||
|
||||
def test_shell_tool_missing_command_string(tmp_path: Path) -> None:
|
||||
"""Test that shell tool raises an error when command is not a string."""
|
||||
middleware = ShellToolMiddleware(workspace_root=tmp_path / "workspace")
|
||||
try:
|
||||
state: AgentState = _empty_state()
|
||||
updates = middleware.before_agent(state, None)
|
||||
if updates:
|
||||
state.update(updates)
|
||||
resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
|
||||
|
||||
with pytest.raises(ToolException, match="expects a 'command' string"):
|
||||
middleware._run_shell_tool(resources, {"command": None}, tool_call_id=None)
|
||||
|
||||
with pytest.raises(ToolException, match="expects a 'command' string"):
|
||||
middleware._run_shell_tool(
|
||||
resources,
|
||||
{"command": 123}, # type: ignore[dict-item]
|
||||
tool_call_id=None,
|
||||
)
|
||||
finally:
|
||||
updates = middleware.after_agent(state, None)
|
||||
if updates:
|
||||
state.update(updates)
|
||||
|
||||
|
||||
def test_tool_message_formatting_with_id(tmp_path: Path) -> None:
|
||||
"""Test that tool messages are properly formatted with tool_call_id."""
|
||||
middleware = ShellToolMiddleware(workspace_root=tmp_path / "workspace")
|
||||
try:
|
||||
state: AgentState = _empty_state()
|
||||
updates = middleware.before_agent(state, None)
|
||||
if updates:
|
||||
state.update(updates)
|
||||
resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
|
||||
|
||||
result = middleware._run_shell_tool(
|
||||
resources, {"command": "echo test"}, tool_call_id="test-id-123"
|
||||
)
|
||||
|
||||
assert isinstance(result, ToolMessage)
|
||||
assert result.tool_call_id == "test-id-123"
|
||||
assert result.name == "shell"
|
||||
assert result.status == "success"
|
||||
assert "test" in result.content
|
||||
finally:
|
||||
updates = middleware.after_agent(state, None)
|
||||
if updates:
|
||||
state.update(updates)
|
||||
|
||||
|
||||
def test_nonzero_exit_code_returns_error(tmp_path: Path) -> None:
|
||||
"""Test that non-zero exit codes are marked as errors."""
|
||||
middleware = ShellToolMiddleware(workspace_root=tmp_path / "workspace")
|
||||
try:
|
||||
state: AgentState = _empty_state()
|
||||
updates = middleware.before_agent(state, None)
|
||||
if updates:
|
||||
state.update(updates)
|
||||
resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
|
||||
|
||||
result = middleware._run_shell_tool(
|
||||
resources,
|
||||
{"command": "false"}, # Command that exits with 1 but doesn't kill shell
|
||||
tool_call_id="test-id",
|
||||
)
|
||||
|
||||
assert isinstance(result, ToolMessage)
|
||||
assert result.status == "error"
|
||||
assert "Exit code: 1" in result.content
|
||||
assert result.artifact["exit_code"] == 1 # type: ignore[index]
|
||||
finally:
|
||||
updates = middleware.after_agent(state, None)
|
||||
if updates:
|
||||
state.update(updates)
|
||||
|
||||
|
||||
def test_truncation_by_bytes(tmp_path: Path) -> None:
|
||||
"""Test that output is truncated by bytes when max_output_bytes is exceeded."""
|
||||
policy = HostExecutionPolicy(max_output_bytes=50, command_timeout=5.0)
|
||||
middleware = ShellToolMiddleware(workspace_root=tmp_path / "workspace", execution_policy=policy)
|
||||
try:
|
||||
state: AgentState = _empty_state()
|
||||
updates = middleware.before_agent(state, None)
|
||||
if updates:
|
||||
state.update(updates)
|
||||
resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
|
||||
|
||||
result = middleware._run_shell_tool(
|
||||
resources, {"command": "python3 -c 'print(\"x\" * 100)'"}, tool_call_id=None
|
||||
)
|
||||
|
||||
assert "truncated at 50 bytes" in result.lower()
|
||||
finally:
|
||||
updates = middleware.after_agent(state, None)
|
||||
if updates:
|
||||
state.update(updates)
|
||||
|
||||
|
||||
def test_startup_command_failure(tmp_path: Path) -> None:
|
||||
"""Test that startup command failure raises an error."""
|
||||
policy = HostExecutionPolicy(startup_timeout=1.0)
|
||||
middleware = ShellToolMiddleware(
|
||||
workspace_root=tmp_path / "workspace", startup_commands=("exit 1",), execution_policy=policy
|
||||
)
|
||||
state: AgentState = _empty_state()
|
||||
with pytest.raises(RuntimeError, match="Startup command.*failed"):
|
||||
middleware.before_agent(state, None)
|
||||
|
||||
|
||||
def test_shutdown_command_failure_logged(tmp_path: Path) -> None:
|
||||
"""Test that shutdown command failures are logged but don't raise."""
|
||||
policy = HostExecutionPolicy(command_timeout=1.0)
|
||||
middleware = ShellToolMiddleware(
|
||||
workspace_root=tmp_path / "workspace",
|
||||
shutdown_commands=("exit 1",),
|
||||
execution_policy=policy,
|
||||
)
|
||||
try:
|
||||
state: AgentState = _empty_state()
|
||||
updates = middleware.before_agent(state, None)
|
||||
if updates:
|
||||
state.update(updates)
|
||||
finally:
|
||||
# Should not raise despite shutdown command failing
|
||||
middleware.after_agent(state, None)
|
||||
|
||||
|
||||
def test_shutdown_command_timeout_logged(tmp_path: Path) -> None:
|
||||
"""Test that shutdown command timeouts are logged but don't raise."""
|
||||
policy = HostExecutionPolicy(command_timeout=0.1)
|
||||
middleware = ShellToolMiddleware(
|
||||
workspace_root=tmp_path / "workspace",
|
||||
execution_policy=policy,
|
||||
shutdown_commands=("sleep 2",),
|
||||
)
|
||||
try:
|
||||
state: AgentState = _empty_state()
|
||||
updates = middleware.before_agent(state, None)
|
||||
if updates:
|
||||
state.update(updates)
|
||||
finally:
|
||||
# Should not raise despite shutdown command timing out
|
||||
middleware.after_agent(state, None)
|
||||
|
||||
|
||||
def test_empty_output_replaced_with_no_output(tmp_path: Path) -> None:
|
||||
"""Test that empty command output is replaced with '<no output>'."""
|
||||
middleware = ShellToolMiddleware(workspace_root=tmp_path / "workspace")
|
||||
try:
|
||||
state: AgentState = _empty_state()
|
||||
updates = middleware.before_agent(state, None)
|
||||
if updates:
|
||||
state.update(updates)
|
||||
resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
|
||||
|
||||
result = middleware._run_shell_tool(
|
||||
resources,
|
||||
{"command": "true"}, # Command that produces no output
|
||||
tool_call_id=None,
|
||||
)
|
||||
|
||||
assert "<no output>" in result
|
||||
finally:
|
||||
updates = middleware.after_agent(state, None)
|
||||
if updates:
|
||||
state.update(updates)
|
||||
|
||||
|
||||
def test_stderr_output_labeling(tmp_path: Path) -> None:
|
||||
"""Test that stderr output is properly labeled."""
|
||||
middleware = ShellToolMiddleware(workspace_root=tmp_path / "workspace")
|
||||
try:
|
||||
state: AgentState = _empty_state()
|
||||
updates = middleware.before_agent(state, None)
|
||||
if updates:
|
||||
state.update(updates)
|
||||
resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
|
||||
|
||||
result = middleware._run_shell_tool(
|
||||
resources, {"command": "echo error >&2"}, tool_call_id=None
|
||||
)
|
||||
|
||||
assert "[stderr] error" in result
|
||||
finally:
|
||||
updates = middleware.after_agent(state, None)
|
||||
if updates:
|
||||
state.update(updates)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("startup_commands", "expected"),
|
||||
[
|
||||
("echo test", ("echo test",)), # String
|
||||
(["echo test", "pwd"], ("echo test", "pwd")), # List
|
||||
(("echo test",), ("echo test",)), # Tuple
|
||||
(None, ()), # None
|
||||
],
|
||||
)
|
||||
def test_normalize_commands_string_tuple_list(
|
||||
tmp_path: Path,
|
||||
startup_commands: str | list[str] | tuple[str, ...] | None,
|
||||
expected: tuple[str, ...],
|
||||
) -> None:
|
||||
"""Test various command normalization formats."""
|
||||
middleware = ShellToolMiddleware(
|
||||
workspace_root=tmp_path / "workspace", startup_commands=startup_commands
|
||||
)
|
||||
assert middleware._startup_commands == expected # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def test_async_methods_delegate_to_sync(tmp_path: Path) -> None:
|
||||
"""Test that async methods properly delegate to sync methods."""
|
||||
middleware = ShellToolMiddleware(workspace_root=tmp_path / "workspace")
|
||||
try:
|
||||
state: AgentState = _empty_state()
|
||||
|
||||
# Test abefore_agent
|
||||
updates = asyncio.run(middleware.abefore_agent(state, None))
|
||||
if updates:
|
||||
state.update(updates)
|
||||
|
||||
# Test aafter_agent
|
||||
asyncio.run(middleware.aafter_agent(state, None))
|
||||
finally:
|
||||
pass
|
||||
|
||||
|
||||
def test_shell_middleware_resumable_after_interrupt(tmp_path: Path) -> None:
|
||||
"""Test that shell middleware is resumable after an interrupt.
|
||||
|
||||
This test simulates a scenario where:
|
||||
1. The middleware creates a shell session
|
||||
2. A command is executed
|
||||
3. The agent is interrupted (state is preserved)
|
||||
4. The agent resumes with the same state
|
||||
5. The shell session is reused (not recreated)
|
||||
"""
|
||||
workspace = tmp_path / "workspace"
|
||||
middleware = ShellToolMiddleware(workspace_root=workspace)
|
||||
|
||||
# Simulate first execution (before interrupt)
|
||||
state: AgentState = _empty_state()
|
||||
updates = middleware.before_agent(state, None)
|
||||
if updates:
|
||||
state.update(updates)
|
||||
|
||||
# Get the resources and verify they exist
|
||||
resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
|
||||
initial_session = resources.session
|
||||
initial_tempdir = resources.tempdir
|
||||
|
||||
# Execute a command to set state
|
||||
middleware._run_shell_tool(resources, {"command": "export TEST_VAR=hello"}, tool_call_id=None)
|
||||
|
||||
# Simulate interrupt - state is preserved, but we don't call after_agent
|
||||
# In a real scenario, the state would be checkpointed here
|
||||
|
||||
# Simulate resumption - call before_agent again with same state
|
||||
# This should reuse existing resources, not create new ones
|
||||
updates = middleware.before_agent(state, None)
|
||||
if updates:
|
||||
state.update(updates)
|
||||
|
||||
# Get resources again - should be the same session
|
||||
resumed_resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
|
||||
|
||||
# Verify the session was reused (same object reference)
|
||||
assert resumed_resources.session is initial_session
|
||||
assert resumed_resources.tempdir is initial_tempdir
|
||||
|
||||
# Verify the session state persisted (environment variable still set)
|
||||
result = middleware._run_shell_tool(
|
||||
resumed_resources, {"command": "echo ${TEST_VAR:-unset}"}, tool_call_id=None
|
||||
)
|
||||
assert "hello" in result
|
||||
assert "unset" not in result
|
||||
|
||||
# Clean up
|
||||
middleware.after_agent(state, None)
|
||||
|
||||
|
||||
def test_get_or_create_resources_creates_when_missing(tmp_path: Path) -> None:
|
||||
"""Test that _get_or_create_resources creates resources when they don't exist."""
|
||||
workspace = tmp_path / "workspace"
|
||||
middleware = ShellToolMiddleware(workspace_root=workspace)
|
||||
|
||||
state: AgentState = _empty_state()
|
||||
|
||||
# State has no resources initially
|
||||
assert "shell_session_resources" not in state
|
||||
|
||||
# Call _get_or_create_resources - should create new resources
|
||||
resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
|
||||
|
||||
assert isinstance(resources, _SessionResources)
|
||||
assert resources.session is not None
|
||||
assert state.get("shell_session_resources") is resources
|
||||
|
||||
# Clean up
|
||||
resources._finalizer()
|
||||
|
||||
|
||||
def test_get_or_create_resources_reuses_existing(tmp_path: Path) -> None:
|
||||
"""Test that _get_or_create_resources reuses existing resources."""
|
||||
workspace = tmp_path / "workspace"
|
||||
middleware = ShellToolMiddleware(workspace_root=workspace)
|
||||
|
||||
state: AgentState = _empty_state()
|
||||
|
||||
# Create resources first time
|
||||
resources1 = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
|
||||
|
||||
# Call again - should return the same resources
|
||||
resources2 = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
|
||||
|
||||
assert resources1 is resources2
|
||||
assert resources1.session is resources2.session
|
||||
|
||||
# Clean up
|
||||
resources1._finalizer()
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user