diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 000000000..09a35ce9c --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,25 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "name": "Python: Current File", + "type": "python", + "request": "launch", + "program": "${file}", + "console": "integratedTerminal", + "justMyCode": true, + "env": {"PYTHONPATH": "${workspaceFolder}"}, + "envFile": "${workspaceFolder}/.env" + }, + { + "name": "Python: Module", + "type": "python", + "request": "launch", + "module": "pilot", + "justMyCode": true, + } + ] +} \ No newline at end of file diff --git a/pilot/__init__.py b/pilot/__init__.py index 9244e14db..f102a9cad 100644 --- a/pilot/__init__.py +++ b/pilot/__init__.py @@ -1,3 +1 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- __version__ = "0.0.1" diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py index b0157ae19..58ad5cffe 100644 --- a/pilot/configs/model_config.py +++ b/pilot/configs/model_config.py @@ -8,10 +8,12 @@ root_path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__fi model_path = os.path.join(root_path, "models") vector_storepath = os.path.join(root_path, "vector_store") - llm_model_config = { "flan-t5-base": os.path.join(model_path, "flan-t5-base"), "vicuna-13b": os.path.join(model_path, "vicuna-13b") } -LLM_MODEL = "vicuna-13b" \ No newline at end of file +LLM_MODEL = "vicuna-13b" + + +vicuna_model_server = "http://192.168.31.114:21000/" \ No newline at end of file diff --git a/pilot/model/inference.py b/pilot/model/inference.py new file mode 100644 index 000000000..c3698fb1f --- /dev/null +++ b/pilot/model/inference.py @@ -0,0 +1,4 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import torch diff --git a/pilot/model/vicuna_llm.py b/pilot/model/vicuna_llm.py index 1cc0ca3c3..b3ecd079d 100644 --- a/pilot/model/vicuna_llm.py +++ b/pilot/model/vicuna_llm.py @@ -1,9 +1,29 @@ #!/usr/bin/env python3 # -*- coding:utf-8 -*- +import requests +from typing import Any, Mapping, Optional, List from transformers import pipeline from langchain.llms.base import LLM from configs.model_config import * class VicunaLLM(LLM): - model_name = llm_model_config[LLM_MODEL] + + def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + url = vicuna_model_server + params = { + "model": "vicuna-13b", + "prompt": prompt, + "temperature": 0.7, + "max_new_tokens": 512, + "stop": "###" + } + pass + + @property + def _llm_type(self) -> str: + return "custome" + + def _identifying_params(self) -> Mapping[str, Any]: + return {} + \ No newline at end of file diff --git a/pilot/server/chatbot.py b/pilot/server/chatbot.py index 97206f2d5..6cc1b8904 100644 --- a/pilot/server/chatbot.py +++ b/pilot/server/chatbot.py @@ -1,3 +1,56 @@ #!/usr/bin/env python3 # -*- coding:utf-8 -*- +import requests +import json +import time +from urllib.parse import urljoin +import gradio as gr +from configs.model_config import * +vicuna_base_uri = "http://192.168.31.114:21002/" +vicuna_stream_path = "worker_generate_stream" +vicuna_status_path = "worker_get_status" + +def generate(prompt): + params = { + "model": "vicuna-13b", + "prompt": "给出一个查询用户的SQL", + "temperature": 0.7, + "max_new_tokens": 512, + "stop": "###" + } + + sts_response = requests.post( + url=urljoin(vicuna_base_uri, vicuna_status_path) + ) + print(sts_response.text) + + response = requests.post( + url=urljoin(vicuna_base_uri, vicuna_stream_path), data=json.dumps(params) + ) + + skip_echo_len = len(params["prompt"]) + 1 - params["prompt"].count("") * 3 + for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): + if chunk: + data = json.loads(chunk.decode()) + if data["error_code"] == 0: + output = data["text"] + yield(output) + + time.sleep(0.02) + +if __name__ == "__main__": + print(LLM_MODEL) + with gr.Blocks() as demo: + gr.Markdown("数据库SQL生成助手") + with gr.Tab("SQL生成"): + text_input = gr.TextArea() + text_output = gr.TextArea() + text_button = gr.Button("提交") + + + text_button.click(generate, inputs=text_input, outputs=text_output) + + demo.queue(concurrency_count=3).launch(server_name="0.0.0.0") + + \ No newline at end of file diff --git a/pilot/server/sqlgpt.py b/pilot/server/sqlgpt.py index 52522e6bd..6dbf1bfc1 100644 --- a/pilot/server/sqlgpt.py +++ b/pilot/server/sqlgpt.py @@ -20,8 +20,8 @@ model = AutoModelForCausalLM.from_pretrained( ) def generate(prompt): - # compress_module(model, device) - # model.to(device) + compress_module(model, device) + model.to(device) print(model, tokenizer) params = { "model": "vicuna-13b", @@ -31,13 +31,8 @@ def generate(prompt): "stop": "###" } for output in generate_stream( - model, tokenizer, params, device, context_len=2048, stream_interval=2): - ret = { - "text": output, - "error_code": 0 - } - - yield json.dumps(ret).decode() + b"\0" + model, tokenizer, params, device, context_len=2048, stream_interval=1): + yield output if __name__ == "__main__": with gr.Blocks() as demo: @@ -50,7 +45,4 @@ if __name__ == "__main__": text_button.click(generate, inputs=text_input, outputs=text_output) - demo.queue(concurrency_count=3).launch(host="0.0.0.0") - - - + demo.queue(concurrency_count=3).launch(server_name="0.0.0.0")