mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-04-30 04:35:17 +00:00
* [inference] support only TP (#4998) * support only tp * enable tp * add support for bloom (#5008) * [refactor] refactor gptq and smoothquant llama (#5012) * refactor gptq and smoothquant llama * fix import error * fix linear import torch-int * fix smoothquant llama import error * fix import accelerate error * fix bug * fix import smooth cuda * fix smoothcuda * [Inference Refactor] Merge chatglm2 with pp and tp (#5023) merge chatglm with pp and tp * [Refactor] remove useless inference code (#5022) * remove useless code * fix quant model * fix test import bug * mv original inference legacy * fix chatglm2 * [Refactor] refactor policy search and quant type controlling in inference (#5035) * [Refactor] refactor policy search and quant type controling in inference * [inference] update readme (#5051) * update readme * update readme * fix architecture * fix table * fix table * [inference] udpate example (#5053) * udpate example * fix run.sh * fix rebase bug * fix some errors * update readme * add some features * update interface * update readme * update benchmark * add requirements-infer --------- Co-authored-by: Bin Jia <45593998+FoolPlayer@users.noreply.github.com> Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>
107 lines
4.5 KiB
Python
107 lines
4.5 KiB
Python
"""
|
|
Refered/Modified from lightllm/common/mem_manager.py
|
|
of the ModelTC/lightllm GitHub repository
|
|
https://github.com/ModelTC/lightllm/blob/050af3ce65edca617e2f30ec2479397d5bb248c9/lightllm/common/mem_manager.py
|
|
we slightly changed it to make it suitable for our colossal-ai shardformer TP-engine design.
|
|
"""
|
|
import torch
|
|
from transformers.utils import logging
|
|
|
|
|
|
class MemoryManager:
|
|
r"""
|
|
Manage token block indexes and allocate physical memory for key and value cache
|
|
|
|
Args:
|
|
size: maximum token number used as the size of key and value buffer
|
|
dtype: data type of cached key and value
|
|
head_num: number of heads the memory manager is responsible for
|
|
head_dim: embedded size per head
|
|
layer_num: the number of layers in the model
|
|
device: device used to store the key and value cache
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
size: int,
|
|
dtype: torch.dtype,
|
|
head_num: int,
|
|
head_dim: int,
|
|
layer_num: int,
|
|
device: torch.device = torch.device("cuda"),
|
|
):
|
|
self.logger = logging.get_logger(__name__)
|
|
self.available_size = size
|
|
self.max_len_in_batch = 0
|
|
self._init_mem_states(size, device)
|
|
self._init_kv_buffers(size, device, dtype, head_num, head_dim, layer_num)
|
|
|
|
def _init_mem_states(self, size, device):
|
|
"""Initialize tensors used to manage memory states"""
|
|
self.mem_state = torch.ones((size,), dtype=torch.bool, device=device)
|
|
self.mem_cum_sum = torch.empty((size,), dtype=torch.int32, device=device)
|
|
self.indexes = torch.arange(0, size, dtype=torch.long, device=device)
|
|
|
|
def _init_kv_buffers(self, size, device, dtype, head_num, head_dim, layer_num):
|
|
"""Initialize key buffer and value buffer on specified device"""
|
|
self.key_buffer = [
|
|
torch.empty((size, head_num, head_dim), dtype=dtype, device=device) for _ in range(layer_num)
|
|
]
|
|
self.value_buffer = [
|
|
torch.empty((size, head_num, head_dim), dtype=dtype, device=device) for _ in range(layer_num)
|
|
]
|
|
|
|
@torch.no_grad()
|
|
def alloc(self, required_size):
|
|
"""allocate space of required_size by providing indexes representing available physical spaces"""
|
|
if required_size > self.available_size:
|
|
self.logger.warning(f"No enough cache: required_size {required_size} " f"left_size {self.available_size}")
|
|
return None
|
|
torch.cumsum(self.mem_state, dim=0, dtype=torch.int32, out=self.mem_cum_sum)
|
|
select_index = torch.logical_and(self.mem_cum_sum <= required_size, self.mem_state == 1)
|
|
select_index = self.indexes[select_index]
|
|
self.mem_state[select_index] = 0
|
|
self.available_size -= len(select_index)
|
|
return select_index
|
|
|
|
@torch.no_grad()
|
|
def alloc_contiguous(self, required_size):
|
|
"""allocate contiguous space of required_size"""
|
|
if required_size > self.available_size:
|
|
self.logger.warning(f"No enough cache: required_size {required_size} " f"left_size {self.available_size}")
|
|
return None
|
|
torch.cumsum(self.mem_state, dim=0, dtype=torch.int32, out=self.mem_cum_sum)
|
|
sum_size = len(self.mem_cum_sum)
|
|
loc_sums = (
|
|
self.mem_cum_sum[required_size - 1 :]
|
|
- self.mem_cum_sum[0 : sum_size - required_size + 1]
|
|
+ self.mem_state[0 : sum_size - required_size + 1]
|
|
)
|
|
can_used_loc = self.indexes[0 : sum_size - required_size + 1][loc_sums == required_size]
|
|
if can_used_loc.shape[0] == 0:
|
|
self.logger.info(
|
|
f"No enough contiguous cache: required_size {required_size} " f"left_size {self.available_size}"
|
|
)
|
|
return None
|
|
start_loc = can_used_loc[0]
|
|
select_index = self.indexes[start_loc : start_loc + required_size]
|
|
self.mem_state[select_index] = 0
|
|
self.available_size -= len(select_index)
|
|
start = start_loc.item()
|
|
end = start + required_size
|
|
return select_index, start, end
|
|
|
|
@torch.no_grad()
|
|
def free(self, free_index):
|
|
"""free memory by updating memory states based on given indexes"""
|
|
self.available_size += free_index.shape[0]
|
|
self.mem_state[free_index] = 1
|
|
|
|
@torch.no_grad()
|
|
def free_all(self):
|
|
"""free all memory by updating memory states"""
|
|
self.available_size = len(self.mem_state)
|
|
self.mem_state[:] = 1
|
|
self.max_len_in_batch = 0
|
|
self.logger.info("freed all space of memory manager")
|