Compare commits

...

91 Commits

Author SHA1 Message Date
ccurme
d0222964c1 groq, fireworks, text-splitters (#26104) 2024-09-05 13:51:41 -04:00
Chester Curme
b97307c8b4 Merge branch 'v0.3rc' into v0.3/dev_releases 2024-09-05 13:42:00 -04:00
ccurme
1ad66e70dc text-splitters[major]: update core dep + drop support for python 3.8 (#26102) 2024-09-05 13:41:28 -04:00
Bagatur
76564edd3a openai[patch]: update configurable model dumps (#26101) 2024-09-05 13:26:40 -04:00
Eugene Yurtsev
1c51e1693d core[patch]: Fix issue with adapter utility for pydantic repr (#26099)
This repr will be deleted prior to release -- it's temporarily here to
make it easy to separate code changes in langchain vs. code changes
stemming from breaking changes in pydantic
2024-09-05 12:27:01 -04:00
Bagatur
c0f886dc52 fix core version 2024-09-05 11:57:52 -04:00
Eugene Yurtsev
a267da6a3a core[minor]: Add type overload for secret_from_env factory (#26091)
Add type overload
2024-09-05 11:52:19 -04:00
Bagatur
0c63b18c1f ci 2024-09-05 11:47:56 -04:00
Bagatur
915c1e3dfb Merge branch 'v0.3rc' into v0.3/dev_releases 2024-09-05 11:41:32 -04:00
Bagatur
8da2ace99d openai[patch]: update snapshots (#26098) 2024-09-05 11:41:14 -04:00
Bagatur
81cd73cfca openai 0.2.0.dev0, anthropic 0.2.0.dev0 2024-09-05 11:07:47 -04:00
ccurme
e358846b39 core[patch]: add bedrock to load.mapping (#26094) 2024-09-05 10:56:46 -04:00
Eugene Yurtsev
3c598d25a6 core[minor]: Add get_input_jsonschema, get_output_jsonschema, get_config_jsonschema (#26034)
This PR adds methods to directly get the json schema for inputs,
outputs, and config.
Currently, it's delegating to the underlying pydantic implementation,
but this may be changed in the future to be independent of pydantic.
2024-09-05 10:36:42 -04:00
ccurme
e5aa0f938b mongo[major]: upgrade pydantic (#26053) 2024-09-05 09:05:41 -04:00
Bagatur
79c46319dd couchbase[patch]: rm pydantic usage (#26068) 2024-09-04 16:29:14 -07:00
ccurme
c5d4dfefc0 prompty[major]: upgrade pydantic (#26056) 2024-09-04 19:26:18 -04:00
ccurme
6e853501ec voyageai[major]: upgrade pydantic (#26070) 2024-09-04 18:59:13 -04:00
Bagatur
fd1f3ca213 exa[major]: use pydantic v2 (#26069) 2024-09-04 15:02:05 -07:00
Bagatur
567a4ce5aa box[major]: use pydantic v2 (#26067) 2024-09-04 14:51:53 -07:00
ccurme
923ce84aa7 robocorp[major]: upgrade pydantic (#26062) 2024-09-04 17:10:15 -04:00
Eugene Yurtsev
9379613132 langchain[major]: Upgrade langchain to be pydantic 2 compatible (#26050)
Upgrading the langchain package to be pydantic 2 compatible.

Had to remove some parts of unit tests in parsers that were relying on
spying on methods since that fails with pydantic 2. The unit tests don't
seem particularly good, so can be re-written at a future date.

Depends on: https://github.com/langchain-ai/langchain/pull/26057

Most of this PR was done using gritql for code mods, followed by some
fixes done manually to account for changes made by pydantic
2024-09-04 16:59:07 -04:00
Bagatur
c72a76237f cherry-pick 88e9e6b (#26063) 2024-09-04 13:50:42 -07:00
Bagatur
f9cafcbcb0 pinecone[patch]: rm pydantic lint script (#26052) 2024-09-04 13:49:09 -07:00
Bagatur
1fce5543bc poetry lock 2024-09-04 13:44:51 -07:00
Bagatur
88e9e6bf55 core,standard-tests[patch]: add Ser/Des test and update serialization mapping (#26042) 2024-09-04 13:38:03 -07:00
Bagatur
7f0dd4b182 fmt 2024-09-04 13:31:29 -07:00
Bagatur
5557b86a54 fmt 2024-09-04 13:31:29 -07:00
Bagatur
caf4ae3a45 fmt 2024-09-04 13:31:28 -07:00
Bagatur
c88b75ca6a fmt 2024-09-04 13:30:02 -07:00
Bagatur
e409a85a28 fmt 2024-09-04 13:29:24 -07:00
Bagatur
40634d441a make 2024-09-04 13:29:24 -07:00
Bagatur
1d2a503ab8 standard-tests[patch]: add Ser/Des test 2024-09-04 13:29:20 -07:00
ccurme
b924c61440 qdrant[major]: drop support for python 3.8 (#26061) 2024-09-04 16:22:54 -04:00
Eugene Yurtsev
efa10c8ef8 core[minor]: Add message chunks to AnyMessage (#26057)
Adds the chunk variant of each Message to AnyMessage.

Required for this PR:
https://github.com/langchain-ai/langchain/pull/26050/files
2024-09-04 15:36:22 -04:00
ccurme
0a6c67ce6a nomic: drop support for python 3.8 (#26055) 2024-09-04 15:30:00 -04:00
ccurme
ed771f2d2b huggingface[major]: upgrade pydantic (#26048) 2024-09-04 15:08:43 -04:00
ccurme
63ba12d8e0 milvus: drop support for python 3.8 (#26051)
to be consistent with core
2024-09-04 14:54:45 -04:00
Bagatur
f785cf029b pinecone[major]: Update to pydantic v2 (#26039) 2024-09-04 11:28:54 -07:00
ccurme
be7cd0756f ollama[major]: upgrade pydantic (#26044) 2024-09-04 13:54:52 -04:00
ccurme
51c6899850 groq[major]: upgrade pydantic (#26036) 2024-09-04 13:41:40 -04:00
ccurme
163d6fe8ef anthropic: update pydantic (#26000)
Migrated with gritql: https://github.com/eyurtsev/migrate-pydantic
2024-09-04 13:35:51 -04:00
ccurme
7cee7fbfad mistralai: update pydantic (#25995)
Migrated with gritql: https://github.com/eyurtsev/migrate-pydantic
2024-09-04 13:26:17 -04:00
ccurme
4799ad95d0 core[patch]: remove warnings from protected namespaces on RunnableSerializable (#26040) 2024-09-04 13:10:08 -04:00
Bagatur
88065d794b fmt 2024-09-04 09:52:01 -07:00
Bagatur
b27bfa6717 pinecone[major]: Update to pydantic v2 2024-09-04 09:50:39 -07:00
Bagatur
5adeaf0732 openai[major]: switch to pydantic v2 (#26001) 2024-09-04 09:18:29 -07:00
Bagatur
f9d91e19c5 fireworks[major]: switch to pydantic v2 (#26004) 2024-09-04 09:18:10 -07:00
Bagatur
4c7afb0d6c Update libs/partners/openai/langchain_openai/llms/base.py 2024-09-03 23:36:19 -07:00
Bagatur
c1ff61669d Update libs/partners/openai/langchain_openai/llms/base.py 2024-09-03 23:36:14 -07:00
Bagatur
54d6808c1e Update libs/partners/openai/langchain_openai/llms/azure.py 2024-09-03 23:36:08 -07:00
Bagatur
78468de2e5 Update libs/partners/openai/langchain_openai/llms/azure.py 2024-09-03 23:36:02 -07:00
Bagatur
76572f963b Update libs/partners/openai/langchain_openai/embeddings/base.py 2024-09-03 23:35:56 -07:00
Bagatur
c0448f27ba Update libs/partners/openai/langchain_openai/embeddings/base.py 2024-09-03 23:35:51 -07:00
Bagatur
179aaa4007 Update libs/partners/openai/langchain_openai/embeddings/azure.py 2024-09-03 23:35:43 -07:00
Bagatur
d072d592a1 Update libs/partners/openai/langchain_openai/embeddings/azure.py 2024-09-03 23:35:35 -07:00
Bagatur
78c454c130 Update libs/partners/openai/langchain_openai/chat_models/base.py 2024-09-03 23:35:30 -07:00
Bagatur
5199555c0d Update libs/partners/openai/langchain_openai/chat_models/base.py 2024-09-03 23:35:26 -07:00
Bagatur
5e31cd91a7 Update libs/partners/openai/langchain_openai/chat_models/azure.py 2024-09-03 23:35:21 -07:00
Bagatur
49a1f5dd47 Update libs/partners/openai/langchain_openai/chat_models/azure.py 2024-09-03 23:35:15 -07:00
Bagatur
d0cc9b022a Update libs/partners/fireworks/langchain_fireworks/chat_models.py 2024-09-03 23:30:56 -07:00
Bagatur
a91bd2737a Update libs/partners/fireworks/langchain_fireworks/chat_models.py 2024-09-03 23:30:49 -07:00
Bagatur
5ad2b8ce80 Merge branch 'v0.3rc' into bagatur/fireworks_0.3 2024-09-03 23:29:07 -07:00
Bagatur
b78764599b Merge branch 'v0.3rc' into bagatur/openai_attempt_2 2024-09-03 23:28:50 -07:00
Bagatur
2888e34f53 infra: remove pydantic v1 tests (#26006) 2024-09-03 23:27:52 -07:00
Bagatur
dd4418a503 rm requires 2024-09-03 23:26:13 -07:00
Bagatur
a976f2071b Merge branch 'v0.3rc' into bagatur/rm_pydantic_v1_ci 2024-09-03 19:06:22 -07:00
Eugene Yurtsev
5f98975be0 core[patch]: Fix injected args in tool signature (#25991)
- Fix injected args in tool signature
- Fix another unit test that was using the wrong namespace import in
pydantic
2024-09-03 21:53:50 -04:00
Bagatur
0529c991ce rm 2024-09-03 18:02:12 -07:00
Bagatur
954abcce59 infra: remove pydantic v1 tests 2024-09-03 18:01:34 -07:00
Bagatur
6ad515d34e Merge branch 'v0.3rc' into bagatur/fireworks_0.3 2024-09-03 17:51:46 -07:00
Bagatur
99348e1614 Merge branch 'v0.3rc' into bagatur/openai_attempt_2 2024-09-03 17:51:27 -07:00
Bagatur
2c742cc20d standard-tests[major]: use pydantic v2 (#26005) 2024-09-03 17:50:45 -07:00
Bagatur
02f87203f7 standard-tests[major]: use pydantic v2 2024-09-03 17:48:20 -07:00
Bagatur
56163481dd fmt 2024-09-03 17:46:41 -07:00
Bagatur
6aac2eeab5 fmt 2024-09-03 17:42:22 -07:00
Bagatur
559d8a4d13 fireworks[major]: switch to pydantic v2 2024-09-03 17:41:28 -07:00
Bagatur
ec9e8eb71c fmt 2024-09-03 17:24:24 -07:00
Bagatur
9399df7777 fmt 2024-09-03 16:57:42 -07:00
Bagatur
5fc1104d00 fmt 2024-09-03 16:51:14 -07:00
Bagatur
6777106fbe fmt 2024-09-03 16:50:17 -07:00
Bagatur
5f5287c3b0 fmt 2024-09-03 16:48:53 -07:00
Bagatur
615f8b0d47 openai[major]: switch to pydantic v2 2024-09-03 16:33:35 -07:00
Bagatur
9a9ab65030 merge master correctly (#25999) 2024-09-03 14:57:29 -07:00
Bagatur
241b6d2355 Revert "merge master (#25997)" (#25998) 2024-09-03 14:55:28 -07:00
Bagatur
91e09ffee5 merge master (#25997)
Co-authored-by: Dan O'Donovan <dan.odonovan@gmail.com>
Co-authored-by: Tom Daniel Grande <tomdgrande@gmail.com>
Co-authored-by: Grande <Tom.Daniel.Grande@statsbygg.no>
Co-authored-by: Erick Friis <erick@langchain.dev>
Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
2024-09-03 14:51:26 -07:00
Eugene Yurtsev
8e4bae351e core[major]: Drop python 3.8 support (#25996)
Drop python 3.8 support as EOL is 2024 October
2024-09-03 14:47:27 -07:00
Erick Friis
0da201c1d5 core: fix lint 0.3rc (#25993) 2024-09-03 17:13:52 -04:00
Erick Friis
29413a22e1 infra: also run lint/test on rc (#25992) 2024-09-03 14:02:49 -07:00
Eugene Yurtsev
ae5a574aa5 core[major]: Upgrade langchain-core to pydantic 2 (#25986)
This PR upgrades core to pydantic 2.

It involves a combination of manual changes together with automated code
mods using gritql.

Changes and known issues:

1. Current models override __repr__ to be consistent with pydantic 1
(this will be removed in a follow up PR)
Related:
https://github.com/langchain-ai/langchain/pull/25986/files#diff-e5bd296179b7a72fcd4ea5cfa28b145beaf787da057e6d122aa76ee0bb8132c9R74
2. Issue with decorator for BaseChatModel
(https://github.com/langchain-ai/langchain/pull/25986/files#diff-932bf3b314b268754ef640a5b8f52da96f9024fb81dd388dcd166b5713ecdf66R202)
-- cc @baskaryan
3. `name` attribute in Base Runnable does not have a default -- was
raising a pydantic warning due to override. We need to see if there's a
way to fix to avoid making a breaking change for folks with custom
runnables.
(https://github.com/langchain-ai/langchain/pull/25986/files#diff-836773d27f8565f4dd45e9d6cf828920f89991a880c098b7511e0d3bb78a8a0dR238)
4. Likely can remove hard-coded RunnableBranch name
(https://github.com/langchain-ai/langchain/pull/25986/files#diff-72894b94f70b1bfc908eb4d53f5ff90bb33bf8a4240a5e34cae48ddc62ac313aR147)
5. `model_*` namespace is reserved in pydantic. We'll need to specify
`protected_namespaces`
6. create_model does not have a cached path yet
7. get_input_schema() in many places has been updated to be explicit
about whether parameters are required or optional
8. injected tool args aren't picked up properly (losing type annotation)

For posterity the following gritql migrations were used:

```
engine marzano(0.1)
language python

or {
    `from $IMPORT import $...` where {
        $IMPORT <: contains `pydantic_v1`,
        $IMPORT => `pydantic`
    },
    `$X.update_forward_refs` => `$X.model_rebuild`,
  // This pattern still needs fixing as it fails (populate_by_name vs.
  // allow_populate_by_name)
  class_definition($name, $body) as $C where {
      $name <: `Config`,
      $body <: block($statements),
      $t = "",
      $statements <: some bubble($t) assignment(left=$x, right=$y) as $A where {    
        or {
            $x <: `allow_population_by_field_name` where {
                $t += `populate_by_name=$y,`
            },
            $t += `$x=$y,`
        }
      },
      $C => `model_config = ConfigDict($t)`,
      add_import(source="pydantic", name="ConfigDict")
  }
}

```



```
engine marzano(0.1)
language python

`@root_validator(pre=True)` as $decorator where {
    $decorator <: before function_definition($body, $return_type),
    $decorator => `@model_validator(mode="before")\n@classmethod`,
    add_import(source="pydantic", name="model_validator"),
    $return_type => `Any`
}
```

```
engine marzano(0.1)
language python

`@root_validator(pre=False, skip_on_failure=True)` as $decorator where {
    $decorator <: before function_definition($body, $parameters, $return_type) where {
        $body <: contains bubble or {
            `values["$Q"]` => `self.$Q`,
            `values.get("$Q")` => `(self.$Q or None)`,
            `values.get($Q, $...)` as $V where {
                $Q <: contains `"$QName"`,
                $V => `self.$QName`,
            },
            `return $Q` => `return self`
        }
    },
    $decorator => `@model_validator(mode="after")`,
    // Silly work around a bug in grit
    // Adding Self to pydantic and then will replace it with one from typing
    add_import(source="pydantic", name="model_validator"),
    $parameters => `self`,
    $return_type => `Self`
}

```

```
grit apply --language python '`Self` where { add_import(source="typing_extensions", name="Self")}'
```
2024-09-03 16:30:44 -04:00
Erick Friis
5a0e82c31c infra: fix 0.3rc ci check (#25988) 2024-09-03 12:20:08 -07:00
Erick Friis
8590b421c4 infra: ignore core dependents for 0.3rc (#25980) 2024-09-03 11:06:45 -07:00
356 changed files with 16950 additions and 7531 deletions

View File

@@ -16,6 +16,10 @@ LANGCHAIN_DIRS = [
"libs/experimental",
]
# for 0.3rc, we are ignoring core dependents
# in order to be able to get CI to pass for individual PRs.
IGNORE_CORE_DEPENDENTS = True
# ignored partners are removed from dependents
# but still run if directly edited
IGNORED_PARTNERS = [
@@ -104,7 +108,7 @@ def _get_configs_for_single_dir(job: str, dir_: str) -> List[Dict[str, str]]:
{"working-directory": dir_, "python-version": f"3.{v}"}
for v in range(8, 13)
]
min_python = "3.8"
min_python = "3.9"
max_python = "3.12"
# custom logic for specific directories
@@ -184,6 +188,9 @@ if __name__ == "__main__":
# for extended testing
found = False
for dir_ in LANGCHAIN_DIRS:
if dir_ == "libs/core" and IGNORE_CORE_DEPENDENTS:
dirs_to_run["extended-test"].add(dir_)
continue
if file.startswith(dir_):
found = True
if found:

View File

@@ -11,7 +11,7 @@ if __name__ == "__main__":
# see if we're releasing an rc
version = toml_data["tool"]["poetry"]["version"]
releasing_rc = "rc" in version
releasing_rc = "rc" in version or "dev" in version
# if not, iterate through dependencies and make sure none allow prereleases
if not releasing_rc:

View File

@@ -1,114 +0,0 @@
name: dependencies
on:
workflow_call:
inputs:
working-directory:
required: true
type: string
description: "From which folder this pipeline executes"
langchain-location:
required: false
type: string
description: "Relative path to the langchain library folder"
python-version:
required: true
type: string
description: "Python version to use"
env:
POETRY_VERSION: "1.7.1"
jobs:
build:
defaults:
run:
working-directory: ${{ inputs.working-directory }}
runs-on: ubuntu-latest
name: dependency checks ${{ inputs.python-version }}
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ inputs.python-version }} + Poetry ${{ env.POETRY_VERSION }}
uses: "./.github/actions/poetry_setup"
with:
python-version: ${{ inputs.python-version }}
poetry-version: ${{ env.POETRY_VERSION }}
working-directory: ${{ inputs.working-directory }}
cache-key: pydantic-cross-compat
- name: Install dependencies
shell: bash
run: poetry install
- name: Check imports with base dependencies
shell: bash
run: poetry run make check_imports
- name: Install test dependencies
shell: bash
run: poetry install --with test
- name: Install langchain editable
working-directory: ${{ inputs.working-directory }}
if: ${{ inputs.langchain-location }}
env:
LANGCHAIN_LOCATION: ${{ inputs.langchain-location }}
run: |
poetry run pip install -e "$LANGCHAIN_LOCATION"
- name: Install the opposite major version of pydantic
# If normal tests use pydantic v1, here we'll use v2, and vice versa.
shell: bash
# airbyte currently doesn't support pydantic v2
if: ${{ !startsWith(inputs.working-directory, 'libs/partners/airbyte') }}
run: |
# Determine the major part of pydantic version
REGULAR_VERSION=$(poetry run python -c "import pydantic; print(pydantic.__version__)" | cut -d. -f1)
if [[ "$REGULAR_VERSION" == "1" ]]; then
PYDANTIC_DEP=">=2.1,<3"
TEST_WITH_VERSION="2"
elif [[ "$REGULAR_VERSION" == "2" ]]; then
PYDANTIC_DEP="<2"
TEST_WITH_VERSION="1"
else
echo "Unexpected pydantic major version '$REGULAR_VERSION', cannot determine which version to use for cross-compatibility test."
exit 1
fi
# Install via `pip` instead of `poetry add` to avoid changing lockfile,
# which would prevent caching from working: the cache would get saved
# to a different key than where it gets loaded from.
poetry run pip install "pydantic${PYDANTIC_DEP}"
# Ensure that the correct pydantic is installed now.
echo "Checking pydantic version... Expecting ${TEST_WITH_VERSION}"
# Determine the major part of pydantic version
CURRENT_VERSION=$(poetry run python -c "import pydantic; print(pydantic.__version__)" | cut -d. -f1)
# Check that the major part of pydantic version is as expected, if not
# raise an error
if [[ "$CURRENT_VERSION" != "$TEST_WITH_VERSION" ]]; then
echo "Error: expected pydantic version ${CURRENT_VERSION} to have been installed, but found: ${TEST_WITH_VERSION}"
exit 1
fi
echo "Found pydantic version ${CURRENT_VERSION}, as expected"
- name: Run pydantic compatibility tests
# airbyte currently doesn't support pydantic v2
if: ${{ !startsWith(inputs.working-directory, 'libs/partners/airbyte') }}
shell: bash
run: make test
- name: Ensure the tests did not create any additional files
shell: bash
run: |
set -eu
STATUS="$(git status)"
echo "$STATUS"
# grep will exit non-zero if the target message isn't found,
# and `set -e` above will cause the step to fail.
echo "$STATUS" | grep 'nothing to commit, working tree clean'

View File

@@ -89,19 +89,6 @@ jobs:
python-version: ${{ matrix.job-configs.python-version }}
secrets: inherit
dependencies:
name: cd ${{ matrix.job-configs.working-directory }}
needs: [ build ]
if: ${{ needs.build.outputs.dependencies != '[]' }}
strategy:
matrix:
job-configs: ${{ fromJson(needs.build.outputs.dependencies) }}
uses: ./.github/workflows/_dependencies.yml
with:
working-directory: ${{ matrix.job-configs.working-directory }}
python-version: ${{ matrix.job-configs.python-version }}
secrets: inherit
extended-tests:
name: "cd ${{ matrix.job-configs.working-directory }} / make extended_tests #${{ matrix.job-configs.python-version }}"
needs: [ build ]
@@ -149,7 +136,7 @@ jobs:
echo "$STATUS" | grep 'nothing to commit, working tree clean'
ci_success:
name: "CI Success"
needs: [build, lint, test, compile-integration-tests, dependencies, extended-tests, test-doc-imports]
needs: [build, lint, test, compile-integration-tests, extended-tests, test-doc-imports]
if: |
always()
runs-on: ubuntu-latest

View File

@@ -17,7 +17,7 @@ jobs:
fail-fast: false
matrix:
python-version:
- "3.8"
- "3.9"
- "3.11"
working-directory:
- "libs/partners/openai"

View File

@@ -12,6 +12,9 @@ integration_test integration_tests: TEST_FILE = tests/integration_tests/
test tests:
poetry run pytest --disable-socket --allow-unix-socket $(TEST_FILE)
test_watch:
poetry run ptw --snapshot-update --now . -- -vv $(TEST_FILE)
# integration tests are run without the --disable-socket flag to allow network calls
integration_test integration_tests:
poetry run pytest $(TEST_FILE)

View File

@@ -23,6 +23,7 @@ pytest = "^7.4.3"
pytest-asyncio = "^0.23.2"
pytest-socket = "^0.7.0"
langchain-core = { path = "../../core", develop = true }
pytest-watcher = "^0.3.4"
[tool.poetry.group.codespell]
optional = true

View File

@@ -102,6 +102,16 @@ def test_serializable_mapping() -> None:
"modifier",
"RemoveMessage",
),
("langchain", "chat_models", "mistralai", "ChatMistralAI"): (
"langchain_mistralai",
"chat_models",
"ChatMistralAI",
),
("langchain_groq", "chat_models", "ChatGroq"): (
"langchain_groq",
"chat_models",
"ChatGroq",
),
}
serializable_modules = import_all_modules("langchain")

View File

@@ -39,7 +39,6 @@ lint_tests: PYTHON_FILES=tests
lint_tests: MYPY_CACHE=.mypy_cache_test
lint lint_diff lint_package lint_tests:
./scripts/check_pydantic.sh .
./scripts/lint_imports.sh
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff check $(PYTHON_FILES)
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff

View File

@@ -18,6 +18,8 @@ from typing import (
Union,
)
from pydantic import ConfigDict
from langchain_core._api.beta_decorator import beta
from langchain_core.runnables.base import (
Runnable,
@@ -229,8 +231,9 @@ class ContextSet(RunnableSerializable):
keys: Mapping[str, Optional[Runnable]]
class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
def __init__(
self,

View File

@@ -20,13 +20,14 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from typing import List, Sequence, Union
from pydantic import BaseModel, Field
from langchain_core.messages import (
AIMessage,
BaseMessage,
HumanMessage,
get_buffer_string,
)
from langchain_core.pydantic_v1 import BaseModel, Field
class BaseChatMessageHistory(ABC):

View File

@@ -4,10 +4,12 @@ import contextlib
import mimetypes
from io import BufferedReader, BytesIO
from pathlib import PurePath
from typing import Any, Generator, List, Literal, Mapping, Optional, Union, cast
from typing import Any, Dict, Generator, List, Literal, Optional, Union, cast
from pydantic import ConfigDict, Field, model_validator
from langchain_core.load.serializable import Serializable
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.utils.pydantic import v1_repr
PathLike = Union[str, PurePath]
@@ -110,9 +112,10 @@ class Blob(BaseMedia):
path: Optional[PathLike] = None
"""Location where the original content was found."""
class Config:
arbitrary_types_allowed = True
frozen = True
model_config = ConfigDict(
arbitrary_types_allowed=True,
frozen=True,
)
@property
def source(self) -> Optional[str]:
@@ -127,8 +130,9 @@ class Blob(BaseMedia):
return cast(Optional[str], self.metadata["source"])
return str(self.path) if self.path else None
@root_validator(pre=True)
def check_blob_is_valid(cls, values: Mapping[str, Any]) -> Mapping[str, Any]:
@model_validator(mode="before")
@classmethod
def check_blob_is_valid(cls, values: Dict[str, Any]) -> Any:
"""Verify that either data or path is provided."""
if "data" not in values and "path" not in values:
raise ValueError("Either data or path must be provided")
@@ -293,3 +297,7 @@ class Document(BaseMedia):
return f"page_content='{self.page_content}' metadata={self.metadata}"
else:
return f"page_content='{self.page_content}'"
def __repr__(self) -> str:
# TODO(0.3): Remove this override after confirming unit tests!
return v1_repr(self)

View File

@@ -3,9 +3,10 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Optional, Sequence
from pydantic import BaseModel
from langchain_core.callbacks import Callbacks
from langchain_core.documents import Document
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables import run_in_executor

View File

@@ -4,8 +4,9 @@
import hashlib
from typing import List
from pydantic import BaseModel
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel
class FakeEmbeddings(Embeddings, BaseModel):

View File

@@ -3,9 +3,10 @@
import re
from typing import Callable, Dict, List
from pydantic import BaseModel, validator
from langchain_core.example_selectors.base import BaseExampleSelector
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.pydantic_v1 import BaseModel, validator
def _get_length_based(text: str) -> int:

View File

@@ -5,9 +5,10 @@ from __future__ import annotations
from abc import ABC
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
from pydantic import BaseModel, ConfigDict
from langchain_core.documents import Document
from langchain_core.example_selectors.base import BaseExampleSelector
from langchain_core.pydantic_v1 import BaseModel, Extra
from langchain_core.vectorstores import VectorStore
if TYPE_CHECKING:
@@ -42,9 +43,10 @@ class _VectorStoreExampleSelector(BaseExampleSelector, BaseModel, ABC):
vectorstore_kwargs: Optional[Dict[str, Any]] = None
"""Extra arguments passed to similarity_search function of the vectorstore."""
class Config:
arbitrary_types_allowed = True
extra = Extra.forbid
model_config = ConfigDict(
arbitrary_types_allowed=True,
extra="forbid",
)
@staticmethod
def _example_to_text(

View File

@@ -12,6 +12,8 @@ from typing import (
Optional,
)
from pydantic import Field
from langchain_core._api import beta
from langchain_core.callbacks import (
AsyncCallbackManagerForRetrieverRun,
@@ -20,7 +22,6 @@ from langchain_core.callbacks import (
from langchain_core.documents import Document
from langchain_core.graph_vectorstores.links import METADATA_LINKS_KEY, Link
from langchain_core.load import Serializable
from langchain_core.pydantic_v1 import Field
from langchain_core.runnables import run_in_executor
from langchain_core.vectorstores import VectorStore, VectorStoreRetriever

View File

@@ -25,10 +25,11 @@ from typing import (
cast,
)
from pydantic import model_validator
from langchain_core.document_loaders.base import BaseLoader
from langchain_core.documents import Document
from langchain_core.indexing.base import DocumentIndex, RecordManager
from langchain_core.pydantic_v1 import root_validator
from langchain_core.vectorstores import VectorStore
# Magic UUID to use as a namespace for hashing.
@@ -68,8 +69,9 @@ class _HashedDocument(Document):
def is_lc_serializable(cls) -> bool:
return False
@root_validator(pre=True)
def calculate_hashes(cls, values: Dict[str, Any]) -> Dict[str, Any]:
@model_validator(mode="before")
@classmethod
def calculate_hashes(cls, values: Dict[str, Any]) -> Any:
"""Root validator to calculate content and metadata hash."""
content = values.get("page_content", "")
metadata = values.get("metadata", {})

View File

@@ -1,12 +1,13 @@
import uuid
from typing import Any, Dict, List, Optional, Sequence, cast
from pydantic import Field
from langchain_core._api import beta
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.indexing import UpsertResponse
from langchain_core.indexing.base import DeleteResponse, DocumentIndex
from langchain_core.pydantic_v1 import Field
@beta(message="Introduced in version 0.2.29. Underlying abstraction subject to change.")

View File

@@ -18,6 +18,7 @@ from typing import (
Union,
)
from pydantic import BaseModel, ConfigDict, Field, validator
from typing_extensions import TypeAlias, TypedDict
from langchain_core._api import deprecated
@@ -28,7 +29,6 @@ from langchain_core.messages import (
get_buffer_string,
)
from langchain_core.prompt_values import PromptValue
from langchain_core.pydantic_v1 import BaseModel, Field, validator
from langchain_core.runnables import Runnable, RunnableSerializable
from langchain_core.utils import get_pydantic_field_names
@@ -113,7 +113,11 @@ class BaseLanguageModel(
Caching is not currently supported for streaming methods of models.
"""
verbose: bool = Field(default_factory=_get_verbosity)
# Repr = False is consistent with pydantic 1 if verbose = False
# We can relax this for pydantic 2?
# TODO(0.3): Resolve repr for verbose
# Modified just to get unit tests to pass.
verbose: bool = Field(default_factory=_get_verbosity, exclude=True, repr=False)
"""Whether to print out response text."""
callbacks: Callbacks = Field(default=None, exclude=True)
"""Callbacks to add to the run trace."""
@@ -126,6 +130,10 @@ class BaseLanguageModel(
)
"""Optional encoder to use for counting tokens."""
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
@validator("verbose", pre=True, always=True, allow_reuse=True)
def set_verbose(cls, verbose: Optional[bool]) -> bool:
"""If verbose is None, set it.

View File

@@ -23,6 +23,13 @@ from typing import (
cast,
)
from pydantic import (
BaseModel,
ConfigDict,
Field,
model_validator,
)
from langchain_core._api import deprecated
from langchain_core.caches import BaseCache
from langchain_core.callbacks import (
@@ -57,11 +64,6 @@ from langchain_core.outputs import (
RunInfo,
)
from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue
from langchain_core.pydantic_v1 import (
BaseModel,
Field,
root_validator,
)
from langchain_core.rate_limiters import BaseRateLimiter
from langchain_core.runnables import RunnableMap, RunnablePassthrough
from langchain_core.runnables.config import ensure_config, run_in_executor
@@ -193,14 +195,20 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
""" # noqa: E501
callback_manager: Optional[BaseCallbackManager] = deprecated(
name="callback_manager", since="0.1.7", removal="1.0", alternative="callbacks"
)(
Field(
default=None,
exclude=True,
description="Callback manager to add to the run trace.",
)
# TODO(0.3): Figure out how to re-apply deprecated decorator
# callback_manager: Optional[BaseCallbackManager] = deprecated(
# name="callback_manager", since="0.1.7", removal="1.0", alternative="callbacks"
# )(
# Field(
# default=None,
# exclude=True,
# description="Callback manager to add to the run trace.",
# )
# )
callback_manager: Optional[BaseCallbackManager] = Field(
default=None,
exclude=True,
description="Callback manager to add to the run trace.",
)
rate_limiter: Optional[BaseRateLimiter] = Field(default=None, exclude=True)
@@ -218,8 +226,9 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
- If False (default), will always use streaming case if available.
"""
@root_validator(pre=True)
def raise_deprecation(cls, values: Dict) -> Dict:
@model_validator(mode="before")
@classmethod
def raise_deprecation(cls, values: Dict) -> Any:
"""Raise deprecation warning if callback_manager is used.
Args:
@@ -240,8 +249,9 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
values["callbacks"] = values.pop("callback_manager", None)
return values
class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
# --- Runnable methods ---

View File

@@ -27,6 +27,7 @@ from typing import (
)
import yaml
from pydantic import ConfigDict, Field, model_validator
from tenacity import (
RetryCallState,
before_sleep_log,
@@ -62,7 +63,6 @@ from langchain_core.messages import (
)
from langchain_core.outputs import Generation, GenerationChunk, LLMResult, RunInfo
from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.runnables import RunnableConfig, ensure_config, get_config_list
from langchain_core.runnables.config import run_in_executor
@@ -300,11 +300,13 @@ class BaseLLM(BaseLanguageModel[str], ABC):
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
"""[DEPRECATED]"""
class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
@root_validator(pre=True)
def raise_deprecation(cls, values: Dict) -> Dict:
@model_validator(mode="before")
@classmethod
def raise_deprecation(cls, values: Dict) -> Any:
"""Raise deprecation warning if callback_manager is used."""
if values.get("callback_manager") is not None:
warnings.warn(

View File

@@ -17,6 +17,7 @@ DEFAULT_NAMESPACES = [
"langchain_core",
"langchain_community",
"langchain_anthropic",
"langchain_groq",
]
ALL_SERIALIZABLE_MAPPINGS = {

View File

@@ -271,6 +271,11 @@ SERIALIZABLE_MAPPING: Dict[Tuple[str, ...], Tuple[str, ...]] = {
"chat_models",
"ChatAnthropic",
),
("langchain_groq", "chat_models", "ChatGroq"): (
"langchain_groq",
"chat_models",
"ChatGroq",
),
("langchain", "chat_models", "fireworks", "ChatFireworks"): (
"langchain_fireworks",
"chat_models",
@@ -287,6 +292,17 @@ SERIALIZABLE_MAPPING: Dict[Tuple[str, ...], Tuple[str, ...]] = {
"chat_models",
"ChatVertexAI",
),
("langchain", "chat_models", "mistralai", "ChatMistralAI"): (
"langchain_mistralai",
"chat_models",
"ChatMistralAI",
),
("langchain", "chat_models", "bedrock", "ChatBedrock"): (
"langchain_aws",
"chat_models",
"bedrock",
"ChatBedrock",
),
("langchain", "schema", "output", "ChatGenerationChunk"): (
"langchain_core",
"outputs",

View File

@@ -10,9 +10,10 @@ from typing import (
cast,
)
from pydantic import BaseModel, ConfigDict
from typing_extensions import NotRequired
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.utils.pydantic import v1_repr
class BaseSerialized(TypedDict):
@@ -80,7 +81,7 @@ def try_neq_default(value: Any, key: str, model: BaseModel) -> bool:
Exception: If the key is not in the model.
"""
try:
return model.__fields__[key].get_default() != value
return model.model_fields[key].get_default() != value
except Exception:
return True
@@ -161,16 +162,25 @@ class Serializable(BaseModel, ABC):
For example, for the class `langchain.llms.openai.OpenAI`, the id is
["langchain", "llms", "openai", "OpenAI"].
"""
return [*cls.get_lc_namespace(), cls.__name__]
# Pydantic generics change the class name. So we need to do the following
if (
"origin" in cls.__pydantic_generic_metadata__
and cls.__pydantic_generic_metadata__["origin"] is not None
):
original_name = cls.__pydantic_generic_metadata__["origin"].__name__
else:
original_name = cls.__name__
return [*cls.get_lc_namespace(), original_name]
class Config:
extra = "ignore"
model_config = ConfigDict(
extra="ignore",
)
def __repr_args__(self) -> Any:
return [
(k, v)
for k, v in super().__repr_args__()
if (k not in self.__fields__ or try_neq_default(v, k, self))
if (k not in self.model_fields or try_neq_default(v, k, self))
]
def to_json(self) -> Union[SerializedConstructor, SerializedNotImplemented]:
@@ -184,12 +194,15 @@ class Serializable(BaseModel, ABC):
secrets = dict()
# Get latest values for kwargs if there is an attribute with same name
lc_kwargs = {
k: getattr(self, k, v)
for k, v in self
if not (self.__exclude_fields__ or {}).get(k, False) # type: ignore
and _is_field_useful(self, k, v)
}
lc_kwargs = {}
for k, v in self:
if not _is_field_useful(self, k, v):
continue
# Do nothing if the field is excluded
if k in self.model_fields and self.model_fields[k].exclude:
continue
lc_kwargs[k] = getattr(self, k, v)
# Merge the lc_secrets and lc_attributes from every class in the MRO
for cls in [None, *self.__class__.mro()]:
@@ -221,8 +234,10 @@ class Serializable(BaseModel, ABC):
# that are not present in the fields.
for key in list(secrets):
value = secrets[key]
if key in this.__fields__:
secrets[this.__fields__[key].alias] = value
if key in this.model_fields:
alias = this.model_fields[key].alias
if alias is not None:
secrets[alias] = value
lc_kwargs.update(this.lc_attributes)
# include all secrets, even if not specified in kwargs
@@ -244,6 +259,10 @@ class Serializable(BaseModel, ABC):
def to_json_not_implemented(self) -> SerializedNotImplemented:
return to_json_not_implemented(self)
def __repr__(self) -> str:
# TODO(0.3): Remove this override after confirming unit tests!
return v1_repr(self)
def _is_field_useful(inst: Serializable, key: str, value: Any) -> bool:
"""Check if a field is useful as a constructor argument.
@@ -259,9 +278,13 @@ def _is_field_useful(inst: Serializable, key: str, value: Any) -> bool:
If the field is not required and the value is None, it is useful if the
default value is different from the value.
"""
field = inst.__fields__.get(key)
field = inst.model_fields.get(key)
if not field:
return False
if field.is_required():
return True
# Handle edge case: a value cannot be converted to a boolean (e.g. a
# Pandas DataFrame).
try:
@@ -269,6 +292,17 @@ def _is_field_useful(inst: Serializable, key: str, value: Any) -> bool:
except Exception as _:
value_is_truthy = False
if value_is_truthy:
return True
# Value is still falsy here!
if field.default_factory is dict and isinstance(value, dict):
return False
# Value is still falsy here!
if field.default_factory is list and isinstance(value, list):
return False
# Handle edge case: inequality of two objects does not evaluate to a bool (e.g. two
# Pandas DataFrames).
try:
@@ -282,7 +316,8 @@ def _is_field_useful(inst: Serializable, key: str, value: Any) -> bool:
except Exception as _:
value_neq_default = False
return field.required is True or value_is_truthy or value_neq_default
# If value is falsy and does not match the default
return value_is_truthy or value_neq_default
def _replace_secrets(

View File

@@ -13,6 +13,8 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Dict, List
from pydantic import ConfigDict
from langchain_core.load.serializable import Serializable
from langchain_core.runnables import run_in_executor
@@ -47,8 +49,9 @@ class BaseMemory(Serializable, ABC):
pass
""" # noqa: E501
class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
@property
@abstractmethod

View File

@@ -1,7 +1,8 @@
import json
from typing import Any, Dict, List, Literal, Optional, Union
from typing_extensions import TypedDict
from pydantic import model_validator
from typing_extensions import Self, TypedDict
from langchain_core.messages.base import (
BaseMessage,
@@ -24,7 +25,6 @@ from langchain_core.messages.tool import (
from langchain_core.messages.tool import (
tool_call_chunk as create_tool_call_chunk,
)
from langchain_core.pydantic_v1 import root_validator
from langchain_core.utils._merge import merge_dicts, merge_lists
from langchain_core.utils.json import parse_partial_json
@@ -111,8 +111,9 @@ class AIMessage(BaseMessage):
"invalid_tool_calls": self.invalid_tool_calls,
}
@root_validator(pre=True)
def _backwards_compat_tool_calls(cls, values: dict) -> dict:
@model_validator(mode="before")
@classmethod
def _backwards_compat_tool_calls(cls, values: dict) -> Any:
check_additional_kwargs = not any(
values.get(k)
for k in ("tool_calls", "invalid_tool_calls", "tool_call_chunks")
@@ -204,7 +205,7 @@ class AIMessage(BaseMessage):
return (base.strip() + "\n" + "\n".join(lines)).strip()
AIMessage.update_forward_refs()
AIMessage.model_rebuild()
class AIMessageChunk(AIMessage, BaseMessageChunk):
@@ -238,8 +239,8 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
"invalid_tool_calls": self.invalid_tool_calls,
}
@root_validator(pre=False, skip_on_failure=True)
def init_tool_calls(cls, values: dict) -> dict:
@model_validator(mode="after")
def init_tool_calls(self) -> Self:
"""Initialize tool calls from tool call chunks.
Args:
@@ -251,35 +252,35 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
Raises:
ValueError: If the tool call chunks are malformed.
"""
if not values["tool_call_chunks"]:
if values["tool_calls"]:
values["tool_call_chunks"] = [
if not self.tool_call_chunks:
if self.tool_calls:
self.tool_call_chunks = [
create_tool_call_chunk(
name=tc["name"],
args=json.dumps(tc["args"]),
id=tc["id"],
index=None,
)
for tc in values["tool_calls"]
for tc in self.tool_calls
]
if values["invalid_tool_calls"]:
tool_call_chunks = values.get("tool_call_chunks", [])
if self.invalid_tool_calls:
tool_call_chunks = self.tool_call_chunks
tool_call_chunks.extend(
[
create_tool_call_chunk(
name=tc["name"], args=tc["args"], id=tc["id"], index=None
)
for tc in values["invalid_tool_calls"]
for tc in self.invalid_tool_calls
]
)
values["tool_call_chunks"] = tool_call_chunks
self.tool_call_chunks = tool_call_chunks
return values
return self
tool_calls = []
invalid_tool_calls = []
for chunk in values["tool_call_chunks"]:
for chunk in self.tool_call_chunks:
try:
args_ = parse_partial_json(chunk["args"]) if chunk["args"] != "" else {}
args_ = parse_partial_json(chunk["args"]) if chunk["args"] != "" else {} # type: ignore[arg-type]
if isinstance(args_, dict):
tool_calls.append(
create_tool_call(
@@ -299,9 +300,9 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
error=None,
)
)
values["tool_calls"] = tool_calls
values["invalid_tool_calls"] = invalid_tool_calls
return values
self.tool_calls = tool_calls
self.invalid_tool_calls = invalid_tool_calls
return self
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
if isinstance(other, AIMessageChunk):

View File

@@ -2,11 +2,13 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union, cast
from pydantic import ConfigDict, Field
from langchain_core.load.serializable import Serializable
from langchain_core.pydantic_v1 import Extra, Field
from langchain_core.utils import get_bolded_text
from langchain_core.utils._merge import merge_dicts, merge_lists
from langchain_core.utils.interactive_env import is_interactive_env
from langchain_core.utils.pydantic import v1_repr
if TYPE_CHECKING:
from langchain_core.prompts.chat import ChatPromptTemplate
@@ -51,8 +53,9 @@ class BaseMessage(Serializable):
"""An optional unique identifier for the message. This should ideally be
provided by the provider/model which created the message."""
class Config:
extra = Extra.allow
model_config = ConfigDict(
extra="allow",
)
def __init__(
self, content: Union[str, List[Union[str, Dict]]], **kwargs: Any
@@ -108,6 +111,10 @@ class BaseMessage(Serializable):
def pretty_print(self) -> None:
print(self.pretty_repr(html=is_interactive_env())) # noqa: T201
def __repr__(self) -> str:
# TODO(0.3): Remove this override after confirming unit tests!
return v1_repr(self)
def merge_content(
first_content: Union[str, List[Union[str, Dict]]],

View File

@@ -25,7 +25,7 @@ class ChatMessage(BaseMessage):
return ["langchain", "schema", "messages"]
ChatMessage.update_forward_refs()
ChatMessage.model_rebuild()
class ChatMessageChunk(ChatMessage, BaseMessageChunk):

View File

@@ -32,7 +32,7 @@ class FunctionMessage(BaseMessage):
return ["langchain", "schema", "messages"]
FunctionMessage.update_forward_refs()
FunctionMessage.model_rebuild()
class FunctionMessageChunk(FunctionMessage, BaseMessageChunk):

View File

@@ -56,7 +56,7 @@ class HumanMessage(BaseMessage):
super().__init__(content=content, **kwargs)
HumanMessage.update_forward_refs()
HumanMessage.model_rebuild()
class HumanMessageChunk(HumanMessage, BaseMessageChunk):

View File

@@ -33,4 +33,4 @@ class RemoveMessage(BaseMessage):
return ["langchain", "schema", "messages"]
RemoveMessage.update_forward_refs()
RemoveMessage.model_rebuild()

View File

@@ -50,7 +50,7 @@ class SystemMessage(BaseMessage):
super().__init__(content=content, **kwargs)
SystemMessage.update_forward_refs()
SystemMessage.model_rebuild()
class SystemMessageChunk(SystemMessage, BaseMessageChunk):

View File

@@ -1,6 +1,7 @@
import json
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
from pydantic import Field
from typing_extensions import NotRequired, TypedDict
from langchain_core.messages.base import BaseMessage, BaseMessageChunk, merge_content
@@ -70,6 +71,11 @@ class ToolMessage(BaseMessage):
.. versionadded:: 0.2.24
"""
additional_kwargs: dict = Field(default_factory=dict, repr=False)
"""Currently inherited from BaseMessage, but not used."""
response_metadata: dict = Field(default_factory=dict, repr=False)
"""Currently inherited from BaseMessage, but not used."""
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object.
@@ -88,7 +94,7 @@ class ToolMessage(BaseMessage):
super().__init__(content=content, **kwargs)
ToolMessage.update_forward_refs()
ToolMessage.model_rebuild()
class ToolMessageChunk(ToolMessage, BaseMessageChunk):

View File

@@ -52,6 +52,12 @@ AnyMessage = Union[
SystemMessage,
FunctionMessage,
ToolMessage,
AIMessageChunk,
HumanMessageChunk,
ChatMessageChunk,
SystemMessageChunk,
FunctionMessageChunk,
ToolMessageChunk,
]

View File

@@ -13,8 +13,6 @@ from typing import (
Union,
)
from typing_extensions import get_args
from langchain_core.language_models import LanguageModelOutput
from langchain_core.messages import AnyMessage, BaseMessage
from langchain_core.outputs import ChatGeneration, Generation
@@ -166,10 +164,11 @@ class BaseOutputParser(
Raises:
TypeError: If the class doesn't have an inferable OutputType.
"""
for cls in self.__class__.__orig_bases__: # type: ignore[attr-defined]
type_args = get_args(cls)
if type_args and len(type_args) == 1:
return type_args[0]
for base in self.__class__.mro():
if hasattr(base, "__pydantic_generic_metadata__"):
metadata = base.__pydantic_generic_metadata__
if "args" in metadata and len(metadata["args"]) > 0:
return metadata["args"][0]
raise TypeError(
f"Runnable {self.__class__.__name__} doesn't have an inferable OutputType. "

View File

@@ -5,7 +5,9 @@ from json import JSONDecodeError
from typing import Any, List, Optional, Type, TypeVar, Union
import jsonpatch # type: ignore[import]
import pydantic # pydantic: ignore
import pydantic
from pydantic import SkipValidation
from typing_extensions import Annotated
from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers.format_instructions import JSON_FORMAT_INSTRUCTIONS
@@ -22,7 +24,7 @@ if PYDANTIC_MAJOR_VERSION < 2:
PydanticBaseModel = pydantic.BaseModel
else:
from pydantic.v1 import BaseModel # pydantic: ignore
from pydantic.v1 import BaseModel
# Union type needs to be last assignment to PydanticBaseModel to make mypy happy.
PydanticBaseModel = Union[BaseModel, pydantic.BaseModel] # type: ignore
@@ -40,7 +42,7 @@ class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
describing the difference between the previous and the current object.
"""
pydantic_object: Optional[Type[TBaseModel]] = None # type: ignore
pydantic_object: Annotated[Optional[Type[TBaseModel]], SkipValidation()] = None # type: ignore
"""The Pydantic object to use for validation.
If None, no validation is performed."""

View File

@@ -4,6 +4,7 @@ import re
from abc import abstractmethod
from collections import deque
from typing import AsyncIterator, Deque, Iterator, List, TypeVar, Union
from typing import Optional as Optional
from langchain_core.messages import BaseMessage
from langchain_core.output_parsers.transform import BaseTransformOutputParser
@@ -122,6 +123,9 @@ class ListOutputParser(BaseTransformOutputParser[List[str]]):
yield [part]
ListOutputParser.model_rebuild()
class CommaSeparatedListOutputParser(ListOutputParser):
"""Parse the output of an LLM call to a comma-separated list."""

View File

@@ -3,6 +3,7 @@ import json
from typing import Any, Dict, List, Optional, Type, Union
import jsonpatch # type: ignore[import]
from pydantic import BaseModel, model_validator
from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers import (
@@ -11,7 +12,6 @@ from langchain_core.output_parsers import (
)
from langchain_core.output_parsers.json import parse_partial_json
from langchain_core.outputs import ChatGeneration, Generation
from langchain_core.pydantic_v1 import BaseModel, root_validator
class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
@@ -230,8 +230,9 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser):
determine which schema to use.
"""
@root_validator(pre=True)
def validate_schema(cls, values: Dict) -> Dict:
@model_validator(mode="before")
@classmethod
def validate_schema(cls, values: Dict) -> Any:
"""Validate the pydantic schema.
Args:
@@ -267,11 +268,17 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser):
"""
_result = super().parse_result(result)
if self.args_only:
pydantic_args = self.pydantic_schema.parse_raw(_result) # type: ignore
if hasattr(self.pydantic_schema, "model_validate_json"):
pydantic_args = self.pydantic_schema.model_validate_json(_result)
else:
pydantic_args = self.pydantic_schema.parse_raw(_result) # type: ignore
else:
fn_name = _result["name"]
_args = _result["arguments"]
pydantic_args = self.pydantic_schema[fn_name].parse_raw(_args) # type: ignore
if hasattr(self.pydantic_schema, "model_validate_json"):
pydantic_args = self.pydantic_schema[fn_name].model_validate_json(_args) # type: ignore
else:
pydantic_args = self.pydantic_schema[fn_name].parse_raw(_args) # type: ignore
return pydantic_args

View File

@@ -3,13 +3,15 @@ import json
from json import JSONDecodeError
from typing import Any, Dict, List, Optional
from pydantic import SkipValidation, ValidationError
from typing_extensions import Annotated
from langchain_core.exceptions import OutputParserException
from langchain_core.messages import AIMessage, InvalidToolCall
from langchain_core.messages.tool import invalid_tool_call
from langchain_core.messages.tool import tool_call as create_tool_call
from langchain_core.output_parsers.transform import BaseCumulativeTransformOutputParser
from langchain_core.outputs import ChatGeneration, Generation
from langchain_core.pydantic_v1 import ValidationError
from langchain_core.utils.json import parse_partial_json
from langchain_core.utils.pydantic import TypeBaseModel
@@ -252,7 +254,7 @@ class JsonOutputKeyToolsParser(JsonOutputToolsParser):
class PydanticToolsParser(JsonOutputToolsParser):
"""Parse tools from OpenAI response."""
tools: List[TypeBaseModel]
tools: Annotated[List[TypeBaseModel], SkipValidation()]
"""The tools to parse."""
# TODO: Support more granular streaming of objects. Currently only streams once all

View File

@@ -1,7 +1,9 @@
import json
from typing import Generic, List, Optional, Type
import pydantic # pydantic: ignore
import pydantic
from pydantic import SkipValidation
from typing_extensions import Annotated
from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers import JsonOutputParser
@@ -16,7 +18,7 @@ from langchain_core.utils.pydantic import (
class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
"""Parse an output using a pydantic model."""
pydantic_object: Type[TBaseModel] # type: ignore
pydantic_object: Annotated[Type[TBaseModel], SkipValidation()] # type: ignore
"""The pydantic model to parse."""
def _parse_obj(self, obj: dict) -> TBaseModel:
@@ -111,6 +113,9 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
return self.pydantic_object
PydanticOutputParser.model_rebuild()
_PYDANTIC_FORMAT_INSTRUCTIONS = """The output should be formatted as a JSON instance that conforms to the JSON schema below.
As an example, for the schema {{"properties": {{"foo": {{"title": "Foo", "description": "a list of strings", "type": "array", "items": {{"type": "string"}}}}}}, "required": ["foo"]}}

View File

@@ -1,4 +1,5 @@
from typing import List
from typing import Optional as Optional
from langchain_core.output_parsers.transform import BaseTransformOutputParser
@@ -24,3 +25,6 @@ class StrOutputParser(BaseTransformOutputParser[str]):
def parse(self, text: str) -> str:
"""Returns the input text with no changes."""
return text
StrOutputParser.model_rebuild()

View File

@@ -1,10 +1,12 @@
from __future__ import annotations
from typing import Any, Dict, List, Literal, Union
from typing import List, Literal, Union
from pydantic import model_validator
from typing_extensions import Self
from langchain_core.messages import BaseMessage, BaseMessageChunk
from langchain_core.outputs.generation import Generation
from langchain_core.pydantic_v1 import root_validator
from langchain_core.utils._merge import merge_dicts
@@ -30,8 +32,8 @@ class ChatGeneration(Generation):
type: Literal["ChatGeneration"] = "ChatGeneration" # type: ignore[assignment]
"""Type is used exclusively for serialization purposes."""
@root_validator(pre=False, skip_on_failure=True)
def set_text(cls, values: Dict[str, Any]) -> Dict[str, Any]:
@model_validator(mode="after")
def set_text(self) -> Self:
"""Set the text attribute to be the contents of the message.
Args:
@@ -45,12 +47,12 @@ class ChatGeneration(Generation):
"""
try:
text = ""
if isinstance(values["message"].content, str):
text = values["message"].content
if isinstance(self.message.content, str):
text = self.message.content
# HACK: Assumes text in content blocks in OpenAI format.
# Uses first text block.
elif isinstance(values["message"].content, list):
for block in values["message"].content:
elif isinstance(self.message.content, list):
for block in self.message.content:
if isinstance(block, str):
text = block
break
@@ -61,10 +63,10 @@ class ChatGeneration(Generation):
pass
else:
pass
values["text"] = text
self.text = text
except (KeyError, AttributeError) as e:
raise ValueError("Error while initializing ChatGeneration") from e
return values
return self
@classmethod
def get_lc_namespace(cls) -> List[str]:

View File

@@ -1,7 +1,8 @@
from typing import List, Optional
from pydantic import BaseModel
from langchain_core.outputs.chat_generation import ChatGeneration
from langchain_core.pydantic_v1 import BaseModel
class ChatResult(BaseModel):

View File

@@ -1,11 +1,13 @@
from __future__ import annotations
from copy import deepcopy
from typing import List, Optional
from typing import List, Optional, Union
from langchain_core.outputs.generation import Generation
from pydantic import BaseModel
from langchain_core.outputs.chat_generation import ChatGeneration, ChatGenerationChunk
from langchain_core.outputs.generation import Generation, GenerationChunk
from langchain_core.outputs.run_info import RunInfo
from langchain_core.pydantic_v1 import BaseModel
class LLMResult(BaseModel):
@@ -16,7 +18,9 @@ class LLMResult(BaseModel):
wants to return.
"""
generations: List[List[Generation]]
generations: List[
List[Union[Generation, ChatGeneration, GenerationChunk, ChatGenerationChunk]]
]
"""Generated outputs.
The first dimension of the list represents completions for different input

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
from uuid import UUID
from langchain_core.pydantic_v1 import BaseModel
from pydantic import BaseModel
class RunInfo(BaseModel):

View File

@@ -18,6 +18,8 @@ from typing import (
)
import yaml
from pydantic import BaseModel, ConfigDict, Field, model_validator
from typing_extensions import Self
from langchain_core.output_parsers.base import BaseOutputParser
from langchain_core.prompt_values import (
@@ -25,7 +27,6 @@ from langchain_core.prompt_values import (
PromptValue,
StringPromptValue,
)
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
from langchain_core.runnables import RunnableConfig, RunnableSerializable
from langchain_core.runnables.config import ensure_config
from langchain_core.runnables.utils import create_model
@@ -64,28 +65,26 @@ class BasePromptTemplate(
tags: Optional[List[str]] = None
"""Tags to be used for tracing."""
@root_validator(pre=False, skip_on_failure=True)
def validate_variable_names(cls, values: Dict) -> Dict:
@model_validator(mode="after")
def validate_variable_names(self) -> Self:
"""Validate variable names do not include restricted names."""
if "stop" in values["input_variables"]:
if "stop" in self.input_variables:
raise ValueError(
"Cannot have an input variable named 'stop', as it is used internally,"
" please rename."
)
if "stop" in values["partial_variables"]:
if "stop" in self.partial_variables:
raise ValueError(
"Cannot have an partial variable named 'stop', as it is used "
"internally, please rename."
)
overall = set(values["input_variables"]).intersection(
values["partial_variables"]
)
overall = set(self.input_variables).intersection(self.partial_variables)
if overall:
raise ValueError(
f"Found overlapping input and partial variables: {overall}"
)
return values
return self
@classmethod
def get_lc_namespace(cls) -> List[str]:
@@ -99,8 +98,9 @@ class BasePromptTemplate(
Returns True."""
return True
class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
@property
def OutputType(self) -> Any:

View File

@@ -21,6 +21,14 @@ from typing import (
overload,
)
from pydantic import (
Field,
PositiveInt,
SkipValidation,
model_validator,
)
from typing_extensions import Annotated
from langchain_core._api import deprecated
from langchain_core.load import Serializable
from langchain_core.messages import (
@@ -38,7 +46,6 @@ from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.prompts.image import ImagePromptTemplate
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.prompts.string import StringPromptTemplate, get_template_variables
from langchain_core.pydantic_v1 import Field, PositiveInt, root_validator
from langchain_core.utils import get_colored_text
from langchain_core.utils.interactive_env import is_interactive_env
@@ -207,8 +214,14 @@ class MessagesPlaceholder(BaseMessagePromptTemplate):
"""Get the namespace of the langchain object."""
return ["langchain", "prompts", "chat"]
def __init__(self, variable_name: str, *, optional: bool = False, **kwargs: Any):
super().__init__(variable_name=variable_name, optional=optional, **kwargs)
def __init__(
self, variable_name: str, *, optional: bool = False, **kwargs: Any
) -> None:
# mypy can't detect the init which is defined in the parent class
# b/c these are BaseModel classes.
super().__init__( # type: ignore
variable_name=variable_name, optional=optional, **kwargs
)
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""Format messages from kwargs.
@@ -922,7 +935,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
""" # noqa: E501
messages: List[MessageLike]
messages: Annotated[List[MessageLike], SkipValidation()]
"""List of messages consisting of either message prompt templates or messages."""
validate_template: bool = False
"""Whether or not to try validating the template."""
@@ -1038,8 +1051,9 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
else:
raise NotImplementedError(f"Unsupported operand type for +: {type(other)}")
@root_validator(pre=True)
def validate_input_variables(cls, values: dict) -> dict:
@model_validator(mode="before")
@classmethod
def validate_input_variables(cls, values: dict) -> Any:
"""Validate input variables.
If input_variables is not set, it will be set to the union of

View File

@@ -5,6 +5,14 @@ from __future__ import annotations
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Union
from pydantic import (
BaseModel,
ConfigDict,
Field,
model_validator,
)
from typing_extensions import Self
from langchain_core.example_selectors import BaseExampleSelector
from langchain_core.messages import BaseMessage, get_buffer_string
from langchain_core.prompts.chat import (
@@ -18,7 +26,6 @@ from langchain_core.prompts.string import (
check_valid_template,
get_template_variables,
)
from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator
class _FewShotPromptTemplateMixin(BaseModel):
@@ -32,12 +39,14 @@ class _FewShotPromptTemplateMixin(BaseModel):
"""ExampleSelector to choose the examples to format into the prompt.
Either this or examples should be provided."""
class Config:
arbitrary_types_allowed = True
extra = Extra.forbid
model_config = ConfigDict(
arbitrary_types_allowed=True,
extra="forbid",
)
@root_validator(pre=True)
def check_examples_and_selector(cls, values: Dict) -> Dict:
@model_validator(mode="before")
@classmethod
def check_examples_and_selector(cls, values: Dict) -> Any:
"""Check that one and only one of examples/example_selector are provided.
Args:
@@ -139,28 +148,29 @@ class FewShotPromptTemplate(_FewShotPromptTemplateMixin, StringPromptTemplate):
kwargs["input_variables"] = kwargs["example_prompt"].input_variables
super().__init__(**kwargs)
@root_validator(pre=False, skip_on_failure=True)
def template_is_valid(cls, values: Dict) -> Dict:
@model_validator(mode="after")
def template_is_valid(self) -> Self:
"""Check that prefix, suffix, and input variables are consistent."""
if values["validate_template"]:
if self.validate_template:
check_valid_template(
values["prefix"] + values["suffix"],
values["template_format"],
values["input_variables"] + list(values["partial_variables"]),
self.prefix + self.suffix,
self.template_format,
self.input_variables + list(self.partial_variables),
)
elif values.get("template_format"):
values["input_variables"] = [
elif self.template_format or None:
self.input_variables = [
var
for var in get_template_variables(
values["prefix"] + values["suffix"], values["template_format"]
self.prefix + self.suffix, self.template_format
)
if var not in values["partial_variables"]
if var not in self.partial_variables
]
return values
return self
class Config:
arbitrary_types_allowed = True
extra = Extra.forbid
model_config = ConfigDict(
arbitrary_types_allowed=True,
extra="forbid",
)
def format(self, **kwargs: Any) -> str:
"""Format the prompt with inputs generating a string.
@@ -365,9 +375,10 @@ class FewShotChatMessagePromptTemplate(
"""Return whether or not the class is serializable."""
return False
class Config:
arbitrary_types_allowed = True
extra = Extra.forbid
model_config = ConfigDict(
arbitrary_types_allowed=True,
extra="forbid",
)
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""Format kwargs into a list of messages.

View File

@@ -3,12 +3,14 @@
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from pydantic import ConfigDict, model_validator
from typing_extensions import Self
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.prompts.string import (
DEFAULT_FORMATTER_MAPPING,
StringPromptTemplate,
)
from langchain_core.pydantic_v1 import Extra, root_validator
class FewShotPromptWithTemplates(StringPromptTemplate):
@@ -45,8 +47,9 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
"""Get the namespace of the langchain object."""
return ["langchain", "prompts", "few_shot_with_templates"]
@root_validator(pre=True)
def check_examples_and_selector(cls, values: Dict) -> Dict:
@model_validator(mode="before")
@classmethod
def check_examples_and_selector(cls, values: Dict) -> Any:
"""Check that one and only one of examples/example_selector are provided."""
examples = values.get("examples", None)
example_selector = values.get("example_selector", None)
@@ -62,15 +65,15 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
return values
@root_validator(pre=False, skip_on_failure=True)
def template_is_valid(cls, values: Dict) -> Dict:
@model_validator(mode="after")
def template_is_valid(self) -> Self:
"""Check that prefix, suffix, and input variables are consistent."""
if values["validate_template"]:
input_variables = values["input_variables"]
expected_input_variables = set(values["suffix"].input_variables)
expected_input_variables |= set(values["partial_variables"])
if values["prefix"] is not None:
expected_input_variables |= set(values["prefix"].input_variables)
if self.validate_template:
input_variables = self.input_variables
expected_input_variables = set(self.suffix.input_variables)
expected_input_variables |= set(self.partial_variables)
if self.prefix is not None:
expected_input_variables |= set(self.prefix.input_variables)
missing_vars = expected_input_variables.difference(input_variables)
if missing_vars:
raise ValueError(
@@ -78,16 +81,17 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
f"prefix/suffix expected {expected_input_variables}"
)
else:
values["input_variables"] = sorted(
set(values["suffix"].input_variables)
| set(values["prefix"].input_variables if values["prefix"] else [])
- set(values["partial_variables"])
self.input_variables = sorted(
set(self.suffix.input_variables)
| set(self.prefix.input_variables if self.prefix else [])
- set(self.partial_variables)
)
return values
return self
class Config:
arbitrary_types_allowed = True
extra = Extra.forbid
model_config = ConfigDict(
arbitrary_types_allowed=True,
extra="forbid",
)
def _get_examples(self, **kwargs: Any) -> List[dict]:
if self.examples is not None:

View File

@@ -1,8 +1,9 @@
from typing import Any, List
from pydantic import Field
from langchain_core.prompt_values import ImagePromptValue, ImageURL, PromptValue
from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.pydantic_v1 import Field
from langchain_core.runnables import run_in_executor
from langchain_core.utils import image as image_utils

View File

@@ -1,9 +1,11 @@
from typing import Any, Dict, List, Tuple
from typing import Optional as Optional
from pydantic import model_validator
from langchain_core.prompt_values import PromptValue
from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.prompts.chat import BaseChatPromptTemplate
from langchain_core.pydantic_v1 import root_validator
def _get_inputs(inputs: dict, input_variables: List[str]) -> dict:
@@ -34,8 +36,9 @@ class PipelinePromptTemplate(BasePromptTemplate):
"""Get the namespace of the langchain object."""
return ["langchain", "prompts", "pipeline"]
@root_validator(pre=True)
def get_input_variables(cls, values: Dict) -> Dict:
@model_validator(mode="before")
@classmethod
def get_input_variables(cls, values: Dict) -> Any:
"""Get input variables."""
created_variables = set()
all_variables = set()
@@ -106,3 +109,6 @@ class PipelinePromptTemplate(BasePromptTemplate):
@property
def _prompt_type(self) -> str:
raise ValueError
PipelinePromptTemplate.model_rebuild()

View File

@@ -6,6 +6,8 @@ import warnings
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Union
from pydantic import BaseModel, model_validator
from langchain_core.prompts.string import (
DEFAULT_FORMATTER_MAPPING,
StringPromptTemplate,
@@ -13,7 +15,6 @@ from langchain_core.prompts.string import (
get_template_variables,
mustache_schema,
)
from langchain_core.pydantic_v1 import BaseModel, root_validator
from langchain_core.runnables.config import RunnableConfig
@@ -73,8 +74,9 @@ class PromptTemplate(StringPromptTemplate):
validate_template: bool = False
"""Whether or not to try validating the template."""
@root_validator(pre=True)
def pre_init_validation(cls, values: Dict) -> Dict:
@model_validator(mode="before")
@classmethod
def pre_init_validation(cls, values: Dict) -> Any:
"""Check that template and input variables are consistent."""
if values.get("template") is None:
# Will let pydantic fail with a ValidationError if template

View File

@@ -7,10 +7,11 @@ from abc import ABC
from string import Formatter
from typing import Any, Callable, Dict, List, Set, Tuple, Type
from pydantic import BaseModel, create_model
import langchain_core.utils.mustache as mustache
from langchain_core.prompt_values import PromptValue, StringPromptValue
from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.pydantic_v1 import BaseModel, create_model
from langchain_core.utils import get_colored_text
from langchain_core.utils.formatting import formatter
from langchain_core.utils.interactive_env import is_interactive_env

View File

@@ -11,13 +11,14 @@ from typing import (
Union,
)
from pydantic import BaseModel, Field
from langchain_core._api.beta_decorator import beta
from langchain_core.language_models.base import BaseLanguageModel
from langchain_core.prompts.chat import (
ChatPromptTemplate,
MessageLikeRepresentation,
)
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.runnables.base import (
Other,
Runnable,

View File

@@ -26,6 +26,7 @@ from abc import ABC, abstractmethod
from inspect import signature
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from pydantic import ConfigDict
from typing_extensions import TypedDict
from langchain_core._api import deprecated
@@ -126,8 +127,9 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
return [self.docs[i] for i in results.argsort()[-self.k :][::-1]]
""" # noqa: E501
class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
_new_arg_supported: bool = False
_expects_other_args: bool = False

View File

@@ -35,7 +35,8 @@ from typing import (
overload,
)
from typing_extensions import Literal, get_args
from pydantic import BaseModel, ConfigDict, Field, RootModel
from typing_extensions import Literal, get_args, get_type_hints
from langchain_core._api import beta_decorator
from langchain_core.load.dump import dumpd
@@ -44,7 +45,6 @@ from langchain_core.load.serializable import (
SerializedConstructor,
SerializedNotImplemented,
)
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.runnables.config import (
RunnableConfig,
_set_config_context,
@@ -83,7 +83,6 @@ from langchain_core.runnables.utils import (
)
from langchain_core.utils.aiter import aclosing, atee, py_anext
from langchain_core.utils.iter import safetee
from langchain_core.utils.pydantic import is_basemodel_subclass
if TYPE_CHECKING:
from langchain_core.callbacks.manager import (
@@ -236,25 +235,58 @@ class Runnable(Generic[Input, Output], ABC):
For a UI (and much more) checkout LangSmith: https://docs.smith.langchain.com/
""" # noqa: E501
name: Optional[str] = None
name: Optional[str]
"""The name of the Runnable. Used for debugging and tracing."""
def get_name(
self, suffix: Optional[str] = None, *, name: Optional[str] = None
) -> str:
"""Get the name of the Runnable."""
name = name or self.name or self.__class__.__name__
if suffix:
if name[0].isupper():
return name + suffix.title()
else:
return name + "_" + suffix.lower()
if name:
name_ = name
elif hasattr(self, "name") and self.name:
name_ = self.name
else:
return name
# Here we handle a case where the runnable subclass is also a pydantic
# model.
cls = self.__class__
# Then it's a pydantic sub-class, and we have to check
# whether it's a generic, and if so recover the original name.
if (
hasattr(
cls,
"__pydantic_generic_metadata__",
)
and "origin" in cls.__pydantic_generic_metadata__
and cls.__pydantic_generic_metadata__["origin"] is not None
):
name_ = cls.__pydantic_generic_metadata__["origin"].__name__
else:
name_ = cls.__name__
if suffix:
if name_[0].isupper():
return name_ + suffix.title()
else:
return name_ + "_" + suffix.lower()
else:
return name_
@property
def InputType(self) -> Type[Input]:
"""The type of input this Runnable accepts specified as a type annotation."""
# First loop through all parent classes and if any of them is
# a pydantic model, we will pick up the generic parameterization
# from that model via the __pydantic_generic_metadata__ attribute.
for base in self.__class__.mro():
if hasattr(base, "__pydantic_generic_metadata__"):
metadata = base.__pydantic_generic_metadata__
if "args" in metadata and len(metadata["args"]) == 2:
return metadata["args"][0]
# If we didn't find a pydantic model in the parent classes,
# then loop through __orig_bases__. This corresponds to
# Runnables that are not pydantic models.
for cls in self.__class__.__orig_bases__: # type: ignore[attr-defined]
type_args = get_args(cls)
if type_args and len(type_args) == 2:
@@ -268,6 +300,14 @@ class Runnable(Generic[Input, Output], ABC):
@property
def OutputType(self) -> Type[Output]:
"""The type of output this Runnable produces specified as a type annotation."""
# First loop through bases -- this will help generic
# any pydantic models.
for base in self.__class__.mro():
if hasattr(base, "__pydantic_generic_metadata__"):
metadata = base.__pydantic_generic_metadata__
if "args" in metadata and len(metadata["args"]) == 2:
return metadata["args"][1]
for cls in self.__class__.__orig_bases__: # type: ignore[attr-defined]
type_args = get_args(cls)
if type_args and len(type_args) == 2:
@@ -302,14 +342,42 @@ class Runnable(Generic[Input, Output], ABC):
"""
root_type = self.InputType
if inspect.isclass(root_type) and is_basemodel_subclass(root_type):
if inspect.isclass(root_type) and issubclass(root_type, BaseModel):
return root_type
return create_model(
self.get_name("Input"),
__root__=(root_type, None),
__root__=root_type,
)
def get_input_jsonschema(
self, config: Optional[RunnableConfig] = None
) -> Dict[str, Any]:
"""Get a JSON schema that represents the input to the Runnable.
Args:
config: A config to use when generating the schema.
Returns:
A JSON schema that represents the input to the Runnable.
Example:
.. code-block:: python
from langchain_core.runnables import RunnableLambda
def add_one(x: int) -> int:
return x + 1
runnable = RunnableLambda(add_one)
print(runnable.get_input_jsonschema())
.. versionadded:: 0.3.0
"""
return self.get_input_schema(config).model_json_schema()
@property
def output_schema(self) -> Type[BaseModel]:
"""The type of output this Runnable produces specified as a pydantic model."""
@@ -334,14 +402,42 @@ class Runnable(Generic[Input, Output], ABC):
"""
root_type = self.OutputType
if inspect.isclass(root_type) and is_basemodel_subclass(root_type):
if inspect.isclass(root_type) and issubclass(root_type, BaseModel):
return root_type
return create_model(
self.get_name("Output"),
__root__=(root_type, None),
__root__=root_type,
)
def get_output_jsonschema(
self, config: Optional[RunnableConfig] = None
) -> Dict[str, Any]:
"""Get a JSON schema that represents the output of the Runnable.
Args:
config: A config to use when generating the schema.
Returns:
A JSON schema that represents the output of the Runnable.
Example:
.. code-block:: python
from langchain_core.runnables import RunnableLambda
def add_one(x: int) -> int:
return x + 1
runnable = RunnableLambda(add_one)
print(runnable.get_output_jsonschema())
.. versionadded:: 0.3.0
"""
return self.get_output_schema(config).model_json_schema()
@property
def config_specs(self) -> List[ConfigurableFieldSpec]:
"""List configurable fields for this Runnable."""
@@ -381,15 +477,34 @@ class Runnable(Generic[Input, Output], ABC):
else None
)
return create_model( # type: ignore[call-overload]
self.get_name("Config"),
# Many need to create a typed dict instead to implement NotRequired!
all_fields = {
**({"configurable": (configurable, None)} if configurable else {}),
**{
field_name: (field_type, None)
for field_name, field_type in RunnableConfig.__annotations__.items()
for field_name, field_type in get_type_hints(RunnableConfig).items()
if field_name in [i for i in include if i != "configurable"]
},
}
model = create_model( # type: ignore[call-overload]
self.get_name("Config"), **all_fields
)
return model
def get_config_jsonschema(
self, *, include: Optional[Sequence[str]] = None
) -> Dict[str, Any]:
"""Get a JSON schema that represents the output of the Runnable.
Args:
include: A list of fields to include in the config schema.
Returns:
A JSON schema that represents the output of the Runnable.
.. versionadded:: 0.3.0
"""
return self.config_schema(include=include).model_json_schema()
def get_graph(self, config: Optional[RunnableConfig] = None) -> Graph:
"""Return a graph representation of this Runnable."""
@@ -579,7 +694,7 @@ class Runnable(Generic[Input, Output], ABC):
"""
from langchain_core.runnables.passthrough import RunnableAssign
return self | RunnableAssign(RunnableParallel(kwargs))
return self | RunnableAssign(RunnableParallel[Dict[str, Any]](kwargs))
""" --- Public API --- """
@@ -2129,7 +2244,6 @@ class Runnable(Generic[Input, Output], ABC):
name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None),
)
iterator_ = None
try:
child_config = patch_config(config, callbacks=run_manager.get_child())
if accepts_config(transformer):
@@ -2314,7 +2428,12 @@ class RunnableSerializable(Serializable, Runnable[Input, Output]):
"""Runnable that can be serialized to JSON."""
name: Optional[str] = None
"""The name of the Runnable. Used for debugging and tracing."""
model_config = ConfigDict(
# Suppress warnings from pydantic protected namespaces
# (e.g., `model_`)
protected_namespaces=(),
)
def to_json(self) -> Union[SerializedConstructor, SerializedNotImplemented]:
"""Serialize the Runnable to JSON.
@@ -2369,10 +2488,10 @@ class RunnableSerializable(Serializable, Runnable[Input, Output]):
from langchain_core.runnables.configurable import RunnableConfigurableFields
for key in kwargs:
if key not in self.__fields__:
if key not in self.model_fields:
raise ValueError(
f"Configuration key {key} not found in {self}: "
f"available keys are {self.__fields__.keys()}"
f"available keys are {self.model_fields.keys()}"
)
return RunnableConfigurableFields(default=self, fields=kwargs)
@@ -2447,13 +2566,13 @@ def _seq_input_schema(
return first.get_input_schema(config)
elif isinstance(first, RunnableAssign):
next_input_schema = _seq_input_schema(steps[1:], config)
if not next_input_schema.__custom_root_type__:
if not issubclass(next_input_schema, RootModel):
# it's a dict as expected
return create_model( # type: ignore[call-overload]
"RunnableSequenceInput",
**{
k: (v.annotation, v.default)
for k, v in next_input_schema.__fields__.items()
for k, v in next_input_schema.model_fields.items()
if k not in first.mapper.steps__
},
)
@@ -2474,36 +2593,36 @@ def _seq_output_schema(
elif isinstance(last, RunnableAssign):
mapper_output_schema = last.mapper.get_output_schema(config)
prev_output_schema = _seq_output_schema(steps[:-1], config)
if not prev_output_schema.__custom_root_type__:
if not issubclass(prev_output_schema, RootModel):
# it's a dict as expected
return create_model( # type: ignore[call-overload]
"RunnableSequenceOutput",
**{
**{
k: (v.annotation, v.default)
for k, v in prev_output_schema.__fields__.items()
for k, v in prev_output_schema.model_fields.items()
},
**{
k: (v.annotation, v.default)
for k, v in mapper_output_schema.__fields__.items()
for k, v in mapper_output_schema.model_fields.items()
},
},
)
elif isinstance(last, RunnablePick):
prev_output_schema = _seq_output_schema(steps[:-1], config)
if not prev_output_schema.__custom_root_type__:
if not issubclass(prev_output_schema, RootModel):
# it's a dict as expected
if isinstance(last.keys, list):
return create_model( # type: ignore[call-overload]
"RunnableSequenceOutput",
**{
k: (v.annotation, v.default)
for k, v in prev_output_schema.__fields__.items()
for k, v in prev_output_schema.model_fields.items()
if k in last.keys
},
)
else:
field = prev_output_schema.__fields__[last.keys]
field = prev_output_schema.model_fields[last.keys]
return create_model( # type: ignore[call-overload]
"RunnableSequenceOutput",
__root__=(field.annotation, field.default),
@@ -2665,8 +2784,9 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
"""
return True
class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
@property
def InputType(self) -> Type[Input]:
@@ -3402,8 +3522,9 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"]
class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
def get_name(
self, suffix: Optional[str] = None, *, name: Optional[str] = None
@@ -3450,7 +3571,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
**{
k: (v.annotation, v.default)
for step in self.steps__.values()
for k, v in step.get_input_schema(config).__fields__.items()
for k, v in step.get_input_schema(config).model_fields.items()
if k != "__root__"
},
)
@@ -3468,11 +3589,8 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
Returns:
The output schema of the Runnable.
"""
# This is correct, but pydantic typings/mypy don't think so.
return create_model( # type: ignore[call-overload]
self.get_name("Output"),
**{k: (v.OutputType, None) for k, v in self.steps__.items()},
)
fields = {k: (v.OutputType, ...) for k, v in self.steps__.items()}
return create_model(self.get_name("Output"), **fields)
@property
def config_specs(self) -> List[ConfigurableFieldSpec]:
@@ -3882,6 +4000,8 @@ class RunnableGenerator(Runnable[Input, Output]):
atransform: Optional[
Callable[[AsyncIterator[Input]], AsyncIterator[Output]]
] = None,
*,
name: Optional[str] = None,
) -> None:
"""Initialize a RunnableGenerator.
@@ -3909,9 +4029,9 @@ class RunnableGenerator(Runnable[Input, Output]):
)
try:
self.name = func_for_name.__name__
self.name = name or func_for_name.__name__
except AttributeError:
pass
self.name = "RunnableGenerator"
@property
def InputType(self) -> Any:
@@ -4183,15 +4303,13 @@ class RunnableLambda(Runnable[Input, Output]):
if all(
item[0] == "'" and item[-1] == "'" and len(item) > 2 for item in items
):
fields = {item[1:-1]: (Any, ...) for item in items}
# It's a dict, lol
return create_model(
self.get_name("Input"),
**{item[1:-1]: (Any, None) for item in items}, # type: ignore
)
return create_model(self.get_name("Input"), **fields)
else:
return create_model(
self.get_name("Input"),
__root__=(List[Any], None),
__root__=List[Any],
)
if self.InputType != Any:
@@ -4200,7 +4318,7 @@ class RunnableLambda(Runnable[Input, Output]):
if dict_keys := get_function_first_arg_dict_keys(func):
return create_model(
self.get_name("Input"),
**{key: (Any, None) for key in dict_keys}, # type: ignore
**{key: (Any, ...) for key in dict_keys}, # type: ignore
)
return super().get_input_schema(config)
@@ -4728,8 +4846,9 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
bound: Runnable[Input, Output]
class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
@property
def InputType(self) -> Any:
@@ -4756,10 +4875,7 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
schema = self.bound.get_output_schema(config)
return create_model(
self.get_name("Output"),
__root__=(
List[schema], # type: ignore
None,
),
__root__=List[schema], # type: ignore[valid-type]
)
@property
@@ -4979,8 +5095,9 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
The type can be a pydantic model, or a type annotation (e.g., `List[str]`).
"""
class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
def __init__(
self,
@@ -5316,7 +5433,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
yield item
RunnableBindingBase.update_forward_refs(RunnableConfig=RunnableConfig)
RunnableBindingBase.model_rebuild()
class RunnableBinding(RunnableBindingBase[Input, Output]):

View File

@@ -14,8 +14,9 @@ from typing import (
cast,
)
from pydantic import BaseModel, ConfigDict
from langchain_core.load.dump import dumpd
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables.base import (
Runnable,
RunnableLike,
@@ -134,10 +135,21 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
runnable = coerce_to_runnable(runnable)
_branches.append((condition, runnable))
super().__init__(branches=_branches, default=default_) # type: ignore[call-arg]
super().__init__(
branches=_branches,
default=default_,
# Hard-coding a name here because RunnableBranch is a generic
# and with pydantic 2, the class name with pydantic will capture
# include the parameterized type, which is not what we want.
# e.g., we'd get RunnableBranch[Input, Output] instead of RunnableBranch
# for the name. This information is already captured in the
# input and output types.
name="RunnableBranch",
) # type: ignore[call-arg]
class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
@classmethod
def is_lc_serializable(cls) -> bool:

View File

@@ -20,7 +20,8 @@ from typing import (
)
from weakref import WeakValueDictionary
from langchain_core.pydantic_v1 import BaseModel
from pydantic import BaseModel, ConfigDict
from langchain_core.runnables.base import Runnable, RunnableSerializable
from langchain_core.runnables.config import (
RunnableConfig,
@@ -58,8 +59,9 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
config: Optional[RunnableConfig] = None
class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
@classmethod
def is_lc_serializable(cls) -> bool:
@@ -373,28 +375,33 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
Returns:
List[ConfigurableFieldSpec]: The configuration specs.
"""
return get_unique_config_specs(
[
(
# TODO(0.3): This change removes field_info which isn't needed in pydantic 2
config_specs = []
for field_name, spec in self.fields.items():
if isinstance(spec, ConfigurableField):
config_specs.append(
ConfigurableFieldSpec(
id=spec.id,
name=spec.name,
description=spec.description
or self.default.__fields__[field_name].field_info.description,
or self.default.model_fields[field_name].description,
annotation=spec.annotation
or self.default.__fields__[field_name].annotation,
or self.default.model_fields[field_name].annotation,
default=getattr(self.default, field_name),
is_shared=spec.is_shared,
)
if isinstance(spec, ConfigurableField)
else make_options_spec(
spec, self.default.__fields__[field_name].field_info.description
)
else:
config_specs.append(
make_options_spec(
spec, self.default.model_fields[field_name].description
)
)
for field_name, spec in self.fields.items()
]
+ list(self.default.config_specs)
)
config_specs.extend(self.default.config_specs)
return get_unique_config_specs(config_specs)
def configurable_fields(
self, **kwargs: AnyConfigurableField
@@ -436,7 +443,7 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
init_params = {
k: v
for k, v in self.default.__dict__.items()
if k in self.default.__fields__
if k in self.default.model_fields
}
return (
self.default.__class__(**{**init_params, **configurable}),

View File

@@ -18,8 +18,9 @@ from typing import (
cast,
)
from pydantic import BaseModel, ConfigDict
from langchain_core.load.dump import dumpd
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables.base import Runnable, RunnableSerializable
from langchain_core.runnables.config import (
RunnableConfig,
@@ -107,8 +108,9 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
will not be passed to fallbacks. If used, the base Runnable and its fallbacks
must accept a dictionary as input."""
class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
@property
def InputType(self) -> Type[Input]:

View File

@@ -22,8 +22,9 @@ from typing import (
)
from uuid import UUID, uuid4
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.utils.pydantic import is_basemodel_subclass
from pydantic import BaseModel
from langchain_core.utils.pydantic import _IgnoreUnserializable, is_basemodel_subclass
if TYPE_CHECKING:
from langchain_core.runnables.base import Runnable as RunnableType
@@ -235,7 +236,9 @@ def node_data_json(
json = (
{
"type": "schema",
"data": node.data.schema(),
"data": node.data.model_json_schema(
schema_generator=_IgnoreUnserializable
),
}
if with_schemas
else {

View File

@@ -13,9 +13,10 @@ from typing import (
Union,
)
from pydantic import BaseModel
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.load.load import load
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables.base import Runnable, RunnableBindingBase, RunnableLambda
from langchain_core.runnables.passthrough import RunnablePassthrough
from langchain_core.runnables.utils import (
@@ -372,28 +373,25 @@ class RunnableWithMessageHistory(RunnableBindingBase):
def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
super_schema = super().get_input_schema(config)
if super_schema.__custom_root_type__ or not super_schema.schema().get(
"properties"
):
from langchain_core.messages import BaseMessage
# TODO(0.3): Verify that this change was correct
# Not enough tests and unclear on why the previous implementation was
# necessary.
from langchain_core.messages import BaseMessage
fields: Dict = {}
if self.input_messages_key and self.history_messages_key:
fields[self.input_messages_key] = (
Union[str, BaseMessage, Sequence[BaseMessage]],
...,
)
elif self.input_messages_key:
fields[self.input_messages_key] = (Sequence[BaseMessage], ...)
else:
fields["__root__"] = (Sequence[BaseMessage], ...)
return create_model( # type: ignore[call-overload]
"RunnableWithChatHistoryInput",
**fields,
fields: Dict = {}
if self.input_messages_key and self.history_messages_key:
fields[self.input_messages_key] = (
Union[str, BaseMessage, Sequence[BaseMessage]],
...,
)
elif self.input_messages_key:
fields[self.input_messages_key] = (Sequence[BaseMessage], ...)
else:
return super_schema
fields["__root__"] = (Sequence[BaseMessage], ...)
return create_model( # type: ignore[call-overload]
"RunnableWithChatHistoryInput",
**fields,
)
def _is_not_async(self, *args: Sequence[Any], **kwargs: Dict[str, Any]) -> bool:
return False

View File

@@ -21,7 +21,8 @@ from typing import (
cast,
)
from langchain_core.pydantic_v1 import BaseModel
from pydantic import BaseModel, RootModel
from langchain_core.runnables.base import (
Other,
Runnable,
@@ -227,7 +228,7 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
A Runnable that merges the Dict input with the output produced by the
mapping argument.
"""
return RunnableAssign(RunnableParallel(kwargs))
return RunnableAssign(RunnableParallel[Dict[str, Any]](kwargs))
def invoke(
self, input: Other, config: Optional[RunnableConfig] = None, **kwargs: Any
@@ -419,7 +420,7 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
map_input_schema = self.mapper.get_input_schema(config)
if not map_input_schema.__custom_root_type__:
if not issubclass(map_input_schema, RootModel):
# ie. it's a dict
return map_input_schema
@@ -430,20 +431,22 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
) -> Type[BaseModel]:
map_input_schema = self.mapper.get_input_schema(config)
map_output_schema = self.mapper.get_output_schema(config)
if (
not map_input_schema.__custom_root_type__
and not map_output_schema.__custom_root_type__
if not issubclass(map_input_schema, RootModel) and not issubclass(
map_output_schema, RootModel
):
# ie. both are dicts
fields = {}
for name, field_info in map_input_schema.model_fields.items():
fields[name] = (field_info.annotation, field_info.default)
for name, field_info in map_output_schema.model_fields.items():
fields[name] = (field_info.annotation, field_info.default)
return create_model( # type: ignore[call-overload]
"RunnableAssignOutput",
**{
k: (v.type_, v.default)
for s in (map_input_schema, map_output_schema)
for k, v in s.__fields__.items()
},
**fields,
)
elif not map_output_schema.__custom_root_type__:
elif not issubclass(map_output_schema, RootModel):
# ie. only map output is a dict
# ie. input type is either unknown or inferred incorrectly
return map_output_schema

View File

@@ -12,6 +12,7 @@ from typing import (
cast,
)
from pydantic import ConfigDict
from typing_extensions import TypedDict
from langchain_core.runnables.base import (
@@ -83,8 +84,9 @@ class RouterRunnable(RunnableSerializable[RouterInput, Output]):
runnables={key: coerce_to_runnable(r) for key, r in runnables.items()}
)
class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
@classmethod
def is_lc_serializable(cls) -> bool:

View File

@@ -28,12 +28,18 @@ from typing import (
Type,
TypeVar,
Union,
cast,
)
from pydantic import BaseModel, ConfigDict, RootModel
from pydantic import create_model as _create_model_base # pydantic :ignore
from pydantic.json_schema import (
DEFAULT_REF_TEMPLATE,
GenerateJsonSchema,
JsonSchemaMode,
)
from typing_extensions import TypeGuard
from langchain_core.pydantic_v1 import BaseConfig, BaseModel
from langchain_core.pydantic_v1 import create_model as _create_model_base
from langchain_core.runnables.schema import StreamEvent
Input = TypeVar("Input", contravariant=True)
@@ -350,7 +356,7 @@ def get_function_first_arg_dict_keys(func: Callable) -> Optional[List[str]]:
tree = ast.parse(textwrap.dedent(code))
visitor = IsFunctionArgDict()
visitor.visit(tree)
return list(visitor.keys) if visitor.keys else None
return sorted(visitor.keys) if visitor.keys else None
except (SyntaxError, TypeError, OSError, SystemError):
return None
@@ -699,9 +705,57 @@ class _RootEventFilter:
return include
class _SchemaConfig(BaseConfig):
arbitrary_types_allowed = True
frozen = True
_SchemaConfig = ConfigDict(arbitrary_types_allowed=True, frozen=True)
NO_DEFAULT = object()
def create_base_class(
name: str, type_: Any, default_: object = NO_DEFAULT
) -> Type[BaseModel]:
"""Create a base class."""
def schema(
cls: Type[BaseModel],
by_alias: bool = True,
ref_template: str = DEFAULT_REF_TEMPLATE,
) -> Dict[str, Any]:
# Complains about schema not being defined in superclass
schema_ = super(cls, cls).schema( # type: ignore[misc]
by_alias=by_alias, ref_template=ref_template
)
schema_["title"] = name
return schema_
def model_json_schema(
cls: Type[BaseModel],
by_alias: bool = True,
ref_template: str = DEFAULT_REF_TEMPLATE,
schema_generator: type[GenerateJsonSchema] = GenerateJsonSchema,
mode: JsonSchemaMode = "validation",
) -> Dict[str, Any]:
# Complains about model_json_schema not being defined in superclass
schema_ = super(cls, cls).model_json_schema( # type: ignore[misc]
by_alias=by_alias,
ref_template=ref_template,
schema_generator=schema_generator,
mode=mode,
)
schema_["title"] = name
return schema_
base_class_attributes = {
"__annotations__": {"root": type_},
"model_config": ConfigDict(arbitrary_types_allowed=True),
"schema": classmethod(schema),
"model_json_schema": classmethod(model_json_schema),
"__module__": "langchain_core.runnables.utils",
}
if default_ is not NO_DEFAULT:
base_class_attributes["root"] = default_
custom_root_type = type(name, (RootModel,), base_class_attributes)
return cast(Type[BaseModel], custom_root_type)
def create_model(
@@ -717,6 +771,21 @@ def create_model(
Returns:
Type[BaseModel]: The created model.
"""
# Move this to caching path
if "__root__" in field_definitions:
if len(field_definitions) > 1:
raise NotImplementedError(
"When specifying __root__ no other "
f"fields should be provided. Got {field_definitions}"
)
arg = field_definitions["__root__"]
if isinstance(arg, tuple):
named_root_model = create_base_class(__model_name, arg[0], arg[1])
else:
named_root_model = create_base_class(__model_name, arg)
return named_root_model
try:
return _create_model_cached(__model_name, **field_definitions)
except TypeError:

View File

@@ -6,7 +6,7 @@ from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, List, Optional, Sequence, Union
from langchain_core.pydantic_v1 import BaseModel
from pydantic import BaseModel
class Visitor(ABC):
@@ -127,7 +127,8 @@ class Comparison(FilterDirective):
def __init__(
self, comparator: Comparator, attribute: str, value: Any, **kwargs: Any
) -> None:
super().__init__(
# super exists from BaseModel
super().__init__( # type: ignore[call-arg]
comparator=comparator, attribute=attribute, value=value, **kwargs
)
@@ -145,8 +146,11 @@ class Operation(FilterDirective):
def __init__(
self, operator: Operator, arguments: List[FilterDirective], **kwargs: Any
):
super().__init__(operator=operator, arguments=arguments, **kwargs)
) -> None:
# super exists from BaseModel
super().__init__( # type: ignore[call-arg]
operator=operator, arguments=arguments, **kwargs
)
class StructuredQuery(Expr):
@@ -165,5 +169,8 @@ class StructuredQuery(Expr):
filter: Optional[FilterDirective],
limit: Optional[int] = None,
**kwargs: Any,
):
super().__init__(query=query, filter=filter, limit=limit, **kwargs)
) -> None:
# super exists from BaseModel
super().__init__( # type: ignore[call-arg]
query=query, filter=filter, limit=limit, **kwargs
)

View File

@@ -19,12 +19,25 @@ from typing import (
Sequence,
Tuple,
Type,
TypeVar,
Union,
cast,
get_args,
get_origin,
get_type_hints,
)
from typing_extensions import Annotated, TypeVar, get_args, get_origin
from pydantic import (
BaseModel,
ConfigDict,
Extra,
Field,
SkipValidation,
ValidationError,
model_validator,
validate_arguments,
)
from typing_extensions import Annotated
from langchain_core._api import deprecated
from langchain_core.callbacks import (
@@ -33,16 +46,7 @@ from langchain_core.callbacks import (
CallbackManager,
Callbacks,
)
from langchain_core.load import Serializable
from langchain_core.messages import ToolCall, ToolMessage
from langchain_core.pydantic_v1 import (
BaseModel,
Extra,
Field,
ValidationError,
root_validator,
validate_arguments,
)
from langchain_core.messages.tool import ToolCall, ToolMessage
from langchain_core.runnables import (
RunnableConfig,
RunnableSerializable,
@@ -59,6 +63,7 @@ from langchain_core.utils.function_calling import (
from langchain_core.utils.pydantic import (
TypeBaseModel,
_create_subset_model,
get_fields,
is_basemodel_subclass,
is_pydantic_v1_subclass,
is_pydantic_v2_subclass,
@@ -204,20 +209,64 @@ def create_schema_from_function(
"""
# https://docs.pydantic.dev/latest/usage/validation_decorator/
validated = validate_arguments(func, config=_SchemaConfig) # type: ignore
sig = inspect.signature(func)
# Let's ignore `self` and `cls` arguments for class and instance methods
if func.__qualname__ and "." in func.__qualname__:
# Then it likely belongs in a class namespace
in_class = True
else:
in_class = False
has_args = False
has_kwargs = False
for param in sig.parameters.values():
if param.kind == param.VAR_POSITIONAL:
has_args = True
elif param.kind == param.VAR_KEYWORD:
has_kwargs = True
inferred_model = validated.model # type: ignore
filter_args = filter_args if filter_args is not None else FILTERED_ARGS
for arg in filter_args:
if arg in inferred_model.__fields__:
del inferred_model.__fields__[arg]
if filter_args:
filter_args_ = filter_args
else:
# Handle classmethods and instance methods
existing_params: List[str] = list(sig.parameters.keys())
if existing_params and existing_params[0] in ("self", "cls") and in_class:
filter_args_ = [existing_params[0]] + list(FILTERED_ARGS)
else:
filter_args_ = list(FILTERED_ARGS)
for existing_param in existing_params:
if not include_injected and _is_injected_arg_type(
sig.parameters[existing_param].annotation
):
filter_args_.append(existing_param)
description, arg_descriptions = _infer_arg_descriptions(
func,
parse_docstring=parse_docstring,
error_on_invalid_docstring=error_on_invalid_docstring,
)
# Pydantic adds placeholder virtual fields we need to strip
valid_properties = _get_filtered_args(
inferred_model, func, filter_args=filter_args, include_injected=include_injected
)
valid_properties = []
for field in get_fields(inferred_model):
if not has_args:
if field == "args":
continue
if not has_kwargs:
if field == "kwargs":
continue
if field == "v__duplicate_kwargs": # Internal pydantic field
continue
if field not in filter_args_:
valid_properties.append(field)
return _create_subset_model(
f"{model_name}Schema",
inferred_model,
@@ -274,7 +323,10 @@ class ChildTool(BaseTool):
You can provide few-shot examples as a part of the description.
"""
args_schema: Optional[TypeBaseModel] = None
args_schema: Annotated[Optional[TypeBaseModel], SkipValidation()] = Field(
default=None, description="The tool schema."
)
"""Pydantic model class to validate and parse the tool's input arguments.
Args schema should be either:
@@ -345,8 +397,9 @@ class ChildTool(BaseTool):
)
super().__init__(**kwargs)
class Config(Serializable.Config):
arbitrary_types_allowed = True
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
@property
def is_single_input(self) -> bool:
@@ -416,7 +469,7 @@ class ChildTool(BaseTool):
input_args = self.args_schema
if isinstance(tool_input, str):
if input_args is not None:
key_ = next(iter(input_args.__fields__.keys()))
key_ = next(iter(get_fields(input_args).keys()))
input_args.validate({key_: tool_input})
return tool_input
else:
@@ -429,8 +482,9 @@ class ChildTool(BaseTool):
}
return tool_input
@root_validator(pre=True)
def raise_deprecation(cls, values: Dict) -> Dict:
@model_validator(mode="before")
@classmethod
def raise_deprecation(cls, values: Dict) -> Any:
"""Raise deprecation warning if callback_manager is used.
Args:

View File

@@ -1,8 +1,9 @@
import inspect
from typing import Any, Callable, Dict, Literal, Optional, Type, Union, get_type_hints
from pydantic import BaseModel, Field, create_model
from langchain_core.callbacks import Callbacks
from langchain_core.pydantic_v1 import BaseModel, Field, create_model
from langchain_core.runnables import Runnable
from langchain_core.tools.base import BaseTool
from langchain_core.tools.simple import Tool

View File

@@ -3,6 +3,8 @@ from __future__ import annotations
from functools import partial
from typing import Optional
from pydantic import BaseModel, Field
from langchain_core.callbacks import Callbacks
from langchain_core.prompts import (
BasePromptTemplate,
@@ -10,7 +12,6 @@ from langchain_core.prompts import (
aformat_document,
format_document,
)
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.retrievers import BaseRetriever
from langchain_core.tools.simple import Tool

View File

@@ -1,14 +1,24 @@
from __future__ import annotations
from inspect import signature
from typing import Any, Awaitable, Callable, Dict, Optional, Tuple, Type, Union
from typing import (
Any,
Awaitable,
Callable,
Dict,
Optional,
Tuple,
Type,
Union,
)
from pydantic import BaseModel
from langchain_core.callbacks import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain_core.messages import ToolCall
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables import RunnableConfig, run_in_executor
from langchain_core.tools.base import (
BaseTool,
@@ -155,3 +165,6 @@ class Tool(BaseTool):
args_schema=args_schema,
**kwargs,
)
Tool.model_rebuild()

View File

@@ -2,14 +2,26 @@ from __future__ import annotations
import textwrap
from inspect import signature
from typing import Any, Awaitable, Callable, Dict, List, Literal, Optional, Type, Union
from typing import (
Any,
Awaitable,
Callable,
Dict,
List,
Literal,
Optional,
Type,
Union,
)
from pydantic import BaseModel, Field, SkipValidation
from typing_extensions import Annotated
from langchain_core.callbacks import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain_core.messages import ToolCall
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.runnables import RunnableConfig, run_in_executor
from langchain_core.tools.base import (
FILTERED_ARGS,
@@ -24,7 +36,9 @@ class StructuredTool(BaseTool):
"""Tool that can operate on any number of inputs."""
description: str = ""
args_schema: TypeBaseModel = Field(..., description="The tool schema.")
args_schema: Annotated[TypeBaseModel, SkipValidation()] = Field(
..., description="The tool schema."
)
"""The input arguments' schema."""
func: Optional[Callable[..., Any]]
"""The function to run when the tool is called."""

View File

@@ -11,7 +11,6 @@ from langsmith.schemas import RunBase as BaseRunV2
from langsmith.schemas import RunTypeEnum as RunTypeEnumDep
from langchain_core._api import deprecated
from langchain_core.outputs import LLMResult
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
@@ -83,7 +82,8 @@ class LLMRun(BaseRun):
"""Class for LLMRun."""
prompts: List[str]
response: Optional[LLMResult] = None
# Temporarily, remove but we will completely remove LLMRun
# response: Optional[LLMResult] = None
@deprecated("0.1.0", alternative="Run", removal="1.0")

View File

@@ -23,11 +23,11 @@ from typing import (
cast,
)
from pydantic import BaseModel
from typing_extensions import Annotated, TypedDict, get_args, get_origin, is_typeddict
from langchain_core._api import deprecated
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage
from langchain_core.pydantic_v1 import BaseModel, Field, create_model
from langchain_core.utils.json_schema import dereference_refs
from langchain_core.utils.pydantic import is_basemodel_subclass
@@ -85,7 +85,7 @@ def _rm_titles(kv: dict, prev_key: str = "") -> dict:
removal="1.0",
)
def convert_pydantic_to_openai_function(
model: Type[BaseModel],
model: Type,
*,
name: Optional[str] = None,
description: Optional[str] = None,
@@ -109,7 +109,10 @@ def convert_pydantic_to_openai_function(
else:
schema = model.schema() # Pydantic 1
schema = dereference_refs(schema)
schema.pop("definitions", None)
if "definitions" in schema: # pydantic 1
schema.pop("definitions", None)
if "$defs" in schema: # pydantic 2
schema.pop("$defs", None)
title = schema.pop("title", "")
default_description = schema.pop("description", "")
return {
@@ -193,11 +196,13 @@ def convert_python_function_to_openai_function(
def _convert_typed_dict_to_openai_function(typed_dict: Type) -> FunctionDescription:
visited: Dict = {}
from pydantic.v1 import BaseModel
model = cast(
Type[BaseModel],
_convert_any_typed_dicts_to_pydantic(typed_dict, visited=visited),
)
return convert_pydantic_to_openai_function(model)
return convert_pydantic_to_openai_function(model) # type: ignore
_MAX_TYPED_DICT_RECURSION = 25
@@ -209,6 +214,9 @@ def _convert_any_typed_dicts_to_pydantic(
visited: Dict,
depth: int = 0,
) -> Type:
from pydantic.v1 import Field as Field_v1
from pydantic.v1 import create_model as create_model_v1
if type_ in visited:
return visited[type_]
elif depth >= _MAX_TYPED_DICT_RECURSION:
@@ -242,7 +250,7 @@ def _convert_any_typed_dicts_to_pydantic(
field_kwargs["description"] = arg_desc
else:
pass
fields[arg] = (new_arg_type, Field(**field_kwargs))
fields[arg] = (new_arg_type, Field_v1(**field_kwargs))
else:
new_arg_type = _convert_any_typed_dicts_to_pydantic(
arg_type, depth=depth + 1, visited=visited
@@ -250,8 +258,8 @@ def _convert_any_typed_dicts_to_pydantic(
field_kwargs = {"default": ...}
if arg_desc := arg_descriptions.get(arg):
field_kwargs["description"] = arg_desc
fields[arg] = (new_arg_type, Field(**field_kwargs))
model = create_model(typed_dict.__name__, **fields)
fields[arg] = (new_arg_type, Field_v1(**field_kwargs))
model = create_model_v1(typed_dict.__name__, **fields)
model.__doc__ = description
visited[typed_dict] = model
return model

View File

@@ -7,9 +7,10 @@ import textwrap
from functools import wraps
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union, overload
import pydantic # pydantic: ignore
from langchain_core.pydantic_v1 import BaseModel, root_validator
import pydantic
from pydantic import BaseModel, root_validator
from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue
from pydantic_core import core_schema
def get_pydantic_major_version() -> int:
@@ -76,13 +77,13 @@ def is_basemodel_subclass(cls: Type) -> bool:
return False
if PYDANTIC_MAJOR_VERSION == 1:
from pydantic import BaseModel as BaseModelV1Proper # pydantic: ignore
from pydantic import BaseModel as BaseModelV1Proper
if issubclass(cls, BaseModelV1Proper):
return True
elif PYDANTIC_MAJOR_VERSION == 2:
from pydantic import BaseModel as BaseModelV2 # pydantic: ignore
from pydantic.v1 import BaseModel as BaseModelV1 # pydantic: ignore
from pydantic import BaseModel as BaseModelV2
from pydantic.v1 import BaseModel as BaseModelV1
if issubclass(cls, BaseModelV2):
return True
@@ -104,13 +105,13 @@ def is_basemodel_instance(obj: Any) -> bool:
* pydantic.v1.BaseModel in Pydantic 2.x
"""
if PYDANTIC_MAJOR_VERSION == 1:
from pydantic import BaseModel as BaseModelV1Proper # pydantic: ignore
from pydantic import BaseModel as BaseModelV1Proper
if isinstance(obj, BaseModelV1Proper):
return True
elif PYDANTIC_MAJOR_VERSION == 2:
from pydantic import BaseModel as BaseModelV2 # pydantic: ignore
from pydantic.v1 import BaseModel as BaseModelV1 # pydantic: ignore
from pydantic import BaseModel as BaseModelV2
from pydantic.v1 import BaseModel as BaseModelV1
if isinstance(obj, BaseModelV2):
return True
@@ -146,7 +147,7 @@ def pre_init(func: Callable) -> Any:
Dict[str, Any]: The values to initialize the model with.
"""
# Insert default values
fields = cls.__fields__
fields = cls.model_fields
for name, field_info in fields.items():
# Check if allow_population_by_field_name is enabled
# If yes, then set the field name to the alias
@@ -155,9 +156,13 @@ def pre_init(func: Callable) -> Any:
if cls.Config.allow_population_by_field_name:
if field_info.alias in values:
values[name] = values.pop(field_info.alias)
if hasattr(cls, "model_config"):
if cls.model_config.get("populate_by_name"):
if field_info.alias in values:
values[name] = values.pop(field_info.alias)
if name not in values or values[name] is None:
if not field_info.required:
if not field_info.is_required():
if field_info.default_factory is not None:
values[name] = field_info.default_factory()
else:
@@ -169,6 +174,46 @@ def pre_init(func: Callable) -> Any:
return wrapper
class _IgnoreUnserializable(GenerateJsonSchema):
"""A JSON schema generator that ignores unknown types.
https://docs.pydantic.dev/latest/concepts/json_schema/#customizing-the-json-schema-generation-process
"""
def handle_invalid_for_json_schema(
self, schema: core_schema.CoreSchema, error_info: str
) -> JsonSchemaValue:
return {}
def v1_repr(obj: BaseModel) -> str:
"""Return the schema of the object as a string.
Get a repr for the pydantic object which is consistent with pydantic.v1.
"""
if not is_basemodel_instance(obj):
raise TypeError(f"Expected a pydantic BaseModel, got {type(obj)}")
repr_ = []
for name, field in get_fields(obj).items():
value = getattr(obj, name)
if isinstance(value, BaseModel):
repr_.append(f"{name}={v1_repr(value)}")
else:
if field.exclude:
continue
if not field.is_required():
if not value:
continue
if field.default == value:
continue
repr_.append(f"{name}={repr(value)}")
args = ", ".join(repr_)
return f"{obj.__class__.__name__}({args})"
def _create_subset_model_v1(
name: str,
model: Type[BaseModel],
@@ -178,12 +223,20 @@ def _create_subset_model_v1(
fn_description: Optional[str] = None,
) -> Type[BaseModel]:
"""Create a pydantic model with only a subset of model's fields."""
from langchain_core.pydantic_v1 import create_model
if PYDANTIC_MAJOR_VERSION == 1:
from pydantic import create_model
elif PYDANTIC_MAJOR_VERSION == 2:
from pydantic.v1 import create_model # type: ignore
else:
raise NotImplementedError(
f"Unsupported pydantic version: {PYDANTIC_MAJOR_VERSION}"
)
fields = {}
for field_name in field_names:
field = model.__fields__[field_name]
# Using pydantic v1 so can access __fields__ as a dict.
field = model.__fields__[field_name] # type: ignore
t = (
# this isn't perfect but should work for most functions
field.outer_type_
@@ -208,8 +261,8 @@ def _create_subset_model_v2(
fn_description: Optional[str] = None,
) -> Type[pydantic.BaseModel]:
"""Create a pydantic model with a subset of the model fields."""
from pydantic import create_model # pydantic: ignore
from pydantic.fields import FieldInfo # pydantic: ignore
from pydantic import create_model
from pydantic.fields import FieldInfo
descriptions_ = descriptions or {}
fields = {}
@@ -222,6 +275,17 @@ def _create_subset_model_v2(
fields[field_name] = (field.annotation, field_info)
rtn = create_model(name, **fields) # type: ignore
# TODO(0.3): Determine if there is a more "pydantic" way to preserve annotations.
# This is done to preserve __annotations__ when working with pydantic 2.x
# and using the Annotated type with TypedDict.
# Comment out the following line, to trigger the relevant test case.
selected_annotations = [
(name, annotation)
for name, annotation in model.__annotations__.items()
if name in field_names
]
rtn.__annotations__ = dict(selected_annotations)
rtn.__doc__ = textwrap.dedent(fn_description or model.__doc__ or "")
return rtn
@@ -248,7 +312,7 @@ def _create_subset_model(
fn_description=fn_description,
)
elif PYDANTIC_MAJOR_VERSION == 2:
from pydantic.v1 import BaseModel as BaseModelV1 # pydantic: ignore
from pydantic.v1 import BaseModel as BaseModelV1
if issubclass(model, BaseModelV1):
return _create_subset_model_v1(

View File

@@ -10,9 +10,9 @@ from importlib.metadata import version
from typing import Any, Callable, Dict, Optional, Sequence, Set, Tuple, Union, overload
from packaging.version import parse
from pydantic import SecretStr
from requests import HTTPError, Response
from langchain_core.pydantic_v1 import SecretStr
from langchain_core.utils.pydantic import (
is_pydantic_v1_subclass,
)
@@ -353,7 +353,7 @@ def from_env(
@overload
def secret_from_env(key: str, /) -> Callable[[], SecretStr]: ...
def secret_from_env(key: Union[str, Sequence[str]], /) -> Callable[[], SecretStr]: ...
@overload

View File

@@ -42,8 +42,9 @@ from typing import (
TypeVar,
)
from pydantic import ConfigDict, Field, model_validator
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.retrievers import BaseRetriever, LangSmithRetrieverParams
from langchain_core.runnables.config import run_in_executor
@@ -984,11 +985,13 @@ class VectorStoreRetriever(BaseRetriever):
"mmr",
)
class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
@root_validator(pre=True)
def validate_search_type(cls, values: Dict) -> Dict:
@model_validator(mode="before")
@classmethod
def validate_search_type(cls, values: Dict) -> Any:
"""Validate search type.
Args:

View File

@@ -51,7 +51,7 @@ def _cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
f"and Y has shape {Y.shape}."
)
try:
import simsimd as simd
import simsimd as simd # type: ignore[import-not-found]
X = np.array(X, dtype=np.float32)
Y = np.array(Y, dtype=np.float32)

97
libs/core/poetry.lock generated
View File

@@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand.
[[package]]
name = "annotated-types"
@@ -11,9 +11,6 @@ files = [
{file = "annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89"},
]
[package.dependencies]
typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.9\""}
[[package]]
name = "anyio"
version = "4.4.0"
@@ -185,9 +182,6 @@ files = [
{file = "babel-2.16.0.tar.gz", hash = "sha256:d1f3554ca26605fe173f3de0c65f750f5a42f924499bf134de6423582298e316"},
]
[package.dependencies]
pytz = {version = ">=2015.7", markers = "python_version < \"3.9\""}
[package.extras]
dev = ["freezegun (>=1.0,<2.0)", "pytest (>=6.0)", "pytest-cov"]
@@ -693,28 +687,6 @@ doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linke
perf = ["ipython"]
test = ["flufl.flake8", "importlib-resources (>=1.3)", "jaraco.test (>=5.4)", "packaging", "pyfakefs", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-perf (>=0.9.2)", "pytest-ruff (>=0.2.1)"]
[[package]]
name = "importlib-resources"
version = "6.4.4"
description = "Read resources from Python packages"
optional = false
python-versions = ">=3.8"
files = [
{file = "importlib_resources-6.4.4-py3-none-any.whl", hash = "sha256:dda242603d1c9cd836c3368b1174ed74cb4049ecd209e7a1a0104620c18c5c11"},
{file = "importlib_resources-6.4.4.tar.gz", hash = "sha256:20600c8b7361938dc0bb2d5ec0297802e575df486f5a544fa414da65e13721f7"},
]
[package.dependencies]
zipp = {version = ">=3.1.0", markers = "python_version < \"3.10\""}
[package.extras]
check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)"]
cover = ["pytest-cov"]
doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"]
enabler = ["pytest-enabler (>=2.2)"]
test = ["jaraco.test (>=5.4)", "pytest (>=6,!=8.1.*)", "zipp (>=3.17)"]
type = ["pytest-mypy"]
[[package]]
name = "iniconfig"
version = "2.0.0"
@@ -920,11 +892,9 @@ files = [
attrs = ">=22.2.0"
fqdn = {version = "*", optional = true, markers = "extra == \"format-nongpl\""}
idna = {version = "*", optional = true, markers = "extra == \"format-nongpl\""}
importlib-resources = {version = ">=1.4.0", markers = "python_version < \"3.9\""}
isoduration = {version = "*", optional = true, markers = "extra == \"format-nongpl\""}
jsonpointer = {version = ">1.13", optional = true, markers = "extra == \"format-nongpl\""}
jsonschema-specifications = ">=2023.03.6"
pkgutil-resolve-name = {version = ">=1.3.10", markers = "python_version < \"3.9\""}
referencing = ">=0.28.4"
rfc3339-validator = {version = "*", optional = true, markers = "extra == \"format-nongpl\""}
rfc3986-validator = {version = ">0.1.0", optional = true, markers = "extra == \"format-nongpl\""}
@@ -948,7 +918,6 @@ files = [
]
[package.dependencies]
importlib-resources = {version = ">=1.4.0", markers = "python_version < \"3.9\""}
referencing = ">=0.31.0"
[[package]]
@@ -1148,7 +1117,6 @@ files = [
async-lru = ">=1.0.0"
httpx = ">=0.25.0"
importlib-metadata = {version = ">=4.8.3", markers = "python_version < \"3.10\""}
importlib-resources = {version = ">=1.4", markers = "python_version < \"3.9\""}
ipykernel = ">=6.5.0"
jinja2 = ">=3.0.3"
jupyter-core = "*"
@@ -1555,43 +1523,6 @@ jupyter-server = ">=1.8,<3"
[package.extras]
test = ["pytest", "pytest-console-scripts", "pytest-jupyter", "pytest-tornasync"]
[[package]]
name = "numpy"
version = "1.24.4"
description = "Fundamental package for array computing in Python"
optional = false
python-versions = ">=3.8"
files = [
{file = "numpy-1.24.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c0bfb52d2169d58c1cdb8cc1f16989101639b34c7d3ce60ed70b19c63eba0b64"},
{file = "numpy-1.24.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ed094d4f0c177b1b8e7aa9cba7d6ceed51c0e569a5318ac0ca9a090680a6a1b1"},
{file = "numpy-1.24.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:79fc682a374c4a8ed08b331bef9c5f582585d1048fa6d80bc6c35bc384eee9b4"},
{file = "numpy-1.24.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7ffe43c74893dbf38c2b0a1f5428760a1a9c98285553c89e12d70a96a7f3a4d6"},
{file = "numpy-1.24.4-cp310-cp310-win32.whl", hash = "sha256:4c21decb6ea94057331e111a5bed9a79d335658c27ce2adb580fb4d54f2ad9bc"},
{file = "numpy-1.24.4-cp310-cp310-win_amd64.whl", hash = "sha256:b4bea75e47d9586d31e892a7401f76e909712a0fd510f58f5337bea9572c571e"},
{file = "numpy-1.24.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:f136bab9c2cfd8da131132c2cf6cc27331dd6fae65f95f69dcd4ae3c3639c810"},
{file = "numpy-1.24.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e2926dac25b313635e4d6cf4dc4e51c8c0ebfed60b801c799ffc4c32bf3d1254"},
{file = "numpy-1.24.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:222e40d0e2548690405b0b3c7b21d1169117391c2e82c378467ef9ab4c8f0da7"},
{file = "numpy-1.24.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7215847ce88a85ce39baf9e89070cb860c98fdddacbaa6c0da3ffb31b3350bd5"},
{file = "numpy-1.24.4-cp311-cp311-win32.whl", hash = "sha256:4979217d7de511a8d57f4b4b5b2b965f707768440c17cb70fbf254c4b225238d"},
{file = "numpy-1.24.4-cp311-cp311-win_amd64.whl", hash = "sha256:b7b1fc9864d7d39e28f41d089bfd6353cb5f27ecd9905348c24187a768c79694"},
{file = "numpy-1.24.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1452241c290f3e2a312c137a9999cdbf63f78864d63c79039bda65ee86943f61"},
{file = "numpy-1.24.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:04640dab83f7c6c85abf9cd729c5b65f1ebd0ccf9de90b270cd61935eef0197f"},
{file = "numpy-1.24.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a5425b114831d1e77e4b5d812b69d11d962e104095a5b9c3b641a218abcc050e"},
{file = "numpy-1.24.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd80e219fd4c71fc3699fc1dadac5dcf4fd882bfc6f7ec53d30fa197b8ee22dc"},
{file = "numpy-1.24.4-cp38-cp38-win32.whl", hash = "sha256:4602244f345453db537be5314d3983dbf5834a9701b7723ec28923e2889e0bb2"},
{file = "numpy-1.24.4-cp38-cp38-win_amd64.whl", hash = "sha256:692f2e0f55794943c5bfff12b3f56f99af76f902fc47487bdfe97856de51a706"},
{file = "numpy-1.24.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:2541312fbf09977f3b3ad449c4e5f4bb55d0dbf79226d7724211acc905049400"},
{file = "numpy-1.24.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9667575fb6d13c95f1b36aca12c5ee3356bf001b714fc354eb5465ce1609e62f"},
{file = "numpy-1.24.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f3a86ed21e4f87050382c7bc96571755193c4c1392490744ac73d660e8f564a9"},
{file = "numpy-1.24.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d11efb4dbecbdf22508d55e48d9c8384db795e1b7b51ea735289ff96613ff74d"},
{file = "numpy-1.24.4-cp39-cp39-win32.whl", hash = "sha256:6620c0acd41dbcb368610bb2f4d83145674040025e5536954782467100aa8835"},
{file = "numpy-1.24.4-cp39-cp39-win_amd64.whl", hash = "sha256:befe2bf740fd8373cf56149a5c23a0f601e82869598d41f8e188a0e9869926f8"},
{file = "numpy-1.24.4-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:31f13e25b4e304632a4619d0e0777662c2ffea99fcae2029556b17d8ff958aef"},
{file = "numpy-1.24.4-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95f7ac6540e95bc440ad77f56e520da5bf877f87dca58bd095288dce8940532a"},
{file = "numpy-1.24.4-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:e98f220aa76ca2a977fe435f5b04d7b3470c0a2e6312907b37ba6068f26787f2"},
{file = "numpy-1.24.4.tar.gz", hash = "sha256:80f5e3a4e498641401868df4208b74581206afbee7cf7b8329daae82676d9463"},
]
[[package]]
name = "numpy"
version = "1.26.4"
@@ -1776,17 +1707,6 @@ files = [
{file = "pickleshare-0.7.5.tar.gz", hash = "sha256:87683d47965c1da65cdacaf31c8441d12b8044cdec9aca500cd78fc2c683afca"},
]
[[package]]
name = "pkgutil-resolve-name"
version = "1.3.10"
description = "Resolve a name to an object."
optional = false
python-versions = ">=3.6"
files = [
{file = "pkgutil_resolve_name-1.3.10-py3-none-any.whl", hash = "sha256:ca27cc078d25c5ad71a9de0a7a330146c4e014c2462d9af19c6b828280649c5e"},
{file = "pkgutil_resolve_name-1.3.10.tar.gz", hash = "sha256:357d6c9e6a755653cfd78893817c0853af365dd51ec97f3d358a819373bbd174"},
]
[[package]]
name = "platformdirs"
version = "4.2.2"
@@ -2178,17 +2098,6 @@ files = [
{file = "python_json_logger-2.0.7-py3-none-any.whl", hash = "sha256:f380b826a991ebbe3de4d897aeec42760035ac760345e57b812938dc8b35e2bd"},
]
[[package]]
name = "pytz"
version = "2024.1"
description = "World timezone definitions, modern and historical"
optional = false
python-versions = "*"
files = [
{file = "pytz-2024.1-py2.py3-none-any.whl", hash = "sha256:328171f4e3623139da4983451950b28e95ac706e13f3f2630a879749e7a8b319"},
{file = "pytz-2024.1.tar.gz", hash = "sha256:2a29735ea9c18baf14b448846bde5a48030ed267578472d8955cd0e7443a9812"},
]
[[package]]
name = "pywin32"
version = "306"
@@ -3219,5 +3128,5 @@ type = ["pytest-mypy"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.8.1,<4.0"
content-hash = "e86e28d75744b77f8c2173b58e950c2d23c95cacad8ee5d3110309c4248f7c09"
python-versions = ">=3.9,<4.0"
content-hash = "7f2ce36878754aeb498d961452c156ab63e95bc5c6bcf3ca29acae325062aed9"

View File

@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
[tool.poetry]
name = "langchain-core"
version = "0.2.38"
version = "0.3.0.dev1"
description = "Building applications with LLMs through composability"
authors = []
license = "MIT"
@@ -23,7 +23,7 @@ ignore_missing_imports = true
"Release Notes" = "https://github.com/langchain-ai/langchain/releases?q=tag%3A%22langchain-core%3D%3D0%22&expanded=true"
[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
python = ">=3.9,<4.0"
langsmith = "^0.1.75"
tenacity = "^8.1.0,!=8.4.0"
jsonpatch = "^1.33"

View File

@@ -3,9 +3,9 @@ import warnings
from typing import Any, Dict
import pytest
from pydantic import BaseModel
from langchain_core._api.beta_decorator import beta, warn_beta
from langchain_core.pydantic_v1 import BaseModel
@pytest.mark.parametrize(

View File

@@ -3,13 +3,13 @@ import warnings
from typing import Any, Dict
import pytest
from pydantic import BaseModel
from langchain_core._api.deprecation import (
deprecated,
rename_parameter,
warn_deprecated,
)
from langchain_core.pydantic_v1 import BaseModel
@pytest.mark.parametrize(

View File

@@ -4,9 +4,10 @@ from itertools import chain
from typing import Any, Dict, List, Optional, Union
from uuid import UUID
from pydantic import BaseModel
from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
from langchain_core.messages import BaseMessage
from langchain_core.pydantic_v1 import BaseModel
class BaseFakeCallbackHandler(BaseModel):
@@ -256,7 +257,8 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any:
self.on_retriever_error_common()
def __deepcopy__(self, memo: dict) -> "FakeCallbackHandler":
# Overriding since BaseModel has __deepcopy__ method as well
def __deepcopy__(self, memo: dict) -> "FakeCallbackHandler": # type: ignore
return self
@@ -390,5 +392,6 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
) -> None:
self.on_text_common()
def __deepcopy__(self, memo: dict) -> "FakeAsyncCallbackHandler":
# Overriding since BaseModel has __deepcopy__ method as well
def __deepcopy__(self, memo: dict) -> "FakeAsyncCallbackHandler": # type: ignore
return self

View File

@@ -9,7 +9,6 @@ from langchain_core.language_models import GenericFakeChatModel, ParrotFakeChatM
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
from tests.unit_tests.stubs import (
AnyStr,
_AnyIdAIMessage,
_AnyIdAIMessageChunk,
_AnyIdHumanMessage,
@@ -70,8 +69,8 @@ async def test_generic_fake_chat_model_stream() -> None:
model = GenericFakeChatModel(messages=cycle([message]))
chunks = [chunk async for chunk in model.astream("meow")]
assert chunks == [
AIMessageChunk(content="", additional_kwargs={"foo": 42}, id=AnyStr()),
AIMessageChunk(content="", additional_kwargs={"bar": 24}, id=AnyStr()),
_AnyIdAIMessageChunk(content="", additional_kwargs={"foo": 42}),
_AnyIdAIMessageChunk(content="", additional_kwargs={"bar": 24}),
]
assert len({chunk.id for chunk in chunks}) == 1
@@ -89,29 +88,23 @@ async def test_generic_fake_chat_model_stream() -> None:
chunks = [chunk async for chunk in model.astream("meow")]
assert chunks == [
AIMessageChunk(
content="",
additional_kwargs={"function_call": {"name": "move_file"}},
id=AnyStr(),
_AnyIdAIMessageChunk(
content="", additional_kwargs={"function_call": {"name": "move_file"}}
),
AIMessageChunk(
_AnyIdAIMessageChunk(
content="",
additional_kwargs={
"function_call": {"arguments": '{\n "source_path": "foo"'},
},
id=AnyStr(),
),
AIMessageChunk(
content="",
additional_kwargs={"function_call": {"arguments": ","}},
id=AnyStr(),
_AnyIdAIMessageChunk(
content="", additional_kwargs={"function_call": {"arguments": ","}}
),
AIMessageChunk(
_AnyIdAIMessageChunk(
content="",
additional_kwargs={
"function_call": {"arguments": '\n "destination_path": "bar"\n}'},
},
id=AnyStr(),
),
]
assert len({chunk.id for chunk in chunks}) == 1

View File

@@ -1,4 +1,5 @@
import time
from typing import Optional as Optional
from langchain_core.caches import InMemoryCache
from langchain_core.language_models import GenericFakeChatModel
@@ -220,6 +221,9 @@ class SerializableModel(GenericFakeChatModel):
return True
SerializableModel.model_rebuild()
def test_serialization_with_rate_limiter() -> None:
"""Test model serialization with rate limiter."""
from langchain_core.load import dumps

View File

@@ -1,8 +1,9 @@
from typing import Dict
from pydantic import ConfigDict, Field
from langchain_core.load import Serializable, dumpd
from langchain_core.load.serializable import _is_field_useful
from langchain_core.pydantic_v1 import Field
def test_simple_serialization() -> None:
@@ -40,8 +41,9 @@ def test_simple_serialization_is_serializable() -> None:
def test_simple_serialization_secret() -> None:
"""Test handling of secrets."""
from pydantic import SecretStr
from langchain_core.load import Serializable
from langchain_core.pydantic_v1 import SecretStr
class Foo(Serializable):
bar: int
@@ -97,8 +99,9 @@ def test__is_field_useful() -> None:
# Make sure works for fields without default.
z: ArrayObj
class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
foo = Foo(x=ArrayObj(), y=NonBoolObj(), z=ArrayObj())
assert _is_field_useful(foo, "x", foo.x)

View File

@@ -1,6 +1,7 @@
"""Module to test base parser implementations."""
from typing import List
from typing import Optional as Optional
from langchain_core.exceptions import OutputParserException
from langchain_core.language_models import GenericFakeChatModel
@@ -46,6 +47,8 @@ def test_base_generation_parser() -> None:
assert isinstance(content, str)
return content.swapcase() # type: ignore
StrInvertCase.model_rebuild()
model = GenericFakeChatModel(messages=iter([AIMessage(content="hEllo")]))
chain = model | StrInvertCase()
assert chain.invoke("") == "HeLLO"

View File

@@ -2,12 +2,12 @@ import json
from typing import Any, AsyncIterator, Iterator, Tuple
import pytest
from pydantic import BaseModel
from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers.json import (
SimpleJsonOutputParser,
)
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.utils.function_calling import convert_to_openai_function
from langchain_core.utils.json import parse_json_markdown, parse_partial_json
from tests.unit_tests.pydantic_utils import _schema

View File

@@ -2,6 +2,7 @@ import json
from typing import Any, Dict
import pytest
from pydantic import BaseModel
from langchain_core.exceptions import OutputParserException
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
@@ -10,7 +11,6 @@ from langchain_core.output_parsers.openai_functions import (
PydanticOutputFunctionsParser,
)
from langchain_core.outputs import ChatGeneration
from langchain_core.pydantic_v1 import BaseModel
def test_json_output_function_parser() -> None:

View File

@@ -1,6 +1,7 @@
from typing import Any, AsyncIterator, Iterator, List
import pytest
from pydantic import BaseModel, Field
from langchain_core.messages import (
AIMessage,
@@ -14,7 +15,6 @@ from langchain_core.output_parsers.openai_tools import (
PydanticToolsParser,
)
from langchain_core.outputs import ChatGeneration
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION
STREAMED_MESSAGES: list = [
@@ -531,7 +531,7 @@ async def test_partial_pydantic_output_parser_async() -> None:
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="This test is for pydantic 2")
def test_parse_with_different_pydantic_2_v1() -> None:
"""Test with pydantic.v1.BaseModel from pydantic 2."""
import pydantic # pydantic: ignore
import pydantic
class Forecast(pydantic.v1.BaseModel):
temperature: int
@@ -566,7 +566,7 @@ def test_parse_with_different_pydantic_2_v1() -> None:
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="This test is for pydantic 2")
def test_parse_with_different_pydantic_2_proper() -> None:
"""Test with pydantic.BaseModel from pydantic 2."""
import pydantic # pydantic: ignore
import pydantic
class Forecast(pydantic.BaseModel):
temperature: int
@@ -601,7 +601,7 @@ def test_parse_with_different_pydantic_2_proper() -> None:
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 1, reason="This test is for pydantic 1")
def test_parse_with_different_pydantic_1_proper() -> None:
"""Test with pydantic.BaseModel from pydantic 1."""
import pydantic # pydantic: ignore
import pydantic
class Forecast(pydantic.BaseModel):
temperature: int

View File

@@ -3,22 +3,17 @@
from enum import Enum
from typing import Literal, Optional
import pydantic # pydantic: ignore
import pydantic
import pytest
from pydantic import BaseModel, Field
from pydantic.v1 import BaseModel as V1BaseModel
from langchain_core.exceptions import OutputParserException
from langchain_core.language_models import ParrotFakeChatModel
from langchain_core.output_parsers import PydanticOutputParser
from langchain_core.output_parsers.json import JsonOutputParser
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION, TBaseModel
V1BaseModel = pydantic.BaseModel
if PYDANTIC_MAJOR_VERSION == 2:
from pydantic.v1 import BaseModel # pydantic: ignore
V1BaseModel = BaseModel # type: ignore
from langchain_core.utils.pydantic import TBaseModel
class ForecastV2(pydantic.BaseModel):
@@ -194,7 +189,7 @@ def test_pydantic_output_parser_type_inference() -> None:
def test_format_instructions_preserves_language() -> None:
"""Test format instructions does not attempt to encode into ascii."""
from langchain_core.pydantic_v1 import BaseModel, Field
from pydantic import BaseModel, Field
description = (
"你好, こんにちは, नमस्ते, Bonjour, Hola, "

File diff suppressed because it is too large Load Diff

View File

@@ -4,9 +4,12 @@ from pathlib import Path
from typing import Any, List, Tuple, Union, cast
import pytest
from pydantic import ValidationError
from syrupy import SnapshotAssertion
from langchain_core._api.deprecation import LangChainPendingDeprecationWarning
from langchain_core._api.deprecation import (
LangChainPendingDeprecationWarning,
)
from langchain_core.load import dumpd, load
from langchain_core.messages import (
AIMessage,
@@ -28,8 +31,6 @@ from langchain_core.prompts.chat import (
SystemMessagePromptTemplate,
_convert_to_message,
)
from langchain_core.pydantic_v1 import ValidationError
from tests.unit_tests.pydantic_utils import _schema
@pytest.fixture
@@ -794,21 +795,21 @@ def test_chat_input_schema(snapshot: SnapshotAssertion) -> None:
assert prompt_all_required.optional_variables == []
with pytest.raises(ValidationError):
prompt_all_required.input_schema(input="")
assert _schema(prompt_all_required.input_schema) == snapshot(name="required")
assert prompt_all_required.get_input_jsonschema() == snapshot(name="required")
prompt_optional = ChatPromptTemplate(
messages=[MessagesPlaceholder("history", optional=True), ("user", "${input}")]
)
# input variables only lists required variables
assert set(prompt_optional.input_variables) == {"input"}
prompt_optional.input_schema(input="") # won't raise error
assert _schema(prompt_optional.input_schema) == snapshot(name="partial")
assert prompt_optional.get_input_jsonschema() == snapshot(name="partial")
def test_chat_prompt_w_msgs_placeholder_ser_des(snapshot: SnapshotAssertion) -> None:
prompt = ChatPromptTemplate.from_messages(
[("system", "foo"), MessagesPlaceholder("bar"), ("human", "baz")]
)
assert dumpd(MessagesPlaceholder("bar")) == snapshot(name="placholder")
assert dumpd(MessagesPlaceholder("bar")) == snapshot(name="placeholder")
assert load(dumpd(MessagesPlaceholder("bar"))) == MessagesPlaceholder("bar")
assert dumpd(prompt) == snapshot(name="chat_prompt")
assert load(dumpd(prompt)) == prompt

View File

@@ -7,7 +7,6 @@ import pytest
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.tracers.run_collector import RunCollectorCallbackHandler
from tests.unit_tests.pydantic_utils import _schema
def test_prompt_valid() -> None:
@@ -70,10 +69,10 @@ def test_mustache_prompt_from_template() -> None:
prompt = PromptTemplate.from_template(template, template_format="mustache")
assert prompt.format(foo="bar") == "This is a bar test."
assert prompt.input_variables == ["foo"]
assert _schema(prompt.input_schema) == {
assert prompt.get_input_jsonschema() == {
"title": "PromptInput",
"type": "object",
"properties": {"foo": {"title": "Foo", "type": "string"}},
"properties": {"foo": {"title": "Foo", "type": "string", "default": None}},
}
# Multiple input variables.
@@ -81,12 +80,12 @@ def test_mustache_prompt_from_template() -> None:
prompt = PromptTemplate.from_template(template, template_format="mustache")
assert prompt.format(bar="baz", foo="bar") == "This baz is a bar test."
assert prompt.input_variables == ["bar", "foo"]
assert _schema(prompt.input_schema) == {
assert prompt.get_input_jsonschema() == {
"title": "PromptInput",
"type": "object",
"properties": {
"bar": {"title": "Bar", "type": "string"},
"foo": {"title": "Foo", "type": "string"},
"bar": {"title": "Bar", "type": "string", "default": None},
"foo": {"title": "Foo", "type": "string", "default": None},
},
}
@@ -95,12 +94,12 @@ def test_mustache_prompt_from_template() -> None:
prompt = PromptTemplate.from_template(template, template_format="mustache")
assert prompt.format(bar="baz", foo="bar") == "This baz is a bar test bar."
assert prompt.input_variables == ["bar", "foo"]
assert _schema(prompt.input_schema) == {
assert prompt.get_input_jsonschema() == {
"title": "PromptInput",
"type": "object",
"properties": {
"bar": {"title": "Bar", "type": "string"},
"foo": {"title": "Foo", "type": "string"},
"bar": {"title": "Bar", "type": "string", "default": None},
"foo": {"title": "Foo", "type": "string", "default": None},
},
}
@@ -111,23 +110,23 @@ def test_mustache_prompt_from_template() -> None:
"This foo is a bar test baz."
)
assert prompt.input_variables == ["foo", "obj"]
assert _schema(prompt.input_schema) == {
"title": "PromptInput",
"type": "object",
"properties": {
"foo": {"title": "Foo", "type": "string"},
"obj": {"$ref": "#/definitions/obj"},
},
"definitions": {
assert prompt.get_input_jsonschema() == {
"$defs": {
"obj": {
"properties": {
"bar": {"default": None, "title": "Bar", "type": "string"},
"foo": {"default": None, "title": "Foo", "type": "string"},
},
"title": "obj",
"type": "object",
"properties": {
"foo": {"title": "Foo", "type": "string"},
"bar": {"title": "Bar", "type": "string"},
},
}
},
"properties": {
"foo": {"default": None, "title": "Foo", "type": "string"},
"obj": {"allOf": [{"$ref": "#/$defs/obj"}], "default": None},
},
"title": "PromptInput",
"type": "object",
}
# . variables
@@ -135,7 +134,7 @@ def test_mustache_prompt_from_template() -> None:
prompt = PromptTemplate.from_template(template, template_format="mustache")
assert prompt.format(foo="baz") == ("This {'foo': 'baz'} is a test.")
assert prompt.input_variables == []
assert _schema(prompt.input_schema) == {
assert prompt.get_input_jsonschema() == {
"title": "PromptInput",
"type": "object",
"properties": {},
@@ -152,17 +151,19 @@ def test_mustache_prompt_from_template() -> None:
is a test."""
)
assert prompt.input_variables == ["foo"]
assert _schema(prompt.input_schema) == {
"title": "PromptInput",
"type": "object",
"properties": {"foo": {"$ref": "#/definitions/foo"}},
"definitions": {
assert prompt.get_input_jsonschema() == {
"$defs": {
"foo": {
"properties": {
"bar": {"default": None, "title": "Bar", "type": "string"}
},
"title": "foo",
"type": "object",
"properties": {"bar": {"title": "Bar", "type": "string"}},
}
},
"properties": {"foo": {"allOf": [{"$ref": "#/$defs/foo"}], "default": None}},
"title": "PromptInput",
"type": "object",
}
# more complex nested section/context variables
@@ -184,26 +185,28 @@ def test_mustache_prompt_from_template() -> None:
is a test."""
)
assert prompt.input_variables == ["foo"]
assert _schema(prompt.input_schema) == {
"title": "PromptInput",
"type": "object",
"properties": {"foo": {"$ref": "#/definitions/foo"}},
"definitions": {
"foo": {
"title": "foo",
"type": "object",
"properties": {
"bar": {"title": "Bar", "type": "string"},
"baz": {"$ref": "#/definitions/baz"},
"quux": {"title": "Quux", "type": "string"},
},
},
assert prompt.get_input_jsonschema() == {
"$defs": {
"baz": {
"properties": {
"qux": {"default": None, "title": "Qux", "type": "string"}
},
"title": "baz",
"type": "object",
"properties": {"qux": {"title": "Qux", "type": "string"}},
},
"foo": {
"properties": {
"bar": {"default": None, "title": "Bar", "type": "string"},
"baz": {"allOf": [{"$ref": "#/$defs/baz"}], "default": None},
"quux": {"default": None, "title": "Quux", "type": "string"},
},
"title": "foo",
"type": "object",
},
},
"properties": {"foo": {"allOf": [{"$ref": "#/$defs/foo"}], "default": None}},
"title": "PromptInput",
"type": "object",
}
# triply nested section/context variables
@@ -239,39 +242,43 @@ def test_mustache_prompt_from_template() -> None:
is a test."""
)
assert prompt.input_variables == ["foo"]
assert _schema(prompt.input_schema) == {
"title": "PromptInput",
"type": "object",
"properties": {"foo": {"$ref": "#/definitions/foo"}},
"definitions": {
"foo": {
"title": "foo",
"type": "object",
"properties": {
"bar": {"title": "Bar", "type": "string"},
"baz": {"$ref": "#/definitions/baz"},
"quux": {"title": "Quux", "type": "string"},
},
},
"baz": {
"title": "baz",
"type": "object",
"properties": {"qux": {"$ref": "#/definitions/qux"}},
},
"qux": {
"title": "qux",
"type": "object",
"properties": {
"foobar": {"title": "Foobar", "type": "string"},
"barfoo": {"$ref": "#/definitions/barfoo"},
},
},
assert prompt.get_input_jsonschema() == {
"$defs": {
"barfoo": {
"properties": {
"foobar": {"default": None, "title": "Foobar", "type": "string"}
},
"title": "barfoo",
"type": "object",
"properties": {"foobar": {"title": "Foobar", "type": "string"}},
},
"baz": {
"properties": {
"qux": {"allOf": [{"$ref": "#/$defs/qux"}], "default": None}
},
"title": "baz",
"type": "object",
},
"foo": {
"properties": {
"bar": {"default": None, "title": "Bar", "type": "string"},
"baz": {"allOf": [{"$ref": "#/$defs/baz"}], "default": None},
"quux": {"default": None, "title": "Quux", "type": "string"},
},
"title": "foo",
"type": "object",
},
"qux": {
"properties": {
"barfoo": {"allOf": [{"$ref": "#/$defs/barfoo"}], "default": None},
"foobar": {"default": None, "title": "Foobar", "type": "string"},
},
"title": "qux",
"type": "object",
},
},
"properties": {"foo": {"allOf": [{"$ref": "#/$defs/foo"}], "default": None}},
"title": "PromptInput",
"type": "object",
}
# section/context variables with repeats
@@ -287,19 +294,20 @@ def test_mustache_prompt_from_template() -> None:
is a test."""
)
assert prompt.input_variables == ["foo"]
assert _schema(prompt.input_schema) == {
"title": "PromptInput",
"type": "object",
"properties": {"foo": {"$ref": "#/definitions/foo"}},
"definitions": {
assert prompt.get_input_jsonschema() == {
"$defs": {
"foo": {
"properties": {
"bar": {"default": None, "title": "Bar", "type": "string"}
},
"title": "foo",
"type": "object",
"properties": {"bar": {"title": "Bar", "type": "string"}},
}
},
"properties": {"foo": {"allOf": [{"$ref": "#/$defs/foo"}], "default": None}},
"title": "PromptInput",
"type": "object",
}
template = """This{{^foo}}
no foos
{{/foo}}is a test."""
@@ -310,10 +318,10 @@ def test_mustache_prompt_from_template() -> None:
is a test."""
)
assert prompt.input_variables == ["foo"]
assert _schema(prompt.input_schema) == {
assert prompt.get_input_jsonschema() == {
"properties": {"foo": {"default": None, "title": "Foo", "type": "object"}},
"title": "PromptInput",
"type": "object",
"properties": {"foo": {"title": "Foo", "type": "object"}},
}

View File

@@ -1,12 +1,14 @@
from functools import partial
from inspect import isclass
from typing import Any, Dict, Type, Union, cast
from typing import Optional as Optional
from pydantic import BaseModel
from langchain_core.language_models import FakeListChatModel
from langchain_core.load.dump import dumps
from langchain_core.load.load import loads
from langchain_core.prompts.structured import StructuredPrompt
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables.base import Runnable, RunnableLambda
from langchain_core.utils.pydantic import is_basemodel_subclass
@@ -34,6 +36,9 @@ class FakeStructuredChatModel(FakeListChatModel):
return "fake-messages-list-chat-model"
FakeStructuredChatModel.model_rebuild()
def test_structured_prompt_pydantic() -> None:
class OutputSchema(BaseModel):
name: str

View File

@@ -1,20 +1,10 @@
"""Helper utilities for pydantic.
This module includes helper utilities to ease the migration from pydantic v1 to v2.
They're meant to be used in the following way:
1) Use utility code to help (selected) unit tests pass without modifications
2) Upgrade the unit tests to match pydantic 2
3) Stop using the utility code
"""
from typing import Any
from langchain_core.utils.pydantic import is_basemodel_subclass
# Function to replace allOf with $ref
def _replace_all_of_with_ref(schema: Any) -> None:
"""Replace allOf with $ref in the schema."""
def replace_all_of_with_ref(schema: Any) -> None:
if isinstance(schema, dict):
# If the schema has an allOf key with a single item that contains a $ref
if (
@@ -30,13 +20,13 @@ def _replace_all_of_with_ref(schema: Any) -> None:
# Recursively process nested schemas
for value in schema.values():
if isinstance(value, (dict, list)):
_replace_all_of_with_ref(value)
replace_all_of_with_ref(value)
elif isinstance(schema, list):
for item in schema:
_replace_all_of_with_ref(item)
replace_all_of_with_ref(item)
def _remove_bad_none_defaults(schema: Any) -> None:
def remove_all_none_default(schema: Any) -> None:
"""Removing all none defaults.
Pydantic v1 did not generate these, but Pydantic v2 does.
@@ -56,39 +46,48 @@ def _remove_bad_none_defaults(schema: Any) -> None:
break # Null type explicitly defined
else:
del value["default"]
_remove_bad_none_defaults(value)
remove_all_none_default(value)
elif isinstance(value, list):
for item in value:
_remove_bad_none_defaults(item)
remove_all_none_default(item)
elif isinstance(schema, list):
for item in schema:
_remove_bad_none_defaults(item)
remove_all_none_default(item)
def _remove_enum_description(obj: Any) -> None:
"""Remove the description from enums."""
if isinstance(obj, dict):
if "enum" in obj:
if "description" in obj and obj["description"] == "An enumeration.":
del obj["description"]
for value in obj.values():
_remove_enum_description(value)
elif isinstance(obj, list):
for item in obj:
_remove_enum_description(item)
def _schema(obj: Any) -> dict:
"""Get the schema of a pydantic model in the pydantic v1 style.
This will attempt to map the schema as close as possible to the pydantic v1 schema.
"""
"""Return the schema of the object."""
# Remap to old style schema
if not is_basemodel_subclass(obj):
raise TypeError(
f"Object must be a Pydantic BaseModel subclass. Got {type(obj)}"
)
if not hasattr(obj, "model_json_schema"): # V1 model
return obj.schema()
# Then we're using V2 models internally.
raise AssertionError(
"Hi there! Looks like you're attempting to upgrade to Pydantic v2. If so: \n"
"1) remove this exception\n"
"2) confirm that the old unit tests pass, and if not look for difference\n"
"3) update the unit tests to match the new schema\n"
"4) remove this utility function\n"
)
schema_ = obj.model_json_schema(ref_template="#/definitions/{model}")
if "$defs" in schema_:
schema_["definitions"] = schema_["$defs"]
del schema_["$defs"]
_replace_all_of_with_ref(schema_)
_remove_bad_none_defaults(schema_)
if "default" in schema_ and schema_["default"] is None:
del schema_["default"]
replace_all_of_with_ref(schema_)
remove_all_none_default(schema_)
_remove_enum_description(schema_)
return schema_

View File

@@ -1,8 +1,9 @@
from typing import Any, Dict, Optional
import pytest
from pydantic import ConfigDict, Field, model_validator
from typing_extensions import Self
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.runnables import (
ConfigurableField,
RunnableConfig,
@@ -14,19 +15,21 @@ class MyRunnable(RunnableSerializable[str, str]):
my_property: str = Field(alias="my_property_alias")
_my_hidden_property: str = ""
class Config:
allow_population_by_field_name = True
model_config = ConfigDict(
populate_by_name=True,
)
@root_validator(pre=True)
def my_error(cls, values: Dict[str, Any]) -> Dict[str, Any]:
@model_validator(mode="before")
@classmethod
def my_error(cls, values: Dict[str, Any]) -> Any:
if "_my_hidden_property" in values:
raise ValueError("Cannot set _my_hidden_property")
return values
@root_validator(pre=False, skip_on_failure=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
values["_my_hidden_property"] = values["my_property"]
return values
@model_validator(mode="after")
def build_extra(self) -> Self:
self._my_hidden_property = self.my_property
return self
def invoke(self, input: str, config: Optional[RunnableConfig] = None) -> Any:
return input + self._my_hidden_property

View File

@@ -13,6 +13,7 @@ from typing import (
)
import pytest
from pydantic import BaseModel
from syrupy import SnapshotAssertion
from langchain_core.callbacks import CallbackManagerForLLMRun
@@ -25,7 +26,6 @@ from langchain_core.load import dumps
from langchain_core.messages import BaseMessage
from langchain_core.outputs import ChatResult
from langchain_core.prompts import PromptTemplate
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables import (
Runnable,
RunnableBinding,

View File

@@ -1,5 +1,6 @@
from typing import Optional
from pydantic import BaseModel
from syrupy import SnapshotAssertion
from langchain_core.language_models import FakeListLLM
@@ -7,11 +8,9 @@ from langchain_core.output_parsers.list import CommaSeparatedListOutputParser
from langchain_core.output_parsers.string import StrOutputParser
from langchain_core.output_parsers.xml import XMLOutputParser
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables.base import Runnable, RunnableConfig
from langchain_core.runnables.graph import Edge, Graph, Node
from langchain_core.runnables.graph_mermaid import _escape_node_label
from tests.unit_tests.pydantic_utils import _schema
def test_graph_single_runnable(snapshot: SnapshotAssertion) -> None:
@@ -19,10 +18,10 @@ def test_graph_single_runnable(snapshot: SnapshotAssertion) -> None:
graph = StrOutputParser().get_graph()
first_node = graph.first_node()
assert first_node is not None
assert _schema(first_node.data) == _schema(runnable.input_schema) # type: ignore[union-attr]
assert first_node.data.schema() == runnable.get_input_jsonschema() # type: ignore[union-attr]
last_node = graph.last_node()
assert last_node is not None
assert _schema(last_node.data) == _schema(runnable.output_schema) # type: ignore[union-attr]
assert last_node.data.schema() == runnable.get_output_jsonschema() # type: ignore[union-attr]
assert len(graph.nodes) == 3
assert len(graph.edges) == 2
assert graph.edges[0].source == first_node.id

Some files were not shown because too many files have changed in this diff Show More