[ColossalQA] refactor server and webui & add new feature (#5138)

* refactor server and webui & add new feature

* add requirements

* modify readme and ui
This commit is contained in:
Michelle
2023-11-30 22:55:52 +08:00
committed by GitHub
parent 2a2ec49aa7
commit c7fd9a5213
12 changed files with 374 additions and 251 deletions

View File

@@ -1,3 +1,4 @@
import os
from typing import Dict, Tuple
from colossalqa.chain.retrieval_qa.base import RetrievalQA
@@ -12,29 +13,11 @@ from colossalqa.prompt.prompt import (
ZH_RETRIEVAL_QA_TRIGGER_KEYWORDS,
)
from colossalqa.retriever import CustomRetriever
from colossalqa.text_splitter import ChineseTextSplitter
from langchain import LLMChain
from langchain.embeddings import HuggingFaceEmbeddings
logger = get_logger()
DEFAULT_RAG_CFG = {
"retri_top_k": 3,
"retri_kb_file_path": "./",
"verbose": True,
"mem_summary_prompt": SUMMARY_PROMPT_ZH,
"mem_human_prefix": "用户",
"mem_ai_prefix": "Assistant",
"mem_max_tokens": 2000,
"mem_llm_kwargs": {"max_new_tokens": 50, "temperature": 1, "do_sample": True},
"disambig_prompt": PROMPT_DISAMBIGUATE_ZH,
"disambig_llm_kwargs": {"max_new_tokens": 30, "temperature": 1, "do_sample": True},
"embed_model_name_or_path": "moka-ai/m3e-base",
"embed_model_device": {"device": "cpu"},
"gen_llm_kwargs": {"max_new_tokens": 100, "temperature": 1, "do_sample": True},
"gen_qa_prompt": PROMPT_RETRIEVAL_QA_ZH,
}
class RAG_ChatBot:
def __init__(
@@ -44,13 +27,16 @@ class RAG_ChatBot:
) -> None:
self.llm = llm
self.rag_config = rag_config
self.set_embed_model(**self.rag_config)
self.set_text_splitter(**self.rag_config)
self.set_memory(**self.rag_config)
self.set_info_retriever(**self.rag_config)
self.set_rag_chain(**self.rag_config)
if self.rag_config.get("disambig_prompt", None):
self.set_disambig_retriv(**self.rag_config)
self.set_embed_model(**self.rag_config["embed"])
self.set_text_splitter(**self.rag_config["splitter"])
self.set_memory(**self.rag_config["chain"])
self.set_info_retriever(**self.rag_config["retrieval"])
self.set_rag_chain(**self.rag_config["chain"])
if self.rag_config["chain"].get("disambig_prompt", None):
self.set_disambig_retriv(**self.rag_config["chain"])
self.documents = []
self.docs_names = []
def set_embed_model(self, **kwargs):
self.embed_model = HuggingFaceEmbeddings(
@@ -61,7 +47,7 @@ class RAG_ChatBot:
def set_text_splitter(self, **kwargs):
# Initialize text_splitter
self.text_splitter = ChineseTextSplitter()
self.text_splitter = kwargs["name"]()
def set_memory(self, **kwargs):
params = {"llm_kwargs": kwargs["mem_llm_kwargs"]} if kwargs.get("mem_llm_kwargs", None) else {}
@@ -91,10 +77,6 @@ class RAG_ChatBot:
**params,
)
def split_docs(self, documents):
doc_splits = self.text_splitter.split_documents(documents)
return doc_splits
def set_disambig_retriv(self, **kwargs):
params = {"llm_kwargs": kwargs["disambig_llm_kwargs"]} if kwargs.get("disambig_llm_kwargs", None) else {}
self.llm_chain_disambiguate = LLMChain(llm=self.llm, prompt=kwargs["disambig_prompt"], **params)
@@ -106,42 +88,50 @@ class RAG_ChatBot:
self.info_retriever.set_rephrase_handler(disambiguity)
def load_doc_from_console(self, json_parse_args: Dict = {}):
documents = []
print("Select files for constructing Chinese retriever")
print("Select files for constructing the retriever")
while True:
file = input("Enter a file path or press Enter directly without input to exit:").strip()
if file == "":
break
data_name = input("Enter a short description of the data:")
docs = DocumentLoader([[file, data_name.replace(" ", "_")]], **json_parse_args).all_data
documents.extend(docs)
self.documents = documents
self.split_docs_and_add_to_mem(**self.rag_config)
self.documents.extend(docs)
self.docs_names.append(data_name)
self.split_docs_and_add_to_mem(**self.rag_config["chain"])
def load_doc_from_files(self, files, data_name="default_kb", json_parse_args: Dict = {}):
documents = []
for file in files:
docs = DocumentLoader([[file, data_name.replace(" ", "_")]], **json_parse_args).all_data
documents.extend(docs)
self.documents = documents
self.split_docs_and_add_to_mem(**self.rag_config)
self.documents.extend(docs)
self.docs_names.append(os.path.basename(file))
self.split_docs_and_add_to_mem(**self.rag_config["chain"])
def split_docs_and_add_to_mem(self, **kwargs):
self.doc_splits = self.split_docs(self.documents)
doc_splits = self.split_docs(self.documents)
self.info_retriever.add_documents(
docs=self.doc_splits, cleanup="incremental", mode="by_source", embedding=self.embed_model
docs=doc_splits, cleanup="incremental", mode="by_source", embedding=self.embed_model
)
self.memory.initiate_document_retrieval_chain(self.llm, kwargs["gen_qa_prompt"], self.info_retriever)
def split_docs(self, documents):
doc_splits = self.text_splitter.split_documents(documents)
return doc_splits
def clear_docs(self, **kwargs):
self.documents = []
self.docs_names = []
self.info_retriever.clear_documents()
self.memory.initiate_document_retrieval_chain(self.llm, kwargs["gen_qa_prompt"], self.info_retriever)
def reset_config(self, rag_config):
self.rag_config = rag_config
self.set_embed_model(**self.rag_config)
self.set_text_splitter(**self.rag_config)
self.set_memory(**self.rag_config)
self.set_info_retriever(**self.rag_config)
self.set_rag_chain(**self.rag_config)
if self.rag_config.get("disambig_prompt", None):
self.set_disambig_retriv(**self.rag_config)
self.set_embed_model(**self.rag_config["embed"])
self.set_text_splitter(**self.rag_config["splitter"])
self.set_memory(**self.rag_config["chain"])
self.set_info_retriever(**self.rag_config["retrieval"])
self.set_rag_chain(**self.rag_config["chain"])
if self.rag_config["chain"].get("disambig_prompt", None):
self.set_disambig_retriv(**self.rag_config["chain"])
def run(self, user_input: str, memory: ConversationBufferWithSummary) -> Tuple[str, ConversationBufferWithSummary]:
if memory:
@@ -153,7 +143,7 @@ class RAG_ChatBot:
rejection_trigger_keywrods=ZH_RETRIEVAL_QA_TRIGGER_KEYWORDS,
rejection_answer=ZH_RETRIEVAL_QA_REJECTION_ANSWER,
)
return result.split("\n")[0], memory
return result, memory
def start_test_session(self):
"""
@@ -170,15 +160,18 @@ class RAG_ChatBot:
if __name__ == "__main__":
# Initialize an Langchain LLM(here we use ChatGPT as an example)
import config
from langchain.llms import OpenAI
llm = OpenAI(openai_api_key="YOUR_OPENAI_API_KEY")
# you need to: export OPENAI_API_KEY="YOUR_OPENAI_API_KEY"
llm = OpenAI(openai_api_key=os.getenv("OPENAI_API_KEY"))
# chatgpt cannot control temperature, do_sample, etc.
DEFAULT_RAG_CFG["mem_llm_kwargs"] = None
DEFAULT_RAG_CFG["disambig_llm_kwargs"] = None
DEFAULT_RAG_CFG["gen_llm_kwargs"] = None
all_config = config.ALL_CONFIG
all_config["chain"]["mem_llm_kwargs"] = None
all_config["chain"]["disambig_llm_kwargs"] = None
all_config["chain"]["gen_llm_kwargs"] = None
rag = RAG_ChatBot(llm, DEFAULT_RAG_CFG)
rag = RAG_ChatBot(llm, all_config)
rag.load_doc_from_console()
rag.start_test_session()

View File

@@ -16,22 +16,103 @@ cd ColossalAI/applications/ColossalQA/
pip install -e .
```
Install the dependencies for ColossalQA webui demo:
```sh
pip install -r requirements.txt
```
## Configure the RAG Chain
Customize the RAG Chain settings, such as the embedding model (default: moka-ai/m3e) and the language model, in the `start_colossal_qa.sh` script.
Customize the RAG Chain settings, such as the embedding model (default: moka-ai/m3e), the language model, and the prompts, in the `config.py`. Please refer to [`Prepare configuration file`](#prepare-configuration-file) for the details of `config.py`.
For API-based language models (like ChatGPT or Huawei Pangu), provide your API key for authentication. For locally-run models, indicate the path to the model's checkpoint file.
If you want to customize prompts in the RAG Chain, you can have a look at the `RAG_ChatBot.py` file to modify them.
## Prepare configuration file
All configs are defined in `ColossalQA/examples/webui_demo/config.py`.
- embed:
- <mark>embed_name</mark>: the embedding model name
- <mark>embed_model_name_or_path</mark>: path to embedding model, could be a local path or a huggingface path
- embed_model_device: device to load the embedding model
- model:
- <mark>mode</mark>: "local" for loading models, "api" for using model api
- <mark>model_name</mark>: "chatgpt_api", "pangu_api", or your local model name
- <mark>model_path</mark>: path to the model, could be a local path or a huggingface path. don't need if mode="api"
- device: device to load the LLM
- splitter:
- name: text splitter class name, the class should be imported at the beginning of `config.py`
- retrieval:
- retri_top_k: number of retrieval text which will be provided to the model
- retri_kb_file_path: path to store database files
- verbose: Boolean type, to control the level of detail in program output
- chain:
- mem_summary_prompt: summary prompt template
- mem_human_prefix: human prefix for prompt
- mem_ai_prefix: AI assistant prefix for prompt
- mem_max_tokens: max tokens for history information
- mem_llm_kwargs: model's generation kwargs for summarizing history
- max_new_tokens: int
- temperature: int
- do_sample: bool
- disambig_prompt: disambiguate prompt template
- disambig_llm_kwargs: model's generation kwargs for disambiguating user's input
- max_new_tokens: int
- temperature": int
- do_sample: bool
- gen_llm_kwargs: model's generation kwargs
- max_new_tokens: int
- temperature: int
- do_sample: bool
- gen_qa_prompt: generation prompt template
- verbose: Boolean type, to control the level of detail in program output
## Run WebUI Demo
Execute the following command to start the demo:
1. If you want to use a local model as the backend model, you need to specify the model name and model path in `config.py` and run the following commands.
```sh
bash start_colossal_qa.sh
export TMP="path/to/store/tmp/files"
# start the backend server
python server.py --http_host "host" --http_port "port"
# in an another terminal, start the ui
python webui.py --http_host "your-backend-api-host" --http_port "your-backend-api-port"
```
2. If you want to use pangu api as the backend model, you need to change the model mode to "api", change the model name to "chatgpt_api" in `config.py`, and run the following commands.
```sh
export TMP="path/to/store/tmp/files"
# Auth info for OpenAI API
export OPENAI_API_KEY="YOUR_OPENAI_API_KEY"
# start the backend server
python server.py --http_host "host" --http_port "port"
# in an another terminal, start the ui
python webui.py --http_host "your-backend-api-host" --http_port "your-backend-api-port"
```
3. If you want to use pangu api as the backend model, you need to change the model mode to "api", change the model name to "pangu_api" in `config.py`, and run the following commands.
```sh
export TMP="path/to/store/tmp/files"
# Auth info for Pangu API
export URL=""
export USERNAME=""
export PASSWORD=""
export DOMAIN_NAME=""
# start the backend server
python server.py --http_host "host" --http_port "port"
# in an another terminal, start the ui
python webui.py --http_host "your-backend-api-host" --http_port "your-backend-api-port"
```
After launching the script, you can upload files and engage with the chatbot through your web browser.
![ColossalQA Demo](https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/colossalqa/img/qa_demo.png)
![ColossalQA Demo](https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/colossalqa/new_ui.png)

View File

@@ -0,0 +1,58 @@
from colossalqa.prompt.prompt import (
PROMPT_DISAMBIGUATE_ZH,
PROMPT_RETRIEVAL_QA_ZH,
SUMMARY_PROMPT_ZH,
ZH_RETRIEVAL_QA_REJECTION_ANSWER,
ZH_RETRIEVAL_QA_TRIGGER_KEYWORDS,
)
from colossalqa.text_splitter import ChineseTextSplitter
ALL_CONFIG = {
"embed": {
"embed_name": "m3e", # embedding model name
"embed_model_name_or_path": "moka-ai/m3e-base", # path to embedding model, could be a local path or a huggingface path
"embed_model_device": {
"device": "cpu"
}
},
"model": {
"mode": "api", # "local" for loading models, "api" for using model api
"model_name": "chatgpt_api", # local model name, "chatgpt_api" or "pangu_api"
"model_path": "", # path to the model, could be a local path or a huggingface path. don't need if using an api
"device": {
"device": "cuda"
}
},
"splitter": {
"name": ChineseTextSplitter
},
"retrieval": {
"retri_top_k": 3,
"retri_kb_file_path": "./", # path to store database files
"verbose": True
},
"chain": {
"mem_summary_prompt": SUMMARY_PROMPT_ZH, # summary prompt template
"mem_human_prefix": "用户",
"mem_ai_prefix": "Assistant",
"mem_max_tokens": 2000,
"mem_llm_kwargs": {
"max_new_tokens": 50,
"temperature": 1,
"do_sample": True
},
"disambig_prompt": PROMPT_DISAMBIGUATE_ZH, # disambiguate prompt template
"disambig_llm_kwargs": {
"max_new_tokens": 30,
"temperature": 1,
"do_sample": True
},
"gen_llm_kwargs": {
"max_new_tokens": 100,
"temperature": 1,
"do_sample": True
},
"gen_qa_prompt": PROMPT_RETRIEVAL_QA_ZH, # generation prompt template
"verbose": True
}
}

View File

@@ -0,0 +1,3 @@
fastapi==0.99.1
uvicorn>=0.24.0
pydantic==1.10.13

View File

@@ -1,117 +1,98 @@
import argparse
import copy
import json
import os
import random
import string
from http.server import BaseHTTPRequestHandler, HTTPServer
from typing import List, Union
from colossalqa.local.llm import ColossalAPI, ColossalLLM
from colossalqa.data_loader.document_loader import DocumentLoader
from colossalqa.mylogging import get_logger
from colossalqa.retrieval_conversation_zh import ChineseRetrievalConversation
from colossalqa.retriever import CustomRetriever
from enum import Enum
from fastapi import FastAPI, Request
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from RAG_ChatBot import RAG_ChatBot, DEFAULT_RAG_CFG
from pydantic import BaseModel, Field
import uvicorn
# Define the mapping between embed_model_name(passed from Front End) and the actual path on the back end server
EMBED_MODEL_DICT = {
"m3e": os.environ.get("EMB_MODEL_PATH", DEFAULT_RAG_CFG["embed_model_name_or_path"])
}
# Define the mapping between LLM_name(passed from Front End) and the actual path on the back end server
LLM_DICT = {
"chatglm2": os.environ.get("CHAT_LLM_PATH", "THUDM/chatglm-6b"),
"pangu": "Pangu_API",
"chatgpt": "OpenAI_API"
}
import config
from RAG_ChatBot import RAG_ChatBot
from utils import DocAction
def randomword(length):
letters = string.ascii_lowercase
return "".join(random.choice(letters) for i in range(length))
class ColossalQAServerRequestHandler(BaseHTTPRequestHandler):
chatbot = None
def _set_response(self):
"""
set http header for response
"""
self.send_response(200)
self.send_header("Content-type", "application/json")
self.end_headers()
logger = get_logger()
def do_POST(self):
content_length = int(self.headers["Content-Length"])
post_data = self.rfile.read(content_length)
received_json = json.loads(post_data.decode("utf-8"))
print(received_json)
# conversation_ready is False(user's first request): Need to upload files and initialize the RAG chain
if received_json["conversation_ready"] is False:
self.rag_config = DEFAULT_RAG_CFG.copy()
try:
assert received_json["embed_model_name"] in EMBED_MODEL_DICT
assert received_json["llm_name"] in LLM_DICT
self.docs_files = received_json["docs"]
embed_model_name, llm_name = received_json["embed_model_name"], received_json["llm_name"]
# Find the embed_model/llm ckpt path on the back end server.
embed_model_path, llm_path = EMBED_MODEL_DICT[embed_model_name], LLM_DICT[llm_name]
self.rag_config["embed_model_name_or_path"] = embed_model_path
def parseArgs():
parser = argparse.ArgumentParser()
parser.add_argument("--http_host", default="0.0.0.0")
parser.add_argument("--http_port", type=int, default=13666)
return parser.parse_args()
app = FastAPI()
class DocUpdateReq(BaseModel):
doc_files: Union[List[str], str, None] = None
action: DocAction = DocAction.ADD
class GenerationTaskReq(BaseModel):
user_input: str
@app.post("/update")
def update_docs(data: DocUpdateReq, request: Request):
if data.action == "add":
if isinstance(data.doc_files, str):
data.doc_files = [data.doc_files]
chatbot.load_doc_from_files(files = data.doc_files)
all_docs = ""
for doc in chatbot.docs_names:
all_docs += f"\t{doc}\n\n"
return {"response": f"文件上传完成,所有数据库文件:\n\n{all_docs}让我们开始对话吧!"}
elif data.action == "clear":
chatbot.clear_docs(**all_config["chain"])
return {"response": f"已清空数据库。"}
@app.post("/generate")
def generate(data: GenerationTaskReq, request: Request):
try:
chatbot_response, chatbot.memory = chatbot.run(data.user_input, chatbot.memory)
return {"response": chatbot_response, "error": ""}
except Exception as e:
return {"response": "模型生成回答有误", "error": f"Error in generating answers, details: {e}"}
# Create the storage path for knowledge base files
self.rag_config["retri_kb_file_path"] = os.path.join(os.environ["TMP"], "colossalqa_kb/"+randomword(20))
if not os.path.exists(self.rag_config["retri_kb_file_path"]):
os.makedirs(self.rag_config["retri_kb_file_path"])
if (embed_model_path is not None) and (llm_path is not None):
# ---- Intialize LLM, QA_chatbot here ----
print("Initializing LLM...")
if llm_path == "Pangu_API":
from colossalqa.local.pangu_llm import Pangu
self.llm = Pangu(id=1)
self.llm.set_auth_config() # verify user's auth info here
self.rag_config["mem_llm_kwargs"] = None
self.rag_config["disambig_llm_kwargs"] = None
self.rag_config["gen_llm_kwargs"] = None
elif llm_path == "OpenAI_API":
from langchain.llms import OpenAI
self.llm = OpenAI()
self.rag_config["mem_llm_kwargs"] = None
self.rag_config["disambig_llm_kwargs"] = None
self.rag_config["gen_llm_kwargs"] = None
else:
# ** (For Testing Only) **
# In practice, all LLMs will run on the cloud platform and accessed by API, instead of running locally.
# initialize model from model_path by using ColossalLLM
self.rag_config["mem_llm_kwargs"] = {"max_new_tokens": 50, "temperature": 1, "do_sample": True}
self.rag_config["disambig_llm_kwargs"] = {"max_new_tokens": 30, "temperature": 1, "do_sample": True}
self.rag_config["gen_llm_kwargs"] = {"max_new_tokens": 100, "temperature": 1, "do_sample": True}
self.colossal_api = ColossalAPI(llm_name, llm_path)
self.llm = ColossalLLM(n=1, api=self.colossal_api)
print(f"Initializing RAG Chain...")
print("RAG_CONFIG: ", self.rag_config)
self.__class__.chatbot = RAG_ChatBot(self.llm, self.rag_config)
print("Loading Files....\n", self.docs_files)
self.__class__.chatbot.load_doc_from_files(self.docs_files)
# -----------------------------------------------------------------------------------
res = {"response": f"文件上传完成,模型初始化完成,让我们开始对话吧!(后端模型:{llm_name})", "error": "", "conversation_ready": True}
except Exception as e:
res = {"response": "文件上传或模型初始化有误,无法开始对话。",
"error": f"Error in File Uploading and/or RAG initialization. Error details: {e}",
"conversation_ready": False}
# conversation_ready is True: Chatbot and docs are all set. Ready to chat.
else:
user_input = received_json["user_input"]
chatbot_response, self.__class__.chatbot.memory = self.__class__.chatbot.run(user_input, self.__class__.chatbot.memory)
res = {"response": chatbot_response, "error": "", "conversation_ready": True}
self._set_response()
self.wfile.write(json.dumps(res).encode("utf-8"))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Chinese retrieval based conversation system")
parser.add_argument("--port", type=int, default=13666, help="port on localhost to start the server")
args = parser.parse_args()
server_address = ("localhost", args.port)
httpd = HTTPServer(server_address, ColossalQAServerRequestHandler)
print(f"Starting server on port {args.port}...")
httpd.serve_forever()
args = parseArgs()
all_config = config.ALL_CONFIG
model_name = all_config["model"]["model_name"]
# initialize chatbot
logger.info(f"Initialize the chatbot from {model_name}")
if all_config["model"]["mode"] == "local":
colossal_api = ColossalAPI(model_name, all_config["model"]["model_path"])
llm = ColossalLLM(n=1, api=colossal_api)
elif all_config["model"]["mode"] == "api":
all_config["chain"]["mem_llm_kwargs"] = None
all_config["chain"]["disambig_llm_kwargs"] = None
all_config["chain"]["gen_llm_kwargs"] = None
if model_name == "pangu_api":
from colossalqa.local.pangu_llm import Pangu
llm = Pangu(id=1)
llm.set_auth_config() # verify user's auth info here
elif model_name == "chatgpt_api":
from langchain.llms import OpenAI
llm = OpenAI()
else:
raise ValueError("Unsupported mode.")
# initialize chatbot
chatbot = RAG_ChatBot(llm, all_config)
app_config = uvicorn.Config(app, host=args.http_host, port=args.http_port)
server = uvicorn.Server(config=app_config)
server.run()

View File

@@ -1,43 +0,0 @@
#!/bin/bash
cleanup() {
echo "Caught Signal ... cleaning up."
pkill -P $$ # kill all subprocess of this script
exit 1 # exit script
}
# 'cleanup' is trigered when receive SIGINT(Ctrl+C) OR SIGTERM(kill) signal
trap cleanup INT TERM
# Disable your proxy
# unset HTTP_PROXY HTTPS_PROXY http_proxy https_proxy
# Path to store knowledge base(Home Directory by default)
export TMP=$HOME
# Use m3e as embedding model
export EMB_MODEL="m3e" # moka-ai/m3e-base model will be download automatically
# export EMB_MODEL_PATH="PATH_TO_LOCAL_CHECKPOINT/m3e-base" # you can also specify the local path to embedding model
# Choose a backend LLM
# - ChatGLM2
# export CHAT_LLM="chatglm2"
# export CHAT_LLM_PATH="PATH_TO_LOCAL_CHECKPOINT/chatglm2-6b"
# - ChatGPT
export CHAT_LLM="chatgpt"
# Auth info for OpenAI API
export OPENAI_API_KEY="YOUR_OPENAI_API_KEY"
# - Pangu
# export CHAT_LLM="pangu"
# # Auth info for Pangu API
# export URL=""
# export USERNAME=""
# export PASSWORD=""
# export DOMAIN_NAME=""
# Run server.py and colossalqa_webui.py in the background
python server.py &
python webui.py &
# Wait for all processes to finish
wait

View File

@@ -0,0 +1,6 @@
from enum import Enum
class DocAction(str, Enum):
ADD = "add"
CLEAR = "clear"

View File

@@ -1,17 +1,21 @@
import argparse
import json
import os
import gradio as gr
import requests
RAG_STATE = {"conversation_ready": False, # Conversation is not ready until files are uploaded and RAG chain is initialized
"embed_model_name": os.environ.get("EMB_MODEL", "m3e"),
"llm_name": os.environ.get("CHAT_LLM", "chatgpt")}
URL = "http://localhost:13666"
import gradio as gr
def get_response(client_data, URL):
from utils import DocAction
def parseArgs():
parser = argparse.ArgumentParser()
parser.add_argument("--http_host", default="0.0.0.0")
parser.add_argument("--http_port", type=int, default=13666)
return parser.parse_args()
def get_response(data, url):
headers = {"Content-type": "application/json"}
print(f"Sending request to server url: {URL}")
response = requests.post(URL, data=json.dumps(client_data), headers=headers)
response = requests.post(url, json=data, headers=headers)
response = json.loads(response.content)
return response
@@ -19,41 +23,43 @@ def add_text(history, text):
history = history + [(text, None)]
return history, gr.update(value=None, interactive=True)
def add_file(history, files):
global RAG_STATE
RAG_STATE["conversation_ready"] = False # after adding new files, reset the ChatBot
RAG_STATE["upload_files"]=[file.name for file in files]
files_string = "\n".join([os.path.basename(path) for path in RAG_STATE["upload_files"]])
print(files_string)
history = history + [(files_string, None)]
files_string = "\n".join([os.path.basename(file.name) for file in files])
doc_files = [file.name for file in files]
data = {
"doc_files": doc_files,
"action": DocAction.ADD
}
response = get_response(data, update_url)["response"]
history = history + [(files_string, response)]
return history
def bot(history):
print(history)
global RAG_STATE
if not RAG_STATE["conversation_ready"]:
# Upload files and initialize models
client_data = {
"docs": RAG_STATE["upload_files"],
"embed_model_name": RAG_STATE["embed_model_name"], # Select embedding model name here
"llm_name": RAG_STATE["llm_name"], # Select LLM model name here. ["pangu", "chatglm2"]
"conversation_ready": RAG_STATE["conversation_ready"]
}
else:
client_data = {}
client_data["conversation_ready"] = RAG_STATE["conversation_ready"]
client_data["user_input"] = history[-1][0].strip()
def bot(history):
data = {
"user_input": history[-1][0].strip()
}
response = get_response(data, gen_url)
response = get_response(client_data, URL) # TODO: async request, to avoid users waiting the model initialization too long
print(response)
if response["error"] != "":
raise gr.Error(response["error"])
RAG_STATE["conversation_ready"] = response["conversation_ready"]
history[-1][1] = response["response"]
yield history
def restart(chatbot, txt):
# Reset the conversation state and clear the chat history
data = {
"doc_files": "",
"action": DocAction.CLEAR
}
response = get_response(data, update_url)
return gr.update(value=None), gr.update(value=None, interactive=True)
CSS = """
.contain { display: flex; flex-direction: column; height: 100vh }
#component-0 { height: 100%; }
@@ -63,7 +69,7 @@ CSS = """
header_html = """
<div style="background: linear-gradient(to right, #2a0cf4, #7100ed, #9800e6, #b600df, #ce00d9, #dc0cd1, #e81bca, #f229c3, #f738ba, #f946b2, #fb53ab, #fb5fa5); padding: 20px; text-align: left;">
<h1 style="color: white;">ColossalQA</h1>
<h4 style="color: white;">ColossalQA</h4>
<h4 style="color: white;">A powerful Q&A system with knowledge bases</h4>
</div>
"""
@@ -78,25 +84,32 @@ with gr.Blocks(css=CSS) as demo:
(os.path.join(os.path.dirname(__file__), "img/avatar_ai.png")),
),
)
with gr.Row():
btn = gr.UploadButton("📁", file_types=["file"], file_count="multiple", size="sm")
restart_btn = gr.Button(str("\u21BB"), elem_id="restart-btn", scale=1)
txt = gr.Textbox(
scale=4,
scale=8,
show_label=False,
placeholder="Enter text and press enter, or upload an image",
placeholder="Enter text and press enter, or use 📁 to upload files, click \u21BB to clear loaded files and restart chat",
container=True,
autofocus=True,
)
btn = gr.UploadButton("📁", file_types=["file"], file_count="multiple")
txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(bot, chatbot, chatbot)
# Clear the original textbox
txt_msg.then(lambda: gr.update(value=None, interactive=True), None, [txt], queue=False)
# Click Upload Button: 1. upload files 2. send config to backend, initalize model 3. get response "conversation_ready" = True/False
file_msg = btn.upload(add_file, [chatbot, btn], [chatbot], queue=False).then(bot, chatbot, chatbot)
file_msg = btn.upload(add_file, [chatbot, btn], [chatbot], queue=False)
# restart
restart_msg = restart_btn.click(restart, [chatbot, txt], [chatbot, txt], queue=False)
if __name__ == "__main__":
args = parseArgs()
update_url = f"http://{args.http_host}:{args.http_port}/update"
gen_url = f"http://{args.http_host}:{args.http_port}/generate"
demo.queue()
demo.launch(share=True) # share=True will release a public link of the demo