diff --git a/dbgpt/_private/config.py b/dbgpt/_private/config.py index 35c424029..99eb0e2eb 100644 --- a/dbgpt/_private/config.py +++ b/dbgpt/_private/config.py @@ -232,11 +232,9 @@ class Config(metaclass=Singleton): self.OB_USER = os.getenv("OB_USER", "root") self.OB_PASSWORD = os.getenv("OB_PASSWORD", "") self.OB_DATABASE = os.getenv("OB_DATABASE", "test") - self.OB_SQL_DBG_LOG_PATH = os.getenv("OB_SQL_DBG_LOG_PATH", "") self.OB_ENABLE_NORMALIZE_VECTOR = bool( os.getenv("OB_ENABLE_NORMALIZE_VECTOR", "") ) - self.OB_ENABLE_INDEX = bool(os.getenv("OB_ENABLE_INDEX", "")) # QLoRA self.QLoRA = os.getenv("QUANTIZE_QLORA", "True") diff --git a/dbgpt/storage/vector_store/oceanbase_store.py b/dbgpt/storage/vector_store/oceanbase_store.py index 4b30f0ff7..c9a2545af 100644 --- a/dbgpt/storage/vector_store/oceanbase_store.py +++ b/dbgpt/storage/vector_store/oceanbase_store.py @@ -1,633 +1,75 @@ """OceanBase vector store.""" import json import logging +import math import os -import threading import uuid -from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple +from typing import Any, List, Optional, Tuple +import numpy as np from pydantic import Field -from sqlalchemy import Column, Table, create_engine, insert, text -from sqlalchemy.dialects.mysql import JSON, LONGTEXT, VARCHAR -from sqlalchemy.types import String, UserDefinedType +from sqlalchemy import JSON, Column, String, Table, func, text +from sqlalchemy.dialects.mysql import LONGTEXT -from dbgpt.core import Chunk, Document, Embeddings +from dbgpt.core import Chunk from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource from dbgpt.storage.vector_store.base import ( _COMMON_PARAMETERS, VectorStoreBase, VectorStoreConfig, ) -from dbgpt.storage.vector_store.filters import MetadataFilters +from dbgpt.storage.vector_store.filters import FilterOperator, MetadataFilters from dbgpt.util.i18n_utils import _ -try: - from sqlalchemy.orm import declarative_base -except ImportError: - from sqlalchemy.ext.declarative import declarative_base - - logger = logging.getLogger(__name__) -sql_logger = None -sql_dbg_log_path = os.getenv("OB_SQL_DBG_LOG_PATH", "") -if sql_dbg_log_path != "": - sql_logger = logging.getLogger("ob_sql_dbg") - sql_logger.setLevel(logging.DEBUG) - file_handler = logging.FileHandler(sql_dbg_log_path) - file_handler.setLevel(logging.DEBUG) - formatter = logging.Formatter( - "%(asctime)s - %(name)s - %(levelname)s - %(message)s" - ) - file_handler.setFormatter(formatter) - sql_logger.addHandler(file_handler) -_OCEANBASE_DEFAULT_EMBEDDING_DIM = 1536 -_OCEANBASE_DEFAULT_COLLECTION_NAME = "langchain_document" -_OCEANBASE_DEFAULT_IVFFLAT_ROW_THRESHOLD = 10000 -_OCEANBASE_DEFAULT_RWLOCK_MAX_READER = 64 +DEFAULT_OCEANBASE_BATCH_SIZE = 100 +DEFAULT_OCEANBASE_VECTOR_TABLE_NAME = "dbgpt_vector" +DEFAULT_OCEANBASE_HNSW_BUILD_PARAM = {"M": 16, "efConstruction": 256} +DEFAULT_OCEANBASE_HNSW_SEARCH_PARAM = {"efSearch": 64} +OCEANBASE_SUPPORTED_VECTOR_INDEX_TYPE = "HNSW" +DEFAULT_OCEANBASE_VECTOR_METRIC_TYPE = "l2" -Base = declarative_base() +DEFAULT_OCEANBASE_PFIELD = "id" +DEFAULT_OCEANBASE_DOCID_FIELD = "doc_id" +DEFAULT_OCEANBASE_VEC_FIELD = "embedding" +DEFAULT_OCEANBASE_DOC_FIELD = "document" +DEFAULT_OCEANBASE_METADATA_FIELD = "metadata" +DEFAULT_OCEANBASE_VEC_INDEX_NAME = "vidx" -def ob_vector_from_db(value): - """Parse vector from oceanbase.""" - return [float(v) for v in value[1:-1].split(",")] +def _parse_filter_value(filter_value: Any, is_text_match: bool = False): + if filter_value is None: + return filter_value -def ob_vector_to_db(value, dim=None): - """Parse vector to oceanbase vector constant type.""" - if value is None: - return value + if is_text_match: + return f"'{filter_value!s}%'" - return "[" + ",".join([str(float(v)) for v in value]) + "]" + if isinstance(filter_value, str): + return f"'{filter_value!s}'" + if isinstance(filter_value, list): + if all(isinstance(item, str) for item in filter_value): + return "(" + ",".join([f"'{str(v)}'" for v in filter_value]) + ")" + return "(" + ",".join([str(v) for v in filter_value]) + ")" -class Vector(UserDefinedType): - """OceanBase Vector Column Type.""" + return str(filter_value) - cache_ok = True - _string = String() - def __init__(self, dim): - """Create a Vector column with dimemsion `dim`.""" - super(UserDefinedType, self).__init__() - self.dim = dim +def _euclidean_similarity(distance: float) -> float: + return 1.0 - distance / math.sqrt(2) - def get_col_spec(self, **kw): - """Get vector column definition in string format.""" - return "VECTOR(%d)" % self.dim - def bind_processor(self, dialect): - """Get a processor to parse an array to oceanbase vector.""" +def _neg_inner_product_similarity(distance: float) -> float: + return -distance - def process(value): - return ob_vector_to_db(value, self.dim) - return process - - def literal_processor(self, dialect): - """Get a string processor to parse an array to OceanBase Vector.""" - string_literal_processor = self._string._cached_literal_processor(dialect) - - def process(value): - return string_literal_processor(ob_vector_to_db(value, self.dim)) - - return process - - def result_processor(self, dialect, coltype): - """Get a processor to parse OceanBase Vector to array.""" - - def process(value): - return ob_vector_from_db(value) - - return process - - -class OceanBaseCollectionStat: - """A tracer that maintains a table status in OceanBase.""" - - def __init__(self): - """Create OceanBaseCollectionStat instance.""" - self._lock = threading.Lock() - self.maybe_collection_not_exist = True - self.maybe_collection_index_not_exist = True - - def collection_exists(self): - """Set a table is existing.""" - with self._lock: - self.maybe_collection_not_exist = False - - def collection_index_exists(self): - """Set the index of a table is existing.""" - with self._lock: - self.maybe_collection_index_not_exist = False - - def collection_not_exists(self): - """Set a table is dropped.""" - with self._lock: - self.maybe_collection_not_exist = True - - def collection_index_not_exists(self): - """Set the index of a table is dropped.""" - with self._lock: - self.maybe_collection_index_not_exist = True - - def get_maybe_collection_not_exist(self): - """Get the existing status of a table.""" - with self._lock: - return self.maybe_collection_not_exist - - def get_maybe_collection_index_not_exist(self): - """Get the existing stats of the index of a table.""" - with self._lock: - return self.maybe_collection_index_not_exist - - -class OceanBaseGlobalRWLock: - """A global rwlock for OceanBase to do creating vector index table offline ddl.""" - - def __init__(self, max_readers) -> None: - """Create a rwlock.""" - self.max_readers_ = max_readers - self.writer_entered_ = False - self.reader_cnt_ = 0 - self.mutex_ = threading.Lock() - self.writer_cv_ = threading.Condition(self.mutex_) - self.reader_cv_ = threading.Condition(self.mutex_) - - def rlock(self): - """Lock for reading.""" - self.mutex_.acquire() - while self.writer_entered_ or self.max_readers_ == self.reader_cnt_: - self.reader_cv_.wait() - self.reader_cnt_ += 1 - self.mutex_.release() - - def runlock(self): - """Unlock reading lock.""" - self.mutex_.acquire() - self.reader_cnt_ -= 1 - if self.writer_entered_: - if self.reader_cnt_ == 0: - self.writer_cv_.notify(1) - else: - if self.max_readers_ - 1 == self.reader_cnt_: - self.reader_cv_.notify(1) - self.mutex_.release() - - def wlock(self): - """Lock for writing.""" - self.mutex_.acquire() - while self.writer_entered_: - self.reader_cv_.wait() - self.writer_entered_ = True - while 0 < self.reader_cnt_: - self.writer_cv_.wait() - self.mutex_.release() - - def wunlock(self): - """Unlock writing lock.""" - self.mutex_.acquire() - self.writer_entered_ = False - self.reader_cv_.notifyAll() - self.mutex_.release() - - class OBRLock: - """Reading Lock Wrapper for `with` clause.""" - - def __init__(self, rwlock) -> None: - """Create reading lock wrapper instance.""" - self.rwlock_ = rwlock - - def __enter__(self): - """Lock.""" - self.rwlock_.rlock() - - def __exit__(self, exc_type, exc_value, traceback): - """Unlock when exiting.""" - self.rwlock_.runlock() - - class OBWLock: - """Writing Lock Wrapper for `with` clause.""" - - def __init__(self, rwlock) -> None: - """Create writing lock wrapper instance.""" - self.rwlock_ = rwlock - - def __enter__(self): - """Lock.""" - self.rwlock_.wlock() - - def __exit__(self, exc_type, exc_value, traceback): - """Unlock when exiting.""" - self.rwlock_.wunlock() - - def reader_lock(self): - """Get reading lock wrapper.""" - return self.OBRLock(self) - - def writer_lock(self): - """Get writing lock wrapper.""" - return self.OBWLock(self) - - -ob_grwlock = OceanBaseGlobalRWLock(_OCEANBASE_DEFAULT_RWLOCK_MAX_READER) - - -class OceanBase: - """OceanBase Vector Store implementation.""" - - def __init__( - self, - database: str, - connection_string: str, - embedding_function: Embeddings, - embedding_dimension: int = _OCEANBASE_DEFAULT_EMBEDDING_DIM, - collection_name: str = _OCEANBASE_DEFAULT_COLLECTION_NAME, - pre_delete_collection: bool = False, - logger: Optional[logging.Logger] = None, - engine_args: Optional[dict] = None, - delay_table_creation: bool = True, - enable_index: bool = False, - th_create_ivfflat_index: int = _OCEANBASE_DEFAULT_IVFFLAT_ROW_THRESHOLD, - sql_logger: Optional[logging.Logger] = None, - collection_stat: Optional[OceanBaseCollectionStat] = None, - enable_normalize_vector: bool = False, - ) -> None: - """Create OceanBase Vector Store instance.""" - self.database = database - self.connection_string = connection_string - self.embedding_function = embedding_function - self.embedding_dimension = embedding_dimension - self.collection_name = collection_name - self.pre_delete_collection = pre_delete_collection - self.logger = logger or logging.getLogger(__name__) - self.delay_table_creation = delay_table_creation - self.th_create_ivfflat_index = th_create_ivfflat_index - self.enable_index = enable_index - self.sql_logger = sql_logger - self.collection_stat = collection_stat - self.enable_normalize_vector = enable_normalize_vector - self.__post_init__(engine_args) - - def __post_init__( - self, - engine_args: Optional[dict] = None, - ) -> None: - """Create connection & vector table.""" - _engine_args = engine_args or {} - if "pool_recycle" not in _engine_args: - _engine_args["pool_recycle"] = 3600 - self.engine = create_engine(self.connection_string, **_engine_args) - self.create_collection() - - @property - def embeddings(self) -> Embeddings: - """Get embedding function.""" - return self.embedding_function - - def create_collection(self) -> None: - """Create vector table.""" - if self.pre_delete_collection: - self.delete_collection() - if not self.delay_table_creation and ( - self.collection_stat is None - or self.collection_stat.get_maybe_collection_not_exist() - ): - self.create_table_if_not_exists() - if self.collection_stat is not None: - self.collection_stat.collection_exists() - - def delete_collection(self) -> None: - """Drop vector table.""" - drop_statement = text(f"DROP TABLE IF EXISTS {self.collection_name}") - if self.sql_logger is not None: - self.sql_logger.debug(f"Trying to delete collection: {drop_statement}") - with self.engine.connect() as conn, conn.begin(): - conn.execute(drop_statement) - if self.collection_stat is not None: - self.collection_stat.collection_not_exists() - self.collection_stat.collection_index_not_exists() - - def create_table_if_not_exists(self) -> None: - """Create vector table with SQL.""" - create_table_query = f""" - CREATE TABLE IF NOT EXISTS `{self.collection_name}` ( - id VARCHAR(40) NOT NULL, - embedding VECTOR({self.embedding_dimension}), - document LONGTEXT, - metadata JSON, - PRIMARY KEY (id) - ) - """ - if self.sql_logger is not None: - self.sql_logger.debug(f"Trying to create table: {create_table_query}") - with self.engine.connect() as conn, conn.begin(): - # Create the table - conn.execute(text(create_table_query)) - - def create_collection_ivfflat_index_if_not_exists(self) -> None: - """Create ivfflat index table with SQL.""" - create_index_query = f""" - CREATE INDEX IF NOT EXISTS `embedding_idx` on `{self.collection_name}` ( - embedding l2 - ) using ivfflat with (lists=20) - """ - with ob_grwlock.writer_lock(), self.engine.connect() as conn, conn.begin(): - # Create Ivfflat Index - if self.sql_logger is not None: - self.sql_logger.debug( - f"Trying to create ivfflat index: {create_index_query}" - ) - conn.execute(text(create_index_query)) - - def check_table_exists(self) -> bool: - """Whether table `collection_name` exists.""" - check_table_query = f""" - SELECT COUNT(*) as cnt - FROM information_schema.tables - WHERE table_schema='{self.database}' AND table_name='{self.collection_name}' - """ - try: - with self.engine.connect() as conn, conn.begin(), ob_grwlock.reader_lock(): - table_exists_res = conn.execute(text(check_table_query)) - for row in table_exists_res: - return row.cnt > 0 - # No `cnt` rows? Just return False to pass `make mypy` - return False - except Exception as e: - logger.error(f"check_table_exists error: {e}") - return False - - def add_texts( - self, - texts: Iterable[str], - metadatas: Optional[List[dict]] = None, - ids: Optional[List[str]] = None, - batch_size: int = 500, - **kwargs: Any, - ) -> List[str]: - """Insert texts into vector table.""" - if ids is None: - ids = [str(uuid.uuid1()) for _ in texts] - - embeddings = self.embedding_function.embed_documents(list(texts)) - - if len(embeddings) == 0: - return ids - - if not metadatas: - metadatas = [{} for _ in texts] - - if self.delay_table_creation and ( - self.collection_stat is None - or self.collection_stat.get_maybe_collection_not_exist() - ): - self.embedding_dimension = len(embeddings[0]) - self.create_table_if_not_exists() - self.delay_table_creation = False - if self.collection_stat is not None: - self.collection_stat.collection_exists() - - chunks_table = Table( - self.collection_name, - Base.metadata, - Column("id", VARCHAR(40), primary_key=True), - Column("embedding", Vector(self.embedding_dimension)), - Column("document", LONGTEXT, nullable=True), - Column("metadata", JSON, nullable=True), # filter - keep_existing=True, - ) - - row_count_query = f""" - SELECT COUNT(*) as cnt FROM `{self.collection_name}` - """ - chunks_table_data = [] - with self.engine.connect() as conn, conn.begin(): - for document, metadata, chunk_id, embedding in zip( - texts, metadatas, ids, embeddings - ): - chunks_table_data.append( - { - "id": chunk_id, - "embedding": embedding - if not self.enable_normalize_vector - else self._normalization_vectors(embedding), - "document": document, - "metadata": metadata, - } - ) - - # Execute the batch insert when the batch size is reached - if len(chunks_table_data) == batch_size: - with ob_grwlock.reader_lock(): - if self.sql_logger is not None: - insert_sql_for_log = str( - insert(chunks_table).values(chunks_table_data) - ) - self.sql_logger.debug( - f"""Trying to insert vectors: - {insert_sql_for_log}""" - ) - conn.execute(insert(chunks_table).values(chunks_table_data)) - # Clear the chunks_table_data list for the next batch - chunks_table_data.clear() - - # Insert any remaining records that didn't make up a full batch - if chunks_table_data: - with ob_grwlock.reader_lock(): - if self.sql_logger is not None: - insert_sql_for_log = str( - insert(chunks_table).values(chunks_table_data) - ) - self.sql_logger.debug( - f"""Trying to insert vectors: - {insert_sql_for_log}""" - ) - conn.execute(insert(chunks_table).values(chunks_table_data)) - - if self.enable_index and ( - self.collection_stat is None - or self.collection_stat.get_maybe_collection_index_not_exist() - ): - with ob_grwlock.reader_lock(): - row_cnt_res = conn.execute(text(row_count_query)) - for row in row_cnt_res: - if row.cnt > self.th_create_ivfflat_index: - self.create_collection_ivfflat_index_if_not_exists() - if self.collection_stat is not None: - self.collection_stat.collection_index_exists() - - return ids - - def similarity_search( - self, - query: str, - k: int = 4, - filter: Optional[MetadataFilters] = None, - **kwargs: Any, - ) -> List[Document]: - """Do similarity search via query in String.""" - embedding = self.embedding_function.embed_query(query) - docs = self.similarity_search_by_vector(embedding=embedding, k=k, filter=filter) - return docs - - def similarity_search_by_vector( - self, - embedding: List[float], - k: int = 4, - filter: Optional[MetadataFilters] = None, - **kwargs: Any, - ) -> List[Document]: - """Do similarity search via query vector.""" - docs_and_scores = self.similarity_search_with_score_by_vector( - embedding=embedding, k=k, filter=filter - ) - return [doc for doc, _ in docs_and_scores] - - def similarity_search_with_score_by_vector( - self, - embedding: List[float], - k: int = 4, - filter: Optional[MetadataFilters] = None, - score_threshold: Optional[float] = None, - ) -> List[Tuple[Document, float]]: - """Do similarity search via query vector with score.""" - try: - from sqlalchemy.engine import Row - except ImportError: - raise ImportError( - "Could not import Row from sqlalchemy.engine. " - "Please 'pip install sqlalchemy>=1.4'." - ) - - # filter is not support in OceanBase currently. - - # normailze embedding vector - if self.enable_normalize_vector: - embedding = self._normalization_vectors(embedding) - - embedding_str = ob_vector_to_db(embedding, self.embedding_dimension) - vector_distance_op = "<@>" if self.enable_normalize_vector else "<->" - sql_query = f""" - SELECT document, metadata, embedding {vector_distance_op} '{embedding_str}' - as distance - FROM {self.collection_name} - ORDER BY embedding {vector_distance_op} '{embedding_str}' - LIMIT :k - """ - sql_query_str_for_log = f""" - SELECT document, metadata, embedding {vector_distance_op} '?' as distance - FROM {self.collection_name} - ORDER BY embedding {vector_distance_op} '?' - LIMIT {k} - """ - - params = {"k": k} - try: - with ob_grwlock.reader_lock(), self.engine.connect() as conn: - if self.sql_logger is not None: - self.sql_logger.debug( - f"Trying to do similarity search: {sql_query_str_for_log}" - ) - results: Sequence[Row] = conn.execute( - text(sql_query), params - ).fetchall() - - if (score_threshold is not None) and self.enable_normalize_vector: - documents_with_scores = [ - ( - Document( - content=result.document, - metadata=json.loads(result.metadata), - ), - result.distance, - ) - for result in results - if result.distance >= score_threshold - ] - else: - documents_with_scores = [ - ( - Document( - content=result.document, - metadata=json.loads(result.metadata), - ), - result.distance, - ) - for result in results - ] - return documents_with_scores - except Exception as e: - self.logger.error("similarity_search_with_score_by_vector failed:", str(e)) - return [] - - def similarity_search_with_score( - self, - query: str, - k: int = 4, - filter: Optional[MetadataFilters] = None, - score_threshold: Optional[float] = None, - ) -> List[Tuple[Document, float]]: - """Do similarity search via query String with score.""" - embedding = self.embedding_function.embed_query(query) - docs = self.similarity_search_with_score_by_vector( - embedding=embedding, k=k, filter=filter, score_threshold=score_threshold - ) - return docs - - def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]: - """Delete vectors from vector table.""" - if ids is None: - raise ValueError("No ids provided to delete.") - - # Define the table schema - chunks_table = Table( - self.collection_name, - Base.metadata, - Column("id", VARCHAR(40), primary_key=True), - Column("embedding", Vector(self.embedding_dimension)), - Column("document", LONGTEXT, nullable=True), - Column("metadata", JSON, nullable=True), # filter - keep_existing=True, - ) - - try: - with self.engine.connect() as conn, conn.begin(): - delete_condition = chunks_table.c.id.in_(ids) - delete_stmt = chunks_table.delete().where(delete_condition) - with ob_grwlock.reader_lock(): - if self.sql_logger is not None: - self.sql_logger.debug( - f"Trying to delete vectors: {str(delete_stmt)}" - ) - conn.execute(delete_stmt) - return True - except Exception as e: - self.logger.error("Delete operation failed:", str(e)) - return False - - def _normalization_vectors(self, vector): - import numpy as np - - norm = np.linalg.norm(vector) - return (vector / norm).tolist() - - @classmethod - def connection_string_from_db_params( - cls, - host: str, - port: int, - database: str, - user: str, - password: str, - ) -> str: - """Get connection string.""" - return f"mysql+pymysql://{user}:{password}@{host}:{port}/{database}" - - -ob_collection_stats_lock = threading.Lock() -ob_collection_stats: Dict[str, OceanBaseCollectionStat] = {} +def _normalize(vector: List[float]) -> List[float]: + arr = np.array(vector) + norm = np.linalg.norm(arr) + arr = arr / norm + return arr.tolist() @register_resource( @@ -715,106 +157,333 @@ class OceanBaseStore(VectorStoreBase): def __init__(self, vector_store_config: OceanBaseConfig) -> None: """Create a OceanBaseStore instance.""" + try: + from pyobvector import ObVecClient # type: ignore + except ImportError: + raise ImportError( + "Could not import pyobvector package. " + "Please install it with `pip install pyobvector`." + ) + if vector_store_config.embedding_fn is None: raise ValueError("embedding_fn is required for OceanBaseStore") + super().__init__() + self._vector_store_config = vector_store_config + self.embedding_function = vector_store_config.embedding_fn + self.table_name = vector_store_config.name - self.embeddings = vector_store_config.embedding_fn - self.collection_name = vector_store_config.name - vector_store_config = vector_store_config.dict() - self.OB_HOST = str( - vector_store_config.get("ob_host") or os.getenv("OB_HOST", "127.0.0.1") + vector_store_config_map = vector_store_config.to_dict() + OB_HOST = str( + vector_store_config_map.get("ob_host") or os.getenv("OB_HOST", "127.0.0.1") ) - self.OB_PORT = int( - vector_store_config.get("ob_port") or int(os.getenv("OB_PORT", "2881")) + OB_PORT = int( + vector_store_config_map.get("ob_port") or int(os.getenv("OB_PORT", "2881")) ) - self.OB_USER = str( - vector_store_config.get("ob_user") or os.getenv("OB_USER", "root@test") + OB_USER = str( + vector_store_config_map.get("ob_user") or os.getenv("OB_USER", "root@test") ) - self.OB_PASSWORD = str( - vector_store_config.get("ob_password") or os.getenv("OB_PASSWORD", "") + OB_PASSWORD = str( + vector_store_config_map.get("ob_password") or os.getenv("OB_PASSWORD", "") ) - self.OB_DATABASE = str( - vector_store_config.get("ob_database") or os.getenv("OB_DATABASE", "test") + OB_DATABASE = str( + vector_store_config_map.get("ob_database") + or os.getenv("OB_DATABASE", "test") ) - self.OB_ENABLE_NORMALIZE_VECTOR = bool( - os.getenv("OB_ENABLE_NORMALIZE_VECTOR", "") - ) - self.connection_string = OceanBase.connection_string_from_db_params( - self.OB_HOST, self.OB_PORT, self.OB_DATABASE, self.OB_USER, self.OB_PASSWORD - ) - self.logger = logger - with ob_collection_stats_lock: - if ob_collection_stats.get(self.collection_name) is None: - ob_collection_stats[self.collection_name] = OceanBaseCollectionStat() - self.collection_stat = ob_collection_stats[self.collection_name] - self.vector_store_client = OceanBase( - database=self.OB_DATABASE, - connection_string=self.connection_string, - embedding_function=self.embeddings, - collection_name=self.collection_name, - logger=self.logger, - sql_logger=sql_logger, - enable_index=bool(os.getenv("OB_ENABLE_INDEX", "")), - collection_stat=self.collection_stat, - enable_normalize_vector=self.OB_ENABLE_NORMALIZE_VECTOR, + self.normalize = bool(os.getenv("OB_ENABLE_NORMALIZE_VECTOR", "")) + self.vidx_metric_type = DEFAULT_OCEANBASE_VECTOR_METRIC_TYPE + self.vidx_algo_params = DEFAULT_OCEANBASE_HNSW_BUILD_PARAM + self.primary_field = DEFAULT_OCEANBASE_PFIELD + self.vector_field = DEFAULT_OCEANBASE_VEC_FIELD + self.text_field = DEFAULT_OCEANBASE_DOC_FIELD + self.metadata_field = DEFAULT_OCEANBASE_METADATA_FIELD + self.vidx_name = DEFAULT_OCEANBASE_VEC_INDEX_NAME + self.hnsw_ef_search = -1 + + self.vector_store_client = ObVecClient( + uri=OB_HOST + ":" + str(OB_PORT), + user=OB_USER, + password=OB_PASSWORD, + db_name=OB_DATABASE, + echo=True, ) def get_config(self) -> OceanBaseConfig: """Get the vector store config.""" return self._vector_store_config - def similar_search( - self, text, topk, filters: Optional[MetadataFilters] = None, **kwargs: Any - ) -> List[Chunk]: - """Perform a search on a query string and return results.""" - self.logger.info("OceanBase: similar_search..") - documents = self.vector_store_client.similarity_search( - text, topk, filter=filters - ) - return [Chunk(content=doc.content, metadata=doc.metadata) for doc in documents] + def vector_name_exists(self) -> bool: + """Whether vector name exists.""" + return self.vector_store_client.check_table_exists(table_name=self.table_name) - def similar_search_with_scores( - self, - text, - topk, - score_threshold: float, - filters: Optional[MetadataFilters] = None, - ) -> List[Chunk]: - """Perform a search on a query string and return results with score.""" - self.logger.info("OceanBase: similar_search_with_scores..") - docs_and_scores = self.vector_store_client.similarity_search_with_score( - text, topk, filter=filters + def _load_table(self) -> None: + table = Table( + self.table_name, + self.vector_store_client.metadata_obj, + autoload_with=self.vector_store_client.engine, ) - return [ - Chunk(content=doc.content, metadata=doc.metadata, score=score) - for doc, score in docs_and_scores + column_names = [column.name for column in table.columns] + assert len(column_names) == 4 + + self.primary_field = column_names[0] + self.vector_field = column_names[1] + self.text_field = column_names[2] + self.metadata_field = column_names[3] + + def _create_table_with_index(self, embeddings: list) -> None: + try: + from pyobvector import VECTOR + except ImportError: + raise ImportError( + "Could not import pyobvector package. " + "Please install it with `pip install pyobvector`." + ) + + if self.vector_store_client.check_table_exists(self.table_name): + self._load_table() + return + + dim = len(embeddings[0]) + cols = [ + Column( + self.primary_field, String(4096), primary_key=True, autoincrement=False + ), + Column(self.vector_field, VECTOR(dim)), + Column(self.text_field, LONGTEXT), + Column(self.metadata_field, JSON), ] - def vector_name_exists(self): - """Whether vector name exists.""" - self.logger.info("OceanBase: vector_name_exists..") - return self.vector_store_client.check_table_exists() + vidx_params = self.vector_store_client.prepare_index_params() + vidx_params.add_index( + field_name=self.vector_field, + index_type=OCEANBASE_SUPPORTED_VECTOR_INDEX_TYPE, + index_name=self.vidx_name, + metric_type=self.vidx_metric_type, + params=self.vidx_algo_params, + ) + + self.vector_store_client.create_table_with_index_params( + table_name=self.table_name, + columns=cols, + indexes=None, + vidxs=vidx_params, + ) def load_document(self, chunks: List[Chunk]) -> List[str]: """Load document in vector database.""" - self.logger.info("OceanBase: load_document..") - # lc_documents = [Chunk.chunk2langchain(chunk) for chunk in chunks] - texts = [chunk.content for chunk in chunks] - metadatas = [chunk.metadata for chunk in chunks] - ids = self.vector_store_client.add_texts(texts=texts, metadatas=metadatas) - return ids + batch_size = 100 + texts = [d.content for d in chunks] + metadatas = [d.metadata for d in chunks] + embeddings = self.embedding_function.embed_documents(texts) - def delete_vector_name(self, vector_name): + self._create_table_with_index(embeddings) + + ids = [str(uuid.uuid4()) for _ in texts] + pks: list[str] = [] + for i in range(0, len(embeddings), batch_size): + data = [ + { + self.primary_field: id, + self.vector_field: ( + embedding if not self.normalize else _normalize(embedding) + ), + self.text_field: text, + self.metadata_field: metadata, + } + for id, embedding, text, metadata in zip( + ids[i : i + batch_size], + embeddings[i : i + batch_size], + texts[i : i + batch_size], + metadatas[i : i + batch_size], + ) + ] + self.vector_store_client.insert( + table_name=self.table_name, + data=data, + ) + pks.extend(ids[i : i + batch_size]) + return pks + + def _parse_metric_type_str_to_dist_func(self) -> Any: + if self.vidx_metric_type == "l2": + return func.l2_distance + if self.vidx_metric_type == "cosine": + return func.cosine_distance + if self.vidx_metric_type == "inner_product": + return func.negative_inner_product + raise ValueError(f"Invalid vector index metric type: {self.vidx_metric_type}") + + def similar_search( + self, + text: str, + topk: int, + filters: Optional[MetadataFilters] = None, + param: Optional[dict] = None, + ) -> List[Chunk]: + """Perform a search on a query string and return results.""" + query_vector = self.embedding_function.embed_query(text) + return self._similarity_search_by_vector( + embedding=query_vector, k=topk, param=param, filters=filters + ) + + def similar_search_with_scores( + self, + text: str, + topk: int, + score_threshold: float, + filters: Optional[MetadataFilters] = None, + param: Optional[dict] = None, + ) -> List[Chunk]: + """Perform a search on a query string and return results with score.""" + query_vector = self.embedding_function.embed_query(text) + docs_with_id_and_scores = self._similarity_search_with_score_by_vector( + embedding=query_vector, k=topk, param=param, filters=filters + ) + return [ + Chunk( + metadata=doc.metadata, + content=doc.content, + score=score, + chunk_id=str(id), + ) + for doc, id, score in docs_with_id_and_scores + if score >= score_threshold + ] + + def _similarity_search_by_vector( + self, + embedding: List[float], + k: int = 4, + param: Optional[dict] = None, + filters: Optional[MetadataFilters] = None, + ) -> List[Chunk]: + if filters is not None: + filter = self._convert_metadata_filters(filters) + else: + filter = None + + search_param = ( + param if param is not None else DEFAULT_OCEANBASE_HNSW_SEARCH_PARAM + ) + ef_search = search_param.get( + "efSearch", DEFAULT_OCEANBASE_HNSW_SEARCH_PARAM["efSearch"] + ) + if ef_search != self.hnsw_ef_search: + self.vector_store_client.set_ob_hnsw_ef_search(ef_search) + self.hnsw_ef_search = ef_search + + res = self.vector_store_client.ann_search( + table_name=self.table_name, + vec_data=(embedding if not self.normalize else _normalize(embedding)), + vec_column_name=self.vector_field, + distance_func=self._parse_metric_type_str_to_dist_func(), + topk=k, + output_column_names=[self.text_field, self.metadata_field], + where_clause=([text(filter)] if filter is not None else None), + ) + return [ + Chunk( + content=r[0], + metadata=json.loads(r[1]), + ) + for r in res.fetchall() + ] + + def _similarity_search_with_score_by_vector( + self, + embedding: List[float], + k: int = 10, + param: Optional[dict] = None, + filters: Optional[MetadataFilters] = None, + **kwargs: Any, + ) -> List[Tuple[Chunk, str, float]]: + if filters is not None: + filter = self._convert_metadata_filters(filters) + else: + filter = None + + search_param = ( + param if param is not None else DEFAULT_OCEANBASE_HNSW_SEARCH_PARAM + ) + ef_search = search_param.get( + "efSearch", DEFAULT_OCEANBASE_HNSW_SEARCH_PARAM["efSearch"] + ) + if ef_search != self.hnsw_ef_search: + self.vector_store_client.set_ob_hnsw_ef_search(ef_search) + self.hnsw_ef_search = ef_search + + res = self.vector_store_client.ann_search( + table_name=self.table_name, + vec_data=(embedding if not self.normalize else _normalize(embedding)), + vec_column_name=self.vector_field, + distance_func=self._parse_metric_type_str_to_dist_func(), + with_dist=True, + topk=k, + output_column_names=[ + self.text_field, + self.metadata_field, + self.primary_field, + ], + where_clause=([text(filter)] if filter is not None else None), + **kwargs, + ) + return [ + ( + Chunk(content=r[0], metadata=json.loads(r[1])), + r[2], + r[3], + ) + for r in res.fetchall() + ] + + def delete_vector_name(self, vector_name: str): """Delete vector name.""" - self.logger.info("OceanBase: delete_vector_name..") - return self.vector_store_client.delete_collection() + self.vector_store_client.drop_table_if_exist(table_name=self.table_name) - def delete_by_ids(self, ids): + def delete_by_ids(self, ids: str): """Delete vector by ids.""" - self.logger.info("OceanBase: delete_by_ids..") - ids = ids.split(",") - if len(ids) > 0: - self.vector_store_client.delete(ids) + split_ids = ids.split(",") + self.vector_store_client.delete(table_name=self.table_name, ids=split_ids) + + def _enhance_filter_key(self, filter_key: str) -> str: + return f"{self.metadata_field}->'$.{filter_key}'" + + def _convert_metadata_filters(self, metafilters: MetadataFilters) -> str: + filters = [] + for filter in metafilters.filters: + filter_value = _parse_filter_value(filter.value) + + if filter.operator == FilterOperator.EQ: + filters.append(f"{self._enhance_filter_key(filter.key)}={filter_value}") + elif filter.operator == FilterOperator.GT: + filters.append(f"{self._enhance_filter_key(filter.key)}>{filter_value}") + elif filter.operator == FilterOperator.LT: + filters.append(f"{self._enhance_filter_key(filter.key)}<{filter_value}") + elif filter.operator == FilterOperator.NE: + filters.append( + f"{self._enhance_filter_key(filter.key)}!={filter_value}" + ) + elif filter.operator == FilterOperator.GTE: + filters.append( + f"{self._enhance_filter_key(filter.key)}>={filter_value}" + ) + elif filter.operator == FilterOperator.LTE: + filters.append( + f"{self._enhance_filter_key(filter.key)}<={filter_value}" + ) + elif filter.operator == FilterOperator.IN: + filters.append( + f"{self._enhance_filter_key(filter.key)} in {filter_value}" + ) + elif filter.operator == FilterOperator.NIN: + filters.append( + f"{self._enhance_filter_key(filter.key)} not in {filter_value}" + ) + else: + raise ValueError( + f"Operator {filter.operator} ('{filter.operator.value}') " + f"is not supported by OceanBase." + ) + return f" {metafilters.condition.value} ".join(filters) diff --git a/docker/compose_examples/dbgpt-oceanbase-docker-compose.yml b/docker/compose_examples/dbgpt-oceanbase-docker-compose.yml index 64e82078a..a45f3d953 100644 --- a/docker/compose_examples/dbgpt-oceanbase-docker-compose.yml +++ b/docker/compose_examples/dbgpt-oceanbase-docker-compose.yml @@ -17,7 +17,6 @@ services: - OB_PORT=2881 - OB_USER=root@test - OB_DATABASE=test - - OB_SQL_DBG_LOG_PATH=/sql_log/sql.log - LOCAL_DB_TYPE=sqlite - LLM_MODEL=tongyi_proxyllm - PROXYLLM_BACKEND=qwen-plus @@ -26,7 +25,6 @@ services: # - TONGYI_PROXY_API_KEY={your-api-key} - LANGUAGE=zh # - OB_ENABLE_NORMALIZE_VECTOR=True - # - OB_ENABLE_INDEX=True ports: - 3306:3306 - 12345:12345 diff --git a/docs/docs/application/advanced_tutorial/rag.md b/docs/docs/application/advanced_tutorial/rag.md index acb015ae6..b73ba0988 100644 --- a/docs/docs/application/advanced_tutorial/rag.md +++ b/docs/docs/application/advanced_tutorial/rag.md @@ -120,12 +120,8 @@ OB_USER=root@test OB_DATABASE=test ## Optional # OB_PASSWORD= -## Optional: SQL statements executed by OceanBase is recorded in the log file specified by {OB_SQL_DBG_LOG_PATH}. -# OB_SQL_DBG_LOG_PATH={your-sql-dbg-log-dir}/sql.log ## Optional: If {OB_ENABLE_NORMALIZE_VECTOR} is set, the vector stored in OceanBase is normalized. # OB_ENABLE_NORMALIZE_VECTOR=True -## Optional: If {OB_ENABLE_INDEX} is set, OceanBase will automatically create a vector index table. -# OB_ENABLE_INDEX=True ``` diff --git a/docs/docs/faq/kbqa.md b/docs/docs/faq/kbqa.md index ea3e2ac74..7b3a79bf6 100644 --- a/docs/docs/faq/kbqa.md +++ b/docs/docs/faq/kbqa.md @@ -23,8 +23,24 @@ If you want to change vector db, Update your .env, set your vector store type, V If you want to use OceanBase, please first start a docker container via the following command: ```shell -docker run -p 2881:2881 --name obvec -d oceanbase/oceanbase-ce:vector +docker run --name=ob433 -e MODE=slim -p 2881:2881 -d quay.io/oceanbase/oceanbase-ce:4.3.3.0-100000142024101215 ``` + +Donwload the partner package: +```shell +pip install --upgrade --quiet pyobvector +``` + +Check the connection to OceanBase and set the memory usage ratio for vector data: +```python +from pyobvector import ObVecClient + +tmp_client = ObVecClient() +tmp_client.perform_raw_text_sql( + "ALTER SYSTEM ob_vector_memory_limit_percentage = 30" +) +``` + Then set the following variables in the .env file: ```shell VECTOR_STORE_TYPE=OceanBase @@ -34,12 +50,8 @@ OB_USER=root@test OB_DATABASE=test ## Optional # OB_PASSWORD= -## Optional: SQL statements executed by OceanBase is recorded in the log file specified by {OB_SQL_DBG_LOG_PATH}. -# OB_SQL_DBG_LOG_PATH={your-sql-dbg-log-dir}/sql.log ## Optional: If {OB_ENABLE_NORMALIZE_VECTOR} is set, the vector stored in OceanBase is normalized. # OB_ENABLE_NORMALIZE_VECTOR=True -## Optional: If {OB_ENABLE_INDEX} is set, OceanBase will automatically create a vector index table. -# OB_ENABLE_INDEX=True ``` If you want to support more vector db, you can integrate yourself.[how to integrate](https://db-gpt.readthedocs.io/en/latest/modules/vector.html) ```commandline