connect database

This commit is contained in:
csunny
2023-05-03 22:10:35 +08:00
parent a71c8b6d56
commit 88a5c57646
4 changed files with 87 additions and 26 deletions

View File

@@ -4,7 +4,7 @@
import dataclasses
from enum import auto, Enum
from typing import List, Any
from pilot.configs.model_config import DB_SETTINGS
class SeparatorStyle(Enum):
@@ -88,6 +88,19 @@ class Conversation:
}
def gen_sqlgen_conversation(dbname):
from pilot.connections.mysql_conn import MySQLOperator
mo = MySQLOperator(
**DB_SETTINGS
)
message = ""
schemas = mo.get_schema(dbname)
for s in schemas:
message += s["schema_info"] + ";"
return f"数据库{dbname}的Schema信息如下: {message}\n"
conv_one_shot = Conversation(
system="A chat between a curious human and an artificial intelligence assistant, who very familiar with database related knowledge. "
"The assistant gives helpful, detailed, professional and polite answers to the human's questions. ",
@@ -121,7 +134,7 @@ conv_one_shot = Conversation(
sep_style=SeparatorStyle.SINGLE,
sep="###"
)
conv_vicuna_v1 = Conversation(
system = "A chat between a curious user and an artificial intelligence assistant. who very familiar with database related knowledge. "
"The assistant gives helpful, detailed, professional and polite answers to the user's questions. ",
@@ -137,5 +150,10 @@ default_conversation = conv_one_shot
conv_templates = {
"conv_one_shot": conv_one_shot,
"vicuna_v1": conv_vicuna_v1
"vicuna_v1": conv_vicuna_v1,
}
if __name__ == "__main__":
message = gen_sqlgen_conversation("dbgpt")
print(message)