fix merge problem

This commit is contained in:
csunny
2023-04-29 21:29:29 +08:00
11 changed files with 558 additions and 255 deletions

View File

@@ -1,3 +1,56 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
import requests
import json
import time
from urllib.parse import urljoin
import gradio as gr
from configs.model_config import *
vicuna_base_uri = "http://192.168.31.114:21002/"
vicuna_stream_path = "worker_generate_stream"
vicuna_status_path = "worker_get_status"
def generate(prompt):
params = {
"model": "vicuna-13b",
"prompt": prompt,
"temperature": 0.7,
"max_new_tokens": 512,
"stop": "###"
}
sts_response = requests.post(
url=urljoin(vicuna_base_uri, vicuna_status_path)
)
print(sts_response.text)
response = requests.post(
url=urljoin(vicuna_base_uri, vicuna_stream_path), data=json.dumps(params)
)
skip_echo_len = len(params["prompt"]) + 1 - params["prompt"].count("</s>") * 3
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode())
if data["error_code"] == 0:
output = data["text"]
yield(output)
time.sleep(0.02)
if __name__ == "__main__":
print(LLM_MODEL)
with gr.Blocks() as demo:
gr.Markdown("数据库SQL生成助手")
with gr.Tab("SQL生成"):
text_input = gr.TextArea()
text_output = gr.TextArea()
text_button = gr.Button("提交")
text_button.click(generate, inputs=text_input, outputs=text_output)
demo.queue(concurrency_count=3).launch(server_name="0.0.0.0")

View File

@@ -7,7 +7,6 @@ import torch
import gradio as gr
from fastchat.serve.inference import generate_stream, compress_module
from transformers import AutoTokenizer, AutoModelForCausalLM
device = "cuda" if torch.cuda.is_available() else "cpu"
BASE_MODE = "/home/magic/workspace/github/DB-GPT/models/vicuna-13b"
@@ -21,12 +20,12 @@ model = AutoModelForCausalLM.from_pretrained(
)
def generate(prompt):
# compress_module(model, device)
# model.to(device)
compress_module(model, device)
model.to(device)
print(model, tokenizer)
params = {
"model": "vicuna-13b",
"prompt": prompt,
"prompt": "这是一个用户与助手之间的对话, 助手精通数据库领域的知识, 并能够对数据库领域知识做出非常专业的回答。以下是用户的问题:" + prompt,
"temperature": 0.7,
"max_new_tokens": 512,
"stop": "###"
@@ -36,9 +35,6 @@ def generate(prompt):
for chunk in output:
yield chunk
#for chunk in output.iter_lines(decode_unicode=False, delimiter=b"\0"):
# if chunk:
# yield chunk
if __name__ == "__main__":
with gr.Blocks() as demo:
@@ -53,5 +49,3 @@ if __name__ == "__main__":
demo.queue(concurrency_count=3).launch(server_name="0.0.0.0")

View File

@@ -0,0 +1,48 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from typing import Optional, List
from fastapi import FastAPI
from pydantic import BaseModel
from pilot.model.inference import generate_output, get_embeddings
from pilot.model.loader import ModerLoader
from pilot.configs.model_config import *
model_path = llm_model_config[LLM_MODEL]
ml = ModerLoader(model_path=model_path)
model, tokenizer = ml.loader(load_8bit=isload_8bit, debug=isdebug)
app = FastAPI()
class PromptRequest(BaseModel):
prompt: str
temperature: float
max_new_tokens: int
stop: Optional(List[str]) = None
class EmbeddingRequest(BaseModel):
prompt: str
@app.post("/generate")
def generate(prompt_request: PromptRequest):
params = {
"prompt": prompt_request.prompt,
"temperature": prompt_request.temperature,
"max_new_tokens": prompt_request.max_new_tokens,
"stop": prompt_request.stop
}
print("Receive prompt: ", params["prompt"])
output = generate_output(model, tokenizer, params, DEVICE)
print("Output: ", output)
return {"response": output}
@app.post("/embedding")
def embeddings(prompt_request: EmbeddingRequest):
params = {"prompt": prompt_request.prompt}
print("Received prompt: ", params["prompt"])
output = get_embeddings(model, tokenizer, params["prompt"])
return {"response": [float(x) for x in output]}