update:merge

This commit is contained in:
aries-ckt
2023-05-18 19:52:59 +08:00
parent e50c04ead2
commit e59c3834eb
20 changed files with 271 additions and 507 deletions

View File

@@ -7,12 +7,15 @@ import time
import uuid
from urllib.parse import urljoin
import gradio as gr
from pilot.configs.model_config import *
from pilot.configs.config import Config
from pilot.conversation import conv_qa_prompt_template, conv_templates
from langchain.prompts import PromptTemplate
vicuna_stream_path = "generate_stream"
CFG = Config()
def generate(query):
template_name = "conv_one_shot"
@@ -41,7 +44,7 @@ def generate(query):
}
response = requests.post(
url=urljoin(VICUNA_MODEL_SERVER, vicuna_stream_path), data=json.dumps(params)
url=urljoin(CFG.MODEL_SERVER, vicuna_stream_path), data=json.dumps(params)
)
skip_echo_len = len(params["prompt"]) + 1 - params["prompt"].count("</s>") * 3
@@ -54,7 +57,7 @@ def generate(query):
yield(output)
if __name__ == "__main__":
print(LLM_MODEL)
print(CFG.LLM_MODEL)
with gr.Blocks() as demo:
gr.Markdown("数据库SQL生成助手")
with gr.Tab("SQL生成"):