mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-23 12:21:08 +00:00
add stream output
This commit is contained in:
parent
909045c0d6
commit
d73cab9c4b
@ -7,6 +7,7 @@ import os
|
||||
root_path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
model_path = os.path.join(root_path, "models")
|
||||
vector_storepath = os.path.join(root_path, "vector_store")
|
||||
LOGDIR = os.path.join(root_path, "logs")
|
||||
|
||||
|
||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
@ -16,8 +17,8 @@ llm_model_config = {
|
||||
}
|
||||
|
||||
LLM_MODEL = "vicuna-13b"
|
||||
|
||||
|
||||
LIMIT_MODEL_CONCURRENCY = 5
|
||||
MAX_POSITION_EMBEDDINGS = 2048
|
||||
vicuna_model_server = "http://192.168.31.114:8000/"
|
||||
|
||||
|
||||
|
3
pilot/server/embdserver.py
Normal file
3
pilot/server/embdserver.py
Normal file
@ -0,0 +1,3 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
|
@ -2,8 +2,12 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import uvicorn
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Optional, List
|
||||
from fastapi import FastAPI
|
||||
from fastapi import FastAPI, Request, BackgroundTasks
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastchat.serve.inference import generate_stream
|
||||
from pydantic import BaseModel
|
||||
from pilot.model.inference import generate_output, get_embeddings
|
||||
from fastchat.serve.inference import load_model
|
||||
@ -11,6 +15,11 @@ from pilot.model.loader import ModerLoader
|
||||
from pilot.configs.model_config import *
|
||||
|
||||
model_path = llm_model_config[LLM_MODEL]
|
||||
|
||||
|
||||
global_counter = 0
|
||||
model_semaphore = None
|
||||
|
||||
# ml = ModerLoader(model_path=model_path)
|
||||
# model, tokenizer = ml.loader(load_8bit=isload_8bit, debug=isdebug)
|
||||
|
||||
@ -27,6 +36,47 @@ class PromptRequest(BaseModel):
|
||||
class EmbeddingRequest(BaseModel):
|
||||
prompt: str
|
||||
|
||||
def release_model_semaphore():
|
||||
model_semaphore.release()
|
||||
|
||||
|
||||
def generate_stream_gate(params):
|
||||
try:
|
||||
for output in generate_stream(
|
||||
model,
|
||||
tokenizer,
|
||||
params,
|
||||
DEVICE,
|
||||
MAX_POSITION_EMBEDDINGS,
|
||||
2,
|
||||
):
|
||||
ret = {
|
||||
"text": output,
|
||||
"error_code": 0,
|
||||
}
|
||||
yield json.dumps(ret).encode() + b"\0"
|
||||
except torch.cuda.CudaError:
|
||||
ret = {
|
||||
"text": "**GPU OutOfMemory, Please Refresh.**",
|
||||
"error_code": 0
|
||||
}
|
||||
yield json.dumps(ret).encode() + b"\0"
|
||||
|
||||
|
||||
@app.post("/generate_stream")
|
||||
async def api_generate_stream(request: Request):
|
||||
global model_semaphore, global_counter
|
||||
global_counter += 1
|
||||
params = await request.json()
|
||||
|
||||
if model_semaphore is None:
|
||||
model_semaphore = asyncio.Semaphore(LIMIT_MODEL_CONCURRENCY)
|
||||
await model_semaphore.acquire()
|
||||
|
||||
generator = generate_stream_gate(params)
|
||||
background_tasks = BackgroundTasks()
|
||||
background_tasks.add_task(release_model_semaphore)
|
||||
return StreamingResponse(generator, background=background_tasks)
|
||||
|
||||
@app.post("/generate")
|
||||
def generate(prompt_request: PromptRequest):
|
||||
|
117
pilot/server/webserver.py
Normal file
117
pilot/server/webserver.py
Normal file
@ -0,0 +1,117 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import os
|
||||
import uuid
|
||||
import time
|
||||
import gradio as gr
|
||||
import datetime
|
||||
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
|
||||
from pilot.conversation import (
|
||||
get_default_conv_template,
|
||||
compute_skip_echo_len,
|
||||
SeparatorStyle
|
||||
)
|
||||
|
||||
from fastchat.utils import (
|
||||
build_logger,
|
||||
server_error_msg,
|
||||
violates_moderation,
|
||||
moderation_msg
|
||||
)
|
||||
|
||||
from fastchat.serve.gradio_patch import Chatbot as grChatbot
|
||||
from fastchat.serve.gradio_css import code_highlight_css
|
||||
|
||||
logger = build_logger("webserver", "webserver.log")
|
||||
headers = {"User-Agent": "dbgpt Client"}
|
||||
|
||||
no_change_btn = gr.Button.update()
|
||||
enable_btn = gr.Button.update(interactive=True)
|
||||
disable_btn = gr.Button.update(interactive=True)
|
||||
|
||||
enable_moderation = False
|
||||
models = []
|
||||
|
||||
priority = {
|
||||
"vicuna-13b": "aaa"
|
||||
}
|
||||
|
||||
def set_global_vars(enable_moderation_, models_):
|
||||
global enable_moderation, models
|
||||
enable_moderation = enable_moderation_
|
||||
models = models_
|
||||
|
||||
def get_conv_log_filename():
|
||||
t = datetime.datetime.now()
|
||||
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
|
||||
return name
|
||||
|
||||
|
||||
def regenerate(state, request: gr.Request):
|
||||
logger.info(f"regenerate. ip: {request.client.host}")
|
||||
state.messages[-1][-1] = None
|
||||
state.skip_next = False
|
||||
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
|
||||
|
||||
def clear_history(request: gr.Request):
|
||||
logger.info(f"clear_history. ip: {request.client.host}")
|
||||
state = None
|
||||
return (state, [], "") + (disable_btn,) * 5
|
||||
|
||||
def add_text(state, text, request: gr.Request):
|
||||
logger.info(f"add_text. ip: {request.client.host}. len:{len(text)}")
|
||||
|
||||
if state is None:
|
||||
state = get_default_conv_template("vicuna").copy()
|
||||
|
||||
if len(text) <= 0:
|
||||
state.skip_next = True
|
||||
return (state, state.to_gradio_chatbot(), "") + (no_change_btn,) * 5
|
||||
|
||||
if enable_moderation:
|
||||
flagged = violates_moderation(text)
|
||||
if flagged:
|
||||
logger.info(f"violate moderation. ip: {request.client.host}. text: {text}")
|
||||
state.skip_next = True
|
||||
return (state, state.to_gradio_chatbot(), moderation_msg) + (no_change_btn,) * 5
|
||||
text = text[:1536] # ?
|
||||
state.append_message(state.roles[0], text)
|
||||
state.append_message(state.roles[1], None)
|
||||
state.skip_next = False
|
||||
|
||||
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
|
||||
|
||||
def post_process_code(code):
|
||||
sep = "\n```"
|
||||
if sep in code:
|
||||
blocks = code.split(sep)
|
||||
if len(blocks) % 2 == 1:
|
||||
for i in range(1, len(blocks), 2):
|
||||
blocks[i] = blocks[i].replace("\\_", "_")
|
||||
code = sep.join(blocks)
|
||||
return code
|
||||
|
||||
def http_bot(state, model_selector, temperature, max_new_tokens, request: gr.Request):
|
||||
logger.info(f"http_bot. ip: {request.client.host}")
|
||||
start_tstamp = time.time()
|
||||
|
||||
model_name = model_selector
|
||||
temperature = float(temperature)
|
||||
max_new_tokens = int(max_new_tokens)
|
||||
|
||||
if state.skip_next:
|
||||
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
|
||||
return
|
||||
|
||||
if len(state.message) == state.offset + 2:
|
||||
new_state = get_default_conv_template(model_name).copy()
|
||||
new_state.conv_id = uuid.uuid4().hex
|
||||
new_state.append_message(new_state.roles[0], state.messages[-2][1])
|
||||
new_state.append_message(new_state.roles[1], None)
|
||||
state = new_state
|
||||
|
||||
|
||||
# TODO
|
Loading…
Reference in New Issue
Block a user