mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-15 22:19:28 +00:00
fix confilict
This commit is contained in:
121
pilot/model/compression.py
Normal file
121
pilot/model/compression.py
Normal file
@@ -0,0 +1,121 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
import dataclasses
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CompressionConfig:
|
||||
"""Group-wise quantization."""
|
||||
num_bits: int
|
||||
group_size: int
|
||||
group_dim: int
|
||||
symmetric: bool
|
||||
enabled: bool = True
|
||||
|
||||
|
||||
default_compression_config = CompressionConfig(
|
||||
num_bits=8, group_size=256, group_dim=1, symmetric=True, enabled=True)
|
||||
|
||||
|
||||
class CLinear(nn.Module):
|
||||
"""Compressed Linear Layer."""
|
||||
|
||||
def __init__(self, weight, bias, device):
|
||||
super().__init__()
|
||||
|
||||
self.weight = compress(weight.data.to(device), default_compression_config)
|
||||
self.bias = bias
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
weight = decompress(self.weight, default_compression_config)
|
||||
return F.linear(input, weight, self.bias)
|
||||
|
||||
|
||||
def compress_module(module, target_device):
|
||||
for attr_str in dir(module):
|
||||
target_attr = getattr(module, attr_str)
|
||||
if type(target_attr) == torch.nn.Linear:
|
||||
setattr(module, attr_str,
|
||||
CLinear(target_attr.weight, target_attr.bias, target_device))
|
||||
for name, child in module.named_children():
|
||||
compress_module(child, target_device)
|
||||
|
||||
|
||||
def compress(tensor, config):
|
||||
"""Simulate group-wise quantization."""
|
||||
if not config.enabled:
|
||||
return tensor
|
||||
|
||||
group_size, num_bits, group_dim, symmetric = (
|
||||
config.group_size, config.num_bits, config.group_dim, config.symmetric)
|
||||
assert num_bits <= 8
|
||||
|
||||
original_shape = tensor.shape
|
||||
num_groups = (original_shape[group_dim] + group_size - 1) // group_size
|
||||
new_shape = (original_shape[:group_dim] + (num_groups, group_size) +
|
||||
original_shape[group_dim+1:])
|
||||
|
||||
# Pad
|
||||
pad_len = (group_size - original_shape[group_dim] % group_size) % group_size
|
||||
if pad_len != 0:
|
||||
pad_shape = original_shape[:group_dim] + (pad_len,) + original_shape[group_dim+1:]
|
||||
tensor = torch.cat([
|
||||
tensor,
|
||||
torch.zeros(pad_shape, dtype=tensor.dtype, device=tensor.device)],
|
||||
dim=group_dim)
|
||||
data = tensor.view(new_shape)
|
||||
|
||||
# Quantize
|
||||
if symmetric:
|
||||
B = 2 ** (num_bits - 1) - 1
|
||||
scale = B / torch.max(data.abs(), dim=group_dim + 1, keepdim=True)[0]
|
||||
data = data * scale
|
||||
data = data.clamp_(-B, B).round_().to(torch.int8)
|
||||
return data, scale, original_shape
|
||||
else:
|
||||
B = 2 ** num_bits - 1
|
||||
mn = torch.min(data, dim=group_dim + 1, keepdim=True)[0]
|
||||
mx = torch.max(data, dim=group_dim + 1, keepdim=True)[0]
|
||||
|
||||
scale = B / (mx - mn)
|
||||
data = data - mn
|
||||
data.mul_(scale)
|
||||
|
||||
data = data.clamp_(0, B).round_().to(torch.uint8)
|
||||
return data, mn, scale, original_shape
|
||||
|
||||
|
||||
def decompress(packed_data, config):
|
||||
"""Simulate group-wise dequantization."""
|
||||
if not config.enabled:
|
||||
return packed_data
|
||||
|
||||
group_size, num_bits, group_dim, symmetric = (
|
||||
config.group_size, config.num_bits, config.group_dim, config.symmetric)
|
||||
|
||||
# Dequantize
|
||||
if symmetric:
|
||||
data, scale, original_shape = packed_data
|
||||
data = data / scale
|
||||
else:
|
||||
data, mn, scale, original_shape = packed_data
|
||||
data = data / scale
|
||||
data.add_(mn)
|
||||
|
||||
# Unpad
|
||||
pad_len = (group_size - original_shape[group_dim] % group_size) % group_size
|
||||
if pad_len:
|
||||
padded_original_shape = (
|
||||
original_shape[:group_dim] +
|
||||
(original_shape[group_dim] + pad_len,) +
|
||||
original_shape[group_dim+1:])
|
||||
data = data.reshape(padded_original_shape)
|
||||
indices = [slice(0, x) for x in original_shape]
|
||||
return data[indices].contiguous()
|
||||
else:
|
||||
return data.view(original_shape)
|
@@ -5,13 +5,13 @@ import torch
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate_stream(model, tokenizer, params, device,
|
||||
context_len=2048, stream_interval=2):
|
||||
context_len=4096, stream_interval=2):
|
||||
|
||||
"""Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/inference.py """
|
||||
prompt = params["prompt"]
|
||||
l_prompt = len(prompt)
|
||||
temperature = float(params.get("temperature", 1.0))
|
||||
max_new_tokens = int(params.get("max_new_tokens", 256))
|
||||
max_new_tokens = int(params.get("max_new_tokens", 2048))
|
||||
stop_str = params.get("stop", None)
|
||||
|
||||
input_ids = tokenizer(prompt).input_ids
|
||||
|
34
pilot/model/llm/base.py
Normal file
34
pilot/model/llm/base.py
Normal file
@@ -0,0 +1,34 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, TypedDict
|
||||
|
||||
|
||||
class Message(TypedDict):
|
||||
"""Vicuna Message object containing a role and the message content """
|
||||
|
||||
role: str
|
||||
content: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelInfo:
|
||||
"""Struct for model information.
|
||||
|
||||
Would be lovely to eventually get this directly from APIs
|
||||
"""
|
||||
name: str
|
||||
max_tokens: int
|
||||
|
||||
@dataclass
|
||||
class LLMResponse:
|
||||
"""Standard response struct for a response from a LLM model."""
|
||||
model_info = ModelInfo
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatModelResponse(LLMResponse):
|
||||
"""Standard response struct for a response from an LLM model."""
|
||||
|
||||
content: str = None
|
108
pilot/model/llm/llm_utils.py
Normal file
108
pilot/model/llm/llm_utils.py
Normal file
@@ -0,0 +1,108 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import abc
|
||||
import time
|
||||
import functools
|
||||
from typing import List, Optional
|
||||
from pilot.model.llm.base import Message
|
||||
from pilot.conversation import conv_templates, Conversation, conv_one_shot, auto_dbgpt_one_shot
|
||||
from pilot.configs.config import Config
|
||||
|
||||
|
||||
# TODO Rewrite this
|
||||
def retry_stream_api(
|
||||
num_retries: int = 10,
|
||||
backoff_base: float = 2.0,
|
||||
warn_user: bool = True
|
||||
):
|
||||
"""Retry an Vicuna Server call.
|
||||
|
||||
Args:
|
||||
num_retries int: Number of retries. Defaults to 10.
|
||||
backoff_base float: Base for exponential backoff. Defaults to 2.
|
||||
warn_user bool: Whether to warn the user. Defaults to True.
|
||||
"""
|
||||
retry_limit_msg = f"Error: Reached rate limit, passing..."
|
||||
backoff_msg = (f"Error: API Bad gateway. Waiting {{backoff}} seconds...")
|
||||
|
||||
def _wrapper(func):
|
||||
@functools.wraps(func)
|
||||
def _wrapped(*args, **kwargs):
|
||||
user_warned = not warn_user
|
||||
num_attempts = num_retries + 1 # +1 for the first attempt
|
||||
for attempt in range(1, num_attempts + 1):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
if (e.http_status != 502) or (attempt == num_attempts):
|
||||
raise
|
||||
|
||||
backoff = backoff_base ** (attempt + 2)
|
||||
time.sleep(backoff)
|
||||
return _wrapped
|
||||
return _wrapper
|
||||
|
||||
# Overly simple abstraction util we create something better
|
||||
# simple retry mechanism when getting a rate error or a bad gateway
|
||||
def create_chat_competion(
|
||||
conv: Conversation,
|
||||
model: Optional[str] = None,
|
||||
temperature: float = None,
|
||||
max_new_tokens: Optional[int] = None,
|
||||
) -> str:
|
||||
"""Create a chat completion using the Vicuna-13b
|
||||
|
||||
Args:
|
||||
messages(List[Message]): The messages to send to the chat completion
|
||||
model (str, optional): The model to use. Default to None.
|
||||
temperature (float, optional): The temperature to use. Defaults to 0.7.
|
||||
max_tokens (int, optional): The max tokens to use. Defaults to None.
|
||||
|
||||
Returns:
|
||||
str: The response from the chat completion
|
||||
"""
|
||||
cfg = Config()
|
||||
if temperature is None:
|
||||
temperature = cfg.temperature
|
||||
|
||||
# TODO request vicuna model get response
|
||||
# convert vicuna message to chat completion.
|
||||
for plugin in cfg.plugins:
|
||||
if plugin.can_handle_chat_completion():
|
||||
pass
|
||||
|
||||
|
||||
class ChatIO(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def prompt_for_input(self, role: str) -> str:
|
||||
"""Prompt for input from a role."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def prompt_for_output(self, role: str) -> str:
|
||||
"""Prompt for output from a role."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def stream_output(self, output_stream, skip_echo_len: int):
|
||||
"""Stream output."""
|
||||
|
||||
|
||||
class SimpleChatIO(ChatIO):
|
||||
def prompt_for_input(self, role: str) -> str:
|
||||
return input(f"{role}: ")
|
||||
|
||||
def prompt_for_output(self, role: str) -> str:
|
||||
print(f"{role}: ", end="", flush=True)
|
||||
|
||||
def stream_output(self, output_stream, skip_echo_len: int):
|
||||
pre = 0
|
||||
for outputs in output_stream:
|
||||
outputs = outputs[skip_echo_len:].strip()
|
||||
now = len(outputs) - 1
|
||||
if now > pre:
|
||||
print(" ".join(outputs[pre:now]), end=" ", flush=True)
|
||||
pre = now
|
||||
|
||||
print(" ".join(outputs[pre:]), flush=True)
|
||||
return " ".join(outputs)
|
||||
|
@@ -2,15 +2,17 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
from pilot.singleton import Singleton
|
||||
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoModelForCausalLM,
|
||||
AutoModel
|
||||
)
|
||||
|
||||
from fastchat.serve.compression import compress_module
|
||||
from pilot.model.compression import compress_module
|
||||
|
||||
class ModelLoader:
|
||||
class ModelLoader(metaclass=Singleton):
|
||||
"""Model loader is a class for model load
|
||||
|
||||
Args: model_path
|
||||
|
Reference in New Issue
Block a user