mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 20:10:17 +00:00
[ColossalQA] refactor server and webui & add new feature (#5138)
* refactor server and webui & add new feature * add requirements * modify readme and ui
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
from typing import Dict, Tuple
|
||||
|
||||
from colossalqa.chain.retrieval_qa.base import RetrievalQA
|
||||
@@ -12,29 +13,11 @@ from colossalqa.prompt.prompt import (
|
||||
ZH_RETRIEVAL_QA_TRIGGER_KEYWORDS,
|
||||
)
|
||||
from colossalqa.retriever import CustomRetriever
|
||||
from colossalqa.text_splitter import ChineseTextSplitter
|
||||
from langchain import LLMChain
|
||||
from langchain.embeddings import HuggingFaceEmbeddings
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
DEFAULT_RAG_CFG = {
|
||||
"retri_top_k": 3,
|
||||
"retri_kb_file_path": "./",
|
||||
"verbose": True,
|
||||
"mem_summary_prompt": SUMMARY_PROMPT_ZH,
|
||||
"mem_human_prefix": "用户",
|
||||
"mem_ai_prefix": "Assistant",
|
||||
"mem_max_tokens": 2000,
|
||||
"mem_llm_kwargs": {"max_new_tokens": 50, "temperature": 1, "do_sample": True},
|
||||
"disambig_prompt": PROMPT_DISAMBIGUATE_ZH,
|
||||
"disambig_llm_kwargs": {"max_new_tokens": 30, "temperature": 1, "do_sample": True},
|
||||
"embed_model_name_or_path": "moka-ai/m3e-base",
|
||||
"embed_model_device": {"device": "cpu"},
|
||||
"gen_llm_kwargs": {"max_new_tokens": 100, "temperature": 1, "do_sample": True},
|
||||
"gen_qa_prompt": PROMPT_RETRIEVAL_QA_ZH,
|
||||
}
|
||||
|
||||
|
||||
class RAG_ChatBot:
|
||||
def __init__(
|
||||
@@ -44,13 +27,16 @@ class RAG_ChatBot:
|
||||
) -> None:
|
||||
self.llm = llm
|
||||
self.rag_config = rag_config
|
||||
self.set_embed_model(**self.rag_config)
|
||||
self.set_text_splitter(**self.rag_config)
|
||||
self.set_memory(**self.rag_config)
|
||||
self.set_info_retriever(**self.rag_config)
|
||||
self.set_rag_chain(**self.rag_config)
|
||||
if self.rag_config.get("disambig_prompt", None):
|
||||
self.set_disambig_retriv(**self.rag_config)
|
||||
self.set_embed_model(**self.rag_config["embed"])
|
||||
self.set_text_splitter(**self.rag_config["splitter"])
|
||||
self.set_memory(**self.rag_config["chain"])
|
||||
self.set_info_retriever(**self.rag_config["retrieval"])
|
||||
self.set_rag_chain(**self.rag_config["chain"])
|
||||
if self.rag_config["chain"].get("disambig_prompt", None):
|
||||
self.set_disambig_retriv(**self.rag_config["chain"])
|
||||
|
||||
self.documents = []
|
||||
self.docs_names = []
|
||||
|
||||
def set_embed_model(self, **kwargs):
|
||||
self.embed_model = HuggingFaceEmbeddings(
|
||||
@@ -61,7 +47,7 @@ class RAG_ChatBot:
|
||||
|
||||
def set_text_splitter(self, **kwargs):
|
||||
# Initialize text_splitter
|
||||
self.text_splitter = ChineseTextSplitter()
|
||||
self.text_splitter = kwargs["name"]()
|
||||
|
||||
def set_memory(self, **kwargs):
|
||||
params = {"llm_kwargs": kwargs["mem_llm_kwargs"]} if kwargs.get("mem_llm_kwargs", None) else {}
|
||||
@@ -91,10 +77,6 @@ class RAG_ChatBot:
|
||||
**params,
|
||||
)
|
||||
|
||||
def split_docs(self, documents):
|
||||
doc_splits = self.text_splitter.split_documents(documents)
|
||||
return doc_splits
|
||||
|
||||
def set_disambig_retriv(self, **kwargs):
|
||||
params = {"llm_kwargs": kwargs["disambig_llm_kwargs"]} if kwargs.get("disambig_llm_kwargs", None) else {}
|
||||
self.llm_chain_disambiguate = LLMChain(llm=self.llm, prompt=kwargs["disambig_prompt"], **params)
|
||||
@@ -106,42 +88,50 @@ class RAG_ChatBot:
|
||||
self.info_retriever.set_rephrase_handler(disambiguity)
|
||||
|
||||
def load_doc_from_console(self, json_parse_args: Dict = {}):
|
||||
documents = []
|
||||
print("Select files for constructing Chinese retriever")
|
||||
print("Select files for constructing the retriever")
|
||||
while True:
|
||||
file = input("Enter a file path or press Enter directly without input to exit:").strip()
|
||||
if file == "":
|
||||
break
|
||||
data_name = input("Enter a short description of the data:")
|
||||
docs = DocumentLoader([[file, data_name.replace(" ", "_")]], **json_parse_args).all_data
|
||||
documents.extend(docs)
|
||||
self.documents = documents
|
||||
self.split_docs_and_add_to_mem(**self.rag_config)
|
||||
self.documents.extend(docs)
|
||||
self.docs_names.append(data_name)
|
||||
self.split_docs_and_add_to_mem(**self.rag_config["chain"])
|
||||
|
||||
def load_doc_from_files(self, files, data_name="default_kb", json_parse_args: Dict = {}):
|
||||
documents = []
|
||||
for file in files:
|
||||
docs = DocumentLoader([[file, data_name.replace(" ", "_")]], **json_parse_args).all_data
|
||||
documents.extend(docs)
|
||||
self.documents = documents
|
||||
self.split_docs_and_add_to_mem(**self.rag_config)
|
||||
self.documents.extend(docs)
|
||||
self.docs_names.append(os.path.basename(file))
|
||||
self.split_docs_and_add_to_mem(**self.rag_config["chain"])
|
||||
|
||||
def split_docs_and_add_to_mem(self, **kwargs):
|
||||
self.doc_splits = self.split_docs(self.documents)
|
||||
doc_splits = self.split_docs(self.documents)
|
||||
self.info_retriever.add_documents(
|
||||
docs=self.doc_splits, cleanup="incremental", mode="by_source", embedding=self.embed_model
|
||||
docs=doc_splits, cleanup="incremental", mode="by_source", embedding=self.embed_model
|
||||
)
|
||||
self.memory.initiate_document_retrieval_chain(self.llm, kwargs["gen_qa_prompt"], self.info_retriever)
|
||||
|
||||
def split_docs(self, documents):
|
||||
doc_splits = self.text_splitter.split_documents(documents)
|
||||
return doc_splits
|
||||
|
||||
def clear_docs(self, **kwargs):
|
||||
self.documents = []
|
||||
self.docs_names = []
|
||||
self.info_retriever.clear_documents()
|
||||
self.memory.initiate_document_retrieval_chain(self.llm, kwargs["gen_qa_prompt"], self.info_retriever)
|
||||
|
||||
def reset_config(self, rag_config):
|
||||
self.rag_config = rag_config
|
||||
self.set_embed_model(**self.rag_config)
|
||||
self.set_text_splitter(**self.rag_config)
|
||||
self.set_memory(**self.rag_config)
|
||||
self.set_info_retriever(**self.rag_config)
|
||||
self.set_rag_chain(**self.rag_config)
|
||||
if self.rag_config.get("disambig_prompt", None):
|
||||
self.set_disambig_retriv(**self.rag_config)
|
||||
self.set_embed_model(**self.rag_config["embed"])
|
||||
self.set_text_splitter(**self.rag_config["splitter"])
|
||||
self.set_memory(**self.rag_config["chain"])
|
||||
self.set_info_retriever(**self.rag_config["retrieval"])
|
||||
self.set_rag_chain(**self.rag_config["chain"])
|
||||
if self.rag_config["chain"].get("disambig_prompt", None):
|
||||
self.set_disambig_retriv(**self.rag_config["chain"])
|
||||
|
||||
def run(self, user_input: str, memory: ConversationBufferWithSummary) -> Tuple[str, ConversationBufferWithSummary]:
|
||||
if memory:
|
||||
@@ -153,7 +143,7 @@ class RAG_ChatBot:
|
||||
rejection_trigger_keywrods=ZH_RETRIEVAL_QA_TRIGGER_KEYWORDS,
|
||||
rejection_answer=ZH_RETRIEVAL_QA_REJECTION_ANSWER,
|
||||
)
|
||||
return result.split("\n")[0], memory
|
||||
return result, memory
|
||||
|
||||
def start_test_session(self):
|
||||
"""
|
||||
@@ -170,15 +160,18 @@ class RAG_ChatBot:
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Initialize an Langchain LLM(here we use ChatGPT as an example)
|
||||
import config
|
||||
from langchain.llms import OpenAI
|
||||
|
||||
llm = OpenAI(openai_api_key="YOUR_OPENAI_API_KEY")
|
||||
# you need to: export OPENAI_API_KEY="YOUR_OPENAI_API_KEY"
|
||||
llm = OpenAI(openai_api_key=os.getenv("OPENAI_API_KEY"))
|
||||
|
||||
# chatgpt cannot control temperature, do_sample, etc.
|
||||
DEFAULT_RAG_CFG["mem_llm_kwargs"] = None
|
||||
DEFAULT_RAG_CFG["disambig_llm_kwargs"] = None
|
||||
DEFAULT_RAG_CFG["gen_llm_kwargs"] = None
|
||||
all_config = config.ALL_CONFIG
|
||||
all_config["chain"]["mem_llm_kwargs"] = None
|
||||
all_config["chain"]["disambig_llm_kwargs"] = None
|
||||
all_config["chain"]["gen_llm_kwargs"] = None
|
||||
|
||||
rag = RAG_ChatBot(llm, DEFAULT_RAG_CFG)
|
||||
rag = RAG_ChatBot(llm, all_config)
|
||||
rag.load_doc_from_console()
|
||||
rag.start_test_session()
|
||||
|
@@ -16,22 +16,103 @@ cd ColossalAI/applications/ColossalQA/
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
Install the dependencies for ColossalQA webui demo:
|
||||
```sh
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Configure the RAG Chain
|
||||
|
||||
Customize the RAG Chain settings, such as the embedding model (default: moka-ai/m3e) and the language model, in the `start_colossal_qa.sh` script.
|
||||
Customize the RAG Chain settings, such as the embedding model (default: moka-ai/m3e), the language model, and the prompts, in the `config.py`. Please refer to [`Prepare configuration file`](#prepare-configuration-file) for the details of `config.py`.
|
||||
|
||||
For API-based language models (like ChatGPT or Huawei Pangu), provide your API key for authentication. For locally-run models, indicate the path to the model's checkpoint file.
|
||||
|
||||
If you want to customize prompts in the RAG Chain, you can have a look at the `RAG_ChatBot.py` file to modify them.
|
||||
## Prepare configuration file
|
||||
|
||||
All configs are defined in `ColossalQA/examples/webui_demo/config.py`.
|
||||
|
||||
- embed:
|
||||
- <mark>embed_name</mark>: the embedding model name
|
||||
- <mark>embed_model_name_or_path</mark>: path to embedding model, could be a local path or a huggingface path
|
||||
- embed_model_device: device to load the embedding model
|
||||
- model:
|
||||
- <mark>mode</mark>: "local" for loading models, "api" for using model api
|
||||
- <mark>model_name</mark>: "chatgpt_api", "pangu_api", or your local model name
|
||||
- <mark>model_path</mark>: path to the model, could be a local path or a huggingface path. don't need if mode="api"
|
||||
- device: device to load the LLM
|
||||
- splitter:
|
||||
- name: text splitter class name, the class should be imported at the beginning of `config.py`
|
||||
- retrieval:
|
||||
- retri_top_k: number of retrieval text which will be provided to the model
|
||||
- retri_kb_file_path: path to store database files
|
||||
- verbose: Boolean type, to control the level of detail in program output
|
||||
- chain:
|
||||
- mem_summary_prompt: summary prompt template
|
||||
- mem_human_prefix: human prefix for prompt
|
||||
- mem_ai_prefix: AI assistant prefix for prompt
|
||||
- mem_max_tokens: max tokens for history information
|
||||
- mem_llm_kwargs: model's generation kwargs for summarizing history
|
||||
- max_new_tokens: int
|
||||
- temperature: int
|
||||
- do_sample: bool
|
||||
- disambig_prompt: disambiguate prompt template
|
||||
- disambig_llm_kwargs: model's generation kwargs for disambiguating user's input
|
||||
- max_new_tokens: int
|
||||
- temperature": int
|
||||
- do_sample: bool
|
||||
- gen_llm_kwargs: model's generation kwargs
|
||||
- max_new_tokens: int
|
||||
- temperature: int
|
||||
- do_sample: bool
|
||||
- gen_qa_prompt: generation prompt template
|
||||
- verbose: Boolean type, to control the level of detail in program output
|
||||
|
||||
|
||||
## Run WebUI Demo
|
||||
|
||||
Execute the following command to start the demo:
|
||||
|
||||
1. If you want to use a local model as the backend model, you need to specify the model name and model path in `config.py` and run the following commands.
|
||||
|
||||
```sh
|
||||
bash start_colossal_qa.sh
|
||||
export TMP="path/to/store/tmp/files"
|
||||
# start the backend server
|
||||
python server.py --http_host "host" --http_port "port"
|
||||
|
||||
# in an another terminal, start the ui
|
||||
python webui.py --http_host "your-backend-api-host" --http_port "your-backend-api-port"
|
||||
```
|
||||
|
||||
2. If you want to use pangu api as the backend model, you need to change the model mode to "api", change the model name to "chatgpt_api" in `config.py`, and run the following commands.
|
||||
```sh
|
||||
export TMP="path/to/store/tmp/files"
|
||||
|
||||
# Auth info for OpenAI API
|
||||
export OPENAI_API_KEY="YOUR_OPENAI_API_KEY"
|
||||
|
||||
# start the backend server
|
||||
python server.py --http_host "host" --http_port "port"
|
||||
|
||||
# in an another terminal, start the ui
|
||||
python webui.py --http_host "your-backend-api-host" --http_port "your-backend-api-port"
|
||||
```
|
||||
|
||||
3. If you want to use pangu api as the backend model, you need to change the model mode to "api", change the model name to "pangu_api" in `config.py`, and run the following commands.
|
||||
```sh
|
||||
export TMP="path/to/store/tmp/files"
|
||||
|
||||
# Auth info for Pangu API
|
||||
export URL=""
|
||||
export USERNAME=""
|
||||
export PASSWORD=""
|
||||
export DOMAIN_NAME=""
|
||||
|
||||
# start the backend server
|
||||
python server.py --http_host "host" --http_port "port"
|
||||
|
||||
# in an another terminal, start the ui
|
||||
python webui.py --http_host "your-backend-api-host" --http_port "your-backend-api-port"
|
||||
```
|
||||
|
||||
After launching the script, you can upload files and engage with the chatbot through your web browser.
|
||||
|
||||

|
||||

|
58
applications/ColossalQA/examples/webui_demo/config.py
Normal file
58
applications/ColossalQA/examples/webui_demo/config.py
Normal file
@@ -0,0 +1,58 @@
|
||||
from colossalqa.prompt.prompt import (
|
||||
PROMPT_DISAMBIGUATE_ZH,
|
||||
PROMPT_RETRIEVAL_QA_ZH,
|
||||
SUMMARY_PROMPT_ZH,
|
||||
ZH_RETRIEVAL_QA_REJECTION_ANSWER,
|
||||
ZH_RETRIEVAL_QA_TRIGGER_KEYWORDS,
|
||||
)
|
||||
from colossalqa.text_splitter import ChineseTextSplitter
|
||||
|
||||
ALL_CONFIG = {
|
||||
"embed": {
|
||||
"embed_name": "m3e", # embedding model name
|
||||
"embed_model_name_or_path": "moka-ai/m3e-base", # path to embedding model, could be a local path or a huggingface path
|
||||
"embed_model_device": {
|
||||
"device": "cpu"
|
||||
}
|
||||
},
|
||||
"model": {
|
||||
"mode": "api", # "local" for loading models, "api" for using model api
|
||||
"model_name": "chatgpt_api", # local model name, "chatgpt_api" or "pangu_api"
|
||||
"model_path": "", # path to the model, could be a local path or a huggingface path. don't need if using an api
|
||||
"device": {
|
||||
"device": "cuda"
|
||||
}
|
||||
},
|
||||
"splitter": {
|
||||
"name": ChineseTextSplitter
|
||||
},
|
||||
"retrieval": {
|
||||
"retri_top_k": 3,
|
||||
"retri_kb_file_path": "./", # path to store database files
|
||||
"verbose": True
|
||||
},
|
||||
"chain": {
|
||||
"mem_summary_prompt": SUMMARY_PROMPT_ZH, # summary prompt template
|
||||
"mem_human_prefix": "用户",
|
||||
"mem_ai_prefix": "Assistant",
|
||||
"mem_max_tokens": 2000,
|
||||
"mem_llm_kwargs": {
|
||||
"max_new_tokens": 50,
|
||||
"temperature": 1,
|
||||
"do_sample": True
|
||||
},
|
||||
"disambig_prompt": PROMPT_DISAMBIGUATE_ZH, # disambiguate prompt template
|
||||
"disambig_llm_kwargs": {
|
||||
"max_new_tokens": 30,
|
||||
"temperature": 1,
|
||||
"do_sample": True
|
||||
},
|
||||
"gen_llm_kwargs": {
|
||||
"max_new_tokens": 100,
|
||||
"temperature": 1,
|
||||
"do_sample": True
|
||||
},
|
||||
"gen_qa_prompt": PROMPT_RETRIEVAL_QA_ZH, # generation prompt template
|
||||
"verbose": True
|
||||
}
|
||||
}
|
@@ -0,0 +1,3 @@
|
||||
fastapi==0.99.1
|
||||
uvicorn>=0.24.0
|
||||
pydantic==1.10.13
|
@@ -1,117 +1,98 @@
|
||||
import argparse
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import string
|
||||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||
from typing import List, Union
|
||||
|
||||
|
||||
from colossalqa.local.llm import ColossalAPI, ColossalLLM
|
||||
from colossalqa.data_loader.document_loader import DocumentLoader
|
||||
from colossalqa.mylogging import get_logger
|
||||
from colossalqa.retrieval_conversation_zh import ChineseRetrievalConversation
|
||||
from colossalqa.retriever import CustomRetriever
|
||||
from enum import Enum
|
||||
from fastapi import FastAPI, Request
|
||||
from langchain.embeddings import HuggingFaceEmbeddings
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from RAG_ChatBot import RAG_ChatBot, DEFAULT_RAG_CFG
|
||||
from pydantic import BaseModel, Field
|
||||
import uvicorn
|
||||
|
||||
# Define the mapping between embed_model_name(passed from Front End) and the actual path on the back end server
|
||||
EMBED_MODEL_DICT = {
|
||||
"m3e": os.environ.get("EMB_MODEL_PATH", DEFAULT_RAG_CFG["embed_model_name_or_path"])
|
||||
}
|
||||
# Define the mapping between LLM_name(passed from Front End) and the actual path on the back end server
|
||||
LLM_DICT = {
|
||||
"chatglm2": os.environ.get("CHAT_LLM_PATH", "THUDM/chatglm-6b"),
|
||||
"pangu": "Pangu_API",
|
||||
"chatgpt": "OpenAI_API"
|
||||
}
|
||||
import config
|
||||
from RAG_ChatBot import RAG_ChatBot
|
||||
from utils import DocAction
|
||||
|
||||
def randomword(length):
|
||||
letters = string.ascii_lowercase
|
||||
return "".join(random.choice(letters) for i in range(length))
|
||||
|
||||
class ColossalQAServerRequestHandler(BaseHTTPRequestHandler):
|
||||
chatbot = None
|
||||
def _set_response(self):
|
||||
"""
|
||||
set http header for response
|
||||
"""
|
||||
self.send_response(200)
|
||||
self.send_header("Content-type", "application/json")
|
||||
self.end_headers()
|
||||
logger = get_logger()
|
||||
|
||||
def do_POST(self):
|
||||
content_length = int(self.headers["Content-Length"])
|
||||
post_data = self.rfile.read(content_length)
|
||||
received_json = json.loads(post_data.decode("utf-8"))
|
||||
print(received_json)
|
||||
# conversation_ready is False(user's first request): Need to upload files and initialize the RAG chain
|
||||
if received_json["conversation_ready"] is False:
|
||||
self.rag_config = DEFAULT_RAG_CFG.copy()
|
||||
try:
|
||||
assert received_json["embed_model_name"] in EMBED_MODEL_DICT
|
||||
assert received_json["llm_name"] in LLM_DICT
|
||||
self.docs_files = received_json["docs"]
|
||||
embed_model_name, llm_name = received_json["embed_model_name"], received_json["llm_name"]
|
||||
|
||||
# Find the embed_model/llm ckpt path on the back end server.
|
||||
embed_model_path, llm_path = EMBED_MODEL_DICT[embed_model_name], LLM_DICT[llm_name]
|
||||
self.rag_config["embed_model_name_or_path"] = embed_model_path
|
||||
def parseArgs():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--http_host", default="0.0.0.0")
|
||||
parser.add_argument("--http_port", type=int, default=13666)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
class DocUpdateReq(BaseModel):
|
||||
doc_files: Union[List[str], str, None] = None
|
||||
action: DocAction = DocAction.ADD
|
||||
|
||||
class GenerationTaskReq(BaseModel):
|
||||
user_input: str
|
||||
|
||||
|
||||
@app.post("/update")
|
||||
def update_docs(data: DocUpdateReq, request: Request):
|
||||
if data.action == "add":
|
||||
if isinstance(data.doc_files, str):
|
||||
data.doc_files = [data.doc_files]
|
||||
chatbot.load_doc_from_files(files = data.doc_files)
|
||||
all_docs = ""
|
||||
for doc in chatbot.docs_names:
|
||||
all_docs += f"\t{doc}\n\n"
|
||||
return {"response": f"文件上传完成,所有数据库文件:\n\n{all_docs}让我们开始对话吧!"}
|
||||
elif data.action == "clear":
|
||||
chatbot.clear_docs(**all_config["chain"])
|
||||
return {"response": f"已清空数据库。"}
|
||||
|
||||
|
||||
@app.post("/generate")
|
||||
def generate(data: GenerationTaskReq, request: Request):
|
||||
try:
|
||||
chatbot_response, chatbot.memory = chatbot.run(data.user_input, chatbot.memory)
|
||||
return {"response": chatbot_response, "error": ""}
|
||||
except Exception as e:
|
||||
return {"response": "模型生成回答有误", "error": f"Error in generating answers, details: {e}"}
|
||||
|
||||
# Create the storage path for knowledge base files
|
||||
self.rag_config["retri_kb_file_path"] = os.path.join(os.environ["TMP"], "colossalqa_kb/"+randomword(20))
|
||||
if not os.path.exists(self.rag_config["retri_kb_file_path"]):
|
||||
os.makedirs(self.rag_config["retri_kb_file_path"])
|
||||
|
||||
if (embed_model_path is not None) and (llm_path is not None):
|
||||
# ---- Intialize LLM, QA_chatbot here ----
|
||||
print("Initializing LLM...")
|
||||
if llm_path == "Pangu_API":
|
||||
from colossalqa.local.pangu_llm import Pangu
|
||||
self.llm = Pangu(id=1)
|
||||
self.llm.set_auth_config() # verify user's auth info here
|
||||
self.rag_config["mem_llm_kwargs"] = None
|
||||
self.rag_config["disambig_llm_kwargs"] = None
|
||||
self.rag_config["gen_llm_kwargs"] = None
|
||||
elif llm_path == "OpenAI_API":
|
||||
from langchain.llms import OpenAI
|
||||
self.llm = OpenAI()
|
||||
self.rag_config["mem_llm_kwargs"] = None
|
||||
self.rag_config["disambig_llm_kwargs"] = None
|
||||
self.rag_config["gen_llm_kwargs"] = None
|
||||
else:
|
||||
# ** (For Testing Only) **
|
||||
# In practice, all LLMs will run on the cloud platform and accessed by API, instead of running locally.
|
||||
# initialize model from model_path by using ColossalLLM
|
||||
self.rag_config["mem_llm_kwargs"] = {"max_new_tokens": 50, "temperature": 1, "do_sample": True}
|
||||
self.rag_config["disambig_llm_kwargs"] = {"max_new_tokens": 30, "temperature": 1, "do_sample": True}
|
||||
self.rag_config["gen_llm_kwargs"] = {"max_new_tokens": 100, "temperature": 1, "do_sample": True}
|
||||
self.colossal_api = ColossalAPI(llm_name, llm_path)
|
||||
self.llm = ColossalLLM(n=1, api=self.colossal_api)
|
||||
|
||||
print(f"Initializing RAG Chain...")
|
||||
print("RAG_CONFIG: ", self.rag_config)
|
||||
self.__class__.chatbot = RAG_ChatBot(self.llm, self.rag_config)
|
||||
print("Loading Files....\n", self.docs_files)
|
||||
self.__class__.chatbot.load_doc_from_files(self.docs_files)
|
||||
# -----------------------------------------------------------------------------------
|
||||
res = {"response": f"文件上传完成,模型初始化完成,让我们开始对话吧!(后端模型:{llm_name})", "error": "", "conversation_ready": True}
|
||||
except Exception as e:
|
||||
res = {"response": "文件上传或模型初始化有误,无法开始对话。",
|
||||
"error": f"Error in File Uploading and/or RAG initialization. Error details: {e}",
|
||||
"conversation_ready": False}
|
||||
# conversation_ready is True: Chatbot and docs are all set. Ready to chat.
|
||||
else:
|
||||
user_input = received_json["user_input"]
|
||||
chatbot_response, self.__class__.chatbot.memory = self.__class__.chatbot.run(user_input, self.__class__.chatbot.memory)
|
||||
res = {"response": chatbot_response, "error": "", "conversation_ready": True}
|
||||
self._set_response()
|
||||
self.wfile.write(json.dumps(res).encode("utf-8"))
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Chinese retrieval based conversation system")
|
||||
parser.add_argument("--port", type=int, default=13666, help="port on localhost to start the server")
|
||||
args = parser.parse_args()
|
||||
server_address = ("localhost", args.port)
|
||||
httpd = HTTPServer(server_address, ColossalQAServerRequestHandler)
|
||||
print(f"Starting server on port {args.port}...")
|
||||
httpd.serve_forever()
|
||||
|
||||
args = parseArgs()
|
||||
|
||||
all_config = config.ALL_CONFIG
|
||||
model_name = all_config["model"]["model_name"]
|
||||
|
||||
# initialize chatbot
|
||||
logger.info(f"Initialize the chatbot from {model_name}")
|
||||
|
||||
if all_config["model"]["mode"] == "local":
|
||||
colossal_api = ColossalAPI(model_name, all_config["model"]["model_path"])
|
||||
llm = ColossalLLM(n=1, api=colossal_api)
|
||||
elif all_config["model"]["mode"] == "api":
|
||||
all_config["chain"]["mem_llm_kwargs"] = None
|
||||
all_config["chain"]["disambig_llm_kwargs"] = None
|
||||
all_config["chain"]["gen_llm_kwargs"] = None
|
||||
if model_name == "pangu_api":
|
||||
from colossalqa.local.pangu_llm import Pangu
|
||||
llm = Pangu(id=1)
|
||||
llm.set_auth_config() # verify user's auth info here
|
||||
elif model_name == "chatgpt_api":
|
||||
from langchain.llms import OpenAI
|
||||
llm = OpenAI()
|
||||
else:
|
||||
raise ValueError("Unsupported mode.")
|
||||
|
||||
# initialize chatbot
|
||||
chatbot = RAG_ChatBot(llm, all_config)
|
||||
|
||||
app_config = uvicorn.Config(app, host=args.http_host, port=args.http_port)
|
||||
server = uvicorn.Server(config=app_config)
|
||||
server.run()
|
||||
|
@@ -1,43 +0,0 @@
|
||||
#!/bin/bash
|
||||
cleanup() {
|
||||
echo "Caught Signal ... cleaning up."
|
||||
pkill -P $$ # kill all subprocess of this script
|
||||
exit 1 # exit script
|
||||
}
|
||||
# 'cleanup' is trigered when receive SIGINT(Ctrl+C) OR SIGTERM(kill) signal
|
||||
trap cleanup INT TERM
|
||||
|
||||
# Disable your proxy
|
||||
# unset HTTP_PROXY HTTPS_PROXY http_proxy https_proxy
|
||||
|
||||
# Path to store knowledge base(Home Directory by default)
|
||||
export TMP=$HOME
|
||||
|
||||
# Use m3e as embedding model
|
||||
export EMB_MODEL="m3e" # moka-ai/m3e-base model will be download automatically
|
||||
# export EMB_MODEL_PATH="PATH_TO_LOCAL_CHECKPOINT/m3e-base" # you can also specify the local path to embedding model
|
||||
|
||||
# Choose a backend LLM
|
||||
# - ChatGLM2
|
||||
# export CHAT_LLM="chatglm2"
|
||||
# export CHAT_LLM_PATH="PATH_TO_LOCAL_CHECKPOINT/chatglm2-6b"
|
||||
|
||||
# - ChatGPT
|
||||
export CHAT_LLM="chatgpt"
|
||||
# Auth info for OpenAI API
|
||||
export OPENAI_API_KEY="YOUR_OPENAI_API_KEY"
|
||||
|
||||
# - Pangu
|
||||
# export CHAT_LLM="pangu"
|
||||
# # Auth info for Pangu API
|
||||
# export URL=""
|
||||
# export USERNAME=""
|
||||
# export PASSWORD=""
|
||||
# export DOMAIN_NAME=""
|
||||
|
||||
# Run server.py and colossalqa_webui.py in the background
|
||||
python server.py &
|
||||
python webui.py &
|
||||
|
||||
# Wait for all processes to finish
|
||||
wait
|
6
applications/ColossalQA/examples/webui_demo/utils.py
Normal file
6
applications/ColossalQA/examples/webui_demo/utils.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class DocAction(str, Enum):
|
||||
ADD = "add"
|
||||
CLEAR = "clear"
|
@@ -1,17 +1,21 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import gradio as gr
|
||||
import requests
|
||||
|
||||
RAG_STATE = {"conversation_ready": False, # Conversation is not ready until files are uploaded and RAG chain is initialized
|
||||
"embed_model_name": os.environ.get("EMB_MODEL", "m3e"),
|
||||
"llm_name": os.environ.get("CHAT_LLM", "chatgpt")}
|
||||
URL = "http://localhost:13666"
|
||||
import gradio as gr
|
||||
|
||||
def get_response(client_data, URL):
|
||||
from utils import DocAction
|
||||
|
||||
def parseArgs():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--http_host", default="0.0.0.0")
|
||||
parser.add_argument("--http_port", type=int, default=13666)
|
||||
return parser.parse_args()
|
||||
|
||||
def get_response(data, url):
|
||||
headers = {"Content-type": "application/json"}
|
||||
print(f"Sending request to server url: {URL}")
|
||||
response = requests.post(URL, data=json.dumps(client_data), headers=headers)
|
||||
response = requests.post(url, json=data, headers=headers)
|
||||
response = json.loads(response.content)
|
||||
return response
|
||||
|
||||
@@ -19,41 +23,43 @@ def add_text(history, text):
|
||||
history = history + [(text, None)]
|
||||
return history, gr.update(value=None, interactive=True)
|
||||
|
||||
|
||||
def add_file(history, files):
|
||||
global RAG_STATE
|
||||
RAG_STATE["conversation_ready"] = False # after adding new files, reset the ChatBot
|
||||
RAG_STATE["upload_files"]=[file.name for file in files]
|
||||
files_string = "\n".join([os.path.basename(path) for path in RAG_STATE["upload_files"]])
|
||||
print(files_string)
|
||||
history = history + [(files_string, None)]
|
||||
files_string = "\n".join([os.path.basename(file.name) for file in files])
|
||||
|
||||
doc_files = [file.name for file in files]
|
||||
data = {
|
||||
"doc_files": doc_files,
|
||||
"action": DocAction.ADD
|
||||
}
|
||||
response = get_response(data, update_url)["response"]
|
||||
history = history + [(files_string, response)]
|
||||
return history
|
||||
|
||||
def bot(history):
|
||||
print(history)
|
||||
global RAG_STATE
|
||||
if not RAG_STATE["conversation_ready"]:
|
||||
# Upload files and initialize models
|
||||
client_data = {
|
||||
"docs": RAG_STATE["upload_files"],
|
||||
"embed_model_name": RAG_STATE["embed_model_name"], # Select embedding model name here
|
||||
"llm_name": RAG_STATE["llm_name"], # Select LLM model name here. ["pangu", "chatglm2"]
|
||||
"conversation_ready": RAG_STATE["conversation_ready"]
|
||||
}
|
||||
else:
|
||||
client_data = {}
|
||||
client_data["conversation_ready"] = RAG_STATE["conversation_ready"]
|
||||
client_data["user_input"] = history[-1][0].strip()
|
||||
def bot(history):
|
||||
data = {
|
||||
"user_input": history[-1][0].strip()
|
||||
}
|
||||
response = get_response(data, gen_url)
|
||||
|
||||
response = get_response(client_data, URL) # TODO: async request, to avoid users waiting the model initialization too long
|
||||
print(response)
|
||||
if response["error"] != "":
|
||||
raise gr.Error(response["error"])
|
||||
|
||||
RAG_STATE["conversation_ready"] = response["conversation_ready"]
|
||||
history[-1][1] = response["response"]
|
||||
yield history
|
||||
|
||||
|
||||
def restart(chatbot, txt):
|
||||
# Reset the conversation state and clear the chat history
|
||||
data = {
|
||||
"doc_files": "",
|
||||
"action": DocAction.CLEAR
|
||||
}
|
||||
response = get_response(data, update_url)
|
||||
|
||||
return gr.update(value=None), gr.update(value=None, interactive=True)
|
||||
|
||||
|
||||
CSS = """
|
||||
.contain { display: flex; flex-direction: column; height: 100vh }
|
||||
#component-0 { height: 100%; }
|
||||
@@ -63,7 +69,7 @@ CSS = """
|
||||
header_html = """
|
||||
<div style="background: linear-gradient(to right, #2a0cf4, #7100ed, #9800e6, #b600df, #ce00d9, #dc0cd1, #e81bca, #f229c3, #f738ba, #f946b2, #fb53ab, #fb5fa5); padding: 20px; text-align: left;">
|
||||
<h1 style="color: white;">ColossalQA</h1>
|
||||
<h4 style="color: white;">ColossalQA</h4>
|
||||
<h4 style="color: white;">A powerful Q&A system with knowledge bases</h4>
|
||||
</div>
|
||||
"""
|
||||
|
||||
@@ -78,25 +84,32 @@ with gr.Blocks(css=CSS) as demo:
|
||||
(os.path.join(os.path.dirname(__file__), "img/avatar_ai.png")),
|
||||
),
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
btn = gr.UploadButton("📁", file_types=["file"], file_count="multiple", size="sm")
|
||||
restart_btn = gr.Button(str("\u21BB"), elem_id="restart-btn", scale=1)
|
||||
txt = gr.Textbox(
|
||||
scale=4,
|
||||
scale=8,
|
||||
show_label=False,
|
||||
placeholder="Enter text and press enter, or upload an image",
|
||||
placeholder="Enter text and press enter, or use 📁 to upload files, click \u21BB to clear loaded files and restart chat",
|
||||
container=True,
|
||||
autofocus=True,
|
||||
)
|
||||
btn = gr.UploadButton("📁", file_types=["file"], file_count="multiple")
|
||||
|
||||
txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(bot, chatbot, chatbot)
|
||||
# Clear the original textbox
|
||||
txt_msg.then(lambda: gr.update(value=None, interactive=True), None, [txt], queue=False)
|
||||
# Click Upload Button: 1. upload files 2. send config to backend, initalize model 3. get response "conversation_ready" = True/False
|
||||
file_msg = btn.upload(add_file, [chatbot, btn], [chatbot], queue=False).then(bot, chatbot, chatbot)
|
||||
file_msg = btn.upload(add_file, [chatbot, btn], [chatbot], queue=False)
|
||||
|
||||
# restart
|
||||
restart_msg = restart_btn.click(restart, [chatbot, txt], [chatbot, txt], queue=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parseArgs()
|
||||
|
||||
update_url = f"http://{args.http_host}:{args.http_port}/update"
|
||||
gen_url = f"http://{args.http_host}:{args.http_port}/generate"
|
||||
|
||||
demo.queue()
|
||||
demo.launch(share=True) # share=True will release a public link of the demo
|
||||
|
Reference in New Issue
Block a user