refactor: Refactor datasource module (#1309)

This commit is contained in:
Fangyin Cheng
2024-03-18 18:06:40 +08:00
committed by GitHub
parent 84bedee306
commit 4970c9f813
108 changed files with 1194 additions and 1066 deletions

View File

@@ -2,8 +2,8 @@
from functools import reduce
from typing import List, Optional, cast
from dbgpt.datasource.rdbms.base import RDBMSDatabase
from dbgpt.rag.chunk import Chunk
from dbgpt.core import Chunk
from dbgpt.datasource.rdbms.base import RDBMSConnector
from dbgpt.rag.retriever.base import BaseRetriever
from dbgpt.rag.retriever.rerank import DefaultRanker, Ranker
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
@@ -18,7 +18,7 @@ class DBSchemaRetriever(BaseRetriever):
self,
vector_store_connector: VectorStoreConnector,
top_k: int = 4,
connection: Optional[RDBMSDatabase] = None,
connection: Optional[RDBMSConnector] = None,
query_rewrite: bool = False,
rerank: Optional[Ranker] = None,
**kwargs
@@ -28,14 +28,14 @@ class DBSchemaRetriever(BaseRetriever):
Args:
vector_store_connector (VectorStoreConnector): vector store connector
top_k (int): top k
connection (Optional[RDBMSDatabase]): RDBMSDatabase connection.
connection (Optional[RDBMSConnector]): RDBMSConnector connection.
query_rewrite (bool): query rewrite
rerank (Ranker): rerank
Examples:
.. code-block:: python
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnector
from dbgpt.serve.rag.assembler.db_schema import DBSchemaAssembler
from dbgpt.storage.vector_store.connector import VectorStoreConnector
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
@@ -43,7 +43,7 @@ class DBSchemaRetriever(BaseRetriever):
def _create_temporary_connection():
connect = SQLiteTempConnect.create_temporary_db()
connect = SQLiteTempConnector.create_temporary_db()
connect.create_temp_tables(
{
"user": {
@@ -109,7 +109,7 @@ class DBSchemaRetriever(BaseRetriever):
return cast(List[Chunk], reduce(lambda x, y: x + y, candidates))
else:
if not self._connection:
raise RuntimeError("RDBMSDatabase connection is required.")
raise RuntimeError("RDBMSConnector connection is required.")
table_summaries = _parse_db_summary(self._connection)
return [Chunk(content=table_summary) for table_summary in table_summaries]
@@ -174,5 +174,5 @@ class DBSchemaRetriever(BaseRetriever):
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
if not self._connection:
raise RuntimeError("RDBMSDatabase connection is required.")
raise RuntimeError("RDBMSConnector connection is required.")
return _parse_db_summary(self._connection)