Files
DB-GPT/dbgpt/rag/operators/db_schema.py
2024-03-27 22:47:28 +08:00

91 lines
2.9 KiB
Python

"""The DBSchema Retriever Operator."""
from typing import List, Optional
from dbgpt.core import Chunk
from dbgpt.core.interface.operators.retriever import RetrieverOperator
from dbgpt.datasource.base import BaseConnector
from dbgpt.storage.vector_store.connector import VectorStoreConnector
from ..assembler.db_schema import DBSchemaAssembler
from ..chunk_manager import ChunkParameters
from ..retriever.db_schema import DBSchemaRetriever
from .assembler import AssemblerOperator
class DBSchemaRetrieverOperator(RetrieverOperator[str, List[Chunk]]):
"""The DBSchema Retriever Operator.
Args:
connector (BaseConnector): The connection.
top_k (int, optional): The top k. Defaults to 4.
vector_store_connector (VectorStoreConnector, optional): The vector store
connector. Defaults to None.
"""
def __init__(
self,
vector_store_connector: VectorStoreConnector,
top_k: int = 4,
connector: Optional[BaseConnector] = None,
**kwargs
):
"""Create a new DBSchemaRetrieverOperator."""
super().__init__(**kwargs)
self._retriever = DBSchemaRetriever(
top_k=top_k,
connector=connector,
vector_store_connector=vector_store_connector,
)
def retrieve(self, query: str) -> List[Chunk]:
"""Retrieve the table schemas.
Args:
query (str): The query.
"""
return self._retriever.retrieve(query)
class DBSchemaAssemblerOperator(AssemblerOperator[BaseConnector, List[Chunk]]):
"""The DBSchema Assembler Operator."""
def __init__(
self,
connector: BaseConnector,
vector_store_connector: VectorStoreConnector,
chunk_parameters: Optional[ChunkParameters] = None,
**kwargs
):
"""Create a new DBSchemaAssemblerOperator.
Args:
connector (BaseConnector): The connection.
vector_store_connector (VectorStoreConnector): The vector store connector.
chunk_parameters (Optional[ChunkParameters], optional): The chunk
parameters.
"""
if not chunk_parameters:
chunk_parameters = ChunkParameters(chunk_strategy="CHUNK_BY_SIZE")
self._chunk_parameters = chunk_parameters
self._vector_store_connector = vector_store_connector
self._connector = connector
super().__init__(**kwargs)
def assemble(self, dummy_value) -> List[Chunk]:
"""Persist the database schema.
Args:
dummy_value: Dummy value, not used.
Returns:
List[Chunk]: The chunks.
"""
assembler = DBSchemaAssembler.load_from_connection(
connector=self._connector,
chunk_parameters=self._chunk_parameters,
vector_store_connector=self._vector_store_connector,
)
assembler.persist()
return assembler.get_chunks()