refactor: Refactor datasource module (#1309)

This commit is contained in:
Fangyin Cheng
2024-03-18 18:06:40 +08:00
committed by GitHub
parent 84bedee306
commit 4970c9f813
108 changed files with 1194 additions and 1066 deletions

View File

@@ -3,14 +3,14 @@
from typing import Any
from dbgpt.core.interface.operators.retriever import RetrieverOperator
from dbgpt.datasource.rdbms.base import RDBMSDatabase
from dbgpt.datasource.rdbms.base import RDBMSConnector
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
class DatasourceRetrieverOperator(RetrieverOperator[Any, Any]):
"""The Datasource Retriever Operator."""
def __init__(self, connection: RDBMSDatabase, **kwargs):
def __init__(self, connection: RDBMSConnector, **kwargs):
"""Create a new DatasourceRetrieverOperator."""
super().__init__(**kwargs)
self._connection = connection

View File

@@ -3,7 +3,7 @@
from typing import Any, Optional
from dbgpt.core.interface.operators.retriever import RetrieverOperator
from dbgpt.datasource.rdbms.base import RDBMSDatabase
from dbgpt.datasource.rdbms.base import RDBMSConnector
from dbgpt.rag.retriever.db_schema import DBSchemaRetriever
from dbgpt.storage.vector_store.connector import VectorStoreConnector
@@ -12,7 +12,7 @@ class DBSchemaRetrieverOperator(RetrieverOperator[Any, Any]):
"""The DBSchema Retriever Operator.
Args:
connection (RDBMSDatabase): The connection.
connection (RDBMSConnector): The connection.
top_k (int, optional): The top k. Defaults to 4.
vector_store_connector (VectorStoreConnector, optional): The vector store
connector. Defaults to None.
@@ -22,7 +22,7 @@ class DBSchemaRetrieverOperator(RetrieverOperator[Any, Any]):
self,
vector_store_connector: VectorStoreConnector,
top_k: int = 4,
connection: Optional[RDBMSDatabase] = None,
connection: Optional[RDBMSConnector] = None,
**kwargs
):
"""Create a new DBSchemaRetrieverOperator."""

View File

@@ -3,8 +3,8 @@
from functools import reduce
from typing import List, Optional, Union
from dbgpt.core import Chunk
from dbgpt.core.interface.operators.retriever import RetrieverOperator
from dbgpt.rag.chunk import Chunk
from dbgpt.rag.retriever.embedding import EmbeddingRetriever
from dbgpt.rag.retriever.rerank import Ranker
from dbgpt.rag.retriever.rewrite import QueryRewrite

View File

@@ -2,12 +2,11 @@
import asyncio
from typing import Any, List, Optional
from dbgpt.core import Chunk
from dbgpt.core.awel import JoinOperator
from dbgpt.core.interface.evaluation import EvaluationMetric, EvaluationResult
from dbgpt.core.interface.llm import LLMClient
from ..chunk import Chunk
class RetrieverEvaluatorOperator(JoinOperator[List[EvaluationResult]]):
"""Evaluator for retriever."""

View File

@@ -1,8 +1,8 @@
"""The Rerank Operator."""
from typing import Any, List, Optional
from dbgpt.core import Chunk
from dbgpt.core.awel import MapOperator
from dbgpt.rag.chunk import Chunk
from dbgpt.rag.retriever.rerank import RANK_FUNC, DefaultRanker

View File

@@ -7,7 +7,7 @@ from typing import Any, Optional
from dbgpt.core import LLMClient
from dbgpt.core.awel import MapOperator
from dbgpt.datasource.rdbms.base import RDBMSDatabase
from dbgpt.datasource.rdbms.base import RDBMSConnector
from dbgpt.rag.schemalinker.schema_linking import SchemaLinking
from dbgpt.storage.vector_store.connector import VectorStoreConnector
@@ -17,7 +17,7 @@ class SchemaLinkingOperator(MapOperator[Any, Any]):
def __init__(
self,
connection: RDBMSDatabase,
connection: RDBMSConnector,
model_name: str,
llm: LLMClient,
top_k: int = 5,
@@ -27,7 +27,7 @@ class SchemaLinkingOperator(MapOperator[Any, Any]):
"""Create the schema linking operator.
Args:
connection (RDBMSDatabase): The connection.
connection (RDBMSConnector): The connection.
llm (Optional[LLMClient]): base llm
"""
super().__init__(**kwargs)