mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-01 08:11:45 +00:00
fix(rag): Fix db schema aretriever bug (#1755)
This commit is contained in:
parent
55c8b39e2e
commit
25d7d94b89
20
.github/workflows/sync-docs.yaml
vendored
20
.github/workflows/sync-docs.yaml
vendored
@ -1,20 +0,0 @@
|
|||||||
name: Trigger Auto Publish
|
|
||||||
|
|
||||||
on:
|
|
||||||
push:
|
|
||||||
tags:
|
|
||||||
- "*"
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
trigger-api:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v2
|
|
||||||
|
|
||||||
- name: Trigger Publish API
|
|
||||||
run: |
|
|
||||||
curl -X POST ${{secrets.PUBLISH_SECRET_API}} \
|
|
||||||
-H "Content-Type: application/json" \
|
|
||||||
-d '{"tag": "${{ github.ref }}"}'
|
|
@ -167,7 +167,7 @@ class DBSchemaRetriever(BaseRetriever):
|
|||||||
result_candidates = await run_async_tasks(
|
result_candidates = await run_async_tasks(
|
||||||
tasks=candidates, concurrency_limit=1
|
tasks=candidates, concurrency_limit=1
|
||||||
)
|
)
|
||||||
return result_candidates
|
return cast(List[Chunk], reduce(lambda x, y: x + y, result_candidates))
|
||||||
else:
|
else:
|
||||||
from dbgpt.rag.summary.rdbms_db_summary import ( # noqa: F401
|
from dbgpt.rag.summary.rdbms_db_summary import ( # noqa: F401
|
||||||
_parse_db_summary,
|
_parse_db_summary,
|
||||||
@ -177,7 +177,9 @@ class DBSchemaRetriever(BaseRetriever):
|
|||||||
tasks=[self._aparse_db_summary(root_tracer.get_current_span_id())],
|
tasks=[self._aparse_db_summary(root_tracer.get_current_span_id())],
|
||||||
concurrency_limit=1,
|
concurrency_limit=1,
|
||||||
)
|
)
|
||||||
return [Chunk(content=table_summary) for table_summary in table_summaries]
|
return [
|
||||||
|
Chunk(content=table_summary) for table_summary in table_summaries[0]
|
||||||
|
]
|
||||||
|
|
||||||
async def _aretrieve_with_score(
|
async def _aretrieve_with_score(
|
||||||
self,
|
self,
|
||||||
|
@ -22,29 +22,35 @@ def mock_vector_store_connector():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def dbstruct_retriever(mock_db_connection, mock_vector_store_connector):
|
def db_struct_retriever(mock_db_connection, mock_vector_store_connector):
|
||||||
return DBSchemaRetriever(
|
return DBSchemaRetriever(
|
||||||
connector=mock_db_connection,
|
connector=mock_db_connection,
|
||||||
index_store=mock_vector_store_connector,
|
index_store=mock_vector_store_connector,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def mock_parse_db_summary() -> str:
|
def mock_parse_db_summary(conn) -> List[str]:
|
||||||
"""Patch _parse_db_summary method."""
|
"""Patch _parse_db_summary method."""
|
||||||
return "Table summary"
|
return ["Table summary"]
|
||||||
|
|
||||||
|
|
||||||
# Mocking the _parse_db_summary method in your test function
|
# Mocking the _parse_db_summary method in your test function
|
||||||
@patch.object(
|
@patch.object(
|
||||||
dbgpt.rag.summary.rdbms_db_summary, "_parse_db_summary", mock_parse_db_summary
|
dbgpt.rag.summary.rdbms_db_summary, "_parse_db_summary", mock_parse_db_summary
|
||||||
)
|
)
|
||||||
def test_retrieve_with_mocked_summary(dbstruct_retriever):
|
def test_retrieve_with_mocked_summary(db_struct_retriever):
|
||||||
query = "Table summary"
|
query = "Table summary"
|
||||||
chunks: List[Chunk] = dbstruct_retriever._retrieve(query)
|
chunks: List[Chunk] = db_struct_retriever._retrieve(query)
|
||||||
assert isinstance(chunks[0], Chunk)
|
assert isinstance(chunks[0], Chunk)
|
||||||
assert chunks[0].content == "Table summary"
|
assert chunks[0].content == "Table summary"
|
||||||
|
|
||||||
|
|
||||||
async def async_mock_parse_db_summary() -> str:
|
@pytest.mark.asyncio
|
||||||
"""Asynchronous patch for _parse_db_summary method."""
|
@patch.object(
|
||||||
return "Table summary"
|
dbgpt.rag.summary.rdbms_db_summary, "_parse_db_summary", mock_parse_db_summary
|
||||||
|
)
|
||||||
|
async def test_aretrieve_with_mocked_summary(db_struct_retriever):
|
||||||
|
query = "Table summary"
|
||||||
|
chunks: List[Chunk] = await db_struct_retriever._aretrieve(query)
|
||||||
|
assert isinstance(chunks[0], Chunk)
|
||||||
|
assert chunks[0].content == "Table summary"
|
||||||
|
Loading…
Reference in New Issue
Block a user