chore: Add pylint for DB-GPT rag lib (#1267)

This commit is contained in:
Fangyin Cheng
2024-03-07 23:27:43 +08:00
committed by GitHub
parent aaaf34db17
commit 7446817340
70 changed files with 1135 additions and 587 deletions

View File

@@ -1,5 +1,6 @@
"""DBSchema retriever."""
from functools import reduce
from typing import List, Optional
from typing import List, Optional, cast
from dbgpt.datasource.rdbms.base import RDBMSDatabase
from dbgpt.rag.chunk import Chunk
@@ -15,23 +16,23 @@ class DBSchemaRetriever(BaseRetriever):
def __init__(
self,
vector_store_connector: VectorStoreConnector,
top_k: int = 4,
connection: Optional[RDBMSDatabase] = None,
query_rewrite: bool = False,
rerank: Ranker = None,
vector_store_connector: Optional[VectorStoreConnector] = None,
rerank: Optional[Ranker] = None,
**kwargs
):
"""
"""Create DBSchemaRetriever.
Args:
vector_store_connector (VectorStoreConnector): vector store connector
top_k (int): top k
connection (Optional[RDBMSDatabase]): RDBMSDatabase connection.
query_rewrite (bool): query rewrite
rerank (Ranker): rerank
vector_store_connector (VectorStoreConnector): vector store connector
Examples:
.. code-block:: python
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect
@@ -78,12 +79,9 @@ class DBSchemaRetriever(BaseRetriever):
top_k=3, vector_store_connector=vector_connector
)
chunks = retriever.retrieve("show columns from table")
print(
f"db struct rag example results:{[chunk.content for chunk in chunks]}"
)
result = [chunk.content for chunk in chunks]
print(f"db struct rag example results:{result}")
"""
self._top_k = top_k
self._connection = connection
self._query_rewrite = query_rewrite
@@ -95,8 +93,12 @@ class DBSchemaRetriever(BaseRetriever):
def _retrieve(self, query: str) -> List[Chunk]:
"""Retrieve knowledge chunks.
Args:
query (str): query text
Returns:
List[Chunk]: list of chunks
"""
if self._need_embeddings:
queries = [query]
@@ -104,32 +106,45 @@ class DBSchemaRetriever(BaseRetriever):
self._vector_store_connector.similar_search(query, self._top_k)
for query in queries
]
candidates = reduce(lambda x, y: x + y, candidates)
return candidates
return cast(List[Chunk], reduce(lambda x, y: x + y, candidates))
else:
if not self._connection:
raise RuntimeError("RDBMSDatabase connection is required.")
table_summaries = _parse_db_summary(self._connection)
return [Chunk(content=table_summary) for table_summary in table_summaries]
def _retrieve_with_score(self, query: str, score_threshold: float) -> List[Chunk]:
"""Retrieve knowledge chunks with score.
Args:
query (str): query text
score_threshold (float): score threshold
Returns:
List[Chunk]: list of chunks
"""
return self._retrieve(query)
async def _aretrieve(self, query: str) -> List[Chunk]:
"""Retrieve knowledge chunks.
Args:
query (str): query text
Returns:
List[Chunk]: list of chunks
"""
if self._need_embeddings:
queries = [query]
candidates = [self._similarity_search(query) for query in queries]
candidates = await run_async_tasks(tasks=candidates, concurrency_limit=1)
return candidates
result_candidates = await run_async_tasks(
tasks=candidates, concurrency_limit=1
)
return result_candidates
else:
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
from dbgpt.rag.summary.rdbms_db_summary import ( # noqa: F401
_parse_db_summary,
)
table_summaries = await run_async_tasks(
tasks=[self._aparse_db_summary()], concurrency_limit=1
@@ -140,6 +155,7 @@ class DBSchemaRetriever(BaseRetriever):
self, query: str, score_threshold: float
) -> List[Chunk]:
"""Retrieve knowledge chunks with score.
Args:
query (str): query text
score_threshold (float): score threshold
@@ -157,4 +173,6 @@ class DBSchemaRetriever(BaseRetriever):
"""Similar search."""
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
return _parse_db_summary()
if not self._connection:
raise RuntimeError("RDBMSDatabase connection is required.")
return _parse_db_summary(self._connection)