mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-03 18:17:45 +00:00
feat(ChatKnowledge): ChatKnowledge Support Keyword Retrieve (#1624)
Co-authored-by: Fangyin Cheng <staneyffer@gmail.com>
This commit is contained in:
@@ -4,10 +4,10 @@ from typing import List, Optional, cast
|
||||
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.datasource.base import BaseConnector
|
||||
from dbgpt.rag.index.base import IndexStoreBase
|
||||
from dbgpt.rag.retriever.base import BaseRetriever
|
||||
from dbgpt.rag.retriever.rerank import DefaultRanker, Ranker
|
||||
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
from dbgpt.storage.vector_store.filters import MetadataFilters
|
||||
from dbgpt.util.chat_util import run_async_tasks
|
||||
|
||||
@@ -17,7 +17,7 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vector_store_connector: VectorStoreConnector,
|
||||
index_store: IndexStoreBase,
|
||||
top_k: int = 4,
|
||||
connector: Optional[BaseConnector] = None,
|
||||
query_rewrite: bool = False,
|
||||
@@ -27,7 +27,7 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
"""Create DBSchemaRetriever.
|
||||
|
||||
Args:
|
||||
vector_store_connector (VectorStoreConnector): vector store connector
|
||||
index_store(IndexStore): index connector
|
||||
top_k (int): top k
|
||||
connector (Optional[BaseConnector]): RDBMSConnector.
|
||||
query_rewrite (bool): query rewrite
|
||||
@@ -67,18 +67,22 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
|
||||
|
||||
connector = _create_temporary_connection()
|
||||
vector_store_config = ChromaVectorConfig(name="vector_store_name")
|
||||
embedding_model_path = "{your_embedding_model_path}"
|
||||
embedding_fn = embedding_factory.create(model_name=embedding_model_path)
|
||||
vector_connector = VectorStoreConnector.from_default(
|
||||
"Chroma",
|
||||
vector_store_config=vector_store_config,
|
||||
embedding_fn=embedding_fn,
|
||||
config = ChromaVectorConfig(
|
||||
persist_path=PILOT_PATH,
|
||||
name="dbschema_rag_test",
|
||||
embedding_fn=DefaultEmbeddingFactory(
|
||||
default_model_name=os.path.join(
|
||||
MODEL_PATH, "text2vec-large-chinese"
|
||||
),
|
||||
).create(),
|
||||
)
|
||||
|
||||
vector_store = ChromaStore(config)
|
||||
# get db struct retriever
|
||||
retriever = DBSchemaRetriever(
|
||||
top_k=3,
|
||||
vector_store_connector=vector_connector,
|
||||
index_store=vector_store,
|
||||
connector=connector,
|
||||
)
|
||||
chunks = retriever.retrieve("show columns from table")
|
||||
@@ -88,9 +92,9 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
self._top_k = top_k
|
||||
self._connector = connector
|
||||
self._query_rewrite = query_rewrite
|
||||
self._vector_store_connector = vector_store_connector
|
||||
self._index_store = index_store
|
||||
self._need_embeddings = False
|
||||
if self._vector_store_connector:
|
||||
if self._index_store:
|
||||
self._need_embeddings = True
|
||||
self._rerank = rerank or DefaultRanker(self._top_k)
|
||||
|
||||
@@ -109,7 +113,7 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
if self._need_embeddings:
|
||||
queries = [query]
|
||||
candidates = [
|
||||
self._vector_store_connector.similar_search(query, self._top_k, filters)
|
||||
self._index_store.similar_search(query, self._top_k, filters)
|
||||
for query in queries
|
||||
]
|
||||
return cast(List[Chunk], reduce(lambda x, y: x + y, candidates))
|
||||
@@ -185,7 +189,7 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
self, query, filters: Optional[MetadataFilters] = None
|
||||
) -> List[Chunk]:
|
||||
"""Similar search."""
|
||||
return self._vector_store_connector.similar_search(query, self._top_k, filters)
|
||||
return self._index_store.similar_search(query, self._top_k, filters)
|
||||
|
||||
async def _aparse_db_summary(self) -> List[str]:
|
||||
"""Similar search."""
|
||||
|
Reference in New Issue
Block a user