feat(rag): Support RAG SDK (#1322)

This commit is contained in:
Fangyin Cheng
2024-03-22 15:36:57 +08:00
committed by GitHub
parent e65732d6e4
commit 8a17099dd2
69 changed files with 1332 additions and 558 deletions

View File

@@ -10,7 +10,7 @@ from dbgpt.core import (
ModelMessageRoleType,
ModelRequest,
)
from dbgpt.datasource.rdbms.base import RDBMSConnector
from dbgpt.datasource.base import BaseConnector
from dbgpt.rag.schemalinker.base_linker import BaseSchemaLinker
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
from dbgpt.storage.vector_store.connector import VectorStoreConnector
@@ -42,7 +42,7 @@ class SchemaLinking(BaseSchemaLinker):
def __init__(
self,
connection: RDBMSConnector,
connector: BaseConnector,
model_name: str,
llm: LLMClient,
top_k: int = 5,
@@ -52,19 +52,19 @@ class SchemaLinking(BaseSchemaLinker):
"""Create the schema linking instance.
Args:
connection (Optional[RDBMSConnector]): RDBMSConnector connection.
connection (Optional[BaseConnector]): BaseConnector connection.
llm (Optional[LLMClient]): base llm
"""
super().__init__(**kwargs)
self._top_k = top_k
self._connection = connection
self._connector = connector
self._llm = llm
self._model_name = model_name
self._vector_store_connector = vector_store_connector
def _schema_linking(self, query: str) -> List:
"""Get all db schema info."""
table_summaries = _parse_db_summary(self._connection)
table_summaries = _parse_db_summary(self._connector)
chunks = [Chunk(content=table_summary) for table_summary in table_summaries]
chunks_content = [chunk.content for chunk in chunks]
return chunks_content