mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 01:06:00 +00:00
[Feature] Add document retrieval QA (#5020)
* add langchain * add langchain * Add files via upload * add langchain * fix style * fix style: remove extra space * add pytest; modified retriever * add pytest; modified retriever * add tests to build_on_pr.yml * fix build_on_pr.yml * fix build on pr; fix environ vars * seperate unit tests for colossalqa from build from pr * fix container setting; fix environ vars * commented dev code * add incremental update * remove stale code * fix style * change to sha3 224 * fix retriever; fix style; add unit test for document loader * fix ci workflow config * fix ci workflow config * add set cuda visible device script in ci * fix doc string * fix style; update readme; refactored * add force log info * change build on pr, ignore colossalqa * fix docstring, captitalize all initial letters * fix indexing; fix text-splitter * remove debug code, update reference * reset previous commit * update LICENSE update README add key-value mode, fix bugs * add files back * revert force push * remove junk file * add test files * fix retriever bug, add intent classification * change conversation chain design * rewrite prompt and conversation chain * add ui v1 * ui v1 * fix atavar * add header * Refactor the RAG Code and support Pangu * Refactor the ColossalQA chain to Object-Oriented Programming and the UI demo. * resolved conversation. tested scripts under examples. web demo still buggy * fix ci tests * Some modifications to add ChatGPT api * modify llm.py and remove unnecessary files * Delete applications/ColossalQA/examples/ui/test_frontend_input.json * Remove OpenAI api key * add colossalqa * move files * move files * move files * move files * fix style * Add Readme and fix some bugs. * Add something to readme and modify some code * modify a directory name for clarity * remove redundant directory * Correct a type in llm.py * fix AI prefix * fix test_memory.py * fix conversation * fix some erros and typos * Fix a missing import in RAG_ChatBot.py * add colossalcloud LLM wrapper, correct issues in code review --------- Co-authored-by: YeAnbang <anbangy2@outlook.com> Co-authored-by: Orion-Zheng <zheng_zian@u.nus.edu> Co-authored-by: Zian(Andy) Zheng <62330719+Orion-Zheng@users.noreply.github.com> Co-authored-by: Orion-Zheng <zhengzian@u.nus.edu>
This commit is contained in:
@@ -0,0 +1,91 @@
|
||||
"""
|
||||
Chain that combines documents by stuffing into context
|
||||
|
||||
Modified from Original Source
|
||||
|
||||
This code is based on LangChain Ai's langchain, which can be found at
|
||||
https://github.com/langchain-ai/langchain
|
||||
The original code is licensed under the MIT license.
|
||||
"""
|
||||
import copy
|
||||
from typing import Any, List
|
||||
|
||||
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.schema import format_document
|
||||
|
||||
|
||||
class CustomStuffDocumentsChain(StuffDocumentsChain):
|
||||
"""Chain that combines documents by stuffing into context.
|
||||
|
||||
This chain takes a list of documents and first combines them into a single string.
|
||||
It does this by formatting each document into a string with the `document_prompt`
|
||||
and then joining them together with `document_separator`. It then adds that new
|
||||
string to the inputs with the variable name set by `document_variable_name`.
|
||||
Those inputs are then passed to the `llm_chain`.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.chains import StuffDocumentsChain, LLMChain
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.llms import OpenAI
|
||||
|
||||
# This controls how each document will be formatted. Specifically,
|
||||
# it will be passed to `format_document` - see that function for more
|
||||
# details.
|
||||
document_prompt = PromptTemplate(
|
||||
input_variables=["page_content"],
|
||||
template="{page_content}"
|
||||
)
|
||||
document_variable_name = "context"
|
||||
llm = OpenAI()
|
||||
# The prompt here should take as an input variable the
|
||||
# `document_variable_name`
|
||||
prompt = PromptTemplate.from_template(
|
||||
"Summarize this content: {context}"
|
||||
)
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||
chain = StuffDocumentsChain(
|
||||
llm_chain=llm_chain,
|
||||
document_prompt=document_prompt,
|
||||
document_variable_name=document_variable_name
|
||||
)
|
||||
"""
|
||||
|
||||
def _get_inputs(self, docs: List[Document], **kwargs: Any) -> dict:
|
||||
"""Construct inputs from kwargs and docs.
|
||||
|
||||
Format and the join all the documents together into one input with name
|
||||
`self.document_variable_name`. The pluck any additional variables
|
||||
from **kwargs.
|
||||
|
||||
Args:
|
||||
docs: List of documents to format and then join into single input
|
||||
**kwargs: additional inputs to chain, will pluck any other required
|
||||
arguments from here.
|
||||
|
||||
Returns:
|
||||
dictionary of inputs to LLMChain
|
||||
"""
|
||||
# Format each document according to the prompt
|
||||
|
||||
# if the document is in the key-value format has a 'is_key_value_mapping'=True in meta_data and has 'value' in metadata
|
||||
# use the value to replace the key
|
||||
doc_prefix = kwargs.get("doc_prefix", "Supporting Document")
|
||||
docs_ = []
|
||||
for id, doc in enumerate(docs):
|
||||
doc_ = copy.deepcopy(doc)
|
||||
if doc_.metadata.get("is_key_value_mapping", False) and "value" in doc_.metadata:
|
||||
doc_.page_content = str(doc_.metadata["value"])
|
||||
prefix = doc_prefix + str(id)
|
||||
doc_.page_content = str(prefix + ":" + (" " if doc_.page_content[0] != " " else "") + doc_.page_content)
|
||||
docs_.append(doc_)
|
||||
|
||||
doc_strings = [format_document(doc, self.document_prompt) for doc in docs_]
|
||||
arg_list = ["stop", "temperature", "top_k", "top_p", "max_new_tokens"]
|
||||
arg_list.extend(self.llm_chain.prompt.input_variables)
|
||||
# Join the documents together to put them in the prompt.
|
||||
inputs = {k: v for k, v in kwargs.items() if k in arg_list}
|
||||
inputs[self.document_variable_name] = self.document_separator.join(doc_strings)
|
||||
return inputs
|
Reference in New Issue
Block a user