diff --git a/dbgpt/_private/config.py b/dbgpt/_private/config.py index b9fa4ea27..32f669a31 100644 --- a/dbgpt/_private/config.py +++ b/dbgpt/_private/config.py @@ -238,6 +238,18 @@ class Config(metaclass=Singleton): self.MILVUS_USERNAME = os.getenv("MILVUS_USERNAME", None) self.MILVUS_PASSWORD = os.getenv("MILVUS_PASSWORD", None) + ## OceanBase Configuration + self.OB_HOST = os.getenv("OB_HOST", "127.0.0.1") + self.OB_PORT = int(os.getenv("OB_PORT", "2881")) + self.OB_USER = os.getenv("OB_USER", "root") + self.OB_PASSWORD = os.getenv("OB_PASSWORD", "") + self.OB_DATABASE = os.getenv("OB_DATABASE", "test") + self.OB_SQL_DBG_LOG_PATH = os.getenv("OB_SQL_DBG_LOG_PATH", "") + self.OB_ENABLE_NORMALIZE_VECTOR = bool( + os.getenv("OB_ENABLE_NORMALIZE_VECTOR", "") + ) + self.OB_ENABLE_INDEX = bool(os.getenv("OB_ENABLE_INDEX", "")) + # QLoRA self.QLoRA = os.getenv("QUANTIZE_QLORA", "True") self.IS_LOAD_8BIT = os.getenv("QUANTIZE_8bit", "True").lower() == "true" diff --git a/dbgpt/rag/text_splitter/tests/test_splitters.py b/dbgpt/rag/text_splitter/tests/test_splitters.py index fd32956ce..6b31f397f 100644 --- a/dbgpt/rag/text_splitter/tests/test_splitters.py +++ b/dbgpt/rag/text_splitter/tests/test_splitters.py @@ -25,11 +25,11 @@ def test_md_header_text_splitter() -> None: output = markdown_splitter.split_text(markdown_document) expected_output = [ Chunk( - content="{'Header 1': 'dbgpt', 'Header 2': 'description'}, my name is dbgpt", + content='"dbgpt-description": my name is dbgpt', metadata={"Header 1": "dbgpt", "Header 2": "description"}, ), Chunk( - content="{'Header 1': 'dbgpt', 'Header 2': 'content'}, my name is aries", + content='"dbgpt-content": my name is aries', metadata={"Header 1": "dbgpt", "Header 2": "content"}, ), ] diff --git a/dbgpt/rag/text_splitter/text_splitter.py b/dbgpt/rag/text_splitter/text_splitter.py index 32ae4a45d..a4b44ca8e 100644 --- a/dbgpt/rag/text_splitter/text_splitter.py +++ b/dbgpt/rag/text_splitter/text_splitter.py @@ -515,7 +515,8 @@ class MarkdownHeaderTextSplitter(TextSplitter): aggregated_chunks[-1]["content"] += " \n" + line["content"] else: # Otherwise, append the current line to the aggregated list - line["content"] = f"{line['metadata']}, " + line["content"] + subtitles = "-".join((list(line["metadata"].values()))) + line["content"] = f'"{subtitles}": ' + line["content"] aggregated_chunks.append(line) return [ @@ -557,16 +558,28 @@ class MarkdownHeaderTextSplitter(TextSplitter): # header_stack: List[Dict[str, Union[int, str]]] = [] header_stack: List[HeaderType] = [] initial_metadata: Dict[str, str] = {} + # Determine whether a line is within a markdown code block. + in_code_block = False for line in lines: stripped_line = line.strip() + # A code frame starts with "```" + with_code_frame = stripped_line.startswith("```") and ( + stripped_line != "```" + ) + if (not in_code_block) and with_code_frame: + in_code_block = True # Check each line against each of the header types (e.g., #, ##) for sep, name in self.headers_to_split_on: # Check if line starts with a header that we intend to split on - if stripped_line.startswith(sep) and ( - # Header with no text OR header is followed by space - # Both are valid conditions that sep is being used a header - len(stripped_line) == len(sep) - or stripped_line[len(sep)] == " " + if ( + (not in_code_block) + and stripped_line.startswith(sep) + and ( + # Header with no text OR header is followed by space + # Both are valid conditions that sep is being used a header + len(stripped_line) == len(sep) + or stripped_line[len(sep)] == " " + ) ): # Ensure we are tracking the header as metadata if name is not None: @@ -620,6 +633,10 @@ class MarkdownHeaderTextSplitter(TextSplitter): ) current_content.clear() + # Code block ends + if in_code_block and stripped_line == "```": + in_code_block = False + current_metadata = initial_metadata.copy() if current_content: lines_with_metadata.append( diff --git a/dbgpt/storage/vector_store/__init__.py b/dbgpt/storage/vector_store/__init__.py index 924680b39..9199bd08d 100644 --- a/dbgpt/storage/vector_store/__init__.py +++ b/dbgpt/storage/vector_store/__init__.py @@ -26,6 +26,12 @@ def _import_weaviate() -> Any: return WeaviateStore +def _import_oceanbase() -> Any: + from dbgpt.storage.vector_store.oceanbase_store import OceanBaseStore + + return OceanBaseStore + + def __getattr__(name: str) -> Any: if name == "Chroma": return _import_chroma() @@ -35,8 +41,10 @@ def __getattr__(name: str) -> Any: return _import_weaviate() elif name == "PGVector": return _import_pgvector() + elif name == "OceanBase": + return _import_oceanbase() else: raise AttributeError(f"Could not find: {name}") -__all__ = ["Chroma", "Milvus", "Weaviate", "PGVector"] +__all__ = ["Chroma", "Milvus", "Weaviate", "OceanBase", "PGVector"] diff --git a/dbgpt/storage/vector_store/oceanbase_store.py b/dbgpt/storage/vector_store/oceanbase_store.py new file mode 100644 index 000000000..fa58d7fdf --- /dev/null +++ b/dbgpt/storage/vector_store/oceanbase_store.py @@ -0,0 +1,798 @@ +"""OceanBase vector store.""" +import json +import logging +import os +import threading +import uuid +from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple + +from pydantic import Field +from sqlalchemy import Column, Table, create_engine, insert, text +from sqlalchemy.dialects.mysql import JSON, LONGTEXT, VARCHAR +from sqlalchemy.types import String, UserDefinedType + +from dbgpt.core import Chunk, Document, Embeddings +from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource +from dbgpt.storage.vector_store.base import ( + _COMMON_PARAMETERS, + VectorStoreBase, + VectorStoreConfig, +) +from dbgpt.storage.vector_store.filters import MetadataFilters +from dbgpt.util.i18n_utils import _ + +try: + from sqlalchemy.orm import declarative_base +except ImportError: + from sqlalchemy.ext.declarative import declarative_base + + +logger = logging.getLogger(__name__) +sql_logger = None +sql_dbg_log_path = os.getenv("OB_SQL_DBG_LOG_PATH", "") +if sql_dbg_log_path != "": + sql_logger = logging.getLogger("ob_sql_dbg") + sql_logger.setLevel(logging.DEBUG) + file_handler = logging.FileHandler(sql_dbg_log_path) + file_handler.setLevel(logging.DEBUG) + formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + file_handler.setFormatter(formatter) + sql_logger.addHandler(file_handler) + +_OCEANBASE_DEFAULT_EMBEDDING_DIM = 1536 +_OCEANBASE_DEFAULT_COLLECTION_NAME = "langchain_document" +_OCEANBASE_DEFAULT_IVFFLAT_ROW_THRESHOLD = 10000 +_OCEANBASE_DEFAULT_RWLOCK_MAX_READER = 64 + +Base = declarative_base() + + +def ob_vector_from_db(value): + """Parse vector from oceanbase.""" + return [float(v) for v in value[1:-1].split(",")] + + +def ob_vector_to_db(value, dim=None): + """Parse vector to oceanbase vector constant type.""" + if value is None: + return value + + return "[" + ",".join([str(float(v)) for v in value]) + "]" + + +class Vector(UserDefinedType): + """OceanBase Vector Column Type.""" + + cache_ok = True + _string = String() + + def __init__(self, dim): + """Create a Vector column with dimemsion `dim`.""" + super(UserDefinedType, self).__init__() + self.dim = dim + + def get_col_spec(self, **kw): + """Get vector column definition in string format.""" + return "VECTOR(%d)" % self.dim + + def bind_processor(self, dialect): + """Get a processor to parse an array to oceanbase vector.""" + + def process(value): + return ob_vector_to_db(value, self.dim) + + return process + + def literal_processor(self, dialect): + """Get a string processor to parse an array to OceanBase Vector.""" + string_literal_processor = self._string._cached_literal_processor(dialect) + + def process(value): + return string_literal_processor(ob_vector_to_db(value, self.dim)) + + return process + + def result_processor(self, dialect, coltype): + """Get a processor to parse OceanBase Vector to array.""" + + def process(value): + return ob_vector_from_db(value) + + return process + + +class OceanBaseCollectionStat: + """A tracer that maintains a table status in OceanBase.""" + + def __init__(self): + """Create OceanBaseCollectionStat instance.""" + self._lock = threading.Lock() + self.maybe_collection_not_exist = True + self.maybe_collection_index_not_exist = True + + def collection_exists(self): + """Set a table is existing.""" + with self._lock: + self.maybe_collection_not_exist = False + + def collection_index_exists(self): + """Set the index of a table is existing.""" + with self._lock: + self.maybe_collection_index_not_exist = False + + def collection_not_exists(self): + """Set a table is dropped.""" + with self._lock: + self.maybe_collection_not_exist = True + + def collection_index_not_exists(self): + """Set the index of a table is dropped.""" + with self._lock: + self.maybe_collection_index_not_exist = True + + def get_maybe_collection_not_exist(self): + """Get the existing status of a table.""" + with self._lock: + return self.maybe_collection_not_exist + + def get_maybe_collection_index_not_exist(self): + """Get the existing stats of the index of a table.""" + with self._lock: + return self.maybe_collection_index_not_exist + + +class OceanBaseGlobalRWLock: + """A global rwlock for OceanBase to do creating vector index table offline ddl.""" + + def __init__(self, max_readers) -> None: + """Create a rwlock.""" + self.max_readers_ = max_readers + self.writer_entered_ = False + self.reader_cnt_ = 0 + self.mutex_ = threading.Lock() + self.writer_cv_ = threading.Condition(self.mutex_) + self.reader_cv_ = threading.Condition(self.mutex_) + + def rlock(self): + """Lock for reading.""" + self.mutex_.acquire() + while self.writer_entered_ or self.max_readers_ == self.reader_cnt_: + self.reader_cv_.wait() + self.reader_cnt_ += 1 + self.mutex_.release() + + def runlock(self): + """Unlock reading lock.""" + self.mutex_.acquire() + self.reader_cnt_ -= 1 + if self.writer_entered_: + if self.reader_cnt_ == 0: + self.writer_cv_.notify(1) + else: + if self.max_readers_ - 1 == self.reader_cnt_: + self.reader_cv_.notify(1) + self.mutex_.release() + + def wlock(self): + """Lock for writing.""" + self.mutex_.acquire() + while self.writer_entered_: + self.reader_cv_.wait() + self.writer_entered_ = True + while 0 < self.reader_cnt_: + self.writer_cv_.wait() + self.mutex_.release() + + def wunlock(self): + """Unlock writing lock.""" + self.mutex_.acquire() + self.writer_entered_ = False + self.reader_cv_.notifyAll() + self.mutex_.release() + + class OBRLock: + """Reading Lock Wrapper for `with` clause.""" + + def __init__(self, rwlock) -> None: + """Create reading lock wrapper instance.""" + self.rwlock_ = rwlock + + def __enter__(self): + """Lock.""" + self.rwlock_.rlock() + + def __exit__(self, exc_type, exc_value, traceback): + """Unlock when exiting.""" + self.rwlock_.runlock() + + class OBWLock: + """Writing Lock Wrapper for `with` clause.""" + + def __init__(self, rwlock) -> None: + """Create writing lock wrapper instance.""" + self.rwlock_ = rwlock + + def __enter__(self): + """Lock.""" + self.rwlock_.wlock() + + def __exit__(self, exc_type, exc_value, traceback): + """Unlock when exiting.""" + self.rwlock_.wunlock() + + def reader_lock(self): + """Get reading lock wrapper.""" + return self.OBRLock(self) + + def writer_lock(self): + """Get writing lock wrapper.""" + return self.OBWLock(self) + + +ob_grwlock = OceanBaseGlobalRWLock(_OCEANBASE_DEFAULT_RWLOCK_MAX_READER) + + +class OceanBase: + """OceanBase Vector Store implementation.""" + + def __init__( + self, + connection_string: str, + embedding_function: Embeddings, + embedding_dimension: int = _OCEANBASE_DEFAULT_EMBEDDING_DIM, + collection_name: str = _OCEANBASE_DEFAULT_COLLECTION_NAME, + pre_delete_collection: bool = False, + logger: Optional[logging.Logger] = None, + engine_args: Optional[dict] = None, + delay_table_creation: bool = True, + enable_index: bool = False, + th_create_ivfflat_index: int = _OCEANBASE_DEFAULT_IVFFLAT_ROW_THRESHOLD, + sql_logger: Optional[logging.Logger] = None, + collection_stat: Optional[OceanBaseCollectionStat] = None, + enable_normalize_vector: bool = False, + ) -> None: + """Create OceanBase Vector Store instance.""" + self.connection_string = connection_string + self.embedding_function = embedding_function + self.embedding_dimension = embedding_dimension + self.collection_name = collection_name + self.pre_delete_collection = pre_delete_collection + self.logger = logger or logging.getLogger(__name__) + self.delay_table_creation = delay_table_creation + self.th_create_ivfflat_index = th_create_ivfflat_index + self.enable_index = enable_index + self.sql_logger = sql_logger + self.collection_stat = collection_stat + self.enable_normalize_vector = enable_normalize_vector + self.__post_init__(engine_args) + + def __post_init__( + self, + engine_args: Optional[dict] = None, + ) -> None: + """Create connection & vector table.""" + _engine_args = engine_args or {} + if "pool_recycle" not in _engine_args: + _engine_args["pool_recycle"] = 3600 + self.engine = create_engine(self.connection_string, **_engine_args) + self.create_collection() + + @property + def embeddings(self) -> Embeddings: + """Get embedding function.""" + return self.embedding_function + + def create_collection(self) -> None: + """Create vector table.""" + if self.pre_delete_collection: + self.delete_collection() + if not self.delay_table_creation and ( + self.collection_stat is None + or self.collection_stat.get_maybe_collection_not_exist() + ): + self.create_table_if_not_exists() + if self.collection_stat is not None: + self.collection_stat.collection_exists() + + def delete_collection(self) -> None: + """Drop vector table.""" + drop_statement = text(f"DROP TABLE IF EXISTS {self.collection_name}") + if self.sql_logger is not None: + self.sql_logger.debug(f"Trying to delete collection: {drop_statement}") + with self.engine.connect() as conn, conn.begin(): + conn.execute(drop_statement) + if self.collection_stat is not None: + self.collection_stat.collection_not_exists() + self.collection_stat.collection_index_not_exists() + + def create_table_if_not_exists(self) -> None: + """Create vector table with SQL.""" + create_table_query = f""" + CREATE TABLE IF NOT EXISTS `{self.collection_name}` ( + id VARCHAR(40) NOT NULL, + embedding VECTOR({self.embedding_dimension}), + document LONGTEXT, + metadata JSON, + PRIMARY KEY (id) + ) + """ + if self.sql_logger is not None: + self.sql_logger.debug(f"Trying to create table: {create_table_query}") + with self.engine.connect() as conn, conn.begin(): + # Create the table + conn.execute(text(create_table_query)) + + def create_collection_ivfflat_index_if_not_exists(self) -> None: + """Create ivfflat index table with SQL.""" + create_index_query = f""" + CREATE INDEX IF NOT EXISTS `embedding_idx` on `{self.collection_name}` ( + embedding l2 + ) using ivfflat with (lists=20) + """ + with ob_grwlock.writer_lock(), self.engine.connect() as conn, conn.begin(): + # Create Ivfflat Index + if self.sql_logger is not None: + self.sql_logger.debug( + f"Trying to create ivfflat index: {create_index_query}" + ) + conn.execute(text(create_index_query)) + + def add_texts( + self, + texts: Iterable[str], + metadatas: Optional[List[dict]] = None, + ids: Optional[List[str]] = None, + batch_size: int = 500, + **kwargs: Any, + ) -> List[str]: + """Insert texts into vector table.""" + if ids is None: + ids = [str(uuid.uuid1()) for _ in texts] + + embeddings = self.embedding_function.embed_documents(list(texts)) + + if len(embeddings) == 0: + return ids + + if not metadatas: + metadatas = [{} for _ in texts] + + if self.delay_table_creation and ( + self.collection_stat is None + or self.collection_stat.get_maybe_collection_not_exist() + ): + self.embedding_dimension = len(embeddings[0]) + self.create_table_if_not_exists() + self.delay_table_creation = False + if self.collection_stat is not None: + self.collection_stat.collection_exists() + + chunks_table = Table( + self.collection_name, + Base.metadata, + Column("id", VARCHAR(40), primary_key=True), + Column("embedding", Vector(self.embedding_dimension)), + Column("document", LONGTEXT, nullable=True), + Column("metadata", JSON, nullable=True), # filter + keep_existing=True, + ) + + row_count_query = f""" + SELECT COUNT(*) as cnt FROM `{self.collection_name}` + """ + chunks_table_data = [] + with self.engine.connect() as conn, conn.begin(): + for document, metadata, chunk_id, embedding in zip( + texts, metadatas, ids, embeddings + ): + chunks_table_data.append( + { + "id": chunk_id, + "embedding": embedding + if not self.enable_normalize_vector + else self._normalization_vectors(embedding), + "document": document, + "metadata": metadata, + } + ) + + # Execute the batch insert when the batch size is reached + if len(chunks_table_data) == batch_size: + with ob_grwlock.reader_lock(): + if self.sql_logger is not None: + insert_sql_for_log = str( + insert(chunks_table).values(chunks_table_data) + ) + self.sql_logger.debug( + f"""Trying to insert vectors: + {insert_sql_for_log}""" + ) + conn.execute(insert(chunks_table).values(chunks_table_data)) + # Clear the chunks_table_data list for the next batch + chunks_table_data.clear() + + # Insert any remaining records that didn't make up a full batch + if chunks_table_data: + with ob_grwlock.reader_lock(): + if self.sql_logger is not None: + insert_sql_for_log = str( + insert(chunks_table).values(chunks_table_data) + ) + self.sql_logger.debug( + f"""Trying to insert vectors: + {insert_sql_for_log}""" + ) + conn.execute(insert(chunks_table).values(chunks_table_data)) + + if self.enable_index and ( + self.collection_stat is None + or self.collection_stat.get_maybe_collection_index_not_exist() + ): + with ob_grwlock.reader_lock(): + row_cnt_res = conn.execute(text(row_count_query)) + for row in row_cnt_res: + if row.cnt > self.th_create_ivfflat_index: + self.create_collection_ivfflat_index_if_not_exists() + if self.collection_stat is not None: + self.collection_stat.collection_index_exists() + + return ids + + def similarity_search( + self, + query: str, + k: int = 4, + filter: Optional[MetadataFilters] = None, + **kwargs: Any, + ) -> List[Document]: + """Do similarity search via query in String.""" + embedding = self.embedding_function.embed_query(query) + docs = self.similarity_search_by_vector(embedding=embedding, k=k, filter=filter) + return docs + + def similarity_search_by_vector( + self, + embedding: List[float], + k: int = 4, + filter: Optional[MetadataFilters] = None, + **kwargs: Any, + ) -> List[Document]: + """Do similarity search via query vector.""" + docs_and_scores = self.similarity_search_with_score_by_vector( + embedding=embedding, k=k, filter=filter + ) + return [doc for doc, _ in docs_and_scores] + + def similarity_search_with_score_by_vector( + self, + embedding: List[float], + k: int = 4, + filter: Optional[MetadataFilters] = None, + score_threshold: Optional[float] = None, + ) -> List[Tuple[Document, float]]: + """Do similarity search via query vector with score.""" + try: + from sqlalchemy.engine import Row + except ImportError: + raise ImportError( + "Could not import Row from sqlalchemy.engine. " + "Please 'pip install sqlalchemy>=1.4'." + ) + + # filter is not support in OceanBase currently. + + # normailze embedding vector + if self.enable_normalize_vector: + embedding = self._normalization_vectors(embedding) + + embedding_str = ob_vector_to_db(embedding, self.embedding_dimension) + vector_distance_op = "<@>" if self.enable_normalize_vector else "<->" + sql_query = f""" + SELECT document, metadata, embedding {vector_distance_op} '{embedding_str}' + as distance + FROM {self.collection_name} + ORDER BY embedding {vector_distance_op} '{embedding_str}' + LIMIT :k + """ + sql_query_str_for_log = f""" + SELECT document, metadata, embedding {vector_distance_op} '?' as distance + FROM {self.collection_name} + ORDER BY embedding {vector_distance_op} '?' + LIMIT {k} + """ + + params = {"k": k} + try: + with ob_grwlock.reader_lock(), self.engine.connect() as conn: + if self.sql_logger is not None: + self.sql_logger.debug( + f"Trying to do similarity search: {sql_query_str_for_log}" + ) + results: Sequence[Row] = conn.execute( + text(sql_query), params + ).fetchall() + + if (score_threshold is not None) and self.enable_normalize_vector: + documents_with_scores = [ + ( + Document( + content=result.document, + metadata=json.loads(result.metadata), + ), + result.distance, + ) + for result in results + if result.distance >= score_threshold + ] + else: + documents_with_scores = [ + ( + Document( + content=result.document, + metadata=json.loads(result.metadata), + ), + result.distance, + ) + for result in results + ] + return documents_with_scores + except Exception as e: + self.logger.error("similarity_search_with_score_by_vector failed:", str(e)) + return [] + + def similarity_search_with_score( + self, + query: str, + k: int = 4, + filter: Optional[MetadataFilters] = None, + score_threshold: Optional[float] = None, + ) -> List[Tuple[Document, float]]: + """Do similarity search via query String with score.""" + embedding = self.embedding_function.embed_query(query) + docs = self.similarity_search_with_score_by_vector( + embedding=embedding, k=k, filter=filter, score_threshold=score_threshold + ) + return docs + + def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]: + """Delete vectors from vector table.""" + if ids is None: + raise ValueError("No ids provided to delete.") + + # Define the table schema + chunks_table = Table( + self.collection_name, + Base.metadata, + Column("id", VARCHAR(40), primary_key=True), + Column("embedding", Vector(self.embedding_dimension)), + Column("document", LONGTEXT, nullable=True), + Column("metadata", JSON, nullable=True), # filter + keep_existing=True, + ) + + try: + with self.engine.connect() as conn, conn.begin(): + delete_condition = chunks_table.c.id.in_(ids) + delete_stmt = chunks_table.delete().where(delete_condition) + with ob_grwlock.reader_lock(): + if self.sql_logger is not None: + self.sql_logger.debug( + f"Trying to delete vectors: {str(delete_stmt)}" + ) + conn.execute(delete_stmt) + return True + except Exception as e: + self.logger.error("Delete operation failed:", str(e)) + return False + + def _normalization_vectors(self, vector): + import numpy as np + + norm = np.linalg.norm(vector) + return (vector / norm).tolist() + + @classmethod + def connection_string_from_db_params( + cls, + host: str, + port: int, + database: str, + user: str, + password: str, + ) -> str: + """Get connection string.""" + return f"mysql+pymysql://{user}:{password}@{host}:{port}/{database}" + + +ob_collection_stats_lock = threading.Lock() +ob_collection_stats: Dict[str, OceanBaseCollectionStat] = {} + + +@register_resource( + _("OceanBase Vector Store"), + "oceanbase_vector_store", + category=ResourceCategory.VECTOR_STORE, + parameters=[ + *_COMMON_PARAMETERS, + Parameter.build_from( + _("OceanBase Host"), + "ob_host", + str, + description=_("oceanbase host"), + optional=True, + default=None, + ), + Parameter.build_from( + _("OceanBase Port"), + "ob_port", + int, + description=_("oceanbase port"), + optional=True, + default=None, + ), + Parameter.build_from( + _("OceanBase User"), + "ob_user", + str, + description=_("user to login"), + optional=True, + default=None, + ), + Parameter.build_from( + _("OceanBase Password"), + "ob_password", + str, + description=_("password to login"), + optional=True, + default=None, + ), + Parameter.build_from( + _("OceanBase Database"), + "ob_database", + str, + description=_("database for vector tables"), + optional=True, + default=None, + ), + ], + description="OceanBase vector store.", +) +class OceanBaseConfig(VectorStoreConfig): + """OceanBase vector store config.""" + + class Config: + """Config for BaseModel.""" + + arbitrary_types_allowed = True + + """OceanBase config""" + ob_host: str = Field( + default="127.0.0.1", + description="oceanbase host", + ) + ob_port: int = Field( + default=2881, + description="oceanbase port", + ) + ob_user: str = Field( + default="root@test", + description="user to login", + ) + ob_password: str = Field( + default="", + description="password to login", + ) + ob_database: str = Field( + default="test", + description="database for vector tables", + ) + + +class OceanBaseStore(VectorStoreBase): + """OceanBase vector store.""" + + def __init__(self, vector_store_config: OceanBaseConfig) -> None: + """Create a OceanBaseStore instance.""" + if vector_store_config.embedding_fn is None: + raise ValueError("embedding_fn is required for OceanBaseStore") + + self.embeddings = vector_store_config.embedding_fn + self.collection_name = vector_store_config.name + vector_store_config = vector_store_config.dict() + self.OB_HOST = str( + vector_store_config.get("ob_host") or os.getenv("OB_HOST", "127.0.0.1") + ) + self.OB_PORT = int( + vector_store_config.get("ob_port") or int(os.getenv("OB_PORT", "2881")) + ) + self.OB_USER = str( + vector_store_config.get("ob_user") or os.getenv("OB_USER", "root@test") + ) + self.OB_PASSWORD = str( + vector_store_config.get("ob_password") or os.getenv("OB_PASSWORD", "") + ) + self.OB_DATABASE = str( + vector_store_config.get("ob_database") or os.getenv("OB_DATABASE", "test") + ) + self.OB_ENABLE_NORMALIZE_VECTOR = bool( + os.getenv("OB_ENABLE_NORMALIZE_VECTOR", "") + ) + self.connection_string = OceanBase.connection_string_from_db_params( + self.OB_HOST, self.OB_PORT, self.OB_DATABASE, self.OB_USER, self.OB_PASSWORD + ) + self.logger = logger + with ob_collection_stats_lock: + if ob_collection_stats.get(self.collection_name) is None: + ob_collection_stats[self.collection_name] = OceanBaseCollectionStat() + self.collection_stat = ob_collection_stats[self.collection_name] + + self.vector_store_client = OceanBase( + connection_string=self.connection_string, + embedding_function=self.embeddings, + collection_name=self.collection_name, + logger=self.logger, + sql_logger=sql_logger, + enable_index=bool(os.getenv("OB_ENABLE_INDEX", "")), + collection_stat=self.collection_stat, + enable_normalize_vector=self.OB_ENABLE_NORMALIZE_VECTOR, + ) + + def similar_search( + self, text, topk, filters: Optional[MetadataFilters] = None, **kwargs: Any + ) -> List[Chunk]: + """Perform a search on a query string and return results.""" + self.logger.info("OceanBase: similar_search..") + documents = self.vector_store_client.similarity_search( + text, topk, filter=filters + ) + return [Chunk(content=doc.content, metadata=doc.metadata) for doc in documents] + + def similar_search_with_scores( + self, + text, + topk, + score_threshold: float, + filters: Optional[MetadataFilters] = None, + ) -> List[Chunk]: + """Perform a search on a query string and return results with score.""" + self.logger.info("OceanBase: similar_search_with_scores..") + docs_and_scores = self.vector_store_client.similarity_search_with_score( + text, topk, filter=filters + ) + return [ + Chunk(content=doc.content, metadata=doc.metadata, score=score) + for doc, score in docs_and_scores + ] + + def vector_name_exists(self): + """Whether vector name exists.""" + self.logger.info("OceanBase: vector_name_exists..") + try: + self.vector_store_client.create_collection() + return True + except Exception as e: + logger.error("vector_name_exists error", e.message) + return False + + def load_document(self, chunks: List[Chunk]) -> List[str]: + """Load document in vector database.""" + self.logger.info("OceanBase: load_document..") + # lc_documents = [Chunk.chunk2langchain(chunk) for chunk in chunks] + texts = [chunk.content for chunk in chunks] + metadatas = [chunk.metadata for chunk in chunks] + ids = self.vector_store_client.add_texts(texts=texts, metadatas=metadatas) + return ids + + def delete_vector_name(self, vector_name): + """Delete vector name.""" + self.logger.info("OceanBase: delete_vector_name..") + return self.vector_store_client.delete_collection() + + def delete_by_ids(self, ids): + """Delete vector by ids.""" + self.logger.info("OceanBase: delete_by_ids..") + ids = ids.split(",") + if len(ids) > 0: + self.vector_store_client.delete(ids) diff --git a/docker/allinone/Dockerfile b/docker/allinone/Dockerfile index 48c8362cf..2422ed253 100644 --- a/docker/allinone/Dockerfile +++ b/docker/allinone/Dockerfile @@ -2,6 +2,7 @@ ARG BASE_IMAGE="eosphorosai/dbgpt:latest" FROM ${BASE_IMAGE} +RUN pip install dashscope RUN apt-get update && apt-get install -y wget gnupg lsb-release net-tools RUN apt-key adv --keyserver keyserver.ubuntu.com --recv-keys 467B942D3A79BD29 @@ -9,7 +10,7 @@ RUN apt-key adv --keyserver keyserver.ubuntu.com --recv-keys 467B942D3A79BD29 RUN wget https://dev.mysql.com/get/mysql-apt-config_0.8.17-1_all.deb RUN dpkg -i mysql-apt-config_0.8.17-1_all.deb -RUN apt-get update && apt-get install -y mysql-server && apt-get clean +RUN apt-key adv --keyserver keyserver.ubuntu.com --recv-keys B7B3B788A8D3785C && apt-get update && apt-get install -y mysql-server && apt-get clean # Remote access RUN sed -i 's/bind-address\s*=\s*127.0.0.1/bind-address = 0.0.0.0/g' /etc/mysql/mysql.conf.d/mysqld.cnf \ diff --git a/docker/compose_examples/dbgpt-oceanbase-docker-compose.yml b/docker/compose_examples/dbgpt-oceanbase-docker-compose.yml new file mode 100644 index 000000000..64e82078a --- /dev/null +++ b/docker/compose_examples/dbgpt-oceanbase-docker-compose.yml @@ -0,0 +1,42 @@ +version: '3.8' + +services: + oceanbase: + image: oceanbase/oceanbase-ce:vector + ports: + - 2881:2881 + networks: + - dbgptnet + dbgpt: + image: eosphorosai/dbgpt-allinone + environment: + - DBGPT_WEBSERVER_PORT=12345 + - VECTOR_STORE_TYPE=OceanBase + - OB_HOST=oceanbase + - OB_HOST=127.0.0.1 + - OB_PORT=2881 + - OB_USER=root@test + - OB_DATABASE=test + - OB_SQL_DBG_LOG_PATH=/sql_log/sql.log + - LOCAL_DB_TYPE=sqlite + - LLM_MODEL=tongyi_proxyllm + - PROXYLLM_BACKEND=qwen-plus + - EMBEDDING_MODEL=text2vec + # your api key + # - TONGYI_PROXY_API_KEY={your-api-key} + - LANGUAGE=zh + # - OB_ENABLE_NORMALIZE_VECTOR=True + # - OB_ENABLE_INDEX=True + ports: + - 3306:3306 + - 12345:12345 + volumes: + # - {your-ob-sql-dbg-log-dir}:/sql_log + # - {your-model-dir}:/app/models + networks: + - dbgptnet + +networks: + dbgptnet: + driver: bridge + name: dbgptnet \ No newline at end of file diff --git a/docs/docs/application/advanced_tutorial/rag.md b/docs/docs/application/advanced_tutorial/rag.md index 1cfeab5c2..acb015ae6 100644 --- a/docs/docs/application/advanced_tutorial/rag.md +++ b/docs/docs/application/advanced_tutorial/rag.md @@ -69,6 +69,7 @@ import TabItem from '@theme/TabItem'; {label: 'Chroma', value: 'Chroma'}, {label: 'Milvus', value: 'Milvus'}, {label: 'Weaviate', value: 'Weaviate'}, + {label: 'OceanBase', value: 'OceanBase'}, ]}> @@ -106,6 +107,25 @@ set ``VECTOR_STORE_TYPE`` in ``.env`` file VECTOR_STORE_TYPE=Weaviate #WEAVIATE_URL=https://kt-region-m8hcy0wc.weaviate.network ``` + + + +set ``VECTOR_STORE_TYPE`` in ``.env`` file + +```shell +OB_HOST=127.0.0.1 +OB_PORT=2881 +OB_USER=root@test +OB_DATABASE=test +## Optional +# OB_PASSWORD= +## Optional: SQL statements executed by OceanBase is recorded in the log file specified by {OB_SQL_DBG_LOG_PATH}. +# OB_SQL_DBG_LOG_PATH={your-sql-dbg-log-dir}/sql.log +## Optional: If {OB_ENABLE_NORMALIZE_VECTOR} is set, the vector stored in OceanBase is normalized. +# OB_ENABLE_NORMALIZE_VECTOR=True +## Optional: If {OB_ENABLE_INDEX} is set, OceanBase will automatically create a vector index table. +# OB_ENABLE_INDEX=True +``` diff --git a/docs/docs/faq/kbqa.md b/docs/docs/faq/kbqa.md index d35715b73..ea3e2ac74 100644 --- a/docs/docs/faq/kbqa.md +++ b/docs/docs/faq/kbqa.md @@ -18,8 +18,29 @@ git lfs clone https://huggingface.co/GanymedeNil/text2vec-large-chinese Update .env file and set VECTOR_STORE_TYPE. -DB-GPT currently support Chroma(Default), Milvus(>2.1), Weaviate vector database. -If you want to change vector db, Update your .env, set your vector store type, VECTOR_STORE_TYPE=Chroma (now only support Chroma and Milvus(>2.1), if you set Milvus, please set MILVUS_URL and MILVUS_PORT) +DB-GPT currently support Chroma(Default), Milvus(>2.1), Weaviate, OceanBase vector database. +If you want to change vector db, Update your .env, set your vector store type, VECTOR_STORE_TYPE=Chroma (now only support Chroma and Milvus(>2.1), if you set Milvus, please set MILVUS_URL and MILVUS_PORT). + +If you want to use OceanBase, please first start a docker container via the following command: +```shell +docker run -p 2881:2881 --name obvec -d oceanbase/oceanbase-ce:vector +``` +Then set the following variables in the .env file: +```shell +VECTOR_STORE_TYPE=OceanBase +OB_HOST=127.0.0.1 +OB_PORT=2881 +OB_USER=root@test +OB_DATABASE=test +## Optional +# OB_PASSWORD= +## Optional: SQL statements executed by OceanBase is recorded in the log file specified by {OB_SQL_DBG_LOG_PATH}. +# OB_SQL_DBG_LOG_PATH={your-sql-dbg-log-dir}/sql.log +## Optional: If {OB_ENABLE_NORMALIZE_VECTOR} is set, the vector stored in OceanBase is normalized. +# OB_ENABLE_NORMALIZE_VECTOR=True +## Optional: If {OB_ENABLE_INDEX} is set, OceanBase will automatically create a vector index table. +# OB_ENABLE_INDEX=True +``` If you want to support more vector db, you can integrate yourself.[how to integrate](https://db-gpt.readthedocs.io/en/latest/modules/vector.html) ```commandline #*******************************************************************# diff --git a/i18n/locales/zh_CN/LC_MESSAGES/dbgpt_storage.po b/i18n/locales/zh_CN/LC_MESSAGES/dbgpt_storage.po index 267db25fc..75c2ed1c3 100644 --- a/i18n/locales/zh_CN/LC_MESSAGES/dbgpt_storage.po +++ b/i18n/locales/zh_CN/LC_MESSAGES/dbgpt_storage.po @@ -147,6 +147,50 @@ msgid "" "connection string." msgstr "向量存储的连接字符串,如果未设置,将使用默认的连接字符串。" +#: ../dbgpt/storage/vector_store/oceanbase_store.py:542 +msgid "OceanBase Vector Store" +msgstr "OceanBase 向量存储" + +#: ../dbgpt/storage/vector_store/oceanbase_store.py:548 +msgid "OceanBase Host" +msgstr "OceanBase 主机地址" + +#: ../dbgpt/storage/vector_store/oceanbase_store.py:552 +msgid "oceanbase host" +msgstr "oceanbase 主机地址" + +#: ../dbgpt/storage/vector_store/oceanbase_store.py:558 +msgid "OceanBase Port" +msgstr "OceanBase 端口" + +#: ../dbgpt/storage/vector_store/oceanbase_store.py:562 +msgid "oceanbase port" +msgstr "oceanbase 端口" + +#: ../dbgpt/storage/vector_store/oceanbase_store.py:568 +msgid "OceanBase User" +msgstr "OceanBase 用户名" + +#: ../dbgpt/storage/vector_store/oceanbase_store.py:572 +msgid "user to login" +msgstr "登录用户名" + +#: ../dbgpt/storage/vector_store/oceanbase_store.py:578 +msgid "OceanBase Password" +msgstr "OceanBase 密码" + +#: ../dbgpt/storage/vector_store/oceanbase_store.py:582 +msgid "password to login" +msgstr "登录密码" + +#: ../dbgpt/storage/vector_store/oceanbase_store.py:588 +msgid "OceanBase Database" +msgstr "OceanBase 数据库名" + +#: ../dbgpt/storage/vector_store/oceanbase_store.py:592 +msgid "database for vector tables" +msgstr "存放向量表的数据库名" + #: ../dbgpt/storage/vector_store/base.py:19 msgid "Collection Name" msgstr "集合名称"