Oraclevs integration (#29723)

Thank you for contributing to LangChain!

- [ ] **PR title**: "package: description"
- Where "package" is whichever of langchain, community, core, etc. is
being modified. Use "docs: ..." for purely docs changes, "infra: ..."
for CI changes.
  - Example: "community: add foobar LLM"
  community: langchain_community/vectorstore/oraclevs.py


- [ ] **PR message**: ***Delete this entire checklist*** and replace
with
- **Description:** Refactored code to allow a connection or a connection
pool.
- **Issue:** Normally an idel connection is terminated by the server
side listener at timeout. A user thus has to re-instantiate the vector
store. The timeout in case of connection is not configurable. The
solution is to use a connection pool where a user can specify a user
defined timeout and the connections are managed by the pool.
    - **Dependencies:** None
    - **Twitter handle:** 


- [ ] **Add tests and docs**: This is not a new integration. A user can
pass either a connection or a connection pool. The determination of what
is passed is made at run time. Everything should work as before.

- [ ] **Lint and test**:  Already done.

Additional guidelines:
- Make sure optional dependencies are imported within a function.
- Please do not add dependencies to pyproject.toml files (even optional
ones) unless they are required for unit tests.
- Most PRs should not touch more than one package.
- Changes should be backwards compatible.
- If you are adding something to community, do not re-import it in
langchain.

If no one reviews your PR within a few days, please @-mention one of
baskaryan, efriis, eyurtsev, ccurme, vbarda, hwchase17.

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Shailendra Mishra 2025-02-11 14:56:55 -08:00 committed by GitHub
parent 42ebf6ae0c
commit c7d74eb7a3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -22,6 +22,8 @@ from typing import (
cast, cast,
) )
from numpy.typing import NDArray
if TYPE_CHECKING: if TYPE_CHECKING:
from oracledb import Connection from oracledb import Connection
@ -47,6 +49,31 @@ logging.basicConfig(
T = TypeVar("T", bound=Callable[..., Any]) T = TypeVar("T", bound=Callable[..., Any])
def _get_connection(client: Any) -> Connection | None:
# Dynamically import oracledb and the required classes
try:
import oracledb
except ImportError as e:
raise ImportError(
"Unable to import oracledb, please install with `pip install -U oracledb`."
) from e
# check if ConnectionPool exists
connection_pool_class = getattr(oracledb, "ConnectionPool", None)
if isinstance(client, oracledb.Connection):
return client
elif connection_pool_class and isinstance(client, connection_pool_class):
return client.acquire()
else:
valid_types = "oracledb.Connection"
if connection_pool_class:
valid_types += " or oracledb.ConnectionPool"
raise TypeError(
f"Expected client of type {valid_types}, got {type(client).__name__}"
)
def _handle_exceptions(func: T) -> T: def _handle_exceptions(func: T) -> T:
@functools.wraps(func) @functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any: def wrapper(*args: Any, **kwargs: Any) -> Any:
@ -70,7 +97,7 @@ def _handle_exceptions(func: T) -> T:
return cast(T, wrapper) return cast(T, wrapper)
def _table_exists(client: Connection, table_name: str) -> bool: def _table_exists(connection: Connection, table_name: str) -> bool:
try: try:
import oracledb import oracledb
except ImportError as e: except ImportError as e:
@ -79,7 +106,7 @@ def _table_exists(client: Connection, table_name: str) -> bool:
) from e ) from e
try: try:
with client.cursor() as cursor: with connection.cursor() as cursor:
cursor.execute(f"SELECT COUNT(*) FROM {table_name}") cursor.execute(f"SELECT COUNT(*) FROM {table_name}")
return True return True
except oracledb.DatabaseError as ex: except oracledb.DatabaseError as ex:
@ -106,7 +133,7 @@ def _compare_version(version: str, target_version: str) -> bool:
@_handle_exceptions @_handle_exceptions
def _index_exists(client: Connection, index_name: str) -> bool: def _index_exists(connection: Connection, index_name: str) -> bool:
# Check if the index exists # Check if the index exists
query = """ query = """
SELECT index_name SELECT index_name
@ -114,7 +141,7 @@ def _index_exists(client: Connection, index_name: str) -> bool:
WHERE upper(index_name) = upper(:idx_name) WHERE upper(index_name) = upper(:idx_name)
""" """
with client.cursor() as cursor: with connection.cursor() as cursor:
# Execute the query # Execute the query
cursor.execute(query, idx_name=index_name.upper()) cursor.execute(query, idx_name=index_name.upper())
result = cursor.fetchone() result = cursor.fetchone()
@ -146,16 +173,16 @@ def _get_index_name(base_name: str) -> str:
@_handle_exceptions @_handle_exceptions
def _create_table(client: Connection, table_name: str, embedding_dim: int) -> None: def _create_table(connection: Connection, table_name: str, embedding_dim: int) -> None:
cols_dict = { cols_dict = {
"id": "RAW(16) DEFAULT SYS_GUID() PRIMARY KEY", "id": "RAW(16) DEFAULT SYS_GUID() PRIMARY KEY",
"text": "CLOB", "text": "CLOB",
"metadata": "CLOB", "metadata": "JSON",
"embedding": f"vector({embedding_dim}, FLOAT32)", "embedding": f"vector({embedding_dim}, FLOAT32)",
} }
if not _table_exists(client, table_name): if not _table_exists(connection, table_name):
with client.cursor() as cursor: with connection.cursor() as cursor:
ddl_body = ", ".join( ddl_body = ", ".join(
f"{col_name} {col_type}" for col_name, col_type in cols_dict.items() f"{col_name} {col_type}" for col_name, col_type in cols_dict.items()
) )
@ -168,43 +195,45 @@ def _create_table(client: Connection, table_name: str, embedding_dim: int) -> No
@_handle_exceptions @_handle_exceptions
def create_index( def create_index(
client: Connection, client: Any,
vector_store: OracleVS, vector_store: OracleVS,
params: Optional[dict[str, Any]] = None, params: Optional[dict[str, Any]] = None,
) -> None: ) -> None:
"""Create an index on the vector store. connection = _get_connection(client)
if connection is None:
Args: raise ValueError("Failed to acquire a connection.")
client: The OracleDB connection object.
vector_store: The vector store object.
params: Optional parameters for the index creation.
Raises:
ValueError: If an invalid parameter is provided.
"""
if params: if params:
if params["idx_type"] == "HNSW": if params["idx_type"] == "HNSW":
_create_hnsw_index( _create_hnsw_index(
client, vector_store.table_name, vector_store.distance_strategy, params connection,
vector_store.table_name,
vector_store.distance_strategy,
params,
) )
elif params["idx_type"] == "IVF": elif params["idx_type"] == "IVF":
_create_ivf_index( _create_ivf_index(
client, vector_store.table_name, vector_store.distance_strategy, params connection,
vector_store.table_name,
vector_store.distance_strategy,
params,
) )
else: else:
_create_hnsw_index( _create_hnsw_index(
client, vector_store.table_name, vector_store.distance_strategy, params connection,
vector_store.table_name,
vector_store.distance_strategy,
params,
) )
else: else:
_create_hnsw_index( _create_hnsw_index(
client, vector_store.table_name, vector_store.distance_strategy, params connection, vector_store.table_name, vector_store.distance_strategy, params
) )
return return
@_handle_exceptions @_handle_exceptions
def _create_hnsw_index( def _create_hnsw_index(
client: Connection, connection: Connection,
table_name: str, table_name: str,
distance_strategy: DistanceStrategy, distance_strategy: DistanceStrategy,
params: Optional[dict[str, Any]] = None, params: Optional[dict[str, Any]] = None,
@ -278,8 +307,8 @@ def _create_hnsw_index(
ddl = ddl_assembly.format(**config) ddl = ddl_assembly.format(**config)
# Check if the index exists # Check if the index exists
if not _index_exists(client, config["idx_name"]): if not _index_exists(connection, config["idx_name"]):
with client.cursor() as cursor: with connection.cursor() as cursor:
cursor.execute(ddl) cursor.execute(ddl)
logger.info("Index created successfully...") logger.info("Index created successfully...")
else: else:
@ -288,7 +317,7 @@ def _create_hnsw_index(
@_handle_exceptions @_handle_exceptions
def _create_ivf_index( def _create_ivf_index(
client: Connection, connection: Connection,
table_name: str, table_name: str,
distance_strategy: DistanceStrategy, distance_strategy: DistanceStrategy,
params: Optional[dict[str, Any]] = None, params: Optional[dict[str, Any]] = None,
@ -350,8 +379,8 @@ def _create_ivf_index(
ddl = ddl_assembly.format(**config) ddl = ddl_assembly.format(**config)
# Check if the index exists # Check if the index exists
if not _index_exists(client, config["idx_name"]): if not _index_exists(connection, config["idx_name"]):
with client.cursor() as cursor: with connection.cursor() as cursor:
cursor.execute(ddl) cursor.execute(ddl)
logger.info("Index created successfully...") logger.info("Index created successfully...")
else: else:
@ -359,7 +388,7 @@ def _create_ivf_index(
@_handle_exceptions @_handle_exceptions
def drop_table_purge(client: Connection, table_name: str) -> None: def drop_table_purge(client: Any, table_name: str) -> None:
"""Drop a table and purge it from the database. """Drop a table and purge it from the database.
Args: Args:
@ -369,9 +398,11 @@ def drop_table_purge(client: Connection, table_name: str) -> None:
Raises: Raises:
RuntimeError: If an error occurs while dropping the table. RuntimeError: If an error occurs while dropping the table.
""" """
if _table_exists(client, table_name): connection = _get_connection(client)
cursor = client.cursor() if connection is None:
with cursor: raise ValueError("Failed to acquire a connection.")
if _table_exists(connection, table_name):
with connection.cursor() as cursor:
ddl = f"DROP TABLE {table_name} PURGE" ddl = f"DROP TABLE {table_name} PURGE"
cursor.execute(ddl) cursor.execute(ddl)
logger.info("Table dropped successfully...") logger.info("Table dropped successfully...")
@ -381,7 +412,7 @@ def drop_table_purge(client: Connection, table_name: str) -> None:
@_handle_exceptions @_handle_exceptions
def drop_index_if_exists(client: Connection, index_name: str) -> None: def drop_index_if_exists(client: Any, index_name: str) -> None:
"""Drop an index if it exists. """Drop an index if it exists.
Args: Args:
@ -391,9 +422,12 @@ def drop_index_if_exists(client: Connection, index_name: str) -> None:
Raises: Raises:
RuntimeError: If an error occurs while dropping the index. RuntimeError: If an error occurs while dropping the index.
""" """
if _index_exists(client, index_name): connection = _get_connection(client)
if connection is None:
raise ValueError("Failed to acquire a connection.")
if _index_exists(connection, index_name):
drop_query = f"DROP INDEX {index_name}" drop_query = f"DROP INDEX {index_name}"
with client.cursor() as cursor: with connection.cursor() as cursor:
cursor.execute(drop_query) cursor.execute(drop_query)
logger.info(f"Index {index_name} has been dropped.") logger.info(f"Index {index_name} has been dropped.")
else: else:
@ -426,7 +460,7 @@ class OracleVS(VectorStore):
def __init__( def __init__(
self, self,
client: Connection, client: Any,
embedding_function: Union[ embedding_function: Union[
Callable[[str], List[float]], Callable[[str], List[float]],
Embeddings, Embeddings,
@ -445,8 +479,11 @@ class OracleVS(VectorStore):
) from e ) from e
self.insert_mode = "array" self.insert_mode = "array"
connection = _get_connection(client)
if connection is None:
raise ValueError("Failed to acquire a connection.")
if client.thin is True: if hasattr(connection, "thin") and connection.thin:
if oracledb.__version__ == "2.1.0": if oracledb.__version__ == "2.1.0":
raise Exception( raise Exception(
"Oracle DB python thin client driver version 2.1.0 not supported" "Oracle DB python thin client driver version 2.1.0 not supported"
@ -494,8 +531,7 @@ class OracleVS(VectorStore):
self.table_name = table_name self.table_name = table_name
self.distance_strategy = distance_strategy self.distance_strategy = distance_strategy
self.params = params self.params = params
_create_table(connection, table_name, embedding_dim)
_create_table(client, table_name, embedding_dim)
except oracledb.DatabaseError as db_err: except oracledb.DatabaseError as db_err:
logger.exception(f"Database error occurred while create table: {db_err}") logger.exception(f"Database error occurred while create table: {db_err}")
raise RuntimeError( raise RuntimeError(
@ -613,13 +649,16 @@ class OracleVS(VectorStore):
) )
] ]
with self.client.cursor() as cursor: connection = _get_connection(self.client)
if connection is None:
raise ValueError("Failed to acquire a connection.")
with connection.cursor() as cursor:
cursor.executemany( cursor.executemany(
f"INSERT INTO {self.table_name} (id, embedding, metadata, " f"INSERT INTO {self.table_name} (id, embedding, metadata, "
f"text) VALUES (:1, :2, :3, :4)", f"text) VALUES (:1, :2, :3, :4)",
docs, docs,
) )
self.client.commit() connection.commit()
return processed_ids return processed_ids
def similarity_search( def similarity_search(
@ -630,6 +669,7 @@ class OracleVS(VectorStore):
**kwargs: Any, **kwargs: Any,
) -> List[Document]: ) -> List[Document]:
"""Return docs most similar to query.""" """Return docs most similar to query."""
embedding: List[float] = []
if isinstance(self.embedding_function, Embeddings): if isinstance(self.embedding_function, Embeddings):
embedding = self.embedding_function.embed_query(query) embedding = self.embedding_function.embed_query(query)
documents = self.similarity_search_by_vector( documents = self.similarity_search_by_vector(
@ -657,6 +697,7 @@ class OracleVS(VectorStore):
**kwargs: Any, **kwargs: Any,
) -> List[Tuple[Document, float]]: ) -> List[Tuple[Document, float]]:
"""Return docs most similar to query.""" """Return docs most similar to query."""
embedding: List[float] = []
if isinstance(self.embedding_function, Embeddings): if isinstance(self.embedding_function, Embeddings):
embedding = self.embedding_function.embed_query(query) embedding = self.embedding_function.embed_query(query)
docs_and_scores = self.similarity_search_by_vector_with_relevance_scores( docs_and_scores = self.similarity_search_by_vector_with_relevance_scores(
@ -707,25 +748,27 @@ class OracleVS(VectorStore):
embedding_arr = array.array("f", embedding) embedding_arr = array.array("f", embedding)
query = f""" query = f"""
SELECT id, SELECT id,
text, text,
metadata, metadata,
vector_distance(embedding, :embedding, vector_distance(embedding, :embedding,
{_get_distance_function(self.distance_strategy)}) as distance {_get_distance_function(self.distance_strategy)}) as distance
FROM {self.table_name} FROM {self.table_name}
ORDER BY distance ORDER BY distance
FETCH APPROX FIRST {k} ROWS ONLY FETCH APPROX FIRST {k} ROWS ONLY
""" """
# Execute the query # Execute the query
with self.client.cursor() as cursor: connection = _get_connection(self.client)
if connection is None:
raise ValueError("Failed to acquire a connection.")
with connection.cursor() as cursor:
cursor.execute(query, embedding=embedding_arr) cursor.execute(query, embedding=embedding_arr)
results = cursor.fetchall() results = cursor.fetchall()
# Filter results if filter is provided # Filter results if filter is provided
for result in results: for result in results:
metadata = json.loads( metadata = dict(result[2]) if isinstance(result[2], dict) else {}
self._get_clob_value(result[2]) if result[2] is not None else "{}"
)
# Apply filtering based on the 'filter' dictionary # Apply filtering based on the 'filter' dictionary
if filter: if filter:
@ -761,7 +804,7 @@ class OracleVS(VectorStore):
k: int, k: int,
filter: Optional[Dict[str, Any]] = None, filter: Optional[Dict[str, Any]] = None,
**kwargs: Any, **kwargs: Any,
) -> List[Tuple[Document, float, np.ndarray]]: ) -> List[Tuple[Document, float, NDArray[np.float32]]]:
embedding_arr: Any embedding_arr: Any
if self.insert_mode == "clob": if self.insert_mode == "clob":
embedding_arr = json.dumps(embedding) embedding_arr = json.dumps(embedding)
@ -771,27 +814,29 @@ class OracleVS(VectorStore):
documents = [] documents = []
query = f""" query = f"""
SELECT id, SELECT id,
text, text,
metadata, metadata,
vector_distance(embedding, :embedding, { vector_distance(embedding, :embedding, {
_get_distance_function(self.distance_strategy) _get_distance_function(self.distance_strategy)
}) as distance, }) as distance,
embedding embedding
FROM {self.table_name} FROM {self.table_name}
ORDER BY distance ORDER BY distance
FETCH APPROX FIRST {k} ROWS ONLY FETCH APPROX FIRST {k} ROWS ONLY
""" """
# Execute the query # Execute the query
with self.client.cursor() as cursor: connection = _get_connection(self.client)
if connection is None:
raise ValueError("Failed to acquire a connection.")
with connection.cursor() as cursor:
cursor.execute(query, embedding=embedding_arr) cursor.execute(query, embedding=embedding_arr)
results = cursor.fetchall() results = cursor.fetchall()
for result in results: for result in results:
page_content_str = self._get_clob_value(result[1]) page_content_str = self._get_clob_value(result[1])
metadata_str = self._get_clob_value(result[2]) metadata = result[2] if isinstance(result[2], dict) else {}
metadata = json.loads(metadata_str)
# Apply filter if provided and matches; otherwise, add all # Apply filter if provided and matches; otherwise, add all
# documents # documents
@ -984,9 +1029,12 @@ class OracleVS(VectorStore):
f"id{i}": hashed_id for i, hashed_id in enumerate(hashed_ids, start=1) f"id{i}": hashed_id for i, hashed_id in enumerate(hashed_ids, start=1)
} }
with self.client.cursor() as cursor: connection = _get_connection(self.client)
if connection is None:
raise ValueError("Failed to acquire a connection.")
with connection.cursor() as cursor:
cursor.execute(ddl, bind_vars) cursor.execute(ddl, bind_vars)
self.client.commit() connection.commit()
@classmethod @classmethod
@_handle_exceptions @_handle_exceptions
@ -997,10 +1045,10 @@ class OracleVS(VectorStore):
metadatas: Optional[List[dict]] = None, metadatas: Optional[List[dict]] = None,
**kwargs: Any, **kwargs: Any,
) -> OracleVS: ) -> OracleVS:
"""Return VectorStore initialized from texts and embeddings.""" client: Any = kwargs.get("client", None)
client = kwargs.get("client")
if client is None: if client is None:
raise ValueError("client parameter is required...") raise ValueError("client parameter is required...")
params = kwargs.get("params", {}) params = kwargs.get("params", {})
table_name = str(kwargs.get("table_name", "langchain")) table_name = str(kwargs.get("table_name", "langchain"))