mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-13 13:10:29 +00:00
fix merge problem
This commit is contained in:
@@ -1,3 +1,56 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
import requests
|
||||
import json
|
||||
import time
|
||||
from urllib.parse import urljoin
|
||||
import gradio as gr
|
||||
from configs.model_config import *
|
||||
vicuna_base_uri = "http://192.168.31.114:21002/"
|
||||
vicuna_stream_path = "worker_generate_stream"
|
||||
vicuna_status_path = "worker_get_status"
|
||||
|
||||
def generate(prompt):
|
||||
params = {
|
||||
"model": "vicuna-13b",
|
||||
"prompt": prompt,
|
||||
"temperature": 0.7,
|
||||
"max_new_tokens": 512,
|
||||
"stop": "###"
|
||||
}
|
||||
|
||||
sts_response = requests.post(
|
||||
url=urljoin(vicuna_base_uri, vicuna_status_path)
|
||||
)
|
||||
print(sts_response.text)
|
||||
|
||||
response = requests.post(
|
||||
url=urljoin(vicuna_base_uri, vicuna_stream_path), data=json.dumps(params)
|
||||
)
|
||||
|
||||
skip_echo_len = len(params["prompt"]) + 1 - params["prompt"].count("</s>") * 3
|
||||
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
||||
if chunk:
|
||||
data = json.loads(chunk.decode())
|
||||
if data["error_code"] == 0:
|
||||
output = data["text"]
|
||||
yield(output)
|
||||
|
||||
time.sleep(0.02)
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(LLM_MODEL)
|
||||
with gr.Blocks() as demo:
|
||||
gr.Markdown("数据库SQL生成助手")
|
||||
with gr.Tab("SQL生成"):
|
||||
text_input = gr.TextArea()
|
||||
text_output = gr.TextArea()
|
||||
text_button = gr.Button("提交")
|
||||
|
||||
|
||||
text_button.click(generate, inputs=text_input, outputs=text_output)
|
||||
|
||||
demo.queue(concurrency_count=3).launch(server_name="0.0.0.0")
|
||||
|
||||
|
@@ -7,7 +7,6 @@ import torch
|
||||
import gradio as gr
|
||||
from fastchat.serve.inference import generate_stream, compress_module
|
||||
|
||||
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
BASE_MODE = "/home/magic/workspace/github/DB-GPT/models/vicuna-13b"
|
||||
@@ -21,12 +20,12 @@ model = AutoModelForCausalLM.from_pretrained(
|
||||
)
|
||||
|
||||
def generate(prompt):
|
||||
# compress_module(model, device)
|
||||
# model.to(device)
|
||||
compress_module(model, device)
|
||||
model.to(device)
|
||||
print(model, tokenizer)
|
||||
params = {
|
||||
"model": "vicuna-13b",
|
||||
"prompt": prompt,
|
||||
"prompt": "这是一个用户与助手之间的对话, 助手精通数据库领域的知识, 并能够对数据库领域知识做出非常专业的回答。以下是用户的问题:" + prompt,
|
||||
"temperature": 0.7,
|
||||
"max_new_tokens": 512,
|
||||
"stop": "###"
|
||||
@@ -36,9 +35,6 @@ def generate(prompt):
|
||||
|
||||
for chunk in output:
|
||||
yield chunk
|
||||
#for chunk in output.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
||||
# if chunk:
|
||||
# yield chunk
|
||||
|
||||
if __name__ == "__main__":
|
||||
with gr.Blocks() as demo:
|
||||
@@ -53,5 +49,3 @@ if __name__ == "__main__":
|
||||
|
||||
demo.queue(concurrency_count=3).launch(server_name="0.0.0.0")
|
||||
|
||||
|
||||
|
||||
|
48
pilot/server/vicuna_server.py
Normal file
48
pilot/server/vicuna_server.py
Normal file
@@ -0,0 +1,48 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from typing import Optional, List
|
||||
from fastapi import FastAPI
|
||||
from pydantic import BaseModel
|
||||
from pilot.model.inference import generate_output, get_embeddings
|
||||
from pilot.model.loader import ModerLoader
|
||||
from pilot.configs.model_config import *
|
||||
|
||||
model_path = llm_model_config[LLM_MODEL]
|
||||
ml = ModerLoader(model_path=model_path)
|
||||
model, tokenizer = ml.loader(load_8bit=isload_8bit, debug=isdebug)
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
class PromptRequest(BaseModel):
|
||||
prompt: str
|
||||
temperature: float
|
||||
max_new_tokens: int
|
||||
stop: Optional(List[str]) = None
|
||||
|
||||
|
||||
class EmbeddingRequest(BaseModel):
|
||||
prompt: str
|
||||
|
||||
|
||||
@app.post("/generate")
|
||||
def generate(prompt_request: PromptRequest):
|
||||
params = {
|
||||
"prompt": prompt_request.prompt,
|
||||
"temperature": prompt_request.temperature,
|
||||
"max_new_tokens": prompt_request.max_new_tokens,
|
||||
"stop": prompt_request.stop
|
||||
}
|
||||
|
||||
print("Receive prompt: ", params["prompt"])
|
||||
output = generate_output(model, tokenizer, params, DEVICE)
|
||||
print("Output: ", output)
|
||||
return {"response": output}
|
||||
|
||||
|
||||
@app.post("/embedding")
|
||||
def embeddings(prompt_request: EmbeddingRequest):
|
||||
params = {"prompt": prompt_request.prompt}
|
||||
print("Received prompt: ", params["prompt"])
|
||||
output = get_embeddings(model, tokenizer, params["prompt"])
|
||||
return {"response": [float(x) for x in output]}
|
Reference in New Issue
Block a user