mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-24 06:29:09 +00:00
parent
21aa5de00b
commit
b07a6f4e27
@ -154,7 +154,7 @@ class ConversationBufferWithSummary(ConversationSummaryMemory):
|
|||||||
remain = self.max_tokens - prompt_length
|
remain = self.max_tokens - prompt_length
|
||||||
while self.get_conversation_length() > remain:
|
while self.get_conversation_length() > remain:
|
||||||
if len(self.buffered_history.messages) <= 2:
|
if len(self.buffered_history.messages) <= 2:
|
||||||
raise RuntimeError("Exeeed max_tokens, trunck size of retrieved documents is too large")
|
raise RuntimeError("Exceed max_tokens, trunk size of retrieved documents is too large")
|
||||||
temp = self.buffered_history.messages.pop(0)
|
temp = self.buffered_history.messages.pop(0)
|
||||||
self.summarized_history_temp.messages.append(temp)
|
self.summarized_history_temp.messages.append(temp)
|
||||||
temp = self.buffered_history.messages.pop(0)
|
temp = self.buffered_history.messages.pop(0)
|
||||||
|
@ -77,12 +77,16 @@ if __name__ == "__main__":
|
|||||||
colossal_api = ColossalAPI(model_name, all_config["model"]["model_path"])
|
colossal_api = ColossalAPI(model_name, all_config["model"]["model_path"])
|
||||||
llm = ColossalLLM(n=1, api=colossal_api)
|
llm = ColossalLLM(n=1, api=colossal_api)
|
||||||
elif all_config["model"]["mode"] == "api":
|
elif all_config["model"]["mode"] == "api":
|
||||||
all_config["chain"]["mem_llm_kwargs"] = None
|
|
||||||
all_config["chain"]["disambig_llm_kwargs"] = None
|
|
||||||
all_config["chain"]["gen_llm_kwargs"] = None
|
|
||||||
if model_name == "pangu_api":
|
if model_name == "pangu_api":
|
||||||
from colossalqa.local.pangu_llm import Pangu
|
from colossalqa.local.pangu_llm import Pangu
|
||||||
llm = Pangu(id=1)
|
|
||||||
|
gen_config = {
|
||||||
|
"user": "User",
|
||||||
|
"max_tokens": all_config["chain"]["disambig_llm_kwargs"]["max_new_tokens"],
|
||||||
|
"temperature": all_config["chain"]["disambig_llm_kwargs"]["temperature"],
|
||||||
|
"n": 1 # the number of responses generated
|
||||||
|
}
|
||||||
|
llm = Pangu(gen_config=gen_config)
|
||||||
llm.set_auth_config() # verify user's auth info here
|
llm.set_auth_config() # verify user's auth info here
|
||||||
elif model_name == "chatgpt_api":
|
elif model_name == "chatgpt_api":
|
||||||
from langchain.llms import OpenAI
|
from langchain.llms import OpenAI
|
||||||
|
Loading…
Reference in New Issue
Block a user