mirror of
				https://github.com/csunny/DB-GPT.git
				synced 2025-10-24 19:08:58 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			56 lines
		
	
	
		
			1.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			56 lines
		
	
	
		
			1.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| #!/usr/bin/env python3
 | |
| # -*- coding:utf-8 -*-
 | |
| 
 | |
| import requests
 | |
| import json
 | |
| import time
 | |
| from urllib.parse import urljoin
 | |
| import gradio as gr
 | |
| from pilot.configs.model_config import *
 | |
| vicuna_base_uri = "http://192.168.31.114:21002/"
 | |
| vicuna_stream_path = "worker_generate_stream"
 | |
| vicuna_status_path = "worker_get_status"
 | |
| 
 | |
| def generate(prompt):
 | |
|     params = {
 | |
|         "model": "vicuna-13b",
 | |
|         "prompt": prompt,
 | |
|         "temperature": 0.7,
 | |
|         "max_new_tokens": 512,
 | |
|         "stop": "###"
 | |
|     }
 | |
| 
 | |
|     sts_response = requests.post(
 | |
|         url=urljoin(vicuna_base_uri, vicuna_status_path)
 | |
|     )
 | |
|     print(sts_response.text)
 | |
| 
 | |
|     response = requests.post(
 | |
|         url=urljoin(vicuna_base_uri, vicuna_stream_path), data=json.dumps(params)
 | |
|     )
 | |
| 
 | |
|     skip_echo_len = len(params["prompt"]) + 1 - params["prompt"].count("</s>") * 3
 | |
|     for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
 | |
|         if chunk:
 | |
|             data = json.loads(chunk.decode())
 | |
|             if data["error_code"] == 0:
 | |
|                 output = data["text"]
 | |
|                 yield(output) 
 | |
|             
 | |
|             time.sleep(0.02)
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     print(LLM_MODEL)
 | |
|     with gr.Blocks() as demo:
 | |
|         gr.Markdown("数据库SQL生成助手")
 | |
|         with gr.Tab("SQL生成"):
 | |
|             text_input = gr.TextArea()
 | |
|             text_output = gr.TextArea()
 | |
|             text_button = gr.Button("提交")
 | |
|         
 | |
| 
 | |
|         text_button.click(generate, inputs=text_input, outputs=text_output)
 | |
| 
 | |
|     demo.queue(concurrency_count=3).launch(server_name="0.0.0.0") 
 | |
| 
 | |
|      |