mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-14 15:16:21 +00:00
astradb: move to langchain-datastax repo (#18354)
This commit is contained in:
parent
b641be2edf
commit
6afb135baa
8
.github/scripts/check_diff.py
vendored
8
.github/scripts/check_diff.py
vendored
@ -49,9 +49,11 @@ if __name__ == "__main__":
|
|||||||
dirs_to_run["extended-test"].add(dir_)
|
dirs_to_run["extended-test"].add(dir_)
|
||||||
elif file.startswith("libs/partners"):
|
elif file.startswith("libs/partners"):
|
||||||
partner_dir = file.split("/")[2]
|
partner_dir = file.split("/")[2]
|
||||||
if os.path.isdir(f"libs/partners/{partner_dir}") and os.listdir(
|
if os.path.isdir(f"libs/partners/{partner_dir}") and [
|
||||||
f"libs/partners/{partner_dir}"
|
filename
|
||||||
) != ["README.md"]:
|
for filename in os.listdir(f"libs/partners/{partner_dir}")
|
||||||
|
if not filename.startswith(".")
|
||||||
|
] != ["README.md"]:
|
||||||
dirs_to_run["test"].add(f"libs/partners/{partner_dir}")
|
dirs_to_run["test"].add(f"libs/partners/{partner_dir}")
|
||||||
# Skip if the directory was deleted or is just a tombstone readme
|
# Skip if the directory was deleted or is just a tombstone readme
|
||||||
elif file.startswith("libs/"):
|
elif file.startswith("libs/"):
|
||||||
|
10
.github/workflows/api_doc_build.yml
vendored
10
.github/workflows/api_doc_build.yml
vendored
@ -20,11 +20,19 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
repository: langchain-ai/langchain-google
|
repository: langchain-ai/langchain-google
|
||||||
path: langchain-google
|
path: langchain-google
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
repository: langchain-ai/langchain-datastax
|
||||||
|
path: langchain-datastax
|
||||||
- name: Move google libs
|
- name: Move google libs
|
||||||
run: |
|
run: |
|
||||||
rm -rf langchain/libs/partners/google-genai langchain/libs/partners/google-vertexai
|
rm -rf \
|
||||||
|
langchain/libs/partners/google-genai \
|
||||||
|
langchain/libs/partners/google-vertexai \
|
||||||
|
langchain/libs/partners/astradb
|
||||||
mv langchain-google/libs/genai langchain/libs/partners/google-genai
|
mv langchain-google/libs/genai langchain/libs/partners/google-genai
|
||||||
mv langchain-google/libs/vertexai langchain/libs/partners/google-vertexai
|
mv langchain-google/libs/vertexai langchain/libs/partners/google-vertexai
|
||||||
|
mv langchain-datastax/libs/astradb langchain/libs/partners/astradb
|
||||||
|
|
||||||
- name: Set Git config
|
- name: Set Git config
|
||||||
working-directory: langchain
|
working-directory: langchain
|
||||||
|
@ -1,21 +0,0 @@
|
|||||||
MIT License
|
|
||||||
|
|
||||||
Copyright (c) 2023 LangChain, Inc.
|
|
||||||
|
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
||||||
of this software and associated documentation files (the "Software"), to deal
|
|
||||||
in the Software without restriction, including without limitation the rights
|
|
||||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
||||||
copies of the Software, and to permit persons to whom the Software is
|
|
||||||
furnished to do so, subject to the following conditions:
|
|
||||||
|
|
||||||
The above copyright notice and this permission notice shall be included in all
|
|
||||||
copies or substantial portions of the Software.
|
|
||||||
|
|
||||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
||||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
||||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
||||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
||||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
||||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
||||||
SOFTWARE.
|
|
@ -1,66 +0,0 @@
|
|||||||
SHELL := /bin/bash
|
|
||||||
.PHONY: all format lint test tests integration_test integration_tests spell_check help
|
|
||||||
|
|
||||||
# Default target executed when no arguments are given to make.
|
|
||||||
all: help
|
|
||||||
|
|
||||||
# Define a variable for the test file path.
|
|
||||||
TEST_FILE ?= tests/unit_tests/
|
|
||||||
INTEGRATION_TEST_FILE ?= tests/integration_tests/
|
|
||||||
|
|
||||||
test:
|
|
||||||
poetry run pytest $(TEST_FILE)
|
|
||||||
|
|
||||||
tests:
|
|
||||||
poetry run pytest $(TEST_FILE)
|
|
||||||
|
|
||||||
integration_test:
|
|
||||||
poetry run pytest $(INTEGRATION_TEST_FILE)
|
|
||||||
|
|
||||||
integration_tests:
|
|
||||||
poetry run pytest $(INTEGRATION_TEST_FILE)
|
|
||||||
|
|
||||||
######################
|
|
||||||
# LINTING AND FORMATTING
|
|
||||||
######################
|
|
||||||
|
|
||||||
# Define a variable for Python and notebook files.
|
|
||||||
PYTHON_FILES=.
|
|
||||||
MYPY_CACHE=.mypy_cache
|
|
||||||
lint format: PYTHON_FILES=.
|
|
||||||
lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/partners/astradb --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$')
|
|
||||||
lint_package: PYTHON_FILES=langchain_astradb
|
|
||||||
lint_tests: PYTHON_FILES=tests
|
|
||||||
lint_tests: MYPY_CACHE=.mypy_cache_test
|
|
||||||
|
|
||||||
lint lint_diff lint_package lint_tests:
|
|
||||||
poetry run ruff .
|
|
||||||
poetry run ruff format $(PYTHON_FILES) --diff
|
|
||||||
poetry run ruff --select I $(PYTHON_FILES)
|
|
||||||
mkdir -p $(MYPY_CACHE); poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
|
|
||||||
|
|
||||||
format format_diff:
|
|
||||||
poetry run ruff format $(PYTHON_FILES)
|
|
||||||
poetry run ruff --select I --fix $(PYTHON_FILES)
|
|
||||||
|
|
||||||
spell_check:
|
|
||||||
poetry run codespell --toml pyproject.toml
|
|
||||||
|
|
||||||
spell_fix:
|
|
||||||
poetry run codespell --toml pyproject.toml -w
|
|
||||||
|
|
||||||
check_imports: $(shell find langchain_astradb -name '*.py')
|
|
||||||
poetry run python ./scripts/check_imports.py $^
|
|
||||||
|
|
||||||
######################
|
|
||||||
# HELP
|
|
||||||
######################
|
|
||||||
|
|
||||||
help:
|
|
||||||
@echo '----'
|
|
||||||
@echo 'check_imports - check imports'
|
|
||||||
@echo 'format - run code formatters'
|
|
||||||
@echo 'lint - run linters'
|
|
||||||
@echo 'test - run unit tests'
|
|
||||||
@echo 'tests - run unit tests'
|
|
||||||
@echo 'test TEST_FILE=<test_file> - run all tests in file'
|
|
@ -1,68 +1,3 @@
|
|||||||
# langchain-astradb
|
This package has moved!
|
||||||
|
|
||||||
This package contains the LangChain integrations for using DataStax Astra DB.
|
https://github.com/langchain-ai/langchain-datastax/tree/main/libs/astradb
|
||||||
|
|
||||||
> DataStax [Astra DB](https://docs.datastax.com/en/astra/home/astra.html) is a serverless vector-capable database built on Apache Cassandra® and made conveniently available
|
|
||||||
> through an easy-to-use JSON API.
|
|
||||||
|
|
||||||
_**Note.** For a short transitional period, only some of the Astra DB integration classes are contained in this package (the remaining ones being still in `langchain-community`). In a short while, and surely by version 0.2 of LangChain, all of the Astra DB support will be removed from `langchain-community` and included in this package._
|
|
||||||
|
|
||||||
## Installation and Setup
|
|
||||||
|
|
||||||
Installation of this partner package:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
pip install langchain-astradb
|
|
||||||
```
|
|
||||||
|
|
||||||
## Integrations overview
|
|
||||||
|
|
||||||
### Vector Store
|
|
||||||
|
|
||||||
```python
|
|
||||||
from langchain_astradb import AstraDBVectorStore
|
|
||||||
|
|
||||||
my_store = AstraDBVectorStore(
|
|
||||||
embedding=my_embeddings,
|
|
||||||
collection_name="my_store",
|
|
||||||
api_endpoint="https://...",
|
|
||||||
token="AstraCS:...",
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
### Chat message history
|
|
||||||
|
|
||||||
```python
|
|
||||||
from langchain_astradb import AstraDBChatMessageHistory
|
|
||||||
message_history = AstraDBChatMessageHistory(
|
|
||||||
session_id="test-session",
|
|
||||||
api_endpoint="...",
|
|
||||||
token="...",
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
### Store
|
|
||||||
|
|
||||||
```python
|
|
||||||
from langchain_astradb import AstraDBStore
|
|
||||||
store = AstraDBStore(
|
|
||||||
collection_name="my_kv_store",
|
|
||||||
api_endpoint="...",
|
|
||||||
token="..."
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
### Byte Store
|
|
||||||
|
|
||||||
```python
|
|
||||||
from langchain_astradb import AstraDBByteStore
|
|
||||||
store = AstraDBByteStore(
|
|
||||||
collection_name="my_kv_store",
|
|
||||||
api_endpoint="...",
|
|
||||||
token="..."
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
## Reference
|
|
||||||
|
|
||||||
See the [LangChain docs page](https://python.langchain.com/docs/integrations/providers/astradb) for a more detailed listing.
|
|
@ -1,10 +0,0 @@
|
|||||||
from langchain_astradb.chat_message_histories import AstraDBChatMessageHistory
|
|
||||||
from langchain_astradb.storage import AstraDBByteStore, AstraDBStore
|
|
||||||
from langchain_astradb.vectorstores import AstraDBVectorStore
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"AstraDBByteStore",
|
|
||||||
"AstraDBStore",
|
|
||||||
"AstraDBChatMessageHistory",
|
|
||||||
"AstraDBVectorStore",
|
|
||||||
]
|
|
@ -1,148 +0,0 @@
|
|||||||
"""Astra DB - based chat message history, based on astrapy."""
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
|
||||||
import time
|
|
||||||
from typing import List, Optional, Sequence
|
|
||||||
|
|
||||||
from astrapy.db import AstraDB, AsyncAstraDB
|
|
||||||
from langchain_core.chat_history import BaseChatMessageHistory
|
|
||||||
from langchain_core.messages import (
|
|
||||||
BaseMessage,
|
|
||||||
message_to_dict,
|
|
||||||
messages_from_dict,
|
|
||||||
)
|
|
||||||
|
|
||||||
from langchain_astradb.utils.astradb import (
|
|
||||||
SetupMode,
|
|
||||||
_AstraDBCollectionEnvironment,
|
|
||||||
)
|
|
||||||
|
|
||||||
DEFAULT_COLLECTION_NAME = "langchain_message_store"
|
|
||||||
|
|
||||||
|
|
||||||
class AstraDBChatMessageHistory(BaseChatMessageHistory):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
session_id: str,
|
|
||||||
collection_name: str = DEFAULT_COLLECTION_NAME,
|
|
||||||
token: Optional[str] = None,
|
|
||||||
api_endpoint: Optional[str] = None,
|
|
||||||
astra_db_client: Optional[AstraDB] = None,
|
|
||||||
async_astra_db_client: Optional[AsyncAstraDB] = None,
|
|
||||||
namespace: Optional[str] = None,
|
|
||||||
setup_mode: SetupMode = SetupMode.SYNC,
|
|
||||||
pre_delete_collection: bool = False,
|
|
||||||
) -> None:
|
|
||||||
"""Chat message history that stores history in Astra DB.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session_id: arbitrary key that is used to store the messages
|
|
||||||
of a single chat session.
|
|
||||||
collection_name: name of the Astra DB collection to create/use.
|
|
||||||
token: API token for Astra DB usage.
|
|
||||||
api_endpoint: full URL to the API endpoint,
|
|
||||||
such as "https://<DB-ID>-us-east1.apps.astra.datastax.com".
|
|
||||||
astra_db_client: *alternative to token+api_endpoint*,
|
|
||||||
you can pass an already-created 'astrapy.db.AstraDB' instance.
|
|
||||||
async_astra_db_client: *alternative to token+api_endpoint*,
|
|
||||||
you can pass an already-created 'astrapy.db.AsyncAstraDB' instance.
|
|
||||||
namespace: namespace (aka keyspace) where the
|
|
||||||
collection is created. Defaults to the database's "default namespace".
|
|
||||||
"""
|
|
||||||
self.astra_env = _AstraDBCollectionEnvironment(
|
|
||||||
collection_name=collection_name,
|
|
||||||
token=token,
|
|
||||||
api_endpoint=api_endpoint,
|
|
||||||
astra_db_client=astra_db_client,
|
|
||||||
async_astra_db_client=async_astra_db_client,
|
|
||||||
namespace=namespace,
|
|
||||||
setup_mode=setup_mode,
|
|
||||||
pre_delete_collection=pre_delete_collection,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.collection = self.astra_env.collection
|
|
||||||
self.async_collection = self.astra_env.async_collection
|
|
||||||
|
|
||||||
self.session_id = session_id
|
|
||||||
self.collection_name = collection_name
|
|
||||||
|
|
||||||
@property
|
|
||||||
def messages(self) -> List[BaseMessage]:
|
|
||||||
"""Retrieve all session messages from DB"""
|
|
||||||
self.astra_env.ensure_db_setup()
|
|
||||||
message_blobs = [
|
|
||||||
doc["body_blob"]
|
|
||||||
for doc in sorted(
|
|
||||||
self.collection.paginated_find(
|
|
||||||
filter={
|
|
||||||
"session_id": self.session_id,
|
|
||||||
},
|
|
||||||
projection={
|
|
||||||
"timestamp": 1,
|
|
||||||
"body_blob": 1,
|
|
||||||
},
|
|
||||||
),
|
|
||||||
key=lambda _doc: _doc["timestamp"],
|
|
||||||
)
|
|
||||||
]
|
|
||||||
items = [json.loads(message_blob) for message_blob in message_blobs]
|
|
||||||
messages = messages_from_dict(items)
|
|
||||||
return messages
|
|
||||||
|
|
||||||
@messages.setter
|
|
||||||
def messages(self, messages: List[BaseMessage]) -> None:
|
|
||||||
raise NotImplementedError("Use add_messages instead")
|
|
||||||
|
|
||||||
async def aget_messages(self) -> List[BaseMessage]:
|
|
||||||
await self.astra_env.aensure_db_setup()
|
|
||||||
docs = self.async_collection.paginated_find(
|
|
||||||
filter={
|
|
||||||
"session_id": self.session_id,
|
|
||||||
},
|
|
||||||
projection={
|
|
||||||
"timestamp": 1,
|
|
||||||
"body_blob": 1,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
sorted_docs = sorted(
|
|
||||||
[doc async for doc in docs],
|
|
||||||
key=lambda _doc: _doc["timestamp"],
|
|
||||||
)
|
|
||||||
message_blobs = [doc["body_blob"] for doc in sorted_docs]
|
|
||||||
items = [json.loads(message_blob) for message_blob in message_blobs]
|
|
||||||
messages = messages_from_dict(items)
|
|
||||||
return messages
|
|
||||||
|
|
||||||
def add_messages(self, messages: Sequence[BaseMessage]) -> None:
|
|
||||||
self.astra_env.ensure_db_setup()
|
|
||||||
docs = [
|
|
||||||
{
|
|
||||||
"timestamp": time.time(),
|
|
||||||
"session_id": self.session_id,
|
|
||||||
"body_blob": json.dumps(message_to_dict(message)),
|
|
||||||
}
|
|
||||||
for message in messages
|
|
||||||
]
|
|
||||||
self.collection.chunked_insert_many(docs)
|
|
||||||
|
|
||||||
async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None:
|
|
||||||
await self.astra_env.aensure_db_setup()
|
|
||||||
docs = [
|
|
||||||
{
|
|
||||||
"timestamp": time.time(),
|
|
||||||
"session_id": self.session_id,
|
|
||||||
"body_blob": json.dumps(message_to_dict(message)),
|
|
||||||
}
|
|
||||||
for message in messages
|
|
||||||
]
|
|
||||||
await self.async_collection.chunked_insert_many(docs)
|
|
||||||
|
|
||||||
def clear(self) -> None:
|
|
||||||
self.astra_env.ensure_db_setup()
|
|
||||||
self.collection.delete_many(filter={"session_id": self.session_id})
|
|
||||||
|
|
||||||
async def aclear(self) -> None:
|
|
||||||
await self.astra_env.aensure_db_setup()
|
|
||||||
await self.async_collection.delete_many(filter={"session_id": self.session_id})
|
|
@ -1,217 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import base64
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import (
|
|
||||||
Any,
|
|
||||||
AsyncIterator,
|
|
||||||
Generic,
|
|
||||||
Iterator,
|
|
||||||
List,
|
|
||||||
Optional,
|
|
||||||
Sequence,
|
|
||||||
Tuple,
|
|
||||||
TypeVar,
|
|
||||||
)
|
|
||||||
|
|
||||||
from astrapy.db import AstraDB, AsyncAstraDB
|
|
||||||
from langchain_core.stores import BaseStore, ByteStore
|
|
||||||
|
|
||||||
from langchain_astradb.utils.astradb import (
|
|
||||||
SetupMode,
|
|
||||||
_AstraDBCollectionEnvironment,
|
|
||||||
)
|
|
||||||
|
|
||||||
V = TypeVar("V")
|
|
||||||
|
|
||||||
|
|
||||||
class AstraDBBaseStore(Generic[V], BaseStore[str, V], ABC):
|
|
||||||
"""Base class for the DataStax AstraDB data store."""
|
|
||||||
|
|
||||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
|
||||||
self.astra_env = _AstraDBCollectionEnvironment(*args, **kwargs)
|
|
||||||
self.collection = self.astra_env.collection
|
|
||||||
self.async_collection = self.astra_env.async_collection
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def decode_value(self, value: Any) -> Optional[V]:
|
|
||||||
"""Decodes value from Astra DB"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def encode_value(self, value: Optional[V]) -> Any:
|
|
||||||
"""Encodes value for Astra DB"""
|
|
||||||
|
|
||||||
def mget(self, keys: Sequence[str]) -> List[Optional[V]]:
|
|
||||||
self.astra_env.ensure_db_setup()
|
|
||||||
docs_dict = {}
|
|
||||||
for doc in self.collection.paginated_find(filter={"_id": {"$in": list(keys)}}):
|
|
||||||
docs_dict[doc["_id"]] = doc.get("value")
|
|
||||||
return [self.decode_value(docs_dict.get(key)) for key in keys]
|
|
||||||
|
|
||||||
async def amget(self, keys: Sequence[str]) -> List[Optional[V]]:
|
|
||||||
await self.astra_env.aensure_db_setup()
|
|
||||||
docs_dict = {}
|
|
||||||
async for doc in self.async_collection.paginated_find(
|
|
||||||
filter={"_id": {"$in": list(keys)}}
|
|
||||||
):
|
|
||||||
docs_dict[doc["_id"]] = doc.get("value")
|
|
||||||
return [self.decode_value(docs_dict.get(key)) for key in keys]
|
|
||||||
|
|
||||||
def mset(self, key_value_pairs: Sequence[Tuple[str, V]]) -> None:
|
|
||||||
self.astra_env.ensure_db_setup()
|
|
||||||
for k, v in key_value_pairs:
|
|
||||||
self.collection.upsert_one({"_id": k, "value": self.encode_value(v)})
|
|
||||||
|
|
||||||
async def amset(self, key_value_pairs: Sequence[Tuple[str, V]]) -> None:
|
|
||||||
await self.astra_env.aensure_db_setup()
|
|
||||||
for k, v in key_value_pairs:
|
|
||||||
await self.async_collection.upsert_one(
|
|
||||||
{"_id": k, "value": self.encode_value(v)}
|
|
||||||
)
|
|
||||||
|
|
||||||
def mdelete(self, keys: Sequence[str]) -> None:
|
|
||||||
self.astra_env.ensure_db_setup()
|
|
||||||
self.collection.delete_many(filter={"_id": {"$in": list(keys)}})
|
|
||||||
|
|
||||||
async def amdelete(self, keys: Sequence[str]) -> None:
|
|
||||||
await self.astra_env.aensure_db_setup()
|
|
||||||
await self.async_collection.delete_many(filter={"_id": {"$in": list(keys)}})
|
|
||||||
|
|
||||||
def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]:
|
|
||||||
self.astra_env.ensure_db_setup()
|
|
||||||
docs = self.collection.paginated_find()
|
|
||||||
for doc in docs:
|
|
||||||
key = doc["_id"]
|
|
||||||
if not prefix or key.startswith(prefix):
|
|
||||||
yield key
|
|
||||||
|
|
||||||
async def ayield_keys(self, *, prefix: Optional[str] = None) -> AsyncIterator[str]:
|
|
||||||
await self.astra_env.aensure_db_setup()
|
|
||||||
async for doc in self.async_collection.paginated_find():
|
|
||||||
key = doc["_id"]
|
|
||||||
if not prefix or key.startswith(prefix):
|
|
||||||
yield key
|
|
||||||
|
|
||||||
|
|
||||||
class AstraDBStore(AstraDBBaseStore[Any]):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
collection_name: str,
|
|
||||||
*,
|
|
||||||
token: Optional[str] = None,
|
|
||||||
api_endpoint: Optional[str] = None,
|
|
||||||
astra_db_client: Optional[AstraDB] = None,
|
|
||||||
namespace: Optional[str] = None,
|
|
||||||
async_astra_db_client: Optional[AsyncAstraDB] = None,
|
|
||||||
pre_delete_collection: bool = False,
|
|
||||||
setup_mode: SetupMode = SetupMode.SYNC,
|
|
||||||
) -> None:
|
|
||||||
"""BaseStore implementation using DataStax AstraDB as the underlying store.
|
|
||||||
|
|
||||||
The value type can be any type serializable by json.dumps.
|
|
||||||
Can be used to store embeddings with the CacheBackedEmbeddings.
|
|
||||||
|
|
||||||
Documents in the AstraDB collection will have the format
|
|
||||||
|
|
||||||
.. code-block:: json
|
|
||||||
{
|
|
||||||
"_id": "<key>",
|
|
||||||
"value": <value>
|
|
||||||
}
|
|
||||||
|
|
||||||
Args:
|
|
||||||
collection_name: name of the Astra DB collection to create/use.
|
|
||||||
token: API token for Astra DB usage.
|
|
||||||
api_endpoint: full URL to the API endpoint,
|
|
||||||
such as `https://<DB-ID>-us-east1.apps.astra.datastax.com`.
|
|
||||||
astra_db_client: *alternative to token+api_endpoint*,
|
|
||||||
you can pass an already-created 'astrapy.db.AstraDB' instance.
|
|
||||||
async_astra_db_client: *alternative to token+api_endpoint*,
|
|
||||||
you can pass an already-created 'astrapy.db.AsyncAstraDB' instance.
|
|
||||||
namespace: namespace (aka keyspace) where the
|
|
||||||
collection is created. Defaults to the database's "default namespace".
|
|
||||||
setup_mode: mode used to create the Astra DB collection (SYNC, ASYNC or
|
|
||||||
OFF).
|
|
||||||
pre_delete_collection: whether to delete the collection
|
|
||||||
before creating it. If False and the collection already exists,
|
|
||||||
the collection will be used as is.
|
|
||||||
"""
|
|
||||||
super().__init__(
|
|
||||||
collection_name=collection_name,
|
|
||||||
token=token,
|
|
||||||
api_endpoint=api_endpoint,
|
|
||||||
astra_db_client=astra_db_client,
|
|
||||||
async_astra_db_client=async_astra_db_client,
|
|
||||||
namespace=namespace,
|
|
||||||
setup_mode=setup_mode,
|
|
||||||
pre_delete_collection=pre_delete_collection,
|
|
||||||
)
|
|
||||||
|
|
||||||
def decode_value(self, value: Any) -> Any:
|
|
||||||
return value
|
|
||||||
|
|
||||||
def encode_value(self, value: Any) -> Any:
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
class AstraDBByteStore(AstraDBBaseStore[bytes], ByteStore):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
collection_name: str,
|
|
||||||
token: Optional[str] = None,
|
|
||||||
api_endpoint: Optional[str] = None,
|
|
||||||
astra_db_client: Optional[AstraDB] = None,
|
|
||||||
namespace: Optional[str] = None,
|
|
||||||
async_astra_db_client: Optional[AsyncAstraDB] = None,
|
|
||||||
pre_delete_collection: bool = False,
|
|
||||||
setup_mode: SetupMode = SetupMode.SYNC,
|
|
||||||
) -> None:
|
|
||||||
"""ByteStore implementation using DataStax AstraDB as the underlying store.
|
|
||||||
|
|
||||||
The bytes values are converted to base64 encoded strings
|
|
||||||
Documents in the AstraDB collection will have the format
|
|
||||||
|
|
||||||
.. code-block:: json
|
|
||||||
{
|
|
||||||
"_id": "<key>",
|
|
||||||
"value": "<byte64 string value>"
|
|
||||||
}
|
|
||||||
|
|
||||||
Args:
|
|
||||||
collection_name: name of the Astra DB collection to create/use.
|
|
||||||
token: API token for Astra DB usage.
|
|
||||||
api_endpoint: full URL to the API endpoint,
|
|
||||||
such as `https://<DB-ID>-us-east1.apps.astra.datastax.com`.
|
|
||||||
astra_db_client: *alternative to token+api_endpoint*,
|
|
||||||
you can pass an already-created 'astrapy.db.AstraDB' instance.
|
|
||||||
async_astra_db_client: *alternative to token+api_endpoint*,
|
|
||||||
you can pass an already-created 'astrapy.db.AsyncAstraDB' instance.
|
|
||||||
namespace: namespace (aka keyspace) where the
|
|
||||||
collection is created. Defaults to the database's "default namespace".
|
|
||||||
setup_mode: mode used to create the Astra DB collection (SYNC, ASYNC or
|
|
||||||
OFF).
|
|
||||||
pre_delete_collection: whether to delete the collection
|
|
||||||
before creating it. If False and the collection already exists,
|
|
||||||
the collection will be used as is.
|
|
||||||
"""
|
|
||||||
super().__init__(
|
|
||||||
collection_name=collection_name,
|
|
||||||
token=token,
|
|
||||||
api_endpoint=api_endpoint,
|
|
||||||
astra_db_client=astra_db_client,
|
|
||||||
async_astra_db_client=async_astra_db_client,
|
|
||||||
namespace=namespace,
|
|
||||||
setup_mode=setup_mode,
|
|
||||||
pre_delete_collection=pre_delete_collection,
|
|
||||||
)
|
|
||||||
|
|
||||||
def decode_value(self, value: Any) -> Optional[bytes]:
|
|
||||||
if value is None:
|
|
||||||
return None
|
|
||||||
return base64.b64decode(value)
|
|
||||||
|
|
||||||
def encode_value(self, value: Optional[bytes]) -> Any:
|
|
||||||
if value is None:
|
|
||||||
return None
|
|
||||||
return base64.b64encode(value).decode("ascii")
|
|
@ -1,152 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import inspect
|
|
||||||
from asyncio import InvalidStateError, Task
|
|
||||||
from enum import Enum
|
|
||||||
from typing import Awaitable, Optional, Union
|
|
||||||
|
|
||||||
import langchain_core
|
|
||||||
from astrapy.db import AstraDB, AsyncAstraDB
|
|
||||||
|
|
||||||
|
|
||||||
class SetupMode(Enum):
|
|
||||||
SYNC = 1
|
|
||||||
ASYNC = 2
|
|
||||||
OFF = 3
|
|
||||||
|
|
||||||
|
|
||||||
class _AstraDBEnvironment:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
token: Optional[str] = None,
|
|
||||||
api_endpoint: Optional[str] = None,
|
|
||||||
astra_db_client: Optional[AstraDB] = None,
|
|
||||||
async_astra_db_client: Optional[AsyncAstraDB] = None,
|
|
||||||
namespace: Optional[str] = None,
|
|
||||||
) -> None:
|
|
||||||
self.token = token
|
|
||||||
self.api_endpoint = api_endpoint
|
|
||||||
astra_db = astra_db_client
|
|
||||||
async_astra_db = async_astra_db_client
|
|
||||||
self.namespace = namespace
|
|
||||||
|
|
||||||
# Conflicting-arg checks:
|
|
||||||
if astra_db_client is not None or async_astra_db_client is not None:
|
|
||||||
if token is not None or api_endpoint is not None:
|
|
||||||
raise ValueError(
|
|
||||||
"You cannot pass 'astra_db_client' or 'async_astra_db_client' to "
|
|
||||||
"AstraDBEnvironment if passing 'token' and 'api_endpoint'."
|
|
||||||
)
|
|
||||||
|
|
||||||
if token and api_endpoint:
|
|
||||||
astra_db = AstraDB(
|
|
||||||
token=token,
|
|
||||||
api_endpoint=api_endpoint,
|
|
||||||
namespace=self.namespace,
|
|
||||||
)
|
|
||||||
async_astra_db = AsyncAstraDB(
|
|
||||||
token=token,
|
|
||||||
api_endpoint=api_endpoint,
|
|
||||||
namespace=self.namespace,
|
|
||||||
)
|
|
||||||
|
|
||||||
if astra_db:
|
|
||||||
self.astra_db = astra_db.copy()
|
|
||||||
if async_astra_db:
|
|
||||||
self.async_astra_db = async_astra_db.copy()
|
|
||||||
else:
|
|
||||||
self.async_astra_db = self.astra_db.to_async()
|
|
||||||
elif async_astra_db:
|
|
||||||
self.async_astra_db = async_astra_db.copy()
|
|
||||||
self.astra_db = self.async_astra_db.to_sync()
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"Must provide 'astra_db_client' or 'async_astra_db_client' or "
|
|
||||||
"'token' and 'api_endpoint'"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.astra_db.set_caller(
|
|
||||||
caller_name="langchain",
|
|
||||||
caller_version=getattr(langchain_core, "__version__", None),
|
|
||||||
)
|
|
||||||
self.async_astra_db.set_caller(
|
|
||||||
caller_name="langchain",
|
|
||||||
caller_version=getattr(langchain_core, "__version__", None),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class _AstraDBCollectionEnvironment(_AstraDBEnvironment):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
collection_name: str,
|
|
||||||
token: Optional[str] = None,
|
|
||||||
api_endpoint: Optional[str] = None,
|
|
||||||
astra_db_client: Optional[AstraDB] = None,
|
|
||||||
async_astra_db_client: Optional[AsyncAstraDB] = None,
|
|
||||||
namespace: Optional[str] = None,
|
|
||||||
setup_mode: SetupMode = SetupMode.SYNC,
|
|
||||||
pre_delete_collection: bool = False,
|
|
||||||
embedding_dimension: Union[int, Awaitable[int], None] = None,
|
|
||||||
metric: Optional[str] = None,
|
|
||||||
) -> None:
|
|
||||||
from astrapy.db import AstraDBCollection, AsyncAstraDBCollection
|
|
||||||
|
|
||||||
super().__init__(
|
|
||||||
token, api_endpoint, astra_db_client, async_astra_db_client, namespace
|
|
||||||
)
|
|
||||||
self.collection_name = collection_name
|
|
||||||
self.collection = AstraDBCollection(
|
|
||||||
collection_name=collection_name,
|
|
||||||
astra_db=self.astra_db,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.async_collection = AsyncAstraDBCollection(
|
|
||||||
collection_name=collection_name,
|
|
||||||
astra_db=self.async_astra_db,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.async_setup_db_task: Optional[Task] = None
|
|
||||||
if setup_mode == SetupMode.ASYNC:
|
|
||||||
async_astra_db = self.async_astra_db
|
|
||||||
|
|
||||||
async def _setup_db() -> None:
|
|
||||||
if pre_delete_collection:
|
|
||||||
await async_astra_db.delete_collection(collection_name)
|
|
||||||
if inspect.isawaitable(embedding_dimension):
|
|
||||||
dimension = await embedding_dimension
|
|
||||||
else:
|
|
||||||
dimension = embedding_dimension
|
|
||||||
await async_astra_db.create_collection(
|
|
||||||
collection_name, dimension=dimension, metric=metric
|
|
||||||
)
|
|
||||||
|
|
||||||
self.async_setup_db_task = asyncio.create_task(_setup_db())
|
|
||||||
elif setup_mode == SetupMode.SYNC:
|
|
||||||
if pre_delete_collection:
|
|
||||||
self.astra_db.delete_collection(collection_name)
|
|
||||||
if inspect.isawaitable(embedding_dimension):
|
|
||||||
raise ValueError(
|
|
||||||
"Cannot use an awaitable embedding_dimension with async_setup "
|
|
||||||
"set to False"
|
|
||||||
)
|
|
||||||
self.astra_db.create_collection(
|
|
||||||
collection_name,
|
|
||||||
dimension=embedding_dimension, # type: ignore[arg-type]
|
|
||||||
metric=metric,
|
|
||||||
)
|
|
||||||
|
|
||||||
def ensure_db_setup(self) -> None:
|
|
||||||
if self.async_setup_db_task:
|
|
||||||
try:
|
|
||||||
self.async_setup_db_task.result()
|
|
||||||
except InvalidStateError:
|
|
||||||
raise ValueError(
|
|
||||||
"Asynchronous setup of the DB not finished. "
|
|
||||||
"NB: AstraDB components sync methods shouldn't be called from the "
|
|
||||||
"event loop. Consider using their async equivalents."
|
|
||||||
)
|
|
||||||
|
|
||||||
async def aensure_db_setup(self) -> None:
|
|
||||||
if self.async_setup_db_task:
|
|
||||||
await self.async_setup_db_task
|
|
@ -1,87 +0,0 @@
|
|||||||
"""
|
|
||||||
Tools for the Maximal Marginal Relevance (MMR) reranking.
|
|
||||||
Duplicated from langchain_community to avoid cross-dependencies.
|
|
||||||
|
|
||||||
Functions "maximal_marginal_relevance" and "cosine_similarity"
|
|
||||||
are duplicated in this utility respectively from modules:
|
|
||||||
- "libs/community/langchain_community/vectorstores/utils.py"
|
|
||||||
- "libs/community/langchain_community/utils/math.py"
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import List, Union
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
Matrix = Union[List[List[float]], List[np.ndarray], np.ndarray]
|
|
||||||
|
|
||||||
|
|
||||||
def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
|
|
||||||
"""Row-wise cosine similarity between two equal-width matrices."""
|
|
||||||
if len(X) == 0 or len(Y) == 0:
|
|
||||||
return np.array([])
|
|
||||||
|
|
||||||
X = np.array(X)
|
|
||||||
Y = np.array(Y)
|
|
||||||
if X.shape[1] != Y.shape[1]:
|
|
||||||
raise ValueError(
|
|
||||||
f"Number of columns in X and Y must be the same. X has shape {X.shape} "
|
|
||||||
f"and Y has shape {Y.shape}."
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
import simsimd as simd # type: ignore
|
|
||||||
|
|
||||||
X = np.array(X, dtype=np.float32)
|
|
||||||
Y = np.array(Y, dtype=np.float32)
|
|
||||||
Z = 1 - simd.cdist(X, Y, metric="cosine")
|
|
||||||
if isinstance(Z, float):
|
|
||||||
return np.array([Z])
|
|
||||||
return Z
|
|
||||||
except ImportError:
|
|
||||||
logger.info(
|
|
||||||
"Unable to import simsimd, defaulting to NumPy implementation. If you want "
|
|
||||||
"to use simsimd please install with `pip install simsimd`."
|
|
||||||
)
|
|
||||||
X_norm = np.linalg.norm(X, axis=1)
|
|
||||||
Y_norm = np.linalg.norm(Y, axis=1)
|
|
||||||
# Ignore divide by zero errors run time warnings as those are handled below.
|
|
||||||
with np.errstate(divide="ignore", invalid="ignore"):
|
|
||||||
similarity = np.dot(X, Y.T) / np.outer(X_norm, Y_norm)
|
|
||||||
similarity[np.isnan(similarity) | np.isinf(similarity)] = 0.0
|
|
||||||
return similarity
|
|
||||||
|
|
||||||
|
|
||||||
def maximal_marginal_relevance(
|
|
||||||
query_embedding: np.ndarray,
|
|
||||||
embedding_list: list,
|
|
||||||
lambda_mult: float = 0.5,
|
|
||||||
k: int = 4,
|
|
||||||
) -> List[int]:
|
|
||||||
"""Calculate maximal marginal relevance."""
|
|
||||||
if min(k, len(embedding_list)) <= 0:
|
|
||||||
return []
|
|
||||||
if query_embedding.ndim == 1:
|
|
||||||
query_embedding = np.expand_dims(query_embedding, axis=0)
|
|
||||||
similarity_to_query = cosine_similarity(query_embedding, embedding_list)[0]
|
|
||||||
most_similar = int(np.argmax(similarity_to_query))
|
|
||||||
idxs = [most_similar]
|
|
||||||
selected = np.array([embedding_list[most_similar]])
|
|
||||||
while len(idxs) < min(k, len(embedding_list)):
|
|
||||||
best_score = -np.inf
|
|
||||||
idx_to_add = -1
|
|
||||||
similarity_to_selected = cosine_similarity(embedding_list, selected)
|
|
||||||
for i, query_score in enumerate(similarity_to_query):
|
|
||||||
if i in idxs:
|
|
||||||
continue
|
|
||||||
redundant_score = max(similarity_to_selected[i])
|
|
||||||
equation_score = (
|
|
||||||
lambda_mult * query_score - (1 - lambda_mult) * redundant_score
|
|
||||||
)
|
|
||||||
if equation_score > best_score:
|
|
||||||
best_score = equation_score
|
|
||||||
idx_to_add = i
|
|
||||||
idxs.append(idx_to_add)
|
|
||||||
selected = np.append(selected, [embedding_list[idx_to_add]], axis=0)
|
|
||||||
return idxs
|
|
File diff suppressed because it is too large
Load Diff
1821
libs/partners/astradb/poetry.lock
generated
1821
libs/partners/astradb/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -1,92 +0,0 @@
|
|||||||
[tool.poetry]
|
|
||||||
name = "langchain-astradb"
|
|
||||||
version = "0.0.1"
|
|
||||||
description = "An integration package connecting Astra DB and LangChain"
|
|
||||||
authors = []
|
|
||||||
readme = "README.md"
|
|
||||||
|
|
||||||
[tool.poetry.dependencies]
|
|
||||||
python = ">=3.8.1,<4.0"
|
|
||||||
langchain-core = "^0.1.5"
|
|
||||||
astrapy = "^0.7.5"
|
|
||||||
numpy = "^1"
|
|
||||||
|
|
||||||
[tool.poetry.group.test]
|
|
||||||
optional = true
|
|
||||||
|
|
||||||
[tool.poetry.group.test.dependencies]
|
|
||||||
pytest = "^7.3.0"
|
|
||||||
pytest-dotenv = "^0.5.2"
|
|
||||||
freezegun = "^1.2.2"
|
|
||||||
pytest-mock = "^3.10.0"
|
|
||||||
syrupy = "^4.0.2"
|
|
||||||
pytest-watcher = "^0.3.4"
|
|
||||||
pytest-asyncio = "^0.21.1"
|
|
||||||
langchain-core = { path = "../../core", develop = true }
|
|
||||||
|
|
||||||
[tool.poetry.group.codespell]
|
|
||||||
optional = true
|
|
||||||
|
|
||||||
[tool.poetry.group.codespell.dependencies]
|
|
||||||
codespell = "^2.2.0"
|
|
||||||
|
|
||||||
[tool.poetry.group.test_integration]
|
|
||||||
optional = true
|
|
||||||
|
|
||||||
[tool.poetry.group.test_integration.dependencies]
|
|
||||||
langchain = { path = "../../langchain", develop = true }
|
|
||||||
langchain-community = { path = "../../community", develop = true }
|
|
||||||
langchain-core = { path = "../../core", develop = true }
|
|
||||||
|
|
||||||
[tool.poetry.group.lint]
|
|
||||||
optional = true
|
|
||||||
|
|
||||||
[tool.poetry.group.lint.dependencies]
|
|
||||||
ruff = "^0.1.5"
|
|
||||||
|
|
||||||
[tool.poetry.group.typing.dependencies]
|
|
||||||
mypy = "^0.991"
|
|
||||||
langchain-core = { path = "../../core", develop = true }
|
|
||||||
|
|
||||||
[tool.poetry.group.dev]
|
|
||||||
optional = true
|
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
|
||||||
langchain-core = { path = "../../core", develop = true }
|
|
||||||
|
|
||||||
[tool.ruff]
|
|
||||||
select = [
|
|
||||||
"E", # pycodestyle
|
|
||||||
"F", # pyflakes
|
|
||||||
"I", # isort
|
|
||||||
]
|
|
||||||
|
|
||||||
[tool.mypy]
|
|
||||||
disallow_untyped_defs = "True"
|
|
||||||
|
|
||||||
[tool.coverage.run]
|
|
||||||
omit = ["tests/*"]
|
|
||||||
|
|
||||||
[build-system]
|
|
||||||
requires = ["poetry-core>=1.0.0"]
|
|
||||||
build-backend = "poetry.core.masonry.api"
|
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
|
||||||
# --strict-markers will raise errors on unknown marks.
|
|
||||||
# https://docs.pytest.org/en/7.1.x/how-to/mark.html#raising-errors-on-unknown-marks
|
|
||||||
#
|
|
||||||
# https://docs.pytest.org/en/7.1.x/reference/reference.html
|
|
||||||
# --strict-config any warnings encountered while parsing the `pytest`
|
|
||||||
# section of the configuration file raise errors.
|
|
||||||
#
|
|
||||||
# https://github.com/tophat/syrupy
|
|
||||||
# --snapshot-warn-unused Prints a warning on unused snapshots rather than fail the test suite.
|
|
||||||
addopts = "--snapshot-warn-unused --strict-markers --strict-config --durations=5"
|
|
||||||
# Registering custom markers.
|
|
||||||
# https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers
|
|
||||||
markers = [
|
|
||||||
"requires: mark tests as requiring a specific library",
|
|
||||||
"asyncio: mark tests as requiring asyncio",
|
|
||||||
"compile: mark placeholder test used to compile integration tests without running them",
|
|
||||||
]
|
|
||||||
asyncio_mode = "auto"
|
|
@ -1,17 +0,0 @@
|
|||||||
import sys
|
|
||||||
import traceback
|
|
||||||
from importlib.machinery import SourceFileLoader
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
files = sys.argv[1:]
|
|
||||||
has_failure = False
|
|
||||||
for file in files:
|
|
||||||
try:
|
|
||||||
SourceFileLoader("x", file).load_module()
|
|
||||||
except Exception:
|
|
||||||
has_faillure = True
|
|
||||||
print(file)
|
|
||||||
traceback.print_exc()
|
|
||||||
print()
|
|
||||||
|
|
||||||
sys.exit(1 if has_failure else 0)
|
|
@ -1,27 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
#
|
|
||||||
# This script searches for lines starting with "import pydantic" or "from pydantic"
|
|
||||||
# in tracked files within a Git repository.
|
|
||||||
#
|
|
||||||
# Usage: ./scripts/check_pydantic.sh /path/to/repository
|
|
||||||
|
|
||||||
# Check if a path argument is provided
|
|
||||||
if [ $# -ne 1 ]; then
|
|
||||||
echo "Usage: $0 /path/to/repository"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
repository_path="$1"
|
|
||||||
|
|
||||||
# Search for lines matching the pattern within the specified repository
|
|
||||||
result=$(git -C "$repository_path" grep -E '^import pydantic|^from pydantic')
|
|
||||||
|
|
||||||
# Check if any matching lines were found
|
|
||||||
if [ -n "$result" ]; then
|
|
||||||
echo "ERROR: The following lines need to be updated:"
|
|
||||||
echo "$result"
|
|
||||||
echo "Please replace the code with an import from langchain_core.pydantic_v1."
|
|
||||||
echo "For example, replace 'from pydantic import BaseModel'"
|
|
||||||
echo "with 'from langchain_core.pydantic_v1 import BaseModel'"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
@ -1,17 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
set -eu
|
|
||||||
|
|
||||||
# Initialize a variable to keep track of errors
|
|
||||||
errors=0
|
|
||||||
|
|
||||||
# make sure not importing from langchain or langchain_experimental
|
|
||||||
git --no-pager grep '^from langchain\.' . && errors=$((errors+1))
|
|
||||||
git --no-pager grep '^from langchain_experimental\.' . && errors=$((errors+1))
|
|
||||||
|
|
||||||
# Decide on an exit status based on the errors
|
|
||||||
if [ "$errors" -gt 0 ]; then
|
|
||||||
exit 1
|
|
||||||
else
|
|
||||||
exit 0
|
|
||||||
fi
|
|
@ -1,5 +0,0 @@
|
|||||||
# astra db
|
|
||||||
ASTRA_DB_API_ENDPOINT=https://your_astra_db_id-your_region.apps.astra.datastax.com
|
|
||||||
ASTRA_DB_APPLICATION_TOKEN=AstraCS:your_astra_db_application_token
|
|
||||||
# ASTRA_DB_KEYSPACE=your_astra_db_namespace
|
|
||||||
# ASTRA_DB_SKIP_COLLECTION_DELETIONS=true
|
|
@ -1,19 +0,0 @@
|
|||||||
# Getting the absolute path of the current file's directory
|
|
||||||
import os
|
|
||||||
|
|
||||||
ABS_PATH = os.path.dirname(os.path.abspath(__file__))
|
|
||||||
|
|
||||||
# Getting the absolute path of the project's root directory
|
|
||||||
PROJECT_DIR = os.path.abspath(os.path.join(ABS_PATH, os.pardir, os.pardir))
|
|
||||||
|
|
||||||
|
|
||||||
# Loading the .env file if it exists
|
|
||||||
def _load_env() -> None:
|
|
||||||
dotenv_path = os.path.join(PROJECT_DIR, "tests", "integration_tests", ".env")
|
|
||||||
if os.path.exists(dotenv_path):
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
|
|
||||||
load_dotenv(dotenv_path)
|
|
||||||
|
|
||||||
|
|
||||||
_load_env()
|
|
@ -1,198 +0,0 @@
|
|||||||
import os
|
|
||||||
from typing import AsyncIterable, Iterable
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from langchain.memory import ConversationBufferMemory
|
|
||||||
from langchain_core.messages import AIMessage, HumanMessage
|
|
||||||
|
|
||||||
from langchain_astradb.chat_message_histories import (
|
|
||||||
AstraDBChatMessageHistory,
|
|
||||||
)
|
|
||||||
from langchain_astradb.utils.astradb import SetupMode
|
|
||||||
|
|
||||||
|
|
||||||
def _has_env_vars() -> bool:
|
|
||||||
return all(
|
|
||||||
[
|
|
||||||
"ASTRA_DB_APPLICATION_TOKEN" in os.environ,
|
|
||||||
"ASTRA_DB_API_ENDPOINT" in os.environ,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
|
||||||
def history1() -> Iterable[AstraDBChatMessageHistory]:
|
|
||||||
history1 = AstraDBChatMessageHistory(
|
|
||||||
session_id="session-test-1",
|
|
||||||
collection_name="langchain_cmh_test",
|
|
||||||
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
|
|
||||||
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
|
|
||||||
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
|
|
||||||
)
|
|
||||||
yield history1
|
|
||||||
history1.collection.astra_db.delete_collection("langchain_cmh_test")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
|
||||||
def history2() -> Iterable[AstraDBChatMessageHistory]:
|
|
||||||
history2 = AstraDBChatMessageHistory(
|
|
||||||
session_id="session-test-2",
|
|
||||||
collection_name="langchain_cmh_test",
|
|
||||||
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
|
|
||||||
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
|
|
||||||
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
|
|
||||||
)
|
|
||||||
yield history2
|
|
||||||
history2.collection.astra_db.delete_collection("langchain_cmh_test")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
async def async_history1() -> AsyncIterable[AstraDBChatMessageHistory]:
|
|
||||||
history1 = AstraDBChatMessageHistory(
|
|
||||||
session_id="async-session-test-1",
|
|
||||||
collection_name="langchain_cmh_test",
|
|
||||||
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
|
|
||||||
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
|
|
||||||
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
|
|
||||||
setup_mode=SetupMode.ASYNC,
|
|
||||||
)
|
|
||||||
yield history1
|
|
||||||
await history1.async_collection.astra_db.delete_collection("langchain_cmh_test")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
|
||||||
async def async_history2() -> AsyncIterable[AstraDBChatMessageHistory]:
|
|
||||||
history2 = AstraDBChatMessageHistory(
|
|
||||||
session_id="async-session-test-2",
|
|
||||||
collection_name="langchain_cmh_test",
|
|
||||||
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
|
|
||||||
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
|
|
||||||
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
|
|
||||||
setup_mode=SetupMode.ASYNC,
|
|
||||||
)
|
|
||||||
yield history2
|
|
||||||
await history2.async_collection.astra_db.delete_collection("langchain_cmh_test")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars")
|
|
||||||
def test_memory_with_message_store(history1: AstraDBChatMessageHistory) -> None:
|
|
||||||
"""Test the memory with a message store."""
|
|
||||||
memory = ConversationBufferMemory(
|
|
||||||
memory_key="baz",
|
|
||||||
chat_memory=history1,
|
|
||||||
return_messages=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert memory.chat_memory.messages == []
|
|
||||||
|
|
||||||
# add some messages
|
|
||||||
memory.chat_memory.add_messages(
|
|
||||||
[
|
|
||||||
AIMessage(content="This is me, the AI"),
|
|
||||||
HumanMessage(content="This is me, the human"),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
messages = memory.chat_memory.messages
|
|
||||||
expected = [
|
|
||||||
AIMessage(content="This is me, the AI"),
|
|
||||||
HumanMessage(content="This is me, the human"),
|
|
||||||
]
|
|
||||||
assert messages == expected
|
|
||||||
|
|
||||||
# clear the store
|
|
||||||
memory.chat_memory.clear()
|
|
||||||
|
|
||||||
assert memory.chat_memory.messages == []
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars")
|
|
||||||
async def test_memory_with_message_store_async(
|
|
||||||
async_history1: AstraDBChatMessageHistory,
|
|
||||||
) -> None:
|
|
||||||
"""Test the memory with a message store."""
|
|
||||||
memory = ConversationBufferMemory(
|
|
||||||
memory_key="baz",
|
|
||||||
chat_memory=async_history1,
|
|
||||||
return_messages=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert await memory.chat_memory.aget_messages() == []
|
|
||||||
|
|
||||||
# add some messages
|
|
||||||
await memory.chat_memory.aadd_messages(
|
|
||||||
[
|
|
||||||
AIMessage(content="This is me, the AI"),
|
|
||||||
HumanMessage(content="This is me, the human"),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
messages = await memory.chat_memory.aget_messages()
|
|
||||||
expected = [
|
|
||||||
AIMessage(content="This is me, the AI"),
|
|
||||||
HumanMessage(content="This is me, the human"),
|
|
||||||
]
|
|
||||||
assert messages == expected
|
|
||||||
|
|
||||||
# clear the store
|
|
||||||
await memory.chat_memory.aclear()
|
|
||||||
|
|
||||||
assert await memory.chat_memory.aget_messages() == []
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars")
|
|
||||||
def test_memory_separate_session_ids(
|
|
||||||
history1: AstraDBChatMessageHistory, history2: AstraDBChatMessageHistory
|
|
||||||
) -> None:
|
|
||||||
"""Test that separate session IDs do not share entries."""
|
|
||||||
memory1 = ConversationBufferMemory(
|
|
||||||
memory_key="mk1",
|
|
||||||
chat_memory=history1,
|
|
||||||
return_messages=True,
|
|
||||||
)
|
|
||||||
memory2 = ConversationBufferMemory(
|
|
||||||
memory_key="mk2",
|
|
||||||
chat_memory=history2,
|
|
||||||
return_messages=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
memory1.chat_memory.add_messages([AIMessage(content="Just saying.")])
|
|
||||||
|
|
||||||
assert memory2.chat_memory.messages == []
|
|
||||||
|
|
||||||
memory2.chat_memory.clear()
|
|
||||||
|
|
||||||
assert memory1.chat_memory.messages != []
|
|
||||||
|
|
||||||
memory1.chat_memory.clear()
|
|
||||||
|
|
||||||
assert memory1.chat_memory.messages == []
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars")
|
|
||||||
async def test_memory_separate_session_ids_async(
|
|
||||||
async_history1: AstraDBChatMessageHistory, async_history2: AstraDBChatMessageHistory
|
|
||||||
) -> None:
|
|
||||||
"""Test that separate session IDs do not share entries."""
|
|
||||||
memory1 = ConversationBufferMemory(
|
|
||||||
memory_key="mk1",
|
|
||||||
chat_memory=async_history1,
|
|
||||||
return_messages=True,
|
|
||||||
)
|
|
||||||
memory2 = ConversationBufferMemory(
|
|
||||||
memory_key="mk2",
|
|
||||||
chat_memory=async_history2,
|
|
||||||
return_messages=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
await memory1.chat_memory.aadd_messages([AIMessage(content="Just saying.")])
|
|
||||||
|
|
||||||
assert await memory2.chat_memory.aget_messages() == []
|
|
||||||
|
|
||||||
await memory2.chat_memory.aclear()
|
|
||||||
|
|
||||||
assert await memory1.chat_memory.aget_messages() != []
|
|
||||||
|
|
||||||
await memory1.chat_memory.aclear()
|
|
||||||
|
|
||||||
assert await memory1.chat_memory.aget_messages() == []
|
|
@ -1,7 +0,0 @@
|
|||||||
import pytest
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.compile
|
|
||||||
def test_placeholder() -> None:
|
|
||||||
"""Used for compiling integration tests without running any real tests."""
|
|
||||||
pass
|
|
@ -1,176 +0,0 @@
|
|||||||
"""Implement integration tests for AstraDB storage."""
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from astrapy.db import AstraDB, AsyncAstraDB
|
|
||||||
|
|
||||||
from langchain_astradb.storage import AstraDBByteStore, AstraDBStore
|
|
||||||
from langchain_astradb.utils.astradb import SetupMode
|
|
||||||
|
|
||||||
|
|
||||||
def _has_env_vars() -> bool:
|
|
||||||
return all(
|
|
||||||
[
|
|
||||||
"ASTRA_DB_APPLICATION_TOKEN" in os.environ,
|
|
||||||
"ASTRA_DB_API_ENDPOINT" in os.environ,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def astra_db() -> AstraDB:
|
|
||||||
from astrapy.db import AstraDB
|
|
||||||
|
|
||||||
return AstraDB(
|
|
||||||
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
|
|
||||||
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
|
|
||||||
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def async_astra_db() -> AsyncAstraDB:
|
|
||||||
from astrapy.db import AsyncAstraDB
|
|
||||||
|
|
||||||
return AsyncAstraDB(
|
|
||||||
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
|
|
||||||
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
|
|
||||||
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def init_store(astra_db: AstraDB, collection_name: str) -> AstraDBStore:
|
|
||||||
store = AstraDBStore(collection_name=collection_name, astra_db_client=astra_db)
|
|
||||||
store.mset([("key1", [0.1, 0.2]), ("key2", "value2")])
|
|
||||||
return store
|
|
||||||
|
|
||||||
|
|
||||||
def init_bytestore(astra_db: AstraDB, collection_name: str) -> AstraDBByteStore:
|
|
||||||
store = AstraDBByteStore(collection_name=collection_name, astra_db_client=astra_db)
|
|
||||||
store.mset([("key1", b"value1"), ("key2", b"value2")])
|
|
||||||
return store
|
|
||||||
|
|
||||||
|
|
||||||
async def init_async_store(
|
|
||||||
async_astra_db: AsyncAstraDB, collection_name: str
|
|
||||||
) -> AstraDBStore:
|
|
||||||
store = AstraDBStore(
|
|
||||||
collection_name=collection_name,
|
|
||||||
async_astra_db_client=async_astra_db,
|
|
||||||
setup_mode=SetupMode.ASYNC,
|
|
||||||
)
|
|
||||||
await store.amset([("key1", [0.1, 0.2]), ("key2", "value2")])
|
|
||||||
return store
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars")
|
|
||||||
class TestAstraDBStore:
|
|
||||||
def test_mget(self, astra_db: AstraDB) -> None:
|
|
||||||
"""Test AstraDBStore mget method."""
|
|
||||||
collection_name = "lc_test_store_mget"
|
|
||||||
try:
|
|
||||||
store = init_store(astra_db, collection_name)
|
|
||||||
assert store.mget(["key1", "key2"]) == [[0.1, 0.2], "value2"]
|
|
||||||
finally:
|
|
||||||
astra_db.delete_collection(collection_name)
|
|
||||||
|
|
||||||
async def test_amget(self, async_astra_db: AsyncAstraDB) -> None:
|
|
||||||
"""Test AstraDBStore amget method."""
|
|
||||||
collection_name = "lc_test_store_mget"
|
|
||||||
try:
|
|
||||||
store = await init_async_store(async_astra_db, collection_name)
|
|
||||||
assert await store.amget(["key1", "key2"]) == [[0.1, 0.2], "value2"]
|
|
||||||
finally:
|
|
||||||
await async_astra_db.delete_collection(collection_name)
|
|
||||||
|
|
||||||
def test_mset(self, astra_db: AstraDB) -> None:
|
|
||||||
"""Test that multiple keys can be set with AstraDBStore."""
|
|
||||||
collection_name = "lc_test_store_mset"
|
|
||||||
try:
|
|
||||||
store = init_store(astra_db, collection_name)
|
|
||||||
result = store.collection.find_one({"_id": "key1"})
|
|
||||||
assert result["data"]["document"]["value"] == [0.1, 0.2]
|
|
||||||
result = store.collection.find_one({"_id": "key2"})
|
|
||||||
assert result["data"]["document"]["value"] == "value2"
|
|
||||||
finally:
|
|
||||||
astra_db.delete_collection(collection_name)
|
|
||||||
|
|
||||||
async def test_amset(self, async_astra_db: AsyncAstraDB) -> None:
|
|
||||||
"""Test that multiple keys can be set with AstraDBStore."""
|
|
||||||
collection_name = "lc_test_store_mset"
|
|
||||||
try:
|
|
||||||
store = await init_async_store(async_astra_db, collection_name)
|
|
||||||
result = await store.async_collection.find_one({"_id": "key1"})
|
|
||||||
assert result["data"]["document"]["value"] == [0.1, 0.2]
|
|
||||||
result = await store.async_collection.find_one({"_id": "key2"})
|
|
||||||
assert result["data"]["document"]["value"] == "value2"
|
|
||||||
finally:
|
|
||||||
await async_astra_db.delete_collection(collection_name)
|
|
||||||
|
|
||||||
def test_mdelete(self, astra_db: AstraDB) -> None:
|
|
||||||
"""Test that deletion works as expected."""
|
|
||||||
collection_name = "lc_test_store_mdelete"
|
|
||||||
try:
|
|
||||||
store = init_store(astra_db, collection_name)
|
|
||||||
store.mdelete(["key1", "key2"])
|
|
||||||
result = store.mget(["key1", "key2"])
|
|
||||||
assert result == [None, None]
|
|
||||||
finally:
|
|
||||||
astra_db.delete_collection(collection_name)
|
|
||||||
|
|
||||||
async def test_amdelete(self, async_astra_db: AsyncAstraDB) -> None:
|
|
||||||
"""Test that deletion works as expected."""
|
|
||||||
collection_name = "lc_test_store_mdelete"
|
|
||||||
try:
|
|
||||||
store = await init_async_store(async_astra_db, collection_name)
|
|
||||||
await store.amdelete(["key1", "key2"])
|
|
||||||
result = await store.amget(["key1", "key2"])
|
|
||||||
assert result == [None, None]
|
|
||||||
finally:
|
|
||||||
await async_astra_db.delete_collection(collection_name)
|
|
||||||
|
|
||||||
def test_yield_keys(self, astra_db: AstraDB) -> None:
|
|
||||||
collection_name = "lc_test_store_yield_keys"
|
|
||||||
try:
|
|
||||||
store = init_store(astra_db, collection_name)
|
|
||||||
assert set(store.yield_keys()) == {"key1", "key2"}
|
|
||||||
assert set(store.yield_keys(prefix="key")) == {"key1", "key2"}
|
|
||||||
assert set(store.yield_keys(prefix="lang")) == set()
|
|
||||||
finally:
|
|
||||||
astra_db.delete_collection(collection_name)
|
|
||||||
|
|
||||||
async def test_ayield_keys(self, async_astra_db: AsyncAstraDB) -> None:
|
|
||||||
collection_name = "lc_test_store_yield_keys"
|
|
||||||
try:
|
|
||||||
store = await init_async_store(async_astra_db, collection_name)
|
|
||||||
assert {key async for key in store.ayield_keys()} == {"key1", "key2"}
|
|
||||||
assert {key async for key in store.ayield_keys(prefix="key")} == {
|
|
||||||
"key1",
|
|
||||||
"key2",
|
|
||||||
}
|
|
||||||
assert {key async for key in store.ayield_keys(prefix="lang")} == set()
|
|
||||||
finally:
|
|
||||||
await async_astra_db.delete_collection(collection_name)
|
|
||||||
|
|
||||||
def test_bytestore_mget(self, astra_db: AstraDB) -> None:
|
|
||||||
"""Test AstraDBByteStore mget method."""
|
|
||||||
collection_name = "lc_test_bytestore_mget"
|
|
||||||
try:
|
|
||||||
store = init_bytestore(astra_db, collection_name)
|
|
||||||
assert store.mget(["key1", "key2"]) == [b"value1", b"value2"]
|
|
||||||
finally:
|
|
||||||
astra_db.delete_collection(collection_name)
|
|
||||||
|
|
||||||
def test_bytestore_mset(self, astra_db: AstraDB) -> None:
|
|
||||||
"""Test that multiple keys can be set with AstraDBByteStore."""
|
|
||||||
collection_name = "lc_test_bytestore_mset"
|
|
||||||
try:
|
|
||||||
store = init_bytestore(astra_db, collection_name)
|
|
||||||
result = store.collection.find_one({"_id": "key1"})
|
|
||||||
assert result["data"]["document"]["value"] == "dmFsdWUx"
|
|
||||||
result = store.collection.find_one({"_id": "key2"})
|
|
||||||
assert result["data"]["document"]["value"] == "dmFsdWUy"
|
|
||||||
finally:
|
|
||||||
astra_db.delete_collection(collection_name)
|
|
@ -1,868 +0,0 @@
|
|||||||
"""
|
|
||||||
Test of Astra DB vector store class `AstraDBVectorStore`
|
|
||||||
|
|
||||||
Required to run this test:
|
|
||||||
- a recent `astrapy` Python package available
|
|
||||||
- an Astra DB instance;
|
|
||||||
- the two environment variables set:
|
|
||||||
export ASTRA_DB_API_ENDPOINT="https://<DB-ID>-us-east1.apps.astra.datastax.com"
|
|
||||||
export ASTRA_DB_APPLICATION_TOKEN="AstraCS:........."
|
|
||||||
- optionally this as well (otherwise defaults are used):
|
|
||||||
export ASTRA_DB_KEYSPACE="my_keyspace"
|
|
||||||
- optionally:
|
|
||||||
export ASTRA_DB_SKIP_COLLECTION_DELETIONS="0" ("1" = no deletions, default)
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import math
|
|
||||||
import os
|
|
||||||
from typing import Iterable, List, Optional, TypedDict
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from langchain_core.documents import Document
|
|
||||||
from langchain_core.embeddings import Embeddings
|
|
||||||
|
|
||||||
from langchain_astradb.vectorstores import AstraDBVectorStore
|
|
||||||
|
|
||||||
# Faster testing (no actual collection deletions). Off by default (=full tests)
|
|
||||||
SKIP_COLLECTION_DELETE = (
|
|
||||||
int(os.environ.get("ASTRA_DB_SKIP_COLLECTION_DELETIONS", "0")) != 0
|
|
||||||
)
|
|
||||||
|
|
||||||
COLLECTION_NAME_DIM2 = "lc_test_d2"
|
|
||||||
COLLECTION_NAME_DIM2_EUCLIDEAN = "lc_test_d2_eucl"
|
|
||||||
|
|
||||||
MATCH_EPSILON = 0.0001
|
|
||||||
|
|
||||||
# Ad-hoc embedding classes:
|
|
||||||
|
|
||||||
|
|
||||||
class AstraDBCredentials(TypedDict):
|
|
||||||
token: str
|
|
||||||
api_endpoint: str
|
|
||||||
namespace: Optional[str]
|
|
||||||
|
|
||||||
|
|
||||||
class SomeEmbeddings(Embeddings):
|
|
||||||
"""
|
|
||||||
Turn a sentence into an embedding vector in some way.
|
|
||||||
Not important how. It is deterministic is all that counts.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, dimension: int) -> None:
|
|
||||||
self.dimension = dimension
|
|
||||||
|
|
||||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
||||||
return [self.embed_query(txt) for txt in texts]
|
|
||||||
|
|
||||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
||||||
return self.embed_documents(texts)
|
|
||||||
|
|
||||||
def embed_query(self, text: str) -> List[float]:
|
|
||||||
unnormed0 = [ord(c) for c in text[: self.dimension]]
|
|
||||||
unnormed = (unnormed0 + [1] + [0] * (self.dimension - 1 - len(unnormed0)))[
|
|
||||||
: self.dimension
|
|
||||||
]
|
|
||||||
norm = sum(x * x for x in unnormed) ** 0.5
|
|
||||||
normed = [x / norm for x in unnormed]
|
|
||||||
return normed
|
|
||||||
|
|
||||||
async def aembed_query(self, text: str) -> List[float]:
|
|
||||||
return self.embed_query(text)
|
|
||||||
|
|
||||||
|
|
||||||
class ParserEmbeddings(Embeddings):
|
|
||||||
"""
|
|
||||||
Parse input texts: if they are json for a List[float], fine.
|
|
||||||
Otherwise, return all zeros and call it a day.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, dimension: int) -> None:
|
|
||||||
self.dimension = dimension
|
|
||||||
|
|
||||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
||||||
return [self.embed_query(txt) for txt in texts]
|
|
||||||
|
|
||||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
||||||
return self.embed_documents(texts)
|
|
||||||
|
|
||||||
def embed_query(self, text: str) -> List[float]:
|
|
||||||
try:
|
|
||||||
vals = json.loads(text)
|
|
||||||
assert len(vals) == self.dimension
|
|
||||||
return vals
|
|
||||||
except Exception:
|
|
||||||
print(f'[ParserEmbeddings] Returning a moot vector for "{text}"')
|
|
||||||
return [0.0] * self.dimension
|
|
||||||
|
|
||||||
async def aembed_query(self, text: str) -> List[float]:
|
|
||||||
return self.embed_query(text)
|
|
||||||
|
|
||||||
|
|
||||||
def _has_env_vars() -> bool:
|
|
||||||
return all(
|
|
||||||
[
|
|
||||||
"ASTRA_DB_APPLICATION_TOKEN" in os.environ,
|
|
||||||
"ASTRA_DB_API_ENDPOINT" in os.environ,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def astradb_credentials() -> Iterable[AstraDBCredentials]:
|
|
||||||
yield {
|
|
||||||
"token": os.environ["ASTRA_DB_APPLICATION_TOKEN"],
|
|
||||||
"api_endpoint": os.environ["ASTRA_DB_API_ENDPOINT"],
|
|
||||||
"namespace": os.environ.get("ASTRA_DB_KEYSPACE"),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
|
||||||
def store_someemb(
|
|
||||||
astradb_credentials: AstraDBCredentials,
|
|
||||||
) -> Iterable[AstraDBVectorStore]:
|
|
||||||
emb = SomeEmbeddings(dimension=2)
|
|
||||||
v_store = AstraDBVectorStore(
|
|
||||||
embedding=emb,
|
|
||||||
collection_name=COLLECTION_NAME_DIM2,
|
|
||||||
**astradb_credentials,
|
|
||||||
)
|
|
||||||
v_store.clear()
|
|
||||||
|
|
||||||
yield v_store
|
|
||||||
|
|
||||||
if not SKIP_COLLECTION_DELETE:
|
|
||||||
v_store.delete_collection()
|
|
||||||
else:
|
|
||||||
v_store.clear()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
|
||||||
def store_parseremb(
|
|
||||||
astradb_credentials: AstraDBCredentials,
|
|
||||||
) -> Iterable[AstraDBVectorStore]:
|
|
||||||
emb = ParserEmbeddings(dimension=2)
|
|
||||||
v_store = AstraDBVectorStore(
|
|
||||||
embedding=emb,
|
|
||||||
collection_name=COLLECTION_NAME_DIM2,
|
|
||||||
**astradb_credentials,
|
|
||||||
)
|
|
||||||
v_store.clear()
|
|
||||||
|
|
||||||
yield v_store
|
|
||||||
|
|
||||||
if not SKIP_COLLECTION_DELETE:
|
|
||||||
v_store.delete_collection()
|
|
||||||
else:
|
|
||||||
v_store.clear()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars")
|
|
||||||
class TestAstraDBVectorStore:
|
|
||||||
def test_astradb_vectorstore_create_delete(
|
|
||||||
self, astradb_credentials: AstraDBCredentials
|
|
||||||
) -> None:
|
|
||||||
"""Create and delete."""
|
|
||||||
from astrapy.db import AstraDB as LibAstraDB
|
|
||||||
|
|
||||||
emb = SomeEmbeddings(dimension=2)
|
|
||||||
# creation by passing the connection secrets
|
|
||||||
v_store = AstraDBVectorStore(
|
|
||||||
embedding=emb,
|
|
||||||
collection_name=COLLECTION_NAME_DIM2,
|
|
||||||
**astradb_credentials,
|
|
||||||
)
|
|
||||||
v_store.add_texts("Sample 1")
|
|
||||||
if not SKIP_COLLECTION_DELETE:
|
|
||||||
v_store.delete_collection()
|
|
||||||
else:
|
|
||||||
v_store.clear()
|
|
||||||
|
|
||||||
# Creation by passing a ready-made astrapy client:
|
|
||||||
astra_db_client = LibAstraDB(
|
|
||||||
**astradb_credentials,
|
|
||||||
)
|
|
||||||
v_store_2 = AstraDBVectorStore(
|
|
||||||
embedding=emb,
|
|
||||||
collection_name=COLLECTION_NAME_DIM2,
|
|
||||||
astra_db_client=astra_db_client,
|
|
||||||
)
|
|
||||||
v_store_2.add_texts("Sample 2")
|
|
||||||
if not SKIP_COLLECTION_DELETE:
|
|
||||||
v_store_2.delete_collection()
|
|
||||||
else:
|
|
||||||
v_store_2.clear()
|
|
||||||
|
|
||||||
async def test_astradb_vectorstore_create_delete_async(
|
|
||||||
self, astradb_credentials: AstraDBCredentials
|
|
||||||
) -> None:
|
|
||||||
"""Create and delete."""
|
|
||||||
emb = SomeEmbeddings(dimension=2)
|
|
||||||
# creation by passing the connection secrets
|
|
||||||
v_store = AstraDBVectorStore(
|
|
||||||
embedding=emb,
|
|
||||||
collection_name=COLLECTION_NAME_DIM2,
|
|
||||||
**astradb_credentials,
|
|
||||||
)
|
|
||||||
await v_store.adelete_collection()
|
|
||||||
# Creation by passing a ready-made astrapy client:
|
|
||||||
from astrapy.db import AsyncAstraDB
|
|
||||||
|
|
||||||
astra_db_client = AsyncAstraDB(
|
|
||||||
**astradb_credentials,
|
|
||||||
)
|
|
||||||
v_store_2 = AstraDBVectorStore(
|
|
||||||
embedding=emb,
|
|
||||||
collection_name="lc_test_2_async",
|
|
||||||
async_astra_db_client=astra_db_client,
|
|
||||||
)
|
|
||||||
if not SKIP_COLLECTION_DELETE:
|
|
||||||
await v_store_2.adelete_collection()
|
|
||||||
else:
|
|
||||||
await v_store_2.aclear()
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
|
||||||
SKIP_COLLECTION_DELETE,
|
|
||||||
reason="Collection-deletion tests are suppressed",
|
|
||||||
)
|
|
||||||
def test_astradb_vectorstore_pre_delete_collection(
|
|
||||||
self, astradb_credentials: AstraDBCredentials
|
|
||||||
) -> None:
|
|
||||||
"""Use of the pre_delete_collection flag."""
|
|
||||||
emb = SomeEmbeddings(dimension=2)
|
|
||||||
v_store = AstraDBVectorStore(
|
|
||||||
embedding=emb,
|
|
||||||
collection_name=COLLECTION_NAME_DIM2,
|
|
||||||
**astradb_credentials,
|
|
||||||
)
|
|
||||||
v_store.clear()
|
|
||||||
try:
|
|
||||||
v_store.add_texts(
|
|
||||||
texts=["aa"],
|
|
||||||
metadatas=[
|
|
||||||
{"k": "a", "ord": 0},
|
|
||||||
],
|
|
||||||
ids=["a"],
|
|
||||||
)
|
|
||||||
res1 = v_store.similarity_search("aa", k=5)
|
|
||||||
assert len(res1) == 1
|
|
||||||
v_store = AstraDBVectorStore(
|
|
||||||
embedding=emb,
|
|
||||||
pre_delete_collection=True,
|
|
||||||
collection_name=COLLECTION_NAME_DIM2,
|
|
||||||
**astradb_credentials,
|
|
||||||
)
|
|
||||||
res1 = v_store.similarity_search("aa", k=5)
|
|
||||||
assert len(res1) == 0
|
|
||||||
finally:
|
|
||||||
v_store.delete_collection()
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
|
||||||
SKIP_COLLECTION_DELETE,
|
|
||||||
reason="Collection-deletion tests are suppressed",
|
|
||||||
)
|
|
||||||
async def test_astradb_vectorstore_pre_delete_collection_async(
|
|
||||||
self, astradb_credentials: AstraDBCredentials
|
|
||||||
) -> None:
|
|
||||||
"""Use of the pre_delete_collection flag."""
|
|
||||||
emb = SomeEmbeddings(dimension=2)
|
|
||||||
# creation by passing the connection secrets
|
|
||||||
|
|
||||||
v_store = AstraDBVectorStore(
|
|
||||||
embedding=emb,
|
|
||||||
collection_name=COLLECTION_NAME_DIM2,
|
|
||||||
**astradb_credentials,
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
await v_store.aadd_texts(
|
|
||||||
texts=["aa"],
|
|
||||||
metadatas=[
|
|
||||||
{"k": "a", "ord": 0},
|
|
||||||
],
|
|
||||||
ids=["a"],
|
|
||||||
)
|
|
||||||
res1 = await v_store.asimilarity_search("aa", k=5)
|
|
||||||
assert len(res1) == 1
|
|
||||||
v_store = AstraDBVectorStore(
|
|
||||||
embedding=emb,
|
|
||||||
pre_delete_collection=True,
|
|
||||||
collection_name=COLLECTION_NAME_DIM2,
|
|
||||||
**astradb_credentials,
|
|
||||||
)
|
|
||||||
res1 = await v_store.asimilarity_search("aa", k=5)
|
|
||||||
assert len(res1) == 0
|
|
||||||
finally:
|
|
||||||
await v_store.adelete_collection()
|
|
||||||
|
|
||||||
def test_astradb_vectorstore_from_x(
|
|
||||||
self, astradb_credentials: AstraDBCredentials
|
|
||||||
) -> None:
|
|
||||||
"""from_texts and from_documents methods."""
|
|
||||||
emb = SomeEmbeddings(dimension=2)
|
|
||||||
# prepare empty collection
|
|
||||||
AstraDBVectorStore(
|
|
||||||
embedding=emb,
|
|
||||||
collection_name=COLLECTION_NAME_DIM2,
|
|
||||||
**astradb_credentials,
|
|
||||||
).clear()
|
|
||||||
# from_texts
|
|
||||||
v_store = AstraDBVectorStore.from_texts(
|
|
||||||
texts=["Hi", "Ho"],
|
|
||||||
embedding=emb,
|
|
||||||
collection_name=COLLECTION_NAME_DIM2,
|
|
||||||
**astradb_credentials,
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
assert v_store.similarity_search("Ho", k=1)[0].page_content == "Ho"
|
|
||||||
finally:
|
|
||||||
if not SKIP_COLLECTION_DELETE:
|
|
||||||
v_store.delete_collection()
|
|
||||||
else:
|
|
||||||
v_store.clear()
|
|
||||||
|
|
||||||
# from_documents
|
|
||||||
v_store_2 = AstraDBVectorStore.from_documents(
|
|
||||||
[
|
|
||||||
Document(page_content="Hee"),
|
|
||||||
Document(page_content="Hoi"),
|
|
||||||
],
|
|
||||||
embedding=emb,
|
|
||||||
collection_name=COLLECTION_NAME_DIM2,
|
|
||||||
**astradb_credentials,
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
assert v_store_2.similarity_search("Hoi", k=1)[0].page_content == "Hoi"
|
|
||||||
finally:
|
|
||||||
if not SKIP_COLLECTION_DELETE:
|
|
||||||
v_store_2.delete_collection()
|
|
||||||
else:
|
|
||||||
v_store_2.clear()
|
|
||||||
|
|
||||||
async def test_astradb_vectorstore_from_x_async(
|
|
||||||
self, astradb_credentials: AstraDBCredentials
|
|
||||||
) -> None:
|
|
||||||
"""from_texts and from_documents methods."""
|
|
||||||
emb = SomeEmbeddings(dimension=2)
|
|
||||||
# prepare empty collection
|
|
||||||
await AstraDBVectorStore(
|
|
||||||
embedding=emb,
|
|
||||||
collection_name=COLLECTION_NAME_DIM2,
|
|
||||||
**astradb_credentials,
|
|
||||||
).aclear()
|
|
||||||
# from_texts
|
|
||||||
v_store = await AstraDBVectorStore.afrom_texts(
|
|
||||||
texts=["Hi", "Ho"],
|
|
||||||
embedding=emb,
|
|
||||||
collection_name=COLLECTION_NAME_DIM2,
|
|
||||||
**astradb_credentials,
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
assert (await v_store.asimilarity_search("Ho", k=1))[0].page_content == "Ho"
|
|
||||||
finally:
|
|
||||||
if not SKIP_COLLECTION_DELETE:
|
|
||||||
await v_store.adelete_collection()
|
|
||||||
else:
|
|
||||||
await v_store.aclear()
|
|
||||||
|
|
||||||
# from_documents
|
|
||||||
v_store_2 = await AstraDBVectorStore.afrom_documents(
|
|
||||||
[
|
|
||||||
Document(page_content="Hee"),
|
|
||||||
Document(page_content="Hoi"),
|
|
||||||
],
|
|
||||||
embedding=emb,
|
|
||||||
collection_name=COLLECTION_NAME_DIM2,
|
|
||||||
**astradb_credentials,
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
assert (await v_store_2.asimilarity_search("Hoi", k=1))[
|
|
||||||
0
|
|
||||||
].page_content == "Hoi"
|
|
||||||
finally:
|
|
||||||
if not SKIP_COLLECTION_DELETE:
|
|
||||||
await v_store_2.adelete_collection()
|
|
||||||
else:
|
|
||||||
await v_store_2.aclear()
|
|
||||||
|
|
||||||
def test_astradb_vectorstore_crud(self, store_someemb: AstraDBVectorStore) -> None:
|
|
||||||
"""Basic add/delete/update behaviour."""
|
|
||||||
res0 = store_someemb.similarity_search("Abc", k=2)
|
|
||||||
assert res0 == []
|
|
||||||
# write and check again
|
|
||||||
store_someemb.add_texts(
|
|
||||||
texts=["aa", "bb", "cc"],
|
|
||||||
metadatas=[
|
|
||||||
{"k": "a", "ord": 0},
|
|
||||||
{"k": "b", "ord": 1},
|
|
||||||
{"k": "c", "ord": 2},
|
|
||||||
],
|
|
||||||
ids=["a", "b", "c"],
|
|
||||||
)
|
|
||||||
res1 = store_someemb.similarity_search("Abc", k=5)
|
|
||||||
assert {doc.page_content for doc in res1} == {"aa", "bb", "cc"}
|
|
||||||
# partial overwrite and count total entries
|
|
||||||
store_someemb.add_texts(
|
|
||||||
texts=["cc", "dd"],
|
|
||||||
metadatas=[
|
|
||||||
{"k": "c_new", "ord": 102},
|
|
||||||
{"k": "d_new", "ord": 103},
|
|
||||||
],
|
|
||||||
ids=["c", "d"],
|
|
||||||
)
|
|
||||||
res2 = store_someemb.similarity_search("Abc", k=10)
|
|
||||||
assert len(res2) == 4
|
|
||||||
# pick one that was just updated and check its metadata
|
|
||||||
res3 = store_someemb.similarity_search_with_score_id(
|
|
||||||
query="cc", k=1, filter={"k": "c_new"}
|
|
||||||
)
|
|
||||||
print(str(res3))
|
|
||||||
doc3, score3, id3 = res3[0]
|
|
||||||
assert doc3.page_content == "cc"
|
|
||||||
assert doc3.metadata == {"k": "c_new", "ord": 102}
|
|
||||||
assert score3 > 0.999 # leaving some leeway for approximations...
|
|
||||||
assert id3 == "c"
|
|
||||||
# delete and count again
|
|
||||||
del1_res = store_someemb.delete(["b"])
|
|
||||||
assert del1_res is True
|
|
||||||
del2_res = store_someemb.delete(["a", "c", "Z!"])
|
|
||||||
assert del2_res is True # a non-existing ID was supplied
|
|
||||||
assert len(store_someemb.similarity_search("xy", k=10)) == 1
|
|
||||||
# clear store
|
|
||||||
store_someemb.clear()
|
|
||||||
assert store_someemb.similarity_search("Abc", k=2) == []
|
|
||||||
# add_documents with "ids" arg passthrough
|
|
||||||
store_someemb.add_documents(
|
|
||||||
[
|
|
||||||
Document(page_content="vv", metadata={"k": "v", "ord": 204}),
|
|
||||||
Document(page_content="ww", metadata={"k": "w", "ord": 205}),
|
|
||||||
],
|
|
||||||
ids=["v", "w"],
|
|
||||||
)
|
|
||||||
assert len(store_someemb.similarity_search("xy", k=10)) == 2
|
|
||||||
res4 = store_someemb.similarity_search("ww", k=1, filter={"k": "w"})
|
|
||||||
assert res4[0].metadata["ord"] == 205
|
|
||||||
|
|
||||||
async def test_astradb_vectorstore_crud_async(
|
|
||||||
self, store_someemb: AstraDBVectorStore
|
|
||||||
) -> None:
|
|
||||||
"""Basic add/delete/update behaviour."""
|
|
||||||
res0 = await store_someemb.asimilarity_search("Abc", k=2)
|
|
||||||
assert res0 == []
|
|
||||||
# write and check again
|
|
||||||
await store_someemb.aadd_texts(
|
|
||||||
texts=["aa", "bb", "cc"],
|
|
||||||
metadatas=[
|
|
||||||
{"k": "a", "ord": 0},
|
|
||||||
{"k": "b", "ord": 1},
|
|
||||||
{"k": "c", "ord": 2},
|
|
||||||
],
|
|
||||||
ids=["a", "b", "c"],
|
|
||||||
)
|
|
||||||
res1 = await store_someemb.asimilarity_search("Abc", k=5)
|
|
||||||
assert {doc.page_content for doc in res1} == {"aa", "bb", "cc"}
|
|
||||||
# partial overwrite and count total entries
|
|
||||||
await store_someemb.aadd_texts(
|
|
||||||
texts=["cc", "dd"],
|
|
||||||
metadatas=[
|
|
||||||
{"k": "c_new", "ord": 102},
|
|
||||||
{"k": "d_new", "ord": 103},
|
|
||||||
],
|
|
||||||
ids=["c", "d"],
|
|
||||||
)
|
|
||||||
res2 = await store_someemb.asimilarity_search("Abc", k=10)
|
|
||||||
assert len(res2) == 4
|
|
||||||
# pick one that was just updated and check its metadata
|
|
||||||
res3 = await store_someemb.asimilarity_search_with_score_id(
|
|
||||||
query="cc", k=1, filter={"k": "c_new"}
|
|
||||||
)
|
|
||||||
print(str(res3))
|
|
||||||
doc3, score3, id3 = res3[0]
|
|
||||||
assert doc3.page_content == "cc"
|
|
||||||
assert doc3.metadata == {"k": "c_new", "ord": 102}
|
|
||||||
assert score3 > 0.999 # leaving some leeway for approximations...
|
|
||||||
assert id3 == "c"
|
|
||||||
# delete and count again
|
|
||||||
del1_res = await store_someemb.adelete(["b"])
|
|
||||||
assert del1_res is True
|
|
||||||
del2_res = await store_someemb.adelete(["a", "c", "Z!"])
|
|
||||||
assert del2_res is False # a non-existing ID was supplied
|
|
||||||
assert len(await store_someemb.asimilarity_search("xy", k=10)) == 1
|
|
||||||
# clear store
|
|
||||||
await store_someemb.aclear()
|
|
||||||
assert await store_someemb.asimilarity_search("Abc", k=2) == []
|
|
||||||
# add_documents with "ids" arg passthrough
|
|
||||||
await store_someemb.aadd_documents(
|
|
||||||
[
|
|
||||||
Document(page_content="vv", metadata={"k": "v", "ord": 204}),
|
|
||||||
Document(page_content="ww", metadata={"k": "w", "ord": 205}),
|
|
||||||
],
|
|
||||||
ids=["v", "w"],
|
|
||||||
)
|
|
||||||
assert len(await store_someemb.asimilarity_search("xy", k=10)) == 2
|
|
||||||
res4 = await store_someemb.asimilarity_search("ww", k=1, filter={"k": "w"})
|
|
||||||
assert res4[0].metadata["ord"] == 205
|
|
||||||
|
|
||||||
def test_astradb_vectorstore_mmr(self, store_parseremb: AstraDBVectorStore) -> None:
|
|
||||||
"""
|
|
||||||
MMR testing. We work on the unit circle with angle multiples
|
|
||||||
of 2*pi/20 and prepare a store with known vectors for a controlled
|
|
||||||
MMR outcome.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _v_from_i(i: int, N: int) -> str:
|
|
||||||
angle = 2 * math.pi * i / N
|
|
||||||
vector = [math.cos(angle), math.sin(angle)]
|
|
||||||
return json.dumps(vector)
|
|
||||||
|
|
||||||
i_vals = [0, 4, 5, 13]
|
|
||||||
N_val = 20
|
|
||||||
store_parseremb.add_texts(
|
|
||||||
[_v_from_i(i, N_val) for i in i_vals], metadatas=[{"i": i} for i in i_vals]
|
|
||||||
)
|
|
||||||
res1 = store_parseremb.max_marginal_relevance_search(
|
|
||||||
_v_from_i(3, N_val),
|
|
||||||
k=2,
|
|
||||||
fetch_k=3,
|
|
||||||
)
|
|
||||||
res_i_vals = {doc.metadata["i"] for doc in res1}
|
|
||||||
assert res_i_vals == {0, 4}
|
|
||||||
|
|
||||||
async def test_astradb_vectorstore_mmr_async(
|
|
||||||
self, store_parseremb: AstraDBVectorStore
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
MMR testing. We work on the unit circle with angle multiples
|
|
||||||
of 2*pi/20 and prepare a store with known vectors for a controlled
|
|
||||||
MMR outcome.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _v_from_i(i: int, N: int) -> str:
|
|
||||||
angle = 2 * math.pi * i / N
|
|
||||||
vector = [math.cos(angle), math.sin(angle)]
|
|
||||||
return json.dumps(vector)
|
|
||||||
|
|
||||||
i_vals = [0, 4, 5, 13]
|
|
||||||
N_val = 20
|
|
||||||
await store_parseremb.aadd_texts(
|
|
||||||
[_v_from_i(i, N_val) for i in i_vals],
|
|
||||||
metadatas=[{"i": i} for i in i_vals],
|
|
||||||
)
|
|
||||||
res1 = await store_parseremb.amax_marginal_relevance_search(
|
|
||||||
_v_from_i(3, N_val),
|
|
||||||
k=2,
|
|
||||||
fetch_k=3,
|
|
||||||
)
|
|
||||||
res_i_vals = {doc.metadata["i"] for doc in res1}
|
|
||||||
assert res_i_vals == {0, 4}
|
|
||||||
|
|
||||||
def test_astradb_vectorstore_metadata(
|
|
||||||
self, store_someemb: AstraDBVectorStore
|
|
||||||
) -> None:
|
|
||||||
"""Metadata filtering."""
|
|
||||||
store_someemb.add_documents(
|
|
||||||
[
|
|
||||||
Document(
|
|
||||||
page_content="q",
|
|
||||||
metadata={"ord": ord("q"), "group": "consonant"},
|
|
||||||
),
|
|
||||||
Document(
|
|
||||||
page_content="w",
|
|
||||||
metadata={"ord": ord("w"), "group": "consonant"},
|
|
||||||
),
|
|
||||||
Document(
|
|
||||||
page_content="r",
|
|
||||||
metadata={"ord": ord("r"), "group": "consonant"},
|
|
||||||
),
|
|
||||||
Document(
|
|
||||||
page_content="e",
|
|
||||||
metadata={"ord": ord("e"), "group": "vowel"},
|
|
||||||
),
|
|
||||||
Document(
|
|
||||||
page_content="i",
|
|
||||||
metadata={"ord": ord("i"), "group": "vowel"},
|
|
||||||
),
|
|
||||||
Document(
|
|
||||||
page_content="o",
|
|
||||||
metadata={"ord": ord("o"), "group": "vowel"},
|
|
||||||
),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
# no filters
|
|
||||||
res0 = store_someemb.similarity_search("x", k=10)
|
|
||||||
assert {doc.page_content for doc in res0} == set("qwreio")
|
|
||||||
# single filter
|
|
||||||
res1 = store_someemb.similarity_search(
|
|
||||||
"x",
|
|
||||||
k=10,
|
|
||||||
filter={"group": "vowel"},
|
|
||||||
)
|
|
||||||
assert {doc.page_content for doc in res1} == set("eio")
|
|
||||||
# multiple filters
|
|
||||||
res2 = store_someemb.similarity_search(
|
|
||||||
"x",
|
|
||||||
k=10,
|
|
||||||
filter={"group": "consonant", "ord": ord("q")},
|
|
||||||
)
|
|
||||||
assert {doc.page_content for doc in res2} == set("q")
|
|
||||||
# excessive filters
|
|
||||||
res3 = store_someemb.similarity_search(
|
|
||||||
"x",
|
|
||||||
k=10,
|
|
||||||
filter={"group": "consonant", "ord": ord("q"), "case": "upper"},
|
|
||||||
)
|
|
||||||
assert res3 == []
|
|
||||||
# filter with logical operator
|
|
||||||
res4 = store_someemb.similarity_search(
|
|
||||||
"x",
|
|
||||||
k=10,
|
|
||||||
filter={"$or": [{"ord": ord("q")}, {"ord": ord("r")}]},
|
|
||||||
)
|
|
||||||
assert {doc.page_content for doc in res4} == {"q", "r"}
|
|
||||||
|
|
||||||
def test_astradb_vectorstore_similarity_scale(
|
|
||||||
self, store_parseremb: AstraDBVectorStore
|
|
||||||
) -> None:
|
|
||||||
"""Scale of the similarity scores."""
|
|
||||||
store_parseremb.add_texts(
|
|
||||||
texts=[
|
|
||||||
json.dumps([1, 1]),
|
|
||||||
json.dumps([-1, -1]),
|
|
||||||
],
|
|
||||||
ids=["near", "far"],
|
|
||||||
)
|
|
||||||
res1 = store_parseremb.similarity_search_with_score(
|
|
||||||
json.dumps([0.5, 0.5]),
|
|
||||||
k=2,
|
|
||||||
)
|
|
||||||
scores = [sco for _, sco in res1]
|
|
||||||
sco_near, sco_far = scores
|
|
||||||
assert abs(1 - sco_near) < MATCH_EPSILON and abs(sco_far) < MATCH_EPSILON
|
|
||||||
|
|
||||||
async def test_astradb_vectorstore_similarity_scale_async(
|
|
||||||
self, store_parseremb: AstraDBVectorStore
|
|
||||||
) -> None:
|
|
||||||
"""Scale of the similarity scores."""
|
|
||||||
await store_parseremb.aadd_texts(
|
|
||||||
texts=[
|
|
||||||
json.dumps([1, 1]),
|
|
||||||
json.dumps([-1, -1]),
|
|
||||||
],
|
|
||||||
ids=["near", "far"],
|
|
||||||
)
|
|
||||||
res1 = await store_parseremb.asimilarity_search_with_score(
|
|
||||||
json.dumps([0.5, 0.5]),
|
|
||||||
k=2,
|
|
||||||
)
|
|
||||||
scores = [sco for _, sco in res1]
|
|
||||||
sco_near, sco_far = scores
|
|
||||||
assert abs(1 - sco_near) < MATCH_EPSILON and abs(sco_far) < MATCH_EPSILON
|
|
||||||
|
|
||||||
def test_astradb_vectorstore_massive_delete(
|
|
||||||
self, store_someemb: AstraDBVectorStore
|
|
||||||
) -> None:
|
|
||||||
"""Larger-scale bulk deletes."""
|
|
||||||
M = 50
|
|
||||||
texts = [str(i + 1 / 7.0) for i in range(2 * M)]
|
|
||||||
ids0 = ["doc_%i" % i for i in range(M)]
|
|
||||||
ids1 = ["doc_%i" % (i + M) for i in range(M)]
|
|
||||||
ids = ids0 + ids1
|
|
||||||
store_someemb.add_texts(texts=texts, ids=ids)
|
|
||||||
# deleting a bunch of these
|
|
||||||
del_res0 = store_someemb.delete(ids0)
|
|
||||||
assert del_res0 is True
|
|
||||||
# deleting the rest plus a fake one
|
|
||||||
del_res1 = store_someemb.delete(ids1 + ["ghost!"])
|
|
||||||
assert del_res1 is True # ensure no error
|
|
||||||
# nothing left
|
|
||||||
assert store_someemb.similarity_search("x", k=2 * M) == []
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
|
||||||
SKIP_COLLECTION_DELETE,
|
|
||||||
reason="Collection-deletion tests are suppressed",
|
|
||||||
)
|
|
||||||
def test_astradb_vectorstore_delete_collection(
|
|
||||||
self, astradb_credentials: AstraDBCredentials
|
|
||||||
) -> None:
|
|
||||||
"""behaviour of 'delete_collection'."""
|
|
||||||
collection_name = COLLECTION_NAME_DIM2
|
|
||||||
emb = SomeEmbeddings(dimension=2)
|
|
||||||
v_store = AstraDBVectorStore(
|
|
||||||
embedding=emb,
|
|
||||||
collection_name=collection_name,
|
|
||||||
**astradb_credentials,
|
|
||||||
)
|
|
||||||
v_store.add_texts(["huh"])
|
|
||||||
assert len(v_store.similarity_search("hah", k=10)) == 1
|
|
||||||
# another instance pointing to the same collection on DB
|
|
||||||
v_store_kenny = AstraDBVectorStore(
|
|
||||||
embedding=emb,
|
|
||||||
collection_name=collection_name,
|
|
||||||
**astradb_credentials,
|
|
||||||
)
|
|
||||||
v_store_kenny.delete_collection()
|
|
||||||
# dropped on DB, but 'v_store' should have no clue:
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
_ = v_store.similarity_search("hah", k=10)
|
|
||||||
|
|
||||||
def test_astradb_vectorstore_custom_params(
|
|
||||||
self, astradb_credentials: AstraDBCredentials
|
|
||||||
) -> None:
|
|
||||||
"""Custom batch size and concurrency params."""
|
|
||||||
emb = SomeEmbeddings(dimension=2)
|
|
||||||
# prepare empty collection
|
|
||||||
AstraDBVectorStore(
|
|
||||||
embedding=emb,
|
|
||||||
collection_name=COLLECTION_NAME_DIM2,
|
|
||||||
**astradb_credentials,
|
|
||||||
).clear()
|
|
||||||
v_store = AstraDBVectorStore(
|
|
||||||
embedding=emb,
|
|
||||||
collection_name=COLLECTION_NAME_DIM2,
|
|
||||||
**astradb_credentials,
|
|
||||||
batch_size=17,
|
|
||||||
bulk_insert_batch_concurrency=13,
|
|
||||||
bulk_insert_overwrite_concurrency=7,
|
|
||||||
bulk_delete_concurrency=19,
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
# add_texts
|
|
||||||
N = 50
|
|
||||||
texts = [str(i + 1 / 7.0) for i in range(N)]
|
|
||||||
ids = ["doc_%i" % i for i in range(N)]
|
|
||||||
v_store.add_texts(texts=texts, ids=ids)
|
|
||||||
v_store.add_texts(
|
|
||||||
texts=texts,
|
|
||||||
ids=ids,
|
|
||||||
batch_size=19,
|
|
||||||
batch_concurrency=7,
|
|
||||||
overwrite_concurrency=13,
|
|
||||||
)
|
|
||||||
#
|
|
||||||
_ = v_store.delete(ids[: N // 2])
|
|
||||||
_ = v_store.delete(ids[N // 2 :], concurrency=23)
|
|
||||||
#
|
|
||||||
finally:
|
|
||||||
if not SKIP_COLLECTION_DELETE:
|
|
||||||
v_store.delete_collection()
|
|
||||||
else:
|
|
||||||
v_store.clear()
|
|
||||||
|
|
||||||
async def test_astradb_vectorstore_custom_params_async(
|
|
||||||
self, astradb_credentials: AstraDBCredentials
|
|
||||||
) -> None:
|
|
||||||
"""Custom batch size and concurrency params."""
|
|
||||||
emb = SomeEmbeddings(dimension=2)
|
|
||||||
v_store = AstraDBVectorStore(
|
|
||||||
embedding=emb,
|
|
||||||
collection_name="lc_test_c_async",
|
|
||||||
batch_size=17,
|
|
||||||
bulk_insert_batch_concurrency=13,
|
|
||||||
bulk_insert_overwrite_concurrency=7,
|
|
||||||
bulk_delete_concurrency=19,
|
|
||||||
**astradb_credentials,
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
# add_texts
|
|
||||||
N = 50
|
|
||||||
texts = [str(i + 1 / 7.0) for i in range(N)]
|
|
||||||
ids = ["doc_%i" % i for i in range(N)]
|
|
||||||
await v_store.aadd_texts(texts=texts, ids=ids)
|
|
||||||
await v_store.aadd_texts(
|
|
||||||
texts=texts,
|
|
||||||
ids=ids,
|
|
||||||
batch_size=19,
|
|
||||||
batch_concurrency=7,
|
|
||||||
overwrite_concurrency=13,
|
|
||||||
)
|
|
||||||
#
|
|
||||||
await v_store.adelete(ids[: N // 2])
|
|
||||||
await v_store.adelete(ids[N // 2 :], concurrency=23)
|
|
||||||
#
|
|
||||||
finally:
|
|
||||||
if not SKIP_COLLECTION_DELETE:
|
|
||||||
await v_store.adelete_collection()
|
|
||||||
else:
|
|
||||||
await v_store.aclear()
|
|
||||||
|
|
||||||
def test_astradb_vectorstore_metrics(
|
|
||||||
self, astradb_credentials: AstraDBCredentials
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Different choices of similarity metric.
|
|
||||||
Both stores (with "cosine" and "euclidea" metrics) contain these two:
|
|
||||||
- a vector slightly rotated w.r.t query vector
|
|
||||||
- a vector which is a long multiple of query vector
|
|
||||||
so, which one is "the closest one" depends on the metric.
|
|
||||||
"""
|
|
||||||
emb = ParserEmbeddings(dimension=2)
|
|
||||||
isq2 = 0.5**0.5
|
|
||||||
isa = 0.7
|
|
||||||
isb = (1.0 - isa * isa) ** 0.5
|
|
||||||
texts = [
|
|
||||||
json.dumps([isa, isb]),
|
|
||||||
json.dumps([10 * isq2, 10 * isq2]),
|
|
||||||
]
|
|
||||||
ids = [
|
|
||||||
"rotated",
|
|
||||||
"scaled",
|
|
||||||
]
|
|
||||||
query_text = json.dumps([isq2, isq2])
|
|
||||||
|
|
||||||
# prepare empty collections
|
|
||||||
AstraDBVectorStore(
|
|
||||||
embedding=emb,
|
|
||||||
collection_name=COLLECTION_NAME_DIM2,
|
|
||||||
**astradb_credentials,
|
|
||||||
).clear()
|
|
||||||
AstraDBVectorStore(
|
|
||||||
embedding=emb,
|
|
||||||
collection_name=COLLECTION_NAME_DIM2_EUCLIDEAN,
|
|
||||||
metric="euclidean",
|
|
||||||
**astradb_credentials,
|
|
||||||
).clear()
|
|
||||||
|
|
||||||
# creation, population, query - cosine
|
|
||||||
vstore_cos = AstraDBVectorStore(
|
|
||||||
embedding=emb,
|
|
||||||
collection_name=COLLECTION_NAME_DIM2,
|
|
||||||
metric="cosine",
|
|
||||||
**astradb_credentials,
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
vstore_cos.add_texts(
|
|
||||||
texts=texts,
|
|
||||||
ids=ids,
|
|
||||||
)
|
|
||||||
_, _, id_from_cos = vstore_cos.similarity_search_with_score_id(
|
|
||||||
query_text,
|
|
||||||
k=1,
|
|
||||||
)[0]
|
|
||||||
assert id_from_cos == "scaled"
|
|
||||||
finally:
|
|
||||||
if not SKIP_COLLECTION_DELETE:
|
|
||||||
vstore_cos.delete_collection()
|
|
||||||
else:
|
|
||||||
vstore_cos.clear()
|
|
||||||
# creation, population, query - euclidean
|
|
||||||
|
|
||||||
vstore_euc = AstraDBVectorStore(
|
|
||||||
embedding=emb,
|
|
||||||
collection_name=COLLECTION_NAME_DIM2_EUCLIDEAN,
|
|
||||||
metric="euclidean",
|
|
||||||
**astradb_credentials,
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
vstore_euc.add_texts(
|
|
||||||
texts=texts,
|
|
||||||
ids=ids,
|
|
||||||
)
|
|
||||||
_, _, id_from_euc = vstore_euc.similarity_search_with_score_id(
|
|
||||||
query_text,
|
|
||||||
k=1,
|
|
||||||
)[0]
|
|
||||||
assert id_from_euc == "rotated"
|
|
||||||
finally:
|
|
||||||
if not SKIP_COLLECTION_DELETE:
|
|
||||||
vstore_euc.delete_collection()
|
|
||||||
else:
|
|
||||||
vstore_euc.clear()
|
|
@ -1,12 +0,0 @@
|
|||||||
from langchain_astradb import __all__
|
|
||||||
|
|
||||||
EXPECTED_ALL = [
|
|
||||||
"AstraDBByteStore",
|
|
||||||
"AstraDBStore",
|
|
||||||
"AstraDBChatMessageHistory",
|
|
||||||
"AstraDBVectorStore",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def test_all_imports() -> None:
|
|
||||||
assert sorted(EXPECTED_ALL) == sorted(__all__)
|
|
@ -1,45 +0,0 @@
|
|||||||
from typing import List
|
|
||||||
from unittest.mock import Mock
|
|
||||||
|
|
||||||
from langchain_core.embeddings import Embeddings
|
|
||||||
|
|
||||||
from langchain_astradb.vectorstores import AstraDBVectorStore
|
|
||||||
|
|
||||||
|
|
||||||
class SomeEmbeddings(Embeddings):
|
|
||||||
"""
|
|
||||||
Turn a sentence into an embedding vector in some way.
|
|
||||||
Not important how. It is deterministic is all that counts.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, dimension: int) -> None:
|
|
||||||
self.dimension = dimension
|
|
||||||
|
|
||||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
||||||
return [self.embed_query(txt) for txt in texts]
|
|
||||||
|
|
||||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
||||||
return self.embed_documents(texts)
|
|
||||||
|
|
||||||
def embed_query(self, text: str) -> List[float]:
|
|
||||||
unnormed0 = [ord(c) for c in text[: self.dimension]]
|
|
||||||
unnormed = (unnormed0 + [1] + [0] * (self.dimension - 1 - len(unnormed0)))[
|
|
||||||
: self.dimension
|
|
||||||
]
|
|
||||||
norm = sum(x * x for x in unnormed) ** 0.5
|
|
||||||
normed = [x / norm for x in unnormed]
|
|
||||||
return normed
|
|
||||||
|
|
||||||
async def aembed_query(self, text: str) -> List[float]:
|
|
||||||
return self.embed_query(text)
|
|
||||||
|
|
||||||
|
|
||||||
def test_initialization() -> None:
|
|
||||||
"""Test integration vectorstore initialization."""
|
|
||||||
mock_astra_db = Mock()
|
|
||||||
embedding = SomeEmbeddings(dimension=2)
|
|
||||||
AstraDBVectorStore(
|
|
||||||
embedding=embedding,
|
|
||||||
collection_name="mock_coll_name",
|
|
||||||
astra_db_client=mock_astra_db,
|
|
||||||
)
|
|
Loading…
Reference in New Issue
Block a user