mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-26 21:37:40 +00:00
262 lines
8.5 KiB
Python
262 lines
8.5 KiB
Python
import os
|
|
from typing import Any, Dict, Optional
|
|
|
|
from pandas import DataFrame
|
|
|
|
from dbgpt._private.pydantic import BaseModel, Field
|
|
from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH
|
|
from dbgpt.core import LLMClient, ModelMessage, ModelMessageRoleType, ModelRequest
|
|
from dbgpt.core.awel import DAG, HttpTrigger, JoinOperator, MapOperator
|
|
from dbgpt.datasource.rdbms.base import RDBMSConnector
|
|
from dbgpt.model.proxy import OpenAILLMClient
|
|
from dbgpt.rag.embedding import DefaultEmbeddingFactory
|
|
from dbgpt.util.chat_util import run_async_tasks
|
|
from dbgpt_ext.datasource.rdbms.conn_sqlite import SQLiteTempConnector
|
|
from dbgpt_ext.rag.operators.schema_linking import SchemaLinkingOperator
|
|
from dbgpt_ext.storage.vector_store.chroma_store import ChromaStore, ChromaVectorConfig
|
|
|
|
"""AWEL: Simple nl-schemalinking-sql-chart operator example
|
|
|
|
pre-requirements:
|
|
1. install openai python sdk
|
|
```
|
|
pip install "db-gpt[openai]"
|
|
```
|
|
2. set openai key and base
|
|
```
|
|
export OPENAI_API_KEY={your_openai_key}
|
|
export OPENAI_API_BASE={your_openai_base}
|
|
```
|
|
or
|
|
```
|
|
import os
|
|
os.environ["OPENAI_API_KEY"] = {your_openai_key}
|
|
os.environ["OPENAI_API_BASE"] = {your_openai_base}
|
|
```
|
|
python examples/awel/simple_nl_schema_sql_chart_example.py
|
|
Examples:
|
|
..code-block:: shell
|
|
curl --location 'http://127.0.0.1:5555/api/v1/awel/trigger/examples/rag/schema_linking' \
|
|
--header 'Content-Type: application/json' \
|
|
--data '{"query": "Statistics of user age in the user table are based on three categories: age is less than 10, age is greater than or equal to 10 and less than or equal to 20, and age is greater than 20. The first column of the statistical results is different ages, and the second column is count."}'
|
|
"""
|
|
|
|
INSTRUCTION = (
|
|
"I want you to act as a SQL terminal in front of an example database, you need only to return the sql "
|
|
"command to me.Below is an instruction that describes a task, Write a response that appropriately "
|
|
"completes the request.\n###Instruction:\n{}"
|
|
)
|
|
INPUT_PROMPT = "\n###Input:\n{}\n###Response:"
|
|
|
|
|
|
def _create_vector_connector():
|
|
"""Create vector connector."""
|
|
config = ChromaVectorConfig(
|
|
persist_path=PILOT_PATH,
|
|
name="embedding_rag_test",
|
|
embedding_fn=DefaultEmbeddingFactory(
|
|
default_model_name=os.path.join(MODEL_PATH, "text2vec-large-chinese"),
|
|
).create(),
|
|
)
|
|
|
|
return ChromaStore(config)
|
|
|
|
|
|
def _create_temporary_connection():
|
|
"""Create a temporary database connection for testing."""
|
|
connect = SQLiteTempConnector.create_temporary_db()
|
|
connect.create_temp_tables(
|
|
{
|
|
"user": {
|
|
"columns": {
|
|
"id": "INTEGER PRIMARY KEY",
|
|
"name": "TEXT",
|
|
"age": "INTEGER",
|
|
},
|
|
"data": [
|
|
(1, "Tom", 8),
|
|
(2, "Jerry", 16),
|
|
(3, "Jack", 18),
|
|
(4, "Alice", 20),
|
|
(5, "Bob", 22),
|
|
],
|
|
}
|
|
}
|
|
)
|
|
connect.create_temp_tables(
|
|
{
|
|
"job": {
|
|
"columns": {
|
|
"id": "INTEGER PRIMARY KEY",
|
|
"name": "TEXT",
|
|
"age": "INTEGER",
|
|
},
|
|
"data": [
|
|
(1, "student", 8),
|
|
(2, "student", 16),
|
|
(3, "student", 18),
|
|
(4, "teacher", 20),
|
|
(5, "teacher", 22),
|
|
],
|
|
}
|
|
}
|
|
)
|
|
connect.create_temp_tables(
|
|
{
|
|
"student": {
|
|
"columns": {
|
|
"id": "INTEGER PRIMARY KEY",
|
|
"name": "TEXT",
|
|
"age": "INTEGER",
|
|
"info": "TEXT",
|
|
},
|
|
"data": [
|
|
(1, "Andy", 8, "good"),
|
|
(2, "Jerry", 16, "bad"),
|
|
(3, "Wendy", 18, "good"),
|
|
(4, "Spider", 20, "bad"),
|
|
(5, "David", 22, "bad"),
|
|
],
|
|
}
|
|
}
|
|
)
|
|
return connect
|
|
|
|
|
|
def _prompt_join_fn(query: str, chunks: str) -> str:
|
|
prompt = INSTRUCTION.format(chunks + INPUT_PROMPT.format(query))
|
|
return prompt
|
|
|
|
|
|
class TriggerReqBody(BaseModel):
|
|
query: str = Field(..., description="User query")
|
|
|
|
|
|
class RequestHandleOperator(MapOperator[TriggerReqBody, Dict]):
|
|
def __init__(self, **kwargs):
|
|
super().__init__(**kwargs)
|
|
|
|
async def map(self, input_value: TriggerReqBody) -> Dict:
|
|
params = {
|
|
"query": input_value.query,
|
|
}
|
|
print(f"Receive input value: {input_value.query}")
|
|
return params
|
|
|
|
|
|
class SqlGenOperator(MapOperator[Any, Any]):
|
|
"""The Sql Generation Operator."""
|
|
|
|
def __init__(self, llm: Optional[LLMClient], model_name: str, **kwargs):
|
|
"""Init the sql generation operator
|
|
Args:
|
|
llm (Optional[LLMClient]): base llm
|
|
"""
|
|
super().__init__(**kwargs)
|
|
self._llm = llm
|
|
self._model_name = model_name
|
|
|
|
async def map(self, prompt_with_query_and_schema: str) -> str:
|
|
"""generate sql by llm.
|
|
Args:
|
|
prompt_with_query_and_schema (str): prompt
|
|
Return:
|
|
str: sql
|
|
"""
|
|
|
|
messages = [
|
|
ModelMessage(
|
|
role=ModelMessageRoleType.SYSTEM, content=prompt_with_query_and_schema
|
|
)
|
|
]
|
|
request = ModelRequest(model=self._model_name, messages=messages)
|
|
tasks = [self._llm.generate(request)]
|
|
output = await run_async_tasks(tasks=tasks, concurrency_limit=1)
|
|
sql = output[0].text
|
|
return sql
|
|
|
|
|
|
class SqlExecOperator(MapOperator[Any, Any]):
|
|
"""The Sql Execution Operator."""
|
|
|
|
def __init__(self, connector: Optional[RDBMSConnector] = None, **kwargs):
|
|
"""
|
|
Args:
|
|
connection (Optional[RDBMSConnector]): RDBMSConnector connection
|
|
"""
|
|
super().__init__(**kwargs)
|
|
self._connector = connector
|
|
|
|
def map(self, sql: str) -> DataFrame:
|
|
"""retrieve table schemas.
|
|
Args:
|
|
sql (str): query.
|
|
Return:
|
|
str: sql execution
|
|
"""
|
|
dataframe = self._connector.run_to_df(command=sql, fetch="all")
|
|
print(f"sql data is \n{dataframe}")
|
|
return dataframe
|
|
|
|
|
|
class ChartDrawOperator(MapOperator[Any, Any]):
|
|
"""The Chart Draw Operator."""
|
|
|
|
def __init__(self, **kwargs):
|
|
"""
|
|
Args:
|
|
connection (RDBMSConnector): The connection.
|
|
"""
|
|
super().__init__(**kwargs)
|
|
|
|
def map(self, df: DataFrame) -> str:
|
|
"""get sql result in db and draw.
|
|
Args:
|
|
sql (str): str.
|
|
"""
|
|
import matplotlib.pyplot as plt
|
|
|
|
category_column = df.columns[0]
|
|
count_column = df.columns[1]
|
|
plt.figure(figsize=(8, 4))
|
|
plt.bar(df[category_column], df[count_column])
|
|
plt.xlabel(category_column)
|
|
plt.ylabel(count_column)
|
|
plt.show()
|
|
return str(df)
|
|
|
|
|
|
with DAG("simple_nl_schema_sql_chart_example") as dag:
|
|
trigger = HttpTrigger(
|
|
"/examples/rag/schema_linking", methods="POST", request_body=TriggerReqBody
|
|
)
|
|
request_handle_task = RequestHandleOperator()
|
|
query_operator = MapOperator(lambda request: request["query"])
|
|
llm = (OpenAILLMClient(api_key=os.getenv("OPENAI_API_KEY", "your api key")),)
|
|
model_name = "gpt-3.5-turbo"
|
|
retriever_task = SchemaLinkingOperator(
|
|
connector=_create_temporary_connection(), llm=llm, model_name=model_name
|
|
)
|
|
prompt_join_operator = JoinOperator(combine_function=_prompt_join_fn)
|
|
sql_gen_operator = SqlGenOperator(llm=llm, model_name=model_name)
|
|
sql_exec_operator = SqlExecOperator(connector=_create_temporary_connection())
|
|
draw_chart_operator = ChartDrawOperator(connector=_create_temporary_connection())
|
|
trigger >> request_handle_task >> query_operator >> prompt_join_operator
|
|
(
|
|
trigger
|
|
>> request_handle_task
|
|
>> query_operator
|
|
>> retriever_task
|
|
>> prompt_join_operator
|
|
)
|
|
prompt_join_operator >> sql_gen_operator >> sql_exec_operator >> draw_chart_operator
|
|
|
|
if __name__ == "__main__":
|
|
if dag.leaf_nodes[0].dev_mode:
|
|
# Development mode, you can run the dag locally for debugging.
|
|
from dbgpt.core.awel import setup_dev_environment
|
|
|
|
setup_dev_environment([dag], port=5555)
|
|
else:
|
|
pass
|