mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-30 15:21:02 +00:00
fix(rag): Fix db schema aretriever bug (#1755)
This commit is contained in:
parent
55c8b39e2e
commit
25d7d94b89
20
.github/workflows/sync-docs.yaml
vendored
20
.github/workflows/sync-docs.yaml
vendored
@ -1,20 +0,0 @@
|
||||
name: Trigger Auto Publish
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- "*"
|
||||
|
||||
jobs:
|
||||
trigger-api:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Trigger Publish API
|
||||
run: |
|
||||
curl -X POST ${{secrets.PUBLISH_SECRET_API}} \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"tag": "${{ github.ref }}"}'
|
@ -167,7 +167,7 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
result_candidates = await run_async_tasks(
|
||||
tasks=candidates, concurrency_limit=1
|
||||
)
|
||||
return result_candidates
|
||||
return cast(List[Chunk], reduce(lambda x, y: x + y, result_candidates))
|
||||
else:
|
||||
from dbgpt.rag.summary.rdbms_db_summary import ( # noqa: F401
|
||||
_parse_db_summary,
|
||||
@ -177,7 +177,9 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
tasks=[self._aparse_db_summary(root_tracer.get_current_span_id())],
|
||||
concurrency_limit=1,
|
||||
)
|
||||
return [Chunk(content=table_summary) for table_summary in table_summaries]
|
||||
return [
|
||||
Chunk(content=table_summary) for table_summary in table_summaries[0]
|
||||
]
|
||||
|
||||
async def _aretrieve_with_score(
|
||||
self,
|
||||
|
@ -22,29 +22,35 @@ def mock_vector_store_connector():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dbstruct_retriever(mock_db_connection, mock_vector_store_connector):
|
||||
def db_struct_retriever(mock_db_connection, mock_vector_store_connector):
|
||||
return DBSchemaRetriever(
|
||||
connector=mock_db_connection,
|
||||
index_store=mock_vector_store_connector,
|
||||
)
|
||||
|
||||
|
||||
def mock_parse_db_summary() -> str:
|
||||
def mock_parse_db_summary(conn) -> List[str]:
|
||||
"""Patch _parse_db_summary method."""
|
||||
return "Table summary"
|
||||
return ["Table summary"]
|
||||
|
||||
|
||||
# Mocking the _parse_db_summary method in your test function
|
||||
@patch.object(
|
||||
dbgpt.rag.summary.rdbms_db_summary, "_parse_db_summary", mock_parse_db_summary
|
||||
)
|
||||
def test_retrieve_with_mocked_summary(dbstruct_retriever):
|
||||
def test_retrieve_with_mocked_summary(db_struct_retriever):
|
||||
query = "Table summary"
|
||||
chunks: List[Chunk] = dbstruct_retriever._retrieve(query)
|
||||
chunks: List[Chunk] = db_struct_retriever._retrieve(query)
|
||||
assert isinstance(chunks[0], Chunk)
|
||||
assert chunks[0].content == "Table summary"
|
||||
|
||||
|
||||
async def async_mock_parse_db_summary() -> str:
|
||||
"""Asynchronous patch for _parse_db_summary method."""
|
||||
return "Table summary"
|
||||
@pytest.mark.asyncio
|
||||
@patch.object(
|
||||
dbgpt.rag.summary.rdbms_db_summary, "_parse_db_summary", mock_parse_db_summary
|
||||
)
|
||||
async def test_aretrieve_with_mocked_summary(db_struct_retriever):
|
||||
query = "Table summary"
|
||||
chunks: List[Chunk] = await db_struct_retriever._aretrieve(query)
|
||||
assert isinstance(chunks[0], Chunk)
|
||||
assert chunks[0].content == "Table summary"
|
||||
|
Loading…
Reference in New Issue
Block a user