diff --git a/.env.template b/.env.template new file mode 100644 index 000000000..5bf746eaa --- /dev/null +++ b/.env.template @@ -0,0 +1,72 @@ +#*******************************************************************# +#** DB-GPT - GENERAL SETTINGS **# +#*******************************************************************# +## DISABLED_COMMAND_CATEGORIES - The list of categories of commands that are disabled. Each of the below are an option: +## pilot.commands.query_execute + +## For example, to disable coding related features, uncomment the next line +# DISABLED_COMMAND_CATEGORIES= + + +#*******************************************************************# +#*** LLM PROVIDER ***# +#*******************************************************************# + +# TEMPERATURE=0 + +#*******************************************************************# +#** LLM MODELS **# +#*******************************************************************# + +## SMART_LLM_MODEL - Smart language model (Default: vicuna-13b) +## FAST_LLM_MODEL - Fast language model (Default: chatglm-6b) +# SMART_LLM_MODEL=vicuna-13b +# FAST_LLM_MODEL=chatglm-6b + + +### EMBEDDINGS +## EMBEDDING_MODEL - Model to use for creating embeddings +## EMBEDDING_TOKENIZER - Tokenizer to use for chunking large inputs +## EMBEDDING_TOKEN_LIMIT - Chunk size limit for large inputs +# EMBEDDING_MODEL=all-MiniLM-L6-v2 +# EMBEDDING_TOKENIZER=all-MiniLM-L6-v2 +# EMBEDDING_TOKEN_LIMIT=8191 + + +#*******************************************************************# +#** DATABASE SETTINGS **# +#*******************************************************************# +DB_SETTINGS_MYSQL_USER=root +DB_SETTINGS_MYSQL_PASSWORD=password +DB_SETTINGS_MYSQL_HOST=localhost +DB_SETTINGS_MYSQL_PORT=3306 + + +### MILVUS +## MILVUS_ADDR - Milvus remote address (e.g. localhost:19530) +## MILVUS_USERNAME - username for your Milvus database +## MILVUS_PASSWORD - password for your Milvus database +## MILVUS_SECURE - True to enable TLS. (Default: False) +## Setting MILVUS_ADDR to a `https://` URL will override this setting. +## MILVUS_COLLECTION - Milvus collection, change it if you want to start a new memory and retain the old memory. +# MILVUS_ADDR=localhost:19530 +# MILVUS_USERNAME= +# MILVUS_PASSWORD= +# MILVUS_SECURE= +# MILVUS_COLLECTION=dbgpt + +#*******************************************************************# +#** ALLOWLISTED PLUGINS **# +#*******************************************************************# + +#ALLOWLISTED_PLUGINS - Sets the listed plugins that are allowed (Example: plugin1,plugin2,plugin3) +#DENYLISTED_PLUGINS - Sets the listed plugins that are not allowed (Example: plugin1,plugin2,plugin3) +ALLOWLISTED_PLUGINS= +DENYLISTED_PLUGINS= + + +#*******************************************************************# +#** CHAT PLUGIN SETTINGS **# +#*******************************************************************# +# CHAT_MESSAGES_ENABLED - Enable chat messages (Default: False) +# CHAT_MESSAGES_ENABLED=False diff --git a/.github/workflows/pylint.yml b/.github/workflows/pylint.yml new file mode 100644 index 000000000..383e65cd0 --- /dev/null +++ b/.github/workflows/pylint.yml @@ -0,0 +1,23 @@ +name: Pylint + +on: [push] + +jobs: + build: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.8", "3.9", "3.10"] + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install pylint + - name: Analysing the code with pylint + run: | + pylint $(git ls-files '*.py') diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml new file mode 100644 index 000000000..bdaab28a4 --- /dev/null +++ b/.github/workflows/python-publish.yml @@ -0,0 +1,39 @@ +# This workflow will upload a Python Package using Twine when a release is created +# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries + +# This workflow uses actions that are not certified by GitHub. +# They are provided by a third-party and are governed by +# separate terms of service, privacy policy, and support +# documentation. + +name: Upload Python Package + +on: + release: + types: [published] + +permissions: + contents: read + +jobs: + deploy: + + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + - name: Set up Python + uses: actions/setup-python@v3 + with: + python-version: '3.x' + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install build + - name: Build package + run: python -m build + - name: Publish package + uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 + with: + user: __token__ + password: ${{ secrets.PYPI_API_TOKEN }} diff --git a/README.md b/README.md index edf0e71cf..7378799d3 100644 --- a/README.md +++ b/README.md @@ -44,7 +44,7 @@ The Generated SQL is runable. 1. First you need to install python requirements. ``` python>=3.9 -pip install -r requirements +pip install -r requirements.txt ``` or if you use conda envirenment, you can use this command ``` @@ -63,7 +63,7 @@ The password just for test, you can change this if necessary # Install 1. 基础模型下载 关于基础模型, 可以根据[vicuna](https://github.com/lm-sys/FastChat/blob/main/README.md#model-weights)合成教程进行合成。 -如果此步有困难的同学,也可以直接使用[Hugging Face](https://huggingface.co/)上的模型进行替代。 替代模型: [vicuna-13b](https://huggingface.co/Tribbiani/vicuna-13b) +如果此步有困难的同学,也可以直接使用[Hugging Face](https://huggingface.co/)上的模型进行替代. [替代模型](https://huggingface.co/Tribbiani/vicuna-7b) 2. Run model server ``` @@ -86,4 +86,4 @@ python webserver.py # Contribute [Contribute](https://github.com/csunny/DB-GPT/blob/main/CONTRIBUTING) # Licence -[MIT](https://github.com/csunny/DB-GPT/blob/main/LICENSE) \ No newline at end of file +[MIT](https://github.com/csunny/DB-GPT/blob/main/LICENSE) diff --git a/environment.yml b/environment.yml index ea7415df0..d5a68abb2 100644 --- a/environment.yml +++ b/environment.yml @@ -7,11 +7,9 @@ dependencies: - python=3.9 - cudatoolkit - pip - - pytorch=1.12.1 - pytorch-mutex=1.0=cuda - - torchaudio=0.12.1 - - torchvision=0.13.1 - pip: + - pytorch - accelerate==0.16.0 - aiohttp==3.8.4 - aiosignal==1.3.1 @@ -57,11 +55,12 @@ dependencies: - sentence-transformers - umap-learn - notebook - - gradio==3.24.1 + - gradio==3.23 - gradio-client==0.0.8 - wandb - - fschat=0.1.10 - - llama-index=0.5.27 + - llama-index==0.5.27 - pymysql - unstructured==0.6.3 - pytesseract==0.3.10 + - markdown2 + - chromadb \ No newline at end of file diff --git a/examples/gradio_test.py b/examples/gradio_test.py new file mode 100644 index 000000000..f39a1ca9e --- /dev/null +++ b/examples/gradio_test.py @@ -0,0 +1,19 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- + +import gradio as gr + +def change_tab(): + return gr.Tabs.update(selected=1) + +with gr.Blocks() as demo: + with gr.Tabs() as tabs: + with gr.TabItem("Train", id=0): + t = gr.Textbox() + with gr.TabItem("Inference", id=1): + i = gr.Image() + + btn = gr.Button() + btn.click(change_tab, None, tabs) + +demo.launch() \ No newline at end of file diff --git a/pilot/commands/__init__.py b/pilot/commands/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/commands/command.py b/pilot/commands/command.py new file mode 100644 index 000000000..6a987d723 --- /dev/null +++ b/pilot/commands/command.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import functools +import importlib +import inspect +from typing import Any, Callable, Optional + +class Command: + """A class representing a command. + + Attributes: + name (str): The name of the command. + description (str): A brief description of what the command does. + signature (str): The signature of the function that the command executes. Default to None. + """ + + def __init__(self, + name: str, + description: str, + method: Callable[..., Any], + signature: str = "", + enabled: bool = True, + disabled_reason: Optional[str] = None, + ) -> None: + self.name = name + self.description = description + self.method = method + self.signature = signature if signature else str(inspect.signature(self.method)) + self.enabled = enabled + self.disabled_reason = disabled_reason + + def __call__(self, *args: Any, **kwds: Any) -> Any: + if not self.enabled: + return f"Command '{self.name}' is disabled: {self.disabled_reason}" + return self.method(*args, **kwds) \ No newline at end of file diff --git a/pilot/configs/config.py b/pilot/configs/config.py index 709de4b69..adb5e47b7 100644 --- a/pilot/configs/config.py +++ b/pilot/configs/config.py @@ -12,26 +12,55 @@ class Config(metaclass=Singleton): def __init__(self) -> None: """Initialize the Config class""" - # TODO change model_config there + # TODO change model_config there + self.debug_mode = False + self.skip_reprompt = False + + self.plugins_dir = os.getenv("PLUGINS_DIR", "plugins") + self.plugins = List[AutoGPTPluginTemplate] = [] + self.temperature = float(os.getenv("TEMPERATURE", 0.7)) + + # TODO change model_config there self.execute_local_commands = ( os.getenv("EXECUTE_LOCAL_COMMANDS", "False") == "True" ) - self.temperature = float(os.getenv("TEMPERATURE", "0.7")) + # User agent header to use when making HTTP requests + # Some websites might just completely deny request with an error code if + # no user agent was found. + self.user_agent = os.getenv( + "USER_AGENT", + "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_4) AppleWebKit/537.36" + " (KHTML, like Gecko) Chrome/83.0.4103.97 Safari/537.36", + ) + + # milvus or zilliz cloud configuration + self.milvus_addr = os.getenv("MILVUS_ADDR", "localhost:19530") + self.milvus_username = os.getenv("MILVUS_USERNAME") + self.milvus_password = os.getenv("MILVUS_PASSWORD") + self.milvus_collection = os.getenv("MILVUS_COLLECTION", "dbgpt") + self.milvus_secure = os.getenv("MILVUS_SECURE") == "True" - self.plugins_dir = os.getenv("PLUGINS_DIR", 'plugins') - self.plugins:List[AutoGPTPluginTemplate] = [] + plugins_allowlist = os.getenv("ALLOWLISTED_PLUGINS") + if plugins_allowlist: + self.plugins_allowlist = plugins_allowlist.split(",") + else: + self.plugins_allowlist = [] + plugins_denylist = os.getenv("DENYLISTED_PLUGINS") + if plugins_denylist: + self.plugins_denylist = [] + def set_debug_mode(self, value: bool) -> None: - """Set the debug mode value.""" + """Set the debug mode value""" self.debug_mode = value - def set_plugins(self,value: bool) -> None: - """Set the plugins value.""" + def set_plugins(self, value: list) -> None: + """Set the plugins value. """ self.plugins = value - def set_temperature(self, value: int) -> None: - """ Set the temperature value.""" + def set_templature(self, value: int) -> None: + """Set the temperature value.""" self.temperature = value - \ No newline at end of file + \ No newline at end of file diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py index a5c27d9d2..1e78494ce 100644 --- a/pilot/configs/model_config.py +++ b/pilot/configs/model_config.py @@ -5,6 +5,7 @@ import torch import os import nltk + ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) MODEL_PATH = os.path.join(ROOT_PATH, "models") PILOT_PATH = os.path.join(ROOT_PATH, "pilot") @@ -26,9 +27,8 @@ LLM_MODEL_CONFIG = { VECTOR_SEARCH_TOP_K = 3 LLM_MODEL = "vicuna-13b" LIMIT_MODEL_CONCURRENCY = 5 -MAX_POSITION_EMBEDDINGS = 2048 -VICUNA_MODEL_SERVER = "http://192.168.31.114:8000" - +MAX_POSITION_EMBEDDINGS = 4096 +VICUNA_MODEL_SERVER = "http://47.97.125.199:8000" # Load model config ISLOAD_8BIT = True @@ -37,7 +37,7 @@ ISDEBUG = False DB_SETTINGS = { "user": "root", - "password": "aa123456", + "password": "aa12345678", "host": "localhost", "port": 3306 } \ No newline at end of file diff --git a/pilot/conversation.py b/pilot/conversation.py index e1715e427..1f1fc50a3 100644 --- a/pilot/conversation.py +++ b/pilot/conversation.py @@ -10,6 +10,8 @@ class SeparatorStyle(Enum): SINGLE = auto() TWO = auto() + THREE = auto() + FOUR = auto() @dataclasses.dataclass class Conversation: @@ -146,10 +148,103 @@ conv_vicuna_v1 = Conversation( sep2="", ) +auto_dbgpt_one_shot = Conversation( + system="You are DB-GPT, an AI designed to answer questions about HackerNews by query `hackerbews` database in MySQL. " + "Your decisions must always be made independently without seeking user assistance. Play to your strengths as an LLM and pursue simple strategies with no legal complications.", + roles=("USER", "ASSISTANT"), + messages=( + ( + "USER", + """ Answer how many users does hackernews have by query mysql database + Constraints: + 1. ~4000 word limit for short term memory. Your short term memory is short, so immediately save important information to files. + 2. If you are unsure how you previously did something or want to recall past events, thinking about similar events will help you remember. + 3. No user assistance + 4. Exclusively use the commands listed in double quotes e.g. "command name" -conv_qa_prompt_template = """ 基于以下已知的信息, 专业、详细的回答用户的问题。 - 如果无法从提供的恶内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题", 但是你可以给出一些与问题相关答案的建议: - + Commands: + 1. analyze_code: Analyze Code, args: "code": "" + 2. execute_python_file: Execute Python File, args: "filename": "" + 3. append_to_file: Append to file, args: "filename": "", "text": "" + 4. delete_file: Delete file, args: "filename": "" + 5. list_files: List Files in Directory, args: "directory": "" + 6. read_file: Read file, args: "filename": "" + 7. write_to_file: Write to file, args: "filename": "", "text": "" + 8. tidb_sql_executor: "Execute SQL in TiDB Database.", args: "sql": "" + + Resources: + 1. Internet access for searches and information gathering. + 2. Long Term memory management. + 3. vicuna powered Agents for delegation of simple tasks. + 4. File output. + + Performance Evaluation: + 1. Continuously review and analyze your actions to ensure you are performing to the best of your abilities. + 2. Constructively self-criticize your big-picture behavior constantly. + 3. Reflect on past decisions and strategies to refine your approach. + 4. Every command has a cost, so be smart and efficient. Aim to complete tasks in the least number of steps. + 5. Write all code to a file. + + You should only respond in JSON format as described below + Response Format: + { + "thoughts": { + "text": "thought", + "reasoning": "reasoning", + "plan": "- short bulleted\n- list that conveys\n- long-term plan", + "criticism": "constructive self-criticism", + "speak": "thoughts summary to say to user" + }, + "command": { + "name": "command name", + "args": { + "arg name": "value" + } + } + } + Ensure the response can be parsed by Python json.loads + """ + ), + ( + "ASSISTANT", + """ + { + "thoughts": { + "text": "thought", + "reasoning": "reasoning", + "plan": "- short bulleted\n- list that conveys\n- long-term plan", + "criticism": "constructive self-criticism", + "speak": "thoughts summary to say to user" + }, + "command": { + "name": "command name", + "args": { + "arg name": "value" + } + } + } + """ + ) + ), + offset=0, + sep_style=SeparatorStyle.THREE, + sep=" ", + sep2="", +) + +auto_dbgpt_without_shot = Conversation( + system="You are DB-GPT, an AI designed to answer questions about HackerNews by query `hackerbews` database in MySQL. " + "Your decisions must always be made independently without seeking user assistance. Play to your strengths as an LLM and pursue simple strategies with no legal complications.", + roles=("USER", "ASSISTANT"), + messages=(), + offset=0, + sep_style=SeparatorStyle.FOUR, + sep=" ", + sep2="", +) + +conv_qa_prompt_template = """ 基于以下已知的信息, 专业、详细的回答用户的问题, + 如果无法从提供的恶内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题", 但是你可以给出一些与问题相关答案的建议。 已知内容: {context} 问题: diff --git a/pilot/model/compression.py b/pilot/model/compression.py new file mode 100644 index 000000000..9c8c25d08 --- /dev/null +++ b/pilot/model/compression.py @@ -0,0 +1,121 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +import dataclasses + +import torch +from torch import Tensor +import torch.nn as nn +from torch.nn import functional as F + + +@dataclasses.dataclass +class CompressionConfig: + """Group-wise quantization.""" + num_bits: int + group_size: int + group_dim: int + symmetric: bool + enabled: bool = True + + +default_compression_config = CompressionConfig( + num_bits=8, group_size=256, group_dim=1, symmetric=True, enabled=True) + + +class CLinear(nn.Module): + """Compressed Linear Layer.""" + + def __init__(self, weight, bias, device): + super().__init__() + + self.weight = compress(weight.data.to(device), default_compression_config) + self.bias = bias + + def forward(self, input: Tensor) -> Tensor: + weight = decompress(self.weight, default_compression_config) + return F.linear(input, weight, self.bias) + + +def compress_module(module, target_device): + for attr_str in dir(module): + target_attr = getattr(module, attr_str) + if type(target_attr) == torch.nn.Linear: + setattr(module, attr_str, + CLinear(target_attr.weight, target_attr.bias, target_device)) + for name, child in module.named_children(): + compress_module(child, target_device) + + +def compress(tensor, config): + """Simulate group-wise quantization.""" + if not config.enabled: + return tensor + + group_size, num_bits, group_dim, symmetric = ( + config.group_size, config.num_bits, config.group_dim, config.symmetric) + assert num_bits <= 8 + + original_shape = tensor.shape + num_groups = (original_shape[group_dim] + group_size - 1) // group_size + new_shape = (original_shape[:group_dim] + (num_groups, group_size) + + original_shape[group_dim+1:]) + + # Pad + pad_len = (group_size - original_shape[group_dim] % group_size) % group_size + if pad_len != 0: + pad_shape = original_shape[:group_dim] + (pad_len,) + original_shape[group_dim+1:] + tensor = torch.cat([ + tensor, + torch.zeros(pad_shape, dtype=tensor.dtype, device=tensor.device)], + dim=group_dim) + data = tensor.view(new_shape) + + # Quantize + if symmetric: + B = 2 ** (num_bits - 1) - 1 + scale = B / torch.max(data.abs(), dim=group_dim + 1, keepdim=True)[0] + data = data * scale + data = data.clamp_(-B, B).round_().to(torch.int8) + return data, scale, original_shape + else: + B = 2 ** num_bits - 1 + mn = torch.min(data, dim=group_dim + 1, keepdim=True)[0] + mx = torch.max(data, dim=group_dim + 1, keepdim=True)[0] + + scale = B / (mx - mn) + data = data - mn + data.mul_(scale) + + data = data.clamp_(0, B).round_().to(torch.uint8) + return data, mn, scale, original_shape + + +def decompress(packed_data, config): + """Simulate group-wise dequantization.""" + if not config.enabled: + return packed_data + + group_size, num_bits, group_dim, symmetric = ( + config.group_size, config.num_bits, config.group_dim, config.symmetric) + + # Dequantize + if symmetric: + data, scale, original_shape = packed_data + data = data / scale + else: + data, mn, scale, original_shape = packed_data + data = data / scale + data.add_(mn) + + # Unpad + pad_len = (group_size - original_shape[group_dim] % group_size) % group_size + if pad_len: + padded_original_shape = ( + original_shape[:group_dim] + + (original_shape[group_dim] + pad_len,) + + original_shape[group_dim+1:]) + data = data.reshape(padded_original_shape) + indices = [slice(0, x) for x in original_shape] + return data[indices].contiguous() + else: + return data.view(original_shape) diff --git a/pilot/model/inference.py b/pilot/model/inference.py index 66766b3b3..532be9c33 100644 --- a/pilot/model/inference.py +++ b/pilot/model/inference.py @@ -5,13 +5,13 @@ import torch @torch.inference_mode() def generate_stream(model, tokenizer, params, device, - context_len=2048, stream_interval=2): + context_len=4096, stream_interval=2): """Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/inference.py """ prompt = params["prompt"] l_prompt = len(prompt) temperature = float(params.get("temperature", 1.0)) - max_new_tokens = int(params.get("max_new_tokens", 256)) + max_new_tokens = int(params.get("max_new_tokens", 2048)) stop_str = params.get("stop", None) input_ids = tokenizer(prompt).input_ids diff --git a/pilot/model/llm/base.py b/pilot/model/llm/base.py new file mode 100644 index 000000000..435cc0d5f --- /dev/null +++ b/pilot/model/llm/base.py @@ -0,0 +1,34 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +from dataclasses import dataclass, field +from typing import List, TypedDict + + +class Message(TypedDict): + """Vicuna Message object containing a role and the message content """ + + role: str + content: str + + +@dataclass +class ModelInfo: + """Struct for model information. + + Would be lovely to eventually get this directly from APIs + """ + name: str + max_tokens: int + +@dataclass +class LLMResponse: + """Standard response struct for a response from a LLM model.""" + model_info = ModelInfo + + +@dataclass +class ChatModelResponse(LLMResponse): + """Standard response struct for a response from an LLM model.""" + + content: str = None \ No newline at end of file diff --git a/pilot/model/llm/llm_utils.py b/pilot/model/llm/llm_utils.py new file mode 100644 index 000000000..a68860ee6 --- /dev/null +++ b/pilot/model/llm/llm_utils.py @@ -0,0 +1,108 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import abc +import time +import functools +from typing import List, Optional +from pilot.model.llm.base import Message +from pilot.conversation import conv_templates, Conversation, conv_one_shot, auto_dbgpt_one_shot +from pilot.configs.config import Config + + +# TODO Rewrite this +def retry_stream_api( + num_retries: int = 10, + backoff_base: float = 2.0, + warn_user: bool = True +): + """Retry an Vicuna Server call. + + Args: + num_retries int: Number of retries. Defaults to 10. + backoff_base float: Base for exponential backoff. Defaults to 2. + warn_user bool: Whether to warn the user. Defaults to True. + """ + retry_limit_msg = f"Error: Reached rate limit, passing..." + backoff_msg = (f"Error: API Bad gateway. Waiting {{backoff}} seconds...") + + def _wrapper(func): + @functools.wraps(func) + def _wrapped(*args, **kwargs): + user_warned = not warn_user + num_attempts = num_retries + 1 # +1 for the first attempt + for attempt in range(1, num_attempts + 1): + try: + return func(*args, **kwargs) + except Exception as e: + if (e.http_status != 502) or (attempt == num_attempts): + raise + + backoff = backoff_base ** (attempt + 2) + time.sleep(backoff) + return _wrapped + return _wrapper + +# Overly simple abstraction util we create something better +# simple retry mechanism when getting a rate error or a bad gateway +def create_chat_competion( + conv: Conversation, + model: Optional[str] = None, + temperature: float = None, + max_new_tokens: Optional[int] = None, +) -> str: + """Create a chat completion using the Vicuna-13b + + Args: + messages(List[Message]): The messages to send to the chat completion + model (str, optional): The model to use. Default to None. + temperature (float, optional): The temperature to use. Defaults to 0.7. + max_tokens (int, optional): The max tokens to use. Defaults to None. + + Returns: + str: The response from the chat completion + """ + cfg = Config() + if temperature is None: + temperature = cfg.temperature + + # TODO request vicuna model get response + # convert vicuna message to chat completion. + for plugin in cfg.plugins: + if plugin.can_handle_chat_completion(): + pass + + +class ChatIO(abc.ABC): + @abc.abstractmethod + def prompt_for_input(self, role: str) -> str: + """Prompt for input from a role.""" + + @abc.abstractmethod + def prompt_for_output(self, role: str) -> str: + """Prompt for output from a role.""" + + @abc.abstractmethod + def stream_output(self, output_stream, skip_echo_len: int): + """Stream output.""" + + +class SimpleChatIO(ChatIO): + def prompt_for_input(self, role: str) -> str: + return input(f"{role}: ") + + def prompt_for_output(self, role: str) -> str: + print(f"{role}: ", end="", flush=True) + + def stream_output(self, output_stream, skip_echo_len: int): + pre = 0 + for outputs in output_stream: + outputs = outputs[skip_echo_len:].strip() + now = len(outputs) - 1 + if now > pre: + print(" ".join(outputs[pre:now]), end=" ", flush=True) + pre = now + + print(" ".join(outputs[pre:]), flush=True) + return " ".join(outputs) + diff --git a/pilot/model/loader.py b/pilot/model/loader.py index 747585fa4..3c0b9a6a7 100644 --- a/pilot/model/loader.py +++ b/pilot/model/loader.py @@ -2,15 +2,17 @@ # -*- coding: utf-8 -*- import torch +from pilot.singleton import Singleton + from transformers import ( AutoTokenizer, AutoModelForCausalLM, AutoModel ) -from fastchat.serve.compression import compress_module +from pilot.model.compression import compress_module -class ModelLoader: +class ModelLoader(metaclass=Singleton): """Model loader is a class for model load Args: model_path diff --git a/pilot/server/vicuna_server.py b/pilot/server/vicuna_server.py index 868e8b6d9..95781b69b 100644 --- a/pilot/server/vicuna_server.py +++ b/pilot/server/vicuna_server.py @@ -10,8 +10,6 @@ from fastapi.responses import StreamingResponse from pilot.model.inference import generate_stream from pydantic import BaseModel from pilot.model.inference import generate_output, get_embeddings -from fastchat.serve.inference import load_model - from pilot.model.loader import ModelLoader from pilot.configs.model_config import * diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index bbe710667..c34ecd934 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -24,11 +24,9 @@ from pilot.conversation import ( SeparatorStyle ) -from fastchat.utils import ( +from pilot.utils import ( build_logger, server_error_msg, - violates_moderation, - moderation_msg ) from pilot.server.gradio_css import code_highlight_css @@ -45,6 +43,7 @@ enable_moderation = False models = [] dbs = [] vs_list = ["新建知识库"] + get_vector_storelist() +autogpt = False priority = { "vicuna-13b": "aaa" @@ -58,8 +57,6 @@ def get_simlar(q): contents = [dc.page_content for dc, _ in docs] return "\n".join(contents) - - def gen_sqlgen_conversation(dbname): mo = MySQLOperator( **DB_SETTINGS @@ -118,6 +115,8 @@ def regenerate(state, request: gr.Request): return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5 def clear_history(request: gr.Request): + + logger.info(f"clear_history. ip: {request.client.host}") state = None return (state, [], "") + (disable_btn,) * 5 @@ -128,14 +127,9 @@ def add_text(state, text, request: gr.Request): if len(text) <= 0: state.skip_next = True return (state, state.to_gradio_chatbot(), "") + (no_change_btn,) * 5 - if args.moderate: - flagged = violates_moderation(text) - if flagged: - state.skip_next = True - return (state, state.to_gradio_chatbot(), moderation_msg) + ( - no_change_btn,) * 5 - text = text[:1536] # Hard cut-off + """ Default support 4000 tokens, if tokens too lang, we will cut off """ + text = text[:4000] state.append_message(state.roles[0], text) state.append_message(state.roles[1], None) state.skip_next = False @@ -152,6 +146,8 @@ def post_process_code(code): return code def http_bot(state, mode, db_selector, temperature, max_new_tokens, request: gr.Request): + + print("是否是AUTO-GPT模式.", autogpt) start_tstamp = time.time() model_name = LLM_MODEL @@ -162,7 +158,8 @@ def http_bot(state, mode, db_selector, temperature, max_new_tokens, request: gr. yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5 return - + + # TODO when tab mode is AUTO_GPT, Prompt need to rebuild. if len(state.messages) == state.offset + 2: # 第一轮对话需要加入提示Prompt @@ -251,29 +248,28 @@ def http_bot(state, mode, db_selector, temperature, max_new_tokens, request: gr. block_css = ( code_highlight_css + """ -pre { - white-space: pre-wrap; /* Since CSS 2.1 */ - white-space: -moz-pre-wrap; /* Mozilla, since 1999 */ - white-space: -pre-wrap; /* Opera 4-6 */ - white-space: -o-pre-wrap; /* Opera 7 */ - word-wrap: break-word; /* Internet Explorer 5.5+ */ -} -#notice_markdown th { - display: none; -} - """ + pre { + white-space: pre-wrap; /* Since CSS 2.1 */ + white-space: -moz-pre-wrap; /* Mozilla, since 1999 */ + white-space: -pre-wrap; /* Opera 4-6 */ + white-space: -o-pre-wrap; /* Opera 7 */ + word-wrap: break-word; /* Internet Explorer 5.5+ */ + } + #notice_markdown th { + display: none; + } + """ ) -def change_tab(tab): - pass - def change_mode(mode): if mode in ["默认知识库对话", "LLM原生对话"]: return gr.update(visible=False) else: return gr.update(visible=True) - +def change_tab(): + autogpt = True + def build_single_model_ui(): notice_markdown = """ @@ -301,16 +297,17 @@ def build_single_model_ui(): max_output_tokens = gr.Slider( minimum=0, - maximum=1024, - value=512, + maximum=4096, + value=2048, step=64, interactive=True, label="最大输出Token数", ) - tabs = gr.Tabs() + tabs= gr.Tabs() with tabs: - with gr.TabItem("SQL生成与诊断", elem_id="SQL"): - # TODO A selector to choose database + tab_sql = gr.TabItem("SQL生成与诊断", elem_id="SQL") + with tab_sql: + # TODO A selector to choose database with gr.Row(elem_id="db_selector"): db_selector = gr.Dropdown( label="请选择数据库", @@ -318,9 +315,12 @@ def build_single_model_ui(): value=dbs[0] if len(models) > 0 else "", interactive=True, show_label=True).style(container=False) + tab_auto = gr.TabItem("AUTO-GPT", elem_id="auto") + with tab_auto: + gr.Markdown("自动执行模式下, DB-GPT可以具备执行SQL、从网络读取知识自动化存储学习的能力") - with gr.TabItem("知识问答", elem_id="QA"): - + tab_qa = gr.TabItem("知识问答", elem_id="QA") + with tab_qa: mode = gr.Radio(["LLM原生对话", "默认知识库对话", "新增知识库对话"], show_label=False, value="LLM原生对话") vs_setting = gr.Accordion("配置知识库", open=False) mode.change(fn=change_mode, inputs=mode, outputs=vs_setting) @@ -360,9 +360,7 @@ def build_single_model_ui(): regenerate_btn = gr.Button(value="重新生成", interactive=False) clear_btn = gr.Button(value="清理", interactive=False) - gr.Markdown(learn_more_markdown) - btn_list = [regenerate_btn, clear_btn] regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then( http_bot, @@ -434,9 +432,7 @@ if __name__ == "__main__": "--model-list-mode", type=str, default="once", choices=["once", "reload"] ) parser.add_argument("--share", default=False, action="store_true") - parser.add_argument( - "--moderate", action="store_true", help="Enable content moderation" - ) + args = parser.parse_args() logger.info(f"args: {args}") diff --git a/pilot/utils.py b/pilot/utils.py index 093b14f99..0179d12c2 100644 --- a/pilot/utils.py +++ b/pilot/utils.py @@ -3,6 +3,20 @@ import torch +import datetime +import logging +import logging.handlers +import os +import sys + +import requests + +from pilot.configs.model_config import LOGDIR + +server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" + +handler = None + def get_gpu_memory(max_gpus=None): gpu_memory = [] num_gpus = ( @@ -20,3 +34,98 @@ def get_gpu_memory(max_gpus=None): available_memory = total_memory - allocated_memory gpu_memory.append(available_memory) return gpu_memory + + + +def build_logger(logger_name, logger_filename): + global handler + + formatter = logging.Formatter( + fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + + # Set the format of root handlers + if not logging.getLogger().handlers: + logging.basicConfig(level=logging.INFO, encoding='utf-8') + logging.getLogger().handlers[0].setFormatter(formatter) + + # Redirect stdout and stderr to loggers + stdout_logger = logging.getLogger("stdout") + stdout_logger.setLevel(logging.INFO) + sl = StreamToLogger(stdout_logger, logging.INFO) + sys.stdout = sl + + stderr_logger = logging.getLogger("stderr") + stderr_logger.setLevel(logging.ERROR) + sl = StreamToLogger(stderr_logger, logging.ERROR) + sys.stderr = sl + + # Get logger + logger = logging.getLogger(logger_name) + logger.setLevel(logging.INFO) + + # Add a file handler for all loggers + if handler is None: + os.makedirs(LOGDIR, exist_ok=True) + filename = os.path.join(LOGDIR, logger_filename) + handler = logging.handlers.TimedRotatingFileHandler( + filename, when='D', utc=True) + handler.setFormatter(formatter) + + for name, item in logging.root.manager.loggerDict.items(): + if isinstance(item, logging.Logger): + item.addHandler(handler) + + return logger + + +class StreamToLogger(object): + """ + Fake file-like stream object that redirects writes to a logger instance. + """ + def __init__(self, logger, log_level=logging.INFO): + self.terminal = sys.stdout + self.logger = logger + self.log_level = log_level + self.linebuf = '' + + def __getattr__(self, attr): + return getattr(self.terminal, attr) + + def write(self, buf): + temp_linebuf = self.linebuf + buf + self.linebuf = '' + for line in temp_linebuf.splitlines(True): + # From the io.TextIOWrapper docs: + # On output, if newline is None, any '\n' characters written + # are translated to the system default line separator. + # By default sys.stdout.write() expects '\n' newlines and then + # translates them so this is still cross platform. + if line[-1] == '\n': + encoded_message = line.encode('utf-8', 'ignore').decode('utf-8') + self.logger.log(self.log_level, encoded_message.rstrip()) + else: + self.linebuf += line + + def flush(self): + if self.linebuf != '': + encoded_message = self.linebuf.encode('utf-8', 'ignore').decode('utf-8') + self.logger.log(self.log_level, encoded_message.rstrip()) + self.linebuf = '' + + +def disable_torch_init(): + """ + Disable the redundant torch default initialization to accelerate model creation. + """ + import torch + setattr(torch.nn.Linear, "reset_parameters", lambda self: None) + setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) + + +def pretty_print_semaphore(semaphore): + if semaphore is None: + return "None" + return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" + diff --git a/requirements.txt b/requirements.txt index 0582f5a41..7102191ab 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,5 @@ accelerate==0.16.0 torch==2.0.0 -torchvision==0.13.1 -torchaudio==0.12.1 accelerate==0.16.0 aiohttp==3.8.4 aiosignal==1.3.1 @@ -47,11 +45,12 @@ pycocoevalcap sentence-transformers umap-learn notebook -gradio==3.24.1 +gradio==3.23 gradio-client==0.0.8 wandb -fschat=0.1.10 -llama-index=0.5.27 +llama-index==0.5.27 pymysql unstructured==0.6.3 -pytesseract==0.3.10 \ No newline at end of file +pytesseract==0.3.10 +chromadb +markdown2 \ No newline at end of file