mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 14:41:53 +00:00
[Feature] Add document retrieval QA (#5020)
* add langchain * add langchain * Add files via upload * add langchain * fix style * fix style: remove extra space * add pytest; modified retriever * add pytest; modified retriever * add tests to build_on_pr.yml * fix build_on_pr.yml * fix build on pr; fix environ vars * seperate unit tests for colossalqa from build from pr * fix container setting; fix environ vars * commented dev code * add incremental update * remove stale code * fix style * change to sha3 224 * fix retriever; fix style; add unit test for document loader * fix ci workflow config * fix ci workflow config * add set cuda visible device script in ci * fix doc string * fix style; update readme; refactored * add force log info * change build on pr, ignore colossalqa * fix docstring, captitalize all initial letters * fix indexing; fix text-splitter * remove debug code, update reference * reset previous commit * update LICENSE update README add key-value mode, fix bugs * add files back * revert force push * remove junk file * add test files * fix retriever bug, add intent classification * change conversation chain design * rewrite prompt and conversation chain * add ui v1 * ui v1 * fix atavar * add header * Refactor the RAG Code and support Pangu * Refactor the ColossalQA chain to Object-Oriented Programming and the UI demo. * resolved conversation. tested scripts under examples. web demo still buggy * fix ci tests * Some modifications to add ChatGPT api * modify llm.py and remove unnecessary files * Delete applications/ColossalQA/examples/ui/test_frontend_input.json * Remove OpenAI api key * add colossalqa * move files * move files * move files * move files * fix style * Add Readme and fix some bugs. * Add something to readme and modify some code * modify a directory name for clarity * remove redundant directory * Correct a type in llm.py * fix AI prefix * fix test_memory.py * fix conversation * fix some erros and typos * Fix a missing import in RAG_ChatBot.py * add colossalcloud LLM wrapper, correct issues in code review --------- Co-authored-by: YeAnbang <anbangy2@outlook.com> Co-authored-by: Orion-Zheng <zheng_zian@u.nus.edu> Co-authored-by: Zian(Andy) Zheng <62330719+Orion-Zheng@users.noreply.github.com> Co-authored-by: Orion-Zheng <zhengzian@u.nus.edu>
This commit is contained in:
0
applications/ColossalQA/colossalqa/__init__.py
Normal file
0
applications/ColossalQA/colossalqa/__init__.py
Normal file
103
applications/ColossalQA/colossalqa/chain/memory/summary.py
Normal file
103
applications/ColossalQA/colossalqa/chain/memory/summary.py
Normal file
@@ -0,0 +1,103 @@
|
||||
"""
|
||||
Custom SummarizerMixin base class and ConversationSummaryMemory class
|
||||
|
||||
Modified from Original Source
|
||||
|
||||
This code is based on LangChain Ai's langchain, which can be found at
|
||||
https://github.com/langchain-ai/langchain
|
||||
The original code is licensed under the MIT license.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Type
|
||||
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.memory.prompt import SUMMARY_PROMPT
|
||||
from langchain.pydantic_v1 import BaseModel, root_validator
|
||||
from langchain.schema import BaseChatMessageHistory, BasePromptTemplate
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.schema.messages import BaseMessage, SystemMessage, get_buffer_string
|
||||
|
||||
|
||||
class SummarizerMixin(BaseModel):
|
||||
"""
|
||||
Mixin for summarizer.
|
||||
"""
|
||||
|
||||
human_prefix: str = "Human"
|
||||
ai_prefix: str = "Assistant"
|
||||
llm: BaseLanguageModel
|
||||
prompt: BasePromptTemplate = SUMMARY_PROMPT
|
||||
summary_message_cls: Type[BaseMessage] = SystemMessage
|
||||
llm_kwargs: Dict = {}
|
||||
|
||||
def predict_new_summary(self, messages: List[BaseMessage], existing_summary: str, stop: List = []) -> str:
|
||||
"""
|
||||
Recursively summarize a conversation by generating a new summary using
|
||||
the last round of conversation and the existing summary.
|
||||
"""
|
||||
new_lines = get_buffer_string(
|
||||
messages,
|
||||
human_prefix=self.human_prefix,
|
||||
ai_prefix=self.ai_prefix,
|
||||
)
|
||||
|
||||
chain = LLMChain(llm=self.llm, prompt=self.prompt, llm_kwargs=self.llm_kwargs)
|
||||
return chain.predict(summary=existing_summary, new_lines=new_lines, stop=stop)
|
||||
|
||||
|
||||
class ConversationSummaryMemory(BaseChatMemory, SummarizerMixin):
|
||||
"""Conversation summarizer to chat memory."""
|
||||
|
||||
buffer: str = ""
|
||||
memory_key: str = "history"
|
||||
|
||||
@classmethod
|
||||
def from_messages(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
chat_memory: BaseChatMessageHistory,
|
||||
summarize_step: int = 2,
|
||||
**kwargs: Any,
|
||||
) -> ConversationSummaryMemory:
|
||||
obj = cls(llm=llm, chat_memory=chat_memory, **kwargs)
|
||||
for i in range(0, len(obj.chat_memory.messages), summarize_step):
|
||||
obj.buffer = obj.predict_new_summary(obj.chat_memory.messages[i : i + summarize_step], obj.buffer)
|
||||
return obj
|
||||
|
||||
@property
|
||||
def memory_variables(self) -> List[str]:
|
||||
"""Will always return list of memory variables."""
|
||||
return [self.memory_key]
|
||||
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Return history buffer."""
|
||||
if self.return_messages:
|
||||
buffer: Any = [self.summary_message_cls(content=self.buffer)]
|
||||
else:
|
||||
buffer = self.buffer
|
||||
return {self.memory_key: buffer}
|
||||
|
||||
@root_validator()
|
||||
def validate_prompt_input_variables(cls, values: Dict) -> Dict:
|
||||
"""Validate that prompt input variables are consistent."""
|
||||
prompt_variables = values["prompt"].input_variables
|
||||
expected_keys = {"summary", "new_lines"}
|
||||
if expected_keys != set(prompt_variables):
|
||||
raise ValueError(
|
||||
"Got unexpected prompt input variables. The prompt expects "
|
||||
f"{prompt_variables}, but it should have {expected_keys}."
|
||||
)
|
||||
return values
|
||||
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
"""Save context from this conversation to buffer."""
|
||||
super().save_context(inputs, outputs)
|
||||
self.buffer = self.predict_new_summary(self.chat_memory.messages[-2:], self.buffer)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear memory contents."""
|
||||
super().clear()
|
||||
self.buffer = ""
|
214
applications/ColossalQA/colossalqa/chain/retrieval_qa/base.py
Normal file
214
applications/ColossalQA/colossalqa/chain/retrieval_qa/base.py
Normal file
@@ -0,0 +1,214 @@
|
||||
"""
|
||||
Chain for question-answering against a vector database.
|
||||
|
||||
Modified from Original Source
|
||||
|
||||
This code is based on LangChain Ai's langchain, which can be found at
|
||||
https://github.com/langchain-ai/langchain
|
||||
The original code is licensed under the MIT license.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import inspect
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from colossalqa.chain.retrieval_qa.load_chain import load_qa_chain
|
||||
from colossalqa.chain.retrieval_qa.stuff import CustomStuffDocumentsChain
|
||||
from langchain.callbacks.manager import AsyncCallbackManagerForChainRun, CallbackManagerForChainRun, Callbacks
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.question_answering.stuff_prompt import PROMPT_SELECTOR
|
||||
from langchain.chains.retrieval_qa.base import BaseRetrievalQA
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.pydantic_v1 import Field
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
|
||||
class CustomBaseRetrievalQA(BaseRetrievalQA):
|
||||
"""Base class for question-answering chains."""
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
prompt: Optional[PromptTemplate] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseRetrievalQA:
|
||||
"""Initialize from LLM."""
|
||||
llm_kwargs = kwargs.pop("llm_kwargs", {})
|
||||
_prompt = prompt or PROMPT_SELECTOR.get_prompt(llm)
|
||||
llm_chain = LLMChain(llm=llm, prompt=_prompt, callbacks=callbacks, llm_kwargs=llm_kwargs)
|
||||
document_prompt = kwargs.get(
|
||||
"document_prompt", PromptTemplate(input_variables=["page_content"], template="Context:\n{page_content}")
|
||||
)
|
||||
combine_documents_chain = CustomStuffDocumentsChain(
|
||||
llm_chain=llm_chain,
|
||||
document_variable_name="context",
|
||||
document_prompt=document_prompt,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
return cls(
|
||||
combine_documents_chain=combine_documents_chain,
|
||||
callbacks=callbacks,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_chain_type(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
chain_type: str = "stuff",
|
||||
chain_type_kwargs: Optional[dict] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseRetrievalQA:
|
||||
"""Load chain from chain type."""
|
||||
llm_kwargs = kwargs.pop("llm_kwargs", {})
|
||||
_chain_type_kwargs = chain_type_kwargs or {}
|
||||
combine_documents_chain = load_qa_chain(llm, chain_type=chain_type, **_chain_type_kwargs, llm_kwargs=llm_kwargs)
|
||||
return cls(combine_documents_chain=combine_documents_chain, **kwargs)
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run get_relevant_text and llm on input query.
|
||||
|
||||
If chain has 'return_source_documents' as 'True', returns
|
||||
the retrieved documents as well under the key 'source_documents'.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
res = indexqa({'query': 'This is my query'})
|
||||
answer, docs = res['result'], res['source_documents']
|
||||
"""
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
question = inputs[self.input_key]
|
||||
accepts_run_manager = "run_manager" in inspect.signature(self._get_docs).parameters
|
||||
if accepts_run_manager:
|
||||
docs = self._get_docs(question, run_manager=_run_manager)
|
||||
else:
|
||||
docs = self._get_docs(question) # type: ignore[call-arg]
|
||||
|
||||
kwargs = {
|
||||
k: v
|
||||
for k, v in inputs.items()
|
||||
if k in ["stop", "temperature", "top_k", "top_p", "max_new_tokens", "doc_prefix"]
|
||||
}
|
||||
answers = []
|
||||
if self.combine_documents_chain.memory is not None:
|
||||
buffered_history_backup, summarized_history_temp_backup = copy.deepcopy(
|
||||
self.combine_documents_chain.memory.buffered_history
|
||||
), copy.deepcopy(self.combine_documents_chain.memory.summarized_history_temp)
|
||||
else:
|
||||
buffered_history_backup = None
|
||||
summarized_history_temp_backup = None
|
||||
|
||||
answer = self.combine_documents_chain.run(
|
||||
input_documents=docs, question=question, callbacks=_run_manager.get_child(), **kwargs
|
||||
)
|
||||
if summarized_history_temp_backup is not None and buffered_history_backup is not None:
|
||||
(
|
||||
self.combine_documents_chain.memory.buffered_history,
|
||||
self.combine_documents_chain.memory.summarized_history_temp,
|
||||
) = copy.deepcopy(buffered_history_backup), copy.deepcopy(summarized_history_temp_backup)
|
||||
|
||||
# if rejection_trigger_keywords is not given, return the response from LLM directly
|
||||
rejection_trigger_keywrods = inputs.get('rejection_trigger_keywrods', [])
|
||||
answer = answer if all([rej not in answer for rej in rejection_trigger_keywrods]) else None
|
||||
if answer is None:
|
||||
answer = inputs.get('rejection_answer', "抱歉,根据提供的信息无法回答该问题。")
|
||||
if self.combine_documents_chain.memory is not None:
|
||||
self.combine_documents_chain.memory.save_context({"question": question}, {"output": answer})
|
||||
|
||||
if self.return_source_documents:
|
||||
return {self.output_key: answer, "source_documents": docs}
|
||||
else:
|
||||
return {self.output_key: answer}
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run get_relevant_text and llm on input query.
|
||||
|
||||
If chain has 'return_source_documents' as 'True', returns
|
||||
the retrieved documents as well under the key 'source_documents'.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
res = indexqa({'query': 'This is my query'})
|
||||
answer, docs = res['result'], res['source_documents']
|
||||
"""
|
||||
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
||||
question = inputs[self.input_key]
|
||||
accepts_run_manager = "run_manager" in inspect.signature(self._aget_docs).parameters
|
||||
if accepts_run_manager:
|
||||
docs = await self._aget_docs(question, run_manager=_run_manager)
|
||||
else:
|
||||
docs = await self._aget_docs(question) # type: ignore[call-arg]
|
||||
kwargs = {
|
||||
k: v
|
||||
for k, v in inputs.items()
|
||||
if k in ["stop", "temperature", "top_k", "top_p", "max_new_tokens", "doc_prefix"]
|
||||
}
|
||||
answer = await self.combine_documents_chain.arun(
|
||||
input_documents=docs, question=question, callbacks=_run_manager.get_child(), **kwargs
|
||||
)
|
||||
# if rejection_trigger_keywords is not given, return the response from LLM directly
|
||||
rejection_trigger_keywrods = inputs.get('rejection_trigger_keywrods', [])
|
||||
answer = answer if all([rej not in answer for rej in rejection_trigger_keywrods]) or len(rejection_trigger_keywrods)==0 else None
|
||||
if answer is None:
|
||||
answer = inputs.get('rejection_answer', "抱歉,根据提供的信息无法回答该问题。")
|
||||
self.combine_documents_chain.memory.save_context({"question": question}, {"output": answer})
|
||||
|
||||
if self.return_source_documents:
|
||||
return {self.output_key: answer, "source_documents": docs}
|
||||
else:
|
||||
return {self.output_key: answer}
|
||||
|
||||
|
||||
class RetrievalQA(CustomBaseRetrievalQA):
|
||||
"""Chain for question-answering against an index.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.llms import OpenAI
|
||||
from langchain.chains import RetrievalQA
|
||||
from langchain.faiss import FAISS
|
||||
from langchain.vectorstores.base import VectorStoreRetriever
|
||||
retriever = VectorStoreRetriever(vectorstore=FAISS(...))
|
||||
retrievalQA = RetrievalQA.from_llm(llm=OpenAI(), retriever=retriever)
|
||||
|
||||
"""
|
||||
|
||||
retriever: BaseRetriever = Field(exclude=True)
|
||||
|
||||
def _get_docs(
|
||||
self,
|
||||
question: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForChainRun,
|
||||
) -> List[Document]:
|
||||
"""Get docs."""
|
||||
return self.retriever.get_relevant_documents(question, callbacks=run_manager.get_child())
|
||||
|
||||
async def _aget_docs(
|
||||
self,
|
||||
question: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForChainRun,
|
||||
) -> List[Document]:
|
||||
"""Get docs."""
|
||||
return await self.retriever.aget_relevant_documents(question, callbacks=run_manager.get_child())
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
"""Return the chain type."""
|
||||
return "retrieval_qa"
|
@@ -0,0 +1,87 @@
|
||||
"""
|
||||
Load question answering chains.
|
||||
For now, only the stuffed chain is modified
|
||||
|
||||
Modified from Original Source
|
||||
|
||||
This code is based on LangChain Ai's langchain, which can be found at
|
||||
https://github.com/langchain-ai/langchain
|
||||
The original code is licensed under the MIT license.
|
||||
"""
|
||||
import copy
|
||||
from typing import Any, Mapping, Optional, Protocol
|
||||
|
||||
from colossalqa.chain.retrieval_qa.stuff import CustomStuffDocumentsChain
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.question_answering import stuff_prompt
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.schema.prompt_template import BasePromptTemplate
|
||||
|
||||
|
||||
class LoadingCallable(Protocol):
|
||||
"""Interface for loading the combine documents chain."""
|
||||
|
||||
def __call__(self, llm: BaseLanguageModel, **kwargs: Any) -> BaseCombineDocumentsChain:
|
||||
"""Callable to load the combine documents chain."""
|
||||
|
||||
|
||||
def _load_stuff_chain(
|
||||
llm: BaseLanguageModel,
|
||||
prompt: Optional[BasePromptTemplate] = None,
|
||||
document_variable_name: str = "context",
|
||||
verbose: Optional[bool] = None,
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> CustomStuffDocumentsChain:
|
||||
_prompt = prompt or stuff_prompt.PROMPT_SELECTOR.get_prompt(llm)
|
||||
if "llm_kwargs" in kwargs:
|
||||
llm_kwargs = copy.deepcopy(kwargs["llm_kwargs"])
|
||||
del kwargs["llm_kwargs"]
|
||||
else:
|
||||
llm_kwargs = {}
|
||||
llm_chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=_prompt,
|
||||
verbose=verbose,
|
||||
callback_manager=callback_manager,
|
||||
callbacks=callbacks,
|
||||
llm_kwargs=llm_kwargs,
|
||||
)
|
||||
return CustomStuffDocumentsChain(
|
||||
llm_chain=llm_chain,
|
||||
document_variable_name=document_variable_name,
|
||||
verbose=verbose,
|
||||
callback_manager=callback_manager,
|
||||
callbacks=callbacks,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def load_qa_chain(
|
||||
llm: BaseLanguageModel,
|
||||
chain_type: str = "stuff",
|
||||
verbose: Optional[bool] = None,
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseCombineDocumentsChain:
|
||||
"""Load question answering chain.
|
||||
|
||||
Args:
|
||||
llm: Language Model to use in the chain.
|
||||
chain_type: Type of document combining chain to use. Should be one of "stuff",
|
||||
"map_reduce", "map_rerank", and "refine".
|
||||
verbose: Whether chains should be run in verbose mode or not. Note that this
|
||||
applies to all chains that make up the final chain.
|
||||
callback_manager: Callback manager to use for the chain.
|
||||
|
||||
Returns:
|
||||
A chain to use for question answering.
|
||||
"""
|
||||
loader_mapping: Mapping[str, LoadingCallable] = {"stuff": _load_stuff_chain}
|
||||
if chain_type not in loader_mapping:
|
||||
raise ValueError(f"Got unsupported chain type: {chain_type}. " f"Should be one of {loader_mapping.keys()}")
|
||||
return loader_mapping[chain_type](llm, verbose=verbose, callback_manager=callback_manager, **kwargs)
|
@@ -0,0 +1,91 @@
|
||||
"""
|
||||
Chain that combines documents by stuffing into context
|
||||
|
||||
Modified from Original Source
|
||||
|
||||
This code is based on LangChain Ai's langchain, which can be found at
|
||||
https://github.com/langchain-ai/langchain
|
||||
The original code is licensed under the MIT license.
|
||||
"""
|
||||
import copy
|
||||
from typing import Any, List
|
||||
|
||||
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.schema import format_document
|
||||
|
||||
|
||||
class CustomStuffDocumentsChain(StuffDocumentsChain):
|
||||
"""Chain that combines documents by stuffing into context.
|
||||
|
||||
This chain takes a list of documents and first combines them into a single string.
|
||||
It does this by formatting each document into a string with the `document_prompt`
|
||||
and then joining them together with `document_separator`. It then adds that new
|
||||
string to the inputs with the variable name set by `document_variable_name`.
|
||||
Those inputs are then passed to the `llm_chain`.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.chains import StuffDocumentsChain, LLMChain
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.llms import OpenAI
|
||||
|
||||
# This controls how each document will be formatted. Specifically,
|
||||
# it will be passed to `format_document` - see that function for more
|
||||
# details.
|
||||
document_prompt = PromptTemplate(
|
||||
input_variables=["page_content"],
|
||||
template="{page_content}"
|
||||
)
|
||||
document_variable_name = "context"
|
||||
llm = OpenAI()
|
||||
# The prompt here should take as an input variable the
|
||||
# `document_variable_name`
|
||||
prompt = PromptTemplate.from_template(
|
||||
"Summarize this content: {context}"
|
||||
)
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||
chain = StuffDocumentsChain(
|
||||
llm_chain=llm_chain,
|
||||
document_prompt=document_prompt,
|
||||
document_variable_name=document_variable_name
|
||||
)
|
||||
"""
|
||||
|
||||
def _get_inputs(self, docs: List[Document], **kwargs: Any) -> dict:
|
||||
"""Construct inputs from kwargs and docs.
|
||||
|
||||
Format and the join all the documents together into one input with name
|
||||
`self.document_variable_name`. The pluck any additional variables
|
||||
from **kwargs.
|
||||
|
||||
Args:
|
||||
docs: List of documents to format and then join into single input
|
||||
**kwargs: additional inputs to chain, will pluck any other required
|
||||
arguments from here.
|
||||
|
||||
Returns:
|
||||
dictionary of inputs to LLMChain
|
||||
"""
|
||||
# Format each document according to the prompt
|
||||
|
||||
# if the document is in the key-value format has a 'is_key_value_mapping'=True in meta_data and has 'value' in metadata
|
||||
# use the value to replace the key
|
||||
doc_prefix = kwargs.get("doc_prefix", "Supporting Document")
|
||||
docs_ = []
|
||||
for id, doc in enumerate(docs):
|
||||
doc_ = copy.deepcopy(doc)
|
||||
if doc_.metadata.get("is_key_value_mapping", False) and "value" in doc_.metadata:
|
||||
doc_.page_content = str(doc_.metadata["value"])
|
||||
prefix = doc_prefix + str(id)
|
||||
doc_.page_content = str(prefix + ":" + (" " if doc_.page_content[0] != " " else "") + doc_.page_content)
|
||||
docs_.append(doc_)
|
||||
|
||||
doc_strings = [format_document(doc, self.document_prompt) for doc in docs_]
|
||||
arg_list = ["stop", "temperature", "top_k", "top_p", "max_new_tokens"]
|
||||
arg_list.extend(self.llm_chain.prompt.input_variables)
|
||||
# Join the documents together to put them in the prompt.
|
||||
inputs = {k: v for k, v in kwargs.items() if k in arg_list}
|
||||
inputs[self.document_variable_name] = self.document_separator.join(doc_strings)
|
||||
return inputs
|
@@ -0,0 +1,128 @@
|
||||
"""
|
||||
Class for loading document type data
|
||||
"""
|
||||
|
||||
import glob
|
||||
from typing import List
|
||||
|
||||
from colossalqa.mylogging import get_logger
|
||||
from langchain.document_loaders import (
|
||||
JSONLoader,
|
||||
PyPDFLoader,
|
||||
TextLoader,
|
||||
UnstructuredHTMLLoader,
|
||||
UnstructuredMarkdownLoader,
|
||||
)
|
||||
from langchain.document_loaders.csv_loader import CSVLoader
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
SUPPORTED_DATA_FORMAT = [".csv", ".json", ".html", ".md", ".pdf", ".txt", ".jsonl"]
|
||||
|
||||
|
||||
class DocumentLoader:
|
||||
"""
|
||||
Load documents from different files into list of langchain Documents
|
||||
"""
|
||||
|
||||
def __init__(self, files: List, **kwargs) -> None:
|
||||
"""
|
||||
Args:
|
||||
files: list of files (list[file path, name])
|
||||
**kwargs: keyword type arguments, useful for certain document types
|
||||
"""
|
||||
self.data = {}
|
||||
self.kwargs = kwargs
|
||||
|
||||
for item in files:
|
||||
path = item[0] if isinstance(item, list) else item
|
||||
logger.info(f"Loading data from {path}")
|
||||
self.load_data(path)
|
||||
logger.info("Data loaded")
|
||||
|
||||
self.all_data = []
|
||||
for key in self.data:
|
||||
if isinstance(self.data[key], list):
|
||||
for item in self.data[key]:
|
||||
if isinstance(item, list):
|
||||
self.all_data.extend(item)
|
||||
else:
|
||||
self.all_data.append(item)
|
||||
|
||||
def load_data(self, path: str) -> None:
|
||||
"""
|
||||
Load data. Please refer to https://python.langchain.com/docs/modules/data_connection/document_loaders/
|
||||
for sepcific format requirements.
|
||||
Args:
|
||||
path: path to a file
|
||||
To load files with glob path, here are some examples.
|
||||
Load all file from directory: folder1/folder2/*
|
||||
Load all pdf file from directory: folder1/folder2/*.pdf
|
||||
"""
|
||||
files = []
|
||||
|
||||
# Handle glob expression
|
||||
try:
|
||||
files = glob.glob(path)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
if len(files) == 0:
|
||||
raise ValueError("Unsupported file/directory format. For directories, please use glob expression")
|
||||
elif len(files) == 1:
|
||||
path = files[0]
|
||||
else:
|
||||
for file in files:
|
||||
self.load_data(file)
|
||||
return
|
||||
|
||||
# Load data if the path is a file
|
||||
logger.info(f"load {path}", verbose=True)
|
||||
if path.endswith(".csv"):
|
||||
# Load csv
|
||||
loader = CSVLoader(file_path=path, encoding="utf8")
|
||||
data = loader.load()
|
||||
self.data[path] = data
|
||||
elif path.endswith(".txt"):
|
||||
# Load txt
|
||||
loader = TextLoader(path, encoding="utf8")
|
||||
data = loader.load()
|
||||
self.data[path] = data
|
||||
elif path.endswith("html"):
|
||||
# Load html
|
||||
loader = UnstructuredHTMLLoader(path, encoding="utf8")
|
||||
data = loader.load()
|
||||
self.data[path] = data
|
||||
elif path.endswith("json"):
|
||||
# Load json
|
||||
loader = JSONLoader(
|
||||
file_path=path,
|
||||
jq_schema=self.kwargs.get("jq_schema", ".data[]"),
|
||||
content_key=self.kwargs.get("content_key", "content"),
|
||||
metadata_func=self.kwargs.get("metadata_func", None),
|
||||
)
|
||||
|
||||
data = loader.load()
|
||||
self.data[path] = data
|
||||
elif path.endswith("jsonl"):
|
||||
# Load jsonl
|
||||
loader = JSONLoader(
|
||||
file_path=path, jq_schema=self.kwargs.get("jq_schema", ".data[].content"), json_lines=True
|
||||
)
|
||||
data = loader.load()
|
||||
self.data[path] = data
|
||||
elif path.endswith(".md"):
|
||||
# Load markdown
|
||||
loader = UnstructuredMarkdownLoader(path)
|
||||
data = loader.load()
|
||||
self.data[path] = data
|
||||
elif path.endswith(".pdf"):
|
||||
# Load pdf
|
||||
loader = PyPDFLoader(path)
|
||||
data = loader.load_and_split()
|
||||
self.data[path] = data
|
||||
else:
|
||||
if "." in path.split("/")[-1]:
|
||||
raise ValueError(f"Unsupported file format {path}. Supported formats: {SUPPORTED_DATA_FORMAT}")
|
||||
else:
|
||||
# May ba a directory, we strictly follow the glob path and will not load files in subdirectories
|
||||
pass
|
@@ -0,0 +1,119 @@
|
||||
'''
|
||||
Class for loading table type data. please refer to Pandas-Input/Output for file format details.
|
||||
'''
|
||||
|
||||
|
||||
import os
|
||||
import glob
|
||||
import pandas as pd
|
||||
from sqlalchemy import create_engine
|
||||
from colossalqa.utils import drop_table
|
||||
from colossalqa.mylogging import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
SUPPORTED_DATA_FORMAT = ['.csv','.xlsx', '.xls','.json','.html','.h5', '.hdf5','.parquet','.feather','.dta']
|
||||
|
||||
class TableLoader:
|
||||
'''
|
||||
Load tables from different files and serve a sql database for database operations
|
||||
'''
|
||||
def __init__(self, files: str,
|
||||
sql_path:str='sqlite:///mydatabase.db',
|
||||
verbose=False, **kwargs) -> None:
|
||||
'''
|
||||
Args:
|
||||
files: list of files (list[file path, name])
|
||||
sql_path: how to serve the sql database
|
||||
**kwargs: keyword type arguments, useful for certain document types
|
||||
'''
|
||||
self.data = {}
|
||||
self.verbose = verbose
|
||||
self.sql_path = sql_path
|
||||
self.kwargs = kwargs
|
||||
self.sql_engine = create_engine(self.sql_path)
|
||||
drop_table(self.sql_engine)
|
||||
|
||||
self.sql_engine = create_engine(self.sql_path)
|
||||
for item in files:
|
||||
path = item[0]
|
||||
dataset_name = item[1]
|
||||
if not os.path.exists(path):
|
||||
raise FileNotFoundError(f"{path} doesn't exists")
|
||||
if not any([path.endswith(i) for i in SUPPORTED_DATA_FORMAT]):
|
||||
raise TypeError(f"{path} not supported. Supported type {SUPPORTED_DATA_FORMAT}")
|
||||
|
||||
logger.info("loading data", verbose=self.verbose)
|
||||
self.load_data(path)
|
||||
logger.info("data loaded", verbose=self.verbose)
|
||||
self.to_sql(path, dataset_name)
|
||||
|
||||
def load_data(self, path):
|
||||
'''
|
||||
Load data and serve the data as sql database.
|
||||
Data must be in pandas format
|
||||
'''
|
||||
files = []
|
||||
# Handle glob expression
|
||||
try:
|
||||
files = glob.glob(path)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
if len(files)==0:
|
||||
raise ValueError("Unsupported file/directory format. For directories, please use glob expression")
|
||||
elif len(files)==1:
|
||||
path = files[0]
|
||||
else:
|
||||
for file in files:
|
||||
self.load_data(file)
|
||||
|
||||
if path.endswith('.csv'):
|
||||
# Load csv
|
||||
self.data[path] = pd.read_csv(path)
|
||||
elif path.endswith('.xlsx') or path.endswith('.xls'):
|
||||
# Load excel
|
||||
self.data[path] = pd.read_excel(path) # You can adjust the sheet_name as needed
|
||||
elif path.endswith('.json'):
|
||||
# Load json
|
||||
self.data[path] = pd.read_json(path)
|
||||
elif path.endswith('.html'):
|
||||
# Load html
|
||||
html_tables = pd.read_html(path)
|
||||
# Choose the desired table from the list of DataFrame objects
|
||||
self.data[path] = html_tables[0] # You may need to adjust this index
|
||||
elif path.endswith('.h5') or path.endswith('.hdf5'):
|
||||
# Load h5
|
||||
self.data[path] = pd.read_hdf(path, key=self.kwargs.get('key', 'data')) # You can adjust the key as needed
|
||||
elif path.endswith('.parquet'):
|
||||
# Load parquet
|
||||
self.data[path] = pd.read_parquet(path, engine='fastparquet')
|
||||
elif path.endswith('.feather'):
|
||||
# Load feather
|
||||
self.data[path] = pd.read_feather(path)
|
||||
elif path.endswith('.dta'):
|
||||
# Load dta
|
||||
self.data[path] = pd.read_stata(path)
|
||||
else:
|
||||
raise ValueError("Unsupported file format")
|
||||
|
||||
def to_sql(self, path, table_name):
|
||||
'''
|
||||
Serve the data as sql database.
|
||||
'''
|
||||
self.data[path].to_sql(table_name, con=self.sql_engine, if_exists='replace', index=False)
|
||||
logger.info(f"Loaded to Sqlite3\nPath: {path}", verbose=self.verbose)
|
||||
return self.sql_path
|
||||
|
||||
def get_sql_path(self):
|
||||
return self.sql_path
|
||||
|
||||
def __del__(self):
|
||||
if self.sql_engine:
|
||||
drop_table(self.sql_engine)
|
||||
self.sql_engine.dispose()
|
||||
del self.data
|
||||
del self.sql_engine
|
||||
|
||||
|
||||
|
||||
|
125
applications/ColossalQA/colossalqa/local/colossalcloud_llm.py
Normal file
125
applications/ColossalQA/colossalqa/local/colossalcloud_llm.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""
|
||||
LLM wrapper for LLMs running on ColossalCloud Platform
|
||||
|
||||
Usage:
|
||||
|
||||
os.environ['URL'] = ""
|
||||
os.environ['HOST'] = ""
|
||||
|
||||
gen_config = {
|
||||
'max_new_tokens': 100,
|
||||
# 'top_k': 2,
|
||||
'top_p': 0.9,
|
||||
'temperature': 0.5,
|
||||
'repetition_penalty': 2,
|
||||
}
|
||||
|
||||
llm = ColossalCloudLLM(n=1)
|
||||
llm.set_auth_config()
|
||||
resp = llm(prompt='What do you call a three-ton kangaroo?', **gen_config)
|
||||
print(resp) # super-heavyweight awesome-natured yawning Australian creature!
|
||||
|
||||
"""
|
||||
import json
|
||||
from typing import Any, List, Mapping, Optional
|
||||
|
||||
import requests
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
|
||||
class ColossalCloudLLM(LLM):
|
||||
"""
|
||||
A custom LLM class that integrates LLMs running on the ColossalCloud Platform
|
||||
|
||||
"""
|
||||
n: int
|
||||
gen_config: dict = None
|
||||
auth_config: dict = None
|
||||
valid_gen_para: list = ['max_new_tokens', 'top_k',
|
||||
'top_p', 'temperature', 'repetition_penalty']
|
||||
|
||||
def __init__(self, gen_config=None, **kwargs):
|
||||
"""
|
||||
Args:
|
||||
gen_config: config for generation,
|
||||
max_new_tokens: 50 by default
|
||||
top_k: (1, vocab_size)
|
||||
top_p: (0, 1) if not None
|
||||
temperature: (0, inf) if not None
|
||||
repetition_penalty: (1, inf) if not None
|
||||
"""
|
||||
super(ColossalCloudLLM, self).__init__(**kwargs)
|
||||
if gen_config is None:
|
||||
self.gen_config = {"max_new_tokens": 50}
|
||||
else:
|
||||
assert "max_new_tokens" in gen_config, "max_new_tokens is a compulsory key in the gen config"
|
||||
self.gen_config = gen_config
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return {"n": self.n}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return 'ColossalCloudLLM'
|
||||
|
||||
def set_auth_config(self, **kwargs):
|
||||
url = get_from_dict_or_env(kwargs, "url", "URL")
|
||||
host = get_from_dict_or_env(kwargs, "host", "HOST")
|
||||
|
||||
auth_config = {}
|
||||
auth_config['endpoint'] = url
|
||||
auth_config['Host'] = host
|
||||
self.auth_config = auth_config
|
||||
|
||||
def _call(self, prompt: str, stop=None, **kwargs: Any) -> str:
|
||||
"""
|
||||
Args:
|
||||
prompt: The prompt to pass into the model.
|
||||
stop: A list of strings to stop generation when encountered
|
||||
|
||||
Returns:
|
||||
The string generated by the model
|
||||
"""
|
||||
# Update the generation arguments
|
||||
for key, value in kwargs.items():
|
||||
if key not in self.valid_gen_para:
|
||||
raise KeyError(f"Invalid generation parameter: '{key}'. Valid keys are: {', '.join(self.valid_gen_para)}")
|
||||
if key in self.gen_config:
|
||||
self.gen_config[key] = value
|
||||
|
||||
resp_text = self.text_completion(prompt, self.gen_config, self.auth_config)
|
||||
# TODO: This may cause excessive tokens count
|
||||
if stop is not None:
|
||||
for stopping_words in stop:
|
||||
if stopping_words in resp_text:
|
||||
resp_text = resp_text.split(stopping_words)[0]
|
||||
return resp_text
|
||||
|
||||
|
||||
def text_completion(self, prompt, gen_config, auth_config):
|
||||
# Complusory Parameters
|
||||
endpoint = auth_config.pop('endpoint')
|
||||
max_new_tokens = gen_config.pop('max_new_tokens')
|
||||
# Optional Parameters
|
||||
optional_params = ['top_k', 'top_p', 'temperature', 'repetition_penalty'] # Self.optional
|
||||
gen_config = {key: gen_config[key] for key in optional_params if key in gen_config}
|
||||
# Define the data payload
|
||||
data = {
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"history": [
|
||||
{"instruction": prompt, "response": ""}
|
||||
],
|
||||
**gen_config
|
||||
}
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
**auth_config # 'Host',
|
||||
}
|
||||
# Make the POST request
|
||||
response = requests.post(endpoint, headers=headers, data=json.dumps(data))
|
||||
response.raise_for_status() # raise error if return code is not 200(success)
|
||||
# Check the response
|
||||
return response.text
|
183
applications/ColossalQA/colossalqa/local/llm.py
Normal file
183
applications/ColossalQA/colossalqa/local/llm.py
Normal file
@@ -0,0 +1,183 @@
|
||||
"""
|
||||
API and LLM warpper class for running LLMs locally
|
||||
|
||||
Usage:
|
||||
|
||||
import os
|
||||
model_path = os.environ.get("ZH_MODEL_PATH")
|
||||
model_name = "chatglm2"
|
||||
colossal_api = ColossalAPI(model_name, model_path)
|
||||
llm = ColossalLLM(n=1, api=colossal_api)
|
||||
TEST_PROMPT_CHATGLM="续写文章:惊蛰一过,春寒加剧。先是料料峭峭,继而雨季开始,"
|
||||
logger.info(llm(TEST_PROMPT_CHATGLM, max_new_tokens=100), verbose=True)
|
||||
|
||||
"""
|
||||
from typing import Any, List, Mapping, Optional
|
||||
|
||||
import torch
|
||||
from colossalqa.local.utils import get_response, post_http_request
|
||||
from colossalqa.mylogging import get_logger
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class ColossalAPI:
|
||||
"""
|
||||
API for calling LLM.generate
|
||||
"""
|
||||
|
||||
__instances = dict()
|
||||
|
||||
def __init__(self, model_type: str, model_path: str, ckpt_path: str = None) -> None:
|
||||
"""
|
||||
Configurate model
|
||||
"""
|
||||
if model_type + model_path + (ckpt_path or "") in ColossalAPI.__instances:
|
||||
return
|
||||
else:
|
||||
ColossalAPI.__instances[model_type + model_path + (ckpt_path or "")] = self
|
||||
self.model_type = model_type
|
||||
self.model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, trust_remote_code=True)
|
||||
|
||||
if ckpt_path is not None:
|
||||
state_dict = torch.load(ckpt_path)
|
||||
self.model.load_state_dict(state_dict)
|
||||
self.model.to(torch.cuda.current_device())
|
||||
|
||||
# Configurate tokenizer
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
|
||||
self.model.eval()
|
||||
|
||||
@staticmethod
|
||||
def get_api(model_type: str, model_path: str, ckpt_path: str = None):
|
||||
if model_type + model_path + (ckpt_path or "") in ColossalAPI.__instances:
|
||||
return ColossalAPI.__instances[model_type + model_path + (ckpt_path or "")]
|
||||
else:
|
||||
return ColossalAPI(model_type, model_path, ckpt_path)
|
||||
|
||||
def generate(self, input: str, **kwargs) -> str:
|
||||
"""
|
||||
Generate response given the prompt
|
||||
Args:
|
||||
input: input string
|
||||
**kwargs: language model keyword type arguments, such as top_k, top_p, temperature, max_new_tokens...
|
||||
Returns:
|
||||
output: output string
|
||||
"""
|
||||
if self.model_type in ["chatglm", "chatglm2"]:
|
||||
inputs = {
|
||||
k: v.to(torch.cuda.current_device()) for k, v in self.tokenizer(input, return_tensors="pt").items()
|
||||
}
|
||||
else:
|
||||
inputs = {
|
||||
"input_ids": self.tokenizer(input, return_tensors="pt")["input_ids"].to(torch.cuda.current_device())
|
||||
}
|
||||
|
||||
output = self.model.generate(**inputs, **kwargs)
|
||||
output = output.cpu()
|
||||
prompt_len = inputs["input_ids"].size(1)
|
||||
response = output[0, prompt_len:]
|
||||
output = self.tokenizer.decode(response, skip_special_tokens=True)
|
||||
return output
|
||||
|
||||
|
||||
class VllmAPI:
|
||||
def __init__(self, host: str = "localhost", port: int = 8077) -> None:
|
||||
# Configurate api for model served through web
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.url = f"http://{self.host}:{self.port}/generate"
|
||||
|
||||
def generate(self, input: str, **kwargs):
|
||||
output = get_response(post_http_request(input, self.url, **kwargs))[0]
|
||||
return output[len(input) :]
|
||||
|
||||
|
||||
class ColossalLLM(LLM):
|
||||
"""
|
||||
Langchain LLM wrapper for a local LLM
|
||||
"""
|
||||
|
||||
n: int
|
||||
api: Any
|
||||
kwargs = {"max_new_tokens": 100}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "custom"
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
logger.info(f"kwargs:{kwargs}\nstop:{stop}\nprompt:{prompt}", verbose=self.verbose)
|
||||
for k in self.kwargs:
|
||||
if k not in kwargs:
|
||||
kwargs[k] = self.kwargs[k]
|
||||
|
||||
generate_args = {k: kwargs[k] for k in kwargs if k not in ["stop", "n"]}
|
||||
out = self.api.generate(prompt, **generate_args)
|
||||
if isinstance(stop, list) and len(stop) != 0:
|
||||
for stopping_words in stop:
|
||||
if stopping_words in out:
|
||||
out = out.split(stopping_words)[0]
|
||||
logger.info(f"{prompt}{out}", verbose=self.verbose)
|
||||
return out
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, int]:
|
||||
"""Get the identifying parameters."""
|
||||
return {"n": self.n}
|
||||
|
||||
|
||||
class VllmLLM(LLM):
|
||||
"""
|
||||
Langchain LLM wrapper for a local LLM
|
||||
"""
|
||||
|
||||
n: int
|
||||
api: Any
|
||||
kwargs = {"max_new_tokens": 100}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "custom"
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
for k in self.kwargs:
|
||||
if k not in kwargs:
|
||||
kwargs[k] = self.kwargs[k]
|
||||
logger.info(f"kwargs:{kwargs}\nstop:{stop}\nprompt:{prompt}", verbose=self.verbose)
|
||||
generate_args = {k: kwargs[k] for k in kwargs if k in ["n", "max_tokens", "temperature", "stream"]}
|
||||
out = self.api.generate(prompt, **generate_args)
|
||||
if len(stop) != 0:
|
||||
for stopping_words in stop:
|
||||
if stopping_words in out:
|
||||
out = out.split(stopping_words)[0]
|
||||
logger.info(f"{prompt}{out}", verbose=self.verbose)
|
||||
return out
|
||||
|
||||
def set_host_port(self, host: str = "localhost", port: int = 8077, **kwargs) -> None:
|
||||
if "max_tokens" not in kwargs:
|
||||
kwargs["max_tokens"] = 100
|
||||
self.kwargs = kwargs
|
||||
self.api = VllmAPI(host=host, port=port)
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, int]:
|
||||
"""Get the identifying parameters."""
|
||||
return {"n": self.n}
|
||||
|
150
applications/ColossalQA/colossalqa/local/pangu_llm.py
Normal file
150
applications/ColossalQA/colossalqa/local/pangu_llm.py
Normal file
@@ -0,0 +1,150 @@
|
||||
"""
|
||||
LLM wrapper for Pangu
|
||||
|
||||
Usage:
|
||||
|
||||
# URL: “盘古大模型套件管理”->点击“服务管理”->“模型列表”->点击想要使用的模型的“复制路径”
|
||||
# USERNAME: 华为云控制台:“我的凭证”->“API凭证”下的“IAM用户名”,也就是你登录IAM账户的名字
|
||||
# PASSWORD: IAM用户的密码
|
||||
# DOMAIN_NAME: 华为云控制台:“我的凭证”->“API凭证”下的“用户名”,也就是公司管理IAM账户的总账户名
|
||||
|
||||
os.environ["URL"] = ""
|
||||
os.environ["URLNAME"] = ""
|
||||
os.environ["PASSWORD"] = ""
|
||||
os.environ["DOMAIN_NAME"] = ""
|
||||
|
||||
pg = Pangu(id=1)
|
||||
pg.set_auth_config()
|
||||
|
||||
res = pg('你是谁') # 您好,我是华为盘古大模型。我能够通过和您对话互动为您提供帮助。请问您有什么想问我的吗?
|
||||
"""
|
||||
|
||||
import http.client
|
||||
import json
|
||||
from typing import Any, List, Mapping, Optional
|
||||
|
||||
import requests
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
|
||||
class Pangu(LLM):
|
||||
"""
|
||||
A custom LLM class that integrates pangu models
|
||||
|
||||
"""
|
||||
|
||||
n: int
|
||||
gen_config: dict = None
|
||||
auth_config: dict = None
|
||||
|
||||
def __init__(self, gen_config=None, **kwargs):
|
||||
super(Pangu, self).__init__(**kwargs)
|
||||
if gen_config is None:
|
||||
self.gen_config = {"user": "User", "max_tokens": 50, "temperature": 0.95, "n": 1}
|
||||
else:
|
||||
self.gen_config = gen_config
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return {"n": self.n}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "pangu"
|
||||
|
||||
def _call(self, prompt: str, stop: Optional[List[str]] = None, **kwargs) -> str:
|
||||
"""
|
||||
Args:
|
||||
prompt: The prompt to pass into the model.
|
||||
stop: A list of strings to stop generation when encountered
|
||||
|
||||
Returns:
|
||||
The string generated by the model
|
||||
"""
|
||||
# Update the generation arguments
|
||||
for key, value in kwargs.items():
|
||||
if key in self.gen_config:
|
||||
self.gen_config[key] = value
|
||||
|
||||
response = self.text_completion(prompt, self.gen_config, self.auth_config)
|
||||
text = response["choices"][0]["text"]
|
||||
if stop is not None:
|
||||
for stopping_words in stop:
|
||||
if stopping_words in text:
|
||||
text = text.split(stopping_words)[0]
|
||||
return text
|
||||
|
||||
def set_auth_config(self, **kwargs):
|
||||
url = get_from_dict_or_env(kwargs, "url", "URL")
|
||||
username = get_from_dict_or_env(kwargs, "username", "USERNAME")
|
||||
password = get_from_dict_or_env(kwargs, "password", "PASSWORD")
|
||||
domain_name = get_from_dict_or_env(kwargs, "domain_name", "DOMAIN_NAME")
|
||||
|
||||
region = url.split(".")[1]
|
||||
auth_config = {}
|
||||
auth_config["endpoint"] = url[url.find("https://") + 8 : url.find(".com") + 4]
|
||||
auth_config["resource_path"] = url[url.find(".com") + 4 :]
|
||||
auth_config["auth_token"] = self.get_latest_auth_token(region, username, password, domain_name)
|
||||
self.auth_config = auth_config
|
||||
|
||||
def get_latest_auth_token(self, region, username, password, domain_name):
|
||||
url = f"https://iam.{region}.myhuaweicloud.com/v3/auth/tokens"
|
||||
payload = json.dumps(
|
||||
{
|
||||
"auth": {
|
||||
"identity": {
|
||||
"methods": ["password"],
|
||||
"password": {"user": {"name": username, "password": password, "domain": {"name": domain_name}}},
|
||||
},
|
||||
"scope": {"project": {"name": region}},
|
||||
}
|
||||
}
|
||||
)
|
||||
headers = {"Content-Type": "application/json"}
|
||||
|
||||
response = requests.request("POST", url, headers=headers, data=payload)
|
||||
return response.headers["X-Subject-Token"]
|
||||
|
||||
def text_completion(self, text, gen_config, auth_config):
|
||||
conn = http.client.HTTPSConnection(auth_config["endpoint"])
|
||||
payload = json.dumps(
|
||||
{
|
||||
"prompt": text,
|
||||
"user": gen_config["user"],
|
||||
"max_tokens": gen_config["max_tokens"],
|
||||
"temperature": gen_config["temperature"],
|
||||
"n": gen_config["n"],
|
||||
}
|
||||
)
|
||||
headers = {
|
||||
"X-Auth-Token": auth_config["auth_token"],
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
conn.request("POST", auth_config["resource_path"], payload, headers)
|
||||
res = conn.getresponse()
|
||||
data = res.read()
|
||||
data = json.loads(data.decode("utf-8"))
|
||||
return data
|
||||
|
||||
def chat_model(self, messages, gen_config, auth_config):
|
||||
conn = http.client.HTTPSConnection(auth_config["endpoint"])
|
||||
payload = json.dumps(
|
||||
{
|
||||
"messages": messages,
|
||||
"user": gen_config["user"],
|
||||
"max_tokens": gen_config["max_tokens"],
|
||||
"temperature": gen_config["temperature"],
|
||||
"n": gen_config["n"],
|
||||
}
|
||||
)
|
||||
headers = {
|
||||
"X-Auth-Token": auth_config["auth_token"],
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
conn.request("POST", auth_config["resource_path"], payload, headers)
|
||||
res = conn.getresponse()
|
||||
data = res.read()
|
||||
data = json.loads(data.decode("utf-8"))
|
||||
return data
|
29
applications/ColossalQA/colossalqa/local/utils.py
Normal file
29
applications/ColossalQA/colossalqa/local/utils.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""
|
||||
Generation utilities
|
||||
"""
|
||||
import json
|
||||
from typing import List
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
def post_http_request(
|
||||
prompt: str, api_url: str, n: int = 1, max_tokens: int = 100, temperature: float = 0.0, stream: bool = False
|
||||
) -> requests.Response:
|
||||
headers = {"User-Agent": "Test Client"}
|
||||
pload = {
|
||||
"prompt": prompt,
|
||||
"n": 1,
|
||||
"use_beam_search": False,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"stream": stream,
|
||||
}
|
||||
response = requests.post(api_url, headers=headers, json=pload, stream=True, timeout=3)
|
||||
return response
|
||||
|
||||
|
||||
def get_response(response: requests.Response) -> List[str]:
|
||||
data = json.loads(response.content)
|
||||
output = data["text"]
|
||||
return output
|
168
applications/ColossalQA/colossalqa/memory.py
Normal file
168
applications/ColossalQA/colossalqa/memory.py
Normal file
@@ -0,0 +1,168 @@
|
||||
"""
|
||||
Implement a memory class for storing conversation history
|
||||
Support long term and short term memory
|
||||
"""
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from colossalqa.chain.memory.summary import ConversationSummaryMemory
|
||||
from colossalqa.chain.retrieval_qa.load_chain import load_qa_chain
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.memory.chat_message_histories.in_memory import ChatMessageHistory
|
||||
from langchain.schema import BaseChatMessageHistory
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain.schema.retriever import BaseRetriever
|
||||
from pydantic import Field
|
||||
|
||||
|
||||
class ConversationBufferWithSummary(ConversationSummaryMemory):
|
||||
"""Memory class for storing information about entities."""
|
||||
|
||||
# Define dictionary to store information about entities.
|
||||
# Store the most recent conversation history
|
||||
buffered_history: BaseChatMessageHistory = Field(default_factory=ChatMessageHistory)
|
||||
# Temp buffer
|
||||
summarized_history_temp: BaseChatMessageHistory = Field(default_factory=ChatMessageHistory)
|
||||
human_prefix: str = "Human"
|
||||
ai_prefix: str = "Assistant"
|
||||
buffer: str = "" # Formated conversation in str
|
||||
existing_summary: str = "" # Summarization of stale converstion in str
|
||||
# Define key to pass information about entities into prompt.
|
||||
memory_key: str = "chat_history"
|
||||
input_key: str = "question"
|
||||
retriever: BaseRetriever = None
|
||||
max_tokens: int = 2000
|
||||
chain: BaseCombineDocumentsChain = None
|
||||
input_chain_type_kwargs: List = {}
|
||||
|
||||
@property
|
||||
def buffer(self) -> Any:
|
||||
"""String buffer of memory."""
|
||||
return self.buffer_as_messages if self.return_messages else self.buffer_as_str
|
||||
|
||||
@property
|
||||
def buffer_as_str(self) -> str:
|
||||
"""Exposes the buffer as a string in case return_messages is True."""
|
||||
self.buffer = self.format_dialogue()
|
||||
return self.buffer
|
||||
|
||||
@property
|
||||
def buffer_as_messages(self) -> List[BaseMessage]:
|
||||
"""Exposes the buffer as a list of messages in case return_messages is False."""
|
||||
return self.buffered_history.messages
|
||||
|
||||
def clear(self):
|
||||
"""Clear all the memory"""
|
||||
self.buffered_history.clear()
|
||||
self.summarized_history_temp.clear()
|
||||
|
||||
def initiate_document_retrieval_chain(
|
||||
self, llm: Any, prompt_template: Any, retriever: Any, chain_type_kwargs: Dict[str, Any] = {}
|
||||
) -> None:
|
||||
"""
|
||||
Since we need to calculate the length of the prompt, we need to initiate a retrieval chain
|
||||
to calculate the length of the prompt.
|
||||
Args:
|
||||
llm: the language model for the retrieval chain (we won't actually return the output)
|
||||
prompt_template: the prompt template for constructing the retrieval chain
|
||||
retriever: the retriever for the retrieval chain
|
||||
max_tokens: the max length of the prompt (not include the output)
|
||||
chain_type_kwargs: the kwargs for the retrieval chain
|
||||
memory_key: the key for the chat history
|
||||
input_key: the key for the input query
|
||||
"""
|
||||
self.retriever = retriever
|
||||
input_chain_type_kwargs = {k: v for k, v in chain_type_kwargs.items() if k not in [self.memory_key]}
|
||||
self.input_chain_type_kwargs = input_chain_type_kwargs
|
||||
self.chain = load_qa_chain(llm, chain_type="stuff", prompt=prompt_template, **self.input_chain_type_kwargs)
|
||||
|
||||
@property
|
||||
def memory_variables(self) -> List[str]:
|
||||
"""Define the variables we are providing to the prompt."""
|
||||
return [self.memory_key]
|
||||
|
||||
def format_dialogue(self, lang: str = "en") -> str:
|
||||
"""Format memory into two parts--- summarization of historical conversation and most recent conversation"""
|
||||
if len(self.summarized_history_temp.messages) != 0:
|
||||
for i in range(int(len(self.summarized_history_temp.messages) / 2)):
|
||||
self.existing_summary = (
|
||||
self.predict_new_summary(
|
||||
self.summarized_history_temp.messages[i * 2 : i * 2 + 2], self.existing_summary, stop=["\n\n"]
|
||||
)
|
||||
.strip()
|
||||
.split("\n")[0]
|
||||
.strip()
|
||||
)
|
||||
for i in range(int(len(self.summarized_history_temp.messages) / 2)):
|
||||
self.summarized_history_temp.messages.pop(0)
|
||||
self.summarized_history_temp.messages.pop(0)
|
||||
conversation_buffer = []
|
||||
for t in self.buffered_history.messages:
|
||||
if t.type == "human":
|
||||
prefix = self.human_prefix
|
||||
else:
|
||||
prefix = self.ai_prefix
|
||||
conversation_buffer.append(prefix + ": " + t.content)
|
||||
conversation_buffer = "\n".join(conversation_buffer)
|
||||
if len(self.existing_summary) > 0:
|
||||
if lang == "en":
|
||||
message = f"A summarization of historical conversation:\n{self.existing_summary}\nMost recent conversation:\n{conversation_buffer}"
|
||||
elif lang == "zh":
|
||||
message = f"历史对话概要:\n{self.existing_summary}\n最近的对话:\n{conversation_buffer}"
|
||||
else:
|
||||
raise ValueError("Unsupported language")
|
||||
return message
|
||||
else:
|
||||
message = conversation_buffer
|
||||
return message
|
||||
|
||||
def get_conversation_length(self):
|
||||
"""Get the length of the formatted conversation"""
|
||||
prompt = self.format_dialogue()
|
||||
length = self.llm.get_num_tokens(prompt)
|
||||
return length
|
||||
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
"""Load the memory variables.
|
||||
Summarize oversize conversation to fit into the length constraint defined by max_tokene
|
||||
Args:
|
||||
inputs: the kwargs of the chain of your definition
|
||||
Returns:
|
||||
a dict that maps from memory key to the formated dialogue
|
||||
the formated dialogue has the following format
|
||||
if conversation is too long:
|
||||
A summarization of historical conversation:
|
||||
{summarization}
|
||||
Most recent conversation:
|
||||
Human: XXX
|
||||
Assistant: XXX
|
||||
...
|
||||
otherwise
|
||||
Human: XXX
|
||||
Assistant: XXX
|
||||
...
|
||||
"""
|
||||
# Calculate remain length
|
||||
if "input_documents" in inputs:
|
||||
# Run in a retrieval qa chain
|
||||
docs = inputs["input_documents"]
|
||||
else:
|
||||
# For test
|
||||
docs = self.retriever.get_relevant_documents(inputs[self.input_key])
|
||||
inputs[self.memory_key] = ""
|
||||
inputs = {k: v for k, v in inputs.items() if k in [self.chain.input_key, self.input_key, self.memory_key]}
|
||||
prompt_length = self.chain.prompt_length(docs, **inputs)
|
||||
remain = self.max_tokens - prompt_length
|
||||
while self.get_conversation_length() > remain:
|
||||
if len(self.buffered_history.messages) <= 2:
|
||||
raise RuntimeError("Exeeed max_tokens, trunck size of retrieved documents is too large")
|
||||
temp = self.buffered_history.messages.pop(0)
|
||||
self.summarized_history_temp.messages.append(temp)
|
||||
temp = self.buffered_history.messages.pop(0)
|
||||
self.summarized_history_temp.messages.append(temp)
|
||||
return {self.memory_key: self.format_dialogue()}
|
||||
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
"""Save context from this conversation to buffer."""
|
||||
input_str, output_str = self._get_input_output(inputs, outputs)
|
||||
self.buffered_history.add_user_message(input_str.strip())
|
||||
self.buffered_history.add_ai_message(output_str.strip())
|
92
applications/ColossalQA/colossalqa/mylogging.py
Normal file
92
applications/ColossalQA/colossalqa/mylogging.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""
|
||||
Class for logging with extra control for debugging
|
||||
"""
|
||||
import logging
|
||||
|
||||
|
||||
class ColossalQALogger:
|
||||
"""This is a distributed event logger class essentially based on :class:`logging`.
|
||||
|
||||
Args:
|
||||
name (str): The name of the logger.
|
||||
|
||||
Note:
|
||||
Logging types: ``info``, ``warning``, ``debug`` and ``error``
|
||||
"""
|
||||
|
||||
__instances = dict()
|
||||
|
||||
def __init__(self, name):
|
||||
if name in ColossalQALogger.__instances:
|
||||
raise ValueError("Logger with the same name has been created")
|
||||
else:
|
||||
self._name = name
|
||||
self._logger = logging.getLogger(name)
|
||||
|
||||
ColossalQALogger.__instances[name] = self
|
||||
|
||||
@staticmethod
|
||||
def get_instance(name: str):
|
||||
"""Get the unique single logger instance based on name.
|
||||
|
||||
Args:
|
||||
name (str): The name of the logger.
|
||||
|
||||
Returns:
|
||||
DistributedLogger: A DistributedLogger object
|
||||
"""
|
||||
if name in ColossalQALogger.__instances:
|
||||
return ColossalQALogger.__instances[name]
|
||||
else:
|
||||
logger = ColossalQALogger(name=name)
|
||||
return logger
|
||||
|
||||
def info(self, message: str, verbose: bool = False) -> None:
|
||||
"""Log an info message.
|
||||
|
||||
Args:
|
||||
message (str): The message to be logged.
|
||||
verbose (bool): Whether to print the message to stdout.
|
||||
"""
|
||||
if verbose:
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
self._logger.info(message)
|
||||
|
||||
def warning(self, message: str, verbose: bool = False) -> None:
|
||||
"""Log a warning message.
|
||||
|
||||
Args:
|
||||
message (str): The message to be logged.
|
||||
verbose (bool): Whether to print the message to stdout.
|
||||
"""
|
||||
if verbose:
|
||||
self._logger.warning(message)
|
||||
|
||||
def debug(self, message: str, verbose: bool = False) -> None:
|
||||
"""Log a debug message.
|
||||
|
||||
Args:
|
||||
message (str): The message to be logged.
|
||||
verbose (bool): Whether to print the message to stdout.
|
||||
"""
|
||||
if verbose:
|
||||
self._logger.debug(message)
|
||||
|
||||
def error(self, message: str) -> None:
|
||||
"""Log an error message.
|
||||
|
||||
Args:
|
||||
message (str): The message to be logged.
|
||||
"""
|
||||
self._logger.error(message)
|
||||
|
||||
|
||||
def get_logger(name: str = None, level=logging.INFO) -> ColossalQALogger:
|
||||
"""
|
||||
Get the logger by name, if name is None, return the default logger
|
||||
"""
|
||||
if name:
|
||||
logger = ColossalQALogger.get_instance(name=name)
|
||||
else:
|
||||
logger = ColossalQALogger.get_instance(name="colossalqa")
|
||||
return logger
|
144
applications/ColossalQA/colossalqa/prompt/README.md
Normal file
144
applications/ColossalQA/colossalqa/prompt/README.md
Normal file
@@ -0,0 +1,144 @@
|
||||
# Prompt Design Guide
|
||||
|
||||
For the retriever conversation system, users can customize three prompts.
|
||||
|
||||
## The Retrieval QA Prompt
|
||||
This is the prompt for retrieval QA, the input is user's inputs, the retrieved documents, the historical conversation.
|
||||
|
||||
### Chinese
|
||||
```
|
||||
你是一个善于解答用户问题的AI助手。在保证安全的前提下,回答问题要尽可能有帮助。你的答案不应该包含任何有害的、不道德的、种族主义的、性别歧视的、危险的或非法的内容。请确保你的回答是公正和积极的。
|
||||
如果不能根据给定的上下文推断出答案,请不要分享虚假、不确定的信息。
|
||||
使用提供的背景信息和聊天记录对用户的输入作出回应或继续对话。您应该只生成一个回复。不需要跟进回答。请使用中文作答。
|
||||
|
||||
背景信息:
|
||||
[retrieved documents]
|
||||
|
||||
聊天记录:
|
||||
[historical conversation, overlength chat history will be summarized]
|
||||
|
||||
用户: [question]
|
||||
Assistant:
|
||||
```
|
||||
|
||||
### English
|
||||
```
|
||||
[INST] <<SYS>>Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
|
||||
If the answer cannot be infered based on the given context, please don't share false information.<</SYS>>
|
||||
Use the context and chat history to respond to the human's input at the end or carry on the conversation. You should generate one response only. No following up is needed.
|
||||
|
||||
context:
|
||||
[retrieved documents]
|
||||
|
||||
chat history
|
||||
[historical conversation, overlength chat history will be summarized]
|
||||
|
||||
Human: {question}
|
||||
Assistant:
|
||||
```
|
||||
|
||||
## Summarization Prompt
|
||||
This prompt is used by the memory module to recursively summarize overlength conversation to shrink the length of the prompt.
|
||||
|
||||
## Disambiguity Prompt
|
||||
This prompt is used to perform zero-shot reference resolution to disambiguate entity references within user's questions.
|
||||
|
||||
## Final Prompt Examples
|
||||
Assume k=3 for the retriever.
|
||||
|
||||
### English
|
||||
Note that the "[INST] <<SYS>>...<</SYS>>" template is the specific prompt format used in LLaMA2.
|
||||
#### Normal Length
|
||||
```
|
||||
[INST] <<SYS>>Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
|
||||
If the answer cannot be infered based on the given context, please don't share false information.<</SYS>>
|
||||
Use the context and chat history to respond to the human's input at the end or carry on the conversation. You should generate one response only. No following up is needed.
|
||||
|
||||
context:
|
||||
[document 1]
|
||||
|
||||
[document 2]
|
||||
|
||||
[document 3]
|
||||
|
||||
chat history
|
||||
Human: XXX
|
||||
Assistant: XXX
|
||||
...
|
||||
|
||||
Human: {question}
|
||||
Assistant:
|
||||
```
|
||||
|
||||
#### Overlength
|
||||
```
|
||||
[INST] <<SYS>>Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
|
||||
If the answer cannot be infered based on the given context, please don't share false information.<</SYS>>
|
||||
Use the context and chat history to respond to the human's input at the end or carry on the conversation. You should generate one response only. No following up is needed.
|
||||
|
||||
context:
|
||||
[document 1]
|
||||
|
||||
[document 2]
|
||||
|
||||
[document 3]
|
||||
|
||||
chat history
|
||||
A summarization of historical conversation:
|
||||
[one line summary of historical conversation]
|
||||
Most recent conversation:
|
||||
Human: XXX
|
||||
Assistant: XXX
|
||||
...
|
||||
|
||||
Human: {question}
|
||||
Assistant:
|
||||
```
|
||||
|
||||
### Chinese
|
||||
#### Normal Length
|
||||
```
|
||||
你是一个善于解答用户问题的AI助手。在保证安全的前提下,回答问题要尽可能有帮助。你的答案不应该包含任何有害的、不道德的、种族主义的、性别歧视的、危险的或非法的内容。请确保你的回答是公正和积极的。
|
||||
如果不能根据给定的上下文推断出答案,请不要分享虚假、不确定的信息。
|
||||
使用提供的背景信息和聊天记录对用户的输入作出回应或继续对话。您应该只生成一个回复。不需要跟进回答。请使用中文作答。
|
||||
|
||||
背景信息:
|
||||
[document 1]
|
||||
|
||||
[document 2]
|
||||
|
||||
[document 3]
|
||||
|
||||
聊天记录:
|
||||
用户: XXX
|
||||
Assistant: XXX
|
||||
...
|
||||
|
||||
用户: [question]
|
||||
Assistant:
|
||||
```
|
||||
|
||||
#### Overlength
|
||||
```
|
||||
你是一个善于解答用户问题的AI助手。在保证安全的前提下,回答问题要尽可能有帮助。你的答案不应该包含任何有害的、不道德的、种族主义的、性别歧视的、危险的或非法的内容。请确保你的回答是公正和积极的。
|
||||
如果不能根据给定的上下文推断出答案,请不要分享虚假、不确定的信息。
|
||||
使用提供的背景信息和聊天记录对用户的输入作出回应或继续对话。您应该只生成一个回复。不需要跟进回答。请使用中文作答。
|
||||
|
||||
背景信息:
|
||||
[document 1]
|
||||
|
||||
[document 2]
|
||||
|
||||
[document 3]
|
||||
|
||||
聊天记录:
|
||||
历史对话概要:
|
||||
[one line summary of historical conversation]
|
||||
最近的对话:
|
||||
用户: XXX
|
||||
Assistant: XXX
|
||||
...
|
||||
|
||||
用户: [question]
|
||||
Assistant:
|
||||
```
|
124
applications/ColossalQA/colossalqa/prompt/prompt.py
Normal file
124
applications/ColossalQA/colossalqa/prompt/prompt.py
Normal file
@@ -0,0 +1,124 @@
|
||||
"""
|
||||
All custom prompt templates are defined here.
|
||||
"""
|
||||
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
|
||||
_CUSTOM_SUMMARIZER_TEMPLATE_ZH = """请递进式地总结所提供的当前对话,将当前对话的摘要内容添加到先前已有的摘要上,返回一个融合了当前对话的新的摘要。
|
||||
|
||||
例1:
|
||||
已有的摘要:
|
||||
人类问Assistant对人工智能的看法。人工智能认为人工智能是一种善的力量。
|
||||
|
||||
新的对话内容:
|
||||
人类: 为什么你认为人工智能是一种好的力量?
|
||||
Assistant: 因为人工智能将帮助人类充分发挥潜力。
|
||||
|
||||
新的摘要:
|
||||
人类问Assistant对人工智能的看法。人工智能认为人工智能是一种积极的力量,因为它将帮助人类充分发挥潜力。
|
||||
示例结束
|
||||
|
||||
已有的摘要:
|
||||
{summary}
|
||||
|
||||
新的对话内容:
|
||||
{new_lines}
|
||||
|
||||
新的摘要:"""
|
||||
|
||||
|
||||
# Chinese retrieval qa prompt
|
||||
|
||||
_ZH_RETRIEVAL_QA_PROMPT = """<指令>根据下列支持文档和对话历史,简洁和专业地来回答问题。如果无法从支持文档中得到答案,请说 “根据已知信息无法回答该问题”。回答中请不要涉及支持文档中没有提及的信息,答案请使用中文。 </指令>
|
||||
|
||||
{context}
|
||||
|
||||
<对话历史>
|
||||
{chat_history}
|
||||
</对话历史>
|
||||
|
||||
<问题>{question}</问题>
|
||||
答案:"""
|
||||
|
||||
ZH_RETRIEVAL_QA_TRIGGER_KEYWORDS = ["无法回答该问题"]
|
||||
ZH_RETRIEVAL_QA_REJECTION_ANSWER = "抱歉,根据提供的信息无法回答该问题。"
|
||||
|
||||
|
||||
_ZH_RETRIEVAL_CLASSIFICATION_USE_CASE = """使用提供的参考案例判断客户遇到的故障所属的故障原因分类。
|
||||
|
||||
背景信息:
|
||||
{context}
|
||||
|
||||
客服记录:
|
||||
{question}
|
||||
故障原因分类:"""
|
||||
|
||||
_ZH_DISAMBIGUATION_PROMPT = """你是一个乐于助人、恭敬而诚实的助手。你总是按照指示去做。
|
||||
请用聊天记录中提到的具体名称或实体名称替换给定句子中的任何模糊或有歧义的指代,如果没有提供聊天记录或句子中不包含模糊或有歧义的指代,则只输出原始句子。您的输出应该是消除歧义的句子本身(与“消除歧义的句子:”在同一行中),并且不包含任何其他内容。
|
||||
|
||||
下面是一个例子:
|
||||
聊天记录:
|
||||
用户: 我有一个朋友,张三。你认识他吗?
|
||||
Assistant: 我认识一个叫张三的人
|
||||
|
||||
句子: 他最喜欢的食物是什么?
|
||||
消除歧义的句子: 张三最喜欢的食物是什么?
|
||||
|
||||
聊天记录:
|
||||
{chat_history}
|
||||
|
||||
句子: {input}
|
||||
消除歧义的句子:"""
|
||||
|
||||
# English retrieval qa prompt
|
||||
|
||||
_EN_RETRIEVAL_QA_PROMPT = """[INST] <<SYS>>Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist content.
|
||||
If the answer cannot be infered based on the given context, please say "I cannot answer the question based on the information given.".<</SYS>>
|
||||
Use the context and chat history to answer the question.
|
||||
|
||||
context:
|
||||
{context}
|
||||
|
||||
chat history
|
||||
{chat_history}
|
||||
|
||||
question: {question}
|
||||
answer:"""
|
||||
EN_RETRIEVAL_QA_TRIGGER_KEYWORDS = ["cannot answer the question"]
|
||||
EN_RETRIEVAL_QA_REJECTION_ANSWER = "Sorry, this question cannot be answered based on the information provided."
|
||||
|
||||
_EN_DISAMBIGUATION_PROMPT = """[INST] <<SYS>>You are a helpful, respectful and honest assistant. You always follow the instruction.<</SYS>>
|
||||
Please replace any ambiguous references in the given sentence with the specific names or entities mentioned in the chat history or just output the original sentence if no chat history is provided or if the sentence doesn't contain ambiguous references. Your output should be the disambiguated sentence itself (in the same line as "disambiguated sentence:") and contain nothing else.
|
||||
|
||||
Here is an example:
|
||||
Chat history:
|
||||
Human: I have a friend, Mike. Do you know him?
|
||||
Assistant: Yes, I know a person named Mike
|
||||
|
||||
sentence: What's his favorate food?
|
||||
disambiguated sentence: What's Mike's favorate food?
|
||||
[/INST]
|
||||
Chat history:
|
||||
{chat_history}
|
||||
|
||||
sentence: {input}
|
||||
disambiguated sentence:"""
|
||||
|
||||
|
||||
PROMPT_RETRIEVAL_QA_EN = PromptTemplate(
|
||||
template=_EN_RETRIEVAL_QA_PROMPT, input_variables=["question", "chat_history", "context"]
|
||||
)
|
||||
|
||||
PROMPT_DISAMBIGUATE_EN = PromptTemplate(template=_EN_DISAMBIGUATION_PROMPT, input_variables=["chat_history", "input"])
|
||||
|
||||
SUMMARY_PROMPT_ZH = PromptTemplate(input_variables=["summary", "new_lines"], template=_CUSTOM_SUMMARIZER_TEMPLATE_ZH)
|
||||
|
||||
PROMPT_DISAMBIGUATE_ZH = PromptTemplate(template=_ZH_DISAMBIGUATION_PROMPT, input_variables=["chat_history", "input"])
|
||||
|
||||
PROMPT_RETRIEVAL_QA_ZH = PromptTemplate(
|
||||
template=_ZH_RETRIEVAL_QA_PROMPT, input_variables=["question", "chat_history", "context"]
|
||||
)
|
||||
|
||||
PROMPT_RETRIEVAL_CLASSIFICATION_USE_CASE_ZH = PromptTemplate(
|
||||
template=_ZH_RETRIEVAL_CLASSIFICATION_USE_CASE, input_variables=["question", "context"]
|
||||
)
|
@@ -0,0 +1,87 @@
|
||||
"""
|
||||
Script for Chinese retrieval based conversation system backed by ChatGLM
|
||||
"""
|
||||
from typing import Tuple
|
||||
|
||||
from colossalqa.chain.retrieval_qa.base import RetrievalQA
|
||||
from colossalqa.local.llm import ColossalAPI, ColossalLLM
|
||||
from colossalqa.memory import ConversationBufferWithSummary
|
||||
from colossalqa.mylogging import get_logger
|
||||
from colossalqa.prompt.prompt import PROMPT_DISAMBIGUATE_EN, PROMPT_RETRIEVAL_QA_EN
|
||||
from colossalqa.retriever import CustomRetriever
|
||||
from langchain import LLMChain
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class EnglishRetrievalConversation:
|
||||
"""
|
||||
Wrapper class for Chinese retrieval conversation system
|
||||
"""
|
||||
|
||||
def __init__(self, retriever: CustomRetriever, model_path: str, model_name: str) -> None:
|
||||
"""
|
||||
Setup retrieval qa chain for Chinese retrieval based QA
|
||||
"""
|
||||
logger.info(f"model_name: {model_name}; model_path: {model_path}", verbose=True)
|
||||
colossal_api = ColossalAPI.get_api(model_name, model_path)
|
||||
self.llm = ColossalLLM(n=1, api=colossal_api)
|
||||
|
||||
# Define the retriever
|
||||
self.retriever = retriever
|
||||
|
||||
# Define the chain to preprocess the input
|
||||
# Disambiguate the input. e.g. "What is the capital of that country?" -> "What is the capital of France?"
|
||||
# Prompt is summarization prompt
|
||||
self.llm_chain_disambiguate = LLMChain(
|
||||
llm=self.llm,
|
||||
prompt=PROMPT_DISAMBIGUATE_EN,
|
||||
llm_kwargs={"max_new_tokens": 30, "temperature": 0.6, "do_sample": True},
|
||||
)
|
||||
|
||||
self.retriever.set_rephrase_handler(self.disambiguity)
|
||||
# Define memory with summarization ability
|
||||
self.memory = ConversationBufferWithSummary(
|
||||
llm=self.llm, llm_kwargs={"max_new_tokens": 50, "temperature": 0.6, "do_sample": True}
|
||||
)
|
||||
self.memory.initiate_document_retrieval_chain(
|
||||
self.llm,
|
||||
PROMPT_RETRIEVAL_QA_EN,
|
||||
self.retriever,
|
||||
chain_type_kwargs={
|
||||
"chat_history": "",
|
||||
},
|
||||
)
|
||||
self.retrieval_chain = RetrievalQA.from_chain_type(
|
||||
llm=self.llm,
|
||||
verbose=False,
|
||||
chain_type="stuff",
|
||||
retriever=self.retriever,
|
||||
chain_type_kwargs={"prompt": PROMPT_RETRIEVAL_QA_EN, "memory": self.memory},
|
||||
llm_kwargs={"max_new_tokens": 50, "temperature": 0.75, "do_sample": True},
|
||||
)
|
||||
|
||||
def disambiguity(self, input: str):
|
||||
out = self.llm_chain_disambiguate.run(input=input, chat_history=self.memory.buffer, stop=["\n"])
|
||||
return out.split("\n")[0]
|
||||
|
||||
@classmethod
|
||||
def from_retriever(
|
||||
cls, retriever: CustomRetriever, model_path: str, model_name: str
|
||||
) -> "EnglishRetrievalConversation":
|
||||
return cls(retriever, model_path, model_name)
|
||||
|
||||
def run(self, user_input: str, memory: ConversationBufferWithSummary) -> Tuple[str, ConversationBufferWithSummary]:
|
||||
if memory:
|
||||
# TODO add translation chain here
|
||||
self.memory.buffered_history.messages = memory.buffered_history.messages
|
||||
self.memory.summarized_history_temp.messages = memory.summarized_history_temp.messages
|
||||
return (
|
||||
self.retrieval_chain.run(
|
||||
query=user_input,
|
||||
stop=[self.memory.human_prefix + ": "],
|
||||
rejection_trigger_keywrods=["cannot answer the question"],
|
||||
rejection_answer="Sorry, this question cannot be answered based on the information provided.",
|
||||
).split("\n")[0],
|
||||
self.memory,
|
||||
)
|
@@ -0,0 +1,138 @@
|
||||
"""
|
||||
Multilingual retrieval based conversation system
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
from colossalqa.data_loader.document_loader import DocumentLoader
|
||||
from colossalqa.mylogging import get_logger
|
||||
from colossalqa.retrieval_conversation_en import EnglishRetrievalConversation
|
||||
from colossalqa.retrieval_conversation_zh import ChineseRetrievalConversation
|
||||
from colossalqa.retriever import CustomRetriever
|
||||
from colossalqa.text_splitter import ChineseTextSplitter
|
||||
from colossalqa.utils import detect_lang_naive
|
||||
from langchain.embeddings import HuggingFaceEmbeddings
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class UniversalRetrievalConversation:
|
||||
"""
|
||||
Wrapper class for bilingual retrieval conversation system
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_model_path: str = "moka-ai/m3e-base",
|
||||
embedding_model_device: str = "cpu",
|
||||
zh_model_path: str = None,
|
||||
zh_model_name: str = None,
|
||||
en_model_path: str = None,
|
||||
en_model_name: str = None,
|
||||
sql_file_path: str = None,
|
||||
files_zh: List[List[str]] = None,
|
||||
files_en: List[List[str]] = None,
|
||||
text_splitter_chunk_size=100,
|
||||
text_splitter_chunk_overlap=10,
|
||||
) -> None:
|
||||
"""
|
||||
Warpper for multilingual retrieval qa class (Chinese + English)
|
||||
Args:
|
||||
embedding_model_path: local or huggingface embedding model
|
||||
embedding_model_device:
|
||||
files_zh: [[file_path, name_of_file, separator],...] defines the files used as supporting documents for Chinese retrieval QA
|
||||
files_en: [[file_path, name_of_file, separator],...] defines the files used as supporting documents for English retrieval QA
|
||||
"""
|
||||
self.embedding = HuggingFaceEmbeddings(
|
||||
model_name=embedding_model_path,
|
||||
model_kwargs={"device": embedding_model_device},
|
||||
encode_kwargs={"normalize_embeddings": False},
|
||||
)
|
||||
print("Select files for constructing Chinese retriever")
|
||||
docs_zh = self.load_supporting_docs(
|
||||
files=files_zh,
|
||||
text_splitter=ChineseTextSplitter(
|
||||
chunk_size=text_splitter_chunk_size, chunk_overlap=text_splitter_chunk_overlap
|
||||
),
|
||||
)
|
||||
# Create retriever
|
||||
self.information_retriever_zh = CustomRetriever(
|
||||
k=3, sql_file_path=sql_file_path.replace(".db", "_zh.db"), verbose=True
|
||||
)
|
||||
self.information_retriever_zh.add_documents(
|
||||
docs=docs_zh, cleanup="incremental", mode="by_source", embedding=self.embedding
|
||||
)
|
||||
|
||||
print("Select files for constructing English retriever")
|
||||
docs_en = self.load_supporting_docs(
|
||||
files=files_en,
|
||||
text_splitter=RecursiveCharacterTextSplitter(
|
||||
chunk_size=text_splitter_chunk_size, chunk_overlap=text_splitter_chunk_overlap
|
||||
),
|
||||
)
|
||||
# Create retriever
|
||||
self.information_retriever_en = CustomRetriever(
|
||||
k=3, sql_file_path=sql_file_path.replace(".db", "_en.db"), verbose=True
|
||||
)
|
||||
self.information_retriever_en.add_documents(
|
||||
docs=docs_en, cleanup="incremental", mode="by_source", embedding=self.embedding
|
||||
)
|
||||
|
||||
self.chinese_retrieval_conversation = ChineseRetrievalConversation.from_retriever(
|
||||
self.information_retriever_zh, model_path=zh_model_path, model_name=zh_model_name
|
||||
)
|
||||
self.english_retrieval_conversation = EnglishRetrievalConversation.from_retriever(
|
||||
self.information_retriever_en, model_path=en_model_path, model_name=en_model_name
|
||||
)
|
||||
self.memory = None
|
||||
|
||||
def load_supporting_docs(self, files: List[List[str]] = None, text_splitter: TextSplitter = None):
|
||||
"""
|
||||
Load supporting documents, currently, all documents will be stored in one vector store
|
||||
"""
|
||||
documents = []
|
||||
if files:
|
||||
for file in files:
|
||||
retriever_data = DocumentLoader([[file["data_path"], file["name"]]]).all_data
|
||||
splits = text_splitter.split_documents(retriever_data)
|
||||
documents.extend(splits)
|
||||
else:
|
||||
while True:
|
||||
file = input("Select a file to load or press Enter to exit:")
|
||||
if file == "":
|
||||
break
|
||||
data_name = input("Enter a short description of the data:")
|
||||
separator = input(
|
||||
"Enter a separator to force separating text into chunks, if no separator is given, the defaut separator is '\\n\\n', press ENTER directly to skip:"
|
||||
)
|
||||
separator = separator if separator != "" else "\n\n"
|
||||
retriever_data = DocumentLoader([[file, data_name.replace(" ", "_")]]).all_data
|
||||
|
||||
# Split
|
||||
splits = text_splitter.split_documents(retriever_data)
|
||||
documents.extend(splits)
|
||||
return documents
|
||||
|
||||
def start_test_session(self):
|
||||
"""
|
||||
Simple multilingual session for testing purpose, with naive language selection mechanism
|
||||
"""
|
||||
while True:
|
||||
user_input = input("User: ")
|
||||
lang = detect_lang_naive(user_input)
|
||||
if "END" == user_input:
|
||||
print("Agent: Happy to chat with you :)")
|
||||
break
|
||||
agent_response = self.run(user_input, which_language=lang)
|
||||
print(f"Agent: {agent_response}")
|
||||
|
||||
def run(self, user_input: str, which_language=str):
|
||||
"""
|
||||
Generate the response given the user input and a str indicates the language requirement of the output string
|
||||
"""
|
||||
assert which_language in ["zh", "en"]
|
||||
if which_language == "zh":
|
||||
agent_response, self.memory = self.chinese_retrieval_conversation.run(user_input, self.memory)
|
||||
else:
|
||||
agent_response, self.memory = self.english_retrieval_conversation.run(user_input, self.memory)
|
||||
return agent_response.split("\n")[0]
|
@@ -0,0 +1,94 @@
|
||||
"""
|
||||
Script for Chinese retrieval based conversation system backed by ChatGLM
|
||||
"""
|
||||
from typing import Tuple
|
||||
|
||||
from colossalqa.chain.retrieval_qa.base import RetrievalQA
|
||||
from colossalqa.local.llm import ColossalAPI, ColossalLLM
|
||||
from colossalqa.memory import ConversationBufferWithSummary
|
||||
from colossalqa.mylogging import get_logger
|
||||
from colossalqa.prompt.prompt import PROMPT_DISAMBIGUATE_ZH, PROMPT_RETRIEVAL_QA_ZH, SUMMARY_PROMPT_ZH
|
||||
from colossalqa.retriever import CustomRetriever
|
||||
from langchain import LLMChain
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class ChineseRetrievalConversation:
|
||||
"""
|
||||
Wrapper class for Chinese retrieval conversation system
|
||||
"""
|
||||
|
||||
def __init__(self, retriever: CustomRetriever, model_path: str, model_name: str) -> None:
|
||||
"""
|
||||
Setup retrieval qa chain for Chinese retrieval based QA
|
||||
"""
|
||||
# Local coati api
|
||||
logger.info(f"model_name: {model_name}; model_path: {model_path}", verbose=True)
|
||||
colossal_api = ColossalAPI.get_api(model_name, model_path)
|
||||
self.llm = ColossalLLM(n=1, api=colossal_api)
|
||||
|
||||
# Define the retriever
|
||||
self.retriever = retriever
|
||||
|
||||
# Define the chain to preprocess the input
|
||||
# Disambiguate the input. e.g. "What is the capital of that country?" -> "What is the capital of France?"
|
||||
# Prompt is summarization prompt
|
||||
self.llm_chain_disambiguate = LLMChain(
|
||||
llm=self.llm,
|
||||
prompt=PROMPT_DISAMBIGUATE_ZH,
|
||||
llm_kwargs={"max_new_tokens": 30, "temperature": 0.6, "do_sample": True},
|
||||
)
|
||||
|
||||
self.retriever.set_rephrase_handler(self.disambiguity)
|
||||
# Define memory with summarization ability
|
||||
self.memory = ConversationBufferWithSummary(
|
||||
llm=self.llm,
|
||||
prompt=SUMMARY_PROMPT_ZH,
|
||||
human_prefix="用户",
|
||||
ai_prefix="Assistant",
|
||||
max_tokens=2000,
|
||||
llm_kwargs={"max_new_tokens": 50, "temperature": 0.6, "do_sample": True},
|
||||
)
|
||||
self.memory.initiate_document_retrieval_chain(
|
||||
self.llm,
|
||||
PROMPT_RETRIEVAL_QA_ZH,
|
||||
self.retriever,
|
||||
chain_type_kwargs={
|
||||
"chat_history": "",
|
||||
},
|
||||
)
|
||||
self.retrieval_chain = RetrievalQA.from_chain_type(
|
||||
llm=self.llm,
|
||||
verbose=False,
|
||||
chain_type="stuff",
|
||||
retriever=self.retriever,
|
||||
chain_type_kwargs={"prompt": PROMPT_RETRIEVAL_QA_ZH, "memory": self.memory},
|
||||
llm_kwargs={"max_new_tokens": 150, "temperature": 0.9, "do_sample": True},
|
||||
)
|
||||
|
||||
def disambiguity(self, input: str):
|
||||
out = self.llm_chain_disambiguate.run(input=input, chat_history=self.memory.buffer, stop=["\n"])
|
||||
return out.split("\n")[0]
|
||||
|
||||
@classmethod
|
||||
def from_retriever(
|
||||
cls, retriever: CustomRetriever, model_path: str, model_name: str
|
||||
) -> "ChineseRetrievalConversation":
|
||||
return cls(retriever, model_path, model_name)
|
||||
|
||||
def run(self, user_input: str, memory: ConversationBufferWithSummary) -> Tuple[str, ConversationBufferWithSummary]:
|
||||
if memory:
|
||||
# TODO add translation chain here
|
||||
self.memory.buffered_history.messages = memory.buffered_history.messages
|
||||
self.memory.summarized_history_temp.messages = memory.summarized_history_temp.messages
|
||||
return (
|
||||
self.retrieval_chain.run(
|
||||
query=user_input,
|
||||
stop=["</答案>"],
|
||||
doc_prefix="支持文档",
|
||||
rejection_trigger_keywrods=["无法回答该问题"],
|
||||
rejection_answer="抱歉,根据提供的信息无法回答该问题。",
|
||||
).split("\n")[0],
|
||||
self.memory,
|
||||
)
|
166
applications/ColossalQA/colossalqa/retriever.py
Normal file
166
applications/ColossalQA/colossalqa/retriever.py
Normal file
@@ -0,0 +1,166 @@
|
||||
"""
|
||||
Code for custom retriver with incremental update
|
||||
"""
|
||||
import copy
|
||||
import hashlib
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from typing import Any, Callable, Dict, List
|
||||
|
||||
from colossalqa.mylogging import get_logger
|
||||
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.indexes import SQLRecordManager, index
|
||||
from langchain.schema.retriever import BaseRetriever, Document
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
from langchain.vectorstores.chroma import Chroma
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class CustomRetriever(BaseRetriever):
|
||||
"""
|
||||
Custom retriever class with support for incremental update of indexes
|
||||
"""
|
||||
|
||||
vector_stores: Dict[str, VectorStore] = {}
|
||||
sql_index_database: Dict[str, str] = {}
|
||||
record_managers: Dict[str, SQLRecordManager] = {}
|
||||
sql_db_chains = []
|
||||
k = 3
|
||||
rephrase_handler: Callable = None
|
||||
buffer: Dict = []
|
||||
buffer_size: int = 5
|
||||
verbose: bool = False
|
||||
sql_file_path: str = None
|
||||
|
||||
@classmethod
|
||||
def from_documents(
|
||||
cls,
|
||||
documents: List[Document],
|
||||
embeddings: Embeddings,
|
||||
**kwargs: Any,
|
||||
) -> BaseRetriever:
|
||||
k = kwargs.pop("k", 3)
|
||||
cleanup = kwargs.pop("cleanup", "incremental")
|
||||
mode = kwargs.pop("mode", "by_source")
|
||||
ret = cls(k=k)
|
||||
ret.add_documents(documents, embedding=embeddings, cleanup=cleanup, mode=mode)
|
||||
return ret
|
||||
|
||||
def add_documents(
|
||||
self,
|
||||
docs: Dict[str, Document] = [],
|
||||
cleanup: str = "incremental",
|
||||
mode: str = "by_source",
|
||||
embedding: Embeddings = None,
|
||||
) -> None:
|
||||
"""
|
||||
Add documents to retriever
|
||||
Args:
|
||||
docs: the documents to add
|
||||
cleanup: choose from "incremental" (update embeddings, skip existing embeddings) and "full" (destory and rebuild retriever)
|
||||
mode: choose from "by source" (documents are grouped by source) and "merge" (documents are merged into one vector store)
|
||||
"""
|
||||
if cleanup == "full":
|
||||
# Cleanup
|
||||
for source in self.vector_stores:
|
||||
os.remove(self.sql_index_database[source])
|
||||
# Add documents
|
||||
data_by_source = defaultdict(list)
|
||||
if mode == "by_source":
|
||||
for doc in docs:
|
||||
data_by_source[doc.metadata["source"]].append(doc)
|
||||
elif mode == "merge":
|
||||
data_by_source["merged"] = docs
|
||||
for source in data_by_source:
|
||||
if source not in self.vector_stores:
|
||||
hash_encoding = hashlib.sha3_224(source.encode()).hexdigest()
|
||||
if os.path.exists(f"{self.sql_file_path}/{hash_encoding}.db"):
|
||||
# Remove the stale file
|
||||
os.remove(f"{self.sql_file_path}/{hash_encoding}.db")
|
||||
# Create a new sql database to store indexes, sql files are stored in the same directory as the source file
|
||||
sql_path = f"sqlite:///{self.sql_file_path}/{hash_encoding}.db"
|
||||
self.vector_stores[source] = Chroma(embedding_function=embedding, collection_name=hash_encoding)
|
||||
self.sql_index_database[source] = f"{self.sql_file_path}/{hash_encoding}.db"
|
||||
self.record_managers[source] = SQLRecordManager(source, db_url=sql_path)
|
||||
self.record_managers[source].create_schema()
|
||||
index(
|
||||
data_by_source[source],
|
||||
self.record_managers[source],
|
||||
self.vector_stores[source],
|
||||
cleanup=cleanup,
|
||||
source_id_key="source",
|
||||
)
|
||||
|
||||
def __del__(self):
|
||||
for source in self.sql_index_database:
|
||||
if os.path.exists(self.sql_index_database[source]):
|
||||
os.remove(self.sql_index_database[source])
|
||||
|
||||
def set_sql_database_chain(self, db_chains) -> None:
|
||||
"""
|
||||
set sql agent chain to retrieve information from sql database
|
||||
Not used in this version
|
||||
"""
|
||||
self.sql_db_chains = db_chains
|
||||
|
||||
def set_rephrase_handler(self, handler: Callable = None) -> None:
|
||||
"""
|
||||
Set a handler to preprocess the input str before feed into the retriever
|
||||
"""
|
||||
self.rephrase_handler = handler
|
||||
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun = None,
|
||||
score_threshold: float = None,
|
||||
return_scores: bool = False,
|
||||
) -> List[Document]:
|
||||
"""
|
||||
This function is called by the retriever to get the relevant documents.
|
||||
recent vistied queries are stored in buffer, if the query is in buffer, return the documents directly
|
||||
|
||||
Args:
|
||||
query: the query to be searched
|
||||
run_manager: the callback manager for retriever run
|
||||
Returns:
|
||||
documents: the relevant documents
|
||||
"""
|
||||
for buffered_doc in self.buffer:
|
||||
if buffered_doc[0] == query:
|
||||
return buffered_doc[1]
|
||||
query_ = str(query)
|
||||
# Use your existing retriever to get the documents
|
||||
if self.rephrase_handler:
|
||||
query = self.rephrase_handler(query)
|
||||
documents = []
|
||||
for k in self.vector_stores:
|
||||
# Retrieve documents from each retriever
|
||||
vectorstore = self.vector_stores[k]
|
||||
documents.extend(vectorstore.similarity_search_with_score(query, self.k, score_threshold=score_threshold))
|
||||
# print(documents)
|
||||
# Return the top k documents among all retrievers
|
||||
documents = sorted(documents, key=lambda x: x[1], reverse=False)[: self.k]
|
||||
if return_scores:
|
||||
# Return score
|
||||
documents = copy.deepcopy(documents)
|
||||
for doc in documents:
|
||||
doc[0].metadata["score"] = doc[1]
|
||||
documents = [doc[0] for doc in documents]
|
||||
# Retrieve documents from sql database (not applicable for the local chains)
|
||||
for sql_chain in self.sql_db_chains:
|
||||
documents.append(
|
||||
Document(
|
||||
page_content=f"Query: {query} Answer: {sql_chain.run(query)}", metadata={"source": "sql_query"}
|
||||
)
|
||||
)
|
||||
if len(self.buffer) < self.buffer_size:
|
||||
self.buffer.append([query_, documents])
|
||||
else:
|
||||
self.buffer.pop(0)
|
||||
self.buffer.append([query_, documents])
|
||||
logger.info(f"retrieved documents:\n{str(documents)}", verbose=self.verbose)
|
||||
return documents
|
@@ -0,0 +1 @@
|
||||
from .chinese_text_splitter import ChineseTextSplitter
|
@@ -0,0 +1,56 @@
|
||||
"""
|
||||
Code for Chinese text splitter
|
||||
"""
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from colossalqa.text_splitter.utils import get_cleaned_paragraph
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
|
||||
class ChineseTextSplitter(RecursiveCharacterTextSplitter):
|
||||
def __init__(self, separators: Optional[List[str]] = None, is_separator_regrx: bool = False, **kwargs: Any):
|
||||
self._separators = separators or ["\n\n", "\n", ",", "。", "!", "?", "?"]
|
||||
if "chunk_size" not in kwargs:
|
||||
kwargs["chunk_size"] = 50
|
||||
if "chunk_overlap" not in kwargs:
|
||||
kwargs["chunk_overlap"] = 10
|
||||
super().__init__(separators=separators, keep_separator=True, **kwargs)
|
||||
self._is_separator_regex = is_separator_regrx
|
||||
|
||||
def split_text(self, text: str) -> List[str]:
|
||||
"""Return the list of separated text chunks"""
|
||||
cleaned_paragraph = get_cleaned_paragraph(text)
|
||||
splitted = []
|
||||
for paragraph in cleaned_paragraph:
|
||||
segs = super().split_text(paragraph)
|
||||
for i in range(len(segs) - 1):
|
||||
if segs[i][-1] not in self._separators:
|
||||
pos = text.find(segs[i])
|
||||
pos_end = pos + len(segs[i])
|
||||
if i > 0:
|
||||
last_sentence_start = max([text.rfind(m, 0, pos) for m in ["。", "!", "?"]])
|
||||
pos = last_sentence_start + 1
|
||||
segs[i] = str(text[pos:pos_end])
|
||||
if i != len(segs) - 1:
|
||||
next_sentence_end = max([text.find(m, pos_end) for m in ["。", "!", "?"]])
|
||||
segs[i] = str(text[pos : next_sentence_end + 1])
|
||||
splitted.append(segs[i])
|
||||
if len(splitted) <= 1:
|
||||
return splitted
|
||||
splitted_text = []
|
||||
i = 1
|
||||
if splitted[0] not in splitted[1]:
|
||||
splitted_text.append([splitted[0], 0])
|
||||
if splitted[-1] not in splitted[-2]:
|
||||
splitted_text.append([splitted[-1], len(splitted) - 1])
|
||||
while i < len(splitted) - 1:
|
||||
if splitted[i] not in splitted[i + 1] and splitted[i] not in splitted[i - 1]:
|
||||
splitted_text.append([splitted[i], i])
|
||||
i += 1
|
||||
splitted_text = sorted(splitted_text, key=lambda x: x[1])
|
||||
splitted_text = [splitted_text[i][0] for i in range(len(splitted_text))]
|
||||
ret = []
|
||||
for s in splitted_text:
|
||||
if s not in ret:
|
||||
ret.append(s)
|
||||
return ret
|
19
applications/ColossalQA/colossalqa/text_splitter/utils.py
Normal file
19
applications/ColossalQA/colossalqa/text_splitter/utils.py
Normal file
@@ -0,0 +1,19 @@
|
||||
import re
|
||||
|
||||
|
||||
def remove_format(text: str) -> str:
|
||||
# if the accout of \t, \r, \v, \f is less than 3, replace \t, \r, \v, \f with space
|
||||
if len(re.findall(r"\s", text.replace(" ", ""))) > 3:
|
||||
# in case this is a line of a table
|
||||
return text
|
||||
return re.sub(r"\s", " ", text)
|
||||
|
||||
|
||||
# remove newlines
|
||||
def get_cleaned_paragraph(s: str) -> str:
|
||||
text = str(s)
|
||||
text = re.sub(r"\n{3,}", r"\n", text) # replace \n\n\n... with \n
|
||||
text = re.sub("\n\n", "", text)
|
||||
lines = text.split("\n")
|
||||
lines_remove_format = [remove_format(line) for line in lines]
|
||||
return lines_remove_format
|
61
applications/ColossalQA/colossalqa/utils.py
Normal file
61
applications/ColossalQA/colossalqa/utils.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import re
|
||||
from typing import Union
|
||||
|
||||
from colossalqa.mylogging import get_logger
|
||||
from sqlalchemy import Engine, MetaData, create_engine
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def drop_table(engine: Engine) -> None:
|
||||
"""
|
||||
Drop all existing table
|
||||
"""
|
||||
Base = declarative_base()
|
||||
metadata = MetaData()
|
||||
metadata.reflect(bind=engine)
|
||||
for key in metadata.tables:
|
||||
table = metadata.tables[key]
|
||||
if table is not None:
|
||||
Base.metadata.drop_all(engine, [table], checkfirst=True)
|
||||
|
||||
|
||||
def create_empty_sql_database(database_uri):
|
||||
try:
|
||||
# Create an SQLAlchemy engine to connect to the database
|
||||
engine = create_engine(database_uri)
|
||||
|
||||
# Create the database
|
||||
engine.connect()
|
||||
|
||||
logger.info(f"Database created at {database_uri}")
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Error creating database: {str(e)}")
|
||||
return engine, database_uri
|
||||
|
||||
|
||||
def destroy_sql_database(sql_engine: Union[Engine, str]) -> None:
|
||||
"""
|
||||
Destroy an sql database
|
||||
"""
|
||||
if isinstance(sql_engine, str):
|
||||
sql_engine = create_engine(sql_engine)
|
||||
drop_table(sql_engine)
|
||||
sql_engine.dispose()
|
||||
sql_engine = None
|
||||
|
||||
|
||||
def detect_lang_naive(s):
|
||||
"""
|
||||
Naive function for language detection, should be replaced by an independant layer
|
||||
"""
|
||||
remove_nota = "[’·°–!\"#$%&'()*+,-./:;<=>?@,。?★、…【】()《》?“”‘’![\\]^_`{|}~]+"
|
||||
s = re.sub(remove_nota, "", s)
|
||||
s = re.sub("[0-9]", "", s).strip()
|
||||
res = re.sub("[a-zA-Z]", "", s).strip()
|
||||
if len(res) <= 0:
|
||||
return "en"
|
||||
else:
|
||||
return "zh"
|
Reference in New Issue
Block a user