mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-13 05:01:25 +00:00
convert file address
This commit is contained in:
65
examples/app.py
Normal file
65
examples/app.py
Normal file
@@ -0,0 +1,65 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
import gradio as gr
|
||||
from langchain.agents import (
|
||||
load_tools,
|
||||
initialize_agent,
|
||||
AgentType
|
||||
)
|
||||
from pilot.model.vicuna_llm import VicunaRequestLLM, VicunaEmbeddingLLM
|
||||
from llama_index import LLMPredictor, LangchainEmbedding, ServiceContext
|
||||
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
||||
from llama_index import Document, GPTSimpleVectorIndex
|
||||
|
||||
def agent_demo():
|
||||
llm = VicunaRequestLLM()
|
||||
|
||||
tools = load_tools(['python_repl'], llm=llm)
|
||||
agent = initialize_agent(tools, llm, agent=AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION, verbose=True)
|
||||
agent.run(
|
||||
"Write a SQL script that Query 'select count(1)!'"
|
||||
)
|
||||
|
||||
def knowledged_qa_demo(text_list):
|
||||
llm_predictor = LLMPredictor(llm=VicunaRequestLLM())
|
||||
hfemb = VicunaEmbeddingLLM()
|
||||
embed_model = LangchainEmbedding(hfemb)
|
||||
documents = [Document(t) for t in text_list]
|
||||
|
||||
service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor, embed_model=embed_model)
|
||||
index = GPTSimpleVectorIndex.from_documents(documents, service_context=service_context)
|
||||
return index
|
||||
|
||||
|
||||
def get_answer(q):
|
||||
base_knowledge = """ """
|
||||
text_list = [base_knowledge]
|
||||
index = knowledged_qa_demo(text_list)
|
||||
response = index.query(q)
|
||||
return response.response
|
||||
|
||||
def get_similar(q):
|
||||
from pilot.vector_store.extract_tovec import knownledge_tovec, knownledge_tovec_st
|
||||
docsearch = knownledge_tovec_st("./datasets/plan.md")
|
||||
docs = docsearch.similarity_search_with_score(q, k=1)
|
||||
|
||||
for doc in docs:
|
||||
dc, s = doc
|
||||
print(s)
|
||||
yield dc.page_content
|
||||
|
||||
if __name__ == "__main__":
|
||||
# agent_demo()
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
gr.Markdown("数据库智能助手")
|
||||
with gr.Tab("知识问答"):
|
||||
text_input = gr.TextArea()
|
||||
text_output = gr.TextArea()
|
||||
text_button = gr.Button()
|
||||
|
||||
text_button.click(get_similar, inputs=text_input, outputs=text_output)
|
||||
|
||||
demo.queue(concurrency_count=3).launch(server_name="0.0.0.0")
|
||||
|
70
examples/embdserver.py
Normal file
70
examples/embdserver.py
Normal file
@@ -0,0 +1,70 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
import requests
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from urllib.parse import urljoin
|
||||
import gradio as gr
|
||||
from pilot.configs.model_config import *
|
||||
from pilot.conversation import conv_qa_prompt_template, conv_templates
|
||||
from langchain.prompts import PromptTemplate
|
||||
|
||||
vicuna_stream_path = "generate_stream"
|
||||
|
||||
def generate(query):
|
||||
|
||||
template_name = "conv_one_shot"
|
||||
state = conv_templates[template_name].copy()
|
||||
|
||||
pt = PromptTemplate(
|
||||
template=conv_qa_prompt_template,
|
||||
input_variables=["context", "question"]
|
||||
)
|
||||
|
||||
result = pt.format(context="This page covers how to use the Chroma ecosystem within LangChain. It is broken into two parts: installation and setup, and then references to specific Chroma wrappers.",
|
||||
question=query)
|
||||
|
||||
print(result)
|
||||
|
||||
state.append_message(state.roles[0], result)
|
||||
state.append_message(state.roles[1], None)
|
||||
|
||||
prompt = state.get_prompt()
|
||||
params = {
|
||||
"model": "vicuna-13b",
|
||||
"prompt": prompt,
|
||||
"temperature": 0.7,
|
||||
"max_new_tokens": 1024,
|
||||
"stop": "###"
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
url=urljoin(VICUNA_MODEL_SERVER, vicuna_stream_path), data=json.dumps(params)
|
||||
)
|
||||
|
||||
skip_echo_len = len(params["prompt"]) + 1 - params["prompt"].count("</s>") * 3
|
||||
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
||||
if chunk:
|
||||
data = json.loads(chunk.decode())
|
||||
if data["error_code"] == 0:
|
||||
output = data["text"][skip_echo_len:].strip()
|
||||
state.messages[-1][-1] = output + "▌"
|
||||
yield(output)
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(LLM_MODEL)
|
||||
with gr.Blocks() as demo:
|
||||
gr.Markdown("数据库SQL生成助手")
|
||||
with gr.Tab("SQL生成"):
|
||||
text_input = gr.TextArea()
|
||||
text_output = gr.TextArea()
|
||||
text_button = gr.Button("提交")
|
||||
|
||||
|
||||
text_button.click(generate, inputs=text_input, outputs=text_output)
|
||||
|
||||
demo.queue(concurrency_count=3).launch(server_name="0.0.0.0")
|
||||
|
||||
|
Reference in New Issue
Block a user