mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-20 13:54:48 +00:00
Oraclevs integration (#29723)
Thank you for contributing to LangChain! - [ ] **PR title**: "package: description" - Where "package" is whichever of langchain, community, core, etc. is being modified. Use "docs: ..." for purely docs changes, "infra: ..." for CI changes. - Example: "community: add foobar LLM" community: langchain_community/vectorstore/oraclevs.py - [ ] **PR message**: ***Delete this entire checklist*** and replace with - **Description:** Refactored code to allow a connection or a connection pool. - **Issue:** Normally an idel connection is terminated by the server side listener at timeout. A user thus has to re-instantiate the vector store. The timeout in case of connection is not configurable. The solution is to use a connection pool where a user can specify a user defined timeout and the connections are managed by the pool. - **Dependencies:** None - **Twitter handle:** - [ ] **Add tests and docs**: This is not a new integration. A user can pass either a connection or a connection pool. The determination of what is passed is made at run time. Everything should work as before. - [ ] **Lint and test**: Already done. Additional guidelines: - Make sure optional dependencies are imported within a function. - Please do not add dependencies to pyproject.toml files (even optional ones) unless they are required for unit tests. - Most PRs should not touch more than one package. - Changes should be backwards compatible. - If you are adding something to community, do not re-import it in langchain. If no one reviews your PR within a few days, please @-mention one of baskaryan, efriis, eyurtsev, ccurme, vbarda, hwchase17. --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
42ebf6ae0c
commit
c7d74eb7a3
@ -22,6 +22,8 @@ from typing import (
|
||||
cast,
|
||||
)
|
||||
|
||||
from numpy.typing import NDArray
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from oracledb import Connection
|
||||
|
||||
@ -47,6 +49,31 @@ logging.basicConfig(
|
||||
T = TypeVar("T", bound=Callable[..., Any])
|
||||
|
||||
|
||||
def _get_connection(client: Any) -> Connection | None:
|
||||
# Dynamically import oracledb and the required classes
|
||||
try:
|
||||
import oracledb
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Unable to import oracledb, please install with `pip install -U oracledb`."
|
||||
) from e
|
||||
|
||||
# check if ConnectionPool exists
|
||||
connection_pool_class = getattr(oracledb, "ConnectionPool", None)
|
||||
|
||||
if isinstance(client, oracledb.Connection):
|
||||
return client
|
||||
elif connection_pool_class and isinstance(client, connection_pool_class):
|
||||
return client.acquire()
|
||||
else:
|
||||
valid_types = "oracledb.Connection"
|
||||
if connection_pool_class:
|
||||
valid_types += " or oracledb.ConnectionPool"
|
||||
raise TypeError(
|
||||
f"Expected client of type {valid_types}, got {type(client).__name__}"
|
||||
)
|
||||
|
||||
|
||||
def _handle_exceptions(func: T) -> T:
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
@ -70,7 +97,7 @@ def _handle_exceptions(func: T) -> T:
|
||||
return cast(T, wrapper)
|
||||
|
||||
|
||||
def _table_exists(client: Connection, table_name: str) -> bool:
|
||||
def _table_exists(connection: Connection, table_name: str) -> bool:
|
||||
try:
|
||||
import oracledb
|
||||
except ImportError as e:
|
||||
@ -79,7 +106,7 @@ def _table_exists(client: Connection, table_name: str) -> bool:
|
||||
) from e
|
||||
|
||||
try:
|
||||
with client.cursor() as cursor:
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(f"SELECT COUNT(*) FROM {table_name}")
|
||||
return True
|
||||
except oracledb.DatabaseError as ex:
|
||||
@ -106,7 +133,7 @@ def _compare_version(version: str, target_version: str) -> bool:
|
||||
|
||||
|
||||
@_handle_exceptions
|
||||
def _index_exists(client: Connection, index_name: str) -> bool:
|
||||
def _index_exists(connection: Connection, index_name: str) -> bool:
|
||||
# Check if the index exists
|
||||
query = """
|
||||
SELECT index_name
|
||||
@ -114,7 +141,7 @@ def _index_exists(client: Connection, index_name: str) -> bool:
|
||||
WHERE upper(index_name) = upper(:idx_name)
|
||||
"""
|
||||
|
||||
with client.cursor() as cursor:
|
||||
with connection.cursor() as cursor:
|
||||
# Execute the query
|
||||
cursor.execute(query, idx_name=index_name.upper())
|
||||
result = cursor.fetchone()
|
||||
@ -146,16 +173,16 @@ def _get_index_name(base_name: str) -> str:
|
||||
|
||||
|
||||
@_handle_exceptions
|
||||
def _create_table(client: Connection, table_name: str, embedding_dim: int) -> None:
|
||||
def _create_table(connection: Connection, table_name: str, embedding_dim: int) -> None:
|
||||
cols_dict = {
|
||||
"id": "RAW(16) DEFAULT SYS_GUID() PRIMARY KEY",
|
||||
"text": "CLOB",
|
||||
"metadata": "CLOB",
|
||||
"metadata": "JSON",
|
||||
"embedding": f"vector({embedding_dim}, FLOAT32)",
|
||||
}
|
||||
|
||||
if not _table_exists(client, table_name):
|
||||
with client.cursor() as cursor:
|
||||
if not _table_exists(connection, table_name):
|
||||
with connection.cursor() as cursor:
|
||||
ddl_body = ", ".join(
|
||||
f"{col_name} {col_type}" for col_name, col_type in cols_dict.items()
|
||||
)
|
||||
@ -168,43 +195,45 @@ def _create_table(client: Connection, table_name: str, embedding_dim: int) -> No
|
||||
|
||||
@_handle_exceptions
|
||||
def create_index(
|
||||
client: Connection,
|
||||
client: Any,
|
||||
vector_store: OracleVS,
|
||||
params: Optional[dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""Create an index on the vector store.
|
||||
|
||||
Args:
|
||||
client: The OracleDB connection object.
|
||||
vector_store: The vector store object.
|
||||
params: Optional parameters for the index creation.
|
||||
|
||||
Raises:
|
||||
ValueError: If an invalid parameter is provided.
|
||||
"""
|
||||
connection = _get_connection(client)
|
||||
if connection is None:
|
||||
raise ValueError("Failed to acquire a connection.")
|
||||
if params:
|
||||
if params["idx_type"] == "HNSW":
|
||||
_create_hnsw_index(
|
||||
client, vector_store.table_name, vector_store.distance_strategy, params
|
||||
connection,
|
||||
vector_store.table_name,
|
||||
vector_store.distance_strategy,
|
||||
params,
|
||||
)
|
||||
elif params["idx_type"] == "IVF":
|
||||
_create_ivf_index(
|
||||
client, vector_store.table_name, vector_store.distance_strategy, params
|
||||
connection,
|
||||
vector_store.table_name,
|
||||
vector_store.distance_strategy,
|
||||
params,
|
||||
)
|
||||
else:
|
||||
_create_hnsw_index(
|
||||
client, vector_store.table_name, vector_store.distance_strategy, params
|
||||
connection,
|
||||
vector_store.table_name,
|
||||
vector_store.distance_strategy,
|
||||
params,
|
||||
)
|
||||
else:
|
||||
_create_hnsw_index(
|
||||
client, vector_store.table_name, vector_store.distance_strategy, params
|
||||
connection, vector_store.table_name, vector_store.distance_strategy, params
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
@_handle_exceptions
|
||||
def _create_hnsw_index(
|
||||
client: Connection,
|
||||
connection: Connection,
|
||||
table_name: str,
|
||||
distance_strategy: DistanceStrategy,
|
||||
params: Optional[dict[str, Any]] = None,
|
||||
@ -278,8 +307,8 @@ def _create_hnsw_index(
|
||||
ddl = ddl_assembly.format(**config)
|
||||
|
||||
# Check if the index exists
|
||||
if not _index_exists(client, config["idx_name"]):
|
||||
with client.cursor() as cursor:
|
||||
if not _index_exists(connection, config["idx_name"]):
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(ddl)
|
||||
logger.info("Index created successfully...")
|
||||
else:
|
||||
@ -288,7 +317,7 @@ def _create_hnsw_index(
|
||||
|
||||
@_handle_exceptions
|
||||
def _create_ivf_index(
|
||||
client: Connection,
|
||||
connection: Connection,
|
||||
table_name: str,
|
||||
distance_strategy: DistanceStrategy,
|
||||
params: Optional[dict[str, Any]] = None,
|
||||
@ -350,8 +379,8 @@ def _create_ivf_index(
|
||||
ddl = ddl_assembly.format(**config)
|
||||
|
||||
# Check if the index exists
|
||||
if not _index_exists(client, config["idx_name"]):
|
||||
with client.cursor() as cursor:
|
||||
if not _index_exists(connection, config["idx_name"]):
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(ddl)
|
||||
logger.info("Index created successfully...")
|
||||
else:
|
||||
@ -359,7 +388,7 @@ def _create_ivf_index(
|
||||
|
||||
|
||||
@_handle_exceptions
|
||||
def drop_table_purge(client: Connection, table_name: str) -> None:
|
||||
def drop_table_purge(client: Any, table_name: str) -> None:
|
||||
"""Drop a table and purge it from the database.
|
||||
|
||||
Args:
|
||||
@ -369,9 +398,11 @@ def drop_table_purge(client: Connection, table_name: str) -> None:
|
||||
Raises:
|
||||
RuntimeError: If an error occurs while dropping the table.
|
||||
"""
|
||||
if _table_exists(client, table_name):
|
||||
cursor = client.cursor()
|
||||
with cursor:
|
||||
connection = _get_connection(client)
|
||||
if connection is None:
|
||||
raise ValueError("Failed to acquire a connection.")
|
||||
if _table_exists(connection, table_name):
|
||||
with connection.cursor() as cursor:
|
||||
ddl = f"DROP TABLE {table_name} PURGE"
|
||||
cursor.execute(ddl)
|
||||
logger.info("Table dropped successfully...")
|
||||
@ -381,7 +412,7 @@ def drop_table_purge(client: Connection, table_name: str) -> None:
|
||||
|
||||
|
||||
@_handle_exceptions
|
||||
def drop_index_if_exists(client: Connection, index_name: str) -> None:
|
||||
def drop_index_if_exists(client: Any, index_name: str) -> None:
|
||||
"""Drop an index if it exists.
|
||||
|
||||
Args:
|
||||
@ -391,9 +422,12 @@ def drop_index_if_exists(client: Connection, index_name: str) -> None:
|
||||
Raises:
|
||||
RuntimeError: If an error occurs while dropping the index.
|
||||
"""
|
||||
if _index_exists(client, index_name):
|
||||
connection = _get_connection(client)
|
||||
if connection is None:
|
||||
raise ValueError("Failed to acquire a connection.")
|
||||
if _index_exists(connection, index_name):
|
||||
drop_query = f"DROP INDEX {index_name}"
|
||||
with client.cursor() as cursor:
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(drop_query)
|
||||
logger.info(f"Index {index_name} has been dropped.")
|
||||
else:
|
||||
@ -426,7 +460,7 @@ class OracleVS(VectorStore):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: Connection,
|
||||
client: Any,
|
||||
embedding_function: Union[
|
||||
Callable[[str], List[float]],
|
||||
Embeddings,
|
||||
@ -445,8 +479,11 @@ class OracleVS(VectorStore):
|
||||
) from e
|
||||
|
||||
self.insert_mode = "array"
|
||||
connection = _get_connection(client)
|
||||
if connection is None:
|
||||
raise ValueError("Failed to acquire a connection.")
|
||||
|
||||
if client.thin is True:
|
||||
if hasattr(connection, "thin") and connection.thin:
|
||||
if oracledb.__version__ == "2.1.0":
|
||||
raise Exception(
|
||||
"Oracle DB python thin client driver version 2.1.0 not supported"
|
||||
@ -494,8 +531,7 @@ class OracleVS(VectorStore):
|
||||
self.table_name = table_name
|
||||
self.distance_strategy = distance_strategy
|
||||
self.params = params
|
||||
|
||||
_create_table(client, table_name, embedding_dim)
|
||||
_create_table(connection, table_name, embedding_dim)
|
||||
except oracledb.DatabaseError as db_err:
|
||||
logger.exception(f"Database error occurred while create table: {db_err}")
|
||||
raise RuntimeError(
|
||||
@ -613,13 +649,16 @@ class OracleVS(VectorStore):
|
||||
)
|
||||
]
|
||||
|
||||
with self.client.cursor() as cursor:
|
||||
connection = _get_connection(self.client)
|
||||
if connection is None:
|
||||
raise ValueError("Failed to acquire a connection.")
|
||||
with connection.cursor() as cursor:
|
||||
cursor.executemany(
|
||||
f"INSERT INTO {self.table_name} (id, embedding, metadata, "
|
||||
f"text) VALUES (:1, :2, :3, :4)",
|
||||
docs,
|
||||
)
|
||||
self.client.commit()
|
||||
connection.commit()
|
||||
return processed_ids
|
||||
|
||||
def similarity_search(
|
||||
@ -630,6 +669,7 @@ class OracleVS(VectorStore):
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return docs most similar to query."""
|
||||
embedding: List[float] = []
|
||||
if isinstance(self.embedding_function, Embeddings):
|
||||
embedding = self.embedding_function.embed_query(query)
|
||||
documents = self.similarity_search_by_vector(
|
||||
@ -657,6 +697,7 @@ class OracleVS(VectorStore):
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Return docs most similar to query."""
|
||||
embedding: List[float] = []
|
||||
if isinstance(self.embedding_function, Embeddings):
|
||||
embedding = self.embedding_function.embed_query(query)
|
||||
docs_and_scores = self.similarity_search_by_vector_with_relevance_scores(
|
||||
@ -716,16 +757,18 @@ class OracleVS(VectorStore):
|
||||
ORDER BY distance
|
||||
FETCH APPROX FIRST {k} ROWS ONLY
|
||||
"""
|
||||
|
||||
# Execute the query
|
||||
with self.client.cursor() as cursor:
|
||||
connection = _get_connection(self.client)
|
||||
if connection is None:
|
||||
raise ValueError("Failed to acquire a connection.")
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(query, embedding=embedding_arr)
|
||||
results = cursor.fetchall()
|
||||
|
||||
# Filter results if filter is provided
|
||||
for result in results:
|
||||
metadata = json.loads(
|
||||
self._get_clob_value(result[2]) if result[2] is not None else "{}"
|
||||
)
|
||||
metadata = dict(result[2]) if isinstance(result[2], dict) else {}
|
||||
|
||||
# Apply filtering based on the 'filter' dictionary
|
||||
if filter:
|
||||
@ -761,7 +804,7 @@ class OracleVS(VectorStore):
|
||||
k: int,
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float, np.ndarray]]:
|
||||
) -> List[Tuple[Document, float, NDArray[np.float32]]]:
|
||||
embedding_arr: Any
|
||||
if self.insert_mode == "clob":
|
||||
embedding_arr = json.dumps(embedding)
|
||||
@ -784,14 +827,16 @@ class OracleVS(VectorStore):
|
||||
"""
|
||||
|
||||
# Execute the query
|
||||
with self.client.cursor() as cursor:
|
||||
connection = _get_connection(self.client)
|
||||
if connection is None:
|
||||
raise ValueError("Failed to acquire a connection.")
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(query, embedding=embedding_arr)
|
||||
results = cursor.fetchall()
|
||||
|
||||
for result in results:
|
||||
page_content_str = self._get_clob_value(result[1])
|
||||
metadata_str = self._get_clob_value(result[2])
|
||||
metadata = json.loads(metadata_str)
|
||||
metadata = result[2] if isinstance(result[2], dict) else {}
|
||||
|
||||
# Apply filter if provided and matches; otherwise, add all
|
||||
# documents
|
||||
@ -984,9 +1029,12 @@ class OracleVS(VectorStore):
|
||||
f"id{i}": hashed_id for i, hashed_id in enumerate(hashed_ids, start=1)
|
||||
}
|
||||
|
||||
with self.client.cursor() as cursor:
|
||||
connection = _get_connection(self.client)
|
||||
if connection is None:
|
||||
raise ValueError("Failed to acquire a connection.")
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(ddl, bind_vars)
|
||||
self.client.commit()
|
||||
connection.commit()
|
||||
|
||||
@classmethod
|
||||
@_handle_exceptions
|
||||
@ -997,10 +1045,10 @@ class OracleVS(VectorStore):
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
**kwargs: Any,
|
||||
) -> OracleVS:
|
||||
"""Return VectorStore initialized from texts and embeddings."""
|
||||
client = kwargs.get("client")
|
||||
client: Any = kwargs.get("client", None)
|
||||
if client is None:
|
||||
raise ValueError("client parameter is required...")
|
||||
|
||||
params = kwargs.get("params", {})
|
||||
|
||||
table_name = str(kwargs.get("table_name", "langchain"))
|
||||
|
Loading…
Reference in New Issue
Block a user