From 6d76825a109db6a8d0bc14cf9ccf291ea5ced5b7 Mon Sep 17 00:00:00 2001 From: csunny Date: Thu, 11 May 2023 10:59:08 +0800 Subject: [PATCH] rm fschat relay --- environment.yml | 5 +- pilot/configs/model_config.py | 1 + pilot/model/compression.py | 121 +++++++++++++++++++++++++++++++ pilot/model/loader.py | 2 +- pilot/server/webserver.py | 2 +- pilot/utils.py | 131 ++++++++++++++++++++++++++++++++++ requirements.txt | 1 - 7 files changed, 256 insertions(+), 7 deletions(-) create mode 100644 pilot/model/compression.py diff --git a/environment.yml b/environment.yml index ce2c7339f..db3929e23 100644 --- a/environment.yml +++ b/environment.yml @@ -7,11 +7,9 @@ dependencies: - python=3.9 - cudatoolkit - pip - - pytorch=1.12.1 - pytorch-mutex=1.0=cuda - - torchaudio=0.12.1 - - torchvision=0.13.1 - pip: + - pytorch - accelerate==0.16.0 - aiohttp==3.8.4 - aiosignal==1.3.1 @@ -60,7 +58,6 @@ dependencies: - gradio==3.23 - gradio-client==0.0.8 - wandb - - fschat==0.1.10 - llama-index==0.5.27 - pymysql - unstructured==0.6.3 diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py index 7b4269fb6..eee05f3e1 100644 --- a/pilot/configs/model_config.py +++ b/pilot/configs/model_config.py @@ -5,6 +5,7 @@ import torch import os import nltk + ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) MODEL_PATH = os.path.join(ROOT_PATH, "models") PILOT_PATH = os.path.join(ROOT_PATH, "pilot") diff --git a/pilot/model/compression.py b/pilot/model/compression.py new file mode 100644 index 000000000..9c8c25d08 --- /dev/null +++ b/pilot/model/compression.py @@ -0,0 +1,121 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +import dataclasses + +import torch +from torch import Tensor +import torch.nn as nn +from torch.nn import functional as F + + +@dataclasses.dataclass +class CompressionConfig: + """Group-wise quantization.""" + num_bits: int + group_size: int + group_dim: int + symmetric: bool + enabled: bool = True + + +default_compression_config = CompressionConfig( + num_bits=8, group_size=256, group_dim=1, symmetric=True, enabled=True) + + +class CLinear(nn.Module): + """Compressed Linear Layer.""" + + def __init__(self, weight, bias, device): + super().__init__() + + self.weight = compress(weight.data.to(device), default_compression_config) + self.bias = bias + + def forward(self, input: Tensor) -> Tensor: + weight = decompress(self.weight, default_compression_config) + return F.linear(input, weight, self.bias) + + +def compress_module(module, target_device): + for attr_str in dir(module): + target_attr = getattr(module, attr_str) + if type(target_attr) == torch.nn.Linear: + setattr(module, attr_str, + CLinear(target_attr.weight, target_attr.bias, target_device)) + for name, child in module.named_children(): + compress_module(child, target_device) + + +def compress(tensor, config): + """Simulate group-wise quantization.""" + if not config.enabled: + return tensor + + group_size, num_bits, group_dim, symmetric = ( + config.group_size, config.num_bits, config.group_dim, config.symmetric) + assert num_bits <= 8 + + original_shape = tensor.shape + num_groups = (original_shape[group_dim] + group_size - 1) // group_size + new_shape = (original_shape[:group_dim] + (num_groups, group_size) + + original_shape[group_dim+1:]) + + # Pad + pad_len = (group_size - original_shape[group_dim] % group_size) % group_size + if pad_len != 0: + pad_shape = original_shape[:group_dim] + (pad_len,) + original_shape[group_dim+1:] + tensor = torch.cat([ + tensor, + torch.zeros(pad_shape, dtype=tensor.dtype, device=tensor.device)], + dim=group_dim) + data = tensor.view(new_shape) + + # Quantize + if symmetric: + B = 2 ** (num_bits - 1) - 1 + scale = B / torch.max(data.abs(), dim=group_dim + 1, keepdim=True)[0] + data = data * scale + data = data.clamp_(-B, B).round_().to(torch.int8) + return data, scale, original_shape + else: + B = 2 ** num_bits - 1 + mn = torch.min(data, dim=group_dim + 1, keepdim=True)[0] + mx = torch.max(data, dim=group_dim + 1, keepdim=True)[0] + + scale = B / (mx - mn) + data = data - mn + data.mul_(scale) + + data = data.clamp_(0, B).round_().to(torch.uint8) + return data, mn, scale, original_shape + + +def decompress(packed_data, config): + """Simulate group-wise dequantization.""" + if not config.enabled: + return packed_data + + group_size, num_bits, group_dim, symmetric = ( + config.group_size, config.num_bits, config.group_dim, config.symmetric) + + # Dequantize + if symmetric: + data, scale, original_shape = packed_data + data = data / scale + else: + data, mn, scale, original_shape = packed_data + data = data / scale + data.add_(mn) + + # Unpad + pad_len = (group_size - original_shape[group_dim] % group_size) % group_size + if pad_len: + padded_original_shape = ( + original_shape[:group_dim] + + (original_shape[group_dim] + pad_len,) + + original_shape[group_dim+1:]) + data = data.reshape(padded_original_shape) + indices = [slice(0, x) for x in original_shape] + return data[indices].contiguous() + else: + return data.view(original_shape) diff --git a/pilot/model/loader.py b/pilot/model/loader.py index c5788d2ab..3c0b9a6a7 100644 --- a/pilot/model/loader.py +++ b/pilot/model/loader.py @@ -10,7 +10,7 @@ from transformers import ( AutoModel ) -from fastchat.serve.compression import compress_module +from pilot.model.compression import compress_module class ModelLoader(metaclass=Singleton): """Model loader is a class for model load diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index bbe710667..1ca3cee20 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -24,7 +24,7 @@ from pilot.conversation import ( SeparatorStyle ) -from fastchat.utils import ( +from pilot.utils import ( build_logger, server_error_msg, violates_moderation, diff --git a/pilot/utils.py b/pilot/utils.py index 093b14f99..b2505eba1 100644 --- a/pilot/utils.py +++ b/pilot/utils.py @@ -3,6 +3,21 @@ import torch +import datetime +import logging +import logging.handlers +import os +import sys + +import requests + +from pilot.configs.model_config import LOGDIR + +server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" +moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN." + +handler = None + def get_gpu_memory(max_gpus=None): gpu_memory = [] num_gpus = ( @@ -20,3 +35,119 @@ def get_gpu_memory(max_gpus=None): available_memory = total_memory - allocated_memory gpu_memory.append(available_memory) return gpu_memory + + + +def build_logger(logger_name, logger_filename): + global handler + + formatter = logging.Formatter( + fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + + # Set the format of root handlers + if not logging.getLogger().handlers: + logging.basicConfig(level=logging.INFO, encoding='utf-8') + logging.getLogger().handlers[0].setFormatter(formatter) + + # Redirect stdout and stderr to loggers + stdout_logger = logging.getLogger("stdout") + stdout_logger.setLevel(logging.INFO) + sl = StreamToLogger(stdout_logger, logging.INFO) + sys.stdout = sl + + stderr_logger = logging.getLogger("stderr") + stderr_logger.setLevel(logging.ERROR) + sl = StreamToLogger(stderr_logger, logging.ERROR) + sys.stderr = sl + + # Get logger + logger = logging.getLogger(logger_name) + logger.setLevel(logging.INFO) + + # Add a file handler for all loggers + if handler is None: + os.makedirs(LOGDIR, exist_ok=True) + filename = os.path.join(LOGDIR, logger_filename) + handler = logging.handlers.TimedRotatingFileHandler( + filename, when='D', utc=True) + handler.setFormatter(formatter) + + for name, item in logging.root.manager.loggerDict.items(): + if isinstance(item, logging.Logger): + item.addHandler(handler) + + return logger + + +class StreamToLogger(object): + """ + Fake file-like stream object that redirects writes to a logger instance. + """ + def __init__(self, logger, log_level=logging.INFO): + self.terminal = sys.stdout + self.logger = logger + self.log_level = log_level + self.linebuf = '' + + def __getattr__(self, attr): + return getattr(self.terminal, attr) + + def write(self, buf): + temp_linebuf = self.linebuf + buf + self.linebuf = '' + for line in temp_linebuf.splitlines(True): + # From the io.TextIOWrapper docs: + # On output, if newline is None, any '\n' characters written + # are translated to the system default line separator. + # By default sys.stdout.write() expects '\n' newlines and then + # translates them so this is still cross platform. + if line[-1] == '\n': + encoded_message = line.encode('utf-8', 'ignore').decode('utf-8') + self.logger.log(self.log_level, encoded_message.rstrip()) + else: + self.linebuf += line + + def flush(self): + if self.linebuf != '': + encoded_message = self.linebuf.encode('utf-8', 'ignore').decode('utf-8') + self.logger.log(self.log_level, encoded_message.rstrip()) + self.linebuf = '' + + +def disable_torch_init(): + """ + Disable the redundant torch default initialization to accelerate model creation. + """ + import torch + setattr(torch.nn.Linear, "reset_parameters", lambda self: None) + setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) + + +def violates_moderation(text): + """ + Check whether the text violates OpenAI moderation API. + """ + url = "https://api.openai.com/v1/moderations" + headers = {"Content-Type": "application/json", + "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]} + text = text.replace("\n", "") + data = "{" + '"input": ' + f'"{text}"' + "}" + data = data.encode("utf-8") + try: + ret = requests.post(url, headers=headers, data=data, timeout=5) + flagged = ret.json()["results"][0]["flagged"] + except requests.exceptions.RequestException as e: + flagged = False + except KeyError as e: + flagged = False + + return flagged + + +def pretty_print_semaphore(semaphore): + if semaphore is None: + return "None" + return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" + diff --git a/requirements.txt b/requirements.txt index cdf08db8f..3e86cd311 100644 --- a/requirements.txt +++ b/requirements.txt @@ -48,7 +48,6 @@ notebook gradio==3.23 gradio-client==0.0.8 wandb -fschat==0.1.10 llama-index==0.5.27 pymysql unstructured==0.6.3