mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-08 10:09:46 +00:00
Compare commits
54 Commits
langchain-
...
fix/respon
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d765a91c5c | ||
|
|
88af494b37 | ||
|
|
5b0a55ad35 | ||
|
|
08c4055347 | ||
|
|
6e2f46d04c | ||
|
|
5bf0b218c8 | ||
|
|
4e39c164bb | ||
|
|
0b3af47335 | ||
|
|
bc91a4811c | ||
|
|
05a61f9508 | ||
|
|
aa63de9366 | ||
|
|
86fa34f3eb | ||
|
|
36037c9251 | ||
|
|
ad26c892ea | ||
|
|
4828a85ab0 | ||
|
|
b999f356e8 | ||
|
|
062196a7b3 | ||
|
|
dc9f941326 | ||
|
|
238ecd09e0 | ||
|
|
6b5fdfb804 | ||
|
|
b42dac5fe6 | ||
|
|
e0a4af8d8b | ||
|
|
fcf7175392 | ||
|
|
1f2ab17dff | ||
|
|
2dc89a2ae7 | ||
|
|
e3c4aeaea1 | ||
|
|
444939945a | ||
|
|
ae8db86486 | ||
|
|
8a1419dad1 | ||
|
|
840e4c8e9f | ||
|
|
37aff0a153 | ||
|
|
a163d59988 | ||
|
|
b26e52aa4d | ||
|
|
38cdd7a2ec | ||
|
|
26e5d1302b | ||
|
|
107425c68d | ||
|
|
009cc3bf50 | ||
|
|
6185558449 | ||
|
|
0928ff5b12 | ||
|
|
7f9b0772fc | ||
|
|
d6e618258f | ||
|
|
806bc593ab | ||
|
|
047bcbaa13 | ||
|
|
18db07c292 | ||
|
|
1fe2c4084b | ||
|
|
c6c7fce6c9 | ||
|
|
3d08b6bd11 | ||
|
|
f2dcdae467 | ||
|
|
45f1b67340 | ||
|
|
8e37d39d66 | ||
|
|
be9274054f | ||
|
|
03b9214737 | ||
|
|
fcebafea9b | ||
|
|
3efa31d786 |
9
.github/PULL_REQUEST_TEMPLATE.md
vendored
9
.github/PULL_REQUEST_TEMPLATE.md
vendored
@@ -1,3 +1,5 @@
|
||||
(Replace this entire block of text)
|
||||
|
||||
Thank you for contributing to LangChain! Follow these steps to mark your pull request as ready for review. **If any of these steps are not completed, your PR will not be considered for review.**
|
||||
|
||||
- [ ] **PR title**: Follows the format: {TYPE}({SCOPE}): {DESCRIPTION}
|
||||
@@ -9,14 +11,13 @@ Thank you for contributing to LangChain! Follow these steps to mark your pull re
|
||||
- feat, fix, docs, style, refactor, perf, test, build, ci, chore, revert, release
|
||||
- Allowed `{SCOPE}` values (optional):
|
||||
- core, cli, langchain, standard-tests, docs, anthropic, chroma, deepseek, exa, fireworks, groq, huggingface, mistralai, nomic, ollama, openai, perplexity, prompty, qdrant, xai
|
||||
- Note: the `{DESCRIPTION}` must not start with an uppercase letter.
|
||||
- *Note:* the `{DESCRIPTION}` must not start with an uppercase letter.
|
||||
- Once you've written the title, please delete this checklist item; do not include it in the PR.
|
||||
|
||||
- [ ] **PR message**: ***Delete this entire checklist*** and replace with
|
||||
- **Description:** a description of the change. Include a [closing keyword](https://docs.github.com/en/issues/tracking-your-work-with-issues/using-issues/linking-a-pull-request-to-an-issue#linking-a-pull-request-to-an-issue-using-a-keyword) if applicable to a relevant issue.
|
||||
- **Issue:** the issue # it fixes, if applicable (e.g. Fixes #123)
|
||||
- **Dependencies:** any dependencies required for this change
|
||||
- **Twitter handle:** if your PR gets announced, and you'd like a mention, we'll gladly shout you out!
|
||||
|
||||
- [ ] **Add tests and docs**: If you're adding a new integration, you must include:
|
||||
1. A test for the integration, preferably unit tests that do not rely on network access,
|
||||
@@ -26,7 +27,7 @@ Thank you for contributing to LangChain! Follow these steps to mark your pull re
|
||||
|
||||
Additional guidelines:
|
||||
|
||||
- Make sure optional dependencies are imported within a function.
|
||||
- Please do not add dependencies to `pyproject.toml` files (even optional ones) unless they are **required** for unit tests.
|
||||
- Most PRs should not touch more than one package.
|
||||
- Please do not add dependencies to `pyproject.toml` files (even optional ones) unless they are **required** for unit tests.
|
||||
- Changes should be backwards compatible.
|
||||
- Make sure optional dependencies are imported within a function.
|
||||
|
||||
2
.github/scripts/check_diff.py
vendored
2
.github/scripts/check_diff.py
vendored
@@ -132,6 +132,8 @@ def _get_configs_for_single_dir(job: str, dir_: str) -> List[Dict[str, str]]:
|
||||
|
||||
elif dir_ == "libs/langchain" and job == "extended-tests":
|
||||
py_versions = ["3.9", "3.13"]
|
||||
elif dir_ == "libs/langchain_v1":
|
||||
py_versions = ["3.10", "3.13"]
|
||||
|
||||
elif dir_ == ".":
|
||||
# unable to install with 3.13 because tokenizers doesn't support 3.13 yet
|
||||
|
||||
29
.github/workflows/check_core_versions.yml
vendored
29
.github/workflows/check_core_versions.yml
vendored
@@ -20,15 +20,30 @@ jobs:
|
||||
|
||||
- name: '✅ Verify pyproject.toml & version.py Match'
|
||||
run: |
|
||||
PYPROJECT_VERSION=$(grep -Po '(?<=^version = ")[^"]*' libs/core/pyproject.toml)
|
||||
VERSION_PY_VERSION=$(grep -Po '(?<=^VERSION = ")[^"]*' libs/core/langchain_core/version.py)
|
||||
# Check core versions
|
||||
CORE_PYPROJECT_VERSION=$(grep -Po '(?<=^version = ")[^"]*' libs/core/pyproject.toml)
|
||||
CORE_VERSION_PY_VERSION=$(grep -Po '(?<=^VERSION = ")[^"]*' libs/core/langchain_core/version.py)
|
||||
|
||||
# Compare the two versions
|
||||
if [ "$PYPROJECT_VERSION" != "$VERSION_PY_VERSION" ]; then
|
||||
# Compare core versions
|
||||
if [ "$CORE_PYPROJECT_VERSION" != "$CORE_VERSION_PY_VERSION" ]; then
|
||||
echo "langchain-core versions in pyproject.toml and version.py do not match!"
|
||||
echo "pyproject.toml version: $PYPROJECT_VERSION"
|
||||
echo "version.py version: $VERSION_PY_VERSION"
|
||||
echo "pyproject.toml version: $CORE_PYPROJECT_VERSION"
|
||||
echo "version.py version: $CORE_VERSION_PY_VERSION"
|
||||
exit 1
|
||||
else
|
||||
echo "Versions match: $PYPROJECT_VERSION"
|
||||
echo "Core versions match: $CORE_PYPROJECT_VERSION"
|
||||
fi
|
||||
|
||||
# Check langchain_v1 versions
|
||||
LANGCHAIN_PYPROJECT_VERSION=$(grep -Po '(?<=^version = ")[^"]*' libs/langchain_v1/pyproject.toml)
|
||||
LANGCHAIN_INIT_PY_VERSION=$(grep -Po '(?<=^__version__ = ")[^"]*' libs/langchain_v1/langchain/__init__.py)
|
||||
|
||||
# Compare langchain_v1 versions
|
||||
if [ "$LANGCHAIN_PYPROJECT_VERSION" != "$LANGCHAIN_INIT_PY_VERSION" ]; then
|
||||
echo "langchain_v1 versions in pyproject.toml and __init__.py do not match!"
|
||||
echo "pyproject.toml version: $LANGCHAIN_PYPROJECT_VERSION"
|
||||
echo "version.py version: $LANGCHAIN_INIT_PY_VERSION"
|
||||
exit 1
|
||||
else
|
||||
echo "Langchain v1 versions match: $LANGCHAIN_PYPROJECT_VERSION"
|
||||
fi
|
||||
|
||||
@@ -31,7 +31,7 @@ The conceptual guide does not cover step-by-step instructions or specific implem
|
||||
- **[Vector stores](/docs/concepts/vectorstores)**: Storage of and efficient search over vectors and associated metadata.
|
||||
- **[Retriever](/docs/concepts/retrievers)**: A component that returns relevant documents from a knowledge base in response to a query.
|
||||
- **[Retrieval Augmented Generation (RAG)](/docs/concepts/rag)**: A technique that enhances language models by combining them with external knowledge bases.
|
||||
- **[Agents](/docs/concepts/agents)**: Use a [language model](/docs/concepts/chat_models) to choose a sequence of actions to take. Agents can interact with external resources via [tool](/docs/concepts/tools).
|
||||
- **[Agents](/docs/concepts/agents)**: Use a [language model](/docs/concepts/chat_models) to choose a sequence of actions to take. Agents can interact with external resources via [tools](/docs/concepts/tools).
|
||||
- **[Prompt templates](/docs/concepts/prompt_templates)**: Component for factoring out the static parts of a model "prompt" (usually a sequence of messages). Useful for serializing, versioning, and reusing these static parts.
|
||||
- **[Output parsers](/docs/concepts/output_parsers)**: Responsible for taking the output of a model and transforming it into a more suitable format for downstream tasks. Output parsers were primarily useful prior to the general availability of [tool calling](/docs/concepts/tool_calling) and [structured outputs](/docs/concepts/structured_outputs).
|
||||
- **[Few-shot prompting](/docs/concepts/few_shot_prompting)**: A technique for improving model performance by providing a few examples of the task to perform in the prompt.
|
||||
@@ -48,7 +48,7 @@ The conceptual guide does not cover step-by-step instructions or specific implem
|
||||
- **[AIMessage](/docs/concepts/messages#aimessage)**: Represents a complete response from an AI model.
|
||||
- **[astream_events](/docs/concepts/chat_models#key-methods)**: Stream granular information from [LCEL](/docs/concepts/lcel) chains.
|
||||
- **[BaseTool](/docs/concepts/tools/#tool-interface)**: The base class for all tools in LangChain.
|
||||
- **[batch](/docs/concepts/runnables)**: Use to execute a runnable with batch inputs.
|
||||
- **[batch](/docs/concepts/runnables)**: Used to execute a runnable with batch inputs.
|
||||
- **[bind_tools](/docs/concepts/tool_calling/#tool-binding)**: Allows models to interact with tools.
|
||||
- **[Caching](/docs/concepts/chat_models#caching)**: Storing results to avoid redundant calls to a chat model.
|
||||
- **[Chat models](/docs/concepts/multimodality/#multimodality-in-chat-models)**: Chat models that handle multiple data modalities.
|
||||
|
||||
@@ -7,4 +7,4 @@ Traces contain individual steps called `runs`. These can be individual calls fro
|
||||
tool, or sub-chains.
|
||||
Tracing gives you observability inside your chains and agents, and is vital in diagnosing issues.
|
||||
|
||||
For a deeper dive, check out [this LangSmith conceptual guide](https://docs.smith.langchain.com/concepts/tracing).
|
||||
For a deeper dive, check out [this LangSmith conceptual guide](https://docs.langchain.com/langsmith/observability-quickstart).
|
||||
|
||||
@@ -3,9 +3,9 @@
|
||||
Here are some things to keep in mind for all types of contributions:
|
||||
|
||||
- Follow the ["fork and pull request"](https://docs.github.com/en/get-started/exploring-projects-on-github/contributing-to-a-project) workflow.
|
||||
- Fill out the checked-in pull request template when opening pull requests. Note related issues and tag relevant maintainers.
|
||||
- Fill out the checked-in pull request template when opening pull requests. Note related issues.
|
||||
- Ensure your PR passes formatting, linting, and testing checks before requesting a review.
|
||||
- If you would like comments or feedback on your current progress, please open an issue or discussion and tag a maintainer.
|
||||
- If you would like comments or feedback on your current progress, please open an issue or discussion.
|
||||
- See the sections on [Testing](setup.mdx#testing) and [Formatting and Linting](setup.mdx#formatting-and-linting) for how to run these checks locally.
|
||||
- Backwards compatibility is key. Your changes must not be breaking, except in case of critical bug and security fixes.
|
||||
- Look for duplicate PRs or issues that have already been opened before opening a new one.
|
||||
|
||||
@@ -79,7 +79,7 @@ Here are some high-level tips on writing a good how-to guide:
|
||||
|
||||
### Conceptual guide
|
||||
|
||||
LangChain's conceptual guide falls under the **Explanation** quadrant of Diataxis. These guides should cover LangChain terms and concepts
|
||||
LangChain's conceptual guides fall under the **Explanation** quadrant of Diataxis. These guides should cover LangChain terms and concepts
|
||||
in a more abstract way than how-to guides or tutorials, targeting curious users interested in
|
||||
gaining a deeper understanding and insights of the framework. Try to avoid excessively large code examples as the primary goal is to
|
||||
provide perspective to the user rather than to finish a practical project. These guides should cover **why** things work the way they do.
|
||||
@@ -105,7 +105,7 @@ Here are some high-level tips on writing a good conceptual guide:
|
||||
### References
|
||||
|
||||
References contain detailed, low-level information that describes exactly what functionality exists and how to use it.
|
||||
In LangChain, this is mainly our API reference pages, which are populated from docstrings within code.
|
||||
In LangChain, these are mainly our API reference pages, which are populated from docstrings within code.
|
||||
References pages are generally not read end-to-end, but are consulted as necessary when a user needs to know
|
||||
how to use something specific.
|
||||
|
||||
@@ -119,7 +119,7 @@ but here are some high-level tips on writing a good docstring:
|
||||
- Be concise
|
||||
- Discuss special cases and deviations from a user's expectations
|
||||
- Go into detail on required inputs and outputs
|
||||
- Light details on when one might use the feature are fine, but in-depth details belong in other sections.
|
||||
- Light details on when one might use the feature are fine, but in-depth details belong in other sections
|
||||
|
||||
Each category serves a distinct purpose and requires a specific approach to writing and structuring the content.
|
||||
|
||||
@@ -127,17 +127,17 @@ Each category serves a distinct purpose and requires a specific approach to writ
|
||||
|
||||
Here are some other guidelines you should think about when writing and organizing documentation.
|
||||
|
||||
We generally do not merge new tutorials from outside contributors without an actue need.
|
||||
We generally do not merge new tutorials from outside contributors without an acute need.
|
||||
We welcome updates as well as new integration docs, how-tos, and references.
|
||||
|
||||
### Avoid duplication
|
||||
|
||||
Multiple pages that cover the same material in depth are difficult to maintain and cause confusion. There should
|
||||
be only one (very rarely two), canonical pages for a given concept or feature. Instead, you should link to other guides.
|
||||
be only one (very rarely two) canonical pages for a given concept or feature. Instead, you should link to other guides.
|
||||
|
||||
### Link to other sections
|
||||
|
||||
Because sections of the docs do not exist in a vacuum, it is important to link to other sections frequently,
|
||||
Because sections of the docs do not exist in a vacuum, it is important to link to other sections frequently
|
||||
to allow a developer to learn more about an unfamiliar topic within the flow of reading.
|
||||
|
||||
This includes linking to the API references and conceptual sections!
|
||||
|
||||
@@ -33,7 +33,7 @@ Sometimes you want to make a small change, like fixing a typo, and the easiest w
|
||||
- Click the "Commit changes..." button at the top-right corner of the page.
|
||||
- Give your commit a title like "Fix typo in X section."
|
||||
- Optionally, write an extended commit description.
|
||||
- Click "Propose changes"
|
||||
- Click "Propose changes".
|
||||
|
||||
5. **Submit a pull request (PR):**
|
||||
- GitHub will redirect you to a page where you can create a pull request.
|
||||
|
||||
@@ -5,7 +5,7 @@ sidebar_class_name: hidden
|
||||
|
||||
# How-to guides
|
||||
|
||||
Here you’ll find answers to “How do I….?” types of questions.
|
||||
Here you’ll find answers to "How do I….?" types of questions.
|
||||
These guides are *goal-oriented* and *concrete*; they're meant to help you complete a specific task.
|
||||
For conceptual explanations see the [Conceptual guide](/docs/concepts/).
|
||||
For end-to-end walkthroughs see [Tutorials](/docs/tutorials).
|
||||
|
||||
@@ -55,7 +55,7 @@
|
||||
"source": [
|
||||
"## Defining tool schemas\n",
|
||||
"\n",
|
||||
"For a model to be able to call tools, we need to pass in tool schemas that describe what the tool does and what it's arguments are. Chat models that support tool calling features implement a `.bind_tools()` method for passing tool schemas to the model. Tool schemas can be passed in as Python functions (with typehints and docstrings), Pydantic models, TypedDict classes, or LangChain [Tool objects](https://python.langchain.com/api_reference/core/tools/langchain_core.tools.base.BaseTool.html#basetool). Subsequent invocations of the model will pass in these tool schemas along with the prompt.\n",
|
||||
"For a model to be able to call tools, we need to pass in tool schemas that describe what the tool does and what its arguments are. Chat models that support tool calling features implement a `.bind_tools()` method for passing tool schemas to the model. Tool schemas can be passed in as Python functions (with typehints and docstrings), Pydantic models, TypedDict classes, or LangChain [Tool objects](https://python.langchain.com/api_reference/core/tools/langchain_core.tools.base.BaseTool.html#basetool). Subsequent invocations of the model will pass in these tool schemas along with the prompt.\n",
|
||||
"\n",
|
||||
"### Python functions\n",
|
||||
"Our tool schemas can be Python functions:"
|
||||
|
||||
@@ -2,67 +2,91 @@
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"jupyter": {
|
||||
"outputs_hidden": false
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"# Oracle Autonomous Database\n",
|
||||
"\n",
|
||||
"Oracle autonomous database is a cloud database that uses machine learning to automate database tuning, security, backups, updates, and other routine management tasks traditionally performed by DBAs.\n",
|
||||
"Oracle Autonomous Database is a cloud database that uses machine learning to automate database tuning, security, backups, updates, and other routine management tasks traditionally performed by DBAs.\n",
|
||||
"\n",
|
||||
"This notebook covers how to load documents from oracle autonomous database, the loader supports connection with connection string or tns configuration.\n",
|
||||
"This notebook covers how to load documents from Oracle Autonomous Database.\n",
|
||||
"\n",
|
||||
"## Prerequisites\n",
|
||||
"1. Database runs in a 'Thin' mode:\n",
|
||||
" https://python-oracledb.readthedocs.io/en/latest/user_guide/appendix_b.html\n",
|
||||
"2. `pip install oracledb`:\n",
|
||||
" https://python-oracledb.readthedocs.io/en/latest/user_guide/installation.html"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
"1. Install python-oracledb:\n",
|
||||
"\n",
|
||||
" `pip install oracledb`\n",
|
||||
" \n",
|
||||
" See [Installing python-oracledb](https://python-oracledb.readthedocs.io/en/latest/user_guide/installation.html).\n",
|
||||
"\n",
|
||||
"2. A database that python-oracledb's default 'Thin' mode can connected to. This is true of Oracle Autonomous Database, see [python-oracledb Architecture](https://python-oracledb.readthedocs.io/en/latest/user_guide/introduction.html#architecture).\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"jupyter": {
|
||||
"outputs_hidden": false
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"## Instructions"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"jupyter": {
|
||||
"outputs_hidden": false
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"pip install oracledb"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"jupyter": {
|
||||
"outputs_hidden": false
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_community.document_loaders import OracleAutonomousDatabaseLoader\n",
|
||||
"from settings import s"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"With mutual TLS authentication (mTLS), wallet_location and wallet_password are required to create the connection, user can create connection by providing either connection string or tns configuration details."
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
"collapsed": false,
|
||||
"jupyter": {
|
||||
"outputs_hidden": false
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"With mutual TLS authentication (mTLS), wallet_location and wallet_password parameters are required to create the connection. See python-oracledb documentation [Connecting to Oracle Cloud Autonomous Databases](https://python-oracledb.readthedocs.io/en/latest/user_guide/connection_handling.html#connecting-to-oracle-cloud-autonomous-databases)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"jupyter": {
|
||||
"outputs_hidden": false
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"SQL_QUERY = \"select prod_id, time_id from sh.costs fetch first 5 rows only\"\n",
|
||||
@@ -89,24 +113,30 @@
|
||||
" wallet_password=s.PASSWORD,\n",
|
||||
")\n",
|
||||
"doc_2 = doc_loader_2.load()"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"With TLS authentication, wallet_location and wallet_password are not required.\n",
|
||||
"Bind variable option is provided by argument \"parameters\"."
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
"collapsed": false,
|
||||
"jupyter": {
|
||||
"outputs_hidden": false
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"With 1-way TLS authentication, only the database credentials and connection string are required to establish a connection.\n",
|
||||
"The example below also shows passing bind variable values with the argument \"parameters\"."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"jupyter": {
|
||||
"outputs_hidden": false
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"SQL_QUERY = \"select channel_id, channel_desc from sh.channels where channel_desc = :1 fetch first 5 rows only\"\n",
|
||||
@@ -131,31 +161,28 @@
|
||||
" parameters=[\"Direct Sales\"],\n",
|
||||
")\n",
|
||||
"doc_4 = doc_loader_4.load()"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 2
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython2",
|
||||
"version": "2.7.6"
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.11"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
|
||||
@@ -104,7 +104,7 @@ Head to the reference section for full documentation of all classes and methods
|
||||
Trace and evaluate your language model applications and intelligent agents to help you move from prototype to production.
|
||||
|
||||
### [🦜🕸️ LangGraph](https://langchain-ai.github.io/langgraph)
|
||||
Build stateful, multi-actor applications with LLMs. Integrates smoothly with LangChain, but can be used without it. LangGraph powers production-grade agents, trusted by Linkedin, Uber, Klarna, GitLab, and many more.
|
||||
Build stateful, multi-actor applications with LLMs. Integrates smoothly with LangChain, but can be used without it. LangGraph powers production-grade agents, trusted by LinkedIn, Uber, Klarna, GitLab, and many more.
|
||||
|
||||
## Additional resources
|
||||
|
||||
|
||||
@@ -44,4 +44,4 @@ You can peruse [LangSmith tutorials here](https://docs.smith.langchain.com/).
|
||||
|
||||
LangSmith helps you evaluate the performance of your LLM applications. The tutorial below is a great way to get started:
|
||||
|
||||
- [Evaluate your LLM application](https://docs.smith.langchain.com/tutorials/Developers/evaluation)
|
||||
- [Evaluate your LLM application](https://docs.langchain.com/langsmith/evaluate-llm-application)
|
||||
|
||||
@@ -159,7 +159,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": null,
|
||||
"id": "1b2481f0",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -178,8 +178,8 @@
|
||||
"from langchain_core.messages import HumanMessage, SystemMessage\n",
|
||||
"\n",
|
||||
"messages = [\n",
|
||||
" SystemMessage(\"Translate the following from English into Italian\"),\n",
|
||||
" HumanMessage(\"hi!\"),\n",
|
||||
" SystemMessage(content=\"Translate the following from English into Italian\"),\n",
|
||||
" HumanMessage(content=\"hi!\"),\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"model.invoke(messages)"
|
||||
|
||||
@@ -118,7 +118,8 @@ export default function ChatModelTabs(props) {
|
||||
{
|
||||
value: "anthropic",
|
||||
label: "Anthropic",
|
||||
model: "claude-3-5-sonnet-latest",
|
||||
model: "claude-3-7-sonnet-20250219",
|
||||
comment: "# Note: Model versions may become outdated. Check https://docs.anthropic.com/en/docs/models-overview for latest versions",
|
||||
apiKeyName: "ANTHROPIC_API_KEY",
|
||||
packageName: "langchain[anthropic]",
|
||||
},
|
||||
@@ -269,6 +270,9 @@ if not os.environ.get("${selectedTabItem.apiKeyName}"):
|
||||
|
||||
${llmVarName} = init_chat_model("${selectedTabItem.model}", model_provider="${selectedTabItem.value}"${selectedTabItem?.kwargs ? `, ${selectedTabItem.kwargs}` : ""})`;
|
||||
|
||||
// Add comment if available
|
||||
const commentText = selectedTabItem?.comment ? selectedTabItem.comment + "\n\n" : "";
|
||||
|
||||
return (
|
||||
<div>
|
||||
<CustomDropdown
|
||||
@@ -282,7 +286,7 @@ ${llmVarName} = init_chat_model("${selectedTabItem.model}", model_provider="${se
|
||||
{`pip install -qU "${selectedTabItem.packageName}"`}
|
||||
</CodeBlock>
|
||||
<CodeBlock language="python">
|
||||
{apiKeyText ? apiKeyText + "\n\n" + initModelText : initModelText}
|
||||
{apiKeyText ? apiKeyText + "\n\n" + commentText + initModelText : commentText + initModelText}
|
||||
</CodeBlock>
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -72,9 +72,7 @@ def new(
|
||||
name_str = name
|
||||
pip_bool = bool(pip) # None should be false
|
||||
else:
|
||||
name_str = (
|
||||
name if name else typer.prompt("What folder would you like to create?")
|
||||
)
|
||||
name_str = name or typer.prompt("What folder would you like to create?")
|
||||
if not has_packages:
|
||||
package = []
|
||||
package_prompt = "What package would you like to add? (leave blank to skip)"
|
||||
|
||||
@@ -6,7 +6,9 @@ import os
|
||||
import pathlib
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
from typing import Any, Optional
|
||||
from typing import Optional
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
HERE = Path(__file__).parent
|
||||
# Should bring us to [root]/src
|
||||
@@ -22,10 +24,11 @@ class ImportExtractor(ast.NodeVisitor):
|
||||
|
||||
def __init__(self, *, from_package: Optional[str] = None) -> None:
|
||||
"""Extract all imports from the given code, optionally filtering by package."""
|
||||
self.imports: list = []
|
||||
self.imports: list[tuple[str, str]] = []
|
||||
self.package = from_package
|
||||
|
||||
def visit_ImportFrom(self, node: ast.ImportFrom) -> None: # noqa: N802
|
||||
@override
|
||||
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
|
||||
if node.module and (
|
||||
self.package is None or str(node.module).startswith(self.package)
|
||||
):
|
||||
@@ -44,7 +47,8 @@ def _get_class_names(code: str) -> list[str]:
|
||||
|
||||
# Define a node visitor class to collect class names
|
||||
class ClassVisitor(ast.NodeVisitor):
|
||||
def visit_ClassDef(self, node: ast.ClassDef) -> None: # noqa: N802
|
||||
@override
|
||||
def visit_ClassDef(self, node: ast.ClassDef) -> None:
|
||||
class_names.append(node.name)
|
||||
self.generic_visit(node)
|
||||
|
||||
@@ -54,7 +58,7 @@ def _get_class_names(code: str) -> list[str]:
|
||||
return class_names
|
||||
|
||||
|
||||
def is_subclass(class_obj: Any, classes_: list[type]) -> bool:
|
||||
def is_subclass(class_obj: type, classes_: list[type]) -> bool:
|
||||
"""Check if the given class object is a subclass of any class in list classes."""
|
||||
return any(
|
||||
issubclass(class_obj, kls)
|
||||
|
||||
@@ -4,7 +4,7 @@ from pathlib import Path
|
||||
|
||||
import rich
|
||||
import typer
|
||||
from gritql import run # type: ignore[import]
|
||||
from gritql import run # type: ignore[import-untyped]
|
||||
from typer import Option
|
||||
|
||||
|
||||
@@ -68,7 +68,7 @@ def migrate(
|
||||
final_code = run.apply_pattern(
|
||||
"langchain_all_migrations()",
|
||||
args,
|
||||
grit_dir=get_gritdir_path(),
|
||||
grit_dir=str(get_gritdir_path()),
|
||||
)
|
||||
|
||||
raise typer.Exit(code=final_code)
|
||||
|
||||
@@ -34,7 +34,7 @@ def new(
|
||||
package_name_split = computed_name.split("/")
|
||||
package_name = (
|
||||
package_name_split[-2]
|
||||
if len(package_name_split) > 1 and package_name_split[-1] == ""
|
||||
if len(package_name_split) > 1 and not package_name_split[-1]
|
||||
else package_name_split[-1]
|
||||
)
|
||||
module_name = re.sub(
|
||||
|
||||
@@ -6,7 +6,7 @@ app = FastAPI()
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def redirect_root_to_docs():
|
||||
async def redirect_root_to_docs() -> RedirectResponse:
|
||||
return RedirectResponse("/docs")
|
||||
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ class EventDict(TypedDict):
|
||||
properties: Optional[dict[str, Any]]
|
||||
|
||||
|
||||
def create_events(events: list[EventDict]) -> Optional[Any]:
|
||||
def create_events(events: list[EventDict]) -> Optional[dict[str, Any]]:
|
||||
"""Create events."""
|
||||
try:
|
||||
data = {
|
||||
@@ -48,7 +48,8 @@ def create_events(events: list[EventDict]) -> Optional[Any]:
|
||||
|
||||
res = conn.getresponse()
|
||||
|
||||
return json.loads(res.read())
|
||||
response_data = json.loads(res.read())
|
||||
return response_data if isinstance(response_data, dict) else None
|
||||
except (http.client.HTTPException, OSError, json.JSONDecodeError) as exc:
|
||||
typer.echo(f"Error sending events: {exc}")
|
||||
return None
|
||||
|
||||
@@ -6,7 +6,7 @@ import re
|
||||
import shutil
|
||||
from collections.abc import Sequence
|
||||
from pathlib import Path
|
||||
from typing import Optional, TypedDict
|
||||
from typing import Any, Optional, TypedDict
|
||||
|
||||
from git import Repo
|
||||
|
||||
@@ -26,7 +26,7 @@ class DependencySource(TypedDict):
|
||||
ref: Optional[str]
|
||||
subdirectory: Optional[str]
|
||||
api_path: Optional[str]
|
||||
event_metadata: dict
|
||||
event_metadata: dict[str, Any]
|
||||
|
||||
|
||||
# use poetry dependency string format
|
||||
@@ -138,8 +138,8 @@ def parse_dependencies(
|
||||
if (
|
||||
(dependencies and len(dependencies) != num_deps)
|
||||
or (api_path and len(api_path) != num_deps)
|
||||
or (repo and len(repo) not in [1, num_deps])
|
||||
or (branch and len(branch) not in [1, num_deps])
|
||||
or (repo and len(repo) not in {1, num_deps})
|
||||
or (branch and len(branch) not in {1, num_deps})
|
||||
):
|
||||
msg = (
|
||||
"Number of defined repos/branches/api_paths did not match the "
|
||||
@@ -151,15 +151,15 @@ def parse_dependencies(
|
||||
inner_repos = _list_arg_to_length(repo, num_deps)
|
||||
inner_branches = _list_arg_to_length(branch, num_deps)
|
||||
|
||||
return [
|
||||
parse_dependency_string(iter_dep, iter_repo, iter_branch, iter_api_path)
|
||||
for iter_dep, iter_repo, iter_branch, iter_api_path in zip(
|
||||
return list(
|
||||
map(
|
||||
parse_dependency_string,
|
||||
inner_deps,
|
||||
inner_repos,
|
||||
inner_branches,
|
||||
inner_api_paths,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _get_repo_path(gitstring: str, ref: Optional[str], repo_dir: Path) -> Path:
|
||||
@@ -167,7 +167,7 @@ def _get_repo_path(gitstring: str, ref: Optional[str], repo_dir: Path) -> Path:
|
||||
ref_str = ref if ref is not None else ""
|
||||
hashed = hashlib.sha256((f"{gitstring}:{ref_str}").encode()).hexdigest()[:8]
|
||||
|
||||
removed_protocol = gitstring.split("://")[-1]
|
||||
removed_protocol = gitstring.split("://", maxsplit=1)[-1]
|
||||
removed_basename = re.split(r"[/:]", removed_protocol, maxsplit=1)[-1]
|
||||
removed_extras = removed_basename.split("#")[0]
|
||||
foldername = re.sub(r"\W", "_", removed_extras)
|
||||
|
||||
@@ -7,7 +7,7 @@ authors = [{ name = "Erick Friis", email = "erick@langchain.dev" }]
|
||||
license = { text = "MIT" }
|
||||
requires-python = ">=3.9"
|
||||
dependencies = [
|
||||
"typer[all]<1.0.0,>=0.9.0",
|
||||
"typer<1.0.0,>=0.9.0",
|
||||
"gitpython<4,>=3",
|
||||
"langserve[all]>=0.0.51",
|
||||
"uvicorn<1.0,>=0.23",
|
||||
@@ -15,7 +15,7 @@ dependencies = [
|
||||
"gritql<1.0.0,>=0.2.0",
|
||||
]
|
||||
name = "langchain-cli"
|
||||
version = "0.0.36"
|
||||
version = "0.0.37"
|
||||
description = "CLI for interacting with LangChain"
|
||||
readme = "README.md"
|
||||
|
||||
@@ -29,8 +29,8 @@ langchain = "langchain_cli.cli:app"
|
||||
langchain-cli = "langchain_cli.cli:app"
|
||||
|
||||
[dependency-groups]
|
||||
dev = ["pytest<8.0.0,>=7.4.2", "pytest-watcher<1.0.0,>=0.3.4"]
|
||||
lint = ["ruff<0.13,>=0.12.2", "mypy<2.0.0,>=1.13.0"]
|
||||
dev = ["pytest<9.0.0,>=7.4.2", "pytest-watcher<1.0.0,>=0.3.4"]
|
||||
lint = ["ruff<0.13,>=0.12.2", "mypy<1.18,>=1.17.1"]
|
||||
test = ["langchain-core", "langchain"]
|
||||
typing = ["langchain"]
|
||||
test_integration = []
|
||||
@@ -63,9 +63,7 @@ ignore = [
|
||||
"TD003", # Missing issue link in TODO
|
||||
|
||||
# TODO rules
|
||||
"ANN401",
|
||||
"BLE",
|
||||
"D1",
|
||||
]
|
||||
unfixable = [
|
||||
"B028", # People should intentionally tune the stacklevel
|
||||
@@ -84,6 +82,11 @@ pyupgrade.keep-runtime-typing = true
|
||||
"scripts/**" = [ "INP", "S",]
|
||||
|
||||
[tool.mypy]
|
||||
plugins = ["pydantic.mypy"]
|
||||
strict = true
|
||||
enable_error_code = "deprecated"
|
||||
warn_unreachable = true
|
||||
|
||||
exclude = [
|
||||
"langchain_cli/integration_template",
|
||||
"langchain_cli/package_template",
|
||||
|
||||
@@ -52,7 +52,7 @@ def cli() -> None:
|
||||
def generic(
|
||||
pkg1: str,
|
||||
pkg2: str,
|
||||
output: str,
|
||||
output: Optional[str],
|
||||
filter_by_all: bool, # noqa: FBT001
|
||||
format_: str,
|
||||
) -> None:
|
||||
@@ -73,7 +73,7 @@ def generic(
|
||||
else:
|
||||
dumped = dump_migrations_as_grit(name, migrations)
|
||||
|
||||
Path(output).write_text(dumped)
|
||||
Path(output).write_text(dumped, encoding="utf-8")
|
||||
|
||||
|
||||
def handle_partner(pkg: str, output: Optional[str] = None) -> None:
|
||||
@@ -84,7 +84,7 @@ def handle_partner(pkg: str, output: Optional[str] = None) -> None:
|
||||
data = dump_migrations_as_grit(name, migrations)
|
||||
output_name = f"{name}.grit" if output is None else output
|
||||
if migrations:
|
||||
Path(output_name).write_text(data)
|
||||
Path(output_name).write_text(data, encoding="utf-8")
|
||||
click.secho(f"LangChain migration script saved to {output_name}")
|
||||
else:
|
||||
click.secho(f"No migrations found for {pkg}", fg="yellow")
|
||||
@@ -109,7 +109,7 @@ def json_to_grit(json_file: str) -> None:
|
||||
name = file.stem
|
||||
data = dump_migrations_as_grit(name, migrations)
|
||||
output_name = f"{name}.grit"
|
||||
Path(output_name).write_text(data)
|
||||
Path(output_name).write_text(data, encoding="utf-8")
|
||||
click.secho(f"GritQL migration script saved to {output_name}")
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -16,7 +16,7 @@ def _assert_dependency_equals(
|
||||
git: Optional[str] = None,
|
||||
ref: Optional[str] = None,
|
||||
subdirectory: Optional[str] = None,
|
||||
event_metadata: Optional[dict] = None,
|
||||
event_metadata: Optional[dict[str, Any]] = None,
|
||||
) -> None:
|
||||
if dep["git"] != git:
|
||||
msg = f"Expected git to be {git} but got {dep['git']}"
|
||||
|
||||
111
libs/cli/uv.lock
generated
111
libs/cli/uv.lock
generated
@@ -1,5 +1,5 @@
|
||||
version = 1
|
||||
revision = 2
|
||||
revision = 3
|
||||
requires-python = ">=3.9"
|
||||
resolution-markers = [
|
||||
"python_full_version >= '3.12.4'",
|
||||
@@ -425,15 +425,11 @@ dependencies = [
|
||||
requires-dist = [
|
||||
{ name = "async-timeout", marker = "python_full_version < '3.11'", specifier = ">=4.0.0,<5.0.0" },
|
||||
{ name = "langchain-anthropic", marker = "extra == 'anthropic'" },
|
||||
{ name = "langchain-aws", marker = "extra == 'aws'" },
|
||||
{ name = "langchain-azure-ai", marker = "extra == 'azure-ai'" },
|
||||
{ name = "langchain-cohere", marker = "extra == 'cohere'" },
|
||||
{ name = "langchain-community", marker = "extra == 'community'" },
|
||||
{ name = "langchain-core", editable = "../core" },
|
||||
{ name = "langchain-deepseek", marker = "extra == 'deepseek'" },
|
||||
{ name = "langchain-fireworks", marker = "extra == 'fireworks'" },
|
||||
{ name = "langchain-google-genai", marker = "extra == 'google-genai'" },
|
||||
{ name = "langchain-google-vertexai", marker = "extra == 'google-vertexai'" },
|
||||
{ name = "langchain-groq", marker = "extra == 'groq'" },
|
||||
{ name = "langchain-huggingface", marker = "extra == 'huggingface'" },
|
||||
{ name = "langchain-mistralai", marker = "extra == 'mistralai'" },
|
||||
@@ -442,14 +438,13 @@ requires-dist = [
|
||||
{ name = "langchain-perplexity", marker = "extra == 'perplexity'" },
|
||||
{ name = "langchain-text-splitters", editable = "../text-splitters" },
|
||||
{ name = "langchain-together", marker = "extra == 'together'" },
|
||||
{ name = "langchain-xai", marker = "extra == 'xai'" },
|
||||
{ name = "langsmith", specifier = ">=0.1.17" },
|
||||
{ name = "pydantic", specifier = ">=2.7.4,<3.0.0" },
|
||||
{ name = "pyyaml", specifier = ">=5.3" },
|
||||
{ name = "requests", specifier = ">=2,<3" },
|
||||
{ name = "sqlalchemy", specifier = ">=1.4,<3" },
|
||||
]
|
||||
provides-extras = ["community", "anthropic", "openai", "azure-ai", "cohere", "google-vertexai", "google-genai", "fireworks", "ollama", "together", "mistralai", "huggingface", "groq", "aws", "deepseek", "xai", "perplexity"]
|
||||
provides-extras = ["community", "anthropic", "openai", "google-genai", "fireworks", "ollama", "together", "mistralai", "huggingface", "groq", "deepseek", "perplexity"]
|
||||
|
||||
[package.metadata.requires-dev]
|
||||
codespell = [{ name = "codespell", specifier = ">=2.2.0,<3.0.0" }]
|
||||
@@ -520,7 +515,7 @@ typing = [
|
||||
|
||||
[[package]]
|
||||
name = "langchain-cli"
|
||||
version = "0.0.36"
|
||||
version = "0.0.37"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "gitpython" },
|
||||
@@ -554,17 +549,17 @@ requires-dist = [
|
||||
{ name = "gritql", specifier = ">=0.2.0,<1.0.0" },
|
||||
{ name = "langserve", extras = ["all"], specifier = ">=0.0.51" },
|
||||
{ name = "tomlkit", specifier = ">=0.12" },
|
||||
{ name = "typer", extras = ["all"], specifier = ">=0.9.0,<1.0.0" },
|
||||
{ name = "typer", specifier = ">=0.9.0,<1.0.0" },
|
||||
{ name = "uvicorn", specifier = ">=0.23,<1.0" },
|
||||
]
|
||||
|
||||
[package.metadata.requires-dev]
|
||||
dev = [
|
||||
{ name = "pytest", specifier = ">=7.4.2,<8.0.0" },
|
||||
{ name = "pytest", specifier = ">=7.4.2,<9.0.0" },
|
||||
{ name = "pytest-watcher", specifier = ">=0.3.4,<1.0.0" },
|
||||
]
|
||||
lint = [
|
||||
{ name = "mypy", specifier = ">=1.13.0,<2.0.0" },
|
||||
{ name = "mypy", specifier = ">=1.17.1,<1.18" },
|
||||
{ name = "ruff", specifier = ">=0.12.2,<0.13" },
|
||||
]
|
||||
test = [
|
||||
@@ -576,7 +571,7 @@ typing = [{ name = "langchain", editable = "../langchain" }]
|
||||
|
||||
[[package]]
|
||||
name = "langchain-core"
|
||||
version = "0.3.72"
|
||||
version = "0.3.75"
|
||||
source = { editable = "../core" }
|
||||
dependencies = [
|
||||
{ name = "jsonpatch" },
|
||||
@@ -627,14 +622,14 @@ test = [
|
||||
test-integration = []
|
||||
typing = [
|
||||
{ name = "langchain-text-splitters", directory = "../text-splitters" },
|
||||
{ name = "mypy", specifier = ">=1.15,<1.16" },
|
||||
{ name = "mypy", specifier = ">=1.17.1,<1.18" },
|
||||
{ name = "types-pyyaml", specifier = ">=6.0.12.2,<7.0.0.0" },
|
||||
{ name = "types-requests", specifier = ">=2.28.11.5,<3.0.0.0" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "langchain-text-splitters"
|
||||
version = "0.3.9"
|
||||
version = "0.3.11"
|
||||
source = { editable = "../text-splitters" }
|
||||
dependencies = [
|
||||
{ name = "langchain-core" },
|
||||
@@ -650,7 +645,7 @@ dev = [
|
||||
]
|
||||
lint = [
|
||||
{ name = "langchain-core", editable = "../core" },
|
||||
{ name = "ruff", specifier = ">=0.12.2,<0.13" },
|
||||
{ name = "ruff", specifier = ">=0.12.8,<0.13" },
|
||||
]
|
||||
test = [
|
||||
{ name = "freezegun", specifier = ">=1.2.2,<2.0.0" },
|
||||
@@ -663,15 +658,17 @@ test = [
|
||||
{ name = "pytest-xdist", specifier = ">=3.6.1,<4.0.0" },
|
||||
]
|
||||
test-integration = [
|
||||
{ name = "en-core-web-sm", url = "https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl" },
|
||||
{ name = "nltk", specifier = ">=3.9.1,<4.0.0" },
|
||||
{ name = "sentence-transformers", specifier = ">=3.0.1" },
|
||||
{ name = "spacy", specifier = ">=3.8.7,<4.0.0" },
|
||||
{ name = "thinc", specifier = ">=8.3.6,<9.0.0" },
|
||||
{ name = "tiktoken", specifier = ">=0.8.0,<1.0.0" },
|
||||
{ name = "transformers", specifier = ">=4.51.3,<5.0.0" },
|
||||
]
|
||||
typing = [
|
||||
{ name = "lxml-stubs", specifier = ">=0.5.1,<1.0.0" },
|
||||
{ name = "mypy", specifier = ">=1.15,<2.0" },
|
||||
{ name = "mypy", specifier = ">=1.17.1,<1.18" },
|
||||
{ name = "tiktoken", specifier = ">=0.8.0,<1.0.0" },
|
||||
{ name = "types-requests", specifier = ">=2.31.0.20240218,<3.0.0.0" },
|
||||
]
|
||||
@@ -738,46 +735,53 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "mypy"
|
||||
version = "1.14.1"
|
||||
version = "1.17.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "mypy-extensions" },
|
||||
{ name = "pathspec" },
|
||||
{ name = "tomli", marker = "python_full_version < '3.11'" },
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/b9/eb/2c92d8ea1e684440f54fa49ac5d9a5f19967b7b472a281f419e69a8d228e/mypy-1.14.1.tar.gz", hash = "sha256:7ec88144fe9b510e8475ec2f5f251992690fcf89ccb4500b214b4226abcd32d6", size = 3216051, upload-time = "2024-12-30T16:39:07.335Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/8e/22/ea637422dedf0bf36f3ef238eab4e455e2a0dcc3082b5cc067615347ab8e/mypy-1.17.1.tar.gz", hash = "sha256:25e01ec741ab5bb3eec8ba9cdb0f769230368a22c959c4937360efb89b7e9f01", size = 3352570, upload-time = "2025-07-31T07:54:19.204Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/9b/7a/87ae2adb31d68402da6da1e5f30c07ea6063e9f09b5e7cfc9dfa44075e74/mypy-1.14.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:52686e37cf13d559f668aa398dd7ddf1f92c5d613e4f8cb262be2fb4fedb0fcb", size = 11211002, upload-time = "2024-12-30T16:37:22.435Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e1/23/eada4c38608b444618a132be0d199b280049ded278b24cbb9d3fc59658e4/mypy-1.14.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:1fb545ca340537d4b45d3eecdb3def05e913299ca72c290326be19b3804b39c0", size = 10358400, upload-time = "2024-12-30T16:37:53.526Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/43/c9/d6785c6f66241c62fd2992b05057f404237deaad1566545e9f144ced07f5/mypy-1.14.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:90716d8b2d1f4cd503309788e51366f07c56635a3309b0f6a32547eaaa36a64d", size = 12095172, upload-time = "2024-12-30T16:37:50.332Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c3/62/daa7e787770c83c52ce2aaf1a111eae5893de9e004743f51bfcad9e487ec/mypy-1.14.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2ae753f5c9fef278bcf12e1a564351764f2a6da579d4a81347e1d5a15819997b", size = 12828732, upload-time = "2024-12-30T16:37:29.96Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/1b/a2/5fb18318a3637f29f16f4e41340b795da14f4751ef4f51c99ff39ab62e52/mypy-1.14.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:e0fe0f5feaafcb04505bcf439e991c6d8f1bf8b15f12b05feeed96e9e7bf1427", size = 13012197, upload-time = "2024-12-30T16:38:05.037Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/28/99/e153ce39105d164b5f02c06c35c7ba958aaff50a2babba7d080988b03fe7/mypy-1.14.1-cp310-cp310-win_amd64.whl", hash = "sha256:7d54bd85b925e501c555a3227f3ec0cfc54ee8b6930bd6141ec872d1c572f81f", size = 9780836, upload-time = "2024-12-30T16:37:19.726Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/da/11/a9422850fd506edbcdc7f6090682ecceaf1f87b9dd847f9df79942da8506/mypy-1.14.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:f995e511de847791c3b11ed90084a7a0aafdc074ab88c5a9711622fe4751138c", size = 11120432, upload-time = "2024-12-30T16:37:11.533Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b6/9e/47e450fd39078d9c02d620545b2cb37993a8a8bdf7db3652ace2f80521ca/mypy-1.14.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d64169ec3b8461311f8ce2fd2eb5d33e2d0f2c7b49116259c51d0d96edee48d1", size = 10279515, upload-time = "2024-12-30T16:37:40.724Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/01/b5/6c8d33bd0f851a7692a8bfe4ee75eb82b6983a3cf39e5e32a5d2a723f0c1/mypy-1.14.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ba24549de7b89b6381b91fbc068d798192b1b5201987070319889e93038967a8", size = 12025791, upload-time = "2024-12-30T16:36:58.73Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f0/4c/e10e2c46ea37cab5c471d0ddaaa9a434dc1d28650078ac1b56c2d7b9b2e4/mypy-1.14.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:183cf0a45457d28ff9d758730cd0210419ac27d4d3f285beda038c9083363b1f", size = 12749203, upload-time = "2024-12-30T16:37:03.741Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/88/55/beacb0c69beab2153a0f57671ec07861d27d735a0faff135a494cd4f5020/mypy-1.14.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:f2a0ecc86378f45347f586e4163d1769dd81c5a223d577fe351f26b179e148b1", size = 12885900, upload-time = "2024-12-30T16:37:57.948Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a2/75/8c93ff7f315c4d086a2dfcde02f713004357d70a163eddb6c56a6a5eff40/mypy-1.14.1-cp311-cp311-win_amd64.whl", hash = "sha256:ad3301ebebec9e8ee7135d8e3109ca76c23752bac1e717bc84cd3836b4bf3eae", size = 9777869, upload-time = "2024-12-30T16:37:33.428Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/43/1b/b38c079609bb4627905b74fc6a49849835acf68547ac33d8ceb707de5f52/mypy-1.14.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:30ff5ef8519bbc2e18b3b54521ec319513a26f1bba19a7582e7b1f58a6e69f14", size = 11266668, upload-time = "2024-12-30T16:38:02.211Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/6b/75/2ed0d2964c1ffc9971c729f7a544e9cd34b2cdabbe2d11afd148d7838aa2/mypy-1.14.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:cb9f255c18052343c70234907e2e532bc7e55a62565d64536dbc7706a20b78b9", size = 10254060, upload-time = "2024-12-30T16:37:46.131Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a1/5f/7b8051552d4da3c51bbe8fcafffd76a6823779101a2b198d80886cd8f08e/mypy-1.14.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8b4e3413e0bddea671012b063e27591b953d653209e7a4fa5e48759cda77ca11", size = 11933167, upload-time = "2024-12-30T16:37:43.534Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/04/90/f53971d3ac39d8b68bbaab9a4c6c58c8caa4d5fd3d587d16f5927eeeabe1/mypy-1.14.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:553c293b1fbdebb6c3c4030589dab9fafb6dfa768995a453d8a5d3b23784af2e", size = 12864341, upload-time = "2024-12-30T16:37:36.249Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/03/d2/8bc0aeaaf2e88c977db41583559319f1821c069e943ada2701e86d0430b7/mypy-1.14.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:fad79bfe3b65fe6a1efaed97b445c3d37f7be9fdc348bdb2d7cac75579607c89", size = 12972991, upload-time = "2024-12-30T16:37:06.743Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/6f/17/07815114b903b49b0f2cf7499f1c130e5aa459411596668267535fe9243c/mypy-1.14.1-cp312-cp312-win_amd64.whl", hash = "sha256:8fa2220e54d2946e94ab6dbb3ba0a992795bd68b16dc852db33028df2b00191b", size = 9879016, upload-time = "2024-12-30T16:37:15.02Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/9e/15/bb6a686901f59222275ab228453de741185f9d54fecbaacec041679496c6/mypy-1.14.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:92c3ed5afb06c3a8e188cb5da4984cab9ec9a77ba956ee419c68a388b4595255", size = 11252097, upload-time = "2024-12-30T16:37:25.144Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f8/b3/8b0f74dfd072c802b7fa368829defdf3ee1566ba74c32a2cb2403f68024c/mypy-1.14.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:dbec574648b3e25f43d23577309b16534431db4ddc09fda50841f1e34e64ed34", size = 10239728, upload-time = "2024-12-30T16:38:08.634Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c5/9b/4fd95ab20c52bb5b8c03cc49169be5905d931de17edfe4d9d2986800b52e/mypy-1.14.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8c6d94b16d62eb3e947281aa7347d78236688e21081f11de976376cf010eb31a", size = 11924965, upload-time = "2024-12-30T16:38:12.132Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/56/9d/4a236b9c57f5d8f08ed346914b3f091a62dd7e19336b2b2a0d85485f82ff/mypy-1.14.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d4b19b03fdf54f3c5b2fa474c56b4c13c9dbfb9a2db4370ede7ec11a2c5927d9", size = 12867660, upload-time = "2024-12-30T16:38:17.342Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/40/88/a61a5497e2f68d9027de2bb139c7bb9abaeb1be1584649fa9d807f80a338/mypy-1.14.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:0c911fde686394753fff899c409fd4e16e9b294c24bfd5e1ea4675deae1ac6fd", size = 12969198, upload-time = "2024-12-30T16:38:32.839Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/54/da/3d6fc5d92d324701b0c23fb413c853892bfe0e1dbe06c9138037d459756b/mypy-1.14.1-cp313-cp313-win_amd64.whl", hash = "sha256:8b21525cb51671219f5307be85f7e646a153e5acc656e5cebf64bfa076c50107", size = 9885276, upload-time = "2024-12-30T16:38:20.828Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ca/1f/186d133ae2514633f8558e78cd658070ba686c0e9275c5a5c24a1e1f0d67/mypy-1.14.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3888a1816d69f7ab92092f785a462944b3ca16d7c470d564165fe703b0970c35", size = 11200493, upload-time = "2024-12-30T16:38:26.935Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/af/fc/4842485d034e38a4646cccd1369f6b1ccd7bc86989c52770d75d719a9941/mypy-1.14.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:46c756a444117c43ee984bd055db99e498bc613a70bbbc120272bd13ca579fbc", size = 10357702, upload-time = "2024-12-30T16:38:50.623Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b4/e6/457b83f2d701e23869cfec013a48a12638f75b9d37612a9ddf99072c1051/mypy-1.14.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:27fc248022907e72abfd8e22ab1f10e903915ff69961174784a3900a8cba9ad9", size = 12091104, upload-time = "2024-12-30T16:38:53.735Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f1/bf/76a569158db678fee59f4fd30b8e7a0d75bcbaeef49edd882a0d63af6d66/mypy-1.14.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:499d6a72fb7e5de92218db961f1a66d5f11783f9ae549d214617edab5d4dbdbb", size = 12830167, upload-time = "2024-12-30T16:38:56.437Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/43/bc/0bc6b694b3103de9fed61867f1c8bd33336b913d16831431e7cb48ef1c92/mypy-1.14.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:57961db9795eb566dc1d1b4e9139ebc4c6b0cb6e7254ecde69d1552bf7613f60", size = 13013834, upload-time = "2024-12-30T16:38:59.204Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b0/79/5f5ec47849b6df1e6943d5fd8e6632fbfc04b4fd4acfa5a5a9535d11b4e2/mypy-1.14.1-cp39-cp39-win_amd64.whl", hash = "sha256:07ba89fdcc9451f2ebb02853deb6aaaa3d2239a236669a63ab3801bbf923ef5c", size = 9781231, upload-time = "2024-12-30T16:39:05.124Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a0/b5/32dd67b69a16d088e533962e5044e51004176a9952419de0370cdaead0f8/mypy-1.14.1-py3-none-any.whl", hash = "sha256:b66a60cc4073aeb8ae00057f9c1f64d49e90f918fbcef9a977eb121da8b8f1d1", size = 2752905, upload-time = "2024-12-30T16:38:42.021Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/77/a9/3d7aa83955617cdf02f94e50aab5c830d205cfa4320cf124ff64acce3a8e/mypy-1.17.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:3fbe6d5555bf608c47203baa3e72dbc6ec9965b3d7c318aa9a4ca76f465bd972", size = 11003299, upload-time = "2025-07-31T07:54:06.425Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/83/e8/72e62ff837dd5caaac2b4a5c07ce769c8e808a00a65e5d8f94ea9c6f20ab/mypy-1.17.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:80ef5c058b7bce08c83cac668158cb7edea692e458d21098c7d3bce35a5d43e7", size = 10125451, upload-time = "2025-07-31T07:53:52.974Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/7d/10/f3f3543f6448db11881776f26a0ed079865926b0c841818ee22de2c6bbab/mypy-1.17.1-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c4a580f8a70c69e4a75587bd925d298434057fe2a428faaf927ffe6e4b9a98df", size = 11916211, upload-time = "2025-07-31T07:53:18.879Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/06/bf/63e83ed551282d67bb3f7fea2cd5561b08d2bb6eb287c096539feb5ddbc5/mypy-1.17.1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:dd86bb649299f09d987a2eebb4d52d10603224500792e1bee18303bbcc1ce390", size = 12652687, upload-time = "2025-07-31T07:53:30.544Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/69/66/68f2eeef11facf597143e85b694a161868b3b006a5fbad50e09ea117ef24/mypy-1.17.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:a76906f26bd8d51ea9504966a9c25419f2e668f012e0bdf3da4ea1526c534d94", size = 12896322, upload-time = "2025-07-31T07:53:50.74Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a3/87/8e3e9c2c8bd0d7e071a89c71be28ad088aaecbadf0454f46a540bda7bca6/mypy-1.17.1-cp310-cp310-win_amd64.whl", hash = "sha256:e79311f2d904ccb59787477b7bd5d26f3347789c06fcd7656fa500875290264b", size = 9507962, upload-time = "2025-07-31T07:53:08.431Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/46/cf/eadc80c4e0a70db1c08921dcc220357ba8ab2faecb4392e3cebeb10edbfa/mypy-1.17.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ad37544be07c5d7fba814eb370e006df58fed8ad1ef33ed1649cb1889ba6ff58", size = 10921009, upload-time = "2025-07-31T07:53:23.037Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/5d/c1/c869d8c067829ad30d9bdae051046561552516cfb3a14f7f0347b7d973ee/mypy-1.17.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:064e2ff508e5464b4bd807a7c1625bc5047c5022b85c70f030680e18f37273a5", size = 10047482, upload-time = "2025-07-31T07:53:26.151Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/98/b9/803672bab3fe03cee2e14786ca056efda4bb511ea02dadcedde6176d06d0/mypy-1.17.1-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:70401bbabd2fa1aa7c43bb358f54037baf0586f41e83b0ae67dd0534fc64edfd", size = 11832883, upload-time = "2025-07-31T07:53:47.948Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/88/fb/fcdac695beca66800918c18697b48833a9a6701de288452b6715a98cfee1/mypy-1.17.1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e92bdc656b7757c438660f775f872a669b8ff374edc4d18277d86b63edba6b8b", size = 12566215, upload-time = "2025-07-31T07:54:04.031Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/7f/37/a932da3d3dace99ee8eb2043b6ab03b6768c36eb29a02f98f46c18c0da0e/mypy-1.17.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:c1fdf4abb29ed1cb091cf432979e162c208a5ac676ce35010373ff29247bcad5", size = 12751956, upload-time = "2025-07-31T07:53:36.263Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/8c/cf/6438a429e0f2f5cab8bc83e53dbebfa666476f40ee322e13cac5e64b79e7/mypy-1.17.1-cp311-cp311-win_amd64.whl", hash = "sha256:ff2933428516ab63f961644bc49bc4cbe42bbffb2cd3b71cc7277c07d16b1a8b", size = 9507307, upload-time = "2025-07-31T07:53:59.734Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/17/a2/7034d0d61af8098ec47902108553122baa0f438df8a713be860f7407c9e6/mypy-1.17.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:69e83ea6553a3ba79c08c6e15dbd9bfa912ec1e493bf75489ef93beb65209aeb", size = 11086295, upload-time = "2025-07-31T07:53:28.124Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/14/1f/19e7e44b594d4b12f6ba8064dbe136505cec813549ca3e5191e40b1d3cc2/mypy-1.17.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:1b16708a66d38abb1e6b5702f5c2c87e133289da36f6a1d15f6a5221085c6403", size = 10112355, upload-time = "2025-07-31T07:53:21.121Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/5b/69/baa33927e29e6b4c55d798a9d44db5d394072eef2bdc18c3e2048c9ed1e9/mypy-1.17.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:89e972c0035e9e05823907ad5398c5a73b9f47a002b22359b177d40bdaee7056", size = 11875285, upload-time = "2025-07-31T07:53:55.293Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/90/13/f3a89c76b0a41e19490b01e7069713a30949d9a6c147289ee1521bcea245/mypy-1.17.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:03b6d0ed2b188e35ee6d5c36b5580cffd6da23319991c49ab5556c023ccf1341", size = 12737895, upload-time = "2025-07-31T07:53:43.623Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/23/a1/c4ee79ac484241301564072e6476c5a5be2590bc2e7bfd28220033d2ef8f/mypy-1.17.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:c837b896b37cd103570d776bda106eabb8737aa6dd4f248451aecf53030cdbeb", size = 12931025, upload-time = "2025-07-31T07:54:17.125Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/89/b8/7409477be7919a0608900e6320b155c72caab4fef46427c5cc75f85edadd/mypy-1.17.1-cp312-cp312-win_amd64.whl", hash = "sha256:665afab0963a4b39dff7c1fa563cc8b11ecff7910206db4b2e64dd1ba25aed19", size = 9584664, upload-time = "2025-07-31T07:54:12.842Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/5b/82/aec2fc9b9b149f372850291827537a508d6c4d3664b1750a324b91f71355/mypy-1.17.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:93378d3203a5c0800c6b6d850ad2f19f7a3cdf1a3701d3416dbf128805c6a6a7", size = 11075338, upload-time = "2025-07-31T07:53:38.873Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/07/ac/ee93fbde9d2242657128af8c86f5d917cd2887584cf948a8e3663d0cd737/mypy-1.17.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:15d54056f7fe7a826d897789f53dd6377ec2ea8ba6f776dc83c2902b899fee81", size = 10113066, upload-time = "2025-07-31T07:54:14.707Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/5a/68/946a1e0be93f17f7caa56c45844ec691ca153ee8b62f21eddda336a2d203/mypy-1.17.1-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:209a58fed9987eccc20f2ca94afe7257a8f46eb5df1fb69958650973230f91e6", size = 11875473, upload-time = "2025-07-31T07:53:14.504Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/9f/0f/478b4dce1cb4f43cf0f0d00fba3030b21ca04a01b74d1cd272a528cf446f/mypy-1.17.1-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:099b9a5da47de9e2cb5165e581f158e854d9e19d2e96b6698c0d64de911dd849", size = 12744296, upload-time = "2025-07-31T07:53:03.896Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ca/70/afa5850176379d1b303f992a828de95fc14487429a7139a4e0bdd17a8279/mypy-1.17.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:fa6ffadfbe6994d724c5a1bb6123a7d27dd68fc9c059561cd33b664a79578e14", size = 12914657, upload-time = "2025-07-31T07:54:08.576Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/53/f9/4a83e1c856a3d9c8f6edaa4749a4864ee98486e9b9dbfbc93842891029c2/mypy-1.17.1-cp313-cp313-win_amd64.whl", hash = "sha256:9a2b7d9180aed171f033c9f2fc6c204c1245cf60b0cb61cf2e7acc24eea78e0a", size = 9593320, upload-time = "2025-07-31T07:53:01.341Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/38/56/79c2fac86da57c7d8c48622a05873eaab40b905096c33597462713f5af90/mypy-1.17.1-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:15a83369400454c41ed3a118e0cc58bd8123921a602f385cb6d6ea5df050c733", size = 11040037, upload-time = "2025-07-31T07:54:10.942Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/4d/c3/adabe6ff53638e3cad19e3547268482408323b1e68bf082c9119000cd049/mypy-1.17.1-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:55b918670f692fc9fba55c3298d8a3beae295c5cded0a55dccdc5bbead814acd", size = 10131550, upload-time = "2025-07-31T07:53:41.307Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b8/c5/2e234c22c3bdeb23a7817af57a58865a39753bde52c74e2c661ee0cfc640/mypy-1.17.1-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:62761474061feef6f720149d7ba876122007ddc64adff5ba6f374fda35a018a0", size = 11872963, upload-time = "2025-07-31T07:53:16.878Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ab/26/c13c130f35ca8caa5f2ceab68a247775648fdcd6c9a18f158825f2bc2410/mypy-1.17.1-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c49562d3d908fd49ed0938e5423daed8d407774a479b595b143a3d7f87cdae6a", size = 12710189, upload-time = "2025-07-31T07:54:01.962Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/82/df/c7d79d09f6de8383fe800521d066d877e54d30b4fb94281c262be2df84ef/mypy-1.17.1-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:397fba5d7616a5bc60b45c7ed204717eaddc38f826e3645402c426057ead9a91", size = 12900322, upload-time = "2025-07-31T07:53:10.551Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b8/98/3d5a48978b4f708c55ae832619addc66d677f6dc59f3ebad71bae8285ca6/mypy-1.17.1-cp314-cp314-win_amd64.whl", hash = "sha256:9d6b20b97d373f41617bd0708fd46aa656059af57f2ef72aa8c7d6a2b73b74ed", size = 9751879, upload-time = "2025-07-31T07:52:56.683Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/29/cb/673e3d34e5d8de60b3a61f44f80150a738bff568cd6b7efb55742a605e98/mypy-1.17.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:5d1092694f166a7e56c805caaf794e0585cabdbf1df36911c414e4e9abb62ae9", size = 10992466, upload-time = "2025-07-31T07:53:57.574Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/0c/d0/fe1895836eea3a33ab801561987a10569df92f2d3d4715abf2cfeaa29cb2/mypy-1.17.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:79d44f9bfb004941ebb0abe8eff6504223a9c1ac51ef967d1263c6572bbebc99", size = 10117638, upload-time = "2025-07-31T07:53:34.256Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/97/f3/514aa5532303aafb95b9ca400a31054a2bd9489de166558c2baaeea9c522/mypy-1.17.1-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b01586eed696ec905e61bd2568f48740f7ac4a45b3a468e6423a03d3788a51a8", size = 11915673, upload-time = "2025-07-31T07:52:59.361Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ab/c3/c0805f0edec96fe8e2c048b03769a6291523d509be8ee7f56ae922fa3882/mypy-1.17.1-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:43808d9476c36b927fbcd0b0255ce75efe1b68a080154a38ae68a7e62de8f0f8", size = 12649022, upload-time = "2025-07-31T07:53:45.92Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/45/3e/d646b5a298ada21a8512fa7e5531f664535a495efa672601702398cea2b4/mypy-1.17.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:feb8cc32d319edd5859da2cc084493b3e2ce5e49a946377663cc90f6c15fb259", size = 12895536, upload-time = "2025-07-31T07:53:06.17Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/14/55/e13d0dcd276975927d1f4e9e2ec4fd409e199f01bdc671717e673cc63a22/mypy-1.17.1-cp39-cp39-win_amd64.whl", hash = "sha256:d7598cf74c3e16539d4e2f0b8d8c318e00041553d83d4861f87c7a72e95ac24d", size = 9512564, upload-time = "2025-07-31T07:53:12.346Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/1d/f3/8fcd2af0f5b806f6cf463efaffd3c9548a28f84220493ecd38d127b6b66d/mypy-1.17.1-py3-none-any.whl", hash = "sha256:a9f52c0351c21fe24c21d8c0eb1f62967b262d6729393397b6f443c3b773c3b9", size = 2283411, upload-time = "2025-07-31T07:53:24.664Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -871,6 +875,15 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/88/ef/eb23f262cca3c0c4eb7ab1933c3b1f03d021f2c48f54763065b6f0e321be/packaging-24.2-py3-none-any.whl", hash = "sha256:09abb1bccd265c01f4a3aa3f7a7db064b36514d2cba19a2f694fe6150451a759", size = 65451, upload-time = "2024-11-08T09:47:44.722Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pathspec"
|
||||
version = "0.12.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/ca/bc/f35b8446f4531a7cb215605d100cd88b7ac6f44ab3fc94870c120ab3adbf/pathspec-0.12.1.tar.gz", hash = "sha256:a482d51503a1ab33b1c67a6c3813a26953dbdc71c31dacaef9a838c4e29f5712", size = 51043, upload-time = "2023-12-10T22:30:45Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/cc/20/ff623b09d963f88bfde16306a54e12ee5ea43e9b597108672ff3a408aad6/pathspec-0.12.1-py3-none-any.whl", hash = "sha256:a0d503e138a4c123b27490a4f7beda6a01c6f288df0e4a8b79c7eb0dc7b4cc08", size = 31191, upload-time = "2023-12-10T22:30:43.14Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pluggy"
|
||||
version = "1.5.0"
|
||||
|
||||
@@ -515,6 +515,7 @@ _WellKnownOpenAITools = (
|
||||
"mcp",
|
||||
"image_generation",
|
||||
"web_search_preview",
|
||||
"web_search",
|
||||
)
|
||||
|
||||
|
||||
|
||||
13
libs/core/uv.lock
generated
13
libs/core/uv.lock
generated
@@ -1,5 +1,5 @@
|
||||
version = 1
|
||||
revision = 2
|
||||
revision = 3
|
||||
requires-python = ">=3.9"
|
||||
resolution-markers = [
|
||||
"python_full_version >= '3.14' and platform_python_implementation == 'PyPy'",
|
||||
@@ -1154,28 +1154,27 @@ dependencies = [
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "httpx", specifier = ">=0.25.0,<1" },
|
||||
{ name = "httpx", specifier = ">=0.28.1,<1" },
|
||||
{ name = "langchain-core", editable = "." },
|
||||
{ name = "numpy", marker = "python_full_version < '3.13'", specifier = ">=1.26.2" },
|
||||
{ name = "numpy", marker = "python_full_version >= '3.13'", specifier = ">=2.1.0" },
|
||||
{ name = "pytest", specifier = ">=7,<9" },
|
||||
{ name = "pytest-asyncio", specifier = ">=0.20,<1" },
|
||||
{ name = "pytest-asyncio", specifier = ">=0.20,<2" },
|
||||
{ name = "pytest-benchmark" },
|
||||
{ name = "pytest-codspeed" },
|
||||
{ name = "pytest-recording" },
|
||||
{ name = "pytest-socket", specifier = ">=0.6.0,<1" },
|
||||
{ name = "pytest-socket", specifier = ">=0.7.0,<1" },
|
||||
{ name = "syrupy", specifier = ">=4,<5" },
|
||||
{ name = "vcrpy", specifier = ">=7.0" },
|
||||
]
|
||||
|
||||
[package.metadata.requires-dev]
|
||||
codespell = [{ name = "codespell", specifier = ">=2.2.0,<3.0.0" }]
|
||||
lint = [{ name = "ruff", specifier = ">=0.12.8,<0.13" }]
|
||||
lint = [{ name = "ruff", specifier = ">=0.12.10,<0.13" }]
|
||||
test = [{ name = "langchain-core", editable = "." }]
|
||||
test-integration = []
|
||||
typing = [
|
||||
{ name = "langchain-core", editable = "." },
|
||||
{ name = "mypy", specifier = ">=1,<2" },
|
||||
{ name = "mypy", specifier = ">=1.17.1,<2" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
@@ -326,7 +326,7 @@ class SQLRecordManager(RecordManager):
|
||||
records_to_upsert,
|
||||
)
|
||||
stmt = pg_insert_stmt.on_conflict_do_update( # type: ignore[assignment]
|
||||
"uix_key_namespace", # Name of constraint
|
||||
constraint="uix_key_namespace", # Name of constraint
|
||||
set_={
|
||||
"updated_at": pg_insert_stmt.excluded.updated_at,
|
||||
"group_id": pg_insert_stmt.excluded.group_id,
|
||||
@@ -408,7 +408,7 @@ class SQLRecordManager(RecordManager):
|
||||
records_to_upsert,
|
||||
)
|
||||
stmt = pg_insert_stmt.on_conflict_do_update( # type: ignore[assignment]
|
||||
"uix_key_namespace", # Name of constraint
|
||||
constraint="uix_key_namespace", # Name of constraint
|
||||
set_={
|
||||
"updated_at": pg_insert_stmt.excluded.updated_at,
|
||||
"group_id": pg_insert_stmt.excluded.group_id,
|
||||
|
||||
15
libs/langchain/uv.lock
generated
15
libs/langchain/uv.lock
generated
@@ -2190,7 +2190,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "langchain-core"
|
||||
version = "0.3.74"
|
||||
version = "0.3.75"
|
||||
source = { editable = "../core" }
|
||||
dependencies = [
|
||||
{ name = "jsonpatch" },
|
||||
@@ -2346,7 +2346,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "langchain-openai"
|
||||
version = "0.3.30"
|
||||
version = "0.3.32"
|
||||
source = { editable = "../partners/openai" }
|
||||
dependencies = [
|
||||
{ name = "langchain-core" },
|
||||
@@ -2429,28 +2429,27 @@ dependencies = [
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "httpx", specifier = ">=0.25.0,<1" },
|
||||
{ name = "httpx", specifier = ">=0.28.1,<1" },
|
||||
{ name = "langchain-core", editable = "../core" },
|
||||
{ name = "numpy", marker = "python_full_version < '3.13'", specifier = ">=1.26.2" },
|
||||
{ name = "numpy", marker = "python_full_version >= '3.13'", specifier = ">=2.1.0" },
|
||||
{ name = "pytest", specifier = ">=7,<9" },
|
||||
{ name = "pytest-asyncio", specifier = ">=0.20,<1" },
|
||||
{ name = "pytest-asyncio", specifier = ">=0.20,<2" },
|
||||
{ name = "pytest-benchmark" },
|
||||
{ name = "pytest-codspeed" },
|
||||
{ name = "pytest-recording" },
|
||||
{ name = "pytest-socket", specifier = ">=0.6.0,<1" },
|
||||
{ name = "pytest-socket", specifier = ">=0.7.0,<1" },
|
||||
{ name = "syrupy", specifier = ">=4,<5" },
|
||||
{ name = "vcrpy", specifier = ">=7.0" },
|
||||
]
|
||||
|
||||
[package.metadata.requires-dev]
|
||||
codespell = [{ name = "codespell", specifier = ">=2.2.0,<3.0.0" }]
|
||||
lint = [{ name = "ruff", specifier = ">=0.12.8,<0.13" }]
|
||||
lint = [{ name = "ruff", specifier = ">=0.12.10,<0.13" }]
|
||||
test = [{ name = "langchain-core", editable = "../core" }]
|
||||
test-integration = []
|
||||
typing = [
|
||||
{ name = "langchain-core", editable = "../core" },
|
||||
{ name = "mypy", specifier = ">=1,<2" },
|
||||
{ name = "mypy", specifier = ">=1.17.1,<2" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
.PHONY: all clean docs_build docs_clean docs_linkcheck api_docs_build api_docs_clean api_docs_linkcheck format lint test tests test_watch integration_tests help extended_tests
|
||||
.PHONY: all clean docs_build docs_clean docs_linkcheck api_docs_build api_docs_clean api_docs_linkcheck format lint test tests test_watch integration_tests help extended_tests start_services stop_services
|
||||
|
||||
# Default target executed when no arguments are given to make.
|
||||
all: help
|
||||
@@ -7,6 +7,12 @@ all: help
|
||||
# TESTING AND COVERAGE
|
||||
######################
|
||||
|
||||
start_services:
|
||||
docker compose -f tests/unit_tests/agents/compose-postgres.yml -f tests/unit_tests/agents/compose-redis.yml up -V --force-recreate --wait --remove-orphans
|
||||
|
||||
stop_services:
|
||||
docker compose -f tests/unit_tests/agents/compose-postgres.yml -f tests/unit_tests/agents/compose-redis.yml down -v
|
||||
|
||||
# Define a variable for the test file path.
|
||||
TEST_FILE ?= tests/unit_tests/
|
||||
|
||||
@@ -21,17 +27,32 @@ coverage:
|
||||
--cov-report term-missing:skip-covered \
|
||||
$(TEST_FILE)
|
||||
|
||||
test tests:
|
||||
uv run --group test pytest -n auto --disable-socket --allow-unix-socket $(TEST_FILE)
|
||||
test:
|
||||
make start_services && LANGGRAPH_TEST_FAST=0 uv run --group test pytest -n auto --disable-socket --allow-unix-socket $(TEST_FILE) --cov-report term-missing:skip-covered; \
|
||||
EXIT_CODE=$$?; \
|
||||
make stop_services; \
|
||||
exit $$EXIT_CODE
|
||||
|
||||
test_fast:
|
||||
LANGGRAPH_TEST_FAST=1 uv run --group test pytest -n auto --disable-socket --allow-unix-socket $(TEST_FILE)
|
||||
|
||||
extended_tests:
|
||||
uv run --group test pytest --disable-socket --allow-unix-socket --only-extended tests/unit_tests
|
||||
make start_services && LANGGRAPH_TEST_FAST=0 uv run --group test pytest --disable-socket --allow-unix-socket --only-extended tests/unit_tests; \
|
||||
EXIT_CODE=$$?; \
|
||||
make stop_services; \
|
||||
exit $$EXIT_CODE
|
||||
|
||||
test_watch:
|
||||
uv run --group test ptw --snapshot-update --now . -- -x --disable-socket --allow-unix-socket --disable-warnings tests/unit_tests
|
||||
make start_services && LANGGRAPH_TEST_FAST=0 uv run --group test ptw --snapshot-update --now . -- -x --disable-socket --allow-unix-socket --disable-warnings tests/unit_tests; \
|
||||
EXIT_CODE=$$?; \
|
||||
make stop_services; \
|
||||
exit $$EXIT_CODE
|
||||
|
||||
test_watch_extended:
|
||||
uv run --group test ptw --snapshot-update --now . -- -x --disable-socket --allow-unix-socket --only-extended tests/unit_tests
|
||||
make start_services && LANGGRAPH_TEST_FAST=0 uv run --group test ptw --snapshot-update --now . -- -x --disable-socket --allow-unix-socket --only-extended tests/unit_tests; \
|
||||
EXIT_CODE=$$?; \
|
||||
make stop_services; \
|
||||
exit $$EXIT_CODE
|
||||
|
||||
integration_tests:
|
||||
uv run --group test --group test_integration pytest tests/integration_tests
|
||||
@@ -87,7 +108,8 @@ help:
|
||||
@echo 'spell_fix - run codespell on the project and fix the errors'
|
||||
@echo '-- TESTS --'
|
||||
@echo 'coverage - run unit tests and generate coverage report'
|
||||
@echo 'test - run unit tests'
|
||||
@echo 'test - run unit tests with all services'
|
||||
@echo 'test_fast - run unit tests with in-memory services only'
|
||||
@echo 'tests - run unit tests (alias for "make test")'
|
||||
@echo 'test TEST_FILE=<test_file> - run all tests in file'
|
||||
@echo 'extended_tests - run only extended unit tests'
|
||||
|
||||
@@ -1,18 +1,15 @@
|
||||
"""Main entrypoint into package."""
|
||||
"""Main entrypoint into LangChain."""
|
||||
|
||||
from importlib import metadata
|
||||
from typing import Any
|
||||
|
||||
try:
|
||||
__version__ = metadata.version(__package__)
|
||||
except metadata.PackageNotFoundError:
|
||||
# Case where package metadata is not available.
|
||||
__version__ = ""
|
||||
del metadata # optional, avoids polluting the results of dir(__package__)
|
||||
__version__ = "1.0.0a3"
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any: # noqa: ANN401
|
||||
"""Get an attribute from the package."""
|
||||
"""Get an attribute from the package.
|
||||
|
||||
TODO: will be removed in a future alpha version.
|
||||
"""
|
||||
if name == "verbose":
|
||||
from langchain.globals import _verbose
|
||||
|
||||
|
||||
@@ -32,9 +32,4 @@ def format_document_xml(doc: Document) -> str:
|
||||
if doc.metadata:
|
||||
metadata_items = [f"{k}: {v!s}" for k, v in doc.metadata.items()]
|
||||
metadata_str = f"<metadata>{', '.join(metadata_items)}</metadata>"
|
||||
return (
|
||||
f"<document>{id_str}"
|
||||
f"<content>{doc.page_content}</content>"
|
||||
f"{metadata_str}"
|
||||
f"</document>"
|
||||
)
|
||||
return f"<document>{id_str}<content>{doc.page_content}</content>{metadata_str}</document>"
|
||||
|
||||
@@ -12,10 +12,10 @@ particularly for summarization chains and other document processing workflows.
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from typing import TYPE_CHECKING, Callable, Union
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Awaitable
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
from langchain_core.messages import MessageLikeRepresentation
|
||||
from langgraph.runtime import Runtime
|
||||
@@ -92,9 +92,7 @@ async def aresolve_prompt(
|
||||
str,
|
||||
None,
|
||||
Callable[[StateT, Runtime[ContextT]], list[MessageLikeRepresentation]],
|
||||
Callable[
|
||||
[StateT, Runtime[ContextT]], Awaitable[list[MessageLikeRepresentation]]
|
||||
],
|
||||
Callable[[StateT, Runtime[ContextT]], Awaitable[list[MessageLikeRepresentation]]],
|
||||
],
|
||||
state: StateT,
|
||||
runtime: Runtime[ContextT],
|
||||
|
||||
@@ -2,11 +2,10 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Protocol, TypeVar, Union
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Protocol, TypeAlias, TypeVar, Union
|
||||
|
||||
from langgraph.graph._node import StateNode
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dataclasses import Field
|
||||
|
||||
10
libs/langchain_v1/langchain/agents/__init__.py
Normal file
10
libs/langchain_v1/langchain/agents/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""langgraph.prebuilt exposes a higher-level API for creating and executing agents and tools."""
|
||||
|
||||
from langchain.agents.react_agent import AgentState, create_agent
|
||||
from langchain.agents.tool_node import ToolNode
|
||||
|
||||
__all__ = [
|
||||
"AgentState",
|
||||
"ToolNode",
|
||||
"create_agent",
|
||||
]
|
||||
1
libs/langchain_v1/langchain/agents/_internal/__init__.py
Normal file
1
libs/langchain_v1/langchain/agents/_internal/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Internal utilities for agents."""
|
||||
13
libs/langchain_v1/langchain/agents/_internal/_typing.py
Normal file
13
libs/langchain_v1/langchain/agents/_internal/_typing.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""Typing utilities for agents."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import TypeVar, Union
|
||||
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
SyncOrAsync = Callable[P, Union[R, Awaitable[R]]]
|
||||
92
libs/langchain_v1/langchain/agents/interrupt.py
Normal file
92
libs/langchain_v1/langchain/agents/interrupt.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""Interrupt types to use with agent inbox like setups."""
|
||||
|
||||
from typing import Literal, Union
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
||||
class HumanInterruptConfig(TypedDict):
|
||||
"""Configuration that defines what actions are allowed for a human interrupt.
|
||||
|
||||
This controls the available interaction options when the graph is paused for human input.
|
||||
|
||||
Attributes:
|
||||
allow_ignore: Whether the human can choose to ignore/skip the current step
|
||||
allow_respond: Whether the human can provide a text response/feedback
|
||||
allow_edit: Whether the human can edit the provided content/state
|
||||
allow_accept: Whether the human can accept/approve the current state
|
||||
"""
|
||||
|
||||
allow_ignore: bool
|
||||
allow_respond: bool
|
||||
allow_edit: bool
|
||||
allow_accept: bool
|
||||
|
||||
|
||||
class ActionRequest(TypedDict):
|
||||
"""Represents a request for human action within the graph execution.
|
||||
|
||||
Contains the action type and any associated arguments needed for the action.
|
||||
|
||||
Attributes:
|
||||
action: The type or name of action being requested (e.g., "Approve XYZ action")
|
||||
args: Key-value pairs of arguments needed for the action
|
||||
"""
|
||||
|
||||
action: str
|
||||
args: dict
|
||||
|
||||
|
||||
class HumanInterrupt(TypedDict):
|
||||
"""Represents an interrupt triggered by the graph that requires human intervention.
|
||||
|
||||
This is passed to the `interrupt` function when execution is paused for human input.
|
||||
|
||||
Attributes:
|
||||
action_request: The specific action being requested from the human
|
||||
config: Configuration defining what actions are allowed
|
||||
description: Optional detailed description of what input is needed
|
||||
|
||||
Example:
|
||||
```python
|
||||
# Extract a tool call from the state and create an interrupt request
|
||||
request = HumanInterrupt(
|
||||
action_request=ActionRequest(
|
||||
action="run_command", # The action being requested
|
||||
args={"command": "ls", "args": ["-l"]} # Arguments for the action
|
||||
),
|
||||
config=HumanInterruptConfig(
|
||||
allow_ignore=True, # Allow skipping this step
|
||||
allow_respond=True, # Allow text feedback
|
||||
allow_edit=False, # Don't allow editing
|
||||
allow_accept=True # Allow direct acceptance
|
||||
),
|
||||
description="Please review the command before execution"
|
||||
)
|
||||
# Send the interrupt request and get the response
|
||||
response = interrupt([request])[0]
|
||||
```
|
||||
"""
|
||||
|
||||
action_request: ActionRequest
|
||||
config: HumanInterruptConfig
|
||||
description: str | None
|
||||
|
||||
|
||||
class HumanResponse(TypedDict):
|
||||
"""The response provided by a human to an interrupt, which is returned when graph execution resumes.
|
||||
|
||||
Attributes:
|
||||
type: The type of response:
|
||||
- "accept": Approves the current state without changes
|
||||
- "ignore": Skips/ignores the current step
|
||||
- "response": Provides text feedback or instructions
|
||||
- "edit": Modifies the current state/content
|
||||
args: The response payload:
|
||||
- None: For ignore/accept actions
|
||||
- str: For text responses
|
||||
- ActionRequest: For edit actions with updated content
|
||||
"""
|
||||
|
||||
type: Literal["accept", "ignore", "response", "edit"]
|
||||
args: Union[None, str, ActionRequest]
|
||||
15
libs/langchain_v1/langchain/agents/middleware/__init__.py
Normal file
15
libs/langchain_v1/langchain/agents/middleware/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""Middleware plugins for agents."""
|
||||
|
||||
from .human_in_the_loop import HumanInTheLoopMiddleware
|
||||
from .prompt_caching import AnthropicPromptCachingMiddleware
|
||||
from .summarization import SummarizationMiddleware
|
||||
from .types import AgentMiddleware, AgentState, ModelRequest
|
||||
|
||||
__all__ = [
|
||||
"AgentMiddleware",
|
||||
"AgentState",
|
||||
"AnthropicPromptCachingMiddleware",
|
||||
"HumanInTheLoopMiddleware",
|
||||
"ModelRequest",
|
||||
"SummarizationMiddleware",
|
||||
]
|
||||
11
libs/langchain_v1/langchain/agents/middleware/_utils.py
Normal file
11
libs/langchain_v1/langchain/agents/middleware/_utils.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""Utility functions for middleware."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
def _generate_correction_tool_messages(content: str, tool_calls: list) -> list[dict[str, Any]]:
|
||||
"""Generate tool messages for model behavior correction."""
|
||||
return [
|
||||
{"role": "tool", "content": content, "tool_call_id": tool_call["id"]}
|
||||
for tool_call in tool_calls
|
||||
]
|
||||
@@ -0,0 +1,128 @@
|
||||
"""Human in the loop middleware."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from langgraph.prebuilt.interrupt import (
|
||||
ActionRequest,
|
||||
HumanInterrupt,
|
||||
HumanInterruptConfig,
|
||||
HumanResponse,
|
||||
)
|
||||
from langgraph.types import interrupt
|
||||
|
||||
from langchain.agents.middleware._utils import _generate_correction_tool_messages
|
||||
from langchain.agents.middleware.types import AgentMiddleware, AgentState
|
||||
|
||||
ToolInterruptConfig = dict[str, HumanInterruptConfig]
|
||||
|
||||
|
||||
class HumanInTheLoopMiddleware(AgentMiddleware):
|
||||
"""Human in the loop middleware."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tool_configs: ToolInterruptConfig,
|
||||
message_prefix: str = "Tool execution requires approval",
|
||||
) -> None:
|
||||
"""Initialize the human in the loop middleware.
|
||||
|
||||
Args:
|
||||
tool_configs: The tool interrupt configs to use for the middleware.
|
||||
message_prefix: The message prefix to use when constructing interrupt content.
|
||||
"""
|
||||
super().__init__()
|
||||
self.tool_configs = tool_configs
|
||||
self.message_prefix = message_prefix
|
||||
|
||||
def after_model(self, state: AgentState) -> dict[str, Any] | None:
|
||||
"""Trigger HITL flows for relevant tool calls after an AIMessage."""
|
||||
messages = state["messages"]
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
last_message = messages[-1]
|
||||
|
||||
if not hasattr(last_message, "tool_calls") or not last_message.tool_calls:
|
||||
return None
|
||||
|
||||
# Separate tool calls that need interrupts from those that don't
|
||||
interrupt_tool_calls = []
|
||||
auto_approved_tool_calls = []
|
||||
|
||||
for tool_call in last_message.tool_calls:
|
||||
tool_name = tool_call["name"]
|
||||
if tool_name in self.tool_configs:
|
||||
interrupt_tool_calls.append(tool_call)
|
||||
else:
|
||||
auto_approved_tool_calls.append(tool_call)
|
||||
|
||||
# If no interrupts needed, return early
|
||||
if not interrupt_tool_calls:
|
||||
return None
|
||||
|
||||
approved_tool_calls = auto_approved_tool_calls.copy()
|
||||
|
||||
# Right now, we do not support multiple tool calls with interrupts
|
||||
if len(interrupt_tool_calls) > 1:
|
||||
tool_names = [t["name"] for t in interrupt_tool_calls]
|
||||
msg = f"Called the following tools which require interrupts: {tool_names}\n\nYou may only call ONE tool that requires an interrupt at a time"
|
||||
return {
|
||||
"messages": _generate_correction_tool_messages(msg, last_message.tool_calls),
|
||||
"jump_to": "model",
|
||||
}
|
||||
|
||||
# Right now, we do not support interrupting a tool call if other tool calls exist
|
||||
if auto_approved_tool_calls:
|
||||
tool_names = [t["name"] for t in interrupt_tool_calls]
|
||||
msg = f"Called the following tools which require interrupts: {tool_names}. You also called other tools that do not require interrupts. If you call a tool that requires and interrupt, you may ONLY call that tool."
|
||||
return {
|
||||
"messages": _generate_correction_tool_messages(msg, last_message.tool_calls),
|
||||
"jump_to": "model",
|
||||
}
|
||||
|
||||
# Only one tool call will need interrupts
|
||||
tool_call = interrupt_tool_calls[0]
|
||||
tool_name = tool_call["name"]
|
||||
tool_args = tool_call["args"]
|
||||
description = f"{self.message_prefix}\n\nTool: {tool_name}\nArgs: {tool_args}"
|
||||
tool_config = self.tool_configs[tool_name]
|
||||
|
||||
request: HumanInterrupt = {
|
||||
"action_request": ActionRequest(
|
||||
action=tool_name,
|
||||
args=tool_args,
|
||||
),
|
||||
"config": tool_config,
|
||||
"description": description,
|
||||
}
|
||||
|
||||
responses: list[HumanResponse] = interrupt([request])
|
||||
response = responses[0]
|
||||
|
||||
if response["type"] == "accept":
|
||||
approved_tool_calls.append(tool_call)
|
||||
elif response["type"] == "edit":
|
||||
edited: ActionRequest = response["args"] # type: ignore[assignment]
|
||||
new_tool_call = {
|
||||
"type": "tool_call",
|
||||
"name": tool_call["name"],
|
||||
"args": edited["args"],
|
||||
"id": tool_call["id"],
|
||||
}
|
||||
approved_tool_calls.append(new_tool_call)
|
||||
elif response["type"] == "ignore":
|
||||
return {"jump_to": "__end__"}
|
||||
elif response["type"] == "response":
|
||||
tool_message = {
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call["id"],
|
||||
"content": response["args"],
|
||||
}
|
||||
return {"messages": [tool_message], "jump_to": "model"}
|
||||
else:
|
||||
msg = f"Unknown response type: {response['type']}"
|
||||
raise ValueError(msg)
|
||||
|
||||
last_message.tool_calls = approved_tool_calls
|
||||
|
||||
return {"messages": [last_message]}
|
||||
@@ -0,0 +1,57 @@
|
||||
"""Anthropic prompt caching middleware."""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from langchain.agents.middleware.types import AgentMiddleware, AgentState, ModelRequest
|
||||
|
||||
|
||||
class AnthropicPromptCachingMiddleware(AgentMiddleware):
|
||||
"""Prompt Caching Middleware - Optimizes API usage by caching conversation prefixes for Anthropic models.
|
||||
|
||||
Learn more about anthropic prompt caching [here](https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
type: Literal["ephemeral"] = "ephemeral",
|
||||
ttl: Literal["5m", "1h"] = "5m",
|
||||
min_messages_to_cache: int = 0,
|
||||
) -> None:
|
||||
"""Initialize the middleware with cache control settings.
|
||||
|
||||
Args:
|
||||
type: The type of cache to use, only "ephemeral" is supported.
|
||||
ttl: The time to live for the cache, only "5m" and "1h" are supported.
|
||||
min_messages_to_cache: The minimum number of messages until the cache is used, default is 0.
|
||||
"""
|
||||
self.type = type
|
||||
self.ttl = ttl
|
||||
self.min_messages_to_cache = min_messages_to_cache
|
||||
|
||||
def modify_model_request(self, request: ModelRequest, state: AgentState) -> ModelRequest: # noqa: ARG002
|
||||
"""Modify the model request to add cache control blocks."""
|
||||
try:
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
except ImportError:
|
||||
msg = (
|
||||
"AnthropicPromptCachingMiddleware caching middleware only supports Anthropic models."
|
||||
"Please install langchain-anthropic."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
if not isinstance(request.model, ChatAnthropic):
|
||||
msg = (
|
||||
"AnthropicPromptCachingMiddleware caching middleware only supports Anthropic models, "
|
||||
f"not instances of {type(request.model)}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
messages_count = (
|
||||
len(request.messages) + 1 if request.system_prompt else len(request.messages)
|
||||
)
|
||||
if messages_count < self.min_messages_to_cache:
|
||||
return request
|
||||
|
||||
request.model_settings["cache_control"] = {"type": self.type, "ttl": self.ttl}
|
||||
|
||||
return request
|
||||
248
libs/langchain_v1/langchain/agents/middleware/summarization.py
Normal file
248
libs/langchain_v1/langchain/agents/middleware/summarization.py
Normal file
@@ -0,0 +1,248 @@
|
||||
"""Summarization middleware."""
|
||||
|
||||
import uuid
|
||||
from collections.abc import Callable, Iterable
|
||||
from typing import Any, cast
|
||||
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AnyMessage,
|
||||
MessageLikeRepresentation,
|
||||
RemoveMessage,
|
||||
ToolMessage,
|
||||
)
|
||||
from langchain_core.messages.human import HumanMessage
|
||||
from langchain_core.messages.utils import count_tokens_approximately, trim_messages
|
||||
from langgraph.graph.message import (
|
||||
REMOVE_ALL_MESSAGES,
|
||||
)
|
||||
|
||||
from langchain.agents.middleware.types import AgentMiddleware, AgentState
|
||||
from langchain.chat_models import BaseChatModel, init_chat_model
|
||||
|
||||
TokenCounter = Callable[[Iterable[MessageLikeRepresentation]], int]
|
||||
|
||||
DEFAULT_SUMMARY_PROMPT = """<role>
|
||||
Context Extraction Assistant
|
||||
</role>
|
||||
|
||||
<primary_objective>
|
||||
Your sole objective in this task is to extract the highest quality/most relevant context from the conversation history below.
|
||||
</primary_objective>
|
||||
|
||||
<objective_information>
|
||||
You're nearing the total number of input tokens you can accept, so you must extract the highest quality/most relevant pieces of information from your conversation history.
|
||||
This context will then overwrite the conversation history presented below. Because of this, ensure the context you extract is only the most important information to your overall goal.
|
||||
</objective_information>
|
||||
|
||||
<instructions>
|
||||
The conversation history below will be replaced with the context you extract in this step. Because of this, you must do your very best to extract and record all of the most important context from the conversation history.
|
||||
You want to ensure that you don't repeat any actions you've already completed, so the context you extract from the conversation history should be focused on the most important information to your overall goal.
|
||||
</instructions>
|
||||
|
||||
The user will message you with the full message history you'll be extracting context from, to then replace. Carefully read over it all, and think deeply about what information is most important to your overall goal that should be saved:
|
||||
|
||||
With all of this in mind, please carefully read over the entire conversation history, and extract the most important and relevant context to replace it so that you can free up space in the conversation history.
|
||||
Respond ONLY with the extracted context. Do not include any additional information, or text before or after the extracted context.
|
||||
|
||||
<messages>
|
||||
Messages to summarize:
|
||||
{messages}
|
||||
</messages>"""
|
||||
|
||||
SUMMARY_PREFIX = "## Previous conversation summary:"
|
||||
|
||||
_DEFAULT_MESSAGES_TO_KEEP = 20
|
||||
_DEFAULT_TRIM_TOKEN_LIMIT = 4000
|
||||
_DEFAULT_FALLBACK_MESSAGE_COUNT = 15
|
||||
_SEARCH_RANGE_FOR_TOOL_PAIRS = 5
|
||||
|
||||
|
||||
class SummarizationMiddleware(AgentMiddleware):
|
||||
"""Middleware that summarizes conversation history when token limits are approached.
|
||||
|
||||
This middleware monitors message token counts and automatically summarizes older
|
||||
messages when a threshold is reached, preserving recent messages and maintaining
|
||||
context continuity by ensuring AI/Tool message pairs remain together.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str | BaseChatModel,
|
||||
max_tokens_before_summary: int | None = None,
|
||||
messages_to_keep: int = _DEFAULT_MESSAGES_TO_KEEP,
|
||||
token_counter: TokenCounter = count_tokens_approximately,
|
||||
summary_prompt: str = DEFAULT_SUMMARY_PROMPT,
|
||||
summary_prefix: str = SUMMARY_PREFIX,
|
||||
) -> None:
|
||||
"""Initialize the summarization middleware.
|
||||
|
||||
Args:
|
||||
model: The language model to use for generating summaries.
|
||||
max_tokens_before_summary: Token threshold to trigger summarization.
|
||||
If None, summarization is disabled.
|
||||
messages_to_keep: Number of recent messages to preserve after summarization.
|
||||
token_counter: Function to count tokens in messages.
|
||||
summary_prompt: Prompt template for generating summaries.
|
||||
summary_prefix: Prefix added to system message when including summary.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
if isinstance(model, str):
|
||||
model = init_chat_model(model)
|
||||
|
||||
self.model = model
|
||||
self.max_tokens_before_summary = max_tokens_before_summary
|
||||
self.messages_to_keep = messages_to_keep
|
||||
self.token_counter = token_counter
|
||||
self.summary_prompt = summary_prompt
|
||||
self.summary_prefix = summary_prefix
|
||||
|
||||
def before_model(self, state: AgentState) -> dict[str, Any] | None:
|
||||
"""Process messages before model invocation, potentially triggering summarization."""
|
||||
messages = state["messages"]
|
||||
self._ensure_message_ids(messages)
|
||||
|
||||
total_tokens = self.token_counter(messages)
|
||||
if (
|
||||
self.max_tokens_before_summary is not None
|
||||
and total_tokens < self.max_tokens_before_summary
|
||||
):
|
||||
return None
|
||||
|
||||
cutoff_index = self._find_safe_cutoff(messages)
|
||||
|
||||
if cutoff_index <= 0:
|
||||
return None
|
||||
|
||||
messages_to_summarize, preserved_messages = self._partition_messages(messages, cutoff_index)
|
||||
|
||||
summary = self._create_summary(messages_to_summarize)
|
||||
new_messages = self._build_new_messages(summary)
|
||||
|
||||
return {
|
||||
"messages": [
|
||||
RemoveMessage(id=REMOVE_ALL_MESSAGES),
|
||||
*new_messages,
|
||||
*preserved_messages,
|
||||
]
|
||||
}
|
||||
|
||||
def _build_new_messages(self, summary: str) -> list[HumanMessage]:
|
||||
return [
|
||||
HumanMessage(content=f"Here is a summary of the conversation to date:\n\n{summary}")
|
||||
]
|
||||
|
||||
def _ensure_message_ids(self, messages: list[AnyMessage]) -> None:
|
||||
"""Ensure all messages have unique IDs for the add_messages reducer."""
|
||||
for msg in messages:
|
||||
if msg.id is None:
|
||||
msg.id = str(uuid.uuid4())
|
||||
|
||||
def _partition_messages(
|
||||
self,
|
||||
conversation_messages: list[AnyMessage],
|
||||
cutoff_index: int,
|
||||
) -> tuple[list[AnyMessage], list[AnyMessage]]:
|
||||
"""Partition messages into those to summarize and those to preserve."""
|
||||
messages_to_summarize = conversation_messages[:cutoff_index]
|
||||
preserved_messages = conversation_messages[cutoff_index:]
|
||||
|
||||
return messages_to_summarize, preserved_messages
|
||||
|
||||
def _find_safe_cutoff(self, messages: list[AnyMessage]) -> int:
|
||||
"""Find safe cutoff point that preserves AI/Tool message pairs.
|
||||
|
||||
Returns the index where messages can be safely cut without separating
|
||||
related AI and Tool messages. Returns 0 if no safe cutoff is found.
|
||||
"""
|
||||
if len(messages) <= self.messages_to_keep:
|
||||
return 0
|
||||
|
||||
target_cutoff = len(messages) - self.messages_to_keep
|
||||
|
||||
for i in range(target_cutoff, -1, -1):
|
||||
if self._is_safe_cutoff_point(messages, i):
|
||||
return i
|
||||
|
||||
return 0
|
||||
|
||||
def _is_safe_cutoff_point(self, messages: list[AnyMessage], cutoff_index: int) -> bool:
|
||||
"""Check if cutting at index would separate AI/Tool message pairs."""
|
||||
if cutoff_index >= len(messages):
|
||||
return True
|
||||
|
||||
search_start = max(0, cutoff_index - _SEARCH_RANGE_FOR_TOOL_PAIRS)
|
||||
search_end = min(len(messages), cutoff_index + _SEARCH_RANGE_FOR_TOOL_PAIRS)
|
||||
|
||||
for i in range(search_start, search_end):
|
||||
if not self._has_tool_calls(messages[i]):
|
||||
continue
|
||||
|
||||
tool_call_ids = self._extract_tool_call_ids(cast("AIMessage", messages[i]))
|
||||
if self._cutoff_separates_tool_pair(messages, i, cutoff_index, tool_call_ids):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _has_tool_calls(self, message: AnyMessage) -> bool:
|
||||
"""Check if message is an AI message with tool calls."""
|
||||
return (
|
||||
isinstance(message, AIMessage) and hasattr(message, "tool_calls") and message.tool_calls # type: ignore[return-value]
|
||||
)
|
||||
|
||||
def _extract_tool_call_ids(self, ai_message: AIMessage) -> set[str]:
|
||||
"""Extract tool call IDs from an AI message."""
|
||||
tool_call_ids = set()
|
||||
for tc in ai_message.tool_calls:
|
||||
call_id = tc.get("id") if isinstance(tc, dict) else getattr(tc, "id", None)
|
||||
if call_id is not None:
|
||||
tool_call_ids.add(call_id)
|
||||
return tool_call_ids
|
||||
|
||||
def _cutoff_separates_tool_pair(
|
||||
self,
|
||||
messages: list[AnyMessage],
|
||||
ai_message_index: int,
|
||||
cutoff_index: int,
|
||||
tool_call_ids: set[str],
|
||||
) -> bool:
|
||||
"""Check if cutoff separates an AI message from its corresponding tool messages."""
|
||||
for j in range(ai_message_index + 1, len(messages)):
|
||||
message = messages[j]
|
||||
if isinstance(message, ToolMessage) and message.tool_call_id in tool_call_ids:
|
||||
ai_before_cutoff = ai_message_index < cutoff_index
|
||||
tool_before_cutoff = j < cutoff_index
|
||||
if ai_before_cutoff != tool_before_cutoff:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _create_summary(self, messages_to_summarize: list[AnyMessage]) -> str:
|
||||
"""Generate summary for the given messages."""
|
||||
if not messages_to_summarize:
|
||||
return "No previous conversation history."
|
||||
|
||||
trimmed_messages = self._trim_messages_for_summary(messages_to_summarize)
|
||||
if not trimmed_messages:
|
||||
return "Previous conversation was too long to summarize."
|
||||
|
||||
try:
|
||||
response = self.model.invoke(self.summary_prompt.format(messages=trimmed_messages))
|
||||
return cast("str", response.content).strip()
|
||||
except Exception as e: # noqa: BLE001
|
||||
return f"Error generating summary: {e!s}"
|
||||
|
||||
def _trim_messages_for_summary(self, messages: list[AnyMessage]) -> list[AnyMessage]:
|
||||
"""Trim messages to fit within summary generation limits."""
|
||||
try:
|
||||
return trim_messages(
|
||||
messages,
|
||||
max_tokens=_DEFAULT_TRIM_TOKEN_LIMIT,
|
||||
token_counter=self.token_counter,
|
||||
start_on="human",
|
||||
strategy="last",
|
||||
allow_partial=True,
|
||||
include_system=True,
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
return messages[-_DEFAULT_FALLBACK_MESSAGE_COUNT:]
|
||||
78
libs/langchain_v1/langchain/agents/middleware/types.py
Normal file
78
libs/langchain_v1/langchain/agents/middleware/types.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""Types for middleware and agents."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Annotated, Any, Generic, Literal, cast
|
||||
|
||||
# needed as top level import for pydantic schema generation on AgentState
|
||||
from langchain_core.messages import AnyMessage # noqa: TC002
|
||||
from langgraph.channels.ephemeral_value import EphemeralValue
|
||||
from langgraph.graph.message import Messages, add_messages
|
||||
from typing_extensions import NotRequired, Required, TypedDict, TypeVar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from langchain.agents.structured_output import ResponseFormat
|
||||
|
||||
JumpTo = Literal["tools", "model", "__end__"]
|
||||
"""Destination to jump to when a middleware node returns."""
|
||||
|
||||
ResponseT = TypeVar("ResponseT")
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelRequest:
|
||||
"""Model request information for the agent."""
|
||||
|
||||
model: BaseChatModel
|
||||
system_prompt: str | None
|
||||
messages: list[AnyMessage] # excluding system prompt
|
||||
tool_choice: Any | None
|
||||
tools: list[BaseTool]
|
||||
response_format: ResponseFormat | None
|
||||
model_settings: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class AgentState(TypedDict, Generic[ResponseT]):
|
||||
"""State schema for the agent."""
|
||||
|
||||
messages: Required[Annotated[list[AnyMessage], add_messages]]
|
||||
model_request: NotRequired[Annotated[ModelRequest | None, EphemeralValue]]
|
||||
jump_to: NotRequired[Annotated[JumpTo | None, EphemeralValue]]
|
||||
response: NotRequired[ResponseT]
|
||||
|
||||
|
||||
class PublicAgentState(TypedDict, Generic[ResponseT]):
|
||||
"""Input / output schema for the agent."""
|
||||
|
||||
messages: Required[Messages]
|
||||
response: NotRequired[ResponseT]
|
||||
|
||||
|
||||
StateT = TypeVar("StateT", bound=AgentState)
|
||||
|
||||
|
||||
class AgentMiddleware(Generic[StateT]):
|
||||
"""Base middleware class for an agent.
|
||||
|
||||
Subclass this and implement any of the defined methods to customize agent behavior between steps in the main agent loop.
|
||||
"""
|
||||
|
||||
state_schema: type[StateT] = cast("type[StateT]", AgentState)
|
||||
"""The schema for state passed to the middleware nodes."""
|
||||
|
||||
tools: list[BaseTool]
|
||||
"""Additional tools registered by the middleware."""
|
||||
|
||||
def before_model(self, state: StateT) -> dict[str, Any] | None:
|
||||
"""Logic to run before the model is called."""
|
||||
|
||||
def modify_model_request(self, request: ModelRequest, state: StateT) -> ModelRequest: # noqa: ARG002
|
||||
"""Logic to modify request kwargs before the model is called."""
|
||||
return request
|
||||
|
||||
def after_model(self, state: StateT) -> dict[str, Any] | None:
|
||||
"""Logic to run after the model is called."""
|
||||
554
libs/langchain_v1/langchain/agents/middleware_agent.py
Normal file
554
libs/langchain_v1/langchain/agents/middleware_agent.py
Normal file
@@ -0,0 +1,554 @@
|
||||
"""Middleware agent implementation."""
|
||||
|
||||
import itertools
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Any, Union
|
||||
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage, AnyMessage, SystemMessage, ToolMessage
|
||||
from langchain_core.runnables import Runnable
|
||||
from langchain_core.tools import BaseTool
|
||||
from langgraph.constants import END, START
|
||||
from langgraph.graph.state import StateGraph
|
||||
from langgraph.typing import ContextT
|
||||
from typing_extensions import TypedDict, TypeVar
|
||||
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
JumpTo,
|
||||
ModelRequest,
|
||||
PublicAgentState,
|
||||
)
|
||||
|
||||
# Import structured output classes from the old implementation
|
||||
from langchain.agents.structured_output import (
|
||||
MultipleStructuredOutputsError,
|
||||
OutputToolBinding,
|
||||
ProviderStrategy,
|
||||
ProviderStrategyBinding,
|
||||
ResponseFormat,
|
||||
StructuredOutputValidationError,
|
||||
ToolStrategy,
|
||||
)
|
||||
from langchain.agents.tool_node import ToolNode
|
||||
from langchain.chat_models import init_chat_model
|
||||
|
||||
STRUCTURED_OUTPUT_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes."
|
||||
|
||||
|
||||
def _merge_state_schemas(schemas: list[type]) -> type:
|
||||
"""Merge multiple TypedDict schemas into a single schema with all fields."""
|
||||
if not schemas:
|
||||
return AgentState
|
||||
|
||||
all_annotations = {}
|
||||
|
||||
for schema in schemas:
|
||||
all_annotations.update(schema.__annotations__)
|
||||
|
||||
return TypedDict("MergedState", all_annotations) # type: ignore[operator]
|
||||
|
||||
|
||||
def _filter_state_for_schema(state: dict[str, Any], schema: type) -> dict[str, Any]:
|
||||
"""Filter state to only include fields defined in the given schema."""
|
||||
if not hasattr(schema, "__annotations__"):
|
||||
return state
|
||||
|
||||
schema_fields = set(schema.__annotations__.keys())
|
||||
return {k: v for k, v in state.items() if k in schema_fields}
|
||||
|
||||
|
||||
def _supports_native_structured_output(model: Union[str, BaseChatModel]) -> bool:
|
||||
"""Check if a model supports native structured output."""
|
||||
model_name: str | None = None
|
||||
if isinstance(model, str):
|
||||
model_name = model
|
||||
elif isinstance(model, BaseChatModel):
|
||||
model_name = getattr(model, "model_name", None)
|
||||
|
||||
return (
|
||||
"grok" in model_name.lower()
|
||||
or any(part in model_name for part in ["gpt-5", "gpt-4.1", "gpt-oss", "o3-pro", "o3-mini"])
|
||||
if model_name
|
||||
else False
|
||||
)
|
||||
|
||||
|
||||
def _handle_structured_output_error(
|
||||
exception: Exception,
|
||||
response_format: ResponseFormat,
|
||||
) -> tuple[bool, str]:
|
||||
"""Handle structured output error. Returns (should_retry, retry_tool_message)."""
|
||||
if not isinstance(response_format, ToolStrategy):
|
||||
return False, ""
|
||||
|
||||
handle_errors = response_format.handle_errors
|
||||
|
||||
if handle_errors is False:
|
||||
return False, ""
|
||||
if handle_errors is True:
|
||||
return True, STRUCTURED_OUTPUT_ERROR_TEMPLATE.format(error=str(exception))
|
||||
if isinstance(handle_errors, str):
|
||||
return True, handle_errors
|
||||
if isinstance(handle_errors, type) and issubclass(handle_errors, Exception):
|
||||
if isinstance(exception, handle_errors):
|
||||
return True, STRUCTURED_OUTPUT_ERROR_TEMPLATE.format(error=str(exception))
|
||||
return False, ""
|
||||
if isinstance(handle_errors, tuple):
|
||||
if any(isinstance(exception, exc_type) for exc_type in handle_errors):
|
||||
return True, STRUCTURED_OUTPUT_ERROR_TEMPLATE.format(error=str(exception))
|
||||
return False, ""
|
||||
if callable(handle_errors):
|
||||
# type narrowing not working appropriately w/ callable check, can fix later
|
||||
return True, handle_errors(exception) # type: ignore[return-value,call-arg]
|
||||
return False, ""
|
||||
|
||||
|
||||
ResponseT = TypeVar("ResponseT")
|
||||
|
||||
|
||||
def create_agent( # noqa: PLR0915
|
||||
*,
|
||||
model: str | BaseChatModel,
|
||||
tools: Sequence[BaseTool | Callable | dict[str, Any]] | ToolNode | None = None,
|
||||
system_prompt: str | None = None,
|
||||
middleware: Sequence[AgentMiddleware] = (),
|
||||
response_format: ResponseFormat[ResponseT] | type[ResponseT] | None = None,
|
||||
context_schema: type[ContextT] | None = None,
|
||||
) -> StateGraph[
|
||||
AgentState[ResponseT], ContextT, PublicAgentState[ResponseT], PublicAgentState[ResponseT]
|
||||
]:
|
||||
"""Create a middleware agent graph."""
|
||||
# init chat model
|
||||
if isinstance(model, str):
|
||||
model = init_chat_model(model)
|
||||
|
||||
# Handle tools being None or empty
|
||||
if tools is None:
|
||||
tools = []
|
||||
|
||||
# Setup structured output
|
||||
structured_output_tools: dict[str, OutputToolBinding] = {}
|
||||
native_output_binding: ProviderStrategyBinding | None = None
|
||||
|
||||
if response_format is not None:
|
||||
if not isinstance(response_format, (ToolStrategy, ProviderStrategy)):
|
||||
# Auto-detect strategy based on model capabilities
|
||||
if _supports_native_structured_output(model):
|
||||
response_format = ProviderStrategy(schema=response_format)
|
||||
else:
|
||||
response_format = ToolStrategy(schema=response_format)
|
||||
|
||||
if isinstance(response_format, ToolStrategy):
|
||||
# Setup tools strategy for structured output
|
||||
for response_schema in response_format.schema_specs:
|
||||
structured_tool_info = OutputToolBinding.from_schema_spec(response_schema)
|
||||
structured_output_tools[structured_tool_info.tool.name] = structured_tool_info
|
||||
elif isinstance(response_format, ProviderStrategy):
|
||||
# Setup native strategy
|
||||
native_output_binding = ProviderStrategyBinding.from_schema_spec(
|
||||
response_format.schema_spec
|
||||
)
|
||||
middleware_tools = [t for m in middleware for t in getattr(m, "tools", [])]
|
||||
|
||||
# Setup tools
|
||||
tool_node: ToolNode | None = None
|
||||
if isinstance(tools, list):
|
||||
# Extract builtin provider tools (dict format)
|
||||
builtin_tools = [t for t in tools if isinstance(t, dict)]
|
||||
regular_tools = [t for t in tools if not isinstance(t, dict)]
|
||||
|
||||
# Add structured output tools to regular tools
|
||||
structured_tools = [info.tool for info in structured_output_tools.values()]
|
||||
all_tools = middleware_tools + regular_tools + structured_tools
|
||||
|
||||
# Only create ToolNode if we have tools
|
||||
tool_node = ToolNode(tools=all_tools) if all_tools else None
|
||||
default_tools = regular_tools + builtin_tools + structured_tools + middleware_tools
|
||||
elif isinstance(tools, ToolNode):
|
||||
# tools is ToolNode or None
|
||||
tool_node = tools
|
||||
if tool_node:
|
||||
default_tools = list(tool_node.tools_by_name.values()) + middleware_tools
|
||||
# Update tool node to know about tools provided by middleware
|
||||
all_tools = list(tool_node.tools_by_name.values()) + middleware_tools
|
||||
tool_node = ToolNode(all_tools)
|
||||
# Add structured output tools
|
||||
for info in structured_output_tools.values():
|
||||
default_tools.append(info.tool)
|
||||
else:
|
||||
default_tools = (
|
||||
list(structured_output_tools.values()) if structured_output_tools else []
|
||||
) + middleware_tools
|
||||
|
||||
# validate middleware
|
||||
assert len({m.__class__.__name__ for m in middleware}) == len(middleware), ( # noqa: S101
|
||||
"Please remove duplicate middleware instances."
|
||||
)
|
||||
middleware_w_before = [
|
||||
m for m in middleware if m.__class__.before_model is not AgentMiddleware.before_model
|
||||
]
|
||||
middleware_w_modify_model_request = [
|
||||
m
|
||||
for m in middleware
|
||||
if m.__class__.modify_model_request is not AgentMiddleware.modify_model_request
|
||||
]
|
||||
middleware_w_after = [
|
||||
m for m in middleware if m.__class__.after_model is not AgentMiddleware.after_model
|
||||
]
|
||||
|
||||
# Collect all middleware state schemas and create merged schema
|
||||
merged_state_schema: type[AgentState] = _merge_state_schemas(
|
||||
[m.state_schema for m in middleware]
|
||||
)
|
||||
|
||||
# create graph, add nodes
|
||||
graph = StateGraph(
|
||||
merged_state_schema,
|
||||
input_schema=PublicAgentState,
|
||||
output_schema=PublicAgentState,
|
||||
context_schema=context_schema,
|
||||
)
|
||||
|
||||
def _prepare_model_request(state: dict[str, Any]) -> tuple[ModelRequest, list[AnyMessage]]:
|
||||
"""Prepare model request and messages."""
|
||||
request = state.get("model_request") or ModelRequest(
|
||||
model=model,
|
||||
tools=default_tools,
|
||||
system_prompt=system_prompt,
|
||||
response_format=response_format,
|
||||
messages=state["messages"],
|
||||
tool_choice=None,
|
||||
)
|
||||
|
||||
# prepare messages
|
||||
messages = request.messages
|
||||
if request.system_prompt:
|
||||
messages = [SystemMessage(request.system_prompt), *messages]
|
||||
|
||||
return request, messages
|
||||
|
||||
def _handle_model_output(state: dict[str, Any], output: AIMessage) -> dict[str, Any]:
|
||||
"""Handle model output including structured responses."""
|
||||
# Handle structured output with native strategy
|
||||
if isinstance(response_format, ProviderStrategy):
|
||||
if not output.tool_calls and native_output_binding:
|
||||
structured_response = native_output_binding.parse(output)
|
||||
return {"messages": [output], "response": structured_response}
|
||||
if state.get("response") is not None:
|
||||
return {"messages": [output], "response": None}
|
||||
return {"messages": [output]}
|
||||
|
||||
# Handle structured output with tools strategy
|
||||
if (
|
||||
isinstance(response_format, ToolStrategy)
|
||||
and isinstance(output, AIMessage)
|
||||
and output.tool_calls
|
||||
):
|
||||
structured_tool_calls = [
|
||||
tc for tc in output.tool_calls if tc["name"] in structured_output_tools
|
||||
]
|
||||
|
||||
if structured_tool_calls:
|
||||
exception: Exception | None = None
|
||||
if len(structured_tool_calls) > 1:
|
||||
# Handle multiple structured outputs error
|
||||
tool_names = [tc["name"] for tc in structured_tool_calls]
|
||||
exception = MultipleStructuredOutputsError(tool_names)
|
||||
should_retry, error_message = _handle_structured_output_error(
|
||||
exception, response_format
|
||||
)
|
||||
if not should_retry:
|
||||
raise exception
|
||||
|
||||
# Add error messages and retry
|
||||
tool_messages = [
|
||||
ToolMessage(
|
||||
content=error_message,
|
||||
tool_call_id=tc["id"],
|
||||
name=tc["name"],
|
||||
)
|
||||
for tc in structured_tool_calls
|
||||
]
|
||||
return {"messages": [output, *tool_messages]}
|
||||
|
||||
# Handle single structured output
|
||||
tool_call = structured_tool_calls[0]
|
||||
try:
|
||||
structured_tool_binding = structured_output_tools[tool_call["name"]]
|
||||
structured_response = structured_tool_binding.parse(tool_call["args"])
|
||||
|
||||
tool_message_content = (
|
||||
response_format.tool_message_content
|
||||
if response_format.tool_message_content
|
||||
else f"Returning structured response: {structured_response}"
|
||||
)
|
||||
|
||||
return {
|
||||
"messages": [
|
||||
output,
|
||||
ToolMessage(
|
||||
content=tool_message_content,
|
||||
tool_call_id=tool_call["id"],
|
||||
name=tool_call["name"],
|
||||
),
|
||||
],
|
||||
"response": structured_response,
|
||||
}
|
||||
except Exception as exc: # noqa: BLE001
|
||||
exception = StructuredOutputValidationError(tool_call["name"], exc)
|
||||
should_retry, error_message = _handle_structured_output_error(
|
||||
exception, response_format
|
||||
)
|
||||
if not should_retry:
|
||||
raise exception
|
||||
|
||||
return {
|
||||
"messages": [
|
||||
output,
|
||||
ToolMessage(
|
||||
content=error_message,
|
||||
tool_call_id=tool_call["id"],
|
||||
name=tool_call["name"],
|
||||
),
|
||||
],
|
||||
}
|
||||
|
||||
# Standard response handling
|
||||
if state.get("response") is not None:
|
||||
return {"messages": [output], "response": None}
|
||||
return {"messages": [output]}
|
||||
|
||||
def _get_bound_model(request: ModelRequest) -> Runnable:
|
||||
"""Get the model with appropriate tool bindings."""
|
||||
if isinstance(response_format, ProviderStrategy):
|
||||
# Use native structured output
|
||||
kwargs = response_format.to_model_kwargs()
|
||||
return request.model.bind_tools(
|
||||
request.tools, strict=True, **kwargs, **request.model_settings
|
||||
)
|
||||
if isinstance(response_format, ToolStrategy):
|
||||
tool_choice = "any" if structured_output_tools else request.tool_choice
|
||||
return request.model.bind_tools(
|
||||
request.tools, tool_choice=tool_choice, **request.model_settings
|
||||
)
|
||||
# Standard model binding
|
||||
if request.tools:
|
||||
return request.model.bind_tools(
|
||||
request.tools, tool_choice=request.tool_choice, **request.model_settings
|
||||
)
|
||||
return request.model.bind(**request.model_settings)
|
||||
|
||||
def model_request(state: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Sync model request handler with sequential middleware processing."""
|
||||
# Start with the base model request
|
||||
request, messages = _prepare_model_request(state)
|
||||
|
||||
# Apply modify_model_request middleware in sequence
|
||||
for m in middleware_w_modify_model_request:
|
||||
# Filter state to only include fields defined in this middleware's schema
|
||||
filtered_state = _filter_state_for_schema(state, m.state_schema)
|
||||
request = m.modify_model_request(request, filtered_state)
|
||||
|
||||
# Get the bound model with the final request
|
||||
model_ = _get_bound_model(request)
|
||||
output = model_.invoke(messages)
|
||||
return _handle_model_output(state, output)
|
||||
|
||||
async def amodel_request(state: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Async model request handler with sequential middleware processing."""
|
||||
# Start with the base model request
|
||||
request, messages = _prepare_model_request(state)
|
||||
|
||||
# Apply modify_model_request middleware in sequence
|
||||
for m in middleware_w_modify_model_request:
|
||||
# Filter state to only include fields defined in this middleware's schema
|
||||
filtered_state = _filter_state_for_schema(state, m.state_schema)
|
||||
request = m.modify_model_request(request, filtered_state)
|
||||
|
||||
# Get the bound model with the final request
|
||||
model_ = _get_bound_model(request)
|
||||
output = await model_.ainvoke(messages)
|
||||
return _handle_model_output(state, output)
|
||||
|
||||
# Use sync or async based on model capabilities
|
||||
from langgraph._internal._runnable import RunnableCallable
|
||||
|
||||
graph.add_node("model_request", RunnableCallable(model_request, amodel_request))
|
||||
|
||||
# Only add tools node if we have tools
|
||||
if tool_node is not None:
|
||||
graph.add_node("tools", tool_node)
|
||||
|
||||
# Add middleware nodes
|
||||
for m in middleware:
|
||||
if m.__class__.before_model is not AgentMiddleware.before_model:
|
||||
graph.add_node(
|
||||
f"{m.__class__.__name__}.before_model",
|
||||
m.before_model,
|
||||
input_schema=m.state_schema,
|
||||
)
|
||||
|
||||
if m.__class__.after_model is not AgentMiddleware.after_model:
|
||||
graph.add_node(
|
||||
f"{m.__class__.__name__}.after_model",
|
||||
m.after_model,
|
||||
input_schema=m.state_schema,
|
||||
)
|
||||
|
||||
# add start edge
|
||||
first_node = (
|
||||
f"{middleware_w_before[0].__class__.__name__}.before_model"
|
||||
if middleware_w_before
|
||||
else "model_request"
|
||||
)
|
||||
last_node = (
|
||||
f"{middleware_w_after[0].__class__.__name__}.after_model"
|
||||
if middleware_w_after
|
||||
else "model_request"
|
||||
)
|
||||
graph.add_edge(START, first_node)
|
||||
|
||||
# add conditional edges only if tools exist
|
||||
if tool_node is not None:
|
||||
graph.add_conditional_edges(
|
||||
"tools",
|
||||
_make_tools_to_model_edge(tool_node, first_node),
|
||||
[first_node, END],
|
||||
)
|
||||
graph.add_conditional_edges(
|
||||
last_node,
|
||||
_make_model_to_tools_edge(first_node, structured_output_tools),
|
||||
[first_node, "tools", END],
|
||||
)
|
||||
elif last_node == "model_request":
|
||||
# If no tools, just go to END from model
|
||||
graph.add_edge(last_node, END)
|
||||
else:
|
||||
# If after_model, then need to check for jump_to
|
||||
_add_middleware_edge(
|
||||
graph,
|
||||
f"{middleware_w_after[0].__class__.__name__}.after_model",
|
||||
END,
|
||||
first_node,
|
||||
tools_available=tool_node is not None,
|
||||
)
|
||||
|
||||
# Add middleware edges (same as before)
|
||||
if middleware_w_before:
|
||||
for m1, m2 in itertools.pairwise(middleware_w_before):
|
||||
_add_middleware_edge(
|
||||
graph,
|
||||
f"{m1.__class__.__name__}.before_model",
|
||||
f"{m2.__class__.__name__}.before_model",
|
||||
first_node,
|
||||
tools_available=tool_node is not None,
|
||||
)
|
||||
# Go directly to model_request after the last before_model
|
||||
_add_middleware_edge(
|
||||
graph,
|
||||
f"{middleware_w_before[-1].__class__.__name__}.before_model",
|
||||
"model_request",
|
||||
first_node,
|
||||
tools_available=tool_node is not None,
|
||||
)
|
||||
|
||||
if middleware_w_after:
|
||||
graph.add_edge("model_request", f"{middleware_w_after[-1].__class__.__name__}.after_model")
|
||||
for idx in range(len(middleware_w_after) - 1, 0, -1):
|
||||
m1 = middleware_w_after[idx]
|
||||
m2 = middleware_w_after[idx - 1]
|
||||
_add_middleware_edge(
|
||||
graph,
|
||||
f"{m1.__class__.__name__}.after_model",
|
||||
f"{m2.__class__.__name__}.after_model",
|
||||
first_node,
|
||||
tools_available=tool_node is not None,
|
||||
)
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
def _resolve_jump(jump_to: JumpTo | None, first_node: str) -> str | None:
|
||||
if jump_to == "model":
|
||||
return first_node
|
||||
if jump_to:
|
||||
return jump_to
|
||||
return None
|
||||
|
||||
|
||||
def _make_model_to_tools_edge(
|
||||
first_node: str, structured_output_tools: dict[str, OutputToolBinding]
|
||||
) -> Callable[[AgentState], str | None]:
|
||||
def model_to_tools(state: AgentState) -> str | None:
|
||||
if jump_to := state.get("jump_to"):
|
||||
return _resolve_jump(jump_to, first_node)
|
||||
|
||||
message = state["messages"][-1]
|
||||
|
||||
# Check if this is a ToolMessage from structured output - if so, end
|
||||
if isinstance(message, ToolMessage) and message.name in structured_output_tools:
|
||||
return END
|
||||
|
||||
# Check for tool calls
|
||||
if isinstance(message, AIMessage) and message.tool_calls:
|
||||
# If all tool calls are for structured output, don't go to tools
|
||||
non_structured_calls = [
|
||||
tc for tc in message.tool_calls if tc["name"] not in structured_output_tools
|
||||
]
|
||||
if non_structured_calls:
|
||||
return "tools"
|
||||
|
||||
return END
|
||||
|
||||
return model_to_tools
|
||||
|
||||
|
||||
def _make_tools_to_model_edge(
|
||||
tool_node: ToolNode, next_node: str
|
||||
) -> Callable[[AgentState], str | None]:
|
||||
def tools_to_model(state: AgentState) -> str | None:
|
||||
ai_message = [m for m in state["messages"] if isinstance(m, AIMessage)][-1]
|
||||
if all(
|
||||
tool_node.tools_by_name[c["name"]].return_direct
|
||||
for c in ai_message.tool_calls
|
||||
if c["name"] in tool_node.tools_by_name
|
||||
):
|
||||
return END
|
||||
|
||||
return next_node
|
||||
|
||||
return tools_to_model
|
||||
|
||||
|
||||
def _add_middleware_edge(
|
||||
graph: StateGraph[AgentState, ContextT, PublicAgentState, PublicAgentState],
|
||||
name: str,
|
||||
default_destination: str,
|
||||
model_destination: str,
|
||||
tools_available: bool, # noqa: FBT001
|
||||
) -> None:
|
||||
"""Add an edge to the graph for a middleware node.
|
||||
|
||||
Args:
|
||||
graph: The graph to add the edge to.
|
||||
method: The method to call for the middleware node.
|
||||
name: The name of the middleware node.
|
||||
default_destination: The default destination for the edge.
|
||||
model_destination: The destination for the edge to the model.
|
||||
tools_available: Whether tools are available for the edge to potentially route to.
|
||||
"""
|
||||
|
||||
def jump_edge(state: AgentState) -> str:
|
||||
return _resolve_jump(state.get("jump_to"), model_destination) or default_destination
|
||||
|
||||
destinations = [default_destination]
|
||||
if default_destination != END:
|
||||
destinations.append(END)
|
||||
if tools_available:
|
||||
destinations.append("tools")
|
||||
if name != model_destination:
|
||||
destinations.append(model_destination)
|
||||
|
||||
graph.add_conditional_edges(name, jump_edge, destinations)
|
||||
1203
libs/langchain_v1/langchain/agents/react_agent.py
Normal file
1203
libs/langchain_v1/langchain/agents/react_agent.py
Normal file
File diff suppressed because it is too large
Load Diff
403
libs/langchain_v1/langchain/agents/structured_output.py
Normal file
403
libs/langchain_v1/langchain/agents/structured_output.py
Normal file
@@ -0,0 +1,403 @@
|
||||
"""Types for setting agent response formats."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from dataclasses import dataclass, is_dataclass
|
||||
from types import UnionType
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Generic,
|
||||
Literal,
|
||||
TypeVar,
|
||||
Union,
|
||||
get_args,
|
||||
get_origin,
|
||||
)
|
||||
|
||||
from langchain_core.tools import BaseTool, StructuredTool
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
from typing_extensions import Self, is_typeddict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable, Iterable
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
# Supported schema types: Pydantic models, dataclasses, TypedDict, JSON schema dicts
|
||||
SchemaT = TypeVar("SchemaT")
|
||||
|
||||
SchemaKind = Literal["pydantic", "dataclass", "typeddict", "json_schema"]
|
||||
|
||||
|
||||
class StructuredOutputError(Exception):
|
||||
"""Base class for structured output errors."""
|
||||
|
||||
|
||||
class MultipleStructuredOutputsError(StructuredOutputError):
|
||||
"""Raised when model returns multiple structured output tool calls when only one is expected."""
|
||||
|
||||
def __init__(self, tool_names: list[str]) -> None:
|
||||
"""Initialize MultipleStructuredOutputsError.
|
||||
|
||||
Args:
|
||||
tool_names: The names of the tools called for structured output.
|
||||
"""
|
||||
self.tool_names = tool_names
|
||||
|
||||
super().__init__(
|
||||
f"Model incorrectly returned multiple structured responses ({', '.join(tool_names)}) when only one is expected."
|
||||
)
|
||||
|
||||
|
||||
class StructuredOutputValidationError(StructuredOutputError):
|
||||
"""Raised when structured output tool call arguments fail to parse according to the schema."""
|
||||
|
||||
def __init__(self, tool_name: str, source: Exception) -> None:
|
||||
"""Initialize StructuredOutputValidationError.
|
||||
|
||||
Args:
|
||||
tool_name: The name of the tool that failed.
|
||||
source: The exception that occurred.
|
||||
"""
|
||||
self.tool_name = tool_name
|
||||
self.source = source
|
||||
super().__init__(f"Failed to parse structured output for tool '{tool_name}': {source}.")
|
||||
|
||||
|
||||
def _parse_with_schema(
|
||||
schema: Union[type[SchemaT], dict], schema_kind: SchemaKind, data: dict[str, Any]
|
||||
) -> Any:
|
||||
"""Parse data using for any supported schema type.
|
||||
|
||||
Args:
|
||||
schema: The schema type (Pydantic model, dataclass, or TypedDict)
|
||||
schema_kind: One of "pydantic", "dataclass", "typeddict", or "json_schema"
|
||||
data: The data to parse
|
||||
|
||||
Returns:
|
||||
The parsed instance according to the schema type
|
||||
|
||||
Raises:
|
||||
ValueError: If parsing fails
|
||||
"""
|
||||
if schema_kind == "json_schema":
|
||||
return data
|
||||
try:
|
||||
adapter: TypeAdapter[SchemaT] = TypeAdapter(schema)
|
||||
return adapter.validate_python(data)
|
||||
except Exception as e:
|
||||
schema_name = getattr(schema, "__name__", str(schema))
|
||||
msg = f"Failed to parse data to {schema_name}: {e}"
|
||||
raise ValueError(msg) from e
|
||||
|
||||
|
||||
@dataclass(init=False)
|
||||
class _SchemaSpec(Generic[SchemaT]):
|
||||
"""Describes a structured output schema."""
|
||||
|
||||
schema: type[SchemaT]
|
||||
"""The schema for the response, can be a Pydantic model, dataclass, TypedDict, or JSON schema dict."""
|
||||
|
||||
name: str
|
||||
"""Name of the schema, used for tool calling.
|
||||
|
||||
If not provided, the name will be the model name or "response_format" if it's a JSON schema.
|
||||
"""
|
||||
|
||||
description: str
|
||||
"""Custom description of the schema.
|
||||
|
||||
If not provided, provided will use the model's docstring.
|
||||
"""
|
||||
|
||||
schema_kind: SchemaKind
|
||||
"""The kind of schema."""
|
||||
|
||||
json_schema: dict[str, Any]
|
||||
"""JSON schema associated with the schema."""
|
||||
|
||||
strict: bool = False
|
||||
"""Whether to enforce strict validation of the schema."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
schema: type[SchemaT],
|
||||
*,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
strict: bool = False,
|
||||
) -> None:
|
||||
"""Initialize SchemaSpec with schema and optional parameters."""
|
||||
self.schema = schema
|
||||
|
||||
if name:
|
||||
self.name = name
|
||||
elif isinstance(schema, dict):
|
||||
self.name = str(schema.get("title", f"response_format_{str(uuid.uuid4())[:4]}"))
|
||||
else:
|
||||
self.name = str(getattr(schema, "__name__", f"response_format_{str(uuid.uuid4())[:4]}"))
|
||||
|
||||
self.description = description or (
|
||||
schema.get("description", "")
|
||||
if isinstance(schema, dict)
|
||||
else getattr(schema, "__doc__", None) or ""
|
||||
)
|
||||
|
||||
self.strict = strict
|
||||
|
||||
if isinstance(schema, dict):
|
||||
self.schema_kind = "json_schema"
|
||||
self.json_schema = schema
|
||||
elif isinstance(schema, type) and issubclass(schema, BaseModel):
|
||||
self.schema_kind = "pydantic"
|
||||
self.json_schema = schema.model_json_schema()
|
||||
elif is_dataclass(schema):
|
||||
self.schema_kind = "dataclass"
|
||||
self.json_schema = TypeAdapter(schema).json_schema()
|
||||
elif is_typeddict(schema):
|
||||
self.schema_kind = "typeddict"
|
||||
self.json_schema = TypeAdapter(schema).json_schema()
|
||||
else:
|
||||
msg = (
|
||||
f"Unsupported schema type: {type(schema)}. "
|
||||
f"Supported types: Pydantic models, dataclasses, TypedDicts, and JSON schema dicts."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
@dataclass(init=False)
|
||||
class ToolStrategy(Generic[SchemaT]):
|
||||
"""Use a tool calling strategy for model responses."""
|
||||
|
||||
schema: type[SchemaT]
|
||||
"""Schema for the tool calls."""
|
||||
|
||||
schema_specs: list[_SchemaSpec[SchemaT]]
|
||||
"""Schema specs for the tool calls."""
|
||||
|
||||
tool_message_content: str | None
|
||||
"""The content of the tool message to be returned when the model calls an artificial structured output tool."""
|
||||
|
||||
handle_errors: Union[
|
||||
bool,
|
||||
str,
|
||||
type[Exception],
|
||||
tuple[type[Exception], ...],
|
||||
Callable[[Exception], str],
|
||||
]
|
||||
"""Error handling strategy for structured output via ToolStrategy. Default is True.
|
||||
|
||||
- True: Catch all errors with default error template
|
||||
- str: Catch all errors with this custom message
|
||||
- type[Exception]: Only catch this exception type with default message
|
||||
- tuple[type[Exception], ...]: Only catch these exception types with default message
|
||||
- Callable[[Exception], str]: Custom function that returns error message
|
||||
- False: No retry, let exceptions propagate
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
schema: type[SchemaT],
|
||||
*,
|
||||
tool_message_content: str | None = None,
|
||||
handle_errors: Union[
|
||||
bool,
|
||||
str,
|
||||
type[Exception],
|
||||
tuple[type[Exception], ...],
|
||||
Callable[[Exception], str],
|
||||
] = True,
|
||||
) -> None:
|
||||
"""Initialize ToolStrategy with schemas, tool message content, and error handling strategy."""
|
||||
self.schema = schema
|
||||
self.tool_message_content = tool_message_content
|
||||
self.handle_errors = handle_errors
|
||||
|
||||
def _iter_variants(schema: Any) -> Iterable[Any]:
|
||||
"""Yield leaf variants from Union and JSON Schema oneOf."""
|
||||
if get_origin(schema) in (UnionType, Union):
|
||||
for arg in get_args(schema):
|
||||
yield from _iter_variants(arg)
|
||||
return
|
||||
|
||||
if isinstance(schema, dict) and "oneOf" in schema:
|
||||
for sub in schema.get("oneOf", []):
|
||||
yield from _iter_variants(sub)
|
||||
return
|
||||
|
||||
yield schema
|
||||
|
||||
self.schema_specs = [_SchemaSpec(s) for s in _iter_variants(schema)]
|
||||
|
||||
|
||||
@dataclass(init=False)
|
||||
class ProviderStrategy(Generic[SchemaT]):
|
||||
"""Use the model provider's native structured output method."""
|
||||
|
||||
schema: type[SchemaT]
|
||||
"""Schema for native mode."""
|
||||
|
||||
schema_spec: _SchemaSpec[SchemaT]
|
||||
"""Schema spec for native mode."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
schema: type[SchemaT],
|
||||
) -> None:
|
||||
"""Initialize ProviderStrategy with schema."""
|
||||
self.schema = schema
|
||||
self.schema_spec = _SchemaSpec(schema)
|
||||
|
||||
def to_model_kwargs(self) -> dict[str, Any]:
|
||||
"""Convert to kwargs to bind to a model to force structured output."""
|
||||
# OpenAI:
|
||||
# - see https://platform.openai.com/docs/guides/structured-outputs
|
||||
response_format = {
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": self.schema_spec.name,
|
||||
"schema": self.schema_spec.json_schema,
|
||||
},
|
||||
}
|
||||
return {"response_format": response_format}
|
||||
|
||||
|
||||
@dataclass
|
||||
class OutputToolBinding(Generic[SchemaT]):
|
||||
"""Information for tracking structured output tool metadata.
|
||||
|
||||
This contains all necessary information to handle structured responses
|
||||
generated via tool calls, including the original schema, its type classification,
|
||||
and the corresponding tool implementation used by the tools strategy.
|
||||
"""
|
||||
|
||||
schema: type[SchemaT]
|
||||
"""The original schema provided for structured output (Pydantic model, dataclass, TypedDict, or JSON schema dict)."""
|
||||
|
||||
schema_kind: SchemaKind
|
||||
"""Classification of the schema type for proper response construction."""
|
||||
|
||||
tool: BaseTool
|
||||
"""LangChain tool instance created from the schema for model binding."""
|
||||
|
||||
@classmethod
|
||||
def from_schema_spec(cls, schema_spec: _SchemaSpec[SchemaT]) -> Self:
|
||||
"""Create an OutputToolBinding instance from a SchemaSpec.
|
||||
|
||||
Args:
|
||||
schema_spec: The SchemaSpec to convert
|
||||
|
||||
Returns:
|
||||
An OutputToolBinding instance with the appropriate tool created
|
||||
"""
|
||||
return cls(
|
||||
schema=schema_spec.schema,
|
||||
schema_kind=schema_spec.schema_kind,
|
||||
tool=StructuredTool(
|
||||
args_schema=schema_spec.json_schema,
|
||||
name=schema_spec.name,
|
||||
description=schema_spec.description,
|
||||
),
|
||||
)
|
||||
|
||||
def parse(self, tool_args: dict[str, Any]) -> SchemaT:
|
||||
"""Parse tool arguments according to the schema.
|
||||
|
||||
Args:
|
||||
tool_args: The arguments from the tool call
|
||||
|
||||
Returns:
|
||||
The parsed response according to the schema type
|
||||
|
||||
Raises:
|
||||
ValueError: If parsing fails
|
||||
"""
|
||||
return _parse_with_schema(self.schema, self.schema_kind, tool_args)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProviderStrategyBinding(Generic[SchemaT]):
|
||||
"""Information for tracking native structured output metadata.
|
||||
|
||||
This contains all necessary information to handle structured responses
|
||||
generated via native provider output, including the original schema,
|
||||
its type classification, and parsing logic for provider-enforced JSON.
|
||||
"""
|
||||
|
||||
schema: type[SchemaT]
|
||||
"""The original schema provided for structured output (Pydantic model, dataclass, TypedDict, or JSON schema dict)."""
|
||||
|
||||
schema_kind: SchemaKind
|
||||
"""Classification of the schema type for proper response construction."""
|
||||
|
||||
@classmethod
|
||||
def from_schema_spec(cls, schema_spec: _SchemaSpec[SchemaT]) -> Self:
|
||||
"""Create a ProviderStrategyBinding instance from a SchemaSpec.
|
||||
|
||||
Args:
|
||||
schema_spec: The SchemaSpec to convert
|
||||
|
||||
Returns:
|
||||
A ProviderStrategyBinding instance for parsing native structured output
|
||||
"""
|
||||
return cls(
|
||||
schema=schema_spec.schema,
|
||||
schema_kind=schema_spec.schema_kind,
|
||||
)
|
||||
|
||||
def parse(self, response: AIMessage) -> SchemaT:
|
||||
"""Parse AIMessage content according to the schema.
|
||||
|
||||
Args:
|
||||
response: The AI message containing the structured output
|
||||
|
||||
Returns:
|
||||
The parsed response according to the schema
|
||||
|
||||
Raises:
|
||||
ValueError: If text extraction, JSON parsing or schema validation fails
|
||||
"""
|
||||
# Extract text content from AIMessage and parse as JSON
|
||||
raw_text = self._extract_text_content_from_message(response)
|
||||
|
||||
import json
|
||||
|
||||
try:
|
||||
data = json.loads(raw_text)
|
||||
except Exception as e:
|
||||
schema_name = getattr(self.schema, "__name__", "response_format")
|
||||
msg = f"Native structured output expected valid JSON for {schema_name}, but parsing failed: {e}."
|
||||
raise ValueError(msg) from e
|
||||
|
||||
# Parse according to schema
|
||||
return _parse_with_schema(self.schema, self.schema_kind, data)
|
||||
|
||||
def _extract_text_content_from_message(self, message: AIMessage) -> str:
|
||||
"""Extract text content from an AIMessage.
|
||||
|
||||
Args:
|
||||
message: The AI message to extract text from
|
||||
|
||||
Returns:
|
||||
The extracted text content
|
||||
"""
|
||||
content = message.content
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts: list[str] = []
|
||||
for c in content:
|
||||
if isinstance(c, dict):
|
||||
if c.get("type") == "text" and "text" in c:
|
||||
parts.append(str(c["text"]))
|
||||
elif "content" in c and isinstance(c["content"], str):
|
||||
parts.append(c["content"])
|
||||
else:
|
||||
parts.append(str(c))
|
||||
return "".join(parts)
|
||||
return str(content)
|
||||
|
||||
|
||||
ResponseFormat = Union[ToolStrategy[SchemaT], ProviderStrategy[SchemaT]]
|
||||
1174
libs/langchain_v1/langchain/agents/tool_node.py
Normal file
1174
libs/langchain_v1/langchain/agents/tool_node.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,9 +0,0 @@
|
||||
from langchain.chains.documents import (
|
||||
create_map_reduce_chain,
|
||||
create_stuff_documents_chain,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"create_map_reduce_chain",
|
||||
"create_stuff_documents_chain",
|
||||
]
|
||||
@@ -1,17 +0,0 @@
|
||||
"""Document extraction chains.
|
||||
|
||||
This module provides different strategies for extracting information from collections
|
||||
of documents using LangGraph and modern language models.
|
||||
|
||||
Available Strategies:
|
||||
- Stuff: Processes all documents together in a single context window
|
||||
- Map-Reduce: Processes documents in parallel (map), then combines results (reduce)
|
||||
"""
|
||||
|
||||
from langchain.chains.documents.map_reduce import create_map_reduce_chain
|
||||
from langchain.chains.documents.stuff import create_stuff_documents_chain
|
||||
|
||||
__all__ = [
|
||||
"create_map_reduce_chain",
|
||||
"create_stuff_documents_chain",
|
||||
]
|
||||
@@ -1,586 +0,0 @@
|
||||
"""Map-Reduce Extraction Implementation using LangGraph Send API."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import operator
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Annotated,
|
||||
Any,
|
||||
Callable,
|
||||
Generic,
|
||||
Literal,
|
||||
Optional,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
from langgraph.types import Send
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
|
||||
from langchain._internal._documents import format_document_xml
|
||||
from langchain._internal._prompts import aresolve_prompt, resolve_prompt
|
||||
from langchain._internal._typing import ContextT, StateNode
|
||||
from langchain._internal._utils import RunnableCallable
|
||||
from langchain.chat_models import init_chat_model
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
|
||||
# Pycharm is unable to identify that AIMessage is used in the cast below
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
MessageLikeRepresentation,
|
||||
)
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.runtime import Runtime
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ExtractionResult(TypedDict):
|
||||
"""Result from processing a document or group of documents."""
|
||||
|
||||
indexes: list[int]
|
||||
"""Document indexes that contributed to this result."""
|
||||
result: Any
|
||||
"""Extracted result from the document(s)."""
|
||||
|
||||
|
||||
class MapReduceState(TypedDict):
|
||||
"""State for map-reduce extraction chain.
|
||||
|
||||
This state tracks the map-reduce process where documents are processed
|
||||
in parallel during the map phase, then combined in the reduce phase.
|
||||
"""
|
||||
|
||||
documents: list[Document]
|
||||
"""List of documents to process."""
|
||||
map_results: Annotated[list[ExtractionResult], operator.add]
|
||||
"""Individual results from the map phase."""
|
||||
result: NotRequired[Any]
|
||||
"""Final combined result from the reduce phase if applicable."""
|
||||
|
||||
|
||||
# The payload for the map phase is a list of documents and their indexes.
|
||||
# The current implementation only supports a single document per map operation,
|
||||
# but the structure allows for future expansion to process a group of documents.
|
||||
# A user would provide an input split function that returns groups of documents
|
||||
# to process together, if desired.
|
||||
class MapState(TypedDict):
|
||||
"""State for individual map operations."""
|
||||
|
||||
documents: list[Document]
|
||||
"""List of documents to process in map phase."""
|
||||
indexes: list[int]
|
||||
"""List of indexes of the documents in the original list."""
|
||||
|
||||
|
||||
class InputSchema(TypedDict):
|
||||
"""Input schema for the map-reduce extraction chain.
|
||||
|
||||
Defines the expected input format when invoking the extraction chain.
|
||||
"""
|
||||
|
||||
documents: list[Document]
|
||||
"""List of documents to process."""
|
||||
|
||||
|
||||
class OutputSchema(TypedDict):
|
||||
"""Output schema for the map-reduce extraction chain.
|
||||
|
||||
Defines the format of the final result returned by the chain.
|
||||
"""
|
||||
|
||||
map_results: list[ExtractionResult]
|
||||
"""List of individual extraction results from the map phase."""
|
||||
|
||||
result: Any
|
||||
"""Final combined result from all documents."""
|
||||
|
||||
|
||||
class MapReduceNodeUpdate(TypedDict):
|
||||
"""Update returned by map-reduce nodes."""
|
||||
|
||||
map_results: NotRequired[list[ExtractionResult]]
|
||||
"""Updated results after map phase."""
|
||||
result: NotRequired[Any]
|
||||
"""Final result after reduce phase."""
|
||||
|
||||
|
||||
class _MapReduceExtractor(Generic[ContextT]):
|
||||
"""Map-reduce extraction implementation using LangGraph Send API.
|
||||
|
||||
This implementation uses a language model to process documents through up
|
||||
to two phases:
|
||||
|
||||
1. **Map Phase**: Each document is processed independently by the LLM using
|
||||
the configured map_prompt to generate individual extraction results.
|
||||
2. **Reduce Phase (Optional)**: Individual results can optionally be
|
||||
combined using either:
|
||||
- The default LLM-based reducer with the configured reduce_prompt
|
||||
- A custom reducer function (which can be non-LLM based)
|
||||
- Skipped entirely by setting reduce=None
|
||||
|
||||
The map phase processes documents in parallel for efficiency, making this approach
|
||||
well-suited for large document collections. The reduce phase is flexible and can be
|
||||
customized or omitted based on your specific requirements.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Union[BaseChatModel, str],
|
||||
*,
|
||||
map_prompt: Union[
|
||||
str,
|
||||
None,
|
||||
Callable[
|
||||
[MapState, Runtime[ContextT]],
|
||||
list[MessageLikeRepresentation],
|
||||
],
|
||||
] = None,
|
||||
reduce_prompt: Union[
|
||||
str,
|
||||
None,
|
||||
Callable[
|
||||
[MapReduceState, Runtime[ContextT]],
|
||||
list[MessageLikeRepresentation],
|
||||
],
|
||||
] = None,
|
||||
reduce: Union[
|
||||
Literal["default_reducer"],
|
||||
None,
|
||||
StateNode,
|
||||
] = "default_reducer",
|
||||
context_schema: type[ContextT] | None = None,
|
||||
response_format: Optional[type[BaseModel]] = None,
|
||||
) -> None:
|
||||
"""Initialize the MapReduceExtractor.
|
||||
|
||||
Args:
|
||||
model: The language model either a chat model instance
|
||||
(e.g., `ChatAnthropic()`) or string identifier
|
||||
(e.g., `"anthropic:claude-sonnet-4-20250514"`)
|
||||
map_prompt: Prompt for individual document processing. Can be:
|
||||
- str: A system message string
|
||||
- None: Use default system message
|
||||
- Callable: A function that takes (state, runtime) and returns messages
|
||||
reduce_prompt: Prompt for combining results. Can be:
|
||||
- str: A system message string
|
||||
- None: Use default system message
|
||||
- Callable: A function that takes (state, runtime) and returns messages
|
||||
reduce: Controls the reduce behavior. Can be:
|
||||
- "default_reducer": Use the default LLM-based reduce step
|
||||
- None: Skip the reduce step entirely
|
||||
- Callable: Custom reduce function (sync or async)
|
||||
context_schema: Optional context schema for the LangGraph runtime.
|
||||
response_format: Optional pydantic BaseModel for structured output.
|
||||
"""
|
||||
if (reduce is None or callable(reduce)) and reduce_prompt is not None:
|
||||
msg = (
|
||||
"reduce_prompt must be None when reduce is None or a custom "
|
||||
"callable. Custom reduce functions handle their own logic and "
|
||||
"should not use reduce_prompt."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
self.response_format = response_format
|
||||
|
||||
if isinstance(model, str):
|
||||
model = init_chat_model(model)
|
||||
|
||||
self.model = (
|
||||
model.with_structured_output(response_format) if response_format else model
|
||||
)
|
||||
self.map_prompt = map_prompt
|
||||
self.reduce_prompt = reduce_prompt
|
||||
self.reduce = reduce
|
||||
self.context_schema = context_schema
|
||||
|
||||
def _get_map_prompt(
|
||||
self, state: MapState, runtime: Runtime[ContextT]
|
||||
) -> list[MessageLikeRepresentation]:
|
||||
"""Generate the LLM prompt for processing documents."""
|
||||
documents = state["documents"]
|
||||
user_content = "\n\n".join(format_document_xml(doc) for doc in documents)
|
||||
default_system = (
|
||||
"You are a helpful assistant that processes documents. "
|
||||
"Please process the following documents and provide a result."
|
||||
)
|
||||
|
||||
return resolve_prompt(
|
||||
self.map_prompt,
|
||||
state,
|
||||
runtime,
|
||||
user_content,
|
||||
default_system,
|
||||
)
|
||||
|
||||
async def _aget_map_prompt(
|
||||
self, state: MapState, runtime: Runtime[ContextT]
|
||||
) -> list[MessageLikeRepresentation]:
|
||||
"""Generate the LLM prompt for processing documents in the map phase.
|
||||
|
||||
Async version.
|
||||
"""
|
||||
documents = state["documents"]
|
||||
user_content = "\n\n".join(format_document_xml(doc) for doc in documents)
|
||||
default_system = (
|
||||
"You are a helpful assistant that processes documents. "
|
||||
"Please process the following documents and provide a result."
|
||||
)
|
||||
|
||||
return await aresolve_prompt(
|
||||
self.map_prompt,
|
||||
state,
|
||||
runtime,
|
||||
user_content,
|
||||
default_system,
|
||||
)
|
||||
|
||||
def _get_reduce_prompt(
|
||||
self, state: MapReduceState, runtime: Runtime[ContextT]
|
||||
) -> list[MessageLikeRepresentation]:
|
||||
"""Generate the LLM prompt for combining individual results.
|
||||
|
||||
Combines map results in the reduce phase.
|
||||
"""
|
||||
map_results = state.get("map_results", [])
|
||||
if not map_results:
|
||||
msg = (
|
||||
"Internal programming error: Results must exist when reducing. "
|
||||
"This indicates that the reduce node was reached without "
|
||||
"first processing the map nodes, which violates "
|
||||
"the expected graph execution order."
|
||||
)
|
||||
raise AssertionError(msg)
|
||||
|
||||
results_text = "\n\n".join(
|
||||
f"Result {i + 1} (from documents "
|
||||
f"{', '.join(map(str, result['indexes']))}):\n{result['result']}"
|
||||
for i, result in enumerate(map_results)
|
||||
)
|
||||
user_content = (
|
||||
f"Please combine the following results into a single, "
|
||||
f"comprehensive result:\n\n{results_text}"
|
||||
)
|
||||
default_system = (
|
||||
"You are a helpful assistant that combines multiple results. "
|
||||
"Given several individual results, create a single comprehensive "
|
||||
"result that captures the key information from all inputs while "
|
||||
"maintaining conciseness and coherence."
|
||||
)
|
||||
|
||||
return resolve_prompt(
|
||||
self.reduce_prompt,
|
||||
state,
|
||||
runtime,
|
||||
user_content,
|
||||
default_system,
|
||||
)
|
||||
|
||||
async def _aget_reduce_prompt(
|
||||
self, state: MapReduceState, runtime: Runtime[ContextT]
|
||||
) -> list[MessageLikeRepresentation]:
|
||||
"""Generate the LLM prompt for combining individual results.
|
||||
|
||||
Async version of reduce phase.
|
||||
"""
|
||||
map_results = state.get("map_results", [])
|
||||
if not map_results:
|
||||
msg = (
|
||||
"Internal programming error: Results must exist when reducing. "
|
||||
"This indicates that the reduce node was reached without "
|
||||
"first processing the map nodes, which violates "
|
||||
"the expected graph execution order."
|
||||
)
|
||||
raise AssertionError(msg)
|
||||
|
||||
results_text = "\n\n".join(
|
||||
f"Result {i + 1} (from documents "
|
||||
f"{', '.join(map(str, result['indexes']))}):\n{result['result']}"
|
||||
for i, result in enumerate(map_results)
|
||||
)
|
||||
user_content = (
|
||||
f"Please combine the following results into a single, "
|
||||
f"comprehensive result:\n\n{results_text}"
|
||||
)
|
||||
default_system = (
|
||||
"You are a helpful assistant that combines multiple results. "
|
||||
"Given several individual results, create a single comprehensive "
|
||||
"result that captures the key information from all inputs while "
|
||||
"maintaining conciseness and coherence."
|
||||
)
|
||||
|
||||
return await aresolve_prompt(
|
||||
self.reduce_prompt,
|
||||
state,
|
||||
runtime,
|
||||
user_content,
|
||||
default_system,
|
||||
)
|
||||
|
||||
def create_map_node(self) -> RunnableCallable:
|
||||
"""Create a LangGraph node that processes individual documents using the LLM."""
|
||||
|
||||
def _map_node(
|
||||
state: MapState, runtime: Runtime[ContextT], config: RunnableConfig
|
||||
) -> dict[str, list[ExtractionResult]]:
|
||||
prompt = self._get_map_prompt(state, runtime)
|
||||
response = cast("AIMessage", self.model.invoke(prompt, config=config))
|
||||
result = response if self.response_format else response.text()
|
||||
extraction_result: ExtractionResult = {
|
||||
"indexes": state["indexes"],
|
||||
"result": result,
|
||||
}
|
||||
return {"map_results": [extraction_result]}
|
||||
|
||||
async def _amap_node(
|
||||
state: MapState,
|
||||
runtime: Runtime[ContextT],
|
||||
config: RunnableConfig,
|
||||
) -> dict[str, list[ExtractionResult]]:
|
||||
prompt = await self._aget_map_prompt(state, runtime)
|
||||
response = cast(
|
||||
"AIMessage", await self.model.ainvoke(prompt, config=config)
|
||||
)
|
||||
result = response if self.response_format else response.text()
|
||||
extraction_result: ExtractionResult = {
|
||||
"indexes": state["indexes"],
|
||||
"result": result,
|
||||
}
|
||||
return {"map_results": [extraction_result]}
|
||||
|
||||
return RunnableCallable(
|
||||
_map_node,
|
||||
_amap_node,
|
||||
trace=False,
|
||||
)
|
||||
|
||||
def create_reduce_node(self) -> RunnableCallable:
|
||||
"""Create a LangGraph node that combines individual results using the LLM."""
|
||||
|
||||
def _reduce_node(
|
||||
state: MapReduceState, runtime: Runtime[ContextT], config: RunnableConfig
|
||||
) -> MapReduceNodeUpdate:
|
||||
prompt = self._get_reduce_prompt(state, runtime)
|
||||
response = cast("AIMessage", self.model.invoke(prompt, config=config))
|
||||
result = response if self.response_format else response.text()
|
||||
return {"result": result}
|
||||
|
||||
async def _areduce_node(
|
||||
state: MapReduceState,
|
||||
runtime: Runtime[ContextT],
|
||||
config: RunnableConfig,
|
||||
) -> MapReduceNodeUpdate:
|
||||
prompt = await self._aget_reduce_prompt(state, runtime)
|
||||
response = cast(
|
||||
"AIMessage", await self.model.ainvoke(prompt, config=config)
|
||||
)
|
||||
result = response if self.response_format else response.text()
|
||||
return {"result": result}
|
||||
|
||||
return RunnableCallable(
|
||||
_reduce_node,
|
||||
_areduce_node,
|
||||
trace=False,
|
||||
)
|
||||
|
||||
def continue_to_map(self, state: MapReduceState) -> list[Send]:
|
||||
"""Generate Send objects for parallel map operations."""
|
||||
return [
|
||||
Send("map_process", {"documents": [doc], "indexes": [i]})
|
||||
for i, doc in enumerate(state["documents"])
|
||||
]
|
||||
|
||||
def build(
|
||||
self,
|
||||
) -> StateGraph[MapReduceState, ContextT, InputSchema, OutputSchema]:
|
||||
"""Build and compile the LangGraph for map-reduce summarization."""
|
||||
builder = StateGraph(
|
||||
MapReduceState,
|
||||
context_schema=self.context_schema,
|
||||
input_schema=InputSchema,
|
||||
output_schema=OutputSchema,
|
||||
)
|
||||
|
||||
builder.add_node("map_process", self.create_map_node())
|
||||
|
||||
builder.add_edge(START, "continue_to_map")
|
||||
# Add-conditional edges doesn't explicitly type Send
|
||||
builder.add_conditional_edges(
|
||||
"continue_to_map",
|
||||
self.continue_to_map, # type: ignore[arg-type]
|
||||
["map_process"],
|
||||
)
|
||||
|
||||
if self.reduce is None:
|
||||
builder.add_edge("map_process", END)
|
||||
elif self.reduce == "default_reducer":
|
||||
builder.add_node("reduce_process", self.create_reduce_node())
|
||||
builder.add_edge("map_process", "reduce_process")
|
||||
builder.add_edge("reduce_process", END)
|
||||
else:
|
||||
reduce_node = cast("StateNode", self.reduce)
|
||||
# The type is ignored here. Requires parameterizing with generics.
|
||||
builder.add_node("reduce_process", reduce_node) # type: ignore[arg-type]
|
||||
builder.add_edge("map_process", "reduce_process")
|
||||
builder.add_edge("reduce_process", END)
|
||||
|
||||
return builder
|
||||
|
||||
|
||||
def create_map_reduce_chain(
|
||||
model: Union[BaseChatModel, str],
|
||||
*,
|
||||
map_prompt: Union[
|
||||
str,
|
||||
None,
|
||||
Callable[[MapState, Runtime[ContextT]], list[MessageLikeRepresentation]],
|
||||
] = None,
|
||||
reduce_prompt: Union[
|
||||
str,
|
||||
None,
|
||||
Callable[[MapReduceState, Runtime[ContextT]], list[MessageLikeRepresentation]],
|
||||
] = None,
|
||||
reduce: Union[
|
||||
Literal["default_reducer"],
|
||||
None,
|
||||
StateNode,
|
||||
] = "default_reducer",
|
||||
context_schema: type[ContextT] | None = None,
|
||||
response_format: Optional[type[BaseModel]] = None,
|
||||
) -> StateGraph[MapReduceState, ContextT, InputSchema, OutputSchema]:
|
||||
"""Create a map-reduce document extraction chain.
|
||||
|
||||
This implementation uses a language model to extract information from documents
|
||||
through a flexible approach that efficiently handles large document collections
|
||||
by processing documents in parallel.
|
||||
|
||||
**Processing Flow:**
|
||||
1. **Map Phase**: Each document is independently processed by the LLM
|
||||
using the map_prompt to extract relevant information and generate
|
||||
individual results.
|
||||
2. **Reduce Phase (Optional)**: Individual extraction results can
|
||||
optionally be combined using:
|
||||
- The default LLM-based reducer with reduce_prompt (default behavior)
|
||||
- A custom reducer function (can be non-LLM based)
|
||||
- Skipped entirely by setting reduce=None
|
||||
3. **Output**: Returns the individual map results and optionally the final
|
||||
combined result.
|
||||
|
||||
Example:
|
||||
>>> from langchain_anthropic import ChatAnthropic
|
||||
>>> from langchain_core.documents import Document
|
||||
>>>
|
||||
>>> model = ChatAnthropic(
|
||||
... model="claude-sonnet-4-20250514",
|
||||
... temperature=0,
|
||||
... max_tokens=62_000,
|
||||
... timeout=None,
|
||||
... max_retries=2,
|
||||
... )
|
||||
>>> builder = create_map_reduce_chain(model)
|
||||
>>> chain = builder.compile()
|
||||
>>> docs = [
|
||||
... Document(page_content="First document content..."),
|
||||
... Document(page_content="Second document content..."),
|
||||
... Document(page_content="Third document content..."),
|
||||
... ]
|
||||
>>> result = chain.invoke({"documents": docs})
|
||||
>>> print(result["result"])
|
||||
|
||||
Example with string model:
|
||||
>>> builder = create_map_reduce_chain("anthropic:claude-sonnet-4-20250514")
|
||||
>>> chain = builder.compile()
|
||||
>>> result = chain.invoke({"documents": docs})
|
||||
>>> print(result["result"])
|
||||
|
||||
Example with structured output:
|
||||
```python
|
||||
from pydantic import BaseModel
|
||||
|
||||
class ExtractionModel(BaseModel):
|
||||
title: str
|
||||
key_points: list[str]
|
||||
conclusion: str
|
||||
|
||||
builder = create_map_reduce_chain(
|
||||
model,
|
||||
response_format=ExtractionModel
|
||||
)
|
||||
chain = builder.compile()
|
||||
result = chain.invoke({"documents": docs})
|
||||
print(result["result"].title) # Access structured fields
|
||||
```
|
||||
|
||||
Example skipping the reduce phase:
|
||||
```python
|
||||
# Only perform map phase, skip combining results
|
||||
builder = create_map_reduce_chain(model, reduce=None)
|
||||
chain = builder.compile()
|
||||
result = chain.invoke({"documents": docs})
|
||||
# result["result"] will be None, only map_results are available
|
||||
for map_result in result["map_results"]:
|
||||
print(f"Document {map_result['indexes'][0]}: {map_result['result']}")
|
||||
```
|
||||
|
||||
Example with custom reducer:
|
||||
```python
|
||||
def custom_aggregator(state, runtime):
|
||||
# Custom non-LLM based reduction logic
|
||||
map_results = state["map_results"]
|
||||
combined_text = " | ".join(r["result"] for r in map_results)
|
||||
word_count = len(combined_text.split())
|
||||
return {
|
||||
"result": f"Combined {len(map_results)} results with "
|
||||
f"{word_count} total words"
|
||||
}
|
||||
|
||||
builder = create_map_reduce_chain(model, reduce=custom_aggregator)
|
||||
chain = builder.compile()
|
||||
result = chain.invoke({"documents": docs})
|
||||
print(result["result"]) # Custom aggregated result
|
||||
```
|
||||
|
||||
Args:
|
||||
model: The language model either a chat model instance
|
||||
(e.g., `ChatAnthropic()`) or string identifier
|
||||
(e.g., `"anthropic:claude-sonnet-4-20250514"`)
|
||||
map_prompt: Prompt for individual document processing. Can be:
|
||||
- str: A system message string
|
||||
- None: Use default system message
|
||||
- Callable: A function that takes (state, runtime) and returns messages
|
||||
reduce_prompt: Prompt for combining results. Can be:
|
||||
- str: A system message string
|
||||
- None: Use default system message
|
||||
- Callable: A function that takes (state, runtime) and returns messages
|
||||
reduce: Controls the reduce behavior. Can be:
|
||||
- "default_reducer": Use the default LLM-based reduce step
|
||||
- None: Skip the reduce step entirely
|
||||
- Callable: Custom reduce function (sync or async)
|
||||
context_schema: Optional context schema for the LangGraph runtime.
|
||||
response_format: Optional pydantic BaseModel for structured output.
|
||||
|
||||
Returns:
|
||||
A LangGraph that can be invoked with documents to get map-reduce
|
||||
extraction results.
|
||||
|
||||
.. note::
|
||||
This implementation is well-suited for large document collections as it
|
||||
processes documents in parallel during the map phase. The Send API enables
|
||||
efficient parallelization while maintaining clean state management.
|
||||
"""
|
||||
extractor = _MapReduceExtractor(
|
||||
model,
|
||||
map_prompt=map_prompt,
|
||||
reduce_prompt=reduce_prompt,
|
||||
reduce=reduce,
|
||||
context_schema=context_schema,
|
||||
response_format=response_format,
|
||||
)
|
||||
return extractor.build()
|
||||
|
||||
|
||||
__all__ = ["create_map_reduce_chain"]
|
||||
@@ -1,473 +0,0 @@
|
||||
"""Stuff documents chain for processing documents by putting them all in context."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Generic,
|
||||
Optional,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
# Used not only for type checking, but is fetched at runtime by Pydantic.
|
||||
from langchain_core.documents import Document as Document # noqa: TC002
|
||||
from langgraph.graph import START, StateGraph
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
|
||||
from langchain._internal._documents import format_document_xml
|
||||
from langchain._internal._prompts import aresolve_prompt, resolve_prompt
|
||||
from langchain._internal._typing import ContextT
|
||||
from langchain._internal._utils import RunnableCallable
|
||||
from langchain.chat_models import init_chat_model
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
|
||||
# Used for type checking, but IDEs may not recognize it inside the cast.
|
||||
from langchain_core.messages import AIMessage as AIMessage
|
||||
from langchain_core.messages import MessageLikeRepresentation
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.runtime import Runtime
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
# Default system prompts
|
||||
DEFAULT_INIT_PROMPT = (
|
||||
"You are a helpful assistant that summarizes text. "
|
||||
"Please provide a concise summary of the documents "
|
||||
"provided by the user."
|
||||
)
|
||||
|
||||
DEFAULT_STRUCTURED_INIT_PROMPT = (
|
||||
"You are a helpful assistant that extracts structured information from documents. "
|
||||
"Use the provided content and optional question to generate your output, formatted "
|
||||
"according to the predefined schema."
|
||||
)
|
||||
|
||||
DEFAULT_REFINE_PROMPT = (
|
||||
"You are a helpful assistant that refines summaries. "
|
||||
"Given an existing summary and new context, produce a refined summary "
|
||||
"that incorporates the new information while maintaining conciseness."
|
||||
)
|
||||
|
||||
DEFAULT_STRUCTURED_REFINE_PROMPT = (
|
||||
"You are a helpful assistant refining structured information extracted "
|
||||
"from documents. "
|
||||
"You are given a previous result and new document context. "
|
||||
"Update the output to reflect the new context, staying consistent with "
|
||||
"the expected schema."
|
||||
)
|
||||
|
||||
|
||||
def _format_documents_content(documents: list[Document]) -> str:
|
||||
"""Format documents into content string.
|
||||
|
||||
Args:
|
||||
documents: List of documents to format.
|
||||
|
||||
Returns:
|
||||
Formatted document content string.
|
||||
"""
|
||||
return "\n\n".join(format_document_xml(doc) for doc in documents)
|
||||
|
||||
|
||||
class ExtractionState(TypedDict):
|
||||
"""State for extraction chain.
|
||||
|
||||
This state tracks the extraction process where documents
|
||||
are processed in batch, with the result being refined if needed.
|
||||
"""
|
||||
|
||||
documents: list[Document]
|
||||
"""List of documents to process."""
|
||||
result: NotRequired[Any]
|
||||
"""Current result, refined with each document."""
|
||||
|
||||
|
||||
class InputSchema(TypedDict):
|
||||
"""Input schema for the extraction chain.
|
||||
|
||||
Defines the expected input format when invoking the extraction chain.
|
||||
"""
|
||||
|
||||
documents: list[Document]
|
||||
"""List of documents to process."""
|
||||
result: NotRequired[Any]
|
||||
"""Existing result to refine (optional)."""
|
||||
|
||||
|
||||
class OutputSchema(TypedDict):
|
||||
"""Output schema for the extraction chain.
|
||||
|
||||
Defines the format of the final result returned by the chain.
|
||||
"""
|
||||
|
||||
result: Any
|
||||
"""Result from processing the documents."""
|
||||
|
||||
|
||||
class ExtractionNodeUpdate(TypedDict):
|
||||
"""Update returned by processing nodes."""
|
||||
|
||||
result: NotRequired[Any]
|
||||
"""Updated result after processing a document."""
|
||||
|
||||
|
||||
class _Extractor(Generic[ContextT]):
|
||||
"""Stuff documents chain implementation.
|
||||
|
||||
This chain works by putting all the documents in the batch into the context
|
||||
window of the language model. It processes all documents together in a single
|
||||
request for extracting information or summaries. Can refine existing results
|
||||
when provided.
|
||||
|
||||
Important: This chain does not attempt to control for the size of the context
|
||||
window of the LLM. Ensure your documents fit within the model's context limits.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Union[BaseChatModel, str],
|
||||
*,
|
||||
prompt: Union[
|
||||
str,
|
||||
None,
|
||||
Callable[
|
||||
[ExtractionState, Runtime[ContextT]],
|
||||
list[MessageLikeRepresentation],
|
||||
],
|
||||
] = None,
|
||||
refine_prompt: Union[
|
||||
str,
|
||||
None,
|
||||
Callable[
|
||||
[ExtractionState, Runtime[ContextT]],
|
||||
list[MessageLikeRepresentation],
|
||||
],
|
||||
] = None,
|
||||
context_schema: type[ContextT] | None = None,
|
||||
response_format: Optional[type[BaseModel]] = None,
|
||||
) -> None:
|
||||
"""Initialize the Extractor.
|
||||
|
||||
Args:
|
||||
model: The language model either a chat model instance
|
||||
(e.g., `ChatAnthropic()`) or string identifier
|
||||
(e.g., `"anthropic:claude-sonnet-4-20250514"`)
|
||||
prompt: Prompt for initial processing. Can be:
|
||||
- str: A system message string
|
||||
- None: Use default system message
|
||||
- Callable: A function that takes (state, runtime) and returns messages
|
||||
refine_prompt: Prompt for refinement steps. Can be:
|
||||
- str: A system message string
|
||||
- None: Use default system message
|
||||
- Callable: A function that takes (state, runtime) and returns messages
|
||||
context_schema: Optional context schema for the LangGraph runtime.
|
||||
response_format: Optional pydantic BaseModel for structured output.
|
||||
"""
|
||||
self.response_format = response_format
|
||||
|
||||
if isinstance(model, str):
|
||||
model = init_chat_model(model)
|
||||
|
||||
self.model = (
|
||||
model.with_structured_output(response_format) if response_format else model
|
||||
)
|
||||
self.initial_prompt = prompt
|
||||
self.refine_prompt = refine_prompt
|
||||
self.context_schema = context_schema
|
||||
|
||||
def _get_initial_prompt(
|
||||
self, state: ExtractionState, runtime: Runtime[ContextT]
|
||||
) -> list[MessageLikeRepresentation]:
|
||||
"""Generate the initial extraction prompt."""
|
||||
user_content = _format_documents_content(state["documents"])
|
||||
|
||||
# Choose default prompt based on structured output format
|
||||
default_prompt = (
|
||||
DEFAULT_STRUCTURED_INIT_PROMPT
|
||||
if self.response_format
|
||||
else DEFAULT_INIT_PROMPT
|
||||
)
|
||||
|
||||
return resolve_prompt(
|
||||
self.initial_prompt,
|
||||
state,
|
||||
runtime,
|
||||
user_content,
|
||||
default_prompt,
|
||||
)
|
||||
|
||||
async def _aget_initial_prompt(
|
||||
self, state: ExtractionState, runtime: Runtime[ContextT]
|
||||
) -> list[MessageLikeRepresentation]:
|
||||
"""Generate the initial extraction prompt (async version)."""
|
||||
user_content = _format_documents_content(state["documents"])
|
||||
|
||||
# Choose default prompt based on structured output format
|
||||
default_prompt = (
|
||||
DEFAULT_STRUCTURED_INIT_PROMPT
|
||||
if self.response_format
|
||||
else DEFAULT_INIT_PROMPT
|
||||
)
|
||||
|
||||
return await aresolve_prompt(
|
||||
self.initial_prompt,
|
||||
state,
|
||||
runtime,
|
||||
user_content,
|
||||
default_prompt,
|
||||
)
|
||||
|
||||
def _get_refine_prompt(
|
||||
self, state: ExtractionState, runtime: Runtime[ContextT]
|
||||
) -> list[MessageLikeRepresentation]:
|
||||
"""Generate the refinement prompt."""
|
||||
# Result should be guaranteed to exist at refinement stage
|
||||
if "result" not in state or state["result"] == "":
|
||||
msg = (
|
||||
"Internal programming error: Result must exist when refining. "
|
||||
"This indicates that the refinement node was reached without "
|
||||
"first processing the initial result node, which violates "
|
||||
"the expected graph execution order."
|
||||
)
|
||||
raise AssertionError(msg)
|
||||
|
||||
new_context = _format_documents_content(state["documents"])
|
||||
|
||||
user_content = (
|
||||
f"Previous result:\n{state['result']}\n\n"
|
||||
f"New context:\n{new_context}\n\n"
|
||||
f"Please provide a refined result."
|
||||
)
|
||||
|
||||
# Choose default prompt based on structured output format
|
||||
default_prompt = (
|
||||
DEFAULT_STRUCTURED_REFINE_PROMPT
|
||||
if self.response_format
|
||||
else DEFAULT_REFINE_PROMPT
|
||||
)
|
||||
|
||||
return resolve_prompt(
|
||||
self.refine_prompt,
|
||||
state,
|
||||
runtime,
|
||||
user_content,
|
||||
default_prompt,
|
||||
)
|
||||
|
||||
async def _aget_refine_prompt(
|
||||
self, state: ExtractionState, runtime: Runtime[ContextT]
|
||||
) -> list[MessageLikeRepresentation]:
|
||||
"""Generate the refinement prompt (async version)."""
|
||||
# Result should be guaranteed to exist at refinement stage
|
||||
if "result" not in state or state["result"] == "":
|
||||
msg = (
|
||||
"Internal programming error: Result must exist when refining. "
|
||||
"This indicates that the refinement node was reached without "
|
||||
"first processing the initial result node, which violates "
|
||||
"the expected graph execution order."
|
||||
)
|
||||
raise AssertionError(msg)
|
||||
|
||||
new_context = _format_documents_content(state["documents"])
|
||||
|
||||
user_content = (
|
||||
f"Previous result:\n{state['result']}\n\n"
|
||||
f"New context:\n{new_context}\n\n"
|
||||
f"Please provide a refined result."
|
||||
)
|
||||
|
||||
# Choose default prompt based on structured output format
|
||||
default_prompt = (
|
||||
DEFAULT_STRUCTURED_REFINE_PROMPT
|
||||
if self.response_format
|
||||
else DEFAULT_REFINE_PROMPT
|
||||
)
|
||||
|
||||
return await aresolve_prompt(
|
||||
self.refine_prompt,
|
||||
state,
|
||||
runtime,
|
||||
user_content,
|
||||
default_prompt,
|
||||
)
|
||||
|
||||
def create_document_processor_node(self) -> RunnableCallable:
|
||||
"""Create the main document processing node.
|
||||
|
||||
The node handles both initial processing and refinement of results.
|
||||
|
||||
Refinement is done by providing the existing result and new context.
|
||||
|
||||
If the workflow is run with a checkpointer enabled, the result will be
|
||||
persisted and available for a given thread id.
|
||||
"""
|
||||
|
||||
def _process_node(
|
||||
state: ExtractionState, runtime: Runtime[ContextT], config: RunnableConfig
|
||||
) -> ExtractionNodeUpdate:
|
||||
# Handle empty document list
|
||||
if not state["documents"]:
|
||||
return {}
|
||||
|
||||
# Determine if this is initial processing or refinement
|
||||
if "result" not in state or state["result"] == "":
|
||||
# Initial processing
|
||||
prompt = self._get_initial_prompt(state, runtime)
|
||||
response = cast("AIMessage", self.model.invoke(prompt, config=config))
|
||||
result = response if self.response_format else response.text()
|
||||
return {"result": result}
|
||||
# Refinement
|
||||
prompt = self._get_refine_prompt(state, runtime)
|
||||
response = cast("AIMessage", self.model.invoke(prompt, config=config))
|
||||
result = response if self.response_format else response.text()
|
||||
return {"result": result}
|
||||
|
||||
async def _aprocess_node(
|
||||
state: ExtractionState,
|
||||
runtime: Runtime[ContextT],
|
||||
config: RunnableConfig,
|
||||
) -> ExtractionNodeUpdate:
|
||||
# Handle empty document list
|
||||
if not state["documents"]:
|
||||
return {}
|
||||
|
||||
# Determine if this is initial processing or refinement
|
||||
if "result" not in state or state["result"] == "":
|
||||
# Initial processing
|
||||
prompt = await self._aget_initial_prompt(state, runtime)
|
||||
response = cast(
|
||||
"AIMessage", await self.model.ainvoke(prompt, config=config)
|
||||
)
|
||||
result = response if self.response_format else response.text()
|
||||
return {"result": result}
|
||||
# Refinement
|
||||
prompt = await self._aget_refine_prompt(state, runtime)
|
||||
response = cast(
|
||||
"AIMessage", await self.model.ainvoke(prompt, config=config)
|
||||
)
|
||||
result = response if self.response_format else response.text()
|
||||
return {"result": result}
|
||||
|
||||
return RunnableCallable(
|
||||
_process_node,
|
||||
_aprocess_node,
|
||||
trace=False,
|
||||
)
|
||||
|
||||
def build(
|
||||
self,
|
||||
) -> StateGraph[ExtractionState, ContextT, InputSchema, OutputSchema]:
|
||||
"""Build and compile the LangGraph for batch document extraction."""
|
||||
builder = StateGraph(
|
||||
ExtractionState,
|
||||
context_schema=self.context_schema,
|
||||
input_schema=InputSchema,
|
||||
output_schema=OutputSchema,
|
||||
)
|
||||
builder.add_edge(START, "process")
|
||||
builder.add_node("process", self.create_document_processor_node())
|
||||
return builder
|
||||
|
||||
|
||||
def create_stuff_documents_chain(
|
||||
model: Union[BaseChatModel, str],
|
||||
*,
|
||||
prompt: Union[
|
||||
str,
|
||||
None,
|
||||
Callable[[ExtractionState, Runtime[ContextT]], list[MessageLikeRepresentation]],
|
||||
] = None,
|
||||
refine_prompt: Union[
|
||||
str,
|
||||
None,
|
||||
Callable[[ExtractionState, Runtime[ContextT]], list[MessageLikeRepresentation]],
|
||||
] = None,
|
||||
context_schema: type[ContextT] | None = None,
|
||||
response_format: Optional[type[BaseModel]] = None,
|
||||
) -> StateGraph[ExtractionState, ContextT, InputSchema, OutputSchema]:
|
||||
"""Create a stuff documents chain for processing documents.
|
||||
|
||||
This chain works by putting all the documents in the batch into the context
|
||||
window of the language model. It processes all documents together in a single
|
||||
request for extracting information or summaries. Can refine existing results
|
||||
when provided. The default prompts are optimized for summarization tasks, but
|
||||
can be customized for other extraction tasks via the prompt parameters or
|
||||
response_format.
|
||||
|
||||
Strategy:
|
||||
1. Put all documents into the context window
|
||||
2. Process all documents together in a single request
|
||||
3. If an existing result is provided, refine it with all documents at once
|
||||
4. Return the result
|
||||
|
||||
Important:
|
||||
This chain does not attempt to control for the size of the context
|
||||
window of the LLM. Ensure your documents fit within the model's context limits.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from langchain.chat_models import init_chat_model
|
||||
from langchain_core.documents import Document
|
||||
|
||||
model = init_chat_model("anthropic:claude-sonnet-4-20250514")
|
||||
builder = create_stuff_documents_chain(model)
|
||||
chain = builder.compile()
|
||||
docs = [
|
||||
Document(page_content="First document content..."),
|
||||
Document(page_content="Second document content..."),
|
||||
Document(page_content="Third document content..."),
|
||||
]
|
||||
result = chain.invoke({"documents": docs})
|
||||
print(result["result"])
|
||||
|
||||
# Structured summary/extraction by passing a schema
|
||||
from pydantic import BaseModel
|
||||
|
||||
class Summary(BaseModel):
|
||||
title: str
|
||||
key_points: list[str]
|
||||
|
||||
builder = create_stuff_documents_chain(model, response_format=Summary)
|
||||
chain = builder.compile()
|
||||
result = chain.invoke({"documents": docs})
|
||||
print(result["result"].title) # Access structured fields
|
||||
```
|
||||
|
||||
Args:
|
||||
model: The language model for document processing.
|
||||
prompt: Prompt for initial processing. Can be:
|
||||
- str: A system message string
|
||||
- None: Use default system message
|
||||
- Callable: A function that takes (state, runtime) and returns messages
|
||||
refine_prompt: Prompt for refinement steps. Can be:
|
||||
- str: A system message string
|
||||
- None: Use default system message
|
||||
- Callable: A function that takes (state, runtime) and returns messages
|
||||
context_schema: Optional context schema for the LangGraph runtime.
|
||||
response_format: Optional pydantic BaseModel for structured output.
|
||||
|
||||
Returns:
|
||||
A LangGraph that can be invoked with documents to extract information.
|
||||
|
||||
.. note::
|
||||
This is a "stuff" documents chain that puts all documents into the context
|
||||
window and processes them together. It supports refining existing results.
|
||||
Default prompts are optimized for summarization but can be customized for
|
||||
other tasks. Important: Does not control for context window size.
|
||||
"""
|
||||
extractor = _Extractor(
|
||||
model,
|
||||
prompt=prompt,
|
||||
refine_prompt=refine_prompt,
|
||||
context_schema=context_schema,
|
||||
response_format=response_format,
|
||||
)
|
||||
return extractor.build()
|
||||
|
||||
|
||||
__all__ = ["create_stuff_documents_chain"]
|
||||
@@ -1,3 +1,5 @@
|
||||
"""Chat models."""
|
||||
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
|
||||
from langchain.chat_models.base import init_chat_model
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
"""Factory functions for chat models."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
@@ -5,21 +7,20 @@ from importlib import util
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Literal,
|
||||
Optional,
|
||||
TypeAlias,
|
||||
Union,
|
||||
cast,
|
||||
overload,
|
||||
)
|
||||
|
||||
from langchain_core.language_models import BaseChatModel, LanguageModelInput
|
||||
from langchain_core.messages import AnyMessage, BaseMessage
|
||||
from langchain_core.messages import AIMessage, AnyMessage
|
||||
from langchain_core.runnables import Runnable, RunnableConfig, ensure_config
|
||||
from typing_extensions import TypeAlias, override
|
||||
from typing_extensions import override
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncIterator, Iterator, Sequence
|
||||
from collections.abc import AsyncIterator, Callable, Iterator, Sequence
|
||||
|
||||
from langchain_core.runnables.schema import StreamEvent
|
||||
from langchain_core.tools import BaseTool
|
||||
@@ -31,9 +32,9 @@ if TYPE_CHECKING:
|
||||
def init_chat_model(
|
||||
model: str,
|
||||
*,
|
||||
model_provider: Optional[str] = None,
|
||||
model_provider: str | None = None,
|
||||
configurable_fields: Literal[None] = None,
|
||||
config_prefix: Optional[str] = None,
|
||||
config_prefix: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseChatModel: ...
|
||||
|
||||
@@ -42,20 +43,20 @@ def init_chat_model(
|
||||
def init_chat_model(
|
||||
model: Literal[None] = None,
|
||||
*,
|
||||
model_provider: Optional[str] = None,
|
||||
model_provider: str | None = None,
|
||||
configurable_fields: Literal[None] = None,
|
||||
config_prefix: Optional[str] = None,
|
||||
config_prefix: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> _ConfigurableModel: ...
|
||||
|
||||
|
||||
@overload
|
||||
def init_chat_model(
|
||||
model: Optional[str] = None,
|
||||
model: str | None = None,
|
||||
*,
|
||||
model_provider: Optional[str] = None,
|
||||
model_provider: str | None = None,
|
||||
configurable_fields: Union[Literal["any"], list[str], tuple[str, ...]] = ...,
|
||||
config_prefix: Optional[str] = None,
|
||||
config_prefix: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> _ConfigurableModel: ...
|
||||
|
||||
@@ -64,13 +65,11 @@ def init_chat_model(
|
||||
# name to the supported list in the docstring below. Do *not* change the order of the
|
||||
# existing providers.
|
||||
def init_chat_model(
|
||||
model: Optional[str] = None,
|
||||
model: str | None = None,
|
||||
*,
|
||||
model_provider: Optional[str] = None,
|
||||
configurable_fields: Optional[
|
||||
Union[Literal["any"], list[str], tuple[str, ...]]
|
||||
] = None,
|
||||
config_prefix: Optional[str] = None,
|
||||
model_provider: str | None = None,
|
||||
configurable_fields: Union[Literal["any"], list[str], tuple[str, ...]] | None = None,
|
||||
config_prefix: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[BaseChatModel, _ConfigurableModel]:
|
||||
"""Initialize a ChatModel from the model name and provider.
|
||||
@@ -326,7 +325,7 @@ def init_chat_model(
|
||||
def _init_chat_model_helper(
|
||||
model: str,
|
||||
*,
|
||||
model_provider: Optional[str] = None,
|
||||
model_provider: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseChatModel:
|
||||
model, model_provider = _parse_model(model, model_provider)
|
||||
@@ -446,9 +445,7 @@ def _init_chat_model_helper(
|
||||
|
||||
return ChatPerplexity(model=model, **kwargs)
|
||||
supported = ", ".join(_SUPPORTED_PROVIDERS)
|
||||
msg = (
|
||||
f"Unsupported {model_provider=}.\n\nSupported model providers are: {supported}"
|
||||
)
|
||||
msg = f"Unsupported {model_provider=}.\n\nSupported model providers are: {supported}"
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
@@ -476,7 +473,7 @@ _SUPPORTED_PROVIDERS = {
|
||||
}
|
||||
|
||||
|
||||
def _attempt_infer_model_provider(model_name: str) -> Optional[str]:
|
||||
def _attempt_infer_model_provider(model_name: str) -> str | None:
|
||||
if any(model_name.startswith(pre) for pre in ("gpt-3", "gpt-4", "o1", "o3")):
|
||||
return "openai"
|
||||
if model_name.startswith("claude"):
|
||||
@@ -500,31 +497,24 @@ def _attempt_infer_model_provider(model_name: str) -> Optional[str]:
|
||||
return None
|
||||
|
||||
|
||||
def _parse_model(model: str, model_provider: Optional[str]) -> tuple[str, str]:
|
||||
if (
|
||||
not model_provider
|
||||
and ":" in model
|
||||
and model.split(":")[0] in _SUPPORTED_PROVIDERS
|
||||
):
|
||||
def _parse_model(model: str, model_provider: str | None) -> tuple[str, str]:
|
||||
if not model_provider and ":" in model and model.split(":")[0] in _SUPPORTED_PROVIDERS:
|
||||
model_provider = model.split(":")[0]
|
||||
model = ":".join(model.split(":")[1:])
|
||||
model_provider = model_provider or _attempt_infer_model_provider(model)
|
||||
if not model_provider:
|
||||
msg = (
|
||||
f"Unable to infer model provider for {model=}, please specify "
|
||||
f"model_provider directly."
|
||||
f"Unable to infer model provider for {model=}, please specify model_provider directly."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
model_provider = model_provider.replace("-", "_").lower()
|
||||
return model, model_provider
|
||||
|
||||
|
||||
def _check_pkg(pkg: str, *, pkg_kebab: Optional[str] = None) -> None:
|
||||
def _check_pkg(pkg: str, *, pkg_kebab: str | None = None) -> None:
|
||||
if not util.find_spec(pkg):
|
||||
pkg_kebab = pkg_kebab if pkg_kebab is not None else pkg.replace("_", "-")
|
||||
msg = (
|
||||
f"Unable to import {pkg}. Please install with `pip install -U {pkg_kebab}`"
|
||||
)
|
||||
msg = f"Unable to import {pkg}. Please install with `pip install -U {pkg_kebab}`"
|
||||
raise ImportError(msg)
|
||||
|
||||
|
||||
@@ -539,16 +529,14 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
default_config: Optional[dict] = None,
|
||||
default_config: dict | None = None,
|
||||
configurable_fields: Union[Literal["any"], list[str], tuple[str, ...]] = "any",
|
||||
config_prefix: str = "",
|
||||
queued_declarative_operations: Sequence[tuple[str, tuple, dict]] = (),
|
||||
) -> None:
|
||||
self._default_config: dict = default_config or {}
|
||||
self._configurable_fields: Union[Literal["any"], list[str]] = (
|
||||
configurable_fields
|
||||
if configurable_fields == "any"
|
||||
else list(configurable_fields)
|
||||
configurable_fields if configurable_fields == "any" else list(configurable_fields)
|
||||
)
|
||||
self._config_prefix = (
|
||||
config_prefix + "_"
|
||||
@@ -589,14 +577,14 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
||||
msg += "."
|
||||
raise AttributeError(msg)
|
||||
|
||||
def _model(self, config: Optional[RunnableConfig] = None) -> Runnable:
|
||||
def _model(self, config: RunnableConfig | None = None) -> Runnable:
|
||||
params = {**self._default_config, **self._model_params(config)}
|
||||
model = _init_chat_model_helper(**params)
|
||||
for name, args, kwargs in self._queued_declarative_operations:
|
||||
model = getattr(model, name)(*args, **kwargs)
|
||||
return model
|
||||
|
||||
def _model_params(self, config: Optional[RunnableConfig]) -> dict:
|
||||
def _model_params(self, config: RunnableConfig | None) -> dict:
|
||||
config = ensure_config(config)
|
||||
model_params = {
|
||||
_remove_prefix(k, self._config_prefix): v
|
||||
@@ -604,14 +592,12 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
||||
if k.startswith(self._config_prefix)
|
||||
}
|
||||
if self._configurable_fields != "any":
|
||||
model_params = {
|
||||
k: v for k, v in model_params.items() if k in self._configurable_fields
|
||||
}
|
||||
model_params = {k: v for k, v in model_params.items() if k in self._configurable_fields}
|
||||
return model_params
|
||||
|
||||
def with_config(
|
||||
self,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
config: RunnableConfig | None = None,
|
||||
**kwargs: Any,
|
||||
) -> _ConfigurableModel:
|
||||
"""Bind config to a Runnable, returning a new Runnable."""
|
||||
@@ -662,7 +648,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
||||
def invoke(
|
||||
self,
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
config: RunnableConfig | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
return self._model(config).invoke(input, config=config, **kwargs)
|
||||
@@ -671,7 +657,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
||||
async def ainvoke(
|
||||
self,
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
config: RunnableConfig | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
return await self._model(config).ainvoke(input, config=config, **kwargs)
|
||||
@@ -680,8 +666,8 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
||||
def stream(
|
||||
self,
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
config: RunnableConfig | None = None,
|
||||
**kwargs: Any | None,
|
||||
) -> Iterator[Any]:
|
||||
yield from self._model(config).stream(input, config=config, **kwargs)
|
||||
|
||||
@@ -689,8 +675,8 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
||||
async def astream(
|
||||
self,
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
config: RunnableConfig | None = None,
|
||||
**kwargs: Any | None,
|
||||
) -> AsyncIterator[Any]:
|
||||
async for x in self._model(config).astream(input, config=config, **kwargs):
|
||||
yield x
|
||||
@@ -698,10 +684,10 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
||||
def batch(
|
||||
self,
|
||||
inputs: list[LanguageModelInput],
|
||||
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
|
||||
config: Union[RunnableConfig, list[RunnableConfig]] | None = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Optional[Any],
|
||||
**kwargs: Any | None,
|
||||
) -> list[Any]:
|
||||
config = config or None
|
||||
# If <= 1 config use the underlying models batch implementation.
|
||||
@@ -726,10 +712,10 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
||||
async def abatch(
|
||||
self,
|
||||
inputs: list[LanguageModelInput],
|
||||
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
|
||||
config: Union[RunnableConfig, list[RunnableConfig]] | None = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Optional[Any],
|
||||
**kwargs: Any | None,
|
||||
) -> list[Any]:
|
||||
config = config or None
|
||||
# If <= 1 config use the underlying models batch implementation.
|
||||
@@ -754,7 +740,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
||||
def batch_as_completed(
|
||||
self,
|
||||
inputs: Sequence[LanguageModelInput],
|
||||
config: Optional[Union[RunnableConfig, Sequence[RunnableConfig]]] = None,
|
||||
config: Union[RunnableConfig, Sequence[RunnableConfig]] | None = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Any,
|
||||
@@ -783,7 +769,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
||||
async def abatch_as_completed(
|
||||
self,
|
||||
inputs: Sequence[LanguageModelInput],
|
||||
config: Optional[Union[RunnableConfig, Sequence[RunnableConfig]]] = None,
|
||||
config: Union[RunnableConfig, Sequence[RunnableConfig]] | None = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Any,
|
||||
@@ -817,8 +803,8 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
||||
def transform(
|
||||
self,
|
||||
input: Iterator[LanguageModelInput],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
config: RunnableConfig | None = None,
|
||||
**kwargs: Any | None,
|
||||
) -> Iterator[Any]:
|
||||
yield from self._model(config).transform(input, config=config, **kwargs)
|
||||
|
||||
@@ -826,8 +812,8 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
||||
async def atransform(
|
||||
self,
|
||||
input: AsyncIterator[LanguageModelInput],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
config: RunnableConfig | None = None,
|
||||
**kwargs: Any | None,
|
||||
) -> AsyncIterator[Any]:
|
||||
async for x in self._model(config).atransform(input, config=config, **kwargs):
|
||||
yield x
|
||||
@@ -836,16 +822,16 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
||||
def astream_log(
|
||||
self,
|
||||
input: Any,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
config: RunnableConfig | None = None,
|
||||
*,
|
||||
diff: Literal[True] = True,
|
||||
with_streamed_output_list: bool = True,
|
||||
include_names: Optional[Sequence[str]] = None,
|
||||
include_types: Optional[Sequence[str]] = None,
|
||||
include_tags: Optional[Sequence[str]] = None,
|
||||
exclude_names: Optional[Sequence[str]] = None,
|
||||
exclude_types: Optional[Sequence[str]] = None,
|
||||
exclude_tags: Optional[Sequence[str]] = None,
|
||||
include_names: Sequence[str] | None = None,
|
||||
include_types: Sequence[str] | None = None,
|
||||
include_tags: Sequence[str] | None = None,
|
||||
exclude_names: Sequence[str] | None = None,
|
||||
exclude_types: Sequence[str] | None = None,
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[RunLogPatch]: ...
|
||||
|
||||
@@ -853,16 +839,16 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
||||
def astream_log(
|
||||
self,
|
||||
input: Any,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
config: RunnableConfig | None = None,
|
||||
*,
|
||||
diff: Literal[False],
|
||||
with_streamed_output_list: bool = True,
|
||||
include_names: Optional[Sequence[str]] = None,
|
||||
include_types: Optional[Sequence[str]] = None,
|
||||
include_tags: Optional[Sequence[str]] = None,
|
||||
exclude_names: Optional[Sequence[str]] = None,
|
||||
exclude_types: Optional[Sequence[str]] = None,
|
||||
exclude_tags: Optional[Sequence[str]] = None,
|
||||
include_names: Sequence[str] | None = None,
|
||||
include_types: Sequence[str] | None = None,
|
||||
include_tags: Sequence[str] | None = None,
|
||||
exclude_names: Sequence[str] | None = None,
|
||||
exclude_types: Sequence[str] | None = None,
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[RunLog]: ...
|
||||
|
||||
@@ -870,16 +856,16 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
||||
async def astream_log(
|
||||
self,
|
||||
input: Any,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
config: RunnableConfig | None = None,
|
||||
*,
|
||||
diff: bool = True,
|
||||
with_streamed_output_list: bool = True,
|
||||
include_names: Optional[Sequence[str]] = None,
|
||||
include_types: Optional[Sequence[str]] = None,
|
||||
include_tags: Optional[Sequence[str]] = None,
|
||||
exclude_names: Optional[Sequence[str]] = None,
|
||||
exclude_types: Optional[Sequence[str]] = None,
|
||||
exclude_tags: Optional[Sequence[str]] = None,
|
||||
include_names: Sequence[str] | None = None,
|
||||
include_types: Sequence[str] | None = None,
|
||||
include_tags: Sequence[str] | None = None,
|
||||
exclude_names: Sequence[str] | None = None,
|
||||
exclude_types: Sequence[str] | None = None,
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AsyncIterator[RunLogPatch], AsyncIterator[RunLog]]:
|
||||
async for x in self._model(config).astream_log( # type: ignore[call-overload, misc]
|
||||
@@ -901,15 +887,15 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
||||
async def astream_events(
|
||||
self,
|
||||
input: Any,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
config: RunnableConfig | None = None,
|
||||
*,
|
||||
version: Literal["v1", "v2"] = "v2",
|
||||
include_names: Optional[Sequence[str]] = None,
|
||||
include_types: Optional[Sequence[str]] = None,
|
||||
include_tags: Optional[Sequence[str]] = None,
|
||||
exclude_names: Optional[Sequence[str]] = None,
|
||||
exclude_types: Optional[Sequence[str]] = None,
|
||||
exclude_tags: Optional[Sequence[str]] = None,
|
||||
include_names: Sequence[str] | None = None,
|
||||
include_types: Sequence[str] | None = None,
|
||||
include_tags: Sequence[str] | None = None,
|
||||
exclude_names: Sequence[str] | None = None,
|
||||
exclude_types: Sequence[str] | None = None,
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[StreamEvent]:
|
||||
async for x in self._model(config).astream_events(
|
||||
@@ -931,7 +917,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
||||
self,
|
||||
tools: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]],
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||
) -> Runnable[LanguageModelInput, AIMessage]:
|
||||
return self.__getattr__("bind_tools")(tools, **kwargs)
|
||||
|
||||
# Explicitly added to satisfy downstream linters.
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
"""Document."""
|
||||
|
||||
from langchain_core.documents import Document
|
||||
|
||||
__all__ = [
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
"""Embeddings."""
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
from langchain.embeddings.base import init_embeddings
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Factory functions for embeddings."""
|
||||
|
||||
import functools
|
||||
from importlib import util
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Union
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.runnables import Runnable
|
||||
@@ -19,9 +21,7 @@ _SUPPORTED_PROVIDERS = {
|
||||
|
||||
def _get_provider_list() -> str:
|
||||
"""Get formatted list of providers and their packages."""
|
||||
return "\n".join(
|
||||
f" - {p}: {pkg.replace('_', '-')}" for p, pkg in _SUPPORTED_PROVIDERS.items()
|
||||
)
|
||||
return "\n".join(f" - {p}: {pkg.replace('_', '-')}" for p, pkg in _SUPPORTED_PROVIDERS.items())
|
||||
|
||||
|
||||
def _parse_model_string(model_name: str) -> tuple[str, str]:
|
||||
@@ -82,7 +82,7 @@ def _parse_model_string(model_name: str) -> tuple[str, str]:
|
||||
def _infer_model_and_provider(
|
||||
model: str,
|
||||
*,
|
||||
provider: Optional[str] = None,
|
||||
provider: str | None = None,
|
||||
) -> tuple[str, str]:
|
||||
if not model.strip():
|
||||
msg = "Model name cannot be empty"
|
||||
@@ -117,17 +117,14 @@ def _infer_model_and_provider(
|
||||
def _check_pkg(pkg: str) -> None:
|
||||
"""Check if a package is installed."""
|
||||
if not util.find_spec(pkg):
|
||||
msg = (
|
||||
f"Could not import {pkg} python package. "
|
||||
f"Please install it with `pip install {pkg}`"
|
||||
)
|
||||
msg = f"Could not import {pkg} python package. Please install it with `pip install {pkg}`"
|
||||
raise ImportError(msg)
|
||||
|
||||
|
||||
def init_embeddings(
|
||||
model: str,
|
||||
*,
|
||||
provider: Optional[str] = None,
|
||||
provider: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[Embeddings, Runnable[Any, list[float]]]:
|
||||
"""Initialize an embeddings model from a model name and optional provider.
|
||||
@@ -182,9 +179,7 @@ def init_embeddings(
|
||||
"""
|
||||
if not model:
|
||||
providers = _SUPPORTED_PROVIDERS.keys()
|
||||
msg = (
|
||||
f"Must specify model name. Supported providers are: {', '.join(providers)}"
|
||||
)
|
||||
msg = f"Must specify model name. Supported providers are: {', '.join(providers)}"
|
||||
raise ValueError(msg)
|
||||
|
||||
provider, model_name = _infer_model_and_provider(model, provider=provider)
|
||||
|
||||
@@ -13,7 +13,7 @@ import hashlib
|
||||
import json
|
||||
import uuid
|
||||
import warnings
|
||||
from typing import TYPE_CHECKING, Callable, Literal, Optional, Union, cast
|
||||
from typing import TYPE_CHECKING, Literal, Union, cast
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.utils.iter import batch_iterate
|
||||
@@ -21,7 +21,7 @@ from langchain_core.utils.iter import batch_iterate
|
||||
from langchain.storage.encoder_backed import EncoderBackedStore
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Sequence
|
||||
|
||||
from langchain_core.stores import BaseStore, ByteStore
|
||||
|
||||
@@ -147,8 +147,8 @@ class CacheBackedEmbeddings(Embeddings):
|
||||
underlying_embeddings: Embeddings,
|
||||
document_embedding_store: BaseStore[str, list[float]],
|
||||
*,
|
||||
batch_size: Optional[int] = None,
|
||||
query_embedding_store: Optional[BaseStore[str, list[float]]] = None,
|
||||
batch_size: int | None = None,
|
||||
query_embedding_store: BaseStore[str, list[float]] | None = None,
|
||||
) -> None:
|
||||
"""Initialize the embedder.
|
||||
|
||||
@@ -181,17 +181,15 @@ class CacheBackedEmbeddings(Embeddings):
|
||||
vectors: list[Union[list[float], None]] = self.document_embedding_store.mget(
|
||||
texts,
|
||||
)
|
||||
all_missing_indices: list[int] = [
|
||||
i for i, vector in enumerate(vectors) if vector is None
|
||||
]
|
||||
all_missing_indices: list[int] = [i for i, vector in enumerate(vectors) if vector is None]
|
||||
|
||||
for missing_indices in batch_iterate(self.batch_size, all_missing_indices):
|
||||
missing_texts = [texts[i] for i in missing_indices]
|
||||
missing_vectors = self.underlying_embeddings.embed_documents(missing_texts)
|
||||
self.document_embedding_store.mset(
|
||||
list(zip(missing_texts, missing_vectors)),
|
||||
list(zip(missing_texts, missing_vectors, strict=False)),
|
||||
)
|
||||
for index, updated_vector in zip(missing_indices, missing_vectors):
|
||||
for index, updated_vector in zip(missing_indices, missing_vectors, strict=False):
|
||||
vectors[index] = updated_vector
|
||||
|
||||
return cast(
|
||||
@@ -212,12 +210,8 @@ class CacheBackedEmbeddings(Embeddings):
|
||||
Returns:
|
||||
A list of embeddings for the given texts.
|
||||
"""
|
||||
vectors: list[
|
||||
Union[list[float], None]
|
||||
] = await self.document_embedding_store.amget(texts)
|
||||
all_missing_indices: list[int] = [
|
||||
i for i, vector in enumerate(vectors) if vector is None
|
||||
]
|
||||
vectors: list[Union[list[float], None]] = await self.document_embedding_store.amget(texts)
|
||||
all_missing_indices: list[int] = [i for i, vector in enumerate(vectors) if vector is None]
|
||||
|
||||
# batch_iterate supports None batch_size which returns all elements at once
|
||||
# as a single batch.
|
||||
@@ -227,9 +221,9 @@ class CacheBackedEmbeddings(Embeddings):
|
||||
missing_texts,
|
||||
)
|
||||
await self.document_embedding_store.amset(
|
||||
list(zip(missing_texts, missing_vectors)),
|
||||
list(zip(missing_texts, missing_vectors, strict=False)),
|
||||
)
|
||||
for index, updated_vector in zip(missing_indices, missing_vectors):
|
||||
for index, updated_vector in zip(missing_indices, missing_vectors, strict=False):
|
||||
vectors[index] = updated_vector
|
||||
|
||||
return cast(
|
||||
@@ -290,7 +284,7 @@ class CacheBackedEmbeddings(Embeddings):
|
||||
document_embedding_cache: ByteStore,
|
||||
*,
|
||||
namespace: str = "",
|
||||
batch_size: Optional[int] = None,
|
||||
batch_size: int | None = None,
|
||||
query_embedding_cache: Union[bool, ByteStore] = False,
|
||||
key_encoder: Union[
|
||||
Callable[[str], str],
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
"""Global values and configuration that apply to all of LangChain."""
|
||||
"""Global values and configuration that apply to all of LangChain.
|
||||
|
||||
TODO: will be removed in a future alpha version.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from collections.abc import AsyncIterator, Iterator, Sequence
|
||||
"""Encoder-backed store implementation."""
|
||||
|
||||
from collections.abc import AsyncIterator, Callable, Iterator, Sequence
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Optional,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
@@ -62,37 +62,29 @@ class EncoderBackedStore(BaseStore[K, V]):
|
||||
self.value_serializer = value_serializer
|
||||
self.value_deserializer = value_deserializer
|
||||
|
||||
def mget(self, keys: Sequence[K]) -> list[Optional[V]]:
|
||||
def mget(self, keys: Sequence[K]) -> list[V | None]:
|
||||
"""Get the values associated with the given keys."""
|
||||
encoded_keys: list[str] = [self.key_encoder(key) for key in keys]
|
||||
values = self.store.mget(encoded_keys)
|
||||
return [
|
||||
self.value_deserializer(value) if value is not None else value
|
||||
for value in values
|
||||
]
|
||||
return [self.value_deserializer(value) if value is not None else value for value in values]
|
||||
|
||||
async def amget(self, keys: Sequence[K]) -> list[Optional[V]]:
|
||||
async def amget(self, keys: Sequence[K]) -> list[V | None]:
|
||||
"""Get the values associated with the given keys."""
|
||||
encoded_keys: list[str] = [self.key_encoder(key) for key in keys]
|
||||
values = await self.store.amget(encoded_keys)
|
||||
return [
|
||||
self.value_deserializer(value) if value is not None else value
|
||||
for value in values
|
||||
]
|
||||
return [self.value_deserializer(value) if value is not None else value for value in values]
|
||||
|
||||
def mset(self, key_value_pairs: Sequence[tuple[K, V]]) -> None:
|
||||
"""Set the values for the given keys."""
|
||||
encoded_pairs = [
|
||||
(self.key_encoder(key), self.value_serializer(value))
|
||||
for key, value in key_value_pairs
|
||||
(self.key_encoder(key), self.value_serializer(value)) for key, value in key_value_pairs
|
||||
]
|
||||
self.store.mset(encoded_pairs)
|
||||
|
||||
async def amset(self, key_value_pairs: Sequence[tuple[K, V]]) -> None:
|
||||
"""Set the values for the given keys."""
|
||||
encoded_pairs = [
|
||||
(self.key_encoder(key), self.value_serializer(value))
|
||||
for key, value in key_value_pairs
|
||||
(self.key_encoder(key), self.value_serializer(value)) for key, value in key_value_pairs
|
||||
]
|
||||
await self.store.amset(encoded_pairs)
|
||||
|
||||
@@ -109,7 +101,7 @@ class EncoderBackedStore(BaseStore[K, V]):
|
||||
def yield_keys(
|
||||
self,
|
||||
*,
|
||||
prefix: Optional[str] = None,
|
||||
prefix: str | None = None,
|
||||
) -> Union[Iterator[K], Iterator[str]]:
|
||||
"""Get an iterator over keys that match the given prefix."""
|
||||
# For the time being this does not return K, but str
|
||||
@@ -119,7 +111,7 @@ class EncoderBackedStore(BaseStore[K, V]):
|
||||
async def ayield_keys(
|
||||
self,
|
||||
*,
|
||||
prefix: Optional[str] = None,
|
||||
prefix: str | None = None,
|
||||
) -> Union[AsyncIterator[K], AsyncIterator[str]]:
|
||||
"""Get an iterator over keys that match the given prefix."""
|
||||
# For the time being this does not return K, but str
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
"""Store exceptions."""
|
||||
|
||||
from langchain_core.stores import InvalidKeyException
|
||||
|
||||
__all__ = ["InvalidKeyException"]
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
"""Tools."""
|
||||
|
||||
from langchain_core.tools import (
|
||||
BaseTool,
|
||||
InjectedToolArg,
|
||||
|
||||
@@ -5,17 +5,16 @@ build-backend = "pdm.backend"
|
||||
[project]
|
||||
authors = []
|
||||
license = { text = "MIT" }
|
||||
requires-python = ">=3.9, <4.0"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"langchain-core<1.0.0,>=0.3.66",
|
||||
"langchain-text-splitters<1.0.0,>=0.3.8",
|
||||
"langgraph>=0.6.0",
|
||||
"langchain-core<2.0.0,>=0.3.75",
|
||||
"langchain-text-splitters<1.0.0,>=0.3.11",
|
||||
"langgraph>=0.6.7",
|
||||
"pydantic>=2.7.4",
|
||||
]
|
||||
|
||||
|
||||
name = "langchain"
|
||||
version = "1.0.0dev1"
|
||||
version = "1.0.0a4"
|
||||
description = "Building applications with LLMs through composability"
|
||||
readme = "README.md"
|
||||
|
||||
@@ -46,35 +45,33 @@ repository = "https://github.com/langchain-ai/langchain"
|
||||
[dependency-groups]
|
||||
test = [
|
||||
"pytest<9,>=8",
|
||||
"pytest-cov<5.0.0,>=4.0.0",
|
||||
"pytest-watcher<1.0.0,>=0.2.6",
|
||||
"pytest-asyncio<1.0.0,>=0.23.2",
|
||||
"pytest-socket<1.0.0,>=0.6.0",
|
||||
"syrupy<5.0.0,>=4.0.2",
|
||||
"pytest-xdist<4.0.0,>=3.6.1",
|
||||
"blockbuster<1.6,>=1.5.18",
|
||||
"pytest-cov>=4.0.0",
|
||||
"pytest-watcher>=0.2.6",
|
||||
"pytest-asyncio>=0.23.2",
|
||||
"pytest-socket>=0.6.0",
|
||||
"syrupy>=4.0.2",
|
||||
"pytest-xdist>=3.6.1",
|
||||
"langchain-tests",
|
||||
"langchain-core",
|
||||
"langchain-text-splitters",
|
||||
"langchain-openai",
|
||||
"toml>=0.10.2",
|
||||
"pytest-mock"
|
||||
]
|
||||
codespell = ["codespell<3.0.0,>=2.2.0"]
|
||||
lint = [
|
||||
"ruff<0.13,>=0.12.2",
|
||||
"mypy<1.16,>=1.15",
|
||||
"ruff>=0.12.2",
|
||||
]
|
||||
typing = [
|
||||
"mypy<1.18,>=1.17.1",
|
||||
"types-toml>=0.10.8.20240310",
|
||||
]
|
||||
|
||||
test_integration = [
|
||||
"vcrpy>=7.0",
|
||||
"urllib3<2; python_version < \"3.10\"",
|
||||
"wrapt<2.0.0,>=1.15.0",
|
||||
"python-dotenv<2.0.0,>=1.0.0",
|
||||
"cassio<1.0.0,>=0.1.0",
|
||||
"langchainhub<1.0.0,>=0.1.16",
|
||||
"wrapt>=1.15.0",
|
||||
"python-dotenv>=1.0.0",
|
||||
"cassio>=0.1.0",
|
||||
"langchainhub>=0.1.16",
|
||||
"langchain-core",
|
||||
"langchain-text-splitters",
|
||||
]
|
||||
@@ -86,22 +83,19 @@ langchain-text-splitters = { path = "../text-splitters", editable = true }
|
||||
langchain-openai = { path = "../partners/openai", editable = true }
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py39"
|
||||
target-version = "py310"
|
||||
exclude = ["tests/integration_tests/examples/non-utf8-encoding.py"]
|
||||
line-length = 100
|
||||
|
||||
[tool.mypy]
|
||||
strict = "True"
|
||||
strict_bytes = "True"
|
||||
ignore_missing_imports = "True"
|
||||
strict = true
|
||||
ignore_missing_imports = true
|
||||
enable_error_code = "deprecated"
|
||||
report_deprecated_as_note = "True"
|
||||
exclude = ["tests/unit_tests/agents/*", "tests/integration_tests/agents/*"]
|
||||
|
||||
# TODO: activate for 'strict' checking
|
||||
disallow_untyped_calls = "False"
|
||||
disallow_any_generics = "False"
|
||||
disallow_untyped_decorators = "False"
|
||||
warn_return_any = "False"
|
||||
strict_equality = "False"
|
||||
disallow_any_generics = false
|
||||
warn_return_any = false
|
||||
|
||||
[tool.codespell]
|
||||
skip = ".git,*.pdf,*.svg,*.pdf,*.yaml,*.ipynb,poetry.lock,*.min.js,*.css,package-lock.json,example_data,_dist,examples,*.trig"
|
||||
@@ -113,9 +107,6 @@ select = [
|
||||
"ALL"
|
||||
]
|
||||
ignore = [
|
||||
"D100", # pydocstyle: Missing docstring in public module
|
||||
"D104", # pydocstyle: Missing docstring in public package
|
||||
"D105", # pydocstyle: Missing docstring in magic method
|
||||
"COM812", # Messes with the formatter
|
||||
"ISC001", # Messes with the formatter
|
||||
"PERF203", # Rarely useful
|
||||
@@ -133,17 +124,27 @@ flake8-annotations.allow-star-arg-any = true
|
||||
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
"tests/*" = [
|
||||
"D", # Documentation rules
|
||||
"D1", # Documentation rules
|
||||
"PLC0415", # Imports should be at the top. Not always desirable for tests
|
||||
]
|
||||
"langchain/agents/*" = [
|
||||
"ANN401", # we use Any right now, need to narrow
|
||||
"E501", # line too long, needs to fix
|
||||
"A002", # input is shadowing builtin
|
||||
"A001", # input is shadowing builtin
|
||||
"B904", # use from for exceptions
|
||||
"PLR2004", # magic values are fine for this case
|
||||
"C901", # too complex
|
||||
"TRY004", # type error exception
|
||||
"PLR0912", # too many branches
|
||||
"PLR0911", # too many return statements
|
||||
]
|
||||
"tests/unit_tests/agents/*" = ["ALL"]
|
||||
"tests/integration_tests/agents/*" = ["ALL"]
|
||||
|
||||
[tool.ruff.lint.extend-per-file-ignores]
|
||||
"scripts/check_imports.py" = ["ALL"]
|
||||
|
||||
"langchain/globals.py" = [
|
||||
"PLW"
|
||||
]
|
||||
|
||||
"langchain/chat_models/base.py" = [
|
||||
"ANN",
|
||||
"C901",
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
"""Integration tests for the agents module."""
|
||||
@@ -0,0 +1,79 @@
|
||||
import pytest
|
||||
from langchain_core.messages import HumanMessage
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from langchain.agents import create_agent
|
||||
from langchain.agents.structured_output import ToolStrategy
|
||||
|
||||
|
||||
class WeatherBaseModel(BaseModel):
|
||||
"""Weather response."""
|
||||
|
||||
temperature: float = Field(description="The temperature in fahrenheit")
|
||||
condition: str = Field(description="Weather condition")
|
||||
|
||||
|
||||
def get_weather(city: str) -> str: # noqa: ARG001
|
||||
"""Get the weather for a city."""
|
||||
return "The weather is sunny and 75°F."
|
||||
|
||||
|
||||
@pytest.mark.requires("langchain_openai")
|
||||
def test_inference_to_native_output() -> None:
|
||||
"""Test that native output is inferred when a model supports it."""
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
model = ChatOpenAI(model="gpt-5")
|
||||
agent = create_agent(
|
||||
model,
|
||||
prompt=(
|
||||
"You are a helpful weather assistant. Please call the get_weather tool, "
|
||||
"then use the WeatherReport tool to generate the final response."
|
||||
),
|
||||
tools=[get_weather],
|
||||
response_format=WeatherBaseModel,
|
||||
)
|
||||
response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
|
||||
|
||||
assert isinstance(response["structured_response"], WeatherBaseModel)
|
||||
assert response["structured_response"].temperature == 75.0
|
||||
assert response["structured_response"].condition.lower() == "sunny"
|
||||
assert len(response["messages"]) == 4
|
||||
|
||||
assert [m.type for m in response["messages"]] == [
|
||||
"human", # "What's the weather?"
|
||||
"ai", # "What's the weather?"
|
||||
"tool", # "The weather is sunny and 75°F."
|
||||
"ai", # structured response
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.requires("langchain_openai")
|
||||
def test_inference_to_tool_output() -> None:
|
||||
"""Test that tool output is inferred when a model supports it."""
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
model = ChatOpenAI(model="gpt-4")
|
||||
agent = create_agent(
|
||||
model,
|
||||
prompt=(
|
||||
"You are a helpful weather assistant. Please call the get_weather tool, "
|
||||
"then use the WeatherReport tool to generate the final response."
|
||||
),
|
||||
tools=[get_weather],
|
||||
response_format=ToolStrategy(WeatherBaseModel),
|
||||
)
|
||||
response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
|
||||
|
||||
assert isinstance(response["structured_response"], WeatherBaseModel)
|
||||
assert response["structured_response"].temperature == 75.0
|
||||
assert response["structured_response"].condition.lower() == "sunny"
|
||||
assert len(response["messages"]) == 5
|
||||
|
||||
assert [m.type for m in response["messages"]] == [
|
||||
"human", # "What's the weather?"
|
||||
"ai", # "What's the weather?"
|
||||
"tool", # "The weather is sunny and 75°F."
|
||||
"ai", # structured response
|
||||
"tool", # artificial tool message
|
||||
]
|
||||
@@ -12,7 +12,9 @@ class FakeEmbeddings(Embeddings):
|
||||
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Return simple embeddings.
|
||||
Embeddings encode each text as its index."""
|
||||
|
||||
Embeddings encode each text as its index.
|
||||
"""
|
||||
return [[1.0] * 9 + [float(i)] for i in range(len(texts))]
|
||||
|
||||
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
@@ -20,9 +22,11 @@ class FakeEmbeddings(Embeddings):
|
||||
|
||||
def embed_query(self, text: str) -> list[float]:
|
||||
"""Return constant query embeddings.
|
||||
|
||||
Embeddings are identical to embed_documents(texts)[0].
|
||||
Distance to each text will be that text's index,
|
||||
as it was passed to embed_documents."""
|
||||
as it was passed to embed_documents.
|
||||
"""
|
||||
return [1.0] * 9 + [0.0]
|
||||
|
||||
async def aembed_query(self, text: str) -> list[float]:
|
||||
@@ -30,8 +34,11 @@ class FakeEmbeddings(Embeddings):
|
||||
|
||||
|
||||
class ConsistentFakeEmbeddings(FakeEmbeddings):
|
||||
"""Fake embeddings which remember all the texts seen so far to return consistent
|
||||
vectors for the same texts."""
|
||||
"""Consistent fake embeddings.
|
||||
|
||||
Fake embeddings which remember all the texts seen so far to return consistent
|
||||
vectors for the same texts.
|
||||
"""
|
||||
|
||||
def __init__(self, dimensionality: int = 10) -> None:
|
||||
self.known_texts: list[str] = []
|
||||
@@ -50,25 +57,24 @@ class ConsistentFakeEmbeddings(FakeEmbeddings):
|
||||
return out_vectors
|
||||
|
||||
def embed_query(self, text: str) -> list[float]:
|
||||
"""Return consistent embeddings for the text, if seen before, or a constant
|
||||
one if the text is unknown."""
|
||||
"""Return consistent embeddings.
|
||||
|
||||
Return consistent embeddings for the text, if seen before, or a constant
|
||||
one if the text is unknown.
|
||||
"""
|
||||
return self.embed_documents([text])[0]
|
||||
|
||||
|
||||
class AngularTwoDimensionalEmbeddings(Embeddings):
|
||||
"""
|
||||
From angles (as strings in units of pi) to unit embedding vectors on a circle.
|
||||
"""
|
||||
"""From angles (as strings in units of pi) to unit embedding vectors on a circle."""
|
||||
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
"""
|
||||
Make a list of texts into a list of embedding vectors.
|
||||
"""
|
||||
"""Make a list of texts into a list of embedding vectors."""
|
||||
return [self.embed_query(text) for text in texts]
|
||||
|
||||
def embed_query(self, text: str) -> list[float]:
|
||||
"""
|
||||
Convert input text to a 'vector' (list of floats).
|
||||
"""Convert input text to a 'vector' (list of floats).
|
||||
|
||||
If the text is a number, use it as the angle for the
|
||||
unit vector in units of pi.
|
||||
Any other input text becomes the singular result [0, 0] !
|
||||
|
||||
@@ -31,9 +31,7 @@ async def test_init_chat_model_chain() -> None:
|
||||
chain = prompt | model_with_config
|
||||
output = chain.invoke({"input": "bar"})
|
||||
assert isinstance(output, AIMessage)
|
||||
events = [
|
||||
event async for event in chain.astream_events({"input": "bar"}, version="v2")
|
||||
]
|
||||
events = [event async for event in chain.astream_events({"input": "bar"}, version="v2")]
|
||||
assert events
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,533 @@
|
||||
# serializer version: 1
|
||||
# name: test_create_agent_diagram
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model_request(model_request)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
__start__ --> model_request;
|
||||
model_request --> __end__;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_diagram.1
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model_request(model_request)
|
||||
NoopOne_before_model(NoopOne.before_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopOne_before_model -.-> __end__;
|
||||
NoopOne_before_model -.-> model_request;
|
||||
__start__ --> NoopOne_before_model;
|
||||
model_request --> __end__;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_diagram.10
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model_request(model_request)
|
||||
NoopTen_before_model(NoopTen.before_model)
|
||||
NoopTen_after_model(NoopTen.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopTen_after_model -.-> NoopTen_before_model;
|
||||
NoopTen_after_model -.-> __end__;
|
||||
NoopTen_before_model -.-> __end__;
|
||||
NoopTen_before_model -.-> model_request;
|
||||
__start__ --> NoopTen_before_model;
|
||||
model_request --> NoopTen_after_model;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_diagram.11
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model_request(model_request)
|
||||
NoopTen_before_model(NoopTen.before_model)
|
||||
NoopTen_after_model(NoopTen.after_model)
|
||||
NoopEleven_before_model(NoopEleven.before_model)
|
||||
NoopEleven_after_model(NoopEleven.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopEleven_after_model -.-> NoopTen_after_model;
|
||||
NoopEleven_after_model -.-> NoopTen_before_model;
|
||||
NoopEleven_after_model -.-> __end__;
|
||||
NoopEleven_before_model -.-> NoopTen_before_model;
|
||||
NoopEleven_before_model -.-> __end__;
|
||||
NoopEleven_before_model -.-> model_request;
|
||||
NoopTen_after_model -.-> NoopTen_before_model;
|
||||
NoopTen_after_model -.-> __end__;
|
||||
NoopTen_before_model -.-> NoopEleven_before_model;
|
||||
NoopTen_before_model -.-> __end__;
|
||||
__start__ --> NoopTen_before_model;
|
||||
model_request --> NoopEleven_after_model;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_diagram.2
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model_request(model_request)
|
||||
NoopOne_before_model(NoopOne.before_model)
|
||||
NoopTwo_before_model(NoopTwo.before_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopOne_before_model -.-> NoopTwo_before_model;
|
||||
NoopOne_before_model -.-> __end__;
|
||||
NoopTwo_before_model -.-> NoopOne_before_model;
|
||||
NoopTwo_before_model -.-> __end__;
|
||||
NoopTwo_before_model -.-> model_request;
|
||||
__start__ --> NoopOne_before_model;
|
||||
model_request --> __end__;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_diagram.3
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model_request(model_request)
|
||||
NoopOne_before_model(NoopOne.before_model)
|
||||
NoopTwo_before_model(NoopTwo.before_model)
|
||||
NoopThree_before_model(NoopThree.before_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopOne_before_model -.-> NoopTwo_before_model;
|
||||
NoopOne_before_model -.-> __end__;
|
||||
NoopThree_before_model -.-> NoopOne_before_model;
|
||||
NoopThree_before_model -.-> __end__;
|
||||
NoopThree_before_model -.-> model_request;
|
||||
NoopTwo_before_model -.-> NoopOne_before_model;
|
||||
NoopTwo_before_model -.-> NoopThree_before_model;
|
||||
NoopTwo_before_model -.-> __end__;
|
||||
__start__ --> NoopOne_before_model;
|
||||
model_request --> __end__;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_diagram.4
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model_request(model_request)
|
||||
NoopFour_after_model(NoopFour.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopFour_after_model -.-> __end__;
|
||||
NoopFour_after_model -.-> model_request;
|
||||
__start__ --> model_request;
|
||||
model_request --> NoopFour_after_model;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_diagram.5
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model_request(model_request)
|
||||
NoopFour_after_model(NoopFour.after_model)
|
||||
NoopFive_after_model(NoopFive.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopFive_after_model -.-> NoopFour_after_model;
|
||||
NoopFive_after_model -.-> __end__;
|
||||
NoopFive_after_model -.-> model_request;
|
||||
NoopFour_after_model -.-> __end__;
|
||||
NoopFour_after_model -.-> model_request;
|
||||
__start__ --> model_request;
|
||||
model_request --> NoopFive_after_model;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_diagram.6
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model_request(model_request)
|
||||
NoopFour_after_model(NoopFour.after_model)
|
||||
NoopFive_after_model(NoopFive.after_model)
|
||||
NoopSix_after_model(NoopSix.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopFive_after_model -.-> NoopFour_after_model;
|
||||
NoopFive_after_model -.-> __end__;
|
||||
NoopFive_after_model -.-> model_request;
|
||||
NoopFour_after_model -.-> __end__;
|
||||
NoopFour_after_model -.-> model_request;
|
||||
NoopSix_after_model -.-> NoopFive_after_model;
|
||||
NoopSix_after_model -.-> __end__;
|
||||
NoopSix_after_model -.-> model_request;
|
||||
__start__ --> model_request;
|
||||
model_request --> NoopSix_after_model;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_diagram.7
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model_request(model_request)
|
||||
NoopSeven_before_model(NoopSeven.before_model)
|
||||
NoopSeven_after_model(NoopSeven.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopSeven_after_model -.-> NoopSeven_before_model;
|
||||
NoopSeven_after_model -.-> __end__;
|
||||
NoopSeven_before_model -.-> __end__;
|
||||
NoopSeven_before_model -.-> model_request;
|
||||
__start__ --> NoopSeven_before_model;
|
||||
model_request --> NoopSeven_after_model;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_diagram.8
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model_request(model_request)
|
||||
NoopSeven_before_model(NoopSeven.before_model)
|
||||
NoopSeven_after_model(NoopSeven.after_model)
|
||||
NoopEight_before_model(NoopEight.before_model)
|
||||
NoopEight_after_model(NoopEight.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopEight_after_model -.-> NoopSeven_after_model;
|
||||
NoopEight_after_model -.-> NoopSeven_before_model;
|
||||
NoopEight_after_model -.-> __end__;
|
||||
NoopEight_before_model -.-> NoopSeven_before_model;
|
||||
NoopEight_before_model -.-> __end__;
|
||||
NoopEight_before_model -.-> model_request;
|
||||
NoopSeven_after_model -.-> NoopSeven_before_model;
|
||||
NoopSeven_after_model -.-> __end__;
|
||||
NoopSeven_before_model -.-> NoopEight_before_model;
|
||||
NoopSeven_before_model -.-> __end__;
|
||||
__start__ --> NoopSeven_before_model;
|
||||
model_request --> NoopEight_after_model;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_diagram.9
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model_request(model_request)
|
||||
NoopSeven_before_model(NoopSeven.before_model)
|
||||
NoopSeven_after_model(NoopSeven.after_model)
|
||||
NoopEight_before_model(NoopEight.before_model)
|
||||
NoopEight_after_model(NoopEight.after_model)
|
||||
NoopNine_before_model(NoopNine.before_model)
|
||||
NoopNine_after_model(NoopNine.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopEight_after_model -.-> NoopSeven_after_model;
|
||||
NoopEight_after_model -.-> NoopSeven_before_model;
|
||||
NoopEight_after_model -.-> __end__;
|
||||
NoopEight_before_model -.-> NoopNine_before_model;
|
||||
NoopEight_before_model -.-> NoopSeven_before_model;
|
||||
NoopEight_before_model -.-> __end__;
|
||||
NoopNine_after_model -.-> NoopEight_after_model;
|
||||
NoopNine_after_model -.-> NoopSeven_before_model;
|
||||
NoopNine_after_model -.-> __end__;
|
||||
NoopNine_before_model -.-> NoopSeven_before_model;
|
||||
NoopNine_before_model -.-> __end__;
|
||||
NoopNine_before_model -.-> model_request;
|
||||
NoopSeven_after_model -.-> NoopSeven_before_model;
|
||||
NoopSeven_after_model -.-> __end__;
|
||||
NoopSeven_before_model -.-> NoopEight_before_model;
|
||||
NoopSeven_before_model -.-> __end__;
|
||||
__start__ --> NoopSeven_before_model;
|
||||
model_request --> NoopNine_after_model;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_jump[memory]
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model_request(model_request)
|
||||
tools(tools)
|
||||
NoopSeven_before_model(NoopSeven.before_model)
|
||||
NoopSeven_after_model(NoopSeven.after_model)
|
||||
NoopEight_before_model(NoopEight.before_model)
|
||||
NoopEight_after_model(NoopEight.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopEight_after_model -.-> NoopSeven_after_model;
|
||||
NoopEight_after_model -.-> NoopSeven_before_model;
|
||||
NoopEight_after_model -.-> __end__;
|
||||
NoopEight_after_model -.-> tools;
|
||||
NoopEight_before_model -.-> NoopSeven_before_model;
|
||||
NoopEight_before_model -.-> __end__;
|
||||
NoopEight_before_model -.-> model_request;
|
||||
NoopEight_before_model -.-> tools;
|
||||
NoopSeven_after_model -.-> NoopSeven_before_model;
|
||||
NoopSeven_after_model -.-> __end__;
|
||||
NoopSeven_after_model -.-> tools;
|
||||
NoopSeven_before_model -.-> NoopEight_before_model;
|
||||
NoopSeven_before_model -.-> __end__;
|
||||
NoopSeven_before_model -.-> tools;
|
||||
__start__ --> NoopSeven_before_model;
|
||||
model_request --> NoopEight_after_model;
|
||||
tools -.-> NoopSeven_before_model;
|
||||
tools -.-> __end__;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_jump[postgres]
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model_request(model_request)
|
||||
tools(tools)
|
||||
NoopSeven_before_model(NoopSeven.before_model)
|
||||
NoopSeven_after_model(NoopSeven.after_model)
|
||||
NoopEight_before_model(NoopEight.before_model)
|
||||
NoopEight_after_model(NoopEight.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopEight_after_model -.-> NoopSeven_after_model;
|
||||
NoopEight_after_model -.-> NoopSeven_before_model;
|
||||
NoopEight_after_model -.-> __end__;
|
||||
NoopEight_after_model -.-> tools;
|
||||
NoopEight_before_model -.-> NoopSeven_before_model;
|
||||
NoopEight_before_model -.-> __end__;
|
||||
NoopEight_before_model -.-> model_request;
|
||||
NoopEight_before_model -.-> tools;
|
||||
NoopSeven_after_model -.-> NoopSeven_before_model;
|
||||
NoopSeven_after_model -.-> __end__;
|
||||
NoopSeven_after_model -.-> tools;
|
||||
NoopSeven_before_model -.-> NoopEight_before_model;
|
||||
NoopSeven_before_model -.-> __end__;
|
||||
NoopSeven_before_model -.-> tools;
|
||||
__start__ --> NoopSeven_before_model;
|
||||
model_request --> NoopEight_after_model;
|
||||
tools -.-> NoopSeven_before_model;
|
||||
tools -.-> __end__;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_jump[postgres_pipe]
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model_request(model_request)
|
||||
tools(tools)
|
||||
NoopSeven_before_model(NoopSeven.before_model)
|
||||
NoopSeven_after_model(NoopSeven.after_model)
|
||||
NoopEight_before_model(NoopEight.before_model)
|
||||
NoopEight_after_model(NoopEight.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopEight_after_model -.-> NoopSeven_after_model;
|
||||
NoopEight_after_model -.-> NoopSeven_before_model;
|
||||
NoopEight_after_model -.-> __end__;
|
||||
NoopEight_after_model -.-> tools;
|
||||
NoopEight_before_model -.-> NoopSeven_before_model;
|
||||
NoopEight_before_model -.-> __end__;
|
||||
NoopEight_before_model -.-> model_request;
|
||||
NoopEight_before_model -.-> tools;
|
||||
NoopSeven_after_model -.-> NoopSeven_before_model;
|
||||
NoopSeven_after_model -.-> __end__;
|
||||
NoopSeven_after_model -.-> tools;
|
||||
NoopSeven_before_model -.-> NoopEight_before_model;
|
||||
NoopSeven_before_model -.-> __end__;
|
||||
NoopSeven_before_model -.-> tools;
|
||||
__start__ --> NoopSeven_before_model;
|
||||
model_request --> NoopEight_after_model;
|
||||
tools -.-> NoopSeven_before_model;
|
||||
tools -.-> __end__;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_jump[postgres_pool]
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model_request(model_request)
|
||||
tools(tools)
|
||||
NoopSeven_before_model(NoopSeven.before_model)
|
||||
NoopSeven_after_model(NoopSeven.after_model)
|
||||
NoopEight_before_model(NoopEight.before_model)
|
||||
NoopEight_after_model(NoopEight.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopEight_after_model -.-> NoopSeven_after_model;
|
||||
NoopEight_after_model -.-> NoopSeven_before_model;
|
||||
NoopEight_after_model -.-> __end__;
|
||||
NoopEight_after_model -.-> tools;
|
||||
NoopEight_before_model -.-> NoopSeven_before_model;
|
||||
NoopEight_before_model -.-> __end__;
|
||||
NoopEight_before_model -.-> model_request;
|
||||
NoopEight_before_model -.-> tools;
|
||||
NoopSeven_after_model -.-> NoopSeven_before_model;
|
||||
NoopSeven_after_model -.-> __end__;
|
||||
NoopSeven_after_model -.-> tools;
|
||||
NoopSeven_before_model -.-> NoopEight_before_model;
|
||||
NoopSeven_before_model -.-> __end__;
|
||||
NoopSeven_before_model -.-> tools;
|
||||
__start__ --> NoopSeven_before_model;
|
||||
model_request --> NoopEight_after_model;
|
||||
tools -.-> NoopSeven_before_model;
|
||||
tools -.-> __end__;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_jump[sqlite]
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model_request(model_request)
|
||||
tools(tools)
|
||||
NoopSeven_before_model(NoopSeven.before_model)
|
||||
NoopSeven_after_model(NoopSeven.after_model)
|
||||
NoopEight_before_model(NoopEight.before_model)
|
||||
NoopEight_after_model(NoopEight.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopEight_after_model -.-> NoopSeven_after_model;
|
||||
NoopEight_after_model -.-> NoopSeven_before_model;
|
||||
NoopEight_after_model -.-> __end__;
|
||||
NoopEight_after_model -.-> tools;
|
||||
NoopEight_before_model -.-> NoopSeven_before_model;
|
||||
NoopEight_before_model -.-> __end__;
|
||||
NoopEight_before_model -.-> model_request;
|
||||
NoopEight_before_model -.-> tools;
|
||||
NoopSeven_after_model -.-> NoopSeven_before_model;
|
||||
NoopSeven_after_model -.-> __end__;
|
||||
NoopSeven_after_model -.-> tools;
|
||||
NoopSeven_before_model -.-> NoopEight_before_model;
|
||||
NoopSeven_before_model -.-> __end__;
|
||||
NoopSeven_before_model -.-> tools;
|
||||
__start__ --> NoopSeven_before_model;
|
||||
model_request --> NoopEight_after_model;
|
||||
tools -.-> NoopSeven_before_model;
|
||||
tools -.-> __end__;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
@@ -0,0 +1,83 @@
|
||||
# serializer version: 1
|
||||
# name: test_react_agent_graph_structure[None-None-tools0]
|
||||
'''
|
||||
graph TD;
|
||||
__start__ --> agent;
|
||||
agent --> __end__;
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_react_agent_graph_structure[None-None-tools1]
|
||||
'''
|
||||
graph TD;
|
||||
__start__ --> agent;
|
||||
agent -.-> __end__;
|
||||
agent -.-> tools;
|
||||
tools --> agent;
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_react_agent_graph_structure[None-pre_model_hook-tools0]
|
||||
'''
|
||||
graph TD;
|
||||
__start__ --> pre_model_hook;
|
||||
pre_model_hook --> agent;
|
||||
agent --> __end__;
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_react_agent_graph_structure[None-pre_model_hook-tools1]
|
||||
'''
|
||||
graph TD;
|
||||
__start__ --> pre_model_hook;
|
||||
agent -.-> __end__;
|
||||
agent -.-> tools;
|
||||
pre_model_hook --> agent;
|
||||
tools --> pre_model_hook;
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_react_agent_graph_structure[post_model_hook-None-tools0]
|
||||
'''
|
||||
graph TD;
|
||||
__start__ --> agent;
|
||||
agent --> post_model_hook;
|
||||
post_model_hook --> __end__;
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_react_agent_graph_structure[post_model_hook-None-tools1]
|
||||
'''
|
||||
graph TD;
|
||||
__start__ --> agent;
|
||||
agent --> post_model_hook;
|
||||
post_model_hook -.-> __end__;
|
||||
post_model_hook -.-> agent;
|
||||
post_model_hook -.-> tools;
|
||||
tools --> agent;
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_react_agent_graph_structure[post_model_hook-pre_model_hook-tools0]
|
||||
'''
|
||||
graph TD;
|
||||
__start__ --> pre_model_hook;
|
||||
agent --> post_model_hook;
|
||||
pre_model_hook --> agent;
|
||||
post_model_hook --> __end__;
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_react_agent_graph_structure[post_model_hook-pre_model_hook-tools1]
|
||||
'''
|
||||
graph TD;
|
||||
__start__ --> pre_model_hook;
|
||||
agent --> post_model_hook;
|
||||
post_model_hook -.-> __end__;
|
||||
post_model_hook -.-> pre_model_hook;
|
||||
post_model_hook -.-> tools;
|
||||
pre_model_hook --> agent;
|
||||
tools --> pre_model_hook;
|
||||
|
||||
'''
|
||||
# ---
|
||||
18
libs/langchain_v1/tests/unit_tests/agents/any_str.py
Normal file
18
libs/langchain_v1/tests/unit_tests/agents/any_str.py
Normal file
@@ -0,0 +1,18 @@
|
||||
import re
|
||||
from typing import Union
|
||||
|
||||
|
||||
class AnyStr(str):
|
||||
def __init__(self, prefix: Union[str, re.Pattern] = "") -> None:
|
||||
super().__init__()
|
||||
self.prefix = prefix
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return isinstance(other, str) and (
|
||||
other.startswith(self.prefix)
|
||||
if isinstance(self.prefix, str)
|
||||
else self.prefix.match(other)
|
||||
)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash((str(self), self.prefix))
|
||||
@@ -0,0 +1,17 @@
|
||||
name: langgraph-tests
|
||||
services:
|
||||
postgres-test:
|
||||
image: postgres:16
|
||||
ports:
|
||||
- "5442:5432"
|
||||
environment:
|
||||
POSTGRES_DB: postgres
|
||||
POSTGRES_USER: postgres
|
||||
POSTGRES_PASSWORD: postgres
|
||||
healthcheck:
|
||||
test: pg_isready -U postgres
|
||||
start_period: 10s
|
||||
timeout: 1s
|
||||
retries: 5
|
||||
interval: 60s
|
||||
start_interval: 1s
|
||||
16
libs/langchain_v1/tests/unit_tests/agents/compose-redis.yml
Normal file
16
libs/langchain_v1/tests/unit_tests/agents/compose-redis.yml
Normal file
@@ -0,0 +1,16 @@
|
||||
name: langgraph-tests-redis
|
||||
services:
|
||||
redis-test:
|
||||
image: redis:7-alpine
|
||||
ports:
|
||||
- "6379:6379"
|
||||
command: redis-server --maxmemory 256mb --maxmemory-policy allkeys-lru
|
||||
healthcheck:
|
||||
test: redis-cli ping
|
||||
start_period: 10s
|
||||
timeout: 1s
|
||||
retries: 5
|
||||
interval: 5s
|
||||
start_interval: 1s
|
||||
tmpfs:
|
||||
- /data # Use tmpfs for faster testing
|
||||
194
libs/langchain_v1/tests/unit_tests/agents/conftest.py
Normal file
194
libs/langchain_v1/tests/unit_tests/agents/conftest.py
Normal file
@@ -0,0 +1,194 @@
|
||||
import os
|
||||
from collections.abc import AsyncIterator, Iterator
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from langgraph.checkpoint.base import BaseCheckpointSaver
|
||||
from langgraph.store.base import BaseStore
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from .conftest_checkpointer import (
|
||||
_checkpointer_memory,
|
||||
_checkpointer_postgres,
|
||||
_checkpointer_postgres_aio,
|
||||
_checkpointer_postgres_aio_pipe,
|
||||
_checkpointer_postgres_aio_pool,
|
||||
_checkpointer_postgres_pipe,
|
||||
_checkpointer_postgres_pool,
|
||||
_checkpointer_sqlite,
|
||||
_checkpointer_sqlite_aio,
|
||||
)
|
||||
from .conftest_store import (
|
||||
_store_memory,
|
||||
_store_postgres,
|
||||
_store_postgres_aio,
|
||||
_store_postgres_aio_pipe,
|
||||
_store_postgres_aio_pool,
|
||||
_store_postgres_pipe,
|
||||
_store_postgres_pool,
|
||||
)
|
||||
|
||||
# Global variables for checkpointer and store configurations
|
||||
FAST_MODE = os.getenv("LANGGRAPH_TEST_FAST", "true").lower() in ("true", "1", "yes")
|
||||
|
||||
SYNC_CHECKPOINTER_PARAMS = (
|
||||
["memory"]
|
||||
if FAST_MODE
|
||||
else [
|
||||
"memory",
|
||||
"sqlite",
|
||||
"postgres",
|
||||
"postgres_pipe",
|
||||
"postgres_pool",
|
||||
]
|
||||
)
|
||||
|
||||
ASYNC_CHECKPOINTER_PARAMS = (
|
||||
["memory"]
|
||||
if FAST_MODE
|
||||
else [
|
||||
"memory",
|
||||
"sqlite_aio",
|
||||
"postgres_aio",
|
||||
"postgres_aio_pipe",
|
||||
"postgres_aio_pool",
|
||||
]
|
||||
)
|
||||
|
||||
SYNC_STORE_PARAMS = (
|
||||
["in_memory"]
|
||||
if FAST_MODE
|
||||
else [
|
||||
"in_memory",
|
||||
"postgres",
|
||||
"postgres_pipe",
|
||||
"postgres_pool",
|
||||
]
|
||||
)
|
||||
|
||||
ASYNC_STORE_PARAMS = (
|
||||
["in_memory"]
|
||||
if FAST_MODE
|
||||
else [
|
||||
"in_memory",
|
||||
"postgres_aio",
|
||||
"postgres_aio_pipe",
|
||||
"postgres_aio_pool",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def anyio_backend() -> str:
|
||||
return "asyncio"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def deterministic_uuids(mocker: MockerFixture) -> MockerFixture:
|
||||
side_effect = (UUID(f"00000000-0000-4000-8000-{i:012}", version=4) for i in range(10000))
|
||||
return mocker.patch("uuid.uuid4", side_effect=side_effect)
|
||||
|
||||
|
||||
# checkpointer fixtures
|
||||
|
||||
|
||||
@pytest.fixture(
|
||||
params=SYNC_STORE_PARAMS,
|
||||
)
|
||||
def sync_store(request: pytest.FixtureRequest) -> Iterator[BaseStore]:
|
||||
store_name = request.param
|
||||
if store_name is None:
|
||||
yield None
|
||||
elif store_name == "in_memory":
|
||||
with _store_memory() as store:
|
||||
yield store
|
||||
elif store_name == "postgres":
|
||||
with _store_postgres() as store:
|
||||
yield store
|
||||
elif store_name == "postgres_pipe":
|
||||
with _store_postgres_pipe() as store:
|
||||
yield store
|
||||
elif store_name == "postgres_pool":
|
||||
with _store_postgres_pool() as store:
|
||||
yield store
|
||||
else:
|
||||
msg = f"Unknown store {store_name}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
@pytest.fixture(
|
||||
params=ASYNC_STORE_PARAMS,
|
||||
)
|
||||
async def async_store(request: pytest.FixtureRequest) -> AsyncIterator[BaseStore]:
|
||||
store_name = request.param
|
||||
if store_name is None:
|
||||
yield None
|
||||
elif store_name == "in_memory":
|
||||
with _store_memory() as store:
|
||||
yield store
|
||||
elif store_name == "postgres_aio":
|
||||
async with _store_postgres_aio() as store:
|
||||
yield store
|
||||
elif store_name == "postgres_aio_pipe":
|
||||
async with _store_postgres_aio_pipe() as store:
|
||||
yield store
|
||||
elif store_name == "postgres_aio_pool":
|
||||
async with _store_postgres_aio_pool() as store:
|
||||
yield store
|
||||
else:
|
||||
msg = f"Unknown store {store_name}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
@pytest.fixture(
|
||||
params=SYNC_CHECKPOINTER_PARAMS,
|
||||
)
|
||||
def sync_checkpointer(
|
||||
request: pytest.FixtureRequest,
|
||||
) -> Iterator[BaseCheckpointSaver]:
|
||||
checkpointer_name = request.param
|
||||
if checkpointer_name == "memory":
|
||||
with _checkpointer_memory() as checkpointer:
|
||||
yield checkpointer
|
||||
elif checkpointer_name == "sqlite":
|
||||
with _checkpointer_sqlite() as checkpointer:
|
||||
yield checkpointer
|
||||
elif checkpointer_name == "postgres":
|
||||
with _checkpointer_postgres() as checkpointer:
|
||||
yield checkpointer
|
||||
elif checkpointer_name == "postgres_pipe":
|
||||
with _checkpointer_postgres_pipe() as checkpointer:
|
||||
yield checkpointer
|
||||
elif checkpointer_name == "postgres_pool":
|
||||
with _checkpointer_postgres_pool() as checkpointer:
|
||||
yield checkpointer
|
||||
else:
|
||||
msg = f"Unknown checkpointer: {checkpointer_name}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
@pytest.fixture(
|
||||
params=ASYNC_CHECKPOINTER_PARAMS,
|
||||
)
|
||||
async def async_checkpointer(
|
||||
request: pytest.FixtureRequest,
|
||||
) -> AsyncIterator[BaseCheckpointSaver]:
|
||||
checkpointer_name = request.param
|
||||
if checkpointer_name == "memory":
|
||||
with _checkpointer_memory() as checkpointer:
|
||||
yield checkpointer
|
||||
elif checkpointer_name == "sqlite_aio":
|
||||
async with _checkpointer_sqlite_aio() as checkpointer:
|
||||
yield checkpointer
|
||||
elif checkpointer_name == "postgres_aio":
|
||||
async with _checkpointer_postgres_aio() as checkpointer:
|
||||
yield checkpointer
|
||||
elif checkpointer_name == "postgres_aio_pipe":
|
||||
async with _checkpointer_postgres_aio_pipe() as checkpointer:
|
||||
yield checkpointer
|
||||
elif checkpointer_name == "postgres_aio_pool":
|
||||
async with _checkpointer_postgres_aio_pool() as checkpointer:
|
||||
yield checkpointer
|
||||
else:
|
||||
msg = f"Unknown checkpointer: {checkpointer_name}"
|
||||
raise NotImplementedError(msg)
|
||||
@@ -0,0 +1,64 @@
|
||||
from contextlib import asynccontextmanager, contextmanager
|
||||
|
||||
from .memory_assert import (
|
||||
MemorySaverAssertImmutable,
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _checkpointer_memory():
|
||||
yield MemorySaverAssertImmutable()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def _checkpointer_memory_aio():
|
||||
yield MemorySaverAssertImmutable()
|
||||
|
||||
|
||||
# Placeholder functions for other checkpointer types that aren't available
|
||||
@contextmanager
|
||||
def _checkpointer_sqlite():
|
||||
# Fallback to memory for now
|
||||
yield MemorySaverAssertImmutable()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _checkpointer_postgres():
|
||||
# Fallback to memory for now
|
||||
yield MemorySaverAssertImmutable()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _checkpointer_postgres_pipe():
|
||||
# Fallback to memory for now
|
||||
yield MemorySaverAssertImmutable()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _checkpointer_postgres_pool():
|
||||
# Fallback to memory for now
|
||||
yield MemorySaverAssertImmutable()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def _checkpointer_sqlite_aio():
|
||||
# Fallback to memory for now
|
||||
yield MemorySaverAssertImmutable()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def _checkpointer_postgres_aio():
|
||||
# Fallback to memory for now
|
||||
yield MemorySaverAssertImmutable()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def _checkpointer_postgres_aio_pipe():
|
||||
# Fallback to memory for now
|
||||
yield MemorySaverAssertImmutable()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def _checkpointer_postgres_aio_pool():
|
||||
# Fallback to memory for now
|
||||
yield MemorySaverAssertImmutable()
|
||||
58
libs/langchain_v1/tests/unit_tests/agents/conftest_store.py
Normal file
58
libs/langchain_v1/tests/unit_tests/agents/conftest_store.py
Normal file
@@ -0,0 +1,58 @@
|
||||
from contextlib import asynccontextmanager, contextmanager
|
||||
|
||||
from langgraph.store.memory import InMemoryStore
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _store_memory():
|
||||
store = InMemoryStore()
|
||||
yield store
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def _store_memory_aio():
|
||||
store = InMemoryStore()
|
||||
yield store
|
||||
|
||||
|
||||
# Placeholder functions for other store types that aren't available
|
||||
@contextmanager
|
||||
def _store_postgres():
|
||||
# Fallback to memory for now
|
||||
store = InMemoryStore()
|
||||
yield store
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _store_postgres_pipe():
|
||||
# Fallback to memory for now
|
||||
store = InMemoryStore()
|
||||
yield store
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _store_postgres_pool():
|
||||
# Fallback to memory for now
|
||||
store = InMemoryStore()
|
||||
yield store
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def _store_postgres_aio():
|
||||
# Fallback to memory for now
|
||||
store = InMemoryStore()
|
||||
yield store
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def _store_postgres_aio_pipe():
|
||||
# Fallback to memory for now
|
||||
store = InMemoryStore()
|
||||
yield store
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def _store_postgres_aio_pool():
|
||||
# Fallback to memory for now
|
||||
store = InMemoryStore()
|
||||
yield store
|
||||
56
libs/langchain_v1/tests/unit_tests/agents/memory_assert.py
Normal file
56
libs/langchain_v1/tests/unit_tests/agents/memory_assert.py
Normal file
@@ -0,0 +1,56 @@
|
||||
import os
|
||||
import tempfile
|
||||
from collections import defaultdict
|
||||
from functools import partial
|
||||
|
||||
from langgraph.checkpoint.base import (
|
||||
ChannelVersions,
|
||||
Checkpoint,
|
||||
CheckpointMetadata,
|
||||
SerializerProtocol,
|
||||
)
|
||||
from langgraph.checkpoint.memory import InMemorySaver, PersistentDict
|
||||
from langgraph.pregel._checkpoint import copy_checkpoint
|
||||
|
||||
|
||||
class MemorySaverAssertImmutable(InMemorySaver):
|
||||
storage_for_copies: defaultdict[str, dict[str, dict[str, Checkpoint]]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
serde: SerializerProtocol | None = None,
|
||||
put_sleep: float | None = None,
|
||||
) -> None:
|
||||
_, filename = tempfile.mkstemp()
|
||||
super().__init__(serde=serde, factory=partial(PersistentDict, filename=filename))
|
||||
self.storage_for_copies = defaultdict(lambda: defaultdict(dict))
|
||||
self.put_sleep = put_sleep
|
||||
self.stack.callback(os.remove, filename)
|
||||
|
||||
def put(
|
||||
self,
|
||||
config: dict,
|
||||
checkpoint: Checkpoint,
|
||||
metadata: CheckpointMetadata,
|
||||
new_versions: ChannelVersions,
|
||||
) -> None:
|
||||
if self.put_sleep:
|
||||
import time
|
||||
|
||||
time.sleep(self.put_sleep)
|
||||
# assert checkpoint hasn't been modified since last written
|
||||
thread_id = config["configurable"]["thread_id"]
|
||||
checkpoint_ns = config["configurable"]["checkpoint_ns"]
|
||||
if saved := super().get(config):
|
||||
assert (
|
||||
self.serde.loads_typed(
|
||||
self.storage_for_copies[thread_id][checkpoint_ns][saved["id"]]
|
||||
)
|
||||
== saved
|
||||
)
|
||||
self.storage_for_copies[thread_id][checkpoint_ns][checkpoint["id"]] = (
|
||||
self.serde.dumps_typed(copy_checkpoint(checkpoint))
|
||||
)
|
||||
# call super to write checkpoint
|
||||
return super().put(config, checkpoint, metadata, new_versions)
|
||||
28
libs/langchain_v1/tests/unit_tests/agents/messages.py
Normal file
28
libs/langchain_v1/tests/unit_tests/agents/messages.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""Redefined messages as a work-around for pydantic issue with AnyStr.
|
||||
|
||||
The code below creates version of pydantic models
|
||||
that will work in unit tests with AnyStr as id field
|
||||
Please note that the `id` field is assigned AFTER the model is created
|
||||
to workaround an issue with pydantic ignoring the __eq__ method on
|
||||
subclassed strings.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage, ToolMessage
|
||||
|
||||
from .any_str import AnyStr
|
||||
|
||||
|
||||
def _AnyIdHumanMessage(**kwargs: Any) -> HumanMessage:
|
||||
"""Create a human message with an any id field."""
|
||||
message = HumanMessage(**kwargs)
|
||||
message.id = AnyStr()
|
||||
return message
|
||||
|
||||
|
||||
def _AnyIdToolMessage(**kwargs: Any) -> ToolMessage:
|
||||
"""Create a tool message with an any id field."""
|
||||
message = ToolMessage(**kwargs)
|
||||
message.id = AnyStr()
|
||||
return message
|
||||
111
libs/langchain_v1/tests/unit_tests/agents/model.py
Normal file
111
libs/langchain_v1/tests/unit_tests/agents/model.py
Normal file
@@ -0,0 +1,111 @@
|
||||
import json
|
||||
from collections.abc import Callable, Sequence
|
||||
from dataclasses import asdict, is_dataclass
|
||||
from typing import (
|
||||
Any,
|
||||
Generic,
|
||||
Literal,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models import BaseChatModel, LanguageModelInput
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
ToolCall,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
from langchain_core.runnables import Runnable
|
||||
from langchain_core.tools import BaseTool
|
||||
from pydantic import BaseModel
|
||||
|
||||
StructuredResponseT = TypeVar("StructuredResponseT")
|
||||
|
||||
|
||||
class FakeToolCallingModel(BaseChatModel, Generic[StructuredResponseT]):
|
||||
tool_calls: Union[list[list[ToolCall]], list[list[dict]]] | None = None
|
||||
structured_response: StructuredResponseT | None = None
|
||||
index: int = 0
|
||||
tool_style: Literal["openai", "anthropic"] = "openai"
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""Top Level call"""
|
||||
rf = kwargs.get("response_format")
|
||||
is_native = isinstance(rf, dict) and rf.get("type") == "json_schema"
|
||||
|
||||
if self.tool_calls:
|
||||
if is_native:
|
||||
tool_calls = (
|
||||
self.tool_calls[self.index] if self.index < len(self.tool_calls) else []
|
||||
)
|
||||
else:
|
||||
tool_calls = self.tool_calls[self.index % len(self.tool_calls)]
|
||||
else:
|
||||
tool_calls = []
|
||||
|
||||
if is_native and not tool_calls:
|
||||
if isinstance(self.structured_response, BaseModel):
|
||||
content_obj = self.structured_response.model_dump()
|
||||
elif is_dataclass(self.structured_response):
|
||||
content_obj = asdict(self.structured_response)
|
||||
elif isinstance(self.structured_response, dict):
|
||||
content_obj = self.structured_response
|
||||
message = AIMessage(content=json.dumps(content_obj), id=str(self.index))
|
||||
else:
|
||||
messages_string = "-".join([m.content for m in messages])
|
||||
message = AIMessage(
|
||||
content=messages_string,
|
||||
id=str(self.index),
|
||||
tool_calls=tool_calls.copy(),
|
||||
)
|
||||
self.index += 1
|
||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "fake-tool-call-model"
|
||||
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]],
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||
if len(tools) == 0:
|
||||
msg = "Must provide at least one tool"
|
||||
raise ValueError(msg)
|
||||
|
||||
tool_dicts = []
|
||||
for tool in tools:
|
||||
if isinstance(tool, dict):
|
||||
tool_dicts.append(tool)
|
||||
continue
|
||||
if not isinstance(tool, BaseTool):
|
||||
msg = "Only BaseTool and dict is supported by FakeToolCallingModel.bind_tools"
|
||||
raise TypeError(msg)
|
||||
|
||||
# NOTE: this is a simplified tool spec for testing purposes only
|
||||
if self.tool_style == "openai":
|
||||
tool_dicts.append(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool.name,
|
||||
},
|
||||
}
|
||||
)
|
||||
elif self.tool_style == "anthropic":
|
||||
tool_dicts.append(
|
||||
{
|
||||
"name": tool.name,
|
||||
}
|
||||
)
|
||||
|
||||
return self.bind(tools=tool_dicts)
|
||||
@@ -0,0 +1,87 @@
|
||||
[
|
||||
{
|
||||
"name": "updated structured response",
|
||||
"responseFormat": [
|
||||
{
|
||||
"title": "role_schema_structured_output",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": { "type": "string" },
|
||||
"role": { "type": "string" }
|
||||
},
|
||||
"required": ["name", "role"]
|
||||
},
|
||||
{
|
||||
"title": "department_schema_structured_output",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": { "type": "string" },
|
||||
"department": { "type": "string" }
|
||||
},
|
||||
"required": ["name", "department"]
|
||||
}
|
||||
],
|
||||
"assertionsByInvocation": [
|
||||
{
|
||||
"prompt": "What is the role of Sabine?",
|
||||
"toolsWithExpectedCalls": {
|
||||
"getEmployeeRole": 1,
|
||||
"getEmployeeDepartment": 0
|
||||
},
|
||||
"expectedLastMessage": "Returning structured response: {'name': 'Sabine', 'role': 'Developer'}",
|
||||
"expectedStructuredResponse": { "name": "Sabine", "role": "Developer" },
|
||||
"llmRequestCount": 2
|
||||
},
|
||||
{
|
||||
"prompt": "In which department does Henrik work?",
|
||||
"toolsWithExpectedCalls": {
|
||||
"getEmployeeRole": 1,
|
||||
"getEmployeeDepartment": 1
|
||||
},
|
||||
"expectedLastMessage": "Returning structured response: {'name': 'Henrik', 'department': 'IT'}",
|
||||
"expectedStructuredResponse": { "name": "Henrik", "department": "IT" },
|
||||
"llmRequestCount": 4
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "asking for information that does not fit into the response format",
|
||||
"responseFormat": [
|
||||
{
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": { "type": "string" },
|
||||
"role": { "type": "string" }
|
||||
},
|
||||
"required": ["name", "role"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": { "type": "string" },
|
||||
"department": { "type": "string" }
|
||||
},
|
||||
"required": ["name", "department"]
|
||||
}
|
||||
}
|
||||
],
|
||||
"assertionsByInvocation": [
|
||||
{
|
||||
"prompt": "How much does Saskia earn?",
|
||||
"toolsWithExpectedCalls": {
|
||||
"getEmployeeRole": 1,
|
||||
"getEmployeeDepartment": 0
|
||||
},
|
||||
"expectedLastMessage": "Returning structured response: {'name': 'Saskia', 'role': 'Software Engineer'}",
|
||||
"expectedStructuredResponse": {
|
||||
"name": "Saskia",
|
||||
"role": "Software Engineer"
|
||||
},
|
||||
"llmRequestCount": 2
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
@@ -0,0 +1,48 @@
|
||||
[
|
||||
{
|
||||
"name": "Scenario: NO return_direct, NO response_format",
|
||||
"returnDirect": false,
|
||||
"responseFormat": null,
|
||||
"expectedToolCalls": 10,
|
||||
"expectedLastMessage": "Attempts: 10",
|
||||
"expectedStructuredResponse": null
|
||||
},
|
||||
{
|
||||
"name": "Scenario: NO return_direct, YES response_format",
|
||||
"returnDirect": false,
|
||||
"responseFormat": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"attempts": { "type": "number" },
|
||||
"succeeded": { "type": "boolean" }
|
||||
},
|
||||
"required": ["attempts", "succeeded"]
|
||||
},
|
||||
"expectedToolCalls": 10,
|
||||
"expectedLastMessage": "Returning structured response: {'attempts': 10, 'succeeded': True}",
|
||||
"expectedStructuredResponse": { "attempts": 10, "succeeded": true }
|
||||
},
|
||||
{
|
||||
"name": "Scenario: YES return_direct, NO response_format",
|
||||
"returnDirect": true,
|
||||
"responseFormat": null,
|
||||
"expectedToolCalls": 1,
|
||||
"expectedLastMessage": "{\"status\": \"pending\", \"attempts\": 1}",
|
||||
"expectedStructuredResponse": null
|
||||
},
|
||||
{
|
||||
"name": "Scenario: YES return_direct, YES response_format",
|
||||
"returnDirect": true,
|
||||
"responseFormat": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"attempts": { "type": "number" },
|
||||
"succeeded": { "type": "boolean" }
|
||||
},
|
||||
"required": ["attempts", "succeeded"]
|
||||
},
|
||||
"expectedToolCalls": 1,
|
||||
"expectedLastMessage": "{\"status\": \"pending\", \"attempts\": 1}",
|
||||
"expectedStructuredResponse": null
|
||||
}
|
||||
]
|
||||
@@ -0,0 +1,712 @@
|
||||
import pytest
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
HumanMessage,
|
||||
RemoveMessage,
|
||||
ToolCall,
|
||||
ToolMessage,
|
||||
)
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from langchain.agents.middleware_agent import create_agent
|
||||
from langchain.agents.middleware.human_in_the_loop import HumanInTheLoopMiddleware
|
||||
from langchain.agents.middleware.prompt_caching import AnthropicPromptCachingMiddleware
|
||||
from langchain.agents.middleware.summarization import SummarizationMiddleware
|
||||
from langchain.agents.middleware.types import AgentMiddleware, ModelRequest
|
||||
|
||||
from langgraph.checkpoint.base import BaseCheckpointSaver
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
from langgraph.constants import END
|
||||
from langgraph.graph.message import REMOVE_ALL_MESSAGES
|
||||
from langgraph.prebuilt.interrupt import ActionRequest, HumanInterruptConfig
|
||||
|
||||
from .messages import _AnyIdHumanMessage, _AnyIdToolMessage
|
||||
from .model import FakeToolCallingModel
|
||||
|
||||
|
||||
def test_create_agent_diagram(
|
||||
snapshot: SnapshotAssertion,
|
||||
):
|
||||
class NoopOne(AgentMiddleware):
|
||||
def before_model(self, state):
|
||||
pass
|
||||
|
||||
class NoopTwo(AgentMiddleware):
|
||||
def before_model(self, state):
|
||||
pass
|
||||
|
||||
class NoopThree(AgentMiddleware):
|
||||
def before_model(self, state):
|
||||
pass
|
||||
|
||||
class NoopFour(AgentMiddleware):
|
||||
def after_model(self, state):
|
||||
pass
|
||||
|
||||
class NoopFive(AgentMiddleware):
|
||||
def after_model(self, state):
|
||||
pass
|
||||
|
||||
class NoopSix(AgentMiddleware):
|
||||
def after_model(self, state):
|
||||
pass
|
||||
|
||||
class NoopSeven(AgentMiddleware):
|
||||
def before_model(self, state):
|
||||
pass
|
||||
|
||||
def after_model(self, state):
|
||||
pass
|
||||
|
||||
class NoopEight(AgentMiddleware):
|
||||
def before_model(self, state):
|
||||
pass
|
||||
|
||||
def after_model(self, state):
|
||||
pass
|
||||
|
||||
class NoopNine(AgentMiddleware):
|
||||
def before_model(self, state):
|
||||
pass
|
||||
|
||||
def after_model(self, state):
|
||||
pass
|
||||
|
||||
class NoopTen(AgentMiddleware):
|
||||
def before_model(self, state):
|
||||
pass
|
||||
|
||||
def modify_model_request(self, request, state):
|
||||
pass
|
||||
|
||||
def after_model(self, state):
|
||||
pass
|
||||
|
||||
class NoopEleven(AgentMiddleware):
|
||||
def before_model(self, state):
|
||||
pass
|
||||
|
||||
def modify_model_request(self, request, state):
|
||||
pass
|
||||
|
||||
def after_model(self, state):
|
||||
pass
|
||||
|
||||
agent_zero = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
tools=[],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
)
|
||||
|
||||
assert agent_zero.compile().get_graph().draw_mermaid() == snapshot
|
||||
|
||||
agent_one = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
tools=[],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[NoopOne()],
|
||||
)
|
||||
|
||||
assert agent_one.compile().get_graph().draw_mermaid() == snapshot
|
||||
|
||||
agent_two = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
tools=[],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[NoopOne(), NoopTwo()],
|
||||
)
|
||||
|
||||
assert agent_two.compile().get_graph().draw_mermaid() == snapshot
|
||||
|
||||
agent_three = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
tools=[],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[NoopOne(), NoopTwo(), NoopThree()],
|
||||
)
|
||||
|
||||
assert agent_three.compile().get_graph().draw_mermaid() == snapshot
|
||||
|
||||
agent_four = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
tools=[],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[NoopFour()],
|
||||
)
|
||||
|
||||
assert agent_four.compile().get_graph().draw_mermaid() == snapshot
|
||||
|
||||
agent_five = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
tools=[],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[NoopFour(), NoopFive()],
|
||||
)
|
||||
|
||||
assert agent_five.compile().get_graph().draw_mermaid() == snapshot
|
||||
|
||||
agent_six = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
tools=[],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[NoopFour(), NoopFive(), NoopSix()],
|
||||
)
|
||||
|
||||
assert agent_six.compile().get_graph().draw_mermaid() == snapshot
|
||||
|
||||
agent_seven = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
tools=[],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[NoopSeven()],
|
||||
)
|
||||
|
||||
assert agent_seven.compile().get_graph().draw_mermaid() == snapshot
|
||||
|
||||
agent_eight = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
tools=[],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[NoopSeven(), NoopEight()],
|
||||
)
|
||||
|
||||
assert agent_eight.compile().get_graph().draw_mermaid() == snapshot
|
||||
|
||||
agent_nine = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
tools=[],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[NoopSeven(), NoopEight(), NoopNine()],
|
||||
)
|
||||
|
||||
assert agent_nine.compile().get_graph().draw_mermaid() == snapshot
|
||||
|
||||
agent_ten = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
tools=[],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[NoopTen()],
|
||||
)
|
||||
|
||||
assert agent_ten.compile().get_graph().draw_mermaid() == snapshot
|
||||
|
||||
agent_eleven = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
tools=[],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[NoopTen(), NoopEleven()],
|
||||
)
|
||||
|
||||
assert agent_eleven.compile().get_graph().draw_mermaid() == snapshot
|
||||
|
||||
|
||||
def test_create_agent_invoke(
|
||||
snapshot: SnapshotAssertion,
|
||||
sync_checkpointer: BaseCheckpointSaver,
|
||||
):
|
||||
calls = []
|
||||
|
||||
class NoopSeven(AgentMiddleware):
|
||||
def before_model(self, state):
|
||||
calls.append("NoopSeven.before_model")
|
||||
|
||||
def modify_model_request(self, request, state):
|
||||
calls.append("NoopSeven.modify_model_request")
|
||||
return request
|
||||
|
||||
def after_model(self, state):
|
||||
calls.append("NoopSeven.after_model")
|
||||
|
||||
class NoopEight(AgentMiddleware):
|
||||
def before_model(self, state):
|
||||
calls.append("NoopEight.before_model")
|
||||
|
||||
def modify_model_request(self, request, state):
|
||||
calls.append("NoopEight.modify_model_request")
|
||||
return request
|
||||
|
||||
def after_model(self, state):
|
||||
calls.append("NoopEight.after_model")
|
||||
|
||||
@tool
|
||||
def my_tool(input: str) -> str:
|
||||
"""A great tool"""
|
||||
calls.append("my_tool")
|
||||
return input.upper()
|
||||
|
||||
agent_one = create_agent(
|
||||
model=FakeToolCallingModel(
|
||||
tool_calls=[
|
||||
[
|
||||
{"args": {"input": "yo"}, "id": "1", "name": "my_tool"},
|
||||
],
|
||||
[],
|
||||
]
|
||||
),
|
||||
tools=[my_tool],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[NoopSeven(), NoopEight()],
|
||||
).compile(checkpointer=sync_checkpointer)
|
||||
|
||||
thread1 = {"configurable": {"thread_id": "1"}}
|
||||
assert agent_one.invoke({"messages": ["hello"]}, thread1) == {
|
||||
"messages": [
|
||||
_AnyIdHumanMessage(content="hello"),
|
||||
AIMessage(
|
||||
content="You are a helpful assistant.-hello",
|
||||
additional_kwargs={},
|
||||
response_metadata={},
|
||||
id="0",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "my_tool",
|
||||
"args": {"input": "yo"},
|
||||
"id": "1",
|
||||
"type": "tool_call",
|
||||
}
|
||||
],
|
||||
),
|
||||
_AnyIdToolMessage(content="YO", name="my_tool", tool_call_id="1"),
|
||||
AIMessage(
|
||||
content="You are a helpful assistant.-hello-You are a helpful assistant.-hello-YO",
|
||||
additional_kwargs={},
|
||||
response_metadata={},
|
||||
id="1",
|
||||
),
|
||||
],
|
||||
}
|
||||
assert calls == [
|
||||
"NoopSeven.before_model",
|
||||
"NoopEight.before_model",
|
||||
"NoopSeven.modify_model_request",
|
||||
"NoopEight.modify_model_request",
|
||||
"NoopEight.after_model",
|
||||
"NoopSeven.after_model",
|
||||
"my_tool",
|
||||
"NoopSeven.before_model",
|
||||
"NoopEight.before_model",
|
||||
"NoopSeven.modify_model_request",
|
||||
"NoopEight.modify_model_request",
|
||||
"NoopEight.after_model",
|
||||
"NoopSeven.after_model",
|
||||
]
|
||||
|
||||
|
||||
def test_create_agent_jump(
|
||||
snapshot: SnapshotAssertion,
|
||||
sync_checkpointer: BaseCheckpointSaver,
|
||||
):
|
||||
calls = []
|
||||
|
||||
class NoopSeven(AgentMiddleware):
|
||||
def before_model(self, state):
|
||||
calls.append("NoopSeven.before_model")
|
||||
|
||||
def modify_model_request(self, request, state):
|
||||
calls.append("NoopSeven.modify_model_request")
|
||||
return request
|
||||
|
||||
def after_model(self, state):
|
||||
calls.append("NoopSeven.after_model")
|
||||
|
||||
class NoopEight(AgentMiddleware):
|
||||
def before_model(self, state) -> dict[str, Any]:
|
||||
calls.append("NoopEight.before_model")
|
||||
return {"jump_to": END}
|
||||
|
||||
def modify_model_request(self, request, state) -> ModelRequest:
|
||||
calls.append("NoopEight.modify_model_request")
|
||||
return request
|
||||
|
||||
def after_model(self, state):
|
||||
calls.append("NoopEight.after_model")
|
||||
|
||||
@tool
|
||||
def my_tool(input: str) -> str:
|
||||
"""A great tool"""
|
||||
calls.append("my_tool")
|
||||
return input.upper()
|
||||
|
||||
agent_one = create_agent(
|
||||
model=FakeToolCallingModel(
|
||||
tool_calls=[[ToolCall(id="1", name="my_tool", args={"input": "yo"})]],
|
||||
),
|
||||
tools=[my_tool],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[NoopSeven(), NoopEight()],
|
||||
).compile(checkpointer=sync_checkpointer)
|
||||
|
||||
if isinstance(sync_checkpointer, InMemorySaver):
|
||||
assert agent_one.get_graph().draw_mermaid() == snapshot
|
||||
|
||||
thread1 = {"configurable": {"thread_id": "1"}}
|
||||
assert agent_one.invoke({"messages": []}, thread1) == {"messages": []}
|
||||
assert calls == ["NoopSeven.before_model", "NoopEight.before_model"]
|
||||
|
||||
|
||||
# Tests for HumanInTheLoopMiddleware
|
||||
def test_human_in_the_loop_middleware_initialization() -> None:
|
||||
"""Test HumanInTheLoopMiddleware initialization."""
|
||||
tool_configs = {
|
||||
"test_tool": HumanInterruptConfig(
|
||||
allow_ignore=True, allow_respond=True, allow_edit=True, allow_accept=True
|
||||
)
|
||||
}
|
||||
|
||||
middleware = HumanInTheLoopMiddleware(tool_configs=tool_configs, message_prefix="Custom prefix")
|
||||
|
||||
assert middleware.tool_configs == tool_configs
|
||||
assert middleware.message_prefix == "Custom prefix"
|
||||
|
||||
|
||||
def test_human_in_the_loop_middleware_no_interrupts_needed() -> None:
|
||||
"""Test HumanInTheLoopMiddleware when no interrupts are needed."""
|
||||
tool_configs = {
|
||||
"test_tool": HumanInterruptConfig(
|
||||
allow_ignore=True, allow_respond=True, allow_edit=True, allow_accept=True
|
||||
)
|
||||
}
|
||||
|
||||
middleware = HumanInTheLoopMiddleware(tool_configs=tool_configs)
|
||||
|
||||
# Test with no messages
|
||||
state: dict[str, Any] = {"messages": []}
|
||||
result = middleware.after_model(state)
|
||||
assert result is None
|
||||
|
||||
# Test with message but no tool calls
|
||||
state = {"messages": [HumanMessage(content="Hello"), AIMessage(content="Hi there")]}
|
||||
result = middleware.after_model(state)
|
||||
assert result is None
|
||||
|
||||
# Test with tool calls that don't require interrupts
|
||||
ai_message = AIMessage(
|
||||
content="I'll help you",
|
||||
tool_calls=[{"name": "other_tool", "args": {"input": "test"}, "id": "1"}],
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
|
||||
result = middleware.after_model(state)
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_human_in_the_loop_middleware_interrupt_responses() -> None:
|
||||
"""Test HumanInTheLoopMiddleware with different interrupt response types."""
|
||||
tool_configs = {
|
||||
"test_tool": HumanInterruptConfig(
|
||||
allow_ignore=True, allow_respond=True, allow_edit=True, allow_accept=True
|
||||
)
|
||||
}
|
||||
|
||||
middleware = HumanInTheLoopMiddleware(tool_configs=tool_configs)
|
||||
|
||||
ai_message = AIMessage(
|
||||
content="I'll help you",
|
||||
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
|
||||
|
||||
# Test accept response
|
||||
def mock_accept(requests):
|
||||
return [{"type": "accept", "args": None}]
|
||||
|
||||
with patch("langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_accept):
|
||||
result = middleware.after_model(state)
|
||||
assert result is not None
|
||||
assert result["messages"][0] == ai_message
|
||||
assert result["messages"][0].tool_calls == ai_message.tool_calls
|
||||
|
||||
# Test edit response
|
||||
def mock_edit(requests):
|
||||
return [
|
||||
{"type": "edit", "args": ActionRequest(action="test_tool", args={"input": "edited"})}
|
||||
]
|
||||
|
||||
with patch("langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_edit):
|
||||
result = middleware.after_model(state)
|
||||
assert result is not None
|
||||
assert result["messages"][0].tool_calls[0]["args"] == {"input": "edited"}
|
||||
|
||||
# Test ignore response
|
||||
def mock_ignore(requests):
|
||||
return [{"type": "ignore", "args": None}]
|
||||
|
||||
with patch("langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_ignore):
|
||||
result = middleware.after_model(state)
|
||||
assert result is not None
|
||||
assert result["jump_to"] == "__end__"
|
||||
|
||||
# Test response type
|
||||
def mock_response(requests):
|
||||
return [{"type": "response", "args": "Custom response"}]
|
||||
|
||||
with patch(
|
||||
"langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_response
|
||||
):
|
||||
result = middleware.after_model(state)
|
||||
assert result is not None
|
||||
assert result["jump_to"] == "model"
|
||||
assert result["messages"][0]["role"] == "tool"
|
||||
assert result["messages"][0]["content"] == "Custom response"
|
||||
|
||||
# Test unknown response type
|
||||
def mock_unknown(requests):
|
||||
return [{"type": "unknown", "args": None}]
|
||||
|
||||
with patch("langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_unknown):
|
||||
with pytest.raises(ValueError, match="Unknown response type: unknown"):
|
||||
middleware.after_model(state)
|
||||
|
||||
|
||||
# Tests for AnthropicPromptCachingMiddleware
|
||||
def test_anthropic_prompt_caching_middleware_initialization() -> None:
|
||||
"""Test AnthropicPromptCachingMiddleware initialization."""
|
||||
# Test with custom values
|
||||
middleware = AnthropicPromptCachingMiddleware(
|
||||
type="ephemeral", ttl="1h", min_messages_to_cache=5
|
||||
)
|
||||
assert middleware.type == "ephemeral"
|
||||
assert middleware.ttl == "1h"
|
||||
assert middleware.min_messages_to_cache == 5
|
||||
|
||||
# Test with default values
|
||||
middleware = AnthropicPromptCachingMiddleware()
|
||||
assert middleware.type == "ephemeral"
|
||||
assert middleware.ttl == "5m"
|
||||
assert middleware.min_messages_to_cache == 0
|
||||
|
||||
|
||||
# Tests for SummarizationMiddleware
|
||||
def test_summarization_middleware_initialization() -> None:
|
||||
"""Test SummarizationMiddleware initialization."""
|
||||
model = FakeToolCallingModel()
|
||||
middleware = SummarizationMiddleware(
|
||||
model=model,
|
||||
max_tokens_before_summary=1000,
|
||||
messages_to_keep=10,
|
||||
summary_prompt="Custom prompt: {messages}",
|
||||
summary_prefix="Custom prefix:",
|
||||
)
|
||||
|
||||
assert middleware.model == model
|
||||
assert middleware.max_tokens_before_summary == 1000
|
||||
assert middleware.messages_to_keep == 10
|
||||
assert middleware.summary_prompt == "Custom prompt: {messages}"
|
||||
assert middleware.summary_prefix == "Custom prefix:"
|
||||
|
||||
# Test with string model
|
||||
with patch(
|
||||
"langchain.agents.middleware.summarization.init_chat_model",
|
||||
return_value=FakeToolCallingModel(),
|
||||
):
|
||||
middleware = SummarizationMiddleware(model="fake-model")
|
||||
assert isinstance(middleware.model, FakeToolCallingModel)
|
||||
|
||||
|
||||
def test_summarization_middleware_no_summarization_cases() -> None:
|
||||
"""Test SummarizationMiddleware when summarization is not needed or disabled."""
|
||||
model = FakeToolCallingModel()
|
||||
middleware = SummarizationMiddleware(model=model, max_tokens_before_summary=1000)
|
||||
|
||||
# Test when summarization is disabled
|
||||
middleware_disabled = SummarizationMiddleware(model=model, max_tokens_before_summary=None)
|
||||
state = {"messages": [HumanMessage(content="Hello"), AIMessage(content="Hi")]}
|
||||
result = middleware_disabled.before_model(state)
|
||||
assert result is None
|
||||
|
||||
# Test when token count is below threshold
|
||||
def mock_token_counter(messages):
|
||||
return 500 # Below threshold
|
||||
|
||||
middleware.token_counter = mock_token_counter
|
||||
result = middleware.before_model(state)
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_summarization_middleware_helper_methods() -> None:
|
||||
"""Test SummarizationMiddleware helper methods."""
|
||||
model = FakeToolCallingModel()
|
||||
middleware = SummarizationMiddleware(model=model, max_tokens_before_summary=1000)
|
||||
|
||||
# Test message ID assignment
|
||||
messages = [HumanMessage(content="Hello"), AIMessage(content="Hi")]
|
||||
middleware._ensure_message_ids(messages)
|
||||
for msg in messages:
|
||||
assert msg.id is not None
|
||||
|
||||
# Test message partitioning
|
||||
messages = [
|
||||
HumanMessage(content="1"),
|
||||
HumanMessage(content="2"),
|
||||
HumanMessage(content="3"),
|
||||
HumanMessage(content="4"),
|
||||
HumanMessage(content="5"),
|
||||
]
|
||||
to_summarize, preserved = middleware._partition_messages(messages, 2)
|
||||
assert len(to_summarize) == 2
|
||||
assert len(preserved) == 3
|
||||
assert to_summarize == messages[:2]
|
||||
assert preserved == messages[2:]
|
||||
|
||||
# Test summary message building
|
||||
summary = "This is a test summary"
|
||||
new_messages = middleware._build_new_messages(summary)
|
||||
assert len(new_messages) == 1
|
||||
assert isinstance(new_messages[0], HumanMessage)
|
||||
assert "Here is a summary of the conversation to date:" in new_messages[0].content
|
||||
assert summary in new_messages[0].content
|
||||
|
||||
# Test tool call detection
|
||||
ai_message_no_tools = AIMessage(content="Hello")
|
||||
assert not middleware._has_tool_calls(ai_message_no_tools)
|
||||
|
||||
ai_message_with_tools = AIMessage(
|
||||
content="Hello", tool_calls=[{"name": "test", "args": {}, "id": "1"}]
|
||||
)
|
||||
assert middleware._has_tool_calls(ai_message_with_tools)
|
||||
|
||||
human_message = HumanMessage(content="Hello")
|
||||
assert not middleware._has_tool_calls(human_message)
|
||||
|
||||
|
||||
def test_summarization_middleware_tool_call_safety() -> None:
|
||||
"""Test SummarizationMiddleware tool call safety logic."""
|
||||
model = FakeToolCallingModel()
|
||||
middleware = SummarizationMiddleware(
|
||||
model=model, max_tokens_before_summary=1000, messages_to_keep=3
|
||||
)
|
||||
|
||||
# Test safe cutoff point detection with tool calls
|
||||
messages = [
|
||||
HumanMessage(content="1"),
|
||||
AIMessage(content="2", tool_calls=[{"name": "test", "args": {}, "id": "1"}]),
|
||||
ToolMessage(content="3", tool_call_id="1"),
|
||||
HumanMessage(content="4"),
|
||||
]
|
||||
|
||||
# Safe cutoff (doesn't separate AI/Tool pair)
|
||||
is_safe = middleware._is_safe_cutoff_point(messages, 0)
|
||||
assert is_safe is True
|
||||
|
||||
# Unsafe cutoff (separates AI/Tool pair)
|
||||
is_safe = middleware._is_safe_cutoff_point(messages, 2)
|
||||
assert is_safe is False
|
||||
|
||||
# Test tool call ID extraction
|
||||
ids = middleware._extract_tool_call_ids(messages[1])
|
||||
assert ids == {"1"}
|
||||
|
||||
|
||||
def test_summarization_middleware_summary_creation() -> None:
|
||||
"""Test SummarizationMiddleware summary creation."""
|
||||
|
||||
class MockModel(BaseChatModel):
|
||||
def invoke(self, prompt):
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
return AIMessage(content="Generated summary")
|
||||
|
||||
def _generate(self, messages, **kwargs):
|
||||
from langchain_core.outputs import ChatResult, ChatGeneration
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))])
|
||||
|
||||
@property
|
||||
def _llm_type(self):
|
||||
return "mock"
|
||||
|
||||
middleware = SummarizationMiddleware(model=MockModel(), max_tokens_before_summary=1000)
|
||||
|
||||
# Test normal summary creation
|
||||
messages = [HumanMessage(content="Hello"), AIMessage(content="Hi")]
|
||||
summary = middleware._create_summary(messages)
|
||||
assert summary == "Generated summary"
|
||||
|
||||
# Test empty messages
|
||||
summary = middleware._create_summary([])
|
||||
assert summary == "No previous conversation history."
|
||||
|
||||
# Test error handling
|
||||
class ErrorModel(BaseChatModel):
|
||||
def invoke(self, prompt):
|
||||
raise Exception("Model error")
|
||||
|
||||
def _generate(self, messages, **kwargs):
|
||||
from langchain_core.outputs import ChatResult, ChatGeneration
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))])
|
||||
|
||||
@property
|
||||
def _llm_type(self):
|
||||
return "mock"
|
||||
|
||||
middleware_error = SummarizationMiddleware(model=ErrorModel(), max_tokens_before_summary=1000)
|
||||
summary = middleware_error._create_summary(messages)
|
||||
assert "Error generating summary: Model error" in summary
|
||||
|
||||
|
||||
def test_summarization_middleware_full_workflow() -> None:
|
||||
"""Test SummarizationMiddleware complete summarization workflow."""
|
||||
|
||||
class MockModel(BaseChatModel):
|
||||
def invoke(self, prompt):
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
return AIMessage(content="Generated summary")
|
||||
|
||||
def _generate(self, messages, **kwargs):
|
||||
from langchain_core.outputs import ChatResult, ChatGeneration
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))])
|
||||
|
||||
@property
|
||||
def _llm_type(self):
|
||||
return "mock"
|
||||
|
||||
middleware = SummarizationMiddleware(
|
||||
model=MockModel(), max_tokens_before_summary=1000, messages_to_keep=2
|
||||
)
|
||||
|
||||
# Mock high token count to trigger summarization
|
||||
def mock_token_counter(messages):
|
||||
return 1500 # Above threshold
|
||||
|
||||
middleware.token_counter = mock_token_counter
|
||||
|
||||
messages = [
|
||||
HumanMessage(content="1"),
|
||||
HumanMessage(content="2"),
|
||||
HumanMessage(content="3"),
|
||||
HumanMessage(content="4"),
|
||||
HumanMessage(content="5"),
|
||||
]
|
||||
|
||||
state = {"messages": messages}
|
||||
result = middleware.before_model(state)
|
||||
|
||||
assert result is not None
|
||||
assert "messages" in result
|
||||
assert len(result["messages"]) > 0
|
||||
|
||||
# Should have RemoveMessage for cleanup
|
||||
assert isinstance(result["messages"][0], RemoveMessage)
|
||||
assert result["messages"][0].id == REMOVE_ALL_MESSAGES
|
||||
|
||||
# Should have summary message
|
||||
summary_message = None
|
||||
for msg in result["messages"]:
|
||||
if isinstance(msg, HumanMessage) and "summary of the conversation" in msg.content:
|
||||
summary_message = msg
|
||||
break
|
||||
|
||||
assert summary_message is not None
|
||||
assert "Generated summary" in summary_message.content
|
||||
1645
libs/langchain_v1/tests/unit_tests/agents/test_react_agent.py
Normal file
1645
libs/langchain_v1/tests/unit_tests/agents/test_react_agent.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,58 @@
|
||||
from collections.abc import Callable
|
||||
from typing import Union
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
|
||||
from langchain.agents import create_agent
|
||||
|
||||
from .model import FakeToolCallingModel
|
||||
|
||||
model = FakeToolCallingModel()
|
||||
|
||||
|
||||
def tool() -> None:
|
||||
"""Testing tool."""
|
||||
|
||||
|
||||
def pre_model_hook() -> None:
|
||||
"""Pre-model hook."""
|
||||
|
||||
|
||||
def post_model_hook() -> None:
|
||||
"""Post-model hook."""
|
||||
|
||||
|
||||
class ResponseFormat(BaseModel):
|
||||
"""Response format for the agent."""
|
||||
|
||||
result: str
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tools", [[], [tool]])
|
||||
@pytest.mark.parametrize("pre_model_hook", [None, pre_model_hook])
|
||||
@pytest.mark.parametrize("post_model_hook", [None, post_model_hook])
|
||||
def test_react_agent_graph_structure(
|
||||
snapshot: SnapshotAssertion,
|
||||
tools: list[Callable],
|
||||
pre_model_hook: Union[Callable, None],
|
||||
post_model_hook: Union[Callable, None],
|
||||
) -> None:
|
||||
agent = create_agent(
|
||||
model,
|
||||
tools=tools,
|
||||
pre_model_hook=pre_model_hook,
|
||||
post_model_hook=post_model_hook,
|
||||
)
|
||||
try:
|
||||
assert agent.get_graph().draw_mermaid(with_styles=False) == snapshot
|
||||
except Exception as e:
|
||||
msg = (
|
||||
"The graph structure has changed. Please update the snapshot."
|
||||
"Configuration used:\n"
|
||||
f"tools: {tools}, "
|
||||
f"pre_model_hook: {pre_model_hook}, "
|
||||
f"post_model_hook: {post_model_hook}, "
|
||||
)
|
||||
raise ValueError(msg) from e
|
||||
@@ -0,0 +1,704 @@
|
||||
"""Test suite for create_agent with structured output response_format permutations."""
|
||||
|
||||
import pytest
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Union
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain.agents import create_agent
|
||||
from langchain.agents.structured_output import (
|
||||
MultipleStructuredOutputsError,
|
||||
ProviderStrategy,
|
||||
StructuredOutputValidationError,
|
||||
ToolStrategy,
|
||||
)
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from tests.unit_tests.agents.model import FakeToolCallingModel
|
||||
|
||||
|
||||
# Test data models
|
||||
class WeatherBaseModel(BaseModel):
|
||||
"""Weather response."""
|
||||
|
||||
temperature: float = Field(description="The temperature in fahrenheit")
|
||||
condition: str = Field(description="Weather condition")
|
||||
|
||||
|
||||
@dataclass
|
||||
class WeatherDataclass:
|
||||
"""Weather response."""
|
||||
|
||||
temperature: float
|
||||
condition: str
|
||||
|
||||
|
||||
class WeatherTypedDict(TypedDict):
|
||||
"""Weather response."""
|
||||
|
||||
temperature: float
|
||||
condition: str
|
||||
|
||||
|
||||
weather_json_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"temperature": {"type": "number", "description": "Temperature in fahrenheit"},
|
||||
"condition": {"type": "string", "description": "Weather condition"},
|
||||
},
|
||||
"title": "weather_schema",
|
||||
"required": ["temperature", "condition"],
|
||||
}
|
||||
|
||||
|
||||
class LocationResponse(BaseModel):
|
||||
city: str = Field(description="The city name")
|
||||
country: str = Field(description="The country name")
|
||||
|
||||
|
||||
class LocationTypedDict(TypedDict):
|
||||
city: str
|
||||
country: str
|
||||
|
||||
|
||||
location_json_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {"type": "string", "description": "The city name"},
|
||||
"country": {"type": "string", "description": "The country name"},
|
||||
},
|
||||
"title": "location_schema",
|
||||
"required": ["city", "country"],
|
||||
}
|
||||
|
||||
|
||||
def get_weather() -> str:
|
||||
"""Get the weather."""
|
||||
|
||||
return "The weather is sunny and 75°F."
|
||||
|
||||
|
||||
def get_location() -> str:
|
||||
"""Get the current location."""
|
||||
|
||||
return "You are in New York, USA."
|
||||
|
||||
|
||||
# Standardized test data
|
||||
WEATHER_DATA = {"temperature": 75.0, "condition": "sunny"}
|
||||
LOCATION_DATA = {"city": "New York", "country": "USA"}
|
||||
|
||||
# Standardized expected responses
|
||||
EXPECTED_WEATHER_PYDANTIC = WeatherBaseModel(**WEATHER_DATA)
|
||||
EXPECTED_WEATHER_DATACLASS = WeatherDataclass(**WEATHER_DATA)
|
||||
EXPECTED_WEATHER_DICT: WeatherTypedDict = {"temperature": 75.0, "condition": "sunny"}
|
||||
EXPECTED_LOCATION = LocationResponse(**LOCATION_DATA)
|
||||
EXPECTED_LOCATION_DICT: LocationTypedDict = {"city": "New York", "country": "USA"}
|
||||
|
||||
|
||||
class TestResponseFormatAsModel:
|
||||
def test_pydantic_model(self) -> None:
|
||||
"""Test response_format as Pydantic model."""
|
||||
tool_calls = [
|
||||
[{"args": {}, "id": "1", "name": "get_weather"}],
|
||||
[
|
||||
{
|
||||
"name": "WeatherBaseModel",
|
||||
"id": "2",
|
||||
"args": WEATHER_DATA,
|
||||
}
|
||||
],
|
||||
]
|
||||
|
||||
model = FakeToolCallingModel(tool_calls=tool_calls)
|
||||
|
||||
agent = create_agent(model, [get_weather], response_format=WeatherBaseModel)
|
||||
response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
|
||||
|
||||
assert response["structured_response"] == EXPECTED_WEATHER_PYDANTIC
|
||||
assert len(response["messages"]) == 5
|
||||
|
||||
def test_dataclass(self) -> None:
|
||||
"""Test response_format as dataclass."""
|
||||
tool_calls = [
|
||||
[{"args": {}, "id": "1", "name": "get_weather"}],
|
||||
[
|
||||
{
|
||||
"name": "WeatherDataclass",
|
||||
"id": "2",
|
||||
"args": WEATHER_DATA,
|
||||
}
|
||||
],
|
||||
]
|
||||
|
||||
model = FakeToolCallingModel(tool_calls=tool_calls)
|
||||
|
||||
agent = create_agent(model, [get_weather], response_format=WeatherDataclass)
|
||||
response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
|
||||
|
||||
assert response["structured_response"] == EXPECTED_WEATHER_DATACLASS
|
||||
assert len(response["messages"]) == 5
|
||||
|
||||
def test_typed_dict(self) -> None:
|
||||
"""Test response_format as TypedDict."""
|
||||
tool_calls = [
|
||||
[{"args": {}, "id": "1", "name": "get_weather"}],
|
||||
[
|
||||
{
|
||||
"name": "WeatherTypedDict",
|
||||
"id": "2",
|
||||
"args": WEATHER_DATA,
|
||||
}
|
||||
],
|
||||
]
|
||||
|
||||
model = FakeToolCallingModel(tool_calls=tool_calls)
|
||||
|
||||
agent = create_agent(model, [get_weather], response_format=WeatherTypedDict)
|
||||
response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
|
||||
|
||||
assert response["structured_response"] == EXPECTED_WEATHER_DICT
|
||||
assert len(response["messages"]) == 5
|
||||
|
||||
def test_json_schema(self) -> None:
|
||||
"""Test response_format as JSON schema."""
|
||||
tool_calls = [
|
||||
[{"args": {}, "id": "1", "name": "get_weather"}],
|
||||
[
|
||||
{
|
||||
"name": "weather_schema",
|
||||
"id": "2",
|
||||
"args": WEATHER_DATA,
|
||||
}
|
||||
],
|
||||
]
|
||||
|
||||
model = FakeToolCallingModel(tool_calls=tool_calls)
|
||||
|
||||
agent = create_agent(model, [get_weather], response_format=weather_json_schema)
|
||||
response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
|
||||
|
||||
assert response["structured_response"] == EXPECTED_WEATHER_DICT
|
||||
assert len(response["messages"]) == 5
|
||||
|
||||
|
||||
class TestResponseFormatAsToolStrategy:
|
||||
def test_pydantic_model(self) -> None:
|
||||
"""Test response_format as ToolStrategy with Pydantic model."""
|
||||
tool_calls = [
|
||||
[{"args": {}, "id": "1", "name": "get_weather"}],
|
||||
[
|
||||
{
|
||||
"name": "WeatherBaseModel",
|
||||
"id": "2",
|
||||
"args": WEATHER_DATA,
|
||||
}
|
||||
],
|
||||
]
|
||||
|
||||
model = FakeToolCallingModel(tool_calls=tool_calls)
|
||||
|
||||
agent = create_agent(model, [get_weather], response_format=ToolStrategy(WeatherBaseModel))
|
||||
response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
|
||||
|
||||
assert response["structured_response"] == EXPECTED_WEATHER_PYDANTIC
|
||||
assert len(response["messages"]) == 5
|
||||
|
||||
def test_dataclass(self) -> None:
|
||||
"""Test response_format as ToolStrategy with dataclass."""
|
||||
tool_calls = [
|
||||
[{"args": {}, "id": "1", "name": "get_weather"}],
|
||||
[
|
||||
{
|
||||
"name": "WeatherDataclass",
|
||||
"id": "2",
|
||||
"args": WEATHER_DATA,
|
||||
}
|
||||
],
|
||||
]
|
||||
|
||||
model = FakeToolCallingModel(tool_calls=tool_calls)
|
||||
|
||||
agent = create_agent(model, [get_weather], response_format=ToolStrategy(WeatherDataclass))
|
||||
response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
|
||||
|
||||
assert response["structured_response"] == EXPECTED_WEATHER_DATACLASS
|
||||
assert len(response["messages"]) == 5
|
||||
|
||||
def test_typed_dict(self) -> None:
|
||||
"""Test response_format as ToolStrategy with TypedDict."""
|
||||
tool_calls = [
|
||||
[{"args": {}, "id": "1", "name": "get_weather"}],
|
||||
[
|
||||
{
|
||||
"name": "WeatherTypedDict",
|
||||
"id": "2",
|
||||
"args": WEATHER_DATA,
|
||||
}
|
||||
],
|
||||
]
|
||||
|
||||
model = FakeToolCallingModel(tool_calls=tool_calls)
|
||||
|
||||
agent = create_agent(model, [get_weather], response_format=ToolStrategy(WeatherTypedDict))
|
||||
response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
|
||||
|
||||
assert response["structured_response"] == EXPECTED_WEATHER_DICT
|
||||
assert len(response["messages"]) == 5
|
||||
|
||||
def test_json_schema(self) -> None:
|
||||
"""Test response_format as ToolStrategy with JSON schema."""
|
||||
tool_calls = [
|
||||
[{"args": {}, "id": "1", "name": "get_weather"}],
|
||||
[
|
||||
{
|
||||
"name": "weather_schema",
|
||||
"id": "2",
|
||||
"args": WEATHER_DATA,
|
||||
}
|
||||
],
|
||||
]
|
||||
|
||||
model = FakeToolCallingModel(tool_calls=tool_calls)
|
||||
|
||||
agent = create_agent(
|
||||
model, [get_weather], response_format=ToolStrategy(weather_json_schema)
|
||||
)
|
||||
response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
|
||||
|
||||
assert response["structured_response"] == EXPECTED_WEATHER_DICT
|
||||
assert len(response["messages"]) == 5
|
||||
|
||||
def test_union_of_json_schemas(self) -> None:
|
||||
"""Test response_format as ToolStrategy with union of JSON schemas."""
|
||||
tool_calls = [
|
||||
[{"args": {}, "id": "1", "name": "get_weather"}],
|
||||
[
|
||||
{
|
||||
"name": "weather_schema",
|
||||
"id": "2",
|
||||
"args": WEATHER_DATA,
|
||||
}
|
||||
],
|
||||
]
|
||||
|
||||
model = FakeToolCallingModel(tool_calls=tool_calls)
|
||||
|
||||
agent = create_agent(
|
||||
model,
|
||||
[get_weather, get_location],
|
||||
response_format=ToolStrategy({"oneOf": [weather_json_schema, location_json_schema]}),
|
||||
)
|
||||
response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
|
||||
|
||||
assert response["structured_response"] == EXPECTED_WEATHER_DICT
|
||||
assert len(response["messages"]) == 5
|
||||
|
||||
# Test with LocationResponse
|
||||
tool_calls_location = [
|
||||
[{"args": {}, "id": "1", "name": "get_location"}],
|
||||
[
|
||||
{
|
||||
"name": "location_schema",
|
||||
"id": "2",
|
||||
"args": LOCATION_DATA,
|
||||
}
|
||||
],
|
||||
]
|
||||
|
||||
model_location = FakeToolCallingModel(tool_calls=tool_calls_location)
|
||||
|
||||
agent_location = create_agent(
|
||||
model_location,
|
||||
[get_weather, get_location],
|
||||
response_format=ToolStrategy({"oneOf": [weather_json_schema, location_json_schema]}),
|
||||
)
|
||||
response_location = agent_location.invoke({"messages": [HumanMessage("Where am I?")]})
|
||||
|
||||
assert response_location["structured_response"] == EXPECTED_LOCATION_DICT
|
||||
assert len(response_location["messages"]) == 5
|
||||
|
||||
def test_union_of_types(self) -> None:
|
||||
"""Test response_format as ToolStrategy with Union of various types."""
|
||||
# Test with WeatherBaseModel
|
||||
tool_calls = [
|
||||
[{"args": {}, "id": "1", "name": "get_weather"}],
|
||||
[
|
||||
{
|
||||
"name": "WeatherBaseModel",
|
||||
"id": "2",
|
||||
"args": WEATHER_DATA,
|
||||
}
|
||||
],
|
||||
]
|
||||
|
||||
model = FakeToolCallingModel[Union[WeatherBaseModel, LocationResponse]](
|
||||
tool_calls=tool_calls
|
||||
)
|
||||
|
||||
agent = create_agent(
|
||||
model,
|
||||
[get_weather, get_location],
|
||||
response_format=ToolStrategy(Union[WeatherBaseModel, LocationResponse]),
|
||||
)
|
||||
response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
|
||||
|
||||
assert response["structured_response"] == EXPECTED_WEATHER_PYDANTIC
|
||||
assert len(response["messages"]) == 5
|
||||
|
||||
# Test with LocationResponse
|
||||
tool_calls_location = [
|
||||
[{"args": {}, "id": "1", "name": "get_location"}],
|
||||
[
|
||||
{
|
||||
"name": "LocationResponse",
|
||||
"id": "2",
|
||||
"args": LOCATION_DATA,
|
||||
}
|
||||
],
|
||||
]
|
||||
|
||||
model_location = FakeToolCallingModel(tool_calls=tool_calls_location)
|
||||
|
||||
agent_location = create_agent(
|
||||
model_location,
|
||||
[get_weather, get_location],
|
||||
response_format=ToolStrategy(Union[WeatherBaseModel, LocationResponse]),
|
||||
)
|
||||
response_location = agent_location.invoke({"messages": [HumanMessage("Where am I?")]})
|
||||
|
||||
assert response_location["structured_response"] == EXPECTED_LOCATION
|
||||
assert len(response_location["messages"]) == 5
|
||||
|
||||
def test_multiple_structured_outputs_error_without_retry(self) -> None:
|
||||
"""Test that MultipleStructuredOutputsError is raised when model returns multiple structured tool calls without retry."""
|
||||
tool_calls = [
|
||||
[
|
||||
{
|
||||
"name": "WeatherBaseModel",
|
||||
"id": "1",
|
||||
"args": WEATHER_DATA,
|
||||
},
|
||||
{
|
||||
"name": "LocationResponse",
|
||||
"id": "2",
|
||||
"args": LOCATION_DATA,
|
||||
},
|
||||
],
|
||||
]
|
||||
|
||||
model = FakeToolCallingModel(tool_calls=tool_calls)
|
||||
|
||||
agent = create_agent(
|
||||
model,
|
||||
[],
|
||||
response_format=ToolStrategy(
|
||||
Union[WeatherBaseModel, LocationResponse],
|
||||
handle_errors=False,
|
||||
),
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
MultipleStructuredOutputsError,
|
||||
match=".*WeatherBaseModel.*LocationResponse.*",
|
||||
):
|
||||
agent.invoke({"messages": [HumanMessage("Give me weather and location")]})
|
||||
|
||||
def test_multiple_structured_outputs_with_retry(self) -> None:
|
||||
"""Test that retry handles multiple structured output tool calls."""
|
||||
tool_calls = [
|
||||
[
|
||||
{
|
||||
"name": "WeatherBaseModel",
|
||||
"id": "1",
|
||||
"args": WEATHER_DATA,
|
||||
},
|
||||
{
|
||||
"name": "LocationResponse",
|
||||
"id": "2",
|
||||
"args": LOCATION_DATA,
|
||||
},
|
||||
],
|
||||
[
|
||||
{
|
||||
"name": "WeatherBaseModel",
|
||||
"id": "3",
|
||||
"args": WEATHER_DATA,
|
||||
},
|
||||
],
|
||||
]
|
||||
|
||||
model = FakeToolCallingModel(tool_calls=tool_calls)
|
||||
|
||||
agent = create_agent(
|
||||
model,
|
||||
[],
|
||||
response_format=ToolStrategy(
|
||||
Union[WeatherBaseModel, LocationResponse],
|
||||
handle_errors=True,
|
||||
),
|
||||
)
|
||||
|
||||
response = agent.invoke({"messages": [HumanMessage("Give me weather")]})
|
||||
|
||||
# HumanMessage, AIMessage, ToolMessage, ToolMessage, AI, ToolMessage
|
||||
assert len(response["messages"]) == 6
|
||||
assert response["structured_response"] == EXPECTED_WEATHER_PYDANTIC
|
||||
|
||||
def test_structured_output_parsing_error_without_retry(self) -> None:
|
||||
"""Test that StructuredOutputParsingError is raised when tool args fail to parse without retry."""
|
||||
tool_calls = [
|
||||
[
|
||||
{
|
||||
"name": "WeatherBaseModel",
|
||||
"id": "1",
|
||||
"args": {"invalid": "data"},
|
||||
},
|
||||
],
|
||||
]
|
||||
|
||||
model = FakeToolCallingModel(tool_calls=tool_calls)
|
||||
|
||||
agent = create_agent(
|
||||
model,
|
||||
[],
|
||||
response_format=ToolStrategy(
|
||||
WeatherBaseModel,
|
||||
handle_errors=False,
|
||||
),
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
StructuredOutputValidationError,
|
||||
match=".*WeatherBaseModel.*",
|
||||
):
|
||||
agent.invoke({"messages": [HumanMessage("What's the weather?")]})
|
||||
|
||||
def test_structured_output_parsing_error_with_retry(self) -> None:
|
||||
"""Test that retry handles parsing errors for structured output."""
|
||||
tool_calls = [
|
||||
[
|
||||
{
|
||||
"name": "WeatherBaseModel",
|
||||
"id": "1",
|
||||
"args": {"invalid": "data"},
|
||||
},
|
||||
],
|
||||
[
|
||||
{
|
||||
"name": "WeatherBaseModel",
|
||||
"id": "2",
|
||||
"args": WEATHER_DATA,
|
||||
},
|
||||
],
|
||||
]
|
||||
|
||||
model = FakeToolCallingModel(tool_calls=tool_calls)
|
||||
|
||||
agent = create_agent(
|
||||
model,
|
||||
[],
|
||||
response_format=ToolStrategy(
|
||||
WeatherBaseModel,
|
||||
handle_errors=(StructuredOutputValidationError,),
|
||||
),
|
||||
)
|
||||
|
||||
response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
|
||||
|
||||
# HumanMessage, AIMessage, ToolMessage, AIMessage, ToolMessage
|
||||
assert len(response["messages"]) == 5
|
||||
assert response["structured_response"] == EXPECTED_WEATHER_PYDANTIC
|
||||
|
||||
def test_retry_with_custom_function(self) -> None:
|
||||
"""Test retry with custom message generation."""
|
||||
tool_calls = [
|
||||
[
|
||||
{
|
||||
"name": "WeatherBaseModel",
|
||||
"id": "1",
|
||||
"args": WEATHER_DATA,
|
||||
},
|
||||
{
|
||||
"name": "LocationResponse",
|
||||
"id": "2",
|
||||
"args": LOCATION_DATA,
|
||||
},
|
||||
],
|
||||
[
|
||||
{
|
||||
"name": "WeatherBaseModel",
|
||||
"id": "3",
|
||||
"args": WEATHER_DATA,
|
||||
},
|
||||
],
|
||||
]
|
||||
|
||||
model = FakeToolCallingModel(tool_calls=tool_calls)
|
||||
|
||||
def custom_message(exception: Exception) -> str:
|
||||
if isinstance(exception, MultipleStructuredOutputsError):
|
||||
return "Custom error: Multiple outputs not allowed"
|
||||
return "Custom error"
|
||||
|
||||
agent = create_agent(
|
||||
model,
|
||||
[],
|
||||
response_format=ToolStrategy(
|
||||
Union[WeatherBaseModel, LocationResponse],
|
||||
handle_errors=custom_message,
|
||||
),
|
||||
)
|
||||
|
||||
response = agent.invoke({"messages": [HumanMessage("Give me weather")]})
|
||||
|
||||
# HumanMessage, AIMessage, ToolMessage, ToolMessage, AI, ToolMessage
|
||||
assert len(response["messages"]) == 6
|
||||
assert response["messages"][2].content == "Custom error: Multiple outputs not allowed"
|
||||
assert response["messages"][3].content == "Custom error: Multiple outputs not allowed"
|
||||
assert response["structured_response"] == EXPECTED_WEATHER_PYDANTIC
|
||||
|
||||
def test_retry_with_custom_string_message(self) -> None:
|
||||
"""Test retry with custom static string message."""
|
||||
tool_calls = [
|
||||
[
|
||||
{
|
||||
"name": "WeatherBaseModel",
|
||||
"id": "1",
|
||||
"args": {"invalid": "data"},
|
||||
},
|
||||
],
|
||||
[
|
||||
{
|
||||
"name": "WeatherBaseModel",
|
||||
"id": "2",
|
||||
"args": WEATHER_DATA,
|
||||
},
|
||||
],
|
||||
]
|
||||
|
||||
model = FakeToolCallingModel(tool_calls=tool_calls)
|
||||
|
||||
agent = create_agent(
|
||||
model,
|
||||
[],
|
||||
response_format=ToolStrategy(
|
||||
WeatherBaseModel,
|
||||
handle_errors="Please provide valid weather data with temperature and condition.",
|
||||
),
|
||||
)
|
||||
|
||||
response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
|
||||
|
||||
assert len(response["messages"]) == 5
|
||||
assert (
|
||||
response["messages"][2].content
|
||||
== "Please provide valid weather data with temperature and condition."
|
||||
)
|
||||
assert response["structured_response"] == EXPECTED_WEATHER_PYDANTIC
|
||||
|
||||
|
||||
class TestResponseFormatAsProviderStrategy:
|
||||
def test_pydantic_model(self) -> None:
|
||||
"""Test response_format as ProviderStrategy with Pydantic model."""
|
||||
tool_calls = [
|
||||
[{"args": {}, "id": "1", "name": "get_weather"}],
|
||||
]
|
||||
|
||||
model = FakeToolCallingModel[WeatherBaseModel](
|
||||
tool_calls=tool_calls, structured_response=EXPECTED_WEATHER_PYDANTIC
|
||||
)
|
||||
|
||||
agent = create_agent(
|
||||
model, [get_weather], response_format=ProviderStrategy(WeatherBaseModel)
|
||||
)
|
||||
response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
|
||||
|
||||
assert response["structured_response"] == EXPECTED_WEATHER_PYDANTIC
|
||||
assert len(response["messages"]) == 4
|
||||
|
||||
def test_dataclass(self) -> None:
|
||||
"""Test response_format as ProviderStrategy with dataclass."""
|
||||
tool_calls = [
|
||||
[{"args": {}, "id": "1", "name": "get_weather"}],
|
||||
]
|
||||
|
||||
model = FakeToolCallingModel[WeatherDataclass](
|
||||
tool_calls=tool_calls, structured_response=EXPECTED_WEATHER_DATACLASS
|
||||
)
|
||||
|
||||
agent = create_agent(
|
||||
model, [get_weather], response_format=ProviderStrategy(WeatherDataclass)
|
||||
)
|
||||
response = agent.invoke(
|
||||
{"messages": [HumanMessage("What's the weather?")]},
|
||||
)
|
||||
|
||||
assert response["structured_response"] == EXPECTED_WEATHER_DATACLASS
|
||||
assert len(response["messages"]) == 4
|
||||
|
||||
def test_typed_dict(self) -> None:
|
||||
"""Test response_format as ProviderStrategy with TypedDict."""
|
||||
tool_calls = [
|
||||
[{"args": {}, "id": "1", "name": "get_weather"}],
|
||||
]
|
||||
|
||||
model = FakeToolCallingModel[WeatherTypedDict](
|
||||
tool_calls=tool_calls, structured_response=EXPECTED_WEATHER_DICT
|
||||
)
|
||||
|
||||
agent = create_agent(
|
||||
model, [get_weather], response_format=ProviderStrategy(WeatherTypedDict)
|
||||
)
|
||||
response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
|
||||
|
||||
assert response["structured_response"] == EXPECTED_WEATHER_DICT
|
||||
assert len(response["messages"]) == 4
|
||||
|
||||
def test_json_schema(self) -> None:
|
||||
"""Test response_format as ProviderStrategy with JSON schema."""
|
||||
tool_calls = [
|
||||
[{"args": {}, "id": "1", "name": "get_weather"}],
|
||||
]
|
||||
|
||||
model = FakeToolCallingModel[dict](
|
||||
tool_calls=tool_calls, structured_response=EXPECTED_WEATHER_DICT
|
||||
)
|
||||
|
||||
agent = create_agent(
|
||||
model, [get_weather], response_format=ProviderStrategy(weather_json_schema)
|
||||
)
|
||||
response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
|
||||
|
||||
assert response["structured_response"] == EXPECTED_WEATHER_DICT
|
||||
assert len(response["messages"]) == 4
|
||||
|
||||
|
||||
def test_union_of_types() -> None:
|
||||
"""Test response_format as ProviderStrategy with Union (if supported)."""
|
||||
tool_calls = [
|
||||
[{"args": {}, "id": "1", "name": "get_weather"}],
|
||||
[
|
||||
{
|
||||
"name": "WeatherBaseModel",
|
||||
"id": "2",
|
||||
"args": WEATHER_DATA,
|
||||
}
|
||||
],
|
||||
]
|
||||
|
||||
model = FakeToolCallingModel[Union[WeatherBaseModel, LocationResponse]](
|
||||
tool_calls=tool_calls, structured_response=EXPECTED_WEATHER_PYDANTIC
|
||||
)
|
||||
|
||||
agent = create_agent(
|
||||
model,
|
||||
[get_weather, get_location],
|
||||
response_format=ToolStrategy(Union[WeatherBaseModel, LocationResponse]),
|
||||
)
|
||||
response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
|
||||
|
||||
assert response["structured_response"] == EXPECTED_WEATHER_PYDANTIC
|
||||
assert len(response["messages"]) == 5
|
||||
140
libs/langchain_v1/tests/unit_tests/agents/test_responses.py
Normal file
140
libs/langchain_v1/tests/unit_tests/agents/test_responses.py
Normal file
@@ -0,0 +1,140 @@
|
||||
"""Unit tests for langgraph.prebuilt.responses module."""
|
||||
|
||||
import pytest
|
||||
|
||||
# Skip this test since langgraph.prebuilt.responses is not available
|
||||
pytest.skip("langgraph.prebuilt.responses not available", allow_module_level=True)
|
||||
|
||||
|
||||
class _TestModel(BaseModel):
|
||||
"""A test model for structured output."""
|
||||
|
||||
name: str
|
||||
age: int
|
||||
email: str = "default@example.com"
|
||||
|
||||
|
||||
class CustomModel(BaseModel):
|
||||
"""Custom model with a custom docstring."""
|
||||
|
||||
value: float
|
||||
description: str
|
||||
|
||||
|
||||
class EmptyDocModel(BaseModel):
|
||||
# No custom docstring, should have no description in tool
|
||||
data: str
|
||||
|
||||
|
||||
class TestUsingToolStrategy:
|
||||
"""Test UsingToolStrategy dataclass."""
|
||||
|
||||
def test_basic_creation(self) -> None:
|
||||
"""Test basic UsingToolStrategy creation."""
|
||||
strategy = ToolStrategy(schema=_TestModel)
|
||||
assert strategy.schema == _TestModel
|
||||
assert strategy.tool_message_content is None
|
||||
assert len(strategy.schema_specs) == 1
|
||||
|
||||
def test_multiple_schemas(self) -> None:
|
||||
"""Test UsingToolStrategy with multiple schemas."""
|
||||
strategy = ToolStrategy(schema=Union[_TestModel, CustomModel])
|
||||
assert len(strategy.schema_specs) == 2
|
||||
assert strategy.schema_specs[0].schema == _TestModel
|
||||
assert strategy.schema_specs[1].schema == CustomModel
|
||||
|
||||
def test_schema_with_tool_message_content(self) -> None:
|
||||
"""Test UsingToolStrategy with tool message content."""
|
||||
strategy = ToolStrategy(schema=_TestModel, tool_message_content="custom message")
|
||||
assert strategy.schema == _TestModel
|
||||
assert strategy.tool_message_content == "custom message"
|
||||
assert len(strategy.schema_specs) == 1
|
||||
|
||||
|
||||
class TestOutputToolBinding:
|
||||
"""Test OutputToolBinding dataclass and its methods."""
|
||||
|
||||
def test_from_schema_spec_basic(self) -> None:
|
||||
"""Test basic OutputToolBinding creation from SchemaSpec."""
|
||||
schema_spec = _SchemaSpec(schema=_TestModel)
|
||||
tool_binding = OutputToolBinding.from_schema_spec(schema_spec)
|
||||
|
||||
assert tool_binding.schema == _TestModel
|
||||
assert tool_binding.schema_kind == "pydantic"
|
||||
assert tool_binding.tool is not None
|
||||
assert tool_binding.tool.name == "_TestModel"
|
||||
|
||||
def test_from_schema_spec_with_custom_name(self) -> None:
|
||||
"""Test OutputToolBinding creation with custom name."""
|
||||
schema_spec = _SchemaSpec(schema=_TestModel, name="custom_tool_name")
|
||||
tool_binding = OutputToolBinding.from_schema_spec(schema_spec)
|
||||
assert tool_binding.tool.name == "custom_tool_name"
|
||||
|
||||
def test_from_schema_spec_with_custom_description(self) -> None:
|
||||
"""Test OutputToolBinding creation with custom description."""
|
||||
schema_spec = _SchemaSpec(schema=_TestModel, description="Custom tool description")
|
||||
tool_binding = OutputToolBinding.from_schema_spec(schema_spec)
|
||||
|
||||
assert tool_binding.tool.description == "Custom tool description"
|
||||
|
||||
def test_from_schema_spec_with_model_docstring(self) -> None:
|
||||
"""Test OutputToolBinding creation using model docstring as description."""
|
||||
schema_spec = _SchemaSpec(schema=CustomModel)
|
||||
tool_binding = OutputToolBinding.from_schema_spec(schema_spec)
|
||||
|
||||
assert tool_binding.tool.description == "Custom model with a custom docstring."
|
||||
|
||||
@pytest.mark.skip(reason="Need to fix bug in langchain-core for inheritance of doc-strings.")
|
||||
def test_from_schema_spec_empty_docstring(self) -> None:
|
||||
"""Test OutputToolBinding creation with model that has default docstring."""
|
||||
|
||||
# Create a model with the same docstring as BaseModel
|
||||
class DefaultDocModel(BaseModel):
|
||||
# This should have the same docstring as BaseModel
|
||||
pass
|
||||
|
||||
schema_spec = _SchemaSpec(schema=DefaultDocModel)
|
||||
tool_binding = OutputToolBinding.from_schema_spec(schema_spec)
|
||||
|
||||
# Should use empty description when model has default BaseModel docstring
|
||||
assert tool_binding.tool.description == ""
|
||||
|
||||
def test_parse_payload_pydantic_success(self) -> None:
|
||||
"""Test successful parsing for Pydantic model."""
|
||||
schema_spec = _SchemaSpec(schema=_TestModel)
|
||||
tool_binding = OutputToolBinding.from_schema_spec(schema_spec)
|
||||
|
||||
tool_args = {"name": "John", "age": 30}
|
||||
result = tool_binding.parse(tool_args)
|
||||
|
||||
assert isinstance(result, _TestModel)
|
||||
assert result.name == "John"
|
||||
assert result.age == 30
|
||||
assert result.email == "default@example.com" # default value
|
||||
|
||||
def test_parse_payload_pydantic_validation_error(self) -> None:
|
||||
"""Test parsing failure for invalid Pydantic data."""
|
||||
schema_spec = _SchemaSpec(schema=_TestModel)
|
||||
tool_binding = OutputToolBinding.from_schema_spec(schema_spec)
|
||||
|
||||
# Missing required field 'name'
|
||||
tool_args = {"age": 30}
|
||||
|
||||
with pytest.raises(ValueError, match="Failed to parse data to _TestModel"):
|
||||
tool_binding.parse(tool_args)
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
"""Test edge cases and error conditions."""
|
||||
|
||||
def test_empty_schemas_list(self) -> None:
|
||||
"""Test UsingToolStrategy with empty schemas list."""
|
||||
strategy = ToolStrategy(EmptyDocModel)
|
||||
assert len(strategy.schema_specs) == 1
|
||||
|
||||
@pytest.mark.skip(reason="Need to fix bug in langchain-core for inheritance of doc-strings.")
|
||||
def test_base_model_doc_constant(self) -> None:
|
||||
"""Test that BASE_MODEL_DOC constant is set correctly."""
|
||||
binding = OutputToolBinding.from_schema_spec(_SchemaSpec(EmptyDocModel))
|
||||
assert binding.tool.name == "EmptyDocModel"
|
||||
assert binding.tool.description[:5] == "" # Should be empty for default docstring
|
||||
147
libs/langchain_v1/tests/unit_tests/agents/test_responses_spec.py
Normal file
147
libs/langchain_v1/tests/unit_tests/agents/test_responses_spec.py
Normal file
@@ -0,0 +1,147 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
# Skip this test since langgraph.prebuilt.responses is not available
|
||||
pytest.skip("langgraph.prebuilt.responses not available", allow_module_level=True)
|
||||
|
||||
try:
|
||||
from langchain_openai import ChatOpenAI
|
||||
except ImportError:
|
||||
skip_openai_integration_tests = True
|
||||
else:
|
||||
skip_openai_integration_tests = False
|
||||
|
||||
AGENT_PROMPT = "You are an HR assistant."
|
||||
|
||||
|
||||
class ToolCalls(BaseSchema):
|
||||
get_employee_role: int
|
||||
get_employee_department: int
|
||||
|
||||
|
||||
class AssertionByInvocation(BaseSchema):
|
||||
prompt: str
|
||||
tools_with_expected_calls: ToolCalls
|
||||
expected_last_message: str
|
||||
expected_structured_response: Optional[Dict[str, Any]]
|
||||
llm_request_count: int
|
||||
|
||||
|
||||
class TestCase(BaseSchema):
|
||||
name: str
|
||||
response_format: Union[Dict[str, Any], List[Dict[str, Any]]]
|
||||
assertions_by_invocation: List[AssertionByInvocation]
|
||||
|
||||
|
||||
class Employee(BaseModel):
|
||||
name: str
|
||||
role: str
|
||||
department: str
|
||||
|
||||
|
||||
EMPLOYEES: list[Employee] = [
|
||||
Employee(name="Sabine", role="Developer", department="IT"),
|
||||
Employee(name="Henrik", role="Product Manager", department="IT"),
|
||||
Employee(name="Jessica", role="HR", department="People"),
|
||||
]
|
||||
|
||||
TEST_CASES = load_spec("responses", as_model=TestCase)
|
||||
|
||||
|
||||
def _make_tool(fn, *, name: str, description: str):
|
||||
mock = MagicMock(side_effect=lambda *, name: fn(name=name))
|
||||
InputModel = create_model(f"{name}_input", name=(str, ...))
|
||||
|
||||
@tool(name, description=description, args_schema=InputModel)
|
||||
def _wrapped(name: str):
|
||||
return mock(name=name)
|
||||
|
||||
return {"tool": _wrapped, "mock": mock}
|
||||
|
||||
|
||||
@pytest.mark.skipif(skip_openai_integration_tests, reason="OpenAI integration tests are disabled.")
|
||||
@pytest.mark.parametrize("case", TEST_CASES, ids=[c.name for c in TEST_CASES])
|
||||
def test_responses_integration_matrix(case: TestCase) -> None:
|
||||
if case.name == "asking for information that does not fit into the response format":
|
||||
pytest.xfail(
|
||||
"currently failing due to undefined behavior when model cannot conform to any of the structured response formats."
|
||||
)
|
||||
|
||||
def get_employee_role(*, name: str) -> Optional[str]:
|
||||
for e in EMPLOYEES:
|
||||
if e.name == name:
|
||||
return e.role
|
||||
return None
|
||||
|
||||
def get_employee_department(*, name: str) -> Optional[str]:
|
||||
for e in EMPLOYEES:
|
||||
if e.name == name:
|
||||
return e.department
|
||||
return None
|
||||
|
||||
role_tool = _make_tool(
|
||||
get_employee_role,
|
||||
name="get_employee_role",
|
||||
description="Get the employee role by name",
|
||||
)
|
||||
dept_tool = _make_tool(
|
||||
get_employee_department,
|
||||
name="get_employee_department",
|
||||
description="Get the employee department by name",
|
||||
)
|
||||
|
||||
response_format_spec = case.response_format
|
||||
if isinstance(response_format_spec, dict):
|
||||
response_format_spec = [response_format_spec]
|
||||
# Unwrap nested schema objects
|
||||
response_format_spec = [item.get("schema", item) for item in response_format_spec]
|
||||
if len(response_format_spec) == 1:
|
||||
tool_output = ToolStrategy(response_format_spec[0])
|
||||
else:
|
||||
tool_output = ToolStrategy({"oneOf": response_format_spec})
|
||||
|
||||
llm_request_count = 0
|
||||
|
||||
for assertion in case.assertions_by_invocation:
|
||||
|
||||
def on_request(request: httpx.Request) -> None:
|
||||
nonlocal llm_request_count
|
||||
llm_request_count += 1
|
||||
|
||||
http_client = httpx.Client(
|
||||
event_hooks={"request": [on_request]},
|
||||
)
|
||||
|
||||
model = ChatOpenAI(
|
||||
model="gpt-4o",
|
||||
temperature=0,
|
||||
http_client=http_client,
|
||||
)
|
||||
|
||||
agent = create_agent(
|
||||
model,
|
||||
tools=[role_tool["tool"], dept_tool["tool"]],
|
||||
prompt=AGENT_PROMPT,
|
||||
response_format=tool_output,
|
||||
)
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage(assertion.prompt)]})
|
||||
|
||||
# Count tool calls
|
||||
assert role_tool["mock"].call_count == assertion.tools_with_expected_calls.get_employee_role
|
||||
assert (
|
||||
dept_tool["mock"].call_count
|
||||
== assertion.tools_with_expected_calls.get_employee_department
|
||||
)
|
||||
|
||||
# Count LLM calls
|
||||
assert llm_request_count == assertion.llm_request_count
|
||||
|
||||
# Check last message content
|
||||
last_message = result["messages"][-1]
|
||||
assert last_message.content == assertion.expected_last_message
|
||||
|
||||
# Check structured response
|
||||
structured_response_json = result["structured_response"]
|
||||
assert structured_response_json == assertion.expected_structured_response
|
||||
@@ -0,0 +1,107 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
# Skip this test since langgraph.prebuilt.responses is not available
|
||||
pytest.skip("langgraph.prebuilt.responses not available", allow_module_level=True)
|
||||
|
||||
try:
|
||||
from langchain_openai import ChatOpenAI
|
||||
except ImportError:
|
||||
skip_openai_integration_tests = True
|
||||
else:
|
||||
skip_openai_integration_tests = False
|
||||
|
||||
AGENT_PROMPT = """
|
||||
You are a strict polling bot.
|
||||
|
||||
- Only use the "poll_job" tool until it returns { status: "succeeded" }.
|
||||
- If status is "pending", call the tool again. Do not produce a final answer.
|
||||
- When it is "succeeded", return exactly: "Attempts: <number>" with no extra text.
|
||||
"""
|
||||
|
||||
|
||||
class TestCase(BaseSchema):
|
||||
name: str
|
||||
return_direct: bool
|
||||
response_format: Optional[Dict[str, Any]]
|
||||
expected_tool_calls: int
|
||||
expected_last_message: str
|
||||
expected_structured_response: Optional[Dict[str, Any]]
|
||||
|
||||
|
||||
TEST_CASES = load_spec("return_direct", as_model=TestCase)
|
||||
|
||||
|
||||
def _make_tool(return_direct: bool):
|
||||
attempts = 0
|
||||
|
||||
def _side_effect():
|
||||
nonlocal attempts
|
||||
attempts += 1
|
||||
return {
|
||||
"status": "succeeded" if attempts >= 10 else "pending",
|
||||
"attempts": attempts,
|
||||
}
|
||||
|
||||
mock = MagicMock(side_effect=_side_effect)
|
||||
|
||||
@tool(
|
||||
"pollJob",
|
||||
description=(
|
||||
"Check the status of a long-running job. "
|
||||
"Returns { status: 'pending' | 'succeeded', attempts: number }."
|
||||
),
|
||||
return_direct=return_direct,
|
||||
)
|
||||
def _wrapped():
|
||||
return mock()
|
||||
|
||||
return {"tool": _wrapped, "mock": mock}
|
||||
|
||||
|
||||
@pytest.mark.skipif(skip_openai_integration_tests, reason="OpenAI integration tests are disabled.")
|
||||
@pytest.mark.parametrize("case", TEST_CASES, ids=[c.name for c in TEST_CASES])
|
||||
def test_return_direct_integration_matrix(case: TestCase) -> None:
|
||||
poll_tool = _make_tool(case.return_direct)
|
||||
|
||||
model = ChatOpenAI(
|
||||
model="gpt-4o",
|
||||
temperature=0,
|
||||
)
|
||||
|
||||
if case.response_format:
|
||||
agent = create_agent(
|
||||
model,
|
||||
tools=[poll_tool["tool"]],
|
||||
prompt=AGENT_PROMPT,
|
||||
response_format=ToolStrategy(case.response_format),
|
||||
)
|
||||
else:
|
||||
agent = create_agent(
|
||||
model,
|
||||
tools=[poll_tool["tool"]],
|
||||
prompt=AGENT_PROMPT,
|
||||
)
|
||||
|
||||
result = agent.invoke(
|
||||
{
|
||||
"messages": [
|
||||
HumanMessage("Poll the job until it's done and tell me how many attempts it took.")
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
# Count tool calls
|
||||
assert poll_tool["mock"].call_count == case.expected_tool_calls
|
||||
|
||||
# Check last message content
|
||||
last_message = result["messages"][-1]
|
||||
assert last_message.content == case.expected_last_message
|
||||
|
||||
# Check structured response
|
||||
if case.expected_structured_response is not None:
|
||||
structured_response_json = result["structured_response"]
|
||||
assert structured_response_json == case.expected_structured_response
|
||||
else:
|
||||
assert "structured_response" not in result
|
||||
1482
libs/langchain_v1/tests/unit_tests/agents/test_tool_node.py
Normal file
1482
libs/langchain_v1/tests/unit_tests/agents/test_tool_node.py
Normal file
File diff suppressed because it is too large
Load Diff
21
libs/langchain_v1/tests/unit_tests/agents/utils.py
Normal file
21
libs/langchain_v1/tests/unit_tests/agents/utils.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic.alias_generators import to_camel
|
||||
|
||||
|
||||
class BaseSchema(BaseModel):
|
||||
model_config = ConfigDict(
|
||||
alias_generator=to_camel,
|
||||
populate_by_name=True,
|
||||
from_attributes=True,
|
||||
)
|
||||
|
||||
|
||||
def load_spec(spec_name: str, as_model: type[BaseModel]) -> list[BaseModel]:
|
||||
with (Path(__file__).parent / "specifications" / f"{spec_name}.json").open(
|
||||
"r", encoding="utf-8"
|
||||
) as f:
|
||||
data = json.load(f)
|
||||
return [as_model(**item) for item in data]
|
||||
@@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
@@ -38,7 +38,7 @@ def test_all_imports() -> None:
|
||||
("mixtral-8x7b-32768", "groq"),
|
||||
],
|
||||
)
|
||||
def test_init_chat_model(model_name: str, model_provider: Optional[str]) -> None:
|
||||
def test_init_chat_model(model_name: str, model_provider: str | None) -> None:
|
||||
llm1: BaseChatModel = init_chat_model(
|
||||
model_name,
|
||||
model_provider=model_provider,
|
||||
@@ -222,7 +222,7 @@ def test_configurable_with_default() -> None:
|
||||
config={"configurable": {"my_model_model": "claude-3-sonnet-20240229"}}
|
||||
)
|
||||
|
||||
""" # noqa: E501
|
||||
"""
|
||||
model = init_chat_model("gpt-4o", configurable_fields="any", config_prefix="bar")
|
||||
for method in (
|
||||
"invoke",
|
||||
|
||||
@@ -1,42 +1,9 @@
|
||||
"""Configuration for unit tests."""
|
||||
|
||||
from collections.abc import Iterator, Sequence
|
||||
from collections.abc import Sequence
|
||||
from importlib import util
|
||||
|
||||
import pytest
|
||||
from blockbuster import blockbuster_ctx
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def blockbuster() -> Iterator[None]:
|
||||
with blockbuster_ctx("langchain") as bb:
|
||||
bb.functions["io.TextIOWrapper.read"].can_block_in(
|
||||
"langchain/__init__.py",
|
||||
"<module>",
|
||||
)
|
||||
|
||||
for func in ["os.stat", "os.path.abspath"]:
|
||||
(
|
||||
bb.functions[func]
|
||||
.can_block_in("langchain_core/runnables/base.py", "__repr__")
|
||||
.can_block_in(
|
||||
"langchain_core/beta/runnables/context.py",
|
||||
"aconfig_with_context",
|
||||
)
|
||||
)
|
||||
|
||||
for func in ["os.stat", "io.TextIOWrapper.read"]:
|
||||
bb.functions[func].can_block_in(
|
||||
"langsmith/client.py",
|
||||
"_default_retry_config",
|
||||
)
|
||||
|
||||
for bb_function in bb.functions.values():
|
||||
bb_function.can_block_in(
|
||||
"freezegun/api.py",
|
||||
"_get_cached_module_attributes",
|
||||
)
|
||||
yield
|
||||
|
||||
|
||||
def pytest_addoption(parser: pytest.Parser) -> None:
|
||||
@@ -53,9 +20,7 @@ def pytest_addoption(parser: pytest.Parser) -> None:
|
||||
)
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(
|
||||
config: pytest.Config, items: Sequence[pytest.Function]
|
||||
) -> None:
|
||||
def pytest_collection_modifyitems(config: pytest.Config, items: Sequence[pytest.Function]) -> None:
|
||||
"""Add implementations for handling custom markers.
|
||||
|
||||
At the moment, this adds support for a custom `requires` marker.
|
||||
|
||||
@@ -113,9 +113,7 @@ async def test_aembed_documents(cache_embeddings: CacheBackedEmbeddings) -> None
|
||||
vectors = await cache_embeddings.aembed_documents(texts)
|
||||
expected_vectors: list[list[float]] = [[1, 2.0], [2.0, 3.0], [1.0, 2.0], [3.0, 4.0]]
|
||||
assert vectors == expected_vectors
|
||||
keys = [
|
||||
key async for key in cache_embeddings.document_embedding_store.ayield_keys()
|
||||
]
|
||||
keys = [key async for key in cache_embeddings.document_embedding_store.ayield_keys()]
|
||||
assert len(keys) == 4
|
||||
# UUID is expected to be the same for the same text
|
||||
assert keys[0] == "test_namespace812b86c1-8ebf-5483-95c6-c95cf2b52d12"
|
||||
@@ -128,10 +126,7 @@ async def test_aembed_documents_batch(
|
||||
texts = ["1", "22", "a", "333", "RAISE_EXCEPTION"]
|
||||
with contextlib.suppress(ValueError):
|
||||
await cache_embeddings_batch.aembed_documents(texts)
|
||||
keys = [
|
||||
key
|
||||
async for key in cache_embeddings_batch.document_embedding_store.ayield_keys()
|
||||
]
|
||||
keys = [key async for key in cache_embeddings_batch.document_embedding_store.ayield_keys()]
|
||||
# only the first batch of three embeddings should exist
|
||||
assert len(keys) == 3
|
||||
# UUID is expected to be the same for the same text
|
||||
|
||||
@@ -13,9 +13,7 @@ def test_import_all() -> None:
|
||||
library_code = PKG_ROOT / "langchain"
|
||||
for path in library_code.rglob("*.py"):
|
||||
# Calculate the relative path to the module
|
||||
module_name = (
|
||||
path.relative_to(PKG_ROOT).with_suffix("").as_posix().replace("/", ".")
|
||||
)
|
||||
module_name = path.relative_to(PKG_ROOT).with_suffix("").as_posix().replace("/", ".")
|
||||
if module_name.endswith("__init__"):
|
||||
# Without init
|
||||
module_name = module_name.rsplit(".", 1)[0]
|
||||
@@ -39,9 +37,7 @@ def test_import_all_using_dir() -> None:
|
||||
library_code = PKG_ROOT / "langchain"
|
||||
for path in library_code.rglob("*.py"):
|
||||
# Calculate the relative path to the module
|
||||
module_name = (
|
||||
path.relative_to(PKG_ROOT).with_suffix("").as_posix().replace("/", ".")
|
||||
)
|
||||
module_name = path.relative_to(PKG_ROOT).with_suffix("").as_posix().replace("/", ".")
|
||||
if module_name.endswith("__init__"):
|
||||
# Without init
|
||||
module_name = module_name.rsplit(".", 1)[0]
|
||||
|
||||
2490
libs/langchain_v1/uv.lock
generated
2490
libs/langchain_v1/uv.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -999,24 +999,27 @@ class ChatAnthropic(BaseChatModel):
|
||||
|
||||
.. dropdown:: Extended caching
|
||||
|
||||
.. versionadded:: 0.3.15
|
||||
|
||||
The cache lifetime is 5 minutes by default. If this is too short, you can
|
||||
apply one hour caching by enabling the ``'extended-cache-ttl-2025-04-11'``
|
||||
beta header:
|
||||
apply one hour caching by setting ``ttl`` to ``'1h'``.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
llm = ChatAnthropic(
|
||||
model="claude-3-7-sonnet-20250219",
|
||||
betas=["extended-cache-ttl-2025-04-11"],
|
||||
)
|
||||
|
||||
and specifying ``"cache_control": {"type": "ephemeral", "ttl": "1h"}``.
|
||||
messages = [{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"{long_text}",
|
||||
"cache_control": {"type": "ephemeral", "ttl": "1h"},
|
||||
},
|
||||
],
|
||||
}]
|
||||
|
||||
.. important::
|
||||
Specifying a `ttl` key under `cache_control` will not work unless the
|
||||
beta header is set!
|
||||
response = llm.invoke(messages)
|
||||
|
||||
Details of cached token counts will be included on the ``InputTokenDetails``
|
||||
of response's ``usage_metadata``:
|
||||
@@ -1432,23 +1435,6 @@ class ChatAnthropic(BaseChatModel):
|
||||
# If cache_control is provided in kwargs, add it to last message
|
||||
# and content block.
|
||||
if "cache_control" in kwargs and formatted_messages:
|
||||
cache_control = kwargs["cache_control"]
|
||||
|
||||
# Validate TTL usage requires extended cache TTL beta header
|
||||
if (
|
||||
isinstance(cache_control, dict)
|
||||
and "ttl" in cache_control
|
||||
and (
|
||||
not self.betas or "extended-cache-ttl-2025-04-11" not in self.betas
|
||||
)
|
||||
):
|
||||
msg = (
|
||||
"Specifying a 'ttl' under 'cache_control' requires enabling "
|
||||
"the 'extended-cache-ttl-2025-04-11' beta header. "
|
||||
"Set betas=['extended-cache-ttl-2025-04-11'] when initializing "
|
||||
"ChatAnthropic."
|
||||
)
|
||||
warnings.warn(msg, stacklevel=2)
|
||||
if isinstance(formatted_messages[-1]["content"], list):
|
||||
formatted_messages[-1]["content"][-1]["cache_control"] = kwargs.pop(
|
||||
"cache_control"
|
||||
|
||||
@@ -128,6 +128,9 @@ markers = [
|
||||
]
|
||||
asyncio_mode = "auto"
|
||||
|
||||
[tool.ruff.lint.pydocstyle]
|
||||
convention = "google"
|
||||
|
||||
[tool.ruff.lint.extend-per-file-ignores]
|
||||
"tests/**/*.py" = [
|
||||
"S101", # Tests need assertions
|
||||
|
||||
@@ -395,7 +395,7 @@ def function() -> Callable:
|
||||
arg1: foo
|
||||
arg2: one of 'bar', 'baz'
|
||||
|
||||
""" # noqa: D401
|
||||
"""
|
||||
|
||||
return dummy_function
|
||||
|
||||
|
||||
13
libs/partners/anthropic/uv.lock
generated
13
libs/partners/anthropic/uv.lock
generated
@@ -505,7 +505,7 @@ typing = [
|
||||
|
||||
[[package]]
|
||||
name = "langchain-core"
|
||||
version = "0.3.74"
|
||||
version = "0.3.75"
|
||||
source = { editable = "../../core" }
|
||||
dependencies = [
|
||||
{ name = "jsonpatch" },
|
||||
@@ -582,28 +582,27 @@ dependencies = [
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "httpx", specifier = ">=0.25.0,<1" },
|
||||
{ name = "httpx", specifier = ">=0.28.1,<1" },
|
||||
{ name = "langchain-core", editable = "../../core" },
|
||||
{ name = "numpy", marker = "python_full_version < '3.13'", specifier = ">=1.26.2" },
|
||||
{ name = "numpy", marker = "python_full_version >= '3.13'", specifier = ">=2.1.0" },
|
||||
{ name = "pytest", specifier = ">=7,<9" },
|
||||
{ name = "pytest-asyncio", specifier = ">=0.20,<1" },
|
||||
{ name = "pytest-asyncio", specifier = ">=0.20,<2" },
|
||||
{ name = "pytest-benchmark" },
|
||||
{ name = "pytest-codspeed" },
|
||||
{ name = "pytest-recording" },
|
||||
{ name = "pytest-socket", specifier = ">=0.6.0,<1" },
|
||||
{ name = "pytest-socket", specifier = ">=0.7.0,<1" },
|
||||
{ name = "syrupy", specifier = ">=4,<5" },
|
||||
{ name = "vcrpy", specifier = ">=7.0" },
|
||||
]
|
||||
|
||||
[package.metadata.requires-dev]
|
||||
codespell = [{ name = "codespell", specifier = ">=2.2.0,<3.0.0" }]
|
||||
lint = [{ name = "ruff", specifier = ">=0.12.8,<0.13" }]
|
||||
lint = [{ name = "ruff", specifier = ">=0.12.10,<0.13" }]
|
||||
test = [{ name = "langchain-core", editable = "../../core" }]
|
||||
test-integration = []
|
||||
typing = [
|
||||
{ name = "langchain-core", editable = "../../core" },
|
||||
{ name = "mypy", specifier = ">=1,<2" },
|
||||
{ name = "mypy", specifier = ">=1.17.1,<2" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user