refactor: The first refactored version for sdk release (#907)

Co-authored-by: chengfangyin2 <chengfangyin3@jd.com>
This commit is contained in:
FangYin Cheng
2023-12-08 14:45:59 +08:00
committed by GitHub
parent e7e4aff667
commit cd725db1fb
573 changed files with 2094 additions and 3571 deletions

37
dbgpt/model/llm/base.py Normal file
View File

@@ -0,0 +1,37 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from dataclasses import dataclass
from typing import TypedDict
class Message(TypedDict):
"""Vicuna Message object containing a role and the message content"""
role: str
content: str
@dataclass
class ModelInfo:
"""Struct for model information.
Would be lovely to eventually get this directly from APIs
"""
name: str
max_tokens: int
@dataclass
class LLMResponse:
"""Standard response struct for a response from a LLM model."""
model_info = ModelInfo
@dataclass
class ChatModelResponse(LLMResponse):
"""Standard response struct for a response from an LLM model."""
content: str = None

View File

View File

@@ -0,0 +1,147 @@
"""
Fork from text-generation-webui https://github.com/oobabooga/text-generation-webui/blob/main/modules/llamacpp_model.py
"""
import re
from typing import Dict
import logging
import torch
import llama_cpp
from dbgpt.model.parameter import LlamaCppModelParameters
logger = logging.getLogger(__name__)
if torch.cuda.is_available() and not torch.version.hip:
try:
import llama_cpp_cuda
except:
llama_cpp_cuda = None
else:
llama_cpp_cuda = None
def llama_cpp_lib(prefer_cpu: bool = False):
if prefer_cpu or llama_cpp_cuda is None:
logger.info(f"Llama.cpp use cpu")
return llama_cpp
else:
return llama_cpp_cuda
def ban_eos_logits_processor(eos_token, input_ids, logits):
logits[eos_token] = -float("inf")
return logits
def get_params(model_path: str, model_params: LlamaCppModelParameters) -> Dict:
return {
"model_path": model_path,
"n_ctx": model_params.max_context_size,
"seed": model_params.seed,
"n_threads": model_params.n_threads,
"n_batch": model_params.n_batch,
"use_mmap": True,
"use_mlock": False,
"low_vram": False,
"n_gpu_layers": 0 if model_params.prefer_cpu else model_params.n_gpu_layers,
"n_gqa": model_params.n_gqa,
"logits_all": True,
"rms_norm_eps": model_params.rms_norm_eps,
}
class LlamaCppModel:
def __init__(self):
self.initialized = False
self.model = None
self.verbose = True
def __del__(self):
if self.model:
self.model.__del__()
@classmethod
def from_pretrained(self, model_path, model_params: LlamaCppModelParameters):
Llama = llama_cpp_lib(prefer_cpu=model_params.prefer_cpu).Llama
LlamaCache = llama_cpp_lib(prefer_cpu=model_params.prefer_cpu).LlamaCache
result = self()
cache_capacity = 0
cache_capacity_str = model_params.cache_capacity
if cache_capacity_str is not None:
if "GiB" in cache_capacity_str:
cache_capacity = (
int(re.sub("[a-zA-Z]", "", cache_capacity_str)) * 1000 * 1000 * 1000
)
elif "MiB" in cache_capacity_str:
cache_capacity = (
int(re.sub("[a-zA-Z]", "", cache_capacity_str)) * 1000 * 1000
)
else:
cache_capacity = int(cache_capacity_str)
params = get_params(model_path, model_params)
logger.info("Cache capacity is " + str(cache_capacity) + " bytes")
logger.info(f"Load LLama model with params: {params}")
result.model = Llama(**params)
result.verbose = model_params.verbose
if cache_capacity > 0:
result.model.set_cache(LlamaCache(capacity_bytes=cache_capacity))
# This is ugly, but the model and the tokenizer are the same object in this library.
return result, result
def encode(self, string):
if type(string) is str:
string = string.encode()
return self.model.tokenize(string)
def decode(self, tokens):
return self.model.detokenize(tokens)
def generate_streaming(self, params, context_len: int):
# LogitsProcessorList = llama_cpp_lib().LogitsProcessorList
# Read parameters
prompt = params["prompt"]
if self.verbose:
print(f"Prompt of model: \n{prompt}")
temperature = float(params.get("temperature", 1.0))
repetition_penalty = float(params.get("repetition_penalty", 1.1))
top_p = float(params.get("top_p", 1.0))
top_k = int(params.get("top_k", -1)) # -1 means disable
max_new_tokens = int(params.get("max_new_tokens", 2048))
echo = bool(params.get("echo", True))
max_src_len = context_len - max_new_tokens
# Handle truncation
prompt = self.encode(prompt)
prompt = prompt[-max_src_len:]
prompt = self.decode(prompt).decode("utf-8")
# TODO Compared with the original llama model, the Chinese effect of llama.cpp is very general, and it needs to be debugged
completion_chunks = self.model.create_completion(
prompt=prompt,
max_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repeat_penalty=repetition_penalty,
# tfs_z=params['tfs'],
# mirostat_mode=int(params['mirostat_mode']),
# mirostat_tau=params['mirostat_tau'],
# mirostat_eta=params['mirostat_eta'],
stream=True,
echo=echo,
logits_processor=None,
)
output = ""
for completion_chunk in completion_chunks:
text = completion_chunk["choices"][0]["text"]
output += text
# print(output)
yield output

View File

@@ -0,0 +1,74 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import abc
import functools
import time
# TODO Rewrite this
def retry_stream_api(
num_retries: int = 10, backoff_base: float = 2.0, warn_user: bool = True
):
"""Retry an Vicuna Server call.
Args:
num_retries int: Number of retries. Defaults to 10.
backoff_base float: Base for exponential backoff. Defaults to 2.
warn_user bool: Whether to warn the user. Defaults to True.
"""
retry_limit_msg = f"Error: Reached rate limit, passing..."
backoff_msg = f"Error: API Bad gateway. Waiting {{backoff}} seconds..."
def _wrapper(func):
@functools.wraps(func)
def _wrapped(*args, **kwargs):
user_warned = not warn_user
num_attempts = num_retries + 1 # +1 for the first attempt
for attempt in range(1, num_attempts + 1):
try:
return func(*args, **kwargs)
except Exception as e:
if (e.http_status != 502) or (attempt == num_attempts):
raise
backoff = backoff_base ** (attempt + 2)
time.sleep(backoff)
return _wrapped
return _wrapper
class ChatIO(abc.ABC):
@abc.abstractmethod
def prompt_for_input(self, role: str) -> str:
"""Prompt for input from a role."""
@abc.abstractmethod
def prompt_for_output(self, role: str) -> str:
"""Prompt for output from a role."""
@abc.abstractmethod
def stream_output(self, output_stream, skip_echo_len: int):
"""Stream output."""
class SimpleChatIO(ChatIO):
def prompt_for_input(self, role: str) -> str:
return input(f"{role}: ")
def prompt_for_output(self, role: str) -> str:
print(f"{role}: ", end="", flush=True)
def stream_output(self, output_stream, skip_echo_len: int):
pre = 0
for outputs in output_stream:
outputs = outputs[skip_echo_len:].strip()
now = len(outputs) - 1
if now > pre:
print(" ".join(outputs[pre:now]), end=" ", flush=True)
pre = now
print(" ".join(outputs[pre:]), flush=True)
return " ".join(outputs)

View File

@@ -0,0 +1,125 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
import math
from typing import Optional, Tuple
import torch
import transformers
from torch import nn
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2].clone()
x2 = x[..., x.shape[-1] // 2 :].clone()
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1]
gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])
cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = (
self.q_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
key_states = (
self.k_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
value_states = (
self.v_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
# [bsz, nh, t, hd]
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(
self.head_dim
)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
attn_weights = torch.max(
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
)
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
query_states.dtype
)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
def replace_llama_attn_with_non_inplace_operations():
"""Avoid bugs in mps backend by not using in-place operations."""
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
import transformers
def replace_llama_attn_with_non_inplace_operations():
"""Avoid bugs in mps backend by not using in-place operations."""
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward