mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-13 21:21:08 +00:00
refactor: RAG Refactor (#985)
Co-authored-by: Aralhi <xiaoping0501@gmail.com> Co-authored-by: csunny <cfqsunny@163.com>
This commit is contained in:
99
dbgpt/rag/retriever/base.py
Normal file
99
dbgpt/rag/retriever/base.py
Normal file
@@ -0,0 +1,99 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import List, Tuple
|
||||
from dbgpt.rag.chunk import Chunk
|
||||
|
||||
|
||||
class RetrieverStrategy(str, Enum):
|
||||
"""Retriever strategy.
|
||||
Args:
|
||||
- EMBEDDING: embedding retriever
|
||||
- KEYWORD: keyword retriever
|
||||
- HYBRID: hybrid retriever
|
||||
"""
|
||||
|
||||
EMBEDDING = "embedding"
|
||||
KEYWORD = "keyword"
|
||||
HYBRID = "hybrid"
|
||||
|
||||
|
||||
class BaseRetriever(ABC):
|
||||
"""Base retriever."""
|
||||
|
||||
def retrieve(self, query: str) -> List[Chunk]:
|
||||
"""
|
||||
Args:
|
||||
query (str): query text
|
||||
Returns:
|
||||
List[Chunk]: list of chunks
|
||||
"""
|
||||
return self._retrieve(query)
|
||||
|
||||
async def aretrieve(self, query: str) -> List[Chunk]:
|
||||
"""
|
||||
Args:
|
||||
query (str): async query text
|
||||
Returns:
|
||||
List[Chunk]: list of chunks
|
||||
"""
|
||||
return await self._aretrieve(query)
|
||||
|
||||
def retrieve_with_scores(self, query: str, score_threshold: float) -> List[Chunk]:
|
||||
"""
|
||||
Args:
|
||||
query (str): query text
|
||||
score_threshold (float): score threshold
|
||||
"""
|
||||
return self._retrieve_with_score(query, score_threshold)
|
||||
|
||||
async def aretrieve_with_scores(
|
||||
self, query: str, score_threshold: float
|
||||
) -> List[Chunk]:
|
||||
"""
|
||||
Args:
|
||||
query (str): query text
|
||||
score_threshold (float): score threshold
|
||||
Returns:
|
||||
List[Chunk]: list of chunks
|
||||
"""
|
||||
return await self._aretrieve_with_score(query, score_threshold)
|
||||
|
||||
@abstractmethod
|
||||
def _retrieve(self, query: str) -> List[Chunk]:
|
||||
"""Retrieve knowledge chunks.
|
||||
Args:
|
||||
query (str): query text
|
||||
Returns:
|
||||
List[Chunk]: list of chunks
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def _aretrieve(self, query: str) -> List[Chunk]:
|
||||
"""Async Retrieve knowledge chunks.
|
||||
Args:
|
||||
query (str): query text
|
||||
Returns:
|
||||
List[Chunk]: list of chunks
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def _retrieve_with_score(self, query: str, score_threshold: float) -> List[Chunk]:
|
||||
"""Retrieve knowledge chunks with score.
|
||||
Args:
|
||||
query (str): query text
|
||||
score_threshold (float): score threshold
|
||||
Returns:
|
||||
List[Chunk]: list of chunks
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def _aretrieve_with_score(
|
||||
self, query: str, score_threshold: float
|
||||
) -> List[Chunk]:
|
||||
"""Async Retrieve knowledge chunks with score.
|
||||
Args:
|
||||
query (str): query text
|
||||
score_threshold (float): score threshold
|
||||
Returns:
|
||||
List[Chunk]: list of chunks
|
||||
"""
|
152
dbgpt/rag/retriever/db_struct.py
Normal file
152
dbgpt/rag/retriever/db_struct.py
Normal file
@@ -0,0 +1,152 @@
|
||||
from functools import reduce
|
||||
from typing import List, Optional
|
||||
|
||||
from dbgpt.util.chat_util import run_async_tasks
|
||||
from dbgpt.datasource.rdbms.base import RDBMSDatabase
|
||||
from dbgpt.rag.chunk import Chunk
|
||||
from dbgpt.rag.retriever.base import BaseRetriever
|
||||
from dbgpt.rag.retriever.rerank import Ranker, DefaultRanker
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
|
||||
class DBStructRetriever(BaseRetriever):
|
||||
"""DBStruct retriever."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
top_k: int = 4,
|
||||
connection: Optional[RDBMSDatabase] = None,
|
||||
is_embeddings: bool = True,
|
||||
query_rewrite: bool = False,
|
||||
rerank: Ranker = None,
|
||||
vector_store_connector: Optional[VectorStoreConnector] = None,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
top_k (int): top k
|
||||
connection (Optional[RDBMSDatabase]): RDBMSDatabase connection.
|
||||
is_embeddings (bool): Whether to query by embeddings in the vector store, Defaults to True.
|
||||
query_rewrite (bool): query rewrite
|
||||
rerank (Ranker): rerank
|
||||
vector_store_connector (VectorStoreConnector): vector store connector
|
||||
code example:
|
||||
.. code-block:: python
|
||||
>>> from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect
|
||||
>>> from dbgpt.serve.rag.assembler.db_struct import DBStructAssembler
|
||||
>>> from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
>>> from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
||||
>>> from dbgpt.rag.retriever.embedding import EmbeddingRetriever
|
||||
|
||||
def _create_temporary_connection():
|
||||
connect = SQLiteTempConnect.create_temporary_db()
|
||||
connect.create_temp_tables(
|
||||
{
|
||||
"user": {
|
||||
"columns": {
|
||||
"id": "INTEGER PRIMARY KEY",
|
||||
"name": "TEXT",
|
||||
"age": "INTEGER",
|
||||
},
|
||||
"data": [
|
||||
(1, "Tom", 10),
|
||||
(2, "Jerry", 16),
|
||||
(3, "Jack", 18),
|
||||
(4, "Alice", 20),
|
||||
(5, "Bob", 22),
|
||||
],
|
||||
}
|
||||
}
|
||||
)
|
||||
return connect
|
||||
connection = _create_temporary_connection()
|
||||
vector_store_config = ChromaVectorConfig(name="vector_store_name")
|
||||
embedding_model_path = "{your_embedding_model_path}"
|
||||
embedding_fn = embedding_factory.create(
|
||||
model_name=embedding_model_path
|
||||
)
|
||||
vector_connector = VectorStoreConnector.from_default(
|
||||
"Chroma",
|
||||
vector_store_config=vector_store_config,
|
||||
embedding_fn=embedding_fn
|
||||
)
|
||||
# get db struct retriever
|
||||
retriever = DBStructRetriever(top_k=3, vector_store_connector=vector_connector)
|
||||
chunks = retriever.retrieve("show columns from table")
|
||||
print(f"db struct rag example results:{[chunk.content for chunk in chunks]}")
|
||||
"""
|
||||
|
||||
self._top_k = top_k
|
||||
self._is_embeddings = is_embeddings
|
||||
self._connection = connection
|
||||
self._query_rewrite = query_rewrite
|
||||
self._vector_store_connector = vector_store_connector
|
||||
self._rerank = rerank or DefaultRanker(self._top_k)
|
||||
|
||||
def _retrieve(self, query: str) -> List[Chunk]:
|
||||
"""Retrieve knowledge chunks.
|
||||
Args:
|
||||
query (str): query text
|
||||
"""
|
||||
if self._is_embeddings:
|
||||
queries = [query]
|
||||
candidates = [
|
||||
self._vector_store_connector.similar_search(query, self._top_k)
|
||||
for query in queries
|
||||
]
|
||||
candidates = reduce(lambda x, y: x + y, candidates)
|
||||
return candidates
|
||||
else:
|
||||
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
|
||||
|
||||
table_summaries = _parse_db_summary(self._connection)
|
||||
return [Chunk(content=table_summary) for table_summary in table_summaries]
|
||||
|
||||
def _retrieve_with_score(self, query: str, score_threshold: float) -> List[Chunk]:
|
||||
"""Retrieve knowledge chunks with score.
|
||||
Args:
|
||||
query (str): query text
|
||||
score_threshold (float): score threshold
|
||||
"""
|
||||
return self._retrieve(query)
|
||||
|
||||
async def _aretrieve(self, query: str) -> List[Chunk]:
|
||||
"""Retrieve knowledge chunks.
|
||||
Args:
|
||||
query (str): query text
|
||||
"""
|
||||
if self._is_embeddings:
|
||||
queries = [query]
|
||||
candidates = [self._similarity_search(query) for query in queries]
|
||||
candidates = await run_async_tasks(tasks=candidates, concurrency_limit=1)
|
||||
return candidates
|
||||
else:
|
||||
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
|
||||
|
||||
table_summaries = await run_async_tasks(
|
||||
tasks=[self._aparse_db_summary()], concurrency_limit=1
|
||||
)
|
||||
return [Chunk(content=table_summary) for table_summary in table_summaries]
|
||||
|
||||
async def _aretrieve_with_score(
|
||||
self, query: str, score_threshold: float
|
||||
) -> List[Chunk]:
|
||||
"""Retrieve knowledge chunks with score.
|
||||
Args:
|
||||
query (str): query text
|
||||
score_threshold (float): score threshold
|
||||
"""
|
||||
return await self._aretrieve(query)
|
||||
|
||||
async def _similarity_search(self, query) -> List[Chunk]:
|
||||
"""Similar search."""
|
||||
return self._vector_store_connector.similar_search(
|
||||
query,
|
||||
self._top_k,
|
||||
)
|
||||
|
||||
async def _aparse_db_summary(self) -> List[Chunk]:
|
||||
"""Similar search."""
|
||||
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
|
||||
|
||||
return _parse_db_summary()
|
146
dbgpt/rag/retriever/embedding.py
Normal file
146
dbgpt/rag/retriever/embedding.py
Normal file
@@ -0,0 +1,146 @@
|
||||
from functools import reduce
|
||||
from typing import List, Optional
|
||||
|
||||
from dbgpt.rag.retriever.rewrite import QueryRewrite
|
||||
from dbgpt.util.chat_util import run_async_tasks
|
||||
from dbgpt.rag.chunk import Chunk
|
||||
from dbgpt.rag.retriever.base import BaseRetriever
|
||||
from dbgpt.rag.retriever.rerank import Ranker, DefaultRanker
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
|
||||
class EmbeddingRetriever(BaseRetriever):
|
||||
"""Embedding retriever."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
top_k: int = 4,
|
||||
query_rewrite: Optional[QueryRewrite] = None,
|
||||
rerank: Ranker = None,
|
||||
vector_store_connector: VectorStoreConnector = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
top_k (int): top k
|
||||
query_rewrite (Optional[QueryRewrite]): query rewrite
|
||||
rerank (Ranker): rerank
|
||||
vector_store_connector (VectorStoreConnector): vector store connector
|
||||
code example:
|
||||
.. code-block:: python
|
||||
>>> from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
>>> from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
||||
>>> from dbgpt.rag.retriever.embedding import EmbeddingRetriever
|
||||
>>> from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory
|
||||
|
||||
embedding_factory = DefaultEmbeddingFactory()
|
||||
from dbgpt.rag.retriever.embedding import EmbeddingRetriever
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
embedding_fn = embedding_factory.create(
|
||||
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
|
||||
)
|
||||
vector_name = "test"
|
||||
config = ChromaVectorConfig(name=vector_name, embedding_fn=embedding_fn)
|
||||
vector_store_connector = VectorStoreConnector(
|
||||
vector_store_type=""Chroma"",
|
||||
vector_store_config=config,
|
||||
)
|
||||
embedding_retriever = EmbeddingRetriever(
|
||||
top_k=3, vector_store_connector=vector_store_connector
|
||||
)
|
||||
chunks = embedding_retriever.retrieve("your query text")
|
||||
print(f"embedding retriever results:{[chunk.content for chunk in chunks]}")
|
||||
"""
|
||||
self._top_k = top_k
|
||||
self._query_rewrite = query_rewrite
|
||||
self._vector_store_connector = vector_store_connector
|
||||
self._rerank = rerank or DefaultRanker(self._top_k)
|
||||
|
||||
def _retrieve(self, query: str) -> List[Chunk]:
|
||||
"""Retrieve knowledge chunks.
|
||||
Args:
|
||||
query (str): query text
|
||||
Return:
|
||||
List[Chunk]: list of chunks
|
||||
"""
|
||||
queries = [query]
|
||||
candidates = [
|
||||
self._vector_store_connector.similar_search(query, self._top_k)
|
||||
for query in queries
|
||||
]
|
||||
candidates = reduce(lambda x, y: x + y, candidates)
|
||||
return candidates
|
||||
|
||||
def _retrieve_with_score(self, query: str, score_threshold: float) -> List[Chunk]:
|
||||
"""Retrieve knowledge chunks with score.
|
||||
Args:
|
||||
query (str): query text
|
||||
score_threshold (float): score threshold
|
||||
Return:
|
||||
List[Chunk]: list of chunks with score
|
||||
"""
|
||||
queries = [query]
|
||||
candidates_with_score = [
|
||||
self._vector_store_connector.similar_search_with_scores(
|
||||
query, self._top_k, score_threshold
|
||||
)
|
||||
for query in queries
|
||||
]
|
||||
candidates_with_score = reduce(lambda x, y: x + y, candidates_with_score)
|
||||
candidates_with_score = self._rerank.rank(candidates_with_score)
|
||||
return candidates_with_score
|
||||
|
||||
async def _aretrieve(self, query: str) -> List[Chunk]:
|
||||
"""Retrieve knowledge chunks.
|
||||
Args:
|
||||
query (str): query text
|
||||
Return:
|
||||
List[Chunk]: list of chunks
|
||||
"""
|
||||
queries = [query]
|
||||
if self._query_rewrite:
|
||||
new_queries = await self._query_rewrite.rewrite(origin_query=query, nums=1)
|
||||
queries.extend(new_queries)
|
||||
candidates = [self._similarity_search(query) for query in queries]
|
||||
candidates = await run_async_tasks(tasks=candidates, concurrency_limit=1)
|
||||
return candidates
|
||||
|
||||
async def _aretrieve_with_score(
|
||||
self, query: str, score_threshold: float
|
||||
) -> List[Chunk]:
|
||||
"""Retrieve knowledge chunks with score.
|
||||
Args:
|
||||
query (str): query text
|
||||
score_threshold (float): score threshold
|
||||
Return:
|
||||
List[Chunk]: list of chunks with score
|
||||
"""
|
||||
queries = [query]
|
||||
if self._query_rewrite:
|
||||
new_queries = await self._query_rewrite.rewrite(origin_query=query, nums=1)
|
||||
queries.extend(new_queries)
|
||||
candidates_with_score = [
|
||||
self._similarity_search_with_score(query, score_threshold)
|
||||
for query in queries
|
||||
]
|
||||
candidates_with_score = await run_async_tasks(
|
||||
tasks=candidates_with_score, concurrency_limit=1
|
||||
)
|
||||
candidates_with_score = reduce(lambda x, y: x + y, candidates_with_score)
|
||||
candidates_with_score = self._rerank.rank(candidates_with_score)
|
||||
return candidates_with_score
|
||||
|
||||
async def _similarity_search(self, query) -> List[Chunk]:
|
||||
"""Similar search."""
|
||||
return self._vector_store_connector.similar_search(
|
||||
query,
|
||||
self._top_k,
|
||||
)
|
||||
|
||||
async def _similarity_search_with_score(
|
||||
self, query, score_threshold
|
||||
) -> List[Chunk]:
|
||||
"""Similar search with score."""
|
||||
return self._vector_store_connector.similar_search_with_scores(
|
||||
query, self._top_k, score_threshold
|
||||
)
|
@@ -1,53 +0,0 @@
|
||||
from typing import List
|
||||
|
||||
from dbgpt.app.scene import ChatScene
|
||||
from dbgpt.app.scene import BaseChat
|
||||
|
||||
|
||||
class QueryReinforce:
|
||||
"""
|
||||
query reinforce, include query rewrite, query correct
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, query: str = None, model_name: str = None, llm_chat: BaseChat = None
|
||||
):
|
||||
"""query reinforce
|
||||
Args:
|
||||
- query: str, user query
|
||||
- model_name: str, llm model name
|
||||
"""
|
||||
self.query = query
|
||||
self.model_name = model_name
|
||||
self.llm_chat = llm_chat
|
||||
|
||||
async def rewrite(self) -> List[str]:
|
||||
"""query rewrite"""
|
||||
from dbgpt._private.chat_util import llm_chat_response_nostream
|
||||
import uuid
|
||||
|
||||
chat_param = {
|
||||
"chat_session_id": uuid.uuid1(),
|
||||
"current_user_input": self.query,
|
||||
"select_param": 2,
|
||||
"model_name": self.model_name,
|
||||
"model_cache_enable": False,
|
||||
}
|
||||
tasks = [
|
||||
llm_chat_response_nostream(
|
||||
ChatScene.QueryRewrite.value(), **{"chat_param": chat_param}
|
||||
)
|
||||
]
|
||||
from dbgpt._private.chat_util import run_async_tasks
|
||||
|
||||
queries = await run_async_tasks(tasks=tasks, concurrency_limit=1)
|
||||
queries = list(
|
||||
filter(
|
||||
lambda content: "LLMServer Generate Error" not in content,
|
||||
queries,
|
||||
)
|
||||
)
|
||||
return queries[0]
|
||||
|
||||
def correct(self) -> List[str]:
|
||||
pass
|
@@ -1,11 +1,13 @@
|
||||
from abc import ABC
|
||||
from typing import List, Tuple, Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from dbgpt.rag.chunk import Chunk
|
||||
|
||||
|
||||
class Ranker(ABC):
|
||||
"""Base Ranker"""
|
||||
|
||||
def __init__(self, topk: int, rank_fn: Optional[callable] = None):
|
||||
def __init__(self, topk: int, rank_fn: Optional[callable] = None) -> None:
|
||||
"""
|
||||
abstract base ranker
|
||||
Args:
|
||||
@@ -15,7 +17,7 @@ class Ranker(ABC):
|
||||
self.topk = topk
|
||||
self.rank_fn = rank_fn
|
||||
|
||||
def rank(self, candidates_with_scores: List, topk: int):
|
||||
def rank(self, candidates_with_scores: List) -> List[Chunk]:
|
||||
"""rank algorithm implementation return topk documents by candidates similarity score
|
||||
Args:
|
||||
candidates_with_scores: List[Tuple]
|
||||
@@ -26,17 +28,17 @@ class Ranker(ABC):
|
||||
|
||||
pass
|
||||
|
||||
def _filter(self, candidates_with_scores: List):
|
||||
def _filter(self, candidates_with_scores: List) -> List[Chunk]:
|
||||
"""filter duplicate candidates documents"""
|
||||
candidates_with_scores = sorted(
|
||||
candidates_with_scores, key=lambda x: x[1], reverse=True
|
||||
candidates_with_scores, key=lambda x: x.score, reverse=True
|
||||
)
|
||||
visited_docs = set()
|
||||
new_candidates = []
|
||||
for candidate_doc, score in candidates_with_scores:
|
||||
if candidate_doc.page_content not in visited_docs:
|
||||
new_candidates.append((candidate_doc, score))
|
||||
visited_docs.add(candidate_doc.page_content)
|
||||
for candidate_chunk in candidates_with_scores:
|
||||
if candidate_chunk.content not in visited_docs:
|
||||
new_candidates.append(candidate_chunk)
|
||||
visited_docs.add(candidate_chunk.content)
|
||||
return new_candidates
|
||||
|
||||
|
||||
@@ -46,7 +48,7 @@ class DefaultRanker(Ranker):
|
||||
def __init__(self, topk: int, rank_fn: Optional[callable] = None):
|
||||
super().__init__(topk, rank_fn)
|
||||
|
||||
def rank(self, candidates_with_scores: List[Tuple]):
|
||||
def rank(self, candidates_with_scores: List[Chunk]) -> List[Chunk]:
|
||||
"""Default rank algorithm implementation
|
||||
return topk documents by candidates similarity score
|
||||
Args:
|
||||
@@ -59,11 +61,9 @@ class DefaultRanker(Ranker):
|
||||
candidates_with_scores = self.rank_fn(candidates_with_scores)
|
||||
else:
|
||||
candidates_with_scores = sorted(
|
||||
candidates_with_scores, key=lambda x: x[1], reverse=True
|
||||
candidates_with_scores, key=lambda x: x.score, reverse=True
|
||||
)
|
||||
return [
|
||||
(candidate_doc, score) for candidate_doc, score in candidates_with_scores
|
||||
][: self.topk]
|
||||
return candidates_with_scores[: self.topk]
|
||||
|
||||
|
||||
class RRFRanker(Ranker):
|
||||
@@ -72,7 +72,7 @@ class RRFRanker(Ranker):
|
||||
def __init__(self, topk: int, rank_fn: Optional[callable] = None):
|
||||
super().__init__(topk, rank_fn)
|
||||
|
||||
def rank(self, candidates_with_scores: List):
|
||||
def rank(self, candidates_with_scores: List[Chunk]) -> List[Chunk]:
|
||||
"""RRF rank algorithm implementation
|
||||
This code implements an algorithm called Reciprocal Rank Fusion (RRF), is a method for combining multiple result sets with different relevance indicators into a single result set. RRF requires no tuning, and the different relevance indicators do not have to be related to each other to achieve high-quality results.
|
||||
RRF uses the following formula to determine the score for ranking each document:
|
||||
|
103
dbgpt/rag/retriever/rewrite.py
Normal file
103
dbgpt/rag/retriever/rewrite.py
Normal file
@@ -0,0 +1,103 @@
|
||||
from typing import List, Optional
|
||||
from dbgpt.core import LLMClient, ModelMessage, ModelRequest, ModelMessageRoleType
|
||||
|
||||
REWRITE_PROMPT_TEMPLATE_EN = """
|
||||
Generate {nums} search queries related to: {original_query}, Provide following comma-separated format: 'queries: <queries>'\n":
|
||||
"original query:: {original_query}\n"
|
||||
"queries:\n"
|
||||
"""
|
||||
|
||||
REWRITE_PROMPT_TEMPLATE_ZH = """请根据原问题优化生成{nums}个相关的搜索查询,这些查询应与原始查询相似并且是人们可能会提出的可回答的搜索问题。请勿使用任何示例中提到的内容,确保所有生成的查询均独立于示例,仅基于提供的原始查询。请按照以下逗号分隔的格式提供: 'queries:<queries>':
|
||||
"original_query:{original_query}\n"
|
||||
"queries:\n"
|
||||
"""
|
||||
|
||||
|
||||
class QueryRewrite:
|
||||
"""
|
||||
query reinforce, include query rewrite, query correct
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = None,
|
||||
llm_client: Optional[LLMClient] = None,
|
||||
language: Optional[str] = "en",
|
||||
) -> None:
|
||||
"""query rewrite
|
||||
Args:
|
||||
- query: (str), user query
|
||||
- model_name: (str), llm model name
|
||||
- llm_client: (Optional[LLMClient])
|
||||
"""
|
||||
self._model_name = model_name
|
||||
self._llm_client = llm_client
|
||||
self._language = language
|
||||
self._prompt_template = (
|
||||
REWRITE_PROMPT_TEMPLATE_EN
|
||||
if language == "en"
|
||||
else REWRITE_PROMPT_TEMPLATE_ZH
|
||||
)
|
||||
|
||||
async def rewrite(self, origin_query: str, nums: Optional[int] = 1) -> List[str]:
|
||||
"""query rewrite
|
||||
Args:
|
||||
origin_query: str original query
|
||||
nums: Optional[int] rewrite nums
|
||||
Returns:
|
||||
queries: List[str]
|
||||
"""
|
||||
from dbgpt.util.chat_util import run_async_tasks
|
||||
|
||||
prompt = self._prompt_template.format(original_query=origin_query, nums=nums)
|
||||
messages = [ModelMessage(role=ModelMessageRoleType.SYSTEM, content=prompt)]
|
||||
request = ModelRequest(model=self._model_name, messages=messages)
|
||||
tasks = [self._llm_client.generate(request)]
|
||||
queries = await run_async_tasks(tasks=tasks, concurrency_limit=1)
|
||||
queries = [model_out.text for model_out in queries]
|
||||
queries = list(
|
||||
filter(
|
||||
lambda content: "LLMServer Generate Error" not in content,
|
||||
queries,
|
||||
)
|
||||
)
|
||||
print("rewrite queries:", queries)
|
||||
return self._parse_llm_output(output=queries[0])
|
||||
|
||||
def correct(self) -> List[str]:
|
||||
pass
|
||||
|
||||
def _parse_llm_output(self, output: str) -> List[str]:
|
||||
"""parse llm output
|
||||
Args:
|
||||
output: str
|
||||
Returns:
|
||||
output: List[str]
|
||||
"""
|
||||
lowercase = True
|
||||
try:
|
||||
results = []
|
||||
response = output.strip()
|
||||
|
||||
if response.startswith("queries:"):
|
||||
response = response[len("queries:") :]
|
||||
|
||||
queries = response.split(",")
|
||||
if len(queries) == 1:
|
||||
queries = response.split(",")
|
||||
if len(queries) == 1:
|
||||
queries = response.split("?")
|
||||
if len(queries) == 1:
|
||||
queries = response.split("?")
|
||||
for k in queries:
|
||||
rk = k
|
||||
if lowercase:
|
||||
rk = rk.lower()
|
||||
s = rk.strip()
|
||||
if s == "":
|
||||
continue
|
||||
results.append(s)
|
||||
except Exception as e:
|
||||
print(f"parse query rewrite prompt_response error: {e}")
|
||||
return []
|
||||
return results
|
0
dbgpt/rag/retriever/tests/__init__.py
Normal file
0
dbgpt/rag/retriever/tests/__init__.py
Normal file
49
dbgpt/rag/retriever/tests/test_db_struct.py
Normal file
49
dbgpt/rag/retriever/tests/test_db_struct.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch, AsyncMock
|
||||
from typing import List
|
||||
|
||||
import dbgpt
|
||||
from dbgpt.rag.chunk import Chunk
|
||||
from dbgpt.rag.retriever.db_struct import DBStructRetriever
|
||||
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_connection():
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vector_store_connector():
|
||||
mock_connector = MagicMock()
|
||||
mock_connector.similar_search.return_value = [Chunk(content="Table summary")] * 4
|
||||
return mock_connector
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dbstruct_retriever(mock_db_connection, mock_vector_store_connector):
|
||||
return DBStructRetriever(
|
||||
connection=mock_db_connection,
|
||||
vector_store_connector=mock_vector_store_connector,
|
||||
)
|
||||
|
||||
|
||||
def mock_parse_db_summary() -> str:
|
||||
"""Patch _parse_db_summary method."""
|
||||
return "Table summary"
|
||||
|
||||
|
||||
# Mocking the _parse_db_summary method in your test function
|
||||
@patch.object(
|
||||
dbgpt.rag.summary.rdbms_db_summary, "_parse_db_summary", mock_parse_db_summary
|
||||
)
|
||||
def test_retrieve_with_mocked_summary(dbstruct_retriever):
|
||||
query = "Table summary"
|
||||
chunks: List[Chunk] = dbstruct_retriever._retrieve(query)
|
||||
assert isinstance(chunks[0], Chunk)
|
||||
assert chunks[0].content == "Table summary"
|
||||
|
||||
|
||||
async def async_mock_parse_db_summary() -> str:
|
||||
"""Asynchronous patch for _parse_db_summary method."""
|
||||
return "Table summary"
|
39
dbgpt/rag/retriever/tests/test_embedding.py
Normal file
39
dbgpt/rag/retriever/tests/test_embedding.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from unittest.mock import MagicMock
|
||||
import pytest
|
||||
|
||||
from dbgpt.rag.chunk import Chunk
|
||||
from dbgpt.rag.retriever.embedding import EmbeddingRetriever
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def top_k():
|
||||
return 4
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def query():
|
||||
return "test query"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vector_store_connector():
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def embedding_retriever(top_k, mock_vector_store_connector):
|
||||
return EmbeddingRetriever(
|
||||
top_k=top_k,
|
||||
query_rewrite=False,
|
||||
vector_store_connector=mock_vector_store_connector,
|
||||
)
|
||||
|
||||
|
||||
def test_retrieve(query, top_k, mock_vector_store_connector, embedding_retriever):
|
||||
expected_chunks = [Chunk() for _ in range(top_k)]
|
||||
mock_vector_store_connector.similar_search.return_value = expected_chunks
|
||||
|
||||
retrieved_chunks = embedding_retriever._retrieve(query)
|
||||
|
||||
mock_vector_store_connector.similar_search.assert_called_once_with(query, top_k)
|
||||
assert len(retrieved_chunks) == top_k
|
Reference in New Issue
Block a user