mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-05 18:33:52 +00:00
knownledge based qa
This commit is contained in:
parent
529f077409
commit
c1758f030b
@ -40,8 +40,8 @@ def get_answer(q):
|
|||||||
return response.response
|
return response.response
|
||||||
|
|
||||||
def get_similar(q):
|
def get_similar(q):
|
||||||
from pilot.vector_store.extract_tovec import knownledge_tovec, knownledge_tovec_st
|
from pilot.vector_store.extract_tovec import knownledge_tovec, load_knownledge_from_doc
|
||||||
docsearch = knownledge_tovec_st("./datasets/plan.md")
|
docsearch = load_knownledge_from_doc()
|
||||||
docs = docsearch.similarity_search_with_score(q, k=1)
|
docs = docsearch.similarity_search_with_score(q, k=1)
|
||||||
|
|
||||||
for doc in docs:
|
for doc in docs:
|
||||||
|
@ -8,7 +8,7 @@ ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__fi
|
|||||||
MODEL_PATH = os.path.join(ROOT_PATH, "models")
|
MODEL_PATH = os.path.join(ROOT_PATH, "models")
|
||||||
VECTORE_PATH = os.path.join(ROOT_PATH, "vector_store")
|
VECTORE_PATH = os.path.join(ROOT_PATH, "vector_store")
|
||||||
LOGDIR = os.path.join(ROOT_PATH, "logs")
|
LOGDIR = os.path.join(ROOT_PATH, "logs")
|
||||||
DATASETS_DIR = os.path.join(ROOT_PATH, "datasets")
|
DATASETS_DIR = os.path.join(ROOT_PATH, "pilot/datasets")
|
||||||
|
|
||||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
LLM_MODEL_CONFIG = {
|
LLM_MODEL_CONFIG = {
|
||||||
|
@ -12,9 +12,9 @@ import requests
|
|||||||
from urllib.parse import urljoin
|
from urllib.parse import urljoin
|
||||||
from pilot.configs.model_config import DB_SETTINGS
|
from pilot.configs.model_config import DB_SETTINGS
|
||||||
from pilot.connections.mysql_conn import MySQLOperator
|
from pilot.connections.mysql_conn import MySQLOperator
|
||||||
from pilot.vector_store.extract_tovec import get_vector_storelist, load_knownledge_from_doc
|
from pilot.vector_store.extract_tovec import get_vector_storelist, load_knownledge_from_doc, knownledge_tovec_st
|
||||||
|
|
||||||
from pilot.configs.model_config import LOGDIR, VICUNA_MODEL_SERVER, LLM_MODEL
|
from pilot.configs.model_config import LOGDIR, VICUNA_MODEL_SERVER, LLM_MODEL, DATASETS_DIR
|
||||||
|
|
||||||
from pilot.conversation import (
|
from pilot.conversation import (
|
||||||
default_conversation,
|
default_conversation,
|
||||||
@ -50,7 +50,7 @@ priority = {
|
|||||||
|
|
||||||
def get_simlar(q):
|
def get_simlar(q):
|
||||||
|
|
||||||
docsearch = load_knownledge_from_doc()
|
docsearch = knownledge_tovec_st(os.path.join(DATASETS_DIR, "plan.md"))
|
||||||
docs = docsearch.similarity_search_with_score(q, k=1)
|
docs = docsearch.similarity_search_with_score(q, k=1)
|
||||||
|
|
||||||
contents = [dc.page_content for dc, _ in docs]
|
contents = [dc.page_content for dc, _ in docs]
|
||||||
@ -171,12 +171,20 @@ def http_bot(state, db_selector, temperature, max_new_tokens, request: gr.Reques
|
|||||||
# prompt 中添加上下文提示
|
# prompt 中添加上下文提示
|
||||||
if db_selector:
|
if db_selector:
|
||||||
new_state.append_message(new_state.roles[0], gen_sqlgen_conversation(dbname) + query)
|
new_state.append_message(new_state.roles[0], gen_sqlgen_conversation(dbname) + query)
|
||||||
|
new_state.append_message(new_state.roles[1], None)
|
||||||
|
state = new_state
|
||||||
|
else:
|
||||||
|
new_state.append_message(new_state.roles[0], query)
|
||||||
|
new_state.append_message(new_state.roles[1], None)
|
||||||
|
state = new_state
|
||||||
|
|
||||||
new_state.append_message(new_state.roles[1], None)
|
try:
|
||||||
state = new_state
|
if not db_selector:
|
||||||
|
sim_q = get_simlar(query)
|
||||||
if not db_selector:
|
print("********vector similar info*************: ", sim_q)
|
||||||
state.append_message(new_state.roles[0], get_simlar(query) + query)
|
state.append_message(new_state.roles[0], sim_q + query)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
|
||||||
prompt = state.get_prompt()
|
prompt = state.get_prompt()
|
||||||
|
|
||||||
|
@ -50,17 +50,15 @@ def load_knownledge_from_doc():
|
|||||||
from pilot.configs.model_config import LLM_MODEL_CONFIG
|
from pilot.configs.model_config import LLM_MODEL_CONFIG
|
||||||
embeddings = HuggingFaceEmbeddings(model_name=LLM_MODEL_CONFIG["sentence-transforms"])
|
embeddings = HuggingFaceEmbeddings(model_name=LLM_MODEL_CONFIG["sentence-transforms"])
|
||||||
|
|
||||||
docs = []
|
|
||||||
files = os.listdir(DATASETS_DIR)
|
files = os.listdir(DATASETS_DIR)
|
||||||
for file in files:
|
for file in files:
|
||||||
if not os.path.isdir(file):
|
if not os.path.isdir(file):
|
||||||
with open(file, "r") as f:
|
filename = os.path.join(DATASETS_DIR, file)
|
||||||
doc = f.read()
|
with open(filename, "r") as f:
|
||||||
docs.append(docs)
|
knownledge = f.read()
|
||||||
|
|
||||||
print(doc)
|
|
||||||
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_owerlap=0)
|
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_owerlap=0)
|
||||||
texts = text_splitter.split_text("\n".join(docs))
|
texts = text_splitter.split_text(knownledge)
|
||||||
docsearch = Chroma.from_texts(texts, embeddings, metadatas=[{"source": str(i)} for i in range(len(texts))],
|
docsearch = Chroma.from_texts(texts, embeddings, metadatas=[{"source": str(i)} for i in range(len(texts))],
|
||||||
persist_directory=os.path.join(VECTORE_PATH, ".vectore"))
|
persist_directory=os.path.join(VECTORE_PATH, ".vectore"))
|
||||||
return docsearch
|
return docsearch
|
||||||
|
Loading…
Reference in New Issue
Block a user