From 473cdf6017f95cd72e4a8f7b8e89510f72d18e44 Mon Sep 17 00:00:00 2001 From: "magic.chen" Date: Sat, 29 Apr 2023 23:30:31 +0800 Subject: [PATCH 01/13] Create pylint.yml --- .github/workflows/pylint.yml | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) create mode 100644 .github/workflows/pylint.yml diff --git a/.github/workflows/pylint.yml b/.github/workflows/pylint.yml new file mode 100644 index 000000000..383e65cd0 --- /dev/null +++ b/.github/workflows/pylint.yml @@ -0,0 +1,23 @@ +name: Pylint + +on: [push] + +jobs: + build: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.8", "3.9", "3.10"] + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install pylint + - name: Analysing the code with pylint + run: | + pylint $(git ls-files '*.py') From c755beba8e0a3180dc4b037a2b504574a501d2fe Mon Sep 17 00:00:00 2001 From: "magic.chen" Date: Sat, 29 Apr 2023 23:30:55 +0800 Subject: [PATCH 02/13] Create python-publish.yml --- .github/workflows/python-publish.yml | 39 ++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 .github/workflows/python-publish.yml diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml new file mode 100644 index 000000000..bdaab28a4 --- /dev/null +++ b/.github/workflows/python-publish.yml @@ -0,0 +1,39 @@ +# This workflow will upload a Python Package using Twine when a release is created +# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries + +# This workflow uses actions that are not certified by GitHub. +# They are provided by a third-party and are governed by +# separate terms of service, privacy policy, and support +# documentation. + +name: Upload Python Package + +on: + release: + types: [published] + +permissions: + contents: read + +jobs: + deploy: + + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + - name: Set up Python + uses: actions/setup-python@v3 + with: + python-version: '3.x' + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install build + - name: Build package + run: python -m build + - name: Publish package + uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 + with: + user: __token__ + password: ${{ secrets.PYPI_API_TOKEN }} From bfbbf0ba88732f1457c9dec3ba37db8d6cca48de Mon Sep 17 00:00:00 2001 From: csunny Date: Tue, 9 May 2023 21:48:47 +0800 Subject: [PATCH 03/13] update conversation --- pilot/conversation.py | 2 -- pilot/model/llm/llm_utils.py | 40 +++++++++++++++++++++++++++++++++++- pilot/model/loader.py | 2 +- 3 files changed, 40 insertions(+), 4 deletions(-) diff --git a/pilot/conversation.py b/pilot/conversation.py index d52a51b41..877e61a80 100644 --- a/pilot/conversation.py +++ b/pilot/conversation.py @@ -251,8 +251,6 @@ conv_qa_prompt_template = """ 基于以下已知的信息, 专业、详细的回 {question} """ - - default_conversation = conv_one_shot conversation_types = { diff --git a/pilot/model/llm/llm_utils.py b/pilot/model/llm/llm_utils.py index 7a1fb47bd..a68860ee6 100644 --- a/pilot/model/llm/llm_utils.py +++ b/pilot/model/llm/llm_utils.py @@ -2,11 +2,47 @@ # -*- coding: utf-8 -*- import abc +import time +import functools from typing import List, Optional from pilot.model.llm.base import Message from pilot.conversation import conv_templates, Conversation, conv_one_shot, auto_dbgpt_one_shot from pilot.configs.config import Config + +# TODO Rewrite this +def retry_stream_api( + num_retries: int = 10, + backoff_base: float = 2.0, + warn_user: bool = True +): + """Retry an Vicuna Server call. + + Args: + num_retries int: Number of retries. Defaults to 10. + backoff_base float: Base for exponential backoff. Defaults to 2. + warn_user bool: Whether to warn the user. Defaults to True. + """ + retry_limit_msg = f"Error: Reached rate limit, passing..." + backoff_msg = (f"Error: API Bad gateway. Waiting {{backoff}} seconds...") + + def _wrapper(func): + @functools.wraps(func) + def _wrapped(*args, **kwargs): + user_warned = not warn_user + num_attempts = num_retries + 1 # +1 for the first attempt + for attempt in range(1, num_attempts + 1): + try: + return func(*args, **kwargs) + except Exception as e: + if (e.http_status != 502) or (attempt == num_attempts): + raise + + backoff = backoff_base ** (attempt + 2) + time.sleep(backoff) + return _wrapped + return _wrapper + # Overly simple abstraction util we create something better # simple retry mechanism when getting a rate error or a bad gateway def create_chat_competion( @@ -31,8 +67,10 @@ def create_chat_competion( temperature = cfg.temperature # TODO request vicuna model get response + # convert vicuna message to chat completion. for plugin in cfg.plugins: - pass + if plugin.can_handle_chat_completion(): + pass class ChatIO(abc.ABC): diff --git a/pilot/model/loader.py b/pilot/model/loader.py index 747585fa4..55458ff4a 100644 --- a/pilot/model/loader.py +++ b/pilot/model/loader.py @@ -10,7 +10,7 @@ from transformers import ( from fastchat.serve.compression import compress_module -class ModelLoader: +class ModelLoader(): """Model loader is a class for model load Args: model_path From 8a1e68ea513967fa145c9fbe8fe0f917665806f5 Mon Sep 17 00:00:00 2001 From: "magic.chen" Date: Tue, 9 May 2023 23:49:37 +0800 Subject: [PATCH 04/13] Update README.md MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 更新readme file, add requirement txt --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index edf0e71cf..7378799d3 100644 --- a/README.md +++ b/README.md @@ -44,7 +44,7 @@ The Generated SQL is runable. 1. First you need to install python requirements. ``` python>=3.9 -pip install -r requirements +pip install -r requirements.txt ``` or if you use conda envirenment, you can use this command ``` @@ -63,7 +63,7 @@ The password just for test, you can change this if necessary # Install 1. 基础模型下载 关于基础模型, 可以根据[vicuna](https://github.com/lm-sys/FastChat/blob/main/README.md#model-weights)合成教程进行合成。 -如果此步有困难的同学,也可以直接使用[Hugging Face](https://huggingface.co/)上的模型进行替代。 替代模型: [vicuna-13b](https://huggingface.co/Tribbiani/vicuna-13b) +如果此步有困难的同学,也可以直接使用[Hugging Face](https://huggingface.co/)上的模型进行替代. [替代模型](https://huggingface.co/Tribbiani/vicuna-7b) 2. Run model server ``` @@ -86,4 +86,4 @@ python webserver.py # Contribute [Contribute](https://github.com/csunny/DB-GPT/blob/main/CONTRIBUTING) # Licence -[MIT](https://github.com/csunny/DB-GPT/blob/main/LICENSE) \ No newline at end of file +[MIT](https://github.com/csunny/DB-GPT/blob/main/LICENSE) From fd8bc8d169d0f84a43811584f77df5f33535b0d8 Mon Sep 17 00:00:00 2001 From: csunny Date: Wed, 10 May 2023 10:53:48 +0800 Subject: [PATCH 05/13] modelLoader use singleton --- pilot/model/loader.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pilot/model/loader.py b/pilot/model/loader.py index 55458ff4a..c5788d2ab 100644 --- a/pilot/model/loader.py +++ b/pilot/model/loader.py @@ -2,6 +2,8 @@ # -*- coding: utf-8 -*- import torch +from pilot.singleton import Singleton + from transformers import ( AutoTokenizer, AutoModelForCausalLM, @@ -10,7 +12,7 @@ from transformers import ( from fastchat.serve.compression import compress_module -class ModelLoader(): +class ModelLoader(metaclass=Singleton): """Model loader is a class for model load Args: model_path From eb069f1a456f59ce717a29654226967147bd29d7 Mon Sep 17 00:00:00 2001 From: csunny Date: Wed, 10 May 2023 20:52:31 +0800 Subject: [PATCH 06/13] update --- pilot/commands/__init__.py | 0 pilot/commands/command.py | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 36 insertions(+) create mode 100644 pilot/commands/__init__.py create mode 100644 pilot/commands/command.py diff --git a/pilot/commands/__init__.py b/pilot/commands/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/commands/command.py b/pilot/commands/command.py new file mode 100644 index 000000000..6a987d723 --- /dev/null +++ b/pilot/commands/command.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import functools +import importlib +import inspect +from typing import Any, Callable, Optional + +class Command: + """A class representing a command. + + Attributes: + name (str): The name of the command. + description (str): A brief description of what the command does. + signature (str): The signature of the function that the command executes. Default to None. + """ + + def __init__(self, + name: str, + description: str, + method: Callable[..., Any], + signature: str = "", + enabled: bool = True, + disabled_reason: Optional[str] = None, + ) -> None: + self.name = name + self.description = description + self.method = method + self.signature = signature if signature else str(inspect.signature(self.method)) + self.enabled = enabled + self.disabled_reason = disabled_reason + + def __call__(self, *args: Any, **kwds: Any) -> Any: + if not self.enabled: + return f"Command '{self.name}' is disabled: {self.disabled_reason}" + return self.method(*args, **kwds) \ No newline at end of file From 9b813a801a8e238f2d1b59fd3db2301ef1e5cd23 Mon Sep 17 00:00:00 2001 From: csunny Date: Wed, 10 May 2023 21:02:50 +0800 Subject: [PATCH 07/13] fix and update --- pilot/conversation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pilot/conversation.py b/pilot/conversation.py index 877e61a80..1f1fc50a3 100644 --- a/pilot/conversation.py +++ b/pilot/conversation.py @@ -227,7 +227,7 @@ auto_dbgpt_one_shot = Conversation( ) ), offset=0, - sep_style=SeparatorStyle.THREE(), + sep_style=SeparatorStyle.THREE, sep=" ", sep2="", ) @@ -238,7 +238,7 @@ auto_dbgpt_without_shot = Conversation( roles=("USER", "ASSISTANT"), messages=(), offset=0, - sep_style=SeparatorStyle.FOUR(), + sep_style=SeparatorStyle.FOUR, sep=" ", sep2="", ) From ff50936076efaec0a623a62db89945a72f9c935d Mon Sep 17 00:00:00 2001 From: csunny Date: Thu, 11 May 2023 01:29:04 +0800 Subject: [PATCH 08/13] fix --- requirements.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index b22b2d2ad..d8219ad3d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,5 @@ accelerate==0.16.0 torch==2.0.0 -torchvision==0.13.1 torchaudio==0.12.1 accelerate==0.16.0 aiohttp==3.8.4 From 75fbf7f504282cf7e4c334e22569b1ee38b38f06 Mon Sep 17 00:00:00 2001 From: csunny Date: Thu, 11 May 2023 10:51:37 +0800 Subject: [PATCH 09/13] update --- requirements.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index d8219ad3d..cdf08db8f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,5 @@ accelerate==0.16.0 torch==2.0.0 -torchaudio==0.12.1 accelerate==0.16.0 aiohttp==3.8.4 aiosignal==1.3.1 From 6d76825a109db6a8d0bc14cf9ccf291ea5ced5b7 Mon Sep 17 00:00:00 2001 From: csunny Date: Thu, 11 May 2023 10:59:08 +0800 Subject: [PATCH 10/13] rm fschat relay --- environment.yml | 5 +- pilot/configs/model_config.py | 1 + pilot/model/compression.py | 121 +++++++++++++++++++++++++++++++ pilot/model/loader.py | 2 +- pilot/server/webserver.py | 2 +- pilot/utils.py | 131 ++++++++++++++++++++++++++++++++++ requirements.txt | 1 - 7 files changed, 256 insertions(+), 7 deletions(-) create mode 100644 pilot/model/compression.py diff --git a/environment.yml b/environment.yml index ce2c7339f..db3929e23 100644 --- a/environment.yml +++ b/environment.yml @@ -7,11 +7,9 @@ dependencies: - python=3.9 - cudatoolkit - pip - - pytorch=1.12.1 - pytorch-mutex=1.0=cuda - - torchaudio=0.12.1 - - torchvision=0.13.1 - pip: + - pytorch - accelerate==0.16.0 - aiohttp==3.8.4 - aiosignal==1.3.1 @@ -60,7 +58,6 @@ dependencies: - gradio==3.23 - gradio-client==0.0.8 - wandb - - fschat==0.1.10 - llama-index==0.5.27 - pymysql - unstructured==0.6.3 diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py index 7b4269fb6..eee05f3e1 100644 --- a/pilot/configs/model_config.py +++ b/pilot/configs/model_config.py @@ -5,6 +5,7 @@ import torch import os import nltk + ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) MODEL_PATH = os.path.join(ROOT_PATH, "models") PILOT_PATH = os.path.join(ROOT_PATH, "pilot") diff --git a/pilot/model/compression.py b/pilot/model/compression.py new file mode 100644 index 000000000..9c8c25d08 --- /dev/null +++ b/pilot/model/compression.py @@ -0,0 +1,121 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +import dataclasses + +import torch +from torch import Tensor +import torch.nn as nn +from torch.nn import functional as F + + +@dataclasses.dataclass +class CompressionConfig: + """Group-wise quantization.""" + num_bits: int + group_size: int + group_dim: int + symmetric: bool + enabled: bool = True + + +default_compression_config = CompressionConfig( + num_bits=8, group_size=256, group_dim=1, symmetric=True, enabled=True) + + +class CLinear(nn.Module): + """Compressed Linear Layer.""" + + def __init__(self, weight, bias, device): + super().__init__() + + self.weight = compress(weight.data.to(device), default_compression_config) + self.bias = bias + + def forward(self, input: Tensor) -> Tensor: + weight = decompress(self.weight, default_compression_config) + return F.linear(input, weight, self.bias) + + +def compress_module(module, target_device): + for attr_str in dir(module): + target_attr = getattr(module, attr_str) + if type(target_attr) == torch.nn.Linear: + setattr(module, attr_str, + CLinear(target_attr.weight, target_attr.bias, target_device)) + for name, child in module.named_children(): + compress_module(child, target_device) + + +def compress(tensor, config): + """Simulate group-wise quantization.""" + if not config.enabled: + return tensor + + group_size, num_bits, group_dim, symmetric = ( + config.group_size, config.num_bits, config.group_dim, config.symmetric) + assert num_bits <= 8 + + original_shape = tensor.shape + num_groups = (original_shape[group_dim] + group_size - 1) // group_size + new_shape = (original_shape[:group_dim] + (num_groups, group_size) + + original_shape[group_dim+1:]) + + # Pad + pad_len = (group_size - original_shape[group_dim] % group_size) % group_size + if pad_len != 0: + pad_shape = original_shape[:group_dim] + (pad_len,) + original_shape[group_dim+1:] + tensor = torch.cat([ + tensor, + torch.zeros(pad_shape, dtype=tensor.dtype, device=tensor.device)], + dim=group_dim) + data = tensor.view(new_shape) + + # Quantize + if symmetric: + B = 2 ** (num_bits - 1) - 1 + scale = B / torch.max(data.abs(), dim=group_dim + 1, keepdim=True)[0] + data = data * scale + data = data.clamp_(-B, B).round_().to(torch.int8) + return data, scale, original_shape + else: + B = 2 ** num_bits - 1 + mn = torch.min(data, dim=group_dim + 1, keepdim=True)[0] + mx = torch.max(data, dim=group_dim + 1, keepdim=True)[0] + + scale = B / (mx - mn) + data = data - mn + data.mul_(scale) + + data = data.clamp_(0, B).round_().to(torch.uint8) + return data, mn, scale, original_shape + + +def decompress(packed_data, config): + """Simulate group-wise dequantization.""" + if not config.enabled: + return packed_data + + group_size, num_bits, group_dim, symmetric = ( + config.group_size, config.num_bits, config.group_dim, config.symmetric) + + # Dequantize + if symmetric: + data, scale, original_shape = packed_data + data = data / scale + else: + data, mn, scale, original_shape = packed_data + data = data / scale + data.add_(mn) + + # Unpad + pad_len = (group_size - original_shape[group_dim] % group_size) % group_size + if pad_len: + padded_original_shape = ( + original_shape[:group_dim] + + (original_shape[group_dim] + pad_len,) + + original_shape[group_dim+1:]) + data = data.reshape(padded_original_shape) + indices = [slice(0, x) for x in original_shape] + return data[indices].contiguous() + else: + return data.view(original_shape) diff --git a/pilot/model/loader.py b/pilot/model/loader.py index c5788d2ab..3c0b9a6a7 100644 --- a/pilot/model/loader.py +++ b/pilot/model/loader.py @@ -10,7 +10,7 @@ from transformers import ( AutoModel ) -from fastchat.serve.compression import compress_module +from pilot.model.compression import compress_module class ModelLoader(metaclass=Singleton): """Model loader is a class for model load diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index bbe710667..1ca3cee20 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -24,7 +24,7 @@ from pilot.conversation import ( SeparatorStyle ) -from fastchat.utils import ( +from pilot.utils import ( build_logger, server_error_msg, violates_moderation, diff --git a/pilot/utils.py b/pilot/utils.py index 093b14f99..b2505eba1 100644 --- a/pilot/utils.py +++ b/pilot/utils.py @@ -3,6 +3,21 @@ import torch +import datetime +import logging +import logging.handlers +import os +import sys + +import requests + +from pilot.configs.model_config import LOGDIR + +server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" +moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN." + +handler = None + def get_gpu_memory(max_gpus=None): gpu_memory = [] num_gpus = ( @@ -20,3 +35,119 @@ def get_gpu_memory(max_gpus=None): available_memory = total_memory - allocated_memory gpu_memory.append(available_memory) return gpu_memory + + + +def build_logger(logger_name, logger_filename): + global handler + + formatter = logging.Formatter( + fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + + # Set the format of root handlers + if not logging.getLogger().handlers: + logging.basicConfig(level=logging.INFO, encoding='utf-8') + logging.getLogger().handlers[0].setFormatter(formatter) + + # Redirect stdout and stderr to loggers + stdout_logger = logging.getLogger("stdout") + stdout_logger.setLevel(logging.INFO) + sl = StreamToLogger(stdout_logger, logging.INFO) + sys.stdout = sl + + stderr_logger = logging.getLogger("stderr") + stderr_logger.setLevel(logging.ERROR) + sl = StreamToLogger(stderr_logger, logging.ERROR) + sys.stderr = sl + + # Get logger + logger = logging.getLogger(logger_name) + logger.setLevel(logging.INFO) + + # Add a file handler for all loggers + if handler is None: + os.makedirs(LOGDIR, exist_ok=True) + filename = os.path.join(LOGDIR, logger_filename) + handler = logging.handlers.TimedRotatingFileHandler( + filename, when='D', utc=True) + handler.setFormatter(formatter) + + for name, item in logging.root.manager.loggerDict.items(): + if isinstance(item, logging.Logger): + item.addHandler(handler) + + return logger + + +class StreamToLogger(object): + """ + Fake file-like stream object that redirects writes to a logger instance. + """ + def __init__(self, logger, log_level=logging.INFO): + self.terminal = sys.stdout + self.logger = logger + self.log_level = log_level + self.linebuf = '' + + def __getattr__(self, attr): + return getattr(self.terminal, attr) + + def write(self, buf): + temp_linebuf = self.linebuf + buf + self.linebuf = '' + for line in temp_linebuf.splitlines(True): + # From the io.TextIOWrapper docs: + # On output, if newline is None, any '\n' characters written + # are translated to the system default line separator. + # By default sys.stdout.write() expects '\n' newlines and then + # translates them so this is still cross platform. + if line[-1] == '\n': + encoded_message = line.encode('utf-8', 'ignore').decode('utf-8') + self.logger.log(self.log_level, encoded_message.rstrip()) + else: + self.linebuf += line + + def flush(self): + if self.linebuf != '': + encoded_message = self.linebuf.encode('utf-8', 'ignore').decode('utf-8') + self.logger.log(self.log_level, encoded_message.rstrip()) + self.linebuf = '' + + +def disable_torch_init(): + """ + Disable the redundant torch default initialization to accelerate model creation. + """ + import torch + setattr(torch.nn.Linear, "reset_parameters", lambda self: None) + setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) + + +def violates_moderation(text): + """ + Check whether the text violates OpenAI moderation API. + """ + url = "https://api.openai.com/v1/moderations" + headers = {"Content-Type": "application/json", + "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]} + text = text.replace("\n", "") + data = "{" + '"input": ' + f'"{text}"' + "}" + data = data.encode("utf-8") + try: + ret = requests.post(url, headers=headers, data=data, timeout=5) + flagged = ret.json()["results"][0]["flagged"] + except requests.exceptions.RequestException as e: + flagged = False + except KeyError as e: + flagged = False + + return flagged + + +def pretty_print_semaphore(semaphore): + if semaphore is None: + return "None" + return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" + diff --git a/requirements.txt b/requirements.txt index cdf08db8f..3e86cd311 100644 --- a/requirements.txt +++ b/requirements.txt @@ -48,7 +48,6 @@ notebook gradio==3.23 gradio-client==0.0.8 wandb -fschat==0.1.10 llama-index==0.5.27 pymysql unstructured==0.6.3 From 144c6e0148d994a08c9ec313270d1304cc11cfa9 Mon Sep 17 00:00:00 2001 From: csunny Date: Thu, 11 May 2023 11:15:58 +0800 Subject: [PATCH 11/13] fix --- pilot/server/vicuna_server.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pilot/server/vicuna_server.py b/pilot/server/vicuna_server.py index 868e8b6d9..95781b69b 100644 --- a/pilot/server/vicuna_server.py +++ b/pilot/server/vicuna_server.py @@ -10,8 +10,6 @@ from fastapi.responses import StreamingResponse from pilot.model.inference import generate_stream from pydantic import BaseModel from pilot.model.inference import generate_output, get_embeddings -from fastchat.serve.inference import load_model - from pilot.model.loader import ModelLoader from pilot.configs.model_config import * From e1329801275271e83945298b9757c1b4f9d6c337 Mon Sep 17 00:00:00 2001 From: csunny Date: Thu, 11 May 2023 13:30:57 +0800 Subject: [PATCH 12/13] add token size --- pilot/configs/model_config.py | 2 +- pilot/model/inference.py | 4 ++-- pilot/server/webserver.py | 4 ++-- requirements.txt | 3 ++- 4 files changed, 7 insertions(+), 6 deletions(-) diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py index eee05f3e1..da21e4ac8 100644 --- a/pilot/configs/model_config.py +++ b/pilot/configs/model_config.py @@ -27,7 +27,7 @@ LLM_MODEL_CONFIG = { VECTOR_SEARCH_TOP_K = 3 LLM_MODEL = "vicuna-13b" LIMIT_MODEL_CONCURRENCY = 5 -MAX_POSITION_EMBEDDINGS = 2048 +MAX_POSITION_EMBEDDINGS = 4096 VICUNA_MODEL_SERVER = "http://192.168.31.114:8000" diff --git a/pilot/model/inference.py b/pilot/model/inference.py index 66766b3b3..532be9c33 100644 --- a/pilot/model/inference.py +++ b/pilot/model/inference.py @@ -5,13 +5,13 @@ import torch @torch.inference_mode() def generate_stream(model, tokenizer, params, device, - context_len=2048, stream_interval=2): + context_len=4096, stream_interval=2): """Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/inference.py """ prompt = params["prompt"] l_prompt = len(prompt) temperature = float(params.get("temperature", 1.0)) - max_new_tokens = int(params.get("max_new_tokens", 256)) + max_new_tokens = int(params.get("max_new_tokens", 2048)) stop_str = params.get("stop", None) input_ids = tokenizer(prompt).input_ids diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index 1ca3cee20..5c09d7b85 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -301,8 +301,8 @@ def build_single_model_ui(): max_output_tokens = gr.Slider( minimum=0, - maximum=1024, - value=512, + maximum=4096, + value=2048, step=64, interactive=True, label="最大输出Token数", diff --git a/requirements.txt b/requirements.txt index 3e86cd311..d3bcf1bec 100644 --- a/requirements.txt +++ b/requirements.txt @@ -51,4 +51,5 @@ wandb llama-index==0.5.27 pymysql unstructured==0.6.3 -pytesseract==0.3.10 \ No newline at end of file +pytesseract==0.3.10 +chromadb \ No newline at end of file From ca50c9fe47b6035b478ab65833c66f3cf4b54ee2 Mon Sep 17 00:00:00 2001 From: csunny Date: Thu, 11 May 2023 15:40:12 +0800 Subject: [PATCH 13/13] fix --- examples/gradio_test.py | 19 ++++++++++++ pilot/configs/model_config.py | 3 +- pilot/server/webserver.py | 57 +++++++++++++++++++---------------- 3 files changed, 51 insertions(+), 28 deletions(-) create mode 100644 examples/gradio_test.py diff --git a/examples/gradio_test.py b/examples/gradio_test.py new file mode 100644 index 000000000..f39a1ca9e --- /dev/null +++ b/examples/gradio_test.py @@ -0,0 +1,19 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- + +import gradio as gr + +def change_tab(): + return gr.Tabs.update(selected=1) + +with gr.Blocks() as demo: + with gr.Tabs() as tabs: + with gr.TabItem("Train", id=0): + t = gr.Textbox() + with gr.TabItem("Inference", id=1): + i = gr.Image() + + btn = gr.Button() + btn.click(change_tab, None, tabs) + +demo.launch() \ No newline at end of file diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py index da21e4ac8..1e78494ce 100644 --- a/pilot/configs/model_config.py +++ b/pilot/configs/model_config.py @@ -28,8 +28,7 @@ VECTOR_SEARCH_TOP_K = 3 LLM_MODEL = "vicuna-13b" LIMIT_MODEL_CONCURRENCY = 5 MAX_POSITION_EMBEDDINGS = 4096 -VICUNA_MODEL_SERVER = "http://192.168.31.114:8000" - +VICUNA_MODEL_SERVER = "http://47.97.125.199:8000" # Load model config ISLOAD_8BIT = True diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index 5c09d7b85..b75a44e04 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -45,6 +45,7 @@ enable_moderation = False models = [] dbs = [] vs_list = ["新建知识库"] + get_vector_storelist() +autogpt = False priority = { "vicuna-13b": "aaa" @@ -58,8 +59,6 @@ def get_simlar(q): contents = [dc.page_content for dc, _ in docs] return "\n".join(contents) - - def gen_sqlgen_conversation(dbname): mo = MySQLOperator( **DB_SETTINGS @@ -118,6 +117,8 @@ def regenerate(state, request: gr.Request): return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5 def clear_history(request: gr.Request): + + logger.info(f"clear_history. ip: {request.client.host}") state = None return (state, [], "") + (disable_btn,) * 5 @@ -135,7 +136,7 @@ def add_text(state, text, request: gr.Request): return (state, state.to_gradio_chatbot(), moderation_msg) + ( no_change_btn,) * 5 - text = text[:1536] # Hard cut-off + text = text[:4000] # Hard cut-off state.append_message(state.roles[0], text) state.append_message(state.roles[1], None) state.skip_next = False @@ -152,6 +153,8 @@ def post_process_code(code): return code def http_bot(state, mode, db_selector, temperature, max_new_tokens, request: gr.Request): + + print("是否是AUTO-GPT模式.", autogpt) start_tstamp = time.time() model_name = LLM_MODEL @@ -162,7 +165,8 @@ def http_bot(state, mode, db_selector, temperature, max_new_tokens, request: gr. yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5 return - + + # TODO when tab mode is AUTO_GPT, Prompt need to rebuild. if len(state.messages) == state.offset + 2: # 第一轮对话需要加入提示Prompt @@ -251,29 +255,28 @@ def http_bot(state, mode, db_selector, temperature, max_new_tokens, request: gr. block_css = ( code_highlight_css + """ -pre { - white-space: pre-wrap; /* Since CSS 2.1 */ - white-space: -moz-pre-wrap; /* Mozilla, since 1999 */ - white-space: -pre-wrap; /* Opera 4-6 */ - white-space: -o-pre-wrap; /* Opera 7 */ - word-wrap: break-word; /* Internet Explorer 5.5+ */ -} -#notice_markdown th { - display: none; -} - """ + pre { + white-space: pre-wrap; /* Since CSS 2.1 */ + white-space: -moz-pre-wrap; /* Mozilla, since 1999 */ + white-space: -pre-wrap; /* Opera 4-6 */ + white-space: -o-pre-wrap; /* Opera 7 */ + word-wrap: break-word; /* Internet Explorer 5.5+ */ + } + #notice_markdown th { + display: none; + } + """ ) -def change_tab(tab): - pass - def change_mode(mode): if mode in ["默认知识库对话", "LLM原生对话"]: return gr.update(visible=False) else: return gr.update(visible=True) - +def change_tab(): + autogpt = True + def build_single_model_ui(): notice_markdown = """ @@ -307,10 +310,11 @@ def build_single_model_ui(): interactive=True, label="最大输出Token数", ) - tabs = gr.Tabs() + tabs= gr.Tabs() with tabs: - with gr.TabItem("SQL生成与诊断", elem_id="SQL"): - # TODO A selector to choose database + tab_sql = gr.TabItem("SQL生成与诊断", elem_id="SQL") + with tab_sql: + # TODO A selector to choose database with gr.Row(elem_id="db_selector"): db_selector = gr.Dropdown( label="请选择数据库", @@ -318,9 +322,12 @@ def build_single_model_ui(): value=dbs[0] if len(models) > 0 else "", interactive=True, show_label=True).style(container=False) + tab_auto = gr.TabItem("AUTO-GPT", elem_id="auto") + with tab_auto: + gr.Markdown("自动执行模式下, DB-GPT可以具备执行SQL、从网络读取知识自动化存储学习的能力") - with gr.TabItem("知识问答", elem_id="QA"): - + tab_qa = gr.TabItem("知识问答", elem_id="QA") + with tab_qa: mode = gr.Radio(["LLM原生对话", "默认知识库对话", "新增知识库对话"], show_label=False, value="LLM原生对话") vs_setting = gr.Accordion("配置知识库", open=False) mode.change(fn=change_mode, inputs=mode, outputs=vs_setting) @@ -360,9 +367,7 @@ def build_single_model_ui(): regenerate_btn = gr.Button(value="重新生成", interactive=False) clear_btn = gr.Button(value="清理", interactive=False) - gr.Markdown(learn_more_markdown) - btn_list = [regenerate_btn, clear_btn] regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then( http_bot,