refactor: RAG Refactor (#985)

Co-authored-by: Aralhi <xiaoping0501@gmail.com>
Co-authored-by: csunny <cfqsunny@163.com>
This commit is contained in:
Aries-ckt
2024-01-03 09:45:26 +08:00
committed by GitHub
parent 90775aad50
commit 9ad70a2961
206 changed files with 5766 additions and 2419 deletions

View File

@@ -0,0 +1,99 @@
from abc import ABC, abstractmethod
from enum import Enum
from typing import List, Tuple
from dbgpt.rag.chunk import Chunk
class RetrieverStrategy(str, Enum):
"""Retriever strategy.
Args:
- EMBEDDING: embedding retriever
- KEYWORD: keyword retriever
- HYBRID: hybrid retriever
"""
EMBEDDING = "embedding"
KEYWORD = "keyword"
HYBRID = "hybrid"
class BaseRetriever(ABC):
"""Base retriever."""
def retrieve(self, query: str) -> List[Chunk]:
"""
Args:
query (str): query text
Returns:
List[Chunk]: list of chunks
"""
return self._retrieve(query)
async def aretrieve(self, query: str) -> List[Chunk]:
"""
Args:
query (str): async query text
Returns:
List[Chunk]: list of chunks
"""
return await self._aretrieve(query)
def retrieve_with_scores(self, query: str, score_threshold: float) -> List[Chunk]:
"""
Args:
query (str): query text
score_threshold (float): score threshold
"""
return self._retrieve_with_score(query, score_threshold)
async def aretrieve_with_scores(
self, query: str, score_threshold: float
) -> List[Chunk]:
"""
Args:
query (str): query text
score_threshold (float): score threshold
Returns:
List[Chunk]: list of chunks
"""
return await self._aretrieve_with_score(query, score_threshold)
@abstractmethod
def _retrieve(self, query: str) -> List[Chunk]:
"""Retrieve knowledge chunks.
Args:
query (str): query text
Returns:
List[Chunk]: list of chunks
"""
@abstractmethod
async def _aretrieve(self, query: str) -> List[Chunk]:
"""Async Retrieve knowledge chunks.
Args:
query (str): query text
Returns:
List[Chunk]: list of chunks
"""
@abstractmethod
def _retrieve_with_score(self, query: str, score_threshold: float) -> List[Chunk]:
"""Retrieve knowledge chunks with score.
Args:
query (str): query text
score_threshold (float): score threshold
Returns:
List[Chunk]: list of chunks
"""
@abstractmethod
async def _aretrieve_with_score(
self, query: str, score_threshold: float
) -> List[Chunk]:
"""Async Retrieve knowledge chunks with score.
Args:
query (str): query text
score_threshold (float): score threshold
Returns:
List[Chunk]: list of chunks
"""

View File

@@ -0,0 +1,152 @@
from functools import reduce
from typing import List, Optional
from dbgpt.util.chat_util import run_async_tasks
from dbgpt.datasource.rdbms.base import RDBMSDatabase
from dbgpt.rag.chunk import Chunk
from dbgpt.rag.retriever.base import BaseRetriever
from dbgpt.rag.retriever.rerank import Ranker, DefaultRanker
from dbgpt.storage.vector_store.connector import VectorStoreConnector
class DBStructRetriever(BaseRetriever):
"""DBStruct retriever."""
def __init__(
self,
top_k: int = 4,
connection: Optional[RDBMSDatabase] = None,
is_embeddings: bool = True,
query_rewrite: bool = False,
rerank: Ranker = None,
vector_store_connector: Optional[VectorStoreConnector] = None,
**kwargs
):
"""
Args:
top_k (int): top k
connection (Optional[RDBMSDatabase]): RDBMSDatabase connection.
is_embeddings (bool): Whether to query by embeddings in the vector store, Defaults to True.
query_rewrite (bool): query rewrite
rerank (Ranker): rerank
vector_store_connector (VectorStoreConnector): vector store connector
code example:
.. code-block:: python
>>> from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect
>>> from dbgpt.serve.rag.assembler.db_struct import DBStructAssembler
>>> from dbgpt.storage.vector_store.connector import VectorStoreConnector
>>> from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
>>> from dbgpt.rag.retriever.embedding import EmbeddingRetriever
def _create_temporary_connection():
connect = SQLiteTempConnect.create_temporary_db()
connect.create_temp_tables(
{
"user": {
"columns": {
"id": "INTEGER PRIMARY KEY",
"name": "TEXT",
"age": "INTEGER",
},
"data": [
(1, "Tom", 10),
(2, "Jerry", 16),
(3, "Jack", 18),
(4, "Alice", 20),
(5, "Bob", 22),
],
}
}
)
return connect
connection = _create_temporary_connection()
vector_store_config = ChromaVectorConfig(name="vector_store_name")
embedding_model_path = "{your_embedding_model_path}"
embedding_fn = embedding_factory.create(
model_name=embedding_model_path
)
vector_connector = VectorStoreConnector.from_default(
"Chroma",
vector_store_config=vector_store_config,
embedding_fn=embedding_fn
)
# get db struct retriever
retriever = DBStructRetriever(top_k=3, vector_store_connector=vector_connector)
chunks = retriever.retrieve("show columns from table")
print(f"db struct rag example results:{[chunk.content for chunk in chunks]}")
"""
self._top_k = top_k
self._is_embeddings = is_embeddings
self._connection = connection
self._query_rewrite = query_rewrite
self._vector_store_connector = vector_store_connector
self._rerank = rerank or DefaultRanker(self._top_k)
def _retrieve(self, query: str) -> List[Chunk]:
"""Retrieve knowledge chunks.
Args:
query (str): query text
"""
if self._is_embeddings:
queries = [query]
candidates = [
self._vector_store_connector.similar_search(query, self._top_k)
for query in queries
]
candidates = reduce(lambda x, y: x + y, candidates)
return candidates
else:
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
table_summaries = _parse_db_summary(self._connection)
return [Chunk(content=table_summary) for table_summary in table_summaries]
def _retrieve_with_score(self, query: str, score_threshold: float) -> List[Chunk]:
"""Retrieve knowledge chunks with score.
Args:
query (str): query text
score_threshold (float): score threshold
"""
return self._retrieve(query)
async def _aretrieve(self, query: str) -> List[Chunk]:
"""Retrieve knowledge chunks.
Args:
query (str): query text
"""
if self._is_embeddings:
queries = [query]
candidates = [self._similarity_search(query) for query in queries]
candidates = await run_async_tasks(tasks=candidates, concurrency_limit=1)
return candidates
else:
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
table_summaries = await run_async_tasks(
tasks=[self._aparse_db_summary()], concurrency_limit=1
)
return [Chunk(content=table_summary) for table_summary in table_summaries]
async def _aretrieve_with_score(
self, query: str, score_threshold: float
) -> List[Chunk]:
"""Retrieve knowledge chunks with score.
Args:
query (str): query text
score_threshold (float): score threshold
"""
return await self._aretrieve(query)
async def _similarity_search(self, query) -> List[Chunk]:
"""Similar search."""
return self._vector_store_connector.similar_search(
query,
self._top_k,
)
async def _aparse_db_summary(self) -> List[Chunk]:
"""Similar search."""
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
return _parse_db_summary()

View File

@@ -0,0 +1,146 @@
from functools import reduce
from typing import List, Optional
from dbgpt.rag.retriever.rewrite import QueryRewrite
from dbgpt.util.chat_util import run_async_tasks
from dbgpt.rag.chunk import Chunk
from dbgpt.rag.retriever.base import BaseRetriever
from dbgpt.rag.retriever.rerank import Ranker, DefaultRanker
from dbgpt.storage.vector_store.connector import VectorStoreConnector
class EmbeddingRetriever(BaseRetriever):
"""Embedding retriever."""
def __init__(
self,
top_k: int = 4,
query_rewrite: Optional[QueryRewrite] = None,
rerank: Ranker = None,
vector_store_connector: VectorStoreConnector = None,
):
"""
Args:
top_k (int): top k
query_rewrite (Optional[QueryRewrite]): query rewrite
rerank (Ranker): rerank
vector_store_connector (VectorStoreConnector): vector store connector
code example:
.. code-block:: python
>>> from dbgpt.storage.vector_store.connector import VectorStoreConnector
>>> from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
>>> from dbgpt.rag.retriever.embedding import EmbeddingRetriever
>>> from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory
embedding_factory = DefaultEmbeddingFactory()
from dbgpt.rag.retriever.embedding import EmbeddingRetriever
from dbgpt.storage.vector_store.connector import VectorStoreConnector
embedding_fn = embedding_factory.create(
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
)
vector_name = "test"
config = ChromaVectorConfig(name=vector_name, embedding_fn=embedding_fn)
vector_store_connector = VectorStoreConnector(
vector_store_type=""Chroma"",
vector_store_config=config,
)
embedding_retriever = EmbeddingRetriever(
top_k=3, vector_store_connector=vector_store_connector
)
chunks = embedding_retriever.retrieve("your query text")
print(f"embedding retriever results:{[chunk.content for chunk in chunks]}")
"""
self._top_k = top_k
self._query_rewrite = query_rewrite
self._vector_store_connector = vector_store_connector
self._rerank = rerank or DefaultRanker(self._top_k)
def _retrieve(self, query: str) -> List[Chunk]:
"""Retrieve knowledge chunks.
Args:
query (str): query text
Return:
List[Chunk]: list of chunks
"""
queries = [query]
candidates = [
self._vector_store_connector.similar_search(query, self._top_k)
for query in queries
]
candidates = reduce(lambda x, y: x + y, candidates)
return candidates
def _retrieve_with_score(self, query: str, score_threshold: float) -> List[Chunk]:
"""Retrieve knowledge chunks with score.
Args:
query (str): query text
score_threshold (float): score threshold
Return:
List[Chunk]: list of chunks with score
"""
queries = [query]
candidates_with_score = [
self._vector_store_connector.similar_search_with_scores(
query, self._top_k, score_threshold
)
for query in queries
]
candidates_with_score = reduce(lambda x, y: x + y, candidates_with_score)
candidates_with_score = self._rerank.rank(candidates_with_score)
return candidates_with_score
async def _aretrieve(self, query: str) -> List[Chunk]:
"""Retrieve knowledge chunks.
Args:
query (str): query text
Return:
List[Chunk]: list of chunks
"""
queries = [query]
if self._query_rewrite:
new_queries = await self._query_rewrite.rewrite(origin_query=query, nums=1)
queries.extend(new_queries)
candidates = [self._similarity_search(query) for query in queries]
candidates = await run_async_tasks(tasks=candidates, concurrency_limit=1)
return candidates
async def _aretrieve_with_score(
self, query: str, score_threshold: float
) -> List[Chunk]:
"""Retrieve knowledge chunks with score.
Args:
query (str): query text
score_threshold (float): score threshold
Return:
List[Chunk]: list of chunks with score
"""
queries = [query]
if self._query_rewrite:
new_queries = await self._query_rewrite.rewrite(origin_query=query, nums=1)
queries.extend(new_queries)
candidates_with_score = [
self._similarity_search_with_score(query, score_threshold)
for query in queries
]
candidates_with_score = await run_async_tasks(
tasks=candidates_with_score, concurrency_limit=1
)
candidates_with_score = reduce(lambda x, y: x + y, candidates_with_score)
candidates_with_score = self._rerank.rank(candidates_with_score)
return candidates_with_score
async def _similarity_search(self, query) -> List[Chunk]:
"""Similar search."""
return self._vector_store_connector.similar_search(
query,
self._top_k,
)
async def _similarity_search_with_score(
self, query, score_threshold
) -> List[Chunk]:
"""Similar search with score."""
return self._vector_store_connector.similar_search_with_scores(
query, self._top_k, score_threshold
)

View File

@@ -1,53 +0,0 @@
from typing import List
from dbgpt.app.scene import ChatScene
from dbgpt.app.scene import BaseChat
class QueryReinforce:
"""
query reinforce, include query rewrite, query correct
"""
def __init__(
self, query: str = None, model_name: str = None, llm_chat: BaseChat = None
):
"""query reinforce
Args:
- query: str, user query
- model_name: str, llm model name
"""
self.query = query
self.model_name = model_name
self.llm_chat = llm_chat
async def rewrite(self) -> List[str]:
"""query rewrite"""
from dbgpt._private.chat_util import llm_chat_response_nostream
import uuid
chat_param = {
"chat_session_id": uuid.uuid1(),
"current_user_input": self.query,
"select_param": 2,
"model_name": self.model_name,
"model_cache_enable": False,
}
tasks = [
llm_chat_response_nostream(
ChatScene.QueryRewrite.value(), **{"chat_param": chat_param}
)
]
from dbgpt._private.chat_util import run_async_tasks
queries = await run_async_tasks(tasks=tasks, concurrency_limit=1)
queries = list(
filter(
lambda content: "LLMServer Generate Error" not in content,
queries,
)
)
return queries[0]
def correct(self) -> List[str]:
pass

View File

@@ -1,11 +1,13 @@
from abc import ABC
from typing import List, Tuple, Optional
from typing import List, Optional
from dbgpt.rag.chunk import Chunk
class Ranker(ABC):
"""Base Ranker"""
def __init__(self, topk: int, rank_fn: Optional[callable] = None):
def __init__(self, topk: int, rank_fn: Optional[callable] = None) -> None:
"""
abstract base ranker
Args:
@@ -15,7 +17,7 @@ class Ranker(ABC):
self.topk = topk
self.rank_fn = rank_fn
def rank(self, candidates_with_scores: List, topk: int):
def rank(self, candidates_with_scores: List) -> List[Chunk]:
"""rank algorithm implementation return topk documents by candidates similarity score
Args:
candidates_with_scores: List[Tuple]
@@ -26,17 +28,17 @@ class Ranker(ABC):
pass
def _filter(self, candidates_with_scores: List):
def _filter(self, candidates_with_scores: List) -> List[Chunk]:
"""filter duplicate candidates documents"""
candidates_with_scores = sorted(
candidates_with_scores, key=lambda x: x[1], reverse=True
candidates_with_scores, key=lambda x: x.score, reverse=True
)
visited_docs = set()
new_candidates = []
for candidate_doc, score in candidates_with_scores:
if candidate_doc.page_content not in visited_docs:
new_candidates.append((candidate_doc, score))
visited_docs.add(candidate_doc.page_content)
for candidate_chunk in candidates_with_scores:
if candidate_chunk.content not in visited_docs:
new_candidates.append(candidate_chunk)
visited_docs.add(candidate_chunk.content)
return new_candidates
@@ -46,7 +48,7 @@ class DefaultRanker(Ranker):
def __init__(self, topk: int, rank_fn: Optional[callable] = None):
super().__init__(topk, rank_fn)
def rank(self, candidates_with_scores: List[Tuple]):
def rank(self, candidates_with_scores: List[Chunk]) -> List[Chunk]:
"""Default rank algorithm implementation
return topk documents by candidates similarity score
Args:
@@ -59,11 +61,9 @@ class DefaultRanker(Ranker):
candidates_with_scores = self.rank_fn(candidates_with_scores)
else:
candidates_with_scores = sorted(
candidates_with_scores, key=lambda x: x[1], reverse=True
candidates_with_scores, key=lambda x: x.score, reverse=True
)
return [
(candidate_doc, score) for candidate_doc, score in candidates_with_scores
][: self.topk]
return candidates_with_scores[: self.topk]
class RRFRanker(Ranker):
@@ -72,7 +72,7 @@ class RRFRanker(Ranker):
def __init__(self, topk: int, rank_fn: Optional[callable] = None):
super().__init__(topk, rank_fn)
def rank(self, candidates_with_scores: List):
def rank(self, candidates_with_scores: List[Chunk]) -> List[Chunk]:
"""RRF rank algorithm implementation
This code implements an algorithm called Reciprocal Rank Fusion (RRF), is a method for combining multiple result sets with different relevance indicators into a single result set. RRF requires no tuning, and the different relevance indicators do not have to be related to each other to achieve high-quality results.
RRF uses the following formula to determine the score for ranking each document:

View File

@@ -0,0 +1,103 @@
from typing import List, Optional
from dbgpt.core import LLMClient, ModelMessage, ModelRequest, ModelMessageRoleType
REWRITE_PROMPT_TEMPLATE_EN = """
Generate {nums} search queries related to: {original_query}, Provide following comma-separated format: 'queries: <queries>'\n":
"original query:: {original_query}\n"
"queries:\n"
"""
REWRITE_PROMPT_TEMPLATE_ZH = """请根据原问题优化生成{nums}个相关的搜索查询,这些查询应与原始查询相似并且是人们可能会提出的可回答的搜索问题。请勿使用任何示例中提到的内容,确保所有生成的查询均独立于示例,仅基于提供的原始查询。请按照以下逗号分隔的格式提供: 'queries<queries>'
"original_query{original_query}\n"
"queries\n"
"""
class QueryRewrite:
"""
query reinforce, include query rewrite, query correct
"""
def __init__(
self,
model_name: str = None,
llm_client: Optional[LLMClient] = None,
language: Optional[str] = "en",
) -> None:
"""query rewrite
Args:
- query: (str), user query
- model_name: (str), llm model name
- llm_client: (Optional[LLMClient])
"""
self._model_name = model_name
self._llm_client = llm_client
self._language = language
self._prompt_template = (
REWRITE_PROMPT_TEMPLATE_EN
if language == "en"
else REWRITE_PROMPT_TEMPLATE_ZH
)
async def rewrite(self, origin_query: str, nums: Optional[int] = 1) -> List[str]:
"""query rewrite
Args:
origin_query: str original query
nums: Optional[int] rewrite nums
Returns:
queries: List[str]
"""
from dbgpt.util.chat_util import run_async_tasks
prompt = self._prompt_template.format(original_query=origin_query, nums=nums)
messages = [ModelMessage(role=ModelMessageRoleType.SYSTEM, content=prompt)]
request = ModelRequest(model=self._model_name, messages=messages)
tasks = [self._llm_client.generate(request)]
queries = await run_async_tasks(tasks=tasks, concurrency_limit=1)
queries = [model_out.text for model_out in queries]
queries = list(
filter(
lambda content: "LLMServer Generate Error" not in content,
queries,
)
)
print("rewrite queries:", queries)
return self._parse_llm_output(output=queries[0])
def correct(self) -> List[str]:
pass
def _parse_llm_output(self, output: str) -> List[str]:
"""parse llm output
Args:
output: str
Returns:
output: List[str]
"""
lowercase = True
try:
results = []
response = output.strip()
if response.startswith("queries:"):
response = response[len("queries:") :]
queries = response.split(",")
if len(queries) == 1:
queries = response.split("")
if len(queries) == 1:
queries = response.split("?")
if len(queries) == 1:
queries = response.split("")
for k in queries:
rk = k
if lowercase:
rk = rk.lower()
s = rk.strip()
if s == "":
continue
results.append(s)
except Exception as e:
print(f"parse query rewrite prompt_response error: {e}")
return []
return results

View File

View File

@@ -0,0 +1,49 @@
import pytest
from unittest.mock import MagicMock, patch, AsyncMock
from typing import List
import dbgpt
from dbgpt.rag.chunk import Chunk
from dbgpt.rag.retriever.db_struct import DBStructRetriever
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
@pytest.fixture
def mock_db_connection():
return MagicMock()
@pytest.fixture
def mock_vector_store_connector():
mock_connector = MagicMock()
mock_connector.similar_search.return_value = [Chunk(content="Table summary")] * 4
return mock_connector
@pytest.fixture
def dbstruct_retriever(mock_db_connection, mock_vector_store_connector):
return DBStructRetriever(
connection=mock_db_connection,
vector_store_connector=mock_vector_store_connector,
)
def mock_parse_db_summary() -> str:
"""Patch _parse_db_summary method."""
return "Table summary"
# Mocking the _parse_db_summary method in your test function
@patch.object(
dbgpt.rag.summary.rdbms_db_summary, "_parse_db_summary", mock_parse_db_summary
)
def test_retrieve_with_mocked_summary(dbstruct_retriever):
query = "Table summary"
chunks: List[Chunk] = dbstruct_retriever._retrieve(query)
assert isinstance(chunks[0], Chunk)
assert chunks[0].content == "Table summary"
async def async_mock_parse_db_summary() -> str:
"""Asynchronous patch for _parse_db_summary method."""
return "Table summary"

View File

@@ -0,0 +1,39 @@
from unittest.mock import MagicMock
import pytest
from dbgpt.rag.chunk import Chunk
from dbgpt.rag.retriever.embedding import EmbeddingRetriever
@pytest.fixture
def top_k():
return 4
@pytest.fixture
def query():
return "test query"
@pytest.fixture
def mock_vector_store_connector():
return MagicMock()
@pytest.fixture
def embedding_retriever(top_k, mock_vector_store_connector):
return EmbeddingRetriever(
top_k=top_k,
query_rewrite=False,
vector_store_connector=mock_vector_store_connector,
)
def test_retrieve(query, top_k, mock_vector_store_connector, embedding_retriever):
expected_chunks = [Chunk() for _ in range(top_k)]
mock_vector_store_connector.similar_search.return_value = expected_chunks
retrieved_chunks = embedding_retriever._retrieve(query)
mock_vector_store_connector.similar_search.assert_called_once_with(query, top_k)
assert len(retrieved_chunks) == top_k