diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py index 1238d1bcb..675c51b66 100644 --- a/pilot/configs/model_config.py +++ b/pilot/configs/model_config.py @@ -4,34 +4,34 @@ import torch import os -root_path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -model_path = os.path.join(root_path, "models") -vector_storepath = os.path.join(root_path, "vector_store") -LOGDIR = os.path.join(root_path, "logs") +ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +MODEL_PATH = os.path.join(ROOT_PATH, "models") +VECTORE_PATH = os.path.join(ROOT_PATH, "vector_store") +LOGDIR = os.path.join(ROOT_PATH, "logs") DEVICE = "cuda" if torch.cuda.is_available() else "cpu" -llm_model_config = { - "flan-t5-base": os.path.join(model_path, "flan-t5-base"), - "vicuna-13b": os.path.join(model_path, "vicuna-13b"), - "sentence-transforms": os.path.join(model_path, "all-MiniLM-L6-v2") +LLM_MODEL_CONFIG = { + "flan-t5-base": os.path.join(MODEL_PATH, "flan-t5-base"), + "vicuna-13b": os.path.join(MODEL_PATH, "vicuna-13b"), + "sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2") } LLM_MODEL = "vicuna-13b" LIMIT_MODEL_CONCURRENCY = 5 MAX_POSITION_EMBEDDINGS = 2048 -vicuna_model_server = "http://192.168.31.114:8000" +VICUNA_MODEL_SERVER = "http://192.168.31.114:8000" # Load model config -isload_8bit = True -isdebug = False +ISLOAD_8BIT = True +ISDEBUG = False DB_SETTINGS = { "user": "root", - "password": "********", + "password": "aa123456", "host": "localhost", "port": 3306 } \ No newline at end of file diff --git a/pilot/model/vicuna_llm.py b/pilot/model/vicuna_llm.py index 26673344f..eba2834ae 100644 --- a/pilot/model/vicuna_llm.py +++ b/pilot/model/vicuna_llm.py @@ -25,7 +25,7 @@ class VicunaRequestLLM(LLM): "stop": stop } response = requests.post( - url=urljoin(vicuna_model_server, self.vicuna_generate_path), + url=urljoin(VICUNA_MODEL_SERVER, self.vicuna_generate_path), data=json.dumps(params), ) response.raise_for_status() @@ -55,7 +55,7 @@ class VicunaEmbeddingLLM(BaseModel, Embeddings): print("Sending prompt ", p) response = requests.post( - url=urljoin(vicuna_model_server, self.vicuna_embedding_path), + url=urljoin(VICUNA_MODEL_SERVER, self.vicuna_embedding_path), json={ "prompt": p } diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index c13a5331f..4468082d8 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -14,7 +14,7 @@ from pilot.configs.model_config import DB_SETTINGS from pilot.connections.mysql_conn import MySQLOperator -from pilot.configs.model_config import LOGDIR, vicuna_model_server, LLM_MODEL +from pilot.configs.model_config import LOGDIR, VICUNA_MODEL_SERVER, LLM_MODEL from pilot.conversation import ( default_conversation, @@ -181,7 +181,7 @@ def http_bot(state, db_selector, temperature, max_new_tokens, request: gr.Reques try: # Stream output - response = requests.post(urljoin(vicuna_model_server, "generate_stream"), + response = requests.post(urljoin(VICUNA_MODEL_SERVER, "generate_stream"), headers=headers, json=payload, stream=True, timeout=20) for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): if chunk: