mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-12 20:53:48 +00:00
feat(ChatKnowledge): ChatKnowledge Support Keyword Retrieve (#1624)
Co-authored-by: Fangyin Cheng <staneyffer@gmail.com>
This commit is contained in:
@@ -1,13 +1,12 @@
|
||||
"""DBSchemaAssembler."""
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from dbgpt.core import Chunk, Embeddings
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.datasource.base import BaseConnector
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
from ..assembler.base import BaseAssembler
|
||||
from ..chunk_manager import ChunkParameters
|
||||
from ..embedding.embedding_factory import DefaultEmbeddingFactory
|
||||
from ..index.base import IndexStoreBase
|
||||
from ..knowledge.datasource import DatasourceKnowledge
|
||||
from ..retriever.db_schema import DBSchemaRetriever
|
||||
|
||||
@@ -36,36 +35,22 @@ class DBSchemaAssembler(BaseAssembler):
|
||||
def __init__(
|
||||
self,
|
||||
connector: BaseConnector,
|
||||
vector_store_connector: VectorStoreConnector,
|
||||
index_store: IndexStoreBase,
|
||||
chunk_parameters: Optional[ChunkParameters] = None,
|
||||
embedding_model: Optional[str] = None,
|
||||
embeddings: Optional[Embeddings] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize with Embedding Assembler arguments.
|
||||
|
||||
Args:
|
||||
connector: (BaseConnector) BaseConnector connection.
|
||||
vector_store_connector: (VectorStoreConnector) VectorStoreConnector to use.
|
||||
index_store: (IndexStoreBase) IndexStoreBase to use.
|
||||
chunk_manager: (Optional[ChunkManager]) ChunkManager to use for chunking.
|
||||
embedding_model: (Optional[str]) Embedding model to use.
|
||||
embeddings: (Optional[Embeddings]) Embeddings to use.
|
||||
"""
|
||||
knowledge = DatasourceKnowledge(connector)
|
||||
self._connector = connector
|
||||
self._vector_store_connector = vector_store_connector
|
||||
|
||||
self._embedding_model = embedding_model
|
||||
if self._embedding_model and not embeddings:
|
||||
embeddings = DefaultEmbeddingFactory(
|
||||
default_model_name=self._embedding_model
|
||||
).create(self._embedding_model)
|
||||
|
||||
if (
|
||||
embeddings
|
||||
and self._vector_store_connector.vector_store_config.embedding_fn is None
|
||||
):
|
||||
self._vector_store_connector.vector_store_config.embedding_fn = embeddings
|
||||
self._index_store = index_store
|
||||
|
||||
super().__init__(
|
||||
knowledge=knowledge,
|
||||
@@ -77,29 +62,23 @@ class DBSchemaAssembler(BaseAssembler):
|
||||
def load_from_connection(
|
||||
cls,
|
||||
connector: BaseConnector,
|
||||
vector_store_connector: VectorStoreConnector,
|
||||
index_store: IndexStoreBase,
|
||||
chunk_parameters: Optional[ChunkParameters] = None,
|
||||
embedding_model: Optional[str] = None,
|
||||
embeddings: Optional[Embeddings] = None,
|
||||
) -> "DBSchemaAssembler":
|
||||
"""Load document embedding into vector store from path.
|
||||
|
||||
Args:
|
||||
connector: (BaseConnector) BaseConnector connection.
|
||||
vector_store_connector: (VectorStoreConnector) VectorStoreConnector to use.
|
||||
index_store: (IndexStoreBase) IndexStoreBase to use.
|
||||
chunk_parameters: (Optional[ChunkParameters]) ChunkManager to use for
|
||||
chunking.
|
||||
embedding_model: (Optional[str]) Embedding model to use.
|
||||
embeddings: (Optional[Embeddings]) Embeddings to use.
|
||||
Returns:
|
||||
DBSchemaAssembler
|
||||
"""
|
||||
return cls(
|
||||
connector=connector,
|
||||
vector_store_connector=vector_store_connector,
|
||||
embedding_model=embedding_model,
|
||||
index_store=index_store,
|
||||
chunk_parameters=chunk_parameters,
|
||||
embeddings=embeddings,
|
||||
)
|
||||
|
||||
def get_chunks(self) -> List[Chunk]:
|
||||
@@ -112,7 +91,7 @@ class DBSchemaAssembler(BaseAssembler):
|
||||
Returns:
|
||||
List[str]: List of chunk ids.
|
||||
"""
|
||||
return self._vector_store_connector.load_document(self._chunks)
|
||||
return self._index_store.load_document(self._chunks)
|
||||
|
||||
def _extract_info(self, chunks) -> List[Chunk]:
|
||||
"""Extract info from chunks."""
|
||||
@@ -131,5 +110,5 @@ class DBSchemaAssembler(BaseAssembler):
|
||||
top_k=top_k,
|
||||
connector=self._connector,
|
||||
is_embeddings=True,
|
||||
vector_store_connector=self._vector_store_connector,
|
||||
index_store=self._index_store,
|
||||
)
|
||||
|
@@ -3,13 +3,13 @@ from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from dbgpt.core import Chunk, Embeddings
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
from ...util.executor_utils import blocking_func_to_async
|
||||
from ..assembler.base import BaseAssembler
|
||||
from ..chunk_manager import ChunkParameters
|
||||
from ..embedding.embedding_factory import DefaultEmbeddingFactory
|
||||
from ..index.base import IndexStoreBase
|
||||
from ..knowledge.base import Knowledge
|
||||
from ..retriever import BaseRetriever, RetrieverStrategy
|
||||
from ..retriever.embedding import EmbeddingRetriever
|
||||
|
||||
|
||||
@@ -32,37 +32,26 @@ class EmbeddingAssembler(BaseAssembler):
|
||||
def __init__(
|
||||
self,
|
||||
knowledge: Knowledge,
|
||||
vector_store_connector: VectorStoreConnector,
|
||||
index_store: IndexStoreBase,
|
||||
chunk_parameters: Optional[ChunkParameters] = None,
|
||||
embedding_model: Optional[str] = None,
|
||||
embeddings: Optional[Embeddings] = None,
|
||||
retrieve_strategy: Optional[RetrieverStrategy] = RetrieverStrategy.EMBEDDING,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize with Embedding Assembler arguments.
|
||||
|
||||
Args:
|
||||
knowledge: (Knowledge) Knowledge datasource.
|
||||
vector_store_connector: (VectorStoreConnector) VectorStoreConnector to use.
|
||||
index_store: (IndexStoreBase) IndexStoreBase to use.
|
||||
chunk_parameters: (Optional[ChunkParameters]) ChunkManager to use for
|
||||
chunking.
|
||||
keyword_store: (Optional[IndexStoreBase]) IndexStoreBase to use.
|
||||
embedding_model: (Optional[str]) Embedding model to use.
|
||||
embeddings: (Optional[Embeddings]) Embeddings to use.
|
||||
"""
|
||||
if knowledge is None:
|
||||
raise ValueError("knowledge datasource must be provided.")
|
||||
self._vector_store_connector = vector_store_connector
|
||||
|
||||
self._embedding_model = embedding_model
|
||||
if self._embedding_model and not embeddings:
|
||||
embeddings = DefaultEmbeddingFactory(
|
||||
default_model_name=self._embedding_model
|
||||
).create(self._embedding_model)
|
||||
|
||||
if (
|
||||
embeddings
|
||||
and self._vector_store_connector.vector_store_config.embedding_fn is None
|
||||
):
|
||||
self._vector_store_connector.vector_store_config.embedding_fn = embeddings
|
||||
self._index_store = index_store
|
||||
self._retrieve_strategy = retrieve_strategy
|
||||
|
||||
super().__init__(
|
||||
knowledge=knowledge,
|
||||
@@ -74,52 +63,53 @@ class EmbeddingAssembler(BaseAssembler):
|
||||
def load_from_knowledge(
|
||||
cls,
|
||||
knowledge: Knowledge,
|
||||
vector_store_connector: VectorStoreConnector,
|
||||
index_store: IndexStoreBase,
|
||||
chunk_parameters: Optional[ChunkParameters] = None,
|
||||
embedding_model: Optional[str] = None,
|
||||
embeddings: Optional[Embeddings] = None,
|
||||
retrieve_strategy: Optional[RetrieverStrategy] = RetrieverStrategy.EMBEDDING,
|
||||
) -> "EmbeddingAssembler":
|
||||
"""Load document embedding into vector store from path.
|
||||
|
||||
Args:
|
||||
knowledge: (Knowledge) Knowledge datasource.
|
||||
vector_store_connector: (VectorStoreConnector) VectorStoreConnector to use.
|
||||
index_store: (IndexStoreBase) IndexStoreBase to use.
|
||||
chunk_parameters: (Optional[ChunkParameters]) ChunkManager to use for
|
||||
chunking.
|
||||
embedding_model: (Optional[str]) Embedding model to use.
|
||||
embeddings: (Optional[Embeddings]) Embeddings to use.
|
||||
retrieve_strategy: (Optional[RetrieverStrategy]) Retriever strategy.
|
||||
|
||||
Returns:
|
||||
EmbeddingAssembler
|
||||
"""
|
||||
return cls(
|
||||
knowledge=knowledge,
|
||||
vector_store_connector=vector_store_connector,
|
||||
index_store=index_store,
|
||||
chunk_parameters=chunk_parameters,
|
||||
embedding_model=embedding_model,
|
||||
embeddings=embeddings,
|
||||
retrieve_strategy=retrieve_strategy,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def aload_from_knowledge(
|
||||
cls,
|
||||
knowledge: Knowledge,
|
||||
vector_store_connector: VectorStoreConnector,
|
||||
index_store: IndexStoreBase,
|
||||
chunk_parameters: Optional[ChunkParameters] = None,
|
||||
embedding_model: Optional[str] = None,
|
||||
embeddings: Optional[Embeddings] = None,
|
||||
executor: Optional[ThreadPoolExecutor] = None,
|
||||
retrieve_strategy: Optional[RetrieverStrategy] = RetrieverStrategy.EMBEDDING,
|
||||
) -> "EmbeddingAssembler":
|
||||
"""Load document embedding into vector store from path.
|
||||
|
||||
Args:
|
||||
knowledge: (Knowledge) Knowledge datasource.
|
||||
vector_store_connector: (VectorStoreConnector) VectorStoreConnector to use.
|
||||
chunk_parameters: (Optional[ChunkParameters]) ChunkManager to use for
|
||||
chunking.
|
||||
embedding_model: (Optional[str]) Embedding model to use.
|
||||
embeddings: (Optional[Embeddings]) Embeddings to use.
|
||||
index_store: (IndexStoreBase) Index store to use.
|
||||
executor: (Optional[ThreadPoolExecutor) ThreadPoolExecutor to use.
|
||||
retrieve_strategy: (Optional[RetrieverStrategy]) Retriever strategy.
|
||||
|
||||
Returns:
|
||||
EmbeddingAssembler
|
||||
@@ -129,19 +119,18 @@ class EmbeddingAssembler(BaseAssembler):
|
||||
executor,
|
||||
cls,
|
||||
knowledge,
|
||||
vector_store_connector,
|
||||
index_store,
|
||||
chunk_parameters,
|
||||
embedding_model,
|
||||
embeddings,
|
||||
retrieve_strategy,
|
||||
)
|
||||
|
||||
def persist(self) -> List[str]:
|
||||
"""Persist chunks into vector store.
|
||||
"""Persist chunks into store.
|
||||
|
||||
Returns:
|
||||
List[str]: List of chunk ids.
|
||||
"""
|
||||
return self._vector_store_connector.load_document(self._chunks)
|
||||
return self._index_store.load_document(self._chunks)
|
||||
|
||||
async def apersist(self) -> List[str]:
|
||||
"""Persist chunks into store.
|
||||
@@ -149,13 +138,14 @@ class EmbeddingAssembler(BaseAssembler):
|
||||
Returns:
|
||||
List[str]: List of chunk ids.
|
||||
"""
|
||||
return await self._vector_store_connector.aload_document(self._chunks)
|
||||
# persist chunks into vector store
|
||||
return await self._index_store.aload_document(self._chunks)
|
||||
|
||||
def _extract_info(self, chunks) -> List[Chunk]:
|
||||
"""Extract info from chunks."""
|
||||
return []
|
||||
|
||||
def as_retriever(self, top_k: int = 4, **kwargs) -> EmbeddingRetriever:
|
||||
def as_retriever(self, top_k: int = 4, **kwargs) -> BaseRetriever:
|
||||
"""Create a retriever.
|
||||
|
||||
Args:
|
||||
@@ -165,5 +155,7 @@ class EmbeddingAssembler(BaseAssembler):
|
||||
EmbeddingRetriever
|
||||
"""
|
||||
return EmbeddingRetriever(
|
||||
top_k=top_k, vector_store_connector=self._vector_store_connector
|
||||
top_k=top_k,
|
||||
index_store=self._index_store,
|
||||
retrieve_strategy=self._retrieve_strategy,
|
||||
)
|
||||
|
@@ -8,7 +8,7 @@ from dbgpt.rag.chunk_manager import ChunkParameters, SplitterType
|
||||
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
|
||||
from dbgpt.rag.knowledge.base import Knowledge
|
||||
from dbgpt.rag.text_splitter.text_splitter import CharacterTextSplitter
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaStore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -48,7 +48,7 @@ def mock_embedding_factory():
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vector_store_connector():
|
||||
return MagicMock(spec=VectorStoreConnector)
|
||||
return MagicMock(spec=ChromaStore)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -70,7 +70,7 @@ def test_load_knowledge(
|
||||
knowledge=mock_knowledge,
|
||||
chunk_parameters=mock_chunk_parameters,
|
||||
embeddings=mock_embedding_factory.create(),
|
||||
vector_store_connector=mock_vector_store_connector,
|
||||
index_store=mock_vector_store_connector,
|
||||
)
|
||||
assembler.load_knowledge(knowledge=mock_knowledge)
|
||||
assert len(assembler._chunks) == 0
|
||||
|
@@ -7,7 +7,7 @@ from dbgpt.rag.assembler.db_schema import DBSchemaAssembler
|
||||
from dbgpt.rag.chunk_manager import ChunkParameters, SplitterType
|
||||
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
|
||||
from dbgpt.rag.text_splitter.text_splitter import CharacterTextSplitter
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaStore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -47,7 +47,7 @@ def mock_embedding_factory():
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vector_store_connector():
|
||||
return MagicMock(spec=VectorStoreConnector)
|
||||
return MagicMock(spec=ChromaStore)
|
||||
|
||||
|
||||
def test_load_knowledge(
|
||||
@@ -63,6 +63,6 @@ def test_load_knowledge(
|
||||
connector=mock_db_connection,
|
||||
chunk_parameters=mock_chunk_parameters,
|
||||
embeddings=mock_embedding_factory.create(),
|
||||
vector_store_connector=mock_vector_store_connector,
|
||||
index_store=mock_vector_store_connector,
|
||||
)
|
||||
assert len(assembler._chunks) == 1
|
||||
|
@@ -2,7 +2,7 @@
|
||||
import logging
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from concurrent.futures import Executor, ThreadPoolExecutor
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field, model_to_dict
|
||||
@@ -47,7 +47,7 @@ class IndexStoreConfig(BaseModel):
|
||||
class IndexStoreBase(ABC):
|
||||
"""Index store base class."""
|
||||
|
||||
def __init__(self, executor: Optional[ThreadPoolExecutor] = None):
|
||||
def __init__(self, executor: Optional[Executor] = None):
|
||||
"""Init index store."""
|
||||
self._executor = executor or ThreadPoolExecutor()
|
||||
|
||||
@@ -63,7 +63,7 @@ class IndexStoreBase(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def aload_document(self, chunks: List[Chunk]) -> List[str]:
|
||||
async def aload_document(self, chunks: List[Chunk]) -> List[str]:
|
||||
"""Load document in index database.
|
||||
|
||||
Args:
|
||||
@@ -94,7 +94,7 @@ class IndexStoreBase(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def delete_by_ids(self, ids: str):
|
||||
def delete_by_ids(self, ids: str) -> List[str]:
|
||||
"""Delete docs.
|
||||
|
||||
Args:
|
||||
|
@@ -5,10 +5,10 @@ from typing import List, Optional
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.core.interface.operators.retriever import RetrieverOperator
|
||||
from dbgpt.datasource.base import BaseConnector
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
from ..assembler.db_schema import DBSchemaAssembler
|
||||
from ..chunk_manager import ChunkParameters
|
||||
from ..index.base import IndexStoreBase
|
||||
from ..retriever.db_schema import DBSchemaRetriever
|
||||
from .assembler import AssemblerOperator
|
||||
|
||||
@@ -19,13 +19,13 @@ class DBSchemaRetrieverOperator(RetrieverOperator[str, List[Chunk]]):
|
||||
Args:
|
||||
connector (BaseConnector): The connection.
|
||||
top_k (int, optional): The top k. Defaults to 4.
|
||||
vector_store_connector (VectorStoreConnector, optional): The vector store
|
||||
index_store (IndexStoreBase, optional): The vector store
|
||||
connector. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vector_store_connector: VectorStoreConnector,
|
||||
index_store: IndexStoreBase,
|
||||
top_k: int = 4,
|
||||
connector: Optional[BaseConnector] = None,
|
||||
**kwargs
|
||||
@@ -35,7 +35,7 @@ class DBSchemaRetrieverOperator(RetrieverOperator[str, List[Chunk]]):
|
||||
self._retriever = DBSchemaRetriever(
|
||||
top_k=top_k,
|
||||
connector=connector,
|
||||
vector_store_connector=vector_store_connector,
|
||||
index_store=index_store,
|
||||
)
|
||||
|
||||
def retrieve(self, query: str) -> List[Chunk]:
|
||||
@@ -53,7 +53,7 @@ class DBSchemaAssemblerOperator(AssemblerOperator[BaseConnector, List[Chunk]]):
|
||||
def __init__(
|
||||
self,
|
||||
connector: BaseConnector,
|
||||
vector_store_connector: VectorStoreConnector,
|
||||
index_store: IndexStoreBase,
|
||||
chunk_parameters: Optional[ChunkParameters] = None,
|
||||
**kwargs
|
||||
):
|
||||
@@ -61,14 +61,14 @@ class DBSchemaAssemblerOperator(AssemblerOperator[BaseConnector, List[Chunk]]):
|
||||
|
||||
Args:
|
||||
connector (BaseConnector): The connection.
|
||||
vector_store_connector (VectorStoreConnector): The vector store connector.
|
||||
index_store (IndexStoreBase): The Storage IndexStoreBase.
|
||||
chunk_parameters (Optional[ChunkParameters], optional): The chunk
|
||||
parameters.
|
||||
"""
|
||||
if not chunk_parameters:
|
||||
chunk_parameters = ChunkParameters(chunk_strategy="CHUNK_BY_SIZE")
|
||||
self._chunk_parameters = chunk_parameters
|
||||
self._vector_store_connector = vector_store_connector
|
||||
self._index_store = index_store
|
||||
self._connector = connector
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@@ -84,7 +84,7 @@ class DBSchemaAssemblerOperator(AssemblerOperator[BaseConnector, List[Chunk]]):
|
||||
assembler = DBSchemaAssembler.load_from_connection(
|
||||
connector=self._connector,
|
||||
chunk_parameters=self._chunk_parameters,
|
||||
vector_store_connector=self._vector_store_connector,
|
||||
index_store=self._index_store,
|
||||
)
|
||||
assembler.persist()
|
||||
return assembler.get_chunks()
|
||||
|
@@ -6,11 +6,11 @@ from typing import List, Optional, Union
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.core.awel.flow import IOField, OperatorCategory, Parameter, ViewMetadata
|
||||
from dbgpt.core.interface.operators.retriever import RetrieverOperator
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
from dbgpt.util.i18n_utils import _
|
||||
|
||||
from ..assembler.embedding import EmbeddingAssembler
|
||||
from ..chunk_manager import ChunkParameters
|
||||
from ..index.base import IndexStoreBase
|
||||
from ..knowledge import Knowledge
|
||||
from ..retriever.embedding import EmbeddingRetriever
|
||||
from ..retriever.rerank import Ranker
|
||||
@@ -28,9 +28,9 @@ class EmbeddingRetrieverOperator(RetrieverOperator[Union[str, List[str]], List[C
|
||||
category=OperatorCategory.RAG,
|
||||
parameters=[
|
||||
Parameter.build_from(
|
||||
_("Vector Store Connector"),
|
||||
_("Storage Index Store"),
|
||||
"vector_store_connector",
|
||||
VectorStoreConnector,
|
||||
IndexStoreBase,
|
||||
description=_("The vector store connector."),
|
||||
),
|
||||
Parameter.build_from(
|
||||
@@ -88,7 +88,7 @@ class EmbeddingRetrieverOperator(RetrieverOperator[Union[str, List[str]], List[C
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vector_store_connector: VectorStoreConnector,
|
||||
index_store: IndexStoreBase,
|
||||
top_k: int,
|
||||
score_threshold: float = 0.3,
|
||||
query_rewrite: Optional[QueryRewrite] = None,
|
||||
@@ -99,7 +99,7 @@ class EmbeddingRetrieverOperator(RetrieverOperator[Union[str, List[str]], List[C
|
||||
super().__init__(**kwargs)
|
||||
self._score_threshold = score_threshold
|
||||
self._retriever = EmbeddingRetriever(
|
||||
vector_store_connector=vector_store_connector,
|
||||
index_store=index_store,
|
||||
top_k=top_k,
|
||||
query_rewrite=query_rewrite,
|
||||
rerank=rerank,
|
||||
@@ -129,7 +129,7 @@ class EmbeddingAssemblerOperator(AssemblerOperator[Knowledge, List[Chunk]]):
|
||||
Parameter.build_from(
|
||||
_("Vector Store Connector"),
|
||||
"vector_store_connector",
|
||||
VectorStoreConnector,
|
||||
IndexStoreBase,
|
||||
description=_("The vector store connector."),
|
||||
),
|
||||
Parameter.build_from(
|
||||
@@ -164,21 +164,21 @@ class EmbeddingAssemblerOperator(AssemblerOperator[Knowledge, List[Chunk]]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vector_store_connector: VectorStoreConnector,
|
||||
index_store: IndexStoreBase,
|
||||
chunk_parameters: Optional[ChunkParameters] = None,
|
||||
**kwargs
|
||||
):
|
||||
"""Create a new EmbeddingAssemblerOperator.
|
||||
|
||||
Args:
|
||||
vector_store_connector (VectorStoreConnector): The vector store connector.
|
||||
index_store (IndexStoreBase): The index storage.
|
||||
chunk_parameters (Optional[ChunkParameters], optional): The chunk
|
||||
parameters. Defaults to ChunkParameters(chunk_strategy="CHUNK_BY_SIZE").
|
||||
"""
|
||||
if not chunk_parameters:
|
||||
chunk_parameters = ChunkParameters(chunk_strategy="CHUNK_BY_SIZE")
|
||||
self._chunk_parameters = chunk_parameters
|
||||
self._vector_store_connector = vector_store_connector
|
||||
self._index_store = index_store
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def assemble(self, knowledge: Knowledge) -> List[Chunk]:
|
||||
@@ -186,7 +186,7 @@ class EmbeddingAssemblerOperator(AssemblerOperator[Knowledge, List[Chunk]]):
|
||||
assembler = EmbeddingAssembler.load_from_knowledge(
|
||||
knowledge=knowledge,
|
||||
chunk_parameters=self._chunk_parameters,
|
||||
vector_store_connector=self._vector_store_connector,
|
||||
index_store=self._index_store,
|
||||
)
|
||||
assembler.persist()
|
||||
return assembler.get_chunks()
|
||||
|
@@ -8,8 +8,8 @@ from typing import Any, Optional
|
||||
from dbgpt.core import LLMClient
|
||||
from dbgpt.core.awel import MapOperator
|
||||
from dbgpt.datasource.base import BaseConnector
|
||||
from dbgpt.rag.index.base import IndexStoreBase
|
||||
from dbgpt.rag.schemalinker.schema_linking import SchemaLinking
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
|
||||
class SchemaLinkingOperator(MapOperator[Any, Any]):
|
||||
@@ -21,7 +21,7 @@ class SchemaLinkingOperator(MapOperator[Any, Any]):
|
||||
model_name: str,
|
||||
llm: LLMClient,
|
||||
top_k: int = 5,
|
||||
vector_store_connector: Optional[VectorStoreConnector] = None,
|
||||
index_store: Optional[IndexStoreBase] = None,
|
||||
**kwargs
|
||||
):
|
||||
"""Create the schema linking operator.
|
||||
@@ -37,7 +37,7 @@ class SchemaLinkingOperator(MapOperator[Any, Any]):
|
||||
connector=connector,
|
||||
llm=llm,
|
||||
model_name=model_name,
|
||||
vector_store_connector=vector_store_connector,
|
||||
index_store=index_store,
|
||||
)
|
||||
|
||||
async def map(self, query: str) -> str:
|
||||
|
@@ -17,6 +17,7 @@ class RetrieverStrategy(str, Enum):
|
||||
"""
|
||||
|
||||
EMBEDDING = "embedding"
|
||||
GRAPH = "graph"
|
||||
KEYWORD = "keyword"
|
||||
HYBRID = "hybrid"
|
||||
|
||||
|
@@ -4,10 +4,10 @@ from typing import List, Optional, cast
|
||||
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.datasource.base import BaseConnector
|
||||
from dbgpt.rag.index.base import IndexStoreBase
|
||||
from dbgpt.rag.retriever.base import BaseRetriever
|
||||
from dbgpt.rag.retriever.rerank import DefaultRanker, Ranker
|
||||
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
from dbgpt.storage.vector_store.filters import MetadataFilters
|
||||
from dbgpt.util.chat_util import run_async_tasks
|
||||
|
||||
@@ -17,7 +17,7 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vector_store_connector: VectorStoreConnector,
|
||||
index_store: IndexStoreBase,
|
||||
top_k: int = 4,
|
||||
connector: Optional[BaseConnector] = None,
|
||||
query_rewrite: bool = False,
|
||||
@@ -27,7 +27,7 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
"""Create DBSchemaRetriever.
|
||||
|
||||
Args:
|
||||
vector_store_connector (VectorStoreConnector): vector store connector
|
||||
index_store(IndexStore): index connector
|
||||
top_k (int): top k
|
||||
connector (Optional[BaseConnector]): RDBMSConnector.
|
||||
query_rewrite (bool): query rewrite
|
||||
@@ -67,18 +67,22 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
|
||||
|
||||
connector = _create_temporary_connection()
|
||||
vector_store_config = ChromaVectorConfig(name="vector_store_name")
|
||||
embedding_model_path = "{your_embedding_model_path}"
|
||||
embedding_fn = embedding_factory.create(model_name=embedding_model_path)
|
||||
vector_connector = VectorStoreConnector.from_default(
|
||||
"Chroma",
|
||||
vector_store_config=vector_store_config,
|
||||
embedding_fn=embedding_fn,
|
||||
config = ChromaVectorConfig(
|
||||
persist_path=PILOT_PATH,
|
||||
name="dbschema_rag_test",
|
||||
embedding_fn=DefaultEmbeddingFactory(
|
||||
default_model_name=os.path.join(
|
||||
MODEL_PATH, "text2vec-large-chinese"
|
||||
),
|
||||
).create(),
|
||||
)
|
||||
|
||||
vector_store = ChromaStore(config)
|
||||
# get db struct retriever
|
||||
retriever = DBSchemaRetriever(
|
||||
top_k=3,
|
||||
vector_store_connector=vector_connector,
|
||||
index_store=vector_store,
|
||||
connector=connector,
|
||||
)
|
||||
chunks = retriever.retrieve("show columns from table")
|
||||
@@ -88,9 +92,9 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
self._top_k = top_k
|
||||
self._connector = connector
|
||||
self._query_rewrite = query_rewrite
|
||||
self._vector_store_connector = vector_store_connector
|
||||
self._index_store = index_store
|
||||
self._need_embeddings = False
|
||||
if self._vector_store_connector:
|
||||
if self._index_store:
|
||||
self._need_embeddings = True
|
||||
self._rerank = rerank or DefaultRanker(self._top_k)
|
||||
|
||||
@@ -109,7 +113,7 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
if self._need_embeddings:
|
||||
queries = [query]
|
||||
candidates = [
|
||||
self._vector_store_connector.similar_search(query, self._top_k, filters)
|
||||
self._index_store.similar_search(query, self._top_k, filters)
|
||||
for query in queries
|
||||
]
|
||||
return cast(List[Chunk], reduce(lambda x, y: x + y, candidates))
|
||||
@@ -185,7 +189,7 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
self, query, filters: Optional[MetadataFilters] = None
|
||||
) -> List[Chunk]:
|
||||
"""Similar search."""
|
||||
return self._vector_store_connector.similar_search(query, self._top_k, filters)
|
||||
return self._index_store.similar_search(query, self._top_k, filters)
|
||||
|
||||
async def _aparse_db_summary(self) -> List[str]:
|
||||
"""Similar search."""
|
||||
|
@@ -4,10 +4,10 @@ from functools import reduce
|
||||
from typing import Any, Dict, List, Optional, cast
|
||||
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.rag.retriever.base import BaseRetriever
|
||||
from dbgpt.rag.index.base import IndexStoreBase
|
||||
from dbgpt.rag.retriever.base import BaseRetriever, RetrieverStrategy
|
||||
from dbgpt.rag.retriever.rerank import DefaultRanker, Ranker
|
||||
from dbgpt.rag.retriever.rewrite import QueryRewrite
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
from dbgpt.storage.vector_store.filters import MetadataFilters
|
||||
from dbgpt.util.chat_util import run_async_tasks
|
||||
from dbgpt.util.tracer import root_tracer
|
||||
@@ -18,18 +18,19 @@ class EmbeddingRetriever(BaseRetriever):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vector_store_connector: VectorStoreConnector,
|
||||
index_store: IndexStoreBase,
|
||||
top_k: int = 4,
|
||||
query_rewrite: Optional[QueryRewrite] = None,
|
||||
rerank: Optional[Ranker] = None,
|
||||
retrieve_strategy: Optional[RetrieverStrategy] = RetrieverStrategy.EMBEDDING,
|
||||
):
|
||||
"""Create EmbeddingRetriever.
|
||||
|
||||
Args:
|
||||
index_store(IndexStore): vector store connector
|
||||
top_k (int): top k
|
||||
query_rewrite (Optional[QueryRewrite]): query rewrite
|
||||
rerank (Ranker): rerank
|
||||
vector_store_connector (VectorStoreConnector): vector store connector
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
@@ -64,8 +65,9 @@ class EmbeddingRetriever(BaseRetriever):
|
||||
"""
|
||||
self._top_k = top_k
|
||||
self._query_rewrite = query_rewrite
|
||||
self._vector_store_connector = vector_store_connector
|
||||
self._index_store = index_store
|
||||
self._rerank = rerank or DefaultRanker(self._top_k)
|
||||
self._retrieve_strategy = retrieve_strategy
|
||||
|
||||
def load_document(self, chunks: List[Chunk], **kwargs: Dict[str, Any]) -> List[str]:
|
||||
"""Load document in vector database.
|
||||
@@ -75,7 +77,7 @@ class EmbeddingRetriever(BaseRetriever):
|
||||
Return:
|
||||
List[str]: chunk ids.
|
||||
"""
|
||||
return self._vector_store_connector.load_document(chunks)
|
||||
return self._index_store.load_document(chunks)
|
||||
|
||||
def _retrieve(
|
||||
self, query: str, filters: Optional[MetadataFilters] = None
|
||||
@@ -90,7 +92,7 @@ class EmbeddingRetriever(BaseRetriever):
|
||||
"""
|
||||
queries = [query]
|
||||
candidates = [
|
||||
self._vector_store_connector.similar_search(query, self._top_k, filters)
|
||||
self._index_store.similar_search(query, self._top_k, filters)
|
||||
for query in queries
|
||||
]
|
||||
res_candidates = cast(List[Chunk], reduce(lambda x, y: x + y, candidates))
|
||||
@@ -113,7 +115,7 @@ class EmbeddingRetriever(BaseRetriever):
|
||||
"""
|
||||
queries = [query]
|
||||
candidates_with_score = [
|
||||
self._vector_store_connector.similar_search_with_scores(
|
||||
self._index_store.similar_search_with_scores(
|
||||
query, self._top_k, score_threshold, filters
|
||||
)
|
||||
for query in queries
|
||||
@@ -217,7 +219,7 @@ class EmbeddingRetriever(BaseRetriever):
|
||||
self, query, filters: Optional[MetadataFilters] = None
|
||||
) -> List[Chunk]:
|
||||
"""Similar search."""
|
||||
return self._vector_store_connector.similar_search(query, self._top_k, filters)
|
||||
return self._index_store.similar_search(query, self._top_k, filters)
|
||||
|
||||
async def _run_async_tasks(self, tasks) -> List[Chunk]:
|
||||
"""Run async tasks."""
|
||||
@@ -229,6 +231,6 @@ class EmbeddingRetriever(BaseRetriever):
|
||||
self, query, score_threshold, filters: Optional[MetadataFilters] = None
|
||||
) -> List[Chunk]:
|
||||
"""Similar search with score."""
|
||||
return await self._vector_store_connector.asimilar_search_with_scores(
|
||||
return await self._index_store.asimilar_search_with_scores(
|
||||
query, self._top_k, score_threshold, filters
|
||||
)
|
||||
|
@@ -25,7 +25,7 @@ def mock_vector_store_connector():
|
||||
def dbstruct_retriever(mock_db_connection, mock_vector_store_connector):
|
||||
return DBSchemaRetriever(
|
||||
connector=mock_db_connection,
|
||||
vector_store_connector=mock_vector_store_connector,
|
||||
index_store=mock_vector_store_connector,
|
||||
)
|
||||
|
||||
|
||||
|
@@ -25,8 +25,8 @@ def mock_vector_store_connector():
|
||||
def embedding_retriever(top_k, mock_vector_store_connector):
|
||||
return EmbeddingRetriever(
|
||||
top_k=top_k,
|
||||
query_rewrite=False,
|
||||
vector_store_connector=mock_vector_store_connector,
|
||||
query_rewrite=None,
|
||||
index_store=mock_vector_store_connector,
|
||||
)
|
||||
|
||||
|
||||
|
@@ -7,9 +7,9 @@ from typing import Any, Dict, List, Optional, Tuple
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.rag.retriever.rerank import Ranker
|
||||
from dbgpt.rag.retriever.rewrite import QueryRewrite
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
from dbgpt.storage.vector_store.filters import MetadataFilters
|
||||
|
||||
from ..index.base import IndexStoreBase
|
||||
from .embedding import EmbeddingRetriever
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ class TimeWeightedEmbeddingRetriever(EmbeddingRetriever):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vector_store_connector: VectorStoreConnector,
|
||||
index_store: IndexStoreBase,
|
||||
top_k: int = 100,
|
||||
query_rewrite: Optional[QueryRewrite] = None,
|
||||
rerank: Optional[Ranker] = None,
|
||||
@@ -32,13 +32,13 @@ class TimeWeightedEmbeddingRetriever(EmbeddingRetriever):
|
||||
"""Initialize TimeWeightedEmbeddingRetriever.
|
||||
|
||||
Args:
|
||||
vector_store_connector (VectorStoreConnector): vector store connector
|
||||
index_store (IndexStoreBase): vector store connector
|
||||
top_k (int): top k
|
||||
query_rewrite (Optional[QueryRewrite]): query rewrite
|
||||
rerank (Ranker): rerank
|
||||
"""
|
||||
super().__init__(
|
||||
vector_store_connector=vector_store_connector,
|
||||
index_store=index_store,
|
||||
top_k=top_k,
|
||||
query_rewrite=query_rewrite,
|
||||
rerank=rerank,
|
||||
@@ -69,7 +69,7 @@ class TimeWeightedEmbeddingRetriever(EmbeddingRetriever):
|
||||
doc.metadata["created_at"] = current_time
|
||||
doc.metadata["buffer_idx"] = len(self.memory_stream) + i
|
||||
self.memory_stream.extend(dup_docs)
|
||||
return self._vector_store_connector.load_document(dup_docs)
|
||||
return self._index_store.load_document(dup_docs)
|
||||
|
||||
def _retrieve(
|
||||
self, query: str, filters: Optional[MetadataFilters] = None
|
||||
@@ -125,7 +125,7 @@ class TimeWeightedEmbeddingRetriever(EmbeddingRetriever):
|
||||
def get_salient_docs(self, query: str) -> Dict[int, Tuple[Chunk, float]]:
|
||||
"""Return documents that are salient to the query."""
|
||||
docs_and_scores: List[Chunk]
|
||||
docs_and_scores = self._vector_store_connector.similar_search_with_scores(
|
||||
docs_and_scores = self._index_store.similar_search_with_scores(
|
||||
query, topk=self._top_k, score_threshold=0
|
||||
)
|
||||
results = {}
|
||||
|
@@ -13,7 +13,7 @@ from dbgpt.core import (
|
||||
from dbgpt.datasource.base import BaseConnector
|
||||
from dbgpt.rag.schemalinker.base_linker import BaseSchemaLinker
|
||||
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
from dbgpt.serve.rag.connector import VectorStoreConnector
|
||||
from dbgpt.util.chat_util import run_async_tasks
|
||||
|
||||
INSTRUCTION = """
|
||||
|
@@ -48,8 +48,8 @@ class DBSummaryClient:
|
||||
|
||||
def get_db_summary(self, dbname, query, topk):
|
||||
"""Get user query related tables info."""
|
||||
from dbgpt.serve.rag.connector import VectorStoreConnector
|
||||
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
vector_store_config = VectorStoreConfig(name=dbname + "_profile")
|
||||
vector_connector = VectorStoreConnector.from_default(
|
||||
@@ -60,7 +60,7 @@ class DBSummaryClient:
|
||||
from dbgpt.rag.retriever.db_schema import DBSchemaRetriever
|
||||
|
||||
retriever = DBSchemaRetriever(
|
||||
top_k=topk, vector_store_connector=vector_connector
|
||||
top_k=topk, index_store=vector_connector.index_client
|
||||
)
|
||||
table_docs = retriever.retrieve(query)
|
||||
ans = [d.content for d in table_docs]
|
||||
@@ -88,8 +88,8 @@ class DBSummaryClient:
|
||||
dbname(str): dbname
|
||||
"""
|
||||
vector_store_name = dbname + "_profile"
|
||||
from dbgpt.serve.rag.connector import VectorStoreConnector
|
||||
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
vector_store_config = VectorStoreConfig(name=vector_store_name)
|
||||
vector_connector = VectorStoreConnector.from_default(
|
||||
@@ -102,7 +102,7 @@ class DBSummaryClient:
|
||||
|
||||
db_assembler = DBSchemaAssembler.load_from_connection(
|
||||
connector=db_summary_client.db,
|
||||
vector_store_connector=vector_connector,
|
||||
index_store=vector_connector.index_client,
|
||||
)
|
||||
|
||||
if len(db_assembler.get_chunks()) > 0:
|
||||
@@ -114,8 +114,8 @@ class DBSummaryClient:
|
||||
def delete_db_profile(self, dbname):
|
||||
"""Delete db profile."""
|
||||
vector_store_name = dbname + "_profile"
|
||||
from dbgpt.serve.rag.connector import VectorStoreConnector
|
||||
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
vector_store_config = VectorStoreConfig(name=vector_store_name)
|
||||
vector_connector = VectorStoreConnector.from_default(
|
||||
|
Reference in New Issue
Block a user