support oceanbase as an optional vector database (#1435)

Signed-off-by: shanhaikang.shk <shanhaikang.shk@oceanbase.com>
This commit is contained in:
GITHUBear 2024-04-24 16:08:30 +08:00 committed by GitHub
parent 91c1371234
commit 6520367623
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 975 additions and 12 deletions

View File

@ -238,6 +238,18 @@ class Config(metaclass=Singleton):
self.MILVUS_USERNAME = os.getenv("MILVUS_USERNAME", None)
self.MILVUS_PASSWORD = os.getenv("MILVUS_PASSWORD", None)
## OceanBase Configuration
self.OB_HOST = os.getenv("OB_HOST", "127.0.0.1")
self.OB_PORT = int(os.getenv("OB_PORT", "2881"))
self.OB_USER = os.getenv("OB_USER", "root")
self.OB_PASSWORD = os.getenv("OB_PASSWORD", "")
self.OB_DATABASE = os.getenv("OB_DATABASE", "test")
self.OB_SQL_DBG_LOG_PATH = os.getenv("OB_SQL_DBG_LOG_PATH", "")
self.OB_ENABLE_NORMALIZE_VECTOR = bool(
os.getenv("OB_ENABLE_NORMALIZE_VECTOR", "")
)
self.OB_ENABLE_INDEX = bool(os.getenv("OB_ENABLE_INDEX", ""))
# QLoRA
self.QLoRA = os.getenv("QUANTIZE_QLORA", "True")
self.IS_LOAD_8BIT = os.getenv("QUANTIZE_8bit", "True").lower() == "true"

View File

@ -25,11 +25,11 @@ def test_md_header_text_splitter() -> None:
output = markdown_splitter.split_text(markdown_document)
expected_output = [
Chunk(
content="{'Header 1': 'dbgpt', 'Header 2': 'description'}, my name is dbgpt",
content='"dbgpt-description": my name is dbgpt',
metadata={"Header 1": "dbgpt", "Header 2": "description"},
),
Chunk(
content="{'Header 1': 'dbgpt', 'Header 2': 'content'}, my name is aries",
content='"dbgpt-content": my name is aries',
metadata={"Header 1": "dbgpt", "Header 2": "content"},
),
]

View File

@ -515,7 +515,8 @@ class MarkdownHeaderTextSplitter(TextSplitter):
aggregated_chunks[-1]["content"] += " \n" + line["content"]
else:
# Otherwise, append the current line to the aggregated list
line["content"] = f"{line['metadata']}, " + line["content"]
subtitles = "-".join((list(line["metadata"].values())))
line["content"] = f'"{subtitles}": ' + line["content"]
aggregated_chunks.append(line)
return [
@ -557,16 +558,28 @@ class MarkdownHeaderTextSplitter(TextSplitter):
# header_stack: List[Dict[str, Union[int, str]]] = []
header_stack: List[HeaderType] = []
initial_metadata: Dict[str, str] = {}
# Determine whether a line is within a markdown code block.
in_code_block = False
for line in lines:
stripped_line = line.strip()
# A code frame starts with "```"
with_code_frame = stripped_line.startswith("```") and (
stripped_line != "```"
)
if (not in_code_block) and with_code_frame:
in_code_block = True
# Check each line against each of the header types (e.g., #, ##)
for sep, name in self.headers_to_split_on:
# Check if line starts with a header that we intend to split on
if stripped_line.startswith(sep) and (
# Header with no text OR header is followed by space
# Both are valid conditions that sep is being used a header
len(stripped_line) == len(sep)
or stripped_line[len(sep)] == " "
if (
(not in_code_block)
and stripped_line.startswith(sep)
and (
# Header with no text OR header is followed by space
# Both are valid conditions that sep is being used a header
len(stripped_line) == len(sep)
or stripped_line[len(sep)] == " "
)
):
# Ensure we are tracking the header as metadata
if name is not None:
@ -620,6 +633,10 @@ class MarkdownHeaderTextSplitter(TextSplitter):
)
current_content.clear()
# Code block ends
if in_code_block and stripped_line == "```":
in_code_block = False
current_metadata = initial_metadata.copy()
if current_content:
lines_with_metadata.append(

View File

@ -26,6 +26,12 @@ def _import_weaviate() -> Any:
return WeaviateStore
def _import_oceanbase() -> Any:
from dbgpt.storage.vector_store.oceanbase_store import OceanBaseStore
return OceanBaseStore
def __getattr__(name: str) -> Any:
if name == "Chroma":
return _import_chroma()
@ -35,8 +41,10 @@ def __getattr__(name: str) -> Any:
return _import_weaviate()
elif name == "PGVector":
return _import_pgvector()
elif name == "OceanBase":
return _import_oceanbase()
else:
raise AttributeError(f"Could not find: {name}")
__all__ = ["Chroma", "Milvus", "Weaviate", "PGVector"]
__all__ = ["Chroma", "Milvus", "Weaviate", "OceanBase", "PGVector"]

View File

@ -0,0 +1,798 @@
"""OceanBase vector store."""
import json
import logging
import os
import threading
import uuid
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple
from pydantic import Field
from sqlalchemy import Column, Table, create_engine, insert, text
from sqlalchemy.dialects.mysql import JSON, LONGTEXT, VARCHAR
from sqlalchemy.types import String, UserDefinedType
from dbgpt.core import Chunk, Document, Embeddings
from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource
from dbgpt.storage.vector_store.base import (
_COMMON_PARAMETERS,
VectorStoreBase,
VectorStoreConfig,
)
from dbgpt.storage.vector_store.filters import MetadataFilters
from dbgpt.util.i18n_utils import _
try:
from sqlalchemy.orm import declarative_base
except ImportError:
from sqlalchemy.ext.declarative import declarative_base
logger = logging.getLogger(__name__)
sql_logger = None
sql_dbg_log_path = os.getenv("OB_SQL_DBG_LOG_PATH", "")
if sql_dbg_log_path != "":
sql_logger = logging.getLogger("ob_sql_dbg")
sql_logger.setLevel(logging.DEBUG)
file_handler = logging.FileHandler(sql_dbg_log_path)
file_handler.setLevel(logging.DEBUG)
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
file_handler.setFormatter(formatter)
sql_logger.addHandler(file_handler)
_OCEANBASE_DEFAULT_EMBEDDING_DIM = 1536
_OCEANBASE_DEFAULT_COLLECTION_NAME = "langchain_document"
_OCEANBASE_DEFAULT_IVFFLAT_ROW_THRESHOLD = 10000
_OCEANBASE_DEFAULT_RWLOCK_MAX_READER = 64
Base = declarative_base()
def ob_vector_from_db(value):
"""Parse vector from oceanbase."""
return [float(v) for v in value[1:-1].split(",")]
def ob_vector_to_db(value, dim=None):
"""Parse vector to oceanbase vector constant type."""
if value is None:
return value
return "[" + ",".join([str(float(v)) for v in value]) + "]"
class Vector(UserDefinedType):
"""OceanBase Vector Column Type."""
cache_ok = True
_string = String()
def __init__(self, dim):
"""Create a Vector column with dimemsion `dim`."""
super(UserDefinedType, self).__init__()
self.dim = dim
def get_col_spec(self, **kw):
"""Get vector column definition in string format."""
return "VECTOR(%d)" % self.dim
def bind_processor(self, dialect):
"""Get a processor to parse an array to oceanbase vector."""
def process(value):
return ob_vector_to_db(value, self.dim)
return process
def literal_processor(self, dialect):
"""Get a string processor to parse an array to OceanBase Vector."""
string_literal_processor = self._string._cached_literal_processor(dialect)
def process(value):
return string_literal_processor(ob_vector_to_db(value, self.dim))
return process
def result_processor(self, dialect, coltype):
"""Get a processor to parse OceanBase Vector to array."""
def process(value):
return ob_vector_from_db(value)
return process
class OceanBaseCollectionStat:
"""A tracer that maintains a table status in OceanBase."""
def __init__(self):
"""Create OceanBaseCollectionStat instance."""
self._lock = threading.Lock()
self.maybe_collection_not_exist = True
self.maybe_collection_index_not_exist = True
def collection_exists(self):
"""Set a table is existing."""
with self._lock:
self.maybe_collection_not_exist = False
def collection_index_exists(self):
"""Set the index of a table is existing."""
with self._lock:
self.maybe_collection_index_not_exist = False
def collection_not_exists(self):
"""Set a table is dropped."""
with self._lock:
self.maybe_collection_not_exist = True
def collection_index_not_exists(self):
"""Set the index of a table is dropped."""
with self._lock:
self.maybe_collection_index_not_exist = True
def get_maybe_collection_not_exist(self):
"""Get the existing status of a table."""
with self._lock:
return self.maybe_collection_not_exist
def get_maybe_collection_index_not_exist(self):
"""Get the existing stats of the index of a table."""
with self._lock:
return self.maybe_collection_index_not_exist
class OceanBaseGlobalRWLock:
"""A global rwlock for OceanBase to do creating vector index table offline ddl."""
def __init__(self, max_readers) -> None:
"""Create a rwlock."""
self.max_readers_ = max_readers
self.writer_entered_ = False
self.reader_cnt_ = 0
self.mutex_ = threading.Lock()
self.writer_cv_ = threading.Condition(self.mutex_)
self.reader_cv_ = threading.Condition(self.mutex_)
def rlock(self):
"""Lock for reading."""
self.mutex_.acquire()
while self.writer_entered_ or self.max_readers_ == self.reader_cnt_:
self.reader_cv_.wait()
self.reader_cnt_ += 1
self.mutex_.release()
def runlock(self):
"""Unlock reading lock."""
self.mutex_.acquire()
self.reader_cnt_ -= 1
if self.writer_entered_:
if self.reader_cnt_ == 0:
self.writer_cv_.notify(1)
else:
if self.max_readers_ - 1 == self.reader_cnt_:
self.reader_cv_.notify(1)
self.mutex_.release()
def wlock(self):
"""Lock for writing."""
self.mutex_.acquire()
while self.writer_entered_:
self.reader_cv_.wait()
self.writer_entered_ = True
while 0 < self.reader_cnt_:
self.writer_cv_.wait()
self.mutex_.release()
def wunlock(self):
"""Unlock writing lock."""
self.mutex_.acquire()
self.writer_entered_ = False
self.reader_cv_.notifyAll()
self.mutex_.release()
class OBRLock:
"""Reading Lock Wrapper for `with` clause."""
def __init__(self, rwlock) -> None:
"""Create reading lock wrapper instance."""
self.rwlock_ = rwlock
def __enter__(self):
"""Lock."""
self.rwlock_.rlock()
def __exit__(self, exc_type, exc_value, traceback):
"""Unlock when exiting."""
self.rwlock_.runlock()
class OBWLock:
"""Writing Lock Wrapper for `with` clause."""
def __init__(self, rwlock) -> None:
"""Create writing lock wrapper instance."""
self.rwlock_ = rwlock
def __enter__(self):
"""Lock."""
self.rwlock_.wlock()
def __exit__(self, exc_type, exc_value, traceback):
"""Unlock when exiting."""
self.rwlock_.wunlock()
def reader_lock(self):
"""Get reading lock wrapper."""
return self.OBRLock(self)
def writer_lock(self):
"""Get writing lock wrapper."""
return self.OBWLock(self)
ob_grwlock = OceanBaseGlobalRWLock(_OCEANBASE_DEFAULT_RWLOCK_MAX_READER)
class OceanBase:
"""OceanBase Vector Store implementation."""
def __init__(
self,
connection_string: str,
embedding_function: Embeddings,
embedding_dimension: int = _OCEANBASE_DEFAULT_EMBEDDING_DIM,
collection_name: str = _OCEANBASE_DEFAULT_COLLECTION_NAME,
pre_delete_collection: bool = False,
logger: Optional[logging.Logger] = None,
engine_args: Optional[dict] = None,
delay_table_creation: bool = True,
enable_index: bool = False,
th_create_ivfflat_index: int = _OCEANBASE_DEFAULT_IVFFLAT_ROW_THRESHOLD,
sql_logger: Optional[logging.Logger] = None,
collection_stat: Optional[OceanBaseCollectionStat] = None,
enable_normalize_vector: bool = False,
) -> None:
"""Create OceanBase Vector Store instance."""
self.connection_string = connection_string
self.embedding_function = embedding_function
self.embedding_dimension = embedding_dimension
self.collection_name = collection_name
self.pre_delete_collection = pre_delete_collection
self.logger = logger or logging.getLogger(__name__)
self.delay_table_creation = delay_table_creation
self.th_create_ivfflat_index = th_create_ivfflat_index
self.enable_index = enable_index
self.sql_logger = sql_logger
self.collection_stat = collection_stat
self.enable_normalize_vector = enable_normalize_vector
self.__post_init__(engine_args)
def __post_init__(
self,
engine_args: Optional[dict] = None,
) -> None:
"""Create connection & vector table."""
_engine_args = engine_args or {}
if "pool_recycle" not in _engine_args:
_engine_args["pool_recycle"] = 3600
self.engine = create_engine(self.connection_string, **_engine_args)
self.create_collection()
@property
def embeddings(self) -> Embeddings:
"""Get embedding function."""
return self.embedding_function
def create_collection(self) -> None:
"""Create vector table."""
if self.pre_delete_collection:
self.delete_collection()
if not self.delay_table_creation and (
self.collection_stat is None
or self.collection_stat.get_maybe_collection_not_exist()
):
self.create_table_if_not_exists()
if self.collection_stat is not None:
self.collection_stat.collection_exists()
def delete_collection(self) -> None:
"""Drop vector table."""
drop_statement = text(f"DROP TABLE IF EXISTS {self.collection_name}")
if self.sql_logger is not None:
self.sql_logger.debug(f"Trying to delete collection: {drop_statement}")
with self.engine.connect() as conn, conn.begin():
conn.execute(drop_statement)
if self.collection_stat is not None:
self.collection_stat.collection_not_exists()
self.collection_stat.collection_index_not_exists()
def create_table_if_not_exists(self) -> None:
"""Create vector table with SQL."""
create_table_query = f"""
CREATE TABLE IF NOT EXISTS `{self.collection_name}` (
id VARCHAR(40) NOT NULL,
embedding VECTOR({self.embedding_dimension}),
document LONGTEXT,
metadata JSON,
PRIMARY KEY (id)
)
"""
if self.sql_logger is not None:
self.sql_logger.debug(f"Trying to create table: {create_table_query}")
with self.engine.connect() as conn, conn.begin():
# Create the table
conn.execute(text(create_table_query))
def create_collection_ivfflat_index_if_not_exists(self) -> None:
"""Create ivfflat index table with SQL."""
create_index_query = f"""
CREATE INDEX IF NOT EXISTS `embedding_idx` on `{self.collection_name}` (
embedding l2
) using ivfflat with (lists=20)
"""
with ob_grwlock.writer_lock(), self.engine.connect() as conn, conn.begin():
# Create Ivfflat Index
if self.sql_logger is not None:
self.sql_logger.debug(
f"Trying to create ivfflat index: {create_index_query}"
)
conn.execute(text(create_index_query))
def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
batch_size: int = 500,
**kwargs: Any,
) -> List[str]:
"""Insert texts into vector table."""
if ids is None:
ids = [str(uuid.uuid1()) for _ in texts]
embeddings = self.embedding_function.embed_documents(list(texts))
if len(embeddings) == 0:
return ids
if not metadatas:
metadatas = [{} for _ in texts]
if self.delay_table_creation and (
self.collection_stat is None
or self.collection_stat.get_maybe_collection_not_exist()
):
self.embedding_dimension = len(embeddings[0])
self.create_table_if_not_exists()
self.delay_table_creation = False
if self.collection_stat is not None:
self.collection_stat.collection_exists()
chunks_table = Table(
self.collection_name,
Base.metadata,
Column("id", VARCHAR(40), primary_key=True),
Column("embedding", Vector(self.embedding_dimension)),
Column("document", LONGTEXT, nullable=True),
Column("metadata", JSON, nullable=True), # filter
keep_existing=True,
)
row_count_query = f"""
SELECT COUNT(*) as cnt FROM `{self.collection_name}`
"""
chunks_table_data = []
with self.engine.connect() as conn, conn.begin():
for document, metadata, chunk_id, embedding in zip(
texts, metadatas, ids, embeddings
):
chunks_table_data.append(
{
"id": chunk_id,
"embedding": embedding
if not self.enable_normalize_vector
else self._normalization_vectors(embedding),
"document": document,
"metadata": metadata,
}
)
# Execute the batch insert when the batch size is reached
if len(chunks_table_data) == batch_size:
with ob_grwlock.reader_lock():
if self.sql_logger is not None:
insert_sql_for_log = str(
insert(chunks_table).values(chunks_table_data)
)
self.sql_logger.debug(
f"""Trying to insert vectors:
{insert_sql_for_log}"""
)
conn.execute(insert(chunks_table).values(chunks_table_data))
# Clear the chunks_table_data list for the next batch
chunks_table_data.clear()
# Insert any remaining records that didn't make up a full batch
if chunks_table_data:
with ob_grwlock.reader_lock():
if self.sql_logger is not None:
insert_sql_for_log = str(
insert(chunks_table).values(chunks_table_data)
)
self.sql_logger.debug(
f"""Trying to insert vectors:
{insert_sql_for_log}"""
)
conn.execute(insert(chunks_table).values(chunks_table_data))
if self.enable_index and (
self.collection_stat is None
or self.collection_stat.get_maybe_collection_index_not_exist()
):
with ob_grwlock.reader_lock():
row_cnt_res = conn.execute(text(row_count_query))
for row in row_cnt_res:
if row.cnt > self.th_create_ivfflat_index:
self.create_collection_ivfflat_index_if_not_exists()
if self.collection_stat is not None:
self.collection_stat.collection_index_exists()
return ids
def similarity_search(
self,
query: str,
k: int = 4,
filter: Optional[MetadataFilters] = None,
**kwargs: Any,
) -> List[Document]:
"""Do similarity search via query in String."""
embedding = self.embedding_function.embed_query(query)
docs = self.similarity_search_by_vector(embedding=embedding, k=k, filter=filter)
return docs
def similarity_search_by_vector(
self,
embedding: List[float],
k: int = 4,
filter: Optional[MetadataFilters] = None,
**kwargs: Any,
) -> List[Document]:
"""Do similarity search via query vector."""
docs_and_scores = self.similarity_search_with_score_by_vector(
embedding=embedding, k=k, filter=filter
)
return [doc for doc, _ in docs_and_scores]
def similarity_search_with_score_by_vector(
self,
embedding: List[float],
k: int = 4,
filter: Optional[MetadataFilters] = None,
score_threshold: Optional[float] = None,
) -> List[Tuple[Document, float]]:
"""Do similarity search via query vector with score."""
try:
from sqlalchemy.engine import Row
except ImportError:
raise ImportError(
"Could not import Row from sqlalchemy.engine. "
"Please 'pip install sqlalchemy>=1.4'."
)
# filter is not support in OceanBase currently.
# normailze embedding vector
if self.enable_normalize_vector:
embedding = self._normalization_vectors(embedding)
embedding_str = ob_vector_to_db(embedding, self.embedding_dimension)
vector_distance_op = "<@>" if self.enable_normalize_vector else "<->"
sql_query = f"""
SELECT document, metadata, embedding {vector_distance_op} '{embedding_str}'
as distance
FROM {self.collection_name}
ORDER BY embedding {vector_distance_op} '{embedding_str}'
LIMIT :k
"""
sql_query_str_for_log = f"""
SELECT document, metadata, embedding {vector_distance_op} '?' as distance
FROM {self.collection_name}
ORDER BY embedding {vector_distance_op} '?'
LIMIT {k}
"""
params = {"k": k}
try:
with ob_grwlock.reader_lock(), self.engine.connect() as conn:
if self.sql_logger is not None:
self.sql_logger.debug(
f"Trying to do similarity search: {sql_query_str_for_log}"
)
results: Sequence[Row] = conn.execute(
text(sql_query), params
).fetchall()
if (score_threshold is not None) and self.enable_normalize_vector:
documents_with_scores = [
(
Document(
content=result.document,
metadata=json.loads(result.metadata),
),
result.distance,
)
for result in results
if result.distance >= score_threshold
]
else:
documents_with_scores = [
(
Document(
content=result.document,
metadata=json.loads(result.metadata),
),
result.distance,
)
for result in results
]
return documents_with_scores
except Exception as e:
self.logger.error("similarity_search_with_score_by_vector failed:", str(e))
return []
def similarity_search_with_score(
self,
query: str,
k: int = 4,
filter: Optional[MetadataFilters] = None,
score_threshold: Optional[float] = None,
) -> List[Tuple[Document, float]]:
"""Do similarity search via query String with score."""
embedding = self.embedding_function.embed_query(query)
docs = self.similarity_search_with_score_by_vector(
embedding=embedding, k=k, filter=filter, score_threshold=score_threshold
)
return docs
def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]:
"""Delete vectors from vector table."""
if ids is None:
raise ValueError("No ids provided to delete.")
# Define the table schema
chunks_table = Table(
self.collection_name,
Base.metadata,
Column("id", VARCHAR(40), primary_key=True),
Column("embedding", Vector(self.embedding_dimension)),
Column("document", LONGTEXT, nullable=True),
Column("metadata", JSON, nullable=True), # filter
keep_existing=True,
)
try:
with self.engine.connect() as conn, conn.begin():
delete_condition = chunks_table.c.id.in_(ids)
delete_stmt = chunks_table.delete().where(delete_condition)
with ob_grwlock.reader_lock():
if self.sql_logger is not None:
self.sql_logger.debug(
f"Trying to delete vectors: {str(delete_stmt)}"
)
conn.execute(delete_stmt)
return True
except Exception as e:
self.logger.error("Delete operation failed:", str(e))
return False
def _normalization_vectors(self, vector):
import numpy as np
norm = np.linalg.norm(vector)
return (vector / norm).tolist()
@classmethod
def connection_string_from_db_params(
cls,
host: str,
port: int,
database: str,
user: str,
password: str,
) -> str:
"""Get connection string."""
return f"mysql+pymysql://{user}:{password}@{host}:{port}/{database}"
ob_collection_stats_lock = threading.Lock()
ob_collection_stats: Dict[str, OceanBaseCollectionStat] = {}
@register_resource(
_("OceanBase Vector Store"),
"oceanbase_vector_store",
category=ResourceCategory.VECTOR_STORE,
parameters=[
*_COMMON_PARAMETERS,
Parameter.build_from(
_("OceanBase Host"),
"ob_host",
str,
description=_("oceanbase host"),
optional=True,
default=None,
),
Parameter.build_from(
_("OceanBase Port"),
"ob_port",
int,
description=_("oceanbase port"),
optional=True,
default=None,
),
Parameter.build_from(
_("OceanBase User"),
"ob_user",
str,
description=_("user to login"),
optional=True,
default=None,
),
Parameter.build_from(
_("OceanBase Password"),
"ob_password",
str,
description=_("password to login"),
optional=True,
default=None,
),
Parameter.build_from(
_("OceanBase Database"),
"ob_database",
str,
description=_("database for vector tables"),
optional=True,
default=None,
),
],
description="OceanBase vector store.",
)
class OceanBaseConfig(VectorStoreConfig):
"""OceanBase vector store config."""
class Config:
"""Config for BaseModel."""
arbitrary_types_allowed = True
"""OceanBase config"""
ob_host: str = Field(
default="127.0.0.1",
description="oceanbase host",
)
ob_port: int = Field(
default=2881,
description="oceanbase port",
)
ob_user: str = Field(
default="root@test",
description="user to login",
)
ob_password: str = Field(
default="",
description="password to login",
)
ob_database: str = Field(
default="test",
description="database for vector tables",
)
class OceanBaseStore(VectorStoreBase):
"""OceanBase vector store."""
def __init__(self, vector_store_config: OceanBaseConfig) -> None:
"""Create a OceanBaseStore instance."""
if vector_store_config.embedding_fn is None:
raise ValueError("embedding_fn is required for OceanBaseStore")
self.embeddings = vector_store_config.embedding_fn
self.collection_name = vector_store_config.name
vector_store_config = vector_store_config.dict()
self.OB_HOST = str(
vector_store_config.get("ob_host") or os.getenv("OB_HOST", "127.0.0.1")
)
self.OB_PORT = int(
vector_store_config.get("ob_port") or int(os.getenv("OB_PORT", "2881"))
)
self.OB_USER = str(
vector_store_config.get("ob_user") or os.getenv("OB_USER", "root@test")
)
self.OB_PASSWORD = str(
vector_store_config.get("ob_password") or os.getenv("OB_PASSWORD", "")
)
self.OB_DATABASE = str(
vector_store_config.get("ob_database") or os.getenv("OB_DATABASE", "test")
)
self.OB_ENABLE_NORMALIZE_VECTOR = bool(
os.getenv("OB_ENABLE_NORMALIZE_VECTOR", "")
)
self.connection_string = OceanBase.connection_string_from_db_params(
self.OB_HOST, self.OB_PORT, self.OB_DATABASE, self.OB_USER, self.OB_PASSWORD
)
self.logger = logger
with ob_collection_stats_lock:
if ob_collection_stats.get(self.collection_name) is None:
ob_collection_stats[self.collection_name] = OceanBaseCollectionStat()
self.collection_stat = ob_collection_stats[self.collection_name]
self.vector_store_client = OceanBase(
connection_string=self.connection_string,
embedding_function=self.embeddings,
collection_name=self.collection_name,
logger=self.logger,
sql_logger=sql_logger,
enable_index=bool(os.getenv("OB_ENABLE_INDEX", "")),
collection_stat=self.collection_stat,
enable_normalize_vector=self.OB_ENABLE_NORMALIZE_VECTOR,
)
def similar_search(
self, text, topk, filters: Optional[MetadataFilters] = None, **kwargs: Any
) -> List[Chunk]:
"""Perform a search on a query string and return results."""
self.logger.info("OceanBase: similar_search..")
documents = self.vector_store_client.similarity_search(
text, topk, filter=filters
)
return [Chunk(content=doc.content, metadata=doc.metadata) for doc in documents]
def similar_search_with_scores(
self,
text,
topk,
score_threshold: float,
filters: Optional[MetadataFilters] = None,
) -> List[Chunk]:
"""Perform a search on a query string and return results with score."""
self.logger.info("OceanBase: similar_search_with_scores..")
docs_and_scores = self.vector_store_client.similarity_search_with_score(
text, topk, filter=filters
)
return [
Chunk(content=doc.content, metadata=doc.metadata, score=score)
for doc, score in docs_and_scores
]
def vector_name_exists(self):
"""Whether vector name exists."""
self.logger.info("OceanBase: vector_name_exists..")
try:
self.vector_store_client.create_collection()
return True
except Exception as e:
logger.error("vector_name_exists error", e.message)
return False
def load_document(self, chunks: List[Chunk]) -> List[str]:
"""Load document in vector database."""
self.logger.info("OceanBase: load_document..")
# lc_documents = [Chunk.chunk2langchain(chunk) for chunk in chunks]
texts = [chunk.content for chunk in chunks]
metadatas = [chunk.metadata for chunk in chunks]
ids = self.vector_store_client.add_texts(texts=texts, metadatas=metadatas)
return ids
def delete_vector_name(self, vector_name):
"""Delete vector name."""
self.logger.info("OceanBase: delete_vector_name..")
return self.vector_store_client.delete_collection()
def delete_by_ids(self, ids):
"""Delete vector by ids."""
self.logger.info("OceanBase: delete_by_ids..")
ids = ids.split(",")
if len(ids) > 0:
self.vector_store_client.delete(ids)

View File

@ -2,6 +2,7 @@ ARG BASE_IMAGE="eosphorosai/dbgpt:latest"
FROM ${BASE_IMAGE}
RUN pip install dashscope
RUN apt-get update && apt-get install -y wget gnupg lsb-release net-tools
RUN apt-key adv --keyserver keyserver.ubuntu.com --recv-keys 467B942D3A79BD29
@ -9,7 +10,7 @@ RUN apt-key adv --keyserver keyserver.ubuntu.com --recv-keys 467B942D3A79BD29
RUN wget https://dev.mysql.com/get/mysql-apt-config_0.8.17-1_all.deb
RUN dpkg -i mysql-apt-config_0.8.17-1_all.deb
RUN apt-get update && apt-get install -y mysql-server && apt-get clean
RUN apt-key adv --keyserver keyserver.ubuntu.com --recv-keys B7B3B788A8D3785C && apt-get update && apt-get install -y mysql-server && apt-get clean
# Remote access
RUN sed -i 's/bind-address\s*=\s*127.0.0.1/bind-address = 0.0.0.0/g' /etc/mysql/mysql.conf.d/mysqld.cnf \

View File

@ -0,0 +1,42 @@
version: '3.8'
services:
oceanbase:
image: oceanbase/oceanbase-ce:vector
ports:
- 2881:2881
networks:
- dbgptnet
dbgpt:
image: eosphorosai/dbgpt-allinone
environment:
- DBGPT_WEBSERVER_PORT=12345
- VECTOR_STORE_TYPE=OceanBase
- OB_HOST=oceanbase
- OB_HOST=127.0.0.1
- OB_PORT=2881
- OB_USER=root@test
- OB_DATABASE=test
- OB_SQL_DBG_LOG_PATH=/sql_log/sql.log
- LOCAL_DB_TYPE=sqlite
- LLM_MODEL=tongyi_proxyllm
- PROXYLLM_BACKEND=qwen-plus
- EMBEDDING_MODEL=text2vec
# your api key
# - TONGYI_PROXY_API_KEY={your-api-key}
- LANGUAGE=zh
# - OB_ENABLE_NORMALIZE_VECTOR=True
# - OB_ENABLE_INDEX=True
ports:
- 3306:3306
- 12345:12345
volumes:
# - {your-ob-sql-dbg-log-dir}:/sql_log
# - {your-model-dir}:/app/models
networks:
- dbgptnet
networks:
dbgptnet:
driver: bridge
name: dbgptnet

View File

@ -69,6 +69,7 @@ import TabItem from '@theme/TabItem';
{label: 'Chroma', value: 'Chroma'},
{label: 'Milvus', value: 'Milvus'},
{label: 'Weaviate', value: 'Weaviate'},
{label: 'OceanBase', value: 'OceanBase'},
]}>
<TabItem value="Chroma" label="Chroma">
@ -106,6 +107,25 @@ set ``VECTOR_STORE_TYPE`` in ``.env`` file
VECTOR_STORE_TYPE=Weaviate
#WEAVIATE_URL=https://kt-region-m8hcy0wc.weaviate.network
```
</TabItem>
<TabItem value="OceanBase" label="OceanBase">
set ``VECTOR_STORE_TYPE`` in ``.env`` file
```shell
OB_HOST=127.0.0.1
OB_PORT=2881
OB_USER=root@test
OB_DATABASE=test
## Optional
# OB_PASSWORD=
## Optional: SQL statements executed by OceanBase is recorded in the log file specified by {OB_SQL_DBG_LOG_PATH}.
# OB_SQL_DBG_LOG_PATH={your-sql-dbg-log-dir}/sql.log
## Optional: If {OB_ENABLE_NORMALIZE_VECTOR} is set, the vector stored in OceanBase is normalized.
# OB_ENABLE_NORMALIZE_VECTOR=True
## Optional: If {OB_ENABLE_INDEX} is set, OceanBase will automatically create a vector index table.
# OB_ENABLE_INDEX=True
```
</TabItem>
</Tabs>

View File

@ -18,8 +18,29 @@ git lfs clone https://huggingface.co/GanymedeNil/text2vec-large-chinese
Update .env file and set VECTOR_STORE_TYPE.
DB-GPT currently support Chroma(Default), Milvus(>2.1), Weaviate vector database.
If you want to change vector db, Update your .env, set your vector store type, VECTOR_STORE_TYPE=Chroma (now only support Chroma and Milvus(>2.1), if you set Milvus, please set MILVUS_URL and MILVUS_PORT)
DB-GPT currently support Chroma(Default), Milvus(>2.1), Weaviate, OceanBase vector database.
If you want to change vector db, Update your .env, set your vector store type, VECTOR_STORE_TYPE=Chroma (now only support Chroma and Milvus(>2.1), if you set Milvus, please set MILVUS_URL and MILVUS_PORT).
If you want to use OceanBase, please first start a docker container via the following command:
```shell
docker run -p 2881:2881 --name obvec -d oceanbase/oceanbase-ce:vector
```
Then set the following variables in the .env file:
```shell
VECTOR_STORE_TYPE=OceanBase
OB_HOST=127.0.0.1
OB_PORT=2881
OB_USER=root@test
OB_DATABASE=test
## Optional
# OB_PASSWORD=
## Optional: SQL statements executed by OceanBase is recorded in the log file specified by {OB_SQL_DBG_LOG_PATH}.
# OB_SQL_DBG_LOG_PATH={your-sql-dbg-log-dir}/sql.log
## Optional: If {OB_ENABLE_NORMALIZE_VECTOR} is set, the vector stored in OceanBase is normalized.
# OB_ENABLE_NORMALIZE_VECTOR=True
## Optional: If {OB_ENABLE_INDEX} is set, OceanBase will automatically create a vector index table.
# OB_ENABLE_INDEX=True
```
If you want to support more vector db, you can integrate yourself.[how to integrate](https://db-gpt.readthedocs.io/en/latest/modules/vector.html)
```commandline
#*******************************************************************#

View File

@ -147,6 +147,50 @@ msgid ""
"connection string."
msgstr "向量存储的连接字符串,如果未设置,将使用默认的连接字符串。"
#: ../dbgpt/storage/vector_store/oceanbase_store.py:542
msgid "OceanBase Vector Store"
msgstr "OceanBase 向量存储"
#: ../dbgpt/storage/vector_store/oceanbase_store.py:548
msgid "OceanBase Host"
msgstr "OceanBase 主机地址"
#: ../dbgpt/storage/vector_store/oceanbase_store.py:552
msgid "oceanbase host"
msgstr "oceanbase 主机地址"
#: ../dbgpt/storage/vector_store/oceanbase_store.py:558
msgid "OceanBase Port"
msgstr "OceanBase 端口"
#: ../dbgpt/storage/vector_store/oceanbase_store.py:562
msgid "oceanbase port"
msgstr "oceanbase 端口"
#: ../dbgpt/storage/vector_store/oceanbase_store.py:568
msgid "OceanBase User"
msgstr "OceanBase 用户名"
#: ../dbgpt/storage/vector_store/oceanbase_store.py:572
msgid "user to login"
msgstr "登录用户名"
#: ../dbgpt/storage/vector_store/oceanbase_store.py:578
msgid "OceanBase Password"
msgstr "OceanBase 密码"
#: ../dbgpt/storage/vector_store/oceanbase_store.py:582
msgid "password to login"
msgstr "登录密码"
#: ../dbgpt/storage/vector_store/oceanbase_store.py:588
msgid "OceanBase Database"
msgstr "OceanBase 数据库名"
#: ../dbgpt/storage/vector_store/oceanbase_store.py:592
msgid "database for vector tables"
msgstr "存放向量表的数据库名"
#: ../dbgpt/storage/vector_store/base.py:19
msgid "Collection Name"
msgstr "集合名称"