mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-13 05:01:25 +00:00
refactor: The first refactored version for sdk release (#907)
Co-authored-by: chengfangyin2 <chengfangyin3@jd.com>
This commit is contained in:
37
dbgpt/model/llm/base.py
Normal file
37
dbgpt/model/llm/base.py
Normal file
@@ -0,0 +1,37 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TypedDict
|
||||
|
||||
|
||||
class Message(TypedDict):
|
||||
"""Vicuna Message object containing a role and the message content"""
|
||||
|
||||
role: str
|
||||
content: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelInfo:
|
||||
"""Struct for model information.
|
||||
|
||||
Would be lovely to eventually get this directly from APIs
|
||||
"""
|
||||
|
||||
name: str
|
||||
max_tokens: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMResponse:
|
||||
"""Standard response struct for a response from a LLM model."""
|
||||
|
||||
model_info = ModelInfo
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatModelResponse(LLMResponse):
|
||||
"""Standard response struct for a response from an LLM model."""
|
||||
|
||||
content: str = None
|
0
dbgpt/model/llm/llama_cpp/__init__.py
Normal file
0
dbgpt/model/llm/llama_cpp/__init__.py
Normal file
147
dbgpt/model/llm/llama_cpp/llama_cpp.py
Normal file
147
dbgpt/model/llm/llama_cpp/llama_cpp.py
Normal file
@@ -0,0 +1,147 @@
|
||||
"""
|
||||
Fork from text-generation-webui https://github.com/oobabooga/text-generation-webui/blob/main/modules/llamacpp_model.py
|
||||
"""
|
||||
import re
|
||||
from typing import Dict
|
||||
import logging
|
||||
import torch
|
||||
import llama_cpp
|
||||
|
||||
from dbgpt.model.parameter import LlamaCppModelParameters
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if torch.cuda.is_available() and not torch.version.hip:
|
||||
try:
|
||||
import llama_cpp_cuda
|
||||
except:
|
||||
llama_cpp_cuda = None
|
||||
else:
|
||||
llama_cpp_cuda = None
|
||||
|
||||
|
||||
def llama_cpp_lib(prefer_cpu: bool = False):
|
||||
if prefer_cpu or llama_cpp_cuda is None:
|
||||
logger.info(f"Llama.cpp use cpu")
|
||||
return llama_cpp
|
||||
else:
|
||||
return llama_cpp_cuda
|
||||
|
||||
|
||||
def ban_eos_logits_processor(eos_token, input_ids, logits):
|
||||
logits[eos_token] = -float("inf")
|
||||
return logits
|
||||
|
||||
|
||||
def get_params(model_path: str, model_params: LlamaCppModelParameters) -> Dict:
|
||||
return {
|
||||
"model_path": model_path,
|
||||
"n_ctx": model_params.max_context_size,
|
||||
"seed": model_params.seed,
|
||||
"n_threads": model_params.n_threads,
|
||||
"n_batch": model_params.n_batch,
|
||||
"use_mmap": True,
|
||||
"use_mlock": False,
|
||||
"low_vram": False,
|
||||
"n_gpu_layers": 0 if model_params.prefer_cpu else model_params.n_gpu_layers,
|
||||
"n_gqa": model_params.n_gqa,
|
||||
"logits_all": True,
|
||||
"rms_norm_eps": model_params.rms_norm_eps,
|
||||
}
|
||||
|
||||
|
||||
class LlamaCppModel:
|
||||
def __init__(self):
|
||||
self.initialized = False
|
||||
self.model = None
|
||||
self.verbose = True
|
||||
|
||||
def __del__(self):
|
||||
if self.model:
|
||||
self.model.__del__()
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, model_path, model_params: LlamaCppModelParameters):
|
||||
Llama = llama_cpp_lib(prefer_cpu=model_params.prefer_cpu).Llama
|
||||
LlamaCache = llama_cpp_lib(prefer_cpu=model_params.prefer_cpu).LlamaCache
|
||||
|
||||
result = self()
|
||||
cache_capacity = 0
|
||||
cache_capacity_str = model_params.cache_capacity
|
||||
if cache_capacity_str is not None:
|
||||
if "GiB" in cache_capacity_str:
|
||||
cache_capacity = (
|
||||
int(re.sub("[a-zA-Z]", "", cache_capacity_str)) * 1000 * 1000 * 1000
|
||||
)
|
||||
elif "MiB" in cache_capacity_str:
|
||||
cache_capacity = (
|
||||
int(re.sub("[a-zA-Z]", "", cache_capacity_str)) * 1000 * 1000
|
||||
)
|
||||
else:
|
||||
cache_capacity = int(cache_capacity_str)
|
||||
|
||||
params = get_params(model_path, model_params)
|
||||
logger.info("Cache capacity is " + str(cache_capacity) + " bytes")
|
||||
logger.info(f"Load LLama model with params: {params}")
|
||||
|
||||
result.model = Llama(**params)
|
||||
result.verbose = model_params.verbose
|
||||
if cache_capacity > 0:
|
||||
result.model.set_cache(LlamaCache(capacity_bytes=cache_capacity))
|
||||
|
||||
# This is ugly, but the model and the tokenizer are the same object in this library.
|
||||
return result, result
|
||||
|
||||
def encode(self, string):
|
||||
if type(string) is str:
|
||||
string = string.encode()
|
||||
|
||||
return self.model.tokenize(string)
|
||||
|
||||
def decode(self, tokens):
|
||||
return self.model.detokenize(tokens)
|
||||
|
||||
def generate_streaming(self, params, context_len: int):
|
||||
# LogitsProcessorList = llama_cpp_lib().LogitsProcessorList
|
||||
|
||||
# Read parameters
|
||||
prompt = params["prompt"]
|
||||
if self.verbose:
|
||||
print(f"Prompt of model: \n{prompt}")
|
||||
|
||||
temperature = float(params.get("temperature", 1.0))
|
||||
repetition_penalty = float(params.get("repetition_penalty", 1.1))
|
||||
top_p = float(params.get("top_p", 1.0))
|
||||
top_k = int(params.get("top_k", -1)) # -1 means disable
|
||||
max_new_tokens = int(params.get("max_new_tokens", 2048))
|
||||
echo = bool(params.get("echo", True))
|
||||
|
||||
max_src_len = context_len - max_new_tokens
|
||||
# Handle truncation
|
||||
prompt = self.encode(prompt)
|
||||
prompt = prompt[-max_src_len:]
|
||||
prompt = self.decode(prompt).decode("utf-8")
|
||||
|
||||
# TODO Compared with the original llama model, the Chinese effect of llama.cpp is very general, and it needs to be debugged
|
||||
completion_chunks = self.model.create_completion(
|
||||
prompt=prompt,
|
||||
max_tokens=max_new_tokens,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
repeat_penalty=repetition_penalty,
|
||||
# tfs_z=params['tfs'],
|
||||
# mirostat_mode=int(params['mirostat_mode']),
|
||||
# mirostat_tau=params['mirostat_tau'],
|
||||
# mirostat_eta=params['mirostat_eta'],
|
||||
stream=True,
|
||||
echo=echo,
|
||||
logits_processor=None,
|
||||
)
|
||||
|
||||
output = ""
|
||||
for completion_chunk in completion_chunks:
|
||||
text = completion_chunk["choices"][0]["text"]
|
||||
output += text
|
||||
# print(output)
|
||||
yield output
|
74
dbgpt/model/llm/llm_utils.py
Normal file
74
dbgpt/model/llm/llm_utils.py
Normal file
@@ -0,0 +1,74 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import abc
|
||||
import functools
|
||||
import time
|
||||
|
||||
|
||||
# TODO Rewrite this
|
||||
def retry_stream_api(
|
||||
num_retries: int = 10, backoff_base: float = 2.0, warn_user: bool = True
|
||||
):
|
||||
"""Retry an Vicuna Server call.
|
||||
|
||||
Args:
|
||||
num_retries int: Number of retries. Defaults to 10.
|
||||
backoff_base float: Base for exponential backoff. Defaults to 2.
|
||||
warn_user bool: Whether to warn the user. Defaults to True.
|
||||
"""
|
||||
retry_limit_msg = f"Error: Reached rate limit, passing..."
|
||||
backoff_msg = f"Error: API Bad gateway. Waiting {{backoff}} seconds..."
|
||||
|
||||
def _wrapper(func):
|
||||
@functools.wraps(func)
|
||||
def _wrapped(*args, **kwargs):
|
||||
user_warned = not warn_user
|
||||
num_attempts = num_retries + 1 # +1 for the first attempt
|
||||
for attempt in range(1, num_attempts + 1):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
if (e.http_status != 502) or (attempt == num_attempts):
|
||||
raise
|
||||
|
||||
backoff = backoff_base ** (attempt + 2)
|
||||
time.sleep(backoff)
|
||||
|
||||
return _wrapped
|
||||
|
||||
return _wrapper
|
||||
|
||||
|
||||
class ChatIO(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def prompt_for_input(self, role: str) -> str:
|
||||
"""Prompt for input from a role."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def prompt_for_output(self, role: str) -> str:
|
||||
"""Prompt for output from a role."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def stream_output(self, output_stream, skip_echo_len: int):
|
||||
"""Stream output."""
|
||||
|
||||
|
||||
class SimpleChatIO(ChatIO):
|
||||
def prompt_for_input(self, role: str) -> str:
|
||||
return input(f"{role}: ")
|
||||
|
||||
def prompt_for_output(self, role: str) -> str:
|
||||
print(f"{role}: ", end="", flush=True)
|
||||
|
||||
def stream_output(self, output_stream, skip_echo_len: int):
|
||||
pre = 0
|
||||
for outputs in output_stream:
|
||||
outputs = outputs[skip_echo_len:].strip()
|
||||
now = len(outputs) - 1
|
||||
if now > pre:
|
||||
print(" ".join(outputs[pre:now]), end=" ", flush=True)
|
||||
pre = now
|
||||
|
||||
print(" ".join(outputs[pre:]), flush=True)
|
||||
return " ".join(outputs)
|
125
dbgpt/model/llm/monkey_patch.py
Normal file
125
dbgpt/model/llm/monkey_patch.py
Normal file
@@ -0,0 +1,125 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from torch import nn
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2].clone()
|
||||
x2 = x[..., x.shape[-1] // 2 :].clone()
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
||||
gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1]
|
||||
gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])
|
||||
cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
|
||||
sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = (
|
||||
self.q_proj(hidden_states)
|
||||
.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
key_states = (
|
||||
self.k_proj(hidden_states)
|
||||
.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
value_states = (
|
||||
self.v_proj(hidden_states)
|
||||
.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, position_ids
|
||||
)
|
||||
# [bsz, nh, t, hd]
|
||||
|
||||
if past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(
|
||||
self.head_dim
|
||||
)
|
||||
|
||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
attn_weights = torch.max(
|
||||
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
|
||||
)
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
|
||||
query_states.dtype
|
||||
)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
def replace_llama_attn_with_non_inplace_operations():
|
||||
"""Avoid bugs in mps backend by not using in-place operations."""
|
||||
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
|
||||
|
||||
|
||||
import transformers
|
||||
|
||||
|
||||
def replace_llama_attn_with_non_inplace_operations():
|
||||
"""Avoid bugs in mps backend by not using in-place operations."""
|
||||
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
|
Reference in New Issue
Block a user