community[minor]: Add Initial Support for TiDB Vector Store (#15796)

This pull request introduces initial support for the TiDB vector store.
The current version is basic, laying the foundation for the vector store
integration. While this implementation provides the essential features,
we plan to expand and improve the TiDB vector store support with
additional enhancements in future updates.

Upcoming Enhancements:
* Support for Vector Index Creation: To enhance the efficiency and
performance of the vector store.
* Support for max marginal relevance search. 
* Customized Table Structure Support: Recognizing the need for
flexibility, we plan for more tailored and efficient data store
solutions.

Simple use case exmaple

```python
from typing import List, Tuple
from langchain.docstore.document import Document
from langchain_community.vectorstores import TiDBVectorStore
from langchain_openai import OpenAIEmbeddings

db = TiDBVectorStore.from_texts(
    embedding=embeddings,
    texts=['Andrew like eating oranges', 'Alexandra is from England', 'Ketanji Brown Jackson is a judge'],
    table_name="tidb_vector_langchain",
    connection_string=tidb_connection_url,
    distance_strategy="cosine",
)

query = "Can you tell me about Alexandra?"
docs_with_score: List[Tuple[Document, float]] = db.similarity_search_with_score(query)
for doc, score in docs_with_score:
    print("-" * 80)
    print("Score: ", score)
    print(doc.page_content)
    print("-" * 80)
```
This commit is contained in:
Ian
2024-03-08 09:18:20 +08:00
committed by GitHub
parent 3b1eb1f828
commit 390ef6abe3
10 changed files with 1425 additions and 5 deletions

View File

@@ -436,6 +436,12 @@ def _import_tencentvectordb() -> Any:
return TencentVectorDB
def _import_tidb_vectorstore() -> Any:
from langchain_community.vectorstores.tidb_vector import TiDBVectorStore
return TiDBVectorStore
def _import_tiledb() -> Any:
from langchain_community.vectorstores.tiledb import TileDB
@@ -651,6 +657,8 @@ def __getattr__(name: str) -> Any:
return _import_tair()
elif name == "TencentVectorDB":
return _import_tencentvectordb()
elif name == "TiDBVectorStore":
return _import_tidb_vectorstore()
elif name == "TileDB":
return _import_tiledb()
elif name == "Tigris":
@@ -746,6 +754,7 @@ __all__ = [
"SupabaseVectorStore",
"SurrealDBStore",
"Tair",
"TiDBVectorStore",
"TileDB",
"Tigris",
"TimescaleVector",

View File

@@ -0,0 +1,362 @@
import uuid
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VectorStore
DEFAULT_DISTANCE_STRATEGY = "cosine" # or "l2", "inner_product"
DEFAULT_TiDB_VECTOR_TABLE_NAME = "langchain_vector"
class TiDBVectorStore(VectorStore):
def __init__(
self,
connection_string: str,
embedding_function: Embeddings,
table_name: str = DEFAULT_TiDB_VECTOR_TABLE_NAME,
distance_strategy: str = DEFAULT_DISTANCE_STRATEGY,
*,
engine_args: Optional[Dict[str, Any]] = None,
drop_existing_table: bool = False,
**kwargs: Any,
) -> None:
"""
Initialize a TiDB Vector Store in Langchain with a flexible
and standardized table structure for storing vector data
which remains fixed regardless of the dynamic table name setting.
The vector table schema includes:
- 'id': a UUID for each entry.
- 'embedding': stores vector data in a VectorType column.
- 'document': a Text column for the original data or additional information.
- 'meta': a JSON column for flexible metadata storage.
- 'create_time' and 'update_time': timestamp columns for tracking data changes.
This table structure caters to general use cases and
complex scenarios where the table serves as a semantic layer for advanced
data integration and analysis, leveraging SQL for join queries.
Args:
connection_string (str): The connection string for the TiDB database,
format: "mysql+pymysql://root@34.212.137.91:4000/test".
embedding_function: The embedding function used to generate embeddings.
table_name (str, optional): The name of the table that will be used to
store vector data. If you do not provide a table name,
a default table named `langchain_vector` will be created automatically.
distance_strategy: The strategy used for similarity search,
defaults to "cosine", valid values: "l2", "cosine", "inner_product".
engine_args (Optional[Dict]): Additional arguments for the database engine,
defaults to None.
drop_existing_table: Drop the existing TiDB table before initializing,
defaults to False.
**kwargs (Any): Additional keyword arguments.
Examples:
.. code-block:: python
from langchain_community.vectorstores import TiDBVectorStore
from langchain_openai import OpenAIEmbeddings
embeddingFunc = OpenAIEmbeddings()
CONNECTION_STRING = "mysql+pymysql://root@34.212.137.91:4000/test"
vs = TiDBVector.from_texts(
embedding=embeddingFunc,
texts = [..., ...],
connection_string=CONNECTION_STRING,
distance_strategy="l2",
table_name="tidb_vector_langchain",
)
query = "What did the president say about Ketanji Brown Jackson"
docs = db.similarity_search_with_score(query)
"""
super().__init__(**kwargs)
self._connection_string = connection_string
self._embedding_function = embedding_function
self._distance_strategy = distance_strategy
self._vector_dimension = self._get_dimension()
try:
from tidb_vector.integrations import TiDBVectorClient
except ImportError:
raise ImportError(
"Could not import tidbvec python package. "
"Please install it with `pip install tidb-vector`."
)
self._tidb = TiDBVectorClient(
connection_string=connection_string,
table_name=table_name,
distance_strategy=distance_strategy,
vector_dimension=self._vector_dimension,
engine_args=engine_args,
drop_existing_table=drop_existing_table,
**kwargs,
)
@property
def embeddings(self) -> Embeddings:
"""Return the function used to generate embeddings."""
return self._embedding_function
@property
def tidb_vector_client(self) -> Any:
"""Return the TiDB Vector Client."""
return self._tidb
@property
def distance_strategy(self) -> Any:
"""
Returns the current distance strategy.
"""
return self._distance_strategy
def _get_dimension(self) -> int:
"""
Get the dimension of the vector using embedding functions.
"""
return len(self._embedding_function.embed_query("test embedding length"))
@classmethod
def from_texts(
cls,
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
) -> "TiDBVectorStore":
"""
Create a VectorStore from a list of texts.
Args:
texts (List[str]): The list of texts to be added to the TiDB Vector.
embedding (Embeddings): The function to use for generating embeddings.
metadatas: The list of metadata dictionaries corresponding to each text,
defaults to None.
**kwargs (Any): Additional keyword arguments.
connection_string (str): The connection string for the TiDB database,
format: "mysql+pymysql://root@34.212.137.91:4000/test".
table_name (str, optional): The name of table used to store vector data,
defaults to "langchain_vector".
distance_strategy: The distance strategy used for similarity search,
defaults to "cosine", allowed: "l2", "cosine", "inner_product".
ids (Optional[List[str]]): The list of IDs corresponding to each text,
defaults to None.
engine_args: Additional arguments for the underlying database engine,
defaults to None.
drop_existing_table: Drop the existing TiDB table before initializing,
defaults to False.
Returns:
VectorStore: The created TiDB Vector Store.
"""
# Extract arguments from kwargs with default values
connection_string = kwargs.pop("connection_string", None)
if connection_string is None:
raise ValueError("please provide your tidb connection_url")
table_name = kwargs.pop("table_name", "langchain_vector")
distance_strategy = kwargs.pop("distance_strategy", "cosine")
ids = kwargs.pop("ids", None)
engine_args = kwargs.pop("engine_args", None)
drop_existing_table = kwargs.pop("drop_existing_table", False)
embeddings = embedding.embed_documents(list(texts))
vs = cls(
connection_string=connection_string,
table_name=table_name,
embedding_function=embedding,
distance_strategy=distance_strategy,
engine_args=engine_args,
drop_existing_table=drop_existing_table,
**kwargs,
)
vs._tidb.insert(
texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs
)
return vs
@classmethod
def from_existing_vector_table(
cls,
embedding: Embeddings,
connection_string: str,
table_name: str,
distance_strategy: str = DEFAULT_DISTANCE_STRATEGY,
*,
engine_args: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> VectorStore:
"""
Create a VectorStore instance from an existing TiDB Vector Store in TiDB.
Args:
embedding (Embeddings): The function to use for generating embeddings.
connection_string (str): The connection string for the TiDB database,
format: "mysql+pymysql://root@34.212.137.91:4000/test".
table_name (str, optional): The name of table used to store vector data,
defaults to "langchain_vector".
distance_strategy: The distance strategy used for similarity search,
defaults to "cosine", allowed: "l2", "cosine", 'inner_product'.
engine_args: Additional arguments for the underlying database engine,
defaults to None.
**kwargs (Any): Additional keyword arguments.
Returns:
VectorStore: The VectorStore instance.
Raises:
NoSuchTableError: If the specified table does not exist in the TiDB.
"""
try:
from tidb_vector.integrations import check_table_existence
except ImportError:
raise ImportError(
"Could not import tidbvec python package. "
"Please install it with `pip install tidb-vector`."
)
if check_table_existence(connection_string, table_name):
return cls(
connection_string=connection_string,
table_name=table_name,
embedding_function=embedding,
distance_strategy=distance_strategy,
engine_args=engine_args,
**kwargs,
)
else:
raise ValueError(f"Table {table_name} does not exist in the TiDB database.")
def drop_vectorstore(self) -> None:
"""
Drop the Vector Store from the TiDB database.
"""
self._tidb.drop_table()
def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
**kwargs: Any,
) -> List[str]:
"""
Add texts to TiDB Vector Store.
Args:
texts (Iterable[str]): The texts to be added.
metadatas (Optional[List[dict]]): The metadata associated with each text,
Defaults to None.
ids (Optional[List[str]]): The IDs to be assigned to each text,
Defaults to None, will be generated if not provided.
Returns:
List[str]: The IDs assigned to the added texts.
"""
embeddings = self._embedding_function.embed_documents(list(texts))
if ids is None:
ids = [str(uuid.uuid4()) for _ in texts]
if not metadatas:
metadatas = [{} for _ in texts]
return self._tidb.insert(
texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs
)
def delete(
self,
ids: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
"""
Delete vector data from the TiDB Vector Store.
Args:
ids (Optional[List[str]]): A list of vector IDs to delete.
**kwargs: Additional keyword arguments.
"""
self._tidb.delete(ids=ids, **kwargs)
def similarity_search(
self,
query: str,
k: int = 4,
filter: Optional[dict] = None,
**kwargs: Any,
) -> List[Document]:
"""
Perform a similarity search using the given query.
Args:
query (str): The query string.
k (int, optional): The number of results to retrieve. Defaults to 4.
filter (dict, optional): A filter to apply to the search results.
Defaults to None.
**kwargs: Additional keyword arguments.
Returns:
List[Document]: A list of Document objects representing the search results.
"""
result = self.similarity_search_with_score(query, k, filter, **kwargs)
return [doc for doc, _ in result]
def similarity_search_with_score(
self,
query: str,
k: int = 5,
filter: Optional[dict] = None,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""
Perform a similarity search with score based on the given query.
Args:
query (str): The query string.
k (int, optional): The number of results to return. Defaults to 5.
filter (dict, optional): A filter to apply to the search results.
Defaults to None.
**kwargs: Additional keyword arguments.
Returns:
A list of tuples containing relevant documents and their similarity scores.
"""
query_vector = self._embedding_function.embed_query(query)
relevant_docs = self._tidb.query(
query_vector=query_vector, k=k, filter=filter, **kwargs
)
return [
(
Document(
page_content=doc.document,
metadata=doc.metadata,
),
doc.distance,
)
for doc in relevant_docs
]
def _select_relevance_score_fn(self) -> Callable[[float], float]:
"""
Select the relevance score function based on the distance strategy.
"""
if self._distance_strategy == "cosine":
return self._cosine_relevance_score_fn
elif self._distance_strategy == "l2":
return self._euclidean_relevance_score_fn
else:
raise ValueError(
"No supported normalization function"
f" for distance_strategy of {self._distance_strategy}."
"Consider providing relevance_score_fn to PGVector constructor."
)

View File

@@ -6527,7 +6527,6 @@ files = [
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
@@ -7808,6 +7807,21 @@ files = [
{file = "threadpoolctl-3.2.0.tar.gz", hash = "sha256:c96a0ba3bdddeaca37dc4cc7344aafad41cdb8c313f74fdfe387a867bba93355"},
]
[[package]]
name = "tidb-vector"
version = "0.0.4"
description = ""
optional = true
python-versions = ">=3.8.1,<4.0"
files = [
{file = "tidb_vector-0.0.4-py3-none-any.whl", hash = "sha256:8e10d3f06da3beb5d676b3a6d817df1defb5d35a91945778a072c2452e777a3a"},
{file = "tidb_vector-0.0.4.tar.gz", hash = "sha256:b2dcd3c437e6e073724f7e0093bb4e48484d41d8f7c8087329335dd3e44403ef"},
]
[package.dependencies]
numpy = ">=1,<2"
SQLAlchemy = ">=1.4,<3"
[[package]]
name = "tiktoken"
version = "0.5.2"
@@ -9176,9 +9190,9 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
[extras]
cli = ["typer"]
extended-testing = ["aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "azure-ai-documentintelligence", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "cloudpickle", "cloudpickle", "cohere", "databricks-vectorsearch", "datasets", "dgml-utils", "elasticsearch", "esprima", "faiss-cpu", "feedparser", "fireworks-ai", "geopandas", "gitpython", "google-cloud-documentai", "gql", "gradientai", "hdbcli", "hologres-vector", "html2text", "httpx", "javelin-sdk", "jinja2", "jq", "jsonschema", "lxml", "markdownify", "motor", "msal", "mwparserfromhell", "mwxml", "newspaper3k", "numexpr", "nvidia-riva-client", "oci", "openai", "openapi-pydantic", "oracle-ads", "pandas", "pdfminer-six", "pgvector", "praw", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "rapidocr-onnxruntime", "rdflib", "requests-toolbelt", "rspace_client", "scikit-learn", "sqlite-vss", "streamlit", "sympy", "telethon", "timescale-vector", "tqdm", "tree-sitter", "tree-sitter-languages", "upstash-redis", "xata", "xmltodict", "zhipuai"]
extended-testing = ["aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "azure-ai-documentintelligence", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "cloudpickle", "cloudpickle", "cohere", "databricks-vectorsearch", "datasets", "dgml-utils", "elasticsearch", "esprima", "faiss-cpu", "feedparser", "fireworks-ai", "geopandas", "gitpython", "google-cloud-documentai", "gql", "gradientai", "hdbcli", "hologres-vector", "html2text", "httpx", "javelin-sdk", "jinja2", "jq", "jsonschema", "lxml", "markdownify", "motor", "msal", "mwparserfromhell", "mwxml", "newspaper3k", "numexpr", "nvidia-riva-client", "oci", "openai", "openapi-pydantic", "oracle-ads", "pandas", "pdfminer-six", "pgvector", "praw", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "rapidocr-onnxruntime", "rdflib", "requests-toolbelt", "rspace_client", "scikit-learn", "sqlite-vss", "streamlit", "sympy", "telethon", "tidb-vector", "timescale-vector", "tqdm", "tree-sitter", "tree-sitter-languages", "upstash-redis", "xata", "xmltodict", "zhipuai"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.8.1,<4.0"
content-hash = "b744cd68e1c4be95f1461ddfd9c06526bbdef88595e652a3d8019e27a8225c1c"
content-hash = "c8a7a435aabbfafc6c4210c2ceca2030b4a3724bd1d5804526498d2b52aa9db1"

View File

@@ -95,6 +95,7 @@ hdbcli = {version = "^2.19.21", optional = true}
oci = {version = "^2.119.1", optional = true}
rdflib = {version = "7.0.0", optional = true}
nvidia-riva-client = {version = "^2.14.0", optional = true}
tidb-vector = {version = ">=0.0.3,<1.0.0", optional = true}
[tool.poetry.group.test]
optional = true
@@ -263,6 +264,7 @@ extended_testing = [
"hdbcli",
"oci",
"rdflib",
"tidb-vector",
"cloudpickle",
]

View File

@@ -0,0 +1,349 @@
"""Test TiDB Vector functionality."""
import os
from typing import List
from langchain_core.documents import Document
from langchain_community.vectorstores import TiDBVectorStore
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
TiDB_CONNECT_URL = os.environ.get(
"TEST_TiDB_CONNECTION_URL", "mysql+pymysql://root@127.0.0.1:4000/test"
)
ADA_TOKEN_COUNT = 1536
class FakeEmbeddingsWithAdaDimension(FakeEmbeddings):
"""Fake embeddings functionality for testing."""
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Return simple embeddings based on ASCII values of text characters."""
return [self._text_to_embedding(text) for text in texts]
def embed_query(self, text: str) -> List[float]:
"""Return simple embeddings based on ASCII values of text characters."""
return self._text_to_embedding(text)
def _text_to_embedding(self, text: str) -> List[float]:
"""Convert text to a unique embedding using ASCII values."""
ascii_values = [float(ord(char)) for char in text]
# Pad or trim the list to make it of length ADA_TOKEN_COUNT
ascii_values = ascii_values[:ADA_TOKEN_COUNT] + [0.0] * (
ADA_TOKEN_COUNT - len(ascii_values)
)
return ascii_values
def test_search() -> None:
"""Test end to end construction and search."""
texts = ["foo", "bar", "baz"]
metadatas = [{"page": str(i)} for i in range(len(texts))]
ids = ["1", "2", "3"]
docsearch = TiDBVectorStore.from_texts(
texts=texts,
table_name="test_tidb_vectorstore_langchain",
embedding=FakeEmbeddingsWithAdaDimension(),
connection_string=TiDB_CONNECT_URL,
metadatas=metadatas,
ids=ids,
drop_existing_table=True,
distance_strategy="cosine",
)
with docsearch.tidb_vector_client._make_session() as session:
records = list(session.query(docsearch.tidb_vector_client._table_model).all())
assert len([record.id for record in records]) == 3 # type: ignore
session.close()
output = docsearch.similarity_search("foo", k=1)
docsearch.drop_vectorstore()
assert output == [Document(page_content="foo", metadata={"page": "0"})]
def test_search_with_filter() -> None:
"""Test end to end construction and search."""
# no metadata
texts = ["foo", "bar", "baz"]
docsearch = TiDBVectorStore.from_texts(
texts=texts,
table_name="test_tidb_vectorstore_langchain",
embedding=FakeEmbeddingsWithAdaDimension(),
connection_string=TiDB_CONNECT_URL,
drop_existing_table=True,
)
output = docsearch.similarity_search("foo", k=1)
output_filtered = docsearch.similarity_search(
"foo", k=1, filter={"filter_condition": "N/A"}
)
assert output == [Document(page_content="foo")]
assert output_filtered == []
# having metadata
metadatas = [{"page": i + 1, "page_str": str(i + 1)} for i in range(len(texts))]
docsearch = TiDBVectorStore.from_texts(
texts=texts,
table_name="test_tidb_vectorstore_langchain",
embedding=FakeEmbeddingsWithAdaDimension(),
connection_string=TiDB_CONNECT_URL,
metadatas=metadatas,
drop_existing_table=True,
)
output = docsearch.similarity_search("foo", k=1, filter={"page": 1})
assert output == [
Document(page_content="foo", metadata={"page": 1, "page_str": "1"})
]
# test mismatched value
output = docsearch.similarity_search("foo", k=1, filter={"page": "1"})
assert output == []
# test non-existing key
output = docsearch.similarity_search("foo", k=1, filter={"filter_condition": "N/A"})
assert output == []
# test IN, NIN expression
output = docsearch.similarity_search("foo", k=1, filter={"page": {"$in": [1, 2]}})
assert output == [
Document(page_content="foo", metadata={"page": 1, "page_str": "1"})
]
output = docsearch.similarity_search("foo", k=1, filter={"page": {"$nin": [1, 2]}})
assert output == [
Document(page_content="baz", metadata={"page": 3, "page_str": "3"})
]
output = docsearch.similarity_search(
"foo", k=1, filter={"page": {"$in": ["1", "2"]}}
)
assert output == []
output = docsearch.similarity_search(
"foo", k=1, filter={"page_str": {"$in": ["1", "2"]}}
)
assert output == [
Document(page_content="foo", metadata={"page": 1, "page_str": "1"})
]
# test GT, GTE, LT, LTE expression
output = docsearch.similarity_search("foo", k=1, filter={"page": {"$gt": 1}})
assert output == [
Document(page_content="bar", metadata={"page": 2, "page_str": "2"})
]
output = docsearch.similarity_search("foo", k=1, filter={"page": {"$gte": 1}})
assert output == [
Document(page_content="foo", metadata={"page": 1, "page_str": "1"})
]
output = docsearch.similarity_search("foo", k=1, filter={"page": {"$lt": 3}})
assert output == [
Document(page_content="foo", metadata={"page": 1, "page_str": "1"})
]
output = docsearch.similarity_search("baz", k=1, filter={"page": {"$lte": 3}})
assert output == [
Document(page_content="baz", metadata={"page": 3, "page_str": "3"})
]
output = docsearch.similarity_search("foo", k=1, filter={"page": {"$gt": 3}})
assert output == []
output = docsearch.similarity_search("foo", k=1, filter={"page": {"$lt": 1}})
assert output == []
# test eq, neq expression
output = docsearch.similarity_search("foo", k=1, filter={"page": {"$eq": 3}})
assert output == [
Document(page_content="baz", metadata={"page": 3, "page_str": "3"})
]
output = docsearch.similarity_search("bar", k=1, filter={"page": {"$ne": 2}})
assert output == [
Document(page_content="baz", metadata={"page": 3, "page_str": "3"})
]
# test AND, OR expression
output = docsearch.similarity_search(
"bar", k=1, filter={"$and": [{"page": 1}, {"page_str": "1"}]}
)
assert output == [
Document(page_content="foo", metadata={"page": 1, "page_str": "1"})
]
output = docsearch.similarity_search(
"bar", k=1, filter={"$or": [{"page": 1}, {"page_str": "2"}]}
)
assert output == [
Document(page_content="bar", metadata={"page": 2, "page_str": "2"}),
]
output = docsearch.similarity_search(
"foo",
k=1,
filter={
"$or": [{"page": 1}, {"page": 2}],
"$and": [{"page": 2}],
},
)
assert output == [
Document(page_content="bar", metadata={"page": 2, "page_str": "2"})
]
output = docsearch.similarity_search(
"foo", k=1, filter={"$and": [{"$or": [{"page": 1}, {"page": 2}]}, {"page": 3}]}
)
assert output == []
docsearch.drop_vectorstore()
def test_search_with_score() -> None:
"""Test end to end construction, search"""
texts = ["foo", "bar", "baz"]
metadatas = [{"page": str(i)} for i in range(len(texts))]
docsearch = TiDBVectorStore.from_texts(
texts=texts,
table_name="test_tidb_vectorstore_langchain",
embedding=FakeEmbeddingsWithAdaDimension(),
connection_string=TiDB_CONNECT_URL,
metadatas=metadatas,
drop_existing_table=True,
distance_strategy="cosine",
)
output = docsearch.similarity_search_with_score("foo", k=1)
docsearch.drop_vectorstore()
assert output == [(Document(page_content="foo", metadata={"page": "0"}), 0.0)]
def test_load_from_existing_vectorstore() -> None:
"""Test loading existing TiDB Vector Store."""
# create tidb vector store and add documents
texts = ["foo", "bar", "baz"]
metadatas = [{"page": str(i)} for i in range(len(texts))]
docsearch = TiDBVectorStore.from_texts(
texts=texts,
table_name="test_tidb_vectorstore_langchain",
embedding=FakeEmbeddingsWithAdaDimension(),
connection_string=TiDB_CONNECT_URL,
metadatas=metadatas,
drop_existing_table=True,
distance_strategy="cosine",
)
# load from existing tidb vector store
docsearch_copy = TiDBVectorStore.from_existing_vector_table(
table_name="test_tidb_vectorstore_langchain",
embedding=FakeEmbeddingsWithAdaDimension(),
connection_string=TiDB_CONNECT_URL,
)
output = docsearch_copy.similarity_search_with_score("foo", k=1)
docsearch.drop_vectorstore()
assert output == [(Document(page_content="foo", metadata={"page": "0"}), 0.0)]
# load from non-existing tidb vector store
try:
_ = TiDBVectorStore.from_existing_vector_table(
table_name="test_vectorstore_non_existing",
embedding=FakeEmbeddingsWithAdaDimension(),
connection_string=TiDB_CONNECT_URL,
)
assert False, "non-existing tidb vector store testing raised an error"
except ValueError:
pass
def test_delete_doc() -> None:
"""Test delete document from TiDB Vector."""
texts = ["foo", "bar", "baz"]
metadatas = [{"page": str(i)} for i in range(len(texts))]
ids = ["1", "2", "3"]
docsearch = TiDBVectorStore.from_texts(
texts=texts,
table_name="test_tidb_vectorstore_langchain",
embedding=FakeEmbeddingsWithAdaDimension(),
ids=ids,
connection_string=TiDB_CONNECT_URL,
metadatas=metadatas,
drop_existing_table=True,
)
output = docsearch.similarity_search_with_score("foo", k=1)
docsearch.delete(["1", "2"])
output_after_deleted = docsearch.similarity_search_with_score("foo", k=1)
docsearch.drop_vectorstore()
assert output == [(Document(page_content="foo", metadata={"page": "0"}), 0)]
assert output_after_deleted == [
(Document(page_content="baz", metadata={"page": "2"}), 0.004691842206844599)
]
def test_relevance_score() -> None:
"""Test to make sure the relevance score is scaled to 0-1."""
texts = ["foo", "bar", "baz"]
metadatas = [{"page": str(i)} for i in range(len(texts))]
docsearch_consine = TiDBVectorStore.from_texts(
texts=texts,
table_name="test_tidb_vectorstore_langchain",
embedding=FakeEmbeddingsWithAdaDimension(),
connection_string=TiDB_CONNECT_URL,
metadatas=metadatas,
distance_strategy="cosine",
drop_existing_table=True,
)
output_consine = docsearch_consine.similarity_search_with_relevance_scores(
"foo", k=3
)
assert output_consine == [
(Document(page_content="foo", metadata={"page": "0"}), 1.0),
(Document(page_content="bar", metadata={"page": "1"}), 0.9977280385800326),
(Document(page_content="baz", metadata={"page": "2"}), 0.9953081577931554),
]
docsearch_l2 = TiDBVectorStore.from_existing_vector_table(
table_name="test_tidb_vectorstore_langchain",
embedding=FakeEmbeddingsWithAdaDimension(),
connection_string=TiDB_CONNECT_URL,
distance_strategy="l2",
)
output_l2 = docsearch_l2.similarity_search_with_relevance_scores("foo", k=3)
assert output_l2 == [
(Document(page_content="foo", metadata={"page": "0"}), 1.0),
(Document(page_content="bar", metadata={"page": "1"}), -9.51189802081432),
(Document(page_content="baz", metadata={"page": "2"}), -11.90348790056394),
]
try:
_ = TiDBVectorStore.from_texts(
texts=texts,
table_name="test_tidb_vectorstore_langchain",
embedding=FakeEmbeddingsWithAdaDimension(),
connection_string=TiDB_CONNECT_URL,
metadatas=metadatas,
distance_strategy="inner",
drop_existing_table=True,
)
assert False, "inner product should raise error"
except ValueError:
pass
docsearch_l2.drop_vectorstore()
def test_retriever_search_threshold() -> None:
"""Test using retriever for searching with threshold."""
texts = ["foo", "bar", "baz"]
metadatas = [{"page": str(i)} for i in range(len(texts))]
docsearch = TiDBVectorStore.from_texts(
texts=texts,
table_name="test_tidb_vectorstore_langchain",
embedding=FakeEmbeddingsWithAdaDimension(),
metadatas=metadatas,
connection_string=TiDB_CONNECT_URL,
drop_existing_table=True,
)
retriever = docsearch.as_retriever(
search_type="similarity_score_threshold",
search_kwargs={"k": 3, "score_threshold": 0.997},
)
output = retriever.get_relevant_documents("foo")
assert output == [
Document(page_content="foo", metadata={"page": "0"}),
Document(page_content="bar", metadata={"page": "1"}),
]
docsearch.drop_vectorstore()

View File

@@ -55,6 +55,7 @@ def test_compatible_vectorstore_documentation() -> None:
"Chroma",
"DashVector",
"DatabricksVectorSearch",
"TiDBVectorStore",
"DeepLake",
"Dingo",
"DocumentDBVectorSearch",

View File

@@ -64,6 +64,7 @@ _EXPECTED = [
"SupabaseVectorStore",
"SurrealDBStore",
"Tair",
"TiDBVectorStore",
"TileDB",
"Tigris",
"TimescaleVector",