mirror of
https://github.com/csunny/DB-GPT.git
synced 2026-01-25 14:54:26 +00:00
style:format code style
format code style
This commit is contained in:
@@ -23,9 +23,12 @@ from fastapi import FastAPI, applications
|
||||
from fastapi.openapi.docs import get_swagger_ui_html
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pilot.openapi.knowledge.knowledge_controller import router as knowledge_router
|
||||
|
||||
|
||||
from pilot.openapi.api_v1.api_v1 import router as api_v1, validation_exception_handler
|
||||
|
||||
|
||||
static_file_path = os.path.join(os.getcwd(), "server/static")
|
||||
|
||||
CFG = Config()
|
||||
@@ -34,9 +37,10 @@ logger = build_logger("webserver", LOGDIR + "webserver.log")
|
||||
|
||||
def swagger_monkey_patch(*args, **kwargs):
|
||||
return get_swagger_ui_html(
|
||||
*args, **kwargs,
|
||||
swagger_js_url='https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui-bundle.js',
|
||||
swagger_css_url='https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui.css'
|
||||
*args,
|
||||
**kwargs,
|
||||
swagger_js_url="https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui-bundle.js",
|
||||
swagger_css_url="https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui.css"
|
||||
)
|
||||
|
||||
|
||||
@@ -55,14 +59,16 @@ app.add_middleware(
|
||||
)
|
||||
|
||||
app.mount("/static", StaticFiles(directory=static_file_path), name="static")
|
||||
app.add_route("/test", "static/test.html")
|
||||
|
||||
app.add_route("/test", "static/test.html")
|
||||
app.include_router(knowledge_router)
|
||||
app.include_router(api_v1)
|
||||
app.add_exception_handler(RequestValidationError, validation_exception_handler)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model_list_mode", type=str, default="once", choices=["once", "reload"])
|
||||
parser.add_argument(
|
||||
"--model_list_mode", type=str, default="once", choices=["once", "reload"]
|
||||
)
|
||||
|
||||
# old version server config
|
||||
parser.add_argument("--host", type=str, default="0.0.0.0")
|
||||
@@ -75,4 +81,5 @@ if __name__ == "__main__":
|
||||
server_init(args)
|
||||
CFG.NEW_SERVER_MODE = True
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=5000)
|
||||
|
||||
@@ -9,7 +9,8 @@ import sys
|
||||
import uvicorn
|
||||
from fastapi import BackgroundTasks, FastAPI, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
# from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel
|
||||
|
||||
global_counter = 0
|
||||
@@ -41,11 +42,11 @@ class ModelWorker:
|
||||
|
||||
if not isinstance(self.model, str):
|
||||
if hasattr(self.model, "config") and hasattr(
|
||||
self.model.config, "max_sequence_length"
|
||||
self.model.config, "max_sequence_length"
|
||||
):
|
||||
self.context_len = self.model.config.max_sequence_length
|
||||
elif hasattr(self.model, "config") and hasattr(
|
||||
self.model.config, "max_position_embeddings"
|
||||
self.model.config, "max_position_embeddings"
|
||||
):
|
||||
self.context_len = self.model.config.max_position_embeddings
|
||||
|
||||
@@ -60,22 +61,22 @@ class ModelWorker:
|
||||
|
||||
def get_queue_length(self):
|
||||
if (
|
||||
model_semaphore is None
|
||||
or model_semaphore._value is None
|
||||
or model_semaphore._waiters is None
|
||||
model_semaphore is None
|
||||
or model_semaphore._value is None
|
||||
or model_semaphore._waiters is None
|
||||
):
|
||||
return 0
|
||||
else:
|
||||
(
|
||||
CFG.LIMIT_MODEL_CONCURRENCY
|
||||
- model_semaphore._value
|
||||
+ len(model_semaphore._waiters)
|
||||
CFG.LIMIT_MODEL_CONCURRENCY
|
||||
- model_semaphore._value
|
||||
+ len(model_semaphore._waiters)
|
||||
)
|
||||
|
||||
def generate_stream_gate(self, params):
|
||||
try:
|
||||
for output in self.generate_stream_func(
|
||||
self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS
|
||||
self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS
|
||||
):
|
||||
# Please do not open the output in production!
|
||||
# The gpt4all thread shares stdout with the parent process,
|
||||
@@ -107,23 +108,23 @@ worker = ModelWorker(
|
||||
)
|
||||
|
||||
app = FastAPI()
|
||||
from pilot.openapi.knowledge.knowledge_controller import router
|
||||
|
||||
app.include_router(router)
|
||||
|
||||
origins = [
|
||||
"http://localhost",
|
||||
"http://localhost:8000",
|
||||
"http://localhost:3000",
|
||||
]
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
# from pilot.openapi.knowledge.knowledge_controller import router
|
||||
#
|
||||
# app.include_router(router)
|
||||
#
|
||||
# origins = [
|
||||
# "http://localhost",
|
||||
# "http://localhost:8000",
|
||||
# "http://localhost:3000",
|
||||
# ]
|
||||
#
|
||||
# app.add_middleware(
|
||||
# CORSMiddleware,
|
||||
# allow_origins=origins,
|
||||
# allow_credentials=True,
|
||||
# allow_methods=["*"],
|
||||
# allow_headers=["*"],
|
||||
# )
|
||||
|
||||
|
||||
class PromptRequest(BaseModel):
|
||||
|
||||
@@ -40,6 +40,7 @@ def server_init(args):
|
||||
cfg = Config()
|
||||
|
||||
from pilot.server.llmserver import worker
|
||||
|
||||
worker.start_check()
|
||||
load_native_plugins(cfg)
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
Reference in New Issue
Block a user