diff --git a/.env.template b/.env.template index 5bf746eaa..3d2240bff 100644 --- a/.env.template +++ b/.env.template @@ -17,6 +17,10 @@ #*******************************************************************# #** LLM MODELS **# #*******************************************************************# +LLM_MODEL=vicuna-13b +MODEL_SERVER=http://your_model_server_url +LIMIT_MODEL_CONCURRENCY=5 +MAX_POSITION_EMBEDDINGS=4096 ## SMART_LLM_MODEL - Smart language model (Default: vicuna-13b) ## FAST_LLM_MODEL - Fast language model (Default: chatglm-6b) @@ -36,10 +40,10 @@ #*******************************************************************# #** DATABASE SETTINGS **# #*******************************************************************# -DB_SETTINGS_MYSQL_USER=root -DB_SETTINGS_MYSQL_PASSWORD=password -DB_SETTINGS_MYSQL_HOST=localhost -DB_SETTINGS_MYSQL_PORT=3306 +LOCAL_DB_USER=root +LOCAL_DB_PASSWORD=password +LOCAL_DB_HOST=localhost +LOCAL_DB_PORT=3306 ### MILVUS diff --git a/.gitignore b/.gitignore index c4c4a344e..cb21ee557 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ __pycache__/ # C extensions *.so +.env .idea .vscode .idea diff --git a/README.en.md b/README.en.md index 7435d29d2..03c3cbb3c 100644 --- a/README.en.md +++ b/README.en.md @@ -179,7 +179,7 @@ Run gradio webui ```bash $ python pilot/server/webserver.py ``` -Notice: the webserver need to connect llmserver, so you need change the pilot/configs/model_config.py file. change the VICUNA_MODEL_SERVER = "http://127.0.0.1:8000" to your address. It's very important. +Notice: the webserver need to connect llmserver, so you need change the .env file. change the MODEL_SERVER = "http://127.0.0.1:8000" to your address. It's very important. ## Usage Instructions We provide a user interface for Gradio, which allows you to use DB-GPT through our user interface. Additionally, we have prepared several reference articles (written in Chinese) that introduce the code and principles related to our project. diff --git a/README.md b/README.md index 3ef04884b..e0994acc9 100644 --- a/README.md +++ b/README.md @@ -178,7 +178,7 @@ python llmserver.py ```bash $ python webserver.py ``` -注意: 在启动Webserver之前, 需要修改pilot/configs/model_config.py 文件中的VICUNA_MODEL_SERVER = "http://127.0.0.1:8000", 将地址设置为你的服务器地址。 +注意: 在启动Webserver之前, 需要修改配置文件 .env文件中的MODEL_SERVER = "http://127.0.0.1:8000", 将地址设置为你的服务器地址。 ## 使用说明 diff --git a/examples/embdserver.py b/examples/embdserver.py index 6599a18ad..79140ba66 100644 --- a/examples/embdserver.py +++ b/examples/embdserver.py @@ -7,12 +7,15 @@ import time import uuid from urllib.parse import urljoin import gradio as gr -from pilot.configs.model_config import * +from pilot.configs.config import Config from pilot.conversation import conv_qa_prompt_template, conv_templates from langchain.prompts import PromptTemplate + vicuna_stream_path = "generate_stream" +CFG = Config() + def generate(query): template_name = "conv_one_shot" @@ -41,7 +44,7 @@ def generate(query): } response = requests.post( - url=urljoin(VICUNA_MODEL_SERVER, vicuna_stream_path), data=json.dumps(params) + url=urljoin(CFG.MODEL_SERVER, vicuna_stream_path), data=json.dumps(params) ) skip_echo_len = len(params["prompt"]) + 1 - params["prompt"].count("") * 3 @@ -54,7 +57,7 @@ def generate(query): yield(output) if __name__ == "__main__": - print(LLM_MODEL) + print(CFG.LLM_MODEL) with gr.Blocks() as demo: gr.Markdown("数据库SQL生成助手") with gr.Tab("SQL生成"): diff --git a/pilot/configs/config.py b/pilot/configs/config.py index 5749a752d..9023bc061 100644 --- a/pilot/configs/config.py +++ b/pilot/configs/config.py @@ -2,24 +2,23 @@ # -*- coding: utf-8 -*- import os +import nltk from typing import List from auto_gpt_plugin_template import AutoGPTPluginTemplate from pilot.singleton import Singleton + class Config(metaclass=Singleton): """Configuration class to store the state of bools for different scripts access""" def __init__(self) -> None: """Initialize the Config class""" - # TODO change model_config there - self.debug_mode = False self.skip_reprompt = False - self.temperature = float(os.getenv("TEMPERATURE", 0.7)) - # TODO change model_config there + self.execute_local_commands = ( os.getenv("EXECUTE_LOCAL_COMMANDS", "False") == "True" ) @@ -46,17 +45,12 @@ class Config(metaclass=Singleton): self.milvus_collection = os.getenv("MILVUS_COLLECTION", "dbgpt") self.milvus_secure = os.getenv("MILVUS_SECURE") == "True" + self.authorise_key = os.getenv("AUTHORISE_COMMAND_KEY", "y") self.exit_key = os.getenv("EXIT_KEY", "n") - self.image_provider = bool(os.getenv("IMAGE_PROVIDER", True)) + self.image_provider = os.getenv("IMAGE_PROVIDER", True) self.image_size = int(os.getenv("IMAGE_SIZE", 256)) - self.plugins_dir = os.getenv("PLUGINS_DIR", "../../plugins") - self.plugins: List[AutoGPTPluginTemplate] = [] - self.plugins_openai = [] - - self.command_registry = [] - self.huggingface_api_token = os.getenv("HUGGINGFACE_API_TOKEN") self.image_provider = os.getenv("IMAGE_PROVIDER") self.image_size = int(os.getenv("IMAGE_SIZE", 256)) @@ -68,6 +62,10 @@ class Config(metaclass=Singleton): ) self.speak_mode = False + + ### Related configuration of built-in commands + self.command_registry = [] + disabled_command_categories = os.getenv("DISABLED_COMMAND_CATEGORIES") if disabled_command_categories: self.disabled_command_categories = disabled_command_categories.split(",") @@ -78,6 +76,12 @@ class Config(metaclass=Singleton): os.getenv("EXECUTE_LOCAL_COMMANDS", "False") == "True" ) + + ### The associated configuration parameters of the plug-in control the loading and use of the plug-in + self.plugins_dir = os.getenv("PLUGINS_DIR", "../../plugins") + self.plugins: List[AutoGPTPluginTemplate] = [] + self.plugins_openai = [] + plugins_allowlist = os.getenv("ALLOWLISTED_PLUGINS") if plugins_allowlist: self.plugins_allowlist = plugins_allowlist.split(",") @@ -89,7 +93,21 @@ class Config(metaclass=Singleton): self.plugins_denylist = plugins_denylist.split(",") else: self.plugins_denylist = [] - + + + ### Local database connection configuration + self.LOCAL_DB_HOST = os.getenv("LOCAL_DB_HOST", "127.0.0.1") + self.LOCAL_DB_PORT = int(os.getenv("LOCAL_DB_PORT", 3306)) + self.LOCAL_DB_USER = os.getenv("LOCAL_DB_USER", "root") + self.LOCAL_DB_PASSWORD = os.getenv("LOCAL_DB_PASSWORD", "aa123456") + + ### LLM Model Service Configuration + self.LLM_MODEL = os.getenv("LLM_MODEL", "vicuna-13b") + self.LIMIT_MODEL_CONCURRENCY = int(os.getenv("LIMIT_MODEL_CONCURRENCY", 5)) + self.MAX_POSITION_EMBEDDINGS = int(os.getenv("MAX_POSITION_EMBEDDINGS", 4096)) + self.MODEL_SERVER = os.getenv("MODEL_SERVER", "http://121.41.167.183:8000") + self.ISLOAD_8BIT = os.getenv("ISLOAD_8BIT", "True") == "True" + def set_debug_mode(self, value: bool) -> None: """Set the debug mode value""" self.debug_mode = value diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py index 12c7e33da..af6c138b5 100644 --- a/pilot/configs/model_config.py +++ b/pilot/configs/model_config.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python3 +#!/usr/bin/env python3 # -*- coding:utf-8 -*- import torch @@ -22,25 +22,18 @@ LLM_MODEL_CONFIG = { "vicuna-13b": os.path.join(MODEL_PATH, "vicuna-13b"), "sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2") } - - -VECTOR_SEARCH_TOP_K = 3 -LLM_MODEL = "vicuna-13b" -LIMIT_MODEL_CONCURRENCY = 5 -MAX_POSITION_EMBEDDINGS = 4096 -VICUNA_MODEL_SERVER = "http://121.41.167.183:8000" - # Load model config ISLOAD_8BIT = True ISDEBUG = False -DB_SETTINGS = { - "user": "root", - "password": "aa123456", - "host": "127.0.0.1", - "port": 3306 -} +VECTOR_SEARCH_TOP_K = 3 +# LLM_MODEL = "vicuna-13b" +# LIMIT_MODEL_CONCURRENCY = 5 +# MAX_POSITION_EMBEDDINGS = 4096 +# VICUNA_MODEL_SERVER = "http://121.41.167.183:8000" + + VS_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "vs_store") KNOWLEDGE_UPLOAD_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "knowledge") \ No newline at end of file diff --git a/pilot/conversation.py b/pilot/conversation.py index 073f25f24..47414ec8a 100644 --- a/pilot/conversation.py +++ b/pilot/conversation.py @@ -4,8 +4,16 @@ import dataclasses from enum import auto, Enum from typing import List, Any -from pilot.configs.model_config import DB_SETTINGS +from pilot.configs.config import Config +CFG = Config() + +DB_SETTINGS = { + "user": CFG.LOCAL_DB_USER, + "password": CFG.LOCAL_DB_PASSWORD, + "host": CFG.LOCAL_DB_HOST, + "port": CFG.LOCAL_DB_PORT +} class SeparatorStyle(Enum): SINGLE = auto() @@ -91,7 +99,7 @@ class Conversation: def gen_sqlgen_conversation(dbname): from pilot.connections.mysql import MySQLOperator mo = MySQLOperator( - **DB_SETTINGS + **(DB_SETTINGS) ) message = "" diff --git a/pilot/model/vicuna_llm.py b/pilot/model/vicuna_llm.py index 2337a3bbf..63788a619 100644 --- a/pilot/model/vicuna_llm.py +++ b/pilot/model/vicuna_llm.py @@ -8,8 +8,9 @@ from langchain.embeddings.base import Embeddings from pydantic import BaseModel from typing import Any, Mapping, Optional, List from langchain.llms.base import LLM -from pilot.configs.model_config import * +from pilot.configs.config import Config +CFG = Config() class VicunaLLM(LLM): vicuna_generate_path = "generate_stream" @@ -22,7 +23,7 @@ class VicunaLLM(LLM): "stop": stop } response = requests.post( - url=urljoin(VICUNA_MODEL_SERVER, self.vicuna_generate_path), + url=urljoin(CFG.MODEL_SERVER, self.vicuna_generate_path), data=json.dumps(params), ) @@ -51,7 +52,7 @@ class VicunaEmbeddingLLM(BaseModel, Embeddings): print("Sending prompt ", p) response = requests.post( - url=urljoin(VICUNA_MODEL_SERVER, self.vicuna_embedding_path), + url=urljoin(CFG.MODEL_SERVER, self.vicuna_embedding_path), json={ "prompt": p } diff --git a/pilot/pturning/lora/finetune.py b/pilot/pturning/lora/finetune.py index 6cd9935ed..91ec07d0a 100644 --- a/pilot/pturning/lora/finetune.py +++ b/pilot/pturning/lora/finetune.py @@ -17,14 +17,17 @@ from peft import ( import torch from datasets import load_dataset import pandas as pd +from pilot.configs.config import Config -from pilot.configs.model_config import DATA_DIR, LLM_MODEL, LLM_MODEL_CONFIG +from pilot.configs.model_config import DATA_DIR, LLM_MODEL_CONFIG device = "cuda" if torch.cuda.is_available() else "cpu" CUTOFF_LEN = 50 df = pd.read_csv(os.path.join(DATA_DIR, "BTC_Tweets_Updated.csv")) +CFG = Config() + def sentiment_score_to_name(score: float): if score > 0: return "Positive" @@ -49,7 +52,7 @@ with open(os.path.join(DATA_DIR, "alpaca-bitcoin-sentiment-dataset.json"), "w") data = load_dataset("json", data_files=os.path.join(DATA_DIR, "alpaca-bitcoin-sentiment-dataset.json")) print(data["train"]) -BASE_MODEL = LLM_MODEL_CONFIG[LLM_MODEL] +BASE_MODEL = LLM_MODEL_CONFIG[CFG.LLM_MODEL] model = LlamaForCausalLM.from_pretrained( BASE_MODEL, torch_dtype=torch.float16, diff --git a/pilot/server/llmserver.py b/pilot/server/llmserver.py index 2860c3b77..2b29949a3 100644 --- a/pilot/server/llmserver.py +++ b/pilot/server/llmserver.py @@ -13,8 +13,11 @@ from pilot.model.inference import generate_output, get_embeddings from pilot.model.loader import ModelLoader from pilot.configs.model_config import * +from pilot.configs.config import Config -model_path = LLM_MODEL_CONFIG[LLM_MODEL] + +CFG = Config() +model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL] global_counter = 0 @@ -60,7 +63,7 @@ def generate_stream_gate(params): tokenizer, params, DEVICE, - MAX_POSITION_EMBEDDINGS, + CFG.MAX_POSITION_EMBEDDINGS, ): print("output: ", output) ret = { @@ -84,7 +87,7 @@ async def api_generate_stream(request: Request): print(model, tokenizer, params, DEVICE) if model_semaphore is None: - model_semaphore = asyncio.Semaphore(LIMIT_MODEL_CONCURRENCY) + model_semaphore = asyncio.Semaphore(CFG.LIMIT_MODEL_CONCURRENCY) await model_semaphore.acquire() generator = generate_stream_gate(params) diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index 139caab4d..caacfdf61 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -14,13 +14,13 @@ from urllib.parse import urljoin from langchain import PromptTemplate -from pilot.configs.model_config import DB_SETTINGS, KNOWLEDGE_UPLOAD_ROOT_PATH, LLM_MODEL_CONFIG +from pilot.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH, LLM_MODEL_CONFIG from pilot.server.vectordb_qa import KnownLedgeBaseQA from pilot.connections.mysql import MySQLOperator from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding from pilot.vector_store.extract_tovec import get_vector_storelist, load_knownledge_from_doc, knownledge_tovec_st -from pilot.configs.model_config import LOGDIR, VICUNA_MODEL_SERVER, LLM_MODEL, DATASETS_DIR +from pilot.configs.model_config import LOGDIR, DATASETS_DIR from pilot.plugins import scan_plugins from pilot.configs.config import Config @@ -67,7 +67,15 @@ priority = { "vicuna-13b": "aaa" } +# 加载插件 +CFG= Config() +DB_SETTINGS = { + "user": CFG.LOCAL_DB_USER, + "password": CFG.LOCAL_DB_PASSWORD, + "host": CFG.LOCAL_DB_HOST, + "port": CFG.LOCAL_DB_PORT +} def get_simlar(q): docsearch = knownledge_tovec_st(os.path.join(DATASETS_DIR, "plan.md")) docs = docsearch.similarity_search_with_score(q, k=1) @@ -178,7 +186,7 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re print("是否是AUTO-GPT模式.", autogpt) start_tstamp = time.time() - model_name = LLM_MODEL + model_name = CFG.LLM_MODEL dbname = db_selector # TODO 这里的请求需要拼接现有知识库, 使得其根据现有知识库作答, 所以prompt需要继续优化 @@ -268,7 +276,7 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re logger.info(f"Requert: \n{payload}") if sql_mode == conversation_sql_mode["auto_execute_ai_response"]: - response = requests.post(urljoin(VICUNA_MODEL_SERVER, "generate"), + response = requests.post(urljoin(CFG.MODEL_SERVER, "generate"), headers=headers, json=payload, timeout=120) print(response.json()) @@ -316,7 +324,7 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re try: # Stream output - response = requests.post(urljoin(VICUNA_MODEL_SERVER, "generate_stream"), + response = requests.post(urljoin(CFG.MODEL_SERVER, "generate_stream"), headers=headers, json=payload, stream=True, timeout=20) for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): if chunk: @@ -595,9 +603,8 @@ if __name__ == "__main__": # dbs = get_database_list() - # 加载插件 + # 配置初始化 cfg = Config() - cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode)) # 加载插件可执行命令 diff --git a/requirements.txt b/requirements.txt index 5654dba6f..5bbc34f4c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -57,6 +57,9 @@ pymdown-extensions mkdocs requests gTTS==2.3.1 +langchain +nltk +python-dotenv==1.0.0 # Testing dependencies pytest