add stream output

This commit is contained in:
csunny 2023-04-30 16:05:31 +08:00
parent 909045c0d6
commit d73cab9c4b
4 changed files with 174 additions and 3 deletions

View File

@ -7,6 +7,7 @@ import os
root_path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
model_path = os.path.join(root_path, "models")
vector_storepath = os.path.join(root_path, "vector_store")
LOGDIR = os.path.join(root_path, "logs")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
@ -16,8 +17,8 @@ llm_model_config = {
}
LLM_MODEL = "vicuna-13b"
LIMIT_MODEL_CONCURRENCY = 5
MAX_POSITION_EMBEDDINGS = 2048
vicuna_model_server = "http://192.168.31.114:8000/"

View File

@ -0,0 +1,3 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-

View File

@ -2,8 +2,12 @@
# -*- coding: utf-8 -*-
import uvicorn
import asyncio
import json
from typing import Optional, List
from fastapi import FastAPI
from fastapi import FastAPI, Request, BackgroundTasks
from fastapi.responses import StreamingResponse
from fastchat.serve.inference import generate_stream
from pydantic import BaseModel
from pilot.model.inference import generate_output, get_embeddings
from fastchat.serve.inference import load_model
@ -11,6 +15,11 @@ from pilot.model.loader import ModerLoader
from pilot.configs.model_config import *
model_path = llm_model_config[LLM_MODEL]
global_counter = 0
model_semaphore = None
# ml = ModerLoader(model_path=model_path)
# model, tokenizer = ml.loader(load_8bit=isload_8bit, debug=isdebug)
@ -27,6 +36,47 @@ class PromptRequest(BaseModel):
class EmbeddingRequest(BaseModel):
prompt: str
def release_model_semaphore():
model_semaphore.release()
def generate_stream_gate(params):
try:
for output in generate_stream(
model,
tokenizer,
params,
DEVICE,
MAX_POSITION_EMBEDDINGS,
2,
):
ret = {
"text": output,
"error_code": 0,
}
yield json.dumps(ret).encode() + b"\0"
except torch.cuda.CudaError:
ret = {
"text": "**GPU OutOfMemory, Please Refresh.**",
"error_code": 0
}
yield json.dumps(ret).encode() + b"\0"
@app.post("/generate_stream")
async def api_generate_stream(request: Request):
global model_semaphore, global_counter
global_counter += 1
params = await request.json()
if model_semaphore is None:
model_semaphore = asyncio.Semaphore(LIMIT_MODEL_CONCURRENCY)
await model_semaphore.acquire()
generator = generate_stream_gate(params)
background_tasks = BackgroundTasks()
background_tasks.add_task(release_model_semaphore)
return StreamingResponse(generator, background=background_tasks)
@app.post("/generate")
def generate(prompt_request: PromptRequest):

117
pilot/server/webserver.py Normal file
View File

@ -0,0 +1,117 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import uuid
import time
import gradio as gr
import datetime
from pilot.configs.model_config import LOGDIR
from pilot.conversation import (
get_default_conv_template,
compute_skip_echo_len,
SeparatorStyle
)
from fastchat.utils import (
build_logger,
server_error_msg,
violates_moderation,
moderation_msg
)
from fastchat.serve.gradio_patch import Chatbot as grChatbot
from fastchat.serve.gradio_css import code_highlight_css
logger = build_logger("webserver", "webserver.log")
headers = {"User-Agent": "dbgpt Client"}
no_change_btn = gr.Button.update()
enable_btn = gr.Button.update(interactive=True)
disable_btn = gr.Button.update(interactive=True)
enable_moderation = False
models = []
priority = {
"vicuna-13b": "aaa"
}
def set_global_vars(enable_moderation_, models_):
global enable_moderation, models
enable_moderation = enable_moderation_
models = models_
def get_conv_log_filename():
t = datetime.datetime.now()
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
return name
def regenerate(state, request: gr.Request):
logger.info(f"regenerate. ip: {request.client.host}")
state.messages[-1][-1] = None
state.skip_next = False
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
def clear_history(request: gr.Request):
logger.info(f"clear_history. ip: {request.client.host}")
state = None
return (state, [], "") + (disable_btn,) * 5
def add_text(state, text, request: gr.Request):
logger.info(f"add_text. ip: {request.client.host}. len:{len(text)}")
if state is None:
state = get_default_conv_template("vicuna").copy()
if len(text) <= 0:
state.skip_next = True
return (state, state.to_gradio_chatbot(), "") + (no_change_btn,) * 5
if enable_moderation:
flagged = violates_moderation(text)
if flagged:
logger.info(f"violate moderation. ip: {request.client.host}. text: {text}")
state.skip_next = True
return (state, state.to_gradio_chatbot(), moderation_msg) + (no_change_btn,) * 5
text = text[:1536] # ?
state.append_message(state.roles[0], text)
state.append_message(state.roles[1], None)
state.skip_next = False
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
def post_process_code(code):
sep = "\n```"
if sep in code:
blocks = code.split(sep)
if len(blocks) % 2 == 1:
for i in range(1, len(blocks), 2):
blocks[i] = blocks[i].replace("\\_", "_")
code = sep.join(blocks)
return code
def http_bot(state, model_selector, temperature, max_new_tokens, request: gr.Request):
logger.info(f"http_bot. ip: {request.client.host}")
start_tstamp = time.time()
model_name = model_selector
temperature = float(temperature)
max_new_tokens = int(max_new_tokens)
if state.skip_next:
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
return
if len(state.message) == state.offset + 2:
new_state = get_default_conv_template(model_name).copy()
new_state.conv_id = uuid.uuid4().hex
new_state.append_message(new_state.roles[0], state.messages[-2][1])
new_state.append_message(new_state.roles[1], None)
state = new_state
# TODO