feat(ChatKnowledge): ChatKnowledge Support Keyword Retrieve (#1624)

Co-authored-by: Fangyin Cheng <staneyffer@gmail.com>
This commit is contained in:
Aries-ckt
2024-06-13 13:49:17 +08:00
committed by GitHub
parent 162e2c9b1c
commit 58d08780d6
86 changed files with 948 additions and 440 deletions

View File

@@ -4,10 +4,10 @@ from typing import List, Optional, cast
from dbgpt.core import Chunk
from dbgpt.datasource.base import BaseConnector
from dbgpt.rag.index.base import IndexStoreBase
from dbgpt.rag.retriever.base import BaseRetriever
from dbgpt.rag.retriever.rerank import DefaultRanker, Ranker
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
from dbgpt.storage.vector_store.connector import VectorStoreConnector
from dbgpt.storage.vector_store.filters import MetadataFilters
from dbgpt.util.chat_util import run_async_tasks
@@ -17,7 +17,7 @@ class DBSchemaRetriever(BaseRetriever):
def __init__(
self,
vector_store_connector: VectorStoreConnector,
index_store: IndexStoreBase,
top_k: int = 4,
connector: Optional[BaseConnector] = None,
query_rewrite: bool = False,
@@ -27,7 +27,7 @@ class DBSchemaRetriever(BaseRetriever):
"""Create DBSchemaRetriever.
Args:
vector_store_connector (VectorStoreConnector): vector store connector
index_store(IndexStore): index connector
top_k (int): top k
connector (Optional[BaseConnector]): RDBMSConnector.
query_rewrite (bool): query rewrite
@@ -67,18 +67,22 @@ class DBSchemaRetriever(BaseRetriever):
connector = _create_temporary_connection()
vector_store_config = ChromaVectorConfig(name="vector_store_name")
embedding_model_path = "{your_embedding_model_path}"
embedding_fn = embedding_factory.create(model_name=embedding_model_path)
vector_connector = VectorStoreConnector.from_default(
"Chroma",
vector_store_config=vector_store_config,
embedding_fn=embedding_fn,
config = ChromaVectorConfig(
persist_path=PILOT_PATH,
name="dbschema_rag_test",
embedding_fn=DefaultEmbeddingFactory(
default_model_name=os.path.join(
MODEL_PATH, "text2vec-large-chinese"
),
).create(),
)
vector_store = ChromaStore(config)
# get db struct retriever
retriever = DBSchemaRetriever(
top_k=3,
vector_store_connector=vector_connector,
index_store=vector_store,
connector=connector,
)
chunks = retriever.retrieve("show columns from table")
@@ -88,9 +92,9 @@ class DBSchemaRetriever(BaseRetriever):
self._top_k = top_k
self._connector = connector
self._query_rewrite = query_rewrite
self._vector_store_connector = vector_store_connector
self._index_store = index_store
self._need_embeddings = False
if self._vector_store_connector:
if self._index_store:
self._need_embeddings = True
self._rerank = rerank or DefaultRanker(self._top_k)
@@ -109,7 +113,7 @@ class DBSchemaRetriever(BaseRetriever):
if self._need_embeddings:
queries = [query]
candidates = [
self._vector_store_connector.similar_search(query, self._top_k, filters)
self._index_store.similar_search(query, self._top_k, filters)
for query in queries
]
return cast(List[Chunk], reduce(lambda x, y: x + y, candidates))
@@ -185,7 +189,7 @@ class DBSchemaRetriever(BaseRetriever):
self, query, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Similar search."""
return self._vector_store_connector.similar_search(query, self._top_k, filters)
return self._index_store.similar_search(query, self._top_k, filters)
async def _aparse_db_summary(self) -> List[str]:
"""Similar search."""