[Feature] Add document retrieval QA (#5020)

* add langchain

* add langchain

* Add files via upload

* add langchain

* fix style

* fix style: remove extra space

* add pytest; modified retriever

* add pytest; modified retriever

* add tests to build_on_pr.yml

* fix build_on_pr.yml

* fix build on pr; fix environ vars

* seperate unit tests for colossalqa from build from pr

* fix container setting; fix environ vars

* commented dev code

* add incremental update

* remove stale code

* fix style

* change to sha3 224

* fix retriever; fix style; add unit test for document loader

* fix ci workflow config

* fix ci workflow config

* add set cuda visible device script in ci

* fix doc string

* fix style; update readme; refactored

* add force log info

* change build on pr, ignore colossalqa

* fix docstring, captitalize all initial letters

* fix indexing; fix text-splitter

* remove debug code, update reference

* reset previous commit

* update LICENSE update README add key-value mode, fix bugs

* add files back

* revert force push

* remove junk file

* add test files

* fix retriever bug, add intent classification

* change conversation chain design

* rewrite prompt and conversation chain

* add ui v1

* ui v1

* fix atavar

* add header

* Refactor the RAG Code and support Pangu

* Refactor the ColossalQA chain to Object-Oriented Programming and the UI demo.

* resolved conversation. tested scripts under examples. web demo still buggy

* fix ci tests

* Some modifications to add ChatGPT api

* modify llm.py and remove unnecessary files

* Delete applications/ColossalQA/examples/ui/test_frontend_input.json

* Remove OpenAI api key

* add colossalqa

* move files

* move files

* move files

* move files

* fix style

* Add Readme and fix some bugs.

* Add something to readme and modify some code

* modify a directory name for clarity

* remove redundant directory

* Correct a type in  llm.py

* fix AI prefix

* fix test_memory.py

* fix conversation

* fix some erros and typos

* Fix a missing import in RAG_ChatBot.py

* add colossalcloud LLM wrapper, correct issues in code review

---------

Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: Orion-Zheng <zheng_zian@u.nus.edu>
Co-authored-by: Zian(Andy) Zheng <62330719+Orion-Zheng@users.noreply.github.com>
Co-authored-by: Orion-Zheng <zhengzian@u.nus.edu>
This commit is contained in:
YeAnbang
2023-11-23 10:33:48 +08:00
committed by GitHub
parent 3acbf6d496
commit e53e729d8e
69 changed files with 6758 additions and 0 deletions

View File

@@ -0,0 +1,184 @@
from typing import Dict, Tuple
from colossalqa.chain.retrieval_qa.base import RetrievalQA
from colossalqa.data_loader.document_loader import DocumentLoader
from colossalqa.memory import ConversationBufferWithSummary
from colossalqa.mylogging import get_logger
from colossalqa.prompt.prompt import (
PROMPT_DISAMBIGUATE_ZH,
PROMPT_RETRIEVAL_QA_ZH,
SUMMARY_PROMPT_ZH,
ZH_RETRIEVAL_QA_REJECTION_ANSWER,
ZH_RETRIEVAL_QA_TRIGGER_KEYWORDS,
)
from colossalqa.retriever import CustomRetriever
from colossalqa.text_splitter import ChineseTextSplitter
from langchain import LLMChain
from langchain.embeddings import HuggingFaceEmbeddings
logger = get_logger()
DEFAULT_RAG_CFG = {
"retri_top_k": 3,
"retri_kb_file_path": "./",
"verbose": True,
"mem_summary_prompt": SUMMARY_PROMPT_ZH,
"mem_human_prefix": "用户",
"mem_ai_prefix": "Assistant",
"mem_max_tokens": 2000,
"mem_llm_kwargs": {"max_new_tokens": 50, "temperature": 1, "do_sample": True},
"disambig_prompt": PROMPT_DISAMBIGUATE_ZH,
"disambig_llm_kwargs": {"max_new_tokens": 30, "temperature": 1, "do_sample": True},
"embed_model_name_or_path": "moka-ai/m3e-base",
"embed_model_device": {"device": "cpu"},
"gen_llm_kwargs": {"max_new_tokens": 100, "temperature": 1, "do_sample": True},
"gen_qa_prompt": PROMPT_RETRIEVAL_QA_ZH,
}
class RAG_ChatBot:
def __init__(
self,
llm,
rag_config,
) -> None:
self.llm = llm
self.rag_config = rag_config
self.set_embed_model(**self.rag_config)
self.set_text_splitter(**self.rag_config)
self.set_memory(**self.rag_config)
self.set_info_retriever(**self.rag_config)
self.set_rag_chain(**self.rag_config)
if self.rag_config.get("disambig_prompt", None):
self.set_disambig_retriv(**self.rag_config)
def set_embed_model(self, **kwargs):
self.embed_model = HuggingFaceEmbeddings(
model_name=kwargs["embed_model_name_or_path"],
model_kwargs=kwargs["embed_model_device"],
encode_kwargs={"normalize_embeddings": False},
)
def set_text_splitter(self, **kwargs):
# Initialize text_splitter
self.text_splitter = ChineseTextSplitter()
def set_memory(self, **kwargs):
params = {"llm_kwargs": kwargs["mem_llm_kwargs"]} if kwargs.get("mem_llm_kwargs", None) else {}
# Initialize memory with summarization ability
self.memory = ConversationBufferWithSummary(
llm=self.llm,
prompt=kwargs["mem_summary_prompt"],
human_prefix=kwargs["mem_human_prefix"],
ai_prefix=kwargs["mem_ai_prefix"],
max_tokens=kwargs["mem_max_tokens"],
**params,
)
def set_info_retriever(self, **kwargs):
self.info_retriever = CustomRetriever(
k=kwargs["retri_top_k"], sql_file_path=kwargs["retri_kb_file_path"], verbose=kwargs["verbose"]
)
def set_rag_chain(self, **kwargs):
params = {"llm_kwargs": kwargs["gen_llm_kwargs"]} if kwargs.get("gen_llm_kwargs", None) else {}
self.rag_chain = RetrievalQA.from_chain_type(
llm=self.llm,
verbose=kwargs["verbose"],
chain_type="stuff",
retriever=self.info_retriever,
chain_type_kwargs={"prompt": kwargs["gen_qa_prompt"], "memory": self.memory},
**params,
)
def split_docs(self, documents):
doc_splits = self.text_splitter.split_documents(documents)
return doc_splits
def set_disambig_retriv(self, **kwargs):
params = {"llm_kwargs": kwargs["disambig_llm_kwargs"]} if kwargs.get("disambig_llm_kwargs", None) else {}
self.llm_chain_disambiguate = LLMChain(llm=self.llm, prompt=kwargs["disambig_prompt"], **params)
def disambiguity(input: str):
out = self.llm_chain_disambiguate.run(input=input, chat_history=self.memory.buffer, stop=["\n"])
return out.split("\n")[0]
self.info_retriever.set_rephrase_handler(disambiguity)
def load_doc_from_console(self, json_parse_args: Dict = {}):
documents = []
print("Select files for constructing Chinese retriever")
while True:
file = input("Enter a file path or press Enter directly without input to exit:").strip()
if file == "":
break
data_name = input("Enter a short description of the data:")
docs = DocumentLoader([[file, data_name.replace(" ", "_")]], **json_parse_args).all_data
documents.extend(docs)
self.documents = documents
self.split_docs_and_add_to_mem(**self.rag_config)
def load_doc_from_files(self, files, data_name="default_kb", json_parse_args: Dict = {}):
documents = []
for file in files:
docs = DocumentLoader([[file, data_name.replace(" ", "_")]], **json_parse_args).all_data
documents.extend(docs)
self.documents = documents
self.split_docs_and_add_to_mem(**self.rag_config)
def split_docs_and_add_to_mem(self, **kwargs):
self.doc_splits = self.split_docs(self.documents)
self.info_retriever.add_documents(
docs=self.doc_splits, cleanup="incremental", mode="by_source", embedding=self.embed_model
)
self.memory.initiate_document_retrieval_chain(self.llm, kwargs["gen_qa_prompt"], self.info_retriever)
def reset_config(self, rag_config):
self.rag_config = rag_config
self.set_embed_model(**self.rag_config)
self.set_text_splitter(**self.rag_config)
self.set_memory(**self.rag_config)
self.set_info_retriever(**self.rag_config)
self.set_rag_chain(**self.rag_config)
if self.rag_config.get("disambig_prompt", None):
self.set_disambig_retriv(**self.rag_config)
def run(self, user_input: str, memory: ConversationBufferWithSummary) -> Tuple[str, ConversationBufferWithSummary]:
if memory:
memory.buffered_history.messages = memory.buffered_history.messages
memory.summarized_history_temp.messages = memory.summarized_history_temp.messages
result = self.rag_chain.run(
query=user_input,
stop=[memory.human_prefix + ": "],
rejection_trigger_keywrods=ZH_RETRIEVAL_QA_TRIGGER_KEYWORDS,
rejection_answer=ZH_RETRIEVAL_QA_REJECTION_ANSWER,
)
return result.split("\n")[0], memory
def start_test_session(self):
"""
Simple session for testing purpose
"""
while True:
user_input = input("User: ")
if "END" == user_input:
print("Agent: Happy to chat with you :)")
break
agent_response, self.memory = self.run(user_input, self.memory)
print(f"Agent: {agent_response}")
if __name__ == "__main__":
# Initialize an Langchain LLM(here we use ChatGPT as an example)
from langchain.llms import OpenAI
llm = OpenAI(openai_api_key="YOUR_OPENAI_API_KEY")
# chatgpt cannot control temperature, do_sample, etc.
DEFAULT_RAG_CFG["mem_llm_kwargs"] = None
DEFAULT_RAG_CFG["disambig_llm_kwargs"] = None
DEFAULT_RAG_CFG["gen_llm_kwargs"] = None
rag = RAG_ChatBot(llm, DEFAULT_RAG_CFG)
rag.load_doc_from_console()
rag.start_test_session()

View File

@@ -0,0 +1,37 @@
# ColossalQA WebUI Demo
This demo provides a simple WebUI for ColossalQA, enabling you to upload your files as a knowledge base and interact with them through a chat interface in your browser.
The `server.py` initializes the backend RAG chain that can be backed by various language models (e.g., ChatGPT, Huawei Pangu, ChatGLM2). Meanwhile, `webui.py` launches a Gradio-supported chatbot interface.
# Usage
## Installation
First, install the necessary dependencies for ColossalQA:
```sh
git clone https://github.com/hpcaitech/ColossalAI.git
cd ColossalAI/applications/ColossalQA/
pip install -e .
```
## Configure the RAG Chain
Customize the RAG Chain settings, such as the embedding model (default: moka-ai/m3e) and the language model, in the `start_colossal_qa.sh` script.
For API-based language models (like ChatGPT or Huawei Pangu), provide your API key for authentication. For locally-run models, indicate the path to the model's checkpoint file.
If you want to customize prompts in the RAG Chain, you can have a look at the `RAG_ChatBot.py` file to modify them.
## Run WebUI Demo
Execute the following command to start the demo:
```sh
bash start_colossal_qa.sh
```
After launching the script, you can upload files and engage with the chatbot through your web browser.
![ColossalQA Demo](https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/colossalqa/img/qa_demo.png)

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.9 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.4 KiB

View File

@@ -0,0 +1,117 @@
import argparse
import copy
import json
import os
import random
import string
from http.server import BaseHTTPRequestHandler, HTTPServer
from colossalqa.local.llm import ColossalAPI, ColossalLLM
from colossalqa.data_loader.document_loader import DocumentLoader
from colossalqa.retrieval_conversation_zh import ChineseRetrievalConversation
from colossalqa.retriever import CustomRetriever
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from RAG_ChatBot import RAG_ChatBot, DEFAULT_RAG_CFG
# Define the mapping between embed_model_name(passed from Front End) and the actual path on the back end server
EMBED_MODEL_DICT = {
"m3e": os.environ.get("EMB_MODEL_PATH", DEFAULT_RAG_CFG["embed_model_name_or_path"])
}
# Define the mapping between LLM_name(passed from Front End) and the actual path on the back end server
LLM_DICT = {
"chatglm2": os.environ.get("CHAT_LLM_PATH", "THUDM/chatglm-6b"),
"pangu": "Pangu_API",
"chatgpt": "OpenAI_API"
}
def randomword(length):
letters = string.ascii_lowercase
return "".join(random.choice(letters) for i in range(length))
class ColossalQAServerRequestHandler(BaseHTTPRequestHandler):
chatbot = None
def _set_response(self):
"""
set http header for response
"""
self.send_response(200)
self.send_header("Content-type", "application/json")
self.end_headers()
def do_POST(self):
content_length = int(self.headers["Content-Length"])
post_data = self.rfile.read(content_length)
received_json = json.loads(post_data.decode("utf-8"))
print(received_json)
# conversation_ready is False(user's first request): Need to upload files and initialize the RAG chain
if received_json["conversation_ready"] is False:
self.rag_config = DEFAULT_RAG_CFG.copy()
try:
assert received_json["embed_model_name"] in EMBED_MODEL_DICT
assert received_json["llm_name"] in LLM_DICT
self.docs_files = received_json["docs"]
embed_model_name, llm_name = received_json["embed_model_name"], received_json["llm_name"]
# Find the embed_model/llm ckpt path on the back end server.
embed_model_path, llm_path = EMBED_MODEL_DICT[embed_model_name], LLM_DICT[llm_name]
self.rag_config["embed_model_name_or_path"] = embed_model_path
# Create the storage path for knowledge base files
self.rag_config["retri_kb_file_path"] = os.path.join(os.environ["TMP"], "colossalqa_kb/"+randomword(20))
if not os.path.exists(self.rag_config["retri_kb_file_path"]):
os.makedirs(self.rag_config["retri_kb_file_path"])
if (embed_model_path is not None) and (llm_path is not None):
# ---- Intialize LLM, QA_chatbot here ----
print("Initializing LLM...")
if llm_path == "Pangu_API":
from colossalqa.local.pangu_llm import Pangu
self.llm = Pangu(id=1)
self.llm.set_auth_config() # verify user's auth info here
self.rag_config["mem_llm_kwargs"] = None
self.rag_config["disambig_llm_kwargs"] = None
self.rag_config["gen_llm_kwargs"] = None
elif llm_path == "OpenAI_API":
from langchain.llms import OpenAI
self.llm = OpenAI()
self.rag_config["mem_llm_kwargs"] = None
self.rag_config["disambig_llm_kwargs"] = None
self.rag_config["gen_llm_kwargs"] = None
else:
# ** (For Testing Only) **
# In practice, all LLMs will run on the cloud platform and accessed by API, instead of running locally.
# initialize model from model_path by using ColossalLLM
self.rag_config["mem_llm_kwargs"] = {"max_new_tokens": 50, "temperature": 1, "do_sample": True}
self.rag_config["disambig_llm_kwargs"] = {"max_new_tokens": 30, "temperature": 1, "do_sample": True}
self.rag_config["gen_llm_kwargs"] = {"max_new_tokens": 100, "temperature": 1, "do_sample": True}
self.colossal_api = ColossalAPI(llm_name, llm_path)
self.llm = ColossalLLM(n=1, api=self.colossal_api)
print(f"Initializing RAG Chain...")
print("RAG_CONFIG: ", self.rag_config)
self.__class__.chatbot = RAG_ChatBot(self.llm, self.rag_config)
print("Loading Files....\n", self.docs_files)
self.__class__.chatbot.load_doc_from_files(self.docs_files)
# -----------------------------------------------------------------------------------
res = {"response": f"文件上传完成,模型初始化完成,让我们开始对话吧!(后端模型:{llm_name})", "error": "", "conversation_ready": True}
except Exception as e:
res = {"response": "文件上传或模型初始化有误,无法开始对话。",
"error": f"Error in File Uploading and/or RAG initialization. Error details: {e}",
"conversation_ready": False}
# conversation_ready is True: Chatbot and docs are all set. Ready to chat.
else:
user_input = received_json["user_input"]
chatbot_response, self.__class__.chatbot.memory = self.__class__.chatbot.run(user_input, self.__class__.chatbot.memory)
res = {"response": chatbot_response, "error": "", "conversation_ready": True}
self._set_response()
self.wfile.write(json.dumps(res).encode("utf-8"))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Chinese retrieval based conversation system")
parser.add_argument("--port", type=int, default=13666, help="port on localhost to start the server")
args = parser.parse_args()
server_address = ("localhost", args.port)
httpd = HTTPServer(server_address, ColossalQAServerRequestHandler)
print(f"Starting server on port {args.port}...")
httpd.serve_forever()

View File

@@ -0,0 +1,43 @@
#!/bin/bash
cleanup() {
echo "Caught Signal ... cleaning up."
pkill -P $$ # kill all subprocess of this script
exit 1 # exit script
}
# 'cleanup' is trigered when receive SIGINT(Ctrl+C) OR SIGTERM(kill) signal
trap cleanup INT TERM
# Disable your proxy
# unset HTTP_PROXY HTTPS_PROXY http_proxy https_proxy
# Path to store knowledge base(Home Directory by default)
export TMP=$HOME
# Use m3e as embedding model
export EMB_MODEL="m3e" # moka-ai/m3e-base model will be download automatically
# export EMB_MODEL_PATH="PATH_TO_LOCAL_CHECKPOINT/m3e-base" # you can also specify the local path to embedding model
# Choose a backend LLM
# - ChatGLM2
# export CHAT_LLM="chatglm2"
# export CHAT_LLM_PATH="PATH_TO_LOCAL_CHECKPOINT/chatglm2-6b"
# - ChatGPT
export CHAT_LLM="chatgpt"
# Auth info for OpenAI API
export OPENAI_API_KEY="YOUR_OPENAI_API_KEY"
# - Pangu
# export CHAT_LLM="pangu"
# # Auth info for Pangu API
# export URL=""
# export USERNAME=""
# export PASSWORD=""
# export DOMAIN_NAME=""
# Run server.py and colossalqa_webui.py in the background
python server.py &
python webui.py &
# Wait for all processes to finish
wait

View File

@@ -0,0 +1,102 @@
import json
import os
import gradio as gr
import requests
RAG_STATE = {"conversation_ready": False, # Conversation is not ready until files are uploaded and RAG chain is initialized
"embed_model_name": os.environ.get("EMB_MODEL", "m3e"),
"llm_name": os.environ.get("CHAT_LLM", "chatgpt")}
URL = "http://localhost:13666"
def get_response(client_data, URL):
headers = {"Content-type": "application/json"}
print(f"Sending request to server url: {URL}")
response = requests.post(URL, data=json.dumps(client_data), headers=headers)
response = json.loads(response.content)
return response
def add_text(history, text):
history = history + [(text, None)]
return history, gr.update(value=None, interactive=True)
def add_file(history, files):
global RAG_STATE
RAG_STATE["conversation_ready"] = False # after adding new files, reset the ChatBot
RAG_STATE["upload_files"]=[file.name for file in files]
files_string = "\n".join([os.path.basename(path) for path in RAG_STATE["upload_files"]])
print(files_string)
history = history + [(files_string, None)]
return history
def bot(history):
print(history)
global RAG_STATE
if not RAG_STATE["conversation_ready"]:
# Upload files and initialize models
client_data = {
"docs": RAG_STATE["upload_files"],
"embed_model_name": RAG_STATE["embed_model_name"], # Select embedding model name here
"llm_name": RAG_STATE["llm_name"], # Select LLM model name here. ["pangu", "chatglm2"]
"conversation_ready": RAG_STATE["conversation_ready"]
}
else:
client_data = {}
client_data["conversation_ready"] = RAG_STATE["conversation_ready"]
client_data["user_input"] = history[-1][0].strip()
response = get_response(client_data, URL) # TODO: async request, to avoid users waiting the model initialization too long
print(response)
if response["error"] != "":
raise gr.Error(response["error"])
RAG_STATE["conversation_ready"] = response["conversation_ready"]
history[-1][1] = response["response"]
yield history
CSS = """
.contain { display: flex; flex-direction: column; height: 100vh }
#component-0 { height: 100%; }
#chatbot { flex-grow: 1; }
"""
header_html = """
<div style="background: linear-gradient(to right, #2a0cf4, #7100ed, #9800e6, #b600df, #ce00d9, #dc0cd1, #e81bca, #f229c3, #f738ba, #f946b2, #fb53ab, #fb5fa5); padding: 20px; text-align: left;">
<h1 style="color: white;">ColossalQA</h1>
<h4 style="color: white;">ColossalQA</h4>
</div>
"""
with gr.Blocks(css=CSS) as demo:
html = gr.HTML(header_html)
chatbot = gr.Chatbot(
[],
elem_id="chatbot",
bubble_full_width=False,
avatar_images=(
(os.path.join(os.path.dirname(__file__), "img/avatar_user.png")),
(os.path.join(os.path.dirname(__file__), "img/avatar_ai.png")),
),
)
with gr.Row():
txt = gr.Textbox(
scale=4,
show_label=False,
placeholder="Enter text and press enter, or upload an image",
container=True,
autofocus=True,
)
btn = gr.UploadButton("📁", file_types=["file"], file_count="multiple")
txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(bot, chatbot, chatbot)
# Clear the original textbox
txt_msg.then(lambda: gr.update(value=None, interactive=True), None, [txt], queue=False)
# Click Upload Button: 1. upload files 2. send config to backend, initalize model 3. get response "conversation_ready" = True/False
file_msg = btn.upload(add_file, [chatbot, btn], [chatbot], queue=False).then(bot, chatbot, chatbot)
if __name__ == "__main__":
demo.queue()
demo.launch(share=True) # share=True will release a public link of the demo