refactor: RAG Refactor (#985)

Co-authored-by: Aralhi <xiaoping0501@gmail.com>
Co-authored-by: csunny <cfqsunny@163.com>
This commit is contained in:
Aries-ckt
2024-01-03 09:45:26 +08:00
committed by GitHub
parent 90775aad50
commit 9ad70a2961
206 changed files with 5766 additions and 2419 deletions

View File

View File

View File

@@ -0,0 +1,50 @@
from abc import ABC, abstractmethod
from typing import Optional, Any, List
from dbgpt.rag.chunk import Chunk
from dbgpt.rag.chunk_manager import ChunkManager, ChunkParameters
from dbgpt.rag.extractor.base import Extractor
from dbgpt.rag.knowledge.base import Knowledge
from dbgpt.rag.retriever.base import BaseRetriever
class BaseAssembler(ABC):
"""Base Assembler"""
def __init__(
self,
knowledge: Optional[Knowledge] = None,
chunk_parameters: Optional[ChunkParameters] = None,
extractor: Optional[Extractor] = None,
**kwargs: Any,
) -> None:
"""Initialize with Assembler arguments.
Args:
knowledge: (Knowledge) Knowledge datasource.
chunk_parameters: (Optional[ChunkParameters]) ChunkManager to use for chunking.
extractor: (Optional[Extractor]) Extractor to use for summarization."""
self._knowledge = knowledge
self._chunk_parameters = chunk_parameters or ChunkParameters()
self._extractor = extractor
self._chunk_manager = ChunkManager(
knowledge=self._knowledge, chunk_parameter=self._chunk_parameters
)
self._chunks = None
self.load_knowledge(self._knowledge)
def load_knowledge(self, knowledge) -> None:
"""Load knowledge Pipeline."""
documents = knowledge.load()
self._chunks = self._chunk_manager.split(documents)
@abstractmethod
def as_retriever(self, **kwargs: Any) -> BaseRetriever:
"""Return a retriever."""
@abstractmethod
def persist(self, chunks: List[Chunk]) -> None:
"""Persist chunks."""
def get_chunks(self) -> List[Chunk]:
"""Return chunks."""
return self._chunks

View File

@@ -0,0 +1,151 @@
import os
from typing import Optional, Any, List
from dbgpt.datasource.rdbms.base import RDBMSDatabase
from dbgpt.rag.chunk import Chunk
from dbgpt.rag.chunk_manager import ChunkParameters, ChunkManager
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
from dbgpt.rag.knowledge.base import Knowledge, ChunkStrategy
from dbgpt.rag.knowledge.factory import KnowledgeFactory
from dbgpt.rag.retriever.db_struct import DBStructRetriever
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
from dbgpt.serve.rag.assembler.base import BaseAssembler
from dbgpt.storage.vector_store.connector import VectorStoreConnector
class DBStructAssembler(BaseAssembler):
"""DBStructAssembler
Example:
.. code-block:: python
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect
from dbgpt.serve.rag.assembler.db_struct import DBStructAssembler
from dbgpt.storage.vector_store.connector import VectorStoreConnector
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
connection = SQLiteTempConnect.create_temporary_db()
assembler = DBStructAssembler.load_from_connection(
connection=connection,
embedding_model=embedding_model_path,
)
assembler.persist()
# get db struct retriever
retriever = assembler.as_retriever(top_k=3)
"""
def __init__(
self,
connection: RDBMSDatabase = None,
chunk_parameters: Optional[ChunkParameters] = None,
embedding_model: Optional[str] = None,
embedding_factory: Optional[EmbeddingFactory] = None,
vector_store_connector: Optional[VectorStoreConnector] = None,
**kwargs: Any,
) -> None:
"""Initialize with Embedding Assembler arguments.
Args:
connection: (RDBMSDatabase) RDBMSDatabase connection.
knowledge: (Knowledge) Knowledge datasource.
chunk_manager: (Optional[ChunkManager]) ChunkManager to use for chunking.
embedding_model: (Optional[str]) Embedding model to use.
embedding_factory: (Optional[EmbeddingFactory]) EmbeddingFactory to use.
vector_store_connector: (Optional[VectorStoreConnector]) VectorStoreConnector to use.
"""
if connection is None:
raise ValueError("datasource connection must be provided.")
from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory
embedding_factory = embedding_factory or DefaultEmbeddingFactory(
default_model_name=os.getenv("EMBEDDING_MODEL")
)
self._connection = connection
if embedding_model:
embedding_fn = embedding_factory.create(model_name=embedding_model)
self._vector_store_connector = (
vector_store_connector
or VectorStoreConnector.from_default(embedding_fn=embedding_fn)
)
super().__init__(
chunk_parameters=chunk_parameters,
**kwargs,
)
@classmethod
def load_from_connection(
cls,
connection: RDBMSDatabase = None,
knowledge: Optional[Knowledge] = None,
chunk_parameters: Optional[ChunkParameters] = None,
embedding_model: Optional[str] = None,
embedding_factory: Optional[EmbeddingFactory] = None,
vector_store_connector: Optional[VectorStoreConnector] = None,
) -> "DBStructAssembler":
"""Load document embedding into vector store from path.
Args:
connection: (RDBMSDatabase) RDBMSDatabase connection.
knowledge: (Knowledge) Knowledge datasource.
chunk_parameters: (Optional[ChunkParameters]) ChunkManager to use for chunking.
embedding_model: (Optional[str]) Embedding model to use.
embedding_factory: (Optional[EmbeddingFactory]) EmbeddingFactory to use.
vector_store_connector: (Optional[VectorStoreConnector]) VectorStoreConnector to use.
Returns:
DBStructAssembler
"""
from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory
embedding_factory = embedding_factory or DefaultEmbeddingFactory(
default_model_name=embedding_model or os.getenv("EMBEDDING_MODEL_PATH")
)
chunk_parameters = chunk_parameters or ChunkParameters(
chunk_strategy=ChunkStrategy.CHUNK_BY_SIZE.name, chunk_overlap=0
)
return cls(
connection=connection,
knowledge=knowledge,
embedding_model=embedding_model,
chunk_parameters=chunk_parameters,
embedding_factory=embedding_factory,
vector_store_connector=vector_store_connector,
)
def load_knowledge(self, knowledge: Optional[Knowledge] = None) -> None:
table_summaries = _parse_db_summary(self._connection)
self._chunks = []
self._knowledge = knowledge
for table_summary in table_summaries:
from dbgpt.rag.knowledge.base import KnowledgeType
self._knowledge = KnowledgeFactory.from_text(
text=table_summary, knowledge_type=KnowledgeType.DOCUMENT
)
self._chunk_parameters.chunk_size = len(table_summary)
self._chunk_manager = ChunkManager(
knowledge=self._knowledge, chunk_parameter=self._chunk_parameters
)
self._chunks.extend(self._chunk_manager.split(self._knowledge.load()))
def get_chunks(self) -> List[Chunk]:
"""Return chunk ids."""
return self._chunks
def persist(self) -> List[str]:
"""Persist chunks into vector store."""
return self._vector_store_connector.load_document(self._chunks)
def _extract_info(self, chunks) -> List[Chunk]:
"""Extract info from chunks."""
def as_retriever(self, top_k: Optional[int] = 4) -> DBStructRetriever:
"""
Args:
top_k:(Optional[int]), default 4
Returns:
DBStructRetriever
"""
return DBStructRetriever(
top_k=top_k,
connection=self._connection,
is_embeddings=True,
vector_store_connector=self._vector_store_connector,
)

View File

@@ -0,0 +1,116 @@
import os
from typing import Optional, Any, List
from dbgpt.rag.chunk import Chunk
from dbgpt.rag.chunk_manager import ChunkParameters
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
from dbgpt.rag.knowledge.base import Knowledge
from dbgpt.rag.retriever.embedding import EmbeddingRetriever
from dbgpt.serve.rag.assembler.base import BaseAssembler
from dbgpt.storage.vector_store.connector import VectorStoreConnector
class EmbeddingAssembler(BaseAssembler):
"""Embedding Assembler
Example:
.. code-block:: python
from dbgpt.rag.assembler import EmbeddingAssembler
pdf_path = "path/to/document.pdf"
knowledge = KnowledgeFactory.from_file_path(pdf_path)
assembler = EmbeddingAssembler.load_from_knowledge(
knowledge=knowledge,
embedding_model="text2vec",
)
"""
def __init__(
self,
knowledge: Knowledge = None,
chunk_parameters: Optional[ChunkParameters] = None,
embedding_model: Optional[str] = None,
embedding_factory: Optional[EmbeddingFactory] = None,
vector_store_connector: Optional[VectorStoreConnector] = None,
**kwargs: Any,
) -> None:
"""Initialize with Embedding Assembler arguments.
Args:
knowledge: (Knowledge) Knowledge datasource.
chunk_parameters: (Optional[ChunkParameters]) ChunkManager to use for chunking.
embedding_model: (Optional[str]) Embedding model to use.
embedding_factory: (Optional[EmbeddingFactory]) EmbeddingFactory to use.
vector_store_connector: (Optional[VectorStoreConnector]) VectorStoreConnector to use.
"""
if knowledge is None:
raise ValueError("knowledge datasource must be provided.")
from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory
embedding_factory = embedding_factory or DefaultEmbeddingFactory(
default_model_name=os.getenv("EMBEDDING_MODEL")
)
if embedding_model:
embedding_fn = embedding_factory.create(model_name=embedding_model)
self._vector_store_connector = (
vector_store_connector
or VectorStoreConnector.from_default(embedding_fn=embedding_fn)
)
super().__init__(
knowledge=knowledge,
chunk_parameters=chunk_parameters,
**kwargs,
)
@classmethod
def load_from_knowledge(
cls,
knowledge: Knowledge = None,
chunk_parameters: Optional[ChunkParameters] = None,
embedding_model: Optional[str] = None,
embedding_factory: Optional[EmbeddingFactory] = None,
vector_store_connector: Optional[VectorStoreConnector] = None,
) -> "EmbeddingAssembler":
"""Load document embedding into vector store from path.
Args:
knowledge: (Knowledge) Knowledge datasource.
chunk_parameters: (Optional[ChunkParameters]) ChunkManager to use for chunking.
embedding_model: (Optional[str]) Embedding model to use.
embedding_factory: (Optional[EmbeddingFactory]) EmbeddingFactory to use.
vector_store_connector: (Optional[VectorStoreConnector]) VectorStoreConnector to use.
Returns:
EmbeddingAssembler
"""
from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory
embedding_factory = embedding_factory or DefaultEmbeddingFactory(
default_model_name=embedding_model or os.getenv("EMBEDDING_MODEL_PATH")
)
return cls(
knowledge=knowledge,
chunk_parameters=chunk_parameters,
embedding_model=embedding_model,
embedding_factory=embedding_factory,
vector_store_connector=vector_store_connector,
)
def persist(self) -> List[str]:
"""Persist chunks into vector store."""
return self._vector_store_connector.load_document(self._chunks)
def _extract_info(self, chunks) -> List[Chunk]:
"""Extract info from chunks."""
pass
def as_retriever(self, top_k: Optional[int] = 4) -> EmbeddingRetriever:
"""
Args:
top_k:(Optional[int]), default 4
Returns:
EmbeddingRetriever
"""
return EmbeddingRetriever(
top_k=top_k, vector_store_connector=self._vector_store_connector
)

View File

@@ -0,0 +1,113 @@
import os
from typing import Optional, Any, List
from dbgpt.core import LLMClient
from dbgpt.rag.chunk import Chunk
from dbgpt.rag.chunk_manager import ChunkParameters
from dbgpt.rag.extractor.base import Extractor
from dbgpt.rag.knowledge.base import Knowledge
from dbgpt.rag.retriever.base import BaseRetriever
from dbgpt.serve.rag.assembler.base import BaseAssembler
class SummaryAssembler(BaseAssembler):
"""Summary Assembler
Example:
.. code-block:: python
pdf_path = "../../../DB-GPT/docs/docs/awel.md"
OPEN_AI_KEY = "{your_api_key}"
OPEN_AI_BASE = "{your_api_base}"
llm_client = OpenAILLMClient(api_key=OPEN_AI_KEY, api_base=OPEN_AI_BASE)
knowledge = KnowledgeFactory.from_file_path(pdf_path)
chunk_parameters = ChunkParameters(chunk_strategy="CHUNK_BY_SIZE")
assembler = SummaryAssembler.load_from_knowledge(
knowledge=knowledge,
chunk_parameters=chunk_parameters,
llm_client=llm_client,
model_name="gpt-3.5-turbo",
)
summary = await assembler.generate_summary()
"""
def __init__(
self,
knowledge: Knowledge = None,
chunk_parameters: Optional[ChunkParameters] = None,
model_name: Optional[str] = None,
llm_client: Optional[LLMClient] = None,
extractor: Optional[Extractor] = None,
language: Optional[str] = "en",
**kwargs: Any,
) -> None:
"""Initialize with Embedding Assembler arguments.
Args:
knowledge: (Knowledge) Knowledge datasource.
chunk_manager: (Optional[ChunkManager]) ChunkManager to use for chunking.
model_name: (Optional[str]) llm model to use.
llm_client: (Optional[LLMClient]) LLMClient to use.
extractor: (Optional[Extractor]) Extractor to use for summarization.
language: (Optional[str]) The language of the prompt. Defaults to "en".
"""
if knowledge is None:
raise ValueError("knowledge datasource must be provided.")
self._model_name = model_name or os.getenv("LLM_MODEL")
self._llm_client = llm_client
from dbgpt.rag.extractor.summary import SummaryExtractor
self._extractor = extractor or SummaryExtractor(
llm_client=self._llm_client, model_name=self._model_name
)
self._language = language
super().__init__(
knowledge=knowledge,
chunk_parameters=chunk_parameters,
extractor=self._extractor,
**kwargs,
)
@classmethod
def load_from_knowledge(
cls,
knowledge: Knowledge = None,
chunk_parameters: Optional[ChunkParameters] = None,
model_name: Optional[str] = None,
llm_client: Optional[LLMClient] = None,
extractor: Optional[Extractor] = None,
language: Optional[str] = "en",
**kwargs: Any,
) -> "SummaryAssembler":
"""Load document embedding into vector store from path.
Args:
knowledge: (Knowledge) Knowledge datasource.
chunk_parameters: (Optional[ChunkParameters]) ChunkManager to use for chunking.
model_name: (Optional[str]) llm model to use.
llm_client: (Optional[LLMClient]) LLMClient to use.
extractor: (Optional[Extractor]) Extractor to use for summarization.
language: (Optional[str]) The language of the prompt. Defaults to "en".
Returns:
SummaryAssembler
"""
return cls(
knowledge=knowledge,
chunk_parameters=chunk_parameters,
model_name=model_name,
llm_client=llm_client,
extractor=extractor,
language=language,
**kwargs,
)
async def generate_summary(self) -> str:
"""Generate summary."""
return await self._extractor.aextract(self._chunks)
def persist(self) -> List[str]:
"""Persist chunks into store."""
def _extract_info(self, chunks) -> List[Chunk]:
"""Extract info from chunks."""
def as_retriever(self, **kwargs: Any) -> BaseRetriever:
"""Return a retriever."""

View File

@@ -0,0 +1,76 @@
from unittest.mock import MagicMock
import pytest
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect
from dbgpt.rag.chunk_manager import ChunkParameters, SplitterType
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
from dbgpt.rag.knowledge.base import Knowledge
from dbgpt.rag.text_splitter.text_splitter import CharacterTextSplitter
from dbgpt.serve.rag.assembler.embedding import EmbeddingAssembler
from dbgpt.storage.vector_store.connector import VectorStoreConnector
@pytest.fixture
def mock_db_connection():
"""Create a temporary database connection for testing."""
connect = SQLiteTempConnect.create_temporary_db()
connect.create_temp_tables(
{
"user": {
"columns": {
"id": "INTEGER PRIMARY KEY",
"name": "TEXT",
"age": "INTEGER",
},
"data": [
(1, "Tom", 10),
(2, "Jerry", 16),
(3, "Jack", 18),
(4, "Alice", 20),
(5, "Bob", 22),
],
}
}
)
return connect
@pytest.fixture
def mock_chunk_parameters():
return MagicMock(spec=ChunkParameters)
@pytest.fixture
def mock_embedding_factory():
return MagicMock(spec=EmbeddingFactory)
@pytest.fixture
def mock_vector_store_connector():
return MagicMock(spec=VectorStoreConnector)
@pytest.fixture
def mock_knowledge():
return MagicMock(spec=Knowledge)
def test_load_knowledge(
mock_db_connection,
mock_knowledge,
mock_chunk_parameters,
mock_embedding_factory,
mock_vector_store_connector,
):
mock_chunk_parameters.chunk_strategy = "CHUNK_BY_SIZE"
mock_chunk_parameters.text_splitter = CharacterTextSplitter()
mock_chunk_parameters.splitter_type = SplitterType.USER_DEFINE
assembler = EmbeddingAssembler(
knowledge=mock_knowledge,
chunk_parameters=mock_chunk_parameters,
embedding_factory=mock_embedding_factory,
vector_store_connector=mock_vector_store_connector,
)
assembler.load_knowledge(knowledge=mock_knowledge)
assert len(assembler._chunks) == 0

View File

@@ -0,0 +1,76 @@
from unittest.mock import MagicMock, patch
import pytest
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect
from dbgpt.rag.chunk_manager import ChunkParameters, SplitterType
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
from dbgpt.rag.knowledge.base import Knowledge
from dbgpt.rag.text_splitter.text_splitter import CharacterTextSplitter
from dbgpt.serve.rag.assembler.db_struct import DBStructAssembler
from dbgpt.storage.vector_store.connector import VectorStoreConnector
@pytest.fixture
def mock_db_connection():
"""Create a temporary database connection for testing."""
connect = SQLiteTempConnect.create_temporary_db()
connect.create_temp_tables(
{
"user": {
"columns": {
"id": "INTEGER PRIMARY KEY",
"name": "TEXT",
"age": "INTEGER",
},
"data": [
(1, "Tom", 10),
(2, "Jerry", 16),
(3, "Jack", 18),
(4, "Alice", 20),
(5, "Bob", 22),
],
}
}
)
return connect
@pytest.fixture
def mock_chunk_parameters():
return MagicMock(spec=ChunkParameters)
@pytest.fixture
def mock_embedding_factory():
return MagicMock(spec=EmbeddingFactory)
@pytest.fixture
def mock_vector_store_connector():
return MagicMock(spec=VectorStoreConnector)
@pytest.fixture
def mock_knowledge():
return MagicMock(spec=Knowledge)
def test_load_knowledge(
mock_db_connection,
mock_knowledge,
mock_chunk_parameters,
mock_embedding_factory,
mock_vector_store_connector,
):
mock_chunk_parameters.chunk_strategy = "CHUNK_BY_SIZE"
mock_chunk_parameters.text_splitter = CharacterTextSplitter()
mock_chunk_parameters.splitter_type = SplitterType.USER_DEFINE
assembler = DBStructAssembler(
connection=mock_db_connection,
chunk_parameters=mock_chunk_parameters,
embedding_factory=mock_embedding_factory,
vector_store_connector=mock_vector_store_connector,
)
assembler.load_knowledge(knowledge=mock_knowledge)
assert len(assembler._chunks) == 1