[colossalqa] fix pangu api (#5170)

* fix pangu api

* add comment
This commit is contained in:
Michelle
2023-12-11 14:08:11 +08:00
committed by GitHub
parent 21aa5de00b
commit b07a6f4e27
2 changed files with 9 additions and 5 deletions

View File

@@ -77,12 +77,16 @@ if __name__ == "__main__":
colossal_api = ColossalAPI(model_name, all_config["model"]["model_path"])
llm = ColossalLLM(n=1, api=colossal_api)
elif all_config["model"]["mode"] == "api":
all_config["chain"]["mem_llm_kwargs"] = None
all_config["chain"]["disambig_llm_kwargs"] = None
all_config["chain"]["gen_llm_kwargs"] = None
if model_name == "pangu_api":
from colossalqa.local.pangu_llm import Pangu
llm = Pangu(id=1)
gen_config = {
"user": "User",
"max_tokens": all_config["chain"]["disambig_llm_kwargs"]["max_new_tokens"],
"temperature": all_config["chain"]["disambig_llm_kwargs"]["temperature"],
"n": 1 # the number of responses generated
}
llm = Pangu(gen_config=gen_config)
llm.set_auth_config() # verify user's auth info here
elif model_name == "chatgpt_api":
from langchain.llms import OpenAI