feat(RAG):add rag operators and rag awel examples (#1061)

Co-authored-by: csunny <cfqsunny@163.com>
This commit is contained in:
Aries-ckt
2024-01-13 16:14:48 +08:00
committed by GitHub
parent 99ea6ac1a4
commit a035433170
29 changed files with 1010 additions and 102 deletions

View File

@@ -0,0 +1,151 @@
from functools import reduce
from typing import List, Optional
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
from dbgpt.util.chat_util import run_async_tasks
from dbgpt.datasource.rdbms.base import RDBMSDatabase
from dbgpt.rag.chunk import Chunk
from dbgpt.rag.retriever.base import BaseRetriever
from dbgpt.rag.retriever.rerank import Ranker, DefaultRanker
from dbgpt.storage.vector_store.connector import VectorStoreConnector
class DBSchemaRetriever(BaseRetriever):
"""DBSchema retriever."""
def __init__(
self,
top_k: int = 4,
connection: Optional[RDBMSDatabase] = None,
query_rewrite: bool = False,
rerank: Ranker = None,
vector_store_connector: Optional[VectorStoreConnector] = None,
**kwargs
):
"""
Args:
top_k (int): top k
connection (Optional[RDBMSDatabase]): RDBMSDatabase connection.
query_rewrite (bool): query rewrite
rerank (Ranker): rerank
vector_store_connector (VectorStoreConnector): vector store connector
code example:
.. code-block:: python
>>> from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect
>>> from dbgpt.serve.rag.assembler.db_schema import DBSchemaAssembler
>>> from dbgpt.storage.vector_store.connector import VectorStoreConnector
>>> from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
>>> from dbgpt.rag.retriever.embedding import EmbeddingRetriever
def _create_temporary_connection():
connect = SQLiteTempConnect.create_temporary_db()
connect.create_temp_tables(
{
"user": {
"columns": {
"id": "INTEGER PRIMARY KEY",
"name": "TEXT",
"age": "INTEGER",
},
"data": [
(1, "Tom", 10),
(2, "Jerry", 16),
(3, "Jack", 18),
(4, "Alice", 20),
(5, "Bob", 22),
],
}
}
)
return connect
connection = _create_temporary_connection()
vector_store_config = ChromaVectorConfig(name="vector_store_name")
embedding_model_path = "{your_embedding_model_path}"
embedding_fn = embedding_factory.create(
model_name=embedding_model_path
)
vector_connector = VectorStoreConnector.from_default(
"Chroma",
vector_store_config=vector_store_config,
embedding_fn=embedding_fn
)
# get db struct retriever
retriever = DBSchemaRetriever(top_k=3, vector_store_connector=vector_connector)
chunks = retriever.retrieve("show columns from table")
print(f"db struct rag example results:{[chunk.content for chunk in chunks]}")
"""
self._top_k = top_k
self._connection = connection
self._query_rewrite = query_rewrite
self._vector_store_connector = vector_store_connector
self._need_embeddings = False
if self._vector_store_connector:
self._need_embeddings = True
self._rerank = rerank or DefaultRanker(self._top_k)
def _retrieve(self, query: str) -> List[Chunk]:
"""Retrieve knowledge chunks.
Args:
query (str): query text
"""
if self._need_embeddings:
queries = [query]
candidates = [
self._vector_store_connector.similar_search(query, self._top_k)
for query in queries
]
candidates = reduce(lambda x, y: x + y, candidates)
return candidates
else:
table_summaries = _parse_db_summary(self._connection)
return [Chunk(content=table_summary) for table_summary in table_summaries]
def _retrieve_with_score(self, query: str, score_threshold: float) -> List[Chunk]:
"""Retrieve knowledge chunks with score.
Args:
query (str): query text
score_threshold (float): score threshold
"""
return self._retrieve(query)
async def _aretrieve(self, query: str) -> List[Chunk]:
"""Retrieve knowledge chunks.
Args:
query (str): query text
"""
if self._need_embeddings:
queries = [query]
candidates = [self._similarity_search(query) for query in queries]
candidates = await run_async_tasks(tasks=candidates, concurrency_limit=1)
return candidates
else:
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
table_summaries = await run_async_tasks(
tasks=[self._aparse_db_summary()], concurrency_limit=1
)
return [Chunk(content=table_summary) for table_summary in table_summaries]
async def _aretrieve_with_score(
self, query: str, score_threshold: float
) -> List[Chunk]:
"""Retrieve knowledge chunks with score.
Args:
query (str): query text
score_threshold (float): score threshold
"""
return await self._aretrieve(query)
async def _similarity_search(self, query) -> List[Chunk]:
"""Similar search."""
return self._vector_store_connector.similar_search(
query,
self._top_k,
)
async def _aparse_db_summary(self) -> List[str]:
"""Similar search."""
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
return _parse_db_summary()