Merge branch 'dbgpt_doc' into ty_test

This commit is contained in:
yhjun1026
2023-06-13 22:39:12 +08:00
54 changed files with 1597 additions and 130 deletions

View File

@@ -392,6 +392,14 @@ class Database:
indexes = cursor.fetchall()
return [(index[2], index[4]) for index in indexes]
def get_show_create_table(self, table_name):
"""Get table show create table about specified table."""
session = self._db_sessions()
cursor = session.execute(text(f"SHOW CREATE TABLE {table_name}"))
ans = cursor.fetchall()
return ans[0][1]
def get_fields(self, table_name):
"""Get column fields about specified table."""
session = self._db_sessions()

View File

@@ -10,5 +10,6 @@ if "pytest" in sys.argv or "pytest" in sys.modules or os.getenv("CI"):
# Load the users .env file into environment variables
load_dotenv(verbose=True, override=True)
load_dotenv(".plugin_env")
del load_dotenv

View File

@@ -17,14 +17,10 @@ nltk.data.path = [os.path.join(PILOT_PATH, "nltk_data")] + nltk.data.path
PLUGINS_DIR = os.path.join(ROOT_PATH, "plugins")
FONT_DIR = os.path.join(PILOT_PATH, "fonts")
# 获取当前工作目录
current_directory = os.getcwd()
print("当前工作目录:", current_directory)
# 设置当前工作目录
new_directory = PILOT_PATH
os.chdir(new_directory)
print("新的工作目录:", os.getcwd())
DEVICE = (
"cuda"

View File

@@ -7,7 +7,7 @@ lang_dicts = {
"learn_more_markdown": "该服务是仅供非商业用途的研究预览。受 Vicuna-13B 模型 [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) 的约束",
"model_control_param": "模型参数",
"sql_generate_mode_direct": "直接执行结果",
"sql_generate_mode_none": "不直接执行结果",
"sql_generate_mode_none": "db问答",
"max_input_token_size": "最大输出Token数",
"please_choose_database": "请选择数据",
"sql_generate_diagnostics": "SQL生成与诊断",
@@ -44,7 +44,7 @@ lang_dicts = {
"learn_more_markdown": "The service is a research preview intended for non-commercial use only. subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of Vicuna-13B",
"model_control_param": "Model Parameters",
"sql_generate_mode_direct": "Execute directly",
"sql_generate_mode_none": "Execute without mode",
"sql_generate_mode_none": "db chat",
"max_input_token_size": "Maximum output token size",
"please_choose_database": "Please choose database",
"sql_generate_diagnostics": "SQL Generation & Diagnostics",

View File

@@ -51,7 +51,7 @@ def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048)
}
)
# 把最后一个用户的信息移动到末尾
# Move the last user's information to the end
temp_his = history[::-1]
last_user_input = None
for m in temp_his:
@@ -66,7 +66,7 @@ def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048)
"messages": history,
"temperature": params.get("temperature"),
"max_tokens": params.get("max_new_tokens"),
"stream": True
"stream": True,
}
res = requests.post(
@@ -78,30 +78,9 @@ def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048)
if line:
json_data = line.split(b': ', 1)[1]
decoded_line = json_data.decode("utf-8")
if decoded_line.lower() != '[DONE]'.lower():
if decoded_line.lower() != "[DONE]".lower():
obj = json.loads(json_data)
if obj['choices'][0]['delta'].get('content') is not None:
content = obj['choices'][0]['delta']['content']
if obj["choices"][0]["delta"].get("content") is not None:
content = obj["choices"][0]["delta"]["content"]
text += content
yield text
# native result.
# payloads = {
# "model": "gpt-3.5-turbo", # just for test, remove this later
# "messages": history,
# "temperature": params.get("temperature"),
# "max_tokens": params.get("max_new_tokens"),
# }
#
# res = requests.post(
# CFG.proxy_server_url, headers=headers, json=payloads, stream=True
# )
#
# text = ""
# line = res.content
# if line:
# decoded_line = line.decode("utf-8")
# json_line = json.loads(decoded_line)
# print(json_line)
# text += json_line["choices"][0]["message"]["content"]
# yield text

View File

@@ -52,7 +52,7 @@ class ChatWithDbQA(BaseChat):
raise ValueError("Could not import DBSummaryClient. ")
if self.db_name:
client = DBSummaryClient()
table_info = client.get_similar_tables(
table_info = client.get_db_summary(
dbname=self.db_name, query=self.current_user_input, topk=self.top_k
)
# table_info = self.database.table_simple_info(self.db_connect)
@@ -60,8 +60,8 @@ class ChatWithDbQA(BaseChat):
input_values = {
"input": self.current_user_input,
"top_k": str(self.top_k),
"dialect": dialect,
# "top_k": str(self.top_k),
# "dialect": dialect,
"table_info": table_info,
}
return input_values

View File

@@ -10,22 +10,44 @@ CFG = Config()
PROMPT_SCENE_DEFINE = """A chat between a curious user and an artificial intelligence assistant, who very familiar with database related knowledge. """
PROMPT_SUFFIX = """Only use the following tables generate sql if have any table info:
# PROMPT_SUFFIX = """Only use the following tables generate sql if have any table info:
# {table_info}
#
# Question: {input}
#
# """
# _DEFAULT_TEMPLATE = """
# You are a SQL expert. Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
# Unless the user specifies in his question a specific number of examples he wishes to obtain, always limit your query to at most {top_k} results.
# You can order the results by a relevant column to return the most interesting examples in the database.
# Never query for all the columns from a specific table, only ask for a the few relevant columns given the question.
# Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
#
# """
_DEFAULT_TEMPLATE_EN = """
You are a database expert. you will be given metadata information about a database or table, and then provide a brief summary and answer to the question. For example, question: "How many tables are there in database 'db_gpt'?" , answer: "There are 5 tables in database 'db_gpt', which are 'book', 'book_category', 'borrower', 'borrowing', and 'category'.
Based on the database metadata information below, provide users with professional and concise answers to their questions. If the answer cannot be obtained from the provided content, please say: "The information provided in the knowledge base is not sufficient to answer this question." It is forbidden to make up information randomly.
database metadata information:
{table_info}
Question: {input}
question:
{input}
"""
_DEFAULT_TEMPLATE = """
You are a SQL expert. Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
Unless the user specifies in his question a specific number of examples he wishes to obtain, always limit your query to at most {top_k} results.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for a the few relevant columns given the question.
Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
_DEFAULT_TEMPLATE_ZH = """
你是一位数据库专家。你将获得有关数据库或表的元数据信息,然后提供简要的总结和回答。例如,问题:“数据库 'db_gpt' 中有多少个表?” 答案:“数据库 'db_gpt' 中有 5 个表,分别是 'book''book_category''borrower''borrowing''category'。”
根据以下数据库元数据信息,为用户提供专业简洁的答案。如果无法从提供的内容中获取答案,请说:“知识库中提供的信息不足以回答此问题。” 禁止随意捏造信息。
数据库元数据信息:
{table_info}
问题:
{input}
"""
_DEFAULT_TEMPLATE = (
_DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH
)
PROMPT_SEP = SeparatorStyle.SINGLE.value
@@ -33,10 +55,10 @@ PROMPT_NEED_NEED_STREAM_OUT = True
prompt = PromptTemplate(
template_scene=ChatScene.ChatWithDbQA.value,
input_variables=["input", "table_info", "dialect", "top_k"],
input_variables=["input", "table_info"],
response_format=None,
template_define=PROMPT_SCENE_DEFINE,
template=_DEFAULT_TEMPLATE + PROMPT_SUFFIX,
template=_DEFAULT_TEMPLATE,
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
output_parser=NormalChatOutputParser(
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT

View File

@@ -1,3 +1,5 @@
from chromadb.errors import NoIndexException
from pilot.scene.base_chat import BaseChat, logger, headers
from pilot.scene.base import ChatScene
from pilot.common.sql_database import Database
@@ -46,12 +48,15 @@ class ChatDefaultKnowledge(BaseChat):
)
def generate_input_values(self):
docs = self.knowledge_embedding_client.similar_search(
self.current_user_input, CFG.KNOWLEDGE_SEARCH_TOP_SIZE
)
context = [d.page_content for d in docs]
context = context[:2000]
input_values = {"context": context, "question": self.current_user_input}
try:
docs = self.knowledge_embedding_client.similar_search(
self.current_user_input, CFG.KNOWLEDGE_SEARCH_TOP_SIZE
)
context = [d.page_content for d in docs]
context = context[:2000]
input_values = {"context": context, "question": self.current_user_input}
except NoIndexException:
raise ValueError("you have no default knowledge store, please execute python knowledge_init.py")
return input_values
def do_with_prompt_response(self, prompt_response):

View File

@@ -38,7 +38,7 @@ class ChatUrlKnowledge(BaseChat):
)
self.url = url
vector_store_config = {
"vector_store_name": url,
"vector_store_name": url.replace(":", ""),
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
}
self.knowledge_embedding_client = KnowledgeEmbedding(

View File

@@ -667,7 +667,8 @@ if __name__ == "__main__":
args = parser.parse_args()
logger.info(f"args: {args}")
# 配置初始化
# init config
cfg = Config()
load_native_plugins(cfg)
@@ -676,12 +677,12 @@ if __name__ == "__main__":
async_db_summery()
cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode))
# 加载插件可执行命令
# Loader plugins and commands
command_categories = [
"pilot.commands.built_in.audio_text",
"pilot.commands.built_in.image_gen",
]
# 排除禁用命令
# exclude commands
command_categories = [
x for x in command_categories if x not in cfg.disabled_command_categories
]

View File

@@ -1,11 +1,13 @@
from typing import Optional
from chromadb.errors import NotEnoughElementsException
from langchain.embeddings import HuggingFaceEmbeddings
from pilot.configs.config import Config
from pilot.source_embedding.csv_embedding import CSVEmbedding
from pilot.source_embedding.markdown_embedding import MarkdownEmbedding
from pilot.source_embedding.pdf_embedding import PDFEmbedding
from pilot.source_embedding.ppt_embedding import PPTEmbedding
from pilot.source_embedding.url_embedding import URLEmbedding
from pilot.source_embedding.word_embedding import WordEmbedding
from pilot.vector_store.connector import VectorStoreConnector
@@ -19,6 +21,8 @@ KnowledgeEmbeddingType = {
".doc": (WordEmbedding, {}),
".docx": (WordEmbedding, {}),
".csv": (CSVEmbedding, {}),
".ppt": (PPTEmbedding, {}),
".pptx": (PPTEmbedding, {}),
}
@@ -42,8 +46,12 @@ class KnowledgeEmbedding:
self.knowledge_embedding_client = self.init_knowledge_embedding()
self.knowledge_embedding_client.source_embedding()
def knowledge_embedding_batch(self):
self.knowledge_embedding_client.batch_embedding()
def knowledge_embedding_batch(self, docs):
# docs = self.knowledge_embedding_client.read_batch()
self.knowledge_embedding_client.index_to_store(docs)
def read(self):
return self.knowledge_embedding_client.read_batch()
def init_knowledge_embedding(self):
if self.file_type == "url":
@@ -68,7 +76,11 @@ class KnowledgeEmbedding:
vector_client = VectorStoreConnector(
CFG.VECTOR_STORE_TYPE, self.vector_store_config
)
return vector_client.similar_search(text, topk)
try:
ans = vector_client.similar_search(text, topk)
except NotEnoughElementsException:
ans = vector_client.similar_search(text, 1)
return ans
def vector_exist(self):
vector_client = VectorStoreConnector(

View File

@@ -5,8 +5,8 @@ from typing import List
import markdown
from bs4 import BeautifulSoup
from langchain.document_loaders import TextLoader
from langchain.schema import Document
from langchain.text_splitter import SpacyTextSplitter
from pilot.configs.config import Config
from pilot.source_embedding import SourceEmbedding, register
@@ -30,32 +30,8 @@ class MarkdownEmbedding(SourceEmbedding):
def read(self):
"""Load from markdown path."""
loader = EncodeTextLoader(self.file_path)
text_splitter = CHNDocumentSplitter(
pdf=True, sentence_size=CFG.KNOWLEDGE_CHUNK_SIZE
)
return loader.load_and_split(text_splitter)
@register
def read_batch(self):
"""Load from markdown path."""
docments = []
for root, _, files in os.walk(self.file_path, topdown=False):
for file in files:
filename = os.path.join(root, file)
loader = TextLoader(filename)
# text_splitor = CHNDocumentSplitter(chunk_size=1000, chunk_overlap=20, length_function=len)
# docs = loader.load_and_split()
docs = loader.load()
# 更新metadata数据
new_docs = []
for doc in docs:
doc.metadata = {
"source": doc.metadata["source"].replace(self.file_path, "")
}
print("doc is embedding ... ", doc.metadata)
new_docs.append(doc)
docments += new_docs
return docments
textsplitter = SpacyTextSplitter(pipeline='zh_core_web_sm', chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, chunk_overlap=200)
return loader.load_and_split(textsplitter)
@register
def data_process(self, documents: List[Document]):

View File

@@ -29,7 +29,7 @@ class PDFEmbedding(SourceEmbedding):
# pdf=True, sentence_size=CFG.KNOWLEDGE_CHUNK_SIZE
# )
textsplitter = SpacyTextSplitter(
pipeline="zh_core_web_sm", chunk_size=1000, chunk_overlap=200
pipeline="zh_core_web_sm", chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, chunk_overlap=200
)
return loader.load_and_split(textsplitter)

View File

@@ -0,0 +1,37 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from typing import List
from langchain.document_loaders import UnstructuredPowerPointLoader
from langchain.schema import Document
from langchain.text_splitter import SpacyTextSplitter
from pilot.configs.config import Config
from pilot.source_embedding import SourceEmbedding, register
CFG = Config()
class PPTEmbedding(SourceEmbedding):
"""ppt embedding for read ppt document."""
def __init__(self, file_path, vector_store_config):
"""Initialize with pdf path."""
super().__init__(file_path, vector_store_config)
self.file_path = file_path
self.vector_store_config = vector_store_config
@register
def read(self):
"""Load from ppt path."""
loader = UnstructuredPowerPointLoader(self.file_path)
textsplitter = SpacyTextSplitter(pipeline='zh_core_web_sm', chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, chunk_overlap=200)
return loader.load_and_split(textsplitter)
@register
def data_process(self, documents: List[Document]):
i = 0
for d in documents:
documents[i].page_content = d.page_content.replace("\n", "")
i += 1
return documents

View File

@@ -2,6 +2,8 @@
# -*- coding: utf-8 -*-
from abc import ABC, abstractmethod
from typing import Dict, List, Optional
from chromadb.errors import NotEnoughElementsException
from pilot.configs.config import Config
from pilot.vector_store.connector import VectorStoreConnector
@@ -62,7 +64,11 @@ class SourceEmbedding(ABC):
@register
def similar_search(self, doc, topk):
"""vector store similarity_search"""
return self.vector_client.similar_search(doc, topk)
try:
ans = self.vector_client.similar_search(doc, topk)
except NotEnoughElementsException:
ans = self.vector_client.similar_search(doc, 1)
return ans
def vector_name_exist(self):
return self.vector_client.vector_name_exists()
@@ -79,14 +85,11 @@ class SourceEmbedding(ABC):
if "index_to_store" in registered_methods:
self.index_to_store(text)
def batch_embedding(self):
if "read_batch" in registered_methods:
text = self.read_batch()
def read_batch(self):
if "read" in registered_methods:
text = self.read()
if "data_process" in registered_methods:
text = self.data_process(text)
if "text_split" in registered_methods:
self.text_split(text)
if "text_to_vector" in registered_methods:
self.text_to_vector(text)
if "index_to_store" in registered_methods:
self.index_to_store(text)
return text

View File

@@ -32,13 +32,14 @@ class DBSummaryClient:
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
)
vector_store_config = {
"vector_store_name": dbname + "_profile",
"vector_store_name": dbname + "_summary",
"embeddings": embeddings,
}
embedding = StringEmbedding(
file_path=db_summary_client.get_summery(),
vector_store_config=vector_store_config,
)
self.init_db_profile(db_summary_client, dbname, embeddings)
if not embedding.vector_name_exist():
if CFG.SUMMARY_CONFIG == "FAST":
for vector_table_info in db_summary_client.get_summery():
@@ -69,10 +70,22 @@ class DBSummaryClient:
logger.info("db summary embedding success")
def get_db_summary(self, dbname, query, topk):
vector_store_config = {
"vector_store_name": dbname + "_profile",
}
knowledge_embedding_client = KnowledgeEmbedding(
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
vector_store_config=vector_store_config,
)
table_docs =knowledge_embedding_client.similar_search(query, topk)
ans = [d.page_content for d in table_docs]
return ans
def get_similar_tables(self, dbname, query, topk):
"""get user query related tables info"""
vector_store_config = {
"vector_store_name": dbname + "_profile",
"vector_store_name": dbname + "_summary",
}
knowledge_embedding_client = KnowledgeEmbedding(
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
@@ -112,6 +125,29 @@ class DBSummaryClient:
for dbname in dbs:
self.db_summary_embedding(dbname)
def init_db_profile(self, db_summary_client, dbname, embeddings):
profile_store_config = {
"vector_store_name": dbname + "_profile",
"embeddings": embeddings,
}
embedding = StringEmbedding(
file_path=db_summary_client.get_db_summery(),
vector_store_config=profile_store_config,
)
if not embedding.vector_name_exist():
docs = []
docs.extend(embedding.read_batch())
for table_summary in db_summary_client.table_info_json():
embedding = StringEmbedding(
table_summary,
profile_store_config,
)
docs.extend(embedding.read_batch())
embedding.index_to_store(docs)
logger.info("init db profile success...")
def _get_llm_response(query, db_input, dbsummary):
chat_param = {

View File

@@ -5,6 +5,43 @@ from pilot.summary.db_summary import DBSummary, TableSummary, FieldSummary, Inde
CFG = Config()
# {
# "database_name": "mydatabase",
# "tables": [
# {
# "table_name": "customers",
# "columns": [
# {"name": "id", "type": "int(11)", "is_primary_key": true},
# {"name": "name", "type": "varchar(255)", "is_primary_key": false},
# {"name": "email", "type": "varchar(255)", "is_primary_key": false}
# ],
# "indexes": [
# {"name": "PRIMARY", "type": "primary", "columns": ["id"]},
# {"name": "idx_name", "type": "index", "columns": ["name"]},
# {"name": "idx_email", "type": "index", "columns": ["email"]}
# ],
# "size_in_bytes": 1024,
# "rows": 1000
# },
# {
# "table_name": "orders",
# "columns": [
# {"name": "id", "type": "int(11)", "is_primary_key": true},
# {"name": "customer_id", "type": "int(11)", "is_primary_key": false},
# {"name": "order_date", "type": "date", "is_primary_key": false},
# {"name": "total_amount", "type": "decimal(10,2)", "is_primary_key": false}
# ],
# "indexes": [
# {"name": "PRIMARY", "type": "primary", "columns": ["id"]},
# {"name": "fk_customer_id", "type": "foreign_key", "columns": ["customer_id"], "referenced_table": "customers", "referenced_columns": ["id"]}
# ],
# "size_in_bytes": 2048,
# "rows": 500
# }
# ],
# "qps": 100,
# "tps": 50
# }
class MysqlSummary(DBSummary):
"""Get mysql summary template."""
@@ -13,7 +50,7 @@ class MysqlSummary(DBSummary):
self.name = name
self.type = "MYSQL"
self.summery = (
"""database name:{name}, database type:{type}, table infos:{table_info}"""
"""{{"database_name": "{name}", "type": "{type}", "tables": "{tables}", "qps": "{qps}", "tps": {tps}}}"""
)
self.tables = {}
self.tables_info = []
@@ -31,12 +68,14 @@ class MysqlSummary(DBSummary):
)
tables = self.db.get_table_names()
self.table_comments = self.db.get_table_comments(name)
comment_map = {}
for table_comment in self.table_comments:
self.tables_info.append(
"table name:{table_name},table description:{table_comment}".format(
table_name=table_comment[0], table_comment=table_comment[1]
)
)
comment_map[table_comment[0]] = table_comment[1]
vector_table = json.dumps(
{"table_name": table_comment[0], "table_description": table_comment[1]}
@@ -45,11 +84,18 @@ class MysqlSummary(DBSummary):
vector_table.encode("utf-8").decode("unicode_escape")
)
self.table_columns_info = []
self.table_columns_json = []
for table_name in tables:
table_summary = MysqlTableSummary(self.db, name, table_name)
table_summary = MysqlTableSummary(self.db, name, table_name, comment_map)
# self.tables[table_name] = table_summary.get_summery()
self.tables[table_name] = table_summary.get_columns()
self.table_columns_info.append(table_summary.get_columns())
# self.table_columns_json.append(table_summary.get_summary_json())
table_profile = "table name:{table_name},table description:{table_comment}".format(
table_name=table_name, table_comment=self.db.get_show_create_table(table_name)
)
self.table_columns_json.append(table_profile)
# self.tables_info.append(table_summary.get_summery())
def get_summery(self):
@@ -60,23 +106,29 @@ class MysqlSummary(DBSummary):
name=self.name, type=self.type, table_info=";".join(self.tables_info)
)
def get_db_summery(self):
return self.summery.format(
name=self.name, type=self.type, tables=";".join(self.vector_tables_info), qps=1000, tps=1000
)
def get_table_summary(self):
return self.tables
def get_table_comments(self):
return self.table_comments
def get_columns(self):
return self.table_columns_info
def table_info_json(self):
return self.table_columns_json
class MysqlTableSummary(TableSummary):
"""Get mysql table summary template."""
def __init__(self, instance, dbname, name):
def __init__(self, instance, dbname, name, comment_map):
self.name = name
self.dbname = dbname
self.summery = """database name:{dbname}, table name:{name}, have columns info: {fields}, have indexes info: {indexes}"""
self.json_summery_template = """{{"table_name": "{name}", "comment": "{comment}", "columns": "{fields}", "indexes": "{indexes}", "size_in_bytes": {size_in_bytes}, "rows": {rows}}}"""
self.fields = []
self.fields_info = []
self.indexes = []
@@ -100,6 +152,10 @@ class MysqlTableSummary(TableSummary):
self.indexes.append(index_summary)
self.indexes_info.append(index_summary.get_summery())
self.json_summery = self.json_summery_template.format(
name=name, comment=comment_map[name], fields=self.fields_info, indexes=self.indexes_info, size_in_bytes=1000, rows=1000
)
def get_summery(self):
return self.summery.format(
name=self.name,
@@ -111,20 +167,24 @@ class MysqlTableSummary(TableSummary):
def get_columns(self):
return self.column_summery
def get_summary_json(self):
return self.json_summery
class MysqlFieldsSummary(FieldSummary):
"""Get mysql field summary template."""
def __init__(self, field):
self.name = field[0]
self.summery = """column name:{name}, column data type:{data_type}, is nullable:{is_nullable}, default value is:{default_value}, comment is:{comment} """
# self.summery = """column name:{name}, column data type:{data_type}, is nullable:{is_nullable}, default value is:{default_value}, comment is:{comment} """
# self.summery = """{"name": {name}, "type": {data_type}, "is_primary_key": {is_nullable}, "comment":{comment}, "default":{default_value}}"""
self.data_type = field[1]
self.default_value = field[2]
self.is_nullable = field[3]
self.comment = field[4]
def get_summery(self):
return self.summery.format(
return '{{"name": "{name}", "type": "{data_type}", "is_primary_key": "{is_nullable}", "comment": "{comment}", "default": "{default_value}"}}'.format(
name=self.name,
data_type=self.data_type,
is_nullable=self.is_nullable,
@@ -138,11 +198,12 @@ class MysqlIndexSummary(IndexSummary):
def __init__(self, index):
self.name = index[0]
self.summery = """index name:{name}, index bind columns:{bind_fields}"""
# self.summery = """index name:{name}, index bind columns:{bind_fields}"""
self.summery_template = '{{"name": "{name}", "columns": {bind_fields}}}'
self.bind_fields = index[1]
def get_summery(self):
return self.summery.format(name=self.name, bind_fields=self.bind_fields)
return self.summery_template.format(name=self.name, bind_fields=self.bind_fields)
if __name__ == "__main__":