mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-28 14:27:20 +00:00
add stream output
This commit is contained in:
parent
909045c0d6
commit
d73cab9c4b
@ -7,6 +7,7 @@ import os
|
|||||||
root_path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
root_path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
model_path = os.path.join(root_path, "models")
|
model_path = os.path.join(root_path, "models")
|
||||||
vector_storepath = os.path.join(root_path, "vector_store")
|
vector_storepath = os.path.join(root_path, "vector_store")
|
||||||
|
LOGDIR = os.path.join(root_path, "logs")
|
||||||
|
|
||||||
|
|
||||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
@ -16,8 +17,8 @@ llm_model_config = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
LLM_MODEL = "vicuna-13b"
|
LLM_MODEL = "vicuna-13b"
|
||||||
|
LIMIT_MODEL_CONCURRENCY = 5
|
||||||
|
MAX_POSITION_EMBEDDINGS = 2048
|
||||||
vicuna_model_server = "http://192.168.31.114:8000/"
|
vicuna_model_server = "http://192.168.31.114:8000/"
|
||||||
|
|
||||||
|
|
||||||
|
3
pilot/server/embdserver.py
Normal file
3
pilot/server/embdserver.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding:utf-8 -*-
|
||||||
|
|
@ -2,8 +2,12 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI, Request, BackgroundTasks
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
from fastchat.serve.inference import generate_stream
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from pilot.model.inference import generate_output, get_embeddings
|
from pilot.model.inference import generate_output, get_embeddings
|
||||||
from fastchat.serve.inference import load_model
|
from fastchat.serve.inference import load_model
|
||||||
@ -11,6 +15,11 @@ from pilot.model.loader import ModerLoader
|
|||||||
from pilot.configs.model_config import *
|
from pilot.configs.model_config import *
|
||||||
|
|
||||||
model_path = llm_model_config[LLM_MODEL]
|
model_path = llm_model_config[LLM_MODEL]
|
||||||
|
|
||||||
|
|
||||||
|
global_counter = 0
|
||||||
|
model_semaphore = None
|
||||||
|
|
||||||
# ml = ModerLoader(model_path=model_path)
|
# ml = ModerLoader(model_path=model_path)
|
||||||
# model, tokenizer = ml.loader(load_8bit=isload_8bit, debug=isdebug)
|
# model, tokenizer = ml.loader(load_8bit=isload_8bit, debug=isdebug)
|
||||||
|
|
||||||
@ -27,6 +36,47 @@ class PromptRequest(BaseModel):
|
|||||||
class EmbeddingRequest(BaseModel):
|
class EmbeddingRequest(BaseModel):
|
||||||
prompt: str
|
prompt: str
|
||||||
|
|
||||||
|
def release_model_semaphore():
|
||||||
|
model_semaphore.release()
|
||||||
|
|
||||||
|
|
||||||
|
def generate_stream_gate(params):
|
||||||
|
try:
|
||||||
|
for output in generate_stream(
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
params,
|
||||||
|
DEVICE,
|
||||||
|
MAX_POSITION_EMBEDDINGS,
|
||||||
|
2,
|
||||||
|
):
|
||||||
|
ret = {
|
||||||
|
"text": output,
|
||||||
|
"error_code": 0,
|
||||||
|
}
|
||||||
|
yield json.dumps(ret).encode() + b"\0"
|
||||||
|
except torch.cuda.CudaError:
|
||||||
|
ret = {
|
||||||
|
"text": "**GPU OutOfMemory, Please Refresh.**",
|
||||||
|
"error_code": 0
|
||||||
|
}
|
||||||
|
yield json.dumps(ret).encode() + b"\0"
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/generate_stream")
|
||||||
|
async def api_generate_stream(request: Request):
|
||||||
|
global model_semaphore, global_counter
|
||||||
|
global_counter += 1
|
||||||
|
params = await request.json()
|
||||||
|
|
||||||
|
if model_semaphore is None:
|
||||||
|
model_semaphore = asyncio.Semaphore(LIMIT_MODEL_CONCURRENCY)
|
||||||
|
await model_semaphore.acquire()
|
||||||
|
|
||||||
|
generator = generate_stream_gate(params)
|
||||||
|
background_tasks = BackgroundTasks()
|
||||||
|
background_tasks.add_task(release_model_semaphore)
|
||||||
|
return StreamingResponse(generator, background=background_tasks)
|
||||||
|
|
||||||
@app.post("/generate")
|
@app.post("/generate")
|
||||||
def generate(prompt_request: PromptRequest):
|
def generate(prompt_request: PromptRequest):
|
||||||
|
117
pilot/server/webserver.py
Normal file
117
pilot/server/webserver.py
Normal file
@ -0,0 +1,117 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
import time
|
||||||
|
import gradio as gr
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
from pilot.configs.model_config import LOGDIR
|
||||||
|
|
||||||
|
from pilot.conversation import (
|
||||||
|
get_default_conv_template,
|
||||||
|
compute_skip_echo_len,
|
||||||
|
SeparatorStyle
|
||||||
|
)
|
||||||
|
|
||||||
|
from fastchat.utils import (
|
||||||
|
build_logger,
|
||||||
|
server_error_msg,
|
||||||
|
violates_moderation,
|
||||||
|
moderation_msg
|
||||||
|
)
|
||||||
|
|
||||||
|
from fastchat.serve.gradio_patch import Chatbot as grChatbot
|
||||||
|
from fastchat.serve.gradio_css import code_highlight_css
|
||||||
|
|
||||||
|
logger = build_logger("webserver", "webserver.log")
|
||||||
|
headers = {"User-Agent": "dbgpt Client"}
|
||||||
|
|
||||||
|
no_change_btn = gr.Button.update()
|
||||||
|
enable_btn = gr.Button.update(interactive=True)
|
||||||
|
disable_btn = gr.Button.update(interactive=True)
|
||||||
|
|
||||||
|
enable_moderation = False
|
||||||
|
models = []
|
||||||
|
|
||||||
|
priority = {
|
||||||
|
"vicuna-13b": "aaa"
|
||||||
|
}
|
||||||
|
|
||||||
|
def set_global_vars(enable_moderation_, models_):
|
||||||
|
global enable_moderation, models
|
||||||
|
enable_moderation = enable_moderation_
|
||||||
|
models = models_
|
||||||
|
|
||||||
|
def get_conv_log_filename():
|
||||||
|
t = datetime.datetime.now()
|
||||||
|
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
|
||||||
|
return name
|
||||||
|
|
||||||
|
|
||||||
|
def regenerate(state, request: gr.Request):
|
||||||
|
logger.info(f"regenerate. ip: {request.client.host}")
|
||||||
|
state.messages[-1][-1] = None
|
||||||
|
state.skip_next = False
|
||||||
|
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
|
||||||
|
|
||||||
|
def clear_history(request: gr.Request):
|
||||||
|
logger.info(f"clear_history. ip: {request.client.host}")
|
||||||
|
state = None
|
||||||
|
return (state, [], "") + (disable_btn,) * 5
|
||||||
|
|
||||||
|
def add_text(state, text, request: gr.Request):
|
||||||
|
logger.info(f"add_text. ip: {request.client.host}. len:{len(text)}")
|
||||||
|
|
||||||
|
if state is None:
|
||||||
|
state = get_default_conv_template("vicuna").copy()
|
||||||
|
|
||||||
|
if len(text) <= 0:
|
||||||
|
state.skip_next = True
|
||||||
|
return (state, state.to_gradio_chatbot(), "") + (no_change_btn,) * 5
|
||||||
|
|
||||||
|
if enable_moderation:
|
||||||
|
flagged = violates_moderation(text)
|
||||||
|
if flagged:
|
||||||
|
logger.info(f"violate moderation. ip: {request.client.host}. text: {text}")
|
||||||
|
state.skip_next = True
|
||||||
|
return (state, state.to_gradio_chatbot(), moderation_msg) + (no_change_btn,) * 5
|
||||||
|
text = text[:1536] # ?
|
||||||
|
state.append_message(state.roles[0], text)
|
||||||
|
state.append_message(state.roles[1], None)
|
||||||
|
state.skip_next = False
|
||||||
|
|
||||||
|
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
|
||||||
|
|
||||||
|
def post_process_code(code):
|
||||||
|
sep = "\n```"
|
||||||
|
if sep in code:
|
||||||
|
blocks = code.split(sep)
|
||||||
|
if len(blocks) % 2 == 1:
|
||||||
|
for i in range(1, len(blocks), 2):
|
||||||
|
blocks[i] = blocks[i].replace("\\_", "_")
|
||||||
|
code = sep.join(blocks)
|
||||||
|
return code
|
||||||
|
|
||||||
|
def http_bot(state, model_selector, temperature, max_new_tokens, request: gr.Request):
|
||||||
|
logger.info(f"http_bot. ip: {request.client.host}")
|
||||||
|
start_tstamp = time.time()
|
||||||
|
|
||||||
|
model_name = model_selector
|
||||||
|
temperature = float(temperature)
|
||||||
|
max_new_tokens = int(max_new_tokens)
|
||||||
|
|
||||||
|
if state.skip_next:
|
||||||
|
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
|
||||||
|
return
|
||||||
|
|
||||||
|
if len(state.message) == state.offset + 2:
|
||||||
|
new_state = get_default_conv_template(model_name).copy()
|
||||||
|
new_state.conv_id = uuid.uuid4().hex
|
||||||
|
new_state.append_message(new_state.roles[0], state.messages[-2][1])
|
||||||
|
new_state.append_message(new_state.roles[1], None)
|
||||||
|
state = new_state
|
||||||
|
|
||||||
|
|
||||||
|
# TODO
|
Loading…
Reference in New Issue
Block a user