#!/usr/bin/env python3 # -*- coding: utf-8 -*- import asyncio import json import os import sys from typing import List import uvicorn from fastapi import BackgroundTasks, FastAPI, Request from fastapi.responses import StreamingResponse # from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel global_counter = 0 model_semaphore = None ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(ROOT_PATH) from pilot.configs.config import Config from pilot.configs.model_config import * from pilot.model.llm_out.vicuna_base_llm import get_embeddings from pilot.model.loader import ModelLoader, _get_model_real_path from pilot.server.chat_adapter import get_llm_chat_adapter from pilot.scene.base_message import ModelMessage CFG = Config() class ModelWorker: def __init__(self, model_path, model_name, device): if model_path.endswith("/"): model_path = model_path[:-1] model_path = _get_model_real_path(model_name, model_path) # self.model_name = model_name or model_path.split("/")[-1] self.device = device print( f"Loading {model_name} LLM ModelServer in {device} from model path {model_path}! Please Wait......" ) self.ml: ModelLoader = ModelLoader(model_path=model_path, model_name=model_name) self.model, self.tokenizer = self.ml.loader( load_8bit=CFG.IS_LOAD_8BIT, load_4bit=CFG.IS_LOAD_4BIT, debug=ISDEBUG, max_gpu_memory=CFG.MAX_GPU_MEMORY, ) if not isinstance(self.model, str): if hasattr(self.model, "config") and hasattr( self.model.config, "max_sequence_length" ): self.context_len = self.model.config.max_sequence_length elif hasattr(self.model, "config") and hasattr( self.model.config, "max_position_embeddings" ): self.context_len = self.model.config.max_position_embeddings else: self.context_len = 2048 self.llm_chat_adapter = get_llm_chat_adapter(model_name, model_path) self.generate_stream_func = self.llm_chat_adapter.get_generate_stream_func( model_path ) def start_check(self): print("LLM Model Loading Success!") def get_queue_length(self): if ( model_semaphore is None or model_semaphore._value is None or model_semaphore._waiters is None ): return 0 else: ( CFG.LIMIT_MODEL_CONCURRENCY - model_semaphore._value + len(model_semaphore._waiters) ) def generate_stream_gate(self, params): try: # params adaptation params, model_context = self.llm_chat_adapter.model_adaptation( params, self.ml.model_path, prompt_template=self.ml.prompt_template ) for output in self.generate_stream_func( self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS ): # Please do not open the output in production! # The gpt4all thread shares stdout with the parent process, # and opening it may affect the frontend output. print("output: ", output) # return some model context to dgt-server ret = {"text": output, "error_code": 0, "model_context": model_context} yield json.dumps(ret).encode() + b"\0" except torch.cuda.CudaError: ret = {"text": "**GPU OutOfMemory, Please Refresh.**", "error_code": 0} yield json.dumps(ret).encode() + b"\0" except Exception as e: ret = { "text": f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}", "error_code": 0, } yield json.dumps(ret).encode() + b"\0" def get_embeddings(self, prompt): return get_embeddings(self.model, self.tokenizer, prompt) model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL] worker = ModelWorker(model_path=model_path, model_name=CFG.LLM_MODEL, device=DEVICE) app = FastAPI() # from pilot.openapi.knowledge.knowledge_controller import router # # app.include_router(router) # # origins = [ # "http://localhost", # "http://localhost:8000", # "http://localhost:3000", # ] # # app.add_middleware( # CORSMiddleware, # allow_origins=origins, # allow_credentials=True, # allow_methods=["*"], # allow_headers=["*"], # ) class PromptRequest(BaseModel): messages: List[ModelMessage] prompt: str temperature: float max_new_tokens: int model: str stop: str = None echo: bool = True class StreamRequest(BaseModel): model: str prompt: str temperature: float max_new_tokens: int stop: str class EmbeddingRequest(BaseModel): prompt: str def release_model_semaphore(): model_semaphore.release() @app.post("/generate_stream") async def api_generate_stream(request: Request): global model_semaphore, global_counter global_counter += 1 params = await request.json() if model_semaphore is None: model_semaphore = asyncio.Semaphore(CFG.LIMIT_MODEL_CONCURRENCY) await model_semaphore.acquire() generator = worker.generate_stream_gate(params) background_tasks = BackgroundTasks() background_tasks.add_task(release_model_semaphore) return StreamingResponse(generator, background=background_tasks) @app.post("/generate") def generate(prompt_request: PromptRequest) -> str: params = { "messages": prompt_request.messages, "prompt": prompt_request.prompt, "temperature": prompt_request.temperature, "max_new_tokens": prompt_request.max_new_tokens, "stop": prompt_request.stop, "echo": prompt_request.echo, } rsp_str = "" output = worker.generate_stream_gate(params) for rsp in output: # rsp = rsp.decode("utf-8") rsp = rsp.replace(b"\0", b"") rsp_str = rsp.decode() return rsp_str @app.post("/embedding") def embeddings(prompt_request: EmbeddingRequest): params = {"prompt": prompt_request.prompt} print("Received prompt: ", params["prompt"]) output = worker.get_embeddings(params["prompt"]) return {"response": [float(x) for x in output]} if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=CFG.MODEL_PORT, log_level="info")