mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-08 12:30:14 +00:00
chore: Add pylint for DB-GPT rag lib (#1267)
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
"""DBSchema retriever."""
|
||||
from functools import reduce
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, cast
|
||||
|
||||
from dbgpt.datasource.rdbms.base import RDBMSDatabase
|
||||
from dbgpt.rag.chunk import Chunk
|
||||
@@ -15,23 +16,23 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vector_store_connector: VectorStoreConnector,
|
||||
top_k: int = 4,
|
||||
connection: Optional[RDBMSDatabase] = None,
|
||||
query_rewrite: bool = False,
|
||||
rerank: Ranker = None,
|
||||
vector_store_connector: Optional[VectorStoreConnector] = None,
|
||||
rerank: Optional[Ranker] = None,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
"""Create DBSchemaRetriever.
|
||||
|
||||
Args:
|
||||
vector_store_connector (VectorStoreConnector): vector store connector
|
||||
top_k (int): top k
|
||||
connection (Optional[RDBMSDatabase]): RDBMSDatabase connection.
|
||||
query_rewrite (bool): query rewrite
|
||||
rerank (Ranker): rerank
|
||||
vector_store_connector (VectorStoreConnector): vector store connector
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect
|
||||
@@ -78,12 +79,9 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
top_k=3, vector_store_connector=vector_connector
|
||||
)
|
||||
chunks = retriever.retrieve("show columns from table")
|
||||
print(
|
||||
f"db struct rag example results:{[chunk.content for chunk in chunks]}"
|
||||
)
|
||||
|
||||
result = [chunk.content for chunk in chunks]
|
||||
print(f"db struct rag example results:{result}")
|
||||
"""
|
||||
|
||||
self._top_k = top_k
|
||||
self._connection = connection
|
||||
self._query_rewrite = query_rewrite
|
||||
@@ -95,8 +93,12 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
|
||||
def _retrieve(self, query: str) -> List[Chunk]:
|
||||
"""Retrieve knowledge chunks.
|
||||
|
||||
Args:
|
||||
query (str): query text
|
||||
|
||||
Returns:
|
||||
List[Chunk]: list of chunks
|
||||
"""
|
||||
if self._need_embeddings:
|
||||
queries = [query]
|
||||
@@ -104,32 +106,45 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
self._vector_store_connector.similar_search(query, self._top_k)
|
||||
for query in queries
|
||||
]
|
||||
candidates = reduce(lambda x, y: x + y, candidates)
|
||||
return candidates
|
||||
return cast(List[Chunk], reduce(lambda x, y: x + y, candidates))
|
||||
else:
|
||||
if not self._connection:
|
||||
raise RuntimeError("RDBMSDatabase connection is required.")
|
||||
table_summaries = _parse_db_summary(self._connection)
|
||||
return [Chunk(content=table_summary) for table_summary in table_summaries]
|
||||
|
||||
def _retrieve_with_score(self, query: str, score_threshold: float) -> List[Chunk]:
|
||||
"""Retrieve knowledge chunks with score.
|
||||
|
||||
Args:
|
||||
query (str): query text
|
||||
score_threshold (float): score threshold
|
||||
|
||||
Returns:
|
||||
List[Chunk]: list of chunks
|
||||
"""
|
||||
return self._retrieve(query)
|
||||
|
||||
async def _aretrieve(self, query: str) -> List[Chunk]:
|
||||
"""Retrieve knowledge chunks.
|
||||
|
||||
Args:
|
||||
query (str): query text
|
||||
|
||||
Returns:
|
||||
List[Chunk]: list of chunks
|
||||
"""
|
||||
if self._need_embeddings:
|
||||
queries = [query]
|
||||
candidates = [self._similarity_search(query) for query in queries]
|
||||
candidates = await run_async_tasks(tasks=candidates, concurrency_limit=1)
|
||||
return candidates
|
||||
result_candidates = await run_async_tasks(
|
||||
tasks=candidates, concurrency_limit=1
|
||||
)
|
||||
return result_candidates
|
||||
else:
|
||||
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
|
||||
from dbgpt.rag.summary.rdbms_db_summary import ( # noqa: F401
|
||||
_parse_db_summary,
|
||||
)
|
||||
|
||||
table_summaries = await run_async_tasks(
|
||||
tasks=[self._aparse_db_summary()], concurrency_limit=1
|
||||
@@ -140,6 +155,7 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
self, query: str, score_threshold: float
|
||||
) -> List[Chunk]:
|
||||
"""Retrieve knowledge chunks with score.
|
||||
|
||||
Args:
|
||||
query (str): query text
|
||||
score_threshold (float): score threshold
|
||||
@@ -157,4 +173,6 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
"""Similar search."""
|
||||
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
|
||||
|
||||
return _parse_db_summary()
|
||||
if not self._connection:
|
||||
raise RuntimeError("RDBMSDatabase connection is required.")
|
||||
return _parse_db_summary(self._connection)
|
||||
|
Reference in New Issue
Block a user