mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-22 11:51:42 +00:00
support oceanbase as an optional vector database (#1435)
Signed-off-by: shanhaikang.shk <shanhaikang.shk@oceanbase.com>
This commit is contained in:
parent
91c1371234
commit
6520367623
@ -238,6 +238,18 @@ class Config(metaclass=Singleton):
|
||||
self.MILVUS_USERNAME = os.getenv("MILVUS_USERNAME", None)
|
||||
self.MILVUS_PASSWORD = os.getenv("MILVUS_PASSWORD", None)
|
||||
|
||||
## OceanBase Configuration
|
||||
self.OB_HOST = os.getenv("OB_HOST", "127.0.0.1")
|
||||
self.OB_PORT = int(os.getenv("OB_PORT", "2881"))
|
||||
self.OB_USER = os.getenv("OB_USER", "root")
|
||||
self.OB_PASSWORD = os.getenv("OB_PASSWORD", "")
|
||||
self.OB_DATABASE = os.getenv("OB_DATABASE", "test")
|
||||
self.OB_SQL_DBG_LOG_PATH = os.getenv("OB_SQL_DBG_LOG_PATH", "")
|
||||
self.OB_ENABLE_NORMALIZE_VECTOR = bool(
|
||||
os.getenv("OB_ENABLE_NORMALIZE_VECTOR", "")
|
||||
)
|
||||
self.OB_ENABLE_INDEX = bool(os.getenv("OB_ENABLE_INDEX", ""))
|
||||
|
||||
# QLoRA
|
||||
self.QLoRA = os.getenv("QUANTIZE_QLORA", "True")
|
||||
self.IS_LOAD_8BIT = os.getenv("QUANTIZE_8bit", "True").lower() == "true"
|
||||
|
@ -25,11 +25,11 @@ def test_md_header_text_splitter() -> None:
|
||||
output = markdown_splitter.split_text(markdown_document)
|
||||
expected_output = [
|
||||
Chunk(
|
||||
content="{'Header 1': 'dbgpt', 'Header 2': 'description'}, my name is dbgpt",
|
||||
content='"dbgpt-description": my name is dbgpt',
|
||||
metadata={"Header 1": "dbgpt", "Header 2": "description"},
|
||||
),
|
||||
Chunk(
|
||||
content="{'Header 1': 'dbgpt', 'Header 2': 'content'}, my name is aries",
|
||||
content='"dbgpt-content": my name is aries',
|
||||
metadata={"Header 1": "dbgpt", "Header 2": "content"},
|
||||
),
|
||||
]
|
||||
|
@ -515,7 +515,8 @@ class MarkdownHeaderTextSplitter(TextSplitter):
|
||||
aggregated_chunks[-1]["content"] += " \n" + line["content"]
|
||||
else:
|
||||
# Otherwise, append the current line to the aggregated list
|
||||
line["content"] = f"{line['metadata']}, " + line["content"]
|
||||
subtitles = "-".join((list(line["metadata"].values())))
|
||||
line["content"] = f'"{subtitles}": ' + line["content"]
|
||||
aggregated_chunks.append(line)
|
||||
|
||||
return [
|
||||
@ -557,16 +558,28 @@ class MarkdownHeaderTextSplitter(TextSplitter):
|
||||
# header_stack: List[Dict[str, Union[int, str]]] = []
|
||||
header_stack: List[HeaderType] = []
|
||||
initial_metadata: Dict[str, str] = {}
|
||||
# Determine whether a line is within a markdown code block.
|
||||
in_code_block = False
|
||||
for line in lines:
|
||||
stripped_line = line.strip()
|
||||
# A code frame starts with "```"
|
||||
with_code_frame = stripped_line.startswith("```") and (
|
||||
stripped_line != "```"
|
||||
)
|
||||
if (not in_code_block) and with_code_frame:
|
||||
in_code_block = True
|
||||
# Check each line against each of the header types (e.g., #, ##)
|
||||
for sep, name in self.headers_to_split_on:
|
||||
# Check if line starts with a header that we intend to split on
|
||||
if stripped_line.startswith(sep) and (
|
||||
# Header with no text OR header is followed by space
|
||||
# Both are valid conditions that sep is being used a header
|
||||
len(stripped_line) == len(sep)
|
||||
or stripped_line[len(sep)] == " "
|
||||
if (
|
||||
(not in_code_block)
|
||||
and stripped_line.startswith(sep)
|
||||
and (
|
||||
# Header with no text OR header is followed by space
|
||||
# Both are valid conditions that sep is being used a header
|
||||
len(stripped_line) == len(sep)
|
||||
or stripped_line[len(sep)] == " "
|
||||
)
|
||||
):
|
||||
# Ensure we are tracking the header as metadata
|
||||
if name is not None:
|
||||
@ -620,6 +633,10 @@ class MarkdownHeaderTextSplitter(TextSplitter):
|
||||
)
|
||||
current_content.clear()
|
||||
|
||||
# Code block ends
|
||||
if in_code_block and stripped_line == "```":
|
||||
in_code_block = False
|
||||
|
||||
current_metadata = initial_metadata.copy()
|
||||
if current_content:
|
||||
lines_with_metadata.append(
|
||||
|
@ -26,6 +26,12 @@ def _import_weaviate() -> Any:
|
||||
return WeaviateStore
|
||||
|
||||
|
||||
def _import_oceanbase() -> Any:
|
||||
from dbgpt.storage.vector_store.oceanbase_store import OceanBaseStore
|
||||
|
||||
return OceanBaseStore
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
if name == "Chroma":
|
||||
return _import_chroma()
|
||||
@ -35,8 +41,10 @@ def __getattr__(name: str) -> Any:
|
||||
return _import_weaviate()
|
||||
elif name == "PGVector":
|
||||
return _import_pgvector()
|
||||
elif name == "OceanBase":
|
||||
return _import_oceanbase()
|
||||
else:
|
||||
raise AttributeError(f"Could not find: {name}")
|
||||
|
||||
|
||||
__all__ = ["Chroma", "Milvus", "Weaviate", "PGVector"]
|
||||
__all__ = ["Chroma", "Milvus", "Weaviate", "OceanBase", "PGVector"]
|
||||
|
798
dbgpt/storage/vector_store/oceanbase_store.py
Normal file
798
dbgpt/storage/vector_store/oceanbase_store.py
Normal file
@ -0,0 +1,798 @@
|
||||
"""OceanBase vector store."""
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import uuid
|
||||
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple
|
||||
|
||||
from pydantic import Field
|
||||
from sqlalchemy import Column, Table, create_engine, insert, text
|
||||
from sqlalchemy.dialects.mysql import JSON, LONGTEXT, VARCHAR
|
||||
from sqlalchemy.types import String, UserDefinedType
|
||||
|
||||
from dbgpt.core import Chunk, Document, Embeddings
|
||||
from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource
|
||||
from dbgpt.storage.vector_store.base import (
|
||||
_COMMON_PARAMETERS,
|
||||
VectorStoreBase,
|
||||
VectorStoreConfig,
|
||||
)
|
||||
from dbgpt.storage.vector_store.filters import MetadataFilters
|
||||
from dbgpt.util.i18n_utils import _
|
||||
|
||||
try:
|
||||
from sqlalchemy.orm import declarative_base
|
||||
except ImportError:
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
sql_logger = None
|
||||
sql_dbg_log_path = os.getenv("OB_SQL_DBG_LOG_PATH", "")
|
||||
if sql_dbg_log_path != "":
|
||||
sql_logger = logging.getLogger("ob_sql_dbg")
|
||||
sql_logger.setLevel(logging.DEBUG)
|
||||
file_handler = logging.FileHandler(sql_dbg_log_path)
|
||||
file_handler.setLevel(logging.DEBUG)
|
||||
formatter = logging.Formatter(
|
||||
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
file_handler.setFormatter(formatter)
|
||||
sql_logger.addHandler(file_handler)
|
||||
|
||||
_OCEANBASE_DEFAULT_EMBEDDING_DIM = 1536
|
||||
_OCEANBASE_DEFAULT_COLLECTION_NAME = "langchain_document"
|
||||
_OCEANBASE_DEFAULT_IVFFLAT_ROW_THRESHOLD = 10000
|
||||
_OCEANBASE_DEFAULT_RWLOCK_MAX_READER = 64
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
def ob_vector_from_db(value):
|
||||
"""Parse vector from oceanbase."""
|
||||
return [float(v) for v in value[1:-1].split(",")]
|
||||
|
||||
|
||||
def ob_vector_to_db(value, dim=None):
|
||||
"""Parse vector to oceanbase vector constant type."""
|
||||
if value is None:
|
||||
return value
|
||||
|
||||
return "[" + ",".join([str(float(v)) for v in value]) + "]"
|
||||
|
||||
|
||||
class Vector(UserDefinedType):
|
||||
"""OceanBase Vector Column Type."""
|
||||
|
||||
cache_ok = True
|
||||
_string = String()
|
||||
|
||||
def __init__(self, dim):
|
||||
"""Create a Vector column with dimemsion `dim`."""
|
||||
super(UserDefinedType, self).__init__()
|
||||
self.dim = dim
|
||||
|
||||
def get_col_spec(self, **kw):
|
||||
"""Get vector column definition in string format."""
|
||||
return "VECTOR(%d)" % self.dim
|
||||
|
||||
def bind_processor(self, dialect):
|
||||
"""Get a processor to parse an array to oceanbase vector."""
|
||||
|
||||
def process(value):
|
||||
return ob_vector_to_db(value, self.dim)
|
||||
|
||||
return process
|
||||
|
||||
def literal_processor(self, dialect):
|
||||
"""Get a string processor to parse an array to OceanBase Vector."""
|
||||
string_literal_processor = self._string._cached_literal_processor(dialect)
|
||||
|
||||
def process(value):
|
||||
return string_literal_processor(ob_vector_to_db(value, self.dim))
|
||||
|
||||
return process
|
||||
|
||||
def result_processor(self, dialect, coltype):
|
||||
"""Get a processor to parse OceanBase Vector to array."""
|
||||
|
||||
def process(value):
|
||||
return ob_vector_from_db(value)
|
||||
|
||||
return process
|
||||
|
||||
|
||||
class OceanBaseCollectionStat:
|
||||
"""A tracer that maintains a table status in OceanBase."""
|
||||
|
||||
def __init__(self):
|
||||
"""Create OceanBaseCollectionStat instance."""
|
||||
self._lock = threading.Lock()
|
||||
self.maybe_collection_not_exist = True
|
||||
self.maybe_collection_index_not_exist = True
|
||||
|
||||
def collection_exists(self):
|
||||
"""Set a table is existing."""
|
||||
with self._lock:
|
||||
self.maybe_collection_not_exist = False
|
||||
|
||||
def collection_index_exists(self):
|
||||
"""Set the index of a table is existing."""
|
||||
with self._lock:
|
||||
self.maybe_collection_index_not_exist = False
|
||||
|
||||
def collection_not_exists(self):
|
||||
"""Set a table is dropped."""
|
||||
with self._lock:
|
||||
self.maybe_collection_not_exist = True
|
||||
|
||||
def collection_index_not_exists(self):
|
||||
"""Set the index of a table is dropped."""
|
||||
with self._lock:
|
||||
self.maybe_collection_index_not_exist = True
|
||||
|
||||
def get_maybe_collection_not_exist(self):
|
||||
"""Get the existing status of a table."""
|
||||
with self._lock:
|
||||
return self.maybe_collection_not_exist
|
||||
|
||||
def get_maybe_collection_index_not_exist(self):
|
||||
"""Get the existing stats of the index of a table."""
|
||||
with self._lock:
|
||||
return self.maybe_collection_index_not_exist
|
||||
|
||||
|
||||
class OceanBaseGlobalRWLock:
|
||||
"""A global rwlock for OceanBase to do creating vector index table offline ddl."""
|
||||
|
||||
def __init__(self, max_readers) -> None:
|
||||
"""Create a rwlock."""
|
||||
self.max_readers_ = max_readers
|
||||
self.writer_entered_ = False
|
||||
self.reader_cnt_ = 0
|
||||
self.mutex_ = threading.Lock()
|
||||
self.writer_cv_ = threading.Condition(self.mutex_)
|
||||
self.reader_cv_ = threading.Condition(self.mutex_)
|
||||
|
||||
def rlock(self):
|
||||
"""Lock for reading."""
|
||||
self.mutex_.acquire()
|
||||
while self.writer_entered_ or self.max_readers_ == self.reader_cnt_:
|
||||
self.reader_cv_.wait()
|
||||
self.reader_cnt_ += 1
|
||||
self.mutex_.release()
|
||||
|
||||
def runlock(self):
|
||||
"""Unlock reading lock."""
|
||||
self.mutex_.acquire()
|
||||
self.reader_cnt_ -= 1
|
||||
if self.writer_entered_:
|
||||
if self.reader_cnt_ == 0:
|
||||
self.writer_cv_.notify(1)
|
||||
else:
|
||||
if self.max_readers_ - 1 == self.reader_cnt_:
|
||||
self.reader_cv_.notify(1)
|
||||
self.mutex_.release()
|
||||
|
||||
def wlock(self):
|
||||
"""Lock for writing."""
|
||||
self.mutex_.acquire()
|
||||
while self.writer_entered_:
|
||||
self.reader_cv_.wait()
|
||||
self.writer_entered_ = True
|
||||
while 0 < self.reader_cnt_:
|
||||
self.writer_cv_.wait()
|
||||
self.mutex_.release()
|
||||
|
||||
def wunlock(self):
|
||||
"""Unlock writing lock."""
|
||||
self.mutex_.acquire()
|
||||
self.writer_entered_ = False
|
||||
self.reader_cv_.notifyAll()
|
||||
self.mutex_.release()
|
||||
|
||||
class OBRLock:
|
||||
"""Reading Lock Wrapper for `with` clause."""
|
||||
|
||||
def __init__(self, rwlock) -> None:
|
||||
"""Create reading lock wrapper instance."""
|
||||
self.rwlock_ = rwlock
|
||||
|
||||
def __enter__(self):
|
||||
"""Lock."""
|
||||
self.rwlock_.rlock()
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
"""Unlock when exiting."""
|
||||
self.rwlock_.runlock()
|
||||
|
||||
class OBWLock:
|
||||
"""Writing Lock Wrapper for `with` clause."""
|
||||
|
||||
def __init__(self, rwlock) -> None:
|
||||
"""Create writing lock wrapper instance."""
|
||||
self.rwlock_ = rwlock
|
||||
|
||||
def __enter__(self):
|
||||
"""Lock."""
|
||||
self.rwlock_.wlock()
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
"""Unlock when exiting."""
|
||||
self.rwlock_.wunlock()
|
||||
|
||||
def reader_lock(self):
|
||||
"""Get reading lock wrapper."""
|
||||
return self.OBRLock(self)
|
||||
|
||||
def writer_lock(self):
|
||||
"""Get writing lock wrapper."""
|
||||
return self.OBWLock(self)
|
||||
|
||||
|
||||
ob_grwlock = OceanBaseGlobalRWLock(_OCEANBASE_DEFAULT_RWLOCK_MAX_READER)
|
||||
|
||||
|
||||
class OceanBase:
|
||||
"""OceanBase Vector Store implementation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connection_string: str,
|
||||
embedding_function: Embeddings,
|
||||
embedding_dimension: int = _OCEANBASE_DEFAULT_EMBEDDING_DIM,
|
||||
collection_name: str = _OCEANBASE_DEFAULT_COLLECTION_NAME,
|
||||
pre_delete_collection: bool = False,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
engine_args: Optional[dict] = None,
|
||||
delay_table_creation: bool = True,
|
||||
enable_index: bool = False,
|
||||
th_create_ivfflat_index: int = _OCEANBASE_DEFAULT_IVFFLAT_ROW_THRESHOLD,
|
||||
sql_logger: Optional[logging.Logger] = None,
|
||||
collection_stat: Optional[OceanBaseCollectionStat] = None,
|
||||
enable_normalize_vector: bool = False,
|
||||
) -> None:
|
||||
"""Create OceanBase Vector Store instance."""
|
||||
self.connection_string = connection_string
|
||||
self.embedding_function = embedding_function
|
||||
self.embedding_dimension = embedding_dimension
|
||||
self.collection_name = collection_name
|
||||
self.pre_delete_collection = pre_delete_collection
|
||||
self.logger = logger or logging.getLogger(__name__)
|
||||
self.delay_table_creation = delay_table_creation
|
||||
self.th_create_ivfflat_index = th_create_ivfflat_index
|
||||
self.enable_index = enable_index
|
||||
self.sql_logger = sql_logger
|
||||
self.collection_stat = collection_stat
|
||||
self.enable_normalize_vector = enable_normalize_vector
|
||||
self.__post_init__(engine_args)
|
||||
|
||||
def __post_init__(
|
||||
self,
|
||||
engine_args: Optional[dict] = None,
|
||||
) -> None:
|
||||
"""Create connection & vector table."""
|
||||
_engine_args = engine_args or {}
|
||||
if "pool_recycle" not in _engine_args:
|
||||
_engine_args["pool_recycle"] = 3600
|
||||
self.engine = create_engine(self.connection_string, **_engine_args)
|
||||
self.create_collection()
|
||||
|
||||
@property
|
||||
def embeddings(self) -> Embeddings:
|
||||
"""Get embedding function."""
|
||||
return self.embedding_function
|
||||
|
||||
def create_collection(self) -> None:
|
||||
"""Create vector table."""
|
||||
if self.pre_delete_collection:
|
||||
self.delete_collection()
|
||||
if not self.delay_table_creation and (
|
||||
self.collection_stat is None
|
||||
or self.collection_stat.get_maybe_collection_not_exist()
|
||||
):
|
||||
self.create_table_if_not_exists()
|
||||
if self.collection_stat is not None:
|
||||
self.collection_stat.collection_exists()
|
||||
|
||||
def delete_collection(self) -> None:
|
||||
"""Drop vector table."""
|
||||
drop_statement = text(f"DROP TABLE IF EXISTS {self.collection_name}")
|
||||
if self.sql_logger is not None:
|
||||
self.sql_logger.debug(f"Trying to delete collection: {drop_statement}")
|
||||
with self.engine.connect() as conn, conn.begin():
|
||||
conn.execute(drop_statement)
|
||||
if self.collection_stat is not None:
|
||||
self.collection_stat.collection_not_exists()
|
||||
self.collection_stat.collection_index_not_exists()
|
||||
|
||||
def create_table_if_not_exists(self) -> None:
|
||||
"""Create vector table with SQL."""
|
||||
create_table_query = f"""
|
||||
CREATE TABLE IF NOT EXISTS `{self.collection_name}` (
|
||||
id VARCHAR(40) NOT NULL,
|
||||
embedding VECTOR({self.embedding_dimension}),
|
||||
document LONGTEXT,
|
||||
metadata JSON,
|
||||
PRIMARY KEY (id)
|
||||
)
|
||||
"""
|
||||
if self.sql_logger is not None:
|
||||
self.sql_logger.debug(f"Trying to create table: {create_table_query}")
|
||||
with self.engine.connect() as conn, conn.begin():
|
||||
# Create the table
|
||||
conn.execute(text(create_table_query))
|
||||
|
||||
def create_collection_ivfflat_index_if_not_exists(self) -> None:
|
||||
"""Create ivfflat index table with SQL."""
|
||||
create_index_query = f"""
|
||||
CREATE INDEX IF NOT EXISTS `embedding_idx` on `{self.collection_name}` (
|
||||
embedding l2
|
||||
) using ivfflat with (lists=20)
|
||||
"""
|
||||
with ob_grwlock.writer_lock(), self.engine.connect() as conn, conn.begin():
|
||||
# Create Ivfflat Index
|
||||
if self.sql_logger is not None:
|
||||
self.sql_logger.debug(
|
||||
f"Trying to create ivfflat index: {create_index_query}"
|
||||
)
|
||||
conn.execute(text(create_index_query))
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
batch_size: int = 500,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
"""Insert texts into vector table."""
|
||||
if ids is None:
|
||||
ids = [str(uuid.uuid1()) for _ in texts]
|
||||
|
||||
embeddings = self.embedding_function.embed_documents(list(texts))
|
||||
|
||||
if len(embeddings) == 0:
|
||||
return ids
|
||||
|
||||
if not metadatas:
|
||||
metadatas = [{} for _ in texts]
|
||||
|
||||
if self.delay_table_creation and (
|
||||
self.collection_stat is None
|
||||
or self.collection_stat.get_maybe_collection_not_exist()
|
||||
):
|
||||
self.embedding_dimension = len(embeddings[0])
|
||||
self.create_table_if_not_exists()
|
||||
self.delay_table_creation = False
|
||||
if self.collection_stat is not None:
|
||||
self.collection_stat.collection_exists()
|
||||
|
||||
chunks_table = Table(
|
||||
self.collection_name,
|
||||
Base.metadata,
|
||||
Column("id", VARCHAR(40), primary_key=True),
|
||||
Column("embedding", Vector(self.embedding_dimension)),
|
||||
Column("document", LONGTEXT, nullable=True),
|
||||
Column("metadata", JSON, nullable=True), # filter
|
||||
keep_existing=True,
|
||||
)
|
||||
|
||||
row_count_query = f"""
|
||||
SELECT COUNT(*) as cnt FROM `{self.collection_name}`
|
||||
"""
|
||||
chunks_table_data = []
|
||||
with self.engine.connect() as conn, conn.begin():
|
||||
for document, metadata, chunk_id, embedding in zip(
|
||||
texts, metadatas, ids, embeddings
|
||||
):
|
||||
chunks_table_data.append(
|
||||
{
|
||||
"id": chunk_id,
|
||||
"embedding": embedding
|
||||
if not self.enable_normalize_vector
|
||||
else self._normalization_vectors(embedding),
|
||||
"document": document,
|
||||
"metadata": metadata,
|
||||
}
|
||||
)
|
||||
|
||||
# Execute the batch insert when the batch size is reached
|
||||
if len(chunks_table_data) == batch_size:
|
||||
with ob_grwlock.reader_lock():
|
||||
if self.sql_logger is not None:
|
||||
insert_sql_for_log = str(
|
||||
insert(chunks_table).values(chunks_table_data)
|
||||
)
|
||||
self.sql_logger.debug(
|
||||
f"""Trying to insert vectors:
|
||||
{insert_sql_for_log}"""
|
||||
)
|
||||
conn.execute(insert(chunks_table).values(chunks_table_data))
|
||||
# Clear the chunks_table_data list for the next batch
|
||||
chunks_table_data.clear()
|
||||
|
||||
# Insert any remaining records that didn't make up a full batch
|
||||
if chunks_table_data:
|
||||
with ob_grwlock.reader_lock():
|
||||
if self.sql_logger is not None:
|
||||
insert_sql_for_log = str(
|
||||
insert(chunks_table).values(chunks_table_data)
|
||||
)
|
||||
self.sql_logger.debug(
|
||||
f"""Trying to insert vectors:
|
||||
{insert_sql_for_log}"""
|
||||
)
|
||||
conn.execute(insert(chunks_table).values(chunks_table_data))
|
||||
|
||||
if self.enable_index and (
|
||||
self.collection_stat is None
|
||||
or self.collection_stat.get_maybe_collection_index_not_exist()
|
||||
):
|
||||
with ob_grwlock.reader_lock():
|
||||
row_cnt_res = conn.execute(text(row_count_query))
|
||||
for row in row_cnt_res:
|
||||
if row.cnt > self.th_create_ivfflat_index:
|
||||
self.create_collection_ivfflat_index_if_not_exists()
|
||||
if self.collection_stat is not None:
|
||||
self.collection_stat.collection_index_exists()
|
||||
|
||||
return ids
|
||||
|
||||
def similarity_search(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
filter: Optional[MetadataFilters] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Do similarity search via query in String."""
|
||||
embedding = self.embedding_function.embed_query(query)
|
||||
docs = self.similarity_search_by_vector(embedding=embedding, k=k, filter=filter)
|
||||
return docs
|
||||
|
||||
def similarity_search_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = 4,
|
||||
filter: Optional[MetadataFilters] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Do similarity search via query vector."""
|
||||
docs_and_scores = self.similarity_search_with_score_by_vector(
|
||||
embedding=embedding, k=k, filter=filter
|
||||
)
|
||||
return [doc for doc, _ in docs_and_scores]
|
||||
|
||||
def similarity_search_with_score_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = 4,
|
||||
filter: Optional[MetadataFilters] = None,
|
||||
score_threshold: Optional[float] = None,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Do similarity search via query vector with score."""
|
||||
try:
|
||||
from sqlalchemy.engine import Row
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import Row from sqlalchemy.engine. "
|
||||
"Please 'pip install sqlalchemy>=1.4'."
|
||||
)
|
||||
|
||||
# filter is not support in OceanBase currently.
|
||||
|
||||
# normailze embedding vector
|
||||
if self.enable_normalize_vector:
|
||||
embedding = self._normalization_vectors(embedding)
|
||||
|
||||
embedding_str = ob_vector_to_db(embedding, self.embedding_dimension)
|
||||
vector_distance_op = "<@>" if self.enable_normalize_vector else "<->"
|
||||
sql_query = f"""
|
||||
SELECT document, metadata, embedding {vector_distance_op} '{embedding_str}'
|
||||
as distance
|
||||
FROM {self.collection_name}
|
||||
ORDER BY embedding {vector_distance_op} '{embedding_str}'
|
||||
LIMIT :k
|
||||
"""
|
||||
sql_query_str_for_log = f"""
|
||||
SELECT document, metadata, embedding {vector_distance_op} '?' as distance
|
||||
FROM {self.collection_name}
|
||||
ORDER BY embedding {vector_distance_op} '?'
|
||||
LIMIT {k}
|
||||
"""
|
||||
|
||||
params = {"k": k}
|
||||
try:
|
||||
with ob_grwlock.reader_lock(), self.engine.connect() as conn:
|
||||
if self.sql_logger is not None:
|
||||
self.sql_logger.debug(
|
||||
f"Trying to do similarity search: {sql_query_str_for_log}"
|
||||
)
|
||||
results: Sequence[Row] = conn.execute(
|
||||
text(sql_query), params
|
||||
).fetchall()
|
||||
|
||||
if (score_threshold is not None) and self.enable_normalize_vector:
|
||||
documents_with_scores = [
|
||||
(
|
||||
Document(
|
||||
content=result.document,
|
||||
metadata=json.loads(result.metadata),
|
||||
),
|
||||
result.distance,
|
||||
)
|
||||
for result in results
|
||||
if result.distance >= score_threshold
|
||||
]
|
||||
else:
|
||||
documents_with_scores = [
|
||||
(
|
||||
Document(
|
||||
content=result.document,
|
||||
metadata=json.loads(result.metadata),
|
||||
),
|
||||
result.distance,
|
||||
)
|
||||
for result in results
|
||||
]
|
||||
return documents_with_scores
|
||||
except Exception as e:
|
||||
self.logger.error("similarity_search_with_score_by_vector failed:", str(e))
|
||||
return []
|
||||
|
||||
def similarity_search_with_score(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
filter: Optional[MetadataFilters] = None,
|
||||
score_threshold: Optional[float] = None,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Do similarity search via query String with score."""
|
||||
embedding = self.embedding_function.embed_query(query)
|
||||
docs = self.similarity_search_with_score_by_vector(
|
||||
embedding=embedding, k=k, filter=filter, score_threshold=score_threshold
|
||||
)
|
||||
return docs
|
||||
|
||||
def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]:
|
||||
"""Delete vectors from vector table."""
|
||||
if ids is None:
|
||||
raise ValueError("No ids provided to delete.")
|
||||
|
||||
# Define the table schema
|
||||
chunks_table = Table(
|
||||
self.collection_name,
|
||||
Base.metadata,
|
||||
Column("id", VARCHAR(40), primary_key=True),
|
||||
Column("embedding", Vector(self.embedding_dimension)),
|
||||
Column("document", LONGTEXT, nullable=True),
|
||||
Column("metadata", JSON, nullable=True), # filter
|
||||
keep_existing=True,
|
||||
)
|
||||
|
||||
try:
|
||||
with self.engine.connect() as conn, conn.begin():
|
||||
delete_condition = chunks_table.c.id.in_(ids)
|
||||
delete_stmt = chunks_table.delete().where(delete_condition)
|
||||
with ob_grwlock.reader_lock():
|
||||
if self.sql_logger is not None:
|
||||
self.sql_logger.debug(
|
||||
f"Trying to delete vectors: {str(delete_stmt)}"
|
||||
)
|
||||
conn.execute(delete_stmt)
|
||||
return True
|
||||
except Exception as e:
|
||||
self.logger.error("Delete operation failed:", str(e))
|
||||
return False
|
||||
|
||||
def _normalization_vectors(self, vector):
|
||||
import numpy as np
|
||||
|
||||
norm = np.linalg.norm(vector)
|
||||
return (vector / norm).tolist()
|
||||
|
||||
@classmethod
|
||||
def connection_string_from_db_params(
|
||||
cls,
|
||||
host: str,
|
||||
port: int,
|
||||
database: str,
|
||||
user: str,
|
||||
password: str,
|
||||
) -> str:
|
||||
"""Get connection string."""
|
||||
return f"mysql+pymysql://{user}:{password}@{host}:{port}/{database}"
|
||||
|
||||
|
||||
ob_collection_stats_lock = threading.Lock()
|
||||
ob_collection_stats: Dict[str, OceanBaseCollectionStat] = {}
|
||||
|
||||
|
||||
@register_resource(
|
||||
_("OceanBase Vector Store"),
|
||||
"oceanbase_vector_store",
|
||||
category=ResourceCategory.VECTOR_STORE,
|
||||
parameters=[
|
||||
*_COMMON_PARAMETERS,
|
||||
Parameter.build_from(
|
||||
_("OceanBase Host"),
|
||||
"ob_host",
|
||||
str,
|
||||
description=_("oceanbase host"),
|
||||
optional=True,
|
||||
default=None,
|
||||
),
|
||||
Parameter.build_from(
|
||||
_("OceanBase Port"),
|
||||
"ob_port",
|
||||
int,
|
||||
description=_("oceanbase port"),
|
||||
optional=True,
|
||||
default=None,
|
||||
),
|
||||
Parameter.build_from(
|
||||
_("OceanBase User"),
|
||||
"ob_user",
|
||||
str,
|
||||
description=_("user to login"),
|
||||
optional=True,
|
||||
default=None,
|
||||
),
|
||||
Parameter.build_from(
|
||||
_("OceanBase Password"),
|
||||
"ob_password",
|
||||
str,
|
||||
description=_("password to login"),
|
||||
optional=True,
|
||||
default=None,
|
||||
),
|
||||
Parameter.build_from(
|
||||
_("OceanBase Database"),
|
||||
"ob_database",
|
||||
str,
|
||||
description=_("database for vector tables"),
|
||||
optional=True,
|
||||
default=None,
|
||||
),
|
||||
],
|
||||
description="OceanBase vector store.",
|
||||
)
|
||||
class OceanBaseConfig(VectorStoreConfig):
|
||||
"""OceanBase vector store config."""
|
||||
|
||||
class Config:
|
||||
"""Config for BaseModel."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
"""OceanBase config"""
|
||||
ob_host: str = Field(
|
||||
default="127.0.0.1",
|
||||
description="oceanbase host",
|
||||
)
|
||||
ob_port: int = Field(
|
||||
default=2881,
|
||||
description="oceanbase port",
|
||||
)
|
||||
ob_user: str = Field(
|
||||
default="root@test",
|
||||
description="user to login",
|
||||
)
|
||||
ob_password: str = Field(
|
||||
default="",
|
||||
description="password to login",
|
||||
)
|
||||
ob_database: str = Field(
|
||||
default="test",
|
||||
description="database for vector tables",
|
||||
)
|
||||
|
||||
|
||||
class OceanBaseStore(VectorStoreBase):
|
||||
"""OceanBase vector store."""
|
||||
|
||||
def __init__(self, vector_store_config: OceanBaseConfig) -> None:
|
||||
"""Create a OceanBaseStore instance."""
|
||||
if vector_store_config.embedding_fn is None:
|
||||
raise ValueError("embedding_fn is required for OceanBaseStore")
|
||||
|
||||
self.embeddings = vector_store_config.embedding_fn
|
||||
self.collection_name = vector_store_config.name
|
||||
vector_store_config = vector_store_config.dict()
|
||||
self.OB_HOST = str(
|
||||
vector_store_config.get("ob_host") or os.getenv("OB_HOST", "127.0.0.1")
|
||||
)
|
||||
self.OB_PORT = int(
|
||||
vector_store_config.get("ob_port") or int(os.getenv("OB_PORT", "2881"))
|
||||
)
|
||||
self.OB_USER = str(
|
||||
vector_store_config.get("ob_user") or os.getenv("OB_USER", "root@test")
|
||||
)
|
||||
self.OB_PASSWORD = str(
|
||||
vector_store_config.get("ob_password") or os.getenv("OB_PASSWORD", "")
|
||||
)
|
||||
self.OB_DATABASE = str(
|
||||
vector_store_config.get("ob_database") or os.getenv("OB_DATABASE", "test")
|
||||
)
|
||||
self.OB_ENABLE_NORMALIZE_VECTOR = bool(
|
||||
os.getenv("OB_ENABLE_NORMALIZE_VECTOR", "")
|
||||
)
|
||||
self.connection_string = OceanBase.connection_string_from_db_params(
|
||||
self.OB_HOST, self.OB_PORT, self.OB_DATABASE, self.OB_USER, self.OB_PASSWORD
|
||||
)
|
||||
self.logger = logger
|
||||
with ob_collection_stats_lock:
|
||||
if ob_collection_stats.get(self.collection_name) is None:
|
||||
ob_collection_stats[self.collection_name] = OceanBaseCollectionStat()
|
||||
self.collection_stat = ob_collection_stats[self.collection_name]
|
||||
|
||||
self.vector_store_client = OceanBase(
|
||||
connection_string=self.connection_string,
|
||||
embedding_function=self.embeddings,
|
||||
collection_name=self.collection_name,
|
||||
logger=self.logger,
|
||||
sql_logger=sql_logger,
|
||||
enable_index=bool(os.getenv("OB_ENABLE_INDEX", "")),
|
||||
collection_stat=self.collection_stat,
|
||||
enable_normalize_vector=self.OB_ENABLE_NORMALIZE_VECTOR,
|
||||
)
|
||||
|
||||
def similar_search(
|
||||
self, text, topk, filters: Optional[MetadataFilters] = None, **kwargs: Any
|
||||
) -> List[Chunk]:
|
||||
"""Perform a search on a query string and return results."""
|
||||
self.logger.info("OceanBase: similar_search..")
|
||||
documents = self.vector_store_client.similarity_search(
|
||||
text, topk, filter=filters
|
||||
)
|
||||
return [Chunk(content=doc.content, metadata=doc.metadata) for doc in documents]
|
||||
|
||||
def similar_search_with_scores(
|
||||
self,
|
||||
text,
|
||||
topk,
|
||||
score_threshold: float,
|
||||
filters: Optional[MetadataFilters] = None,
|
||||
) -> List[Chunk]:
|
||||
"""Perform a search on a query string and return results with score."""
|
||||
self.logger.info("OceanBase: similar_search_with_scores..")
|
||||
docs_and_scores = self.vector_store_client.similarity_search_with_score(
|
||||
text, topk, filter=filters
|
||||
)
|
||||
return [
|
||||
Chunk(content=doc.content, metadata=doc.metadata, score=score)
|
||||
for doc, score in docs_and_scores
|
||||
]
|
||||
|
||||
def vector_name_exists(self):
|
||||
"""Whether vector name exists."""
|
||||
self.logger.info("OceanBase: vector_name_exists..")
|
||||
try:
|
||||
self.vector_store_client.create_collection()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error("vector_name_exists error", e.message)
|
||||
return False
|
||||
|
||||
def load_document(self, chunks: List[Chunk]) -> List[str]:
|
||||
"""Load document in vector database."""
|
||||
self.logger.info("OceanBase: load_document..")
|
||||
# lc_documents = [Chunk.chunk2langchain(chunk) for chunk in chunks]
|
||||
texts = [chunk.content for chunk in chunks]
|
||||
metadatas = [chunk.metadata for chunk in chunks]
|
||||
ids = self.vector_store_client.add_texts(texts=texts, metadatas=metadatas)
|
||||
return ids
|
||||
|
||||
def delete_vector_name(self, vector_name):
|
||||
"""Delete vector name."""
|
||||
self.logger.info("OceanBase: delete_vector_name..")
|
||||
return self.vector_store_client.delete_collection()
|
||||
|
||||
def delete_by_ids(self, ids):
|
||||
"""Delete vector by ids."""
|
||||
self.logger.info("OceanBase: delete_by_ids..")
|
||||
ids = ids.split(",")
|
||||
if len(ids) > 0:
|
||||
self.vector_store_client.delete(ids)
|
@ -2,6 +2,7 @@ ARG BASE_IMAGE="eosphorosai/dbgpt:latest"
|
||||
|
||||
FROM ${BASE_IMAGE}
|
||||
|
||||
RUN pip install dashscope
|
||||
RUN apt-get update && apt-get install -y wget gnupg lsb-release net-tools
|
||||
|
||||
RUN apt-key adv --keyserver keyserver.ubuntu.com --recv-keys 467B942D3A79BD29
|
||||
@ -9,7 +10,7 @@ RUN apt-key adv --keyserver keyserver.ubuntu.com --recv-keys 467B942D3A79BD29
|
||||
RUN wget https://dev.mysql.com/get/mysql-apt-config_0.8.17-1_all.deb
|
||||
RUN dpkg -i mysql-apt-config_0.8.17-1_all.deb
|
||||
|
||||
RUN apt-get update && apt-get install -y mysql-server && apt-get clean
|
||||
RUN apt-key adv --keyserver keyserver.ubuntu.com --recv-keys B7B3B788A8D3785C && apt-get update && apt-get install -y mysql-server && apt-get clean
|
||||
|
||||
# Remote access
|
||||
RUN sed -i 's/bind-address\s*=\s*127.0.0.1/bind-address = 0.0.0.0/g' /etc/mysql/mysql.conf.d/mysqld.cnf \
|
||||
|
42
docker/compose_examples/dbgpt-oceanbase-docker-compose.yml
Normal file
42
docker/compose_examples/dbgpt-oceanbase-docker-compose.yml
Normal file
@ -0,0 +1,42 @@
|
||||
version: '3.8'
|
||||
|
||||
services:
|
||||
oceanbase:
|
||||
image: oceanbase/oceanbase-ce:vector
|
||||
ports:
|
||||
- 2881:2881
|
||||
networks:
|
||||
- dbgptnet
|
||||
dbgpt:
|
||||
image: eosphorosai/dbgpt-allinone
|
||||
environment:
|
||||
- DBGPT_WEBSERVER_PORT=12345
|
||||
- VECTOR_STORE_TYPE=OceanBase
|
||||
- OB_HOST=oceanbase
|
||||
- OB_HOST=127.0.0.1
|
||||
- OB_PORT=2881
|
||||
- OB_USER=root@test
|
||||
- OB_DATABASE=test
|
||||
- OB_SQL_DBG_LOG_PATH=/sql_log/sql.log
|
||||
- LOCAL_DB_TYPE=sqlite
|
||||
- LLM_MODEL=tongyi_proxyllm
|
||||
- PROXYLLM_BACKEND=qwen-plus
|
||||
- EMBEDDING_MODEL=text2vec
|
||||
# your api key
|
||||
# - TONGYI_PROXY_API_KEY={your-api-key}
|
||||
- LANGUAGE=zh
|
||||
# - OB_ENABLE_NORMALIZE_VECTOR=True
|
||||
# - OB_ENABLE_INDEX=True
|
||||
ports:
|
||||
- 3306:3306
|
||||
- 12345:12345
|
||||
volumes:
|
||||
# - {your-ob-sql-dbg-log-dir}:/sql_log
|
||||
# - {your-model-dir}:/app/models
|
||||
networks:
|
||||
- dbgptnet
|
||||
|
||||
networks:
|
||||
dbgptnet:
|
||||
driver: bridge
|
||||
name: dbgptnet
|
@ -69,6 +69,7 @@ import TabItem from '@theme/TabItem';
|
||||
{label: 'Chroma', value: 'Chroma'},
|
||||
{label: 'Milvus', value: 'Milvus'},
|
||||
{label: 'Weaviate', value: 'Weaviate'},
|
||||
{label: 'OceanBase', value: 'OceanBase'},
|
||||
]}>
|
||||
<TabItem value="Chroma" label="Chroma">
|
||||
|
||||
@ -106,6 +107,25 @@ set ``VECTOR_STORE_TYPE`` in ``.env`` file
|
||||
VECTOR_STORE_TYPE=Weaviate
|
||||
#WEAVIATE_URL=https://kt-region-m8hcy0wc.weaviate.network
|
||||
```
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="OceanBase" label="OceanBase">
|
||||
|
||||
set ``VECTOR_STORE_TYPE`` in ``.env`` file
|
||||
|
||||
```shell
|
||||
OB_HOST=127.0.0.1
|
||||
OB_PORT=2881
|
||||
OB_USER=root@test
|
||||
OB_DATABASE=test
|
||||
## Optional
|
||||
# OB_PASSWORD=
|
||||
## Optional: SQL statements executed by OceanBase is recorded in the log file specified by {OB_SQL_DBG_LOG_PATH}.
|
||||
# OB_SQL_DBG_LOG_PATH={your-sql-dbg-log-dir}/sql.log
|
||||
## Optional: If {OB_ENABLE_NORMALIZE_VECTOR} is set, the vector stored in OceanBase is normalized.
|
||||
# OB_ENABLE_NORMALIZE_VECTOR=True
|
||||
## Optional: If {OB_ENABLE_INDEX} is set, OceanBase will automatically create a vector index table.
|
||||
# OB_ENABLE_INDEX=True
|
||||
```
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
@ -18,8 +18,29 @@ git lfs clone https://huggingface.co/GanymedeNil/text2vec-large-chinese
|
||||
|
||||
Update .env file and set VECTOR_STORE_TYPE.
|
||||
|
||||
DB-GPT currently support Chroma(Default), Milvus(>2.1), Weaviate vector database.
|
||||
If you want to change vector db, Update your .env, set your vector store type, VECTOR_STORE_TYPE=Chroma (now only support Chroma and Milvus(>2.1), if you set Milvus, please set MILVUS_URL and MILVUS_PORT)
|
||||
DB-GPT currently support Chroma(Default), Milvus(>2.1), Weaviate, OceanBase vector database.
|
||||
If you want to change vector db, Update your .env, set your vector store type, VECTOR_STORE_TYPE=Chroma (now only support Chroma and Milvus(>2.1), if you set Milvus, please set MILVUS_URL and MILVUS_PORT).
|
||||
|
||||
If you want to use OceanBase, please first start a docker container via the following command:
|
||||
```shell
|
||||
docker run -p 2881:2881 --name obvec -d oceanbase/oceanbase-ce:vector
|
||||
```
|
||||
Then set the following variables in the .env file:
|
||||
```shell
|
||||
VECTOR_STORE_TYPE=OceanBase
|
||||
OB_HOST=127.0.0.1
|
||||
OB_PORT=2881
|
||||
OB_USER=root@test
|
||||
OB_DATABASE=test
|
||||
## Optional
|
||||
# OB_PASSWORD=
|
||||
## Optional: SQL statements executed by OceanBase is recorded in the log file specified by {OB_SQL_DBG_LOG_PATH}.
|
||||
# OB_SQL_DBG_LOG_PATH={your-sql-dbg-log-dir}/sql.log
|
||||
## Optional: If {OB_ENABLE_NORMALIZE_VECTOR} is set, the vector stored in OceanBase is normalized.
|
||||
# OB_ENABLE_NORMALIZE_VECTOR=True
|
||||
## Optional: If {OB_ENABLE_INDEX} is set, OceanBase will automatically create a vector index table.
|
||||
# OB_ENABLE_INDEX=True
|
||||
```
|
||||
If you want to support more vector db, you can integrate yourself.[how to integrate](https://db-gpt.readthedocs.io/en/latest/modules/vector.html)
|
||||
```commandline
|
||||
#*******************************************************************#
|
||||
|
@ -147,6 +147,50 @@ msgid ""
|
||||
"connection string."
|
||||
msgstr "向量存储的连接字符串,如果未设置,将使用默认的连接字符串。"
|
||||
|
||||
#: ../dbgpt/storage/vector_store/oceanbase_store.py:542
|
||||
msgid "OceanBase Vector Store"
|
||||
msgstr "OceanBase 向量存储"
|
||||
|
||||
#: ../dbgpt/storage/vector_store/oceanbase_store.py:548
|
||||
msgid "OceanBase Host"
|
||||
msgstr "OceanBase 主机地址"
|
||||
|
||||
#: ../dbgpt/storage/vector_store/oceanbase_store.py:552
|
||||
msgid "oceanbase host"
|
||||
msgstr "oceanbase 主机地址"
|
||||
|
||||
#: ../dbgpt/storage/vector_store/oceanbase_store.py:558
|
||||
msgid "OceanBase Port"
|
||||
msgstr "OceanBase 端口"
|
||||
|
||||
#: ../dbgpt/storage/vector_store/oceanbase_store.py:562
|
||||
msgid "oceanbase port"
|
||||
msgstr "oceanbase 端口"
|
||||
|
||||
#: ../dbgpt/storage/vector_store/oceanbase_store.py:568
|
||||
msgid "OceanBase User"
|
||||
msgstr "OceanBase 用户名"
|
||||
|
||||
#: ../dbgpt/storage/vector_store/oceanbase_store.py:572
|
||||
msgid "user to login"
|
||||
msgstr "登录用户名"
|
||||
|
||||
#: ../dbgpt/storage/vector_store/oceanbase_store.py:578
|
||||
msgid "OceanBase Password"
|
||||
msgstr "OceanBase 密码"
|
||||
|
||||
#: ../dbgpt/storage/vector_store/oceanbase_store.py:582
|
||||
msgid "password to login"
|
||||
msgstr "登录密码"
|
||||
|
||||
#: ../dbgpt/storage/vector_store/oceanbase_store.py:588
|
||||
msgid "OceanBase Database"
|
||||
msgstr "OceanBase 数据库名"
|
||||
|
||||
#: ../dbgpt/storage/vector_store/oceanbase_store.py:592
|
||||
msgid "database for vector tables"
|
||||
msgstr "存放向量表的数据库名"
|
||||
|
||||
#: ../dbgpt/storage/vector_store/base.py:19
|
||||
msgid "Collection Name"
|
||||
msgstr "集合名称"
|
||||
|
Loading…
Reference in New Issue
Block a user