update config file

This commit is contained in:
csunny 2023-05-05 20:06:24 +08:00
parent a7364c9fe3
commit a164d2f156
3 changed files with 16 additions and 16 deletions

View File

@ -4,34 +4,34 @@
import torch
import os
root_path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
model_path = os.path.join(root_path, "models")
vector_storepath = os.path.join(root_path, "vector_store")
LOGDIR = os.path.join(root_path, "logs")
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
MODEL_PATH = os.path.join(ROOT_PATH, "models")
VECTORE_PATH = os.path.join(ROOT_PATH, "vector_store")
LOGDIR = os.path.join(ROOT_PATH, "logs")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
llm_model_config = {
"flan-t5-base": os.path.join(model_path, "flan-t5-base"),
"vicuna-13b": os.path.join(model_path, "vicuna-13b"),
"sentence-transforms": os.path.join(model_path, "all-MiniLM-L6-v2")
LLM_MODEL_CONFIG = {
"flan-t5-base": os.path.join(MODEL_PATH, "flan-t5-base"),
"vicuna-13b": os.path.join(MODEL_PATH, "vicuna-13b"),
"sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2")
}
LLM_MODEL = "vicuna-13b"
LIMIT_MODEL_CONCURRENCY = 5
MAX_POSITION_EMBEDDINGS = 2048
vicuna_model_server = "http://192.168.31.114:8000"
VICUNA_MODEL_SERVER = "http://192.168.31.114:8000"
# Load model config
isload_8bit = True
isdebug = False
ISLOAD_8BIT = True
ISDEBUG = False
DB_SETTINGS = {
"user": "root",
"password": "********",
"password": "aa123456",
"host": "localhost",
"port": 3306
}

View File

@ -25,7 +25,7 @@ class VicunaRequestLLM(LLM):
"stop": stop
}
response = requests.post(
url=urljoin(vicuna_model_server, self.vicuna_generate_path),
url=urljoin(VICUNA_MODEL_SERVER, self.vicuna_generate_path),
data=json.dumps(params),
)
response.raise_for_status()
@ -55,7 +55,7 @@ class VicunaEmbeddingLLM(BaseModel, Embeddings):
print("Sending prompt ", p)
response = requests.post(
url=urljoin(vicuna_model_server, self.vicuna_embedding_path),
url=urljoin(VICUNA_MODEL_SERVER, self.vicuna_embedding_path),
json={
"prompt": p
}

View File

@ -14,7 +14,7 @@ from pilot.configs.model_config import DB_SETTINGS
from pilot.connections.mysql_conn import MySQLOperator
from pilot.configs.model_config import LOGDIR, vicuna_model_server, LLM_MODEL
from pilot.configs.model_config import LOGDIR, VICUNA_MODEL_SERVER, LLM_MODEL
from pilot.conversation import (
default_conversation,
@ -181,7 +181,7 @@ def http_bot(state, db_selector, temperature, max_new_tokens, request: gr.Reques
try:
# Stream output
response = requests.post(urljoin(vicuna_model_server, "generate_stream"),
response = requests.post(urljoin(VICUNA_MODEL_SERVER, "generate_stream"),
headers=headers, json=payload, stream=True, timeout=20)
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk: