[inference] Refactor inference architecture (#5057)

* [inference] support only TP (#4998)

* support only tp

* enable tp

* add support for bloom (#5008)

* [refactor] refactor gptq and smoothquant llama (#5012)

* refactor gptq and smoothquant llama

* fix import error

* fix linear import torch-int

* fix smoothquant llama import error

* fix import accelerate error

* fix bug

* fix import smooth cuda

* fix smoothcuda

* [Inference Refactor] Merge chatglm2 with pp and tp (#5023)

merge chatglm with pp and tp

* [Refactor] remove useless inference code (#5022)

* remove useless code

* fix quant model

* fix test import bug

* mv original inference legacy

* fix chatglm2

* [Refactor] refactor policy search and quant type controlling in inference (#5035)

* [Refactor] refactor policy search and quant type controling in inference

* [inference] update readme (#5051)

* update readme

* update readme

* fix architecture

* fix table

* fix table

* [inference] udpate example (#5053)

* udpate example

* fix run.sh

* fix rebase bug

* fix some errors

* update readme

* add some features

* update interface

* update readme

* update benchmark

* add requirements-infer

---------

Co-authored-by: Bin Jia <45593998+FoolPlayer@users.noreply.github.com>
Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>
This commit is contained in:
Xu Kai
2023-11-19 21:05:05 +08:00
committed by GitHub
parent bc09b95f50
commit fd6482ad8c
115 changed files with 6027 additions and 1431 deletions

View File

@@ -0,0 +1,2 @@
from .batch_infer_state import BatchInferState
from .kvcache_manager import MemoryManager

View File

@@ -0,0 +1,118 @@
# might want to consider combine with InferenceConfig in colossalai/ppinference/inference_config.py later
from dataclasses import dataclass
import torch
from transformers.tokenization_utils_base import BatchEncoding
from .kvcache_manager import MemoryManager
# adapted from: lightllm/server/router/model_infer/infer_batch.py
@dataclass
class BatchInferState:
r"""
Information to be passed and used for a batch of inputs during
a single model forward
"""
batch_size: int
max_len_in_batch: int
cache_manager: MemoryManager = None
block_loc: torch.Tensor = None
start_loc: torch.Tensor = None
seq_len: torch.Tensor = None
past_key_values_len: int = None
is_context_stage: bool = False
context_mem_index: torch.Tensor = None
decode_is_contiguous: bool = None
decode_mem_start: int = None
decode_mem_end: int = None
decode_mem_index: torch.Tensor = None
decode_layer_id: int = None
device: torch.device = torch.device("cuda")
@property
def total_token_num(self):
# return self.batch_size * self.max_len_in_batch
assert self.seq_len is not None and self.seq_len.size(0) > 0
return int(torch.sum(self.seq_len))
def set_cache_manager(self, manager: MemoryManager):
self.cache_manager = manager
# adapted from: https://github.com/ModelTC/lightllm/blob/28c1267cfca536b7b4f28e921e03de735b003039/lightllm/common/infer_utils.py#L1
@staticmethod
def init_block_loc(
b_loc: torch.Tensor, seq_len: torch.Tensor, max_len_in_batch: int, alloc_mem_index: torch.Tensor
):
"""in-place update block loc mapping based on the sequence length of the inputs in current bath"""
start_index = 0
seq_len_numpy = seq_len.cpu().numpy()
for i, cur_seq_len in enumerate(seq_len_numpy):
b_loc[i, max_len_in_batch - cur_seq_len : max_len_in_batch] = alloc_mem_index[
start_index : start_index + cur_seq_len
]
start_index += cur_seq_len
return
@classmethod
def init_from_batch(
cls,
batch: torch.Tensor,
max_input_len: int,
max_output_len: int,
cache_manager: MemoryManager,
):
if not isinstance(batch, (BatchEncoding, dict, list, torch.Tensor)):
raise TypeError(f"batch type {type(batch)} is not supported in prepare_batch_state")
input_ids_list = None
attention_mask = None
if isinstance(batch, (BatchEncoding, dict)):
input_ids_list = batch["input_ids"]
attention_mask = batch["attention_mask"]
else:
input_ids_list = batch
if isinstance(input_ids_list[0], int): # for a single input
input_ids_list = [input_ids_list]
attention_mask = [attention_mask] if attention_mask is not None else attention_mask
batch_size = len(input_ids_list)
seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
start_index = 0
max_len_in_batch = -1
if isinstance(batch, (BatchEncoding, dict)):
for i, attn_mask in enumerate(attention_mask):
curr_seq_len = len(attn_mask)
seq_lengths[i] = curr_seq_len
seq_start_indexes[i] = start_index
start_index += curr_seq_len
max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch
else:
length = max(len(input_id) for input_id in input_ids_list)
for i, input_ids in enumerate(input_ids_list):
curr_seq_len = length
seq_lengths[i] = curr_seq_len
seq_start_indexes[i] = start_index
start_index += curr_seq_len
max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch
block_loc = torch.zeros((batch_size, max_input_len + max_output_len), dtype=torch.long, device="cuda")
return cls(
batch_size=batch_size,
max_len_in_batch=max_len_in_batch,
seq_len=seq_lengths.to("cuda"),
start_loc=seq_start_indexes.to("cuda"),
block_loc=block_loc,
decode_layer_id=0,
past_key_values_len=0,
is_context_stage=True,
cache_manager=cache_manager,
)

View File

@@ -0,0 +1,106 @@
"""
Refered/Modified from lightllm/common/mem_manager.py
of the ModelTC/lightllm GitHub repository
https://github.com/ModelTC/lightllm/blob/050af3ce65edca617e2f30ec2479397d5bb248c9/lightllm/common/mem_manager.py
we slightly changed it to make it suitable for our colossal-ai shardformer TP-engine design.
"""
import torch
from transformers.utils import logging
class MemoryManager:
r"""
Manage token block indexes and allocate physical memory for key and value cache
Args:
size: maximum token number used as the size of key and value buffer
dtype: data type of cached key and value
head_num: number of heads the memory manager is responsible for
head_dim: embedded size per head
layer_num: the number of layers in the model
device: device used to store the key and value cache
"""
def __init__(
self,
size: int,
dtype: torch.dtype,
head_num: int,
head_dim: int,
layer_num: int,
device: torch.device = torch.device("cuda"),
):
self.logger = logging.get_logger(__name__)
self.available_size = size
self.max_len_in_batch = 0
self._init_mem_states(size, device)
self._init_kv_buffers(size, device, dtype, head_num, head_dim, layer_num)
def _init_mem_states(self, size, device):
"""Initialize tensors used to manage memory states"""
self.mem_state = torch.ones((size,), dtype=torch.bool, device=device)
self.mem_cum_sum = torch.empty((size,), dtype=torch.int32, device=device)
self.indexes = torch.arange(0, size, dtype=torch.long, device=device)
def _init_kv_buffers(self, size, device, dtype, head_num, head_dim, layer_num):
"""Initialize key buffer and value buffer on specified device"""
self.key_buffer = [
torch.empty((size, head_num, head_dim), dtype=dtype, device=device) for _ in range(layer_num)
]
self.value_buffer = [
torch.empty((size, head_num, head_dim), dtype=dtype, device=device) for _ in range(layer_num)
]
@torch.no_grad()
def alloc(self, required_size):
"""allocate space of required_size by providing indexes representing available physical spaces"""
if required_size > self.available_size:
self.logger.warning(f"No enough cache: required_size {required_size} " f"left_size {self.available_size}")
return None
torch.cumsum(self.mem_state, dim=0, dtype=torch.int32, out=self.mem_cum_sum)
select_index = torch.logical_and(self.mem_cum_sum <= required_size, self.mem_state == 1)
select_index = self.indexes[select_index]
self.mem_state[select_index] = 0
self.available_size -= len(select_index)
return select_index
@torch.no_grad()
def alloc_contiguous(self, required_size):
"""allocate contiguous space of required_size"""
if required_size > self.available_size:
self.logger.warning(f"No enough cache: required_size {required_size} " f"left_size {self.available_size}")
return None
torch.cumsum(self.mem_state, dim=0, dtype=torch.int32, out=self.mem_cum_sum)
sum_size = len(self.mem_cum_sum)
loc_sums = (
self.mem_cum_sum[required_size - 1 :]
- self.mem_cum_sum[0 : sum_size - required_size + 1]
+ self.mem_state[0 : sum_size - required_size + 1]
)
can_used_loc = self.indexes[0 : sum_size - required_size + 1][loc_sums == required_size]
if can_used_loc.shape[0] == 0:
self.logger.info(
f"No enough contiguous cache: required_size {required_size} " f"left_size {self.available_size}"
)
return None
start_loc = can_used_loc[0]
select_index = self.indexes[start_loc : start_loc + required_size]
self.mem_state[select_index] = 0
self.available_size -= len(select_index)
start = start_loc.item()
end = start + required_size
return select_index, start, end
@torch.no_grad()
def free(self, free_index):
"""free memory by updating memory states based on given indexes"""
self.available_size += free_index.shape[0]
self.mem_state[free_index] = 1
@torch.no_grad()
def free_all(self):
"""free all memory by updating memory states"""
self.available_size = len(self.mem_state)
self.mem_state[:] = 1
self.max_len_in_batch = 0
# self.logger.info("freed all space of memory manager")