mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 04:50:17 +00:00
[inference] Refactor inference architecture (#5057)
* [inference] support only TP (#4998) * support only tp * enable tp * add support for bloom (#5008) * [refactor] refactor gptq and smoothquant llama (#5012) * refactor gptq and smoothquant llama * fix import error * fix linear import torch-int * fix smoothquant llama import error * fix import accelerate error * fix bug * fix import smooth cuda * fix smoothcuda * [Inference Refactor] Merge chatglm2 with pp and tp (#5023) merge chatglm with pp and tp * [Refactor] remove useless inference code (#5022) * remove useless code * fix quant model * fix test import bug * mv original inference legacy * fix chatglm2 * [Refactor] refactor policy search and quant type controlling in inference (#5035) * [Refactor] refactor policy search and quant type controling in inference * [inference] update readme (#5051) * update readme * update readme * fix architecture * fix table * fix table * [inference] udpate example (#5053) * udpate example * fix run.sh * fix rebase bug * fix some errors * update readme * add some features * update interface * update readme * update benchmark * add requirements-infer --------- Co-authored-by: Bin Jia <45593998+FoolPlayer@users.noreply.github.com> Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>
This commit is contained in:
346
colossalai/legacy/inference/dynamic_batching/infer_batch.py
Normal file
346
colossalai/legacy/inference/dynamic_batching/infer_batch.py
Normal file
@@ -0,0 +1,346 @@
|
||||
# Adapted from https://github.com/ModelTC/lightllm
|
||||
|
||||
import collections
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from colossalai.inference.tensor_parallel import MemoryManager
|
||||
|
||||
|
||||
# make batch infer state an attr of InferBatch
|
||||
class InferSamplingParams:
|
||||
def __init__(
|
||||
self,
|
||||
do_sample: bool = False,
|
||||
presence_penalty: float = 0.0,
|
||||
frequency_penalty: float = 0.0,
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = -1,
|
||||
vocab_size: int = -1,
|
||||
) -> None:
|
||||
self.do_sample = do_sample
|
||||
self.presence_penalty = presence_penalty
|
||||
self.frequency_penalty = frequency_penalty
|
||||
self.temperature = temperature
|
||||
self.top_p = top_p
|
||||
self.top_k = top_k
|
||||
if self.top_k == -1:
|
||||
self.top_k = vocab_size
|
||||
return
|
||||
|
||||
|
||||
@dataclass
|
||||
class InferBatch:
|
||||
batch_id: int
|
||||
requests: List
|
||||
requests_idx_mapping: Dict[int, int]
|
||||
|
||||
input_ids: torch.Tensor
|
||||
|
||||
all_input_ids: List[List[int]]
|
||||
input_lengths: List[int]
|
||||
|
||||
out_token_id_counts: List
|
||||
sampling_param_list: List[InferSamplingParams]
|
||||
|
||||
nopad_total_token_num: int
|
||||
nopad_max_len_in_batch: int
|
||||
nopad_b_loc: torch.Tensor
|
||||
nopad_b_start_loc: torch.Tensor
|
||||
nopad_b_seq_len: torch.Tensor
|
||||
cache_manager: MemoryManager
|
||||
max_total_len: int
|
||||
|
||||
@classmethod
|
||||
@torch.no_grad()
|
||||
def init_batch(
|
||||
cls,
|
||||
batch_id,
|
||||
requests,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
cache_manager: MemoryManager,
|
||||
vocab_size: int,
|
||||
max_total_len: int,
|
||||
) -> "InferBatch":
|
||||
input_lengths = []
|
||||
all_input_ids = []
|
||||
requests_idx_mapping = {}
|
||||
|
||||
out_token_id_counts = []
|
||||
sampling_param_list = []
|
||||
|
||||
nopad_total_token_num = 0
|
||||
nopad_max_len_in_batch = 0
|
||||
nopad_b_loc = torch.empty((len(requests), max_total_len + 12), dtype=torch.long, device="cuda")
|
||||
# to avoid memory leak , we pre-allocate 12 more space for each batch.
|
||||
nopad_b_start_loc = torch.zeros(len(requests), dtype=torch.int32, device="cuda")
|
||||
for i, r in enumerate(requests):
|
||||
# request id -> idx in list mapping
|
||||
requests_idx_mapping[r["request_id"]] = i
|
||||
|
||||
tokenized_input = r["input_id"]
|
||||
|
||||
input_length = len(tokenized_input)
|
||||
input_lengths.append(input_length)
|
||||
all_input_ids.append(tokenized_input)
|
||||
out_token_id_counts.append(collections.defaultdict(int))
|
||||
|
||||
# postprocessor
|
||||
sampling_param = r["sampling_param"]
|
||||
sampling_param["vocab_size"] = vocab_size
|
||||
sampling_param_list.append(InferSamplingParams(**sampling_param))
|
||||
|
||||
nopad_total_token_num += input_length
|
||||
nopad_max_len_in_batch = max(nopad_max_len_in_batch, input_length)
|
||||
|
||||
nopad_b_seq_len = torch.tensor(input_lengths, dtype=torch.int32, device="cuda")
|
||||
nopad_b_start_loc[1:] = torch.cumsum(nopad_b_seq_len, dim=0, dtype=torch.int32)[0:-1]
|
||||
|
||||
if len(requests) > 1:
|
||||
input_ids = np.concatenate(all_input_ids, dtype=np.int64)
|
||||
else:
|
||||
input_ids = all_input_ids[0]
|
||||
|
||||
# Create tensors on device
|
||||
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
||||
|
||||
return cls(
|
||||
batch_id=batch_id,
|
||||
requests=requests,
|
||||
requests_idx_mapping=requests_idx_mapping,
|
||||
input_ids=input_ids,
|
||||
input_lengths=input_lengths,
|
||||
all_input_ids=all_input_ids,
|
||||
nopad_total_token_num=nopad_total_token_num,
|
||||
nopad_max_len_in_batch=nopad_max_len_in_batch,
|
||||
nopad_b_loc=nopad_b_loc,
|
||||
nopad_b_start_loc=nopad_b_start_loc,
|
||||
nopad_b_seq_len=nopad_b_seq_len,
|
||||
out_token_id_counts=out_token_id_counts,
|
||||
sampling_param_list=sampling_param_list,
|
||||
cache_manager=cache_manager,
|
||||
max_total_len=max_total_len,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def free_self(self) -> None:
|
||||
"""
|
||||
Free the memory of the InferBatch itself
|
||||
"""
|
||||
remove_index = []
|
||||
for idx in range(len(self)):
|
||||
remove_index.append(
|
||||
self.nopad_b_loc[
|
||||
idx,
|
||||
(self.nopad_max_len_in_batch - 1)
|
||||
- (self.nopad_b_seq_len[idx] - 1) : (self.nopad_max_len_in_batch - 1),
|
||||
]
|
||||
)
|
||||
remove_index = torch.cat(remove_index, dim=-1)
|
||||
self.cache_manager.free(remove_index)
|
||||
|
||||
@torch.no_grad()
|
||||
def filter(self, request_ids: List[int]) -> "InferBatch":
|
||||
"""
|
||||
Filter finished batch and return a new InferBatch with left ones.
|
||||
"""
|
||||
if len(request_ids) == 0:
|
||||
raise ValueError("Batch must have at least one request")
|
||||
if len(request_ids) == len(self):
|
||||
return self
|
||||
requests_idx_mapping = {}
|
||||
indices = []
|
||||
requests = []
|
||||
all_input_ids = []
|
||||
input_lengths = []
|
||||
nopad_total_token_num = 0
|
||||
nopad_max_len_in_batch = 0
|
||||
nopad_b_loc = torch.empty((len(request_ids), self.max_total_len + 12), dtype=torch.long, device="cuda")
|
||||
nopad_b_start_loc = torch.zeros(len(request_ids), dtype=torch.int32, device="cuda")
|
||||
nopad_b_seq_len = torch.zeros(len(request_ids), dtype=torch.int32, device="cuda")
|
||||
|
||||
left_idx = []
|
||||
for i, request_id in enumerate(request_ids):
|
||||
idx = self.requests_idx_mapping[request_id]
|
||||
left_idx.append(idx)
|
||||
|
||||
left_idx_set = set(left_idx)
|
||||
remove_index = []
|
||||
for idx in range(len(self)):
|
||||
if idx not in left_idx_set:
|
||||
remove_index.append(
|
||||
self.nopad_b_loc[
|
||||
idx,
|
||||
(self.nopad_max_len_in_batch - 1)
|
||||
- (self.nopad_b_seq_len[idx] - 1) : (self.nopad_max_len_in_batch - 1),
|
||||
]
|
||||
)
|
||||
remove_index = torch.cat(remove_index, dim=-1)
|
||||
self.cache_manager.free(remove_index)
|
||||
|
||||
nopad_max_len_in_batch = 0
|
||||
for i, request_id in enumerate(request_ids):
|
||||
idx = self.requests_idx_mapping[request_id]
|
||||
indices.append(idx)
|
||||
|
||||
nopad_b_seq_len[:] = self.nopad_b_seq_len[indices]
|
||||
nopad_max_len_in_batch = torch.max(nopad_b_seq_len).item()
|
||||
nopad_b_start_loc[1:] = torch.cumsum(nopad_b_seq_len, dim=0, dtype=torch.int32)[0:-1]
|
||||
nopad_total_token_num = torch.sum(nopad_b_seq_len).item()
|
||||
|
||||
nopad_b_loc[:, 0 : (nopad_max_len_in_batch - 1)] = self.nopad_b_loc[
|
||||
indices,
|
||||
(self.nopad_max_len_in_batch - 1) - (nopad_max_len_in_batch - 1) : (self.nopad_max_len_in_batch - 1),
|
||||
]
|
||||
for i, request_id in enumerate(request_ids):
|
||||
idx = self.requests_idx_mapping[request_id]
|
||||
requests_idx_mapping[request_id] = i
|
||||
requests.append(self.requests[idx])
|
||||
all_input_ids.append(self.all_input_ids[idx])
|
||||
input_lengths.append(self.input_lengths[idx])
|
||||
|
||||
input_ids = self.input_ids[indices]
|
||||
|
||||
return InferBatch(
|
||||
batch_id=self.batch_id,
|
||||
requests=requests,
|
||||
requests_idx_mapping=requests_idx_mapping,
|
||||
input_ids=input_ids,
|
||||
input_lengths=input_lengths,
|
||||
all_input_ids=all_input_ids,
|
||||
nopad_total_token_num=nopad_total_token_num,
|
||||
nopad_max_len_in_batch=nopad_max_len_in_batch,
|
||||
nopad_b_loc=nopad_b_loc,
|
||||
nopad_b_start_loc=nopad_b_start_loc,
|
||||
nopad_b_seq_len=nopad_b_seq_len,
|
||||
out_token_id_counts=[self.out_token_id_counts[_i] for _i in indices],
|
||||
sampling_param_list=[self.sampling_param_list[_i] for _i in indices],
|
||||
cache_manager=self.cache_manager,
|
||||
max_total_len=self.max_total_len,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@torch.no_grad()
|
||||
def merge(cls, batch1, batch2) -> "InferBatch":
|
||||
"""
|
||||
Return megerd new InferBatch
|
||||
"""
|
||||
requests = batch1.requests + batch2.requests
|
||||
requests_idx_mapping = {}
|
||||
new_batch_size = len(batch1) + len(batch2)
|
||||
|
||||
input_ids = batch1.input_ids.new_empty(new_batch_size)
|
||||
all_input_ids = []
|
||||
input_lengths = []
|
||||
out_token_id_counts = []
|
||||
sampling_param_list = []
|
||||
|
||||
cumulative_batch_size = 0
|
||||
nopad_total_token_num = batch1.nopad_total_token_num + batch2.nopad_total_token_num
|
||||
nopad_max_len_in_batch = max(batch1.nopad_max_len_in_batch, batch2.nopad_max_len_in_batch)
|
||||
max_total_len = max(batch1.max_total_len, batch2.max_total_len)
|
||||
nopad_b_loc = torch.empty((new_batch_size, batch1.max_total_len + 12), dtype=torch.long, device="cuda")
|
||||
nopad_b_start_loc = torch.zeros(new_batch_size, dtype=torch.int32, device="cuda")
|
||||
nopad_b_seq_len = torch.zeros(new_batch_size, dtype=torch.int32, device="cuda")
|
||||
nopad_start_loc_len_temp = 0
|
||||
batches = [batch1, batch2]
|
||||
for i, batch in enumerate(batches):
|
||||
if i == 0:
|
||||
requests_idx_mapping = batch.requests_idx_mapping
|
||||
else:
|
||||
for k, v in batch.requests_idx_mapping.items():
|
||||
requests_idx_mapping[k] = v + cumulative_batch_size
|
||||
start_index = cumulative_batch_size
|
||||
end_index = cumulative_batch_size + len(batch)
|
||||
input_ids[start_index:end_index] = batch.input_ids
|
||||
nopad_b_seq_len[start_index:end_index] = batch.nopad_b_seq_len
|
||||
nopad_b_start_loc[start_index:end_index] = batch.nopad_b_start_loc + nopad_start_loc_len_temp
|
||||
nopad_start_loc_len_temp = nopad_b_start_loc[end_index - 1] + nopad_b_seq_len[end_index - 1]
|
||||
nopad_b_loc[
|
||||
start_index:end_index,
|
||||
nopad_max_len_in_batch - batch.nopad_max_len_in_batch : nopad_max_len_in_batch - 1,
|
||||
] = batch.nopad_b_loc[:, : batch.nopad_max_len_in_batch - 1]
|
||||
|
||||
all_input_ids.extend(batch.all_input_ids)
|
||||
|
||||
input_lengths.extend(batch.input_lengths)
|
||||
out_token_id_counts.extend(batch.out_token_id_counts)
|
||||
sampling_param_list.extend(batch.sampling_param_list)
|
||||
# Update
|
||||
cumulative_batch_size += len(batch)
|
||||
|
||||
nopad_b_loc[:, nopad_max_len_in_batch - 1] = (
|
||||
nopad_total_token_num - new_batch_size + torch.arange(0, new_batch_size, dtype=torch.int32, device="cuda")
|
||||
)
|
||||
return InferBatch(
|
||||
batch_id=batches[0].batch_id,
|
||||
requests=requests,
|
||||
requests_idx_mapping=requests_idx_mapping,
|
||||
input_ids=input_ids,
|
||||
input_lengths=input_lengths,
|
||||
all_input_ids=all_input_ids,
|
||||
nopad_total_token_num=nopad_total_token_num,
|
||||
nopad_max_len_in_batch=nopad_max_len_in_batch,
|
||||
nopad_b_loc=nopad_b_loc,
|
||||
nopad_b_start_loc=nopad_b_start_loc,
|
||||
nopad_b_seq_len=nopad_b_seq_len,
|
||||
out_token_id_counts=out_token_id_counts,
|
||||
sampling_param_list=sampling_param_list,
|
||||
cache_manager=batches[0].cache_manager,
|
||||
max_total_len=max_total_len,
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.requests)
|
||||
|
||||
def get_post_sample_tensors(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
presence_penalties: List[float] = []
|
||||
frequency_penalties: List[float] = []
|
||||
temperatures: List[float] = []
|
||||
top_ps: List[float] = []
|
||||
top_ks: List[int] = []
|
||||
p_token_ids: List[int] = []
|
||||
p_token_counts: List[int] = []
|
||||
p_seq_len: List[int] = [
|
||||
0,
|
||||
]
|
||||
p_max_len_in_batch: int = 0
|
||||
for i, id_to_count in enumerate(self.out_token_id_counts):
|
||||
sample_param = self.sampling_param_list[i]
|
||||
presence_penalties.append(sample_param.presence_penalty)
|
||||
frequency_penalties.append(sample_param.frequency_penalty)
|
||||
temperatures.append(sample_param.temperature)
|
||||
top_ps.append(sample_param.top_p)
|
||||
top_ks.append(sample_param.top_k)
|
||||
|
||||
for token_id, count in id_to_count.items():
|
||||
p_token_ids.append(token_id)
|
||||
p_token_counts.append(count)
|
||||
p_seq_len.append(len(id_to_count))
|
||||
p_max_len_in_batch = max(p_max_len_in_batch, len(id_to_count))
|
||||
|
||||
presence_penalties = torch.tensor(presence_penalties, dtype=torch.float, device="cuda")
|
||||
frequency_penalties = torch.tensor(frequency_penalties, dtype=torch.float, device="cuda")
|
||||
temperatures = torch.tensor(temperatures, dtype=torch.float, device="cuda")
|
||||
top_ps = torch.tensor(top_ps, dtype=torch.float, device="cuda")
|
||||
top_ks = torch.tensor(top_ks, dtype=torch.int32, device="cuda")
|
||||
p_token_ids = torch.tensor(p_token_ids, dtype=torch.int32, device="cuda")
|
||||
p_token_counts = torch.tensor(p_token_counts, dtype=torch.int32, device="cuda")
|
||||
p_seq_len = torch.tensor(p_seq_len, dtype=torch.int32, device="cuda")
|
||||
p_cumsum_seq_len = torch.cumsum(p_seq_len, dim=0, dtype=torch.int32)
|
||||
return (
|
||||
presence_penalties,
|
||||
frequency_penalties,
|
||||
temperatures,
|
||||
top_ps,
|
||||
top_ks,
|
||||
p_token_ids,
|
||||
p_token_counts,
|
||||
p_cumsum_seq_len,
|
||||
p_max_len_in_batch,
|
||||
)
|
Reference in New Issue
Block a user