mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-04-28 11:45:23 +00:00
* add langchain * add langchain * Add files via upload * add langchain * fix style * fix style: remove extra space * add pytest; modified retriever * add pytest; modified retriever * add tests to build_on_pr.yml * fix build_on_pr.yml * fix build on pr; fix environ vars * seperate unit tests for colossalqa from build from pr * fix container setting; fix environ vars * commented dev code * add incremental update * remove stale code * fix style * change to sha3 224 * fix retriever; fix style; add unit test for document loader * fix ci workflow config * fix ci workflow config * add set cuda visible device script in ci * fix doc string * fix style; update readme; refactored * add force log info * change build on pr, ignore colossalqa * fix docstring, captitalize all initial letters * fix indexing; fix text-splitter * remove debug code, update reference * reset previous commit * update LICENSE update README add key-value mode, fix bugs * add files back * revert force push * remove junk file * add test files * fix retriever bug, add intent classification * change conversation chain design * rewrite prompt and conversation chain * add ui v1 * ui v1 * fix atavar * add header * Refactor the RAG Code and support Pangu * Refactor the ColossalQA chain to Object-Oriented Programming and the UI demo. * resolved conversation. tested scripts under examples. web demo still buggy * fix ci tests * Some modifications to add ChatGPT api * modify llm.py and remove unnecessary files * Delete applications/ColossalQA/examples/ui/test_frontend_input.json * Remove OpenAI api key * add colossalqa * move files * move files * move files * move files * fix style * Add Readme and fix some bugs. * Add something to readme and modify some code * modify a directory name for clarity * remove redundant directory * Correct a type in llm.py * fix AI prefix * fix test_memory.py * fix conversation * fix some erros and typos * Fix a missing import in RAG_ChatBot.py * add colossalcloud LLM wrapper, correct issues in code review --------- Co-authored-by: YeAnbang <anbangy2@outlook.com> Co-authored-by: Orion-Zheng <zheng_zian@u.nus.edu> Co-authored-by: Zian(Andy) Zheng <62330719+Orion-Zheng@users.noreply.github.com> Co-authored-by: Orion-Zheng <zhengzian@u.nus.edu>
126 lines
4.9 KiB
Python
126 lines
4.9 KiB
Python
"""
|
||
Script for the multilingual conversation based experimental AI agent
|
||
We used ChatGPT as the language model
|
||
You need openai api key to run this script
|
||
"""
|
||
|
||
import argparse
|
||
import os
|
||
|
||
from colossalqa.data_loader.document_loader import DocumentLoader
|
||
from colossalqa.data_loader.table_dataloader import TableLoader
|
||
from langchain import LLMChain, OpenAI
|
||
from langchain.agents import Tool, ZeroShotAgent
|
||
from langchain.agents.agent import AgentExecutor
|
||
from langchain.agents.agent_toolkits import create_retriever_tool
|
||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||
from langchain.llms import OpenAI
|
||
from langchain.memory import ChatMessageHistory, ConversationBufferMemory
|
||
from langchain.memory.chat_memory import ChatMessageHistory
|
||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||
from langchain.utilities import SQLDatabase
|
||
from langchain.vectorstores import Chroma
|
||
from langchain_experimental.sql import SQLDatabaseChain
|
||
|
||
if __name__ == "__main__":
|
||
parser = argparse.ArgumentParser(description="Experimental AI agent powered by ChatGPT")
|
||
parser.add_argument("--open_ai_key_path", type=str, default=None, help="path to the plain text open_ai_key file")
|
||
|
||
args = parser.parse_args()
|
||
|
||
# Setup openai key
|
||
# Set env var OPENAI_API_KEY or load from a file
|
||
openai_key = open(args.open_ai_key_path).read()
|
||
os.environ["OPENAI_API_KEY"] = openai_key
|
||
|
||
# Load data served on sql
|
||
print("Select files for constructing sql database")
|
||
tools = []
|
||
|
||
llm = OpenAI(temperature=0.0)
|
||
|
||
while True:
|
||
file = input("Select a file to load or press Enter to exit:")
|
||
if file == "":
|
||
break
|
||
data_name = input("Enter a short description of the data:")
|
||
|
||
table_loader = TableLoader(
|
||
[[file, data_name.replace(" ", "_")]], sql_path=f"sqlite:///{data_name.replace(' ', '_')}.db"
|
||
)
|
||
sql_path = table_loader.get_sql_path()
|
||
|
||
# Create sql database
|
||
db = SQLDatabase.from_uri(sql_path)
|
||
print(db.get_table_info())
|
||
|
||
db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True)
|
||
name = f"Query the SQL database regarding {data_name}"
|
||
description = (
|
||
f"useful for when you need to answer questions based on data stored on a SQL database regarding {data_name}"
|
||
)
|
||
tools.append(
|
||
Tool(
|
||
name=name,
|
||
func=db_chain.run,
|
||
description=description,
|
||
)
|
||
)
|
||
print(f"Added sql dataset\n\tname={name}\n\tdescription:{description}")
|
||
|
||
# VectorDB
|
||
embedding = OpenAIEmbeddings()
|
||
|
||
# Load data serve on sql
|
||
print("Select files for constructing retriever")
|
||
while True:
|
||
file = input("Select a file to load or press Enter to exit:")
|
||
if file == "":
|
||
break
|
||
data_name = input("Enter a short description of the data:")
|
||
retriever_data = DocumentLoader([[file, data_name.replace(" ", "_")]]).all_data
|
||
|
||
# Split
|
||
text_splitter = RecursiveCharacterTextSplitter(chunk_size=200, chunk_overlap=0)
|
||
splits = text_splitter.split_documents(retriever_data)
|
||
|
||
# Create vector store
|
||
vectordb = Chroma.from_documents(documents=splits, embedding=embedding)
|
||
# Create retriever
|
||
retriever = vectordb.as_retriever(
|
||
search_type="similarity_score_threshold", search_kwargs={"score_threshold": 0.5, "k": 5}
|
||
)
|
||
# Add to tool chain
|
||
name = f"Searches and returns documents regarding {data_name}."
|
||
tools.append(create_retriever_tool(retriever, data_name, name))
|
||
|
||
prefix = """Have a conversation with a human, answering the following questions as best you can. You have access to the following tools. If none of the tools can be used to answer the question. Do not share uncertain answer unless you think answering the question doesn't need any background information. In that case, try to answer the question directly."""
|
||
suffix = """You are provided with the following background knowledge:
|
||
Begin!"
|
||
|
||
{chat_history}
|
||
Question: {input}
|
||
{agent_scratchpad}"""
|
||
|
||
prompt = ZeroShotAgent.create_prompt(
|
||
tools,
|
||
prefix=prefix,
|
||
suffix=suffix,
|
||
input_variables=["input", "chat_history", "agent_scratchpad"],
|
||
)
|
||
|
||
memory = ConversationBufferMemory(memory_key="chat_history", chat_memory=ChatMessageHistory())
|
||
|
||
llm_chain = LLMChain(llm=OpenAI(temperature=0.7), prompt=prompt)
|
||
agent = ZeroShotAgent(llm_chain=llm_chain, tools=tools, verbose=True)
|
||
agent_chain = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True, memory=memory)
|
||
|
||
while True:
|
||
user_input = input("User: ")
|
||
if " end " in user_input:
|
||
print("Agent: Happy to chat with you :)")
|
||
break
|
||
agent_response = agent_chain.run(user_input)
|
||
print(f"Agent: {agent_response}")
|
||
table_loader.sql_engine.dispose()
|