langchain/libs/community/tests/integration_tests/document_loaders/test_astradb.py
Bagatur a0c2281540
infra: update mypy 1.10, ruff 0.5 (#23721)
```python
"""python scripts/update_mypy_ruff.py"""
import glob
import tomllib
from pathlib import Path

import toml
import subprocess
import re

ROOT_DIR = Path(__file__).parents[1]


def main():
    for path in glob.glob(str(ROOT_DIR / "libs/**/pyproject.toml"), recursive=True):
        print(path)
        with open(path, "rb") as f:
            pyproject = tomllib.load(f)
        try:
            pyproject["tool"]["poetry"]["group"]["typing"]["dependencies"]["mypy"] = (
                "^1.10"
            )
            pyproject["tool"]["poetry"]["group"]["lint"]["dependencies"]["ruff"] = (
                "^0.5"
            )
        except KeyError:
            continue
        with open(path, "w") as f:
            toml.dump(pyproject, f)
        cwd = "/".join(path.split("/")[:-1])
        completed = subprocess.run(
            "poetry lock --no-update; poetry install --with typing; poetry run mypy . --no-color",
            cwd=cwd,
            shell=True,
            capture_output=True,
            text=True,
        )
        logs = completed.stdout.split("\n")

        to_ignore = {}
        for l in logs:
            if re.match("^(.*)\:(\d+)\: error:.*\[(.*)\]", l):
                path, line_no, error_type = re.match(
                    "^(.*)\:(\d+)\: error:.*\[(.*)\]", l
                ).groups()
                if (path, line_no) in to_ignore:
                    to_ignore[(path, line_no)].append(error_type)
                else:
                    to_ignore[(path, line_no)] = [error_type]
        print(len(to_ignore))
        for (error_path, line_no), error_types in to_ignore.items():
            all_errors = ", ".join(error_types)
            full_path = f"{cwd}/{error_path}"
            try:
                with open(full_path, "r") as f:
                    file_lines = f.readlines()
            except FileNotFoundError:
                continue
            file_lines[int(line_no) - 1] = (
                file_lines[int(line_no) - 1][:-1] + f"  # type: ignore[{all_errors}]\n"
            )
            with open(full_path, "w") as f:
                f.write("".join(file_lines))

        subprocess.run(
            "poetry run ruff format .; poetry run ruff --select I --fix .",
            cwd=cwd,
            shell=True,
            capture_output=True,
            text=True,
        )


if __name__ == "__main__":
    main()

```
2024-07-03 10:33:27 -07:00

173 lines
5.9 KiB
Python

"""
Test of Astra DB document loader class `AstraDBLoader`
Required to run this test:
- a recent `astrapy` Python package available
- an Astra DB instance;
- the two environment variables set:
export ASTRA_DB_API_ENDPOINT="https://<DB-ID>-us-east1.apps.astra.datastax.com"
export ASTRA_DB_APPLICATION_TOKEN="AstraCS:........."
- optionally this as well (otherwise defaults are used):
export ASTRA_DB_KEYSPACE="my_keyspace"
"""
from __future__ import annotations
import json
import os
import uuid
from typing import TYPE_CHECKING, AsyncIterator, Iterator
import pytest
from langchain_community.document_loaders.astradb import AstraDBLoader
if TYPE_CHECKING:
from astrapy.db import (
AstraDBCollection,
AsyncAstraDBCollection,
)
ASTRA_DB_APPLICATION_TOKEN = os.getenv("ASTRA_DB_APPLICATION_TOKEN")
ASTRA_DB_API_ENDPOINT = os.getenv("ASTRA_DB_API_ENDPOINT")
ASTRA_DB_KEYSPACE = os.getenv("ASTRA_DB_KEYSPACE")
def _has_env_vars() -> bool:
return all([ASTRA_DB_APPLICATION_TOKEN, ASTRA_DB_API_ENDPOINT])
@pytest.fixture
def astra_db_collection() -> Iterator[AstraDBCollection]:
from astrapy.db import AstraDB
astra_db = AstraDB(
token=ASTRA_DB_APPLICATION_TOKEN or "",
api_endpoint=ASTRA_DB_API_ENDPOINT or "",
namespace=ASTRA_DB_KEYSPACE,
)
collection_name = f"lc_test_loader_{str(uuid.uuid4()).split('-')[0]}"
collection = astra_db.create_collection(collection_name)
collection.insert_many([{"foo": "bar", "baz": "qux"}] * 20)
collection.insert_many(
[{"foo": "bar2", "baz": "qux"}] * 4 + [{"foo": "bar", "baz": "qux"}] * 4
)
yield collection
astra_db.delete_collection(collection_name)
@pytest.fixture
async def async_astra_db_collection() -> AsyncIterator[AsyncAstraDBCollection]:
from astrapy.db import AsyncAstraDB
astra_db = AsyncAstraDB(
token=ASTRA_DB_APPLICATION_TOKEN or "",
api_endpoint=ASTRA_DB_API_ENDPOINT or "",
namespace=ASTRA_DB_KEYSPACE,
)
collection_name = f"lc_test_loader_{str(uuid.uuid4()).split('-')[0]}"
collection = await astra_db.create_collection(collection_name)
await collection.insert_many([{"foo": "bar", "baz": "qux"}] * 20)
await collection.insert_many(
[{"foo": "bar2", "baz": "qux"}] * 4 + [{"foo": "bar", "baz": "qux"}] * 4
)
yield collection
await astra_db.delete_collection(collection_name)
@pytest.mark.requires("astrapy")
@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars")
class TestAstraDB:
def test_astradb_loader(self, astra_db_collection: AstraDBCollection) -> None:
loader = AstraDBLoader(
astra_db_collection.collection_name,
token=ASTRA_DB_APPLICATION_TOKEN,
api_endpoint=ASTRA_DB_API_ENDPOINT,
namespace=ASTRA_DB_KEYSPACE,
nb_prefetched=1,
projection={"foo": 1},
find_options={"limit": 22},
filter_criteria={"foo": "bar"},
)
docs = loader.load()
assert len(docs) == 22
ids = set()
for doc in docs:
content = json.loads(doc.page_content)
assert content["foo"] == "bar"
assert "baz" not in content
assert content["_id"] not in ids
ids.add(content["_id"])
assert doc.metadata == {
"namespace": astra_db_collection.astra_db.namespace,
"api_endpoint": astra_db_collection.astra_db.base_url,
"collection": astra_db_collection.collection_name,
}
def test_extraction_function(self, astra_db_collection: AstraDBCollection) -> None:
loader = AstraDBLoader(
astra_db_collection.collection_name,
token=ASTRA_DB_APPLICATION_TOKEN,
api_endpoint=ASTRA_DB_API_ENDPOINT,
namespace=ASTRA_DB_KEYSPACE,
find_options={"limit": 30},
extraction_function=lambda x: x["foo"],
)
docs = loader.lazy_load()
doc = next(docs)
assert doc.page_content == "bar"
async def test_astradb_loader_async(
self, async_astra_db_collection: AsyncAstraDBCollection
) -> None:
await async_astra_db_collection.insert_many([{"foo": "bar", "baz": "qux"}] * 20)
await async_astra_db_collection.insert_many(
[{"foo": "bar2", "baz": "qux"}] * 4 + [{"foo": "bar", "baz": "qux"}] * 4
)
loader = AstraDBLoader(
async_astra_db_collection.collection_name,
token=ASTRA_DB_APPLICATION_TOKEN,
api_endpoint=ASTRA_DB_API_ENDPOINT,
namespace=ASTRA_DB_KEYSPACE,
nb_prefetched=1,
projection={"foo": 1},
find_options={"limit": 22},
filter_criteria={"foo": "bar"},
)
docs = await loader.aload()
assert len(docs) == 22
ids = set()
for doc in docs:
content = json.loads(doc.page_content)
assert content["foo"] == "bar"
assert "baz" not in content
assert content["_id"] not in ids
ids.add(content["_id"])
assert doc.metadata == {
"namespace": async_astra_db_collection.astra_db.namespace,
"api_endpoint": async_astra_db_collection.astra_db.base_url,
"collection": async_astra_db_collection.collection_name,
}
async def test_extraction_function_async(
self, async_astra_db_collection: AsyncAstraDBCollection
) -> None:
loader = AstraDBLoader(
async_astra_db_collection.collection_name,
token=ASTRA_DB_APPLICATION_TOKEN,
api_endpoint=ASTRA_DB_API_ENDPOINT,
namespace=ASTRA_DB_KEYSPACE,
find_options={"limit": 30},
extraction_function=lambda x: x["foo"],
)
doc = await loader.alazy_load().__anext__()
assert doc.page_content == "bar"