mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-09 21:08:59 +00:00
Unified configuration
This commit is contained in:
@@ -17,6 +17,10 @@
|
|||||||
#*******************************************************************#
|
#*******************************************************************#
|
||||||
#** LLM MODELS **#
|
#** LLM MODELS **#
|
||||||
#*******************************************************************#
|
#*******************************************************************#
|
||||||
|
LLM_MODEL=vicuna-13b
|
||||||
|
MODEL_SERVER=http://your_model_server_url
|
||||||
|
LIMIT_MODEL_CONCURRENCY=5
|
||||||
|
MAX_POSITION_EMBEDDINGS=4096
|
||||||
|
|
||||||
## SMART_LLM_MODEL - Smart language model (Default: vicuna-13b)
|
## SMART_LLM_MODEL - Smart language model (Default: vicuna-13b)
|
||||||
## FAST_LLM_MODEL - Fast language model (Default: chatglm-6b)
|
## FAST_LLM_MODEL - Fast language model (Default: chatglm-6b)
|
||||||
@@ -36,10 +40,10 @@
|
|||||||
#*******************************************************************#
|
#*******************************************************************#
|
||||||
#** DATABASE SETTINGS **#
|
#** DATABASE SETTINGS **#
|
||||||
#*******************************************************************#
|
#*******************************************************************#
|
||||||
DB_SETTINGS_MYSQL_USER=root
|
LOCAL_DB_USER=root
|
||||||
DB_SETTINGS_MYSQL_PASSWORD=password
|
LOCAL_DB_PASSWORD=password
|
||||||
DB_SETTINGS_MYSQL_HOST=localhost
|
LOCAL_DB_HOST=localhost
|
||||||
DB_SETTINGS_MYSQL_PORT=3306
|
LOCAL_DB_PORT=3306
|
||||||
|
|
||||||
|
|
||||||
### MILVUS
|
### MILVUS
|
||||||
|
1
.gitignore
vendored
1
.gitignore
vendored
@@ -6,6 +6,7 @@ __pycache__/
|
|||||||
# C extensions
|
# C extensions
|
||||||
*.so
|
*.so
|
||||||
|
|
||||||
|
.env
|
||||||
.idea
|
.idea
|
||||||
.vscode
|
.vscode
|
||||||
.idea
|
.idea
|
||||||
|
@@ -179,7 +179,7 @@ Run gradio webui
|
|||||||
```bash
|
```bash
|
||||||
$ python pilot/server/webserver.py
|
$ python pilot/server/webserver.py
|
||||||
```
|
```
|
||||||
Notice: the webserver need to connect llmserver, so you need change the pilot/configs/model_config.py file. change the VICUNA_MODEL_SERVER = "http://127.0.0.1:8000" to your address. It's very important.
|
Notice: the webserver need to connect llmserver, so you need change the .env file. change the MODEL_SERVER = "http://127.0.0.1:8000" to your address. It's very important.
|
||||||
|
|
||||||
## Usage Instructions
|
## Usage Instructions
|
||||||
We provide a user interface for Gradio, which allows you to use DB-GPT through our user interface. Additionally, we have prepared several reference articles (written in Chinese) that introduce the code and principles related to our project.
|
We provide a user interface for Gradio, which allows you to use DB-GPT through our user interface. Additionally, we have prepared several reference articles (written in Chinese) that introduce the code and principles related to our project.
|
||||||
|
@@ -178,7 +178,7 @@ python llmserver.py
|
|||||||
```bash
|
```bash
|
||||||
$ python webserver.py
|
$ python webserver.py
|
||||||
```
|
```
|
||||||
注意: 在启动Webserver之前, 需要修改pilot/configs/model_config.py 文件中的VICUNA_MODEL_SERVER = "http://127.0.0.1:8000", 将地址设置为你的服务器地址。
|
注意: 在启动Webserver之前, 需要修改配置文件 .env文件中的MODEL_SERVER = "http://127.0.0.1:8000", 将地址设置为你的服务器地址。
|
||||||
|
|
||||||
## 使用说明
|
## 使用说明
|
||||||
|
|
||||||
|
@@ -7,12 +7,15 @@ import time
|
|||||||
import uuid
|
import uuid
|
||||||
from urllib.parse import urljoin
|
from urllib.parse import urljoin
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
from pilot.configs.model_config import *
|
from pilot.configs.config import Config
|
||||||
from pilot.conversation import conv_qa_prompt_template, conv_templates
|
from pilot.conversation import conv_qa_prompt_template, conv_templates
|
||||||
from langchain.prompts import PromptTemplate
|
from langchain.prompts import PromptTemplate
|
||||||
|
|
||||||
|
|
||||||
vicuna_stream_path = "generate_stream"
|
vicuna_stream_path = "generate_stream"
|
||||||
|
|
||||||
|
CFG = Config()
|
||||||
|
|
||||||
def generate(query):
|
def generate(query):
|
||||||
|
|
||||||
template_name = "conv_one_shot"
|
template_name = "conv_one_shot"
|
||||||
@@ -41,7 +44,7 @@ def generate(query):
|
|||||||
}
|
}
|
||||||
|
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
url=urljoin(VICUNA_MODEL_SERVER, vicuna_stream_path), data=json.dumps(params)
|
url=urljoin(CFG.MODEL_SERVER, vicuna_stream_path), data=json.dumps(params)
|
||||||
)
|
)
|
||||||
|
|
||||||
skip_echo_len = len(params["prompt"]) + 1 - params["prompt"].count("</s>") * 3
|
skip_echo_len = len(params["prompt"]) + 1 - params["prompt"].count("</s>") * 3
|
||||||
@@ -54,7 +57,7 @@ def generate(query):
|
|||||||
yield(output)
|
yield(output)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
print(LLM_MODEL)
|
print(CFG.LLM_MODEL)
|
||||||
with gr.Blocks() as demo:
|
with gr.Blocks() as demo:
|
||||||
gr.Markdown("数据库SQL生成助手")
|
gr.Markdown("数据库SQL生成助手")
|
||||||
with gr.Tab("SQL生成"):
|
with gr.Tab("SQL生成"):
|
||||||
|
@@ -2,24 +2,23 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import nltk
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from auto_gpt_plugin_template import AutoGPTPluginTemplate
|
from auto_gpt_plugin_template import AutoGPTPluginTemplate
|
||||||
from pilot.singleton import Singleton
|
from pilot.singleton import Singleton
|
||||||
|
|
||||||
|
|
||||||
class Config(metaclass=Singleton):
|
class Config(metaclass=Singleton):
|
||||||
"""Configuration class to store the state of bools for different scripts access"""
|
"""Configuration class to store the state of bools for different scripts access"""
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
"""Initialize the Config class"""
|
"""Initialize the Config class"""
|
||||||
|
|
||||||
# TODO change model_config there
|
|
||||||
|
|
||||||
self.debug_mode = False
|
self.debug_mode = False
|
||||||
self.skip_reprompt = False
|
self.skip_reprompt = False
|
||||||
|
|
||||||
self.temperature = float(os.getenv("TEMPERATURE", 0.7))
|
self.temperature = float(os.getenv("TEMPERATURE", 0.7))
|
||||||
|
|
||||||
# TODO change model_config there
|
|
||||||
self.execute_local_commands = (
|
self.execute_local_commands = (
|
||||||
os.getenv("EXECUTE_LOCAL_COMMANDS", "False") == "True"
|
os.getenv("EXECUTE_LOCAL_COMMANDS", "False") == "True"
|
||||||
)
|
)
|
||||||
@@ -46,17 +45,12 @@ class Config(metaclass=Singleton):
|
|||||||
self.milvus_collection = os.getenv("MILVUS_COLLECTION", "dbgpt")
|
self.milvus_collection = os.getenv("MILVUS_COLLECTION", "dbgpt")
|
||||||
self.milvus_secure = os.getenv("MILVUS_SECURE") == "True"
|
self.milvus_secure = os.getenv("MILVUS_SECURE") == "True"
|
||||||
|
|
||||||
|
|
||||||
self.authorise_key = os.getenv("AUTHORISE_COMMAND_KEY", "y")
|
self.authorise_key = os.getenv("AUTHORISE_COMMAND_KEY", "y")
|
||||||
self.exit_key = os.getenv("EXIT_KEY", "n")
|
self.exit_key = os.getenv("EXIT_KEY", "n")
|
||||||
self.image_provider = bool(os.getenv("IMAGE_PROVIDER", True))
|
self.image_provider = os.getenv("IMAGE_PROVIDER", True)
|
||||||
self.image_size = int(os.getenv("IMAGE_SIZE", 256))
|
self.image_size = int(os.getenv("IMAGE_SIZE", 256))
|
||||||
|
|
||||||
self.plugins_dir = os.getenv("PLUGINS_DIR", "../../plugins")
|
|
||||||
self.plugins: List[AutoGPTPluginTemplate] = []
|
|
||||||
self.plugins_openai = []
|
|
||||||
|
|
||||||
self.command_registry = []
|
|
||||||
|
|
||||||
self.huggingface_api_token = os.getenv("HUGGINGFACE_API_TOKEN")
|
self.huggingface_api_token = os.getenv("HUGGINGFACE_API_TOKEN")
|
||||||
self.image_provider = os.getenv("IMAGE_PROVIDER")
|
self.image_provider = os.getenv("IMAGE_PROVIDER")
|
||||||
self.image_size = int(os.getenv("IMAGE_SIZE", 256))
|
self.image_size = int(os.getenv("IMAGE_SIZE", 256))
|
||||||
@@ -68,6 +62,10 @@ class Config(metaclass=Singleton):
|
|||||||
)
|
)
|
||||||
self.speak_mode = False
|
self.speak_mode = False
|
||||||
|
|
||||||
|
|
||||||
|
### Related configuration of built-in commands
|
||||||
|
self.command_registry = []
|
||||||
|
|
||||||
disabled_command_categories = os.getenv("DISABLED_COMMAND_CATEGORIES")
|
disabled_command_categories = os.getenv("DISABLED_COMMAND_CATEGORIES")
|
||||||
if disabled_command_categories:
|
if disabled_command_categories:
|
||||||
self.disabled_command_categories = disabled_command_categories.split(",")
|
self.disabled_command_categories = disabled_command_categories.split(",")
|
||||||
@@ -78,6 +76,12 @@ class Config(metaclass=Singleton):
|
|||||||
os.getenv("EXECUTE_LOCAL_COMMANDS", "False") == "True"
|
os.getenv("EXECUTE_LOCAL_COMMANDS", "False") == "True"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
### The associated configuration parameters of the plug-in control the loading and use of the plug-in
|
||||||
|
self.plugins_dir = os.getenv("PLUGINS_DIR", "../../plugins")
|
||||||
|
self.plugins: List[AutoGPTPluginTemplate] = []
|
||||||
|
self.plugins_openai = []
|
||||||
|
|
||||||
plugins_allowlist = os.getenv("ALLOWLISTED_PLUGINS")
|
plugins_allowlist = os.getenv("ALLOWLISTED_PLUGINS")
|
||||||
if plugins_allowlist:
|
if plugins_allowlist:
|
||||||
self.plugins_allowlist = plugins_allowlist.split(",")
|
self.plugins_allowlist = plugins_allowlist.split(",")
|
||||||
@@ -89,7 +93,21 @@ class Config(metaclass=Singleton):
|
|||||||
self.plugins_denylist = plugins_denylist.split(",")
|
self.plugins_denylist = plugins_denylist.split(",")
|
||||||
else:
|
else:
|
||||||
self.plugins_denylist = []
|
self.plugins_denylist = []
|
||||||
|
|
||||||
|
|
||||||
|
### Local database connection configuration
|
||||||
|
self.LOCAL_DB_HOST = os.getenv("LOCAL_DB_HOST", "127.0.0.1")
|
||||||
|
self.LOCAL_DB_PORT = int(os.getenv("LOCAL_DB_PORT", 3306))
|
||||||
|
self.LOCAL_DB_USER = os.getenv("LOCAL_DB_USER", "root")
|
||||||
|
self.LOCAL_DB_PASSWORD = os.getenv("LOCAL_DB_PASSWORD", "aa123456")
|
||||||
|
|
||||||
|
### LLM Model Service Configuration
|
||||||
|
self.LLM_MODEL = os.getenv("LLM_MODEL", "vicuna-13b")
|
||||||
|
self.LIMIT_MODEL_CONCURRENCY = int(os.getenv("LIMIT_MODEL_CONCURRENCY", 5))
|
||||||
|
self.MAX_POSITION_EMBEDDINGS = int(os.getenv("MAX_POSITION_EMBEDDINGS", 4096))
|
||||||
|
self.MODEL_SERVER = os.getenv("MODEL_SERVER", "http://121.41.167.183:8000")
|
||||||
|
self.ISLOAD_8BIT = os.getenv("ISLOAD_8BIT", "True") == "True"
|
||||||
|
|
||||||
def set_debug_mode(self, value: bool) -> None:
|
def set_debug_mode(self, value: bool) -> None:
|
||||||
"""Set the debug mode value"""
|
"""Set the debug mode value"""
|
||||||
self.debug_mode = value
|
self.debug_mode = value
|
||||||
|
@@ -1,4 +1,4 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# -*- coding:utf-8 -*-
|
# -*- coding:utf-8 -*-
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -22,25 +22,18 @@ LLM_MODEL_CONFIG = {
|
|||||||
"vicuna-13b": os.path.join(MODEL_PATH, "vicuna-13b"),
|
"vicuna-13b": os.path.join(MODEL_PATH, "vicuna-13b"),
|
||||||
"sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2")
|
"sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
VECTOR_SEARCH_TOP_K = 3
|
|
||||||
LLM_MODEL = "vicuna-13b"
|
|
||||||
LIMIT_MODEL_CONCURRENCY = 5
|
|
||||||
MAX_POSITION_EMBEDDINGS = 4096
|
|
||||||
VICUNA_MODEL_SERVER = "http://121.41.167.183:8000"
|
|
||||||
|
|
||||||
# Load model config
|
# Load model config
|
||||||
ISLOAD_8BIT = True
|
ISLOAD_8BIT = True
|
||||||
ISDEBUG = False
|
ISDEBUG = False
|
||||||
|
|
||||||
|
|
||||||
DB_SETTINGS = {
|
VECTOR_SEARCH_TOP_K = 3
|
||||||
"user": "root",
|
# LLM_MODEL = "vicuna-13b"
|
||||||
"password": "aa123456",
|
# LIMIT_MODEL_CONCURRENCY = 5
|
||||||
"host": "127.0.0.1",
|
# MAX_POSITION_EMBEDDINGS = 4096
|
||||||
"port": 3306
|
# VICUNA_MODEL_SERVER = "http://121.41.167.183:8000"
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
VS_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "vs_store")
|
VS_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "vs_store")
|
||||||
KNOWLEDGE_UPLOAD_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "knowledge")
|
KNOWLEDGE_UPLOAD_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "knowledge")
|
@@ -4,8 +4,16 @@
|
|||||||
import dataclasses
|
import dataclasses
|
||||||
from enum import auto, Enum
|
from enum import auto, Enum
|
||||||
from typing import List, Any
|
from typing import List, Any
|
||||||
from pilot.configs.model_config import DB_SETTINGS
|
from pilot.configs.config import Config
|
||||||
|
|
||||||
|
CFG = Config()
|
||||||
|
|
||||||
|
DB_SETTINGS = {
|
||||||
|
"user": CFG.LOCAL_DB_USER,
|
||||||
|
"password": CFG.LOCAL_DB_PASSWORD,
|
||||||
|
"host": CFG.LOCAL_DB_HOST,
|
||||||
|
"port": CFG.LOCAL_DB_PORT
|
||||||
|
}
|
||||||
|
|
||||||
class SeparatorStyle(Enum):
|
class SeparatorStyle(Enum):
|
||||||
SINGLE = auto()
|
SINGLE = auto()
|
||||||
@@ -91,7 +99,7 @@ class Conversation:
|
|||||||
def gen_sqlgen_conversation(dbname):
|
def gen_sqlgen_conversation(dbname):
|
||||||
from pilot.connections.mysql import MySQLOperator
|
from pilot.connections.mysql import MySQLOperator
|
||||||
mo = MySQLOperator(
|
mo = MySQLOperator(
|
||||||
**DB_SETTINGS
|
**(DB_SETTINGS)
|
||||||
)
|
)
|
||||||
|
|
||||||
message = ""
|
message = ""
|
||||||
|
@@ -8,8 +8,9 @@ from langchain.embeddings.base import Embeddings
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import Any, Mapping, Optional, List
|
from typing import Any, Mapping, Optional, List
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import LLM
|
||||||
from pilot.configs.model_config import *
|
from pilot.configs.config import Config
|
||||||
|
|
||||||
|
CFG = Config()
|
||||||
class VicunaLLM(LLM):
|
class VicunaLLM(LLM):
|
||||||
|
|
||||||
vicuna_generate_path = "generate_stream"
|
vicuna_generate_path = "generate_stream"
|
||||||
@@ -22,7 +23,7 @@ class VicunaLLM(LLM):
|
|||||||
"stop": stop
|
"stop": stop
|
||||||
}
|
}
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
url=urljoin(VICUNA_MODEL_SERVER, self.vicuna_generate_path),
|
url=urljoin(CFG.MODEL_SERVER, self.vicuna_generate_path),
|
||||||
data=json.dumps(params),
|
data=json.dumps(params),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -51,7 +52,7 @@ class VicunaEmbeddingLLM(BaseModel, Embeddings):
|
|||||||
print("Sending prompt ", p)
|
print("Sending prompt ", p)
|
||||||
|
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
url=urljoin(VICUNA_MODEL_SERVER, self.vicuna_embedding_path),
|
url=urljoin(CFG.MODEL_SERVER, self.vicuna_embedding_path),
|
||||||
json={
|
json={
|
||||||
"prompt": p
|
"prompt": p
|
||||||
}
|
}
|
||||||
|
@@ -17,14 +17,17 @@ from peft import (
|
|||||||
import torch
|
import torch
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
from pilot.configs.config import Config
|
||||||
|
|
||||||
|
|
||||||
from pilot.configs.model_config import DATA_DIR, LLM_MODEL, LLM_MODEL_CONFIG
|
from pilot.configs.model_config import DATA_DIR, LLM_MODEL_CONFIG
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
CUTOFF_LEN = 50
|
CUTOFF_LEN = 50
|
||||||
|
|
||||||
df = pd.read_csv(os.path.join(DATA_DIR, "BTC_Tweets_Updated.csv"))
|
df = pd.read_csv(os.path.join(DATA_DIR, "BTC_Tweets_Updated.csv"))
|
||||||
|
|
||||||
|
CFG = Config()
|
||||||
|
|
||||||
def sentiment_score_to_name(score: float):
|
def sentiment_score_to_name(score: float):
|
||||||
if score > 0:
|
if score > 0:
|
||||||
return "Positive"
|
return "Positive"
|
||||||
@@ -49,7 +52,7 @@ with open(os.path.join(DATA_DIR, "alpaca-bitcoin-sentiment-dataset.json"), "w")
|
|||||||
data = load_dataset("json", data_files=os.path.join(DATA_DIR, "alpaca-bitcoin-sentiment-dataset.json"))
|
data = load_dataset("json", data_files=os.path.join(DATA_DIR, "alpaca-bitcoin-sentiment-dataset.json"))
|
||||||
print(data["train"])
|
print(data["train"])
|
||||||
|
|
||||||
BASE_MODEL = LLM_MODEL_CONFIG[LLM_MODEL]
|
BASE_MODEL = LLM_MODEL_CONFIG[CFG.LLM_MODEL]
|
||||||
model = LlamaForCausalLM.from_pretrained(
|
model = LlamaForCausalLM.from_pretrained(
|
||||||
BASE_MODEL,
|
BASE_MODEL,
|
||||||
torch_dtype=torch.float16,
|
torch_dtype=torch.float16,
|
||||||
|
@@ -13,8 +13,11 @@ from pilot.model.inference import generate_output, get_embeddings
|
|||||||
|
|
||||||
from pilot.model.loader import ModelLoader
|
from pilot.model.loader import ModelLoader
|
||||||
from pilot.configs.model_config import *
|
from pilot.configs.model_config import *
|
||||||
|
from pilot.configs.config import Config
|
||||||
|
|
||||||
model_path = LLM_MODEL_CONFIG[LLM_MODEL]
|
|
||||||
|
CFG = Config()
|
||||||
|
model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL]
|
||||||
|
|
||||||
|
|
||||||
global_counter = 0
|
global_counter = 0
|
||||||
@@ -60,7 +63,7 @@ def generate_stream_gate(params):
|
|||||||
tokenizer,
|
tokenizer,
|
||||||
params,
|
params,
|
||||||
DEVICE,
|
DEVICE,
|
||||||
MAX_POSITION_EMBEDDINGS,
|
CFG.MAX_POSITION_EMBEDDINGS,
|
||||||
):
|
):
|
||||||
print("output: ", output)
|
print("output: ", output)
|
||||||
ret = {
|
ret = {
|
||||||
@@ -84,7 +87,7 @@ async def api_generate_stream(request: Request):
|
|||||||
print(model, tokenizer, params, DEVICE)
|
print(model, tokenizer, params, DEVICE)
|
||||||
|
|
||||||
if model_semaphore is None:
|
if model_semaphore is None:
|
||||||
model_semaphore = asyncio.Semaphore(LIMIT_MODEL_CONCURRENCY)
|
model_semaphore = asyncio.Semaphore(CFG.LIMIT_MODEL_CONCURRENCY)
|
||||||
await model_semaphore.acquire()
|
await model_semaphore.acquire()
|
||||||
|
|
||||||
generator = generate_stream_gate(params)
|
generator = generate_stream_gate(params)
|
||||||
|
@@ -14,13 +14,13 @@ from urllib.parse import urljoin
|
|||||||
|
|
||||||
from langchain import PromptTemplate
|
from langchain import PromptTemplate
|
||||||
|
|
||||||
from pilot.configs.model_config import DB_SETTINGS, KNOWLEDGE_UPLOAD_ROOT_PATH, LLM_MODEL_CONFIG
|
from pilot.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH, LLM_MODEL_CONFIG
|
||||||
from pilot.server.vectordb_qa import KnownLedgeBaseQA
|
from pilot.server.vectordb_qa import KnownLedgeBaseQA
|
||||||
from pilot.connections.mysql import MySQLOperator
|
from pilot.connections.mysql import MySQLOperator
|
||||||
from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding
|
from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding
|
||||||
from pilot.vector_store.extract_tovec import get_vector_storelist, load_knownledge_from_doc, knownledge_tovec_st
|
from pilot.vector_store.extract_tovec import get_vector_storelist, load_knownledge_from_doc, knownledge_tovec_st
|
||||||
|
|
||||||
from pilot.configs.model_config import LOGDIR, VICUNA_MODEL_SERVER, LLM_MODEL, DATASETS_DIR
|
from pilot.configs.model_config import LOGDIR, DATASETS_DIR
|
||||||
|
|
||||||
from pilot.plugins import scan_plugins
|
from pilot.plugins import scan_plugins
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
@@ -67,7 +67,15 @@ priority = {
|
|||||||
"vicuna-13b": "aaa"
|
"vicuna-13b": "aaa"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# 加载插件
|
||||||
|
CFG= Config()
|
||||||
|
|
||||||
|
DB_SETTINGS = {
|
||||||
|
"user": CFG.LOCAL_DB_USER,
|
||||||
|
"password": CFG.LOCAL_DB_PASSWORD,
|
||||||
|
"host": CFG.LOCAL_DB_HOST,
|
||||||
|
"port": CFG.LOCAL_DB_PORT
|
||||||
|
}
|
||||||
def get_simlar(q):
|
def get_simlar(q):
|
||||||
docsearch = knownledge_tovec_st(os.path.join(DATASETS_DIR, "plan.md"))
|
docsearch = knownledge_tovec_st(os.path.join(DATASETS_DIR, "plan.md"))
|
||||||
docs = docsearch.similarity_search_with_score(q, k=1)
|
docs = docsearch.similarity_search_with_score(q, k=1)
|
||||||
@@ -178,7 +186,7 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
|
|||||||
print("是否是AUTO-GPT模式.", autogpt)
|
print("是否是AUTO-GPT模式.", autogpt)
|
||||||
|
|
||||||
start_tstamp = time.time()
|
start_tstamp = time.time()
|
||||||
model_name = LLM_MODEL
|
model_name = CFG.LLM_MODEL
|
||||||
|
|
||||||
dbname = db_selector
|
dbname = db_selector
|
||||||
# TODO 这里的请求需要拼接现有知识库, 使得其根据现有知识库作答, 所以prompt需要继续优化
|
# TODO 这里的请求需要拼接现有知识库, 使得其根据现有知识库作答, 所以prompt需要继续优化
|
||||||
@@ -268,7 +276,7 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
|
|||||||
logger.info(f"Requert: \n{payload}")
|
logger.info(f"Requert: \n{payload}")
|
||||||
|
|
||||||
if sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
|
if sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
|
||||||
response = requests.post(urljoin(VICUNA_MODEL_SERVER, "generate"),
|
response = requests.post(urljoin(CFG.MODEL_SERVER, "generate"),
|
||||||
headers=headers, json=payload, timeout=120)
|
headers=headers, json=payload, timeout=120)
|
||||||
|
|
||||||
print(response.json())
|
print(response.json())
|
||||||
@@ -316,7 +324,7 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Stream output
|
# Stream output
|
||||||
response = requests.post(urljoin(VICUNA_MODEL_SERVER, "generate_stream"),
|
response = requests.post(urljoin(CFG.MODEL_SERVER, "generate_stream"),
|
||||||
headers=headers, json=payload, stream=True, timeout=20)
|
headers=headers, json=payload, stream=True, timeout=20)
|
||||||
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
||||||
if chunk:
|
if chunk:
|
||||||
@@ -595,9 +603,8 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# dbs = get_database_list()
|
# dbs = get_database_list()
|
||||||
|
|
||||||
# 加载插件
|
# 配置初始化
|
||||||
cfg = Config()
|
cfg = Config()
|
||||||
|
|
||||||
cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode))
|
cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode))
|
||||||
|
|
||||||
# 加载插件可执行命令
|
# 加载插件可执行命令
|
||||||
|
@@ -57,6 +57,9 @@ pymdown-extensions
|
|||||||
mkdocs
|
mkdocs
|
||||||
requests
|
requests
|
||||||
gTTS==2.3.1
|
gTTS==2.3.1
|
||||||
|
langchain
|
||||||
|
nltk
|
||||||
|
python-dotenv==1.0.0
|
||||||
|
|
||||||
# Testing dependencies
|
# Testing dependencies
|
||||||
pytest
|
pytest
|
||||||
|
Reference in New Issue
Block a user