diff --git a/docs/docs/integrations/vectorstores/yellowbrick.ipynb b/docs/docs/integrations/vectorstores/yellowbrick.ipynb index 367fc8ca58f..dc789a86659 100644 --- a/docs/docs/integrations/vectorstores/yellowbrick.ipynb +++ b/docs/docs/integrations/vectorstores/yellowbrick.ipynb @@ -98,7 +98,7 @@ "import psycopg2\n", "from IPython.display import Markdown, display\n", "from langchain.chains import LLMChain, RetrievalQAWithSourcesChain\n", - "from langchain_community.docstore.document import Document\n", + "from langchain.schema import Document\n", "from langchain_community.vectorstores import Yellowbrick\n", "from langchain_openai import ChatOpenAI, OpenAIEmbeddings\n", "from langchain_text_splitters import RecursiveCharacterTextSplitter\n", @@ -209,14 +209,12 @@ "\n", "# Define the SQL statement to create a table\n", "create_table_query = f\"\"\"\n", - "CREATE TABLE if not exists {embedding_table} (\n", - " id uuid,\n", - " embedding_id integer,\n", - " text character varying(60000),\n", - " metadata character varying(1024),\n", - " embedding double precision\n", + "CREATE TABLE IF NOT EXISTS {embedding_table} (\n", + " doc_id uuid NOT NULL,\n", + " embedding_id smallint NOT NULL,\n", + " embedding double precision NOT NULL\n", ")\n", - "DISTRIBUTE ON (id);\n", + "DISTRIBUTE ON (doc_id);\n", "truncate table {embedding_table};\n", "\"\"\"\n", "\n", @@ -257,6 +255,8 @@ " f\"postgres://{urlparse.quote(YBUSER)}:{YBPASSWORD}@{YBHOST}:5432/{YB_DOC_DATABASE}\"\n", ")\n", "\n", + "print(yellowbrick_doc_connection_string)\n", + "\n", "# Establish a connection to the Yellowbrick database\n", "conn = psycopg2.connect(yellowbrick_doc_connection_string)\n", "\n", @@ -324,7 +324,7 @@ "vector_store = Yellowbrick.from_documents(\n", " documents=split_docs,\n", " embedding=embeddings,\n", - " connection_string=yellowbrick_connection_string,\n", + " connection_info=yellowbrick_connection_string,\n", " table=embedding_table,\n", ")\n", "\n", @@ -403,6 +403,88 @@ "print_result_sources(\"Whats an easy way to add users in bulk to Yellowbrick?\")" ] }, + { + "cell_type": "markdown", + "id": "1f39fd30", + "metadata": {}, + "source": [ + "## Part 6: Introducing an Index to Increase Performance\n", + "\n", + "Yellowbrick also supports indexing using the Locality-Sensitive Hashing approach. This is an approximate nearest-neighbor search technique, and allows one to trade off similarity search time at the expense of accuracy. The index introduces two new tunable parameters:\n", + "\n", + "- The number of hyperplanes, which is provided as an argument to `create_lsh_index(num_hyperplanes)`. The more documents, the more hyperplanes are needed. LSH is a form of dimensionality reduction. The original embeddings are transformed into lower dimensional vectors where the number of components is the same as the number of hyperplanes.\n", + "- The Hamming distance, an integer representing the breadth of the search. Smaller Hamming distances result in faster retreival but lower accuracy.\n", + "\n", + "Here's how you can create an index on the embeddings we loaded into Yellowbrick. We'll also re-run the previous chat session, but this time the retrieval will use the index. Note that for such a small number of documents, you won't see the benefit of indexing in terms of performance." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "02ba61c4", + "metadata": {}, + "outputs": [], + "source": [ + "system_template = \"\"\"Use the following pieces of context to answer the users question.\n", + "Take note of the sources and include them in the answer in the format: \"SOURCES: source1 source2\", use \"SOURCES\" in capital letters regardless of the number of sources.\n", + "If you don't know the answer, just say that \"I don't know\", don't try to make up an answer.\n", + "----------------\n", + "{summaries}\"\"\"\n", + "messages = [\n", + " SystemMessagePromptTemplate.from_template(system_template),\n", + " HumanMessagePromptTemplate.from_template(\"{question}\"),\n", + "]\n", + "prompt = ChatPromptTemplate.from_messages(messages)\n", + "\n", + "vector_store = Yellowbrick(\n", + " OpenAIEmbeddings(),\n", + " yellowbrick_connection_string,\n", + " embedding_table, # Change the table name to reflect your embeddings\n", + ")\n", + "\n", + "lsh_params = Yellowbrick.IndexParams(\n", + " Yellowbrick.IndexType.LSH, {\"num_hyperplanes\": 8, \"hamming_distance\": 2}\n", + ")\n", + "vector_store.create_index(lsh_params)\n", + "\n", + "chain_type_kwargs = {\"prompt\": prompt}\n", + "llm = ChatOpenAI(\n", + " model_name=\"gpt-3.5-turbo\", # Modify model_name if you have access to GPT-4\n", + " temperature=0,\n", + " max_tokens=256,\n", + ")\n", + "chain = RetrievalQAWithSourcesChain.from_chain_type(\n", + " llm=llm,\n", + " chain_type=\"stuff\",\n", + " retriever=vector_store.as_retriever(\n", + " k=5, search_kwargs={\"index_params\": lsh_params}\n", + " ),\n", + " return_source_documents=True,\n", + " chain_type_kwargs=chain_type_kwargs,\n", + ")\n", + "\n", + "\n", + "def print_result_sources(query):\n", + " result = chain(query)\n", + " output_text = f\"\"\"### Question: \n", + " {query}\n", + " ### Answer: \n", + " {result['answer']}\n", + " ### Sources: \n", + " {result['sources']}\n", + " ### All relevant sources:\n", + " {', '.join(list(set([doc.metadata['source'] for doc in result['source_documents']])))}\n", + " \"\"\"\n", + " display(Markdown(output_text))\n", + "\n", + "\n", + "# Use the chain to query\n", + "\n", + "print_result_sources(\"How many databases can be in a Yellowbrick Instance?\")\n", + "\n", + "print_result_sources(\"Whats an easy way to add users in bulk to Yellowbrick?\")" + ] + }, { "cell_type": "markdown", "id": "697c8a38", @@ -418,9 +500,9 @@ ], "metadata": { "kernelspec": { - "display_name": "langchain_venv", + "display_name": "Python 3", "language": "python", - "name": "langchain_venv" + "name": "python3" }, "language_info": { "codemirror_mode": { diff --git a/docs/docs/modules/data_connection/indexing.ipynb b/docs/docs/modules/data_connection/indexing.ipynb index c8d8b2fd1e2..831718b55f6 100644 --- a/docs/docs/modules/data_connection/indexing.ipynb +++ b/docs/docs/modules/data_connection/indexing.ipynb @@ -60,7 +60,7 @@ " * document addition by id (`add_documents` method with `ids` argument)\n", " * delete by id (`delete` method with `ids` argument)\n", "\n", - "Compatible Vectorstores: `AnalyticDB`, `AstraDB`, `AzureCosmosDBVectorSearch`, `AzureSearch`, `AwaDB`, `Bagel`, `Cassandra`, `Chroma`, `CouchbaseVectorStore`, `DashVector`, `DatabricksVectorSearch`, `DeepLake`, `Dingo`, `ElasticVectorSearch`, `ElasticsearchStore`, `FAISS`, `HanaDB`, `LanceDB`, `Milvus`, `MyScale`, `OpenSearchVectorSearch`, `PGVector`, `Pinecone`, `Qdrant`, `Redis`, `Rockset`, `ScaNN`, `SupabaseVectorStore`, `SurrealDBStore`, `TimescaleVector`, `UpstashVectorStore`, `Vald`, `VDMS`, `Vearch`, `VespaStore`, `Weaviate`, `ZepVectorStore`, `TencentVectorDB`, `OpenSearchVectorSearch`.\n", + "Compatible Vectorstores: `AnalyticDB`, `AstraDB`, `AzureCosmosDBVectorSearch`, `AzureSearch`, `AwaDB`, `Bagel`, `Cassandra`, `Chroma`, `CouchbaseVectorStore`, `DashVector`, `DatabricksVectorSearch`, `DeepLake`, `Dingo`, `ElasticVectorSearch`, `ElasticsearchStore`, `FAISS`, `HanaDB`, `LanceDB`, `Milvus`, `MyScale`, `OpenSearchVectorSearch`, `PGVector`, `Pinecone`, `Qdrant`, `Redis`, `Rockset`, `ScaNN`, `SupabaseVectorStore`, `SurrealDBStore`, `TimescaleVector`, `UpstashVectorStore`, `Vald`, `VDMS`, `Vearch`, `VespaStore`, `Weaviate`, `ZepVectorStore`, `TencentVectorDB`, `OpenSearchVectorSearch`, `Yellowbrick`.\n", " \n", "## Caution\n", "\n", diff --git a/libs/community/langchain_community/vectorstores/yellowbrick.py b/libs/community/langchain_community/vectorstores/yellowbrick.py index e3e5504346a..dbf9df917e8 100644 --- a/libs/community/langchain_community/vectorstores/yellowbrick.py +++ b/libs/community/langchain_community/vectorstores/yellowbrick.py @@ -1,12 +1,18 @@ from __future__ import annotations +import atexit +import csv +import enum import json import logging import uuid -import warnings -from itertools import repeat +from contextlib import contextmanager +from io import StringIO from typing import ( + TYPE_CHECKING, Any, + Dict, + Generator, Iterable, List, Optional, @@ -19,12 +25,13 @@ from langchain_core.vectorstores import VectorStore from langchain_community.docstore.document import Document -logger = logging.getLogger(__name__) +if TYPE_CHECKING: + from psycopg2.extensions import connection as PgConnection + from psycopg2.extensions import cursor as PgCursor class Yellowbrick(VectorStore): """Yellowbrick as a vector database. - Example: .. code-block:: python from langchain_community.vectorstores import Yellowbrick @@ -32,11 +39,37 @@ class Yellowbrick(VectorStore): ... """ + class IndexType(str, enum.Enum): + """Enumerator for the supported Index types within Yellowbrick.""" + + NONE = "none" + LSH = "lsh" + + class IndexParams: + """Parameters for configuring a Yellowbrick index.""" + + def __init__( + self, + index_type: Optional["Yellowbrick.IndexType"] = None, + params: Optional[Dict[str, Any]] = None, + ): + if index_type is None: + index_type = Yellowbrick.IndexType.NONE + self.index_type = index_type + self.params = params or {} + + def get_param(self, key: str, default: Any = None) -> Any: + return self.params.get(key, default) + def __init__( self, embedding: Embeddings, connection_string: str, table: str, + *, + schema: Optional[str] = None, + logger: Optional[logging.Logger] = None, + drop: bool = False, ) -> None: """Initialize with yellowbrick client. Args: @@ -44,79 +77,232 @@ class Yellowbrick(VectorStore): connection_string: Format 'postgres://username:password@host:port/database' table: Table used to store / retrieve embeddings from """ + from psycopg2 import extras - import psycopg2 + extras.register_uuid() + + if logger: + self.logger = logger + else: + self.logger = logging.getLogger(__name__) + self.logger.setLevel(logging.ERROR) + handler = logging.StreamHandler() + handler.setLevel(logging.DEBUG) + formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") + handler.setFormatter(formatter) + self.logger.addHandler(handler) if not isinstance(embedding, Embeddings): - warnings.warn("embeddings input must be Embeddings object.") + self.logger.error("embeddings input must be Embeddings object.") + return + + self.LSH_INDEX_TABLE: str = "_lsh_index" + self.LSH_HYPERPLANE_TABLE: str = "_lsh_hyperplane" + self.CONTENT_TABLE: str = "_content" self.connection_string = connection_string + self.connection = Yellowbrick.DatabaseConnection(connection_string, self.logger) + atexit.register(self.connection.close_connection) + + self._schema = schema self._table = table self._embedding = embedding - self._connection = psycopg2.connect(connection_string) + self._max_embedding_len = None + self._check_database_utf8() - self.__post_init__() + with self.connection.get_cursor() as cursor: + if drop: + self.drop(table=self._table, schema=self._schema, cursor=cursor) + self.drop( + table=self._table + self.CONTENT_TABLE, + schema=self._schema, + cursor=cursor, + ) + self._drop_lsh_index_tables(cursor) - def __post_init__( - self, - ) -> None: - """Initialize the store.""" - self.check_database_utf8() - self.create_table_if_not_exists() + self._create_schema(cursor) + self._create_table(cursor) - def __del__(self) -> None: - if self._connection: - self._connection.close() + class DatabaseConnection: + _instance = None + _connection_string: str + _connection: Optional["PgConnection"] = None + _logger: logging.Logger - def create_table_if_not_exists(self) -> None: + def __new__( + cls, connection_string: str, logger: logging.Logger + ) -> "Yellowbrick.DatabaseConnection": + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._connection_string = connection_string + cls._instance._logger = logger + return cls._instance + + def close_connection(self) -> None: + if self._connection and not self._connection.closed: + self._connection.close() + self._connection = None + + def get_connection(self) -> "PgConnection": + import psycopg2 + + if not self._connection or self._connection.closed: + self._connection = psycopg2.connect(self._connection_string) + self._connection.autocommit = False + + return self._connection + + @contextmanager + def get_managed_connection(self) -> Generator["PgConnection", None, None]: + from psycopg2 import DatabaseError + + conn = self.get_connection() + try: + yield conn + except DatabaseError as e: + conn.rollback() + self._logger.error( + "Database error occurred, rolling back transaction.", exc_info=True + ) + raise RuntimeError("Database transaction failed.") from e + else: + conn.commit() + + @contextmanager + def get_cursor(self) -> Generator["PgCursor", None, None]: + with self.get_managed_connection() as conn: + cursor = conn.cursor() + try: + yield cursor + finally: + cursor.close() + + def _create_schema(self, cursor: "PgCursor") -> None: + """ + Helper function: create schema if not exists + """ + from psycopg2 import sql + + if self._schema: + cursor.execute( + sql.SQL( + """ + CREATE SCHEMA IF NOT EXISTS {s} + """ + ).format( + s=sql.Identifier(self._schema), + ) + ) + + def _create_table(self, cursor: "PgCursor") -> None: """ Helper function: create table if not exists """ from psycopg2 import sql - cursor = self._connection.cursor() + schema_prefix = (self._schema,) if self._schema else () + t = sql.Identifier(*schema_prefix, self._table + self.CONTENT_TABLE) + c = sql.Identifier(self._table + self.CONTENT_TABLE + "_pk_doc_id") cursor.execute( sql.SQL( - "CREATE TABLE IF NOT EXISTS {} ( \ - id UUID, \ - embedding_id INTEGER, \ - text VARCHAR(60000), \ - metadata VARCHAR(1024), \ - embedding FLOAT)" - ).format(sql.Identifier(self._table)) + """ + CREATE TABLE IF NOT EXISTS {t} ( + doc_id UUID NOT NULL, + text VARCHAR(60000) NOT NULL, + metadata VARCHAR(1024) NOT NULL, + CONSTRAINT {c} PRIMARY KEY (doc_id)) + DISTRIBUTE ON (doc_id) SORT ON (doc_id) + """ + ).format( + t=t, + c=c, + ) ) - self._connection.commit() - cursor.close() - def drop(self, table: str) -> None: + schema_prefix = (self._schema,) if self._schema else () + t1 = sql.Identifier(*schema_prefix, self._table) + t2 = sql.Identifier(*schema_prefix, self._table + self.CONTENT_TABLE) + c1 = sql.Identifier( + self._table + self.CONTENT_TABLE + "_pk_doc_id_embedding_id" + ) + c2 = sql.Identifier(self._table + self.CONTENT_TABLE + "_fk_doc_id") + cursor.execute( + sql.SQL( + """ + CREATE TABLE IF NOT EXISTS {t1} ( + doc_id UUID NOT NULL, + embedding_id SMALLINT NOT NULL, + embedding FLOAT NOT NULL, + CONSTRAINT {c1} PRIMARY KEY (doc_id, embedding_id), + CONSTRAINT {c2} FOREIGN KEY (doc_id) REFERENCES {t2}(doc_id)) + DISTRIBUTE ON (doc_id) SORT ON (doc_id) + """ + ).format( + t1=t1, + t2=t2, + c1=c1, + c2=c2, + ) + ) + + def drop( + self, + table: str, + schema: Optional[str] = None, + cursor: Optional["PgCursor"] = None, + ) -> None: """ - Helper function: Drop data + Helper function: Drop data. If a cursor is provided, use it; + otherwise, obtain a new cursor for the operation. + """ + if cursor is None: + with self.connection.get_cursor() as cursor: + self._drop_table(cursor, table, schema=schema) + else: + self._drop_table(cursor, table, schema=schema) + + def _drop_table( + self, + cursor: "PgCursor", + table: str, + schema: Optional[str] = None, + ) -> None: + """ + Executes the drop table command using the given cursor. """ from psycopg2 import sql - cursor = self._connection.cursor() - cursor.execute(sql.SQL("DROP TABLE IF EXISTS {}").format(sql.Identifier(table))) - self._connection.commit() - cursor.close() + if schema: + table_name = sql.Identifier(schema, table) + else: + table_name = sql.Identifier(table) - def check_database_utf8(self) -> bool: + drop_table_query = sql.SQL( + """ + DROP TABLE IF EXISTS {} CASCADE + """ + ).format(table_name) + cursor.execute(drop_table_query) + + def _check_database_utf8(self) -> bool: """ Helper function: Test the database is UTF-8 encoded """ - cursor = self._connection.cursor() - query = "SELECT pg_encoding_to_char(encoding) \ - FROM pg_database \ - WHERE datname = current_database();" - cursor.execute(query) - encoding = cursor.fetchone()[0] - cursor.close() + with self.connection.get_cursor() as cursor: + query = """ + SELECT pg_encoding_to_char(encoding) + FROM pg_database + WHERE datname = current_database(); + """ + cursor.execute(query) + encoding = cursor.fetchone()[0] + if encoding.lower() == "utf8" or encoding.lower() == "utf-8": return True else: - raise Exception( - f"Database \ - '{self.connection_string.split('/')[-1]}' encoding is not UTF-8" - ) + raise Exception("Database encoding is not UTF-8") + + return False def add_texts( self, @@ -124,52 +310,83 @@ class Yellowbrick(VectorStore): metadatas: Optional[List[dict]] = None, **kwargs: Any, ) -> List[str]: - """Add more texts to the vectorstore index. - Args: - texts: Iterable of strings to add to the vectorstore. - metadatas: Optional list of metadatas associated with the texts. - kwargs: vectorstore specific parameters - """ - from psycopg2 import sql + batch_size = 10000 texts = list(texts) - cursor = self._connection.cursor() embeddings = self._embedding.embed_documents(list(texts)) results = [] if not metadatas: metadatas = [{} for _ in texts] - for id in range(len(embeddings)): - doc_uuid = uuid.uuid4() - results.append(str(doc_uuid)) - data_input = [ - (str(id), embedding_id, text, json.dumps(metadata), embedding) - for id, embedding_id, text, metadata, embedding in zip( - repeat(doc_uuid), - range(len(embeddings[id])), - repeat(texts[id]), - repeat(metadatas[id]), - embeddings[id], - ) - ] - flattened_input = [val for sublist in data_input for val in sublist] - insert_query = sql.SQL( - "INSERT INTO {t} \ - (id, embedding_id, text, metadata, embedding) VALUES {v}" - ).format( - t=sql.Identifier(self._table), - v=( - sql.SQL(",").join( - [ - sql.SQL("(%s,%s,%s,%s,%s)") - for _ in range(len(embeddings[id])) - ] - ) - ), + + index_params = kwargs.get("index_params") or Yellowbrick.IndexParams() + + with self.connection.get_cursor() as cursor: + content_io = StringIO() + embeddings_io = StringIO() + content_writer = csv.writer( + content_io, delimiter="\t", quotechar='"', quoting=csv.QUOTE_MINIMAL ) - cursor.execute(insert_query, flattened_input) - self._connection.commit() + embeddings_writer = csv.writer( + embeddings_io, delimiter="\t", quotechar='"', quoting=csv.QUOTE_MINIMAL + ) + current_batch_size = 0 + + for i, text in enumerate(texts): + doc_uuid = str(uuid.uuid4()) + results.append(doc_uuid) + + content_writer.writerow([doc_uuid, text, json.dumps(metadatas[i])]) + + for embedding_id, embedding in enumerate(embeddings[i]): + embeddings_writer.writerow([doc_uuid, embedding_id, embedding]) + + current_batch_size += 1 + + if current_batch_size >= batch_size: + self._copy_to_db(cursor, content_io, embeddings_io) + + content_io.seek(0) + content_io.truncate(0) + embeddings_io.seek(0) + embeddings_io.truncate(0) + current_batch_size = 0 + + if current_batch_size > 0: + self._copy_to_db(cursor, content_io, embeddings_io) + + if index_params.index_type == Yellowbrick.IndexType.LSH: + self._update_index(index_params, uuid.UUID(doc_uuid)) + return results + def _copy_to_db( + self, cursor: "PgCursor", content_io: StringIO, embeddings_io: StringIO + ) -> None: + content_io.seek(0) + embeddings_io.seek(0) + + from psycopg2 import sql + + schema_prefix = (self._schema,) if self._schema else () + table = sql.Identifier(*schema_prefix, self._table + self.CONTENT_TABLE) + content_copy_query = sql.SQL( + """ + COPY {table} (doc_id, text, metadata) FROM + STDIN WITH (FORMAT CSV, DELIMITER E'\\t', QUOTE '\"') + """ + ).format(table=table) + cursor.copy_expert(content_copy_query, content_io) + + schema_prefix = (self._schema,) if self._schema else () + table = sql.Identifier(*schema_prefix, self._table) + embeddings_copy_query = sql.SQL( + """ + COPY {table} (doc_id, embedding_id, embedding) FROM + STDIN WITH (FORMAT CSV, DELIMITER E'\\t', QUOTE '\"') + """ + ).format(table=table) + cursor.copy_expert(embeddings_copy_query, embeddings_io) + @classmethod def from_texts( cls: Type[Yellowbrick], @@ -178,6 +395,8 @@ class Yellowbrick(VectorStore): metadatas: Optional[List[dict]] = None, connection_string: str = "", table: str = "langchain", + schema: str = "public", + drop: bool = False, **kwargs: Any, ) -> Yellowbrick: """Add texts to the vectorstore index. @@ -189,16 +408,110 @@ class Yellowbrick(VectorStore): table: table to store embeddings kwargs: vectorstore specific parameters """ - if connection_string is None: - raise ValueError("connection_string must be provided") vss = cls( embedding=embedding, connection_string=connection_string, table=table, + schema=schema, + drop=drop, ) - vss.add_texts(texts=texts, metadatas=metadatas) + vss.add_texts(texts=texts, metadatas=metadatas, **kwargs) return vss + def delete( + self, + ids: Optional[List[str]] = None, + delete_all: Optional[bool] = None, + **kwargs: Any, + ) -> None: + """Delete vectors by uuids. + + Args: + ids: List of ids to delete, where each id is a uuid string. + """ + from psycopg2 import sql + + if delete_all: + where_sql = sql.SQL( + """ + WHERE 1=1 + """ + ) + elif ids is not None: + uuids = tuple(sql.Literal(id) for id in ids) + ids_formatted = sql.SQL(", ").join(uuids) + where_sql = sql.SQL( + """ + WHERE doc_id IN ({ids}) + """ + ).format( + ids=ids_formatted, + ) + else: + raise ValueError("Either ids or delete_all must be provided.") + + schema_prefix = (self._schema,) if self._schema else () + with self.connection.get_cursor() as cursor: + table_identifier = sql.Identifier( + *schema_prefix, self._table + self.CONTENT_TABLE + ) + query = sql.SQL("DELETE FROM {table} {where_sql}").format( + table=table_identifier, where_sql=where_sql + ) + cursor.execute(query) + + table_identifier = sql.Identifier(*schema_prefix, self._table) + query = sql.SQL("DELETE FROM {table} {where_sql}").format( + table=table_identifier, where_sql=where_sql + ) + cursor.execute(query) + + if self._table_exists( + cursor, self._table + self.LSH_INDEX_TABLE, *schema_prefix + ): + table_identifier = sql.Identifier( + *schema_prefix, self._table + self.LSH_INDEX_TABLE + ) + query = sql.SQL("DELETE FROM {table} {where_sql}").format( + table=table_identifier, where_sql=where_sql + ) + cursor.execute(query) + + return None + + def _table_exists( + self, cursor: "PgCursor", table_name: str, schema: str = "public" + ) -> bool: + """ + Checks if a table exists in the given schema + """ + from psycopg2 import sql + + schema = sql.Literal(schema) + table_name = sql.Literal(table_name) + cursor.execute( + sql.SQL( + """ + SELECT COUNT(*) + FROM sys.table t INNER JOIN sys.schema s ON t.schema_id = s.schema_id + WHERE s.name = {schema} AND t.name = {table_name} + """ + ).format( + schema=schema, + table_name=table_name, + ) + ) + return cursor.fetchone()[0] > 0 + + def _generate_vector_uuid(self, vector: List[float]) -> uuid.UUID: + import hashlib + + vector_str = ",".join(map(str, vector)) + hash_object = hashlib.sha1(vector_str.encode()) + hash_digest = hash_object.digest() + vector_uuid = uuid.UUID(bytes=hash_digest[:16]) + return vector_uuid + def similarity_search_with_score_by_vector( self, embedding: List[float], k: int = 4, **kwargs: Any ) -> List[Tuple[Document, float]]: @@ -215,46 +528,133 @@ class Yellowbrick(VectorStore): List[Document, float]: List of Documents and scores """ from psycopg2 import sql + from psycopg2.extras import execute_values - cursor = self._connection.cursor() - tmp_table = "tmp_" + self._table - cursor.execute( - sql.SQL( - "CREATE TEMPORARY TABLE {} ( \ - embedding_id INTEGER, embedding FLOAT)" - ).format(sql.Identifier(tmp_table)) - ) - self._connection.commit() + index_params = kwargs.get("index_params") or Yellowbrick.IndexParams() - data_input = [ - (embedding_id, embedding) - for embedding_id, embedding in zip(range(len(embedding)), embedding) - ] - flattened_input = [val for sublist in data_input for val in sublist] - insert_query = sql.SQL( - "INSERT INTO {t} \ - (embedding_id, embedding) VALUES {v}" - ).format( - t=sql.Identifier(tmp_table), - v=sql.SQL(",").join([sql.SQL("(%s,%s)") for _ in range(len(embedding))]), - ) - cursor.execute(insert_query, flattened_input) - self._connection.commit() - sql_query = sql.SQL( - "SELECT text, \ - metadata, \ - sum(v1.embedding * v2.embedding) / \ - ( sqrt(sum(v1.embedding * v1.embedding)) * \ - sqrt(sum(v2.embedding * v2.embedding))) AS score \ - FROM {v1} v1 INNER JOIN {v2} v2 \ - ON v1.embedding_id = v2.embedding_id \ - GROUP BY v2.id, v2.text, v2.metadata \ - ORDER BY score DESC \ - LIMIT %s" - ).format(v1=sql.Identifier(tmp_table), v2=sql.Identifier(self._table)) - cursor.execute(sql_query, (k,)) - results = cursor.fetchall() - self.drop(tmp_table) + with self.connection.get_cursor() as cursor: + tmp_embeddings_table = "tmp_" + self._table + tmp_doc_id = self._generate_vector_uuid(embedding) + create_table_query = sql.SQL( + """ + CREATE TEMPORARY TABLE {} ( + doc_id UUID, + embedding_id SMALLINT, + embedding FLOAT) + ON COMMIT DROP + DISTRIBUTE REPLICATE + """ + ).format(sql.Identifier(tmp_embeddings_table)) + cursor.execute(create_table_query) + data_input = [ + (str(tmp_doc_id), embedding_id, embedding_value) + for embedding_id, embedding_value in enumerate(embedding) + ] + insert_query = sql.SQL( + "INSERT INTO {} (doc_id, embedding_id, embedding) VALUES %s" + ).format(sql.Identifier(tmp_embeddings_table)) + execute_values(cursor, insert_query, data_input) + + v1 = sql.Identifier(tmp_embeddings_table) + schema_prefix = (self._schema,) if self._schema else () + v2 = sql.Identifier(*schema_prefix, self._table) + v3 = sql.Identifier(*schema_prefix, self._table + self.CONTENT_TABLE) + if index_params.index_type == Yellowbrick.IndexType.LSH: + tmp_hash_table = self._table + "_tmp_hash" + self._generate_tmp_lsh_hashes( + cursor, + tmp_embeddings_table, + tmp_hash_table, + ) + + schema_prefix = (self._schema,) if self._schema else () + lsh_index = sql.Identifier( + *schema_prefix, self._table + self.LSH_INDEX_TABLE + ) + input_hash_table = sql.Identifier(tmp_hash_table) + sql_query = sql.SQL( + """ + WITH index_docs AS ( + SELECT + t1.doc_id, + SUM(ABS(t1.hash-t2.hash)) as hamming_distance + FROM + {lsh_index} t1 + INNER JOIN + {input_hash_table} t2 + ON t1.hash_index = t2.hash_index + GROUP BY t1.doc_id + HAVING hamming_distance <= {hamming_distance} + ) + SELECT + text, + metadata, + SUM(v1.embedding * v2.embedding) / + (SQRT(SUM(v1.embedding * v1.embedding)) * + SQRT(SUM(v2.embedding * v2.embedding))) AS score + FROM + {v1} v1 + INNER JOIN + {v2} v2 + ON v1.embedding_id = v2.embedding_id + INNER JOIN + {v3} v3 + ON v2.doc_id = v3.doc_id + INNER JOIN + index_docs v4 + ON v2.doc_id = v4.doc_id + GROUP BY v3.doc_id, v3.text, v3.metadata + ORDER BY score DESC + LIMIT %s + """ + ).format( + v1=v1, + v2=v2, + v3=v3, + lsh_index=lsh_index, + input_hash_table=input_hash_table, + hamming_distance=sql.Literal( + index_params.get_param("hamming_distance", 0) + ), + ) + cursor.execute( + sql_query, + (k,), + ) + results = cursor.fetchall() + else: + sql_query = sql.SQL( + """ + SELECT + text, + metadata, + score + FROM + (SELECT + v2.doc_id doc_id, + SUM(v1.embedding * v2.embedding) / + (SQRT(SUM(v1.embedding * v1.embedding)) * + SQRT(SUM(v2.embedding * v2.embedding))) AS score + FROM + {v1} v1 + INNER JOIN + {v2} v2 + ON v1.embedding_id = v2.embedding_id + GROUP BY v2.doc_id + ORDER BY score DESC LIMIT %s + ) v4 + INNER JOIN + {v3} v3 + ON v4.doc_id = v3.doc_id + ORDER BY score DESC + """ + ).format( + v1=v1, + v2=v2, + v3=v3, + ) + cursor.execute(sql_query, (k,)) + results = cursor.fetchall() documents: List[Tuple[Document, float]] = [] for result in results: @@ -262,7 +662,6 @@ class Yellowbrick(VectorStore): doc = Document(page_content=result[0], metadata=metadata) documents.append((doc, result[2])) - cursor.close() return documents def similarity_search( @@ -282,7 +681,7 @@ class Yellowbrick(VectorStore): """ embedding = self._embedding.embed_query(query) documents = self.similarity_search_with_score_by_vector( - embedding=embedding, k=k + embedding=embedding, k=k, **kwargs ) return [doc for doc, _ in documents] @@ -303,7 +702,7 @@ class Yellowbrick(VectorStore): """ embedding = self._embedding.embed_query(query) documents = self.similarity_search_with_score_by_vector( - embedding=embedding, k=k + embedding=embedding, k=k, **kwargs ) return documents @@ -323,6 +722,252 @@ class Yellowbrick(VectorStore): List[Document]: List of documents """ documents = self.similarity_search_with_score_by_vector( - embedding=embedding, k=k + embedding=embedding, k=k, **kwargs ) return [doc for doc, _ in documents] + + def _update_lsh_hashes( + self, + cursor: "PgCursor", + doc_id: Optional[uuid.UUID] = None, + ) -> None: + """Add hashes to LSH index""" + from psycopg2 import sql + + schema_prefix = (self._schema,) if self._schema else () + lsh_hyperplane_table = sql.Identifier( + *schema_prefix, self._table + self.LSH_HYPERPLANE_TABLE + ) + lsh_index_table_id = sql.Identifier( + *schema_prefix, self._table + self.LSH_INDEX_TABLE + ) + embedding_table_id = sql.Identifier(*schema_prefix, self._table) + query_prefix_id = sql.SQL("INSERT INTO {}").format(lsh_index_table_id) + condition = ( + sql.SQL("WHERE e.doc_id = {doc_id}").format(doc_id=sql.Literal(str(doc_id))) + if doc_id + else sql.SQL("") + ) + group_by = sql.SQL("GROUP BY 1, 2") + + input_query = sql.SQL( + """ + {query_prefix} + SELECT + e.doc_id as doc_id, + h.id as hash_index, + CASE WHEN SUM(e.embedding * h.hyperplane) > 0 THEN 1 ELSE 0 END as hash + FROM {embedding_table} e + INNER JOIN {hyperplanes} h ON e.embedding_id = h.hyperplane_id + {condition} + {group_by} + """ + ).format( + query_prefix=query_prefix_id, + embedding_table=embedding_table_id, + hyperplanes=lsh_hyperplane_table, + condition=condition, + group_by=group_by, + ) + cursor.execute(input_query) + + def _generate_tmp_lsh_hashes( + self, cursor: "PgCursor", tmp_embedding_table: str, tmp_hash_table: str + ) -> None: + """Generate temp LSH""" + from psycopg2 import sql + + schema_prefix = (self._schema,) if self._schema else () + lsh_hyperplane_table = sql.Identifier( + *schema_prefix, self._table + self.LSH_HYPERPLANE_TABLE + ) + tmp_embedding_table_id = sql.Identifier(tmp_embedding_table) + tmp_hash_table_id = sql.Identifier(tmp_hash_table) + query_prefix = sql.SQL("CREATE TEMPORARY TABLE {} ON COMMIT DROP AS").format( + tmp_hash_table_id + ) + group_by = sql.SQL("GROUP BY 1") + + input_query = sql.SQL( + """ + {query_prefix} + SELECT + h.id as hash_index, + CASE WHEN SUM(e.embedding * h.hyperplane) > 0 THEN 1 ELSE 0 END as hash + FROM {embedding_table} e + INNER JOIN {hyperplanes} h ON e.embedding_id = h.hyperplane_id + {group_by} + DISTRIBUTE REPLICATE + """ + ).format( + query_prefix=query_prefix, + embedding_table=tmp_embedding_table_id, + hyperplanes=lsh_hyperplane_table, + group_by=group_by, + ) + cursor.execute(input_query) + + def _populate_hyperplanes(self, cursor: "PgCursor", num_hyperplanes: int) -> None: + """Generate random hyperplanes and store in Yellowbrick""" + from psycopg2 import sql + + schema_prefix = (self._schema,) if self._schema else () + hyperplanes_table = sql.Identifier( + *schema_prefix, self._table + self.LSH_HYPERPLANE_TABLE + ) + cursor.execute(sql.SQL("SELECT COUNT(*) FROM {t}").format(t=hyperplanes_table)) + if cursor.fetchone()[0] > 0: + return + + t = sql.Identifier(*schema_prefix, self._table) + cursor.execute(sql.SQL("SELECT MAX(embedding_id) FROM {t}").format(t=t)) + num_dimensions = cursor.fetchone()[0] + num_dimensions += 1 + + insert_query = sql.SQL( + """ + WITH parameters AS ( + SELECT {num_hyperplanes} AS num_hyperplanes, + {dims_per_hyperplane} AS dims_per_hyperplane + ) + INSERT INTO {hyperplanes_table} (id, hyperplane_id, hyperplane) + SELECT id, hyperplane_id, (random() * 2 - 1) AS hyperplane + FROM + (SELECT range-1 id FROM sys.rowgenerator + WHERE range BETWEEN 1 AND + (SELECT num_hyperplanes FROM parameters) AND + worker_lid = 0 AND thread_id = 0) a, + (SELECT range-1 hyperplane_id FROM sys.rowgenerator + WHERE range BETWEEN 1 AND + (SELECT dims_per_hyperplane FROM parameters) AND + worker_lid = 0 AND thread_id = 0) b + """ + ).format( + num_hyperplanes=sql.Literal(num_hyperplanes), + dims_per_hyperplane=sql.Literal(num_dimensions), + hyperplanes_table=hyperplanes_table, + ) + cursor.execute(insert_query) + + def _create_lsh_index_tables(self, cursor: "PgCursor") -> None: + """Create LSH index and hyperplane tables""" + from psycopg2 import sql + + schema_prefix = (self._schema,) if self._schema else () + t1 = sql.Identifier(*schema_prefix, self._table + self.LSH_INDEX_TABLE) + t2 = sql.Identifier(*schema_prefix, self._table + self.CONTENT_TABLE) + c1 = sql.Identifier(self._table + self.LSH_INDEX_TABLE + "_pk_doc_id") + c2 = sql.Identifier(self._table + self.LSH_INDEX_TABLE + "_fk_doc_id") + cursor.execute( + sql.SQL( + """ + CREATE TABLE IF NOT EXISTS {t1} ( + doc_id UUID NOT NULL, + hash_index SMALLINT NOT NULL, + hash SMALLINT NOT NULL, + CONSTRAINT {c1} PRIMARY KEY (doc_id, hash_index), + CONSTRAINT {c2} FOREIGN KEY (doc_id) REFERENCES {t2}(doc_id)) + DISTRIBUTE ON (doc_id) SORT ON (doc_id) + """ + ).format( + t1=t1, + t2=t2, + c1=c1, + c2=c2, + ) + ) + + schema_prefix = (self._schema,) if self._schema else () + t = sql.Identifier(*schema_prefix, self._table + self.LSH_HYPERPLANE_TABLE) + c = sql.Identifier(self._table + self.LSH_HYPERPLANE_TABLE + "_pk_id_hp_id") + cursor.execute( + sql.SQL( + """ + CREATE TABLE IF NOT EXISTS {t} ( + id SMALLINT NOT NULL, + hyperplane_id SMALLINT NOT NULL, + hyperplane FLOAT NOT NULL, + CONSTRAINT {c} PRIMARY KEY (id, hyperplane_id)) + DISTRIBUTE REPLICATE SORT ON (id) + """ + ).format( + t=t, + c=c, + ) + ) + + def _drop_lsh_index_tables(self, cursor: "PgCursor") -> None: + """Drop LSH index tables""" + self.drop( + schema=self._schema, table=self._table + self.LSH_INDEX_TABLE, cursor=cursor + ) + self.drop( + schema=self._schema, + table=self._table + self.LSH_HYPERPLANE_TABLE, + cursor=cursor, + ) + + def create_index(self, index_params: Yellowbrick.IndexParams) -> None: + """Create index from existing vectors""" + if index_params.index_type == Yellowbrick.IndexType.LSH: + with self.connection.get_cursor() as cursor: + self._drop_lsh_index_tables(cursor) + self._create_lsh_index_tables(cursor) + self._populate_hyperplanes( + cursor, index_params.get_param("num_hyperplanes", 128) + ) + self._update_lsh_hashes(cursor) + + def drop_index(self, index_params: Yellowbrick.IndexParams) -> None: + """Drop an index""" + if index_params.index_type == Yellowbrick.IndexType.LSH: + with self.connection.get_cursor() as cursor: + self._drop_lsh_index_tables(cursor) + + def _update_index( + self, index_params: Yellowbrick.IndexParams, doc_id: uuid.UUID + ) -> None: + """Update an index with a new or modified embedding in the embeddings table""" + if index_params.index_type == Yellowbrick.IndexType.LSH: + with self.connection.get_cursor() as cursor: + self._update_lsh_hashes(cursor, doc_id) + + def migrate_schema_v1_to_v2(self) -> None: + from psycopg2 import sql + + try: + with self.connection.get_cursor() as cursor: + schema_prefix = (self._schema,) if self._schema else () + embeddings = sql.Identifier(*schema_prefix, self._table) + old_embeddings = sql.Identifier(*schema_prefix, self._table + "_v1") + content = sql.Identifier( + *schema_prefix, self._table + self.CONTENT_TABLE + ) + alter_table_query = sql.SQL("ALTER TABLE {t1} RENAME TO {t2}").format( + t1=embeddings, + t2=old_embeddings, + ) + cursor.execute(alter_table_query) + + self._create_table(cursor) + + insert_query = sql.SQL( + """ + INSERT INTO {t1} (doc_id, embedding_id, embedding) + SELECT id, embedding_id, embedding FROM {t2} + """ + ).format( + t1=embeddings, + t2=old_embeddings, + ) + cursor.execute(insert_query) + + insert_content_query = sql.SQL( + """ + INSERT INTO {t1} (doc_id, text, metadata) + SELECT DISTINCT id, text, metadata FROM {t2} + """ + ).format(t1=content, t2=old_embeddings) + cursor.execute(insert_content_query) + except Exception as e: + raise RuntimeError(f"Failed to migrate schema: {e}") from e diff --git a/libs/community/tests/integration_tests/vectorstores/test_yellowbrick.py b/libs/community/tests/integration_tests/vectorstores/test_yellowbrick.py index b5eba727cb4..bff44688581 100644 --- a/libs/community/tests/integration_tests/vectorstores/test_yellowbrick.py +++ b/libs/community/tests/integration_tests/vectorstores/test_yellowbrick.py @@ -1,3 +1,4 @@ +import logging from typing import List, Optional import pytest @@ -5,60 +6,256 @@ import pytest from langchain_community.docstore.document import Document from langchain_community.vectorstores import Yellowbrick from tests.integration_tests.vectorstores.fake_embeddings import ( - FakeEmbeddings, + ConsistentFakeEmbeddings, fake_texts, ) YELLOWBRICK_URL = "postgres://username:password@host:port/database" YELLOWBRICK_TABLE = "test_table" +YELLOWBRICK_CONTENT = "test_table_content" +YELLOWBRICK_SCHEMA = "test_schema" def _yellowbrick_vector_from_texts( metadatas: Optional[List[dict]] = None, drop: bool = True ) -> Yellowbrick: - return Yellowbrick.from_texts( + db = Yellowbrick.from_texts( fake_texts, - FakeEmbeddings(), + ConsistentFakeEmbeddings(), metadatas, YELLOWBRICK_URL, - YELLOWBRICK_TABLE, + table=YELLOWBRICK_TABLE, + schema=YELLOWBRICK_SCHEMA, + drop=drop, ) + db.logger.setLevel(logging.DEBUG) + return db + + +def _yellowbrick_vector_from_texts_no_schema( + metadatas: Optional[List[dict]] = None, drop: bool = True +) -> Yellowbrick: + db = Yellowbrick.from_texts( + fake_texts, + ConsistentFakeEmbeddings(), + metadatas, + YELLOWBRICK_URL, + table=YELLOWBRICK_TABLE, + drop=drop, + ) + db.logger.setLevel(logging.DEBUG) + return db @pytest.mark.requires("yb-vss") def test_yellowbrick() -> None: """Test end to end construction and search.""" - docsearch = _yellowbrick_vector_from_texts() - output = docsearch.similarity_search("foo", k=1) - docsearch.drop(YELLOWBRICK_TABLE) - assert output == [Document(page_content="foo", metadata={})] + docsearches = [ + _yellowbrick_vector_from_texts(), + _yellowbrick_vector_from_texts_no_schema(), + ] + for docsearch in docsearches: + output = docsearch.similarity_search("foo", k=1) + assert output == [Document(page_content="foo", metadata={})] + docsearch.drop(table=YELLOWBRICK_TABLE, schema=docsearch._schema) + docsearch.drop(table=YELLOWBRICK_CONTENT, schema=docsearch._schema) + + +@pytest.mark.requires("yb-vss") +def test_yellowbrick_add_text() -> None: + """Test end to end construction and search.""" + docsearches = [ + _yellowbrick_vector_from_texts(), + _yellowbrick_vector_from_texts_no_schema(), + ] + for docsearch in docsearches: + output = docsearch.similarity_search("foo", k=1) + assert output == [Document(page_content="foo", metadata={})] + texts = ["oof"] + docsearch.add_texts(texts) + output = docsearch.similarity_search("oof", k=1) + assert output == [Document(page_content="oof", metadata={})] + docsearch.drop(table=YELLOWBRICK_TABLE, schema=docsearch._schema) + docsearch.drop(table=YELLOWBRICK_CONTENT, schema=docsearch._schema) + + +@pytest.mark.requires("yb-vss") +def test_yellowbrick_delete() -> None: + """Test end to end construction and search.""" + docsearches = [ + _yellowbrick_vector_from_texts(), + _yellowbrick_vector_from_texts_no_schema(), + ] + for docsearch in docsearches: + output = docsearch.similarity_search("foo", k=1) + assert output == [Document(page_content="foo", metadata={})] + texts = ["oof"] + added_docs = docsearch.add_texts(texts) + output = docsearch.similarity_search("oof", k=1) + assert output == [Document(page_content="oof", metadata={})] + docsearch.delete(added_docs) + output = docsearch.similarity_search("oof", k=1) + assert output != [Document(page_content="oof", metadata={})] + docsearch.drop(table=YELLOWBRICK_TABLE, schema=docsearch._schema) + docsearch.drop(table=YELLOWBRICK_CONTENT, schema=docsearch._schema) + + +@pytest.mark.requires("yb-vss") +def test_yellowbrick_delete_all() -> None: + """Test end to end construction and search.""" + docsearches = [ + _yellowbrick_vector_from_texts(), + _yellowbrick_vector_from_texts_no_schema(), + ] + for docsearch in docsearches: + output = docsearch.similarity_search("foo", k=1) + assert output == [Document(page_content="foo", metadata={})] + texts = ["oof"] + docsearch.add_texts(texts) + output = docsearch.similarity_search("oof", k=1) + assert output == [Document(page_content="oof", metadata={})] + docsearch.delete(delete_all=True) + output = docsearch.similarity_search("oof", k=1) + assert output != [Document(page_content="oof", metadata={})] + output = docsearch.similarity_search("foo", k=1) + assert output != [Document(page_content="foo", metadata={})] + docsearch.drop(table=YELLOWBRICK_TABLE, schema=docsearch._schema) + docsearch.drop(table=YELLOWBRICK_CONTENT, schema=docsearch._schema) + + +@pytest.mark.requires("yb-vss") +def test_yellowbrick_lsh_search() -> None: + """Test end to end construction and search.""" + docsearches = [ + _yellowbrick_vector_from_texts(), + _yellowbrick_vector_from_texts_no_schema(), + ] + for docsearch in docsearches: + index_params = Yellowbrick.IndexParams( + Yellowbrick.IndexType.LSH, {"num_hyperplanes": 10, "hamming_distance": 0} + ) + docsearch.drop_index(index_params) + docsearch.create_index(index_params) + output = docsearch.similarity_search("foo", k=1, index_params=index_params) + assert output == [Document(page_content="foo", metadata={})] + docsearch.drop(table=YELLOWBRICK_TABLE, schema=docsearch._schema) + docsearch.drop(table=YELLOWBRICK_CONTENT, schema=docsearch._schema) + docsearch.drop_index(index_params=index_params) + + +@pytest.mark.requires("yb-vss") +def test_yellowbrick_lsh_search_update() -> None: + """Test end to end construction and search.""" + docsearches = [ + _yellowbrick_vector_from_texts(), + _yellowbrick_vector_from_texts_no_schema(), + ] + for docsearch in docsearches: + index_params = Yellowbrick.IndexParams( + Yellowbrick.IndexType.LSH, {"num_hyperplanes": 10, "hamming_distance": 0} + ) + docsearch.drop_index(index_params) + docsearch.create_index(index_params) + output = docsearch.similarity_search("foo", k=1, index_params=index_params) + assert output == [Document(page_content="foo", metadata={})] + texts = ["oof"] + docsearch.add_texts(texts, index_params=index_params) + output = docsearch.similarity_search("oof", k=1, index_params=index_params) + assert output == [Document(page_content="oof", metadata={})] + docsearch.drop(table=YELLOWBRICK_TABLE, schema=docsearch._schema) + docsearch.drop(table=YELLOWBRICK_CONTENT, schema=docsearch._schema) + docsearch.drop_index(index_params=index_params) + + +@pytest.mark.requires("yb-vss") +def test_yellowbrick_lsh_delete() -> None: + """Test end to end construction and search.""" + docsearches = [ + _yellowbrick_vector_from_texts(), + _yellowbrick_vector_from_texts_no_schema(), + ] + for docsearch in docsearches: + index_params = Yellowbrick.IndexParams( + Yellowbrick.IndexType.LSH, {"num_hyperplanes": 10, "hamming_distance": 0} + ) + docsearch.drop_index(index_params) + docsearch.create_index(index_params) + output = docsearch.similarity_search("foo", k=1, index_params=index_params) + assert output == [Document(page_content="foo", metadata={})] + texts = ["oof"] + added_docs = docsearch.add_texts(texts, index_params=index_params) + output = docsearch.similarity_search("oof", k=1, index_params=index_params) + assert output == [Document(page_content="oof", metadata={})] + docsearch.delete(added_docs) + output = docsearch.similarity_search("oof", k=1, index_params=index_params) + assert output != [Document(page_content="oof", metadata={})] + docsearch.drop(table=YELLOWBRICK_TABLE, schema=docsearch._schema) + docsearch.drop(table=YELLOWBRICK_CONTENT, schema=docsearch._schema) + docsearch.drop_index(index_params=index_params) + + +@pytest.mark.requires("yb-vss") +def test_yellowbrick_lsh_delete_all() -> None: + """Test end to end construction and search.""" + docsearches = [ + _yellowbrick_vector_from_texts(), + _yellowbrick_vector_from_texts_no_schema(), + ] + for docsearch in docsearches: + index_params = Yellowbrick.IndexParams( + Yellowbrick.IndexType.LSH, {"num_hyperplanes": 10, "hamming_distance": 0} + ) + docsearch.drop_index(index_params) + docsearch.create_index(index_params) + output = docsearch.similarity_search("foo", k=1, index_params=index_params) + assert output == [Document(page_content="foo", metadata={})] + texts = ["oof"] + docsearch.add_texts(texts, index_params=index_params) + output = docsearch.similarity_search("oof", k=1, index_params=index_params) + assert output == [Document(page_content="oof", metadata={})] + docsearch.delete(delete_all=True) + output = docsearch.similarity_search("oof", k=1, index_params=index_params) + assert output != [Document(page_content="oof", metadata={})] + output = docsearch.similarity_search("foo", k=1, index_params=index_params) + assert output != [Document(page_content="foo", metadata={})] + docsearch.drop(table=YELLOWBRICK_TABLE, schema=docsearch._schema) + docsearch.drop(table=YELLOWBRICK_CONTENT, schema=docsearch._schema) + docsearch.drop_index(index_params=index_params) @pytest.mark.requires("yb-vss") def test_yellowbrick_with_score() -> None: """Test end to end construction and search with scores and IDs.""" - texts = ["foo", "bar", "baz"] - metadatas = [{"page": i} for i in range(len(texts))] - docsearch = _yellowbrick_vector_from_texts(metadatas=metadatas) - output = docsearch.similarity_search_with_score("foo", k=3) - docs = [o[0] for o in output] - distances = [o[1] for o in output] - docsearch.drop(YELLOWBRICK_TABLE) - assert docs == [ - Document(page_content="foo", metadata={"page": 0}), - Document(page_content="bar", metadata={"page": 1}), - Document(page_content="baz", metadata={"page": 2}), + docsearches = [ + _yellowbrick_vector_from_texts(), + _yellowbrick_vector_from_texts_no_schema(), ] - assert distances[0] > distances[1] > distances[2] + for docsearch in docsearches: + texts = ["foo", "bar", "baz"] + metadatas = [{"page": i} for i in range(len(texts))] + docsearch = _yellowbrick_vector_from_texts(metadatas=metadatas) + output = docsearch.similarity_search_with_score("foo", k=3) + docs = [o[0] for o in output] + distances = [o[1] for o in output] + assert docs == [ + Document(page_content="foo", metadata={"page": 0}), + Document(page_content="bar", metadata={"page": 1}), + Document(page_content="baz", metadata={"page": 2}), + ] + assert distances[0] > distances[1] > distances[2] @pytest.mark.requires("yb-vss") def test_yellowbrick_add_extra() -> None: """Test end to end construction and MRR search.""" - texts = ["foo", "bar", "baz"] - metadatas = [{"page": i} for i in range(len(texts))] - docsearch = _yellowbrick_vector_from_texts(metadatas=metadatas) - docsearch.add_texts(texts, metadatas) - output = docsearch.similarity_search("foo", k=10) - docsearch.drop(YELLOWBRICK_TABLE) - assert len(output) == 6 + docsearches = [ + _yellowbrick_vector_from_texts(), + _yellowbrick_vector_from_texts_no_schema(), + ] + for docsearch in docsearches: + texts = ["foo", "bar", "baz"] + metadatas = [{"page": i} for i in range(len(texts))] + docsearch = _yellowbrick_vector_from_texts(metadatas=metadatas) + docsearch.add_texts(texts, metadatas) + output = docsearch.similarity_search("foo", k=10) + assert len(output) == 6 diff --git a/libs/community/tests/unit_tests/vectorstores/test_indexing_docs.py b/libs/community/tests/unit_tests/vectorstores/test_indexing_docs.py index 3df9d17bcc4..3658f5bf540 100644 --- a/libs/community/tests/unit_tests/vectorstores/test_indexing_docs.py +++ b/libs/community/tests/unit_tests/vectorstores/test_indexing_docs.py @@ -95,6 +95,7 @@ def test_compatible_vectorstore_documentation() -> None: "VespaStore", "VLite", "Weaviate", + "Yellowbrick", "ZepVectorStore", "Zilliz", "Lantern",