mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-02 17:45:31 +00:00
feat(rag): Support RAG SDK (#1322)
This commit is contained in:
@@ -10,7 +10,7 @@ from dbgpt.core import (
|
||||
ModelMessageRoleType,
|
||||
ModelRequest,
|
||||
)
|
||||
from dbgpt.datasource.rdbms.base import RDBMSConnector
|
||||
from dbgpt.datasource.base import BaseConnector
|
||||
from dbgpt.rag.schemalinker.base_linker import BaseSchemaLinker
|
||||
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
@@ -42,7 +42,7 @@ class SchemaLinking(BaseSchemaLinker):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connection: RDBMSConnector,
|
||||
connector: BaseConnector,
|
||||
model_name: str,
|
||||
llm: LLMClient,
|
||||
top_k: int = 5,
|
||||
@@ -52,19 +52,19 @@ class SchemaLinking(BaseSchemaLinker):
|
||||
"""Create the schema linking instance.
|
||||
|
||||
Args:
|
||||
connection (Optional[RDBMSConnector]): RDBMSConnector connection.
|
||||
connection (Optional[BaseConnector]): BaseConnector connection.
|
||||
llm (Optional[LLMClient]): base llm
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._top_k = top_k
|
||||
self._connection = connection
|
||||
self._connector = connector
|
||||
self._llm = llm
|
||||
self._model_name = model_name
|
||||
self._vector_store_connector = vector_store_connector
|
||||
|
||||
def _schema_linking(self, query: str) -> List:
|
||||
"""Get all db schema info."""
|
||||
table_summaries = _parse_db_summary(self._connection)
|
||||
table_summaries = _parse_db_summary(self._connector)
|
||||
chunks = [Chunk(content=table_summary) for table_summary in table_summaries]
|
||||
chunks_content = [chunk.content for chunk in chunks]
|
||||
return chunks_content
|
||||
|
Reference in New Issue
Block a user