From d73cab9c4b4f269e86be343daff4dee9651ce66f Mon Sep 17 00:00:00 2001 From: csunny Date: Sun, 30 Apr 2023 16:05:31 +0800 Subject: [PATCH] add stream output --- pilot/configs/model_config.py | 5 +- pilot/server/embdserver.py | 3 + pilot/server/vicuna_server.py | 52 ++++++++++++++- pilot/server/webserver.py | 117 ++++++++++++++++++++++++++++++++++ 4 files changed, 174 insertions(+), 3 deletions(-) create mode 100644 pilot/server/embdserver.py create mode 100644 pilot/server/webserver.py diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py index 2b9ae6df0..cf26ff7e1 100644 --- a/pilot/configs/model_config.py +++ b/pilot/configs/model_config.py @@ -7,6 +7,7 @@ import os root_path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) model_path = os.path.join(root_path, "models") vector_storepath = os.path.join(root_path, "vector_store") +LOGDIR = os.path.join(root_path, "logs") DEVICE = "cuda" if torch.cuda.is_available() else "cpu" @@ -16,8 +17,8 @@ llm_model_config = { } LLM_MODEL = "vicuna-13b" - - +LIMIT_MODEL_CONCURRENCY = 5 +MAX_POSITION_EMBEDDINGS = 2048 vicuna_model_server = "http://192.168.31.114:8000/" diff --git a/pilot/server/embdserver.py b/pilot/server/embdserver.py new file mode 100644 index 000000000..97206f2d5 --- /dev/null +++ b/pilot/server/embdserver.py @@ -0,0 +1,3 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- + diff --git a/pilot/server/vicuna_server.py b/pilot/server/vicuna_server.py index 9fa5ddbee..91aba5dec 100644 --- a/pilot/server/vicuna_server.py +++ b/pilot/server/vicuna_server.py @@ -2,8 +2,12 @@ # -*- coding: utf-8 -*- import uvicorn +import asyncio +import json from typing import Optional, List -from fastapi import FastAPI +from fastapi import FastAPI, Request, BackgroundTasks +from fastapi.responses import StreamingResponse +from fastchat.serve.inference import generate_stream from pydantic import BaseModel from pilot.model.inference import generate_output, get_embeddings from fastchat.serve.inference import load_model @@ -11,6 +15,11 @@ from pilot.model.loader import ModerLoader from pilot.configs.model_config import * model_path = llm_model_config[LLM_MODEL] + + +global_counter = 0 +model_semaphore = None + # ml = ModerLoader(model_path=model_path) # model, tokenizer = ml.loader(load_8bit=isload_8bit, debug=isdebug) @@ -27,6 +36,47 @@ class PromptRequest(BaseModel): class EmbeddingRequest(BaseModel): prompt: str +def release_model_semaphore(): + model_semaphore.release() + + +def generate_stream_gate(params): + try: + for output in generate_stream( + model, + tokenizer, + params, + DEVICE, + MAX_POSITION_EMBEDDINGS, + 2, + ): + ret = { + "text": output, + "error_code": 0, + } + yield json.dumps(ret).encode() + b"\0" + except torch.cuda.CudaError: + ret = { + "text": "**GPU OutOfMemory, Please Refresh.**", + "error_code": 0 + } + yield json.dumps(ret).encode() + b"\0" + + +@app.post("/generate_stream") +async def api_generate_stream(request: Request): + global model_semaphore, global_counter + global_counter += 1 + params = await request.json() + + if model_semaphore is None: + model_semaphore = asyncio.Semaphore(LIMIT_MODEL_CONCURRENCY) + await model_semaphore.acquire() + + generator = generate_stream_gate(params) + background_tasks = BackgroundTasks() + background_tasks.add_task(release_model_semaphore) + return StreamingResponse(generator, background=background_tasks) @app.post("/generate") def generate(prompt_request: PromptRequest): diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py new file mode 100644 index 000000000..147af8b3d --- /dev/null +++ b/pilot/server/webserver.py @@ -0,0 +1,117 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import os +import uuid +import time +import gradio as gr +import datetime + +from pilot.configs.model_config import LOGDIR + +from pilot.conversation import ( + get_default_conv_template, + compute_skip_echo_len, + SeparatorStyle +) + +from fastchat.utils import ( + build_logger, + server_error_msg, + violates_moderation, + moderation_msg +) + +from fastchat.serve.gradio_patch import Chatbot as grChatbot +from fastchat.serve.gradio_css import code_highlight_css + +logger = build_logger("webserver", "webserver.log") +headers = {"User-Agent": "dbgpt Client"} + +no_change_btn = gr.Button.update() +enable_btn = gr.Button.update(interactive=True) +disable_btn = gr.Button.update(interactive=True) + +enable_moderation = False +models = [] + +priority = { + "vicuna-13b": "aaa" +} + +def set_global_vars(enable_moderation_, models_): + global enable_moderation, models + enable_moderation = enable_moderation_ + models = models_ + +def get_conv_log_filename(): + t = datetime.datetime.now() + name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json") + return name + + +def regenerate(state, request: gr.Request): + logger.info(f"regenerate. ip: {request.client.host}") + state.messages[-1][-1] = None + state.skip_next = False + return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5 + +def clear_history(request: gr.Request): + logger.info(f"clear_history. ip: {request.client.host}") + state = None + return (state, [], "") + (disable_btn,) * 5 + +def add_text(state, text, request: gr.Request): + logger.info(f"add_text. ip: {request.client.host}. len:{len(text)}") + + if state is None: + state = get_default_conv_template("vicuna").copy() + + if len(text) <= 0: + state.skip_next = True + return (state, state.to_gradio_chatbot(), "") + (no_change_btn,) * 5 + + if enable_moderation: + flagged = violates_moderation(text) + if flagged: + logger.info(f"violate moderation. ip: {request.client.host}. text: {text}") + state.skip_next = True + return (state, state.to_gradio_chatbot(), moderation_msg) + (no_change_btn,) * 5 + text = text[:1536] # ? + state.append_message(state.roles[0], text) + state.append_message(state.roles[1], None) + state.skip_next = False + + return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5 + +def post_process_code(code): + sep = "\n```" + if sep in code: + blocks = code.split(sep) + if len(blocks) % 2 == 1: + for i in range(1, len(blocks), 2): + blocks[i] = blocks[i].replace("\\_", "_") + code = sep.join(blocks) + return code + +def http_bot(state, model_selector, temperature, max_new_tokens, request: gr.Request): + logger.info(f"http_bot. ip: {request.client.host}") + start_tstamp = time.time() + + model_name = model_selector + temperature = float(temperature) + max_new_tokens = int(max_new_tokens) + + if state.skip_next: + yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5 + return + + if len(state.message) == state.offset + 2: + new_state = get_default_conv_template(model_name).copy() + new_state.conv_id = uuid.uuid4().hex + new_state.append_message(new_state.roles[0], state.messages[-2][1]) + new_state.append_message(new_state.roles[1], None) + state = new_state + + + # TODO \ No newline at end of file