diff --git a/pilot/server/chat_adapter.py b/pilot/server/chat_adapter.py index ded0a1b19..805cacb3d 100644 --- a/pilot/server/chat_adapter.py +++ b/pilot/server/chat_adapter.py @@ -15,7 +15,7 @@ class BaseChatAdpter: def get_generate_stream_func(self): """Return the generate stream handler func""" pass - + llm_model_chat_adapters: List[BaseChatAdpter] = [] diff --git a/pilot/server/llmserver.py b/pilot/server/llmserver.py index 33d3d545d..fa1da5608 100644 --- a/pilot/server/llmserver.py +++ b/pilot/server/llmserver.py @@ -23,19 +23,67 @@ from pilot.model.inference import generate_output, get_embeddings from pilot.model.loader import ModelLoader from pilot.configs.model_config import * from pilot.configs.config import Config +from pilot.server.chat_adapter import get_llm_chat_adapter CFG = Config() model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL] - -ml = ModelLoader(model_path=model_path) -model, tokenizer = ml.loader(num_gpus=1, load_8bit=ISLOAD_8BIT, debug=ISDEBUG) +print(model_path) class ModelWorker: - def __init__(self): - pass - # TODO + def __init__(self, model_path, model_name, device, num_gpus=1): + + if model_path.endswith("/"): + model_path = model_path[:-1] + self.model_name = model_name or model_path.split("/")[-1] + self.device = device + + self.ml = ModelLoader(model_path=model_path) + self.model, self.tokenizer = self.ml.loader(num_gpus, load_8bit=ISLOAD_8BIT, debug=ISDEBUG) + + if hasattr(self.model.config, "max_sequence_length"): + self.context_len = self.model.config.max_sequence_length + elif hasattr(self.model.config, "max_position_embeddings"): + self.context_len = self.model.config.max_position_embeddings + + else: + self.context_len = 2048 + + self.llm_chat_adapter = get_llm_chat_adapter(model_path) + self.generate_stream_func = self.llm_chat_adapter.get_generate_stream_func() + + def get_queue_length(self): + if model_semaphore is None or model_semaphore._value is None or model_semaphore._waiters is None: + return 0 + else: + CFG.LIMIT_MODEL_CONCURRENCY - model_semaphore._value + len(model_semaphore._waiters) + + def generate_stream_gate(self, params): + try: + for output in self.generate_stream_func( + self.model, + self.tokenizer, + params, + DEVICE, + CFG.MAX_POSITION_EMBEDDINGS + ): + print("output: ", output) + ret = { + "text": output, + "error_code": 0, + } + yield json.dumps(ret).encode() + b"\0" + + except torch.cuda.CudaError: + ret = { + "text": "**GPU OutOfMemory, Please Refresh.**", + "error_code": 0 + } + yield json.dumps(ret).encode() + b"\0" + + def get_embeddings(self, prompt): + return get_embeddings(self.model, self.tokenizer, prompt) app = FastAPI() @@ -60,41 +108,17 @@ def release_model_semaphore(): model_semaphore.release() -def generate_stream_gate(params): - try: - for output in generate_stream( - model, - tokenizer, - params, - DEVICE, - CFG.MAX_POSITION_EMBEDDINGS, - ): - print("output: ", output) - ret = { - "text": output, - "error_code": 0, - } - yield json.dumps(ret).encode() + b"\0" - except torch.cuda.CudaError: - ret = { - "text": "**GPU OutOfMemory, Please Refresh.**", - "error_code": 0 - } - yield json.dumps(ret).encode() + b"\0" - - @app.post("/generate_stream") async def api_generate_stream(request: Request): global model_semaphore, global_counter global_counter += 1 params = await request.json() - print(model, tokenizer, params, DEVICE) if model_semaphore is None: model_semaphore = asyncio.Semaphore(CFG.LIMIT_MODEL_CONCURRENCY) await model_semaphore.acquire() - generator = generate_stream_gate(params) + generator = worker.generate_stream_gate(params) background_tasks = BackgroundTasks() background_tasks.add_task(release_model_semaphore) return StreamingResponse(generator, background=background_tasks) @@ -110,7 +134,7 @@ def generate(prompt_request: PromptRequest): response = [] rsp_str = "" - output = generate_stream_gate(params) + output = worker.generate_stream_gate(params) for rsp in output: # rsp = rsp.decode("utf-8") rsp_str = str(rsp, "utf-8") @@ -124,9 +148,18 @@ def generate(prompt_request: PromptRequest): def embeddings(prompt_request: EmbeddingRequest): params = {"prompt": prompt_request.prompt} print("Received prompt: ", params["prompt"]) - output = get_embeddings(model, tokenizer, params["prompt"]) + output = worker.get_embeddings(params["prompt"]) return {"response": [float(x) for x in output]} if __name__ == "__main__": + + + worker = ModelWorker( + model_path=model_path, + model_name=CFG.LLM_MODEL, + device=DEVICE, + num_gpus=1 + ) + uvicorn.run(app, host="0.0.0.0", port=CFG.MODEL_PORT, log_level="info") \ No newline at end of file