Compare commits

...

38 Commits

Author SHA1 Message Date
Mason Daugherty
56379fb94c feat(docs): enhance documentation for name attribute in message classes 2025-08-08 13:18:40 -04:00
Mason Daugherty
7f989d3c3b feat(docs): clarify ToolMessage contentfield usage 2025-08-08 13:02:59 -04:00
Mason Daugherty
b7968c2b7d feat(docs): add link to artifact usage in ToolMessage 2025-08-08 12:51:15 -04:00
Mason Daugherty
2f0c6421a1 Merge branch 'master' into wip-v0.4 2025-08-08 10:21:44 -04:00
Mason Daugherty
c31236264e chore: formatting across codebase (#32466) 2025-08-08 10:20:10 -04:00
Chester Curme
cfe13f673a Merge branch 'master' into wip-v0.4
# Conflicts:
#	libs/core/langchain_core/version.py
#	libs/core/pyproject.toml
#	libs/core/uv.lock
#	libs/partners/openai/tests/integration_tests/chat_models/test_responses_api.py
#	libs/partners/openai/uv.lock
2025-08-08 09:04:57 -04:00
ccurme
02001212b0 fix(openai): revert some changes (#32462)
Keep coverage on `output_version="v0"` (increasing coverage is being
managed in v0.4 branch).
2025-08-08 08:51:18 -04:00
Mason Daugherty
00244122bd feat(openai): minimal and verbosity (#32455) 2025-08-08 02:24:21 +00:00
Mason Daugherty
5599c59d4a chore: formatting across codebase (#32456)
To prevent polluting future PRs
2025-08-07 22:09:26 -04:00
ccurme
6727d6e8c8 release(core): 0.3.74 (#32454) 2025-08-07 16:39:01 -04:00
Michael Matloka
5036bd7adb fix(openai): don't crash get_num_tokens_from_messages on gpt-5 (#32451) 2025-08-07 16:33:19 -04:00
ccurme
ec2b34a02d feat(openai): custom tools (#32449) 2025-08-07 16:30:01 -04:00
Mason Daugherty
11d68a0b9e bump locks 2025-08-07 15:51:36 -04:00
Mason Daugherty
566774a893 Merge branch 'wip-v0.4' of github.com:langchain-ai/langchain into wip-v0.4 2025-08-07 15:50:40 -04:00
Mason Daugherty
255a6d668a feat: allow bypassing CI using PR label 2025-08-07 15:50:15 -04:00
Mason Daugherty
cbf4c0e565 Merge branch 'master' into wip-v0.4 2025-08-07 15:33:12 -04:00
Mason Daugherty
145d38f7dd test(openai): add tests for prompt_cache_key parameter and update docs (#32363)
Introduce tests to validate the behavior and inclusion of the
`prompt_cache_key` parameter in request payloads for the `ChatOpenAI`
model.
2025-08-07 15:29:47 -04:00
ccurme
68c70da33e fix(openai): add in output_text (#32450)
This property was deleted in `openai==1.99.2`.
2025-08-07 15:23:56 -04:00
Eugene Yurtsev
754528d23f feat(langchain): add stuff and map reduce chains (#32333)
* Add stuff and map reduce chains
* We'll need to rename and add unit tests to the chains prior to
official release
2025-08-07 15:20:05 -04:00
Mason Daugherty
dc66737f03 fix: docs and formatting (#32448) 2025-08-07 15:17:25 -04:00
Christophe Bornet
499dc35cfb chore(core): bump mypy version to 1.17 (#32390)
Co-authored-by: Mason Daugherty <mason@langchain.dev>
2025-08-07 13:26:29 -04:00
Mason Daugherty
42c1159991 feat: add TextAccessor, deprecate .text() as method (#32441)
Adds backward compat for `.text()` on messages while keeping `.text`
access

_The kicker:_

Any previous use of `.text()` will now need a `# type: ignore[operator]`
to silence type checkers. However, it will still behave as expected at
runtime. Deprecating in v0.4.0, to be removed in v2.0.0.
2025-08-07 12:16:31 -04:00
CLOVA Studio 개발
ac706c77d4 docs(docs): update v0.1.1 chatModel document on langchain-naver. (#32445)
## **Description:** 
This PR was requested after the `langchain-naver` partner-managed
packages were released
[v0.1.1](https://pypi.org/project/langchain-naver/0.1.1/).
So we've updated some our documents with the additional changed
features.

## **Dependencies:** 
https://github.com/langchain-ai/langchain/pull/30956

---------

Co-authored-by: 김필환[AI Studio Dev1] <pilhwan.kim@navercorp.com>
Co-authored-by: Mason Daugherty <github@mdrxy.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Mason Daugherty <mason@langchain.dev>
2025-08-07 15:45:50 +00:00
Tianyu Chen
8493887b6f docs: update Docker image name for jaguardb setup (#32438)
**Description**
Updated the quick setup instructions for JaguarDB in the documentation.
Replaced the outdated Docker image `jaguardb/jaguardb_with_http` with
the current recommended image `jaguardb/jaguardb` for pulling and
running the server.
2025-08-07 11:23:29 -04:00
Christophe Bornet
a647073b26 feat(standard-tests): add a property to set the name of the parameter for the number of results to return (#32443)
Not all retrievers use `k` as param name to set the number of results to
return. Even in LangChain itself. Eg:
bc4251b9e0/libs/core/langchain_core/indexing/in_memory.py (L31)

So it's helpful to be able to change it for a given retriever.
The change also adds hints to disable the tests if the retriever doesn't
support setting the param in the constructor or in the invoke method
(for instance, the `InMemoryDocumentIndex` in the link supports in the
constructor but not in the invoke method).

This change is backward compatible.

---------

Co-authored-by: Mason Daugherty <mason@langchain.dev>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-08-07 11:22:24 -04:00
ccurme
e120604774 fix(infra): exclude pre-releases from previous version testing (#32447) 2025-08-07 10:18:59 -04:00
ccurme
06d8754b0b release(core): 0.3.73 (#32446) 2025-08-07 09:03:53 -04:00
ccurme
6e108c1cb4 feat(core): zero-out token costs for cache hits (#32437) 2025-08-07 08:49:34 -04:00
Mason Daugherty
cc6139860c fix: docs typing issues 2025-08-06 23:50:33 -04:00
Mason Daugherty
ae8f58ac6f fix(settings): update Python terminal settings and default interpreter path 2025-08-06 23:37:40 -04:00
Mason Daugherty
346731544b Merge branch 'master' into wip-v0.4 2025-08-06 18:24:10 -04:00
Mason Daugherty
c1b86cc929 feat: minor core work, v1 standard tests & (most of) v1 ollama (#32315)
Resolves #32215

---------

Co-authored-by: Chester Curme <chester.curme@gmail.com>
Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
Co-authored-by: Nuno Campos <nuno@langchain.dev>
2025-08-06 18:22:02 -04:00
Mason Daugherty
376f70be96 sync wip with master (#32436)
Co-authored-by: Kanav Bansal <13186335+bansalkanav@users.noreply.github.com>
Co-authored-by: Pranav Bhartiya <124018094+pranauww@users.noreply.github.com>
Co-authored-by: Nelson Sproul <nelson.sproul@gmail.com>
Co-authored-by: John Bledsoe <jmbledsoe@gmail.com>
2025-08-06 17:57:05 -04:00
ccurme
ac2de920b1 chore: increment versions for 0.4 branch (#32419) 2025-08-05 15:39:37 -04:00
ccurme
e02eed5489 feat: standard outputs (#32287)
Co-authored-by: Mason Daugherty <mason@langchain.dev>
Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
Co-authored-by: Mason Daugherty <github@mdrxy.com>
Co-authored-by: Nuno Campos <nuno@langchain.dev>
2025-08-05 15:17:32 -04:00
Chester Curme
5414527236 Merge branch 'master' into wip-v0.4 2025-08-04 11:55:14 -04:00
Chester Curme
881c6534a6 Merge branch 'master' into wip-v0.4
# Conflicts:
#	.github/workflows/_integration_test.yml
#	.github/workflows/_release.yml
#	.github/workflows/api_doc_build.yml
#	.github/workflows/people.yml
#	.github/workflows/run_notebooks.yml
#	.github/workflows/scheduled_test.yml
#	SECURITY.md
#	docs/docs/integrations/vectorstores/pgvectorstore.ipynb
#	libs/langchain_v1/langchain/chat_models/base.py
#	libs/langchain_v1/tests/integration_tests/chat_models/test_base.py
#	libs/langchain_v1/tests/unit_tests/chat_models/test_chat_models.py
2025-07-30 13:16:17 -04:00
Mason Daugherty
5e9eb19a83 chore: update branch with changes from master (#32277)
Co-authored-by: Maxime Grenu <69890511+cluster2600@users.noreply.github.com>
Co-authored-by: Claude <claude@anthropic.com>
Co-authored-by: Claude <noreply@anthropic.com>
Co-authored-by: jmaillefaud <jonathan.maillefaud@evooq.ch>
Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
Co-authored-by: tanwirahmad <tanwirahmad@users.noreply.github.com>
Co-authored-by: Christophe Bornet <cbornet@hotmail.com>
Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com>
Co-authored-by: niceg <79145285+growmuye@users.noreply.github.com>
Co-authored-by: Chaitanya varma <varmac301@gmail.com>
Co-authored-by: dishaprakash <57954147+dishaprakash@users.noreply.github.com>
Co-authored-by: Chester Curme <chester.curme@gmail.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Kanav Bansal <13186335+bansalkanav@users.noreply.github.com>
Co-authored-by: Aleksandr Filippov <71711753+alex-feel@users.noreply.github.com>
Co-authored-by: Alex Feel <afilippov@spotware.com>
2025-07-28 10:39:41 -04:00
226 changed files with 30045 additions and 5358 deletions

View File

@@ -15,12 +15,12 @@ You may use the button above, or follow these steps to open this repo in a Codes
1. Click **Create codespace on master**.
For more info, check out the [GitHub documentation](https://docs.github.com/en/free-pro-team@latest/github/developing-online-with-codespaces/creating-a-codespace#creating-a-codespace).
## VS Code Dev Containers
[![Open in Dev Containers](https://img.shields.io/static/v1?label=Dev%20Containers&message=Open&color=blue&logo=visualstudiocode)](https://vscode.dev/redirect?url=vscode://ms-vscode-remote.remote-containers/cloneInVolume?url=https://github.com/langchain-ai/langchain)
> [!NOTE]
> [!NOTE]
> If you click the link above you will open the main repo (`langchain-ai/langchain`) and *not* your local cloned repo. This is fine if you only want to run and test the library, but if you want to contribute you can use the link below and replace with your username and cloned repo name:
```txt

View File

@@ -4,7 +4,7 @@ services:
build:
dockerfile: libs/langchain/dev.Dockerfile
context: ..
networks:
- langchain-network

View File

@@ -129,4 +129,4 @@ For answers to common questions about this code of conduct, see the FAQ at
[v2.1]: https://www.contributor-covenant.org/version/2/1/code_of_conduct.html
[Mozilla CoC]: https://github.com/mozilla/diversity
[FAQ]: https://www.contributor-covenant.org/faq
[translations]: https://www.contributor-covenant.org/translations
[translations]: https://www.contributor-covenant.org/translations

View File

@@ -5,7 +5,7 @@ body:
- type: markdown
attributes:
value: |
Thank you for taking the time to file a bug report.
Thank you for taking the time to file a bug report.
Use this to report BUGS in LangChain. For usage questions, feature requests and general design questions, please use the [LangChain Forum](https://forum.langchain.com/).
@@ -50,7 +50,7 @@ body:
If a maintainer can copy it, run it, and see it right away, there's a much higher chance that you'll be able to get help.
**Important!**
**Important!**
* Avoid screenshots when possible, as they are hard to read and (more importantly) don't allow others to copy-and-paste your code.
* Reduce your code to the minimum required to reproduce the issue if possible. This makes it much easier for others to help you.
@@ -58,14 +58,14 @@ body:
* INCLUDE the language label (e.g. `python`) after the first three backticks to enable syntax highlighting. (e.g., ```python rather than ```).
placeholder: |
The following code:
The following code:
```python
from langchain_core.runnables import RunnableLambda
def bad_code(inputs) -> int:
raise NotImplementedError('For demo purpose')
chain = RunnableLambda(bad_code)
chain.invoke('Hello!')
```

View File

@@ -14,7 +14,7 @@ body:
Do **NOT** use this to ask usage questions or reporting issues with your code.
If you have usage questions or need help solving some problem,
If you have usage questions or need help solving some problem,
please use the [LangChain Forum](https://forum.langchain.com/).
If you're in the wrong place, here are some helpful links to find a better

View File

@@ -8,7 +8,7 @@ body:
If you are not a LangChain maintainer or were not asked directly by a maintainer to create an issue, then please start the conversation on the [LangChain Forum](https://forum.langchain.com/) instead.
You are a LangChain maintainer if you maintain any of the packages inside of the LangChain repository
You are a LangChain maintainer if you maintain any of the packages inside of the LangChain repository
or are a regular contributor to LangChain with previous merged pull requests.
- type: checkboxes
id: privileged

View File

@@ -4,4 +4,4 @@ RUN pip install httpx PyGithub "pydantic==2.0.2" pydantic-settings "pyyaml>=5.3.
COPY ./app /app
CMD ["python", "/app/main.py"]
CMD ["python", "/app/main.py"]

View File

@@ -4,8 +4,8 @@ description: "Generate the data for the LangChain People page"
author: "Jacob Lee <jacob@langchain.dev>"
inputs:
token:
description: 'User token, to read the GitHub API. Can be passed in using {{ secrets.LANGCHAIN_PEOPLE_GITHUB_TOKEN }}'
description: "User token, to read the GitHub API. Can be passed in using {{ secrets.LANGCHAIN_PEOPLE_GITHUB_TOKEN }}"
required: true
runs:
using: 'docker'
image: 'Dockerfile'
using: "docker"
image: "Dockerfile"

View File

@@ -3,14 +3,12 @@ import json
import os
import sys
from collections import defaultdict
from typing import Dict, List, Set
from pathlib import Path
from typing import Dict, List, Set
import tomllib
from packaging.requirements import Requirement
from get_min_versions import get_min_version_from_toml
from packaging.requirements import Requirement
LANGCHAIN_DIRS = [
"libs/core",
@@ -38,7 +36,7 @@ IGNORED_PARTNERS = [
]
PY_312_MAX_PACKAGES = [
"libs/partners/chroma", # https://github.com/chroma-core/chroma/issues/4382
"libs/partners/chroma", # https://github.com/chroma-core/chroma/issues/4382
]
@@ -85,9 +83,9 @@ def dependents_graph() -> dict:
for depline in extended_deps:
if depline.startswith("-e "):
# editable dependency
assert depline.startswith(
"-e ../partners/"
), "Extended test deps should only editable install partner packages"
assert depline.startswith("-e ../partners/"), (
"Extended test deps should only editable install partner packages"
)
partner = depline.split("partners/")[1]
dep = f"langchain-{partner}"
else:
@@ -271,7 +269,7 @@ if __name__ == "__main__":
dirs_to_run["extended-test"].add(dir_)
elif file.startswith("libs/standard-tests"):
# TODO: update to include all packages that rely on standard-tests (all partner packages)
# note: won't run on external repo partners
# Note: won't run on external repo partners
dirs_to_run["lint"].add("libs/standard-tests")
dirs_to_run["test"].add("libs/standard-tests")
dirs_to_run["lint"].add("libs/cli")
@@ -285,7 +283,7 @@ if __name__ == "__main__":
elif file.startswith("libs/cli"):
dirs_to_run["lint"].add("libs/cli")
dirs_to_run["test"].add("libs/cli")
elif file.startswith("libs/partners"):
partner_dir = file.split("/")[2]
if os.path.isdir(f"libs/partners/{partner_dir}") and [
@@ -303,7 +301,10 @@ if __name__ == "__main__":
f"Unknown lib: {file}. check_diff.py likely needs "
"an update for this new library!"
)
elif file.startswith("docs/") or file in ["pyproject.toml", "uv.lock"]: # docs or root uv files
elif file.startswith("docs/") or file in [
"pyproject.toml",
"uv.lock",
]: # docs or root uv files
docs_edited = True
dirs_to_run["lint"].add(".")

View File

@@ -1,4 +1,5 @@
import sys
import tomllib
if __name__ == "__main__":

View File

@@ -1,5 +1,5 @@
from collections import defaultdict
import sys
from collections import defaultdict
from typing import Optional
if sys.version_info >= (3, 11):
@@ -8,17 +8,13 @@ else:
# for python 3.10 and below, which doesnt have stdlib tomllib
import tomli as tomllib
from packaging.requirements import Requirement
from packaging.specifiers import SpecifierSet
from packaging.version import Version
import requests
from packaging.version import parse
import re
from typing import List
import re
import requests
from packaging.requirements import Requirement
from packaging.specifiers import SpecifierSet
from packaging.version import Version, parse
MIN_VERSION_LIBS = [
"langchain-core",
@@ -72,11 +68,13 @@ def get_minimum_version(package_name: str, spec_string: str) -> Optional[str]:
spec_string = re.sub(r"\^0\.0\.(\d+)", r"0.0.\1", spec_string)
# rewrite occurrences of ^0.y.z to >=0.y.z,<0.y+1 (can be anywhere in constraint string)
for y in range(1, 10):
spec_string = re.sub(rf"\^0\.{y}\.(\d+)", rf">=0.{y}.\1,<0.{y+1}", spec_string)
spec_string = re.sub(
rf"\^0\.{y}\.(\d+)", rf">=0.{y}.\1,<0.{y + 1}", spec_string
)
# rewrite occurrences of ^x.y.z to >=x.y.z,<x+1.0.0 (can be anywhere in constraint string)
for x in range(1, 10):
spec_string = re.sub(
rf"\^{x}\.(\d+)\.(\d+)", rf">={x}.\1.\2,<{x+1}", spec_string
rf"\^{x}\.(\d+)\.(\d+)", rf">={x}.\1.\2,<{x + 1}", spec_string
)
spec_set = SpecifierSet(spec_string)
@@ -169,12 +167,12 @@ def check_python_version(version_string, constraint_string):
# rewrite occurrences of ^0.y.z to >=0.y.z,<0.y+1.0 (can be anywhere in constraint string)
for y in range(1, 10):
constraint_string = re.sub(
rf"\^0\.{y}\.(\d+)", rf">=0.{y}.\1,<0.{y+1}.0", constraint_string
rf"\^0\.{y}\.(\d+)", rf">=0.{y}.\1,<0.{y + 1}.0", constraint_string
)
# rewrite occurrences of ^x.y.z to >=x.y.z,<x+1.0.0 (can be anywhere in constraint string)
for x in range(1, 10):
constraint_string = re.sub(
rf"\^{x}\.0\.(\d+)", rf">={x}.0.\1,<{x+1}.0.0", constraint_string
rf"\^{x}\.0\.(\d+)", rf">={x}.0.\1,<{x + 1}.0.0", constraint_string
)
try:

View File

@@ -3,9 +3,10 @@
import os
import shutil
import yaml
from pathlib import Path
from typing import Dict, Any
from typing import Any, Dict
import yaml
def load_packages_yaml() -> Dict[str, Any]:
@@ -28,7 +29,6 @@ def get_target_dir(package_name: str) -> Path:
def clean_target_directories(packages: list) -> None:
"""Remove old directories that will be replaced."""
for package in packages:
target_dir = get_target_dir(package["name"])
if target_dir.exists():
print(f"Removing {target_dir}")
@@ -38,7 +38,6 @@ def clean_target_directories(packages: list) -> None:
def move_libraries(packages: list) -> None:
"""Move libraries from their source locations to the target directories."""
for package in packages:
repo_name = package["repo"].split("/")[1]
source_path = package["path"]
target_dir = get_target_dir(package["name"])
@@ -68,23 +67,33 @@ def main():
package_yaml = load_packages_yaml()
# Clean target directories
clean_target_directories([
p
for p in package_yaml["packages"]
if (p["repo"].startswith("langchain-ai/") or p.get("include_in_api_ref"))
and p["repo"] != "langchain-ai/langchain"
and p["name"] != "langchain-ai21" # Skip AI21 due to dependency conflicts
])
clean_target_directories(
[
p
for p in package_yaml["packages"]
if (
p["repo"].startswith("langchain-ai/") or p.get("include_in_api_ref")
)
and p["repo"] != "langchain-ai/langchain"
and p["name"]
!= "langchain-ai21" # Skip AI21 due to dependency conflicts
]
)
# Move libraries to their new locations
move_libraries([
p
for p in package_yaml["packages"]
if not p.get("disabled", False)
and (p["repo"].startswith("langchain-ai/") or p.get("include_in_api_ref"))
and p["repo"] != "langchain-ai/langchain"
and p["name"] != "langchain-ai21" # Skip AI21 due to dependency conflicts
])
move_libraries(
[
p
for p in package_yaml["packages"]
if not p.get("disabled", False)
and (
p["repo"].startswith("langchain-ai/") or p.get("include_in_api_ref")
)
and p["repo"] != "langchain-ai/langchain"
and p["name"]
!= "langchain-ai21" # Skip AI21 due to dependency conflicts
]
)
# Delete ones without a pyproject.toml
for partner in Path("langchain/libs/partners").iterdir():

View File

@@ -81,56 +81,93 @@ import time
__version__ = "2022.12+dev"
# Update symlinks only if the platform supports not following them
UPDATE_SYMLINKS = bool(os.utime in getattr(os, 'supports_follow_symlinks', []))
UPDATE_SYMLINKS = bool(os.utime in getattr(os, "supports_follow_symlinks", []))
# Call os.path.normpath() only if not in a POSIX platform (Windows)
NORMALIZE_PATHS = (os.path.sep != '/')
NORMALIZE_PATHS = os.path.sep != "/"
# How many files to process in each batch when re-trying merge commits
STEPMISSING = 100
# (Extra) keywords for the os.utime() call performed by touch()
UTIME_KWS = {} if not UPDATE_SYMLINKS else {'follow_symlinks': False}
UTIME_KWS = {} if not UPDATE_SYMLINKS else {"follow_symlinks": False}
# Command-line interface ######################################################
def parse_args():
parser = argparse.ArgumentParser(
description=__doc__.split('\n---')[0])
parser = argparse.ArgumentParser(description=__doc__.split("\n---")[0])
group = parser.add_mutually_exclusive_group()
group.add_argument('--quiet', '-q', dest='loglevel',
action="store_const", const=logging.WARNING, default=logging.INFO,
help="Suppress informative messages and summary statistics.")
group.add_argument('--verbose', '-v', action="count", help="""
group.add_argument(
"--quiet",
"-q",
dest="loglevel",
action="store_const",
const=logging.WARNING,
default=logging.INFO,
help="Suppress informative messages and summary statistics.",
)
group.add_argument(
"--verbose",
"-v",
action="count",
help="""
Print additional information for each processed file.
Specify twice to further increase verbosity.
""")
""",
)
parser.add_argument('--cwd', '-C', metavar="DIRECTORY", help="""
parser.add_argument(
"--cwd",
"-C",
metavar="DIRECTORY",
help="""
Run as if %(prog)s was started in directory %(metavar)s.
This affects how --work-tree, --git-dir and PATHSPEC arguments are handled.
See 'man 1 git' or 'git --help' for more information.
""")
""",
)
parser.add_argument('--git-dir', dest='gitdir', metavar="GITDIR", help="""
parser.add_argument(
"--git-dir",
dest="gitdir",
metavar="GITDIR",
help="""
Path to the git repository, by default auto-discovered by searching
the current directory and its parents for a .git/ subdirectory.
""")
""",
)
parser.add_argument('--work-tree', dest='workdir', metavar="WORKTREE", help="""
parser.add_argument(
"--work-tree",
dest="workdir",
metavar="WORKTREE",
help="""
Path to the work tree root, by default the parent of GITDIR if it's
automatically discovered, or the current directory if GITDIR is set.
""")
""",
)
parser.add_argument('--force', '-f', default=False, action="store_true", help="""
parser.add_argument(
"--force",
"-f",
default=False,
action="store_true",
help="""
Force updating files with uncommitted modifications.
Untracked files and uncommitted deletions, renames and additions are
always ignored.
""")
""",
)
parser.add_argument('--merge', '-m', default=False, action="store_true", help="""
parser.add_argument(
"--merge",
"-m",
default=False,
action="store_true",
help="""
Include merge commits.
Leads to more recent times and more files per commit, thus with the same
time, which may or may not be what you want.
@@ -138,71 +175,130 @@ def parse_args():
are found sooner, which can improve performance, sometimes substantially.
But as merge commits are usually huge, processing them may also take longer.
By default, merge commits are only used for files missing from regular commits.
""")
""",
)
parser.add_argument('--first-parent', default=False, action="store_true", help="""
parser.add_argument(
"--first-parent",
default=False,
action="store_true",
help="""
Consider only the first parent, the "main branch", when evaluating merge commits.
Only effective when merge commits are processed, either when --merge is
used or when finding missing files after the first regular log search.
See --skip-missing.
""")
""",
)
parser.add_argument('--skip-missing', '-s', dest="missing", default=True,
action="store_false", help="""
parser.add_argument(
"--skip-missing",
"-s",
dest="missing",
default=True,
action="store_false",
help="""
Do not try to find missing files.
If merge commits were not evaluated with --merge and some files were
not found in regular commits, by default %(prog)s searches for these
files again in the merge commits.
This option disables this retry, so files found only in merge commits
will not have their timestamp updated.
""")
""",
)
parser.add_argument('--no-directories', '-D', dest='dirs', default=True,
action="store_false", help="""
parser.add_argument(
"--no-directories",
"-D",
dest="dirs",
default=True,
action="store_false",
help="""
Do not update directory timestamps.
By default, use the time of its most recently created, renamed or deleted file.
Note that just modifying a file will NOT update its directory time.
""")
""",
)
parser.add_argument('--test', '-t', default=False, action="store_true",
help="Test run: do not actually update any file timestamp.")
parser.add_argument(
"--test",
"-t",
default=False,
action="store_true",
help="Test run: do not actually update any file timestamp.",
)
parser.add_argument('--commit-time', '-c', dest='commit_time', default=False,
action='store_true', help="Use commit time instead of author time.")
parser.add_argument(
"--commit-time",
"-c",
dest="commit_time",
default=False,
action="store_true",
help="Use commit time instead of author time.",
)
parser.add_argument('--oldest-time', '-o', dest='reverse_order', default=False,
action='store_true', help="""
parser.add_argument(
"--oldest-time",
"-o",
dest="reverse_order",
default=False,
action="store_true",
help="""
Update times based on the oldest, instead of the most recent commit of a file.
This reverses the order in which the git log is processed to emulate a
file "creation" date. Note this will be inaccurate for files deleted and
re-created at later dates.
""")
""",
)
parser.add_argument('--skip-older-than', metavar='SECONDS', type=int, help="""
parser.add_argument(
"--skip-older-than",
metavar="SECONDS",
type=int,
help="""
Ignore files that are currently older than %(metavar)s.
Useful in workflows that assume such files already have a correct timestamp,
as it may improve performance by processing fewer files.
""")
""",
)
parser.add_argument('--skip-older-than-commit', '-N', default=False,
action='store_true', help="""
parser.add_argument(
"--skip-older-than-commit",
"-N",
default=False,
action="store_true",
help="""
Ignore files older than the timestamp it would be updated to.
Such files may be considered "original", likely in the author's repository.
""")
""",
)
parser.add_argument('--unique-times', default=False, action="store_true", help="""
parser.add_argument(
"--unique-times",
default=False,
action="store_true",
help="""
Set the microseconds to a unique value per commit.
Allows telling apart changes that would otherwise have identical timestamps,
as git's time accuracy is in seconds.
""")
""",
)
parser.add_argument('pathspec', nargs='*', metavar='PATHSPEC', help="""
parser.add_argument(
"pathspec",
nargs="*",
metavar="PATHSPEC",
help="""
Only modify paths matching %(metavar)s, relative to current directory.
By default, update all but untracked files and submodules.
""")
""",
)
parser.add_argument('--version', '-V', action='version',
version='%(prog)s version {version}'.format(version=get_version()))
parser.add_argument(
"--version",
"-V",
action="version",
version="%(prog)s version {version}".format(version=get_version()),
)
args_ = parser.parse_args()
if args_.verbose:
@@ -212,17 +308,18 @@ def parse_args():
def get_version(version=__version__):
if not version.endswith('+dev'):
if not version.endswith("+dev"):
return version
try:
cwd = os.path.dirname(os.path.realpath(__file__))
return Git(cwd=cwd, errors=False).describe().lstrip('v')
return Git(cwd=cwd, errors=False).describe().lstrip("v")
except Git.Error:
return '-'.join((version, "unknown"))
return "-".join((version, "unknown"))
# Helper functions ############################################################
def setup_logging():
"""Add TRACE logging level and corresponding method, return the root logger"""
logging.TRACE = TRACE = logging.DEBUG // 2
@@ -255,11 +352,13 @@ def normalize(path):
if path and path[0] == '"':
# Python 2: path = path[1:-1].decode("string-escape")
# Python 3: https://stackoverflow.com/a/46650050/624066
path = (path[1:-1] # Remove enclosing double quotes
.encode('latin1') # Convert to bytes, required by 'unicode-escape'
.decode('unicode-escape') # Perform the actual octal-escaping decode
.encode('latin1') # 1:1 mapping to bytes, UTF-8 encoded
.decode('utf8', 'surrogateescape')) # Decode from UTF-8
path = (
path[1:-1] # Remove enclosing double quotes
.encode("latin1") # Convert to bytes, required by 'unicode-escape'
.decode("unicode-escape") # Perform the actual octal-escaping decode
.encode("latin1") # 1:1 mapping to bytes, UTF-8 encoded
.decode("utf8", "surrogateescape")
) # Decode from UTF-8
if NORMALIZE_PATHS:
# Make sure the slash matches the OS; for Windows we need a backslash
path = os.path.normpath(path)
@@ -282,12 +381,12 @@ def touch_ns(path, mtime_ns):
def isodate(secs: int):
# time.localtime() accepts floats, but discards fractional part
return time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(secs))
return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(secs))
def isodate_ns(ns: int):
# for integers fromtimestamp() is equivalent and ~16% slower than isodate()
return datetime.datetime.fromtimestamp(ns / 1000000000).isoformat(sep=' ')
return datetime.datetime.fromtimestamp(ns / 1000000000).isoformat(sep=" ")
def get_mtime_ns(secs: int, idx: int):
@@ -305,35 +404,49 @@ def get_mtime_path(path):
# Git class and parse_log(), the heart of the script ##########################
class Git:
def __init__(self, workdir=None, gitdir=None, cwd=None, errors=True):
self.gitcmd = ['git']
self.gitcmd = ["git"]
self.errors = errors
self._proc = None
if workdir: self.gitcmd.extend(('--work-tree', workdir))
if gitdir: self.gitcmd.extend(('--git-dir', gitdir))
if cwd: self.gitcmd.extend(('-C', cwd))
if workdir:
self.gitcmd.extend(("--work-tree", workdir))
if gitdir:
self.gitcmd.extend(("--git-dir", gitdir))
if cwd:
self.gitcmd.extend(("-C", cwd))
self.workdir, self.gitdir = self._get_repo_dirs()
def ls_files(self, paths: list = None):
return (normalize(_) for _ in self._run('ls-files --full-name', paths))
return (normalize(_) for _ in self._run("ls-files --full-name", paths))
def ls_dirty(self, force=False):
return (normalize(_[3:].split(' -> ', 1)[-1])
for _ in self._run('status --porcelain')
if _[:2] != '??' and (not force or (_[0] in ('R', 'A')
or _[1] == 'D')))
return (
normalize(_[3:].split(" -> ", 1)[-1])
for _ in self._run("status --porcelain")
if _[:2] != "??" and (not force or (_[0] in ("R", "A") or _[1] == "D"))
)
def log(self, merge=False, first_parent=False, commit_time=False,
reverse_order=False, paths: list = None):
cmd = 'whatchanged --pretty={}'.format('%ct' if commit_time else '%at')
if merge: cmd += ' -m'
if first_parent: cmd += ' --first-parent'
if reverse_order: cmd += ' --reverse'
def log(
self,
merge=False,
first_parent=False,
commit_time=False,
reverse_order=False,
paths: list = None,
):
cmd = "whatchanged --pretty={}".format("%ct" if commit_time else "%at")
if merge:
cmd += " -m"
if first_parent:
cmd += " --first-parent"
if reverse_order:
cmd += " --reverse"
return self._run(cmd, paths)
def describe(self):
return self._run('describe --tags', check=True)[0]
return self._run("describe --tags", check=True)[0]
def terminate(self):
if self._proc is None:
@@ -345,18 +458,22 @@ class Git:
pass
def _get_repo_dirs(self):
return (os.path.normpath(_) for _ in
self._run('rev-parse --show-toplevel --absolute-git-dir', check=True))
return (
os.path.normpath(_)
for _ in self._run(
"rev-parse --show-toplevel --absolute-git-dir", check=True
)
)
def _run(self, cmdstr: str, paths: list = None, output=True, check=False):
cmdlist = self.gitcmd + shlex.split(cmdstr)
if paths:
cmdlist.append('--')
cmdlist.append("--")
cmdlist.extend(paths)
popen_args = dict(universal_newlines=True, encoding='utf8')
popen_args = dict(universal_newlines=True, encoding="utf8")
if not self.errors:
popen_args['stderr'] = subprocess.DEVNULL
log.trace("Executing: %s", ' '.join(cmdlist))
popen_args["stderr"] = subprocess.DEVNULL
log.trace("Executing: %s", " ".join(cmdlist))
if not output:
return subprocess.call(cmdlist, **popen_args)
if check:
@@ -379,30 +496,26 @@ def parse_log(filelist, dirlist, stats, git, merge=False, filterlist=None):
mtime = 0
datestr = isodate(0)
for line in git.log(
merge,
args.first_parent,
args.commit_time,
args.reverse_order,
filterlist
merge, args.first_parent, args.commit_time, args.reverse_order, filterlist
):
stats['loglines'] += 1
stats["loglines"] += 1
# Blank line between Date and list of files
if not line:
continue
# Date line
if line[0] != ':': # Faster than `not line.startswith(':')`
stats['commits'] += 1
if line[0] != ":": # Faster than `not line.startswith(':')`
stats["commits"] += 1
mtime = int(line)
if args.unique_times:
mtime = get_mtime_ns(mtime, stats['commits'])
mtime = get_mtime_ns(mtime, stats["commits"])
if args.debug:
datestr = isodate(mtime)
continue
# File line: three tokens if it describes a renaming, otherwise two
tokens = line.split('\t')
tokens = line.split("\t")
# Possible statuses:
# M: Modified (content changed)
@@ -411,7 +524,7 @@ def parse_log(filelist, dirlist, stats, git, merge=False, filterlist=None):
# T: Type changed: to/from regular file, symlinks, submodules
# R099: Renamed (moved), with % of unchanged content. 100 = pure rename
# Not possible in log: C=Copied, U=Unmerged, X=Unknown, B=pairing Broken
status = tokens[0].split(' ')[-1]
status = tokens[0].split(" ")[-1]
file = tokens[-1]
# Handles non-ASCII chars and OS path separator
@@ -419,56 +532,76 @@ def parse_log(filelist, dirlist, stats, git, merge=False, filterlist=None):
def do_file():
if args.skip_older_than_commit and get_mtime_path(file) <= mtime:
stats['skip'] += 1
stats["skip"] += 1
return
if args.debug:
log.debug("%d\t%d\t%d\t%s\t%s",
stats['loglines'], stats['commits'], stats['files'],
datestr, file)
log.debug(
"%d\t%d\t%d\t%s\t%s",
stats["loglines"],
stats["commits"],
stats["files"],
datestr,
file,
)
try:
touch(os.path.join(git.workdir, file), mtime)
stats['touches'] += 1
stats["touches"] += 1
except Exception as e:
log.error("ERROR: %s: %s", e, file)
stats['errors'] += 1
stats["errors"] += 1
def do_dir():
if args.debug:
log.debug("%d\t%d\t-\t%s\t%s",
stats['loglines'], stats['commits'],
datestr, "{}/".format(dirname or '.'))
log.debug(
"%d\t%d\t-\t%s\t%s",
stats["loglines"],
stats["commits"],
datestr,
"{}/".format(dirname or "."),
)
try:
touch(os.path.join(git.workdir, dirname), mtime)
stats['dirtouches'] += 1
stats["dirtouches"] += 1
except Exception as e:
log.error("ERROR: %s: %s", e, dirname)
stats['direrrors'] += 1
stats["direrrors"] += 1
if file in filelist:
stats['files'] -= 1
stats["files"] -= 1
filelist.remove(file)
do_file()
if args.dirs and status in ('A', 'D'):
if args.dirs and status in ("A", "D"):
dirname = os.path.dirname(file)
if dirname in dirlist:
dirlist.remove(dirname)
do_dir()
# All files done?
if not stats['files']:
if not stats["files"]:
git.terminate()
return
# Main Logic ##################################################################
def main():
start = time.time() # yes, Wall time. CPU time is not realistic for users.
stats = {_: 0 for _ in ('loglines', 'commits', 'touches', 'skip', 'errors',
'dirtouches', 'direrrors')}
stats = {
_: 0
for _ in (
"loglines",
"commits",
"touches",
"skip",
"errors",
"dirtouches",
"direrrors",
)
}
logging.basicConfig(level=args.loglevel, format='%(message)s')
logging.basicConfig(level=args.loglevel, format="%(message)s")
log.trace("Arguments: %s", args)
# First things first: Where and Who are we?
@@ -499,13 +632,16 @@ def main():
# Symlink (to file, to dir or broken - git handles the same way)
if not UPDATE_SYMLINKS and os.path.islink(fullpath):
log.warning("WARNING: Skipping symlink, no OS support for updates: %s",
path)
log.warning(
"WARNING: Skipping symlink, no OS support for updates: %s", path
)
continue
# skip files which are older than given threshold
if (args.skip_older_than
and start - get_mtime_path(fullpath) > args.skip_older_than):
if (
args.skip_older_than
and start - get_mtime_path(fullpath) > args.skip_older_than
):
continue
# Always add files relative to worktree root
@@ -519,15 +655,17 @@ def main():
else:
dirty = set(git.ls_dirty())
if dirty:
log.warning("WARNING: Modified files in the working directory were ignored."
"\nTo include such files, commit your changes or use --force.")
log.warning(
"WARNING: Modified files in the working directory were ignored."
"\nTo include such files, commit your changes or use --force."
)
filelist -= dirty
# Build dir list to be processed
dirlist = set(os.path.dirname(_) for _ in filelist) if args.dirs else set()
stats['totalfiles'] = stats['files'] = len(filelist)
log.info("{0:,} files to be processed in work dir".format(stats['totalfiles']))
stats["totalfiles"] = stats["files"] = len(filelist)
log.info("{0:,} files to be processed in work dir".format(stats["totalfiles"]))
if not filelist:
# Nothing to do. Exit silently and without errors, just like git does
@@ -544,10 +682,18 @@ def main():
if args.missing and not args.merge:
filterlist = list(filelist)
missing = len(filterlist)
log.info("{0:,} files not found in log, trying merge commits".format(missing))
log.info(
"{0:,} files not found in log, trying merge commits".format(missing)
)
for i in range(0, missing, STEPMISSING):
parse_log(filelist, dirlist, stats, git,
merge=True, filterlist=filterlist[i:i + STEPMISSING])
parse_log(
filelist,
dirlist,
stats,
git,
merge=True,
filterlist=filterlist[i : i + STEPMISSING],
)
# Still missing some?
for file in filelist:
@@ -556,29 +702,33 @@ def main():
# Final statistics
# Suggestion: use git-log --before=mtime to brag about skipped log entries
def log_info(msg, *a, width=13):
ifmt = '{:%d,}' % (width,) # not using 'n' for consistency with ffmt
ffmt = '{:%d,.2f}' % (width,)
ifmt = "{:%d,}" % (width,) # not using 'n' for consistency with ffmt
ffmt = "{:%d,.2f}" % (width,)
# %-formatting lacks a thousand separator, must pre-render with .format()
log.info(msg.replace('%d', ifmt).replace('%f', ffmt).format(*a))
log.info(msg.replace("%d", ifmt).replace("%f", ffmt).format(*a))
log_info(
"Statistics:\n"
"%f seconds\n"
"%d log lines processed\n"
"%d commits evaluated",
time.time() - start, stats['loglines'], stats['commits'])
"Statistics:\n%f seconds\n%d log lines processed\n%d commits evaluated",
time.time() - start,
stats["loglines"],
stats["commits"],
)
if args.dirs:
if stats['direrrors']: log_info("%d directory update errors", stats['direrrors'])
log_info("%d directories updated", stats['dirtouches'])
if stats["direrrors"]:
log_info("%d directory update errors", stats["direrrors"])
log_info("%d directories updated", stats["dirtouches"])
if stats['touches'] != stats['totalfiles']:
log_info("%d files", stats['totalfiles'])
if stats['skip']: log_info("%d files skipped", stats['skip'])
if stats['files']: log_info("%d files missing", stats['files'])
if stats['errors']: log_info("%d file update errors", stats['errors'])
if stats["touches"] != stats["totalfiles"]:
log_info("%d files", stats["totalfiles"])
if stats["skip"]:
log_info("%d files skipped", stats["skip"])
if stats["files"]:
log_info("%d files missing", stats["files"])
if stats["errors"]:
log_info("%d file update errors", stats["errors"])
log_info("%d files updated", stats['touches'])
log_info("%d files updated", stats["touches"])
if args.test:
log.info("TEST RUN - No files modified!")

View File

@@ -388,11 +388,12 @@ jobs:
- name: Test against ${{ matrix.partner }}
if: startsWith(inputs.working-directory, 'libs/core')
run: |
# Identify latest tag
# Identify latest tag, excluding pre-releases
LATEST_PACKAGE_TAG="$(
git ls-remote --tags origin "langchain-${{ matrix.partner }}*" \
| awk '{print $2}' \
| sed 's|refs/tags/||' \
| grep -Ev '==[^=]*(\.?dev[0-9]*|\.?rc[0-9]*)$' \
| sort -Vr \
| head -n 1
)"

View File

@@ -79,4 +79,4 @@ jobs:
# grep will exit non-zero if the target message isn't found,
# and `set -e` above will cause the step to fail.
echo "$STATUS" | grep 'nothing to commit, working tree clean'

View File

@@ -64,4 +64,4 @@ jobs:
# grep will exit non-zero if the target message isn't found,
# and `set -e` above will cause the step to fail.
echo "$STATUS" | grep 'nothing to commit, working tree clean'
echo "$STATUS" | grep 'nothing to commit, working tree clean'

View File

@@ -52,7 +52,6 @@ jobs:
run: |
# Get unique repositories
REPOS=$(echo "$REPOS_UNSORTED" | sort -u)
# Checkout each unique repository
for repo in $REPOS; do
# Validate repository format (allow any org with proper format)
@@ -68,7 +67,6 @@ jobs:
echo "Error: Invalid repository name: $REPO_NAME"
exit 1
fi
echo "Checking out $repo to $REPO_NAME"
git clone --depth 1 https://github.com/$repo.git $REPO_NAME
done

View File

@@ -30,6 +30,7 @@ jobs:
build:
name: 'Detect Changes & Set Matrix'
runs-on: ubuntu-latest
if: ${{ !contains(github.event.pull_request.labels.*.name, 'ci-ignore') }}
steps:
- name: '📋 Checkout Code'
uses: actions/checkout@v4

View File

@@ -20,6 +20,7 @@ jobs:
codspeed:
name: 'Benchmark'
runs-on: ubuntu-latest
if: ${{ !contains(github.event.pull_request.labels.*.name, 'codspeed-ignore') }}
strategy:
matrix:
include:

View File

@@ -11,4 +11,4 @@
"MD046": {
"style": "fenced"
}
}
}

View File

@@ -21,7 +21,7 @@
"[python]": {
"editor.formatOnSave": true,
"editor.codeActionsOnSave": {
"source.organizeImports": "explicit",
"source.organizeImports.ruff": "explicit",
"source.fixAll": "explicit"
},
"editor.defaultFormatter": "charliermarsh.ruff"
@@ -77,4 +77,6 @@
"editor.tabSize": 2,
"editor.insertSpaces": true
},
"python.terminal.activateEnvironment": false,
"python.defaultInterpreterPath": "./.venv/bin/python"
}

View File

@@ -63,4 +63,4 @@ Notebook | Description
[rag-locally-on-intel-cpu.ipynb](https://github.com/langchain-ai/langchain/tree/master/cookbook/rag-locally-on-intel-cpu.ipynb) | Perform Retrieval-Augmented-Generation (RAG) on locally downloaded open-source models using langchain and open source tools and execute it on Intel Xeon CPU. We showed an example of how to apply RAG on Llama 2 model and enable it to answer the queries related to Intel Q1 2024 earnings release.
[visual_RAG_vdms.ipynb](https://github.com/langchain-ai/langchain/tree/master/cookbook/visual_RAG_vdms.ipynb) | Performs Visual Retrieval-Augmented-Generation (RAG) using videos and scene descriptions generated by open source models.
[contextual_rag.ipynb](https://github.com/langchain-ai/langchain/tree/master/cookbook/contextual_rag.ipynb) | Performs contextual retrieval-augmented generation (RAG) prepending chunk-specific explanatory context to each chunk before embedding.
[rag-agents-locally-on-intel-cpu.ipynb](https://github.com/langchain-ai/langchain/tree/master/cookbook/local_rag_agents_intel_cpu.ipynb) | Build a RAG agent locally with open source models that routes questions through one of two paths to find answers. The agent generates answers based on documents retrieved from either the vector database or retrieved from web search. If the vector database lacks relevant information, the agent opts for web search. Open-source models for LLM and embeddings are used locally on an Intel Xeon CPU to execute this pipeline.
[rag-agents-locally-on-intel-cpu.ipynb](https://github.com/langchain-ai/langchain/tree/master/cookbook/local_rag_agents_intel_cpu.ipynb) | Build a RAG agent locally with open source models that routes questions through one of two paths to find answers. The agent generates answers based on documents retrieved from either the vector database or retrieved from web search. If the vector database lacks relevant information, the agent opts for web search. Open-source models for LLM and embeddings are used locally on an Intel Xeon CPU to execute this pipeline.

View File

@@ -97,7 +97,7 @@ def _load_module_members(module_path: str, namespace: str) -> ModuleMembers:
if type(type_) is typing_extensions._TypedDictMeta: # type: ignore
kind: ClassKind = "TypedDict"
elif type(type_) is typing._TypedDictMeta: # type: ignore
kind: ClassKind = "TypedDict"
kind = "TypedDict"
elif (
issubclass(type_, Runnable)
and issubclass(type_, BaseModel)
@@ -189,7 +189,7 @@ def _load_package_modules(
if isinstance(package_directory, str)
else package_directory
)
modules_by_namespace = {}
modules_by_namespace: Dict[str, ModuleMembers] = {}
# Get the high level package name
package_name = package_path.name
@@ -217,7 +217,11 @@ def _load_package_modules(
# Get the full namespace of the module
namespace = str(relative_module_name).replace(".py", "").replace("/", ".")
# Keep only the top level namespace
top_namespace = namespace.split(".")[0]
# (but make special exception for content_blocks and v1.messages)
if namespace == "messages.content_blocks" or namespace == "v1.messages":
top_namespace = namespace # Keep full namespace for content_blocks
else:
top_namespace = namespace.split(".")[0]
try:
# If submodule is present, we need to construct the paths in a slightly
@@ -283,7 +287,7 @@ def _construct_doc(
.. toctree::
:hidden:
:maxdepth: 2
"""
index_autosummary = """
"""
@@ -365,9 +369,9 @@ def _construct_doc(
module_doc += f"""\
:template: {template}
{class_["qualified_name"]}
"""
index_autosummary += f"""
{class_["qualified_name"]}
@@ -550,8 +554,8 @@ def _build_index(dirs: List[str]) -> None:
integrations = sorted(dir_ for dir_ in dirs if dir_ not in main_)
doc = """# LangChain Python API Reference
Welcome to the LangChain Python API reference. This is a reference for all
`langchain-x` packages.
Welcome to the LangChain Python API reference. This is a reference for all
`langchain-x` packages.
For user guides see [https://python.langchain.com](https://python.langchain.com).

View File

@@ -124,6 +124,47 @@ start "" htmlcov/index.html || open htmlcov/index.html
```
## Snapshot Testing
Some tests use [syrupy](https://github.com/tophat/syrupy) for snapshot testing, which captures the output of functions and compares them to stored snapshots. This is particularly useful for testing JSON schema generation and other structured outputs.
### Updating Snapshots
To update snapshots when the expected output has legitimately changed:
```bash
uv run --group test pytest path/to/test.py --snapshot-update
```
### Pydantic Version Compatibility Issues
Pydantic generates different JSON schemas across versions, which can cause snapshot test failures in CI when tests run with different Pydantic versions than what was used to generate the snapshots.
**Symptoms:**
- CI fails with snapshot mismatches showing differences like missing or extra fields.
- Tests pass locally but fail in CI with different Pydantic versions
**Solution:**
Locally update snapshots using the same Pydantic version that CI uses:
1. **Identify the failing Pydantic version** from CI logs (e.g., `2.7.0`, `2.8.0`, `2.9.0`)
2. **Update snapshots with that version:**
```bash
uv run --with "pydantic==2.9.0" --group test pytest tests/unit_tests/path/to/test.py::test_name --snapshot-update
```
3. **Verify compatibility across supported versions:**
```bash
# Test with the version you used to update
uv run --with "pydantic==2.9.0" --group test pytest tests/unit_tests/path/to/test.py::test_name
# Test with other supported versions
uv run --with "pydantic==2.8.0" --group test pytest tests/unit_tests/path/to/test.py::test_name
```
**Note:** Some tests use `@pytest.mark.skipif` decorators to only run with specific Pydantic version ranges (e.g., `PYDANTIC_VERSION_AT_LEAST_210`). Make sure to understand these constraints when updating snapshots.
## Coverage
Code coverage (i.e. the amount of code that is covered by unit tests) helps identify areas of the code that are potentially more or less brittle.

View File

@@ -122,13 +122,13 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"from langchain_experimental.graph_transformers import LLMGraphTransformer\n",
"# from langchain_experimental.graph_transformers import LLMGraphTransformer\n",
"from langchain_openai import ChatOpenAI\n",
"\n",
"llm = ChatOpenAI(temperature=0, model_name=\"gpt-4-turbo\")\n",

View File

@@ -74,12 +74,12 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"id": "a88ff70c",
"metadata": {},
"outputs": [],
"source": [
"from langchain_experimental.text_splitter import SemanticChunker\n",
"# from langchain_experimental.text_splitter import SemanticChunker\n",
"from langchain_openai.embeddings import OpenAIEmbeddings\n",
"\n",
"text_splitter = SemanticChunker(OpenAIEmbeddings())"

View File

@@ -612,56 +612,11 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": null,
"id": "35ea904e-795f-411b-bef8-6484dbb6e35c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m\n",
"Invoking: `python_repl_ast` with `{'query': \"df[['Age', 'Fare']].corr().iloc[0,1]\"}`\n",
"\n",
"\n",
"\u001b[0m\u001b[36;1m\u001b[1;3m0.11232863699941621\u001b[0m\u001b[32;1m\u001b[1;3m\n",
"Invoking: `python_repl_ast` with `{'query': \"df[['Fare', 'Survived']].corr().iloc[0,1]\"}`\n",
"\n",
"\n",
"\u001b[0m\u001b[36;1m\u001b[1;3m0.2561785496289603\u001b[0m\u001b[32;1m\u001b[1;3mThe correlation between Age and Fare is approximately 0.112, and the correlation between Fare and Survival is approximately 0.256.\n",
"\n",
"Therefore, the correlation between Fare and Survival (0.256) is greater than the correlation between Age and Fare (0.112).\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"{'input': \"What's the correlation between age and fare? is that greater than the correlation between fare and survival?\",\n",
" 'output': 'The correlation between Age and Fare is approximately 0.112, and the correlation between Fare and Survival is approximately 0.256.\\n\\nTherefore, the correlation between Fare and Survival (0.256) is greater than the correlation between Age and Fare (0.112).'}"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from langchain_experimental.agents import create_pandas_dataframe_agent\n",
"\n",
"agent = create_pandas_dataframe_agent(\n",
" llm, df, agent_type=\"openai-tools\", verbose=True, allow_dangerous_code=True\n",
")\n",
"agent.invoke(\n",
" {\n",
" \"input\": \"What's the correlation between age and fare? is that greater than the correlation between fare and survival?\"\n",
" }\n",
")"
]
"outputs": [],
"source": "from langchain_experimental.agents import create_pandas_dataframe_agent\n\nagent = create_pandas_dataframe_agent(\n llm, df, agent_type=\"openai-tools\", verbose=True, allow_dangerous_code=True\n)\nagent.invoke(\n {\n \"input\": \"What's the correlation between age and fare? is that greater than the correlation between fare and survival?\"\n }\n)"
},
{
"cell_type": "markdown",
@@ -786,4 +741,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -447,6 +447,163 @@
")"
]
},
{
"cell_type": "markdown",
"id": "c5d9d19d-8ab1-4d9d-b3a0-56ee4e89c528",
"metadata": {},
"source": [
"### Custom tools\n",
"\n",
":::info Requires ``langchain-openai>=0.3.29``\n",
"\n",
":::\n",
"\n",
"[Custom tools](https://platform.openai.com/docs/guides/function-calling#custom-tools) support tools with arbitrary string inputs. They can be particularly useful when you expect your string arguments to be long or complex."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "a47c809b-852f-46bd-8b9e-d9534c17213d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"================================\u001b[1m Human Message \u001b[0m=================================\n",
"\n",
"Use the tool to calculate 3^3.\n",
"==================================\u001b[1m Ai Message \u001b[0m==================================\n",
"\n",
"[{'id': 'rs_6894ff5747c0819d9b02fc5645b0be9c000169fd9fb68d99', 'summary': [], 'type': 'reasoning'}, {'call_id': 'call_7SYwMSQPbbEqFcKlKOpXeEux', 'input': 'print(3**3)', 'name': 'execute_code', 'type': 'custom_tool_call', 'id': 'ctc_6894ff5b9f54819d8155a63638d34103000169fd9fb68d99', 'status': 'completed'}]\n",
"Tool Calls:\n",
" execute_code (call_7SYwMSQPbbEqFcKlKOpXeEux)\n",
" Call ID: call_7SYwMSQPbbEqFcKlKOpXeEux\n",
" Args:\n",
" __arg1: print(3**3)\n",
"=================================\u001b[1m Tool Message \u001b[0m=================================\n",
"Name: execute_code\n",
"\n",
"[{'type': 'custom_tool_call_output', 'output': '27'}]\n",
"==================================\u001b[1m Ai Message \u001b[0m==================================\n",
"\n",
"[{'type': 'text', 'text': '27', 'annotations': [], 'id': 'msg_6894ff5db3b8819d9159b3a370a25843000169fd9fb68d99'}]\n"
]
}
],
"source": [
"from langchain_openai import ChatOpenAI, custom_tool\n",
"from langgraph.prebuilt import create_react_agent\n",
"\n",
"\n",
"@custom_tool\n",
"def execute_code(code: str) -> str:\n",
" \"\"\"Execute python code.\"\"\"\n",
" return \"27\"\n",
"\n",
"\n",
"llm = ChatOpenAI(model=\"gpt-5\", output_version=\"responses/v1\")\n",
"\n",
"agent = create_react_agent(llm, [execute_code])\n",
"\n",
"input_message = {\"role\": \"user\", \"content\": \"Use the tool to calculate 3^3.\"}\n",
"for step in agent.stream(\n",
" {\"messages\": [input_message]},\n",
" stream_mode=\"values\",\n",
"):\n",
" step[\"messages\"][-1].pretty_print()"
]
},
{
"cell_type": "markdown",
"id": "5ef93be6-6d4c-4eea-acfd-248774074082",
"metadata": {},
"source": [
"<details>\n",
"<summary>Context-free grammars</summary>\n",
"\n",
"OpenAI supports the specification of a [context-free grammar](https://platform.openai.com/docs/guides/function-calling#context-free-grammars) for custom tool inputs in `lark` or `regex` format. See [OpenAI docs](https://platform.openai.com/docs/guides/function-calling#context-free-grammars) for details. The `format` parameter can be passed into `@custom_tool` as shown below:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "2ae04586-be33-49c6-8947-7867801d868f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"================================\u001b[1m Human Message \u001b[0m=================================\n",
"\n",
"Use the tool to calculate 3^3.\n",
"==================================\u001b[1m Ai Message \u001b[0m==================================\n",
"\n",
"[{'id': 'rs_689500828a8481a297ff0f98e328689c0681550c89797f43', 'summary': [], 'type': 'reasoning'}, {'call_id': 'call_jzH01RVhu6EFz7yUrOFXX55s', 'input': '3 * 3 * 3', 'name': 'do_math', 'type': 'custom_tool_call', 'id': 'ctc_6895008d57bc81a2b84d0993517a66b90681550c89797f43', 'status': 'completed'}]\n",
"Tool Calls:\n",
" do_math (call_jzH01RVhu6EFz7yUrOFXX55s)\n",
" Call ID: call_jzH01RVhu6EFz7yUrOFXX55s\n",
" Args:\n",
" __arg1: 3 * 3 * 3\n",
"=================================\u001b[1m Tool Message \u001b[0m=================================\n",
"Name: do_math\n",
"\n",
"[{'type': 'custom_tool_call_output', 'output': '27'}]\n",
"==================================\u001b[1m Ai Message \u001b[0m==================================\n",
"\n",
"[{'type': 'text', 'text': '27', 'annotations': [], 'id': 'msg_6895009776b881a2a25f0be8507d08f20681550c89797f43'}]\n"
]
}
],
"source": [
"from langchain_openai import ChatOpenAI, custom_tool\n",
"from langgraph.prebuilt import create_react_agent\n",
"\n",
"grammar = \"\"\"\n",
"start: expr\n",
"expr: term (SP ADD SP term)* -> add\n",
"| term\n",
"term: factor (SP MUL SP factor)* -> mul\n",
"| factor\n",
"factor: INT\n",
"SP: \" \"\n",
"ADD: \"+\"\n",
"MUL: \"*\"\n",
"%import common.INT\n",
"\"\"\"\n",
"\n",
"format_ = {\"type\": \"grammar\", \"syntax\": \"lark\", \"definition\": grammar}\n",
"\n",
"\n",
"# highlight-next-line\n",
"@custom_tool(format=format_)\n",
"def do_math(input_string: str) -> str:\n",
" \"\"\"Do a mathematical operation.\"\"\"\n",
" return \"27\"\n",
"\n",
"\n",
"llm = ChatOpenAI(model=\"gpt-5\", output_version=\"responses/v1\")\n",
"\n",
"agent = create_react_agent(llm, [do_math])\n",
"\n",
"input_message = {\"role\": \"user\", \"content\": \"Use the tool to calculate 3^3.\"}\n",
"for step in agent.stream(\n",
" {\"messages\": [input_message]},\n",
" stream_mode=\"values\",\n",
"):\n",
" step[\"messages\"][-1].pretty_print()"
]
},
{
"cell_type": "markdown",
"id": "c63430c9-c7b0-4e92-a491-3f165dddeb8f",
"metadata": {},
"source": [
"</details>"
]
},
{
"cell_type": "markdown",
"id": "84833dd0-17e9-4269-82ed-550639d65751",

View File

@@ -132,12 +132,13 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain_core.documents import Document\n",
"from langchain_experimental.graph_transformers import LLMGraphTransformer\n",
"\n",
"# from langchain_experimental.graph_transformers import LLMGraphTransformer\n",
"from langchain_openai import ChatOpenAI\n",
"\n",
"# Define the LLMGraphTransformer\n",

View File

@@ -548,12 +548,12 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain_core.documents import Document\n",
"from langchain_experimental.graph_transformers import LLMGraphTransformer"
"# from langchain_experimental.graph_transformers import LLMGraphTransformer"
]
},
{

View File

@@ -29,8 +29,8 @@
" Please refer to the instructions in:\n",
" [www.jaguardb.com](http://www.jaguardb.com)\n",
" For quick setup in docker environment:\n",
" docker pull jaguardb/jaguardb_with_http\n",
" docker run -d -p 8888:8888 -p 8080:8080 --name jaguardb_with_http jaguardb/jaguardb_with_http\n",
" docker pull jaguardb/jaguardb\n",
" docker run -d -p 8888:8888 -p 8080:8080 --name jaguardb jaguardb/jaguardb\n",
"\n",
"2. You must install the http client package for JaguarDB:\n",
" ```\n",

View File

@@ -35,6 +35,7 @@ embeddings.embed_query("What is the meaning of life?")
```
## LLMs
`__ModuleName__LLM` class exposes LLMs from __ModuleName__.
```python

View File

@@ -1,3 +1,3 @@
version: 0.0.1
patterns:
- name: github.com/getgrit/stdlib#*
- name: github.com/getgrit/stdlib#*

View File

@@ -27,16 +27,16 @@ langchain app add __package_name__
```
And add the following code to your `server.py` file:
```python
__app_route_code__
```
(Optional) Let's now configure LangSmith.
LangSmith will help us trace, monitor and debug LangChain applications.
You can sign up for LangSmith [here](https://smith.langchain.com/).
(Optional) Let's now configure LangSmith.
LangSmith will help us trace, monitor and debug LangChain applications.
You can sign up for LangSmith [here](https://smith.langchain.com/).
If you don't have access, you can skip this section
```shell
export LANGSMITH_TRACING=true
export LANGSMITH_API_KEY=<your-api-key>
@@ -49,11 +49,11 @@ If you are inside this directory, then you can spin up a LangServe instance dire
langchain serve
```
This will start the FastAPI app with a server is running locally at
This will start the FastAPI app with a server is running locally at
[http://localhost:8000](http://localhost:8000)
We can see all templates at [http://127.0.0.1:8000/docs](http://127.0.0.1:8000/docs)
We can access the playground at [http://127.0.0.1:8000/__package_name__/playground](http://127.0.0.1:8000/__package_name__/playground)
We can access the playground at [http://127.0.0.1:8000/__package_name__/playground](http://127.0.0.1:8000/__package_name__/playground)
We can access the template from code with:
@@ -61,4 +61,4 @@ We can access the template from code with:
from langserve.client import RemoteRunnable
runnable = RemoteRunnable("http://localhost:8000/__package_name__")
```
```

View File

@@ -11,7 +11,7 @@ pip install -U langchain-cli
## Adding packages
```bash
# adding packages from
# adding packages from
# https://github.com/langchain-ai/langchain/tree/master/templates
langchain app add $PROJECT_NAME
@@ -31,10 +31,10 @@ langchain app remove my/custom/path/rag
```
## Setup LangSmith (Optional)
LangSmith will help us trace, monitor and debug LangChain applications.
You can sign up for LangSmith [here](https://smith.langchain.com/).
If you don't have access, you can skip this section
LangSmith will help us trace, monitor and debug LangChain applications.
You can sign up for LangSmith [here](https://smith.langchain.com/).
If you don't have access, you can skip this section
```shell
export LANGSMITH_TRACING=true

View File

@@ -144,7 +144,7 @@ def beta(
obj.__init__ = functools.wraps(obj.__init__)( # type: ignore[misc]
warn_if_direct_instance
)
return cast("T", obj)
return obj
elif isinstance(obj, property):
# note(erick): this block doesn't seem to be used?

View File

@@ -225,7 +225,7 @@ def deprecated(
obj.__init__ = functools.wraps(obj.__init__)( # type: ignore[misc]
warn_if_direct_instance
)
return cast("T", obj)
return obj
elif isinstance(obj, FieldInfoV1):
wrapped = None

View File

@@ -7,6 +7,8 @@ from typing import TYPE_CHECKING, Any, Optional, Union
from typing_extensions import Self
from langchain_core.v1.messages import AIMessage, AIMessageChunk, MessageV1
if TYPE_CHECKING:
from collections.abc import Sequence
from uuid import UUID
@@ -66,7 +68,9 @@ class LLMManagerMixin:
self,
token: str,
*,
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
chunk: Optional[
Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]
] = None,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
@@ -75,8 +79,8 @@ class LLMManagerMixin:
Args:
token (str): The new token.
chunk (GenerationChunk | ChatGenerationChunk): The new generated chunk,
containing content and other information.
chunk (GenerationChunk | ChatGenerationChunk | AIMessageChunk): The new
generated chunk, containing content and other information.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
kwargs (Any): Additional keyword arguments.
@@ -84,7 +88,7 @@ class LLMManagerMixin:
def on_llm_end(
self,
response: LLMResult,
response: Union[LLMResult, AIMessage],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
@@ -93,7 +97,7 @@ class LLMManagerMixin:
"""Run when LLM ends running.
Args:
response (LLMResult): The response which was generated.
response (LLMResult | AIMessage): The response which was generated.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
kwargs (Any): Additional keyword arguments.
@@ -261,7 +265,7 @@ class CallbackManagerMixin:
def on_chat_model_start(
self,
serialized: dict[str, Any],
messages: list[list[BaseMessage]],
messages: Union[list[list[BaseMessage]], list[MessageV1]],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
@@ -439,6 +443,9 @@ class BaseCallbackHandler(
run_inline: bool = False
"""Whether to run the callback inline."""
accepts_new_messages: bool = False
"""Whether the callback accepts new message format."""
@property
def ignore_llm(self) -> bool:
"""Whether to ignore LLM callbacks."""
@@ -509,7 +516,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
async def on_chat_model_start(
self,
serialized: dict[str, Any],
messages: list[list[BaseMessage]],
messages: Union[list[list[BaseMessage]], list[MessageV1]],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
@@ -540,7 +547,9 @@ class AsyncCallbackHandler(BaseCallbackHandler):
self,
token: str,
*,
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
chunk: Optional[
Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]
] = None,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[list[str]] = None,
@@ -550,8 +559,8 @@ class AsyncCallbackHandler(BaseCallbackHandler):
Args:
token (str): The new token.
chunk (GenerationChunk | ChatGenerationChunk): The new generated chunk,
containing content and other information.
chunk (GenerationChunk | ChatGenerationChunk | AIMessageChunk): The new
generated chunk, containing content and other information.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
tags (Optional[list[str]]): The tags.
@@ -560,7 +569,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
async def on_llm_end(
self,
response: LLMResult,
response: Union[LLMResult, AIMessage],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
@@ -570,7 +579,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
"""Run when LLM ends running.
Args:
response (LLMResult): The response which was generated.
response (LLMResult | AIMessage): The response which was generated.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
tags (Optional[list[str]]): The tags.
@@ -594,8 +603,8 @@ class AsyncCallbackHandler(BaseCallbackHandler):
parent_run_id: The parent run ID. This is the ID of the parent run.
tags: The tags.
kwargs (Any): Additional keyword arguments.
- response (LLMResult): The response which was generated before
the error occurred.
- response (LLMResult | AIMessage): The response which was generated
before the error occurred.
"""
async def on_chain_start(

View File

@@ -49,7 +49,7 @@ class FileCallbackHandler(BaseCallbackHandler):
mode: The file open mode. Defaults to ``'a'`` (append).
color: Default color for text output. Defaults to ``None``.
Note:
.. note::
When not used as a context manager, a deprecation warning will be issued
on first use. The file will be opened immediately in ``__init__`` and closed
in ``__del__`` or when ``close()`` is called explicitly.
@@ -65,6 +65,7 @@ class FileCallbackHandler(BaseCallbackHandler):
filename: Path to the output file.
mode: File open mode (e.g., ``'w'``, ``'a'``, ``'x'``). Defaults to ``'a'``.
color: Default text color for output. Defaults to ``None``.
"""
self.filename = filename
self.mode = mode
@@ -82,9 +83,10 @@ class FileCallbackHandler(BaseCallbackHandler):
Returns:
The FileCallbackHandler instance.
Note:
.. note::
The file is already opened in ``__init__``, so this just marks that
the handler is being used as a context manager.
"""
self._file_opened_in_context = True
return self
@@ -101,6 +103,7 @@ class FileCallbackHandler(BaseCallbackHandler):
exc_type: Exception type if an exception occurred.
exc_val: Exception value if an exception occurred.
exc_tb: Exception traceback if an exception occurred.
"""
self.close()
@@ -113,6 +116,7 @@ class FileCallbackHandler(BaseCallbackHandler):
This method is safe to call multiple times and will only close
the file if it's currently open.
"""
if hasattr(self, "file") and self.file and not self.file.closed:
self.file.close()
@@ -133,6 +137,7 @@ class FileCallbackHandler(BaseCallbackHandler):
Raises:
RuntimeError: If the file is closed or not available.
"""
global _GLOBAL_DEPRECATION_WARNED # noqa: PLW0603
if not self._file_opened_in_context and not _GLOBAL_DEPRECATION_WARNED:
@@ -163,6 +168,7 @@ class FileCallbackHandler(BaseCallbackHandler):
serialized: The serialized chain information.
inputs: The inputs to the chain.
**kwargs: Additional keyword arguments that may contain ``'name'``.
"""
name = (
kwargs.get("name")
@@ -178,6 +184,7 @@ class FileCallbackHandler(BaseCallbackHandler):
Args:
outputs: The outputs of the chain.
**kwargs: Additional keyword arguments.
"""
self._write("\n> Finished chain.", end="\n")
@@ -192,6 +199,7 @@ class FileCallbackHandler(BaseCallbackHandler):
color: Color override for this specific output. If ``None``, uses
``self.color``.
**kwargs: Additional keyword arguments.
"""
self._write(action.log, color=color or self.color)
@@ -213,6 +221,7 @@ class FileCallbackHandler(BaseCallbackHandler):
observation_prefix: Optional prefix to write before the output.
llm_prefix: Optional prefix to write after the output.
**kwargs: Additional keyword arguments.
"""
if observation_prefix is not None:
self._write(f"\n{observation_prefix}")
@@ -232,6 +241,7 @@ class FileCallbackHandler(BaseCallbackHandler):
``self.color``.
end: String appended after the text. Defaults to ``""``.
**kwargs: Additional keyword arguments.
"""
self._write(text, color=color or self.color, end=end)
@@ -246,5 +256,6 @@ class FileCallbackHandler(BaseCallbackHandler):
color: Color override for this specific output. If ``None``, uses
``self.color``.
**kwargs: Additional keyword arguments.
"""
self._write(finish.log, color=color or self.color, end="\n")

View File

@@ -11,15 +11,7 @@ from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from contextlib import asynccontextmanager, contextmanager
from contextvars import copy_context
from typing import (
TYPE_CHECKING,
Any,
Callable,
Optional,
TypeVar,
Union,
cast,
)
from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union, cast
from uuid import UUID
from langsmith.run_helpers import get_tracing_context
@@ -37,8 +29,16 @@ from langchain_core.callbacks.base import (
)
from langchain_core.callbacks.stdout import StdOutCallbackHandler
from langchain_core.messages import BaseMessage, get_buffer_string
from langchain_core.messages.utils import convert_from_v1_message
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, LLMResult
from langchain_core.tracers.schemas import Run
from langchain_core.utils.env import env_var_is_set
from langchain_core.v1.messages import (
AIMessage,
AIMessageChunk,
MessageV1,
MessageV1Types,
)
if TYPE_CHECKING:
from collections.abc import AsyncGenerator, Coroutine, Generator, Sequence
@@ -47,7 +47,7 @@ if TYPE_CHECKING:
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.documents import Document
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk, LLMResult
from langchain_core.outputs import GenerationChunk
from langchain_core.runnables.config import RunnableConfig
logger = logging.getLogger(__name__)
@@ -92,7 +92,8 @@ def trace_as_chain_group(
metadata (dict[str, Any], optional): The metadata to apply to all runs.
Defaults to None.
Note: must have LANGCHAIN_TRACING_V2 env var set to true to see the trace in LangSmith.
.. note:
Must have ``LANGCHAIN_TRACING_V2`` env var set to true to see the trace in LangSmith.
Returns:
CallbackManagerForChainGroup: The callback manager for the chain group.
@@ -177,7 +178,8 @@ async def atrace_as_chain_group(
Returns:
AsyncCallbackManager: The async callback manager for the chain group.
Note: must have LANGCHAIN_TRACING_V2 env var set to true to see the trace in LangSmith.
.. note:
Must have ``LANGCHAIN_TRACING_V2`` env var set to true to see the trace in LangSmith.
Example:
.. code-block:: python
@@ -234,6 +236,7 @@ def shielded(func: Func) -> Func:
Returns:
Callable: The shielded function
"""
@functools.wraps(func)
@@ -243,6 +246,46 @@ def shielded(func: Func) -> Func:
return cast("Func", wrapped)
def _convert_llm_events(
event_name: str, args: tuple[Any, ...], kwargs: dict[str, Any]
) -> tuple[tuple[Any, ...], dict[str, Any]]:
args_list = list(args)
if (
event_name == "on_chat_model_start"
and isinstance(args_list[1], list)
and args_list[1]
and isinstance(args_list[1][0], MessageV1Types)
):
batch = [
convert_from_v1_message(item)
for item in args_list[1]
if isinstance(item, MessageV1Types)
]
args_list[1] = [batch]
elif (
event_name == "on_llm_new_token"
and "chunk" in kwargs
and isinstance(kwargs["chunk"], MessageV1Types)
):
chunk = kwargs["chunk"]
kwargs["chunk"] = ChatGenerationChunk(text=chunk.text, message=chunk)
elif event_name == "on_llm_end" and isinstance(args_list[0], MessageV1Types):
args_list[0] = LLMResult(
generations=[
[
ChatGeneration(
text=args_list[0].text,
message=convert_from_v1_message(args_list[0]),
)
]
]
)
else:
pass
return tuple(args_list), kwargs
def handle_event(
handlers: list[BaseCallbackHandler],
event_name: str,
@@ -252,15 +295,17 @@ def handle_event(
) -> None:
"""Generic event handler for CallbackManager.
Note: This function is used by LangServe to handle events.
.. note::
This function is used by ``LangServe`` to handle events.
Args:
handlers: The list of handlers that will handle the event.
event_name: The name of the event (e.g., "on_llm_start").
event_name: The name of the event (e.g., ``'on_llm_start'``).
ignore_condition_name: Name of the attribute defined on handler
that if True will cause the handler to be skipped for the given event.
*args: The arguments to pass to the event handler.
**kwargs: The keyword arguments to pass to the event handler
"""
coros: list[Coroutine[Any, Any, Any]] = []
@@ -271,6 +316,8 @@ def handle_event(
if ignore_condition_name is None or not getattr(
handler, ignore_condition_name
):
if not handler.accepts_new_messages:
args, kwargs = _convert_llm_events(event_name, args, kwargs)
event = getattr(handler, event_name)(*args, **kwargs)
if asyncio.iscoroutine(event):
coros.append(event)
@@ -365,6 +412,8 @@ async def _ahandle_event_for_handler(
) -> None:
try:
if ignore_condition_name is None or not getattr(handler, ignore_condition_name):
if not handler.accepts_new_messages:
args, kwargs = _convert_llm_events(event_name, args, kwargs)
event = getattr(handler, event_name)
if asyncio.iscoroutinefunction(event):
await event(*args, **kwargs)
@@ -415,17 +464,19 @@ async def ahandle_event(
*args: Any,
**kwargs: Any,
) -> None:
"""Async generic event handler for AsyncCallbackManager.
"""Async generic event handler for ``AsyncCallbackManager``.
Note: This function is used by LangServe to handle events.
.. note::
This function is used by ``LangServe`` to handle events.
Args:
handlers: The list of handlers that will handle the event.
event_name: The name of the event (e.g., "on_llm_start").
event_name: The name of the event (e.g., ``'on_llm_start'``).
ignore_condition_name: Name of the attribute defined on handler
that if True will cause the handler to be skipped for the given event.
*args: The arguments to pass to the event handler.
**kwargs: The keyword arguments to pass to the event handler.
"""
for handler in [h for h in handlers if h.run_inline]:
await _ahandle_event_for_handler(
@@ -477,6 +528,7 @@ class BaseRunManager(RunManagerMixin):
Defaults to None.
inheritable_metadata (Optional[dict[str, Any]]): The inheritable metadata.
Defaults to None.
"""
self.run_id = run_id
self.handlers = handlers
@@ -493,6 +545,7 @@ class BaseRunManager(RunManagerMixin):
Returns:
BaseRunManager: The noop manager.
"""
return cls(
run_id=uuid.uuid4(),
@@ -545,6 +598,7 @@ class RunManager(BaseRunManager):
Args:
retry_state (RetryCallState): The retry state.
**kwargs (Any): Additional keyword arguments.
"""
if not self.handlers:
return
@@ -572,6 +626,7 @@ class ParentRunManager(RunManager):
Returns:
CallbackManager: The child callback manager.
"""
manager = CallbackManager(handlers=[], parent_run_id=self.run_id)
manager.set_handlers(self.inheritable_handlers)
@@ -591,6 +646,7 @@ class AsyncRunManager(BaseRunManager, ABC):
Returns:
RunManager: The sync RunManager.
"""
async def on_text(
@@ -606,6 +662,7 @@ class AsyncRunManager(BaseRunManager, ABC):
Returns:
Any: The result of the callback.
"""
if not self.handlers:
return
@@ -630,6 +687,7 @@ class AsyncRunManager(BaseRunManager, ABC):
Args:
retry_state (RetryCallState): The retry state.
**kwargs (Any): Additional keyword arguments.
"""
if not self.handlers:
return
@@ -657,6 +715,7 @@ class AsyncParentRunManager(AsyncRunManager):
Returns:
AsyncCallbackManager: The child callback manager.
"""
manager = AsyncCallbackManager(handlers=[], parent_run_id=self.run_id)
manager.set_handlers(self.inheritable_handlers)
@@ -674,7 +733,9 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
self,
token: str,
*,
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
chunk: Optional[
Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]
] = None,
**kwargs: Any,
) -> None:
"""Run when LLM generates a new token.
@@ -684,6 +745,7 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
chunk (Optional[Union[GenerationChunk, ChatGenerationChunk]], optional):
The chunk. Defaults to None.
**kwargs (Any): Additional keyword arguments.
"""
if not self.handlers:
return
@@ -699,12 +761,13 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
**kwargs,
)
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
def on_llm_end(self, response: Union[LLMResult, AIMessage], **kwargs: Any) -> None:
"""Run when LLM ends running.
Args:
response (LLMResult): The LLM result.
response (LLMResult | AIMessage): The LLM result.
**kwargs (Any): Additional keyword arguments.
"""
if not self.handlers:
return
@@ -729,8 +792,9 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
Args:
error (Exception or KeyboardInterrupt): The error.
kwargs (Any): Additional keyword arguments.
- response (LLMResult): The response which was generated before
the error occurred.
- response (LLMResult | AIMessage): The response which was generated
before the error occurred.
"""
if not self.handlers:
return
@@ -754,6 +818,7 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
Returns:
CallbackManagerForLLMRun: The sync RunManager.
"""
return CallbackManagerForLLMRun(
run_id=self.run_id,
@@ -770,7 +835,9 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
self,
token: str,
*,
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
chunk: Optional[
Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]
] = None,
**kwargs: Any,
) -> None:
"""Run when LLM generates a new token.
@@ -780,6 +847,7 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
chunk (Optional[Union[GenerationChunk, ChatGenerationChunk]], optional):
The chunk. Defaults to None.
**kwargs (Any): Additional keyword arguments.
"""
if not self.handlers:
return
@@ -796,12 +864,15 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
)
@shielded
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
async def on_llm_end(
self, response: Union[LLMResult, AIMessage], **kwargs: Any
) -> None:
"""Run when LLM ends running.
Args:
response (LLMResult): The LLM result.
response (LLMResult | AIMessage): The LLM result.
**kwargs (Any): Additional keyword arguments.
"""
if not self.handlers:
return
@@ -827,10 +898,8 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
Args:
error (Exception or KeyboardInterrupt): The error.
kwargs (Any): Additional keyword arguments.
- response (LLMResult): The response which was generated before
the error occurred.
- response (LLMResult | AIMessage): The response which was generated
before the error occurred.
"""
if not self.handlers:
@@ -856,6 +925,7 @@ class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin):
Args:
outputs (Union[dict[str, Any], Any]): The outputs of the chain.
**kwargs (Any): Additional keyword arguments.
"""
if not self.handlers:
return
@@ -880,6 +950,7 @@ class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin):
Args:
error (Exception or KeyboardInterrupt): The error.
**kwargs (Any): Additional keyword arguments.
"""
if not self.handlers:
return
@@ -903,6 +974,7 @@ class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin):
Returns:
Any: The result of the callback.
"""
if not self.handlers:
return
@@ -926,6 +998,7 @@ class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin):
Returns:
Any: The result of the callback.
"""
if not self.handlers:
return
@@ -970,6 +1043,7 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin):
Args:
outputs (Union[dict[str, Any], Any]): The outputs of the chain.
**kwargs (Any): Additional keyword arguments.
"""
if not self.handlers:
return
@@ -995,6 +1069,7 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin):
Args:
error (Exception or KeyboardInterrupt): The error.
**kwargs (Any): Additional keyword arguments.
"""
if not self.handlers:
return
@@ -1018,6 +1093,7 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin):
Returns:
Any: The result of the callback.
"""
if not self.handlers:
return
@@ -1041,6 +1117,7 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin):
Returns:
Any: The result of the callback.
"""
if not self.handlers:
return
@@ -1069,6 +1146,7 @@ class CallbackManagerForToolRun(ParentRunManager, ToolManagerMixin):
Args:
output (Any): The output of the tool.
**kwargs (Any): The keyword arguments to pass to the event handler
"""
if not self.handlers:
return
@@ -1093,6 +1171,7 @@ class CallbackManagerForToolRun(ParentRunManager, ToolManagerMixin):
Args:
error (Exception or KeyboardInterrupt): The error.
**kwargs (Any): Additional keyword arguments.
"""
if not self.handlers:
return
@@ -1134,6 +1213,7 @@ class AsyncCallbackManagerForToolRun(AsyncParentRunManager, ToolManagerMixin):
Args:
output (Any): The output of the tool.
**kwargs (Any): Additional keyword arguments.
"""
if not self.handlers:
return
@@ -1158,6 +1238,7 @@ class AsyncCallbackManagerForToolRun(AsyncParentRunManager, ToolManagerMixin):
Args:
error (Exception or KeyboardInterrupt): The error.
**kwargs (Any): Additional keyword arguments.
"""
if not self.handlers:
return
@@ -1186,6 +1267,7 @@ class CallbackManagerForRetrieverRun(ParentRunManager, RetrieverManagerMixin):
Args:
documents (Sequence[Document]): The retrieved documents.
**kwargs (Any): Additional keyword arguments.
"""
if not self.handlers:
return
@@ -1210,6 +1292,7 @@ class CallbackManagerForRetrieverRun(ParentRunManager, RetrieverManagerMixin):
Args:
error (BaseException): The error.
**kwargs (Any): Additional keyword arguments.
"""
if not self.handlers:
return
@@ -1236,6 +1319,7 @@ class AsyncCallbackManagerForRetrieverRun(
Returns:
CallbackManagerForRetrieverRun: The sync RunManager.
"""
return CallbackManagerForRetrieverRun(
run_id=self.run_id,
@@ -1257,6 +1341,7 @@ class AsyncCallbackManagerForRetrieverRun(
Args:
documents (Sequence[Document]): The retrieved documents.
**kwargs (Any): Additional keyword arguments.
"""
if not self.handlers:
return
@@ -1282,6 +1367,7 @@ class AsyncCallbackManagerForRetrieverRun(
Args:
error (BaseException): The error.
**kwargs (Any): Additional keyword arguments.
"""
if not self.handlers:
return
@@ -1318,6 +1404,7 @@ class CallbackManager(BaseCallbackManager):
Returns:
list[CallbackManagerForLLMRun]: A callback manager for each
prompt as an LLM run.
"""
managers = []
for i, prompt in enumerate(prompts):
@@ -1354,7 +1441,7 @@ class CallbackManager(BaseCallbackManager):
def on_chat_model_start(
self,
serialized: dict[str, Any],
messages: list[list[BaseMessage]],
messages: Union[list[list[BaseMessage]], list[MessageV1]],
run_id: Optional[UUID] = None,
**kwargs: Any,
) -> list[CallbackManagerForLLMRun]:
@@ -1362,14 +1449,41 @@ class CallbackManager(BaseCallbackManager):
Args:
serialized (dict[str, Any]): The serialized LLM.
messages (list[list[BaseMessage]]): The list of messages.
messages (list[list[BaseMessage | MessageV1]]): The list of messages.
run_id (UUID, optional): The ID of the run. Defaults to None.
**kwargs (Any): Additional keyword arguments.
Returns:
list[CallbackManagerForLLMRun]: A callback manager for each
list of messages as an LLM run.
"""
if messages and isinstance(messages[0], MessageV1Types):
run_id_ = run_id if run_id is not None else uuid.uuid4()
handle_event(
self.handlers,
"on_chat_model_start",
"ignore_chat_model",
serialized,
messages,
run_id=run_id_,
parent_run_id=self.parent_run_id,
tags=self.tags,
metadata=self.metadata,
**kwargs,
)
return [
CallbackManagerForLLMRun(
run_id=run_id_,
handlers=self.handlers,
inheritable_handlers=self.inheritable_handlers,
parent_run_id=self.parent_run_id,
tags=self.tags,
inheritable_tags=self.inheritable_tags,
metadata=self.metadata,
inheritable_metadata=self.inheritable_metadata,
)
]
managers = []
for message_list in messages:
if run_id is not None:
@@ -1422,6 +1536,7 @@ class CallbackManager(BaseCallbackManager):
Returns:
CallbackManagerForChainRun: The callback manager for the chain run.
"""
if run_id is None:
run_id = uuid.uuid4()
@@ -1476,6 +1591,7 @@ class CallbackManager(BaseCallbackManager):
Returns:
CallbackManagerForToolRun: The callback manager for the tool run.
"""
if run_id is None:
run_id = uuid.uuid4()
@@ -1522,6 +1638,7 @@ class CallbackManager(BaseCallbackManager):
run_id (UUID, optional): The ID of the run. Defaults to None.
parent_run_id (UUID, optional): The ID of the parent run. Defaults to None.
**kwargs (Any): Additional keyword arguments.
"""
if run_id is None:
run_id = uuid.uuid4()
@@ -1569,6 +1686,7 @@ class CallbackManager(BaseCallbackManager):
run_id: The ID of the run. Defaults to None.
.. versionadded:: 0.2.14
"""
if not self.handlers:
return
@@ -1623,6 +1741,7 @@ class CallbackManager(BaseCallbackManager):
Returns:
CallbackManager: The configured callback manager.
"""
return _configure(
cls,
@@ -1657,6 +1776,7 @@ class CallbackManagerForChainGroup(CallbackManager):
parent_run_id (Optional[UUID]): The ID of the parent run. Defaults to None.
parent_run_manager (CallbackManagerForChainRun): The parent run manager.
**kwargs (Any): Additional keyword arguments.
"""
super().__init__(
handlers,
@@ -1745,6 +1865,7 @@ class CallbackManagerForChainGroup(CallbackManager):
Args:
outputs (Union[dict[str, Any], Any]): The outputs of the chain.
**kwargs (Any): Additional keyword arguments.
"""
self.ended = True
return self.parent_run_manager.on_chain_end(outputs, **kwargs)
@@ -1759,6 +1880,7 @@ class CallbackManagerForChainGroup(CallbackManager):
Args:
error (Exception or KeyboardInterrupt): The error.
**kwargs (Any): Additional keyword arguments.
"""
self.ended = True
return self.parent_run_manager.on_chain_error(error, **kwargs)
@@ -1864,7 +1986,7 @@ class AsyncCallbackManager(BaseCallbackManager):
async def on_chat_model_start(
self,
serialized: dict[str, Any],
messages: list[list[BaseMessage]],
messages: Union[list[list[BaseMessage]], list[MessageV1]],
run_id: Optional[UUID] = None,
**kwargs: Any,
) -> list[AsyncCallbackManagerForLLMRun]:
@@ -1872,7 +1994,7 @@ class AsyncCallbackManager(BaseCallbackManager):
Args:
serialized (dict[str, Any]): The serialized LLM.
messages (list[list[BaseMessage]]): The list of messages.
messages (list[list[BaseMessage | MessageV1]]): The list of messages.
run_id (UUID, optional): The ID of the run. Defaults to None.
**kwargs (Any): Additional keyword arguments.
@@ -1881,10 +2003,51 @@ class AsyncCallbackManager(BaseCallbackManager):
async callback managers, one for each LLM Run
corresponding to each inner message list.
"""
if messages and isinstance(messages[0], MessageV1Types):
run_id_ = run_id if run_id is not None else uuid.uuid4()
inline_tasks = []
non_inline_tasks = []
for handler in self.handlers:
task = ahandle_event(
[handler],
"on_chat_model_start",
"ignore_chat_model",
serialized,
messages,
run_id=run_id_,
parent_run_id=self.parent_run_id,
tags=self.tags,
metadata=self.metadata,
**kwargs,
)
if handler.run_inline:
inline_tasks.append(task)
else:
non_inline_tasks.append(task)
managers = [
AsyncCallbackManagerForLLMRun(
run_id=run_id_,
handlers=self.handlers,
inheritable_handlers=self.inheritable_handlers,
parent_run_id=self.parent_run_id,
tags=self.tags,
inheritable_tags=self.inheritable_tags,
metadata=self.metadata,
inheritable_metadata=self.inheritable_metadata,
)
]
# Run inline tasks sequentially
for task in inline_tasks:
await task
# Run non-inline tasks concurrently
if non_inline_tasks:
await asyncio.gather(*non_inline_tasks)
return managers
inline_tasks = []
non_inline_tasks = []
managers = []
for message_list in messages:
if run_id is not None:
run_id_ = run_id

View File

@@ -3,7 +3,7 @@
from __future__ import annotations
import sys
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Union
from typing_extensions import override
@@ -13,6 +13,7 @@ if TYPE_CHECKING:
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.messages import BaseMessage
from langchain_core.outputs import LLMResult
from langchain_core.v1.messages import AIMessage, MessageV1
class StreamingStdOutCallbackHandler(BaseCallbackHandler):
@@ -32,7 +33,7 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler):
def on_chat_model_start(
self,
serialized: dict[str, Any],
messages: list[list[BaseMessage]],
messages: Union[list[list[BaseMessage]], list[MessageV1]],
**kwargs: Any,
) -> None:
"""Run when LLM starts running.
@@ -54,7 +55,7 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler):
sys.stdout.write(token)
sys.stdout.flush()
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
def on_llm_end(self, response: Union[LLMResult, AIMessage], **kwargs: Any) -> None:
"""Run when LLM ends running.
Args:

View File

@@ -4,14 +4,16 @@ import threading
from collections.abc import Generator
from contextlib import contextmanager
from contextvars import ContextVar
from typing import Any, Optional
from typing import Any, Optional, Union
from typing_extensions import override
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.messages import AIMessage
from langchain_core.messages.ai import UsageMetadata, add_usage
from langchain_core.messages.utils import convert_from_v1_message
from langchain_core.outputs import ChatGeneration, LLMResult
from langchain_core.v1.messages import AIMessage as AIMessageV1
class UsageMetadataCallbackHandler(BaseCallbackHandler):
@@ -58,9 +60,17 @@ class UsageMetadataCallbackHandler(BaseCallbackHandler):
return str(self.usage_metadata)
@override
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
def on_llm_end(
self, response: Union[LLMResult, AIMessageV1], **kwargs: Any
) -> None:
"""Collect token usage."""
# Check for usage_metadata (langchain-core >= 0.2.2)
if isinstance(response, AIMessageV1):
response = LLMResult(
generations=[
[ChatGeneration(message=convert_from_v1_message(response))]
]
)
try:
generation = response.generations[0][0]
except IndexError:

View File

@@ -117,9 +117,9 @@ class BaseChatMessageHistory(ABC):
def add_user_message(self, message: Union[HumanMessage, str]) -> None:
"""Convenience method for adding a human message string to the store.
Please note that this is a convenience method. Code should favor the
bulk add_messages interface instead to save on round-trips to the underlying
persistence layer.
.. note::
This is a convenience method. Code should favor the bulk ``add_messages``
interface instead to save on round-trips to the persistence layer.
This method may be deprecated in a future release.
@@ -134,9 +134,9 @@ class BaseChatMessageHistory(ABC):
def add_ai_message(self, message: Union[AIMessage, str]) -> None:
"""Convenience method for adding an AI message string to the store.
Please note that this is a convenience method. Code should favor the bulk
add_messages interface instead to save on round-trips to the underlying
persistence layer.
.. note::
This is a convenience method. Code should favor the bulk ``add_messages``
interface instead to save on round-trips to the persistence layer.
This method may be deprecated in a future release.

View File

@@ -19,17 +19,18 @@ if TYPE_CHECKING:
class BaseDocumentCompressor(BaseModel, ABC):
"""Base class for document compressors.
This abstraction is primarily used for
post-processing of retrieved documents.
This abstraction is primarily used for post-processing of retrieved documents.
Documents matching a given query are first retrieved.
Then the list of documents can be further processed.
For example, one could re-rank the retrieved documents
using an LLM.
For example, one could re-rank the retrieved documents using an LLM.
.. note::
Users should favor using a RunnableLambda instead of sub-classing from this
interface.
**Note** users should favor using a RunnableLambda
instead of sub-classing from this interface.
"""
@abstractmethod
@@ -48,6 +49,7 @@ class BaseDocumentCompressor(BaseModel, ABC):
Returns:
The compressed documents.
"""
async def acompress_documents(
@@ -65,6 +67,7 @@ class BaseDocumentCompressor(BaseModel, ABC):
Returns:
The compressed documents.
"""
return await run_in_executor(
None, self.compress_documents, documents, query, callbacks

View File

@@ -488,8 +488,8 @@ class DeleteResponse(TypedDict, total=False):
failed: Sequence[str]
"""The IDs that failed to be deleted.
Please note that deleting an ID that
does not exist is **NOT** considered a failure.
.. warning::
Deleting an ID that does not exist is **NOT** considered a failure.
"""
num_failed: int

View File

@@ -1,8 +1,10 @@
import copy
import re
from collections.abc import Sequence
from typing import Optional
from langchain_core.messages import BaseMessage
from langchain_core.v1.messages import MessageV1
def _is_openai_data_block(block: dict) -> bool:
@@ -138,3 +140,37 @@ def _normalize_messages(messages: Sequence[BaseMessage]) -> list[BaseMessage]:
formatted_messages.append(formatted_message)
return formatted_messages
def _normalize_messages_v1(messages: Sequence[MessageV1]) -> list[MessageV1]:
"""Extend support for message formats.
Chat models implement support for images in OpenAI Chat Completions format, as well
as other multimodal data as standard data blocks. This function extends support to
audio and file data in OpenAI Chat Completions format by converting them to standard
data blocks.
"""
formatted_messages = []
for message in messages:
formatted_message = message
if isinstance(message.content, list):
for idx, block in enumerate(message.content):
if (
isinstance(block, dict)
# Subset to (PDF) files and audio, as most relevant chat models
# support images in OAI format (and some may not yet support the
# standard data block format)
and block.get("type") in {"file", "input_audio"}
and _is_openai_data_block(block) # type: ignore[arg-type]
):
if formatted_message is message:
formatted_message = copy.copy(message)
# Also shallow-copy content
formatted_message.content = list(formatted_message.content)
formatted_message.content[idx] = ( # type: ignore[call-overload]
_convert_openai_format_to_data_block(block) # type: ignore[arg-type]
)
formatted_messages.append(formatted_message)
return formatted_messages

View File

@@ -31,6 +31,7 @@ from langchain_core.messages import (
from langchain_core.prompt_values import PromptValue
from langchain_core.runnables import Runnable, RunnableSerializable
from langchain_core.utils import get_pydantic_field_names
from langchain_core.v1.messages import AIMessage as AIMessageV1
if TYPE_CHECKING:
from langchain_core.outputs import LLMResult
@@ -57,8 +58,8 @@ class LangSmithParams(TypedDict, total=False):
def get_tokenizer() -> Any:
"""Get a GPT-2 tokenizer instance.
This function is cached to avoid re-loading the tokenizer
every time it is called.
This function is cached to avoid re-loading the tokenizer every time it is called.
"""
try:
from transformers import GPT2TokenizerFast # type: ignore[import-not-found]
@@ -85,7 +86,9 @@ def _get_token_ids_default_method(text: str) -> list[int]:
LanguageModelInput = Union[PromptValue, str, Sequence[MessageLikeRepresentation]]
LanguageModelOutput = Union[BaseMessage, str]
LanguageModelLike = Runnable[LanguageModelInput, LanguageModelOutput]
LanguageModelOutputVar = TypeVar("LanguageModelOutputVar", BaseMessage, str)
LanguageModelOutputVar = TypeVar(
"LanguageModelOutputVar", BaseMessage, str, AIMessageV1
)
def _get_verbosity() -> bool:
@@ -99,7 +102,8 @@ class BaseLanguageModel(
):
"""Abstract base class for interfacing with language models.
All language model wrappers inherited from BaseLanguageModel.
All language model wrappers inherited from ``BaseLanguageModel``.
"""
cache: Union[BaseCache, bool, None] = Field(default=None, exclude=True)
@@ -108,9 +112,10 @@ class BaseLanguageModel(
* If true, will use the global cache.
* If false, will not use a cache
* If None, will use the global cache if it's set, otherwise no cache.
* If instance of BaseCache, will use the provided cache.
* If instance of ``BaseCache``, will use the provided cache.
Caching is not currently supported for streaming methods of models.
"""
verbose: bool = Field(default_factory=_get_verbosity, exclude=True, repr=False)
"""Whether to print out response text."""
@@ -140,6 +145,7 @@ class BaseLanguageModel(
Returns:
The verbosity setting to use.
"""
if verbose is None:
return _get_verbosity()
@@ -195,7 +201,8 @@ class BaseLanguageModel(
Returns:
An LLMResult, which contains a list of candidate Generations for each input
prompt and additional model provider-specific output.
prompt and additional model provider-specific output.
"""
@abstractmethod
@@ -229,8 +236,9 @@ class BaseLanguageModel(
to the model provider API call.
Returns:
An LLMResult, which contains a list of candidate Generations for each input
prompt and additional model provider-specific output.
An ``LLMResult``, which contains a list of candidate Generations for each
input prompt and additional model provider-specific output.
"""
def with_structured_output(
@@ -248,8 +256,8 @@ class BaseLanguageModel(
) -> str:
"""Pass a single string input to the model and return a string.
Use this method when passing in raw text. If you want to pass in specific
types of chat messages, use predict_messages.
Use this method when passing in raw text. If you want to pass in specific types
of chat messages, use predict_messages.
Args:
text: String input to pass to the model.
@@ -260,6 +268,7 @@ class BaseLanguageModel(
Returns:
Top model prediction as a string.
"""
@deprecated("0.1.7", alternative="invoke", removal="1.0")
@@ -274,7 +283,7 @@ class BaseLanguageModel(
"""Pass a message sequence to the model and return a message.
Use this method when passing in chat messages. If you want to pass in raw text,
use predict.
use predict.
Args:
messages: A sequence of chat messages corresponding to a single model input.
@@ -285,6 +294,7 @@ class BaseLanguageModel(
Returns:
Top model prediction as a message.
"""
@deprecated("0.1.7", alternative="ainvoke", removal="1.0")
@@ -295,7 +305,7 @@ class BaseLanguageModel(
"""Asynchronously pass a string to the model and return a string.
Use this method when calling pure text generation models and only the top
candidate generation is needed.
candidate generation is needed.
Args:
text: String input to pass to the model.
@@ -306,6 +316,7 @@ class BaseLanguageModel(
Returns:
Top model prediction as a string.
"""
@deprecated("0.1.7", alternative="ainvoke", removal="1.0")
@@ -319,8 +330,8 @@ class BaseLanguageModel(
) -> BaseMessage:
"""Asynchronously pass messages to the model and return a message.
Use this method when calling chat models and only the top
candidate generation is needed.
Use this method when calling chat models and only the top candidate generation
is needed.
Args:
messages: A sequence of chat messages corresponding to a single model input.
@@ -331,6 +342,7 @@ class BaseLanguageModel(
Returns:
Top model prediction as a message.
"""
@property
@@ -346,7 +358,8 @@ class BaseLanguageModel(
Returns:
A list of ids corresponding to the tokens in the text, in order they occur
in the text.
in the text.
"""
if self.custom_get_token_ids is not None:
return self.custom_get_token_ids(text)
@@ -362,6 +375,7 @@ class BaseLanguageModel(
Returns:
The integer number of tokens in the text.
"""
return len(self.get_token_ids(text))
@@ -374,16 +388,18 @@ class BaseLanguageModel(
Useful for checking if an input fits in a model's context window.
**Note**: the base implementation of get_num_tokens_from_messages ignores
tool schemas.
.. note::
The base implementation of ``get_num_tokens_from_messages`` ignores tool
schemas.
Args:
messages: The message inputs to tokenize.
tools: If provided, sequence of dict, BaseModel, function, or BaseTools
to be converted to tool schemas.
tools: If provided, sequence of dict, ``BaseModel``, function, or
``BaseTools`` to be converted to tool schemas.
Returns:
The sum of the number of tokens across the messages.
"""
if tools is not None:
warnings.warn(
@@ -396,6 +412,7 @@ class BaseLanguageModel(
def _all_required_field_names(cls) -> set:
"""DEPRECATED: Kept for backwards compatibility.
Use get_pydantic_field_names.
Use ``get_pydantic_field_names``.
"""
return get_pydantic_field_names(cls)

View File

@@ -97,17 +97,18 @@ def _generate_response_from_error(error: BaseException) -> list[ChatGeneration]:
def _format_for_tracing(messages: list[BaseMessage]) -> list[BaseMessage]:
"""Format messages for tracing in on_chat_model_start.
"""Format messages for tracing in ``on_chat_model_start``.
- Update image content blocks to OpenAI Chat Completions format (backward
compatibility).
- Add "type" key to content blocks that have a single key.
- Add ``type`` key to content blocks that have a single key.
Args:
messages: List of messages to format.
Returns:
List of messages formatted for tracing.
"""
messages_to_trace = []
for message in messages:
@@ -153,10 +154,11 @@ def generate_from_stream(stream: Iterator[ChatGenerationChunk]) -> ChatResult:
"""Generate from a stream.
Args:
stream: Iterator of ChatGenerationChunk.
stream: Iterator of ``ChatGenerationChunk``.
Returns:
ChatResult: Chat result.
"""
generation = next(stream, None)
if generation:
@@ -180,10 +182,11 @@ async def agenerate_from_stream(
"""Async generate from a stream.
Args:
stream: Iterator of ChatGenerationChunk.
stream: Iterator of ``ChatGenerationChunk``.
Returns:
ChatResult: Chat result.
"""
chunks = [chunk async for chunk in stream]
return await run_in_executor(None, generate_from_stream, iter(chunks))
@@ -311,15 +314,16 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
provided. This offers the best of both worlds.
- If False (default), will always use streaming case if available.
The main reason for this flag is that code might be written using ``.stream()`` and
The main reason for this flag is that code might be written using ``stream()`` and
a user may want to swap out a given model for another model whose the implementation
does not properly support streaming.
"""
@model_validator(mode="before")
@classmethod
def raise_deprecation(cls, values: dict) -> Any:
"""Raise deprecation warning if callback_manager is used.
"""Raise deprecation warning if ``callback_manager`` is used.
Args:
values (Dict): Values to validate.
@@ -328,7 +332,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
Dict: Validated values.
Raises:
DeprecationWarning: If callback_manager is used.
DeprecationWarning: If ``callback_manager`` is used.
"""
if values.get("callback_manager") is not None:
warnings.warn(
@@ -653,6 +658,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
Returns:
List of ChatGeneration objects.
"""
converted_generations = []
for gen in cache_val:
@@ -666,6 +672,16 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
converted_generations.append(chat_gen)
else:
# Already a ChatGeneration or other expected type
if hasattr(gen, "message") and isinstance(gen.message, AIMessage):
# We zero out cost on cache hits
gen.message = gen.message.model_copy(
update={
"usage_metadata": {
**(gen.message.usage_metadata or {}),
"total_cost": 0,
}
}
)
converted_generations.append(gen)
return converted_generations
@@ -768,7 +784,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
Returns:
An LLMResult, which contains a list of candidate Generations for each input
prompt and additional model provider-specific output.
prompt and additional model provider-specific output.
"""
ls_structured_output_format = kwargs.pop(
"ls_structured_output_format", None
@@ -882,7 +899,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
Returns:
An LLMResult, which contains a list of candidate Generations for each input
prompt and additional model provider-specific output.
prompt and additional model provider-specific output.
"""
ls_structured_output_format = kwargs.pop(
"ls_structured_output_format", None
@@ -1238,6 +1256,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
Returns:
The model output message.
"""
generation = self.generate(
[messages], stop=stop, callbacks=callbacks, **kwargs
@@ -1278,6 +1297,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
Returns:
The model output string.
"""
return self.predict(message, stop=stop, **kwargs)
@@ -1297,6 +1317,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
Returns:
The predicted output string.
"""
stop_ = None if stop is None else list(stop)
result = self([HumanMessage(content=text)], stop=stop_, **kwargs)
@@ -1372,6 +1393,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
Returns:
A Runnable that returns a message.
"""
raise NotImplementedError
@@ -1534,8 +1556,10 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
class SimpleChatModel(BaseChatModel):
"""Simplified implementation for a chat model to inherit from.
**Note** This implementation is primarily here for backwards compatibility.
For new implementations, please use `BaseChatModel` directly.
.. note::
This implementation is primarily here for backwards compatibility. For new
implementations, please use ``BaseChatModel`` directly.
"""
def _generate(

View File

@@ -3,7 +3,7 @@
import asyncio
import re
import time
from collections.abc import AsyncIterator, Iterator
from collections.abc import AsyncIterator, Iterable, Iterator
from typing import Any, Optional, Union, cast
from typing_extensions import override
@@ -16,6 +16,10 @@ from langchain_core.language_models.chat_models import BaseChatModel, SimpleChat
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.runnables import RunnableConfig
from langchain_core.v1.chat_models import BaseChatModel as BaseChatModelV1
from langchain_core.v1.messages import AIMessage as AIMessageV1
from langchain_core.v1.messages import AIMessageChunk as AIMessageChunkV1
from langchain_core.v1.messages import MessageV1
class FakeMessagesListChatModel(BaseChatModel):
@@ -223,11 +227,12 @@ class GenericFakeChatModel(BaseChatModel):
This can be expanded to accept other types like Callables / dicts / strings
to make the interface more generic if needed.
Note: if you want to pass a list, you can use `iter` to convert it to an iterator.
.. note::
if you want to pass a list, you can use ``iter`` to convert it to an iterator.
Please note that streaming is not implemented yet. We should try to implement it
in the future by delegating to invoke and then breaking the resulting output
into message chunks.
.. warning::
Streaming is not implemented yet. We should try to implement it in the future by
delegating to invoke and then breaking the resulting output into message chunks.
"""
@override
@@ -367,3 +372,69 @@ class ParrotFakeChatModel(BaseChatModel):
@property
def _llm_type(self) -> str:
return "parrot-fake-chat-model"
class GenericFakeChatModelV1(BaseChatModelV1):
"""Generic fake chat model that can be used to test the chat model interface."""
messages: Optional[Iterator[Union[AIMessageV1, str]]] = None
message_chunks: Optional[Iterable[Union[AIMessageChunkV1, str]]] = None
@override
def _invoke(
self,
messages: list[MessageV1],
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AIMessageV1:
"""Top Level call."""
if self.messages is None:
error_msg = "Messages iterator is not set."
raise ValueError(error_msg)
message = next(self.messages)
return AIMessageV1(content=message) if isinstance(message, str) else message
@override
def _stream(
self,
messages: list[MessageV1],
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[AIMessageChunkV1]:
"""Top Level call."""
if self.message_chunks is None:
error_msg = "Message chunks iterator is not set."
raise ValueError(error_msg)
for chunk in self.message_chunks:
if isinstance(chunk, str):
yield AIMessageChunkV1(chunk)
else:
yield chunk
@property
def _llm_type(self) -> str:
return "generic-fake-chat-model"
class ParrotFakeChatModelV1(BaseChatModelV1):
"""Generic fake chat model that can be used to test the chat model interface.
* Chat model should be usable in both sync and async tests
"""
@override
def _invoke(
self,
messages: list[MessageV1],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AIMessageV1:
"""Top Level call."""
if isinstance(messages[-1], AIMessageV1):
return messages[-1]
return AIMessageV1(content=messages[-1].content)
@property
def _llm_type(self) -> str:
return "parrot-fake-chat-model"

View File

@@ -1,11 +1,14 @@
"""Dump objects to json."""
import dataclasses
import inspect
import json
from typing import Any
from pydantic import BaseModel
from langchain_core.load.serializable import Serializable, to_json_not_implemented
from langchain_core.v1.messages import MessageV1Types
def default(obj: Any) -> Any:
@@ -19,6 +22,24 @@ def default(obj: Any) -> Any:
"""
if isinstance(obj, Serializable):
return obj.to_json()
# Handle v1 message classes
if type(obj) in MessageV1Types:
# Get the constructor signature to only include valid parameters
init_sig = inspect.signature(type(obj).__init__)
valid_params = set(init_sig.parameters.keys()) - {"self"}
# Filter dataclass fields to only include constructor params
all_fields = dataclasses.asdict(obj)
kwargs = {k: v for k, v in all_fields.items() if k in valid_params}
return {
"lc": 1,
"type": "constructor",
"id": ["langchain_core", "v1", "messages", type(obj).__name__],
"kwargs": kwargs,
}
return to_json_not_implemented(obj)
@@ -73,10 +94,9 @@ def dumps(obj: Any, *, pretty: bool = False, **kwargs: Any) -> str:
def dumpd(obj: Any) -> Any:
"""Return a dict representation of an object.
Note:
Unfortunately this function is not as efficient as it could be
because it first dumps the object to a json string and then loads it
back into a dictionary.
.. note::
Unfortunately this function is not as efficient as it could be because it first
dumps the object to a json string and then loads it back into a dictionary.
Args:
obj: The object to dump.

View File

@@ -156,8 +156,13 @@ class Reviver:
cls = getattr(mod, name)
# The class must be a subclass of Serializable.
if not issubclass(cls, Serializable):
# Import MessageV1Types lazily to avoid circular import:
# load.load -> v1.messages -> messages.ai -> messages.base ->
# load.serializable -> load.__init__ -> load.load
from langchain_core.v1.messages import MessageV1Types
# The class must be a subclass of Serializable or a v1 message class.
if not (issubclass(cls, Serializable) or cls in MessageV1Types):
msg = f"Invalid namespace: {value}"
raise ValueError(msg)

View File

@@ -33,9 +33,31 @@ if TYPE_CHECKING:
)
from langchain_core.messages.chat import ChatMessage, ChatMessageChunk
from langchain_core.messages.content_blocks import (
Annotation,
AudioContentBlock,
Citation,
CodeInterpreterCall,
CodeInterpreterOutput,
CodeInterpreterResult,
ContentBlock,
DataContentBlock,
FileContentBlock,
ImageContentBlock,
NonStandardAnnotation,
NonStandardContentBlock,
PlainTextContentBlock,
ReasoningContentBlock,
TextContentBlock,
VideoContentBlock,
WebSearchCall,
WebSearchResult,
convert_to_openai_data_block,
convert_to_openai_image_block,
is_data_content_block,
is_reasoning_block,
is_text_block,
is_tool_call_block,
is_tool_call_chunk,
)
from langchain_core.messages.function import FunctionMessage, FunctionMessageChunk
from langchain_core.messages.human import HumanMessage, HumanMessageChunk
@@ -65,24 +87,42 @@ if TYPE_CHECKING:
__all__ = (
"AIMessage",
"AIMessageChunk",
"Annotation",
"AnyMessage",
"AudioContentBlock",
"BaseMessage",
"BaseMessageChunk",
"ChatMessage",
"ChatMessageChunk",
"Citation",
"CodeInterpreterCall",
"CodeInterpreterOutput",
"CodeInterpreterResult",
"ContentBlock",
"DataContentBlock",
"FileContentBlock",
"FunctionMessage",
"FunctionMessageChunk",
"HumanMessage",
"HumanMessageChunk",
"ImageContentBlock",
"InvalidToolCall",
"MessageLikeRepresentation",
"NonStandardAnnotation",
"NonStandardContentBlock",
"PlainTextContentBlock",
"ReasoningContentBlock",
"RemoveMessage",
"SystemMessage",
"SystemMessageChunk",
"TextContentBlock",
"ToolCall",
"ToolCallChunk",
"ToolMessage",
"ToolMessageChunk",
"VideoContentBlock",
"WebSearchCall",
"WebSearchResult",
"_message_from_dict",
"convert_to_messages",
"convert_to_openai_data_block",
@@ -91,6 +131,10 @@ __all__ = (
"filter_messages",
"get_buffer_string",
"is_data_content_block",
"is_reasoning_block",
"is_text_block",
"is_tool_call_block",
"is_tool_call_chunk",
"merge_content",
"merge_message_runs",
"message_chunk_to_message",
@@ -103,25 +147,43 @@ __all__ = (
_dynamic_imports = {
"AIMessage": "ai",
"AIMessageChunk": "ai",
"Annotation": "content_blocks",
"AudioContentBlock": "content_blocks",
"BaseMessage": "base",
"BaseMessageChunk": "base",
"merge_content": "base",
"message_to_dict": "base",
"messages_to_dict": "base",
"Citation": "content_blocks",
"ContentBlock": "content_blocks",
"ChatMessage": "chat",
"ChatMessageChunk": "chat",
"CodeInterpreterCall": "content_blocks",
"CodeInterpreterOutput": "content_blocks",
"CodeInterpreterResult": "content_blocks",
"DataContentBlock": "content_blocks",
"FileContentBlock": "content_blocks",
"FunctionMessage": "function",
"FunctionMessageChunk": "function",
"HumanMessage": "human",
"HumanMessageChunk": "human",
"NonStandardAnnotation": "content_blocks",
"NonStandardContentBlock": "content_blocks",
"PlainTextContentBlock": "content_blocks",
"ReasoningContentBlock": "content_blocks",
"RemoveMessage": "modifier",
"SystemMessage": "system",
"SystemMessageChunk": "system",
"WebSearchCall": "content_blocks",
"WebSearchResult": "content_blocks",
"ImageContentBlock": "content_blocks",
"InvalidToolCall": "tool",
"TextContentBlock": "content_blocks",
"ToolCall": "tool",
"ToolCallChunk": "tool",
"ToolMessage": "tool",
"ToolMessageChunk": "tool",
"VideoContentBlock": "content_blocks",
"AnyMessage": "utils",
"MessageLikeRepresentation": "utils",
"_message_from_dict": "utils",
@@ -132,6 +194,10 @@ _dynamic_imports = {
"filter_messages": "utils",
"get_buffer_string": "utils",
"is_data_content_block": "content_blocks",
"is_reasoning_block": "content_blocks",
"is_text_block": "content_blocks",
"is_tool_call_block": "content_blocks",
"is_tool_call_chunk": "content_blocks",
"merge_message_runs": "utils",
"message_chunk_to_message": "utils",
"messages_from_dict": "utils",

View File

@@ -8,11 +8,7 @@ from typing import Any, Literal, Optional, Union, cast
from pydantic import model_validator
from typing_extensions import NotRequired, Self, TypedDict, override
from langchain_core.messages.base import (
BaseMessage,
BaseMessageChunk,
merge_content,
)
from langchain_core.messages.base import BaseMessage, BaseMessageChunk, merge_content
from langchain_core.messages.tool import (
InvalidToolCall,
ToolCall,
@@ -20,23 +16,26 @@ from langchain_core.messages.tool import (
default_tool_chunk_parser,
default_tool_parser,
)
from langchain_core.messages.tool import (
invalid_tool_call as create_invalid_tool_call,
)
from langchain_core.messages.tool import (
tool_call as create_tool_call,
)
from langchain_core.messages.tool import (
tool_call_chunk as create_tool_call_chunk,
)
from langchain_core.messages.tool import invalid_tool_call as create_invalid_tool_call
from langchain_core.messages.tool import tool_call as create_tool_call
from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chunk
from langchain_core.utils._merge import merge_dicts, merge_lists
from langchain_core.utils.json import parse_partial_json
from langchain_core.utils.usage import _dict_int_op
logger = logging.getLogger(__name__)
_LC_AUTO_PREFIX = "lc_"
"""LangChain auto-generated ID prefix for messages and content blocks."""
_LC_ID_PREFIX = "run-"
_LC_ID_PREFIX = f"{_LC_AUTO_PREFIX}run-"
"""Internal tracing/callback system identifier.
Used for:
- Tracing. Every LangChain operation (LLM call, chain execution, tool use, etc.)
gets a unique run_id (UUID)
- Enables tracking parent-child relationships between operations
"""
class InputTokenDetails(TypedDict, total=False):
@@ -428,17 +427,27 @@ def add_ai_message_chunks(
chunk_id = None
candidates = [left.id] + [o.id for o in others]
# first pass: pick the first non-run-* id
# first pass: pick the first provider-assigned id (non-run-* and non-lc_*)
for id_ in candidates:
if id_ and not id_.startswith(_LC_ID_PREFIX):
if (
id_
and not id_.startswith(_LC_ID_PREFIX)
and not id_.startswith(_LC_AUTO_PREFIX)
):
chunk_id = id_
break
else:
# second pass: no provider-assigned id found, just take the first non-null
# second pass: prefer lc_run-* ids over lc_* ids
for id_ in candidates:
if id_:
if id_ and id_.startswith(_LC_ID_PREFIX):
chunk_id = id_
break
else:
# third pass: take any remaining id (auto-generated lc_* ids)
for id_ in candidates:
if id_:
chunk_id = id_
break
return left.__class__(
example=left.example,

File diff suppressed because it is too large Load Diff

View File

@@ -13,7 +13,7 @@ class RemoveMessage(BaseMessage):
def __init__(
self,
id: str, # noqa: A002
id: str,
**kwargs: Any,
) -> None:
"""Create a RemoveMessage.

View File

@@ -5,9 +5,12 @@ from typing import Any, Literal, Optional, Union
from uuid import UUID
from pydantic import Field, model_validator
from typing_extensions import NotRequired, TypedDict, override
from typing_extensions import override
from langchain_core.messages.base import BaseMessage, BaseMessageChunk, merge_content
from langchain_core.messages.content_blocks import InvalidToolCall as InvalidToolCall
from langchain_core.messages.content_blocks import ToolCall as ToolCall
from langchain_core.messages.content_blocks import ToolCallChunk as ToolCallChunk
from langchain_core.utils._merge import merge_dicts, merge_obj
@@ -177,42 +180,11 @@ class ToolMessageChunk(ToolMessage, BaseMessageChunk):
return super().__add__(other)
class ToolCall(TypedDict):
"""Represents a request to call a tool.
Example:
.. code-block:: python
{
"name": "foo",
"args": {"a": 1},
"id": "123"
}
This represents a request to call the tool named "foo" with arguments {"a": 1}
and an identifier of "123".
"""
name: str
"""The name of the tool to be called."""
args: dict[str, Any]
"""The arguments to the tool call."""
id: Optional[str]
"""An identifier associated with the tool call.
An identifier is needed to associate a tool call request with a tool
call result in events when multiple concurrent tool calls are made.
"""
type: NotRequired[Literal["tool_call"]]
def tool_call(
*,
name: str,
args: dict[str, Any],
id: Optional[str], # noqa: A002
id: Optional[str],
) -> ToolCall:
"""Create a tool call.
@@ -224,43 +196,11 @@ def tool_call(
return ToolCall(name=name, args=args, id=id, type="tool_call")
class ToolCallChunk(TypedDict):
"""A chunk of a tool call (e.g., as part of a stream).
When merging ToolCallChunks (e.g., via AIMessageChunk.__add__),
all string attributes are concatenated. Chunks are only merged if their
values of `index` are equal and not None.
Example:
.. code-block:: python
left_chunks = [ToolCallChunk(name="foo", args='{"a":', index=0)]
right_chunks = [ToolCallChunk(name=None, args='1}', index=0)]
(
AIMessageChunk(content="", tool_call_chunks=left_chunks)
+ AIMessageChunk(content="", tool_call_chunks=right_chunks)
).tool_call_chunks == [ToolCallChunk(name='foo', args='{"a":1}', index=0)]
"""
name: Optional[str]
"""The name of the tool to be called."""
args: Optional[str]
"""The arguments to the tool call."""
id: Optional[str]
"""An identifier associated with the tool call."""
index: Optional[int]
"""The index of the tool call in a sequence."""
type: NotRequired[Literal["tool_call_chunk"]]
def tool_call_chunk(
*,
name: Optional[str] = None,
args: Optional[str] = None,
id: Optional[str] = None, # noqa: A002
id: Optional[str] = None,
index: Optional[int] = None,
) -> ToolCallChunk:
"""Create a tool call chunk.
@@ -276,29 +216,11 @@ def tool_call_chunk(
)
class InvalidToolCall(TypedDict):
"""Allowance for errors made by LLM.
Here we add an `error` key to surface errors made during generation
(e.g., invalid JSON arguments.)
"""
name: Optional[str]
"""The name of the tool to be called."""
args: Optional[str]
"""The arguments to the tool call."""
id: Optional[str]
"""An identifier associated with the tool call."""
error: Optional[str]
"""An error message associated with the tool call."""
type: NotRequired[Literal["invalid_tool_call"]]
def invalid_tool_call(
*,
name: Optional[str] = None,
args: Optional[str] = None,
id: Optional[str] = None, # noqa: A002
id: Optional[str] = None,
error: Optional[str] = None,
) -> InvalidToolCall:
"""Create an invalid tool call.

View File

@@ -35,11 +35,18 @@ from langchain_core.messages import convert_to_openai_data_block, is_data_conten
from langchain_core.messages.ai import AIMessage, AIMessageChunk
from langchain_core.messages.base import BaseMessage, BaseMessageChunk
from langchain_core.messages.chat import ChatMessage, ChatMessageChunk
from langchain_core.messages.content_blocks import ContentBlock
from langchain_core.messages.function import FunctionMessage, FunctionMessageChunk
from langchain_core.messages.human import HumanMessage, HumanMessageChunk
from langchain_core.messages.modifier import RemoveMessage
from langchain_core.messages.system import SystemMessage, SystemMessageChunk
from langchain_core.messages.tool import ToolCall, ToolMessage, ToolMessageChunk
from langchain_core.v1.messages import AIMessage as AIMessageV1
from langchain_core.v1.messages import AIMessageChunk as AIMessageChunkV1
from langchain_core.v1.messages import HumanMessage as HumanMessageV1
from langchain_core.v1.messages import MessageV1, MessageV1Types, ResponseMetadata
from langchain_core.v1.messages import SystemMessage as SystemMessageV1
from langchain_core.v1.messages import ToolMessage as ToolMessageV1
if TYPE_CHECKING:
from langchain_text_splitters import TextSplitter
@@ -203,7 +210,7 @@ def message_chunk_to_message(chunk: BaseMessageChunk) -> BaseMessage:
MessageLikeRepresentation = Union[
BaseMessage, list[str], tuple[str, str], str, dict[str, Any]
BaseMessage, list[str], tuple[str, str], str, dict[str, Any], MessageV1
]
@@ -213,7 +220,7 @@ def _create_message_from_message_type(
name: Optional[str] = None,
tool_call_id: Optional[str] = None,
tool_calls: Optional[list[dict[str, Any]]] = None,
id: Optional[str] = None, # noqa: A002
id: Optional[str] = None,
**additional_kwargs: Any,
) -> BaseMessage:
"""Create a message from a message type and content string.
@@ -294,6 +301,130 @@ def _create_message_from_message_type(
return message
def _create_message_from_message_type_v1(
message_type: str,
content: str,
name: Optional[str] = None,
tool_call_id: Optional[str] = None,
tool_calls: Optional[list[dict[str, Any]]] = None,
id: Optional[str] = None,
**kwargs: Any,
) -> MessageV1:
"""Create a message from a message type and content string.
Args:
message_type: (str) the type of the message (e.g., "human", "ai", etc.).
content: (str) the content string.
name: (str) the name of the message. Default is None.
tool_call_id: (str) the tool call id. Default is None.
tool_calls: (list[dict[str, Any]]) the tool calls. Default is None.
id: (str) the id of the message. Default is None.
kwargs: (dict[str, Any]) additional keyword arguments.
Returns:
a message of the appropriate type.
Raises:
ValueError: if the message type is not one of "human", "user", "ai",
"assistant", "tool", "system", or "developer".
"""
if name is not None:
kwargs["name"] = name
if tool_call_id is not None:
kwargs["tool_call_id"] = tool_call_id
if kwargs and (response_metadata := kwargs.pop("response_metadata", None)):
kwargs["response_metadata"] = response_metadata
if id is not None:
kwargs["id"] = id
if tool_calls is not None:
kwargs["tool_calls"] = []
for tool_call in tool_calls:
# Convert OpenAI-format tool call to LangChain format.
if "function" in tool_call:
args = tool_call["function"]["arguments"]
if isinstance(args, str):
args = json.loads(args, strict=False)
kwargs["tool_calls"].append(
{
"name": tool_call["function"]["name"],
"args": args,
"id": tool_call["id"],
"type": "tool_call",
}
)
else:
kwargs["tool_calls"].append(tool_call)
if message_type in {"human", "user"}:
message: MessageV1 = HumanMessageV1(content=content, **kwargs)
elif message_type in {"ai", "assistant"}:
message = AIMessageV1(content=content, **kwargs)
elif message_type in {"system", "developer"}:
if message_type == "developer":
kwargs["custom_role"] = "developer"
message = SystemMessageV1(content=content, **kwargs)
elif message_type == "tool":
artifact = kwargs.pop("artifact", None)
message = ToolMessageV1(content=content, artifact=artifact, **kwargs)
else:
msg = (
f"Unexpected message type: '{message_type}'. Use one of 'human',"
f" 'user', 'ai', 'assistant', 'function', 'tool', 'system', or 'developer'."
)
msg = create_message(message=msg, error_code=ErrorCode.MESSAGE_COERCION_FAILURE)
raise ValueError(msg)
return message
def convert_from_v1_message(message: MessageV1) -> BaseMessage:
"""Compatibility layer to convert v1 messages to current messages.
Args:
message: MessageV1 instance to convert.
Returns:
BaseMessage: Converted message instance.
"""
content = cast("Union[str, list[str | dict]]", message.content)
if isinstance(message, AIMessageV1):
return AIMessage(
content=content,
id=message.id,
name=message.name,
tool_calls=message.tool_calls,
response_metadata=cast("dict", message.response_metadata),
)
if isinstance(message, AIMessageChunkV1):
return AIMessageChunk(
content=content,
id=message.id,
name=message.name,
tool_call_chunks=message.tool_call_chunks,
response_metadata=cast("dict", message.response_metadata),
)
if isinstance(message, HumanMessageV1):
return HumanMessage(
content=content,
id=message.id,
name=message.name,
)
if isinstance(message, SystemMessageV1):
return SystemMessage(
content=content,
id=message.id,
)
if isinstance(message, ToolMessageV1):
return ToolMessage(
content=content,
id=message.id,
tool_call_id=message.tool_call_id,
artifact=message.artifact,
name=message.name,
status=message.status,
)
message = f"Unsupported message type: {type(message)}"
raise NotImplementedError(message)
def _convert_to_message(message: MessageLikeRepresentation) -> BaseMessage:
"""Instantiate a message from a variety of message formats.
@@ -341,6 +472,143 @@ def _convert_to_message(message: MessageLikeRepresentation) -> BaseMessage:
message_ = _create_message_from_message_type(
msg_type, msg_content, **msg_kwargs
)
elif isinstance(message, MessageV1Types):
message_ = convert_from_v1_message(message)
else:
msg = f"Unsupported message type: {type(message)}"
msg = create_message(message=msg, error_code=ErrorCode.MESSAGE_COERCION_FAILURE)
raise NotImplementedError(msg)
return message_
def _convert_from_v0_to_v1(message: BaseMessage) -> MessageV1:
"""Convert a v0 message to a v1 message."""
if isinstance(message, HumanMessage): # Checking for v0 HumanMessage
return HumanMessageV1(message.content, id=message.id, name=message.name) # type: ignore[arg-type]
if isinstance(message, AIMessage): # Checking for v0 AIMessage
return AIMessageV1(
content=message.content, # type: ignore[arg-type]
id=message.id,
name=message.name,
lc_version="v1",
response_metadata=message.response_metadata, # type: ignore[arg-type]
usage_metadata=message.usage_metadata,
tool_calls=message.tool_calls,
invalid_tool_calls=message.invalid_tool_calls,
)
if isinstance(message, SystemMessage): # Checking for v0 SystemMessage
return SystemMessageV1(
message.content, # type: ignore[arg-type]
id=message.id,
name=message.name,
)
if isinstance(message, ToolMessage): # Checking for v0 ToolMessage
return ToolMessageV1(
message.content, # type: ignore[arg-type]
message.tool_call_id,
id=message.id,
name=message.name,
artifact=message.artifact,
status=message.status,
)
msg = f"Unsupported v0 message type for conversion to v1: {type(message)}"
raise NotImplementedError(msg)
def _safe_convert_from_v0_to_v1(message: BaseMessage) -> MessageV1:
"""Convert a v0 message to a v1 message."""
from langchain_core.messages.content_blocks import create_text_block
if isinstance(message, HumanMessage): # Checking for v0 HumanMessage
content: list[ContentBlock] = [create_text_block(str(message.content))]
return HumanMessageV1(content, id=message.id, name=message.name)
if isinstance(message, AIMessage): # Checking for v0 AIMessage
content = [create_text_block(str(message.content))]
# Construct ResponseMetadata TypedDict from v0 response_metadata dict
# Since ResponseMetadata has total=False, we can safely cast the dict
response_metadata = cast("ResponseMetadata", message.response_metadata or {})
return AIMessageV1(
content=content,
id=message.id,
name=message.name,
lc_version="v1",
response_metadata=response_metadata,
usage_metadata=message.usage_metadata,
tool_calls=message.tool_calls,
invalid_tool_calls=message.invalid_tool_calls,
)
if isinstance(message, SystemMessage): # Checking for v0 SystemMessage
content = [create_text_block(str(message.content))]
return SystemMessageV1(content=content, id=message.id, name=message.name)
if isinstance(message, ToolMessage): # Checking for v0 ToolMessage
content = [create_text_block(str(message.content))]
return ToolMessageV1(
content,
message.tool_call_id,
id=message.id,
name=message.name,
artifact=message.artifact,
status=message.status,
)
msg = f"Unsupported v0 message type for conversion to v1: {type(message)}"
raise NotImplementedError(msg)
def _convert_to_message_v1(message: MessageLikeRepresentation) -> MessageV1:
"""Instantiate a message from a variety of message formats.
The message format can be one of the following:
- BaseMessagePromptTemplate
- BaseMessage
- 2-tuple of (role string, template); e.g., ("human", "{user_input}")
- dict: a message dict with role and content keys
- string: shorthand for ("human", template); e.g., "{user_input}"
Args:
message: a representation of a message in one of the supported formats.
Returns:
an instance of a message or a message template.
Raises:
NotImplementedError: if the message type is not supported.
ValueError: if the message dict does not contain the required keys.
"""
if isinstance(message, MessageV1Types):
if isinstance(message, AIMessageChunkV1):
message_: MessageV1 = message.to_message()
else:
message_ = message
elif isinstance(message, BaseMessage):
# Convert v0 messages to v1 messages
message_ = _convert_from_v0_to_v1(message)
elif isinstance(message, str):
message_ = _create_message_from_message_type_v1("human", message)
elif isinstance(message, Sequence) and len(message) == 2:
# mypy doesn't realise this can't be a string given the previous branch
message_type_str, template = message # type: ignore[misc]
message_ = _create_message_from_message_type_v1(message_type_str, template)
elif isinstance(message, dict):
msg_kwargs = message.copy()
try:
try:
msg_type = msg_kwargs.pop("role")
except KeyError:
msg_type = msg_kwargs.pop("type")
# None msg content is not allowed
msg_content = msg_kwargs.pop("content") or ""
except KeyError as e:
msg = f"Message dict must contain 'role' and 'content' keys, got {message}"
msg = create_message(
message=msg, error_code=ErrorCode.MESSAGE_COERCION_FAILURE
)
raise ValueError(msg) from e
message_ = _create_message_from_message_type_v1(
msg_type, msg_content, **msg_kwargs
)
else:
msg = f"Unsupported message type: {type(message)}"
msg = create_message(message=msg, error_code=ErrorCode.MESSAGE_COERCION_FAILURE)
@@ -368,6 +636,25 @@ def convert_to_messages(
return [_convert_to_message(m) for m in messages]
def convert_to_messages_v1(
messages: Union[Iterable[MessageLikeRepresentation], PromptValue],
) -> list[MessageV1]:
"""Convert a sequence of messages to a list of messages.
Args:
messages: Sequence of messages to convert.
Returns:
list of messages (BaseMessages).
"""
# Import here to avoid circular imports
from langchain_core.prompt_values import PromptValue
if isinstance(messages, PromptValue):
return messages.to_messages(message_version="v1")
return [_convert_to_message_v1(m) for m in messages]
def _runnable_support(func: Callable) -> Callable:
@overload
def wrapped(
@@ -656,22 +943,23 @@ def trim_messages(
properties:
1. The resulting chat history should be valid. Most chat models expect that chat
history starts with either (1) a `HumanMessage` or (2) a `SystemMessage` followed
by a `HumanMessage`. To achieve this, set `start_on="human"`.
In addition, generally a `ToolMessage` can only appear after an `AIMessage`
history starts with either (1) a ``HumanMessage`` or (2) a ``SystemMessage`` followed
by a ``HumanMessage``. To achieve this, set ``start_on="human"``.
In addition, generally a ``ToolMessage`` can only appear after an ``AIMessage``
that involved a tool call.
Please see the following link for more information about messages:
https://python.langchain.com/docs/concepts/#messages
2. It includes recent messages and drops old messages in the chat history.
To achieve this set the `strategy="last"`.
3. Usually, the new chat history should include the `SystemMessage` if it
was present in the original chat history since the `SystemMessage` includes
special instructions to the chat model. The `SystemMessage` is almost always
To achieve this set the ``strategy="last"``.
3. Usually, the new chat history should include the ``SystemMessage`` if it
was present in the original chat history since the ``SystemMessage`` includes
special instructions to the chat model. The ``SystemMessage`` is almost always
the first message in the history if present. To achieve this set the
`include_system=True`.
``include_system=True``.
**Note** The examples below show how to configure `trim_messages` to achieve
a behavior consistent with the above properties.
.. note::
The examples below show how to configure ``trim_messages`` to achieve a behavior
consistent with the above properties.
Args:
messages: Sequence of Message-like objects to trim.
@@ -1007,10 +1295,11 @@ def convert_to_openai_messages(
oai_messages: list = []
if is_single := isinstance(messages, (BaseMessage, dict, str)):
if is_single := isinstance(messages, (BaseMessage, dict, str, MessageV1Types)):
messages = [messages]
messages = convert_to_messages(messages)
# TODO: resolve type ignore here
messages = convert_to_messages(messages) # type: ignore[arg-type]
for i, message in enumerate(messages):
oai_msg: dict = {"role": _get_message_openai_role(message)}
@@ -1580,26 +1869,26 @@ def count_tokens_approximately(
chars_per_token: Number of characters per token to use for the approximation.
Default is 4 (one token corresponds to ~4 chars for common English text).
You can also specify float values for more fine-grained control.
See more here: https://platform.openai.com/tokenizer
`See more here. <https://platform.openai.com/tokenizer>`__
extra_tokens_per_message: Number of extra tokens to add per message.
Default is 3 (special tokens, including beginning/end of message).
You can also specify float values for more fine-grained control.
See more here:
https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
`See more here. <https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb>`__
count_name: Whether to include message names in the count.
Enabled by default.
Returns:
Approximate number of tokens in the messages.
Note:
This is a simple approximation that may not match the exact token count
used by specific models. For accurate counts, use model-specific tokenizers.
.. note::
This is a simple approximation that may not match the exact token count used by
specific models. For accurate counts, use model-specific tokenizers.
Warning:
This function does not currently support counting image tokens.
.. versionadded:: 0.3.46
"""
token_count = 0.0
for message in convert_to_messages(messages):

View File

@@ -11,6 +11,7 @@ from typing import (
Optional,
TypeVar,
Union,
cast,
)
from typing_extensions import override
@@ -20,19 +21,22 @@ from langchain_core.messages import AnyMessage, BaseMessage
from langchain_core.outputs import ChatGeneration, Generation
from langchain_core.runnables import Runnable, RunnableConfig, RunnableSerializable
from langchain_core.runnables.config import run_in_executor
from langchain_core.v1.messages import AIMessage, MessageV1, MessageV1Types
if TYPE_CHECKING:
from langchain_core.prompt_values import PromptValue
T = TypeVar("T")
OutputParserLike = Runnable[LanguageModelOutput, T]
OutputParserLike = Runnable[Union[LanguageModelOutput, AIMessage], T]
class BaseLLMOutputParser(ABC, Generic[T]):
"""Abstract base class for parsing the outputs of a model."""
@abstractmethod
def parse_result(self, result: list[Generation], *, partial: bool = False) -> T:
def parse_result(
self, result: Union[list[Generation], AIMessage], *, partial: bool = False
) -> T:
"""Parse a list of candidate model Generations into a specific format.
Args:
@@ -46,7 +50,7 @@ class BaseLLMOutputParser(ABC, Generic[T]):
"""
async def aparse_result(
self, result: list[Generation], *, partial: bool = False
self, result: Union[list[Generation], AIMessage], *, partial: bool = False
) -> T:
"""Async parse a list of candidate model Generations into a specific format.
@@ -71,7 +75,7 @@ class BaseGenerationOutputParser(
@override
def InputType(self) -> Any:
"""Return the input type for the parser."""
return Union[str, AnyMessage]
return Union[str, AnyMessage, MessageV1]
@property
@override
@@ -84,7 +88,7 @@ class BaseGenerationOutputParser(
@override
def invoke(
self,
input: Union[str, BaseMessage],
input: Union[str, BaseMessage, MessageV1],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> T:
@@ -97,9 +101,16 @@ class BaseGenerationOutputParser(
config,
run_type="parser",
)
if isinstance(input, MessageV1Types):
return self._call_with_config(
lambda inner_input: self.parse_result(inner_input),
input,
config,
run_type="parser",
)
return self._call_with_config(
lambda inner_input: self.parse_result([Generation(text=inner_input)]),
input,
cast("str", input),
config,
run_type="parser",
)
@@ -120,6 +131,13 @@ class BaseGenerationOutputParser(
config,
run_type="parser",
)
if isinstance(input, MessageV1Types):
return await self._acall_with_config(
lambda inner_input: self.aparse_result(inner_input),
input,
config,
run_type="parser",
)
return await self._acall_with_config(
lambda inner_input: self.aparse_result([Generation(text=inner_input)]),
input,
@@ -129,7 +147,7 @@ class BaseGenerationOutputParser(
class BaseOutputParser(
BaseLLMOutputParser, RunnableSerializable[LanguageModelOutput, T]
BaseLLMOutputParser, RunnableSerializable[Union[LanguageModelOutput, AIMessage], T]
):
"""Base class to parse the output of an LLM call.
@@ -162,7 +180,7 @@ class BaseOutputParser(
@override
def InputType(self) -> Any:
"""Return the input type for the parser."""
return Union[str, AnyMessage]
return Union[str, AnyMessage, MessageV1]
@property
@override
@@ -189,7 +207,7 @@ class BaseOutputParser(
@override
def invoke(
self,
input: Union[str, BaseMessage],
input: Union[str, BaseMessage, MessageV1],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> T:
@@ -202,9 +220,16 @@ class BaseOutputParser(
config,
run_type="parser",
)
if isinstance(input, MessageV1Types):
return self._call_with_config(
lambda inner_input: self.parse_result(inner_input),
input,
config,
run_type="parser",
)
return self._call_with_config(
lambda inner_input: self.parse_result([Generation(text=inner_input)]),
input,
cast("str", input),
config,
run_type="parser",
)
@@ -212,7 +237,7 @@ class BaseOutputParser(
@override
async def ainvoke(
self,
input: Union[str, BaseMessage],
input: Union[str, BaseMessage, MessageV1],
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> T:
@@ -225,15 +250,24 @@ class BaseOutputParser(
config,
run_type="parser",
)
if isinstance(input, MessageV1Types):
return await self._acall_with_config(
lambda inner_input: self.aparse_result(inner_input),
input,
config,
run_type="parser",
)
return await self._acall_with_config(
lambda inner_input: self.aparse_result([Generation(text=inner_input)]),
input,
cast("str", input),
config,
run_type="parser",
)
@override
def parse_result(self, result: list[Generation], *, partial: bool = False) -> T:
def parse_result(
self, result: Union[list[Generation], AIMessage], *, partial: bool = False
) -> T:
"""Parse a list of candidate model Generations into a specific format.
The return value is parsed from only the first Generation in the result, which
@@ -248,6 +282,8 @@ class BaseOutputParser(
Returns:
Structured output.
"""
if isinstance(result, AIMessage):
return self.parse(result.text)
return self.parse(result[0].text)
@abstractmethod
@@ -262,7 +298,7 @@ class BaseOutputParser(
"""
async def aparse_result(
self, result: list[Generation], *, partial: bool = False
self, result: Union[list[Generation], AIMessage], *, partial: bool = False
) -> T:
"""Async parse a list of candidate model Generations into a specific format.

View File

@@ -21,6 +21,7 @@ from langchain_core.utils.json import (
parse_json_markdown,
parse_partial_json,
)
from langchain_core.v1.messages import AIMessage
# Union type needs to be last assignment to PydanticBaseModel to make mypy happy.
PydanticBaseModel = Union[BaseModel, pydantic.BaseModel]
@@ -53,7 +54,9 @@ class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
return pydantic_object.schema()
return None
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
def parse_result(
self, result: Union[list[Generation], AIMessage], *, partial: bool = False
) -> Any:
"""Parse the result of an LLM call to a JSON object.
Args:
@@ -70,7 +73,7 @@ class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
Raises:
OutputParserException: If the output is not valid JSON.
"""
text = result[0].text
text = result.text if isinstance(result, AIMessage) else result[0].text
text = text.strip()
if partial:
try:

View File

@@ -13,6 +13,7 @@ from typing_extensions import override
from langchain_core.messages import BaseMessage
from langchain_core.output_parsers.transform import BaseTransformOutputParser
from langchain_core.v1.messages import AIMessage
if TYPE_CHECKING:
from collections.abc import AsyncIterator, Iterator
@@ -71,7 +72,7 @@ class ListOutputParser(BaseTransformOutputParser[list[str]]):
@override
def _transform(
self, input: Iterator[Union[str, BaseMessage]]
self, input: Iterator[Union[str, BaseMessage, AIMessage]]
) -> Iterator[list[str]]:
buffer = ""
for chunk in input:
@@ -81,6 +82,8 @@ class ListOutputParser(BaseTransformOutputParser[list[str]]):
if not isinstance(chunk_content, str):
continue
buffer += chunk_content
elif isinstance(chunk, AIMessage):
buffer += chunk.text
else:
# add current chunk to buffer
buffer += chunk
@@ -105,7 +108,7 @@ class ListOutputParser(BaseTransformOutputParser[list[str]]):
@override
async def _atransform(
self, input: AsyncIterator[Union[str, BaseMessage]]
self, input: AsyncIterator[Union[str, BaseMessage, AIMessage]]
) -> AsyncIterator[list[str]]:
buffer = ""
async for chunk in input:
@@ -115,6 +118,8 @@ class ListOutputParser(BaseTransformOutputParser[list[str]]):
if not isinstance(chunk_content, str):
continue
buffer += chunk_content
elif isinstance(chunk, AIMessage):
buffer += chunk.text
else:
# add current chunk to buffer
buffer += chunk

View File

@@ -17,6 +17,7 @@ from langchain_core.output_parsers import (
)
from langchain_core.output_parsers.json import parse_partial_json
from langchain_core.outputs import ChatGeneration, Generation
from langchain_core.v1.messages import AIMessage
class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
@@ -26,7 +27,9 @@ class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
"""Whether to only return the arguments to the function call."""
@override
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
def parse_result(
self, result: Union[list[Generation], AIMessage], *, partial: bool = False
) -> Any:
"""Parse the result of an LLM call to a JSON object.
Args:
@@ -39,6 +42,12 @@ class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
Raises:
OutputParserException: If the output is not valid JSON.
"""
if isinstance(result, AIMessage):
msg = (
"This output parser does not support v1 AIMessages. Use "
"JsonOutputToolsParser instead."
)
raise TypeError(msg)
generation = result[0]
if not isinstance(generation, ChatGeneration):
msg = "This output parser can only be used with a chat generation."
@@ -77,7 +86,9 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
def _diff(self, prev: Optional[Any], next: Any) -> Any:
return jsonpatch.make_patch(prev, next).patch
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
def parse_result(
self, result: Union[list[Generation], AIMessage], *, partial: bool = False
) -> Any:
"""Parse the result of an LLM call to a JSON object.
Args:
@@ -90,6 +101,12 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
Raises:
OutputParserException: If the output is not valid JSON.
"""
if isinstance(result, AIMessage):
msg = (
"This output parser does not support v1 AIMessages. Use "
"JsonOutputToolsParser instead."
)
raise TypeError(msg)
if len(result) != 1:
msg = f"Expected exactly one result, but got {len(result)}"
raise OutputParserException(msg)
@@ -160,7 +177,9 @@ class JsonKeyOutputFunctionsParser(JsonOutputFunctionsParser):
key_name: str
"""The name of the key to return."""
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
def parse_result(
self, result: Union[list[Generation], AIMessage], *, partial: bool = False
) -> Any:
"""Parse the result of an LLM call to a JSON object.
Args:
@@ -254,7 +273,9 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser):
return values
@override
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
def parse_result(
self, result: Union[list[Generation], AIMessage], *, partial: bool = False
) -> Any:
"""Parse the result of an LLM call to a JSON object.
Args:
@@ -294,7 +315,9 @@ class PydanticAttrOutputFunctionsParser(PydanticOutputFunctionsParser):
"""The name of the attribute to return."""
@override
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
def parse_result(
self, result: Union[list[Generation], AIMessage], *, partial: bool = False
) -> Any:
"""Parse the result of an LLM call to a JSON object.
Args:

View File

@@ -4,7 +4,7 @@ import copy
import json
import logging
from json import JSONDecodeError
from typing import Annotated, Any, Optional
from typing import Annotated, Any, Optional, Union
from pydantic import SkipValidation, ValidationError
@@ -16,6 +16,7 @@ from langchain_core.output_parsers.transform import BaseCumulativeTransformOutpu
from langchain_core.outputs import ChatGeneration, Generation
from langchain_core.utils.json import parse_partial_json
from langchain_core.utils.pydantic import TypeBaseModel
from langchain_core.v1.messages import AIMessage as AIMessageV1
logger = logging.getLogger(__name__)
@@ -156,7 +157,9 @@ class JsonOutputToolsParser(BaseCumulativeTransformOutputParser[Any]):
If no tool calls are found, None will be returned.
"""
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
def parse_result(
self, result: Union[list[Generation], AIMessageV1], *, partial: bool = False
) -> Any:
"""Parse the result of an LLM call to a list of tool calls.
Args:
@@ -173,31 +176,45 @@ class JsonOutputToolsParser(BaseCumulativeTransformOutputParser[Any]):
Raises:
OutputParserException: If the output is not valid JSON.
"""
generation = result[0]
if not isinstance(generation, ChatGeneration):
msg = "This output parser can only be used with a chat generation."
raise OutputParserException(msg)
message = generation.message
if isinstance(message, AIMessage) and message.tool_calls:
tool_calls = [dict(tc) for tc in message.tool_calls]
if isinstance(result, list):
generation = result[0]
if not isinstance(generation, ChatGeneration):
msg = (
"This output parser can only be used with a chat generation or "
"v1 AIMessage."
)
raise OutputParserException(msg)
message = generation.message
if isinstance(message, AIMessage) and message.tool_calls:
tool_calls = [dict(tc) for tc in message.tool_calls]
for tool_call in tool_calls:
if not self.return_id:
_ = tool_call.pop("id")
else:
try:
raw_tool_calls = copy.deepcopy(
message.additional_kwargs["tool_calls"]
)
except KeyError:
return []
tool_calls = parse_tool_calls(
raw_tool_calls,
partial=partial,
strict=self.strict,
return_id=self.return_id,
)
elif result.tool_calls:
# v1 message
tool_calls = [dict(tc) for tc in result.tool_calls]
for tool_call in tool_calls:
if not self.return_id:
_ = tool_call.pop("id")
else:
try:
raw_tool_calls = copy.deepcopy(message.additional_kwargs["tool_calls"])
except KeyError:
return []
tool_calls = parse_tool_calls(
raw_tool_calls,
partial=partial,
strict=self.strict,
return_id=self.return_id,
)
return []
# for backwards compatibility
for tc in tool_calls:
tc["type"] = tc.pop("name")
if self.first_tool_only:
return tool_calls[0] if tool_calls else None
return tool_calls
@@ -220,7 +237,9 @@ class JsonOutputKeyToolsParser(JsonOutputToolsParser):
key_name: str
"""The type of tools to return."""
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
def parse_result(
self, result: Union[list[Generation], AIMessageV1], *, partial: bool = False
) -> Any:
"""Parse the result of an LLM call to a list of tool calls.
Args:
@@ -234,32 +253,47 @@ class JsonOutputKeyToolsParser(JsonOutputToolsParser):
Returns:
The parsed tool calls.
"""
generation = result[0]
if not isinstance(generation, ChatGeneration):
msg = "This output parser can only be used with a chat generation."
raise OutputParserException(msg)
message = generation.message
if isinstance(message, AIMessage) and message.tool_calls:
parsed_tool_calls = [dict(tc) for tc in message.tool_calls]
if isinstance(result, list):
generation = result[0]
if not isinstance(generation, ChatGeneration):
msg = "This output parser can only be used with a chat generation."
raise OutputParserException(msg)
message = generation.message
if isinstance(message, AIMessage) and message.tool_calls:
parsed_tool_calls = [dict(tc) for tc in message.tool_calls]
for tool_call in parsed_tool_calls:
if not self.return_id:
_ = tool_call.pop("id")
else:
try:
raw_tool_calls = copy.deepcopy(
message.additional_kwargs["tool_calls"]
)
except KeyError:
if self.first_tool_only:
return None
return []
parsed_tool_calls = parse_tool_calls(
raw_tool_calls,
partial=partial,
strict=self.strict,
return_id=self.return_id,
)
elif result.tool_calls:
# v1 message
parsed_tool_calls = [dict(tc) for tc in result.tool_calls]
for tool_call in parsed_tool_calls:
if not self.return_id:
_ = tool_call.pop("id")
else:
try:
raw_tool_calls = copy.deepcopy(message.additional_kwargs["tool_calls"])
except KeyError:
if self.first_tool_only:
return None
return []
parsed_tool_calls = parse_tool_calls(
raw_tool_calls,
partial=partial,
strict=self.strict,
return_id=self.return_id,
)
if self.first_tool_only:
return None
return []
# For backwards compatibility
for tc in parsed_tool_calls:
tc["type"] = tc.pop("name")
if self.first_tool_only:
parsed_result = list(
filter(lambda x: x["type"] == self.key_name, parsed_tool_calls)
@@ -299,7 +333,9 @@ class PydanticToolsParser(JsonOutputToolsParser):
# TODO: Support more granular streaming of objects. Currently only streams once all
# Pydantic object fields are present.
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
def parse_result(
self, result: Union[list[Generation], AIMessageV1], *, partial: bool = False
) -> Any:
"""Parse the result of an LLM call to a list of Pydantic objects.
Args:
@@ -337,12 +373,19 @@ class PydanticToolsParser(JsonOutputToolsParser):
except (ValidationError, ValueError):
if partial:
continue
has_max_tokens_stop_reason = any(
generation.message.response_metadata.get("stop_reason")
== "max_tokens"
for generation in result
if isinstance(generation, ChatGeneration)
)
has_max_tokens_stop_reason = False
if isinstance(result, list):
has_max_tokens_stop_reason = any(
generation.message.response_metadata.get("stop_reason")
== "max_tokens"
for generation in result
if isinstance(generation, ChatGeneration)
)
else:
# v1 message
has_max_tokens_stop_reason = (
result.response_metadata.get("stop_reason") == "max_tokens"
)
if has_max_tokens_stop_reason:
logger.exception(_MAX_TOKENS_ERROR)
raise

View File

@@ -1,7 +1,7 @@
"""Output parsers using Pydantic."""
import json
from typing import Annotated, Generic, Optional
from typing import Annotated, Generic, Optional, Union
import pydantic
from pydantic import SkipValidation
@@ -14,6 +14,7 @@ from langchain_core.utils.pydantic import (
PydanticBaseModel,
TBaseModel,
)
from langchain_core.v1.messages import AIMessage
class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
@@ -43,7 +44,7 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
return OutputParserException(msg, llm_output=json_string)
def parse_result(
self, result: list[Generation], *, partial: bool = False
self, result: Union[list[Generation], AIMessage], *, partial: bool = False
) -> Optional[TBaseModel]:
"""Parse the result of an LLM call to a pydantic object.

View File

@@ -20,6 +20,7 @@ from langchain_core.outputs import (
GenerationChunk,
)
from langchain_core.runnables.config import run_in_executor
from langchain_core.v1.messages import AIMessage, AIMessageChunk
if TYPE_CHECKING:
from collections.abc import AsyncIterator, Iterator
@@ -32,23 +33,27 @@ class BaseTransformOutputParser(BaseOutputParser[T]):
def _transform(
self,
input: Iterator[Union[str, BaseMessage]], # noqa: A002
input: Iterator[Union[str, BaseMessage, AIMessage]],
) -> Iterator[T]:
for chunk in input:
if isinstance(chunk, BaseMessage):
yield self.parse_result([ChatGeneration(message=chunk)])
elif isinstance(chunk, AIMessage):
yield self.parse_result(chunk)
else:
yield self.parse_result([Generation(text=chunk)])
async def _atransform(
self,
input: AsyncIterator[Union[str, BaseMessage]], # noqa: A002
input: AsyncIterator[Union[str, BaseMessage, AIMessage]],
) -> AsyncIterator[T]:
async for chunk in input:
if isinstance(chunk, BaseMessage):
yield await run_in_executor(
None, self.parse_result, [ChatGeneration(message=chunk)]
)
elif isinstance(chunk, AIMessage):
yield await run_in_executor(None, self.parse_result, chunk)
else:
yield await run_in_executor(
None, self.parse_result, [Generation(text=chunk)]
@@ -57,7 +62,7 @@ class BaseTransformOutputParser(BaseOutputParser[T]):
@override
def transform(
self,
input: Iterator[Union[str, BaseMessage]],
input: Iterator[Union[str, BaseMessage, AIMessage]],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Iterator[T]:
@@ -78,7 +83,7 @@ class BaseTransformOutputParser(BaseOutputParser[T]):
@override
async def atransform(
self,
input: AsyncIterator[Union[str, BaseMessage]],
input: AsyncIterator[Union[str, BaseMessage, AIMessage]],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> AsyncIterator[T]:
@@ -125,23 +130,42 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
raise NotImplementedError
@override
def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[Any]:
def _transform(
self, input: Iterator[Union[str, BaseMessage, AIMessage]]
) -> Iterator[Any]:
prev_parsed = None
acc_gen: Union[GenerationChunk, ChatGenerationChunk, None] = None
acc_gen: Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk, None] = (
None
)
for chunk in input:
chunk_gen: Union[GenerationChunk, ChatGenerationChunk]
chunk_gen: Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]
if isinstance(chunk, BaseMessageChunk):
chunk_gen = ChatGenerationChunk(message=chunk)
elif isinstance(chunk, BaseMessage):
chunk_gen = ChatGenerationChunk(
message=BaseMessageChunk(**chunk.model_dump())
)
elif isinstance(chunk, AIMessageChunk):
chunk_gen = chunk
elif isinstance(chunk, AIMessage):
chunk_gen = AIMessageChunk(
content=chunk.content,
id=chunk.id,
name=chunk.name,
lc_version=chunk.lc_version,
response_metadata=chunk.response_metadata,
usage_metadata=chunk.usage_metadata,
parsed=chunk.parsed,
)
else:
chunk_gen = GenerationChunk(text=chunk)
acc_gen = chunk_gen if acc_gen is None else acc_gen + chunk_gen # type: ignore[operator]
parsed = self.parse_result([acc_gen], partial=True)
if isinstance(acc_gen, AIMessageChunk):
parsed = self.parse_result(acc_gen, partial=True)
else:
parsed = self.parse_result([acc_gen], partial=True)
if parsed is not None and parsed != prev_parsed:
if self.diff:
yield self._diff(prev_parsed, parsed)
@@ -151,24 +175,41 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
@override
async def _atransform(
self, input: AsyncIterator[Union[str, BaseMessage]]
self, input: AsyncIterator[Union[str, BaseMessage, AIMessage]]
) -> AsyncIterator[T]:
prev_parsed = None
acc_gen: Union[GenerationChunk, ChatGenerationChunk, None] = None
acc_gen: Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk, None] = (
None
)
async for chunk in input:
chunk_gen: Union[GenerationChunk, ChatGenerationChunk]
chunk_gen: Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]
if isinstance(chunk, BaseMessageChunk):
chunk_gen = ChatGenerationChunk(message=chunk)
elif isinstance(chunk, BaseMessage):
chunk_gen = ChatGenerationChunk(
message=BaseMessageChunk(**chunk.model_dump())
)
elif isinstance(chunk, AIMessageChunk):
chunk_gen = chunk
elif isinstance(chunk, AIMessage):
chunk_gen = AIMessageChunk(
content=chunk.content,
id=chunk.id,
name=chunk.name,
lc_version=chunk.lc_version,
response_metadata=chunk.response_metadata,
usage_metadata=chunk.usage_metadata,
parsed=chunk.parsed,
)
else:
chunk_gen = GenerationChunk(text=chunk)
acc_gen = chunk_gen if acc_gen is None else acc_gen + chunk_gen # type: ignore[operator]
parsed = await self.aparse_result([acc_gen], partial=True)
if isinstance(acc_gen, AIMessageChunk):
parsed = await self.aparse_result(acc_gen, partial=True)
else:
parsed = await self.aparse_result([acc_gen], partial=True)
if parsed is not None and parsed != prev_parsed:
if self.diff:
yield await run_in_executor(None, self._diff, prev_parsed, parsed)

View File

@@ -12,8 +12,10 @@ from typing_extensions import override
from langchain_core.exceptions import OutputParserException
from langchain_core.messages import BaseMessage
from langchain_core.messages.utils import convert_from_v1_message
from langchain_core.output_parsers.transform import BaseTransformOutputParser
from langchain_core.runnables.utils import AddableDict
from langchain_core.v1.messages import AIMessage
XML_FORMAT_INSTRUCTIONS = """The output should be formatted as a XML file.
1. Output should conform to the tags below.
@@ -105,23 +107,27 @@ class _StreamingParser:
self.buffer = ""
# yield all events
try:
for event, elem in self.pull_parser.read_events():
if event == "start":
# update current path
self.current_path.append(elem.tag)
self.current_path_has_children = False
elif event == "end":
# remove last element from current path
#
self.current_path.pop()
# yield element
if not self.current_path_has_children:
yield nested_element(self.current_path, elem)
# prevent yielding of parent element
if self.current_path:
self.current_path_has_children = True
else:
self.xml_started = False
for raw_event in self.pull_parser.read_events():
if len(raw_event) <= 1:
continue
event, elem = raw_event
if isinstance(elem, ET.Element):
if event == "start":
# update current path
self.current_path.append(elem.tag)
self.current_path_has_children = False
elif event == "end":
# remove last element from current path
#
self.current_path.pop()
# yield element
if not self.current_path_has_children:
yield nested_element(self.current_path, elem)
# prevent yielding of parent element
if self.current_path:
self.current_path_has_children = True
else:
self.xml_started = False
except xml.etree.ElementTree.ParseError:
# This might be junk at the end of the XML input.
# Let's check whether the current path is empty.
@@ -240,21 +246,28 @@ class XMLOutputParser(BaseTransformOutputParser):
@override
def _transform(
self, input: Iterator[Union[str, BaseMessage]]
self, input: Iterator[Union[str, BaseMessage, AIMessage]]
) -> Iterator[AddableDict]:
streaming_parser = _StreamingParser(self.parser)
for chunk in input:
yield from streaming_parser.parse(chunk)
if isinstance(chunk, AIMessage):
yield from streaming_parser.parse(convert_from_v1_message(chunk))
else:
yield from streaming_parser.parse(chunk)
streaming_parser.close()
@override
async def _atransform(
self, input: AsyncIterator[Union[str, BaseMessage]]
self, input: AsyncIterator[Union[str, BaseMessage, AIMessage]]
) -> AsyncIterator[AddableDict]:
streaming_parser = _StreamingParser(self.parser)
async for chunk in input:
for output in streaming_parser.parse(chunk):
yield output
if isinstance(chunk, AIMessage):
for output in streaming_parser.parse(convert_from_v1_message(chunk)):
yield output
else:
for output in streaming_parser.parse(chunk):
yield output
streaming_parser.close()
def _root_to_dict(self, root: ET.Element) -> dict[str, Union[str, list[Any]]]:

View File

@@ -8,17 +8,65 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import Literal, cast
from typing import Literal, Union, cast
from typing_extensions import TypedDict
from typing_extensions import TypedDict, overload
from langchain_core.load.serializable import Serializable
from langchain_core.messages import (
AIMessage,
AnyMessage,
BaseMessage,
HumanMessage,
SystemMessage,
ToolMessage,
get_buffer_string,
)
from langchain_core.messages import content_blocks as types
from langchain_core.v1.messages import AIMessage as AIMessageV1
from langchain_core.v1.messages import HumanMessage as HumanMessageV1
from langchain_core.v1.messages import MessageV1, ResponseMetadata
from langchain_core.v1.messages import SystemMessage as SystemMessageV1
from langchain_core.v1.messages import ToolMessage as ToolMessageV1
def _convert_to_v1(message: BaseMessage) -> MessageV1:
"""Best-effort conversion of a V0 AIMessage to V1."""
if isinstance(message.content, str):
content: list[types.ContentBlock] = []
if message.content:
content = [{"type": "text", "text": message.content}]
else:
content = []
for block in message.content:
if isinstance(block, str):
content.append({"type": "text", "text": block})
elif isinstance(block, dict):
content.append(cast("types.ContentBlock", block))
else:
pass
if isinstance(message, HumanMessage):
return HumanMessageV1(content=content)
if isinstance(message, AIMessage):
for tool_call in message.tool_calls:
content.append(tool_call)
return AIMessageV1(
content=content,
usage_metadata=message.usage_metadata,
response_metadata=cast("ResponseMetadata", message.response_metadata),
tool_calls=message.tool_calls,
)
if isinstance(message, SystemMessage):
return SystemMessageV1(content=content)
if isinstance(message, ToolMessage):
return ToolMessageV1(
tool_call_id=message.tool_call_id,
content=content,
artifact=message.artifact,
)
error_message = f"Unsupported message type: {type(message)}"
raise TypeError(error_message)
class PromptValue(Serializable, ABC):
@@ -46,8 +94,18 @@ class PromptValue(Serializable, ABC):
def to_string(self) -> str:
"""Return prompt value as string."""
@overload
def to_messages(
self, message_version: Literal["v0"] = "v0"
) -> list[BaseMessage]: ...
@overload
def to_messages(self, message_version: Literal["v1"]) -> list[MessageV1]: ...
@abstractmethod
def to_messages(self) -> list[BaseMessage]:
def to_messages(
self, message_version: Literal["v0", "v1"] = "v0"
) -> Union[Sequence[BaseMessage], Sequence[MessageV1]]:
"""Return prompt as a list of Messages."""
@@ -71,8 +129,20 @@ class StringPromptValue(PromptValue):
"""Return prompt as string."""
return self.text
def to_messages(self) -> list[BaseMessage]:
@overload
def to_messages(
self, message_version: Literal["v0"] = "v0"
) -> list[BaseMessage]: ...
@overload
def to_messages(self, message_version: Literal["v1"]) -> list[MessageV1]: ...
def to_messages(
self, message_version: Literal["v0", "v1"] = "v0"
) -> Union[Sequence[BaseMessage], Sequence[MessageV1]]:
"""Return prompt as messages."""
if message_version == "v1":
return [HumanMessageV1(content=self.text)]
return [HumanMessage(content=self.text)]
@@ -89,8 +159,24 @@ class ChatPromptValue(PromptValue):
"""Return prompt as string."""
return get_buffer_string(self.messages)
def to_messages(self) -> list[BaseMessage]:
"""Return prompt as a list of messages."""
@overload
def to_messages(
self, message_version: Literal["v0"] = "v0"
) -> list[BaseMessage]: ...
@overload
def to_messages(self, message_version: Literal["v1"]) -> list[MessageV1]: ...
def to_messages(
self, message_version: Literal["v0", "v1"] = "v0"
) -> Union[Sequence[BaseMessage], Sequence[MessageV1]]:
"""Return prompt as a list of messages.
Args:
message_version: The output version, either "v0" (default) or "v1".
"""
if message_version == "v1":
return [_convert_to_v1(m) for m in self.messages]
return list(self.messages)
@classmethod
@@ -125,8 +211,26 @@ class ImagePromptValue(PromptValue):
"""Return prompt (image URL) as string."""
return self.image_url["url"]
def to_messages(self) -> list[BaseMessage]:
@overload
def to_messages(
self, message_version: Literal["v0"] = "v0"
) -> list[BaseMessage]: ...
@overload
def to_messages(self, message_version: Literal["v1"]) -> list[MessageV1]: ...
def to_messages(
self, message_version: Literal["v0", "v1"] = "v0"
) -> Union[Sequence[BaseMessage], Sequence[MessageV1]]:
"""Return prompt (image URL) as messages."""
if message_version == "v1":
block: types.ImageContentBlock = {
"type": "image",
"url": self.image_url["url"],
}
if "detail" in self.image_url:
block["detail"] = self.image_url["detail"]
return [HumanMessageV1(content=[block])]
return [HumanMessage(content=[cast("dict", self.image_url)])]

File diff suppressed because it is too large Load Diff

View File

@@ -402,7 +402,7 @@ def call_func_with_variable_args(
Callable[[Input, CallbackManagerForChainRun], Output],
Callable[[Input, CallbackManagerForChainRun, RunnableConfig], Output],
],
input: Input, # noqa: A002
input: Input,
config: RunnableConfig,
run_manager: Optional[CallbackManagerForChainRun] = None,
**kwargs: Any,
@@ -439,7 +439,7 @@ def acall_func_with_variable_args(
Awaitable[Output],
],
],
input: Input, # noqa: A002
input: Input,
config: RunnableConfig,
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
**kwargs: Any,

View File

@@ -5,7 +5,7 @@ import inspect
import typing
from collections.abc import AsyncIterator, Iterator, Sequence
from functools import wraps
from typing import TYPE_CHECKING, Any, Optional, Union
from typing import TYPE_CHECKING, Any, Optional, Union, cast
from pydantic import BaseModel, ConfigDict
from typing_extensions import override
@@ -397,7 +397,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
)
)
to_return = {}
to_return: dict[int, Union[Output, BaseException]] = {}
run_again = dict(enumerate(inputs))
handled_exceptions: dict[int, BaseException] = {}
first_to_raise = None
@@ -447,7 +447,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
if not return_exceptions and sorted_handled_exceptions:
raise sorted_handled_exceptions[0][1]
to_return.update(handled_exceptions)
return [output for _, output in sorted(to_return.items())] # type: ignore[misc]
return [cast("Output", output) for _, output in sorted(to_return.items())]
@override
def stream(
@@ -569,7 +569,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
async for chunk in stream:
yield chunk
try:
output = output + chunk
output = output + chunk # type: ignore[operator]
except TypeError:
output = None
except BaseException as e:

View File

@@ -114,7 +114,7 @@ class Node(NamedTuple):
def copy(
self,
*,
id: Optional[str] = None, # noqa: A002
id: Optional[str] = None,
name: Optional[str] = None,
) -> Node:
"""Return a copy of the node with optional new id and name.
@@ -187,7 +187,7 @@ class MermaidDrawMethod(Enum):
def node_data_str(
id: str, # noqa: A002
id: str,
data: Union[type[BaseModel], RunnableType, None],
) -> str:
"""Convert the data of a node to a string.
@@ -328,7 +328,7 @@ class Graph:
def add_node(
self,
data: Union[type[BaseModel], RunnableType, None],
id: Optional[str] = None, # noqa: A002
id: Optional[str] = None,
*,
metadata: Optional[dict[str, Any]] = None,
) -> Node:

View File

@@ -68,13 +68,21 @@ from langchain_core.utils.pydantic import (
is_pydantic_v1_subclass,
is_pydantic_v2_subclass,
)
from langchain_core.v1.messages import ToolMessage as ToolMessageV1
if TYPE_CHECKING:
import uuid
from collections.abc import Sequence
FILTERED_ARGS = ("run_manager", "callbacks")
TOOL_MESSAGE_BLOCK_TYPES = ("text", "image_url", "image", "json", "search_result")
TOOL_MESSAGE_BLOCK_TYPES = (
"text",
"image_url",
"image",
"json",
"search_result",
"custom_tool_call_output",
)
class SchemaAnnotationError(TypeError):
@@ -498,6 +506,15 @@ class ChildTool(BaseTool):
two-tuple corresponding to the (content, artifact) of a ToolMessage.
"""
message_version: Literal["v0", "v1"] = "v0"
"""Version of ToolMessage to return given
:class:`~langchain_core.messages.content_blocks.ToolCall` input.
If ``"v0"``, output will be a v0 :class:`~langchain_core.messages.tool.ToolMessage`.
If ``"v1"``, output will be a v1 :class:`~langchain_core.v1.messages.ToolMessage`.
"""
def __init__(self, **kwargs: Any) -> None:
"""Initialize the tool."""
if (
@@ -835,7 +852,7 @@ class ChildTool(BaseTool):
content = None
artifact = None
status = "success"
status: Literal["success", "error"] = "success"
error_to_raise: Union[Exception, KeyboardInterrupt, None] = None
try:
child_config = patch_config(config, callbacks=run_manager.get_child())
@@ -879,7 +896,14 @@ class ChildTool(BaseTool):
if error_to_raise:
run_manager.on_tool_error(error_to_raise)
raise error_to_raise
output = _format_output(content, artifact, tool_call_id, self.name, status)
output = _format_output(
content,
artifact,
tool_call_id,
self.name,
status,
message_version=self.message_version,
)
run_manager.on_tool_end(output, color=color, name=self.name, **kwargs)
return output
@@ -945,7 +969,7 @@ class ChildTool(BaseTool):
)
content = None
artifact = None
status = "success"
status: Literal["success", "error"] = "success"
error_to_raise: Optional[Union[Exception, KeyboardInterrupt]] = None
try:
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input, tool_call_id)
@@ -993,7 +1017,14 @@ class ChildTool(BaseTool):
await run_manager.on_tool_error(error_to_raise)
raise error_to_raise
output = _format_output(content, artifact, tool_call_id, self.name, status)
output = _format_output(
content,
artifact,
tool_call_id,
self.name,
status,
message_version=self.message_version,
)
await run_manager.on_tool_end(output, color=color, name=self.name, **kwargs)
return output
@@ -1131,7 +1162,9 @@ def _format_output(
artifact: Any,
tool_call_id: Optional[str],
name: str,
status: str,
status: Literal["success", "error"],
*,
message_version: Literal["v0", "v1"] = "v0",
) -> Union[ToolOutputMixin, Any]:
"""Format tool output as a ToolMessage if appropriate.
@@ -1141,6 +1174,7 @@ def _format_output(
tool_call_id: The ID of the tool call.
name: The name of the tool.
status: The execution status.
message_version: The version of the ToolMessage to return.
Returns:
The formatted output, either as a ToolMessage or the original content.
@@ -1149,7 +1183,15 @@ def _format_output(
return content
if not _is_message_content_type(content):
content = _stringify(content)
return ToolMessage(
if message_version == "v0":
return ToolMessage(
content,
artifact=artifact,
tool_call_id=tool_call_id,
name=name,
status=status,
)
return ToolMessageV1(
content,
artifact=artifact,
tool_call_id=tool_call_id,

View File

@@ -22,6 +22,7 @@ def tool(
response_format: Literal["content", "content_and_artifact"] = "content",
parse_docstring: bool = False,
error_on_invalid_docstring: bool = True,
message_version: Literal["v0", "v1"] = "v0",
) -> Callable[[Union[Callable, Runnable]], BaseTool]: ...
@@ -37,6 +38,7 @@ def tool(
response_format: Literal["content", "content_and_artifact"] = "content",
parse_docstring: bool = False,
error_on_invalid_docstring: bool = True,
message_version: Literal["v0", "v1"] = "v0",
) -> BaseTool: ...
@@ -51,6 +53,7 @@ def tool(
response_format: Literal["content", "content_and_artifact"] = "content",
parse_docstring: bool = False,
error_on_invalid_docstring: bool = True,
message_version: Literal["v0", "v1"] = "v0",
) -> BaseTool: ...
@@ -65,6 +68,7 @@ def tool(
response_format: Literal["content", "content_and_artifact"] = "content",
parse_docstring: bool = False,
error_on_invalid_docstring: bool = True,
message_version: Literal["v0", "v1"] = "v0",
) -> Callable[[Union[Callable, Runnable]], BaseTool]: ...
@@ -79,6 +83,7 @@ def tool(
response_format: Literal["content", "content_and_artifact"] = "content",
parse_docstring: bool = False,
error_on_invalid_docstring: bool = True,
message_version: Literal["v0", "v1"] = "v0",
) -> Union[
BaseTool,
Callable[[Union[Callable, Runnable]], BaseTool],
@@ -118,6 +123,11 @@ def tool(
error_on_invalid_docstring: if ``parse_docstring`` is provided, configure
whether to raise ValueError on invalid Google Style docstrings.
Defaults to True.
message_version: Version of ToolMessage to return given
:class:`~langchain_core.messages.content_blocks.ToolCall` input.
If ``"v0"``, output will be a v0 :class:`~langchain_core.messages.tool.ToolMessage`.
If ``"v1"``, output will be a v1 :class:`~langchain_core.v1.messages.ToolMessage`.
Returns:
The tool.
@@ -216,7 +226,7 @@ def tool(
\"\"\"
return bar
""" # noqa: D214, D410, D411
""" # noqa: D214, D410, D411, E501
def _create_tool_factory(
tool_name: str,
@@ -274,6 +284,7 @@ def tool(
response_format=response_format,
parse_docstring=parse_docstring,
error_on_invalid_docstring=error_on_invalid_docstring,
message_version=message_version,
)
# If someone doesn't want a schema applied, we must treat it as
# a simple string->string function
@@ -290,6 +301,7 @@ def tool(
return_direct=return_direct,
coroutine=coroutine,
response_format=response_format,
message_version=message_version,
)
return _tool_factory
@@ -383,6 +395,7 @@ def convert_runnable_to_tool(
name: Optional[str] = None,
description: Optional[str] = None,
arg_types: Optional[dict[str, type]] = None,
message_version: Literal["v0", "v1"] = "v0",
) -> BaseTool:
"""Convert a Runnable into a BaseTool.
@@ -392,10 +405,15 @@ def convert_runnable_to_tool(
name: The name of the tool. Defaults to None.
description: The description of the tool. Defaults to None.
arg_types: The types of the arguments. Defaults to None.
message_version: Version of ToolMessage to return given
:class:`~langchain_core.messages.content_blocks.ToolCall` input.
If ``"v0"``, output will be a v0 :class:`~langchain_core.messages.tool.ToolMessage`.
If ``"v1"``, output will be a v1 :class:`~langchain_core.v1.messages.ToolMessage`.
Returns:
The tool.
"""
""" # noqa: E501
if args_schema:
runnable = runnable.with_types(input_type=args_schema)
description = description or _get_description_from_runnable(runnable)
@@ -408,6 +426,7 @@ def convert_runnable_to_tool(
func=runnable.invoke,
coroutine=runnable.ainvoke,
description=description,
message_version=message_version,
)
async def ainvoke_wrapper(
@@ -435,4 +454,5 @@ def convert_runnable_to_tool(
coroutine=ainvoke_wrapper,
description=description,
args_schema=args_schema,
message_version=message_version,
)

View File

@@ -72,6 +72,7 @@ def create_retriever_tool(
document_prompt: Optional[BasePromptTemplate] = None,
document_separator: str = "\n\n",
response_format: Literal["content", "content_and_artifact"] = "content",
message_version: Literal["v0", "v1"] = "v1",
) -> Tool:
r"""Create a tool to do retrieval of documents.
@@ -88,10 +89,15 @@ def create_retriever_tool(
"content_and_artifact" then the output is expected to be a two-tuple
corresponding to the (content, artifact) of a ToolMessage (artifact
being a list of documents in this case). Defaults to "content".
message_version: Version of ToolMessage to return given
:class:`~langchain_core.messages.content_blocks.ToolCall` input.
If ``"v0"``, output will be a v0 :class:`~langchain_core.messages.tool.ToolMessage`.
If ``"v1"``, output will be a v1 :class:`~langchain_core.v1.messages.ToolMessage`.
Returns:
Tool class to pass to an agent.
"""
""" # noqa: E501
document_prompt = document_prompt or PromptTemplate.from_template("{page_content}")
func = partial(
_get_relevant_documents,
@@ -114,4 +120,5 @@ def create_retriever_tool(
coroutine=afunc,
args_schema=RetrieverInput,
response_format=response_format,
message_version=message_version,
)

View File

@@ -129,6 +129,7 @@ class StructuredTool(BaseTool):
response_format: Literal["content", "content_and_artifact"] = "content",
parse_docstring: bool = False,
error_on_invalid_docstring: bool = False,
message_version: Literal["v0", "v1"] = "v0",
**kwargs: Any,
) -> StructuredTool:
"""Create tool from a given function.
@@ -157,6 +158,12 @@ class StructuredTool(BaseTool):
error_on_invalid_docstring: if ``parse_docstring`` is provided, configure
whether to raise ValueError on invalid Google Style docstrings.
Defaults to False.
message_version: Version of ToolMessage to return given
:class:`~langchain_core.messages.content_blocks.ToolCall` input.
If ``"v0"``, output will be a v0 :class:`~langchain_core.messages.tool.ToolMessage`.
If ``"v1"``, output will be a v1 :class:`~langchain_core.v1.messages.ToolMessage`.
kwargs: Additional arguments to pass to the tool
Returns:
@@ -175,7 +182,7 @@ class StructuredTool(BaseTool):
tool = StructuredTool.from_function(add)
tool.run(1, 2) # 3
"""
""" # noqa: E501
if func is not None:
source_function = func
elif coroutine is not None:
@@ -232,6 +239,7 @@ class StructuredTool(BaseTool):
description=description_,
return_direct=return_direct,
response_format=response_format,
message_version=message_version,
**kwargs,
)

View File

@@ -17,6 +17,7 @@ from typing_extensions import override
from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
from langchain_core.exceptions import TracerException # noqa: F401
from langchain_core.tracers.core import _TracerCore
from langchain_core.v1.messages import AIMessage, AIMessageChunk, MessageV1
if TYPE_CHECKING:
from collections.abc import Sequence
@@ -54,7 +55,7 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
def on_chat_model_start(
self,
serialized: dict[str, Any],
messages: list[list[BaseMessage]],
messages: Union[list[list[BaseMessage]], list[MessageV1]],
*,
run_id: UUID,
tags: Optional[list[str]] = None,
@@ -138,7 +139,9 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
self,
token: str,
*,
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
chunk: Optional[
Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]
] = None,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
@@ -190,7 +193,9 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
)
@override
def on_llm_end(self, response: LLMResult, *, run_id: UUID, **kwargs: Any) -> Run:
def on_llm_end(
self, response: Union[LLMResult, AIMessage], *, run_id: UUID, **kwargs: Any
) -> Run:
"""End a trace for an LLM run.
Args:
@@ -562,7 +567,7 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
async def on_chat_model_start(
self,
serialized: dict[str, Any],
messages: list[list[BaseMessage]],
messages: Union[list[list[BaseMessage]], list[MessageV1]],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
@@ -617,7 +622,9 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
self,
token: str,
*,
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
chunk: Optional[
Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]
] = None,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
@@ -646,7 +653,7 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
@override
async def on_llm_end(
self,
response: LLMResult,
response: Union[LLMResult, AIMessage],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
@@ -882,7 +889,7 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
self,
run: Run,
token: str,
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]],
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]],
) -> None:
"""Process new LLM token."""

View File

@@ -18,6 +18,7 @@ from typing import (
from langchain_core.exceptions import TracerException
from langchain_core.load import dumpd
from langchain_core.messages.utils import convert_from_v1_message
from langchain_core.outputs import (
ChatGeneration,
ChatGenerationChunk,
@@ -25,6 +26,12 @@ from langchain_core.outputs import (
LLMResult,
)
from langchain_core.tracers.schemas import Run
from langchain_core.v1.messages import (
AIMessage,
AIMessageChunk,
MessageV1,
MessageV1Types,
)
if TYPE_CHECKING:
from collections.abc import Coroutine, Sequence
@@ -156,7 +163,7 @@ class _TracerCore(ABC):
def _create_chat_model_run(
self,
serialized: dict[str, Any],
messages: list[list[BaseMessage]],
messages: Union[list[list[BaseMessage]], list[MessageV1]],
run_id: UUID,
tags: Optional[list[str]] = None,
parent_run_id: Optional[UUID] = None,
@@ -181,6 +188,12 @@ class _TracerCore(ABC):
start_time = datetime.now(timezone.utc)
if metadata:
kwargs.update({"metadata": metadata})
if isinstance(messages[0], MessageV1Types):
# Convert from v1 messages to BaseMessage
messages = [
[convert_from_v1_message(msg) for msg in messages] # type: ignore[arg-type]
]
messages = cast("list[list[BaseMessage]]", messages)
return Run(
id=run_id,
parent_run_id=parent_run_id,
@@ -230,7 +243,9 @@ class _TracerCore(ABC):
self,
token: str,
run_id: UUID,
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
chunk: Optional[
Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]
] = None,
parent_run_id: Optional[UUID] = None, # noqa: ARG002
) -> Run:
"""Append token event to LLM run and return the run."""
@@ -276,7 +291,15 @@ class _TracerCore(ABC):
)
return llm_run
def _complete_llm_run(self, response: LLMResult, run_id: UUID) -> Run:
def _complete_llm_run(
self, response: Union[LLMResult, AIMessage], run_id: UUID
) -> Run:
if isinstance(response, AIMessage):
response = LLMResult(
generations=[
[ChatGeneration(message=convert_from_v1_message(response))]
]
)
llm_run = self._get_run(run_id, run_type={"llm", "chat_model"})
if getattr(llm_run, "outputs", None) is None:
llm_run.outputs = {}
@@ -558,7 +581,7 @@ class _TracerCore(ABC):
self,
run: Run, # noqa: ARG002
token: str, # noqa: ARG002
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]], # noqa: ARG002
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]], # noqa: ARG002
) -> Union[None, Coroutine[Any, Any, None]]:
"""Process new LLM token."""
return None

View File

@@ -38,6 +38,7 @@ from langchain_core.runnables.utils import (
from langchain_core.tracers._streaming import _StreamingCallbackHandler
from langchain_core.tracers.memory_stream import _MemoryStream
from langchain_core.utils.aiter import aclosing, py_anext
from langchain_core.v1.messages import MessageV1
if TYPE_CHECKING:
from collections.abc import AsyncIterator, Iterator, Sequence
@@ -45,6 +46,8 @@ if TYPE_CHECKING:
from langchain_core.documents import Document
from langchain_core.runnables import Runnable, RunnableConfig
from langchain_core.tracers.log_stream import LogEntry
from langchain_core.v1.messages import AIMessage as AIMessageV1
from langchain_core.v1.messages import AIMessageChunk as AIMessageChunkV1
logger = logging.getLogger(__name__)
@@ -297,7 +300,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
async def on_chat_model_start(
self,
serialized: dict[str, Any],
messages: list[list[BaseMessage]],
messages: Union[list[list[BaseMessage]], list[MessageV1]],
*,
run_id: UUID,
tags: Optional[list[str]] = None,
@@ -307,6 +310,8 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
**kwargs: Any,
) -> None:
"""Start a trace for an LLM run."""
# below cast is because type is converted in handle_event
messages = cast("list[list[BaseMessage]]", messages)
name_ = _assign_name(name, serialized)
run_type = "chat_model"
@@ -407,13 +412,18 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
self,
token: str,
*,
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
chunk: Optional[
Union[GenerationChunk, ChatGenerationChunk, AIMessageChunkV1]
] = None,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> None:
"""Run on new LLM token. Only available when streaming is enabled."""
run_info = self.run_map.get(run_id)
chunk = cast(
"Optional[Union[GenerationChunk, ChatGenerationChunk]]", chunk
) # converted in handle_event
chunk_: Union[GenerationChunk, BaseMessageChunk]
if run_info is None:
@@ -456,9 +466,10 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
@override
async def on_llm_end(
self, response: LLMResult, *, run_id: UUID, **kwargs: Any
self, response: Union[LLMResult, AIMessageV1], *, run_id: UUID, **kwargs: Any
) -> None:
"""End a trace for an LLM run."""
response = cast("LLMResult", response) # converted in handle_event
run_info = self.run_map.pop(run_id)
inputs_ = run_info["inputs"]

View File

@@ -5,7 +5,7 @@ from __future__ import annotations
import logging
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any, Optional, Union
from typing import TYPE_CHECKING, Any, Optional, Union, cast
from uuid import UUID
from langsmith import Client
@@ -21,12 +21,15 @@ from typing_extensions import override
from langchain_core.env import get_runtime_environment
from langchain_core.load import dumpd
from langchain_core.messages.utils import convert_from_v1_message
from langchain_core.tracers.base import BaseTracer
from langchain_core.tracers.schemas import Run
from langchain_core.v1.messages import MessageV1Types
if TYPE_CHECKING:
from langchain_core.messages import BaseMessage
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
from langchain_core.v1.messages import AIMessageChunk, MessageV1
logger = logging.getLogger(__name__)
_LOGGED = set()
@@ -113,7 +116,7 @@ class LangChainTracer(BaseTracer):
def on_chat_model_start(
self,
serialized: dict[str, Any],
messages: list[list[BaseMessage]],
messages: Union[list[list[BaseMessage]], list[MessageV1]],
*,
run_id: UUID,
tags: Optional[list[str]] = None,
@@ -140,6 +143,12 @@ class LangChainTracer(BaseTracer):
start_time = datetime.now(timezone.utc)
if metadata:
kwargs.update({"metadata": metadata})
if isinstance(messages[0], MessageV1Types):
# Convert from v1 messages to BaseMessage
messages = [
[convert_from_v1_message(msg) for msg in messages] # type: ignore[arg-type]
]
messages = cast("list[list[BaseMessage]]", messages)
chat_model_run = Run(
id=run_id,
parent_run_id=parent_run_id,
@@ -232,7 +241,9 @@ class LangChainTracer(BaseTracer):
self,
token: str,
run_id: UUID,
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
chunk: Optional[
Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]
] = None,
parent_run_id: Optional[UUID] = None,
) -> Run:
"""Append token event to LLM run and return the run."""

View File

@@ -34,6 +34,7 @@ if TYPE_CHECKING:
from langchain_core.runnables.utils import Input, Output
from langchain_core.tracers.schemas import Run
from langchain_core.v1.messages import AIMessageChunk
class LogEntry(TypedDict):
@@ -176,7 +177,7 @@ class RunLog(RunLogPatch):
# Then compare that the ops are the same
return super().__eq__(other)
__hash__ = None # type: ignore[assignment]
__hash__ = None
T = TypeVar("T")
@@ -485,7 +486,7 @@ class LogStreamCallbackHandler(BaseTracer, _StreamingCallbackHandler):
self,
run: Run,
token: str,
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]],
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]],
) -> None:
"""Process new LLM token."""
index = self._key_map_by_run_id.get(run.id)

View File

@@ -277,7 +277,7 @@ def _convert_any_typed_dicts_to_pydantic(
)
fields: dict = {}
for arg, arg_type in annotations_.items():
if get_origin(arg_type) is Annotated:
if get_origin(arg_type) is Annotated: # type: ignore[comparison-overlap]
annotated_args = get_args(arg_type)
new_arg_type = _convert_any_typed_dicts_to_pydantic(
annotated_args[0], depth=depth + 1, visited=visited
@@ -575,12 +575,23 @@ def convert_to_openai_tool(
Added support for OpenAI's image generation built-in tool.
"""
from langchain_core.tools import Tool
if isinstance(tool, dict):
if tool.get("type") in _WellKnownOpenAITools:
return tool
# As of 03.12.25 can be "web_search_preview" or "web_search_preview_2025_03_11"
if (tool.get("type") or "").startswith("web_search_preview"):
return tool
if isinstance(tool, Tool) and (tool.metadata or {}).get("type") == "custom_tool":
oai_tool = {
"type": "custom",
"name": tool.name,
"description": tool.description,
}
if tool.metadata is not None and "format" in tool.metadata:
oai_tool["format"] = tool.metadata["format"]
return oai_tool
oai_function = convert_to_openai_function(tool, strict=strict)
return {"type": "function", "function": oai_function}
@@ -616,7 +627,7 @@ def convert_to_json_schema(
@beta()
def tool_example_to_messages(
input: str, # noqa: A002
input: str,
tool_calls: list[BaseModel],
tool_outputs: Optional[list[str]] = None,
*,
@@ -629,15 +640,16 @@ def tool_example_to_messages(
The list of messages per example by default corresponds to:
1) HumanMessage: contains the content from which content should be extracted.
2) AIMessage: contains the extracted information from the model
3) ToolMessage: contains confirmation to the model that the model requested a tool
correctly.
1. ``HumanMessage``: contains the content from which content should be extracted.
2. ``AIMessage``: contains the extracted information from the model
3. ``ToolMessage``: contains confirmation to the model that the model requested a
tool correctly.
If `ai_response` is specified, there will be a final AIMessage with that response.
If ``ai_response`` is specified, there will be a final ``AIMessage`` with that
response.
The ToolMessage is required because some chat models are hyper-optimized for agents
rather than for an extraction use case.
The ``ToolMessage`` is required because some chat models are hyper-optimized for
agents rather than for an extraction use case.
Arguments:
input: string, the user input
@@ -646,7 +658,7 @@ def tool_example_to_messages(
tool_outputs: Optional[list[str]], a list of tool call outputs.
Does not need to be provided. If not provided, a placeholder value
will be inserted. Defaults to None.
ai_response: Optional[str], if provided, content for a final AIMessage.
ai_response: Optional[str], if provided, content for a final ``AIMessage``.
Returns:
A list of messages
@@ -728,6 +740,7 @@ def _parse_google_docstring(
"""Parse the function and argument descriptions from the docstring of a function.
Assumes the function docstring follows Google Python style guide.
"""
if docstring:
docstring_blocks = docstring.split("\n\n")

View File

@@ -0,0 +1 @@
"""LangChain v1.0.0 types."""

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,986 @@
"""LangChain v1.0.0 message format.
Each message has content that may be comprised of content blocks, defined under
``langchain_core.messages.content_blocks``.
"""
import uuid
from dataclasses import dataclass, field
from typing import Any, Literal, Optional, Union, cast, get_args
from pydantic import BaseModel
from typing_extensions import TypedDict
import langchain_core.messages.content_blocks as types
from langchain_core._api.deprecation import warn_deprecated
from langchain_core.messages.ai import (
_LC_AUTO_PREFIX,
_LC_ID_PREFIX,
UsageMetadata,
add_usage,
)
from langchain_core.messages.base import merge_content
from langchain_core.messages.tool import ToolOutputMixin
from langchain_core.messages.tool import invalid_tool_call as create_invalid_tool_call
from langchain_core.messages.tool import tool_call as create_tool_call
from langchain_core.utils._merge import merge_dicts
from langchain_core.utils.json import parse_partial_json
class TextAccessor(str):
"""String-like object that supports both property and method access patterns.
Exists to maintain backward compatibility while transitioning from method-based to
property-based text access in message objects. In LangChain <v0.4, message text was
accessed via ``.text()`` method calls. In v0.4=<, the preferred pattern is property
access via ``.text``.
Rather than breaking existing code immediately, ``TextAccessor`` allows both
patterns:
- Modern property access: ``message.text`` (returns string directly)
- Legacy method access: ``message.text()`` (callable, emits deprecation warning)
Examples:
>>> msg = AIMessage("Hello world")
>>> text = msg.text # Preferred: property access
>>> text = msg.text() # Deprecated: method access (shows warning)
"""
__slots__ = ()
def __new__(cls, value: str) -> "TextAccessor":
"""Create new TextAccessor instance."""
return str.__new__(cls, value)
def __call__(self) -> str:
"""Enable method-style text access for backward compatibility.
.. deprecated:: 0.4.0
Calling ``.text()`` as a method is deprecated. Use ``.text`` as a property
instead. This method will be removed in 2.0.0.
Returns:
The string content, identical to property access.
"""
warn_deprecated(
since="0.4.0",
message=(
"Calling .text() as a method is deprecated. "
"Use .text as a property instead (e.g., message.text)."
),
removal="2.0.0",
)
return str(self)
def _ensure_id(id_val: Optional[str]) -> str:
"""Ensure the ID is a valid string, generating a new UUID if not provided.
Auto-generated UUIDs are prefixed by ``'lc_'`` to indicate they are
LangChain-generated IDs.
Args:
id_val: Optional string ID value to validate.
Returns:
A valid string ID, either the provided value or a new UUID.
"""
return id_val or str(f"{_LC_AUTO_PREFIX}{uuid.uuid4()}")
class ResponseMetadata(TypedDict, total=False):
"""Metadata about the response from the AI provider.
Contains additional information returned by the provider, such as
response headers, service tiers, log probabilities, system fingerprints, etc.
**Extensibility Design:**
This uses ``total=False`` to allow arbitrary additional keys beyond the typed
fields below. This enables provider-specific metadata without breaking type safety:
- OpenAI might include: ``{"system_fingerprint": "fp_123", "logprobs": {...}}``
- Anthropic might include: ``{"stop_reason": "stop_sequence", "usage": {...}}``
- Custom providers can add their own fields
The common fields (``model_provider``, ``model_name``) provide a baseline
contract while preserving flexibility for provider innovations.
"""
model_provider: str
"""Name and version of the provider that created the message (ex: ``'openai'``)."""
model_name: str
"""Name of the model that generated the message."""
@dataclass
class AIMessage:
"""A v1 message generated by an AI assistant.
Represents a response from an AI model, including text content, tool calls,
and metadata about the generation process.
Attributes:
type: Message type identifier, always ``'ai'``.
id: Unique identifier for the message.
name: The name/identifier of the agent or assistant that generated this message.
lc_version: Encoding version for the message.
content: List of content blocks containing the message data.
tool_calls: Optional list of tool calls made by the AI.
invalid_tool_calls: Optional list of tool calls that failed validation.
usage: Optional dictionary containing usage statistics.
"""
type: Literal["ai"] = "ai"
"""The type of the message. Must be a string that is unique to the message type.
The purpose of this field is to allow for easy identification of the message type
when deserializing messages.
"""
name: Optional[str] = None
"""The name/identifier of the agent or assistant that generated this message.
Used primarily in multi-agent systems to track which agent is speaking. Also used by
some providers for conversation attribution and context.
Usage of this field is optional, and whether it's used or not is up to the
model implementation.
**Examples:**
.. python::
AIMessage(
content= [
TextContentBlock("Analysis complete"),
],
name="research_agent"
)
AIMessage(
content= [
TextContentBlock("Task routed to specialist"),
],
name="supervisor"
)
"""
id: Optional[str] = None
"""Unique identifier for the message.
If the provider assigns a meaningful ID, it should be used here. Otherwise, a
LangChain-generated ID will be used.
"""
lc_version: str = "v1"
"""Encoding version for the message. Used for serialization."""
content: list[types.ContentBlock] = field(default_factory=list)
"""Message content as a list of content blocks."""
usage_metadata: Optional[UsageMetadata] = None
"""If provided, usage metadata for a message, such as token counts."""
response_metadata: ResponseMetadata = field(
default_factory=lambda: ResponseMetadata()
)
"""Metadata about the response.
This field should include non-standard data returned by the provider, such as
response headers, service tiers, or log probabilities.
"""
parsed: Optional[Union[dict[str, Any], BaseModel]] = None
"""Auto-parsed message contents, if applicable."""
def __init__(
self,
content: Union[str, list[types.ContentBlock]],
id: Optional[str] = None,
name: Optional[str] = None,
lc_version: str = "v1",
response_metadata: Optional[ResponseMetadata] = None,
usage_metadata: Optional[UsageMetadata] = None,
tool_calls: Optional[list[types.ToolCall]] = None,
invalid_tool_calls: Optional[list[types.InvalidToolCall]] = None,
parsed: Optional[Union[dict[str, Any], BaseModel]] = None,
):
"""Initialize a v1 AI message.
Args:
content: Message content as string or list of content blocks.
id: Optional unique identifier for the message.
name: The name/identifier of the agent or assistant that generated this
message.
lc_version: Encoding version for the message.
response_metadata: Optional metadata about the response.
usage_metadata: Optional metadata about token usage.
tool_calls: Optional list of tool calls made by the AI. Tool calls should
generally be included in message content. If passed on init, they will
be added to the content list.
invalid_tool_calls: Optional list of tool calls that failed validation.
parsed: Optional auto-parsed message contents, if applicable.
"""
if isinstance(content, str):
self.content = [types.create_text_block(content)]
else:
self.content = content
self.id = _ensure_id(id)
self.name = name
self.lc_version = lc_version
self.usage_metadata = usage_metadata
self.parsed = parsed
if response_metadata is None:
self.response_metadata = {}
else:
self.response_metadata = response_metadata
# Add tool calls to content if provided on init
if tool_calls:
content_tool_calls = {
block["id"]
for block in self.content
if types.is_tool_call_block(block) and "id" in block
}
for tool_call in tool_calls:
if "id" in tool_call and tool_call["id"] in content_tool_calls:
continue
self.content.append(tool_call)
if invalid_tool_calls:
content_tool_calls = {
block["id"]
for block in self.content
if types.is_invalid_tool_call_block(block) and "id" in block
}
for invalid_tool_call in invalid_tool_calls:
if (
"id" in invalid_tool_call
and invalid_tool_call["id"] in content_tool_calls
):
continue
self.content.append(invalid_tool_call)
self._tool_calls: list[types.ToolCall] = [
block for block in self.content if types.is_tool_call_block(block)
]
self._invalid_tool_calls: list[types.InvalidToolCall] = [
block for block in self.content if types.is_invalid_tool_call_block(block)
]
@property
def text(self) -> str:
"""Extract all text content from the AI message as a string.
Can be used as both property (``message.text``) and method (``message.text()``).
.. deprecated:: 0.4.0
Calling ``.text()`` as a method is deprecated. Use ``.text`` as a property
instead. This method will be removed in 2.0.0.
"""
text_value = "".join(
block["text"] for block in self.content if types.is_text_block(block)
)
return cast("str", TextAccessor(text_value))
@property
def tool_calls(self) -> list[types.ToolCall]:
"""Get the tool calls made by the AI."""
if not self._tool_calls:
self._tool_calls = [
block for block in self.content if types.is_tool_call_block(block)
]
return self._tool_calls
@tool_calls.setter
def tool_calls(self, value: list[types.ToolCall]) -> None:
"""Set the tool calls for the AI message."""
self._tool_calls = value
@property
def invalid_tool_calls(self) -> list[types.InvalidToolCall]:
"""Get the invalid tool calls made by the AI."""
if not self._invalid_tool_calls:
self._invalid_tool_calls = [
block
for block in self.content
if types.is_invalid_tool_call_block(block)
]
return self._invalid_tool_calls
@dataclass
class AIMessageChunk(AIMessage):
"""A partial chunk of an AI message during streaming.
Represents a portion of an AI response that is delivered incrementally
during streaming generation. When AI providers stream responses token-by-token,
each chunk contains partial content that gets accumulated into a complete message.
**Streaming Workflow:**
1. Provider streams partial responses as ``AIMessageChunk`` objects
2. Chunks are accumulated: ``chunk1 + chunk2 + ...``
3. Final accumulated chunk can be converted to ``AIMessage`` via ``.to_message()``
**Tool Call Handling:**
During streaming, tool calls arrive as ``ToolCallChunk`` objects with partial
JSON. When chunks are accumulated, the final chunk (marked with
``chunk_position="last"``) triggers parsing of complete tool calls from the
accumulated JSON strings.
**Content Merging:**
Content blocks are merged intelligently - text blocks combine their strings,
tool call chunks accumulate arguments, and other blocks are concatenated.
Attributes:
type: Message type identifier, always ``'ai_chunk'``.
id: Unique identifier for the message chunk.
name: The name/identifier of the agent or assistant that generated this message.
content: List of content blocks containing partial message data.
tool_call_chunks: Optional list of partial tool call data.
usage_metadata: Optional metadata about token usage and costs.
"""
type: Literal["ai_chunk"] = "ai_chunk" # type: ignore[assignment]
"""The type of the message. Must be a string that is unique to the message type.
The purpose of this field is to allow for easy identification of the message type
when deserializing messages.
"""
def __init__(
self,
content: Union[str, list[types.ContentBlock]],
*,
id: Optional[str] = None,
name: Optional[str] = None,
lc_version: str = "v1",
response_metadata: Optional[ResponseMetadata] = None,
usage_metadata: Optional[UsageMetadata] = None,
tool_call_chunks: Optional[list[types.ToolCallChunk]] = None,
parsed: Optional[Union[dict[str, Any], BaseModel]] = None,
chunk_position: Optional[Literal["last"]] = None,
):
"""Initialize a v1 AI message.
Args:
content: Message content as string or list of content blocks.
id: Optional unique identifier for the message.
name: The name/identifier of the agent or assistant that generated this
message.
lc_version: Encoding version for the message.
response_metadata: Optional metadata about the response.
usage_metadata: Optional metadata about token usage.
tool_call_chunks: Optional list of partial tool call data.
parsed: Optional auto-parsed message contents, if applicable.
chunk_position: Optional position of the chunk in the stream. If ``'last'``,
tool calls will be parsed when aggregated into a stream.
"""
if isinstance(content, str):
self.content = [{"type": "text", "text": content, "index": 0}]
else:
self.content = content
self.id = _ensure_id(id)
self.name = name
self.lc_version = lc_version
self.usage_metadata = usage_metadata
self.parsed = parsed
self.chunk_position = chunk_position
if response_metadata is None:
self.response_metadata = {}
else:
self.response_metadata = response_metadata
if tool_call_chunks:
content_tool_call_chunks = {
block["id"]
for block in self.content
if types.is_tool_call_chunk(block) and "id" in block
}
for chunk in tool_call_chunks:
if "id" in chunk and chunk["id"] in content_tool_call_chunks:
continue
self.content.append(chunk)
self._tool_call_chunks = [
block for block in self.content if types.is_tool_call_chunk(block)
]
self._tool_calls: list[types.ToolCall] = []
self._invalid_tool_calls: list[types.InvalidToolCall] = []
@property
def tool_call_chunks(self) -> list[types.ToolCallChunk]:
"""Get the tool calls made by the AI."""
if not self._tool_call_chunks:
self._tool_call_chunks = [
block for block in self.content if types.is_tool_call_chunk(block)
]
return self._tool_call_chunks
@property
def tool_calls(self) -> list[types.ToolCall]:
"""Get the tool calls made by the AI."""
if not self._tool_calls:
parsed_content = _init_tool_calls(self.content)
tool_calls: list[types.ToolCall] = []
invalid_tool_calls: list[types.InvalidToolCall] = []
for block in parsed_content:
if types.is_tool_call_block(block):
tool_calls.append(block)
elif types.is_invalid_tool_call_block(block):
invalid_tool_calls.append(block)
self._tool_calls = tool_calls
self._invalid_tool_calls = invalid_tool_calls
return self._tool_calls
@tool_calls.setter
def tool_calls(self, value: list[types.ToolCall]) -> None:
"""Set the tool calls for the AI message."""
self._tool_calls = value
@property
def invalid_tool_calls(self) -> list[types.InvalidToolCall]:
"""Get the invalid tool calls made by the AI."""
if not self._invalid_tool_calls:
parsed_content = _init_tool_calls(self.content)
tool_calls: list[types.ToolCall] = []
invalid_tool_calls: list[types.InvalidToolCall] = []
for block in parsed_content:
if types.is_tool_call_block(block):
tool_calls.append(block)
elif types.is_invalid_tool_call_block(block):
invalid_tool_calls.append(block)
self._tool_calls = tool_calls
self._invalid_tool_calls = invalid_tool_calls
return self._invalid_tool_calls
def __add__(self, other: Any) -> "AIMessageChunk":
"""Add ``AIMessageChunk`` to this one."""
if isinstance(other, AIMessageChunk):
return add_ai_message_chunks(self, other)
if isinstance(other, (list, tuple)) and all(
isinstance(o, AIMessageChunk) for o in other
):
return add_ai_message_chunks(self, *other)
error_msg = "Can only add AIMessageChunk or sequence of AIMessageChunk."
raise NotImplementedError(error_msg)
def to_message(self) -> "AIMessage":
"""Convert this ``AIMessageChunk`` to an ``AIMessage``."""
return AIMessage(
content=_init_tool_calls(self.content),
id=self.id,
name=self.name,
lc_version=self.lc_version,
response_metadata=self.response_metadata,
usage_metadata=self.usage_metadata,
parsed=self.parsed,
)
def _init_tool_calls(content: list[types.ContentBlock]) -> list[types.ContentBlock]:
"""Parse tool call chunks in content into tool calls."""
new_content = []
for block in content:
if not types.is_tool_call_chunk(block):
new_content.append(block)
continue
try:
args_str = block.get("args")
args_ = parse_partial_json(str(args_str)) if args_str else {}
if isinstance(args_, dict):
new_content.append(
create_tool_call(
name=block.get("name") or "",
args=args_,
id=block.get("id", ""),
)
)
else:
new_content.append(
create_invalid_tool_call(
name=block.get("name", ""),
args=block.get("args", ""),
id=block.get("id", ""),
error=None,
)
)
except Exception:
new_content.append(
create_invalid_tool_call(
name=block.get("name", ""),
args=block.get("args", ""),
id=block.get("id", ""),
error=None,
)
)
return new_content
def add_ai_message_chunks(
left: AIMessageChunk, *others: AIMessageChunk
) -> AIMessageChunk:
"""Add multiple ``AIMessageChunks`` together."""
if not others:
return left
content = cast(
"list[types.ContentBlock]",
merge_content(
cast("list[str | dict[Any, Any]]", left.content),
*(cast("list[str | dict[Any, Any]]", o.content) for o in others),
),
)
response_metadata = merge_dicts(
cast("dict", left.response_metadata),
*(cast("dict", o.response_metadata) for o in others),
)
# Token usage
if left.usage_metadata or any(o.usage_metadata is not None for o in others):
usage_metadata: Optional[UsageMetadata] = left.usage_metadata
for other in others:
usage_metadata = add_usage(usage_metadata, other.usage_metadata)
else:
usage_metadata = None
# Parsed
# 'parsed' always represents an aggregation not an incremental value, so the last
# non-null value is kept.
parsed = None
for m in reversed([left, *others]):
if m.parsed is not None:
parsed = m.parsed
break
chunk_id = None
candidates = [left.id] + [o.id for o in others]
# first pass: pick the first provider-assigned id (non-`run-*` and non-`lc_*`)
for id_ in candidates:
if (
id_
and not id_.startswith(_LC_ID_PREFIX)
and not id_.startswith(_LC_AUTO_PREFIX)
):
chunk_id = id_
break
else:
# second pass: prefer lc_run-* ids over lc_* ids
for id_ in candidates:
if id_ and id_.startswith(_LC_ID_PREFIX):
chunk_id = id_
break
else:
# third pass: take any remaining id (auto-generated lc_* ids)
for id_ in candidates:
if id_:
chunk_id = id_
break
chunk_position: Optional[Literal["last"]] = (
"last" if any(x.chunk_position == "last" for x in [left, *others]) else None
)
if chunk_position == "last":
content = _init_tool_calls(content)
return left.__class__(
content=content,
response_metadata=cast("ResponseMetadata", response_metadata),
usage_metadata=usage_metadata,
parsed=parsed,
id=chunk_id,
chunk_position=chunk_position,
)
@dataclass
class HumanMessage:
"""A message from a human user.
Represents input from a human user in a conversation, containing text
or other content types like images.
Attributes:
type: Message type identifier, always ``'human'``.
id: Unique identifier for the message.
content: List of content blocks containing the user's input.
name: Optional identifier for the human user who sent this message.
"""
id: str
"""Used for serialization.
If the provider assigns a meaningful ID, it should be used here. Otherwise, a
LangChain-generated ID will be used.
"""
content: list[types.ContentBlock]
"""Message content as a list of content blocks."""
type: Literal["human"] = "human"
"""The type of the message. Must be a string that is unique to the message type.
The purpose of this field is to allow for easy identification of the message type
when deserializing messages.
"""
name: Optional[str] = None
"""Optional identifier for the human user who sent this message.
Can be helpful in multi-user scenarios or for conversation tracking. Most chat model
providers ignore this field for human messages.
Usage of this field is optional, and whether it's used or not is up to the
model implementation.
**Examples:**
.. python::
HumanMessage(
content= [
TextContentBlock("Hello"),
],
name="user_alice"
)
HumanMessage(
content= [
TextContentBlock("Run analysis"),
],
name="admin_bob"
)
"""
def __init__(
self,
content: Union[str, list[types.ContentBlock]],
*,
id: Optional[str] = None,
name: Optional[str] = None,
):
"""Initialize a v1 human message.
Args:
content: Message content as string or list of content blocks.
id: Optional unique identifier for the message.
name: Optional identifier for the human user who sent this message.
"""
self.id = _ensure_id(id)
if isinstance(content, str):
self.content = [{"type": "text", "text": content}]
else:
self.content = content
self.name = name
@property
def text(self) -> str:
"""Extract all text content from the message as a string.
Can be used as both property (``message.text``) and method (``message.text()``).
.. deprecated:: 0.4.0
Calling ``.text()`` as a method is deprecated. Use ``.text`` as a property
instead. This method will be removed in 2.0.0.
"""
text_value = "".join(
block["text"] for block in self.content if types.is_text_block(block)
)
return cast("str", TextAccessor(text_value))
@dataclass
class SystemMessage:
"""A system message containing instructions or context.
Represents system-level instructions or context that guides the AI's
behavior and understanding of the conversation.
Attributes:
type: Message type identifier, always ``'system'``.
id: Unique identifier for the message.
content: List of content blocks containing system instructions.
"""
id: str
"""Used for serialization.
If the provider assigns a meaningful ID, it should be used here. Otherwise, a
LangChain-generated ID will be used.
"""
content: list[types.ContentBlock]
"""Message content as a list of content blocks."""
type: Literal["system"] = "system"
"""The type of the message. Must be a string that is unique to the message type.
The purpose of this field is to allow for easy identification of the message type
when deserializing messages.
"""
name: Optional[str] = None
"""Optional identifier for the system component/context that generated this message.
Can be used to identify different system contexts or configurations.
Usage of this field is optional, and whether it's used or not is up to the
model implementation.
**Examples:**
.. python::
SystemMessage(
content= [
TextContentBlock("You are a helpful assistant"),
],
name="base_prompt"
)
SystemMessage(
content= [
TextContentBlock("Advanced mode enabled"),
],
name="config_update"
)
"""
custom_role: Optional[str] = None
"""If provided, a custom role for the system message.
Example: ``'developer'``.
Integration packages may use this field to assign the system message role if it
contains a recognized value.
"""
def __init__(
self,
content: Union[str, list[types.ContentBlock]],
*,
id: Optional[str] = None,
custom_role: Optional[str] = None,
name: Optional[str] = None,
):
"""Initialize a v1 system message.
Args:
content: Message content as string or list of content blocks.
id: Optional unique identifier for the message.
custom_role: If provided, a custom role for the system message.
name: Optional identifier for the system component/context that generated
this message.
"""
self.id = _ensure_id(id)
if isinstance(content, str):
self.content = [{"type": "text", "text": content}]
else:
self.content = content
self.custom_role = custom_role
self.name = name
@property
def text(self) -> str:
"""Extract all text content from the system message as a string.
Can be used as both property (``message.text``) and method (``message.text()``).
.. deprecated:: 0.4.0
Calling ``.text()`` as a method is deprecated. Use ``.text`` as a property
instead. This method will be removed in 2.0.0.
"""
text_value = "".join(
block["text"] for block in self.content if types.is_text_block(block)
)
return cast("str", TextAccessor(text_value))
@dataclass
class ToolMessage(ToolOutputMixin):
"""A message containing the result of a tool execution.
Represents the output from executing a tool or function call,
including the result data and execution status.
Attributes:
type: Message type identifier, always ``'tool'``.
id: Unique identifier for the message.
tool_call_id: ID of the tool call this message responds to.
content: The result content from tool execution.
artifact: Optional app-side payload not intended for the model.
name: Name of the tool/function that was executed to generate this message.
status: Execution status ("success" or "error").
"""
id: str
"""Used for serialization."""
tool_call_id: str
"""ID of the tool call this message responds to.
This should match the ID of the tool call that this message is responding to.
"""
content: list[types.ContentBlock]
"""Message content as a list of content blocks.
The tool's output should be included in the content, mapped to the appropriate
content block type (e.g., text, image, etc.). For instance, if the tool call returns
a string, it should be wrapped in a ``TextContentBlock``.
"""
type: Literal["tool"] = "tool"
"""The type of the message. Must be a string that is unique to the message type.
The purpose of this field is to allow for easy identification of the message type
when deserializing messages.
"""
artifact: Optional[Any] = None
"""App-side payload not intended for model consumption.
Additonal info and usage examples are available
`in the LangChain documentation <https://python.langchain.com/docs/concepts/tools/#tool-artifacts>`__.
"""
name: Optional[str] = None
"""Name of the tool/function that was executed to generate this message.
.. important::
This field is required by most chat model providers (OpenAI, Anthropic,
Google, etc.) for proper tool calling. The name must match the tool that was
called.
**Examples:**
.. python::
ToolMessage(
content= [
TextContentBlock("42"),
],
name="calculator",
tool_call_id="call_123"
)
ToolMessage(
content= [
TextContentBlock("Weather is sunny"),
],
name="get_weather",
tool_call_id="call_456"
)
"""
status: Literal["success", "error"] = "success"
"""Execution status of the tool call.
Indicates whether the tool call was successful or encountered an error.
Defaults to "success".
"""
def __init__(
self,
content: Union[str, list[types.ContentBlock]],
tool_call_id: str,
*,
id: Optional[str] = None,
name: Optional[str] = None,
artifact: Optional[Any] = None,
status: Literal["success", "error"] = "success",
):
"""Initialize a v1 tool message.
Args:
content: Message content as string or list of content blocks.
tool_call_id: ID of the tool call this message responds to.
id: Optional unique identifier for the message.
name: Name of the tool/function that was executed to generate this message.
artifact: Optional app-side payload not intended for the model.
status: Execution status (``'success'`` or ``'error'``).
"""
self.id = _ensure_id(id)
self.tool_call_id = tool_call_id
if isinstance(content, str):
self.content = [{"type": "text", "text": content}]
else:
self.content = content
self.name = name
self.artifact = artifact
self.status = status
@property
def text(self) -> str:
"""Extract all text content from the tool message as a string.
Can be used as both property (``message.text``) and method (``message.text()``).
.. deprecated:: 0.4.0
Calling ``.text()`` as a method is deprecated. Use ``.text`` as a property
instead. This method will be removed in 2.0.0.
"""
text_value = "".join(
block["text"] for block in self.content if types.is_text_block(block)
)
return cast("str", TextAccessor(text_value))
def __post_init__(self) -> None:
"""Initialize computed fields after dataclass creation.
Ensures the tool message has a valid ID.
"""
self.id = _ensure_id(self.id)
# Alias for a message type that can be any of the defined message types
MessageV1 = Union[
AIMessage,
AIMessageChunk,
HumanMessage,
SystemMessage,
ToolMessage,
]
MessageV1Types = get_args(MessageV1)

View File

@@ -1,3 +1,3 @@
"""langchain-core version information and utilities."""
VERSION = "0.3.72"
VERSION = "0.4.0.dev0"

View File

@@ -16,7 +16,7 @@ dependencies = [
"pydantic>=2.7.4",
]
name = "langchain-core"
version = "0.3.72"
version = "0.4.0.dev0"
description = "Building applications with LLMs through composability"
readme = "README.md"
@@ -28,7 +28,7 @@ repository = "https://github.com/langchain-ai/langchain"
[dependency-groups]
lint = ["ruff<0.13,>=0.12.2"]
typing = [
"mypy<1.16,>=1.15",
"mypy<1.18,>=1.17.1",
"types-pyyaml<7.0.0.0,>=6.0.12.2",
"types-requests<3.0.0.0,>=2.28.11.5",
"langchain-text-splitters",
@@ -67,6 +67,7 @@ langchain-text-splitters = { path = "../text-splitters" }
strict = "True"
strict_bytes = "True"
enable_error_code = "deprecated"
disable_error_code = ["typeddict-unknown-key"]
# TODO: activate for 'strict' checking
disallow_any_generics = "False"
@@ -86,6 +87,7 @@ ignore = [
"FIX002", # Line contains TODO
"ISC001", # Messes with the formatter
"PERF203", # Rarely useful
"PLC0414", # Enable re-export
"PLR09", # Too many something (arg, statements, etc)
"RUF012", # Doesn't play well with Pydantic
"TC001", # Doesn't play well with Pydantic
@@ -105,6 +107,7 @@ unfixable = ["PLW1510",]
flake8-annotations.allow-star-arg-any = true
flake8-annotations.mypy-init-return = true
flake8-builtins.ignorelist = ["id", "input", "type"]
flake8-type-checking.runtime-evaluated-base-classes = ["pydantic.BaseModel","langchain_core.load.serializable.Serializable","langchain_core.runnables.base.RunnableSerializable"]
pep8-naming.classmethod-decorators = [ "classmethod", "langchain_core.utils.pydantic.pre_init", "pydantic.field_validator", "pydantic.v1.root_validator",]
pydocstyle.convention = "google"

View File

@@ -11,6 +11,8 @@ from langchain_core.callbacks.base import AsyncCallbackHandler
from langchain_core.language_models import GenericFakeChatModel
from langchain_core.messages import AIMessage, BaseMessage
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
from langchain_core.v1.messages import AIMessageChunk as AIMessageChunkV1
from langchain_core.v1.messages import MessageV1
class MyCustomAsyncHandler(AsyncCallbackHandler):
@@ -18,7 +20,7 @@ class MyCustomAsyncHandler(AsyncCallbackHandler):
async def on_chat_model_start(
self,
serialized: dict[str, Any],
messages: list[list[BaseMessage]],
messages: Union[list[list[BaseMessage]], list[MessageV1]],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
@@ -35,7 +37,9 @@ class MyCustomAsyncHandler(AsyncCallbackHandler):
self,
token: str,
*,
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
chunk: Optional[
Union[GenerationChunk, ChatGenerationChunk, AIMessageChunkV1]
] = None,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[list[str]] = None,

View File

@@ -9,6 +9,7 @@ from typing_extensions import override
from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
from langchain_core.messages import BaseMessage
from langchain_core.v1.messages import MessageV1
class BaseFakeCallbackHandler(BaseModel):
@@ -285,7 +286,7 @@ class FakeCallbackHandlerWithChatStart(FakeCallbackHandler):
def on_chat_model_start(
self,
serialized: dict[str, Any],
messages: list[list[BaseMessage]],
messages: Union[list[list[BaseMessage]], list[MessageV1]],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,

View File

@@ -16,6 +16,8 @@ from langchain_core.language_models import (
)
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage, HumanMessage
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
from langchain_core.v1.messages import AIMessageChunk as AIMessageChunkV1
from langchain_core.v1.messages import MessageV1
from tests.unit_tests.stubs import (
_any_id_ai_message,
_any_id_ai_message_chunk,
@@ -157,13 +159,13 @@ async def test_callback_handlers() -> None:
"""Verify that model is implemented correctly with handlers working."""
class MyCustomAsyncHandler(AsyncCallbackHandler):
def __init__(self, store: list[str]) -> None:
def __init__(self, store: list[Union[str, AIMessageChunkV1]]) -> None:
self.store = store
async def on_chat_model_start(
self,
serialized: dict[str, Any],
messages: list[list[BaseMessage]],
messages: Union[list[list[BaseMessage]], list[MessageV1]],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
@@ -178,9 +180,11 @@ async def test_callback_handlers() -> None:
@override
async def on_llm_new_token(
self,
token: str,
token: Union[str, AIMessageChunkV1],
*,
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
chunk: Optional[
Union[GenerationChunk, ChatGenerationChunk, AIMessageChunkV1]
] = None,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[list[str]] = None,
@@ -194,7 +198,7 @@ async def test_callback_handlers() -> None:
]
)
model = GenericFakeChatModel(messages=infinite_cycle)
tokens: list[str] = []
tokens: list[Union[str, AIMessageChunkV1]] = []
# New model
results = [
chunk

View File

@@ -14,7 +14,10 @@ from langchain_core.language_models import (
ParrotFakeChatModel,
)
from langchain_core.language_models._utils import _normalize_messages
from langchain_core.language_models.fake_chat_models import FakeListChatModelError
from langchain_core.language_models.fake_chat_models import (
FakeListChatModelError,
GenericFakeChatModelV1,
)
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
@@ -29,6 +32,7 @@ from langchain_core.tracers.base import BaseTracer
from langchain_core.tracers.context import collect_runs
from langchain_core.tracers.event_stream import _AstreamEventsCallbackHandler
from langchain_core.tracers.schemas import Run
from langchain_core.v1.messages import AIMessageChunk as AIMessageChunkV1
from tests.unit_tests.fake.callbacks import (
BaseFakeCallbackHandler,
FakeAsyncCallbackHandler,
@@ -654,3 +658,93 @@ def test_normalize_messages_edge_cases() -> None:
)
]
assert messages == _normalize_messages(messages)
def test_streaming_v1() -> None:
chunks = [
AIMessageChunkV1(
[
{
"type": "reasoning",
"reasoning": "Let's call a tool.",
"index": 0,
}
]
),
AIMessageChunkV1(
[],
tool_call_chunks=[
{
"type": "tool_call_chunk",
"args": "",
"name": "tool_name",
"id": "call_123",
"index": 1,
},
],
),
AIMessageChunkV1(
[],
tool_call_chunks=[
{
"type": "tool_call_chunk",
"args": '{"a',
"name": "",
"id": "",
"index": 1,
},
],
),
AIMessageChunkV1(
[],
tool_call_chunks=[
{
"type": "tool_call_chunk",
"args": '": 1}',
"name": "",
"id": "",
"index": 1,
},
],
),
]
full: Optional[AIMessageChunkV1] = None
for chunk in chunks:
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessageChunkV1)
assert full.content == [
{
"type": "reasoning",
"reasoning": "Let's call a tool.",
"index": 0,
},
{
"type": "tool_call_chunk",
"args": '{"a": 1}',
"name": "tool_name",
"id": "call_123",
"index": 1,
},
]
llm = GenericFakeChatModelV1(message_chunks=chunks)
full = None
for chunk in llm.stream("anything"):
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessageChunkV1)
assert full.content == [
{
"type": "reasoning",
"reasoning": "Let's call a tool.",
"index": 0,
},
{
"type": "tool_call",
"args": {"a": 1},
"name": "tool_name",
"id": "call_123",
},
]

View File

@@ -458,3 +458,23 @@ def test_cleanup_serialized() -> None:
"name": "CustomChat",
"type": "constructor",
}
def test_token_costs_are_zeroed_out() -> None:
# We zero-out token costs for cache hits
local_cache = InMemoryCache()
messages = [
AIMessage(
content="Hello, how are you?",
usage_metadata={"input_tokens": 5, "output_tokens": 10, "total_tokens": 15},
),
]
model = GenericFakeChatModel(messages=iter(messages), cache=local_cache)
first_response = model.invoke("Hello")
assert isinstance(first_response, AIMessage)
assert first_response.usage_metadata
second_response = model.invoke("Hello")
assert isinstance(second_response, AIMessage)
assert second_response.usage_metadata
assert second_response.usage_metadata["total_cost"] == 0 # type: ignore[typeddict-item]

View File

@@ -1,6 +1,9 @@
import json
import pytest
from pydantic import BaseModel, ConfigDict, Field
from langchain_core.load import Serializable, dumpd, load
from langchain_core.load import Serializable, dumpd, dumps, load
from langchain_core.load.serializable import _is_field_useful
from langchain_core.messages import AIMessage
from langchain_core.outputs import ChatGeneration, Generation
@@ -276,3 +279,92 @@ def test_serialization_with_ignore_unserializable_fields() -> None:
]
]
}
# Tests for dumps() function
def test_dumps_basic_serialization() -> None:
"""Test basic string serialization with `dumps()`."""
foo = Foo(bar=42, baz="test")
json_str = dumps(foo)
# Should be valid JSON
parsed = json.loads(json_str)
assert parsed == {
"id": ["tests", "unit_tests", "load", "test_serializable", "Foo"],
"kwargs": {"bar": 42, "baz": "test"},
"lc": 1,
"type": "constructor",
}
def test_dumps_pretty_formatting() -> None:
"""Test pretty printing functionality."""
foo = Foo(bar=1, baz="hello")
# Test pretty=True with default indent
pretty_json = dumps(foo, pretty=True)
assert " " in pretty_json
# Test custom indent (4-space)
custom_indent = dumps(foo, pretty=True, indent=4)
assert " " in custom_indent
# Verify it's still valid JSON
parsed = json.loads(pretty_json)
assert parsed["kwargs"]["bar"] == 1
def test_dumps_invalid_default_kwarg() -> None:
"""Test that passing `'default'` as kwarg raises ValueError."""
foo = Foo(bar=1, baz="test")
with pytest.raises(ValueError, match="`default` should not be passed to dumps"):
dumps(foo, default=lambda x: x)
def test_dumps_additional_json_kwargs() -> None:
"""Test that additional JSON kwargs are passed through."""
foo = Foo(bar=1, baz="test")
compact_json = dumps(foo, separators=(",", ":"))
assert ", " not in compact_json # Should be compact
# Test sort_keys
sorted_json = dumps(foo, sort_keys=True)
parsed = json.loads(sorted_json)
assert parsed == dumpd(foo)
def test_dumps_non_serializable_object() -> None:
"""Test `dumps()` behavior with non-serializable objects."""
class NonSerializable:
def __init__(self, value: int) -> None:
self.value = value
obj = NonSerializable(42)
json_str = dumps(obj)
# Should create a "not_implemented" representation
parsed = json.loads(json_str)
assert parsed["lc"] == 1
assert parsed["type"] == "not_implemented"
assert "NonSerializable" in parsed["repr"]
def test_dumps_mixed_data_structure() -> None:
"""Test `dumps()` with complex nested data structures."""
data = {
"serializable": Foo(bar=1, baz="test"),
"list": [1, 2, {"nested": "value"}],
"primitive": "string",
}
json_str = dumps(data)
parsed = json.loads(json_str)
# Serializable object should be properly serialized
assert parsed["serializable"]["type"] == "constructor"
# Primitives should remain unchanged
assert parsed["list"] == [1, 2, {"nested": "value"}]
assert parsed["primitive"] == "string"

View File

@@ -0,0 +1,913 @@
"""Unit tests for ContentBlock factory functions."""
from uuid import UUID
import pytest
from langchain_core.messages.content_blocks import (
CodeInterpreterCall,
CodeInterpreterOutput,
CodeInterpreterResult,
InvalidToolCall,
ToolCallChunk,
WebSearchCall,
WebSearchResult,
create_audio_block,
create_citation,
create_file_block,
create_image_block,
create_non_standard_block,
create_plaintext_block,
create_reasoning_block,
create_text_block,
create_tool_call,
create_video_block,
)
def _validate_lc_uuid(id_value: str) -> None:
"""Validate that the ID has ``lc_`` prefix and valid UUID suffix.
Args:
id_value: The ID string to validate.
Raises:
AssertionError: If the ID doesn't have ``lc_`` prefix or invalid UUID.
"""
assert id_value.startswith("lc_"), f"ID should start with 'lc_' but got: {id_value}"
# Validate the UUID part after the lc_ prefix
UUID(id_value[3:])
class TestTextBlockFactory:
"""Test create_text_block factory function."""
def test_basic_creation(self) -> None:
"""Test basic text block creation."""
block = create_text_block("Hello world")
assert block["type"] == "text"
assert block.get("text") == "Hello world"
assert "id" in block
id_value = block.get("id")
assert id_value is not None, "block id is None"
_validate_lc_uuid(id_value)
def test_with_custom_id(self) -> None:
"""Test text block creation with custom ID."""
custom_id = "custom-123"
block = create_text_block("Hello", id=custom_id)
assert block.get("id") == custom_id
def test_with_annotations(self) -> None:
"""Test text block creation with annotations."""
citation = create_citation(url="https://example.com", title="Example")
block = create_text_block("Hello", annotations=[citation])
assert block.get("annotations") == [citation]
def test_with_index(self) -> None:
"""Test text block creation with index."""
block = create_text_block("Hello", index=42)
assert block.get("index") == 42
def test_optional_fields_not_present_when_none(self) -> None:
"""Test that optional fields are not included when None."""
block = create_text_block("Hello")
assert "annotations" not in block
assert "index" not in block
class TestImageBlockFactory:
"""Test create_image_block factory function."""
def test_with_url(self) -> None:
"""Test image block creation with URL."""
block = create_image_block(url="https://example.com/image.jpg")
assert block["type"] == "image"
assert block.get("url") == "https://example.com/image.jpg"
assert "id" in block
id_value = block.get("id")
assert id_value is not None, "block id is None"
_validate_lc_uuid(id_value)
def test_with_base64(self) -> None:
"""Test image block creation with base64 data."""
block = create_image_block(
base64="iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJ", mime_type="image/png"
)
assert block.get("base64") == "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJ"
assert block.get("mime_type") == "image/png"
def test_with_file_id(self) -> None:
"""Test image block creation with file ID."""
block = create_image_block(file_id="file-123")
assert block.get("file_id") == "file-123"
def test_no_source_raises_error(self) -> None:
"""Test that missing all sources raises ValueError."""
with pytest.raises(
ValueError, match="Must provide one of: url, base64, or file_id"
):
create_image_block()
def test_with_index(self) -> None:
"""Test image block creation with index."""
block = create_image_block(url="https://example.com/image.jpg", index=1)
assert block.get("index") == 1
def test_optional_fields_not_present_when_not_provided(self) -> None:
"""Test that optional fields are not included when not provided."""
block = create_image_block(url="https://example.com/image.jpg")
assert "base64" not in block
assert "file_id" not in block
assert "mime_type" not in block
assert "index" not in block
class TestVideoBlockFactory:
"""Test create_video_block factory function."""
def test_with_url(self) -> None:
"""Test video block creation with URL."""
block = create_video_block(url="https://example.com/video.mp4")
assert block["type"] == "video"
assert block.get("url") == "https://example.com/video.mp4"
def test_with_base64(self) -> None:
"""Test video block creation with base64 data."""
block = create_video_block(
base64="UklGRnoGAABXQVZFZm10IBAAAAABAAEA", mime_type="video/mp4"
)
assert block.get("base64") == "UklGRnoGAABXQVZFZm10IBAAAAABAAEA"
assert block.get("mime_type") == "video/mp4"
def test_no_source_raises_error(self) -> None:
"""Test that missing all sources raises ValueError."""
with pytest.raises(
ValueError, match="Must provide one of: url, base64, or file_id"
):
create_video_block()
class TestAudioBlockFactory:
"""Test create_audio_block factory function."""
def test_with_url(self) -> None:
"""Test audio block creation with URL."""
block = create_audio_block(url="https://example.com/audio.mp3")
assert block["type"] == "audio"
assert block.get("url") == "https://example.com/audio.mp3"
def test_with_base64(self) -> None:
"""Test audio block creation with base64 data."""
block = create_audio_block(
base64="UklGRnoGAABXQVZFZm10IBAAAAABAAEA", mime_type="audio/mp3"
)
assert block.get("base64") == "UklGRnoGAABXQVZFZm10IBAAAAABAAEA"
assert block.get("mime_type") == "audio/mp3"
def test_no_source_raises_error(self) -> None:
"""Test that missing all sources raises ValueError."""
with pytest.raises(
ValueError, match="Must provide one of: url, base64, or file_id"
):
create_audio_block()
class TestFileBlockFactory:
"""Test create_file_block factory function."""
def test_with_url(self) -> None:
"""Test file block creation with URL."""
block = create_file_block(url="https://example.com/document.pdf")
assert block["type"] == "file"
assert block.get("url") == "https://example.com/document.pdf"
def test_with_base64(self) -> None:
"""Test file block creation with base64 data."""
block = create_file_block(
base64="JVBERi0xLjQKJdPr6eEKMSAwIG9iago8PAovVHlwZSAvQ2F0YWxvZwo=",
mime_type="application/pdf",
)
assert (
block.get("base64")
== "JVBERi0xLjQKJdPr6eEKMSAwIG9iago8PAovVHlwZSAvQ2F0YWxvZwo="
)
assert block.get("mime_type") == "application/pdf"
def test_no_source_raises_error(self) -> None:
"""Test that missing all sources raises ValueError."""
with pytest.raises(
ValueError, match="Must provide one of: url, base64, or file_id"
):
create_file_block()
class TestPlainTextBlockFactory:
"""Test create_plain_text_block factory function."""
def test_basic_creation(self) -> None:
"""Test basic plain text block creation."""
block = create_plaintext_block("This is plain text content.")
assert block["type"] == "text-plain"
assert block.get("mime_type") == "text/plain"
assert block.get("text") == "This is plain text content."
assert "id" in block
id_value = block.get("id")
assert id_value is not None, "block id is None"
_validate_lc_uuid(id_value)
def test_with_title_and_context(self) -> None:
"""Test plain text block creation with title and context."""
block = create_plaintext_block(
"Document content here.",
title="Important Document",
context="This document contains important information.",
)
assert block.get("title") == "Important Document"
assert block.get("context") == "This document contains important information."
def test_with_url(self) -> None:
"""Test plain text block creation with URL."""
block = create_plaintext_block(
"Content", url="https://example.com/document.txt"
)
assert block.get("url") == "https://example.com/document.txt"
class TestToolCallFactory:
"""Test create_tool_call factory function."""
def test_basic_creation(self) -> None:
"""Test basic tool call creation."""
block = create_tool_call("search", {"query": "python"})
assert block["type"] == "tool_call"
assert block["name"] == "search"
assert block["args"] == {"query": "python"}
assert "id" in block
id_value = block.get("id")
assert id_value is not None, "block id is None"
_validate_lc_uuid(id_value)
def test_with_custom_id(self) -> None:
"""Test tool call creation with custom ID."""
block = create_tool_call("search", {"query": "python"}, id="tool-123")
assert block.get("id") == "tool-123"
def test_with_index(self) -> None:
"""Test tool call creation with index."""
block = create_tool_call("search", {"query": "python"}, index=2)
assert block.get("index") == 2
class TestReasoningBlockFactory:
"""Test create_reasoning_block factory function."""
def test_basic_creation(self) -> None:
"""Test basic reasoning block creation."""
block = create_reasoning_block("Let me think about this problem...")
assert block["type"] == "reasoning"
assert block.get("reasoning") == "Let me think about this problem..."
assert "id" in block
id_value = block.get("id")
assert id_value is not None, "block id is None"
_validate_lc_uuid(id_value)
@pytest.mark.xfail(reason="Optional fields not implemented yet")
def test_with_signatures(self) -> None:
"""Test reasoning block creation with signatures."""
block = create_reasoning_block(
"Thinking...",
thought_signature="thought-sig-123", # type: ignore[call-arg]
signature="auth-sig-456", # type: ignore[call-arg, unused-ignore]
)
assert block.get("thought_signature") == "thought-sig-123"
assert block.get("signature") == "auth-sig-456"
def test_with_index(self) -> None:
"""Test reasoning block creation with index."""
block = create_reasoning_block("Thinking...", index=3)
assert block.get("index") == 3
class TestCitationFactory:
"""Test create_citation factory function."""
def test_basic_creation(self) -> None:
"""Test basic citation creation."""
block = create_citation()
assert block["type"] == "citation"
assert "id" in block
id_value = block.get("id")
assert id_value is not None, "block id is None"
_validate_lc_uuid(id_value)
def test_with_all_fields(self) -> None:
"""Test citation creation with all fields."""
block = create_citation(
url="https://example.com/source",
title="Source Document",
start_index=10,
end_index=50,
cited_text="This is the cited text.",
)
assert block.get("url") == "https://example.com/source"
assert block.get("title") == "Source Document"
assert block.get("start_index") == 10
assert block.get("end_index") == 50
assert block.get("cited_text") == "This is the cited text."
def test_optional_fields_not_present_when_none(self) -> None:
"""Test that optional fields are not included when None."""
block = create_citation()
assert "url" not in block
assert "title" not in block
assert "start_index" not in block
assert "end_index" not in block
assert "cited_text" not in block
class TestNonStandardBlockFactory:
"""Test create_non_standard_block factory function."""
def test_basic_creation(self) -> None:
"""Test basic non-standard block creation."""
value = {"custom_field": "custom_value", "number": 42}
block = create_non_standard_block(value)
assert block["type"] == "non_standard"
assert block["value"] == value
assert "id" in block
id_value = block.get("id")
assert id_value is not None, "block id is None"
_validate_lc_uuid(id_value)
def test_with_index(self) -> None:
"""Test non-standard block creation with index."""
value = {"data": "test"}
block = create_non_standard_block(value, index=5)
assert block.get("index") == 5
def test_optional_fields_not_present_when_none(self) -> None:
"""Test that optional fields are not included when None."""
value = {"data": "test"}
block = create_non_standard_block(value)
assert "index" not in block
class TestUUIDValidation:
"""Test UUID generation and validation behavior."""
def test_custom_id_bypasses_lc_prefix_requirement(self) -> None:
"""Test that custom IDs can use any format (don't require lc_ prefix)."""
custom_id = "custom-123"
block = create_text_block("Hello", id=custom_id)
assert block.get("id") == custom_id
# Custom IDs should not be validated with lc_ prefix requirement
def test_generated_ids_are_unique(self) -> None:
"""Test that multiple factory calls generate unique IDs."""
blocks = [create_text_block("test") for _ in range(10)]
ids = [block.get("id") for block in blocks]
# All IDs should be unique
assert len(set(ids)) == len(ids)
# All generated IDs should have lc_ prefix
for id_value in ids:
_validate_lc_uuid(id_value or "")
def test_empty_string_id_generates_new_uuid(self) -> None:
"""Test that empty string ID generates new UUID with lc_ prefix."""
block = create_text_block("Hello", id="")
id_value: str = block.get("id", "")
assert id_value != ""
_validate_lc_uuid(id_value)
def test_generated_id_length(self) -> None:
"""Test that generated IDs have correct length (UUID4 + lc_ prefix)."""
block = create_text_block("Hello")
id_value = block.get("id")
assert id_value is not None
# UUID4 string length is 36 chars, plus 3 for "lc_" prefix = 39 total
expected_length = 36 + 3
assert len(id_value) == expected_length, (
f"Expected length {expected_length}, got {len(id_value)}"
)
# Validate it's properly formatted
_validate_lc_uuid(id_value)
class TestFactoryTypeConsistency:
"""Test that factory functions return correctly typed objects."""
def test_factories_return_correct_types(self) -> None:
"""Test that all factory functions return the expected TypedDict types."""
text_block = create_text_block("test")
assert isinstance(text_block, dict)
assert text_block["type"] == "text"
image_block = create_image_block(url="https://example.com/image.jpg")
assert isinstance(image_block, dict)
assert image_block["type"] == "image"
video_block = create_video_block(url="https://example.com/video.mp4")
assert isinstance(video_block, dict)
assert video_block["type"] == "video"
audio_block = create_audio_block(url="https://example.com/audio.mp3")
assert isinstance(audio_block, dict)
assert audio_block["type"] == "audio"
file_block = create_file_block(url="https://example.com/file.pdf")
assert isinstance(file_block, dict)
assert file_block["type"] == "file"
plain_text_block = create_plaintext_block("content")
assert isinstance(plain_text_block, dict)
assert plain_text_block["type"] == "text-plain"
tool_call = create_tool_call("tool", {"arg": "value"})
assert isinstance(tool_call, dict)
assert tool_call["type"] == "tool_call"
reasoning_block = create_reasoning_block("reasoning")
assert isinstance(reasoning_block, dict)
assert reasoning_block["type"] == "reasoning"
citation = create_citation()
assert isinstance(citation, dict)
assert citation["type"] == "citation"
non_standard_block = create_non_standard_block({"data": "value"})
assert isinstance(non_standard_block, dict)
assert non_standard_block["type"] == "non_standard"
class TestExtraItems:
"""Test that content blocks support extra items."""
def test_text_block_extras_field(self) -> None:
"""Test that TextContentBlock properly supports the extras field."""
block = create_text_block("Hello world")
block["extras"] = {
"openai_metadata": {"model": "gpt-4", "temperature": 0.7},
"anthropic_usage": {"input_tokens": 10, "output_tokens": 20},
"custom_field": "any value",
}
assert block["type"] == "text"
assert block["text"] == "Hello world"
assert "id" in block
assert "extras" in block
extras = block.get("extras", {})
assert extras.get("openai_metadata") == {"model": "gpt-4", "temperature": 0.7}
expected_usage = {"input_tokens": 10, "output_tokens": 20}
assert extras.get("anthropic_usage") == expected_usage
assert extras.get("custom_field") == "any value"
def test_extra_items_do_not_interfere_with_standard_fields(self) -> None:
"""Test that extra items don't interfere with standard field access."""
block = create_text_block("Original text", index=1)
# Add many extra fields
for i in range(10):
block[f"extra_field_{i}"] = f"value_{i}" # type: ignore[literal-required]
# Standard fields should still work correctly
assert block["type"] == "text"
assert block["text"] == "Original text"
assert block["index"] == 1 if "index" in block else None
assert "id" in block
# Extra fields should also be accessible
for i in range(10):
assert block.get(f"extra_field_{i}") == f"value_{i}"
def test_extra_items_can_be_modified(self) -> None:
"""Test that extra items can be modified after creation."""
block = create_image_block(url="https://example.com/image.jpg")
# Add an extra field
block["extras"] = {"status": "pending"}
assert block["extras"].get("status") == "pending"
# Modify the extra field
block["extras"] = {"status": "processed"}
assert block["extras"].get("status") == "processed"
# Add more fields
block["extras"] = {"metadata": {"version": 1}}
metadata = block["extras"].get("metadata", {})
assert isinstance(metadata, dict)
assert metadata.get("version") == 1
# Modify nested extra field
metadata["version"] = 2
assert isinstance(metadata, dict)
assert metadata.get("version") == 2
def test_all_content_blocks_support_extra_items(self) -> None:
"""Test that all content block types support extra items."""
# Test each content block type
text_block = create_text_block("test")
text_block["extras"] = {"text_extra": "a"}
assert text_block.get("extras") == {"text_extra": "a"}
image_block = create_image_block(url="https://example.com/image.jpg")
image_block["extras"] = {"image_extra": "a"}
assert image_block.get("extras") == {"image_extra": "a"}
video_block = create_video_block(url="https://example.com/video.mp4")
video_block["extras"] = {"video_extra": "a"}
assert video_block.get("extras") == {"video_extra": "a"}
audio_block = create_audio_block(url="https://example.com/audio.mp3")
audio_block["extras"] = {"audio_extra": "a"}
assert audio_block.get("extras") == {"audio_extra": "a"}
file_block = create_file_block(url="https://example.com/file.pdf")
file_block["extras"] = {"file_extra": "a"}
assert file_block.get("extras") == {"file_extra": "a"}
plain_text_block = create_plaintext_block("content")
plain_text_block["extras"] = {"plaintext_extra": "a"}
assert plain_text_block.get("extras") == {"plaintext_extra": "a"}
tool_call = create_tool_call("tool", {"arg": "value"})
tool_call["extras"] = {"tool_extra": "a"}
assert tool_call.get("extras") == {"tool_extra": "a"}
reasoning_block = create_reasoning_block("reasoning")
reasoning_block["extras"] = {"reasoning_extra": "a"}
assert reasoning_block.get("extras") == {"reasoning_extra": "a"}
class TestExtrasField:
"""Test the explicit extras field across all content block types."""
def test_all_content_blocks_support_extras_field(self) -> None:
"""Test that all content block types support the explicit extras field."""
provider_metadata = {
"provider": "openai",
"model": "gpt-4",
"temperature": 0.7,
"usage": {"input_tokens": 10, "output_tokens": 20},
}
# Test TextContentBlock
text_block = create_text_block("test")
text_block["extras"] = provider_metadata
assert text_block.get("extras") == provider_metadata
assert text_block["type"] == "text"
# Test ImageContentBlock
image_block = create_image_block(url="https://example.com/image.jpg")
image_block["extras"] = provider_metadata
assert image_block.get("extras") == provider_metadata
assert image_block["type"] == "image"
# Test VideoContentBlock
video_block = create_video_block(url="https://example.com/video.mp4")
video_block["extras"] = provider_metadata
assert video_block.get("extras") == provider_metadata
assert video_block["type"] == "video"
# Test AudioContentBlock
audio_block = create_audio_block(url="https://example.com/audio.mp3")
audio_block["extras"] = provider_metadata
assert audio_block.get("extras") == provider_metadata
assert audio_block["type"] == "audio"
# Test FileContentBlock
file_block = create_file_block(url="https://example.com/file.pdf")
file_block["extras"] = provider_metadata
assert file_block.get("extras") == provider_metadata
assert file_block["type"] == "file"
# Test PlainTextContentBlock
plain_text_block = create_plaintext_block("content")
plain_text_block["extras"] = provider_metadata
assert plain_text_block.get("extras") == provider_metadata
assert plain_text_block["type"] == "text-plain"
# Test ToolCall
tool_call = create_tool_call("tool", {"arg": "value"})
tool_call["extras"] = provider_metadata
assert tool_call.get("extras") == provider_metadata
assert tool_call["type"] == "tool_call"
# Test ReasoningContentBlock
reasoning_block = create_reasoning_block("reasoning")
reasoning_block["extras"] = provider_metadata
assert reasoning_block.get("extras") == provider_metadata
assert reasoning_block["type"] == "reasoning"
# Test Citation
citation = create_citation()
citation["extras"] = provider_metadata
assert citation.get("extras") == provider_metadata
assert citation["type"] == "citation"
def test_extras_field_is_optional(self) -> None:
"""Test that the extras field is optional and blocks work without it."""
# Create blocks without extras
text_block = create_text_block("test")
image_block = create_image_block(url="https://example.com/image.jpg")
tool_call = create_tool_call("tool", {"arg": "value"})
reasoning_block = create_reasoning_block("reasoning")
citation = create_citation()
# Verify blocks work correctly without extras
assert text_block["type"] == "text"
assert image_block["type"] == "image"
assert tool_call["type"] == "tool_call"
assert reasoning_block["type"] == "reasoning"
assert citation["type"] == "citation"
# Verify extras field is not present when not set
assert "extras" not in text_block
assert "extras" not in image_block
assert "extras" not in tool_call
assert "extras" not in reasoning_block
assert "extras" not in citation
def test_extras_field_can_be_modified(self) -> None:
"""Test that the extras field can be modified after creation."""
block = create_text_block("test")
# Add extras
block["extras"] = {"initial": "value"}
assert block.get("extras") == {"initial": "value"}
# Modify extras
block["extras"] = {"updated": "value", "count": 42}
extras = block.get("extras", {})
assert extras.get("updated") == "value"
assert extras.get("count") == 42
assert "initial" not in extras
# Update nested values in extras
if "extras" in block:
block["extras"]["nested"] = {"deep": "value"}
extras = block.get("extras", {})
nested = extras.get("nested", {})
assert isinstance(nested, dict)
assert nested.get("deep") == "value"
def test_extras_field_supports_various_data_types(self) -> None:
"""Test that the extras field can store various data types."""
block = create_text_block("test")
complex_extras = {
"string_val": "test string",
"int_val": 42,
"float_val": 3.14,
"bool_val": True,
"none_val": None,
"list_val": ["item1", "item2", {"nested": "in_list"}],
"dict_val": {"nested": {"deeply": {"nested": "value"}}},
}
block["extras"] = complex_extras
extras = block.get("extras", {})
assert extras.get("string_val") == "test string"
assert extras.get("int_val") == 42
assert extras.get("float_val") == 3.14
assert extras.get("bool_val") is True
assert extras.get("none_val") is None
list_val = extras.get("list_val", [])
assert isinstance(list_val, list)
assert len(list_val) == 3
assert list_val[0] == "item1"
assert list_val[1] == "item2"
assert isinstance(list_val[2], dict)
assert list_val[2].get("nested") == "in_list"
dict_val = extras.get("dict_val", {})
assert isinstance(dict_val, dict)
nested = dict_val.get("nested", {})
assert isinstance(nested, dict)
deeply = nested.get("deeply", {})
assert isinstance(deeply, dict)
assert deeply.get("nested") == "value"
def test_extras_field_does_not_interfere_with_standard_fields(self) -> None:
"""Test that the extras field doesn't interfere with standard fields."""
# Create a complex block with all standard fields
block = create_text_block(
"Test content",
annotations=[create_citation(url="https://example.com")],
index=42,
)
# Add extensive extras
large_extras = {f"field_{i}": f"value_{i}" for i in range(100)}
block["extras"] = large_extras
# Verify all standard fields still work
assert block["type"] == "text"
assert block["text"] == "Test content"
assert block.get("index") == 42
assert "id" in block
assert "annotations" in block
annotations = block.get("annotations", [])
assert len(annotations) == 1
assert annotations[0]["type"] == "citation"
# Verify extras field works
extras = block.get("extras", {})
assert len(extras) == 100
for i in range(100):
assert extras.get(f"field_{i}") == f"value_{i}"
def test_special_content_blocks_support_extras_field(self) -> None:
"""Test that special content blocks support extras field."""
provider_metadata = {
"provider": "openai",
"request_id": "req_12345",
"timing": {"start": 1234567890, "end": 1234567895},
}
# Test ToolCallChunk
tool_call_chunk: ToolCallChunk = {
"type": "tool_call_chunk",
"id": "tool_123",
"name": "search",
"args": '{"query": "test"}',
"index": 0,
"extras": provider_metadata,
}
assert tool_call_chunk.get("extras") == provider_metadata
assert tool_call_chunk["type"] == "tool_call_chunk"
# Test InvalidToolCall
invalid_tool_call: InvalidToolCall = {
"type": "invalid_tool_call",
"id": "invalid_123",
"name": "bad_tool",
"args": "invalid json",
"error": "JSON parse error",
"extras": provider_metadata,
}
assert invalid_tool_call.get("extras") == provider_metadata
assert invalid_tool_call["type"] == "invalid_tool_call"
# Test WebSearchCall
web_search_call: WebSearchCall = {
"type": "web_search_call",
"id": "search_123",
"query": "python langchain",
"index": 0,
"extras": provider_metadata,
}
assert web_search_call.get("extras") == provider_metadata
assert web_search_call["type"] == "web_search_call"
# Test WebSearchResult
web_search_result: WebSearchResult = {
"type": "web_search_result",
"id": "result_123",
"urls": ["https://example.com", "https://test.com"],
"index": 0,
"extras": provider_metadata,
}
assert web_search_result.get("extras") == provider_metadata
assert web_search_result["type"] == "web_search_result"
# Test CodeInterpreterCall
code_interpreter_call: CodeInterpreterCall = {
"type": "code_interpreter_call",
"id": "code_123",
"language": "python",
"code": "print('hello world')",
"index": 0,
"extras": provider_metadata,
}
assert code_interpreter_call.get("extras") == provider_metadata
assert code_interpreter_call["type"] == "code_interpreter_call"
# Test CodeInterpreterOutput
code_interpreter_output: CodeInterpreterOutput = {
"type": "code_interpreter_output",
"id": "output_123",
"return_code": 0,
"stderr": "",
"stdout": "hello world\n",
"file_ids": ["file_123"],
"index": 0,
"extras": provider_metadata,
}
assert code_interpreter_output.get("extras") == provider_metadata
assert code_interpreter_output["type"] == "code_interpreter_output"
# Test CodeInterpreterResult
code_interpreter_result: CodeInterpreterResult = {
"type": "code_interpreter_result",
"id": "result_123",
"output": [code_interpreter_output],
"index": 0,
"extras": provider_metadata,
}
assert code_interpreter_result.get("extras") == provider_metadata
assert code_interpreter_result["type"] == "code_interpreter_result"
def test_extras_field_is_not_required_for_special_blocks(self) -> None:
"""Test that extras field is optional for all special content blocks."""
# Create blocks without extras field
tool_call_chunk: ToolCallChunk = {
"id": "tool_123",
"name": "search",
"args": '{"query": "test"}',
"index": 0,
}
invalid_tool_call: InvalidToolCall = {
"type": "invalid_tool_call",
"id": "invalid_123",
"name": "bad_tool",
"args": "invalid json",
"error": "JSON parse error",
}
web_search_call: WebSearchCall = {
"type": "web_search_call",
"query": "python langchain",
}
web_search_result: WebSearchResult = {
"type": "web_search_result",
"urls": ["https://example.com"],
}
code_interpreter_call: CodeInterpreterCall = {
"type": "code_interpreter_call",
"code": "print('hello')",
}
code_interpreter_output: CodeInterpreterOutput = {
"type": "code_interpreter_output",
"stdout": "hello\n",
}
code_interpreter_result: CodeInterpreterResult = {
"type": "code_interpreter_result",
"output": [code_interpreter_output],
}
# Verify they work without extras
assert tool_call_chunk.get("name") == "search"
assert invalid_tool_call["type"] == "invalid_tool_call"
assert web_search_call["type"] == "web_search_call"
assert web_search_result["type"] == "web_search_result"
assert code_interpreter_call["type"] == "code_interpreter_call"
assert code_interpreter_output["type"] == "code_interpreter_output"
assert code_interpreter_result["type"] == "code_interpreter_result"
# Verify extras field is not present
assert "extras" not in tool_call_chunk
assert "extras" not in invalid_tool_call
assert "extras" not in web_search_call
assert "extras" not in web_search_result
assert "extras" not in code_interpreter_call
assert "extras" not in code_interpreter_output
assert "extras" not in code_interpreter_result

View File

@@ -5,26 +5,48 @@ EXPECTED_ALL = [
"_message_from_dict",
"AIMessage",
"AIMessageChunk",
"Annotation",
"AnyMessage",
"AudioContentBlock",
"BaseMessage",
"BaseMessageChunk",
"ContentBlock",
"ChatMessage",
"ChatMessageChunk",
"Citation",
"CodeInterpreterCall",
"CodeInterpreterOutput",
"CodeInterpreterResult",
"DataContentBlock",
"FileContentBlock",
"FunctionMessage",
"FunctionMessageChunk",
"HumanMessage",
"HumanMessageChunk",
"ImageContentBlock",
"InvalidToolCall",
"NonStandardAnnotation",
"NonStandardContentBlock",
"PlainTextContentBlock",
"SystemMessage",
"SystemMessageChunk",
"TextContentBlock",
"ToolCall",
"ToolCallChunk",
"ToolMessage",
"ToolMessageChunk",
"VideoContentBlock",
"WebSearchCall",
"WebSearchResult",
"ReasoningContentBlock",
"RemoveMessage",
"convert_to_messages",
"get_buffer_string",
"is_data_content_block",
"is_reasoning_block",
"is_text_block",
"is_tool_call_block",
"is_tool_call_chunk",
"merge_content",
"message_chunk_to_message",
"message_to_dict",

View File

@@ -0,0 +1,343 @@
"""Unit tests for ResponseMetadata TypedDict."""
from langchain_core.v1.messages import AIMessage, AIMessageChunk, ResponseMetadata
class TestResponseMetadata:
"""Test the ResponseMetadata TypedDict functionality."""
def test_response_metadata_basic_fields(self) -> None:
"""Test ResponseMetadata with basic required fields."""
metadata: ResponseMetadata = {
"model_provider": "openai",
"model_name": "gpt-4",
}
assert metadata.get("model_provider") == "openai"
assert metadata.get("model_name") == "gpt-4"
def test_response_metadata_is_optional(self) -> None:
"""Test that ResponseMetadata fields are optional due to total=False."""
# Should be able to create empty ResponseMetadata
metadata: ResponseMetadata = {}
assert metadata == {}
# Should be able to create with just one field
metadata_partial: ResponseMetadata = {"model_provider": "anthropic"}
assert metadata_partial.get("model_provider") == "anthropic"
assert "model_name" not in metadata_partial
def test_response_metadata_supports_extra_fields(self) -> None:
"""Test that ResponseMetadata supports provider-specific extra fields."""
metadata: ResponseMetadata = {
"model_provider": "openai",
"model_name": "gpt-4-turbo",
# Extra fields should be allowed
"system_fingerprint": "fp_12345",
"logprobs": None,
"finish_reason": "stop",
"request_id": "req_abc123",
}
assert metadata.get("model_provider") == "openai"
assert metadata.get("model_name") == "gpt-4-turbo"
assert metadata.get("system_fingerprint") == "fp_12345"
assert metadata.get("logprobs") is None
assert metadata.get("finish_reason") == "stop"
assert metadata.get("request_id") == "req_abc123"
def test_response_metadata_various_data_types(self) -> None:
"""Test that ResponseMetadata can store various data types in extra fields."""
metadata: ResponseMetadata = {
"model_provider": "anthropic",
"model_name": "claude-3-sonnet",
"string_field": "test_value",
"int_field": 42,
"float_field": 3.14,
"bool_field": True,
"none_field": None,
"list_field": [1, 2, 3, "test"],
"dict_field": {"nested": {"deeply": "nested_value"}},
}
assert metadata.get("string_field") == "test_value"
assert metadata.get("int_field") == 42
assert metadata.get("float_field") == 3.14
assert metadata.get("bool_field") is True
assert metadata.get("none_field") is None
list_field = metadata.get("list_field")
assert isinstance(list_field, list)
assert list_field == [1, 2, 3, "test"]
dict_field = metadata.get("dict_field")
assert isinstance(dict_field, dict)
nested = dict_field.get("nested")
assert isinstance(nested, dict)
assert nested.get("deeply") == "nested_value"
def test_response_metadata_can_be_modified(self) -> None:
"""Test that ResponseMetadata can be modified after creation."""
metadata: ResponseMetadata = {
"model_provider": "openai",
"model_name": "gpt-3.5-turbo",
}
# Modify existing fields
metadata["model_name"] = "gpt-4"
assert metadata.get("model_name") == "gpt-4"
# Add new fields
metadata["request_id"] = "req_12345"
assert metadata.get("request_id") == "req_12345"
# Modify nested structures
metadata["headers"] = {"x-request-id": "abc123"}
metadata["headers"]["x-rate-limit"] = "100" # type: ignore[typeddict-item]
headers = metadata.get("headers")
assert isinstance(headers, dict)
assert headers.get("x-request-id") == "abc123"
assert headers.get("x-rate-limit") == "100"
def test_response_metadata_provider_specific_examples(self) -> None:
"""Test ResponseMetadata with realistic provider-specific examples."""
# OpenAI-style metadata
openai_metadata: ResponseMetadata = {
"model_provider": "openai",
"model_name": "gpt-4-turbo-2024-04-09",
"system_fingerprint": "fp_abc123",
"created": 1234567890,
"logprobs": None,
"finish_reason": "stop",
}
assert openai_metadata.get("model_provider") == "openai"
assert openai_metadata.get("system_fingerprint") == "fp_abc123"
# Anthropic-style metadata
anthropic_metadata: ResponseMetadata = {
"model_provider": "anthropic",
"model_name": "claude-3-sonnet-20240229",
"stop_reason": "end_turn",
"stop_sequence": None,
}
assert anthropic_metadata.get("model_provider") == "anthropic"
assert anthropic_metadata.get("stop_reason") == "end_turn"
# Custom provider metadata
custom_metadata: ResponseMetadata = {
"model_provider": "custom_llm_service",
"model_name": "custom-model-v1",
"service_tier": "premium",
"rate_limit_info": {
"requests_remaining": 100,
"reset_time": "2024-01-01T00:00:00Z",
},
"response_time_ms": 1250,
}
assert custom_metadata.get("service_tier") == "premium"
rate_limit = custom_metadata.get("rate_limit_info")
assert isinstance(rate_limit, dict)
assert rate_limit.get("requests_remaining") == 100
class TestResponseMetadataWithAIMessages:
"""Test ResponseMetadata integration with AI message classes."""
def test_ai_message_with_response_metadata(self) -> None:
"""Test AIMessage with ResponseMetadata."""
metadata: ResponseMetadata = {
"model_provider": "openai",
"model_name": "gpt-4",
"system_fingerprint": "fp_xyz789",
}
message = AIMessage(content="Hello, world!", response_metadata=metadata)
assert message.response_metadata == metadata
assert message.response_metadata.get("model_provider") == "openai"
assert message.response_metadata.get("model_name") == "gpt-4"
assert message.response_metadata.get("system_fingerprint") == "fp_xyz789"
def test_ai_message_chunk_with_response_metadata(self) -> None:
"""Test AIMessageChunk with ResponseMetadata."""
metadata: ResponseMetadata = {
"model_provider": "anthropic",
"model_name": "claude-3-sonnet",
"stream_id": "stream_12345",
}
chunk = AIMessageChunk(content="Hello", response_metadata=metadata)
assert chunk.response_metadata == metadata
assert chunk.response_metadata.get("stream_id") == "stream_12345"
def test_ai_message_default_empty_response_metadata(self) -> None:
"""Test that AIMessage creates empty ResponseMetadata by default."""
message = AIMessage(content="Test message")
# Should have empty dict as default
assert message.response_metadata == {}
assert isinstance(message.response_metadata, dict)
def test_ai_message_chunk_default_empty_response_metadata(self) -> None:
"""Test that AIMessageChunk creates empty ResponseMetadata by default."""
chunk = AIMessageChunk(content="Test chunk")
# Should have empty dict as default
assert chunk.response_metadata == {}
assert isinstance(chunk.response_metadata, dict)
def test_response_metadata_merging_in_chunks(self) -> None:
"""Test that ResponseMetadata is properly merged when adding AIMessageChunks."""
metadata1: ResponseMetadata = {
"model_provider": "openai",
"model_name": "gpt-4",
"request_id": "req_123",
"system_fingerprint": "fp_abc",
}
metadata2: ResponseMetadata = {
"stream_chunk": 1,
"finish_reason": "length",
}
chunk1 = AIMessageChunk(content="Hello ", response_metadata=metadata1)
chunk2 = AIMessageChunk(content="world!", response_metadata=metadata2)
merged = chunk1 + chunk2
# Should have merged response_metadata
assert merged.response_metadata.get("model_provider") == "openai"
assert merged.response_metadata.get("model_name") == "gpt-4"
assert merged.response_metadata.get("request_id") == "req_123"
assert merged.response_metadata.get("stream_chunk") == 1
assert merged.response_metadata.get("system_fingerprint") == "fp_abc"
assert merged.response_metadata.get("finish_reason") == "length"
def test_response_metadata_modification_after_message_creation(self) -> None:
"""Test that ResponseMetadata can be modified after message creation."""
message = AIMessage(
content="Initial message",
response_metadata={"model_provider": "openai", "model_name": "gpt-3.5"},
)
# Modify existing field
message.response_metadata["model_name"] = "gpt-4"
assert message.response_metadata.get("model_name") == "gpt-4"
# Add new field
message.response_metadata["finish_reason"] = "stop"
assert message.response_metadata.get("finish_reason") == "stop"
def test_response_metadata_with_none_values(self) -> None:
"""Test ResponseMetadata handling of None values."""
metadata: ResponseMetadata = {
"model_provider": "openai",
"model_name": "gpt-4",
"system_fingerprint": None,
"logprobs": None,
}
message = AIMessage(content="Test", response_metadata=metadata)
assert message.response_metadata.get("system_fingerprint") is None
assert message.response_metadata.get("logprobs") is None
assert "system_fingerprint" in message.response_metadata
assert "logprobs" in message.response_metadata
class TestResponseMetadataEdgeCases:
"""Test edge cases and error conditions for ResponseMetadata."""
def test_response_metadata_with_complex_nested_structures(self) -> None:
"""Test ResponseMetadata with deeply nested and complex structures."""
metadata: ResponseMetadata = {
"model_provider": "custom",
"model_name": "complex-model",
"complex_data": {
"level1": {
"level2": {
"level3": {
"deeply_nested": "value",
"array": [
{"item": 1, "metadata": {"nested": True}},
{"item": 2, "metadata": {"nested": False}},
],
}
}
}
},
}
complex_data = metadata.get("complex_data")
assert isinstance(complex_data, dict)
level1 = complex_data.get("level1")
assert isinstance(level1, dict)
level2 = level1.get("level2")
assert isinstance(level2, dict)
level3 = level2.get("level3")
assert isinstance(level3, dict)
assert level3.get("deeply_nested") == "value"
array = level3.get("array")
assert isinstance(array, list)
assert len(array) == 2
assert array[0]["item"] == 1
assert array[0]["metadata"]["nested"] is True
def test_response_metadata_large_data(self) -> None:
"""Test ResponseMetadata with large amounts of data."""
# Create metadata with many fields
large_metadata: ResponseMetadata = {
"model_provider": "test_provider",
"model_name": "test_model",
}
# Add 100 extra fields
for i in range(100):
large_metadata[f"field_{i}"] = f"value_{i}" # type: ignore[literal-required]
message = AIMessage(content="Test", response_metadata=large_metadata)
# Verify all fields are accessible
assert message.response_metadata.get("model_provider") == "test_provider"
for i in range(100):
assert message.response_metadata.get(f"field_{i}") == f"value_{i}"
def test_response_metadata_empty_vs_none(self) -> None:
"""Test the difference between empty ResponseMetadata and None."""
# Message with empty metadata
message_empty = AIMessage(content="Test", response_metadata={})
assert message_empty.response_metadata == {}
assert isinstance(message_empty.response_metadata, dict)
# Message with None metadata (should become empty dict)
message_none = AIMessage(content="Test", response_metadata=None)
assert message_none.response_metadata == {}
assert isinstance(message_none.response_metadata, dict)
# Default message (no metadata specified)
message_default = AIMessage(content="Test")
assert message_default.response_metadata == {}
assert isinstance(message_default.response_metadata, dict)
def test_response_metadata_preserves_original_dict_type(self) -> None:
"""Test that ResponseMetadata preserves the original dict when passed."""
original_dict: ResponseMetadata = {
"model_provider": "openai",
"model_name": "gpt-4",
"custom_field": "custom_value",
}
message = AIMessage(content="Test", response_metadata=original_dict)
# Should be the same dict object
assert message.response_metadata is original_dict
# Modifications to the message's response_metadata should affect original
message.response_metadata["new_field"] = "new_value"
assert original_dict.get("new_field") == "new_value"

View File

@@ -0,0 +1,361 @@
"""Unit tests for ResponseMetadata TypedDict."""
from langchain_core.messages.v1 import AIMessage, AIMessageChunk, ResponseMetadata
class TestResponseMetadata:
"""Test the ResponseMetadata TypedDict functionality."""
def test_response_metadata_basic_fields(self) -> None:
"""Test ResponseMetadata with basic required fields."""
metadata: ResponseMetadata = {
"model_provider": "openai",
"model_name": "gpt-4",
}
assert metadata.get("model_provider") == "openai"
assert metadata.get("model_name") == "gpt-4"
def test_response_metadata_is_optional(self) -> None:
"""Test that ResponseMetadata fields are optional due to total=False."""
# Should be able to create empty ResponseMetadata
metadata: ResponseMetadata = {}
assert metadata == {}
# Should be able to create with just one field
metadata_partial: ResponseMetadata = {"model_provider": "anthropic"}
assert metadata_partial.get("model_provider") == "anthropic"
assert "model_name" not in metadata_partial
def test_response_metadata_supports_extra_fields(self) -> None:
"""Test that ResponseMetadata supports provider-specific extra fields."""
metadata: ResponseMetadata = {
"model_provider": "openai",
"model_name": "gpt-4-turbo",
# Extra fields should be allowed
"usage": {"input_tokens": 100, "output_tokens": 50},
"system_fingerprint": "fp_12345",
"logprobs": None,
"finish_reason": "stop",
}
assert metadata.get("model_provider") == "openai"
assert metadata.get("model_name") == "gpt-4-turbo"
assert metadata.get("usage") == {"input_tokens": 100, "output_tokens": 50}
assert metadata.get("system_fingerprint") == "fp_12345"
assert metadata.get("logprobs") is None
assert metadata.get("finish_reason") == "stop"
def test_response_metadata_various_data_types(self) -> None:
"""Test that ResponseMetadata can store various data types in extra fields."""
metadata: ResponseMetadata = {
"model_provider": "anthropic",
"model_name": "claude-3-sonnet",
"string_field": "test_value", # type: ignore[typeddict-unknown-key]
"int_field": 42, # type: ignore[typeddict-unknown-key]
"float_field": 3.14, # type: ignore[typeddict-unknown-key]
"bool_field": True, # type: ignore[typeddict-unknown-key]
"none_field": None, # type: ignore[typeddict-unknown-key]
"list_field": [1, 2, 3, "test"], # type: ignore[typeddict-unknown-key]
"dict_field": { # type: ignore[typeddict-unknown-key]
"nested": {"deeply": "nested_value"}
},
}
assert metadata.get("string_field") == "test_value" # type: ignore[typeddict-item]
assert metadata.get("int_field") == 42 # type: ignore[typeddict-item]
assert metadata.get("float_field") == 3.14 # type: ignore[typeddict-item]
assert metadata.get("bool_field") is True # type: ignore[typeddict-item]
assert metadata.get("none_field") is None # type: ignore[typeddict-item]
list_field = metadata.get("list_field") # type: ignore[typeddict-item]
assert isinstance(list_field, list)
assert list_field == [1, 2, 3, "test"]
dict_field = metadata.get("dict_field") # type: ignore[typeddict-item]
assert isinstance(dict_field, dict)
nested = dict_field.get("nested") # type: ignore[union-attr]
assert isinstance(nested, dict)
assert nested.get("deeply") == "nested_value" # type: ignore[union-attr]
def test_response_metadata_can_be_modified(self) -> None:
"""Test that ResponseMetadata can be modified after creation."""
metadata: ResponseMetadata = {
"model_provider": "openai",
"model_name": "gpt-3.5-turbo",
}
# Modify existing fields
metadata["model_name"] = "gpt-4"
assert metadata.get("model_name") == "gpt-4"
# Add new fields
metadata["request_id"] = "req_12345" # type: ignore[typeddict-unknown-key]
assert metadata.get("request_id") == "req_12345" # type: ignore[typeddict-item]
# Modify nested structures
metadata["usage"] = {"input_tokens": 10} # type: ignore[typeddict-unknown-key]
metadata["usage"]["output_tokens"] = 20 # type: ignore[typeddict-item]
usage = metadata.get("usage") # type: ignore[typeddict-item]
assert isinstance(usage, dict)
assert usage.get("input_tokens") == 10 # type: ignore[union-attr]
assert usage.get("output_tokens") == 20 # type: ignore[union-attr]
def test_response_metadata_provider_specific_examples(self) -> None:
"""Test ResponseMetadata with realistic provider-specific examples."""
# OpenAI-style metadata
openai_metadata: ResponseMetadata = {
"model_provider": "openai",
"model_name": "gpt-4-turbo-2024-04-09",
"usage": { # type: ignore[typeddict-unknown-key]
"prompt_tokens": 50,
"completion_tokens": 25,
"total_tokens": 75,
},
"system_fingerprint": "fp_abc123", # type: ignore[typeddict-unknown-key]
"created": 1234567890, # type: ignore[typeddict-unknown-key]
"logprobs": None, # type: ignore[typeddict-unknown-key]
"finish_reason": "stop", # type: ignore[typeddict-unknown-key]
}
assert openai_metadata.get("model_provider") == "openai"
assert openai_metadata.get("system_fingerprint") == "fp_abc123" # type: ignore[typeddict-item]
# Anthropic-style metadata
anthropic_metadata: ResponseMetadata = {
"model_provider": "anthropic",
"model_name": "claude-3-sonnet-20240229",
"usage": { # type: ignore[typeddict-unknown-key]
"input_tokens": 75,
"output_tokens": 30,
},
"stop_reason": "end_turn", # type: ignore[typeddict-unknown-key]
"stop_sequence": None, # type: ignore[typeddict-unknown-key]
}
assert anthropic_metadata.get("model_provider") == "anthropic"
assert anthropic_metadata.get("stop_reason") == "end_turn" # type: ignore[typeddict-item]
# Custom provider metadata
custom_metadata: ResponseMetadata = {
"model_provider": "custom_llm_service",
"model_name": "custom-model-v1",
"service_tier": "premium", # type: ignore[typeddict-unknown-key]
"rate_limit_info": { # type: ignore[typeddict-unknown-key]
"requests_remaining": 100,
"reset_time": "2024-01-01T00:00:00Z",
},
"response_time_ms": 1250, # type: ignore[typeddict-unknown-key]
}
assert custom_metadata.get("service_tier") == "premium" # type: ignore[typeddict-item]
rate_limit = custom_metadata.get("rate_limit_info") # type: ignore[typeddict-item]
assert isinstance(rate_limit, dict)
assert rate_limit.get("requests_remaining") == 100 # type: ignore[union-attr]
class TestResponseMetadataWithAIMessages:
"""Test ResponseMetadata integration with AI message classes."""
def test_ai_message_with_response_metadata(self) -> None:
"""Test AIMessage with ResponseMetadata."""
metadata: ResponseMetadata = {
"model_provider": "openai",
"model_name": "gpt-4",
"usage": {"input_tokens": 10, "output_tokens": 5}, # type: ignore[typeddict-unknown-key]
}
message = AIMessage(content="Hello, world!", response_metadata=metadata)
assert message.response_metadata == metadata
assert message.response_metadata.get("model_provider") == "openai"
assert message.response_metadata.get("model_name") == "gpt-4"
usage = message.response_metadata.get("usage") # type: ignore[typeddict-item]
assert isinstance(usage, dict)
assert usage.get("input_tokens") == 10 # type: ignore[union-attr]
def test_ai_message_chunk_with_response_metadata(self) -> None:
"""Test AIMessageChunk with ResponseMetadata."""
metadata: ResponseMetadata = {
"model_provider": "anthropic",
"model_name": "claude-3-sonnet",
"stream_id": "stream_12345", # type: ignore[typeddict-unknown-key]
}
chunk = AIMessageChunk(content="Hello", response_metadata=metadata)
assert chunk.response_metadata == metadata
assert chunk.response_metadata.get("stream_id") == "stream_12345" # type: ignore[typeddict-item]
def test_ai_message_default_empty_response_metadata(self) -> None:
"""Test that AIMessage creates empty ResponseMetadata by default."""
message = AIMessage(content="Test message")
# Should have empty dict as default
assert message.response_metadata == {}
assert isinstance(message.response_metadata, dict)
def test_ai_message_chunk_default_empty_response_metadata(self) -> None:
"""Test that AIMessageChunk creates empty ResponseMetadata by default."""
chunk = AIMessageChunk(content="Test chunk")
# Should have empty dict as default
assert chunk.response_metadata == {}
assert isinstance(chunk.response_metadata, dict)
def test_response_metadata_merging_in_chunks(self) -> None:
"""Test that ResponseMetadata is properly merged when adding AIMessageChunks."""
metadata1: ResponseMetadata = {
"model_provider": "openai",
"model_name": "gpt-4",
"request_id": "req_123", # type: ignore[typeddict-unknown-key]
"usage": {"input_tokens": 10}, # type: ignore[typeddict-unknown-key]
}
metadata2: ResponseMetadata = {
"stream_chunk": 1, # type: ignore[typeddict-unknown-key]
"usage": {"output_tokens": 5}, # type: ignore[typeddict-unknown-key]
}
chunk1 = AIMessageChunk(content="Hello ", response_metadata=metadata1)
chunk2 = AIMessageChunk(content="world!", response_metadata=metadata2)
merged = chunk1 + chunk2
# Should have merged response_metadata
assert merged.response_metadata.get("model_provider") == "openai"
assert merged.response_metadata.get("model_name") == "gpt-4"
assert merged.response_metadata.get("request_id") == "req_123" # type: ignore[typeddict-item]
assert merged.response_metadata.get("stream_chunk") == 1 # type: ignore[typeddict-item]
# Usage should be merged (from merge_dicts behavior)
merged_usage = merged.response_metadata.get("usage") # type: ignore[typeddict-item]
assert isinstance(merged_usage, dict)
assert merged_usage.get("input_tokens") == 10 # type: ignore[union-attr]
assert merged_usage.get("output_tokens") == 5 # type: ignore[union-attr]
def test_response_metadata_modification_after_message_creation(self) -> None:
"""Test that ResponseMetadata can be modified after message creation."""
message = AIMessage(
content="Initial message",
response_metadata={"model_provider": "openai", "model_name": "gpt-3.5"},
)
# Modify existing field
message.response_metadata["model_name"] = "gpt-4"
assert message.response_metadata.get("model_name") == "gpt-4"
# Add new field
message.response_metadata["finish_reason"] = "stop" # type: ignore[typeddict-unknown-key]
assert message.response_metadata.get("finish_reason") == "stop" # type: ignore[typeddict-item]
def test_response_metadata_with_none_values(self) -> None:
"""Test ResponseMetadata handling of None values."""
metadata: ResponseMetadata = {
"model_provider": "openai",
"model_name": "gpt-4",
"system_fingerprint": None, # type: ignore[typeddict-unknown-key]
"logprobs": None, # type: ignore[typeddict-unknown-key]
}
message = AIMessage(content="Test", response_metadata=metadata)
assert message.response_metadata.get("system_fingerprint") is None # type: ignore[typeddict-item]
assert message.response_metadata.get("logprobs") is None # type: ignore[typeddict-item]
assert "system_fingerprint" in message.response_metadata
assert "logprobs" in message.response_metadata
class TestResponseMetadataEdgeCases:
"""Test edge cases and error conditions for ResponseMetadata."""
def test_response_metadata_with_complex_nested_structures(self) -> None:
"""Test ResponseMetadata with deeply nested and complex structures."""
metadata: ResponseMetadata = {
"model_provider": "custom",
"model_name": "complex-model",
"complex_data": { # type: ignore[typeddict-unknown-key]
"level1": {
"level2": {
"level3": {
"deeply_nested": "value",
"array": [
{"item": 1, "metadata": {"nested": True}},
{"item": 2, "metadata": {"nested": False}},
],
}
}
}
},
}
complex_data = metadata.get("complex_data") # type: ignore[typeddict-item]
assert isinstance(complex_data, dict)
level1 = complex_data.get("level1") # type: ignore[union-attr]
assert isinstance(level1, dict)
level2 = level1.get("level2") # type: ignore[union-attr]
assert isinstance(level2, dict)
level3 = level2.get("level3") # type: ignore[union-attr]
assert isinstance(level3, dict)
assert level3.get("deeply_nested") == "value" # type: ignore[union-attr]
array = level3.get("array") # type: ignore[union-attr]
assert isinstance(array, list)
assert len(array) == 2 # type: ignore[arg-type]
assert array[0]["item"] == 1 # type: ignore[index, typeddict-item]
assert array[0]["metadata"]["nested"] is True # type: ignore[index, typeddict-item]
def test_response_metadata_large_data(self) -> None:
"""Test ResponseMetadata with large amounts of data."""
# Create metadata with many fields
large_metadata: ResponseMetadata = {
"model_provider": "test_provider",
"model_name": "test_model",
}
# Add 100 extra fields
for i in range(100):
large_metadata[f"field_{i}"] = f"value_{i}" # type: ignore[literal-required]
message = AIMessage(content="Test", response_metadata=large_metadata)
# Verify all fields are accessible
assert message.response_metadata.get("model_provider") == "test_provider"
for i in range(100):
assert message.response_metadata.get(f"field_{i}") == f"value_{i}" # type: ignore[typeddict-item]
def test_response_metadata_empty_vs_none(self) -> None:
"""Test the difference between empty ResponseMetadata and None."""
# Message with empty metadata
message_empty = AIMessage(content="Test", response_metadata={})
assert message_empty.response_metadata == {}
assert isinstance(message_empty.response_metadata, dict)
# Message with None metadata (should become empty dict)
message_none = AIMessage(content="Test", response_metadata=None)
assert message_none.response_metadata == {}
assert isinstance(message_none.response_metadata, dict)
# Default message (no metadata specified)
message_default = AIMessage(content="Test")
assert message_default.response_metadata == {}
assert isinstance(message_default.response_metadata, dict)
def test_response_metadata_preserves_original_dict_type(self) -> None:
"""Test that ResponseMetadata preserves the original dict when passed."""
original_dict = {
"model_provider": "openai",
"model_name": "gpt-4",
"custom_field": "custom_value",
}
message = AIMessage(content="Test", response_metadata=original_dict)
# Should be the same dict object
assert message.response_metadata is original_dict
# Modifications to the message's response_metadata should affect original
message.response_metadata["new_field"] = "new_value" # type: ignore[typeddict-unknown-key]
assert original_dict.get("new_field") == "new_value" # type: ignore[typeddict-item]

View File

@@ -1221,15 +1221,30 @@ def test_convert_to_openai_messages_multimodal() -> None:
{"type": "text", "text": "Text message"},
{
"type": "image",
"source_type": "url",
"url": "https://example.com/test.png",
},
{
"type": "image",
"source_type": "url", # backward compatibility
"url": "https://example.com/test.png",
},
{
"type": "image",
"base64": "<base64 string>",
"mime_type": "image/png",
},
{
"type": "image",
"source_type": "base64",
"data": "<base64 string>",
"mime_type": "image/png",
},
{
"type": "file",
"base64": "<base64 string>",
"mime_type": "application/pdf",
"filename": "test.pdf",
},
{
"type": "file",
"source_type": "base64",
@@ -1244,11 +1259,20 @@ def test_convert_to_openai_messages_multimodal() -> None:
"file_data": "data:application/pdf;base64,<base64 string>",
},
},
{
"type": "file",
"file_id": "file-abc123",
},
{
"type": "file",
"source_type": "id",
"id": "file-abc123",
},
{
"type": "audio",
"base64": "<base64 string>",
"mime_type": "audio/wav",
},
{
"type": "audio",
"source_type": "base64",
@@ -1268,7 +1292,7 @@ def test_convert_to_openai_messages_multimodal() -> None:
result = convert_to_openai_messages(messages, text_format="block")
assert len(result) == 1
message = result[0]
assert len(message["content"]) == 8
assert len(message["content"]) == 13
# Test adding filename
messages = [
@@ -1276,8 +1300,7 @@ def test_convert_to_openai_messages_multimodal() -> None:
content=[
{
"type": "file",
"source_type": "base64",
"data": "<base64 string>",
"base64": "<base64 string>",
"mime_type": "application/pdf",
},
]

View File

@@ -1,15 +1,19 @@
"""Module to test base parser implementations."""
from typing import Union
from typing_extensions import override
from langchain_core.exceptions import OutputParserException
from langchain_core.language_models import GenericFakeChatModel
from langchain_core.language_models.fake_chat_models import GenericFakeChatModelV1
from langchain_core.messages import AIMessage
from langchain_core.output_parsers import (
BaseGenerationOutputParser,
BaseTransformOutputParser,
)
from langchain_core.outputs import ChatGeneration, Generation
from langchain_core.v1.messages import AIMessage as AIMessageV1
def test_base_generation_parser() -> None:
@@ -20,7 +24,7 @@ def test_base_generation_parser() -> None:
@override
def parse_result(
self, result: list[Generation], *, partial: bool = False
self, result: Union[list[Generation], AIMessageV1], *, partial: bool = False
) -> str:
"""Parse a list of model Generations into a specific format.
@@ -32,16 +36,22 @@ def test_base_generation_parser() -> None:
partial: Whether to allow partial results. This is used for parsers
that support streaming
"""
if len(result) != 1:
msg = "This output parser can only be used with a single generation."
raise NotImplementedError(msg)
generation = result[0]
if not isinstance(generation, ChatGeneration):
# Say that this one only works with chat generations
msg = "This output parser can only be used with a chat generation."
raise OutputParserException(msg)
if isinstance(result, AIMessageV1):
content = result.text
else:
if len(result) != 1:
msg = (
"This output parser can only be used with a single generation."
)
raise NotImplementedError(msg)
generation = result[0]
if not isinstance(generation, ChatGeneration):
# Say that this one only works with chat generations
msg = "This output parser can only be used with a chat generation."
raise OutputParserException(msg)
assert isinstance(generation.message.content, str)
content = generation.message.content
content = generation.message.content
assert isinstance(content, str)
return content.swapcase()
@@ -49,6 +59,10 @@ def test_base_generation_parser() -> None:
chain = model | StrInvertCase()
assert chain.invoke("") == "HeLLO"
model_v1 = GenericFakeChatModelV1(messages=iter([AIMessageV1("hEllo")]))
chain_v1 = model_v1 | StrInvertCase()
assert chain_v1.invoke("") == "HeLLO"
def test_base_transform_output_parser() -> None:
"""Test base transform output parser."""
@@ -62,7 +76,7 @@ def test_base_transform_output_parser() -> None:
@override
def parse_result(
self, result: list[Generation], *, partial: bool = False
self, result: Union[list[Generation], AIMessageV1], *, partial: bool = False
) -> str:
"""Parse a list of model Generations into a specific format.
@@ -74,15 +88,22 @@ def test_base_transform_output_parser() -> None:
partial: Whether to allow partial results. This is used for parsers
that support streaming
"""
if len(result) != 1:
msg = "This output parser can only be used with a single generation."
raise NotImplementedError(msg)
generation = result[0]
if not isinstance(generation, ChatGeneration):
# Say that this one only works with chat generations
msg = "This output parser can only be used with a chat generation."
raise OutputParserException(msg)
content = generation.message.content
if isinstance(result, AIMessageV1):
content = result.text
else:
if len(result) != 1:
msg = (
"This output parser can only be used with a single generation."
)
raise NotImplementedError(msg)
generation = result[0]
if not isinstance(generation, ChatGeneration):
# Say that this one only works with chat generations
msg = "This output parser can only be used with a chat generation."
raise OutputParserException(msg)
assert isinstance(generation.message.content, str)
content = generation.message.content
assert isinstance(content, str)
return content.swapcase()
@@ -91,3 +112,8 @@ def test_base_transform_output_parser() -> None:
# inputs to models are ignored, response is hard-coded in model definition
chunks = list(chain.stream(""))
assert chunks == ["HELLO", " ", "WORLD"]
model_v1 = GenericFakeChatModelV1(message_chunks=["hello", " ", "world"])
chain_v1 = model_v1 | StrInvertCase()
chunks = list(chain_v1.stream(""))
assert chunks == ["HELLO", " ", "WORLD", ""]

View File

@@ -16,6 +16,8 @@ from langchain_core.output_parsers.openai_tools import (
PydanticToolsParser,
)
from langchain_core.outputs import ChatGeneration
from langchain_core.v1.messages import AIMessage as AIMessageV1
from langchain_core.v1.messages import AIMessageChunk as AIMessageChunkV1
STREAMED_MESSAGES: list = [
AIMessageChunk(content=""),
@@ -331,6 +333,14 @@ for message in STREAMED_MESSAGES:
STREAMED_MESSAGES_WITH_TOOL_CALLS.append(message)
STREAMED_MESSAGES_V1 = [
AIMessageChunkV1(
content=[],
tool_call_chunks=chunk.tool_call_chunks,
)
for chunk in STREAMED_MESSAGES_WITH_TOOL_CALLS
]
EXPECTED_STREAMED_JSON = [
{},
{"names": ["suz"]},
@@ -398,6 +408,19 @@ def test_partial_json_output_parser(*, use_tool_calls: bool) -> None:
assert actual == expected
def test_partial_json_output_parser_v1() -> None:
def input_iter(_: Any) -> Iterator[AIMessageChunkV1]:
yield from STREAMED_MESSAGES_V1
chain = input_iter | JsonOutputToolsParser()
actual = list(chain.stream(None))
expected: list = [[]] + [
[{"type": "NameCollector", "args": chunk}] for chunk in EXPECTED_STREAMED_JSON
]
assert actual == expected
@pytest.mark.parametrize("use_tool_calls", [False, True])
async def test_partial_json_output_parser_async(*, use_tool_calls: bool) -> None:
input_iter = _get_aiter(use_tool_calls=use_tool_calls)
@@ -410,6 +433,20 @@ async def test_partial_json_output_parser_async(*, use_tool_calls: bool) -> None
assert actual == expected
async def test_partial_json_output_parser_async_v1() -> None:
async def input_iter(_: Any) -> AsyncIterator[AIMessageChunkV1]:
for msg in STREAMED_MESSAGES_V1:
yield msg
chain = input_iter | JsonOutputToolsParser()
actual = [p async for p in chain.astream(None)]
expected: list = [[]] + [
[{"type": "NameCollector", "args": chunk}] for chunk in EXPECTED_STREAMED_JSON
]
assert actual == expected
@pytest.mark.parametrize("use_tool_calls", [False, True])
def test_partial_json_output_parser_return_id(*, use_tool_calls: bool) -> None:
input_iter = _get_iter(use_tool_calls=use_tool_calls)
@@ -429,6 +466,26 @@ def test_partial_json_output_parser_return_id(*, use_tool_calls: bool) -> None:
assert actual == expected
def test_partial_json_output_parser_return_id_v1() -> None:
def input_iter(_: Any) -> Iterator[AIMessageChunkV1]:
yield from STREAMED_MESSAGES_V1
chain = input_iter | JsonOutputToolsParser(return_id=True)
actual = list(chain.stream(None))
expected: list = [[]] + [
[
{
"type": "NameCollector",
"args": chunk,
"id": "call_OwL7f5PEPJTYzw9sQlNJtCZl",
}
]
for chunk in EXPECTED_STREAMED_JSON
]
assert actual == expected
@pytest.mark.parametrize("use_tool_calls", [False, True])
def test_partial_json_output_key_parser(*, use_tool_calls: bool) -> None:
input_iter = _get_iter(use_tool_calls=use_tool_calls)
@@ -439,6 +496,17 @@ def test_partial_json_output_key_parser(*, use_tool_calls: bool) -> None:
assert actual == expected
def test_partial_json_output_key_parser_v1() -> None:
def input_iter(_: Any) -> Iterator[AIMessageChunkV1]:
yield from STREAMED_MESSAGES_V1
chain = input_iter | JsonOutputKeyToolsParser(key_name="NameCollector")
actual = list(chain.stream(None))
expected: list = [[]] + [[chunk] for chunk in EXPECTED_STREAMED_JSON]
assert actual == expected
@pytest.mark.parametrize("use_tool_calls", [False, True])
async def test_partial_json_output_parser_key_async(*, use_tool_calls: bool) -> None:
input_iter = _get_aiter(use_tool_calls=use_tool_calls)
@@ -450,6 +518,18 @@ async def test_partial_json_output_parser_key_async(*, use_tool_calls: bool) ->
assert actual == expected
async def test_partial_json_output_parser_key_async_v1() -> None:
async def input_iter(_: Any) -> AsyncIterator[AIMessageChunkV1]:
for msg in STREAMED_MESSAGES_V1:
yield msg
chain = input_iter | JsonOutputKeyToolsParser(key_name="NameCollector")
actual = [p async for p in chain.astream(None)]
expected: list = [[]] + [[chunk] for chunk in EXPECTED_STREAMED_JSON]
assert actual == expected
@pytest.mark.parametrize("use_tool_calls", [False, True])
def test_partial_json_output_key_parser_first_only(*, use_tool_calls: bool) -> None:
input_iter = _get_iter(use_tool_calls=use_tool_calls)
@@ -461,6 +541,17 @@ def test_partial_json_output_key_parser_first_only(*, use_tool_calls: bool) -> N
assert list(chain.stream(None)) == EXPECTED_STREAMED_JSON
def test_partial_json_output_key_parser_first_only_v1() -> None:
def input_iter(_: Any) -> Iterator[AIMessageChunkV1]:
yield from STREAMED_MESSAGES_V1
chain = input_iter | JsonOutputKeyToolsParser(
key_name="NameCollector", first_tool_only=True
)
assert list(chain.stream(None)) == EXPECTED_STREAMED_JSON
@pytest.mark.parametrize("use_tool_calls", [False, True])
async def test_partial_json_output_parser_key_async_first_only(
*,
@@ -475,6 +566,18 @@ async def test_partial_json_output_parser_key_async_first_only(
assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON
async def test_partial_json_output_parser_key_async_first_only_v1() -> None:
async def input_iter(_: Any) -> AsyncIterator[AIMessageChunkV1]:
for msg in STREAMED_MESSAGES_V1:
yield msg
chain = input_iter | JsonOutputKeyToolsParser(
key_name="NameCollector", first_tool_only=True
)
assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON
@pytest.mark.parametrize("use_tool_calls", [False, True])
def test_json_output_key_tools_parser_multiple_tools_first_only(
*, use_tool_calls: bool
@@ -531,6 +634,42 @@ def test_json_output_key_tools_parser_multiple_tools_first_only(
assert output_no_id == {"a": 1}
def test_json_output_key_tools_parser_multiple_tools_first_only_v1() -> None:
message = AIMessageV1(
content=[],
tool_calls=[
{
"type": "tool_call",
"id": "call_other",
"name": "other",
"args": {"b": 2},
},
{"type": "tool_call", "id": "call_func", "name": "func", "args": {"a": 1}},
],
)
# Test with return_id=True
parser = JsonOutputKeyToolsParser(
key_name="func", first_tool_only=True, return_id=True
)
output = parser.parse_result(message)
# Should return the func tool call, not None
assert output is not None
assert output["type"] == "func"
assert output["args"] == {"a": 1}
assert "id" in output
# Test with return_id=False
parser_no_id = JsonOutputKeyToolsParser(
key_name="func", first_tool_only=True, return_id=False
)
output_no_id = parser_no_id.parse_result(message)
# Should return just the args
assert output_no_id == {"a": 1}
@pytest.mark.parametrize("use_tool_calls", [False, True])
def test_json_output_key_tools_parser_multiple_tools_no_match(
*, use_tool_calls: bool
@@ -583,6 +722,44 @@ def test_json_output_key_tools_parser_multiple_tools_no_match(
assert output_no_id is None
def test_json_output_key_tools_parser_multiple_tools_no_match_v1() -> None:
message = AIMessageV1(
content=[],
tool_calls=[
{
"type": "tool_call",
"id": "call_other",
"name": "other",
"args": {"b": 2},
},
{
"type": "tool_call",
"id": "call_another",
"name": "another",
"args": {"c": 3},
},
],
)
# Test with return_id=True, first_tool_only=True
parser = JsonOutputKeyToolsParser(
key_name="nonexistent", first_tool_only=True, return_id=True
)
output = parser.parse_result(message)
# Should return None when no matches
assert output is None
# Test with return_id=False, first_tool_only=True
parser_no_id = JsonOutputKeyToolsParser(
key_name="nonexistent", first_tool_only=True, return_id=False
)
output_no_id = parser_no_id.parse_result(message)
# Should return None when no matches
assert output_no_id is None
@pytest.mark.parametrize("use_tool_calls", [False, True])
def test_json_output_key_tools_parser_multiple_matching_tools(
*, use_tool_calls: bool
@@ -643,6 +820,42 @@ def test_json_output_key_tools_parser_multiple_matching_tools(
assert output_all[1]["args"] == {"a": 3}
def test_json_output_key_tools_parser_multiple_matching_tools_v1() -> None:
message = AIMessageV1(
content=[],
tool_calls=[
{"type": "tool_call", "id": "call_func1", "name": "func", "args": {"a": 1}},
{
"type": "tool_call",
"id": "call_other",
"name": "other",
"args": {"b": 2},
},
{"type": "tool_call", "id": "call_func2", "name": "func", "args": {"a": 3}},
],
)
# Test with first_tool_only=True - should return first matching
parser = JsonOutputKeyToolsParser(
key_name="func", first_tool_only=True, return_id=True
)
output = parser.parse_result(message)
assert output is not None
assert output["type"] == "func"
assert output["args"] == {"a": 1} # First matching tool call
# Test with first_tool_only=False - should return all matching
parser_all = JsonOutputKeyToolsParser(
key_name="func", first_tool_only=False, return_id=True
)
output_all = parser_all.parse_result(message)
assert len(output_all) == 2
assert output_all[0]["args"] == {"a": 1}
assert output_all[1]["args"] == {"a": 3}
@pytest.mark.parametrize("use_tool_calls", [False, True])
def test_json_output_key_tools_parser_empty_results(*, use_tool_calls: bool) -> None:
def create_message() -> AIMessage:
@@ -671,6 +884,35 @@ def test_json_output_key_tools_parser_empty_results(*, use_tool_calls: bool) ->
assert output_all == []
@pytest.mark.parametrize(
"empty_message",
[
AIMessageV1(content=[], tool_calls=[]),
AIMessageV1(content="", tool_calls=[]),
],
)
def test_json_output_key_tools_parser_empty_results_v1(
empty_message: AIMessageV1,
) -> None:
# Test with first_tool_only=True
parser = JsonOutputKeyToolsParser(
key_name="func", first_tool_only=True, return_id=True
)
output = parser.parse_result(empty_message)
# Should return None for empty results
assert output is None
# Test with first_tool_only=False
parser_all = JsonOutputKeyToolsParser(
key_name="func", first_tool_only=False, return_id=True
)
output_all = parser_all.parse_result(empty_message)
# Should return empty list for empty results
assert output_all == []
@pytest.mark.parametrize("use_tool_calls", [False, True])
def test_json_output_key_tools_parser_parameter_combinations(
*, use_tool_calls: bool
@@ -746,6 +988,56 @@ def test_json_output_key_tools_parser_parameter_combinations(
assert output4 == [{"a": 1}, {"a": 3}]
def test_json_output_key_tools_parser_parameter_combinations_v1() -> None:
"""Test all parameter combinations of JsonOutputKeyToolsParser."""
result = AIMessageV1(
content=[],
tool_calls=[
{
"type": "tool_call",
"id": "call_other",
"name": "other",
"args": {"b": 2},
},
{"type": "tool_call", "id": "call_func1", "name": "func", "args": {"a": 1}},
{"type": "tool_call", "id": "call_func2", "name": "func", "args": {"a": 3}},
],
)
# Test: first_tool_only=True, return_id=True
parser1 = JsonOutputKeyToolsParser(
key_name="func", first_tool_only=True, return_id=True
)
output1 = parser1.parse_result(result)
assert output1["type"] == "func"
assert output1["args"] == {"a": 1}
assert "id" in output1
# Test: first_tool_only=True, return_id=False
parser2 = JsonOutputKeyToolsParser(
key_name="func", first_tool_only=True, return_id=False
)
output2 = parser2.parse_result(result)
assert output2 == {"a": 1}
# Test: first_tool_only=False, return_id=True
parser3 = JsonOutputKeyToolsParser(
key_name="func", first_tool_only=False, return_id=True
)
output3 = parser3.parse_result(result)
assert len(output3) == 2
assert all("id" in item for item in output3)
assert output3[0]["args"] == {"a": 1}
assert output3[1]["args"] == {"a": 3}
# Test: first_tool_only=False, return_id=False
parser4 = JsonOutputKeyToolsParser(
key_name="func", first_tool_only=False, return_id=False
)
output4 = parser4.parse_result(result)
assert output4 == [{"a": 1}, {"a": 3}]
class Person(BaseModel):
age: int
hair_color: str
@@ -788,6 +1080,18 @@ def test_partial_pydantic_output_parser() -> None:
assert actual == EXPECTED_STREAMED_PYDANTIC
def test_partial_pydantic_output_parser_v1() -> None:
def input_iter(_: Any) -> Iterator[AIMessageChunkV1]:
yield from STREAMED_MESSAGES_V1
chain = input_iter | PydanticToolsParser(
tools=[NameCollector], first_tool_only=True
)
actual = list(chain.stream(None))
assert actual == EXPECTED_STREAMED_PYDANTIC
async def test_partial_pydantic_output_parser_async() -> None:
for use_tool_calls in [False, True]:
input_iter = _get_aiter(use_tool_calls=use_tool_calls)
@@ -800,6 +1104,19 @@ async def test_partial_pydantic_output_parser_async() -> None:
assert actual == EXPECTED_STREAMED_PYDANTIC
async def test_partial_pydantic_output_parser_async_v1() -> None:
async def input_iter(_: Any) -> AsyncIterator[AIMessageChunkV1]:
for msg in STREAMED_MESSAGES_V1:
yield msg
chain = input_iter | PydanticToolsParser(
tools=[NameCollector], first_tool_only=True
)
actual = [p async for p in chain.astream(None)]
assert actual == EXPECTED_STREAMED_PYDANTIC
def test_parse_with_different_pydantic_2_v1() -> None:
"""Test with pydantic.v1.BaseModel from pydantic 2."""
import pydantic
@@ -870,20 +1187,22 @@ def test_parse_with_different_pydantic_2_proper() -> None:
def test_max_tokens_error(caplog: Any) -> None:
parser = PydanticToolsParser(tools=[NameCollector], first_tool_only=True)
message = AIMessage(
content="",
tool_calls=[
{
"id": "call_OwL7f5PE",
"name": "NameCollector",
"args": {"names": ["suz", "jerm"]},
}
],
response_metadata={"stop_reason": "max_tokens"},
)
with pytest.raises(ValidationError):
_ = parser.invoke(message)
assert any(
"`max_tokens` stop reason" in msg and record.levelname == "ERROR"
for record, msg in zip(caplog.records, caplog.messages)
)
for msg_class in [AIMessage, AIMessageV1]:
message = msg_class(
content="",
tool_calls=[
{
"type": "tool_call",
"id": "call_OwL7f5PE",
"name": "NameCollector",
"args": {"names": ["suz", "jerm"]},
}
],
response_metadata={"stop_reason": "max_tokens"},
)
with pytest.raises(ValidationError):
_ = parser.invoke(message)
assert any(
"`max_tokens` stop reason" in msg and record.levelname == "ERROR"
for record, msg in zip(caplog.records, caplog.messages)
)

View File

@@ -726,7 +726,7 @@
'description': '''
Allowance for errors made by LLM.
Here we add an `error` key to surface errors made during generation
Here we add an ``error`` key to surface errors made during generation
(e.g., invalid JSON arguments.)
''',
'properties': dict({
@@ -752,6 +752,10 @@
]),
'title': 'Error',
}),
'extras': dict({
'title': 'Extras',
'type': 'object',
}),
'id': dict({
'anyOf': list([
dict({
@@ -763,6 +767,10 @@
]),
'title': 'Id',
}),
'index': dict({
'title': 'Index',
'type': 'integer',
}),
'name': dict({
'anyOf': list([
dict({
@@ -781,9 +789,10 @@
}),
}),
'required': list([
'type',
'id',
'name',
'args',
'id',
'error',
]),
'title': 'InvalidToolCall',
@@ -998,12 +1007,23 @@
This represents a request to call the tool named "foo" with arguments {"a": 1}
and an identifier of "123".
.. note::
``create_tool_call`` may also be used as a factory to create a
``ToolCall``. Benefits include:
* Automatic ID generation (when not provided)
* Required arguments strictly validated at creation time
''',
'properties': dict({
'args': dict({
'title': 'Args',
'type': 'object',
}),
'extras': dict({
'title': 'Extras',
'type': 'object',
}),
'id': dict({
'anyOf': list([
dict({
@@ -1015,6 +1035,10 @@
]),
'title': 'Id',
}),
'index': dict({
'title': 'Index',
'type': 'integer',
}),
'name': dict({
'title': 'Name',
'type': 'string',
@@ -1026,9 +1050,10 @@
}),
}),
'required': list([
'type',
'id',
'name',
'args',
'id',
]),
'title': 'ToolCall',
'type': 'object',
@@ -1037,9 +1062,9 @@
'description': '''
A chunk of a tool call (e.g., as part of a stream).
When merging ToolCallChunks (e.g., via AIMessageChunk.__add__),
When merging ``ToolCallChunks`` (e.g., via ``AIMessageChunk.__add__``),
all string attributes are concatenated. Chunks are only merged if their
values of `index` are equal and not None.
values of ``index`` are equal and not ``None``.
Example:
@@ -1065,6 +1090,10 @@
]),
'title': 'Args',
}),
'extras': dict({
'title': 'Extras',
'type': 'object',
}),
'id': dict({
'anyOf': list([
dict({
@@ -1105,9 +1134,9 @@
}),
}),
'required': list([
'id',
'name',
'args',
'id',
'index',
]),
'title': 'ToolCallChunk',
@@ -2158,7 +2187,7 @@
'description': '''
Allowance for errors made by LLM.
Here we add an `error` key to surface errors made during generation
Here we add an ``error`` key to surface errors made during generation
(e.g., invalid JSON arguments.)
''',
'properties': dict({
@@ -2184,6 +2213,10 @@
]),
'title': 'Error',
}),
'extras': dict({
'title': 'Extras',
'type': 'object',
}),
'id': dict({
'anyOf': list([
dict({
@@ -2195,6 +2228,10 @@
]),
'title': 'Id',
}),
'index': dict({
'title': 'Index',
'type': 'integer',
}),
'name': dict({
'anyOf': list([
dict({
@@ -2213,9 +2250,10 @@
}),
}),
'required': list([
'type',
'id',
'name',
'args',
'id',
'error',
]),
'title': 'InvalidToolCall',
@@ -2430,12 +2468,23 @@
This represents a request to call the tool named "foo" with arguments {"a": 1}
and an identifier of "123".
.. note::
``create_tool_call`` may also be used as a factory to create a
``ToolCall``. Benefits include:
* Automatic ID generation (when not provided)
* Required arguments strictly validated at creation time
''',
'properties': dict({
'args': dict({
'title': 'Args',
'type': 'object',
}),
'extras': dict({
'title': 'Extras',
'type': 'object',
}),
'id': dict({
'anyOf': list([
dict({
@@ -2447,6 +2496,10 @@
]),
'title': 'Id',
}),
'index': dict({
'title': 'Index',
'type': 'integer',
}),
'name': dict({
'title': 'Name',
'type': 'string',
@@ -2458,9 +2511,10 @@
}),
}),
'required': list([
'type',
'id',
'name',
'args',
'id',
]),
'title': 'ToolCall',
'type': 'object',
@@ -2469,9 +2523,9 @@
'description': '''
A chunk of a tool call (e.g., as part of a stream).
When merging ToolCallChunks (e.g., via AIMessageChunk.__add__),
When merging ``ToolCallChunks`` (e.g., via ``AIMessageChunk.__add__``),
all string attributes are concatenated. Chunks are only merged if their
values of `index` are equal and not None.
values of ``index`` are equal and not ``None``.
Example:
@@ -2497,6 +2551,10 @@
]),
'title': 'Args',
}),
'extras': dict({
'title': 'Extras',
'type': 'object',
}),
'id': dict({
'anyOf': list([
dict({
@@ -2537,9 +2595,9 @@
}),
}),
'required': list([
'id',
'name',
'args',
'id',
'index',
]),
'title': 'ToolCallChunk',

Some files were not shown because too many files have changed in this diff Show More