From 75733c5cc1b4d4d382affc42e2c4f3680f2d1e21 Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Wed, 17 Apr 2024 23:12:28 +0200 Subject: [PATCH] community[minor]: Improve CassandraVectorStore from_texts (#20284) --- .../vectorstores/cassandra.py | 94 ++++++++++++++----- 1 file changed, 70 insertions(+), 24 deletions(-) diff --git a/libs/community/langchain_community/vectorstores/cassandra.py b/libs/community/langchain_community/vectorstores/cassandra.py index 041f6995200..799ea79b487 100644 --- a/libs/community/langchain_community/vectorstores/cassandra.py +++ b/libs/community/langchain_community/vectorstores/cassandra.py @@ -28,6 +28,8 @@ from langchain_community.vectorstores.utils import maximal_marginal_relevance CVST = TypeVar("CVST", bound="Cassandra") +_NOT_SET = object() + class Cassandra(VectorStore): """Wrapper around Apache Cassandra(R) for vector-store workloads. @@ -48,6 +50,13 @@ class Cassandra(VectorStore): keyspace = 'my_keyspace' # the keyspace should exist already table_name = 'my_vector_store' vectorstore = Cassandra(embeddings, session, keyspace, table_name) + + Args: + embedding: Embedding function to use. + session: Cassandra driver session. + keyspace: Cassandra key space. + table_name: Cassandra table. + ttl_seconds: Optional time-to-live for the added texts. """ _embedding_dimension: Union[int, None] @@ -124,7 +133,7 @@ class Cassandra(VectorStore): self.clear() def clear(self) -> None: - """Empty the collection.""" + """Empty the table.""" self.table.clear() def delete_by_document_id(self, document_id: str) -> None: @@ -161,12 +170,11 @@ class Cassandra(VectorStore): """Run more texts through the embeddings and add to the vectorstore. Args: - texts (Iterable[str]): Texts to add to the vectorstore. - metadatas (Optional[List[dict]], optional): Optional list of metadatas. - ids (Optional[List[str]], optional): Optional list of IDs. - batch_size (int): Number of concurrent requests to send to the server. - ttl_seconds (Optional[int], optional): Optional time-to-live - for the added texts. + texts: Texts to add to the vectorstore. + metadatas: Optional list of metadatas. + ids: Optional list of IDs. + batch_size: Number of concurrent requests to send to the server. + ttl_seconds: Optional time-to-live for the added texts. Returns: List[str]: List of IDs of the added texts. @@ -337,8 +345,8 @@ class Cassandra(VectorStore): k: Number of Documents to return. fetch_k: Number of Documents to fetch to pass to MMR algorithm. lambda_mult: Number between 0 and 1 that determines the degree - of diversity among the results with 0 corresponding - to maximum diversity and 1 to minimum diversity. + of diversity among the results with 0 corresponding to maximum + diversity and 1 to minimum diversity. Returns: List of Documents selected by maximal marginal relevance. """ @@ -389,9 +397,9 @@ class Cassandra(VectorStore): k: Number of Documents to return. fetch_k: Number of Documents to fetch to pass to MMR algorithm. lambda_mult: Number between 0 and 1 that determines the degree - of diversity among the results with 0 corresponding - to maximum diversity and 1 to minimum diversity. - Optional. + of diversity among the results with 0 corresponding to maximum + diversity and 1 to minimum diversity. + Defaults to 0.5. Returns: List of Documents selected by maximal marginal relevance. """ @@ -410,53 +418,91 @@ class Cassandra(VectorStore): texts: List[str], embedding: Embeddings, metadatas: Optional[List[dict]] = None, + *, + session: Session = _NOT_SET, + keyspace: str = "", + table_name: str = "", + ids: Optional[List[str]] = None, batch_size: int = 16, + ttl_seconds: Optional[int] = None, **kwargs: Any, ) -> CVST: """Create a Cassandra vectorstore from raw texts. - No support for specifying text IDs + Args: + texts: Texts to add to the vectorstore. + embedding: Embedding function to use. + metadatas: Optional list of metadatas associated with the texts. + session: Cassandra driver session (required). + keyspace: Cassandra key space (required). + table_name: Cassandra table (required). + ids: Optional list of IDs associated with the texts. + batch_size: Number of concurrent requests to send to the server. + Defaults to 16. + ttl_seconds: Optional time-to-live for the added texts. Returns: a Cassandra vectorstore. """ - session: Session = kwargs["session"] - keyspace: str = kwargs["keyspace"] - table_name: str = kwargs["table_name"] - cassandraStore = cls( + if session is _NOT_SET: + raise ValueError("session parameter is required") + if not keyspace: + raise ValueError("keyspace parameter is required") + if not table_name: + raise ValueError("table_name parameter is required") + store = cls( embedding=embedding, session=session, keyspace=keyspace, table_name=table_name, + ttl_seconds=ttl_seconds, ) - cassandraStore.add_texts(texts=texts, metadatas=metadatas) - return cassandraStore + store.add_texts( + texts=texts, metadatas=metadatas, ids=ids, batch_size=batch_size + ) + return store @classmethod def from_documents( cls: Type[CVST], documents: List[Document], embedding: Embeddings, + *, + session: Session = _NOT_SET, + keyspace: str = "", + table_name: str = "", + ids: Optional[List[str]] = None, batch_size: int = 16, + ttl_seconds: Optional[int] = None, **kwargs: Any, ) -> CVST: """Create a Cassandra vectorstore from a document list. - No support for specifying text IDs + Args: + documents: Documents to add to the vectorstore. + embedding: Embedding function to use. + session: Cassandra driver session (required). + keyspace: Cassandra key space (required). + table_name: Cassandra table (required). + ids: Optional list of IDs associated with the documents. + batch_size: Number of concurrent requests to send to the server. + Defaults to 16. + ttl_seconds: Optional time-to-live for the added documents. Returns: a Cassandra vectorstore. """ texts = [doc.page_content for doc in documents] metadatas = [doc.metadata for doc in documents] - session: Session = kwargs["session"] - keyspace: str = kwargs["keyspace"] - table_name: str = kwargs["table_name"] return cls.from_texts( texts=texts, - metadatas=metadatas, embedding=embedding, + metadatas=metadatas, session=session, keyspace=keyspace, table_name=table_name, + ids=ids, + batch_size=batch_size, + ttl_seconds=ttl_seconds, + **kwargs, )