mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-10 04:37:59 +00:00
run a demo
This commit is contained in:
parent
c971f6ec58
commit
4b2e3bf59a
25
.vscode/launch.json
vendored
Normal file
25
.vscode/launch.json
vendored
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
{
|
||||||
|
// Use IntelliSense to learn about possible attributes.
|
||||||
|
// Hover to view descriptions of existing attributes.
|
||||||
|
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||||
|
"version": "0.2.0",
|
||||||
|
"configurations": [
|
||||||
|
{
|
||||||
|
"name": "Python: Current File",
|
||||||
|
"type": "python",
|
||||||
|
"request": "launch",
|
||||||
|
"program": "${file}",
|
||||||
|
"console": "integratedTerminal",
|
||||||
|
"justMyCode": true,
|
||||||
|
"env": {"PYTHONPATH": "${workspaceFolder}"},
|
||||||
|
"envFile": "${workspaceFolder}/.env"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Python: Module",
|
||||||
|
"type": "python",
|
||||||
|
"request": "launch",
|
||||||
|
"module": "pilot",
|
||||||
|
"justMyCode": true,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
@ -1,3 +1 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
__version__ = "0.0.1"
|
__version__ = "0.0.1"
|
||||||
|
@ -8,10 +8,12 @@ root_path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__fi
|
|||||||
model_path = os.path.join(root_path, "models")
|
model_path = os.path.join(root_path, "models")
|
||||||
vector_storepath = os.path.join(root_path, "vector_store")
|
vector_storepath = os.path.join(root_path, "vector_store")
|
||||||
|
|
||||||
|
|
||||||
llm_model_config = {
|
llm_model_config = {
|
||||||
"flan-t5-base": os.path.join(model_path, "flan-t5-base"),
|
"flan-t5-base": os.path.join(model_path, "flan-t5-base"),
|
||||||
"vicuna-13b": os.path.join(model_path, "vicuna-13b")
|
"vicuna-13b": os.path.join(model_path, "vicuna-13b")
|
||||||
}
|
}
|
||||||
|
|
||||||
LLM_MODEL = "vicuna-13b"
|
LLM_MODEL = "vicuna-13b"
|
||||||
|
|
||||||
|
|
||||||
|
vicuna_model_server = "http://192.168.31.114:21000/"
|
4
pilot/model/inference.py
Normal file
4
pilot/model/inference.py
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
import torch
|
@ -1,9 +1,29 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# -*- coding:utf-8 -*-
|
# -*- coding:utf-8 -*-
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from typing import Any, Mapping, Optional, List
|
||||||
from transformers import pipeline
|
from transformers import pipeline
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import LLM
|
||||||
from configs.model_config import *
|
from configs.model_config import *
|
||||||
|
|
||||||
class VicunaLLM(LLM):
|
class VicunaLLM(LLM):
|
||||||
model_name = llm_model_config[LLM_MODEL]
|
|
||||||
|
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||||
|
url = vicuna_model_server
|
||||||
|
params = {
|
||||||
|
"model": "vicuna-13b",
|
||||||
|
"prompt": prompt,
|
||||||
|
"temperature": 0.7,
|
||||||
|
"max_new_tokens": 512,
|
||||||
|
"stop": "###"
|
||||||
|
}
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
return "custome"
|
||||||
|
|
||||||
|
def _identifying_params(self) -> Mapping[str, Any]:
|
||||||
|
return {}
|
||||||
|
|
@ -1,3 +1,56 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# -*- coding:utf-8 -*-
|
# -*- coding:utf-8 -*-
|
||||||
|
|
||||||
|
import requests
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from urllib.parse import urljoin
|
||||||
|
import gradio as gr
|
||||||
|
from configs.model_config import *
|
||||||
|
vicuna_base_uri = "http://192.168.31.114:21002/"
|
||||||
|
vicuna_stream_path = "worker_generate_stream"
|
||||||
|
vicuna_status_path = "worker_get_status"
|
||||||
|
|
||||||
|
def generate(prompt):
|
||||||
|
params = {
|
||||||
|
"model": "vicuna-13b",
|
||||||
|
"prompt": "给出一个查询用户的SQL",
|
||||||
|
"temperature": 0.7,
|
||||||
|
"max_new_tokens": 512,
|
||||||
|
"stop": "###"
|
||||||
|
}
|
||||||
|
|
||||||
|
sts_response = requests.post(
|
||||||
|
url=urljoin(vicuna_base_uri, vicuna_status_path)
|
||||||
|
)
|
||||||
|
print(sts_response.text)
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
url=urljoin(vicuna_base_uri, vicuna_stream_path), data=json.dumps(params)
|
||||||
|
)
|
||||||
|
|
||||||
|
skip_echo_len = len(params["prompt"]) + 1 - params["prompt"].count("</s>") * 3
|
||||||
|
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
||||||
|
if chunk:
|
||||||
|
data = json.loads(chunk.decode())
|
||||||
|
if data["error_code"] == 0:
|
||||||
|
output = data["text"]
|
||||||
|
yield(output)
|
||||||
|
|
||||||
|
time.sleep(0.02)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print(LLM_MODEL)
|
||||||
|
with gr.Blocks() as demo:
|
||||||
|
gr.Markdown("数据库SQL生成助手")
|
||||||
|
with gr.Tab("SQL生成"):
|
||||||
|
text_input = gr.TextArea()
|
||||||
|
text_output = gr.TextArea()
|
||||||
|
text_button = gr.Button("提交")
|
||||||
|
|
||||||
|
|
||||||
|
text_button.click(generate, inputs=text_input, outputs=text_output)
|
||||||
|
|
||||||
|
demo.queue(concurrency_count=3).launch(server_name="0.0.0.0")
|
||||||
|
|
||||||
|
|
@ -20,8 +20,8 @@ model = AutoModelForCausalLM.from_pretrained(
|
|||||||
)
|
)
|
||||||
|
|
||||||
def generate(prompt):
|
def generate(prompt):
|
||||||
# compress_module(model, device)
|
compress_module(model, device)
|
||||||
# model.to(device)
|
model.to(device)
|
||||||
print(model, tokenizer)
|
print(model, tokenizer)
|
||||||
params = {
|
params = {
|
||||||
"model": "vicuna-13b",
|
"model": "vicuna-13b",
|
||||||
@ -31,13 +31,8 @@ def generate(prompt):
|
|||||||
"stop": "###"
|
"stop": "###"
|
||||||
}
|
}
|
||||||
for output in generate_stream(
|
for output in generate_stream(
|
||||||
model, tokenizer, params, device, context_len=2048, stream_interval=2):
|
model, tokenizer, params, device, context_len=2048, stream_interval=1):
|
||||||
ret = {
|
yield output
|
||||||
"text": output,
|
|
||||||
"error_code": 0
|
|
||||||
}
|
|
||||||
|
|
||||||
yield json.dumps(ret).decode() + b"\0"
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
with gr.Blocks() as demo:
|
with gr.Blocks() as demo:
|
||||||
@ -50,7 +45,4 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
text_button.click(generate, inputs=text_input, outputs=text_output)
|
text_button.click(generate, inputs=text_input, outputs=text_output)
|
||||||
|
|
||||||
demo.queue(concurrency_count=3).launch(host="0.0.0.0")
|
demo.queue(concurrency_count=3).launch(server_name="0.0.0.0")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user