mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-11 22:09:44 +00:00
feat(rag): Support RAG SDK (#1322)
This commit is contained in:
@@ -1,70 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.rag.chunk_manager import ChunkManager, ChunkParameters
|
||||
from dbgpt.rag.extractor.base import Extractor
|
||||
from dbgpt.rag.knowledge.base import Knowledge
|
||||
from dbgpt.rag.retriever.base import BaseRetriever
|
||||
from dbgpt.util.tracer import root_tracer
|
||||
|
||||
|
||||
class BaseAssembler(ABC):
|
||||
"""Base Assembler"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
knowledge: Optional[Knowledge] = None,
|
||||
chunk_parameters: Optional[ChunkParameters] = None,
|
||||
extractor: Optional[Extractor] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize with Assembler arguments.
|
||||
Args:
|
||||
knowledge: (Knowledge) Knowledge datasource.
|
||||
chunk_parameters: (Optional[ChunkParameters]) ChunkManager to use for chunking.
|
||||
extractor: (Optional[Extractor]) Extractor to use for summarization."""
|
||||
self._knowledge = knowledge
|
||||
self._chunk_parameters = chunk_parameters or ChunkParameters()
|
||||
self._extractor = extractor
|
||||
self._chunk_manager = ChunkManager(
|
||||
knowledge=self._knowledge, chunk_parameter=self._chunk_parameters
|
||||
)
|
||||
self._chunks = None
|
||||
metadata = {
|
||||
"knowledge_cls": self._knowledge.__class__.__name__
|
||||
if self._knowledge
|
||||
else None,
|
||||
"knowledge_type": self._knowledge.type().value if self._knowledge else None,
|
||||
"path": self._knowledge._path
|
||||
if self._knowledge and hasattr(self._knowledge, "_path")
|
||||
else None,
|
||||
"chunk_parameters": self._chunk_parameters.dict(),
|
||||
}
|
||||
with root_tracer.start_span("BaseAssembler.load_knowledge", metadata=metadata):
|
||||
self.load_knowledge(self._knowledge)
|
||||
|
||||
def load_knowledge(self, knowledge: Optional[Knowledge] = None) -> None:
|
||||
"""Load knowledge Pipeline."""
|
||||
if not knowledge:
|
||||
raise ValueError("knowledge must be provided.")
|
||||
with root_tracer.start_span("BaseAssembler.knowledge.load"):
|
||||
documents = knowledge.load()
|
||||
with root_tracer.start_span("BaseAssembler.chunk_manager.split"):
|
||||
self._chunks = self._chunk_manager.split(documents)
|
||||
|
||||
@abstractmethod
|
||||
def as_retriever(self, **kwargs: Any) -> BaseRetriever:
|
||||
"""Return a retriever."""
|
||||
|
||||
@abstractmethod
|
||||
def persist(self) -> List[str]:
|
||||
"""Persist chunks.
|
||||
|
||||
Returns:
|
||||
List[str]: List of persisted chunk ids.
|
||||
"""
|
||||
|
||||
def get_chunks(self) -> List[Chunk]:
|
||||
"""Return chunks."""
|
||||
return self._chunks
|
@@ -1,153 +0,0 @@
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.datasource.rdbms.base import RDBMSConnector
|
||||
from dbgpt.rag.chunk_manager import ChunkManager, ChunkParameters
|
||||
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
|
||||
from dbgpt.rag.knowledge.base import ChunkStrategy, Knowledge
|
||||
from dbgpt.rag.knowledge.factory import KnowledgeFactory
|
||||
from dbgpt.rag.retriever.db_schema import DBSchemaRetriever
|
||||
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
|
||||
from dbgpt.serve.rag.assembler.base import BaseAssembler
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
|
||||
class DBSchemaAssembler(BaseAssembler):
|
||||
"""DBSchemaAssembler
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnector
|
||||
from dbgpt.serve.rag.assembler.db_struct import DBSchemaAssembler
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
||||
|
||||
connection = SQLiteTempConnector.create_temporary_db()
|
||||
assembler = DBSchemaAssembler.load_from_connection(
|
||||
connection=connection,
|
||||
embedding_model=embedding_model_path,
|
||||
)
|
||||
assembler.persist()
|
||||
# get db struct retriever
|
||||
retriever = assembler.as_retriever(top_k=3)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connection: RDBMSConnector = None,
|
||||
chunk_parameters: Optional[ChunkParameters] = None,
|
||||
embedding_model: Optional[str] = None,
|
||||
embedding_factory: Optional[EmbeddingFactory] = None,
|
||||
vector_store_connector: Optional[VectorStoreConnector] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize with Embedding Assembler arguments.
|
||||
Args:
|
||||
connection: (RDBMSConnector) RDBMSConnector connection.
|
||||
knowledge: (Knowledge) Knowledge datasource.
|
||||
chunk_manager: (Optional[ChunkManager]) ChunkManager to use for chunking.
|
||||
embedding_model: (Optional[str]) Embedding model to use.
|
||||
embedding_factory: (Optional[EmbeddingFactory]) EmbeddingFactory to use.
|
||||
vector_store_connector: (Optional[VectorStoreConnector]) VectorStoreConnector to use.
|
||||
"""
|
||||
if connection is None:
|
||||
raise ValueError("datasource connection must be provided.")
|
||||
self._connection = connection
|
||||
self._vector_store_connector = vector_store_connector
|
||||
from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory
|
||||
|
||||
self._embedding_model = embedding_model
|
||||
if self._embedding_model:
|
||||
embedding_factory = embedding_factory or DefaultEmbeddingFactory(
|
||||
default_model_name=self._embedding_model
|
||||
)
|
||||
self.embedding_fn = embedding_factory.create(self._embedding_model)
|
||||
if self._vector_store_connector.vector_store_config.embedding_fn is None:
|
||||
self._vector_store_connector.vector_store_config.embedding_fn = (
|
||||
self.embedding_fn
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
chunk_parameters=chunk_parameters,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load_from_connection(
|
||||
cls,
|
||||
connection: RDBMSConnector = None,
|
||||
knowledge: Optional[Knowledge] = None,
|
||||
chunk_parameters: Optional[ChunkParameters] = None,
|
||||
embedding_model: Optional[str] = None,
|
||||
embedding_factory: Optional[EmbeddingFactory] = None,
|
||||
vector_store_connector: Optional[VectorStoreConnector] = None,
|
||||
) -> "DBSchemaAssembler":
|
||||
"""Load document embedding into vector store from path.
|
||||
Args:
|
||||
connection: (RDBMSConnector) RDBMSDatabase connection.
|
||||
knowledge: (Knowledge) Knowledge datasource.
|
||||
chunk_parameters: (Optional[ChunkParameters]) ChunkManager to use for chunking.
|
||||
embedding_model: (Optional[str]) Embedding model to use.
|
||||
embedding_factory: (Optional[EmbeddingFactory]) EmbeddingFactory to use.
|
||||
vector_store_connector: (Optional[VectorStoreConnector]) VectorStoreConnector to use.
|
||||
Returns:
|
||||
DBSchemaAssembler
|
||||
"""
|
||||
embedding_factory = embedding_factory
|
||||
chunk_parameters = chunk_parameters or ChunkParameters(
|
||||
chunk_strategy=ChunkStrategy.CHUNK_BY_SIZE.name, chunk_overlap=0
|
||||
)
|
||||
|
||||
return cls(
|
||||
connection=connection,
|
||||
knowledge=knowledge,
|
||||
embedding_model=embedding_model,
|
||||
chunk_parameters=chunk_parameters,
|
||||
embedding_factory=embedding_factory,
|
||||
vector_store_connector=vector_store_connector,
|
||||
)
|
||||
|
||||
def load_knowledge(self, knowledge: Optional[Knowledge] = None) -> None:
|
||||
table_summaries = _parse_db_summary(self._connection)
|
||||
self._chunks = []
|
||||
self._knowledge = knowledge
|
||||
for table_summary in table_summaries:
|
||||
from dbgpt.rag.knowledge.base import KnowledgeType
|
||||
|
||||
self._knowledge = KnowledgeFactory.from_text(
|
||||
text=table_summary, knowledge_type=KnowledgeType.DOCUMENT
|
||||
)
|
||||
self._chunk_parameters.chunk_size = len(table_summary)
|
||||
self._chunk_manager = ChunkManager(
|
||||
knowledge=self._knowledge, chunk_parameter=self._chunk_parameters
|
||||
)
|
||||
self._chunks.extend(self._chunk_manager.split(self._knowledge.load()))
|
||||
|
||||
def get_chunks(self) -> List[Chunk]:
|
||||
"""Return chunk ids."""
|
||||
return self._chunks
|
||||
|
||||
def persist(self) -> List[str]:
|
||||
"""Persist chunks into vector store.
|
||||
|
||||
Returns:
|
||||
List[str]: List of chunk ids.
|
||||
"""
|
||||
return self._vector_store_connector.load_document(self._chunks)
|
||||
|
||||
def _extract_info(self, chunks) -> List[Chunk]:
|
||||
"""Extract info from chunks."""
|
||||
|
||||
def as_retriever(self, top_k: Optional[int] = 4) -> DBSchemaRetriever:
|
||||
"""
|
||||
Args:
|
||||
top_k:(Optional[int]), default 4
|
||||
Returns:
|
||||
DBSchemaRetriever
|
||||
"""
|
||||
return DBSchemaRetriever(
|
||||
top_k=top_k,
|
||||
connection=self._connection,
|
||||
is_embeddings=True,
|
||||
vector_store_connector=self._vector_store_connector,
|
||||
)
|
@@ -1,122 +0,0 @@
|
||||
import os
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.rag.chunk_manager import ChunkParameters
|
||||
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
|
||||
from dbgpt.rag.knowledge.base import Knowledge
|
||||
from dbgpt.rag.retriever.embedding import EmbeddingRetriever
|
||||
from dbgpt.serve.rag.assembler.base import BaseAssembler
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
|
||||
class EmbeddingAssembler(BaseAssembler):
|
||||
"""Embedding Assembler
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from dbgpt.rag.assembler import EmbeddingAssembler
|
||||
|
||||
pdf_path = "path/to/document.pdf"
|
||||
knowledge = KnowledgeFactory.from_file_path(pdf_path)
|
||||
assembler = EmbeddingAssembler.load_from_knowledge(
|
||||
knowledge=knowledge,
|
||||
embedding_model="text2vec",
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
knowledge: Knowledge,
|
||||
chunk_parameters: Optional[ChunkParameters] = None,
|
||||
embedding_model: Optional[str] = None,
|
||||
embedding_factory: Optional[EmbeddingFactory] = None,
|
||||
vector_store_connector: Optional[VectorStoreConnector] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize with Embedding Assembler arguments.
|
||||
Args:
|
||||
knowledge: (Knowledge) Knowledge datasource.
|
||||
chunk_parameters: (Optional[ChunkParameters]) ChunkManager to use for chunking.
|
||||
embedding_model: (Optional[str]) Embedding model to use.
|
||||
embedding_factory: (Optional[EmbeddingFactory]) EmbeddingFactory to use.
|
||||
vector_store_connector: (Optional[VectorStoreConnector]) VectorStoreConnector to use.
|
||||
"""
|
||||
if knowledge is None:
|
||||
raise ValueError("knowledge datasource must be provided.")
|
||||
self._vector_store_connector = vector_store_connector
|
||||
from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory
|
||||
|
||||
self._embedding_model = embedding_model
|
||||
if self._embedding_model:
|
||||
embedding_factory = embedding_factory or DefaultEmbeddingFactory(
|
||||
default_model_name=self._embedding_model
|
||||
)
|
||||
self.embedding_fn = embedding_factory.create(self._embedding_model)
|
||||
if self._vector_store_connector.vector_store_config.embedding_fn is None:
|
||||
self._vector_store_connector.vector_store_config.embedding_fn = (
|
||||
self.embedding_fn
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
knowledge=knowledge,
|
||||
chunk_parameters=chunk_parameters,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load_from_knowledge(
|
||||
cls,
|
||||
knowledge: Knowledge,
|
||||
chunk_parameters: Optional[ChunkParameters] = None,
|
||||
embedding_model: Optional[str] = None,
|
||||
embedding_factory: Optional[EmbeddingFactory] = None,
|
||||
vector_store_connector: Optional[VectorStoreConnector] = None,
|
||||
) -> "EmbeddingAssembler":
|
||||
"""Load document embedding into vector store from path.
|
||||
Args:
|
||||
knowledge: (Knowledge) Knowledge datasource.
|
||||
chunk_parameters: (Optional[ChunkParameters]) ChunkManager to use for chunking.
|
||||
embedding_model: (Optional[str]) Embedding model to use.
|
||||
embedding_factory: (Optional[EmbeddingFactory]) EmbeddingFactory to use.
|
||||
vector_store_connector: (Optional[VectorStoreConnector]) VectorStoreConnector to use.
|
||||
Returns:
|
||||
EmbeddingAssembler
|
||||
"""
|
||||
from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory
|
||||
|
||||
embedding_factory = embedding_factory or DefaultEmbeddingFactory(
|
||||
default_model_name=embedding_model or os.getenv("EMBEDDING_MODEL_PATH")
|
||||
)
|
||||
return cls(
|
||||
knowledge=knowledge,
|
||||
chunk_parameters=chunk_parameters,
|
||||
embedding_model=embedding_model,
|
||||
embedding_factory=embedding_factory,
|
||||
vector_store_connector=vector_store_connector,
|
||||
)
|
||||
|
||||
def persist(self) -> List[str]:
|
||||
"""Persist chunks into vector store.
|
||||
|
||||
Returns:
|
||||
List[str]: List of chunk ids.
|
||||
"""
|
||||
return self._vector_store_connector.load_document(self._chunks)
|
||||
|
||||
def _extract_info(self, chunks) -> List[Chunk]:
|
||||
"""Extract info from chunks."""
|
||||
pass
|
||||
|
||||
def as_retriever(self, top_k: Optional[int] = 4) -> EmbeddingRetriever:
|
||||
"""
|
||||
Args:
|
||||
top_k:(Optional[int]), default 4
|
||||
Returns:
|
||||
EmbeddingRetriever
|
||||
"""
|
||||
return EmbeddingRetriever(
|
||||
top_k=top_k, vector_store_connector=self._vector_store_connector
|
||||
)
|
@@ -1,112 +0,0 @@
|
||||
import os
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from dbgpt.core import Chunk, LLMClient
|
||||
from dbgpt.rag.chunk_manager import ChunkParameters
|
||||
from dbgpt.rag.extractor.base import Extractor
|
||||
from dbgpt.rag.knowledge.base import Knowledge
|
||||
from dbgpt.rag.retriever.base import BaseRetriever
|
||||
from dbgpt.serve.rag.assembler.base import BaseAssembler
|
||||
|
||||
|
||||
class SummaryAssembler(BaseAssembler):
|
||||
"""Summary Assembler
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
pdf_path = "../../../DB-GPT/docs/docs/awel.md"
|
||||
OPEN_AI_KEY = "{your_api_key}"
|
||||
OPEN_AI_BASE = "{your_api_base}"
|
||||
llm_client = OpenAILLMClient(api_key=OPEN_AI_KEY, api_base=OPEN_AI_BASE)
|
||||
knowledge = KnowledgeFactory.from_file_path(pdf_path)
|
||||
chunk_parameters = ChunkParameters(chunk_strategy="CHUNK_BY_SIZE")
|
||||
assembler = SummaryAssembler.load_from_knowledge(
|
||||
knowledge=knowledge,
|
||||
chunk_parameters=chunk_parameters,
|
||||
llm_client=llm_client,
|
||||
model_name="gpt-3.5-turbo",
|
||||
)
|
||||
summary = await assembler.generate_summary()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
knowledge: Knowledge,
|
||||
chunk_parameters: Optional[ChunkParameters] = None,
|
||||
model_name: Optional[str] = None,
|
||||
llm_client: Optional[LLMClient] = None,
|
||||
extractor: Optional[Extractor] = None,
|
||||
language: Optional[str] = "en",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize with Embedding Assembler arguments.
|
||||
Args:
|
||||
knowledge: (Knowledge) Knowledge datasource.
|
||||
chunk_manager: (Optional[ChunkManager]) ChunkManager to use for chunking.
|
||||
model_name: (Optional[str]) llm model to use.
|
||||
llm_client: (Optional[LLMClient]) LLMClient to use.
|
||||
extractor: (Optional[Extractor]) Extractor to use for summarization.
|
||||
language: (Optional[str]) The language of the prompt. Defaults to "en".
|
||||
"""
|
||||
if knowledge is None:
|
||||
raise ValueError("knowledge datasource must be provided.")
|
||||
|
||||
self._model_name = model_name or os.getenv("LLM_MODEL")
|
||||
self._llm_client = llm_client
|
||||
from dbgpt.rag.extractor.summary import SummaryExtractor
|
||||
|
||||
self._extractor = extractor or SummaryExtractor(
|
||||
llm_client=self._llm_client, model_name=self._model_name, language=language
|
||||
)
|
||||
super().__init__(
|
||||
knowledge=knowledge,
|
||||
chunk_parameters=chunk_parameters,
|
||||
extractor=self._extractor,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load_from_knowledge(
|
||||
cls,
|
||||
knowledge: Knowledge,
|
||||
chunk_parameters: Optional[ChunkParameters] = None,
|
||||
model_name: Optional[str] = None,
|
||||
llm_client: Optional[LLMClient] = None,
|
||||
extractor: Optional[Extractor] = None,
|
||||
language: Optional[str] = "en",
|
||||
**kwargs: Any,
|
||||
) -> "SummaryAssembler":
|
||||
"""Load document embedding into vector store from path.
|
||||
Args:
|
||||
knowledge: (Knowledge) Knowledge datasource.
|
||||
chunk_parameters: (Optional[ChunkParameters]) ChunkManager to use for chunking.
|
||||
model_name: (Optional[str]) llm model to use.
|
||||
llm_client: (Optional[LLMClient]) LLMClient to use.
|
||||
extractor: (Optional[Extractor]) Extractor to use for summarization.
|
||||
language: (Optional[str]) The language of the prompt. Defaults to "en".
|
||||
Returns:
|
||||
SummaryAssembler
|
||||
"""
|
||||
return cls(
|
||||
knowledge=knowledge,
|
||||
chunk_parameters=chunk_parameters,
|
||||
model_name=model_name,
|
||||
llm_client=llm_client,
|
||||
extractor=extractor,
|
||||
language=language,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def generate_summary(self) -> str:
|
||||
"""Generate summary."""
|
||||
return await self._extractor.aextract(self._chunks)
|
||||
|
||||
def persist(self) -> List[str]:
|
||||
"""Persist chunks into store."""
|
||||
raise NotImplementedError
|
||||
|
||||
def _extract_info(self, chunks) -> List[Chunk]:
|
||||
"""Extract info from chunks."""
|
||||
|
||||
def as_retriever(self, **kwargs: Any) -> BaseRetriever:
|
||||
"""Return a retriever."""
|
@@ -1,76 +0,0 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnector
|
||||
from dbgpt.rag.chunk_manager import ChunkParameters, SplitterType
|
||||
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
|
||||
from dbgpt.rag.knowledge.base import Knowledge
|
||||
from dbgpt.rag.text_splitter.text_splitter import CharacterTextSplitter
|
||||
from dbgpt.serve.rag.assembler.embedding import EmbeddingAssembler
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_connection():
|
||||
"""Create a temporary database connection for testing."""
|
||||
connect = SQLiteTempConnector.create_temporary_db()
|
||||
connect.create_temp_tables(
|
||||
{
|
||||
"user": {
|
||||
"columns": {
|
||||
"id": "INTEGER PRIMARY KEY",
|
||||
"name": "TEXT",
|
||||
"age": "INTEGER",
|
||||
},
|
||||
"data": [
|
||||
(1, "Tom", 10),
|
||||
(2, "Jerry", 16),
|
||||
(3, "Jack", 18),
|
||||
(4, "Alice", 20),
|
||||
(5, "Bob", 22),
|
||||
],
|
||||
}
|
||||
}
|
||||
)
|
||||
return connect
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_chunk_parameters():
|
||||
return MagicMock(spec=ChunkParameters)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_embedding_factory():
|
||||
return MagicMock(spec=EmbeddingFactory)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vector_store_connector():
|
||||
return MagicMock(spec=VectorStoreConnector)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_knowledge():
|
||||
return MagicMock(spec=Knowledge)
|
||||
|
||||
|
||||
def test_load_knowledge(
|
||||
mock_db_connection,
|
||||
mock_knowledge,
|
||||
mock_chunk_parameters,
|
||||
mock_embedding_factory,
|
||||
mock_vector_store_connector,
|
||||
):
|
||||
mock_chunk_parameters.chunk_strategy = "CHUNK_BY_SIZE"
|
||||
mock_chunk_parameters.text_splitter = CharacterTextSplitter()
|
||||
mock_chunk_parameters.splitter_type = SplitterType.USER_DEFINE
|
||||
assembler = EmbeddingAssembler(
|
||||
knowledge=mock_knowledge,
|
||||
chunk_parameters=mock_chunk_parameters,
|
||||
embedding_factory=mock_embedding_factory,
|
||||
vector_store_connector=mock_vector_store_connector,
|
||||
)
|
||||
assembler.load_knowledge(knowledge=mock_knowledge)
|
||||
assert len(assembler._chunks) == 0
|
@@ -1,76 +0,0 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnector
|
||||
from dbgpt.rag.chunk_manager import ChunkParameters, SplitterType
|
||||
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
|
||||
from dbgpt.rag.knowledge.base import Knowledge
|
||||
from dbgpt.rag.text_splitter.text_splitter import CharacterTextSplitter
|
||||
from dbgpt.serve.rag.assembler.db_schema import DBSchemaAssembler
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_connection():
|
||||
"""Create a temporary database connection for testing."""
|
||||
connect = SQLiteTempConnector.create_temporary_db()
|
||||
connect.create_temp_tables(
|
||||
{
|
||||
"user": {
|
||||
"columns": {
|
||||
"id": "INTEGER PRIMARY KEY",
|
||||
"name": "TEXT",
|
||||
"age": "INTEGER",
|
||||
},
|
||||
"data": [
|
||||
(1, "Tom", 10),
|
||||
(2, "Jerry", 16),
|
||||
(3, "Jack", 18),
|
||||
(4, "Alice", 20),
|
||||
(5, "Bob", 22),
|
||||
],
|
||||
}
|
||||
}
|
||||
)
|
||||
return connect
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_chunk_parameters():
|
||||
return MagicMock(spec=ChunkParameters)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_embedding_factory():
|
||||
return MagicMock(spec=EmbeddingFactory)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vector_store_connector():
|
||||
return MagicMock(spec=VectorStoreConnector)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_knowledge():
|
||||
return MagicMock(spec=Knowledge)
|
||||
|
||||
|
||||
def test_load_knowledge(
|
||||
mock_db_connection,
|
||||
mock_knowledge,
|
||||
mock_chunk_parameters,
|
||||
mock_embedding_factory,
|
||||
mock_vector_store_connector,
|
||||
):
|
||||
mock_chunk_parameters.chunk_strategy = "CHUNK_BY_SIZE"
|
||||
mock_chunk_parameters.text_splitter = CharacterTextSplitter()
|
||||
mock_chunk_parameters.splitter_type = SplitterType.USER_DEFINE
|
||||
assembler = DBSchemaAssembler(
|
||||
connection=mock_db_connection,
|
||||
chunk_parameters=mock_chunk_parameters,
|
||||
embedding_factory=mock_embedding_factory,
|
||||
vector_store_connector=mock_vector_store_connector,
|
||||
)
|
||||
assembler.load_knowledge(knowledge=mock_knowledge)
|
||||
assert len(assembler._chunks) == 1
|
@@ -1,23 +0,0 @@
|
||||
from abc import abstractmethod
|
||||
|
||||
from dbgpt.core.awel import MapOperator
|
||||
from dbgpt.core.awel.task.base import IN, OUT
|
||||
|
||||
|
||||
class AssemblerOperator(MapOperator[IN, OUT]):
|
||||
"""The Base Assembler Operator."""
|
||||
|
||||
async def map(self, input_value: IN) -> OUT:
|
||||
"""Map input value to output value.
|
||||
|
||||
Args:
|
||||
input_value (IN): The input value.
|
||||
|
||||
Returns:
|
||||
OUT: The output value.
|
||||
"""
|
||||
return await self.blocking_func_to_async(self.assemble, input_value)
|
||||
|
||||
@abstractmethod
|
||||
def assemble(self, input_value: IN) -> OUT:
|
||||
"""assemble knowledge for input value."""
|
@@ -1,36 +0,0 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from dbgpt.core.awel.task.base import IN
|
||||
from dbgpt.datasource.rdbms.base import RDBMSConnector
|
||||
from dbgpt.serve.rag.assembler.db_schema import DBSchemaAssembler
|
||||
from dbgpt.serve.rag.operators.base import AssemblerOperator
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
|
||||
class DBSchemaAssemblerOperator(AssemblerOperator[Any, Any]):
|
||||
"""The DBSchema Assembler Operator.
|
||||
Args:
|
||||
connection (RDBMSConnector): The connection.
|
||||
chunk_parameters (Optional[ChunkParameters], optional): The chunk parameters. Defaults to None.
|
||||
vector_store_connector (VectorStoreConnector, optional): The vector store connector. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connection: RDBMSConnector = None,
|
||||
vector_store_connector: Optional[VectorStoreConnector] = None,
|
||||
**kwargs
|
||||
):
|
||||
self._connection = connection
|
||||
self._vector_store_connector = vector_store_connector
|
||||
self._assembler = DBSchemaAssembler.load_from_connection(
|
||||
connection=self._connection,
|
||||
vector_store_connector=self._vector_store_connector,
|
||||
)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def assemble(self, input_value: IN) -> Any:
|
||||
"""assemble knowledge for input value."""
|
||||
if self._vector_store_connector:
|
||||
self._assembler.persist()
|
||||
return self._assembler.get_chunks()
|
@@ -1,44 +0,0 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from dbgpt.core.awel.task.base import IN
|
||||
from dbgpt.rag.chunk_manager import ChunkParameters
|
||||
from dbgpt.rag.knowledge.base import Knowledge
|
||||
from dbgpt.serve.rag.assembler.embedding import EmbeddingAssembler
|
||||
from dbgpt.serve.rag.operators.base import AssemblerOperator
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
|
||||
class EmbeddingAssemblerOperator(AssemblerOperator[Any, Any]):
|
||||
"""The Embedding Assembler Operator.
|
||||
Args:
|
||||
knowledge (Knowledge): The knowledge.
|
||||
chunk_parameters (Optional[ChunkParameters], optional): The chunk parameters. Defaults to None.
|
||||
vector_store_connector (VectorStoreConnector, optional): The vector store connector. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chunk_parameters: Optional[ChunkParameters] = ChunkParameters(
|
||||
chunk_strategy="CHUNK_BY_SIZE"
|
||||
),
|
||||
vector_store_connector: VectorStoreConnector = None,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
chunk_parameters (Optional[ChunkParameters], optional): The chunk parameters. Defaults to ChunkParameters(chunk_strategy="CHUNK_BY_SIZE").
|
||||
vector_store_connector (VectorStoreConnector, optional): The vector store connector. Defaults to None.
|
||||
"""
|
||||
self._chunk_parameters = chunk_parameters
|
||||
self._vector_store_connector = vector_store_connector
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def assemble(self, knowledge: IN) -> Any:
|
||||
"""assemble knowledge for input value."""
|
||||
assembler = EmbeddingAssembler.load_from_knowledge(
|
||||
knowledge=knowledge,
|
||||
chunk_parameters=self._chunk_parameters,
|
||||
vector_store_connector=self._vector_store_connector,
|
||||
)
|
||||
assembler.persist()
|
||||
return assembler.get_chunks()
|
@@ -23,6 +23,7 @@ from dbgpt.configs.model_config import (
|
||||
)
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.core.awel.dag.dag_manager import DAGManager
|
||||
from dbgpt.rag.assembler import EmbeddingAssembler
|
||||
from dbgpt.rag.chunk_manager import ChunkParameters
|
||||
from dbgpt.rag.embedding import EmbeddingFactory
|
||||
from dbgpt.rag.knowledge import ChunkStrategy, KnowledgeFactory, KnowledgeType
|
||||
@@ -43,7 +44,6 @@ from ..api.schemas import (
|
||||
SpaceServeRequest,
|
||||
SpaceServeResponse,
|
||||
)
|
||||
from ..assembler.embedding import EmbeddingAssembler
|
||||
from ..config import SERVE_CONFIG_KEY_PREFIX, SERVE_SERVICE_COMPONENT_NAME, ServeConfig
|
||||
from ..models.models import KnowledgeSpaceDao, KnowledgeSpaceEntity
|
||||
|
||||
|
Reference in New Issue
Block a user