DB-GPT/examples/awel/simple_nl_schema_sql_chart_example.py
2025-03-12 10:24:22 +08:00

262 lines
8.5 KiB
Python

import os
from typing import Any, Dict, Optional
from pandas import DataFrame
from dbgpt._private.pydantic import BaseModel, Field
from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH
from dbgpt.core import LLMClient, ModelMessage, ModelMessageRoleType, ModelRequest
from dbgpt.core.awel import DAG, HttpTrigger, JoinOperator, MapOperator
from dbgpt.datasource.rdbms.base import RDBMSConnector
from dbgpt.model.proxy import OpenAILLMClient
from dbgpt.rag.embedding import DefaultEmbeddingFactory
from dbgpt.util.chat_util import run_async_tasks
from dbgpt_ext.datasource.rdbms.conn_sqlite import SQLiteTempConnector
from dbgpt_ext.rag.operators.schema_linking import SchemaLinkingOperator
from dbgpt_ext.storage.vector_store.chroma_store import ChromaStore, ChromaVectorConfig
"""AWEL: Simple nl-schemalinking-sql-chart operator example
pre-requirements:
1. install openai python sdk
```
pip install "db-gpt[openai]"
```
2. set openai key and base
```
export OPENAI_API_KEY={your_openai_key}
export OPENAI_API_BASE={your_openai_base}
```
or
```
import os
os.environ["OPENAI_API_KEY"] = {your_openai_key}
os.environ["OPENAI_API_BASE"] = {your_openai_base}
```
python examples/awel/simple_nl_schema_sql_chart_example.py
Examples:
..code-block:: shell
curl --location 'http://127.0.0.1:5555/api/v1/awel/trigger/examples/rag/schema_linking' \
--header 'Content-Type: application/json' \
--data '{"query": "Statistics of user age in the user table are based on three categories: age is less than 10, age is greater than or equal to 10 and less than or equal to 20, and age is greater than 20. The first column of the statistical results is different ages, and the second column is count."}'
"""
INSTRUCTION = (
"I want you to act as a SQL terminal in front of an example database, you need only to return the sql "
"command to me.Below is an instruction that describes a task, Write a response that appropriately "
"completes the request.\n###Instruction:\n{}"
)
INPUT_PROMPT = "\n###Input:\n{}\n###Response:"
def _create_vector_connector():
"""Create vector connector."""
config = ChromaVectorConfig(
persist_path=PILOT_PATH,
name="embedding_rag_test",
embedding_fn=DefaultEmbeddingFactory(
default_model_name=os.path.join(MODEL_PATH, "text2vec-large-chinese"),
).create(),
)
return ChromaStore(config)
def _create_temporary_connection():
"""Create a temporary database connection for testing."""
connect = SQLiteTempConnector.create_temporary_db()
connect.create_temp_tables(
{
"user": {
"columns": {
"id": "INTEGER PRIMARY KEY",
"name": "TEXT",
"age": "INTEGER",
},
"data": [
(1, "Tom", 8),
(2, "Jerry", 16),
(3, "Jack", 18),
(4, "Alice", 20),
(5, "Bob", 22),
],
}
}
)
connect.create_temp_tables(
{
"job": {
"columns": {
"id": "INTEGER PRIMARY KEY",
"name": "TEXT",
"age": "INTEGER",
},
"data": [
(1, "student", 8),
(2, "student", 16),
(3, "student", 18),
(4, "teacher", 20),
(5, "teacher", 22),
],
}
}
)
connect.create_temp_tables(
{
"student": {
"columns": {
"id": "INTEGER PRIMARY KEY",
"name": "TEXT",
"age": "INTEGER",
"info": "TEXT",
},
"data": [
(1, "Andy", 8, "good"),
(2, "Jerry", 16, "bad"),
(3, "Wendy", 18, "good"),
(4, "Spider", 20, "bad"),
(5, "David", 22, "bad"),
],
}
}
)
return connect
def _prompt_join_fn(query: str, chunks: str) -> str:
prompt = INSTRUCTION.format(chunks + INPUT_PROMPT.format(query))
return prompt
class TriggerReqBody(BaseModel):
query: str = Field(..., description="User query")
class RequestHandleOperator(MapOperator[TriggerReqBody, Dict]):
def __init__(self, **kwargs):
super().__init__(**kwargs)
async def map(self, input_value: TriggerReqBody) -> Dict:
params = {
"query": input_value.query,
}
print(f"Receive input value: {input_value.query}")
return params
class SqlGenOperator(MapOperator[Any, Any]):
"""The Sql Generation Operator."""
def __init__(self, llm: Optional[LLMClient], model_name: str, **kwargs):
"""Init the sql generation operator
Args:
llm (Optional[LLMClient]): base llm
"""
super().__init__(**kwargs)
self._llm = llm
self._model_name = model_name
async def map(self, prompt_with_query_and_schema: str) -> str:
"""generate sql by llm.
Args:
prompt_with_query_and_schema (str): prompt
Return:
str: sql
"""
messages = [
ModelMessage(
role=ModelMessageRoleType.SYSTEM, content=prompt_with_query_and_schema
)
]
request = ModelRequest(model=self._model_name, messages=messages)
tasks = [self._llm.generate(request)]
output = await run_async_tasks(tasks=tasks, concurrency_limit=1)
sql = output[0].text
return sql
class SqlExecOperator(MapOperator[Any, Any]):
"""The Sql Execution Operator."""
def __init__(self, connector: Optional[RDBMSConnector] = None, **kwargs):
"""
Args:
connection (Optional[RDBMSConnector]): RDBMSConnector connection
"""
super().__init__(**kwargs)
self._connector = connector
def map(self, sql: str) -> DataFrame:
"""retrieve table schemas.
Args:
sql (str): query.
Return:
str: sql execution
"""
dataframe = self._connector.run_to_df(command=sql, fetch="all")
print(f"sql data is \n{dataframe}")
return dataframe
class ChartDrawOperator(MapOperator[Any, Any]):
"""The Chart Draw Operator."""
def __init__(self, **kwargs):
"""
Args:
connection (RDBMSConnector): The connection.
"""
super().__init__(**kwargs)
def map(self, df: DataFrame) -> str:
"""get sql result in db and draw.
Args:
sql (str): str.
"""
import matplotlib.pyplot as plt
category_column = df.columns[0]
count_column = df.columns[1]
plt.figure(figsize=(8, 4))
plt.bar(df[category_column], df[count_column])
plt.xlabel(category_column)
plt.ylabel(count_column)
plt.show()
return str(df)
with DAG("simple_nl_schema_sql_chart_example") as dag:
trigger = HttpTrigger(
"/examples/rag/schema_linking", methods="POST", request_body=TriggerReqBody
)
request_handle_task = RequestHandleOperator()
query_operator = MapOperator(lambda request: request["query"])
llm = (OpenAILLMClient(api_key=os.getenv("OPENAI_API_KEY", "your api key")),)
model_name = "gpt-3.5-turbo"
retriever_task = SchemaLinkingOperator(
connector=_create_temporary_connection(), llm=llm, model_name=model_name
)
prompt_join_operator = JoinOperator(combine_function=_prompt_join_fn)
sql_gen_operator = SqlGenOperator(llm=llm, model_name=model_name)
sql_exec_operator = SqlExecOperator(connector=_create_temporary_connection())
draw_chart_operator = ChartDrawOperator(connector=_create_temporary_connection())
trigger >> request_handle_task >> query_operator >> prompt_join_operator
(
trigger
>> request_handle_task
>> query_operator
>> retriever_task
>> prompt_join_operator
)
prompt_join_operator >> sql_gen_operator >> sql_exec_operator >> draw_chart_operator
if __name__ == "__main__":
if dag.leaf_nodes[0].dev_mode:
# Development mode, you can run the dag locally for debugging.
from dbgpt.core.awel import setup_dev_environment
setup_dev_environment([dag], port=5555)
else:
pass