mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 21:40:02 +00:00
[Feature] Add document retrieval QA (#5020)
* add langchain * add langchain * Add files via upload * add langchain * fix style * fix style: remove extra space * add pytest; modified retriever * add pytest; modified retriever * add tests to build_on_pr.yml * fix build_on_pr.yml * fix build on pr; fix environ vars * seperate unit tests for colossalqa from build from pr * fix container setting; fix environ vars * commented dev code * add incremental update * remove stale code * fix style * change to sha3 224 * fix retriever; fix style; add unit test for document loader * fix ci workflow config * fix ci workflow config * add set cuda visible device script in ci * fix doc string * fix style; update readme; refactored * add force log info * change build on pr, ignore colossalqa * fix docstring, captitalize all initial letters * fix indexing; fix text-splitter * remove debug code, update reference * reset previous commit * update LICENSE update README add key-value mode, fix bugs * add files back * revert force push * remove junk file * add test files * fix retriever bug, add intent classification * change conversation chain design * rewrite prompt and conversation chain * add ui v1 * ui v1 * fix atavar * add header * Refactor the RAG Code and support Pangu * Refactor the ColossalQA chain to Object-Oriented Programming and the UI demo. * resolved conversation. tested scripts under examples. web demo still buggy * fix ci tests * Some modifications to add ChatGPT api * modify llm.py and remove unnecessary files * Delete applications/ColossalQA/examples/ui/test_frontend_input.json * Remove OpenAI api key * add colossalqa * move files * move files * move files * move files * fix style * Add Readme and fix some bugs. * Add something to readme and modify some code * modify a directory name for clarity * remove redundant directory * Correct a type in llm.py * fix AI prefix * fix test_memory.py * fix conversation * fix some erros and typos * Fix a missing import in RAG_ChatBot.py * add colossalcloud LLM wrapper, correct issues in code review --------- Co-authored-by: YeAnbang <anbangy2@outlook.com> Co-authored-by: Orion-Zheng <zheng_zian@u.nus.edu> Co-authored-by: Zian(Andy) Zheng <62330719+Orion-Zheng@users.noreply.github.com> Co-authored-by: Orion-Zheng <zhengzian@u.nus.edu>
This commit is contained in:
168
applications/ColossalQA/colossalqa/memory.py
Normal file
168
applications/ColossalQA/colossalqa/memory.py
Normal file
@@ -0,0 +1,168 @@
|
||||
"""
|
||||
Implement a memory class for storing conversation history
|
||||
Support long term and short term memory
|
||||
"""
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from colossalqa.chain.memory.summary import ConversationSummaryMemory
|
||||
from colossalqa.chain.retrieval_qa.load_chain import load_qa_chain
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.memory.chat_message_histories.in_memory import ChatMessageHistory
|
||||
from langchain.schema import BaseChatMessageHistory
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain.schema.retriever import BaseRetriever
|
||||
from pydantic import Field
|
||||
|
||||
|
||||
class ConversationBufferWithSummary(ConversationSummaryMemory):
|
||||
"""Memory class for storing information about entities."""
|
||||
|
||||
# Define dictionary to store information about entities.
|
||||
# Store the most recent conversation history
|
||||
buffered_history: BaseChatMessageHistory = Field(default_factory=ChatMessageHistory)
|
||||
# Temp buffer
|
||||
summarized_history_temp: BaseChatMessageHistory = Field(default_factory=ChatMessageHistory)
|
||||
human_prefix: str = "Human"
|
||||
ai_prefix: str = "Assistant"
|
||||
buffer: str = "" # Formated conversation in str
|
||||
existing_summary: str = "" # Summarization of stale converstion in str
|
||||
# Define key to pass information about entities into prompt.
|
||||
memory_key: str = "chat_history"
|
||||
input_key: str = "question"
|
||||
retriever: BaseRetriever = None
|
||||
max_tokens: int = 2000
|
||||
chain: BaseCombineDocumentsChain = None
|
||||
input_chain_type_kwargs: List = {}
|
||||
|
||||
@property
|
||||
def buffer(self) -> Any:
|
||||
"""String buffer of memory."""
|
||||
return self.buffer_as_messages if self.return_messages else self.buffer_as_str
|
||||
|
||||
@property
|
||||
def buffer_as_str(self) -> str:
|
||||
"""Exposes the buffer as a string in case return_messages is True."""
|
||||
self.buffer = self.format_dialogue()
|
||||
return self.buffer
|
||||
|
||||
@property
|
||||
def buffer_as_messages(self) -> List[BaseMessage]:
|
||||
"""Exposes the buffer as a list of messages in case return_messages is False."""
|
||||
return self.buffered_history.messages
|
||||
|
||||
def clear(self):
|
||||
"""Clear all the memory"""
|
||||
self.buffered_history.clear()
|
||||
self.summarized_history_temp.clear()
|
||||
|
||||
def initiate_document_retrieval_chain(
|
||||
self, llm: Any, prompt_template: Any, retriever: Any, chain_type_kwargs: Dict[str, Any] = {}
|
||||
) -> None:
|
||||
"""
|
||||
Since we need to calculate the length of the prompt, we need to initiate a retrieval chain
|
||||
to calculate the length of the prompt.
|
||||
Args:
|
||||
llm: the language model for the retrieval chain (we won't actually return the output)
|
||||
prompt_template: the prompt template for constructing the retrieval chain
|
||||
retriever: the retriever for the retrieval chain
|
||||
max_tokens: the max length of the prompt (not include the output)
|
||||
chain_type_kwargs: the kwargs for the retrieval chain
|
||||
memory_key: the key for the chat history
|
||||
input_key: the key for the input query
|
||||
"""
|
||||
self.retriever = retriever
|
||||
input_chain_type_kwargs = {k: v for k, v in chain_type_kwargs.items() if k not in [self.memory_key]}
|
||||
self.input_chain_type_kwargs = input_chain_type_kwargs
|
||||
self.chain = load_qa_chain(llm, chain_type="stuff", prompt=prompt_template, **self.input_chain_type_kwargs)
|
||||
|
||||
@property
|
||||
def memory_variables(self) -> List[str]:
|
||||
"""Define the variables we are providing to the prompt."""
|
||||
return [self.memory_key]
|
||||
|
||||
def format_dialogue(self, lang: str = "en") -> str:
|
||||
"""Format memory into two parts--- summarization of historical conversation and most recent conversation"""
|
||||
if len(self.summarized_history_temp.messages) != 0:
|
||||
for i in range(int(len(self.summarized_history_temp.messages) / 2)):
|
||||
self.existing_summary = (
|
||||
self.predict_new_summary(
|
||||
self.summarized_history_temp.messages[i * 2 : i * 2 + 2], self.existing_summary, stop=["\n\n"]
|
||||
)
|
||||
.strip()
|
||||
.split("\n")[0]
|
||||
.strip()
|
||||
)
|
||||
for i in range(int(len(self.summarized_history_temp.messages) / 2)):
|
||||
self.summarized_history_temp.messages.pop(0)
|
||||
self.summarized_history_temp.messages.pop(0)
|
||||
conversation_buffer = []
|
||||
for t in self.buffered_history.messages:
|
||||
if t.type == "human":
|
||||
prefix = self.human_prefix
|
||||
else:
|
||||
prefix = self.ai_prefix
|
||||
conversation_buffer.append(prefix + ": " + t.content)
|
||||
conversation_buffer = "\n".join(conversation_buffer)
|
||||
if len(self.existing_summary) > 0:
|
||||
if lang == "en":
|
||||
message = f"A summarization of historical conversation:\n{self.existing_summary}\nMost recent conversation:\n{conversation_buffer}"
|
||||
elif lang == "zh":
|
||||
message = f"历史对话概要:\n{self.existing_summary}\n最近的对话:\n{conversation_buffer}"
|
||||
else:
|
||||
raise ValueError("Unsupported language")
|
||||
return message
|
||||
else:
|
||||
message = conversation_buffer
|
||||
return message
|
||||
|
||||
def get_conversation_length(self):
|
||||
"""Get the length of the formatted conversation"""
|
||||
prompt = self.format_dialogue()
|
||||
length = self.llm.get_num_tokens(prompt)
|
||||
return length
|
||||
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
"""Load the memory variables.
|
||||
Summarize oversize conversation to fit into the length constraint defined by max_tokene
|
||||
Args:
|
||||
inputs: the kwargs of the chain of your definition
|
||||
Returns:
|
||||
a dict that maps from memory key to the formated dialogue
|
||||
the formated dialogue has the following format
|
||||
if conversation is too long:
|
||||
A summarization of historical conversation:
|
||||
{summarization}
|
||||
Most recent conversation:
|
||||
Human: XXX
|
||||
Assistant: XXX
|
||||
...
|
||||
otherwise
|
||||
Human: XXX
|
||||
Assistant: XXX
|
||||
...
|
||||
"""
|
||||
# Calculate remain length
|
||||
if "input_documents" in inputs:
|
||||
# Run in a retrieval qa chain
|
||||
docs = inputs["input_documents"]
|
||||
else:
|
||||
# For test
|
||||
docs = self.retriever.get_relevant_documents(inputs[self.input_key])
|
||||
inputs[self.memory_key] = ""
|
||||
inputs = {k: v for k, v in inputs.items() if k in [self.chain.input_key, self.input_key, self.memory_key]}
|
||||
prompt_length = self.chain.prompt_length(docs, **inputs)
|
||||
remain = self.max_tokens - prompt_length
|
||||
while self.get_conversation_length() > remain:
|
||||
if len(self.buffered_history.messages) <= 2:
|
||||
raise RuntimeError("Exeeed max_tokens, trunck size of retrieved documents is too large")
|
||||
temp = self.buffered_history.messages.pop(0)
|
||||
self.summarized_history_temp.messages.append(temp)
|
||||
temp = self.buffered_history.messages.pop(0)
|
||||
self.summarized_history_temp.messages.append(temp)
|
||||
return {self.memory_key: self.format_dialogue()}
|
||||
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
"""Save context from this conversation to buffer."""
|
||||
input_str, output_str = self._get_input_output(inputs, outputs)
|
||||
self.buffered_history.add_user_message(input_str.strip())
|
||||
self.buffered_history.add_ai_message(output_str.strip())
|
Reference in New Issue
Block a user