mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-24 12:45:45 +00:00
93 lines
3.2 KiB
Python
93 lines
3.2 KiB
Python
"""SchemaLinking by LLM."""
|
|
|
|
from functools import reduce
|
|
from typing import List, Optional, cast
|
|
|
|
from dbgpt.core import (
|
|
Chunk,
|
|
LLMClient,
|
|
ModelMessage,
|
|
ModelMessageRoleType,
|
|
ModelRequest,
|
|
)
|
|
from dbgpt.datasource.base import BaseConnector
|
|
from dbgpt.rag.index.base import IndexStoreBase
|
|
from dbgpt.rag.schemalinker.base_linker import BaseSchemaLinker
|
|
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
|
|
from dbgpt.util.chat_util import run_async_tasks
|
|
|
|
INSTRUCTION = """
|
|
You need to filter out the most relevant database table schema information (it may be a
|
|
single table or multiple tables) required to generate the SQL of the question query
|
|
from the given database schema information. First, I will show you an example of an
|
|
instruction followed by the correct schema response. Then, I will give you a new
|
|
instruction, and you should write the schema response that appropriately completes the
|
|
request.
|
|
|
|
### Example1 Instruction:
|
|
['job(id, name, age)', 'user(id, name, age)', 'student(id, name, age, info)']
|
|
### Example1 Input:
|
|
Find the age of student table
|
|
### Example1 Response:
|
|
['student(id, name, age, info)']
|
|
###New Instruction:
|
|
{}
|
|
"""
|
|
|
|
INPUT_PROMPT = "\n###New Input:\n{}\n###New Response:"
|
|
|
|
|
|
class SchemaLinking(BaseSchemaLinker):
|
|
"""SchemaLinking by LLM."""
|
|
|
|
def __init__(
|
|
self,
|
|
connector: BaseConnector,
|
|
model_name: str,
|
|
llm: LLMClient,
|
|
top_k: int = 5,
|
|
index_store: Optional[IndexStoreBase] = None,
|
|
):
|
|
"""Create the schema linking instance.
|
|
|
|
Args:
|
|
connection (Optional[BaseConnector]): BaseConnector connection.
|
|
llm (Optional[LLMClient]): base llm
|
|
"""
|
|
self._top_k = top_k
|
|
self._connector = connector
|
|
self._llm = llm
|
|
self._model_name = model_name
|
|
self._index_store = index_store
|
|
|
|
def _schema_linking(self, query: str) -> List:
|
|
"""Get all db schema info."""
|
|
table_summaries = _parse_db_summary(self._connector)
|
|
chunks = [Chunk(content=table_summary) for table_summary in table_summaries]
|
|
chunks_content = [chunk.content for chunk in chunks]
|
|
return chunks_content
|
|
|
|
def _schema_linking_with_vector_db(self, query: str) -> List[Chunk]:
|
|
queries = [query]
|
|
if not self._index_store:
|
|
raise ValueError("Vector store connector is not provided.")
|
|
candidates = [
|
|
self._index_store.similar_search(query, self._top_k) for query in queries
|
|
]
|
|
return cast(List[Chunk], reduce(lambda x, y: x + y, candidates))
|
|
|
|
async def _schema_linking_with_llm(self, query: str) -> List:
|
|
chunks_content = self.schema_linking(query)
|
|
schema_prompt = INSTRUCTION.format(
|
|
str(chunks_content) + INPUT_PROMPT.format(query)
|
|
)
|
|
messages = [
|
|
ModelMessage(role=ModelMessageRoleType.SYSTEM, content=schema_prompt)
|
|
]
|
|
request = ModelRequest(model=self._model_name, messages=messages)
|
|
tasks = [self._llm.generate(request)]
|
|
# get accurate schem info by llm
|
|
schema = await run_async_tasks(tasks=tasks, concurrency_limit=1)
|
|
schema_text = schema[0].text
|
|
return schema_text
|