Unified configuration

This commit is contained in:
yhjun1026
2023-05-18 16:30:57 +08:00
parent a68e164a5f
commit ba7e23d37f
13 changed files with 97 additions and 53 deletions

View File

@@ -13,8 +13,11 @@ from pilot.model.inference import generate_output, get_embeddings
from pilot.model.loader import ModelLoader
from pilot.configs.model_config import *
from pilot.configs.config import Config
model_path = LLM_MODEL_CONFIG[LLM_MODEL]
CFG = Config()
model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL]
global_counter = 0
@@ -60,7 +63,7 @@ def generate_stream_gate(params):
tokenizer,
params,
DEVICE,
MAX_POSITION_EMBEDDINGS,
CFG.MAX_POSITION_EMBEDDINGS,
):
print("output: ", output)
ret = {
@@ -84,7 +87,7 @@ async def api_generate_stream(request: Request):
print(model, tokenizer, params, DEVICE)
if model_semaphore is None:
model_semaphore = asyncio.Semaphore(LIMIT_MODEL_CONCURRENCY)
model_semaphore = asyncio.Semaphore(CFG.LIMIT_MODEL_CONCURRENCY)
await model_semaphore.acquire()
generator = generate_stream_gate(params)