fix confilict

This commit is contained in:
csunny 2023-05-12 20:12:47 +08:00
commit b9df562c84
20 changed files with 757 additions and 78 deletions

72
.env.template Normal file
View File

@ -0,0 +1,72 @@
#*******************************************************************#
#** DB-GPT - GENERAL SETTINGS **#
#*******************************************************************#
## DISABLED_COMMAND_CATEGORIES - The list of categories of commands that are disabled. Each of the below are an option:
## pilot.commands.query_execute
## For example, to disable coding related features, uncomment the next line
# DISABLED_COMMAND_CATEGORIES=
#*******************************************************************#
#*** LLM PROVIDER ***#
#*******************************************************************#
# TEMPERATURE=0
#*******************************************************************#
#** LLM MODELS **#
#*******************************************************************#
## SMART_LLM_MODEL - Smart language model (Default: vicuna-13b)
## FAST_LLM_MODEL - Fast language model (Default: chatglm-6b)
# SMART_LLM_MODEL=vicuna-13b
# FAST_LLM_MODEL=chatglm-6b
### EMBEDDINGS
## EMBEDDING_MODEL - Model to use for creating embeddings
## EMBEDDING_TOKENIZER - Tokenizer to use for chunking large inputs
## EMBEDDING_TOKEN_LIMIT - Chunk size limit for large inputs
# EMBEDDING_MODEL=all-MiniLM-L6-v2
# EMBEDDING_TOKENIZER=all-MiniLM-L6-v2
# EMBEDDING_TOKEN_LIMIT=8191
#*******************************************************************#
#** DATABASE SETTINGS **#
#*******************************************************************#
DB_SETTINGS_MYSQL_USER=root
DB_SETTINGS_MYSQL_PASSWORD=password
DB_SETTINGS_MYSQL_HOST=localhost
DB_SETTINGS_MYSQL_PORT=3306
### MILVUS
## MILVUS_ADDR - Milvus remote address (e.g. localhost:19530)
## MILVUS_USERNAME - username for your Milvus database
## MILVUS_PASSWORD - password for your Milvus database
## MILVUS_SECURE - True to enable TLS. (Default: False)
## Setting MILVUS_ADDR to a `https://` URL will override this setting.
## MILVUS_COLLECTION - Milvus collection, change it if you want to start a new memory and retain the old memory.
# MILVUS_ADDR=localhost:19530
# MILVUS_USERNAME=
# MILVUS_PASSWORD=
# MILVUS_SECURE=
# MILVUS_COLLECTION=dbgpt
#*******************************************************************#
#** ALLOWLISTED PLUGINS **#
#*******************************************************************#
#ALLOWLISTED_PLUGINS - Sets the listed plugins that are allowed (Example: plugin1,plugin2,plugin3)
#DENYLISTED_PLUGINS - Sets the listed plugins that are not allowed (Example: plugin1,plugin2,plugin3)
ALLOWLISTED_PLUGINS=
DENYLISTED_PLUGINS=
#*******************************************************************#
#** CHAT PLUGIN SETTINGS **#
#*******************************************************************#
# CHAT_MESSAGES_ENABLED - Enable chat messages (Default: False)
# CHAT_MESSAGES_ENABLED=False

23
.github/workflows/pylint.yml vendored Normal file
View File

@ -0,0 +1,23 @@
name: Pylint
on: [push]
jobs:
build:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10"]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pylint
- name: Analysing the code with pylint
run: |
pylint $(git ls-files '*.py')

39
.github/workflows/python-publish.yml vendored Normal file
View File

@ -0,0 +1,39 @@
# This workflow will upload a Python Package using Twine when a release is created
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries
# This workflow uses actions that are not certified by GitHub.
# They are provided by a third-party and are governed by
# separate terms of service, privacy policy, and support
# documentation.
name: Upload Python Package
on:
release:
types: [published]
permissions:
contents: read
jobs:
deploy:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v3
with:
python-version: '3.x'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install build
- name: Build package
run: python -m build
- name: Publish package
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
with:
user: __token__
password: ${{ secrets.PYPI_API_TOKEN }}

View File

@ -44,7 +44,7 @@ The Generated SQL is runable.
1. First you need to install python requirements.
```
python>=3.9
pip install -r requirements
pip install -r requirements.txt
```
or if you use conda envirenment, you can use this command
```
@ -63,7 +63,7 @@ The password just for test, you can change this if necessary
# Install
1. 基础模型下载
关于基础模型, 可以根据[vicuna](https://github.com/lm-sys/FastChat/blob/main/README.md#model-weights)合成教程进行合成。
如果此步有困难的同学,也可以直接使用[Hugging Face](https://huggingface.co/)上的模型进行替代。 替代模型: [vicuna-13b](https://huggingface.co/Tribbiani/vicuna-13b)
如果此步有困难的同学,也可以直接使用[Hugging Face](https://huggingface.co/)上的模型进行替代. [替代模型](https://huggingface.co/Tribbiani/vicuna-7b)
2. Run model server
```
@ -86,4 +86,4 @@ python webserver.py
# Contribute
[Contribute](https://github.com/csunny/DB-GPT/blob/main/CONTRIBUTING)
# Licence
[MIT](https://github.com/csunny/DB-GPT/blob/main/LICENSE)
[MIT](https://github.com/csunny/DB-GPT/blob/main/LICENSE)

View File

@ -7,11 +7,9 @@ dependencies:
- python=3.9
- cudatoolkit
- pip
- pytorch=1.12.1
- pytorch-mutex=1.0=cuda
- torchaudio=0.12.1
- torchvision=0.13.1
- pip:
- pytorch
- accelerate==0.16.0
- aiohttp==3.8.4
- aiosignal==1.3.1
@ -57,11 +55,12 @@ dependencies:
- sentence-transformers
- umap-learn
- notebook
- gradio==3.24.1
- gradio==3.23
- gradio-client==0.0.8
- wandb
- fschat=0.1.10
- llama-index=0.5.27
- llama-index==0.5.27
- pymysql
- unstructured==0.6.3
- pytesseract==0.3.10
- markdown2
- chromadb

19
examples/gradio_test.py Normal file
View File

@ -0,0 +1,19 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
import gradio as gr
def change_tab():
return gr.Tabs.update(selected=1)
with gr.Blocks() as demo:
with gr.Tabs() as tabs:
with gr.TabItem("Train", id=0):
t = gr.Textbox()
with gr.TabItem("Inference", id=1):
i = gr.Image()
btn = gr.Button()
btn.click(change_tab, None, tabs)
demo.launch()

View File

36
pilot/commands/command.py Normal file
View File

@ -0,0 +1,36 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import functools
import importlib
import inspect
from typing import Any, Callable, Optional
class Command:
"""A class representing a command.
Attributes:
name (str): The name of the command.
description (str): A brief description of what the command does.
signature (str): The signature of the function that the command executes. Default to None.
"""
def __init__(self,
name: str,
description: str,
method: Callable[..., Any],
signature: str = "",
enabled: bool = True,
disabled_reason: Optional[str] = None,
) -> None:
self.name = name
self.description = description
self.method = method
self.signature = signature if signature else str(inspect.signature(self.method))
self.enabled = enabled
self.disabled_reason = disabled_reason
def __call__(self, *args: Any, **kwds: Any) -> Any:
if not self.enabled:
return f"Command '{self.name}' is disabled: {self.disabled_reason}"
return self.method(*args, **kwds)

View File

@ -12,26 +12,55 @@ class Config(metaclass=Singleton):
def __init__(self) -> None:
"""Initialize the Config class"""
# TODO change model_config there
# TODO change model_config there
self.debug_mode = False
self.skip_reprompt = False
self.plugins_dir = os.getenv("PLUGINS_DIR", "plugins")
self.plugins = List[AutoGPTPluginTemplate] = []
self.temperature = float(os.getenv("TEMPERATURE", 0.7))
# TODO change model_config there
self.execute_local_commands = (
os.getenv("EXECUTE_LOCAL_COMMANDS", "False") == "True"
)
self.temperature = float(os.getenv("TEMPERATURE", "0.7"))
# User agent header to use when making HTTP requests
# Some websites might just completely deny request with an error code if
# no user agent was found.
self.user_agent = os.getenv(
"USER_AGENT",
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_4) AppleWebKit/537.36"
" (KHTML, like Gecko) Chrome/83.0.4103.97 Safari/537.36",
)
# milvus or zilliz cloud configuration
self.milvus_addr = os.getenv("MILVUS_ADDR", "localhost:19530")
self.milvus_username = os.getenv("MILVUS_USERNAME")
self.milvus_password = os.getenv("MILVUS_PASSWORD")
self.milvus_collection = os.getenv("MILVUS_COLLECTION", "dbgpt")
self.milvus_secure = os.getenv("MILVUS_SECURE") == "True"
self.plugins_dir = os.getenv("PLUGINS_DIR", 'plugins')
self.plugins:List[AutoGPTPluginTemplate] = []
plugins_allowlist = os.getenv("ALLOWLISTED_PLUGINS")
if plugins_allowlist:
self.plugins_allowlist = plugins_allowlist.split(",")
else:
self.plugins_allowlist = []
plugins_denylist = os.getenv("DENYLISTED_PLUGINS")
if plugins_denylist:
self.plugins_denylist = []
def set_debug_mode(self, value: bool) -> None:
"""Set the debug mode value."""
"""Set the debug mode value"""
self.debug_mode = value
def set_plugins(self,value: bool) -> None:
"""Set the plugins value."""
def set_plugins(self, value: list) -> None:
"""Set the plugins value. """
self.plugins = value
def set_temperature(self, value: int) -> None:
""" Set the temperature value."""
def set_templature(self, value: int) -> None:
"""Set the temperature value."""
self.temperature = value

View File

@ -5,6 +5,7 @@ import torch
import os
import nltk
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
MODEL_PATH = os.path.join(ROOT_PATH, "models")
PILOT_PATH = os.path.join(ROOT_PATH, "pilot")
@ -26,9 +27,8 @@ LLM_MODEL_CONFIG = {
VECTOR_SEARCH_TOP_K = 3
LLM_MODEL = "vicuna-13b"
LIMIT_MODEL_CONCURRENCY = 5
MAX_POSITION_EMBEDDINGS = 2048
VICUNA_MODEL_SERVER = "http://192.168.31.114:8000"
MAX_POSITION_EMBEDDINGS = 4096
VICUNA_MODEL_SERVER = "http://47.97.125.199:8000"
# Load model config
ISLOAD_8BIT = True
@ -37,7 +37,7 @@ ISDEBUG = False
DB_SETTINGS = {
"user": "root",
"password": "aa123456",
"password": "aa12345678",
"host": "localhost",
"port": 3306
}

View File

@ -10,6 +10,8 @@ class SeparatorStyle(Enum):
SINGLE = auto()
TWO = auto()
THREE = auto()
FOUR = auto()
@dataclasses.dataclass
class Conversation:
@ -146,10 +148,103 @@ conv_vicuna_v1 = Conversation(
sep2="</s>",
)
auto_dbgpt_one_shot = Conversation(
system="You are DB-GPT, an AI designed to answer questions about HackerNews by query `hackerbews` database in MySQL. "
"Your decisions must always be made independently without seeking user assistance. Play to your strengths as an LLM and pursue simple strategies with no legal complications.",
roles=("USER", "ASSISTANT"),
messages=(
(
"USER",
""" Answer how many users does hackernews have by query mysql database
Constraints:
1. ~4000 word limit for short term memory. Your short term memory is short, so immediately save important information to files.
2. If you are unsure how you previously did something or want to recall past events, thinking about similar events will help you remember.
3. No user assistance
4. Exclusively use the commands listed in double quotes e.g. "command name"
conv_qa_prompt_template = """ 基于以下已知的信息, 专业、详细的回答用户的问题。
如果无法从提供的恶内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题", 但是你可以给出一些与问题相关答案的建议:
Commands:
1. analyze_code: Analyze Code, args: "code": "<full_code_string>"
2. execute_python_file: Execute Python File, args: "filename": "<filename>"
3. append_to_file: Append to file, args: "filename": "<filename>", "text": "<text>"
4. delete_file: Delete file, args: "filename": "<filename>"
5. list_files: List Files in Directory, args: "directory": "<directory>"
6. read_file: Read file, args: "filename": "<filename>"
7. write_to_file: Write to file, args: "filename": "<filename>", "text": "<text>"
8. tidb_sql_executor: "Execute SQL in TiDB Database.", args: "sql": "<sql>"
Resources:
1. Internet access for searches and information gathering.
2. Long Term memory management.
3. vicuna powered Agents for delegation of simple tasks.
4. File output.
Performance Evaluation:
1. Continuously review and analyze your actions to ensure you are performing to the best of your abilities.
2. Constructively self-criticize your big-picture behavior constantly.
3. Reflect on past decisions and strategies to refine your approach.
4. Every command has a cost, so be smart and efficient. Aim to complete tasks in the least number of steps.
5. Write all code to a file.
You should only respond in JSON format as described below
Response Format:
{
"thoughts": {
"text": "thought",
"reasoning": "reasoning",
"plan": "- short bulleted\n- list that conveys\n- long-term plan",
"criticism": "constructive self-criticism",
"speak": "thoughts summary to say to user"
},
"command": {
"name": "command name",
"args": {
"arg name": "value"
}
}
}
Ensure the response can be parsed by Python json.loads
"""
),
(
"ASSISTANT",
"""
{
"thoughts": {
"text": "thought",
"reasoning": "reasoning",
"plan": "- short bulleted\n- list that conveys\n- long-term plan",
"criticism": "constructive self-criticism",
"speak": "thoughts summary to say to user"
},
"command": {
"name": "command name",
"args": {
"arg name": "value"
}
}
}
"""
)
),
offset=0,
sep_style=SeparatorStyle.THREE,
sep=" ",
sep2="</s>",
)
auto_dbgpt_without_shot = Conversation(
system="You are DB-GPT, an AI designed to answer questions about HackerNews by query `hackerbews` database in MySQL. "
"Your decisions must always be made independently without seeking user assistance. Play to your strengths as an LLM and pursue simple strategies with no legal complications.",
roles=("USER", "ASSISTANT"),
messages=(),
offset=0,
sep_style=SeparatorStyle.FOUR,
sep=" ",
sep2="</s>",
)
conv_qa_prompt_template = """ 基于以下已知的信息, 专业、详细的回答用户的问题,
如果无法从提供的恶内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题", 但是你可以给出一些与问题相关答案的建议
已知内容:
{context}
问题:

121
pilot/model/compression.py Normal file
View File

@ -0,0 +1,121 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import dataclasses
import torch
from torch import Tensor
import torch.nn as nn
from torch.nn import functional as F
@dataclasses.dataclass
class CompressionConfig:
"""Group-wise quantization."""
num_bits: int
group_size: int
group_dim: int
symmetric: bool
enabled: bool = True
default_compression_config = CompressionConfig(
num_bits=8, group_size=256, group_dim=1, symmetric=True, enabled=True)
class CLinear(nn.Module):
"""Compressed Linear Layer."""
def __init__(self, weight, bias, device):
super().__init__()
self.weight = compress(weight.data.to(device), default_compression_config)
self.bias = bias
def forward(self, input: Tensor) -> Tensor:
weight = decompress(self.weight, default_compression_config)
return F.linear(input, weight, self.bias)
def compress_module(module, target_device):
for attr_str in dir(module):
target_attr = getattr(module, attr_str)
if type(target_attr) == torch.nn.Linear:
setattr(module, attr_str,
CLinear(target_attr.weight, target_attr.bias, target_device))
for name, child in module.named_children():
compress_module(child, target_device)
def compress(tensor, config):
"""Simulate group-wise quantization."""
if not config.enabled:
return tensor
group_size, num_bits, group_dim, symmetric = (
config.group_size, config.num_bits, config.group_dim, config.symmetric)
assert num_bits <= 8
original_shape = tensor.shape
num_groups = (original_shape[group_dim] + group_size - 1) // group_size
new_shape = (original_shape[:group_dim] + (num_groups, group_size) +
original_shape[group_dim+1:])
# Pad
pad_len = (group_size - original_shape[group_dim] % group_size) % group_size
if pad_len != 0:
pad_shape = original_shape[:group_dim] + (pad_len,) + original_shape[group_dim+1:]
tensor = torch.cat([
tensor,
torch.zeros(pad_shape, dtype=tensor.dtype, device=tensor.device)],
dim=group_dim)
data = tensor.view(new_shape)
# Quantize
if symmetric:
B = 2 ** (num_bits - 1) - 1
scale = B / torch.max(data.abs(), dim=group_dim + 1, keepdim=True)[0]
data = data * scale
data = data.clamp_(-B, B).round_().to(torch.int8)
return data, scale, original_shape
else:
B = 2 ** num_bits - 1
mn = torch.min(data, dim=group_dim + 1, keepdim=True)[0]
mx = torch.max(data, dim=group_dim + 1, keepdim=True)[0]
scale = B / (mx - mn)
data = data - mn
data.mul_(scale)
data = data.clamp_(0, B).round_().to(torch.uint8)
return data, mn, scale, original_shape
def decompress(packed_data, config):
"""Simulate group-wise dequantization."""
if not config.enabled:
return packed_data
group_size, num_bits, group_dim, symmetric = (
config.group_size, config.num_bits, config.group_dim, config.symmetric)
# Dequantize
if symmetric:
data, scale, original_shape = packed_data
data = data / scale
else:
data, mn, scale, original_shape = packed_data
data = data / scale
data.add_(mn)
# Unpad
pad_len = (group_size - original_shape[group_dim] % group_size) % group_size
if pad_len:
padded_original_shape = (
original_shape[:group_dim] +
(original_shape[group_dim] + pad_len,) +
original_shape[group_dim+1:])
data = data.reshape(padded_original_shape)
indices = [slice(0, x) for x in original_shape]
return data[indices].contiguous()
else:
return data.view(original_shape)

View File

@ -5,13 +5,13 @@ import torch
@torch.inference_mode()
def generate_stream(model, tokenizer, params, device,
context_len=2048, stream_interval=2):
context_len=4096, stream_interval=2):
"""Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/inference.py """
prompt = params["prompt"]
l_prompt = len(prompt)
temperature = float(params.get("temperature", 1.0))
max_new_tokens = int(params.get("max_new_tokens", 256))
max_new_tokens = int(params.get("max_new_tokens", 2048))
stop_str = params.get("stop", None)
input_ids = tokenizer(prompt).input_ids

34
pilot/model/llm/base.py Normal file
View File

@ -0,0 +1,34 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from dataclasses import dataclass, field
from typing import List, TypedDict
class Message(TypedDict):
"""Vicuna Message object containing a role and the message content """
role: str
content: str
@dataclass
class ModelInfo:
"""Struct for model information.
Would be lovely to eventually get this directly from APIs
"""
name: str
max_tokens: int
@dataclass
class LLMResponse:
"""Standard response struct for a response from a LLM model."""
model_info = ModelInfo
@dataclass
class ChatModelResponse(LLMResponse):
"""Standard response struct for a response from an LLM model."""
content: str = None

View File

@ -0,0 +1,108 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import abc
import time
import functools
from typing import List, Optional
from pilot.model.llm.base import Message
from pilot.conversation import conv_templates, Conversation, conv_one_shot, auto_dbgpt_one_shot
from pilot.configs.config import Config
# TODO Rewrite this
def retry_stream_api(
num_retries: int = 10,
backoff_base: float = 2.0,
warn_user: bool = True
):
"""Retry an Vicuna Server call.
Args:
num_retries int: Number of retries. Defaults to 10.
backoff_base float: Base for exponential backoff. Defaults to 2.
warn_user bool: Whether to warn the user. Defaults to True.
"""
retry_limit_msg = f"Error: Reached rate limit, passing..."
backoff_msg = (f"Error: API Bad gateway. Waiting {{backoff}} seconds...")
def _wrapper(func):
@functools.wraps(func)
def _wrapped(*args, **kwargs):
user_warned = not warn_user
num_attempts = num_retries + 1 # +1 for the first attempt
for attempt in range(1, num_attempts + 1):
try:
return func(*args, **kwargs)
except Exception as e:
if (e.http_status != 502) or (attempt == num_attempts):
raise
backoff = backoff_base ** (attempt + 2)
time.sleep(backoff)
return _wrapped
return _wrapper
# Overly simple abstraction util we create something better
# simple retry mechanism when getting a rate error or a bad gateway
def create_chat_competion(
conv: Conversation,
model: Optional[str] = None,
temperature: float = None,
max_new_tokens: Optional[int] = None,
) -> str:
"""Create a chat completion using the Vicuna-13b
Args:
messages(List[Message]): The messages to send to the chat completion
model (str, optional): The model to use. Default to None.
temperature (float, optional): The temperature to use. Defaults to 0.7.
max_tokens (int, optional): The max tokens to use. Defaults to None.
Returns:
str: The response from the chat completion
"""
cfg = Config()
if temperature is None:
temperature = cfg.temperature
# TODO request vicuna model get response
# convert vicuna message to chat completion.
for plugin in cfg.plugins:
if plugin.can_handle_chat_completion():
pass
class ChatIO(abc.ABC):
@abc.abstractmethod
def prompt_for_input(self, role: str) -> str:
"""Prompt for input from a role."""
@abc.abstractmethod
def prompt_for_output(self, role: str) -> str:
"""Prompt for output from a role."""
@abc.abstractmethod
def stream_output(self, output_stream, skip_echo_len: int):
"""Stream output."""
class SimpleChatIO(ChatIO):
def prompt_for_input(self, role: str) -> str:
return input(f"{role}: ")
def prompt_for_output(self, role: str) -> str:
print(f"{role}: ", end="", flush=True)
def stream_output(self, output_stream, skip_echo_len: int):
pre = 0
for outputs in output_stream:
outputs = outputs[skip_echo_len:].strip()
now = len(outputs) - 1
if now > pre:
print(" ".join(outputs[pre:now]), end=" ", flush=True)
pre = now
print(" ".join(outputs[pre:]), flush=True)
return " ".join(outputs)

View File

@ -2,15 +2,17 @@
# -*- coding: utf-8 -*-
import torch
from pilot.singleton import Singleton
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
AutoModel
)
from fastchat.serve.compression import compress_module
from pilot.model.compression import compress_module
class ModelLoader:
class ModelLoader(metaclass=Singleton):
"""Model loader is a class for model load
Args: model_path

View File

@ -10,8 +10,6 @@ from fastapi.responses import StreamingResponse
from pilot.model.inference import generate_stream
from pydantic import BaseModel
from pilot.model.inference import generate_output, get_embeddings
from fastchat.serve.inference import load_model
from pilot.model.loader import ModelLoader
from pilot.configs.model_config import *

View File

@ -24,11 +24,9 @@ from pilot.conversation import (
SeparatorStyle
)
from fastchat.utils import (
from pilot.utils import (
build_logger,
server_error_msg,
violates_moderation,
moderation_msg
)
from pilot.server.gradio_css import code_highlight_css
@ -45,6 +43,7 @@ enable_moderation = False
models = []
dbs = []
vs_list = ["新建知识库"] + get_vector_storelist()
autogpt = False
priority = {
"vicuna-13b": "aaa"
@ -58,8 +57,6 @@ def get_simlar(q):
contents = [dc.page_content for dc, _ in docs]
return "\n".join(contents)
def gen_sqlgen_conversation(dbname):
mo = MySQLOperator(
**DB_SETTINGS
@ -118,6 +115,8 @@ def regenerate(state, request: gr.Request):
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
def clear_history(request: gr.Request):
logger.info(f"clear_history. ip: {request.client.host}")
state = None
return (state, [], "") + (disable_btn,) * 5
@ -128,14 +127,9 @@ def add_text(state, text, request: gr.Request):
if len(text) <= 0:
state.skip_next = True
return (state, state.to_gradio_chatbot(), "") + (no_change_btn,) * 5
if args.moderate:
flagged = violates_moderation(text)
if flagged:
state.skip_next = True
return (state, state.to_gradio_chatbot(), moderation_msg) + (
no_change_btn,) * 5
text = text[:1536] # Hard cut-off
""" Default support 4000 tokens, if tokens too lang, we will cut off """
text = text[:4000]
state.append_message(state.roles[0], text)
state.append_message(state.roles[1], None)
state.skip_next = False
@ -152,6 +146,8 @@ def post_process_code(code):
return code
def http_bot(state, mode, db_selector, temperature, max_new_tokens, request: gr.Request):
print("是否是AUTO-GPT模式.", autogpt)
start_tstamp = time.time()
model_name = LLM_MODEL
@ -162,7 +158,8 @@ def http_bot(state, mode, db_selector, temperature, max_new_tokens, request: gr.
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
return
# TODO when tab mode is AUTO_GPT, Prompt need to rebuild.
if len(state.messages) == state.offset + 2:
# 第一轮对话需要加入提示Prompt
@ -251,29 +248,28 @@ def http_bot(state, mode, db_selector, temperature, max_new_tokens, request: gr.
block_css = (
code_highlight_css
+ """
pre {
white-space: pre-wrap; /* Since CSS 2.1 */
white-space: -moz-pre-wrap; /* Mozilla, since 1999 */
white-space: -pre-wrap; /* Opera 4-6 */
white-space: -o-pre-wrap; /* Opera 7 */
word-wrap: break-word; /* Internet Explorer 5.5+ */
}
#notice_markdown th {
display: none;
}
"""
pre {
white-space: pre-wrap; /* Since CSS 2.1 */
white-space: -moz-pre-wrap; /* Mozilla, since 1999 */
white-space: -pre-wrap; /* Opera 4-6 */
white-space: -o-pre-wrap; /* Opera 7 */
word-wrap: break-word; /* Internet Explorer 5.5+ */
}
#notice_markdown th {
display: none;
}
"""
)
def change_tab(tab):
pass
def change_mode(mode):
if mode in ["默认知识库对话", "LLM原生对话"]:
return gr.update(visible=False)
else:
return gr.update(visible=True)
def change_tab():
autogpt = True
def build_single_model_ui():
notice_markdown = """
@ -301,16 +297,17 @@ def build_single_model_ui():
max_output_tokens = gr.Slider(
minimum=0,
maximum=1024,
value=512,
maximum=4096,
value=2048,
step=64,
interactive=True,
label="最大输出Token数",
)
tabs = gr.Tabs()
tabs= gr.Tabs()
with tabs:
with gr.TabItem("SQL生成与诊断", elem_id="SQL"):
# TODO A selector to choose database
tab_sql = gr.TabItem("SQL生成与诊断", elem_id="SQL")
with tab_sql:
# TODO A selector to choose database
with gr.Row(elem_id="db_selector"):
db_selector = gr.Dropdown(
label="请选择数据库",
@ -318,9 +315,12 @@ def build_single_model_ui():
value=dbs[0] if len(models) > 0 else "",
interactive=True,
show_label=True).style(container=False)
tab_auto = gr.TabItem("AUTO-GPT", elem_id="auto")
with tab_auto:
gr.Markdown("自动执行模式下, DB-GPT可以具备执行SQL、从网络读取知识自动化存储学习的能力")
with gr.TabItem("知识问答", elem_id="QA"):
tab_qa = gr.TabItem("知识问答", elem_id="QA")
with tab_qa:
mode = gr.Radio(["LLM原生对话", "默认知识库对话", "新增知识库对话"], show_label=False, value="LLM原生对话")
vs_setting = gr.Accordion("配置知识库", open=False)
mode.change(fn=change_mode, inputs=mode, outputs=vs_setting)
@ -360,9 +360,7 @@ def build_single_model_ui():
regenerate_btn = gr.Button(value="重新生成", interactive=False)
clear_btn = gr.Button(value="清理", interactive=False)
gr.Markdown(learn_more_markdown)
btn_list = [regenerate_btn, clear_btn]
regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then(
http_bot,
@ -434,9 +432,7 @@ if __name__ == "__main__":
"--model-list-mode", type=str, default="once", choices=["once", "reload"]
)
parser.add_argument("--share", default=False, action="store_true")
parser.add_argument(
"--moderate", action="store_true", help="Enable content moderation"
)
args = parser.parse_args()
logger.info(f"args: {args}")

View File

@ -3,6 +3,20 @@
import torch
import datetime
import logging
import logging.handlers
import os
import sys
import requests
from pilot.configs.model_config import LOGDIR
server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
handler = None
def get_gpu_memory(max_gpus=None):
gpu_memory = []
num_gpus = (
@ -20,3 +34,98 @@ def get_gpu_memory(max_gpus=None):
available_memory = total_memory - allocated_memory
gpu_memory.append(available_memory)
return gpu_memory
def build_logger(logger_name, logger_filename):
global handler
formatter = logging.Formatter(
fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
# Set the format of root handlers
if not logging.getLogger().handlers:
logging.basicConfig(level=logging.INFO, encoding='utf-8')
logging.getLogger().handlers[0].setFormatter(formatter)
# Redirect stdout and stderr to loggers
stdout_logger = logging.getLogger("stdout")
stdout_logger.setLevel(logging.INFO)
sl = StreamToLogger(stdout_logger, logging.INFO)
sys.stdout = sl
stderr_logger = logging.getLogger("stderr")
stderr_logger.setLevel(logging.ERROR)
sl = StreamToLogger(stderr_logger, logging.ERROR)
sys.stderr = sl
# Get logger
logger = logging.getLogger(logger_name)
logger.setLevel(logging.INFO)
# Add a file handler for all loggers
if handler is None:
os.makedirs(LOGDIR, exist_ok=True)
filename = os.path.join(LOGDIR, logger_filename)
handler = logging.handlers.TimedRotatingFileHandler(
filename, when='D', utc=True)
handler.setFormatter(formatter)
for name, item in logging.root.manager.loggerDict.items():
if isinstance(item, logging.Logger):
item.addHandler(handler)
return logger
class StreamToLogger(object):
"""
Fake file-like stream object that redirects writes to a logger instance.
"""
def __init__(self, logger, log_level=logging.INFO):
self.terminal = sys.stdout
self.logger = logger
self.log_level = log_level
self.linebuf = ''
def __getattr__(self, attr):
return getattr(self.terminal, attr)
def write(self, buf):
temp_linebuf = self.linebuf + buf
self.linebuf = ''
for line in temp_linebuf.splitlines(True):
# From the io.TextIOWrapper docs:
# On output, if newline is None, any '\n' characters written
# are translated to the system default line separator.
# By default sys.stdout.write() expects '\n' newlines and then
# translates them so this is still cross platform.
if line[-1] == '\n':
encoded_message = line.encode('utf-8', 'ignore').decode('utf-8')
self.logger.log(self.log_level, encoded_message.rstrip())
else:
self.linebuf += line
def flush(self):
if self.linebuf != '':
encoded_message = self.linebuf.encode('utf-8', 'ignore').decode('utf-8')
self.logger.log(self.log_level, encoded_message.rstrip())
self.linebuf = ''
def disable_torch_init():
"""
Disable the redundant torch default initialization to accelerate model creation.
"""
import torch
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
def pretty_print_semaphore(semaphore):
if semaphore is None:
return "None"
return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"

View File

@ -1,7 +1,5 @@
accelerate==0.16.0
torch==2.0.0
torchvision==0.13.1
torchaudio==0.12.1
accelerate==0.16.0
aiohttp==3.8.4
aiosignal==1.3.1
@ -47,11 +45,12 @@ pycocoevalcap
sentence-transformers
umap-learn
notebook
gradio==3.24.1
gradio==3.23
gradio-client==0.0.8
wandb
fschat=0.1.10
llama-index=0.5.27
llama-index==0.5.27
pymysql
unstructured==0.6.3
pytesseract==0.3.10
pytesseract==0.3.10
chromadb
markdown2