mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-14 13:40:54 +00:00
feat(rag): Support RAG SDK (#1322)
This commit is contained in:
@@ -1,8 +1,14 @@
|
||||
"""Module for RAG operators."""
|
||||
|
||||
from .datasource import DatasourceRetrieverOperator # noqa: F401
|
||||
from .db_schema import DBSchemaRetrieverOperator # noqa: F401
|
||||
from .embedding import EmbeddingRetrieverOperator # noqa: F401
|
||||
from .db_schema import ( # noqa: F401
|
||||
DBSchemaAssemblerOperator,
|
||||
DBSchemaRetrieverOperator,
|
||||
)
|
||||
from .embedding import ( # noqa: F401
|
||||
EmbeddingAssemblerOperator,
|
||||
EmbeddingRetrieverOperator,
|
||||
)
|
||||
from .evaluation import RetrieverEvaluatorOperator # noqa: F401
|
||||
from .knowledge import KnowledgeOperator # noqa: F401
|
||||
from .rerank import RerankOperator # noqa: F401
|
||||
@@ -12,7 +18,9 @@ from .summary import SummaryAssemblerOperator # noqa: F401
|
||||
__all__ = [
|
||||
"DatasourceRetrieverOperator",
|
||||
"DBSchemaRetrieverOperator",
|
||||
"DBSchemaAssemblerOperator",
|
||||
"EmbeddingRetrieverOperator",
|
||||
"EmbeddingAssemblerOperator",
|
||||
"KnowledgeOperator",
|
||||
"RerankOperator",
|
||||
"QueryRewriteOperator",
|
||||
|
24
dbgpt/rag/operators/assembler.py
Normal file
24
dbgpt/rag/operators/assembler.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""Base Assembler Operator."""
|
||||
from abc import abstractmethod
|
||||
|
||||
from dbgpt.core.awel import MapOperator
|
||||
from dbgpt.core.awel.task.base import IN, OUT
|
||||
|
||||
|
||||
class AssemblerOperator(MapOperator[IN, OUT]):
|
||||
"""The Base Assembler Operator."""
|
||||
|
||||
async def map(self, input_value: IN) -> OUT:
|
||||
"""Map input value to output value.
|
||||
|
||||
Args:
|
||||
input_value (IN): The input value.
|
||||
|
||||
Returns:
|
||||
OUT: The output value.
|
||||
"""
|
||||
return await self.blocking_func_to_async(self.assemble, input_value)
|
||||
|
||||
@abstractmethod
|
||||
def assemble(self, input_value: IN) -> OUT:
|
||||
"""Assemble knowledge for input value."""
|
@@ -1,21 +1,21 @@
|
||||
"""Datasource operator for RDBMS database."""
|
||||
|
||||
from typing import Any
|
||||
from typing import Any, List
|
||||
|
||||
from dbgpt.core.interface.operators.retriever import RetrieverOperator
|
||||
from dbgpt.datasource.rdbms.base import RDBMSConnector
|
||||
from dbgpt.datasource.base import BaseConnector
|
||||
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
|
||||
|
||||
|
||||
class DatasourceRetrieverOperator(RetrieverOperator[Any, Any]):
|
||||
class DatasourceRetrieverOperator(RetrieverOperator[Any, List[str]]):
|
||||
"""The Datasource Retriever Operator."""
|
||||
|
||||
def __init__(self, connection: RDBMSConnector, **kwargs):
|
||||
def __init__(self, connector: BaseConnector, **kwargs):
|
||||
"""Create a new DatasourceRetrieverOperator."""
|
||||
super().__init__(**kwargs)
|
||||
self._connection = connection
|
||||
self._connector = connector
|
||||
|
||||
def retrieve(self, input_value: Any) -> Any:
|
||||
def retrieve(self, input_value: Any) -> List[str]:
|
||||
"""Retrieve the database summary."""
|
||||
summary = _parse_db_summary(self._connection)
|
||||
summary = _parse_db_summary(self._connector)
|
||||
return summary
|
||||
|
@@ -1,18 +1,22 @@
|
||||
"""The DBSchema Retriever Operator."""
|
||||
|
||||
from typing import Any, Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.core.interface.operators.retriever import RetrieverOperator
|
||||
from dbgpt.datasource.rdbms.base import RDBMSConnector
|
||||
from dbgpt.rag.retriever.db_schema import DBSchemaRetriever
|
||||
from dbgpt.datasource.base import BaseConnector
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
from ..assembler.db_schema import DBSchemaAssembler
|
||||
from ..retriever.db_schema import DBSchemaRetriever
|
||||
from .assembler import AssemblerOperator
|
||||
|
||||
class DBSchemaRetrieverOperator(RetrieverOperator[Any, Any]):
|
||||
|
||||
class DBSchemaRetrieverOperator(RetrieverOperator[str, List[Chunk]]):
|
||||
"""The DBSchema Retriever Operator.
|
||||
|
||||
Args:
|
||||
connection (RDBMSConnector): The connection.
|
||||
connector (BaseConnector): The connection.
|
||||
top_k (int, optional): The top k. Defaults to 4.
|
||||
vector_store_connector (VectorStoreConnector, optional): The vector store
|
||||
connector. Defaults to None.
|
||||
@@ -22,21 +26,57 @@ class DBSchemaRetrieverOperator(RetrieverOperator[Any, Any]):
|
||||
self,
|
||||
vector_store_connector: VectorStoreConnector,
|
||||
top_k: int = 4,
|
||||
connection: Optional[RDBMSConnector] = None,
|
||||
connector: Optional[BaseConnector] = None,
|
||||
**kwargs
|
||||
):
|
||||
"""Create a new DBSchemaRetrieverOperator."""
|
||||
super().__init__(**kwargs)
|
||||
self._retriever = DBSchemaRetriever(
|
||||
top_k=top_k,
|
||||
connection=connection,
|
||||
connector=connector,
|
||||
vector_store_connector=vector_store_connector,
|
||||
)
|
||||
|
||||
def retrieve(self, query: Any) -> Any:
|
||||
def retrieve(self, query: str) -> List[Chunk]:
|
||||
"""Retrieve the table schemas.
|
||||
|
||||
Args:
|
||||
query (IN): query.
|
||||
query (str): The query.
|
||||
"""
|
||||
return self._retriever.retrieve(query)
|
||||
|
||||
|
||||
class DBSchemaAssemblerOperator(AssemblerOperator[BaseConnector, List[Chunk]]):
|
||||
"""The DBSchema Assembler Operator."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connector: BaseConnector,
|
||||
vector_store_connector: VectorStoreConnector,
|
||||
**kwargs
|
||||
):
|
||||
"""Create a new DBSchemaAssemblerOperator.
|
||||
|
||||
Args:
|
||||
connector (BaseConnector): The connection.
|
||||
vector_store_connector (VectorStoreConnector): The vector store connector.
|
||||
"""
|
||||
self._vector_store_connector = vector_store_connector
|
||||
self._connector = connector
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def assemble(self, dummy_value) -> List[Chunk]:
|
||||
"""Persist the database schema.
|
||||
|
||||
Args:
|
||||
dummy_value: Dummy value, not used.
|
||||
|
||||
Returns:
|
||||
List[Chunk]: The chunks.
|
||||
"""
|
||||
assembler = DBSchemaAssembler.load_from_connection(
|
||||
connector=self._connector,
|
||||
vector_store_connector=self._vector_store_connector,
|
||||
)
|
||||
assembler.persist()
|
||||
return assembler.get_chunks()
|
||||
|
@@ -5,11 +5,16 @@ from typing import List, Optional, Union
|
||||
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.core.interface.operators.retriever import RetrieverOperator
|
||||
from dbgpt.rag.retriever.embedding import EmbeddingRetriever
|
||||
from dbgpt.rag.retriever.rerank import Ranker
|
||||
from dbgpt.rag.retriever.rewrite import QueryRewrite
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
from ..assembler.embedding import EmbeddingAssembler
|
||||
from ..chunk_manager import ChunkParameters
|
||||
from ..knowledge import Knowledge
|
||||
from ..retriever.embedding import EmbeddingRetriever
|
||||
from ..retriever.rerank import Ranker
|
||||
from ..retriever.rewrite import QueryRewrite
|
||||
from .assembler import AssemblerOperator
|
||||
|
||||
|
||||
class EmbeddingRetrieverOperator(RetrieverOperator[Union[str, List[str]], List[Chunk]]):
|
||||
"""The Embedding Retriever Operator."""
|
||||
@@ -43,3 +48,36 @@ class EmbeddingRetrieverOperator(RetrieverOperator[Union[str, List[str]], List[C
|
||||
for q in query
|
||||
]
|
||||
return reduce(lambda x, y: x + y, candidates)
|
||||
|
||||
|
||||
class EmbeddingAssemblerOperator(AssemblerOperator[Knowledge, List[Chunk]]):
|
||||
"""The Embedding Assembler Operator."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vector_store_connector: VectorStoreConnector,
|
||||
chunk_parameters: Optional[ChunkParameters] = ChunkParameters(
|
||||
chunk_strategy="CHUNK_BY_SIZE"
|
||||
),
|
||||
**kwargs
|
||||
):
|
||||
"""Create a new EmbeddingAssemblerOperator.
|
||||
|
||||
Args:
|
||||
vector_store_connector (VectorStoreConnector): The vector store connector.
|
||||
chunk_parameters (Optional[ChunkParameters], optional): The chunk
|
||||
parameters. Defaults to ChunkParameters(chunk_strategy="CHUNK_BY_SIZE").
|
||||
"""
|
||||
self._chunk_parameters = chunk_parameters
|
||||
self._vector_store_connector = vector_store_connector
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def assemble(self, knowledge: Knowledge) -> List[Chunk]:
|
||||
"""Assemble knowledge for input value."""
|
||||
assembler = EmbeddingAssembler.load_from_knowledge(
|
||||
knowledge=knowledge,
|
||||
chunk_parameters=self._chunk_parameters,
|
||||
vector_store_connector=self._vector_store_connector,
|
||||
)
|
||||
assembler.persist()
|
||||
return assembler.get_chunks()
|
||||
|
@@ -1,6 +1,6 @@
|
||||
"""Knowledge Operator."""
|
||||
|
||||
from typing import Any, Optional
|
||||
from typing import Optional
|
||||
|
||||
from dbgpt.core.awel import MapOperator
|
||||
from dbgpt.core.awel.flow import (
|
||||
@@ -14,7 +14,7 @@ from dbgpt.rag.knowledge.base import Knowledge, KnowledgeType
|
||||
from dbgpt.rag.knowledge.factory import KnowledgeFactory
|
||||
|
||||
|
||||
class KnowledgeOperator(MapOperator[Any, Any]):
|
||||
class KnowledgeOperator(MapOperator[str, Knowledge]):
|
||||
"""Knowledge Factory Operator."""
|
||||
|
||||
metadata = ViewMetadata(
|
||||
@@ -26,7 +26,7 @@ class KnowledgeOperator(MapOperator[Any, Any]):
|
||||
IOField.build_from(
|
||||
"knowledge datasource",
|
||||
"knowledge datasource",
|
||||
dict,
|
||||
str,
|
||||
"knowledge datasource",
|
||||
)
|
||||
],
|
||||
@@ -85,7 +85,7 @@ class KnowledgeOperator(MapOperator[Any, Any]):
|
||||
self._datasource = datasource
|
||||
self._knowledge_type = KnowledgeType.get_by_value(knowledge_type)
|
||||
|
||||
async def map(self, datasource: Any) -> Knowledge:
|
||||
async def map(self, datasource: str) -> Knowledge:
|
||||
"""Create knowledge from datasource."""
|
||||
if self._datasource:
|
||||
datasource = self._datasource
|
||||
|
@@ -1,12 +1,12 @@
|
||||
"""The Rerank Operator."""
|
||||
from typing import Any, List, Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.core.awel import MapOperator
|
||||
from dbgpt.rag.retriever.rerank import RANK_FUNC, DefaultRanker
|
||||
|
||||
|
||||
class RerankOperator(MapOperator[Any, Any]):
|
||||
class RerankOperator(MapOperator[List[Chunk], List[Chunk]]):
|
||||
"""The Rewrite Operator."""
|
||||
|
||||
def __init__(
|
||||
|
@@ -7,7 +7,7 @@ from typing import Any, Optional
|
||||
|
||||
from dbgpt.core import LLMClient
|
||||
from dbgpt.core.awel import MapOperator
|
||||
from dbgpt.datasource.rdbms.base import RDBMSConnector
|
||||
from dbgpt.datasource.base import BaseConnector
|
||||
from dbgpt.rag.schemalinker.schema_linking import SchemaLinking
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
@@ -17,7 +17,7 @@ class SchemaLinkingOperator(MapOperator[Any, Any]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connection: RDBMSConnector,
|
||||
connector: BaseConnector,
|
||||
model_name: str,
|
||||
llm: LLMClient,
|
||||
top_k: int = 5,
|
||||
@@ -27,14 +27,14 @@ class SchemaLinkingOperator(MapOperator[Any, Any]):
|
||||
"""Create the schema linking operator.
|
||||
|
||||
Args:
|
||||
connection (RDBMSConnector): The connection.
|
||||
connector (BaseConnector): The connection.
|
||||
llm (Optional[LLMClient]): base llm
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._schema_linking = SchemaLinking(
|
||||
top_k=top_k,
|
||||
connection=connection,
|
||||
connector=connector,
|
||||
llm=llm,
|
||||
model_name=model_name,
|
||||
vector_store_connector=vector_store_connector,
|
||||
|
@@ -4,9 +4,9 @@ from typing import Any, Optional
|
||||
|
||||
from dbgpt.core import LLMClient
|
||||
from dbgpt.core.awel.flow import IOField, OperatorCategory, Parameter, ViewMetadata
|
||||
from dbgpt.rag.assembler.summary import SummaryAssembler
|
||||
from dbgpt.rag.knowledge.base import Knowledge
|
||||
from dbgpt.serve.rag.assembler.summary import SummaryAssembler
|
||||
from dbgpt.serve.rag.operators.base import AssemblerOperator
|
||||
from dbgpt.rag.operators.assembler import AssemblerOperator
|
||||
|
||||
|
||||
class SummaryAssemblerOperator(AssemblerOperator[Any, Any]):
|
||||
|
Reference in New Issue
Block a user