mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-25 04:53:36 +00:00
fix confilict
This commit is contained in:
commit
b9df562c84
72
.env.template
Normal file
72
.env.template
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
#*******************************************************************#
|
||||||
|
#** DB-GPT - GENERAL SETTINGS **#
|
||||||
|
#*******************************************************************#
|
||||||
|
## DISABLED_COMMAND_CATEGORIES - The list of categories of commands that are disabled. Each of the below are an option:
|
||||||
|
## pilot.commands.query_execute
|
||||||
|
|
||||||
|
## For example, to disable coding related features, uncomment the next line
|
||||||
|
# DISABLED_COMMAND_CATEGORIES=
|
||||||
|
|
||||||
|
|
||||||
|
#*******************************************************************#
|
||||||
|
#*** LLM PROVIDER ***#
|
||||||
|
#*******************************************************************#
|
||||||
|
|
||||||
|
# TEMPERATURE=0
|
||||||
|
|
||||||
|
#*******************************************************************#
|
||||||
|
#** LLM MODELS **#
|
||||||
|
#*******************************************************************#
|
||||||
|
|
||||||
|
## SMART_LLM_MODEL - Smart language model (Default: vicuna-13b)
|
||||||
|
## FAST_LLM_MODEL - Fast language model (Default: chatglm-6b)
|
||||||
|
# SMART_LLM_MODEL=vicuna-13b
|
||||||
|
# FAST_LLM_MODEL=chatglm-6b
|
||||||
|
|
||||||
|
|
||||||
|
### EMBEDDINGS
|
||||||
|
## EMBEDDING_MODEL - Model to use for creating embeddings
|
||||||
|
## EMBEDDING_TOKENIZER - Tokenizer to use for chunking large inputs
|
||||||
|
## EMBEDDING_TOKEN_LIMIT - Chunk size limit for large inputs
|
||||||
|
# EMBEDDING_MODEL=all-MiniLM-L6-v2
|
||||||
|
# EMBEDDING_TOKENIZER=all-MiniLM-L6-v2
|
||||||
|
# EMBEDDING_TOKEN_LIMIT=8191
|
||||||
|
|
||||||
|
|
||||||
|
#*******************************************************************#
|
||||||
|
#** DATABASE SETTINGS **#
|
||||||
|
#*******************************************************************#
|
||||||
|
DB_SETTINGS_MYSQL_USER=root
|
||||||
|
DB_SETTINGS_MYSQL_PASSWORD=password
|
||||||
|
DB_SETTINGS_MYSQL_HOST=localhost
|
||||||
|
DB_SETTINGS_MYSQL_PORT=3306
|
||||||
|
|
||||||
|
|
||||||
|
### MILVUS
|
||||||
|
## MILVUS_ADDR - Milvus remote address (e.g. localhost:19530)
|
||||||
|
## MILVUS_USERNAME - username for your Milvus database
|
||||||
|
## MILVUS_PASSWORD - password for your Milvus database
|
||||||
|
## MILVUS_SECURE - True to enable TLS. (Default: False)
|
||||||
|
## Setting MILVUS_ADDR to a `https://` URL will override this setting.
|
||||||
|
## MILVUS_COLLECTION - Milvus collection, change it if you want to start a new memory and retain the old memory.
|
||||||
|
# MILVUS_ADDR=localhost:19530
|
||||||
|
# MILVUS_USERNAME=
|
||||||
|
# MILVUS_PASSWORD=
|
||||||
|
# MILVUS_SECURE=
|
||||||
|
# MILVUS_COLLECTION=dbgpt
|
||||||
|
|
||||||
|
#*******************************************************************#
|
||||||
|
#** ALLOWLISTED PLUGINS **#
|
||||||
|
#*******************************************************************#
|
||||||
|
|
||||||
|
#ALLOWLISTED_PLUGINS - Sets the listed plugins that are allowed (Example: plugin1,plugin2,plugin3)
|
||||||
|
#DENYLISTED_PLUGINS - Sets the listed plugins that are not allowed (Example: plugin1,plugin2,plugin3)
|
||||||
|
ALLOWLISTED_PLUGINS=
|
||||||
|
DENYLISTED_PLUGINS=
|
||||||
|
|
||||||
|
|
||||||
|
#*******************************************************************#
|
||||||
|
#** CHAT PLUGIN SETTINGS **#
|
||||||
|
#*******************************************************************#
|
||||||
|
# CHAT_MESSAGES_ENABLED - Enable chat messages (Default: False)
|
||||||
|
# CHAT_MESSAGES_ENABLED=False
|
23
.github/workflows/pylint.yml
vendored
Normal file
23
.github/workflows/pylint.yml
vendored
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
name: Pylint
|
||||||
|
|
||||||
|
on: [push]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
python-version: ["3.8", "3.9", "3.10"]
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v3
|
||||||
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
|
uses: actions/setup-python@v3
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install pylint
|
||||||
|
- name: Analysing the code with pylint
|
||||||
|
run: |
|
||||||
|
pylint $(git ls-files '*.py')
|
39
.github/workflows/python-publish.yml
vendored
Normal file
39
.github/workflows/python-publish.yml
vendored
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
# This workflow will upload a Python Package using Twine when a release is created
|
||||||
|
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries
|
||||||
|
|
||||||
|
# This workflow uses actions that are not certified by GitHub.
|
||||||
|
# They are provided by a third-party and are governed by
|
||||||
|
# separate terms of service, privacy policy, and support
|
||||||
|
# documentation.
|
||||||
|
|
||||||
|
name: Upload Python Package
|
||||||
|
|
||||||
|
on:
|
||||||
|
release:
|
||||||
|
types: [published]
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
deploy:
|
||||||
|
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v3
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v3
|
||||||
|
with:
|
||||||
|
python-version: '3.x'
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install build
|
||||||
|
- name: Build package
|
||||||
|
run: python -m build
|
||||||
|
- name: Publish package
|
||||||
|
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
|
||||||
|
with:
|
||||||
|
user: __token__
|
||||||
|
password: ${{ secrets.PYPI_API_TOKEN }}
|
@ -44,7 +44,7 @@ The Generated SQL is runable.
|
|||||||
1. First you need to install python requirements.
|
1. First you need to install python requirements.
|
||||||
```
|
```
|
||||||
python>=3.9
|
python>=3.9
|
||||||
pip install -r requirements
|
pip install -r requirements.txt
|
||||||
```
|
```
|
||||||
or if you use conda envirenment, you can use this command
|
or if you use conda envirenment, you can use this command
|
||||||
```
|
```
|
||||||
@ -63,7 +63,7 @@ The password just for test, you can change this if necessary
|
|||||||
# Install
|
# Install
|
||||||
1. 基础模型下载
|
1. 基础模型下载
|
||||||
关于基础模型, 可以根据[vicuna](https://github.com/lm-sys/FastChat/blob/main/README.md#model-weights)合成教程进行合成。
|
关于基础模型, 可以根据[vicuna](https://github.com/lm-sys/FastChat/blob/main/README.md#model-weights)合成教程进行合成。
|
||||||
如果此步有困难的同学,也可以直接使用[Hugging Face](https://huggingface.co/)上的模型进行替代。 替代模型: [vicuna-13b](https://huggingface.co/Tribbiani/vicuna-13b)
|
如果此步有困难的同学,也可以直接使用[Hugging Face](https://huggingface.co/)上的模型进行替代. [替代模型](https://huggingface.co/Tribbiani/vicuna-7b)
|
||||||
|
|
||||||
2. Run model server
|
2. Run model server
|
||||||
```
|
```
|
||||||
@ -86,4 +86,4 @@ python webserver.py
|
|||||||
# Contribute
|
# Contribute
|
||||||
[Contribute](https://github.com/csunny/DB-GPT/blob/main/CONTRIBUTING)
|
[Contribute](https://github.com/csunny/DB-GPT/blob/main/CONTRIBUTING)
|
||||||
# Licence
|
# Licence
|
||||||
[MIT](https://github.com/csunny/DB-GPT/blob/main/LICENSE)
|
[MIT](https://github.com/csunny/DB-GPT/blob/main/LICENSE)
|
||||||
|
@ -7,11 +7,9 @@ dependencies:
|
|||||||
- python=3.9
|
- python=3.9
|
||||||
- cudatoolkit
|
- cudatoolkit
|
||||||
- pip
|
- pip
|
||||||
- pytorch=1.12.1
|
|
||||||
- pytorch-mutex=1.0=cuda
|
- pytorch-mutex=1.0=cuda
|
||||||
- torchaudio=0.12.1
|
|
||||||
- torchvision=0.13.1
|
|
||||||
- pip:
|
- pip:
|
||||||
|
- pytorch
|
||||||
- accelerate==0.16.0
|
- accelerate==0.16.0
|
||||||
- aiohttp==3.8.4
|
- aiohttp==3.8.4
|
||||||
- aiosignal==1.3.1
|
- aiosignal==1.3.1
|
||||||
@ -57,11 +55,12 @@ dependencies:
|
|||||||
- sentence-transformers
|
- sentence-transformers
|
||||||
- umap-learn
|
- umap-learn
|
||||||
- notebook
|
- notebook
|
||||||
- gradio==3.24.1
|
- gradio==3.23
|
||||||
- gradio-client==0.0.8
|
- gradio-client==0.0.8
|
||||||
- wandb
|
- wandb
|
||||||
- fschat=0.1.10
|
- llama-index==0.5.27
|
||||||
- llama-index=0.5.27
|
|
||||||
- pymysql
|
- pymysql
|
||||||
- unstructured==0.6.3
|
- unstructured==0.6.3
|
||||||
- pytesseract==0.3.10
|
- pytesseract==0.3.10
|
||||||
|
- markdown2
|
||||||
|
- chromadb
|
19
examples/gradio_test.py
Normal file
19
examples/gradio_test.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding:utf-8 -*-
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
def change_tab():
|
||||||
|
return gr.Tabs.update(selected=1)
|
||||||
|
|
||||||
|
with gr.Blocks() as demo:
|
||||||
|
with gr.Tabs() as tabs:
|
||||||
|
with gr.TabItem("Train", id=0):
|
||||||
|
t = gr.Textbox()
|
||||||
|
with gr.TabItem("Inference", id=1):
|
||||||
|
i = gr.Image()
|
||||||
|
|
||||||
|
btn = gr.Button()
|
||||||
|
btn.click(change_tab, None, tabs)
|
||||||
|
|
||||||
|
demo.launch()
|
0
pilot/commands/__init__.py
Normal file
0
pilot/commands/__init__.py
Normal file
36
pilot/commands/command.py
Normal file
36
pilot/commands/command.py
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
import functools
|
||||||
|
import importlib
|
||||||
|
import inspect
|
||||||
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
|
class Command:
|
||||||
|
"""A class representing a command.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
name (str): The name of the command.
|
||||||
|
description (str): A brief description of what the command does.
|
||||||
|
signature (str): The signature of the function that the command executes. Default to None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
name: str,
|
||||||
|
description: str,
|
||||||
|
method: Callable[..., Any],
|
||||||
|
signature: str = "",
|
||||||
|
enabled: bool = True,
|
||||||
|
disabled_reason: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
self.name = name
|
||||||
|
self.description = description
|
||||||
|
self.method = method
|
||||||
|
self.signature = signature if signature else str(inspect.signature(self.method))
|
||||||
|
self.enabled = enabled
|
||||||
|
self.disabled_reason = disabled_reason
|
||||||
|
|
||||||
|
def __call__(self, *args: Any, **kwds: Any) -> Any:
|
||||||
|
if not self.enabled:
|
||||||
|
return f"Command '{self.name}' is disabled: {self.disabled_reason}"
|
||||||
|
return self.method(*args, **kwds)
|
@ -12,26 +12,55 @@ class Config(metaclass=Singleton):
|
|||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
"""Initialize the Config class"""
|
"""Initialize the Config class"""
|
||||||
|
|
||||||
# TODO change model_config there
|
# TODO change model_config there
|
||||||
|
|
||||||
self.debug_mode = False
|
self.debug_mode = False
|
||||||
|
self.skip_reprompt = False
|
||||||
|
|
||||||
|
self.plugins_dir = os.getenv("PLUGINS_DIR", "plugins")
|
||||||
|
self.plugins = List[AutoGPTPluginTemplate] = []
|
||||||
|
self.temperature = float(os.getenv("TEMPERATURE", 0.7))
|
||||||
|
|
||||||
|
# TODO change model_config there
|
||||||
self.execute_local_commands = (
|
self.execute_local_commands = (
|
||||||
os.getenv("EXECUTE_LOCAL_COMMANDS", "False") == "True"
|
os.getenv("EXECUTE_LOCAL_COMMANDS", "False") == "True"
|
||||||
)
|
)
|
||||||
self.temperature = float(os.getenv("TEMPERATURE", "0.7"))
|
# User agent header to use when making HTTP requests
|
||||||
|
# Some websites might just completely deny request with an error code if
|
||||||
|
# no user agent was found.
|
||||||
|
self.user_agent = os.getenv(
|
||||||
|
"USER_AGENT",
|
||||||
|
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_4) AppleWebKit/537.36"
|
||||||
|
" (KHTML, like Gecko) Chrome/83.0.4103.97 Safari/537.36",
|
||||||
|
)
|
||||||
|
|
||||||
|
# milvus or zilliz cloud configuration
|
||||||
|
self.milvus_addr = os.getenv("MILVUS_ADDR", "localhost:19530")
|
||||||
|
self.milvus_username = os.getenv("MILVUS_USERNAME")
|
||||||
|
self.milvus_password = os.getenv("MILVUS_PASSWORD")
|
||||||
|
self.milvus_collection = os.getenv("MILVUS_COLLECTION", "dbgpt")
|
||||||
|
self.milvus_secure = os.getenv("MILVUS_SECURE") == "True"
|
||||||
|
|
||||||
self.plugins_dir = os.getenv("PLUGINS_DIR", 'plugins')
|
plugins_allowlist = os.getenv("ALLOWLISTED_PLUGINS")
|
||||||
self.plugins:List[AutoGPTPluginTemplate] = []
|
if plugins_allowlist:
|
||||||
|
self.plugins_allowlist = plugins_allowlist.split(",")
|
||||||
|
else:
|
||||||
|
self.plugins_allowlist = []
|
||||||
|
|
||||||
|
plugins_denylist = os.getenv("DENYLISTED_PLUGINS")
|
||||||
|
if plugins_denylist:
|
||||||
|
self.plugins_denylist = []
|
||||||
|
|
||||||
def set_debug_mode(self, value: bool) -> None:
|
def set_debug_mode(self, value: bool) -> None:
|
||||||
"""Set the debug mode value."""
|
"""Set the debug mode value"""
|
||||||
self.debug_mode = value
|
self.debug_mode = value
|
||||||
|
|
||||||
def set_plugins(self,value: bool) -> None:
|
def set_plugins(self, value: list) -> None:
|
||||||
"""Set the plugins value."""
|
"""Set the plugins value. """
|
||||||
self.plugins = value
|
self.plugins = value
|
||||||
|
|
||||||
def set_temperature(self, value: int) -> None:
|
def set_templature(self, value: int) -> None:
|
||||||
""" Set the temperature value."""
|
"""Set the temperature value."""
|
||||||
self.temperature = value
|
self.temperature = value
|
||||||
|
|
||||||
|
|
@ -5,6 +5,7 @@ import torch
|
|||||||
import os
|
import os
|
||||||
import nltk
|
import nltk
|
||||||
|
|
||||||
|
|
||||||
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
MODEL_PATH = os.path.join(ROOT_PATH, "models")
|
MODEL_PATH = os.path.join(ROOT_PATH, "models")
|
||||||
PILOT_PATH = os.path.join(ROOT_PATH, "pilot")
|
PILOT_PATH = os.path.join(ROOT_PATH, "pilot")
|
||||||
@ -26,9 +27,8 @@ LLM_MODEL_CONFIG = {
|
|||||||
VECTOR_SEARCH_TOP_K = 3
|
VECTOR_SEARCH_TOP_K = 3
|
||||||
LLM_MODEL = "vicuna-13b"
|
LLM_MODEL = "vicuna-13b"
|
||||||
LIMIT_MODEL_CONCURRENCY = 5
|
LIMIT_MODEL_CONCURRENCY = 5
|
||||||
MAX_POSITION_EMBEDDINGS = 2048
|
MAX_POSITION_EMBEDDINGS = 4096
|
||||||
VICUNA_MODEL_SERVER = "http://192.168.31.114:8000"
|
VICUNA_MODEL_SERVER = "http://47.97.125.199:8000"
|
||||||
|
|
||||||
|
|
||||||
# Load model config
|
# Load model config
|
||||||
ISLOAD_8BIT = True
|
ISLOAD_8BIT = True
|
||||||
@ -37,7 +37,7 @@ ISDEBUG = False
|
|||||||
|
|
||||||
DB_SETTINGS = {
|
DB_SETTINGS = {
|
||||||
"user": "root",
|
"user": "root",
|
||||||
"password": "aa123456",
|
"password": "aa12345678",
|
||||||
"host": "localhost",
|
"host": "localhost",
|
||||||
"port": 3306
|
"port": 3306
|
||||||
}
|
}
|
@ -10,6 +10,8 @@ class SeparatorStyle(Enum):
|
|||||||
|
|
||||||
SINGLE = auto()
|
SINGLE = auto()
|
||||||
TWO = auto()
|
TWO = auto()
|
||||||
|
THREE = auto()
|
||||||
|
FOUR = auto()
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class Conversation:
|
class Conversation:
|
||||||
@ -146,10 +148,103 @@ conv_vicuna_v1 = Conversation(
|
|||||||
sep2="</s>",
|
sep2="</s>",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
auto_dbgpt_one_shot = Conversation(
|
||||||
|
system="You are DB-GPT, an AI designed to answer questions about HackerNews by query `hackerbews` database in MySQL. "
|
||||||
|
"Your decisions must always be made independently without seeking user assistance. Play to your strengths as an LLM and pursue simple strategies with no legal complications.",
|
||||||
|
roles=("USER", "ASSISTANT"),
|
||||||
|
messages=(
|
||||||
|
(
|
||||||
|
"USER",
|
||||||
|
""" Answer how many users does hackernews have by query mysql database
|
||||||
|
Constraints:
|
||||||
|
1. ~4000 word limit for short term memory. Your short term memory is short, so immediately save important information to files.
|
||||||
|
2. If you are unsure how you previously did something or want to recall past events, thinking about similar events will help you remember.
|
||||||
|
3. No user assistance
|
||||||
|
4. Exclusively use the commands listed in double quotes e.g. "command name"
|
||||||
|
|
||||||
conv_qa_prompt_template = """ 基于以下已知的信息, 专业、详细的回答用户的问题。
|
Commands:
|
||||||
如果无法从提供的恶内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题", 但是你可以给出一些与问题相关答案的建议:
|
1. analyze_code: Analyze Code, args: "code": "<full_code_string>"
|
||||||
|
2. execute_python_file: Execute Python File, args: "filename": "<filename>"
|
||||||
|
3. append_to_file: Append to file, args: "filename": "<filename>", "text": "<text>"
|
||||||
|
4. delete_file: Delete file, args: "filename": "<filename>"
|
||||||
|
5. list_files: List Files in Directory, args: "directory": "<directory>"
|
||||||
|
6. read_file: Read file, args: "filename": "<filename>"
|
||||||
|
7. write_to_file: Write to file, args: "filename": "<filename>", "text": "<text>"
|
||||||
|
8. tidb_sql_executor: "Execute SQL in TiDB Database.", args: "sql": "<sql>"
|
||||||
|
|
||||||
|
Resources:
|
||||||
|
1. Internet access for searches and information gathering.
|
||||||
|
2. Long Term memory management.
|
||||||
|
3. vicuna powered Agents for delegation of simple tasks.
|
||||||
|
4. File output.
|
||||||
|
|
||||||
|
Performance Evaluation:
|
||||||
|
1. Continuously review and analyze your actions to ensure you are performing to the best of your abilities.
|
||||||
|
2. Constructively self-criticize your big-picture behavior constantly.
|
||||||
|
3. Reflect on past decisions and strategies to refine your approach.
|
||||||
|
4. Every command has a cost, so be smart and efficient. Aim to complete tasks in the least number of steps.
|
||||||
|
5. Write all code to a file.
|
||||||
|
|
||||||
|
You should only respond in JSON format as described below
|
||||||
|
Response Format:
|
||||||
|
{
|
||||||
|
"thoughts": {
|
||||||
|
"text": "thought",
|
||||||
|
"reasoning": "reasoning",
|
||||||
|
"plan": "- short bulleted\n- list that conveys\n- long-term plan",
|
||||||
|
"criticism": "constructive self-criticism",
|
||||||
|
"speak": "thoughts summary to say to user"
|
||||||
|
},
|
||||||
|
"command": {
|
||||||
|
"name": "command name",
|
||||||
|
"args": {
|
||||||
|
"arg name": "value"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ensure the response can be parsed by Python json.loads
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"ASSISTANT",
|
||||||
|
"""
|
||||||
|
{
|
||||||
|
"thoughts": {
|
||||||
|
"text": "thought",
|
||||||
|
"reasoning": "reasoning",
|
||||||
|
"plan": "- short bulleted\n- list that conveys\n- long-term plan",
|
||||||
|
"criticism": "constructive self-criticism",
|
||||||
|
"speak": "thoughts summary to say to user"
|
||||||
|
},
|
||||||
|
"command": {
|
||||||
|
"name": "command name",
|
||||||
|
"args": {
|
||||||
|
"arg name": "value"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
),
|
||||||
|
offset=0,
|
||||||
|
sep_style=SeparatorStyle.THREE,
|
||||||
|
sep=" ",
|
||||||
|
sep2="</s>",
|
||||||
|
)
|
||||||
|
|
||||||
|
auto_dbgpt_without_shot = Conversation(
|
||||||
|
system="You are DB-GPT, an AI designed to answer questions about HackerNews by query `hackerbews` database in MySQL. "
|
||||||
|
"Your decisions must always be made independently without seeking user assistance. Play to your strengths as an LLM and pursue simple strategies with no legal complications.",
|
||||||
|
roles=("USER", "ASSISTANT"),
|
||||||
|
messages=(),
|
||||||
|
offset=0,
|
||||||
|
sep_style=SeparatorStyle.FOUR,
|
||||||
|
sep=" ",
|
||||||
|
sep2="</s>",
|
||||||
|
)
|
||||||
|
|
||||||
|
conv_qa_prompt_template = """ 基于以下已知的信息, 专业、详细的回答用户的问题,
|
||||||
|
如果无法从提供的恶内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题", 但是你可以给出一些与问题相关答案的建议。
|
||||||
已知内容:
|
已知内容:
|
||||||
{context}
|
{context}
|
||||||
问题:
|
问题:
|
||||||
|
121
pilot/model/compression.py
Normal file
121
pilot/model/compression.py
Normal file
@ -0,0 +1,121 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
import dataclasses
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class CompressionConfig:
|
||||||
|
"""Group-wise quantization."""
|
||||||
|
num_bits: int
|
||||||
|
group_size: int
|
||||||
|
group_dim: int
|
||||||
|
symmetric: bool
|
||||||
|
enabled: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
default_compression_config = CompressionConfig(
|
||||||
|
num_bits=8, group_size=256, group_dim=1, symmetric=True, enabled=True)
|
||||||
|
|
||||||
|
|
||||||
|
class CLinear(nn.Module):
|
||||||
|
"""Compressed Linear Layer."""
|
||||||
|
|
||||||
|
def __init__(self, weight, bias, device):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.weight = compress(weight.data.to(device), default_compression_config)
|
||||||
|
self.bias = bias
|
||||||
|
|
||||||
|
def forward(self, input: Tensor) -> Tensor:
|
||||||
|
weight = decompress(self.weight, default_compression_config)
|
||||||
|
return F.linear(input, weight, self.bias)
|
||||||
|
|
||||||
|
|
||||||
|
def compress_module(module, target_device):
|
||||||
|
for attr_str in dir(module):
|
||||||
|
target_attr = getattr(module, attr_str)
|
||||||
|
if type(target_attr) == torch.nn.Linear:
|
||||||
|
setattr(module, attr_str,
|
||||||
|
CLinear(target_attr.weight, target_attr.bias, target_device))
|
||||||
|
for name, child in module.named_children():
|
||||||
|
compress_module(child, target_device)
|
||||||
|
|
||||||
|
|
||||||
|
def compress(tensor, config):
|
||||||
|
"""Simulate group-wise quantization."""
|
||||||
|
if not config.enabled:
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
group_size, num_bits, group_dim, symmetric = (
|
||||||
|
config.group_size, config.num_bits, config.group_dim, config.symmetric)
|
||||||
|
assert num_bits <= 8
|
||||||
|
|
||||||
|
original_shape = tensor.shape
|
||||||
|
num_groups = (original_shape[group_dim] + group_size - 1) // group_size
|
||||||
|
new_shape = (original_shape[:group_dim] + (num_groups, group_size) +
|
||||||
|
original_shape[group_dim+1:])
|
||||||
|
|
||||||
|
# Pad
|
||||||
|
pad_len = (group_size - original_shape[group_dim] % group_size) % group_size
|
||||||
|
if pad_len != 0:
|
||||||
|
pad_shape = original_shape[:group_dim] + (pad_len,) + original_shape[group_dim+1:]
|
||||||
|
tensor = torch.cat([
|
||||||
|
tensor,
|
||||||
|
torch.zeros(pad_shape, dtype=tensor.dtype, device=tensor.device)],
|
||||||
|
dim=group_dim)
|
||||||
|
data = tensor.view(new_shape)
|
||||||
|
|
||||||
|
# Quantize
|
||||||
|
if symmetric:
|
||||||
|
B = 2 ** (num_bits - 1) - 1
|
||||||
|
scale = B / torch.max(data.abs(), dim=group_dim + 1, keepdim=True)[0]
|
||||||
|
data = data * scale
|
||||||
|
data = data.clamp_(-B, B).round_().to(torch.int8)
|
||||||
|
return data, scale, original_shape
|
||||||
|
else:
|
||||||
|
B = 2 ** num_bits - 1
|
||||||
|
mn = torch.min(data, dim=group_dim + 1, keepdim=True)[0]
|
||||||
|
mx = torch.max(data, dim=group_dim + 1, keepdim=True)[0]
|
||||||
|
|
||||||
|
scale = B / (mx - mn)
|
||||||
|
data = data - mn
|
||||||
|
data.mul_(scale)
|
||||||
|
|
||||||
|
data = data.clamp_(0, B).round_().to(torch.uint8)
|
||||||
|
return data, mn, scale, original_shape
|
||||||
|
|
||||||
|
|
||||||
|
def decompress(packed_data, config):
|
||||||
|
"""Simulate group-wise dequantization."""
|
||||||
|
if not config.enabled:
|
||||||
|
return packed_data
|
||||||
|
|
||||||
|
group_size, num_bits, group_dim, symmetric = (
|
||||||
|
config.group_size, config.num_bits, config.group_dim, config.symmetric)
|
||||||
|
|
||||||
|
# Dequantize
|
||||||
|
if symmetric:
|
||||||
|
data, scale, original_shape = packed_data
|
||||||
|
data = data / scale
|
||||||
|
else:
|
||||||
|
data, mn, scale, original_shape = packed_data
|
||||||
|
data = data / scale
|
||||||
|
data.add_(mn)
|
||||||
|
|
||||||
|
# Unpad
|
||||||
|
pad_len = (group_size - original_shape[group_dim] % group_size) % group_size
|
||||||
|
if pad_len:
|
||||||
|
padded_original_shape = (
|
||||||
|
original_shape[:group_dim] +
|
||||||
|
(original_shape[group_dim] + pad_len,) +
|
||||||
|
original_shape[group_dim+1:])
|
||||||
|
data = data.reshape(padded_original_shape)
|
||||||
|
indices = [slice(0, x) for x in original_shape]
|
||||||
|
return data[indices].contiguous()
|
||||||
|
else:
|
||||||
|
return data.view(original_shape)
|
@ -5,13 +5,13 @@ import torch
|
|||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def generate_stream(model, tokenizer, params, device,
|
def generate_stream(model, tokenizer, params, device,
|
||||||
context_len=2048, stream_interval=2):
|
context_len=4096, stream_interval=2):
|
||||||
|
|
||||||
"""Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/inference.py """
|
"""Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/inference.py """
|
||||||
prompt = params["prompt"]
|
prompt = params["prompt"]
|
||||||
l_prompt = len(prompt)
|
l_prompt = len(prompt)
|
||||||
temperature = float(params.get("temperature", 1.0))
|
temperature = float(params.get("temperature", 1.0))
|
||||||
max_new_tokens = int(params.get("max_new_tokens", 256))
|
max_new_tokens = int(params.get("max_new_tokens", 2048))
|
||||||
stop_str = params.get("stop", None)
|
stop_str = params.get("stop", None)
|
||||||
|
|
||||||
input_ids = tokenizer(prompt).input_ids
|
input_ids = tokenizer(prompt).input_ids
|
||||||
|
34
pilot/model/llm/base.py
Normal file
34
pilot/model/llm/base.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import List, TypedDict
|
||||||
|
|
||||||
|
|
||||||
|
class Message(TypedDict):
|
||||||
|
"""Vicuna Message object containing a role and the message content """
|
||||||
|
|
||||||
|
role: str
|
||||||
|
content: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelInfo:
|
||||||
|
"""Struct for model information.
|
||||||
|
|
||||||
|
Would be lovely to eventually get this directly from APIs
|
||||||
|
"""
|
||||||
|
name: str
|
||||||
|
max_tokens: int
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LLMResponse:
|
||||||
|
"""Standard response struct for a response from a LLM model."""
|
||||||
|
model_info = ModelInfo
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ChatModelResponse(LLMResponse):
|
||||||
|
"""Standard response struct for a response from an LLM model."""
|
||||||
|
|
||||||
|
content: str = None
|
108
pilot/model/llm/llm_utils.py
Normal file
108
pilot/model/llm/llm_utils.py
Normal file
@ -0,0 +1,108 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
import abc
|
||||||
|
import time
|
||||||
|
import functools
|
||||||
|
from typing import List, Optional
|
||||||
|
from pilot.model.llm.base import Message
|
||||||
|
from pilot.conversation import conv_templates, Conversation, conv_one_shot, auto_dbgpt_one_shot
|
||||||
|
from pilot.configs.config import Config
|
||||||
|
|
||||||
|
|
||||||
|
# TODO Rewrite this
|
||||||
|
def retry_stream_api(
|
||||||
|
num_retries: int = 10,
|
||||||
|
backoff_base: float = 2.0,
|
||||||
|
warn_user: bool = True
|
||||||
|
):
|
||||||
|
"""Retry an Vicuna Server call.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_retries int: Number of retries. Defaults to 10.
|
||||||
|
backoff_base float: Base for exponential backoff. Defaults to 2.
|
||||||
|
warn_user bool: Whether to warn the user. Defaults to True.
|
||||||
|
"""
|
||||||
|
retry_limit_msg = f"Error: Reached rate limit, passing..."
|
||||||
|
backoff_msg = (f"Error: API Bad gateway. Waiting {{backoff}} seconds...")
|
||||||
|
|
||||||
|
def _wrapper(func):
|
||||||
|
@functools.wraps(func)
|
||||||
|
def _wrapped(*args, **kwargs):
|
||||||
|
user_warned = not warn_user
|
||||||
|
num_attempts = num_retries + 1 # +1 for the first attempt
|
||||||
|
for attempt in range(1, num_attempts + 1):
|
||||||
|
try:
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
if (e.http_status != 502) or (attempt == num_attempts):
|
||||||
|
raise
|
||||||
|
|
||||||
|
backoff = backoff_base ** (attempt + 2)
|
||||||
|
time.sleep(backoff)
|
||||||
|
return _wrapped
|
||||||
|
return _wrapper
|
||||||
|
|
||||||
|
# Overly simple abstraction util we create something better
|
||||||
|
# simple retry mechanism when getting a rate error or a bad gateway
|
||||||
|
def create_chat_competion(
|
||||||
|
conv: Conversation,
|
||||||
|
model: Optional[str] = None,
|
||||||
|
temperature: float = None,
|
||||||
|
max_new_tokens: Optional[int] = None,
|
||||||
|
) -> str:
|
||||||
|
"""Create a chat completion using the Vicuna-13b
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages(List[Message]): The messages to send to the chat completion
|
||||||
|
model (str, optional): The model to use. Default to None.
|
||||||
|
temperature (float, optional): The temperature to use. Defaults to 0.7.
|
||||||
|
max_tokens (int, optional): The max tokens to use. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The response from the chat completion
|
||||||
|
"""
|
||||||
|
cfg = Config()
|
||||||
|
if temperature is None:
|
||||||
|
temperature = cfg.temperature
|
||||||
|
|
||||||
|
# TODO request vicuna model get response
|
||||||
|
# convert vicuna message to chat completion.
|
||||||
|
for plugin in cfg.plugins:
|
||||||
|
if plugin.can_handle_chat_completion():
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ChatIO(abc.ABC):
|
||||||
|
@abc.abstractmethod
|
||||||
|
def prompt_for_input(self, role: str) -> str:
|
||||||
|
"""Prompt for input from a role."""
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def prompt_for_output(self, role: str) -> str:
|
||||||
|
"""Prompt for output from a role."""
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def stream_output(self, output_stream, skip_echo_len: int):
|
||||||
|
"""Stream output."""
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleChatIO(ChatIO):
|
||||||
|
def prompt_for_input(self, role: str) -> str:
|
||||||
|
return input(f"{role}: ")
|
||||||
|
|
||||||
|
def prompt_for_output(self, role: str) -> str:
|
||||||
|
print(f"{role}: ", end="", flush=True)
|
||||||
|
|
||||||
|
def stream_output(self, output_stream, skip_echo_len: int):
|
||||||
|
pre = 0
|
||||||
|
for outputs in output_stream:
|
||||||
|
outputs = outputs[skip_echo_len:].strip()
|
||||||
|
now = len(outputs) - 1
|
||||||
|
if now > pre:
|
||||||
|
print(" ".join(outputs[pre:now]), end=" ", flush=True)
|
||||||
|
pre = now
|
||||||
|
|
||||||
|
print(" ".join(outputs[pre:]), flush=True)
|
||||||
|
return " ".join(outputs)
|
||||||
|
|
@ -2,15 +2,17 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from pilot.singleton import Singleton
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
AutoModel
|
AutoModel
|
||||||
)
|
)
|
||||||
|
|
||||||
from fastchat.serve.compression import compress_module
|
from pilot.model.compression import compress_module
|
||||||
|
|
||||||
class ModelLoader:
|
class ModelLoader(metaclass=Singleton):
|
||||||
"""Model loader is a class for model load
|
"""Model loader is a class for model load
|
||||||
|
|
||||||
Args: model_path
|
Args: model_path
|
||||||
|
@ -10,8 +10,6 @@ from fastapi.responses import StreamingResponse
|
|||||||
from pilot.model.inference import generate_stream
|
from pilot.model.inference import generate_stream
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from pilot.model.inference import generate_output, get_embeddings
|
from pilot.model.inference import generate_output, get_embeddings
|
||||||
from fastchat.serve.inference import load_model
|
|
||||||
|
|
||||||
|
|
||||||
from pilot.model.loader import ModelLoader
|
from pilot.model.loader import ModelLoader
|
||||||
from pilot.configs.model_config import *
|
from pilot.configs.model_config import *
|
||||||
|
@ -24,11 +24,9 @@ from pilot.conversation import (
|
|||||||
SeparatorStyle
|
SeparatorStyle
|
||||||
)
|
)
|
||||||
|
|
||||||
from fastchat.utils import (
|
from pilot.utils import (
|
||||||
build_logger,
|
build_logger,
|
||||||
server_error_msg,
|
server_error_msg,
|
||||||
violates_moderation,
|
|
||||||
moderation_msg
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from pilot.server.gradio_css import code_highlight_css
|
from pilot.server.gradio_css import code_highlight_css
|
||||||
@ -45,6 +43,7 @@ enable_moderation = False
|
|||||||
models = []
|
models = []
|
||||||
dbs = []
|
dbs = []
|
||||||
vs_list = ["新建知识库"] + get_vector_storelist()
|
vs_list = ["新建知识库"] + get_vector_storelist()
|
||||||
|
autogpt = False
|
||||||
|
|
||||||
priority = {
|
priority = {
|
||||||
"vicuna-13b": "aaa"
|
"vicuna-13b": "aaa"
|
||||||
@ -58,8 +57,6 @@ def get_simlar(q):
|
|||||||
contents = [dc.page_content for dc, _ in docs]
|
contents = [dc.page_content for dc, _ in docs]
|
||||||
return "\n".join(contents)
|
return "\n".join(contents)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def gen_sqlgen_conversation(dbname):
|
def gen_sqlgen_conversation(dbname):
|
||||||
mo = MySQLOperator(
|
mo = MySQLOperator(
|
||||||
**DB_SETTINGS
|
**DB_SETTINGS
|
||||||
@ -118,6 +115,8 @@ def regenerate(state, request: gr.Request):
|
|||||||
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
|
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
|
||||||
|
|
||||||
def clear_history(request: gr.Request):
|
def clear_history(request: gr.Request):
|
||||||
|
|
||||||
|
|
||||||
logger.info(f"clear_history. ip: {request.client.host}")
|
logger.info(f"clear_history. ip: {request.client.host}")
|
||||||
state = None
|
state = None
|
||||||
return (state, [], "") + (disable_btn,) * 5
|
return (state, [], "") + (disable_btn,) * 5
|
||||||
@ -128,14 +127,9 @@ def add_text(state, text, request: gr.Request):
|
|||||||
if len(text) <= 0:
|
if len(text) <= 0:
|
||||||
state.skip_next = True
|
state.skip_next = True
|
||||||
return (state, state.to_gradio_chatbot(), "") + (no_change_btn,) * 5
|
return (state, state.to_gradio_chatbot(), "") + (no_change_btn,) * 5
|
||||||
if args.moderate:
|
|
||||||
flagged = violates_moderation(text)
|
|
||||||
if flagged:
|
|
||||||
state.skip_next = True
|
|
||||||
return (state, state.to_gradio_chatbot(), moderation_msg) + (
|
|
||||||
no_change_btn,) * 5
|
|
||||||
|
|
||||||
text = text[:1536] # Hard cut-off
|
""" Default support 4000 tokens, if tokens too lang, we will cut off """
|
||||||
|
text = text[:4000]
|
||||||
state.append_message(state.roles[0], text)
|
state.append_message(state.roles[0], text)
|
||||||
state.append_message(state.roles[1], None)
|
state.append_message(state.roles[1], None)
|
||||||
state.skip_next = False
|
state.skip_next = False
|
||||||
@ -152,6 +146,8 @@ def post_process_code(code):
|
|||||||
return code
|
return code
|
||||||
|
|
||||||
def http_bot(state, mode, db_selector, temperature, max_new_tokens, request: gr.Request):
|
def http_bot(state, mode, db_selector, temperature, max_new_tokens, request: gr.Request):
|
||||||
|
|
||||||
|
print("是否是AUTO-GPT模式.", autogpt)
|
||||||
start_tstamp = time.time()
|
start_tstamp = time.time()
|
||||||
model_name = LLM_MODEL
|
model_name = LLM_MODEL
|
||||||
|
|
||||||
@ -162,7 +158,8 @@ def http_bot(state, mode, db_selector, temperature, max_new_tokens, request: gr.
|
|||||||
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
|
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
|
# TODO when tab mode is AUTO_GPT, Prompt need to rebuild.
|
||||||
if len(state.messages) == state.offset + 2:
|
if len(state.messages) == state.offset + 2:
|
||||||
# 第一轮对话需要加入提示Prompt
|
# 第一轮对话需要加入提示Prompt
|
||||||
|
|
||||||
@ -251,29 +248,28 @@ def http_bot(state, mode, db_selector, temperature, max_new_tokens, request: gr.
|
|||||||
block_css = (
|
block_css = (
|
||||||
code_highlight_css
|
code_highlight_css
|
||||||
+ """
|
+ """
|
||||||
pre {
|
pre {
|
||||||
white-space: pre-wrap; /* Since CSS 2.1 */
|
white-space: pre-wrap; /* Since CSS 2.1 */
|
||||||
white-space: -moz-pre-wrap; /* Mozilla, since 1999 */
|
white-space: -moz-pre-wrap; /* Mozilla, since 1999 */
|
||||||
white-space: -pre-wrap; /* Opera 4-6 */
|
white-space: -pre-wrap; /* Opera 4-6 */
|
||||||
white-space: -o-pre-wrap; /* Opera 7 */
|
white-space: -o-pre-wrap; /* Opera 7 */
|
||||||
word-wrap: break-word; /* Internet Explorer 5.5+ */
|
word-wrap: break-word; /* Internet Explorer 5.5+ */
|
||||||
}
|
}
|
||||||
#notice_markdown th {
|
#notice_markdown th {
|
||||||
display: none;
|
display: none;
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
def change_tab(tab):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def change_mode(mode):
|
def change_mode(mode):
|
||||||
if mode in ["默认知识库对话", "LLM原生对话"]:
|
if mode in ["默认知识库对话", "LLM原生对话"]:
|
||||||
return gr.update(visible=False)
|
return gr.update(visible=False)
|
||||||
else:
|
else:
|
||||||
return gr.update(visible=True)
|
return gr.update(visible=True)
|
||||||
|
|
||||||
|
def change_tab():
|
||||||
|
autogpt = True
|
||||||
|
|
||||||
def build_single_model_ui():
|
def build_single_model_ui():
|
||||||
|
|
||||||
notice_markdown = """
|
notice_markdown = """
|
||||||
@ -301,16 +297,17 @@ def build_single_model_ui():
|
|||||||
|
|
||||||
max_output_tokens = gr.Slider(
|
max_output_tokens = gr.Slider(
|
||||||
minimum=0,
|
minimum=0,
|
||||||
maximum=1024,
|
maximum=4096,
|
||||||
value=512,
|
value=2048,
|
||||||
step=64,
|
step=64,
|
||||||
interactive=True,
|
interactive=True,
|
||||||
label="最大输出Token数",
|
label="最大输出Token数",
|
||||||
)
|
)
|
||||||
tabs = gr.Tabs()
|
tabs= gr.Tabs()
|
||||||
with tabs:
|
with tabs:
|
||||||
with gr.TabItem("SQL生成与诊断", elem_id="SQL"):
|
tab_sql = gr.TabItem("SQL生成与诊断", elem_id="SQL")
|
||||||
# TODO A selector to choose database
|
with tab_sql:
|
||||||
|
# TODO A selector to choose database
|
||||||
with gr.Row(elem_id="db_selector"):
|
with gr.Row(elem_id="db_selector"):
|
||||||
db_selector = gr.Dropdown(
|
db_selector = gr.Dropdown(
|
||||||
label="请选择数据库",
|
label="请选择数据库",
|
||||||
@ -318,9 +315,12 @@ def build_single_model_ui():
|
|||||||
value=dbs[0] if len(models) > 0 else "",
|
value=dbs[0] if len(models) > 0 else "",
|
||||||
interactive=True,
|
interactive=True,
|
||||||
show_label=True).style(container=False)
|
show_label=True).style(container=False)
|
||||||
|
tab_auto = gr.TabItem("AUTO-GPT", elem_id="auto")
|
||||||
|
with tab_auto:
|
||||||
|
gr.Markdown("自动执行模式下, DB-GPT可以具备执行SQL、从网络读取知识自动化存储学习的能力")
|
||||||
|
|
||||||
with gr.TabItem("知识问答", elem_id="QA"):
|
tab_qa = gr.TabItem("知识问答", elem_id="QA")
|
||||||
|
with tab_qa:
|
||||||
mode = gr.Radio(["LLM原生对话", "默认知识库对话", "新增知识库对话"], show_label=False, value="LLM原生对话")
|
mode = gr.Radio(["LLM原生对话", "默认知识库对话", "新增知识库对话"], show_label=False, value="LLM原生对话")
|
||||||
vs_setting = gr.Accordion("配置知识库", open=False)
|
vs_setting = gr.Accordion("配置知识库", open=False)
|
||||||
mode.change(fn=change_mode, inputs=mode, outputs=vs_setting)
|
mode.change(fn=change_mode, inputs=mode, outputs=vs_setting)
|
||||||
@ -360,9 +360,7 @@ def build_single_model_ui():
|
|||||||
regenerate_btn = gr.Button(value="重新生成", interactive=False)
|
regenerate_btn = gr.Button(value="重新生成", interactive=False)
|
||||||
clear_btn = gr.Button(value="清理", interactive=False)
|
clear_btn = gr.Button(value="清理", interactive=False)
|
||||||
|
|
||||||
|
|
||||||
gr.Markdown(learn_more_markdown)
|
gr.Markdown(learn_more_markdown)
|
||||||
|
|
||||||
btn_list = [regenerate_btn, clear_btn]
|
btn_list = [regenerate_btn, clear_btn]
|
||||||
regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then(
|
regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then(
|
||||||
http_bot,
|
http_bot,
|
||||||
@ -434,9 +432,7 @@ if __name__ == "__main__":
|
|||||||
"--model-list-mode", type=str, default="once", choices=["once", "reload"]
|
"--model-list-mode", type=str, default="once", choices=["once", "reload"]
|
||||||
)
|
)
|
||||||
parser.add_argument("--share", default=False, action="store_true")
|
parser.add_argument("--share", default=False, action="store_true")
|
||||||
parser.add_argument(
|
|
||||||
"--moderate", action="store_true", help="Enable content moderation"
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
logger.info(f"args: {args}")
|
logger.info(f"args: {args}")
|
||||||
|
|
||||||
|
109
pilot/utils.py
109
pilot/utils.py
@ -3,6 +3,20 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
import datetime
|
||||||
|
import logging
|
||||||
|
import logging.handlers
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from pilot.configs.model_config import LOGDIR
|
||||||
|
|
||||||
|
server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
|
||||||
|
|
||||||
|
handler = None
|
||||||
|
|
||||||
def get_gpu_memory(max_gpus=None):
|
def get_gpu_memory(max_gpus=None):
|
||||||
gpu_memory = []
|
gpu_memory = []
|
||||||
num_gpus = (
|
num_gpus = (
|
||||||
@ -20,3 +34,98 @@ def get_gpu_memory(max_gpus=None):
|
|||||||
available_memory = total_memory - allocated_memory
|
available_memory = total_memory - allocated_memory
|
||||||
gpu_memory.append(available_memory)
|
gpu_memory.append(available_memory)
|
||||||
return gpu_memory
|
return gpu_memory
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def build_logger(logger_name, logger_filename):
|
||||||
|
global handler
|
||||||
|
|
||||||
|
formatter = logging.Formatter(
|
||||||
|
fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
||||||
|
datefmt="%Y-%m-%d %H:%M:%S",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set the format of root handlers
|
||||||
|
if not logging.getLogger().handlers:
|
||||||
|
logging.basicConfig(level=logging.INFO, encoding='utf-8')
|
||||||
|
logging.getLogger().handlers[0].setFormatter(formatter)
|
||||||
|
|
||||||
|
# Redirect stdout and stderr to loggers
|
||||||
|
stdout_logger = logging.getLogger("stdout")
|
||||||
|
stdout_logger.setLevel(logging.INFO)
|
||||||
|
sl = StreamToLogger(stdout_logger, logging.INFO)
|
||||||
|
sys.stdout = sl
|
||||||
|
|
||||||
|
stderr_logger = logging.getLogger("stderr")
|
||||||
|
stderr_logger.setLevel(logging.ERROR)
|
||||||
|
sl = StreamToLogger(stderr_logger, logging.ERROR)
|
||||||
|
sys.stderr = sl
|
||||||
|
|
||||||
|
# Get logger
|
||||||
|
logger = logging.getLogger(logger_name)
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
# Add a file handler for all loggers
|
||||||
|
if handler is None:
|
||||||
|
os.makedirs(LOGDIR, exist_ok=True)
|
||||||
|
filename = os.path.join(LOGDIR, logger_filename)
|
||||||
|
handler = logging.handlers.TimedRotatingFileHandler(
|
||||||
|
filename, when='D', utc=True)
|
||||||
|
handler.setFormatter(formatter)
|
||||||
|
|
||||||
|
for name, item in logging.root.manager.loggerDict.items():
|
||||||
|
if isinstance(item, logging.Logger):
|
||||||
|
item.addHandler(handler)
|
||||||
|
|
||||||
|
return logger
|
||||||
|
|
||||||
|
|
||||||
|
class StreamToLogger(object):
|
||||||
|
"""
|
||||||
|
Fake file-like stream object that redirects writes to a logger instance.
|
||||||
|
"""
|
||||||
|
def __init__(self, logger, log_level=logging.INFO):
|
||||||
|
self.terminal = sys.stdout
|
||||||
|
self.logger = logger
|
||||||
|
self.log_level = log_level
|
||||||
|
self.linebuf = ''
|
||||||
|
|
||||||
|
def __getattr__(self, attr):
|
||||||
|
return getattr(self.terminal, attr)
|
||||||
|
|
||||||
|
def write(self, buf):
|
||||||
|
temp_linebuf = self.linebuf + buf
|
||||||
|
self.linebuf = ''
|
||||||
|
for line in temp_linebuf.splitlines(True):
|
||||||
|
# From the io.TextIOWrapper docs:
|
||||||
|
# On output, if newline is None, any '\n' characters written
|
||||||
|
# are translated to the system default line separator.
|
||||||
|
# By default sys.stdout.write() expects '\n' newlines and then
|
||||||
|
# translates them so this is still cross platform.
|
||||||
|
if line[-1] == '\n':
|
||||||
|
encoded_message = line.encode('utf-8', 'ignore').decode('utf-8')
|
||||||
|
self.logger.log(self.log_level, encoded_message.rstrip())
|
||||||
|
else:
|
||||||
|
self.linebuf += line
|
||||||
|
|
||||||
|
def flush(self):
|
||||||
|
if self.linebuf != '':
|
||||||
|
encoded_message = self.linebuf.encode('utf-8', 'ignore').decode('utf-8')
|
||||||
|
self.logger.log(self.log_level, encoded_message.rstrip())
|
||||||
|
self.linebuf = ''
|
||||||
|
|
||||||
|
|
||||||
|
def disable_torch_init():
|
||||||
|
"""
|
||||||
|
Disable the redundant torch default initialization to accelerate model creation.
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
||||||
|
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
||||||
|
|
||||||
|
|
||||||
|
def pretty_print_semaphore(semaphore):
|
||||||
|
if semaphore is None:
|
||||||
|
return "None"
|
||||||
|
return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
|
||||||
|
|
||||||
|
@ -1,7 +1,5 @@
|
|||||||
accelerate==0.16.0
|
accelerate==0.16.0
|
||||||
torch==2.0.0
|
torch==2.0.0
|
||||||
torchvision==0.13.1
|
|
||||||
torchaudio==0.12.1
|
|
||||||
accelerate==0.16.0
|
accelerate==0.16.0
|
||||||
aiohttp==3.8.4
|
aiohttp==3.8.4
|
||||||
aiosignal==1.3.1
|
aiosignal==1.3.1
|
||||||
@ -47,11 +45,12 @@ pycocoevalcap
|
|||||||
sentence-transformers
|
sentence-transformers
|
||||||
umap-learn
|
umap-learn
|
||||||
notebook
|
notebook
|
||||||
gradio==3.24.1
|
gradio==3.23
|
||||||
gradio-client==0.0.8
|
gradio-client==0.0.8
|
||||||
wandb
|
wandb
|
||||||
fschat=0.1.10
|
llama-index==0.5.27
|
||||||
llama-index=0.5.27
|
|
||||||
pymysql
|
pymysql
|
||||||
unstructured==0.6.3
|
unstructured==0.6.3
|
||||||
pytesseract==0.3.10
|
pytesseract==0.3.10
|
||||||
|
chromadb
|
||||||
|
markdown2
|
Loading…
Reference in New Issue
Block a user