ColossalAI/colossalai/legacy/inference/tensor_parallel/batch_infer_state.py
pre-commit-ci[bot] 7c2f79fa98
[pre-commit.ci] pre-commit autoupdate (#5572)
* [pre-commit.ci] pre-commit autoupdate

updates:
- [github.com/PyCQA/autoflake: v2.2.1 → v2.3.1](https://github.com/PyCQA/autoflake/compare/v2.2.1...v2.3.1)
- [github.com/pycqa/isort: 5.12.0 → 5.13.2](https://github.com/pycqa/isort/compare/5.12.0...5.13.2)
- [github.com/psf/black-pre-commit-mirror: 23.9.1 → 24.4.2](https://github.com/psf/black-pre-commit-mirror/compare/23.9.1...24.4.2)
- [github.com/pre-commit/mirrors-clang-format: v13.0.1 → v18.1.7](https://github.com/pre-commit/mirrors-clang-format/compare/v13.0.1...v18.1.7)
- [github.com/pre-commit/pre-commit-hooks: v4.3.0 → v4.6.0](https://github.com/pre-commit/pre-commit-hooks/compare/v4.3.0...v4.6.0)

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2024-07-01 17:16:41 +08:00

120 lines
4.4 KiB
Python

# might want to consider combine with InferenceConfig in colossalai/ppinference/inference_config.py later
from dataclasses import dataclass
import torch
from transformers.tokenization_utils_base import BatchEncoding
from .kvcache_manager import MemoryManager
# adapted from: lightllm/server/router/model_infer/infer_batch.py
@dataclass
class BatchInferState:
r"""
Information to be passed and used for a batch of inputs during
a single model forward
"""
batch_size: int
max_len_in_batch: int
cache_manager: MemoryManager = None
block_loc: torch.Tensor = None
start_loc: torch.Tensor = None
seq_len: torch.Tensor = None
past_key_values_len: int = None
is_context_stage: bool = False
context_mem_index: torch.Tensor = None
decode_is_contiguous: bool = None
decode_mem_start: int = None
decode_mem_end: int = None
decode_mem_index: torch.Tensor = None
decode_layer_id: int = None
device: torch.device = torch.device("cuda")
@property
def total_token_num(self):
# return self.batch_size * self.max_len_in_batch
assert self.seq_len is not None and self.seq_len.size(0) > 0
return int(torch.sum(self.seq_len))
def set_cache_manager(self, manager: MemoryManager):
self.cache_manager = manager
# adapted from: https://github.com/ModelTC/lightllm/blob/28c1267cfca536b7b4f28e921e03de735b003039/lightllm/common/infer_utils.py#L1
@staticmethod
def init_block_loc(
b_loc: torch.Tensor, seq_len: torch.Tensor, max_len_in_batch: int, alloc_mem_index: torch.Tensor
):
"""in-place update block loc mapping based on the sequence length of the inputs in current bath"""
start_index = 0
seq_len_numpy = seq_len.cpu().numpy()
for i, cur_seq_len in enumerate(seq_len_numpy):
b_loc[i, max_len_in_batch - cur_seq_len : max_len_in_batch] = alloc_mem_index[
start_index : start_index + cur_seq_len
]
start_index += cur_seq_len
return
@classmethod
def init_from_batch(
cls,
batch: torch.Tensor,
max_input_len: int,
max_output_len: int,
cache_manager: MemoryManager,
):
if not isinstance(batch, (BatchEncoding, dict, list, torch.Tensor)):
raise TypeError(f"batch type {type(batch)} is not supported in prepare_batch_state")
input_ids_list = None
attention_mask = None
if isinstance(batch, (BatchEncoding, dict)):
input_ids_list = batch["input_ids"]
attention_mask = batch["attention_mask"]
else:
input_ids_list = batch
if isinstance(input_ids_list[0], int): # for a single input
input_ids_list = [input_ids_list]
attention_mask = [attention_mask] if attention_mask is not None else attention_mask
batch_size = len(input_ids_list)
seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
start_index = 0
max_len_in_batch = -1
if isinstance(batch, (BatchEncoding, dict)):
for i, attn_mask in enumerate(attention_mask):
curr_seq_len = len(attn_mask)
seq_lengths[i] = curr_seq_len
seq_start_indexes[i] = start_index
start_index += curr_seq_len
max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch
else:
length = max(len(input_id) for input_id in input_ids_list)
for i, input_ids in enumerate(input_ids_list):
curr_seq_len = length
seq_lengths[i] = curr_seq_len
seq_start_indexes[i] = start_index
start_index += curr_seq_len
max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch
block_loc = torch.zeros((batch_size, max_input_len + max_output_len), dtype=torch.long, device="cuda")
return cls(
batch_size=batch_size,
max_len_in_batch=max_len_in_batch,
seq_len=seq_lengths.to("cuda"),
start_loc=seq_start_indexes.to("cuda"),
block_loc=block_loc,
decode_layer_id=0,
past_key_values_len=0,
is_context_stage=True,
cache_manager=cache_manager,
)