mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-22 11:51:42 +00:00
fix confilict
This commit is contained in:
commit
b9df562c84
72
.env.template
Normal file
72
.env.template
Normal file
@ -0,0 +1,72 @@
|
||||
#*******************************************************************#
|
||||
#** DB-GPT - GENERAL SETTINGS **#
|
||||
#*******************************************************************#
|
||||
## DISABLED_COMMAND_CATEGORIES - The list of categories of commands that are disabled. Each of the below are an option:
|
||||
## pilot.commands.query_execute
|
||||
|
||||
## For example, to disable coding related features, uncomment the next line
|
||||
# DISABLED_COMMAND_CATEGORIES=
|
||||
|
||||
|
||||
#*******************************************************************#
|
||||
#*** LLM PROVIDER ***#
|
||||
#*******************************************************************#
|
||||
|
||||
# TEMPERATURE=0
|
||||
|
||||
#*******************************************************************#
|
||||
#** LLM MODELS **#
|
||||
#*******************************************************************#
|
||||
|
||||
## SMART_LLM_MODEL - Smart language model (Default: vicuna-13b)
|
||||
## FAST_LLM_MODEL - Fast language model (Default: chatglm-6b)
|
||||
# SMART_LLM_MODEL=vicuna-13b
|
||||
# FAST_LLM_MODEL=chatglm-6b
|
||||
|
||||
|
||||
### EMBEDDINGS
|
||||
## EMBEDDING_MODEL - Model to use for creating embeddings
|
||||
## EMBEDDING_TOKENIZER - Tokenizer to use for chunking large inputs
|
||||
## EMBEDDING_TOKEN_LIMIT - Chunk size limit for large inputs
|
||||
# EMBEDDING_MODEL=all-MiniLM-L6-v2
|
||||
# EMBEDDING_TOKENIZER=all-MiniLM-L6-v2
|
||||
# EMBEDDING_TOKEN_LIMIT=8191
|
||||
|
||||
|
||||
#*******************************************************************#
|
||||
#** DATABASE SETTINGS **#
|
||||
#*******************************************************************#
|
||||
DB_SETTINGS_MYSQL_USER=root
|
||||
DB_SETTINGS_MYSQL_PASSWORD=password
|
||||
DB_SETTINGS_MYSQL_HOST=localhost
|
||||
DB_SETTINGS_MYSQL_PORT=3306
|
||||
|
||||
|
||||
### MILVUS
|
||||
## MILVUS_ADDR - Milvus remote address (e.g. localhost:19530)
|
||||
## MILVUS_USERNAME - username for your Milvus database
|
||||
## MILVUS_PASSWORD - password for your Milvus database
|
||||
## MILVUS_SECURE - True to enable TLS. (Default: False)
|
||||
## Setting MILVUS_ADDR to a `https://` URL will override this setting.
|
||||
## MILVUS_COLLECTION - Milvus collection, change it if you want to start a new memory and retain the old memory.
|
||||
# MILVUS_ADDR=localhost:19530
|
||||
# MILVUS_USERNAME=
|
||||
# MILVUS_PASSWORD=
|
||||
# MILVUS_SECURE=
|
||||
# MILVUS_COLLECTION=dbgpt
|
||||
|
||||
#*******************************************************************#
|
||||
#** ALLOWLISTED PLUGINS **#
|
||||
#*******************************************************************#
|
||||
|
||||
#ALLOWLISTED_PLUGINS - Sets the listed plugins that are allowed (Example: plugin1,plugin2,plugin3)
|
||||
#DENYLISTED_PLUGINS - Sets the listed plugins that are not allowed (Example: plugin1,plugin2,plugin3)
|
||||
ALLOWLISTED_PLUGINS=
|
||||
DENYLISTED_PLUGINS=
|
||||
|
||||
|
||||
#*******************************************************************#
|
||||
#** CHAT PLUGIN SETTINGS **#
|
||||
#*******************************************************************#
|
||||
# CHAT_MESSAGES_ENABLED - Enable chat messages (Default: False)
|
||||
# CHAT_MESSAGES_ENABLED=False
|
23
.github/workflows/pylint.yml
vendored
Normal file
23
.github/workflows/pylint.yml
vendored
Normal file
@ -0,0 +1,23 @@
|
||||
name: Pylint
|
||||
|
||||
on: [push]
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.8", "3.9", "3.10"]
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v3
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install pylint
|
||||
- name: Analysing the code with pylint
|
||||
run: |
|
||||
pylint $(git ls-files '*.py')
|
39
.github/workflows/python-publish.yml
vendored
Normal file
39
.github/workflows/python-publish.yml
vendored
Normal file
@ -0,0 +1,39 @@
|
||||
# This workflow will upload a Python Package using Twine when a release is created
|
||||
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries
|
||||
|
||||
# This workflow uses actions that are not certified by GitHub.
|
||||
# They are provided by a third-party and are governed by
|
||||
# separate terms of service, privacy policy, and support
|
||||
# documentation.
|
||||
|
||||
name: Upload Python Package
|
||||
|
||||
on:
|
||||
release:
|
||||
types: [published]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
deploy:
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v3
|
||||
with:
|
||||
python-version: '3.x'
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install build
|
||||
- name: Build package
|
||||
run: python -m build
|
||||
- name: Publish package
|
||||
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
|
||||
with:
|
||||
user: __token__
|
||||
password: ${{ secrets.PYPI_API_TOKEN }}
|
@ -44,7 +44,7 @@ The Generated SQL is runable.
|
||||
1. First you need to install python requirements.
|
||||
```
|
||||
python>=3.9
|
||||
pip install -r requirements
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
or if you use conda envirenment, you can use this command
|
||||
```
|
||||
@ -63,7 +63,7 @@ The password just for test, you can change this if necessary
|
||||
# Install
|
||||
1. 基础模型下载
|
||||
关于基础模型, 可以根据[vicuna](https://github.com/lm-sys/FastChat/blob/main/README.md#model-weights)合成教程进行合成。
|
||||
如果此步有困难的同学,也可以直接使用[Hugging Face](https://huggingface.co/)上的模型进行替代。 替代模型: [vicuna-13b](https://huggingface.co/Tribbiani/vicuna-13b)
|
||||
如果此步有困难的同学,也可以直接使用[Hugging Face](https://huggingface.co/)上的模型进行替代. [替代模型](https://huggingface.co/Tribbiani/vicuna-7b)
|
||||
|
||||
2. Run model server
|
||||
```
|
||||
@ -86,4 +86,4 @@ python webserver.py
|
||||
# Contribute
|
||||
[Contribute](https://github.com/csunny/DB-GPT/blob/main/CONTRIBUTING)
|
||||
# Licence
|
||||
[MIT](https://github.com/csunny/DB-GPT/blob/main/LICENSE)
|
||||
[MIT](https://github.com/csunny/DB-GPT/blob/main/LICENSE)
|
||||
|
@ -7,11 +7,9 @@ dependencies:
|
||||
- python=3.9
|
||||
- cudatoolkit
|
||||
- pip
|
||||
- pytorch=1.12.1
|
||||
- pytorch-mutex=1.0=cuda
|
||||
- torchaudio=0.12.1
|
||||
- torchvision=0.13.1
|
||||
- pip:
|
||||
- pytorch
|
||||
- accelerate==0.16.0
|
||||
- aiohttp==3.8.4
|
||||
- aiosignal==1.3.1
|
||||
@ -57,11 +55,12 @@ dependencies:
|
||||
- sentence-transformers
|
||||
- umap-learn
|
||||
- notebook
|
||||
- gradio==3.24.1
|
||||
- gradio==3.23
|
||||
- gradio-client==0.0.8
|
||||
- wandb
|
||||
- fschat=0.1.10
|
||||
- llama-index=0.5.27
|
||||
- llama-index==0.5.27
|
||||
- pymysql
|
||||
- unstructured==0.6.3
|
||||
- pytesseract==0.3.10
|
||||
- markdown2
|
||||
- chromadb
|
19
examples/gradio_test.py
Normal file
19
examples/gradio_test.py
Normal file
@ -0,0 +1,19 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
import gradio as gr
|
||||
|
||||
def change_tab():
|
||||
return gr.Tabs.update(selected=1)
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
with gr.Tabs() as tabs:
|
||||
with gr.TabItem("Train", id=0):
|
||||
t = gr.Textbox()
|
||||
with gr.TabItem("Inference", id=1):
|
||||
i = gr.Image()
|
||||
|
||||
btn = gr.Button()
|
||||
btn.click(change_tab, None, tabs)
|
||||
|
||||
demo.launch()
|
0
pilot/commands/__init__.py
Normal file
0
pilot/commands/__init__.py
Normal file
36
pilot/commands/command.py
Normal file
36
pilot/commands/command.py
Normal file
@ -0,0 +1,36 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import functools
|
||||
import importlib
|
||||
import inspect
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
class Command:
|
||||
"""A class representing a command.
|
||||
|
||||
Attributes:
|
||||
name (str): The name of the command.
|
||||
description (str): A brief description of what the command does.
|
||||
signature (str): The signature of the function that the command executes. Default to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
name: str,
|
||||
description: str,
|
||||
method: Callable[..., Any],
|
||||
signature: str = "",
|
||||
enabled: bool = True,
|
||||
disabled_reason: Optional[str] = None,
|
||||
) -> None:
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.method = method
|
||||
self.signature = signature if signature else str(inspect.signature(self.method))
|
||||
self.enabled = enabled
|
||||
self.disabled_reason = disabled_reason
|
||||
|
||||
def __call__(self, *args: Any, **kwds: Any) -> Any:
|
||||
if not self.enabled:
|
||||
return f"Command '{self.name}' is disabled: {self.disabled_reason}"
|
||||
return self.method(*args, **kwds)
|
@ -12,26 +12,55 @@ class Config(metaclass=Singleton):
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the Config class"""
|
||||
|
||||
# TODO change model_config there
|
||||
# TODO change model_config there
|
||||
|
||||
self.debug_mode = False
|
||||
self.skip_reprompt = False
|
||||
|
||||
self.plugins_dir = os.getenv("PLUGINS_DIR", "plugins")
|
||||
self.plugins = List[AutoGPTPluginTemplate] = []
|
||||
self.temperature = float(os.getenv("TEMPERATURE", 0.7))
|
||||
|
||||
# TODO change model_config there
|
||||
self.execute_local_commands = (
|
||||
os.getenv("EXECUTE_LOCAL_COMMANDS", "False") == "True"
|
||||
)
|
||||
self.temperature = float(os.getenv("TEMPERATURE", "0.7"))
|
||||
# User agent header to use when making HTTP requests
|
||||
# Some websites might just completely deny request with an error code if
|
||||
# no user agent was found.
|
||||
self.user_agent = os.getenv(
|
||||
"USER_AGENT",
|
||||
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_4) AppleWebKit/537.36"
|
||||
" (KHTML, like Gecko) Chrome/83.0.4103.97 Safari/537.36",
|
||||
)
|
||||
|
||||
# milvus or zilliz cloud configuration
|
||||
self.milvus_addr = os.getenv("MILVUS_ADDR", "localhost:19530")
|
||||
self.milvus_username = os.getenv("MILVUS_USERNAME")
|
||||
self.milvus_password = os.getenv("MILVUS_PASSWORD")
|
||||
self.milvus_collection = os.getenv("MILVUS_COLLECTION", "dbgpt")
|
||||
self.milvus_secure = os.getenv("MILVUS_SECURE") == "True"
|
||||
|
||||
self.plugins_dir = os.getenv("PLUGINS_DIR", 'plugins')
|
||||
self.plugins:List[AutoGPTPluginTemplate] = []
|
||||
plugins_allowlist = os.getenv("ALLOWLISTED_PLUGINS")
|
||||
if plugins_allowlist:
|
||||
self.plugins_allowlist = plugins_allowlist.split(",")
|
||||
else:
|
||||
self.plugins_allowlist = []
|
||||
|
||||
plugins_denylist = os.getenv("DENYLISTED_PLUGINS")
|
||||
if plugins_denylist:
|
||||
self.plugins_denylist = []
|
||||
|
||||
def set_debug_mode(self, value: bool) -> None:
|
||||
"""Set the debug mode value."""
|
||||
"""Set the debug mode value"""
|
||||
self.debug_mode = value
|
||||
|
||||
def set_plugins(self,value: bool) -> None:
|
||||
"""Set the plugins value."""
|
||||
def set_plugins(self, value: list) -> None:
|
||||
"""Set the plugins value. """
|
||||
self.plugins = value
|
||||
|
||||
def set_temperature(self, value: int) -> None:
|
||||
""" Set the temperature value."""
|
||||
def set_templature(self, value: int) -> None:
|
||||
"""Set the temperature value."""
|
||||
self.temperature = value
|
||||
|
||||
|
||||
|
@ -5,6 +5,7 @@ import torch
|
||||
import os
|
||||
import nltk
|
||||
|
||||
|
||||
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
MODEL_PATH = os.path.join(ROOT_PATH, "models")
|
||||
PILOT_PATH = os.path.join(ROOT_PATH, "pilot")
|
||||
@ -26,9 +27,8 @@ LLM_MODEL_CONFIG = {
|
||||
VECTOR_SEARCH_TOP_K = 3
|
||||
LLM_MODEL = "vicuna-13b"
|
||||
LIMIT_MODEL_CONCURRENCY = 5
|
||||
MAX_POSITION_EMBEDDINGS = 2048
|
||||
VICUNA_MODEL_SERVER = "http://192.168.31.114:8000"
|
||||
|
||||
MAX_POSITION_EMBEDDINGS = 4096
|
||||
VICUNA_MODEL_SERVER = "http://47.97.125.199:8000"
|
||||
|
||||
# Load model config
|
||||
ISLOAD_8BIT = True
|
||||
@ -37,7 +37,7 @@ ISDEBUG = False
|
||||
|
||||
DB_SETTINGS = {
|
||||
"user": "root",
|
||||
"password": "aa123456",
|
||||
"password": "aa12345678",
|
||||
"host": "localhost",
|
||||
"port": 3306
|
||||
}
|
@ -10,6 +10,8 @@ class SeparatorStyle(Enum):
|
||||
|
||||
SINGLE = auto()
|
||||
TWO = auto()
|
||||
THREE = auto()
|
||||
FOUR = auto()
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Conversation:
|
||||
@ -146,10 +148,103 @@ conv_vicuna_v1 = Conversation(
|
||||
sep2="</s>",
|
||||
)
|
||||
|
||||
auto_dbgpt_one_shot = Conversation(
|
||||
system="You are DB-GPT, an AI designed to answer questions about HackerNews by query `hackerbews` database in MySQL. "
|
||||
"Your decisions must always be made independently without seeking user assistance. Play to your strengths as an LLM and pursue simple strategies with no legal complications.",
|
||||
roles=("USER", "ASSISTANT"),
|
||||
messages=(
|
||||
(
|
||||
"USER",
|
||||
""" Answer how many users does hackernews have by query mysql database
|
||||
Constraints:
|
||||
1. ~4000 word limit for short term memory. Your short term memory is short, so immediately save important information to files.
|
||||
2. If you are unsure how you previously did something or want to recall past events, thinking about similar events will help you remember.
|
||||
3. No user assistance
|
||||
4. Exclusively use the commands listed in double quotes e.g. "command name"
|
||||
|
||||
conv_qa_prompt_template = """ 基于以下已知的信息, 专业、详细的回答用户的问题。
|
||||
如果无法从提供的恶内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题", 但是你可以给出一些与问题相关答案的建议:
|
||||
|
||||
Commands:
|
||||
1. analyze_code: Analyze Code, args: "code": "<full_code_string>"
|
||||
2. execute_python_file: Execute Python File, args: "filename": "<filename>"
|
||||
3. append_to_file: Append to file, args: "filename": "<filename>", "text": "<text>"
|
||||
4. delete_file: Delete file, args: "filename": "<filename>"
|
||||
5. list_files: List Files in Directory, args: "directory": "<directory>"
|
||||
6. read_file: Read file, args: "filename": "<filename>"
|
||||
7. write_to_file: Write to file, args: "filename": "<filename>", "text": "<text>"
|
||||
8. tidb_sql_executor: "Execute SQL in TiDB Database.", args: "sql": "<sql>"
|
||||
|
||||
Resources:
|
||||
1. Internet access for searches and information gathering.
|
||||
2. Long Term memory management.
|
||||
3. vicuna powered Agents for delegation of simple tasks.
|
||||
4. File output.
|
||||
|
||||
Performance Evaluation:
|
||||
1. Continuously review and analyze your actions to ensure you are performing to the best of your abilities.
|
||||
2. Constructively self-criticize your big-picture behavior constantly.
|
||||
3. Reflect on past decisions and strategies to refine your approach.
|
||||
4. Every command has a cost, so be smart and efficient. Aim to complete tasks in the least number of steps.
|
||||
5. Write all code to a file.
|
||||
|
||||
You should only respond in JSON format as described below
|
||||
Response Format:
|
||||
{
|
||||
"thoughts": {
|
||||
"text": "thought",
|
||||
"reasoning": "reasoning",
|
||||
"plan": "- short bulleted\n- list that conveys\n- long-term plan",
|
||||
"criticism": "constructive self-criticism",
|
||||
"speak": "thoughts summary to say to user"
|
||||
},
|
||||
"command": {
|
||||
"name": "command name",
|
||||
"args": {
|
||||
"arg name": "value"
|
||||
}
|
||||
}
|
||||
}
|
||||
Ensure the response can be parsed by Python json.loads
|
||||
"""
|
||||
),
|
||||
(
|
||||
"ASSISTANT",
|
||||
"""
|
||||
{
|
||||
"thoughts": {
|
||||
"text": "thought",
|
||||
"reasoning": "reasoning",
|
||||
"plan": "- short bulleted\n- list that conveys\n- long-term plan",
|
||||
"criticism": "constructive self-criticism",
|
||||
"speak": "thoughts summary to say to user"
|
||||
},
|
||||
"command": {
|
||||
"name": "command name",
|
||||
"args": {
|
||||
"arg name": "value"
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
)
|
||||
),
|
||||
offset=0,
|
||||
sep_style=SeparatorStyle.THREE,
|
||||
sep=" ",
|
||||
sep2="</s>",
|
||||
)
|
||||
|
||||
auto_dbgpt_without_shot = Conversation(
|
||||
system="You are DB-GPT, an AI designed to answer questions about HackerNews by query `hackerbews` database in MySQL. "
|
||||
"Your decisions must always be made independently without seeking user assistance. Play to your strengths as an LLM and pursue simple strategies with no legal complications.",
|
||||
roles=("USER", "ASSISTANT"),
|
||||
messages=(),
|
||||
offset=0,
|
||||
sep_style=SeparatorStyle.FOUR,
|
||||
sep=" ",
|
||||
sep2="</s>",
|
||||
)
|
||||
|
||||
conv_qa_prompt_template = """ 基于以下已知的信息, 专业、详细的回答用户的问题,
|
||||
如果无法从提供的恶内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题", 但是你可以给出一些与问题相关答案的建议。
|
||||
已知内容:
|
||||
{context}
|
||||
问题:
|
||||
|
121
pilot/model/compression.py
Normal file
121
pilot/model/compression.py
Normal file
@ -0,0 +1,121 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
import dataclasses
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CompressionConfig:
|
||||
"""Group-wise quantization."""
|
||||
num_bits: int
|
||||
group_size: int
|
||||
group_dim: int
|
||||
symmetric: bool
|
||||
enabled: bool = True
|
||||
|
||||
|
||||
default_compression_config = CompressionConfig(
|
||||
num_bits=8, group_size=256, group_dim=1, symmetric=True, enabled=True)
|
||||
|
||||
|
||||
class CLinear(nn.Module):
|
||||
"""Compressed Linear Layer."""
|
||||
|
||||
def __init__(self, weight, bias, device):
|
||||
super().__init__()
|
||||
|
||||
self.weight = compress(weight.data.to(device), default_compression_config)
|
||||
self.bias = bias
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
weight = decompress(self.weight, default_compression_config)
|
||||
return F.linear(input, weight, self.bias)
|
||||
|
||||
|
||||
def compress_module(module, target_device):
|
||||
for attr_str in dir(module):
|
||||
target_attr = getattr(module, attr_str)
|
||||
if type(target_attr) == torch.nn.Linear:
|
||||
setattr(module, attr_str,
|
||||
CLinear(target_attr.weight, target_attr.bias, target_device))
|
||||
for name, child in module.named_children():
|
||||
compress_module(child, target_device)
|
||||
|
||||
|
||||
def compress(tensor, config):
|
||||
"""Simulate group-wise quantization."""
|
||||
if not config.enabled:
|
||||
return tensor
|
||||
|
||||
group_size, num_bits, group_dim, symmetric = (
|
||||
config.group_size, config.num_bits, config.group_dim, config.symmetric)
|
||||
assert num_bits <= 8
|
||||
|
||||
original_shape = tensor.shape
|
||||
num_groups = (original_shape[group_dim] + group_size - 1) // group_size
|
||||
new_shape = (original_shape[:group_dim] + (num_groups, group_size) +
|
||||
original_shape[group_dim+1:])
|
||||
|
||||
# Pad
|
||||
pad_len = (group_size - original_shape[group_dim] % group_size) % group_size
|
||||
if pad_len != 0:
|
||||
pad_shape = original_shape[:group_dim] + (pad_len,) + original_shape[group_dim+1:]
|
||||
tensor = torch.cat([
|
||||
tensor,
|
||||
torch.zeros(pad_shape, dtype=tensor.dtype, device=tensor.device)],
|
||||
dim=group_dim)
|
||||
data = tensor.view(new_shape)
|
||||
|
||||
# Quantize
|
||||
if symmetric:
|
||||
B = 2 ** (num_bits - 1) - 1
|
||||
scale = B / torch.max(data.abs(), dim=group_dim + 1, keepdim=True)[0]
|
||||
data = data * scale
|
||||
data = data.clamp_(-B, B).round_().to(torch.int8)
|
||||
return data, scale, original_shape
|
||||
else:
|
||||
B = 2 ** num_bits - 1
|
||||
mn = torch.min(data, dim=group_dim + 1, keepdim=True)[0]
|
||||
mx = torch.max(data, dim=group_dim + 1, keepdim=True)[0]
|
||||
|
||||
scale = B / (mx - mn)
|
||||
data = data - mn
|
||||
data.mul_(scale)
|
||||
|
||||
data = data.clamp_(0, B).round_().to(torch.uint8)
|
||||
return data, mn, scale, original_shape
|
||||
|
||||
|
||||
def decompress(packed_data, config):
|
||||
"""Simulate group-wise dequantization."""
|
||||
if not config.enabled:
|
||||
return packed_data
|
||||
|
||||
group_size, num_bits, group_dim, symmetric = (
|
||||
config.group_size, config.num_bits, config.group_dim, config.symmetric)
|
||||
|
||||
# Dequantize
|
||||
if symmetric:
|
||||
data, scale, original_shape = packed_data
|
||||
data = data / scale
|
||||
else:
|
||||
data, mn, scale, original_shape = packed_data
|
||||
data = data / scale
|
||||
data.add_(mn)
|
||||
|
||||
# Unpad
|
||||
pad_len = (group_size - original_shape[group_dim] % group_size) % group_size
|
||||
if pad_len:
|
||||
padded_original_shape = (
|
||||
original_shape[:group_dim] +
|
||||
(original_shape[group_dim] + pad_len,) +
|
||||
original_shape[group_dim+1:])
|
||||
data = data.reshape(padded_original_shape)
|
||||
indices = [slice(0, x) for x in original_shape]
|
||||
return data[indices].contiguous()
|
||||
else:
|
||||
return data.view(original_shape)
|
@ -5,13 +5,13 @@ import torch
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate_stream(model, tokenizer, params, device,
|
||||
context_len=2048, stream_interval=2):
|
||||
context_len=4096, stream_interval=2):
|
||||
|
||||
"""Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/inference.py """
|
||||
prompt = params["prompt"]
|
||||
l_prompt = len(prompt)
|
||||
temperature = float(params.get("temperature", 1.0))
|
||||
max_new_tokens = int(params.get("max_new_tokens", 256))
|
||||
max_new_tokens = int(params.get("max_new_tokens", 2048))
|
||||
stop_str = params.get("stop", None)
|
||||
|
||||
input_ids = tokenizer(prompt).input_ids
|
||||
|
34
pilot/model/llm/base.py
Normal file
34
pilot/model/llm/base.py
Normal file
@ -0,0 +1,34 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, TypedDict
|
||||
|
||||
|
||||
class Message(TypedDict):
|
||||
"""Vicuna Message object containing a role and the message content """
|
||||
|
||||
role: str
|
||||
content: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelInfo:
|
||||
"""Struct for model information.
|
||||
|
||||
Would be lovely to eventually get this directly from APIs
|
||||
"""
|
||||
name: str
|
||||
max_tokens: int
|
||||
|
||||
@dataclass
|
||||
class LLMResponse:
|
||||
"""Standard response struct for a response from a LLM model."""
|
||||
model_info = ModelInfo
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatModelResponse(LLMResponse):
|
||||
"""Standard response struct for a response from an LLM model."""
|
||||
|
||||
content: str = None
|
108
pilot/model/llm/llm_utils.py
Normal file
108
pilot/model/llm/llm_utils.py
Normal file
@ -0,0 +1,108 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import abc
|
||||
import time
|
||||
import functools
|
||||
from typing import List, Optional
|
||||
from pilot.model.llm.base import Message
|
||||
from pilot.conversation import conv_templates, Conversation, conv_one_shot, auto_dbgpt_one_shot
|
||||
from pilot.configs.config import Config
|
||||
|
||||
|
||||
# TODO Rewrite this
|
||||
def retry_stream_api(
|
||||
num_retries: int = 10,
|
||||
backoff_base: float = 2.0,
|
||||
warn_user: bool = True
|
||||
):
|
||||
"""Retry an Vicuna Server call.
|
||||
|
||||
Args:
|
||||
num_retries int: Number of retries. Defaults to 10.
|
||||
backoff_base float: Base for exponential backoff. Defaults to 2.
|
||||
warn_user bool: Whether to warn the user. Defaults to True.
|
||||
"""
|
||||
retry_limit_msg = f"Error: Reached rate limit, passing..."
|
||||
backoff_msg = (f"Error: API Bad gateway. Waiting {{backoff}} seconds...")
|
||||
|
||||
def _wrapper(func):
|
||||
@functools.wraps(func)
|
||||
def _wrapped(*args, **kwargs):
|
||||
user_warned = not warn_user
|
||||
num_attempts = num_retries + 1 # +1 for the first attempt
|
||||
for attempt in range(1, num_attempts + 1):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
if (e.http_status != 502) or (attempt == num_attempts):
|
||||
raise
|
||||
|
||||
backoff = backoff_base ** (attempt + 2)
|
||||
time.sleep(backoff)
|
||||
return _wrapped
|
||||
return _wrapper
|
||||
|
||||
# Overly simple abstraction util we create something better
|
||||
# simple retry mechanism when getting a rate error or a bad gateway
|
||||
def create_chat_competion(
|
||||
conv: Conversation,
|
||||
model: Optional[str] = None,
|
||||
temperature: float = None,
|
||||
max_new_tokens: Optional[int] = None,
|
||||
) -> str:
|
||||
"""Create a chat completion using the Vicuna-13b
|
||||
|
||||
Args:
|
||||
messages(List[Message]): The messages to send to the chat completion
|
||||
model (str, optional): The model to use. Default to None.
|
||||
temperature (float, optional): The temperature to use. Defaults to 0.7.
|
||||
max_tokens (int, optional): The max tokens to use. Defaults to None.
|
||||
|
||||
Returns:
|
||||
str: The response from the chat completion
|
||||
"""
|
||||
cfg = Config()
|
||||
if temperature is None:
|
||||
temperature = cfg.temperature
|
||||
|
||||
# TODO request vicuna model get response
|
||||
# convert vicuna message to chat completion.
|
||||
for plugin in cfg.plugins:
|
||||
if plugin.can_handle_chat_completion():
|
||||
pass
|
||||
|
||||
|
||||
class ChatIO(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def prompt_for_input(self, role: str) -> str:
|
||||
"""Prompt for input from a role."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def prompt_for_output(self, role: str) -> str:
|
||||
"""Prompt for output from a role."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def stream_output(self, output_stream, skip_echo_len: int):
|
||||
"""Stream output."""
|
||||
|
||||
|
||||
class SimpleChatIO(ChatIO):
|
||||
def prompt_for_input(self, role: str) -> str:
|
||||
return input(f"{role}: ")
|
||||
|
||||
def prompt_for_output(self, role: str) -> str:
|
||||
print(f"{role}: ", end="", flush=True)
|
||||
|
||||
def stream_output(self, output_stream, skip_echo_len: int):
|
||||
pre = 0
|
||||
for outputs in output_stream:
|
||||
outputs = outputs[skip_echo_len:].strip()
|
||||
now = len(outputs) - 1
|
||||
if now > pre:
|
||||
print(" ".join(outputs[pre:now]), end=" ", flush=True)
|
||||
pre = now
|
||||
|
||||
print(" ".join(outputs[pre:]), flush=True)
|
||||
return " ".join(outputs)
|
||||
|
@ -2,15 +2,17 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
from pilot.singleton import Singleton
|
||||
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoModelForCausalLM,
|
||||
AutoModel
|
||||
)
|
||||
|
||||
from fastchat.serve.compression import compress_module
|
||||
from pilot.model.compression import compress_module
|
||||
|
||||
class ModelLoader:
|
||||
class ModelLoader(metaclass=Singleton):
|
||||
"""Model loader is a class for model load
|
||||
|
||||
Args: model_path
|
||||
|
@ -10,8 +10,6 @@ from fastapi.responses import StreamingResponse
|
||||
from pilot.model.inference import generate_stream
|
||||
from pydantic import BaseModel
|
||||
from pilot.model.inference import generate_output, get_embeddings
|
||||
from fastchat.serve.inference import load_model
|
||||
|
||||
|
||||
from pilot.model.loader import ModelLoader
|
||||
from pilot.configs.model_config import *
|
||||
|
@ -24,11 +24,9 @@ from pilot.conversation import (
|
||||
SeparatorStyle
|
||||
)
|
||||
|
||||
from fastchat.utils import (
|
||||
from pilot.utils import (
|
||||
build_logger,
|
||||
server_error_msg,
|
||||
violates_moderation,
|
||||
moderation_msg
|
||||
)
|
||||
|
||||
from pilot.server.gradio_css import code_highlight_css
|
||||
@ -45,6 +43,7 @@ enable_moderation = False
|
||||
models = []
|
||||
dbs = []
|
||||
vs_list = ["新建知识库"] + get_vector_storelist()
|
||||
autogpt = False
|
||||
|
||||
priority = {
|
||||
"vicuna-13b": "aaa"
|
||||
@ -58,8 +57,6 @@ def get_simlar(q):
|
||||
contents = [dc.page_content for dc, _ in docs]
|
||||
return "\n".join(contents)
|
||||
|
||||
|
||||
|
||||
def gen_sqlgen_conversation(dbname):
|
||||
mo = MySQLOperator(
|
||||
**DB_SETTINGS
|
||||
@ -118,6 +115,8 @@ def regenerate(state, request: gr.Request):
|
||||
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
|
||||
|
||||
def clear_history(request: gr.Request):
|
||||
|
||||
|
||||
logger.info(f"clear_history. ip: {request.client.host}")
|
||||
state = None
|
||||
return (state, [], "") + (disable_btn,) * 5
|
||||
@ -128,14 +127,9 @@ def add_text(state, text, request: gr.Request):
|
||||
if len(text) <= 0:
|
||||
state.skip_next = True
|
||||
return (state, state.to_gradio_chatbot(), "") + (no_change_btn,) * 5
|
||||
if args.moderate:
|
||||
flagged = violates_moderation(text)
|
||||
if flagged:
|
||||
state.skip_next = True
|
||||
return (state, state.to_gradio_chatbot(), moderation_msg) + (
|
||||
no_change_btn,) * 5
|
||||
|
||||
text = text[:1536] # Hard cut-off
|
||||
""" Default support 4000 tokens, if tokens too lang, we will cut off """
|
||||
text = text[:4000]
|
||||
state.append_message(state.roles[0], text)
|
||||
state.append_message(state.roles[1], None)
|
||||
state.skip_next = False
|
||||
@ -152,6 +146,8 @@ def post_process_code(code):
|
||||
return code
|
||||
|
||||
def http_bot(state, mode, db_selector, temperature, max_new_tokens, request: gr.Request):
|
||||
|
||||
print("是否是AUTO-GPT模式.", autogpt)
|
||||
start_tstamp = time.time()
|
||||
model_name = LLM_MODEL
|
||||
|
||||
@ -162,7 +158,8 @@ def http_bot(state, mode, db_selector, temperature, max_new_tokens, request: gr.
|
||||
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
|
||||
return
|
||||
|
||||
|
||||
|
||||
# TODO when tab mode is AUTO_GPT, Prompt need to rebuild.
|
||||
if len(state.messages) == state.offset + 2:
|
||||
# 第一轮对话需要加入提示Prompt
|
||||
|
||||
@ -251,29 +248,28 @@ def http_bot(state, mode, db_selector, temperature, max_new_tokens, request: gr.
|
||||
block_css = (
|
||||
code_highlight_css
|
||||
+ """
|
||||
pre {
|
||||
white-space: pre-wrap; /* Since CSS 2.1 */
|
||||
white-space: -moz-pre-wrap; /* Mozilla, since 1999 */
|
||||
white-space: -pre-wrap; /* Opera 4-6 */
|
||||
white-space: -o-pre-wrap; /* Opera 7 */
|
||||
word-wrap: break-word; /* Internet Explorer 5.5+ */
|
||||
}
|
||||
#notice_markdown th {
|
||||
display: none;
|
||||
}
|
||||
"""
|
||||
pre {
|
||||
white-space: pre-wrap; /* Since CSS 2.1 */
|
||||
white-space: -moz-pre-wrap; /* Mozilla, since 1999 */
|
||||
white-space: -pre-wrap; /* Opera 4-6 */
|
||||
white-space: -o-pre-wrap; /* Opera 7 */
|
||||
word-wrap: break-word; /* Internet Explorer 5.5+ */
|
||||
}
|
||||
#notice_markdown th {
|
||||
display: none;
|
||||
}
|
||||
"""
|
||||
)
|
||||
|
||||
def change_tab(tab):
|
||||
pass
|
||||
|
||||
def change_mode(mode):
|
||||
if mode in ["默认知识库对话", "LLM原生对话"]:
|
||||
return gr.update(visible=False)
|
||||
else:
|
||||
return gr.update(visible=True)
|
||||
|
||||
|
||||
def change_tab():
|
||||
autogpt = True
|
||||
|
||||
def build_single_model_ui():
|
||||
|
||||
notice_markdown = """
|
||||
@ -301,16 +297,17 @@ def build_single_model_ui():
|
||||
|
||||
max_output_tokens = gr.Slider(
|
||||
minimum=0,
|
||||
maximum=1024,
|
||||
value=512,
|
||||
maximum=4096,
|
||||
value=2048,
|
||||
step=64,
|
||||
interactive=True,
|
||||
label="最大输出Token数",
|
||||
)
|
||||
tabs = gr.Tabs()
|
||||
tabs= gr.Tabs()
|
||||
with tabs:
|
||||
with gr.TabItem("SQL生成与诊断", elem_id="SQL"):
|
||||
# TODO A selector to choose database
|
||||
tab_sql = gr.TabItem("SQL生成与诊断", elem_id="SQL")
|
||||
with tab_sql:
|
||||
# TODO A selector to choose database
|
||||
with gr.Row(elem_id="db_selector"):
|
||||
db_selector = gr.Dropdown(
|
||||
label="请选择数据库",
|
||||
@ -318,9 +315,12 @@ def build_single_model_ui():
|
||||
value=dbs[0] if len(models) > 0 else "",
|
||||
interactive=True,
|
||||
show_label=True).style(container=False)
|
||||
tab_auto = gr.TabItem("AUTO-GPT", elem_id="auto")
|
||||
with tab_auto:
|
||||
gr.Markdown("自动执行模式下, DB-GPT可以具备执行SQL、从网络读取知识自动化存储学习的能力")
|
||||
|
||||
with gr.TabItem("知识问答", elem_id="QA"):
|
||||
|
||||
tab_qa = gr.TabItem("知识问答", elem_id="QA")
|
||||
with tab_qa:
|
||||
mode = gr.Radio(["LLM原生对话", "默认知识库对话", "新增知识库对话"], show_label=False, value="LLM原生对话")
|
||||
vs_setting = gr.Accordion("配置知识库", open=False)
|
||||
mode.change(fn=change_mode, inputs=mode, outputs=vs_setting)
|
||||
@ -360,9 +360,7 @@ def build_single_model_ui():
|
||||
regenerate_btn = gr.Button(value="重新生成", interactive=False)
|
||||
clear_btn = gr.Button(value="清理", interactive=False)
|
||||
|
||||
|
||||
gr.Markdown(learn_more_markdown)
|
||||
|
||||
btn_list = [regenerate_btn, clear_btn]
|
||||
regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then(
|
||||
http_bot,
|
||||
@ -434,9 +432,7 @@ if __name__ == "__main__":
|
||||
"--model-list-mode", type=str, default="once", choices=["once", "reload"]
|
||||
)
|
||||
parser.add_argument("--share", default=False, action="store_true")
|
||||
parser.add_argument(
|
||||
"--moderate", action="store_true", help="Enable content moderation"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
logger.info(f"args: {args}")
|
||||
|
||||
|
109
pilot/utils.py
109
pilot/utils.py
@ -3,6 +3,20 @@
|
||||
|
||||
import torch
|
||||
|
||||
import datetime
|
||||
import logging
|
||||
import logging.handlers
|
||||
import os
|
||||
import sys
|
||||
|
||||
import requests
|
||||
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
|
||||
server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
|
||||
|
||||
handler = None
|
||||
|
||||
def get_gpu_memory(max_gpus=None):
|
||||
gpu_memory = []
|
||||
num_gpus = (
|
||||
@ -20,3 +34,98 @@ def get_gpu_memory(max_gpus=None):
|
||||
available_memory = total_memory - allocated_memory
|
||||
gpu_memory.append(available_memory)
|
||||
return gpu_memory
|
||||
|
||||
|
||||
|
||||
def build_logger(logger_name, logger_filename):
|
||||
global handler
|
||||
|
||||
formatter = logging.Formatter(
|
||||
fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
|
||||
# Set the format of root handlers
|
||||
if not logging.getLogger().handlers:
|
||||
logging.basicConfig(level=logging.INFO, encoding='utf-8')
|
||||
logging.getLogger().handlers[0].setFormatter(formatter)
|
||||
|
||||
# Redirect stdout and stderr to loggers
|
||||
stdout_logger = logging.getLogger("stdout")
|
||||
stdout_logger.setLevel(logging.INFO)
|
||||
sl = StreamToLogger(stdout_logger, logging.INFO)
|
||||
sys.stdout = sl
|
||||
|
||||
stderr_logger = logging.getLogger("stderr")
|
||||
stderr_logger.setLevel(logging.ERROR)
|
||||
sl = StreamToLogger(stderr_logger, logging.ERROR)
|
||||
sys.stderr = sl
|
||||
|
||||
# Get logger
|
||||
logger = logging.getLogger(logger_name)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
# Add a file handler for all loggers
|
||||
if handler is None:
|
||||
os.makedirs(LOGDIR, exist_ok=True)
|
||||
filename = os.path.join(LOGDIR, logger_filename)
|
||||
handler = logging.handlers.TimedRotatingFileHandler(
|
||||
filename, when='D', utc=True)
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
for name, item in logging.root.manager.loggerDict.items():
|
||||
if isinstance(item, logging.Logger):
|
||||
item.addHandler(handler)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
class StreamToLogger(object):
|
||||
"""
|
||||
Fake file-like stream object that redirects writes to a logger instance.
|
||||
"""
|
||||
def __init__(self, logger, log_level=logging.INFO):
|
||||
self.terminal = sys.stdout
|
||||
self.logger = logger
|
||||
self.log_level = log_level
|
||||
self.linebuf = ''
|
||||
|
||||
def __getattr__(self, attr):
|
||||
return getattr(self.terminal, attr)
|
||||
|
||||
def write(self, buf):
|
||||
temp_linebuf = self.linebuf + buf
|
||||
self.linebuf = ''
|
||||
for line in temp_linebuf.splitlines(True):
|
||||
# From the io.TextIOWrapper docs:
|
||||
# On output, if newline is None, any '\n' characters written
|
||||
# are translated to the system default line separator.
|
||||
# By default sys.stdout.write() expects '\n' newlines and then
|
||||
# translates them so this is still cross platform.
|
||||
if line[-1] == '\n':
|
||||
encoded_message = line.encode('utf-8', 'ignore').decode('utf-8')
|
||||
self.logger.log(self.log_level, encoded_message.rstrip())
|
||||
else:
|
||||
self.linebuf += line
|
||||
|
||||
def flush(self):
|
||||
if self.linebuf != '':
|
||||
encoded_message = self.linebuf.encode('utf-8', 'ignore').decode('utf-8')
|
||||
self.logger.log(self.log_level, encoded_message.rstrip())
|
||||
self.linebuf = ''
|
||||
|
||||
|
||||
def disable_torch_init():
|
||||
"""
|
||||
Disable the redundant torch default initialization to accelerate model creation.
|
||||
"""
|
||||
import torch
|
||||
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
||||
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
||||
|
||||
|
||||
def pretty_print_semaphore(semaphore):
|
||||
if semaphore is None:
|
||||
return "None"
|
||||
return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
|
||||
|
||||
|
@ -1,7 +1,5 @@
|
||||
accelerate==0.16.0
|
||||
torch==2.0.0
|
||||
torchvision==0.13.1
|
||||
torchaudio==0.12.1
|
||||
accelerate==0.16.0
|
||||
aiohttp==3.8.4
|
||||
aiosignal==1.3.1
|
||||
@ -47,11 +45,12 @@ pycocoevalcap
|
||||
sentence-transformers
|
||||
umap-learn
|
||||
notebook
|
||||
gradio==3.24.1
|
||||
gradio==3.23
|
||||
gradio-client==0.0.8
|
||||
wandb
|
||||
fschat=0.1.10
|
||||
llama-index=0.5.27
|
||||
llama-index==0.5.27
|
||||
pymysql
|
||||
unstructured==0.6.3
|
||||
pytesseract==0.3.10
|
||||
pytesseract==0.3.10
|
||||
chromadb
|
||||
markdown2
|
Loading…
Reference in New Issue
Block a user