add stream output

This commit is contained in:
csunny 2023-04-30 16:05:31 +08:00
parent 909045c0d6
commit d73cab9c4b
4 changed files with 174 additions and 3 deletions

View File

@ -7,6 +7,7 @@ import os
root_path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) root_path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
model_path = os.path.join(root_path, "models") model_path = os.path.join(root_path, "models")
vector_storepath = os.path.join(root_path, "vector_store") vector_storepath = os.path.join(root_path, "vector_store")
LOGDIR = os.path.join(root_path, "logs")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
@ -16,8 +17,8 @@ llm_model_config = {
} }
LLM_MODEL = "vicuna-13b" LLM_MODEL = "vicuna-13b"
LIMIT_MODEL_CONCURRENCY = 5
MAX_POSITION_EMBEDDINGS = 2048
vicuna_model_server = "http://192.168.31.114:8000/" vicuna_model_server = "http://192.168.31.114:8000/"

View File

@ -0,0 +1,3 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-

View File

@ -2,8 +2,12 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import uvicorn import uvicorn
import asyncio
import json
from typing import Optional, List from typing import Optional, List
from fastapi import FastAPI from fastapi import FastAPI, Request, BackgroundTasks
from fastapi.responses import StreamingResponse
from fastchat.serve.inference import generate_stream
from pydantic import BaseModel from pydantic import BaseModel
from pilot.model.inference import generate_output, get_embeddings from pilot.model.inference import generate_output, get_embeddings
from fastchat.serve.inference import load_model from fastchat.serve.inference import load_model
@ -11,6 +15,11 @@ from pilot.model.loader import ModerLoader
from pilot.configs.model_config import * from pilot.configs.model_config import *
model_path = llm_model_config[LLM_MODEL] model_path = llm_model_config[LLM_MODEL]
global_counter = 0
model_semaphore = None
# ml = ModerLoader(model_path=model_path) # ml = ModerLoader(model_path=model_path)
# model, tokenizer = ml.loader(load_8bit=isload_8bit, debug=isdebug) # model, tokenizer = ml.loader(load_8bit=isload_8bit, debug=isdebug)
@ -27,6 +36,47 @@ class PromptRequest(BaseModel):
class EmbeddingRequest(BaseModel): class EmbeddingRequest(BaseModel):
prompt: str prompt: str
def release_model_semaphore():
model_semaphore.release()
def generate_stream_gate(params):
try:
for output in generate_stream(
model,
tokenizer,
params,
DEVICE,
MAX_POSITION_EMBEDDINGS,
2,
):
ret = {
"text": output,
"error_code": 0,
}
yield json.dumps(ret).encode() + b"\0"
except torch.cuda.CudaError:
ret = {
"text": "**GPU OutOfMemory, Please Refresh.**",
"error_code": 0
}
yield json.dumps(ret).encode() + b"\0"
@app.post("/generate_stream")
async def api_generate_stream(request: Request):
global model_semaphore, global_counter
global_counter += 1
params = await request.json()
if model_semaphore is None:
model_semaphore = asyncio.Semaphore(LIMIT_MODEL_CONCURRENCY)
await model_semaphore.acquire()
generator = generate_stream_gate(params)
background_tasks = BackgroundTasks()
background_tasks.add_task(release_model_semaphore)
return StreamingResponse(generator, background=background_tasks)
@app.post("/generate") @app.post("/generate")
def generate(prompt_request: PromptRequest): def generate(prompt_request: PromptRequest):

117
pilot/server/webserver.py Normal file
View File

@ -0,0 +1,117 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import uuid
import time
import gradio as gr
import datetime
from pilot.configs.model_config import LOGDIR
from pilot.conversation import (
get_default_conv_template,
compute_skip_echo_len,
SeparatorStyle
)
from fastchat.utils import (
build_logger,
server_error_msg,
violates_moderation,
moderation_msg
)
from fastchat.serve.gradio_patch import Chatbot as grChatbot
from fastchat.serve.gradio_css import code_highlight_css
logger = build_logger("webserver", "webserver.log")
headers = {"User-Agent": "dbgpt Client"}
no_change_btn = gr.Button.update()
enable_btn = gr.Button.update(interactive=True)
disable_btn = gr.Button.update(interactive=True)
enable_moderation = False
models = []
priority = {
"vicuna-13b": "aaa"
}
def set_global_vars(enable_moderation_, models_):
global enable_moderation, models
enable_moderation = enable_moderation_
models = models_
def get_conv_log_filename():
t = datetime.datetime.now()
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
return name
def regenerate(state, request: gr.Request):
logger.info(f"regenerate. ip: {request.client.host}")
state.messages[-1][-1] = None
state.skip_next = False
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
def clear_history(request: gr.Request):
logger.info(f"clear_history. ip: {request.client.host}")
state = None
return (state, [], "") + (disable_btn,) * 5
def add_text(state, text, request: gr.Request):
logger.info(f"add_text. ip: {request.client.host}. len:{len(text)}")
if state is None:
state = get_default_conv_template("vicuna").copy()
if len(text) <= 0:
state.skip_next = True
return (state, state.to_gradio_chatbot(), "") + (no_change_btn,) * 5
if enable_moderation:
flagged = violates_moderation(text)
if flagged:
logger.info(f"violate moderation. ip: {request.client.host}. text: {text}")
state.skip_next = True
return (state, state.to_gradio_chatbot(), moderation_msg) + (no_change_btn,) * 5
text = text[:1536] # ?
state.append_message(state.roles[0], text)
state.append_message(state.roles[1], None)
state.skip_next = False
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
def post_process_code(code):
sep = "\n```"
if sep in code:
blocks = code.split(sep)
if len(blocks) % 2 == 1:
for i in range(1, len(blocks), 2):
blocks[i] = blocks[i].replace("\\_", "_")
code = sep.join(blocks)
return code
def http_bot(state, model_selector, temperature, max_new_tokens, request: gr.Request):
logger.info(f"http_bot. ip: {request.client.host}")
start_tstamp = time.time()
model_name = model_selector
temperature = float(temperature)
max_new_tokens = int(max_new_tokens)
if state.skip_next:
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
return
if len(state.message) == state.offset + 2:
new_state = get_default_conv_template(model_name).copy()
new_state.conv_id = uuid.uuid4().hex
new_state.append_message(new_state.roles[0], state.messages[-2][1])
new_state.append_message(new_state.roles[1], None)
state = new_state
# TODO