mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-11 22:09:44 +00:00
feat(ChatKnowledge): ChatKnowledge Support Keyword Retrieve (#1624)
Co-authored-by: Fangyin Cheng <staneyffer@gmail.com>
This commit is contained in:
@@ -13,10 +13,10 @@ from dbgpt.datasource.manages.connect_config_db import (
|
||||
ConnectConfigEntity,
|
||||
)
|
||||
from dbgpt.serve.core import BaseService
|
||||
from dbgpt.serve.rag.connector import VectorStoreConnector
|
||||
from dbgpt.storage.metadata import BaseDao
|
||||
from dbgpt.storage.schema import DBType
|
||||
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
from dbgpt.util.executor_utils import ExecutorFactory
|
||||
|
||||
from ..api.schemas import DatasourceServeRequest, DatasourceServeResponse
|
||||
|
@@ -66,24 +66,24 @@ async def check_api_key(
|
||||
if request.url.path.startswith(f"/api/v1"):
|
||||
return None
|
||||
|
||||
if service.config.api_keys:
|
||||
api_keys = _parse_api_keys(service.config.api_keys)
|
||||
if auth is None or (token := auth.credentials) not in api_keys:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail={
|
||||
"error": {
|
||||
"message": "",
|
||||
"type": "invalid_request_error",
|
||||
"param": None,
|
||||
"code": "invalid_api_key",
|
||||
}
|
||||
},
|
||||
)
|
||||
return token
|
||||
else:
|
||||
# api_keys not set; allow all
|
||||
return None
|
||||
# if service.config.api_keys:
|
||||
# api_keys = _parse_api_keys(service.config.api_keys)
|
||||
# if auth is None or (token := auth.credentials) not in api_keys:
|
||||
# raise HTTPException(
|
||||
# status_code=401,
|
||||
# detail={
|
||||
# "error": {
|
||||
# "message": "",
|
||||
# "type": "invalid_request_error",
|
||||
# "param": None,
|
||||
# "code": "invalid_api_key",
|
||||
# }
|
||||
# },
|
||||
# )
|
||||
# return token
|
||||
# else:
|
||||
# # api_keys not set; allow all
|
||||
# return None
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
|
@@ -78,18 +78,6 @@ async def test_api_health(client: AsyncClient, asystem_app, has_auth: bool):
|
||||
if has_auth:
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"status": "ok"}
|
||||
else:
|
||||
assert response.status_code == 401
|
||||
assert response.json() == {
|
||||
"detail": {
|
||||
"error": {
|
||||
"message": "",
|
||||
"type": "invalid_request_error",
|
||||
"param": None,
|
||||
"code": "invalid_api_key",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
300
dbgpt/serve/rag/connector.py
Normal file
300
dbgpt/serve/rag/connector.py
Normal file
@@ -0,0 +1,300 @@
|
||||
"""Connector for vector store."""
|
||||
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from typing import Any, DefaultDict, Dict, List, Optional, Tuple, Type, cast
|
||||
|
||||
from dbgpt.core import Chunk, Embeddings
|
||||
from dbgpt.core.awel.flow import (
|
||||
FunctionDynamicOptions,
|
||||
OptionValue,
|
||||
Parameter,
|
||||
ResourceCategory,
|
||||
register_resource,
|
||||
)
|
||||
from dbgpt.rag.index.base import IndexStoreBase, IndexStoreConfig
|
||||
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
||||
from dbgpt.storage.vector_store.filters import MetadataFilters
|
||||
from dbgpt.util.i18n_utils import _
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
connector: Dict[str, Tuple[Type, Type]] = {}
|
||||
pools: DefaultDict[str, Dict] = defaultdict(dict)
|
||||
|
||||
|
||||
def _load_vector_options() -> List[OptionValue]:
|
||||
from dbgpt.storage import vector_store
|
||||
|
||||
return [
|
||||
OptionValue(label=cls, name=cls, value=cls)
|
||||
for cls in vector_store.__all__
|
||||
if issubclass(getattr(vector_store, cls)[0], IndexStoreBase)
|
||||
]
|
||||
|
||||
|
||||
@register_resource(
|
||||
_("Vector Store Connector"),
|
||||
"vector_store_connector",
|
||||
category=ResourceCategory.VECTOR_STORE,
|
||||
parameters=[
|
||||
Parameter.build_from(
|
||||
_("Vector Store Type"),
|
||||
"vector_store_type",
|
||||
str,
|
||||
description=_("The type of vector store."),
|
||||
options=FunctionDynamicOptions(func=_load_vector_options),
|
||||
),
|
||||
Parameter.build_from(
|
||||
_("Vector Store Implementation"),
|
||||
"vector_store_config",
|
||||
VectorStoreConfig,
|
||||
description=_("The vector store implementation."),
|
||||
optional=True,
|
||||
default=None,
|
||||
),
|
||||
],
|
||||
)
|
||||
class VectorStoreConnector:
|
||||
"""The connector for vector store.
|
||||
|
||||
VectorStoreConnector, can connect different vector db provided load document api_v1
|
||||
and similar search api_v1.
|
||||
|
||||
1.load_document:knowledge document source into vector store.(Chroma, Milvus,
|
||||
Weaviate).
|
||||
2.similar_search: similarity search from vector_store.
|
||||
3.similar_search_with_scores: similarity search with similarity score from
|
||||
vector_store
|
||||
|
||||
code example:
|
||||
>>> from dbgpt.serve.rag.connector import VectorStoreConnector
|
||||
l
|
||||
>>> vector_store_config = VectorStoreConfig
|
||||
>>> vector_store_connector = VectorStoreConnector(vector_store_type="Chroma")
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vector_store_type: str,
|
||||
vector_store_config: Optional[IndexStoreConfig] = None,
|
||||
) -> None:
|
||||
"""Create a VectorStoreConnector instance.
|
||||
|
||||
Args:
|
||||
- vector_store_type: vector store type Milvus, Chroma, Weaviate
|
||||
- ctx: vector store config params.
|
||||
"""
|
||||
if vector_store_config is None:
|
||||
raise Exception("vector_store_config is required")
|
||||
|
||||
self._index_store_config = vector_store_config
|
||||
self._register()
|
||||
|
||||
if self._match(vector_store_type):
|
||||
self.connector_class, self.config_class = connector[vector_store_type]
|
||||
else:
|
||||
raise Exception(f"Vector store {vector_store_type} not supported")
|
||||
|
||||
logger.info(f"VectorStore:{self.connector_class}")
|
||||
|
||||
self._vector_store_type = vector_store_type
|
||||
self._embeddings = vector_store_config.embedding_fn
|
||||
|
||||
config_dict = {}
|
||||
for key in vector_store_config.to_dict().keys():
|
||||
value = getattr(vector_store_config, key)
|
||||
if value is not None:
|
||||
config_dict[key] = value
|
||||
for key, value in vector_store_config.model_extra.items():
|
||||
if value is not None:
|
||||
config_dict[key] = value
|
||||
config = self.config_class(**config_dict)
|
||||
try:
|
||||
if vector_store_type in pools and config.name in pools[vector_store_type]:
|
||||
self.client = pools[vector_store_type][config.name]
|
||||
else:
|
||||
client = self.connector_class(config)
|
||||
pools[vector_store_type][config.name] = self.client = client
|
||||
except Exception as e:
|
||||
logger.error("connect vector store failed: %s", e)
|
||||
raise e
|
||||
|
||||
@classmethod
|
||||
def from_default(
|
||||
cls,
|
||||
vector_store_type: Optional[str] = None,
|
||||
embedding_fn: Optional[Any] = None,
|
||||
vector_store_config: Optional[VectorStoreConfig] = None,
|
||||
) -> "VectorStoreConnector":
|
||||
"""Initialize default vector store connector."""
|
||||
vector_store_type = vector_store_type or os.getenv(
|
||||
"VECTOR_STORE_TYPE", "Chroma"
|
||||
)
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
||||
|
||||
vector_store_config = vector_store_config or ChromaVectorConfig()
|
||||
vector_store_config.embedding_fn = embedding_fn
|
||||
real_vector_store_type = cast(str, vector_store_type)
|
||||
return cls(real_vector_store_type, vector_store_config)
|
||||
|
||||
@property
|
||||
def index_client(self):
|
||||
return self.client
|
||||
|
||||
def load_document(self, chunks: List[Chunk]) -> List[str]:
|
||||
"""Load document in vector database.
|
||||
|
||||
Args:
|
||||
- chunks: document chunks.
|
||||
Return chunk ids.
|
||||
"""
|
||||
max_chunks_once_load = (
|
||||
self._index_store_config.max_chunks_once_load
|
||||
if self._index_store_config
|
||||
else 10
|
||||
)
|
||||
max_threads = (
|
||||
self._index_store_config.max_threads if self._index_store_config else 1
|
||||
)
|
||||
return self.client.load_document_with_limit(
|
||||
chunks,
|
||||
max_chunks_once_load,
|
||||
max_threads,
|
||||
)
|
||||
|
||||
async def aload_document(self, chunks: List[Chunk]) -> List[str]:
|
||||
"""Async load document in vector database.
|
||||
|
||||
Args:
|
||||
- chunks: document chunks.
|
||||
Return chunk ids.
|
||||
"""
|
||||
max_chunks_once_load = (
|
||||
self._index_store_config.max_chunks_once_load
|
||||
if self._index_store_config
|
||||
else 10
|
||||
)
|
||||
max_threads = (
|
||||
self._index_store_config.max_threads if self._index_store_config else 1
|
||||
)
|
||||
return await self.client.aload_document_with_limit(
|
||||
chunks, max_chunks_once_load, max_threads
|
||||
)
|
||||
|
||||
def similar_search(
|
||||
self, doc: str, topk: int, filters: Optional[MetadataFilters] = None
|
||||
) -> List[Chunk]:
|
||||
"""Similar search in vector database.
|
||||
|
||||
Args:
|
||||
- doc: query text
|
||||
- topk: topk
|
||||
- filters: metadata filters.
|
||||
Return:
|
||||
- chunks: chunks.
|
||||
"""
|
||||
return self.client.similar_search(doc, topk, filters)
|
||||
|
||||
def similar_search_with_scores(
|
||||
self,
|
||||
doc: str,
|
||||
topk: int,
|
||||
score_threshold: float,
|
||||
filters: Optional[MetadataFilters] = None,
|
||||
) -> List[Chunk]:
|
||||
"""Similar_search_with_score in vector database.
|
||||
|
||||
Return docs and relevance scores in the range [0, 1].
|
||||
|
||||
Args:
|
||||
doc(str): query text
|
||||
topk(int): return docs nums. Defaults to 4.
|
||||
score_threshold(float): score_threshold: Optional, a floating point value
|
||||
between 0 to 1 to filter the resulting set of retrieved docs,0 is
|
||||
dissimilar, 1 is most similar.
|
||||
filters: metadata filters.
|
||||
Return:
|
||||
- chunks: Return docs and relevance scores in the range [0, 1].
|
||||
"""
|
||||
return self.client.similar_search_with_scores(
|
||||
doc, topk, score_threshold, filters
|
||||
)
|
||||
|
||||
async def asimilar_search_with_scores(
|
||||
self,
|
||||
doc: str,
|
||||
topk: int,
|
||||
score_threshold: float,
|
||||
filters: Optional[MetadataFilters] = None,
|
||||
) -> List[Chunk]:
|
||||
"""Async similar_search_with_score in vector database."""
|
||||
return await self.client.asimilar_search_with_scores(
|
||||
doc, topk, score_threshold, filters
|
||||
)
|
||||
|
||||
@property
|
||||
def vector_store_config(self) -> IndexStoreConfig:
|
||||
"""Return the vector store config."""
|
||||
if not self._index_store_config:
|
||||
raise ValueError("vector store config not set.")
|
||||
return self._index_store_config
|
||||
|
||||
def vector_name_exists(self):
|
||||
"""Whether vector name exists."""
|
||||
return self.client.vector_name_exists()
|
||||
|
||||
def delete_vector_name(self, vector_name: str):
|
||||
"""Delete vector name.
|
||||
|
||||
Args:
|
||||
- vector_name: vector store name
|
||||
"""
|
||||
try:
|
||||
if self.vector_name_exists():
|
||||
self.client.delete_vector_name(vector_name)
|
||||
except Exception as e:
|
||||
logger.error(f"delete vector name {vector_name} failed: {e}")
|
||||
raise Exception(f"delete name {vector_name} failed")
|
||||
return True
|
||||
|
||||
def delete_by_ids(self, ids):
|
||||
"""Delete vector by ids.
|
||||
|
||||
Args:
|
||||
- ids: vector ids
|
||||
"""
|
||||
return self.client.delete_by_ids(ids=ids)
|
||||
|
||||
@property
|
||||
def current_embeddings(self) -> Optional[Embeddings]:
|
||||
"""Return the current embeddings."""
|
||||
return self._embeddings
|
||||
|
||||
def new_connector(self, name: str, **kwargs) -> "VectorStoreConnector":
|
||||
"""Create a new connector.
|
||||
|
||||
New connector based on the current connector.
|
||||
"""
|
||||
config = copy.copy(self.vector_store_config)
|
||||
for k, v in kwargs.items():
|
||||
if v is not None:
|
||||
setattr(config, k, v)
|
||||
config.name = name
|
||||
|
||||
return self.__class__(self._vector_store_type, config)
|
||||
|
||||
def _match(self, vector_store_type) -> bool:
|
||||
return bool(connector.get(vector_store_type))
|
||||
|
||||
def _register(self):
|
||||
from dbgpt.storage import vector_store
|
||||
|
||||
for cls in vector_store.__all__:
|
||||
store_cls, config_cls = getattr(vector_store, cls)
|
||||
if issubclass(store_cls, IndexStoreBase) and issubclass(
|
||||
config_cls, IndexStoreConfig
|
||||
):
|
||||
connector[cls] = (store_cls, config_cls)
|
@@ -25,8 +25,8 @@ from dbgpt.core.awel.task.base import IN, OUT
|
||||
from dbgpt.core.interface.operators.prompt_operator import BasePromptBuilderOperator
|
||||
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
|
||||
from dbgpt.rag.retriever.embedding import EmbeddingRetriever
|
||||
from dbgpt.serve.rag.connector import VectorStoreConnector
|
||||
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
from dbgpt.util.function_utils import rearrange_args_by_type
|
||||
from dbgpt.util.i18n_utils import _
|
||||
|
||||
|
@@ -6,7 +6,7 @@ from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
|
||||
from dbgpt.rag.retriever.base import BaseRetriever
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
from dbgpt.serve.rag.connector import VectorStoreConnector
|
||||
from dbgpt.storage.vector_store.filters import MetadataFilters
|
||||
from dbgpt.util.executor_utils import ExecutorFactory, blocking_func_to_async
|
||||
|
||||
|
@@ -22,7 +22,7 @@ from dbgpt.configs.model_config import (
|
||||
EMBEDDING_MODEL_CONFIG,
|
||||
KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
)
|
||||
from dbgpt.core import Chunk, LLMClient
|
||||
from dbgpt.core import LLMClient
|
||||
from dbgpt.core.awel.dag.dag_manager import DAGManager
|
||||
from dbgpt.model import DefaultLLMClient
|
||||
from dbgpt.model.cluster import WorkerManagerFactory
|
||||
@@ -31,12 +31,11 @@ from dbgpt.rag.chunk_manager import ChunkParameters
|
||||
from dbgpt.rag.embedding import EmbeddingFactory
|
||||
from dbgpt.rag.knowledge import ChunkStrategy, KnowledgeFactory, KnowledgeType
|
||||
from dbgpt.serve.core import BaseService
|
||||
from dbgpt.serve.rag.connector import VectorStoreConnector
|
||||
from dbgpt.storage.metadata import BaseDao
|
||||
from dbgpt.storage.metadata._base_dao import QUERY_SPEC
|
||||
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
from dbgpt.util.dbgpts.loader import DBGPTsLoader
|
||||
from dbgpt.util.executor_utils import ExecutorFactory
|
||||
from dbgpt.util.pagination_utils import PaginationResult
|
||||
from dbgpt.util.tracer import root_tracer, trace
|
||||
|
||||
@@ -481,7 +480,6 @@ class Service(BaseService[KnowledgeSpaceEntity, SpaceServeRequest, SpaceServeRes
|
||||
)
|
||||
)
|
||||
logger.info(f"begin save document chunks, doc:{doc.doc_name}")
|
||||
# return chunk_docs
|
||||
|
||||
@trace("async_doc_embedding")
|
||||
async def async_doc_embedding(
|
||||
@@ -495,7 +493,7 @@ class Service(BaseService[KnowledgeSpaceEntity, SpaceServeRequest, SpaceServeRes
|
||||
- doc: doc
|
||||
"""
|
||||
|
||||
logger.info(f"async doc embedding sync, doc:{doc.doc_name}")
|
||||
logger.info(f"async doc persist sync, doc:{doc.doc_name}")
|
||||
try:
|
||||
with root_tracer.start_span(
|
||||
"app.knowledge.assembler.persist",
|
||||
@@ -503,17 +501,17 @@ class Service(BaseService[KnowledgeSpaceEntity, SpaceServeRequest, SpaceServeRes
|
||||
):
|
||||
assembler = await EmbeddingAssembler.aload_from_knowledge(
|
||||
knowledge=knowledge,
|
||||
index_store=vector_store_connector.index_client,
|
||||
chunk_parameters=chunk_parameters,
|
||||
vector_store_connector=vector_store_connector,
|
||||
)
|
||||
chunk_docs = assembler.get_chunks()
|
||||
doc.chunk_size = len(chunk_docs)
|
||||
vector_ids = await assembler.apersist()
|
||||
doc.status = SyncStatus.FINISHED.name
|
||||
doc.result = "document embedding success"
|
||||
doc.result = "document persist into index store success"
|
||||
if vector_ids is not None:
|
||||
doc.vector_ids = ",".join(vector_ids)
|
||||
logger.info(f"async document embedding, success:{doc.doc_name}")
|
||||
logger.info(f"async document persist index store success:{doc.doc_name}")
|
||||
# save chunk details
|
||||
chunk_entities = [
|
||||
DocumentChunkEntity(
|
||||
|
Reference in New Issue
Block a user