mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-03 09:34:04 +00:00
update config file
This commit is contained in:
parent
a7364c9fe3
commit
a164d2f156
@ -4,34 +4,34 @@
|
|||||||
import torch
|
import torch
|
||||||
import os
|
import os
|
||||||
|
|
||||||
root_path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
model_path = os.path.join(root_path, "models")
|
MODEL_PATH = os.path.join(ROOT_PATH, "models")
|
||||||
vector_storepath = os.path.join(root_path, "vector_store")
|
VECTORE_PATH = os.path.join(ROOT_PATH, "vector_store")
|
||||||
LOGDIR = os.path.join(root_path, "logs")
|
LOGDIR = os.path.join(ROOT_PATH, "logs")
|
||||||
|
|
||||||
|
|
||||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
llm_model_config = {
|
LLM_MODEL_CONFIG = {
|
||||||
"flan-t5-base": os.path.join(model_path, "flan-t5-base"),
|
"flan-t5-base": os.path.join(MODEL_PATH, "flan-t5-base"),
|
||||||
"vicuna-13b": os.path.join(model_path, "vicuna-13b"),
|
"vicuna-13b": os.path.join(MODEL_PATH, "vicuna-13b"),
|
||||||
"sentence-transforms": os.path.join(model_path, "all-MiniLM-L6-v2")
|
"sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
LLM_MODEL = "vicuna-13b"
|
LLM_MODEL = "vicuna-13b"
|
||||||
LIMIT_MODEL_CONCURRENCY = 5
|
LIMIT_MODEL_CONCURRENCY = 5
|
||||||
MAX_POSITION_EMBEDDINGS = 2048
|
MAX_POSITION_EMBEDDINGS = 2048
|
||||||
vicuna_model_server = "http://192.168.31.114:8000"
|
VICUNA_MODEL_SERVER = "http://192.168.31.114:8000"
|
||||||
|
|
||||||
|
|
||||||
# Load model config
|
# Load model config
|
||||||
isload_8bit = True
|
ISLOAD_8BIT = True
|
||||||
isdebug = False
|
ISDEBUG = False
|
||||||
|
|
||||||
|
|
||||||
DB_SETTINGS = {
|
DB_SETTINGS = {
|
||||||
"user": "root",
|
"user": "root",
|
||||||
"password": "********",
|
"password": "aa123456",
|
||||||
"host": "localhost",
|
"host": "localhost",
|
||||||
"port": 3306
|
"port": 3306
|
||||||
}
|
}
|
@ -25,7 +25,7 @@ class VicunaRequestLLM(LLM):
|
|||||||
"stop": stop
|
"stop": stop
|
||||||
}
|
}
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
url=urljoin(vicuna_model_server, self.vicuna_generate_path),
|
url=urljoin(VICUNA_MODEL_SERVER, self.vicuna_generate_path),
|
||||||
data=json.dumps(params),
|
data=json.dumps(params),
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
@ -55,7 +55,7 @@ class VicunaEmbeddingLLM(BaseModel, Embeddings):
|
|||||||
print("Sending prompt ", p)
|
print("Sending prompt ", p)
|
||||||
|
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
url=urljoin(vicuna_model_server, self.vicuna_embedding_path),
|
url=urljoin(VICUNA_MODEL_SERVER, self.vicuna_embedding_path),
|
||||||
json={
|
json={
|
||||||
"prompt": p
|
"prompt": p
|
||||||
}
|
}
|
||||||
|
@ -14,7 +14,7 @@ from pilot.configs.model_config import DB_SETTINGS
|
|||||||
from pilot.connections.mysql_conn import MySQLOperator
|
from pilot.connections.mysql_conn import MySQLOperator
|
||||||
|
|
||||||
|
|
||||||
from pilot.configs.model_config import LOGDIR, vicuna_model_server, LLM_MODEL
|
from pilot.configs.model_config import LOGDIR, VICUNA_MODEL_SERVER, LLM_MODEL
|
||||||
|
|
||||||
from pilot.conversation import (
|
from pilot.conversation import (
|
||||||
default_conversation,
|
default_conversation,
|
||||||
@ -181,7 +181,7 @@ def http_bot(state, db_selector, temperature, max_new_tokens, request: gr.Reques
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Stream output
|
# Stream output
|
||||||
response = requests.post(urljoin(vicuna_model_server, "generate_stream"),
|
response = requests.post(urljoin(VICUNA_MODEL_SERVER, "generate_stream"),
|
||||||
headers=headers, json=payload, stream=True, timeout=20)
|
headers=headers, json=payload, stream=True, timeout=20)
|
||||||
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
||||||
if chunk:
|
if chunk:
|
||||||
|
Loading…
Reference in New Issue
Block a user