feat(rag): Support RAG SDK (#1322)

This commit is contained in:
Fangyin Cheng
2024-03-22 15:36:57 +08:00
committed by GitHub
parent e65732d6e4
commit 8a17099dd2
69 changed files with 1332 additions and 558 deletions

View File

@@ -3,7 +3,7 @@ from functools import reduce
from typing import List, Optional, cast
from dbgpt.core import Chunk
from dbgpt.datasource.rdbms.base import RDBMSConnector
from dbgpt.datasource.base import BaseConnector
from dbgpt.rag.retriever.base import BaseRetriever
from dbgpt.rag.retriever.rerank import DefaultRanker, Ranker
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
@@ -18,7 +18,7 @@ class DBSchemaRetriever(BaseRetriever):
self,
vector_store_connector: VectorStoreConnector,
top_k: int = 4,
connection: Optional[RDBMSConnector] = None,
connector: Optional[BaseConnector] = None,
query_rewrite: bool = False,
rerank: Optional[Ranker] = None,
**kwargs
@@ -28,7 +28,7 @@ class DBSchemaRetriever(BaseRetriever):
Args:
vector_store_connector (VectorStoreConnector): vector store connector
top_k (int): top k
connection (Optional[RDBMSConnector]): RDBMSConnector connection.
connector (Optional[BaseConnector]): RDBMSConnector.
query_rewrite (bool): query rewrite
rerank (Ranker): rerank
@@ -65,7 +65,7 @@ class DBSchemaRetriever(BaseRetriever):
return connect
connection = _create_temporary_connection()
connector = _create_temporary_connection()
vector_store_config = ChromaVectorConfig(name="vector_store_name")
embedding_model_path = "{your_embedding_model_path}"
embedding_fn = embedding_factory.create(model_name=embedding_model_path)
@@ -76,14 +76,16 @@ class DBSchemaRetriever(BaseRetriever):
)
# get db struct retriever
retriever = DBSchemaRetriever(
top_k=3, vector_store_connector=vector_connector
top_k=3,
vector_store_connector=vector_connector,
connector=connector,
)
chunks = retriever.retrieve("show columns from table")
result = [chunk.content for chunk in chunks]
print(f"db struct rag example results:{result}")
"""
self._top_k = top_k
self._connection = connection
self._connector = connector
self._query_rewrite = query_rewrite
self._vector_store_connector = vector_store_connector
self._need_embeddings = False
@@ -108,9 +110,9 @@ class DBSchemaRetriever(BaseRetriever):
]
return cast(List[Chunk], reduce(lambda x, y: x + y, candidates))
else:
if not self._connection:
if not self._connector:
raise RuntimeError("RDBMSConnector connection is required.")
table_summaries = _parse_db_summary(self._connection)
table_summaries = _parse_db_summary(self._connector)
return [Chunk(content=table_summary) for table_summary in table_summaries]
def _retrieve_with_score(self, query: str, score_threshold: float) -> List[Chunk]:
@@ -173,6 +175,6 @@ class DBSchemaRetriever(BaseRetriever):
"""Similar search."""
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
if not self._connection:
if not self._connector:
raise RuntimeError("RDBMSConnector connection is required.")
return _parse_db_summary(self._connection)
return _parse_db_summary(self._connector)