mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-10-08 23:24:27 +00:00
125 lines
3.5 KiB
Python
125 lines
3.5 KiB
Python
import atexit
|
|
import traceback
|
|
import os
|
|
import shutil
|
|
import argparse
|
|
import sys
|
|
import logging
|
|
|
|
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
sys.path.append(ROOT_PATH)
|
|
import signal
|
|
from pilot.configs.config import Config
|
|
|
|
# from pilot.configs.model_config import (
|
|
# DATASETS_DIR,
|
|
# KNOWLEDGE_UPLOAD_ROOT_PATH,
|
|
# LLM_MODEL_CONFIG,
|
|
# LOGDIR,
|
|
# )
|
|
from pilot.utils import build_logger
|
|
|
|
from pilot.server.base import server_init
|
|
|
|
from fastapi.staticfiles import StaticFiles
|
|
from fastapi import FastAPI, applications
|
|
from fastapi.openapi.docs import get_swagger_ui_html
|
|
from fastapi.exceptions import RequestValidationError
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from pilot.server.knowledge.api import router as knowledge_router
|
|
|
|
|
|
from pilot.openapi.api_v1.api_v1 import router as api_v1, validation_exception_handler
|
|
from pilot.openapi.api_v1.editor.api_editor_v1 import router as api_editor_route_v1
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
static_file_path = os.path.join(os.getcwd(), "server/static")
|
|
|
|
|
|
CFG = Config()
|
|
# logger = build_logger("webserver", LOGDIR + "webserver.log")
|
|
|
|
|
|
def signal_handler():
|
|
print("in order to avoid chroma db atexit problem")
|
|
os._exit(0)
|
|
|
|
|
|
def swagger_monkey_patch(*args, **kwargs):
|
|
return get_swagger_ui_html(
|
|
*args,
|
|
**kwargs,
|
|
swagger_js_url="https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui-bundle.js",
|
|
swagger_css_url="https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui.css"
|
|
)
|
|
|
|
|
|
applications.get_swagger_ui_html = swagger_monkey_patch
|
|
|
|
app = FastAPI()
|
|
origins = ["*"]
|
|
|
|
# 添加跨域中间件
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=origins,
|
|
allow_credentials=True,
|
|
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
|
|
# app.include_router(api_v1, prefix="/api")
|
|
# app.include_router(knowledge_router, prefix="/api")
|
|
# app.include_router(api_editor_route_v1, prefix="/api")
|
|
|
|
app.include_router(api_v1)
|
|
app.include_router(knowledge_router)
|
|
app.include_router(api_editor_route_v1)
|
|
|
|
app.mount("/_next/static", StaticFiles(directory=static_file_path + "/_next/static"))
|
|
app.mount("/", StaticFiles(directory=static_file_path, html=True), name="static")
|
|
# app.mount("/chat", StaticFiles(directory=static_file_path + "/chat.html", html=True), name="chat")
|
|
|
|
|
|
app.add_exception_handler(RequestValidationError, validation_exception_handler)
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"--model_list_mode", type=str, default="once", choices=["once", "reload"]
|
|
)
|
|
|
|
# old version server config
|
|
parser.add_argument("--host", type=str, default="0.0.0.0")
|
|
parser.add_argument("--port", type=int, default=5000)
|
|
parser.add_argument("--concurrency-count", type=int, default=10)
|
|
parser.add_argument("--share", default=False, action="store_true")
|
|
parser.add_argument(
|
|
"-light",
|
|
"--light",
|
|
default=False,
|
|
action="store_true",
|
|
help="enable light mode",
|
|
)
|
|
|
|
# init server config
|
|
args = parser.parse_args()
|
|
server_init(args)
|
|
|
|
if not args.light:
|
|
print("Model Unified Deployment Mode!")
|
|
from pilot.server.llmserver import worker
|
|
|
|
worker.start_check()
|
|
CFG.NEW_SERVER_MODE = True
|
|
else:
|
|
CFG.SERVER_LIGHT_MODE = True
|
|
|
|
import uvicorn
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
uvicorn.run(app, host="0.0.0.0", port=args.port, log_level=0)
|
|
signal.signal(signal.SIGINT, signal_handler())
|