From 1c90cbc64c2757616fe571aa807b9c7d10c9013d Mon Sep 17 00:00:00 2001 From: yhjun1026 <460342015@qq.com> Date: Fri, 13 Oct 2023 17:35:43 +0800 Subject: [PATCH] feat(Agent): ChatAgent And AgentHub 1.LLM TongYiQianWen WenXinYiYan ZhiPu support --- .../chat_excel/excel_analyze/chat.py | 13 ++-- pilot/server/dbgpt_server.py | 60 +++++++++---------- 2 files changed, 35 insertions(+), 38 deletions(-) diff --git a/pilot/scene/chat_data/chat_excel/excel_analyze/chat.py b/pilot/scene/chat_data/chat_excel/excel_analyze/chat.py index 9d3cbd0d9..4611aa14d 100644 --- a/pilot/scene/chat_data/chat_excel/excel_analyze/chat.py +++ b/pilot/scene/chat_data/chat_excel/excel_analyze/chat.py @@ -42,11 +42,14 @@ class ChatExcel(BaseChat): def _generate_numbered_list(self) -> str: command_strings = [] if CFG.command_disply: - command_strings += [ - str(item) - for item in CFG.command_disply.commands.values() - if item.enabled - ] + for name, item in CFG.command_disply.commands.items(): + if item.enabled: + command_strings.append(f"{name}:{item.description}") + # command_strings += [ + # str(item) + # for item in CFG.command_disply.commands.values() + # if item.enabled + # ] return "\n".join(f"{i+1}. {item}" for i, item in enumerate(command_strings)) def generate_input_values(self): diff --git a/pilot/server/dbgpt_server.py b/pilot/server/dbgpt_server.py index 4a679a63b..12aac586e 100644 --- a/pilot/server/dbgpt_server.py +++ b/pilot/server/dbgpt_server.py @@ -20,6 +20,7 @@ from pilot.server.component_configs import initialize_components from fastapi.staticfiles import StaticFiles from fastapi import FastAPI, applications from fastapi.openapi.docs import get_swagger_ui_html +from fastapi.openapi.utils import get_openapi from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from pilot.server.knowledge.api import router as knowledge_router @@ -56,11 +57,9 @@ def swagger_monkey_patch(*args, **kwargs): swagger_js_url="https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui-bundle.js", swagger_css_url="https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui.css" ) - - +app = FastAPI() applications.get_swagger_ui_html = swagger_monkey_patch -app = FastAPI() system_app = SystemApp(app) origins = ["*"] @@ -77,15 +76,27 @@ app.add_middleware( app.include_router(api_v1, prefix="/api", tags=["Chat"]) app.include_router(api_editor_route_v1, prefix="/api", tags=["Editor"]) -app.include_router(llm_manage_api, prefix="/api") -app.include_router(api_fb_v1, prefix="/api") +app.include_router(llm_manage_api, prefix="/api", tags=["LLM Manage"]) +app.include_router(api_fb_v1, prefix="/api", tags=["FeedBack"]) app.include_router(knowledge_router, tags=["Knowledge"]) -# app.include_router(api_v1) -app.include_router(prompt_router) + +app.include_router(prompt_router, tags=["Prompt"]) # app.include_router(api_editor_route_v1) +@app.get("/openapi.json") +async def get_openapi_endpoint(): + return get_openapi( + title="Your API title", + version="1.0.0", + description="Your API description", + routes=app.routes, + ) + +@app.get("/docs") +async def get_docs(): + return get_swagger_ui_html(openapi_url="/openapi.json", title="API docs") def mount_static_files(app): os.makedirs(static_message_img_path, exist_ok=True) @@ -102,22 +113,17 @@ def mount_static_files(app): app.add_exception_handler(RequestValidationError, validation_exception_handler) - -def _get_webserver_params(args: List[str] = None): - from pilot.utils.parameter_utils import EnvArgumentParser - - parser: argparse.ArgumentParser = EnvArgumentParser.create_argparse_option( - WebWerverParameters - ) - return WebWerverParameters(**vars(parser.parse_args(args=args))) - - def initialize_app(param: WebWerverParameters = None, args: List[str] = None): """Initialize app If you use gunicorn as a process manager, initialize_app can be invoke in `on_starting` hook. """ if not param: - param = _get_webserver_params(args) + from pilot.utils.parameter_utils import EnvArgumentParser + + parser: argparse.ArgumentParser = EnvArgumentParser.create_argparse_option( + WebWerverParameters + ) + param = WebWerverParameters(**vars(parser.parse_args(args=args))) if not param.log_level: param.log_level = _get_logging_level() @@ -136,7 +142,7 @@ def initialize_app(param: WebWerverParameters = None, args: List[str] = None): model_start_listener = _create_model_start_listener(system_app) initialize_components(param, system_app, embedding_model_name, embedding_model_path) - model_path = LLM_MODEL_CONFIG.get(CFG.LLM_MODEL) + model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL] if not param.light: print("Model Unified Deployment Mode!") if not param.remote_embedding: @@ -183,20 +189,8 @@ def run_uvicorn(param: WebWerverParameters): def run_webserver(param: WebWerverParameters = None): - if not param: - param = _get_webserver_params() - initialize_tracer(system_app, os.path.join(LOGDIR, "dbgpt_webserver_tracer.jsonl")) - - with root_tracer.start_span( - "run_webserver", - span_type=SpanType.RUN, - metadata={ - "run_service": SpanTypeRunName.WEBSERVER, - "params": _get_dict_from_obj(param), - }, - ): - param = initialize_app(param) - run_uvicorn(param) + param = initialize_app(param) + run_uvicorn(param) if __name__ == "__main__":