mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-04 08:10:25 +00:00
Compare commits
91 Commits
langchain-
...
langchain-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d0222964c1 | ||
|
|
b97307c8b4 | ||
|
|
1ad66e70dc | ||
|
|
76564edd3a | ||
|
|
1c51e1693d | ||
|
|
c0f886dc52 | ||
|
|
a267da6a3a | ||
|
|
0c63b18c1f | ||
|
|
915c1e3dfb | ||
|
|
8da2ace99d | ||
|
|
81cd73cfca | ||
|
|
e358846b39 | ||
|
|
3c598d25a6 | ||
|
|
e5aa0f938b | ||
|
|
79c46319dd | ||
|
|
c5d4dfefc0 | ||
|
|
6e853501ec | ||
|
|
fd1f3ca213 | ||
|
|
567a4ce5aa | ||
|
|
923ce84aa7 | ||
|
|
9379613132 | ||
|
|
c72a76237f | ||
|
|
f9cafcbcb0 | ||
|
|
1fce5543bc | ||
|
|
88e9e6bf55 | ||
|
|
7f0dd4b182 | ||
|
|
5557b86a54 | ||
|
|
caf4ae3a45 | ||
|
|
c88b75ca6a | ||
|
|
e409a85a28 | ||
|
|
40634d441a | ||
|
|
1d2a503ab8 | ||
|
|
b924c61440 | ||
|
|
efa10c8ef8 | ||
|
|
0a6c67ce6a | ||
|
|
ed771f2d2b | ||
|
|
63ba12d8e0 | ||
|
|
f785cf029b | ||
|
|
be7cd0756f | ||
|
|
51c6899850 | ||
|
|
163d6fe8ef | ||
|
|
7cee7fbfad | ||
|
|
4799ad95d0 | ||
|
|
88065d794b | ||
|
|
b27bfa6717 | ||
|
|
5adeaf0732 | ||
|
|
f9d91e19c5 | ||
|
|
4c7afb0d6c | ||
|
|
c1ff61669d | ||
|
|
54d6808c1e | ||
|
|
78468de2e5 | ||
|
|
76572f963b | ||
|
|
c0448f27ba | ||
|
|
179aaa4007 | ||
|
|
d072d592a1 | ||
|
|
78c454c130 | ||
|
|
5199555c0d | ||
|
|
5e31cd91a7 | ||
|
|
49a1f5dd47 | ||
|
|
d0cc9b022a | ||
|
|
a91bd2737a | ||
|
|
5ad2b8ce80 | ||
|
|
b78764599b | ||
|
|
2888e34f53 | ||
|
|
dd4418a503 | ||
|
|
a976f2071b | ||
|
|
5f98975be0 | ||
|
|
0529c991ce | ||
|
|
954abcce59 | ||
|
|
6ad515d34e | ||
|
|
99348e1614 | ||
|
|
2c742cc20d | ||
|
|
02f87203f7 | ||
|
|
56163481dd | ||
|
|
6aac2eeab5 | ||
|
|
559d8a4d13 | ||
|
|
ec9e8eb71c | ||
|
|
9399df7777 | ||
|
|
5fc1104d00 | ||
|
|
6777106fbe | ||
|
|
5f5287c3b0 | ||
|
|
615f8b0d47 | ||
|
|
9a9ab65030 | ||
|
|
241b6d2355 | ||
|
|
91e09ffee5 | ||
|
|
8e4bae351e | ||
|
|
0da201c1d5 | ||
|
|
29413a22e1 | ||
|
|
ae5a574aa5 | ||
|
|
5a0e82c31c | ||
|
|
8590b421c4 |
9
.github/scripts/check_diff.py
vendored
9
.github/scripts/check_diff.py
vendored
@@ -16,6 +16,10 @@ LANGCHAIN_DIRS = [
|
||||
"libs/experimental",
|
||||
]
|
||||
|
||||
# for 0.3rc, we are ignoring core dependents
|
||||
# in order to be able to get CI to pass for individual PRs.
|
||||
IGNORE_CORE_DEPENDENTS = True
|
||||
|
||||
# ignored partners are removed from dependents
|
||||
# but still run if directly edited
|
||||
IGNORED_PARTNERS = [
|
||||
@@ -104,7 +108,7 @@ def _get_configs_for_single_dir(job: str, dir_: str) -> List[Dict[str, str]]:
|
||||
{"working-directory": dir_, "python-version": f"3.{v}"}
|
||||
for v in range(8, 13)
|
||||
]
|
||||
min_python = "3.8"
|
||||
min_python = "3.9"
|
||||
max_python = "3.12"
|
||||
|
||||
# custom logic for specific directories
|
||||
@@ -184,6 +188,9 @@ if __name__ == "__main__":
|
||||
# for extended testing
|
||||
found = False
|
||||
for dir_ in LANGCHAIN_DIRS:
|
||||
if dir_ == "libs/core" and IGNORE_CORE_DEPENDENTS:
|
||||
dirs_to_run["extended-test"].add(dir_)
|
||||
continue
|
||||
if file.startswith(dir_):
|
||||
found = True
|
||||
if found:
|
||||
|
||||
@@ -11,7 +11,7 @@ if __name__ == "__main__":
|
||||
|
||||
# see if we're releasing an rc
|
||||
version = toml_data["tool"]["poetry"]["version"]
|
||||
releasing_rc = "rc" in version
|
||||
releasing_rc = "rc" in version or "dev" in version
|
||||
|
||||
# if not, iterate through dependencies and make sure none allow prereleases
|
||||
if not releasing_rc:
|
||||
|
||||
114
.github/workflows/_dependencies.yml
vendored
114
.github/workflows/_dependencies.yml
vendored
@@ -1,114 +0,0 @@
|
||||
name: dependencies
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
working-directory:
|
||||
required: true
|
||||
type: string
|
||||
description: "From which folder this pipeline executes"
|
||||
langchain-location:
|
||||
required: false
|
||||
type: string
|
||||
description: "Relative path to the langchain library folder"
|
||||
python-version:
|
||||
required: true
|
||||
type: string
|
||||
description: "Python version to use"
|
||||
|
||||
env:
|
||||
POETRY_VERSION: "1.7.1"
|
||||
|
||||
jobs:
|
||||
build:
|
||||
defaults:
|
||||
run:
|
||||
working-directory: ${{ inputs.working-directory }}
|
||||
runs-on: ubuntu-latest
|
||||
name: dependency checks ${{ inputs.python-version }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python ${{ inputs.python-version }} + Poetry ${{ env.POETRY_VERSION }}
|
||||
uses: "./.github/actions/poetry_setup"
|
||||
with:
|
||||
python-version: ${{ inputs.python-version }}
|
||||
poetry-version: ${{ env.POETRY_VERSION }}
|
||||
working-directory: ${{ inputs.working-directory }}
|
||||
cache-key: pydantic-cross-compat
|
||||
|
||||
- name: Install dependencies
|
||||
shell: bash
|
||||
run: poetry install
|
||||
|
||||
- name: Check imports with base dependencies
|
||||
shell: bash
|
||||
run: poetry run make check_imports
|
||||
|
||||
- name: Install test dependencies
|
||||
shell: bash
|
||||
run: poetry install --with test
|
||||
|
||||
- name: Install langchain editable
|
||||
working-directory: ${{ inputs.working-directory }}
|
||||
if: ${{ inputs.langchain-location }}
|
||||
env:
|
||||
LANGCHAIN_LOCATION: ${{ inputs.langchain-location }}
|
||||
run: |
|
||||
poetry run pip install -e "$LANGCHAIN_LOCATION"
|
||||
|
||||
- name: Install the opposite major version of pydantic
|
||||
# If normal tests use pydantic v1, here we'll use v2, and vice versa.
|
||||
shell: bash
|
||||
# airbyte currently doesn't support pydantic v2
|
||||
if: ${{ !startsWith(inputs.working-directory, 'libs/partners/airbyte') }}
|
||||
run: |
|
||||
# Determine the major part of pydantic version
|
||||
REGULAR_VERSION=$(poetry run python -c "import pydantic; print(pydantic.__version__)" | cut -d. -f1)
|
||||
|
||||
if [[ "$REGULAR_VERSION" == "1" ]]; then
|
||||
PYDANTIC_DEP=">=2.1,<3"
|
||||
TEST_WITH_VERSION="2"
|
||||
elif [[ "$REGULAR_VERSION" == "2" ]]; then
|
||||
PYDANTIC_DEP="<2"
|
||||
TEST_WITH_VERSION="1"
|
||||
else
|
||||
echo "Unexpected pydantic major version '$REGULAR_VERSION', cannot determine which version to use for cross-compatibility test."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Install via `pip` instead of `poetry add` to avoid changing lockfile,
|
||||
# which would prevent caching from working: the cache would get saved
|
||||
# to a different key than where it gets loaded from.
|
||||
poetry run pip install "pydantic${PYDANTIC_DEP}"
|
||||
|
||||
# Ensure that the correct pydantic is installed now.
|
||||
echo "Checking pydantic version... Expecting ${TEST_WITH_VERSION}"
|
||||
|
||||
# Determine the major part of pydantic version
|
||||
CURRENT_VERSION=$(poetry run python -c "import pydantic; print(pydantic.__version__)" | cut -d. -f1)
|
||||
|
||||
# Check that the major part of pydantic version is as expected, if not
|
||||
# raise an error
|
||||
if [[ "$CURRENT_VERSION" != "$TEST_WITH_VERSION" ]]; then
|
||||
echo "Error: expected pydantic version ${CURRENT_VERSION} to have been installed, but found: ${TEST_WITH_VERSION}"
|
||||
exit 1
|
||||
fi
|
||||
echo "Found pydantic version ${CURRENT_VERSION}, as expected"
|
||||
- name: Run pydantic compatibility tests
|
||||
# airbyte currently doesn't support pydantic v2
|
||||
if: ${{ !startsWith(inputs.working-directory, 'libs/partners/airbyte') }}
|
||||
shell: bash
|
||||
run: make test
|
||||
|
||||
- name: Ensure the tests did not create any additional files
|
||||
shell: bash
|
||||
run: |
|
||||
set -eu
|
||||
|
||||
STATUS="$(git status)"
|
||||
echo "$STATUS"
|
||||
|
||||
# grep will exit non-zero if the target message isn't found,
|
||||
# and `set -e` above will cause the step to fail.
|
||||
echo "$STATUS" | grep 'nothing to commit, working tree clean'
|
||||
15
.github/workflows/check_diffs.yml
vendored
15
.github/workflows/check_diffs.yml
vendored
@@ -89,19 +89,6 @@ jobs:
|
||||
python-version: ${{ matrix.job-configs.python-version }}
|
||||
secrets: inherit
|
||||
|
||||
dependencies:
|
||||
name: cd ${{ matrix.job-configs.working-directory }}
|
||||
needs: [ build ]
|
||||
if: ${{ needs.build.outputs.dependencies != '[]' }}
|
||||
strategy:
|
||||
matrix:
|
||||
job-configs: ${{ fromJson(needs.build.outputs.dependencies) }}
|
||||
uses: ./.github/workflows/_dependencies.yml
|
||||
with:
|
||||
working-directory: ${{ matrix.job-configs.working-directory }}
|
||||
python-version: ${{ matrix.job-configs.python-version }}
|
||||
secrets: inherit
|
||||
|
||||
extended-tests:
|
||||
name: "cd ${{ matrix.job-configs.working-directory }} / make extended_tests #${{ matrix.job-configs.python-version }}"
|
||||
needs: [ build ]
|
||||
@@ -149,7 +136,7 @@ jobs:
|
||||
echo "$STATUS" | grep 'nothing to commit, working tree clean'
|
||||
ci_success:
|
||||
name: "CI Success"
|
||||
needs: [build, lint, test, compile-integration-tests, dependencies, extended-tests, test-doc-imports]
|
||||
needs: [build, lint, test, compile-integration-tests, extended-tests, test-doc-imports]
|
||||
if: |
|
||||
always()
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
2
.github/workflows/scheduled_test.yml
vendored
2
.github/workflows/scheduled_test.yml
vendored
@@ -17,7 +17,7 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version:
|
||||
- "3.8"
|
||||
- "3.9"
|
||||
- "3.11"
|
||||
working-directory:
|
||||
- "libs/partners/openai"
|
||||
|
||||
@@ -12,6 +12,9 @@ integration_test integration_tests: TEST_FILE = tests/integration_tests/
|
||||
test tests:
|
||||
poetry run pytest --disable-socket --allow-unix-socket $(TEST_FILE)
|
||||
|
||||
test_watch:
|
||||
poetry run ptw --snapshot-update --now . -- -vv $(TEST_FILE)
|
||||
|
||||
# integration tests are run without the --disable-socket flag to allow network calls
|
||||
integration_test integration_tests:
|
||||
poetry run pytest $(TEST_FILE)
|
||||
|
||||
@@ -23,6 +23,7 @@ pytest = "^7.4.3"
|
||||
pytest-asyncio = "^0.23.2"
|
||||
pytest-socket = "^0.7.0"
|
||||
langchain-core = { path = "../../core", develop = true }
|
||||
pytest-watcher = "^0.3.4"
|
||||
|
||||
[tool.poetry.group.codespell]
|
||||
optional = true
|
||||
|
||||
@@ -102,6 +102,16 @@ def test_serializable_mapping() -> None:
|
||||
"modifier",
|
||||
"RemoveMessage",
|
||||
),
|
||||
("langchain", "chat_models", "mistralai", "ChatMistralAI"): (
|
||||
"langchain_mistralai",
|
||||
"chat_models",
|
||||
"ChatMistralAI",
|
||||
),
|
||||
("langchain_groq", "chat_models", "ChatGroq"): (
|
||||
"langchain_groq",
|
||||
"chat_models",
|
||||
"ChatGroq",
|
||||
),
|
||||
}
|
||||
serializable_modules = import_all_modules("langchain")
|
||||
|
||||
|
||||
@@ -39,7 +39,6 @@ lint_tests: PYTHON_FILES=tests
|
||||
lint_tests: MYPY_CACHE=.mypy_cache_test
|
||||
|
||||
lint lint_diff lint_package lint_tests:
|
||||
./scripts/check_pydantic.sh .
|
||||
./scripts/lint_imports.sh
|
||||
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff check $(PYTHON_FILES)
|
||||
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff
|
||||
|
||||
@@ -18,6 +18,8 @@ from typing import (
|
||||
Union,
|
||||
)
|
||||
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from langchain_core._api.beta_decorator import beta
|
||||
from langchain_core.runnables.base import (
|
||||
Runnable,
|
||||
@@ -229,8 +231,9 @@ class ContextSet(RunnableSerializable):
|
||||
|
||||
keys: Mapping[str, Optional[Runnable]]
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -20,13 +20,14 @@ from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Sequence, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
get_buffer_string,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
|
||||
|
||||
class BaseChatMessageHistory(ABC):
|
||||
|
||||
@@ -4,10 +4,12 @@ import contextlib
|
||||
import mimetypes
|
||||
from io import BufferedReader, BytesIO
|
||||
from pathlib import PurePath
|
||||
from typing import Any, Generator, List, Literal, Mapping, Optional, Union, cast
|
||||
from typing import Any, Dict, Generator, List, Literal, Optional, Union, cast
|
||||
|
||||
from pydantic import ConfigDict, Field, model_validator
|
||||
|
||||
from langchain_core.load.serializable import Serializable
|
||||
from langchain_core.pydantic_v1 import Field, root_validator
|
||||
from langchain_core.utils.pydantic import v1_repr
|
||||
|
||||
PathLike = Union[str, PurePath]
|
||||
|
||||
@@ -110,9 +112,10 @@ class Blob(BaseMedia):
|
||||
path: Optional[PathLike] = None
|
||||
"""Location where the original content was found."""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
frozen = True
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
frozen=True,
|
||||
)
|
||||
|
||||
@property
|
||||
def source(self) -> Optional[str]:
|
||||
@@ -127,8 +130,9 @@ class Blob(BaseMedia):
|
||||
return cast(Optional[str], self.metadata["source"])
|
||||
return str(self.path) if self.path else None
|
||||
|
||||
@root_validator(pre=True)
|
||||
def check_blob_is_valid(cls, values: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_blob_is_valid(cls, values: Dict[str, Any]) -> Any:
|
||||
"""Verify that either data or path is provided."""
|
||||
if "data" not in values and "path" not in values:
|
||||
raise ValueError("Either data or path must be provided")
|
||||
@@ -293,3 +297,7 @@ class Document(BaseMedia):
|
||||
return f"page_content='{self.page_content}' metadata={self.metadata}"
|
||||
else:
|
||||
return f"page_content='{self.page_content}'"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
# TODO(0.3): Remove this override after confirming unit tests!
|
||||
return v1_repr(self)
|
||||
|
||||
@@ -3,9 +3,10 @@ from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Sequence
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain_core.callbacks import Callbacks
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.runnables import run_in_executor
|
||||
|
||||
|
||||
|
||||
@@ -4,8 +4,9 @@
|
||||
import hashlib
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
|
||||
|
||||
class FakeEmbeddings(Embeddings, BaseModel):
|
||||
|
||||
@@ -3,9 +3,10 @@
|
||||
import re
|
||||
from typing import Callable, Dict, List
|
||||
|
||||
from pydantic import BaseModel, validator
|
||||
|
||||
from langchain_core.example_selectors.base import BaseExampleSelector
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.pydantic_v1 import BaseModel, validator
|
||||
|
||||
|
||||
def _get_length_based(text: str) -> int:
|
||||
|
||||
@@ -5,9 +5,10 @@ from __future__ import annotations
|
||||
from abc import ABC
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.example_selectors.base import BaseExampleSelector
|
||||
from langchain_core.pydantic_v1 import BaseModel, Extra
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -42,9 +43,10 @@ class _VectorStoreExampleSelector(BaseExampleSelector, BaseModel, ABC):
|
||||
vectorstore_kwargs: Optional[Dict[str, Any]] = None
|
||||
"""Extra arguments passed to similarity_search function of the vectorstore."""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
extra = Extra.forbid
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
extra="forbid",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _example_to_text(
|
||||
|
||||
@@ -12,6 +12,8 @@ from typing import (
|
||||
Optional,
|
||||
)
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from langchain_core._api import beta
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
@@ -20,7 +22,6 @@ from langchain_core.callbacks import (
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.graph_vectorstores.links import METADATA_LINKS_KEY, Link
|
||||
from langchain_core.load import Serializable
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
from langchain_core.runnables import run_in_executor
|
||||
from langchain_core.vectorstores import VectorStore, VectorStoreRetriever
|
||||
|
||||
|
||||
@@ -25,10 +25,11 @@ from typing import (
|
||||
cast,
|
||||
)
|
||||
|
||||
from pydantic import model_validator
|
||||
|
||||
from langchain_core.document_loaders.base import BaseLoader
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.indexing.base import DocumentIndex, RecordManager
|
||||
from langchain_core.pydantic_v1 import root_validator
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
|
||||
# Magic UUID to use as a namespace for hashing.
|
||||
@@ -68,8 +69,9 @@ class _HashedDocument(Document):
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
return False
|
||||
|
||||
@root_validator(pre=True)
|
||||
def calculate_hashes(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def calculate_hashes(cls, values: Dict[str, Any]) -> Any:
|
||||
"""Root validator to calculate content and metadata hash."""
|
||||
content = values.get("page_content", "")
|
||||
metadata = values.get("metadata", {})
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional, Sequence, cast
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from langchain_core._api import beta
|
||||
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.indexing import UpsertResponse
|
||||
from langchain_core.indexing.base import DeleteResponse, DocumentIndex
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
|
||||
|
||||
@beta(message="Introduced in version 0.2.29. Underlying abstraction subject to change.")
|
||||
|
||||
@@ -18,6 +18,7 @@ from typing import (
|
||||
Union,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, validator
|
||||
from typing_extensions import TypeAlias, TypedDict
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
@@ -28,7 +29,6 @@ from langchain_core.messages import (
|
||||
get_buffer_string,
|
||||
)
|
||||
from langchain_core.prompt_values import PromptValue
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, validator
|
||||
from langchain_core.runnables import Runnable, RunnableSerializable
|
||||
from langchain_core.utils import get_pydantic_field_names
|
||||
|
||||
@@ -113,7 +113,11 @@ class BaseLanguageModel(
|
||||
|
||||
Caching is not currently supported for streaming methods of models.
|
||||
"""
|
||||
verbose: bool = Field(default_factory=_get_verbosity)
|
||||
# Repr = False is consistent with pydantic 1 if verbose = False
|
||||
# We can relax this for pydantic 2?
|
||||
# TODO(0.3): Resolve repr for verbose
|
||||
# Modified just to get unit tests to pass.
|
||||
verbose: bool = Field(default_factory=_get_verbosity, exclude=True, repr=False)
|
||||
"""Whether to print out response text."""
|
||||
callbacks: Callbacks = Field(default=None, exclude=True)
|
||||
"""Callbacks to add to the run trace."""
|
||||
@@ -126,6 +130,10 @@ class BaseLanguageModel(
|
||||
)
|
||||
"""Optional encoder to use for counting tokens."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
@validator("verbose", pre=True, always=True, allow_reuse=True)
|
||||
def set_verbose(cls, verbose: Optional[bool]) -> bool:
|
||||
"""If verbose is None, set it.
|
||||
|
||||
@@ -23,6 +23,13 @@ from typing import (
|
||||
cast,
|
||||
)
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
model_validator,
|
||||
)
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.caches import BaseCache
|
||||
from langchain_core.callbacks import (
|
||||
@@ -57,11 +64,6 @@ from langchain_core.outputs import (
|
||||
RunInfo,
|
||||
)
|
||||
from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue
|
||||
from langchain_core.pydantic_v1 import (
|
||||
BaseModel,
|
||||
Field,
|
||||
root_validator,
|
||||
)
|
||||
from langchain_core.rate_limiters import BaseRateLimiter
|
||||
from langchain_core.runnables import RunnableMap, RunnablePassthrough
|
||||
from langchain_core.runnables.config import ensure_config, run_in_executor
|
||||
@@ -193,14 +195,20 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
callback_manager: Optional[BaseCallbackManager] = deprecated(
|
||||
name="callback_manager", since="0.1.7", removal="1.0", alternative="callbacks"
|
||||
)(
|
||||
Field(
|
||||
default=None,
|
||||
exclude=True,
|
||||
description="Callback manager to add to the run trace.",
|
||||
)
|
||||
# TODO(0.3): Figure out how to re-apply deprecated decorator
|
||||
# callback_manager: Optional[BaseCallbackManager] = deprecated(
|
||||
# name="callback_manager", since="0.1.7", removal="1.0", alternative="callbacks"
|
||||
# )(
|
||||
# Field(
|
||||
# default=None,
|
||||
# exclude=True,
|
||||
# description="Callback manager to add to the run trace.",
|
||||
# )
|
||||
# )
|
||||
callback_manager: Optional[BaseCallbackManager] = Field(
|
||||
default=None,
|
||||
exclude=True,
|
||||
description="Callback manager to add to the run trace.",
|
||||
)
|
||||
|
||||
rate_limiter: Optional[BaseRateLimiter] = Field(default=None, exclude=True)
|
||||
@@ -218,8 +226,9 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
- If False (default), will always use streaming case if available.
|
||||
"""
|
||||
|
||||
@root_validator(pre=True)
|
||||
def raise_deprecation(cls, values: Dict) -> Dict:
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def raise_deprecation(cls, values: Dict) -> Any:
|
||||
"""Raise deprecation warning if callback_manager is used.
|
||||
|
||||
Args:
|
||||
@@ -240,8 +249,9 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
values["callbacks"] = values.pop("callback_manager", None)
|
||||
return values
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
# --- Runnable methods ---
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@ from typing import (
|
||||
)
|
||||
|
||||
import yaml
|
||||
from pydantic import ConfigDict, Field, model_validator
|
||||
from tenacity import (
|
||||
RetryCallState,
|
||||
before_sleep_log,
|
||||
@@ -62,7 +63,6 @@ from langchain_core.messages import (
|
||||
)
|
||||
from langchain_core.outputs import Generation, GenerationChunk, LLMResult, RunInfo
|
||||
from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue
|
||||
from langchain_core.pydantic_v1 import Field, root_validator
|
||||
from langchain_core.runnables import RunnableConfig, ensure_config, get_config_list
|
||||
from langchain_core.runnables.config import run_in_executor
|
||||
|
||||
@@ -300,11 +300,13 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
|
||||
"""[DEPRECATED]"""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
@root_validator(pre=True)
|
||||
def raise_deprecation(cls, values: Dict) -> Dict:
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def raise_deprecation(cls, values: Dict) -> Any:
|
||||
"""Raise deprecation warning if callback_manager is used."""
|
||||
if values.get("callback_manager") is not None:
|
||||
warnings.warn(
|
||||
|
||||
@@ -17,6 +17,7 @@ DEFAULT_NAMESPACES = [
|
||||
"langchain_core",
|
||||
"langchain_community",
|
||||
"langchain_anthropic",
|
||||
"langchain_groq",
|
||||
]
|
||||
|
||||
ALL_SERIALIZABLE_MAPPINGS = {
|
||||
|
||||
@@ -271,6 +271,11 @@ SERIALIZABLE_MAPPING: Dict[Tuple[str, ...], Tuple[str, ...]] = {
|
||||
"chat_models",
|
||||
"ChatAnthropic",
|
||||
),
|
||||
("langchain_groq", "chat_models", "ChatGroq"): (
|
||||
"langchain_groq",
|
||||
"chat_models",
|
||||
"ChatGroq",
|
||||
),
|
||||
("langchain", "chat_models", "fireworks", "ChatFireworks"): (
|
||||
"langchain_fireworks",
|
||||
"chat_models",
|
||||
@@ -287,6 +292,17 @@ SERIALIZABLE_MAPPING: Dict[Tuple[str, ...], Tuple[str, ...]] = {
|
||||
"chat_models",
|
||||
"ChatVertexAI",
|
||||
),
|
||||
("langchain", "chat_models", "mistralai", "ChatMistralAI"): (
|
||||
"langchain_mistralai",
|
||||
"chat_models",
|
||||
"ChatMistralAI",
|
||||
),
|
||||
("langchain", "chat_models", "bedrock", "ChatBedrock"): (
|
||||
"langchain_aws",
|
||||
"chat_models",
|
||||
"bedrock",
|
||||
"ChatBedrock",
|
||||
),
|
||||
("langchain", "schema", "output", "ChatGenerationChunk"): (
|
||||
"langchain_core",
|
||||
"outputs",
|
||||
|
||||
@@ -10,9 +10,10 @@ from typing import (
|
||||
cast,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.utils.pydantic import v1_repr
|
||||
|
||||
|
||||
class BaseSerialized(TypedDict):
|
||||
@@ -80,7 +81,7 @@ def try_neq_default(value: Any, key: str, model: BaseModel) -> bool:
|
||||
Exception: If the key is not in the model.
|
||||
"""
|
||||
try:
|
||||
return model.__fields__[key].get_default() != value
|
||||
return model.model_fields[key].get_default() != value
|
||||
except Exception:
|
||||
return True
|
||||
|
||||
@@ -161,16 +162,25 @@ class Serializable(BaseModel, ABC):
|
||||
For example, for the class `langchain.llms.openai.OpenAI`, the id is
|
||||
["langchain", "llms", "openai", "OpenAI"].
|
||||
"""
|
||||
return [*cls.get_lc_namespace(), cls.__name__]
|
||||
# Pydantic generics change the class name. So we need to do the following
|
||||
if (
|
||||
"origin" in cls.__pydantic_generic_metadata__
|
||||
and cls.__pydantic_generic_metadata__["origin"] is not None
|
||||
):
|
||||
original_name = cls.__pydantic_generic_metadata__["origin"].__name__
|
||||
else:
|
||||
original_name = cls.__name__
|
||||
return [*cls.get_lc_namespace(), original_name]
|
||||
|
||||
class Config:
|
||||
extra = "ignore"
|
||||
model_config = ConfigDict(
|
||||
extra="ignore",
|
||||
)
|
||||
|
||||
def __repr_args__(self) -> Any:
|
||||
return [
|
||||
(k, v)
|
||||
for k, v in super().__repr_args__()
|
||||
if (k not in self.__fields__ or try_neq_default(v, k, self))
|
||||
if (k not in self.model_fields or try_neq_default(v, k, self))
|
||||
]
|
||||
|
||||
def to_json(self) -> Union[SerializedConstructor, SerializedNotImplemented]:
|
||||
@@ -184,12 +194,15 @@ class Serializable(BaseModel, ABC):
|
||||
|
||||
secrets = dict()
|
||||
# Get latest values for kwargs if there is an attribute with same name
|
||||
lc_kwargs = {
|
||||
k: getattr(self, k, v)
|
||||
for k, v in self
|
||||
if not (self.__exclude_fields__ or {}).get(k, False) # type: ignore
|
||||
and _is_field_useful(self, k, v)
|
||||
}
|
||||
lc_kwargs = {}
|
||||
for k, v in self:
|
||||
if not _is_field_useful(self, k, v):
|
||||
continue
|
||||
# Do nothing if the field is excluded
|
||||
if k in self.model_fields and self.model_fields[k].exclude:
|
||||
continue
|
||||
|
||||
lc_kwargs[k] = getattr(self, k, v)
|
||||
|
||||
# Merge the lc_secrets and lc_attributes from every class in the MRO
|
||||
for cls in [None, *self.__class__.mro()]:
|
||||
@@ -221,8 +234,10 @@ class Serializable(BaseModel, ABC):
|
||||
# that are not present in the fields.
|
||||
for key in list(secrets):
|
||||
value = secrets[key]
|
||||
if key in this.__fields__:
|
||||
secrets[this.__fields__[key].alias] = value
|
||||
if key in this.model_fields:
|
||||
alias = this.model_fields[key].alias
|
||||
if alias is not None:
|
||||
secrets[alias] = value
|
||||
lc_kwargs.update(this.lc_attributes)
|
||||
|
||||
# include all secrets, even if not specified in kwargs
|
||||
@@ -244,6 +259,10 @@ class Serializable(BaseModel, ABC):
|
||||
def to_json_not_implemented(self) -> SerializedNotImplemented:
|
||||
return to_json_not_implemented(self)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
# TODO(0.3): Remove this override after confirming unit tests!
|
||||
return v1_repr(self)
|
||||
|
||||
|
||||
def _is_field_useful(inst: Serializable, key: str, value: Any) -> bool:
|
||||
"""Check if a field is useful as a constructor argument.
|
||||
@@ -259,9 +278,13 @@ def _is_field_useful(inst: Serializable, key: str, value: Any) -> bool:
|
||||
If the field is not required and the value is None, it is useful if the
|
||||
default value is different from the value.
|
||||
"""
|
||||
field = inst.__fields__.get(key)
|
||||
field = inst.model_fields.get(key)
|
||||
if not field:
|
||||
return False
|
||||
|
||||
if field.is_required():
|
||||
return True
|
||||
|
||||
# Handle edge case: a value cannot be converted to a boolean (e.g. a
|
||||
# Pandas DataFrame).
|
||||
try:
|
||||
@@ -269,6 +292,17 @@ def _is_field_useful(inst: Serializable, key: str, value: Any) -> bool:
|
||||
except Exception as _:
|
||||
value_is_truthy = False
|
||||
|
||||
if value_is_truthy:
|
||||
return True
|
||||
|
||||
# Value is still falsy here!
|
||||
if field.default_factory is dict and isinstance(value, dict):
|
||||
return False
|
||||
|
||||
# Value is still falsy here!
|
||||
if field.default_factory is list and isinstance(value, list):
|
||||
return False
|
||||
|
||||
# Handle edge case: inequality of two objects does not evaluate to a bool (e.g. two
|
||||
# Pandas DataFrames).
|
||||
try:
|
||||
@@ -282,7 +316,8 @@ def _is_field_useful(inst: Serializable, key: str, value: Any) -> bool:
|
||||
except Exception as _:
|
||||
value_neq_default = False
|
||||
|
||||
return field.required is True or value_is_truthy or value_neq_default
|
||||
# If value is falsy and does not match the default
|
||||
return value_is_truthy or value_neq_default
|
||||
|
||||
|
||||
def _replace_secrets(
|
||||
|
||||
@@ -13,6 +13,8 @@ from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from langchain_core.load.serializable import Serializable
|
||||
from langchain_core.runnables import run_in_executor
|
||||
|
||||
@@ -47,8 +49,9 @@ class BaseMemory(Serializable, ABC):
|
||||
pass
|
||||
""" # noqa: E501
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import json
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
from pydantic import model_validator
|
||||
from typing_extensions import Self, TypedDict
|
||||
|
||||
from langchain_core.messages.base import (
|
||||
BaseMessage,
|
||||
@@ -24,7 +25,6 @@ from langchain_core.messages.tool import (
|
||||
from langchain_core.messages.tool import (
|
||||
tool_call_chunk as create_tool_call_chunk,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import root_validator
|
||||
from langchain_core.utils._merge import merge_dicts, merge_lists
|
||||
from langchain_core.utils.json import parse_partial_json
|
||||
|
||||
@@ -111,8 +111,9 @@ class AIMessage(BaseMessage):
|
||||
"invalid_tool_calls": self.invalid_tool_calls,
|
||||
}
|
||||
|
||||
@root_validator(pre=True)
|
||||
def _backwards_compat_tool_calls(cls, values: dict) -> dict:
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def _backwards_compat_tool_calls(cls, values: dict) -> Any:
|
||||
check_additional_kwargs = not any(
|
||||
values.get(k)
|
||||
for k in ("tool_calls", "invalid_tool_calls", "tool_call_chunks")
|
||||
@@ -204,7 +205,7 @@ class AIMessage(BaseMessage):
|
||||
return (base.strip() + "\n" + "\n".join(lines)).strip()
|
||||
|
||||
|
||||
AIMessage.update_forward_refs()
|
||||
AIMessage.model_rebuild()
|
||||
|
||||
|
||||
class AIMessageChunk(AIMessage, BaseMessageChunk):
|
||||
@@ -238,8 +239,8 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
|
||||
"invalid_tool_calls": self.invalid_tool_calls,
|
||||
}
|
||||
|
||||
@root_validator(pre=False, skip_on_failure=True)
|
||||
def init_tool_calls(cls, values: dict) -> dict:
|
||||
@model_validator(mode="after")
|
||||
def init_tool_calls(self) -> Self:
|
||||
"""Initialize tool calls from tool call chunks.
|
||||
|
||||
Args:
|
||||
@@ -251,35 +252,35 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
|
||||
Raises:
|
||||
ValueError: If the tool call chunks are malformed.
|
||||
"""
|
||||
if not values["tool_call_chunks"]:
|
||||
if values["tool_calls"]:
|
||||
values["tool_call_chunks"] = [
|
||||
if not self.tool_call_chunks:
|
||||
if self.tool_calls:
|
||||
self.tool_call_chunks = [
|
||||
create_tool_call_chunk(
|
||||
name=tc["name"],
|
||||
args=json.dumps(tc["args"]),
|
||||
id=tc["id"],
|
||||
index=None,
|
||||
)
|
||||
for tc in values["tool_calls"]
|
||||
for tc in self.tool_calls
|
||||
]
|
||||
if values["invalid_tool_calls"]:
|
||||
tool_call_chunks = values.get("tool_call_chunks", [])
|
||||
if self.invalid_tool_calls:
|
||||
tool_call_chunks = self.tool_call_chunks
|
||||
tool_call_chunks.extend(
|
||||
[
|
||||
create_tool_call_chunk(
|
||||
name=tc["name"], args=tc["args"], id=tc["id"], index=None
|
||||
)
|
||||
for tc in values["invalid_tool_calls"]
|
||||
for tc in self.invalid_tool_calls
|
||||
]
|
||||
)
|
||||
values["tool_call_chunks"] = tool_call_chunks
|
||||
self.tool_call_chunks = tool_call_chunks
|
||||
|
||||
return values
|
||||
return self
|
||||
tool_calls = []
|
||||
invalid_tool_calls = []
|
||||
for chunk in values["tool_call_chunks"]:
|
||||
for chunk in self.tool_call_chunks:
|
||||
try:
|
||||
args_ = parse_partial_json(chunk["args"]) if chunk["args"] != "" else {}
|
||||
args_ = parse_partial_json(chunk["args"]) if chunk["args"] != "" else {} # type: ignore[arg-type]
|
||||
if isinstance(args_, dict):
|
||||
tool_calls.append(
|
||||
create_tool_call(
|
||||
@@ -299,9 +300,9 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
|
||||
error=None,
|
||||
)
|
||||
)
|
||||
values["tool_calls"] = tool_calls
|
||||
values["invalid_tool_calls"] = invalid_tool_calls
|
||||
return values
|
||||
self.tool_calls = tool_calls
|
||||
self.invalid_tool_calls = invalid_tool_calls
|
||||
return self
|
||||
|
||||
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
|
||||
if isinstance(other, AIMessageChunk):
|
||||
|
||||
@@ -2,11 +2,13 @@ from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union, cast
|
||||
|
||||
from pydantic import ConfigDict, Field
|
||||
|
||||
from langchain_core.load.serializable import Serializable
|
||||
from langchain_core.pydantic_v1 import Extra, Field
|
||||
from langchain_core.utils import get_bolded_text
|
||||
from langchain_core.utils._merge import merge_dicts, merge_lists
|
||||
from langchain_core.utils.interactive_env import is_interactive_env
|
||||
from langchain_core.utils.pydantic import v1_repr
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.prompts.chat import ChatPromptTemplate
|
||||
@@ -51,8 +53,9 @@ class BaseMessage(Serializable):
|
||||
"""An optional unique identifier for the message. This should ideally be
|
||||
provided by the provider/model which created the message."""
|
||||
|
||||
class Config:
|
||||
extra = Extra.allow
|
||||
model_config = ConfigDict(
|
||||
extra="allow",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self, content: Union[str, List[Union[str, Dict]]], **kwargs: Any
|
||||
@@ -108,6 +111,10 @@ class BaseMessage(Serializable):
|
||||
def pretty_print(self) -> None:
|
||||
print(self.pretty_repr(html=is_interactive_env())) # noqa: T201
|
||||
|
||||
def __repr__(self) -> str:
|
||||
# TODO(0.3): Remove this override after confirming unit tests!
|
||||
return v1_repr(self)
|
||||
|
||||
|
||||
def merge_content(
|
||||
first_content: Union[str, List[Union[str, Dict]]],
|
||||
|
||||
@@ -25,7 +25,7 @@ class ChatMessage(BaseMessage):
|
||||
return ["langchain", "schema", "messages"]
|
||||
|
||||
|
||||
ChatMessage.update_forward_refs()
|
||||
ChatMessage.model_rebuild()
|
||||
|
||||
|
||||
class ChatMessageChunk(ChatMessage, BaseMessageChunk):
|
||||
|
||||
@@ -32,7 +32,7 @@ class FunctionMessage(BaseMessage):
|
||||
return ["langchain", "schema", "messages"]
|
||||
|
||||
|
||||
FunctionMessage.update_forward_refs()
|
||||
FunctionMessage.model_rebuild()
|
||||
|
||||
|
||||
class FunctionMessageChunk(FunctionMessage, BaseMessageChunk):
|
||||
|
||||
@@ -56,7 +56,7 @@ class HumanMessage(BaseMessage):
|
||||
super().__init__(content=content, **kwargs)
|
||||
|
||||
|
||||
HumanMessage.update_forward_refs()
|
||||
HumanMessage.model_rebuild()
|
||||
|
||||
|
||||
class HumanMessageChunk(HumanMessage, BaseMessageChunk):
|
||||
|
||||
@@ -33,4 +33,4 @@ class RemoveMessage(BaseMessage):
|
||||
return ["langchain", "schema", "messages"]
|
||||
|
||||
|
||||
RemoveMessage.update_forward_refs()
|
||||
RemoveMessage.model_rebuild()
|
||||
|
||||
@@ -50,7 +50,7 @@ class SystemMessage(BaseMessage):
|
||||
super().__init__(content=content, **kwargs)
|
||||
|
||||
|
||||
SystemMessage.update_forward_refs()
|
||||
SystemMessage.model_rebuild()
|
||||
|
||||
|
||||
class SystemMessageChunk(SystemMessage, BaseMessageChunk):
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import json
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
||||
|
||||
from pydantic import Field
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
|
||||
from langchain_core.messages.base import BaseMessage, BaseMessageChunk, merge_content
|
||||
@@ -70,6 +71,11 @@ class ToolMessage(BaseMessage):
|
||||
.. versionadded:: 0.2.24
|
||||
"""
|
||||
|
||||
additional_kwargs: dict = Field(default_factory=dict, repr=False)
|
||||
"""Currently inherited from BaseMessage, but not used."""
|
||||
response_metadata: dict = Field(default_factory=dict, repr=False)
|
||||
"""Currently inherited from BaseMessage, but not used."""
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object.
|
||||
@@ -88,7 +94,7 @@ class ToolMessage(BaseMessage):
|
||||
super().__init__(content=content, **kwargs)
|
||||
|
||||
|
||||
ToolMessage.update_forward_refs()
|
||||
ToolMessage.model_rebuild()
|
||||
|
||||
|
||||
class ToolMessageChunk(ToolMessage, BaseMessageChunk):
|
||||
|
||||
@@ -52,6 +52,12 @@ AnyMessage = Union[
|
||||
SystemMessage,
|
||||
FunctionMessage,
|
||||
ToolMessage,
|
||||
AIMessageChunk,
|
||||
HumanMessageChunk,
|
||||
ChatMessageChunk,
|
||||
SystemMessageChunk,
|
||||
FunctionMessageChunk,
|
||||
ToolMessageChunk,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -13,8 +13,6 @@ from typing import (
|
||||
Union,
|
||||
)
|
||||
|
||||
from typing_extensions import get_args
|
||||
|
||||
from langchain_core.language_models import LanguageModelOutput
|
||||
from langchain_core.messages import AnyMessage, BaseMessage
|
||||
from langchain_core.outputs import ChatGeneration, Generation
|
||||
@@ -166,10 +164,11 @@ class BaseOutputParser(
|
||||
Raises:
|
||||
TypeError: If the class doesn't have an inferable OutputType.
|
||||
"""
|
||||
for cls in self.__class__.__orig_bases__: # type: ignore[attr-defined]
|
||||
type_args = get_args(cls)
|
||||
if type_args and len(type_args) == 1:
|
||||
return type_args[0]
|
||||
for base in self.__class__.mro():
|
||||
if hasattr(base, "__pydantic_generic_metadata__"):
|
||||
metadata = base.__pydantic_generic_metadata__
|
||||
if "args" in metadata and len(metadata["args"]) > 0:
|
||||
return metadata["args"][0]
|
||||
|
||||
raise TypeError(
|
||||
f"Runnable {self.__class__.__name__} doesn't have an inferable OutputType. "
|
||||
|
||||
@@ -5,7 +5,9 @@ from json import JSONDecodeError
|
||||
from typing import Any, List, Optional, Type, TypeVar, Union
|
||||
|
||||
import jsonpatch # type: ignore[import]
|
||||
import pydantic # pydantic: ignore
|
||||
import pydantic
|
||||
from pydantic import SkipValidation
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.output_parsers.format_instructions import JSON_FORMAT_INSTRUCTIONS
|
||||
@@ -22,7 +24,7 @@ if PYDANTIC_MAJOR_VERSION < 2:
|
||||
PydanticBaseModel = pydantic.BaseModel
|
||||
|
||||
else:
|
||||
from pydantic.v1 import BaseModel # pydantic: ignore
|
||||
from pydantic.v1 import BaseModel
|
||||
|
||||
# Union type needs to be last assignment to PydanticBaseModel to make mypy happy.
|
||||
PydanticBaseModel = Union[BaseModel, pydantic.BaseModel] # type: ignore
|
||||
@@ -40,7 +42,7 @@ class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
|
||||
describing the difference between the previous and the current object.
|
||||
"""
|
||||
|
||||
pydantic_object: Optional[Type[TBaseModel]] = None # type: ignore
|
||||
pydantic_object: Annotated[Optional[Type[TBaseModel]], SkipValidation()] = None # type: ignore
|
||||
"""The Pydantic object to use for validation.
|
||||
If None, no validation is performed."""
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import re
|
||||
from abc import abstractmethod
|
||||
from collections import deque
|
||||
from typing import AsyncIterator, Deque, Iterator, List, TypeVar, Union
|
||||
from typing import Optional as Optional
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.output_parsers.transform import BaseTransformOutputParser
|
||||
@@ -122,6 +123,9 @@ class ListOutputParser(BaseTransformOutputParser[List[str]]):
|
||||
yield [part]
|
||||
|
||||
|
||||
ListOutputParser.model_rebuild()
|
||||
|
||||
|
||||
class CommaSeparatedListOutputParser(ListOutputParser):
|
||||
"""Parse the output of an LLM call to a comma-separated list."""
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ import json
|
||||
from typing import Any, Dict, List, Optional, Type, Union
|
||||
|
||||
import jsonpatch # type: ignore[import]
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.output_parsers import (
|
||||
@@ -11,7 +12,6 @@ from langchain_core.output_parsers import (
|
||||
)
|
||||
from langchain_core.output_parsers.json import parse_partial_json
|
||||
from langchain_core.outputs import ChatGeneration, Generation
|
||||
from langchain_core.pydantic_v1 import BaseModel, root_validator
|
||||
|
||||
|
||||
class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
|
||||
@@ -230,8 +230,9 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser):
|
||||
determine which schema to use.
|
||||
"""
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_schema(cls, values: Dict) -> Dict:
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_schema(cls, values: Dict) -> Any:
|
||||
"""Validate the pydantic schema.
|
||||
|
||||
Args:
|
||||
@@ -267,11 +268,17 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser):
|
||||
"""
|
||||
_result = super().parse_result(result)
|
||||
if self.args_only:
|
||||
pydantic_args = self.pydantic_schema.parse_raw(_result) # type: ignore
|
||||
if hasattr(self.pydantic_schema, "model_validate_json"):
|
||||
pydantic_args = self.pydantic_schema.model_validate_json(_result)
|
||||
else:
|
||||
pydantic_args = self.pydantic_schema.parse_raw(_result) # type: ignore
|
||||
else:
|
||||
fn_name = _result["name"]
|
||||
_args = _result["arguments"]
|
||||
pydantic_args = self.pydantic_schema[fn_name].parse_raw(_args) # type: ignore
|
||||
if hasattr(self.pydantic_schema, "model_validate_json"):
|
||||
pydantic_args = self.pydantic_schema[fn_name].model_validate_json(_args) # type: ignore
|
||||
else:
|
||||
pydantic_args = self.pydantic_schema[fn_name].parse_raw(_args) # type: ignore
|
||||
return pydantic_args
|
||||
|
||||
|
||||
|
||||
@@ -3,13 +3,15 @@ import json
|
||||
from json import JSONDecodeError
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import SkipValidation, ValidationError
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.messages import AIMessage, InvalidToolCall
|
||||
from langchain_core.messages.tool import invalid_tool_call
|
||||
from langchain_core.messages.tool import tool_call as create_tool_call
|
||||
from langchain_core.output_parsers.transform import BaseCumulativeTransformOutputParser
|
||||
from langchain_core.outputs import ChatGeneration, Generation
|
||||
from langchain_core.pydantic_v1 import ValidationError
|
||||
from langchain_core.utils.json import parse_partial_json
|
||||
from langchain_core.utils.pydantic import TypeBaseModel
|
||||
|
||||
@@ -252,7 +254,7 @@ class JsonOutputKeyToolsParser(JsonOutputToolsParser):
|
||||
class PydanticToolsParser(JsonOutputToolsParser):
|
||||
"""Parse tools from OpenAI response."""
|
||||
|
||||
tools: List[TypeBaseModel]
|
||||
tools: Annotated[List[TypeBaseModel], SkipValidation()]
|
||||
"""The tools to parse."""
|
||||
|
||||
# TODO: Support more granular streaming of objects. Currently only streams once all
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import json
|
||||
from typing import Generic, List, Optional, Type
|
||||
|
||||
import pydantic # pydantic: ignore
|
||||
import pydantic
|
||||
from pydantic import SkipValidation
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.output_parsers import JsonOutputParser
|
||||
@@ -16,7 +18,7 @@ from langchain_core.utils.pydantic import (
|
||||
class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
|
||||
"""Parse an output using a pydantic model."""
|
||||
|
||||
pydantic_object: Type[TBaseModel] # type: ignore
|
||||
pydantic_object: Annotated[Type[TBaseModel], SkipValidation()] # type: ignore
|
||||
"""The pydantic model to parse."""
|
||||
|
||||
def _parse_obj(self, obj: dict) -> TBaseModel:
|
||||
@@ -111,6 +113,9 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
|
||||
return self.pydantic_object
|
||||
|
||||
|
||||
PydanticOutputParser.model_rebuild()
|
||||
|
||||
|
||||
_PYDANTIC_FORMAT_INSTRUCTIONS = """The output should be formatted as a JSON instance that conforms to the JSON schema below.
|
||||
|
||||
As an example, for the schema {{"properties": {{"foo": {{"title": "Foo", "description": "a list of strings", "type": "array", "items": {{"type": "string"}}}}}}, "required": ["foo"]}}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from typing import List
|
||||
from typing import Optional as Optional
|
||||
|
||||
from langchain_core.output_parsers.transform import BaseTransformOutputParser
|
||||
|
||||
@@ -24,3 +25,6 @@ class StrOutputParser(BaseTransformOutputParser[str]):
|
||||
def parse(self, text: str) -> str:
|
||||
"""Returns the input text with no changes."""
|
||||
return text
|
||||
|
||||
|
||||
StrOutputParser.model_rebuild()
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Literal, Union
|
||||
from typing import List, Literal, Union
|
||||
|
||||
from pydantic import model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from langchain_core.messages import BaseMessage, BaseMessageChunk
|
||||
from langchain_core.outputs.generation import Generation
|
||||
from langchain_core.pydantic_v1 import root_validator
|
||||
from langchain_core.utils._merge import merge_dicts
|
||||
|
||||
|
||||
@@ -30,8 +32,8 @@ class ChatGeneration(Generation):
|
||||
type: Literal["ChatGeneration"] = "ChatGeneration" # type: ignore[assignment]
|
||||
"""Type is used exclusively for serialization purposes."""
|
||||
|
||||
@root_validator(pre=False, skip_on_failure=True)
|
||||
def set_text(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
@model_validator(mode="after")
|
||||
def set_text(self) -> Self:
|
||||
"""Set the text attribute to be the contents of the message.
|
||||
|
||||
Args:
|
||||
@@ -45,12 +47,12 @@ class ChatGeneration(Generation):
|
||||
"""
|
||||
try:
|
||||
text = ""
|
||||
if isinstance(values["message"].content, str):
|
||||
text = values["message"].content
|
||||
if isinstance(self.message.content, str):
|
||||
text = self.message.content
|
||||
# HACK: Assumes text in content blocks in OpenAI format.
|
||||
# Uses first text block.
|
||||
elif isinstance(values["message"].content, list):
|
||||
for block in values["message"].content:
|
||||
elif isinstance(self.message.content, list):
|
||||
for block in self.message.content:
|
||||
if isinstance(block, str):
|
||||
text = block
|
||||
break
|
||||
@@ -61,10 +63,10 @@ class ChatGeneration(Generation):
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
values["text"] = text
|
||||
self.text = text
|
||||
except (KeyError, AttributeError) as e:
|
||||
raise ValueError("Error while initializing ChatGeneration") from e
|
||||
return values
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain_core.outputs.chat_generation import ChatGeneration
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
|
||||
|
||||
class ChatResult(BaseModel):
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from langchain_core.outputs.generation import Generation
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain_core.outputs.chat_generation import ChatGeneration, ChatGenerationChunk
|
||||
from langchain_core.outputs.generation import Generation, GenerationChunk
|
||||
from langchain_core.outputs.run_info import RunInfo
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
|
||||
|
||||
class LLMResult(BaseModel):
|
||||
@@ -16,7 +18,9 @@ class LLMResult(BaseModel):
|
||||
wants to return.
|
||||
"""
|
||||
|
||||
generations: List[List[Generation]]
|
||||
generations: List[
|
||||
List[Union[Generation, ChatGeneration, GenerationChunk, ChatGenerationChunk]]
|
||||
]
|
||||
"""Generated outputs.
|
||||
|
||||
The first dimension of the list represents completions for different input
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class RunInfo(BaseModel):
|
||||
|
||||
@@ -18,6 +18,8 @@ from typing import (
|
||||
)
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from langchain_core.output_parsers.base import BaseOutputParser
|
||||
from langchain_core.prompt_values import (
|
||||
@@ -25,7 +27,6 @@ from langchain_core.prompt_values import (
|
||||
PromptValue,
|
||||
StringPromptValue,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
|
||||
from langchain_core.runnables import RunnableConfig, RunnableSerializable
|
||||
from langchain_core.runnables.config import ensure_config
|
||||
from langchain_core.runnables.utils import create_model
|
||||
@@ -64,28 +65,26 @@ class BasePromptTemplate(
|
||||
tags: Optional[List[str]] = None
|
||||
"""Tags to be used for tracing."""
|
||||
|
||||
@root_validator(pre=False, skip_on_failure=True)
|
||||
def validate_variable_names(cls, values: Dict) -> Dict:
|
||||
@model_validator(mode="after")
|
||||
def validate_variable_names(self) -> Self:
|
||||
"""Validate variable names do not include restricted names."""
|
||||
if "stop" in values["input_variables"]:
|
||||
if "stop" in self.input_variables:
|
||||
raise ValueError(
|
||||
"Cannot have an input variable named 'stop', as it is used internally,"
|
||||
" please rename."
|
||||
)
|
||||
if "stop" in values["partial_variables"]:
|
||||
if "stop" in self.partial_variables:
|
||||
raise ValueError(
|
||||
"Cannot have an partial variable named 'stop', as it is used "
|
||||
"internally, please rename."
|
||||
)
|
||||
|
||||
overall = set(values["input_variables"]).intersection(
|
||||
values["partial_variables"]
|
||||
)
|
||||
overall = set(self.input_variables).intersection(self.partial_variables)
|
||||
if overall:
|
||||
raise ValueError(
|
||||
f"Found overlapping input and partial variables: {overall}"
|
||||
)
|
||||
return values
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
@@ -99,8 +98,9 @@ class BasePromptTemplate(
|
||||
Returns True."""
|
||||
return True
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
@property
|
||||
def OutputType(self) -> Any:
|
||||
|
||||
@@ -21,6 +21,14 @@ from typing import (
|
||||
overload,
|
||||
)
|
||||
|
||||
from pydantic import (
|
||||
Field,
|
||||
PositiveInt,
|
||||
SkipValidation,
|
||||
model_validator,
|
||||
)
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.load import Serializable
|
||||
from langchain_core.messages import (
|
||||
@@ -38,7 +46,6 @@ from langchain_core.prompts.base import BasePromptTemplate
|
||||
from langchain_core.prompts.image import ImagePromptTemplate
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.prompts.string import StringPromptTemplate, get_template_variables
|
||||
from langchain_core.pydantic_v1 import Field, PositiveInt, root_validator
|
||||
from langchain_core.utils import get_colored_text
|
||||
from langchain_core.utils.interactive_env import is_interactive_env
|
||||
|
||||
@@ -207,8 +214,14 @@ class MessagesPlaceholder(BaseMessagePromptTemplate):
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "prompts", "chat"]
|
||||
|
||||
def __init__(self, variable_name: str, *, optional: bool = False, **kwargs: Any):
|
||||
super().__init__(variable_name=variable_name, optional=optional, **kwargs)
|
||||
def __init__(
|
||||
self, variable_name: str, *, optional: bool = False, **kwargs: Any
|
||||
) -> None:
|
||||
# mypy can't detect the init which is defined in the parent class
|
||||
# b/c these are BaseModel classes.
|
||||
super().__init__( # type: ignore
|
||||
variable_name=variable_name, optional=optional, **kwargs
|
||||
)
|
||||
|
||||
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
"""Format messages from kwargs.
|
||||
@@ -922,7 +935,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
messages: List[MessageLike]
|
||||
messages: Annotated[List[MessageLike], SkipValidation()]
|
||||
"""List of messages consisting of either message prompt templates or messages."""
|
||||
validate_template: bool = False
|
||||
"""Whether or not to try validating the template."""
|
||||
@@ -1038,8 +1051,9 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported operand type for +: {type(other)}")
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_input_variables(cls, values: dict) -> dict:
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_input_variables(cls, values: dict) -> Any:
|
||||
"""Validate input variables.
|
||||
|
||||
If input_variables is not set, it will be set to the union of
|
||||
|
||||
@@ -5,6 +5,14 @@ from __future__ import annotations
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
model_validator,
|
||||
)
|
||||
from typing_extensions import Self
|
||||
|
||||
from langchain_core.example_selectors import BaseExampleSelector
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
from langchain_core.prompts.chat import (
|
||||
@@ -18,7 +26,6 @@ from langchain_core.prompts.string import (
|
||||
check_valid_template,
|
||||
get_template_variables,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator
|
||||
|
||||
|
||||
class _FewShotPromptTemplateMixin(BaseModel):
|
||||
@@ -32,12 +39,14 @@ class _FewShotPromptTemplateMixin(BaseModel):
|
||||
"""ExampleSelector to choose the examples to format into the prompt.
|
||||
Either this or examples should be provided."""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
extra = Extra.forbid
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
extra="forbid",
|
||||
)
|
||||
|
||||
@root_validator(pre=True)
|
||||
def check_examples_and_selector(cls, values: Dict) -> Dict:
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_examples_and_selector(cls, values: Dict) -> Any:
|
||||
"""Check that one and only one of examples/example_selector are provided.
|
||||
|
||||
Args:
|
||||
@@ -139,28 +148,29 @@ class FewShotPromptTemplate(_FewShotPromptTemplateMixin, StringPromptTemplate):
|
||||
kwargs["input_variables"] = kwargs["example_prompt"].input_variables
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@root_validator(pre=False, skip_on_failure=True)
|
||||
def template_is_valid(cls, values: Dict) -> Dict:
|
||||
@model_validator(mode="after")
|
||||
def template_is_valid(self) -> Self:
|
||||
"""Check that prefix, suffix, and input variables are consistent."""
|
||||
if values["validate_template"]:
|
||||
if self.validate_template:
|
||||
check_valid_template(
|
||||
values["prefix"] + values["suffix"],
|
||||
values["template_format"],
|
||||
values["input_variables"] + list(values["partial_variables"]),
|
||||
self.prefix + self.suffix,
|
||||
self.template_format,
|
||||
self.input_variables + list(self.partial_variables),
|
||||
)
|
||||
elif values.get("template_format"):
|
||||
values["input_variables"] = [
|
||||
elif self.template_format or None:
|
||||
self.input_variables = [
|
||||
var
|
||||
for var in get_template_variables(
|
||||
values["prefix"] + values["suffix"], values["template_format"]
|
||||
self.prefix + self.suffix, self.template_format
|
||||
)
|
||||
if var not in values["partial_variables"]
|
||||
if var not in self.partial_variables
|
||||
]
|
||||
return values
|
||||
return self
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
extra = Extra.forbid
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
extra="forbid",
|
||||
)
|
||||
|
||||
def format(self, **kwargs: Any) -> str:
|
||||
"""Format the prompt with inputs generating a string.
|
||||
@@ -365,9 +375,10 @@ class FewShotChatMessagePromptTemplate(
|
||||
"""Return whether or not the class is serializable."""
|
||||
return False
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
extra = Extra.forbid
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
extra="forbid",
|
||||
)
|
||||
|
||||
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
"""Format kwargs into a list of messages.
|
||||
|
||||
@@ -3,12 +3,14 @@
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from pydantic import ConfigDict, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.prompts.string import (
|
||||
DEFAULT_FORMATTER_MAPPING,
|
||||
StringPromptTemplate,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import Extra, root_validator
|
||||
|
||||
|
||||
class FewShotPromptWithTemplates(StringPromptTemplate):
|
||||
@@ -45,8 +47,9 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "prompts", "few_shot_with_templates"]
|
||||
|
||||
@root_validator(pre=True)
|
||||
def check_examples_and_selector(cls, values: Dict) -> Dict:
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_examples_and_selector(cls, values: Dict) -> Any:
|
||||
"""Check that one and only one of examples/example_selector are provided."""
|
||||
examples = values.get("examples", None)
|
||||
example_selector = values.get("example_selector", None)
|
||||
@@ -62,15 +65,15 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
|
||||
|
||||
return values
|
||||
|
||||
@root_validator(pre=False, skip_on_failure=True)
|
||||
def template_is_valid(cls, values: Dict) -> Dict:
|
||||
@model_validator(mode="after")
|
||||
def template_is_valid(self) -> Self:
|
||||
"""Check that prefix, suffix, and input variables are consistent."""
|
||||
if values["validate_template"]:
|
||||
input_variables = values["input_variables"]
|
||||
expected_input_variables = set(values["suffix"].input_variables)
|
||||
expected_input_variables |= set(values["partial_variables"])
|
||||
if values["prefix"] is not None:
|
||||
expected_input_variables |= set(values["prefix"].input_variables)
|
||||
if self.validate_template:
|
||||
input_variables = self.input_variables
|
||||
expected_input_variables = set(self.suffix.input_variables)
|
||||
expected_input_variables |= set(self.partial_variables)
|
||||
if self.prefix is not None:
|
||||
expected_input_variables |= set(self.prefix.input_variables)
|
||||
missing_vars = expected_input_variables.difference(input_variables)
|
||||
if missing_vars:
|
||||
raise ValueError(
|
||||
@@ -78,16 +81,17 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
|
||||
f"prefix/suffix expected {expected_input_variables}"
|
||||
)
|
||||
else:
|
||||
values["input_variables"] = sorted(
|
||||
set(values["suffix"].input_variables)
|
||||
| set(values["prefix"].input_variables if values["prefix"] else [])
|
||||
- set(values["partial_variables"])
|
||||
self.input_variables = sorted(
|
||||
set(self.suffix.input_variables)
|
||||
| set(self.prefix.input_variables if self.prefix else [])
|
||||
- set(self.partial_variables)
|
||||
)
|
||||
return values
|
||||
return self
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
extra = Extra.forbid
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
extra="forbid",
|
||||
)
|
||||
|
||||
def _get_examples(self, **kwargs: Any) -> List[dict]:
|
||||
if self.examples is not None:
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from typing import Any, List
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from langchain_core.prompt_values import ImagePromptValue, ImageURL, PromptValue
|
||||
from langchain_core.prompts.base import BasePromptTemplate
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
from langchain_core.runnables import run_in_executor
|
||||
from langchain_core.utils import image as image_utils
|
||||
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from typing import Optional as Optional
|
||||
|
||||
from pydantic import model_validator
|
||||
|
||||
from langchain_core.prompt_values import PromptValue
|
||||
from langchain_core.prompts.base import BasePromptTemplate
|
||||
from langchain_core.prompts.chat import BaseChatPromptTemplate
|
||||
from langchain_core.pydantic_v1 import root_validator
|
||||
|
||||
|
||||
def _get_inputs(inputs: dict, input_variables: List[str]) -> dict:
|
||||
@@ -34,8 +36,9 @@ class PipelinePromptTemplate(BasePromptTemplate):
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "prompts", "pipeline"]
|
||||
|
||||
@root_validator(pre=True)
|
||||
def get_input_variables(cls, values: Dict) -> Dict:
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def get_input_variables(cls, values: Dict) -> Any:
|
||||
"""Get input variables."""
|
||||
created_variables = set()
|
||||
all_variables = set()
|
||||
@@ -106,3 +109,6 @@ class PipelinePromptTemplate(BasePromptTemplate):
|
||||
@property
|
||||
def _prompt_type(self) -> str:
|
||||
raise ValueError
|
||||
|
||||
|
||||
PipelinePromptTemplate.model_rebuild()
|
||||
|
||||
@@ -6,6 +6,8 @@ import warnings
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from langchain_core.prompts.string import (
|
||||
DEFAULT_FORMATTER_MAPPING,
|
||||
StringPromptTemplate,
|
||||
@@ -13,7 +15,6 @@ from langchain_core.prompts.string import (
|
||||
get_template_variables,
|
||||
mustache_schema,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import BaseModel, root_validator
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
|
||||
@@ -73,8 +74,9 @@ class PromptTemplate(StringPromptTemplate):
|
||||
validate_template: bool = False
|
||||
"""Whether or not to try validating the template."""
|
||||
|
||||
@root_validator(pre=True)
|
||||
def pre_init_validation(cls, values: Dict) -> Dict:
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def pre_init_validation(cls, values: Dict) -> Any:
|
||||
"""Check that template and input variables are consistent."""
|
||||
if values.get("template") is None:
|
||||
# Will let pydantic fail with a ValidationError if template
|
||||
|
||||
@@ -7,10 +7,11 @@ from abc import ABC
|
||||
from string import Formatter
|
||||
from typing import Any, Callable, Dict, List, Set, Tuple, Type
|
||||
|
||||
from pydantic import BaseModel, create_model
|
||||
|
||||
import langchain_core.utils.mustache as mustache
|
||||
from langchain_core.prompt_values import PromptValue, StringPromptValue
|
||||
from langchain_core.prompts.base import BasePromptTemplate
|
||||
from langchain_core.pydantic_v1 import BaseModel, create_model
|
||||
from langchain_core.utils import get_colored_text
|
||||
from langchain_core.utils.formatting import formatter
|
||||
from langchain_core.utils.interactive_env import is_interactive_env
|
||||
|
||||
@@ -11,13 +11,14 @@ from typing import (
|
||||
Union,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from langchain_core._api.beta_decorator import beta
|
||||
from langchain_core.language_models.base import BaseLanguageModel
|
||||
from langchain_core.prompts.chat import (
|
||||
ChatPromptTemplate,
|
||||
MessageLikeRepresentation,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
from langchain_core.runnables.base import (
|
||||
Other,
|
||||
Runnable,
|
||||
|
||||
@@ -26,6 +26,7 @@ from abc import ABC, abstractmethod
|
||||
from inspect import signature
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
from pydantic import ConfigDict
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
@@ -126,8 +127,9 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
|
||||
return [self.docs[i] for i in results.argsort()[-self.k :][::-1]]
|
||||
""" # noqa: E501
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
_new_arg_supported: bool = False
|
||||
_expects_other_args: bool = False
|
||||
|
||||
@@ -35,7 +35,8 @@ from typing import (
|
||||
overload,
|
||||
)
|
||||
|
||||
from typing_extensions import Literal, get_args
|
||||
from pydantic import BaseModel, ConfigDict, Field, RootModel
|
||||
from typing_extensions import Literal, get_args, get_type_hints
|
||||
|
||||
from langchain_core._api import beta_decorator
|
||||
from langchain_core.load.dump import dumpd
|
||||
@@ -44,7 +45,6 @@ from langchain_core.load.serializable import (
|
||||
SerializedConstructor,
|
||||
SerializedNotImplemented,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
from langchain_core.runnables.config import (
|
||||
RunnableConfig,
|
||||
_set_config_context,
|
||||
@@ -83,7 +83,6 @@ from langchain_core.runnables.utils import (
|
||||
)
|
||||
from langchain_core.utils.aiter import aclosing, atee, py_anext
|
||||
from langchain_core.utils.iter import safetee
|
||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.callbacks.manager import (
|
||||
@@ -236,25 +235,58 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
For a UI (and much more) checkout LangSmith: https://docs.smith.langchain.com/
|
||||
""" # noqa: E501
|
||||
|
||||
name: Optional[str] = None
|
||||
name: Optional[str]
|
||||
"""The name of the Runnable. Used for debugging and tracing."""
|
||||
|
||||
def get_name(
|
||||
self, suffix: Optional[str] = None, *, name: Optional[str] = None
|
||||
) -> str:
|
||||
"""Get the name of the Runnable."""
|
||||
name = name or self.name or self.__class__.__name__
|
||||
if suffix:
|
||||
if name[0].isupper():
|
||||
return name + suffix.title()
|
||||
else:
|
||||
return name + "_" + suffix.lower()
|
||||
if name:
|
||||
name_ = name
|
||||
elif hasattr(self, "name") and self.name:
|
||||
name_ = self.name
|
||||
else:
|
||||
return name
|
||||
# Here we handle a case where the runnable subclass is also a pydantic
|
||||
# model.
|
||||
cls = self.__class__
|
||||
# Then it's a pydantic sub-class, and we have to check
|
||||
# whether it's a generic, and if so recover the original name.
|
||||
if (
|
||||
hasattr(
|
||||
cls,
|
||||
"__pydantic_generic_metadata__",
|
||||
)
|
||||
and "origin" in cls.__pydantic_generic_metadata__
|
||||
and cls.__pydantic_generic_metadata__["origin"] is not None
|
||||
):
|
||||
name_ = cls.__pydantic_generic_metadata__["origin"].__name__
|
||||
else:
|
||||
name_ = cls.__name__
|
||||
|
||||
if suffix:
|
||||
if name_[0].isupper():
|
||||
return name_ + suffix.title()
|
||||
else:
|
||||
return name_ + "_" + suffix.lower()
|
||||
else:
|
||||
return name_
|
||||
|
||||
@property
|
||||
def InputType(self) -> Type[Input]:
|
||||
"""The type of input this Runnable accepts specified as a type annotation."""
|
||||
# First loop through all parent classes and if any of them is
|
||||
# a pydantic model, we will pick up the generic parameterization
|
||||
# from that model via the __pydantic_generic_metadata__ attribute.
|
||||
for base in self.__class__.mro():
|
||||
if hasattr(base, "__pydantic_generic_metadata__"):
|
||||
metadata = base.__pydantic_generic_metadata__
|
||||
if "args" in metadata and len(metadata["args"]) == 2:
|
||||
return metadata["args"][0]
|
||||
|
||||
# If we didn't find a pydantic model in the parent classes,
|
||||
# then loop through __orig_bases__. This corresponds to
|
||||
# Runnables that are not pydantic models.
|
||||
for cls in self.__class__.__orig_bases__: # type: ignore[attr-defined]
|
||||
type_args = get_args(cls)
|
||||
if type_args and len(type_args) == 2:
|
||||
@@ -268,6 +300,14 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
@property
|
||||
def OutputType(self) -> Type[Output]:
|
||||
"""The type of output this Runnable produces specified as a type annotation."""
|
||||
# First loop through bases -- this will help generic
|
||||
# any pydantic models.
|
||||
for base in self.__class__.mro():
|
||||
if hasattr(base, "__pydantic_generic_metadata__"):
|
||||
metadata = base.__pydantic_generic_metadata__
|
||||
if "args" in metadata and len(metadata["args"]) == 2:
|
||||
return metadata["args"][1]
|
||||
|
||||
for cls in self.__class__.__orig_bases__: # type: ignore[attr-defined]
|
||||
type_args = get_args(cls)
|
||||
if type_args and len(type_args) == 2:
|
||||
@@ -302,14 +342,42 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
"""
|
||||
root_type = self.InputType
|
||||
|
||||
if inspect.isclass(root_type) and is_basemodel_subclass(root_type):
|
||||
if inspect.isclass(root_type) and issubclass(root_type, BaseModel):
|
||||
return root_type
|
||||
|
||||
return create_model(
|
||||
self.get_name("Input"),
|
||||
__root__=(root_type, None),
|
||||
__root__=root_type,
|
||||
)
|
||||
|
||||
def get_input_jsonschema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Get a JSON schema that represents the input to the Runnable.
|
||||
|
||||
Args:
|
||||
config: A config to use when generating the schema.
|
||||
|
||||
Returns:
|
||||
A JSON schema that represents the input to the Runnable.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_core.runnables import RunnableLambda
|
||||
|
||||
def add_one(x: int) -> int:
|
||||
return x + 1
|
||||
|
||||
runnable = RunnableLambda(add_one)
|
||||
|
||||
print(runnable.get_input_jsonschema())
|
||||
|
||||
.. versionadded:: 0.3.0
|
||||
"""
|
||||
return self.get_input_schema(config).model_json_schema()
|
||||
|
||||
@property
|
||||
def output_schema(self) -> Type[BaseModel]:
|
||||
"""The type of output this Runnable produces specified as a pydantic model."""
|
||||
@@ -334,14 +402,42 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
"""
|
||||
root_type = self.OutputType
|
||||
|
||||
if inspect.isclass(root_type) and is_basemodel_subclass(root_type):
|
||||
if inspect.isclass(root_type) and issubclass(root_type, BaseModel):
|
||||
return root_type
|
||||
|
||||
return create_model(
|
||||
self.get_name("Output"),
|
||||
__root__=(root_type, None),
|
||||
__root__=root_type,
|
||||
)
|
||||
|
||||
def get_output_jsonschema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Get a JSON schema that represents the output of the Runnable.
|
||||
|
||||
Args:
|
||||
config: A config to use when generating the schema.
|
||||
|
||||
Returns:
|
||||
A JSON schema that represents the output of the Runnable.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_core.runnables import RunnableLambda
|
||||
|
||||
def add_one(x: int) -> int:
|
||||
return x + 1
|
||||
|
||||
runnable = RunnableLambda(add_one)
|
||||
|
||||
print(runnable.get_output_jsonschema())
|
||||
|
||||
.. versionadded:: 0.3.0
|
||||
"""
|
||||
return self.get_output_schema(config).model_json_schema()
|
||||
|
||||
@property
|
||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||
"""List configurable fields for this Runnable."""
|
||||
@@ -381,15 +477,34 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
else None
|
||||
)
|
||||
|
||||
return create_model( # type: ignore[call-overload]
|
||||
self.get_name("Config"),
|
||||
# Many need to create a typed dict instead to implement NotRequired!
|
||||
all_fields = {
|
||||
**({"configurable": (configurable, None)} if configurable else {}),
|
||||
**{
|
||||
field_name: (field_type, None)
|
||||
for field_name, field_type in RunnableConfig.__annotations__.items()
|
||||
for field_name, field_type in get_type_hints(RunnableConfig).items()
|
||||
if field_name in [i for i in include if i != "configurable"]
|
||||
},
|
||||
}
|
||||
model = create_model( # type: ignore[call-overload]
|
||||
self.get_name("Config"), **all_fields
|
||||
)
|
||||
return model
|
||||
|
||||
def get_config_jsonschema(
|
||||
self, *, include: Optional[Sequence[str]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Get a JSON schema that represents the output of the Runnable.
|
||||
|
||||
Args:
|
||||
include: A list of fields to include in the config schema.
|
||||
|
||||
Returns:
|
||||
A JSON schema that represents the output of the Runnable.
|
||||
|
||||
.. versionadded:: 0.3.0
|
||||
"""
|
||||
return self.config_schema(include=include).model_json_schema()
|
||||
|
||||
def get_graph(self, config: Optional[RunnableConfig] = None) -> Graph:
|
||||
"""Return a graph representation of this Runnable."""
|
||||
@@ -579,7 +694,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
"""
|
||||
from langchain_core.runnables.passthrough import RunnableAssign
|
||||
|
||||
return self | RunnableAssign(RunnableParallel(kwargs))
|
||||
return self | RunnableAssign(RunnableParallel[Dict[str, Any]](kwargs))
|
||||
|
||||
""" --- Public API --- """
|
||||
|
||||
@@ -2129,7 +2244,6 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
name=config.get("run_name") or self.get_name(),
|
||||
run_id=config.pop("run_id", None),
|
||||
)
|
||||
iterator_ = None
|
||||
try:
|
||||
child_config = patch_config(config, callbacks=run_manager.get_child())
|
||||
if accepts_config(transformer):
|
||||
@@ -2314,7 +2428,12 @@ class RunnableSerializable(Serializable, Runnable[Input, Output]):
|
||||
"""Runnable that can be serialized to JSON."""
|
||||
|
||||
name: Optional[str] = None
|
||||
"""The name of the Runnable. Used for debugging and tracing."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
# Suppress warnings from pydantic protected namespaces
|
||||
# (e.g., `model_`)
|
||||
protected_namespaces=(),
|
||||
)
|
||||
|
||||
def to_json(self) -> Union[SerializedConstructor, SerializedNotImplemented]:
|
||||
"""Serialize the Runnable to JSON.
|
||||
@@ -2369,10 +2488,10 @@ class RunnableSerializable(Serializable, Runnable[Input, Output]):
|
||||
from langchain_core.runnables.configurable import RunnableConfigurableFields
|
||||
|
||||
for key in kwargs:
|
||||
if key not in self.__fields__:
|
||||
if key not in self.model_fields:
|
||||
raise ValueError(
|
||||
f"Configuration key {key} not found in {self}: "
|
||||
f"available keys are {self.__fields__.keys()}"
|
||||
f"available keys are {self.model_fields.keys()}"
|
||||
)
|
||||
|
||||
return RunnableConfigurableFields(default=self, fields=kwargs)
|
||||
@@ -2447,13 +2566,13 @@ def _seq_input_schema(
|
||||
return first.get_input_schema(config)
|
||||
elif isinstance(first, RunnableAssign):
|
||||
next_input_schema = _seq_input_schema(steps[1:], config)
|
||||
if not next_input_schema.__custom_root_type__:
|
||||
if not issubclass(next_input_schema, RootModel):
|
||||
# it's a dict as expected
|
||||
return create_model( # type: ignore[call-overload]
|
||||
"RunnableSequenceInput",
|
||||
**{
|
||||
k: (v.annotation, v.default)
|
||||
for k, v in next_input_schema.__fields__.items()
|
||||
for k, v in next_input_schema.model_fields.items()
|
||||
if k not in first.mapper.steps__
|
||||
},
|
||||
)
|
||||
@@ -2474,36 +2593,36 @@ def _seq_output_schema(
|
||||
elif isinstance(last, RunnableAssign):
|
||||
mapper_output_schema = last.mapper.get_output_schema(config)
|
||||
prev_output_schema = _seq_output_schema(steps[:-1], config)
|
||||
if not prev_output_schema.__custom_root_type__:
|
||||
if not issubclass(prev_output_schema, RootModel):
|
||||
# it's a dict as expected
|
||||
return create_model( # type: ignore[call-overload]
|
||||
"RunnableSequenceOutput",
|
||||
**{
|
||||
**{
|
||||
k: (v.annotation, v.default)
|
||||
for k, v in prev_output_schema.__fields__.items()
|
||||
for k, v in prev_output_schema.model_fields.items()
|
||||
},
|
||||
**{
|
||||
k: (v.annotation, v.default)
|
||||
for k, v in mapper_output_schema.__fields__.items()
|
||||
for k, v in mapper_output_schema.model_fields.items()
|
||||
},
|
||||
},
|
||||
)
|
||||
elif isinstance(last, RunnablePick):
|
||||
prev_output_schema = _seq_output_schema(steps[:-1], config)
|
||||
if not prev_output_schema.__custom_root_type__:
|
||||
if not issubclass(prev_output_schema, RootModel):
|
||||
# it's a dict as expected
|
||||
if isinstance(last.keys, list):
|
||||
return create_model( # type: ignore[call-overload]
|
||||
"RunnableSequenceOutput",
|
||||
**{
|
||||
k: (v.annotation, v.default)
|
||||
for k, v in prev_output_schema.__fields__.items()
|
||||
for k, v in prev_output_schema.model_fields.items()
|
||||
if k in last.keys
|
||||
},
|
||||
)
|
||||
else:
|
||||
field = prev_output_schema.__fields__[last.keys]
|
||||
field = prev_output_schema.model_fields[last.keys]
|
||||
return create_model( # type: ignore[call-overload]
|
||||
"RunnableSequenceOutput",
|
||||
__root__=(field.annotation, field.default),
|
||||
@@ -2665,8 +2784,9 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
"""
|
||||
return True
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
@property
|
||||
def InputType(self) -> Type[Input]:
|
||||
@@ -3402,8 +3522,9 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "runnable"]
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
def get_name(
|
||||
self, suffix: Optional[str] = None, *, name: Optional[str] = None
|
||||
@@ -3450,7 +3571,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
|
||||
**{
|
||||
k: (v.annotation, v.default)
|
||||
for step in self.steps__.values()
|
||||
for k, v in step.get_input_schema(config).__fields__.items()
|
||||
for k, v in step.get_input_schema(config).model_fields.items()
|
||||
if k != "__root__"
|
||||
},
|
||||
)
|
||||
@@ -3468,11 +3589,8 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
|
||||
Returns:
|
||||
The output schema of the Runnable.
|
||||
"""
|
||||
# This is correct, but pydantic typings/mypy don't think so.
|
||||
return create_model( # type: ignore[call-overload]
|
||||
self.get_name("Output"),
|
||||
**{k: (v.OutputType, None) for k, v in self.steps__.items()},
|
||||
)
|
||||
fields = {k: (v.OutputType, ...) for k, v in self.steps__.items()}
|
||||
return create_model(self.get_name("Output"), **fields)
|
||||
|
||||
@property
|
||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||
@@ -3882,6 +4000,8 @@ class RunnableGenerator(Runnable[Input, Output]):
|
||||
atransform: Optional[
|
||||
Callable[[AsyncIterator[Input]], AsyncIterator[Output]]
|
||||
] = None,
|
||||
*,
|
||||
name: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Initialize a RunnableGenerator.
|
||||
|
||||
@@ -3909,9 +4029,9 @@ class RunnableGenerator(Runnable[Input, Output]):
|
||||
)
|
||||
|
||||
try:
|
||||
self.name = func_for_name.__name__
|
||||
self.name = name or func_for_name.__name__
|
||||
except AttributeError:
|
||||
pass
|
||||
self.name = "RunnableGenerator"
|
||||
|
||||
@property
|
||||
def InputType(self) -> Any:
|
||||
@@ -4183,15 +4303,13 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
if all(
|
||||
item[0] == "'" and item[-1] == "'" and len(item) > 2 for item in items
|
||||
):
|
||||
fields = {item[1:-1]: (Any, ...) for item in items}
|
||||
# It's a dict, lol
|
||||
return create_model(
|
||||
self.get_name("Input"),
|
||||
**{item[1:-1]: (Any, None) for item in items}, # type: ignore
|
||||
)
|
||||
return create_model(self.get_name("Input"), **fields)
|
||||
else:
|
||||
return create_model(
|
||||
self.get_name("Input"),
|
||||
__root__=(List[Any], None),
|
||||
__root__=List[Any],
|
||||
)
|
||||
|
||||
if self.InputType != Any:
|
||||
@@ -4200,7 +4318,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
if dict_keys := get_function_first_arg_dict_keys(func):
|
||||
return create_model(
|
||||
self.get_name("Input"),
|
||||
**{key: (Any, None) for key in dict_keys}, # type: ignore
|
||||
**{key: (Any, ...) for key in dict_keys}, # type: ignore
|
||||
)
|
||||
|
||||
return super().get_input_schema(config)
|
||||
@@ -4728,8 +4846,9 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
|
||||
|
||||
bound: Runnable[Input, Output]
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
@property
|
||||
def InputType(self) -> Any:
|
||||
@@ -4756,10 +4875,7 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
|
||||
schema = self.bound.get_output_schema(config)
|
||||
return create_model(
|
||||
self.get_name("Output"),
|
||||
__root__=(
|
||||
List[schema], # type: ignore
|
||||
None,
|
||||
),
|
||||
__root__=List[schema], # type: ignore[valid-type]
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -4979,8 +5095,9 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
|
||||
The type can be a pydantic model, or a type annotation (e.g., `List[str]`).
|
||||
"""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -5316,7 +5433,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
|
||||
yield item
|
||||
|
||||
|
||||
RunnableBindingBase.update_forward_refs(RunnableConfig=RunnableConfig)
|
||||
RunnableBindingBase.model_rebuild()
|
||||
|
||||
|
||||
class RunnableBinding(RunnableBindingBase[Input, Output]):
|
||||
|
||||
@@ -14,8 +14,9 @@ from typing import (
|
||||
cast,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from langchain_core.load.dump import dumpd
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.runnables.base import (
|
||||
Runnable,
|
||||
RunnableLike,
|
||||
@@ -134,10 +135,21 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
|
||||
runnable = coerce_to_runnable(runnable)
|
||||
_branches.append((condition, runnable))
|
||||
|
||||
super().__init__(branches=_branches, default=default_) # type: ignore[call-arg]
|
||||
super().__init__(
|
||||
branches=_branches,
|
||||
default=default_,
|
||||
# Hard-coding a name here because RunnableBranch is a generic
|
||||
# and with pydantic 2, the class name with pydantic will capture
|
||||
# include the parameterized type, which is not what we want.
|
||||
# e.g., we'd get RunnableBranch[Input, Output] instead of RunnableBranch
|
||||
# for the name. This information is already captured in the
|
||||
# input and output types.
|
||||
name="RunnableBranch",
|
||||
) # type: ignore[call-arg]
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
|
||||
@@ -20,7 +20,8 @@ from typing import (
|
||||
)
|
||||
from weakref import WeakValueDictionary
|
||||
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from langchain_core.runnables.base import Runnable, RunnableSerializable
|
||||
from langchain_core.runnables.config import (
|
||||
RunnableConfig,
|
||||
@@ -58,8 +59,9 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
|
||||
|
||||
config: Optional[RunnableConfig] = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
@@ -373,28 +375,33 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
|
||||
Returns:
|
||||
List[ConfigurableFieldSpec]: The configuration specs.
|
||||
"""
|
||||
return get_unique_config_specs(
|
||||
[
|
||||
(
|
||||
# TODO(0.3): This change removes field_info which isn't needed in pydantic 2
|
||||
config_specs = []
|
||||
|
||||
for field_name, spec in self.fields.items():
|
||||
if isinstance(spec, ConfigurableField):
|
||||
config_specs.append(
|
||||
ConfigurableFieldSpec(
|
||||
id=spec.id,
|
||||
name=spec.name,
|
||||
description=spec.description
|
||||
or self.default.__fields__[field_name].field_info.description,
|
||||
or self.default.model_fields[field_name].description,
|
||||
annotation=spec.annotation
|
||||
or self.default.__fields__[field_name].annotation,
|
||||
or self.default.model_fields[field_name].annotation,
|
||||
default=getattr(self.default, field_name),
|
||||
is_shared=spec.is_shared,
|
||||
)
|
||||
if isinstance(spec, ConfigurableField)
|
||||
else make_options_spec(
|
||||
spec, self.default.__fields__[field_name].field_info.description
|
||||
)
|
||||
else:
|
||||
config_specs.append(
|
||||
make_options_spec(
|
||||
spec, self.default.model_fields[field_name].description
|
||||
)
|
||||
)
|
||||
for field_name, spec in self.fields.items()
|
||||
]
|
||||
+ list(self.default.config_specs)
|
||||
)
|
||||
|
||||
config_specs.extend(self.default.config_specs)
|
||||
|
||||
return get_unique_config_specs(config_specs)
|
||||
|
||||
def configurable_fields(
|
||||
self, **kwargs: AnyConfigurableField
|
||||
@@ -436,7 +443,7 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
|
||||
init_params = {
|
||||
k: v
|
||||
for k, v in self.default.__dict__.items()
|
||||
if k in self.default.__fields__
|
||||
if k in self.default.model_fields
|
||||
}
|
||||
return (
|
||||
self.default.__class__(**{**init_params, **configurable}),
|
||||
|
||||
@@ -18,8 +18,9 @@ from typing import (
|
||||
cast,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from langchain_core.load.dump import dumpd
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.runnables.base import Runnable, RunnableSerializable
|
||||
from langchain_core.runnables.config import (
|
||||
RunnableConfig,
|
||||
@@ -107,8 +108,9 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
||||
will not be passed to fallbacks. If used, the base Runnable and its fallbacks
|
||||
must accept a dictionary as input."""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
@property
|
||||
def InputType(self) -> Type[Input]:
|
||||
|
||||
@@ -22,8 +22,9 @@ from typing import (
|
||||
)
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain_core.utils.pydantic import _IgnoreUnserializable, is_basemodel_subclass
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.runnables.base import Runnable as RunnableType
|
||||
@@ -235,7 +236,9 @@ def node_data_json(
|
||||
json = (
|
||||
{
|
||||
"type": "schema",
|
||||
"data": node.data.schema(),
|
||||
"data": node.data.model_json_schema(
|
||||
schema_generator=_IgnoreUnserializable
|
||||
),
|
||||
}
|
||||
if with_schemas
|
||||
else {
|
||||
|
||||
@@ -13,9 +13,10 @@ from typing import (
|
||||
Union,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain_core.chat_history import BaseChatMessageHistory
|
||||
from langchain_core.load.load import load
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.runnables.base import Runnable, RunnableBindingBase, RunnableLambda
|
||||
from langchain_core.runnables.passthrough import RunnablePassthrough
|
||||
from langchain_core.runnables.utils import (
|
||||
@@ -372,28 +373,25 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
||||
def get_input_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
super_schema = super().get_input_schema(config)
|
||||
if super_schema.__custom_root_type__ or not super_schema.schema().get(
|
||||
"properties"
|
||||
):
|
||||
from langchain_core.messages import BaseMessage
|
||||
# TODO(0.3): Verify that this change was correct
|
||||
# Not enough tests and unclear on why the previous implementation was
|
||||
# necessary.
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
fields: Dict = {}
|
||||
if self.input_messages_key and self.history_messages_key:
|
||||
fields[self.input_messages_key] = (
|
||||
Union[str, BaseMessage, Sequence[BaseMessage]],
|
||||
...,
|
||||
)
|
||||
elif self.input_messages_key:
|
||||
fields[self.input_messages_key] = (Sequence[BaseMessage], ...)
|
||||
else:
|
||||
fields["__root__"] = (Sequence[BaseMessage], ...)
|
||||
return create_model( # type: ignore[call-overload]
|
||||
"RunnableWithChatHistoryInput",
|
||||
**fields,
|
||||
fields: Dict = {}
|
||||
if self.input_messages_key and self.history_messages_key:
|
||||
fields[self.input_messages_key] = (
|
||||
Union[str, BaseMessage, Sequence[BaseMessage]],
|
||||
...,
|
||||
)
|
||||
elif self.input_messages_key:
|
||||
fields[self.input_messages_key] = (Sequence[BaseMessage], ...)
|
||||
else:
|
||||
return super_schema
|
||||
fields["__root__"] = (Sequence[BaseMessage], ...)
|
||||
return create_model( # type: ignore[call-overload]
|
||||
"RunnableWithChatHistoryInput",
|
||||
**fields,
|
||||
)
|
||||
|
||||
def _is_not_async(self, *args: Sequence[Any], **kwargs: Dict[str, Any]) -> bool:
|
||||
return False
|
||||
|
||||
@@ -21,7 +21,8 @@ from typing import (
|
||||
cast,
|
||||
)
|
||||
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from pydantic import BaseModel, RootModel
|
||||
|
||||
from langchain_core.runnables.base import (
|
||||
Other,
|
||||
Runnable,
|
||||
@@ -227,7 +228,7 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
|
||||
A Runnable that merges the Dict input with the output produced by the
|
||||
mapping argument.
|
||||
"""
|
||||
return RunnableAssign(RunnableParallel(kwargs))
|
||||
return RunnableAssign(RunnableParallel[Dict[str, Any]](kwargs))
|
||||
|
||||
def invoke(
|
||||
self, input: Other, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
@@ -419,7 +420,7 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
map_input_schema = self.mapper.get_input_schema(config)
|
||||
if not map_input_schema.__custom_root_type__:
|
||||
if not issubclass(map_input_schema, RootModel):
|
||||
# ie. it's a dict
|
||||
return map_input_schema
|
||||
|
||||
@@ -430,20 +431,22 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
||||
) -> Type[BaseModel]:
|
||||
map_input_schema = self.mapper.get_input_schema(config)
|
||||
map_output_schema = self.mapper.get_output_schema(config)
|
||||
if (
|
||||
not map_input_schema.__custom_root_type__
|
||||
and not map_output_schema.__custom_root_type__
|
||||
if not issubclass(map_input_schema, RootModel) and not issubclass(
|
||||
map_output_schema, RootModel
|
||||
):
|
||||
# ie. both are dicts
|
||||
fields = {}
|
||||
|
||||
for name, field_info in map_input_schema.model_fields.items():
|
||||
fields[name] = (field_info.annotation, field_info.default)
|
||||
|
||||
for name, field_info in map_output_schema.model_fields.items():
|
||||
fields[name] = (field_info.annotation, field_info.default)
|
||||
|
||||
return create_model( # type: ignore[call-overload]
|
||||
"RunnableAssignOutput",
|
||||
**{
|
||||
k: (v.type_, v.default)
|
||||
for s in (map_input_schema, map_output_schema)
|
||||
for k, v in s.__fields__.items()
|
||||
},
|
||||
**fields,
|
||||
)
|
||||
elif not map_output_schema.__custom_root_type__:
|
||||
elif not issubclass(map_output_schema, RootModel):
|
||||
# ie. only map output is a dict
|
||||
# ie. input type is either unknown or inferred incorrectly
|
||||
return map_output_schema
|
||||
|
||||
@@ -12,6 +12,7 @@ from typing import (
|
||||
cast,
|
||||
)
|
||||
|
||||
from pydantic import ConfigDict
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from langchain_core.runnables.base import (
|
||||
@@ -83,8 +84,9 @@ class RouterRunnable(RunnableSerializable[RouterInput, Output]):
|
||||
runnables={key: coerce_to_runnable(r) for key, r in runnables.items()}
|
||||
)
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
|
||||
@@ -28,12 +28,18 @@ from typing import (
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, RootModel
|
||||
from pydantic import create_model as _create_model_base # pydantic :ignore
|
||||
from pydantic.json_schema import (
|
||||
DEFAULT_REF_TEMPLATE,
|
||||
GenerateJsonSchema,
|
||||
JsonSchemaMode,
|
||||
)
|
||||
from typing_extensions import TypeGuard
|
||||
|
||||
from langchain_core.pydantic_v1 import BaseConfig, BaseModel
|
||||
from langchain_core.pydantic_v1 import create_model as _create_model_base
|
||||
from langchain_core.runnables.schema import StreamEvent
|
||||
|
||||
Input = TypeVar("Input", contravariant=True)
|
||||
@@ -350,7 +356,7 @@ def get_function_first_arg_dict_keys(func: Callable) -> Optional[List[str]]:
|
||||
tree = ast.parse(textwrap.dedent(code))
|
||||
visitor = IsFunctionArgDict()
|
||||
visitor.visit(tree)
|
||||
return list(visitor.keys) if visitor.keys else None
|
||||
return sorted(visitor.keys) if visitor.keys else None
|
||||
except (SyntaxError, TypeError, OSError, SystemError):
|
||||
return None
|
||||
|
||||
@@ -699,9 +705,57 @@ class _RootEventFilter:
|
||||
return include
|
||||
|
||||
|
||||
class _SchemaConfig(BaseConfig):
|
||||
arbitrary_types_allowed = True
|
||||
frozen = True
|
||||
_SchemaConfig = ConfigDict(arbitrary_types_allowed=True, frozen=True)
|
||||
|
||||
NO_DEFAULT = object()
|
||||
|
||||
|
||||
def create_base_class(
|
||||
name: str, type_: Any, default_: object = NO_DEFAULT
|
||||
) -> Type[BaseModel]:
|
||||
"""Create a base class."""
|
||||
|
||||
def schema(
|
||||
cls: Type[BaseModel],
|
||||
by_alias: bool = True,
|
||||
ref_template: str = DEFAULT_REF_TEMPLATE,
|
||||
) -> Dict[str, Any]:
|
||||
# Complains about schema not being defined in superclass
|
||||
schema_ = super(cls, cls).schema( # type: ignore[misc]
|
||||
by_alias=by_alias, ref_template=ref_template
|
||||
)
|
||||
schema_["title"] = name
|
||||
return schema_
|
||||
|
||||
def model_json_schema(
|
||||
cls: Type[BaseModel],
|
||||
by_alias: bool = True,
|
||||
ref_template: str = DEFAULT_REF_TEMPLATE,
|
||||
schema_generator: type[GenerateJsonSchema] = GenerateJsonSchema,
|
||||
mode: JsonSchemaMode = "validation",
|
||||
) -> Dict[str, Any]:
|
||||
# Complains about model_json_schema not being defined in superclass
|
||||
schema_ = super(cls, cls).model_json_schema( # type: ignore[misc]
|
||||
by_alias=by_alias,
|
||||
ref_template=ref_template,
|
||||
schema_generator=schema_generator,
|
||||
mode=mode,
|
||||
)
|
||||
schema_["title"] = name
|
||||
return schema_
|
||||
|
||||
base_class_attributes = {
|
||||
"__annotations__": {"root": type_},
|
||||
"model_config": ConfigDict(arbitrary_types_allowed=True),
|
||||
"schema": classmethod(schema),
|
||||
"model_json_schema": classmethod(model_json_schema),
|
||||
"__module__": "langchain_core.runnables.utils",
|
||||
}
|
||||
|
||||
if default_ is not NO_DEFAULT:
|
||||
base_class_attributes["root"] = default_
|
||||
custom_root_type = type(name, (RootModel,), base_class_attributes)
|
||||
return cast(Type[BaseModel], custom_root_type)
|
||||
|
||||
|
||||
def create_model(
|
||||
@@ -717,6 +771,21 @@ def create_model(
|
||||
Returns:
|
||||
Type[BaseModel]: The created model.
|
||||
"""
|
||||
|
||||
# Move this to caching path
|
||||
if "__root__" in field_definitions:
|
||||
if len(field_definitions) > 1:
|
||||
raise NotImplementedError(
|
||||
"When specifying __root__ no other "
|
||||
f"fields should be provided. Got {field_definitions}"
|
||||
)
|
||||
|
||||
arg = field_definitions["__root__"]
|
||||
if isinstance(arg, tuple):
|
||||
named_root_model = create_base_class(__model_name, arg[0], arg[1])
|
||||
else:
|
||||
named_root_model = create_base_class(__model_name, arg)
|
||||
return named_root_model
|
||||
try:
|
||||
return _create_model_cached(__model_name, **field_definitions)
|
||||
except TypeError:
|
||||
|
||||
@@ -6,7 +6,7 @@ from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Any, List, Optional, Sequence, Union
|
||||
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class Visitor(ABC):
|
||||
@@ -127,7 +127,8 @@ class Comparison(FilterDirective):
|
||||
def __init__(
|
||||
self, comparator: Comparator, attribute: str, value: Any, **kwargs: Any
|
||||
) -> None:
|
||||
super().__init__(
|
||||
# super exists from BaseModel
|
||||
super().__init__( # type: ignore[call-arg]
|
||||
comparator=comparator, attribute=attribute, value=value, **kwargs
|
||||
)
|
||||
|
||||
@@ -145,8 +146,11 @@ class Operation(FilterDirective):
|
||||
|
||||
def __init__(
|
||||
self, operator: Operator, arguments: List[FilterDirective], **kwargs: Any
|
||||
):
|
||||
super().__init__(operator=operator, arguments=arguments, **kwargs)
|
||||
) -> None:
|
||||
# super exists from BaseModel
|
||||
super().__init__( # type: ignore[call-arg]
|
||||
operator=operator, arguments=arguments, **kwargs
|
||||
)
|
||||
|
||||
|
||||
class StructuredQuery(Expr):
|
||||
@@ -165,5 +169,8 @@ class StructuredQuery(Expr):
|
||||
filter: Optional[FilterDirective],
|
||||
limit: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
super().__init__(query=query, filter=filter, limit=limit, **kwargs)
|
||||
) -> None:
|
||||
# super exists from BaseModel
|
||||
super().__init__( # type: ignore[call-arg]
|
||||
query=query, filter=filter, limit=limit, **kwargs
|
||||
)
|
||||
|
||||
@@ -19,12 +19,25 @@ from typing import (
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
get_args,
|
||||
get_origin,
|
||||
get_type_hints,
|
||||
)
|
||||
|
||||
from typing_extensions import Annotated, TypeVar, get_args, get_origin
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Extra,
|
||||
Field,
|
||||
SkipValidation,
|
||||
ValidationError,
|
||||
model_validator,
|
||||
validate_arguments,
|
||||
)
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.callbacks import (
|
||||
@@ -33,16 +46,7 @@ from langchain_core.callbacks import (
|
||||
CallbackManager,
|
||||
Callbacks,
|
||||
)
|
||||
from langchain_core.load import Serializable
|
||||
from langchain_core.messages import ToolCall, ToolMessage
|
||||
from langchain_core.pydantic_v1 import (
|
||||
BaseModel,
|
||||
Extra,
|
||||
Field,
|
||||
ValidationError,
|
||||
root_validator,
|
||||
validate_arguments,
|
||||
)
|
||||
from langchain_core.messages.tool import ToolCall, ToolMessage
|
||||
from langchain_core.runnables import (
|
||||
RunnableConfig,
|
||||
RunnableSerializable,
|
||||
@@ -59,6 +63,7 @@ from langchain_core.utils.function_calling import (
|
||||
from langchain_core.utils.pydantic import (
|
||||
TypeBaseModel,
|
||||
_create_subset_model,
|
||||
get_fields,
|
||||
is_basemodel_subclass,
|
||||
is_pydantic_v1_subclass,
|
||||
is_pydantic_v2_subclass,
|
||||
@@ -204,20 +209,64 @@ def create_schema_from_function(
|
||||
"""
|
||||
# https://docs.pydantic.dev/latest/usage/validation_decorator/
|
||||
validated = validate_arguments(func, config=_SchemaConfig) # type: ignore
|
||||
|
||||
sig = inspect.signature(func)
|
||||
|
||||
# Let's ignore `self` and `cls` arguments for class and instance methods
|
||||
if func.__qualname__ and "." in func.__qualname__:
|
||||
# Then it likely belongs in a class namespace
|
||||
in_class = True
|
||||
else:
|
||||
in_class = False
|
||||
|
||||
has_args = False
|
||||
has_kwargs = False
|
||||
|
||||
for param in sig.parameters.values():
|
||||
if param.kind == param.VAR_POSITIONAL:
|
||||
has_args = True
|
||||
elif param.kind == param.VAR_KEYWORD:
|
||||
has_kwargs = True
|
||||
|
||||
inferred_model = validated.model # type: ignore
|
||||
filter_args = filter_args if filter_args is not None else FILTERED_ARGS
|
||||
for arg in filter_args:
|
||||
if arg in inferred_model.__fields__:
|
||||
del inferred_model.__fields__[arg]
|
||||
|
||||
if filter_args:
|
||||
filter_args_ = filter_args
|
||||
else:
|
||||
# Handle classmethods and instance methods
|
||||
existing_params: List[str] = list(sig.parameters.keys())
|
||||
if existing_params and existing_params[0] in ("self", "cls") and in_class:
|
||||
filter_args_ = [existing_params[0]] + list(FILTERED_ARGS)
|
||||
else:
|
||||
filter_args_ = list(FILTERED_ARGS)
|
||||
|
||||
for existing_param in existing_params:
|
||||
if not include_injected and _is_injected_arg_type(
|
||||
sig.parameters[existing_param].annotation
|
||||
):
|
||||
filter_args_.append(existing_param)
|
||||
|
||||
description, arg_descriptions = _infer_arg_descriptions(
|
||||
func,
|
||||
parse_docstring=parse_docstring,
|
||||
error_on_invalid_docstring=error_on_invalid_docstring,
|
||||
)
|
||||
# Pydantic adds placeholder virtual fields we need to strip
|
||||
valid_properties = _get_filtered_args(
|
||||
inferred_model, func, filter_args=filter_args, include_injected=include_injected
|
||||
)
|
||||
valid_properties = []
|
||||
for field in get_fields(inferred_model):
|
||||
if not has_args:
|
||||
if field == "args":
|
||||
continue
|
||||
if not has_kwargs:
|
||||
if field == "kwargs":
|
||||
continue
|
||||
|
||||
if field == "v__duplicate_kwargs": # Internal pydantic field
|
||||
continue
|
||||
|
||||
if field not in filter_args_:
|
||||
valid_properties.append(field)
|
||||
|
||||
return _create_subset_model(
|
||||
f"{model_name}Schema",
|
||||
inferred_model,
|
||||
@@ -274,7 +323,10 @@ class ChildTool(BaseTool):
|
||||
|
||||
You can provide few-shot examples as a part of the description.
|
||||
"""
|
||||
args_schema: Optional[TypeBaseModel] = None
|
||||
|
||||
args_schema: Annotated[Optional[TypeBaseModel], SkipValidation()] = Field(
|
||||
default=None, description="The tool schema."
|
||||
)
|
||||
"""Pydantic model class to validate and parse the tool's input arguments.
|
||||
|
||||
Args schema should be either:
|
||||
@@ -345,8 +397,9 @@ class ChildTool(BaseTool):
|
||||
)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
class Config(Serializable.Config):
|
||||
arbitrary_types_allowed = True
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
@property
|
||||
def is_single_input(self) -> bool:
|
||||
@@ -416,7 +469,7 @@ class ChildTool(BaseTool):
|
||||
input_args = self.args_schema
|
||||
if isinstance(tool_input, str):
|
||||
if input_args is not None:
|
||||
key_ = next(iter(input_args.__fields__.keys()))
|
||||
key_ = next(iter(get_fields(input_args).keys()))
|
||||
input_args.validate({key_: tool_input})
|
||||
return tool_input
|
||||
else:
|
||||
@@ -429,8 +482,9 @@ class ChildTool(BaseTool):
|
||||
}
|
||||
return tool_input
|
||||
|
||||
@root_validator(pre=True)
|
||||
def raise_deprecation(cls, values: Dict) -> Dict:
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def raise_deprecation(cls, values: Dict) -> Any:
|
||||
"""Raise deprecation warning if callback_manager is used.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, Literal, Optional, Type, Union, get_type_hints
|
||||
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
|
||||
from langchain_core.callbacks import Callbacks
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, create_model
|
||||
from langchain_core.runnables import Runnable
|
||||
from langchain_core.tools.base import BaseTool
|
||||
from langchain_core.tools.simple import Tool
|
||||
|
||||
@@ -3,6 +3,8 @@ from __future__ import annotations
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from langchain_core.callbacks import Callbacks
|
||||
from langchain_core.prompts import (
|
||||
BasePromptTemplate,
|
||||
@@ -10,7 +12,6 @@ from langchain_core.prompts import (
|
||||
aformat_document,
|
||||
format_document,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from langchain_core.tools.simple import Tool
|
||||
|
||||
|
||||
@@ -1,14 +1,24 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from inspect import signature
|
||||
from typing import Any, Awaitable, Callable, Dict, Optional, Tuple, Type, Union
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForToolRun,
|
||||
CallbackManagerForToolRun,
|
||||
)
|
||||
from langchain_core.messages import ToolCall
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.runnables import RunnableConfig, run_in_executor
|
||||
from langchain_core.tools.base import (
|
||||
BaseTool,
|
||||
@@ -155,3 +165,6 @@ class Tool(BaseTool):
|
||||
args_schema=args_schema,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
Tool.model_rebuild()
|
||||
|
||||
@@ -2,14 +2,26 @@ from __future__ import annotations
|
||||
|
||||
import textwrap
|
||||
from inspect import signature
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Literal, Optional, Type, Union
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel, Field, SkipValidation
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForToolRun,
|
||||
CallbackManagerForToolRun,
|
||||
)
|
||||
from langchain_core.messages import ToolCall
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
from langchain_core.runnables import RunnableConfig, run_in_executor
|
||||
from langchain_core.tools.base import (
|
||||
FILTERED_ARGS,
|
||||
@@ -24,7 +36,9 @@ class StructuredTool(BaseTool):
|
||||
"""Tool that can operate on any number of inputs."""
|
||||
|
||||
description: str = ""
|
||||
args_schema: TypeBaseModel = Field(..., description="The tool schema.")
|
||||
args_schema: Annotated[TypeBaseModel, SkipValidation()] = Field(
|
||||
..., description="The tool schema."
|
||||
)
|
||||
"""The input arguments' schema."""
|
||||
func: Optional[Callable[..., Any]]
|
||||
"""The function to run when the tool is called."""
|
||||
|
||||
@@ -11,7 +11,6 @@ from langsmith.schemas import RunBase as BaseRunV2
|
||||
from langsmith.schemas import RunTypeEnum as RunTypeEnumDep
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.outputs import LLMResult
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
|
||||
|
||||
|
||||
@@ -83,7 +82,8 @@ class LLMRun(BaseRun):
|
||||
"""Class for LLMRun."""
|
||||
|
||||
prompts: List[str]
|
||||
response: Optional[LLMResult] = None
|
||||
# Temporarily, remove but we will completely remove LLMRun
|
||||
# response: Optional[LLMResult] = None
|
||||
|
||||
|
||||
@deprecated("0.1.0", alternative="Run", removal="1.0")
|
||||
|
||||
@@ -23,11 +23,11 @@ from typing import (
|
||||
cast,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Annotated, TypedDict, get_args, get_origin, is_typeddict
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, create_model
|
||||
from langchain_core.utils.json_schema import dereference_refs
|
||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
|
||||
@@ -85,7 +85,7 @@ def _rm_titles(kv: dict, prev_key: str = "") -> dict:
|
||||
removal="1.0",
|
||||
)
|
||||
def convert_pydantic_to_openai_function(
|
||||
model: Type[BaseModel],
|
||||
model: Type,
|
||||
*,
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
@@ -109,7 +109,10 @@ def convert_pydantic_to_openai_function(
|
||||
else:
|
||||
schema = model.schema() # Pydantic 1
|
||||
schema = dereference_refs(schema)
|
||||
schema.pop("definitions", None)
|
||||
if "definitions" in schema: # pydantic 1
|
||||
schema.pop("definitions", None)
|
||||
if "$defs" in schema: # pydantic 2
|
||||
schema.pop("$defs", None)
|
||||
title = schema.pop("title", "")
|
||||
default_description = schema.pop("description", "")
|
||||
return {
|
||||
@@ -193,11 +196,13 @@ def convert_python_function_to_openai_function(
|
||||
|
||||
def _convert_typed_dict_to_openai_function(typed_dict: Type) -> FunctionDescription:
|
||||
visited: Dict = {}
|
||||
from pydantic.v1 import BaseModel
|
||||
|
||||
model = cast(
|
||||
Type[BaseModel],
|
||||
_convert_any_typed_dicts_to_pydantic(typed_dict, visited=visited),
|
||||
)
|
||||
return convert_pydantic_to_openai_function(model)
|
||||
return convert_pydantic_to_openai_function(model) # type: ignore
|
||||
|
||||
|
||||
_MAX_TYPED_DICT_RECURSION = 25
|
||||
@@ -209,6 +214,9 @@ def _convert_any_typed_dicts_to_pydantic(
|
||||
visited: Dict,
|
||||
depth: int = 0,
|
||||
) -> Type:
|
||||
from pydantic.v1 import Field as Field_v1
|
||||
from pydantic.v1 import create_model as create_model_v1
|
||||
|
||||
if type_ in visited:
|
||||
return visited[type_]
|
||||
elif depth >= _MAX_TYPED_DICT_RECURSION:
|
||||
@@ -242,7 +250,7 @@ def _convert_any_typed_dicts_to_pydantic(
|
||||
field_kwargs["description"] = arg_desc
|
||||
else:
|
||||
pass
|
||||
fields[arg] = (new_arg_type, Field(**field_kwargs))
|
||||
fields[arg] = (new_arg_type, Field_v1(**field_kwargs))
|
||||
else:
|
||||
new_arg_type = _convert_any_typed_dicts_to_pydantic(
|
||||
arg_type, depth=depth + 1, visited=visited
|
||||
@@ -250,8 +258,8 @@ def _convert_any_typed_dicts_to_pydantic(
|
||||
field_kwargs = {"default": ...}
|
||||
if arg_desc := arg_descriptions.get(arg):
|
||||
field_kwargs["description"] = arg_desc
|
||||
fields[arg] = (new_arg_type, Field(**field_kwargs))
|
||||
model = create_model(typed_dict.__name__, **fields)
|
||||
fields[arg] = (new_arg_type, Field_v1(**field_kwargs))
|
||||
model = create_model_v1(typed_dict.__name__, **fields)
|
||||
model.__doc__ = description
|
||||
visited[typed_dict] = model
|
||||
return model
|
||||
|
||||
@@ -7,9 +7,10 @@ import textwrap
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union, overload
|
||||
|
||||
import pydantic # pydantic: ignore
|
||||
|
||||
from langchain_core.pydantic_v1 import BaseModel, root_validator
|
||||
import pydantic
|
||||
from pydantic import BaseModel, root_validator
|
||||
from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue
|
||||
from pydantic_core import core_schema
|
||||
|
||||
|
||||
def get_pydantic_major_version() -> int:
|
||||
@@ -76,13 +77,13 @@ def is_basemodel_subclass(cls: Type) -> bool:
|
||||
return False
|
||||
|
||||
if PYDANTIC_MAJOR_VERSION == 1:
|
||||
from pydantic import BaseModel as BaseModelV1Proper # pydantic: ignore
|
||||
from pydantic import BaseModel as BaseModelV1Proper
|
||||
|
||||
if issubclass(cls, BaseModelV1Proper):
|
||||
return True
|
||||
elif PYDANTIC_MAJOR_VERSION == 2:
|
||||
from pydantic import BaseModel as BaseModelV2 # pydantic: ignore
|
||||
from pydantic.v1 import BaseModel as BaseModelV1 # pydantic: ignore
|
||||
from pydantic import BaseModel as BaseModelV2
|
||||
from pydantic.v1 import BaseModel as BaseModelV1
|
||||
|
||||
if issubclass(cls, BaseModelV2):
|
||||
return True
|
||||
@@ -104,13 +105,13 @@ def is_basemodel_instance(obj: Any) -> bool:
|
||||
* pydantic.v1.BaseModel in Pydantic 2.x
|
||||
"""
|
||||
if PYDANTIC_MAJOR_VERSION == 1:
|
||||
from pydantic import BaseModel as BaseModelV1Proper # pydantic: ignore
|
||||
from pydantic import BaseModel as BaseModelV1Proper
|
||||
|
||||
if isinstance(obj, BaseModelV1Proper):
|
||||
return True
|
||||
elif PYDANTIC_MAJOR_VERSION == 2:
|
||||
from pydantic import BaseModel as BaseModelV2 # pydantic: ignore
|
||||
from pydantic.v1 import BaseModel as BaseModelV1 # pydantic: ignore
|
||||
from pydantic import BaseModel as BaseModelV2
|
||||
from pydantic.v1 import BaseModel as BaseModelV1
|
||||
|
||||
if isinstance(obj, BaseModelV2):
|
||||
return True
|
||||
@@ -146,7 +147,7 @@ def pre_init(func: Callable) -> Any:
|
||||
Dict[str, Any]: The values to initialize the model with.
|
||||
"""
|
||||
# Insert default values
|
||||
fields = cls.__fields__
|
||||
fields = cls.model_fields
|
||||
for name, field_info in fields.items():
|
||||
# Check if allow_population_by_field_name is enabled
|
||||
# If yes, then set the field name to the alias
|
||||
@@ -155,9 +156,13 @@ def pre_init(func: Callable) -> Any:
|
||||
if cls.Config.allow_population_by_field_name:
|
||||
if field_info.alias in values:
|
||||
values[name] = values.pop(field_info.alias)
|
||||
if hasattr(cls, "model_config"):
|
||||
if cls.model_config.get("populate_by_name"):
|
||||
if field_info.alias in values:
|
||||
values[name] = values.pop(field_info.alias)
|
||||
|
||||
if name not in values or values[name] is None:
|
||||
if not field_info.required:
|
||||
if not field_info.is_required():
|
||||
if field_info.default_factory is not None:
|
||||
values[name] = field_info.default_factory()
|
||||
else:
|
||||
@@ -169,6 +174,46 @@ def pre_init(func: Callable) -> Any:
|
||||
return wrapper
|
||||
|
||||
|
||||
class _IgnoreUnserializable(GenerateJsonSchema):
|
||||
"""A JSON schema generator that ignores unknown types.
|
||||
|
||||
https://docs.pydantic.dev/latest/concepts/json_schema/#customizing-the-json-schema-generation-process
|
||||
"""
|
||||
|
||||
def handle_invalid_for_json_schema(
|
||||
self, schema: core_schema.CoreSchema, error_info: str
|
||||
) -> JsonSchemaValue:
|
||||
return {}
|
||||
|
||||
|
||||
def v1_repr(obj: BaseModel) -> str:
|
||||
"""Return the schema of the object as a string.
|
||||
|
||||
Get a repr for the pydantic object which is consistent with pydantic.v1.
|
||||
"""
|
||||
if not is_basemodel_instance(obj):
|
||||
raise TypeError(f"Expected a pydantic BaseModel, got {type(obj)}")
|
||||
repr_ = []
|
||||
for name, field in get_fields(obj).items():
|
||||
value = getattr(obj, name)
|
||||
|
||||
if isinstance(value, BaseModel):
|
||||
repr_.append(f"{name}={v1_repr(value)}")
|
||||
else:
|
||||
if field.exclude:
|
||||
continue
|
||||
if not field.is_required():
|
||||
if not value:
|
||||
continue
|
||||
if field.default == value:
|
||||
continue
|
||||
|
||||
repr_.append(f"{name}={repr(value)}")
|
||||
|
||||
args = ", ".join(repr_)
|
||||
return f"{obj.__class__.__name__}({args})"
|
||||
|
||||
|
||||
def _create_subset_model_v1(
|
||||
name: str,
|
||||
model: Type[BaseModel],
|
||||
@@ -178,12 +223,20 @@ def _create_subset_model_v1(
|
||||
fn_description: Optional[str] = None,
|
||||
) -> Type[BaseModel]:
|
||||
"""Create a pydantic model with only a subset of model's fields."""
|
||||
from langchain_core.pydantic_v1 import create_model
|
||||
if PYDANTIC_MAJOR_VERSION == 1:
|
||||
from pydantic import create_model
|
||||
elif PYDANTIC_MAJOR_VERSION == 2:
|
||||
from pydantic.v1 import create_model # type: ignore
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Unsupported pydantic version: {PYDANTIC_MAJOR_VERSION}"
|
||||
)
|
||||
|
||||
fields = {}
|
||||
|
||||
for field_name in field_names:
|
||||
field = model.__fields__[field_name]
|
||||
# Using pydantic v1 so can access __fields__ as a dict.
|
||||
field = model.__fields__[field_name] # type: ignore
|
||||
t = (
|
||||
# this isn't perfect but should work for most functions
|
||||
field.outer_type_
|
||||
@@ -208,8 +261,8 @@ def _create_subset_model_v2(
|
||||
fn_description: Optional[str] = None,
|
||||
) -> Type[pydantic.BaseModel]:
|
||||
"""Create a pydantic model with a subset of the model fields."""
|
||||
from pydantic import create_model # pydantic: ignore
|
||||
from pydantic.fields import FieldInfo # pydantic: ignore
|
||||
from pydantic import create_model
|
||||
from pydantic.fields import FieldInfo
|
||||
|
||||
descriptions_ = descriptions or {}
|
||||
fields = {}
|
||||
@@ -222,6 +275,17 @@ def _create_subset_model_v2(
|
||||
fields[field_name] = (field.annotation, field_info)
|
||||
rtn = create_model(name, **fields) # type: ignore
|
||||
|
||||
# TODO(0.3): Determine if there is a more "pydantic" way to preserve annotations.
|
||||
# This is done to preserve __annotations__ when working with pydantic 2.x
|
||||
# and using the Annotated type with TypedDict.
|
||||
# Comment out the following line, to trigger the relevant test case.
|
||||
selected_annotations = [
|
||||
(name, annotation)
|
||||
for name, annotation in model.__annotations__.items()
|
||||
if name in field_names
|
||||
]
|
||||
|
||||
rtn.__annotations__ = dict(selected_annotations)
|
||||
rtn.__doc__ = textwrap.dedent(fn_description or model.__doc__ or "")
|
||||
return rtn
|
||||
|
||||
@@ -248,7 +312,7 @@ def _create_subset_model(
|
||||
fn_description=fn_description,
|
||||
)
|
||||
elif PYDANTIC_MAJOR_VERSION == 2:
|
||||
from pydantic.v1 import BaseModel as BaseModelV1 # pydantic: ignore
|
||||
from pydantic.v1 import BaseModel as BaseModelV1
|
||||
|
||||
if issubclass(model, BaseModelV1):
|
||||
return _create_subset_model_v1(
|
||||
|
||||
@@ -10,9 +10,9 @@ from importlib.metadata import version
|
||||
from typing import Any, Callable, Dict, Optional, Sequence, Set, Tuple, Union, overload
|
||||
|
||||
from packaging.version import parse
|
||||
from pydantic import SecretStr
|
||||
from requests import HTTPError, Response
|
||||
|
||||
from langchain_core.pydantic_v1 import SecretStr
|
||||
from langchain_core.utils.pydantic import (
|
||||
is_pydantic_v1_subclass,
|
||||
)
|
||||
@@ -353,7 +353,7 @@ def from_env(
|
||||
|
||||
|
||||
@overload
|
||||
def secret_from_env(key: str, /) -> Callable[[], SecretStr]: ...
|
||||
def secret_from_env(key: Union[str, Sequence[str]], /) -> Callable[[], SecretStr]: ...
|
||||
|
||||
|
||||
@overload
|
||||
|
||||
@@ -42,8 +42,9 @@ from typing import (
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
from pydantic import ConfigDict, Field, model_validator
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.pydantic_v1 import Field, root_validator
|
||||
from langchain_core.retrievers import BaseRetriever, LangSmithRetrieverParams
|
||||
from langchain_core.runnables.config import run_in_executor
|
||||
|
||||
@@ -984,11 +985,13 @@ class VectorStoreRetriever(BaseRetriever):
|
||||
"mmr",
|
||||
)
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_search_type(cls, values: Dict) -> Dict:
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_search_type(cls, values: Dict) -> Any:
|
||||
"""Validate search type.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -51,7 +51,7 @@ def _cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
|
||||
f"and Y has shape {Y.shape}."
|
||||
)
|
||||
try:
|
||||
import simsimd as simd
|
||||
import simsimd as simd # type: ignore[import-not-found]
|
||||
|
||||
X = np.array(X, dtype=np.float32)
|
||||
Y = np.array(Y, dtype=np.float32)
|
||||
|
||||
97
libs/core/poetry.lock
generated
97
libs/core/poetry.lock
generated
@@ -1,4 +1,4 @@
|
||||
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "annotated-types"
|
||||
@@ -11,9 +11,6 @@ files = [
|
||||
{file = "annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.9\""}
|
||||
|
||||
[[package]]
|
||||
name = "anyio"
|
||||
version = "4.4.0"
|
||||
@@ -185,9 +182,6 @@ files = [
|
||||
{file = "babel-2.16.0.tar.gz", hash = "sha256:d1f3554ca26605fe173f3de0c65f750f5a42f924499bf134de6423582298e316"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
pytz = {version = ">=2015.7", markers = "python_version < \"3.9\""}
|
||||
|
||||
[package.extras]
|
||||
dev = ["freezegun (>=1.0,<2.0)", "pytest (>=6.0)", "pytest-cov"]
|
||||
|
||||
@@ -693,28 +687,6 @@ doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linke
|
||||
perf = ["ipython"]
|
||||
test = ["flufl.flake8", "importlib-resources (>=1.3)", "jaraco.test (>=5.4)", "packaging", "pyfakefs", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-perf (>=0.9.2)", "pytest-ruff (>=0.2.1)"]
|
||||
|
||||
[[package]]
|
||||
name = "importlib-resources"
|
||||
version = "6.4.4"
|
||||
description = "Read resources from Python packages"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "importlib_resources-6.4.4-py3-none-any.whl", hash = "sha256:dda242603d1c9cd836c3368b1174ed74cb4049ecd209e7a1a0104620c18c5c11"},
|
||||
{file = "importlib_resources-6.4.4.tar.gz", hash = "sha256:20600c8b7361938dc0bb2d5ec0297802e575df486f5a544fa414da65e13721f7"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
zipp = {version = ">=3.1.0", markers = "python_version < \"3.10\""}
|
||||
|
||||
[package.extras]
|
||||
check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)"]
|
||||
cover = ["pytest-cov"]
|
||||
doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"]
|
||||
enabler = ["pytest-enabler (>=2.2)"]
|
||||
test = ["jaraco.test (>=5.4)", "pytest (>=6,!=8.1.*)", "zipp (>=3.17)"]
|
||||
type = ["pytest-mypy"]
|
||||
|
||||
[[package]]
|
||||
name = "iniconfig"
|
||||
version = "2.0.0"
|
||||
@@ -920,11 +892,9 @@ files = [
|
||||
attrs = ">=22.2.0"
|
||||
fqdn = {version = "*", optional = true, markers = "extra == \"format-nongpl\""}
|
||||
idna = {version = "*", optional = true, markers = "extra == \"format-nongpl\""}
|
||||
importlib-resources = {version = ">=1.4.0", markers = "python_version < \"3.9\""}
|
||||
isoduration = {version = "*", optional = true, markers = "extra == \"format-nongpl\""}
|
||||
jsonpointer = {version = ">1.13", optional = true, markers = "extra == \"format-nongpl\""}
|
||||
jsonschema-specifications = ">=2023.03.6"
|
||||
pkgutil-resolve-name = {version = ">=1.3.10", markers = "python_version < \"3.9\""}
|
||||
referencing = ">=0.28.4"
|
||||
rfc3339-validator = {version = "*", optional = true, markers = "extra == \"format-nongpl\""}
|
||||
rfc3986-validator = {version = ">0.1.0", optional = true, markers = "extra == \"format-nongpl\""}
|
||||
@@ -948,7 +918,6 @@ files = [
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
importlib-resources = {version = ">=1.4.0", markers = "python_version < \"3.9\""}
|
||||
referencing = ">=0.31.0"
|
||||
|
||||
[[package]]
|
||||
@@ -1148,7 +1117,6 @@ files = [
|
||||
async-lru = ">=1.0.0"
|
||||
httpx = ">=0.25.0"
|
||||
importlib-metadata = {version = ">=4.8.3", markers = "python_version < \"3.10\""}
|
||||
importlib-resources = {version = ">=1.4", markers = "python_version < \"3.9\""}
|
||||
ipykernel = ">=6.5.0"
|
||||
jinja2 = ">=3.0.3"
|
||||
jupyter-core = "*"
|
||||
@@ -1555,43 +1523,6 @@ jupyter-server = ">=1.8,<3"
|
||||
[package.extras]
|
||||
test = ["pytest", "pytest-console-scripts", "pytest-jupyter", "pytest-tornasync"]
|
||||
|
||||
[[package]]
|
||||
name = "numpy"
|
||||
version = "1.24.4"
|
||||
description = "Fundamental package for array computing in Python"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "numpy-1.24.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c0bfb52d2169d58c1cdb8cc1f16989101639b34c7d3ce60ed70b19c63eba0b64"},
|
||||
{file = "numpy-1.24.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ed094d4f0c177b1b8e7aa9cba7d6ceed51c0e569a5318ac0ca9a090680a6a1b1"},
|
||||
{file = "numpy-1.24.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:79fc682a374c4a8ed08b331bef9c5f582585d1048fa6d80bc6c35bc384eee9b4"},
|
||||
{file = "numpy-1.24.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7ffe43c74893dbf38c2b0a1f5428760a1a9c98285553c89e12d70a96a7f3a4d6"},
|
||||
{file = "numpy-1.24.4-cp310-cp310-win32.whl", hash = "sha256:4c21decb6ea94057331e111a5bed9a79d335658c27ce2adb580fb4d54f2ad9bc"},
|
||||
{file = "numpy-1.24.4-cp310-cp310-win_amd64.whl", hash = "sha256:b4bea75e47d9586d31e892a7401f76e909712a0fd510f58f5337bea9572c571e"},
|
||||
{file = "numpy-1.24.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:f136bab9c2cfd8da131132c2cf6cc27331dd6fae65f95f69dcd4ae3c3639c810"},
|
||||
{file = "numpy-1.24.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e2926dac25b313635e4d6cf4dc4e51c8c0ebfed60b801c799ffc4c32bf3d1254"},
|
||||
{file = "numpy-1.24.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:222e40d0e2548690405b0b3c7b21d1169117391c2e82c378467ef9ab4c8f0da7"},
|
||||
{file = "numpy-1.24.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7215847ce88a85ce39baf9e89070cb860c98fdddacbaa6c0da3ffb31b3350bd5"},
|
||||
{file = "numpy-1.24.4-cp311-cp311-win32.whl", hash = "sha256:4979217d7de511a8d57f4b4b5b2b965f707768440c17cb70fbf254c4b225238d"},
|
||||
{file = "numpy-1.24.4-cp311-cp311-win_amd64.whl", hash = "sha256:b7b1fc9864d7d39e28f41d089bfd6353cb5f27ecd9905348c24187a768c79694"},
|
||||
{file = "numpy-1.24.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1452241c290f3e2a312c137a9999cdbf63f78864d63c79039bda65ee86943f61"},
|
||||
{file = "numpy-1.24.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:04640dab83f7c6c85abf9cd729c5b65f1ebd0ccf9de90b270cd61935eef0197f"},
|
||||
{file = "numpy-1.24.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a5425b114831d1e77e4b5d812b69d11d962e104095a5b9c3b641a218abcc050e"},
|
||||
{file = "numpy-1.24.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd80e219fd4c71fc3699fc1dadac5dcf4fd882bfc6f7ec53d30fa197b8ee22dc"},
|
||||
{file = "numpy-1.24.4-cp38-cp38-win32.whl", hash = "sha256:4602244f345453db537be5314d3983dbf5834a9701b7723ec28923e2889e0bb2"},
|
||||
{file = "numpy-1.24.4-cp38-cp38-win_amd64.whl", hash = "sha256:692f2e0f55794943c5bfff12b3f56f99af76f902fc47487bdfe97856de51a706"},
|
||||
{file = "numpy-1.24.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:2541312fbf09977f3b3ad449c4e5f4bb55d0dbf79226d7724211acc905049400"},
|
||||
{file = "numpy-1.24.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9667575fb6d13c95f1b36aca12c5ee3356bf001b714fc354eb5465ce1609e62f"},
|
||||
{file = "numpy-1.24.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f3a86ed21e4f87050382c7bc96571755193c4c1392490744ac73d660e8f564a9"},
|
||||
{file = "numpy-1.24.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d11efb4dbecbdf22508d55e48d9c8384db795e1b7b51ea735289ff96613ff74d"},
|
||||
{file = "numpy-1.24.4-cp39-cp39-win32.whl", hash = "sha256:6620c0acd41dbcb368610bb2f4d83145674040025e5536954782467100aa8835"},
|
||||
{file = "numpy-1.24.4-cp39-cp39-win_amd64.whl", hash = "sha256:befe2bf740fd8373cf56149a5c23a0f601e82869598d41f8e188a0e9869926f8"},
|
||||
{file = "numpy-1.24.4-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:31f13e25b4e304632a4619d0e0777662c2ffea99fcae2029556b17d8ff958aef"},
|
||||
{file = "numpy-1.24.4-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95f7ac6540e95bc440ad77f56e520da5bf877f87dca58bd095288dce8940532a"},
|
||||
{file = "numpy-1.24.4-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:e98f220aa76ca2a977fe435f5b04d7b3470c0a2e6312907b37ba6068f26787f2"},
|
||||
{file = "numpy-1.24.4.tar.gz", hash = "sha256:80f5e3a4e498641401868df4208b74581206afbee7cf7b8329daae82676d9463"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "numpy"
|
||||
version = "1.26.4"
|
||||
@@ -1776,17 +1707,6 @@ files = [
|
||||
{file = "pickleshare-0.7.5.tar.gz", hash = "sha256:87683d47965c1da65cdacaf31c8441d12b8044cdec9aca500cd78fc2c683afca"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pkgutil-resolve-name"
|
||||
version = "1.3.10"
|
||||
description = "Resolve a name to an object."
|
||||
optional = false
|
||||
python-versions = ">=3.6"
|
||||
files = [
|
||||
{file = "pkgutil_resolve_name-1.3.10-py3-none-any.whl", hash = "sha256:ca27cc078d25c5ad71a9de0a7a330146c4e014c2462d9af19c6b828280649c5e"},
|
||||
{file = "pkgutil_resolve_name-1.3.10.tar.gz", hash = "sha256:357d6c9e6a755653cfd78893817c0853af365dd51ec97f3d358a819373bbd174"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "platformdirs"
|
||||
version = "4.2.2"
|
||||
@@ -2178,17 +2098,6 @@ files = [
|
||||
{file = "python_json_logger-2.0.7-py3-none-any.whl", hash = "sha256:f380b826a991ebbe3de4d897aeec42760035ac760345e57b812938dc8b35e2bd"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pytz"
|
||||
version = "2024.1"
|
||||
description = "World timezone definitions, modern and historical"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "pytz-2024.1-py2.py3-none-any.whl", hash = "sha256:328171f4e3623139da4983451950b28e95ac706e13f3f2630a879749e7a8b319"},
|
||||
{file = "pytz-2024.1.tar.gz", hash = "sha256:2a29735ea9c18baf14b448846bde5a48030ed267578472d8955cd0e7443a9812"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pywin32"
|
||||
version = "306"
|
||||
@@ -3219,5 +3128,5 @@ type = ["pytest-mypy"]
|
||||
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
content-hash = "e86e28d75744b77f8c2173b58e950c2d23c95cacad8ee5d3110309c4248f7c09"
|
||||
python-versions = ">=3.9,<4.0"
|
||||
content-hash = "7f2ce36878754aeb498d961452c156ab63e95bc5c6bcf3ca29acae325062aed9"
|
||||
|
||||
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
|
||||
|
||||
[tool.poetry]
|
||||
name = "langchain-core"
|
||||
version = "0.2.38"
|
||||
version = "0.3.0.dev1"
|
||||
description = "Building applications with LLMs through composability"
|
||||
authors = []
|
||||
license = "MIT"
|
||||
@@ -23,7 +23,7 @@ ignore_missing_imports = true
|
||||
"Release Notes" = "https://github.com/langchain-ai/langchain/releases?q=tag%3A%22langchain-core%3D%3D0%22&expanded=true"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.8.1,<4.0"
|
||||
python = ">=3.9,<4.0"
|
||||
langsmith = "^0.1.75"
|
||||
tenacity = "^8.1.0,!=8.4.0"
|
||||
jsonpatch = "^1.33"
|
||||
|
||||
@@ -3,9 +3,9 @@ import warnings
|
||||
from typing import Any, Dict
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain_core._api.beta_decorator import beta, warn_beta
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
@@ -3,13 +3,13 @@ import warnings
|
||||
from typing import Any, Dict
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain_core._api.deprecation import (
|
||||
deprecated,
|
||||
rename_parameter,
|
||||
warn_deprecated,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
@@ -4,9 +4,10 @@ from itertools import chain
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
|
||||
|
||||
class BaseFakeCallbackHandler(BaseModel):
|
||||
@@ -256,7 +257,8 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
||||
) -> Any:
|
||||
self.on_retriever_error_common()
|
||||
|
||||
def __deepcopy__(self, memo: dict) -> "FakeCallbackHandler":
|
||||
# Overriding since BaseModel has __deepcopy__ method as well
|
||||
def __deepcopy__(self, memo: dict) -> "FakeCallbackHandler": # type: ignore
|
||||
return self
|
||||
|
||||
|
||||
@@ -390,5 +392,6 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
|
||||
) -> None:
|
||||
self.on_text_common()
|
||||
|
||||
def __deepcopy__(self, memo: dict) -> "FakeAsyncCallbackHandler":
|
||||
# Overriding since BaseModel has __deepcopy__ method as well
|
||||
def __deepcopy__(self, memo: dict) -> "FakeAsyncCallbackHandler": # type: ignore
|
||||
return self
|
||||
|
||||
@@ -9,7 +9,6 @@ from langchain_core.language_models import GenericFakeChatModel, ParrotFakeChatM
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
|
||||
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
|
||||
from tests.unit_tests.stubs import (
|
||||
AnyStr,
|
||||
_AnyIdAIMessage,
|
||||
_AnyIdAIMessageChunk,
|
||||
_AnyIdHumanMessage,
|
||||
@@ -70,8 +69,8 @@ async def test_generic_fake_chat_model_stream() -> None:
|
||||
model = GenericFakeChatModel(messages=cycle([message]))
|
||||
chunks = [chunk async for chunk in model.astream("meow")]
|
||||
assert chunks == [
|
||||
AIMessageChunk(content="", additional_kwargs={"foo": 42}, id=AnyStr()),
|
||||
AIMessageChunk(content="", additional_kwargs={"bar": 24}, id=AnyStr()),
|
||||
_AnyIdAIMessageChunk(content="", additional_kwargs={"foo": 42}),
|
||||
_AnyIdAIMessageChunk(content="", additional_kwargs={"bar": 24}),
|
||||
]
|
||||
assert len({chunk.id for chunk in chunks}) == 1
|
||||
|
||||
@@ -89,29 +88,23 @@ async def test_generic_fake_chat_model_stream() -> None:
|
||||
chunks = [chunk async for chunk in model.astream("meow")]
|
||||
|
||||
assert chunks == [
|
||||
AIMessageChunk(
|
||||
content="",
|
||||
additional_kwargs={"function_call": {"name": "move_file"}},
|
||||
id=AnyStr(),
|
||||
_AnyIdAIMessageChunk(
|
||||
content="", additional_kwargs={"function_call": {"name": "move_file"}}
|
||||
),
|
||||
AIMessageChunk(
|
||||
_AnyIdAIMessageChunk(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"function_call": {"arguments": '{\n "source_path": "foo"'},
|
||||
},
|
||||
id=AnyStr(),
|
||||
),
|
||||
AIMessageChunk(
|
||||
content="",
|
||||
additional_kwargs={"function_call": {"arguments": ","}},
|
||||
id=AnyStr(),
|
||||
_AnyIdAIMessageChunk(
|
||||
content="", additional_kwargs={"function_call": {"arguments": ","}}
|
||||
),
|
||||
AIMessageChunk(
|
||||
_AnyIdAIMessageChunk(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"function_call": {"arguments": '\n "destination_path": "bar"\n}'},
|
||||
},
|
||||
id=AnyStr(),
|
||||
),
|
||||
]
|
||||
assert len({chunk.id for chunk in chunks}) == 1
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import time
|
||||
from typing import Optional as Optional
|
||||
|
||||
from langchain_core.caches import InMemoryCache
|
||||
from langchain_core.language_models import GenericFakeChatModel
|
||||
@@ -220,6 +221,9 @@ class SerializableModel(GenericFakeChatModel):
|
||||
return True
|
||||
|
||||
|
||||
SerializableModel.model_rebuild()
|
||||
|
||||
|
||||
def test_serialization_with_rate_limiter() -> None:
|
||||
"""Test model serialization with rate limiter."""
|
||||
from langchain_core.load import dumps
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from typing import Dict
|
||||
|
||||
from pydantic import ConfigDict, Field
|
||||
|
||||
from langchain_core.load import Serializable, dumpd
|
||||
from langchain_core.load.serializable import _is_field_useful
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
|
||||
|
||||
def test_simple_serialization() -> None:
|
||||
@@ -40,8 +41,9 @@ def test_simple_serialization_is_serializable() -> None:
|
||||
|
||||
def test_simple_serialization_secret() -> None:
|
||||
"""Test handling of secrets."""
|
||||
from pydantic import SecretStr
|
||||
|
||||
from langchain_core.load import Serializable
|
||||
from langchain_core.pydantic_v1 import SecretStr
|
||||
|
||||
class Foo(Serializable):
|
||||
bar: int
|
||||
@@ -97,8 +99,9 @@ def test__is_field_useful() -> None:
|
||||
# Make sure works for fields without default.
|
||||
z: ArrayObj
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
foo = Foo(x=ArrayObj(), y=NonBoolObj(), z=ArrayObj())
|
||||
assert _is_field_useful(foo, "x", foo.x)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Module to test base parser implementations."""
|
||||
|
||||
from typing import List
|
||||
from typing import Optional as Optional
|
||||
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.language_models import GenericFakeChatModel
|
||||
@@ -46,6 +47,8 @@ def test_base_generation_parser() -> None:
|
||||
assert isinstance(content, str)
|
||||
return content.swapcase() # type: ignore
|
||||
|
||||
StrInvertCase.model_rebuild()
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="hEllo")]))
|
||||
chain = model | StrInvertCase()
|
||||
assert chain.invoke("") == "HeLLO"
|
||||
|
||||
@@ -2,12 +2,12 @@ import json
|
||||
from typing import Any, AsyncIterator, Iterator, Tuple
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.output_parsers.json import (
|
||||
SimpleJsonOutputParser,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.utils.function_calling import convert_to_openai_function
|
||||
from langchain_core.utils.json import parse_json_markdown, parse_partial_json
|
||||
from tests.unit_tests.pydantic_utils import _schema
|
||||
|
||||
@@ -2,6 +2,7 @@ import json
|
||||
from typing import Any, Dict
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
||||
@@ -10,7 +11,6 @@ from langchain_core.output_parsers.openai_functions import (
|
||||
PydanticOutputFunctionsParser,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
|
||||
|
||||
def test_json_output_function_parser() -> None:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from typing import Any, AsyncIterator, Iterator, List
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
@@ -14,7 +15,6 @@ from langchain_core.output_parsers.openai_tools import (
|
||||
PydanticToolsParser,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION
|
||||
|
||||
STREAMED_MESSAGES: list = [
|
||||
@@ -531,7 +531,7 @@ async def test_partial_pydantic_output_parser_async() -> None:
|
||||
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="This test is for pydantic 2")
|
||||
def test_parse_with_different_pydantic_2_v1() -> None:
|
||||
"""Test with pydantic.v1.BaseModel from pydantic 2."""
|
||||
import pydantic # pydantic: ignore
|
||||
import pydantic
|
||||
|
||||
class Forecast(pydantic.v1.BaseModel):
|
||||
temperature: int
|
||||
@@ -566,7 +566,7 @@ def test_parse_with_different_pydantic_2_v1() -> None:
|
||||
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="This test is for pydantic 2")
|
||||
def test_parse_with_different_pydantic_2_proper() -> None:
|
||||
"""Test with pydantic.BaseModel from pydantic 2."""
|
||||
import pydantic # pydantic: ignore
|
||||
import pydantic
|
||||
|
||||
class Forecast(pydantic.BaseModel):
|
||||
temperature: int
|
||||
@@ -601,7 +601,7 @@ def test_parse_with_different_pydantic_2_proper() -> None:
|
||||
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 1, reason="This test is for pydantic 1")
|
||||
def test_parse_with_different_pydantic_1_proper() -> None:
|
||||
"""Test with pydantic.BaseModel from pydantic 1."""
|
||||
import pydantic # pydantic: ignore
|
||||
import pydantic
|
||||
|
||||
class Forecast(pydantic.BaseModel):
|
||||
temperature: int
|
||||
|
||||
@@ -3,22 +3,17 @@
|
||||
from enum import Enum
|
||||
from typing import Literal, Optional
|
||||
|
||||
import pydantic # pydantic: ignore
|
||||
import pydantic
|
||||
import pytest
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic.v1 import BaseModel as V1BaseModel
|
||||
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.language_models import ParrotFakeChatModel
|
||||
from langchain_core.output_parsers import PydanticOutputParser
|
||||
from langchain_core.output_parsers.json import JsonOutputParser
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION, TBaseModel
|
||||
|
||||
V1BaseModel = pydantic.BaseModel
|
||||
if PYDANTIC_MAJOR_VERSION == 2:
|
||||
from pydantic.v1 import BaseModel # pydantic: ignore
|
||||
|
||||
V1BaseModel = BaseModel # type: ignore
|
||||
from langchain_core.utils.pydantic import TBaseModel
|
||||
|
||||
|
||||
class ForecastV2(pydantic.BaseModel):
|
||||
@@ -194,7 +189,7 @@ def test_pydantic_output_parser_type_inference() -> None:
|
||||
|
||||
def test_format_instructions_preserves_language() -> None:
|
||||
"""Test format instructions does not attempt to encode into ascii."""
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
description = (
|
||||
"你好, こんにちは, नमस्ते, Bonjour, Hola, "
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -4,9 +4,12 @@ from pathlib import Path
|
||||
from typing import Any, List, Tuple, Union, cast
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
from syrupy import SnapshotAssertion
|
||||
|
||||
from langchain_core._api.deprecation import LangChainPendingDeprecationWarning
|
||||
from langchain_core._api.deprecation import (
|
||||
LangChainPendingDeprecationWarning,
|
||||
)
|
||||
from langchain_core.load import dumpd, load
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
@@ -28,8 +31,6 @@ from langchain_core.prompts.chat import (
|
||||
SystemMessagePromptTemplate,
|
||||
_convert_to_message,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import ValidationError
|
||||
from tests.unit_tests.pydantic_utils import _schema
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -794,21 +795,21 @@ def test_chat_input_schema(snapshot: SnapshotAssertion) -> None:
|
||||
assert prompt_all_required.optional_variables == []
|
||||
with pytest.raises(ValidationError):
|
||||
prompt_all_required.input_schema(input="")
|
||||
assert _schema(prompt_all_required.input_schema) == snapshot(name="required")
|
||||
assert prompt_all_required.get_input_jsonschema() == snapshot(name="required")
|
||||
prompt_optional = ChatPromptTemplate(
|
||||
messages=[MessagesPlaceholder("history", optional=True), ("user", "${input}")]
|
||||
)
|
||||
# input variables only lists required variables
|
||||
assert set(prompt_optional.input_variables) == {"input"}
|
||||
prompt_optional.input_schema(input="") # won't raise error
|
||||
assert _schema(prompt_optional.input_schema) == snapshot(name="partial")
|
||||
assert prompt_optional.get_input_jsonschema() == snapshot(name="partial")
|
||||
|
||||
|
||||
def test_chat_prompt_w_msgs_placeholder_ser_des(snapshot: SnapshotAssertion) -> None:
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[("system", "foo"), MessagesPlaceholder("bar"), ("human", "baz")]
|
||||
)
|
||||
assert dumpd(MessagesPlaceholder("bar")) == snapshot(name="placholder")
|
||||
assert dumpd(MessagesPlaceholder("bar")) == snapshot(name="placeholder")
|
||||
assert load(dumpd(MessagesPlaceholder("bar"))) == MessagesPlaceholder("bar")
|
||||
assert dumpd(prompt) == snapshot(name="chat_prompt")
|
||||
assert load(dumpd(prompt)) == prompt
|
||||
|
||||
@@ -7,7 +7,6 @@ import pytest
|
||||
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.tracers.run_collector import RunCollectorCallbackHandler
|
||||
from tests.unit_tests.pydantic_utils import _schema
|
||||
|
||||
|
||||
def test_prompt_valid() -> None:
|
||||
@@ -70,10 +69,10 @@ def test_mustache_prompt_from_template() -> None:
|
||||
prompt = PromptTemplate.from_template(template, template_format="mustache")
|
||||
assert prompt.format(foo="bar") == "This is a bar test."
|
||||
assert prompt.input_variables == ["foo"]
|
||||
assert _schema(prompt.input_schema) == {
|
||||
assert prompt.get_input_jsonschema() == {
|
||||
"title": "PromptInput",
|
||||
"type": "object",
|
||||
"properties": {"foo": {"title": "Foo", "type": "string"}},
|
||||
"properties": {"foo": {"title": "Foo", "type": "string", "default": None}},
|
||||
}
|
||||
|
||||
# Multiple input variables.
|
||||
@@ -81,12 +80,12 @@ def test_mustache_prompt_from_template() -> None:
|
||||
prompt = PromptTemplate.from_template(template, template_format="mustache")
|
||||
assert prompt.format(bar="baz", foo="bar") == "This baz is a bar test."
|
||||
assert prompt.input_variables == ["bar", "foo"]
|
||||
assert _schema(prompt.input_schema) == {
|
||||
assert prompt.get_input_jsonschema() == {
|
||||
"title": "PromptInput",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"bar": {"title": "Bar", "type": "string"},
|
||||
"foo": {"title": "Foo", "type": "string"},
|
||||
"bar": {"title": "Bar", "type": "string", "default": None},
|
||||
"foo": {"title": "Foo", "type": "string", "default": None},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -95,12 +94,12 @@ def test_mustache_prompt_from_template() -> None:
|
||||
prompt = PromptTemplate.from_template(template, template_format="mustache")
|
||||
assert prompt.format(bar="baz", foo="bar") == "This baz is a bar test bar."
|
||||
assert prompt.input_variables == ["bar", "foo"]
|
||||
assert _schema(prompt.input_schema) == {
|
||||
assert prompt.get_input_jsonschema() == {
|
||||
"title": "PromptInput",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"bar": {"title": "Bar", "type": "string"},
|
||||
"foo": {"title": "Foo", "type": "string"},
|
||||
"bar": {"title": "Bar", "type": "string", "default": None},
|
||||
"foo": {"title": "Foo", "type": "string", "default": None},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -111,23 +110,23 @@ def test_mustache_prompt_from_template() -> None:
|
||||
"This foo is a bar test baz."
|
||||
)
|
||||
assert prompt.input_variables == ["foo", "obj"]
|
||||
assert _schema(prompt.input_schema) == {
|
||||
"title": "PromptInput",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"foo": {"title": "Foo", "type": "string"},
|
||||
"obj": {"$ref": "#/definitions/obj"},
|
||||
},
|
||||
"definitions": {
|
||||
assert prompt.get_input_jsonschema() == {
|
||||
"$defs": {
|
||||
"obj": {
|
||||
"properties": {
|
||||
"bar": {"default": None, "title": "Bar", "type": "string"},
|
||||
"foo": {"default": None, "title": "Foo", "type": "string"},
|
||||
},
|
||||
"title": "obj",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"foo": {"title": "Foo", "type": "string"},
|
||||
"bar": {"title": "Bar", "type": "string"},
|
||||
},
|
||||
}
|
||||
},
|
||||
"properties": {
|
||||
"foo": {"default": None, "title": "Foo", "type": "string"},
|
||||
"obj": {"allOf": [{"$ref": "#/$defs/obj"}], "default": None},
|
||||
},
|
||||
"title": "PromptInput",
|
||||
"type": "object",
|
||||
}
|
||||
|
||||
# . variables
|
||||
@@ -135,7 +134,7 @@ def test_mustache_prompt_from_template() -> None:
|
||||
prompt = PromptTemplate.from_template(template, template_format="mustache")
|
||||
assert prompt.format(foo="baz") == ("This {'foo': 'baz'} is a test.")
|
||||
assert prompt.input_variables == []
|
||||
assert _schema(prompt.input_schema) == {
|
||||
assert prompt.get_input_jsonschema() == {
|
||||
"title": "PromptInput",
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
@@ -152,17 +151,19 @@ def test_mustache_prompt_from_template() -> None:
|
||||
is a test."""
|
||||
)
|
||||
assert prompt.input_variables == ["foo"]
|
||||
assert _schema(prompt.input_schema) == {
|
||||
"title": "PromptInput",
|
||||
"type": "object",
|
||||
"properties": {"foo": {"$ref": "#/definitions/foo"}},
|
||||
"definitions": {
|
||||
assert prompt.get_input_jsonschema() == {
|
||||
"$defs": {
|
||||
"foo": {
|
||||
"properties": {
|
||||
"bar": {"default": None, "title": "Bar", "type": "string"}
|
||||
},
|
||||
"title": "foo",
|
||||
"type": "object",
|
||||
"properties": {"bar": {"title": "Bar", "type": "string"}},
|
||||
}
|
||||
},
|
||||
"properties": {"foo": {"allOf": [{"$ref": "#/$defs/foo"}], "default": None}},
|
||||
"title": "PromptInput",
|
||||
"type": "object",
|
||||
}
|
||||
|
||||
# more complex nested section/context variables
|
||||
@@ -184,26 +185,28 @@ def test_mustache_prompt_from_template() -> None:
|
||||
is a test."""
|
||||
)
|
||||
assert prompt.input_variables == ["foo"]
|
||||
assert _schema(prompt.input_schema) == {
|
||||
"title": "PromptInput",
|
||||
"type": "object",
|
||||
"properties": {"foo": {"$ref": "#/definitions/foo"}},
|
||||
"definitions": {
|
||||
"foo": {
|
||||
"title": "foo",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"bar": {"title": "Bar", "type": "string"},
|
||||
"baz": {"$ref": "#/definitions/baz"},
|
||||
"quux": {"title": "Quux", "type": "string"},
|
||||
},
|
||||
},
|
||||
assert prompt.get_input_jsonschema() == {
|
||||
"$defs": {
|
||||
"baz": {
|
||||
"properties": {
|
||||
"qux": {"default": None, "title": "Qux", "type": "string"}
|
||||
},
|
||||
"title": "baz",
|
||||
"type": "object",
|
||||
"properties": {"qux": {"title": "Qux", "type": "string"}},
|
||||
},
|
||||
"foo": {
|
||||
"properties": {
|
||||
"bar": {"default": None, "title": "Bar", "type": "string"},
|
||||
"baz": {"allOf": [{"$ref": "#/$defs/baz"}], "default": None},
|
||||
"quux": {"default": None, "title": "Quux", "type": "string"},
|
||||
},
|
||||
"title": "foo",
|
||||
"type": "object",
|
||||
},
|
||||
},
|
||||
"properties": {"foo": {"allOf": [{"$ref": "#/$defs/foo"}], "default": None}},
|
||||
"title": "PromptInput",
|
||||
"type": "object",
|
||||
}
|
||||
|
||||
# triply nested section/context variables
|
||||
@@ -239,39 +242,43 @@ def test_mustache_prompt_from_template() -> None:
|
||||
is a test."""
|
||||
)
|
||||
assert prompt.input_variables == ["foo"]
|
||||
assert _schema(prompt.input_schema) == {
|
||||
"title": "PromptInput",
|
||||
"type": "object",
|
||||
"properties": {"foo": {"$ref": "#/definitions/foo"}},
|
||||
"definitions": {
|
||||
"foo": {
|
||||
"title": "foo",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"bar": {"title": "Bar", "type": "string"},
|
||||
"baz": {"$ref": "#/definitions/baz"},
|
||||
"quux": {"title": "Quux", "type": "string"},
|
||||
},
|
||||
},
|
||||
"baz": {
|
||||
"title": "baz",
|
||||
"type": "object",
|
||||
"properties": {"qux": {"$ref": "#/definitions/qux"}},
|
||||
},
|
||||
"qux": {
|
||||
"title": "qux",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"foobar": {"title": "Foobar", "type": "string"},
|
||||
"barfoo": {"$ref": "#/definitions/barfoo"},
|
||||
},
|
||||
},
|
||||
assert prompt.get_input_jsonschema() == {
|
||||
"$defs": {
|
||||
"barfoo": {
|
||||
"properties": {
|
||||
"foobar": {"default": None, "title": "Foobar", "type": "string"}
|
||||
},
|
||||
"title": "barfoo",
|
||||
"type": "object",
|
||||
"properties": {"foobar": {"title": "Foobar", "type": "string"}},
|
||||
},
|
||||
"baz": {
|
||||
"properties": {
|
||||
"qux": {"allOf": [{"$ref": "#/$defs/qux"}], "default": None}
|
||||
},
|
||||
"title": "baz",
|
||||
"type": "object",
|
||||
},
|
||||
"foo": {
|
||||
"properties": {
|
||||
"bar": {"default": None, "title": "Bar", "type": "string"},
|
||||
"baz": {"allOf": [{"$ref": "#/$defs/baz"}], "default": None},
|
||||
"quux": {"default": None, "title": "Quux", "type": "string"},
|
||||
},
|
||||
"title": "foo",
|
||||
"type": "object",
|
||||
},
|
||||
"qux": {
|
||||
"properties": {
|
||||
"barfoo": {"allOf": [{"$ref": "#/$defs/barfoo"}], "default": None},
|
||||
"foobar": {"default": None, "title": "Foobar", "type": "string"},
|
||||
},
|
||||
"title": "qux",
|
||||
"type": "object",
|
||||
},
|
||||
},
|
||||
"properties": {"foo": {"allOf": [{"$ref": "#/$defs/foo"}], "default": None}},
|
||||
"title": "PromptInput",
|
||||
"type": "object",
|
||||
}
|
||||
|
||||
# section/context variables with repeats
|
||||
@@ -287,19 +294,20 @@ def test_mustache_prompt_from_template() -> None:
|
||||
is a test."""
|
||||
)
|
||||
assert prompt.input_variables == ["foo"]
|
||||
assert _schema(prompt.input_schema) == {
|
||||
"title": "PromptInput",
|
||||
"type": "object",
|
||||
"properties": {"foo": {"$ref": "#/definitions/foo"}},
|
||||
"definitions": {
|
||||
assert prompt.get_input_jsonschema() == {
|
||||
"$defs": {
|
||||
"foo": {
|
||||
"properties": {
|
||||
"bar": {"default": None, "title": "Bar", "type": "string"}
|
||||
},
|
||||
"title": "foo",
|
||||
"type": "object",
|
||||
"properties": {"bar": {"title": "Bar", "type": "string"}},
|
||||
}
|
||||
},
|
||||
"properties": {"foo": {"allOf": [{"$ref": "#/$defs/foo"}], "default": None}},
|
||||
"title": "PromptInput",
|
||||
"type": "object",
|
||||
}
|
||||
|
||||
template = """This{{^foo}}
|
||||
no foos
|
||||
{{/foo}}is a test."""
|
||||
@@ -310,10 +318,10 @@ def test_mustache_prompt_from_template() -> None:
|
||||
is a test."""
|
||||
)
|
||||
assert prompt.input_variables == ["foo"]
|
||||
assert _schema(prompt.input_schema) == {
|
||||
assert prompt.get_input_jsonschema() == {
|
||||
"properties": {"foo": {"default": None, "title": "Foo", "type": "object"}},
|
||||
"title": "PromptInput",
|
||||
"type": "object",
|
||||
"properties": {"foo": {"title": "Foo", "type": "object"}},
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
from functools import partial
|
||||
from inspect import isclass
|
||||
from typing import Any, Dict, Type, Union, cast
|
||||
from typing import Optional as Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain_core.language_models import FakeListChatModel
|
||||
from langchain_core.load.dump import dumps
|
||||
from langchain_core.load.load import loads
|
||||
from langchain_core.prompts.structured import StructuredPrompt
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.runnables.base import Runnable, RunnableLambda
|
||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
|
||||
@@ -34,6 +36,9 @@ class FakeStructuredChatModel(FakeListChatModel):
|
||||
return "fake-messages-list-chat-model"
|
||||
|
||||
|
||||
FakeStructuredChatModel.model_rebuild()
|
||||
|
||||
|
||||
def test_structured_prompt_pydantic() -> None:
|
||||
class OutputSchema(BaseModel):
|
||||
name: str
|
||||
|
||||
@@ -1,20 +1,10 @@
|
||||
"""Helper utilities for pydantic.
|
||||
|
||||
This module includes helper utilities to ease the migration from pydantic v1 to v2.
|
||||
|
||||
They're meant to be used in the following way:
|
||||
|
||||
1) Use utility code to help (selected) unit tests pass without modifications
|
||||
2) Upgrade the unit tests to match pydantic 2
|
||||
3) Stop using the utility code
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
|
||||
|
||||
# Function to replace allOf with $ref
|
||||
def _replace_all_of_with_ref(schema: Any) -> None:
|
||||
"""Replace allOf with $ref in the schema."""
|
||||
def replace_all_of_with_ref(schema: Any) -> None:
|
||||
if isinstance(schema, dict):
|
||||
# If the schema has an allOf key with a single item that contains a $ref
|
||||
if (
|
||||
@@ -30,13 +20,13 @@ def _replace_all_of_with_ref(schema: Any) -> None:
|
||||
# Recursively process nested schemas
|
||||
for value in schema.values():
|
||||
if isinstance(value, (dict, list)):
|
||||
_replace_all_of_with_ref(value)
|
||||
replace_all_of_with_ref(value)
|
||||
elif isinstance(schema, list):
|
||||
for item in schema:
|
||||
_replace_all_of_with_ref(item)
|
||||
replace_all_of_with_ref(item)
|
||||
|
||||
|
||||
def _remove_bad_none_defaults(schema: Any) -> None:
|
||||
def remove_all_none_default(schema: Any) -> None:
|
||||
"""Removing all none defaults.
|
||||
|
||||
Pydantic v1 did not generate these, but Pydantic v2 does.
|
||||
@@ -56,39 +46,48 @@ def _remove_bad_none_defaults(schema: Any) -> None:
|
||||
break # Null type explicitly defined
|
||||
else:
|
||||
del value["default"]
|
||||
_remove_bad_none_defaults(value)
|
||||
remove_all_none_default(value)
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
_remove_bad_none_defaults(item)
|
||||
remove_all_none_default(item)
|
||||
elif isinstance(schema, list):
|
||||
for item in schema:
|
||||
_remove_bad_none_defaults(item)
|
||||
remove_all_none_default(item)
|
||||
|
||||
|
||||
def _remove_enum_description(obj: Any) -> None:
|
||||
"""Remove the description from enums."""
|
||||
if isinstance(obj, dict):
|
||||
if "enum" in obj:
|
||||
if "description" in obj and obj["description"] == "An enumeration.":
|
||||
del obj["description"]
|
||||
for value in obj.values():
|
||||
_remove_enum_description(value)
|
||||
elif isinstance(obj, list):
|
||||
for item in obj:
|
||||
_remove_enum_description(item)
|
||||
|
||||
|
||||
def _schema(obj: Any) -> dict:
|
||||
"""Get the schema of a pydantic model in the pydantic v1 style.
|
||||
|
||||
This will attempt to map the schema as close as possible to the pydantic v1 schema.
|
||||
"""
|
||||
"""Return the schema of the object."""
|
||||
# Remap to old style schema
|
||||
if not is_basemodel_subclass(obj):
|
||||
raise TypeError(
|
||||
f"Object must be a Pydantic BaseModel subclass. Got {type(obj)}"
|
||||
)
|
||||
if not hasattr(obj, "model_json_schema"): # V1 model
|
||||
return obj.schema()
|
||||
|
||||
# Then we're using V2 models internally.
|
||||
raise AssertionError(
|
||||
"Hi there! Looks like you're attempting to upgrade to Pydantic v2. If so: \n"
|
||||
"1) remove this exception\n"
|
||||
"2) confirm that the old unit tests pass, and if not look for difference\n"
|
||||
"3) update the unit tests to match the new schema\n"
|
||||
"4) remove this utility function\n"
|
||||
)
|
||||
|
||||
schema_ = obj.model_json_schema(ref_template="#/definitions/{model}")
|
||||
if "$defs" in schema_:
|
||||
schema_["definitions"] = schema_["$defs"]
|
||||
del schema_["$defs"]
|
||||
|
||||
_replace_all_of_with_ref(schema_)
|
||||
_remove_bad_none_defaults(schema_)
|
||||
if "default" in schema_ and schema_["default"] is None:
|
||||
del schema_["default"]
|
||||
|
||||
replace_all_of_with_ref(schema_)
|
||||
remove_all_none_default(schema_)
|
||||
_remove_enum_description(schema_)
|
||||
|
||||
return schema_
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,8 +1,9 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import pytest
|
||||
from pydantic import ConfigDict, Field, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from langchain_core.pydantic_v1 import Field, root_validator
|
||||
from langchain_core.runnables import (
|
||||
ConfigurableField,
|
||||
RunnableConfig,
|
||||
@@ -14,19 +15,21 @@ class MyRunnable(RunnableSerializable[str, str]):
|
||||
my_property: str = Field(alias="my_property_alias")
|
||||
_my_hidden_property: str = ""
|
||||
|
||||
class Config:
|
||||
allow_population_by_field_name = True
|
||||
model_config = ConfigDict(
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
||||
@root_validator(pre=True)
|
||||
def my_error(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def my_error(cls, values: Dict[str, Any]) -> Any:
|
||||
if "_my_hidden_property" in values:
|
||||
raise ValueError("Cannot set _my_hidden_property")
|
||||
return values
|
||||
|
||||
@root_validator(pre=False, skip_on_failure=True)
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
values["_my_hidden_property"] = values["my_property"]
|
||||
return values
|
||||
@model_validator(mode="after")
|
||||
def build_extra(self) -> Self:
|
||||
self._my_hidden_property = self.my_property
|
||||
return self
|
||||
|
||||
def invoke(self, input: str, config: Optional[RunnableConfig] = None) -> Any:
|
||||
return input + self._my_hidden_property
|
||||
|
||||
@@ -13,6 +13,7 @@ from typing import (
|
||||
)
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
from syrupy import SnapshotAssertion
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
@@ -25,7 +26,6 @@ from langchain_core.load import dumps
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.outputs import ChatResult
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.runnables import (
|
||||
Runnable,
|
||||
RunnableBinding,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from syrupy import SnapshotAssertion
|
||||
|
||||
from langchain_core.language_models import FakeListLLM
|
||||
@@ -7,11 +8,9 @@ from langchain_core.output_parsers.list import CommaSeparatedListOutputParser
|
||||
from langchain_core.output_parsers.string import StrOutputParser
|
||||
from langchain_core.output_parsers.xml import XMLOutputParser
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.runnables.base import Runnable, RunnableConfig
|
||||
from langchain_core.runnables.graph import Edge, Graph, Node
|
||||
from langchain_core.runnables.graph_mermaid import _escape_node_label
|
||||
from tests.unit_tests.pydantic_utils import _schema
|
||||
|
||||
|
||||
def test_graph_single_runnable(snapshot: SnapshotAssertion) -> None:
|
||||
@@ -19,10 +18,10 @@ def test_graph_single_runnable(snapshot: SnapshotAssertion) -> None:
|
||||
graph = StrOutputParser().get_graph()
|
||||
first_node = graph.first_node()
|
||||
assert first_node is not None
|
||||
assert _schema(first_node.data) == _schema(runnable.input_schema) # type: ignore[union-attr]
|
||||
assert first_node.data.schema() == runnable.get_input_jsonschema() # type: ignore[union-attr]
|
||||
last_node = graph.last_node()
|
||||
assert last_node is not None
|
||||
assert _schema(last_node.data) == _schema(runnable.output_schema) # type: ignore[union-attr]
|
||||
assert last_node.data.schema() == runnable.get_output_jsonschema() # type: ignore[union-attr]
|
||||
assert len(graph.nodes) == 3
|
||||
assert len(graph.edges) == 2
|
||||
assert graph.edges[0].source == first_node.id
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user