fix(rag): Fix db schema aretriever bug (#1755)

This commit is contained in:
Fangyin Cheng 2024-07-30 10:30:42 +08:00 committed by GitHub
parent 55c8b39e2e
commit 25d7d94b89
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 18 additions and 30 deletions

View File

@ -1,20 +0,0 @@
name: Trigger Auto Publish
on:
push:
tags:
- "*"
jobs:
trigger-api:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v2
- name: Trigger Publish API
run: |
curl -X POST ${{secrets.PUBLISH_SECRET_API}} \
-H "Content-Type: application/json" \
-d '{"tag": "${{ github.ref }}"}'

View File

@ -167,7 +167,7 @@ class DBSchemaRetriever(BaseRetriever):
result_candidates = await run_async_tasks(
tasks=candidates, concurrency_limit=1
)
return result_candidates
return cast(List[Chunk], reduce(lambda x, y: x + y, result_candidates))
else:
from dbgpt.rag.summary.rdbms_db_summary import ( # noqa: F401
_parse_db_summary,
@ -177,7 +177,9 @@ class DBSchemaRetriever(BaseRetriever):
tasks=[self._aparse_db_summary(root_tracer.get_current_span_id())],
concurrency_limit=1,
)
return [Chunk(content=table_summary) for table_summary in table_summaries]
return [
Chunk(content=table_summary) for table_summary in table_summaries[0]
]
async def _aretrieve_with_score(
self,

View File

@ -22,29 +22,35 @@ def mock_vector_store_connector():
@pytest.fixture
def dbstruct_retriever(mock_db_connection, mock_vector_store_connector):
def db_struct_retriever(mock_db_connection, mock_vector_store_connector):
return DBSchemaRetriever(
connector=mock_db_connection,
index_store=mock_vector_store_connector,
)
def mock_parse_db_summary() -> str:
def mock_parse_db_summary(conn) -> List[str]:
"""Patch _parse_db_summary method."""
return "Table summary"
return ["Table summary"]
# Mocking the _parse_db_summary method in your test function
@patch.object(
dbgpt.rag.summary.rdbms_db_summary, "_parse_db_summary", mock_parse_db_summary
)
def test_retrieve_with_mocked_summary(dbstruct_retriever):
def test_retrieve_with_mocked_summary(db_struct_retriever):
query = "Table summary"
chunks: List[Chunk] = dbstruct_retriever._retrieve(query)
chunks: List[Chunk] = db_struct_retriever._retrieve(query)
assert isinstance(chunks[0], Chunk)
assert chunks[0].content == "Table summary"
async def async_mock_parse_db_summary() -> str:
"""Asynchronous patch for _parse_db_summary method."""
return "Table summary"
@pytest.mark.asyncio
@patch.object(
dbgpt.rag.summary.rdbms_db_summary, "_parse_db_summary", mock_parse_db_summary
)
async def test_aretrieve_with_mocked_summary(db_struct_retriever):
query = "Table summary"
chunks: List[Chunk] = await db_struct_retriever._aretrieve(query)
assert isinstance(chunks[0], Chunk)
assert chunks[0].content == "Table summary"