mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-31 15:47:05 +00:00
run a demo
This commit is contained in:
parent
c971f6ec58
commit
4b2e3bf59a
25
.vscode/launch.json
vendored
Normal file
25
.vscode/launch.json
vendored
Normal file
@ -0,0 +1,25 @@
|
||||
{
|
||||
// Use IntelliSense to learn about possible attributes.
|
||||
// Hover to view descriptions of existing attributes.
|
||||
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Python: Current File",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "${file}",
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": true,
|
||||
"env": {"PYTHONPATH": "${workspaceFolder}"},
|
||||
"envFile": "${workspaceFolder}/.env"
|
||||
},
|
||||
{
|
||||
"name": "Python: Module",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"module": "pilot",
|
||||
"justMyCode": true,
|
||||
}
|
||||
]
|
||||
}
|
@ -1,3 +1 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
__version__ = "0.0.1"
|
||||
|
@ -8,10 +8,12 @@ root_path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__fi
|
||||
model_path = os.path.join(root_path, "models")
|
||||
vector_storepath = os.path.join(root_path, "vector_store")
|
||||
|
||||
|
||||
llm_model_config = {
|
||||
"flan-t5-base": os.path.join(model_path, "flan-t5-base"),
|
||||
"vicuna-13b": os.path.join(model_path, "vicuna-13b")
|
||||
}
|
||||
|
||||
LLM_MODEL = "vicuna-13b"
|
||||
|
||||
|
||||
vicuna_model_server = "http://192.168.31.114:21000/"
|
4
pilot/model/inference.py
Normal file
4
pilot/model/inference.py
Normal file
@ -0,0 +1,4 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import torch
|
@ -1,9 +1,29 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
import requests
|
||||
from typing import Any, Mapping, Optional, List
|
||||
from transformers import pipeline
|
||||
from langchain.llms.base import LLM
|
||||
from configs.model_config import *
|
||||
|
||||
class VicunaLLM(LLM):
|
||||
model_name = llm_model_config[LLM_MODEL]
|
||||
|
||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
url = vicuna_model_server
|
||||
params = {
|
||||
"model": "vicuna-13b",
|
||||
"prompt": prompt,
|
||||
"temperature": 0.7,
|
||||
"max_new_tokens": 512,
|
||||
"stop": "###"
|
||||
}
|
||||
pass
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "custome"
|
||||
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
return {}
|
||||
|
@ -1,3 +1,56 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
import requests
|
||||
import json
|
||||
import time
|
||||
from urllib.parse import urljoin
|
||||
import gradio as gr
|
||||
from configs.model_config import *
|
||||
vicuna_base_uri = "http://192.168.31.114:21002/"
|
||||
vicuna_stream_path = "worker_generate_stream"
|
||||
vicuna_status_path = "worker_get_status"
|
||||
|
||||
def generate(prompt):
|
||||
params = {
|
||||
"model": "vicuna-13b",
|
||||
"prompt": "给出一个查询用户的SQL",
|
||||
"temperature": 0.7,
|
||||
"max_new_tokens": 512,
|
||||
"stop": "###"
|
||||
}
|
||||
|
||||
sts_response = requests.post(
|
||||
url=urljoin(vicuna_base_uri, vicuna_status_path)
|
||||
)
|
||||
print(sts_response.text)
|
||||
|
||||
response = requests.post(
|
||||
url=urljoin(vicuna_base_uri, vicuna_stream_path), data=json.dumps(params)
|
||||
)
|
||||
|
||||
skip_echo_len = len(params["prompt"]) + 1 - params["prompt"].count("</s>") * 3
|
||||
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
||||
if chunk:
|
||||
data = json.loads(chunk.decode())
|
||||
if data["error_code"] == 0:
|
||||
output = data["text"]
|
||||
yield(output)
|
||||
|
||||
time.sleep(0.02)
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(LLM_MODEL)
|
||||
with gr.Blocks() as demo:
|
||||
gr.Markdown("数据库SQL生成助手")
|
||||
with gr.Tab("SQL生成"):
|
||||
text_input = gr.TextArea()
|
||||
text_output = gr.TextArea()
|
||||
text_button = gr.Button("提交")
|
||||
|
||||
|
||||
text_button.click(generate, inputs=text_input, outputs=text_output)
|
||||
|
||||
demo.queue(concurrency_count=3).launch(server_name="0.0.0.0")
|
||||
|
||||
|
@ -20,8 +20,8 @@ model = AutoModelForCausalLM.from_pretrained(
|
||||
)
|
||||
|
||||
def generate(prompt):
|
||||
# compress_module(model, device)
|
||||
# model.to(device)
|
||||
compress_module(model, device)
|
||||
model.to(device)
|
||||
print(model, tokenizer)
|
||||
params = {
|
||||
"model": "vicuna-13b",
|
||||
@ -31,13 +31,8 @@ def generate(prompt):
|
||||
"stop": "###"
|
||||
}
|
||||
for output in generate_stream(
|
||||
model, tokenizer, params, device, context_len=2048, stream_interval=2):
|
||||
ret = {
|
||||
"text": output,
|
||||
"error_code": 0
|
||||
}
|
||||
|
||||
yield json.dumps(ret).decode() + b"\0"
|
||||
model, tokenizer, params, device, context_len=2048, stream_interval=1):
|
||||
yield output
|
||||
|
||||
if __name__ == "__main__":
|
||||
with gr.Blocks() as demo:
|
||||
@ -50,7 +45,4 @@ if __name__ == "__main__":
|
||||
|
||||
text_button.click(generate, inputs=text_input, outputs=text_output)
|
||||
|
||||
demo.queue(concurrency_count=3).launch(host="0.0.0.0")
|
||||
|
||||
|
||||
|
||||
demo.queue(concurrency_count=3).launch(server_name="0.0.0.0")
|
||||
|
Loading…
Reference in New Issue
Block a user