mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-13 14:21:27 +00:00
Compare commits
38 Commits
sr/v1-mvp
...
mdrxy/mess
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
56379fb94c | ||
|
|
7f989d3c3b | ||
|
|
b7968c2b7d | ||
|
|
2f0c6421a1 | ||
|
|
c31236264e | ||
|
|
cfe13f673a | ||
|
|
02001212b0 | ||
|
|
00244122bd | ||
|
|
5599c59d4a | ||
|
|
6727d6e8c8 | ||
|
|
5036bd7adb | ||
|
|
ec2b34a02d | ||
|
|
11d68a0b9e | ||
|
|
566774a893 | ||
|
|
255a6d668a | ||
|
|
cbf4c0e565 | ||
|
|
145d38f7dd | ||
|
|
68c70da33e | ||
|
|
754528d23f | ||
|
|
dc66737f03 | ||
|
|
499dc35cfb | ||
|
|
42c1159991 | ||
|
|
ac706c77d4 | ||
|
|
8493887b6f | ||
|
|
a647073b26 | ||
|
|
e120604774 | ||
|
|
06d8754b0b | ||
|
|
6e108c1cb4 | ||
|
|
cc6139860c | ||
|
|
ae8f58ac6f | ||
|
|
346731544b | ||
|
|
c1b86cc929 | ||
|
|
376f70be96 | ||
|
|
ac2de920b1 | ||
|
|
e02eed5489 | ||
|
|
5414527236 | ||
|
|
881c6534a6 | ||
|
|
5e9eb19a83 |
@@ -15,12 +15,12 @@ You may use the button above, or follow these steps to open this repo in a Codes
|
||||
1. Click **Create codespace on master**.
|
||||
|
||||
For more info, check out the [GitHub documentation](https://docs.github.com/en/free-pro-team@latest/github/developing-online-with-codespaces/creating-a-codespace#creating-a-codespace).
|
||||
|
||||
|
||||
## VS Code Dev Containers
|
||||
|
||||
[](https://vscode.dev/redirect?url=vscode://ms-vscode-remote.remote-containers/cloneInVolume?url=https://github.com/langchain-ai/langchain)
|
||||
|
||||
> [!NOTE]
|
||||
> [!NOTE]
|
||||
> If you click the link above you will open the main repo (`langchain-ai/langchain`) and *not* your local cloned repo. This is fine if you only want to run and test the library, but if you want to contribute you can use the link below and replace with your username and cloned repo name:
|
||||
|
||||
```txt
|
||||
|
||||
@@ -4,7 +4,7 @@ services:
|
||||
build:
|
||||
dockerfile: libs/langchain/dev.Dockerfile
|
||||
context: ..
|
||||
|
||||
|
||||
networks:
|
||||
- langchain-network
|
||||
|
||||
|
||||
2
.github/CODE_OF_CONDUCT.md
vendored
2
.github/CODE_OF_CONDUCT.md
vendored
@@ -129,4 +129,4 @@ For answers to common questions about this code of conduct, see the FAQ at
|
||||
[v2.1]: https://www.contributor-covenant.org/version/2/1/code_of_conduct.html
|
||||
[Mozilla CoC]: https://github.com/mozilla/diversity
|
||||
[FAQ]: https://www.contributor-covenant.org/faq
|
||||
[translations]: https://www.contributor-covenant.org/translations
|
||||
[translations]: https://www.contributor-covenant.org/translations
|
||||
|
||||
8
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
8
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
@@ -5,7 +5,7 @@ body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
Thank you for taking the time to file a bug report.
|
||||
Thank you for taking the time to file a bug report.
|
||||
|
||||
Use this to report BUGS in LangChain. For usage questions, feature requests and general design questions, please use the [LangChain Forum](https://forum.langchain.com/).
|
||||
|
||||
@@ -50,7 +50,7 @@ body:
|
||||
|
||||
If a maintainer can copy it, run it, and see it right away, there's a much higher chance that you'll be able to get help.
|
||||
|
||||
**Important!**
|
||||
**Important!**
|
||||
|
||||
* Avoid screenshots when possible, as they are hard to read and (more importantly) don't allow others to copy-and-paste your code.
|
||||
* Reduce your code to the minimum required to reproduce the issue if possible. This makes it much easier for others to help you.
|
||||
@@ -58,14 +58,14 @@ body:
|
||||
* INCLUDE the language label (e.g. `python`) after the first three backticks to enable syntax highlighting. (e.g., ```python rather than ```).
|
||||
|
||||
placeholder: |
|
||||
The following code:
|
||||
The following code:
|
||||
|
||||
```python
|
||||
from langchain_core.runnables import RunnableLambda
|
||||
|
||||
def bad_code(inputs) -> int:
|
||||
raise NotImplementedError('For demo purpose')
|
||||
|
||||
|
||||
chain = RunnableLambda(bad_code)
|
||||
chain.invoke('Hello!')
|
||||
```
|
||||
|
||||
2
.github/ISSUE_TEMPLATE/documentation.yml
vendored
2
.github/ISSUE_TEMPLATE/documentation.yml
vendored
@@ -14,7 +14,7 @@ body:
|
||||
|
||||
Do **NOT** use this to ask usage questions or reporting issues with your code.
|
||||
|
||||
If you have usage questions or need help solving some problem,
|
||||
If you have usage questions or need help solving some problem,
|
||||
please use the [LangChain Forum](https://forum.langchain.com/).
|
||||
|
||||
If you're in the wrong place, here are some helpful links to find a better
|
||||
|
||||
2
.github/ISSUE_TEMPLATE/privileged.yml
vendored
2
.github/ISSUE_TEMPLATE/privileged.yml
vendored
@@ -8,7 +8,7 @@ body:
|
||||
|
||||
If you are not a LangChain maintainer or were not asked directly by a maintainer to create an issue, then please start the conversation on the [LangChain Forum](https://forum.langchain.com/) instead.
|
||||
|
||||
You are a LangChain maintainer if you maintain any of the packages inside of the LangChain repository
|
||||
You are a LangChain maintainer if you maintain any of the packages inside of the LangChain repository
|
||||
or are a regular contributor to LangChain with previous merged pull requests.
|
||||
- type: checkboxes
|
||||
id: privileged
|
||||
|
||||
2
.github/actions/people/Dockerfile
vendored
2
.github/actions/people/Dockerfile
vendored
@@ -4,4 +4,4 @@ RUN pip install httpx PyGithub "pydantic==2.0.2" pydantic-settings "pyyaml>=5.3.
|
||||
|
||||
COPY ./app /app
|
||||
|
||||
CMD ["python", "/app/main.py"]
|
||||
CMD ["python", "/app/main.py"]
|
||||
|
||||
6
.github/actions/people/action.yml
vendored
6
.github/actions/people/action.yml
vendored
@@ -4,8 +4,8 @@ description: "Generate the data for the LangChain People page"
|
||||
author: "Jacob Lee <jacob@langchain.dev>"
|
||||
inputs:
|
||||
token:
|
||||
description: 'User token, to read the GitHub API. Can be passed in using {{ secrets.LANGCHAIN_PEOPLE_GITHUB_TOKEN }}'
|
||||
description: "User token, to read the GitHub API. Can be passed in using {{ secrets.LANGCHAIN_PEOPLE_GITHUB_TOKEN }}"
|
||||
required: true
|
||||
runs:
|
||||
using: 'docker'
|
||||
image: 'Dockerfile'
|
||||
using: "docker"
|
||||
image: "Dockerfile"
|
||||
|
||||
25
.github/scripts/check_diff.py
vendored
25
.github/scripts/check_diff.py
vendored
@@ -3,14 +3,12 @@ import json
|
||||
import os
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Set
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Set
|
||||
|
||||
import tomllib
|
||||
|
||||
from packaging.requirements import Requirement
|
||||
|
||||
from get_min_versions import get_min_version_from_toml
|
||||
|
||||
from packaging.requirements import Requirement
|
||||
|
||||
LANGCHAIN_DIRS = [
|
||||
"libs/core",
|
||||
@@ -38,7 +36,7 @@ IGNORED_PARTNERS = [
|
||||
]
|
||||
|
||||
PY_312_MAX_PACKAGES = [
|
||||
"libs/partners/chroma", # https://github.com/chroma-core/chroma/issues/4382
|
||||
"libs/partners/chroma", # https://github.com/chroma-core/chroma/issues/4382
|
||||
]
|
||||
|
||||
|
||||
@@ -85,9 +83,9 @@ def dependents_graph() -> dict:
|
||||
for depline in extended_deps:
|
||||
if depline.startswith("-e "):
|
||||
# editable dependency
|
||||
assert depline.startswith(
|
||||
"-e ../partners/"
|
||||
), "Extended test deps should only editable install partner packages"
|
||||
assert depline.startswith("-e ../partners/"), (
|
||||
"Extended test deps should only editable install partner packages"
|
||||
)
|
||||
partner = depline.split("partners/")[1]
|
||||
dep = f"langchain-{partner}"
|
||||
else:
|
||||
@@ -271,7 +269,7 @@ if __name__ == "__main__":
|
||||
dirs_to_run["extended-test"].add(dir_)
|
||||
elif file.startswith("libs/standard-tests"):
|
||||
# TODO: update to include all packages that rely on standard-tests (all partner packages)
|
||||
# note: won't run on external repo partners
|
||||
# Note: won't run on external repo partners
|
||||
dirs_to_run["lint"].add("libs/standard-tests")
|
||||
dirs_to_run["test"].add("libs/standard-tests")
|
||||
dirs_to_run["lint"].add("libs/cli")
|
||||
@@ -285,7 +283,7 @@ if __name__ == "__main__":
|
||||
elif file.startswith("libs/cli"):
|
||||
dirs_to_run["lint"].add("libs/cli")
|
||||
dirs_to_run["test"].add("libs/cli")
|
||||
|
||||
|
||||
elif file.startswith("libs/partners"):
|
||||
partner_dir = file.split("/")[2]
|
||||
if os.path.isdir(f"libs/partners/{partner_dir}") and [
|
||||
@@ -303,7 +301,10 @@ if __name__ == "__main__":
|
||||
f"Unknown lib: {file}. check_diff.py likely needs "
|
||||
"an update for this new library!"
|
||||
)
|
||||
elif file.startswith("docs/") or file in ["pyproject.toml", "uv.lock"]: # docs or root uv files
|
||||
elif file.startswith("docs/") or file in [
|
||||
"pyproject.toml",
|
||||
"uv.lock",
|
||||
]: # docs or root uv files
|
||||
docs_edited = True
|
||||
dirs_to_run["lint"].add(".")
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import sys
|
||||
|
||||
import tomllib
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
26
.github/scripts/get_min_versions.py
vendored
26
.github/scripts/get_min_versions.py
vendored
@@ -1,5 +1,5 @@
|
||||
from collections import defaultdict
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from typing import Optional
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
@@ -8,17 +8,13 @@ else:
|
||||
# for python 3.10 and below, which doesnt have stdlib tomllib
|
||||
import tomli as tomllib
|
||||
|
||||
from packaging.requirements import Requirement
|
||||
from packaging.specifiers import SpecifierSet
|
||||
from packaging.version import Version
|
||||
|
||||
|
||||
import requests
|
||||
from packaging.version import parse
|
||||
import re
|
||||
from typing import List
|
||||
|
||||
import re
|
||||
|
||||
import requests
|
||||
from packaging.requirements import Requirement
|
||||
from packaging.specifiers import SpecifierSet
|
||||
from packaging.version import Version, parse
|
||||
|
||||
MIN_VERSION_LIBS = [
|
||||
"langchain-core",
|
||||
@@ -72,11 +68,13 @@ def get_minimum_version(package_name: str, spec_string: str) -> Optional[str]:
|
||||
spec_string = re.sub(r"\^0\.0\.(\d+)", r"0.0.\1", spec_string)
|
||||
# rewrite occurrences of ^0.y.z to >=0.y.z,<0.y+1 (can be anywhere in constraint string)
|
||||
for y in range(1, 10):
|
||||
spec_string = re.sub(rf"\^0\.{y}\.(\d+)", rf">=0.{y}.\1,<0.{y+1}", spec_string)
|
||||
spec_string = re.sub(
|
||||
rf"\^0\.{y}\.(\d+)", rf">=0.{y}.\1,<0.{y + 1}", spec_string
|
||||
)
|
||||
# rewrite occurrences of ^x.y.z to >=x.y.z,<x+1.0.0 (can be anywhere in constraint string)
|
||||
for x in range(1, 10):
|
||||
spec_string = re.sub(
|
||||
rf"\^{x}\.(\d+)\.(\d+)", rf">={x}.\1.\2,<{x+1}", spec_string
|
||||
rf"\^{x}\.(\d+)\.(\d+)", rf">={x}.\1.\2,<{x + 1}", spec_string
|
||||
)
|
||||
|
||||
spec_set = SpecifierSet(spec_string)
|
||||
@@ -169,12 +167,12 @@ def check_python_version(version_string, constraint_string):
|
||||
# rewrite occurrences of ^0.y.z to >=0.y.z,<0.y+1.0 (can be anywhere in constraint string)
|
||||
for y in range(1, 10):
|
||||
constraint_string = re.sub(
|
||||
rf"\^0\.{y}\.(\d+)", rf">=0.{y}.\1,<0.{y+1}.0", constraint_string
|
||||
rf"\^0\.{y}\.(\d+)", rf">=0.{y}.\1,<0.{y + 1}.0", constraint_string
|
||||
)
|
||||
# rewrite occurrences of ^x.y.z to >=x.y.z,<x+1.0.0 (can be anywhere in constraint string)
|
||||
for x in range(1, 10):
|
||||
constraint_string = re.sub(
|
||||
rf"\^{x}\.0\.(\d+)", rf">={x}.0.\1,<{x+1}.0.0", constraint_string
|
||||
rf"\^{x}\.0\.(\d+)", rf">={x}.0.\1,<{x + 1}.0.0", constraint_string
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
47
.github/scripts/prep_api_docs_build.py
vendored
47
.github/scripts/prep_api_docs_build.py
vendored
@@ -3,9 +3,10 @@
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any
|
||||
from typing import Any, Dict
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
def load_packages_yaml() -> Dict[str, Any]:
|
||||
@@ -28,7 +29,6 @@ def get_target_dir(package_name: str) -> Path:
|
||||
def clean_target_directories(packages: list) -> None:
|
||||
"""Remove old directories that will be replaced."""
|
||||
for package in packages:
|
||||
|
||||
target_dir = get_target_dir(package["name"])
|
||||
if target_dir.exists():
|
||||
print(f"Removing {target_dir}")
|
||||
@@ -38,7 +38,6 @@ def clean_target_directories(packages: list) -> None:
|
||||
def move_libraries(packages: list) -> None:
|
||||
"""Move libraries from their source locations to the target directories."""
|
||||
for package in packages:
|
||||
|
||||
repo_name = package["repo"].split("/")[1]
|
||||
source_path = package["path"]
|
||||
target_dir = get_target_dir(package["name"])
|
||||
@@ -68,23 +67,33 @@ def main():
|
||||
package_yaml = load_packages_yaml()
|
||||
|
||||
# Clean target directories
|
||||
clean_target_directories([
|
||||
p
|
||||
for p in package_yaml["packages"]
|
||||
if (p["repo"].startswith("langchain-ai/") or p.get("include_in_api_ref"))
|
||||
and p["repo"] != "langchain-ai/langchain"
|
||||
and p["name"] != "langchain-ai21" # Skip AI21 due to dependency conflicts
|
||||
])
|
||||
clean_target_directories(
|
||||
[
|
||||
p
|
||||
for p in package_yaml["packages"]
|
||||
if (
|
||||
p["repo"].startswith("langchain-ai/") or p.get("include_in_api_ref")
|
||||
)
|
||||
and p["repo"] != "langchain-ai/langchain"
|
||||
and p["name"]
|
||||
!= "langchain-ai21" # Skip AI21 due to dependency conflicts
|
||||
]
|
||||
)
|
||||
|
||||
# Move libraries to their new locations
|
||||
move_libraries([
|
||||
p
|
||||
for p in package_yaml["packages"]
|
||||
if not p.get("disabled", False)
|
||||
and (p["repo"].startswith("langchain-ai/") or p.get("include_in_api_ref"))
|
||||
and p["repo"] != "langchain-ai/langchain"
|
||||
and p["name"] != "langchain-ai21" # Skip AI21 due to dependency conflicts
|
||||
])
|
||||
move_libraries(
|
||||
[
|
||||
p
|
||||
for p in package_yaml["packages"]
|
||||
if not p.get("disabled", False)
|
||||
and (
|
||||
p["repo"].startswith("langchain-ai/") or p.get("include_in_api_ref")
|
||||
)
|
||||
and p["repo"] != "langchain-ai/langchain"
|
||||
and p["name"]
|
||||
!= "langchain-ai21" # Skip AI21 due to dependency conflicts
|
||||
]
|
||||
)
|
||||
|
||||
# Delete ones without a pyproject.toml
|
||||
for partner in Path("langchain/libs/partners").iterdir():
|
||||
|
||||
416
.github/tools/git-restore-mtime
vendored
416
.github/tools/git-restore-mtime
vendored
@@ -81,56 +81,93 @@ import time
|
||||
__version__ = "2022.12+dev"
|
||||
|
||||
# Update symlinks only if the platform supports not following them
|
||||
UPDATE_SYMLINKS = bool(os.utime in getattr(os, 'supports_follow_symlinks', []))
|
||||
UPDATE_SYMLINKS = bool(os.utime in getattr(os, "supports_follow_symlinks", []))
|
||||
|
||||
# Call os.path.normpath() only if not in a POSIX platform (Windows)
|
||||
NORMALIZE_PATHS = (os.path.sep != '/')
|
||||
NORMALIZE_PATHS = os.path.sep != "/"
|
||||
|
||||
# How many files to process in each batch when re-trying merge commits
|
||||
STEPMISSING = 100
|
||||
|
||||
# (Extra) keywords for the os.utime() call performed by touch()
|
||||
UTIME_KWS = {} if not UPDATE_SYMLINKS else {'follow_symlinks': False}
|
||||
UTIME_KWS = {} if not UPDATE_SYMLINKS else {"follow_symlinks": False}
|
||||
|
||||
|
||||
# Command-line interface ######################################################
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description=__doc__.split('\n---')[0])
|
||||
parser = argparse.ArgumentParser(description=__doc__.split("\n---")[0])
|
||||
|
||||
group = parser.add_mutually_exclusive_group()
|
||||
group.add_argument('--quiet', '-q', dest='loglevel',
|
||||
action="store_const", const=logging.WARNING, default=logging.INFO,
|
||||
help="Suppress informative messages and summary statistics.")
|
||||
group.add_argument('--verbose', '-v', action="count", help="""
|
||||
group.add_argument(
|
||||
"--quiet",
|
||||
"-q",
|
||||
dest="loglevel",
|
||||
action="store_const",
|
||||
const=logging.WARNING,
|
||||
default=logging.INFO,
|
||||
help="Suppress informative messages and summary statistics.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--verbose",
|
||||
"-v",
|
||||
action="count",
|
||||
help="""
|
||||
Print additional information for each processed file.
|
||||
Specify twice to further increase verbosity.
|
||||
""")
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument('--cwd', '-C', metavar="DIRECTORY", help="""
|
||||
parser.add_argument(
|
||||
"--cwd",
|
||||
"-C",
|
||||
metavar="DIRECTORY",
|
||||
help="""
|
||||
Run as if %(prog)s was started in directory %(metavar)s.
|
||||
This affects how --work-tree, --git-dir and PATHSPEC arguments are handled.
|
||||
See 'man 1 git' or 'git --help' for more information.
|
||||
""")
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument('--git-dir', dest='gitdir', metavar="GITDIR", help="""
|
||||
parser.add_argument(
|
||||
"--git-dir",
|
||||
dest="gitdir",
|
||||
metavar="GITDIR",
|
||||
help="""
|
||||
Path to the git repository, by default auto-discovered by searching
|
||||
the current directory and its parents for a .git/ subdirectory.
|
||||
""")
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument('--work-tree', dest='workdir', metavar="WORKTREE", help="""
|
||||
parser.add_argument(
|
||||
"--work-tree",
|
||||
dest="workdir",
|
||||
metavar="WORKTREE",
|
||||
help="""
|
||||
Path to the work tree root, by default the parent of GITDIR if it's
|
||||
automatically discovered, or the current directory if GITDIR is set.
|
||||
""")
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument('--force', '-f', default=False, action="store_true", help="""
|
||||
parser.add_argument(
|
||||
"--force",
|
||||
"-f",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="""
|
||||
Force updating files with uncommitted modifications.
|
||||
Untracked files and uncommitted deletions, renames and additions are
|
||||
always ignored.
|
||||
""")
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument('--merge', '-m', default=False, action="store_true", help="""
|
||||
parser.add_argument(
|
||||
"--merge",
|
||||
"-m",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="""
|
||||
Include merge commits.
|
||||
Leads to more recent times and more files per commit, thus with the same
|
||||
time, which may or may not be what you want.
|
||||
@@ -138,71 +175,130 @@ def parse_args():
|
||||
are found sooner, which can improve performance, sometimes substantially.
|
||||
But as merge commits are usually huge, processing them may also take longer.
|
||||
By default, merge commits are only used for files missing from regular commits.
|
||||
""")
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument('--first-parent', default=False, action="store_true", help="""
|
||||
parser.add_argument(
|
||||
"--first-parent",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="""
|
||||
Consider only the first parent, the "main branch", when evaluating merge commits.
|
||||
Only effective when merge commits are processed, either when --merge is
|
||||
used or when finding missing files after the first regular log search.
|
||||
See --skip-missing.
|
||||
""")
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument('--skip-missing', '-s', dest="missing", default=True,
|
||||
action="store_false", help="""
|
||||
parser.add_argument(
|
||||
"--skip-missing",
|
||||
"-s",
|
||||
dest="missing",
|
||||
default=True,
|
||||
action="store_false",
|
||||
help="""
|
||||
Do not try to find missing files.
|
||||
If merge commits were not evaluated with --merge and some files were
|
||||
not found in regular commits, by default %(prog)s searches for these
|
||||
files again in the merge commits.
|
||||
This option disables this retry, so files found only in merge commits
|
||||
will not have their timestamp updated.
|
||||
""")
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument('--no-directories', '-D', dest='dirs', default=True,
|
||||
action="store_false", help="""
|
||||
parser.add_argument(
|
||||
"--no-directories",
|
||||
"-D",
|
||||
dest="dirs",
|
||||
default=True,
|
||||
action="store_false",
|
||||
help="""
|
||||
Do not update directory timestamps.
|
||||
By default, use the time of its most recently created, renamed or deleted file.
|
||||
Note that just modifying a file will NOT update its directory time.
|
||||
""")
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument('--test', '-t', default=False, action="store_true",
|
||||
help="Test run: do not actually update any file timestamp.")
|
||||
parser.add_argument(
|
||||
"--test",
|
||||
"-t",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Test run: do not actually update any file timestamp.",
|
||||
)
|
||||
|
||||
parser.add_argument('--commit-time', '-c', dest='commit_time', default=False,
|
||||
action='store_true', help="Use commit time instead of author time.")
|
||||
parser.add_argument(
|
||||
"--commit-time",
|
||||
"-c",
|
||||
dest="commit_time",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Use commit time instead of author time.",
|
||||
)
|
||||
|
||||
parser.add_argument('--oldest-time', '-o', dest='reverse_order', default=False,
|
||||
action='store_true', help="""
|
||||
parser.add_argument(
|
||||
"--oldest-time",
|
||||
"-o",
|
||||
dest="reverse_order",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="""
|
||||
Update times based on the oldest, instead of the most recent commit of a file.
|
||||
This reverses the order in which the git log is processed to emulate a
|
||||
file "creation" date. Note this will be inaccurate for files deleted and
|
||||
re-created at later dates.
|
||||
""")
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument('--skip-older-than', metavar='SECONDS', type=int, help="""
|
||||
parser.add_argument(
|
||||
"--skip-older-than",
|
||||
metavar="SECONDS",
|
||||
type=int,
|
||||
help="""
|
||||
Ignore files that are currently older than %(metavar)s.
|
||||
Useful in workflows that assume such files already have a correct timestamp,
|
||||
as it may improve performance by processing fewer files.
|
||||
""")
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument('--skip-older-than-commit', '-N', default=False,
|
||||
action='store_true', help="""
|
||||
parser.add_argument(
|
||||
"--skip-older-than-commit",
|
||||
"-N",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="""
|
||||
Ignore files older than the timestamp it would be updated to.
|
||||
Such files may be considered "original", likely in the author's repository.
|
||||
""")
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument('--unique-times', default=False, action="store_true", help="""
|
||||
parser.add_argument(
|
||||
"--unique-times",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="""
|
||||
Set the microseconds to a unique value per commit.
|
||||
Allows telling apart changes that would otherwise have identical timestamps,
|
||||
as git's time accuracy is in seconds.
|
||||
""")
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument('pathspec', nargs='*', metavar='PATHSPEC', help="""
|
||||
parser.add_argument(
|
||||
"pathspec",
|
||||
nargs="*",
|
||||
metavar="PATHSPEC",
|
||||
help="""
|
||||
Only modify paths matching %(metavar)s, relative to current directory.
|
||||
By default, update all but untracked files and submodules.
|
||||
""")
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument('--version', '-V', action='version',
|
||||
version='%(prog)s version {version}'.format(version=get_version()))
|
||||
parser.add_argument(
|
||||
"--version",
|
||||
"-V",
|
||||
action="version",
|
||||
version="%(prog)s version {version}".format(version=get_version()),
|
||||
)
|
||||
|
||||
args_ = parser.parse_args()
|
||||
if args_.verbose:
|
||||
@@ -212,17 +308,18 @@ def parse_args():
|
||||
|
||||
|
||||
def get_version(version=__version__):
|
||||
if not version.endswith('+dev'):
|
||||
if not version.endswith("+dev"):
|
||||
return version
|
||||
try:
|
||||
cwd = os.path.dirname(os.path.realpath(__file__))
|
||||
return Git(cwd=cwd, errors=False).describe().lstrip('v')
|
||||
return Git(cwd=cwd, errors=False).describe().lstrip("v")
|
||||
except Git.Error:
|
||||
return '-'.join((version, "unknown"))
|
||||
return "-".join((version, "unknown"))
|
||||
|
||||
|
||||
# Helper functions ############################################################
|
||||
|
||||
|
||||
def setup_logging():
|
||||
"""Add TRACE logging level and corresponding method, return the root logger"""
|
||||
logging.TRACE = TRACE = logging.DEBUG // 2
|
||||
@@ -255,11 +352,13 @@ def normalize(path):
|
||||
if path and path[0] == '"':
|
||||
# Python 2: path = path[1:-1].decode("string-escape")
|
||||
# Python 3: https://stackoverflow.com/a/46650050/624066
|
||||
path = (path[1:-1] # Remove enclosing double quotes
|
||||
.encode('latin1') # Convert to bytes, required by 'unicode-escape'
|
||||
.decode('unicode-escape') # Perform the actual octal-escaping decode
|
||||
.encode('latin1') # 1:1 mapping to bytes, UTF-8 encoded
|
||||
.decode('utf8', 'surrogateescape')) # Decode from UTF-8
|
||||
path = (
|
||||
path[1:-1] # Remove enclosing double quotes
|
||||
.encode("latin1") # Convert to bytes, required by 'unicode-escape'
|
||||
.decode("unicode-escape") # Perform the actual octal-escaping decode
|
||||
.encode("latin1") # 1:1 mapping to bytes, UTF-8 encoded
|
||||
.decode("utf8", "surrogateescape")
|
||||
) # Decode from UTF-8
|
||||
if NORMALIZE_PATHS:
|
||||
# Make sure the slash matches the OS; for Windows we need a backslash
|
||||
path = os.path.normpath(path)
|
||||
@@ -282,12 +381,12 @@ def touch_ns(path, mtime_ns):
|
||||
|
||||
def isodate(secs: int):
|
||||
# time.localtime() accepts floats, but discards fractional part
|
||||
return time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(secs))
|
||||
return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(secs))
|
||||
|
||||
|
||||
def isodate_ns(ns: int):
|
||||
# for integers fromtimestamp() is equivalent and ~16% slower than isodate()
|
||||
return datetime.datetime.fromtimestamp(ns / 1000000000).isoformat(sep=' ')
|
||||
return datetime.datetime.fromtimestamp(ns / 1000000000).isoformat(sep=" ")
|
||||
|
||||
|
||||
def get_mtime_ns(secs: int, idx: int):
|
||||
@@ -305,35 +404,49 @@ def get_mtime_path(path):
|
||||
|
||||
# Git class and parse_log(), the heart of the script ##########################
|
||||
|
||||
|
||||
class Git:
|
||||
def __init__(self, workdir=None, gitdir=None, cwd=None, errors=True):
|
||||
self.gitcmd = ['git']
|
||||
self.gitcmd = ["git"]
|
||||
self.errors = errors
|
||||
self._proc = None
|
||||
if workdir: self.gitcmd.extend(('--work-tree', workdir))
|
||||
if gitdir: self.gitcmd.extend(('--git-dir', gitdir))
|
||||
if cwd: self.gitcmd.extend(('-C', cwd))
|
||||
if workdir:
|
||||
self.gitcmd.extend(("--work-tree", workdir))
|
||||
if gitdir:
|
||||
self.gitcmd.extend(("--git-dir", gitdir))
|
||||
if cwd:
|
||||
self.gitcmd.extend(("-C", cwd))
|
||||
self.workdir, self.gitdir = self._get_repo_dirs()
|
||||
|
||||
def ls_files(self, paths: list = None):
|
||||
return (normalize(_) for _ in self._run('ls-files --full-name', paths))
|
||||
return (normalize(_) for _ in self._run("ls-files --full-name", paths))
|
||||
|
||||
def ls_dirty(self, force=False):
|
||||
return (normalize(_[3:].split(' -> ', 1)[-1])
|
||||
for _ in self._run('status --porcelain')
|
||||
if _[:2] != '??' and (not force or (_[0] in ('R', 'A')
|
||||
or _[1] == 'D')))
|
||||
return (
|
||||
normalize(_[3:].split(" -> ", 1)[-1])
|
||||
for _ in self._run("status --porcelain")
|
||||
if _[:2] != "??" and (not force or (_[0] in ("R", "A") or _[1] == "D"))
|
||||
)
|
||||
|
||||
def log(self, merge=False, first_parent=False, commit_time=False,
|
||||
reverse_order=False, paths: list = None):
|
||||
cmd = 'whatchanged --pretty={}'.format('%ct' if commit_time else '%at')
|
||||
if merge: cmd += ' -m'
|
||||
if first_parent: cmd += ' --first-parent'
|
||||
if reverse_order: cmd += ' --reverse'
|
||||
def log(
|
||||
self,
|
||||
merge=False,
|
||||
first_parent=False,
|
||||
commit_time=False,
|
||||
reverse_order=False,
|
||||
paths: list = None,
|
||||
):
|
||||
cmd = "whatchanged --pretty={}".format("%ct" if commit_time else "%at")
|
||||
if merge:
|
||||
cmd += " -m"
|
||||
if first_parent:
|
||||
cmd += " --first-parent"
|
||||
if reverse_order:
|
||||
cmd += " --reverse"
|
||||
return self._run(cmd, paths)
|
||||
|
||||
def describe(self):
|
||||
return self._run('describe --tags', check=True)[0]
|
||||
return self._run("describe --tags", check=True)[0]
|
||||
|
||||
def terminate(self):
|
||||
if self._proc is None:
|
||||
@@ -345,18 +458,22 @@ class Git:
|
||||
pass
|
||||
|
||||
def _get_repo_dirs(self):
|
||||
return (os.path.normpath(_) for _ in
|
||||
self._run('rev-parse --show-toplevel --absolute-git-dir', check=True))
|
||||
return (
|
||||
os.path.normpath(_)
|
||||
for _ in self._run(
|
||||
"rev-parse --show-toplevel --absolute-git-dir", check=True
|
||||
)
|
||||
)
|
||||
|
||||
def _run(self, cmdstr: str, paths: list = None, output=True, check=False):
|
||||
cmdlist = self.gitcmd + shlex.split(cmdstr)
|
||||
if paths:
|
||||
cmdlist.append('--')
|
||||
cmdlist.append("--")
|
||||
cmdlist.extend(paths)
|
||||
popen_args = dict(universal_newlines=True, encoding='utf8')
|
||||
popen_args = dict(universal_newlines=True, encoding="utf8")
|
||||
if not self.errors:
|
||||
popen_args['stderr'] = subprocess.DEVNULL
|
||||
log.trace("Executing: %s", ' '.join(cmdlist))
|
||||
popen_args["stderr"] = subprocess.DEVNULL
|
||||
log.trace("Executing: %s", " ".join(cmdlist))
|
||||
if not output:
|
||||
return subprocess.call(cmdlist, **popen_args)
|
||||
if check:
|
||||
@@ -379,30 +496,26 @@ def parse_log(filelist, dirlist, stats, git, merge=False, filterlist=None):
|
||||
mtime = 0
|
||||
datestr = isodate(0)
|
||||
for line in git.log(
|
||||
merge,
|
||||
args.first_parent,
|
||||
args.commit_time,
|
||||
args.reverse_order,
|
||||
filterlist
|
||||
merge, args.first_parent, args.commit_time, args.reverse_order, filterlist
|
||||
):
|
||||
stats['loglines'] += 1
|
||||
stats["loglines"] += 1
|
||||
|
||||
# Blank line between Date and list of files
|
||||
if not line:
|
||||
continue
|
||||
|
||||
# Date line
|
||||
if line[0] != ':': # Faster than `not line.startswith(':')`
|
||||
stats['commits'] += 1
|
||||
if line[0] != ":": # Faster than `not line.startswith(':')`
|
||||
stats["commits"] += 1
|
||||
mtime = int(line)
|
||||
if args.unique_times:
|
||||
mtime = get_mtime_ns(mtime, stats['commits'])
|
||||
mtime = get_mtime_ns(mtime, stats["commits"])
|
||||
if args.debug:
|
||||
datestr = isodate(mtime)
|
||||
continue
|
||||
|
||||
# File line: three tokens if it describes a renaming, otherwise two
|
||||
tokens = line.split('\t')
|
||||
tokens = line.split("\t")
|
||||
|
||||
# Possible statuses:
|
||||
# M: Modified (content changed)
|
||||
@@ -411,7 +524,7 @@ def parse_log(filelist, dirlist, stats, git, merge=False, filterlist=None):
|
||||
# T: Type changed: to/from regular file, symlinks, submodules
|
||||
# R099: Renamed (moved), with % of unchanged content. 100 = pure rename
|
||||
# Not possible in log: C=Copied, U=Unmerged, X=Unknown, B=pairing Broken
|
||||
status = tokens[0].split(' ')[-1]
|
||||
status = tokens[0].split(" ")[-1]
|
||||
file = tokens[-1]
|
||||
|
||||
# Handles non-ASCII chars and OS path separator
|
||||
@@ -419,56 +532,76 @@ def parse_log(filelist, dirlist, stats, git, merge=False, filterlist=None):
|
||||
|
||||
def do_file():
|
||||
if args.skip_older_than_commit and get_mtime_path(file) <= mtime:
|
||||
stats['skip'] += 1
|
||||
stats["skip"] += 1
|
||||
return
|
||||
if args.debug:
|
||||
log.debug("%d\t%d\t%d\t%s\t%s",
|
||||
stats['loglines'], stats['commits'], stats['files'],
|
||||
datestr, file)
|
||||
log.debug(
|
||||
"%d\t%d\t%d\t%s\t%s",
|
||||
stats["loglines"],
|
||||
stats["commits"],
|
||||
stats["files"],
|
||||
datestr,
|
||||
file,
|
||||
)
|
||||
try:
|
||||
touch(os.path.join(git.workdir, file), mtime)
|
||||
stats['touches'] += 1
|
||||
stats["touches"] += 1
|
||||
except Exception as e:
|
||||
log.error("ERROR: %s: %s", e, file)
|
||||
stats['errors'] += 1
|
||||
stats["errors"] += 1
|
||||
|
||||
def do_dir():
|
||||
if args.debug:
|
||||
log.debug("%d\t%d\t-\t%s\t%s",
|
||||
stats['loglines'], stats['commits'],
|
||||
datestr, "{}/".format(dirname or '.'))
|
||||
log.debug(
|
||||
"%d\t%d\t-\t%s\t%s",
|
||||
stats["loglines"],
|
||||
stats["commits"],
|
||||
datestr,
|
||||
"{}/".format(dirname or "."),
|
||||
)
|
||||
try:
|
||||
touch(os.path.join(git.workdir, dirname), mtime)
|
||||
stats['dirtouches'] += 1
|
||||
stats["dirtouches"] += 1
|
||||
except Exception as e:
|
||||
log.error("ERROR: %s: %s", e, dirname)
|
||||
stats['direrrors'] += 1
|
||||
stats["direrrors"] += 1
|
||||
|
||||
if file in filelist:
|
||||
stats['files'] -= 1
|
||||
stats["files"] -= 1
|
||||
filelist.remove(file)
|
||||
do_file()
|
||||
|
||||
if args.dirs and status in ('A', 'D'):
|
||||
if args.dirs and status in ("A", "D"):
|
||||
dirname = os.path.dirname(file)
|
||||
if dirname in dirlist:
|
||||
dirlist.remove(dirname)
|
||||
do_dir()
|
||||
|
||||
# All files done?
|
||||
if not stats['files']:
|
||||
if not stats["files"]:
|
||||
git.terminate()
|
||||
return
|
||||
|
||||
|
||||
# Main Logic ##################################################################
|
||||
|
||||
|
||||
def main():
|
||||
start = time.time() # yes, Wall time. CPU time is not realistic for users.
|
||||
stats = {_: 0 for _ in ('loglines', 'commits', 'touches', 'skip', 'errors',
|
||||
'dirtouches', 'direrrors')}
|
||||
stats = {
|
||||
_: 0
|
||||
for _ in (
|
||||
"loglines",
|
||||
"commits",
|
||||
"touches",
|
||||
"skip",
|
||||
"errors",
|
||||
"dirtouches",
|
||||
"direrrors",
|
||||
)
|
||||
}
|
||||
|
||||
logging.basicConfig(level=args.loglevel, format='%(message)s')
|
||||
logging.basicConfig(level=args.loglevel, format="%(message)s")
|
||||
log.trace("Arguments: %s", args)
|
||||
|
||||
# First things first: Where and Who are we?
|
||||
@@ -499,13 +632,16 @@ def main():
|
||||
|
||||
# Symlink (to file, to dir or broken - git handles the same way)
|
||||
if not UPDATE_SYMLINKS and os.path.islink(fullpath):
|
||||
log.warning("WARNING: Skipping symlink, no OS support for updates: %s",
|
||||
path)
|
||||
log.warning(
|
||||
"WARNING: Skipping symlink, no OS support for updates: %s", path
|
||||
)
|
||||
continue
|
||||
|
||||
# skip files which are older than given threshold
|
||||
if (args.skip_older_than
|
||||
and start - get_mtime_path(fullpath) > args.skip_older_than):
|
||||
if (
|
||||
args.skip_older_than
|
||||
and start - get_mtime_path(fullpath) > args.skip_older_than
|
||||
):
|
||||
continue
|
||||
|
||||
# Always add files relative to worktree root
|
||||
@@ -519,15 +655,17 @@ def main():
|
||||
else:
|
||||
dirty = set(git.ls_dirty())
|
||||
if dirty:
|
||||
log.warning("WARNING: Modified files in the working directory were ignored."
|
||||
"\nTo include such files, commit your changes or use --force.")
|
||||
log.warning(
|
||||
"WARNING: Modified files in the working directory were ignored."
|
||||
"\nTo include such files, commit your changes or use --force."
|
||||
)
|
||||
filelist -= dirty
|
||||
|
||||
# Build dir list to be processed
|
||||
dirlist = set(os.path.dirname(_) for _ in filelist) if args.dirs else set()
|
||||
|
||||
stats['totalfiles'] = stats['files'] = len(filelist)
|
||||
log.info("{0:,} files to be processed in work dir".format(stats['totalfiles']))
|
||||
stats["totalfiles"] = stats["files"] = len(filelist)
|
||||
log.info("{0:,} files to be processed in work dir".format(stats["totalfiles"]))
|
||||
|
||||
if not filelist:
|
||||
# Nothing to do. Exit silently and without errors, just like git does
|
||||
@@ -544,10 +682,18 @@ def main():
|
||||
if args.missing and not args.merge:
|
||||
filterlist = list(filelist)
|
||||
missing = len(filterlist)
|
||||
log.info("{0:,} files not found in log, trying merge commits".format(missing))
|
||||
log.info(
|
||||
"{0:,} files not found in log, trying merge commits".format(missing)
|
||||
)
|
||||
for i in range(0, missing, STEPMISSING):
|
||||
parse_log(filelist, dirlist, stats, git,
|
||||
merge=True, filterlist=filterlist[i:i + STEPMISSING])
|
||||
parse_log(
|
||||
filelist,
|
||||
dirlist,
|
||||
stats,
|
||||
git,
|
||||
merge=True,
|
||||
filterlist=filterlist[i : i + STEPMISSING],
|
||||
)
|
||||
|
||||
# Still missing some?
|
||||
for file in filelist:
|
||||
@@ -556,29 +702,33 @@ def main():
|
||||
# Final statistics
|
||||
# Suggestion: use git-log --before=mtime to brag about skipped log entries
|
||||
def log_info(msg, *a, width=13):
|
||||
ifmt = '{:%d,}' % (width,) # not using 'n' for consistency with ffmt
|
||||
ffmt = '{:%d,.2f}' % (width,)
|
||||
ifmt = "{:%d,}" % (width,) # not using 'n' for consistency with ffmt
|
||||
ffmt = "{:%d,.2f}" % (width,)
|
||||
# %-formatting lacks a thousand separator, must pre-render with .format()
|
||||
log.info(msg.replace('%d', ifmt).replace('%f', ffmt).format(*a))
|
||||
log.info(msg.replace("%d", ifmt).replace("%f", ffmt).format(*a))
|
||||
|
||||
log_info(
|
||||
"Statistics:\n"
|
||||
"%f seconds\n"
|
||||
"%d log lines processed\n"
|
||||
"%d commits evaluated",
|
||||
time.time() - start, stats['loglines'], stats['commits'])
|
||||
"Statistics:\n%f seconds\n%d log lines processed\n%d commits evaluated",
|
||||
time.time() - start,
|
||||
stats["loglines"],
|
||||
stats["commits"],
|
||||
)
|
||||
|
||||
if args.dirs:
|
||||
if stats['direrrors']: log_info("%d directory update errors", stats['direrrors'])
|
||||
log_info("%d directories updated", stats['dirtouches'])
|
||||
if stats["direrrors"]:
|
||||
log_info("%d directory update errors", stats["direrrors"])
|
||||
log_info("%d directories updated", stats["dirtouches"])
|
||||
|
||||
if stats['touches'] != stats['totalfiles']:
|
||||
log_info("%d files", stats['totalfiles'])
|
||||
if stats['skip']: log_info("%d files skipped", stats['skip'])
|
||||
if stats['files']: log_info("%d files missing", stats['files'])
|
||||
if stats['errors']: log_info("%d file update errors", stats['errors'])
|
||||
if stats["touches"] != stats["totalfiles"]:
|
||||
log_info("%d files", stats["totalfiles"])
|
||||
if stats["skip"]:
|
||||
log_info("%d files skipped", stats["skip"])
|
||||
if stats["files"]:
|
||||
log_info("%d files missing", stats["files"])
|
||||
if stats["errors"]:
|
||||
log_info("%d file update errors", stats["errors"])
|
||||
|
||||
log_info("%d files updated", stats['touches'])
|
||||
log_info("%d files updated", stats["touches"])
|
||||
|
||||
if args.test:
|
||||
log.info("TEST RUN - No files modified!")
|
||||
|
||||
3
.github/workflows/_release.yml
vendored
3
.github/workflows/_release.yml
vendored
@@ -388,11 +388,12 @@ jobs:
|
||||
- name: Test against ${{ matrix.partner }}
|
||||
if: startsWith(inputs.working-directory, 'libs/core')
|
||||
run: |
|
||||
# Identify latest tag
|
||||
# Identify latest tag, excluding pre-releases
|
||||
LATEST_PACKAGE_TAG="$(
|
||||
git ls-remote --tags origin "langchain-${{ matrix.partner }}*" \
|
||||
| awk '{print $2}' \
|
||||
| sed 's|refs/tags/||' \
|
||||
| grep -Ev '==[^=]*(\.?dev[0-9]*|\.?rc[0-9]*)$' \
|
||||
| sort -Vr \
|
||||
| head -n 1
|
||||
)"
|
||||
|
||||
2
.github/workflows/_test.yml
vendored
2
.github/workflows/_test.yml
vendored
@@ -79,4 +79,4 @@ jobs:
|
||||
# grep will exit non-zero if the target message isn't found,
|
||||
# and `set -e` above will cause the step to fail.
|
||||
echo "$STATUS" | grep 'nothing to commit, working tree clean'
|
||||
|
||||
|
||||
|
||||
2
.github/workflows/_test_pydantic.yml
vendored
2
.github/workflows/_test_pydantic.yml
vendored
@@ -64,4 +64,4 @@ jobs:
|
||||
|
||||
# grep will exit non-zero if the target message isn't found,
|
||||
# and `set -e` above will cause the step to fail.
|
||||
echo "$STATUS" | grep 'nothing to commit, working tree clean'
|
||||
echo "$STATUS" | grep 'nothing to commit, working tree clean'
|
||||
|
||||
2
.github/workflows/api_doc_build.yml
vendored
2
.github/workflows/api_doc_build.yml
vendored
@@ -52,7 +52,6 @@ jobs:
|
||||
run: |
|
||||
# Get unique repositories
|
||||
REPOS=$(echo "$REPOS_UNSORTED" | sort -u)
|
||||
|
||||
# Checkout each unique repository
|
||||
for repo in $REPOS; do
|
||||
# Validate repository format (allow any org with proper format)
|
||||
@@ -68,7 +67,6 @@ jobs:
|
||||
echo "Error: Invalid repository name: $REPO_NAME"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Checking out $repo to $REPO_NAME"
|
||||
git clone --depth 1 https://github.com/$repo.git $REPO_NAME
|
||||
done
|
||||
|
||||
1
.github/workflows/check_diffs.yml
vendored
1
.github/workflows/check_diffs.yml
vendored
@@ -30,6 +30,7 @@ jobs:
|
||||
build:
|
||||
name: 'Detect Changes & Set Matrix'
|
||||
runs-on: ubuntu-latest
|
||||
if: ${{ !contains(github.event.pull_request.labels.*.name, 'ci-ignore') }}
|
||||
steps:
|
||||
- name: '📋 Checkout Code'
|
||||
uses: actions/checkout@v4
|
||||
|
||||
1
.github/workflows/codspeed.yml
vendored
1
.github/workflows/codspeed.yml
vendored
@@ -20,6 +20,7 @@ jobs:
|
||||
codspeed:
|
||||
name: 'Benchmark'
|
||||
runs-on: ubuntu-latest
|
||||
if: ${{ !contains(github.event.pull_request.labels.*.name, 'codspeed-ignore') }}
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
|
||||
@@ -11,4 +11,4 @@
|
||||
"MD046": {
|
||||
"style": "fenced"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
4
.vscode/settings.json
vendored
4
.vscode/settings.json
vendored
@@ -21,7 +21,7 @@
|
||||
"[python]": {
|
||||
"editor.formatOnSave": true,
|
||||
"editor.codeActionsOnSave": {
|
||||
"source.organizeImports": "explicit",
|
||||
"source.organizeImports.ruff": "explicit",
|
||||
"source.fixAll": "explicit"
|
||||
},
|
||||
"editor.defaultFormatter": "charliermarsh.ruff"
|
||||
@@ -77,4 +77,6 @@
|
||||
"editor.tabSize": 2,
|
||||
"editor.insertSpaces": true
|
||||
},
|
||||
"python.terminal.activateEnvironment": false,
|
||||
"python.defaultInterpreterPath": "./.venv/bin/python"
|
||||
}
|
||||
|
||||
@@ -63,4 +63,4 @@ Notebook | Description
|
||||
[rag-locally-on-intel-cpu.ipynb](https://github.com/langchain-ai/langchain/tree/master/cookbook/rag-locally-on-intel-cpu.ipynb) | Perform Retrieval-Augmented-Generation (RAG) on locally downloaded open-source models using langchain and open source tools and execute it on Intel Xeon CPU. We showed an example of how to apply RAG on Llama 2 model and enable it to answer the queries related to Intel Q1 2024 earnings release.
|
||||
[visual_RAG_vdms.ipynb](https://github.com/langchain-ai/langchain/tree/master/cookbook/visual_RAG_vdms.ipynb) | Performs Visual Retrieval-Augmented-Generation (RAG) using videos and scene descriptions generated by open source models.
|
||||
[contextual_rag.ipynb](https://github.com/langchain-ai/langchain/tree/master/cookbook/contextual_rag.ipynb) | Performs contextual retrieval-augmented generation (RAG) prepending chunk-specific explanatory context to each chunk before embedding.
|
||||
[rag-agents-locally-on-intel-cpu.ipynb](https://github.com/langchain-ai/langchain/tree/master/cookbook/local_rag_agents_intel_cpu.ipynb) | Build a RAG agent locally with open source models that routes questions through one of two paths to find answers. The agent generates answers based on documents retrieved from either the vector database or retrieved from web search. If the vector database lacks relevant information, the agent opts for web search. Open-source models for LLM and embeddings are used locally on an Intel Xeon CPU to execute this pipeline.
|
||||
[rag-agents-locally-on-intel-cpu.ipynb](https://github.com/langchain-ai/langchain/tree/master/cookbook/local_rag_agents_intel_cpu.ipynb) | Build a RAG agent locally with open source models that routes questions through one of two paths to find answers. The agent generates answers based on documents retrieved from either the vector database or retrieved from web search. If the vector database lacks relevant information, the agent opts for web search. Open-source models for LLM and embeddings are used locally on an Intel Xeon CPU to execute this pipeline.
|
||||
|
||||
@@ -97,7 +97,7 @@ def _load_module_members(module_path: str, namespace: str) -> ModuleMembers:
|
||||
if type(type_) is typing_extensions._TypedDictMeta: # type: ignore
|
||||
kind: ClassKind = "TypedDict"
|
||||
elif type(type_) is typing._TypedDictMeta: # type: ignore
|
||||
kind: ClassKind = "TypedDict"
|
||||
kind = "TypedDict"
|
||||
elif (
|
||||
issubclass(type_, Runnable)
|
||||
and issubclass(type_, BaseModel)
|
||||
@@ -189,7 +189,7 @@ def _load_package_modules(
|
||||
if isinstance(package_directory, str)
|
||||
else package_directory
|
||||
)
|
||||
modules_by_namespace = {}
|
||||
modules_by_namespace: Dict[str, ModuleMembers] = {}
|
||||
|
||||
# Get the high level package name
|
||||
package_name = package_path.name
|
||||
@@ -217,7 +217,11 @@ def _load_package_modules(
|
||||
# Get the full namespace of the module
|
||||
namespace = str(relative_module_name).replace(".py", "").replace("/", ".")
|
||||
# Keep only the top level namespace
|
||||
top_namespace = namespace.split(".")[0]
|
||||
# (but make special exception for content_blocks and v1.messages)
|
||||
if namespace == "messages.content_blocks" or namespace == "v1.messages":
|
||||
top_namespace = namespace # Keep full namespace for content_blocks
|
||||
else:
|
||||
top_namespace = namespace.split(".")[0]
|
||||
|
||||
try:
|
||||
# If submodule is present, we need to construct the paths in a slightly
|
||||
@@ -283,7 +287,7 @@ def _construct_doc(
|
||||
.. toctree::
|
||||
:hidden:
|
||||
:maxdepth: 2
|
||||
|
||||
|
||||
"""
|
||||
index_autosummary = """
|
||||
"""
|
||||
@@ -365,9 +369,9 @@ def _construct_doc(
|
||||
|
||||
module_doc += f"""\
|
||||
:template: {template}
|
||||
|
||||
|
||||
{class_["qualified_name"]}
|
||||
|
||||
|
||||
"""
|
||||
index_autosummary += f"""
|
||||
{class_["qualified_name"]}
|
||||
@@ -550,8 +554,8 @@ def _build_index(dirs: List[str]) -> None:
|
||||
integrations = sorted(dir_ for dir_ in dirs if dir_ not in main_)
|
||||
doc = """# LangChain Python API Reference
|
||||
|
||||
Welcome to the LangChain Python API reference. This is a reference for all
|
||||
`langchain-x` packages.
|
||||
Welcome to the LangChain Python API reference. This is a reference for all
|
||||
`langchain-x` packages.
|
||||
|
||||
For user guides see [https://python.langchain.com](https://python.langchain.com).
|
||||
|
||||
|
||||
@@ -124,6 +124,47 @@ start "" htmlcov/index.html || open htmlcov/index.html
|
||||
|
||||
```
|
||||
|
||||
## Snapshot Testing
|
||||
|
||||
Some tests use [syrupy](https://github.com/tophat/syrupy) for snapshot testing, which captures the output of functions and compares them to stored snapshots. This is particularly useful for testing JSON schema generation and other structured outputs.
|
||||
|
||||
### Updating Snapshots
|
||||
|
||||
To update snapshots when the expected output has legitimately changed:
|
||||
|
||||
```bash
|
||||
uv run --group test pytest path/to/test.py --snapshot-update
|
||||
```
|
||||
|
||||
### Pydantic Version Compatibility Issues
|
||||
|
||||
Pydantic generates different JSON schemas across versions, which can cause snapshot test failures in CI when tests run with different Pydantic versions than what was used to generate the snapshots.
|
||||
|
||||
**Symptoms:**
|
||||
- CI fails with snapshot mismatches showing differences like missing or extra fields.
|
||||
- Tests pass locally but fail in CI with different Pydantic versions
|
||||
|
||||
**Solution:**
|
||||
Locally update snapshots using the same Pydantic version that CI uses:
|
||||
|
||||
1. **Identify the failing Pydantic version** from CI logs (e.g., `2.7.0`, `2.8.0`, `2.9.0`)
|
||||
|
||||
2. **Update snapshots with that version:**
|
||||
```bash
|
||||
uv run --with "pydantic==2.9.0" --group test pytest tests/unit_tests/path/to/test.py::test_name --snapshot-update
|
||||
```
|
||||
|
||||
3. **Verify compatibility across supported versions:**
|
||||
```bash
|
||||
# Test with the version you used to update
|
||||
uv run --with "pydantic==2.9.0" --group test pytest tests/unit_tests/path/to/test.py::test_name
|
||||
|
||||
# Test with other supported versions
|
||||
uv run --with "pydantic==2.8.0" --group test pytest tests/unit_tests/path/to/test.py::test_name
|
||||
```
|
||||
|
||||
**Note:** Some tests use `@pytest.mark.skipif` decorators to only run with specific Pydantic version ranges (e.g., `PYDANTIC_VERSION_AT_LEAST_210`). Make sure to understand these constraints when updating snapshots.
|
||||
|
||||
## Coverage
|
||||
|
||||
Code coverage (i.e. the amount of code that is covered by unit tests) helps identify areas of the code that are potentially more or less brittle.
|
||||
|
||||
@@ -122,13 +122,13 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"\n",
|
||||
"from langchain_experimental.graph_transformers import LLMGraphTransformer\n",
|
||||
"# from langchain_experimental.graph_transformers import LLMGraphTransformer\n",
|
||||
"from langchain_openai import ChatOpenAI\n",
|
||||
"\n",
|
||||
"llm = ChatOpenAI(temperature=0, model_name=\"gpt-4-turbo\")\n",
|
||||
|
||||
@@ -74,12 +74,12 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": null,
|
||||
"id": "a88ff70c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_experimental.text_splitter import SemanticChunker\n",
|
||||
"# from langchain_experimental.text_splitter import SemanticChunker\n",
|
||||
"from langchain_openai.embeddings import OpenAIEmbeddings\n",
|
||||
"\n",
|
||||
"text_splitter = SemanticChunker(OpenAIEmbeddings())"
|
||||
|
||||
@@ -612,56 +612,11 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"execution_count": null,
|
||||
"id": "35ea904e-795f-411b-bef8-6484dbb6e35c",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
||||
"\u001b[32;1m\u001b[1;3m\n",
|
||||
"Invoking: `python_repl_ast` with `{'query': \"df[['Age', 'Fare']].corr().iloc[0,1]\"}`\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[0m\u001b[36;1m\u001b[1;3m0.11232863699941621\u001b[0m\u001b[32;1m\u001b[1;3m\n",
|
||||
"Invoking: `python_repl_ast` with `{'query': \"df[['Fare', 'Survived']].corr().iloc[0,1]\"}`\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[0m\u001b[36;1m\u001b[1;3m0.2561785496289603\u001b[0m\u001b[32;1m\u001b[1;3mThe correlation between Age and Fare is approximately 0.112, and the correlation between Fare and Survival is approximately 0.256.\n",
|
||||
"\n",
|
||||
"Therefore, the correlation between Fare and Survival (0.256) is greater than the correlation between Age and Fare (0.112).\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'input': \"What's the correlation between age and fare? is that greater than the correlation between fare and survival?\",\n",
|
||||
" 'output': 'The correlation between Age and Fare is approximately 0.112, and the correlation between Fare and Survival is approximately 0.256.\\n\\nTherefore, the correlation between Fare and Survival (0.256) is greater than the correlation between Age and Fare (0.112).'}"
|
||||
]
|
||||
},
|
||||
"execution_count": 18,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain_experimental.agents import create_pandas_dataframe_agent\n",
|
||||
"\n",
|
||||
"agent = create_pandas_dataframe_agent(\n",
|
||||
" llm, df, agent_type=\"openai-tools\", verbose=True, allow_dangerous_code=True\n",
|
||||
")\n",
|
||||
"agent.invoke(\n",
|
||||
" {\n",
|
||||
" \"input\": \"What's the correlation between age and fare? is that greater than the correlation between fare and survival?\"\n",
|
||||
" }\n",
|
||||
")"
|
||||
]
|
||||
"outputs": [],
|
||||
"source": "from langchain_experimental.agents import create_pandas_dataframe_agent\n\nagent = create_pandas_dataframe_agent(\n llm, df, agent_type=\"openai-tools\", verbose=True, allow_dangerous_code=True\n)\nagent.invoke(\n {\n \"input\": \"What's the correlation between age and fare? is that greater than the correlation between fare and survival?\"\n }\n)"
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
@@ -786,4 +741,4 @@
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -447,6 +447,163 @@
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c5d9d19d-8ab1-4d9d-b3a0-56ee4e89c528",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Custom tools\n",
|
||||
"\n",
|
||||
":::info Requires ``langchain-openai>=0.3.29``\n",
|
||||
"\n",
|
||||
":::\n",
|
||||
"\n",
|
||||
"[Custom tools](https://platform.openai.com/docs/guides/function-calling#custom-tools) support tools with arbitrary string inputs. They can be particularly useful when you expect your string arguments to be long or complex."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "a47c809b-852f-46bd-8b9e-d9534c17213d",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"================================\u001b[1m Human Message \u001b[0m=================================\n",
|
||||
"\n",
|
||||
"Use the tool to calculate 3^3.\n",
|
||||
"==================================\u001b[1m Ai Message \u001b[0m==================================\n",
|
||||
"\n",
|
||||
"[{'id': 'rs_6894ff5747c0819d9b02fc5645b0be9c000169fd9fb68d99', 'summary': [], 'type': 'reasoning'}, {'call_id': 'call_7SYwMSQPbbEqFcKlKOpXeEux', 'input': 'print(3**3)', 'name': 'execute_code', 'type': 'custom_tool_call', 'id': 'ctc_6894ff5b9f54819d8155a63638d34103000169fd9fb68d99', 'status': 'completed'}]\n",
|
||||
"Tool Calls:\n",
|
||||
" execute_code (call_7SYwMSQPbbEqFcKlKOpXeEux)\n",
|
||||
" Call ID: call_7SYwMSQPbbEqFcKlKOpXeEux\n",
|
||||
" Args:\n",
|
||||
" __arg1: print(3**3)\n",
|
||||
"=================================\u001b[1m Tool Message \u001b[0m=================================\n",
|
||||
"Name: execute_code\n",
|
||||
"\n",
|
||||
"[{'type': 'custom_tool_call_output', 'output': '27'}]\n",
|
||||
"==================================\u001b[1m Ai Message \u001b[0m==================================\n",
|
||||
"\n",
|
||||
"[{'type': 'text', 'text': '27', 'annotations': [], 'id': 'msg_6894ff5db3b8819d9159b3a370a25843000169fd9fb68d99'}]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain_openai import ChatOpenAI, custom_tool\n",
|
||||
"from langgraph.prebuilt import create_react_agent\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"@custom_tool\n",
|
||||
"def execute_code(code: str) -> str:\n",
|
||||
" \"\"\"Execute python code.\"\"\"\n",
|
||||
" return \"27\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"llm = ChatOpenAI(model=\"gpt-5\", output_version=\"responses/v1\")\n",
|
||||
"\n",
|
||||
"agent = create_react_agent(llm, [execute_code])\n",
|
||||
"\n",
|
||||
"input_message = {\"role\": \"user\", \"content\": \"Use the tool to calculate 3^3.\"}\n",
|
||||
"for step in agent.stream(\n",
|
||||
" {\"messages\": [input_message]},\n",
|
||||
" stream_mode=\"values\",\n",
|
||||
"):\n",
|
||||
" step[\"messages\"][-1].pretty_print()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "5ef93be6-6d4c-4eea-acfd-248774074082",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"<details>\n",
|
||||
"<summary>Context-free grammars</summary>\n",
|
||||
"\n",
|
||||
"OpenAI supports the specification of a [context-free grammar](https://platform.openai.com/docs/guides/function-calling#context-free-grammars) for custom tool inputs in `lark` or `regex` format. See [OpenAI docs](https://platform.openai.com/docs/guides/function-calling#context-free-grammars) for details. The `format` parameter can be passed into `@custom_tool` as shown below:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "2ae04586-be33-49c6-8947-7867801d868f",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"================================\u001b[1m Human Message \u001b[0m=================================\n",
|
||||
"\n",
|
||||
"Use the tool to calculate 3^3.\n",
|
||||
"==================================\u001b[1m Ai Message \u001b[0m==================================\n",
|
||||
"\n",
|
||||
"[{'id': 'rs_689500828a8481a297ff0f98e328689c0681550c89797f43', 'summary': [], 'type': 'reasoning'}, {'call_id': 'call_jzH01RVhu6EFz7yUrOFXX55s', 'input': '3 * 3 * 3', 'name': 'do_math', 'type': 'custom_tool_call', 'id': 'ctc_6895008d57bc81a2b84d0993517a66b90681550c89797f43', 'status': 'completed'}]\n",
|
||||
"Tool Calls:\n",
|
||||
" do_math (call_jzH01RVhu6EFz7yUrOFXX55s)\n",
|
||||
" Call ID: call_jzH01RVhu6EFz7yUrOFXX55s\n",
|
||||
" Args:\n",
|
||||
" __arg1: 3 * 3 * 3\n",
|
||||
"=================================\u001b[1m Tool Message \u001b[0m=================================\n",
|
||||
"Name: do_math\n",
|
||||
"\n",
|
||||
"[{'type': 'custom_tool_call_output', 'output': '27'}]\n",
|
||||
"==================================\u001b[1m Ai Message \u001b[0m==================================\n",
|
||||
"\n",
|
||||
"[{'type': 'text', 'text': '27', 'annotations': [], 'id': 'msg_6895009776b881a2a25f0be8507d08f20681550c89797f43'}]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain_openai import ChatOpenAI, custom_tool\n",
|
||||
"from langgraph.prebuilt import create_react_agent\n",
|
||||
"\n",
|
||||
"grammar = \"\"\"\n",
|
||||
"start: expr\n",
|
||||
"expr: term (SP ADD SP term)* -> add\n",
|
||||
"| term\n",
|
||||
"term: factor (SP MUL SP factor)* -> mul\n",
|
||||
"| factor\n",
|
||||
"factor: INT\n",
|
||||
"SP: \" \"\n",
|
||||
"ADD: \"+\"\n",
|
||||
"MUL: \"*\"\n",
|
||||
"%import common.INT\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
"format_ = {\"type\": \"grammar\", \"syntax\": \"lark\", \"definition\": grammar}\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# highlight-next-line\n",
|
||||
"@custom_tool(format=format_)\n",
|
||||
"def do_math(input_string: str) -> str:\n",
|
||||
" \"\"\"Do a mathematical operation.\"\"\"\n",
|
||||
" return \"27\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"llm = ChatOpenAI(model=\"gpt-5\", output_version=\"responses/v1\")\n",
|
||||
"\n",
|
||||
"agent = create_react_agent(llm, [do_math])\n",
|
||||
"\n",
|
||||
"input_message = {\"role\": \"user\", \"content\": \"Use the tool to calculate 3^3.\"}\n",
|
||||
"for step in agent.stream(\n",
|
||||
" {\"messages\": [input_message]},\n",
|
||||
" stream_mode=\"values\",\n",
|
||||
"):\n",
|
||||
" step[\"messages\"][-1].pretty_print()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c63430c9-c7b0-4e92-a491-3f165dddeb8f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"</details>"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "84833dd0-17e9-4269-82ed-550639d65751",
|
||||
|
||||
@@ -132,12 +132,13 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_core.documents import Document\n",
|
||||
"from langchain_experimental.graph_transformers import LLMGraphTransformer\n",
|
||||
"\n",
|
||||
"# from langchain_experimental.graph_transformers import LLMGraphTransformer\n",
|
||||
"from langchain_openai import ChatOpenAI\n",
|
||||
"\n",
|
||||
"# Define the LLMGraphTransformer\n",
|
||||
|
||||
@@ -548,12 +548,12 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_core.documents import Document\n",
|
||||
"from langchain_experimental.graph_transformers import LLMGraphTransformer"
|
||||
"# from langchain_experimental.graph_transformers import LLMGraphTransformer"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -29,8 +29,8 @@
|
||||
" Please refer to the instructions in:\n",
|
||||
" [www.jaguardb.com](http://www.jaguardb.com)\n",
|
||||
" For quick setup in docker environment:\n",
|
||||
" docker pull jaguardb/jaguardb_with_http\n",
|
||||
" docker run -d -p 8888:8888 -p 8080:8080 --name jaguardb_with_http jaguardb/jaguardb_with_http\n",
|
||||
" docker pull jaguardb/jaguardb\n",
|
||||
" docker run -d -p 8888:8888 -p 8080:8080 --name jaguardb jaguardb/jaguardb\n",
|
||||
"\n",
|
||||
"2. You must install the http client package for JaguarDB:\n",
|
||||
" ```\n",
|
||||
|
||||
@@ -35,6 +35,7 @@ embeddings.embed_query("What is the meaning of life?")
|
||||
```
|
||||
|
||||
## LLMs
|
||||
|
||||
`__ModuleName__LLM` class exposes LLMs from __ModuleName__.
|
||||
|
||||
```python
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version: 0.0.1
|
||||
patterns:
|
||||
- name: github.com/getgrit/stdlib#*
|
||||
- name: github.com/getgrit/stdlib#*
|
||||
|
||||
@@ -27,16 +27,16 @@ langchain app add __package_name__
|
||||
```
|
||||
|
||||
And add the following code to your `server.py` file:
|
||||
|
||||
```python
|
||||
__app_route_code__
|
||||
```
|
||||
|
||||
(Optional) Let's now configure LangSmith.
|
||||
LangSmith will help us trace, monitor and debug LangChain applications.
|
||||
You can sign up for LangSmith [here](https://smith.langchain.com/).
|
||||
(Optional) Let's now configure LangSmith.
|
||||
LangSmith will help us trace, monitor and debug LangChain applications.
|
||||
You can sign up for LangSmith [here](https://smith.langchain.com/).
|
||||
If you don't have access, you can skip this section
|
||||
|
||||
|
||||
```shell
|
||||
export LANGSMITH_TRACING=true
|
||||
export LANGSMITH_API_KEY=<your-api-key>
|
||||
@@ -49,11 +49,11 @@ If you are inside this directory, then you can spin up a LangServe instance dire
|
||||
langchain serve
|
||||
```
|
||||
|
||||
This will start the FastAPI app with a server is running locally at
|
||||
This will start the FastAPI app with a server is running locally at
|
||||
[http://localhost:8000](http://localhost:8000)
|
||||
|
||||
We can see all templates at [http://127.0.0.1:8000/docs](http://127.0.0.1:8000/docs)
|
||||
We can access the playground at [http://127.0.0.1:8000/__package_name__/playground](http://127.0.0.1:8000/__package_name__/playground)
|
||||
We can access the playground at [http://127.0.0.1:8000/__package_name__/playground](http://127.0.0.1:8000/__package_name__/playground)
|
||||
|
||||
We can access the template from code with:
|
||||
|
||||
@@ -61,4 +61,4 @@ We can access the template from code with:
|
||||
from langserve.client import RemoteRunnable
|
||||
|
||||
runnable = RemoteRunnable("http://localhost:8000/__package_name__")
|
||||
```
|
||||
```
|
||||
|
||||
@@ -11,7 +11,7 @@ pip install -U langchain-cli
|
||||
## Adding packages
|
||||
|
||||
```bash
|
||||
# adding packages from
|
||||
# adding packages from
|
||||
# https://github.com/langchain-ai/langchain/tree/master/templates
|
||||
langchain app add $PROJECT_NAME
|
||||
|
||||
@@ -31,10 +31,10 @@ langchain app remove my/custom/path/rag
|
||||
```
|
||||
|
||||
## Setup LangSmith (Optional)
|
||||
LangSmith will help us trace, monitor and debug LangChain applications.
|
||||
You can sign up for LangSmith [here](https://smith.langchain.com/).
|
||||
If you don't have access, you can skip this section
|
||||
|
||||
LangSmith will help us trace, monitor and debug LangChain applications.
|
||||
You can sign up for LangSmith [here](https://smith.langchain.com/).
|
||||
If you don't have access, you can skip this section
|
||||
|
||||
```shell
|
||||
export LANGSMITH_TRACING=true
|
||||
|
||||
@@ -144,7 +144,7 @@ def beta(
|
||||
obj.__init__ = functools.wraps(obj.__init__)( # type: ignore[misc]
|
||||
warn_if_direct_instance
|
||||
)
|
||||
return cast("T", obj)
|
||||
return obj
|
||||
|
||||
elif isinstance(obj, property):
|
||||
# note(erick): this block doesn't seem to be used?
|
||||
|
||||
@@ -225,7 +225,7 @@ def deprecated(
|
||||
obj.__init__ = functools.wraps(obj.__init__)( # type: ignore[misc]
|
||||
warn_if_direct_instance
|
||||
)
|
||||
return cast("T", obj)
|
||||
return obj
|
||||
|
||||
elif isinstance(obj, FieldInfoV1):
|
||||
wrapped = None
|
||||
|
||||
@@ -7,6 +7,8 @@ from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
from langchain_core.v1.messages import AIMessage, AIMessageChunk, MessageV1
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
from uuid import UUID
|
||||
@@ -66,7 +68,9 @@ class LLMManagerMixin:
|
||||
self,
|
||||
token: str,
|
||||
*,
|
||||
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
|
||||
chunk: Optional[
|
||||
Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]
|
||||
] = None,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
@@ -75,8 +79,8 @@ class LLMManagerMixin:
|
||||
|
||||
Args:
|
||||
token (str): The new token.
|
||||
chunk (GenerationChunk | ChatGenerationChunk): The new generated chunk,
|
||||
containing content and other information.
|
||||
chunk (GenerationChunk | ChatGenerationChunk | AIMessageChunk): The new
|
||||
generated chunk, containing content and other information.
|
||||
run_id (UUID): The run ID. This is the ID of the current run.
|
||||
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
|
||||
kwargs (Any): Additional keyword arguments.
|
||||
@@ -84,7 +88,7 @@ class LLMManagerMixin:
|
||||
|
||||
def on_llm_end(
|
||||
self,
|
||||
response: LLMResult,
|
||||
response: Union[LLMResult, AIMessage],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
@@ -93,7 +97,7 @@ class LLMManagerMixin:
|
||||
"""Run when LLM ends running.
|
||||
|
||||
Args:
|
||||
response (LLMResult): The response which was generated.
|
||||
response (LLMResult | AIMessage): The response which was generated.
|
||||
run_id (UUID): The run ID. This is the ID of the current run.
|
||||
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
|
||||
kwargs (Any): Additional keyword arguments.
|
||||
@@ -261,7 +265,7 @@ class CallbackManagerMixin:
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: dict[str, Any],
|
||||
messages: list[list[BaseMessage]],
|
||||
messages: Union[list[list[BaseMessage]], list[MessageV1]],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
@@ -439,6 +443,9 @@ class BaseCallbackHandler(
|
||||
run_inline: bool = False
|
||||
"""Whether to run the callback inline."""
|
||||
|
||||
accepts_new_messages: bool = False
|
||||
"""Whether the callback accepts new message format."""
|
||||
|
||||
@property
|
||||
def ignore_llm(self) -> bool:
|
||||
"""Whether to ignore LLM callbacks."""
|
||||
@@ -509,7 +516,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
async def on_chat_model_start(
|
||||
self,
|
||||
serialized: dict[str, Any],
|
||||
messages: list[list[BaseMessage]],
|
||||
messages: Union[list[list[BaseMessage]], list[MessageV1]],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
@@ -540,7 +547,9 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
self,
|
||||
token: str,
|
||||
*,
|
||||
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
|
||||
chunk: Optional[
|
||||
Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]
|
||||
] = None,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
@@ -550,8 +559,8 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
Args:
|
||||
token (str): The new token.
|
||||
chunk (GenerationChunk | ChatGenerationChunk): The new generated chunk,
|
||||
containing content and other information.
|
||||
chunk (GenerationChunk | ChatGenerationChunk | AIMessageChunk): The new
|
||||
generated chunk, containing content and other information.
|
||||
run_id (UUID): The run ID. This is the ID of the current run.
|
||||
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
|
||||
tags (Optional[list[str]]): The tags.
|
||||
@@ -560,7 +569,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
async def on_llm_end(
|
||||
self,
|
||||
response: LLMResult,
|
||||
response: Union[LLMResult, AIMessage],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
@@ -570,7 +579,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
"""Run when LLM ends running.
|
||||
|
||||
Args:
|
||||
response (LLMResult): The response which was generated.
|
||||
response (LLMResult | AIMessage): The response which was generated.
|
||||
run_id (UUID): The run ID. This is the ID of the current run.
|
||||
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
|
||||
tags (Optional[list[str]]): The tags.
|
||||
@@ -594,8 +603,8 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
parent_run_id: The parent run ID. This is the ID of the parent run.
|
||||
tags: The tags.
|
||||
kwargs (Any): Additional keyword arguments.
|
||||
- response (LLMResult): The response which was generated before
|
||||
the error occurred.
|
||||
- response (LLMResult | AIMessage): The response which was generated
|
||||
before the error occurred.
|
||||
"""
|
||||
|
||||
async def on_chain_start(
|
||||
|
||||
@@ -49,7 +49,7 @@ class FileCallbackHandler(BaseCallbackHandler):
|
||||
mode: The file open mode. Defaults to ``'a'`` (append).
|
||||
color: Default color for text output. Defaults to ``None``.
|
||||
|
||||
Note:
|
||||
.. note::
|
||||
When not used as a context manager, a deprecation warning will be issued
|
||||
on first use. The file will be opened immediately in ``__init__`` and closed
|
||||
in ``__del__`` or when ``close()`` is called explicitly.
|
||||
@@ -65,6 +65,7 @@ class FileCallbackHandler(BaseCallbackHandler):
|
||||
filename: Path to the output file.
|
||||
mode: File open mode (e.g., ``'w'``, ``'a'``, ``'x'``). Defaults to ``'a'``.
|
||||
color: Default text color for output. Defaults to ``None``.
|
||||
|
||||
"""
|
||||
self.filename = filename
|
||||
self.mode = mode
|
||||
@@ -82,9 +83,10 @@ class FileCallbackHandler(BaseCallbackHandler):
|
||||
Returns:
|
||||
The FileCallbackHandler instance.
|
||||
|
||||
Note:
|
||||
.. note::
|
||||
The file is already opened in ``__init__``, so this just marks that
|
||||
the handler is being used as a context manager.
|
||||
|
||||
"""
|
||||
self._file_opened_in_context = True
|
||||
return self
|
||||
@@ -101,6 +103,7 @@ class FileCallbackHandler(BaseCallbackHandler):
|
||||
exc_type: Exception type if an exception occurred.
|
||||
exc_val: Exception value if an exception occurred.
|
||||
exc_tb: Exception traceback if an exception occurred.
|
||||
|
||||
"""
|
||||
self.close()
|
||||
|
||||
@@ -113,6 +116,7 @@ class FileCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
This method is safe to call multiple times and will only close
|
||||
the file if it's currently open.
|
||||
|
||||
"""
|
||||
if hasattr(self, "file") and self.file and not self.file.closed:
|
||||
self.file.close()
|
||||
@@ -133,6 +137,7 @@ class FileCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the file is closed or not available.
|
||||
|
||||
"""
|
||||
global _GLOBAL_DEPRECATION_WARNED # noqa: PLW0603
|
||||
if not self._file_opened_in_context and not _GLOBAL_DEPRECATION_WARNED:
|
||||
@@ -163,6 +168,7 @@ class FileCallbackHandler(BaseCallbackHandler):
|
||||
serialized: The serialized chain information.
|
||||
inputs: The inputs to the chain.
|
||||
**kwargs: Additional keyword arguments that may contain ``'name'``.
|
||||
|
||||
"""
|
||||
name = (
|
||||
kwargs.get("name")
|
||||
@@ -178,6 +184,7 @@ class FileCallbackHandler(BaseCallbackHandler):
|
||||
Args:
|
||||
outputs: The outputs of the chain.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
"""
|
||||
self._write("\n> Finished chain.", end="\n")
|
||||
|
||||
@@ -192,6 +199,7 @@ class FileCallbackHandler(BaseCallbackHandler):
|
||||
color: Color override for this specific output. If ``None``, uses
|
||||
``self.color``.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
"""
|
||||
self._write(action.log, color=color or self.color)
|
||||
|
||||
@@ -213,6 +221,7 @@ class FileCallbackHandler(BaseCallbackHandler):
|
||||
observation_prefix: Optional prefix to write before the output.
|
||||
llm_prefix: Optional prefix to write after the output.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
"""
|
||||
if observation_prefix is not None:
|
||||
self._write(f"\n{observation_prefix}")
|
||||
@@ -232,6 +241,7 @@ class FileCallbackHandler(BaseCallbackHandler):
|
||||
``self.color``.
|
||||
end: String appended after the text. Defaults to ``""``.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
"""
|
||||
self._write(text, color=color or self.color, end=end)
|
||||
|
||||
@@ -246,5 +256,6 @@ class FileCallbackHandler(BaseCallbackHandler):
|
||||
color: Color override for this specific output. If ``None``, uses
|
||||
``self.color``.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
"""
|
||||
self._write(finish.log, color=color or self.color, end="\n")
|
||||
|
||||
@@ -11,15 +11,7 @@ from abc import ABC, abstractmethod
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextlib import asynccontextmanager, contextmanager
|
||||
from contextvars import copy_context
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Optional,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union, cast
|
||||
from uuid import UUID
|
||||
|
||||
from langsmith.run_helpers import get_tracing_context
|
||||
@@ -37,8 +29,16 @@ from langchain_core.callbacks.base import (
|
||||
)
|
||||
from langchain_core.callbacks.stdout import StdOutCallbackHandler
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
from langchain_core.messages.utils import convert_from_v1_message
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, LLMResult
|
||||
from langchain_core.tracers.schemas import Run
|
||||
from langchain_core.utils.env import env_var_is_set
|
||||
from langchain_core.v1.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
MessageV1,
|
||||
MessageV1Types,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncGenerator, Coroutine, Generator, Sequence
|
||||
@@ -47,7 +47,7 @@ if TYPE_CHECKING:
|
||||
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk, LLMResult
|
||||
from langchain_core.outputs import GenerationChunk
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -92,7 +92,8 @@ def trace_as_chain_group(
|
||||
metadata (dict[str, Any], optional): The metadata to apply to all runs.
|
||||
Defaults to None.
|
||||
|
||||
Note: must have LANGCHAIN_TRACING_V2 env var set to true to see the trace in LangSmith.
|
||||
.. note:
|
||||
Must have ``LANGCHAIN_TRACING_V2`` env var set to true to see the trace in LangSmith.
|
||||
|
||||
Returns:
|
||||
CallbackManagerForChainGroup: The callback manager for the chain group.
|
||||
@@ -177,7 +178,8 @@ async def atrace_as_chain_group(
|
||||
Returns:
|
||||
AsyncCallbackManager: The async callback manager for the chain group.
|
||||
|
||||
Note: must have LANGCHAIN_TRACING_V2 env var set to true to see the trace in LangSmith.
|
||||
.. note:
|
||||
Must have ``LANGCHAIN_TRACING_V2`` env var set to true to see the trace in LangSmith.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
@@ -234,6 +236,7 @@ def shielded(func: Func) -> Func:
|
||||
|
||||
Returns:
|
||||
Callable: The shielded function
|
||||
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
@@ -243,6 +246,46 @@ def shielded(func: Func) -> Func:
|
||||
return cast("Func", wrapped)
|
||||
|
||||
|
||||
def _convert_llm_events(
|
||||
event_name: str, args: tuple[Any, ...], kwargs: dict[str, Any]
|
||||
) -> tuple[tuple[Any, ...], dict[str, Any]]:
|
||||
args_list = list(args)
|
||||
if (
|
||||
event_name == "on_chat_model_start"
|
||||
and isinstance(args_list[1], list)
|
||||
and args_list[1]
|
||||
and isinstance(args_list[1][0], MessageV1Types)
|
||||
):
|
||||
batch = [
|
||||
convert_from_v1_message(item)
|
||||
for item in args_list[1]
|
||||
if isinstance(item, MessageV1Types)
|
||||
]
|
||||
args_list[1] = [batch]
|
||||
elif (
|
||||
event_name == "on_llm_new_token"
|
||||
and "chunk" in kwargs
|
||||
and isinstance(kwargs["chunk"], MessageV1Types)
|
||||
):
|
||||
chunk = kwargs["chunk"]
|
||||
kwargs["chunk"] = ChatGenerationChunk(text=chunk.text, message=chunk)
|
||||
elif event_name == "on_llm_end" and isinstance(args_list[0], MessageV1Types):
|
||||
args_list[0] = LLMResult(
|
||||
generations=[
|
||||
[
|
||||
ChatGeneration(
|
||||
text=args_list[0].text,
|
||||
message=convert_from_v1_message(args_list[0]),
|
||||
)
|
||||
]
|
||||
]
|
||||
)
|
||||
else:
|
||||
pass
|
||||
|
||||
return tuple(args_list), kwargs
|
||||
|
||||
|
||||
def handle_event(
|
||||
handlers: list[BaseCallbackHandler],
|
||||
event_name: str,
|
||||
@@ -252,15 +295,17 @@ def handle_event(
|
||||
) -> None:
|
||||
"""Generic event handler for CallbackManager.
|
||||
|
||||
Note: This function is used by LangServe to handle events.
|
||||
.. note::
|
||||
This function is used by ``LangServe`` to handle events.
|
||||
|
||||
Args:
|
||||
handlers: The list of handlers that will handle the event.
|
||||
event_name: The name of the event (e.g., "on_llm_start").
|
||||
event_name: The name of the event (e.g., ``'on_llm_start'``).
|
||||
ignore_condition_name: Name of the attribute defined on handler
|
||||
that if True will cause the handler to be skipped for the given event.
|
||||
*args: The arguments to pass to the event handler.
|
||||
**kwargs: The keyword arguments to pass to the event handler
|
||||
|
||||
"""
|
||||
coros: list[Coroutine[Any, Any, Any]] = []
|
||||
|
||||
@@ -271,6 +316,8 @@ def handle_event(
|
||||
if ignore_condition_name is None or not getattr(
|
||||
handler, ignore_condition_name
|
||||
):
|
||||
if not handler.accepts_new_messages:
|
||||
args, kwargs = _convert_llm_events(event_name, args, kwargs)
|
||||
event = getattr(handler, event_name)(*args, **kwargs)
|
||||
if asyncio.iscoroutine(event):
|
||||
coros.append(event)
|
||||
@@ -365,6 +412,8 @@ async def _ahandle_event_for_handler(
|
||||
) -> None:
|
||||
try:
|
||||
if ignore_condition_name is None or not getattr(handler, ignore_condition_name):
|
||||
if not handler.accepts_new_messages:
|
||||
args, kwargs = _convert_llm_events(event_name, args, kwargs)
|
||||
event = getattr(handler, event_name)
|
||||
if asyncio.iscoroutinefunction(event):
|
||||
await event(*args, **kwargs)
|
||||
@@ -415,17 +464,19 @@ async def ahandle_event(
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Async generic event handler for AsyncCallbackManager.
|
||||
"""Async generic event handler for ``AsyncCallbackManager``.
|
||||
|
||||
Note: This function is used by LangServe to handle events.
|
||||
.. note::
|
||||
This function is used by ``LangServe`` to handle events.
|
||||
|
||||
Args:
|
||||
handlers: The list of handlers that will handle the event.
|
||||
event_name: The name of the event (e.g., "on_llm_start").
|
||||
event_name: The name of the event (e.g., ``'on_llm_start'``).
|
||||
ignore_condition_name: Name of the attribute defined on handler
|
||||
that if True will cause the handler to be skipped for the given event.
|
||||
*args: The arguments to pass to the event handler.
|
||||
**kwargs: The keyword arguments to pass to the event handler.
|
||||
|
||||
"""
|
||||
for handler in [h for h in handlers if h.run_inline]:
|
||||
await _ahandle_event_for_handler(
|
||||
@@ -477,6 +528,7 @@ class BaseRunManager(RunManagerMixin):
|
||||
Defaults to None.
|
||||
inheritable_metadata (Optional[dict[str, Any]]): The inheritable metadata.
|
||||
Defaults to None.
|
||||
|
||||
"""
|
||||
self.run_id = run_id
|
||||
self.handlers = handlers
|
||||
@@ -493,6 +545,7 @@ class BaseRunManager(RunManagerMixin):
|
||||
|
||||
Returns:
|
||||
BaseRunManager: The noop manager.
|
||||
|
||||
"""
|
||||
return cls(
|
||||
run_id=uuid.uuid4(),
|
||||
@@ -545,6 +598,7 @@ class RunManager(BaseRunManager):
|
||||
Args:
|
||||
retry_state (RetryCallState): The retry state.
|
||||
**kwargs (Any): Additional keyword arguments.
|
||||
|
||||
"""
|
||||
if not self.handlers:
|
||||
return
|
||||
@@ -572,6 +626,7 @@ class ParentRunManager(RunManager):
|
||||
|
||||
Returns:
|
||||
CallbackManager: The child callback manager.
|
||||
|
||||
"""
|
||||
manager = CallbackManager(handlers=[], parent_run_id=self.run_id)
|
||||
manager.set_handlers(self.inheritable_handlers)
|
||||
@@ -591,6 +646,7 @@ class AsyncRunManager(BaseRunManager, ABC):
|
||||
|
||||
Returns:
|
||||
RunManager: The sync RunManager.
|
||||
|
||||
"""
|
||||
|
||||
async def on_text(
|
||||
@@ -606,6 +662,7 @@ class AsyncRunManager(BaseRunManager, ABC):
|
||||
|
||||
Returns:
|
||||
Any: The result of the callback.
|
||||
|
||||
"""
|
||||
if not self.handlers:
|
||||
return
|
||||
@@ -630,6 +687,7 @@ class AsyncRunManager(BaseRunManager, ABC):
|
||||
Args:
|
||||
retry_state (RetryCallState): The retry state.
|
||||
**kwargs (Any): Additional keyword arguments.
|
||||
|
||||
"""
|
||||
if not self.handlers:
|
||||
return
|
||||
@@ -657,6 +715,7 @@ class AsyncParentRunManager(AsyncRunManager):
|
||||
|
||||
Returns:
|
||||
AsyncCallbackManager: The child callback manager.
|
||||
|
||||
"""
|
||||
manager = AsyncCallbackManager(handlers=[], parent_run_id=self.run_id)
|
||||
manager.set_handlers(self.inheritable_handlers)
|
||||
@@ -674,7 +733,9 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
|
||||
self,
|
||||
token: str,
|
||||
*,
|
||||
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
|
||||
chunk: Optional[
|
||||
Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]
|
||||
] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when LLM generates a new token.
|
||||
@@ -684,6 +745,7 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
|
||||
chunk (Optional[Union[GenerationChunk, ChatGenerationChunk]], optional):
|
||||
The chunk. Defaults to None.
|
||||
**kwargs (Any): Additional keyword arguments.
|
||||
|
||||
"""
|
||||
if not self.handlers:
|
||||
return
|
||||
@@ -699,12 +761,13 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
def on_llm_end(self, response: Union[LLMResult, AIMessage], **kwargs: Any) -> None:
|
||||
"""Run when LLM ends running.
|
||||
|
||||
Args:
|
||||
response (LLMResult): The LLM result.
|
||||
response (LLMResult | AIMessage): The LLM result.
|
||||
**kwargs (Any): Additional keyword arguments.
|
||||
|
||||
"""
|
||||
if not self.handlers:
|
||||
return
|
||||
@@ -729,8 +792,9 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
|
||||
Args:
|
||||
error (Exception or KeyboardInterrupt): The error.
|
||||
kwargs (Any): Additional keyword arguments.
|
||||
- response (LLMResult): The response which was generated before
|
||||
the error occurred.
|
||||
- response (LLMResult | AIMessage): The response which was generated
|
||||
before the error occurred.
|
||||
|
||||
"""
|
||||
if not self.handlers:
|
||||
return
|
||||
@@ -754,6 +818,7 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
|
||||
|
||||
Returns:
|
||||
CallbackManagerForLLMRun: The sync RunManager.
|
||||
|
||||
"""
|
||||
return CallbackManagerForLLMRun(
|
||||
run_id=self.run_id,
|
||||
@@ -770,7 +835,9 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
|
||||
self,
|
||||
token: str,
|
||||
*,
|
||||
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
|
||||
chunk: Optional[
|
||||
Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]
|
||||
] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when LLM generates a new token.
|
||||
@@ -780,6 +847,7 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
|
||||
chunk (Optional[Union[GenerationChunk, ChatGenerationChunk]], optional):
|
||||
The chunk. Defaults to None.
|
||||
**kwargs (Any): Additional keyword arguments.
|
||||
|
||||
"""
|
||||
if not self.handlers:
|
||||
return
|
||||
@@ -796,12 +864,15 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
|
||||
)
|
||||
|
||||
@shielded
|
||||
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
async def on_llm_end(
|
||||
self, response: Union[LLMResult, AIMessage], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM ends running.
|
||||
|
||||
Args:
|
||||
response (LLMResult): The LLM result.
|
||||
response (LLMResult | AIMessage): The LLM result.
|
||||
**kwargs (Any): Additional keyword arguments.
|
||||
|
||||
"""
|
||||
if not self.handlers:
|
||||
return
|
||||
@@ -827,10 +898,8 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
|
||||
Args:
|
||||
error (Exception or KeyboardInterrupt): The error.
|
||||
kwargs (Any): Additional keyword arguments.
|
||||
- response (LLMResult): The response which was generated before
|
||||
the error occurred.
|
||||
|
||||
|
||||
- response (LLMResult | AIMessage): The response which was generated
|
||||
before the error occurred.
|
||||
|
||||
"""
|
||||
if not self.handlers:
|
||||
@@ -856,6 +925,7 @@ class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin):
|
||||
Args:
|
||||
outputs (Union[dict[str, Any], Any]): The outputs of the chain.
|
||||
**kwargs (Any): Additional keyword arguments.
|
||||
|
||||
"""
|
||||
if not self.handlers:
|
||||
return
|
||||
@@ -880,6 +950,7 @@ class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin):
|
||||
Args:
|
||||
error (Exception or KeyboardInterrupt): The error.
|
||||
**kwargs (Any): Additional keyword arguments.
|
||||
|
||||
"""
|
||||
if not self.handlers:
|
||||
return
|
||||
@@ -903,6 +974,7 @@ class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin):
|
||||
|
||||
Returns:
|
||||
Any: The result of the callback.
|
||||
|
||||
"""
|
||||
if not self.handlers:
|
||||
return
|
||||
@@ -926,6 +998,7 @@ class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin):
|
||||
|
||||
Returns:
|
||||
Any: The result of the callback.
|
||||
|
||||
"""
|
||||
if not self.handlers:
|
||||
return
|
||||
@@ -970,6 +1043,7 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin):
|
||||
Args:
|
||||
outputs (Union[dict[str, Any], Any]): The outputs of the chain.
|
||||
**kwargs (Any): Additional keyword arguments.
|
||||
|
||||
"""
|
||||
if not self.handlers:
|
||||
return
|
||||
@@ -995,6 +1069,7 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin):
|
||||
Args:
|
||||
error (Exception or KeyboardInterrupt): The error.
|
||||
**kwargs (Any): Additional keyword arguments.
|
||||
|
||||
"""
|
||||
if not self.handlers:
|
||||
return
|
||||
@@ -1018,6 +1093,7 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin):
|
||||
|
||||
Returns:
|
||||
Any: The result of the callback.
|
||||
|
||||
"""
|
||||
if not self.handlers:
|
||||
return
|
||||
@@ -1041,6 +1117,7 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin):
|
||||
|
||||
Returns:
|
||||
Any: The result of the callback.
|
||||
|
||||
"""
|
||||
if not self.handlers:
|
||||
return
|
||||
@@ -1069,6 +1146,7 @@ class CallbackManagerForToolRun(ParentRunManager, ToolManagerMixin):
|
||||
Args:
|
||||
output (Any): The output of the tool.
|
||||
**kwargs (Any): The keyword arguments to pass to the event handler
|
||||
|
||||
"""
|
||||
if not self.handlers:
|
||||
return
|
||||
@@ -1093,6 +1171,7 @@ class CallbackManagerForToolRun(ParentRunManager, ToolManagerMixin):
|
||||
Args:
|
||||
error (Exception or KeyboardInterrupt): The error.
|
||||
**kwargs (Any): Additional keyword arguments.
|
||||
|
||||
"""
|
||||
if not self.handlers:
|
||||
return
|
||||
@@ -1134,6 +1213,7 @@ class AsyncCallbackManagerForToolRun(AsyncParentRunManager, ToolManagerMixin):
|
||||
Args:
|
||||
output (Any): The output of the tool.
|
||||
**kwargs (Any): Additional keyword arguments.
|
||||
|
||||
"""
|
||||
if not self.handlers:
|
||||
return
|
||||
@@ -1158,6 +1238,7 @@ class AsyncCallbackManagerForToolRun(AsyncParentRunManager, ToolManagerMixin):
|
||||
Args:
|
||||
error (Exception or KeyboardInterrupt): The error.
|
||||
**kwargs (Any): Additional keyword arguments.
|
||||
|
||||
"""
|
||||
if not self.handlers:
|
||||
return
|
||||
@@ -1186,6 +1267,7 @@ class CallbackManagerForRetrieverRun(ParentRunManager, RetrieverManagerMixin):
|
||||
Args:
|
||||
documents (Sequence[Document]): The retrieved documents.
|
||||
**kwargs (Any): Additional keyword arguments.
|
||||
|
||||
"""
|
||||
if not self.handlers:
|
||||
return
|
||||
@@ -1210,6 +1292,7 @@ class CallbackManagerForRetrieverRun(ParentRunManager, RetrieverManagerMixin):
|
||||
Args:
|
||||
error (BaseException): The error.
|
||||
**kwargs (Any): Additional keyword arguments.
|
||||
|
||||
"""
|
||||
if not self.handlers:
|
||||
return
|
||||
@@ -1236,6 +1319,7 @@ class AsyncCallbackManagerForRetrieverRun(
|
||||
|
||||
Returns:
|
||||
CallbackManagerForRetrieverRun: The sync RunManager.
|
||||
|
||||
"""
|
||||
return CallbackManagerForRetrieverRun(
|
||||
run_id=self.run_id,
|
||||
@@ -1257,6 +1341,7 @@ class AsyncCallbackManagerForRetrieverRun(
|
||||
Args:
|
||||
documents (Sequence[Document]): The retrieved documents.
|
||||
**kwargs (Any): Additional keyword arguments.
|
||||
|
||||
"""
|
||||
if not self.handlers:
|
||||
return
|
||||
@@ -1282,6 +1367,7 @@ class AsyncCallbackManagerForRetrieverRun(
|
||||
Args:
|
||||
error (BaseException): The error.
|
||||
**kwargs (Any): Additional keyword arguments.
|
||||
|
||||
"""
|
||||
if not self.handlers:
|
||||
return
|
||||
@@ -1318,6 +1404,7 @@ class CallbackManager(BaseCallbackManager):
|
||||
Returns:
|
||||
list[CallbackManagerForLLMRun]: A callback manager for each
|
||||
prompt as an LLM run.
|
||||
|
||||
"""
|
||||
managers = []
|
||||
for i, prompt in enumerate(prompts):
|
||||
@@ -1354,7 +1441,7 @@ class CallbackManager(BaseCallbackManager):
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: dict[str, Any],
|
||||
messages: list[list[BaseMessage]],
|
||||
messages: Union[list[list[BaseMessage]], list[MessageV1]],
|
||||
run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> list[CallbackManagerForLLMRun]:
|
||||
@@ -1362,14 +1449,41 @@ class CallbackManager(BaseCallbackManager):
|
||||
|
||||
Args:
|
||||
serialized (dict[str, Any]): The serialized LLM.
|
||||
messages (list[list[BaseMessage]]): The list of messages.
|
||||
messages (list[list[BaseMessage | MessageV1]]): The list of messages.
|
||||
run_id (UUID, optional): The ID of the run. Defaults to None.
|
||||
**kwargs (Any): Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
list[CallbackManagerForLLMRun]: A callback manager for each
|
||||
list of messages as an LLM run.
|
||||
|
||||
"""
|
||||
if messages and isinstance(messages[0], MessageV1Types):
|
||||
run_id_ = run_id if run_id is not None else uuid.uuid4()
|
||||
handle_event(
|
||||
self.handlers,
|
||||
"on_chat_model_start",
|
||||
"ignore_chat_model",
|
||||
serialized,
|
||||
messages,
|
||||
run_id=run_id_,
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
metadata=self.metadata,
|
||||
**kwargs,
|
||||
)
|
||||
return [
|
||||
CallbackManagerForLLMRun(
|
||||
run_id=run_id_,
|
||||
handlers=self.handlers,
|
||||
inheritable_handlers=self.inheritable_handlers,
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
inheritable_tags=self.inheritable_tags,
|
||||
metadata=self.metadata,
|
||||
inheritable_metadata=self.inheritable_metadata,
|
||||
)
|
||||
]
|
||||
managers = []
|
||||
for message_list in messages:
|
||||
if run_id is not None:
|
||||
@@ -1422,6 +1536,7 @@ class CallbackManager(BaseCallbackManager):
|
||||
|
||||
Returns:
|
||||
CallbackManagerForChainRun: The callback manager for the chain run.
|
||||
|
||||
"""
|
||||
if run_id is None:
|
||||
run_id = uuid.uuid4()
|
||||
@@ -1476,6 +1591,7 @@ class CallbackManager(BaseCallbackManager):
|
||||
|
||||
Returns:
|
||||
CallbackManagerForToolRun: The callback manager for the tool run.
|
||||
|
||||
"""
|
||||
if run_id is None:
|
||||
run_id = uuid.uuid4()
|
||||
@@ -1522,6 +1638,7 @@ class CallbackManager(BaseCallbackManager):
|
||||
run_id (UUID, optional): The ID of the run. Defaults to None.
|
||||
parent_run_id (UUID, optional): The ID of the parent run. Defaults to None.
|
||||
**kwargs (Any): Additional keyword arguments.
|
||||
|
||||
"""
|
||||
if run_id is None:
|
||||
run_id = uuid.uuid4()
|
||||
@@ -1569,6 +1686,7 @@ class CallbackManager(BaseCallbackManager):
|
||||
run_id: The ID of the run. Defaults to None.
|
||||
|
||||
.. versionadded:: 0.2.14
|
||||
|
||||
"""
|
||||
if not self.handlers:
|
||||
return
|
||||
@@ -1623,6 +1741,7 @@ class CallbackManager(BaseCallbackManager):
|
||||
|
||||
Returns:
|
||||
CallbackManager: The configured callback manager.
|
||||
|
||||
"""
|
||||
return _configure(
|
||||
cls,
|
||||
@@ -1657,6 +1776,7 @@ class CallbackManagerForChainGroup(CallbackManager):
|
||||
parent_run_id (Optional[UUID]): The ID of the parent run. Defaults to None.
|
||||
parent_run_manager (CallbackManagerForChainRun): The parent run manager.
|
||||
**kwargs (Any): Additional keyword arguments.
|
||||
|
||||
"""
|
||||
super().__init__(
|
||||
handlers,
|
||||
@@ -1745,6 +1865,7 @@ class CallbackManagerForChainGroup(CallbackManager):
|
||||
Args:
|
||||
outputs (Union[dict[str, Any], Any]): The outputs of the chain.
|
||||
**kwargs (Any): Additional keyword arguments.
|
||||
|
||||
"""
|
||||
self.ended = True
|
||||
return self.parent_run_manager.on_chain_end(outputs, **kwargs)
|
||||
@@ -1759,6 +1880,7 @@ class CallbackManagerForChainGroup(CallbackManager):
|
||||
Args:
|
||||
error (Exception or KeyboardInterrupt): The error.
|
||||
**kwargs (Any): Additional keyword arguments.
|
||||
|
||||
"""
|
||||
self.ended = True
|
||||
return self.parent_run_manager.on_chain_error(error, **kwargs)
|
||||
@@ -1864,7 +1986,7 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
async def on_chat_model_start(
|
||||
self,
|
||||
serialized: dict[str, Any],
|
||||
messages: list[list[BaseMessage]],
|
||||
messages: Union[list[list[BaseMessage]], list[MessageV1]],
|
||||
run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> list[AsyncCallbackManagerForLLMRun]:
|
||||
@@ -1872,7 +1994,7 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
|
||||
Args:
|
||||
serialized (dict[str, Any]): The serialized LLM.
|
||||
messages (list[list[BaseMessage]]): The list of messages.
|
||||
messages (list[list[BaseMessage | MessageV1]]): The list of messages.
|
||||
run_id (UUID, optional): The ID of the run. Defaults to None.
|
||||
**kwargs (Any): Additional keyword arguments.
|
||||
|
||||
@@ -1881,10 +2003,51 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
async callback managers, one for each LLM Run
|
||||
corresponding to each inner message list.
|
||||
"""
|
||||
if messages and isinstance(messages[0], MessageV1Types):
|
||||
run_id_ = run_id if run_id is not None else uuid.uuid4()
|
||||
inline_tasks = []
|
||||
non_inline_tasks = []
|
||||
for handler in self.handlers:
|
||||
task = ahandle_event(
|
||||
[handler],
|
||||
"on_chat_model_start",
|
||||
"ignore_chat_model",
|
||||
serialized,
|
||||
messages,
|
||||
run_id=run_id_,
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
metadata=self.metadata,
|
||||
**kwargs,
|
||||
)
|
||||
if handler.run_inline:
|
||||
inline_tasks.append(task)
|
||||
else:
|
||||
non_inline_tasks.append(task)
|
||||
managers = [
|
||||
AsyncCallbackManagerForLLMRun(
|
||||
run_id=run_id_,
|
||||
handlers=self.handlers,
|
||||
inheritable_handlers=self.inheritable_handlers,
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
inheritable_tags=self.inheritable_tags,
|
||||
metadata=self.metadata,
|
||||
inheritable_metadata=self.inheritable_metadata,
|
||||
)
|
||||
]
|
||||
# Run inline tasks sequentially
|
||||
for task in inline_tasks:
|
||||
await task
|
||||
|
||||
# Run non-inline tasks concurrently
|
||||
if non_inline_tasks:
|
||||
await asyncio.gather(*non_inline_tasks)
|
||||
|
||||
return managers
|
||||
inline_tasks = []
|
||||
non_inline_tasks = []
|
||||
managers = []
|
||||
|
||||
for message_list in messages:
|
||||
if run_id is not None:
|
||||
run_id_ = run_id
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, Union
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
@@ -13,6 +13,7 @@ if TYPE_CHECKING:
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.outputs import LLMResult
|
||||
from langchain_core.v1.messages import AIMessage, MessageV1
|
||||
|
||||
|
||||
class StreamingStdOutCallbackHandler(BaseCallbackHandler):
|
||||
@@ -32,7 +33,7 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler):
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: dict[str, Any],
|
||||
messages: list[list[BaseMessage]],
|
||||
messages: Union[list[list[BaseMessage]], list[MessageV1]],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when LLM starts running.
|
||||
@@ -54,7 +55,7 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler):
|
||||
sys.stdout.write(token)
|
||||
sys.stdout.flush()
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
def on_llm_end(self, response: Union[LLMResult, AIMessage], **kwargs: Any) -> None:
|
||||
"""Run when LLM ends running.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -4,14 +4,16 @@ import threading
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from contextvars import ContextVar
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.messages.ai import UsageMetadata, add_usage
|
||||
from langchain_core.messages.utils import convert_from_v1_message
|
||||
from langchain_core.outputs import ChatGeneration, LLMResult
|
||||
from langchain_core.v1.messages import AIMessage as AIMessageV1
|
||||
|
||||
|
||||
class UsageMetadataCallbackHandler(BaseCallbackHandler):
|
||||
@@ -58,9 +60,17 @@ class UsageMetadataCallbackHandler(BaseCallbackHandler):
|
||||
return str(self.usage_metadata)
|
||||
|
||||
@override
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
def on_llm_end(
|
||||
self, response: Union[LLMResult, AIMessageV1], **kwargs: Any
|
||||
) -> None:
|
||||
"""Collect token usage."""
|
||||
# Check for usage_metadata (langchain-core >= 0.2.2)
|
||||
if isinstance(response, AIMessageV1):
|
||||
response = LLMResult(
|
||||
generations=[
|
||||
[ChatGeneration(message=convert_from_v1_message(response))]
|
||||
]
|
||||
)
|
||||
try:
|
||||
generation = response.generations[0][0]
|
||||
except IndexError:
|
||||
|
||||
@@ -117,9 +117,9 @@ class BaseChatMessageHistory(ABC):
|
||||
def add_user_message(self, message: Union[HumanMessage, str]) -> None:
|
||||
"""Convenience method for adding a human message string to the store.
|
||||
|
||||
Please note that this is a convenience method. Code should favor the
|
||||
bulk add_messages interface instead to save on round-trips to the underlying
|
||||
persistence layer.
|
||||
.. note::
|
||||
This is a convenience method. Code should favor the bulk ``add_messages``
|
||||
interface instead to save on round-trips to the persistence layer.
|
||||
|
||||
This method may be deprecated in a future release.
|
||||
|
||||
@@ -134,9 +134,9 @@ class BaseChatMessageHistory(ABC):
|
||||
def add_ai_message(self, message: Union[AIMessage, str]) -> None:
|
||||
"""Convenience method for adding an AI message string to the store.
|
||||
|
||||
Please note that this is a convenience method. Code should favor the bulk
|
||||
add_messages interface instead to save on round-trips to the underlying
|
||||
persistence layer.
|
||||
.. note::
|
||||
This is a convenience method. Code should favor the bulk ``add_messages``
|
||||
interface instead to save on round-trips to the persistence layer.
|
||||
|
||||
This method may be deprecated in a future release.
|
||||
|
||||
|
||||
@@ -19,17 +19,18 @@ if TYPE_CHECKING:
|
||||
class BaseDocumentCompressor(BaseModel, ABC):
|
||||
"""Base class for document compressors.
|
||||
|
||||
This abstraction is primarily used for
|
||||
post-processing of retrieved documents.
|
||||
This abstraction is primarily used for post-processing of retrieved documents.
|
||||
|
||||
Documents matching a given query are first retrieved.
|
||||
|
||||
Then the list of documents can be further processed.
|
||||
|
||||
For example, one could re-rank the retrieved documents
|
||||
using an LLM.
|
||||
For example, one could re-rank the retrieved documents using an LLM.
|
||||
|
||||
.. note::
|
||||
Users should favor using a RunnableLambda instead of sub-classing from this
|
||||
interface.
|
||||
|
||||
**Note** users should favor using a RunnableLambda
|
||||
instead of sub-classing from this interface.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
@@ -48,6 +49,7 @@ class BaseDocumentCompressor(BaseModel, ABC):
|
||||
|
||||
Returns:
|
||||
The compressed documents.
|
||||
|
||||
"""
|
||||
|
||||
async def acompress_documents(
|
||||
@@ -65,6 +67,7 @@ class BaseDocumentCompressor(BaseModel, ABC):
|
||||
|
||||
Returns:
|
||||
The compressed documents.
|
||||
|
||||
"""
|
||||
return await run_in_executor(
|
||||
None, self.compress_documents, documents, query, callbacks
|
||||
|
||||
@@ -488,8 +488,8 @@ class DeleteResponse(TypedDict, total=False):
|
||||
failed: Sequence[str]
|
||||
"""The IDs that failed to be deleted.
|
||||
|
||||
Please note that deleting an ID that
|
||||
does not exist is **NOT** considered a failure.
|
||||
.. warning::
|
||||
Deleting an ID that does not exist is **NOT** considered a failure.
|
||||
"""
|
||||
|
||||
num_failed: int
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import copy
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.v1.messages import MessageV1
|
||||
|
||||
|
||||
def _is_openai_data_block(block: dict) -> bool:
|
||||
@@ -138,3 +140,37 @@ def _normalize_messages(messages: Sequence[BaseMessage]) -> list[BaseMessage]:
|
||||
formatted_messages.append(formatted_message)
|
||||
|
||||
return formatted_messages
|
||||
|
||||
|
||||
def _normalize_messages_v1(messages: Sequence[MessageV1]) -> list[MessageV1]:
|
||||
"""Extend support for message formats.
|
||||
|
||||
Chat models implement support for images in OpenAI Chat Completions format, as well
|
||||
as other multimodal data as standard data blocks. This function extends support to
|
||||
audio and file data in OpenAI Chat Completions format by converting them to standard
|
||||
data blocks.
|
||||
"""
|
||||
formatted_messages = []
|
||||
for message in messages:
|
||||
formatted_message = message
|
||||
if isinstance(message.content, list):
|
||||
for idx, block in enumerate(message.content):
|
||||
if (
|
||||
isinstance(block, dict)
|
||||
# Subset to (PDF) files and audio, as most relevant chat models
|
||||
# support images in OAI format (and some may not yet support the
|
||||
# standard data block format)
|
||||
and block.get("type") in {"file", "input_audio"}
|
||||
and _is_openai_data_block(block) # type: ignore[arg-type]
|
||||
):
|
||||
if formatted_message is message:
|
||||
formatted_message = copy.copy(message)
|
||||
# Also shallow-copy content
|
||||
formatted_message.content = list(formatted_message.content)
|
||||
|
||||
formatted_message.content[idx] = ( # type: ignore[call-overload]
|
||||
_convert_openai_format_to_data_block(block) # type: ignore[arg-type]
|
||||
)
|
||||
formatted_messages.append(formatted_message)
|
||||
|
||||
return formatted_messages
|
||||
|
||||
@@ -31,6 +31,7 @@ from langchain_core.messages import (
|
||||
from langchain_core.prompt_values import PromptValue
|
||||
from langchain_core.runnables import Runnable, RunnableSerializable
|
||||
from langchain_core.utils import get_pydantic_field_names
|
||||
from langchain_core.v1.messages import AIMessage as AIMessageV1
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.outputs import LLMResult
|
||||
@@ -57,8 +58,8 @@ class LangSmithParams(TypedDict, total=False):
|
||||
def get_tokenizer() -> Any:
|
||||
"""Get a GPT-2 tokenizer instance.
|
||||
|
||||
This function is cached to avoid re-loading the tokenizer
|
||||
every time it is called.
|
||||
This function is cached to avoid re-loading the tokenizer every time it is called.
|
||||
|
||||
"""
|
||||
try:
|
||||
from transformers import GPT2TokenizerFast # type: ignore[import-not-found]
|
||||
@@ -85,7 +86,9 @@ def _get_token_ids_default_method(text: str) -> list[int]:
|
||||
LanguageModelInput = Union[PromptValue, str, Sequence[MessageLikeRepresentation]]
|
||||
LanguageModelOutput = Union[BaseMessage, str]
|
||||
LanguageModelLike = Runnable[LanguageModelInput, LanguageModelOutput]
|
||||
LanguageModelOutputVar = TypeVar("LanguageModelOutputVar", BaseMessage, str)
|
||||
LanguageModelOutputVar = TypeVar(
|
||||
"LanguageModelOutputVar", BaseMessage, str, AIMessageV1
|
||||
)
|
||||
|
||||
|
||||
def _get_verbosity() -> bool:
|
||||
@@ -99,7 +102,8 @@ class BaseLanguageModel(
|
||||
):
|
||||
"""Abstract base class for interfacing with language models.
|
||||
|
||||
All language model wrappers inherited from BaseLanguageModel.
|
||||
All language model wrappers inherited from ``BaseLanguageModel``.
|
||||
|
||||
"""
|
||||
|
||||
cache: Union[BaseCache, bool, None] = Field(default=None, exclude=True)
|
||||
@@ -108,9 +112,10 @@ class BaseLanguageModel(
|
||||
* If true, will use the global cache.
|
||||
* If false, will not use a cache
|
||||
* If None, will use the global cache if it's set, otherwise no cache.
|
||||
* If instance of BaseCache, will use the provided cache.
|
||||
* If instance of ``BaseCache``, will use the provided cache.
|
||||
|
||||
Caching is not currently supported for streaming methods of models.
|
||||
|
||||
"""
|
||||
verbose: bool = Field(default_factory=_get_verbosity, exclude=True, repr=False)
|
||||
"""Whether to print out response text."""
|
||||
@@ -140,6 +145,7 @@ class BaseLanguageModel(
|
||||
|
||||
Returns:
|
||||
The verbosity setting to use.
|
||||
|
||||
"""
|
||||
if verbose is None:
|
||||
return _get_verbosity()
|
||||
@@ -195,7 +201,8 @@ class BaseLanguageModel(
|
||||
|
||||
Returns:
|
||||
An LLMResult, which contains a list of candidate Generations for each input
|
||||
prompt and additional model provider-specific output.
|
||||
prompt and additional model provider-specific output.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
@@ -229,8 +236,9 @@ class BaseLanguageModel(
|
||||
to the model provider API call.
|
||||
|
||||
Returns:
|
||||
An LLMResult, which contains a list of candidate Generations for each input
|
||||
prompt and additional model provider-specific output.
|
||||
An ``LLMResult``, which contains a list of candidate Generations for each
|
||||
input prompt and additional model provider-specific output.
|
||||
|
||||
"""
|
||||
|
||||
def with_structured_output(
|
||||
@@ -248,8 +256,8 @@ class BaseLanguageModel(
|
||||
) -> str:
|
||||
"""Pass a single string input to the model and return a string.
|
||||
|
||||
Use this method when passing in raw text. If you want to pass in specific
|
||||
types of chat messages, use predict_messages.
|
||||
Use this method when passing in raw text. If you want to pass in specific types
|
||||
of chat messages, use predict_messages.
|
||||
|
||||
Args:
|
||||
text: String input to pass to the model.
|
||||
@@ -260,6 +268,7 @@ class BaseLanguageModel(
|
||||
|
||||
Returns:
|
||||
Top model prediction as a string.
|
||||
|
||||
"""
|
||||
|
||||
@deprecated("0.1.7", alternative="invoke", removal="1.0")
|
||||
@@ -274,7 +283,7 @@ class BaseLanguageModel(
|
||||
"""Pass a message sequence to the model and return a message.
|
||||
|
||||
Use this method when passing in chat messages. If you want to pass in raw text,
|
||||
use predict.
|
||||
use predict.
|
||||
|
||||
Args:
|
||||
messages: A sequence of chat messages corresponding to a single model input.
|
||||
@@ -285,6 +294,7 @@ class BaseLanguageModel(
|
||||
|
||||
Returns:
|
||||
Top model prediction as a message.
|
||||
|
||||
"""
|
||||
|
||||
@deprecated("0.1.7", alternative="ainvoke", removal="1.0")
|
||||
@@ -295,7 +305,7 @@ class BaseLanguageModel(
|
||||
"""Asynchronously pass a string to the model and return a string.
|
||||
|
||||
Use this method when calling pure text generation models and only the top
|
||||
candidate generation is needed.
|
||||
candidate generation is needed.
|
||||
|
||||
Args:
|
||||
text: String input to pass to the model.
|
||||
@@ -306,6 +316,7 @@ class BaseLanguageModel(
|
||||
|
||||
Returns:
|
||||
Top model prediction as a string.
|
||||
|
||||
"""
|
||||
|
||||
@deprecated("0.1.7", alternative="ainvoke", removal="1.0")
|
||||
@@ -319,8 +330,8 @@ class BaseLanguageModel(
|
||||
) -> BaseMessage:
|
||||
"""Asynchronously pass messages to the model and return a message.
|
||||
|
||||
Use this method when calling chat models and only the top
|
||||
candidate generation is needed.
|
||||
Use this method when calling chat models and only the top candidate generation
|
||||
is needed.
|
||||
|
||||
Args:
|
||||
messages: A sequence of chat messages corresponding to a single model input.
|
||||
@@ -331,6 +342,7 @@ class BaseLanguageModel(
|
||||
|
||||
Returns:
|
||||
Top model prediction as a message.
|
||||
|
||||
"""
|
||||
|
||||
@property
|
||||
@@ -346,7 +358,8 @@ class BaseLanguageModel(
|
||||
|
||||
Returns:
|
||||
A list of ids corresponding to the tokens in the text, in order they occur
|
||||
in the text.
|
||||
in the text.
|
||||
|
||||
"""
|
||||
if self.custom_get_token_ids is not None:
|
||||
return self.custom_get_token_ids(text)
|
||||
@@ -362,6 +375,7 @@ class BaseLanguageModel(
|
||||
|
||||
Returns:
|
||||
The integer number of tokens in the text.
|
||||
|
||||
"""
|
||||
return len(self.get_token_ids(text))
|
||||
|
||||
@@ -374,16 +388,18 @@ class BaseLanguageModel(
|
||||
|
||||
Useful for checking if an input fits in a model's context window.
|
||||
|
||||
**Note**: the base implementation of get_num_tokens_from_messages ignores
|
||||
tool schemas.
|
||||
.. note::
|
||||
The base implementation of ``get_num_tokens_from_messages`` ignores tool
|
||||
schemas.
|
||||
|
||||
Args:
|
||||
messages: The message inputs to tokenize.
|
||||
tools: If provided, sequence of dict, BaseModel, function, or BaseTools
|
||||
to be converted to tool schemas.
|
||||
tools: If provided, sequence of dict, ``BaseModel``, function, or
|
||||
``BaseTools`` to be converted to tool schemas.
|
||||
|
||||
Returns:
|
||||
The sum of the number of tokens across the messages.
|
||||
|
||||
"""
|
||||
if tools is not None:
|
||||
warnings.warn(
|
||||
@@ -396,6 +412,7 @@ class BaseLanguageModel(
|
||||
def _all_required_field_names(cls) -> set:
|
||||
"""DEPRECATED: Kept for backwards compatibility.
|
||||
|
||||
Use get_pydantic_field_names.
|
||||
Use ``get_pydantic_field_names``.
|
||||
|
||||
"""
|
||||
return get_pydantic_field_names(cls)
|
||||
|
||||
@@ -97,17 +97,18 @@ def _generate_response_from_error(error: BaseException) -> list[ChatGeneration]:
|
||||
|
||||
|
||||
def _format_for_tracing(messages: list[BaseMessage]) -> list[BaseMessage]:
|
||||
"""Format messages for tracing in on_chat_model_start.
|
||||
"""Format messages for tracing in ``on_chat_model_start``.
|
||||
|
||||
- Update image content blocks to OpenAI Chat Completions format (backward
|
||||
compatibility).
|
||||
- Add "type" key to content blocks that have a single key.
|
||||
- Add ``type`` key to content blocks that have a single key.
|
||||
|
||||
Args:
|
||||
messages: List of messages to format.
|
||||
|
||||
Returns:
|
||||
List of messages formatted for tracing.
|
||||
|
||||
"""
|
||||
messages_to_trace = []
|
||||
for message in messages:
|
||||
@@ -153,10 +154,11 @@ def generate_from_stream(stream: Iterator[ChatGenerationChunk]) -> ChatResult:
|
||||
"""Generate from a stream.
|
||||
|
||||
Args:
|
||||
stream: Iterator of ChatGenerationChunk.
|
||||
stream: Iterator of ``ChatGenerationChunk``.
|
||||
|
||||
Returns:
|
||||
ChatResult: Chat result.
|
||||
|
||||
"""
|
||||
generation = next(stream, None)
|
||||
if generation:
|
||||
@@ -180,10 +182,11 @@ async def agenerate_from_stream(
|
||||
"""Async generate from a stream.
|
||||
|
||||
Args:
|
||||
stream: Iterator of ChatGenerationChunk.
|
||||
stream: Iterator of ``ChatGenerationChunk``.
|
||||
|
||||
Returns:
|
||||
ChatResult: Chat result.
|
||||
|
||||
"""
|
||||
chunks = [chunk async for chunk in stream]
|
||||
return await run_in_executor(None, generate_from_stream, iter(chunks))
|
||||
@@ -311,15 +314,16 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
provided. This offers the best of both worlds.
|
||||
- If False (default), will always use streaming case if available.
|
||||
|
||||
The main reason for this flag is that code might be written using ``.stream()`` and
|
||||
The main reason for this flag is that code might be written using ``stream()`` and
|
||||
a user may want to swap out a given model for another model whose the implementation
|
||||
does not properly support streaming.
|
||||
|
||||
"""
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def raise_deprecation(cls, values: dict) -> Any:
|
||||
"""Raise deprecation warning if callback_manager is used.
|
||||
"""Raise deprecation warning if ``callback_manager`` is used.
|
||||
|
||||
Args:
|
||||
values (Dict): Values to validate.
|
||||
@@ -328,7 +332,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
Dict: Validated values.
|
||||
|
||||
Raises:
|
||||
DeprecationWarning: If callback_manager is used.
|
||||
DeprecationWarning: If ``callback_manager`` is used.
|
||||
|
||||
"""
|
||||
if values.get("callback_manager") is not None:
|
||||
warnings.warn(
|
||||
@@ -653,6 +658,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
|
||||
Returns:
|
||||
List of ChatGeneration objects.
|
||||
|
||||
"""
|
||||
converted_generations = []
|
||||
for gen in cache_val:
|
||||
@@ -666,6 +672,16 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
converted_generations.append(chat_gen)
|
||||
else:
|
||||
# Already a ChatGeneration or other expected type
|
||||
if hasattr(gen, "message") and isinstance(gen.message, AIMessage):
|
||||
# We zero out cost on cache hits
|
||||
gen.message = gen.message.model_copy(
|
||||
update={
|
||||
"usage_metadata": {
|
||||
**(gen.message.usage_metadata or {}),
|
||||
"total_cost": 0,
|
||||
}
|
||||
}
|
||||
)
|
||||
converted_generations.append(gen)
|
||||
return converted_generations
|
||||
|
||||
@@ -768,7 +784,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
|
||||
Returns:
|
||||
An LLMResult, which contains a list of candidate Generations for each input
|
||||
prompt and additional model provider-specific output.
|
||||
prompt and additional model provider-specific output.
|
||||
|
||||
"""
|
||||
ls_structured_output_format = kwargs.pop(
|
||||
"ls_structured_output_format", None
|
||||
@@ -882,7 +899,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
|
||||
Returns:
|
||||
An LLMResult, which contains a list of candidate Generations for each input
|
||||
prompt and additional model provider-specific output.
|
||||
prompt and additional model provider-specific output.
|
||||
|
||||
"""
|
||||
ls_structured_output_format = kwargs.pop(
|
||||
"ls_structured_output_format", None
|
||||
@@ -1238,6 +1256,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
|
||||
Returns:
|
||||
The model output message.
|
||||
|
||||
"""
|
||||
generation = self.generate(
|
||||
[messages], stop=stop, callbacks=callbacks, **kwargs
|
||||
@@ -1278,6 +1297,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
|
||||
Returns:
|
||||
The model output string.
|
||||
|
||||
"""
|
||||
return self.predict(message, stop=stop, **kwargs)
|
||||
|
||||
@@ -1297,6 +1317,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
|
||||
Returns:
|
||||
The predicted output string.
|
||||
|
||||
"""
|
||||
stop_ = None if stop is None else list(stop)
|
||||
result = self([HumanMessage(content=text)], stop=stop_, **kwargs)
|
||||
@@ -1372,6 +1393,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
|
||||
Returns:
|
||||
A Runnable that returns a message.
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -1534,8 +1556,10 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
class SimpleChatModel(BaseChatModel):
|
||||
"""Simplified implementation for a chat model to inherit from.
|
||||
|
||||
**Note** This implementation is primarily here for backwards compatibility.
|
||||
For new implementations, please use `BaseChatModel` directly.
|
||||
.. note::
|
||||
This implementation is primarily here for backwards compatibility. For new
|
||||
implementations, please use ``BaseChatModel`` directly.
|
||||
|
||||
"""
|
||||
|
||||
def _generate(
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import asyncio
|
||||
import re
|
||||
import time
|
||||
from collections.abc import AsyncIterator, Iterator
|
||||
from collections.abc import AsyncIterator, Iterable, Iterator
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from typing_extensions import override
|
||||
@@ -16,6 +16,10 @@ from langchain_core.language_models.chat_models import BaseChatModel, SimpleChat
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langchain_core.v1.chat_models import BaseChatModel as BaseChatModelV1
|
||||
from langchain_core.v1.messages import AIMessage as AIMessageV1
|
||||
from langchain_core.v1.messages import AIMessageChunk as AIMessageChunkV1
|
||||
from langchain_core.v1.messages import MessageV1
|
||||
|
||||
|
||||
class FakeMessagesListChatModel(BaseChatModel):
|
||||
@@ -223,11 +227,12 @@ class GenericFakeChatModel(BaseChatModel):
|
||||
This can be expanded to accept other types like Callables / dicts / strings
|
||||
to make the interface more generic if needed.
|
||||
|
||||
Note: if you want to pass a list, you can use `iter` to convert it to an iterator.
|
||||
.. note::
|
||||
if you want to pass a list, you can use ``iter`` to convert it to an iterator.
|
||||
|
||||
Please note that streaming is not implemented yet. We should try to implement it
|
||||
in the future by delegating to invoke and then breaking the resulting output
|
||||
into message chunks.
|
||||
.. warning::
|
||||
Streaming is not implemented yet. We should try to implement it in the future by
|
||||
delegating to invoke and then breaking the resulting output into message chunks.
|
||||
"""
|
||||
|
||||
@override
|
||||
@@ -367,3 +372,69 @@ class ParrotFakeChatModel(BaseChatModel):
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "parrot-fake-chat-model"
|
||||
|
||||
|
||||
class GenericFakeChatModelV1(BaseChatModelV1):
|
||||
"""Generic fake chat model that can be used to test the chat model interface."""
|
||||
|
||||
messages: Optional[Iterator[Union[AIMessageV1, str]]] = None
|
||||
message_chunks: Optional[Iterable[Union[AIMessageChunkV1, str]]] = None
|
||||
|
||||
@override
|
||||
def _invoke(
|
||||
self,
|
||||
messages: list[MessageV1],
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AIMessageV1:
|
||||
"""Top Level call."""
|
||||
if self.messages is None:
|
||||
error_msg = "Messages iterator is not set."
|
||||
raise ValueError(error_msg)
|
||||
message = next(self.messages)
|
||||
return AIMessageV1(content=message) if isinstance(message, str) else message
|
||||
|
||||
@override
|
||||
def _stream(
|
||||
self,
|
||||
messages: list[MessageV1],
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[AIMessageChunkV1]:
|
||||
"""Top Level call."""
|
||||
if self.message_chunks is None:
|
||||
error_msg = "Message chunks iterator is not set."
|
||||
raise ValueError(error_msg)
|
||||
for chunk in self.message_chunks:
|
||||
if isinstance(chunk, str):
|
||||
yield AIMessageChunkV1(chunk)
|
||||
else:
|
||||
yield chunk
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "generic-fake-chat-model"
|
||||
|
||||
|
||||
class ParrotFakeChatModelV1(BaseChatModelV1):
|
||||
"""Generic fake chat model that can be used to test the chat model interface.
|
||||
|
||||
* Chat model should be usable in both sync and async tests
|
||||
"""
|
||||
|
||||
@override
|
||||
def _invoke(
|
||||
self,
|
||||
messages: list[MessageV1],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AIMessageV1:
|
||||
"""Top Level call."""
|
||||
if isinstance(messages[-1], AIMessageV1):
|
||||
return messages[-1]
|
||||
return AIMessageV1(content=messages[-1].content)
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "parrot-fake-chat-model"
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
"""Dump objects to json."""
|
||||
|
||||
import dataclasses
|
||||
import inspect
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain_core.load.serializable import Serializable, to_json_not_implemented
|
||||
from langchain_core.v1.messages import MessageV1Types
|
||||
|
||||
|
||||
def default(obj: Any) -> Any:
|
||||
@@ -19,6 +22,24 @@ def default(obj: Any) -> Any:
|
||||
"""
|
||||
if isinstance(obj, Serializable):
|
||||
return obj.to_json()
|
||||
|
||||
# Handle v1 message classes
|
||||
if type(obj) in MessageV1Types:
|
||||
# Get the constructor signature to only include valid parameters
|
||||
init_sig = inspect.signature(type(obj).__init__)
|
||||
valid_params = set(init_sig.parameters.keys()) - {"self"}
|
||||
|
||||
# Filter dataclass fields to only include constructor params
|
||||
all_fields = dataclasses.asdict(obj)
|
||||
kwargs = {k: v for k, v in all_fields.items() if k in valid_params}
|
||||
|
||||
return {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": ["langchain_core", "v1", "messages", type(obj).__name__],
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
|
||||
return to_json_not_implemented(obj)
|
||||
|
||||
|
||||
@@ -73,10 +94,9 @@ def dumps(obj: Any, *, pretty: bool = False, **kwargs: Any) -> str:
|
||||
def dumpd(obj: Any) -> Any:
|
||||
"""Return a dict representation of an object.
|
||||
|
||||
Note:
|
||||
Unfortunately this function is not as efficient as it could be
|
||||
because it first dumps the object to a json string and then loads it
|
||||
back into a dictionary.
|
||||
.. note::
|
||||
Unfortunately this function is not as efficient as it could be because it first
|
||||
dumps the object to a json string and then loads it back into a dictionary.
|
||||
|
||||
Args:
|
||||
obj: The object to dump.
|
||||
|
||||
@@ -156,8 +156,13 @@ class Reviver:
|
||||
|
||||
cls = getattr(mod, name)
|
||||
|
||||
# The class must be a subclass of Serializable.
|
||||
if not issubclass(cls, Serializable):
|
||||
# Import MessageV1Types lazily to avoid circular import:
|
||||
# load.load -> v1.messages -> messages.ai -> messages.base ->
|
||||
# load.serializable -> load.__init__ -> load.load
|
||||
from langchain_core.v1.messages import MessageV1Types
|
||||
|
||||
# The class must be a subclass of Serializable or a v1 message class.
|
||||
if not (issubclass(cls, Serializable) or cls in MessageV1Types):
|
||||
msg = f"Invalid namespace: {value}"
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
@@ -33,9 +33,31 @@ if TYPE_CHECKING:
|
||||
)
|
||||
from langchain_core.messages.chat import ChatMessage, ChatMessageChunk
|
||||
from langchain_core.messages.content_blocks import (
|
||||
Annotation,
|
||||
AudioContentBlock,
|
||||
Citation,
|
||||
CodeInterpreterCall,
|
||||
CodeInterpreterOutput,
|
||||
CodeInterpreterResult,
|
||||
ContentBlock,
|
||||
DataContentBlock,
|
||||
FileContentBlock,
|
||||
ImageContentBlock,
|
||||
NonStandardAnnotation,
|
||||
NonStandardContentBlock,
|
||||
PlainTextContentBlock,
|
||||
ReasoningContentBlock,
|
||||
TextContentBlock,
|
||||
VideoContentBlock,
|
||||
WebSearchCall,
|
||||
WebSearchResult,
|
||||
convert_to_openai_data_block,
|
||||
convert_to_openai_image_block,
|
||||
is_data_content_block,
|
||||
is_reasoning_block,
|
||||
is_text_block,
|
||||
is_tool_call_block,
|
||||
is_tool_call_chunk,
|
||||
)
|
||||
from langchain_core.messages.function import FunctionMessage, FunctionMessageChunk
|
||||
from langchain_core.messages.human import HumanMessage, HumanMessageChunk
|
||||
@@ -65,24 +87,42 @@ if TYPE_CHECKING:
|
||||
__all__ = (
|
||||
"AIMessage",
|
||||
"AIMessageChunk",
|
||||
"Annotation",
|
||||
"AnyMessage",
|
||||
"AudioContentBlock",
|
||||
"BaseMessage",
|
||||
"BaseMessageChunk",
|
||||
"ChatMessage",
|
||||
"ChatMessageChunk",
|
||||
"Citation",
|
||||
"CodeInterpreterCall",
|
||||
"CodeInterpreterOutput",
|
||||
"CodeInterpreterResult",
|
||||
"ContentBlock",
|
||||
"DataContentBlock",
|
||||
"FileContentBlock",
|
||||
"FunctionMessage",
|
||||
"FunctionMessageChunk",
|
||||
"HumanMessage",
|
||||
"HumanMessageChunk",
|
||||
"ImageContentBlock",
|
||||
"InvalidToolCall",
|
||||
"MessageLikeRepresentation",
|
||||
"NonStandardAnnotation",
|
||||
"NonStandardContentBlock",
|
||||
"PlainTextContentBlock",
|
||||
"ReasoningContentBlock",
|
||||
"RemoveMessage",
|
||||
"SystemMessage",
|
||||
"SystemMessageChunk",
|
||||
"TextContentBlock",
|
||||
"ToolCall",
|
||||
"ToolCallChunk",
|
||||
"ToolMessage",
|
||||
"ToolMessageChunk",
|
||||
"VideoContentBlock",
|
||||
"WebSearchCall",
|
||||
"WebSearchResult",
|
||||
"_message_from_dict",
|
||||
"convert_to_messages",
|
||||
"convert_to_openai_data_block",
|
||||
@@ -91,6 +131,10 @@ __all__ = (
|
||||
"filter_messages",
|
||||
"get_buffer_string",
|
||||
"is_data_content_block",
|
||||
"is_reasoning_block",
|
||||
"is_text_block",
|
||||
"is_tool_call_block",
|
||||
"is_tool_call_chunk",
|
||||
"merge_content",
|
||||
"merge_message_runs",
|
||||
"message_chunk_to_message",
|
||||
@@ -103,25 +147,43 @@ __all__ = (
|
||||
_dynamic_imports = {
|
||||
"AIMessage": "ai",
|
||||
"AIMessageChunk": "ai",
|
||||
"Annotation": "content_blocks",
|
||||
"AudioContentBlock": "content_blocks",
|
||||
"BaseMessage": "base",
|
||||
"BaseMessageChunk": "base",
|
||||
"merge_content": "base",
|
||||
"message_to_dict": "base",
|
||||
"messages_to_dict": "base",
|
||||
"Citation": "content_blocks",
|
||||
"ContentBlock": "content_blocks",
|
||||
"ChatMessage": "chat",
|
||||
"ChatMessageChunk": "chat",
|
||||
"CodeInterpreterCall": "content_blocks",
|
||||
"CodeInterpreterOutput": "content_blocks",
|
||||
"CodeInterpreterResult": "content_blocks",
|
||||
"DataContentBlock": "content_blocks",
|
||||
"FileContentBlock": "content_blocks",
|
||||
"FunctionMessage": "function",
|
||||
"FunctionMessageChunk": "function",
|
||||
"HumanMessage": "human",
|
||||
"HumanMessageChunk": "human",
|
||||
"NonStandardAnnotation": "content_blocks",
|
||||
"NonStandardContentBlock": "content_blocks",
|
||||
"PlainTextContentBlock": "content_blocks",
|
||||
"ReasoningContentBlock": "content_blocks",
|
||||
"RemoveMessage": "modifier",
|
||||
"SystemMessage": "system",
|
||||
"SystemMessageChunk": "system",
|
||||
"WebSearchCall": "content_blocks",
|
||||
"WebSearchResult": "content_blocks",
|
||||
"ImageContentBlock": "content_blocks",
|
||||
"InvalidToolCall": "tool",
|
||||
"TextContentBlock": "content_blocks",
|
||||
"ToolCall": "tool",
|
||||
"ToolCallChunk": "tool",
|
||||
"ToolMessage": "tool",
|
||||
"ToolMessageChunk": "tool",
|
||||
"VideoContentBlock": "content_blocks",
|
||||
"AnyMessage": "utils",
|
||||
"MessageLikeRepresentation": "utils",
|
||||
"_message_from_dict": "utils",
|
||||
@@ -132,6 +194,10 @@ _dynamic_imports = {
|
||||
"filter_messages": "utils",
|
||||
"get_buffer_string": "utils",
|
||||
"is_data_content_block": "content_blocks",
|
||||
"is_reasoning_block": "content_blocks",
|
||||
"is_text_block": "content_blocks",
|
||||
"is_tool_call_block": "content_blocks",
|
||||
"is_tool_call_chunk": "content_blocks",
|
||||
"merge_message_runs": "utils",
|
||||
"message_chunk_to_message": "utils",
|
||||
"messages_from_dict": "utils",
|
||||
|
||||
@@ -8,11 +8,7 @@ from typing import Any, Literal, Optional, Union, cast
|
||||
from pydantic import model_validator
|
||||
from typing_extensions import NotRequired, Self, TypedDict, override
|
||||
|
||||
from langchain_core.messages.base import (
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
merge_content,
|
||||
)
|
||||
from langchain_core.messages.base import BaseMessage, BaseMessageChunk, merge_content
|
||||
from langchain_core.messages.tool import (
|
||||
InvalidToolCall,
|
||||
ToolCall,
|
||||
@@ -20,23 +16,26 @@ from langchain_core.messages.tool import (
|
||||
default_tool_chunk_parser,
|
||||
default_tool_parser,
|
||||
)
|
||||
from langchain_core.messages.tool import (
|
||||
invalid_tool_call as create_invalid_tool_call,
|
||||
)
|
||||
from langchain_core.messages.tool import (
|
||||
tool_call as create_tool_call,
|
||||
)
|
||||
from langchain_core.messages.tool import (
|
||||
tool_call_chunk as create_tool_call_chunk,
|
||||
)
|
||||
from langchain_core.messages.tool import invalid_tool_call as create_invalid_tool_call
|
||||
from langchain_core.messages.tool import tool_call as create_tool_call
|
||||
from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chunk
|
||||
from langchain_core.utils._merge import merge_dicts, merge_lists
|
||||
from langchain_core.utils.json import parse_partial_json
|
||||
from langchain_core.utils.usage import _dict_int_op
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_LC_AUTO_PREFIX = "lc_"
|
||||
"""LangChain auto-generated ID prefix for messages and content blocks."""
|
||||
|
||||
_LC_ID_PREFIX = "run-"
|
||||
_LC_ID_PREFIX = f"{_LC_AUTO_PREFIX}run-"
|
||||
"""Internal tracing/callback system identifier.
|
||||
|
||||
Used for:
|
||||
- Tracing. Every LangChain operation (LLM call, chain execution, tool use, etc.)
|
||||
gets a unique run_id (UUID)
|
||||
- Enables tracking parent-child relationships between operations
|
||||
"""
|
||||
|
||||
|
||||
class InputTokenDetails(TypedDict, total=False):
|
||||
@@ -428,17 +427,27 @@ def add_ai_message_chunks(
|
||||
|
||||
chunk_id = None
|
||||
candidates = [left.id] + [o.id for o in others]
|
||||
# first pass: pick the first non-run-* id
|
||||
# first pass: pick the first provider-assigned id (non-run-* and non-lc_*)
|
||||
for id_ in candidates:
|
||||
if id_ and not id_.startswith(_LC_ID_PREFIX):
|
||||
if (
|
||||
id_
|
||||
and not id_.startswith(_LC_ID_PREFIX)
|
||||
and not id_.startswith(_LC_AUTO_PREFIX)
|
||||
):
|
||||
chunk_id = id_
|
||||
break
|
||||
else:
|
||||
# second pass: no provider-assigned id found, just take the first non-null
|
||||
# second pass: prefer lc_run-* ids over lc_* ids
|
||||
for id_ in candidates:
|
||||
if id_:
|
||||
if id_ and id_.startswith(_LC_ID_PREFIX):
|
||||
chunk_id = id_
|
||||
break
|
||||
else:
|
||||
# third pass: take any remaining id (auto-generated lc_* ids)
|
||||
for id_ in candidates:
|
||||
if id_:
|
||||
chunk_id = id_
|
||||
break
|
||||
|
||||
return left.__class__(
|
||||
example=left.example,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -13,7 +13,7 @@ class RemoveMessage(BaseMessage):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str, # noqa: A002
|
||||
id: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Create a RemoveMessage.
|
||||
|
||||
@@ -5,9 +5,12 @@ from typing import Any, Literal, Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import Field, model_validator
|
||||
from typing_extensions import NotRequired, TypedDict, override
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_core.messages.base import BaseMessage, BaseMessageChunk, merge_content
|
||||
from langchain_core.messages.content_blocks import InvalidToolCall as InvalidToolCall
|
||||
from langchain_core.messages.content_blocks import ToolCall as ToolCall
|
||||
from langchain_core.messages.content_blocks import ToolCallChunk as ToolCallChunk
|
||||
from langchain_core.utils._merge import merge_dicts, merge_obj
|
||||
|
||||
|
||||
@@ -177,42 +180,11 @@ class ToolMessageChunk(ToolMessage, BaseMessageChunk):
|
||||
return super().__add__(other)
|
||||
|
||||
|
||||
class ToolCall(TypedDict):
|
||||
"""Represents a request to call a tool.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
{
|
||||
"name": "foo",
|
||||
"args": {"a": 1},
|
||||
"id": "123"
|
||||
}
|
||||
|
||||
This represents a request to call the tool named "foo" with arguments {"a": 1}
|
||||
and an identifier of "123".
|
||||
|
||||
"""
|
||||
|
||||
name: str
|
||||
"""The name of the tool to be called."""
|
||||
args: dict[str, Any]
|
||||
"""The arguments to the tool call."""
|
||||
id: Optional[str]
|
||||
"""An identifier associated with the tool call.
|
||||
|
||||
An identifier is needed to associate a tool call request with a tool
|
||||
call result in events when multiple concurrent tool calls are made.
|
||||
"""
|
||||
type: NotRequired[Literal["tool_call"]]
|
||||
|
||||
|
||||
def tool_call(
|
||||
*,
|
||||
name: str,
|
||||
args: dict[str, Any],
|
||||
id: Optional[str], # noqa: A002
|
||||
id: Optional[str],
|
||||
) -> ToolCall:
|
||||
"""Create a tool call.
|
||||
|
||||
@@ -224,43 +196,11 @@ def tool_call(
|
||||
return ToolCall(name=name, args=args, id=id, type="tool_call")
|
||||
|
||||
|
||||
class ToolCallChunk(TypedDict):
|
||||
"""A chunk of a tool call (e.g., as part of a stream).
|
||||
|
||||
When merging ToolCallChunks (e.g., via AIMessageChunk.__add__),
|
||||
all string attributes are concatenated. Chunks are only merged if their
|
||||
values of `index` are equal and not None.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
left_chunks = [ToolCallChunk(name="foo", args='{"a":', index=0)]
|
||||
right_chunks = [ToolCallChunk(name=None, args='1}', index=0)]
|
||||
|
||||
(
|
||||
AIMessageChunk(content="", tool_call_chunks=left_chunks)
|
||||
+ AIMessageChunk(content="", tool_call_chunks=right_chunks)
|
||||
).tool_call_chunks == [ToolCallChunk(name='foo', args='{"a":1}', index=0)]
|
||||
|
||||
"""
|
||||
|
||||
name: Optional[str]
|
||||
"""The name of the tool to be called."""
|
||||
args: Optional[str]
|
||||
"""The arguments to the tool call."""
|
||||
id: Optional[str]
|
||||
"""An identifier associated with the tool call."""
|
||||
index: Optional[int]
|
||||
"""The index of the tool call in a sequence."""
|
||||
type: NotRequired[Literal["tool_call_chunk"]]
|
||||
|
||||
|
||||
def tool_call_chunk(
|
||||
*,
|
||||
name: Optional[str] = None,
|
||||
args: Optional[str] = None,
|
||||
id: Optional[str] = None, # noqa: A002
|
||||
id: Optional[str] = None,
|
||||
index: Optional[int] = None,
|
||||
) -> ToolCallChunk:
|
||||
"""Create a tool call chunk.
|
||||
@@ -276,29 +216,11 @@ def tool_call_chunk(
|
||||
)
|
||||
|
||||
|
||||
class InvalidToolCall(TypedDict):
|
||||
"""Allowance for errors made by LLM.
|
||||
|
||||
Here we add an `error` key to surface errors made during generation
|
||||
(e.g., invalid JSON arguments.)
|
||||
"""
|
||||
|
||||
name: Optional[str]
|
||||
"""The name of the tool to be called."""
|
||||
args: Optional[str]
|
||||
"""The arguments to the tool call."""
|
||||
id: Optional[str]
|
||||
"""An identifier associated with the tool call."""
|
||||
error: Optional[str]
|
||||
"""An error message associated with the tool call."""
|
||||
type: NotRequired[Literal["invalid_tool_call"]]
|
||||
|
||||
|
||||
def invalid_tool_call(
|
||||
*,
|
||||
name: Optional[str] = None,
|
||||
args: Optional[str] = None,
|
||||
id: Optional[str] = None, # noqa: A002
|
||||
id: Optional[str] = None,
|
||||
error: Optional[str] = None,
|
||||
) -> InvalidToolCall:
|
||||
"""Create an invalid tool call.
|
||||
|
||||
@@ -35,11 +35,18 @@ from langchain_core.messages import convert_to_openai_data_block, is_data_conten
|
||||
from langchain_core.messages.ai import AIMessage, AIMessageChunk
|
||||
from langchain_core.messages.base import BaseMessage, BaseMessageChunk
|
||||
from langchain_core.messages.chat import ChatMessage, ChatMessageChunk
|
||||
from langchain_core.messages.content_blocks import ContentBlock
|
||||
from langchain_core.messages.function import FunctionMessage, FunctionMessageChunk
|
||||
from langchain_core.messages.human import HumanMessage, HumanMessageChunk
|
||||
from langchain_core.messages.modifier import RemoveMessage
|
||||
from langchain_core.messages.system import SystemMessage, SystemMessageChunk
|
||||
from langchain_core.messages.tool import ToolCall, ToolMessage, ToolMessageChunk
|
||||
from langchain_core.v1.messages import AIMessage as AIMessageV1
|
||||
from langchain_core.v1.messages import AIMessageChunk as AIMessageChunkV1
|
||||
from langchain_core.v1.messages import HumanMessage as HumanMessageV1
|
||||
from langchain_core.v1.messages import MessageV1, MessageV1Types, ResponseMetadata
|
||||
from langchain_core.v1.messages import SystemMessage as SystemMessageV1
|
||||
from langchain_core.v1.messages import ToolMessage as ToolMessageV1
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_text_splitters import TextSplitter
|
||||
@@ -203,7 +210,7 @@ def message_chunk_to_message(chunk: BaseMessageChunk) -> BaseMessage:
|
||||
|
||||
|
||||
MessageLikeRepresentation = Union[
|
||||
BaseMessage, list[str], tuple[str, str], str, dict[str, Any]
|
||||
BaseMessage, list[str], tuple[str, str], str, dict[str, Any], MessageV1
|
||||
]
|
||||
|
||||
|
||||
@@ -213,7 +220,7 @@ def _create_message_from_message_type(
|
||||
name: Optional[str] = None,
|
||||
tool_call_id: Optional[str] = None,
|
||||
tool_calls: Optional[list[dict[str, Any]]] = None,
|
||||
id: Optional[str] = None, # noqa: A002
|
||||
id: Optional[str] = None,
|
||||
**additional_kwargs: Any,
|
||||
) -> BaseMessage:
|
||||
"""Create a message from a message type and content string.
|
||||
@@ -294,6 +301,130 @@ def _create_message_from_message_type(
|
||||
return message
|
||||
|
||||
|
||||
def _create_message_from_message_type_v1(
|
||||
message_type: str,
|
||||
content: str,
|
||||
name: Optional[str] = None,
|
||||
tool_call_id: Optional[str] = None,
|
||||
tool_calls: Optional[list[dict[str, Any]]] = None,
|
||||
id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> MessageV1:
|
||||
"""Create a message from a message type and content string.
|
||||
|
||||
Args:
|
||||
message_type: (str) the type of the message (e.g., "human", "ai", etc.).
|
||||
content: (str) the content string.
|
||||
name: (str) the name of the message. Default is None.
|
||||
tool_call_id: (str) the tool call id. Default is None.
|
||||
tool_calls: (list[dict[str, Any]]) the tool calls. Default is None.
|
||||
id: (str) the id of the message. Default is None.
|
||||
kwargs: (dict[str, Any]) additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
a message of the appropriate type.
|
||||
|
||||
Raises:
|
||||
ValueError: if the message type is not one of "human", "user", "ai",
|
||||
"assistant", "tool", "system", or "developer".
|
||||
"""
|
||||
if name is not None:
|
||||
kwargs["name"] = name
|
||||
if tool_call_id is not None:
|
||||
kwargs["tool_call_id"] = tool_call_id
|
||||
if kwargs and (response_metadata := kwargs.pop("response_metadata", None)):
|
||||
kwargs["response_metadata"] = response_metadata
|
||||
if id is not None:
|
||||
kwargs["id"] = id
|
||||
if tool_calls is not None:
|
||||
kwargs["tool_calls"] = []
|
||||
for tool_call in tool_calls:
|
||||
# Convert OpenAI-format tool call to LangChain format.
|
||||
if "function" in tool_call:
|
||||
args = tool_call["function"]["arguments"]
|
||||
if isinstance(args, str):
|
||||
args = json.loads(args, strict=False)
|
||||
kwargs["tool_calls"].append(
|
||||
{
|
||||
"name": tool_call["function"]["name"],
|
||||
"args": args,
|
||||
"id": tool_call["id"],
|
||||
"type": "tool_call",
|
||||
}
|
||||
)
|
||||
else:
|
||||
kwargs["tool_calls"].append(tool_call)
|
||||
if message_type in {"human", "user"}:
|
||||
message: MessageV1 = HumanMessageV1(content=content, **kwargs)
|
||||
elif message_type in {"ai", "assistant"}:
|
||||
message = AIMessageV1(content=content, **kwargs)
|
||||
elif message_type in {"system", "developer"}:
|
||||
if message_type == "developer":
|
||||
kwargs["custom_role"] = "developer"
|
||||
message = SystemMessageV1(content=content, **kwargs)
|
||||
elif message_type == "tool":
|
||||
artifact = kwargs.pop("artifact", None)
|
||||
message = ToolMessageV1(content=content, artifact=artifact, **kwargs)
|
||||
else:
|
||||
msg = (
|
||||
f"Unexpected message type: '{message_type}'. Use one of 'human',"
|
||||
f" 'user', 'ai', 'assistant', 'function', 'tool', 'system', or 'developer'."
|
||||
)
|
||||
msg = create_message(message=msg, error_code=ErrorCode.MESSAGE_COERCION_FAILURE)
|
||||
raise ValueError(msg)
|
||||
return message
|
||||
|
||||
|
||||
def convert_from_v1_message(message: MessageV1) -> BaseMessage:
|
||||
"""Compatibility layer to convert v1 messages to current messages.
|
||||
|
||||
Args:
|
||||
message: MessageV1 instance to convert.
|
||||
|
||||
Returns:
|
||||
BaseMessage: Converted message instance.
|
||||
"""
|
||||
content = cast("Union[str, list[str | dict]]", message.content)
|
||||
if isinstance(message, AIMessageV1):
|
||||
return AIMessage(
|
||||
content=content,
|
||||
id=message.id,
|
||||
name=message.name,
|
||||
tool_calls=message.tool_calls,
|
||||
response_metadata=cast("dict", message.response_metadata),
|
||||
)
|
||||
if isinstance(message, AIMessageChunkV1):
|
||||
return AIMessageChunk(
|
||||
content=content,
|
||||
id=message.id,
|
||||
name=message.name,
|
||||
tool_call_chunks=message.tool_call_chunks,
|
||||
response_metadata=cast("dict", message.response_metadata),
|
||||
)
|
||||
if isinstance(message, HumanMessageV1):
|
||||
return HumanMessage(
|
||||
content=content,
|
||||
id=message.id,
|
||||
name=message.name,
|
||||
)
|
||||
if isinstance(message, SystemMessageV1):
|
||||
return SystemMessage(
|
||||
content=content,
|
||||
id=message.id,
|
||||
)
|
||||
if isinstance(message, ToolMessageV1):
|
||||
return ToolMessage(
|
||||
content=content,
|
||||
id=message.id,
|
||||
tool_call_id=message.tool_call_id,
|
||||
artifact=message.artifact,
|
||||
name=message.name,
|
||||
status=message.status,
|
||||
)
|
||||
message = f"Unsupported message type: {type(message)}"
|
||||
raise NotImplementedError(message)
|
||||
|
||||
|
||||
def _convert_to_message(message: MessageLikeRepresentation) -> BaseMessage:
|
||||
"""Instantiate a message from a variety of message formats.
|
||||
|
||||
@@ -341,6 +472,143 @@ def _convert_to_message(message: MessageLikeRepresentation) -> BaseMessage:
|
||||
message_ = _create_message_from_message_type(
|
||||
msg_type, msg_content, **msg_kwargs
|
||||
)
|
||||
elif isinstance(message, MessageV1Types):
|
||||
message_ = convert_from_v1_message(message)
|
||||
else:
|
||||
msg = f"Unsupported message type: {type(message)}"
|
||||
msg = create_message(message=msg, error_code=ErrorCode.MESSAGE_COERCION_FAILURE)
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
return message_
|
||||
|
||||
|
||||
def _convert_from_v0_to_v1(message: BaseMessage) -> MessageV1:
|
||||
"""Convert a v0 message to a v1 message."""
|
||||
if isinstance(message, HumanMessage): # Checking for v0 HumanMessage
|
||||
return HumanMessageV1(message.content, id=message.id, name=message.name) # type: ignore[arg-type]
|
||||
if isinstance(message, AIMessage): # Checking for v0 AIMessage
|
||||
return AIMessageV1(
|
||||
content=message.content, # type: ignore[arg-type]
|
||||
id=message.id,
|
||||
name=message.name,
|
||||
lc_version="v1",
|
||||
response_metadata=message.response_metadata, # type: ignore[arg-type]
|
||||
usage_metadata=message.usage_metadata,
|
||||
tool_calls=message.tool_calls,
|
||||
invalid_tool_calls=message.invalid_tool_calls,
|
||||
)
|
||||
if isinstance(message, SystemMessage): # Checking for v0 SystemMessage
|
||||
return SystemMessageV1(
|
||||
message.content, # type: ignore[arg-type]
|
||||
id=message.id,
|
||||
name=message.name,
|
||||
)
|
||||
if isinstance(message, ToolMessage): # Checking for v0 ToolMessage
|
||||
return ToolMessageV1(
|
||||
message.content, # type: ignore[arg-type]
|
||||
message.tool_call_id,
|
||||
id=message.id,
|
||||
name=message.name,
|
||||
artifact=message.artifact,
|
||||
status=message.status,
|
||||
)
|
||||
msg = f"Unsupported v0 message type for conversion to v1: {type(message)}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
def _safe_convert_from_v0_to_v1(message: BaseMessage) -> MessageV1:
|
||||
"""Convert a v0 message to a v1 message."""
|
||||
from langchain_core.messages.content_blocks import create_text_block
|
||||
|
||||
if isinstance(message, HumanMessage): # Checking for v0 HumanMessage
|
||||
content: list[ContentBlock] = [create_text_block(str(message.content))]
|
||||
return HumanMessageV1(content, id=message.id, name=message.name)
|
||||
if isinstance(message, AIMessage): # Checking for v0 AIMessage
|
||||
content = [create_text_block(str(message.content))]
|
||||
|
||||
# Construct ResponseMetadata TypedDict from v0 response_metadata dict
|
||||
# Since ResponseMetadata has total=False, we can safely cast the dict
|
||||
response_metadata = cast("ResponseMetadata", message.response_metadata or {})
|
||||
return AIMessageV1(
|
||||
content=content,
|
||||
id=message.id,
|
||||
name=message.name,
|
||||
lc_version="v1",
|
||||
response_metadata=response_metadata,
|
||||
usage_metadata=message.usage_metadata,
|
||||
tool_calls=message.tool_calls,
|
||||
invalid_tool_calls=message.invalid_tool_calls,
|
||||
)
|
||||
if isinstance(message, SystemMessage): # Checking for v0 SystemMessage
|
||||
content = [create_text_block(str(message.content))]
|
||||
return SystemMessageV1(content=content, id=message.id, name=message.name)
|
||||
if isinstance(message, ToolMessage): # Checking for v0 ToolMessage
|
||||
content = [create_text_block(str(message.content))]
|
||||
return ToolMessageV1(
|
||||
content,
|
||||
message.tool_call_id,
|
||||
id=message.id,
|
||||
name=message.name,
|
||||
artifact=message.artifact,
|
||||
status=message.status,
|
||||
)
|
||||
msg = f"Unsupported v0 message type for conversion to v1: {type(message)}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
def _convert_to_message_v1(message: MessageLikeRepresentation) -> MessageV1:
|
||||
"""Instantiate a message from a variety of message formats.
|
||||
|
||||
The message format can be one of the following:
|
||||
|
||||
- BaseMessagePromptTemplate
|
||||
- BaseMessage
|
||||
- 2-tuple of (role string, template); e.g., ("human", "{user_input}")
|
||||
- dict: a message dict with role and content keys
|
||||
- string: shorthand for ("human", template); e.g., "{user_input}"
|
||||
|
||||
Args:
|
||||
message: a representation of a message in one of the supported formats.
|
||||
|
||||
Returns:
|
||||
an instance of a message or a message template.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: if the message type is not supported.
|
||||
ValueError: if the message dict does not contain the required keys.
|
||||
"""
|
||||
if isinstance(message, MessageV1Types):
|
||||
if isinstance(message, AIMessageChunkV1):
|
||||
message_: MessageV1 = message.to_message()
|
||||
else:
|
||||
message_ = message
|
||||
elif isinstance(message, BaseMessage):
|
||||
# Convert v0 messages to v1 messages
|
||||
message_ = _convert_from_v0_to_v1(message)
|
||||
elif isinstance(message, str):
|
||||
message_ = _create_message_from_message_type_v1("human", message)
|
||||
elif isinstance(message, Sequence) and len(message) == 2:
|
||||
# mypy doesn't realise this can't be a string given the previous branch
|
||||
message_type_str, template = message # type: ignore[misc]
|
||||
message_ = _create_message_from_message_type_v1(message_type_str, template)
|
||||
elif isinstance(message, dict):
|
||||
msg_kwargs = message.copy()
|
||||
try:
|
||||
try:
|
||||
msg_type = msg_kwargs.pop("role")
|
||||
except KeyError:
|
||||
msg_type = msg_kwargs.pop("type")
|
||||
# None msg content is not allowed
|
||||
msg_content = msg_kwargs.pop("content") or ""
|
||||
except KeyError as e:
|
||||
msg = f"Message dict must contain 'role' and 'content' keys, got {message}"
|
||||
msg = create_message(
|
||||
message=msg, error_code=ErrorCode.MESSAGE_COERCION_FAILURE
|
||||
)
|
||||
raise ValueError(msg) from e
|
||||
message_ = _create_message_from_message_type_v1(
|
||||
msg_type, msg_content, **msg_kwargs
|
||||
)
|
||||
else:
|
||||
msg = f"Unsupported message type: {type(message)}"
|
||||
msg = create_message(message=msg, error_code=ErrorCode.MESSAGE_COERCION_FAILURE)
|
||||
@@ -368,6 +636,25 @@ def convert_to_messages(
|
||||
return [_convert_to_message(m) for m in messages]
|
||||
|
||||
|
||||
def convert_to_messages_v1(
|
||||
messages: Union[Iterable[MessageLikeRepresentation], PromptValue],
|
||||
) -> list[MessageV1]:
|
||||
"""Convert a sequence of messages to a list of messages.
|
||||
|
||||
Args:
|
||||
messages: Sequence of messages to convert.
|
||||
|
||||
Returns:
|
||||
list of messages (BaseMessages).
|
||||
"""
|
||||
# Import here to avoid circular imports
|
||||
from langchain_core.prompt_values import PromptValue
|
||||
|
||||
if isinstance(messages, PromptValue):
|
||||
return messages.to_messages(message_version="v1")
|
||||
return [_convert_to_message_v1(m) for m in messages]
|
||||
|
||||
|
||||
def _runnable_support(func: Callable) -> Callable:
|
||||
@overload
|
||||
def wrapped(
|
||||
@@ -656,22 +943,23 @@ def trim_messages(
|
||||
properties:
|
||||
|
||||
1. The resulting chat history should be valid. Most chat models expect that chat
|
||||
history starts with either (1) a `HumanMessage` or (2) a `SystemMessage` followed
|
||||
by a `HumanMessage`. To achieve this, set `start_on="human"`.
|
||||
In addition, generally a `ToolMessage` can only appear after an `AIMessage`
|
||||
history starts with either (1) a ``HumanMessage`` or (2) a ``SystemMessage`` followed
|
||||
by a ``HumanMessage``. To achieve this, set ``start_on="human"``.
|
||||
In addition, generally a ``ToolMessage`` can only appear after an ``AIMessage``
|
||||
that involved a tool call.
|
||||
Please see the following link for more information about messages:
|
||||
https://python.langchain.com/docs/concepts/#messages
|
||||
2. It includes recent messages and drops old messages in the chat history.
|
||||
To achieve this set the `strategy="last"`.
|
||||
3. Usually, the new chat history should include the `SystemMessage` if it
|
||||
was present in the original chat history since the `SystemMessage` includes
|
||||
special instructions to the chat model. The `SystemMessage` is almost always
|
||||
To achieve this set the ``strategy="last"``.
|
||||
3. Usually, the new chat history should include the ``SystemMessage`` if it
|
||||
was present in the original chat history since the ``SystemMessage`` includes
|
||||
special instructions to the chat model. The ``SystemMessage`` is almost always
|
||||
the first message in the history if present. To achieve this set the
|
||||
`include_system=True`.
|
||||
``include_system=True``.
|
||||
|
||||
**Note** The examples below show how to configure `trim_messages` to achieve
|
||||
a behavior consistent with the above properties.
|
||||
.. note::
|
||||
The examples below show how to configure ``trim_messages`` to achieve a behavior
|
||||
consistent with the above properties.
|
||||
|
||||
Args:
|
||||
messages: Sequence of Message-like objects to trim.
|
||||
@@ -1007,10 +1295,11 @@ def convert_to_openai_messages(
|
||||
|
||||
oai_messages: list = []
|
||||
|
||||
if is_single := isinstance(messages, (BaseMessage, dict, str)):
|
||||
if is_single := isinstance(messages, (BaseMessage, dict, str, MessageV1Types)):
|
||||
messages = [messages]
|
||||
|
||||
messages = convert_to_messages(messages)
|
||||
# TODO: resolve type ignore here
|
||||
messages = convert_to_messages(messages) # type: ignore[arg-type]
|
||||
|
||||
for i, message in enumerate(messages):
|
||||
oai_msg: dict = {"role": _get_message_openai_role(message)}
|
||||
@@ -1580,26 +1869,26 @@ def count_tokens_approximately(
|
||||
chars_per_token: Number of characters per token to use for the approximation.
|
||||
Default is 4 (one token corresponds to ~4 chars for common English text).
|
||||
You can also specify float values for more fine-grained control.
|
||||
See more here: https://platform.openai.com/tokenizer
|
||||
`See more here. <https://platform.openai.com/tokenizer>`__
|
||||
extra_tokens_per_message: Number of extra tokens to add per message.
|
||||
Default is 3 (special tokens, including beginning/end of message).
|
||||
You can also specify float values for more fine-grained control.
|
||||
See more here:
|
||||
https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
||||
`See more here. <https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb>`__
|
||||
count_name: Whether to include message names in the count.
|
||||
Enabled by default.
|
||||
|
||||
Returns:
|
||||
Approximate number of tokens in the messages.
|
||||
|
||||
Note:
|
||||
This is a simple approximation that may not match the exact token count
|
||||
used by specific models. For accurate counts, use model-specific tokenizers.
|
||||
.. note::
|
||||
This is a simple approximation that may not match the exact token count used by
|
||||
specific models. For accurate counts, use model-specific tokenizers.
|
||||
|
||||
Warning:
|
||||
This function does not currently support counting image tokens.
|
||||
|
||||
.. versionadded:: 0.3.46
|
||||
|
||||
"""
|
||||
token_count = 0.0
|
||||
for message in convert_to_messages(messages):
|
||||
|
||||
@@ -11,6 +11,7 @@ from typing import (
|
||||
Optional,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from typing_extensions import override
|
||||
@@ -20,19 +21,22 @@ from langchain_core.messages import AnyMessage, BaseMessage
|
||||
from langchain_core.outputs import ChatGeneration, Generation
|
||||
from langchain_core.runnables import Runnable, RunnableConfig, RunnableSerializable
|
||||
from langchain_core.runnables.config import run_in_executor
|
||||
from langchain_core.v1.messages import AIMessage, MessageV1, MessageV1Types
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.prompt_values import PromptValue
|
||||
|
||||
T = TypeVar("T")
|
||||
OutputParserLike = Runnable[LanguageModelOutput, T]
|
||||
OutputParserLike = Runnable[Union[LanguageModelOutput, AIMessage], T]
|
||||
|
||||
|
||||
class BaseLLMOutputParser(ABC, Generic[T]):
|
||||
"""Abstract base class for parsing the outputs of a model."""
|
||||
|
||||
@abstractmethod
|
||||
def parse_result(self, result: list[Generation], *, partial: bool = False) -> T:
|
||||
def parse_result(
|
||||
self, result: Union[list[Generation], AIMessage], *, partial: bool = False
|
||||
) -> T:
|
||||
"""Parse a list of candidate model Generations into a specific format.
|
||||
|
||||
Args:
|
||||
@@ -46,7 +50,7 @@ class BaseLLMOutputParser(ABC, Generic[T]):
|
||||
"""
|
||||
|
||||
async def aparse_result(
|
||||
self, result: list[Generation], *, partial: bool = False
|
||||
self, result: Union[list[Generation], AIMessage], *, partial: bool = False
|
||||
) -> T:
|
||||
"""Async parse a list of candidate model Generations into a specific format.
|
||||
|
||||
@@ -71,7 +75,7 @@ class BaseGenerationOutputParser(
|
||||
@override
|
||||
def InputType(self) -> Any:
|
||||
"""Return the input type for the parser."""
|
||||
return Union[str, AnyMessage]
|
||||
return Union[str, AnyMessage, MessageV1]
|
||||
|
||||
@property
|
||||
@override
|
||||
@@ -84,7 +88,7 @@ class BaseGenerationOutputParser(
|
||||
@override
|
||||
def invoke(
|
||||
self,
|
||||
input: Union[str, BaseMessage],
|
||||
input: Union[str, BaseMessage, MessageV1],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> T:
|
||||
@@ -97,9 +101,16 @@ class BaseGenerationOutputParser(
|
||||
config,
|
||||
run_type="parser",
|
||||
)
|
||||
if isinstance(input, MessageV1Types):
|
||||
return self._call_with_config(
|
||||
lambda inner_input: self.parse_result(inner_input),
|
||||
input,
|
||||
config,
|
||||
run_type="parser",
|
||||
)
|
||||
return self._call_with_config(
|
||||
lambda inner_input: self.parse_result([Generation(text=inner_input)]),
|
||||
input,
|
||||
cast("str", input),
|
||||
config,
|
||||
run_type="parser",
|
||||
)
|
||||
@@ -120,6 +131,13 @@ class BaseGenerationOutputParser(
|
||||
config,
|
||||
run_type="parser",
|
||||
)
|
||||
if isinstance(input, MessageV1Types):
|
||||
return await self._acall_with_config(
|
||||
lambda inner_input: self.aparse_result(inner_input),
|
||||
input,
|
||||
config,
|
||||
run_type="parser",
|
||||
)
|
||||
return await self._acall_with_config(
|
||||
lambda inner_input: self.aparse_result([Generation(text=inner_input)]),
|
||||
input,
|
||||
@@ -129,7 +147,7 @@ class BaseGenerationOutputParser(
|
||||
|
||||
|
||||
class BaseOutputParser(
|
||||
BaseLLMOutputParser, RunnableSerializable[LanguageModelOutput, T]
|
||||
BaseLLMOutputParser, RunnableSerializable[Union[LanguageModelOutput, AIMessage], T]
|
||||
):
|
||||
"""Base class to parse the output of an LLM call.
|
||||
|
||||
@@ -162,7 +180,7 @@ class BaseOutputParser(
|
||||
@override
|
||||
def InputType(self) -> Any:
|
||||
"""Return the input type for the parser."""
|
||||
return Union[str, AnyMessage]
|
||||
return Union[str, AnyMessage, MessageV1]
|
||||
|
||||
@property
|
||||
@override
|
||||
@@ -189,7 +207,7 @@ class BaseOutputParser(
|
||||
@override
|
||||
def invoke(
|
||||
self,
|
||||
input: Union[str, BaseMessage],
|
||||
input: Union[str, BaseMessage, MessageV1],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> T:
|
||||
@@ -202,9 +220,16 @@ class BaseOutputParser(
|
||||
config,
|
||||
run_type="parser",
|
||||
)
|
||||
if isinstance(input, MessageV1Types):
|
||||
return self._call_with_config(
|
||||
lambda inner_input: self.parse_result(inner_input),
|
||||
input,
|
||||
config,
|
||||
run_type="parser",
|
||||
)
|
||||
return self._call_with_config(
|
||||
lambda inner_input: self.parse_result([Generation(text=inner_input)]),
|
||||
input,
|
||||
cast("str", input),
|
||||
config,
|
||||
run_type="parser",
|
||||
)
|
||||
@@ -212,7 +237,7 @@ class BaseOutputParser(
|
||||
@override
|
||||
async def ainvoke(
|
||||
self,
|
||||
input: Union[str, BaseMessage],
|
||||
input: Union[str, BaseMessage, MessageV1],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> T:
|
||||
@@ -225,15 +250,24 @@ class BaseOutputParser(
|
||||
config,
|
||||
run_type="parser",
|
||||
)
|
||||
if isinstance(input, MessageV1Types):
|
||||
return await self._acall_with_config(
|
||||
lambda inner_input: self.aparse_result(inner_input),
|
||||
input,
|
||||
config,
|
||||
run_type="parser",
|
||||
)
|
||||
return await self._acall_with_config(
|
||||
lambda inner_input: self.aparse_result([Generation(text=inner_input)]),
|
||||
input,
|
||||
cast("str", input),
|
||||
config,
|
||||
run_type="parser",
|
||||
)
|
||||
|
||||
@override
|
||||
def parse_result(self, result: list[Generation], *, partial: bool = False) -> T:
|
||||
def parse_result(
|
||||
self, result: Union[list[Generation], AIMessage], *, partial: bool = False
|
||||
) -> T:
|
||||
"""Parse a list of candidate model Generations into a specific format.
|
||||
|
||||
The return value is parsed from only the first Generation in the result, which
|
||||
@@ -248,6 +282,8 @@ class BaseOutputParser(
|
||||
Returns:
|
||||
Structured output.
|
||||
"""
|
||||
if isinstance(result, AIMessage):
|
||||
return self.parse(result.text)
|
||||
return self.parse(result[0].text)
|
||||
|
||||
@abstractmethod
|
||||
@@ -262,7 +298,7 @@ class BaseOutputParser(
|
||||
"""
|
||||
|
||||
async def aparse_result(
|
||||
self, result: list[Generation], *, partial: bool = False
|
||||
self, result: Union[list[Generation], AIMessage], *, partial: bool = False
|
||||
) -> T:
|
||||
"""Async parse a list of candidate model Generations into a specific format.
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ from langchain_core.utils.json import (
|
||||
parse_json_markdown,
|
||||
parse_partial_json,
|
||||
)
|
||||
from langchain_core.v1.messages import AIMessage
|
||||
|
||||
# Union type needs to be last assignment to PydanticBaseModel to make mypy happy.
|
||||
PydanticBaseModel = Union[BaseModel, pydantic.BaseModel]
|
||||
@@ -53,7 +54,9 @@ class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
|
||||
return pydantic_object.schema()
|
||||
return None
|
||||
|
||||
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
|
||||
def parse_result(
|
||||
self, result: Union[list[Generation], AIMessage], *, partial: bool = False
|
||||
) -> Any:
|
||||
"""Parse the result of an LLM call to a JSON object.
|
||||
|
||||
Args:
|
||||
@@ -70,7 +73,7 @@ class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
|
||||
Raises:
|
||||
OutputParserException: If the output is not valid JSON.
|
||||
"""
|
||||
text = result[0].text
|
||||
text = result.text if isinstance(result, AIMessage) else result[0].text
|
||||
text = text.strip()
|
||||
if partial:
|
||||
try:
|
||||
|
||||
@@ -13,6 +13,7 @@ from typing_extensions import override
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.output_parsers.transform import BaseTransformOutputParser
|
||||
from langchain_core.v1.messages import AIMessage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncIterator, Iterator
|
||||
@@ -71,7 +72,7 @@ class ListOutputParser(BaseTransformOutputParser[list[str]]):
|
||||
|
||||
@override
|
||||
def _transform(
|
||||
self, input: Iterator[Union[str, BaseMessage]]
|
||||
self, input: Iterator[Union[str, BaseMessage, AIMessage]]
|
||||
) -> Iterator[list[str]]:
|
||||
buffer = ""
|
||||
for chunk in input:
|
||||
@@ -81,6 +82,8 @@ class ListOutputParser(BaseTransformOutputParser[list[str]]):
|
||||
if not isinstance(chunk_content, str):
|
||||
continue
|
||||
buffer += chunk_content
|
||||
elif isinstance(chunk, AIMessage):
|
||||
buffer += chunk.text
|
||||
else:
|
||||
# add current chunk to buffer
|
||||
buffer += chunk
|
||||
@@ -105,7 +108,7 @@ class ListOutputParser(BaseTransformOutputParser[list[str]]):
|
||||
|
||||
@override
|
||||
async def _atransform(
|
||||
self, input: AsyncIterator[Union[str, BaseMessage]]
|
||||
self, input: AsyncIterator[Union[str, BaseMessage, AIMessage]]
|
||||
) -> AsyncIterator[list[str]]:
|
||||
buffer = ""
|
||||
async for chunk in input:
|
||||
@@ -115,6 +118,8 @@ class ListOutputParser(BaseTransformOutputParser[list[str]]):
|
||||
if not isinstance(chunk_content, str):
|
||||
continue
|
||||
buffer += chunk_content
|
||||
elif isinstance(chunk, AIMessage):
|
||||
buffer += chunk.text
|
||||
else:
|
||||
# add current chunk to buffer
|
||||
buffer += chunk
|
||||
|
||||
@@ -17,6 +17,7 @@ from langchain_core.output_parsers import (
|
||||
)
|
||||
from langchain_core.output_parsers.json import parse_partial_json
|
||||
from langchain_core.outputs import ChatGeneration, Generation
|
||||
from langchain_core.v1.messages import AIMessage
|
||||
|
||||
|
||||
class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
|
||||
@@ -26,7 +27,9 @@ class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
|
||||
"""Whether to only return the arguments to the function call."""
|
||||
|
||||
@override
|
||||
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
|
||||
def parse_result(
|
||||
self, result: Union[list[Generation], AIMessage], *, partial: bool = False
|
||||
) -> Any:
|
||||
"""Parse the result of an LLM call to a JSON object.
|
||||
|
||||
Args:
|
||||
@@ -39,6 +42,12 @@ class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
|
||||
Raises:
|
||||
OutputParserException: If the output is not valid JSON.
|
||||
"""
|
||||
if isinstance(result, AIMessage):
|
||||
msg = (
|
||||
"This output parser does not support v1 AIMessages. Use "
|
||||
"JsonOutputToolsParser instead."
|
||||
)
|
||||
raise TypeError(msg)
|
||||
generation = result[0]
|
||||
if not isinstance(generation, ChatGeneration):
|
||||
msg = "This output parser can only be used with a chat generation."
|
||||
@@ -77,7 +86,9 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
|
||||
def _diff(self, prev: Optional[Any], next: Any) -> Any:
|
||||
return jsonpatch.make_patch(prev, next).patch
|
||||
|
||||
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
|
||||
def parse_result(
|
||||
self, result: Union[list[Generation], AIMessage], *, partial: bool = False
|
||||
) -> Any:
|
||||
"""Parse the result of an LLM call to a JSON object.
|
||||
|
||||
Args:
|
||||
@@ -90,6 +101,12 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
|
||||
Raises:
|
||||
OutputParserException: If the output is not valid JSON.
|
||||
"""
|
||||
if isinstance(result, AIMessage):
|
||||
msg = (
|
||||
"This output parser does not support v1 AIMessages. Use "
|
||||
"JsonOutputToolsParser instead."
|
||||
)
|
||||
raise TypeError(msg)
|
||||
if len(result) != 1:
|
||||
msg = f"Expected exactly one result, but got {len(result)}"
|
||||
raise OutputParserException(msg)
|
||||
@@ -160,7 +177,9 @@ class JsonKeyOutputFunctionsParser(JsonOutputFunctionsParser):
|
||||
key_name: str
|
||||
"""The name of the key to return."""
|
||||
|
||||
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
|
||||
def parse_result(
|
||||
self, result: Union[list[Generation], AIMessage], *, partial: bool = False
|
||||
) -> Any:
|
||||
"""Parse the result of an LLM call to a JSON object.
|
||||
|
||||
Args:
|
||||
@@ -254,7 +273,9 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser):
|
||||
return values
|
||||
|
||||
@override
|
||||
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
|
||||
def parse_result(
|
||||
self, result: Union[list[Generation], AIMessage], *, partial: bool = False
|
||||
) -> Any:
|
||||
"""Parse the result of an LLM call to a JSON object.
|
||||
|
||||
Args:
|
||||
@@ -294,7 +315,9 @@ class PydanticAttrOutputFunctionsParser(PydanticOutputFunctionsParser):
|
||||
"""The name of the attribute to return."""
|
||||
|
||||
@override
|
||||
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
|
||||
def parse_result(
|
||||
self, result: Union[list[Generation], AIMessage], *, partial: bool = False
|
||||
) -> Any:
|
||||
"""Parse the result of an LLM call to a JSON object.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -4,7 +4,7 @@ import copy
|
||||
import json
|
||||
import logging
|
||||
from json import JSONDecodeError
|
||||
from typing import Annotated, Any, Optional
|
||||
from typing import Annotated, Any, Optional, Union
|
||||
|
||||
from pydantic import SkipValidation, ValidationError
|
||||
|
||||
@@ -16,6 +16,7 @@ from langchain_core.output_parsers.transform import BaseCumulativeTransformOutpu
|
||||
from langchain_core.outputs import ChatGeneration, Generation
|
||||
from langchain_core.utils.json import parse_partial_json
|
||||
from langchain_core.utils.pydantic import TypeBaseModel
|
||||
from langchain_core.v1.messages import AIMessage as AIMessageV1
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -156,7 +157,9 @@ class JsonOutputToolsParser(BaseCumulativeTransformOutputParser[Any]):
|
||||
If no tool calls are found, None will be returned.
|
||||
"""
|
||||
|
||||
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
|
||||
def parse_result(
|
||||
self, result: Union[list[Generation], AIMessageV1], *, partial: bool = False
|
||||
) -> Any:
|
||||
"""Parse the result of an LLM call to a list of tool calls.
|
||||
|
||||
Args:
|
||||
@@ -173,31 +176,45 @@ class JsonOutputToolsParser(BaseCumulativeTransformOutputParser[Any]):
|
||||
Raises:
|
||||
OutputParserException: If the output is not valid JSON.
|
||||
"""
|
||||
generation = result[0]
|
||||
if not isinstance(generation, ChatGeneration):
|
||||
msg = "This output parser can only be used with a chat generation."
|
||||
raise OutputParserException(msg)
|
||||
message = generation.message
|
||||
if isinstance(message, AIMessage) and message.tool_calls:
|
||||
tool_calls = [dict(tc) for tc in message.tool_calls]
|
||||
if isinstance(result, list):
|
||||
generation = result[0]
|
||||
if not isinstance(generation, ChatGeneration):
|
||||
msg = (
|
||||
"This output parser can only be used with a chat generation or "
|
||||
"v1 AIMessage."
|
||||
)
|
||||
raise OutputParserException(msg)
|
||||
message = generation.message
|
||||
if isinstance(message, AIMessage) and message.tool_calls:
|
||||
tool_calls = [dict(tc) for tc in message.tool_calls]
|
||||
for tool_call in tool_calls:
|
||||
if not self.return_id:
|
||||
_ = tool_call.pop("id")
|
||||
else:
|
||||
try:
|
||||
raw_tool_calls = copy.deepcopy(
|
||||
message.additional_kwargs["tool_calls"]
|
||||
)
|
||||
except KeyError:
|
||||
return []
|
||||
tool_calls = parse_tool_calls(
|
||||
raw_tool_calls,
|
||||
partial=partial,
|
||||
strict=self.strict,
|
||||
return_id=self.return_id,
|
||||
)
|
||||
elif result.tool_calls:
|
||||
# v1 message
|
||||
tool_calls = [dict(tc) for tc in result.tool_calls]
|
||||
for tool_call in tool_calls:
|
||||
if not self.return_id:
|
||||
_ = tool_call.pop("id")
|
||||
else:
|
||||
try:
|
||||
raw_tool_calls = copy.deepcopy(message.additional_kwargs["tool_calls"])
|
||||
except KeyError:
|
||||
return []
|
||||
tool_calls = parse_tool_calls(
|
||||
raw_tool_calls,
|
||||
partial=partial,
|
||||
strict=self.strict,
|
||||
return_id=self.return_id,
|
||||
)
|
||||
return []
|
||||
|
||||
# for backwards compatibility
|
||||
for tc in tool_calls:
|
||||
tc["type"] = tc.pop("name")
|
||||
|
||||
if self.first_tool_only:
|
||||
return tool_calls[0] if tool_calls else None
|
||||
return tool_calls
|
||||
@@ -220,7 +237,9 @@ class JsonOutputKeyToolsParser(JsonOutputToolsParser):
|
||||
key_name: str
|
||||
"""The type of tools to return."""
|
||||
|
||||
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
|
||||
def parse_result(
|
||||
self, result: Union[list[Generation], AIMessageV1], *, partial: bool = False
|
||||
) -> Any:
|
||||
"""Parse the result of an LLM call to a list of tool calls.
|
||||
|
||||
Args:
|
||||
@@ -234,32 +253,47 @@ class JsonOutputKeyToolsParser(JsonOutputToolsParser):
|
||||
Returns:
|
||||
The parsed tool calls.
|
||||
"""
|
||||
generation = result[0]
|
||||
if not isinstance(generation, ChatGeneration):
|
||||
msg = "This output parser can only be used with a chat generation."
|
||||
raise OutputParserException(msg)
|
||||
message = generation.message
|
||||
if isinstance(message, AIMessage) and message.tool_calls:
|
||||
parsed_tool_calls = [dict(tc) for tc in message.tool_calls]
|
||||
if isinstance(result, list):
|
||||
generation = result[0]
|
||||
if not isinstance(generation, ChatGeneration):
|
||||
msg = "This output parser can only be used with a chat generation."
|
||||
raise OutputParserException(msg)
|
||||
message = generation.message
|
||||
if isinstance(message, AIMessage) and message.tool_calls:
|
||||
parsed_tool_calls = [dict(tc) for tc in message.tool_calls]
|
||||
for tool_call in parsed_tool_calls:
|
||||
if not self.return_id:
|
||||
_ = tool_call.pop("id")
|
||||
else:
|
||||
try:
|
||||
raw_tool_calls = copy.deepcopy(
|
||||
message.additional_kwargs["tool_calls"]
|
||||
)
|
||||
except KeyError:
|
||||
if self.first_tool_only:
|
||||
return None
|
||||
return []
|
||||
parsed_tool_calls = parse_tool_calls(
|
||||
raw_tool_calls,
|
||||
partial=partial,
|
||||
strict=self.strict,
|
||||
return_id=self.return_id,
|
||||
)
|
||||
elif result.tool_calls:
|
||||
# v1 message
|
||||
parsed_tool_calls = [dict(tc) for tc in result.tool_calls]
|
||||
for tool_call in parsed_tool_calls:
|
||||
if not self.return_id:
|
||||
_ = tool_call.pop("id")
|
||||
else:
|
||||
try:
|
||||
raw_tool_calls = copy.deepcopy(message.additional_kwargs["tool_calls"])
|
||||
except KeyError:
|
||||
if self.first_tool_only:
|
||||
return None
|
||||
return []
|
||||
parsed_tool_calls = parse_tool_calls(
|
||||
raw_tool_calls,
|
||||
partial=partial,
|
||||
strict=self.strict,
|
||||
return_id=self.return_id,
|
||||
)
|
||||
if self.first_tool_only:
|
||||
return None
|
||||
return []
|
||||
|
||||
# For backwards compatibility
|
||||
for tc in parsed_tool_calls:
|
||||
tc["type"] = tc.pop("name")
|
||||
|
||||
if self.first_tool_only:
|
||||
parsed_result = list(
|
||||
filter(lambda x: x["type"] == self.key_name, parsed_tool_calls)
|
||||
@@ -299,7 +333,9 @@ class PydanticToolsParser(JsonOutputToolsParser):
|
||||
|
||||
# TODO: Support more granular streaming of objects. Currently only streams once all
|
||||
# Pydantic object fields are present.
|
||||
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
|
||||
def parse_result(
|
||||
self, result: Union[list[Generation], AIMessageV1], *, partial: bool = False
|
||||
) -> Any:
|
||||
"""Parse the result of an LLM call to a list of Pydantic objects.
|
||||
|
||||
Args:
|
||||
@@ -337,12 +373,19 @@ class PydanticToolsParser(JsonOutputToolsParser):
|
||||
except (ValidationError, ValueError):
|
||||
if partial:
|
||||
continue
|
||||
has_max_tokens_stop_reason = any(
|
||||
generation.message.response_metadata.get("stop_reason")
|
||||
== "max_tokens"
|
||||
for generation in result
|
||||
if isinstance(generation, ChatGeneration)
|
||||
)
|
||||
has_max_tokens_stop_reason = False
|
||||
if isinstance(result, list):
|
||||
has_max_tokens_stop_reason = any(
|
||||
generation.message.response_metadata.get("stop_reason")
|
||||
== "max_tokens"
|
||||
for generation in result
|
||||
if isinstance(generation, ChatGeneration)
|
||||
)
|
||||
else:
|
||||
# v1 message
|
||||
has_max_tokens_stop_reason = (
|
||||
result.response_metadata.get("stop_reason") == "max_tokens"
|
||||
)
|
||||
if has_max_tokens_stop_reason:
|
||||
logger.exception(_MAX_TOKENS_ERROR)
|
||||
raise
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Output parsers using Pydantic."""
|
||||
|
||||
import json
|
||||
from typing import Annotated, Generic, Optional
|
||||
from typing import Annotated, Generic, Optional, Union
|
||||
|
||||
import pydantic
|
||||
from pydantic import SkipValidation
|
||||
@@ -14,6 +14,7 @@ from langchain_core.utils.pydantic import (
|
||||
PydanticBaseModel,
|
||||
TBaseModel,
|
||||
)
|
||||
from langchain_core.v1.messages import AIMessage
|
||||
|
||||
|
||||
class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
|
||||
@@ -43,7 +44,7 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
|
||||
return OutputParserException(msg, llm_output=json_string)
|
||||
|
||||
def parse_result(
|
||||
self, result: list[Generation], *, partial: bool = False
|
||||
self, result: Union[list[Generation], AIMessage], *, partial: bool = False
|
||||
) -> Optional[TBaseModel]:
|
||||
"""Parse the result of an LLM call to a pydantic object.
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ from langchain_core.outputs import (
|
||||
GenerationChunk,
|
||||
)
|
||||
from langchain_core.runnables.config import run_in_executor
|
||||
from langchain_core.v1.messages import AIMessage, AIMessageChunk
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncIterator, Iterator
|
||||
@@ -32,23 +33,27 @@ class BaseTransformOutputParser(BaseOutputParser[T]):
|
||||
|
||||
def _transform(
|
||||
self,
|
||||
input: Iterator[Union[str, BaseMessage]], # noqa: A002
|
||||
input: Iterator[Union[str, BaseMessage, AIMessage]],
|
||||
) -> Iterator[T]:
|
||||
for chunk in input:
|
||||
if isinstance(chunk, BaseMessage):
|
||||
yield self.parse_result([ChatGeneration(message=chunk)])
|
||||
elif isinstance(chunk, AIMessage):
|
||||
yield self.parse_result(chunk)
|
||||
else:
|
||||
yield self.parse_result([Generation(text=chunk)])
|
||||
|
||||
async def _atransform(
|
||||
self,
|
||||
input: AsyncIterator[Union[str, BaseMessage]], # noqa: A002
|
||||
input: AsyncIterator[Union[str, BaseMessage, AIMessage]],
|
||||
) -> AsyncIterator[T]:
|
||||
async for chunk in input:
|
||||
if isinstance(chunk, BaseMessage):
|
||||
yield await run_in_executor(
|
||||
None, self.parse_result, [ChatGeneration(message=chunk)]
|
||||
)
|
||||
elif isinstance(chunk, AIMessage):
|
||||
yield await run_in_executor(None, self.parse_result, chunk)
|
||||
else:
|
||||
yield await run_in_executor(
|
||||
None, self.parse_result, [Generation(text=chunk)]
|
||||
@@ -57,7 +62,7 @@ class BaseTransformOutputParser(BaseOutputParser[T]):
|
||||
@override
|
||||
def transform(
|
||||
self,
|
||||
input: Iterator[Union[str, BaseMessage]],
|
||||
input: Iterator[Union[str, BaseMessage, AIMessage]],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[T]:
|
||||
@@ -78,7 +83,7 @@ class BaseTransformOutputParser(BaseOutputParser[T]):
|
||||
@override
|
||||
async def atransform(
|
||||
self,
|
||||
input: AsyncIterator[Union[str, BaseMessage]],
|
||||
input: AsyncIterator[Union[str, BaseMessage, AIMessage]],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[T]:
|
||||
@@ -125,23 +130,42 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[Any]:
|
||||
def _transform(
|
||||
self, input: Iterator[Union[str, BaseMessage, AIMessage]]
|
||||
) -> Iterator[Any]:
|
||||
prev_parsed = None
|
||||
acc_gen: Union[GenerationChunk, ChatGenerationChunk, None] = None
|
||||
acc_gen: Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk, None] = (
|
||||
None
|
||||
)
|
||||
for chunk in input:
|
||||
chunk_gen: Union[GenerationChunk, ChatGenerationChunk]
|
||||
chunk_gen: Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]
|
||||
if isinstance(chunk, BaseMessageChunk):
|
||||
chunk_gen = ChatGenerationChunk(message=chunk)
|
||||
elif isinstance(chunk, BaseMessage):
|
||||
chunk_gen = ChatGenerationChunk(
|
||||
message=BaseMessageChunk(**chunk.model_dump())
|
||||
)
|
||||
elif isinstance(chunk, AIMessageChunk):
|
||||
chunk_gen = chunk
|
||||
elif isinstance(chunk, AIMessage):
|
||||
chunk_gen = AIMessageChunk(
|
||||
content=chunk.content,
|
||||
id=chunk.id,
|
||||
name=chunk.name,
|
||||
lc_version=chunk.lc_version,
|
||||
response_metadata=chunk.response_metadata,
|
||||
usage_metadata=chunk.usage_metadata,
|
||||
parsed=chunk.parsed,
|
||||
)
|
||||
else:
|
||||
chunk_gen = GenerationChunk(text=chunk)
|
||||
|
||||
acc_gen = chunk_gen if acc_gen is None else acc_gen + chunk_gen # type: ignore[operator]
|
||||
|
||||
parsed = self.parse_result([acc_gen], partial=True)
|
||||
if isinstance(acc_gen, AIMessageChunk):
|
||||
parsed = self.parse_result(acc_gen, partial=True)
|
||||
else:
|
||||
parsed = self.parse_result([acc_gen], partial=True)
|
||||
if parsed is not None and parsed != prev_parsed:
|
||||
if self.diff:
|
||||
yield self._diff(prev_parsed, parsed)
|
||||
@@ -151,24 +175,41 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
|
||||
|
||||
@override
|
||||
async def _atransform(
|
||||
self, input: AsyncIterator[Union[str, BaseMessage]]
|
||||
self, input: AsyncIterator[Union[str, BaseMessage, AIMessage]]
|
||||
) -> AsyncIterator[T]:
|
||||
prev_parsed = None
|
||||
acc_gen: Union[GenerationChunk, ChatGenerationChunk, None] = None
|
||||
acc_gen: Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk, None] = (
|
||||
None
|
||||
)
|
||||
async for chunk in input:
|
||||
chunk_gen: Union[GenerationChunk, ChatGenerationChunk]
|
||||
chunk_gen: Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]
|
||||
if isinstance(chunk, BaseMessageChunk):
|
||||
chunk_gen = ChatGenerationChunk(message=chunk)
|
||||
elif isinstance(chunk, BaseMessage):
|
||||
chunk_gen = ChatGenerationChunk(
|
||||
message=BaseMessageChunk(**chunk.model_dump())
|
||||
)
|
||||
elif isinstance(chunk, AIMessageChunk):
|
||||
chunk_gen = chunk
|
||||
elif isinstance(chunk, AIMessage):
|
||||
chunk_gen = AIMessageChunk(
|
||||
content=chunk.content,
|
||||
id=chunk.id,
|
||||
name=chunk.name,
|
||||
lc_version=chunk.lc_version,
|
||||
response_metadata=chunk.response_metadata,
|
||||
usage_metadata=chunk.usage_metadata,
|
||||
parsed=chunk.parsed,
|
||||
)
|
||||
else:
|
||||
chunk_gen = GenerationChunk(text=chunk)
|
||||
|
||||
acc_gen = chunk_gen if acc_gen is None else acc_gen + chunk_gen # type: ignore[operator]
|
||||
|
||||
parsed = await self.aparse_result([acc_gen], partial=True)
|
||||
if isinstance(acc_gen, AIMessageChunk):
|
||||
parsed = await self.aparse_result(acc_gen, partial=True)
|
||||
else:
|
||||
parsed = await self.aparse_result([acc_gen], partial=True)
|
||||
if parsed is not None and parsed != prev_parsed:
|
||||
if self.diff:
|
||||
yield await run_in_executor(None, self._diff, prev_parsed, parsed)
|
||||
|
||||
@@ -12,8 +12,10 @@ from typing_extensions import override
|
||||
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages.utils import convert_from_v1_message
|
||||
from langchain_core.output_parsers.transform import BaseTransformOutputParser
|
||||
from langchain_core.runnables.utils import AddableDict
|
||||
from langchain_core.v1.messages import AIMessage
|
||||
|
||||
XML_FORMAT_INSTRUCTIONS = """The output should be formatted as a XML file.
|
||||
1. Output should conform to the tags below.
|
||||
@@ -105,23 +107,27 @@ class _StreamingParser:
|
||||
self.buffer = ""
|
||||
# yield all events
|
||||
try:
|
||||
for event, elem in self.pull_parser.read_events():
|
||||
if event == "start":
|
||||
# update current path
|
||||
self.current_path.append(elem.tag)
|
||||
self.current_path_has_children = False
|
||||
elif event == "end":
|
||||
# remove last element from current path
|
||||
#
|
||||
self.current_path.pop()
|
||||
# yield element
|
||||
if not self.current_path_has_children:
|
||||
yield nested_element(self.current_path, elem)
|
||||
# prevent yielding of parent element
|
||||
if self.current_path:
|
||||
self.current_path_has_children = True
|
||||
else:
|
||||
self.xml_started = False
|
||||
for raw_event in self.pull_parser.read_events():
|
||||
if len(raw_event) <= 1:
|
||||
continue
|
||||
event, elem = raw_event
|
||||
if isinstance(elem, ET.Element):
|
||||
if event == "start":
|
||||
# update current path
|
||||
self.current_path.append(elem.tag)
|
||||
self.current_path_has_children = False
|
||||
elif event == "end":
|
||||
# remove last element from current path
|
||||
#
|
||||
self.current_path.pop()
|
||||
# yield element
|
||||
if not self.current_path_has_children:
|
||||
yield nested_element(self.current_path, elem)
|
||||
# prevent yielding of parent element
|
||||
if self.current_path:
|
||||
self.current_path_has_children = True
|
||||
else:
|
||||
self.xml_started = False
|
||||
except xml.etree.ElementTree.ParseError:
|
||||
# This might be junk at the end of the XML input.
|
||||
# Let's check whether the current path is empty.
|
||||
@@ -240,21 +246,28 @@ class XMLOutputParser(BaseTransformOutputParser):
|
||||
|
||||
@override
|
||||
def _transform(
|
||||
self, input: Iterator[Union[str, BaseMessage]]
|
||||
self, input: Iterator[Union[str, BaseMessage, AIMessage]]
|
||||
) -> Iterator[AddableDict]:
|
||||
streaming_parser = _StreamingParser(self.parser)
|
||||
for chunk in input:
|
||||
yield from streaming_parser.parse(chunk)
|
||||
if isinstance(chunk, AIMessage):
|
||||
yield from streaming_parser.parse(convert_from_v1_message(chunk))
|
||||
else:
|
||||
yield from streaming_parser.parse(chunk)
|
||||
streaming_parser.close()
|
||||
|
||||
@override
|
||||
async def _atransform(
|
||||
self, input: AsyncIterator[Union[str, BaseMessage]]
|
||||
self, input: AsyncIterator[Union[str, BaseMessage, AIMessage]]
|
||||
) -> AsyncIterator[AddableDict]:
|
||||
streaming_parser = _StreamingParser(self.parser)
|
||||
async for chunk in input:
|
||||
for output in streaming_parser.parse(chunk):
|
||||
yield output
|
||||
if isinstance(chunk, AIMessage):
|
||||
for output in streaming_parser.parse(convert_from_v1_message(chunk)):
|
||||
yield output
|
||||
else:
|
||||
for output in streaming_parser.parse(chunk):
|
||||
yield output
|
||||
streaming_parser.close()
|
||||
|
||||
def _root_to_dict(self, root: ET.Element) -> dict[str, Union[str, list[Any]]]:
|
||||
|
||||
@@ -8,17 +8,65 @@ from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from typing import Literal, cast
|
||||
from typing import Literal, Union, cast
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
from typing_extensions import TypedDict, overload
|
||||
|
||||
from langchain_core.load.serializable import Serializable
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AnyMessage,
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
ToolMessage,
|
||||
get_buffer_string,
|
||||
)
|
||||
from langchain_core.messages import content_blocks as types
|
||||
from langchain_core.v1.messages import AIMessage as AIMessageV1
|
||||
from langchain_core.v1.messages import HumanMessage as HumanMessageV1
|
||||
from langchain_core.v1.messages import MessageV1, ResponseMetadata
|
||||
from langchain_core.v1.messages import SystemMessage as SystemMessageV1
|
||||
from langchain_core.v1.messages import ToolMessage as ToolMessageV1
|
||||
|
||||
|
||||
def _convert_to_v1(message: BaseMessage) -> MessageV1:
|
||||
"""Best-effort conversion of a V0 AIMessage to V1."""
|
||||
if isinstance(message.content, str):
|
||||
content: list[types.ContentBlock] = []
|
||||
if message.content:
|
||||
content = [{"type": "text", "text": message.content}]
|
||||
else:
|
||||
content = []
|
||||
for block in message.content:
|
||||
if isinstance(block, str):
|
||||
content.append({"type": "text", "text": block})
|
||||
elif isinstance(block, dict):
|
||||
content.append(cast("types.ContentBlock", block))
|
||||
else:
|
||||
pass
|
||||
|
||||
if isinstance(message, HumanMessage):
|
||||
return HumanMessageV1(content=content)
|
||||
if isinstance(message, AIMessage):
|
||||
for tool_call in message.tool_calls:
|
||||
content.append(tool_call)
|
||||
return AIMessageV1(
|
||||
content=content,
|
||||
usage_metadata=message.usage_metadata,
|
||||
response_metadata=cast("ResponseMetadata", message.response_metadata),
|
||||
tool_calls=message.tool_calls,
|
||||
)
|
||||
if isinstance(message, SystemMessage):
|
||||
return SystemMessageV1(content=content)
|
||||
if isinstance(message, ToolMessage):
|
||||
return ToolMessageV1(
|
||||
tool_call_id=message.tool_call_id,
|
||||
content=content,
|
||||
artifact=message.artifact,
|
||||
)
|
||||
error_message = f"Unsupported message type: {type(message)}"
|
||||
raise TypeError(error_message)
|
||||
|
||||
|
||||
class PromptValue(Serializable, ABC):
|
||||
@@ -46,8 +94,18 @@ class PromptValue(Serializable, ABC):
|
||||
def to_string(self) -> str:
|
||||
"""Return prompt value as string."""
|
||||
|
||||
@overload
|
||||
def to_messages(
|
||||
self, message_version: Literal["v0"] = "v0"
|
||||
) -> list[BaseMessage]: ...
|
||||
|
||||
@overload
|
||||
def to_messages(self, message_version: Literal["v1"]) -> list[MessageV1]: ...
|
||||
|
||||
@abstractmethod
|
||||
def to_messages(self) -> list[BaseMessage]:
|
||||
def to_messages(
|
||||
self, message_version: Literal["v0", "v1"] = "v0"
|
||||
) -> Union[Sequence[BaseMessage], Sequence[MessageV1]]:
|
||||
"""Return prompt as a list of Messages."""
|
||||
|
||||
|
||||
@@ -71,8 +129,20 @@ class StringPromptValue(PromptValue):
|
||||
"""Return prompt as string."""
|
||||
return self.text
|
||||
|
||||
def to_messages(self) -> list[BaseMessage]:
|
||||
@overload
|
||||
def to_messages(
|
||||
self, message_version: Literal["v0"] = "v0"
|
||||
) -> list[BaseMessage]: ...
|
||||
|
||||
@overload
|
||||
def to_messages(self, message_version: Literal["v1"]) -> list[MessageV1]: ...
|
||||
|
||||
def to_messages(
|
||||
self, message_version: Literal["v0", "v1"] = "v0"
|
||||
) -> Union[Sequence[BaseMessage], Sequence[MessageV1]]:
|
||||
"""Return prompt as messages."""
|
||||
if message_version == "v1":
|
||||
return [HumanMessageV1(content=self.text)]
|
||||
return [HumanMessage(content=self.text)]
|
||||
|
||||
|
||||
@@ -89,8 +159,24 @@ class ChatPromptValue(PromptValue):
|
||||
"""Return prompt as string."""
|
||||
return get_buffer_string(self.messages)
|
||||
|
||||
def to_messages(self) -> list[BaseMessage]:
|
||||
"""Return prompt as a list of messages."""
|
||||
@overload
|
||||
def to_messages(
|
||||
self, message_version: Literal["v0"] = "v0"
|
||||
) -> list[BaseMessage]: ...
|
||||
|
||||
@overload
|
||||
def to_messages(self, message_version: Literal["v1"]) -> list[MessageV1]: ...
|
||||
|
||||
def to_messages(
|
||||
self, message_version: Literal["v0", "v1"] = "v0"
|
||||
) -> Union[Sequence[BaseMessage], Sequence[MessageV1]]:
|
||||
"""Return prompt as a list of messages.
|
||||
|
||||
Args:
|
||||
message_version: The output version, either "v0" (default) or "v1".
|
||||
"""
|
||||
if message_version == "v1":
|
||||
return [_convert_to_v1(m) for m in self.messages]
|
||||
return list(self.messages)
|
||||
|
||||
@classmethod
|
||||
@@ -125,8 +211,26 @@ class ImagePromptValue(PromptValue):
|
||||
"""Return prompt (image URL) as string."""
|
||||
return self.image_url["url"]
|
||||
|
||||
def to_messages(self) -> list[BaseMessage]:
|
||||
@overload
|
||||
def to_messages(
|
||||
self, message_version: Literal["v0"] = "v0"
|
||||
) -> list[BaseMessage]: ...
|
||||
|
||||
@overload
|
||||
def to_messages(self, message_version: Literal["v1"]) -> list[MessageV1]: ...
|
||||
|
||||
def to_messages(
|
||||
self, message_version: Literal["v0", "v1"] = "v0"
|
||||
) -> Union[Sequence[BaseMessage], Sequence[MessageV1]]:
|
||||
"""Return prompt (image URL) as messages."""
|
||||
if message_version == "v1":
|
||||
block: types.ImageContentBlock = {
|
||||
"type": "image",
|
||||
"url": self.image_url["url"],
|
||||
}
|
||||
if "detail" in self.image_url:
|
||||
block["detail"] = self.image_url["detail"]
|
||||
return [HumanMessageV1(content=[block])]
|
||||
return [HumanMessage(content=[cast("dict", self.image_url)])]
|
||||
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -402,7 +402,7 @@ def call_func_with_variable_args(
|
||||
Callable[[Input, CallbackManagerForChainRun], Output],
|
||||
Callable[[Input, CallbackManagerForChainRun, RunnableConfig], Output],
|
||||
],
|
||||
input: Input, # noqa: A002
|
||||
input: Input,
|
||||
config: RunnableConfig,
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
**kwargs: Any,
|
||||
@@ -439,7 +439,7 @@ def acall_func_with_variable_args(
|
||||
Awaitable[Output],
|
||||
],
|
||||
],
|
||||
input: Input, # noqa: A002
|
||||
input: Input,
|
||||
config: RunnableConfig,
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
**kwargs: Any,
|
||||
|
||||
@@ -5,7 +5,7 @@ import inspect
|
||||
import typing
|
||||
from collections.abc import AsyncIterator, Iterator, Sequence
|
||||
from functools import wraps
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from typing_extensions import override
|
||||
@@ -397,7 +397,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
||||
)
|
||||
)
|
||||
|
||||
to_return = {}
|
||||
to_return: dict[int, Union[Output, BaseException]] = {}
|
||||
run_again = dict(enumerate(inputs))
|
||||
handled_exceptions: dict[int, BaseException] = {}
|
||||
first_to_raise = None
|
||||
@@ -447,7 +447,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
||||
if not return_exceptions and sorted_handled_exceptions:
|
||||
raise sorted_handled_exceptions[0][1]
|
||||
to_return.update(handled_exceptions)
|
||||
return [output for _, output in sorted(to_return.items())] # type: ignore[misc]
|
||||
return [cast("Output", output) for _, output in sorted(to_return.items())]
|
||||
|
||||
@override
|
||||
def stream(
|
||||
@@ -569,7 +569,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
||||
async for chunk in stream:
|
||||
yield chunk
|
||||
try:
|
||||
output = output + chunk
|
||||
output = output + chunk # type: ignore[operator]
|
||||
except TypeError:
|
||||
output = None
|
||||
except BaseException as e:
|
||||
|
||||
@@ -114,7 +114,7 @@ class Node(NamedTuple):
|
||||
def copy(
|
||||
self,
|
||||
*,
|
||||
id: Optional[str] = None, # noqa: A002
|
||||
id: Optional[str] = None,
|
||||
name: Optional[str] = None,
|
||||
) -> Node:
|
||||
"""Return a copy of the node with optional new id and name.
|
||||
@@ -187,7 +187,7 @@ class MermaidDrawMethod(Enum):
|
||||
|
||||
|
||||
def node_data_str(
|
||||
id: str, # noqa: A002
|
||||
id: str,
|
||||
data: Union[type[BaseModel], RunnableType, None],
|
||||
) -> str:
|
||||
"""Convert the data of a node to a string.
|
||||
@@ -328,7 +328,7 @@ class Graph:
|
||||
def add_node(
|
||||
self,
|
||||
data: Union[type[BaseModel], RunnableType, None],
|
||||
id: Optional[str] = None, # noqa: A002
|
||||
id: Optional[str] = None,
|
||||
*,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
) -> Node:
|
||||
|
||||
@@ -68,13 +68,21 @@ from langchain_core.utils.pydantic import (
|
||||
is_pydantic_v1_subclass,
|
||||
is_pydantic_v2_subclass,
|
||||
)
|
||||
from langchain_core.v1.messages import ToolMessage as ToolMessageV1
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import uuid
|
||||
from collections.abc import Sequence
|
||||
|
||||
FILTERED_ARGS = ("run_manager", "callbacks")
|
||||
TOOL_MESSAGE_BLOCK_TYPES = ("text", "image_url", "image", "json", "search_result")
|
||||
TOOL_MESSAGE_BLOCK_TYPES = (
|
||||
"text",
|
||||
"image_url",
|
||||
"image",
|
||||
"json",
|
||||
"search_result",
|
||||
"custom_tool_call_output",
|
||||
)
|
||||
|
||||
|
||||
class SchemaAnnotationError(TypeError):
|
||||
@@ -498,6 +506,15 @@ class ChildTool(BaseTool):
|
||||
two-tuple corresponding to the (content, artifact) of a ToolMessage.
|
||||
"""
|
||||
|
||||
message_version: Literal["v0", "v1"] = "v0"
|
||||
"""Version of ToolMessage to return given
|
||||
:class:`~langchain_core.messages.content_blocks.ToolCall` input.
|
||||
|
||||
If ``"v0"``, output will be a v0 :class:`~langchain_core.messages.tool.ToolMessage`.
|
||||
If ``"v1"``, output will be a v1 :class:`~langchain_core.v1.messages.ToolMessage`.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""Initialize the tool."""
|
||||
if (
|
||||
@@ -835,7 +852,7 @@ class ChildTool(BaseTool):
|
||||
|
||||
content = None
|
||||
artifact = None
|
||||
status = "success"
|
||||
status: Literal["success", "error"] = "success"
|
||||
error_to_raise: Union[Exception, KeyboardInterrupt, None] = None
|
||||
try:
|
||||
child_config = patch_config(config, callbacks=run_manager.get_child())
|
||||
@@ -879,7 +896,14 @@ class ChildTool(BaseTool):
|
||||
if error_to_raise:
|
||||
run_manager.on_tool_error(error_to_raise)
|
||||
raise error_to_raise
|
||||
output = _format_output(content, artifact, tool_call_id, self.name, status)
|
||||
output = _format_output(
|
||||
content,
|
||||
artifact,
|
||||
tool_call_id,
|
||||
self.name,
|
||||
status,
|
||||
message_version=self.message_version,
|
||||
)
|
||||
run_manager.on_tool_end(output, color=color, name=self.name, **kwargs)
|
||||
return output
|
||||
|
||||
@@ -945,7 +969,7 @@ class ChildTool(BaseTool):
|
||||
)
|
||||
content = None
|
||||
artifact = None
|
||||
status = "success"
|
||||
status: Literal["success", "error"] = "success"
|
||||
error_to_raise: Optional[Union[Exception, KeyboardInterrupt]] = None
|
||||
try:
|
||||
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input, tool_call_id)
|
||||
@@ -993,7 +1017,14 @@ class ChildTool(BaseTool):
|
||||
await run_manager.on_tool_error(error_to_raise)
|
||||
raise error_to_raise
|
||||
|
||||
output = _format_output(content, artifact, tool_call_id, self.name, status)
|
||||
output = _format_output(
|
||||
content,
|
||||
artifact,
|
||||
tool_call_id,
|
||||
self.name,
|
||||
status,
|
||||
message_version=self.message_version,
|
||||
)
|
||||
await run_manager.on_tool_end(output, color=color, name=self.name, **kwargs)
|
||||
return output
|
||||
|
||||
@@ -1131,7 +1162,9 @@ def _format_output(
|
||||
artifact: Any,
|
||||
tool_call_id: Optional[str],
|
||||
name: str,
|
||||
status: str,
|
||||
status: Literal["success", "error"],
|
||||
*,
|
||||
message_version: Literal["v0", "v1"] = "v0",
|
||||
) -> Union[ToolOutputMixin, Any]:
|
||||
"""Format tool output as a ToolMessage if appropriate.
|
||||
|
||||
@@ -1141,6 +1174,7 @@ def _format_output(
|
||||
tool_call_id: The ID of the tool call.
|
||||
name: The name of the tool.
|
||||
status: The execution status.
|
||||
message_version: The version of the ToolMessage to return.
|
||||
|
||||
Returns:
|
||||
The formatted output, either as a ToolMessage or the original content.
|
||||
@@ -1149,7 +1183,15 @@ def _format_output(
|
||||
return content
|
||||
if not _is_message_content_type(content):
|
||||
content = _stringify(content)
|
||||
return ToolMessage(
|
||||
if message_version == "v0":
|
||||
return ToolMessage(
|
||||
content,
|
||||
artifact=artifact,
|
||||
tool_call_id=tool_call_id,
|
||||
name=name,
|
||||
status=status,
|
||||
)
|
||||
return ToolMessageV1(
|
||||
content,
|
||||
artifact=artifact,
|
||||
tool_call_id=tool_call_id,
|
||||
|
||||
@@ -22,6 +22,7 @@ def tool(
|
||||
response_format: Literal["content", "content_and_artifact"] = "content",
|
||||
parse_docstring: bool = False,
|
||||
error_on_invalid_docstring: bool = True,
|
||||
message_version: Literal["v0", "v1"] = "v0",
|
||||
) -> Callable[[Union[Callable, Runnable]], BaseTool]: ...
|
||||
|
||||
|
||||
@@ -37,6 +38,7 @@ def tool(
|
||||
response_format: Literal["content", "content_and_artifact"] = "content",
|
||||
parse_docstring: bool = False,
|
||||
error_on_invalid_docstring: bool = True,
|
||||
message_version: Literal["v0", "v1"] = "v0",
|
||||
) -> BaseTool: ...
|
||||
|
||||
|
||||
@@ -51,6 +53,7 @@ def tool(
|
||||
response_format: Literal["content", "content_and_artifact"] = "content",
|
||||
parse_docstring: bool = False,
|
||||
error_on_invalid_docstring: bool = True,
|
||||
message_version: Literal["v0", "v1"] = "v0",
|
||||
) -> BaseTool: ...
|
||||
|
||||
|
||||
@@ -65,6 +68,7 @@ def tool(
|
||||
response_format: Literal["content", "content_and_artifact"] = "content",
|
||||
parse_docstring: bool = False,
|
||||
error_on_invalid_docstring: bool = True,
|
||||
message_version: Literal["v0", "v1"] = "v0",
|
||||
) -> Callable[[Union[Callable, Runnable]], BaseTool]: ...
|
||||
|
||||
|
||||
@@ -79,6 +83,7 @@ def tool(
|
||||
response_format: Literal["content", "content_and_artifact"] = "content",
|
||||
parse_docstring: bool = False,
|
||||
error_on_invalid_docstring: bool = True,
|
||||
message_version: Literal["v0", "v1"] = "v0",
|
||||
) -> Union[
|
||||
BaseTool,
|
||||
Callable[[Union[Callable, Runnable]], BaseTool],
|
||||
@@ -118,6 +123,11 @@ def tool(
|
||||
error_on_invalid_docstring: if ``parse_docstring`` is provided, configure
|
||||
whether to raise ValueError on invalid Google Style docstrings.
|
||||
Defaults to True.
|
||||
message_version: Version of ToolMessage to return given
|
||||
:class:`~langchain_core.messages.content_blocks.ToolCall` input.
|
||||
|
||||
If ``"v0"``, output will be a v0 :class:`~langchain_core.messages.tool.ToolMessage`.
|
||||
If ``"v1"``, output will be a v1 :class:`~langchain_core.v1.messages.ToolMessage`.
|
||||
|
||||
Returns:
|
||||
The tool.
|
||||
@@ -216,7 +226,7 @@ def tool(
|
||||
\"\"\"
|
||||
return bar
|
||||
|
||||
""" # noqa: D214, D410, D411
|
||||
""" # noqa: D214, D410, D411, E501
|
||||
|
||||
def _create_tool_factory(
|
||||
tool_name: str,
|
||||
@@ -274,6 +284,7 @@ def tool(
|
||||
response_format=response_format,
|
||||
parse_docstring=parse_docstring,
|
||||
error_on_invalid_docstring=error_on_invalid_docstring,
|
||||
message_version=message_version,
|
||||
)
|
||||
# If someone doesn't want a schema applied, we must treat it as
|
||||
# a simple string->string function
|
||||
@@ -290,6 +301,7 @@ def tool(
|
||||
return_direct=return_direct,
|
||||
coroutine=coroutine,
|
||||
response_format=response_format,
|
||||
message_version=message_version,
|
||||
)
|
||||
|
||||
return _tool_factory
|
||||
@@ -383,6 +395,7 @@ def convert_runnable_to_tool(
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
arg_types: Optional[dict[str, type]] = None,
|
||||
message_version: Literal["v0", "v1"] = "v0",
|
||||
) -> BaseTool:
|
||||
"""Convert a Runnable into a BaseTool.
|
||||
|
||||
@@ -392,10 +405,15 @@ def convert_runnable_to_tool(
|
||||
name: The name of the tool. Defaults to None.
|
||||
description: The description of the tool. Defaults to None.
|
||||
arg_types: The types of the arguments. Defaults to None.
|
||||
message_version: Version of ToolMessage to return given
|
||||
:class:`~langchain_core.messages.content_blocks.ToolCall` input.
|
||||
|
||||
If ``"v0"``, output will be a v0 :class:`~langchain_core.messages.tool.ToolMessage`.
|
||||
If ``"v1"``, output will be a v1 :class:`~langchain_core.v1.messages.ToolMessage`.
|
||||
|
||||
Returns:
|
||||
The tool.
|
||||
"""
|
||||
""" # noqa: E501
|
||||
if args_schema:
|
||||
runnable = runnable.with_types(input_type=args_schema)
|
||||
description = description or _get_description_from_runnable(runnable)
|
||||
@@ -408,6 +426,7 @@ def convert_runnable_to_tool(
|
||||
func=runnable.invoke,
|
||||
coroutine=runnable.ainvoke,
|
||||
description=description,
|
||||
message_version=message_version,
|
||||
)
|
||||
|
||||
async def ainvoke_wrapper(
|
||||
@@ -435,4 +454,5 @@ def convert_runnable_to_tool(
|
||||
coroutine=ainvoke_wrapper,
|
||||
description=description,
|
||||
args_schema=args_schema,
|
||||
message_version=message_version,
|
||||
)
|
||||
|
||||
@@ -72,6 +72,7 @@ def create_retriever_tool(
|
||||
document_prompt: Optional[BasePromptTemplate] = None,
|
||||
document_separator: str = "\n\n",
|
||||
response_format: Literal["content", "content_and_artifact"] = "content",
|
||||
message_version: Literal["v0", "v1"] = "v1",
|
||||
) -> Tool:
|
||||
r"""Create a tool to do retrieval of documents.
|
||||
|
||||
@@ -88,10 +89,15 @@ def create_retriever_tool(
|
||||
"content_and_artifact" then the output is expected to be a two-tuple
|
||||
corresponding to the (content, artifact) of a ToolMessage (artifact
|
||||
being a list of documents in this case). Defaults to "content".
|
||||
message_version: Version of ToolMessage to return given
|
||||
:class:`~langchain_core.messages.content_blocks.ToolCall` input.
|
||||
|
||||
If ``"v0"``, output will be a v0 :class:`~langchain_core.messages.tool.ToolMessage`.
|
||||
If ``"v1"``, output will be a v1 :class:`~langchain_core.v1.messages.ToolMessage`.
|
||||
|
||||
Returns:
|
||||
Tool class to pass to an agent.
|
||||
"""
|
||||
""" # noqa: E501
|
||||
document_prompt = document_prompt or PromptTemplate.from_template("{page_content}")
|
||||
func = partial(
|
||||
_get_relevant_documents,
|
||||
@@ -114,4 +120,5 @@ def create_retriever_tool(
|
||||
coroutine=afunc,
|
||||
args_schema=RetrieverInput,
|
||||
response_format=response_format,
|
||||
message_version=message_version,
|
||||
)
|
||||
|
||||
@@ -129,6 +129,7 @@ class StructuredTool(BaseTool):
|
||||
response_format: Literal["content", "content_and_artifact"] = "content",
|
||||
parse_docstring: bool = False,
|
||||
error_on_invalid_docstring: bool = False,
|
||||
message_version: Literal["v0", "v1"] = "v0",
|
||||
**kwargs: Any,
|
||||
) -> StructuredTool:
|
||||
"""Create tool from a given function.
|
||||
@@ -157,6 +158,12 @@ class StructuredTool(BaseTool):
|
||||
error_on_invalid_docstring: if ``parse_docstring`` is provided, configure
|
||||
whether to raise ValueError on invalid Google Style docstrings.
|
||||
Defaults to False.
|
||||
message_version: Version of ToolMessage to return given
|
||||
:class:`~langchain_core.messages.content_blocks.ToolCall` input.
|
||||
|
||||
If ``"v0"``, output will be a v0 :class:`~langchain_core.messages.tool.ToolMessage`.
|
||||
If ``"v1"``, output will be a v1 :class:`~langchain_core.v1.messages.ToolMessage`.
|
||||
|
||||
kwargs: Additional arguments to pass to the tool
|
||||
|
||||
Returns:
|
||||
@@ -175,7 +182,7 @@ class StructuredTool(BaseTool):
|
||||
tool = StructuredTool.from_function(add)
|
||||
tool.run(1, 2) # 3
|
||||
|
||||
"""
|
||||
""" # noqa: E501
|
||||
if func is not None:
|
||||
source_function = func
|
||||
elif coroutine is not None:
|
||||
@@ -232,6 +239,7 @@ class StructuredTool(BaseTool):
|
||||
description=description_,
|
||||
return_direct=return_direct,
|
||||
response_format=response_format,
|
||||
message_version=message_version,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ from typing_extensions import override
|
||||
from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
|
||||
from langchain_core.exceptions import TracerException # noqa: F401
|
||||
from langchain_core.tracers.core import _TracerCore
|
||||
from langchain_core.v1.messages import AIMessage, AIMessageChunk, MessageV1
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
@@ -54,7 +55,7 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: dict[str, Any],
|
||||
messages: list[list[BaseMessage]],
|
||||
messages: Union[list[list[BaseMessage]], list[MessageV1]],
|
||||
*,
|
||||
run_id: UUID,
|
||||
tags: Optional[list[str]] = None,
|
||||
@@ -138,7 +139,9 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
|
||||
self,
|
||||
token: str,
|
||||
*,
|
||||
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
|
||||
chunk: Optional[
|
||||
Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]
|
||||
] = None,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
@@ -190,7 +193,9 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
|
||||
)
|
||||
|
||||
@override
|
||||
def on_llm_end(self, response: LLMResult, *, run_id: UUID, **kwargs: Any) -> Run:
|
||||
def on_llm_end(
|
||||
self, response: Union[LLMResult, AIMessage], *, run_id: UUID, **kwargs: Any
|
||||
) -> Run:
|
||||
"""End a trace for an LLM run.
|
||||
|
||||
Args:
|
||||
@@ -562,7 +567,7 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
|
||||
async def on_chat_model_start(
|
||||
self,
|
||||
serialized: dict[str, Any],
|
||||
messages: list[list[BaseMessage]],
|
||||
messages: Union[list[list[BaseMessage]], list[MessageV1]],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
@@ -617,7 +622,9 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
|
||||
self,
|
||||
token: str,
|
||||
*,
|
||||
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
|
||||
chunk: Optional[
|
||||
Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]
|
||||
] = None,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
@@ -646,7 +653,7 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
|
||||
@override
|
||||
async def on_llm_end(
|
||||
self,
|
||||
response: LLMResult,
|
||||
response: Union[LLMResult, AIMessage],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
@@ -882,7 +889,7 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
|
||||
self,
|
||||
run: Run,
|
||||
token: str,
|
||||
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]],
|
||||
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]],
|
||||
) -> None:
|
||||
"""Process new LLM token."""
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ from typing import (
|
||||
|
||||
from langchain_core.exceptions import TracerException
|
||||
from langchain_core.load import dumpd
|
||||
from langchain_core.messages.utils import convert_from_v1_message
|
||||
from langchain_core.outputs import (
|
||||
ChatGeneration,
|
||||
ChatGenerationChunk,
|
||||
@@ -25,6 +26,12 @@ from langchain_core.outputs import (
|
||||
LLMResult,
|
||||
)
|
||||
from langchain_core.tracers.schemas import Run
|
||||
from langchain_core.v1.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
MessageV1,
|
||||
MessageV1Types,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Coroutine, Sequence
|
||||
@@ -156,7 +163,7 @@ class _TracerCore(ABC):
|
||||
def _create_chat_model_run(
|
||||
self,
|
||||
serialized: dict[str, Any],
|
||||
messages: list[list[BaseMessage]],
|
||||
messages: Union[list[list[BaseMessage]], list[MessageV1]],
|
||||
run_id: UUID,
|
||||
tags: Optional[list[str]] = None,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
@@ -181,6 +188,12 @@ class _TracerCore(ABC):
|
||||
start_time = datetime.now(timezone.utc)
|
||||
if metadata:
|
||||
kwargs.update({"metadata": metadata})
|
||||
if isinstance(messages[0], MessageV1Types):
|
||||
# Convert from v1 messages to BaseMessage
|
||||
messages = [
|
||||
[convert_from_v1_message(msg) for msg in messages] # type: ignore[arg-type]
|
||||
]
|
||||
messages = cast("list[list[BaseMessage]]", messages)
|
||||
return Run(
|
||||
id=run_id,
|
||||
parent_run_id=parent_run_id,
|
||||
@@ -230,7 +243,9 @@ class _TracerCore(ABC):
|
||||
self,
|
||||
token: str,
|
||||
run_id: UUID,
|
||||
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
|
||||
chunk: Optional[
|
||||
Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]
|
||||
] = None,
|
||||
parent_run_id: Optional[UUID] = None, # noqa: ARG002
|
||||
) -> Run:
|
||||
"""Append token event to LLM run and return the run."""
|
||||
@@ -276,7 +291,15 @@ class _TracerCore(ABC):
|
||||
)
|
||||
return llm_run
|
||||
|
||||
def _complete_llm_run(self, response: LLMResult, run_id: UUID) -> Run:
|
||||
def _complete_llm_run(
|
||||
self, response: Union[LLMResult, AIMessage], run_id: UUID
|
||||
) -> Run:
|
||||
if isinstance(response, AIMessage):
|
||||
response = LLMResult(
|
||||
generations=[
|
||||
[ChatGeneration(message=convert_from_v1_message(response))]
|
||||
]
|
||||
)
|
||||
llm_run = self._get_run(run_id, run_type={"llm", "chat_model"})
|
||||
if getattr(llm_run, "outputs", None) is None:
|
||||
llm_run.outputs = {}
|
||||
@@ -558,7 +581,7 @@ class _TracerCore(ABC):
|
||||
self,
|
||||
run: Run, # noqa: ARG002
|
||||
token: str, # noqa: ARG002
|
||||
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]], # noqa: ARG002
|
||||
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]], # noqa: ARG002
|
||||
) -> Union[None, Coroutine[Any, Any, None]]:
|
||||
"""Process new LLM token."""
|
||||
return None
|
||||
|
||||
@@ -38,6 +38,7 @@ from langchain_core.runnables.utils import (
|
||||
from langchain_core.tracers._streaming import _StreamingCallbackHandler
|
||||
from langchain_core.tracers.memory_stream import _MemoryStream
|
||||
from langchain_core.utils.aiter import aclosing, py_anext
|
||||
from langchain_core.v1.messages import MessageV1
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncIterator, Iterator, Sequence
|
||||
@@ -45,6 +46,8 @@ if TYPE_CHECKING:
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.runnables import Runnable, RunnableConfig
|
||||
from langchain_core.tracers.log_stream import LogEntry
|
||||
from langchain_core.v1.messages import AIMessage as AIMessageV1
|
||||
from langchain_core.v1.messages import AIMessageChunk as AIMessageChunkV1
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -297,7 +300,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
|
||||
async def on_chat_model_start(
|
||||
self,
|
||||
serialized: dict[str, Any],
|
||||
messages: list[list[BaseMessage]],
|
||||
messages: Union[list[list[BaseMessage]], list[MessageV1]],
|
||||
*,
|
||||
run_id: UUID,
|
||||
tags: Optional[list[str]] = None,
|
||||
@@ -307,6 +310,8 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Start a trace for an LLM run."""
|
||||
# below cast is because type is converted in handle_event
|
||||
messages = cast("list[list[BaseMessage]]", messages)
|
||||
name_ = _assign_name(name, serialized)
|
||||
run_type = "chat_model"
|
||||
|
||||
@@ -407,13 +412,18 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
|
||||
self,
|
||||
token: str,
|
||||
*,
|
||||
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
|
||||
chunk: Optional[
|
||||
Union[GenerationChunk, ChatGenerationChunk, AIMessageChunkV1]
|
||||
] = None,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run on new LLM token. Only available when streaming is enabled."""
|
||||
run_info = self.run_map.get(run_id)
|
||||
chunk = cast(
|
||||
"Optional[Union[GenerationChunk, ChatGenerationChunk]]", chunk
|
||||
) # converted in handle_event
|
||||
chunk_: Union[GenerationChunk, BaseMessageChunk]
|
||||
|
||||
if run_info is None:
|
||||
@@ -456,9 +466,10 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
|
||||
|
||||
@override
|
||||
async def on_llm_end(
|
||||
self, response: LLMResult, *, run_id: UUID, **kwargs: Any
|
||||
self, response: Union[LLMResult, AIMessageV1], *, run_id: UUID, **kwargs: Any
|
||||
) -> None:
|
||||
"""End a trace for an LLM run."""
|
||||
response = cast("LLMResult", response) # converted in handle_event
|
||||
run_info = self.run_map.pop(run_id)
|
||||
inputs_ = run_info["inputs"]
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from __future__ import annotations
|
||||
import logging
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
||||
from uuid import UUID
|
||||
|
||||
from langsmith import Client
|
||||
@@ -21,12 +21,15 @@ from typing_extensions import override
|
||||
|
||||
from langchain_core.env import get_runtime_environment
|
||||
from langchain_core.load import dumpd
|
||||
from langchain_core.messages.utils import convert_from_v1_message
|
||||
from langchain_core.tracers.base import BaseTracer
|
||||
from langchain_core.tracers.schemas import Run
|
||||
from langchain_core.v1.messages import MessageV1Types
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
|
||||
from langchain_core.v1.messages import AIMessageChunk, MessageV1
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_LOGGED = set()
|
||||
@@ -113,7 +116,7 @@ class LangChainTracer(BaseTracer):
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: dict[str, Any],
|
||||
messages: list[list[BaseMessage]],
|
||||
messages: Union[list[list[BaseMessage]], list[MessageV1]],
|
||||
*,
|
||||
run_id: UUID,
|
||||
tags: Optional[list[str]] = None,
|
||||
@@ -140,6 +143,12 @@ class LangChainTracer(BaseTracer):
|
||||
start_time = datetime.now(timezone.utc)
|
||||
if metadata:
|
||||
kwargs.update({"metadata": metadata})
|
||||
if isinstance(messages[0], MessageV1Types):
|
||||
# Convert from v1 messages to BaseMessage
|
||||
messages = [
|
||||
[convert_from_v1_message(msg) for msg in messages] # type: ignore[arg-type]
|
||||
]
|
||||
messages = cast("list[list[BaseMessage]]", messages)
|
||||
chat_model_run = Run(
|
||||
id=run_id,
|
||||
parent_run_id=parent_run_id,
|
||||
@@ -232,7 +241,9 @@ class LangChainTracer(BaseTracer):
|
||||
self,
|
||||
token: str,
|
||||
run_id: UUID,
|
||||
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
|
||||
chunk: Optional[
|
||||
Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]
|
||||
] = None,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
) -> Run:
|
||||
"""Append token event to LLM run and return the run."""
|
||||
|
||||
@@ -34,6 +34,7 @@ if TYPE_CHECKING:
|
||||
|
||||
from langchain_core.runnables.utils import Input, Output
|
||||
from langchain_core.tracers.schemas import Run
|
||||
from langchain_core.v1.messages import AIMessageChunk
|
||||
|
||||
|
||||
class LogEntry(TypedDict):
|
||||
@@ -176,7 +177,7 @@ class RunLog(RunLogPatch):
|
||||
# Then compare that the ops are the same
|
||||
return super().__eq__(other)
|
||||
|
||||
__hash__ = None # type: ignore[assignment]
|
||||
__hash__ = None
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
@@ -485,7 +486,7 @@ class LogStreamCallbackHandler(BaseTracer, _StreamingCallbackHandler):
|
||||
self,
|
||||
run: Run,
|
||||
token: str,
|
||||
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]],
|
||||
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]],
|
||||
) -> None:
|
||||
"""Process new LLM token."""
|
||||
index = self._key_map_by_run_id.get(run.id)
|
||||
|
||||
@@ -277,7 +277,7 @@ def _convert_any_typed_dicts_to_pydantic(
|
||||
)
|
||||
fields: dict = {}
|
||||
for arg, arg_type in annotations_.items():
|
||||
if get_origin(arg_type) is Annotated:
|
||||
if get_origin(arg_type) is Annotated: # type: ignore[comparison-overlap]
|
||||
annotated_args = get_args(arg_type)
|
||||
new_arg_type = _convert_any_typed_dicts_to_pydantic(
|
||||
annotated_args[0], depth=depth + 1, visited=visited
|
||||
@@ -575,12 +575,23 @@ def convert_to_openai_tool(
|
||||
|
||||
Added support for OpenAI's image generation built-in tool.
|
||||
"""
|
||||
from langchain_core.tools import Tool
|
||||
|
||||
if isinstance(tool, dict):
|
||||
if tool.get("type") in _WellKnownOpenAITools:
|
||||
return tool
|
||||
# As of 03.12.25 can be "web_search_preview" or "web_search_preview_2025_03_11"
|
||||
if (tool.get("type") or "").startswith("web_search_preview"):
|
||||
return tool
|
||||
if isinstance(tool, Tool) and (tool.metadata or {}).get("type") == "custom_tool":
|
||||
oai_tool = {
|
||||
"type": "custom",
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
}
|
||||
if tool.metadata is not None and "format" in tool.metadata:
|
||||
oai_tool["format"] = tool.metadata["format"]
|
||||
return oai_tool
|
||||
oai_function = convert_to_openai_function(tool, strict=strict)
|
||||
return {"type": "function", "function": oai_function}
|
||||
|
||||
@@ -616,7 +627,7 @@ def convert_to_json_schema(
|
||||
|
||||
@beta()
|
||||
def tool_example_to_messages(
|
||||
input: str, # noqa: A002
|
||||
input: str,
|
||||
tool_calls: list[BaseModel],
|
||||
tool_outputs: Optional[list[str]] = None,
|
||||
*,
|
||||
@@ -629,15 +640,16 @@ def tool_example_to_messages(
|
||||
|
||||
The list of messages per example by default corresponds to:
|
||||
|
||||
1) HumanMessage: contains the content from which content should be extracted.
|
||||
2) AIMessage: contains the extracted information from the model
|
||||
3) ToolMessage: contains confirmation to the model that the model requested a tool
|
||||
correctly.
|
||||
1. ``HumanMessage``: contains the content from which content should be extracted.
|
||||
2. ``AIMessage``: contains the extracted information from the model
|
||||
3. ``ToolMessage``: contains confirmation to the model that the model requested a
|
||||
tool correctly.
|
||||
|
||||
If `ai_response` is specified, there will be a final AIMessage with that response.
|
||||
If ``ai_response`` is specified, there will be a final ``AIMessage`` with that
|
||||
response.
|
||||
|
||||
The ToolMessage is required because some chat models are hyper-optimized for agents
|
||||
rather than for an extraction use case.
|
||||
The ``ToolMessage`` is required because some chat models are hyper-optimized for
|
||||
agents rather than for an extraction use case.
|
||||
|
||||
Arguments:
|
||||
input: string, the user input
|
||||
@@ -646,7 +658,7 @@ def tool_example_to_messages(
|
||||
tool_outputs: Optional[list[str]], a list of tool call outputs.
|
||||
Does not need to be provided. If not provided, a placeholder value
|
||||
will be inserted. Defaults to None.
|
||||
ai_response: Optional[str], if provided, content for a final AIMessage.
|
||||
ai_response: Optional[str], if provided, content for a final ``AIMessage``.
|
||||
|
||||
Returns:
|
||||
A list of messages
|
||||
@@ -728,6 +740,7 @@ def _parse_google_docstring(
|
||||
"""Parse the function and argument descriptions from the docstring of a function.
|
||||
|
||||
Assumes the function docstring follows Google Python style guide.
|
||||
|
||||
"""
|
||||
if docstring:
|
||||
docstring_blocks = docstring.split("\n\n")
|
||||
|
||||
1
libs/core/langchain_core/v1/__init__.py
Normal file
1
libs/core/langchain_core/v1/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""LangChain v1.0.0 types."""
|
||||
1065
libs/core/langchain_core/v1/chat_models.py
Normal file
1065
libs/core/langchain_core/v1/chat_models.py
Normal file
File diff suppressed because it is too large
Load Diff
986
libs/core/langchain_core/v1/messages.py
Normal file
986
libs/core/langchain_core/v1/messages.py
Normal file
@@ -0,0 +1,986 @@
|
||||
"""LangChain v1.0.0 message format.
|
||||
|
||||
Each message has content that may be comprised of content blocks, defined under
|
||||
``langchain_core.messages.content_blocks``.
|
||||
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Literal, Optional, Union, cast, get_args
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
import langchain_core.messages.content_blocks as types
|
||||
from langchain_core._api.deprecation import warn_deprecated
|
||||
from langchain_core.messages.ai import (
|
||||
_LC_AUTO_PREFIX,
|
||||
_LC_ID_PREFIX,
|
||||
UsageMetadata,
|
||||
add_usage,
|
||||
)
|
||||
from langchain_core.messages.base import merge_content
|
||||
from langchain_core.messages.tool import ToolOutputMixin
|
||||
from langchain_core.messages.tool import invalid_tool_call as create_invalid_tool_call
|
||||
from langchain_core.messages.tool import tool_call as create_tool_call
|
||||
from langchain_core.utils._merge import merge_dicts
|
||||
from langchain_core.utils.json import parse_partial_json
|
||||
|
||||
|
||||
class TextAccessor(str):
|
||||
"""String-like object that supports both property and method access patterns.
|
||||
|
||||
Exists to maintain backward compatibility while transitioning from method-based to
|
||||
property-based text access in message objects. In LangChain <v0.4, message text was
|
||||
accessed via ``.text()`` method calls. In v0.4=<, the preferred pattern is property
|
||||
access via ``.text``.
|
||||
|
||||
Rather than breaking existing code immediately, ``TextAccessor`` allows both
|
||||
patterns:
|
||||
- Modern property access: ``message.text`` (returns string directly)
|
||||
- Legacy method access: ``message.text()`` (callable, emits deprecation warning)
|
||||
|
||||
Examples:
|
||||
>>> msg = AIMessage("Hello world")
|
||||
>>> text = msg.text # Preferred: property access
|
||||
>>> text = msg.text() # Deprecated: method access (shows warning)
|
||||
|
||||
"""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def __new__(cls, value: str) -> "TextAccessor":
|
||||
"""Create new TextAccessor instance."""
|
||||
return str.__new__(cls, value)
|
||||
|
||||
def __call__(self) -> str:
|
||||
"""Enable method-style text access for backward compatibility.
|
||||
|
||||
.. deprecated:: 0.4.0
|
||||
Calling ``.text()`` as a method is deprecated. Use ``.text`` as a property
|
||||
instead. This method will be removed in 2.0.0.
|
||||
|
||||
Returns:
|
||||
The string content, identical to property access.
|
||||
|
||||
"""
|
||||
warn_deprecated(
|
||||
since="0.4.0",
|
||||
message=(
|
||||
"Calling .text() as a method is deprecated. "
|
||||
"Use .text as a property instead (e.g., message.text)."
|
||||
),
|
||||
removal="2.0.0",
|
||||
)
|
||||
return str(self)
|
||||
|
||||
|
||||
def _ensure_id(id_val: Optional[str]) -> str:
|
||||
"""Ensure the ID is a valid string, generating a new UUID if not provided.
|
||||
|
||||
Auto-generated UUIDs are prefixed by ``'lc_'`` to indicate they are
|
||||
LangChain-generated IDs.
|
||||
|
||||
Args:
|
||||
id_val: Optional string ID value to validate.
|
||||
|
||||
Returns:
|
||||
A valid string ID, either the provided value or a new UUID.
|
||||
|
||||
"""
|
||||
return id_val or str(f"{_LC_AUTO_PREFIX}{uuid.uuid4()}")
|
||||
|
||||
|
||||
class ResponseMetadata(TypedDict, total=False):
|
||||
"""Metadata about the response from the AI provider.
|
||||
|
||||
Contains additional information returned by the provider, such as
|
||||
response headers, service tiers, log probabilities, system fingerprints, etc.
|
||||
|
||||
**Extensibility Design:**
|
||||
|
||||
This uses ``total=False`` to allow arbitrary additional keys beyond the typed
|
||||
fields below. This enables provider-specific metadata without breaking type safety:
|
||||
|
||||
- OpenAI might include: ``{"system_fingerprint": "fp_123", "logprobs": {...}}``
|
||||
- Anthropic might include: ``{"stop_reason": "stop_sequence", "usage": {...}}``
|
||||
- Custom providers can add their own fields
|
||||
|
||||
The common fields (``model_provider``, ``model_name``) provide a baseline
|
||||
contract while preserving flexibility for provider innovations.
|
||||
|
||||
"""
|
||||
|
||||
model_provider: str
|
||||
"""Name and version of the provider that created the message (ex: ``'openai'``)."""
|
||||
|
||||
model_name: str
|
||||
"""Name of the model that generated the message."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class AIMessage:
|
||||
"""A v1 message generated by an AI assistant.
|
||||
|
||||
Represents a response from an AI model, including text content, tool calls,
|
||||
and metadata about the generation process.
|
||||
|
||||
Attributes:
|
||||
type: Message type identifier, always ``'ai'``.
|
||||
id: Unique identifier for the message.
|
||||
name: The name/identifier of the agent or assistant that generated this message.
|
||||
lc_version: Encoding version for the message.
|
||||
content: List of content blocks containing the message data.
|
||||
tool_calls: Optional list of tool calls made by the AI.
|
||||
invalid_tool_calls: Optional list of tool calls that failed validation.
|
||||
usage: Optional dictionary containing usage statistics.
|
||||
|
||||
"""
|
||||
|
||||
type: Literal["ai"] = "ai"
|
||||
"""The type of the message. Must be a string that is unique to the message type.
|
||||
|
||||
The purpose of this field is to allow for easy identification of the message type
|
||||
when deserializing messages.
|
||||
|
||||
"""
|
||||
|
||||
name: Optional[str] = None
|
||||
"""The name/identifier of the agent or assistant that generated this message.
|
||||
|
||||
Used primarily in multi-agent systems to track which agent is speaking. Also used by
|
||||
some providers for conversation attribution and context.
|
||||
|
||||
Usage of this field is optional, and whether it's used or not is up to the
|
||||
model implementation.
|
||||
|
||||
**Examples:**
|
||||
|
||||
.. python::
|
||||
|
||||
AIMessage(
|
||||
content= [
|
||||
TextContentBlock("Analysis complete"),
|
||||
],
|
||||
name="research_agent"
|
||||
)
|
||||
|
||||
AIMessage(
|
||||
content= [
|
||||
TextContentBlock("Task routed to specialist"),
|
||||
],
|
||||
name="supervisor"
|
||||
)
|
||||
|
||||
"""
|
||||
|
||||
id: Optional[str] = None
|
||||
"""Unique identifier for the message.
|
||||
|
||||
If the provider assigns a meaningful ID, it should be used here. Otherwise, a
|
||||
LangChain-generated ID will be used.
|
||||
|
||||
"""
|
||||
|
||||
lc_version: str = "v1"
|
||||
"""Encoding version for the message. Used for serialization."""
|
||||
|
||||
content: list[types.ContentBlock] = field(default_factory=list)
|
||||
"""Message content as a list of content blocks."""
|
||||
|
||||
usage_metadata: Optional[UsageMetadata] = None
|
||||
"""If provided, usage metadata for a message, such as token counts."""
|
||||
|
||||
response_metadata: ResponseMetadata = field(
|
||||
default_factory=lambda: ResponseMetadata()
|
||||
)
|
||||
"""Metadata about the response.
|
||||
|
||||
This field should include non-standard data returned by the provider, such as
|
||||
response headers, service tiers, or log probabilities.
|
||||
|
||||
"""
|
||||
|
||||
parsed: Optional[Union[dict[str, Any], BaseModel]] = None
|
||||
"""Auto-parsed message contents, if applicable."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
content: Union[str, list[types.ContentBlock]],
|
||||
id: Optional[str] = None,
|
||||
name: Optional[str] = None,
|
||||
lc_version: str = "v1",
|
||||
response_metadata: Optional[ResponseMetadata] = None,
|
||||
usage_metadata: Optional[UsageMetadata] = None,
|
||||
tool_calls: Optional[list[types.ToolCall]] = None,
|
||||
invalid_tool_calls: Optional[list[types.InvalidToolCall]] = None,
|
||||
parsed: Optional[Union[dict[str, Any], BaseModel]] = None,
|
||||
):
|
||||
"""Initialize a v1 AI message.
|
||||
|
||||
Args:
|
||||
content: Message content as string or list of content blocks.
|
||||
id: Optional unique identifier for the message.
|
||||
name: The name/identifier of the agent or assistant that generated this
|
||||
message.
|
||||
lc_version: Encoding version for the message.
|
||||
response_metadata: Optional metadata about the response.
|
||||
usage_metadata: Optional metadata about token usage.
|
||||
tool_calls: Optional list of tool calls made by the AI. Tool calls should
|
||||
generally be included in message content. If passed on init, they will
|
||||
be added to the content list.
|
||||
invalid_tool_calls: Optional list of tool calls that failed validation.
|
||||
parsed: Optional auto-parsed message contents, if applicable.
|
||||
|
||||
"""
|
||||
if isinstance(content, str):
|
||||
self.content = [types.create_text_block(content)]
|
||||
else:
|
||||
self.content = content
|
||||
|
||||
self.id = _ensure_id(id)
|
||||
self.name = name
|
||||
self.lc_version = lc_version
|
||||
self.usage_metadata = usage_metadata
|
||||
self.parsed = parsed
|
||||
if response_metadata is None:
|
||||
self.response_metadata = {}
|
||||
else:
|
||||
self.response_metadata = response_metadata
|
||||
|
||||
# Add tool calls to content if provided on init
|
||||
if tool_calls:
|
||||
content_tool_calls = {
|
||||
block["id"]
|
||||
for block in self.content
|
||||
if types.is_tool_call_block(block) and "id" in block
|
||||
}
|
||||
for tool_call in tool_calls:
|
||||
if "id" in tool_call and tool_call["id"] in content_tool_calls:
|
||||
continue
|
||||
self.content.append(tool_call)
|
||||
if invalid_tool_calls:
|
||||
content_tool_calls = {
|
||||
block["id"]
|
||||
for block in self.content
|
||||
if types.is_invalid_tool_call_block(block) and "id" in block
|
||||
}
|
||||
for invalid_tool_call in invalid_tool_calls:
|
||||
if (
|
||||
"id" in invalid_tool_call
|
||||
and invalid_tool_call["id"] in content_tool_calls
|
||||
):
|
||||
continue
|
||||
self.content.append(invalid_tool_call)
|
||||
self._tool_calls: list[types.ToolCall] = [
|
||||
block for block in self.content if types.is_tool_call_block(block)
|
||||
]
|
||||
self._invalid_tool_calls: list[types.InvalidToolCall] = [
|
||||
block for block in self.content if types.is_invalid_tool_call_block(block)
|
||||
]
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
"""Extract all text content from the AI message as a string.
|
||||
|
||||
Can be used as both property (``message.text``) and method (``message.text()``).
|
||||
|
||||
.. deprecated:: 0.4.0
|
||||
Calling ``.text()`` as a method is deprecated. Use ``.text`` as a property
|
||||
instead. This method will be removed in 2.0.0.
|
||||
|
||||
"""
|
||||
text_value = "".join(
|
||||
block["text"] for block in self.content if types.is_text_block(block)
|
||||
)
|
||||
return cast("str", TextAccessor(text_value))
|
||||
|
||||
@property
|
||||
def tool_calls(self) -> list[types.ToolCall]:
|
||||
"""Get the tool calls made by the AI."""
|
||||
if not self._tool_calls:
|
||||
self._tool_calls = [
|
||||
block for block in self.content if types.is_tool_call_block(block)
|
||||
]
|
||||
return self._tool_calls
|
||||
|
||||
@tool_calls.setter
|
||||
def tool_calls(self, value: list[types.ToolCall]) -> None:
|
||||
"""Set the tool calls for the AI message."""
|
||||
self._tool_calls = value
|
||||
|
||||
@property
|
||||
def invalid_tool_calls(self) -> list[types.InvalidToolCall]:
|
||||
"""Get the invalid tool calls made by the AI."""
|
||||
if not self._invalid_tool_calls:
|
||||
self._invalid_tool_calls = [
|
||||
block
|
||||
for block in self.content
|
||||
if types.is_invalid_tool_call_block(block)
|
||||
]
|
||||
return self._invalid_tool_calls
|
||||
|
||||
|
||||
@dataclass
|
||||
class AIMessageChunk(AIMessage):
|
||||
"""A partial chunk of an AI message during streaming.
|
||||
|
||||
Represents a portion of an AI response that is delivered incrementally
|
||||
during streaming generation. When AI providers stream responses token-by-token,
|
||||
each chunk contains partial content that gets accumulated into a complete message.
|
||||
|
||||
**Streaming Workflow:**
|
||||
|
||||
1. Provider streams partial responses as ``AIMessageChunk`` objects
|
||||
2. Chunks are accumulated: ``chunk1 + chunk2 + ...``
|
||||
3. Final accumulated chunk can be converted to ``AIMessage`` via ``.to_message()``
|
||||
|
||||
**Tool Call Handling:**
|
||||
|
||||
During streaming, tool calls arrive as ``ToolCallChunk`` objects with partial
|
||||
JSON. When chunks are accumulated, the final chunk (marked with
|
||||
``chunk_position="last"``) triggers parsing of complete tool calls from the
|
||||
accumulated JSON strings.
|
||||
|
||||
**Content Merging:**
|
||||
|
||||
Content blocks are merged intelligently - text blocks combine their strings,
|
||||
tool call chunks accumulate arguments, and other blocks are concatenated.
|
||||
|
||||
Attributes:
|
||||
type: Message type identifier, always ``'ai_chunk'``.
|
||||
id: Unique identifier for the message chunk.
|
||||
name: The name/identifier of the agent or assistant that generated this message.
|
||||
content: List of content blocks containing partial message data.
|
||||
tool_call_chunks: Optional list of partial tool call data.
|
||||
usage_metadata: Optional metadata about token usage and costs.
|
||||
|
||||
"""
|
||||
|
||||
type: Literal["ai_chunk"] = "ai_chunk" # type: ignore[assignment]
|
||||
"""The type of the message. Must be a string that is unique to the message type.
|
||||
|
||||
The purpose of this field is to allow for easy identification of the message type
|
||||
when deserializing messages.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
content: Union[str, list[types.ContentBlock]],
|
||||
*,
|
||||
id: Optional[str] = None,
|
||||
name: Optional[str] = None,
|
||||
lc_version: str = "v1",
|
||||
response_metadata: Optional[ResponseMetadata] = None,
|
||||
usage_metadata: Optional[UsageMetadata] = None,
|
||||
tool_call_chunks: Optional[list[types.ToolCallChunk]] = None,
|
||||
parsed: Optional[Union[dict[str, Any], BaseModel]] = None,
|
||||
chunk_position: Optional[Literal["last"]] = None,
|
||||
):
|
||||
"""Initialize a v1 AI message.
|
||||
|
||||
Args:
|
||||
content: Message content as string or list of content blocks.
|
||||
id: Optional unique identifier for the message.
|
||||
name: The name/identifier of the agent or assistant that generated this
|
||||
message.
|
||||
lc_version: Encoding version for the message.
|
||||
response_metadata: Optional metadata about the response.
|
||||
usage_metadata: Optional metadata about token usage.
|
||||
tool_call_chunks: Optional list of partial tool call data.
|
||||
parsed: Optional auto-parsed message contents, if applicable.
|
||||
chunk_position: Optional position of the chunk in the stream. If ``'last'``,
|
||||
tool calls will be parsed when aggregated into a stream.
|
||||
|
||||
"""
|
||||
if isinstance(content, str):
|
||||
self.content = [{"type": "text", "text": content, "index": 0}]
|
||||
else:
|
||||
self.content = content
|
||||
|
||||
self.id = _ensure_id(id)
|
||||
self.name = name
|
||||
self.lc_version = lc_version
|
||||
self.usage_metadata = usage_metadata
|
||||
self.parsed = parsed
|
||||
self.chunk_position = chunk_position
|
||||
if response_metadata is None:
|
||||
self.response_metadata = {}
|
||||
else:
|
||||
self.response_metadata = response_metadata
|
||||
|
||||
if tool_call_chunks:
|
||||
content_tool_call_chunks = {
|
||||
block["id"]
|
||||
for block in self.content
|
||||
if types.is_tool_call_chunk(block) and "id" in block
|
||||
}
|
||||
for chunk in tool_call_chunks:
|
||||
if "id" in chunk and chunk["id"] in content_tool_call_chunks:
|
||||
continue
|
||||
self.content.append(chunk)
|
||||
self._tool_call_chunks = [
|
||||
block for block in self.content if types.is_tool_call_chunk(block)
|
||||
]
|
||||
|
||||
self._tool_calls: list[types.ToolCall] = []
|
||||
self._invalid_tool_calls: list[types.InvalidToolCall] = []
|
||||
|
||||
@property
|
||||
def tool_call_chunks(self) -> list[types.ToolCallChunk]:
|
||||
"""Get the tool calls made by the AI."""
|
||||
if not self._tool_call_chunks:
|
||||
self._tool_call_chunks = [
|
||||
block for block in self.content if types.is_tool_call_chunk(block)
|
||||
]
|
||||
return self._tool_call_chunks
|
||||
|
||||
@property
|
||||
def tool_calls(self) -> list[types.ToolCall]:
|
||||
"""Get the tool calls made by the AI."""
|
||||
if not self._tool_calls:
|
||||
parsed_content = _init_tool_calls(self.content)
|
||||
tool_calls: list[types.ToolCall] = []
|
||||
invalid_tool_calls: list[types.InvalidToolCall] = []
|
||||
for block in parsed_content:
|
||||
if types.is_tool_call_block(block):
|
||||
tool_calls.append(block)
|
||||
elif types.is_invalid_tool_call_block(block):
|
||||
invalid_tool_calls.append(block)
|
||||
self._tool_calls = tool_calls
|
||||
self._invalid_tool_calls = invalid_tool_calls
|
||||
return self._tool_calls
|
||||
|
||||
@tool_calls.setter
|
||||
def tool_calls(self, value: list[types.ToolCall]) -> None:
|
||||
"""Set the tool calls for the AI message."""
|
||||
self._tool_calls = value
|
||||
|
||||
@property
|
||||
def invalid_tool_calls(self) -> list[types.InvalidToolCall]:
|
||||
"""Get the invalid tool calls made by the AI."""
|
||||
if not self._invalid_tool_calls:
|
||||
parsed_content = _init_tool_calls(self.content)
|
||||
tool_calls: list[types.ToolCall] = []
|
||||
invalid_tool_calls: list[types.InvalidToolCall] = []
|
||||
for block in parsed_content:
|
||||
if types.is_tool_call_block(block):
|
||||
tool_calls.append(block)
|
||||
elif types.is_invalid_tool_call_block(block):
|
||||
invalid_tool_calls.append(block)
|
||||
self._tool_calls = tool_calls
|
||||
self._invalid_tool_calls = invalid_tool_calls
|
||||
return self._invalid_tool_calls
|
||||
|
||||
def __add__(self, other: Any) -> "AIMessageChunk":
|
||||
"""Add ``AIMessageChunk`` to this one."""
|
||||
if isinstance(other, AIMessageChunk):
|
||||
return add_ai_message_chunks(self, other)
|
||||
if isinstance(other, (list, tuple)) and all(
|
||||
isinstance(o, AIMessageChunk) for o in other
|
||||
):
|
||||
return add_ai_message_chunks(self, *other)
|
||||
error_msg = "Can only add AIMessageChunk or sequence of AIMessageChunk."
|
||||
raise NotImplementedError(error_msg)
|
||||
|
||||
def to_message(self) -> "AIMessage":
|
||||
"""Convert this ``AIMessageChunk`` to an ``AIMessage``."""
|
||||
return AIMessage(
|
||||
content=_init_tool_calls(self.content),
|
||||
id=self.id,
|
||||
name=self.name,
|
||||
lc_version=self.lc_version,
|
||||
response_metadata=self.response_metadata,
|
||||
usage_metadata=self.usage_metadata,
|
||||
parsed=self.parsed,
|
||||
)
|
||||
|
||||
|
||||
def _init_tool_calls(content: list[types.ContentBlock]) -> list[types.ContentBlock]:
|
||||
"""Parse tool call chunks in content into tool calls."""
|
||||
new_content = []
|
||||
for block in content:
|
||||
if not types.is_tool_call_chunk(block):
|
||||
new_content.append(block)
|
||||
continue
|
||||
try:
|
||||
args_str = block.get("args")
|
||||
args_ = parse_partial_json(str(args_str)) if args_str else {}
|
||||
if isinstance(args_, dict):
|
||||
new_content.append(
|
||||
create_tool_call(
|
||||
name=block.get("name") or "",
|
||||
args=args_,
|
||||
id=block.get("id", ""),
|
||||
)
|
||||
)
|
||||
else:
|
||||
new_content.append(
|
||||
create_invalid_tool_call(
|
||||
name=block.get("name", ""),
|
||||
args=block.get("args", ""),
|
||||
id=block.get("id", ""),
|
||||
error=None,
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
new_content.append(
|
||||
create_invalid_tool_call(
|
||||
name=block.get("name", ""),
|
||||
args=block.get("args", ""),
|
||||
id=block.get("id", ""),
|
||||
error=None,
|
||||
)
|
||||
)
|
||||
return new_content
|
||||
|
||||
|
||||
def add_ai_message_chunks(
|
||||
left: AIMessageChunk, *others: AIMessageChunk
|
||||
) -> AIMessageChunk:
|
||||
"""Add multiple ``AIMessageChunks`` together."""
|
||||
if not others:
|
||||
return left
|
||||
content = cast(
|
||||
"list[types.ContentBlock]",
|
||||
merge_content(
|
||||
cast("list[str | dict[Any, Any]]", left.content),
|
||||
*(cast("list[str | dict[Any, Any]]", o.content) for o in others),
|
||||
),
|
||||
)
|
||||
response_metadata = merge_dicts(
|
||||
cast("dict", left.response_metadata),
|
||||
*(cast("dict", o.response_metadata) for o in others),
|
||||
)
|
||||
|
||||
# Token usage
|
||||
if left.usage_metadata or any(o.usage_metadata is not None for o in others):
|
||||
usage_metadata: Optional[UsageMetadata] = left.usage_metadata
|
||||
for other in others:
|
||||
usage_metadata = add_usage(usage_metadata, other.usage_metadata)
|
||||
else:
|
||||
usage_metadata = None
|
||||
|
||||
# Parsed
|
||||
# 'parsed' always represents an aggregation not an incremental value, so the last
|
||||
# non-null value is kept.
|
||||
parsed = None
|
||||
for m in reversed([left, *others]):
|
||||
if m.parsed is not None:
|
||||
parsed = m.parsed
|
||||
break
|
||||
|
||||
chunk_id = None
|
||||
candidates = [left.id] + [o.id for o in others]
|
||||
# first pass: pick the first provider-assigned id (non-`run-*` and non-`lc_*`)
|
||||
for id_ in candidates:
|
||||
if (
|
||||
id_
|
||||
and not id_.startswith(_LC_ID_PREFIX)
|
||||
and not id_.startswith(_LC_AUTO_PREFIX)
|
||||
):
|
||||
chunk_id = id_
|
||||
break
|
||||
else:
|
||||
# second pass: prefer lc_run-* ids over lc_* ids
|
||||
for id_ in candidates:
|
||||
if id_ and id_.startswith(_LC_ID_PREFIX):
|
||||
chunk_id = id_
|
||||
break
|
||||
else:
|
||||
# third pass: take any remaining id (auto-generated lc_* ids)
|
||||
for id_ in candidates:
|
||||
if id_:
|
||||
chunk_id = id_
|
||||
break
|
||||
|
||||
chunk_position: Optional[Literal["last"]] = (
|
||||
"last" if any(x.chunk_position == "last" for x in [left, *others]) else None
|
||||
)
|
||||
if chunk_position == "last":
|
||||
content = _init_tool_calls(content)
|
||||
|
||||
return left.__class__(
|
||||
content=content,
|
||||
response_metadata=cast("ResponseMetadata", response_metadata),
|
||||
usage_metadata=usage_metadata,
|
||||
parsed=parsed,
|
||||
id=chunk_id,
|
||||
chunk_position=chunk_position,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class HumanMessage:
|
||||
"""A message from a human user.
|
||||
|
||||
Represents input from a human user in a conversation, containing text
|
||||
or other content types like images.
|
||||
|
||||
Attributes:
|
||||
type: Message type identifier, always ``'human'``.
|
||||
id: Unique identifier for the message.
|
||||
content: List of content blocks containing the user's input.
|
||||
name: Optional identifier for the human user who sent this message.
|
||||
|
||||
"""
|
||||
|
||||
id: str
|
||||
"""Used for serialization.
|
||||
|
||||
If the provider assigns a meaningful ID, it should be used here. Otherwise, a
|
||||
LangChain-generated ID will be used.
|
||||
|
||||
"""
|
||||
|
||||
content: list[types.ContentBlock]
|
||||
"""Message content as a list of content blocks."""
|
||||
|
||||
type: Literal["human"] = "human"
|
||||
"""The type of the message. Must be a string that is unique to the message type.
|
||||
|
||||
The purpose of this field is to allow for easy identification of the message type
|
||||
when deserializing messages.
|
||||
|
||||
"""
|
||||
|
||||
name: Optional[str] = None
|
||||
"""Optional identifier for the human user who sent this message.
|
||||
|
||||
Can be helpful in multi-user scenarios or for conversation tracking. Most chat model
|
||||
providers ignore this field for human messages.
|
||||
|
||||
Usage of this field is optional, and whether it's used or not is up to the
|
||||
model implementation.
|
||||
|
||||
**Examples:**
|
||||
|
||||
.. python::
|
||||
|
||||
HumanMessage(
|
||||
content= [
|
||||
TextContentBlock("Hello"),
|
||||
],
|
||||
name="user_alice"
|
||||
)
|
||||
|
||||
HumanMessage(
|
||||
content= [
|
||||
TextContentBlock("Run analysis"),
|
||||
],
|
||||
name="admin_bob"
|
||||
)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
content: Union[str, list[types.ContentBlock]],
|
||||
*,
|
||||
id: Optional[str] = None,
|
||||
name: Optional[str] = None,
|
||||
):
|
||||
"""Initialize a v1 human message.
|
||||
|
||||
Args:
|
||||
content: Message content as string or list of content blocks.
|
||||
id: Optional unique identifier for the message.
|
||||
name: Optional identifier for the human user who sent this message.
|
||||
|
||||
"""
|
||||
self.id = _ensure_id(id)
|
||||
if isinstance(content, str):
|
||||
self.content = [{"type": "text", "text": content}]
|
||||
else:
|
||||
self.content = content
|
||||
self.name = name
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
"""Extract all text content from the message as a string.
|
||||
|
||||
Can be used as both property (``message.text``) and method (``message.text()``).
|
||||
|
||||
.. deprecated:: 0.4.0
|
||||
Calling ``.text()`` as a method is deprecated. Use ``.text`` as a property
|
||||
instead. This method will be removed in 2.0.0.
|
||||
|
||||
"""
|
||||
text_value = "".join(
|
||||
block["text"] for block in self.content if types.is_text_block(block)
|
||||
)
|
||||
return cast("str", TextAccessor(text_value))
|
||||
|
||||
|
||||
@dataclass
|
||||
class SystemMessage:
|
||||
"""A system message containing instructions or context.
|
||||
|
||||
Represents system-level instructions or context that guides the AI's
|
||||
behavior and understanding of the conversation.
|
||||
|
||||
Attributes:
|
||||
type: Message type identifier, always ``'system'``.
|
||||
id: Unique identifier for the message.
|
||||
content: List of content blocks containing system instructions.
|
||||
|
||||
"""
|
||||
|
||||
id: str
|
||||
"""Used for serialization.
|
||||
|
||||
If the provider assigns a meaningful ID, it should be used here. Otherwise, a
|
||||
LangChain-generated ID will be used.
|
||||
|
||||
"""
|
||||
|
||||
content: list[types.ContentBlock]
|
||||
"""Message content as a list of content blocks."""
|
||||
|
||||
type: Literal["system"] = "system"
|
||||
"""The type of the message. Must be a string that is unique to the message type.
|
||||
|
||||
The purpose of this field is to allow for easy identification of the message type
|
||||
when deserializing messages.
|
||||
|
||||
"""
|
||||
|
||||
name: Optional[str] = None
|
||||
"""Optional identifier for the system component/context that generated this message.
|
||||
|
||||
Can be used to identify different system contexts or configurations.
|
||||
|
||||
Usage of this field is optional, and whether it's used or not is up to the
|
||||
model implementation.
|
||||
|
||||
**Examples:**
|
||||
|
||||
.. python::
|
||||
|
||||
SystemMessage(
|
||||
content= [
|
||||
TextContentBlock("You are a helpful assistant"),
|
||||
],
|
||||
name="base_prompt"
|
||||
)
|
||||
|
||||
SystemMessage(
|
||||
content= [
|
||||
TextContentBlock("Advanced mode enabled"),
|
||||
],
|
||||
name="config_update"
|
||||
)
|
||||
|
||||
"""
|
||||
|
||||
custom_role: Optional[str] = None
|
||||
"""If provided, a custom role for the system message.
|
||||
|
||||
Example: ``'developer'``.
|
||||
|
||||
Integration packages may use this field to assign the system message role if it
|
||||
contains a recognized value.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
content: Union[str, list[types.ContentBlock]],
|
||||
*,
|
||||
id: Optional[str] = None,
|
||||
custom_role: Optional[str] = None,
|
||||
name: Optional[str] = None,
|
||||
):
|
||||
"""Initialize a v1 system message.
|
||||
|
||||
Args:
|
||||
content: Message content as string or list of content blocks.
|
||||
id: Optional unique identifier for the message.
|
||||
custom_role: If provided, a custom role for the system message.
|
||||
name: Optional identifier for the system component/context that generated
|
||||
this message.
|
||||
|
||||
"""
|
||||
self.id = _ensure_id(id)
|
||||
if isinstance(content, str):
|
||||
self.content = [{"type": "text", "text": content}]
|
||||
else:
|
||||
self.content = content
|
||||
self.custom_role = custom_role
|
||||
self.name = name
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
"""Extract all text content from the system message as a string.
|
||||
|
||||
Can be used as both property (``message.text``) and method (``message.text()``).
|
||||
|
||||
.. deprecated:: 0.4.0
|
||||
Calling ``.text()`` as a method is deprecated. Use ``.text`` as a property
|
||||
instead. This method will be removed in 2.0.0.
|
||||
|
||||
"""
|
||||
text_value = "".join(
|
||||
block["text"] for block in self.content if types.is_text_block(block)
|
||||
)
|
||||
return cast("str", TextAccessor(text_value))
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolMessage(ToolOutputMixin):
|
||||
"""A message containing the result of a tool execution.
|
||||
|
||||
Represents the output from executing a tool or function call,
|
||||
including the result data and execution status.
|
||||
|
||||
Attributes:
|
||||
type: Message type identifier, always ``'tool'``.
|
||||
id: Unique identifier for the message.
|
||||
tool_call_id: ID of the tool call this message responds to.
|
||||
content: The result content from tool execution.
|
||||
artifact: Optional app-side payload not intended for the model.
|
||||
name: Name of the tool/function that was executed to generate this message.
|
||||
status: Execution status ("success" or "error").
|
||||
|
||||
"""
|
||||
|
||||
id: str
|
||||
"""Used for serialization."""
|
||||
|
||||
tool_call_id: str
|
||||
"""ID of the tool call this message responds to.
|
||||
|
||||
This should match the ID of the tool call that this message is responding to.
|
||||
|
||||
"""
|
||||
|
||||
content: list[types.ContentBlock]
|
||||
"""Message content as a list of content blocks.
|
||||
|
||||
The tool's output should be included in the content, mapped to the appropriate
|
||||
content block type (e.g., text, image, etc.). For instance, if the tool call returns
|
||||
a string, it should be wrapped in a ``TextContentBlock``.
|
||||
|
||||
"""
|
||||
|
||||
type: Literal["tool"] = "tool"
|
||||
"""The type of the message. Must be a string that is unique to the message type.
|
||||
|
||||
The purpose of this field is to allow for easy identification of the message type
|
||||
when deserializing messages.
|
||||
|
||||
"""
|
||||
|
||||
artifact: Optional[Any] = None
|
||||
"""App-side payload not intended for model consumption.
|
||||
|
||||
Additonal info and usage examples are available
|
||||
`in the LangChain documentation <https://python.langchain.com/docs/concepts/tools/#tool-artifacts>`__.
|
||||
|
||||
"""
|
||||
|
||||
name: Optional[str] = None
|
||||
"""Name of the tool/function that was executed to generate this message.
|
||||
|
||||
.. important::
|
||||
This field is required by most chat model providers (OpenAI, Anthropic,
|
||||
Google, etc.) for proper tool calling. The name must match the tool that was
|
||||
called.
|
||||
|
||||
**Examples:**
|
||||
|
||||
.. python::
|
||||
|
||||
ToolMessage(
|
||||
content= [
|
||||
TextContentBlock("42"),
|
||||
],
|
||||
name="calculator",
|
||||
tool_call_id="call_123"
|
||||
)
|
||||
|
||||
ToolMessage(
|
||||
content= [
|
||||
TextContentBlock("Weather is sunny"),
|
||||
],
|
||||
name="get_weather",
|
||||
tool_call_id="call_456"
|
||||
)
|
||||
|
||||
"""
|
||||
|
||||
status: Literal["success", "error"] = "success"
|
||||
"""Execution status of the tool call.
|
||||
|
||||
Indicates whether the tool call was successful or encountered an error.
|
||||
Defaults to "success".
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
content: Union[str, list[types.ContentBlock]],
|
||||
tool_call_id: str,
|
||||
*,
|
||||
id: Optional[str] = None,
|
||||
name: Optional[str] = None,
|
||||
artifact: Optional[Any] = None,
|
||||
status: Literal["success", "error"] = "success",
|
||||
):
|
||||
"""Initialize a v1 tool message.
|
||||
|
||||
Args:
|
||||
content: Message content as string or list of content blocks.
|
||||
tool_call_id: ID of the tool call this message responds to.
|
||||
id: Optional unique identifier for the message.
|
||||
name: Name of the tool/function that was executed to generate this message.
|
||||
artifact: Optional app-side payload not intended for the model.
|
||||
status: Execution status (``'success'`` or ``'error'``).
|
||||
|
||||
"""
|
||||
self.id = _ensure_id(id)
|
||||
self.tool_call_id = tool_call_id
|
||||
if isinstance(content, str):
|
||||
self.content = [{"type": "text", "text": content}]
|
||||
else:
|
||||
self.content = content
|
||||
self.name = name
|
||||
self.artifact = artifact
|
||||
self.status = status
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
"""Extract all text content from the tool message as a string.
|
||||
|
||||
Can be used as both property (``message.text``) and method (``message.text()``).
|
||||
|
||||
.. deprecated:: 0.4.0
|
||||
Calling ``.text()`` as a method is deprecated. Use ``.text`` as a property
|
||||
instead. This method will be removed in 2.0.0.
|
||||
|
||||
"""
|
||||
text_value = "".join(
|
||||
block["text"] for block in self.content if types.is_text_block(block)
|
||||
)
|
||||
return cast("str", TextAccessor(text_value))
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Initialize computed fields after dataclass creation.
|
||||
|
||||
Ensures the tool message has a valid ID.
|
||||
|
||||
"""
|
||||
self.id = _ensure_id(self.id)
|
||||
|
||||
|
||||
# Alias for a message type that can be any of the defined message types
|
||||
MessageV1 = Union[
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
ToolMessage,
|
||||
]
|
||||
MessageV1Types = get_args(MessageV1)
|
||||
@@ -1,3 +1,3 @@
|
||||
"""langchain-core version information and utilities."""
|
||||
|
||||
VERSION = "0.3.72"
|
||||
VERSION = "0.4.0.dev0"
|
||||
|
||||
@@ -16,7 +16,7 @@ dependencies = [
|
||||
"pydantic>=2.7.4",
|
||||
]
|
||||
name = "langchain-core"
|
||||
version = "0.3.72"
|
||||
version = "0.4.0.dev0"
|
||||
description = "Building applications with LLMs through composability"
|
||||
readme = "README.md"
|
||||
|
||||
@@ -28,7 +28,7 @@ repository = "https://github.com/langchain-ai/langchain"
|
||||
[dependency-groups]
|
||||
lint = ["ruff<0.13,>=0.12.2"]
|
||||
typing = [
|
||||
"mypy<1.16,>=1.15",
|
||||
"mypy<1.18,>=1.17.1",
|
||||
"types-pyyaml<7.0.0.0,>=6.0.12.2",
|
||||
"types-requests<3.0.0.0,>=2.28.11.5",
|
||||
"langchain-text-splitters",
|
||||
@@ -67,6 +67,7 @@ langchain-text-splitters = { path = "../text-splitters" }
|
||||
strict = "True"
|
||||
strict_bytes = "True"
|
||||
enable_error_code = "deprecated"
|
||||
disable_error_code = ["typeddict-unknown-key"]
|
||||
|
||||
# TODO: activate for 'strict' checking
|
||||
disallow_any_generics = "False"
|
||||
@@ -86,6 +87,7 @@ ignore = [
|
||||
"FIX002", # Line contains TODO
|
||||
"ISC001", # Messes with the formatter
|
||||
"PERF203", # Rarely useful
|
||||
"PLC0414", # Enable re-export
|
||||
"PLR09", # Too many something (arg, statements, etc)
|
||||
"RUF012", # Doesn't play well with Pydantic
|
||||
"TC001", # Doesn't play well with Pydantic
|
||||
@@ -105,6 +107,7 @@ unfixable = ["PLW1510",]
|
||||
|
||||
flake8-annotations.allow-star-arg-any = true
|
||||
flake8-annotations.mypy-init-return = true
|
||||
flake8-builtins.ignorelist = ["id", "input", "type"]
|
||||
flake8-type-checking.runtime-evaluated-base-classes = ["pydantic.BaseModel","langchain_core.load.serializable.Serializable","langchain_core.runnables.base.RunnableSerializable"]
|
||||
pep8-naming.classmethod-decorators = [ "classmethod", "langchain_core.utils.pydantic.pre_init", "pydantic.field_validator", "pydantic.v1.root_validator",]
|
||||
pydocstyle.convention = "google"
|
||||
|
||||
@@ -11,6 +11,8 @@ from langchain_core.callbacks.base import AsyncCallbackHandler
|
||||
from langchain_core.language_models import GenericFakeChatModel
|
||||
from langchain_core.messages import AIMessage, BaseMessage
|
||||
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
|
||||
from langchain_core.v1.messages import AIMessageChunk as AIMessageChunkV1
|
||||
from langchain_core.v1.messages import MessageV1
|
||||
|
||||
|
||||
class MyCustomAsyncHandler(AsyncCallbackHandler):
|
||||
@@ -18,7 +20,7 @@ class MyCustomAsyncHandler(AsyncCallbackHandler):
|
||||
async def on_chat_model_start(
|
||||
self,
|
||||
serialized: dict[str, Any],
|
||||
messages: list[list[BaseMessage]],
|
||||
messages: Union[list[list[BaseMessage]], list[MessageV1]],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
@@ -35,7 +37,9 @@ class MyCustomAsyncHandler(AsyncCallbackHandler):
|
||||
self,
|
||||
token: str,
|
||||
*,
|
||||
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
|
||||
chunk: Optional[
|
||||
Union[GenerationChunk, ChatGenerationChunk, AIMessageChunkV1]
|
||||
] = None,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
|
||||
@@ -9,6 +9,7 @@ from typing_extensions import override
|
||||
|
||||
from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.v1.messages import MessageV1
|
||||
|
||||
|
||||
class BaseFakeCallbackHandler(BaseModel):
|
||||
@@ -285,7 +286,7 @@ class FakeCallbackHandlerWithChatStart(FakeCallbackHandler):
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: dict[str, Any],
|
||||
messages: list[list[BaseMessage]],
|
||||
messages: Union[list[list[BaseMessage]], list[MessageV1]],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
|
||||
@@ -16,6 +16,8 @@ from langchain_core.language_models import (
|
||||
)
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage, HumanMessage
|
||||
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
|
||||
from langchain_core.v1.messages import AIMessageChunk as AIMessageChunkV1
|
||||
from langchain_core.v1.messages import MessageV1
|
||||
from tests.unit_tests.stubs import (
|
||||
_any_id_ai_message,
|
||||
_any_id_ai_message_chunk,
|
||||
@@ -157,13 +159,13 @@ async def test_callback_handlers() -> None:
|
||||
"""Verify that model is implemented correctly with handlers working."""
|
||||
|
||||
class MyCustomAsyncHandler(AsyncCallbackHandler):
|
||||
def __init__(self, store: list[str]) -> None:
|
||||
def __init__(self, store: list[Union[str, AIMessageChunkV1]]) -> None:
|
||||
self.store = store
|
||||
|
||||
async def on_chat_model_start(
|
||||
self,
|
||||
serialized: dict[str, Any],
|
||||
messages: list[list[BaseMessage]],
|
||||
messages: Union[list[list[BaseMessage]], list[MessageV1]],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
@@ -178,9 +180,11 @@ async def test_callback_handlers() -> None:
|
||||
@override
|
||||
async def on_llm_new_token(
|
||||
self,
|
||||
token: str,
|
||||
token: Union[str, AIMessageChunkV1],
|
||||
*,
|
||||
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
|
||||
chunk: Optional[
|
||||
Union[GenerationChunk, ChatGenerationChunk, AIMessageChunkV1]
|
||||
] = None,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
@@ -194,7 +198,7 @@ async def test_callback_handlers() -> None:
|
||||
]
|
||||
)
|
||||
model = GenericFakeChatModel(messages=infinite_cycle)
|
||||
tokens: list[str] = []
|
||||
tokens: list[Union[str, AIMessageChunkV1]] = []
|
||||
# New model
|
||||
results = [
|
||||
chunk
|
||||
|
||||
@@ -14,7 +14,10 @@ from langchain_core.language_models import (
|
||||
ParrotFakeChatModel,
|
||||
)
|
||||
from langchain_core.language_models._utils import _normalize_messages
|
||||
from langchain_core.language_models.fake_chat_models import FakeListChatModelError
|
||||
from langchain_core.language_models.fake_chat_models import (
|
||||
FakeListChatModelError,
|
||||
GenericFakeChatModelV1,
|
||||
)
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
@@ -29,6 +32,7 @@ from langchain_core.tracers.base import BaseTracer
|
||||
from langchain_core.tracers.context import collect_runs
|
||||
from langchain_core.tracers.event_stream import _AstreamEventsCallbackHandler
|
||||
from langchain_core.tracers.schemas import Run
|
||||
from langchain_core.v1.messages import AIMessageChunk as AIMessageChunkV1
|
||||
from tests.unit_tests.fake.callbacks import (
|
||||
BaseFakeCallbackHandler,
|
||||
FakeAsyncCallbackHandler,
|
||||
@@ -654,3 +658,93 @@ def test_normalize_messages_edge_cases() -> None:
|
||||
)
|
||||
]
|
||||
assert messages == _normalize_messages(messages)
|
||||
|
||||
|
||||
def test_streaming_v1() -> None:
|
||||
chunks = [
|
||||
AIMessageChunkV1(
|
||||
[
|
||||
{
|
||||
"type": "reasoning",
|
||||
"reasoning": "Let's call a tool.",
|
||||
"index": 0,
|
||||
}
|
||||
]
|
||||
),
|
||||
AIMessageChunkV1(
|
||||
[],
|
||||
tool_call_chunks=[
|
||||
{
|
||||
"type": "tool_call_chunk",
|
||||
"args": "",
|
||||
"name": "tool_name",
|
||||
"id": "call_123",
|
||||
"index": 1,
|
||||
},
|
||||
],
|
||||
),
|
||||
AIMessageChunkV1(
|
||||
[],
|
||||
tool_call_chunks=[
|
||||
{
|
||||
"type": "tool_call_chunk",
|
||||
"args": '{"a',
|
||||
"name": "",
|
||||
"id": "",
|
||||
"index": 1,
|
||||
},
|
||||
],
|
||||
),
|
||||
AIMessageChunkV1(
|
||||
[],
|
||||
tool_call_chunks=[
|
||||
{
|
||||
"type": "tool_call_chunk",
|
||||
"args": '": 1}',
|
||||
"name": "",
|
||||
"id": "",
|
||||
"index": 1,
|
||||
},
|
||||
],
|
||||
),
|
||||
]
|
||||
full: Optional[AIMessageChunkV1] = None
|
||||
for chunk in chunks:
|
||||
full = chunk if full is None else full + chunk
|
||||
|
||||
assert isinstance(full, AIMessageChunkV1)
|
||||
assert full.content == [
|
||||
{
|
||||
"type": "reasoning",
|
||||
"reasoning": "Let's call a tool.",
|
||||
"index": 0,
|
||||
},
|
||||
{
|
||||
"type": "tool_call_chunk",
|
||||
"args": '{"a": 1}',
|
||||
"name": "tool_name",
|
||||
"id": "call_123",
|
||||
"index": 1,
|
||||
},
|
||||
]
|
||||
|
||||
llm = GenericFakeChatModelV1(message_chunks=chunks)
|
||||
|
||||
full = None
|
||||
for chunk in llm.stream("anything"):
|
||||
full = chunk if full is None else full + chunk
|
||||
|
||||
assert isinstance(full, AIMessageChunkV1)
|
||||
assert full.content == [
|
||||
{
|
||||
"type": "reasoning",
|
||||
"reasoning": "Let's call a tool.",
|
||||
"index": 0,
|
||||
},
|
||||
{
|
||||
"type": "tool_call",
|
||||
"args": {"a": 1},
|
||||
"name": "tool_name",
|
||||
"id": "call_123",
|
||||
},
|
||||
]
|
||||
|
||||
@@ -458,3 +458,23 @@ def test_cleanup_serialized() -> None:
|
||||
"name": "CustomChat",
|
||||
"type": "constructor",
|
||||
}
|
||||
|
||||
|
||||
def test_token_costs_are_zeroed_out() -> None:
|
||||
# We zero-out token costs for cache hits
|
||||
local_cache = InMemoryCache()
|
||||
messages = [
|
||||
AIMessage(
|
||||
content="Hello, how are you?",
|
||||
usage_metadata={"input_tokens": 5, "output_tokens": 10, "total_tokens": 15},
|
||||
),
|
||||
]
|
||||
model = GenericFakeChatModel(messages=iter(messages), cache=local_cache)
|
||||
first_response = model.invoke("Hello")
|
||||
assert isinstance(first_response, AIMessage)
|
||||
assert first_response.usage_metadata
|
||||
|
||||
second_response = model.invoke("Hello")
|
||||
assert isinstance(second_response, AIMessage)
|
||||
assert second_response.usage_metadata
|
||||
assert second_response.usage_metadata["total_cost"] == 0 # type: ignore[typeddict-item]
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
import json
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from langchain_core.load import Serializable, dumpd, load
|
||||
from langchain_core.load import Serializable, dumpd, dumps, load
|
||||
from langchain_core.load.serializable import _is_field_useful
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.outputs import ChatGeneration, Generation
|
||||
@@ -276,3 +279,92 @@ def test_serialization_with_ignore_unserializable_fields() -> None:
|
||||
]
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
# Tests for dumps() function
|
||||
def test_dumps_basic_serialization() -> None:
|
||||
"""Test basic string serialization with `dumps()`."""
|
||||
foo = Foo(bar=42, baz="test")
|
||||
json_str = dumps(foo)
|
||||
|
||||
# Should be valid JSON
|
||||
parsed = json.loads(json_str)
|
||||
assert parsed == {
|
||||
"id": ["tests", "unit_tests", "load", "test_serializable", "Foo"],
|
||||
"kwargs": {"bar": 42, "baz": "test"},
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
}
|
||||
|
||||
|
||||
def test_dumps_pretty_formatting() -> None:
|
||||
"""Test pretty printing functionality."""
|
||||
foo = Foo(bar=1, baz="hello")
|
||||
|
||||
# Test pretty=True with default indent
|
||||
pretty_json = dumps(foo, pretty=True)
|
||||
assert " " in pretty_json
|
||||
|
||||
# Test custom indent (4-space)
|
||||
custom_indent = dumps(foo, pretty=True, indent=4)
|
||||
assert " " in custom_indent
|
||||
|
||||
# Verify it's still valid JSON
|
||||
parsed = json.loads(pretty_json)
|
||||
assert parsed["kwargs"]["bar"] == 1
|
||||
|
||||
|
||||
def test_dumps_invalid_default_kwarg() -> None:
|
||||
"""Test that passing `'default'` as kwarg raises ValueError."""
|
||||
foo = Foo(bar=1, baz="test")
|
||||
|
||||
with pytest.raises(ValueError, match="`default` should not be passed to dumps"):
|
||||
dumps(foo, default=lambda x: x)
|
||||
|
||||
|
||||
def test_dumps_additional_json_kwargs() -> None:
|
||||
"""Test that additional JSON kwargs are passed through."""
|
||||
foo = Foo(bar=1, baz="test")
|
||||
|
||||
compact_json = dumps(foo, separators=(",", ":"))
|
||||
assert ", " not in compact_json # Should be compact
|
||||
|
||||
# Test sort_keys
|
||||
sorted_json = dumps(foo, sort_keys=True)
|
||||
parsed = json.loads(sorted_json)
|
||||
assert parsed == dumpd(foo)
|
||||
|
||||
|
||||
def test_dumps_non_serializable_object() -> None:
|
||||
"""Test `dumps()` behavior with non-serializable objects."""
|
||||
|
||||
class NonSerializable:
|
||||
def __init__(self, value: int) -> None:
|
||||
self.value = value
|
||||
|
||||
obj = NonSerializable(42)
|
||||
json_str = dumps(obj)
|
||||
|
||||
# Should create a "not_implemented" representation
|
||||
parsed = json.loads(json_str)
|
||||
assert parsed["lc"] == 1
|
||||
assert parsed["type"] == "not_implemented"
|
||||
assert "NonSerializable" in parsed["repr"]
|
||||
|
||||
|
||||
def test_dumps_mixed_data_structure() -> None:
|
||||
"""Test `dumps()` with complex nested data structures."""
|
||||
data = {
|
||||
"serializable": Foo(bar=1, baz="test"),
|
||||
"list": [1, 2, {"nested": "value"}],
|
||||
"primitive": "string",
|
||||
}
|
||||
|
||||
json_str = dumps(data)
|
||||
parsed = json.loads(json_str)
|
||||
|
||||
# Serializable object should be properly serialized
|
||||
assert parsed["serializable"]["type"] == "constructor"
|
||||
# Primitives should remain unchanged
|
||||
assert parsed["list"] == [1, 2, {"nested": "value"}]
|
||||
assert parsed["primitive"] == "string"
|
||||
|
||||
@@ -0,0 +1,913 @@
|
||||
"""Unit tests for ContentBlock factory functions."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_core.messages.content_blocks import (
|
||||
CodeInterpreterCall,
|
||||
CodeInterpreterOutput,
|
||||
CodeInterpreterResult,
|
||||
InvalidToolCall,
|
||||
ToolCallChunk,
|
||||
WebSearchCall,
|
||||
WebSearchResult,
|
||||
create_audio_block,
|
||||
create_citation,
|
||||
create_file_block,
|
||||
create_image_block,
|
||||
create_non_standard_block,
|
||||
create_plaintext_block,
|
||||
create_reasoning_block,
|
||||
create_text_block,
|
||||
create_tool_call,
|
||||
create_video_block,
|
||||
)
|
||||
|
||||
|
||||
def _validate_lc_uuid(id_value: str) -> None:
|
||||
"""Validate that the ID has ``lc_`` prefix and valid UUID suffix.
|
||||
|
||||
Args:
|
||||
id_value: The ID string to validate.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the ID doesn't have ``lc_`` prefix or invalid UUID.
|
||||
"""
|
||||
assert id_value.startswith("lc_"), f"ID should start with 'lc_' but got: {id_value}"
|
||||
# Validate the UUID part after the lc_ prefix
|
||||
UUID(id_value[3:])
|
||||
|
||||
|
||||
class TestTextBlockFactory:
|
||||
"""Test create_text_block factory function."""
|
||||
|
||||
def test_basic_creation(self) -> None:
|
||||
"""Test basic text block creation."""
|
||||
block = create_text_block("Hello world")
|
||||
|
||||
assert block["type"] == "text"
|
||||
assert block.get("text") == "Hello world"
|
||||
assert "id" in block
|
||||
id_value = block.get("id")
|
||||
assert id_value is not None, "block id is None"
|
||||
_validate_lc_uuid(id_value)
|
||||
|
||||
def test_with_custom_id(self) -> None:
|
||||
"""Test text block creation with custom ID."""
|
||||
custom_id = "custom-123"
|
||||
block = create_text_block("Hello", id=custom_id)
|
||||
|
||||
assert block.get("id") == custom_id
|
||||
|
||||
def test_with_annotations(self) -> None:
|
||||
"""Test text block creation with annotations."""
|
||||
citation = create_citation(url="https://example.com", title="Example")
|
||||
block = create_text_block("Hello", annotations=[citation])
|
||||
|
||||
assert block.get("annotations") == [citation]
|
||||
|
||||
def test_with_index(self) -> None:
|
||||
"""Test text block creation with index."""
|
||||
block = create_text_block("Hello", index=42)
|
||||
|
||||
assert block.get("index") == 42
|
||||
|
||||
def test_optional_fields_not_present_when_none(self) -> None:
|
||||
"""Test that optional fields are not included when None."""
|
||||
block = create_text_block("Hello")
|
||||
|
||||
assert "annotations" not in block
|
||||
assert "index" not in block
|
||||
|
||||
|
||||
class TestImageBlockFactory:
|
||||
"""Test create_image_block factory function."""
|
||||
|
||||
def test_with_url(self) -> None:
|
||||
"""Test image block creation with URL."""
|
||||
block = create_image_block(url="https://example.com/image.jpg")
|
||||
|
||||
assert block["type"] == "image"
|
||||
assert block.get("url") == "https://example.com/image.jpg"
|
||||
assert "id" in block
|
||||
id_value = block.get("id")
|
||||
assert id_value is not None, "block id is None"
|
||||
_validate_lc_uuid(id_value)
|
||||
|
||||
def test_with_base64(self) -> None:
|
||||
"""Test image block creation with base64 data."""
|
||||
block = create_image_block(
|
||||
base64="iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJ", mime_type="image/png"
|
||||
)
|
||||
|
||||
assert block.get("base64") == "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJ"
|
||||
assert block.get("mime_type") == "image/png"
|
||||
|
||||
def test_with_file_id(self) -> None:
|
||||
"""Test image block creation with file ID."""
|
||||
block = create_image_block(file_id="file-123")
|
||||
|
||||
assert block.get("file_id") == "file-123"
|
||||
|
||||
def test_no_source_raises_error(self) -> None:
|
||||
"""Test that missing all sources raises ValueError."""
|
||||
with pytest.raises(
|
||||
ValueError, match="Must provide one of: url, base64, or file_id"
|
||||
):
|
||||
create_image_block()
|
||||
|
||||
def test_with_index(self) -> None:
|
||||
"""Test image block creation with index."""
|
||||
block = create_image_block(url="https://example.com/image.jpg", index=1)
|
||||
|
||||
assert block.get("index") == 1
|
||||
|
||||
def test_optional_fields_not_present_when_not_provided(self) -> None:
|
||||
"""Test that optional fields are not included when not provided."""
|
||||
block = create_image_block(url="https://example.com/image.jpg")
|
||||
|
||||
assert "base64" not in block
|
||||
assert "file_id" not in block
|
||||
assert "mime_type" not in block
|
||||
assert "index" not in block
|
||||
|
||||
|
||||
class TestVideoBlockFactory:
|
||||
"""Test create_video_block factory function."""
|
||||
|
||||
def test_with_url(self) -> None:
|
||||
"""Test video block creation with URL."""
|
||||
block = create_video_block(url="https://example.com/video.mp4")
|
||||
|
||||
assert block["type"] == "video"
|
||||
assert block.get("url") == "https://example.com/video.mp4"
|
||||
|
||||
def test_with_base64(self) -> None:
|
||||
"""Test video block creation with base64 data."""
|
||||
block = create_video_block(
|
||||
base64="UklGRnoGAABXQVZFZm10IBAAAAABAAEA", mime_type="video/mp4"
|
||||
)
|
||||
|
||||
assert block.get("base64") == "UklGRnoGAABXQVZFZm10IBAAAAABAAEA"
|
||||
assert block.get("mime_type") == "video/mp4"
|
||||
|
||||
def test_no_source_raises_error(self) -> None:
|
||||
"""Test that missing all sources raises ValueError."""
|
||||
with pytest.raises(
|
||||
ValueError, match="Must provide one of: url, base64, or file_id"
|
||||
):
|
||||
create_video_block()
|
||||
|
||||
|
||||
class TestAudioBlockFactory:
|
||||
"""Test create_audio_block factory function."""
|
||||
|
||||
def test_with_url(self) -> None:
|
||||
"""Test audio block creation with URL."""
|
||||
block = create_audio_block(url="https://example.com/audio.mp3")
|
||||
|
||||
assert block["type"] == "audio"
|
||||
assert block.get("url") == "https://example.com/audio.mp3"
|
||||
|
||||
def test_with_base64(self) -> None:
|
||||
"""Test audio block creation with base64 data."""
|
||||
block = create_audio_block(
|
||||
base64="UklGRnoGAABXQVZFZm10IBAAAAABAAEA", mime_type="audio/mp3"
|
||||
)
|
||||
|
||||
assert block.get("base64") == "UklGRnoGAABXQVZFZm10IBAAAAABAAEA"
|
||||
assert block.get("mime_type") == "audio/mp3"
|
||||
|
||||
def test_no_source_raises_error(self) -> None:
|
||||
"""Test that missing all sources raises ValueError."""
|
||||
with pytest.raises(
|
||||
ValueError, match="Must provide one of: url, base64, or file_id"
|
||||
):
|
||||
create_audio_block()
|
||||
|
||||
|
||||
class TestFileBlockFactory:
|
||||
"""Test create_file_block factory function."""
|
||||
|
||||
def test_with_url(self) -> None:
|
||||
"""Test file block creation with URL."""
|
||||
block = create_file_block(url="https://example.com/document.pdf")
|
||||
|
||||
assert block["type"] == "file"
|
||||
assert block.get("url") == "https://example.com/document.pdf"
|
||||
|
||||
def test_with_base64(self) -> None:
|
||||
"""Test file block creation with base64 data."""
|
||||
block = create_file_block(
|
||||
base64="JVBERi0xLjQKJdPr6eEKMSAwIG9iago8PAovVHlwZSAvQ2F0YWxvZwo=",
|
||||
mime_type="application/pdf",
|
||||
)
|
||||
|
||||
assert (
|
||||
block.get("base64")
|
||||
== "JVBERi0xLjQKJdPr6eEKMSAwIG9iago8PAovVHlwZSAvQ2F0YWxvZwo="
|
||||
)
|
||||
assert block.get("mime_type") == "application/pdf"
|
||||
|
||||
def test_no_source_raises_error(self) -> None:
|
||||
"""Test that missing all sources raises ValueError."""
|
||||
with pytest.raises(
|
||||
ValueError, match="Must provide one of: url, base64, or file_id"
|
||||
):
|
||||
create_file_block()
|
||||
|
||||
|
||||
class TestPlainTextBlockFactory:
|
||||
"""Test create_plain_text_block factory function."""
|
||||
|
||||
def test_basic_creation(self) -> None:
|
||||
"""Test basic plain text block creation."""
|
||||
block = create_plaintext_block("This is plain text content.")
|
||||
|
||||
assert block["type"] == "text-plain"
|
||||
assert block.get("mime_type") == "text/plain"
|
||||
assert block.get("text") == "This is plain text content."
|
||||
assert "id" in block
|
||||
id_value = block.get("id")
|
||||
assert id_value is not None, "block id is None"
|
||||
_validate_lc_uuid(id_value)
|
||||
|
||||
def test_with_title_and_context(self) -> None:
|
||||
"""Test plain text block creation with title and context."""
|
||||
block = create_plaintext_block(
|
||||
"Document content here.",
|
||||
title="Important Document",
|
||||
context="This document contains important information.",
|
||||
)
|
||||
|
||||
assert block.get("title") == "Important Document"
|
||||
assert block.get("context") == "This document contains important information."
|
||||
|
||||
def test_with_url(self) -> None:
|
||||
"""Test plain text block creation with URL."""
|
||||
block = create_plaintext_block(
|
||||
"Content", url="https://example.com/document.txt"
|
||||
)
|
||||
|
||||
assert block.get("url") == "https://example.com/document.txt"
|
||||
|
||||
|
||||
class TestToolCallFactory:
|
||||
"""Test create_tool_call factory function."""
|
||||
|
||||
def test_basic_creation(self) -> None:
|
||||
"""Test basic tool call creation."""
|
||||
block = create_tool_call("search", {"query": "python"})
|
||||
|
||||
assert block["type"] == "tool_call"
|
||||
assert block["name"] == "search"
|
||||
assert block["args"] == {"query": "python"}
|
||||
assert "id" in block
|
||||
id_value = block.get("id")
|
||||
assert id_value is not None, "block id is None"
|
||||
_validate_lc_uuid(id_value)
|
||||
|
||||
def test_with_custom_id(self) -> None:
|
||||
"""Test tool call creation with custom ID."""
|
||||
block = create_tool_call("search", {"query": "python"}, id="tool-123")
|
||||
|
||||
assert block.get("id") == "tool-123"
|
||||
|
||||
def test_with_index(self) -> None:
|
||||
"""Test tool call creation with index."""
|
||||
block = create_tool_call("search", {"query": "python"}, index=2)
|
||||
|
||||
assert block.get("index") == 2
|
||||
|
||||
|
||||
class TestReasoningBlockFactory:
|
||||
"""Test create_reasoning_block factory function."""
|
||||
|
||||
def test_basic_creation(self) -> None:
|
||||
"""Test basic reasoning block creation."""
|
||||
block = create_reasoning_block("Let me think about this problem...")
|
||||
|
||||
assert block["type"] == "reasoning"
|
||||
assert block.get("reasoning") == "Let me think about this problem..."
|
||||
assert "id" in block
|
||||
id_value = block.get("id")
|
||||
assert id_value is not None, "block id is None"
|
||||
_validate_lc_uuid(id_value)
|
||||
|
||||
@pytest.mark.xfail(reason="Optional fields not implemented yet")
|
||||
def test_with_signatures(self) -> None:
|
||||
"""Test reasoning block creation with signatures."""
|
||||
block = create_reasoning_block(
|
||||
"Thinking...",
|
||||
thought_signature="thought-sig-123", # type: ignore[call-arg]
|
||||
signature="auth-sig-456", # type: ignore[call-arg, unused-ignore]
|
||||
)
|
||||
|
||||
assert block.get("thought_signature") == "thought-sig-123"
|
||||
assert block.get("signature") == "auth-sig-456"
|
||||
|
||||
def test_with_index(self) -> None:
|
||||
"""Test reasoning block creation with index."""
|
||||
block = create_reasoning_block("Thinking...", index=3)
|
||||
|
||||
assert block.get("index") == 3
|
||||
|
||||
|
||||
class TestCitationFactory:
|
||||
"""Test create_citation factory function."""
|
||||
|
||||
def test_basic_creation(self) -> None:
|
||||
"""Test basic citation creation."""
|
||||
block = create_citation()
|
||||
|
||||
assert block["type"] == "citation"
|
||||
assert "id" in block
|
||||
id_value = block.get("id")
|
||||
assert id_value is not None, "block id is None"
|
||||
_validate_lc_uuid(id_value)
|
||||
|
||||
def test_with_all_fields(self) -> None:
|
||||
"""Test citation creation with all fields."""
|
||||
block = create_citation(
|
||||
url="https://example.com/source",
|
||||
title="Source Document",
|
||||
start_index=10,
|
||||
end_index=50,
|
||||
cited_text="This is the cited text.",
|
||||
)
|
||||
|
||||
assert block.get("url") == "https://example.com/source"
|
||||
assert block.get("title") == "Source Document"
|
||||
assert block.get("start_index") == 10
|
||||
assert block.get("end_index") == 50
|
||||
assert block.get("cited_text") == "This is the cited text."
|
||||
|
||||
def test_optional_fields_not_present_when_none(self) -> None:
|
||||
"""Test that optional fields are not included when None."""
|
||||
block = create_citation()
|
||||
|
||||
assert "url" not in block
|
||||
assert "title" not in block
|
||||
assert "start_index" not in block
|
||||
assert "end_index" not in block
|
||||
assert "cited_text" not in block
|
||||
|
||||
|
||||
class TestNonStandardBlockFactory:
|
||||
"""Test create_non_standard_block factory function."""
|
||||
|
||||
def test_basic_creation(self) -> None:
|
||||
"""Test basic non-standard block creation."""
|
||||
value = {"custom_field": "custom_value", "number": 42}
|
||||
block = create_non_standard_block(value)
|
||||
|
||||
assert block["type"] == "non_standard"
|
||||
assert block["value"] == value
|
||||
assert "id" in block
|
||||
id_value = block.get("id")
|
||||
assert id_value is not None, "block id is None"
|
||||
_validate_lc_uuid(id_value)
|
||||
|
||||
def test_with_index(self) -> None:
|
||||
"""Test non-standard block creation with index."""
|
||||
value = {"data": "test"}
|
||||
block = create_non_standard_block(value, index=5)
|
||||
|
||||
assert block.get("index") == 5
|
||||
|
||||
def test_optional_fields_not_present_when_none(self) -> None:
|
||||
"""Test that optional fields are not included when None."""
|
||||
value = {"data": "test"}
|
||||
block = create_non_standard_block(value)
|
||||
|
||||
assert "index" not in block
|
||||
|
||||
|
||||
class TestUUIDValidation:
|
||||
"""Test UUID generation and validation behavior."""
|
||||
|
||||
def test_custom_id_bypasses_lc_prefix_requirement(self) -> None:
|
||||
"""Test that custom IDs can use any format (don't require lc_ prefix)."""
|
||||
custom_id = "custom-123"
|
||||
block = create_text_block("Hello", id=custom_id)
|
||||
|
||||
assert block.get("id") == custom_id
|
||||
# Custom IDs should not be validated with lc_ prefix requirement
|
||||
|
||||
def test_generated_ids_are_unique(self) -> None:
|
||||
"""Test that multiple factory calls generate unique IDs."""
|
||||
blocks = [create_text_block("test") for _ in range(10)]
|
||||
ids = [block.get("id") for block in blocks]
|
||||
|
||||
# All IDs should be unique
|
||||
assert len(set(ids)) == len(ids)
|
||||
|
||||
# All generated IDs should have lc_ prefix
|
||||
for id_value in ids:
|
||||
_validate_lc_uuid(id_value or "")
|
||||
|
||||
def test_empty_string_id_generates_new_uuid(self) -> None:
|
||||
"""Test that empty string ID generates new UUID with lc_ prefix."""
|
||||
block = create_text_block("Hello", id="")
|
||||
|
||||
id_value: str = block.get("id", "")
|
||||
assert id_value != ""
|
||||
_validate_lc_uuid(id_value)
|
||||
|
||||
def test_generated_id_length(self) -> None:
|
||||
"""Test that generated IDs have correct length (UUID4 + lc_ prefix)."""
|
||||
block = create_text_block("Hello")
|
||||
|
||||
id_value = block.get("id")
|
||||
assert id_value is not None
|
||||
|
||||
# UUID4 string length is 36 chars, plus 3 for "lc_" prefix = 39 total
|
||||
expected_length = 36 + 3
|
||||
assert len(id_value) == expected_length, (
|
||||
f"Expected length {expected_length}, got {len(id_value)}"
|
||||
)
|
||||
|
||||
# Validate it's properly formatted
|
||||
_validate_lc_uuid(id_value)
|
||||
|
||||
|
||||
class TestFactoryTypeConsistency:
|
||||
"""Test that factory functions return correctly typed objects."""
|
||||
|
||||
def test_factories_return_correct_types(self) -> None:
|
||||
"""Test that all factory functions return the expected TypedDict types."""
|
||||
text_block = create_text_block("test")
|
||||
assert isinstance(text_block, dict)
|
||||
assert text_block["type"] == "text"
|
||||
|
||||
image_block = create_image_block(url="https://example.com/image.jpg")
|
||||
assert isinstance(image_block, dict)
|
||||
assert image_block["type"] == "image"
|
||||
|
||||
video_block = create_video_block(url="https://example.com/video.mp4")
|
||||
assert isinstance(video_block, dict)
|
||||
assert video_block["type"] == "video"
|
||||
|
||||
audio_block = create_audio_block(url="https://example.com/audio.mp3")
|
||||
assert isinstance(audio_block, dict)
|
||||
assert audio_block["type"] == "audio"
|
||||
|
||||
file_block = create_file_block(url="https://example.com/file.pdf")
|
||||
assert isinstance(file_block, dict)
|
||||
assert file_block["type"] == "file"
|
||||
|
||||
plain_text_block = create_plaintext_block("content")
|
||||
assert isinstance(plain_text_block, dict)
|
||||
assert plain_text_block["type"] == "text-plain"
|
||||
|
||||
tool_call = create_tool_call("tool", {"arg": "value"})
|
||||
assert isinstance(tool_call, dict)
|
||||
assert tool_call["type"] == "tool_call"
|
||||
|
||||
reasoning_block = create_reasoning_block("reasoning")
|
||||
assert isinstance(reasoning_block, dict)
|
||||
assert reasoning_block["type"] == "reasoning"
|
||||
|
||||
citation = create_citation()
|
||||
assert isinstance(citation, dict)
|
||||
assert citation["type"] == "citation"
|
||||
|
||||
non_standard_block = create_non_standard_block({"data": "value"})
|
||||
assert isinstance(non_standard_block, dict)
|
||||
assert non_standard_block["type"] == "non_standard"
|
||||
|
||||
|
||||
class TestExtraItems:
|
||||
"""Test that content blocks support extra items."""
|
||||
|
||||
def test_text_block_extras_field(self) -> None:
|
||||
"""Test that TextContentBlock properly supports the extras field."""
|
||||
block = create_text_block("Hello world")
|
||||
|
||||
block["extras"] = {
|
||||
"openai_metadata": {"model": "gpt-4", "temperature": 0.7},
|
||||
"anthropic_usage": {"input_tokens": 10, "output_tokens": 20},
|
||||
"custom_field": "any value",
|
||||
}
|
||||
|
||||
assert block["type"] == "text"
|
||||
assert block["text"] == "Hello world"
|
||||
assert "id" in block
|
||||
assert "extras" in block
|
||||
|
||||
extras = block.get("extras", {})
|
||||
assert extras.get("openai_metadata") == {"model": "gpt-4", "temperature": 0.7}
|
||||
expected_usage = {"input_tokens": 10, "output_tokens": 20}
|
||||
assert extras.get("anthropic_usage") == expected_usage
|
||||
assert extras.get("custom_field") == "any value"
|
||||
|
||||
def test_extra_items_do_not_interfere_with_standard_fields(self) -> None:
|
||||
"""Test that extra items don't interfere with standard field access."""
|
||||
block = create_text_block("Original text", index=1)
|
||||
|
||||
# Add many extra fields
|
||||
for i in range(10):
|
||||
block[f"extra_field_{i}"] = f"value_{i}" # type: ignore[literal-required]
|
||||
|
||||
# Standard fields should still work correctly
|
||||
assert block["type"] == "text"
|
||||
assert block["text"] == "Original text"
|
||||
assert block["index"] == 1 if "index" in block else None
|
||||
assert "id" in block
|
||||
|
||||
# Extra fields should also be accessible
|
||||
for i in range(10):
|
||||
assert block.get(f"extra_field_{i}") == f"value_{i}"
|
||||
|
||||
def test_extra_items_can_be_modified(self) -> None:
|
||||
"""Test that extra items can be modified after creation."""
|
||||
block = create_image_block(url="https://example.com/image.jpg")
|
||||
|
||||
# Add an extra field
|
||||
block["extras"] = {"status": "pending"}
|
||||
assert block["extras"].get("status") == "pending"
|
||||
|
||||
# Modify the extra field
|
||||
block["extras"] = {"status": "processed"}
|
||||
assert block["extras"].get("status") == "processed"
|
||||
|
||||
# Add more fields
|
||||
block["extras"] = {"metadata": {"version": 1}}
|
||||
metadata = block["extras"].get("metadata", {})
|
||||
assert isinstance(metadata, dict)
|
||||
assert metadata.get("version") == 1
|
||||
|
||||
# Modify nested extra field
|
||||
metadata["version"] = 2
|
||||
assert isinstance(metadata, dict)
|
||||
assert metadata.get("version") == 2
|
||||
|
||||
def test_all_content_blocks_support_extra_items(self) -> None:
|
||||
"""Test that all content block types support extra items."""
|
||||
# Test each content block type
|
||||
text_block = create_text_block("test")
|
||||
text_block["extras"] = {"text_extra": "a"}
|
||||
assert text_block.get("extras") == {"text_extra": "a"}
|
||||
|
||||
image_block = create_image_block(url="https://example.com/image.jpg")
|
||||
image_block["extras"] = {"image_extra": "a"}
|
||||
assert image_block.get("extras") == {"image_extra": "a"}
|
||||
|
||||
video_block = create_video_block(url="https://example.com/video.mp4")
|
||||
video_block["extras"] = {"video_extra": "a"}
|
||||
assert video_block.get("extras") == {"video_extra": "a"}
|
||||
|
||||
audio_block = create_audio_block(url="https://example.com/audio.mp3")
|
||||
audio_block["extras"] = {"audio_extra": "a"}
|
||||
assert audio_block.get("extras") == {"audio_extra": "a"}
|
||||
|
||||
file_block = create_file_block(url="https://example.com/file.pdf")
|
||||
file_block["extras"] = {"file_extra": "a"}
|
||||
assert file_block.get("extras") == {"file_extra": "a"}
|
||||
|
||||
plain_text_block = create_plaintext_block("content")
|
||||
plain_text_block["extras"] = {"plaintext_extra": "a"}
|
||||
assert plain_text_block.get("extras") == {"plaintext_extra": "a"}
|
||||
|
||||
tool_call = create_tool_call("tool", {"arg": "value"})
|
||||
tool_call["extras"] = {"tool_extra": "a"}
|
||||
assert tool_call.get("extras") == {"tool_extra": "a"}
|
||||
|
||||
reasoning_block = create_reasoning_block("reasoning")
|
||||
reasoning_block["extras"] = {"reasoning_extra": "a"}
|
||||
assert reasoning_block.get("extras") == {"reasoning_extra": "a"}
|
||||
|
||||
|
||||
class TestExtrasField:
|
||||
"""Test the explicit extras field across all content block types."""
|
||||
|
||||
def test_all_content_blocks_support_extras_field(self) -> None:
|
||||
"""Test that all content block types support the explicit extras field."""
|
||||
provider_metadata = {
|
||||
"provider": "openai",
|
||||
"model": "gpt-4",
|
||||
"temperature": 0.7,
|
||||
"usage": {"input_tokens": 10, "output_tokens": 20},
|
||||
}
|
||||
|
||||
# Test TextContentBlock
|
||||
text_block = create_text_block("test")
|
||||
text_block["extras"] = provider_metadata
|
||||
assert text_block.get("extras") == provider_metadata
|
||||
assert text_block["type"] == "text"
|
||||
|
||||
# Test ImageContentBlock
|
||||
image_block = create_image_block(url="https://example.com/image.jpg")
|
||||
image_block["extras"] = provider_metadata
|
||||
assert image_block.get("extras") == provider_metadata
|
||||
assert image_block["type"] == "image"
|
||||
|
||||
# Test VideoContentBlock
|
||||
video_block = create_video_block(url="https://example.com/video.mp4")
|
||||
video_block["extras"] = provider_metadata
|
||||
assert video_block.get("extras") == provider_metadata
|
||||
assert video_block["type"] == "video"
|
||||
|
||||
# Test AudioContentBlock
|
||||
audio_block = create_audio_block(url="https://example.com/audio.mp3")
|
||||
audio_block["extras"] = provider_metadata
|
||||
assert audio_block.get("extras") == provider_metadata
|
||||
assert audio_block["type"] == "audio"
|
||||
|
||||
# Test FileContentBlock
|
||||
file_block = create_file_block(url="https://example.com/file.pdf")
|
||||
file_block["extras"] = provider_metadata
|
||||
assert file_block.get("extras") == provider_metadata
|
||||
assert file_block["type"] == "file"
|
||||
|
||||
# Test PlainTextContentBlock
|
||||
plain_text_block = create_plaintext_block("content")
|
||||
plain_text_block["extras"] = provider_metadata
|
||||
assert plain_text_block.get("extras") == provider_metadata
|
||||
assert plain_text_block["type"] == "text-plain"
|
||||
|
||||
# Test ToolCall
|
||||
tool_call = create_tool_call("tool", {"arg": "value"})
|
||||
tool_call["extras"] = provider_metadata
|
||||
assert tool_call.get("extras") == provider_metadata
|
||||
assert tool_call["type"] == "tool_call"
|
||||
|
||||
# Test ReasoningContentBlock
|
||||
reasoning_block = create_reasoning_block("reasoning")
|
||||
reasoning_block["extras"] = provider_metadata
|
||||
assert reasoning_block.get("extras") == provider_metadata
|
||||
assert reasoning_block["type"] == "reasoning"
|
||||
|
||||
# Test Citation
|
||||
citation = create_citation()
|
||||
citation["extras"] = provider_metadata
|
||||
assert citation.get("extras") == provider_metadata
|
||||
assert citation["type"] == "citation"
|
||||
|
||||
def test_extras_field_is_optional(self) -> None:
|
||||
"""Test that the extras field is optional and blocks work without it."""
|
||||
# Create blocks without extras
|
||||
text_block = create_text_block("test")
|
||||
image_block = create_image_block(url="https://example.com/image.jpg")
|
||||
tool_call = create_tool_call("tool", {"arg": "value"})
|
||||
reasoning_block = create_reasoning_block("reasoning")
|
||||
citation = create_citation()
|
||||
|
||||
# Verify blocks work correctly without extras
|
||||
assert text_block["type"] == "text"
|
||||
assert image_block["type"] == "image"
|
||||
assert tool_call["type"] == "tool_call"
|
||||
assert reasoning_block["type"] == "reasoning"
|
||||
assert citation["type"] == "citation"
|
||||
|
||||
# Verify extras field is not present when not set
|
||||
assert "extras" not in text_block
|
||||
assert "extras" not in image_block
|
||||
assert "extras" not in tool_call
|
||||
assert "extras" not in reasoning_block
|
||||
assert "extras" not in citation
|
||||
|
||||
def test_extras_field_can_be_modified(self) -> None:
|
||||
"""Test that the extras field can be modified after creation."""
|
||||
block = create_text_block("test")
|
||||
|
||||
# Add extras
|
||||
block["extras"] = {"initial": "value"}
|
||||
assert block.get("extras") == {"initial": "value"}
|
||||
|
||||
# Modify extras
|
||||
block["extras"] = {"updated": "value", "count": 42}
|
||||
extras = block.get("extras", {})
|
||||
assert extras.get("updated") == "value"
|
||||
assert extras.get("count") == 42
|
||||
assert "initial" not in extras
|
||||
|
||||
# Update nested values in extras
|
||||
if "extras" in block:
|
||||
block["extras"]["nested"] = {"deep": "value"}
|
||||
extras = block.get("extras", {})
|
||||
nested = extras.get("nested", {})
|
||||
assert isinstance(nested, dict)
|
||||
assert nested.get("deep") == "value"
|
||||
|
||||
def test_extras_field_supports_various_data_types(self) -> None:
|
||||
"""Test that the extras field can store various data types."""
|
||||
block = create_text_block("test")
|
||||
|
||||
complex_extras = {
|
||||
"string_val": "test string",
|
||||
"int_val": 42,
|
||||
"float_val": 3.14,
|
||||
"bool_val": True,
|
||||
"none_val": None,
|
||||
"list_val": ["item1", "item2", {"nested": "in_list"}],
|
||||
"dict_val": {"nested": {"deeply": {"nested": "value"}}},
|
||||
}
|
||||
|
||||
block["extras"] = complex_extras
|
||||
|
||||
extras = block.get("extras", {})
|
||||
assert extras.get("string_val") == "test string"
|
||||
assert extras.get("int_val") == 42
|
||||
assert extras.get("float_val") == 3.14
|
||||
assert extras.get("bool_val") is True
|
||||
assert extras.get("none_val") is None
|
||||
|
||||
list_val = extras.get("list_val", [])
|
||||
assert isinstance(list_val, list)
|
||||
assert len(list_val) == 3
|
||||
assert list_val[0] == "item1"
|
||||
assert list_val[1] == "item2"
|
||||
assert isinstance(list_val[2], dict)
|
||||
assert list_val[2].get("nested") == "in_list"
|
||||
|
||||
dict_val = extras.get("dict_val", {})
|
||||
assert isinstance(dict_val, dict)
|
||||
nested = dict_val.get("nested", {})
|
||||
assert isinstance(nested, dict)
|
||||
deeply = nested.get("deeply", {})
|
||||
assert isinstance(deeply, dict)
|
||||
assert deeply.get("nested") == "value"
|
||||
|
||||
def test_extras_field_does_not_interfere_with_standard_fields(self) -> None:
|
||||
"""Test that the extras field doesn't interfere with standard fields."""
|
||||
# Create a complex block with all standard fields
|
||||
block = create_text_block(
|
||||
"Test content",
|
||||
annotations=[create_citation(url="https://example.com")],
|
||||
index=42,
|
||||
)
|
||||
|
||||
# Add extensive extras
|
||||
large_extras = {f"field_{i}": f"value_{i}" for i in range(100)}
|
||||
block["extras"] = large_extras
|
||||
|
||||
# Verify all standard fields still work
|
||||
assert block["type"] == "text"
|
||||
assert block["text"] == "Test content"
|
||||
assert block.get("index") == 42
|
||||
assert "id" in block
|
||||
assert "annotations" in block
|
||||
|
||||
annotations = block.get("annotations", [])
|
||||
assert len(annotations) == 1
|
||||
assert annotations[0]["type"] == "citation"
|
||||
|
||||
# Verify extras field works
|
||||
extras = block.get("extras", {})
|
||||
assert len(extras) == 100
|
||||
for i in range(100):
|
||||
assert extras.get(f"field_{i}") == f"value_{i}"
|
||||
|
||||
def test_special_content_blocks_support_extras_field(self) -> None:
|
||||
"""Test that special content blocks support extras field."""
|
||||
provider_metadata = {
|
||||
"provider": "openai",
|
||||
"request_id": "req_12345",
|
||||
"timing": {"start": 1234567890, "end": 1234567895},
|
||||
}
|
||||
|
||||
# Test ToolCallChunk
|
||||
tool_call_chunk: ToolCallChunk = {
|
||||
"type": "tool_call_chunk",
|
||||
"id": "tool_123",
|
||||
"name": "search",
|
||||
"args": '{"query": "test"}',
|
||||
"index": 0,
|
||||
"extras": provider_metadata,
|
||||
}
|
||||
assert tool_call_chunk.get("extras") == provider_metadata
|
||||
assert tool_call_chunk["type"] == "tool_call_chunk"
|
||||
|
||||
# Test InvalidToolCall
|
||||
invalid_tool_call: InvalidToolCall = {
|
||||
"type": "invalid_tool_call",
|
||||
"id": "invalid_123",
|
||||
"name": "bad_tool",
|
||||
"args": "invalid json",
|
||||
"error": "JSON parse error",
|
||||
"extras": provider_metadata,
|
||||
}
|
||||
assert invalid_tool_call.get("extras") == provider_metadata
|
||||
assert invalid_tool_call["type"] == "invalid_tool_call"
|
||||
|
||||
# Test WebSearchCall
|
||||
web_search_call: WebSearchCall = {
|
||||
"type": "web_search_call",
|
||||
"id": "search_123",
|
||||
"query": "python langchain",
|
||||
"index": 0,
|
||||
"extras": provider_metadata,
|
||||
}
|
||||
assert web_search_call.get("extras") == provider_metadata
|
||||
assert web_search_call["type"] == "web_search_call"
|
||||
|
||||
# Test WebSearchResult
|
||||
web_search_result: WebSearchResult = {
|
||||
"type": "web_search_result",
|
||||
"id": "result_123",
|
||||
"urls": ["https://example.com", "https://test.com"],
|
||||
"index": 0,
|
||||
"extras": provider_metadata,
|
||||
}
|
||||
assert web_search_result.get("extras") == provider_metadata
|
||||
assert web_search_result["type"] == "web_search_result"
|
||||
|
||||
# Test CodeInterpreterCall
|
||||
code_interpreter_call: CodeInterpreterCall = {
|
||||
"type": "code_interpreter_call",
|
||||
"id": "code_123",
|
||||
"language": "python",
|
||||
"code": "print('hello world')",
|
||||
"index": 0,
|
||||
"extras": provider_metadata,
|
||||
}
|
||||
assert code_interpreter_call.get("extras") == provider_metadata
|
||||
assert code_interpreter_call["type"] == "code_interpreter_call"
|
||||
|
||||
# Test CodeInterpreterOutput
|
||||
code_interpreter_output: CodeInterpreterOutput = {
|
||||
"type": "code_interpreter_output",
|
||||
"id": "output_123",
|
||||
"return_code": 0,
|
||||
"stderr": "",
|
||||
"stdout": "hello world\n",
|
||||
"file_ids": ["file_123"],
|
||||
"index": 0,
|
||||
"extras": provider_metadata,
|
||||
}
|
||||
assert code_interpreter_output.get("extras") == provider_metadata
|
||||
assert code_interpreter_output["type"] == "code_interpreter_output"
|
||||
|
||||
# Test CodeInterpreterResult
|
||||
code_interpreter_result: CodeInterpreterResult = {
|
||||
"type": "code_interpreter_result",
|
||||
"id": "result_123",
|
||||
"output": [code_interpreter_output],
|
||||
"index": 0,
|
||||
"extras": provider_metadata,
|
||||
}
|
||||
assert code_interpreter_result.get("extras") == provider_metadata
|
||||
assert code_interpreter_result["type"] == "code_interpreter_result"
|
||||
|
||||
def test_extras_field_is_not_required_for_special_blocks(self) -> None:
|
||||
"""Test that extras field is optional for all special content blocks."""
|
||||
# Create blocks without extras field
|
||||
tool_call_chunk: ToolCallChunk = {
|
||||
"id": "tool_123",
|
||||
"name": "search",
|
||||
"args": '{"query": "test"}',
|
||||
"index": 0,
|
||||
}
|
||||
|
||||
invalid_tool_call: InvalidToolCall = {
|
||||
"type": "invalid_tool_call",
|
||||
"id": "invalid_123",
|
||||
"name": "bad_tool",
|
||||
"args": "invalid json",
|
||||
"error": "JSON parse error",
|
||||
}
|
||||
|
||||
web_search_call: WebSearchCall = {
|
||||
"type": "web_search_call",
|
||||
"query": "python langchain",
|
||||
}
|
||||
|
||||
web_search_result: WebSearchResult = {
|
||||
"type": "web_search_result",
|
||||
"urls": ["https://example.com"],
|
||||
}
|
||||
|
||||
code_interpreter_call: CodeInterpreterCall = {
|
||||
"type": "code_interpreter_call",
|
||||
"code": "print('hello')",
|
||||
}
|
||||
|
||||
code_interpreter_output: CodeInterpreterOutput = {
|
||||
"type": "code_interpreter_output",
|
||||
"stdout": "hello\n",
|
||||
}
|
||||
|
||||
code_interpreter_result: CodeInterpreterResult = {
|
||||
"type": "code_interpreter_result",
|
||||
"output": [code_interpreter_output],
|
||||
}
|
||||
|
||||
# Verify they work without extras
|
||||
assert tool_call_chunk.get("name") == "search"
|
||||
assert invalid_tool_call["type"] == "invalid_tool_call"
|
||||
assert web_search_call["type"] == "web_search_call"
|
||||
assert web_search_result["type"] == "web_search_result"
|
||||
assert code_interpreter_call["type"] == "code_interpreter_call"
|
||||
assert code_interpreter_output["type"] == "code_interpreter_output"
|
||||
assert code_interpreter_result["type"] == "code_interpreter_result"
|
||||
|
||||
# Verify extras field is not present
|
||||
assert "extras" not in tool_call_chunk
|
||||
assert "extras" not in invalid_tool_call
|
||||
assert "extras" not in web_search_call
|
||||
assert "extras" not in web_search_result
|
||||
assert "extras" not in code_interpreter_call
|
||||
assert "extras" not in code_interpreter_output
|
||||
assert "extras" not in code_interpreter_result
|
||||
@@ -5,26 +5,48 @@ EXPECTED_ALL = [
|
||||
"_message_from_dict",
|
||||
"AIMessage",
|
||||
"AIMessageChunk",
|
||||
"Annotation",
|
||||
"AnyMessage",
|
||||
"AudioContentBlock",
|
||||
"BaseMessage",
|
||||
"BaseMessageChunk",
|
||||
"ContentBlock",
|
||||
"ChatMessage",
|
||||
"ChatMessageChunk",
|
||||
"Citation",
|
||||
"CodeInterpreterCall",
|
||||
"CodeInterpreterOutput",
|
||||
"CodeInterpreterResult",
|
||||
"DataContentBlock",
|
||||
"FileContentBlock",
|
||||
"FunctionMessage",
|
||||
"FunctionMessageChunk",
|
||||
"HumanMessage",
|
||||
"HumanMessageChunk",
|
||||
"ImageContentBlock",
|
||||
"InvalidToolCall",
|
||||
"NonStandardAnnotation",
|
||||
"NonStandardContentBlock",
|
||||
"PlainTextContentBlock",
|
||||
"SystemMessage",
|
||||
"SystemMessageChunk",
|
||||
"TextContentBlock",
|
||||
"ToolCall",
|
||||
"ToolCallChunk",
|
||||
"ToolMessage",
|
||||
"ToolMessageChunk",
|
||||
"VideoContentBlock",
|
||||
"WebSearchCall",
|
||||
"WebSearchResult",
|
||||
"ReasoningContentBlock",
|
||||
"RemoveMessage",
|
||||
"convert_to_messages",
|
||||
"get_buffer_string",
|
||||
"is_data_content_block",
|
||||
"is_reasoning_block",
|
||||
"is_text_block",
|
||||
"is_tool_call_block",
|
||||
"is_tool_call_chunk",
|
||||
"merge_content",
|
||||
"message_chunk_to_message",
|
||||
"message_to_dict",
|
||||
|
||||
343
libs/core/tests/unit_tests/messages/test_response_metadata.py
Normal file
343
libs/core/tests/unit_tests/messages/test_response_metadata.py
Normal file
@@ -0,0 +1,343 @@
|
||||
"""Unit tests for ResponseMetadata TypedDict."""
|
||||
|
||||
from langchain_core.v1.messages import AIMessage, AIMessageChunk, ResponseMetadata
|
||||
|
||||
|
||||
class TestResponseMetadata:
|
||||
"""Test the ResponseMetadata TypedDict functionality."""
|
||||
|
||||
def test_response_metadata_basic_fields(self) -> None:
|
||||
"""Test ResponseMetadata with basic required fields."""
|
||||
metadata: ResponseMetadata = {
|
||||
"model_provider": "openai",
|
||||
"model_name": "gpt-4",
|
||||
}
|
||||
|
||||
assert metadata.get("model_provider") == "openai"
|
||||
assert metadata.get("model_name") == "gpt-4"
|
||||
|
||||
def test_response_metadata_is_optional(self) -> None:
|
||||
"""Test that ResponseMetadata fields are optional due to total=False."""
|
||||
# Should be able to create empty ResponseMetadata
|
||||
metadata: ResponseMetadata = {}
|
||||
assert metadata == {}
|
||||
|
||||
# Should be able to create with just one field
|
||||
metadata_partial: ResponseMetadata = {"model_provider": "anthropic"}
|
||||
assert metadata_partial.get("model_provider") == "anthropic"
|
||||
assert "model_name" not in metadata_partial
|
||||
|
||||
def test_response_metadata_supports_extra_fields(self) -> None:
|
||||
"""Test that ResponseMetadata supports provider-specific extra fields."""
|
||||
metadata: ResponseMetadata = {
|
||||
"model_provider": "openai",
|
||||
"model_name": "gpt-4-turbo",
|
||||
# Extra fields should be allowed
|
||||
"system_fingerprint": "fp_12345",
|
||||
"logprobs": None,
|
||||
"finish_reason": "stop",
|
||||
"request_id": "req_abc123",
|
||||
}
|
||||
|
||||
assert metadata.get("model_provider") == "openai"
|
||||
assert metadata.get("model_name") == "gpt-4-turbo"
|
||||
assert metadata.get("system_fingerprint") == "fp_12345"
|
||||
assert metadata.get("logprobs") is None
|
||||
assert metadata.get("finish_reason") == "stop"
|
||||
assert metadata.get("request_id") == "req_abc123"
|
||||
|
||||
def test_response_metadata_various_data_types(self) -> None:
|
||||
"""Test that ResponseMetadata can store various data types in extra fields."""
|
||||
metadata: ResponseMetadata = {
|
||||
"model_provider": "anthropic",
|
||||
"model_name": "claude-3-sonnet",
|
||||
"string_field": "test_value",
|
||||
"int_field": 42,
|
||||
"float_field": 3.14,
|
||||
"bool_field": True,
|
||||
"none_field": None,
|
||||
"list_field": [1, 2, 3, "test"],
|
||||
"dict_field": {"nested": {"deeply": "nested_value"}},
|
||||
}
|
||||
|
||||
assert metadata.get("string_field") == "test_value"
|
||||
assert metadata.get("int_field") == 42
|
||||
assert metadata.get("float_field") == 3.14
|
||||
assert metadata.get("bool_field") is True
|
||||
assert metadata.get("none_field") is None
|
||||
|
||||
list_field = metadata.get("list_field")
|
||||
assert isinstance(list_field, list)
|
||||
assert list_field == [1, 2, 3, "test"]
|
||||
|
||||
dict_field = metadata.get("dict_field")
|
||||
assert isinstance(dict_field, dict)
|
||||
nested = dict_field.get("nested")
|
||||
assert isinstance(nested, dict)
|
||||
assert nested.get("deeply") == "nested_value"
|
||||
|
||||
def test_response_metadata_can_be_modified(self) -> None:
|
||||
"""Test that ResponseMetadata can be modified after creation."""
|
||||
metadata: ResponseMetadata = {
|
||||
"model_provider": "openai",
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
}
|
||||
|
||||
# Modify existing fields
|
||||
metadata["model_name"] = "gpt-4"
|
||||
assert metadata.get("model_name") == "gpt-4"
|
||||
|
||||
# Add new fields
|
||||
metadata["request_id"] = "req_12345"
|
||||
assert metadata.get("request_id") == "req_12345"
|
||||
|
||||
# Modify nested structures
|
||||
metadata["headers"] = {"x-request-id": "abc123"}
|
||||
metadata["headers"]["x-rate-limit"] = "100" # type: ignore[typeddict-item]
|
||||
|
||||
headers = metadata.get("headers")
|
||||
assert isinstance(headers, dict)
|
||||
assert headers.get("x-request-id") == "abc123"
|
||||
assert headers.get("x-rate-limit") == "100"
|
||||
|
||||
def test_response_metadata_provider_specific_examples(self) -> None:
|
||||
"""Test ResponseMetadata with realistic provider-specific examples."""
|
||||
# OpenAI-style metadata
|
||||
openai_metadata: ResponseMetadata = {
|
||||
"model_provider": "openai",
|
||||
"model_name": "gpt-4-turbo-2024-04-09",
|
||||
"system_fingerprint": "fp_abc123",
|
||||
"created": 1234567890,
|
||||
"logprobs": None,
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
|
||||
assert openai_metadata.get("model_provider") == "openai"
|
||||
assert openai_metadata.get("system_fingerprint") == "fp_abc123"
|
||||
|
||||
# Anthropic-style metadata
|
||||
anthropic_metadata: ResponseMetadata = {
|
||||
"model_provider": "anthropic",
|
||||
"model_name": "claude-3-sonnet-20240229",
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": None,
|
||||
}
|
||||
|
||||
assert anthropic_metadata.get("model_provider") == "anthropic"
|
||||
assert anthropic_metadata.get("stop_reason") == "end_turn"
|
||||
|
||||
# Custom provider metadata
|
||||
custom_metadata: ResponseMetadata = {
|
||||
"model_provider": "custom_llm_service",
|
||||
"model_name": "custom-model-v1",
|
||||
"service_tier": "premium",
|
||||
"rate_limit_info": {
|
||||
"requests_remaining": 100,
|
||||
"reset_time": "2024-01-01T00:00:00Z",
|
||||
},
|
||||
"response_time_ms": 1250,
|
||||
}
|
||||
|
||||
assert custom_metadata.get("service_tier") == "premium"
|
||||
rate_limit = custom_metadata.get("rate_limit_info")
|
||||
assert isinstance(rate_limit, dict)
|
||||
assert rate_limit.get("requests_remaining") == 100
|
||||
|
||||
|
||||
class TestResponseMetadataWithAIMessages:
|
||||
"""Test ResponseMetadata integration with AI message classes."""
|
||||
|
||||
def test_ai_message_with_response_metadata(self) -> None:
|
||||
"""Test AIMessage with ResponseMetadata."""
|
||||
metadata: ResponseMetadata = {
|
||||
"model_provider": "openai",
|
||||
"model_name": "gpt-4",
|
||||
"system_fingerprint": "fp_xyz789",
|
||||
}
|
||||
|
||||
message = AIMessage(content="Hello, world!", response_metadata=metadata)
|
||||
|
||||
assert message.response_metadata == metadata
|
||||
assert message.response_metadata.get("model_provider") == "openai"
|
||||
assert message.response_metadata.get("model_name") == "gpt-4"
|
||||
assert message.response_metadata.get("system_fingerprint") == "fp_xyz789"
|
||||
|
||||
def test_ai_message_chunk_with_response_metadata(self) -> None:
|
||||
"""Test AIMessageChunk with ResponseMetadata."""
|
||||
metadata: ResponseMetadata = {
|
||||
"model_provider": "anthropic",
|
||||
"model_name": "claude-3-sonnet",
|
||||
"stream_id": "stream_12345",
|
||||
}
|
||||
|
||||
chunk = AIMessageChunk(content="Hello", response_metadata=metadata)
|
||||
|
||||
assert chunk.response_metadata == metadata
|
||||
assert chunk.response_metadata.get("stream_id") == "stream_12345"
|
||||
|
||||
def test_ai_message_default_empty_response_metadata(self) -> None:
|
||||
"""Test that AIMessage creates empty ResponseMetadata by default."""
|
||||
message = AIMessage(content="Test message")
|
||||
|
||||
# Should have empty dict as default
|
||||
assert message.response_metadata == {}
|
||||
assert isinstance(message.response_metadata, dict)
|
||||
|
||||
def test_ai_message_chunk_default_empty_response_metadata(self) -> None:
|
||||
"""Test that AIMessageChunk creates empty ResponseMetadata by default."""
|
||||
chunk = AIMessageChunk(content="Test chunk")
|
||||
|
||||
# Should have empty dict as default
|
||||
assert chunk.response_metadata == {}
|
||||
assert isinstance(chunk.response_metadata, dict)
|
||||
|
||||
def test_response_metadata_merging_in_chunks(self) -> None:
|
||||
"""Test that ResponseMetadata is properly merged when adding AIMessageChunks."""
|
||||
metadata1: ResponseMetadata = {
|
||||
"model_provider": "openai",
|
||||
"model_name": "gpt-4",
|
||||
"request_id": "req_123",
|
||||
"system_fingerprint": "fp_abc",
|
||||
}
|
||||
|
||||
metadata2: ResponseMetadata = {
|
||||
"stream_chunk": 1,
|
||||
"finish_reason": "length",
|
||||
}
|
||||
|
||||
chunk1 = AIMessageChunk(content="Hello ", response_metadata=metadata1)
|
||||
chunk2 = AIMessageChunk(content="world!", response_metadata=metadata2)
|
||||
|
||||
merged = chunk1 + chunk2
|
||||
|
||||
# Should have merged response_metadata
|
||||
assert merged.response_metadata.get("model_provider") == "openai"
|
||||
assert merged.response_metadata.get("model_name") == "gpt-4"
|
||||
assert merged.response_metadata.get("request_id") == "req_123"
|
||||
assert merged.response_metadata.get("stream_chunk") == 1
|
||||
assert merged.response_metadata.get("system_fingerprint") == "fp_abc"
|
||||
assert merged.response_metadata.get("finish_reason") == "length"
|
||||
|
||||
def test_response_metadata_modification_after_message_creation(self) -> None:
|
||||
"""Test that ResponseMetadata can be modified after message creation."""
|
||||
message = AIMessage(
|
||||
content="Initial message",
|
||||
response_metadata={"model_provider": "openai", "model_name": "gpt-3.5"},
|
||||
)
|
||||
|
||||
# Modify existing field
|
||||
message.response_metadata["model_name"] = "gpt-4"
|
||||
assert message.response_metadata.get("model_name") == "gpt-4"
|
||||
|
||||
# Add new field
|
||||
message.response_metadata["finish_reason"] = "stop"
|
||||
assert message.response_metadata.get("finish_reason") == "stop"
|
||||
|
||||
def test_response_metadata_with_none_values(self) -> None:
|
||||
"""Test ResponseMetadata handling of None values."""
|
||||
metadata: ResponseMetadata = {
|
||||
"model_provider": "openai",
|
||||
"model_name": "gpt-4",
|
||||
"system_fingerprint": None,
|
||||
"logprobs": None,
|
||||
}
|
||||
|
||||
message = AIMessage(content="Test", response_metadata=metadata)
|
||||
|
||||
assert message.response_metadata.get("system_fingerprint") is None
|
||||
assert message.response_metadata.get("logprobs") is None
|
||||
assert "system_fingerprint" in message.response_metadata
|
||||
assert "logprobs" in message.response_metadata
|
||||
|
||||
|
||||
class TestResponseMetadataEdgeCases:
|
||||
"""Test edge cases and error conditions for ResponseMetadata."""
|
||||
|
||||
def test_response_metadata_with_complex_nested_structures(self) -> None:
|
||||
"""Test ResponseMetadata with deeply nested and complex structures."""
|
||||
metadata: ResponseMetadata = {
|
||||
"model_provider": "custom",
|
||||
"model_name": "complex-model",
|
||||
"complex_data": {
|
||||
"level1": {
|
||||
"level2": {
|
||||
"level3": {
|
||||
"deeply_nested": "value",
|
||||
"array": [
|
||||
{"item": 1, "metadata": {"nested": True}},
|
||||
{"item": 2, "metadata": {"nested": False}},
|
||||
],
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
complex_data = metadata.get("complex_data")
|
||||
assert isinstance(complex_data, dict)
|
||||
level1 = complex_data.get("level1")
|
||||
assert isinstance(level1, dict)
|
||||
level2 = level1.get("level2")
|
||||
assert isinstance(level2, dict)
|
||||
level3 = level2.get("level3")
|
||||
assert isinstance(level3, dict)
|
||||
|
||||
assert level3.get("deeply_nested") == "value"
|
||||
array = level3.get("array")
|
||||
assert isinstance(array, list)
|
||||
assert len(array) == 2
|
||||
assert array[0]["item"] == 1
|
||||
assert array[0]["metadata"]["nested"] is True
|
||||
|
||||
def test_response_metadata_large_data(self) -> None:
|
||||
"""Test ResponseMetadata with large amounts of data."""
|
||||
# Create metadata with many fields
|
||||
large_metadata: ResponseMetadata = {
|
||||
"model_provider": "test_provider",
|
||||
"model_name": "test_model",
|
||||
}
|
||||
|
||||
# Add 100 extra fields
|
||||
for i in range(100):
|
||||
large_metadata[f"field_{i}"] = f"value_{i}" # type: ignore[literal-required]
|
||||
|
||||
message = AIMessage(content="Test", response_metadata=large_metadata)
|
||||
|
||||
# Verify all fields are accessible
|
||||
assert message.response_metadata.get("model_provider") == "test_provider"
|
||||
for i in range(100):
|
||||
assert message.response_metadata.get(f"field_{i}") == f"value_{i}"
|
||||
|
||||
def test_response_metadata_empty_vs_none(self) -> None:
|
||||
"""Test the difference between empty ResponseMetadata and None."""
|
||||
# Message with empty metadata
|
||||
message_empty = AIMessage(content="Test", response_metadata={})
|
||||
assert message_empty.response_metadata == {}
|
||||
assert isinstance(message_empty.response_metadata, dict)
|
||||
|
||||
# Message with None metadata (should become empty dict)
|
||||
message_none = AIMessage(content="Test", response_metadata=None)
|
||||
assert message_none.response_metadata == {}
|
||||
assert isinstance(message_none.response_metadata, dict)
|
||||
|
||||
# Default message (no metadata specified)
|
||||
message_default = AIMessage(content="Test")
|
||||
assert message_default.response_metadata == {}
|
||||
assert isinstance(message_default.response_metadata, dict)
|
||||
|
||||
def test_response_metadata_preserves_original_dict_type(self) -> None:
|
||||
"""Test that ResponseMetadata preserves the original dict when passed."""
|
||||
original_dict: ResponseMetadata = {
|
||||
"model_provider": "openai",
|
||||
"model_name": "gpt-4",
|
||||
"custom_field": "custom_value",
|
||||
}
|
||||
|
||||
message = AIMessage(content="Test", response_metadata=original_dict)
|
||||
|
||||
# Should be the same dict object
|
||||
assert message.response_metadata is original_dict
|
||||
|
||||
# Modifications to the message's response_metadata should affect original
|
||||
message.response_metadata["new_field"] = "new_value"
|
||||
assert original_dict.get("new_field") == "new_value"
|
||||
@@ -0,0 +1,361 @@
|
||||
"""Unit tests for ResponseMetadata TypedDict."""
|
||||
|
||||
from langchain_core.messages.v1 import AIMessage, AIMessageChunk, ResponseMetadata
|
||||
|
||||
|
||||
class TestResponseMetadata:
|
||||
"""Test the ResponseMetadata TypedDict functionality."""
|
||||
|
||||
def test_response_metadata_basic_fields(self) -> None:
|
||||
"""Test ResponseMetadata with basic required fields."""
|
||||
metadata: ResponseMetadata = {
|
||||
"model_provider": "openai",
|
||||
"model_name": "gpt-4",
|
||||
}
|
||||
|
||||
assert metadata.get("model_provider") == "openai"
|
||||
assert metadata.get("model_name") == "gpt-4"
|
||||
|
||||
def test_response_metadata_is_optional(self) -> None:
|
||||
"""Test that ResponseMetadata fields are optional due to total=False."""
|
||||
# Should be able to create empty ResponseMetadata
|
||||
metadata: ResponseMetadata = {}
|
||||
assert metadata == {}
|
||||
|
||||
# Should be able to create with just one field
|
||||
metadata_partial: ResponseMetadata = {"model_provider": "anthropic"}
|
||||
assert metadata_partial.get("model_provider") == "anthropic"
|
||||
assert "model_name" not in metadata_partial
|
||||
|
||||
def test_response_metadata_supports_extra_fields(self) -> None:
|
||||
"""Test that ResponseMetadata supports provider-specific extra fields."""
|
||||
metadata: ResponseMetadata = {
|
||||
"model_provider": "openai",
|
||||
"model_name": "gpt-4-turbo",
|
||||
# Extra fields should be allowed
|
||||
"usage": {"input_tokens": 100, "output_tokens": 50},
|
||||
"system_fingerprint": "fp_12345",
|
||||
"logprobs": None,
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
|
||||
assert metadata.get("model_provider") == "openai"
|
||||
assert metadata.get("model_name") == "gpt-4-turbo"
|
||||
assert metadata.get("usage") == {"input_tokens": 100, "output_tokens": 50}
|
||||
assert metadata.get("system_fingerprint") == "fp_12345"
|
||||
assert metadata.get("logprobs") is None
|
||||
assert metadata.get("finish_reason") == "stop"
|
||||
|
||||
def test_response_metadata_various_data_types(self) -> None:
|
||||
"""Test that ResponseMetadata can store various data types in extra fields."""
|
||||
metadata: ResponseMetadata = {
|
||||
"model_provider": "anthropic",
|
||||
"model_name": "claude-3-sonnet",
|
||||
"string_field": "test_value", # type: ignore[typeddict-unknown-key]
|
||||
"int_field": 42, # type: ignore[typeddict-unknown-key]
|
||||
"float_field": 3.14, # type: ignore[typeddict-unknown-key]
|
||||
"bool_field": True, # type: ignore[typeddict-unknown-key]
|
||||
"none_field": None, # type: ignore[typeddict-unknown-key]
|
||||
"list_field": [1, 2, 3, "test"], # type: ignore[typeddict-unknown-key]
|
||||
"dict_field": { # type: ignore[typeddict-unknown-key]
|
||||
"nested": {"deeply": "nested_value"}
|
||||
},
|
||||
}
|
||||
|
||||
assert metadata.get("string_field") == "test_value" # type: ignore[typeddict-item]
|
||||
assert metadata.get("int_field") == 42 # type: ignore[typeddict-item]
|
||||
assert metadata.get("float_field") == 3.14 # type: ignore[typeddict-item]
|
||||
assert metadata.get("bool_field") is True # type: ignore[typeddict-item]
|
||||
assert metadata.get("none_field") is None # type: ignore[typeddict-item]
|
||||
|
||||
list_field = metadata.get("list_field") # type: ignore[typeddict-item]
|
||||
assert isinstance(list_field, list)
|
||||
assert list_field == [1, 2, 3, "test"]
|
||||
|
||||
dict_field = metadata.get("dict_field") # type: ignore[typeddict-item]
|
||||
assert isinstance(dict_field, dict)
|
||||
nested = dict_field.get("nested") # type: ignore[union-attr]
|
||||
assert isinstance(nested, dict)
|
||||
assert nested.get("deeply") == "nested_value" # type: ignore[union-attr]
|
||||
|
||||
def test_response_metadata_can_be_modified(self) -> None:
|
||||
"""Test that ResponseMetadata can be modified after creation."""
|
||||
metadata: ResponseMetadata = {
|
||||
"model_provider": "openai",
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
}
|
||||
|
||||
# Modify existing fields
|
||||
metadata["model_name"] = "gpt-4"
|
||||
assert metadata.get("model_name") == "gpt-4"
|
||||
|
||||
# Add new fields
|
||||
metadata["request_id"] = "req_12345" # type: ignore[typeddict-unknown-key]
|
||||
assert metadata.get("request_id") == "req_12345" # type: ignore[typeddict-item]
|
||||
|
||||
# Modify nested structures
|
||||
metadata["usage"] = {"input_tokens": 10} # type: ignore[typeddict-unknown-key]
|
||||
metadata["usage"]["output_tokens"] = 20 # type: ignore[typeddict-item]
|
||||
|
||||
usage = metadata.get("usage") # type: ignore[typeddict-item]
|
||||
assert isinstance(usage, dict)
|
||||
assert usage.get("input_tokens") == 10 # type: ignore[union-attr]
|
||||
assert usage.get("output_tokens") == 20 # type: ignore[union-attr]
|
||||
|
||||
def test_response_metadata_provider_specific_examples(self) -> None:
|
||||
"""Test ResponseMetadata with realistic provider-specific examples."""
|
||||
# OpenAI-style metadata
|
||||
openai_metadata: ResponseMetadata = {
|
||||
"model_provider": "openai",
|
||||
"model_name": "gpt-4-turbo-2024-04-09",
|
||||
"usage": { # type: ignore[typeddict-unknown-key]
|
||||
"prompt_tokens": 50,
|
||||
"completion_tokens": 25,
|
||||
"total_tokens": 75,
|
||||
},
|
||||
"system_fingerprint": "fp_abc123", # type: ignore[typeddict-unknown-key]
|
||||
"created": 1234567890, # type: ignore[typeddict-unknown-key]
|
||||
"logprobs": None, # type: ignore[typeddict-unknown-key]
|
||||
"finish_reason": "stop", # type: ignore[typeddict-unknown-key]
|
||||
}
|
||||
|
||||
assert openai_metadata.get("model_provider") == "openai"
|
||||
assert openai_metadata.get("system_fingerprint") == "fp_abc123" # type: ignore[typeddict-item]
|
||||
|
||||
# Anthropic-style metadata
|
||||
anthropic_metadata: ResponseMetadata = {
|
||||
"model_provider": "anthropic",
|
||||
"model_name": "claude-3-sonnet-20240229",
|
||||
"usage": { # type: ignore[typeddict-unknown-key]
|
||||
"input_tokens": 75,
|
||||
"output_tokens": 30,
|
||||
},
|
||||
"stop_reason": "end_turn", # type: ignore[typeddict-unknown-key]
|
||||
"stop_sequence": None, # type: ignore[typeddict-unknown-key]
|
||||
}
|
||||
|
||||
assert anthropic_metadata.get("model_provider") == "anthropic"
|
||||
assert anthropic_metadata.get("stop_reason") == "end_turn" # type: ignore[typeddict-item]
|
||||
|
||||
# Custom provider metadata
|
||||
custom_metadata: ResponseMetadata = {
|
||||
"model_provider": "custom_llm_service",
|
||||
"model_name": "custom-model-v1",
|
||||
"service_tier": "premium", # type: ignore[typeddict-unknown-key]
|
||||
"rate_limit_info": { # type: ignore[typeddict-unknown-key]
|
||||
"requests_remaining": 100,
|
||||
"reset_time": "2024-01-01T00:00:00Z",
|
||||
},
|
||||
"response_time_ms": 1250, # type: ignore[typeddict-unknown-key]
|
||||
}
|
||||
|
||||
assert custom_metadata.get("service_tier") == "premium" # type: ignore[typeddict-item]
|
||||
rate_limit = custom_metadata.get("rate_limit_info") # type: ignore[typeddict-item]
|
||||
assert isinstance(rate_limit, dict)
|
||||
assert rate_limit.get("requests_remaining") == 100 # type: ignore[union-attr]
|
||||
|
||||
|
||||
class TestResponseMetadataWithAIMessages:
|
||||
"""Test ResponseMetadata integration with AI message classes."""
|
||||
|
||||
def test_ai_message_with_response_metadata(self) -> None:
|
||||
"""Test AIMessage with ResponseMetadata."""
|
||||
metadata: ResponseMetadata = {
|
||||
"model_provider": "openai",
|
||||
"model_name": "gpt-4",
|
||||
"usage": {"input_tokens": 10, "output_tokens": 5}, # type: ignore[typeddict-unknown-key]
|
||||
}
|
||||
|
||||
message = AIMessage(content="Hello, world!", response_metadata=metadata)
|
||||
|
||||
assert message.response_metadata == metadata
|
||||
assert message.response_metadata.get("model_provider") == "openai"
|
||||
assert message.response_metadata.get("model_name") == "gpt-4"
|
||||
|
||||
usage = message.response_metadata.get("usage") # type: ignore[typeddict-item]
|
||||
assert isinstance(usage, dict)
|
||||
assert usage.get("input_tokens") == 10 # type: ignore[union-attr]
|
||||
|
||||
def test_ai_message_chunk_with_response_metadata(self) -> None:
|
||||
"""Test AIMessageChunk with ResponseMetadata."""
|
||||
metadata: ResponseMetadata = {
|
||||
"model_provider": "anthropic",
|
||||
"model_name": "claude-3-sonnet",
|
||||
"stream_id": "stream_12345", # type: ignore[typeddict-unknown-key]
|
||||
}
|
||||
|
||||
chunk = AIMessageChunk(content="Hello", response_metadata=metadata)
|
||||
|
||||
assert chunk.response_metadata == metadata
|
||||
assert chunk.response_metadata.get("stream_id") == "stream_12345" # type: ignore[typeddict-item]
|
||||
|
||||
def test_ai_message_default_empty_response_metadata(self) -> None:
|
||||
"""Test that AIMessage creates empty ResponseMetadata by default."""
|
||||
message = AIMessage(content="Test message")
|
||||
|
||||
# Should have empty dict as default
|
||||
assert message.response_metadata == {}
|
||||
assert isinstance(message.response_metadata, dict)
|
||||
|
||||
def test_ai_message_chunk_default_empty_response_metadata(self) -> None:
|
||||
"""Test that AIMessageChunk creates empty ResponseMetadata by default."""
|
||||
chunk = AIMessageChunk(content="Test chunk")
|
||||
|
||||
# Should have empty dict as default
|
||||
assert chunk.response_metadata == {}
|
||||
assert isinstance(chunk.response_metadata, dict)
|
||||
|
||||
def test_response_metadata_merging_in_chunks(self) -> None:
|
||||
"""Test that ResponseMetadata is properly merged when adding AIMessageChunks."""
|
||||
metadata1: ResponseMetadata = {
|
||||
"model_provider": "openai",
|
||||
"model_name": "gpt-4",
|
||||
"request_id": "req_123", # type: ignore[typeddict-unknown-key]
|
||||
"usage": {"input_tokens": 10}, # type: ignore[typeddict-unknown-key]
|
||||
}
|
||||
|
||||
metadata2: ResponseMetadata = {
|
||||
"stream_chunk": 1, # type: ignore[typeddict-unknown-key]
|
||||
"usage": {"output_tokens": 5}, # type: ignore[typeddict-unknown-key]
|
||||
}
|
||||
|
||||
chunk1 = AIMessageChunk(content="Hello ", response_metadata=metadata1)
|
||||
chunk2 = AIMessageChunk(content="world!", response_metadata=metadata2)
|
||||
|
||||
merged = chunk1 + chunk2
|
||||
|
||||
# Should have merged response_metadata
|
||||
assert merged.response_metadata.get("model_provider") == "openai"
|
||||
assert merged.response_metadata.get("model_name") == "gpt-4"
|
||||
assert merged.response_metadata.get("request_id") == "req_123" # type: ignore[typeddict-item]
|
||||
assert merged.response_metadata.get("stream_chunk") == 1 # type: ignore[typeddict-item]
|
||||
|
||||
# Usage should be merged (from merge_dicts behavior)
|
||||
merged_usage = merged.response_metadata.get("usage") # type: ignore[typeddict-item]
|
||||
assert isinstance(merged_usage, dict)
|
||||
assert merged_usage.get("input_tokens") == 10 # type: ignore[union-attr]
|
||||
assert merged_usage.get("output_tokens") == 5 # type: ignore[union-attr]
|
||||
|
||||
def test_response_metadata_modification_after_message_creation(self) -> None:
|
||||
"""Test that ResponseMetadata can be modified after message creation."""
|
||||
message = AIMessage(
|
||||
content="Initial message",
|
||||
response_metadata={"model_provider": "openai", "model_name": "gpt-3.5"},
|
||||
)
|
||||
|
||||
# Modify existing field
|
||||
message.response_metadata["model_name"] = "gpt-4"
|
||||
assert message.response_metadata.get("model_name") == "gpt-4"
|
||||
|
||||
# Add new field
|
||||
message.response_metadata["finish_reason"] = "stop" # type: ignore[typeddict-unknown-key]
|
||||
assert message.response_metadata.get("finish_reason") == "stop" # type: ignore[typeddict-item]
|
||||
|
||||
def test_response_metadata_with_none_values(self) -> None:
|
||||
"""Test ResponseMetadata handling of None values."""
|
||||
metadata: ResponseMetadata = {
|
||||
"model_provider": "openai",
|
||||
"model_name": "gpt-4",
|
||||
"system_fingerprint": None, # type: ignore[typeddict-unknown-key]
|
||||
"logprobs": None, # type: ignore[typeddict-unknown-key]
|
||||
}
|
||||
|
||||
message = AIMessage(content="Test", response_metadata=metadata)
|
||||
|
||||
assert message.response_metadata.get("system_fingerprint") is None # type: ignore[typeddict-item]
|
||||
assert message.response_metadata.get("logprobs") is None # type: ignore[typeddict-item]
|
||||
assert "system_fingerprint" in message.response_metadata
|
||||
assert "logprobs" in message.response_metadata
|
||||
|
||||
|
||||
class TestResponseMetadataEdgeCases:
|
||||
"""Test edge cases and error conditions for ResponseMetadata."""
|
||||
|
||||
def test_response_metadata_with_complex_nested_structures(self) -> None:
|
||||
"""Test ResponseMetadata with deeply nested and complex structures."""
|
||||
metadata: ResponseMetadata = {
|
||||
"model_provider": "custom",
|
||||
"model_name": "complex-model",
|
||||
"complex_data": { # type: ignore[typeddict-unknown-key]
|
||||
"level1": {
|
||||
"level2": {
|
||||
"level3": {
|
||||
"deeply_nested": "value",
|
||||
"array": [
|
||||
{"item": 1, "metadata": {"nested": True}},
|
||||
{"item": 2, "metadata": {"nested": False}},
|
||||
],
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
complex_data = metadata.get("complex_data") # type: ignore[typeddict-item]
|
||||
assert isinstance(complex_data, dict)
|
||||
level1 = complex_data.get("level1") # type: ignore[union-attr]
|
||||
assert isinstance(level1, dict)
|
||||
level2 = level1.get("level2") # type: ignore[union-attr]
|
||||
assert isinstance(level2, dict)
|
||||
level3 = level2.get("level3") # type: ignore[union-attr]
|
||||
assert isinstance(level3, dict)
|
||||
|
||||
assert level3.get("deeply_nested") == "value" # type: ignore[union-attr]
|
||||
array = level3.get("array") # type: ignore[union-attr]
|
||||
assert isinstance(array, list)
|
||||
assert len(array) == 2 # type: ignore[arg-type]
|
||||
assert array[0]["item"] == 1 # type: ignore[index, typeddict-item]
|
||||
assert array[0]["metadata"]["nested"] is True # type: ignore[index, typeddict-item]
|
||||
|
||||
def test_response_metadata_large_data(self) -> None:
|
||||
"""Test ResponseMetadata with large amounts of data."""
|
||||
# Create metadata with many fields
|
||||
large_metadata: ResponseMetadata = {
|
||||
"model_provider": "test_provider",
|
||||
"model_name": "test_model",
|
||||
}
|
||||
|
||||
# Add 100 extra fields
|
||||
for i in range(100):
|
||||
large_metadata[f"field_{i}"] = f"value_{i}" # type: ignore[literal-required]
|
||||
|
||||
message = AIMessage(content="Test", response_metadata=large_metadata)
|
||||
|
||||
# Verify all fields are accessible
|
||||
assert message.response_metadata.get("model_provider") == "test_provider"
|
||||
for i in range(100):
|
||||
assert message.response_metadata.get(f"field_{i}") == f"value_{i}" # type: ignore[typeddict-item]
|
||||
|
||||
def test_response_metadata_empty_vs_none(self) -> None:
|
||||
"""Test the difference between empty ResponseMetadata and None."""
|
||||
# Message with empty metadata
|
||||
message_empty = AIMessage(content="Test", response_metadata={})
|
||||
assert message_empty.response_metadata == {}
|
||||
assert isinstance(message_empty.response_metadata, dict)
|
||||
|
||||
# Message with None metadata (should become empty dict)
|
||||
message_none = AIMessage(content="Test", response_metadata=None)
|
||||
assert message_none.response_metadata == {}
|
||||
assert isinstance(message_none.response_metadata, dict)
|
||||
|
||||
# Default message (no metadata specified)
|
||||
message_default = AIMessage(content="Test")
|
||||
assert message_default.response_metadata == {}
|
||||
assert isinstance(message_default.response_metadata, dict)
|
||||
|
||||
def test_response_metadata_preserves_original_dict_type(self) -> None:
|
||||
"""Test that ResponseMetadata preserves the original dict when passed."""
|
||||
original_dict = {
|
||||
"model_provider": "openai",
|
||||
"model_name": "gpt-4",
|
||||
"custom_field": "custom_value",
|
||||
}
|
||||
|
||||
message = AIMessage(content="Test", response_metadata=original_dict)
|
||||
|
||||
# Should be the same dict object
|
||||
assert message.response_metadata is original_dict
|
||||
|
||||
# Modifications to the message's response_metadata should affect original
|
||||
message.response_metadata["new_field"] = "new_value" # type: ignore[typeddict-unknown-key]
|
||||
assert original_dict.get("new_field") == "new_value" # type: ignore[typeddict-item]
|
||||
@@ -1221,15 +1221,30 @@ def test_convert_to_openai_messages_multimodal() -> None:
|
||||
{"type": "text", "text": "Text message"},
|
||||
{
|
||||
"type": "image",
|
||||
"source_type": "url",
|
||||
"url": "https://example.com/test.png",
|
||||
},
|
||||
{
|
||||
"type": "image",
|
||||
"source_type": "url", # backward compatibility
|
||||
"url": "https://example.com/test.png",
|
||||
},
|
||||
{
|
||||
"type": "image",
|
||||
"base64": "<base64 string>",
|
||||
"mime_type": "image/png",
|
||||
},
|
||||
{
|
||||
"type": "image",
|
||||
"source_type": "base64",
|
||||
"data": "<base64 string>",
|
||||
"mime_type": "image/png",
|
||||
},
|
||||
{
|
||||
"type": "file",
|
||||
"base64": "<base64 string>",
|
||||
"mime_type": "application/pdf",
|
||||
"filename": "test.pdf",
|
||||
},
|
||||
{
|
||||
"type": "file",
|
||||
"source_type": "base64",
|
||||
@@ -1244,11 +1259,20 @@ def test_convert_to_openai_messages_multimodal() -> None:
|
||||
"file_data": "data:application/pdf;base64,<base64 string>",
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "file",
|
||||
"file_id": "file-abc123",
|
||||
},
|
||||
{
|
||||
"type": "file",
|
||||
"source_type": "id",
|
||||
"id": "file-abc123",
|
||||
},
|
||||
{
|
||||
"type": "audio",
|
||||
"base64": "<base64 string>",
|
||||
"mime_type": "audio/wav",
|
||||
},
|
||||
{
|
||||
"type": "audio",
|
||||
"source_type": "base64",
|
||||
@@ -1268,7 +1292,7 @@ def test_convert_to_openai_messages_multimodal() -> None:
|
||||
result = convert_to_openai_messages(messages, text_format="block")
|
||||
assert len(result) == 1
|
||||
message = result[0]
|
||||
assert len(message["content"]) == 8
|
||||
assert len(message["content"]) == 13
|
||||
|
||||
# Test adding filename
|
||||
messages = [
|
||||
@@ -1276,8 +1300,7 @@ def test_convert_to_openai_messages_multimodal() -> None:
|
||||
content=[
|
||||
{
|
||||
"type": "file",
|
||||
"source_type": "base64",
|
||||
"data": "<base64 string>",
|
||||
"base64": "<base64 string>",
|
||||
"mime_type": "application/pdf",
|
||||
},
|
||||
]
|
||||
|
||||
@@ -1,15 +1,19 @@
|
||||
"""Module to test base parser implementations."""
|
||||
|
||||
from typing import Union
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.language_models import GenericFakeChatModel
|
||||
from langchain_core.language_models.fake_chat_models import GenericFakeChatModelV1
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.output_parsers import (
|
||||
BaseGenerationOutputParser,
|
||||
BaseTransformOutputParser,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, Generation
|
||||
from langchain_core.v1.messages import AIMessage as AIMessageV1
|
||||
|
||||
|
||||
def test_base_generation_parser() -> None:
|
||||
@@ -20,7 +24,7 @@ def test_base_generation_parser() -> None:
|
||||
|
||||
@override
|
||||
def parse_result(
|
||||
self, result: list[Generation], *, partial: bool = False
|
||||
self, result: Union[list[Generation], AIMessageV1], *, partial: bool = False
|
||||
) -> str:
|
||||
"""Parse a list of model Generations into a specific format.
|
||||
|
||||
@@ -32,16 +36,22 @@ def test_base_generation_parser() -> None:
|
||||
partial: Whether to allow partial results. This is used for parsers
|
||||
that support streaming
|
||||
"""
|
||||
if len(result) != 1:
|
||||
msg = "This output parser can only be used with a single generation."
|
||||
raise NotImplementedError(msg)
|
||||
generation = result[0]
|
||||
if not isinstance(generation, ChatGeneration):
|
||||
# Say that this one only works with chat generations
|
||||
msg = "This output parser can only be used with a chat generation."
|
||||
raise OutputParserException(msg)
|
||||
if isinstance(result, AIMessageV1):
|
||||
content = result.text
|
||||
else:
|
||||
if len(result) != 1:
|
||||
msg = (
|
||||
"This output parser can only be used with a single generation."
|
||||
)
|
||||
raise NotImplementedError(msg)
|
||||
generation = result[0]
|
||||
if not isinstance(generation, ChatGeneration):
|
||||
# Say that this one only works with chat generations
|
||||
msg = "This output parser can only be used with a chat generation."
|
||||
raise OutputParserException(msg)
|
||||
assert isinstance(generation.message.content, str)
|
||||
content = generation.message.content
|
||||
|
||||
content = generation.message.content
|
||||
assert isinstance(content, str)
|
||||
return content.swapcase()
|
||||
|
||||
@@ -49,6 +59,10 @@ def test_base_generation_parser() -> None:
|
||||
chain = model | StrInvertCase()
|
||||
assert chain.invoke("") == "HeLLO"
|
||||
|
||||
model_v1 = GenericFakeChatModelV1(messages=iter([AIMessageV1("hEllo")]))
|
||||
chain_v1 = model_v1 | StrInvertCase()
|
||||
assert chain_v1.invoke("") == "HeLLO"
|
||||
|
||||
|
||||
def test_base_transform_output_parser() -> None:
|
||||
"""Test base transform output parser."""
|
||||
@@ -62,7 +76,7 @@ def test_base_transform_output_parser() -> None:
|
||||
|
||||
@override
|
||||
def parse_result(
|
||||
self, result: list[Generation], *, partial: bool = False
|
||||
self, result: Union[list[Generation], AIMessageV1], *, partial: bool = False
|
||||
) -> str:
|
||||
"""Parse a list of model Generations into a specific format.
|
||||
|
||||
@@ -74,15 +88,22 @@ def test_base_transform_output_parser() -> None:
|
||||
partial: Whether to allow partial results. This is used for parsers
|
||||
that support streaming
|
||||
"""
|
||||
if len(result) != 1:
|
||||
msg = "This output parser can only be used with a single generation."
|
||||
raise NotImplementedError(msg)
|
||||
generation = result[0]
|
||||
if not isinstance(generation, ChatGeneration):
|
||||
# Say that this one only works with chat generations
|
||||
msg = "This output parser can only be used with a chat generation."
|
||||
raise OutputParserException(msg)
|
||||
content = generation.message.content
|
||||
if isinstance(result, AIMessageV1):
|
||||
content = result.text
|
||||
else:
|
||||
if len(result) != 1:
|
||||
msg = (
|
||||
"This output parser can only be used with a single generation."
|
||||
)
|
||||
raise NotImplementedError(msg)
|
||||
generation = result[0]
|
||||
if not isinstance(generation, ChatGeneration):
|
||||
# Say that this one only works with chat generations
|
||||
msg = "This output parser can only be used with a chat generation."
|
||||
raise OutputParserException(msg)
|
||||
assert isinstance(generation.message.content, str)
|
||||
content = generation.message.content
|
||||
|
||||
assert isinstance(content, str)
|
||||
return content.swapcase()
|
||||
|
||||
@@ -91,3 +112,8 @@ def test_base_transform_output_parser() -> None:
|
||||
# inputs to models are ignored, response is hard-coded in model definition
|
||||
chunks = list(chain.stream(""))
|
||||
assert chunks == ["HELLO", " ", "WORLD"]
|
||||
|
||||
model_v1 = GenericFakeChatModelV1(message_chunks=["hello", " ", "world"])
|
||||
chain_v1 = model_v1 | StrInvertCase()
|
||||
chunks = list(chain_v1.stream(""))
|
||||
assert chunks == ["HELLO", " ", "WORLD", ""]
|
||||
|
||||
@@ -16,6 +16,8 @@ from langchain_core.output_parsers.openai_tools import (
|
||||
PydanticToolsParser,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration
|
||||
from langchain_core.v1.messages import AIMessage as AIMessageV1
|
||||
from langchain_core.v1.messages import AIMessageChunk as AIMessageChunkV1
|
||||
|
||||
STREAMED_MESSAGES: list = [
|
||||
AIMessageChunk(content=""),
|
||||
@@ -331,6 +333,14 @@ for message in STREAMED_MESSAGES:
|
||||
STREAMED_MESSAGES_WITH_TOOL_CALLS.append(message)
|
||||
|
||||
|
||||
STREAMED_MESSAGES_V1 = [
|
||||
AIMessageChunkV1(
|
||||
content=[],
|
||||
tool_call_chunks=chunk.tool_call_chunks,
|
||||
)
|
||||
for chunk in STREAMED_MESSAGES_WITH_TOOL_CALLS
|
||||
]
|
||||
|
||||
EXPECTED_STREAMED_JSON = [
|
||||
{},
|
||||
{"names": ["suz"]},
|
||||
@@ -398,6 +408,19 @@ def test_partial_json_output_parser(*, use_tool_calls: bool) -> None:
|
||||
assert actual == expected
|
||||
|
||||
|
||||
def test_partial_json_output_parser_v1() -> None:
|
||||
def input_iter(_: Any) -> Iterator[AIMessageChunkV1]:
|
||||
yield from STREAMED_MESSAGES_V1
|
||||
|
||||
chain = input_iter | JsonOutputToolsParser()
|
||||
|
||||
actual = list(chain.stream(None))
|
||||
expected: list = [[]] + [
|
||||
[{"type": "NameCollector", "args": chunk}] for chunk in EXPECTED_STREAMED_JSON
|
||||
]
|
||||
assert actual == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_tool_calls", [False, True])
|
||||
async def test_partial_json_output_parser_async(*, use_tool_calls: bool) -> None:
|
||||
input_iter = _get_aiter(use_tool_calls=use_tool_calls)
|
||||
@@ -410,6 +433,20 @@ async def test_partial_json_output_parser_async(*, use_tool_calls: bool) -> None
|
||||
assert actual == expected
|
||||
|
||||
|
||||
async def test_partial_json_output_parser_async_v1() -> None:
|
||||
async def input_iter(_: Any) -> AsyncIterator[AIMessageChunkV1]:
|
||||
for msg in STREAMED_MESSAGES_V1:
|
||||
yield msg
|
||||
|
||||
chain = input_iter | JsonOutputToolsParser()
|
||||
|
||||
actual = [p async for p in chain.astream(None)]
|
||||
expected: list = [[]] + [
|
||||
[{"type": "NameCollector", "args": chunk}] for chunk in EXPECTED_STREAMED_JSON
|
||||
]
|
||||
assert actual == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_tool_calls", [False, True])
|
||||
def test_partial_json_output_parser_return_id(*, use_tool_calls: bool) -> None:
|
||||
input_iter = _get_iter(use_tool_calls=use_tool_calls)
|
||||
@@ -429,6 +466,26 @@ def test_partial_json_output_parser_return_id(*, use_tool_calls: bool) -> None:
|
||||
assert actual == expected
|
||||
|
||||
|
||||
def test_partial_json_output_parser_return_id_v1() -> None:
|
||||
def input_iter(_: Any) -> Iterator[AIMessageChunkV1]:
|
||||
yield from STREAMED_MESSAGES_V1
|
||||
|
||||
chain = input_iter | JsonOutputToolsParser(return_id=True)
|
||||
|
||||
actual = list(chain.stream(None))
|
||||
expected: list = [[]] + [
|
||||
[
|
||||
{
|
||||
"type": "NameCollector",
|
||||
"args": chunk,
|
||||
"id": "call_OwL7f5PEPJTYzw9sQlNJtCZl",
|
||||
}
|
||||
]
|
||||
for chunk in EXPECTED_STREAMED_JSON
|
||||
]
|
||||
assert actual == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_tool_calls", [False, True])
|
||||
def test_partial_json_output_key_parser(*, use_tool_calls: bool) -> None:
|
||||
input_iter = _get_iter(use_tool_calls=use_tool_calls)
|
||||
@@ -439,6 +496,17 @@ def test_partial_json_output_key_parser(*, use_tool_calls: bool) -> None:
|
||||
assert actual == expected
|
||||
|
||||
|
||||
def test_partial_json_output_key_parser_v1() -> None:
|
||||
def input_iter(_: Any) -> Iterator[AIMessageChunkV1]:
|
||||
yield from STREAMED_MESSAGES_V1
|
||||
|
||||
chain = input_iter | JsonOutputKeyToolsParser(key_name="NameCollector")
|
||||
|
||||
actual = list(chain.stream(None))
|
||||
expected: list = [[]] + [[chunk] for chunk in EXPECTED_STREAMED_JSON]
|
||||
assert actual == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_tool_calls", [False, True])
|
||||
async def test_partial_json_output_parser_key_async(*, use_tool_calls: bool) -> None:
|
||||
input_iter = _get_aiter(use_tool_calls=use_tool_calls)
|
||||
@@ -450,6 +518,18 @@ async def test_partial_json_output_parser_key_async(*, use_tool_calls: bool) ->
|
||||
assert actual == expected
|
||||
|
||||
|
||||
async def test_partial_json_output_parser_key_async_v1() -> None:
|
||||
async def input_iter(_: Any) -> AsyncIterator[AIMessageChunkV1]:
|
||||
for msg in STREAMED_MESSAGES_V1:
|
||||
yield msg
|
||||
|
||||
chain = input_iter | JsonOutputKeyToolsParser(key_name="NameCollector")
|
||||
|
||||
actual = [p async for p in chain.astream(None)]
|
||||
expected: list = [[]] + [[chunk] for chunk in EXPECTED_STREAMED_JSON]
|
||||
assert actual == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_tool_calls", [False, True])
|
||||
def test_partial_json_output_key_parser_first_only(*, use_tool_calls: bool) -> None:
|
||||
input_iter = _get_iter(use_tool_calls=use_tool_calls)
|
||||
@@ -461,6 +541,17 @@ def test_partial_json_output_key_parser_first_only(*, use_tool_calls: bool) -> N
|
||||
assert list(chain.stream(None)) == EXPECTED_STREAMED_JSON
|
||||
|
||||
|
||||
def test_partial_json_output_key_parser_first_only_v1() -> None:
|
||||
def input_iter(_: Any) -> Iterator[AIMessageChunkV1]:
|
||||
yield from STREAMED_MESSAGES_V1
|
||||
|
||||
chain = input_iter | JsonOutputKeyToolsParser(
|
||||
key_name="NameCollector", first_tool_only=True
|
||||
)
|
||||
|
||||
assert list(chain.stream(None)) == EXPECTED_STREAMED_JSON
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_tool_calls", [False, True])
|
||||
async def test_partial_json_output_parser_key_async_first_only(
|
||||
*,
|
||||
@@ -475,6 +566,18 @@ async def test_partial_json_output_parser_key_async_first_only(
|
||||
assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON
|
||||
|
||||
|
||||
async def test_partial_json_output_parser_key_async_first_only_v1() -> None:
|
||||
async def input_iter(_: Any) -> AsyncIterator[AIMessageChunkV1]:
|
||||
for msg in STREAMED_MESSAGES_V1:
|
||||
yield msg
|
||||
|
||||
chain = input_iter | JsonOutputKeyToolsParser(
|
||||
key_name="NameCollector", first_tool_only=True
|
||||
)
|
||||
|
||||
assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_tool_calls", [False, True])
|
||||
def test_json_output_key_tools_parser_multiple_tools_first_only(
|
||||
*, use_tool_calls: bool
|
||||
@@ -531,6 +634,42 @@ def test_json_output_key_tools_parser_multiple_tools_first_only(
|
||||
assert output_no_id == {"a": 1}
|
||||
|
||||
|
||||
def test_json_output_key_tools_parser_multiple_tools_first_only_v1() -> None:
|
||||
message = AIMessageV1(
|
||||
content=[],
|
||||
tool_calls=[
|
||||
{
|
||||
"type": "tool_call",
|
||||
"id": "call_other",
|
||||
"name": "other",
|
||||
"args": {"b": 2},
|
||||
},
|
||||
{"type": "tool_call", "id": "call_func", "name": "func", "args": {"a": 1}},
|
||||
],
|
||||
)
|
||||
|
||||
# Test with return_id=True
|
||||
parser = JsonOutputKeyToolsParser(
|
||||
key_name="func", first_tool_only=True, return_id=True
|
||||
)
|
||||
output = parser.parse_result(message)
|
||||
|
||||
# Should return the func tool call, not None
|
||||
assert output is not None
|
||||
assert output["type"] == "func"
|
||||
assert output["args"] == {"a": 1}
|
||||
assert "id" in output
|
||||
|
||||
# Test with return_id=False
|
||||
parser_no_id = JsonOutputKeyToolsParser(
|
||||
key_name="func", first_tool_only=True, return_id=False
|
||||
)
|
||||
output_no_id = parser_no_id.parse_result(message)
|
||||
|
||||
# Should return just the args
|
||||
assert output_no_id == {"a": 1}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_tool_calls", [False, True])
|
||||
def test_json_output_key_tools_parser_multiple_tools_no_match(
|
||||
*, use_tool_calls: bool
|
||||
@@ -583,6 +722,44 @@ def test_json_output_key_tools_parser_multiple_tools_no_match(
|
||||
assert output_no_id is None
|
||||
|
||||
|
||||
def test_json_output_key_tools_parser_multiple_tools_no_match_v1() -> None:
|
||||
message = AIMessageV1(
|
||||
content=[],
|
||||
tool_calls=[
|
||||
{
|
||||
"type": "tool_call",
|
||||
"id": "call_other",
|
||||
"name": "other",
|
||||
"args": {"b": 2},
|
||||
},
|
||||
{
|
||||
"type": "tool_call",
|
||||
"id": "call_another",
|
||||
"name": "another",
|
||||
"args": {"c": 3},
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
# Test with return_id=True, first_tool_only=True
|
||||
parser = JsonOutputKeyToolsParser(
|
||||
key_name="nonexistent", first_tool_only=True, return_id=True
|
||||
)
|
||||
output = parser.parse_result(message)
|
||||
|
||||
# Should return None when no matches
|
||||
assert output is None
|
||||
|
||||
# Test with return_id=False, first_tool_only=True
|
||||
parser_no_id = JsonOutputKeyToolsParser(
|
||||
key_name="nonexistent", first_tool_only=True, return_id=False
|
||||
)
|
||||
output_no_id = parser_no_id.parse_result(message)
|
||||
|
||||
# Should return None when no matches
|
||||
assert output_no_id is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_tool_calls", [False, True])
|
||||
def test_json_output_key_tools_parser_multiple_matching_tools(
|
||||
*, use_tool_calls: bool
|
||||
@@ -643,6 +820,42 @@ def test_json_output_key_tools_parser_multiple_matching_tools(
|
||||
assert output_all[1]["args"] == {"a": 3}
|
||||
|
||||
|
||||
def test_json_output_key_tools_parser_multiple_matching_tools_v1() -> None:
|
||||
message = AIMessageV1(
|
||||
content=[],
|
||||
tool_calls=[
|
||||
{"type": "tool_call", "id": "call_func1", "name": "func", "args": {"a": 1}},
|
||||
{
|
||||
"type": "tool_call",
|
||||
"id": "call_other",
|
||||
"name": "other",
|
||||
"args": {"b": 2},
|
||||
},
|
||||
{"type": "tool_call", "id": "call_func2", "name": "func", "args": {"a": 3}},
|
||||
],
|
||||
)
|
||||
|
||||
# Test with first_tool_only=True - should return first matching
|
||||
parser = JsonOutputKeyToolsParser(
|
||||
key_name="func", first_tool_only=True, return_id=True
|
||||
)
|
||||
output = parser.parse_result(message)
|
||||
|
||||
assert output is not None
|
||||
assert output["type"] == "func"
|
||||
assert output["args"] == {"a": 1} # First matching tool call
|
||||
|
||||
# Test with first_tool_only=False - should return all matching
|
||||
parser_all = JsonOutputKeyToolsParser(
|
||||
key_name="func", first_tool_only=False, return_id=True
|
||||
)
|
||||
output_all = parser_all.parse_result(message)
|
||||
|
||||
assert len(output_all) == 2
|
||||
assert output_all[0]["args"] == {"a": 1}
|
||||
assert output_all[1]["args"] == {"a": 3}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_tool_calls", [False, True])
|
||||
def test_json_output_key_tools_parser_empty_results(*, use_tool_calls: bool) -> None:
|
||||
def create_message() -> AIMessage:
|
||||
@@ -671,6 +884,35 @@ def test_json_output_key_tools_parser_empty_results(*, use_tool_calls: bool) ->
|
||||
assert output_all == []
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"empty_message",
|
||||
[
|
||||
AIMessageV1(content=[], tool_calls=[]),
|
||||
AIMessageV1(content="", tool_calls=[]),
|
||||
],
|
||||
)
|
||||
def test_json_output_key_tools_parser_empty_results_v1(
|
||||
empty_message: AIMessageV1,
|
||||
) -> None:
|
||||
# Test with first_tool_only=True
|
||||
parser = JsonOutputKeyToolsParser(
|
||||
key_name="func", first_tool_only=True, return_id=True
|
||||
)
|
||||
output = parser.parse_result(empty_message)
|
||||
|
||||
# Should return None for empty results
|
||||
assert output is None
|
||||
|
||||
# Test with first_tool_only=False
|
||||
parser_all = JsonOutputKeyToolsParser(
|
||||
key_name="func", first_tool_only=False, return_id=True
|
||||
)
|
||||
output_all = parser_all.parse_result(empty_message)
|
||||
|
||||
# Should return empty list for empty results
|
||||
assert output_all == []
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_tool_calls", [False, True])
|
||||
def test_json_output_key_tools_parser_parameter_combinations(
|
||||
*, use_tool_calls: bool
|
||||
@@ -746,6 +988,56 @@ def test_json_output_key_tools_parser_parameter_combinations(
|
||||
assert output4 == [{"a": 1}, {"a": 3}]
|
||||
|
||||
|
||||
def test_json_output_key_tools_parser_parameter_combinations_v1() -> None:
|
||||
"""Test all parameter combinations of JsonOutputKeyToolsParser."""
|
||||
result = AIMessageV1(
|
||||
content=[],
|
||||
tool_calls=[
|
||||
{
|
||||
"type": "tool_call",
|
||||
"id": "call_other",
|
||||
"name": "other",
|
||||
"args": {"b": 2},
|
||||
},
|
||||
{"type": "tool_call", "id": "call_func1", "name": "func", "args": {"a": 1}},
|
||||
{"type": "tool_call", "id": "call_func2", "name": "func", "args": {"a": 3}},
|
||||
],
|
||||
)
|
||||
|
||||
# Test: first_tool_only=True, return_id=True
|
||||
parser1 = JsonOutputKeyToolsParser(
|
||||
key_name="func", first_tool_only=True, return_id=True
|
||||
)
|
||||
output1 = parser1.parse_result(result)
|
||||
assert output1["type"] == "func"
|
||||
assert output1["args"] == {"a": 1}
|
||||
assert "id" in output1
|
||||
|
||||
# Test: first_tool_only=True, return_id=False
|
||||
parser2 = JsonOutputKeyToolsParser(
|
||||
key_name="func", first_tool_only=True, return_id=False
|
||||
)
|
||||
output2 = parser2.parse_result(result)
|
||||
assert output2 == {"a": 1}
|
||||
|
||||
# Test: first_tool_only=False, return_id=True
|
||||
parser3 = JsonOutputKeyToolsParser(
|
||||
key_name="func", first_tool_only=False, return_id=True
|
||||
)
|
||||
output3 = parser3.parse_result(result)
|
||||
assert len(output3) == 2
|
||||
assert all("id" in item for item in output3)
|
||||
assert output3[0]["args"] == {"a": 1}
|
||||
assert output3[1]["args"] == {"a": 3}
|
||||
|
||||
# Test: first_tool_only=False, return_id=False
|
||||
parser4 = JsonOutputKeyToolsParser(
|
||||
key_name="func", first_tool_only=False, return_id=False
|
||||
)
|
||||
output4 = parser4.parse_result(result)
|
||||
assert output4 == [{"a": 1}, {"a": 3}]
|
||||
|
||||
|
||||
class Person(BaseModel):
|
||||
age: int
|
||||
hair_color: str
|
||||
@@ -788,6 +1080,18 @@ def test_partial_pydantic_output_parser() -> None:
|
||||
assert actual == EXPECTED_STREAMED_PYDANTIC
|
||||
|
||||
|
||||
def test_partial_pydantic_output_parser_v1() -> None:
|
||||
def input_iter(_: Any) -> Iterator[AIMessageChunkV1]:
|
||||
yield from STREAMED_MESSAGES_V1
|
||||
|
||||
chain = input_iter | PydanticToolsParser(
|
||||
tools=[NameCollector], first_tool_only=True
|
||||
)
|
||||
|
||||
actual = list(chain.stream(None))
|
||||
assert actual == EXPECTED_STREAMED_PYDANTIC
|
||||
|
||||
|
||||
async def test_partial_pydantic_output_parser_async() -> None:
|
||||
for use_tool_calls in [False, True]:
|
||||
input_iter = _get_aiter(use_tool_calls=use_tool_calls)
|
||||
@@ -800,6 +1104,19 @@ async def test_partial_pydantic_output_parser_async() -> None:
|
||||
assert actual == EXPECTED_STREAMED_PYDANTIC
|
||||
|
||||
|
||||
async def test_partial_pydantic_output_parser_async_v1() -> None:
|
||||
async def input_iter(_: Any) -> AsyncIterator[AIMessageChunkV1]:
|
||||
for msg in STREAMED_MESSAGES_V1:
|
||||
yield msg
|
||||
|
||||
chain = input_iter | PydanticToolsParser(
|
||||
tools=[NameCollector], first_tool_only=True
|
||||
)
|
||||
|
||||
actual = [p async for p in chain.astream(None)]
|
||||
assert actual == EXPECTED_STREAMED_PYDANTIC
|
||||
|
||||
|
||||
def test_parse_with_different_pydantic_2_v1() -> None:
|
||||
"""Test with pydantic.v1.BaseModel from pydantic 2."""
|
||||
import pydantic
|
||||
@@ -870,20 +1187,22 @@ def test_parse_with_different_pydantic_2_proper() -> None:
|
||||
|
||||
def test_max_tokens_error(caplog: Any) -> None:
|
||||
parser = PydanticToolsParser(tools=[NameCollector], first_tool_only=True)
|
||||
message = AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "call_OwL7f5PE",
|
||||
"name": "NameCollector",
|
||||
"args": {"names": ["suz", "jerm"]},
|
||||
}
|
||||
],
|
||||
response_metadata={"stop_reason": "max_tokens"},
|
||||
)
|
||||
with pytest.raises(ValidationError):
|
||||
_ = parser.invoke(message)
|
||||
assert any(
|
||||
"`max_tokens` stop reason" in msg and record.levelname == "ERROR"
|
||||
for record, msg in zip(caplog.records, caplog.messages)
|
||||
)
|
||||
for msg_class in [AIMessage, AIMessageV1]:
|
||||
message = msg_class(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"type": "tool_call",
|
||||
"id": "call_OwL7f5PE",
|
||||
"name": "NameCollector",
|
||||
"args": {"names": ["suz", "jerm"]},
|
||||
}
|
||||
],
|
||||
response_metadata={"stop_reason": "max_tokens"},
|
||||
)
|
||||
with pytest.raises(ValidationError):
|
||||
_ = parser.invoke(message)
|
||||
assert any(
|
||||
"`max_tokens` stop reason" in msg and record.levelname == "ERROR"
|
||||
for record, msg in zip(caplog.records, caplog.messages)
|
||||
)
|
||||
|
||||
@@ -726,7 +726,7 @@
|
||||
'description': '''
|
||||
Allowance for errors made by LLM.
|
||||
|
||||
Here we add an `error` key to surface errors made during generation
|
||||
Here we add an ``error`` key to surface errors made during generation
|
||||
(e.g., invalid JSON arguments.)
|
||||
''',
|
||||
'properties': dict({
|
||||
@@ -752,6 +752,10 @@
|
||||
]),
|
||||
'title': 'Error',
|
||||
}),
|
||||
'extras': dict({
|
||||
'title': 'Extras',
|
||||
'type': 'object',
|
||||
}),
|
||||
'id': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
@@ -763,6 +767,10 @@
|
||||
]),
|
||||
'title': 'Id',
|
||||
}),
|
||||
'index': dict({
|
||||
'title': 'Index',
|
||||
'type': 'integer',
|
||||
}),
|
||||
'name': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
@@ -781,9 +789,10 @@
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'type',
|
||||
'id',
|
||||
'name',
|
||||
'args',
|
||||
'id',
|
||||
'error',
|
||||
]),
|
||||
'title': 'InvalidToolCall',
|
||||
@@ -998,12 +1007,23 @@
|
||||
|
||||
This represents a request to call the tool named "foo" with arguments {"a": 1}
|
||||
and an identifier of "123".
|
||||
|
||||
.. note::
|
||||
``create_tool_call`` may also be used as a factory to create a
|
||||
``ToolCall``. Benefits include:
|
||||
|
||||
* Automatic ID generation (when not provided)
|
||||
* Required arguments strictly validated at creation time
|
||||
''',
|
||||
'properties': dict({
|
||||
'args': dict({
|
||||
'title': 'Args',
|
||||
'type': 'object',
|
||||
}),
|
||||
'extras': dict({
|
||||
'title': 'Extras',
|
||||
'type': 'object',
|
||||
}),
|
||||
'id': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
@@ -1015,6 +1035,10 @@
|
||||
]),
|
||||
'title': 'Id',
|
||||
}),
|
||||
'index': dict({
|
||||
'title': 'Index',
|
||||
'type': 'integer',
|
||||
}),
|
||||
'name': dict({
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
@@ -1026,9 +1050,10 @@
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'type',
|
||||
'id',
|
||||
'name',
|
||||
'args',
|
||||
'id',
|
||||
]),
|
||||
'title': 'ToolCall',
|
||||
'type': 'object',
|
||||
@@ -1037,9 +1062,9 @@
|
||||
'description': '''
|
||||
A chunk of a tool call (e.g., as part of a stream).
|
||||
|
||||
When merging ToolCallChunks (e.g., via AIMessageChunk.__add__),
|
||||
When merging ``ToolCallChunks`` (e.g., via ``AIMessageChunk.__add__``),
|
||||
all string attributes are concatenated. Chunks are only merged if their
|
||||
values of `index` are equal and not None.
|
||||
values of ``index`` are equal and not ``None``.
|
||||
|
||||
Example:
|
||||
|
||||
@@ -1065,6 +1090,10 @@
|
||||
]),
|
||||
'title': 'Args',
|
||||
}),
|
||||
'extras': dict({
|
||||
'title': 'Extras',
|
||||
'type': 'object',
|
||||
}),
|
||||
'id': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
@@ -1105,9 +1134,9 @@
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'id',
|
||||
'name',
|
||||
'args',
|
||||
'id',
|
||||
'index',
|
||||
]),
|
||||
'title': 'ToolCallChunk',
|
||||
@@ -2158,7 +2187,7 @@
|
||||
'description': '''
|
||||
Allowance for errors made by LLM.
|
||||
|
||||
Here we add an `error` key to surface errors made during generation
|
||||
Here we add an ``error`` key to surface errors made during generation
|
||||
(e.g., invalid JSON arguments.)
|
||||
''',
|
||||
'properties': dict({
|
||||
@@ -2184,6 +2213,10 @@
|
||||
]),
|
||||
'title': 'Error',
|
||||
}),
|
||||
'extras': dict({
|
||||
'title': 'Extras',
|
||||
'type': 'object',
|
||||
}),
|
||||
'id': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
@@ -2195,6 +2228,10 @@
|
||||
]),
|
||||
'title': 'Id',
|
||||
}),
|
||||
'index': dict({
|
||||
'title': 'Index',
|
||||
'type': 'integer',
|
||||
}),
|
||||
'name': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
@@ -2213,9 +2250,10 @@
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'type',
|
||||
'id',
|
||||
'name',
|
||||
'args',
|
||||
'id',
|
||||
'error',
|
||||
]),
|
||||
'title': 'InvalidToolCall',
|
||||
@@ -2430,12 +2468,23 @@
|
||||
|
||||
This represents a request to call the tool named "foo" with arguments {"a": 1}
|
||||
and an identifier of "123".
|
||||
|
||||
.. note::
|
||||
``create_tool_call`` may also be used as a factory to create a
|
||||
``ToolCall``. Benefits include:
|
||||
|
||||
* Automatic ID generation (when not provided)
|
||||
* Required arguments strictly validated at creation time
|
||||
''',
|
||||
'properties': dict({
|
||||
'args': dict({
|
||||
'title': 'Args',
|
||||
'type': 'object',
|
||||
}),
|
||||
'extras': dict({
|
||||
'title': 'Extras',
|
||||
'type': 'object',
|
||||
}),
|
||||
'id': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
@@ -2447,6 +2496,10 @@
|
||||
]),
|
||||
'title': 'Id',
|
||||
}),
|
||||
'index': dict({
|
||||
'title': 'Index',
|
||||
'type': 'integer',
|
||||
}),
|
||||
'name': dict({
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
@@ -2458,9 +2511,10 @@
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'type',
|
||||
'id',
|
||||
'name',
|
||||
'args',
|
||||
'id',
|
||||
]),
|
||||
'title': 'ToolCall',
|
||||
'type': 'object',
|
||||
@@ -2469,9 +2523,9 @@
|
||||
'description': '''
|
||||
A chunk of a tool call (e.g., as part of a stream).
|
||||
|
||||
When merging ToolCallChunks (e.g., via AIMessageChunk.__add__),
|
||||
When merging ``ToolCallChunks`` (e.g., via ``AIMessageChunk.__add__``),
|
||||
all string attributes are concatenated. Chunks are only merged if their
|
||||
values of `index` are equal and not None.
|
||||
values of ``index`` are equal and not ``None``.
|
||||
|
||||
Example:
|
||||
|
||||
@@ -2497,6 +2551,10 @@
|
||||
]),
|
||||
'title': 'Args',
|
||||
}),
|
||||
'extras': dict({
|
||||
'title': 'Extras',
|
||||
'type': 'object',
|
||||
}),
|
||||
'id': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
@@ -2537,9 +2595,9 @@
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'id',
|
||||
'name',
|
||||
'args',
|
||||
'id',
|
||||
'index',
|
||||
]),
|
||||
'title': 'ToolCallChunk',
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user