mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 01:55:12 +00:00
@@ -77,12 +77,16 @@ if __name__ == "__main__":
|
||||
colossal_api = ColossalAPI(model_name, all_config["model"]["model_path"])
|
||||
llm = ColossalLLM(n=1, api=colossal_api)
|
||||
elif all_config["model"]["mode"] == "api":
|
||||
all_config["chain"]["mem_llm_kwargs"] = None
|
||||
all_config["chain"]["disambig_llm_kwargs"] = None
|
||||
all_config["chain"]["gen_llm_kwargs"] = None
|
||||
if model_name == "pangu_api":
|
||||
from colossalqa.local.pangu_llm import Pangu
|
||||
llm = Pangu(id=1)
|
||||
|
||||
gen_config = {
|
||||
"user": "User",
|
||||
"max_tokens": all_config["chain"]["disambig_llm_kwargs"]["max_new_tokens"],
|
||||
"temperature": all_config["chain"]["disambig_llm_kwargs"]["temperature"],
|
||||
"n": 1 # the number of responses generated
|
||||
}
|
||||
llm = Pangu(gen_config=gen_config)
|
||||
llm.set_auth_config() # verify user's auth info here
|
||||
elif model_name == "chatgpt_api":
|
||||
from langchain.llms import OpenAI
|
||||
|
Reference in New Issue
Block a user