mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-04 10:34:30 +00:00
feat(rag): Support RAG SDK (#1322)
This commit is contained in:
@@ -3,7 +3,7 @@ from functools import reduce
|
||||
from typing import List, Optional, cast
|
||||
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.datasource.rdbms.base import RDBMSConnector
|
||||
from dbgpt.datasource.base import BaseConnector
|
||||
from dbgpt.rag.retriever.base import BaseRetriever
|
||||
from dbgpt.rag.retriever.rerank import DefaultRanker, Ranker
|
||||
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
|
||||
@@ -18,7 +18,7 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
self,
|
||||
vector_store_connector: VectorStoreConnector,
|
||||
top_k: int = 4,
|
||||
connection: Optional[RDBMSConnector] = None,
|
||||
connector: Optional[BaseConnector] = None,
|
||||
query_rewrite: bool = False,
|
||||
rerank: Optional[Ranker] = None,
|
||||
**kwargs
|
||||
@@ -28,7 +28,7 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
Args:
|
||||
vector_store_connector (VectorStoreConnector): vector store connector
|
||||
top_k (int): top k
|
||||
connection (Optional[RDBMSConnector]): RDBMSConnector connection.
|
||||
connector (Optional[BaseConnector]): RDBMSConnector.
|
||||
query_rewrite (bool): query rewrite
|
||||
rerank (Ranker): rerank
|
||||
|
||||
@@ -65,7 +65,7 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
return connect
|
||||
|
||||
|
||||
connection = _create_temporary_connection()
|
||||
connector = _create_temporary_connection()
|
||||
vector_store_config = ChromaVectorConfig(name="vector_store_name")
|
||||
embedding_model_path = "{your_embedding_model_path}"
|
||||
embedding_fn = embedding_factory.create(model_name=embedding_model_path)
|
||||
@@ -76,14 +76,16 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
)
|
||||
# get db struct retriever
|
||||
retriever = DBSchemaRetriever(
|
||||
top_k=3, vector_store_connector=vector_connector
|
||||
top_k=3,
|
||||
vector_store_connector=vector_connector,
|
||||
connector=connector,
|
||||
)
|
||||
chunks = retriever.retrieve("show columns from table")
|
||||
result = [chunk.content for chunk in chunks]
|
||||
print(f"db struct rag example results:{result}")
|
||||
"""
|
||||
self._top_k = top_k
|
||||
self._connection = connection
|
||||
self._connector = connector
|
||||
self._query_rewrite = query_rewrite
|
||||
self._vector_store_connector = vector_store_connector
|
||||
self._need_embeddings = False
|
||||
@@ -108,9 +110,9 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
]
|
||||
return cast(List[Chunk], reduce(lambda x, y: x + y, candidates))
|
||||
else:
|
||||
if not self._connection:
|
||||
if not self._connector:
|
||||
raise RuntimeError("RDBMSConnector connection is required.")
|
||||
table_summaries = _parse_db_summary(self._connection)
|
||||
table_summaries = _parse_db_summary(self._connector)
|
||||
return [Chunk(content=table_summary) for table_summary in table_summaries]
|
||||
|
||||
def _retrieve_with_score(self, query: str, score_threshold: float) -> List[Chunk]:
|
||||
@@ -173,6 +175,6 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
"""Similar search."""
|
||||
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
|
||||
|
||||
if not self._connection:
|
||||
if not self._connector:
|
||||
raise RuntimeError("RDBMSConnector connection is required.")
|
||||
return _parse_db_summary(self._connection)
|
||||
return _parse_db_summary(self._connector)
|
||||
|
Reference in New Issue
Block a user