From 484a0090124fe1ec7d32b3c38c7041077a37e598 Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Mon, 6 May 2024 20:32:32 +0200 Subject: [PATCH] community[minor]: Relax constraints on Cassandra VectorStore constructors (#21209) If Session and/or keyspace are not provided, they are resolved from cassio's context. So they are not required. This change is fully backward compatible. --- .../vectorstores/cassandra.py | 69 +++++++++---------- 1 file changed, 32 insertions(+), 37 deletions(-) diff --git a/libs/community/langchain_community/vectorstores/cassandra.py b/libs/community/langchain_community/vectorstores/cassandra.py index 9656f56339e..6f64000da4e 100644 --- a/libs/community/langchain_community/vectorstores/cassandra.py +++ b/libs/community/langchain_community/vectorstores/cassandra.py @@ -31,8 +31,6 @@ from langchain_community.vectorstores.utils import maximal_marginal_relevance CVST = TypeVar("CVST", bound="Cassandra") -_NOT_SET = object() - class Cassandra(VectorStore): """Apache Cassandra(R) for vector-store workloads. @@ -56,9 +54,9 @@ class Cassandra(VectorStore): Args: embedding: Embedding function to use. - session: Cassandra driver session. - keyspace: Cassandra key space. - table_name: Cassandra table. + session: Cassandra driver session. If not provided, it is resolved from cassio. + keyspace: Cassandra key space. If not provided, it is resolved from cassio. + table_name: Cassandra table (required). ttl_seconds: Optional time-to-live for the added texts. body_index_options: Optional options used to create the body index. Eg. body_index_options = [cassio.table.cql.STANDARD_ANALYZER] @@ -83,9 +81,9 @@ class Cassandra(VectorStore): def __init__( self, embedding: Embeddings, - session: Session, - keyspace: str, - table_name: str, + session: Optional[Session] = None, + keyspace: Optional[str] = None, + table_name: str = "", ttl_seconds: Optional[int] = None, *, body_index_options: Optional[List[Tuple[str, Any]]] = None, @@ -98,7 +96,8 @@ class Cassandra(VectorStore): "Could not import cassio python package. " "Please install it with `pip install cassio`." ) - """Create a vector table.""" + if not table_name: + raise ValueError("Missing required parameter 'table_name'.") self.embedding = embedding self.session = session self.keyspace = keyspace @@ -779,8 +778,8 @@ class Cassandra(VectorStore): embedding: Embeddings, metadatas: Optional[List[dict]] = None, *, - session: Session = _NOT_SET, - keyspace: str = "", + session: Optional[Session] = None, + keyspace: Optional[str] = None, table_name: str = "", ids: Optional[List[str]] = None, batch_size: int = 16, @@ -794,8 +793,10 @@ class Cassandra(VectorStore): texts: Texts to add to the vectorstore. embedding: Embedding function to use. metadatas: Optional list of metadatas associated with the texts. - session: Cassandra driver session (required). - keyspace: Cassandra key space (required). + session: Cassandra driver session. + If not provided, it is resolved from cassio. + keyspace: Cassandra key space. + If not provided, it is resolved from cassio. table_name: Cassandra table (required). ids: Optional list of IDs associated with the texts. batch_size: Number of concurrent requests to send to the server. @@ -807,12 +808,6 @@ class Cassandra(VectorStore): Returns: a Cassandra vectorstore. """ - if session is _NOT_SET: - raise ValueError("session parameter is required") - if not keyspace: - raise ValueError("keyspace parameter is required") - if not table_name: - raise ValueError("table_name parameter is required") store = cls( embedding=embedding, session=session, @@ -833,8 +828,8 @@ class Cassandra(VectorStore): embedding: Embeddings, metadatas: Optional[List[dict]] = None, *, - session: Session = _NOT_SET, - keyspace: str = "", + session: Optional[Session] = None, + keyspace: Optional[str] = None, table_name: str = "", ids: Optional[List[str]] = None, concurrency: int = 16, @@ -848,8 +843,10 @@ class Cassandra(VectorStore): texts: Texts to add to the vectorstore. embedding: Embedding function to use. metadatas: Optional list of metadatas associated with the texts. - session: Cassandra driver session (required). - keyspace: Cassandra key space (required). + session: Cassandra driver session. + If not provided, it is resolved from cassio. + keyspace: Cassandra key space. + If not provided, it is resolved from cassio. table_name: Cassandra table (required). ids: Optional list of IDs associated with the texts. concurrency: Number of concurrent queries to send to the database. @@ -861,12 +858,6 @@ class Cassandra(VectorStore): Returns: a Cassandra vectorstore. """ - if session is _NOT_SET: - raise ValueError("session parameter is required") - if not keyspace: - raise ValueError("keyspace parameter is required") - if not table_name: - raise ValueError("table_name parameter is required") store = cls( embedding=embedding, session=session, @@ -887,8 +878,8 @@ class Cassandra(VectorStore): documents: List[Document], embedding: Embeddings, *, - session: Session = _NOT_SET, - keyspace: str = "", + session: Optional[Session] = None, + keyspace: Optional[str] = None, table_name: str = "", ids: Optional[List[str]] = None, batch_size: int = 16, @@ -901,8 +892,10 @@ class Cassandra(VectorStore): Args: documents: Documents to add to the vectorstore. embedding: Embedding function to use. - session: Cassandra driver session (required). - keyspace: Cassandra key space (required). + session: Cassandra driver session. + If not provided, it is resolved from cassio. + keyspace: Cassandra key space. + If not provided, it is resolved from cassio. table_name: Cassandra table (required). ids: Optional list of IDs associated with the documents. batch_size: Number of concurrent requests to send to the server. @@ -936,8 +929,8 @@ class Cassandra(VectorStore): documents: List[Document], embedding: Embeddings, *, - session: Session = _NOT_SET, - keyspace: str = "", + session: Optional[Session] = None, + keyspace: Optional[str] = None, table_name: str = "", ids: Optional[List[str]] = None, concurrency: int = 16, @@ -950,8 +943,10 @@ class Cassandra(VectorStore): Args: documents: Documents to add to the vectorstore. embedding: Embedding function to use. - session: Cassandra driver session (required). - keyspace: Cassandra key space (required). + session: Cassandra driver session. + If not provided, it is resolved from cassio. + keyspace: Cassandra key space. + If not provided, it is resolved from cassio. table_name: Cassandra table (required). ids: Optional list of IDs associated with the documents. concurrency: Number of concurrent queries to send to the database.