mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-18 07:30:40 +00:00
Merge remote-tracking branch 'origin/main' into dev
# Conflicts: # pilot/configs/config.py # pilot/connections/mysql.py # pilot/conversation.py # pilot/server/webserver.py
This commit is contained in:
90
pilot/server/chat_adapter.py
Normal file
90
pilot/server/chat_adapter.py
Normal file
@@ -0,0 +1,90 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from functools import cache
|
||||
from typing import List
|
||||
|
||||
from pilot.model.inference import generate_stream
|
||||
|
||||
|
||||
class BaseChatAdpter:
|
||||
"""The Base class for chat with llm models. it will match the model,
|
||||
and fetch output from model"""
|
||||
|
||||
def match(self, model_path: str):
|
||||
return True
|
||||
|
||||
def get_generate_stream_func(self):
|
||||
"""Return the generate stream handler func"""
|
||||
pass
|
||||
|
||||
|
||||
llm_model_chat_adapters: List[BaseChatAdpter] = []
|
||||
|
||||
|
||||
def register_llm_model_chat_adapter(cls):
|
||||
"""Register a chat adapter"""
|
||||
llm_model_chat_adapters.append(cls())
|
||||
|
||||
|
||||
@cache
|
||||
def get_llm_chat_adapter(model_path: str) -> BaseChatAdpter:
|
||||
"""Get a chat generate func for a model"""
|
||||
for adapter in llm_model_chat_adapters:
|
||||
if adapter.match(model_path):
|
||||
return adapter
|
||||
|
||||
raise ValueError(f"Invalid model for chat adapter {model_path}")
|
||||
|
||||
|
||||
class VicunaChatAdapter(BaseChatAdpter):
|
||||
|
||||
"""Model chat Adapter for vicuna"""
|
||||
|
||||
def match(self, model_path: str):
|
||||
return "vicuna" in model_path
|
||||
|
||||
def get_generate_stream_func(self):
|
||||
return generate_stream
|
||||
|
||||
|
||||
class ChatGLMChatAdapter(BaseChatAdpter):
|
||||
"""Model chat Adapter for ChatGLM"""
|
||||
|
||||
def match(self, model_path: str):
|
||||
return "chatglm" in model_path
|
||||
|
||||
def get_generate_stream_func(self):
|
||||
from pilot.model.chatglm_llm import chatglm_generate_stream
|
||||
|
||||
return chatglm_generate_stream
|
||||
|
||||
|
||||
class CodeT5ChatAdapter(BaseChatAdpter):
|
||||
|
||||
"""Model chat adapter for CodeT5"""
|
||||
|
||||
def match(self, model_path: str):
|
||||
return "codet5" in model_path
|
||||
|
||||
def get_generate_stream_func(self):
|
||||
# TODO
|
||||
pass
|
||||
|
||||
|
||||
class CodeGenChatAdapter(BaseChatAdpter):
|
||||
|
||||
"""Model chat adapter for CodeGen"""
|
||||
|
||||
def match(self, model_path: str):
|
||||
return "codegen" in model_path
|
||||
|
||||
def get_generate_stream_func(self):
|
||||
# TODO
|
||||
pass
|
||||
|
||||
|
||||
register_llm_model_chat_adapter(VicunaChatAdapter)
|
||||
register_llm_model_chat_adapter(ChatGLMChatAdapter)
|
||||
|
||||
register_llm_model_chat_adapter(BaseChatAdpter)
|
@@ -1,8 +1,7 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
code_highlight_css = (
|
||||
"""
|
||||
code_highlight_css = """
|
||||
#chatbot .hll { background-color: #ffffcc }
|
||||
#chatbot .c { color: #408080; font-style: italic }
|
||||
#chatbot .err { border: 1px solid #FF0000 }
|
||||
@@ -71,6 +70,5 @@ code_highlight_css = (
|
||||
#chatbot .vi { color: #19177C }
|
||||
#chatbot .vm { color: #19177C }
|
||||
#chatbot .il { color: #666666 }
|
||||
""")
|
||||
#.highlight { background: #f8f8f8; }
|
||||
|
||||
"""
|
||||
# .highlight { background: #f8f8f8; }
|
||||
|
@@ -49,7 +49,7 @@ class Chatbot(Changeable, Selectable, IOComponent, JSONSerializable):
|
||||
warnings.warn(
|
||||
"The 'color_map' parameter has been deprecated.",
|
||||
)
|
||||
#self.md = utils.get_markdown_parser()
|
||||
# self.md = utils.get_markdown_parser()
|
||||
self.md = Markdown(extras=["fenced-code-blocks", "tables", "break-on-newline"])
|
||||
self.select: EventListenerMethod
|
||||
"""
|
||||
@@ -112,7 +112,7 @@ class Chatbot(Changeable, Selectable, IOComponent, JSONSerializable):
|
||||
): # This happens for previously processed messages
|
||||
return chat_message
|
||||
elif isinstance(chat_message, str):
|
||||
#return self.md.render(chat_message)
|
||||
# return self.md.render(chat_message)
|
||||
return str(self.md.convert(chat_message))
|
||||
else:
|
||||
raise ValueError(f"Invalid message for Chatbot component: {chat_message}")
|
||||
@@ -141,9 +141,10 @@ class Chatbot(Changeable, Selectable, IOComponent, JSONSerializable):
|
||||
), f"Expected a list of lists of length 2 or list of tuples of length 2. Received: {message_pair}"
|
||||
processed_messages.append(
|
||||
(
|
||||
#self._process_chat_messages(message_pair[0]),
|
||||
'<pre style="font-family: var(--font)">' +
|
||||
message_pair[0] + "</pre>",
|
||||
# self._process_chat_messages(message_pair[0]),
|
||||
'<pre style="font-family: var(--font)">'
|
||||
+ message_pair[0]
|
||||
+ "</pre>",
|
||||
self._process_chat_messages(message_pair[1]),
|
||||
)
|
||||
)
|
||||
@@ -163,5 +164,3 @@ class Chatbot(Changeable, Selectable, IOComponent, JSONSerializable):
|
||||
**kwargs,
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
|
@@ -1,40 +1,91 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import uvicorn
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Optional, List
|
||||
from fastapi import FastAPI, Request, BackgroundTasks
|
||||
import os
|
||||
import sys
|
||||
|
||||
import uvicorn
|
||||
from fastapi import BackgroundTasks, FastAPI, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pilot.model.inference import generate_stream
|
||||
from pydantic import BaseModel
|
||||
from pilot.model.inference import generate_output, get_embeddings
|
||||
|
||||
from pilot.model.loader import ModelLoader
|
||||
from pilot.configs.model_config import *
|
||||
from pilot.configs.config import Config
|
||||
|
||||
|
||||
CFG = Config()
|
||||
model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL]
|
||||
|
||||
|
||||
global_counter = 0
|
||||
model_semaphore = None
|
||||
|
||||
ml = ModelLoader(model_path=model_path)
|
||||
model, tokenizer = ml.loader(num_gpus=1, load_8bit=ISLOAD_8BIT, debug=ISDEBUG)
|
||||
#model, tokenizer = load_model(model_path=model_path, device=DEVICE, num_gpus=1, load_8bit=True, debug=False)
|
||||
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(ROOT_PATH)
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.configs.model_config import *
|
||||
from pilot.model.inference import generate_output, generate_stream, get_embeddings
|
||||
from pilot.model.loader import ModelLoader
|
||||
from pilot.server.chat_adapter import get_llm_chat_adapter
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class ModelWorker:
|
||||
def __init__(self):
|
||||
pass
|
||||
def __init__(self, model_path, model_name, device, num_gpus=1):
|
||||
if model_path.endswith("/"):
|
||||
model_path = model_path[:-1]
|
||||
self.model_name = model_name or model_path.split("/")[-1]
|
||||
self.device = device
|
||||
|
||||
self.ml = ModelLoader(model_path=model_path)
|
||||
self.model, self.tokenizer = self.ml.loader(
|
||||
num_gpus, load_8bit=ISLOAD_8BIT, debug=ISDEBUG
|
||||
)
|
||||
|
||||
if hasattr(self.model.config, "max_sequence_length"):
|
||||
self.context_len = self.model.config.max_sequence_length
|
||||
elif hasattr(self.model.config, "max_position_embeddings"):
|
||||
self.context_len = self.model.config.max_position_embeddings
|
||||
|
||||
else:
|
||||
self.context_len = 2048
|
||||
|
||||
self.llm_chat_adapter = get_llm_chat_adapter(model_path)
|
||||
self.generate_stream_func = self.llm_chat_adapter.get_generate_stream_func()
|
||||
|
||||
def get_queue_length(self):
|
||||
if (
|
||||
model_semaphore is None
|
||||
or model_semaphore._value is None
|
||||
or model_semaphore._waiters is None
|
||||
):
|
||||
return 0
|
||||
else:
|
||||
(
|
||||
CFG.LIMIT_MODEL_CONCURRENCY
|
||||
- model_semaphore._value
|
||||
+ len(model_semaphore._waiters)
|
||||
)
|
||||
|
||||
def generate_stream_gate(self, params):
|
||||
try:
|
||||
for output in self.generate_stream_func(
|
||||
self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS
|
||||
):
|
||||
print("output: ", output)
|
||||
ret = {
|
||||
"text": output,
|
||||
"error_code": 0,
|
||||
}
|
||||
yield json.dumps(ret).encode() + b"\0"
|
||||
|
||||
except torch.cuda.CudaError:
|
||||
ret = {"text": "**GPU OutOfMemory, Please Refresh.**", "error_code": 0}
|
||||
yield json.dumps(ret).encode() + b"\0"
|
||||
|
||||
def get_embeddings(self, prompt):
|
||||
return get_embeddings(self.model, self.tokenizer, prompt)
|
||||
|
||||
# TODO
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
class PromptRequest(BaseModel):
|
||||
prompt: str
|
||||
temperature: float
|
||||
@@ -42,6 +93,7 @@ class PromptRequest(BaseModel):
|
||||
model: str
|
||||
stop: str = None
|
||||
|
||||
|
||||
class StreamRequest(BaseModel):
|
||||
model: str
|
||||
prompt: str
|
||||
@@ -49,64 +101,43 @@ class StreamRequest(BaseModel):
|
||||
max_new_tokens: int
|
||||
stop: str
|
||||
|
||||
|
||||
class EmbeddingRequest(BaseModel):
|
||||
prompt: str
|
||||
|
||||
|
||||
def release_model_semaphore():
|
||||
model_semaphore.release()
|
||||
|
||||
|
||||
def generate_stream_gate(params):
|
||||
try:
|
||||
for output in generate_stream(
|
||||
model,
|
||||
tokenizer,
|
||||
params,
|
||||
DEVICE,
|
||||
CFG.MAX_POSITION_EMBEDDINGS,
|
||||
):
|
||||
print("output: ", output)
|
||||
ret = {
|
||||
"text": output,
|
||||
"error_code": 0,
|
||||
}
|
||||
yield json.dumps(ret).encode() + b"\0"
|
||||
except torch.cuda.CudaError:
|
||||
ret = {
|
||||
"text": "**GPU OutOfMemory, Please Refresh.**",
|
||||
"error_code": 0
|
||||
}
|
||||
yield json.dumps(ret).encode() + b"\0"
|
||||
|
||||
|
||||
@app.post("/generate_stream")
|
||||
async def api_generate_stream(request: Request):
|
||||
global model_semaphore, global_counter
|
||||
global_counter += 1
|
||||
params = await request.json()
|
||||
print(model, tokenizer, params, DEVICE)
|
||||
|
||||
if model_semaphore is None:
|
||||
model_semaphore = asyncio.Semaphore(CFG.LIMIT_MODEL_CONCURRENCY)
|
||||
await model_semaphore.acquire()
|
||||
await model_semaphore.acquire()
|
||||
|
||||
generator = generate_stream_gate(params)
|
||||
generator = worker.generate_stream_gate(params)
|
||||
background_tasks = BackgroundTasks()
|
||||
background_tasks.add_task(release_model_semaphore)
|
||||
return StreamingResponse(generator, background=background_tasks)
|
||||
|
||||
|
||||
@app.post("/generate")
|
||||
def generate(prompt_request: PromptRequest):
|
||||
params = {
|
||||
"prompt": prompt_request.prompt,
|
||||
"temperature": prompt_request.temperature,
|
||||
"max_new_tokens": prompt_request.max_new_tokens,
|
||||
"stop": prompt_request.stop
|
||||
"stop": prompt_request.stop,
|
||||
}
|
||||
|
||||
response = []
|
||||
response = []
|
||||
rsp_str = ""
|
||||
output = generate_stream_gate(params)
|
||||
output = worker.generate_stream_gate(params)
|
||||
for rsp in output:
|
||||
# rsp = rsp.decode("utf-8")
|
||||
rsp_str = str(rsp, "utf-8")
|
||||
@@ -114,15 +145,22 @@ def generate(prompt_request: PromptRequest):
|
||||
response.append(rsp_str)
|
||||
|
||||
return {"response": rsp_str}
|
||||
|
||||
|
||||
|
||||
@app.post("/embedding")
|
||||
def embeddings(prompt_request: EmbeddingRequest):
|
||||
params = {"prompt": prompt_request.prompt}
|
||||
print("Received prompt: ", params["prompt"])
|
||||
output = get_embeddings(model, tokenizer, params["prompt"])
|
||||
output = worker.get_embeddings(params["prompt"])
|
||||
return {"response": [float(x) for x in output]}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(app, host="0.0.0.0", log_level="info")
|
||||
model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL]
|
||||
print(model_path, DEVICE)
|
||||
|
||||
worker = ModelWorker(
|
||||
model_path=model_path, model_name=CFG.LLM_MODEL, device=DEVICE, num_gpus=1
|
||||
)
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=CFG.MODEL_PORT, log_level="info")
|
||||
|
@@ -1,29 +1,30 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from pilot.vector_store.file_loader import KnownLedge2Vector
|
||||
from langchain.prompts import PromptTemplate
|
||||
from pilot.conversation import conv_qa_prompt_template
|
||||
|
||||
from pilot.configs.model_config import VECTOR_SEARCH_TOP_K
|
||||
from pilot.conversation import conv_qa_prompt_template
|
||||
from pilot.model.vicuna_llm import VicunaLLM
|
||||
from pilot.vector_store.file_loader import KnownLedge2Vector
|
||||
|
||||
|
||||
class KnownLedgeBaseQA:
|
||||
|
||||
def __init__(self) -> None:
|
||||
k2v = KnownLedge2Vector()
|
||||
self.vector_store = k2v.init_vector_store()
|
||||
self.llm = VicunaLLM()
|
||||
|
||||
|
||||
def get_similar_answer(self, query):
|
||||
|
||||
prompt = PromptTemplate(
|
||||
template=conv_qa_prompt_template,
|
||||
input_variables=["context", "question"]
|
||||
template=conv_qa_prompt_template, input_variables=["context", "question"]
|
||||
)
|
||||
|
||||
retriever = self.vector_store.as_retriever(search_kwargs={"k": VECTOR_SEARCH_TOP_K})
|
||||
retriever = self.vector_store.as_retriever(
|
||||
search_kwargs={"k": VECTOR_SEARCH_TOP_K}
|
||||
)
|
||||
docs = retriever.get_relevant_documents(query=query)
|
||||
|
||||
context = [d.page_content for d in docs]
|
||||
context = [d.page_content for d in docs]
|
||||
result = prompt.format(context="\n".join(context), question=query)
|
||||
return result
|
||||
|
@@ -2,50 +2,63 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import argparse
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import uuid
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
import gradio as gr
|
||||
import datetime
|
||||
import requests
|
||||
import uuid
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import gradio as gr
|
||||
import requests
|
||||
from langchain import PromptTemplate
|
||||
|
||||
from pilot.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH, LLM_MODEL_CONFIG
|
||||
from pilot.server.vectordb_qa import KnownLedgeBaseQA
|
||||
from pilot.connections.mysql import MySQLOperator
|
||||
from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding
|
||||
from pilot.vector_store.extract_tovec import get_vector_storelist, load_knownledge_from_doc, knownledge_tovec_st
|
||||
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(ROOT_PATH)
|
||||
|
||||
from pilot.configs.model_config import LOGDIR, DATASETS_DIR
|
||||
|
||||
from pilot.plugins import scan_plugins
|
||||
from pilot.configs.config import Config
|
||||
from pilot.commands.command import execute_ai_response_json
|
||||
from pilot.commands.command_mange import CommandRegistry
|
||||
from pilot.prompts.auto_mode_prompt import AutoModePrompt
|
||||
from pilot.prompts.generator import PromptGenerator
|
||||
from pilot.scene.base_chat import BaseChat
|
||||
|
||||
from pilot.commands.exception_not_commands import NotCommands
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.configs.model_config import (
|
||||
DATASETS_DIR,
|
||||
KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
LLM_MODEL_CONFIG,
|
||||
LOGDIR,
|
||||
VECTOR_SEARCH_TOP_K,
|
||||
)
|
||||
from pilot.connections.mysql import MySQLOperator
|
||||
from pilot.conversation import (
|
||||
default_conversation,
|
||||
SeparatorStyle,
|
||||
conv_qa_prompt_template,
|
||||
conv_templates,
|
||||
conversation_types,
|
||||
conversation_sql_mode,
|
||||
SeparatorStyle, conv_qa_prompt_template
|
||||
conversation_types,
|
||||
default_conversation,
|
||||
)
|
||||
|
||||
from pilot.utils import (
|
||||
build_logger,
|
||||
server_error_msg,
|
||||
)
|
||||
|
||||
from pilot.plugins import scan_plugins
|
||||
from pilot.prompts.auto_mode_prompt import AutoModePrompt
|
||||
from pilot.prompts.generator import PromptGenerator
|
||||
from pilot.server.gradio_css import code_highlight_css
|
||||
from pilot.server.gradio_patch import Chatbot as grChatbot
|
||||
from pilot.server.vectordb_qa import KnownLedgeBaseQA
|
||||
from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding
|
||||
from pilot.utils import build_logger, server_error_msg
|
||||
from pilot.vector_store.extract_tovec import (
|
||||
get_vector_storelist,
|
||||
knownledge_tovec_st,
|
||||
load_knownledge_from_doc,
|
||||
)
|
||||
|
||||
from pilot.commands.command import execute_ai_response_json
|
||||
from pilot.scene.base import ChatScene
|
||||
@@ -66,9 +79,7 @@ autogpt = False
|
||||
vector_store_client = None
|
||||
vector_store_name = {"vs_name": ""}
|
||||
|
||||
priority = {
|
||||
"vicuna-13b": "aaa"
|
||||
}
|
||||
priority = {"vicuna-13b": "aaa"}
|
||||
|
||||
# 加载插件
|
||||
CFG = Config()
|
||||
@@ -76,10 +87,12 @@ CHAT_FACTORY = ChatFactory()
|
||||
|
||||
DB_SETTINGS = {
|
||||
"user": CFG.LOCAL_DB_USER,
|
||||
"password": CFG.LOCAL_DB_PASSWORD,
|
||||
"password": CFG.LOCAL_DB_PASSWORD,
|
||||
"host": CFG.LOCAL_DB_HOST,
|
||||
"port": CFG.LOCAL_DB_PORT
|
||||
"port": CFG.LOCAL_DB_PORT,
|
||||
}
|
||||
|
||||
|
||||
def get_simlar(q):
|
||||
docsearch = knownledge_tovec_st(os.path.join(DATASETS_DIR, "plan.md"))
|
||||
docs = docsearch.similarity_search_with_score(q, k=1)
|
||||
@@ -89,9 +102,7 @@ def get_simlar(q):
|
||||
|
||||
|
||||
def gen_sqlgen_conversation(dbname):
|
||||
mo = MySQLOperator(
|
||||
**DB_SETTINGS
|
||||
)
|
||||
mo = MySQLOperator(**DB_SETTINGS)
|
||||
|
||||
message = ""
|
||||
|
||||
@@ -334,8 +345,8 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
|
||||
|
||||
|
||||
block_css = (
|
||||
code_highlight_css
|
||||
+ """
|
||||
code_highlight_css
|
||||
+ """
|
||||
pre {
|
||||
white-space: pre-wrap; /* Since CSS 2.1 */
|
||||
white-space: -moz-pre-wrap; /* Mozilla, since 1999 */
|
||||
@@ -372,7 +383,7 @@ def build_single_model_ui():
|
||||
notice_markdown = """
|
||||
# DB-GPT
|
||||
|
||||
[DB-GPT](https://github.com/csunny/DB-GPT) 是一个实验性的开源应用程序,它基于[FastChat](https://github.com/lm-sys/FastChat),并使用vicuna-13b作为基础模型。此外,此程序结合了langchain和llama-index基于现有知识库进行In-Context Learning来对其进行数据库相关知识的增强。它可以进行SQL生成、SQL诊断、数据库知识问答等一系列的工作。 总的来说,它是一个用于数据库的复杂且创新的AI工具。如果您对如何在工作中使用或实施DB-GPT有任何具体问题,请联系我, 我会尽力提供帮助, 同时也欢迎大家参与到项目建设中, 做一些有趣的事情。
|
||||
[DB-GPT](https://github.com/csunny/DB-GPT) 是一个开源的以数据库为基础的GPT实验项目,使用本地化的GPT大模型与您的数据和环境进行交互,无数据泄露风险,100% 私密,100% 安全。
|
||||
"""
|
||||
learn_more_markdown = """
|
||||
### Licence
|
||||
@@ -396,7 +407,7 @@ def build_single_model_ui():
|
||||
max_output_tokens = gr.Slider(
|
||||
minimum=0,
|
||||
maximum=1024,
|
||||
value=1024,
|
||||
value=512,
|
||||
step=64,
|
||||
interactive=True,
|
||||
label="最大输出Token数",
|
||||
@@ -412,7 +423,8 @@ def build_single_model_ui():
|
||||
choices=dbs,
|
||||
value=dbs[0] if len(models) > 0 else "",
|
||||
interactive=True,
|
||||
show_label=True).style(container=False)
|
||||
show_label=True,
|
||||
).style(container=False)
|
||||
|
||||
sql_mode = gr.Radio(["直接执行结果", "不执行结果"], show_label=False, value="不执行结果")
|
||||
sql_vs_setting = gr.Markdown("自动执行模式下, DB-GPT可以具备执行SQL、从网络读取知识自动化存储学习的能力")
|
||||
@@ -420,7 +432,9 @@ def build_single_model_ui():
|
||||
|
||||
tab_qa = gr.TabItem("知识问答", elem_id="QA")
|
||||
with tab_qa:
|
||||
mode = gr.Radio(["LLM原生对话", "默认知识库对话", "新增知识库对话"], show_label=False, value="LLM原生对话")
|
||||
mode = gr.Radio(
|
||||
["LLM原生对话", "默认知识库对话", "新增知识库对话"], show_label=False, value="LLM原生对话"
|
||||
)
|
||||
vs_setting = gr.Accordion("配置知识库", open=False)
|
||||
mode.change(fn=change_mode, inputs=mode, outputs=vs_setting)
|
||||
with vs_setting:
|
||||
@@ -429,18 +443,22 @@ def build_single_model_ui():
|
||||
with gr.Column() as doc2vec:
|
||||
gr.Markdown("向知识库中添加文件")
|
||||
with gr.Tab("上传文件"):
|
||||
files = gr.File(label="添加文件",
|
||||
file_types=[".txt", ".md", ".docx", ".pdf"],
|
||||
file_count="multiple",
|
||||
show_label=False
|
||||
)
|
||||
files = gr.File(
|
||||
label="添加文件",
|
||||
file_types=[".txt", ".md", ".docx", ".pdf"],
|
||||
file_count="multiple",
|
||||
allow_flagged_uploads=True,
|
||||
show_label=False,
|
||||
)
|
||||
|
||||
load_file_button = gr.Button("上传并加载到知识库")
|
||||
with gr.Tab("上传文件夹"):
|
||||
folder_files = gr.File(label="添加文件夹",
|
||||
accept_multiple_files=True,
|
||||
file_count="directory",
|
||||
show_label=False)
|
||||
folder_files = gr.File(
|
||||
label="添加文件夹",
|
||||
accept_multiple_files=True,
|
||||
file_count="directory",
|
||||
show_label=False,
|
||||
)
|
||||
load_folder_button = gr.Button("上传并加载到知识库")
|
||||
|
||||
with gr.Blocks():
|
||||
@@ -481,28 +499,32 @@ def build_single_model_ui():
|
||||
).then(
|
||||
http_bot,
|
||||
[state, mode, sql_mode, db_selector, temperature, max_output_tokens],
|
||||
[state, chatbot] + btn_list
|
||||
[state, chatbot] + btn_list,
|
||||
)
|
||||
vs_add.click(
|
||||
fn=save_vs_name, show_progress=True, inputs=[vs_name], outputs=[vs_name]
|
||||
)
|
||||
load_file_button.click(
|
||||
fn=knowledge_embedding_store,
|
||||
show_progress=True,
|
||||
inputs=[vs_name, files],
|
||||
outputs=[vs_name],
|
||||
)
|
||||
load_folder_button.click(
|
||||
fn=knowledge_embedding_store,
|
||||
show_progress=True,
|
||||
inputs=[vs_name, folder_files],
|
||||
outputs=[vs_name],
|
||||
)
|
||||
vs_add.click(fn=save_vs_name, show_progress=True,
|
||||
inputs=[vs_name],
|
||||
outputs=[vs_name])
|
||||
load_file_button.click(fn=knowledge_embedding_store,
|
||||
show_progress=True,
|
||||
inputs=[vs_name, files],
|
||||
outputs=[vs_name])
|
||||
load_folder_button.click(fn=knowledge_embedding_store,
|
||||
show_progress=True,
|
||||
inputs=[vs_name, folder_files],
|
||||
outputs=[vs_name])
|
||||
return state, chatbot, textbox, send_btn, button_row, parameter_row
|
||||
|
||||
|
||||
def build_webdemo():
|
||||
with gr.Blocks(
|
||||
title="数据库智能助手",
|
||||
# theme=gr.themes.Base(),
|
||||
theme=gr.themes.Default(),
|
||||
css=block_css,
|
||||
title="数据库智能助手",
|
||||
# theme=gr.themes.Base(),
|
||||
theme=gr.themes.Default(),
|
||||
css=block_css,
|
||||
) as demo:
|
||||
url_params = gr.JSON(visible=False)
|
||||
(
|
||||
@@ -544,15 +566,21 @@ def knowledge_embedding_store(vs_id, files):
|
||||
os.makedirs(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id))
|
||||
for file in files:
|
||||
filename = os.path.split(file.name)[-1]
|
||||
shutil.move(file.name, os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename))
|
||||
shutil.move(
|
||||
file.name, os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename)
|
||||
)
|
||||
knowledge_embedding_client = KnowledgeEmbedding(
|
||||
file_path=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename),
|
||||
model_name=LLM_MODEL_CONFIG["sentence-transforms"],
|
||||
model_name=LLM_MODEL_CONFIG["text2vec"],
|
||||
local_persist=False,
|
||||
vector_store_config={
|
||||
"vector_store_name": vector_store_name["vs_name"],
|
||||
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH})
|
||||
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
},
|
||||
)
|
||||
knowledge_embedding_client.knowledge_embedding()
|
||||
|
||||
|
||||
logger.info("knowledge embedding success")
|
||||
return os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, vs_id + ".vectordb")
|
||||
|
||||
@@ -596,5 +624,8 @@ if __name__ == "__main__":
|
||||
demo.queue(
|
||||
concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False
|
||||
).launch(
|
||||
server_name=args.host, server_port=args.port, share=args.share, max_threads=200,
|
||||
server_name=args.host,
|
||||
server_port=args.port,
|
||||
share=args.share,
|
||||
max_threads=200,
|
||||
)
|
||||
|
Reference in New Issue
Block a user