update config file

This commit is contained in:
csunny 2023-05-05 20:06:24 +08:00
parent a7364c9fe3
commit a164d2f156
3 changed files with 16 additions and 16 deletions

View File

@ -4,34 +4,34 @@
import torch import torch
import os import os
root_path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
model_path = os.path.join(root_path, "models") MODEL_PATH = os.path.join(ROOT_PATH, "models")
vector_storepath = os.path.join(root_path, "vector_store") VECTORE_PATH = os.path.join(ROOT_PATH, "vector_store")
LOGDIR = os.path.join(root_path, "logs") LOGDIR = os.path.join(ROOT_PATH, "logs")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
llm_model_config = { LLM_MODEL_CONFIG = {
"flan-t5-base": os.path.join(model_path, "flan-t5-base"), "flan-t5-base": os.path.join(MODEL_PATH, "flan-t5-base"),
"vicuna-13b": os.path.join(model_path, "vicuna-13b"), "vicuna-13b": os.path.join(MODEL_PATH, "vicuna-13b"),
"sentence-transforms": os.path.join(model_path, "all-MiniLM-L6-v2") "sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2")
} }
LLM_MODEL = "vicuna-13b" LLM_MODEL = "vicuna-13b"
LIMIT_MODEL_CONCURRENCY = 5 LIMIT_MODEL_CONCURRENCY = 5
MAX_POSITION_EMBEDDINGS = 2048 MAX_POSITION_EMBEDDINGS = 2048
vicuna_model_server = "http://192.168.31.114:8000" VICUNA_MODEL_SERVER = "http://192.168.31.114:8000"
# Load model config # Load model config
isload_8bit = True ISLOAD_8BIT = True
isdebug = False ISDEBUG = False
DB_SETTINGS = { DB_SETTINGS = {
"user": "root", "user": "root",
"password": "********", "password": "aa123456",
"host": "localhost", "host": "localhost",
"port": 3306 "port": 3306
} }

View File

@ -25,7 +25,7 @@ class VicunaRequestLLM(LLM):
"stop": stop "stop": stop
} }
response = requests.post( response = requests.post(
url=urljoin(vicuna_model_server, self.vicuna_generate_path), url=urljoin(VICUNA_MODEL_SERVER, self.vicuna_generate_path),
data=json.dumps(params), data=json.dumps(params),
) )
response.raise_for_status() response.raise_for_status()
@ -55,7 +55,7 @@ class VicunaEmbeddingLLM(BaseModel, Embeddings):
print("Sending prompt ", p) print("Sending prompt ", p)
response = requests.post( response = requests.post(
url=urljoin(vicuna_model_server, self.vicuna_embedding_path), url=urljoin(VICUNA_MODEL_SERVER, self.vicuna_embedding_path),
json={ json={
"prompt": p "prompt": p
} }

View File

@ -14,7 +14,7 @@ from pilot.configs.model_config import DB_SETTINGS
from pilot.connections.mysql_conn import MySQLOperator from pilot.connections.mysql_conn import MySQLOperator
from pilot.configs.model_config import LOGDIR, vicuna_model_server, LLM_MODEL from pilot.configs.model_config import LOGDIR, VICUNA_MODEL_SERVER, LLM_MODEL
from pilot.conversation import ( from pilot.conversation import (
default_conversation, default_conversation,
@ -181,7 +181,7 @@ def http_bot(state, db_selector, temperature, max_new_tokens, request: gr.Reques
try: try:
# Stream output # Stream output
response = requests.post(urljoin(vicuna_model_server, "generate_stream"), response = requests.post(urljoin(VICUNA_MODEL_SERVER, "generate_stream"),
headers=headers, json=payload, stream=True, timeout=20) headers=headers, json=payload, stream=True, timeout=20)
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk: if chunk: