[inference] Dynamic Batching for Single and Multiple GPUs (#4831)

* finish batch manager

* 1

* first

* fix

* fix dynamic batching

* llama infer

* finish test

* support different lengths generating

* del prints

* del prints

* fix

* fix bug

---------

Co-authored-by: CjhHa1 <cjh18671720497outlook.com>
This commit is contained in:
Jianghai 2023-10-11 17:52:52 +08:00 committed by GitHub
parent 8aed02b957
commit e0757c31fb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 1220 additions and 47 deletions

View File

@ -0,0 +1,346 @@
import collections
from dataclasses import dataclass
from typing import Dict, List , Tuple
import numpy as np
import torch
from colossalai.inference.tensor_parallel import MemoryManager
# make batch infer state an attr of InferBatch
class InferSamplingParams:
def __init__(
self,
do_sample: bool = False,
presence_penalty: float = 0.0,
frequency_penalty: float = 0.0,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
vocab_size: int = -1,
) -> None:
self.do_sample = do_sample
self.presence_penalty = presence_penalty
self.frequency_penalty = frequency_penalty
self.temperature = temperature
self.top_p = top_p
self.top_k = top_k
if self.top_k == -1:
self.top_k = vocab_size
return
@dataclass
class InferBatch:
batch_id: int
requests: List
requests_idx_mapping: Dict[int, int]
input_ids: torch.Tensor
all_input_ids: List[List[int]]
input_lengths: List[int]
out_token_id_counts: List
sampling_param_list: List[InferSamplingParams]
nopad_total_token_num: int
nopad_max_len_in_batch: int
nopad_b_loc: torch.Tensor
nopad_b_start_loc: torch.Tensor
nopad_b_seq_len: torch.Tensor
cache_manager: MemoryManager
max_total_len: int
@classmethod
@torch.no_grad()
def init_batch(
cls,
batch_id,
requests,
dtype: torch.dtype,
device: torch.device,
cache_manager: MemoryManager,
vocab_size: int,
max_total_len: int,
) -> 'InferBatch':
input_lengths = []
all_input_ids = []
requests_idx_mapping = {}
out_token_id_counts = []
sampling_param_list = []
nopad_total_token_num = 0
nopad_max_len_in_batch = 0
nopad_b_loc = torch.empty((len(requests), max_total_len + 12), dtype=torch.long, device="cuda")
# to avoid memory leak , we pre-allocate 12 more space for each batch.
nopad_b_start_loc = torch.zeros(len(requests), dtype=torch.int32, device="cuda")
for i, r in enumerate(requests):
# request id -> idx in list mapping
requests_idx_mapping[r["request_id"]] = i
tokenized_input = r["input_id"]
input_length = len(tokenized_input)
input_lengths.append(input_length)
all_input_ids.append(tokenized_input)
out_token_id_counts.append(collections.defaultdict(int))
# postprocessor
sampling_param = r["sampling_param"]
sampling_param["vocab_size"] = vocab_size
sampling_param_list.append(InferSamplingParams(**sampling_param))
nopad_total_token_num += input_length
nopad_max_len_in_batch = max(nopad_max_len_in_batch, input_length)
nopad_b_seq_len = torch.tensor(input_lengths, dtype=torch.int32, device="cuda")
nopad_b_start_loc[1:] = torch.cumsum(nopad_b_seq_len, dim=0, dtype=torch.int32)[0:-1]
if len(requests) > 1:
input_ids = np.concatenate(all_input_ids, dtype=np.int64)
else:
input_ids = all_input_ids[0]
# Create tensors on device
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
return cls(
batch_id=batch_id,
requests=requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
input_lengths=input_lengths,
all_input_ids=all_input_ids,
nopad_total_token_num=nopad_total_token_num,
nopad_max_len_in_batch=nopad_max_len_in_batch,
nopad_b_loc=nopad_b_loc,
nopad_b_start_loc=nopad_b_start_loc,
nopad_b_seq_len=nopad_b_seq_len,
out_token_id_counts=out_token_id_counts,
sampling_param_list=sampling_param_list,
cache_manager=cache_manager,
max_total_len=max_total_len,
)
@torch.no_grad()
def free_self(self) -> None:
"""
Free the memory of the InferBatch itself
"""
remove_index = []
for idx in range(len(self)):
remove_index.append(
self.nopad_b_loc[
idx,
(self.nopad_max_len_in_batch - 1)
- (self.nopad_b_seq_len[idx] - 1) : (self.nopad_max_len_in_batch - 1),
]
)
remove_index = torch.cat(remove_index, dim=-1)
self.cache_manager.free(remove_index)
@torch.no_grad()
def filter(self, request_ids: List[int]) -> 'InferBatch':
"""
Filter finished batch and return a new InferBatch with left ones.
"""
if len(request_ids) == 0:
raise ValueError("Batch must have at least one request")
if len(request_ids) == len(self):
return self
requests_idx_mapping = {}
indices = []
requests = []
all_input_ids = []
input_lengths = []
nopad_total_token_num = 0
nopad_max_len_in_batch = 0
nopad_b_loc = torch.empty((len(request_ids), self.max_total_len + 12), dtype=torch.long, device="cuda")
nopad_b_start_loc = torch.zeros(len(request_ids), dtype=torch.int32, device="cuda")
nopad_b_seq_len = torch.zeros(len(request_ids), dtype=torch.int32, device="cuda")
left_idx = []
for i, request_id in enumerate(request_ids):
idx = self.requests_idx_mapping[request_id]
left_idx.append(idx)
left_idx_set = set(left_idx)
remove_index = []
for idx in range(len(self)):
if idx not in left_idx_set:
remove_index.append(
self.nopad_b_loc[
idx,
(self.nopad_max_len_in_batch - 1)
- (self.nopad_b_seq_len[idx] - 1) : (self.nopad_max_len_in_batch - 1),
]
)
remove_index = torch.cat(remove_index, dim=-1)
self.cache_manager.free(remove_index)
nopad_max_len_in_batch = 0
for i, request_id in enumerate(request_ids):
idx = self.requests_idx_mapping[request_id]
indices.append(idx)
nopad_b_seq_len[:] = self.nopad_b_seq_len[indices]
nopad_max_len_in_batch = torch.max(nopad_b_seq_len).item()
nopad_b_start_loc[1:] = torch.cumsum(nopad_b_seq_len, dim=0, dtype=torch.int32)[0:-1]
nopad_total_token_num = torch.sum(nopad_b_seq_len).item()
nopad_b_loc[:, 0 : (nopad_max_len_in_batch - 1)] = self.nopad_b_loc[
indices,
(self.nopad_max_len_in_batch - 1) - (nopad_max_len_in_batch - 1) : (self.nopad_max_len_in_batch - 1),
]
for i, request_id in enumerate(request_ids):
idx = self.requests_idx_mapping[request_id]
requests_idx_mapping[request_id] = i
requests.append(self.requests[idx])
all_input_ids.append(self.all_input_ids[idx])
input_lengths.append(self.input_lengths[idx])
input_ids = self.input_ids[indices]
return InferBatch(
batch_id=self.batch_id,
requests=requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
input_lengths=input_lengths,
all_input_ids=all_input_ids,
nopad_total_token_num=nopad_total_token_num,
nopad_max_len_in_batch=nopad_max_len_in_batch,
nopad_b_loc=nopad_b_loc,
nopad_b_start_loc=nopad_b_start_loc,
nopad_b_seq_len=nopad_b_seq_len,
out_token_id_counts=[self.out_token_id_counts[_i] for _i in indices],
sampling_param_list=[self.sampling_param_list[_i] for _i in indices],
cache_manager=self.cache_manager,
max_total_len=self.max_total_len,
)
@classmethod
@torch.no_grad()
def merge(cls, batch1, batch2) -> 'InferBatch':
"""
Return megerd new InferBatch
"""
requests = batch1.requests + batch2.requests
requests_idx_mapping = {}
new_batch_size = len(batch1) + len(batch2)
input_ids = batch1.input_ids.new_empty(new_batch_size)
all_input_ids = []
input_lengths = []
out_token_id_counts = []
sampling_param_list = []
cumulative_batch_size = 0
nopad_total_token_num = batch1.nopad_total_token_num + batch2.nopad_total_token_num
nopad_max_len_in_batch = max(batch1.nopad_max_len_in_batch, batch2.nopad_max_len_in_batch)
max_total_len = max(batch1.max_total_len, batch2.max_total_len)
nopad_b_loc = torch.empty((new_batch_size, batch1.max_total_len + 12), dtype=torch.long, device="cuda")
nopad_b_start_loc = torch.zeros(new_batch_size, dtype=torch.int32, device="cuda")
nopad_b_seq_len = torch.zeros(new_batch_size, dtype=torch.int32, device="cuda")
nopad_start_loc_len_temp = 0
batches = [batch1, batch2]
for i, batch in enumerate(batches):
if i == 0:
requests_idx_mapping = batch.requests_idx_mapping
else:
for k, v in batch.requests_idx_mapping.items():
requests_idx_mapping[k] = v + cumulative_batch_size
start_index = cumulative_batch_size
end_index = cumulative_batch_size + len(batch)
input_ids[start_index:end_index] = batch.input_ids
nopad_b_seq_len[start_index:end_index] = batch.nopad_b_seq_len
nopad_b_start_loc[start_index:end_index] = batch.nopad_b_start_loc + nopad_start_loc_len_temp
nopad_start_loc_len_temp = nopad_b_start_loc[end_index - 1] + nopad_b_seq_len[end_index - 1]
nopad_b_loc[
start_index:end_index,
nopad_max_len_in_batch - batch.nopad_max_len_in_batch : nopad_max_len_in_batch - 1,
] = batch.nopad_b_loc[:, : batch.nopad_max_len_in_batch - 1]
all_input_ids.extend(batch.all_input_ids)
input_lengths.extend(batch.input_lengths)
out_token_id_counts.extend(batch.out_token_id_counts)
sampling_param_list.extend(batch.sampling_param_list)
# Update
cumulative_batch_size += len(batch)
nopad_b_loc[:, nopad_max_len_in_batch - 1] = (
nopad_total_token_num - new_batch_size + torch.arange(0, new_batch_size, dtype=torch.int32, device="cuda")
)
return InferBatch(
batch_id=batches[0].batch_id,
requests=requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
input_lengths=input_lengths,
all_input_ids=all_input_ids,
nopad_total_token_num=nopad_total_token_num,
nopad_max_len_in_batch=nopad_max_len_in_batch,
nopad_b_loc=nopad_b_loc,
nopad_b_start_loc=nopad_b_start_loc,
nopad_b_seq_len=nopad_b_seq_len,
out_token_id_counts=out_token_id_counts,
sampling_param_list=sampling_param_list,
cache_manager=batches[0].cache_manager,
max_total_len=max_total_len,
)
def __len__(self):
return len(self.requests)
def get_post_sample_tensors(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
presence_penalties: List[float] = []
frequency_penalties: List[float] = []
temperatures: List[float] = []
top_ps: List[float] = []
top_ks: List[int] = []
p_token_ids: List[int] = []
p_token_counts: List[int] = []
p_seq_len: List[int] = [
0,
]
p_max_len_in_batch: int = 0
for i, id_to_count in enumerate(self.out_token_id_counts):
sample_param = self.sampling_param_list[i]
presence_penalties.append(sample_param.presence_penalty)
frequency_penalties.append(sample_param.frequency_penalty)
temperatures.append(sample_param.temperature)
top_ps.append(sample_param.top_p)
top_ks.append(sample_param.top_k)
for token_id, count in id_to_count.items():
p_token_ids.append(token_id)
p_token_counts.append(count)
p_seq_len.append(len(id_to_count))
p_max_len_in_batch = max(p_max_len_in_batch, len(id_to_count))
presence_penalties = torch.tensor(presence_penalties, dtype=torch.float, device="cuda")
frequency_penalties = torch.tensor(frequency_penalties, dtype=torch.float, device="cuda")
temperatures = torch.tensor(temperatures, dtype=torch.float, device="cuda")
top_ps = torch.tensor(top_ps, dtype=torch.float, device="cuda")
top_ks = torch.tensor(top_ks, dtype=torch.int32, device="cuda")
p_token_ids = torch.tensor(p_token_ids, dtype=torch.int32, device="cuda")
p_token_counts = torch.tensor(p_token_counts, dtype=torch.int32, device="cuda")
p_seq_len = torch.tensor(p_seq_len, dtype=torch.int32, device="cuda")
p_cumsum_seq_len = torch.cumsum(p_seq_len, dim=0, dtype=torch.int32)
return (
presence_penalties,
frequency_penalties,
temperatures,
top_ps,
top_ks,
p_token_ids,
p_token_counts,
p_cumsum_seq_len,
p_max_len_in_batch,
)

View File

@ -0,0 +1,149 @@
from typing import Dict, List, Tuple
from .sampling_params import SamplingParams
class Req:
def __init__(self, request_id, prompt_ids, sample_params: SamplingParams):
self.request_id = request_id
self.prompt_ids = prompt_ids
self.input_len = len(prompt_ids)
self.max_output_len = sample_params.max_new_tokens
self.sample_params = sample_params
self.output_ids = []
self.output_metadata_list = []
self.has_generate_finished = False
self.aborted = False
def to_rpc_obj(self):
return {
"request_id": self.request_id,
"input_id": self.prompt_ids,
"output_len": self.max_output_len,
"sampling_param": self.sample_params.to_dict(),
}
def to_req_detokenization_state(self):
out = ReqDetokenizationState(
self.request_id, self.prompt_ids, self.max_output_len, self.sample_params.ignore_eos
)
if self.output_metadata_list:
out.gen_metadata.update(self.output_metadata_list[-1])
return out
def stop_sequences_matched(self):
# should we add stpp sequences to the sample params?
if self.sample_params.stop_sequences is not None:
for stop_token_ids in self.sample_params.stop_sequences:
stop_len = len(stop_token_ids)
if stop_len > 0 and len(self.output_ids) >= stop_len and all(self.output_ids[-(stop_len - i)] == stop_token_ids[i] for i in range(stop_len)):
return True
return False
def __repr__(self):
return f"request_id(n={self.request_id}, " f"prompt_ids={self.prompt_ids}, "
class ReqDetokenizationState:
def __init__(
self,
request_id: str,
prompt_ids: List[int],
max_output_len: int,
ignore_eos: bool,
) -> None:
self.request_id = request_id
self.prompt_ids = prompt_ids
self.output_ids = []
self.output_tokens = []
self.output_str = ""
self.sub_texts = []
self.current_sub_text = []
self.max_output_len = max_output_len
self.ignore_eos = ignore_eos
self.gen_metadata = {}
class Batch:
def __init__(self, batch_id, reqs: List[Req]):
self.batch_id = batch_id
self.reqs = reqs
self.id_to_reqs = {req.request_id: req for req in reqs}
def input_tokens(self):
batch_input_tokens = 0
for req in self.reqs:
batch_input_tokens += req.input_len
return batch_input_tokens
def calcu_max_tokens(self):
tokens = 0
for req in self.reqs:
tokens += req.input_len + req.max_output_len
return tokens
def calcu_used_tokens(self):
tokens = 0
for req in self.reqs:
tokens += req.input_len + len(req.output_ids)
return tokens
def mark_finished_req(self, eos_id):
has_new_finish = False
for req in self.reqs:
if req.stop_sequences_matched():
req.has_generate_finished = True
has_new_finish = True
if req.output_ids[-1] == eos_id and req.sample_params.ignore_eos == False:
req.has_generate_finished = True
has_new_finish = True
if len(req.output_ids) >= req.max_output_len or req.aborted:
req.has_generate_finished = True
has_new_finish = True
return has_new_finish
def filter_finished(self):
"""
Filter finished requests from the batch, the finished ones will be removed from 'reqs'.
"""
# TODO: the logic of return should be defined here.
unfinished_req = []
for req in self.reqs:
if not req.has_generate_finished:
unfinished_req.append(req)
self.reqs = unfinished_req
self.id_to_reqs = {req.request_id: req for req in self.reqs}
def is_clear(self):
return len(self.reqs) == 0
def merge(self, mini_batch):
for _req in mini_batch.reqs:
self.reqs.append(_req)
self.id_to_reqs = {req.request_id: req for req in self.reqs}
return
def __repr__(self):
return f"batch_id={self.batch_id}, " f"reqs={self.reqs}, "
def __len__(self):
return len(self.reqs)
class BatchTokenIdOut:
def __init__(self):
self.reqs_infs: List[
Tuple[str, int, Dict, bool, bool]
] = [] # [req_id, new_token_id, gen_metadata, finished_state, abort_state]
class BatchStrOut:
def __init__(self):
self.reqs_infs: List[
Tuple[str, str, Dict, bool, bool]
] = [] # [req_id, token_str, gen_metadata, finished_state, abort_state]
class AbortReq:
def __init__(self, req_id):
self.req_id = req_id

View File

@ -0,0 +1,71 @@
import uuid
from typing import List
import numpy as np
from .io_struct import Batch, Req
class ReqQueue:
def __init__(self, max_total_tokens, batch_max_tokens, running_max_req_size, waiting_req_list=[]) -> None:
self.max_total_tokens = max_total_tokens
assert batch_max_tokens is not None
self.batch_max_tokens = batch_max_tokens
self.running_max_req_size = running_max_req_size
self.waiting_req_list: List[Req] = waiting_req_list
def append(self, req):
self.waiting_req_list.append(req)
return
def _init_cache_list(self, current_batch: Batch):
if current_batch is not None:
self.cache_len_list = [
(req.input_len + len(req.output_ids), req.max_output_len - len(req.output_ids) - 1)
for req in current_batch.reqs
]
else:
self.cache_len_list = []
# @calculate_time(show=True, min_cost_ms=0.1)
def _can_add_new_req(self, req):
self.cache_len_list.append((req.input_len + 1, req.max_output_len - 1)) # hard to analysis
self.cache_len_list.sort(key=lambda x: -x[1])
left_out_len_array = np.array([e[1] for e in self.cache_len_list])
# assert left_out_len_array.min() >= 0
has_run_len_array = np.array([e[0] for e in self.cache_len_list])
cum_run_len_array = np.cumsum(has_run_len_array)
size_array = np.arange(1, len(self.cache_len_list) + 1, 1)
need_max_token_num = (left_out_len_array * size_array + cum_run_len_array).max()
# NOTE: change here < to <=
return need_max_token_num <= self.max_total_tokens and len(self.cache_len_list) <= self.running_max_req_size
def generate_new_batch(self, current_batch: Batch = None):
if current_batch is not None and len(current_batch.reqs) >= self.running_max_req_size:
return None
self._init_cache_list(current_batch)
can_run_list = []
new_batch_total_tokens = 0
aborted_count = 0
for req in self.waiting_req_list:
flag = self._can_add_new_req(req)
if req.aborted:
aborted_count += 1
continue
if flag and new_batch_total_tokens + req.input_len <= self.batch_max_tokens:
can_run_list.append(req)
new_batch_total_tokens += req.input_len
else:
break
if len(can_run_list) != 0:
new_batch = Batch(uuid.uuid4().hex, can_run_list)
self.waiting_req_list = self.waiting_req_list[len(can_run_list) + aborted_count :]
return new_batch
else:
return None
def __len__(self):
return self.waiting_req_list.__len__()

View File

@ -0,0 +1,82 @@
"""Sampling parameters for text generation."""
from typing import List, Optional, Union
_SAMPLING_EPS = 1e-5
class SamplingParams:
def __init__(
self,
do_sample: bool = False,
presence_penalty: float = 0.0,
frequency_penalty: float = 0.0,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1, # -1 is for all
ignore_eos: bool = False,
max_new_tokens: int = 16,
stop_sequences: Optional[Union[str, List[str]]] = None # conditions to stop generation
) -> None:
self.do_sample = do_sample
self.presence_penalty = presence_penalty
self.frequency_penalty = frequency_penalty
self.temperature = temperature
self.top_p = top_p
self.top_k = top_k
self.ignore_eos = ignore_eos
self.max_new_tokens = max_new_tokens
self.stop_sequences = stop_sequences
if self.do_sample == False:
self.temperature = 1.0
self.top_p = 1.0
self.top_k = 1
if self.temperature >= 0.0 and self.temperature < _SAMPLING_EPS: # temperature is too slow, change to greedy search
self.temperature = 1.0
self.top_k = 1
return
def verify(self):
if self.presence_penalty < 0.0:
raise ValueError(f"presence_penalty must >= 0.0, got {self.presence_penalty}")
if self.frequency_penalty < 0.0:
raise ValueError(f"frequency_penalty must >= 0.0, got {self.frequency_penalty}")
if self.temperature <= 0.0:
raise ValueError(f"temperature must > 0.0, got {self.temperature}")
if self.top_p <= 0.0 or self.top_p > 1.0:
raise ValueError(f"top_p must in (0.0, 1.0], got {self.top_p}")
if self.top_k < -1 or self.top_k == 0:
raise ValueError(f"top_k must be -1 (disable), or at least 1, got {self.top_k}.")
if self.max_new_tokens < 1:
raise ValueError(f"max_new_tokens must be at least 1 , got {self.max_new_tokens}.")
return
def stop_sentences_to_token_ids(self, tokenizer):
if self.stop_sequences is None:
self.stop_sequences = []
else:
if isinstance(self.stop_sequences, str):
self.stop_sequences = [self.stop_sequences]
new_stop_sequences = []
for stop_str in self.stop_sequences:
stop_str_ids = tokenizer.encode(stop_str)
if stop_str_ids is not None and len(stop_str_ids) >= 1: # remove bos_token_id
stop_str_ids = stop_str_ids[1:]
if len(stop_str_ids) > 0:
new_stop_sequences.append(stop_str_ids)
self.stop_sequences = new_stop_sequences
return
def to_dict(self):
ret = {}
ret["do_sample"] = self.do_sample
ret["presence_penalty"] = self.presence_penalty
ret["frequency_penalty"] = self.frequency_penalty
ret["temperature"] = self.temperature
ret["top_p"] = self.top_p
ret["top_k"] = self.top_k
# if self.ignore_eos is not None:
# ret["ignore_eos"] = self.ignore_eos
# if self.max_tokens is not None:
# ret["max_tokens"] = self.max_tokens
return ret

View File

@ -0,0 +1,43 @@
import time
class Stats:
def __init__(self, log_status, log_stats_interval) -> None:
self.log_stats = log_status
self.log_stats_interval = log_stats_interval
self.last_log_time = time.time()
self.all_tokens = 0
self.output_tokens = 0
self.prompt_tokens = 0
return
def count_prompt_tokens(self, run_batch):
if self.log_stats:
tokens = run_batch.input_tokens()
self.prompt_tokens += tokens
self.all_tokens += tokens
return
def count_output_tokens(self, run_batch):
if self.log_stats:
tokens = len(run_batch.reqs)
self.output_tokens += tokens
self.all_tokens += tokens
return
def print_stats(self):
if not self.log_stats:
return
now = time.time()
if now - self.last_log_time > self.log_stats_interval:
print(
f"Avg tokens(prompt+generate) throughput: {self.all_tokens/(now-self.last_log_time):8.3f} tokens/s\n"
f"Avg prompt tokens throughput: {self.prompt_tokens/(now-self.last_log_time):8.3f} tokens/s\n"
f"Avg generate tokens throughput: {self.output_tokens/(now-self.last_log_time):8.3f} tokens/s"
)
self.all_tokens = 0
self.output_tokens = 0
self.prompt_tokens = 0
self.last_log_time = now
return

View File

@ -0,0 +1,243 @@
import time
from typing import List
from .dynamic_batching.infer_batch import InferBatch
from .dynamic_batching.io_struct import Batch, Req
from .dynamic_batching.req_queue import ReqQueue
from .dynamic_batching.sampling_params import SamplingParams
from .dynamic_batching.stats import Stats
from .tensor_parallel import TPInferEngine
class DynamicBatchManager:
def __init__(
self,
tp_engine: TPInferEngine,
max_total_token_num,
batch_max_tokens,
eos_id,
log_stats=True,
log_stats_interval=10,
running_batch: Batch = None,
waiting_req_list: List = [],
):
"""
Args: tp_engine : The tp engine that dynamic batch manager hold, defined before dynamic batch manager
max_total_token_num : max_total_token_num for memory manager, default to: max batch size * (max input len + max output len)
batch_max_tokens : max tokens of one batch, default to (max input + output len) * num_requests
running_max_req_size : max request size of running batch, equals to MAX_BATCH_SIZE of tp engine
eos_id : The end token of a seq
log_stats : whether to log stats
log_stats_interval : log stats interval
running_batch : running batch
waiting_req_list : list of waiting requests, initialized before dynamic batch manager
"""
self.engine = tp_engine
self.max_total_token_num = max_total_token_num
running_max_req_size = self.engine.max_batch_size if self.engine is not None else 2
self.req_queue = ReqQueue(max_total_token_num, batch_max_tokens, running_max_req_size, waiting_req_list)
# all the inputs should be put into req_queue: waiting req list
self.running_batch: Batch = running_batch
self.eos_id = eos_id
self.has_wait_tokens = 0
self.max_wait_tokens = 10
self.stats_tool = Stats(log_stats, log_stats_interval)
self.mem_usage_interval = log_stats_interval * 2
def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str):
"""
Add new request to req queue, during initialization all requests are held in waiting list.
"""
req = Req(request_id, prompt_ids, sampling_params)
self.req_queue.append(req)
return
def abort(self, request_id):
if self.running_batch is not None:
for req in self.running_batch.reqs:
if req.request_id == request_id:
req.has_generate_finished = True
req.aborted = True
for req in self.req_queue.waiting_req_list:
if req.request_id == request_id:
req.has_generate_finished = True
req.aborted = True
return
def loop_for_fwd(self):
"""
The main loop for a dynamic batching process.
"""
counter_count = 0
while self.running_batch is not None or self.req_queue.waiting_req_list:
self._step()
counter_count += 1
if self.running_batch is not None:
if counter_count % self.mem_usage_interval == 0:
print(
"current batch size:",
len(self.running_batch.reqs),
"token used ratio:",
self.running_batch.calcu_used_tokens() / self.max_total_token_num,
)
self.stats_tool.print_stats()
if self.running_batch is None:
time.sleep(0.1) # 10ms
def _step(self):
"""
Logic for handling requests
"""
if self.running_batch is None:
new_batch = self.req_queue.generate_new_batch(self.running_batch)
if new_batch is not None:
self.stats_tool.count_prompt_tokens(new_batch)
self.running_batch = new_batch
self._prefill_batch(self.running_batch)
self._filter_runing_batch()
self.has_wait_tokens = 0
return
if self.has_wait_tokens < self.max_wait_tokens:
self.stats_tool.count_output_tokens(self.running_batch)
self._decode_batch(self.running_batch)
self._filter_runing_batch()
self.has_wait_tokens += 1
return
else:
new_mini_batch = self.req_queue.generate_new_batch(self.running_batch)
if new_mini_batch is not None:
self.stats_tool.count_prompt_tokens(new_mini_batch)
self._prefill_batch(new_mini_batch)
if not new_mini_batch.is_clear():
self._merge_batch(self.running_batch, new_mini_batch)
self.running_batch.merge(new_mini_batch)
self.has_wait_tokens = 0
else:
self.stats_tool.count_output_tokens(self.running_batch)
self._decode_batch(self.running_batch)
self._filter_runing_batch()
self.has_wait_tokens += 1
return
def _init_batch(self, batch: Batch, dtype="fp16"):
reqs = [r.to_rpc_obj() for r in batch.reqs]
batch_id = batch.batch_id
import torch
if dtype == "fp16":
dtype = torch.float16
else:
assert False, "error dtype"
batch_data = InferBatch.init_batch(
batch_id,
reqs,
dtype,
torch.cuda.current_device(),
self.engine.cache_manager,
self.engine.model.config.vocab_size,
self.engine.max_input_len + self.engine.max_output_len,
)
self.engine.cache[batch_id] = batch_data
def _prefill_batch(self, batch):
"""
For all batches, no matter it is a new batch or a mini batch, we need to do prefill first.
"""
self._init_batch(batch)
# TODO: figure out if cache and batch id is needed
ans = self.engine._prefill_batch(batch.batch_id)
req_to_out_token_id = ans
self._add_token_id_to_req(batch, req_to_out_token_id)
has_new_finished_req = batch.mark_finished_req(self.eos_id)
self._handle_finish_req(batch, has_new_finished_req)
# delete finished reqs
def _decode_batch(self, batch: Batch):
"""
Decoding process
"""
ans = self.engine._decode_batch(batch.batch_id)
req_to_out_token_id = ans
self._add_token_id_to_req(batch, req_to_out_token_id)
has_new_finished_req = batch.mark_finished_req(self.eos_id)
self._handle_finish_req(batch, has_new_finished_req)
def _filter_batch(self, batch: Batch):
batch_id = batch.batch_id
req_id_list = [r.request_id for r in batch.reqs]
batch = self.engine.cache.pop(batch_id)
filter_batch = batch.filter(req_id_list)
del batch
self.engine.cache[batch_id] = filter_batch
def _merge_batch(self, batch1, batch2):
"""
Merge new mini batch into running batch.
"""
batch1 = self.engine.cache.pop(batch1.batch_id)
batch2 = self.engine.cache.pop(batch2.batch_id)
m_batch = InferBatch.merge(batch1, batch2)
self.engine.cache[batch1.batch_id] = m_batch
del batch1
del batch2
def _remove_batch(self, batch):
"""
Remove finished batch.
"""
batch = self.engine.cache.pop(batch.batch_id)
batch.free_self()
del batch
def _handle_finish_req(self, batch: Batch, has_new_finished_req):
if has_new_finished_req:
batch.filter_finished()
if batch.is_clear():
self._remove_batch(batch)
else:
self._filter_batch(batch)
def _filter_runing_batch(self):
if self.running_batch is not None and self.running_batch.is_clear():
self.running_batch = None
def _add_token_id_to_req(self, batch: Batch, req_ans):
for req_id, (new_token_id, new_gen_metadata) in req_ans.items():
req = batch.id_to_reqs[req_id]
req.output_ids.append(new_token_id)
req.output_metadata_list.append(new_gen_metadata)
return
def clean_up(self):
# this logic should be implemented in the future.
pass
def start_dynamic_batching(args, tp_engine, waiting_req_list):
# try:
batch_manager = DynamicBatchManager(
tp_engine=tp_engine,
max_total_token_num=args.max_total_token_num,
batch_max_tokens=args.batch_max_tokens,
eos_id=args.eos_id,
log_stats=not args.disable_log_stats,
log_stats_interval=args.log_stats_interval,
waiting_req_list=waiting_req_list,
)
# except Exception:
# batch_manager.clean_up()
# raise
batch_manager.loop_for_fwd()
return

View File

@ -1,7 +1,6 @@
from typing import Any, Callable, List, Optional, Union from typing import Any, Callable, List, Optional, Union
import torch import torch
import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from transformers import BloomForCausalLM, LlamaForCausalLM from transformers import BloomForCausalLM, LlamaForCausalLM
from transformers.generation import GenerationConfig from transformers.generation import GenerationConfig
@ -14,6 +13,8 @@ from colossalai.shardformer.policies.auto_policy import get_autopolicy
from .batch_infer_state import BatchInferState from .batch_infer_state import BatchInferState
from .kvcache_manager import MemoryManager from .kvcache_manager import MemoryManager
# from dynamic_batching.infer_batch import InferBatch
DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2 DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2
_supported_models = [ _supported_models = [
@ -90,6 +91,8 @@ class TPInferEngine:
self.shard_config = shard_config self.shard_config = shard_config
self.model = None self.model = None
self.cache = {}
# optimize the original model by sharding with ShardFormer # optimize the original model by sharding with ShardFormer
self._optimize_model(model=model.to(device)) self._optimize_model(model=model.to(device))
@ -116,13 +119,15 @@ class TPInferEngine:
def _post_init_gptq_buffer(self, model: nn.Module) -> None: def _post_init_gptq_buffer(self, model: nn.Module) -> None:
from colossalai.inference.quant.gptq.cai_gptq import CaiQuantLinear from colossalai.inference.quant.gptq.cai_gptq import CaiQuantLinear
HAS_GPTQ_CUDA = False HAS_GPTQ_CUDA = False
try: try:
from colossalai.kernel.op_builder.gptq import GPTQBuilder from colossalai.kernel.op_builder.gptq import GPTQBuilder
gptq_cuda = GPTQBuilder().load() gptq_cuda = GPTQBuilder().load()
HAS_GPTQ_CUDA = True HAS_GPTQ_CUDA = True
except ImportError: except ImportError:
warnings.warn('CUDA gptq is not installed') warnings.warn("CUDA gptq is not installed")
HAS_GPTQ_CUDA = False HAS_GPTQ_CUDA = False
for name, submodule in model.named_modules(): for name, submodule in model.named_modules():
@ -130,8 +135,9 @@ class TPInferEngine:
self.max_dq_buffer_size = max(self.max_dq_buffer_size, submodule.qweight.numel() * 8) self.max_dq_buffer_size = max(self.max_dq_buffer_size, submodule.qweight.numel() * 8)
if self.use_act_order: if self.use_act_order:
self.max_inner_outer_dim = max(self.max_inner_outer_dim, submodule.infeatures, self.max_inner_outer_dim = max(
submodule.outfeatures) self.max_inner_outer_dim, submodule.infeatures, submodule.outfeatures
)
self.bits = submodule.bits self.bits = submodule.bits
if not (HAS_GPTQ_CUDA and self.bits == 4): if not (HAS_GPTQ_CUDA and self.bits == 4):
return return
@ -141,15 +147,16 @@ class TPInferEngine:
max_input_len = self.max_input_len max_input_len = self.max_input_len
# The temp_state buffer is required to reorder X in the act-order case. # The temp_state buffer is required to reorder X in the act-order case.
# The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill. # The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
self.gptq_temp_state_buffer = torch.zeros((max_input_len, self.max_inner_outer_dim), self.gptq_temp_state_buffer = torch.zeros(
dtype=torch.float16, (max_input_len, self.max_inner_outer_dim), dtype=torch.float16, device=torch.cuda.current_device()
device=torch.cuda.current_device()) )
self.gptq_temp_dq_buffer = torch.zeros((1, self.max_dq_buffer_size), self.gptq_temp_dq_buffer = torch.zeros(
dtype=torch.float16, (1, self.max_dq_buffer_size), dtype=torch.float16, device=torch.cuda.current_device()
device=torch.cuda.current_device()) )
gptq_cuda.prepare_buffers(torch.device(torch.cuda.current_device()), self.gptq_temp_state_buffer, gptq_cuda.prepare_buffers(
self.gptq_temp_dq_buffer) torch.device(torch.cuda.current_device()), self.gptq_temp_state_buffer, self.gptq_temp_dq_buffer
)
# Using the default from exllama repo here. # Using the default from exllama repo here.
matmul_recons_thd = 8 matmul_recons_thd = 8
matmul_fused_remap = False matmul_fused_remap = False
@ -270,7 +277,6 @@ class TPInferEngine:
attention_mask = [attention_mask] if attention_mask is not None else attention_mask attention_mask = [attention_mask] if attention_mask is not None else attention_mask
batch_size = len(input_ids_list) batch_size = len(input_ids_list)
seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device="cuda") seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device="cuda") seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
start_index = 0 start_index = 0
@ -304,6 +310,7 @@ class TPInferEngine:
batch_infer_state.past_key_values_len = 0 batch_infer_state.past_key_values_len = 0
batch_infer_state.is_context_stage = True batch_infer_state.is_context_stage = True
batch_infer_state.set_cache_manager(self.cache_manager) batch_infer_state.set_cache_manager(self.cache_manager)
return batch_infer_state return batch_infer_state
@torch.no_grad() @torch.no_grad()
@ -367,6 +374,86 @@ class TPInferEngine:
infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device=device) infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device=device)
infer_state.seq_len += 1 infer_state.seq_len += 1
@torch.no_grad()
def forward(self, batch_id, is_prefill):
"""
Forward is used in Dynamic Batching Manager
"""
batch = self.cache.pop(batch_id)
if is_prefill:
input_ = torch.tensor(batch.all_input_ids).cuda()
else:
input_ = batch.input_ids.reshape(len(batch), 1)
batch_args = {
"batch_size": len(batch),
"max_len_in_batch": batch.nopad_max_len_in_batch,
"block_loc": batch.nopad_b_loc,
"start_loc": batch.nopad_b_start_loc,
"seq_len": batch.nopad_b_seq_len,
"cache_manager": batch.cache_manager,
"is_context_stage": is_prefill,
}
infer_state = BatchInferState(**batch_args)
model = self.model
if isinstance(model, LlamaForCausalLM):
model = self.model.model
elif isinstance(model, BloomForCausalLM):
model = self.model.transformer
setattr(model, "infer_state", infer_state)
output = self.model.forward(input_ids=input_)
logits = output.logits
# bsz, seq_len, vocab_size
prob_out = torch.softmax(
logits[
:,
-1,
],
dim=-1,
).squeeze(1)
# prob_out: bsz, vocab_size
predict_ids = torch.argmax(prob_out, dim=-1, keepdim=True)
prob_out = torch.log(prob_out).detach().cpu().numpy()
predict_ids = predict_ids.detach().cpu().numpy()
# [ batch_size, 1 ]
output_dict = {}
new_input_ids = []
for i, (r, all_input_ids, next_token_id, next_token_logprob) in enumerate(
zip(batch.requests, batch.all_input_ids, predict_ids, prob_out)
):
next_token_id = int(next_token_id)
next_token_logprob = next_token_logprob[next_token_id]
# all_input_ids_tensor = torch.tensor(all_input_ids, dtype=torch.long, device="cuda")
all_input_ids.append(next_token_id)
# all_input_ids_tensor = None
new_input_ids.append(next_token_id)
batch.all_input_ids[i] = all_input_ids
batch.input_lengths[i] += 1
batch.out_token_id_counts[i][next_token_id] += 1
metadata = {
"id": int(next_token_id),
"logprob": float(next_token_logprob),
}
output_dict[r["request_id"]] = (int(next_token_id), metadata)
batch.input_ids = torch.tensor(new_input_ids, dtype=torch.long).cuda()
batch.nopad_total_token_num += len(batch)
batch.nopad_max_len_in_batch += 1
self.cache[batch.batch_id] = batch
return output_dict
@torch.no_grad()
def _prefill_batch(self, batch_id):
return self.forward(batch_id, is_prefill=True)
@torch.no_grad()
def _decode_batch(self, batch_id):
return self.forward(batch_id, is_prefill=False)
# might want to create a sequence pool # might want to create a sequence pool
# add a single request/sequence/input text at a time and record its length # add a single request/sequence/input text at a time and record its length
# In other words, store the actual length of input tokens representing a single input text # In other words, store the actual length of input tokens representing a single input text

View File

@ -45,7 +45,7 @@ def init_to_get_rotary(self, base=10000, use_elem=False):
base = float(base) base = float(base)
# NTK ref: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ # NTK ref: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
ntk_alpha = float(os.environ.get("INFER_NTK_ALPHA", None)) ntk_alpha = os.environ.get("INFER_NTK_ALPHA", None)
if ntk_alpha is not None: if ntk_alpha is not None:
ntk_alpha = float(ntk_alpha) ntk_alpha = float(ntk_alpha)

View File

@ -62,12 +62,11 @@ class LlamaInferenceForwards:
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
): ):
batch_size = input_ids.shape[0] # input_ids.shape[0]
infer_state = self.infer_state infer_state = self.infer_state
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
use_cache = use_cache if use_cache is not None else self.config.use_cache
# retrieve input_ids and inputs_embeds # retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
@ -78,15 +77,12 @@ class LlamaInferenceForwards:
else: else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
seq_length_with_past = seq_length # NOT READY FOR PRIME TIME
past_key_values_length = 0 # dummy but work, revise it
if infer_state.is_context_stage:
if past_key_values is not None: past_key_values_length = 0
# NOT READY FOR PRIME TIME else:
# dummy but work, revise it past_key_values_length = infer_state.max_len_in_batch - 1
past_key_values_length = infer_state.cache_manager.past_key_values_length
# past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
# NOTE: differentiate with prefill stage # NOTE: differentiate with prefill stage
# block_loc require different value-assigning method for two different stage # block_loc require different value-assigning method for two different stage
@ -106,23 +102,23 @@ class LlamaInferenceForwards:
infer_state.decode_mem_index = alloc_mem[0] infer_state.decode_mem_index = alloc_mem[0]
infer_state.decode_mem_start = alloc_mem[1] infer_state.decode_mem_start = alloc_mem[1]
infer_state.decode_mem_end = alloc_mem[2] infer_state.decode_mem_end = alloc_mem[2]
infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index infer_state.block_loc[:, infer_state.max_len_in_batch - 1] = infer_state.decode_mem_index
else: else:
print(f" *** Encountered allocation non-contiguous") print(f" *** Encountered allocation non-contiguous")
print( print(f" infer_state.max_len_in_batch : {infer_state.max_len_in_batch}")
f" infer_state.cache_manager.past_key_values_length: {infer_state.cache_manager.past_key_values_length}"
)
infer_state.decode_is_contiguous = False infer_state.decode_is_contiguous = False
alloc_mem = infer_state.cache_manager.alloc(batch_size) alloc_mem = infer_state.cache_manager.alloc(batch_size)
infer_state.decode_mem_index = alloc_mem infer_state.decode_mem_index = alloc_mem
# infer_state.decode_key_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") # infer_state.decode_key_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda")
# infer_state.decode_value_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") # infer_state.decode_value_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda")
infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index infer_state.block_loc[:, infer_state.max_len_in_batch - 1] = infer_state.decode_mem_index
if position_ids is None: if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange( position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
) )
position_ids = position_ids.repeat(batch_size, 1)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length) position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else: else:
position_ids = position_ids.view(-1, seq_length).long() position_ids = position_ids.view(-1, seq_length).long()
@ -134,6 +130,7 @@ class LlamaInferenceForwards:
infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view( infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view(
position_ids.view(-1).shape[0], -1 position_ids.view(-1).shape[0], -1
) )
else: else:
seq_len = infer_state.seq_len seq_len = infer_state.seq_len
infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1)
@ -145,7 +142,7 @@ class LlamaInferenceForwards:
# embed positions # embed positions
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones( attention_mask = torch.ones(
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device (batch_size, infer_state.max_len_in_batch), dtype=torch.bool, device=inputs_embeds.device
) )
attention_mask = self._prepare_decoder_attention_mask( attention_mask = self._prepare_decoder_attention_mask(
@ -160,7 +157,6 @@ class LlamaInferenceForwards:
next_decoder_cache = () if use_cache else None next_decoder_cache = () if use_cache else None
infer_state.decode_layer_id = 0 infer_state.decode_layer_id = 0
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
# NOTE: modify here for passing args to decoder layer # NOTE: modify here for passing args to decoder layer
@ -184,7 +180,7 @@ class LlamaInferenceForwards:
# update indices # update indices
# infer_state.block_loc[:, infer_state.max_len_in_batch-1] = infer_state.total_token_num + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") # infer_state.block_loc[:, infer_state.max_len_in_batch-1] = infer_state.total_token_num + torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") infer_state.start_loc += torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
infer_state.seq_len += 1 infer_state.seq_len += 1
if not return_dict: if not return_dict:
@ -211,7 +207,6 @@ class LlamaInferenceForwards:
residual = hidden_states residual = hidden_states
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
# Self Attention # Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states, hidden_states=hidden_states,
@ -267,11 +262,8 @@ class LlamaInferenceForwards:
# NOTE might want to revise # NOTE might want to revise
# need some way to record the length of past key values cache # need some way to record the length of past key values cache
# since we won't return past_key_value_cache right now # since we won't return past_key_value_cache right now
if infer_state.decode_layer_id == 0: # once per model.forward
infer_state.cache_manager.past_key_values_length += q_len # seq_len
cos, sin = infer_state.position_cos, infer_state.position_sin cos, sin = infer_state.position_cos, infer_state.position_sin
# print("shape ", cos.shape, query_states.view(-1, self.num_heads, self.head_dim).shape, )
rotary_embedding_fwd(query_states.view(-1, self.num_heads, self.head_dim), cos, sin) rotary_embedding_fwd(query_states.view(-1, self.num_heads, self.head_dim), cos, sin)
rotary_embedding_fwd(key_states.view(-1, self.num_heads, self.head_dim), cos, sin) rotary_embedding_fwd(key_states.view(-1, self.num_heads, self.head_dim), cos, sin)
@ -282,7 +274,6 @@ class LlamaInferenceForwards:
if infer_state.is_context_stage: if infer_state.is_context_stage:
# first token generation # first token generation
# copy key and value calculated in current step to memory manager # copy key and value calculated in current step to memory manager
copy_kv_to_mem_cache( copy_kv_to_mem_cache(
infer_state.decode_layer_id, infer_state.decode_layer_id,
@ -291,9 +282,7 @@ class LlamaInferenceForwards:
infer_state.context_mem_index, infer_state.context_mem_index,
infer_state.cache_manager, infer_state.cache_manager,
) )
attn_output = torch.empty_like(query_states) attn_output = torch.empty_like(query_states)
llama_context_attn_fwd( llama_context_attn_fwd(
query_states, query_states,
key_states, key_states,
@ -301,7 +290,7 @@ class LlamaInferenceForwards:
attn_output, attn_output,
infer_state.start_loc, infer_state.start_loc,
infer_state.seq_len, infer_state.seq_len,
infer_state.cache_manager.past_key_values_length, infer_state.max_len_in_batch,
) )
else: else:
if infer_state.decode_is_contiguous: if infer_state.decode_is_contiguous:
@ -338,7 +327,7 @@ class LlamaInferenceForwards:
infer_state.block_loc, infer_state.block_loc,
infer_state.start_loc, infer_state.start_loc,
infer_state.seq_len, infer_state.seq_len,
infer_state.cache_manager.past_key_values_length, infer_state.max_len_in_batch,
) )
attn_output = attn_output.view(bsz, q_len, self.hidden_size) attn_output = attn_output.view(bsz, q_len, self.hidden_size)

View File

@ -2,7 +2,6 @@ try:
import triton import triton
HAS_TRITON = True HAS_TRITON = True
except ImportError: except ImportError:
HAS_TRITON = False HAS_TRITON = False
print("Triton is not installed. Please install Triton to use Triton kernels.") print("Triton is not installed. Please install Triton to use Triton kernels.")

View File

@ -51,7 +51,6 @@ if HAS_TRITON:
assert head_dim == out.shape[2], "head_dim should be the same for k_ptr and out" assert head_dim == out.shape[2], "head_dim should be the same for k_ptr and out"
num_warps = 2 num_warps = 2
_fwd_copy_kv_cache_dest[(seq_len,)]( _fwd_copy_kv_cache_dest[(seq_len,)](
k_ptr, k_ptr,
dest_index_ptr, dest_index_ptr,

View File

@ -27,8 +27,10 @@ if HAS_LLAMA:
# tokenized_input = tokenizer(input, return_tensors='pt').to('cuda') # tokenized_input = tokenizer(input, return_tensors='pt').to('cuda')
# ----------------------------------- # -----------------------------------
input_ids = torch.Tensor([[1, 15043, 29892, 590, 11203, 338, 274, 1082]]).long() input_ids = torch.Tensor(
attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1]]).long() [[1, 15043, 29892, 590, 11203, 338, 274, 1082], [1, 15043, 29892, 590, 11203, 338, 274, 1082]]
).long()
attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1]]).long()
return dict(input_ids=input_ids, attention_mask=attention_mask) return dict(input_ids=input_ids, attention_mask=attention_mask)
# label is needed for casual lm # label is needed for casual lm

View File

@ -0,0 +1,94 @@
import pytest
from transformers import LlamaForCausalLM
from transformers.models.llama.configuration_llama import LlamaConfig
import colossalai
from colossalai.inference.dynamic_batching.io_struct import Req
from colossalai.inference.dynamic_batching.sampling_params import SamplingParams
from colossalai.inference.manager import DynamicBatchManager
from colossalai.inference.tensor_parallel import TPInferEngine
from colossalai.shardformer import ShardConfig
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
TP_SIZE = 1
BATCH_SIZE = 2
MAX_INPUT_LEN = 5
MAX_OUTPUT_LEN = 16
def run():
sampling_params = SamplingParams()
req1 = Req(0, [1], sampling_params)
req2 = Req(1, [2], sampling_params)
req3 = Req(2, [3], sampling_params)
# req 1-3 are initiliazed as token forward requests
req4 = Req(3, [10, 10, 10, 9, 1], sampling_params)
waiting_list = []
waiting_list.append(req1)
waiting_list.append(req2)
waiting_list.append(req3)
# init model and tp engine
llama_config = LlamaConfig(num_hidden_layers=2, bos_token_id=0, eos_token_id=1, vocab_size=1200, hidden_size=1024)
model = LlamaForCausalLM(llama_config)
model = model.half()
shard_config = ShardConfig(enable_tensor_parallelism=False, inference_only=True)
infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
dynamic_batch_manager = DynamicBatchManager(
tp_engine=infer_engine,
max_total_token_num=42,
batch_max_tokens=42,
eos_id=0,
log_stats=False,
log_stats_interval=10,
waiting_req_list=waiting_list,
)
before_add = len(dynamic_batch_manager.req_queue)
# test add req function
dynamic_batch_manager.add_req(req4.prompt_ids, req4.sample_params, req4.request_id)
assert len(dynamic_batch_manager.req_queue.waiting_req_list) == before_add + 1
# test abort function
dynamic_batch_manager.abort(req4.request_id)
assert dynamic_batch_manager.req_queue.waiting_req_list[-1].aborted == True
# test filter batch function, loop_for_fwd, _step, _init_batch and _prefill/_decode batch are tested
batch = dynamic_batch_manager.req_queue.generate_new_batch()
assert len(batch) == 2
dynamic_batch_manager._init_batch(batch)
assert dynamic_batch_manager.engine.cache[batch.batch_id] is not None
batch.reqs[0].has_generate_finished = True
# filter one finished
batch.filter_finished()
dynamic_batch_manager._filter_batch(batch)
assert len(dynamic_batch_manager.engine.cache) == 1
# test merge batch
new_batch = dynamic_batch_manager.req_queue.generate_new_batch(batch)
assert len(new_batch) == 1
dynamic_batch_manager._init_batch(new_batch)
dynamic_batch_manager._merge_batch(batch, new_batch)
assert len(dynamic_batch_manager.engine.cache[batch.batch_id]) == 2
def check_dynamic_batching_manager(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run()
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_dynamic_batching_manager():
spawn(check_dynamic_batching_manager, 1)
if __name__ == "__main__":
test_dynamic_batching_manager()

View File

@ -0,0 +1,70 @@
import pytest
import torch
from packaging import version
from transformers import LlamaForCausalLM
from transformers.models.llama.configuration_llama import LlamaConfig
import colossalai
from dataclasses import dataclass
from colossalai.inference.dynamic_batching.io_struct import Req
from colossalai.inference.dynamic_batching.sampling_params import SamplingParams
from colossalai.inference.manager import start_dynamic_batching
from colossalai.inference.tensor_parallel import TPInferEngine
from colossalai.shardformer import ShardConfig
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
TP_SIZE = 1
MAX_BATCH_SIZE = 2
MAX_INPUT_LEN = 5
MAX_OUTPUT_LEN = 16
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")
@dataclass
class args:
max_total_token_num: int
batch_max_tokens: int
eos_id: int
disable_log_stats: bool
log_stats_interval: int
def run():
arg = args(max_total_token_num=42, batch_max_tokens=42, eos_id=0, disable_log_stats=False, log_stats_interval=10)
sampling_params = SamplingParams()
req1 = Req(0, [0, 0, 10, 6, 8], sampling_params)
req2 = Req(1, [10, 10, 10, 10, 10], sampling_params)
req3 = Req(2, [0, 0, 10, 10, 10], sampling_params)
req4 = Req(3, [0, 0, 10, 10, 10], sampling_params)
waiting_list = []
waiting_list.append(req1)
waiting_list.append(req2)
waiting_list.append(req3)
waiting_list.append(req4)
llama_config = LlamaConfig(num_hidden_layers=2, bos_token_id=0, eos_token_id=1, vocab_size=1200, hidden_size=1024)
model = LlamaForCausalLM(llama_config)
model = model.half()
shard_config = ShardConfig(enable_tensor_parallelism=True if TP_SIZE > 1 else False, inference_only=True)
infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
start_dynamic_batching(arg, tp_engine=infer_engine, waiting_req_list=waiting_list)
def check_dynamic_forward(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run()
@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_dynamic_batching():
spawn(check_dynamic_forward, TP_SIZE)
if __name__ == "__main__":
test_dynamic_batching()

View File

@ -38,7 +38,6 @@ def run_llama_test(test_config):
enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True
) )
infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
init_to_get_rotary(model.model, base=10000)
generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False) generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False)
input_tokens = { input_tokens = {