Unified configuration

This commit is contained in:
yhjun1026
2023-05-18 16:30:57 +08:00
parent a68e164a5f
commit ba7e23d37f
13 changed files with 97 additions and 53 deletions

View File

@@ -17,6 +17,10 @@
#*******************************************************************# #*******************************************************************#
#** LLM MODELS **# #** LLM MODELS **#
#*******************************************************************# #*******************************************************************#
LLM_MODEL=vicuna-13b
MODEL_SERVER=http://your_model_server_url
LIMIT_MODEL_CONCURRENCY=5
MAX_POSITION_EMBEDDINGS=4096
## SMART_LLM_MODEL - Smart language model (Default: vicuna-13b) ## SMART_LLM_MODEL - Smart language model (Default: vicuna-13b)
## FAST_LLM_MODEL - Fast language model (Default: chatglm-6b) ## FAST_LLM_MODEL - Fast language model (Default: chatglm-6b)
@@ -36,10 +40,10 @@
#*******************************************************************# #*******************************************************************#
#** DATABASE SETTINGS **# #** DATABASE SETTINGS **#
#*******************************************************************# #*******************************************************************#
DB_SETTINGS_MYSQL_USER=root LOCAL_DB_USER=root
DB_SETTINGS_MYSQL_PASSWORD=password LOCAL_DB_PASSWORD=password
DB_SETTINGS_MYSQL_HOST=localhost LOCAL_DB_HOST=localhost
DB_SETTINGS_MYSQL_PORT=3306 LOCAL_DB_PORT=3306
### MILVUS ### MILVUS

1
.gitignore vendored
View File

@@ -6,6 +6,7 @@ __pycache__/
# C extensions # C extensions
*.so *.so
.env
.idea .idea
.vscode .vscode
.idea .idea

View File

@@ -179,7 +179,7 @@ Run gradio webui
```bash ```bash
$ python pilot/server/webserver.py $ python pilot/server/webserver.py
``` ```
Notice: the webserver need to connect llmserver, so you need change the pilot/configs/model_config.py file. change the VICUNA_MODEL_SERVER = "http://127.0.0.1:8000" to your address. It's very important. Notice: the webserver need to connect llmserver, so you need change the .env file. change the MODEL_SERVER = "http://127.0.0.1:8000" to your address. It's very important.
## Usage Instructions ## Usage Instructions
We provide a user interface for Gradio, which allows you to use DB-GPT through our user interface. Additionally, we have prepared several reference articles (written in Chinese) that introduce the code and principles related to our project. We provide a user interface for Gradio, which allows you to use DB-GPT through our user interface. Additionally, we have prepared several reference articles (written in Chinese) that introduce the code and principles related to our project.

View File

@@ -178,7 +178,7 @@ python llmserver.py
```bash ```bash
$ python webserver.py $ python webserver.py
``` ```
注意: 在启动Webserver之前, 需要修改pilot/configs/model_config.py 文件中的VICUNA_MODEL_SERVER = "http://127.0.0.1:8000", 将地址设置为你的服务器地址。 注意: 在启动Webserver之前, 需要修改配置文件 .env文件中的MODEL_SERVER = "http://127.0.0.1:8000", 将地址设置为你的服务器地址。
## 使用说明 ## 使用说明

View File

@@ -7,12 +7,15 @@ import time
import uuid import uuid
from urllib.parse import urljoin from urllib.parse import urljoin
import gradio as gr import gradio as gr
from pilot.configs.model_config import * from pilot.configs.config import Config
from pilot.conversation import conv_qa_prompt_template, conv_templates from pilot.conversation import conv_qa_prompt_template, conv_templates
from langchain.prompts import PromptTemplate from langchain.prompts import PromptTemplate
vicuna_stream_path = "generate_stream" vicuna_stream_path = "generate_stream"
CFG = Config()
def generate(query): def generate(query):
template_name = "conv_one_shot" template_name = "conv_one_shot"
@@ -41,7 +44,7 @@ def generate(query):
} }
response = requests.post( response = requests.post(
url=urljoin(VICUNA_MODEL_SERVER, vicuna_stream_path), data=json.dumps(params) url=urljoin(CFG.MODEL_SERVER, vicuna_stream_path), data=json.dumps(params)
) )
skip_echo_len = len(params["prompt"]) + 1 - params["prompt"].count("</s>") * 3 skip_echo_len = len(params["prompt"]) + 1 - params["prompt"].count("</s>") * 3
@@ -54,7 +57,7 @@ def generate(query):
yield(output) yield(output)
if __name__ == "__main__": if __name__ == "__main__":
print(LLM_MODEL) print(CFG.LLM_MODEL)
with gr.Blocks() as demo: with gr.Blocks() as demo:
gr.Markdown("数据库SQL生成助手") gr.Markdown("数据库SQL生成助手")
with gr.Tab("SQL生成"): with gr.Tab("SQL生成"):

View File

@@ -2,24 +2,23 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import os import os
import nltk
from typing import List from typing import List
from auto_gpt_plugin_template import AutoGPTPluginTemplate from auto_gpt_plugin_template import AutoGPTPluginTemplate
from pilot.singleton import Singleton from pilot.singleton import Singleton
class Config(metaclass=Singleton): class Config(metaclass=Singleton):
"""Configuration class to store the state of bools for different scripts access""" """Configuration class to store the state of bools for different scripts access"""
def __init__(self) -> None: def __init__(self) -> None:
"""Initialize the Config class""" """Initialize the Config class"""
# TODO change model_config there
self.debug_mode = False self.debug_mode = False
self.skip_reprompt = False self.skip_reprompt = False
self.temperature = float(os.getenv("TEMPERATURE", 0.7)) self.temperature = float(os.getenv("TEMPERATURE", 0.7))
# TODO change model_config there
self.execute_local_commands = ( self.execute_local_commands = (
os.getenv("EXECUTE_LOCAL_COMMANDS", "False") == "True" os.getenv("EXECUTE_LOCAL_COMMANDS", "False") == "True"
) )
@@ -46,17 +45,12 @@ class Config(metaclass=Singleton):
self.milvus_collection = os.getenv("MILVUS_COLLECTION", "dbgpt") self.milvus_collection = os.getenv("MILVUS_COLLECTION", "dbgpt")
self.milvus_secure = os.getenv("MILVUS_SECURE") == "True" self.milvus_secure = os.getenv("MILVUS_SECURE") == "True"
self.authorise_key = os.getenv("AUTHORISE_COMMAND_KEY", "y") self.authorise_key = os.getenv("AUTHORISE_COMMAND_KEY", "y")
self.exit_key = os.getenv("EXIT_KEY", "n") self.exit_key = os.getenv("EXIT_KEY", "n")
self.image_provider = bool(os.getenv("IMAGE_PROVIDER", True)) self.image_provider = os.getenv("IMAGE_PROVIDER", True)
self.image_size = int(os.getenv("IMAGE_SIZE", 256)) self.image_size = int(os.getenv("IMAGE_SIZE", 256))
self.plugins_dir = os.getenv("PLUGINS_DIR", "../../plugins")
self.plugins: List[AutoGPTPluginTemplate] = []
self.plugins_openai = []
self.command_registry = []
self.huggingface_api_token = os.getenv("HUGGINGFACE_API_TOKEN") self.huggingface_api_token = os.getenv("HUGGINGFACE_API_TOKEN")
self.image_provider = os.getenv("IMAGE_PROVIDER") self.image_provider = os.getenv("IMAGE_PROVIDER")
self.image_size = int(os.getenv("IMAGE_SIZE", 256)) self.image_size = int(os.getenv("IMAGE_SIZE", 256))
@@ -68,6 +62,10 @@ class Config(metaclass=Singleton):
) )
self.speak_mode = False self.speak_mode = False
### Related configuration of built-in commands
self.command_registry = []
disabled_command_categories = os.getenv("DISABLED_COMMAND_CATEGORIES") disabled_command_categories = os.getenv("DISABLED_COMMAND_CATEGORIES")
if disabled_command_categories: if disabled_command_categories:
self.disabled_command_categories = disabled_command_categories.split(",") self.disabled_command_categories = disabled_command_categories.split(",")
@@ -78,6 +76,12 @@ class Config(metaclass=Singleton):
os.getenv("EXECUTE_LOCAL_COMMANDS", "False") == "True" os.getenv("EXECUTE_LOCAL_COMMANDS", "False") == "True"
) )
### The associated configuration parameters of the plug-in control the loading and use of the plug-in
self.plugins_dir = os.getenv("PLUGINS_DIR", "../../plugins")
self.plugins: List[AutoGPTPluginTemplate] = []
self.plugins_openai = []
plugins_allowlist = os.getenv("ALLOWLISTED_PLUGINS") plugins_allowlist = os.getenv("ALLOWLISTED_PLUGINS")
if plugins_allowlist: if plugins_allowlist:
self.plugins_allowlist = plugins_allowlist.split(",") self.plugins_allowlist = plugins_allowlist.split(",")
@@ -90,6 +94,20 @@ class Config(metaclass=Singleton):
else: else:
self.plugins_denylist = [] self.plugins_denylist = []
### Local database connection configuration
self.LOCAL_DB_HOST = os.getenv("LOCAL_DB_HOST", "127.0.0.1")
self.LOCAL_DB_PORT = int(os.getenv("LOCAL_DB_PORT", 3306))
self.LOCAL_DB_USER = os.getenv("LOCAL_DB_USER", "root")
self.LOCAL_DB_PASSWORD = os.getenv("LOCAL_DB_PASSWORD", "aa123456")
### LLM Model Service Configuration
self.LLM_MODEL = os.getenv("LLM_MODEL", "vicuna-13b")
self.LIMIT_MODEL_CONCURRENCY = int(os.getenv("LIMIT_MODEL_CONCURRENCY", 5))
self.MAX_POSITION_EMBEDDINGS = int(os.getenv("MAX_POSITION_EMBEDDINGS", 4096))
self.MODEL_SERVER = os.getenv("MODEL_SERVER", "http://121.41.167.183:8000")
self.ISLOAD_8BIT = os.getenv("ISLOAD_8BIT", "True") == "True"
def set_debug_mode(self, value: bool) -> None: def set_debug_mode(self, value: bool) -> None:
"""Set the debug mode value""" """Set the debug mode value"""
self.debug_mode = value self.debug_mode = value

View File

@@ -22,25 +22,18 @@ LLM_MODEL_CONFIG = {
"vicuna-13b": os.path.join(MODEL_PATH, "vicuna-13b"), "vicuna-13b": os.path.join(MODEL_PATH, "vicuna-13b"),
"sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2") "sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2")
} }
VECTOR_SEARCH_TOP_K = 3
LLM_MODEL = "vicuna-13b"
LIMIT_MODEL_CONCURRENCY = 5
MAX_POSITION_EMBEDDINGS = 4096
VICUNA_MODEL_SERVER = "http://121.41.167.183:8000"
# Load model config # Load model config
ISLOAD_8BIT = True ISLOAD_8BIT = True
ISDEBUG = False ISDEBUG = False
DB_SETTINGS = { VECTOR_SEARCH_TOP_K = 3
"user": "root", # LLM_MODEL = "vicuna-13b"
"password": "aa123456", # LIMIT_MODEL_CONCURRENCY = 5
"host": "127.0.0.1", # MAX_POSITION_EMBEDDINGS = 4096
"port": 3306 # VICUNA_MODEL_SERVER = "http://121.41.167.183:8000"
}
VS_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "vs_store") VS_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "vs_store")
KNOWLEDGE_UPLOAD_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "knowledge") KNOWLEDGE_UPLOAD_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "knowledge")

View File

@@ -4,8 +4,16 @@
import dataclasses import dataclasses
from enum import auto, Enum from enum import auto, Enum
from typing import List, Any from typing import List, Any
from pilot.configs.model_config import DB_SETTINGS from pilot.configs.config import Config
CFG = Config()
DB_SETTINGS = {
"user": CFG.LOCAL_DB_USER,
"password": CFG.LOCAL_DB_PASSWORD,
"host": CFG.LOCAL_DB_HOST,
"port": CFG.LOCAL_DB_PORT
}
class SeparatorStyle(Enum): class SeparatorStyle(Enum):
SINGLE = auto() SINGLE = auto()
@@ -91,7 +99,7 @@ class Conversation:
def gen_sqlgen_conversation(dbname): def gen_sqlgen_conversation(dbname):
from pilot.connections.mysql import MySQLOperator from pilot.connections.mysql import MySQLOperator
mo = MySQLOperator( mo = MySQLOperator(
**DB_SETTINGS **(DB_SETTINGS)
) )
message = "" message = ""

View File

@@ -8,8 +8,9 @@ from langchain.embeddings.base import Embeddings
from pydantic import BaseModel from pydantic import BaseModel
from typing import Any, Mapping, Optional, List from typing import Any, Mapping, Optional, List
from langchain.llms.base import LLM from langchain.llms.base import LLM
from pilot.configs.model_config import * from pilot.configs.config import Config
CFG = Config()
class VicunaLLM(LLM): class VicunaLLM(LLM):
vicuna_generate_path = "generate_stream" vicuna_generate_path = "generate_stream"
@@ -22,7 +23,7 @@ class VicunaLLM(LLM):
"stop": stop "stop": stop
} }
response = requests.post( response = requests.post(
url=urljoin(VICUNA_MODEL_SERVER, self.vicuna_generate_path), url=urljoin(CFG.MODEL_SERVER, self.vicuna_generate_path),
data=json.dumps(params), data=json.dumps(params),
) )
@@ -51,7 +52,7 @@ class VicunaEmbeddingLLM(BaseModel, Embeddings):
print("Sending prompt ", p) print("Sending prompt ", p)
response = requests.post( response = requests.post(
url=urljoin(VICUNA_MODEL_SERVER, self.vicuna_embedding_path), url=urljoin(CFG.MODEL_SERVER, self.vicuna_embedding_path),
json={ json={
"prompt": p "prompt": p
} }

View File

@@ -17,14 +17,17 @@ from peft import (
import torch import torch
from datasets import load_dataset from datasets import load_dataset
import pandas as pd import pandas as pd
from pilot.configs.config import Config
from pilot.configs.model_config import DATA_DIR, LLM_MODEL, LLM_MODEL_CONFIG from pilot.configs.model_config import DATA_DIR, LLM_MODEL_CONFIG
device = "cuda" if torch.cuda.is_available() else "cpu" device = "cuda" if torch.cuda.is_available() else "cpu"
CUTOFF_LEN = 50 CUTOFF_LEN = 50
df = pd.read_csv(os.path.join(DATA_DIR, "BTC_Tweets_Updated.csv")) df = pd.read_csv(os.path.join(DATA_DIR, "BTC_Tweets_Updated.csv"))
CFG = Config()
def sentiment_score_to_name(score: float): def sentiment_score_to_name(score: float):
if score > 0: if score > 0:
return "Positive" return "Positive"
@@ -49,7 +52,7 @@ with open(os.path.join(DATA_DIR, "alpaca-bitcoin-sentiment-dataset.json"), "w")
data = load_dataset("json", data_files=os.path.join(DATA_DIR, "alpaca-bitcoin-sentiment-dataset.json")) data = load_dataset("json", data_files=os.path.join(DATA_DIR, "alpaca-bitcoin-sentiment-dataset.json"))
print(data["train"]) print(data["train"])
BASE_MODEL = LLM_MODEL_CONFIG[LLM_MODEL] BASE_MODEL = LLM_MODEL_CONFIG[CFG.LLM_MODEL]
model = LlamaForCausalLM.from_pretrained( model = LlamaForCausalLM.from_pretrained(
BASE_MODEL, BASE_MODEL,
torch_dtype=torch.float16, torch_dtype=torch.float16,

View File

@@ -13,8 +13,11 @@ from pilot.model.inference import generate_output, get_embeddings
from pilot.model.loader import ModelLoader from pilot.model.loader import ModelLoader
from pilot.configs.model_config import * from pilot.configs.model_config import *
from pilot.configs.config import Config
model_path = LLM_MODEL_CONFIG[LLM_MODEL]
CFG = Config()
model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL]
global_counter = 0 global_counter = 0
@@ -60,7 +63,7 @@ def generate_stream_gate(params):
tokenizer, tokenizer,
params, params,
DEVICE, DEVICE,
MAX_POSITION_EMBEDDINGS, CFG.MAX_POSITION_EMBEDDINGS,
): ):
print("output: ", output) print("output: ", output)
ret = { ret = {
@@ -84,7 +87,7 @@ async def api_generate_stream(request: Request):
print(model, tokenizer, params, DEVICE) print(model, tokenizer, params, DEVICE)
if model_semaphore is None: if model_semaphore is None:
model_semaphore = asyncio.Semaphore(LIMIT_MODEL_CONCURRENCY) model_semaphore = asyncio.Semaphore(CFG.LIMIT_MODEL_CONCURRENCY)
await model_semaphore.acquire() await model_semaphore.acquire()
generator = generate_stream_gate(params) generator = generate_stream_gate(params)

View File

@@ -14,13 +14,13 @@ from urllib.parse import urljoin
from langchain import PromptTemplate from langchain import PromptTemplate
from pilot.configs.model_config import DB_SETTINGS, KNOWLEDGE_UPLOAD_ROOT_PATH, LLM_MODEL_CONFIG from pilot.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH, LLM_MODEL_CONFIG
from pilot.server.vectordb_qa import KnownLedgeBaseQA from pilot.server.vectordb_qa import KnownLedgeBaseQA
from pilot.connections.mysql import MySQLOperator from pilot.connections.mysql import MySQLOperator
from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding
from pilot.vector_store.extract_tovec import get_vector_storelist, load_knownledge_from_doc, knownledge_tovec_st from pilot.vector_store.extract_tovec import get_vector_storelist, load_knownledge_from_doc, knownledge_tovec_st
from pilot.configs.model_config import LOGDIR, VICUNA_MODEL_SERVER, LLM_MODEL, DATASETS_DIR from pilot.configs.model_config import LOGDIR, DATASETS_DIR
from pilot.plugins import scan_plugins from pilot.plugins import scan_plugins
from pilot.configs.config import Config from pilot.configs.config import Config
@@ -67,7 +67,15 @@ priority = {
"vicuna-13b": "aaa" "vicuna-13b": "aaa"
} }
# 加载插件
CFG= Config()
DB_SETTINGS = {
"user": CFG.LOCAL_DB_USER,
"password": CFG.LOCAL_DB_PASSWORD,
"host": CFG.LOCAL_DB_HOST,
"port": CFG.LOCAL_DB_PORT
}
def get_simlar(q): def get_simlar(q):
docsearch = knownledge_tovec_st(os.path.join(DATASETS_DIR, "plan.md")) docsearch = knownledge_tovec_st(os.path.join(DATASETS_DIR, "plan.md"))
docs = docsearch.similarity_search_with_score(q, k=1) docs = docsearch.similarity_search_with_score(q, k=1)
@@ -178,7 +186,7 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
print("是否是AUTO-GPT模式.", autogpt) print("是否是AUTO-GPT模式.", autogpt)
start_tstamp = time.time() start_tstamp = time.time()
model_name = LLM_MODEL model_name = CFG.LLM_MODEL
dbname = db_selector dbname = db_selector
# TODO 这里的请求需要拼接现有知识库, 使得其根据现有知识库作答, 所以prompt需要继续优化 # TODO 这里的请求需要拼接现有知识库, 使得其根据现有知识库作答, 所以prompt需要继续优化
@@ -268,7 +276,7 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
logger.info(f"Requert: \n{payload}") logger.info(f"Requert: \n{payload}")
if sql_mode == conversation_sql_mode["auto_execute_ai_response"]: if sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
response = requests.post(urljoin(VICUNA_MODEL_SERVER, "generate"), response = requests.post(urljoin(CFG.MODEL_SERVER, "generate"),
headers=headers, json=payload, timeout=120) headers=headers, json=payload, timeout=120)
print(response.json()) print(response.json())
@@ -316,7 +324,7 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
try: try:
# Stream output # Stream output
response = requests.post(urljoin(VICUNA_MODEL_SERVER, "generate_stream"), response = requests.post(urljoin(CFG.MODEL_SERVER, "generate_stream"),
headers=headers, json=payload, stream=True, timeout=20) headers=headers, json=payload, stream=True, timeout=20)
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk: if chunk:
@@ -595,9 +603,8 @@ if __name__ == "__main__":
# dbs = get_database_list() # dbs = get_database_list()
# 加载插件 # 配置初始化
cfg = Config() cfg = Config()
cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode)) cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode))
# 加载插件可执行命令 # 加载插件可执行命令

View File

@@ -57,6 +57,9 @@ pymdown-extensions
mkdocs mkdocs
requests requests
gTTS==2.3.1 gTTS==2.3.1
langchain
nltk
python-dotenv==1.0.0
# Testing dependencies # Testing dependencies
pytest pytest