Unified configuration

This commit is contained in:
yhjun1026
2023-05-18 16:30:57 +08:00
parent a68e164a5f
commit ba7e23d37f
13 changed files with 97 additions and 53 deletions

View File

@@ -13,8 +13,11 @@ from pilot.model.inference import generate_output, get_embeddings
from pilot.model.loader import ModelLoader
from pilot.configs.model_config import *
from pilot.configs.config import Config
model_path = LLM_MODEL_CONFIG[LLM_MODEL]
CFG = Config()
model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL]
global_counter = 0
@@ -60,7 +63,7 @@ def generate_stream_gate(params):
tokenizer,
params,
DEVICE,
MAX_POSITION_EMBEDDINGS,
CFG.MAX_POSITION_EMBEDDINGS,
):
print("output: ", output)
ret = {
@@ -84,7 +87,7 @@ async def api_generate_stream(request: Request):
print(model, tokenizer, params, DEVICE)
if model_semaphore is None:
model_semaphore = asyncio.Semaphore(LIMIT_MODEL_CONCURRENCY)
model_semaphore = asyncio.Semaphore(CFG.LIMIT_MODEL_CONCURRENCY)
await model_semaphore.acquire()
generator = generate_stream_gate(params)

View File

@@ -14,13 +14,13 @@ from urllib.parse import urljoin
from langchain import PromptTemplate
from pilot.configs.model_config import DB_SETTINGS, KNOWLEDGE_UPLOAD_ROOT_PATH, LLM_MODEL_CONFIG
from pilot.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH, LLM_MODEL_CONFIG
from pilot.server.vectordb_qa import KnownLedgeBaseQA
from pilot.connections.mysql import MySQLOperator
from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding
from pilot.vector_store.extract_tovec import get_vector_storelist, load_knownledge_from_doc, knownledge_tovec_st
from pilot.configs.model_config import LOGDIR, VICUNA_MODEL_SERVER, LLM_MODEL, DATASETS_DIR
from pilot.configs.model_config import LOGDIR, DATASETS_DIR
from pilot.plugins import scan_plugins
from pilot.configs.config import Config
@@ -67,7 +67,15 @@ priority = {
"vicuna-13b": "aaa"
}
# 加载插件
CFG= Config()
DB_SETTINGS = {
"user": CFG.LOCAL_DB_USER,
"password": CFG.LOCAL_DB_PASSWORD,
"host": CFG.LOCAL_DB_HOST,
"port": CFG.LOCAL_DB_PORT
}
def get_simlar(q):
docsearch = knownledge_tovec_st(os.path.join(DATASETS_DIR, "plan.md"))
docs = docsearch.similarity_search_with_score(q, k=1)
@@ -178,7 +186,7 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
print("是否是AUTO-GPT模式.", autogpt)
start_tstamp = time.time()
model_name = LLM_MODEL
model_name = CFG.LLM_MODEL
dbname = db_selector
# TODO 这里的请求需要拼接现有知识库, 使得其根据现有知识库作答, 所以prompt需要继续优化
@@ -268,7 +276,7 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
logger.info(f"Requert: \n{payload}")
if sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
response = requests.post(urljoin(VICUNA_MODEL_SERVER, "generate"),
response = requests.post(urljoin(CFG.MODEL_SERVER, "generate"),
headers=headers, json=payload, timeout=120)
print(response.json())
@@ -316,7 +324,7 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
try:
# Stream output
response = requests.post(urljoin(VICUNA_MODEL_SERVER, "generate_stream"),
response = requests.post(urljoin(CFG.MODEL_SERVER, "generate_stream"),
headers=headers, json=payload, stream=True, timeout=20)
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
@@ -595,9 +603,8 @@ if __name__ == "__main__":
# dbs = get_database_list()
# 加载插件
# 配置初始化
cfg = Config()
cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode))
# 加载插件可执行命令