feat: add schema-linking awel example (#1081)

This commit is contained in:
junewgl 2024-01-21 09:59:59 +08:00 committed by GitHub
parent 2d905191f8
commit 4f833634df
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 445 additions and 0 deletions

View File

@ -0,0 +1,44 @@
from typing import Any, Optional
from dbgpt.core import LLMClient
from dbgpt.core.awel import MapOperator
from dbgpt.datasource.rdbms.base import RDBMSDatabase
from dbgpt.rag.schemalinker.schema_linking import SchemaLinking
from dbgpt.storage.vector_store.connector import VectorStoreConnector
class SchemaLinkingOperator(MapOperator[Any, Any]):
"""The Schema Linking Operator."""
def __init__(
self,
top_k: int = 5,
connection: Optional[RDBMSDatabase] = None,
llm: Optional[LLMClient] = None,
model_name: Optional[str] = None,
vector_store_connector: Optional[VectorStoreConnector] = None,
**kwargs
):
"""Init the schema linking operator
Args:
connection (RDBMSDatabase): The connection.
llm (Optional[LLMClient]): base llm
"""
super().__init__(**kwargs)
self._schema_linking = SchemaLinking(
top_k=top_k,
connection=connection,
llm=llm,
model_name=model_name,
vector_store_connector=vector_store_connector,
)
async def map(self, query: str) -> str:
"""retrieve table schemas.
Args:
query (str): query.
Return:
str: schema info
"""
return str(await self._schema_linking.schema_linking_with_llm(query))

View File

View File

@ -0,0 +1,60 @@
from abc import ABC, abstractmethod
from typing import List
class BaseSchemaLinker(ABC):
"""Base Linker."""
def schema_linking(self, query: str) -> List:
"""
Args:
query (str): query text
Returns:
List: list of schema
"""
return self._schema_linking(query)
def schema_linking_with_vector_db(self, query: str) -> List:
"""
Args:
query (str): query text
Returns:
List: list of schema
"""
return self._schema_linking_with_vector_db(query)
async def schema_linking_with_llm(self, query: str) -> List:
""" "
Args:
query(str): query text
Returns:
List: list of schema
"""
return await self._schema_linking_with_llm(query)
@abstractmethod
def _schema_linking(self, query: str) -> List:
"""
Args:
query (str): query text
Returns:
List: list of schema
"""
@abstractmethod
def _schema_linking_with_vector_db(self, query: str) -> List:
"""
Args:
query (str): query text
Returns:
List: list of schema
"""
@abstractmethod
async def _schema_linking_with_llm(self, query: str) -> List:
"""
Args:
query (str): query text
Returns:
List: list of schema
"""

View File

@ -0,0 +1,78 @@
from functools import reduce
from typing import List, Optional
from dbgpt.core import LLMClient, ModelMessage, ModelMessageRoleType, ModelRequest
from dbgpt.datasource.rdbms.base import RDBMSDatabase
from dbgpt.rag.chunk import Chunk
from dbgpt.rag.schemalinker.base_linker import BaseSchemaLinker
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
from dbgpt.storage.vector_store.connector import VectorStoreConnector
from dbgpt.util.chat_util import run_async_tasks
INSTRUCTION = (
"You need to filter out the most relevant database table schema information (it may be a single "
"table or multiple tables) required to generate the SQL of the question query from the given "
"database schema information. First, I will show you an example of an instruction followed by "
"the correct schema response. Then, I will give you a new instruction, and you should write "
"the schema response that appropriately completes the request.\n### Example1 Instruction:\n"
"['job(id, name, age)', 'user(id, name, age)', 'student(id, name, age, info)']\n### Example1 "
"Input:\nFind the age of student table\n### Example1 Response:\n['student(id, name, age, info)']"
"\n###New Instruction:\n{}"
)
INPUT_PROMPT = "\n###New Input:\n{}\n###New Response:"
class SchemaLinking(BaseSchemaLinker):
"""SchemaLinking by LLM"""
def __init__(
self,
top_k: int = 5,
connection: Optional[RDBMSDatabase] = None,
llm: Optional[LLMClient] = None,
model_name: Optional[str] = None,
vector_store_connector: Optional[VectorStoreConnector] = None,
**kwargs
):
"""
Args:
connection (Optional[RDBMSDatabase]): RDBMSDatabase connection.
llm (Optional[LLMClient]): base llm
"""
super().__init__(**kwargs)
self._top_k = top_k
self._connection = connection
self._llm = llm
self._model_name = model_name
self._vector_store_connector = vector_store_connector
def _schema_linking(self, query: str) -> List:
"""get all db schema info"""
table_summaries = _parse_db_summary(self._connection)
chunks = [Chunk(content=table_summary) for table_summary in table_summaries]
chunks_content = [chunk.content for chunk in chunks]
return chunks_content
def _schema_linking_with_vector_db(self, query: str) -> List:
queries = [query]
candidates = [
self._vector_store_connector.similar_search(query, self._top_k)
for query in queries
]
candidates = reduce(lambda x, y: x + y, candidates)
return candidates
async def _schema_linking_with_llm(self, query: str) -> List:
chunks_content = self.schema_linking(query)
schema_prompt = INSTRUCTION.format(
str(chunks_content) + INPUT_PROMPT.format(query)
)
messages = [
ModelMessage(role=ModelMessageRoleType.SYSTEM, content=schema_prompt)
]
request = ModelRequest(model=self._model_name, messages=messages)
tasks = [self._llm.generate(request)]
# get accurate schem info by llm
schema = await run_async_tasks(tasks=tasks, concurrency_limit=1)
schema_text = schema[0].text
return schema_text

View File

@ -0,0 +1,263 @@
import os
from typing import Any, Dict, Optional
from pandas import DataFrame
from pydantic import BaseModel, Field
from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH
from dbgpt.core import LLMClient, ModelMessage, ModelMessageRoleType, ModelRequest
from dbgpt.core.awel import DAG, HttpTrigger, JoinOperator, MapOperator
from dbgpt.datasource.rdbms.base import RDBMSDatabase
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect
from dbgpt.model import OpenAILLMClient
from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory
from dbgpt.rag.operator.schema_linking import SchemaLinkingOperator
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
from dbgpt.storage.vector_store.connector import VectorStoreConnector
from dbgpt.util.chat_util import run_async_tasks
"""AWEL: Simple nl-schemalinking-sql-chart operator example
pre-requirements:
1. install openai python sdk
```
pip install "db-gpt[openai]"
```
2. set openai key and base
```
export OPENAI_API_KEY={your_openai_key}
export OPENAI_API_BASE={your_openai_base}
```
or
```
import os
os.environ["OPENAI_API_KEY"] = {your_openai_key}
os.environ["OPENAI_API_BASE"] = {your_openai_base}
```
python examples/awel/simple_nl_schema_sql_chart_example.py
Examples:
..code-block:: shell
curl --location 'http://127.0.0.1:5555/api/v1/awel/trigger/examples/rag/schema_linking' \
--header 'Content-Type: application/json' \
--data '{"query": "Statistics of user age in the user table are based on three categories: age is less than 10, age is greater than or equal to 10 and less than or equal to 20, and age is greater than 20. The first column of the statistical results is different ages, and the second column is count."}'
"""
INSTRUCTION = (
"I want you to act as a SQL terminal in front of an example database, you need only to return the sql "
"command to me.Below is an instruction that describes a task, Write a response that appropriately "
"completes the request.\n###Instruction:\n{}"
)
INPUT_PROMPT = "\n###Input:\n{}\n###Response:"
def _create_vector_connector():
"""Create vector connector."""
return VectorStoreConnector.from_default(
"Chroma",
vector_store_config=ChromaVectorConfig(
name="vector_name",
persist_path=os.path.join(PILOT_PATH, "data"),
),
embedding_fn=DefaultEmbeddingFactory(
default_model_name=os.path.join(MODEL_PATH, "text2vec-large-chinese"),
).create(),
)
def _create_temporary_connection():
"""Create a temporary database connection for testing."""
connect = SQLiteTempConnect.create_temporary_db()
connect.create_temp_tables(
{
"user": {
"columns": {
"id": "INTEGER PRIMARY KEY",
"name": "TEXT",
"age": "INTEGER",
},
"data": [
(1, "Tom", 8),
(2, "Jerry", 16),
(3, "Jack", 18),
(4, "Alice", 20),
(5, "Bob", 22),
],
}
}
)
connect.create_temp_tables(
{
"job": {
"columns": {
"id": "INTEGER PRIMARY KEY",
"name": "TEXT",
"age": "INTEGER",
},
"data": [
(1, "student", 8),
(2, "student", 16),
(3, "student", 18),
(4, "teacher", 20),
(5, "teacher", 22),
],
}
}
)
connect.create_temp_tables(
{
"student": {
"columns": {
"id": "INTEGER PRIMARY KEY",
"name": "TEXT",
"age": "INTEGER",
"info": "TEXT",
},
"data": [
(1, "Andy", 8, "good"),
(2, "Jerry", 16, "bad"),
(3, "Wendy", 18, "good"),
(4, "Spider", 20, "bad"),
(5, "David", 22, "bad"),
],
}
}
)
return connect
def _prompt_join_fn(query: str, chunks: str) -> str:
prompt = INSTRUCTION.format(chunks + INPUT_PROMPT.format(query))
return prompt
class TriggerReqBody(BaseModel):
query: str = Field(..., description="User query")
class RequestHandleOperator(MapOperator[TriggerReqBody, Dict]):
def __init__(self, **kwargs):
super().__init__(**kwargs)
async def map(self, input_value: TriggerReqBody) -> Dict:
params = {
"query": input_value.query,
}
print(f"Receive input value: {input_value.query}")
return params
class SqlGenOperator(MapOperator[Any, Any]):
"""The Sql Generation Operator."""
def __init__(self, llm: Optional[LLMClient], model_name: str, **kwargs):
"""Init the sql generation operator
Args:
llm (Optional[LLMClient]): base llm
"""
super().__init__(**kwargs)
self._llm = llm
self._model_name = model_name
async def map(self, prompt_with_query_and_schema: str) -> str:
"""generate sql by llm.
Args:
prompt_with_query_and_schema (str): prompt
Return:
str: sql
"""
messages = [
ModelMessage(
role=ModelMessageRoleType.SYSTEM, content=prompt_with_query_and_schema
)
]
request = ModelRequest(model=self._model_name, messages=messages)
tasks = [self._llm.generate(request)]
output = await run_async_tasks(tasks=tasks, concurrency_limit=1)
sql = output[0].text
return sql
class SqlExecOperator(MapOperator[Any, Any]):
"""The Sql Execution Operator."""
def __init__(self, connection: Optional[RDBMSDatabase] = None, **kwargs):
"""
Args:
connection (Optional[RDBMSDatabase]): RDBMSDatabase connection
"""
super().__init__(**kwargs)
self._connection = connection
def map(self, sql: str) -> DataFrame:
"""retrieve table schemas.
Args:
sql (str): query.
Return:
str: sql execution
"""
dataframe = self._connection.run_to_df(command=sql, fetch="all")
print(f"sql data is \n{dataframe}")
return dataframe
class ChartDrawOperator(MapOperator[Any, Any]):
"""The Chart Draw Operator."""
def __init__(self, **kwargs):
"""
Args:
connection (RDBMSDatabase): The connection.
"""
super().__init__(**kwargs)
def map(self, df: DataFrame) -> str:
"""get sql result in db and draw.
Args:
sql (str): str.
"""
import matplotlib.pyplot as plt
category_column = df.columns[0]
count_column = df.columns[1]
plt.figure(figsize=(8, 4))
plt.bar(df[category_column], df[count_column])
plt.xlabel(category_column)
plt.ylabel(count_column)
plt.show()
return str(df)
with (DAG("simple_nl_schema_sql_chart_example") as dag):
trigger = HttpTrigger(
"/examples/rag/schema_linking", methods="POST", request_body=TriggerReqBody
)
request_handle_task = RequestHandleOperator()
query_operator = MapOperator(lambda request: request["query"])
llm = OpenAILLMClient()
model_name = "gpt-3.5-turbo"
retriever_task = SchemaLinkingOperator(
connection=_create_temporary_connection(), llm=llm, model_name=model_name
)
prompt_join_operator = JoinOperator(combine_function=_prompt_join_fn)
sql_gen_operator = SqlGenOperator(llm=llm, model_name=model_name)
sql_exec_operator = SqlExecOperator(connection=_create_temporary_connection())
draw_chart_operator = ChartDrawOperator(connection=_create_temporary_connection())
trigger >> request_handle_task >> query_operator >> prompt_join_operator
(
trigger
>> request_handle_task
>> query_operator
>> retriever_task
>> prompt_join_operator
)
prompt_join_operator >> sql_gen_operator >> sql_exec_operator >> draw_chart_operator
if __name__ == "__main__":
if dag.leaf_nodes[0].dev_mode:
# Development mode, you can run the dag locally for debugging.
from dbgpt.core.awel import setup_dev_environment
setup_dev_environment([dag], port=5555)
else:
pass