feat(ChatKnowledge): ChatKnowledge Support Keyword Retrieve (#1624)

Co-authored-by: Fangyin Cheng <staneyffer@gmail.com>
This commit is contained in:
Aries-ckt
2024-06-13 13:49:17 +08:00
committed by GitHub
parent 162e2c9b1c
commit 58d08780d6
86 changed files with 948 additions and 440 deletions

View File

@@ -1,13 +1,12 @@
"""DBSchemaAssembler."""
from typing import Any, List, Optional
from dbgpt.core import Chunk, Embeddings
from dbgpt.core import Chunk
from dbgpt.datasource.base import BaseConnector
from dbgpt.storage.vector_store.connector import VectorStoreConnector
from ..assembler.base import BaseAssembler
from ..chunk_manager import ChunkParameters
from ..embedding.embedding_factory import DefaultEmbeddingFactory
from ..index.base import IndexStoreBase
from ..knowledge.datasource import DatasourceKnowledge
from ..retriever.db_schema import DBSchemaRetriever
@@ -36,36 +35,22 @@ class DBSchemaAssembler(BaseAssembler):
def __init__(
self,
connector: BaseConnector,
vector_store_connector: VectorStoreConnector,
index_store: IndexStoreBase,
chunk_parameters: Optional[ChunkParameters] = None,
embedding_model: Optional[str] = None,
embeddings: Optional[Embeddings] = None,
**kwargs: Any,
) -> None:
"""Initialize with Embedding Assembler arguments.
Args:
connector: (BaseConnector) BaseConnector connection.
vector_store_connector: (VectorStoreConnector) VectorStoreConnector to use.
index_store: (IndexStoreBase) IndexStoreBase to use.
chunk_manager: (Optional[ChunkManager]) ChunkManager to use for chunking.
embedding_model: (Optional[str]) Embedding model to use.
embeddings: (Optional[Embeddings]) Embeddings to use.
"""
knowledge = DatasourceKnowledge(connector)
self._connector = connector
self._vector_store_connector = vector_store_connector
self._embedding_model = embedding_model
if self._embedding_model and not embeddings:
embeddings = DefaultEmbeddingFactory(
default_model_name=self._embedding_model
).create(self._embedding_model)
if (
embeddings
and self._vector_store_connector.vector_store_config.embedding_fn is None
):
self._vector_store_connector.vector_store_config.embedding_fn = embeddings
self._index_store = index_store
super().__init__(
knowledge=knowledge,
@@ -77,29 +62,23 @@ class DBSchemaAssembler(BaseAssembler):
def load_from_connection(
cls,
connector: BaseConnector,
vector_store_connector: VectorStoreConnector,
index_store: IndexStoreBase,
chunk_parameters: Optional[ChunkParameters] = None,
embedding_model: Optional[str] = None,
embeddings: Optional[Embeddings] = None,
) -> "DBSchemaAssembler":
"""Load document embedding into vector store from path.
Args:
connector: (BaseConnector) BaseConnector connection.
vector_store_connector: (VectorStoreConnector) VectorStoreConnector to use.
index_store: (IndexStoreBase) IndexStoreBase to use.
chunk_parameters: (Optional[ChunkParameters]) ChunkManager to use for
chunking.
embedding_model: (Optional[str]) Embedding model to use.
embeddings: (Optional[Embeddings]) Embeddings to use.
Returns:
DBSchemaAssembler
"""
return cls(
connector=connector,
vector_store_connector=vector_store_connector,
embedding_model=embedding_model,
index_store=index_store,
chunk_parameters=chunk_parameters,
embeddings=embeddings,
)
def get_chunks(self) -> List[Chunk]:
@@ -112,7 +91,7 @@ class DBSchemaAssembler(BaseAssembler):
Returns:
List[str]: List of chunk ids.
"""
return self._vector_store_connector.load_document(self._chunks)
return self._index_store.load_document(self._chunks)
def _extract_info(self, chunks) -> List[Chunk]:
"""Extract info from chunks."""
@@ -131,5 +110,5 @@ class DBSchemaAssembler(BaseAssembler):
top_k=top_k,
connector=self._connector,
is_embeddings=True,
vector_store_connector=self._vector_store_connector,
index_store=self._index_store,
)