fix(rag): Fix schema linking error (#1637)

This commit is contained in:
Fangyin Cheng 2024-06-15 14:15:58 +08:00 committed by GitHub
parent bb7f41bdba
commit e1e94f997a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -11,9 +11,9 @@ from dbgpt.core import (
ModelRequest,
)
from dbgpt.datasource.base import BaseConnector
from dbgpt.rag.index.base import IndexStoreBase
from dbgpt.rag.schemalinker.base_linker import BaseSchemaLinker
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
from dbgpt.serve.rag.connector import VectorStoreConnector
from dbgpt.util.chat_util import run_async_tasks
INSTRUCTION = """
@ -46,8 +46,7 @@ class SchemaLinking(BaseSchemaLinker):
model_name: str,
llm: LLMClient,
top_k: int = 5,
vector_store_connector: Optional[VectorStoreConnector] = None,
**kwargs
index_store: Optional[IndexStoreBase] = None,
):
"""Create the schema linking instance.
@ -55,12 +54,11 @@ class SchemaLinking(BaseSchemaLinker):
connection (Optional[BaseConnector]): BaseConnector connection.
llm (Optional[LLMClient]): base llm
"""
super().__init__(**kwargs)
self._top_k = top_k
self._connector = connector
self._llm = llm
self._model_name = model_name
self._vector_store_connector = vector_store_connector
self._index_store = index_store
def _schema_linking(self, query: str) -> List:
"""Get all db schema info."""
@ -71,11 +69,10 @@ class SchemaLinking(BaseSchemaLinker):
def _schema_linking_with_vector_db(self, query: str) -> List[Chunk]:
queries = [query]
if not self._vector_store_connector:
if not self._index_store:
raise ValueError("Vector store connector is not provided.")
candidates = [
self._vector_store_connector.similar_search(query, self._top_k)
for query in queries
self._index_store.similar_search(query, self._top_k) for query in queries
]
return cast(List[Chunk], reduce(lambda x, y: x + y, candidates))