DB-GPT/dbgpt/rag/assembler/db_schema.py
Cooper 9b0161e521
Feat rdb summary wide table (#2035)
Co-authored-by: dongzhancai1 <dongzhancai1@jd.com>
Co-authored-by: dong <dongzhancai@iie2.com>
2024-12-18 20:34:21 +08:00

185 lines
6.9 KiB
Python

"""DBSchemaAssembler."""
import os
from typing import Any, List, Optional
from dbgpt.core import Chunk, Embeddings
from dbgpt.datasource.base import BaseConnector
from ...serve.rag.connector import VectorStoreConnector
from ...storage.vector_store.base import VectorStoreConfig
from ..assembler.base import BaseAssembler
from ..chunk_manager import ChunkParameters
from ..embedding.embedding_factory import DefaultEmbeddingFactory
from ..knowledge.datasource import DatasourceKnowledge
from ..retriever.db_schema import DBSchemaRetriever
class DBSchemaAssembler(BaseAssembler):
"""DBSchemaAssembler.
Example:
.. code-block:: python
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnector
from dbgpt.serve.rag.assembler.db_struct import DBSchemaAssembler
from dbgpt.storage.vector_store.connector import VectorStoreConnector
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
connection = SQLiteTempConnector.create_temporary_db()
assembler = DBSchemaAssembler.load_from_connection(
connector=connection,
embedding_model=embedding_model_path,
)
assembler.persist()
# get db struct retriever
retriever = assembler.as_retriever(top_k=3)
"""
def __init__(
self,
connector: BaseConnector,
table_vector_store_connector: VectorStoreConnector,
field_vector_store_connector: VectorStoreConnector = None,
chunk_parameters: Optional[ChunkParameters] = None,
embedding_model: Optional[str] = None,
embeddings: Optional[Embeddings] = None,
max_seq_length: int = 512,
**kwargs: Any,
) -> None:
"""Initialize with Embedding Assembler arguments.
Args:
connector: (BaseConnector) BaseConnector connection.
table_vector_store_connector: VectorStoreConnector to load
and retrieve table info.
field_vector_store_connector: VectorStoreConnector to load
and retrieve field info.
chunk_manager: (Optional[ChunkManager]) ChunkManager to use for chunking.
embedding_model: (Optional[str]) Embedding model to use.
embeddings: (Optional[Embeddings]) Embeddings to use.
"""
self._connector = connector
self._table_vector_store_connector = table_vector_store_connector
field_vector_store_config = VectorStoreConfig(
name=table_vector_store_connector.vector_store_config.name + "_field"
)
self._field_vector_store_connector = (
field_vector_store_connector
or VectorStoreConnector.from_default(
os.getenv("VECTOR_STORE_TYPE", "Chroma"),
self._table_vector_store_connector.current_embeddings,
vector_store_config=field_vector_store_config,
)
)
self._embedding_model = embedding_model
if self._embedding_model and not embeddings:
embeddings = DefaultEmbeddingFactory(
default_model_name=self._embedding_model
).create(self._embedding_model)
if (
embeddings
and self._table_vector_store_connector.vector_store_config.embedding_fn
is None
):
self._table_vector_store_connector.vector_store_config.embedding_fn = (
embeddings
)
if (
embeddings
and self._field_vector_store_connector.vector_store_config.embedding_fn
is None
):
self._field_vector_store_connector.vector_store_config.embedding_fn = (
embeddings
)
knowledge = DatasourceKnowledge(connector, model_dimension=max_seq_length)
super().__init__(
knowledge=knowledge,
chunk_parameters=chunk_parameters,
**kwargs,
)
@classmethod
def load_from_connection(
cls,
connector: BaseConnector,
table_vector_store_connector: VectorStoreConnector,
field_vector_store_connector: VectorStoreConnector = None,
chunk_parameters: Optional[ChunkParameters] = None,
embedding_model: Optional[str] = None,
embeddings: Optional[Embeddings] = None,
max_seq_length: int = 512,
) -> "DBSchemaAssembler":
"""Load document embedding into vector store from path.
Args:
connector: (BaseConnector) BaseConnector connection.
table_vector_store_connector: used to load table chunks.
field_vector_store_connector: used to load field chunks
if field in table is too much.
chunk_parameters: (Optional[ChunkParameters]) ChunkManager to use for
chunking.
embedding_model: (Optional[str]) Embedding model to use.
embeddings: (Optional[Embeddings]) Embeddings to use.
max_seq_length: Embedding model max sequence length
Returns:
DBSchemaAssembler
"""
return cls(
connector=connector,
table_vector_store_connector=table_vector_store_connector,
field_vector_store_connector=field_vector_store_connector,
embedding_model=embedding_model,
chunk_parameters=chunk_parameters,
embeddings=embeddings,
max_seq_length=max_seq_length,
)
def get_chunks(self) -> List[Chunk]:
"""Return chunk ids."""
return self._chunks
def persist(self, **kwargs: Any) -> List[str]:
"""Persist chunks into vector store.
Returns:
List[str]: List of chunk ids.
"""
table_chunks, field_chunks = [], []
for chunk in self._chunks:
metadata = chunk.metadata
if metadata.get("separated"):
if metadata.get("part") == "table":
table_chunks.append(chunk)
else:
field_chunks.append(chunk)
else:
table_chunks.append(chunk)
self._field_vector_store_connector.load_document(field_chunks)
return self._table_vector_store_connector.load_document(table_chunks)
def _extract_info(self, chunks) -> List[Chunk]:
"""Extract info from chunks."""
return []
def as_retriever(self, top_k: int = 4, **kwargs) -> DBSchemaRetriever:
"""Create DBSchemaRetriever.
Args:
top_k(int): default 4.
Returns:
DBSchemaRetriever
"""
return DBSchemaRetriever(
top_k=top_k,
connector=self._connector,
is_embeddings=True,
table_vector_store_connector=self._table_vector_store_connector,
field_vector_store_connector=self._field_vector_store_connector,
)