fetch top3 similar answer

This commit is contained in:
csunny
2023-05-07 17:32:10 +08:00
parent e4899ff7dd
commit 56e9cde86e
6 changed files with 61 additions and 52 deletions

View File

@@ -4,29 +4,44 @@
import requests
import json
import time
import uuid
from urllib.parse import urljoin
import gradio as gr
from pilot.configs.model_config import *
vicuna_base_uri = "http://192.168.31.114:21002/"
vicuna_stream_path = "worker_generate_stream"
vicuna_status_path = "worker_get_status"
from pilot.conversation import conv_qa_prompt_template, conv_templates
from langchain.prompts import PromptTemplate
def generate(prompt):
vicuna_stream_path = "generate_stream"
def generate(query):
template_name = "conv_one_shot"
state = conv_templates[template_name].copy()
pt = PromptTemplate(
template=conv_qa_prompt_template,
input_variables=["context", "question"]
)
result = pt.format(context="This page covers how to use the Chroma ecosystem within LangChain. It is broken into two parts: installation and setup, and then references to specific Chroma wrappers.",
question=query)
print(result)
state.append_message(state.roles[0], result)
state.append_message(state.roles[1], None)
prompt = state.get_prompt()
params = {
"model": "vicuna-13b",
"prompt": prompt,
"temperature": 0.7,
"max_new_tokens": 512,
"max_new_tokens": 1024,
"stop": "###"
}
sts_response = requests.post(
url=urljoin(vicuna_base_uri, vicuna_status_path)
)
print(sts_response.text)
response = requests.post(
url=urljoin(vicuna_base_uri, vicuna_stream_path), data=json.dumps(params)
url=urljoin(VICUNA_MODEL_SERVER, vicuna_stream_path), data=json.dumps(params)
)
skip_echo_len = len(params["prompt"]) + 1 - params["prompt"].count("</s>") * 3
@@ -34,11 +49,10 @@ def generate(prompt):
if chunk:
data = json.loads(chunk.decode())
if data["error_code"] == 0:
output = data["text"]
output = data["text"][skip_echo_len:].strip()
state.messages[-1][-1] = output + ""
yield(output)
time.sleep(0.02)
if __name__ == "__main__":
print(LLM_MODEL)
with gr.Blocks() as demo: