run a demo

This commit is contained in:
csunny 2023-04-29 15:17:48 +08:00
parent c971f6ec58
commit 4b2e3bf59a
7 changed files with 112 additions and 18 deletions

25
.vscode/launch.json vendored Normal file
View File

@ -0,0 +1,25 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Python: Current File",
"type": "python",
"request": "launch",
"program": "${file}",
"console": "integratedTerminal",
"justMyCode": true,
"env": {"PYTHONPATH": "${workspaceFolder}"},
"envFile": "${workspaceFolder}/.env"
},
{
"name": "Python: Module",
"type": "python",
"request": "launch",
"module": "pilot",
"justMyCode": true,
}
]
}

View File

@ -1,3 +1 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
__version__ = "0.0.1"

View File

@ -8,10 +8,12 @@ root_path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__fi
model_path = os.path.join(root_path, "models")
vector_storepath = os.path.join(root_path, "vector_store")
llm_model_config = {
"flan-t5-base": os.path.join(model_path, "flan-t5-base"),
"vicuna-13b": os.path.join(model_path, "vicuna-13b")
}
LLM_MODEL = "vicuna-13b"
LLM_MODEL = "vicuna-13b"
vicuna_model_server = "http://192.168.31.114:21000/"

4
pilot/model/inference.py Normal file
View File

@ -0,0 +1,4 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import torch

View File

@ -1,9 +1,29 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
import requests
from typing import Any, Mapping, Optional, List
from transformers import pipeline
from langchain.llms.base import LLM
from configs.model_config import *
class VicunaLLM(LLM):
model_name = llm_model_config[LLM_MODEL]
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
url = vicuna_model_server
params = {
"model": "vicuna-13b",
"prompt": prompt,
"temperature": 0.7,
"max_new_tokens": 512,
"stop": "###"
}
pass
@property
def _llm_type(self) -> str:
return "custome"
def _identifying_params(self) -> Mapping[str, Any]:
return {}

View File

@ -1,3 +1,56 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
import requests
import json
import time
from urllib.parse import urljoin
import gradio as gr
from configs.model_config import *
vicuna_base_uri = "http://192.168.31.114:21002/"
vicuna_stream_path = "worker_generate_stream"
vicuna_status_path = "worker_get_status"
def generate(prompt):
params = {
"model": "vicuna-13b",
"prompt": "给出一个查询用户的SQL",
"temperature": 0.7,
"max_new_tokens": 512,
"stop": "###"
}
sts_response = requests.post(
url=urljoin(vicuna_base_uri, vicuna_status_path)
)
print(sts_response.text)
response = requests.post(
url=urljoin(vicuna_base_uri, vicuna_stream_path), data=json.dumps(params)
)
skip_echo_len = len(params["prompt"]) + 1 - params["prompt"].count("</s>") * 3
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode())
if data["error_code"] == 0:
output = data["text"]
yield(output)
time.sleep(0.02)
if __name__ == "__main__":
print(LLM_MODEL)
with gr.Blocks() as demo:
gr.Markdown("数据库SQL生成助手")
with gr.Tab("SQL生成"):
text_input = gr.TextArea()
text_output = gr.TextArea()
text_button = gr.Button("提交")
text_button.click(generate, inputs=text_input, outputs=text_output)
demo.queue(concurrency_count=3).launch(server_name="0.0.0.0")

View File

@ -20,8 +20,8 @@ model = AutoModelForCausalLM.from_pretrained(
)
def generate(prompt):
# compress_module(model, device)
# model.to(device)
compress_module(model, device)
model.to(device)
print(model, tokenizer)
params = {
"model": "vicuna-13b",
@ -31,13 +31,8 @@ def generate(prompt):
"stop": "###"
}
for output in generate_stream(
model, tokenizer, params, device, context_len=2048, stream_interval=2):
ret = {
"text": output,
"error_code": 0
}
yield json.dumps(ret).decode() + b"\0"
model, tokenizer, params, device, context_len=2048, stream_interval=1):
yield output
if __name__ == "__main__":
with gr.Blocks() as demo:
@ -50,7 +45,4 @@ if __name__ == "__main__":
text_button.click(generate, inputs=text_input, outputs=text_output)
demo.queue(concurrency_count=3).launch(host="0.0.0.0")
demo.queue(concurrency_count=3).launch(server_name="0.0.0.0")