mirror of
https://github.com/hwchase17/langchain.git
synced 2026-01-23 05:09:12 +00:00
Compare commits
108 Commits
langchain-
...
langchain-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0e4106256d | ||
|
|
933f4ab230 | ||
|
|
42d8b3631a | ||
|
|
4200876531 | ||
|
|
5bbd5364f1 | ||
|
|
f3b12f8c0c | ||
|
|
e02b093d81 | ||
|
|
522203c752 | ||
|
|
c492b7d33a | ||
|
|
8c4a52a9cc | ||
|
|
0cc6584889 | ||
|
|
8696f9f3a0 | ||
|
|
e1ab188e82 | ||
|
|
6e1b0d0228 | ||
|
|
a111098230 | ||
|
|
9e7222618b | ||
|
|
8516a03a02 | ||
|
|
d0222964c1 | ||
|
|
b97307c8b4 | ||
|
|
1ad66e70dc | ||
|
|
76564edd3a | ||
|
|
1c51e1693d | ||
|
|
c0f886dc52 | ||
|
|
a267da6a3a | ||
|
|
0c63b18c1f | ||
|
|
915c1e3dfb | ||
|
|
8da2ace99d | ||
|
|
81cd73cfca | ||
|
|
e358846b39 | ||
|
|
3c598d25a6 | ||
|
|
e5aa0f938b | ||
|
|
79c46319dd | ||
|
|
c5d4dfefc0 | ||
|
|
6e853501ec | ||
|
|
fd1f3ca213 | ||
|
|
567a4ce5aa | ||
|
|
923ce84aa7 | ||
|
|
9379613132 | ||
|
|
c72a76237f | ||
|
|
f9cafcbcb0 | ||
|
|
1fce5543bc | ||
|
|
88e9e6bf55 | ||
|
|
7f0dd4b182 | ||
|
|
5557b86a54 | ||
|
|
caf4ae3a45 | ||
|
|
c88b75ca6a | ||
|
|
e409a85a28 | ||
|
|
40634d441a | ||
|
|
1d2a503ab8 | ||
|
|
b924c61440 | ||
|
|
efa10c8ef8 | ||
|
|
0a6c67ce6a | ||
|
|
ed771f2d2b | ||
|
|
63ba12d8e0 | ||
|
|
f785cf029b | ||
|
|
be7cd0756f | ||
|
|
51c6899850 | ||
|
|
163d6fe8ef | ||
|
|
7cee7fbfad | ||
|
|
4799ad95d0 | ||
|
|
88065d794b | ||
|
|
b27bfa6717 | ||
|
|
5adeaf0732 | ||
|
|
f9d91e19c5 | ||
|
|
4c7afb0d6c | ||
|
|
c1ff61669d | ||
|
|
54d6808c1e | ||
|
|
78468de2e5 | ||
|
|
76572f963b | ||
|
|
c0448f27ba | ||
|
|
179aaa4007 | ||
|
|
d072d592a1 | ||
|
|
78c454c130 | ||
|
|
5199555c0d | ||
|
|
5e31cd91a7 | ||
|
|
49a1f5dd47 | ||
|
|
d0cc9b022a | ||
|
|
a91bd2737a | ||
|
|
5ad2b8ce80 | ||
|
|
b78764599b | ||
|
|
2888e34f53 | ||
|
|
dd4418a503 | ||
|
|
a976f2071b | ||
|
|
5f98975be0 | ||
|
|
0529c991ce | ||
|
|
954abcce59 | ||
|
|
6ad515d34e | ||
|
|
99348e1614 | ||
|
|
2c742cc20d | ||
|
|
02f87203f7 | ||
|
|
56163481dd | ||
|
|
6aac2eeab5 | ||
|
|
559d8a4d13 | ||
|
|
ec9e8eb71c | ||
|
|
9399df7777 | ||
|
|
5fc1104d00 | ||
|
|
6777106fbe | ||
|
|
5f5287c3b0 | ||
|
|
615f8b0d47 | ||
|
|
9a9ab65030 | ||
|
|
241b6d2355 | ||
|
|
91e09ffee5 | ||
|
|
8e4bae351e | ||
|
|
0da201c1d5 | ||
|
|
29413a22e1 | ||
|
|
ae5a574aa5 | ||
|
|
5a0e82c31c | ||
|
|
8590b421c4 |
9
.github/scripts/check_diff.py
vendored
9
.github/scripts/check_diff.py
vendored
@@ -16,6 +16,10 @@ LANGCHAIN_DIRS = [
|
||||
"libs/experimental",
|
||||
]
|
||||
|
||||
# for 0.3rc, we are ignoring core dependents
|
||||
# in order to be able to get CI to pass for individual PRs.
|
||||
IGNORE_CORE_DEPENDENTS = True
|
||||
|
||||
# ignored partners are removed from dependents
|
||||
# but still run if directly edited
|
||||
IGNORED_PARTNERS = [
|
||||
@@ -104,7 +108,7 @@ def _get_configs_for_single_dir(job: str, dir_: str) -> List[Dict[str, str]]:
|
||||
{"working-directory": dir_, "python-version": f"3.{v}"}
|
||||
for v in range(8, 13)
|
||||
]
|
||||
min_python = "3.8"
|
||||
min_python = "3.9"
|
||||
max_python = "3.12"
|
||||
|
||||
# custom logic for specific directories
|
||||
@@ -184,6 +188,9 @@ if __name__ == "__main__":
|
||||
# for extended testing
|
||||
found = False
|
||||
for dir_ in LANGCHAIN_DIRS:
|
||||
if dir_ == "libs/core" and IGNORE_CORE_DEPENDENTS:
|
||||
dirs_to_run["extended-test"].add(dir_)
|
||||
continue
|
||||
if file.startswith(dir_):
|
||||
found = True
|
||||
if found:
|
||||
|
||||
@@ -11,7 +11,7 @@ if __name__ == "__main__":
|
||||
|
||||
# see if we're releasing an rc
|
||||
version = toml_data["tool"]["poetry"]["version"]
|
||||
releasing_rc = "rc" in version
|
||||
releasing_rc = "rc" in version or "dev" in version
|
||||
|
||||
# if not, iterate through dependencies and make sure none allow prereleases
|
||||
if not releasing_rc:
|
||||
|
||||
114
.github/workflows/_dependencies.yml
vendored
114
.github/workflows/_dependencies.yml
vendored
@@ -1,114 +0,0 @@
|
||||
name: dependencies
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
working-directory:
|
||||
required: true
|
||||
type: string
|
||||
description: "From which folder this pipeline executes"
|
||||
langchain-location:
|
||||
required: false
|
||||
type: string
|
||||
description: "Relative path to the langchain library folder"
|
||||
python-version:
|
||||
required: true
|
||||
type: string
|
||||
description: "Python version to use"
|
||||
|
||||
env:
|
||||
POETRY_VERSION: "1.7.1"
|
||||
|
||||
jobs:
|
||||
build:
|
||||
defaults:
|
||||
run:
|
||||
working-directory: ${{ inputs.working-directory }}
|
||||
runs-on: ubuntu-latest
|
||||
name: dependency checks ${{ inputs.python-version }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python ${{ inputs.python-version }} + Poetry ${{ env.POETRY_VERSION }}
|
||||
uses: "./.github/actions/poetry_setup"
|
||||
with:
|
||||
python-version: ${{ inputs.python-version }}
|
||||
poetry-version: ${{ env.POETRY_VERSION }}
|
||||
working-directory: ${{ inputs.working-directory }}
|
||||
cache-key: pydantic-cross-compat
|
||||
|
||||
- name: Install dependencies
|
||||
shell: bash
|
||||
run: poetry install
|
||||
|
||||
- name: Check imports with base dependencies
|
||||
shell: bash
|
||||
run: poetry run make check_imports
|
||||
|
||||
- name: Install test dependencies
|
||||
shell: bash
|
||||
run: poetry install --with test
|
||||
|
||||
- name: Install langchain editable
|
||||
working-directory: ${{ inputs.working-directory }}
|
||||
if: ${{ inputs.langchain-location }}
|
||||
env:
|
||||
LANGCHAIN_LOCATION: ${{ inputs.langchain-location }}
|
||||
run: |
|
||||
poetry run pip install -e "$LANGCHAIN_LOCATION"
|
||||
|
||||
- name: Install the opposite major version of pydantic
|
||||
# If normal tests use pydantic v1, here we'll use v2, and vice versa.
|
||||
shell: bash
|
||||
# airbyte currently doesn't support pydantic v2
|
||||
if: ${{ !startsWith(inputs.working-directory, 'libs/partners/airbyte') }}
|
||||
run: |
|
||||
# Determine the major part of pydantic version
|
||||
REGULAR_VERSION=$(poetry run python -c "import pydantic; print(pydantic.__version__)" | cut -d. -f1)
|
||||
|
||||
if [[ "$REGULAR_VERSION" == "1" ]]; then
|
||||
PYDANTIC_DEP=">=2.1,<3"
|
||||
TEST_WITH_VERSION="2"
|
||||
elif [[ "$REGULAR_VERSION" == "2" ]]; then
|
||||
PYDANTIC_DEP="<2"
|
||||
TEST_WITH_VERSION="1"
|
||||
else
|
||||
echo "Unexpected pydantic major version '$REGULAR_VERSION', cannot determine which version to use for cross-compatibility test."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Install via `pip` instead of `poetry add` to avoid changing lockfile,
|
||||
# which would prevent caching from working: the cache would get saved
|
||||
# to a different key than where it gets loaded from.
|
||||
poetry run pip install "pydantic${PYDANTIC_DEP}"
|
||||
|
||||
# Ensure that the correct pydantic is installed now.
|
||||
echo "Checking pydantic version... Expecting ${TEST_WITH_VERSION}"
|
||||
|
||||
# Determine the major part of pydantic version
|
||||
CURRENT_VERSION=$(poetry run python -c "import pydantic; print(pydantic.__version__)" | cut -d. -f1)
|
||||
|
||||
# Check that the major part of pydantic version is as expected, if not
|
||||
# raise an error
|
||||
if [[ "$CURRENT_VERSION" != "$TEST_WITH_VERSION" ]]; then
|
||||
echo "Error: expected pydantic version ${CURRENT_VERSION} to have been installed, but found: ${TEST_WITH_VERSION}"
|
||||
exit 1
|
||||
fi
|
||||
echo "Found pydantic version ${CURRENT_VERSION}, as expected"
|
||||
- name: Run pydantic compatibility tests
|
||||
# airbyte currently doesn't support pydantic v2
|
||||
if: ${{ !startsWith(inputs.working-directory, 'libs/partners/airbyte') }}
|
||||
shell: bash
|
||||
run: make test
|
||||
|
||||
- name: Ensure the tests did not create any additional files
|
||||
shell: bash
|
||||
run: |
|
||||
set -eu
|
||||
|
||||
STATUS="$(git status)"
|
||||
echo "$STATUS"
|
||||
|
||||
# grep will exit non-zero if the target message isn't found,
|
||||
# and `set -e` above will cause the step to fail.
|
||||
echo "$STATUS" | grep 'nothing to commit, working tree clean'
|
||||
15
.github/workflows/check_diffs.yml
vendored
15
.github/workflows/check_diffs.yml
vendored
@@ -89,19 +89,6 @@ jobs:
|
||||
python-version: ${{ matrix.job-configs.python-version }}
|
||||
secrets: inherit
|
||||
|
||||
dependencies:
|
||||
name: cd ${{ matrix.job-configs.working-directory }}
|
||||
needs: [ build ]
|
||||
if: ${{ needs.build.outputs.dependencies != '[]' }}
|
||||
strategy:
|
||||
matrix:
|
||||
job-configs: ${{ fromJson(needs.build.outputs.dependencies) }}
|
||||
uses: ./.github/workflows/_dependencies.yml
|
||||
with:
|
||||
working-directory: ${{ matrix.job-configs.working-directory }}
|
||||
python-version: ${{ matrix.job-configs.python-version }}
|
||||
secrets: inherit
|
||||
|
||||
extended-tests:
|
||||
name: "cd ${{ matrix.job-configs.working-directory }} / make extended_tests #${{ matrix.job-configs.python-version }}"
|
||||
needs: [ build ]
|
||||
@@ -149,7 +136,7 @@ jobs:
|
||||
echo "$STATUS" | grep 'nothing to commit, working tree clean'
|
||||
ci_success:
|
||||
name: "CI Success"
|
||||
needs: [build, lint, test, compile-integration-tests, dependencies, extended-tests, test-doc-imports]
|
||||
needs: [build, lint, test, compile-integration-tests, extended-tests, test-doc-imports]
|
||||
if: |
|
||||
always()
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
2
.github/workflows/scheduled_test.yml
vendored
2
.github/workflows/scheduled_test.yml
vendored
@@ -17,7 +17,7 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version:
|
||||
- "3.8"
|
||||
- "3.9"
|
||||
- "3.11"
|
||||
working-directory:
|
||||
- "libs/partners/openai"
|
||||
|
||||
@@ -12,6 +12,9 @@ integration_test integration_tests: TEST_FILE = tests/integration_tests/
|
||||
test tests:
|
||||
poetry run pytest --disable-socket --allow-unix-socket $(TEST_FILE)
|
||||
|
||||
test_watch:
|
||||
poetry run ptw --snapshot-update --now . -- -vv $(TEST_FILE)
|
||||
|
||||
# integration tests are run without the --disable-socket flag to allow network calls
|
||||
integration_test integration_tests:
|
||||
poetry run pytest $(TEST_FILE)
|
||||
|
||||
@@ -23,6 +23,7 @@ pytest = "^7.4.3"
|
||||
pytest-asyncio = "^0.23.2"
|
||||
pytest-socket = "^0.7.0"
|
||||
langchain-core = { path = "../../core", develop = true }
|
||||
pytest-watcher = "^0.3.4"
|
||||
|
||||
[tool.poetry.group.codespell]
|
||||
optional = true
|
||||
|
||||
@@ -22,7 +22,7 @@ integration_tests:
|
||||
poetry run pytest $(TEST_FILE)
|
||||
|
||||
test_watch:
|
||||
poetry run ptw --disable-socket --allow-unix-socket --snapshot-update --now . -- -vv -x tests/unit_tests
|
||||
poetry run ptw --disable-socket --allow-unix-socket --snapshot-update --now . -- -vv tests/unit_tests
|
||||
|
||||
check_imports: $(shell find langchain_community -name '*.py')
|
||||
poetry run python ./scripts/check_imports.py $^
|
||||
@@ -45,7 +45,6 @@ lint_tests: PYTHON_FILES=tests
|
||||
lint_tests: MYPY_CACHE=.mypy_cache_test
|
||||
|
||||
lint lint_diff lint_package lint_tests:
|
||||
./scripts/check_pydantic.sh .
|
||||
./scripts/lint_imports.sh .
|
||||
./scripts/check_pickle.sh .
|
||||
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff check $(PYTHON_FILES)
|
||||
|
||||
@@ -25,7 +25,7 @@ from langchain_core.messages import (
|
||||
SystemMessage,
|
||||
ToolMessage,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Literal
|
||||
|
||||
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, List, Literal, Optional
|
||||
from typing import TYPE_CHECKING, Any, List, Literal, Optional
|
||||
|
||||
from langchain_core.pydantic_v1 import root_validator
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.tools.base import BaseToolkit
|
||||
from pydantic import ConfigDict, model_validator
|
||||
|
||||
from langchain_community.tools.ainetwork.app import AINAppOps
|
||||
from langchain_community.tools.ainetwork.owner import AINOwnerOps
|
||||
@@ -36,8 +36,9 @@ class AINetworkToolkit(BaseToolkit):
|
||||
network: Optional[Literal["mainnet", "testnet"]] = "testnet"
|
||||
interface: Optional[Ain] = None
|
||||
|
||||
@root_validator(pre=True)
|
||||
def set_interface(cls, values: dict) -> dict:
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def set_interface(cls, values: dict) -> Any:
|
||||
"""Set the interface if not provided.
|
||||
|
||||
If the interface is not provided, attempt to authenticate with the
|
||||
@@ -53,9 +54,10 @@ class AINetworkToolkit(BaseToolkit):
|
||||
values["interface"] = authenticate(network=values.get("network", "testnet"))
|
||||
return values
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
validate_all = True
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
validate_default=True,
|
||||
)
|
||||
|
||||
def get_tools(self) -> List[BaseTool]:
|
||||
"""Get the tools in the toolkit."""
|
||||
|
||||
@@ -3,9 +3,9 @@ from __future__ import annotations
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.tools.base import BaseToolkit
|
||||
from pydantic import ConfigDict, Field
|
||||
|
||||
from langchain_community.tools.amadeus.closest_airport import AmadeusClosestAirport
|
||||
from langchain_community.tools.amadeus.flight_search import AmadeusFlightSearch
|
||||
@@ -26,8 +26,9 @@ class AmadeusToolkit(BaseToolkit):
|
||||
client: Client = Field(default_factory=authenticate)
|
||||
llm: Optional[BaseLanguageModel] = Field(default=None)
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
def get_tools(self) -> List[BaseTool]:
|
||||
"""Get the tools in the toolkit."""
|
||||
|
||||
@@ -2,9 +2,9 @@
|
||||
|
||||
from typing import List
|
||||
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.tools.base import BaseToolkit
|
||||
from pydantic import ConfigDict, Field
|
||||
|
||||
from langchain_community.tools.cassandra_database.tool import (
|
||||
GetSchemaCassandraDatabaseTool,
|
||||
@@ -24,8 +24,9 @@ class CassandraDatabaseToolkit(BaseToolkit):
|
||||
|
||||
db: CassandraDatabase = Field(exclude=True)
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
def get_tools(self) -> List[BaseTool]:
|
||||
"""Get the tools in the toolkit."""
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from typing import List
|
||||
from typing import Any, List
|
||||
|
||||
from langchain_core.pydantic_v1 import root_validator
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.tools.base import BaseToolkit
|
||||
from pydantic import model_validator
|
||||
|
||||
from langchain_community.tools.connery import ConneryService
|
||||
|
||||
@@ -23,8 +23,9 @@ class ConneryToolkit(BaseToolkit):
|
||||
"""
|
||||
return self.tools
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_attributes(cls, values: dict) -> dict:
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_attributes(cls, values: dict) -> Any:
|
||||
"""
|
||||
Validate the attributes of the ConneryToolkit class.
|
||||
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, List, Optional, Type
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
|
||||
from langchain_core.pydantic_v1 import root_validator
|
||||
from langchain_core.tools import BaseTool, BaseToolkit
|
||||
from langchain_core.utils.pydantic import get_fields
|
||||
from pydantic import model_validator
|
||||
|
||||
from langchain_community.tools.file_management.copy import CopyFileTool
|
||||
from langchain_community.tools.file_management.delete import DeleteFileTool
|
||||
@@ -63,8 +63,9 @@ class FileManagementToolkit(BaseToolkit):
|
||||
selected_tools: Optional[List[str]] = None
|
||||
"""If provided, only provide the selected tools. Defaults to all."""
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_tools(cls, values: dict) -> dict:
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_tools(cls, values: dict) -> Any:
|
||||
selected_tools = values.get("selected_tools") or []
|
||||
for tool_name in selected_tools:
|
||||
if tool_name not in _FILE_TOOLS_MAP:
|
||||
|
||||
@@ -2,9 +2,9 @@ from __future__ import annotations
|
||||
|
||||
from typing import List
|
||||
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.tools.base import BaseToolkit
|
||||
from pydantic import ConfigDict, Field
|
||||
|
||||
from langchain_community.tools.financial_datasets.balance_sheets import BalanceSheets
|
||||
from langchain_community.tools.financial_datasets.cash_flow_statements import (
|
||||
@@ -31,8 +31,9 @@ class FinancialDatasetsToolkit(BaseToolkit):
|
||||
super().__init__()
|
||||
self.api_wrapper = api_wrapper
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
def get_tools(self) -> List[BaseTool]:
|
||||
"""Get the tools in the toolkit."""
|
||||
|
||||
@@ -2,9 +2,9 @@
|
||||
|
||||
from typing import Dict, List
|
||||
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.tools.base import BaseToolkit
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from langchain_community.tools.github.prompt import (
|
||||
COMMENT_ON_ISSUE_PROMPT,
|
||||
|
||||
@@ -2,9 +2,9 @@ from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.tools.base import BaseToolkit
|
||||
from pydantic import ConfigDict, Field
|
||||
|
||||
from langchain_community.tools.gmail.create_draft import GmailCreateDraft
|
||||
from langchain_community.tools.gmail.get_message import GmailGetMessage
|
||||
@@ -117,8 +117,9 @@ class GmailToolkit(BaseToolkit):
|
||||
|
||||
api_resource: Resource = Field(default_factory=build_resource_service)
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
def get_tools(self) -> List[BaseTool]:
|
||||
"""Get the tools in the toolkit."""
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import List
|
||||
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.tools.base import BaseToolkit
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from langchain_community.tools.multion.close_session import MultionCloseSession
|
||||
from langchain_community.tools.multion.create_session import MultionCreateSession
|
||||
@@ -25,8 +26,9 @@ class MultionToolkit(BaseToolkit):
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
"""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
def get_tools(self) -> List[BaseTool]:
|
||||
"""Get the tools in the toolkit."""
|
||||
|
||||
@@ -3,9 +3,9 @@ from __future__ import annotations
|
||||
from typing import Any, List, Optional, Sequence
|
||||
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.tools.base import BaseToolkit
|
||||
from pydantic import Field
|
||||
|
||||
from langchain_community.agent_toolkits.nla.tool import NLATool
|
||||
from langchain_community.tools.openapi.utils.openapi_utils import OpenAPISpec
|
||||
|
||||
@@ -2,9 +2,9 @@ from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.tools.base import BaseToolkit
|
||||
from pydantic import ConfigDict, Field
|
||||
|
||||
from langchain_community.tools.office365.create_draft_message import (
|
||||
O365CreateDraftMessage,
|
||||
@@ -40,8 +40,9 @@ class O365Toolkit(BaseToolkit):
|
||||
|
||||
account: Account = Field(default_factory=authenticate)
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
def get_tools(self) -> List[BaseTool]:
|
||||
"""Get the tools in the toolkit."""
|
||||
|
||||
@@ -9,8 +9,8 @@ import yaml
|
||||
from langchain_core.callbacks import BaseCallbackManager
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.prompts import BasePromptTemplate, PromptTemplate
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
from langchain_core.tools import BaseTool, Tool
|
||||
from pydantic import Field
|
||||
|
||||
from langchain_community.agent_toolkits.openapi.planner_prompt import (
|
||||
API_CONTROLLER_PROMPT,
|
||||
@@ -69,7 +69,7 @@ class RequestsGetToolWithParsing(BaseRequestsTool, BaseTool):
|
||||
|
||||
name: str = "requests_get"
|
||||
"""Tool name."""
|
||||
description = REQUESTS_GET_TOOL_DESCRIPTION
|
||||
description: str = REQUESTS_GET_TOOL_DESCRIPTION
|
||||
"""Tool description."""
|
||||
response_length: int = MAX_RESPONSE_LENGTH
|
||||
"""Maximum length of the response to be returned."""
|
||||
@@ -103,7 +103,7 @@ class RequestsPostToolWithParsing(BaseRequestsTool, BaseTool):
|
||||
|
||||
name: str = "requests_post"
|
||||
"""Tool name."""
|
||||
description = REQUESTS_POST_TOOL_DESCRIPTION
|
||||
description: str = REQUESTS_POST_TOOL_DESCRIPTION
|
||||
"""Tool description."""
|
||||
response_length: int = MAX_RESPONSE_LENGTH
|
||||
"""Maximum length of the response to be returned."""
|
||||
@@ -134,7 +134,7 @@ class RequestsPatchToolWithParsing(BaseRequestsTool, BaseTool):
|
||||
|
||||
name: str = "requests_patch"
|
||||
"""Tool name."""
|
||||
description = REQUESTS_PATCH_TOOL_DESCRIPTION
|
||||
description: str = REQUESTS_PATCH_TOOL_DESCRIPTION
|
||||
"""Tool description."""
|
||||
response_length: int = MAX_RESPONSE_LENGTH
|
||||
"""Maximum length of the response to be returned."""
|
||||
@@ -167,7 +167,7 @@ class RequestsPutToolWithParsing(BaseRequestsTool, BaseTool):
|
||||
|
||||
name: str = "requests_put"
|
||||
"""Tool name."""
|
||||
description = REQUESTS_PUT_TOOL_DESCRIPTION
|
||||
description: str = REQUESTS_PUT_TOOL_DESCRIPTION
|
||||
"""Tool description."""
|
||||
response_length: int = MAX_RESPONSE_LENGTH
|
||||
"""Maximum length of the response to be returned."""
|
||||
@@ -198,7 +198,7 @@ class RequestsDeleteToolWithParsing(BaseRequestsTool, BaseTool):
|
||||
|
||||
name: str = "requests_delete"
|
||||
"""The name of the tool."""
|
||||
description = REQUESTS_DELETE_TOOL_DESCRIPTION
|
||||
description: str = REQUESTS_DELETE_TOOL_DESCRIPTION
|
||||
"""The description of the tool."""
|
||||
|
||||
response_length: Optional[int] = MAX_RESPONSE_LENGTH
|
||||
|
||||
@@ -2,10 +2,10 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, List, Optional, Type, cast
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Type, cast
|
||||
|
||||
from langchain_core.pydantic_v1 import root_validator
|
||||
from langchain_core.tools import BaseTool, BaseToolkit
|
||||
from pydantic import ConfigDict, model_validator
|
||||
|
||||
from langchain_community.tools.playwright.base import (
|
||||
BaseBrowserTool,
|
||||
@@ -68,12 +68,14 @@ class PlayWrightBrowserToolkit(BaseToolkit):
|
||||
sync_browser: Optional["SyncBrowser"] = None
|
||||
async_browser: Optional["AsyncBrowser"] = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
extra = "forbid"
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
extra="forbid",
|
||||
)
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_imports_and_browser_provided(cls, values: dict) -> dict:
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_imports_and_browser_provided(cls, values: dict) -> Any:
|
||||
"""Check that the arguments are valid."""
|
||||
lazy_import_playwright_browsers()
|
||||
if values.get("async_browser") is None and values.get("sync_browser") is None:
|
||||
|
||||
@@ -13,9 +13,9 @@ from langchain_core.prompts.chat import (
|
||||
HumanMessagePromptTemplate,
|
||||
SystemMessagePromptTemplate,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.tools.base import BaseToolkit
|
||||
from pydantic import ConfigDict, Field
|
||||
|
||||
from langchain_community.tools.powerbi.prompt import (
|
||||
QUESTION_TO_QUERY_BASE,
|
||||
@@ -63,8 +63,9 @@ class PowerBIToolkit(BaseToolkit):
|
||||
output_token_limit: Optional[int] = None
|
||||
tiktoken_model_name: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
def get_tools(self) -> List[BaseTool]:
|
||||
"""Get the tools in the toolkit."""
|
||||
|
||||
@@ -2,9 +2,9 @@ from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.tools.base import BaseToolkit
|
||||
from pydantic import ConfigDict, Field
|
||||
|
||||
from langchain_community.tools.slack.get_channel import SlackGetChannel
|
||||
from langchain_community.tools.slack.get_message import SlackGetMessage
|
||||
@@ -91,8 +91,9 @@ class SlackToolkit(BaseToolkit):
|
||||
|
||||
client: WebClient = Field(default_factory=login)
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
def get_tools(self) -> List[BaseTool]:
|
||||
"""Get the tools in the toolkit."""
|
||||
|
||||
@@ -3,9 +3,9 @@
|
||||
from typing import List
|
||||
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.tools.base import BaseToolkit
|
||||
from pydantic import ConfigDict, Field
|
||||
|
||||
from langchain_community.tools.spark_sql.tool import (
|
||||
InfoSparkSQLTool,
|
||||
@@ -27,8 +27,9 @@ class SparkSQLToolkit(BaseToolkit):
|
||||
db: SparkSQL = Field(exclude=True)
|
||||
llm: BaseLanguageModel = Field(exclude=True)
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
def get_tools(self) -> List[BaseTool]:
|
||||
"""Get the tools in the toolkit."""
|
||||
|
||||
@@ -3,9 +3,9 @@
|
||||
from typing import List
|
||||
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.tools.base import BaseToolkit
|
||||
from pydantic import ConfigDict, Field
|
||||
|
||||
from langchain_community.tools.sql_database.tool import (
|
||||
InfoSQLDatabaseTool,
|
||||
@@ -83,8 +83,9 @@ class SQLDatabaseToolkit(BaseToolkit):
|
||||
"""Return string representation of SQL dialect to use."""
|
||||
return self.db.dialect
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
def get_tools(self) -> List[BaseTool]:
|
||||
"""Get the tools in the toolkit."""
|
||||
|
||||
@@ -15,10 +15,11 @@ from langchain.agents.openai_assistant.base import OpenAIAssistantRunnable, Outp
|
||||
from langchain_core._api import beta
|
||||
from langchain_core.callbacks import CallbackManager
|
||||
from langchain_core.load import dumpd
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
|
||||
from langchain_core.runnables import RunnableConfig, ensure_config
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import openai
|
||||
@@ -209,14 +210,14 @@ class OpenAIAssistantV2Runnable(OpenAIAssistantRunnable):
|
||||
as_agent: bool = False
|
||||
"""Use as a LangChain agent, compatible with the AgentExecutor."""
|
||||
|
||||
@root_validator(pre=False, skip_on_failure=True)
|
||||
def validate_async_client(cls, values: dict) -> dict:
|
||||
if values["async_client"] is None:
|
||||
@model_validator(mode="after")
|
||||
def validate_async_client(self) -> Self:
|
||||
if self.async_client is None:
|
||||
import openai
|
||||
|
||||
api_key = values["client"].api_key
|
||||
values["async_client"] = openai.AsyncOpenAI(api_key=api_key)
|
||||
return values
|
||||
api_key = self.client.api_key
|
||||
self.async_client = openai.AsyncOpenAI(api_key=api_key)
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def create_assistant(
|
||||
|
||||
@@ -22,9 +22,9 @@ from langchain_core.output_parsers import (
|
||||
BaseOutputParser,
|
||||
)
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.runnables import Runnable
|
||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain_community.output_parsers.ernie_functions import (
|
||||
JsonOutputFunctionsParser,
|
||||
|
||||
@@ -10,7 +10,7 @@ from langchain.chains.llm import LLMChain
|
||||
from langchain_core.callbacks import CallbackManagerForChainRun
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
from pydantic import Field
|
||||
|
||||
from langchain_community.chains.graph_qa.prompts import (
|
||||
AQL_FIX_PROMPT,
|
||||
|
||||
@@ -9,7 +9,7 @@ from langchain.chains.llm import LLMChain
|
||||
from langchain_core.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
from pydantic import Field
|
||||
|
||||
from langchain_community.chains.graph_qa.prompts import (
|
||||
ENTITY_EXTRACTION_PROMPT,
|
||||
|
||||
@@ -22,8 +22,8 @@ from langchain_core.prompts import (
|
||||
HumanMessagePromptTemplate,
|
||||
MessagesPlaceholder,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
from langchain_core.runnables import Runnable
|
||||
from pydantic import Field
|
||||
|
||||
from langchain_community.chains.graph_qa.cypher_utils import (
|
||||
CypherQueryCorrector,
|
||||
|
||||
@@ -10,7 +10,7 @@ from langchain.chains.llm import LLMChain
|
||||
from langchain_core.callbacks import CallbackManagerForChainRun
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
from pydantic import Field
|
||||
|
||||
from langchain_community.chains.graph_qa.prompts import (
|
||||
CYPHER_GENERATION_PROMPT,
|
||||
|
||||
@@ -10,7 +10,7 @@ from langchain_core.callbacks.manager import CallbackManager, CallbackManagerFor
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
from pydantic import Field
|
||||
|
||||
from langchain_community.chains.graph_qa.prompts import (
|
||||
CYPHER_QA_PROMPT,
|
||||
|
||||
@@ -9,7 +9,7 @@ from langchain.chains.llm import LLMChain
|
||||
from langchain_core.callbacks import CallbackManagerForChainRun
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
from pydantic import Field
|
||||
|
||||
from langchain_community.chains.graph_qa.prompts import (
|
||||
CYPHER_QA_PROMPT,
|
||||
|
||||
@@ -10,7 +10,7 @@ from langchain.chains.llm import LLMChain
|
||||
from langchain_core.callbacks import CallbackManagerForChainRun
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
from pydantic import Field
|
||||
|
||||
from langchain_community.chains.graph_qa.prompts import (
|
||||
CYPHER_QA_PROMPT,
|
||||
|
||||
@@ -9,7 +9,7 @@ from langchain.chains.llm import LLMChain
|
||||
from langchain_core.callbacks import CallbackManagerForChainRun
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
from pydantic import Field
|
||||
|
||||
from langchain_community.chains.graph_qa.prompts import (
|
||||
CYPHER_QA_PROMPT,
|
||||
|
||||
@@ -9,7 +9,7 @@ from langchain.chains.prompt_selector import ConditionalPromptSelector
|
||||
from langchain_core.callbacks import CallbackManagerForChainRun
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.prompts.base import BasePromptTemplate
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
from pydantic import Field
|
||||
|
||||
from langchain_community.chains.graph_qa.prompts import (
|
||||
CYPHER_QA_PROMPT,
|
||||
|
||||
@@ -12,7 +12,7 @@ from langchain_core.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.prompts.base import BasePromptTemplate
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
from pydantic import Field
|
||||
|
||||
from langchain_community.chains.graph_qa.prompts import SPARQL_QA_PROMPT
|
||||
from langchain_community.graphs import NeptuneRdfGraph
|
||||
|
||||
@@ -12,7 +12,7 @@ from langchain.chains.llm import LLMChain
|
||||
from langchain_core.callbacks.manager import CallbackManager, CallbackManagerForChainRun
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.prompts.base import BasePromptTemplate
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
from pydantic import Field
|
||||
|
||||
from langchain_community.chains.graph_qa.prompts import (
|
||||
GRAPHDB_QA_PROMPT,
|
||||
|
||||
@@ -11,7 +11,7 @@ from langchain.chains.llm import LLMChain
|
||||
from langchain_core.callbacks import CallbackManagerForChainRun
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.prompts.base import BasePromptTemplate
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
from pydantic import Field
|
||||
|
||||
from langchain_community.chains.graph_qa.prompts import (
|
||||
SPARQL_GENERATION_SELECT_PROMPT,
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import Any, Dict, List, Optional
|
||||
from langchain.chains import LLMChain
|
||||
from langchain.chains.base import Chain
|
||||
from langchain_core.callbacks import CallbackManagerForChainRun
|
||||
from langchain_core.pydantic_v1 import Field, root_validator
|
||||
from pydantic import ConfigDict, Field, model_validator
|
||||
|
||||
from langchain_community.utilities.requests import TextRequestsWrapper
|
||||
|
||||
@@ -38,9 +38,10 @@ class LLMRequestsChain(Chain):
|
||||
input_key: str = "url" #: :meta private:
|
||||
output_key: str = "output" #: :meta private:
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
extra = "forbid"
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
extra="forbid",
|
||||
)
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
@@ -58,8 +59,9 @@ class LLMRequestsChain(Chain):
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_environment(cls, values: Dict) -> Any:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
try:
|
||||
from bs4 import BeautifulSoup # noqa: F401
|
||||
|
||||
@@ -11,7 +11,7 @@ from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain_core.callbacks import CallbackManagerForChainRun, Callbacks
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
from pydantic import BaseModel, Field
|
||||
from requests import Response
|
||||
|
||||
from langchain_community.tools.openapi.utils.api_models import APIOperation
|
||||
|
||||
@@ -17,8 +17,8 @@ from langchain_core.callbacks import (
|
||||
)
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.pydantic_v1 import Field, validator
|
||||
from langchain_core.vectorstores import VectorStoreRetriever
|
||||
from pydantic import ConfigDict, Field, validator
|
||||
|
||||
from langchain_community.chains.pebblo_retrieval.enforcement_filters import (
|
||||
SUPPORTED_VECTORSTORES,
|
||||
@@ -189,10 +189,11 @@ class PebbloRetrievalQA(Chain):
|
||||
else:
|
||||
return {self.output_key: answer}
|
||||
|
||||
class Config:
|
||||
allow_population_by_field_name = True
|
||||
arbitrary_types_allowed = True
|
||||
extra = "forbid"
|
||||
model_config = ConfigDict(
|
||||
populate_by_name=True,
|
||||
arbitrary_types_allowed=True,
|
||||
extra="forbid",
|
||||
)
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class AuthContext(BaseModel):
|
||||
|
||||
@@ -10,9 +10,9 @@ import aiohttp
|
||||
from aiohttp import ClientTimeout
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.env import get_runtime_environment
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
from langchain_core.vectorstores import VectorStoreRetriever
|
||||
from pydantic import BaseModel
|
||||
from requests import Response, request
|
||||
from requests.exceptions import RequestException
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ from langchain_core.messages import (
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.prompt_values import PromptValue
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from langchain_community.llms.anthropic import _AnthropicCommon
|
||||
|
||||
@@ -91,9 +92,10 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
|
||||
model = ChatAnthropic(model="<model_name>", anthropic_api_key="my-api-key")
|
||||
"""
|
||||
|
||||
class Config:
|
||||
allow_population_by_field_name = True
|
||||
arbitrary_types_allowed = True
|
||||
model_config = ConfigDict(
|
||||
populate_by_name=True,
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
|
||||
@@ -5,12 +5,12 @@ from __future__ import annotations
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from typing import TYPE_CHECKING, Dict, Optional, Set
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Set
|
||||
|
||||
import requests
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
||||
from pydantic import Field, SecretStr, model_validator
|
||||
|
||||
from langchain_community.adapters.openai import convert_message_to_dict
|
||||
from langchain_community.chat_models.openai import (
|
||||
@@ -102,8 +102,9 @@ class ChatAnyscale(ChatOpenAI):
|
||||
|
||||
return {model["id"] for model in models_response.json()["data"]}
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_environment(cls, values: dict) -> dict:
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_environment(cls, values: dict) -> Any:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
values["anyscale_api_key"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(
|
||||
|
||||
@@ -9,8 +9,8 @@ from typing import Any, Callable, Dict, List, Union
|
||||
|
||||
from langchain_core._api.deprecation import deprecated
|
||||
from langchain_core.outputs import ChatResult
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
from langchain_core.utils import get_from_dict_or_env, pre_init
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from langchain_community.chat_models.openai import ChatOpenAI
|
||||
from langchain_community.utils.openai import is_openai_v1
|
||||
|
||||
@@ -44,7 +44,6 @@ from langchain_core.output_parsers.openai_tools import (
|
||||
parse_tool_call,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
|
||||
from langchain_core.runnables import Runnable
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils import (
|
||||
@@ -53,6 +52,13 @@ from langchain_core.utils import (
|
||||
get_pydantic_field_names,
|
||||
)
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
SecretStr,
|
||||
model_validator,
|
||||
)
|
||||
|
||||
from langchain_community.chat_models.llamacpp import (
|
||||
_lc_invalid_tool_call_to_openai_tool_call,
|
||||
@@ -375,11 +381,13 @@ class ChatBaichuan(BaseChatModel):
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Holds any model parameters valid for API call not explicitly specified."""
|
||||
|
||||
class Config:
|
||||
allow_population_by_field_name = True
|
||||
model_config = ConfigDict(
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
||||
@root_validator(pre=True)
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Any:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = get_pydantic_field_names(cls)
|
||||
extra = values.get("model_kwargs", {})
|
||||
@@ -404,8 +412,9 @@ class ChatBaichuan(BaseChatModel):
|
||||
values["model_kwargs"] = extra
|
||||
return values
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_environment(cls, values: Dict) -> Any:
|
||||
values["baichuan_api_base"] = get_from_dict_or_env(
|
||||
values,
|
||||
"baichuan_api_base",
|
||||
|
||||
@@ -41,17 +41,18 @@ from langchain_core.output_parsers.openai_tools import (
|
||||
PydanticToolsParser,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.pydantic_v1 import (
|
||||
BaseModel,
|
||||
Field,
|
||||
SecretStr,
|
||||
root_validator,
|
||||
)
|
||||
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
from langchain_core.utils.pydantic import get_fields, is_basemodel_subclass
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
SecretStr,
|
||||
model_validator,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -248,7 +249,7 @@ class QianfanChatEndpoint(BaseChatModel):
|
||||
Tool calling:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class GetWeather(BaseModel):
|
||||
@@ -287,7 +288,7 @@ class QianfanChatEndpoint(BaseChatModel):
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Joke(BaseModel):
|
||||
@@ -380,11 +381,13 @@ class QianfanChatEndpoint(BaseChatModel):
|
||||
endpoint: Optional[str] = None
|
||||
"""Endpoint of the Qianfan LLM, required if custom model used."""
|
||||
|
||||
class Config:
|
||||
allow_population_by_field_name = True
|
||||
model_config = ConfigDict(
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_environment(cls, values: Dict) -> Any:
|
||||
values["qianfan_ak"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(
|
||||
values, ["qianfan_ak", "api_key"], "QIANFAN_AK", default=""
|
||||
@@ -747,7 +750,7 @@ class QianfanChatEndpoint(BaseChatModel):
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_mistralai import QianfanChatEndpoint
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from pydantic import BaseModel
|
||||
|
||||
class AnswerWithJustification(BaseModel):
|
||||
'''An answer to the user question along with justification for the answer.'''
|
||||
@@ -768,7 +771,7 @@ class QianfanChatEndpoint(BaseChatModel):
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_mistralai import QianfanChatEndpoint
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from pydantic import BaseModel
|
||||
|
||||
class AnswerWithJustification(BaseModel):
|
||||
'''An answer to the user question along with justification for the answer.'''
|
||||
@@ -789,7 +792,7 @@ class QianfanChatEndpoint(BaseChatModel):
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_mistralai import QianfanChatEndpoint
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from pydantic import BaseModel
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
|
||||
class AnswerWithJustification(BaseModel):
|
||||
|
||||
@@ -16,6 +16,7 @@ from langchain_core.messages import (
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from langchain_community.chat_models.anthropic import (
|
||||
convert_messages_to_prompt_anthropic,
|
||||
@@ -231,8 +232,9 @@ class BedrockChat(BaseChatModel, BedrockBase):
|
||||
|
||||
return attributes
|
||||
|
||||
class Config:
|
||||
extra = "forbid"
|
||||
model_config = ConfigDict(
|
||||
extra="forbid",
|
||||
)
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
|
||||
@@ -19,6 +19,7 @@ from langchain_core.messages import (
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from langchain_community.llms.cohere import BaseCohere
|
||||
|
||||
@@ -117,9 +118,10 @@ class ChatCohere(BaseChatModel, BaseCohere):
|
||||
chat.invoke(messages)
|
||||
"""
|
||||
|
||||
class Config:
|
||||
allow_population_by_field_name = True
|
||||
arbitrary_types_allowed = True
|
||||
model_config = ConfigDict(
|
||||
populate_by_name=True,
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
|
||||
@@ -19,11 +19,11 @@ from langchain_core.messages import (
|
||||
HumanMessageChunk,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
|
||||
from langchain_core.utils import (
|
||||
convert_to_secret_str,
|
||||
get_from_dict_or_env,
|
||||
)
|
||||
from pydantic import ConfigDict, Field, SecretStr, model_validator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -111,11 +111,13 @@ class ChatCoze(BaseChatModel):
|
||||
"Streaming response" will provide real-time response of the model to the client, and
|
||||
the client needs to assemble the final reply based on the type of message. """
|
||||
|
||||
class Config:
|
||||
allow_population_by_field_name = True
|
||||
model_config = ConfigDict(
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_environment(cls, values: Dict) -> Any:
|
||||
values["coze_api_base"] = get_from_dict_or_env(
|
||||
values,
|
||||
"coze_api_base",
|
||||
|
||||
@@ -13,8 +13,8 @@ from langchain_core.messages import (
|
||||
BaseMessage,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
||||
from pydantic import ConfigDict, Field, SecretStr, model_validator
|
||||
|
||||
from langchain_community.utilities.requests import Requests
|
||||
|
||||
@@ -70,11 +70,13 @@ class ChatDappierAI(BaseChatModel):
|
||||
|
||||
dappier_api_key: Optional[SecretStr] = Field(None, description="Dappier API Token")
|
||||
|
||||
class Config:
|
||||
extra = "forbid"
|
||||
model_config = ConfigDict(
|
||||
extra="forbid",
|
||||
)
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_environment(cls, values: Dict) -> Any:
|
||||
"""Validate that api key exists in environment."""
|
||||
values["dappier_api_key"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(values, "dappier_api_key", "DAPPIER_API_KEY")
|
||||
|
||||
@@ -54,11 +54,12 @@ from langchain_core.outputs import (
|
||||
ChatGenerationChunk,
|
||||
ChatResult,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
|
||||
from langchain_core.runnables import Runnable
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from langchain_community.utilities.requests import Requests
|
||||
|
||||
@@ -222,10 +223,9 @@ class ChatDeepInfra(BaseChatModel):
|
||||
streaming: bool = False
|
||||
max_retries: int = 1
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
allow_population_by_field_name = True
|
||||
model_config = ConfigDict(
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
@@ -291,8 +291,9 @@ class ChatDeepInfra(BaseChatModel):
|
||||
|
||||
return await _completion_with_retry(**kwargs)
|
||||
|
||||
@root_validator(pre=True)
|
||||
def init_defaults(cls, values: Dict) -> Dict:
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def init_defaults(cls, values: Dict) -> Any:
|
||||
"""Validate api key, python package exists, temperature, top_p, and top_k."""
|
||||
# For compatibility with LiteLLM
|
||||
api_key = get_from_dict_or_env(
|
||||
@@ -309,18 +310,18 @@ class ChatDeepInfra(BaseChatModel):
|
||||
)
|
||||
return values
|
||||
|
||||
@root_validator(pre=False, skip_on_failure=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
if values["temperature"] is not None and not 0 <= values["temperature"] <= 1:
|
||||
@model_validator(mode="after")
|
||||
def validate_environment(self) -> Self:
|
||||
if self.temperature is not None and not 0 <= self.temperature <= 1:
|
||||
raise ValueError("temperature must be in the range [0.0, 1.0]")
|
||||
|
||||
if values["top_p"] is not None and not 0 <= values["top_p"] <= 1:
|
||||
if self.top_p is not None and not 0 <= self.top_p <= 1:
|
||||
raise ValueError("top_p must be in the range [0.0, 1.0]")
|
||||
|
||||
if values["top_k"] is not None and values["top_k"] <= 0:
|
||||
if self.top_k is not None and self.top_k <= 0:
|
||||
raise ValueError("top_k must be positive")
|
||||
|
||||
return values
|
||||
return self
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
|
||||
@@ -47,16 +47,17 @@ from langchain_core.output_parsers.openai_tools import (
|
||||
PydanticToolsParser,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.pydantic_v1 import (
|
||||
BaseModel,
|
||||
Field,
|
||||
SecretStr,
|
||||
)
|
||||
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
SecretStr,
|
||||
)
|
||||
|
||||
from langchain_community.utilities.requests import Requests
|
||||
|
||||
@@ -296,8 +297,9 @@ class ChatEdenAI(BaseChatModel):
|
||||
|
||||
edenai_api_key: Optional[SecretStr] = Field(None, description="EdenAI API Token")
|
||||
|
||||
class Config:
|
||||
extra = "forbid"
|
||||
model_config = ConfigDict(
|
||||
extra="forbid",
|
||||
)
|
||||
|
||||
@pre_init
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
|
||||
@@ -13,8 +13,8 @@ from langchain_core.messages import (
|
||||
HumanMessage,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
from langchain_core.pydantic_v1 import root_validator
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
from pydantic import model_validator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -108,8 +108,9 @@ class ErnieBotChat(BaseChatModel):
|
||||
|
||||
_lock = threading.Lock()
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_environment(cls, values: Dict) -> Any:
|
||||
values["ernie_api_base"] = get_from_dict_or_env(
|
||||
values, "ernie_api_base", "ERNIE_API_BASE", "https://aip.baidubce.com"
|
||||
)
|
||||
|
||||
@@ -4,11 +4,11 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from typing import TYPE_CHECKING, Dict, Optional, Set
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Set
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.pydantic_v1 import Field, root_validator
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
||||
from pydantic import Field, model_validator
|
||||
|
||||
from langchain_community.adapters.openai import convert_message_to_dict
|
||||
from langchain_community.chat_models.openai import (
|
||||
@@ -76,8 +76,9 @@ class ChatEverlyAI(ChatOpenAI):
|
||||
]
|
||||
)
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_environment_override(cls, values: dict) -> dict:
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_environment_override(cls, values: dict) -> Any:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
values["openai_api_key"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(
|
||||
|
||||
@@ -32,9 +32,9 @@ from langchain_core.messages import (
|
||||
SystemMessageChunk,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
|
||||
from langchain_core.utils import convert_to_secret_str
|
||||
from langchain_core.utils.env import get_from_dict_or_env
|
||||
from pydantic import Field, SecretStr, model_validator
|
||||
|
||||
from langchain_community.adapters.openai import convert_message_to_dict
|
||||
|
||||
@@ -112,8 +112,9 @@ class ChatFireworks(BaseChatModel):
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "chat_models", "fireworks"]
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_environment(cls, values: Dict) -> Any:
|
||||
"""Validate that api key in environment."""
|
||||
try:
|
||||
import fireworks.client
|
||||
|
||||
@@ -21,8 +21,8 @@ from langchain_core.outputs import (
|
||||
ChatGeneration,
|
||||
ChatResult,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import BaseModel, SecretStr
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init
|
||||
from pydantic import BaseModel, SecretStr
|
||||
from tenacity import (
|
||||
before_sleep_log,
|
||||
retry,
|
||||
|
||||
@@ -30,8 +30,9 @@ from langchain_core.language_models.chat_models import (
|
||||
from langchain_core.language_models.llms import create_base_retry_decorator
|
||||
from langchain_core.messages import AIMessageChunk, BaseMessage, BaseMessageChunk
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
||||
from pydantic import BaseModel, Field, SecretStr, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from langchain_community.adapters.openai import (
|
||||
convert_dict_to_message,
|
||||
@@ -150,7 +151,7 @@ class GPTRouter(BaseChatModel):
|
||||
"""
|
||||
|
||||
client: Any = Field(default=None, exclude=True) #: :meta private:
|
||||
models_priority_list: List[GPTRouterModel] = Field(min_items=1)
|
||||
models_priority_list: List[GPTRouterModel] = Field(min_length=1)
|
||||
gpt_router_api_base: str = Field(default=None)
|
||||
"""WriteSonic GPTRouter custom endpoint"""
|
||||
gpt_router_api_key: Optional[SecretStr] = None
|
||||
@@ -167,8 +168,9 @@ class GPTRouter(BaseChatModel):
|
||||
"""Number of chat completions to generate for each prompt."""
|
||||
max_tokens: int = 256
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_environment(cls, values: Dict) -> Any:
|
||||
values["gpt_router_api_base"] = get_from_dict_or_env(
|
||||
values,
|
||||
"gpt_router_api_base",
|
||||
@@ -185,8 +187,8 @@ class GPTRouter(BaseChatModel):
|
||||
)
|
||||
return values
|
||||
|
||||
@root_validator(pre=True, skip_on_failure=True)
|
||||
def post_init(cls, values: Dict) -> Dict:
|
||||
@model_validator(mode="after")
|
||||
def post_init(self) -> Self:
|
||||
try:
|
||||
from gpt_router.client import GPTRouterClient
|
||||
|
||||
@@ -197,12 +199,14 @@ class GPTRouter(BaseChatModel):
|
||||
)
|
||||
|
||||
gpt_router_client = GPTRouterClient(
|
||||
values["gpt_router_api_base"],
|
||||
values["gpt_router_api_key"].get_secret_value(),
|
||||
self.gpt_router_api_base,
|
||||
self.gpt_router_api_key.get_secret_value()
|
||||
if self.gpt_router_api_key
|
||||
else None,
|
||||
)
|
||||
values["client"] = gpt_router_client
|
||||
self.client = gpt_router_client
|
||||
|
||||
return values
|
||||
return self
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
|
||||
@@ -25,7 +25,8 @@ from langchain_core.outputs import (
|
||||
ChatResult,
|
||||
LLMResult,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import root_validator
|
||||
from pydantic import model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint
|
||||
from langchain_community.llms.huggingface_hub import HuggingFaceHub
|
||||
@@ -76,17 +77,17 @@ class ChatHuggingFace(BaseChatModel):
|
||||
else self.tokenizer
|
||||
)
|
||||
|
||||
@root_validator(pre=False, skip_on_failure=True)
|
||||
def validate_llm(cls, values: dict) -> dict:
|
||||
@model_validator(mode="after")
|
||||
def validate_llm(self) -> Self:
|
||||
if not isinstance(
|
||||
values["llm"],
|
||||
self.llm,
|
||||
(HuggingFaceTextGenInference, HuggingFaceEndpoint, HuggingFaceHub),
|
||||
):
|
||||
raise TypeError(
|
||||
"Expected llm to be one of HuggingFaceTextGenInference, "
|
||||
f"HuggingFaceEndpoint, HuggingFaceHub, received {type(values['llm'])}"
|
||||
f"HuggingFaceEndpoint, HuggingFaceHub, received {type(self.llm)}"
|
||||
)
|
||||
return values
|
||||
return self
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
|
||||
@@ -15,7 +15,7 @@ from langchain_core.messages import (
|
||||
messages_to_dict,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
from pydantic import Field
|
||||
|
||||
from langchain_community.llms.utils import enforce_stop_tokens
|
||||
|
||||
|
||||
@@ -19,13 +19,13 @@ from langchain_core.messages import (
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
|
||||
from langchain_core.utils import (
|
||||
convert_to_secret_str,
|
||||
get_from_dict_or_env,
|
||||
get_pydantic_field_names,
|
||||
pre_init,
|
||||
)
|
||||
from pydantic import ConfigDict, Field, SecretStr, model_validator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -138,11 +138,13 @@ class ChatHunyuan(BaseChatModel):
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Holds any model parameters valid for API call not explicitly specified."""
|
||||
|
||||
class Config:
|
||||
allow_population_by_field_name = True
|
||||
model_config = ConfigDict(
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
||||
@root_validator(pre=True)
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Any:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = get_pydantic_field_names(cls)
|
||||
extra = values.get("model_kwargs", {})
|
||||
|
||||
@@ -18,7 +18,7 @@ from langchain_core.outputs import (
|
||||
ChatGeneration,
|
||||
ChatResult,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr
|
||||
from pydantic import BaseModel, ConfigDict, Field, SecretStr
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -62,14 +62,15 @@ class ChatJavelinAIGateway(BaseChatModel):
|
||||
params: Optional[ChatParams] = None
|
||||
"""Parameters for the Javelin AI Gateway LLM."""
|
||||
|
||||
client: Any
|
||||
client: Any = None
|
||||
"""javelin client."""
|
||||
|
||||
javelin_api_key: Optional[SecretStr] = Field(None, alias="api_key")
|
||||
"""The API key for the Javelin AI Gateway."""
|
||||
|
||||
class Config:
|
||||
allow_population_by_field_name = True
|
||||
model_config = ConfigDict(
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
try:
|
||||
|
||||
@@ -40,13 +40,13 @@ from langchain_core.messages import (
|
||||
SystemMessageChunk,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
|
||||
from langchain_core.utils import (
|
||||
convert_to_secret_str,
|
||||
get_from_dict_or_env,
|
||||
get_pydantic_field_names,
|
||||
pre_init,
|
||||
)
|
||||
from pydantic import ConfigDict, Field, SecretStr, model_validator
|
||||
from tenacity import (
|
||||
before_sleep_log,
|
||||
retry,
|
||||
@@ -188,11 +188,13 @@ class JinaChat(BaseChatModel):
|
||||
max_tokens: Optional[int] = None
|
||||
"""Maximum number of tokens to generate."""
|
||||
|
||||
class Config:
|
||||
allow_population_by_field_name = True
|
||||
model_config = ConfigDict(
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
||||
@root_validator(pre=True)
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Any:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = get_pydantic_field_names(cls)
|
||||
extra = values.get("model_kwargs", {})
|
||||
|
||||
@@ -26,7 +26,7 @@ from langchain_core.messages import (
|
||||
)
|
||||
from langchain_core.output_parsers.transform import BaseOutputParser
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult, Generation
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
@@ -543,8 +543,9 @@ class KineticaSqlResponse(BaseModel):
|
||||
dataframe: Any = Field(default=None)
|
||||
"""The Pandas dataframe containing the fetched data."""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
|
||||
class KineticaSqlOutputParser(BaseOutputParser[KineticaSqlResponse]):
|
||||
@@ -582,8 +583,9 @@ class KineticaSqlOutputParser(BaseOutputParser[KineticaSqlResponse]):
|
||||
kdbc: Any = Field(exclude=True)
|
||||
""" Kinetica DB connection. """
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
def parse(self, text: str) -> KineticaSqlResponse:
|
||||
df = self.kdbc.to_df(text)
|
||||
|
||||
@@ -23,8 +23,8 @@ from langchain_core.callbacks import (
|
||||
)
|
||||
from langchain_core.messages import AIMessageChunk, BaseMessage
|
||||
from langchain_core.outputs import ChatGenerationChunk, ChatResult
|
||||
from langchain_core.pydantic_v1 import Field, SecretStr
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init
|
||||
from pydantic import Field, SecretStr
|
||||
|
||||
from langchain_community.adapters.openai import (
|
||||
convert_message_to_dict,
|
||||
|
||||
@@ -52,11 +52,11 @@ from langchain_core.outputs import (
|
||||
ChatGenerationChunk,
|
||||
ChatResult,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
from langchain_core.runnables import Runnable
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils import get_from_dict_or_env, pre_init
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -21,8 +21,8 @@ from langchain_core.messages import (
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.pydantic_v1 import root_validator
|
||||
from langchain_core.utils import get_pydantic_field_names
|
||||
from pydantic import ConfigDict, model_validator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -84,11 +84,13 @@ class LlamaEdgeChatService(BaseChatModel):
|
||||
streaming: bool = False
|
||||
"""Whether to stream the results or not."""
|
||||
|
||||
class Config:
|
||||
allow_population_by_field_name = True
|
||||
model_config = ConfigDict(
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
||||
@root_validator(pre=True)
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Any:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = get_pydantic_field_names(cls)
|
||||
extra = values.get("model_kwargs", {})
|
||||
|
||||
@@ -46,15 +46,16 @@ from langchain_core.output_parsers.openai_tools import (
|
||||
parse_tool_call,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.pydantic_v1 import (
|
||||
BaseModel,
|
||||
Field,
|
||||
root_validator,
|
||||
)
|
||||
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
Field,
|
||||
model_validator,
|
||||
)
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
class ChatLlamaCpp(BaseChatModel):
|
||||
@@ -172,8 +173,8 @@ class ChatLlamaCpp(BaseChatModel):
|
||||
verbose: bool = True
|
||||
"""Print verbose output to stderr."""
|
||||
|
||||
@root_validator(pre=False, skip_on_failure=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
@model_validator(mode="after")
|
||||
def validate_environment(self) -> Self:
|
||||
"""Validate that llama-cpp-python library is installed."""
|
||||
try:
|
||||
from llama_cpp import Llama, LlamaGrammar
|
||||
@@ -184,7 +185,7 @@ class ChatLlamaCpp(BaseChatModel):
|
||||
"use this embedding model: pip install llama-cpp-python"
|
||||
)
|
||||
|
||||
model_path = values["model_path"]
|
||||
model_path = self.model_path
|
||||
model_param_names = [
|
||||
"rope_freq_scale",
|
||||
"rope_freq_base",
|
||||
@@ -203,35 +204,35 @@ class ChatLlamaCpp(BaseChatModel):
|
||||
"last_n_tokens_size",
|
||||
"verbose",
|
||||
]
|
||||
model_params = {k: values[k] for k in model_param_names}
|
||||
model_params = {k: getattr(self, k) for k in model_param_names}
|
||||
# For backwards compatibility, only include if non-null.
|
||||
if values["n_gpu_layers"] is not None:
|
||||
model_params["n_gpu_layers"] = values["n_gpu_layers"]
|
||||
if self.n_gpu_layers is not None:
|
||||
model_params["n_gpu_layers"] = self.n_gpu_layers
|
||||
|
||||
model_params.update(values["model_kwargs"])
|
||||
model_params.update(self.model_kwargs)
|
||||
|
||||
try:
|
||||
values["client"] = Llama(model_path, **model_params)
|
||||
self.client = Llama(model_path, **model_params)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Could not load Llama model from path: {model_path}. "
|
||||
f"Received error {e}"
|
||||
)
|
||||
|
||||
if values["grammar"] and values["grammar_path"]:
|
||||
grammar = values["grammar"]
|
||||
grammar_path = values["grammar_path"]
|
||||
if self.grammar and self.grammar_path:
|
||||
grammar = self.grammar
|
||||
grammar_path = self.grammar_path
|
||||
raise ValueError(
|
||||
"Can only pass in one of grammar and grammar_path. Received "
|
||||
f"{grammar=} and {grammar_path=}."
|
||||
)
|
||||
elif isinstance(values["grammar"], str):
|
||||
values["grammar"] = LlamaGrammar.from_string(values["grammar"])
|
||||
elif values["grammar_path"]:
|
||||
values["grammar"] = LlamaGrammar.from_file(values["grammar_path"])
|
||||
elif isinstance(self.grammar, str):
|
||||
self.grammar = LlamaGrammar.from_string(self.grammar)
|
||||
elif self.grammar_path:
|
||||
self.grammar = LlamaGrammar.from_file(self.grammar_path)
|
||||
else:
|
||||
pass
|
||||
return values
|
||||
return self
|
||||
|
||||
def _get_parameters(self, stop: Optional[List[str]]) -> Dict[str, Any]:
|
||||
"""
|
||||
@@ -433,7 +434,7 @@ class ChatLlamaCpp(BaseChatModel):
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.chat_models import ChatLlamaCpp
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from pydantic import BaseModel
|
||||
|
||||
class AnswerWithJustification(BaseModel):
|
||||
'''An answer to the user question along with justification for the answer.'''
|
||||
@@ -465,7 +466,7 @@ class ChatLlamaCpp(BaseChatModel):
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.chat_models import ChatLlamaCpp
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from pydantic import BaseModel
|
||||
|
||||
class AnswerWithJustification(BaseModel):
|
||||
'''An answer to the user question along with justification for the answer.'''
|
||||
@@ -497,7 +498,7 @@ class ChatLlamaCpp(BaseChatModel):
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.chat_models import ChatLlamaCpp
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from pydantic import BaseModel
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
|
||||
class AnswerWithJustification(BaseModel):
|
||||
|
||||
@@ -16,7 +16,7 @@ from langchain_core.messages import (
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
from pydantic import Field
|
||||
from requests import Response
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
|
||||
@@ -44,12 +44,18 @@ from langchain_core.output_parsers.openai_tools import (
|
||||
PydanticToolsParser,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
|
||||
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
from langchain_core.utils.pydantic import get_fields
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
SecretStr,
|
||||
model_validator,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -267,7 +273,7 @@ class MiniMaxChat(BaseChatModel):
|
||||
Tool calling:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class GetWeather(BaseModel):
|
||||
@@ -307,7 +313,7 @@ class MiniMaxChat(BaseChatModel):
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Joke(BaseModel):
|
||||
@@ -384,11 +390,13 @@ class MiniMaxChat(BaseChatModel):
|
||||
streaming: bool = False
|
||||
"""Whether to stream the results or not."""
|
||||
|
||||
class Config:
|
||||
allow_population_by_field_name = True
|
||||
model_config = ConfigDict(
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_environment(cls, values: Dict) -> Any:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
values["minimax_api_key"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(
|
||||
@@ -694,7 +702,7 @@ class MiniMaxChat(BaseChatModel):
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.chat_models import MiniMaxChat
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from pydantic import BaseModel
|
||||
|
||||
class AnswerWithJustification(BaseModel):
|
||||
'''An answer to the user question along with justification for the answer.'''
|
||||
@@ -715,7 +723,7 @@ class MiniMaxChat(BaseChatModel):
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.chat_models import MiniMaxChat
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from pydantic import BaseModel
|
||||
|
||||
class AnswerWithJustification(BaseModel):
|
||||
'''An answer to the user question along with justification for the answer.'''
|
||||
@@ -737,7 +745,7 @@ class MiniMaxChat(BaseChatModel):
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.chat_models import MiniMaxChat
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from pydantic import BaseModel
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
|
||||
class AnswerWithJustification(BaseModel):
|
||||
|
||||
@@ -42,14 +42,14 @@ from langchain_core.output_parsers.openai_tools import (
|
||||
parse_tool_call,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.pydantic_v1 import (
|
||||
from langchain_core.runnables import Runnable, RunnableConfig
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
Field,
|
||||
PrivateAttr,
|
||||
)
|
||||
from langchain_core.runnables import Runnable, RunnableConfig
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ from langchain_core.outputs import (
|
||||
ChatGeneration,
|
||||
ChatResult,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from pydantic import BaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -38,10 +38,10 @@ from langchain_core.output_parsers.openai_tools import (
|
||||
PydanticToolsParser,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.runnables import Runnable
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils.function_calling import convert_to_openai_function
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from langchain_community.llms.oci_generative_ai import OCIGenAIBase
|
||||
from langchain_community.llms.utils import enforce_stop_tokens
|
||||
@@ -499,8 +499,10 @@ class ChatOCIGenAI(BaseChatModel, OCIGenAIBase):
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
class Config:
|
||||
extra = "forbid"
|
||||
model_config = ConfigDict(
|
||||
extra="forbid",
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
@@ -546,6 +548,9 @@ class ChatOCIGenAI(BaseChatModel, OCIGenAIBase):
|
||||
|
||||
chat_params = {**_model_kwargs, **kwargs, **oci_params}
|
||||
|
||||
if not self.model_id:
|
||||
raise ValueError("Model ID is required to chat")
|
||||
|
||||
if self.model_id.startswith(CUSTOM_ENDPOINT_PREFIX):
|
||||
serving_mode = models.DedicatedServingMode(endpoint_id=self.model_id)
|
||||
else:
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from langchain_core.pydantic_v1 import Field, SecretStr
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init
|
||||
from pydantic import Field, SecretStr
|
||||
|
||||
from langchain_community.chat_models.openai import ChatOpenAI
|
||||
from langchain_community.utils.openai import is_openai_v1
|
||||
|
||||
@@ -44,13 +44,13 @@ from langchain_core.messages import (
|
||||
ToolMessageChunk,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
|
||||
from langchain_core.runnables import Runnable
|
||||
from langchain_core.utils import (
|
||||
get_from_dict_or_env,
|
||||
get_pydantic_field_names,
|
||||
pre_init,
|
||||
)
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
from langchain_community.adapters.openai import (
|
||||
convert_dict_to_message,
|
||||
@@ -244,11 +244,13 @@ class ChatOpenAI(BaseChatModel):
|
||||
http_client: Union[Any, None] = None
|
||||
"""Optional httpx.Client."""
|
||||
|
||||
class Config:
|
||||
allow_population_by_field_name = True
|
||||
model_config = ConfigDict(
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
||||
@root_validator(pre=True)
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Any:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = get_pydantic_field_names(cls)
|
||||
extra = values.get("model_kwargs", {})
|
||||
|
||||
@@ -17,8 +17,8 @@ from langchain_core.messages import (
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.pydantic_v1 import root_validator
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
from pydantic import model_validator
|
||||
|
||||
from langchain_community.llms.utils import enforce_stop_tokens
|
||||
|
||||
@@ -67,8 +67,9 @@ class PaiEasChatEndpoint(BaseChatModel):
|
||||
|
||||
timeout: Optional[int] = 5000
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_environment(cls, values: Dict) -> Any:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
values["eas_service_url"] = get_from_dict_or_env(
|
||||
values, "eas_service_url", "EAS_SERVICE_URL"
|
||||
|
||||
@@ -35,8 +35,12 @@ from langchain_core.messages import (
|
||||
ToolMessageChunk,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.pydantic_v1 import Field, root_validator
|
||||
from langchain_core.utils import get_from_dict_or_env, get_pydantic_field_names
|
||||
from langchain_core.utils import (
|
||||
from_env,
|
||||
get_pydantic_field_names,
|
||||
)
|
||||
from pydantic import ConfigDict, Field, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -60,14 +64,16 @@ class ChatPerplexity(BaseChatModel):
|
||||
)
|
||||
"""
|
||||
|
||||
client: Any #: :meta private:
|
||||
client: Any = None #: :meta private:
|
||||
model: str = "llama-3.1-sonar-small-128k-online"
|
||||
"""Model name."""
|
||||
temperature: float = 0.7
|
||||
"""What sampling temperature to use."""
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Holds any model parameters valid for `create` call not explicitly specified."""
|
||||
pplx_api_key: Optional[str] = Field(None, alias="api_key")
|
||||
pplx_api_key: Optional[str] = Field(
|
||||
default_factory=from_env("PPLX_API_KEY", default=None), alias="api_key"
|
||||
)
|
||||
"""Base URL path for API requests,
|
||||
leave blank if not using a proxy or service emulator."""
|
||||
request_timeout: Optional[Union[float, Tuple[float, float]]] = Field(
|
||||
@@ -81,15 +87,17 @@ class ChatPerplexity(BaseChatModel):
|
||||
max_tokens: Optional[int] = None
|
||||
"""Maximum number of tokens to generate."""
|
||||
|
||||
class Config:
|
||||
allow_population_by_field_name = True
|
||||
model_config = ConfigDict(
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
return {"pplx_api_key": "PPLX_API_KEY"}
|
||||
|
||||
@root_validator(pre=True)
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Any:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = get_pydantic_field_names(cls)
|
||||
extra = values.get("model_kwargs", {})
|
||||
@@ -114,12 +122,9 @@ class ChatPerplexity(BaseChatModel):
|
||||
values["model_kwargs"] = extra
|
||||
return values
|
||||
|
||||
@root_validator(pre=False, skip_on_failure=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
@model_validator(mode="after")
|
||||
def validate_environment(self) -> Self:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
values["pplx_api_key"] = get_from_dict_or_env(
|
||||
values, "pplx_api_key", "PPLX_API_KEY"
|
||||
)
|
||||
try:
|
||||
import openai
|
||||
except ImportError:
|
||||
@@ -128,8 +133,8 @@ class ChatPerplexity(BaseChatModel):
|
||||
"Please install it with `pip install openai`."
|
||||
)
|
||||
try:
|
||||
values["client"] = openai.OpenAI(
|
||||
api_key=values["pplx_api_key"], base_url="https://api.perplexity.ai"
|
||||
self.client = openai.OpenAI(
|
||||
api_key=self.pplx_api_key, base_url="https://api.perplexity.ai"
|
||||
)
|
||||
except AttributeError:
|
||||
raise ValueError(
|
||||
@@ -137,7 +142,7 @@ class ChatPerplexity(BaseChatModel):
|
||||
"due to an old version of the openai package. Try upgrading it "
|
||||
"with `pip install --upgrade openai`."
|
||||
)
|
||||
return values
|
||||
return self
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
|
||||
@@ -38,15 +38,16 @@ from langchain_core.messages import (
|
||||
ToolMessage,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.pydantic_v1 import (
|
||||
BaseModel,
|
||||
Field,
|
||||
SecretStr,
|
||||
)
|
||||
from langchain_core.runnables import Runnable
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils import get_from_dict_or_env, pre_init
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
SecretStr,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from premai.api.chat_completions.v1_chat_completions_create import (
|
||||
@@ -306,10 +307,11 @@ class ChatPremAI(BaseChatModel, BaseModel):
|
||||
|
||||
client: Any
|
||||
|
||||
class Config:
|
||||
allow_population_by_field_name = True
|
||||
arbitrary_types_allowed = True
|
||||
extra = "forbid"
|
||||
model_config = ConfigDict(
|
||||
populate_by_name=True,
|
||||
arbitrary_types_allowed=True,
|
||||
extra="forbid",
|
||||
)
|
||||
|
||||
@pre_init
|
||||
def validate_environments(cls, values: Dict) -> Dict:
|
||||
|
||||
@@ -11,7 +11,6 @@ from langchain_core.messages import (
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
|
||||
from langchain_core.utils import (
|
||||
convert_to_secret_str,
|
||||
get_from_dict_or_env,
|
||||
@@ -19,6 +18,7 @@ from langchain_core.utils import (
|
||||
pre_init,
|
||||
)
|
||||
from langchain_core.utils.utils import build_extra_kwargs
|
||||
from pydantic import Field, SecretStr, model_validator
|
||||
|
||||
SUPPORTED_ROLES: List[str] = [
|
||||
"system",
|
||||
@@ -126,8 +126,9 @@ class ChatSnowflakeCortex(BaseChatModel):
|
||||
snowflake_role: Optional[str] = Field(default=None, alias="role")
|
||||
"""Automatically inferred from env var `SNOWFLAKE_ROLE` if not provided."""
|
||||
|
||||
@root_validator(pre=True)
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Any:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = get_pydantic_field_names(cls)
|
||||
extra = values.get("model_kwargs", {})
|
||||
|
||||
@@ -3,8 +3,8 @@
|
||||
from typing import Dict
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
from langchain_core.utils import get_from_dict_or_env, pre_init
|
||||
from pydantic import ConfigDict, Field
|
||||
|
||||
from langchain_community.chat_models import ChatOpenAI
|
||||
from langchain_community.llms.solar import SOLAR_SERVICE_URL_BASE, SolarCommon
|
||||
@@ -30,10 +30,11 @@ class SolarChat(SolarCommon, ChatOpenAI):
|
||||
max_tokens: int = Field(default=1024)
|
||||
|
||||
# this is needed to match ChatOpenAI superclass
|
||||
class Config:
|
||||
allow_population_by_field_name = True
|
||||
arbitrary_types_allowed = True
|
||||
extra = "ignore"
|
||||
model_config = ConfigDict(
|
||||
populate_by_name=True,
|
||||
arbitrary_types_allowed=True,
|
||||
extra="ignore",
|
||||
)
|
||||
|
||||
@pre_init
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
|
||||
@@ -41,12 +41,12 @@ from langchain_core.outputs import (
|
||||
ChatGenerationChunk,
|
||||
ChatResult,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import Field, root_validator
|
||||
from langchain_core.utils import (
|
||||
get_from_dict_or_env,
|
||||
get_pydantic_field_names,
|
||||
)
|
||||
from langchain_core.utils.pydantic import get_fields
|
||||
from pydantic import ConfigDict, Field, model_validator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -296,11 +296,13 @@ class ChatSparkLLM(BaseChatModel):
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Holds any model parameters valid for API call not explicitly specified."""
|
||||
|
||||
class Config:
|
||||
allow_population_by_field_name = True
|
||||
model_config = ConfigDict(
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
||||
@root_validator(pre=True)
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Any:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = get_pydantic_field_names(cls)
|
||||
extra = values.get("model_kwargs", {})
|
||||
@@ -326,8 +328,9 @@ class ChatSparkLLM(BaseChatModel):
|
||||
|
||||
return values
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_environment(cls, values: Dict) -> Any:
|
||||
values["spark_app_id"] = get_from_dict_or_env(
|
||||
values,
|
||||
["spark_app_id", "app_id"],
|
||||
|
||||
@@ -16,8 +16,8 @@ from langchain_core.language_models.chat_models import (
|
||||
)
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.pydantic_v1 import Field, SecretStr
|
||||
from langchain_core.utils import convert_to_secret_str
|
||||
from pydantic import ConfigDict, Field, SecretStr
|
||||
|
||||
|
||||
def _convert_role(role: str) -> str:
|
||||
@@ -89,11 +89,10 @@ class ChatNebula(BaseChatModel):
|
||||
|
||||
nebula_api_key: Optional[SecretStr] = Field(None, description="Nebula API Token")
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
allow_population_by_field_name = True
|
||||
arbitrary_types_allowed = True
|
||||
model_config = ConfigDict(
|
||||
populate_by_name=True,
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
if "nebula_api_key" in kwargs:
|
||||
|
||||
@@ -53,16 +53,17 @@ from langchain_core.outputs import (
|
||||
ChatGenerationChunk,
|
||||
ChatResult,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import (
|
||||
BaseModel,
|
||||
Field,
|
||||
SecretStr,
|
||||
)
|
||||
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
SecretStr,
|
||||
)
|
||||
from requests.exceptions import HTTPError
|
||||
from tenacity import (
|
||||
before_sleep_log,
|
||||
@@ -348,7 +349,7 @@ class ChatTongyi(BaseChatModel):
|
||||
Tool calling:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class GetWeather(BaseModel):
|
||||
@@ -386,7 +387,7 @@ class ChatTongyi(BaseChatModel):
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Joke(BaseModel):
|
||||
@@ -457,8 +458,9 @@ class ChatTongyi(BaseChatModel):
|
||||
max_retries: int = 10
|
||||
"""Maximum number of retries to make when generating."""
|
||||
|
||||
class Config:
|
||||
allow_population_by_field_name = True
|
||||
model_config = ConfigDict(
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
|
||||
@@ -25,12 +25,12 @@ from langchain_core.messages import (
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.pydantic_v1 import Field, SecretStr
|
||||
from langchain_core.utils import (
|
||||
convert_to_secret_str,
|
||||
get_from_dict_or_env,
|
||||
get_pydantic_field_names,
|
||||
)
|
||||
from pydantic import ConfigDict, Field, SecretStr
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -115,8 +115,9 @@ class ChatYi(BaseChatModel):
|
||||
top_p: float = 0.7
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
class Config:
|
||||
allow_population_by_field_name = True
|
||||
model_config = ConfigDict(
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
kwargs["yi_api_key"] = convert_to_secret_str(
|
||||
|
||||
@@ -40,12 +40,12 @@ from langchain_core.messages import (
|
||||
SystemMessageChunk,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
|
||||
from langchain_core.utils import (
|
||||
get_from_dict_or_env,
|
||||
get_pydantic_field_names,
|
||||
pre_init,
|
||||
)
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
from tenacity import (
|
||||
before_sleep_log,
|
||||
retry,
|
||||
@@ -122,8 +122,9 @@ class ChatYuan2(BaseChatModel):
|
||||
repeat_penalty: Optional[float] = 1.18
|
||||
"""The penalty to apply to repeated tokens."""
|
||||
|
||||
class Config:
|
||||
allow_population_by_field_name = True
|
||||
model_config = ConfigDict(
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
@@ -141,8 +142,9 @@ class ChatYuan2(BaseChatModel):
|
||||
|
||||
return attributes
|
||||
|
||||
@root_validator(pre=True)
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Any:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = get_pydantic_field_names(cls)
|
||||
extra = values.get("model_kwargs", {})
|
||||
|
||||
@@ -50,11 +50,11 @@ from langchain_core.output_parsers.openai_tools import (
|
||||
PydanticToolsParser,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
|
||||
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -331,7 +331,7 @@ class ChatZhipuAI(BaseChatModel):
|
||||
Tool calling:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class GetWeather(BaseModel):
|
||||
@@ -371,7 +371,7 @@ class ChatZhipuAI(BaseChatModel):
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Joke(BaseModel):
|
||||
@@ -480,11 +480,13 @@ class ChatZhipuAI(BaseChatModel):
|
||||
max_tokens: Optional[int] = None
|
||||
"""Maximum number of tokens to generate."""
|
||||
|
||||
class Config:
|
||||
allow_population_by_field_name = True
|
||||
model_config = ConfigDict(
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_environment(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_environment(cls, values: Dict[str, Any]) -> Any:
|
||||
values["zhipuai_api_key"] = get_from_dict_or_env(
|
||||
values, ["zhipuai_api_key", "api_key"], "ZHIPUAI_API_KEY"
|
||||
)
|
||||
@@ -773,7 +775,7 @@ class ChatZhipuAI(BaseChatModel):
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.chat_models import ChatZhipuAI
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from pydantic import BaseModel
|
||||
|
||||
class AnswerWithJustification(BaseModel):
|
||||
'''An answer to the user question along with justification for the answer.'''
|
||||
@@ -793,7 +795,7 @@ class ChatZhipuAI(BaseChatModel):
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.chat_models import ChatZhipuAI
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from pydantic import BaseModel
|
||||
|
||||
class AnswerWithJustification(BaseModel):
|
||||
'''An answer to the user question along with justification for the answer.'''
|
||||
@@ -814,7 +816,7 @@ class ChatZhipuAI(BaseChatModel):
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.chat_models import ChatZhipuAI
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from pydantic import BaseModel
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
|
||||
class AnswerWithJustification(BaseModel):
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from difflib import SequenceMatcher
|
||||
from typing import List, Tuple
|
||||
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain_community.cross_encoders.base import BaseCrossEncoder
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from langchain_community.cross_encoders.base import BaseCrossEncoder
|
||||
|
||||
@@ -45,8 +45,9 @@ class HuggingFaceCrossEncoder(BaseModel, BaseCrossEncoder):
|
||||
self.model_name, **self.model_kwargs
|
||||
)
|
||||
|
||||
class Config:
|
||||
extra = "forbid"
|
||||
model_config = ConfigDict(
|
||||
extra="forbid",
|
||||
)
|
||||
|
||||
def score(self, text_pairs: List[Tuple[str, str]]) -> List[float]:
|
||||
"""Compute similarity scores using a HuggingFace transformer model.
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from langchain_core.pydantic_v1 import BaseModel, root_validator
|
||||
from pydantic import BaseModel, ConfigDict, model_validator
|
||||
|
||||
from langchain_community.cross_encoders.base import BaseCrossEncoder
|
||||
|
||||
@@ -89,12 +89,14 @@ class SagemakerEndpointCrossEncoder(BaseModel, BaseCrossEncoder):
|
||||
.. _boto3: <https://boto3.amazonaws.com/v1/documentation/api/latest/index.html>
|
||||
"""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
extra = "forbid"
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
extra="forbid",
|
||||
)
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_environment(cls, values: Dict) -> Any:
|
||||
"""Validate that AWS credentials to and python package exists in environment."""
|
||||
try:
|
||||
import boto3
|
||||
|
||||
@@ -5,8 +5,8 @@ from typing import Any, Dict, List, Optional, Sequence, Union
|
||||
|
||||
from langchain_core.callbacks.base import Callbacks
|
||||
from langchain_core.documents import BaseDocumentCompressor, Document
|
||||
from langchain_core.pydantic_v1 import Field, root_validator
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
from pydantic import ConfigDict, Field, model_validator
|
||||
|
||||
|
||||
class DashScopeRerank(BaseDocumentCompressor):
|
||||
@@ -25,13 +25,15 @@ class DashScopeRerank(BaseDocumentCompressor):
|
||||
"""DashScope API key. Must be specified directly or via environment variable
|
||||
DASHSCOPE_API_KEY."""
|
||||
|
||||
class Config:
|
||||
allow_population_by_field_name = True
|
||||
arbitrary_types_allowed = True
|
||||
extra = "forbid"
|
||||
model_config = ConfigDict(
|
||||
populate_by_name=True,
|
||||
arbitrary_types_allowed=True,
|
||||
extra="forbid",
|
||||
)
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_environment(cls, values: Dict) -> Any:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
|
||||
if not values.get("client"):
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Dict, Optional, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence
|
||||
|
||||
from langchain_core.callbacks.manager import Callbacks
|
||||
from langchain_core.documents import BaseDocumentCompressor, Document
|
||||
from langchain_core.pydantic_v1 import root_validator
|
||||
from pydantic import ConfigDict, model_validator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from flashrank import Ranker, RerankRequest
|
||||
@@ -33,12 +33,14 @@ class FlashrankRerank(BaseDocumentCompressor):
|
||||
prefix_metadata: str = ""
|
||||
"""Prefix for flashrank_rerank metadata keys"""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
extra = "forbid"
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
extra="forbid",
|
||||
)
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_environment(cls, values: Dict) -> Any:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
if "client" in values:
|
||||
return values
|
||||
|
||||
@@ -6,8 +6,8 @@ from typing import Any, Dict, List, Optional, Sequence, Union
|
||||
import requests
|
||||
from langchain_core.callbacks import Callbacks
|
||||
from langchain_core.documents import BaseDocumentCompressor, Document
|
||||
from langchain_core.pydantic_v1 import root_validator
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
from pydantic import ConfigDict, model_validator
|
||||
|
||||
JINA_API_URL: str = "https://api.jina.ai/v1/rerank"
|
||||
|
||||
@@ -27,12 +27,14 @@ class JinaRerank(BaseDocumentCompressor):
|
||||
user_agent: str = "langchain"
|
||||
"""Identifier for the application making the request."""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
extra = "forbid"
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
extra="forbid",
|
||||
)
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_environment(cls, values: Dict) -> Any:
|
||||
"""Validate that api key exists in environment."""
|
||||
jina_api_key = get_from_dict_or_env(values, "jina_api_key", "JINA_API_KEY")
|
||||
user_agent = values.get("user_agent", "langchain")
|
||||
|
||||
@@ -8,7 +8,7 @@ from langchain_core.documents import Document
|
||||
from langchain_core.documents.compressor import (
|
||||
BaseDocumentCompressor,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import root_validator
|
||||
from pydantic import ConfigDict, Field, model_validator
|
||||
|
||||
DEFAULT_LLM_LINGUA_INSTRUCTION = (
|
||||
"Given this documents, please answer the final question"
|
||||
@@ -35,9 +35,9 @@ class LLMLinguaCompressor(BaseDocumentCompressor):
|
||||
"""The target number of compressed tokens"""
|
||||
rank_method: str = "longllmlingua"
|
||||
"""The ranking method to use"""
|
||||
model_config: dict = {}
|
||||
model_configuration: dict = Field(default_factory=dict, alias="model_config")
|
||||
"""Custom configuration for the model"""
|
||||
open_api_config: dict = {}
|
||||
open_api_config: dict = Field(default_factory=dict)
|
||||
"""open_api configuration"""
|
||||
instruction: str = DEFAULT_LLM_LINGUA_INSTRUCTION
|
||||
"""The instruction for the LLM"""
|
||||
@@ -52,8 +52,9 @@ class LLMLinguaCompressor(BaseDocumentCompressor):
|
||||
lingua: Any
|
||||
"""The instance of the llm linqua"""
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_environment(cls, values: Dict) -> Any:
|
||||
"""Validate that the python package exists in environment."""
|
||||
try:
|
||||
from llmlingua import PromptCompressor
|
||||
@@ -71,9 +72,11 @@ class LLMLinguaCompressor(BaseDocumentCompressor):
|
||||
)
|
||||
return values
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
extra = "forbid"
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
extra="forbid",
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _format_context(docs: Sequence[Document]) -> List[str]:
|
||||
|
||||
@@ -5,7 +5,7 @@ import numpy as np
|
||||
from langchain_core.callbacks import Callbacks
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.documents.compressor import BaseDocumentCompressor
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
from pydantic import Field
|
||||
|
||||
|
||||
class RerankRequest:
|
||||
|
||||
@@ -7,8 +7,8 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence
|
||||
from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
|
||||
from langchain_core.callbacks.manager import Callbacks
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.pydantic_v1 import Field, PrivateAttr, root_validator
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
from pydantic import ConfigDict, Field, PrivateAttr, model_validator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from rank_llm.data import Candidate, Query, Request
|
||||
@@ -36,12 +36,14 @@ class RankLLMRerank(BaseDocumentCompressor):
|
||||
"""OpenAI model name."""
|
||||
_retriever: Any = PrivateAttr()
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
extra = "forbid"
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
extra="forbid",
|
||||
)
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_environment(cls, values: Dict) -> Any:
|
||||
"""Validate python package exists in environment."""
|
||||
|
||||
if not values.get("client"):
|
||||
|
||||
@@ -5,8 +5,8 @@ from typing import Any, Dict, List, Optional, Sequence, Union
|
||||
|
||||
from langchain_core.callbacks.base import Callbacks
|
||||
from langchain_core.documents import BaseDocumentCompressor, Document
|
||||
from langchain_core.pydantic_v1 import root_validator
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
from pydantic import ConfigDict, model_validator
|
||||
|
||||
|
||||
class VolcengineRerank(BaseDocumentCompressor):
|
||||
@@ -32,13 +32,15 @@ class VolcengineRerank(BaseDocumentCompressor):
|
||||
top_n: Optional[int] = 3
|
||||
"""Number of documents to return."""
|
||||
|
||||
class Config:
|
||||
allow_population_by_field_name = True
|
||||
arbitrary_types_allowed = True
|
||||
extra = "forbid"
|
||||
model_config = ConfigDict(
|
||||
populate_by_name=True,
|
||||
arbitrary_types_allowed=True,
|
||||
extra="forbid",
|
||||
)
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_environment(cls, values: Dict) -> Any:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
|
||||
if not values.get("client"):
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Any, Callable, Dict, List
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.pydantic_v1 import BaseModel, root_validator
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from langchain_community.document_loaders.base import BaseLoader
|
||||
|
||||
@@ -49,8 +49,9 @@ class ApifyDatasetLoader(BaseLoader, BaseModel):
|
||||
dataset_id=dataset_id, dataset_mapping_function=dataset_mapping_function
|
||||
)
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_environment(cls, values: Dict) -> Any:
|
||||
"""Validate environment.
|
||||
|
||||
Args:
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user