community[minor]: Improve CassandraVectorStore from_texts (#20284)

This commit is contained in:
Christophe Bornet 2024-04-17 23:12:28 +02:00 committed by GitHub
parent 463160c3f6
commit 75733c5cc1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -28,6 +28,8 @@ from langchain_community.vectorstores.utils import maximal_marginal_relevance
CVST = TypeVar("CVST", bound="Cassandra") CVST = TypeVar("CVST", bound="Cassandra")
_NOT_SET = object()
class Cassandra(VectorStore): class Cassandra(VectorStore):
"""Wrapper around Apache Cassandra(R) for vector-store workloads. """Wrapper around Apache Cassandra(R) for vector-store workloads.
@ -48,6 +50,13 @@ class Cassandra(VectorStore):
keyspace = 'my_keyspace' # the keyspace should exist already keyspace = 'my_keyspace' # the keyspace should exist already
table_name = 'my_vector_store' table_name = 'my_vector_store'
vectorstore = Cassandra(embeddings, session, keyspace, table_name) vectorstore = Cassandra(embeddings, session, keyspace, table_name)
Args:
embedding: Embedding function to use.
session: Cassandra driver session.
keyspace: Cassandra key space.
table_name: Cassandra table.
ttl_seconds: Optional time-to-live for the added texts.
""" """
_embedding_dimension: Union[int, None] _embedding_dimension: Union[int, None]
@ -124,7 +133,7 @@ class Cassandra(VectorStore):
self.clear() self.clear()
def clear(self) -> None: def clear(self) -> None:
"""Empty the collection.""" """Empty the table."""
self.table.clear() self.table.clear()
def delete_by_document_id(self, document_id: str) -> None: def delete_by_document_id(self, document_id: str) -> None:
@ -161,12 +170,11 @@ class Cassandra(VectorStore):
"""Run more texts through the embeddings and add to the vectorstore. """Run more texts through the embeddings and add to the vectorstore.
Args: Args:
texts (Iterable[str]): Texts to add to the vectorstore. texts: Texts to add to the vectorstore.
metadatas (Optional[List[dict]], optional): Optional list of metadatas. metadatas: Optional list of metadatas.
ids (Optional[List[str]], optional): Optional list of IDs. ids: Optional list of IDs.
batch_size (int): Number of concurrent requests to send to the server. batch_size: Number of concurrent requests to send to the server.
ttl_seconds (Optional[int], optional): Optional time-to-live ttl_seconds: Optional time-to-live for the added texts.
for the added texts.
Returns: Returns:
List[str]: List of IDs of the added texts. List[str]: List of IDs of the added texts.
@ -337,8 +345,8 @@ class Cassandra(VectorStore):
k: Number of Documents to return. k: Number of Documents to return.
fetch_k: Number of Documents to fetch to pass to MMR algorithm. fetch_k: Number of Documents to fetch to pass to MMR algorithm.
lambda_mult: Number between 0 and 1 that determines the degree lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding of diversity among the results with 0 corresponding to maximum
to maximum diversity and 1 to minimum diversity. diversity and 1 to minimum diversity.
Returns: Returns:
List of Documents selected by maximal marginal relevance. List of Documents selected by maximal marginal relevance.
""" """
@ -389,9 +397,9 @@ class Cassandra(VectorStore):
k: Number of Documents to return. k: Number of Documents to return.
fetch_k: Number of Documents to fetch to pass to MMR algorithm. fetch_k: Number of Documents to fetch to pass to MMR algorithm.
lambda_mult: Number between 0 and 1 that determines the degree lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding of diversity among the results with 0 corresponding to maximum
to maximum diversity and 1 to minimum diversity. diversity and 1 to minimum diversity.
Optional. Defaults to 0.5.
Returns: Returns:
List of Documents selected by maximal marginal relevance. List of Documents selected by maximal marginal relevance.
""" """
@ -410,53 +418,91 @@ class Cassandra(VectorStore):
texts: List[str], texts: List[str],
embedding: Embeddings, embedding: Embeddings,
metadatas: Optional[List[dict]] = None, metadatas: Optional[List[dict]] = None,
*,
session: Session = _NOT_SET,
keyspace: str = "",
table_name: str = "",
ids: Optional[List[str]] = None,
batch_size: int = 16, batch_size: int = 16,
ttl_seconds: Optional[int] = None,
**kwargs: Any, **kwargs: Any,
) -> CVST: ) -> CVST:
"""Create a Cassandra vectorstore from raw texts. """Create a Cassandra vectorstore from raw texts.
No support for specifying text IDs Args:
texts: Texts to add to the vectorstore.
embedding: Embedding function to use.
metadatas: Optional list of metadatas associated with the texts.
session: Cassandra driver session (required).
keyspace: Cassandra key space (required).
table_name: Cassandra table (required).
ids: Optional list of IDs associated with the texts.
batch_size: Number of concurrent requests to send to the server.
Defaults to 16.
ttl_seconds: Optional time-to-live for the added texts.
Returns: Returns:
a Cassandra vectorstore. a Cassandra vectorstore.
""" """
session: Session = kwargs["session"] if session is _NOT_SET:
keyspace: str = kwargs["keyspace"] raise ValueError("session parameter is required")
table_name: str = kwargs["table_name"] if not keyspace:
cassandraStore = cls( raise ValueError("keyspace parameter is required")
if not table_name:
raise ValueError("table_name parameter is required")
store = cls(
embedding=embedding, embedding=embedding,
session=session, session=session,
keyspace=keyspace, keyspace=keyspace,
table_name=table_name, table_name=table_name,
ttl_seconds=ttl_seconds,
) )
cassandraStore.add_texts(texts=texts, metadatas=metadatas) store.add_texts(
return cassandraStore texts=texts, metadatas=metadatas, ids=ids, batch_size=batch_size
)
return store
@classmethod @classmethod
def from_documents( def from_documents(
cls: Type[CVST], cls: Type[CVST],
documents: List[Document], documents: List[Document],
embedding: Embeddings, embedding: Embeddings,
*,
session: Session = _NOT_SET,
keyspace: str = "",
table_name: str = "",
ids: Optional[List[str]] = None,
batch_size: int = 16, batch_size: int = 16,
ttl_seconds: Optional[int] = None,
**kwargs: Any, **kwargs: Any,
) -> CVST: ) -> CVST:
"""Create a Cassandra vectorstore from a document list. """Create a Cassandra vectorstore from a document list.
No support for specifying text IDs Args:
documents: Documents to add to the vectorstore.
embedding: Embedding function to use.
session: Cassandra driver session (required).
keyspace: Cassandra key space (required).
table_name: Cassandra table (required).
ids: Optional list of IDs associated with the documents.
batch_size: Number of concurrent requests to send to the server.
Defaults to 16.
ttl_seconds: Optional time-to-live for the added documents.
Returns: Returns:
a Cassandra vectorstore. a Cassandra vectorstore.
""" """
texts = [doc.page_content for doc in documents] texts = [doc.page_content for doc in documents]
metadatas = [doc.metadata for doc in documents] metadatas = [doc.metadata for doc in documents]
session: Session = kwargs["session"]
keyspace: str = kwargs["keyspace"]
table_name: str = kwargs["table_name"]
return cls.from_texts( return cls.from_texts(
texts=texts, texts=texts,
metadatas=metadatas,
embedding=embedding, embedding=embedding,
metadatas=metadatas,
session=session, session=session,
keyspace=keyspace, keyspace=keyspace,
table_name=table_name, table_name=table_name,
ids=ids,
batch_size=batch_size,
ttl_seconds=ttl_seconds,
**kwargs,
) )