diff --git a/dbgpt/_private/config.py b/dbgpt/_private/config.py
index b9fa4ea27..32f669a31 100644
--- a/dbgpt/_private/config.py
+++ b/dbgpt/_private/config.py
@@ -238,6 +238,18 @@ class Config(metaclass=Singleton):
self.MILVUS_USERNAME = os.getenv("MILVUS_USERNAME", None)
self.MILVUS_PASSWORD = os.getenv("MILVUS_PASSWORD", None)
+ ## OceanBase Configuration
+ self.OB_HOST = os.getenv("OB_HOST", "127.0.0.1")
+ self.OB_PORT = int(os.getenv("OB_PORT", "2881"))
+ self.OB_USER = os.getenv("OB_USER", "root")
+ self.OB_PASSWORD = os.getenv("OB_PASSWORD", "")
+ self.OB_DATABASE = os.getenv("OB_DATABASE", "test")
+ self.OB_SQL_DBG_LOG_PATH = os.getenv("OB_SQL_DBG_LOG_PATH", "")
+ self.OB_ENABLE_NORMALIZE_VECTOR = bool(
+ os.getenv("OB_ENABLE_NORMALIZE_VECTOR", "")
+ )
+ self.OB_ENABLE_INDEX = bool(os.getenv("OB_ENABLE_INDEX", ""))
+
# QLoRA
self.QLoRA = os.getenv("QUANTIZE_QLORA", "True")
self.IS_LOAD_8BIT = os.getenv("QUANTIZE_8bit", "True").lower() == "true"
diff --git a/dbgpt/rag/text_splitter/tests/test_splitters.py b/dbgpt/rag/text_splitter/tests/test_splitters.py
index fd32956ce..6b31f397f 100644
--- a/dbgpt/rag/text_splitter/tests/test_splitters.py
+++ b/dbgpt/rag/text_splitter/tests/test_splitters.py
@@ -25,11 +25,11 @@ def test_md_header_text_splitter() -> None:
output = markdown_splitter.split_text(markdown_document)
expected_output = [
Chunk(
- content="{'Header 1': 'dbgpt', 'Header 2': 'description'}, my name is dbgpt",
+ content='"dbgpt-description": my name is dbgpt',
metadata={"Header 1": "dbgpt", "Header 2": "description"},
),
Chunk(
- content="{'Header 1': 'dbgpt', 'Header 2': 'content'}, my name is aries",
+ content='"dbgpt-content": my name is aries',
metadata={"Header 1": "dbgpt", "Header 2": "content"},
),
]
diff --git a/dbgpt/rag/text_splitter/text_splitter.py b/dbgpt/rag/text_splitter/text_splitter.py
index 32ae4a45d..a4b44ca8e 100644
--- a/dbgpt/rag/text_splitter/text_splitter.py
+++ b/dbgpt/rag/text_splitter/text_splitter.py
@@ -515,7 +515,8 @@ class MarkdownHeaderTextSplitter(TextSplitter):
aggregated_chunks[-1]["content"] += " \n" + line["content"]
else:
# Otherwise, append the current line to the aggregated list
- line["content"] = f"{line['metadata']}, " + line["content"]
+ subtitles = "-".join((list(line["metadata"].values())))
+ line["content"] = f'"{subtitles}": ' + line["content"]
aggregated_chunks.append(line)
return [
@@ -557,16 +558,28 @@ class MarkdownHeaderTextSplitter(TextSplitter):
# header_stack: List[Dict[str, Union[int, str]]] = []
header_stack: List[HeaderType] = []
initial_metadata: Dict[str, str] = {}
+ # Determine whether a line is within a markdown code block.
+ in_code_block = False
for line in lines:
stripped_line = line.strip()
+ # A code frame starts with "```"
+ with_code_frame = stripped_line.startswith("```") and (
+ stripped_line != "```"
+ )
+ if (not in_code_block) and with_code_frame:
+ in_code_block = True
# Check each line against each of the header types (e.g., #, ##)
for sep, name in self.headers_to_split_on:
# Check if line starts with a header that we intend to split on
- if stripped_line.startswith(sep) and (
- # Header with no text OR header is followed by space
- # Both are valid conditions that sep is being used a header
- len(stripped_line) == len(sep)
- or stripped_line[len(sep)] == " "
+ if (
+ (not in_code_block)
+ and stripped_line.startswith(sep)
+ and (
+ # Header with no text OR header is followed by space
+ # Both are valid conditions that sep is being used a header
+ len(stripped_line) == len(sep)
+ or stripped_line[len(sep)] == " "
+ )
):
# Ensure we are tracking the header as metadata
if name is not None:
@@ -620,6 +633,10 @@ class MarkdownHeaderTextSplitter(TextSplitter):
)
current_content.clear()
+ # Code block ends
+ if in_code_block and stripped_line == "```":
+ in_code_block = False
+
current_metadata = initial_metadata.copy()
if current_content:
lines_with_metadata.append(
diff --git a/dbgpt/storage/vector_store/__init__.py b/dbgpt/storage/vector_store/__init__.py
index 924680b39..9199bd08d 100644
--- a/dbgpt/storage/vector_store/__init__.py
+++ b/dbgpt/storage/vector_store/__init__.py
@@ -26,6 +26,12 @@ def _import_weaviate() -> Any:
return WeaviateStore
+def _import_oceanbase() -> Any:
+ from dbgpt.storage.vector_store.oceanbase_store import OceanBaseStore
+
+ return OceanBaseStore
+
+
def __getattr__(name: str) -> Any:
if name == "Chroma":
return _import_chroma()
@@ -35,8 +41,10 @@ def __getattr__(name: str) -> Any:
return _import_weaviate()
elif name == "PGVector":
return _import_pgvector()
+ elif name == "OceanBase":
+ return _import_oceanbase()
else:
raise AttributeError(f"Could not find: {name}")
-__all__ = ["Chroma", "Milvus", "Weaviate", "PGVector"]
+__all__ = ["Chroma", "Milvus", "Weaviate", "OceanBase", "PGVector"]
diff --git a/dbgpt/storage/vector_store/oceanbase_store.py b/dbgpt/storage/vector_store/oceanbase_store.py
new file mode 100644
index 000000000..fa58d7fdf
--- /dev/null
+++ b/dbgpt/storage/vector_store/oceanbase_store.py
@@ -0,0 +1,798 @@
+"""OceanBase vector store."""
+import json
+import logging
+import os
+import threading
+import uuid
+from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple
+
+from pydantic import Field
+from sqlalchemy import Column, Table, create_engine, insert, text
+from sqlalchemy.dialects.mysql import JSON, LONGTEXT, VARCHAR
+from sqlalchemy.types import String, UserDefinedType
+
+from dbgpt.core import Chunk, Document, Embeddings
+from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource
+from dbgpt.storage.vector_store.base import (
+ _COMMON_PARAMETERS,
+ VectorStoreBase,
+ VectorStoreConfig,
+)
+from dbgpt.storage.vector_store.filters import MetadataFilters
+from dbgpt.util.i18n_utils import _
+
+try:
+ from sqlalchemy.orm import declarative_base
+except ImportError:
+ from sqlalchemy.ext.declarative import declarative_base
+
+
+logger = logging.getLogger(__name__)
+sql_logger = None
+sql_dbg_log_path = os.getenv("OB_SQL_DBG_LOG_PATH", "")
+if sql_dbg_log_path != "":
+ sql_logger = logging.getLogger("ob_sql_dbg")
+ sql_logger.setLevel(logging.DEBUG)
+ file_handler = logging.FileHandler(sql_dbg_log_path)
+ file_handler.setLevel(logging.DEBUG)
+ formatter = logging.Formatter(
+ "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
+ )
+ file_handler.setFormatter(formatter)
+ sql_logger.addHandler(file_handler)
+
+_OCEANBASE_DEFAULT_EMBEDDING_DIM = 1536
+_OCEANBASE_DEFAULT_COLLECTION_NAME = "langchain_document"
+_OCEANBASE_DEFAULT_IVFFLAT_ROW_THRESHOLD = 10000
+_OCEANBASE_DEFAULT_RWLOCK_MAX_READER = 64
+
+Base = declarative_base()
+
+
+def ob_vector_from_db(value):
+ """Parse vector from oceanbase."""
+ return [float(v) for v in value[1:-1].split(",")]
+
+
+def ob_vector_to_db(value, dim=None):
+ """Parse vector to oceanbase vector constant type."""
+ if value is None:
+ return value
+
+ return "[" + ",".join([str(float(v)) for v in value]) + "]"
+
+
+class Vector(UserDefinedType):
+ """OceanBase Vector Column Type."""
+
+ cache_ok = True
+ _string = String()
+
+ def __init__(self, dim):
+ """Create a Vector column with dimemsion `dim`."""
+ super(UserDefinedType, self).__init__()
+ self.dim = dim
+
+ def get_col_spec(self, **kw):
+ """Get vector column definition in string format."""
+ return "VECTOR(%d)" % self.dim
+
+ def bind_processor(self, dialect):
+ """Get a processor to parse an array to oceanbase vector."""
+
+ def process(value):
+ return ob_vector_to_db(value, self.dim)
+
+ return process
+
+ def literal_processor(self, dialect):
+ """Get a string processor to parse an array to OceanBase Vector."""
+ string_literal_processor = self._string._cached_literal_processor(dialect)
+
+ def process(value):
+ return string_literal_processor(ob_vector_to_db(value, self.dim))
+
+ return process
+
+ def result_processor(self, dialect, coltype):
+ """Get a processor to parse OceanBase Vector to array."""
+
+ def process(value):
+ return ob_vector_from_db(value)
+
+ return process
+
+
+class OceanBaseCollectionStat:
+ """A tracer that maintains a table status in OceanBase."""
+
+ def __init__(self):
+ """Create OceanBaseCollectionStat instance."""
+ self._lock = threading.Lock()
+ self.maybe_collection_not_exist = True
+ self.maybe_collection_index_not_exist = True
+
+ def collection_exists(self):
+ """Set a table is existing."""
+ with self._lock:
+ self.maybe_collection_not_exist = False
+
+ def collection_index_exists(self):
+ """Set the index of a table is existing."""
+ with self._lock:
+ self.maybe_collection_index_not_exist = False
+
+ def collection_not_exists(self):
+ """Set a table is dropped."""
+ with self._lock:
+ self.maybe_collection_not_exist = True
+
+ def collection_index_not_exists(self):
+ """Set the index of a table is dropped."""
+ with self._lock:
+ self.maybe_collection_index_not_exist = True
+
+ def get_maybe_collection_not_exist(self):
+ """Get the existing status of a table."""
+ with self._lock:
+ return self.maybe_collection_not_exist
+
+ def get_maybe_collection_index_not_exist(self):
+ """Get the existing stats of the index of a table."""
+ with self._lock:
+ return self.maybe_collection_index_not_exist
+
+
+class OceanBaseGlobalRWLock:
+ """A global rwlock for OceanBase to do creating vector index table offline ddl."""
+
+ def __init__(self, max_readers) -> None:
+ """Create a rwlock."""
+ self.max_readers_ = max_readers
+ self.writer_entered_ = False
+ self.reader_cnt_ = 0
+ self.mutex_ = threading.Lock()
+ self.writer_cv_ = threading.Condition(self.mutex_)
+ self.reader_cv_ = threading.Condition(self.mutex_)
+
+ def rlock(self):
+ """Lock for reading."""
+ self.mutex_.acquire()
+ while self.writer_entered_ or self.max_readers_ == self.reader_cnt_:
+ self.reader_cv_.wait()
+ self.reader_cnt_ += 1
+ self.mutex_.release()
+
+ def runlock(self):
+ """Unlock reading lock."""
+ self.mutex_.acquire()
+ self.reader_cnt_ -= 1
+ if self.writer_entered_:
+ if self.reader_cnt_ == 0:
+ self.writer_cv_.notify(1)
+ else:
+ if self.max_readers_ - 1 == self.reader_cnt_:
+ self.reader_cv_.notify(1)
+ self.mutex_.release()
+
+ def wlock(self):
+ """Lock for writing."""
+ self.mutex_.acquire()
+ while self.writer_entered_:
+ self.reader_cv_.wait()
+ self.writer_entered_ = True
+ while 0 < self.reader_cnt_:
+ self.writer_cv_.wait()
+ self.mutex_.release()
+
+ def wunlock(self):
+ """Unlock writing lock."""
+ self.mutex_.acquire()
+ self.writer_entered_ = False
+ self.reader_cv_.notifyAll()
+ self.mutex_.release()
+
+ class OBRLock:
+ """Reading Lock Wrapper for `with` clause."""
+
+ def __init__(self, rwlock) -> None:
+ """Create reading lock wrapper instance."""
+ self.rwlock_ = rwlock
+
+ def __enter__(self):
+ """Lock."""
+ self.rwlock_.rlock()
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ """Unlock when exiting."""
+ self.rwlock_.runlock()
+
+ class OBWLock:
+ """Writing Lock Wrapper for `with` clause."""
+
+ def __init__(self, rwlock) -> None:
+ """Create writing lock wrapper instance."""
+ self.rwlock_ = rwlock
+
+ def __enter__(self):
+ """Lock."""
+ self.rwlock_.wlock()
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ """Unlock when exiting."""
+ self.rwlock_.wunlock()
+
+ def reader_lock(self):
+ """Get reading lock wrapper."""
+ return self.OBRLock(self)
+
+ def writer_lock(self):
+ """Get writing lock wrapper."""
+ return self.OBWLock(self)
+
+
+ob_grwlock = OceanBaseGlobalRWLock(_OCEANBASE_DEFAULT_RWLOCK_MAX_READER)
+
+
+class OceanBase:
+ """OceanBase Vector Store implementation."""
+
+ def __init__(
+ self,
+ connection_string: str,
+ embedding_function: Embeddings,
+ embedding_dimension: int = _OCEANBASE_DEFAULT_EMBEDDING_DIM,
+ collection_name: str = _OCEANBASE_DEFAULT_COLLECTION_NAME,
+ pre_delete_collection: bool = False,
+ logger: Optional[logging.Logger] = None,
+ engine_args: Optional[dict] = None,
+ delay_table_creation: bool = True,
+ enable_index: bool = False,
+ th_create_ivfflat_index: int = _OCEANBASE_DEFAULT_IVFFLAT_ROW_THRESHOLD,
+ sql_logger: Optional[logging.Logger] = None,
+ collection_stat: Optional[OceanBaseCollectionStat] = None,
+ enable_normalize_vector: bool = False,
+ ) -> None:
+ """Create OceanBase Vector Store instance."""
+ self.connection_string = connection_string
+ self.embedding_function = embedding_function
+ self.embedding_dimension = embedding_dimension
+ self.collection_name = collection_name
+ self.pre_delete_collection = pre_delete_collection
+ self.logger = logger or logging.getLogger(__name__)
+ self.delay_table_creation = delay_table_creation
+ self.th_create_ivfflat_index = th_create_ivfflat_index
+ self.enable_index = enable_index
+ self.sql_logger = sql_logger
+ self.collection_stat = collection_stat
+ self.enable_normalize_vector = enable_normalize_vector
+ self.__post_init__(engine_args)
+
+ def __post_init__(
+ self,
+ engine_args: Optional[dict] = None,
+ ) -> None:
+ """Create connection & vector table."""
+ _engine_args = engine_args or {}
+ if "pool_recycle" not in _engine_args:
+ _engine_args["pool_recycle"] = 3600
+ self.engine = create_engine(self.connection_string, **_engine_args)
+ self.create_collection()
+
+ @property
+ def embeddings(self) -> Embeddings:
+ """Get embedding function."""
+ return self.embedding_function
+
+ def create_collection(self) -> None:
+ """Create vector table."""
+ if self.pre_delete_collection:
+ self.delete_collection()
+ if not self.delay_table_creation and (
+ self.collection_stat is None
+ or self.collection_stat.get_maybe_collection_not_exist()
+ ):
+ self.create_table_if_not_exists()
+ if self.collection_stat is not None:
+ self.collection_stat.collection_exists()
+
+ def delete_collection(self) -> None:
+ """Drop vector table."""
+ drop_statement = text(f"DROP TABLE IF EXISTS {self.collection_name}")
+ if self.sql_logger is not None:
+ self.sql_logger.debug(f"Trying to delete collection: {drop_statement}")
+ with self.engine.connect() as conn, conn.begin():
+ conn.execute(drop_statement)
+ if self.collection_stat is not None:
+ self.collection_stat.collection_not_exists()
+ self.collection_stat.collection_index_not_exists()
+
+ def create_table_if_not_exists(self) -> None:
+ """Create vector table with SQL."""
+ create_table_query = f"""
+ CREATE TABLE IF NOT EXISTS `{self.collection_name}` (
+ id VARCHAR(40) NOT NULL,
+ embedding VECTOR({self.embedding_dimension}),
+ document LONGTEXT,
+ metadata JSON,
+ PRIMARY KEY (id)
+ )
+ """
+ if self.sql_logger is not None:
+ self.sql_logger.debug(f"Trying to create table: {create_table_query}")
+ with self.engine.connect() as conn, conn.begin():
+ # Create the table
+ conn.execute(text(create_table_query))
+
+ def create_collection_ivfflat_index_if_not_exists(self) -> None:
+ """Create ivfflat index table with SQL."""
+ create_index_query = f"""
+ CREATE INDEX IF NOT EXISTS `embedding_idx` on `{self.collection_name}` (
+ embedding l2
+ ) using ivfflat with (lists=20)
+ """
+ with ob_grwlock.writer_lock(), self.engine.connect() as conn, conn.begin():
+ # Create Ivfflat Index
+ if self.sql_logger is not None:
+ self.sql_logger.debug(
+ f"Trying to create ivfflat index: {create_index_query}"
+ )
+ conn.execute(text(create_index_query))
+
+ def add_texts(
+ self,
+ texts: Iterable[str],
+ metadatas: Optional[List[dict]] = None,
+ ids: Optional[List[str]] = None,
+ batch_size: int = 500,
+ **kwargs: Any,
+ ) -> List[str]:
+ """Insert texts into vector table."""
+ if ids is None:
+ ids = [str(uuid.uuid1()) for _ in texts]
+
+ embeddings = self.embedding_function.embed_documents(list(texts))
+
+ if len(embeddings) == 0:
+ return ids
+
+ if not metadatas:
+ metadatas = [{} for _ in texts]
+
+ if self.delay_table_creation and (
+ self.collection_stat is None
+ or self.collection_stat.get_maybe_collection_not_exist()
+ ):
+ self.embedding_dimension = len(embeddings[0])
+ self.create_table_if_not_exists()
+ self.delay_table_creation = False
+ if self.collection_stat is not None:
+ self.collection_stat.collection_exists()
+
+ chunks_table = Table(
+ self.collection_name,
+ Base.metadata,
+ Column("id", VARCHAR(40), primary_key=True),
+ Column("embedding", Vector(self.embedding_dimension)),
+ Column("document", LONGTEXT, nullable=True),
+ Column("metadata", JSON, nullable=True), # filter
+ keep_existing=True,
+ )
+
+ row_count_query = f"""
+ SELECT COUNT(*) as cnt FROM `{self.collection_name}`
+ """
+ chunks_table_data = []
+ with self.engine.connect() as conn, conn.begin():
+ for document, metadata, chunk_id, embedding in zip(
+ texts, metadatas, ids, embeddings
+ ):
+ chunks_table_data.append(
+ {
+ "id": chunk_id,
+ "embedding": embedding
+ if not self.enable_normalize_vector
+ else self._normalization_vectors(embedding),
+ "document": document,
+ "metadata": metadata,
+ }
+ )
+
+ # Execute the batch insert when the batch size is reached
+ if len(chunks_table_data) == batch_size:
+ with ob_grwlock.reader_lock():
+ if self.sql_logger is not None:
+ insert_sql_for_log = str(
+ insert(chunks_table).values(chunks_table_data)
+ )
+ self.sql_logger.debug(
+ f"""Trying to insert vectors:
+ {insert_sql_for_log}"""
+ )
+ conn.execute(insert(chunks_table).values(chunks_table_data))
+ # Clear the chunks_table_data list for the next batch
+ chunks_table_data.clear()
+
+ # Insert any remaining records that didn't make up a full batch
+ if chunks_table_data:
+ with ob_grwlock.reader_lock():
+ if self.sql_logger is not None:
+ insert_sql_for_log = str(
+ insert(chunks_table).values(chunks_table_data)
+ )
+ self.sql_logger.debug(
+ f"""Trying to insert vectors:
+ {insert_sql_for_log}"""
+ )
+ conn.execute(insert(chunks_table).values(chunks_table_data))
+
+ if self.enable_index and (
+ self.collection_stat is None
+ or self.collection_stat.get_maybe_collection_index_not_exist()
+ ):
+ with ob_grwlock.reader_lock():
+ row_cnt_res = conn.execute(text(row_count_query))
+ for row in row_cnt_res:
+ if row.cnt > self.th_create_ivfflat_index:
+ self.create_collection_ivfflat_index_if_not_exists()
+ if self.collection_stat is not None:
+ self.collection_stat.collection_index_exists()
+
+ return ids
+
+ def similarity_search(
+ self,
+ query: str,
+ k: int = 4,
+ filter: Optional[MetadataFilters] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Do similarity search via query in String."""
+ embedding = self.embedding_function.embed_query(query)
+ docs = self.similarity_search_by_vector(embedding=embedding, k=k, filter=filter)
+ return docs
+
+ def similarity_search_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ filter: Optional[MetadataFilters] = None,
+ **kwargs: Any,
+ ) -> List[Document]:
+ """Do similarity search via query vector."""
+ docs_and_scores = self.similarity_search_with_score_by_vector(
+ embedding=embedding, k=k, filter=filter
+ )
+ return [doc for doc, _ in docs_and_scores]
+
+ def similarity_search_with_score_by_vector(
+ self,
+ embedding: List[float],
+ k: int = 4,
+ filter: Optional[MetadataFilters] = None,
+ score_threshold: Optional[float] = None,
+ ) -> List[Tuple[Document, float]]:
+ """Do similarity search via query vector with score."""
+ try:
+ from sqlalchemy.engine import Row
+ except ImportError:
+ raise ImportError(
+ "Could not import Row from sqlalchemy.engine. "
+ "Please 'pip install sqlalchemy>=1.4'."
+ )
+
+ # filter is not support in OceanBase currently.
+
+ # normailze embedding vector
+ if self.enable_normalize_vector:
+ embedding = self._normalization_vectors(embedding)
+
+ embedding_str = ob_vector_to_db(embedding, self.embedding_dimension)
+ vector_distance_op = "<@>" if self.enable_normalize_vector else "<->"
+ sql_query = f"""
+ SELECT document, metadata, embedding {vector_distance_op} '{embedding_str}'
+ as distance
+ FROM {self.collection_name}
+ ORDER BY embedding {vector_distance_op} '{embedding_str}'
+ LIMIT :k
+ """
+ sql_query_str_for_log = f"""
+ SELECT document, metadata, embedding {vector_distance_op} '?' as distance
+ FROM {self.collection_name}
+ ORDER BY embedding {vector_distance_op} '?'
+ LIMIT {k}
+ """
+
+ params = {"k": k}
+ try:
+ with ob_grwlock.reader_lock(), self.engine.connect() as conn:
+ if self.sql_logger is not None:
+ self.sql_logger.debug(
+ f"Trying to do similarity search: {sql_query_str_for_log}"
+ )
+ results: Sequence[Row] = conn.execute(
+ text(sql_query), params
+ ).fetchall()
+
+ if (score_threshold is not None) and self.enable_normalize_vector:
+ documents_with_scores = [
+ (
+ Document(
+ content=result.document,
+ metadata=json.loads(result.metadata),
+ ),
+ result.distance,
+ )
+ for result in results
+ if result.distance >= score_threshold
+ ]
+ else:
+ documents_with_scores = [
+ (
+ Document(
+ content=result.document,
+ metadata=json.loads(result.metadata),
+ ),
+ result.distance,
+ )
+ for result in results
+ ]
+ return documents_with_scores
+ except Exception as e:
+ self.logger.error("similarity_search_with_score_by_vector failed:", str(e))
+ return []
+
+ def similarity_search_with_score(
+ self,
+ query: str,
+ k: int = 4,
+ filter: Optional[MetadataFilters] = None,
+ score_threshold: Optional[float] = None,
+ ) -> List[Tuple[Document, float]]:
+ """Do similarity search via query String with score."""
+ embedding = self.embedding_function.embed_query(query)
+ docs = self.similarity_search_with_score_by_vector(
+ embedding=embedding, k=k, filter=filter, score_threshold=score_threshold
+ )
+ return docs
+
+ def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]:
+ """Delete vectors from vector table."""
+ if ids is None:
+ raise ValueError("No ids provided to delete.")
+
+ # Define the table schema
+ chunks_table = Table(
+ self.collection_name,
+ Base.metadata,
+ Column("id", VARCHAR(40), primary_key=True),
+ Column("embedding", Vector(self.embedding_dimension)),
+ Column("document", LONGTEXT, nullable=True),
+ Column("metadata", JSON, nullable=True), # filter
+ keep_existing=True,
+ )
+
+ try:
+ with self.engine.connect() as conn, conn.begin():
+ delete_condition = chunks_table.c.id.in_(ids)
+ delete_stmt = chunks_table.delete().where(delete_condition)
+ with ob_grwlock.reader_lock():
+ if self.sql_logger is not None:
+ self.sql_logger.debug(
+ f"Trying to delete vectors: {str(delete_stmt)}"
+ )
+ conn.execute(delete_stmt)
+ return True
+ except Exception as e:
+ self.logger.error("Delete operation failed:", str(e))
+ return False
+
+ def _normalization_vectors(self, vector):
+ import numpy as np
+
+ norm = np.linalg.norm(vector)
+ return (vector / norm).tolist()
+
+ @classmethod
+ def connection_string_from_db_params(
+ cls,
+ host: str,
+ port: int,
+ database: str,
+ user: str,
+ password: str,
+ ) -> str:
+ """Get connection string."""
+ return f"mysql+pymysql://{user}:{password}@{host}:{port}/{database}"
+
+
+ob_collection_stats_lock = threading.Lock()
+ob_collection_stats: Dict[str, OceanBaseCollectionStat] = {}
+
+
+@register_resource(
+ _("OceanBase Vector Store"),
+ "oceanbase_vector_store",
+ category=ResourceCategory.VECTOR_STORE,
+ parameters=[
+ *_COMMON_PARAMETERS,
+ Parameter.build_from(
+ _("OceanBase Host"),
+ "ob_host",
+ str,
+ description=_("oceanbase host"),
+ optional=True,
+ default=None,
+ ),
+ Parameter.build_from(
+ _("OceanBase Port"),
+ "ob_port",
+ int,
+ description=_("oceanbase port"),
+ optional=True,
+ default=None,
+ ),
+ Parameter.build_from(
+ _("OceanBase User"),
+ "ob_user",
+ str,
+ description=_("user to login"),
+ optional=True,
+ default=None,
+ ),
+ Parameter.build_from(
+ _("OceanBase Password"),
+ "ob_password",
+ str,
+ description=_("password to login"),
+ optional=True,
+ default=None,
+ ),
+ Parameter.build_from(
+ _("OceanBase Database"),
+ "ob_database",
+ str,
+ description=_("database for vector tables"),
+ optional=True,
+ default=None,
+ ),
+ ],
+ description="OceanBase vector store.",
+)
+class OceanBaseConfig(VectorStoreConfig):
+ """OceanBase vector store config."""
+
+ class Config:
+ """Config for BaseModel."""
+
+ arbitrary_types_allowed = True
+
+ """OceanBase config"""
+ ob_host: str = Field(
+ default="127.0.0.1",
+ description="oceanbase host",
+ )
+ ob_port: int = Field(
+ default=2881,
+ description="oceanbase port",
+ )
+ ob_user: str = Field(
+ default="root@test",
+ description="user to login",
+ )
+ ob_password: str = Field(
+ default="",
+ description="password to login",
+ )
+ ob_database: str = Field(
+ default="test",
+ description="database for vector tables",
+ )
+
+
+class OceanBaseStore(VectorStoreBase):
+ """OceanBase vector store."""
+
+ def __init__(self, vector_store_config: OceanBaseConfig) -> None:
+ """Create a OceanBaseStore instance."""
+ if vector_store_config.embedding_fn is None:
+ raise ValueError("embedding_fn is required for OceanBaseStore")
+
+ self.embeddings = vector_store_config.embedding_fn
+ self.collection_name = vector_store_config.name
+ vector_store_config = vector_store_config.dict()
+ self.OB_HOST = str(
+ vector_store_config.get("ob_host") or os.getenv("OB_HOST", "127.0.0.1")
+ )
+ self.OB_PORT = int(
+ vector_store_config.get("ob_port") or int(os.getenv("OB_PORT", "2881"))
+ )
+ self.OB_USER = str(
+ vector_store_config.get("ob_user") or os.getenv("OB_USER", "root@test")
+ )
+ self.OB_PASSWORD = str(
+ vector_store_config.get("ob_password") or os.getenv("OB_PASSWORD", "")
+ )
+ self.OB_DATABASE = str(
+ vector_store_config.get("ob_database") or os.getenv("OB_DATABASE", "test")
+ )
+ self.OB_ENABLE_NORMALIZE_VECTOR = bool(
+ os.getenv("OB_ENABLE_NORMALIZE_VECTOR", "")
+ )
+ self.connection_string = OceanBase.connection_string_from_db_params(
+ self.OB_HOST, self.OB_PORT, self.OB_DATABASE, self.OB_USER, self.OB_PASSWORD
+ )
+ self.logger = logger
+ with ob_collection_stats_lock:
+ if ob_collection_stats.get(self.collection_name) is None:
+ ob_collection_stats[self.collection_name] = OceanBaseCollectionStat()
+ self.collection_stat = ob_collection_stats[self.collection_name]
+
+ self.vector_store_client = OceanBase(
+ connection_string=self.connection_string,
+ embedding_function=self.embeddings,
+ collection_name=self.collection_name,
+ logger=self.logger,
+ sql_logger=sql_logger,
+ enable_index=bool(os.getenv("OB_ENABLE_INDEX", "")),
+ collection_stat=self.collection_stat,
+ enable_normalize_vector=self.OB_ENABLE_NORMALIZE_VECTOR,
+ )
+
+ def similar_search(
+ self, text, topk, filters: Optional[MetadataFilters] = None, **kwargs: Any
+ ) -> List[Chunk]:
+ """Perform a search on a query string and return results."""
+ self.logger.info("OceanBase: similar_search..")
+ documents = self.vector_store_client.similarity_search(
+ text, topk, filter=filters
+ )
+ return [Chunk(content=doc.content, metadata=doc.metadata) for doc in documents]
+
+ def similar_search_with_scores(
+ self,
+ text,
+ topk,
+ score_threshold: float,
+ filters: Optional[MetadataFilters] = None,
+ ) -> List[Chunk]:
+ """Perform a search on a query string and return results with score."""
+ self.logger.info("OceanBase: similar_search_with_scores..")
+ docs_and_scores = self.vector_store_client.similarity_search_with_score(
+ text, topk, filter=filters
+ )
+ return [
+ Chunk(content=doc.content, metadata=doc.metadata, score=score)
+ for doc, score in docs_and_scores
+ ]
+
+ def vector_name_exists(self):
+ """Whether vector name exists."""
+ self.logger.info("OceanBase: vector_name_exists..")
+ try:
+ self.vector_store_client.create_collection()
+ return True
+ except Exception as e:
+ logger.error("vector_name_exists error", e.message)
+ return False
+
+ def load_document(self, chunks: List[Chunk]) -> List[str]:
+ """Load document in vector database."""
+ self.logger.info("OceanBase: load_document..")
+ # lc_documents = [Chunk.chunk2langchain(chunk) for chunk in chunks]
+ texts = [chunk.content for chunk in chunks]
+ metadatas = [chunk.metadata for chunk in chunks]
+ ids = self.vector_store_client.add_texts(texts=texts, metadatas=metadatas)
+ return ids
+
+ def delete_vector_name(self, vector_name):
+ """Delete vector name."""
+ self.logger.info("OceanBase: delete_vector_name..")
+ return self.vector_store_client.delete_collection()
+
+ def delete_by_ids(self, ids):
+ """Delete vector by ids."""
+ self.logger.info("OceanBase: delete_by_ids..")
+ ids = ids.split(",")
+ if len(ids) > 0:
+ self.vector_store_client.delete(ids)
diff --git a/docker/allinone/Dockerfile b/docker/allinone/Dockerfile
index 48c8362cf..2422ed253 100644
--- a/docker/allinone/Dockerfile
+++ b/docker/allinone/Dockerfile
@@ -2,6 +2,7 @@ ARG BASE_IMAGE="eosphorosai/dbgpt:latest"
FROM ${BASE_IMAGE}
+RUN pip install dashscope
RUN apt-get update && apt-get install -y wget gnupg lsb-release net-tools
RUN apt-key adv --keyserver keyserver.ubuntu.com --recv-keys 467B942D3A79BD29
@@ -9,7 +10,7 @@ RUN apt-key adv --keyserver keyserver.ubuntu.com --recv-keys 467B942D3A79BD29
RUN wget https://dev.mysql.com/get/mysql-apt-config_0.8.17-1_all.deb
RUN dpkg -i mysql-apt-config_0.8.17-1_all.deb
-RUN apt-get update && apt-get install -y mysql-server && apt-get clean
+RUN apt-key adv --keyserver keyserver.ubuntu.com --recv-keys B7B3B788A8D3785C && apt-get update && apt-get install -y mysql-server && apt-get clean
# Remote access
RUN sed -i 's/bind-address\s*=\s*127.0.0.1/bind-address = 0.0.0.0/g' /etc/mysql/mysql.conf.d/mysqld.cnf \
diff --git a/docker/compose_examples/dbgpt-oceanbase-docker-compose.yml b/docker/compose_examples/dbgpt-oceanbase-docker-compose.yml
new file mode 100644
index 000000000..64e82078a
--- /dev/null
+++ b/docker/compose_examples/dbgpt-oceanbase-docker-compose.yml
@@ -0,0 +1,42 @@
+version: '3.8'
+
+services:
+ oceanbase:
+ image: oceanbase/oceanbase-ce:vector
+ ports:
+ - 2881:2881
+ networks:
+ - dbgptnet
+ dbgpt:
+ image: eosphorosai/dbgpt-allinone
+ environment:
+ - DBGPT_WEBSERVER_PORT=12345
+ - VECTOR_STORE_TYPE=OceanBase
+ - OB_HOST=oceanbase
+ - OB_HOST=127.0.0.1
+ - OB_PORT=2881
+ - OB_USER=root@test
+ - OB_DATABASE=test
+ - OB_SQL_DBG_LOG_PATH=/sql_log/sql.log
+ - LOCAL_DB_TYPE=sqlite
+ - LLM_MODEL=tongyi_proxyllm
+ - PROXYLLM_BACKEND=qwen-plus
+ - EMBEDDING_MODEL=text2vec
+ # your api key
+ # - TONGYI_PROXY_API_KEY={your-api-key}
+ - LANGUAGE=zh
+ # - OB_ENABLE_NORMALIZE_VECTOR=True
+ # - OB_ENABLE_INDEX=True
+ ports:
+ - 3306:3306
+ - 12345:12345
+ volumes:
+ # - {your-ob-sql-dbg-log-dir}:/sql_log
+ # - {your-model-dir}:/app/models
+ networks:
+ - dbgptnet
+
+networks:
+ dbgptnet:
+ driver: bridge
+ name: dbgptnet
\ No newline at end of file
diff --git a/docs/docs/application/advanced_tutorial/rag.md b/docs/docs/application/advanced_tutorial/rag.md
index 1cfeab5c2..acb015ae6 100644
--- a/docs/docs/application/advanced_tutorial/rag.md
+++ b/docs/docs/application/advanced_tutorial/rag.md
@@ -69,6 +69,7 @@ import TabItem from '@theme/TabItem';
{label: 'Chroma', value: 'Chroma'},
{label: 'Milvus', value: 'Milvus'},
{label: 'Weaviate', value: 'Weaviate'},
+ {label: 'OceanBase', value: 'OceanBase'},
]}>
@@ -106,6 +107,25 @@ set ``VECTOR_STORE_TYPE`` in ``.env`` file
VECTOR_STORE_TYPE=Weaviate
#WEAVIATE_URL=https://kt-region-m8hcy0wc.weaviate.network
```
+
+
+
+set ``VECTOR_STORE_TYPE`` in ``.env`` file
+
+```shell
+OB_HOST=127.0.0.1
+OB_PORT=2881
+OB_USER=root@test
+OB_DATABASE=test
+## Optional
+# OB_PASSWORD=
+## Optional: SQL statements executed by OceanBase is recorded in the log file specified by {OB_SQL_DBG_LOG_PATH}.
+# OB_SQL_DBG_LOG_PATH={your-sql-dbg-log-dir}/sql.log
+## Optional: If {OB_ENABLE_NORMALIZE_VECTOR} is set, the vector stored in OceanBase is normalized.
+# OB_ENABLE_NORMALIZE_VECTOR=True
+## Optional: If {OB_ENABLE_INDEX} is set, OceanBase will automatically create a vector index table.
+# OB_ENABLE_INDEX=True
+```
diff --git a/docs/docs/faq/kbqa.md b/docs/docs/faq/kbqa.md
index d35715b73..ea3e2ac74 100644
--- a/docs/docs/faq/kbqa.md
+++ b/docs/docs/faq/kbqa.md
@@ -18,8 +18,29 @@ git lfs clone https://huggingface.co/GanymedeNil/text2vec-large-chinese
Update .env file and set VECTOR_STORE_TYPE.
-DB-GPT currently support Chroma(Default), Milvus(>2.1), Weaviate vector database.
-If you want to change vector db, Update your .env, set your vector store type, VECTOR_STORE_TYPE=Chroma (now only support Chroma and Milvus(>2.1), if you set Milvus, please set MILVUS_URL and MILVUS_PORT)
+DB-GPT currently support Chroma(Default), Milvus(>2.1), Weaviate, OceanBase vector database.
+If you want to change vector db, Update your .env, set your vector store type, VECTOR_STORE_TYPE=Chroma (now only support Chroma and Milvus(>2.1), if you set Milvus, please set MILVUS_URL and MILVUS_PORT).
+
+If you want to use OceanBase, please first start a docker container via the following command:
+```shell
+docker run -p 2881:2881 --name obvec -d oceanbase/oceanbase-ce:vector
+```
+Then set the following variables in the .env file:
+```shell
+VECTOR_STORE_TYPE=OceanBase
+OB_HOST=127.0.0.1
+OB_PORT=2881
+OB_USER=root@test
+OB_DATABASE=test
+## Optional
+# OB_PASSWORD=
+## Optional: SQL statements executed by OceanBase is recorded in the log file specified by {OB_SQL_DBG_LOG_PATH}.
+# OB_SQL_DBG_LOG_PATH={your-sql-dbg-log-dir}/sql.log
+## Optional: If {OB_ENABLE_NORMALIZE_VECTOR} is set, the vector stored in OceanBase is normalized.
+# OB_ENABLE_NORMALIZE_VECTOR=True
+## Optional: If {OB_ENABLE_INDEX} is set, OceanBase will automatically create a vector index table.
+# OB_ENABLE_INDEX=True
+```
If you want to support more vector db, you can integrate yourself.[how to integrate](https://db-gpt.readthedocs.io/en/latest/modules/vector.html)
```commandline
#*******************************************************************#
diff --git a/i18n/locales/zh_CN/LC_MESSAGES/dbgpt_storage.po b/i18n/locales/zh_CN/LC_MESSAGES/dbgpt_storage.po
index 267db25fc..75c2ed1c3 100644
--- a/i18n/locales/zh_CN/LC_MESSAGES/dbgpt_storage.po
+++ b/i18n/locales/zh_CN/LC_MESSAGES/dbgpt_storage.po
@@ -147,6 +147,50 @@ msgid ""
"connection string."
msgstr "向量存储的连接字符串,如果未设置,将使用默认的连接字符串。"
+#: ../dbgpt/storage/vector_store/oceanbase_store.py:542
+msgid "OceanBase Vector Store"
+msgstr "OceanBase 向量存储"
+
+#: ../dbgpt/storage/vector_store/oceanbase_store.py:548
+msgid "OceanBase Host"
+msgstr "OceanBase 主机地址"
+
+#: ../dbgpt/storage/vector_store/oceanbase_store.py:552
+msgid "oceanbase host"
+msgstr "oceanbase 主机地址"
+
+#: ../dbgpt/storage/vector_store/oceanbase_store.py:558
+msgid "OceanBase Port"
+msgstr "OceanBase 端口"
+
+#: ../dbgpt/storage/vector_store/oceanbase_store.py:562
+msgid "oceanbase port"
+msgstr "oceanbase 端口"
+
+#: ../dbgpt/storage/vector_store/oceanbase_store.py:568
+msgid "OceanBase User"
+msgstr "OceanBase 用户名"
+
+#: ../dbgpt/storage/vector_store/oceanbase_store.py:572
+msgid "user to login"
+msgstr "登录用户名"
+
+#: ../dbgpt/storage/vector_store/oceanbase_store.py:578
+msgid "OceanBase Password"
+msgstr "OceanBase 密码"
+
+#: ../dbgpt/storage/vector_store/oceanbase_store.py:582
+msgid "password to login"
+msgstr "登录密码"
+
+#: ../dbgpt/storage/vector_store/oceanbase_store.py:588
+msgid "OceanBase Database"
+msgstr "OceanBase 数据库名"
+
+#: ../dbgpt/storage/vector_store/oceanbase_store.py:592
+msgid "database for vector tables"
+msgstr "存放向量表的数据库名"
+
#: ../dbgpt/storage/vector_store/base.py:19
msgid "Collection Name"
msgstr "集合名称"