mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-14 05:31:40 +00:00
refactor: Refactor datasource module (#1309)
This commit is contained in:
@@ -3,14 +3,14 @@
|
||||
from typing import Any
|
||||
|
||||
from dbgpt.core.interface.operators.retriever import RetrieverOperator
|
||||
from dbgpt.datasource.rdbms.base import RDBMSDatabase
|
||||
from dbgpt.datasource.rdbms.base import RDBMSConnector
|
||||
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
|
||||
|
||||
|
||||
class DatasourceRetrieverOperator(RetrieverOperator[Any, Any]):
|
||||
"""The Datasource Retriever Operator."""
|
||||
|
||||
def __init__(self, connection: RDBMSDatabase, **kwargs):
|
||||
def __init__(self, connection: RDBMSConnector, **kwargs):
|
||||
"""Create a new DatasourceRetrieverOperator."""
|
||||
super().__init__(**kwargs)
|
||||
self._connection = connection
|
||||
|
@@ -3,7 +3,7 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from dbgpt.core.interface.operators.retriever import RetrieverOperator
|
||||
from dbgpt.datasource.rdbms.base import RDBMSDatabase
|
||||
from dbgpt.datasource.rdbms.base import RDBMSConnector
|
||||
from dbgpt.rag.retriever.db_schema import DBSchemaRetriever
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
@@ -12,7 +12,7 @@ class DBSchemaRetrieverOperator(RetrieverOperator[Any, Any]):
|
||||
"""The DBSchema Retriever Operator.
|
||||
|
||||
Args:
|
||||
connection (RDBMSDatabase): The connection.
|
||||
connection (RDBMSConnector): The connection.
|
||||
top_k (int, optional): The top k. Defaults to 4.
|
||||
vector_store_connector (VectorStoreConnector, optional): The vector store
|
||||
connector. Defaults to None.
|
||||
@@ -22,7 +22,7 @@ class DBSchemaRetrieverOperator(RetrieverOperator[Any, Any]):
|
||||
self,
|
||||
vector_store_connector: VectorStoreConnector,
|
||||
top_k: int = 4,
|
||||
connection: Optional[RDBMSDatabase] = None,
|
||||
connection: Optional[RDBMSConnector] = None,
|
||||
**kwargs
|
||||
):
|
||||
"""Create a new DBSchemaRetrieverOperator."""
|
||||
|
@@ -3,8 +3,8 @@
|
||||
from functools import reduce
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.core.interface.operators.retriever import RetrieverOperator
|
||||
from dbgpt.rag.chunk import Chunk
|
||||
from dbgpt.rag.retriever.embedding import EmbeddingRetriever
|
||||
from dbgpt.rag.retriever.rerank import Ranker
|
||||
from dbgpt.rag.retriever.rewrite import QueryRewrite
|
||||
|
@@ -2,12 +2,11 @@
|
||||
import asyncio
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.core.awel import JoinOperator
|
||||
from dbgpt.core.interface.evaluation import EvaluationMetric, EvaluationResult
|
||||
from dbgpt.core.interface.llm import LLMClient
|
||||
|
||||
from ..chunk import Chunk
|
||||
|
||||
|
||||
class RetrieverEvaluatorOperator(JoinOperator[List[EvaluationResult]]):
|
||||
"""Evaluator for retriever."""
|
||||
|
@@ -1,8 +1,8 @@
|
||||
"""The Rerank Operator."""
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.core.awel import MapOperator
|
||||
from dbgpt.rag.chunk import Chunk
|
||||
from dbgpt.rag.retriever.rerank import RANK_FUNC, DefaultRanker
|
||||
|
||||
|
||||
|
@@ -7,7 +7,7 @@ from typing import Any, Optional
|
||||
|
||||
from dbgpt.core import LLMClient
|
||||
from dbgpt.core.awel import MapOperator
|
||||
from dbgpt.datasource.rdbms.base import RDBMSDatabase
|
||||
from dbgpt.datasource.rdbms.base import RDBMSConnector
|
||||
from dbgpt.rag.schemalinker.schema_linking import SchemaLinking
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
@@ -17,7 +17,7 @@ class SchemaLinkingOperator(MapOperator[Any, Any]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connection: RDBMSDatabase,
|
||||
connection: RDBMSConnector,
|
||||
model_name: str,
|
||||
llm: LLMClient,
|
||||
top_k: int = 5,
|
||||
@@ -27,7 +27,7 @@ class SchemaLinkingOperator(MapOperator[Any, Any]):
|
||||
"""Create the schema linking operator.
|
||||
|
||||
Args:
|
||||
connection (RDBMSDatabase): The connection.
|
||||
connection (RDBMSConnector): The connection.
|
||||
llm (Optional[LLMClient]): base llm
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
|
Reference in New Issue
Block a user