Merge branch 'dev' of github.com:csunny/DB-GPT into dev

This commit is contained in:
csunny 2023-05-07 05:15:09 +08:00
commit 4ce257e84f
4 changed files with 25 additions and 17 deletions

View File

@ -46,7 +46,7 @@ def get_similar(q):
for doc in docs:
dc, s = doc
print(dc.page_content)
print(s)
yield dc.page_content
if __name__ == "__main__":

View File

@ -8,7 +8,7 @@ ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__fi
MODEL_PATH = os.path.join(ROOT_PATH, "models")
VECTORE_PATH = os.path.join(ROOT_PATH, "vector_store")
LOGDIR = os.path.join(ROOT_PATH, "logs")
DATASETS_DIR = os.path.join(ROOT_PATH, "datasets")
DATASETS_DIR = os.path.join(ROOT_PATH, "pilot/datasets")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
LLM_MODEL_CONFIG = {

View File

@ -12,9 +12,9 @@ import requests
from urllib.parse import urljoin
from pilot.configs.model_config import DB_SETTINGS
from pilot.connections.mysql_conn import MySQLOperator
from pilot.vector_store.extract_tovec import get_vector_storelist, load_knownledge_from_doc
from pilot.vector_store.extract_tovec import get_vector_storelist, load_knownledge_from_doc, knownledge_tovec_st
from pilot.configs.model_config import LOGDIR, VICUNA_MODEL_SERVER, LLM_MODEL
from pilot.configs.model_config import LOGDIR, VICUNA_MODEL_SERVER, LLM_MODEL, DATASETS_DIR
from pilot.conversation import (
default_conversation,
@ -50,7 +50,7 @@ priority = {
def get_simlar(q):
docsearch = load_knownledge_from_doc()
docsearch = knownledge_tovec_st(os.path.join(DATASETS_DIR, "plan.md"))
docs = docsearch.similarity_search_with_score(q, k=1)
contents = [dc.page_content for dc, _ in docs]
@ -160,7 +160,7 @@ def http_bot(state, db_selector, temperature, max_new_tokens, request: gr.Reques
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
return
query = state.messages[-2][1]
if len(state.messages) == state.offset + 2:
# 第一轮对话需要加入提示Prompt
@ -168,15 +168,25 @@ def http_bot(state, db_selector, temperature, max_new_tokens, request: gr.Reques
new_state = conv_templates[template_name].copy()
new_state.conv_id = uuid.uuid4().hex
query = state.messages[-2][1]
# prompt 中添加上下文提示
if db_selector:
new_state.append_message(new_state.roles[0], gen_sqlgen_conversation(dbname) + query)
new_state.append_message(new_state.roles[1], None)
state = new_state
else:
new_state.append_message(new_state.roles[0], query)
new_state.append_message(new_state.roles[1], None)
state = new_state
new_state.append_message(new_state.roles[1], None)
state = new_state
if not db_selector:
state.append_message(new_state.roles[0], get_simlar(query) + query)
# try:
# if not db_selector:
# sim_q = get_simlar(query)
# print("********vector similar info*************: ", sim_q)
# state.append_message(new_state.roles[0], sim_q + query)
# except Exception as e:
# print(e)
prompt = state.get_prompt()

View File

@ -50,17 +50,15 @@ def load_knownledge_from_doc():
from pilot.configs.model_config import LLM_MODEL_CONFIG
embeddings = HuggingFaceEmbeddings(model_name=LLM_MODEL_CONFIG["sentence-transforms"])
docs = []
files = os.listdir(DATASETS_DIR)
for file in files:
if not os.path.isdir(file):
with open(file, "r") as f:
doc = f.read()
docs.append(docs)
filename = os.path.join(DATASETS_DIR, file)
with open(filename, "r") as f:
knownledge = f.read()
print(doc)
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_owerlap=0)
texts = text_splitter.split_text("\n".join(docs))
texts = text_splitter.split_text(knownledge)
docsearch = Chroma.from_texts(texts, embeddings, metadatas=[{"source": str(i)} for i in range(len(texts))],
persist_directory=os.path.join(VECTORE_PATH, ".vectore"))
return docsearch