feat(rag): Support RAG SDK (#1322)

This commit is contained in:
Fangyin Cheng
2024-03-22 15:36:57 +08:00
committed by GitHub
parent e65732d6e4
commit 8a17099dd2
69 changed files with 1332 additions and 558 deletions

View File

@@ -1,8 +1,14 @@
"""Module for RAG operators."""
from .datasource import DatasourceRetrieverOperator # noqa: F401
from .db_schema import DBSchemaRetrieverOperator # noqa: F401
from .embedding import EmbeddingRetrieverOperator # noqa: F401
from .db_schema import ( # noqa: F401
DBSchemaAssemblerOperator,
DBSchemaRetrieverOperator,
)
from .embedding import ( # noqa: F401
EmbeddingAssemblerOperator,
EmbeddingRetrieverOperator,
)
from .evaluation import RetrieverEvaluatorOperator # noqa: F401
from .knowledge import KnowledgeOperator # noqa: F401
from .rerank import RerankOperator # noqa: F401
@@ -12,7 +18,9 @@ from .summary import SummaryAssemblerOperator # noqa: F401
__all__ = [
"DatasourceRetrieverOperator",
"DBSchemaRetrieverOperator",
"DBSchemaAssemblerOperator",
"EmbeddingRetrieverOperator",
"EmbeddingAssemblerOperator",
"KnowledgeOperator",
"RerankOperator",
"QueryRewriteOperator",

View File

@@ -0,0 +1,24 @@
"""Base Assembler Operator."""
from abc import abstractmethod
from dbgpt.core.awel import MapOperator
from dbgpt.core.awel.task.base import IN, OUT
class AssemblerOperator(MapOperator[IN, OUT]):
"""The Base Assembler Operator."""
async def map(self, input_value: IN) -> OUT:
"""Map input value to output value.
Args:
input_value (IN): The input value.
Returns:
OUT: The output value.
"""
return await self.blocking_func_to_async(self.assemble, input_value)
@abstractmethod
def assemble(self, input_value: IN) -> OUT:
"""Assemble knowledge for input value."""

View File

@@ -1,21 +1,21 @@
"""Datasource operator for RDBMS database."""
from typing import Any
from typing import Any, List
from dbgpt.core.interface.operators.retriever import RetrieverOperator
from dbgpt.datasource.rdbms.base import RDBMSConnector
from dbgpt.datasource.base import BaseConnector
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
class DatasourceRetrieverOperator(RetrieverOperator[Any, Any]):
class DatasourceRetrieverOperator(RetrieverOperator[Any, List[str]]):
"""The Datasource Retriever Operator."""
def __init__(self, connection: RDBMSConnector, **kwargs):
def __init__(self, connector: BaseConnector, **kwargs):
"""Create a new DatasourceRetrieverOperator."""
super().__init__(**kwargs)
self._connection = connection
self._connector = connector
def retrieve(self, input_value: Any) -> Any:
def retrieve(self, input_value: Any) -> List[str]:
"""Retrieve the database summary."""
summary = _parse_db_summary(self._connection)
summary = _parse_db_summary(self._connector)
return summary

View File

@@ -1,18 +1,22 @@
"""The DBSchema Retriever Operator."""
from typing import Any, Optional
from typing import List, Optional
from dbgpt.core import Chunk
from dbgpt.core.interface.operators.retriever import RetrieverOperator
from dbgpt.datasource.rdbms.base import RDBMSConnector
from dbgpt.rag.retriever.db_schema import DBSchemaRetriever
from dbgpt.datasource.base import BaseConnector
from dbgpt.storage.vector_store.connector import VectorStoreConnector
from ..assembler.db_schema import DBSchemaAssembler
from ..retriever.db_schema import DBSchemaRetriever
from .assembler import AssemblerOperator
class DBSchemaRetrieverOperator(RetrieverOperator[Any, Any]):
class DBSchemaRetrieverOperator(RetrieverOperator[str, List[Chunk]]):
"""The DBSchema Retriever Operator.
Args:
connection (RDBMSConnector): The connection.
connector (BaseConnector): The connection.
top_k (int, optional): The top k. Defaults to 4.
vector_store_connector (VectorStoreConnector, optional): The vector store
connector. Defaults to None.
@@ -22,21 +26,57 @@ class DBSchemaRetrieverOperator(RetrieverOperator[Any, Any]):
self,
vector_store_connector: VectorStoreConnector,
top_k: int = 4,
connection: Optional[RDBMSConnector] = None,
connector: Optional[BaseConnector] = None,
**kwargs
):
"""Create a new DBSchemaRetrieverOperator."""
super().__init__(**kwargs)
self._retriever = DBSchemaRetriever(
top_k=top_k,
connection=connection,
connector=connector,
vector_store_connector=vector_store_connector,
)
def retrieve(self, query: Any) -> Any:
def retrieve(self, query: str) -> List[Chunk]:
"""Retrieve the table schemas.
Args:
query (IN): query.
query (str): The query.
"""
return self._retriever.retrieve(query)
class DBSchemaAssemblerOperator(AssemblerOperator[BaseConnector, List[Chunk]]):
"""The DBSchema Assembler Operator."""
def __init__(
self,
connector: BaseConnector,
vector_store_connector: VectorStoreConnector,
**kwargs
):
"""Create a new DBSchemaAssemblerOperator.
Args:
connector (BaseConnector): The connection.
vector_store_connector (VectorStoreConnector): The vector store connector.
"""
self._vector_store_connector = vector_store_connector
self._connector = connector
super().__init__(**kwargs)
def assemble(self, dummy_value) -> List[Chunk]:
"""Persist the database schema.
Args:
dummy_value: Dummy value, not used.
Returns:
List[Chunk]: The chunks.
"""
assembler = DBSchemaAssembler.load_from_connection(
connector=self._connector,
vector_store_connector=self._vector_store_connector,
)
assembler.persist()
return assembler.get_chunks()

View File

@@ -5,11 +5,16 @@ from typing import List, Optional, Union
from dbgpt.core import Chunk
from dbgpt.core.interface.operators.retriever import RetrieverOperator
from dbgpt.rag.retriever.embedding import EmbeddingRetriever
from dbgpt.rag.retriever.rerank import Ranker
from dbgpt.rag.retriever.rewrite import QueryRewrite
from dbgpt.storage.vector_store.connector import VectorStoreConnector
from ..assembler.embedding import EmbeddingAssembler
from ..chunk_manager import ChunkParameters
from ..knowledge import Knowledge
from ..retriever.embedding import EmbeddingRetriever
from ..retriever.rerank import Ranker
from ..retriever.rewrite import QueryRewrite
from .assembler import AssemblerOperator
class EmbeddingRetrieverOperator(RetrieverOperator[Union[str, List[str]], List[Chunk]]):
"""The Embedding Retriever Operator."""
@@ -43,3 +48,36 @@ class EmbeddingRetrieverOperator(RetrieverOperator[Union[str, List[str]], List[C
for q in query
]
return reduce(lambda x, y: x + y, candidates)
class EmbeddingAssemblerOperator(AssemblerOperator[Knowledge, List[Chunk]]):
"""The Embedding Assembler Operator."""
def __init__(
self,
vector_store_connector: VectorStoreConnector,
chunk_parameters: Optional[ChunkParameters] = ChunkParameters(
chunk_strategy="CHUNK_BY_SIZE"
),
**kwargs
):
"""Create a new EmbeddingAssemblerOperator.
Args:
vector_store_connector (VectorStoreConnector): The vector store connector.
chunk_parameters (Optional[ChunkParameters], optional): The chunk
parameters. Defaults to ChunkParameters(chunk_strategy="CHUNK_BY_SIZE").
"""
self._chunk_parameters = chunk_parameters
self._vector_store_connector = vector_store_connector
super().__init__(**kwargs)
def assemble(self, knowledge: Knowledge) -> List[Chunk]:
"""Assemble knowledge for input value."""
assembler = EmbeddingAssembler.load_from_knowledge(
knowledge=knowledge,
chunk_parameters=self._chunk_parameters,
vector_store_connector=self._vector_store_connector,
)
assembler.persist()
return assembler.get_chunks()

View File

@@ -1,6 +1,6 @@
"""Knowledge Operator."""
from typing import Any, Optional
from typing import Optional
from dbgpt.core.awel import MapOperator
from dbgpt.core.awel.flow import (
@@ -14,7 +14,7 @@ from dbgpt.rag.knowledge.base import Knowledge, KnowledgeType
from dbgpt.rag.knowledge.factory import KnowledgeFactory
class KnowledgeOperator(MapOperator[Any, Any]):
class KnowledgeOperator(MapOperator[str, Knowledge]):
"""Knowledge Factory Operator."""
metadata = ViewMetadata(
@@ -26,7 +26,7 @@ class KnowledgeOperator(MapOperator[Any, Any]):
IOField.build_from(
"knowledge datasource",
"knowledge datasource",
dict,
str,
"knowledge datasource",
)
],
@@ -85,7 +85,7 @@ class KnowledgeOperator(MapOperator[Any, Any]):
self._datasource = datasource
self._knowledge_type = KnowledgeType.get_by_value(knowledge_type)
async def map(self, datasource: Any) -> Knowledge:
async def map(self, datasource: str) -> Knowledge:
"""Create knowledge from datasource."""
if self._datasource:
datasource = self._datasource

View File

@@ -1,12 +1,12 @@
"""The Rerank Operator."""
from typing import Any, List, Optional
from typing import List, Optional
from dbgpt.core import Chunk
from dbgpt.core.awel import MapOperator
from dbgpt.rag.retriever.rerank import RANK_FUNC, DefaultRanker
class RerankOperator(MapOperator[Any, Any]):
class RerankOperator(MapOperator[List[Chunk], List[Chunk]]):
"""The Rewrite Operator."""
def __init__(

View File

@@ -7,7 +7,7 @@ from typing import Any, Optional
from dbgpt.core import LLMClient
from dbgpt.core.awel import MapOperator
from dbgpt.datasource.rdbms.base import RDBMSConnector
from dbgpt.datasource.base import BaseConnector
from dbgpt.rag.schemalinker.schema_linking import SchemaLinking
from dbgpt.storage.vector_store.connector import VectorStoreConnector
@@ -17,7 +17,7 @@ class SchemaLinkingOperator(MapOperator[Any, Any]):
def __init__(
self,
connection: RDBMSConnector,
connector: BaseConnector,
model_name: str,
llm: LLMClient,
top_k: int = 5,
@@ -27,14 +27,14 @@ class SchemaLinkingOperator(MapOperator[Any, Any]):
"""Create the schema linking operator.
Args:
connection (RDBMSConnector): The connection.
connector (BaseConnector): The connection.
llm (Optional[LLMClient]): base llm
"""
super().__init__(**kwargs)
self._schema_linking = SchemaLinking(
top_k=top_k,
connection=connection,
connector=connector,
llm=llm,
model_name=model_name,
vector_store_connector=vector_store_connector,

View File

@@ -4,9 +4,9 @@ from typing import Any, Optional
from dbgpt.core import LLMClient
from dbgpt.core.awel.flow import IOField, OperatorCategory, Parameter, ViewMetadata
from dbgpt.rag.assembler.summary import SummaryAssembler
from dbgpt.rag.knowledge.base import Knowledge
from dbgpt.serve.rag.assembler.summary import SummaryAssembler
from dbgpt.serve.rag.operators.base import AssemblerOperator
from dbgpt.rag.operators.assembler import AssemblerOperator
class SummaryAssemblerOperator(AssemblerOperator[Any, Any]):