diff --git a/libs/community/langchain_community/vectorstores/oraclevs.py b/libs/community/langchain_community/vectorstores/oraclevs.py index 6aa767927f5..0e0181c67f4 100644 --- a/libs/community/langchain_community/vectorstores/oraclevs.py +++ b/libs/community/langchain_community/vectorstores/oraclevs.py @@ -22,6 +22,8 @@ from typing import ( cast, ) +from numpy.typing import NDArray + if TYPE_CHECKING: from oracledb import Connection @@ -47,6 +49,31 @@ logging.basicConfig( T = TypeVar("T", bound=Callable[..., Any]) +def _get_connection(client: Any) -> Connection | None: + # Dynamically import oracledb and the required classes + try: + import oracledb + except ImportError as e: + raise ImportError( + "Unable to import oracledb, please install with `pip install -U oracledb`." + ) from e + + # check if ConnectionPool exists + connection_pool_class = getattr(oracledb, "ConnectionPool", None) + + if isinstance(client, oracledb.Connection): + return client + elif connection_pool_class and isinstance(client, connection_pool_class): + return client.acquire() + else: + valid_types = "oracledb.Connection" + if connection_pool_class: + valid_types += " or oracledb.ConnectionPool" + raise TypeError( + f"Expected client of type {valid_types}, got {type(client).__name__}" + ) + + def _handle_exceptions(func: T) -> T: @functools.wraps(func) def wrapper(*args: Any, **kwargs: Any) -> Any: @@ -70,7 +97,7 @@ def _handle_exceptions(func: T) -> T: return cast(T, wrapper) -def _table_exists(client: Connection, table_name: str) -> bool: +def _table_exists(connection: Connection, table_name: str) -> bool: try: import oracledb except ImportError as e: @@ -79,7 +106,7 @@ def _table_exists(client: Connection, table_name: str) -> bool: ) from e try: - with client.cursor() as cursor: + with connection.cursor() as cursor: cursor.execute(f"SELECT COUNT(*) FROM {table_name}") return True except oracledb.DatabaseError as ex: @@ -106,7 +133,7 @@ def _compare_version(version: str, target_version: str) -> bool: @_handle_exceptions -def _index_exists(client: Connection, index_name: str) -> bool: +def _index_exists(connection: Connection, index_name: str) -> bool: # Check if the index exists query = """ SELECT index_name @@ -114,7 +141,7 @@ def _index_exists(client: Connection, index_name: str) -> bool: WHERE upper(index_name) = upper(:idx_name) """ - with client.cursor() as cursor: + with connection.cursor() as cursor: # Execute the query cursor.execute(query, idx_name=index_name.upper()) result = cursor.fetchone() @@ -146,16 +173,16 @@ def _get_index_name(base_name: str) -> str: @_handle_exceptions -def _create_table(client: Connection, table_name: str, embedding_dim: int) -> None: +def _create_table(connection: Connection, table_name: str, embedding_dim: int) -> None: cols_dict = { "id": "RAW(16) DEFAULT SYS_GUID() PRIMARY KEY", "text": "CLOB", - "metadata": "CLOB", + "metadata": "JSON", "embedding": f"vector({embedding_dim}, FLOAT32)", } - if not _table_exists(client, table_name): - with client.cursor() as cursor: + if not _table_exists(connection, table_name): + with connection.cursor() as cursor: ddl_body = ", ".join( f"{col_name} {col_type}" for col_name, col_type in cols_dict.items() ) @@ -168,43 +195,45 @@ def _create_table(client: Connection, table_name: str, embedding_dim: int) -> No @_handle_exceptions def create_index( - client: Connection, + client: Any, vector_store: OracleVS, params: Optional[dict[str, Any]] = None, ) -> None: - """Create an index on the vector store. - - Args: - client: The OracleDB connection object. - vector_store: The vector store object. - params: Optional parameters for the index creation. - - Raises: - ValueError: If an invalid parameter is provided. - """ + connection = _get_connection(client) + if connection is None: + raise ValueError("Failed to acquire a connection.") if params: if params["idx_type"] == "HNSW": _create_hnsw_index( - client, vector_store.table_name, vector_store.distance_strategy, params + connection, + vector_store.table_name, + vector_store.distance_strategy, + params, ) elif params["idx_type"] == "IVF": _create_ivf_index( - client, vector_store.table_name, vector_store.distance_strategy, params + connection, + vector_store.table_name, + vector_store.distance_strategy, + params, ) else: _create_hnsw_index( - client, vector_store.table_name, vector_store.distance_strategy, params + connection, + vector_store.table_name, + vector_store.distance_strategy, + params, ) else: _create_hnsw_index( - client, vector_store.table_name, vector_store.distance_strategy, params + connection, vector_store.table_name, vector_store.distance_strategy, params ) return @_handle_exceptions def _create_hnsw_index( - client: Connection, + connection: Connection, table_name: str, distance_strategy: DistanceStrategy, params: Optional[dict[str, Any]] = None, @@ -278,8 +307,8 @@ def _create_hnsw_index( ddl = ddl_assembly.format(**config) # Check if the index exists - if not _index_exists(client, config["idx_name"]): - with client.cursor() as cursor: + if not _index_exists(connection, config["idx_name"]): + with connection.cursor() as cursor: cursor.execute(ddl) logger.info("Index created successfully...") else: @@ -288,7 +317,7 @@ def _create_hnsw_index( @_handle_exceptions def _create_ivf_index( - client: Connection, + connection: Connection, table_name: str, distance_strategy: DistanceStrategy, params: Optional[dict[str, Any]] = None, @@ -350,8 +379,8 @@ def _create_ivf_index( ddl = ddl_assembly.format(**config) # Check if the index exists - if not _index_exists(client, config["idx_name"]): - with client.cursor() as cursor: + if not _index_exists(connection, config["idx_name"]): + with connection.cursor() as cursor: cursor.execute(ddl) logger.info("Index created successfully...") else: @@ -359,7 +388,7 @@ def _create_ivf_index( @_handle_exceptions -def drop_table_purge(client: Connection, table_name: str) -> None: +def drop_table_purge(client: Any, table_name: str) -> None: """Drop a table and purge it from the database. Args: @@ -369,9 +398,11 @@ def drop_table_purge(client: Connection, table_name: str) -> None: Raises: RuntimeError: If an error occurs while dropping the table. """ - if _table_exists(client, table_name): - cursor = client.cursor() - with cursor: + connection = _get_connection(client) + if connection is None: + raise ValueError("Failed to acquire a connection.") + if _table_exists(connection, table_name): + with connection.cursor() as cursor: ddl = f"DROP TABLE {table_name} PURGE" cursor.execute(ddl) logger.info("Table dropped successfully...") @@ -381,7 +412,7 @@ def drop_table_purge(client: Connection, table_name: str) -> None: @_handle_exceptions -def drop_index_if_exists(client: Connection, index_name: str) -> None: +def drop_index_if_exists(client: Any, index_name: str) -> None: """Drop an index if it exists. Args: @@ -391,9 +422,12 @@ def drop_index_if_exists(client: Connection, index_name: str) -> None: Raises: RuntimeError: If an error occurs while dropping the index. """ - if _index_exists(client, index_name): + connection = _get_connection(client) + if connection is None: + raise ValueError("Failed to acquire a connection.") + if _index_exists(connection, index_name): drop_query = f"DROP INDEX {index_name}" - with client.cursor() as cursor: + with connection.cursor() as cursor: cursor.execute(drop_query) logger.info(f"Index {index_name} has been dropped.") else: @@ -426,7 +460,7 @@ class OracleVS(VectorStore): def __init__( self, - client: Connection, + client: Any, embedding_function: Union[ Callable[[str], List[float]], Embeddings, @@ -445,8 +479,11 @@ class OracleVS(VectorStore): ) from e self.insert_mode = "array" + connection = _get_connection(client) + if connection is None: + raise ValueError("Failed to acquire a connection.") - if client.thin is True: + if hasattr(connection, "thin") and connection.thin: if oracledb.__version__ == "2.1.0": raise Exception( "Oracle DB python thin client driver version 2.1.0 not supported" @@ -494,8 +531,7 @@ class OracleVS(VectorStore): self.table_name = table_name self.distance_strategy = distance_strategy self.params = params - - _create_table(client, table_name, embedding_dim) + _create_table(connection, table_name, embedding_dim) except oracledb.DatabaseError as db_err: logger.exception(f"Database error occurred while create table: {db_err}") raise RuntimeError( @@ -613,13 +649,16 @@ class OracleVS(VectorStore): ) ] - with self.client.cursor() as cursor: + connection = _get_connection(self.client) + if connection is None: + raise ValueError("Failed to acquire a connection.") + with connection.cursor() as cursor: cursor.executemany( f"INSERT INTO {self.table_name} (id, embedding, metadata, " f"text) VALUES (:1, :2, :3, :4)", docs, ) - self.client.commit() + connection.commit() return processed_ids def similarity_search( @@ -630,6 +669,7 @@ class OracleVS(VectorStore): **kwargs: Any, ) -> List[Document]: """Return docs most similar to query.""" + embedding: List[float] = [] if isinstance(self.embedding_function, Embeddings): embedding = self.embedding_function.embed_query(query) documents = self.similarity_search_by_vector( @@ -657,6 +697,7 @@ class OracleVS(VectorStore): **kwargs: Any, ) -> List[Tuple[Document, float]]: """Return docs most similar to query.""" + embedding: List[float] = [] if isinstance(self.embedding_function, Embeddings): embedding = self.embedding_function.embed_query(query) docs_and_scores = self.similarity_search_by_vector_with_relevance_scores( @@ -707,25 +748,27 @@ class OracleVS(VectorStore): embedding_arr = array.array("f", embedding) query = f""" - SELECT id, - text, - metadata, - vector_distance(embedding, :embedding, - {_get_distance_function(self.distance_strategy)}) as distance - FROM {self.table_name} - ORDER BY distance - FETCH APPROX FIRST {k} ROWS ONLY + SELECT id, + text, + metadata, + vector_distance(embedding, :embedding, + {_get_distance_function(self.distance_strategy)}) as distance + FROM {self.table_name} + ORDER BY distance + FETCH APPROX FIRST {k} ROWS ONLY """ + # Execute the query - with self.client.cursor() as cursor: + connection = _get_connection(self.client) + if connection is None: + raise ValueError("Failed to acquire a connection.") + with connection.cursor() as cursor: cursor.execute(query, embedding=embedding_arr) results = cursor.fetchall() # Filter results if filter is provided for result in results: - metadata = json.loads( - self._get_clob_value(result[2]) if result[2] is not None else "{}" - ) + metadata = dict(result[2]) if isinstance(result[2], dict) else {} # Apply filtering based on the 'filter' dictionary if filter: @@ -761,7 +804,7 @@ class OracleVS(VectorStore): k: int, filter: Optional[Dict[str, Any]] = None, **kwargs: Any, - ) -> List[Tuple[Document, float, np.ndarray]]: + ) -> List[Tuple[Document, float, NDArray[np.float32]]]: embedding_arr: Any if self.insert_mode == "clob": embedding_arr = json.dumps(embedding) @@ -771,27 +814,29 @@ class OracleVS(VectorStore): documents = [] query = f""" - SELECT id, - text, - metadata, - vector_distance(embedding, :embedding, { + SELECT id, + text, + metadata, + vector_distance(embedding, :embedding, { _get_distance_function(self.distance_strategy) }) as distance, - embedding - FROM {self.table_name} - ORDER BY distance - FETCH APPROX FIRST {k} ROWS ONLY + embedding + FROM {self.table_name} + ORDER BY distance + FETCH APPROX FIRST {k} ROWS ONLY """ # Execute the query - with self.client.cursor() as cursor: + connection = _get_connection(self.client) + if connection is None: + raise ValueError("Failed to acquire a connection.") + with connection.cursor() as cursor: cursor.execute(query, embedding=embedding_arr) results = cursor.fetchall() for result in results: page_content_str = self._get_clob_value(result[1]) - metadata_str = self._get_clob_value(result[2]) - metadata = json.loads(metadata_str) + metadata = result[2] if isinstance(result[2], dict) else {} # Apply filter if provided and matches; otherwise, add all # documents @@ -984,9 +1029,12 @@ class OracleVS(VectorStore): f"id{i}": hashed_id for i, hashed_id in enumerate(hashed_ids, start=1) } - with self.client.cursor() as cursor: + connection = _get_connection(self.client) + if connection is None: + raise ValueError("Failed to acquire a connection.") + with connection.cursor() as cursor: cursor.execute(ddl, bind_vars) - self.client.commit() + connection.commit() @classmethod @_handle_exceptions @@ -997,10 +1045,10 @@ class OracleVS(VectorStore): metadatas: Optional[List[dict]] = None, **kwargs: Any, ) -> OracleVS: - """Return VectorStore initialized from texts and embeddings.""" - client = kwargs.get("client") + client: Any = kwargs.get("client", None) if client is None: raise ValueError("client parameter is required...") + params = kwargs.get("params", {}) table_name = str(kwargs.get("table_name", "langchain"))