mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-03 18:17:45 +00:00
refactor: Refactor datasource module (#1309)
This commit is contained in:
@@ -2,8 +2,8 @@
|
||||
from functools import reduce
|
||||
from typing import List, Optional, cast
|
||||
|
||||
from dbgpt.datasource.rdbms.base import RDBMSDatabase
|
||||
from dbgpt.rag.chunk import Chunk
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.datasource.rdbms.base import RDBMSConnector
|
||||
from dbgpt.rag.retriever.base import BaseRetriever
|
||||
from dbgpt.rag.retriever.rerank import DefaultRanker, Ranker
|
||||
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
|
||||
@@ -18,7 +18,7 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
self,
|
||||
vector_store_connector: VectorStoreConnector,
|
||||
top_k: int = 4,
|
||||
connection: Optional[RDBMSDatabase] = None,
|
||||
connection: Optional[RDBMSConnector] = None,
|
||||
query_rewrite: bool = False,
|
||||
rerank: Optional[Ranker] = None,
|
||||
**kwargs
|
||||
@@ -28,14 +28,14 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
Args:
|
||||
vector_store_connector (VectorStoreConnector): vector store connector
|
||||
top_k (int): top k
|
||||
connection (Optional[RDBMSDatabase]): RDBMSDatabase connection.
|
||||
connection (Optional[RDBMSConnector]): RDBMSConnector connection.
|
||||
query_rewrite (bool): query rewrite
|
||||
rerank (Ranker): rerank
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect
|
||||
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnector
|
||||
from dbgpt.serve.rag.assembler.db_schema import DBSchemaAssembler
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
||||
@@ -43,7 +43,7 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
|
||||
|
||||
def _create_temporary_connection():
|
||||
connect = SQLiteTempConnect.create_temporary_db()
|
||||
connect = SQLiteTempConnector.create_temporary_db()
|
||||
connect.create_temp_tables(
|
||||
{
|
||||
"user": {
|
||||
@@ -109,7 +109,7 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
return cast(List[Chunk], reduce(lambda x, y: x + y, candidates))
|
||||
else:
|
||||
if not self._connection:
|
||||
raise RuntimeError("RDBMSDatabase connection is required.")
|
||||
raise RuntimeError("RDBMSConnector connection is required.")
|
||||
table_summaries = _parse_db_summary(self._connection)
|
||||
return [Chunk(content=table_summary) for table_summary in table_summaries]
|
||||
|
||||
@@ -174,5 +174,5 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
|
||||
|
||||
if not self._connection:
|
||||
raise RuntimeError("RDBMSDatabase connection is required.")
|
||||
raise RuntimeError("RDBMSConnector connection is required.")
|
||||
return _parse_db_summary(self._connection)
|
||||
|
Reference in New Issue
Block a user