[Feature] Add document retrieval QA (#5020)

* add langchain

* add langchain

* Add files via upload

* add langchain

* fix style

* fix style: remove extra space

* add pytest; modified retriever

* add pytest; modified retriever

* add tests to build_on_pr.yml

* fix build_on_pr.yml

* fix build on pr; fix environ vars

* seperate unit tests for colossalqa from build from pr

* fix container setting; fix environ vars

* commented dev code

* add incremental update

* remove stale code

* fix style

* change to sha3 224

* fix retriever; fix style; add unit test for document loader

* fix ci workflow config

* fix ci workflow config

* add set cuda visible device script in ci

* fix doc string

* fix style; update readme; refactored

* add force log info

* change build on pr, ignore colossalqa

* fix docstring, captitalize all initial letters

* fix indexing; fix text-splitter

* remove debug code, update reference

* reset previous commit

* update LICENSE update README add key-value mode, fix bugs

* add files back

* revert force push

* remove junk file

* add test files

* fix retriever bug, add intent classification

* change conversation chain design

* rewrite prompt and conversation chain

* add ui v1

* ui v1

* fix atavar

* add header

* Refactor the RAG Code and support Pangu

* Refactor the ColossalQA chain to Object-Oriented Programming and the UI demo.

* resolved conversation. tested scripts under examples. web demo still buggy

* fix ci tests

* Some modifications to add ChatGPT api

* modify llm.py and remove unnecessary files

* Delete applications/ColossalQA/examples/ui/test_frontend_input.json

* Remove OpenAI api key

* add colossalqa

* move files

* move files

* move files

* move files

* fix style

* Add Readme and fix some bugs.

* Add something to readme and modify some code

* modify a directory name for clarity

* remove redundant directory

* Correct a type in  llm.py

* fix AI prefix

* fix test_memory.py

* fix conversation

* fix some erros and typos

* Fix a missing import in RAG_ChatBot.py

* add colossalcloud LLM wrapper, correct issues in code review

---------

Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: Orion-Zheng <zheng_zian@u.nus.edu>
Co-authored-by: Zian(Andy) Zheng <62330719+Orion-Zheng@users.noreply.github.com>
Co-authored-by: Orion-Zheng <zhengzian@u.nus.edu>
This commit is contained in:
YeAnbang
2023-11-23 10:33:48 +08:00
committed by GitHub
parent 3acbf6d496
commit e53e729d8e
69 changed files with 6758 additions and 0 deletions

View File

@@ -0,0 +1,103 @@
"""
Custom SummarizerMixin base class and ConversationSummaryMemory class
Modified from Original Source
This code is based on LangChain Ai's langchain, which can be found at
https://github.com/langchain-ai/langchain
The original code is licensed under the MIT license.
"""
from __future__ import annotations
from typing import Any, Dict, List, Type
from langchain.chains.llm import LLMChain
from langchain.memory.chat_memory import BaseChatMemory
from langchain.memory.prompt import SUMMARY_PROMPT
from langchain.pydantic_v1 import BaseModel, root_validator
from langchain.schema import BaseChatMessageHistory, BasePromptTemplate
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.messages import BaseMessage, SystemMessage, get_buffer_string
class SummarizerMixin(BaseModel):
"""
Mixin for summarizer.
"""
human_prefix: str = "Human"
ai_prefix: str = "Assistant"
llm: BaseLanguageModel
prompt: BasePromptTemplate = SUMMARY_PROMPT
summary_message_cls: Type[BaseMessage] = SystemMessage
llm_kwargs: Dict = {}
def predict_new_summary(self, messages: List[BaseMessage], existing_summary: str, stop: List = []) -> str:
"""
Recursively summarize a conversation by generating a new summary using
the last round of conversation and the existing summary.
"""
new_lines = get_buffer_string(
messages,
human_prefix=self.human_prefix,
ai_prefix=self.ai_prefix,
)
chain = LLMChain(llm=self.llm, prompt=self.prompt, llm_kwargs=self.llm_kwargs)
return chain.predict(summary=existing_summary, new_lines=new_lines, stop=stop)
class ConversationSummaryMemory(BaseChatMemory, SummarizerMixin):
"""Conversation summarizer to chat memory."""
buffer: str = ""
memory_key: str = "history"
@classmethod
def from_messages(
cls,
llm: BaseLanguageModel,
chat_memory: BaseChatMessageHistory,
summarize_step: int = 2,
**kwargs: Any,
) -> ConversationSummaryMemory:
obj = cls(llm=llm, chat_memory=chat_memory, **kwargs)
for i in range(0, len(obj.chat_memory.messages), summarize_step):
obj.buffer = obj.predict_new_summary(obj.chat_memory.messages[i : i + summarize_step], obj.buffer)
return obj
@property
def memory_variables(self) -> List[str]:
"""Will always return list of memory variables."""
return [self.memory_key]
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Return history buffer."""
if self.return_messages:
buffer: Any = [self.summary_message_cls(content=self.buffer)]
else:
buffer = self.buffer
return {self.memory_key: buffer}
@root_validator()
def validate_prompt_input_variables(cls, values: Dict) -> Dict:
"""Validate that prompt input variables are consistent."""
prompt_variables = values["prompt"].input_variables
expected_keys = {"summary", "new_lines"}
if expected_keys != set(prompt_variables):
raise ValueError(
"Got unexpected prompt input variables. The prompt expects "
f"{prompt_variables}, but it should have {expected_keys}."
)
return values
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Save context from this conversation to buffer."""
super().save_context(inputs, outputs)
self.buffer = self.predict_new_summary(self.chat_memory.messages[-2:], self.buffer)
def clear(self) -> None:
"""Clear memory contents."""
super().clear()
self.buffer = ""

View File

@@ -0,0 +1,214 @@
"""
Chain for question-answering against a vector database.
Modified from Original Source
This code is based on LangChain Ai's langchain, which can be found at
https://github.com/langchain-ai/langchain
The original code is licensed under the MIT license.
"""
from __future__ import annotations
import copy
import inspect
from typing import Any, Dict, List, Optional
from colossalqa.chain.retrieval_qa.load_chain import load_qa_chain
from colossalqa.chain.retrieval_qa.stuff import CustomStuffDocumentsChain
from langchain.callbacks.manager import AsyncCallbackManagerForChainRun, CallbackManagerForChainRun, Callbacks
from langchain.chains.llm import LLMChain
from langchain.chains.question_answering.stuff_prompt import PROMPT_SELECTOR
from langchain.chains.retrieval_qa.base import BaseRetrievalQA
from langchain.prompts import PromptTemplate
from langchain.pydantic_v1 import Field
from langchain.schema import BaseRetriever, Document
from langchain.schema.language_model import BaseLanguageModel
class CustomBaseRetrievalQA(BaseRetrievalQA):
"""Base class for question-answering chains."""
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
prompt: Optional[PromptTemplate] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> BaseRetrievalQA:
"""Initialize from LLM."""
llm_kwargs = kwargs.pop("llm_kwargs", {})
_prompt = prompt or PROMPT_SELECTOR.get_prompt(llm)
llm_chain = LLMChain(llm=llm, prompt=_prompt, callbacks=callbacks, llm_kwargs=llm_kwargs)
document_prompt = kwargs.get(
"document_prompt", PromptTemplate(input_variables=["page_content"], template="Context:\n{page_content}")
)
combine_documents_chain = CustomStuffDocumentsChain(
llm_chain=llm_chain,
document_variable_name="context",
document_prompt=document_prompt,
callbacks=callbacks,
)
return cls(
combine_documents_chain=combine_documents_chain,
callbacks=callbacks,
**kwargs,
)
@classmethod
def from_chain_type(
cls,
llm: BaseLanguageModel,
chain_type: str = "stuff",
chain_type_kwargs: Optional[dict] = None,
**kwargs: Any,
) -> BaseRetrievalQA:
"""Load chain from chain type."""
llm_kwargs = kwargs.pop("llm_kwargs", {})
_chain_type_kwargs = chain_type_kwargs or {}
combine_documents_chain = load_qa_chain(llm, chain_type=chain_type, **_chain_type_kwargs, llm_kwargs=llm_kwargs)
return cls(combine_documents_chain=combine_documents_chain, **kwargs)
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Run get_relevant_text and llm on input query.
If chain has 'return_source_documents' as 'True', returns
the retrieved documents as well under the key 'source_documents'.
Example:
.. code-block:: python
res = indexqa({'query': 'This is my query'})
answer, docs = res['result'], res['source_documents']
"""
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
question = inputs[self.input_key]
accepts_run_manager = "run_manager" in inspect.signature(self._get_docs).parameters
if accepts_run_manager:
docs = self._get_docs(question, run_manager=_run_manager)
else:
docs = self._get_docs(question) # type: ignore[call-arg]
kwargs = {
k: v
for k, v in inputs.items()
if k in ["stop", "temperature", "top_k", "top_p", "max_new_tokens", "doc_prefix"]
}
answers = []
if self.combine_documents_chain.memory is not None:
buffered_history_backup, summarized_history_temp_backup = copy.deepcopy(
self.combine_documents_chain.memory.buffered_history
), copy.deepcopy(self.combine_documents_chain.memory.summarized_history_temp)
else:
buffered_history_backup = None
summarized_history_temp_backup = None
answer = self.combine_documents_chain.run(
input_documents=docs, question=question, callbacks=_run_manager.get_child(), **kwargs
)
if summarized_history_temp_backup is not None and buffered_history_backup is not None:
(
self.combine_documents_chain.memory.buffered_history,
self.combine_documents_chain.memory.summarized_history_temp,
) = copy.deepcopy(buffered_history_backup), copy.deepcopy(summarized_history_temp_backup)
# if rejection_trigger_keywords is not given, return the response from LLM directly
rejection_trigger_keywrods = inputs.get('rejection_trigger_keywrods', [])
answer = answer if all([rej not in answer for rej in rejection_trigger_keywrods]) else None
if answer is None:
answer = inputs.get('rejection_answer', "抱歉,根据提供的信息无法回答该问题。")
if self.combine_documents_chain.memory is not None:
self.combine_documents_chain.memory.save_context({"question": question}, {"output": answer})
if self.return_source_documents:
return {self.output_key: answer, "source_documents": docs}
else:
return {self.output_key: answer}
async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Run get_relevant_text and llm on input query.
If chain has 'return_source_documents' as 'True', returns
the retrieved documents as well under the key 'source_documents'.
Example:
.. code-block:: python
res = indexqa({'query': 'This is my query'})
answer, docs = res['result'], res['source_documents']
"""
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
question = inputs[self.input_key]
accepts_run_manager = "run_manager" in inspect.signature(self._aget_docs).parameters
if accepts_run_manager:
docs = await self._aget_docs(question, run_manager=_run_manager)
else:
docs = await self._aget_docs(question) # type: ignore[call-arg]
kwargs = {
k: v
for k, v in inputs.items()
if k in ["stop", "temperature", "top_k", "top_p", "max_new_tokens", "doc_prefix"]
}
answer = await self.combine_documents_chain.arun(
input_documents=docs, question=question, callbacks=_run_manager.get_child(), **kwargs
)
# if rejection_trigger_keywords is not given, return the response from LLM directly
rejection_trigger_keywrods = inputs.get('rejection_trigger_keywrods', [])
answer = answer if all([rej not in answer for rej in rejection_trigger_keywrods]) or len(rejection_trigger_keywrods)==0 else None
if answer is None:
answer = inputs.get('rejection_answer', "抱歉,根据提供的信息无法回答该问题。")
self.combine_documents_chain.memory.save_context({"question": question}, {"output": answer})
if self.return_source_documents:
return {self.output_key: answer, "source_documents": docs}
else:
return {self.output_key: answer}
class RetrievalQA(CustomBaseRetrievalQA):
"""Chain for question-answering against an index.
Example:
.. code-block:: python
from langchain.llms import OpenAI
from langchain.chains import RetrievalQA
from langchain.faiss import FAISS
from langchain.vectorstores.base import VectorStoreRetriever
retriever = VectorStoreRetriever(vectorstore=FAISS(...))
retrievalQA = RetrievalQA.from_llm(llm=OpenAI(), retriever=retriever)
"""
retriever: BaseRetriever = Field(exclude=True)
def _get_docs(
self,
question: str,
*,
run_manager: CallbackManagerForChainRun,
) -> List[Document]:
"""Get docs."""
return self.retriever.get_relevant_documents(question, callbacks=run_manager.get_child())
async def _aget_docs(
self,
question: str,
*,
run_manager: AsyncCallbackManagerForChainRun,
) -> List[Document]:
"""Get docs."""
return await self.retriever.aget_relevant_documents(question, callbacks=run_manager.get_child())
@property
def _chain_type(self) -> str:
"""Return the chain type."""
return "retrieval_qa"

View File

@@ -0,0 +1,87 @@
"""
Load question answering chains.
For now, only the stuffed chain is modified
Modified from Original Source
This code is based on LangChain Ai's langchain, which can be found at
https://github.com/langchain-ai/langchain
The original code is licensed under the MIT license.
"""
import copy
from typing import Any, Mapping, Optional, Protocol
from colossalqa.chain.retrieval_qa.stuff import CustomStuffDocumentsChain
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.llm import LLMChain
from langchain.chains.question_answering import stuff_prompt
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.prompt_template import BasePromptTemplate
class LoadingCallable(Protocol):
"""Interface for loading the combine documents chain."""
def __call__(self, llm: BaseLanguageModel, **kwargs: Any) -> BaseCombineDocumentsChain:
"""Callable to load the combine documents chain."""
def _load_stuff_chain(
llm: BaseLanguageModel,
prompt: Optional[BasePromptTemplate] = None,
document_variable_name: str = "context",
verbose: Optional[bool] = None,
callback_manager: Optional[BaseCallbackManager] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> CustomStuffDocumentsChain:
_prompt = prompt or stuff_prompt.PROMPT_SELECTOR.get_prompt(llm)
if "llm_kwargs" in kwargs:
llm_kwargs = copy.deepcopy(kwargs["llm_kwargs"])
del kwargs["llm_kwargs"]
else:
llm_kwargs = {}
llm_chain = LLMChain(
llm=llm,
prompt=_prompt,
verbose=verbose,
callback_manager=callback_manager,
callbacks=callbacks,
llm_kwargs=llm_kwargs,
)
return CustomStuffDocumentsChain(
llm_chain=llm_chain,
document_variable_name=document_variable_name,
verbose=verbose,
callback_manager=callback_manager,
callbacks=callbacks,
**kwargs,
)
def load_qa_chain(
llm: BaseLanguageModel,
chain_type: str = "stuff",
verbose: Optional[bool] = None,
callback_manager: Optional[BaseCallbackManager] = None,
**kwargs: Any,
) -> BaseCombineDocumentsChain:
"""Load question answering chain.
Args:
llm: Language Model to use in the chain.
chain_type: Type of document combining chain to use. Should be one of "stuff",
"map_reduce", "map_rerank", and "refine".
verbose: Whether chains should be run in verbose mode or not. Note that this
applies to all chains that make up the final chain.
callback_manager: Callback manager to use for the chain.
Returns:
A chain to use for question answering.
"""
loader_mapping: Mapping[str, LoadingCallable] = {"stuff": _load_stuff_chain}
if chain_type not in loader_mapping:
raise ValueError(f"Got unsupported chain type: {chain_type}. " f"Should be one of {loader_mapping.keys()}")
return loader_mapping[chain_type](llm, verbose=verbose, callback_manager=callback_manager, **kwargs)

View File

@@ -0,0 +1,91 @@
"""
Chain that combines documents by stuffing into context
Modified from Original Source
This code is based on LangChain Ai's langchain, which can be found at
https://github.com/langchain-ai/langchain
The original code is licensed under the MIT license.
"""
import copy
from typing import Any, List
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.docstore.document import Document
from langchain.schema import format_document
class CustomStuffDocumentsChain(StuffDocumentsChain):
"""Chain that combines documents by stuffing into context.
This chain takes a list of documents and first combines them into a single string.
It does this by formatting each document into a string with the `document_prompt`
and then joining them together with `document_separator`. It then adds that new
string to the inputs with the variable name set by `document_variable_name`.
Those inputs are then passed to the `llm_chain`.
Example:
.. code-block:: python
from langchain.chains import StuffDocumentsChain, LLMChain
from langchain.prompts import PromptTemplate
from langchain.llms import OpenAI
# This controls how each document will be formatted. Specifically,
# it will be passed to `format_document` - see that function for more
# details.
document_prompt = PromptTemplate(
input_variables=["page_content"],
template="{page_content}"
)
document_variable_name = "context"
llm = OpenAI()
# The prompt here should take as an input variable the
# `document_variable_name`
prompt = PromptTemplate.from_template(
"Summarize this content: {context}"
)
llm_chain = LLMChain(llm=llm, prompt=prompt)
chain = StuffDocumentsChain(
llm_chain=llm_chain,
document_prompt=document_prompt,
document_variable_name=document_variable_name
)
"""
def _get_inputs(self, docs: List[Document], **kwargs: Any) -> dict:
"""Construct inputs from kwargs and docs.
Format and the join all the documents together into one input with name
`self.document_variable_name`. The pluck any additional variables
from **kwargs.
Args:
docs: List of documents to format and then join into single input
**kwargs: additional inputs to chain, will pluck any other required
arguments from here.
Returns:
dictionary of inputs to LLMChain
"""
# Format each document according to the prompt
# if the document is in the key-value format has a 'is_key_value_mapping'=True in meta_data and has 'value' in metadata
# use the value to replace the key
doc_prefix = kwargs.get("doc_prefix", "Supporting Document")
docs_ = []
for id, doc in enumerate(docs):
doc_ = copy.deepcopy(doc)
if doc_.metadata.get("is_key_value_mapping", False) and "value" in doc_.metadata:
doc_.page_content = str(doc_.metadata["value"])
prefix = doc_prefix + str(id)
doc_.page_content = str(prefix + ":" + (" " if doc_.page_content[0] != " " else "") + doc_.page_content)
docs_.append(doc_)
doc_strings = [format_document(doc, self.document_prompt) for doc in docs_]
arg_list = ["stop", "temperature", "top_k", "top_p", "max_new_tokens"]
arg_list.extend(self.llm_chain.prompt.input_variables)
# Join the documents together to put them in the prompt.
inputs = {k: v for k, v in kwargs.items() if k in arg_list}
inputs[self.document_variable_name] = self.document_separator.join(doc_strings)
return inputs

View File

@@ -0,0 +1,128 @@
"""
Class for loading document type data
"""
import glob
from typing import List
from colossalqa.mylogging import get_logger
from langchain.document_loaders import (
JSONLoader,
PyPDFLoader,
TextLoader,
UnstructuredHTMLLoader,
UnstructuredMarkdownLoader,
)
from langchain.document_loaders.csv_loader import CSVLoader
logger = get_logger()
SUPPORTED_DATA_FORMAT = [".csv", ".json", ".html", ".md", ".pdf", ".txt", ".jsonl"]
class DocumentLoader:
"""
Load documents from different files into list of langchain Documents
"""
def __init__(self, files: List, **kwargs) -> None:
"""
Args:
files: list of files (list[file path, name])
**kwargs: keyword type arguments, useful for certain document types
"""
self.data = {}
self.kwargs = kwargs
for item in files:
path = item[0] if isinstance(item, list) else item
logger.info(f"Loading data from {path}")
self.load_data(path)
logger.info("Data loaded")
self.all_data = []
for key in self.data:
if isinstance(self.data[key], list):
for item in self.data[key]:
if isinstance(item, list):
self.all_data.extend(item)
else:
self.all_data.append(item)
def load_data(self, path: str) -> None:
"""
Load data. Please refer to https://python.langchain.com/docs/modules/data_connection/document_loaders/
for sepcific format requirements.
Args:
path: path to a file
To load files with glob path, here are some examples.
Load all file from directory: folder1/folder2/*
Load all pdf file from directory: folder1/folder2/*.pdf
"""
files = []
# Handle glob expression
try:
files = glob.glob(path)
except Exception as e:
logger.error(e)
if len(files) == 0:
raise ValueError("Unsupported file/directory format. For directories, please use glob expression")
elif len(files) == 1:
path = files[0]
else:
for file in files:
self.load_data(file)
return
# Load data if the path is a file
logger.info(f"load {path}", verbose=True)
if path.endswith(".csv"):
# Load csv
loader = CSVLoader(file_path=path, encoding="utf8")
data = loader.load()
self.data[path] = data
elif path.endswith(".txt"):
# Load txt
loader = TextLoader(path, encoding="utf8")
data = loader.load()
self.data[path] = data
elif path.endswith("html"):
# Load html
loader = UnstructuredHTMLLoader(path, encoding="utf8")
data = loader.load()
self.data[path] = data
elif path.endswith("json"):
# Load json
loader = JSONLoader(
file_path=path,
jq_schema=self.kwargs.get("jq_schema", ".data[]"),
content_key=self.kwargs.get("content_key", "content"),
metadata_func=self.kwargs.get("metadata_func", None),
)
data = loader.load()
self.data[path] = data
elif path.endswith("jsonl"):
# Load jsonl
loader = JSONLoader(
file_path=path, jq_schema=self.kwargs.get("jq_schema", ".data[].content"), json_lines=True
)
data = loader.load()
self.data[path] = data
elif path.endswith(".md"):
# Load markdown
loader = UnstructuredMarkdownLoader(path)
data = loader.load()
self.data[path] = data
elif path.endswith(".pdf"):
# Load pdf
loader = PyPDFLoader(path)
data = loader.load_and_split()
self.data[path] = data
else:
if "." in path.split("/")[-1]:
raise ValueError(f"Unsupported file format {path}. Supported formats: {SUPPORTED_DATA_FORMAT}")
else:
# May ba a directory, we strictly follow the glob path and will not load files in subdirectories
pass

View File

@@ -0,0 +1,119 @@
'''
Class for loading table type data. please refer to Pandas-Input/Output for file format details.
'''
import os
import glob
import pandas as pd
from sqlalchemy import create_engine
from colossalqa.utils import drop_table
from colossalqa.mylogging import get_logger
logger = get_logger()
SUPPORTED_DATA_FORMAT = ['.csv','.xlsx', '.xls','.json','.html','.h5', '.hdf5','.parquet','.feather','.dta']
class TableLoader:
'''
Load tables from different files and serve a sql database for database operations
'''
def __init__(self, files: str,
sql_path:str='sqlite:///mydatabase.db',
verbose=False, **kwargs) -> None:
'''
Args:
files: list of files (list[file path, name])
sql_path: how to serve the sql database
**kwargs: keyword type arguments, useful for certain document types
'''
self.data = {}
self.verbose = verbose
self.sql_path = sql_path
self.kwargs = kwargs
self.sql_engine = create_engine(self.sql_path)
drop_table(self.sql_engine)
self.sql_engine = create_engine(self.sql_path)
for item in files:
path = item[0]
dataset_name = item[1]
if not os.path.exists(path):
raise FileNotFoundError(f"{path} doesn't exists")
if not any([path.endswith(i) for i in SUPPORTED_DATA_FORMAT]):
raise TypeError(f"{path} not supported. Supported type {SUPPORTED_DATA_FORMAT}")
logger.info("loading data", verbose=self.verbose)
self.load_data(path)
logger.info("data loaded", verbose=self.verbose)
self.to_sql(path, dataset_name)
def load_data(self, path):
'''
Load data and serve the data as sql database.
Data must be in pandas format
'''
files = []
# Handle glob expression
try:
files = glob.glob(path)
except Exception as e:
logger.error(e)
if len(files)==0:
raise ValueError("Unsupported file/directory format. For directories, please use glob expression")
elif len(files)==1:
path = files[0]
else:
for file in files:
self.load_data(file)
if path.endswith('.csv'):
# Load csv
self.data[path] = pd.read_csv(path)
elif path.endswith('.xlsx') or path.endswith('.xls'):
# Load excel
self.data[path] = pd.read_excel(path) # You can adjust the sheet_name as needed
elif path.endswith('.json'):
# Load json
self.data[path] = pd.read_json(path)
elif path.endswith('.html'):
# Load html
html_tables = pd.read_html(path)
# Choose the desired table from the list of DataFrame objects
self.data[path] = html_tables[0] # You may need to adjust this index
elif path.endswith('.h5') or path.endswith('.hdf5'):
# Load h5
self.data[path] = pd.read_hdf(path, key=self.kwargs.get('key', 'data')) # You can adjust the key as needed
elif path.endswith('.parquet'):
# Load parquet
self.data[path] = pd.read_parquet(path, engine='fastparquet')
elif path.endswith('.feather'):
# Load feather
self.data[path] = pd.read_feather(path)
elif path.endswith('.dta'):
# Load dta
self.data[path] = pd.read_stata(path)
else:
raise ValueError("Unsupported file format")
def to_sql(self, path, table_name):
'''
Serve the data as sql database.
'''
self.data[path].to_sql(table_name, con=self.sql_engine, if_exists='replace', index=False)
logger.info(f"Loaded to Sqlite3\nPath: {path}", verbose=self.verbose)
return self.sql_path
def get_sql_path(self):
return self.sql_path
def __del__(self):
if self.sql_engine:
drop_table(self.sql_engine)
self.sql_engine.dispose()
del self.data
del self.sql_engine

View File

@@ -0,0 +1,125 @@
"""
LLM wrapper for LLMs running on ColossalCloud Platform
Usage:
os.environ['URL'] = ""
os.environ['HOST'] = ""
gen_config = {
'max_new_tokens': 100,
# 'top_k': 2,
'top_p': 0.9,
'temperature': 0.5,
'repetition_penalty': 2,
}
llm = ColossalCloudLLM(n=1)
llm.set_auth_config()
resp = llm(prompt='What do you call a three-ton kangaroo?', **gen_config)
print(resp) # super-heavyweight awesome-natured yawning Australian creature!
"""
import json
from typing import Any, List, Mapping, Optional
import requests
from langchain.llms.base import LLM
from langchain.utils import get_from_dict_or_env
class ColossalCloudLLM(LLM):
"""
A custom LLM class that integrates LLMs running on the ColossalCloud Platform
"""
n: int
gen_config: dict = None
auth_config: dict = None
valid_gen_para: list = ['max_new_tokens', 'top_k',
'top_p', 'temperature', 'repetition_penalty']
def __init__(self, gen_config=None, **kwargs):
"""
Args:
gen_config: config for generation,
max_new_tokens: 50 by default
top_k: (1, vocab_size)
top_p: (0, 1) if not None
temperature: (0, inf) if not None
repetition_penalty: (1, inf) if not None
"""
super(ColossalCloudLLM, self).__init__(**kwargs)
if gen_config is None:
self.gen_config = {"max_new_tokens": 50}
else:
assert "max_new_tokens" in gen_config, "max_new_tokens is a compulsory key in the gen config"
self.gen_config = gen_config
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {"n": self.n}
@property
def _llm_type(self) -> str:
return 'ColossalCloudLLM'
def set_auth_config(self, **kwargs):
url = get_from_dict_or_env(kwargs, "url", "URL")
host = get_from_dict_or_env(kwargs, "host", "HOST")
auth_config = {}
auth_config['endpoint'] = url
auth_config['Host'] = host
self.auth_config = auth_config
def _call(self, prompt: str, stop=None, **kwargs: Any) -> str:
"""
Args:
prompt: The prompt to pass into the model.
stop: A list of strings to stop generation when encountered
Returns:
The string generated by the model
"""
# Update the generation arguments
for key, value in kwargs.items():
if key not in self.valid_gen_para:
raise KeyError(f"Invalid generation parameter: '{key}'. Valid keys are: {', '.join(self.valid_gen_para)}")
if key in self.gen_config:
self.gen_config[key] = value
resp_text = self.text_completion(prompt, self.gen_config, self.auth_config)
# TODO: This may cause excessive tokens count
if stop is not None:
for stopping_words in stop:
if stopping_words in resp_text:
resp_text = resp_text.split(stopping_words)[0]
return resp_text
def text_completion(self, prompt, gen_config, auth_config):
# Complusory Parameters
endpoint = auth_config.pop('endpoint')
max_new_tokens = gen_config.pop('max_new_tokens')
# Optional Parameters
optional_params = ['top_k', 'top_p', 'temperature', 'repetition_penalty'] # Self.optional
gen_config = {key: gen_config[key] for key in optional_params if key in gen_config}
# Define the data payload
data = {
"max_new_tokens": max_new_tokens,
"history": [
{"instruction": prompt, "response": ""}
],
**gen_config
}
headers = {
"Content-Type": "application/json",
**auth_config # 'Host',
}
# Make the POST request
response = requests.post(endpoint, headers=headers, data=json.dumps(data))
response.raise_for_status() # raise error if return code is not 200(success)
# Check the response
return response.text

View File

@@ -0,0 +1,183 @@
"""
API and LLM warpper class for running LLMs locally
Usage:
import os
model_path = os.environ.get("ZH_MODEL_PATH")
model_name = "chatglm2"
colossal_api = ColossalAPI(model_name, model_path)
llm = ColossalLLM(n=1, api=colossal_api)
TEST_PROMPT_CHATGLM="续写文章:惊蛰一过,春寒加剧。先是料料峭峭,继而雨季开始,"
logger.info(llm(TEST_PROMPT_CHATGLM, max_new_tokens=100), verbose=True)
"""
from typing import Any, List, Mapping, Optional
import torch
from colossalqa.local.utils import get_response, post_http_request
from colossalqa.mylogging import get_logger
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from transformers import AutoModelForCausalLM, AutoTokenizer
logger = get_logger()
class ColossalAPI:
"""
API for calling LLM.generate
"""
__instances = dict()
def __init__(self, model_type: str, model_path: str, ckpt_path: str = None) -> None:
"""
Configurate model
"""
if model_type + model_path + (ckpt_path or "") in ColossalAPI.__instances:
return
else:
ColossalAPI.__instances[model_type + model_path + (ckpt_path or "")] = self
self.model_type = model_type
self.model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, trust_remote_code=True)
if ckpt_path is not None:
state_dict = torch.load(ckpt_path)
self.model.load_state_dict(state_dict)
self.model.to(torch.cuda.current_device())
# Configurate tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
self.model.eval()
@staticmethod
def get_api(model_type: str, model_path: str, ckpt_path: str = None):
if model_type + model_path + (ckpt_path or "") in ColossalAPI.__instances:
return ColossalAPI.__instances[model_type + model_path + (ckpt_path or "")]
else:
return ColossalAPI(model_type, model_path, ckpt_path)
def generate(self, input: str, **kwargs) -> str:
"""
Generate response given the prompt
Args:
input: input string
**kwargs: language model keyword type arguments, such as top_k, top_p, temperature, max_new_tokens...
Returns:
output: output string
"""
if self.model_type in ["chatglm", "chatglm2"]:
inputs = {
k: v.to(torch.cuda.current_device()) for k, v in self.tokenizer(input, return_tensors="pt").items()
}
else:
inputs = {
"input_ids": self.tokenizer(input, return_tensors="pt")["input_ids"].to(torch.cuda.current_device())
}
output = self.model.generate(**inputs, **kwargs)
output = output.cpu()
prompt_len = inputs["input_ids"].size(1)
response = output[0, prompt_len:]
output = self.tokenizer.decode(response, skip_special_tokens=True)
return output
class VllmAPI:
def __init__(self, host: str = "localhost", port: int = 8077) -> None:
# Configurate api for model served through web
self.host = host
self.port = port
self.url = f"http://{self.host}:{self.port}/generate"
def generate(self, input: str, **kwargs):
output = get_response(post_http_request(input, self.url, **kwargs))[0]
return output[len(input) :]
class ColossalLLM(LLM):
"""
Langchain LLM wrapper for a local LLM
"""
n: int
api: Any
kwargs = {"max_new_tokens": 100}
@property
def _llm_type(self) -> str:
return "custom"
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
logger.info(f"kwargs:{kwargs}\nstop:{stop}\nprompt:{prompt}", verbose=self.verbose)
for k in self.kwargs:
if k not in kwargs:
kwargs[k] = self.kwargs[k]
generate_args = {k: kwargs[k] for k in kwargs if k not in ["stop", "n"]}
out = self.api.generate(prompt, **generate_args)
if isinstance(stop, list) and len(stop) != 0:
for stopping_words in stop:
if stopping_words in out:
out = out.split(stopping_words)[0]
logger.info(f"{prompt}{out}", verbose=self.verbose)
return out
@property
def _identifying_params(self) -> Mapping[str, int]:
"""Get the identifying parameters."""
return {"n": self.n}
class VllmLLM(LLM):
"""
Langchain LLM wrapper for a local LLM
"""
n: int
api: Any
kwargs = {"max_new_tokens": 100}
@property
def _llm_type(self) -> str:
return "custom"
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
for k in self.kwargs:
if k not in kwargs:
kwargs[k] = self.kwargs[k]
logger.info(f"kwargs:{kwargs}\nstop:{stop}\nprompt:{prompt}", verbose=self.verbose)
generate_args = {k: kwargs[k] for k in kwargs if k in ["n", "max_tokens", "temperature", "stream"]}
out = self.api.generate(prompt, **generate_args)
if len(stop) != 0:
for stopping_words in stop:
if stopping_words in out:
out = out.split(stopping_words)[0]
logger.info(f"{prompt}{out}", verbose=self.verbose)
return out
def set_host_port(self, host: str = "localhost", port: int = 8077, **kwargs) -> None:
if "max_tokens" not in kwargs:
kwargs["max_tokens"] = 100
self.kwargs = kwargs
self.api = VllmAPI(host=host, port=port)
@property
def _identifying_params(self) -> Mapping[str, int]:
"""Get the identifying parameters."""
return {"n": self.n}

View File

@@ -0,0 +1,150 @@
"""
LLM wrapper for Pangu
Usage:
# URL: “盘古大模型套件管理”->点击“服务管理”->“模型列表”->点击想要使用的模型的“复制路径”
# USERNAME: 华为云控制台:“我的凭证”->“API凭证”下的“IAM用户名”也就是你登录IAM账户的名字
# PASSWORD: IAM用户的密码
# DOMAIN_NAME: 华为云控制台:“我的凭证”->“API凭证”下的“用户名”也就是公司管理IAM账户的总账户名
os.environ["URL"] = ""
os.environ["URLNAME"] = ""
os.environ["PASSWORD"] = ""
os.environ["DOMAIN_NAME"] = ""
pg = Pangu(id=1)
pg.set_auth_config()
res = pg('你是谁') # 您好,我是华为盘古大模型。我能够通过和您对话互动为您提供帮助。请问您有什么想问我的吗?
"""
import http.client
import json
from typing import Any, List, Mapping, Optional
import requests
from langchain.llms.base import LLM
from langchain.utils import get_from_dict_or_env
class Pangu(LLM):
"""
A custom LLM class that integrates pangu models
"""
n: int
gen_config: dict = None
auth_config: dict = None
def __init__(self, gen_config=None, **kwargs):
super(Pangu, self).__init__(**kwargs)
if gen_config is None:
self.gen_config = {"user": "User", "max_tokens": 50, "temperature": 0.95, "n": 1}
else:
self.gen_config = gen_config
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {"n": self.n}
@property
def _llm_type(self) -> str:
return "pangu"
def _call(self, prompt: str, stop: Optional[List[str]] = None, **kwargs) -> str:
"""
Args:
prompt: The prompt to pass into the model.
stop: A list of strings to stop generation when encountered
Returns:
The string generated by the model
"""
# Update the generation arguments
for key, value in kwargs.items():
if key in self.gen_config:
self.gen_config[key] = value
response = self.text_completion(prompt, self.gen_config, self.auth_config)
text = response["choices"][0]["text"]
if stop is not None:
for stopping_words in stop:
if stopping_words in text:
text = text.split(stopping_words)[0]
return text
def set_auth_config(self, **kwargs):
url = get_from_dict_or_env(kwargs, "url", "URL")
username = get_from_dict_or_env(kwargs, "username", "USERNAME")
password = get_from_dict_or_env(kwargs, "password", "PASSWORD")
domain_name = get_from_dict_or_env(kwargs, "domain_name", "DOMAIN_NAME")
region = url.split(".")[1]
auth_config = {}
auth_config["endpoint"] = url[url.find("https://") + 8 : url.find(".com") + 4]
auth_config["resource_path"] = url[url.find(".com") + 4 :]
auth_config["auth_token"] = self.get_latest_auth_token(region, username, password, domain_name)
self.auth_config = auth_config
def get_latest_auth_token(self, region, username, password, domain_name):
url = f"https://iam.{region}.myhuaweicloud.com/v3/auth/tokens"
payload = json.dumps(
{
"auth": {
"identity": {
"methods": ["password"],
"password": {"user": {"name": username, "password": password, "domain": {"name": domain_name}}},
},
"scope": {"project": {"name": region}},
}
}
)
headers = {"Content-Type": "application/json"}
response = requests.request("POST", url, headers=headers, data=payload)
return response.headers["X-Subject-Token"]
def text_completion(self, text, gen_config, auth_config):
conn = http.client.HTTPSConnection(auth_config["endpoint"])
payload = json.dumps(
{
"prompt": text,
"user": gen_config["user"],
"max_tokens": gen_config["max_tokens"],
"temperature": gen_config["temperature"],
"n": gen_config["n"],
}
)
headers = {
"X-Auth-Token": auth_config["auth_token"],
"Content-Type": "application/json",
}
conn.request("POST", auth_config["resource_path"], payload, headers)
res = conn.getresponse()
data = res.read()
data = json.loads(data.decode("utf-8"))
return data
def chat_model(self, messages, gen_config, auth_config):
conn = http.client.HTTPSConnection(auth_config["endpoint"])
payload = json.dumps(
{
"messages": messages,
"user": gen_config["user"],
"max_tokens": gen_config["max_tokens"],
"temperature": gen_config["temperature"],
"n": gen_config["n"],
}
)
headers = {
"X-Auth-Token": auth_config["auth_token"],
"Content-Type": "application/json",
}
conn.request("POST", auth_config["resource_path"], payload, headers)
res = conn.getresponse()
data = res.read()
data = json.loads(data.decode("utf-8"))
return data

View File

@@ -0,0 +1,29 @@
"""
Generation utilities
"""
import json
from typing import List
import requests
def post_http_request(
prompt: str, api_url: str, n: int = 1, max_tokens: int = 100, temperature: float = 0.0, stream: bool = False
) -> requests.Response:
headers = {"User-Agent": "Test Client"}
pload = {
"prompt": prompt,
"n": 1,
"use_beam_search": False,
"temperature": temperature,
"max_tokens": max_tokens,
"stream": stream,
}
response = requests.post(api_url, headers=headers, json=pload, stream=True, timeout=3)
return response
def get_response(response: requests.Response) -> List[str]:
data = json.loads(response.content)
output = data["text"]
return output

View File

@@ -0,0 +1,168 @@
"""
Implement a memory class for storing conversation history
Support long term and short term memory
"""
from typing import Any, Dict, List
from colossalqa.chain.memory.summary import ConversationSummaryMemory
from colossalqa.chain.retrieval_qa.load_chain import load_qa_chain
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.memory.chat_message_histories.in_memory import ChatMessageHistory
from langchain.schema import BaseChatMessageHistory
from langchain.schema.messages import BaseMessage
from langchain.schema.retriever import BaseRetriever
from pydantic import Field
class ConversationBufferWithSummary(ConversationSummaryMemory):
"""Memory class for storing information about entities."""
# Define dictionary to store information about entities.
# Store the most recent conversation history
buffered_history: BaseChatMessageHistory = Field(default_factory=ChatMessageHistory)
# Temp buffer
summarized_history_temp: BaseChatMessageHistory = Field(default_factory=ChatMessageHistory)
human_prefix: str = "Human"
ai_prefix: str = "Assistant"
buffer: str = "" # Formated conversation in str
existing_summary: str = "" # Summarization of stale converstion in str
# Define key to pass information about entities into prompt.
memory_key: str = "chat_history"
input_key: str = "question"
retriever: BaseRetriever = None
max_tokens: int = 2000
chain: BaseCombineDocumentsChain = None
input_chain_type_kwargs: List = {}
@property
def buffer(self) -> Any:
"""String buffer of memory."""
return self.buffer_as_messages if self.return_messages else self.buffer_as_str
@property
def buffer_as_str(self) -> str:
"""Exposes the buffer as a string in case return_messages is True."""
self.buffer = self.format_dialogue()
return self.buffer
@property
def buffer_as_messages(self) -> List[BaseMessage]:
"""Exposes the buffer as a list of messages in case return_messages is False."""
return self.buffered_history.messages
def clear(self):
"""Clear all the memory"""
self.buffered_history.clear()
self.summarized_history_temp.clear()
def initiate_document_retrieval_chain(
self, llm: Any, prompt_template: Any, retriever: Any, chain_type_kwargs: Dict[str, Any] = {}
) -> None:
"""
Since we need to calculate the length of the prompt, we need to initiate a retrieval chain
to calculate the length of the prompt.
Args:
llm: the language model for the retrieval chain (we won't actually return the output)
prompt_template: the prompt template for constructing the retrieval chain
retriever: the retriever for the retrieval chain
max_tokens: the max length of the prompt (not include the output)
chain_type_kwargs: the kwargs for the retrieval chain
memory_key: the key for the chat history
input_key: the key for the input query
"""
self.retriever = retriever
input_chain_type_kwargs = {k: v for k, v in chain_type_kwargs.items() if k not in [self.memory_key]}
self.input_chain_type_kwargs = input_chain_type_kwargs
self.chain = load_qa_chain(llm, chain_type="stuff", prompt=prompt_template, **self.input_chain_type_kwargs)
@property
def memory_variables(self) -> List[str]:
"""Define the variables we are providing to the prompt."""
return [self.memory_key]
def format_dialogue(self, lang: str = "en") -> str:
"""Format memory into two parts--- summarization of historical conversation and most recent conversation"""
if len(self.summarized_history_temp.messages) != 0:
for i in range(int(len(self.summarized_history_temp.messages) / 2)):
self.existing_summary = (
self.predict_new_summary(
self.summarized_history_temp.messages[i * 2 : i * 2 + 2], self.existing_summary, stop=["\n\n"]
)
.strip()
.split("\n")[0]
.strip()
)
for i in range(int(len(self.summarized_history_temp.messages) / 2)):
self.summarized_history_temp.messages.pop(0)
self.summarized_history_temp.messages.pop(0)
conversation_buffer = []
for t in self.buffered_history.messages:
if t.type == "human":
prefix = self.human_prefix
else:
prefix = self.ai_prefix
conversation_buffer.append(prefix + ": " + t.content)
conversation_buffer = "\n".join(conversation_buffer)
if len(self.existing_summary) > 0:
if lang == "en":
message = f"A summarization of historical conversation:\n{self.existing_summary}\nMost recent conversation:\n{conversation_buffer}"
elif lang == "zh":
message = f"历史对话概要:\n{self.existing_summary}\n最近的对话:\n{conversation_buffer}"
else:
raise ValueError("Unsupported language")
return message
else:
message = conversation_buffer
return message
def get_conversation_length(self):
"""Get the length of the formatted conversation"""
prompt = self.format_dialogue()
length = self.llm.get_num_tokens(prompt)
return length
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
"""Load the memory variables.
Summarize oversize conversation to fit into the length constraint defined by max_tokene
Args:
inputs: the kwargs of the chain of your definition
Returns:
a dict that maps from memory key to the formated dialogue
the formated dialogue has the following format
if conversation is too long:
A summarization of historical conversation:
{summarization}
Most recent conversation:
Human: XXX
Assistant: XXX
...
otherwise
Human: XXX
Assistant: XXX
...
"""
# Calculate remain length
if "input_documents" in inputs:
# Run in a retrieval qa chain
docs = inputs["input_documents"]
else:
# For test
docs = self.retriever.get_relevant_documents(inputs[self.input_key])
inputs[self.memory_key] = ""
inputs = {k: v for k, v in inputs.items() if k in [self.chain.input_key, self.input_key, self.memory_key]}
prompt_length = self.chain.prompt_length(docs, **inputs)
remain = self.max_tokens - prompt_length
while self.get_conversation_length() > remain:
if len(self.buffered_history.messages) <= 2:
raise RuntimeError("Exeeed max_tokens, trunck size of retrieved documents is too large")
temp = self.buffered_history.messages.pop(0)
self.summarized_history_temp.messages.append(temp)
temp = self.buffered_history.messages.pop(0)
self.summarized_history_temp.messages.append(temp)
return {self.memory_key: self.format_dialogue()}
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Save context from this conversation to buffer."""
input_str, output_str = self._get_input_output(inputs, outputs)
self.buffered_history.add_user_message(input_str.strip())
self.buffered_history.add_ai_message(output_str.strip())

View File

@@ -0,0 +1,92 @@
"""
Class for logging with extra control for debugging
"""
import logging
class ColossalQALogger:
"""This is a distributed event logger class essentially based on :class:`logging`.
Args:
name (str): The name of the logger.
Note:
Logging types: ``info``, ``warning``, ``debug`` and ``error``
"""
__instances = dict()
def __init__(self, name):
if name in ColossalQALogger.__instances:
raise ValueError("Logger with the same name has been created")
else:
self._name = name
self._logger = logging.getLogger(name)
ColossalQALogger.__instances[name] = self
@staticmethod
def get_instance(name: str):
"""Get the unique single logger instance based on name.
Args:
name (str): The name of the logger.
Returns:
DistributedLogger: A DistributedLogger object
"""
if name in ColossalQALogger.__instances:
return ColossalQALogger.__instances[name]
else:
logger = ColossalQALogger(name=name)
return logger
def info(self, message: str, verbose: bool = False) -> None:
"""Log an info message.
Args:
message (str): The message to be logged.
verbose (bool): Whether to print the message to stdout.
"""
if verbose:
logging.basicConfig(level=logging.INFO)
self._logger.info(message)
def warning(self, message: str, verbose: bool = False) -> None:
"""Log a warning message.
Args:
message (str): The message to be logged.
verbose (bool): Whether to print the message to stdout.
"""
if verbose:
self._logger.warning(message)
def debug(self, message: str, verbose: bool = False) -> None:
"""Log a debug message.
Args:
message (str): The message to be logged.
verbose (bool): Whether to print the message to stdout.
"""
if verbose:
self._logger.debug(message)
def error(self, message: str) -> None:
"""Log an error message.
Args:
message (str): The message to be logged.
"""
self._logger.error(message)
def get_logger(name: str = None, level=logging.INFO) -> ColossalQALogger:
"""
Get the logger by name, if name is None, return the default logger
"""
if name:
logger = ColossalQALogger.get_instance(name=name)
else:
logger = ColossalQALogger.get_instance(name="colossalqa")
return logger

View File

@@ -0,0 +1,144 @@
# Prompt Design Guide
For the retriever conversation system, users can customize three prompts.
## The Retrieval QA Prompt
This is the prompt for retrieval QA, the input is user's inputs, the retrieved documents, the historical conversation.
### Chinese
```
你是一个善于解答用户问题的AI助手。在保证安全的前提下回答问题要尽可能有帮助。你的答案不应该包含任何有害的、不道德的、种族主义的、性别歧视的、危险的或非法的内容。请确保你的回答是公正和积极的。
如果不能根据给定的上下文推断出答案,请不要分享虚假、不确定的信息。
使用提供的背景信息和聊天记录对用户的输入作出回应或继续对话。您应该只生成一个回复。不需要跟进回答。请使用中文作答。
背景信息:
[retrieved documents]
聊天记录:
[historical conversation, overlength chat history will be summarized]
用户: [question]
Assistant:
```
### English
```
[INST] <<SYS>>Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
If the answer cannot be infered based on the given context, please don't share false information.<</SYS>>
Use the context and chat history to respond to the human's input at the end or carry on the conversation. You should generate one response only. No following up is needed.
context:
[retrieved documents]
chat history
[historical conversation, overlength chat history will be summarized]
Human: {question}
Assistant:
```
## Summarization Prompt
This prompt is used by the memory module to recursively summarize overlength conversation to shrink the length of the prompt.
## Disambiguity Prompt
This prompt is used to perform zero-shot reference resolution to disambiguate entity references within user's questions.
## Final Prompt Examples
Assume k=3 for the retriever.
### English
Note that the "[INST] <<SYS>>...<</SYS>>" template is the specific prompt format used in LLaMA2.
#### Normal Length
```
[INST] <<SYS>>Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
If the answer cannot be infered based on the given context, please don't share false information.<</SYS>>
Use the context and chat history to respond to the human's input at the end or carry on the conversation. You should generate one response only. No following up is needed.
context:
[document 1]
[document 2]
[document 3]
chat history
Human: XXX
Assistant: XXX
...
Human: {question}
Assistant:
```
#### Overlength
```
[INST] <<SYS>>Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
If the answer cannot be infered based on the given context, please don't share false information.<</SYS>>
Use the context and chat history to respond to the human's input at the end or carry on the conversation. You should generate one response only. No following up is needed.
context:
[document 1]
[document 2]
[document 3]
chat history
A summarization of historical conversation:
[one line summary of historical conversation]
Most recent conversation:
Human: XXX
Assistant: XXX
...
Human: {question}
Assistant:
```
### Chinese
#### Normal Length
```
你是一个善于解答用户问题的AI助手。在保证安全的前提下回答问题要尽可能有帮助。你的答案不应该包含任何有害的、不道德的、种族主义的、性别歧视的、危险的或非法的内容。请确保你的回答是公正和积极的。
如果不能根据给定的上下文推断出答案,请不要分享虚假、不确定的信息。
使用提供的背景信息和聊天记录对用户的输入作出回应或继续对话。您应该只生成一个回复。不需要跟进回答。请使用中文作答。
背景信息:
[document 1]
[document 2]
[document 3]
聊天记录:
用户: XXX
Assistant: XXX
...
用户: [question]
Assistant:
```
#### Overlength
```
你是一个善于解答用户问题的AI助手。在保证安全的前提下回答问题要尽可能有帮助。你的答案不应该包含任何有害的、不道德的、种族主义的、性别歧视的、危险的或非法的内容。请确保你的回答是公正和积极的。
如果不能根据给定的上下文推断出答案,请不要分享虚假、不确定的信息。
使用提供的背景信息和聊天记录对用户的输入作出回应或继续对话。您应该只生成一个回复。不需要跟进回答。请使用中文作答。
背景信息:
[document 1]
[document 2]
[document 3]
聊天记录:
历史对话概要:
[one line summary of historical conversation]
最近的对话:
用户: XXX
Assistant: XXX
...
用户: [question]
Assistant:
```

View File

@@ -0,0 +1,124 @@
"""
All custom prompt templates are defined here.
"""
from langchain.prompts.prompt import PromptTemplate
_CUSTOM_SUMMARIZER_TEMPLATE_ZH = """请递进式地总结所提供的当前对话,将当前对话的摘要内容添加到先前已有的摘要上,返回一个融合了当前对话的新的摘要。
例1:
已有的摘要:
人类问Assistant对人工智能的看法。人工智能认为人工智能是一种善的力量。
新的对话内容:
人类: 为什么你认为人工智能是一种好的力量?
Assistant: 因为人工智能将帮助人类充分发挥潜力。
新的摘要:
人类问Assistant对人工智能的看法。人工智能认为人工智能是一种积极的力量因为它将帮助人类充分发挥潜力。
示例结束
已有的摘要:
{summary}
新的对话内容:
{new_lines}
新的摘要:"""
# Chinese retrieval qa prompt
_ZH_RETRIEVAL_QA_PROMPT = """<指令>根据下列支持文档和对话历史,简洁和专业地来回答问题。如果无法从支持文档中得到答案,请说 “根据已知信息无法回答该问题”。回答中请不要涉及支持文档中没有提及的信息,答案请使用中文。 </指令>
{context}
<对话历史>
{chat_history}
</对话历史>
<问题>{question}</问题>
答案:"""
ZH_RETRIEVAL_QA_TRIGGER_KEYWORDS = ["无法回答该问题"]
ZH_RETRIEVAL_QA_REJECTION_ANSWER = "抱歉,根据提供的信息无法回答该问题。"
_ZH_RETRIEVAL_CLASSIFICATION_USE_CASE = """使用提供的参考案例判断客户遇到的故障所属的故障原因分类。
背景信息:
{context}
客服记录:
{question}
故障原因分类:"""
_ZH_DISAMBIGUATION_PROMPT = """你是一个乐于助人、恭敬而诚实的助手。你总是按照指示去做。
请用聊天记录中提到的具体名称或实体名称替换给定句子中的任何模糊或有歧义的指代,如果没有提供聊天记录或句子中不包含模糊或有歧义的指代,则只输出原始句子。您的输出应该是消除歧义的句子本身(与“消除歧义的句子:”在同一行中),并且不包含任何其他内容。
下面是一个例子:
聊天记录:
用户: 我有一个朋友,张三。你认识他吗?
Assistant: 我认识一个叫张三的人
句子: 他最喜欢的食物是什么?
消除歧义的句子: 张三最喜欢的食物是什么?
聊天记录:
{chat_history}
句子: {input}
消除歧义的句子:"""
# English retrieval qa prompt
_EN_RETRIEVAL_QA_PROMPT = """[INST] <<SYS>>Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist content.
If the answer cannot be infered based on the given context, please say "I cannot answer the question based on the information given.".<</SYS>>
Use the context and chat history to answer the question.
context:
{context}
chat history
{chat_history}
question: {question}
answer:"""
EN_RETRIEVAL_QA_TRIGGER_KEYWORDS = ["cannot answer the question"]
EN_RETRIEVAL_QA_REJECTION_ANSWER = "Sorry, this question cannot be answered based on the information provided."
_EN_DISAMBIGUATION_PROMPT = """[INST] <<SYS>>You are a helpful, respectful and honest assistant. You always follow the instruction.<</SYS>>
Please replace any ambiguous references in the given sentence with the specific names or entities mentioned in the chat history or just output the original sentence if no chat history is provided or if the sentence doesn't contain ambiguous references. Your output should be the disambiguated sentence itself (in the same line as "disambiguated sentence:") and contain nothing else.
Here is an example:
Chat history:
Human: I have a friend, Mike. Do you know him?
Assistant: Yes, I know a person named Mike
sentence: What's his favorate food?
disambiguated sentence: What's Mike's favorate food?
[/INST]
Chat history:
{chat_history}
sentence: {input}
disambiguated sentence:"""
PROMPT_RETRIEVAL_QA_EN = PromptTemplate(
template=_EN_RETRIEVAL_QA_PROMPT, input_variables=["question", "chat_history", "context"]
)
PROMPT_DISAMBIGUATE_EN = PromptTemplate(template=_EN_DISAMBIGUATION_PROMPT, input_variables=["chat_history", "input"])
SUMMARY_PROMPT_ZH = PromptTemplate(input_variables=["summary", "new_lines"], template=_CUSTOM_SUMMARIZER_TEMPLATE_ZH)
PROMPT_DISAMBIGUATE_ZH = PromptTemplate(template=_ZH_DISAMBIGUATION_PROMPT, input_variables=["chat_history", "input"])
PROMPT_RETRIEVAL_QA_ZH = PromptTemplate(
template=_ZH_RETRIEVAL_QA_PROMPT, input_variables=["question", "chat_history", "context"]
)
PROMPT_RETRIEVAL_CLASSIFICATION_USE_CASE_ZH = PromptTemplate(
template=_ZH_RETRIEVAL_CLASSIFICATION_USE_CASE, input_variables=["question", "context"]
)

View File

@@ -0,0 +1,87 @@
"""
Script for Chinese retrieval based conversation system backed by ChatGLM
"""
from typing import Tuple
from colossalqa.chain.retrieval_qa.base import RetrievalQA
from colossalqa.local.llm import ColossalAPI, ColossalLLM
from colossalqa.memory import ConversationBufferWithSummary
from colossalqa.mylogging import get_logger
from colossalqa.prompt.prompt import PROMPT_DISAMBIGUATE_EN, PROMPT_RETRIEVAL_QA_EN
from colossalqa.retriever import CustomRetriever
from langchain import LLMChain
logger = get_logger()
class EnglishRetrievalConversation:
"""
Wrapper class for Chinese retrieval conversation system
"""
def __init__(self, retriever: CustomRetriever, model_path: str, model_name: str) -> None:
"""
Setup retrieval qa chain for Chinese retrieval based QA
"""
logger.info(f"model_name: {model_name}; model_path: {model_path}", verbose=True)
colossal_api = ColossalAPI.get_api(model_name, model_path)
self.llm = ColossalLLM(n=1, api=colossal_api)
# Define the retriever
self.retriever = retriever
# Define the chain to preprocess the input
# Disambiguate the input. e.g. "What is the capital of that country?" -> "What is the capital of France?"
# Prompt is summarization prompt
self.llm_chain_disambiguate = LLMChain(
llm=self.llm,
prompt=PROMPT_DISAMBIGUATE_EN,
llm_kwargs={"max_new_tokens": 30, "temperature": 0.6, "do_sample": True},
)
self.retriever.set_rephrase_handler(self.disambiguity)
# Define memory with summarization ability
self.memory = ConversationBufferWithSummary(
llm=self.llm, llm_kwargs={"max_new_tokens": 50, "temperature": 0.6, "do_sample": True}
)
self.memory.initiate_document_retrieval_chain(
self.llm,
PROMPT_RETRIEVAL_QA_EN,
self.retriever,
chain_type_kwargs={
"chat_history": "",
},
)
self.retrieval_chain = RetrievalQA.from_chain_type(
llm=self.llm,
verbose=False,
chain_type="stuff",
retriever=self.retriever,
chain_type_kwargs={"prompt": PROMPT_RETRIEVAL_QA_EN, "memory": self.memory},
llm_kwargs={"max_new_tokens": 50, "temperature": 0.75, "do_sample": True},
)
def disambiguity(self, input: str):
out = self.llm_chain_disambiguate.run(input=input, chat_history=self.memory.buffer, stop=["\n"])
return out.split("\n")[0]
@classmethod
def from_retriever(
cls, retriever: CustomRetriever, model_path: str, model_name: str
) -> "EnglishRetrievalConversation":
return cls(retriever, model_path, model_name)
def run(self, user_input: str, memory: ConversationBufferWithSummary) -> Tuple[str, ConversationBufferWithSummary]:
if memory:
# TODO add translation chain here
self.memory.buffered_history.messages = memory.buffered_history.messages
self.memory.summarized_history_temp.messages = memory.summarized_history_temp.messages
return (
self.retrieval_chain.run(
query=user_input,
stop=[self.memory.human_prefix + ": "],
rejection_trigger_keywrods=["cannot answer the question"],
rejection_answer="Sorry, this question cannot be answered based on the information provided.",
).split("\n")[0],
self.memory,
)

View File

@@ -0,0 +1,138 @@
"""
Multilingual retrieval based conversation system
"""
from typing import List
from colossalqa.data_loader.document_loader import DocumentLoader
from colossalqa.mylogging import get_logger
from colossalqa.retrieval_conversation_en import EnglishRetrievalConversation
from colossalqa.retrieval_conversation_zh import ChineseRetrievalConversation
from colossalqa.retriever import CustomRetriever
from colossalqa.text_splitter import ChineseTextSplitter
from colossalqa.utils import detect_lang_naive
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
logger = get_logger()
class UniversalRetrievalConversation:
"""
Wrapper class for bilingual retrieval conversation system
"""
def __init__(
self,
embedding_model_path: str = "moka-ai/m3e-base",
embedding_model_device: str = "cpu",
zh_model_path: str = None,
zh_model_name: str = None,
en_model_path: str = None,
en_model_name: str = None,
sql_file_path: str = None,
files_zh: List[List[str]] = None,
files_en: List[List[str]] = None,
text_splitter_chunk_size=100,
text_splitter_chunk_overlap=10,
) -> None:
"""
Warpper for multilingual retrieval qa class (Chinese + English)
Args:
embedding_model_path: local or huggingface embedding model
embedding_model_device:
files_zh: [[file_path, name_of_file, separator],...] defines the files used as supporting documents for Chinese retrieval QA
files_en: [[file_path, name_of_file, separator],...] defines the files used as supporting documents for English retrieval QA
"""
self.embedding = HuggingFaceEmbeddings(
model_name=embedding_model_path,
model_kwargs={"device": embedding_model_device},
encode_kwargs={"normalize_embeddings": False},
)
print("Select files for constructing Chinese retriever")
docs_zh = self.load_supporting_docs(
files=files_zh,
text_splitter=ChineseTextSplitter(
chunk_size=text_splitter_chunk_size, chunk_overlap=text_splitter_chunk_overlap
),
)
# Create retriever
self.information_retriever_zh = CustomRetriever(
k=3, sql_file_path=sql_file_path.replace(".db", "_zh.db"), verbose=True
)
self.information_retriever_zh.add_documents(
docs=docs_zh, cleanup="incremental", mode="by_source", embedding=self.embedding
)
print("Select files for constructing English retriever")
docs_en = self.load_supporting_docs(
files=files_en,
text_splitter=RecursiveCharacterTextSplitter(
chunk_size=text_splitter_chunk_size, chunk_overlap=text_splitter_chunk_overlap
),
)
# Create retriever
self.information_retriever_en = CustomRetriever(
k=3, sql_file_path=sql_file_path.replace(".db", "_en.db"), verbose=True
)
self.information_retriever_en.add_documents(
docs=docs_en, cleanup="incremental", mode="by_source", embedding=self.embedding
)
self.chinese_retrieval_conversation = ChineseRetrievalConversation.from_retriever(
self.information_retriever_zh, model_path=zh_model_path, model_name=zh_model_name
)
self.english_retrieval_conversation = EnglishRetrievalConversation.from_retriever(
self.information_retriever_en, model_path=en_model_path, model_name=en_model_name
)
self.memory = None
def load_supporting_docs(self, files: List[List[str]] = None, text_splitter: TextSplitter = None):
"""
Load supporting documents, currently, all documents will be stored in one vector store
"""
documents = []
if files:
for file in files:
retriever_data = DocumentLoader([[file["data_path"], file["name"]]]).all_data
splits = text_splitter.split_documents(retriever_data)
documents.extend(splits)
else:
while True:
file = input("Select a file to load or press Enter to exit:")
if file == "":
break
data_name = input("Enter a short description of the data:")
separator = input(
"Enter a separator to force separating text into chunks, if no separator is given, the defaut separator is '\\n\\n', press ENTER directly to skip:"
)
separator = separator if separator != "" else "\n\n"
retriever_data = DocumentLoader([[file, data_name.replace(" ", "_")]]).all_data
# Split
splits = text_splitter.split_documents(retriever_data)
documents.extend(splits)
return documents
def start_test_session(self):
"""
Simple multilingual session for testing purpose, with naive language selection mechanism
"""
while True:
user_input = input("User: ")
lang = detect_lang_naive(user_input)
if "END" == user_input:
print("Agent: Happy to chat with you )")
break
agent_response = self.run(user_input, which_language=lang)
print(f"Agent: {agent_response}")
def run(self, user_input: str, which_language=str):
"""
Generate the response given the user input and a str indicates the language requirement of the output string
"""
assert which_language in ["zh", "en"]
if which_language == "zh":
agent_response, self.memory = self.chinese_retrieval_conversation.run(user_input, self.memory)
else:
agent_response, self.memory = self.english_retrieval_conversation.run(user_input, self.memory)
return agent_response.split("\n")[0]

View File

@@ -0,0 +1,94 @@
"""
Script for Chinese retrieval based conversation system backed by ChatGLM
"""
from typing import Tuple
from colossalqa.chain.retrieval_qa.base import RetrievalQA
from colossalqa.local.llm import ColossalAPI, ColossalLLM
from colossalqa.memory import ConversationBufferWithSummary
from colossalqa.mylogging import get_logger
from colossalqa.prompt.prompt import PROMPT_DISAMBIGUATE_ZH, PROMPT_RETRIEVAL_QA_ZH, SUMMARY_PROMPT_ZH
from colossalqa.retriever import CustomRetriever
from langchain import LLMChain
logger = get_logger()
class ChineseRetrievalConversation:
"""
Wrapper class for Chinese retrieval conversation system
"""
def __init__(self, retriever: CustomRetriever, model_path: str, model_name: str) -> None:
"""
Setup retrieval qa chain for Chinese retrieval based QA
"""
# Local coati api
logger.info(f"model_name: {model_name}; model_path: {model_path}", verbose=True)
colossal_api = ColossalAPI.get_api(model_name, model_path)
self.llm = ColossalLLM(n=1, api=colossal_api)
# Define the retriever
self.retriever = retriever
# Define the chain to preprocess the input
# Disambiguate the input. e.g. "What is the capital of that country?" -> "What is the capital of France?"
# Prompt is summarization prompt
self.llm_chain_disambiguate = LLMChain(
llm=self.llm,
prompt=PROMPT_DISAMBIGUATE_ZH,
llm_kwargs={"max_new_tokens": 30, "temperature": 0.6, "do_sample": True},
)
self.retriever.set_rephrase_handler(self.disambiguity)
# Define memory with summarization ability
self.memory = ConversationBufferWithSummary(
llm=self.llm,
prompt=SUMMARY_PROMPT_ZH,
human_prefix="用户",
ai_prefix="Assistant",
max_tokens=2000,
llm_kwargs={"max_new_tokens": 50, "temperature": 0.6, "do_sample": True},
)
self.memory.initiate_document_retrieval_chain(
self.llm,
PROMPT_RETRIEVAL_QA_ZH,
self.retriever,
chain_type_kwargs={
"chat_history": "",
},
)
self.retrieval_chain = RetrievalQA.from_chain_type(
llm=self.llm,
verbose=False,
chain_type="stuff",
retriever=self.retriever,
chain_type_kwargs={"prompt": PROMPT_RETRIEVAL_QA_ZH, "memory": self.memory},
llm_kwargs={"max_new_tokens": 150, "temperature": 0.9, "do_sample": True},
)
def disambiguity(self, input: str):
out = self.llm_chain_disambiguate.run(input=input, chat_history=self.memory.buffer, stop=["\n"])
return out.split("\n")[0]
@classmethod
def from_retriever(
cls, retriever: CustomRetriever, model_path: str, model_name: str
) -> "ChineseRetrievalConversation":
return cls(retriever, model_path, model_name)
def run(self, user_input: str, memory: ConversationBufferWithSummary) -> Tuple[str, ConversationBufferWithSummary]:
if memory:
# TODO add translation chain here
self.memory.buffered_history.messages = memory.buffered_history.messages
self.memory.summarized_history_temp.messages = memory.summarized_history_temp.messages
return (
self.retrieval_chain.run(
query=user_input,
stop=["</答案>"],
doc_prefix="支持文档",
rejection_trigger_keywrods=["无法回答该问题"],
rejection_answer="抱歉,根据提供的信息无法回答该问题。",
).split("\n")[0],
self.memory,
)

View File

@@ -0,0 +1,166 @@
"""
Code for custom retriver with incremental update
"""
import copy
import hashlib
import os
from collections import defaultdict
from typing import Any, Callable, Dict, List
from colossalqa.mylogging import get_logger
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
from langchain.embeddings.base import Embeddings
from langchain.indexes import SQLRecordManager, index
from langchain.schema.retriever import BaseRetriever, Document
from langchain.vectorstores.base import VectorStore
from langchain.vectorstores.chroma import Chroma
logger = get_logger()
class CustomRetriever(BaseRetriever):
"""
Custom retriever class with support for incremental update of indexes
"""
vector_stores: Dict[str, VectorStore] = {}
sql_index_database: Dict[str, str] = {}
record_managers: Dict[str, SQLRecordManager] = {}
sql_db_chains = []
k = 3
rephrase_handler: Callable = None
buffer: Dict = []
buffer_size: int = 5
verbose: bool = False
sql_file_path: str = None
@classmethod
def from_documents(
cls,
documents: List[Document],
embeddings: Embeddings,
**kwargs: Any,
) -> BaseRetriever:
k = kwargs.pop("k", 3)
cleanup = kwargs.pop("cleanup", "incremental")
mode = kwargs.pop("mode", "by_source")
ret = cls(k=k)
ret.add_documents(documents, embedding=embeddings, cleanup=cleanup, mode=mode)
return ret
def add_documents(
self,
docs: Dict[str, Document] = [],
cleanup: str = "incremental",
mode: str = "by_source",
embedding: Embeddings = None,
) -> None:
"""
Add documents to retriever
Args:
docs: the documents to add
cleanup: choose from "incremental" (update embeddings, skip existing embeddings) and "full" (destory and rebuild retriever)
mode: choose from "by source" (documents are grouped by source) and "merge" (documents are merged into one vector store)
"""
if cleanup == "full":
# Cleanup
for source in self.vector_stores:
os.remove(self.sql_index_database[source])
# Add documents
data_by_source = defaultdict(list)
if mode == "by_source":
for doc in docs:
data_by_source[doc.metadata["source"]].append(doc)
elif mode == "merge":
data_by_source["merged"] = docs
for source in data_by_source:
if source not in self.vector_stores:
hash_encoding = hashlib.sha3_224(source.encode()).hexdigest()
if os.path.exists(f"{self.sql_file_path}/{hash_encoding}.db"):
# Remove the stale file
os.remove(f"{self.sql_file_path}/{hash_encoding}.db")
# Create a new sql database to store indexes, sql files are stored in the same directory as the source file
sql_path = f"sqlite:///{self.sql_file_path}/{hash_encoding}.db"
self.vector_stores[source] = Chroma(embedding_function=embedding, collection_name=hash_encoding)
self.sql_index_database[source] = f"{self.sql_file_path}/{hash_encoding}.db"
self.record_managers[source] = SQLRecordManager(source, db_url=sql_path)
self.record_managers[source].create_schema()
index(
data_by_source[source],
self.record_managers[source],
self.vector_stores[source],
cleanup=cleanup,
source_id_key="source",
)
def __del__(self):
for source in self.sql_index_database:
if os.path.exists(self.sql_index_database[source]):
os.remove(self.sql_index_database[source])
def set_sql_database_chain(self, db_chains) -> None:
"""
set sql agent chain to retrieve information from sql database
Not used in this version
"""
self.sql_db_chains = db_chains
def set_rephrase_handler(self, handler: Callable = None) -> None:
"""
Set a handler to preprocess the input str before feed into the retriever
"""
self.rephrase_handler = handler
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun = None,
score_threshold: float = None,
return_scores: bool = False,
) -> List[Document]:
"""
This function is called by the retriever to get the relevant documents.
recent vistied queries are stored in buffer, if the query is in buffer, return the documents directly
Args:
query: the query to be searched
run_manager: the callback manager for retriever run
Returns:
documents: the relevant documents
"""
for buffered_doc in self.buffer:
if buffered_doc[0] == query:
return buffered_doc[1]
query_ = str(query)
# Use your existing retriever to get the documents
if self.rephrase_handler:
query = self.rephrase_handler(query)
documents = []
for k in self.vector_stores:
# Retrieve documents from each retriever
vectorstore = self.vector_stores[k]
documents.extend(vectorstore.similarity_search_with_score(query, self.k, score_threshold=score_threshold))
# print(documents)
# Return the top k documents among all retrievers
documents = sorted(documents, key=lambda x: x[1], reverse=False)[: self.k]
if return_scores:
# Return score
documents = copy.deepcopy(documents)
for doc in documents:
doc[0].metadata["score"] = doc[1]
documents = [doc[0] for doc in documents]
# Retrieve documents from sql database (not applicable for the local chains)
for sql_chain in self.sql_db_chains:
documents.append(
Document(
page_content=f"Query: {query} Answer: {sql_chain.run(query)}", metadata={"source": "sql_query"}
)
)
if len(self.buffer) < self.buffer_size:
self.buffer.append([query_, documents])
else:
self.buffer.pop(0)
self.buffer.append([query_, documents])
logger.info(f"retrieved documents:\n{str(documents)}", verbose=self.verbose)
return documents

View File

@@ -0,0 +1 @@
from .chinese_text_splitter import ChineseTextSplitter

View File

@@ -0,0 +1,56 @@
"""
Code for Chinese text splitter
"""
from typing import Any, List, Optional
from colossalqa.text_splitter.utils import get_cleaned_paragraph
from langchain.text_splitter import RecursiveCharacterTextSplitter
class ChineseTextSplitter(RecursiveCharacterTextSplitter):
def __init__(self, separators: Optional[List[str]] = None, is_separator_regrx: bool = False, **kwargs: Any):
self._separators = separators or ["\n\n", "\n", "", "", "", "", "?"]
if "chunk_size" not in kwargs:
kwargs["chunk_size"] = 50
if "chunk_overlap" not in kwargs:
kwargs["chunk_overlap"] = 10
super().__init__(separators=separators, keep_separator=True, **kwargs)
self._is_separator_regex = is_separator_regrx
def split_text(self, text: str) -> List[str]:
"""Return the list of separated text chunks"""
cleaned_paragraph = get_cleaned_paragraph(text)
splitted = []
for paragraph in cleaned_paragraph:
segs = super().split_text(paragraph)
for i in range(len(segs) - 1):
if segs[i][-1] not in self._separators:
pos = text.find(segs[i])
pos_end = pos + len(segs[i])
if i > 0:
last_sentence_start = max([text.rfind(m, 0, pos) for m in ["", "", ""]])
pos = last_sentence_start + 1
segs[i] = str(text[pos:pos_end])
if i != len(segs) - 1:
next_sentence_end = max([text.find(m, pos_end) for m in ["", "", ""]])
segs[i] = str(text[pos : next_sentence_end + 1])
splitted.append(segs[i])
if len(splitted) <= 1:
return splitted
splitted_text = []
i = 1
if splitted[0] not in splitted[1]:
splitted_text.append([splitted[0], 0])
if splitted[-1] not in splitted[-2]:
splitted_text.append([splitted[-1], len(splitted) - 1])
while i < len(splitted) - 1:
if splitted[i] not in splitted[i + 1] and splitted[i] not in splitted[i - 1]:
splitted_text.append([splitted[i], i])
i += 1
splitted_text = sorted(splitted_text, key=lambda x: x[1])
splitted_text = [splitted_text[i][0] for i in range(len(splitted_text))]
ret = []
for s in splitted_text:
if s not in ret:
ret.append(s)
return ret

View File

@@ -0,0 +1,19 @@
import re
def remove_format(text: str) -> str:
# if the accout of \t, \r, \v, \f is less than 3, replace \t, \r, \v, \f with space
if len(re.findall(r"\s", text.replace(" ", ""))) > 3:
# in case this is a line of a table
return text
return re.sub(r"\s", " ", text)
# remove newlines
def get_cleaned_paragraph(s: str) -> str:
text = str(s)
text = re.sub(r"\n{3,}", r"\n", text) # replace \n\n\n... with \n
text = re.sub("\n\n", "", text)
lines = text.split("\n")
lines_remove_format = [remove_format(line) for line in lines]
return lines_remove_format

View File

@@ -0,0 +1,61 @@
import re
from typing import Union
from colossalqa.mylogging import get_logger
from sqlalchemy import Engine, MetaData, create_engine
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.declarative import declarative_base
logger = get_logger()
def drop_table(engine: Engine) -> None:
"""
Drop all existing table
"""
Base = declarative_base()
metadata = MetaData()
metadata.reflect(bind=engine)
for key in metadata.tables:
table = metadata.tables[key]
if table is not None:
Base.metadata.drop_all(engine, [table], checkfirst=True)
def create_empty_sql_database(database_uri):
try:
# Create an SQLAlchemy engine to connect to the database
engine = create_engine(database_uri)
# Create the database
engine.connect()
logger.info(f"Database created at {database_uri}")
except SQLAlchemyError as e:
logger.error(f"Error creating database: {str(e)}")
return engine, database_uri
def destroy_sql_database(sql_engine: Union[Engine, str]) -> None:
"""
Destroy an sql database
"""
if isinstance(sql_engine, str):
sql_engine = create_engine(sql_engine)
drop_table(sql_engine)
sql_engine.dispose()
sql_engine = None
def detect_lang_naive(s):
"""
Naive function for language detection, should be replaced by an independant layer
"""
remove_nota = "[’·°–!\"#$%&'()*+,-./:;<=>?@,。?★、…【】()《》?“”‘’![\\]^_`{|}~]+"
s = re.sub(remove_nota, "", s)
s = re.sub("[0-9]", "", s).strip()
res = re.sub("[a-zA-Z]", "", s).strip()
if len(res) <= 0:
return "en"
else:
return "zh"