style:format code style

format code style
This commit is contained in:
aries_ckt
2023-06-29 13:52:53 +08:00
parent 359babecdc
commit 4029f48d5f
12 changed files with 205 additions and 109 deletions

View File

@@ -23,9 +23,12 @@ from fastapi import FastAPI, applications
from fastapi.openapi.docs import get_swagger_ui_html
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from pilot.openapi.knowledge.knowledge_controller import router as knowledge_router
from pilot.openapi.api_v1.api_v1 import router as api_v1, validation_exception_handler
static_file_path = os.path.join(os.getcwd(), "server/static")
CFG = Config()
@@ -34,9 +37,10 @@ logger = build_logger("webserver", LOGDIR + "webserver.log")
def swagger_monkey_patch(*args, **kwargs):
return get_swagger_ui_html(
*args, **kwargs,
swagger_js_url='https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui-bundle.js',
swagger_css_url='https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui.css'
*args,
**kwargs,
swagger_js_url="https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui-bundle.js",
swagger_css_url="https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui.css"
)
@@ -55,14 +59,16 @@ app.add_middleware(
)
app.mount("/static", StaticFiles(directory=static_file_path), name="static")
app.add_route("/test", "static/test.html")
app.add_route("/test", "static/test.html")
app.include_router(knowledge_router)
app.include_router(api_v1)
app.add_exception_handler(RequestValidationError, validation_exception_handler)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_list_mode", type=str, default="once", choices=["once", "reload"])
parser.add_argument(
"--model_list_mode", type=str, default="once", choices=["once", "reload"]
)
# old version server config
parser.add_argument("--host", type=str, default="0.0.0.0")
@@ -75,4 +81,5 @@ if __name__ == "__main__":
server_init(args)
CFG.NEW_SERVER_MODE = True
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=5000)

View File

@@ -9,7 +9,8 @@ import sys
import uvicorn
from fastapi import BackgroundTasks, FastAPI, Request
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
# from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
global_counter = 0
@@ -41,11 +42,11 @@ class ModelWorker:
if not isinstance(self.model, str):
if hasattr(self.model, "config") and hasattr(
self.model.config, "max_sequence_length"
self.model.config, "max_sequence_length"
):
self.context_len = self.model.config.max_sequence_length
elif hasattr(self.model, "config") and hasattr(
self.model.config, "max_position_embeddings"
self.model.config, "max_position_embeddings"
):
self.context_len = self.model.config.max_position_embeddings
@@ -60,22 +61,22 @@ class ModelWorker:
def get_queue_length(self):
if (
model_semaphore is None
or model_semaphore._value is None
or model_semaphore._waiters is None
model_semaphore is None
or model_semaphore._value is None
or model_semaphore._waiters is None
):
return 0
else:
(
CFG.LIMIT_MODEL_CONCURRENCY
- model_semaphore._value
+ len(model_semaphore._waiters)
CFG.LIMIT_MODEL_CONCURRENCY
- model_semaphore._value
+ len(model_semaphore._waiters)
)
def generate_stream_gate(self, params):
try:
for output in self.generate_stream_func(
self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS
self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS
):
# Please do not open the output in production!
# The gpt4all thread shares stdout with the parent process,
@@ -107,23 +108,23 @@ worker = ModelWorker(
)
app = FastAPI()
from pilot.openapi.knowledge.knowledge_controller import router
app.include_router(router)
origins = [
"http://localhost",
"http://localhost:8000",
"http://localhost:3000",
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# from pilot.openapi.knowledge.knowledge_controller import router
#
# app.include_router(router)
#
# origins = [
# "http://localhost",
# "http://localhost:8000",
# "http://localhost:3000",
# ]
#
# app.add_middleware(
# CORSMiddleware,
# allow_origins=origins,
# allow_credentials=True,
# allow_methods=["*"],
# allow_headers=["*"],
# )
class PromptRequest(BaseModel):

View File

@@ -40,6 +40,7 @@ def server_init(args):
cfg = Config()
from pilot.server.llmserver import worker
worker.start_check()
load_native_plugins(cfg)
signal.signal(signal.SIGINT, signal_handler)