mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-18 04:25:22 +00:00
Compare commits
84 Commits
langchain-
...
cc/summari
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2c938b787f | ||
|
|
c63f23d233 | ||
|
|
b7091d391d | ||
|
|
7a2952210e | ||
|
|
7549845d82 | ||
|
|
fa18f8eda0 | ||
|
|
878f033ed7 | ||
|
|
4065106c2e | ||
|
|
12df938ace | ||
|
|
65ee43cc10 | ||
|
|
fe7c000fc1 | ||
|
|
dad50e5624 | ||
|
|
0a6d01e61d | ||
|
|
c6f8b0875a | ||
|
|
4c3800d743 | ||
|
|
7fe1c4b78f | ||
|
|
c375732396 | ||
|
|
b2db842cd4 | ||
|
|
9c21f83e82 | ||
|
|
880652b713 | ||
|
|
4ab94579ad | ||
|
|
eb0545a173 | ||
|
|
a2e389de9f | ||
|
|
01573c1375 | ||
|
|
2ba3ce81a6 | ||
|
|
4e4e5d7337 | ||
|
|
2a863727f9 | ||
|
|
30e2260e26 | ||
|
|
cbaea351b2 | ||
|
|
f070217c3b | ||
|
|
0915682c12 | ||
|
|
68ab9a1e56 | ||
|
|
47b79c30c0 | ||
|
|
5899f980aa | ||
|
|
b0bf4afe81 | ||
|
|
33e5d01f7c | ||
|
|
ee3373afc2 | ||
|
|
b296f103a9 | ||
|
|
525d5c0169 | ||
|
|
c4b6ba254e | ||
|
|
b7d1831f9d | ||
|
|
328ba36601 | ||
|
|
6f677ef5c1 | ||
|
|
d47d41cbd3 | ||
|
|
32bbe99efc | ||
|
|
990e346c46 | ||
|
|
9b7792631d | ||
|
|
558a8fe25b | ||
|
|
52b1516d44 | ||
|
|
8a3bb73c05 | ||
|
|
099c042395 | ||
|
|
2d4f00a451 | ||
|
|
9bd401a6d4 | ||
|
|
6aa3794b74 | ||
|
|
189dcf7295 | ||
|
|
1bc88028e6 | ||
|
|
d2942351ce | ||
|
|
83c078f363 | ||
|
|
26d39ffc4a | ||
|
|
421e2ceeee | ||
|
|
275dcbf69f | ||
|
|
9f87b27a5b | ||
|
|
b2e1196e29 | ||
|
|
2dc1396380 | ||
|
|
77941ab3ce | ||
|
|
ee19a30dde | ||
|
|
5d799b3174 | ||
|
|
8f33a985a2 | ||
|
|
78eeccef0e | ||
|
|
3d415441e8 | ||
|
|
74385e0ebd | ||
|
|
2bfbc29ccc | ||
|
|
ef79c26f18 | ||
|
|
fbe32c8e89 | ||
|
|
2511c28f92 | ||
|
|
637bb1cbbc | ||
|
|
3dfea96ec1 | ||
|
|
68643153e5 | ||
|
|
462762f75b | ||
|
|
4f3729c004 | ||
|
|
ba428cdf54 | ||
|
|
69c7d1b01b | ||
|
|
733299ec13 | ||
|
|
e1adf781c6 |
77
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
77
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
@@ -8,16 +8,15 @@ body:
|
||||
value: |
|
||||
Thank you for taking the time to file a bug report.
|
||||
|
||||
Use this to report BUGS in LangChain. For usage questions, feature requests and general design questions, please use the [LangChain Forum](https://forum.langchain.com/).
|
||||
For usage questions, feature requests and general design questions, please use the [LangChain Forum](https://forum.langchain.com/).
|
||||
|
||||
Relevant links to check before filing a bug report to see if your issue has already been reported, fixed or
|
||||
if there's another way to solve your problem:
|
||||
Check these before submitting to see if your issue has already been reported, fixed or if there's another way to solve your problem:
|
||||
|
||||
* [LangChain Forum](https://forum.langchain.com/),
|
||||
* [LangChain documentation with the integrated search](https://docs.langchain.com/oss/python/langchain/overview),
|
||||
* [API Reference](https://reference.langchain.com/python/),
|
||||
* [Documentation](https://docs.langchain.com/oss/python/langchain/overview),
|
||||
* [API Reference Documentation](https://reference.langchain.com/python/),
|
||||
* [LangChain ChatBot](https://chat.langchain.com/)
|
||||
* [GitHub search](https://github.com/langchain-ai/langchain),
|
||||
* [LangChain Forum](https://forum.langchain.com/),
|
||||
- type: checkboxes
|
||||
id: checks
|
||||
attributes:
|
||||
@@ -36,16 +35,48 @@ body:
|
||||
required: true
|
||||
- label: This is not related to the langchain-community package.
|
||||
required: true
|
||||
- label: I read what a minimal reproducible example is (https://stackoverflow.com/help/minimal-reproducible-example).
|
||||
required: true
|
||||
- label: I posted a self-contained, minimal, reproducible example. A maintainer can copy it and run it AS IS.
|
||||
required: true
|
||||
- type: checkboxes
|
||||
id: package
|
||||
attributes:
|
||||
label: Package (Required)
|
||||
description: |
|
||||
Which `langchain` package(s) is this bug related to? Select at least one.
|
||||
|
||||
Note that if the package you are reporting for is not listed here, it is not in this repository (e.g. `langchain-google-genai` is in [`langchain-ai/langchain-google`](https://github.com/langchain-ai/langchain-google/)).
|
||||
|
||||
Please report issues for other packages to their respective repositories.
|
||||
options:
|
||||
- label: langchain
|
||||
- label: langchain-openai
|
||||
- label: langchain-anthropic
|
||||
- label: langchain-classic
|
||||
- label: langchain-core
|
||||
- label: langchain-cli
|
||||
- label: langchain-model-profiles
|
||||
- label: langchain-tests
|
||||
- label: langchain-text-splitters
|
||||
- label: langchain-chroma
|
||||
- label: langchain-deepseek
|
||||
- label: langchain-exa
|
||||
- label: langchain-fireworks
|
||||
- label: langchain-groq
|
||||
- label: langchain-huggingface
|
||||
- label: langchain-mistralai
|
||||
- label: langchain-nomic
|
||||
- label: langchain-ollama
|
||||
- label: langchain-perplexity
|
||||
- label: langchain-prompty
|
||||
- label: langchain-qdrant
|
||||
- label: langchain-xai
|
||||
- label: Other / not sure / general
|
||||
- type: textarea
|
||||
id: reproduction
|
||||
validations:
|
||||
required: true
|
||||
attributes:
|
||||
label: Example Code
|
||||
label: Example Code (Python)
|
||||
description: |
|
||||
Please add a self-contained, [minimal, reproducible, example](https://stackoverflow.com/help/minimal-reproducible-example) with your use case.
|
||||
|
||||
@@ -53,15 +84,12 @@ body:
|
||||
|
||||
**Important!**
|
||||
|
||||
* Avoid screenshots when possible, as they are hard to read and (more importantly) don't allow others to copy-and-paste your code.
|
||||
* Reduce your code to the minimum required to reproduce the issue if possible. This makes it much easier for others to help you.
|
||||
* Use code tags (e.g., ```python ... ```) to correctly [format your code](https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting).
|
||||
* INCLUDE the language label (e.g. `python`) after the first three backticks to enable syntax highlighting. (e.g., ```python rather than ```).
|
||||
* Avoid screenshots, as they are hard to read and (more importantly) don't allow others to copy-and-paste your code.
|
||||
* Reduce your code to the minimum required to reproduce the issue if possible.
|
||||
|
||||
(This will be automatically formatted into code, so no need for backticks.)
|
||||
render: python
|
||||
placeholder: |
|
||||
The following code:
|
||||
|
||||
```python
|
||||
from langchain_core.runnables import RunnableLambda
|
||||
|
||||
def bad_code(inputs) -> int:
|
||||
@@ -69,17 +97,14 @@ body:
|
||||
|
||||
chain = RunnableLambda(bad_code)
|
||||
chain.invoke('Hello!')
|
||||
```
|
||||
- type: textarea
|
||||
id: error
|
||||
validations:
|
||||
required: false
|
||||
attributes:
|
||||
label: Error Message and Stack Trace (if applicable)
|
||||
description: |
|
||||
If you are reporting an error, please include the full error message and stack trace.
|
||||
placeholder: |
|
||||
Exception + full stack trace
|
||||
If you are reporting an error, please copy and paste the full error message and
|
||||
stack trace.
|
||||
(This will be automatically formatted into code, so no need for backticks.)
|
||||
render: shell
|
||||
- type: textarea
|
||||
id: description
|
||||
attributes:
|
||||
@@ -99,9 +124,7 @@ body:
|
||||
attributes:
|
||||
label: System Info
|
||||
description: |
|
||||
Please share your system info with us. Do NOT skip this step and please don't trim
|
||||
the output. Most users don't include enough information here and it makes it harder
|
||||
for us to help you.
|
||||
Please share your system info with us.
|
||||
|
||||
Run the following command in your terminal and paste the output here:
|
||||
|
||||
@@ -113,8 +136,6 @@ body:
|
||||
from langchain_core import sys_info
|
||||
sys_info.print_sys_info()
|
||||
```
|
||||
|
||||
alternatively, put the entire output of `pip freeze` here.
|
||||
placeholder: |
|
||||
python -m langchain_core.sys_info
|
||||
validations:
|
||||
|
||||
11
.github/ISSUE_TEMPLATE/config.yml
vendored
11
.github/ISSUE_TEMPLATE/config.yml
vendored
@@ -1,9 +1,18 @@
|
||||
blank_issues_enabled: false
|
||||
version: 2.1
|
||||
contact_links:
|
||||
- name: 📚 Documentation
|
||||
- name: 📚 Documentation issue
|
||||
url: https://github.com/langchain-ai/docs/issues/new?template=01-langchain.yml
|
||||
about: Report an issue related to the LangChain documentation
|
||||
- name: 💬 LangChain Forum
|
||||
url: https://forum.langchain.com/
|
||||
about: General community discussions and support
|
||||
- name: 📚 LangChain Documentation
|
||||
url: https://docs.langchain.com/oss/python/langchain/overview
|
||||
about: View the official LangChain documentation
|
||||
- name: 📚 API Reference Documentation
|
||||
url: https://reference.langchain.com/python/
|
||||
about: View the official LangChain API reference documentation
|
||||
- name: 💬 LangChain Forum
|
||||
url: https://forum.langchain.com/
|
||||
about: Ask questions and get help from the community
|
||||
|
||||
40
.github/ISSUE_TEMPLATE/feature-request.yml
vendored
40
.github/ISSUE_TEMPLATE/feature-request.yml
vendored
@@ -13,11 +13,11 @@ body:
|
||||
Relevant links to check before filing a feature request to see if your request has already been made or
|
||||
if there's another way to achieve what you want:
|
||||
|
||||
* [LangChain Forum](https://forum.langchain.com/),
|
||||
* [LangChain documentation with the integrated search](https://docs.langchain.com/oss/python/langchain/overview),
|
||||
* [API Reference](https://reference.langchain.com/python/),
|
||||
* [Documentation](https://docs.langchain.com/oss/python/langchain/overview),
|
||||
* [API Reference Documentation](https://reference.langchain.com/python/),
|
||||
* [LangChain ChatBot](https://chat.langchain.com/)
|
||||
* [GitHub search](https://github.com/langchain-ai/langchain),
|
||||
* [LangChain Forum](https://forum.langchain.com/),
|
||||
- type: checkboxes
|
||||
id: checks
|
||||
attributes:
|
||||
@@ -34,6 +34,40 @@ body:
|
||||
required: true
|
||||
- label: This is not related to the langchain-community package.
|
||||
required: true
|
||||
- type: checkboxes
|
||||
id: package
|
||||
attributes:
|
||||
label: Package (Required)
|
||||
description: |
|
||||
Which `langchain` package(s) is this request related to? Select at least one.
|
||||
|
||||
Note that if the package you are requesting for is not listed here, it is not in this repository (e.g. `langchain-google-genai` is in `langchain-ai/langchain`).
|
||||
|
||||
Please submit feature requests for other packages to their respective repositories.
|
||||
options:
|
||||
- label: langchain
|
||||
- label: langchain-openai
|
||||
- label: langchain-anthropic
|
||||
- label: langchain-classic
|
||||
- label: langchain-core
|
||||
- label: langchain-cli
|
||||
- label: langchain-model-profiles
|
||||
- label: langchain-tests
|
||||
- label: langchain-text-splitters
|
||||
- label: langchain-chroma
|
||||
- label: langchain-deepseek
|
||||
- label: langchain-exa
|
||||
- label: langchain-fireworks
|
||||
- label: langchain-groq
|
||||
- label: langchain-huggingface
|
||||
- label: langchain-mistralai
|
||||
- label: langchain-nomic
|
||||
- label: langchain-ollama
|
||||
- label: langchain-perplexity
|
||||
- label: langchain-prompty
|
||||
- label: langchain-qdrant
|
||||
- label: langchain-xai
|
||||
- label: Other / not sure / general
|
||||
- type: textarea
|
||||
id: feature-description
|
||||
validations:
|
||||
|
||||
30
.github/ISSUE_TEMPLATE/privileged.yml
vendored
30
.github/ISSUE_TEMPLATE/privileged.yml
vendored
@@ -18,3 +18,33 @@ body:
|
||||
attributes:
|
||||
label: Issue Content
|
||||
description: Add the content of the issue here.
|
||||
- type: checkboxes
|
||||
id: package
|
||||
attributes:
|
||||
label: Package (Required)
|
||||
description: |
|
||||
Please select package(s) that this issue is related to.
|
||||
options:
|
||||
- label: langchain
|
||||
- label: langchain-openai
|
||||
- label: langchain-anthropic
|
||||
- label: langchain-classic
|
||||
- label: langchain-core
|
||||
- label: langchain-cli
|
||||
- label: langchain-model-profiles
|
||||
- label: langchain-tests
|
||||
- label: langchain-text-splitters
|
||||
- label: langchain-chroma
|
||||
- label: langchain-deepseek
|
||||
- label: langchain-exa
|
||||
- label: langchain-fireworks
|
||||
- label: langchain-groq
|
||||
- label: langchain-huggingface
|
||||
- label: langchain-mistralai
|
||||
- label: langchain-nomic
|
||||
- label: langchain-ollama
|
||||
- label: langchain-perplexity
|
||||
- label: langchain-prompty
|
||||
- label: langchain-qdrant
|
||||
- label: langchain-xai
|
||||
- label: Other / not sure / general
|
||||
|
||||
48
.github/ISSUE_TEMPLATE/task.yml
vendored
48
.github/ISSUE_TEMPLATE/task.yml
vendored
@@ -25,13 +25,13 @@ body:
|
||||
label: Task Description
|
||||
description: |
|
||||
Provide a clear and detailed description of the task.
|
||||
|
||||
|
||||
What needs to be done? Be specific about the scope and requirements.
|
||||
placeholder: |
|
||||
This task involves...
|
||||
|
||||
|
||||
The goal is to...
|
||||
|
||||
|
||||
Specific requirements:
|
||||
- ...
|
||||
- ...
|
||||
@@ -43,7 +43,7 @@ body:
|
||||
label: Acceptance Criteria
|
||||
description: |
|
||||
Define the criteria that must be met for this task to be considered complete.
|
||||
|
||||
|
||||
What are the specific deliverables or outcomes expected?
|
||||
placeholder: |
|
||||
This task will be complete when:
|
||||
@@ -58,15 +58,15 @@ body:
|
||||
label: Context and Background
|
||||
description: |
|
||||
Provide any relevant context, background information, or links to related issues/PRs.
|
||||
|
||||
|
||||
Why is this task needed? What problem does it solve?
|
||||
placeholder: |
|
||||
Background:
|
||||
- ...
|
||||
|
||||
|
||||
Related issues/PRs:
|
||||
- #...
|
||||
|
||||
|
||||
Additional context:
|
||||
- ...
|
||||
validations:
|
||||
@@ -77,15 +77,45 @@ body:
|
||||
label: Dependencies
|
||||
description: |
|
||||
List any dependencies or blockers for this task.
|
||||
|
||||
|
||||
Are there other tasks, issues, or external factors that need to be completed first?
|
||||
placeholder: |
|
||||
This task depends on:
|
||||
- [ ] Issue #...
|
||||
- [ ] PR #...
|
||||
- [ ] External dependency: ...
|
||||
|
||||
|
||||
Blocked by:
|
||||
- ...
|
||||
validations:
|
||||
required: false
|
||||
- type: checkboxes
|
||||
id: package
|
||||
attributes:
|
||||
label: Package (Required)
|
||||
description: |
|
||||
Please select package(s) that this task is related to.
|
||||
options:
|
||||
- label: langchain
|
||||
- label: langchain-openai
|
||||
- label: langchain-anthropic
|
||||
- label: langchain-classic
|
||||
- label: langchain-core
|
||||
- label: langchain-cli
|
||||
- label: langchain-model-profiles
|
||||
- label: langchain-tests
|
||||
- label: langchain-text-splitters
|
||||
- label: langchain-chroma
|
||||
- label: langchain-deepseek
|
||||
- label: langchain-exa
|
||||
- label: langchain-fireworks
|
||||
- label: langchain-groq
|
||||
- label: langchain-huggingface
|
||||
- label: langchain-mistralai
|
||||
- label: langchain-nomic
|
||||
- label: langchain-ollama
|
||||
- label: langchain-perplexity
|
||||
- label: langchain-prompty
|
||||
- label: langchain-qdrant
|
||||
- label: langchain-xai
|
||||
- label: Other / not sure / general
|
||||
|
||||
38
.github/PULL_REQUEST_TEMPLATE.md
vendored
38
.github/PULL_REQUEST_TEMPLATE.md
vendored
@@ -1,28 +1,30 @@
|
||||
(Replace this entire block of text)
|
||||
|
||||
Thank you for contributing to LangChain! Follow these steps to mark your pull request as ready for review. **If any of these steps are not completed, your PR will not be considered for review.**
|
||||
Read the full contributing guidelines: https://docs.langchain.com/oss/python/contributing/overview
|
||||
|
||||
Thank you for contributing to LangChain! Follow these steps to have your pull request considered as ready for review.
|
||||
|
||||
1. PR title: Should follow the format: TYPE(SCOPE): DESCRIPTION
|
||||
|
||||
- [ ] **PR title**: Follows the format: {TYPE}({SCOPE}): {DESCRIPTION}
|
||||
- Examples:
|
||||
- fix(anthropic): resolve flag parsing error
|
||||
- feat(core): add multi-tenant support
|
||||
- fix(cli): resolve flag parsing error
|
||||
- docs(openai): update API usage examples
|
||||
- Allowed `{TYPE}` values:
|
||||
- feat, fix, docs, style, refactor, perf, test, build, ci, chore, revert, release
|
||||
- Allowed `{SCOPE}` values (optional):
|
||||
- core, cli, langchain, standard-tests, text-splitters, docs, anthropic, chroma, deepseek, exa, fireworks, groq, huggingface, mistralai, nomic, ollama, openai, perplexity, prompty, qdrant, xai, infra
|
||||
- Once you've written the title, please delete this checklist item; do not include it in the PR.
|
||||
- test(openai): update API usage tests
|
||||
- Allowed TYPE and SCOPE values: https://github.com/langchain-ai/langchain/blob/master/.github/workflows/pr_lint.yml#L15-L33
|
||||
|
||||
- [ ] **PR message**: ***Delete this entire checklist*** and replace with
|
||||
- **Description:** a description of the change. Include a [closing keyword](https://docs.github.com/en/issues/tracking-your-work-with-issues/using-issues/linking-a-pull-request-to-an-issue#linking-a-pull-request-to-an-issue-using-a-keyword) if applicable to a relevant issue.
|
||||
- **Issue:** the issue # it fixes, if applicable (e.g. Fixes #123)
|
||||
- **Dependencies:** any dependencies required for this change
|
||||
2. PR description:
|
||||
|
||||
- [ ] **Lint and test**: Run `make format`, `make lint` and `make test` from the root of the package(s) you've modified. **We will not consider a PR unless these three are passing in CI.** See [contribution guidelines](https://docs.langchain.com/oss/python/contributing) for more.
|
||||
- Write 1-2 sentences summarizing the change.
|
||||
- If this PR addresses a specific issue, please include "Fixes #ISSUE_NUMBER" in the description to automatically close the issue when the PR is merged.
|
||||
- If there are any breaking changes, please clearly describe them.
|
||||
- If this PR depends on another PR being merged first, please include "Depends on #PR_NUMBER" inthe description.
|
||||
|
||||
3. Run `make format`, `make lint` and `make test` from the root of the package(s) you've modified.
|
||||
|
||||
- We will not consider a PR unless these three are passing in CI.
|
||||
|
||||
Additional guidelines:
|
||||
|
||||
- Most PRs should not touch more than one package.
|
||||
- Please do not add dependencies to `pyproject.toml` files (even optional ones) unless they are **required** for unit tests. Likewise, please do not update the `uv.lock` files unless you are adding a required dependency.
|
||||
- Changes should be backwards compatible.
|
||||
- Make sure optional dependencies are imported within a function.
|
||||
- We ask that if you use generative AI for your contribution, you include a disclaimer.
|
||||
- PRs should not touch more than one package unless absolutely necessary.
|
||||
- Do not update the `uv.lock` files unless or add dependencies to `pyproject.toml` files (even optional ones) unless you have explicit permission to do so by a maintainer.
|
||||
|
||||
@@ -35,7 +35,7 @@ jobs:
|
||||
timeout-minutes: 20
|
||||
name: "Python ${{ inputs.python-version }}"
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
- name: "🐍 Set up Python ${{ inputs.python-version }} + UV"
|
||||
uses: "./.github/actions/uv_setup"
|
||||
|
||||
2
.github/workflows/_lint.yml
vendored
2
.github/workflows/_lint.yml
vendored
@@ -38,7 +38,7 @@ jobs:
|
||||
timeout-minutes: 20
|
||||
steps:
|
||||
- name: "📋 Checkout Code"
|
||||
uses: actions/checkout@v5
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: "🐍 Set up Python ${{ inputs.python-version }} + UV"
|
||||
uses: "./.github/actions/uv_setup"
|
||||
|
||||
16
.github/workflows/_release.yml
vendored
16
.github/workflows/_release.yml
vendored
@@ -54,7 +54,7 @@ jobs:
|
||||
version: ${{ steps.check-version.outputs.version }}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
- name: Set up Python + uv
|
||||
uses: "./.github/actions/uv_setup"
|
||||
@@ -105,7 +105,7 @@ jobs:
|
||||
outputs:
|
||||
release-body: ${{ steps.generate-release-body.outputs.release-body }}
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: actions/checkout@v6
|
||||
with:
|
||||
repository: langchain-ai/langchain
|
||||
path: langchain
|
||||
@@ -206,7 +206,7 @@ jobs:
|
||||
id-token: write
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
- uses: actions/download-artifact@v6
|
||||
with:
|
||||
@@ -237,7 +237,7 @@ jobs:
|
||||
contents: read
|
||||
timeout-minutes: 20
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
# We explicitly *don't* set up caching here. This ensures our tests are
|
||||
# maximally sensitive to catching breakage.
|
||||
@@ -396,7 +396,7 @@ jobs:
|
||||
contents: read
|
||||
strategy:
|
||||
matrix:
|
||||
partner: [openai, anthropic]
|
||||
partner: [anthropic]
|
||||
fail-fast: false # Continue testing other partners if one fails
|
||||
env:
|
||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
@@ -412,7 +412,7 @@ jobs:
|
||||
AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT_NAME: ${{ secrets.AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT_NAME }}
|
||||
LANGCHAIN_TESTS_USER_AGENT: ${{ secrets.LANGCHAIN_TESTS_USER_AGENT }}
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
# We implement this conditional as Github Actions does not have good support
|
||||
# for conditionally needing steps. https://github.com/actions/runner/issues/491
|
||||
@@ -492,7 +492,7 @@ jobs:
|
||||
working-directory: ${{ inputs.working-directory }}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
- name: Set up Python + uv
|
||||
uses: "./.github/actions/uv_setup"
|
||||
@@ -532,7 +532,7 @@ jobs:
|
||||
working-directory: ${{ inputs.working-directory }}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
- name: Set up Python + uv
|
||||
uses: "./.github/actions/uv_setup"
|
||||
|
||||
2
.github/workflows/_test.yml
vendored
2
.github/workflows/_test.yml
vendored
@@ -33,7 +33,7 @@ jobs:
|
||||
name: "Python ${{ inputs.python-version }}"
|
||||
steps:
|
||||
- name: "📋 Checkout Code"
|
||||
uses: actions/checkout@v5
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: "🐍 Set up Python ${{ inputs.python-version }} + UV"
|
||||
uses: "./.github/actions/uv_setup"
|
||||
|
||||
2
.github/workflows/_test_pydantic.yml
vendored
2
.github/workflows/_test_pydantic.yml
vendored
@@ -36,7 +36,7 @@ jobs:
|
||||
name: "Pydantic ~=${{ inputs.pydantic-version }}"
|
||||
steps:
|
||||
- name: "📋 Checkout Code"
|
||||
uses: actions/checkout@v5
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: "🐍 Set up Python ${{ inputs.python-version }} + UV"
|
||||
uses: "./.github/actions/uv_setup"
|
||||
|
||||
107
.github/workflows/auto-label-by-package.yml
vendored
Normal file
107
.github/workflows/auto-label-by-package.yml
vendored
Normal file
@@ -0,0 +1,107 @@
|
||||
name: Auto Label Issues by Package
|
||||
|
||||
on:
|
||||
issues:
|
||||
types: [opened, edited]
|
||||
|
||||
jobs:
|
||||
label-by-package:
|
||||
permissions:
|
||||
issues: write
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Sync package labels
|
||||
uses: actions/github-script@v8
|
||||
with:
|
||||
script: |
|
||||
const body = context.payload.issue.body || "";
|
||||
|
||||
// Extract text under "### Package"
|
||||
const match = body.match(/### Package\s+([\s\S]*?)\n###/i);
|
||||
if (!match) return;
|
||||
|
||||
const packageSection = match[1].trim();
|
||||
|
||||
// Mapping table for package names to labels
|
||||
const mapping = {
|
||||
"langchain": "langchain",
|
||||
"langchain-openai": "openai",
|
||||
"langchain-anthropic": "anthropic",
|
||||
"langchain-classic": "langchain-classic",
|
||||
"langchain-core": "core",
|
||||
"langchain-cli": "cli",
|
||||
"langchain-model-profiles": "model-profiles",
|
||||
"langchain-tests": "standard-tests",
|
||||
"langchain-text-splitters": "text-splitters",
|
||||
"langchain-chroma": "chroma",
|
||||
"langchain-deepseek": "deepseek",
|
||||
"langchain-exa": "exa",
|
||||
"langchain-fireworks": "fireworks",
|
||||
"langchain-groq": "groq",
|
||||
"langchain-huggingface": "huggingface",
|
||||
"langchain-mistralai": "mistralai",
|
||||
"langchain-nomic": "nomic",
|
||||
"langchain-ollama": "ollama",
|
||||
"langchain-perplexity": "perplexity",
|
||||
"langchain-prompty": "prompty",
|
||||
"langchain-qdrant": "qdrant",
|
||||
"langchain-xai": "xai",
|
||||
};
|
||||
|
||||
// All possible package labels we manage
|
||||
const allPackageLabels = Object.values(mapping);
|
||||
const selectedLabels = [];
|
||||
|
||||
// Check if this is checkbox format (multiple selection)
|
||||
const checkboxMatches = packageSection.match(/- \[x\]\s+([^\n\r]+)/gi);
|
||||
if (checkboxMatches) {
|
||||
// Handle checkbox format
|
||||
for (const match of checkboxMatches) {
|
||||
const packageName = match.replace(/- \[x\]\s+/i, '').trim();
|
||||
const label = mapping[packageName];
|
||||
if (label && !selectedLabels.includes(label)) {
|
||||
selectedLabels.push(label);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Handle dropdown format (single selection)
|
||||
const label = mapping[packageSection];
|
||||
if (label) {
|
||||
selectedLabels.push(label);
|
||||
}
|
||||
}
|
||||
|
||||
// Get current issue labels
|
||||
const issue = await github.rest.issues.get({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: context.issue.number
|
||||
});
|
||||
|
||||
const currentLabels = issue.data.labels.map(label => label.name);
|
||||
const currentPackageLabels = currentLabels.filter(label => allPackageLabels.includes(label));
|
||||
|
||||
// Determine labels to add and remove
|
||||
const labelsToAdd = selectedLabels.filter(label => !currentPackageLabels.includes(label));
|
||||
const labelsToRemove = currentPackageLabels.filter(label => !selectedLabels.includes(label));
|
||||
|
||||
// Add new labels
|
||||
if (labelsToAdd.length > 0) {
|
||||
await github.rest.issues.addLabels({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: context.issue.number,
|
||||
labels: labelsToAdd
|
||||
});
|
||||
}
|
||||
|
||||
// Remove old labels
|
||||
for (const label of labelsToRemove) {
|
||||
await github.rest.issues.removeLabel({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: context.issue.number,
|
||||
name: label
|
||||
});
|
||||
}
|
||||
2
.github/workflows/check_core_versions.yml
vendored
2
.github/workflows/check_core_versions.yml
vendored
@@ -18,7 +18,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
- name: "✅ Verify pyproject.toml & version.py Match"
|
||||
run: |
|
||||
|
||||
6
.github/workflows/check_diffs.yml
vendored
6
.github/workflows/check_diffs.yml
vendored
@@ -47,7 +47,7 @@ jobs:
|
||||
if: ${{ !contains(github.event.pull_request.labels.*.name, 'ci-ignore') }}
|
||||
steps:
|
||||
- name: "📋 Checkout Code"
|
||||
uses: actions/checkout@v5
|
||||
uses: actions/checkout@v6
|
||||
- name: "🐍 Setup Python 3.11"
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
@@ -141,7 +141,7 @@ jobs:
|
||||
run:
|
||||
working-directory: ${{ matrix.job-configs.working-directory }}
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
- name: "🐍 Set up Python ${{ matrix.job-configs.python-version }} + UV"
|
||||
uses: "./.github/actions/uv_setup"
|
||||
@@ -182,7 +182,7 @@ jobs:
|
||||
job-configs: ${{ fromJson(needs.build.outputs.codspeed) }}
|
||||
fail-fast: false
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
- name: "📦 Install UV Package Manager"
|
||||
uses: astral-sh/setup-uv@v7
|
||||
|
||||
6
.github/workflows/integration_tests.yml
vendored
6
.github/workflows/integration_tests.yml
vendored
@@ -71,14 +71,14 @@ jobs:
|
||||
working-directory: ${{ fromJSON(needs.compute-matrix.outputs.matrix).working-directory }}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: actions/checkout@v6
|
||||
with:
|
||||
path: langchain
|
||||
- uses: actions/checkout@v5
|
||||
- uses: actions/checkout@v6
|
||||
with:
|
||||
repository: langchain-ai/langchain-google
|
||||
path: langchain-google
|
||||
- uses: actions/checkout@v5
|
||||
- uses: actions/checkout@v6
|
||||
with:
|
||||
repository: langchain-ai/langchain-aws
|
||||
path: langchain-aws
|
||||
|
||||
13
.github/workflows/pr_lint.yml
vendored
13
.github/workflows/pr_lint.yml
vendored
@@ -26,11 +26,13 @@
|
||||
# * revert — reverts a previous commit
|
||||
# * release — prepare a new release
|
||||
#
|
||||
# Allowed Scopes (optional):
|
||||
# core, cli, langchain, langchain_v1, langchain-classic, standard-tests,
|
||||
# text-splitters, docs, anthropic, chroma, deepseek, exa, fireworks, groq,
|
||||
# huggingface, mistralai, nomic, ollama, openai, perplexity, prompty, qdrant,
|
||||
# xai, infra, deps
|
||||
# Allowed Scope(s) (optional):
|
||||
# core, cli, langchain, langchain_v1, langchain-classic, model-profiles,
|
||||
# standard-tests, text-splitters, docs, anthropic, chroma, deepseek, exa,
|
||||
# fireworks, groq, huggingface, mistralai, nomic, ollama, openai,
|
||||
# perplexity, prompty, qdrant, xai, infra, deps
|
||||
#
|
||||
# Multiple scopes can be used by separating them with a comma.
|
||||
#
|
||||
# Rules:
|
||||
# 1. The 'Type' must start with a lowercase letter.
|
||||
@@ -100,6 +102,7 @@ jobs:
|
||||
qdrant
|
||||
xai
|
||||
infra
|
||||
deps
|
||||
requireScope: false
|
||||
disallowScopes: |
|
||||
release
|
||||
|
||||
4
.github/workflows/v03_api_doc_build.yml
vendored
4
.github/workflows/v03_api_doc_build.yml
vendored
@@ -23,12 +23,12 @@ jobs:
|
||||
permissions:
|
||||
contents: read
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: actions/checkout@v6
|
||||
with:
|
||||
ref: v0.3
|
||||
path: langchain
|
||||
|
||||
- uses: actions/checkout@v5
|
||||
- uses: actions/checkout@v6
|
||||
with:
|
||||
repository: langchain-ai/langchain-api-docs-html
|
||||
path: langchain-api-docs-html
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -163,3 +163,6 @@ node_modules
|
||||
|
||||
prof
|
||||
virtualenv/
|
||||
scratch/
|
||||
|
||||
.langgraph_api/
|
||||
|
||||
405
AGENTS.md
405
AGENTS.md
@@ -1,255 +1,58 @@
|
||||
# Global Development Guidelines for LangChain Projects
|
||||
# Global development guidelines for the LangChain monorepo
|
||||
|
||||
## Core Development Principles
|
||||
This document provides context to understand the LangChain Python project and assist with development.
|
||||
|
||||
### 1. Maintain Stable Public Interfaces ⚠️ CRITICAL
|
||||
## Project architecture and context
|
||||
|
||||
**Always attempt to preserve function signatures, argument positions, and names for exported/public methods.**
|
||||
### Monorepo structure
|
||||
|
||||
❌ **Bad - Breaking Change:**
|
||||
This is a Python monorepo with multiple independently versioned packages that use `uv`.
|
||||
|
||||
```python
|
||||
def get_user(id, verbose=False): # Changed from `user_id`
|
||||
pass
|
||||
```txt
|
||||
langchain/
|
||||
├── libs/
|
||||
│ ├── core/ # `langchain-core` primitives and base abstractions
|
||||
│ ├── langchain/ # `langchain-classic` (legacy, no new features)
|
||||
│ ├── langchain_v1/ # Actively maintained `langchain` package
|
||||
│ ├── partners/ # Third-party integrations
|
||||
│ │ ├── openai/ # OpenAI models and embeddings
|
||||
│ │ ├── anthropic/ # Anthropic (Claude) integration
|
||||
│ │ ├── ollama/ # Local model support
|
||||
│ │ └── ... (other integrations maintained by the LangChain team)
|
||||
│ ├── text-splitters/ # Document chunking utilities
|
||||
│ ├── standard-tests/ # Shared test suite for integrations
|
||||
│ ├── model-profiles/ # Model configuration profiles
|
||||
│ └── cli/ # Command-line interface tools
|
||||
├── .github/ # CI/CD workflows and templates
|
||||
├── .vscode/ # VSCode IDE standard settings and recommended extensions
|
||||
└── README.md # Information about LangChain
|
||||
```
|
||||
|
||||
✅ **Good - Stable Interface:**
|
||||
- **Core layer** (`langchain-core`): Base abstractions, interfaces, and protocols. Users should not need to know about this layer directly.
|
||||
- **Implementation layer** (`langchain`): Concrete implementations and high-level public utilities
|
||||
- **Integration layer** (`partners/`): Third-party service integrations. Note that this monorepo is not exhaustive of all LangChain integrations; some are maintained in separate repos, such as `langchain-ai/langchain-google` and `langchain-ai/langchain-aws`. Usually these repos are cloned at the same level as this monorepo, so if needed, you can refer to their code directly by navigating to `../langchain-google/` from this monorepo.
|
||||
- **Testing layer** (`standard-tests/`): Standardized integration tests for partner integrations
|
||||
|
||||
```python
|
||||
def get_user(user_id: str, verbose: bool = False) -> User:
|
||||
"""Retrieve user by ID with optional verbose output."""
|
||||
pass
|
||||
```
|
||||
### Development tools & commands**
|
||||
|
||||
**Before making ANY changes to public APIs:**
|
||||
- `uv` – Fast Python package installer and resolver (replaces pip/poetry)
|
||||
- `make` – Task runner for common development commands. Feel free to look at the `Makefile` for available commands and usage patterns.
|
||||
- `ruff` – Fast Python linter and formatter
|
||||
- `mypy` – Static type checking
|
||||
- `pytest` – Testing framework
|
||||
|
||||
- Check if the function/class is exported in `__init__.py`
|
||||
- Look for existing usage patterns in tests and examples
|
||||
- Use keyword-only arguments for new parameters: `*, new_param: str = "default"`
|
||||
- Mark experimental features clearly with docstring warnings (using MkDocs Material admonitions, like `!!! warning`)
|
||||
This monorepo uses `uv` for dependency management. Local development uses editable installs: `[tool.uv.sources]`
|
||||
|
||||
🧠 *Ask yourself:* "Would this change break someone's code if they used it last week?"
|
||||
|
||||
### 2. Code Quality Standards
|
||||
|
||||
**All Python code MUST include type hints and return types.**
|
||||
|
||||
❌ **Bad:**
|
||||
|
||||
```python
|
||||
def p(u, d):
|
||||
return [x for x in u if x not in d]
|
||||
```
|
||||
|
||||
✅ **Good:**
|
||||
|
||||
```python
|
||||
def filter_unknown_users(users: list[str], known_users: set[str]) -> list[str]:
|
||||
"""Filter out users that are not in the known users set.
|
||||
|
||||
Args:
|
||||
users: List of user identifiers to filter.
|
||||
known_users: Set of known/valid user identifiers.
|
||||
|
||||
Returns:
|
||||
List of users that are not in the known_users set.
|
||||
"""
|
||||
return [user for user in users if user not in known_users]
|
||||
```
|
||||
|
||||
**Style Requirements:**
|
||||
|
||||
- Use descriptive, **self-explanatory variable names**. Avoid overly short or cryptic identifiers.
|
||||
- Attempt to break up complex functions (>20 lines) into smaller, focused functions where it makes sense
|
||||
- Avoid unnecessary abstraction or premature optimization
|
||||
- Follow existing patterns in the codebase you're modifying
|
||||
|
||||
### 3. Testing Requirements
|
||||
|
||||
**Every new feature or bugfix MUST be covered by unit tests.**
|
||||
|
||||
**Test Organization:**
|
||||
|
||||
- Unit tests: `tests/unit_tests/` (no network calls allowed)
|
||||
- Integration tests: `tests/integration_tests/` (network calls permitted)
|
||||
- Use `pytest` as the testing framework
|
||||
|
||||
**Test Quality Checklist:**
|
||||
|
||||
- [ ] Tests fail when your new logic is broken
|
||||
- [ ] Happy path is covered
|
||||
- [ ] Edge cases and error conditions are tested
|
||||
- [ ] Use fixtures/mocks for external dependencies
|
||||
- [ ] Tests are deterministic (no flaky tests)
|
||||
|
||||
Checklist questions:
|
||||
|
||||
- [ ] Does the test suite fail if your new logic is broken?
|
||||
- [ ] Are all expected behaviors exercised (happy path, invalid input, etc)?
|
||||
- [ ] Do tests use fixtures or mocks where needed?
|
||||
|
||||
```python
|
||||
def test_filter_unknown_users():
|
||||
"""Test filtering unknown users from a list."""
|
||||
users = ["alice", "bob", "charlie"]
|
||||
known_users = {"alice", "bob"}
|
||||
|
||||
result = filter_unknown_users(users, known_users)
|
||||
|
||||
assert result == ["charlie"]
|
||||
assert len(result) == 1
|
||||
```
|
||||
|
||||
### 4. Security and Risk Assessment
|
||||
|
||||
**Security Checklist:**
|
||||
|
||||
- No `eval()`, `exec()`, or `pickle` on user-controlled input
|
||||
- Proper exception handling (no bare `except:`) and use a `msg` variable for error messages
|
||||
- Remove unreachable/commented code before committing
|
||||
- Race conditions or resource leaks (file handles, sockets, threads).
|
||||
- Ensure proper resource cleanup (file handles, connections)
|
||||
|
||||
❌ **Bad:**
|
||||
|
||||
```python
|
||||
def load_config(path):
|
||||
with open(path) as f:
|
||||
return eval(f.read()) # ⚠️ Never eval config
|
||||
```
|
||||
|
||||
✅ **Good:**
|
||||
|
||||
```python
|
||||
import json
|
||||
|
||||
def load_config(path: str) -> dict:
|
||||
with open(path) as f:
|
||||
return json.load(f)
|
||||
```
|
||||
|
||||
### 5. Documentation Standards
|
||||
|
||||
**Use Google-style docstrings with Args section for all public functions.**
|
||||
|
||||
❌ **Insufficient Documentation:**
|
||||
|
||||
```python
|
||||
def send_email(to, msg):
|
||||
"""Send an email to a recipient."""
|
||||
```
|
||||
|
||||
✅ **Complete Documentation:**
|
||||
|
||||
```python
|
||||
def send_email(to: str, msg: str, *, priority: str = "normal") -> bool:
|
||||
"""
|
||||
Send an email to a recipient with specified priority.
|
||||
|
||||
Args:
|
||||
to: The email address of the recipient.
|
||||
msg: The message body to send.
|
||||
priority: Email priority level (`'low'`, `'normal'`, `'high'`).
|
||||
|
||||
Returns:
|
||||
`True` if email was sent successfully, `False` otherwise.
|
||||
|
||||
Raises:
|
||||
`InvalidEmailError`: If the email address format is invalid.
|
||||
`SMTPConnectionError`: If unable to connect to email server.
|
||||
"""
|
||||
```
|
||||
|
||||
**Documentation Guidelines:**
|
||||
|
||||
- Types go in function signatures, NOT in docstrings
|
||||
- If a default is present, DO NOT repeat it in the docstring unless there is post-processing or it is set conditionally.
|
||||
- Focus on "why" rather than "what" in descriptions
|
||||
- Document all parameters, return values, and exceptions
|
||||
- Keep descriptions concise but clear
|
||||
- Ensure American English spelling (e.g., "behavior", not "behaviour")
|
||||
|
||||
📌 *Tip:* Keep descriptions concise but clear. Only document return values if non-obvious.
|
||||
|
||||
### 6. Architectural Improvements
|
||||
|
||||
**When you encounter code that could be improved, suggest better designs:**
|
||||
|
||||
❌ **Poor Design:**
|
||||
|
||||
```python
|
||||
def process_data(data, db_conn, email_client, logger):
|
||||
# Function doing too many things
|
||||
validated = validate_data(data)
|
||||
result = db_conn.save(validated)
|
||||
email_client.send_notification(result)
|
||||
logger.log(f"Processed {len(data)} items")
|
||||
return result
|
||||
```
|
||||
|
||||
✅ **Better Design:**
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class ProcessingResult:
|
||||
"""Result of data processing operation."""
|
||||
items_processed: int
|
||||
success: bool
|
||||
errors: List[str] = field(default_factory=list)
|
||||
|
||||
class DataProcessor:
|
||||
"""Handles data validation, storage, and notification."""
|
||||
|
||||
def __init__(self, db_conn: Database, email_client: EmailClient):
|
||||
self.db = db_conn
|
||||
self.email = email_client
|
||||
|
||||
def process(self, data: List[dict]) -> ProcessingResult:
|
||||
"""Process and store data with notifications."""
|
||||
validated = self._validate_data(data)
|
||||
result = self.db.save(validated)
|
||||
self._notify_completion(result)
|
||||
return result
|
||||
```
|
||||
|
||||
**Design Improvement Areas:**
|
||||
|
||||
If there's a **cleaner**, **more scalable**, or **simpler** design, highlight it and suggest improvements that would:
|
||||
|
||||
- Reduce code duplication through shared utilities
|
||||
- Make unit testing easier
|
||||
- Improve separation of concerns (single responsibility)
|
||||
- Make unit testing easier through dependency injection
|
||||
- Add clarity without adding complexity
|
||||
- Prefer dataclasses for structured data
|
||||
|
||||
## Development Tools & Commands
|
||||
|
||||
### Package Management
|
||||
|
||||
```bash
|
||||
# Add package
|
||||
uv add package-name
|
||||
|
||||
# Sync project dependencies
|
||||
uv sync
|
||||
uv lock
|
||||
```
|
||||
|
||||
### Testing
|
||||
Each package in `libs/` has its own `pyproject.toml` and `uv.lock`.
|
||||
|
||||
```bash
|
||||
# Run unit tests (no network)
|
||||
make test
|
||||
|
||||
# Don't run integration tests, as API keys must be set
|
||||
|
||||
# Run specific test file
|
||||
uv run --group test pytest tests/unit_tests/test_specific.py
|
||||
```
|
||||
|
||||
### Code Quality
|
||||
|
||||
```bash
|
||||
# Lint code
|
||||
make lint
|
||||
@@ -261,66 +64,118 @@ make format
|
||||
uv run --group lint mypy .
|
||||
```
|
||||
|
||||
### Dependency Management Patterns
|
||||
#### Key config files
|
||||
|
||||
**Local Development Dependencies:**
|
||||
- pyproject.toml: Main workspace configuration with dependency groups
|
||||
- uv.lock: Locked dependencies for reproducible builds
|
||||
- Makefile: Development tasks
|
||||
|
||||
```toml
|
||||
[tool.uv.sources]
|
||||
langchain-core = { path = "../core", editable = true }
|
||||
langchain-tests = { path = "../standard-tests", editable = true }
|
||||
```
|
||||
#### Commit standards
|
||||
|
||||
**For tools, use the `@tool` decorator from `langchain_core.tools`:**
|
||||
Suggest PR titles that follow Conventional Commits format. Refer to .github/workflows/pr_lint for allowed types and scopes.
|
||||
|
||||
```python
|
||||
from langchain_core.tools import tool
|
||||
#### Pull request guidelines
|
||||
|
||||
@tool
|
||||
def search_database(query: str) -> str:
|
||||
"""Search the database for relevant information.
|
||||
- Always add a disclaimer to the PR description mentioning how AI agents are involved with the contribution.
|
||||
- Describe the "why" of the changes, why the proposed solution is the right one. Limit prose.
|
||||
- Highlight areas of the proposed changes that require careful review.
|
||||
|
||||
## Core development principles
|
||||
|
||||
### Maintain stable public interfaces
|
||||
|
||||
CRITICAL: Always attempt to preserve function signatures, argument positions, and names for exported/public methods. Do not make breaking changes.
|
||||
|
||||
**Before making ANY changes to public APIs:**
|
||||
|
||||
- Check if the function/class is exported in `__init__.py`
|
||||
- Look for existing usage patterns in tests and examples
|
||||
- Use keyword-only arguments for new parameters: `*, new_param: str = "default"`
|
||||
- Mark experimental features clearly with docstring warnings (using MkDocs Material admonitions, like `!!! warning`)
|
||||
|
||||
Ask: "Would this change break someone's code if they used it last week?"
|
||||
|
||||
### Code quality standards
|
||||
|
||||
All Python code MUST include type hints and return types.
|
||||
|
||||
```python title="Example"
|
||||
def filter_unknown_users(users: list[str], known_users: set[str]) -> list[str]:
|
||||
"""Single line description of the function.
|
||||
|
||||
Any additional context about the function can go here.
|
||||
|
||||
Args:
|
||||
query: The search query string.
|
||||
users: List of user identifiers to filter.
|
||||
known_users: Set of known/valid user identifiers.
|
||||
|
||||
Returns:
|
||||
List of users that are not in the known_users set.
|
||||
"""
|
||||
# Implementation here
|
||||
return results
|
||||
```
|
||||
|
||||
## Commit Standards
|
||||
- Use descriptive, self-explanatory variable names.
|
||||
- Follow existing patterns in the codebase you're modifying
|
||||
- Attempt to break up complex functions (>20 lines) into smaller, focused functions where it makes sense
|
||||
|
||||
**Use Conventional Commits format for PR titles:**
|
||||
### Testing requirements
|
||||
|
||||
- `feat(core): add multi-tenant support`
|
||||
- `fix(cli): resolve flag parsing error`
|
||||
- `docs: update API usage examples`
|
||||
- `docs(openai): update API usage examples`
|
||||
Every new feature or bugfix MUST be covered by unit tests.
|
||||
|
||||
## Framework-Specific Guidelines
|
||||
- Unit tests: `tests/unit_tests/` (no network calls allowed)
|
||||
- Integration tests: `tests/integration_tests/` (network calls permitted)
|
||||
- We use `pytest` as the testing framework; if in doubt, check other existing tests for examples.
|
||||
- The testing file structure should mirror the source code structure.
|
||||
|
||||
- Follow the existing patterns in `langchain-core` for base abstractions
|
||||
- Use `langchain_core.callbacks` for execution tracking
|
||||
- Implement proper streaming support where applicable
|
||||
- Avoid deprecated components like legacy `LLMChain`
|
||||
**Checklist:**
|
||||
|
||||
### Partner Integrations
|
||||
- [ ] Tests fail when your new logic is broken
|
||||
- [ ] Happy path is covered
|
||||
- [ ] Edge cases and error conditions are tested
|
||||
- [ ] Use fixtures/mocks for external dependencies
|
||||
- [ ] Tests are deterministic (no flaky tests)
|
||||
- [ ] Does the test suite fail if your new logic is broken?
|
||||
|
||||
- Follow the established patterns in existing partner libraries
|
||||
- Implement standard interfaces (`BaseChatModel`, `BaseEmbeddings`, etc.)
|
||||
- Include comprehensive integration tests
|
||||
- Document API key requirements and authentication
|
||||
### Security and risk assessment
|
||||
|
||||
---
|
||||
- No `eval()`, `exec()`, or `pickle` on user-controlled input
|
||||
- Proper exception handling (no bare `except:`) and use a `msg` variable for error messages
|
||||
- Remove unreachable/commented code before committing
|
||||
- Race conditions or resource leaks (file handles, sockets, threads).
|
||||
- Ensure proper resource cleanup (file handles, connections)
|
||||
|
||||
## Quick Reference Checklist
|
||||
### Documentation standards
|
||||
|
||||
Before submitting code changes:
|
||||
Use Google-style docstrings with Args section for all public functions.
|
||||
|
||||
- [ ] **Breaking Changes**: Verified no public API changes
|
||||
- [ ] **Type Hints**: All functions have complete type annotations
|
||||
- [ ] **Tests**: New functionality is fully tested
|
||||
- [ ] **Security**: No dangerous patterns (eval, silent failures, etc.)
|
||||
- [ ] **Documentation**: Google-style docstrings for public functions
|
||||
- [ ] **Code Quality**: `make lint` and `make format` pass
|
||||
- [ ] **Architecture**: Suggested improvements where applicable
|
||||
- [ ] **Commit Message**: Follows Conventional Commits format
|
||||
```python title="Example"
|
||||
def send_email(to: str, msg: str, *, priority: str = "normal") -> bool:
|
||||
"""Send an email to a recipient with specified priority.
|
||||
|
||||
Any additional context about the function can go here.
|
||||
|
||||
Args:
|
||||
to: The email address of the recipient.
|
||||
msg: The message body to send.
|
||||
priority: Email priority level.
|
||||
|
||||
Returns:
|
||||
`True` if email was sent successfully, `False` otherwise.
|
||||
|
||||
Raises:
|
||||
InvalidEmailError: If the email address format is invalid.
|
||||
SMTPConnectionError: If unable to connect to email server.
|
||||
"""
|
||||
```
|
||||
|
||||
- Types go in function signatures, NOT in docstrings
|
||||
- If a default is present, DO NOT repeat it in the docstring unless there is post-processing or it is set conditionally.
|
||||
- Focus on "why" rather than "what" in descriptions
|
||||
- Document all parameters, return values, and exceptions
|
||||
- Keep descriptions concise but clear
|
||||
- Ensure American English spelling (e.g., "behavior", not "behaviour")
|
||||
|
||||
## Additional resources
|
||||
|
||||
- **Documentation:** https://docs.langchain.com/oss/python/langchain/overview and source at https://github.com/langchain-ai/docs or `../docs/`. Prefer the local install and use file search tools for best results. If needed, use the docs MCP server as defined in `.mcp.json` for programmatic access.
|
||||
- **Contributing Guide:** [`.github/CONTRIBUTING.md`](https://docs.langchain.com/oss/python/contributing/overview)
|
||||
|
||||
405
CLAUDE.md
405
CLAUDE.md
@@ -1,255 +1,58 @@
|
||||
# Global Development Guidelines for LangChain Projects
|
||||
# Global development guidelines for the LangChain monorepo
|
||||
|
||||
## Core Development Principles
|
||||
This document provides context to understand the LangChain Python project and assist with development.
|
||||
|
||||
### 1. Maintain Stable Public Interfaces ⚠️ CRITICAL
|
||||
## Project architecture and context
|
||||
|
||||
**Always attempt to preserve function signatures, argument positions, and names for exported/public methods.**
|
||||
### Monorepo structure
|
||||
|
||||
❌ **Bad - Breaking Change:**
|
||||
This is a Python monorepo with multiple independently versioned packages that use `uv`.
|
||||
|
||||
```python
|
||||
def get_user(id, verbose=False): # Changed from `user_id`
|
||||
pass
|
||||
```txt
|
||||
langchain/
|
||||
├── libs/
|
||||
│ ├── core/ # `langchain-core` primitives and base abstractions
|
||||
│ ├── langchain/ # `langchain-classic` (legacy, no new features)
|
||||
│ ├── langchain_v1/ # Actively maintained `langchain` package
|
||||
│ ├── partners/ # Third-party integrations
|
||||
│ │ ├── openai/ # OpenAI models and embeddings
|
||||
│ │ ├── anthropic/ # Anthropic (Claude) integration
|
||||
│ │ ├── ollama/ # Local model support
|
||||
│ │ └── ... (other integrations maintained by the LangChain team)
|
||||
│ ├── text-splitters/ # Document chunking utilities
|
||||
│ ├── standard-tests/ # Shared test suite for integrations
|
||||
│ ├── model-profiles/ # Model configuration profiles
|
||||
│ └── cli/ # Command-line interface tools
|
||||
├── .github/ # CI/CD workflows and templates
|
||||
├── .vscode/ # VSCode IDE standard settings and recommended extensions
|
||||
└── README.md # Information about LangChain
|
||||
```
|
||||
|
||||
✅ **Good - Stable Interface:**
|
||||
- **Core layer** (`langchain-core`): Base abstractions, interfaces, and protocols. Users should not need to know about this layer directly.
|
||||
- **Implementation layer** (`langchain`): Concrete implementations and high-level public utilities
|
||||
- **Integration layer** (`partners/`): Third-party service integrations. Note that this monorepo is not exhaustive of all LangChain integrations; some are maintained in separate repos, such as `langchain-ai/langchain-google` and `langchain-ai/langchain-aws`. Usually these repos are cloned at the same level as this monorepo, so if needed, you can refer to their code directly by navigating to `../langchain-google/` from this monorepo.
|
||||
- **Testing layer** (`standard-tests/`): Standardized integration tests for partner integrations
|
||||
|
||||
```python
|
||||
def get_user(user_id: str, verbose: bool = False) -> User:
|
||||
"""Retrieve user by ID with optional verbose output."""
|
||||
pass
|
||||
```
|
||||
### Development tools & commands**
|
||||
|
||||
**Before making ANY changes to public APIs:**
|
||||
- `uv` – Fast Python package installer and resolver (replaces pip/poetry)
|
||||
- `make` – Task runner for common development commands. Feel free to look at the `Makefile` for available commands and usage patterns.
|
||||
- `ruff` – Fast Python linter and formatter
|
||||
- `mypy` – Static type checking
|
||||
- `pytest` – Testing framework
|
||||
|
||||
- Check if the function/class is exported in `__init__.py`
|
||||
- Look for existing usage patterns in tests and examples
|
||||
- Use keyword-only arguments for new parameters: `*, new_param: str = "default"`
|
||||
- Mark experimental features clearly with docstring warnings (using MkDocs Material admonitions, like `!!! warning`)
|
||||
This monorepo uses `uv` for dependency management. Local development uses editable installs: `[tool.uv.sources]`
|
||||
|
||||
🧠 *Ask yourself:* "Would this change break someone's code if they used it last week?"
|
||||
|
||||
### 2. Code Quality Standards
|
||||
|
||||
**All Python code MUST include type hints and return types.**
|
||||
|
||||
❌ **Bad:**
|
||||
|
||||
```python
|
||||
def p(u, d):
|
||||
return [x for x in u if x not in d]
|
||||
```
|
||||
|
||||
✅ **Good:**
|
||||
|
||||
```python
|
||||
def filter_unknown_users(users: list[str], known_users: set[str]) -> list[str]:
|
||||
"""Filter out users that are not in the known users set.
|
||||
|
||||
Args:
|
||||
users: List of user identifiers to filter.
|
||||
known_users: Set of known/valid user identifiers.
|
||||
|
||||
Returns:
|
||||
List of users that are not in the known_users set.
|
||||
"""
|
||||
return [user for user in users if user not in known_users]
|
||||
```
|
||||
|
||||
**Style Requirements:**
|
||||
|
||||
- Use descriptive, **self-explanatory variable names**. Avoid overly short or cryptic identifiers.
|
||||
- Attempt to break up complex functions (>20 lines) into smaller, focused functions where it makes sense
|
||||
- Avoid unnecessary abstraction or premature optimization
|
||||
- Follow existing patterns in the codebase you're modifying
|
||||
|
||||
### 3. Testing Requirements
|
||||
|
||||
**Every new feature or bugfix MUST be covered by unit tests.**
|
||||
|
||||
**Test Organization:**
|
||||
|
||||
- Unit tests: `tests/unit_tests/` (no network calls allowed)
|
||||
- Integration tests: `tests/integration_tests/` (network calls permitted)
|
||||
- Use `pytest` as the testing framework
|
||||
|
||||
**Test Quality Checklist:**
|
||||
|
||||
- [ ] Tests fail when your new logic is broken
|
||||
- [ ] Happy path is covered
|
||||
- [ ] Edge cases and error conditions are tested
|
||||
- [ ] Use fixtures/mocks for external dependencies
|
||||
- [ ] Tests are deterministic (no flaky tests)
|
||||
|
||||
Checklist questions:
|
||||
|
||||
- [ ] Does the test suite fail if your new logic is broken?
|
||||
- [ ] Are all expected behaviors exercised (happy path, invalid input, etc)?
|
||||
- [ ] Do tests use fixtures or mocks where needed?
|
||||
|
||||
```python
|
||||
def test_filter_unknown_users():
|
||||
"""Test filtering unknown users from a list."""
|
||||
users = ["alice", "bob", "charlie"]
|
||||
known_users = {"alice", "bob"}
|
||||
|
||||
result = filter_unknown_users(users, known_users)
|
||||
|
||||
assert result == ["charlie"]
|
||||
assert len(result) == 1
|
||||
```
|
||||
|
||||
### 4. Security and Risk Assessment
|
||||
|
||||
**Security Checklist:**
|
||||
|
||||
- No `eval()`, `exec()`, or `pickle` on user-controlled input
|
||||
- Proper exception handling (no bare `except:`) and use a `msg` variable for error messages
|
||||
- Remove unreachable/commented code before committing
|
||||
- Race conditions or resource leaks (file handles, sockets, threads).
|
||||
- Ensure proper resource cleanup (file handles, connections)
|
||||
|
||||
❌ **Bad:**
|
||||
|
||||
```python
|
||||
def load_config(path):
|
||||
with open(path) as f:
|
||||
return eval(f.read()) # ⚠️ Never eval config
|
||||
```
|
||||
|
||||
✅ **Good:**
|
||||
|
||||
```python
|
||||
import json
|
||||
|
||||
def load_config(path: str) -> dict:
|
||||
with open(path) as f:
|
||||
return json.load(f)
|
||||
```
|
||||
|
||||
### 5. Documentation Standards
|
||||
|
||||
**Use Google-style docstrings with Args section for all public functions.**
|
||||
|
||||
❌ **Insufficient Documentation:**
|
||||
|
||||
```python
|
||||
def send_email(to, msg):
|
||||
"""Send an email to a recipient."""
|
||||
```
|
||||
|
||||
✅ **Complete Documentation:**
|
||||
|
||||
```python
|
||||
def send_email(to: str, msg: str, *, priority: str = "normal") -> bool:
|
||||
"""
|
||||
Send an email to a recipient with specified priority.
|
||||
|
||||
Args:
|
||||
to: The email address of the recipient.
|
||||
msg: The message body to send.
|
||||
priority: Email priority level (`'low'`, `'normal'`, `'high'`).
|
||||
|
||||
Returns:
|
||||
`True` if email was sent successfully, `False` otherwise.
|
||||
|
||||
Raises:
|
||||
`InvalidEmailError`: If the email address format is invalid.
|
||||
`SMTPConnectionError`: If unable to connect to email server.
|
||||
"""
|
||||
```
|
||||
|
||||
**Documentation Guidelines:**
|
||||
|
||||
- Types go in function signatures, NOT in docstrings
|
||||
- If a default is present, DO NOT repeat it in the docstring unless there is post-processing or it is set conditionally.
|
||||
- Focus on "why" rather than "what" in descriptions
|
||||
- Document all parameters, return values, and exceptions
|
||||
- Keep descriptions concise but clear
|
||||
- Ensure American English spelling (e.g., "behavior", not "behaviour")
|
||||
|
||||
📌 *Tip:* Keep descriptions concise but clear. Only document return values if non-obvious.
|
||||
|
||||
### 6. Architectural Improvements
|
||||
|
||||
**When you encounter code that could be improved, suggest better designs:**
|
||||
|
||||
❌ **Poor Design:**
|
||||
|
||||
```python
|
||||
def process_data(data, db_conn, email_client, logger):
|
||||
# Function doing too many things
|
||||
validated = validate_data(data)
|
||||
result = db_conn.save(validated)
|
||||
email_client.send_notification(result)
|
||||
logger.log(f"Processed {len(data)} items")
|
||||
return result
|
||||
```
|
||||
|
||||
✅ **Better Design:**
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class ProcessingResult:
|
||||
"""Result of data processing operation."""
|
||||
items_processed: int
|
||||
success: bool
|
||||
errors: List[str] = field(default_factory=list)
|
||||
|
||||
class DataProcessor:
|
||||
"""Handles data validation, storage, and notification."""
|
||||
|
||||
def __init__(self, db_conn: Database, email_client: EmailClient):
|
||||
self.db = db_conn
|
||||
self.email = email_client
|
||||
|
||||
def process(self, data: List[dict]) -> ProcessingResult:
|
||||
"""Process and store data with notifications."""
|
||||
validated = self._validate_data(data)
|
||||
result = self.db.save(validated)
|
||||
self._notify_completion(result)
|
||||
return result
|
||||
```
|
||||
|
||||
**Design Improvement Areas:**
|
||||
|
||||
If there's a **cleaner**, **more scalable**, or **simpler** design, highlight it and suggest improvements that would:
|
||||
|
||||
- Reduce code duplication through shared utilities
|
||||
- Make unit testing easier
|
||||
- Improve separation of concerns (single responsibility)
|
||||
- Make unit testing easier through dependency injection
|
||||
- Add clarity without adding complexity
|
||||
- Prefer dataclasses for structured data
|
||||
|
||||
## Development Tools & Commands
|
||||
|
||||
### Package Management
|
||||
|
||||
```bash
|
||||
# Add package
|
||||
uv add package-name
|
||||
|
||||
# Sync project dependencies
|
||||
uv sync
|
||||
uv lock
|
||||
```
|
||||
|
||||
### Testing
|
||||
Each package in `libs/` has its own `pyproject.toml` and `uv.lock`.
|
||||
|
||||
```bash
|
||||
# Run unit tests (no network)
|
||||
make test
|
||||
|
||||
# Don't run integration tests, as API keys must be set
|
||||
|
||||
# Run specific test file
|
||||
uv run --group test pytest tests/unit_tests/test_specific.py
|
||||
```
|
||||
|
||||
### Code Quality
|
||||
|
||||
```bash
|
||||
# Lint code
|
||||
make lint
|
||||
@@ -261,66 +64,118 @@ make format
|
||||
uv run --group lint mypy .
|
||||
```
|
||||
|
||||
### Dependency Management Patterns
|
||||
#### Key config files
|
||||
|
||||
**Local Development Dependencies:**
|
||||
- pyproject.toml: Main workspace configuration with dependency groups
|
||||
- uv.lock: Locked dependencies for reproducible builds
|
||||
- Makefile: Development tasks
|
||||
|
||||
```toml
|
||||
[tool.uv.sources]
|
||||
langchain-core = { path = "../core", editable = true }
|
||||
langchain-tests = { path = "../standard-tests", editable = true }
|
||||
```
|
||||
#### Commit standards
|
||||
|
||||
**For tools, use the `@tool` decorator from `langchain_core.tools`:**
|
||||
Suggest PR titles that follow Conventional Commits format. Refer to .github/workflows/pr_lint for allowed types and scopes.
|
||||
|
||||
```python
|
||||
from langchain_core.tools import tool
|
||||
#### Pull request guidelines
|
||||
|
||||
@tool
|
||||
def search_database(query: str) -> str:
|
||||
"""Search the database for relevant information.
|
||||
- Always add a disclaimer to the PR description mentioning how AI agents are involved with the contribution.
|
||||
- Describe the "why" of the changes, why the proposed solution is the right one. Limit prose.
|
||||
- Highlight areas of the proposed changes that require careful review.
|
||||
|
||||
## Core development principles
|
||||
|
||||
### Maintain stable public interfaces
|
||||
|
||||
CRITICAL: Always attempt to preserve function signatures, argument positions, and names for exported/public methods. Do not make breaking changes.
|
||||
|
||||
**Before making ANY changes to public APIs:**
|
||||
|
||||
- Check if the function/class is exported in `__init__.py`
|
||||
- Look for existing usage patterns in tests and examples
|
||||
- Use keyword-only arguments for new parameters: `*, new_param: str = "default"`
|
||||
- Mark experimental features clearly with docstring warnings (using MkDocs Material admonitions, like `!!! warning`)
|
||||
|
||||
Ask: "Would this change break someone's code if they used it last week?"
|
||||
|
||||
### Code quality standards
|
||||
|
||||
All Python code MUST include type hints and return types.
|
||||
|
||||
```python title="Example"
|
||||
def filter_unknown_users(users: list[str], known_users: set[str]) -> list[str]:
|
||||
"""Single line description of the function.
|
||||
|
||||
Any additional context about the function can go here.
|
||||
|
||||
Args:
|
||||
query: The search query string.
|
||||
users: List of user identifiers to filter.
|
||||
known_users: Set of known/valid user identifiers.
|
||||
|
||||
Returns:
|
||||
List of users that are not in the known_users set.
|
||||
"""
|
||||
# Implementation here
|
||||
return results
|
||||
```
|
||||
|
||||
## Commit Standards
|
||||
- Use descriptive, self-explanatory variable names.
|
||||
- Follow existing patterns in the codebase you're modifying
|
||||
- Attempt to break up complex functions (>20 lines) into smaller, focused functions where it makes sense
|
||||
|
||||
**Use Conventional Commits format for PR titles:**
|
||||
### Testing requirements
|
||||
|
||||
- `feat(core): add multi-tenant support`
|
||||
- `fix(cli): resolve flag parsing error`
|
||||
- `docs: update API usage examples`
|
||||
- `docs(openai): update API usage examples`
|
||||
Every new feature or bugfix MUST be covered by unit tests.
|
||||
|
||||
## Framework-Specific Guidelines
|
||||
- Unit tests: `tests/unit_tests/` (no network calls allowed)
|
||||
- Integration tests: `tests/integration_tests/` (network calls permitted)
|
||||
- We use `pytest` as the testing framework; if in doubt, check other existing tests for examples.
|
||||
- The testing file structure should mirror the source code structure.
|
||||
|
||||
- Follow the existing patterns in `langchain-core` for base abstractions
|
||||
- Use `langchain_core.callbacks` for execution tracking
|
||||
- Implement proper streaming support where applicable
|
||||
- Avoid deprecated components like legacy `LLMChain`
|
||||
**Checklist:**
|
||||
|
||||
### Partner Integrations
|
||||
- [ ] Tests fail when your new logic is broken
|
||||
- [ ] Happy path is covered
|
||||
- [ ] Edge cases and error conditions are tested
|
||||
- [ ] Use fixtures/mocks for external dependencies
|
||||
- [ ] Tests are deterministic (no flaky tests)
|
||||
- [ ] Does the test suite fail if your new logic is broken?
|
||||
|
||||
- Follow the established patterns in existing partner libraries
|
||||
- Implement standard interfaces (`BaseChatModel`, `BaseEmbeddings`, etc.)
|
||||
- Include comprehensive integration tests
|
||||
- Document API key requirements and authentication
|
||||
### Security and risk assessment
|
||||
|
||||
---
|
||||
- No `eval()`, `exec()`, or `pickle` on user-controlled input
|
||||
- Proper exception handling (no bare `except:`) and use a `msg` variable for error messages
|
||||
- Remove unreachable/commented code before committing
|
||||
- Race conditions or resource leaks (file handles, sockets, threads).
|
||||
- Ensure proper resource cleanup (file handles, connections)
|
||||
|
||||
## Quick Reference Checklist
|
||||
### Documentation standards
|
||||
|
||||
Before submitting code changes:
|
||||
Use Google-style docstrings with Args section for all public functions.
|
||||
|
||||
- [ ] **Breaking Changes**: Verified no public API changes
|
||||
- [ ] **Type Hints**: All functions have complete type annotations
|
||||
- [ ] **Tests**: New functionality is fully tested
|
||||
- [ ] **Security**: No dangerous patterns (eval, silent failures, etc.)
|
||||
- [ ] **Documentation**: Google-style docstrings for public functions
|
||||
- [ ] **Code Quality**: `make lint` and `make format` pass
|
||||
- [ ] **Architecture**: Suggested improvements where applicable
|
||||
- [ ] **Commit Message**: Follows Conventional Commits format
|
||||
```python title="Example"
|
||||
def send_email(to: str, msg: str, *, priority: str = "normal") -> bool:
|
||||
"""Send an email to a recipient with specified priority.
|
||||
|
||||
Any additional context about the function can go here.
|
||||
|
||||
Args:
|
||||
to: The email address of the recipient.
|
||||
msg: The message body to send.
|
||||
priority: Email priority level.
|
||||
|
||||
Returns:
|
||||
`True` if email was sent successfully, `False` otherwise.
|
||||
|
||||
Raises:
|
||||
InvalidEmailError: If the email address format is invalid.
|
||||
SMTPConnectionError: If unable to connect to email server.
|
||||
"""
|
||||
```
|
||||
|
||||
- Types go in function signatures, NOT in docstrings
|
||||
- If a default is present, DO NOT repeat it in the docstring unless there is post-processing or it is set conditionally.
|
||||
- Focus on "why" rather than "what" in descriptions
|
||||
- Document all parameters, return values, and exceptions
|
||||
- Keep descriptions concise but clear
|
||||
- Ensure American English spelling (e.g., "behavior", not "behaviour")
|
||||
|
||||
## Additional resources
|
||||
|
||||
- **Documentation:** https://docs.langchain.com/oss/python/langchain/overview and source at https://github.com/langchain-ai/docs or `../docs/`. Prefer the local install and use file search tools for best results. If needed, use the docs MCP server as defined in `.mcp.json` for programmatic access.
|
||||
- **Contributing Guide:** [`.github/CONTRIBUTING.md`](https://docs.langchain.com/oss/python/contributing/overview)
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
# Migrating
|
||||
|
||||
Please see the following guides for migrating LangChain code:
|
||||
|
||||
* Migrate to [LangChain v1.0](https://docs.langchain.com/oss/python/migrate/langchain-v1)
|
||||
* Migrate to [LangChain v0.3](https://python.langchain.com/docs/versions/v0_3/)
|
||||
* Migrate to [LangChain v0.2](https://python.langchain.com/docs/versions/v0_2/)
|
||||
* Migrating from [LangChain 0.0.x Chains](https://python.langchain.com/docs/versions/migrating_chains/)
|
||||
* Upgrade to [LangGraph Memory](https://python.langchain.com/docs/versions/migrating_memory/)
|
||||
@@ -6,9 +6,8 @@ import hashlib
|
||||
import logging
|
||||
import re
|
||||
import shutil
|
||||
from collections.abc import Sequence
|
||||
from pathlib import Path
|
||||
from typing import Any, TypedDict
|
||||
from typing import TYPE_CHECKING, Any, TypedDict
|
||||
|
||||
from git import Repo
|
||||
|
||||
@@ -18,6 +17,9 @@ from langchain_cli.constants import (
|
||||
DEFAULT_GIT_SUBDIRECTORY,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .file import File
|
||||
from .folder import Folder
|
||||
if TYPE_CHECKING:
|
||||
from .file import File
|
||||
from .folder import Folder
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .file import File
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class Folder:
|
||||
def __init__(self, name: str, *files: Folder | File) -> None:
|
||||
|
||||
@@ -34,7 +34,7 @@ The LangChain ecosystem is built on top of `langchain-core`. Some of the benefit
|
||||
|
||||
## 📖 Documentation
|
||||
|
||||
For full documentation, see the [API reference](https://reference.langchain.com/python/langchain_core/).
|
||||
For full documentation, see the [API reference](https://reference.langchain.com/python/langchain_core/). For conceptual guides, tutorials, and examples on using LangChain, see the [LangChain Docs](https://docs.langchain.com/oss/python/langchain/overview).
|
||||
|
||||
## 📕 Releases & Versioning
|
||||
|
||||
|
||||
@@ -5,13 +5,12 @@ from __future__ import annotations
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
from uuid import UUID
|
||||
|
||||
from tenacity import RetryCallState
|
||||
from typing_extensions import Self
|
||||
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.documents import Document
|
||||
|
||||
@@ -39,7 +39,6 @@ from langchain_core.tracers.context import (
|
||||
tracing_v2_callback_var,
|
||||
)
|
||||
from langchain_core.tracers.langchain import LangChainTracer
|
||||
from langchain_core.tracers.schemas import Run
|
||||
from langchain_core.tracers.stdout import ConsoleCallbackHandler
|
||||
from langchain_core.utils.env import env_var_is_set
|
||||
|
||||
@@ -52,6 +51,7 @@ if TYPE_CHECKING:
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk, LLMResult
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
from langchain_core.tracers.schemas import Run
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -6,16 +6,9 @@ import hashlib
|
||||
import json
|
||||
import uuid
|
||||
import warnings
|
||||
from collections.abc import (
|
||||
AsyncIterable,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Iterable,
|
||||
Iterator,
|
||||
Sequence,
|
||||
)
|
||||
from itertools import islice
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Literal,
|
||||
TypedDict,
|
||||
@@ -29,6 +22,16 @@ from langchain_core.exceptions import LangChainException
|
||||
from langchain_core.indexing.base import DocumentIndex, RecordManager
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import (
|
||||
AsyncIterable,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Iterable,
|
||||
Iterator,
|
||||
Sequence,
|
||||
)
|
||||
|
||||
# Magic UUID to use as a namespace for hashing.
|
||||
# Used to try and generate a unique UUID for each document
|
||||
# from hashing the document content and metadata.
|
||||
@@ -299,6 +302,7 @@ def index(
|
||||
are not able to specify the uid of the document.
|
||||
|
||||
!!! warning "Behavior changed in `langchain-core` 0.3.25"
|
||||
|
||||
Added `scoped_full` cleanup mode.
|
||||
|
||||
!!! warning
|
||||
@@ -637,6 +641,7 @@ async def aindex(
|
||||
are not able to specify the uid of the document.
|
||||
|
||||
!!! warning "Behavior changed in `langchain-core` 0.3.25"
|
||||
|
||||
Added `scoped_full` cleanup mode.
|
||||
|
||||
!!! warning
|
||||
|
||||
@@ -53,6 +53,10 @@ if TYPE_CHECKING:
|
||||
ParrotFakeChatModel,
|
||||
)
|
||||
from langchain_core.language_models.llms import LLM, BaseLLM
|
||||
from langchain_core.language_models.model_profile import (
|
||||
ModelProfile,
|
||||
ModelProfileRegistry,
|
||||
)
|
||||
|
||||
__all__ = (
|
||||
"LLM",
|
||||
@@ -68,6 +72,8 @@ __all__ = (
|
||||
"LanguageModelInput",
|
||||
"LanguageModelLike",
|
||||
"LanguageModelOutput",
|
||||
"ModelProfile",
|
||||
"ModelProfileRegistry",
|
||||
"ParrotFakeChatModel",
|
||||
"SimpleChatModel",
|
||||
"get_tokenizer",
|
||||
@@ -90,6 +96,8 @@ _dynamic_imports = {
|
||||
"GenericFakeChatModel": "fake_chat_models",
|
||||
"ParrotFakeChatModel": "fake_chat_models",
|
||||
"LLM": "llms",
|
||||
"ModelProfile": "model_profile",
|
||||
"ModelProfileRegistry": "model_profile",
|
||||
"BaseLLM": "llms",
|
||||
"is_openai_data_block": "_utils",
|
||||
}
|
||||
|
||||
@@ -140,6 +140,7 @@ def _normalize_messages(
|
||||
- LangChain v0 standard content blocks for backward compatibility
|
||||
|
||||
!!! warning "Behavior changed in `langchain-core` 1.0.0"
|
||||
|
||||
In previous versions, this function returned messages in LangChain v0 format.
|
||||
Now, it returns messages in LangChain v1 format, which upgraded chat models now
|
||||
expect to receive when passing back in message history. For backward
|
||||
|
||||
@@ -299,6 +299,9 @@ class BaseLanguageModel(
|
||||
|
||||
Useful for checking if an input fits in a model's context window.
|
||||
|
||||
This should be overridden by model-specific implementations to provide accurate
|
||||
token counts via model-specific tokenizers.
|
||||
|
||||
Args:
|
||||
text: The string input to tokenize.
|
||||
|
||||
@@ -317,9 +320,17 @@ class BaseLanguageModel(
|
||||
|
||||
Useful for checking if an input fits in a model's context window.
|
||||
|
||||
This should be overridden by model-specific implementations to provide accurate
|
||||
token counts via model-specific tokenizers.
|
||||
|
||||
!!! note
|
||||
The base implementation of `get_num_tokens_from_messages` ignores tool
|
||||
schemas.
|
||||
|
||||
* The base implementation of `get_num_tokens_from_messages` ignores tool
|
||||
schemas.
|
||||
* The base implementation of `get_num_tokens_from_messages` adds additional
|
||||
prefixes to messages in represent user roles, which will add to the
|
||||
overall token count. Model-specific implementations may choose to
|
||||
handle this differently.
|
||||
|
||||
Args:
|
||||
messages: The message inputs to tokenize.
|
||||
|
||||
@@ -15,7 +15,6 @@ from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_core._api.beta_decorator import beta
|
||||
from langchain_core.caches import BaseCache
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManager,
|
||||
@@ -34,6 +33,7 @@ from langchain_core.language_models.base import (
|
||||
LangSmithParams,
|
||||
LanguageModelInput,
|
||||
)
|
||||
from langchain_core.language_models.model_profile import ModelProfile
|
||||
from langchain_core.load import dumpd, dumps
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
@@ -76,8 +76,6 @@ from langchain_core.utils.utils import LC_ID_PREFIX, from_env
|
||||
if TYPE_CHECKING:
|
||||
import uuid
|
||||
|
||||
from langchain_model_profiles import ModelProfile # type: ignore[import-untyped]
|
||||
|
||||
from langchain_core.output_parsers.base import OutputParserLike
|
||||
from langchain_core.runnables import Runnable, RunnableConfig
|
||||
from langchain_core.tools import BaseTool
|
||||
@@ -91,7 +89,10 @@ def _generate_response_from_error(error: BaseException) -> list[ChatGeneration]:
|
||||
try:
|
||||
metadata["body"] = response.json()
|
||||
except Exception:
|
||||
metadata["body"] = getattr(response, "text", None)
|
||||
try:
|
||||
metadata["body"] = getattr(response, "text", None)
|
||||
except Exception:
|
||||
metadata["body"] = None
|
||||
if hasattr(response, "headers"):
|
||||
try:
|
||||
metadata["headers"] = dict(response.headers)
|
||||
@@ -332,10 +333,25 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
|
||||
[`langchain-openai`](https://pypi.org/project/langchain-openai)) can also use this
|
||||
field to roll out new content formats in a backward-compatible way.
|
||||
|
||||
!!! version-added "Added in `langchain-core` 1.0"
|
||||
!!! version-added "Added in `langchain-core` 1.0.0"
|
||||
|
||||
"""
|
||||
|
||||
profile: ModelProfile | None = Field(default=None, exclude=True)
|
||||
"""Profile detailing model capabilities.
|
||||
|
||||
!!! warning "Beta feature"
|
||||
This is a beta feature. The format of model profiles is subject to change.
|
||||
|
||||
If not specified, automatically loaded from the provider package on initialization
|
||||
if data is available.
|
||||
|
||||
Example profile data includes context window sizes, supported modalities, or support
|
||||
for tool calling, structured output, and other features.
|
||||
|
||||
!!! version-added "Added in `langchain-core` 1.1.0"
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
@@ -1616,7 +1632,7 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
|
||||
# }
|
||||
```
|
||||
|
||||
Example: `dict` schema (`include_raw=False`):
|
||||
Example: Dictionary schema (`include_raw=False`):
|
||||
|
||||
```python
|
||||
from pydantic import BaseModel
|
||||
@@ -1644,6 +1660,7 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
|
||||
```
|
||||
|
||||
!!! warning "Behavior changed in `langchain-core` 0.2.26"
|
||||
|
||||
Added support for `TypedDict` class.
|
||||
|
||||
""" # noqa: E501
|
||||
@@ -1685,40 +1702,6 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
|
||||
return RunnableMap(raw=llm) | parser_with_fallback
|
||||
return llm | output_parser
|
||||
|
||||
@property
|
||||
@beta()
|
||||
def profile(self) -> ModelProfile:
|
||||
"""Return profiling information for the model.
|
||||
|
||||
This property relies on the `langchain-model-profiles` package to retrieve chat
|
||||
model capabilities, such as context window sizes and supported features.
|
||||
|
||||
Raises:
|
||||
ImportError: If `langchain-model-profiles` is not installed.
|
||||
|
||||
Returns:
|
||||
A `ModelProfile` object containing profiling information for the model.
|
||||
"""
|
||||
try:
|
||||
from langchain_model_profiles import get_model_profile # noqa: PLC0415
|
||||
except ImportError as err:
|
||||
informative_error_message = (
|
||||
"To access model profiling information, please install the "
|
||||
"`langchain-model-profiles` package: "
|
||||
"`pip install langchain-model-profiles`."
|
||||
)
|
||||
raise ImportError(informative_error_message) from err
|
||||
|
||||
provider_id = self._llm_type
|
||||
model_name = (
|
||||
# Model name is not standardized across integrations. New integrations
|
||||
# should prefer `model`.
|
||||
getattr(self, "model", None)
|
||||
or getattr(self, "model_name", None)
|
||||
or getattr(self, "model_id", "")
|
||||
)
|
||||
return get_model_profile(provider_id, model_name) or {}
|
||||
|
||||
|
||||
class SimpleChatModel(BaseChatModel):
|
||||
"""Simplified implementation for a chat model to inherit from.
|
||||
|
||||
84
libs/core/langchain_core/language_models/model_profile.py
Normal file
84
libs/core/langchain_core/language_models/model_profile.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""Model profile types and utilities."""
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
||||
class ModelProfile(TypedDict, total=False):
|
||||
"""Model profile.
|
||||
|
||||
!!! warning "Beta feature"
|
||||
This is a beta feature. The format of model profiles is subject to change.
|
||||
|
||||
Provides information about chat model capabilities, such as context window sizes
|
||||
and supported features.
|
||||
"""
|
||||
|
||||
# --- Input constraints ---
|
||||
|
||||
max_input_tokens: int
|
||||
"""Maximum context window (tokens)"""
|
||||
|
||||
image_inputs: bool
|
||||
"""Whether image inputs are supported."""
|
||||
# TODO: add more detail about formats?
|
||||
|
||||
image_url_inputs: bool
|
||||
"""Whether [image URL inputs](https://docs.langchain.com/oss/python/langchain/models#multimodal)
|
||||
are supported."""
|
||||
|
||||
pdf_inputs: bool
|
||||
"""Whether [PDF inputs](https://docs.langchain.com/oss/python/langchain/models#multimodal)
|
||||
are supported."""
|
||||
# TODO: add more detail about formats? e.g. bytes or base64
|
||||
|
||||
audio_inputs: bool
|
||||
"""Whether [audio inputs](https://docs.langchain.com/oss/python/langchain/models#multimodal)
|
||||
are supported."""
|
||||
# TODO: add more detail about formats? e.g. bytes or base64
|
||||
|
||||
video_inputs: bool
|
||||
"""Whether [video inputs](https://docs.langchain.com/oss/python/langchain/models#multimodal)
|
||||
are supported."""
|
||||
# TODO: add more detail about formats? e.g. bytes or base64
|
||||
|
||||
image_tool_message: bool
|
||||
"""Whether images can be included in tool messages."""
|
||||
|
||||
pdf_tool_message: bool
|
||||
"""Whether PDFs can be included in tool messages."""
|
||||
|
||||
# --- Output constraints ---
|
||||
|
||||
max_output_tokens: int
|
||||
"""Maximum output tokens"""
|
||||
|
||||
reasoning_output: bool
|
||||
"""Whether the model supports [reasoning / chain-of-thought](https://docs.langchain.com/oss/python/langchain/models#reasoning)"""
|
||||
|
||||
image_outputs: bool
|
||||
"""Whether [image outputs](https://docs.langchain.com/oss/python/langchain/models#multimodal)
|
||||
are supported."""
|
||||
|
||||
audio_outputs: bool
|
||||
"""Whether [audio outputs](https://docs.langchain.com/oss/python/langchain/models#multimodal)
|
||||
are supported."""
|
||||
|
||||
video_outputs: bool
|
||||
"""Whether [video outputs](https://docs.langchain.com/oss/python/langchain/models#multimodal)
|
||||
are supported."""
|
||||
|
||||
# --- Tool calling ---
|
||||
tool_calling: bool
|
||||
"""Whether the model supports [tool calling](https://docs.langchain.com/oss/python/langchain/models#tool-calling)"""
|
||||
|
||||
tool_choice: bool
|
||||
"""Whether the model supports [tool choice](https://docs.langchain.com/oss/python/langchain/models#forcing-tool-calls)"""
|
||||
|
||||
# --- Structured output ---
|
||||
structured_output: bool
|
||||
"""Whether the model supports a native [structured output](https://docs.langchain.com/oss/python/langchain/models#structured-outputs)
|
||||
feature"""
|
||||
|
||||
|
||||
ModelProfileRegistry = dict[str, ModelProfile]
|
||||
"""Registry mapping model identifiers or names to their ModelProfile."""
|
||||
@@ -61,13 +61,15 @@ class Reviver:
|
||||
"""Initialize the reviver.
|
||||
|
||||
Args:
|
||||
secrets_map: A map of secrets to load. If a secret is not found in
|
||||
the map, it will be loaded from the environment if `secrets_from_env`
|
||||
is True.
|
||||
secrets_map: A map of secrets to load.
|
||||
|
||||
If a secret is not found in the map, it will be loaded from the
|
||||
environment if `secrets_from_env` is `True`.
|
||||
valid_namespaces: A list of additional namespaces (modules)
|
||||
to allow to be deserialized.
|
||||
secrets_from_env: Whether to load secrets from the environment.
|
||||
additional_import_mappings: A dictionary of additional namespace mappings
|
||||
|
||||
You can use this to override default mappings or add new mappings.
|
||||
ignore_unserializable_fields: Whether to ignore unserializable fields.
|
||||
"""
|
||||
@@ -195,13 +197,15 @@ def loads(
|
||||
|
||||
Args:
|
||||
text: The string to load.
|
||||
secrets_map: A map of secrets to load. If a secret is not found in
|
||||
the map, it will be loaded from the environment if `secrets_from_env`
|
||||
is True.
|
||||
secrets_map: A map of secrets to load.
|
||||
|
||||
If a secret is not found in the map, it will be loaded from the environment
|
||||
if `secrets_from_env` is `True`.
|
||||
valid_namespaces: A list of additional namespaces (modules)
|
||||
to allow to be deserialized.
|
||||
secrets_from_env: Whether to load secrets from the environment.
|
||||
additional_import_mappings: A dictionary of additional namespace mappings
|
||||
|
||||
You can use this to override default mappings or add new mappings.
|
||||
ignore_unserializable_fields: Whether to ignore unserializable fields.
|
||||
|
||||
@@ -237,13 +241,15 @@ def load(
|
||||
|
||||
Args:
|
||||
obj: The object to load.
|
||||
secrets_map: A map of secrets to load. If a secret is not found in
|
||||
the map, it will be loaded from the environment if `secrets_from_env`
|
||||
is True.
|
||||
secrets_map: A map of secrets to load.
|
||||
|
||||
If a secret is not found in the map, it will be loaded from the environment
|
||||
if `secrets_from_env` is `True`.
|
||||
valid_namespaces: A list of additional namespaces (modules)
|
||||
to allow to be deserialized.
|
||||
secrets_from_env: Whether to load secrets from the environment.
|
||||
additional_import_mappings: A dictionary of additional namespace mappings
|
||||
|
||||
You can use this to override default mappings or add new mappings.
|
||||
ignore_unserializable_fields: Whether to ignore unserializable fields.
|
||||
|
||||
@@ -265,8 +271,6 @@ def load(
|
||||
return reviver(loaded_obj)
|
||||
if isinstance(obj, list):
|
||||
return [_load(o) for o in obj]
|
||||
if isinstance(obj, str) and obj in reviver.secrets_map:
|
||||
return reviver.secrets_map[obj]
|
||||
return obj
|
||||
|
||||
return _load(obj)
|
||||
|
||||
@@ -124,9 +124,11 @@ class UsageMetadata(TypedDict):
|
||||
```
|
||||
|
||||
!!! warning "Behavior changed in `langchain-core` 0.3.9"
|
||||
|
||||
Added `input_token_details` and `output_token_details`.
|
||||
|
||||
!!! note "LangSmith SDK"
|
||||
|
||||
The LangSmith SDK also has a `UsageMetadata` class. While the two share fields,
|
||||
LangSmith's `UsageMetadata` has additional fields to capture cost information
|
||||
used by the LangSmith platform.
|
||||
|
||||
@@ -5,11 +5,9 @@ from __future__ import annotations
|
||||
from typing import TYPE_CHECKING, Any, cast, overload
|
||||
|
||||
from pydantic import ConfigDict, Field
|
||||
from typing_extensions import Self
|
||||
|
||||
from langchain_core._api.deprecation import warn_deprecated
|
||||
from langchain_core.load.serializable import Serializable
|
||||
from langchain_core.messages import content as types
|
||||
from langchain_core.utils import get_bolded_text
|
||||
from langchain_core.utils._merge import merge_dicts, merge_lists
|
||||
from langchain_core.utils.interactive_env import is_interactive_env
|
||||
@@ -17,6 +15,9 @@ from langchain_core.utils.interactive_env import is_interactive_env
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
from langchain_core.messages import content as types
|
||||
from langchain_core.prompts.chat import ChatPromptTemplate
|
||||
|
||||
|
||||
|
||||
@@ -12,10 +12,11 @@ the implementation in `BaseMessage`.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk
|
||||
from langchain_core.messages import content as types
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
import warnings
|
||||
from collections.abc import Iterable
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
|
||||
from langchain_core.language_models._utils import (
|
||||
@@ -14,6 +13,8 @@ from langchain_core.language_models._utils import (
|
||||
from langchain_core.messages import content as types
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable
|
||||
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk
|
||||
|
||||
|
||||
|
||||
@@ -654,7 +654,7 @@ class PlainTextContentBlock(TypedDict):
|
||||
|
||||
!!! note
|
||||
Title and context are optional fields that may be passed to the model. See
|
||||
Anthropic [example](https://docs.claude.com/en/docs/build-with-claude/citations#citable-vs-non-citable-content).
|
||||
Anthropic [example](https://platform.claude.com/docs/en/build-with-claude/citations#citable-vs-non-citable-content).
|
||||
|
||||
!!! note "Factory function"
|
||||
`create_plaintext_block` may also be used as a factory to create a
|
||||
|
||||
@@ -738,8 +738,10 @@ def trim_messages(
|
||||
Set to `len` to count the number of **messages** in the chat history.
|
||||
|
||||
!!! note
|
||||
|
||||
Use `count_tokens_approximately` to get fast, approximate token
|
||||
counts.
|
||||
|
||||
This is recommended for using `trim_messages` on the hot path, where
|
||||
exact token counting is not necessary.
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
|
||||
def _parser_exception(
|
||||
self, e: Exception, json_object: dict
|
||||
) -> OutputParserException:
|
||||
json_string = json.dumps(json_object)
|
||||
json_string = json.dumps(json_object, ensure_ascii=False)
|
||||
name = self.pydantic_object.__name__
|
||||
msg = f"Failed to parse {name} from completion {json_string}. Got: {e}"
|
||||
return OutputParserException(msg, llm_output=json_string)
|
||||
|
||||
@@ -2,15 +2,17 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
|
||||
from pydantic import model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from langchain_core.messages import BaseMessage, BaseMessageChunk
|
||||
from langchain_core.outputs.generation import Generation
|
||||
from langchain_core.utils._merge import merge_dicts
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
class ChatGeneration(Generation):
|
||||
"""A single chat generation output.
|
||||
|
||||
@@ -6,7 +6,7 @@ import contextlib
|
||||
import json
|
||||
import typing
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Mapping
|
||||
from collections.abc import Mapping
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
@@ -33,6 +33,8 @@ from langchain_core.runnables.config import ensure_config
|
||||
from langchain_core.utils.pydantic import create_model_v2
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
from langchain_core.documents import Document
|
||||
|
||||
|
||||
|
||||
@@ -903,23 +903,28 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
|
||||
5. A string which is shorthand for `("human", template)`; e.g.,
|
||||
`"{user_input}"`
|
||||
template_format: Format of the template.
|
||||
input_variables: A list of the names of the variables whose values are
|
||||
required as inputs to the prompt.
|
||||
optional_variables: A list of the names of the variables for placeholder
|
||||
or MessagePlaceholder that are optional.
|
||||
**kwargs: Additional keyword arguments passed to `BasePromptTemplate`,
|
||||
including (but not limited to):
|
||||
|
||||
These variables are auto inferred from the prompt and user need not
|
||||
provide them.
|
||||
partial_variables: A dictionary of the partial variables the prompt
|
||||
template carries.
|
||||
- `input_variables`: A list of the names of the variables whose values
|
||||
are required as inputs to the prompt.
|
||||
- `optional_variables`: A list of the names of the variables for
|
||||
placeholder or `MessagePlaceholder` that are optional.
|
||||
|
||||
Partial variables populate the template so that you don't need to pass
|
||||
them in every time you call the prompt.
|
||||
validate_template: Whether to validate the template.
|
||||
input_types: A dictionary of the types of the variables the prompt template
|
||||
expects.
|
||||
These variables are auto inferred from the prompt and user need not
|
||||
provide them.
|
||||
|
||||
If not provided, all variables are assumed to be strings.
|
||||
- `partial_variables`: A dictionary of the partial variables the prompt
|
||||
template carries.
|
||||
|
||||
Partial variables populate the template so that you don't need to
|
||||
pass them in every time you call the prompt.
|
||||
|
||||
- `validate_template`: Whether to validate the template.
|
||||
- `input_types`: A dictionary of the types of the variables the prompt
|
||||
template expects.
|
||||
|
||||
If not provided, all variables are assumed to be strings.
|
||||
|
||||
Examples:
|
||||
Instantiation from a list of message templates:
|
||||
|
||||
@@ -6,10 +6,10 @@ from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_core.load import Serializable
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.utils.interactive_env import is_interactive_env
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.prompts.chat import ChatPromptTemplate
|
||||
|
||||
|
||||
|
||||
@@ -4,9 +4,8 @@ from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from abc import ABC
|
||||
from collections.abc import Callable, Sequence
|
||||
from string import Formatter
|
||||
from typing import Any, Literal
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
from pydantic import BaseModel, create_model
|
||||
|
||||
@@ -16,10 +15,70 @@ from langchain_core.utils import get_colored_text, mustache
|
||||
from langchain_core.utils.formatting import formatter
|
||||
from langchain_core.utils.interactive_env import is_interactive_env
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable, Sequence
|
||||
|
||||
try:
|
||||
from jinja2 import Environment, meta
|
||||
from jinja2 import meta
|
||||
from jinja2.exceptions import SecurityError
|
||||
from jinja2.sandbox import SandboxedEnvironment
|
||||
|
||||
class _RestrictedSandboxedEnvironment(SandboxedEnvironment):
|
||||
"""A more restrictive Jinja2 sandbox that blocks all attribute/method access.
|
||||
|
||||
This sandbox only allows simple variable lookups, no attribute or method access.
|
||||
This prevents template injection attacks via methods like parse_raw().
|
||||
"""
|
||||
|
||||
def is_safe_attribute(self, _obj: Any, _attr: str, _value: Any) -> bool:
|
||||
"""Block ALL attribute access for security.
|
||||
|
||||
Only allow accessing variables directly from the context dict,
|
||||
no attribute access on those objects.
|
||||
|
||||
Args:
|
||||
_obj: The object being accessed (unused, always blocked).
|
||||
_attr: The attribute name (unused, always blocked).
|
||||
_value: The attribute value (unused, always blocked).
|
||||
|
||||
Returns:
|
||||
False - all attribute access is blocked.
|
||||
"""
|
||||
# Block all attribute access
|
||||
return False
|
||||
|
||||
def is_safe_callable(self, _obj: Any) -> bool:
|
||||
"""Block all method calls for security.
|
||||
|
||||
Args:
|
||||
_obj: The object being checked (unused, always blocked).
|
||||
|
||||
Returns:
|
||||
False - all callables are blocked.
|
||||
"""
|
||||
return False
|
||||
|
||||
def getattr(self, obj: Any, attribute: str) -> Any:
|
||||
"""Override getattr to block all attribute access.
|
||||
|
||||
Args:
|
||||
obj: The object.
|
||||
attribute: The attribute name.
|
||||
|
||||
Returns:
|
||||
Never returns.
|
||||
|
||||
Raises:
|
||||
SecurityError: Always, to block attribute access.
|
||||
"""
|
||||
msg = (
|
||||
f"Access to attributes is not allowed in templates. "
|
||||
f"Attempted to access '{attribute}' on {type(obj).__name__}. "
|
||||
f"Use only simple variable names like {{{{variable}}}} "
|
||||
f"without dots or methods."
|
||||
)
|
||||
raise SecurityError(msg)
|
||||
|
||||
_HAS_JINJA2 = True
|
||||
except ImportError:
|
||||
_HAS_JINJA2 = False
|
||||
@@ -59,14 +118,10 @@ def jinja2_formatter(template: str, /, **kwargs: Any) -> str:
|
||||
)
|
||||
raise ImportError(msg)
|
||||
|
||||
# This uses a sandboxed environment to prevent arbitrary code execution.
|
||||
# Jinja2 uses an opt-out rather than opt-in approach for sand-boxing.
|
||||
# Please treat this sand-boxing as a best-effort approach rather than
|
||||
# a guarantee of security.
|
||||
# We recommend to never use jinja2 templates with untrusted inputs.
|
||||
# https://jinja.palletsprojects.com/en/3.1.x/sandbox/
|
||||
# approach not a guarantee of security.
|
||||
return SandboxedEnvironment().from_string(template).render(**kwargs)
|
||||
# Use a restricted sandbox that blocks ALL attribute/method access
|
||||
# Only simple variable lookups like {{variable}} are allowed
|
||||
# Attribute access like {{variable.attr}} or {{variable.method()}} is blocked
|
||||
return _RestrictedSandboxedEnvironment().from_string(template).render(**kwargs)
|
||||
|
||||
|
||||
def validate_jinja2(template: str, input_variables: list[str]) -> None:
|
||||
@@ -101,7 +156,7 @@ def _get_jinja2_variables_from_template(template: str) -> set[str]:
|
||||
"Please install it with `pip install jinja2`."
|
||||
)
|
||||
raise ImportError(msg)
|
||||
env = Environment() # noqa: S701
|
||||
env = _RestrictedSandboxedEnvironment()
|
||||
ast = env.parse(template)
|
||||
return meta.find_undeclared_variables(ast)
|
||||
|
||||
@@ -271,6 +326,30 @@ def get_template_variables(template: str, template_format: str) -> list[str]:
|
||||
msg = f"Unsupported template format: {template_format}"
|
||||
raise ValueError(msg)
|
||||
|
||||
# For f-strings, block attribute access and indexing syntax
|
||||
# This prevents template injection attacks via accessing dangerous attributes
|
||||
if template_format == "f-string":
|
||||
for var in input_variables:
|
||||
# Formatter().parse() returns field names with dots/brackets if present
|
||||
# e.g., "obj.attr" or "obj[0]" - we need to block these
|
||||
if "." in var or "[" in var or "]" in var:
|
||||
msg = (
|
||||
f"Invalid variable name {var!r} in f-string template. "
|
||||
f"Variable names cannot contain attribute "
|
||||
f"access (.) or indexing ([])."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
# Block variable names that are all digits (e.g., "0", "100")
|
||||
# These are interpreted as positional arguments, not keyword arguments
|
||||
if var.isdigit():
|
||||
msg = (
|
||||
f"Invalid variable name {var!r} in f-string template. "
|
||||
f"Variable names cannot be all digits as they are interpreted "
|
||||
f"as positional arguments."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
return sorted(input_variables)
|
||||
|
||||
|
||||
|
||||
@@ -49,7 +49,13 @@ class StructuredPrompt(ChatPromptTemplate):
|
||||
structured_output_kwargs: additional kwargs for structured output.
|
||||
template_format: template format for the prompt.
|
||||
"""
|
||||
schema_ = schema_ or kwargs.pop("schema")
|
||||
schema_ = schema_ or kwargs.pop("schema", None)
|
||||
if not schema_:
|
||||
err_msg = (
|
||||
"Must pass in a non-empty structured output schema. Received: "
|
||||
f"{schema_}"
|
||||
)
|
||||
raise ValueError(err_msg)
|
||||
structured_output_kwargs = structured_output_kwargs or {}
|
||||
for k in set(kwargs).difference(get_pydantic_field_names(self.__class__)):
|
||||
structured_output_kwargs[k] = kwargs.pop(k)
|
||||
|
||||
@@ -707,51 +707,53 @@ class Runnable(ABC, Generic[Input, Output]):
|
||||
def pick(self, keys: str | list[str]) -> RunnableSerializable[Any, Any]:
|
||||
"""Pick keys from the output `dict` of this `Runnable`.
|
||||
|
||||
Pick a single key:
|
||||
!!! example "Pick a single key"
|
||||
|
||||
```python
|
||||
import json
|
||||
```python
|
||||
import json
|
||||
|
||||
from langchain_core.runnables import RunnableLambda, RunnableMap
|
||||
from langchain_core.runnables import RunnableLambda, RunnableMap
|
||||
|
||||
as_str = RunnableLambda(str)
|
||||
as_json = RunnableLambda(json.loads)
|
||||
chain = RunnableMap(str=as_str, json=as_json)
|
||||
as_str = RunnableLambda(str)
|
||||
as_json = RunnableLambda(json.loads)
|
||||
chain = RunnableMap(str=as_str, json=as_json)
|
||||
|
||||
chain.invoke("[1, 2, 3]")
|
||||
# -> {"str": "[1, 2, 3]", "json": [1, 2, 3]}
|
||||
chain.invoke("[1, 2, 3]")
|
||||
# -> {"str": "[1, 2, 3]", "json": [1, 2, 3]}
|
||||
|
||||
json_only_chain = chain.pick("json")
|
||||
json_only_chain.invoke("[1, 2, 3]")
|
||||
# -> [1, 2, 3]
|
||||
```
|
||||
json_only_chain = chain.pick("json")
|
||||
json_only_chain.invoke("[1, 2, 3]")
|
||||
# -> [1, 2, 3]
|
||||
```
|
||||
|
||||
Pick a list of keys:
|
||||
!!! example "Pick a list of keys"
|
||||
|
||||
```python
|
||||
from typing import Any
|
||||
```python
|
||||
from typing import Any
|
||||
|
||||
import json
|
||||
import json
|
||||
|
||||
from langchain_core.runnables import RunnableLambda, RunnableMap
|
||||
from langchain_core.runnables import RunnableLambda, RunnableMap
|
||||
|
||||
as_str = RunnableLambda(str)
|
||||
as_json = RunnableLambda(json.loads)
|
||||
as_str = RunnableLambda(str)
|
||||
as_json = RunnableLambda(json.loads)
|
||||
|
||||
|
||||
def as_bytes(x: Any) -> bytes:
|
||||
return bytes(x, "utf-8")
|
||||
def as_bytes(x: Any) -> bytes:
|
||||
return bytes(x, "utf-8")
|
||||
|
||||
|
||||
chain = RunnableMap(str=as_str, json=as_json, bytes=RunnableLambda(as_bytes))
|
||||
chain = RunnableMap(
|
||||
str=as_str, json=as_json, bytes=RunnableLambda(as_bytes)
|
||||
)
|
||||
|
||||
chain.invoke("[1, 2, 3]")
|
||||
# -> {"str": "[1, 2, 3]", "json": [1, 2, 3], "bytes": b"[1, 2, 3]"}
|
||||
chain.invoke("[1, 2, 3]")
|
||||
# -> {"str": "[1, 2, 3]", "json": [1, 2, 3], "bytes": b"[1, 2, 3]"}
|
||||
|
||||
json_and_bytes_chain = chain.pick(["json", "bytes"])
|
||||
json_and_bytes_chain.invoke("[1, 2, 3]")
|
||||
# -> {"json": [1, 2, 3], "bytes": b"[1, 2, 3]"}
|
||||
```
|
||||
json_and_bytes_chain = chain.pick(["json", "bytes"])
|
||||
json_and_bytes_chain.invoke("[1, 2, 3]")
|
||||
# -> {"json": [1, 2, 3], "bytes": b"[1, 2, 3]"}
|
||||
```
|
||||
|
||||
Args:
|
||||
keys: A key or list of keys to pick from the output dict.
|
||||
@@ -1372,48 +1374,50 @@ class Runnable(ABC, Generic[Input, Output]):
|
||||
).with_config({"run_name": "my_template", "tags": ["my_template"]})
|
||||
```
|
||||
|
||||
For instance:
|
||||
!!! example
|
||||
|
||||
```python
|
||||
from langchain_core.runnables import RunnableLambda
|
||||
```python
|
||||
from langchain_core.runnables import RunnableLambda
|
||||
|
||||
|
||||
async def reverse(s: str) -> str:
|
||||
return s[::-1]
|
||||
async def reverse(s: str) -> str:
|
||||
return s[::-1]
|
||||
|
||||
|
||||
chain = RunnableLambda(func=reverse)
|
||||
chain = RunnableLambda(func=reverse)
|
||||
|
||||
events = [event async for event in chain.astream_events("hello", version="v2")]
|
||||
events = [
|
||||
event async for event in chain.astream_events("hello", version="v2")
|
||||
]
|
||||
|
||||
# Will produce the following events
|
||||
# (run_id, and parent_ids has been omitted for brevity):
|
||||
[
|
||||
{
|
||||
"data": {"input": "hello"},
|
||||
"event": "on_chain_start",
|
||||
"metadata": {},
|
||||
"name": "reverse",
|
||||
"tags": [],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": "olleh"},
|
||||
"event": "on_chain_stream",
|
||||
"metadata": {},
|
||||
"name": "reverse",
|
||||
"tags": [],
|
||||
},
|
||||
{
|
||||
"data": {"output": "olleh"},
|
||||
"event": "on_chain_end",
|
||||
"metadata": {},
|
||||
"name": "reverse",
|
||||
"tags": [],
|
||||
},
|
||||
]
|
||||
```
|
||||
# Will produce the following events
|
||||
# (run_id, and parent_ids has been omitted for brevity):
|
||||
[
|
||||
{
|
||||
"data": {"input": "hello"},
|
||||
"event": "on_chain_start",
|
||||
"metadata": {},
|
||||
"name": "reverse",
|
||||
"tags": [],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": "olleh"},
|
||||
"event": "on_chain_stream",
|
||||
"metadata": {},
|
||||
"name": "reverse",
|
||||
"tags": [],
|
||||
},
|
||||
{
|
||||
"data": {"output": "olleh"},
|
||||
"event": "on_chain_end",
|
||||
"metadata": {},
|
||||
"name": "reverse",
|
||||
"tags": [],
|
||||
},
|
||||
]
|
||||
```
|
||||
|
||||
```python title="Example: Dispatch Custom Event"
|
||||
```python title="Dispatch custom event"
|
||||
from langchain_core.callbacks.manager import (
|
||||
adispatch_custom_event,
|
||||
)
|
||||
@@ -1447,10 +1451,13 @@ class Runnable(ABC, Generic[Input, Output]):
|
||||
Args:
|
||||
input: The input to the `Runnable`.
|
||||
config: The config to use for the `Runnable`.
|
||||
version: The version of the schema to use either `'v2'` or `'v1'`.
|
||||
version: The version of the schema to use, either `'v2'` or `'v1'`.
|
||||
|
||||
Users should use `'v2'`.
|
||||
|
||||
`'v1'` is for backwards compatibility and will be deprecated
|
||||
in `0.4.0`.
|
||||
|
||||
No default will be assigned until the API is stabilized.
|
||||
custom events will only be surfaced in `'v2'`.
|
||||
include_names: Only include events from `Runnable` objects with matching names.
|
||||
@@ -1460,6 +1467,7 @@ class Runnable(ABC, Generic[Input, Output]):
|
||||
exclude_types: Exclude events from `Runnable` objects with matching types.
|
||||
exclude_tags: Exclude events from `Runnable` objects with matching tags.
|
||||
**kwargs: Additional keyword arguments to pass to the `Runnable`.
|
||||
|
||||
These will be passed to `astream_log` as this implementation
|
||||
of `astream_events` is built on top of `astream_log`.
|
||||
|
||||
@@ -2476,82 +2484,82 @@ class Runnable(ABC, Generic[Input, Output]):
|
||||
Returns:
|
||||
A `BaseTool` instance.
|
||||
|
||||
Typed dict input:
|
||||
!!! example "`TypedDict` input"
|
||||
|
||||
```python
|
||||
from typing_extensions import TypedDict
|
||||
from langchain_core.runnables import RunnableLambda
|
||||
```python
|
||||
from typing_extensions import TypedDict
|
||||
from langchain_core.runnables import RunnableLambda
|
||||
|
||||
|
||||
class Args(TypedDict):
|
||||
a: int
|
||||
b: list[int]
|
||||
class Args(TypedDict):
|
||||
a: int
|
||||
b: list[int]
|
||||
|
||||
|
||||
def f(x: Args) -> str:
|
||||
return str(x["a"] * max(x["b"]))
|
||||
def f(x: Args) -> str:
|
||||
return str(x["a"] * max(x["b"]))
|
||||
|
||||
|
||||
runnable = RunnableLambda(f)
|
||||
as_tool = runnable.as_tool()
|
||||
as_tool.invoke({"a": 3, "b": [1, 2]})
|
||||
```
|
||||
runnable = RunnableLambda(f)
|
||||
as_tool = runnable.as_tool()
|
||||
as_tool.invoke({"a": 3, "b": [1, 2]})
|
||||
```
|
||||
|
||||
`dict` input, specifying schema via `args_schema`:
|
||||
!!! example "`dict` input, specifying schema via `args_schema`"
|
||||
|
||||
```python
|
||||
from typing import Any
|
||||
from pydantic import BaseModel, Field
|
||||
from langchain_core.runnables import RunnableLambda
|
||||
```python
|
||||
from typing import Any
|
||||
from pydantic import BaseModel, Field
|
||||
from langchain_core.runnables import RunnableLambda
|
||||
|
||||
def f(x: dict[str, Any]) -> str:
|
||||
return str(x["a"] * max(x["b"]))
|
||||
def f(x: dict[str, Any]) -> str:
|
||||
return str(x["a"] * max(x["b"]))
|
||||
|
||||
class FSchema(BaseModel):
|
||||
\"\"\"Apply a function to an integer and list of integers.\"\"\"
|
||||
class FSchema(BaseModel):
|
||||
\"\"\"Apply a function to an integer and list of integers.\"\"\"
|
||||
|
||||
a: int = Field(..., description="Integer")
|
||||
b: list[int] = Field(..., description="List of ints")
|
||||
a: int = Field(..., description="Integer")
|
||||
b: list[int] = Field(..., description="List of ints")
|
||||
|
||||
runnable = RunnableLambda(f)
|
||||
as_tool = runnable.as_tool(FSchema)
|
||||
as_tool.invoke({"a": 3, "b": [1, 2]})
|
||||
```
|
||||
runnable = RunnableLambda(f)
|
||||
as_tool = runnable.as_tool(FSchema)
|
||||
as_tool.invoke({"a": 3, "b": [1, 2]})
|
||||
```
|
||||
|
||||
`dict` input, specifying schema via `arg_types`:
|
||||
!!! example "`dict` input, specifying schema via `arg_types`"
|
||||
|
||||
```python
|
||||
from typing import Any
|
||||
from langchain_core.runnables import RunnableLambda
|
||||
```python
|
||||
from typing import Any
|
||||
from langchain_core.runnables import RunnableLambda
|
||||
|
||||
|
||||
def f(x: dict[str, Any]) -> str:
|
||||
return str(x["a"] * max(x["b"]))
|
||||
def f(x: dict[str, Any]) -> str:
|
||||
return str(x["a"] * max(x["b"]))
|
||||
|
||||
|
||||
runnable = RunnableLambda(f)
|
||||
as_tool = runnable.as_tool(arg_types={"a": int, "b": list[int]})
|
||||
as_tool.invoke({"a": 3, "b": [1, 2]})
|
||||
```
|
||||
runnable = RunnableLambda(f)
|
||||
as_tool = runnable.as_tool(arg_types={"a": int, "b": list[int]})
|
||||
as_tool.invoke({"a": 3, "b": [1, 2]})
|
||||
```
|
||||
|
||||
`str` input:
|
||||
!!! example "`str` input"
|
||||
|
||||
```python
|
||||
from langchain_core.runnables import RunnableLambda
|
||||
```python
|
||||
from langchain_core.runnables import RunnableLambda
|
||||
|
||||
|
||||
def f(x: str) -> str:
|
||||
return x + "a"
|
||||
def f(x: str) -> str:
|
||||
return x + "a"
|
||||
|
||||
|
||||
def g(x: str) -> str:
|
||||
return x + "z"
|
||||
def g(x: str) -> str:
|
||||
return x + "z"
|
||||
|
||||
|
||||
runnable = RunnableLambda(f) | g
|
||||
as_tool = runnable.as_tool()
|
||||
as_tool.invoke("b")
|
||||
```
|
||||
runnable = RunnableLambda(f) | g
|
||||
as_tool = runnable.as_tool()
|
||||
as_tool.invoke("b")
|
||||
```
|
||||
"""
|
||||
# Avoid circular import
|
||||
from langchain_core.tools import convert_runnable_to_tool # noqa: PLC0415
|
||||
@@ -2603,29 +2611,33 @@ class RunnableSerializable(Serializable, Runnable[Input, Output]):
|
||||
Returns:
|
||||
A new `Runnable` with the fields configured.
|
||||
|
||||
```python
|
||||
from langchain_core.runnables import ConfigurableField
|
||||
from langchain_openai import ChatOpenAI
|
||||
!!! example
|
||||
|
||||
model = ChatOpenAI(max_tokens=20).configurable_fields(
|
||||
max_tokens=ConfigurableField(
|
||||
id="output_token_number",
|
||||
name="Max tokens in the output",
|
||||
description="The maximum number of tokens in the output",
|
||||
```python
|
||||
from langchain_core.runnables import ConfigurableField
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
model = ChatOpenAI(max_tokens=20).configurable_fields(
|
||||
max_tokens=ConfigurableField(
|
||||
id="output_token_number",
|
||||
name="Max tokens in the output",
|
||||
description="The maximum number of tokens in the output",
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# max_tokens = 20
|
||||
print("max_tokens_20: ", model.invoke("tell me something about chess").content)
|
||||
# max_tokens = 20
|
||||
print(
|
||||
"max_tokens_20: ", model.invoke("tell me something about chess").content
|
||||
)
|
||||
|
||||
# max_tokens = 200
|
||||
print(
|
||||
"max_tokens_200: ",
|
||||
model.with_config(configurable={"output_token_number": 200})
|
||||
.invoke("tell me something about chess")
|
||||
.content,
|
||||
)
|
||||
```
|
||||
# max_tokens = 200
|
||||
print(
|
||||
"max_tokens_200: ",
|
||||
model.with_config(configurable={"output_token_number": 200})
|
||||
.invoke("tell me something about chess")
|
||||
.content,
|
||||
)
|
||||
```
|
||||
"""
|
||||
# Import locally to prevent circular import
|
||||
from langchain_core.runnables.configurable import ( # noqa: PLC0415
|
||||
@@ -2664,29 +2676,31 @@ class RunnableSerializable(Serializable, Runnable[Input, Output]):
|
||||
Returns:
|
||||
A new `Runnable` with the alternatives configured.
|
||||
|
||||
```python
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
from langchain_core.runnables.utils import ConfigurableField
|
||||
from langchain_openai import ChatOpenAI
|
||||
!!! example
|
||||
|
||||
model = ChatAnthropic(
|
||||
model_name="claude-sonnet-4-5-20250929"
|
||||
).configurable_alternatives(
|
||||
ConfigurableField(id="llm"),
|
||||
default_key="anthropic",
|
||||
openai=ChatOpenAI(),
|
||||
)
|
||||
```python
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
from langchain_core.runnables.utils import ConfigurableField
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
# uses the default model ChatAnthropic
|
||||
print(model.invoke("which organization created you?").content)
|
||||
model = ChatAnthropic(
|
||||
model_name="claude-sonnet-4-5-20250929"
|
||||
).configurable_alternatives(
|
||||
ConfigurableField(id="llm"),
|
||||
default_key="anthropic",
|
||||
openai=ChatOpenAI(),
|
||||
)
|
||||
|
||||
# uses ChatOpenAI
|
||||
print(
|
||||
model.with_config(configurable={"llm": "openai"})
|
||||
.invoke("which organization created you?")
|
||||
.content
|
||||
)
|
||||
```
|
||||
# uses the default model ChatAnthropic
|
||||
print(model.invoke("which organization created you?").content)
|
||||
|
||||
# uses ChatOpenAI
|
||||
print(
|
||||
model.with_config(configurable={"llm": "openai"})
|
||||
.invoke("which organization created you?")
|
||||
.content
|
||||
)
|
||||
```
|
||||
"""
|
||||
# Import locally to prevent circular import
|
||||
from langchain_core.runnables.configurable import ( # noqa: PLC0415
|
||||
|
||||
@@ -303,7 +303,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
|
||||
|
||||
Args:
|
||||
input: The input to the `Runnable`.
|
||||
config: The configuration for the Runna`ble.
|
||||
config: The configuration for the `Runnable`.
|
||||
**kwargs: Additional keyword arguments to pass to the `Runnable`.
|
||||
|
||||
Yields:
|
||||
|
||||
@@ -47,54 +47,59 @@ class EmptyDict(TypedDict, total=False):
|
||||
|
||||
|
||||
class RunnableConfig(TypedDict, total=False):
|
||||
"""Configuration for a Runnable."""
|
||||
"""Configuration for a `Runnable`.
|
||||
|
||||
See the [reference docs](https://reference.langchain.com/python/langchain_core/runnables/#langchain_core.runnables.RunnableConfig)
|
||||
for more details.
|
||||
"""
|
||||
|
||||
tags: list[str]
|
||||
"""
|
||||
Tags for this call and any sub-calls (eg. a Chain calling an LLM).
|
||||
"""Tags for this call and any sub-calls (e.g. a Chain calling an LLM).
|
||||
|
||||
You can use these to filter calls.
|
||||
"""
|
||||
|
||||
metadata: dict[str, Any]
|
||||
"""
|
||||
Metadata for this call and any sub-calls (eg. a Chain calling an LLM).
|
||||
"""Metadata for this call and any sub-calls (e.g. a Chain calling an LLM).
|
||||
|
||||
Keys should be strings, values should be JSON-serializable.
|
||||
"""
|
||||
|
||||
callbacks: Callbacks
|
||||
"""
|
||||
Callbacks for this call and any sub-calls (eg. a Chain calling an LLM).
|
||||
"""Callbacks for this call and any sub-calls (e.g. a Chain calling an LLM).
|
||||
|
||||
Tags are passed to all callbacks, metadata is passed to handle*Start callbacks.
|
||||
"""
|
||||
|
||||
run_name: str
|
||||
"""
|
||||
Name for the tracer run for this call. Defaults to the name of the class.
|
||||
"""
|
||||
"""Name for the tracer run for this call.
|
||||
|
||||
Defaults to the name of the class."""
|
||||
|
||||
max_concurrency: int | None
|
||||
"""
|
||||
Maximum number of parallel calls to make. If not provided, defaults to
|
||||
`ThreadPoolExecutor`'s default.
|
||||
"""Maximum number of parallel calls to make.
|
||||
|
||||
If not provided, defaults to `ThreadPoolExecutor`'s default.
|
||||
"""
|
||||
|
||||
recursion_limit: int
|
||||
"""
|
||||
Maximum number of times a call can recurse. If not provided, defaults to `25`.
|
||||
"""Maximum number of times a call can recurse.
|
||||
|
||||
If not provided, defaults to `25`.
|
||||
"""
|
||||
|
||||
configurable: dict[str, Any]
|
||||
"""
|
||||
Runtime values for attributes previously made configurable on this `Runnable`,
|
||||
"""Runtime values for attributes previously made configurable on this `Runnable`,
|
||||
or sub-Runnables, through `configurable_fields` or `configurable_alternatives`.
|
||||
|
||||
Check `output_schema` for a description of the attributes that have been made
|
||||
configurable.
|
||||
"""
|
||||
|
||||
run_id: uuid.UUID | None
|
||||
"""
|
||||
Unique identifier for the tracer run for this call. If not provided, a new UUID
|
||||
will be generated.
|
||||
"""Unique identifier for the tracer run for this call.
|
||||
|
||||
If not provided, a new UUID will be generated.
|
||||
"""
|
||||
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
@@ -22,7 +21,7 @@ from langchain_core.runnables.base import Runnable, RunnableSerializable
|
||||
from langchain_core.utils.pydantic import _IgnoreUnserializable, is_basemodel_subclass
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Sequence
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -642,6 +641,7 @@ class Graph:
|
||||
retry_delay: float = 1.0,
|
||||
frontmatter_config: dict[str, Any] | None = None,
|
||||
base_url: str | None = None,
|
||||
proxies: dict[str, str] | None = None,
|
||||
) -> bytes:
|
||||
"""Draw the graph as a PNG image using Mermaid.
|
||||
|
||||
@@ -674,11 +674,10 @@ class Graph:
|
||||
}
|
||||
```
|
||||
base_url: The base URL of the Mermaid server for rendering via API.
|
||||
|
||||
proxies: HTTP/HTTPS proxies for requests (e.g. `{"http": "http://127.0.0.1:7890"}`).
|
||||
|
||||
Returns:
|
||||
The PNG image as bytes.
|
||||
|
||||
"""
|
||||
# Import locally to prevent circular import
|
||||
from langchain_core.runnables.graph_mermaid import ( # noqa: PLC0415
|
||||
@@ -699,6 +698,7 @@ class Graph:
|
||||
padding=padding,
|
||||
max_retries=max_retries,
|
||||
retry_delay=retry_delay,
|
||||
proxies=proxies,
|
||||
base_url=base_url,
|
||||
)
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ from __future__ import annotations
|
||||
|
||||
import math
|
||||
import os
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
try:
|
||||
@@ -20,6 +19,8 @@ except ImportError:
|
||||
_HAS_GRANDALF = False
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Mapping, Sequence
|
||||
|
||||
from langchain_core.runnables.graph import Edge as LangEdge
|
||||
|
||||
|
||||
|
||||
@@ -281,6 +281,7 @@ def draw_mermaid_png(
|
||||
max_retries: int = 1,
|
||||
retry_delay: float = 1.0,
|
||||
base_url: str | None = None,
|
||||
proxies: dict[str, str] | None = None,
|
||||
) -> bytes:
|
||||
"""Draws a Mermaid graph as PNG using provided syntax.
|
||||
|
||||
@@ -293,6 +294,7 @@ def draw_mermaid_png(
|
||||
max_retries: Maximum number of retries (MermaidDrawMethod.API).
|
||||
retry_delay: Delay between retries (MermaidDrawMethod.API).
|
||||
base_url: Base URL for the Mermaid.ink API.
|
||||
proxies: HTTP/HTTPS proxies for requests (e.g. `{"http": "http://127.0.0.1:7890"}`).
|
||||
|
||||
Returns:
|
||||
PNG image bytes.
|
||||
@@ -314,6 +316,7 @@ def draw_mermaid_png(
|
||||
max_retries=max_retries,
|
||||
retry_delay=retry_delay,
|
||||
base_url=base_url,
|
||||
proxies=proxies,
|
||||
)
|
||||
else:
|
||||
supported_methods = ", ".join([m.value for m in MermaidDrawMethod])
|
||||
@@ -405,6 +408,7 @@ def _render_mermaid_using_api(
|
||||
file_type: Literal["jpeg", "png", "webp"] | None = "png",
|
||||
max_retries: int = 1,
|
||||
retry_delay: float = 1.0,
|
||||
proxies: dict[str, str] | None = None,
|
||||
base_url: str | None = None,
|
||||
) -> bytes:
|
||||
"""Renders Mermaid graph using the Mermaid.INK API."""
|
||||
@@ -445,7 +449,7 @@ def _render_mermaid_using_api(
|
||||
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
response = requests.get(image_url, timeout=10)
|
||||
response = requests.get(image_url, timeout=10, proxies=proxies)
|
||||
if response.status_code == requests.codes.ok:
|
||||
img_bytes = response.content
|
||||
if output_file_path is not None:
|
||||
|
||||
@@ -7,8 +7,7 @@ import asyncio
|
||||
import inspect
|
||||
import sys
|
||||
import textwrap
|
||||
from collections.abc import Callable, Mapping, Sequence
|
||||
from contextvars import Context
|
||||
from collections.abc import Mapping, Sequence
|
||||
from functools import lru_cache
|
||||
from inspect import signature
|
||||
from itertools import groupby
|
||||
@@ -31,9 +30,11 @@ if TYPE_CHECKING:
|
||||
AsyncIterable,
|
||||
AsyncIterator,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Coroutine,
|
||||
Iterable,
|
||||
)
|
||||
from contextvars import Context
|
||||
|
||||
from langchain_core.runnables.schema import StreamEvent
|
||||
|
||||
|
||||
@@ -386,6 +386,8 @@ class ToolException(Exception): # noqa: N818
|
||||
|
||||
ArgsSchema = TypeBaseModel | dict[str, Any]
|
||||
|
||||
_EMPTY_SET: frozenset[str] = frozenset()
|
||||
|
||||
|
||||
class BaseTool(RunnableSerializable[str | dict | ToolCall, Any]):
|
||||
"""Base class for all LangChain tools.
|
||||
@@ -569,6 +571,11 @@ class ChildTool(BaseTool):
|
||||
self.name, full_schema, fields, fn_description=self.description
|
||||
)
|
||||
|
||||
@functools.cached_property
|
||||
def _injected_args_keys(self) -> frozenset[str]:
|
||||
# base implementation doesn't manage injected args
|
||||
return _EMPTY_SET
|
||||
|
||||
# --- Runnable ---
|
||||
|
||||
@override
|
||||
@@ -649,6 +656,7 @@ class ChildTool(BaseTool):
|
||||
if isinstance(input_args, dict):
|
||||
return tool_input
|
||||
if issubclass(input_args, BaseModel):
|
||||
# Check args_schema for InjectedToolCallId
|
||||
for k, v in get_all_basemodel_annotations(input_args).items():
|
||||
if _is_injected_arg_type(v, injected_type=InjectedToolCallId):
|
||||
if tool_call_id is None:
|
||||
@@ -664,6 +672,7 @@ class ChildTool(BaseTool):
|
||||
result = input_args.model_validate(tool_input)
|
||||
result_dict = result.model_dump()
|
||||
elif issubclass(input_args, BaseModelV1):
|
||||
# Check args_schema for InjectedToolCallId
|
||||
for k, v in get_all_basemodel_annotations(input_args).items():
|
||||
if _is_injected_arg_type(v, injected_type=InjectedToolCallId):
|
||||
if tool_call_id is None:
|
||||
@@ -683,9 +692,25 @@ class ChildTool(BaseTool):
|
||||
f"args_schema must be a Pydantic BaseModel, got {self.args_schema}"
|
||||
)
|
||||
raise NotImplementedError(msg)
|
||||
return {
|
||||
k: getattr(result, k) for k, v in result_dict.items() if k in tool_input
|
||||
validated_input = {
|
||||
k: getattr(result, k) for k in result_dict if k in tool_input
|
||||
}
|
||||
for k in self._injected_args_keys:
|
||||
if k == "tool_call_id":
|
||||
if tool_call_id is None:
|
||||
msg = (
|
||||
"When tool includes an InjectedToolCallId "
|
||||
"argument, tool must always be invoked with a full "
|
||||
"model ToolCall of the form: {'args': {...}, "
|
||||
"'name': '...', 'type': 'tool_call', "
|
||||
"'tool_call_id': '...'}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
validated_input[k] = tool_call_id
|
||||
if k in tool_input:
|
||||
injected_val = tool_input[k]
|
||||
validated_input[k] = injected_val
|
||||
return validated_input
|
||||
return tool_input
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import textwrap
|
||||
from collections.abc import Awaitable, Callable
|
||||
from inspect import signature
|
||||
@@ -21,10 +22,12 @@ from langchain_core.callbacks import (
|
||||
)
|
||||
from langchain_core.runnables import RunnableConfig, run_in_executor
|
||||
from langchain_core.tools.base import (
|
||||
_EMPTY_SET,
|
||||
FILTERED_ARGS,
|
||||
ArgsSchema,
|
||||
BaseTool,
|
||||
_get_runnable_config_param,
|
||||
_is_injected_arg_type,
|
||||
create_schema_from_function,
|
||||
)
|
||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
@@ -241,6 +244,17 @@ class StructuredTool(BaseTool):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@functools.cached_property
|
||||
def _injected_args_keys(self) -> frozenset[str]:
|
||||
fn = self.func or self.coroutine
|
||||
if fn is None:
|
||||
return _EMPTY_SET
|
||||
return frozenset(
|
||||
k
|
||||
for k, v in signature(fn).parameters.items()
|
||||
if _is_injected_arg_type(v.annotation)
|
||||
)
|
||||
|
||||
|
||||
def _filter_schema_args(func: Callable) -> list[str]:
|
||||
filter_args = list(FILTERED_ARGS)
|
||||
|
||||
@@ -15,12 +15,6 @@ from typing import (
|
||||
|
||||
from langchain_core.exceptions import TracerException
|
||||
from langchain_core.load import dumpd
|
||||
from langchain_core.outputs import (
|
||||
ChatGeneration,
|
||||
ChatGenerationChunk,
|
||||
GenerationChunk,
|
||||
LLMResult,
|
||||
)
|
||||
from langchain_core.tracers.schemas import Run
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -31,6 +25,12 @@ if TYPE_CHECKING:
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.outputs import (
|
||||
ChatGeneration,
|
||||
ChatGenerationChunk,
|
||||
GenerationChunk,
|
||||
LLMResult,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -8,7 +8,6 @@ import logging
|
||||
import types
|
||||
import typing
|
||||
import uuid
|
||||
from collections.abc import Callable
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Annotated,
|
||||
@@ -33,6 +32,8 @@ from langchain_core.utils.json_schema import dereference_refs
|
||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -352,6 +353,7 @@ def convert_to_openai_function(
|
||||
ValueError: If function is not in a supported format.
|
||||
|
||||
!!! warning "Behavior changed in `langchain-core` 0.3.16"
|
||||
|
||||
`description` and `parameters` keys are now optional. Only `name` is
|
||||
required and guaranteed to be part of the output.
|
||||
"""
|
||||
@@ -476,15 +478,18 @@ def convert_to_openai_tool(
|
||||
OpenAI tool-calling API.
|
||||
|
||||
!!! warning "Behavior changed in `langchain-core` 0.3.16"
|
||||
|
||||
`description` and `parameters` keys are now optional. Only `name` is
|
||||
required and guaranteed to be part of the output.
|
||||
|
||||
!!! warning "Behavior changed in `langchain-core` 0.3.44"
|
||||
|
||||
Return OpenAI Responses API-style tools unchanged. This includes
|
||||
any dict with `"type"` in `"file_search"`, `"function"`,
|
||||
`"computer_use_preview"`, `"web_search_preview"`.
|
||||
|
||||
!!! warning "Behavior changed in `langchain-core` 0.3.63"
|
||||
|
||||
Added support for OpenAI's image generation built-in tool.
|
||||
"""
|
||||
# Import locally to prevent circular import
|
||||
|
||||
@@ -4,11 +4,13 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
|
||||
def _replace_new_line(match: re.Match[str]) -> str:
|
||||
value = match.group(2)
|
||||
|
||||
@@ -170,28 +170,33 @@ def dereference_refs(
|
||||
full_schema: dict | None = None,
|
||||
skip_keys: Sequence[str] | None = None,
|
||||
) -> dict:
|
||||
"""Resolve and inline JSON Schema $ref references in a schema object.
|
||||
"""Resolve and inline JSON Schema `$ref` references in a schema object.
|
||||
|
||||
This function processes a JSON Schema and resolves all $ref references by replacing
|
||||
them with the actual referenced content. It handles both simple references and
|
||||
complex cases like circular references and mixed $ref objects that contain
|
||||
additional properties alongside the $ref.
|
||||
This function processes a JSON Schema and resolves all `$ref` references by
|
||||
replacing them with the actual referenced content.
|
||||
|
||||
Handles both simple references and complex cases like circular references and mixed
|
||||
`$ref` objects that contain additional properties alongside the `$ref`.
|
||||
|
||||
Args:
|
||||
schema_obj: The JSON Schema object or fragment to process. This can be a
|
||||
complete schema or just a portion of one.
|
||||
full_schema: The complete schema containing all definitions that $refs might
|
||||
point to. If not provided, defaults to schema_obj (useful when the
|
||||
schema is self-contained).
|
||||
skip_keys: Controls recursion behavior and reference resolution depth:
|
||||
- If `None` (Default): Only recurse under '$defs' and use shallow reference
|
||||
resolution (break cycles but don't deep-inline nested refs)
|
||||
- If provided (even as []): Recurse under all keys and use deep reference
|
||||
resolution (fully inline all nested references)
|
||||
schema_obj: The JSON Schema object or fragment to process.
|
||||
|
||||
This can be a complete schema or just a portion of one.
|
||||
full_schema: The complete schema containing all definitions that `$refs` might
|
||||
point to.
|
||||
|
||||
If not provided, defaults to `schema_obj` (useful when the schema is
|
||||
self-contained).
|
||||
skip_keys: Controls recursion behavior and reference resolution depth.
|
||||
|
||||
- If `None` (Default): Only recurse under `'$defs'` and use shallow
|
||||
reference resolution (break cycles but don't deep-inline nested refs)
|
||||
- If provided (even as `[]`): Recurse under all keys and use deep reference
|
||||
resolution (fully inline all nested references)
|
||||
|
||||
Returns:
|
||||
A new dictionary with all $ref references resolved and inlined. The original
|
||||
schema_obj is not modified.
|
||||
A new dictionary with all $ref references resolved and inlined.
|
||||
The original `schema_obj` is not modified.
|
||||
|
||||
Examples:
|
||||
Basic reference resolution:
|
||||
@@ -203,7 +208,8 @@ def dereference_refs(
|
||||
>>> result = dereference_refs(schema)
|
||||
>>> result["properties"]["name"] # {"type": "string"}
|
||||
|
||||
Mixed $ref with additional properties:
|
||||
Mixed `$ref` with additional properties:
|
||||
|
||||
>>> schema = {
|
||||
... "properties": {
|
||||
... "name": {"$ref": "#/$defs/base", "description": "User name"}
|
||||
@@ -215,6 +221,7 @@ def dereference_refs(
|
||||
# {"type": "string", "minLength": 1, "description": "User name"}
|
||||
|
||||
Handling circular references:
|
||||
|
||||
>>> schema = {
|
||||
... "properties": {"user": {"$ref": "#/$defs/User"}},
|
||||
... "$defs": {
|
||||
@@ -227,10 +234,11 @@ def dereference_refs(
|
||||
>>> result = dereference_refs(schema) # Won't cause infinite recursion
|
||||
|
||||
!!! note
|
||||
|
||||
- Circular references are handled gracefully by breaking cycles
|
||||
- Mixed $ref objects (with both $ref and other properties) are supported
|
||||
- Additional properties in mixed $refs override resolved properties
|
||||
- The $defs section is preserved in the output by default
|
||||
- Mixed `$ref` objects (with both `$ref` and other properties) are supported
|
||||
- Additional properties in mixed `$refs` override resolved properties
|
||||
- The `$defs` section is preserved in the output by default
|
||||
"""
|
||||
full = full_schema or schema_obj
|
||||
keys_to_skip = list(skip_keys) if skip_keys is not None else ["$defs"]
|
||||
|
||||
@@ -374,15 +374,29 @@ def _get_key(
|
||||
if resolved_scope in (0, False):
|
||||
return resolved_scope
|
||||
# Move into the scope
|
||||
try:
|
||||
# Try subscripting (Normal dictionaries)
|
||||
resolved_scope = cast("dict[str, Any]", resolved_scope)[child]
|
||||
except (TypeError, AttributeError):
|
||||
if isinstance(resolved_scope, dict):
|
||||
try:
|
||||
resolved_scope = getattr(resolved_scope, child)
|
||||
except (TypeError, AttributeError):
|
||||
# Try as a list
|
||||
resolved_scope = resolved_scope[int(child)] # type: ignore[index]
|
||||
resolved_scope = resolved_scope[child]
|
||||
except (KeyError, TypeError):
|
||||
# Key not found - will be caught by outer try-except
|
||||
msg = f"Key {child!r} not found in dict"
|
||||
raise KeyError(msg) from None
|
||||
elif isinstance(resolved_scope, (list, tuple)):
|
||||
try:
|
||||
resolved_scope = resolved_scope[int(child)]
|
||||
except (ValueError, IndexError, TypeError):
|
||||
# Invalid index - will be caught by outer try-except
|
||||
msg = f"Invalid index {child!r} for list/tuple"
|
||||
raise IndexError(msg) from None
|
||||
else:
|
||||
# Reject everything else for security
|
||||
# This prevents traversing into arbitrary Python objects
|
||||
msg = (
|
||||
f"Cannot traverse into {type(resolved_scope).__name__}. "
|
||||
"Mustache templates only support dict, list, and tuple. "
|
||||
f"Got: {type(resolved_scope)}"
|
||||
)
|
||||
raise TypeError(msg) # noqa: TRY301
|
||||
|
||||
try:
|
||||
# This allows for custom falsy data types
|
||||
@@ -393,8 +407,9 @@ def _get_key(
|
||||
if resolved_scope in (0, False):
|
||||
return resolved_scope
|
||||
return resolved_scope or ""
|
||||
except (AttributeError, KeyError, IndexError, ValueError):
|
||||
except (AttributeError, KeyError, IndexError, ValueError, TypeError):
|
||||
# We couldn't find the key in the current scope
|
||||
# TypeError: Attempted to traverse into non-dict/list type
|
||||
# We'll try again on the next pass
|
||||
pass
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ from __future__ import annotations
|
||||
import inspect
|
||||
import textwrap
|
||||
import warnings
|
||||
from collections.abc import Callable
|
||||
from contextlib import nullcontext
|
||||
from functools import lru_cache, wraps
|
||||
from types import GenericAlias
|
||||
@@ -41,10 +40,12 @@ from pydantic.json_schema import (
|
||||
)
|
||||
from pydantic.v1 import BaseModel as BaseModelV1
|
||||
from pydantic.v1 import create_model as create_model_v1
|
||||
from pydantic.v1.fields import ModelField
|
||||
from typing_extensions import deprecated, override
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
from pydantic.v1.fields import ModelField
|
||||
from pydantic_core import core_schema
|
||||
|
||||
PYDANTIC_VERSION = version.parse(pydantic.__version__)
|
||||
|
||||
@@ -11,7 +11,6 @@ import logging
|
||||
import math
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from itertools import cycle
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
@@ -29,7 +28,7 @@ from langchain_core.retrievers import BaseRetriever, LangSmithRetrieverParams
|
||||
from langchain_core.runnables.config import run_in_executor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Collection, Iterable, Iterator, Sequence
|
||||
from collections.abc import Callable, Collection, Iterable, Iterator, Sequence
|
||||
|
||||
from langchain_core.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
@@ -295,8 +294,9 @@ class VectorStore(ABC):
|
||||
|
||||
Args:
|
||||
query: Input text.
|
||||
search_type: Type of search to perform. Can be `'similarity'`, `'mmr'`, or
|
||||
`'similarity_score_threshold'`.
|
||||
search_type: Type of search to perform.
|
||||
|
||||
Can be `'similarity'`, `'mmr'`, or `'similarity_score_threshold'`.
|
||||
**kwargs: Arguments to pass to the search method.
|
||||
|
||||
Returns:
|
||||
@@ -329,8 +329,9 @@ class VectorStore(ABC):
|
||||
|
||||
Args:
|
||||
query: Input text.
|
||||
search_type: Type of search to perform. Can be `'similarity'`, `'mmr'`, or
|
||||
`'similarity_score_threshold'`.
|
||||
search_type: Type of search to perform.
|
||||
|
||||
Can be `'similarity'`, `'mmr'`, or `'similarity_score_threshold'`.
|
||||
**kwargs: Arguments to pass to the search method.
|
||||
|
||||
Returns:
|
||||
@@ -461,9 +462,10 @@ class VectorStore(ABC):
|
||||
Args:
|
||||
query: Input text.
|
||||
k: Number of `Document` objects to return.
|
||||
**kwargs: kwargs to be passed to similarity search. Should include
|
||||
`score_threshold`, An optional floating point value between `0` to `1`
|
||||
to filter the resulting set of retrieved docs
|
||||
**kwargs: Kwargs to be passed to similarity search.
|
||||
|
||||
Should include `score_threshold`, an optional floating point value
|
||||
between `0` to `1` to filter the resulting set of retrieved docs.
|
||||
|
||||
Returns:
|
||||
List of tuples of `(doc, similarity_score)`
|
||||
@@ -488,9 +490,10 @@ class VectorStore(ABC):
|
||||
Args:
|
||||
query: Input text.
|
||||
k: Number of `Document` objects to return.
|
||||
**kwargs: kwargs to be passed to similarity search. Should include
|
||||
`score_threshold`, An optional floating point value between `0` to `1`
|
||||
to filter the resulting set of retrieved docs
|
||||
**kwargs: Kwargs to be passed to similarity search.
|
||||
|
||||
Should include `score_threshold`, an optional floating point value
|
||||
between `0` to `1` to filter the resulting set of retrieved docs.
|
||||
|
||||
Returns:
|
||||
List of tuples of `(doc, similarity_score)`
|
||||
@@ -512,9 +515,10 @@ class VectorStore(ABC):
|
||||
Args:
|
||||
query: Input text.
|
||||
k: Number of `Document` objects to return.
|
||||
**kwargs: kwargs to be passed to similarity search. Should include
|
||||
`score_threshold`, An optional floating point value between `0` to `1`
|
||||
to filter the resulting set of retrieved docs
|
||||
**kwargs: Kwargs to be passed to similarity search.
|
||||
|
||||
Should include `score_threshold`, an optional floating point value
|
||||
between `0` to `1` to filter the resulting set of retrieved docs.
|
||||
|
||||
Returns:
|
||||
List of tuples of `(doc, similarity_score)`.
|
||||
@@ -561,9 +565,10 @@ class VectorStore(ABC):
|
||||
Args:
|
||||
query: Input text.
|
||||
k: Number of `Document` objects to return.
|
||||
**kwargs: kwargs to be passed to similarity search. Should include
|
||||
`score_threshold`, An optional floating point value between `0` to `1`
|
||||
to filter the resulting set of retrieved docs
|
||||
**kwargs: Kwargs to be passed to similarity search.
|
||||
|
||||
Should include `score_threshold`, an optional floating point value
|
||||
between `0` to `1` to filter the resulting set of retrieved docs.
|
||||
|
||||
Returns:
|
||||
List of tuples of `(doc, similarity_score)`
|
||||
@@ -901,13 +906,15 @@ class VectorStore(ABC):
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments to pass to the search function.
|
||||
|
||||
Can include:
|
||||
|
||||
* `search_type`: Defines the type of search that the Retriever should
|
||||
perform. Can be `'similarity'` (default), `'mmr'`, or
|
||||
`'similarity_score_threshold'`.
|
||||
* `search_kwargs`: Keyword arguments to pass to the search function. Can
|
||||
include things like:
|
||||
* `search_kwargs`: Keyword arguments to pass to the search function.
|
||||
|
||||
Can include things like:
|
||||
|
||||
* `k`: Amount of documents to return (Default: `4`)
|
||||
* `score_threshold`: Minimum relevance threshold
|
||||
|
||||
@@ -4,7 +4,6 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
@@ -20,7 +19,7 @@ from langchain_core.vectorstores.utils import _cosine_similarity as cosine_simil
|
||||
from langchain_core.vectorstores.utils import maximal_marginal_relevance
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterator, Sequence
|
||||
from collections.abc import Callable, Iterator, Sequence
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
"""langchain-core version information and utilities."""
|
||||
|
||||
VERSION = "1.0.4"
|
||||
VERSION = "1.1.0"
|
||||
|
||||
@@ -9,7 +9,7 @@ license = {text = "MIT"}
|
||||
readme = "README.md"
|
||||
authors = []
|
||||
|
||||
version = "1.0.4"
|
||||
version = "1.1.0"
|
||||
requires-python = ">=3.10.0,<4.0.0"
|
||||
dependencies = [
|
||||
"langsmith>=0.3.45,<1.0.0",
|
||||
@@ -36,7 +36,6 @@ typing = [
|
||||
"mypy>=1.18.1,<1.19.0",
|
||||
"types-pyyaml>=6.0.12.2,<7.0.0.0",
|
||||
"types-requests>=2.28.11.5,<3.0.0.0",
|
||||
"langchain-model-profiles",
|
||||
"langchain-text-splitters",
|
||||
]
|
||||
dev = [
|
||||
@@ -58,7 +57,6 @@ test = [
|
||||
"blockbuster>=1.5.18,<1.6.0",
|
||||
"numpy>=1.26.4; python_version<'3.13'",
|
||||
"numpy>=2.1.0; python_version>='3.13'",
|
||||
"langchain-model-profiles",
|
||||
"langchain-tests",
|
||||
"pytest-benchmark",
|
||||
"pytest-codspeed",
|
||||
@@ -66,7 +64,6 @@ test = [
|
||||
test_integration = []
|
||||
|
||||
[tool.uv.sources]
|
||||
langchain-model-profiles = { path = "../model-profiles" }
|
||||
langchain-tests = { path = "../standard-tests" }
|
||||
langchain-text-splitters = { path = "../text-splitters" }
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ from langchain_core.language_models import (
|
||||
ParrotFakeChatModel,
|
||||
)
|
||||
from langchain_core.language_models._utils import _normalize_messages
|
||||
from langchain_core.language_models.chat_models import _generate_response_from_error
|
||||
from langchain_core.language_models.fake_chat_models import (
|
||||
FakeListChatModelError,
|
||||
GenericFakeChatModel,
|
||||
@@ -1221,16 +1222,99 @@ def test_get_ls_params() -> None:
|
||||
|
||||
def test_model_profiles() -> None:
|
||||
model = GenericFakeChatModel(messages=iter([]))
|
||||
profile = model.profile
|
||||
assert profile == {}
|
||||
assert model.profile is None
|
||||
|
||||
class MyModel(GenericFakeChatModel):
|
||||
model: str = "gpt-5"
|
||||
model_with_profile = GenericFakeChatModel(
|
||||
messages=iter([]), profile={"max_input_tokens": 100}
|
||||
)
|
||||
assert model_with_profile.profile == {"max_input_tokens": 100}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "openai-chat"
|
||||
|
||||
model = MyModel(messages=iter([]))
|
||||
profile = model.profile
|
||||
assert profile
|
||||
class MockResponse:
|
||||
"""Mock response for testing _generate_response_from_error."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
status_code: int = 400,
|
||||
headers: dict[str, str] | None = None,
|
||||
json_data: dict[str, Any] | None = None,
|
||||
json_raises: type[Exception] | None = None,
|
||||
text_raises: type[Exception] | None = None,
|
||||
):
|
||||
self.status_code = status_code
|
||||
self.headers = headers or {}
|
||||
self._json_data = json_data
|
||||
self._json_raises = json_raises
|
||||
self._text_raises = text_raises
|
||||
|
||||
def json(self) -> dict[str, Any]:
|
||||
if self._json_raises:
|
||||
msg = "JSON parsing failed"
|
||||
raise self._json_raises(msg)
|
||||
return self._json_data or {}
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
if self._text_raises:
|
||||
msg = "Text access failed"
|
||||
raise self._text_raises(msg)
|
||||
return ""
|
||||
|
||||
|
||||
class MockAPIError(Exception):
|
||||
"""Mock API error with response attribute."""
|
||||
|
||||
def __init__(self, message: str, response: MockResponse | None = None):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
if response is not None:
|
||||
self.response = response
|
||||
|
||||
|
||||
def test_generate_response_from_error_with_valid_json() -> None:
|
||||
"""Test `_generate_response_from_error` with valid JSON response."""
|
||||
response = MockResponse(
|
||||
status_code=400,
|
||||
headers={"content-type": "application/json"},
|
||||
json_data={"error": {"message": "Bad request", "type": "invalid_request"}},
|
||||
)
|
||||
error = MockAPIError("API Error", response=response)
|
||||
|
||||
generations = _generate_response_from_error(error)
|
||||
|
||||
assert len(generations) == 1
|
||||
generation = generations[0]
|
||||
assert isinstance(generation, ChatGeneration)
|
||||
assert isinstance(generation.message, AIMessage)
|
||||
assert generation.message.content == ""
|
||||
|
||||
metadata = generation.message.response_metadata
|
||||
assert metadata["body"] == {
|
||||
"error": {"message": "Bad request", "type": "invalid_request"}
|
||||
}
|
||||
assert metadata["headers"] == {"content-type": "application/json"}
|
||||
assert metadata["status_code"] == 400
|
||||
|
||||
|
||||
def test_generate_response_from_error_handles_streaming_response_failure() -> None:
|
||||
# Simulates scenario where accessing response.json() or response.text
|
||||
# raises ResponseNotRead on streaming responses
|
||||
response = MockResponse(
|
||||
status_code=400,
|
||||
headers={"content-type": "application/json"},
|
||||
json_raises=Exception, # Simulates ResponseNotRead or similar
|
||||
text_raises=Exception,
|
||||
)
|
||||
error = MockAPIError("API Error", response=response)
|
||||
|
||||
# This should NOT raise an exception, but should handle it gracefully
|
||||
generations = _generate_response_from_error(error)
|
||||
|
||||
assert len(generations) == 1
|
||||
generation = generations[0]
|
||||
metadata = generation.message.response_metadata
|
||||
|
||||
# When both fail, body should be None instead of raising an exception
|
||||
assert metadata["body"] is None
|
||||
assert metadata["headers"] == {"content-type": "application/json"}
|
||||
assert metadata["status_code"] == 400
|
||||
|
||||
@@ -18,6 +18,8 @@ EXPECTED_ALL = [
|
||||
"FakeStreamingListLLM",
|
||||
"FakeListLLM",
|
||||
"ParrotFakeChatModel",
|
||||
"ModelProfile",
|
||||
"ModelProfileRegistry",
|
||||
"is_openai_data_block",
|
||||
]
|
||||
|
||||
|
||||
@@ -1,11 +0,0 @@
|
||||
"""Test for Serializable base class."""
|
||||
|
||||
from langchain_core.load.load import load
|
||||
|
||||
|
||||
def test_load_with_string_secrets() -> None:
|
||||
obj = {"api_key": "__SECRET_API_KEY__"}
|
||||
secrets_map = {"__SECRET_API_KEY__": "hello"}
|
||||
result = load(obj, secrets_map=secrets_map)
|
||||
|
||||
assert result["api_key"] == "hello"
|
||||
@@ -0,0 +1,140 @@
|
||||
"""Test groq block translator."""
|
||||
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.messages import content as types
|
||||
from langchain_core.messages.base import _extract_reasoning_from_additional_kwargs
|
||||
from langchain_core.messages.block_translators import PROVIDER_TRANSLATORS
|
||||
from langchain_core.messages.block_translators.groq import (
|
||||
_parse_code_json,
|
||||
translate_content,
|
||||
)
|
||||
|
||||
|
||||
def test_groq_translator_registered() -> None:
|
||||
"""Test that groq translator is properly registered."""
|
||||
assert "groq" in PROVIDER_TRANSLATORS
|
||||
assert "translate_content" in PROVIDER_TRANSLATORS["groq"]
|
||||
assert "translate_content_chunk" in PROVIDER_TRANSLATORS["groq"]
|
||||
|
||||
|
||||
def test_extract_reasoning_from_additional_kwargs_exists() -> None:
|
||||
"""Test that _extract_reasoning_from_additional_kwargs can be imported."""
|
||||
# Verify it's callable
|
||||
assert callable(_extract_reasoning_from_additional_kwargs)
|
||||
|
||||
|
||||
def test_groq_translate_content_basic() -> None:
|
||||
"""Test basic groq content translation."""
|
||||
# Test with simple text message
|
||||
message = AIMessage(content="Hello world")
|
||||
blocks = translate_content(message)
|
||||
|
||||
assert isinstance(blocks, list)
|
||||
assert len(blocks) == 1
|
||||
assert blocks[0]["type"] == "text"
|
||||
assert blocks[0]["text"] == "Hello world"
|
||||
|
||||
|
||||
def test_groq_translate_content_with_reasoning() -> None:
|
||||
"""Test groq content translation with reasoning content."""
|
||||
# Test with reasoning content in additional_kwargs
|
||||
message = AIMessage(
|
||||
content="Final answer",
|
||||
additional_kwargs={"reasoning_content": "Let me think about this..."},
|
||||
)
|
||||
blocks = translate_content(message)
|
||||
|
||||
assert isinstance(blocks, list)
|
||||
assert len(blocks) == 2
|
||||
|
||||
# First block should be reasoning
|
||||
assert blocks[0]["type"] == "reasoning"
|
||||
assert blocks[0]["reasoning"] == "Let me think about this..."
|
||||
|
||||
# Second block should be text
|
||||
assert blocks[1]["type"] == "text"
|
||||
assert blocks[1]["text"] == "Final answer"
|
||||
|
||||
|
||||
def test_groq_translate_content_with_tool_calls() -> None:
|
||||
"""Test groq content translation with tool calls."""
|
||||
# Test with tool calls
|
||||
message = AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "search",
|
||||
"args": {"query": "test"},
|
||||
"id": "call_123",
|
||||
}
|
||||
],
|
||||
)
|
||||
blocks = translate_content(message)
|
||||
|
||||
assert isinstance(blocks, list)
|
||||
assert len(blocks) == 1
|
||||
assert blocks[0]["type"] == "tool_call"
|
||||
assert blocks[0]["name"] == "search"
|
||||
assert blocks[0]["args"] == {"query": "test"}
|
||||
assert blocks[0]["id"] == "call_123"
|
||||
|
||||
|
||||
def test_groq_translate_content_with_executed_tools() -> None:
|
||||
"""Test groq content translation with executed tools (built-in tools)."""
|
||||
# Test with executed_tools in additional_kwargs (Groq built-in tools)
|
||||
message = AIMessage(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"executed_tools": [
|
||||
{
|
||||
"type": "python",
|
||||
"arguments": '{"code": "print(\\"hello\\")"}',
|
||||
"output": "hello\\n",
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
blocks = translate_content(message)
|
||||
|
||||
assert isinstance(blocks, list)
|
||||
# Should have server_tool_call and server_tool_result
|
||||
assert len(blocks) >= 2
|
||||
|
||||
# Check for server_tool_call
|
||||
tool_call_blocks = [
|
||||
cast("types.ServerToolCall", b)
|
||||
for b in blocks
|
||||
if b.get("type") == "server_tool_call"
|
||||
]
|
||||
assert len(tool_call_blocks) == 1
|
||||
assert tool_call_blocks[0]["name"] == "code_interpreter"
|
||||
assert "code" in tool_call_blocks[0]["args"]
|
||||
|
||||
# Check for server_tool_result
|
||||
tool_result_blocks = [
|
||||
cast("types.ServerToolResult", b)
|
||||
for b in blocks
|
||||
if b.get("type") == "server_tool_result"
|
||||
]
|
||||
assert len(tool_result_blocks) == 1
|
||||
assert tool_result_blocks[0]["output"] == "hello\\n"
|
||||
assert tool_result_blocks[0]["status"] == "success"
|
||||
|
||||
|
||||
def test_parse_code_json() -> None:
|
||||
"""Test the _parse_code_json helper function."""
|
||||
# Test valid code JSON
|
||||
result = _parse_code_json('{"code": "print(\'hello\')"}')
|
||||
assert result == {"code": "print('hello')"}
|
||||
|
||||
# Test code with unescaped quotes (Groq format)
|
||||
result = _parse_code_json('{"code": "print("hello")"}')
|
||||
assert result == {"code": 'print("hello")'}
|
||||
|
||||
# Test invalid format raises ValueError
|
||||
with pytest.raises(ValueError, match="Could not extract Python code"):
|
||||
_parse_code_json('{"invalid": "format"}')
|
||||
@@ -1320,9 +1320,11 @@
|
||||
```
|
||||
|
||||
!!! warning "Behavior changed in `langchain-core` 0.3.9"
|
||||
|
||||
Added `input_token_details` and `output_token_details`.
|
||||
|
||||
!!! note "LangSmith SDK"
|
||||
|
||||
The LangSmith SDK also has a `UsageMetadata` class. While the two share fields,
|
||||
LangSmith's `UsageMetadata` has additional fields to capture cost information
|
||||
used by the LangSmith platform.
|
||||
@@ -2734,9 +2736,11 @@
|
||||
```
|
||||
|
||||
!!! warning "Behavior changed in `langchain-core` 0.3.9"
|
||||
|
||||
Added `input_token_details` and `output_token_details`.
|
||||
|
||||
!!! note "LangSmith SDK"
|
||||
|
||||
The LangSmith SDK also has a `UsageMetadata` class. While the two share fields,
|
||||
LangSmith's `UsageMetadata` has additional fields to capture cost information
|
||||
used by the LangSmith platform.
|
||||
|
||||
@@ -1540,3 +1540,164 @@ def test_rendering_prompt_with_conditionals_no_empty_text_blocks() -> None:
|
||||
assert not [
|
||||
block for block in content if block["type"] == "text" and block["text"] == ""
|
||||
]
|
||||
|
||||
|
||||
def test_fstring_rejects_invalid_identifier_variable_names() -> None:
|
||||
"""Test that f-string templates block attribute access, indexing.
|
||||
|
||||
This validation prevents template injection attacks by blocking:
|
||||
- Attribute access like {msg.__class__}
|
||||
- Indexing like {msg[0]}
|
||||
- All-digit variable names like {0} or {100} (interpreted as positional args)
|
||||
|
||||
While allowing any other field names that Python's Formatter accepts.
|
||||
"""
|
||||
# Test that attribute access and indexing are blocked (security issue)
|
||||
invalid_templates = [
|
||||
"{msg.__class__}", # Attribute access with dunder
|
||||
"{msg.__class__.__name__}", # Multiple dunders
|
||||
"{msg.content}", # Attribute access
|
||||
"{msg[0]}", # Item access
|
||||
"{0}", # All-digit variable name (positional argument)
|
||||
"{100}", # All-digit variable name (positional argument)
|
||||
"{42}", # All-digit variable name (positional argument)
|
||||
]
|
||||
|
||||
for template_str in invalid_templates:
|
||||
with pytest.raises(ValueError, match="Invalid variable name") as exc_info:
|
||||
ChatPromptTemplate.from_messages(
|
||||
[("human", template_str)],
|
||||
template_format="f-string",
|
||||
)
|
||||
|
||||
error_msg = str(exc_info.value)
|
||||
assert "Invalid variable name" in error_msg
|
||||
# Check for any of the expected error message parts
|
||||
assert (
|
||||
"attribute access" in error_msg
|
||||
or "indexing" in error_msg
|
||||
or "positional arguments" in error_msg
|
||||
)
|
||||
|
||||
# Valid templates - Python's Formatter accepts non-identifier field names
|
||||
valid_templates = [
|
||||
(
|
||||
"Hello {name} and {user_id}",
|
||||
{"name": "Alice", "user_id": "123"},
|
||||
"Hello Alice and 123",
|
||||
),
|
||||
("User: {user-name}", {"user-name": "Bob"}, "User: Bob"), # Hyphen allowed
|
||||
(
|
||||
"Value: {2fast}",
|
||||
{"2fast": "Charlie"},
|
||||
"Value: Charlie",
|
||||
), # Starts with digit allowed
|
||||
("Data: {my var}", {"my var": "Dave"}, "Data: Dave"), # Space allowed
|
||||
]
|
||||
|
||||
for template_str, kwargs, expected in valid_templates:
|
||||
template = ChatPromptTemplate.from_messages(
|
||||
[("human", template_str)],
|
||||
template_format="f-string",
|
||||
)
|
||||
result = template.invoke(kwargs)
|
||||
assert result.messages[0].content == expected # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def test_mustache_template_attribute_access_vulnerability() -> None:
|
||||
"""Test that Mustache template injection is blocked.
|
||||
|
||||
Verify the fix for security vulnerability GHSA-6qv9-48xg-fc7f
|
||||
|
||||
Previously, Mustache used getattr() as a fallback, allowing access to
|
||||
dangerous attributes like __class__, __globals__, etc.
|
||||
|
||||
The fix adds isinstance checks that reject non-dict/list types.
|
||||
When templates try to traverse Python objects, they get empty string
|
||||
per Mustache spec (better than the previous behavior of exposing internals).
|
||||
"""
|
||||
msg = HumanMessage("howdy")
|
||||
|
||||
# Template tries to access attributes on a Python object
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[("human", "{{question.__class__.__name__}}")],
|
||||
template_format="mustache",
|
||||
)
|
||||
|
||||
# After the fix: returns empty string (attack blocked!)
|
||||
# Previously would return "HumanMessage" via getattr()
|
||||
result = prompt.invoke({"question": msg})
|
||||
assert result.messages[0].content == "" # type: ignore[attr-defined]
|
||||
|
||||
# Mustache still works correctly with actual dicts
|
||||
prompt_dict = ChatPromptTemplate.from_messages(
|
||||
[("human", "{{person.name}}")],
|
||||
template_format="mustache",
|
||||
)
|
||||
result_dict = prompt_dict.invoke({"person": {"name": "Alice"}})
|
||||
assert result_dict.messages[0].content == "Alice" # type: ignore[attr-defined]
|
||||
|
||||
|
||||
@pytest.mark.requires("jinja2")
|
||||
def test_jinja2_template_attribute_access_is_blocked() -> None:
|
||||
"""Test that Jinja2 SandboxedEnvironment blocks dangerous attribute access.
|
||||
|
||||
This test verifies that Jinja2's sandbox successfully blocks access to
|
||||
dangerous dunder attributes like __class__, unlike Mustache.
|
||||
|
||||
GOOD: Jinja2 SandboxedEnvironment raises SecurityError when attempting
|
||||
to access __class__, __globals__, etc. This is expected behavior.
|
||||
"""
|
||||
msg = HumanMessage("howdy")
|
||||
|
||||
# Create a Jinja2 template that attempts to access __class__.__name__
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[("human", "{{question.__class__.__name__}}")],
|
||||
template_format="jinja2",
|
||||
)
|
||||
|
||||
# Jinja2 sandbox should block this with SecurityError
|
||||
with pytest.raises(Exception, match="attribute") as exc_info:
|
||||
prompt.invoke(
|
||||
{"question": msg, "question.__class__.__name__": "safe_placeholder"}
|
||||
)
|
||||
|
||||
# Verify it's a SecurityError from Jinja2 blocking __class__ access
|
||||
error_msg = str(exc_info.value)
|
||||
assert (
|
||||
"SecurityError" in str(type(exc_info.value))
|
||||
or "access to attribute '__class__'" in error_msg
|
||||
), f"Expected SecurityError blocking __class__, got: {error_msg}"
|
||||
|
||||
|
||||
@pytest.mark.requires("jinja2")
|
||||
def test_jinja2_blocks_all_attribute_access() -> None:
|
||||
"""Test that Jinja2 now blocks ALL attribute/method access for security.
|
||||
|
||||
After the fix, Jinja2 uses _RestrictedSandboxedEnvironment which blocks
|
||||
ALL attribute access, not just dunder attributes. This prevents the
|
||||
parse_raw() vulnerability.
|
||||
"""
|
||||
msg = HumanMessage("test content")
|
||||
|
||||
# Test 1: Simple variable access should still work
|
||||
prompt_simple = ChatPromptTemplate.from_messages(
|
||||
[("human", "Message: {{message}}")],
|
||||
template_format="jinja2",
|
||||
)
|
||||
result = prompt_simple.invoke({"message": "hello world"})
|
||||
assert "hello world" in result.messages[0].content # type: ignore[attr-defined]
|
||||
|
||||
# Test 2: Attribute access should now be blocked (including safe attributes)
|
||||
prompt_attr = ChatPromptTemplate.from_messages(
|
||||
[("human", "Content: {{msg.content}}")],
|
||||
template_format="jinja2",
|
||||
)
|
||||
with pytest.raises(Exception, match="attribute") as exc_info:
|
||||
prompt_attr.invoke({"msg": msg})
|
||||
|
||||
error_msg = str(exc_info.value)
|
||||
assert (
|
||||
"SecurityError" in str(type(exc_info.value))
|
||||
or "Access to attributes is not allowed" in error_msg
|
||||
), f"Expected SecurityError blocking attribute access, got: {error_msg}"
|
||||
|
||||
@@ -125,7 +125,9 @@ def test_structured_prompt_kwargs() -> None:
|
||||
|
||||
def test_structured_prompt_template_format() -> None:
|
||||
prompt = StructuredPrompt(
|
||||
[("human", "hi {{person.name}}")], schema={}, template_format="mustache"
|
||||
[("human", "hi {{person.name}}")],
|
||||
schema={"type": "object", "properties": {}, "title": "foo"},
|
||||
template_format="mustache",
|
||||
)
|
||||
assert prompt.messages[0].prompt.template_format == "mustache" # type: ignore[union-attr, union-attr]
|
||||
assert prompt.input_variables == ["person"]
|
||||
@@ -136,4 +138,8 @@ def test_structured_prompt_template_format() -> None:
|
||||
|
||||
def test_structured_prompt_template_empty_vars() -> None:
|
||||
with pytest.raises(ChevronError, match="empty tag"):
|
||||
StructuredPrompt([("human", "hi {{}}")], schema={}, template_format="mustache")
|
||||
StructuredPrompt(
|
||||
[("human", "hi {{}}")],
|
||||
schema={"type": "object", "properties": {}, "title": "foo"},
|
||||
template_format="mustache",
|
||||
)
|
||||
|
||||
@@ -1744,9 +1744,11 @@
|
||||
```
|
||||
|
||||
!!! warning "Behavior changed in `langchain-core` 0.3.9"
|
||||
|
||||
Added `input_token_details` and `output_token_details`.
|
||||
|
||||
!!! note "LangSmith SDK"
|
||||
|
||||
The LangSmith SDK also has a `UsageMetadata` class. While the two share fields,
|
||||
LangSmith's `UsageMetadata` has additional fields to capture cost information
|
||||
used by the LangSmith platform.
|
||||
|
||||
@@ -3260,9 +3260,11 @@
|
||||
```
|
||||
|
||||
!!! warning "Behavior changed in `langchain-core` 0.3.9"
|
||||
|
||||
Added `input_token_details` and `output_token_details`.
|
||||
|
||||
!!! note "LangSmith SDK"
|
||||
|
||||
The LangSmith SDK also has a `UsageMetadata` class. While the two share fields,
|
||||
LangSmith's `UsageMetadata` has additional fields to capture cost information
|
||||
used by the LangSmith platform.
|
||||
@@ -4736,9 +4738,11 @@
|
||||
```
|
||||
|
||||
!!! warning "Behavior changed in `langchain-core` 0.3.9"
|
||||
|
||||
Added `input_token_details` and `output_token_details`.
|
||||
|
||||
!!! note "LangSmith SDK"
|
||||
|
||||
The LangSmith SDK also has a `UsageMetadata` class. While the two share fields,
|
||||
LangSmith's `UsageMetadata` has additional fields to capture cost information
|
||||
used by the LangSmith platform.
|
||||
@@ -6224,9 +6228,11 @@
|
||||
```
|
||||
|
||||
!!! warning "Behavior changed in `langchain-core` 0.3.9"
|
||||
|
||||
Added `input_token_details` and `output_token_details`.
|
||||
|
||||
!!! note "LangSmith SDK"
|
||||
|
||||
The LangSmith SDK also has a `UsageMetadata` class. While the two share fields,
|
||||
LangSmith's `UsageMetadata` has additional fields to capture cost information
|
||||
used by the LangSmith platform.
|
||||
@@ -7568,9 +7574,11 @@
|
||||
```
|
||||
|
||||
!!! warning "Behavior changed in `langchain-core` 0.3.9"
|
||||
|
||||
Added `input_token_details` and `output_token_details`.
|
||||
|
||||
!!! note "LangSmith SDK"
|
||||
|
||||
The LangSmith SDK also has a `UsageMetadata` class. While the two share fields,
|
||||
LangSmith's `UsageMetadata` has additional fields to capture cost information
|
||||
used by the LangSmith platform.
|
||||
@@ -9086,9 +9094,11 @@
|
||||
```
|
||||
|
||||
!!! warning "Behavior changed in `langchain-core` 0.3.9"
|
||||
|
||||
Added `input_token_details` and `output_token_details`.
|
||||
|
||||
!!! note "LangSmith SDK"
|
||||
|
||||
The LangSmith SDK also has a `UsageMetadata` class. While the two share fields,
|
||||
LangSmith's `UsageMetadata` has additional fields to capture cost information
|
||||
used by the LangSmith platform.
|
||||
@@ -10475,9 +10485,11 @@
|
||||
```
|
||||
|
||||
!!! warning "Behavior changed in `langchain-core` 0.3.9"
|
||||
|
||||
Added `input_token_details` and `output_token_details`.
|
||||
|
||||
!!! note "LangSmith SDK"
|
||||
|
||||
The LangSmith SDK also has a `UsageMetadata` class. While the two share fields,
|
||||
LangSmith's `UsageMetadata` has additional fields to capture cost information
|
||||
used by the LangSmith platform.
|
||||
@@ -11912,9 +11924,11 @@
|
||||
```
|
||||
|
||||
!!! warning "Behavior changed in `langchain-core` 0.3.9"
|
||||
|
||||
Added `input_token_details` and `output_token_details`.
|
||||
|
||||
!!! note "LangSmith SDK"
|
||||
|
||||
The LangSmith SDK also has a `UsageMetadata` class. While the two share fields,
|
||||
LangSmith's `UsageMetadata` has additional fields to capture cost information
|
||||
used by the LangSmith platform.
|
||||
@@ -13350,9 +13364,11 @@
|
||||
```
|
||||
|
||||
!!! warning "Behavior changed in `langchain-core` 0.3.9"
|
||||
|
||||
Added `input_token_details` and `output_token_details`.
|
||||
|
||||
!!! note "LangSmith SDK"
|
||||
|
||||
The LangSmith SDK also has a `UsageMetadata` class. While the two share fields,
|
||||
LangSmith's `UsageMetadata` has additional fields to capture cost information
|
||||
used by the LangSmith platform.
|
||||
|
||||
@@ -3,12 +3,14 @@
|
||||
import asyncio
|
||||
import time
|
||||
from threading import Lock
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_core.runnables import RunnableConfig, RunnableLambda
|
||||
from langchain_core.runnables.base import Runnable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.runnables.base import Runnable
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -3,21 +3,25 @@ from __future__ import annotations
|
||||
import json
|
||||
import sys
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator, Callable, Coroutine, Generator
|
||||
from inspect import isasyncgenfunction
|
||||
from typing import Any, Literal
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from langsmith import Client, get_current_run_tree, traceable
|
||||
from langsmith.run_helpers import tracing_context
|
||||
from langsmith.run_trees import RunTree
|
||||
from langsmith.utils import get_env_var
|
||||
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.runnables.base import RunnableLambda, RunnableParallel
|
||||
from langchain_core.tracers.langchain import LangChainTracer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncGenerator, Callable, Coroutine, Generator
|
||||
|
||||
from langsmith.run_trees import RunTree
|
||||
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
|
||||
|
||||
def _get_posts(client: Client) -> list:
|
||||
mock_calls = client.session.request.mock_calls # type: ignore[attr-defined]
|
||||
|
||||
@@ -6,6 +6,7 @@ import sys
|
||||
import textwrap
|
||||
import threading
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
@@ -55,6 +56,7 @@ from langchain_core.tools.base import (
|
||||
InjectedToolArg,
|
||||
InjectedToolCallId,
|
||||
SchemaAnnotationError,
|
||||
_DirectlyInjectedToolArg,
|
||||
_is_message_content_block,
|
||||
_is_message_content_type,
|
||||
get_all_basemodel_annotations,
|
||||
@@ -2331,6 +2333,101 @@ def test_injected_arg_with_complex_type() -> None:
|
||||
assert injected_tool.invoke({"x": 5, "foo": Foo()}) == "bar"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("schema_format", ["model", "json_schema"])
|
||||
def test_tool_allows_extra_runtime_args_with_custom_schema(
|
||||
schema_format: Literal["model", "json_schema"],
|
||||
) -> None:
|
||||
"""Ensure runtime args are preserved even if not in the args schema."""
|
||||
|
||||
class InputSchema(BaseModel):
|
||||
query: str
|
||||
|
||||
captured: dict[str, Any] = {}
|
||||
|
||||
@dataclass
|
||||
class MyRuntime(_DirectlyInjectedToolArg):
|
||||
some_obj: object
|
||||
|
||||
args_schema = (
|
||||
InputSchema if schema_format == "model" else InputSchema.model_json_schema()
|
||||
)
|
||||
|
||||
@tool(args_schema=args_schema)
|
||||
def runtime_tool(query: str, runtime: MyRuntime) -> str:
|
||||
"""Echo the query and capture runtime value."""
|
||||
captured["runtime"] = runtime
|
||||
return query
|
||||
|
||||
runtime_obj = object()
|
||||
runtime = MyRuntime(some_obj=runtime_obj)
|
||||
assert runtime_tool.invoke({"query": "hello", "runtime": runtime}) == "hello"
|
||||
assert captured["runtime"] is runtime
|
||||
|
||||
|
||||
def test_tool_injected_tool_call_id_with_custom_schema() -> None:
|
||||
"""Ensure InjectedToolCallId works with custom args schema."""
|
||||
|
||||
class InputSchema(BaseModel):
|
||||
x: int
|
||||
|
||||
@tool(args_schema=InputSchema)
|
||||
def injected_tool(
|
||||
x: int, tool_call_id: Annotated[str, InjectedToolCallId]
|
||||
) -> ToolMessage:
|
||||
"""Tool with injected tool_call_id and custom schema."""
|
||||
return ToolMessage(str(x), tool_call_id=tool_call_id)
|
||||
|
||||
# Test that tool_call_id is properly injected even though not in custom schema
|
||||
result = injected_tool.invoke(
|
||||
{
|
||||
"type": "tool_call",
|
||||
"args": {"x": 42},
|
||||
"name": "injected_tool",
|
||||
"id": "test_call_id",
|
||||
}
|
||||
)
|
||||
assert result == ToolMessage("42", tool_call_id="test_call_id")
|
||||
|
||||
# Test that it still raises error when invoked without a ToolCall
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="When tool includes an InjectedToolCallId argument, "
|
||||
"tool must always be invoked with a full model ToolCall",
|
||||
):
|
||||
injected_tool.invoke({"x": 42})
|
||||
|
||||
|
||||
def test_tool_injected_arg_with_custom_schema() -> None:
|
||||
"""Ensure InjectedToolArg works with custom args schema."""
|
||||
|
||||
class InputSchema(BaseModel):
|
||||
query: str
|
||||
|
||||
class CustomContext:
|
||||
"""Custom context object to be injected."""
|
||||
|
||||
def __init__(self, value: str) -> None:
|
||||
self.value = value
|
||||
|
||||
captured: dict[str, Any] = {}
|
||||
|
||||
@tool(args_schema=InputSchema)
|
||||
def search_tool(
|
||||
query: str, context: Annotated[CustomContext, InjectedToolArg]
|
||||
) -> str:
|
||||
"""Search with custom context."""
|
||||
captured["context"] = context
|
||||
return f"Results for {query} with context {context.value}"
|
||||
|
||||
# Test that context is properly injected even though not in custom schema
|
||||
ctx = CustomContext("test_context")
|
||||
result = search_tool.invoke({"query": "hello", "context": ctx})
|
||||
|
||||
assert result == "Results for hello with context test_context"
|
||||
assert captured["context"] is ctx
|
||||
assert captured["context"].value == "test_context"
|
||||
|
||||
|
||||
def test_tool_injected_tool_call_id() -> None:
|
||||
@tool
|
||||
def foo(x: int, tool_call_id: Annotated[str, InjectedToolCallId]) -> ToolMessage:
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from collections.abc import Callable
|
||||
from contextlib import AbstractContextManager, nullcontext
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
@@ -23,6 +22,9 @@ from langchain_core.utils import (
|
||||
from langchain_core.utils._merge import merge_dicts
|
||||
from langchain_core.utils.utils import secret_from_env
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("package", "check_kwargs", "actual_version", "expected"),
|
||||
|
||||
50
libs/core/uv.lock
generated
50
libs/core/uv.lock
generated
@@ -1,5 +1,5 @@
|
||||
version = 1
|
||||
revision = 2
|
||||
revision = 3
|
||||
requires-python = ">=3.10.0, <4.0.0"
|
||||
resolution-markers = [
|
||||
"python_full_version >= '3.14' and platform_python_implementation == 'PyPy'",
|
||||
@@ -960,7 +960,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "langchain-core"
|
||||
version = "1.0.4"
|
||||
version = "1.1.0"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "jsonpatch" },
|
||||
@@ -985,7 +985,6 @@ test = [
|
||||
{ name = "blockbuster" },
|
||||
{ name = "freezegun" },
|
||||
{ name = "grandalf" },
|
||||
{ name = "langchain-model-profiles" },
|
||||
{ name = "langchain-tests" },
|
||||
{ name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" },
|
||||
{ name = "numpy", version = "2.3.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" },
|
||||
@@ -1001,7 +1000,6 @@ test = [
|
||||
{ name = "syrupy" },
|
||||
]
|
||||
typing = [
|
||||
{ name = "langchain-model-profiles" },
|
||||
{ name = "langchain-text-splitters" },
|
||||
{ name = "mypy" },
|
||||
{ name = "types-pyyaml" },
|
||||
@@ -1031,7 +1029,6 @@ test = [
|
||||
{ name = "blockbuster", specifier = ">=1.5.18,<1.6.0" },
|
||||
{ name = "freezegun", specifier = ">=1.2.2,<2.0.0" },
|
||||
{ name = "grandalf", specifier = ">=0.8.0,<1.0.0" },
|
||||
{ name = "langchain-model-profiles", directory = "../model-profiles" },
|
||||
{ name = "langchain-tests", directory = "../standard-tests" },
|
||||
{ name = "numpy", marker = "python_full_version < '3.13'", specifier = ">=1.26.4" },
|
||||
{ name = "numpy", marker = "python_full_version >= '3.13'", specifier = ">=2.1.0" },
|
||||
@@ -1048,56 +1045,15 @@ test = [
|
||||
]
|
||||
test-integration = []
|
||||
typing = [
|
||||
{ name = "langchain-model-profiles", directory = "../model-profiles" },
|
||||
{ name = "langchain-text-splitters", directory = "../text-splitters" },
|
||||
{ name = "mypy", specifier = ">=1.18.1,<1.19.0" },
|
||||
{ name = "types-pyyaml", specifier = ">=6.0.12.2,<7.0.0.0" },
|
||||
{ name = "types-requests", specifier = ">=2.28.11.5,<3.0.0.0" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "langchain-model-profiles"
|
||||
version = "0.0.3"
|
||||
source = { directory = "../model-profiles" }
|
||||
dependencies = [
|
||||
{ name = "tomli", marker = "python_full_version < '3.11'" },
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "tomli", marker = "python_full_version < '3.11'", specifier = ">=2.0.0,<3.0.0" },
|
||||
{ name = "typing-extensions", specifier = ">=4.7.0,<5.0.0" },
|
||||
]
|
||||
|
||||
[package.metadata.requires-dev]
|
||||
dev = [{ name = "httpx", specifier = ">=0.23.0,<1" }]
|
||||
lint = [
|
||||
{ name = "langchain", editable = "../langchain_v1" },
|
||||
{ name = "ruff", specifier = ">=0.12.2,<0.13.0" },
|
||||
]
|
||||
test = [
|
||||
{ name = "langchain", extras = ["openai"], editable = "../langchain_v1" },
|
||||
{ name = "langchain-core", editable = "." },
|
||||
{ name = "pytest", specifier = ">=8.0.0,<9.0.0" },
|
||||
{ name = "pytest-asyncio", specifier = ">=0.23.2,<2.0.0" },
|
||||
{ name = "pytest-cov", specifier = ">=4.0.0,<8.0.0" },
|
||||
{ name = "pytest-mock" },
|
||||
{ name = "pytest-socket", specifier = ">=0.6.0,<1.0.0" },
|
||||
{ name = "pytest-watcher", specifier = ">=0.2.6,<1.0.0" },
|
||||
{ name = "pytest-xdist", specifier = ">=3.6.1,<4.0.0" },
|
||||
{ name = "syrupy", specifier = ">=4.0.2,<5.0.0" },
|
||||
{ name = "toml", specifier = ">=0.10.2,<1.0.0" },
|
||||
]
|
||||
test-integration = [{ name = "langchain-core", editable = "." }]
|
||||
typing = [
|
||||
{ name = "mypy", specifier = ">=1.18.1,<1.19.0" },
|
||||
{ name = "types-toml", specifier = ">=0.10.8.20240310,<1.0.0.0" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "langchain-tests"
|
||||
version = "1.0.1"
|
||||
version = "1.0.2"
|
||||
source = { directory = "../standard-tests" }
|
||||
dependencies = [
|
||||
{ name = "httpx" },
|
||||
|
||||
@@ -24,7 +24,7 @@ In most cases, you should be using the main [`langchain`](https://pypi.org/proje
|
||||
|
||||
## 📖 Documentation
|
||||
|
||||
For full documentation, see the [API reference](https://reference.langchain.com/python/langchain_classic).
|
||||
For full documentation, see the [API reference](https://reference.langchain.com/python/langchain_classic). For conceptual guides, tutorials, and examples on using LangChain, see the [LangChain Docs](https://docs.langchain.com/oss/python/langchain/overview).
|
||||
|
||||
## 📕 Releases & Versioning
|
||||
|
||||
|
||||
@@ -319,9 +319,11 @@ def init_chat_model(
|
||||
```
|
||||
|
||||
!!! warning "Behavior changed in `langchain` 0.2.8"
|
||||
|
||||
Support for `configurable_fields` and `config_prefix` added.
|
||||
|
||||
!!! warning "Behavior changed in `langchain` 0.2.12"
|
||||
|
||||
Support for Ollama via langchain-ollama package added
|
||||
(`langchain_ollama.ChatOllama`). Previously,
|
||||
the now-deprecated langchain-community version of Ollama was imported
|
||||
@@ -331,9 +333,11 @@ def init_chat_model(
|
||||
(`model_provider="bedrock_converse"`).
|
||||
|
||||
!!! warning "Behavior changed in `langchain` 0.3.5"
|
||||
|
||||
Out of beta.
|
||||
|
||||
!!! warning "Behavior changed in `langchain` 0.3.19"
|
||||
|
||||
Support for Deepseek, IBM, Nvidia, and xAI models added.
|
||||
|
||||
""" # noqa: E501
|
||||
@@ -547,13 +551,17 @@ def _attempt_infer_model_provider(model_name: str) -> str | None:
|
||||
|
||||
|
||||
def _parse_model(model: str, model_provider: str | None) -> tuple[str, str]:
|
||||
if (
|
||||
not model_provider
|
||||
and ":" in model
|
||||
and model.split(":")[0] in _SUPPORTED_PROVIDERS
|
||||
):
|
||||
model_provider = model.split(":")[0]
|
||||
model = ":".join(model.split(":")[1:])
|
||||
if not model_provider and ":" in model:
|
||||
prefix, suffix = model.split(":", 1)
|
||||
if prefix in _SUPPORTED_PROVIDERS:
|
||||
model_provider = prefix
|
||||
model = suffix
|
||||
else:
|
||||
inferred = _attempt_infer_model_provider(prefix)
|
||||
if inferred:
|
||||
model_provider = inferred
|
||||
model = suffix
|
||||
|
||||
model_provider = model_provider or _attempt_infer_model_provider(model)
|
||||
if not model_provider:
|
||||
msg = (
|
||||
|
||||
@@ -32,13 +32,18 @@ class MultiVectorRetriever(BaseRetriever):
|
||||
vectorstore: VectorStore
|
||||
"""The underlying `VectorStore` to use to store small chunks
|
||||
and their embedding vectors"""
|
||||
|
||||
byte_store: ByteStore | None = None
|
||||
"""The lower-level backing storage layer for the parent documents"""
|
||||
|
||||
docstore: BaseStore[str, Document]
|
||||
"""The storage interface for the parent documents"""
|
||||
|
||||
id_key: str = "doc_id"
|
||||
|
||||
search_kwargs: dict = Field(default_factory=dict)
|
||||
"""Keyword arguments to pass to the search function."""
|
||||
|
||||
search_type: SearchType = SearchType.similarity
|
||||
"""Type of search to perform (similarity / mmr)"""
|
||||
|
||||
|
||||
@@ -26,7 +26,6 @@ class ProgressBarCallback(base_callbacks.BaseCallbackHandler):
|
||||
total: The total number of items to be processed.
|
||||
ncols: The character width of the progress bar.
|
||||
end_with: Last string to print after progress bar reaches end.
|
||||
**kwargs: Additional keyword arguments.
|
||||
"""
|
||||
self.total = total
|
||||
self.ncols = ncols
|
||||
|
||||
@@ -295,11 +295,7 @@ def _get_prompt(inputs: dict[str, Any]) -> str:
|
||||
|
||||
|
||||
class ChatModelInput(TypedDict):
|
||||
"""Input for a chat model.
|
||||
|
||||
Args:
|
||||
messages: List of chat messages.
|
||||
"""
|
||||
"""Input for a chat model."""
|
||||
|
||||
messages: list[BaseMessage]
|
||||
|
||||
|
||||
@@ -108,8 +108,8 @@ class LLMStringRunMapper(StringRunMapper):
|
||||
The serialized output text from the first generation.
|
||||
|
||||
Raises:
|
||||
ValueError: If no generations are found in the outputs,
|
||||
or if the generations are empty.
|
||||
ValueError: If no generations are found in the outputs or if the generations
|
||||
are empty.
|
||||
"""
|
||||
if not outputs.get("generations"):
|
||||
msg = "Cannot evaluate LLM Run without generations."
|
||||
@@ -436,8 +436,8 @@ class StringRunEvaluatorChain(Chain, RunEvaluator):
|
||||
The instantiated evaluation chain.
|
||||
|
||||
Raises:
|
||||
If the run type is not supported, or if the evaluator requires a
|
||||
reference from the dataset but the reference key is not provided.
|
||||
ValueError: If the run type is not supported, or if the evaluator requires a
|
||||
reference from the dataset but the reference key is not provided.
|
||||
|
||||
"""
|
||||
# Configure how run inputs/predictions are passed to the evaluator
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
.PHONY: all start_services stop_services coverage test test_fast extended_tests test_watch test_watch_extended integration_tests check_imports lint format lint_diff format_diff lint_package lint_tests help
|
||||
.PHONY: all start_services stop_services coverage coverage_agents test test_fast extended_tests test_watch test_watch_extended integration_tests check_imports lint format lint_diff format_diff lint_package lint_tests help
|
||||
|
||||
# Default target executed when no arguments are given to make.
|
||||
all: help
|
||||
@@ -27,8 +27,17 @@ coverage:
|
||||
--cov-report term-missing:skip-covered \
|
||||
$(TEST_FILE)
|
||||
|
||||
# Run middleware and agent tests with coverage report.
|
||||
coverage_agents:
|
||||
uv run --group test pytest \
|
||||
tests/unit_tests/agents/middleware/ \
|
||||
tests/unit_tests/agents/test_*.py \
|
||||
--cov=langchain.agents \
|
||||
--cov-report=term-missing \
|
||||
--cov-report=html:htmlcov \
|
||||
|
||||
test:
|
||||
make start_services && LANGGRAPH_TEST_FAST=0 uv run --no-sync --active --group test pytest -n auto --disable-socket --allow-unix-socket $(TEST_FILE) --cov-report term-missing:skip-covered; \
|
||||
make start_services && LANGGRAPH_TEST_FAST=0 uv run --no-sync --active --group test pytest -n auto --disable-socket --allow-unix-socket $(TEST_FILE) --cov-report term-missing:skip-covered --snapshot-update; \
|
||||
EXIT_CODE=$$?; \
|
||||
make stop_services; \
|
||||
exit $$EXIT_CODE
|
||||
@@ -93,6 +102,7 @@ help:
|
||||
@echo 'lint - run linters'
|
||||
@echo '-- TESTS --'
|
||||
@echo 'coverage - run unit tests and generate coverage report'
|
||||
@echo 'coverage_agents - run middleware and agent tests with coverage report'
|
||||
@echo 'test - run unit tests with all services'
|
||||
@echo 'test_fast - run unit tests with in-memory services only'
|
||||
@echo 'tests - run unit tests (alias for "make test")'
|
||||
|
||||
@@ -26,7 +26,7 @@ LangChain [agents](https://docs.langchain.com/oss/python/langchain/agents) are b
|
||||
|
||||
## 📖 Documentation
|
||||
|
||||
For full documentation, see the [API reference](https://reference.langchain.com/python/langchain/langchain/).
|
||||
For full documentation, see the [API reference](https://reference.langchain.com/python/langchain/langchain/). For conceptual guides, tutorials, and examples on using LangChain, see the [LangChain Docs](https://docs.langchain.com/oss/python/langchain/overview).
|
||||
|
||||
## 📕 Releases & Versioning
|
||||
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
"""Main entrypoint into LangChain."""
|
||||
|
||||
__version__ = "1.0.5"
|
||||
__version__ = "1.1.0"
|
||||
|
||||
@@ -63,6 +63,18 @@ if TYPE_CHECKING:
|
||||
|
||||
STRUCTURED_OUTPUT_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes."
|
||||
|
||||
FALLBACK_MODELS_WITH_STRUCTURED_OUTPUT = [
|
||||
# if model profile data are not available, these models are assumed to support
|
||||
# structured output
|
||||
"grok",
|
||||
"gpt-5",
|
||||
"gpt-4.1",
|
||||
"gpt-4o",
|
||||
"gpt-oss",
|
||||
"o3-pro",
|
||||
"o3-mini",
|
||||
]
|
||||
|
||||
|
||||
def _normalize_to_model_response(result: ModelResponse | AIMessage) -> ModelResponse:
|
||||
"""Normalize middleware return value to ModelResponse."""
|
||||
@@ -349,11 +361,13 @@ def _get_can_jump_to(middleware: AgentMiddleware[Any, Any], hook_name: str) -> l
|
||||
return []
|
||||
|
||||
|
||||
def _supports_provider_strategy(model: str | BaseChatModel) -> bool:
|
||||
def _supports_provider_strategy(model: str | BaseChatModel, tools: list | None = None) -> bool:
|
||||
"""Check if a model supports provider-specific structured output.
|
||||
|
||||
Args:
|
||||
model: Model name string or `BaseChatModel` instance.
|
||||
tools: Optional list of tools provided to the agent. Needed because some models
|
||||
don't support structured output together with tool calling.
|
||||
|
||||
Returns:
|
||||
`True` if the model supports provider-specific structured output, `False` otherwise.
|
||||
@@ -362,11 +376,23 @@ def _supports_provider_strategy(model: str | BaseChatModel) -> bool:
|
||||
if isinstance(model, str):
|
||||
model_name = model
|
||||
elif isinstance(model, BaseChatModel):
|
||||
model_name = getattr(model, "model_name", None)
|
||||
model_name = (
|
||||
getattr(model, "model_name", None)
|
||||
or getattr(model, "model", None)
|
||||
or getattr(model, "model_id", "")
|
||||
)
|
||||
model_profile = model.profile
|
||||
if (
|
||||
model_profile is not None
|
||||
and model_profile.get("structured_output")
|
||||
# We make an exception for Gemini models, which currently do not support
|
||||
# simultaneous tool use with structured output
|
||||
and not (tools and isinstance(model_name, str) and "gemini" in model_name.lower())
|
||||
):
|
||||
return True
|
||||
|
||||
return (
|
||||
"grok" in model_name.lower()
|
||||
or any(part in model_name for part in ["gpt-5", "gpt-4.1", "gpt-oss", "o3-pro", "o3-mini"])
|
||||
any(part in model_name.lower() for part in FALLBACK_MODELS_WITH_STRUCTURED_OUTPUT)
|
||||
if model_name
|
||||
else False
|
||||
)
|
||||
@@ -516,7 +542,7 @@ def create_agent( # noqa: PLR0915
|
||||
model: str | BaseChatModel,
|
||||
tools: Sequence[BaseTool | Callable | dict[str, Any]] | None = None,
|
||||
*,
|
||||
system_prompt: str | None = None,
|
||||
system_prompt: str | SystemMessage | None = None,
|
||||
middleware: Sequence[AgentMiddleware[StateT_co, ContextT]] = (),
|
||||
response_format: ResponseFormat[ResponseT] | type[ResponseT] | None = None,
|
||||
state_schema: type[AgentState[ResponseT]] | None = None,
|
||||
@@ -562,9 +588,9 @@ def create_agent( # noqa: PLR0915
|
||||
docs for more information.
|
||||
system_prompt: An optional system prompt for the LLM.
|
||||
|
||||
Prompts are converted to a
|
||||
[`SystemMessage`][langchain.messages.SystemMessage] and added to the
|
||||
beginning of the message list.
|
||||
Can be a `str` (which will be converted to a `SystemMessage`) or a
|
||||
`SystemMessage` instance directly. The system message is added to the
|
||||
beginning of the message list when calling the model.
|
||||
middleware: A sequence of middleware instances to apply to the agent.
|
||||
|
||||
Middleware can intercept and modify agent behavior at various stages.
|
||||
@@ -659,6 +685,14 @@ def create_agent( # noqa: PLR0915
|
||||
if isinstance(model, str):
|
||||
model = init_chat_model(model)
|
||||
|
||||
# Convert system_prompt to SystemMessage if needed
|
||||
system_message: SystemMessage | None = None
|
||||
if system_prompt is not None:
|
||||
if isinstance(system_prompt, SystemMessage):
|
||||
system_message = system_prompt
|
||||
else:
|
||||
system_message = SystemMessage(content=system_prompt)
|
||||
|
||||
# Handle tools being None or empty
|
||||
if tools is None:
|
||||
tools = []
|
||||
@@ -988,7 +1022,7 @@ def create_agent( # noqa: PLR0915
|
||||
effective_response_format: ResponseFormat | None
|
||||
if isinstance(request.response_format, AutoStrategy):
|
||||
# User provided raw schema via AutoStrategy - auto-detect best strategy based on model
|
||||
if _supports_provider_strategy(request.model):
|
||||
if _supports_provider_strategy(request.model, tools=request.tools):
|
||||
# Model supports provider strategy - use it
|
||||
effective_response_format = ProviderStrategy(schema=request.response_format.schema)
|
||||
else:
|
||||
@@ -1009,7 +1043,7 @@ def create_agent( # noqa: PLR0915
|
||||
|
||||
# Bind model based on effective response format
|
||||
if isinstance(effective_response_format, ProviderStrategy):
|
||||
# Use provider-specific structured output
|
||||
# (Backward compatibility) Use OpenAI format structured output
|
||||
kwargs = effective_response_format.to_model_kwargs()
|
||||
return (
|
||||
request.model.bind_tools(
|
||||
@@ -1062,8 +1096,8 @@ def create_agent( # noqa: PLR0915
|
||||
# Get the bound model (with auto-detection if needed)
|
||||
model_, effective_response_format = _get_bound_model(request)
|
||||
messages = request.messages
|
||||
if request.system_prompt:
|
||||
messages = [SystemMessage(request.system_prompt), *messages]
|
||||
if request.system_message:
|
||||
messages = [request.system_message, *messages]
|
||||
|
||||
output = model_.invoke(messages)
|
||||
|
||||
@@ -1082,7 +1116,7 @@ def create_agent( # noqa: PLR0915
|
||||
request = ModelRequest(
|
||||
model=model,
|
||||
tools=default_tools,
|
||||
system_prompt=system_prompt,
|
||||
system_message=system_message,
|
||||
response_format=initial_response_format,
|
||||
messages=state["messages"],
|
||||
tool_choice=None,
|
||||
@@ -1115,8 +1149,8 @@ def create_agent( # noqa: PLR0915
|
||||
# Get the bound model (with auto-detection if needed)
|
||||
model_, effective_response_format = _get_bound_model(request)
|
||||
messages = request.messages
|
||||
if request.system_prompt:
|
||||
messages = [SystemMessage(request.system_prompt), *messages]
|
||||
if request.system_message:
|
||||
messages = [request.system_message, *messages]
|
||||
|
||||
output = await model_.ainvoke(messages)
|
||||
|
||||
@@ -1135,7 +1169,7 @@ def create_agent( # noqa: PLR0915
|
||||
request = ModelRequest(
|
||||
model=model,
|
||||
tools=default_tools,
|
||||
system_prompt=system_prompt,
|
||||
system_message=system_message,
|
||||
response_format=initial_response_format,
|
||||
messages=state["messages"],
|
||||
tool_choice=None,
|
||||
|
||||
@@ -11,6 +11,7 @@ from .human_in_the_loop import (
|
||||
)
|
||||
from .model_call_limit import ModelCallLimitMiddleware
|
||||
from .model_fallback import ModelFallbackMiddleware
|
||||
from .model_retry import ModelRetryMiddleware
|
||||
from .pii import PIIDetectionError, PIIMiddleware
|
||||
from .shell_tool import (
|
||||
CodexSandboxExecutionPolicy,
|
||||
@@ -57,6 +58,7 @@ __all__ = [
|
||||
"ModelFallbackMiddleware",
|
||||
"ModelRequest",
|
||||
"ModelResponse",
|
||||
"ModelRetryMiddleware",
|
||||
"PIIDetectionError",
|
||||
"PIIMiddleware",
|
||||
"RedactionRule",
|
||||
|
||||
123
libs/langchain_v1/langchain/agents/middleware/_retry.py
Normal file
123
libs/langchain_v1/langchain/agents/middleware/_retry.py
Normal file
@@ -0,0 +1,123 @@
|
||||
"""Shared retry utilities for agent middleware.
|
||||
|
||||
This module contains common constants, utilities, and logic used by both
|
||||
model and tool retry middleware implementations.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
from collections.abc import Callable
|
||||
from typing import Literal
|
||||
|
||||
# Type aliases
|
||||
RetryOn = tuple[type[Exception], ...] | Callable[[Exception], bool]
|
||||
"""Type for specifying which exceptions to retry on.
|
||||
|
||||
Can be either:
|
||||
- A tuple of exception types to retry on (based on `isinstance` checks)
|
||||
- A callable that takes an exception and returns `True` if it should be retried
|
||||
"""
|
||||
|
||||
OnFailure = Literal["error", "continue"] | Callable[[Exception], str]
|
||||
"""Type for specifying failure handling behavior.
|
||||
|
||||
Can be either:
|
||||
- A literal action string (`'error'` or `'continue'`)
|
||||
- `'error'`: Re-raise the exception, stopping agent execution.
|
||||
- `'continue'`: Inject a message with the error details, allowing the agent to continue.
|
||||
For tool retries, a `ToolMessage` with the error details will be injected.
|
||||
For model retries, an `AIMessage` with the error details will be returned.
|
||||
- A callable that takes an exception and returns a string for error message content
|
||||
"""
|
||||
|
||||
|
||||
def validate_retry_params(
|
||||
max_retries: int,
|
||||
initial_delay: float,
|
||||
max_delay: float,
|
||||
backoff_factor: float,
|
||||
) -> None:
|
||||
"""Validate retry parameters.
|
||||
|
||||
Args:
|
||||
max_retries: Maximum number of retry attempts.
|
||||
initial_delay: Initial delay in seconds before first retry.
|
||||
max_delay: Maximum delay in seconds between retries.
|
||||
backoff_factor: Multiplier for exponential backoff.
|
||||
|
||||
Raises:
|
||||
ValueError: If any parameter is invalid (negative values).
|
||||
"""
|
||||
if max_retries < 0:
|
||||
msg = "max_retries must be >= 0"
|
||||
raise ValueError(msg)
|
||||
if initial_delay < 0:
|
||||
msg = "initial_delay must be >= 0"
|
||||
raise ValueError(msg)
|
||||
if max_delay < 0:
|
||||
msg = "max_delay must be >= 0"
|
||||
raise ValueError(msg)
|
||||
if backoff_factor < 0:
|
||||
msg = "backoff_factor must be >= 0"
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
def should_retry_exception(
|
||||
exc: Exception,
|
||||
retry_on: RetryOn,
|
||||
) -> bool:
|
||||
"""Check if an exception should trigger a retry.
|
||||
|
||||
Args:
|
||||
exc: The exception that occurred.
|
||||
retry_on: Either a tuple of exception types to retry on, or a callable
|
||||
that takes an exception and returns `True` if it should be retried.
|
||||
|
||||
Returns:
|
||||
`True` if the exception should be retried, `False` otherwise.
|
||||
"""
|
||||
if callable(retry_on):
|
||||
return retry_on(exc)
|
||||
return isinstance(exc, retry_on)
|
||||
|
||||
|
||||
def calculate_delay(
|
||||
retry_number: int,
|
||||
*,
|
||||
backoff_factor: float,
|
||||
initial_delay: float,
|
||||
max_delay: float,
|
||||
jitter: bool,
|
||||
) -> float:
|
||||
"""Calculate delay for a retry attempt with exponential backoff and optional jitter.
|
||||
|
||||
Args:
|
||||
retry_number: The retry attempt number (0-indexed).
|
||||
backoff_factor: Multiplier for exponential backoff.
|
||||
|
||||
Set to `0.0` for constant delay.
|
||||
initial_delay: Initial delay in seconds before first retry.
|
||||
max_delay: Maximum delay in seconds between retries.
|
||||
|
||||
Caps exponential backoff growth.
|
||||
jitter: Whether to add random jitter to delay to avoid thundering herd.
|
||||
|
||||
Returns:
|
||||
Delay in seconds before next retry.
|
||||
"""
|
||||
if backoff_factor == 0.0:
|
||||
delay = initial_delay
|
||||
else:
|
||||
delay = initial_delay * (backoff_factor**retry_number)
|
||||
|
||||
# Cap at max_delay
|
||||
delay = min(delay, max_delay)
|
||||
|
||||
if jitter and delay > 0:
|
||||
jitter_amount = delay * 0.25 # ±25% jitter
|
||||
delay = delay + random.uniform(-jitter_amount, jitter_amount) # noqa: S311
|
||||
# Ensure delay is not negative after jitter
|
||||
delay = max(0, delay)
|
||||
|
||||
return delay
|
||||
@@ -10,6 +10,7 @@ chat model.
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Awaitable, Callable, Iterable, Sequence
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
@@ -17,7 +18,6 @@ from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AnyMessage,
|
||||
BaseMessage,
|
||||
SystemMessage,
|
||||
ToolMessage,
|
||||
)
|
||||
from langchain_core.messages.utils import count_tokens_approximately
|
||||
@@ -189,7 +189,7 @@ class ContextEditingMiddleware(AgentMiddleware):
|
||||
configured thresholds.
|
||||
|
||||
Currently the `ClearToolUsesEdit` strategy is supported, aligning with Anthropic's
|
||||
`clear_tool_uses_20250919` behavior [(read more)](https://docs.claude.com/en/docs/agents-and-tools/tool-use/memory-tool).
|
||||
`clear_tool_uses_20250919` behavior [(read more)](https://platform.claude.com/docs/en/agents-and-tools/tool-use/memory-tool).
|
||||
"""
|
||||
|
||||
edits: list[ContextEdit]
|
||||
@@ -229,19 +229,18 @@ class ContextEditingMiddleware(AgentMiddleware):
|
||||
def count_tokens(messages: Sequence[BaseMessage]) -> int:
|
||||
return count_tokens_approximately(messages)
|
||||
else:
|
||||
system_msg = (
|
||||
[SystemMessage(content=request.system_prompt)] if request.system_prompt else []
|
||||
)
|
||||
system_msg = [request.system_message] if request.system_message else []
|
||||
|
||||
def count_tokens(messages: Sequence[BaseMessage]) -> int:
|
||||
return request.model.get_num_tokens_from_messages(
|
||||
system_msg + list(messages), request.tools
|
||||
)
|
||||
|
||||
edited_messages = deepcopy(list(request.messages))
|
||||
for edit in self.edits:
|
||||
edit.apply(request.messages, count_tokens=count_tokens)
|
||||
edit.apply(edited_messages, count_tokens=count_tokens)
|
||||
|
||||
return handler(request)
|
||||
return handler(request.override(messages=edited_messages))
|
||||
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
@@ -257,19 +256,18 @@ class ContextEditingMiddleware(AgentMiddleware):
|
||||
def count_tokens(messages: Sequence[BaseMessage]) -> int:
|
||||
return count_tokens_approximately(messages)
|
||||
else:
|
||||
system_msg = (
|
||||
[SystemMessage(content=request.system_prompt)] if request.system_prompt else []
|
||||
)
|
||||
system_msg = [request.system_message] if request.system_message else []
|
||||
|
||||
def count_tokens(messages: Sequence[BaseMessage]) -> int:
|
||||
return request.model.get_num_tokens_from_messages(
|
||||
system_msg + list(messages), request.tools
|
||||
)
|
||||
|
||||
edited_messages = deepcopy(list(request.messages))
|
||||
for edit in self.edits:
|
||||
edit.apply(request.messages, count_tokens=count_tokens)
|
||||
edit.apply(edited_messages, count_tokens=count_tokens)
|
||||
|
||||
return await handler(request)
|
||||
return await handler(request.override(messages=edited_messages))
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
||||
@@ -120,9 +120,9 @@ class FilesystemFileSearchMiddleware(AgentMiddleware):
|
||||
|
||||
Args:
|
||||
root_path: Root directory to search.
|
||||
use_ripgrep: Whether to use ripgrep for search.
|
||||
use_ripgrep: Whether to use `ripgrep` for search.
|
||||
|
||||
Falls back to Python if ripgrep unavailable.
|
||||
Falls back to Python if `ripgrep` unavailable.
|
||||
max_file_size_mb: Maximum file size to search in MB.
|
||||
"""
|
||||
self.root_path = Path(root_path).resolve()
|
||||
|
||||
@@ -7,7 +7,7 @@ from langgraph.runtime import Runtime
|
||||
from langgraph.types import interrupt
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
|
||||
from langchain.agents.middleware.types import AgentMiddleware, AgentState
|
||||
from langchain.agents.middleware.types import AgentMiddleware, AgentState, ContextT, StateT
|
||||
|
||||
|
||||
class Action(TypedDict):
|
||||
@@ -102,7 +102,7 @@ class HITLResponse(TypedDict):
|
||||
class _DescriptionFactory(Protocol):
|
||||
"""Callable that generates a description for a tool call."""
|
||||
|
||||
def __call__(self, tool_call: ToolCall, state: AgentState, runtime: Runtime) -> str:
|
||||
def __call__(self, tool_call: ToolCall, state: AgentState, runtime: Runtime[ContextT]) -> str:
|
||||
"""Generate a description for a tool call."""
|
||||
...
|
||||
|
||||
@@ -138,7 +138,7 @@ class InterruptOnConfig(TypedDict):
|
||||
def format_tool_description(
|
||||
tool_call: ToolCall,
|
||||
state: AgentState,
|
||||
runtime: Runtime
|
||||
runtime: Runtime[ContextT]
|
||||
) -> str:
|
||||
import json
|
||||
return (
|
||||
@@ -156,7 +156,7 @@ class InterruptOnConfig(TypedDict):
|
||||
"""JSON schema for the args associated with the action, if edits are allowed."""
|
||||
|
||||
|
||||
class HumanInTheLoopMiddleware(AgentMiddleware):
|
||||
class HumanInTheLoopMiddleware(AgentMiddleware[StateT, ContextT]):
|
||||
"""Human in the loop middleware."""
|
||||
|
||||
def __init__(
|
||||
@@ -204,7 +204,7 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
|
||||
tool_call: ToolCall,
|
||||
config: InterruptOnConfig,
|
||||
state: AgentState,
|
||||
runtime: Runtime,
|
||||
runtime: Runtime[ContextT],
|
||||
) -> tuple[ActionRequest, ReviewConfig]:
|
||||
"""Create an ActionRequest and ReviewConfig for a tool call."""
|
||||
tool_name = tool_call["name"]
|
||||
@@ -277,7 +277,7 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
def after_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
|
||||
def after_model(self, state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
|
||||
"""Trigger interrupt flows for relevant tool calls after an `AIMessage`."""
|
||||
messages = state["messages"]
|
||||
if not messages:
|
||||
@@ -287,36 +287,23 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
|
||||
if not last_ai_msg or not last_ai_msg.tool_calls:
|
||||
return None
|
||||
|
||||
# Separate tool calls that need interrupts from those that don't
|
||||
interrupt_tool_calls: list[ToolCall] = []
|
||||
auto_approved_tool_calls = []
|
||||
|
||||
for tool_call in last_ai_msg.tool_calls:
|
||||
interrupt_tool_calls.append(tool_call) if tool_call[
|
||||
"name"
|
||||
] in self.interrupt_on else auto_approved_tool_calls.append(tool_call)
|
||||
|
||||
# If no interrupts needed, return early
|
||||
if not interrupt_tool_calls:
|
||||
return None
|
||||
|
||||
# Process all tool calls that require interrupts
|
||||
revised_tool_calls: list[ToolCall] = auto_approved_tool_calls.copy()
|
||||
artificial_tool_messages: list[ToolMessage] = []
|
||||
|
||||
# Create action requests and review configs for all tools that need approval
|
||||
# Create action requests and review configs for tools that need approval
|
||||
action_requests: list[ActionRequest] = []
|
||||
review_configs: list[ReviewConfig] = []
|
||||
interrupt_indices: list[int] = []
|
||||
|
||||
for tool_call in interrupt_tool_calls:
|
||||
config = self.interrupt_on[tool_call["name"]]
|
||||
for idx, tool_call in enumerate(last_ai_msg.tool_calls):
|
||||
if (config := self.interrupt_on.get(tool_call["name"])) is not None:
|
||||
action_request, review_config = self._create_action_and_config(
|
||||
tool_call, config, state, runtime
|
||||
)
|
||||
action_requests.append(action_request)
|
||||
review_configs.append(review_config)
|
||||
interrupt_indices.append(idx)
|
||||
|
||||
# Create ActionRequest and ReviewConfig using helper method
|
||||
action_request, review_config = self._create_action_and_config(
|
||||
tool_call, config, state, runtime
|
||||
)
|
||||
action_requests.append(action_request)
|
||||
review_configs.append(review_config)
|
||||
# If no interrupts needed, return early
|
||||
if not action_requests:
|
||||
return None
|
||||
|
||||
# Create single HITLRequest with all actions and configs
|
||||
hitl_request = HITLRequest(
|
||||
@@ -325,35 +312,46 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
|
||||
)
|
||||
|
||||
# Send interrupt and get response
|
||||
hitl_response: HITLResponse = interrupt(hitl_request)
|
||||
decisions = hitl_response["decisions"]
|
||||
decisions = interrupt(hitl_request)["decisions"]
|
||||
|
||||
# Validate that the number of decisions matches the number of interrupt tool calls
|
||||
if (decisions_len := len(decisions)) != (
|
||||
interrupt_tool_calls_len := len(interrupt_tool_calls)
|
||||
):
|
||||
if (decisions_len := len(decisions)) != (interrupt_count := len(interrupt_indices)):
|
||||
msg = (
|
||||
f"Number of human decisions ({decisions_len}) does not match "
|
||||
f"number of hanging tool calls ({interrupt_tool_calls_len})."
|
||||
f"number of hanging tool calls ({interrupt_count})."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
# Process each decision using helper method
|
||||
for i, decision in enumerate(decisions):
|
||||
tool_call = interrupt_tool_calls[i]
|
||||
config = self.interrupt_on[tool_call["name"]]
|
||||
# Process decisions and rebuild tool calls in original order
|
||||
revised_tool_calls: list[ToolCall] = []
|
||||
artificial_tool_messages: list[ToolMessage] = []
|
||||
decision_idx = 0
|
||||
|
||||
revised_tool_call, tool_message = self._process_decision(decision, tool_call, config)
|
||||
if revised_tool_call:
|
||||
revised_tool_calls.append(revised_tool_call)
|
||||
if tool_message:
|
||||
artificial_tool_messages.append(tool_message)
|
||||
for idx, tool_call in enumerate(last_ai_msg.tool_calls):
|
||||
if idx in interrupt_indices:
|
||||
# This was an interrupt tool call - process the decision
|
||||
config = self.interrupt_on[tool_call["name"]]
|
||||
decision = decisions[decision_idx]
|
||||
decision_idx += 1
|
||||
|
||||
revised_tool_call, tool_message = self._process_decision(
|
||||
decision, tool_call, config
|
||||
)
|
||||
if revised_tool_call is not None:
|
||||
revised_tool_calls.append(revised_tool_call)
|
||||
if tool_message:
|
||||
artificial_tool_messages.append(tool_message)
|
||||
else:
|
||||
# This was auto-approved - keep original
|
||||
revised_tool_calls.append(tool_call)
|
||||
|
||||
# Update the AI message to only include approved tool calls
|
||||
last_ai_msg.tool_calls = revised_tool_calls
|
||||
|
||||
return {"messages": [last_ai_msg, *artificial_tool_messages]}
|
||||
|
||||
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
|
||||
async def aafter_model(
|
||||
self, state: AgentState, runtime: Runtime[ContextT]
|
||||
) -> dict[str, Any] | None:
|
||||
"""Async trigger interrupt flows for relevant tool calls after an `AIMessage`."""
|
||||
return self.after_model(state, runtime)
|
||||
|
||||
@@ -133,6 +133,7 @@ class ModelCallLimitMiddleware(AgentMiddleware[ModelCallLimitState, Any]):
|
||||
|
||||
`None` means no limit.
|
||||
exit_behavior: What to do when limits are exceeded.
|
||||
|
||||
- `'end'`: Jump to the end of the agent execution and
|
||||
inject an artificial AI message indicating that the limit was
|
||||
exceeded.
|
||||
|
||||
@@ -92,9 +92,8 @@ class ModelFallbackMiddleware(AgentMiddleware):
|
||||
|
||||
# Try fallback models
|
||||
for fallback_model in self.models:
|
||||
request.model = fallback_model
|
||||
try:
|
||||
return handler(request)
|
||||
return handler(request.override(model=fallback_model))
|
||||
except Exception as e: # noqa: BLE001
|
||||
last_exception = e
|
||||
continue
|
||||
@@ -127,9 +126,8 @@ class ModelFallbackMiddleware(AgentMiddleware):
|
||||
|
||||
# Try fallback models
|
||||
for fallback_model in self.models:
|
||||
request.model = fallback_model
|
||||
try:
|
||||
return await handler(request)
|
||||
return await handler(request.override(model=fallback_model))
|
||||
except Exception as e: # noqa: BLE001
|
||||
last_exception = e
|
||||
continue
|
||||
|
||||
300
libs/langchain_v1/langchain/agents/middleware/model_retry.py
Normal file
300
libs/langchain_v1/langchain/agents/middleware/model_retry.py
Normal file
@@ -0,0 +1,300 @@
|
||||
"""Model retry middleware for agents."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
from langchain.agents.middleware._retry import (
|
||||
OnFailure,
|
||||
RetryOn,
|
||||
calculate_delay,
|
||||
should_retry_exception,
|
||||
validate_retry_params,
|
||||
)
|
||||
from langchain.agents.middleware.types import AgentMiddleware, ModelResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
from langchain.agents.middleware.types import ModelRequest
|
||||
|
||||
|
||||
class ModelRetryMiddleware(AgentMiddleware):
|
||||
"""Middleware that automatically retries failed model calls with configurable backoff.
|
||||
|
||||
Supports retrying on specific exceptions and exponential backoff.
|
||||
|
||||
Examples:
|
||||
!!! example "Basic usage with default settings (2 retries, exponential backoff)"
|
||||
|
||||
```python
|
||||
from langchain.agents import create_agent
|
||||
from langchain.agents.middleware import ModelRetryMiddleware
|
||||
|
||||
agent = create_agent(model, tools=[search_tool], middleware=[ModelRetryMiddleware()])
|
||||
```
|
||||
|
||||
!!! example "Retry specific exceptions only"
|
||||
|
||||
```python
|
||||
from anthropic import RateLimitError
|
||||
from openai import APITimeoutError
|
||||
|
||||
retry = ModelRetryMiddleware(
|
||||
max_retries=4,
|
||||
retry_on=(APITimeoutError, RateLimitError),
|
||||
backoff_factor=1.5,
|
||||
)
|
||||
```
|
||||
|
||||
!!! example "Custom exception filtering"
|
||||
|
||||
```python
|
||||
from anthropic import APIStatusError
|
||||
|
||||
|
||||
def should_retry(exc: Exception) -> bool:
|
||||
# Only retry on 5xx errors
|
||||
if isinstance(exc, APIStatusError):
|
||||
return 500 <= exc.status_code < 600
|
||||
return False
|
||||
|
||||
|
||||
retry = ModelRetryMiddleware(
|
||||
max_retries=3,
|
||||
retry_on=should_retry,
|
||||
)
|
||||
```
|
||||
|
||||
!!! example "Custom error handling"
|
||||
|
||||
```python
|
||||
def format_error(exc: Exception) -> str:
|
||||
return "Model temporarily unavailable. Please try again later."
|
||||
|
||||
|
||||
retry = ModelRetryMiddleware(
|
||||
max_retries=4,
|
||||
on_failure=format_error,
|
||||
)
|
||||
```
|
||||
|
||||
!!! example "Constant backoff (no exponential growth)"
|
||||
|
||||
```python
|
||||
retry = ModelRetryMiddleware(
|
||||
max_retries=5,
|
||||
backoff_factor=0.0, # No exponential growth
|
||||
initial_delay=2.0, # Always wait 2 seconds
|
||||
)
|
||||
```
|
||||
|
||||
!!! example "Raise exception on failure"
|
||||
|
||||
```python
|
||||
retry = ModelRetryMiddleware(
|
||||
max_retries=2,
|
||||
on_failure="error", # Re-raise exception instead of returning message
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
max_retries: int = 2,
|
||||
retry_on: RetryOn = (Exception,),
|
||||
on_failure: OnFailure = "continue",
|
||||
backoff_factor: float = 2.0,
|
||||
initial_delay: float = 1.0,
|
||||
max_delay: float = 60.0,
|
||||
jitter: bool = True,
|
||||
) -> None:
|
||||
"""Initialize `ModelRetryMiddleware`.
|
||||
|
||||
Args:
|
||||
max_retries: Maximum number of retry attempts after the initial call.
|
||||
|
||||
Must be `>= 0`.
|
||||
retry_on: Either a tuple of exception types to retry on, or a callable
|
||||
that takes an exception and returns `True` if it should be retried.
|
||||
|
||||
Default is to retry on all exceptions.
|
||||
on_failure: Behavior when all retries are exhausted.
|
||||
|
||||
Options:
|
||||
|
||||
- `'continue'`: Return an `AIMessage` with error details,
|
||||
allowing the agent to continue with an error response.
|
||||
- `'error'`: Re-raise the exception, stopping agent execution.
|
||||
- **Custom callable:** Function that takes the exception and returns a
|
||||
string for the `AIMessage` content, allowing custom error
|
||||
formatting.
|
||||
backoff_factor: Multiplier for exponential backoff.
|
||||
|
||||
Each retry waits `initial_delay * (backoff_factor ** retry_number)`
|
||||
seconds.
|
||||
|
||||
Set to `0.0` for constant delay.
|
||||
initial_delay: Initial delay in seconds before first retry.
|
||||
max_delay: Maximum delay in seconds between retries.
|
||||
|
||||
Caps exponential backoff growth.
|
||||
jitter: Whether to add random jitter (`±25%`) to delay to avoid thundering herd.
|
||||
|
||||
Raises:
|
||||
ValueError: If `max_retries < 0` or delays are negative.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# Validate parameters
|
||||
validate_retry_params(max_retries, initial_delay, max_delay, backoff_factor)
|
||||
|
||||
self.max_retries = max_retries
|
||||
self.tools = [] # No additional tools registered by this middleware
|
||||
self.retry_on = retry_on
|
||||
self.on_failure = on_failure
|
||||
self.backoff_factor = backoff_factor
|
||||
self.initial_delay = initial_delay
|
||||
self.max_delay = max_delay
|
||||
self.jitter = jitter
|
||||
|
||||
def _format_failure_message(self, exc: Exception, attempts_made: int) -> AIMessage:
|
||||
"""Format the failure message when retries are exhausted.
|
||||
|
||||
Args:
|
||||
exc: The exception that caused the failure.
|
||||
attempts_made: Number of attempts actually made.
|
||||
|
||||
Returns:
|
||||
`AIMessage` with formatted error message.
|
||||
"""
|
||||
exc_type = type(exc).__name__
|
||||
exc_msg = str(exc)
|
||||
attempt_word = "attempt" if attempts_made == 1 else "attempts"
|
||||
content = (
|
||||
f"Model call failed after {attempts_made} {attempt_word} with {exc_type}: {exc_msg}"
|
||||
)
|
||||
return AIMessage(content=content)
|
||||
|
||||
def _handle_failure(self, exc: Exception, attempts_made: int) -> ModelResponse:
|
||||
"""Handle failure when all retries are exhausted.
|
||||
|
||||
Args:
|
||||
exc: The exception that caused the failure.
|
||||
attempts_made: Number of attempts actually made.
|
||||
|
||||
Returns:
|
||||
`ModelResponse` with error details.
|
||||
|
||||
Raises:
|
||||
Exception: If `on_failure` is `'error'`, re-raises the exception.
|
||||
"""
|
||||
if self.on_failure == "error":
|
||||
raise exc
|
||||
|
||||
if callable(self.on_failure):
|
||||
content = self.on_failure(exc)
|
||||
ai_msg = AIMessage(content=content)
|
||||
else:
|
||||
ai_msg = self._format_failure_message(exc, attempts_made)
|
||||
|
||||
return ModelResponse(result=[ai_msg])
|
||||
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelResponse | AIMessage:
|
||||
"""Intercept model execution and retry on failure.
|
||||
|
||||
Args:
|
||||
request: Model request with model, messages, state, and runtime.
|
||||
handler: Callable to execute the model (can be called multiple times).
|
||||
|
||||
Returns:
|
||||
`ModelResponse` or `AIMessage` (the final result).
|
||||
"""
|
||||
# Initial attempt + retries
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
return handler(request)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
attempts_made = attempt + 1 # attempt is 0-indexed
|
||||
|
||||
# Check if we should retry this exception
|
||||
if not should_retry_exception(exc, self.retry_on):
|
||||
# Exception is not retryable, handle failure immediately
|
||||
return self._handle_failure(exc, attempts_made)
|
||||
|
||||
# Check if we have more retries left
|
||||
if attempt < self.max_retries:
|
||||
# Calculate and apply backoff delay
|
||||
delay = calculate_delay(
|
||||
attempt,
|
||||
backoff_factor=self.backoff_factor,
|
||||
initial_delay=self.initial_delay,
|
||||
max_delay=self.max_delay,
|
||||
jitter=self.jitter,
|
||||
)
|
||||
if delay > 0:
|
||||
time.sleep(delay)
|
||||
# Continue to next retry
|
||||
else:
|
||||
# No more retries, handle failure
|
||||
return self._handle_failure(exc, attempts_made)
|
||||
|
||||
# Unreachable: loop always returns via handler success or _handle_failure
|
||||
msg = "Unexpected: retry loop completed without returning"
|
||||
raise RuntimeError(msg)
|
||||
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> ModelResponse | AIMessage:
|
||||
"""Intercept and control async model execution with retry logic.
|
||||
|
||||
Args:
|
||||
request: Model request with model, messages, state, and runtime.
|
||||
handler: Async callable to execute the model and returns `ModelResponse`.
|
||||
|
||||
Returns:
|
||||
`ModelResponse` or `AIMessage` (the final result).
|
||||
"""
|
||||
# Initial attempt + retries
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
return await handler(request)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
attempts_made = attempt + 1 # attempt is 0-indexed
|
||||
|
||||
# Check if we should retry this exception
|
||||
if not should_retry_exception(exc, self.retry_on):
|
||||
# Exception is not retryable, handle failure immediately
|
||||
return self._handle_failure(exc, attempts_made)
|
||||
|
||||
# Check if we have more retries left
|
||||
if attempt < self.max_retries:
|
||||
# Calculate and apply backoff delay
|
||||
delay = calculate_delay(
|
||||
attempt,
|
||||
backoff_factor=self.backoff_factor,
|
||||
initial_delay=self.initial_delay,
|
||||
max_delay=self.max_delay,
|
||||
jitter=self.jitter,
|
||||
)
|
||||
if delay > 0:
|
||||
await asyncio.sleep(delay)
|
||||
# Continue to next retry
|
||||
else:
|
||||
# No more retries, handle failure
|
||||
return self._handle_failure(exc, attempts_made)
|
||||
|
||||
# Unreachable: loop always returns via handler success or _handle_failure
|
||||
msg = "Unexpected: retry loop completed without returning"
|
||||
raise RuntimeError(msg)
|
||||
@@ -15,7 +15,7 @@ import uuid
|
||||
import weakref
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Annotated, Any, Literal
|
||||
from typing import TYPE_CHECKING, Annotated, Any, Literal, cast
|
||||
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langchain_core.tools.base import ToolException
|
||||
@@ -339,7 +339,7 @@ class _ShellToolInput(BaseModel):
|
||||
restart: bool | None = None
|
||||
"""Whether to restart the shell session."""
|
||||
|
||||
runtime: Annotated[Any, SkipJsonSchema] = None
|
||||
runtime: Annotated[Any, SkipJsonSchema()] = None
|
||||
"""The runtime for the shell tool.
|
||||
|
||||
Included as a workaround at the moment bc args_schema doesn't work with
|
||||
@@ -389,7 +389,7 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
|
||||
shell_command: Sequence[str] | str | None = None,
|
||||
env: Mapping[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Initialize the middleware.
|
||||
"""Initialize an instance of `ShellToolMiddleware`.
|
||||
|
||||
Args:
|
||||
workspace_root: Base directory for the shell session.
|
||||
@@ -445,7 +445,7 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
|
||||
command: str | None = None,
|
||||
restart: bool = False,
|
||||
) -> ToolMessage | str:
|
||||
resources = self._ensure_resources(runtime.state)
|
||||
resources = self._get_or_create_resources(runtime.state)
|
||||
return self._run_shell_tool(
|
||||
resources,
|
||||
{"command": command, "restart": restart},
|
||||
@@ -491,7 +491,7 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
|
||||
|
||||
def before_agent(self, state: ShellToolState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
|
||||
"""Start the shell session and run startup commands."""
|
||||
resources = self._create_resources()
|
||||
resources = self._get_or_create_resources(state)
|
||||
return {"shell_session_resources": resources}
|
||||
|
||||
async def abefore_agent(self, state: ShellToolState, runtime: Runtime) -> dict[str, Any] | None:
|
||||
@@ -500,7 +500,10 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
|
||||
|
||||
def after_agent(self, state: ShellToolState, runtime: Runtime) -> None: # noqa: ARG002
|
||||
"""Run shutdown commands and release resources when an agent completes."""
|
||||
resources = self._ensure_resources(state)
|
||||
resources = state.get("shell_session_resources")
|
||||
if not isinstance(resources, _SessionResources):
|
||||
# Resources were never created, nothing to clean up
|
||||
return
|
||||
try:
|
||||
self._run_shutdown_commands(resources.session)
|
||||
finally:
|
||||
@@ -510,17 +513,26 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
|
||||
"""Async run shutdown commands and release resources when an agent completes."""
|
||||
return self.after_agent(state, runtime)
|
||||
|
||||
def _ensure_resources(self, state: ShellToolState) -> _SessionResources:
|
||||
def _get_or_create_resources(self, state: ShellToolState) -> _SessionResources:
|
||||
"""Get existing resources from state or create new ones if they don't exist.
|
||||
|
||||
This method enables resumability by checking if resources already exist in the state
|
||||
(e.g., after an interrupt), and only creating new resources if they're not present.
|
||||
|
||||
Args:
|
||||
state: The agent state which may contain shell session resources.
|
||||
|
||||
Returns:
|
||||
Session resources, either retrieved from state or newly created.
|
||||
"""
|
||||
resources = state.get("shell_session_resources")
|
||||
if resources is not None and not isinstance(resources, _SessionResources):
|
||||
resources = None
|
||||
if resources is None:
|
||||
msg = (
|
||||
"Shell session resources are unavailable. Ensure `before_agent` ran successfully "
|
||||
"before invoking the shell tool."
|
||||
)
|
||||
raise ToolException(msg)
|
||||
return resources
|
||||
if isinstance(resources, _SessionResources):
|
||||
return resources
|
||||
|
||||
new_resources = self._create_resources()
|
||||
# Cast needed to make state dict-like for mutation
|
||||
cast("dict[str, Any]", state)["shell_session_resources"] = new_resources
|
||||
return new_resources
|
||||
|
||||
def _create_resources(self) -> _SessionResources:
|
||||
workspace = self._workspace_root
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
"""Summarization middleware."""
|
||||
|
||||
import uuid
|
||||
from collections.abc import Callable, Iterable
|
||||
from typing import Any, cast
|
||||
import warnings
|
||||
from collections.abc import Callable, Iterable, Mapping
|
||||
from functools import cache, partial
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
@@ -51,13 +53,81 @@ Messages to summarize:
|
||||
{messages}
|
||||
</messages>""" # noqa: E501
|
||||
|
||||
SUMMARY_PREFIX = "## Previous conversation summary:"
|
||||
|
||||
_DEFAULT_MESSAGES_TO_KEEP = 20
|
||||
_DEFAULT_TRIM_TOKEN_LIMIT = 4000
|
||||
_DEFAULT_FALLBACK_MESSAGE_COUNT = 15
|
||||
_SEARCH_RANGE_FOR_TOOL_PAIRS = 5
|
||||
|
||||
ContextFraction = tuple[Literal["fraction"], float]
|
||||
"""Fraction of model's maximum input tokens.
|
||||
|
||||
Example:
|
||||
To specify 50% of the model's max input tokens:
|
||||
|
||||
```python
|
||||
("fraction", 0.5)
|
||||
```
|
||||
"""
|
||||
|
||||
ContextTokens = tuple[Literal["tokens"], int]
|
||||
"""Absolute number of tokens.
|
||||
|
||||
Example:
|
||||
To specify 3000 tokens:
|
||||
|
||||
```python
|
||||
("tokens", 3000)
|
||||
```
|
||||
"""
|
||||
|
||||
ContextMessages = tuple[Literal["messages"], int]
|
||||
"""Absolute number of messages.
|
||||
|
||||
Example:
|
||||
To specify 50 messages:
|
||||
|
||||
```python
|
||||
("messages", 50)
|
||||
```
|
||||
"""
|
||||
|
||||
ContextSize = ContextFraction | ContextTokens | ContextMessages
|
||||
"""Union type for context size specifications.
|
||||
|
||||
Can be either:
|
||||
|
||||
- [`ContextFraction`][langchain.agents.middleware.summarization.ContextFraction]: A
|
||||
fraction of the model's maximum input tokens.
|
||||
- [`ContextTokens`][langchain.agents.middleware.summarization.ContextTokens]: An absolute
|
||||
number of tokens.
|
||||
- [`ContextMessages`][langchain.agents.middleware.summarization.ContextMessages]: An
|
||||
absolute number of messages.
|
||||
|
||||
Depending on use with `trigger` or `keep` parameters, this type indicates either
|
||||
when to trigger summarization or how much context to retain.
|
||||
|
||||
Example:
|
||||
```python
|
||||
# ContextFraction
|
||||
context_size: ContextSize = ("fraction", 0.5)
|
||||
|
||||
# ContextTokens
|
||||
context_size: ContextSize = ("tokens", 3000)
|
||||
|
||||
# ContextMessages
|
||||
context_size: ContextSize = ("messages", 50)
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
def _get_approximate_token_counter(model: BaseChatModel) -> TokenCounter:
|
||||
"""Tune parameters of approximate token counter based on model type."""
|
||||
if model._llm_type == "anthropic-chat":
|
||||
# 3.3 was estimated in an offline experiment, comparing with Claude's token-counting
|
||||
# API: https://platform.claude.com/docs/en/build-with-claude/token-counting
|
||||
return partial(count_tokens_approximately, chars_per_token=3.3)
|
||||
return count_tokens_approximately
|
||||
|
||||
|
||||
class SummarizationMiddleware(AgentMiddleware):
|
||||
"""Summarizes conversation history when token limits are approached.
|
||||
@@ -70,34 +140,129 @@ class SummarizationMiddleware(AgentMiddleware):
|
||||
def __init__(
|
||||
self,
|
||||
model: str | BaseChatModel,
|
||||
max_tokens_before_summary: int | None = None,
|
||||
messages_to_keep: int = _DEFAULT_MESSAGES_TO_KEEP,
|
||||
*,
|
||||
trigger: ContextSize | list[ContextSize] | None = None,
|
||||
keep: ContextSize = ("messages", _DEFAULT_MESSAGES_TO_KEEP),
|
||||
token_counter: TokenCounter = count_tokens_approximately,
|
||||
summary_prompt: str = DEFAULT_SUMMARY_PROMPT,
|
||||
summary_prefix: str = SUMMARY_PREFIX,
|
||||
trim_tokens_to_summarize: int | None = _DEFAULT_TRIM_TOKEN_LIMIT,
|
||||
**deprecated_kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize summarization middleware.
|
||||
|
||||
Args:
|
||||
model: The language model to use for generating summaries.
|
||||
max_tokens_before_summary: Token threshold to trigger summarization.
|
||||
If `None`, summarization is disabled.
|
||||
messages_to_keep: Number of recent messages to preserve after summarization.
|
||||
trigger: One or more thresholds that trigger summarization.
|
||||
|
||||
Provide a single
|
||||
[`ContextSize`][langchain.agents.middleware.summarization.ContextSize]
|
||||
tuple or a list of tuples, in which case summarization runs when any
|
||||
threshold is met.
|
||||
|
||||
!!! example
|
||||
|
||||
```python
|
||||
# Trigger summarization when 50 messages is reached
|
||||
("messages", 50)
|
||||
|
||||
# Trigger summarization when 3000 tokens is reached
|
||||
("tokens", 3000)
|
||||
|
||||
# Trigger summarization either when 80% of model's max input tokens
|
||||
# is reached or when 100 messages is reached (whichever comes first)
|
||||
[("fraction", 0.8), ("messages", 100)]
|
||||
```
|
||||
|
||||
See [`ContextSize`][langchain.agents.middleware.summarization.ContextSize]
|
||||
for more details.
|
||||
keep: Context retention policy applied after summarization.
|
||||
|
||||
Provide a [`ContextSize`][langchain.agents.middleware.summarization.ContextSize]
|
||||
tuple to specify how much history to preserve.
|
||||
|
||||
Defaults to keeping the most recent `20` messages.
|
||||
|
||||
Does not support multiple values like `trigger`.
|
||||
|
||||
!!! example
|
||||
|
||||
```python
|
||||
# Keep the most recent 20 messages
|
||||
("messages", 20)
|
||||
|
||||
# Keep the most recent 3000 tokens
|
||||
("tokens", 3000)
|
||||
|
||||
# Keep the most recent 30% of the model's max input tokens
|
||||
("fraction", 0.3)
|
||||
```
|
||||
token_counter: Function to count tokens in messages.
|
||||
summary_prompt: Prompt template for generating summaries.
|
||||
summary_prefix: Prefix added to system message when including summary.
|
||||
trim_tokens_to_summarize: Maximum tokens to keep when preparing messages for
|
||||
the summarization call.
|
||||
|
||||
Pass `None` to skip trimming entirely.
|
||||
"""
|
||||
# Handle deprecated parameters
|
||||
if "max_tokens_before_summary" in deprecated_kwargs:
|
||||
value = deprecated_kwargs["max_tokens_before_summary"]
|
||||
warnings.warn(
|
||||
"max_tokens_before_summary is deprecated. Use trigger=('tokens', value) instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
if trigger is None and value is not None:
|
||||
trigger = ("tokens", value)
|
||||
|
||||
if "messages_to_keep" in deprecated_kwargs:
|
||||
value = deprecated_kwargs["messages_to_keep"]
|
||||
warnings.warn(
|
||||
"messages_to_keep is deprecated. Use keep=('messages', value) instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
if keep == ("messages", _DEFAULT_MESSAGES_TO_KEEP):
|
||||
keep = ("messages", value)
|
||||
|
||||
super().__init__()
|
||||
|
||||
if isinstance(model, str):
|
||||
model = init_chat_model(model)
|
||||
|
||||
self.model = model
|
||||
self.max_tokens_before_summary = max_tokens_before_summary
|
||||
self.messages_to_keep = messages_to_keep
|
||||
self.token_counter = token_counter
|
||||
if trigger is None:
|
||||
self.trigger: ContextSize | list[ContextSize] | None = None
|
||||
trigger_conditions: list[ContextSize] = []
|
||||
elif isinstance(trigger, list):
|
||||
validated_list = [self._validate_context_size(item, "trigger") for item in trigger]
|
||||
self.trigger = validated_list
|
||||
trigger_conditions = validated_list
|
||||
else:
|
||||
validated = self._validate_context_size(trigger, "trigger")
|
||||
self.trigger = validated
|
||||
trigger_conditions = [validated]
|
||||
self._trigger_conditions = trigger_conditions
|
||||
|
||||
self.keep = self._validate_context_size(keep, "keep")
|
||||
if token_counter is count_tokens_approximately:
|
||||
self.token_counter = _get_approximate_token_counter(self.model)
|
||||
else:
|
||||
self.token_counter = token_counter
|
||||
self.summary_prompt = summary_prompt
|
||||
self.summary_prefix = summary_prefix
|
||||
self.trim_tokens_to_summarize = trim_tokens_to_summarize
|
||||
|
||||
requires_profile = any(condition[0] == "fraction" for condition in self._trigger_conditions)
|
||||
if self.keep[0] == "fraction":
|
||||
requires_profile = True
|
||||
if requires_profile and self._get_profile_limits() is None:
|
||||
msg = (
|
||||
"Model profile information is required to use fractional token limits, "
|
||||
"and is unavailable for the specified model. Please use absolute token "
|
||||
"counts instead, or pass "
|
||||
'`\n\nChatModel(..., profile={"max_input_tokens": ...})`.\n\n'
|
||||
"with a desired integer value of the model's maximum input tokens."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
def before_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
|
||||
"""Process messages before model invocation, potentially triggering summarization."""
|
||||
@@ -105,13 +270,10 @@ class SummarizationMiddleware(AgentMiddleware):
|
||||
self._ensure_message_ids(messages)
|
||||
|
||||
total_tokens = self.token_counter(messages)
|
||||
if (
|
||||
self.max_tokens_before_summary is not None
|
||||
and total_tokens < self.max_tokens_before_summary
|
||||
):
|
||||
if not self._should_summarize(messages, total_tokens):
|
||||
return None
|
||||
|
||||
cutoff_index = self._find_safe_cutoff(messages)
|
||||
cutoff_index = self._determine_cutoff_index(messages)
|
||||
|
||||
if cutoff_index <= 0:
|
||||
return None
|
||||
@@ -129,6 +291,158 @@ class SummarizationMiddleware(AgentMiddleware):
|
||||
]
|
||||
}
|
||||
|
||||
async def abefore_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
|
||||
"""Process messages before model invocation, potentially triggering summarization."""
|
||||
messages = state["messages"]
|
||||
self._ensure_message_ids(messages)
|
||||
|
||||
total_tokens = self.token_counter(messages)
|
||||
if not self._should_summarize(messages, total_tokens):
|
||||
return None
|
||||
|
||||
cutoff_index = self._determine_cutoff_index(messages)
|
||||
|
||||
if cutoff_index <= 0:
|
||||
return None
|
||||
|
||||
messages_to_summarize, preserved_messages = self._partition_messages(messages, cutoff_index)
|
||||
|
||||
summary = await self._acreate_summary(messages_to_summarize)
|
||||
new_messages = self._build_new_messages(summary)
|
||||
|
||||
return {
|
||||
"messages": [
|
||||
RemoveMessage(id=REMOVE_ALL_MESSAGES),
|
||||
*new_messages,
|
||||
*preserved_messages,
|
||||
]
|
||||
}
|
||||
|
||||
def _should_summarize(self, messages: list[AnyMessage], total_tokens: int) -> bool:
|
||||
"""Determine whether summarization should run for the current token usage."""
|
||||
if not self._trigger_conditions:
|
||||
return False
|
||||
|
||||
for kind, value in self._trigger_conditions:
|
||||
if kind == "messages" and len(messages) >= value:
|
||||
return True
|
||||
if kind == "tokens" and total_tokens >= value:
|
||||
return True
|
||||
if kind == "fraction":
|
||||
max_input_tokens = self._get_profile_limits()
|
||||
if max_input_tokens is None:
|
||||
continue
|
||||
threshold = int(max_input_tokens * value)
|
||||
if threshold <= 0:
|
||||
threshold = 1
|
||||
if total_tokens >= threshold:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _determine_cutoff_index(self, messages: list[AnyMessage]) -> int:
|
||||
"""Choose cutoff index respecting retention configuration."""
|
||||
kind, value = self.keep
|
||||
if kind in {"tokens", "fraction"}:
|
||||
token_based_cutoff = self._find_token_based_cutoff(messages)
|
||||
if token_based_cutoff is not None:
|
||||
return token_based_cutoff
|
||||
# None cutoff -> model profile data not available (caught in __init__ but
|
||||
# here for safety), fallback to message count
|
||||
return self._find_safe_cutoff(messages, _DEFAULT_MESSAGES_TO_KEEP)
|
||||
return self._find_safe_cutoff(messages, cast("int", value))
|
||||
|
||||
def _find_token_based_cutoff(self, messages: list[AnyMessage]) -> int | None:
|
||||
"""Find cutoff index based on target token retention."""
|
||||
if not messages:
|
||||
return 0
|
||||
|
||||
@cache
|
||||
def suffix_token_count(start_index: int) -> int:
|
||||
return self.token_counter(messages[start_index:])
|
||||
|
||||
kind, value = self.keep
|
||||
if kind == "fraction":
|
||||
max_input_tokens = self._get_profile_limits()
|
||||
if max_input_tokens is None:
|
||||
return None
|
||||
target_token_count = int(max_input_tokens * value)
|
||||
elif kind == "tokens":
|
||||
target_token_count = int(value)
|
||||
else:
|
||||
return None
|
||||
|
||||
if target_token_count <= 0:
|
||||
target_token_count = 1
|
||||
|
||||
if suffix_token_count(0) <= target_token_count:
|
||||
return 0
|
||||
|
||||
# Use binary search to identify the earliest message index that keeps the
|
||||
# suffix within the token budget.
|
||||
left, right = 0, len(messages)
|
||||
cutoff_candidate = len(messages)
|
||||
max_iterations = len(messages).bit_length() + 1
|
||||
for _ in range(max_iterations):
|
||||
if left >= right:
|
||||
break
|
||||
|
||||
mid = (left + right) // 2
|
||||
if suffix_token_count(mid) <= target_token_count:
|
||||
cutoff_candidate = mid
|
||||
right = mid
|
||||
else:
|
||||
left = mid + 1
|
||||
|
||||
if cutoff_candidate == len(messages):
|
||||
cutoff_candidate = left
|
||||
|
||||
if cutoff_candidate >= len(messages):
|
||||
if len(messages) == 1:
|
||||
return 0
|
||||
cutoff_candidate = len(messages) - 1
|
||||
|
||||
for i in range(cutoff_candidate, len(messages) + 1):
|
||||
if (
|
||||
self._is_safe_cutoff_point(messages, i)
|
||||
and suffix_token_count(i) <= target_token_count
|
||||
):
|
||||
return i
|
||||
|
||||
return 0
|
||||
|
||||
def _get_profile_limits(self) -> int | None:
|
||||
"""Retrieve max input token limit from the model profile."""
|
||||
try:
|
||||
profile = self.model.profile
|
||||
except AttributeError:
|
||||
return None
|
||||
|
||||
if not isinstance(profile, Mapping):
|
||||
return None
|
||||
|
||||
max_input_tokens = profile.get("max_input_tokens")
|
||||
|
||||
if not isinstance(max_input_tokens, int):
|
||||
return None
|
||||
|
||||
return max_input_tokens
|
||||
|
||||
def _validate_context_size(self, context: ContextSize, parameter_name: str) -> ContextSize:
|
||||
"""Validate context configuration tuples."""
|
||||
kind, value = context
|
||||
if kind == "fraction":
|
||||
if not 0 < value <= 1:
|
||||
msg = f"Fractional {parameter_name} values must be between 0 and 1, got {value}."
|
||||
raise ValueError(msg)
|
||||
elif kind in {"tokens", "messages"}:
|
||||
if value <= 0:
|
||||
msg = f"{parameter_name} thresholds must be greater than 0, got {value}."
|
||||
raise ValueError(msg)
|
||||
else:
|
||||
msg = f"Unsupported context size type {kind} for {parameter_name}."
|
||||
raise ValueError(msg)
|
||||
return context
|
||||
|
||||
def _build_new_messages(self, summary: str) -> list[HumanMessage]:
|
||||
return [
|
||||
HumanMessage(content=f"Here is a summary of the conversation to date:\n\n{summary}")
|
||||
@@ -151,16 +465,16 @@ class SummarizationMiddleware(AgentMiddleware):
|
||||
|
||||
return messages_to_summarize, preserved_messages
|
||||
|
||||
def _find_safe_cutoff(self, messages: list[AnyMessage]) -> int:
|
||||
def _find_safe_cutoff(self, messages: list[AnyMessage], messages_to_keep: int) -> int:
|
||||
"""Find safe cutoff point that preserves AI/Tool message pairs.
|
||||
|
||||
Returns the index where messages can be safely cut without separating
|
||||
related AI and Tool messages. Returns 0 if no safe cutoff is found.
|
||||
related AI and Tool messages. Returns `0` if no safe cutoff is found.
|
||||
"""
|
||||
if len(messages) <= self.messages_to_keep:
|
||||
if len(messages) <= messages_to_keep:
|
||||
return 0
|
||||
|
||||
target_cutoff = len(messages) - self.messages_to_keep
|
||||
target_cutoff = len(messages) - messages_to_keep
|
||||
|
||||
for i in range(target_cutoff, -1, -1):
|
||||
if self._is_safe_cutoff_point(messages, i):
|
||||
@@ -229,16 +543,35 @@ class SummarizationMiddleware(AgentMiddleware):
|
||||
|
||||
try:
|
||||
response = self.model.invoke(self.summary_prompt.format(messages=trimmed_messages))
|
||||
return cast("str", response.content).strip()
|
||||
return response.text.strip()
|
||||
except Exception as e: # noqa: BLE001
|
||||
return f"Error generating summary: {e!s}"
|
||||
|
||||
async def _acreate_summary(self, messages_to_summarize: list[AnyMessage]) -> str:
|
||||
"""Generate summary for the given messages."""
|
||||
if not messages_to_summarize:
|
||||
return "No previous conversation history."
|
||||
|
||||
trimmed_messages = self._trim_messages_for_summary(messages_to_summarize)
|
||||
if not trimmed_messages:
|
||||
return "Previous conversation was too long to summarize."
|
||||
|
||||
try:
|
||||
response = await self.model.ainvoke(
|
||||
self.summary_prompt.format(messages=trimmed_messages)
|
||||
)
|
||||
return response.text.strip()
|
||||
except Exception as e: # noqa: BLE001
|
||||
return f"Error generating summary: {e!s}"
|
||||
|
||||
def _trim_messages_for_summary(self, messages: list[AnyMessage]) -> list[AnyMessage]:
|
||||
"""Trim messages to fit within summary generation limits."""
|
||||
try:
|
||||
if self.trim_tokens_to_summarize is None:
|
||||
return messages
|
||||
return trim_messages(
|
||||
messages,
|
||||
max_tokens=_DEFAULT_TRIM_TOKEN_LIMIT,
|
||||
max_tokens=self.trim_tokens_to_summarize,
|
||||
token_counter=self.token_counter,
|
||||
start_on="human",
|
||||
strategy="last",
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user