mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-12 20:53:48 +00:00
feat(rag): Support RAG SDK (#1322)
This commit is contained in:
@@ -1 +1,11 @@
|
||||
"""Module of RAG."""
|
||||
|
||||
from dbgpt.core import Chunk, Document # noqa: F401
|
||||
|
||||
from .chunk_manager import ChunkParameters # noqa: F401
|
||||
|
||||
__ALL__ = [
|
||||
"Chunk",
|
||||
"Document",
|
||||
"ChunkParameters",
|
||||
]
|
||||
|
16
dbgpt/rag/assembler/__init__.py
Normal file
16
dbgpt/rag/assembler/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""Assembler Module For RAG.
|
||||
|
||||
The Assembler is a module that is responsible for assembling the knowledge.
|
||||
"""
|
||||
|
||||
from .base import BaseAssembler # noqa: F401
|
||||
from .db_schema import DBSchemaAssembler # noqa: F401
|
||||
from .embedding import EmbeddingAssembler # noqa: F401
|
||||
from .summary import SummaryAssembler # noqa: F401
|
||||
|
||||
__all__ = [
|
||||
"BaseAssembler",
|
||||
"DBSchemaAssembler",
|
||||
"EmbeddingAssembler",
|
||||
"SummaryAssembler",
|
||||
]
|
75
dbgpt/rag/assembler/base.py
Normal file
75
dbgpt/rag/assembler/base.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""Base Assembler."""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.util.tracer import root_tracer
|
||||
|
||||
from ..chunk_manager import ChunkManager, ChunkParameters
|
||||
from ..extractor.base import Extractor
|
||||
from ..knowledge.base import Knowledge
|
||||
from ..retriever.base import BaseRetriever
|
||||
|
||||
|
||||
class BaseAssembler(ABC):
|
||||
"""Base Assembler."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
knowledge: Knowledge,
|
||||
chunk_parameters: Optional[ChunkParameters] = None,
|
||||
extractor: Optional[Extractor] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize with Assembler arguments.
|
||||
|
||||
Args:
|
||||
knowledge(Knowledge): Knowledge datasource.
|
||||
chunk_parameters: (Optional[ChunkParameters]) ChunkManager to use for
|
||||
chunking.
|
||||
extractor(Optional[Extractor]): Extractor to use for summarization.
|
||||
"""
|
||||
self._knowledge = knowledge
|
||||
self._chunk_parameters = chunk_parameters or ChunkParameters()
|
||||
self._extractor = extractor
|
||||
self._chunk_manager = ChunkManager(
|
||||
knowledge=self._knowledge, chunk_parameter=self._chunk_parameters
|
||||
)
|
||||
self._chunks: List[Chunk] = []
|
||||
metadata = {
|
||||
"knowledge_cls": self._knowledge.__class__.__name__
|
||||
if self._knowledge
|
||||
else None,
|
||||
"knowledge_type": self._knowledge.type().value if self._knowledge else None,
|
||||
"path": self._knowledge._path
|
||||
if self._knowledge and hasattr(self._knowledge, "_path")
|
||||
else None,
|
||||
"chunk_parameters": self._chunk_parameters.dict(),
|
||||
}
|
||||
with root_tracer.start_span("BaseAssembler.load_knowledge", metadata=metadata):
|
||||
self.load_knowledge(self._knowledge)
|
||||
|
||||
def load_knowledge(self, knowledge: Optional[Knowledge] = None) -> None:
|
||||
"""Load knowledge Pipeline."""
|
||||
if not knowledge:
|
||||
raise ValueError("knowledge must be provided.")
|
||||
with root_tracer.start_span("BaseAssembler.knowledge.load"):
|
||||
documents = knowledge.load()
|
||||
with root_tracer.start_span("BaseAssembler.chunk_manager.split"):
|
||||
self._chunks = self._chunk_manager.split(documents)
|
||||
|
||||
@abstractmethod
|
||||
def as_retriever(self, **kwargs: Any) -> BaseRetriever:
|
||||
"""Return a retriever."""
|
||||
|
||||
@abstractmethod
|
||||
def persist(self) -> List[str]:
|
||||
"""Persist chunks.
|
||||
|
||||
Returns:
|
||||
List[str]: List of persisted chunk ids.
|
||||
"""
|
||||
|
||||
def get_chunks(self) -> List[Chunk]:
|
||||
"""Return chunks."""
|
||||
return self._chunks
|
135
dbgpt/rag/assembler/db_schema.py
Normal file
135
dbgpt/rag/assembler/db_schema.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""DBSchemaAssembler."""
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from dbgpt.core import Chunk, Embeddings
|
||||
from dbgpt.datasource.base import BaseConnector
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
from ..assembler.base import BaseAssembler
|
||||
from ..chunk_manager import ChunkParameters
|
||||
from ..embedding.embedding_factory import DefaultEmbeddingFactory
|
||||
from ..knowledge.datasource import DatasourceKnowledge
|
||||
from ..retriever.db_schema import DBSchemaRetriever
|
||||
|
||||
|
||||
class DBSchemaAssembler(BaseAssembler):
|
||||
"""DBSchemaAssembler.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnector
|
||||
from dbgpt.serve.rag.assembler.db_struct import DBSchemaAssembler
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
||||
|
||||
connection = SQLiteTempConnector.create_temporary_db()
|
||||
assembler = DBSchemaAssembler.load_from_connection(
|
||||
connector=connection,
|
||||
embedding_model=embedding_model_path,
|
||||
)
|
||||
assembler.persist()
|
||||
# get db struct retriever
|
||||
retriever = assembler.as_retriever(top_k=3)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connector: BaseConnector,
|
||||
vector_store_connector: VectorStoreConnector,
|
||||
chunk_parameters: Optional[ChunkParameters] = None,
|
||||
embedding_model: Optional[str] = None,
|
||||
embeddings: Optional[Embeddings] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize with Embedding Assembler arguments.
|
||||
|
||||
Args:
|
||||
connector: (BaseConnector) BaseConnector connection.
|
||||
vector_store_connector: (VectorStoreConnector) VectorStoreConnector to use.
|
||||
chunk_manager: (Optional[ChunkManager]) ChunkManager to use for chunking.
|
||||
embedding_model: (Optional[str]) Embedding model to use.
|
||||
embeddings: (Optional[Embeddings]) Embeddings to use.
|
||||
"""
|
||||
knowledge = DatasourceKnowledge(connector)
|
||||
self._connector = connector
|
||||
self._vector_store_connector = vector_store_connector
|
||||
|
||||
self._embedding_model = embedding_model
|
||||
if self._embedding_model and not embeddings:
|
||||
embeddings = DefaultEmbeddingFactory(
|
||||
default_model_name=self._embedding_model
|
||||
).create(self._embedding_model)
|
||||
|
||||
if (
|
||||
embeddings
|
||||
and self._vector_store_connector.vector_store_config.embedding_fn is None
|
||||
):
|
||||
self._vector_store_connector.vector_store_config.embedding_fn = embeddings
|
||||
|
||||
super().__init__(
|
||||
knowledge=knowledge,
|
||||
chunk_parameters=chunk_parameters,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load_from_connection(
|
||||
cls,
|
||||
connector: BaseConnector,
|
||||
vector_store_connector: VectorStoreConnector,
|
||||
chunk_parameters: Optional[ChunkParameters] = None,
|
||||
embedding_model: Optional[str] = None,
|
||||
embeddings: Optional[Embeddings] = None,
|
||||
) -> "DBSchemaAssembler":
|
||||
"""Load document embedding into vector store from path.
|
||||
|
||||
Args:
|
||||
connector: (BaseConnector) BaseConnector connection.
|
||||
vector_store_connector: (VectorStoreConnector) VectorStoreConnector to use.
|
||||
chunk_parameters: (Optional[ChunkParameters]) ChunkManager to use for
|
||||
chunking.
|
||||
embedding_model: (Optional[str]) Embedding model to use.
|
||||
embeddings: (Optional[Embeddings]) Embeddings to use.
|
||||
Returns:
|
||||
DBSchemaAssembler
|
||||
"""
|
||||
return cls(
|
||||
connector=connector,
|
||||
vector_store_connector=vector_store_connector,
|
||||
embedding_model=embedding_model,
|
||||
chunk_parameters=chunk_parameters,
|
||||
embeddings=embeddings,
|
||||
)
|
||||
|
||||
def get_chunks(self) -> List[Chunk]:
|
||||
"""Return chunk ids."""
|
||||
return self._chunks
|
||||
|
||||
def persist(self) -> List[str]:
|
||||
"""Persist chunks into vector store.
|
||||
|
||||
Returns:
|
||||
List[str]: List of chunk ids.
|
||||
"""
|
||||
return self._vector_store_connector.load_document(self._chunks)
|
||||
|
||||
def _extract_info(self, chunks) -> List[Chunk]:
|
||||
"""Extract info from chunks."""
|
||||
return []
|
||||
|
||||
def as_retriever(self, top_k: int = 4, **kwargs) -> DBSchemaRetriever:
|
||||
"""Create DBSchemaRetriever.
|
||||
|
||||
Args:
|
||||
top_k(int): default 4.
|
||||
|
||||
Returns:
|
||||
DBSchemaRetriever
|
||||
"""
|
||||
return DBSchemaRetriever(
|
||||
top_k=top_k,
|
||||
connector=self._connector,
|
||||
is_embeddings=True,
|
||||
vector_store_connector=self._vector_store_connector,
|
||||
)
|
124
dbgpt/rag/assembler/embedding.py
Normal file
124
dbgpt/rag/assembler/embedding.py
Normal file
@@ -0,0 +1,124 @@
|
||||
"""Embedding Assembler."""
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from dbgpt.core import Chunk, Embeddings
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
from ..assembler.base import BaseAssembler
|
||||
from ..chunk_manager import ChunkParameters
|
||||
from ..embedding.embedding_factory import DefaultEmbeddingFactory
|
||||
from ..knowledge.base import Knowledge
|
||||
from ..retriever.embedding import EmbeddingRetriever
|
||||
|
||||
|
||||
class EmbeddingAssembler(BaseAssembler):
|
||||
"""Embedding Assembler.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from dbgpt.rag.assembler import EmbeddingAssembler
|
||||
|
||||
pdf_path = "path/to/document.pdf"
|
||||
knowledge = KnowledgeFactory.from_file_path(pdf_path)
|
||||
assembler = EmbeddingAssembler.load_from_knowledge(
|
||||
knowledge=knowledge,
|
||||
embedding_model="text2vec",
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
knowledge: Knowledge,
|
||||
vector_store_connector: VectorStoreConnector,
|
||||
chunk_parameters: Optional[ChunkParameters] = None,
|
||||
embedding_model: Optional[str] = None,
|
||||
embeddings: Optional[Embeddings] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize with Embedding Assembler arguments.
|
||||
|
||||
Args:
|
||||
knowledge: (Knowledge) Knowledge datasource.
|
||||
vector_store_connector: (VectorStoreConnector) VectorStoreConnector to use.
|
||||
chunk_parameters: (Optional[ChunkParameters]) ChunkManager to use for
|
||||
chunking.
|
||||
embedding_model: (Optional[str]) Embedding model to use.
|
||||
embeddings: (Optional[Embeddings]) Embeddings to use.
|
||||
"""
|
||||
if knowledge is None:
|
||||
raise ValueError("knowledge datasource must be provided.")
|
||||
self._vector_store_connector = vector_store_connector
|
||||
|
||||
self._embedding_model = embedding_model
|
||||
if self._embedding_model and not embeddings:
|
||||
embeddings = DefaultEmbeddingFactory(
|
||||
default_model_name=self._embedding_model
|
||||
).create(self._embedding_model)
|
||||
|
||||
if (
|
||||
embeddings
|
||||
and self._vector_store_connector.vector_store_config.embedding_fn is None
|
||||
):
|
||||
self._vector_store_connector.vector_store_config.embedding_fn = embeddings
|
||||
|
||||
super().__init__(
|
||||
knowledge=knowledge,
|
||||
chunk_parameters=chunk_parameters,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load_from_knowledge(
|
||||
cls,
|
||||
knowledge: Knowledge,
|
||||
vector_store_connector: VectorStoreConnector,
|
||||
chunk_parameters: Optional[ChunkParameters] = None,
|
||||
embedding_model: Optional[str] = None,
|
||||
embeddings: Optional[Embeddings] = None,
|
||||
) -> "EmbeddingAssembler":
|
||||
"""Load document embedding into vector store from path.
|
||||
|
||||
Args:
|
||||
knowledge: (Knowledge) Knowledge datasource.
|
||||
vector_store_connector: (VectorStoreConnector) VectorStoreConnector to use.
|
||||
chunk_parameters: (Optional[ChunkParameters]) ChunkManager to use for
|
||||
chunking.
|
||||
embedding_model: (Optional[str]) Embedding model to use.
|
||||
embeddings: (Optional[Embeddings]) Embeddings to use.
|
||||
|
||||
Returns:
|
||||
EmbeddingAssembler
|
||||
"""
|
||||
return cls(
|
||||
knowledge=knowledge,
|
||||
vector_store_connector=vector_store_connector,
|
||||
chunk_parameters=chunk_parameters,
|
||||
embedding_model=embedding_model,
|
||||
embeddings=embeddings,
|
||||
)
|
||||
|
||||
def persist(self) -> List[str]:
|
||||
"""Persist chunks into vector store.
|
||||
|
||||
Returns:
|
||||
List[str]: List of chunk ids.
|
||||
"""
|
||||
return self._vector_store_connector.load_document(self._chunks)
|
||||
|
||||
def _extract_info(self, chunks) -> List[Chunk]:
|
||||
"""Extract info from chunks."""
|
||||
return []
|
||||
|
||||
def as_retriever(self, top_k: int = 4, **kwargs) -> EmbeddingRetriever:
|
||||
"""Create a retriever.
|
||||
|
||||
Args:
|
||||
top_k(int): default 4.
|
||||
|
||||
Returns:
|
||||
EmbeddingRetriever
|
||||
"""
|
||||
return EmbeddingRetriever(
|
||||
top_k=top_k, vector_store_connector=self._vector_store_connector
|
||||
)
|
131
dbgpt/rag/assembler/summary.py
Normal file
131
dbgpt/rag/assembler/summary.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""Summary Assembler."""
|
||||
import os
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from dbgpt.core import Chunk, LLMClient
|
||||
|
||||
from ..assembler.base import BaseAssembler
|
||||
from ..chunk_manager import ChunkParameters
|
||||
from ..extractor.base import Extractor
|
||||
from ..knowledge.base import Knowledge
|
||||
from ..retriever.base import BaseRetriever
|
||||
|
||||
|
||||
class SummaryAssembler(BaseAssembler):
|
||||
"""Summary Assembler.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
pdf_path = "../../../DB-GPT/docs/docs/awel.md"
|
||||
OPEN_AI_KEY = "{your_api_key}"
|
||||
OPEN_AI_BASE = "{your_api_base}"
|
||||
llm_client = OpenAILLMClient(api_key=OPEN_AI_KEY, api_base=OPEN_AI_BASE)
|
||||
knowledge = KnowledgeFactory.from_file_path(pdf_path)
|
||||
chunk_parameters = ChunkParameters(chunk_strategy="CHUNK_BY_SIZE")
|
||||
assembler = SummaryAssembler.load_from_knowledge(
|
||||
knowledge=knowledge,
|
||||
chunk_parameters=chunk_parameters,
|
||||
llm_client=llm_client,
|
||||
model_name="gpt-3.5-turbo",
|
||||
)
|
||||
summary = await assembler.generate_summary()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
knowledge: Knowledge,
|
||||
chunk_parameters: Optional[ChunkParameters] = None,
|
||||
model_name: Optional[str] = None,
|
||||
llm_client: Optional[LLMClient] = None,
|
||||
extractor: Optional[Extractor] = None,
|
||||
language: Optional[str] = "en",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize with Embedding Assembler arguments.
|
||||
|
||||
Args:
|
||||
knowledge: (Knowledge) Knowledge datasource.
|
||||
chunk_manager: (Optional[ChunkManager]) ChunkManager to use for chunking.
|
||||
model_name: (Optional[str]) llm model to use.
|
||||
llm_client: (Optional[LLMClient]) LLMClient to use.
|
||||
extractor: (Optional[Extractor]) Extractor to use for summarization.
|
||||
language: (Optional[str]) The language of the prompt. Defaults to "en".
|
||||
"""
|
||||
if knowledge is None:
|
||||
raise ValueError("knowledge datasource must be provided.")
|
||||
|
||||
model_name = model_name or os.getenv("LLM_MODEL")
|
||||
|
||||
if not extractor:
|
||||
from ..extractor.summary import SummaryExtractor
|
||||
|
||||
if not llm_client:
|
||||
raise ValueError("llm_client must be provided.")
|
||||
if not model_name:
|
||||
raise ValueError("model_name must be provided.")
|
||||
extractor = SummaryExtractor(
|
||||
llm_client=llm_client,
|
||||
model_name=model_name,
|
||||
language=language,
|
||||
)
|
||||
if not extractor:
|
||||
raise ValueError("extractor must be provided.")
|
||||
|
||||
self._extractor: Extractor = extractor
|
||||
super().__init__(
|
||||
knowledge=knowledge,
|
||||
chunk_parameters=chunk_parameters,
|
||||
extractor=self._extractor,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load_from_knowledge(
|
||||
cls,
|
||||
knowledge: Knowledge,
|
||||
chunk_parameters: Optional[ChunkParameters] = None,
|
||||
model_name: Optional[str] = None,
|
||||
llm_client: Optional[LLMClient] = None,
|
||||
extractor: Optional[Extractor] = None,
|
||||
language: Optional[str] = "en",
|
||||
**kwargs: Any,
|
||||
) -> "SummaryAssembler":
|
||||
"""Load document embedding into vector store from path.
|
||||
|
||||
Args:
|
||||
knowledge: (Knowledge) Knowledge datasource.
|
||||
chunk_parameters: (Optional[ChunkParameters]) ChunkManager to use for
|
||||
chunking.
|
||||
model_name: (Optional[str]) llm model to use.
|
||||
llm_client: (Optional[LLMClient]) LLMClient to use.
|
||||
extractor: (Optional[Extractor]) Extractor to use for summarization.
|
||||
language: (Optional[str]) The language of the prompt. Defaults to "en".
|
||||
Returns:
|
||||
SummaryAssembler
|
||||
"""
|
||||
return cls(
|
||||
knowledge=knowledge,
|
||||
chunk_parameters=chunk_parameters,
|
||||
model_name=model_name,
|
||||
llm_client=llm_client,
|
||||
extractor=extractor,
|
||||
language=language,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def generate_summary(self) -> str:
|
||||
"""Generate summary."""
|
||||
return await self._extractor.aextract(self._chunks)
|
||||
|
||||
def persist(self) -> List[str]:
|
||||
"""Persist chunks into store."""
|
||||
raise NotImplementedError
|
||||
|
||||
def _extract_info(self, chunks) -> List[Chunk]:
|
||||
"""Extract info from chunks."""
|
||||
return []
|
||||
|
||||
def as_retriever(self, **kwargs: Any) -> BaseRetriever:
|
||||
"""Return a retriever."""
|
||||
raise NotImplementedError
|
0
dbgpt/rag/assembler/tests/__init__.py
Normal file
0
dbgpt/rag/assembler/tests/__init__.py
Normal file
76
dbgpt/rag/assembler/tests/test_db_struct_assembler.py
Normal file
76
dbgpt/rag/assembler/tests/test_db_struct_assembler.py
Normal file
@@ -0,0 +1,76 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnector
|
||||
from dbgpt.rag.assembler.embedding import EmbeddingAssembler
|
||||
from dbgpt.rag.chunk_manager import ChunkParameters, SplitterType
|
||||
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
|
||||
from dbgpt.rag.knowledge.base import Knowledge
|
||||
from dbgpt.rag.text_splitter.text_splitter import CharacterTextSplitter
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_connection():
|
||||
"""Create a temporary database connection for testing."""
|
||||
connect = SQLiteTempConnector.create_temporary_db()
|
||||
connect.create_temp_tables(
|
||||
{
|
||||
"user": {
|
||||
"columns": {
|
||||
"id": "INTEGER PRIMARY KEY",
|
||||
"name": "TEXT",
|
||||
"age": "INTEGER",
|
||||
},
|
||||
"data": [
|
||||
(1, "Tom", 10),
|
||||
(2, "Jerry", 16),
|
||||
(3, "Jack", 18),
|
||||
(4, "Alice", 20),
|
||||
(5, "Bob", 22),
|
||||
],
|
||||
}
|
||||
}
|
||||
)
|
||||
return connect
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_chunk_parameters():
|
||||
return MagicMock(spec=ChunkParameters)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_embedding_factory():
|
||||
return MagicMock(spec=EmbeddingFactory)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vector_store_connector():
|
||||
return MagicMock(spec=VectorStoreConnector)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_knowledge():
|
||||
return MagicMock(spec=Knowledge)
|
||||
|
||||
|
||||
def test_load_knowledge(
|
||||
mock_db_connection,
|
||||
mock_knowledge,
|
||||
mock_chunk_parameters,
|
||||
mock_embedding_factory,
|
||||
mock_vector_store_connector,
|
||||
):
|
||||
mock_chunk_parameters.chunk_strategy = "CHUNK_BY_SIZE"
|
||||
mock_chunk_parameters.text_splitter = CharacterTextSplitter()
|
||||
mock_chunk_parameters.splitter_type = SplitterType.USER_DEFINE
|
||||
assembler = EmbeddingAssembler(
|
||||
knowledge=mock_knowledge,
|
||||
chunk_parameters=mock_chunk_parameters,
|
||||
embeddings=mock_embedding_factory.create(),
|
||||
vector_store_connector=mock_vector_store_connector,
|
||||
)
|
||||
assembler.load_knowledge(knowledge=mock_knowledge)
|
||||
assert len(assembler._chunks) == 0
|
68
dbgpt/rag/assembler/tests/test_embedding_assembler.py
Normal file
68
dbgpt/rag/assembler/tests/test_embedding_assembler.py
Normal file
@@ -0,0 +1,68 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnector
|
||||
from dbgpt.rag.assembler.db_schema import DBSchemaAssembler
|
||||
from dbgpt.rag.chunk_manager import ChunkParameters, SplitterType
|
||||
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
|
||||
from dbgpt.rag.text_splitter.text_splitter import CharacterTextSplitter
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_connection():
|
||||
"""Create a temporary database connection for testing."""
|
||||
connect = SQLiteTempConnector.create_temporary_db()
|
||||
connect.create_temp_tables(
|
||||
{
|
||||
"user": {
|
||||
"columns": {
|
||||
"id": "INTEGER PRIMARY KEY",
|
||||
"name": "TEXT",
|
||||
"age": "INTEGER",
|
||||
},
|
||||
"data": [
|
||||
(1, "Tom", 10),
|
||||
(2, "Jerry", 16),
|
||||
(3, "Jack", 18),
|
||||
(4, "Alice", 20),
|
||||
(5, "Bob", 22),
|
||||
],
|
||||
}
|
||||
}
|
||||
)
|
||||
return connect
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_chunk_parameters():
|
||||
return MagicMock(spec=ChunkParameters)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_embedding_factory():
|
||||
return MagicMock(spec=EmbeddingFactory)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vector_store_connector():
|
||||
return MagicMock(spec=VectorStoreConnector)
|
||||
|
||||
|
||||
def test_load_knowledge(
|
||||
mock_db_connection,
|
||||
mock_chunk_parameters,
|
||||
mock_embedding_factory,
|
||||
mock_vector_store_connector,
|
||||
):
|
||||
mock_chunk_parameters.chunk_strategy = "CHUNK_BY_SIZE"
|
||||
mock_chunk_parameters.text_splitter = CharacterTextSplitter()
|
||||
mock_chunk_parameters.splitter_type = SplitterType.USER_DEFINE
|
||||
assembler = DBSchemaAssembler(
|
||||
connector=mock_db_connection,
|
||||
chunk_parameters=mock_chunk_parameters,
|
||||
embeddings=mock_embedding_factory.create(),
|
||||
vector_store_connector=mock_vector_store_connector,
|
||||
)
|
||||
assert len(assembler._chunks) == 1
|
@@ -1,6 +1,10 @@
|
||||
"""Module for embedding related classes and functions."""
|
||||
|
||||
from .embedding_factory import DefaultEmbeddingFactory, EmbeddingFactory # noqa: F401
|
||||
from .embedding_factory import ( # noqa: F401
|
||||
DefaultEmbeddingFactory,
|
||||
EmbeddingFactory,
|
||||
WrappedEmbeddingFactory,
|
||||
)
|
||||
from .embeddings import ( # noqa: F401
|
||||
Embeddings,
|
||||
HuggingFaceBgeEmbeddings,
|
||||
@@ -21,4 +25,5 @@ __ALL__ = [
|
||||
"OpenAPIEmbeddings",
|
||||
"DefaultEmbeddingFactory",
|
||||
"EmbeddingFactory",
|
||||
"WrappedEmbeddingFactory",
|
||||
]
|
||||
|
32
dbgpt/rag/embedding/_wrapped.py
Normal file
32
dbgpt/rag/embedding/_wrapped.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""Wraps the third-party language model embeddings to the common interface."""
|
||||
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
from dbgpt.core import Embeddings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain.embeddings.base import Embeddings as LangChainEmbeddings
|
||||
|
||||
|
||||
class WrappedEmbeddings(Embeddings):
|
||||
"""Wraps the third-party language model embeddings to the common interface."""
|
||||
|
||||
def __init__(self, embeddings: "LangChainEmbeddings") -> None:
|
||||
"""Create a new WrappedEmbeddings."""
|
||||
self._embeddings = embeddings
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Embed search docs."""
|
||||
return self._embeddings.embed_documents(texts)
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Embed query text."""
|
||||
return self._embeddings.embed_query(text)
|
||||
|
||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Asynchronous Embed search docs."""
|
||||
return await self._embeddings.aembed_documents(texts)
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
"""Asynchronous Embed query text."""
|
||||
return await self._embeddings.aembed_query(text)
|
@@ -1,15 +1,14 @@
|
||||
"""EmbeddingFactory class and DefaultEmbeddingFactory class."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Optional, Type
|
||||
from typing import Any, Optional, Type
|
||||
|
||||
from dbgpt.component import BaseComponent, SystemApp
|
||||
from dbgpt.rag.embedding.embeddings import HuggingFaceEmbeddings
|
||||
from dbgpt.core import Embeddings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dbgpt.rag.embedding.embeddings import Embeddings
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmbeddingFactory(BaseComponent, ABC):
|
||||
@@ -20,7 +19,7 @@ class EmbeddingFactory(BaseComponent, ABC):
|
||||
@abstractmethod
|
||||
def create(
|
||||
self, model_name: Optional[str] = None, embedding_cls: Optional[Type] = None
|
||||
) -> "Embeddings":
|
||||
) -> Embeddings:
|
||||
"""Create an embedding instance.
|
||||
|
||||
Args:
|
||||
@@ -39,12 +38,19 @@ class DefaultEmbeddingFactory(EmbeddingFactory):
|
||||
self,
|
||||
system_app: Optional[SystemApp] = None,
|
||||
default_model_name: Optional[str] = None,
|
||||
default_model_path: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Create a new DefaultEmbeddingFactory."""
|
||||
super().__init__(system_app=system_app)
|
||||
if not default_model_path:
|
||||
default_model_path = default_model_name
|
||||
if not default_model_name:
|
||||
default_model_name = default_model_path
|
||||
self._default_model_name = default_model_name
|
||||
self.kwargs = kwargs
|
||||
self._default_model_path = default_model_path
|
||||
self._kwargs = kwargs
|
||||
self._model = self._load_model()
|
||||
|
||||
def init_app(self, system_app):
|
||||
"""Init the app."""
|
||||
@@ -52,20 +58,166 @@ class DefaultEmbeddingFactory(EmbeddingFactory):
|
||||
|
||||
def create(
|
||||
self, model_name: Optional[str] = None, embedding_cls: Optional[Type] = None
|
||||
) -> "Embeddings":
|
||||
) -> Embeddings:
|
||||
"""Create an embedding instance.
|
||||
|
||||
Args:
|
||||
model_name (str): The model name.
|
||||
embedding_cls (Type): The embedding class.
|
||||
"""
|
||||
if not model_name:
|
||||
model_name = self._default_model_name
|
||||
|
||||
new_kwargs = {k: v for k, v in self.kwargs.items()}
|
||||
new_kwargs["model_name"] = model_name
|
||||
|
||||
if embedding_cls:
|
||||
return embedding_cls(**new_kwargs)
|
||||
else:
|
||||
return HuggingFaceEmbeddings(**new_kwargs)
|
||||
raise NotImplementedError
|
||||
return self._model
|
||||
|
||||
def _load_model(self) -> Embeddings:
|
||||
from dbgpt.model.adapter.embeddings_loader import (
|
||||
EmbeddingLoader,
|
||||
_parse_embedding_params,
|
||||
)
|
||||
from dbgpt.model.parameter import (
|
||||
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
|
||||
BaseEmbeddingModelParameters,
|
||||
EmbeddingModelParameters,
|
||||
)
|
||||
|
||||
param_cls = EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get(
|
||||
self._default_model_name, EmbeddingModelParameters
|
||||
)
|
||||
model_params: BaseEmbeddingModelParameters = _parse_embedding_params(
|
||||
model_name=self._default_model_name,
|
||||
model_path=self._default_model_path,
|
||||
param_cls=param_cls,
|
||||
**self._kwargs,
|
||||
)
|
||||
logger.info(model_params)
|
||||
loader = EmbeddingLoader()
|
||||
# Ignore model_name args
|
||||
model_name = self._default_model_name or model_params.model_name
|
||||
if not model_name:
|
||||
raise ValueError("model_name must be provided.")
|
||||
return loader.load(model_name, model_params)
|
||||
|
||||
@classmethod
|
||||
def openai(
|
||||
cls,
|
||||
api_url: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
model_name: str = "text-embedding-3-small",
|
||||
timeout: int = 60,
|
||||
**kwargs: Any,
|
||||
) -> Embeddings:
|
||||
"""Create an OpenAI embeddings.
|
||||
|
||||
If api_url and api_key are not provided, we will try to get them from
|
||||
environment variables.
|
||||
|
||||
Args:
|
||||
api_url (Optional[str], optional): The api url. Defaults to None.
|
||||
api_key (Optional[str], optional): The api key. Defaults to None.
|
||||
model_name (str, optional): The model name.
|
||||
Defaults to "text-embedding-3-small".
|
||||
timeout (int, optional): The timeout. Defaults to 60.
|
||||
|
||||
Returns:
|
||||
Embeddings: The embeddings instance.
|
||||
"""
|
||||
api_url = (
|
||||
api_url
|
||||
or os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1") + "/embeddings"
|
||||
)
|
||||
api_key = api_key or os.getenv("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
raise ValueError("api_key must be provided.")
|
||||
return cls.remote(
|
||||
api_url=api_url,
|
||||
api_key=api_key,
|
||||
model_name=model_name,
|
||||
timeout=timeout,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def default(
|
||||
cls, model_name: str, model_path: Optional[str] = None, **kwargs: Any
|
||||
) -> Embeddings:
|
||||
"""Create a default embeddings.
|
||||
|
||||
It will try to load the model from the model name or model path.
|
||||
|
||||
Args:
|
||||
model_name (str): The model name.
|
||||
model_path (Optional[str], optional): The model path. Defaults to None.
|
||||
if not provided, it will use the model name as the model path to load
|
||||
the model.
|
||||
|
||||
Returns:
|
||||
Embeddings: The embeddings instance.
|
||||
"""
|
||||
return cls(
|
||||
default_model_name=model_name, default_model_path=model_path, **kwargs
|
||||
).create()
|
||||
|
||||
@classmethod
|
||||
def remote(
|
||||
cls,
|
||||
api_url: str = "http://localhost:8100/api/v1/embeddings",
|
||||
api_key: Optional[str] = None,
|
||||
model_name: str = "text2vec",
|
||||
timeout: int = 60,
|
||||
**kwargs: Any,
|
||||
) -> Embeddings:
|
||||
"""Create a remote embeddings.
|
||||
|
||||
Create a remote embeddings which API compatible with the OpenAI's API. So if
|
||||
your model is compatible with OpenAI's API, you can use this method to create
|
||||
a remote embeddings.
|
||||
|
||||
Args:
|
||||
api_url (str, optional): The api url. Defaults to
|
||||
"http://localhost:8100/api/v1/embeddings".
|
||||
api_key (Optional[str], optional): The api key. Defaults to None.
|
||||
model_name (str, optional): The model name. Defaults to "text2vec".
|
||||
timeout (int, optional): The timeout. Defaults to 60.
|
||||
"""
|
||||
from .embeddings import OpenAPIEmbeddings
|
||||
|
||||
return OpenAPIEmbeddings(
|
||||
api_url=api_url,
|
||||
api_key=api_key,
|
||||
model_name=model_name,
|
||||
timeout=timeout,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class WrappedEmbeddingFactory(EmbeddingFactory):
|
||||
"""The default embedding factory."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
system_app: Optional[SystemApp] = None,
|
||||
embeddings: Optional[Embeddings] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Create a new DefaultEmbeddingFactory."""
|
||||
super().__init__(system_app=system_app)
|
||||
if not embeddings:
|
||||
raise ValueError("embeddings must be provided.")
|
||||
self._model = embeddings
|
||||
|
||||
def init_app(self, system_app):
|
||||
"""Init the app."""
|
||||
pass
|
||||
|
||||
def create(
|
||||
self, model_name: Optional[str] = None, embedding_cls: Optional[Type] = None
|
||||
) -> Embeddings:
|
||||
"""Create an embedding instance.
|
||||
|
||||
Args:
|
||||
model_name (str): The model name.
|
||||
embedding_cls (Type): The embedding class.
|
||||
"""
|
||||
if embedding_cls:
|
||||
raise NotImplementedError
|
||||
return self._model
|
||||
|
@@ -1,23 +1,50 @@
|
||||
"""Module Of Knowledge."""
|
||||
|
||||
from .base import ChunkStrategy, Knowledge, KnowledgeType # noqa: F401
|
||||
from .csv import CSVKnowledge # noqa: F401
|
||||
from .docx import DocxKnowledge # noqa: F401
|
||||
from .factory import KnowledgeFactory # noqa: F401
|
||||
from .html import HTMLKnowledge # noqa: F401
|
||||
from .markdown import MarkdownKnowledge # noqa: F401
|
||||
from .pdf import PDFKnowledge # noqa: F401
|
||||
from .pptx import PPTXKnowledge # noqa: F401
|
||||
from .string import StringKnowledge # noqa: F401
|
||||
from .txt import TXTKnowledge # noqa: F401
|
||||
from .url import URLKnowledge # noqa: F401
|
||||
from typing import Any, Dict
|
||||
|
||||
__ALL__ = [
|
||||
_MODULE_CACHE: Dict[str, Any] = {}
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
# Lazy load
|
||||
import importlib
|
||||
|
||||
if name in _MODULE_CACHE:
|
||||
return _MODULE_CACHE[name]
|
||||
|
||||
_LIBS = {
|
||||
"KnowledgeFactory": "factory",
|
||||
"Knowledge": "base",
|
||||
"KnowledgeType": "base",
|
||||
"ChunkStrategy": "base",
|
||||
"CSVKnowledge": "csv",
|
||||
"DatasourceKnowledge": "datasource",
|
||||
"DocxKnowledge": "docx",
|
||||
"HTMLKnowledge": "html",
|
||||
"MarkdownKnowledge": "markdown",
|
||||
"PDFKnowledge": "pdf",
|
||||
"PPTXKnowledge": "pptx",
|
||||
"StringKnowledge": "string",
|
||||
"TXTKnowledge": "txt",
|
||||
"URLKnowledge": "url",
|
||||
}
|
||||
|
||||
if name in _LIBS:
|
||||
module_path = "." + _LIBS[name]
|
||||
module = importlib.import_module(module_path, __name__)
|
||||
attr = getattr(module, name)
|
||||
_MODULE_CACHE[name] = attr
|
||||
return attr
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
|
||||
__all__ = [
|
||||
"KnowledgeFactory",
|
||||
"Knowledge",
|
||||
"KnowledgeType",
|
||||
"ChunkStrategy",
|
||||
"CSVKnowledge",
|
||||
"DatasourceKnowledge",
|
||||
"DocxKnowledge",
|
||||
"HTMLKnowledge",
|
||||
"MarkdownKnowledge",
|
||||
|
@@ -25,6 +25,7 @@ class DocumentType(Enum):
|
||||
DOCX = "docx"
|
||||
TXT = "txt"
|
||||
HTML = "html"
|
||||
DATASOURCE = "datasource"
|
||||
|
||||
|
||||
class KnowledgeType(Enum):
|
||||
|
57
dbgpt/rag/knowledge/datasource.py
Normal file
57
dbgpt/rag/knowledge/datasource.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""Datasource Knowledge."""
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from dbgpt.core import Document
|
||||
from dbgpt.datasource import BaseConnector
|
||||
|
||||
from ..summary.rdbms_db_summary import _parse_db_summary
|
||||
from .base import ChunkStrategy, DocumentType, Knowledge, KnowledgeType
|
||||
|
||||
|
||||
class DatasourceKnowledge(Knowledge):
|
||||
"""Datasource Knowledge."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connector: BaseConnector,
|
||||
summary_template: str = "{table_name}({columns})",
|
||||
knowledge_type: Optional[KnowledgeType] = KnowledgeType.DOCUMENT,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Create Datasource Knowledge with Knowledge arguments.
|
||||
|
||||
Args:
|
||||
path(str, optional): file path
|
||||
knowledge_type(KnowledgeType, optional): knowledge type
|
||||
data_loader(Any, optional): loader
|
||||
"""
|
||||
self._connector = connector
|
||||
self._summary_template = summary_template
|
||||
super().__init__(knowledge_type=knowledge_type, **kwargs)
|
||||
|
||||
def _load(self) -> List[Document]:
|
||||
"""Load datasource document from data_loader."""
|
||||
docs = []
|
||||
for table_summary in _parse_db_summary(self._connector, self._summary_template):
|
||||
docs.append(
|
||||
Document(content=table_summary, metadata={"source": "database"})
|
||||
)
|
||||
return docs
|
||||
|
||||
@classmethod
|
||||
def support_chunk_strategy(cls) -> List[ChunkStrategy]:
|
||||
"""Return support chunk strategy."""
|
||||
return [
|
||||
ChunkStrategy.CHUNK_BY_SIZE,
|
||||
ChunkStrategy.CHUNK_BY_SEPARATOR,
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def type(cls) -> KnowledgeType:
|
||||
"""Knowledge type of Datasource."""
|
||||
return KnowledgeType.DOCUMENT
|
||||
|
||||
@classmethod
|
||||
def document_type(cls) -> DocumentType:
|
||||
"""Return document type."""
|
||||
return DocumentType.DATASOURCE
|
@@ -156,6 +156,7 @@ class KnowledgeFactory:
|
||||
"""Get all knowledge subclasses."""
|
||||
from dbgpt.rag.knowledge.base import Knowledge # noqa: F401
|
||||
from dbgpt.rag.knowledge.csv import CSVKnowledge # noqa: F401
|
||||
from dbgpt.rag.knowledge.datasource import DatasourceKnowledge # noqa: F401
|
||||
from dbgpt.rag.knowledge.docx import DocxKnowledge # noqa: F401
|
||||
from dbgpt.rag.knowledge.html import HTMLKnowledge # noqa: F401
|
||||
from dbgpt.rag.knowledge.markdown import MarkdownKnowledge # noqa: F401
|
||||
|
@@ -1,8 +1,14 @@
|
||||
"""Module for RAG operators."""
|
||||
|
||||
from .datasource import DatasourceRetrieverOperator # noqa: F401
|
||||
from .db_schema import DBSchemaRetrieverOperator # noqa: F401
|
||||
from .embedding import EmbeddingRetrieverOperator # noqa: F401
|
||||
from .db_schema import ( # noqa: F401
|
||||
DBSchemaAssemblerOperator,
|
||||
DBSchemaRetrieverOperator,
|
||||
)
|
||||
from .embedding import ( # noqa: F401
|
||||
EmbeddingAssemblerOperator,
|
||||
EmbeddingRetrieverOperator,
|
||||
)
|
||||
from .evaluation import RetrieverEvaluatorOperator # noqa: F401
|
||||
from .knowledge import KnowledgeOperator # noqa: F401
|
||||
from .rerank import RerankOperator # noqa: F401
|
||||
@@ -12,7 +18,9 @@ from .summary import SummaryAssemblerOperator # noqa: F401
|
||||
__all__ = [
|
||||
"DatasourceRetrieverOperator",
|
||||
"DBSchemaRetrieverOperator",
|
||||
"DBSchemaAssemblerOperator",
|
||||
"EmbeddingRetrieverOperator",
|
||||
"EmbeddingAssemblerOperator",
|
||||
"KnowledgeOperator",
|
||||
"RerankOperator",
|
||||
"QueryRewriteOperator",
|
||||
|
24
dbgpt/rag/operators/assembler.py
Normal file
24
dbgpt/rag/operators/assembler.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""Base Assembler Operator."""
|
||||
from abc import abstractmethod
|
||||
|
||||
from dbgpt.core.awel import MapOperator
|
||||
from dbgpt.core.awel.task.base import IN, OUT
|
||||
|
||||
|
||||
class AssemblerOperator(MapOperator[IN, OUT]):
|
||||
"""The Base Assembler Operator."""
|
||||
|
||||
async def map(self, input_value: IN) -> OUT:
|
||||
"""Map input value to output value.
|
||||
|
||||
Args:
|
||||
input_value (IN): The input value.
|
||||
|
||||
Returns:
|
||||
OUT: The output value.
|
||||
"""
|
||||
return await self.blocking_func_to_async(self.assemble, input_value)
|
||||
|
||||
@abstractmethod
|
||||
def assemble(self, input_value: IN) -> OUT:
|
||||
"""Assemble knowledge for input value."""
|
@@ -1,21 +1,21 @@
|
||||
"""Datasource operator for RDBMS database."""
|
||||
|
||||
from typing import Any
|
||||
from typing import Any, List
|
||||
|
||||
from dbgpt.core.interface.operators.retriever import RetrieverOperator
|
||||
from dbgpt.datasource.rdbms.base import RDBMSConnector
|
||||
from dbgpt.datasource.base import BaseConnector
|
||||
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
|
||||
|
||||
|
||||
class DatasourceRetrieverOperator(RetrieverOperator[Any, Any]):
|
||||
class DatasourceRetrieverOperator(RetrieverOperator[Any, List[str]]):
|
||||
"""The Datasource Retriever Operator."""
|
||||
|
||||
def __init__(self, connection: RDBMSConnector, **kwargs):
|
||||
def __init__(self, connector: BaseConnector, **kwargs):
|
||||
"""Create a new DatasourceRetrieverOperator."""
|
||||
super().__init__(**kwargs)
|
||||
self._connection = connection
|
||||
self._connector = connector
|
||||
|
||||
def retrieve(self, input_value: Any) -> Any:
|
||||
def retrieve(self, input_value: Any) -> List[str]:
|
||||
"""Retrieve the database summary."""
|
||||
summary = _parse_db_summary(self._connection)
|
||||
summary = _parse_db_summary(self._connector)
|
||||
return summary
|
||||
|
@@ -1,18 +1,22 @@
|
||||
"""The DBSchema Retriever Operator."""
|
||||
|
||||
from typing import Any, Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.core.interface.operators.retriever import RetrieverOperator
|
||||
from dbgpt.datasource.rdbms.base import RDBMSConnector
|
||||
from dbgpt.rag.retriever.db_schema import DBSchemaRetriever
|
||||
from dbgpt.datasource.base import BaseConnector
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
from ..assembler.db_schema import DBSchemaAssembler
|
||||
from ..retriever.db_schema import DBSchemaRetriever
|
||||
from .assembler import AssemblerOperator
|
||||
|
||||
class DBSchemaRetrieverOperator(RetrieverOperator[Any, Any]):
|
||||
|
||||
class DBSchemaRetrieverOperator(RetrieverOperator[str, List[Chunk]]):
|
||||
"""The DBSchema Retriever Operator.
|
||||
|
||||
Args:
|
||||
connection (RDBMSConnector): The connection.
|
||||
connector (BaseConnector): The connection.
|
||||
top_k (int, optional): The top k. Defaults to 4.
|
||||
vector_store_connector (VectorStoreConnector, optional): The vector store
|
||||
connector. Defaults to None.
|
||||
@@ -22,21 +26,57 @@ class DBSchemaRetrieverOperator(RetrieverOperator[Any, Any]):
|
||||
self,
|
||||
vector_store_connector: VectorStoreConnector,
|
||||
top_k: int = 4,
|
||||
connection: Optional[RDBMSConnector] = None,
|
||||
connector: Optional[BaseConnector] = None,
|
||||
**kwargs
|
||||
):
|
||||
"""Create a new DBSchemaRetrieverOperator."""
|
||||
super().__init__(**kwargs)
|
||||
self._retriever = DBSchemaRetriever(
|
||||
top_k=top_k,
|
||||
connection=connection,
|
||||
connector=connector,
|
||||
vector_store_connector=vector_store_connector,
|
||||
)
|
||||
|
||||
def retrieve(self, query: Any) -> Any:
|
||||
def retrieve(self, query: str) -> List[Chunk]:
|
||||
"""Retrieve the table schemas.
|
||||
|
||||
Args:
|
||||
query (IN): query.
|
||||
query (str): The query.
|
||||
"""
|
||||
return self._retriever.retrieve(query)
|
||||
|
||||
|
||||
class DBSchemaAssemblerOperator(AssemblerOperator[BaseConnector, List[Chunk]]):
|
||||
"""The DBSchema Assembler Operator."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connector: BaseConnector,
|
||||
vector_store_connector: VectorStoreConnector,
|
||||
**kwargs
|
||||
):
|
||||
"""Create a new DBSchemaAssemblerOperator.
|
||||
|
||||
Args:
|
||||
connector (BaseConnector): The connection.
|
||||
vector_store_connector (VectorStoreConnector): The vector store connector.
|
||||
"""
|
||||
self._vector_store_connector = vector_store_connector
|
||||
self._connector = connector
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def assemble(self, dummy_value) -> List[Chunk]:
|
||||
"""Persist the database schema.
|
||||
|
||||
Args:
|
||||
dummy_value: Dummy value, not used.
|
||||
|
||||
Returns:
|
||||
List[Chunk]: The chunks.
|
||||
"""
|
||||
assembler = DBSchemaAssembler.load_from_connection(
|
||||
connector=self._connector,
|
||||
vector_store_connector=self._vector_store_connector,
|
||||
)
|
||||
assembler.persist()
|
||||
return assembler.get_chunks()
|
||||
|
@@ -5,11 +5,16 @@ from typing import List, Optional, Union
|
||||
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.core.interface.operators.retriever import RetrieverOperator
|
||||
from dbgpt.rag.retriever.embedding import EmbeddingRetriever
|
||||
from dbgpt.rag.retriever.rerank import Ranker
|
||||
from dbgpt.rag.retriever.rewrite import QueryRewrite
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
from ..assembler.embedding import EmbeddingAssembler
|
||||
from ..chunk_manager import ChunkParameters
|
||||
from ..knowledge import Knowledge
|
||||
from ..retriever.embedding import EmbeddingRetriever
|
||||
from ..retriever.rerank import Ranker
|
||||
from ..retriever.rewrite import QueryRewrite
|
||||
from .assembler import AssemblerOperator
|
||||
|
||||
|
||||
class EmbeddingRetrieverOperator(RetrieverOperator[Union[str, List[str]], List[Chunk]]):
|
||||
"""The Embedding Retriever Operator."""
|
||||
@@ -43,3 +48,36 @@ class EmbeddingRetrieverOperator(RetrieverOperator[Union[str, List[str]], List[C
|
||||
for q in query
|
||||
]
|
||||
return reduce(lambda x, y: x + y, candidates)
|
||||
|
||||
|
||||
class EmbeddingAssemblerOperator(AssemblerOperator[Knowledge, List[Chunk]]):
|
||||
"""The Embedding Assembler Operator."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vector_store_connector: VectorStoreConnector,
|
||||
chunk_parameters: Optional[ChunkParameters] = ChunkParameters(
|
||||
chunk_strategy="CHUNK_BY_SIZE"
|
||||
),
|
||||
**kwargs
|
||||
):
|
||||
"""Create a new EmbeddingAssemblerOperator.
|
||||
|
||||
Args:
|
||||
vector_store_connector (VectorStoreConnector): The vector store connector.
|
||||
chunk_parameters (Optional[ChunkParameters], optional): The chunk
|
||||
parameters. Defaults to ChunkParameters(chunk_strategy="CHUNK_BY_SIZE").
|
||||
"""
|
||||
self._chunk_parameters = chunk_parameters
|
||||
self._vector_store_connector = vector_store_connector
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def assemble(self, knowledge: Knowledge) -> List[Chunk]:
|
||||
"""Assemble knowledge for input value."""
|
||||
assembler = EmbeddingAssembler.load_from_knowledge(
|
||||
knowledge=knowledge,
|
||||
chunk_parameters=self._chunk_parameters,
|
||||
vector_store_connector=self._vector_store_connector,
|
||||
)
|
||||
assembler.persist()
|
||||
return assembler.get_chunks()
|
||||
|
@@ -1,6 +1,6 @@
|
||||
"""Knowledge Operator."""
|
||||
|
||||
from typing import Any, Optional
|
||||
from typing import Optional
|
||||
|
||||
from dbgpt.core.awel import MapOperator
|
||||
from dbgpt.core.awel.flow import (
|
||||
@@ -14,7 +14,7 @@ from dbgpt.rag.knowledge.base import Knowledge, KnowledgeType
|
||||
from dbgpt.rag.knowledge.factory import KnowledgeFactory
|
||||
|
||||
|
||||
class KnowledgeOperator(MapOperator[Any, Any]):
|
||||
class KnowledgeOperator(MapOperator[str, Knowledge]):
|
||||
"""Knowledge Factory Operator."""
|
||||
|
||||
metadata = ViewMetadata(
|
||||
@@ -26,7 +26,7 @@ class KnowledgeOperator(MapOperator[Any, Any]):
|
||||
IOField.build_from(
|
||||
"knowledge datasource",
|
||||
"knowledge datasource",
|
||||
dict,
|
||||
str,
|
||||
"knowledge datasource",
|
||||
)
|
||||
],
|
||||
@@ -85,7 +85,7 @@ class KnowledgeOperator(MapOperator[Any, Any]):
|
||||
self._datasource = datasource
|
||||
self._knowledge_type = KnowledgeType.get_by_value(knowledge_type)
|
||||
|
||||
async def map(self, datasource: Any) -> Knowledge:
|
||||
async def map(self, datasource: str) -> Knowledge:
|
||||
"""Create knowledge from datasource."""
|
||||
if self._datasource:
|
||||
datasource = self._datasource
|
||||
|
@@ -1,12 +1,12 @@
|
||||
"""The Rerank Operator."""
|
||||
from typing import Any, List, Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.core.awel import MapOperator
|
||||
from dbgpt.rag.retriever.rerank import RANK_FUNC, DefaultRanker
|
||||
|
||||
|
||||
class RerankOperator(MapOperator[Any, Any]):
|
||||
class RerankOperator(MapOperator[List[Chunk], List[Chunk]]):
|
||||
"""The Rewrite Operator."""
|
||||
|
||||
def __init__(
|
||||
|
@@ -7,7 +7,7 @@ from typing import Any, Optional
|
||||
|
||||
from dbgpt.core import LLMClient
|
||||
from dbgpt.core.awel import MapOperator
|
||||
from dbgpt.datasource.rdbms.base import RDBMSConnector
|
||||
from dbgpt.datasource.base import BaseConnector
|
||||
from dbgpt.rag.schemalinker.schema_linking import SchemaLinking
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
@@ -17,7 +17,7 @@ class SchemaLinkingOperator(MapOperator[Any, Any]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connection: RDBMSConnector,
|
||||
connector: BaseConnector,
|
||||
model_name: str,
|
||||
llm: LLMClient,
|
||||
top_k: int = 5,
|
||||
@@ -27,14 +27,14 @@ class SchemaLinkingOperator(MapOperator[Any, Any]):
|
||||
"""Create the schema linking operator.
|
||||
|
||||
Args:
|
||||
connection (RDBMSConnector): The connection.
|
||||
connector (BaseConnector): The connection.
|
||||
llm (Optional[LLMClient]): base llm
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._schema_linking = SchemaLinking(
|
||||
top_k=top_k,
|
||||
connection=connection,
|
||||
connector=connector,
|
||||
llm=llm,
|
||||
model_name=model_name,
|
||||
vector_store_connector=vector_store_connector,
|
||||
|
@@ -4,9 +4,9 @@ from typing import Any, Optional
|
||||
|
||||
from dbgpt.core import LLMClient
|
||||
from dbgpt.core.awel.flow import IOField, OperatorCategory, Parameter, ViewMetadata
|
||||
from dbgpt.rag.assembler.summary import SummaryAssembler
|
||||
from dbgpt.rag.knowledge.base import Knowledge
|
||||
from dbgpt.serve.rag.assembler.summary import SummaryAssembler
|
||||
from dbgpt.serve.rag.operators.base import AssemblerOperator
|
||||
from dbgpt.rag.operators.assembler import AssemblerOperator
|
||||
|
||||
|
||||
class SummaryAssemblerOperator(AssemblerOperator[Any, Any]):
|
||||
|
@@ -3,7 +3,7 @@ from functools import reduce
|
||||
from typing import List, Optional, cast
|
||||
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.datasource.rdbms.base import RDBMSConnector
|
||||
from dbgpt.datasource.base import BaseConnector
|
||||
from dbgpt.rag.retriever.base import BaseRetriever
|
||||
from dbgpt.rag.retriever.rerank import DefaultRanker, Ranker
|
||||
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
|
||||
@@ -18,7 +18,7 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
self,
|
||||
vector_store_connector: VectorStoreConnector,
|
||||
top_k: int = 4,
|
||||
connection: Optional[RDBMSConnector] = None,
|
||||
connector: Optional[BaseConnector] = None,
|
||||
query_rewrite: bool = False,
|
||||
rerank: Optional[Ranker] = None,
|
||||
**kwargs
|
||||
@@ -28,7 +28,7 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
Args:
|
||||
vector_store_connector (VectorStoreConnector): vector store connector
|
||||
top_k (int): top k
|
||||
connection (Optional[RDBMSConnector]): RDBMSConnector connection.
|
||||
connector (Optional[BaseConnector]): RDBMSConnector.
|
||||
query_rewrite (bool): query rewrite
|
||||
rerank (Ranker): rerank
|
||||
|
||||
@@ -65,7 +65,7 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
return connect
|
||||
|
||||
|
||||
connection = _create_temporary_connection()
|
||||
connector = _create_temporary_connection()
|
||||
vector_store_config = ChromaVectorConfig(name="vector_store_name")
|
||||
embedding_model_path = "{your_embedding_model_path}"
|
||||
embedding_fn = embedding_factory.create(model_name=embedding_model_path)
|
||||
@@ -76,14 +76,16 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
)
|
||||
# get db struct retriever
|
||||
retriever = DBSchemaRetriever(
|
||||
top_k=3, vector_store_connector=vector_connector
|
||||
top_k=3,
|
||||
vector_store_connector=vector_connector,
|
||||
connector=connector,
|
||||
)
|
||||
chunks = retriever.retrieve("show columns from table")
|
||||
result = [chunk.content for chunk in chunks]
|
||||
print(f"db struct rag example results:{result}")
|
||||
"""
|
||||
self._top_k = top_k
|
||||
self._connection = connection
|
||||
self._connector = connector
|
||||
self._query_rewrite = query_rewrite
|
||||
self._vector_store_connector = vector_store_connector
|
||||
self._need_embeddings = False
|
||||
@@ -108,9 +110,9 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
]
|
||||
return cast(List[Chunk], reduce(lambda x, y: x + y, candidates))
|
||||
else:
|
||||
if not self._connection:
|
||||
if not self._connector:
|
||||
raise RuntimeError("RDBMSConnector connection is required.")
|
||||
table_summaries = _parse_db_summary(self._connection)
|
||||
table_summaries = _parse_db_summary(self._connector)
|
||||
return [Chunk(content=table_summary) for table_summary in table_summaries]
|
||||
|
||||
def _retrieve_with_score(self, query: str, score_threshold: float) -> List[Chunk]:
|
||||
@@ -173,6 +175,6 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
"""Similar search."""
|
||||
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
|
||||
|
||||
if not self._connection:
|
||||
if not self._connector:
|
||||
raise RuntimeError("RDBMSConnector connection is required.")
|
||||
return _parse_db_summary(self._connection)
|
||||
return _parse_db_summary(self._connector)
|
||||
|
@@ -24,7 +24,7 @@ def mock_vector_store_connector():
|
||||
@pytest.fixture
|
||||
def dbstruct_retriever(mock_db_connection, mock_vector_store_connector):
|
||||
return DBSchemaRetriever(
|
||||
connection=mock_db_connection,
|
||||
connector=mock_db_connection,
|
||||
vector_store_connector=mock_vector_store_connector,
|
||||
)
|
||||
|
||||
|
@@ -10,7 +10,7 @@ from dbgpt.core import (
|
||||
ModelMessageRoleType,
|
||||
ModelRequest,
|
||||
)
|
||||
from dbgpt.datasource.rdbms.base import RDBMSConnector
|
||||
from dbgpt.datasource.base import BaseConnector
|
||||
from dbgpt.rag.schemalinker.base_linker import BaseSchemaLinker
|
||||
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
@@ -42,7 +42,7 @@ class SchemaLinking(BaseSchemaLinker):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connection: RDBMSConnector,
|
||||
connector: BaseConnector,
|
||||
model_name: str,
|
||||
llm: LLMClient,
|
||||
top_k: int = 5,
|
||||
@@ -52,19 +52,19 @@ class SchemaLinking(BaseSchemaLinker):
|
||||
"""Create the schema linking instance.
|
||||
|
||||
Args:
|
||||
connection (Optional[RDBMSConnector]): RDBMSConnector connection.
|
||||
connection (Optional[BaseConnector]): BaseConnector connection.
|
||||
llm (Optional[LLMClient]): base llm
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._top_k = top_k
|
||||
self._connection = connection
|
||||
self._connector = connector
|
||||
self._llm = llm
|
||||
self._model_name = model_name
|
||||
self._vector_store_connector = vector_store_connector
|
||||
|
||||
def _schema_linking(self, query: str) -> List:
|
||||
"""Get all db schema info."""
|
||||
table_summaries = _parse_db_summary(self._connection)
|
||||
table_summaries = _parse_db_summary(self._connector)
|
||||
chunks = [Chunk(content=table_summary) for table_summary in table_summaries]
|
||||
chunks_content = [chunk.content for chunk in chunks]
|
||||
return chunks_content
|
||||
|
@@ -97,10 +97,10 @@ class DBSummaryClient:
|
||||
vector_store_config=vector_store_config,
|
||||
)
|
||||
if not vector_connector.vector_name_exists():
|
||||
from dbgpt.serve.rag.assembler.db_schema import DBSchemaAssembler
|
||||
from dbgpt.rag.assembler.db_schema import DBSchemaAssembler
|
||||
|
||||
db_assembler = DBSchemaAssembler.load_from_connection(
|
||||
connection=db_summary_client.db, vector_store_connector=vector_connector
|
||||
connector=db_summary_client.db, vector_store_connector=vector_connector
|
||||
)
|
||||
if len(db_assembler.get_chunks()) > 0:
|
||||
db_assembler.persist()
|
||||
|
@@ -3,7 +3,7 @@
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.datasource.rdbms.base import RDBMSConnector
|
||||
from dbgpt.datasource import BaseConnector
|
||||
from dbgpt.rag.summary.db_summary import DBSummary
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -64,12 +64,12 @@ class RdbmsSummary(DBSummary):
|
||||
|
||||
|
||||
def _parse_db_summary(
|
||||
conn: RDBMSConnector, summary_template: str = "{table_name}({columns})"
|
||||
conn: BaseConnector, summary_template: str = "{table_name}({columns})"
|
||||
) -> List[str]:
|
||||
"""Get db summary for database.
|
||||
|
||||
Args:
|
||||
conn (RDBMSConnector): database connection
|
||||
conn (BaseConnector): database connection
|
||||
summary_template (str): summary template
|
||||
"""
|
||||
tables = conn.get_table_names()
|
||||
@@ -81,12 +81,12 @@ def _parse_db_summary(
|
||||
|
||||
|
||||
def _parse_table_summary(
|
||||
conn: RDBMSConnector, summary_template: str, table_name: str
|
||||
conn: BaseConnector, summary_template: str, table_name: str
|
||||
) -> str:
|
||||
"""Get table summary for table.
|
||||
|
||||
Args:
|
||||
conn (RDBMSConnector): database connection
|
||||
conn (BaseConnector): database connection
|
||||
summary_template (str): summary template
|
||||
table_name (str): table name
|
||||
|
||||
|
Reference in New Issue
Block a user