mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 10:34:41 +00:00
[Feature] Add document retrieval QA (#5020)
* add langchain * add langchain * Add files via upload * add langchain * fix style * fix style: remove extra space * add pytest; modified retriever * add pytest; modified retriever * add tests to build_on_pr.yml * fix build_on_pr.yml * fix build on pr; fix environ vars * seperate unit tests for colossalqa from build from pr * fix container setting; fix environ vars * commented dev code * add incremental update * remove stale code * fix style * change to sha3 224 * fix retriever; fix style; add unit test for document loader * fix ci workflow config * fix ci workflow config * add set cuda visible device script in ci * fix doc string * fix style; update readme; refactored * add force log info * change build on pr, ignore colossalqa * fix docstring, captitalize all initial letters * fix indexing; fix text-splitter * remove debug code, update reference * reset previous commit * update LICENSE update README add key-value mode, fix bugs * add files back * revert force push * remove junk file * add test files * fix retriever bug, add intent classification * change conversation chain design * rewrite prompt and conversation chain * add ui v1 * ui v1 * fix atavar * add header * Refactor the RAG Code and support Pangu * Refactor the ColossalQA chain to Object-Oriented Programming and the UI demo. * resolved conversation. tested scripts under examples. web demo still buggy * fix ci tests * Some modifications to add ChatGPT api * modify llm.py and remove unnecessary files * Delete applications/ColossalQA/examples/ui/test_frontend_input.json * Remove OpenAI api key * add colossalqa * move files * move files * move files * move files * fix style * Add Readme and fix some bugs. * Add something to readme and modify some code * modify a directory name for clarity * remove redundant directory * Correct a type in llm.py * fix AI prefix * fix test_memory.py * fix conversation * fix some erros and typos * Fix a missing import in RAG_ChatBot.py * add colossalcloud LLM wrapper, correct issues in code review --------- Co-authored-by: YeAnbang <anbangy2@outlook.com> Co-authored-by: Orion-Zheng <zheng_zian@u.nus.edu> Co-authored-by: Zian(Andy) Zheng <62330719+Orion-Zheng@users.noreply.github.com> Co-authored-by: Orion-Zheng <zhengzian@u.nus.edu>
This commit is contained in:
@@ -0,0 +1,87 @@
|
||||
"""
|
||||
Script for Chinese retrieval based conversation system backed by ChatGLM
|
||||
"""
|
||||
from typing import Tuple
|
||||
|
||||
from colossalqa.chain.retrieval_qa.base import RetrievalQA
|
||||
from colossalqa.local.llm import ColossalAPI, ColossalLLM
|
||||
from colossalqa.memory import ConversationBufferWithSummary
|
||||
from colossalqa.mylogging import get_logger
|
||||
from colossalqa.prompt.prompt import PROMPT_DISAMBIGUATE_EN, PROMPT_RETRIEVAL_QA_EN
|
||||
from colossalqa.retriever import CustomRetriever
|
||||
from langchain import LLMChain
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class EnglishRetrievalConversation:
|
||||
"""
|
||||
Wrapper class for Chinese retrieval conversation system
|
||||
"""
|
||||
|
||||
def __init__(self, retriever: CustomRetriever, model_path: str, model_name: str) -> None:
|
||||
"""
|
||||
Setup retrieval qa chain for Chinese retrieval based QA
|
||||
"""
|
||||
logger.info(f"model_name: {model_name}; model_path: {model_path}", verbose=True)
|
||||
colossal_api = ColossalAPI.get_api(model_name, model_path)
|
||||
self.llm = ColossalLLM(n=1, api=colossal_api)
|
||||
|
||||
# Define the retriever
|
||||
self.retriever = retriever
|
||||
|
||||
# Define the chain to preprocess the input
|
||||
# Disambiguate the input. e.g. "What is the capital of that country?" -> "What is the capital of France?"
|
||||
# Prompt is summarization prompt
|
||||
self.llm_chain_disambiguate = LLMChain(
|
||||
llm=self.llm,
|
||||
prompt=PROMPT_DISAMBIGUATE_EN,
|
||||
llm_kwargs={"max_new_tokens": 30, "temperature": 0.6, "do_sample": True},
|
||||
)
|
||||
|
||||
self.retriever.set_rephrase_handler(self.disambiguity)
|
||||
# Define memory with summarization ability
|
||||
self.memory = ConversationBufferWithSummary(
|
||||
llm=self.llm, llm_kwargs={"max_new_tokens": 50, "temperature": 0.6, "do_sample": True}
|
||||
)
|
||||
self.memory.initiate_document_retrieval_chain(
|
||||
self.llm,
|
||||
PROMPT_RETRIEVAL_QA_EN,
|
||||
self.retriever,
|
||||
chain_type_kwargs={
|
||||
"chat_history": "",
|
||||
},
|
||||
)
|
||||
self.retrieval_chain = RetrievalQA.from_chain_type(
|
||||
llm=self.llm,
|
||||
verbose=False,
|
||||
chain_type="stuff",
|
||||
retriever=self.retriever,
|
||||
chain_type_kwargs={"prompt": PROMPT_RETRIEVAL_QA_EN, "memory": self.memory},
|
||||
llm_kwargs={"max_new_tokens": 50, "temperature": 0.75, "do_sample": True},
|
||||
)
|
||||
|
||||
def disambiguity(self, input: str):
|
||||
out = self.llm_chain_disambiguate.run(input=input, chat_history=self.memory.buffer, stop=["\n"])
|
||||
return out.split("\n")[0]
|
||||
|
||||
@classmethod
|
||||
def from_retriever(
|
||||
cls, retriever: CustomRetriever, model_path: str, model_name: str
|
||||
) -> "EnglishRetrievalConversation":
|
||||
return cls(retriever, model_path, model_name)
|
||||
|
||||
def run(self, user_input: str, memory: ConversationBufferWithSummary) -> Tuple[str, ConversationBufferWithSummary]:
|
||||
if memory:
|
||||
# TODO add translation chain here
|
||||
self.memory.buffered_history.messages = memory.buffered_history.messages
|
||||
self.memory.summarized_history_temp.messages = memory.summarized_history_temp.messages
|
||||
return (
|
||||
self.retrieval_chain.run(
|
||||
query=user_input,
|
||||
stop=[self.memory.human_prefix + ": "],
|
||||
rejection_trigger_keywrods=["cannot answer the question"],
|
||||
rejection_answer="Sorry, this question cannot be answered based on the information provided.",
|
||||
).split("\n")[0],
|
||||
self.memory,
|
||||
)
|
Reference in New Issue
Block a user