feat(ChatKnowledge): ChatKnowledge Support Keyword Retrieve (#1624)

Co-authored-by: Fangyin Cheng <staneyffer@gmail.com>
This commit is contained in:
Aries-ckt
2024-06-13 13:49:17 +08:00
committed by GitHub
parent 162e2c9b1c
commit 58d08780d6
86 changed files with 948 additions and 440 deletions

View File

@@ -1,13 +1,12 @@
"""DBSchemaAssembler."""
from typing import Any, List, Optional
from dbgpt.core import Chunk, Embeddings
from dbgpt.core import Chunk
from dbgpt.datasource.base import BaseConnector
from dbgpt.storage.vector_store.connector import VectorStoreConnector
from ..assembler.base import BaseAssembler
from ..chunk_manager import ChunkParameters
from ..embedding.embedding_factory import DefaultEmbeddingFactory
from ..index.base import IndexStoreBase
from ..knowledge.datasource import DatasourceKnowledge
from ..retriever.db_schema import DBSchemaRetriever
@@ -36,36 +35,22 @@ class DBSchemaAssembler(BaseAssembler):
def __init__(
self,
connector: BaseConnector,
vector_store_connector: VectorStoreConnector,
index_store: IndexStoreBase,
chunk_parameters: Optional[ChunkParameters] = None,
embedding_model: Optional[str] = None,
embeddings: Optional[Embeddings] = None,
**kwargs: Any,
) -> None:
"""Initialize with Embedding Assembler arguments.
Args:
connector: (BaseConnector) BaseConnector connection.
vector_store_connector: (VectorStoreConnector) VectorStoreConnector to use.
index_store: (IndexStoreBase) IndexStoreBase to use.
chunk_manager: (Optional[ChunkManager]) ChunkManager to use for chunking.
embedding_model: (Optional[str]) Embedding model to use.
embeddings: (Optional[Embeddings]) Embeddings to use.
"""
knowledge = DatasourceKnowledge(connector)
self._connector = connector
self._vector_store_connector = vector_store_connector
self._embedding_model = embedding_model
if self._embedding_model and not embeddings:
embeddings = DefaultEmbeddingFactory(
default_model_name=self._embedding_model
).create(self._embedding_model)
if (
embeddings
and self._vector_store_connector.vector_store_config.embedding_fn is None
):
self._vector_store_connector.vector_store_config.embedding_fn = embeddings
self._index_store = index_store
super().__init__(
knowledge=knowledge,
@@ -77,29 +62,23 @@ class DBSchemaAssembler(BaseAssembler):
def load_from_connection(
cls,
connector: BaseConnector,
vector_store_connector: VectorStoreConnector,
index_store: IndexStoreBase,
chunk_parameters: Optional[ChunkParameters] = None,
embedding_model: Optional[str] = None,
embeddings: Optional[Embeddings] = None,
) -> "DBSchemaAssembler":
"""Load document embedding into vector store from path.
Args:
connector: (BaseConnector) BaseConnector connection.
vector_store_connector: (VectorStoreConnector) VectorStoreConnector to use.
index_store: (IndexStoreBase) IndexStoreBase to use.
chunk_parameters: (Optional[ChunkParameters]) ChunkManager to use for
chunking.
embedding_model: (Optional[str]) Embedding model to use.
embeddings: (Optional[Embeddings]) Embeddings to use.
Returns:
DBSchemaAssembler
"""
return cls(
connector=connector,
vector_store_connector=vector_store_connector,
embedding_model=embedding_model,
index_store=index_store,
chunk_parameters=chunk_parameters,
embeddings=embeddings,
)
def get_chunks(self) -> List[Chunk]:
@@ -112,7 +91,7 @@ class DBSchemaAssembler(BaseAssembler):
Returns:
List[str]: List of chunk ids.
"""
return self._vector_store_connector.load_document(self._chunks)
return self._index_store.load_document(self._chunks)
def _extract_info(self, chunks) -> List[Chunk]:
"""Extract info from chunks."""
@@ -131,5 +110,5 @@ class DBSchemaAssembler(BaseAssembler):
top_k=top_k,
connector=self._connector,
is_embeddings=True,
vector_store_connector=self._vector_store_connector,
index_store=self._index_store,
)

View File

@@ -3,13 +3,13 @@ from concurrent.futures import ThreadPoolExecutor
from typing import Any, List, Optional
from dbgpt.core import Chunk, Embeddings
from dbgpt.storage.vector_store.connector import VectorStoreConnector
from ...util.executor_utils import blocking_func_to_async
from ..assembler.base import BaseAssembler
from ..chunk_manager import ChunkParameters
from ..embedding.embedding_factory import DefaultEmbeddingFactory
from ..index.base import IndexStoreBase
from ..knowledge.base import Knowledge
from ..retriever import BaseRetriever, RetrieverStrategy
from ..retriever.embedding import EmbeddingRetriever
@@ -32,37 +32,26 @@ class EmbeddingAssembler(BaseAssembler):
def __init__(
self,
knowledge: Knowledge,
vector_store_connector: VectorStoreConnector,
index_store: IndexStoreBase,
chunk_parameters: Optional[ChunkParameters] = None,
embedding_model: Optional[str] = None,
embeddings: Optional[Embeddings] = None,
retrieve_strategy: Optional[RetrieverStrategy] = RetrieverStrategy.EMBEDDING,
**kwargs: Any,
) -> None:
"""Initialize with Embedding Assembler arguments.
Args:
knowledge: (Knowledge) Knowledge datasource.
vector_store_connector: (VectorStoreConnector) VectorStoreConnector to use.
index_store: (IndexStoreBase) IndexStoreBase to use.
chunk_parameters: (Optional[ChunkParameters]) ChunkManager to use for
chunking.
keyword_store: (Optional[IndexStoreBase]) IndexStoreBase to use.
embedding_model: (Optional[str]) Embedding model to use.
embeddings: (Optional[Embeddings]) Embeddings to use.
"""
if knowledge is None:
raise ValueError("knowledge datasource must be provided.")
self._vector_store_connector = vector_store_connector
self._embedding_model = embedding_model
if self._embedding_model and not embeddings:
embeddings = DefaultEmbeddingFactory(
default_model_name=self._embedding_model
).create(self._embedding_model)
if (
embeddings
and self._vector_store_connector.vector_store_config.embedding_fn is None
):
self._vector_store_connector.vector_store_config.embedding_fn = embeddings
self._index_store = index_store
self._retrieve_strategy = retrieve_strategy
super().__init__(
knowledge=knowledge,
@@ -74,52 +63,53 @@ class EmbeddingAssembler(BaseAssembler):
def load_from_knowledge(
cls,
knowledge: Knowledge,
vector_store_connector: VectorStoreConnector,
index_store: IndexStoreBase,
chunk_parameters: Optional[ChunkParameters] = None,
embedding_model: Optional[str] = None,
embeddings: Optional[Embeddings] = None,
retrieve_strategy: Optional[RetrieverStrategy] = RetrieverStrategy.EMBEDDING,
) -> "EmbeddingAssembler":
"""Load document embedding into vector store from path.
Args:
knowledge: (Knowledge) Knowledge datasource.
vector_store_connector: (VectorStoreConnector) VectorStoreConnector to use.
index_store: (IndexStoreBase) IndexStoreBase to use.
chunk_parameters: (Optional[ChunkParameters]) ChunkManager to use for
chunking.
embedding_model: (Optional[str]) Embedding model to use.
embeddings: (Optional[Embeddings]) Embeddings to use.
retrieve_strategy: (Optional[RetrieverStrategy]) Retriever strategy.
Returns:
EmbeddingAssembler
"""
return cls(
knowledge=knowledge,
vector_store_connector=vector_store_connector,
index_store=index_store,
chunk_parameters=chunk_parameters,
embedding_model=embedding_model,
embeddings=embeddings,
retrieve_strategy=retrieve_strategy,
)
@classmethod
async def aload_from_knowledge(
cls,
knowledge: Knowledge,
vector_store_connector: VectorStoreConnector,
index_store: IndexStoreBase,
chunk_parameters: Optional[ChunkParameters] = None,
embedding_model: Optional[str] = None,
embeddings: Optional[Embeddings] = None,
executor: Optional[ThreadPoolExecutor] = None,
retrieve_strategy: Optional[RetrieverStrategy] = RetrieverStrategy.EMBEDDING,
) -> "EmbeddingAssembler":
"""Load document embedding into vector store from path.
Args:
knowledge: (Knowledge) Knowledge datasource.
vector_store_connector: (VectorStoreConnector) VectorStoreConnector to use.
chunk_parameters: (Optional[ChunkParameters]) ChunkManager to use for
chunking.
embedding_model: (Optional[str]) Embedding model to use.
embeddings: (Optional[Embeddings]) Embeddings to use.
index_store: (IndexStoreBase) Index store to use.
executor: (Optional[ThreadPoolExecutor) ThreadPoolExecutor to use.
retrieve_strategy: (Optional[RetrieverStrategy]) Retriever strategy.
Returns:
EmbeddingAssembler
@@ -129,19 +119,18 @@ class EmbeddingAssembler(BaseAssembler):
executor,
cls,
knowledge,
vector_store_connector,
index_store,
chunk_parameters,
embedding_model,
embeddings,
retrieve_strategy,
)
def persist(self) -> List[str]:
"""Persist chunks into vector store.
"""Persist chunks into store.
Returns:
List[str]: List of chunk ids.
"""
return self._vector_store_connector.load_document(self._chunks)
return self._index_store.load_document(self._chunks)
async def apersist(self) -> List[str]:
"""Persist chunks into store.
@@ -149,13 +138,14 @@ class EmbeddingAssembler(BaseAssembler):
Returns:
List[str]: List of chunk ids.
"""
return await self._vector_store_connector.aload_document(self._chunks)
# persist chunks into vector store
return await self._index_store.aload_document(self._chunks)
def _extract_info(self, chunks) -> List[Chunk]:
"""Extract info from chunks."""
return []
def as_retriever(self, top_k: int = 4, **kwargs) -> EmbeddingRetriever:
def as_retriever(self, top_k: int = 4, **kwargs) -> BaseRetriever:
"""Create a retriever.
Args:
@@ -165,5 +155,7 @@ class EmbeddingAssembler(BaseAssembler):
EmbeddingRetriever
"""
return EmbeddingRetriever(
top_k=top_k, vector_store_connector=self._vector_store_connector
top_k=top_k,
index_store=self._index_store,
retrieve_strategy=self._retrieve_strategy,
)

View File

@@ -8,7 +8,7 @@ from dbgpt.rag.chunk_manager import ChunkParameters, SplitterType
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
from dbgpt.rag.knowledge.base import Knowledge
from dbgpt.rag.text_splitter.text_splitter import CharacterTextSplitter
from dbgpt.storage.vector_store.connector import VectorStoreConnector
from dbgpt.storage.vector_store.chroma_store import ChromaStore
@pytest.fixture
@@ -48,7 +48,7 @@ def mock_embedding_factory():
@pytest.fixture
def mock_vector_store_connector():
return MagicMock(spec=VectorStoreConnector)
return MagicMock(spec=ChromaStore)
@pytest.fixture
@@ -70,7 +70,7 @@ def test_load_knowledge(
knowledge=mock_knowledge,
chunk_parameters=mock_chunk_parameters,
embeddings=mock_embedding_factory.create(),
vector_store_connector=mock_vector_store_connector,
index_store=mock_vector_store_connector,
)
assembler.load_knowledge(knowledge=mock_knowledge)
assert len(assembler._chunks) == 0

View File

@@ -7,7 +7,7 @@ from dbgpt.rag.assembler.db_schema import DBSchemaAssembler
from dbgpt.rag.chunk_manager import ChunkParameters, SplitterType
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
from dbgpt.rag.text_splitter.text_splitter import CharacterTextSplitter
from dbgpt.storage.vector_store.connector import VectorStoreConnector
from dbgpt.storage.vector_store.chroma_store import ChromaStore
@pytest.fixture
@@ -47,7 +47,7 @@ def mock_embedding_factory():
@pytest.fixture
def mock_vector_store_connector():
return MagicMock(spec=VectorStoreConnector)
return MagicMock(spec=ChromaStore)
def test_load_knowledge(
@@ -63,6 +63,6 @@ def test_load_knowledge(
connector=mock_db_connection,
chunk_parameters=mock_chunk_parameters,
embeddings=mock_embedding_factory.create(),
vector_store_connector=mock_vector_store_connector,
index_store=mock_vector_store_connector,
)
assert len(assembler._chunks) == 1

View File

@@ -2,7 +2,7 @@
import logging
import time
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import Executor, ThreadPoolExecutor
from typing import Any, Dict, List, Optional
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field, model_to_dict
@@ -47,7 +47,7 @@ class IndexStoreConfig(BaseModel):
class IndexStoreBase(ABC):
"""Index store base class."""
def __init__(self, executor: Optional[ThreadPoolExecutor] = None):
def __init__(self, executor: Optional[Executor] = None):
"""Init index store."""
self._executor = executor or ThreadPoolExecutor()
@@ -63,7 +63,7 @@ class IndexStoreBase(ABC):
"""
@abstractmethod
def aload_document(self, chunks: List[Chunk]) -> List[str]:
async def aload_document(self, chunks: List[Chunk]) -> List[str]:
"""Load document in index database.
Args:
@@ -94,7 +94,7 @@ class IndexStoreBase(ABC):
"""
@abstractmethod
def delete_by_ids(self, ids: str):
def delete_by_ids(self, ids: str) -> List[str]:
"""Delete docs.
Args:

View File

@@ -5,10 +5,10 @@ from typing import List, Optional
from dbgpt.core import Chunk
from dbgpt.core.interface.operators.retriever import RetrieverOperator
from dbgpt.datasource.base import BaseConnector
from dbgpt.storage.vector_store.connector import VectorStoreConnector
from ..assembler.db_schema import DBSchemaAssembler
from ..chunk_manager import ChunkParameters
from ..index.base import IndexStoreBase
from ..retriever.db_schema import DBSchemaRetriever
from .assembler import AssemblerOperator
@@ -19,13 +19,13 @@ class DBSchemaRetrieverOperator(RetrieverOperator[str, List[Chunk]]):
Args:
connector (BaseConnector): The connection.
top_k (int, optional): The top k. Defaults to 4.
vector_store_connector (VectorStoreConnector, optional): The vector store
index_store (IndexStoreBase, optional): The vector store
connector. Defaults to None.
"""
def __init__(
self,
vector_store_connector: VectorStoreConnector,
index_store: IndexStoreBase,
top_k: int = 4,
connector: Optional[BaseConnector] = None,
**kwargs
@@ -35,7 +35,7 @@ class DBSchemaRetrieverOperator(RetrieverOperator[str, List[Chunk]]):
self._retriever = DBSchemaRetriever(
top_k=top_k,
connector=connector,
vector_store_connector=vector_store_connector,
index_store=index_store,
)
def retrieve(self, query: str) -> List[Chunk]:
@@ -53,7 +53,7 @@ class DBSchemaAssemblerOperator(AssemblerOperator[BaseConnector, List[Chunk]]):
def __init__(
self,
connector: BaseConnector,
vector_store_connector: VectorStoreConnector,
index_store: IndexStoreBase,
chunk_parameters: Optional[ChunkParameters] = None,
**kwargs
):
@@ -61,14 +61,14 @@ class DBSchemaAssemblerOperator(AssemblerOperator[BaseConnector, List[Chunk]]):
Args:
connector (BaseConnector): The connection.
vector_store_connector (VectorStoreConnector): The vector store connector.
index_store (IndexStoreBase): The Storage IndexStoreBase.
chunk_parameters (Optional[ChunkParameters], optional): The chunk
parameters.
"""
if not chunk_parameters:
chunk_parameters = ChunkParameters(chunk_strategy="CHUNK_BY_SIZE")
self._chunk_parameters = chunk_parameters
self._vector_store_connector = vector_store_connector
self._index_store = index_store
self._connector = connector
super().__init__(**kwargs)
@@ -84,7 +84,7 @@ class DBSchemaAssemblerOperator(AssemblerOperator[BaseConnector, List[Chunk]]):
assembler = DBSchemaAssembler.load_from_connection(
connector=self._connector,
chunk_parameters=self._chunk_parameters,
vector_store_connector=self._vector_store_connector,
index_store=self._index_store,
)
assembler.persist()
return assembler.get_chunks()

View File

@@ -6,11 +6,11 @@ from typing import List, Optional, Union
from dbgpt.core import Chunk
from dbgpt.core.awel.flow import IOField, OperatorCategory, Parameter, ViewMetadata
from dbgpt.core.interface.operators.retriever import RetrieverOperator
from dbgpt.storage.vector_store.connector import VectorStoreConnector
from dbgpt.util.i18n_utils import _
from ..assembler.embedding import EmbeddingAssembler
from ..chunk_manager import ChunkParameters
from ..index.base import IndexStoreBase
from ..knowledge import Knowledge
from ..retriever.embedding import EmbeddingRetriever
from ..retriever.rerank import Ranker
@@ -28,9 +28,9 @@ class EmbeddingRetrieverOperator(RetrieverOperator[Union[str, List[str]], List[C
category=OperatorCategory.RAG,
parameters=[
Parameter.build_from(
_("Vector Store Connector"),
_("Storage Index Store"),
"vector_store_connector",
VectorStoreConnector,
IndexStoreBase,
description=_("The vector store connector."),
),
Parameter.build_from(
@@ -88,7 +88,7 @@ class EmbeddingRetrieverOperator(RetrieverOperator[Union[str, List[str]], List[C
def __init__(
self,
vector_store_connector: VectorStoreConnector,
index_store: IndexStoreBase,
top_k: int,
score_threshold: float = 0.3,
query_rewrite: Optional[QueryRewrite] = None,
@@ -99,7 +99,7 @@ class EmbeddingRetrieverOperator(RetrieverOperator[Union[str, List[str]], List[C
super().__init__(**kwargs)
self._score_threshold = score_threshold
self._retriever = EmbeddingRetriever(
vector_store_connector=vector_store_connector,
index_store=index_store,
top_k=top_k,
query_rewrite=query_rewrite,
rerank=rerank,
@@ -129,7 +129,7 @@ class EmbeddingAssemblerOperator(AssemblerOperator[Knowledge, List[Chunk]]):
Parameter.build_from(
_("Vector Store Connector"),
"vector_store_connector",
VectorStoreConnector,
IndexStoreBase,
description=_("The vector store connector."),
),
Parameter.build_from(
@@ -164,21 +164,21 @@ class EmbeddingAssemblerOperator(AssemblerOperator[Knowledge, List[Chunk]]):
def __init__(
self,
vector_store_connector: VectorStoreConnector,
index_store: IndexStoreBase,
chunk_parameters: Optional[ChunkParameters] = None,
**kwargs
):
"""Create a new EmbeddingAssemblerOperator.
Args:
vector_store_connector (VectorStoreConnector): The vector store connector.
index_store (IndexStoreBase): The index storage.
chunk_parameters (Optional[ChunkParameters], optional): The chunk
parameters. Defaults to ChunkParameters(chunk_strategy="CHUNK_BY_SIZE").
"""
if not chunk_parameters:
chunk_parameters = ChunkParameters(chunk_strategy="CHUNK_BY_SIZE")
self._chunk_parameters = chunk_parameters
self._vector_store_connector = vector_store_connector
self._index_store = index_store
super().__init__(**kwargs)
def assemble(self, knowledge: Knowledge) -> List[Chunk]:
@@ -186,7 +186,7 @@ class EmbeddingAssemblerOperator(AssemblerOperator[Knowledge, List[Chunk]]):
assembler = EmbeddingAssembler.load_from_knowledge(
knowledge=knowledge,
chunk_parameters=self._chunk_parameters,
vector_store_connector=self._vector_store_connector,
index_store=self._index_store,
)
assembler.persist()
return assembler.get_chunks()

View File

@@ -8,8 +8,8 @@ from typing import Any, Optional
from dbgpt.core import LLMClient
from dbgpt.core.awel import MapOperator
from dbgpt.datasource.base import BaseConnector
from dbgpt.rag.index.base import IndexStoreBase
from dbgpt.rag.schemalinker.schema_linking import SchemaLinking
from dbgpt.storage.vector_store.connector import VectorStoreConnector
class SchemaLinkingOperator(MapOperator[Any, Any]):
@@ -21,7 +21,7 @@ class SchemaLinkingOperator(MapOperator[Any, Any]):
model_name: str,
llm: LLMClient,
top_k: int = 5,
vector_store_connector: Optional[VectorStoreConnector] = None,
index_store: Optional[IndexStoreBase] = None,
**kwargs
):
"""Create the schema linking operator.
@@ -37,7 +37,7 @@ class SchemaLinkingOperator(MapOperator[Any, Any]):
connector=connector,
llm=llm,
model_name=model_name,
vector_store_connector=vector_store_connector,
index_store=index_store,
)
async def map(self, query: str) -> str:

View File

@@ -17,6 +17,7 @@ class RetrieverStrategy(str, Enum):
"""
EMBEDDING = "embedding"
GRAPH = "graph"
KEYWORD = "keyword"
HYBRID = "hybrid"

View File

@@ -4,10 +4,10 @@ from typing import List, Optional, cast
from dbgpt.core import Chunk
from dbgpt.datasource.base import BaseConnector
from dbgpt.rag.index.base import IndexStoreBase
from dbgpt.rag.retriever.base import BaseRetriever
from dbgpt.rag.retriever.rerank import DefaultRanker, Ranker
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
from dbgpt.storage.vector_store.connector import VectorStoreConnector
from dbgpt.storage.vector_store.filters import MetadataFilters
from dbgpt.util.chat_util import run_async_tasks
@@ -17,7 +17,7 @@ class DBSchemaRetriever(BaseRetriever):
def __init__(
self,
vector_store_connector: VectorStoreConnector,
index_store: IndexStoreBase,
top_k: int = 4,
connector: Optional[BaseConnector] = None,
query_rewrite: bool = False,
@@ -27,7 +27,7 @@ class DBSchemaRetriever(BaseRetriever):
"""Create DBSchemaRetriever.
Args:
vector_store_connector (VectorStoreConnector): vector store connector
index_store(IndexStore): index connector
top_k (int): top k
connector (Optional[BaseConnector]): RDBMSConnector.
query_rewrite (bool): query rewrite
@@ -67,18 +67,22 @@ class DBSchemaRetriever(BaseRetriever):
connector = _create_temporary_connection()
vector_store_config = ChromaVectorConfig(name="vector_store_name")
embedding_model_path = "{your_embedding_model_path}"
embedding_fn = embedding_factory.create(model_name=embedding_model_path)
vector_connector = VectorStoreConnector.from_default(
"Chroma",
vector_store_config=vector_store_config,
embedding_fn=embedding_fn,
config = ChromaVectorConfig(
persist_path=PILOT_PATH,
name="dbschema_rag_test",
embedding_fn=DefaultEmbeddingFactory(
default_model_name=os.path.join(
MODEL_PATH, "text2vec-large-chinese"
),
).create(),
)
vector_store = ChromaStore(config)
# get db struct retriever
retriever = DBSchemaRetriever(
top_k=3,
vector_store_connector=vector_connector,
index_store=vector_store,
connector=connector,
)
chunks = retriever.retrieve("show columns from table")
@@ -88,9 +92,9 @@ class DBSchemaRetriever(BaseRetriever):
self._top_k = top_k
self._connector = connector
self._query_rewrite = query_rewrite
self._vector_store_connector = vector_store_connector
self._index_store = index_store
self._need_embeddings = False
if self._vector_store_connector:
if self._index_store:
self._need_embeddings = True
self._rerank = rerank or DefaultRanker(self._top_k)
@@ -109,7 +113,7 @@ class DBSchemaRetriever(BaseRetriever):
if self._need_embeddings:
queries = [query]
candidates = [
self._vector_store_connector.similar_search(query, self._top_k, filters)
self._index_store.similar_search(query, self._top_k, filters)
for query in queries
]
return cast(List[Chunk], reduce(lambda x, y: x + y, candidates))
@@ -185,7 +189,7 @@ class DBSchemaRetriever(BaseRetriever):
self, query, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Similar search."""
return self._vector_store_connector.similar_search(query, self._top_k, filters)
return self._index_store.similar_search(query, self._top_k, filters)
async def _aparse_db_summary(self) -> List[str]:
"""Similar search."""

View File

@@ -4,10 +4,10 @@ from functools import reduce
from typing import Any, Dict, List, Optional, cast
from dbgpt.core import Chunk
from dbgpt.rag.retriever.base import BaseRetriever
from dbgpt.rag.index.base import IndexStoreBase
from dbgpt.rag.retriever.base import BaseRetriever, RetrieverStrategy
from dbgpt.rag.retriever.rerank import DefaultRanker, Ranker
from dbgpt.rag.retriever.rewrite import QueryRewrite
from dbgpt.storage.vector_store.connector import VectorStoreConnector
from dbgpt.storage.vector_store.filters import MetadataFilters
from dbgpt.util.chat_util import run_async_tasks
from dbgpt.util.tracer import root_tracer
@@ -18,18 +18,19 @@ class EmbeddingRetriever(BaseRetriever):
def __init__(
self,
vector_store_connector: VectorStoreConnector,
index_store: IndexStoreBase,
top_k: int = 4,
query_rewrite: Optional[QueryRewrite] = None,
rerank: Optional[Ranker] = None,
retrieve_strategy: Optional[RetrieverStrategy] = RetrieverStrategy.EMBEDDING,
):
"""Create EmbeddingRetriever.
Args:
index_store(IndexStore): vector store connector
top_k (int): top k
query_rewrite (Optional[QueryRewrite]): query rewrite
rerank (Ranker): rerank
vector_store_connector (VectorStoreConnector): vector store connector
Examples:
.. code-block:: python
@@ -64,8 +65,9 @@ class EmbeddingRetriever(BaseRetriever):
"""
self._top_k = top_k
self._query_rewrite = query_rewrite
self._vector_store_connector = vector_store_connector
self._index_store = index_store
self._rerank = rerank or DefaultRanker(self._top_k)
self._retrieve_strategy = retrieve_strategy
def load_document(self, chunks: List[Chunk], **kwargs: Dict[str, Any]) -> List[str]:
"""Load document in vector database.
@@ -75,7 +77,7 @@ class EmbeddingRetriever(BaseRetriever):
Return:
List[str]: chunk ids.
"""
return self._vector_store_connector.load_document(chunks)
return self._index_store.load_document(chunks)
def _retrieve(
self, query: str, filters: Optional[MetadataFilters] = None
@@ -90,7 +92,7 @@ class EmbeddingRetriever(BaseRetriever):
"""
queries = [query]
candidates = [
self._vector_store_connector.similar_search(query, self._top_k, filters)
self._index_store.similar_search(query, self._top_k, filters)
for query in queries
]
res_candidates = cast(List[Chunk], reduce(lambda x, y: x + y, candidates))
@@ -113,7 +115,7 @@ class EmbeddingRetriever(BaseRetriever):
"""
queries = [query]
candidates_with_score = [
self._vector_store_connector.similar_search_with_scores(
self._index_store.similar_search_with_scores(
query, self._top_k, score_threshold, filters
)
for query in queries
@@ -217,7 +219,7 @@ class EmbeddingRetriever(BaseRetriever):
self, query, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Similar search."""
return self._vector_store_connector.similar_search(query, self._top_k, filters)
return self._index_store.similar_search(query, self._top_k, filters)
async def _run_async_tasks(self, tasks) -> List[Chunk]:
"""Run async tasks."""
@@ -229,6 +231,6 @@ class EmbeddingRetriever(BaseRetriever):
self, query, score_threshold, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Similar search with score."""
return await self._vector_store_connector.asimilar_search_with_scores(
return await self._index_store.asimilar_search_with_scores(
query, self._top_k, score_threshold, filters
)

View File

@@ -25,7 +25,7 @@ def mock_vector_store_connector():
def dbstruct_retriever(mock_db_connection, mock_vector_store_connector):
return DBSchemaRetriever(
connector=mock_db_connection,
vector_store_connector=mock_vector_store_connector,
index_store=mock_vector_store_connector,
)

View File

@@ -25,8 +25,8 @@ def mock_vector_store_connector():
def embedding_retriever(top_k, mock_vector_store_connector):
return EmbeddingRetriever(
top_k=top_k,
query_rewrite=False,
vector_store_connector=mock_vector_store_connector,
query_rewrite=None,
index_store=mock_vector_store_connector,
)

View File

@@ -7,9 +7,9 @@ from typing import Any, Dict, List, Optional, Tuple
from dbgpt.core import Chunk
from dbgpt.rag.retriever.rerank import Ranker
from dbgpt.rag.retriever.rewrite import QueryRewrite
from dbgpt.storage.vector_store.connector import VectorStoreConnector
from dbgpt.storage.vector_store.filters import MetadataFilters
from ..index.base import IndexStoreBase
from .embedding import EmbeddingRetriever
@@ -23,7 +23,7 @@ class TimeWeightedEmbeddingRetriever(EmbeddingRetriever):
def __init__(
self,
vector_store_connector: VectorStoreConnector,
index_store: IndexStoreBase,
top_k: int = 100,
query_rewrite: Optional[QueryRewrite] = None,
rerank: Optional[Ranker] = None,
@@ -32,13 +32,13 @@ class TimeWeightedEmbeddingRetriever(EmbeddingRetriever):
"""Initialize TimeWeightedEmbeddingRetriever.
Args:
vector_store_connector (VectorStoreConnector): vector store connector
index_store (IndexStoreBase): vector store connector
top_k (int): top k
query_rewrite (Optional[QueryRewrite]): query rewrite
rerank (Ranker): rerank
"""
super().__init__(
vector_store_connector=vector_store_connector,
index_store=index_store,
top_k=top_k,
query_rewrite=query_rewrite,
rerank=rerank,
@@ -69,7 +69,7 @@ class TimeWeightedEmbeddingRetriever(EmbeddingRetriever):
doc.metadata["created_at"] = current_time
doc.metadata["buffer_idx"] = len(self.memory_stream) + i
self.memory_stream.extend(dup_docs)
return self._vector_store_connector.load_document(dup_docs)
return self._index_store.load_document(dup_docs)
def _retrieve(
self, query: str, filters: Optional[MetadataFilters] = None
@@ -125,7 +125,7 @@ class TimeWeightedEmbeddingRetriever(EmbeddingRetriever):
def get_salient_docs(self, query: str) -> Dict[int, Tuple[Chunk, float]]:
"""Return documents that are salient to the query."""
docs_and_scores: List[Chunk]
docs_and_scores = self._vector_store_connector.similar_search_with_scores(
docs_and_scores = self._index_store.similar_search_with_scores(
query, topk=self._top_k, score_threshold=0
)
results = {}

View File

@@ -13,7 +13,7 @@ from dbgpt.core import (
from dbgpt.datasource.base import BaseConnector
from dbgpt.rag.schemalinker.base_linker import BaseSchemaLinker
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
from dbgpt.storage.vector_store.connector import VectorStoreConnector
from dbgpt.serve.rag.connector import VectorStoreConnector
from dbgpt.util.chat_util import run_async_tasks
INSTRUCTION = """

View File

@@ -48,8 +48,8 @@ class DBSummaryClient:
def get_db_summary(self, dbname, query, topk):
"""Get user query related tables info."""
from dbgpt.serve.rag.connector import VectorStoreConnector
from dbgpt.storage.vector_store.base import VectorStoreConfig
from dbgpt.storage.vector_store.connector import VectorStoreConnector
vector_store_config = VectorStoreConfig(name=dbname + "_profile")
vector_connector = VectorStoreConnector.from_default(
@@ -60,7 +60,7 @@ class DBSummaryClient:
from dbgpt.rag.retriever.db_schema import DBSchemaRetriever
retriever = DBSchemaRetriever(
top_k=topk, vector_store_connector=vector_connector
top_k=topk, index_store=vector_connector.index_client
)
table_docs = retriever.retrieve(query)
ans = [d.content for d in table_docs]
@@ -88,8 +88,8 @@ class DBSummaryClient:
dbname(str): dbname
"""
vector_store_name = dbname + "_profile"
from dbgpt.serve.rag.connector import VectorStoreConnector
from dbgpt.storage.vector_store.base import VectorStoreConfig
from dbgpt.storage.vector_store.connector import VectorStoreConnector
vector_store_config = VectorStoreConfig(name=vector_store_name)
vector_connector = VectorStoreConnector.from_default(
@@ -102,7 +102,7 @@ class DBSummaryClient:
db_assembler = DBSchemaAssembler.load_from_connection(
connector=db_summary_client.db,
vector_store_connector=vector_connector,
index_store=vector_connector.index_client,
)
if len(db_assembler.get_chunks()) > 0:
@@ -114,8 +114,8 @@ class DBSummaryClient:
def delete_db_profile(self, dbname):
"""Delete db profile."""
vector_store_name = dbname + "_profile"
from dbgpt.serve.rag.connector import VectorStoreConnector
from dbgpt.storage.vector_store.base import VectorStoreConfig
from dbgpt.storage.vector_store.connector import VectorStoreConnector
vector_store_config = VectorStoreConfig(name=vector_store_name)
vector_connector = VectorStoreConnector.from_default(