Unified configuration

This commit is contained in:
yhjun1026
2023-05-18 16:30:57 +08:00
parent a68e164a5f
commit ba7e23d37f
13 changed files with 97 additions and 53 deletions

View File

@@ -17,6 +17,10 @@
#*******************************************************************#
#** LLM MODELS **#
#*******************************************************************#
LLM_MODEL=vicuna-13b
MODEL_SERVER=http://your_model_server_url
LIMIT_MODEL_CONCURRENCY=5
MAX_POSITION_EMBEDDINGS=4096
## SMART_LLM_MODEL - Smart language model (Default: vicuna-13b)
## FAST_LLM_MODEL - Fast language model (Default: chatglm-6b)
@@ -36,10 +40,10 @@
#*******************************************************************#
#** DATABASE SETTINGS **#
#*******************************************************************#
DB_SETTINGS_MYSQL_USER=root
DB_SETTINGS_MYSQL_PASSWORD=password
DB_SETTINGS_MYSQL_HOST=localhost
DB_SETTINGS_MYSQL_PORT=3306
LOCAL_DB_USER=root
LOCAL_DB_PASSWORD=password
LOCAL_DB_HOST=localhost
LOCAL_DB_PORT=3306
### MILVUS

1
.gitignore vendored
View File

@@ -6,6 +6,7 @@ __pycache__/
# C extensions
*.so
.env
.idea
.vscode
.idea

View File

@@ -179,7 +179,7 @@ Run gradio webui
```bash
$ python pilot/server/webserver.py
```
Notice: the webserver need to connect llmserver, so you need change the pilot/configs/model_config.py file. change the VICUNA_MODEL_SERVER = "http://127.0.0.1:8000" to your address. It's very important.
Notice: the webserver need to connect llmserver, so you need change the .env file. change the MODEL_SERVER = "http://127.0.0.1:8000" to your address. It's very important.
## Usage Instructions
We provide a user interface for Gradio, which allows you to use DB-GPT through our user interface. Additionally, we have prepared several reference articles (written in Chinese) that introduce the code and principles related to our project.

View File

@@ -178,7 +178,7 @@ python llmserver.py
```bash
$ python webserver.py
```
注意: 在启动Webserver之前, 需要修改pilot/configs/model_config.py 文件中的VICUNA_MODEL_SERVER = "http://127.0.0.1:8000", 将地址设置为你的服务器地址。
注意: 在启动Webserver之前, 需要修改配置文件 .env文件中的MODEL_SERVER = "http://127.0.0.1:8000", 将地址设置为你的服务器地址。
## 使用说明

View File

@@ -7,12 +7,15 @@ import time
import uuid
from urllib.parse import urljoin
import gradio as gr
from pilot.configs.model_config import *
from pilot.configs.config import Config
from pilot.conversation import conv_qa_prompt_template, conv_templates
from langchain.prompts import PromptTemplate
vicuna_stream_path = "generate_stream"
CFG = Config()
def generate(query):
template_name = "conv_one_shot"
@@ -41,7 +44,7 @@ def generate(query):
}
response = requests.post(
url=urljoin(VICUNA_MODEL_SERVER, vicuna_stream_path), data=json.dumps(params)
url=urljoin(CFG.MODEL_SERVER, vicuna_stream_path), data=json.dumps(params)
)
skip_echo_len = len(params["prompt"]) + 1 - params["prompt"].count("</s>") * 3
@@ -54,7 +57,7 @@ def generate(query):
yield(output)
if __name__ == "__main__":
print(LLM_MODEL)
print(CFG.LLM_MODEL)
with gr.Blocks() as demo:
gr.Markdown("数据库SQL生成助手")
with gr.Tab("SQL生成"):

View File

@@ -2,24 +2,23 @@
# -*- coding: utf-8 -*-
import os
import nltk
from typing import List
from auto_gpt_plugin_template import AutoGPTPluginTemplate
from pilot.singleton import Singleton
class Config(metaclass=Singleton):
"""Configuration class to store the state of bools for different scripts access"""
def __init__(self) -> None:
"""Initialize the Config class"""
# TODO change model_config there
self.debug_mode = False
self.skip_reprompt = False
self.temperature = float(os.getenv("TEMPERATURE", 0.7))
# TODO change model_config there
self.execute_local_commands = (
os.getenv("EXECUTE_LOCAL_COMMANDS", "False") == "True"
)
@@ -46,17 +45,12 @@ class Config(metaclass=Singleton):
self.milvus_collection = os.getenv("MILVUS_COLLECTION", "dbgpt")
self.milvus_secure = os.getenv("MILVUS_SECURE") == "True"
self.authorise_key = os.getenv("AUTHORISE_COMMAND_KEY", "y")
self.exit_key = os.getenv("EXIT_KEY", "n")
self.image_provider = bool(os.getenv("IMAGE_PROVIDER", True))
self.image_provider = os.getenv("IMAGE_PROVIDER", True)
self.image_size = int(os.getenv("IMAGE_SIZE", 256))
self.plugins_dir = os.getenv("PLUGINS_DIR", "../../plugins")
self.plugins: List[AutoGPTPluginTemplate] = []
self.plugins_openai = []
self.command_registry = []
self.huggingface_api_token = os.getenv("HUGGINGFACE_API_TOKEN")
self.image_provider = os.getenv("IMAGE_PROVIDER")
self.image_size = int(os.getenv("IMAGE_SIZE", 256))
@@ -68,6 +62,10 @@ class Config(metaclass=Singleton):
)
self.speak_mode = False
### Related configuration of built-in commands
self.command_registry = []
disabled_command_categories = os.getenv("DISABLED_COMMAND_CATEGORIES")
if disabled_command_categories:
self.disabled_command_categories = disabled_command_categories.split(",")
@@ -78,6 +76,12 @@ class Config(metaclass=Singleton):
os.getenv("EXECUTE_LOCAL_COMMANDS", "False") == "True"
)
### The associated configuration parameters of the plug-in control the loading and use of the plug-in
self.plugins_dir = os.getenv("PLUGINS_DIR", "../../plugins")
self.plugins: List[AutoGPTPluginTemplate] = []
self.plugins_openai = []
plugins_allowlist = os.getenv("ALLOWLISTED_PLUGINS")
if plugins_allowlist:
self.plugins_allowlist = plugins_allowlist.split(",")
@@ -90,6 +94,20 @@ class Config(metaclass=Singleton):
else:
self.plugins_denylist = []
### Local database connection configuration
self.LOCAL_DB_HOST = os.getenv("LOCAL_DB_HOST", "127.0.0.1")
self.LOCAL_DB_PORT = int(os.getenv("LOCAL_DB_PORT", 3306))
self.LOCAL_DB_USER = os.getenv("LOCAL_DB_USER", "root")
self.LOCAL_DB_PASSWORD = os.getenv("LOCAL_DB_PASSWORD", "aa123456")
### LLM Model Service Configuration
self.LLM_MODEL = os.getenv("LLM_MODEL", "vicuna-13b")
self.LIMIT_MODEL_CONCURRENCY = int(os.getenv("LIMIT_MODEL_CONCURRENCY", 5))
self.MAX_POSITION_EMBEDDINGS = int(os.getenv("MAX_POSITION_EMBEDDINGS", 4096))
self.MODEL_SERVER = os.getenv("MODEL_SERVER", "http://121.41.167.183:8000")
self.ISLOAD_8BIT = os.getenv("ISLOAD_8BIT", "True") == "True"
def set_debug_mode(self, value: bool) -> None:
"""Set the debug mode value"""
self.debug_mode = value

View File

@@ -22,25 +22,18 @@ LLM_MODEL_CONFIG = {
"vicuna-13b": os.path.join(MODEL_PATH, "vicuna-13b"),
"sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2")
}
VECTOR_SEARCH_TOP_K = 3
LLM_MODEL = "vicuna-13b"
LIMIT_MODEL_CONCURRENCY = 5
MAX_POSITION_EMBEDDINGS = 4096
VICUNA_MODEL_SERVER = "http://121.41.167.183:8000"
# Load model config
ISLOAD_8BIT = True
ISDEBUG = False
DB_SETTINGS = {
"user": "root",
"password": "aa123456",
"host": "127.0.0.1",
"port": 3306
}
VECTOR_SEARCH_TOP_K = 3
# LLM_MODEL = "vicuna-13b"
# LIMIT_MODEL_CONCURRENCY = 5
# MAX_POSITION_EMBEDDINGS = 4096
# VICUNA_MODEL_SERVER = "http://121.41.167.183:8000"
VS_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "vs_store")
KNOWLEDGE_UPLOAD_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "knowledge")

View File

@@ -4,8 +4,16 @@
import dataclasses
from enum import auto, Enum
from typing import List, Any
from pilot.configs.model_config import DB_SETTINGS
from pilot.configs.config import Config
CFG = Config()
DB_SETTINGS = {
"user": CFG.LOCAL_DB_USER,
"password": CFG.LOCAL_DB_PASSWORD,
"host": CFG.LOCAL_DB_HOST,
"port": CFG.LOCAL_DB_PORT
}
class SeparatorStyle(Enum):
SINGLE = auto()
@@ -91,7 +99,7 @@ class Conversation:
def gen_sqlgen_conversation(dbname):
from pilot.connections.mysql import MySQLOperator
mo = MySQLOperator(
**DB_SETTINGS
**(DB_SETTINGS)
)
message = ""

View File

@@ -8,8 +8,9 @@ from langchain.embeddings.base import Embeddings
from pydantic import BaseModel
from typing import Any, Mapping, Optional, List
from langchain.llms.base import LLM
from pilot.configs.model_config import *
from pilot.configs.config import Config
CFG = Config()
class VicunaLLM(LLM):
vicuna_generate_path = "generate_stream"
@@ -22,7 +23,7 @@ class VicunaLLM(LLM):
"stop": stop
}
response = requests.post(
url=urljoin(VICUNA_MODEL_SERVER, self.vicuna_generate_path),
url=urljoin(CFG.MODEL_SERVER, self.vicuna_generate_path),
data=json.dumps(params),
)
@@ -51,7 +52,7 @@ class VicunaEmbeddingLLM(BaseModel, Embeddings):
print("Sending prompt ", p)
response = requests.post(
url=urljoin(VICUNA_MODEL_SERVER, self.vicuna_embedding_path),
url=urljoin(CFG.MODEL_SERVER, self.vicuna_embedding_path),
json={
"prompt": p
}

View File

@@ -17,14 +17,17 @@ from peft import (
import torch
from datasets import load_dataset
import pandas as pd
from pilot.configs.config import Config
from pilot.configs.model_config import DATA_DIR, LLM_MODEL, LLM_MODEL_CONFIG
from pilot.configs.model_config import DATA_DIR, LLM_MODEL_CONFIG
device = "cuda" if torch.cuda.is_available() else "cpu"
CUTOFF_LEN = 50
df = pd.read_csv(os.path.join(DATA_DIR, "BTC_Tweets_Updated.csv"))
CFG = Config()
def sentiment_score_to_name(score: float):
if score > 0:
return "Positive"
@@ -49,7 +52,7 @@ with open(os.path.join(DATA_DIR, "alpaca-bitcoin-sentiment-dataset.json"), "w")
data = load_dataset("json", data_files=os.path.join(DATA_DIR, "alpaca-bitcoin-sentiment-dataset.json"))
print(data["train"])
BASE_MODEL = LLM_MODEL_CONFIG[LLM_MODEL]
BASE_MODEL = LLM_MODEL_CONFIG[CFG.LLM_MODEL]
model = LlamaForCausalLM.from_pretrained(
BASE_MODEL,
torch_dtype=torch.float16,

View File

@@ -13,8 +13,11 @@ from pilot.model.inference import generate_output, get_embeddings
from pilot.model.loader import ModelLoader
from pilot.configs.model_config import *
from pilot.configs.config import Config
model_path = LLM_MODEL_CONFIG[LLM_MODEL]
CFG = Config()
model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL]
global_counter = 0
@@ -60,7 +63,7 @@ def generate_stream_gate(params):
tokenizer,
params,
DEVICE,
MAX_POSITION_EMBEDDINGS,
CFG.MAX_POSITION_EMBEDDINGS,
):
print("output: ", output)
ret = {
@@ -84,7 +87,7 @@ async def api_generate_stream(request: Request):
print(model, tokenizer, params, DEVICE)
if model_semaphore is None:
model_semaphore = asyncio.Semaphore(LIMIT_MODEL_CONCURRENCY)
model_semaphore = asyncio.Semaphore(CFG.LIMIT_MODEL_CONCURRENCY)
await model_semaphore.acquire()
generator = generate_stream_gate(params)

View File

@@ -14,13 +14,13 @@ from urllib.parse import urljoin
from langchain import PromptTemplate
from pilot.configs.model_config import DB_SETTINGS, KNOWLEDGE_UPLOAD_ROOT_PATH, LLM_MODEL_CONFIG
from pilot.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH, LLM_MODEL_CONFIG
from pilot.server.vectordb_qa import KnownLedgeBaseQA
from pilot.connections.mysql import MySQLOperator
from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding
from pilot.vector_store.extract_tovec import get_vector_storelist, load_knownledge_from_doc, knownledge_tovec_st
from pilot.configs.model_config import LOGDIR, VICUNA_MODEL_SERVER, LLM_MODEL, DATASETS_DIR
from pilot.configs.model_config import LOGDIR, DATASETS_DIR
from pilot.plugins import scan_plugins
from pilot.configs.config import Config
@@ -67,7 +67,15 @@ priority = {
"vicuna-13b": "aaa"
}
# 加载插件
CFG= Config()
DB_SETTINGS = {
"user": CFG.LOCAL_DB_USER,
"password": CFG.LOCAL_DB_PASSWORD,
"host": CFG.LOCAL_DB_HOST,
"port": CFG.LOCAL_DB_PORT
}
def get_simlar(q):
docsearch = knownledge_tovec_st(os.path.join(DATASETS_DIR, "plan.md"))
docs = docsearch.similarity_search_with_score(q, k=1)
@@ -178,7 +186,7 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
print("是否是AUTO-GPT模式.", autogpt)
start_tstamp = time.time()
model_name = LLM_MODEL
model_name = CFG.LLM_MODEL
dbname = db_selector
# TODO 这里的请求需要拼接现有知识库, 使得其根据现有知识库作答, 所以prompt需要继续优化
@@ -268,7 +276,7 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
logger.info(f"Requert: \n{payload}")
if sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
response = requests.post(urljoin(VICUNA_MODEL_SERVER, "generate"),
response = requests.post(urljoin(CFG.MODEL_SERVER, "generate"),
headers=headers, json=payload, timeout=120)
print(response.json())
@@ -316,7 +324,7 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
try:
# Stream output
response = requests.post(urljoin(VICUNA_MODEL_SERVER, "generate_stream"),
response = requests.post(urljoin(CFG.MODEL_SERVER, "generate_stream"),
headers=headers, json=payload, stream=True, timeout=20)
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
@@ -595,9 +603,8 @@ if __name__ == "__main__":
# dbs = get_database_list()
# 加载插件
# 配置初始化
cfg = Config()
cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode))
# 加载插件可执行命令

View File

@@ -57,6 +57,9 @@ pymdown-extensions
mkdocs
requests
gTTS==2.3.1
langchain
nltk
python-dotenv==1.0.0
# Testing dependencies
pytest