mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 03:20:52 +00:00
[inference] Refactor inference architecture (#5057)
* [inference] support only TP (#4998) * support only tp * enable tp * add support for bloom (#5008) * [refactor] refactor gptq and smoothquant llama (#5012) * refactor gptq and smoothquant llama * fix import error * fix linear import torch-int * fix smoothquant llama import error * fix import accelerate error * fix bug * fix import smooth cuda * fix smoothcuda * [Inference Refactor] Merge chatglm2 with pp and tp (#5023) merge chatglm with pp and tp * [Refactor] remove useless inference code (#5022) * remove useless code * fix quant model * fix test import bug * mv original inference legacy * fix chatglm2 * [Refactor] refactor policy search and quant type controlling in inference (#5035) * [Refactor] refactor policy search and quant type controling in inference * [inference] update readme (#5051) * update readme * update readme * fix architecture * fix table * fix table * [inference] udpate example (#5053) * udpate example * fix run.sh * fix rebase bug * fix some errors * update readme * add some features * update interface * update readme * update benchmark * add requirements-infer --------- Co-authored-by: Bin Jia <45593998+FoolPlayer@users.noreply.github.com> Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>
This commit is contained in:
83
colossalai/legacy/inference/pipeline/README.md
Normal file
83
colossalai/legacy/inference/pipeline/README.md
Normal file
@@ -0,0 +1,83 @@
|
||||
# 🐳 Pipeline Inference
|
||||
|
||||
## Table of Contents
|
||||
- [💡 Introduction](#introduction)
|
||||
- [🔗 Design](#design)
|
||||
- [🔨 Usage](#usage)
|
||||
- [Example](#example)
|
||||
- [Quick start](#quick-start)
|
||||
- [📊 Performance](#performance)
|
||||
|
||||
## Introduction
|
||||
|
||||
`Pipeline Inference` is a module designed to make inference on a pipeline way. In inference systems, although there is no need to store intermediate information such as activations during forward propagation for backward propagation, the weights of some larger models still cannot fit on a single GPU for inference. This requires us to use model parallelism and other methods to reduce the memory occupation on a single GPU. Pipeline parallelism, as one of the traditional model parallelism approaches, has been widely used due to its reduced all-reduce communication requirements and simple layout. The main issue with pipeline parallelism, known as bubbles, can be almost eliminated in inference because the backward propagation that causes bubbles no longer exists in inference. This makes pipeline parallelism almost bubble-free in the ideal scenario where the sequence length is the same across the pipeline.
|
||||
|
||||
## Design
|
||||
|
||||
Pipeline Inference is composed of three parts: `PPInferEngine`, `MicroBatchManager` and `generate` [schedule](https://github.com/hpcaitech/ColossalAI/blob/feature/pipeline-infer/colossalai/pipeline/schedule/generate.py).
|
||||
|
||||
1. `PPInderEngine` is the High-Level API for users to use. It is responsible for the following tasks:
|
||||
- Initialize the pipeline inference environment with `PipelineStageManager` and model with `ShardFormer`.
|
||||
- Run the pipeline inference model.
|
||||
|
||||
2. `MicroBatchManager` is a structure to manage the micro-batch information. It is responsible for the following tasks:
|
||||
- Record each micro-batch information, like generated new tokens and kvcache.
|
||||
- Record each micro-batch inference state, like prefill, generate or done.
|
||||
- Update the micro-batch information.
|
||||
|
||||
3. `generate` schedule implements the simple pipeline inference layout. When pipeline size is 2, we use `torch.distributed.P2Pop` to implement the communication between stages, mainly to solve the race communication. When pipeline size is larger than 2, we use `torch.distributed.broadcast` which is faster than `torch.distributed.P2Pop`.
|
||||
|
||||
## Usage
|
||||
|
||||
### Example
|
||||
```python
|
||||
from colossalai.inference import PPInferEngine
|
||||
from colossalai.inference.pipeline.policies import LlamaModelInferPolicy
|
||||
import colossalai
|
||||
from transformers import LlamaForCausalLM, LlamaTokenizer
|
||||
|
||||
colossalai.launch_from_torch(config={})
|
||||
|
||||
model = LlamaForCausalLM.from_pretrained("/path/to/model")
|
||||
tokenizer = LlamaTokenizer.from_pretrained("/path/to/model")
|
||||
|
||||
# assume the model is inferred with 2 pipeline stages
|
||||
inferengine = PPInferEngine(pp_size=2, model=model, model_policy=LlamaModelInferPolicy(), new_length=32)
|
||||
|
||||
input = ["Introduce a landmark in London","Introduce a landmark in Singapore"]
|
||||
data = tokenizer(input, return_tensors='pt')
|
||||
output = inferengine.inference(data.to('cuda'))
|
||||
print(tokenizer.batch_decode(output))
|
||||
```
|
||||
|
||||
## Performance
|
||||
|
||||
We conducted multiple benchmark tests to evaluate the performance. We compared the inference `latency` and `throughputs` between `Pipeline Inference` and `hugging face` pipeline. The test environment is 2 * A10, 20G / 2 * A800, 80G.
|
||||
|
||||
### Llama Throughput (tokens/s) | input length=1024, output length=128
|
||||
|
||||
#### A10 7b, fp16
|
||||
| batch_size(micro_batch size)| 2(1) | 4(2) | 8(4) | 16(8) | 32(8) | 32(16)|
|
||||
| :---: | :---: | :---: | :---: | :---: | :---: | :---:|
|
||||
| Pipeline Inference | 40.35 | 77.1 | 139.03 | 232.7 | 257.81 | OOM |
|
||||
| Hugging Face | 41.43 | 65.30 | 91.93 | 114.62 | OOM| OOM |
|
||||
|
||||
#### A10 13b, fp16
|
||||
| batch_size(micro_batch size)| 2(1) | 4(2) | 8(4) | 16(4) |
|
||||
| :---: | :---: | :---: | :---: | :---: |
|
||||
| Pipeline Inference | 25.39 | 47.09 | 83.7 | 89.46 |
|
||||
| Hugging Face | 23.48 | 37.59 | 53.44 | OOM |
|
||||
|
||||
|
||||
#### A800 7b, fp16
|
||||
| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(8) | 32(16) |
|
||||
| :---: | :---: | :---: | :---: | :---: | :---: |
|
||||
| Pipeline Inference| 57.97 | 110.13 | 213.33 | 389.86 | 670.12 |
|
||||
| Hugging Face | 42.44 | 76.5 | 151.97 | 212.88 | 256.13 |
|
||||
|
||||
|
||||
#### A800 13b, fp16
|
||||
| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(8) | 32(16) |
|
||||
| :---: | :---: | :---: | :---: | :---: | :---: |
|
||||
| Pipeline Inference | 41.78 | 94.18 | 172.67| 310.75| 470.15 |
|
||||
| Hugging Face | 36.57 | 68.4 | 105.81 | 139.51 | 166.34 |
|
3
colossalai/legacy/inference/pipeline/__init__.py
Normal file
3
colossalai/legacy/inference/pipeline/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .microbatch_manager import MicroBatchManager
|
||||
|
||||
__all__ = ["MicroBatchManager"]
|
134
colossalai/legacy/inference/pipeline/benchmark/benchmark.py
Normal file
134
colossalai/legacy/inference/pipeline/benchmark/benchmark.py
Normal file
@@ -0,0 +1,134 @@
|
||||
import argparse
|
||||
import time
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import transformers
|
||||
|
||||
import colossalai
|
||||
from colossalai.inference import PPInferEngine
|
||||
from colossalai.inference.pipeline.policies import LlamaModelInferPolicy
|
||||
|
||||
GIGABYTE = 1024**3
|
||||
MEGABYTE = 1024 * 1024
|
||||
|
||||
colossalai.launch_from_torch(config={})
|
||||
|
||||
|
||||
def data_gen(batch_size: int = 4, seq_len: int = 512):
|
||||
input_ids = torch.randint(10, 30000, (1, seq_len), dtype=torch.int32)
|
||||
attention_mask = torch.ones((1, seq_len), dtype=torch.int32)
|
||||
data = dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||
for k, v in data.items():
|
||||
if torch.is_tensor(v) or "Tensor" in v.__class__.__name__:
|
||||
new_shape = [1] * v.dim()
|
||||
new_shape[0] = batch_size
|
||||
data[k] = v.to("cuda").repeat(*new_shape)
|
||||
return data
|
||||
|
||||
|
||||
def print_details_info(timestamps, model_config, args, whole_end2end):
|
||||
if dist.get_rank() == 0:
|
||||
prefill = []
|
||||
encoder = []
|
||||
end2end = []
|
||||
for timestamp in timestamps:
|
||||
prefill.append(timestamp[1] - timestamp[0])
|
||||
encoder.append(
|
||||
sum(timestamp[i + 1] - timestamp[i] for i in range(1, len(timestamp) - 1)) / (len(timestamp) - 2)
|
||||
)
|
||||
end2end.append(timestamp[-1] - timestamp[0])
|
||||
print(whole_end2end)
|
||||
with open(
|
||||
f"{args.log_path}/llama-{args.model}{args.dtype}_pp{args.pp_size}_{args.seq_len}_{args.new_length}_bsz{args.batch_size}_mbsz{args.mb_size}.log",
|
||||
"w+",
|
||||
) as f:
|
||||
mb_avg_end2end = sum(end2end) / len(end2end)
|
||||
mb_avg_latency = mb_avg_end2end / (args.new_length * args.mb_size)
|
||||
whole_avg_latency = whole_end2end / (args.new_length * args.batch_size)
|
||||
num_layers = getattr(model_config, "num_layers", model_config.num_hidden_layers)
|
||||
num_parameters = num_layers * model_config.hidden_size * model_config.hidden_size * 12 / args.pp_size
|
||||
if args.dtype in ["fp16", "bf16"]:
|
||||
num_bytes = 2
|
||||
else:
|
||||
num_bytes = 4
|
||||
|
||||
f.write(
|
||||
f"llama-{args.model}{args.dtype}_pp{args.pp_size}, input_len:{args.seq_len}, output_len:{args.new_length}, bsz:{args.batch_size}, mbsz:{args.mb_size}\n"
|
||||
)
|
||||
f.write("Average prefill time: {0:8.2f} ms\n".format(sum(prefill) / len(prefill) * 1000))
|
||||
f.write("Average encode time: {0:8.2f} ms\n".format(sum(encoder) / len(encoder) * 1000))
|
||||
f.write("Average micro batch end2end time: {0:8.2f} ms\n".format(mb_avg_end2end * 1000))
|
||||
f.write("Average micro batch Per Token Latency: {0:8.2f} ms\n".format(mb_avg_latency * 1000))
|
||||
f.write("Whole batch end2end time: {0:8.2f} ms\n".format(whole_end2end * 1000))
|
||||
f.write("Whole batch Per Token Latency: {0:8.2f} ms\n".format(whole_avg_latency * 1000))
|
||||
f.write("Throughput: {} tokens/s\n".format((1000 / (whole_avg_latency * 1000))))
|
||||
f.write("flops: {0:8.2f} TFlops/s\n".format(1 / whole_avg_latency * num_parameters * num_bytes / 1e12))
|
||||
f.write("----------------------------------------------------------\n")
|
||||
|
||||
if torch.cuda.is_available():
|
||||
current_device = torch.cuda.current_device()
|
||||
|
||||
# free memory and the total available memory in bytes
|
||||
global_free_memory, total_GPU_memory_occupied = torch.cuda.mem_get_info()
|
||||
memory_allocated = torch.cuda.memory_allocated()
|
||||
max_memory_allocated = torch.cuda.max_memory_allocated()
|
||||
memory_reserved = torch.cuda.memory_reserved()
|
||||
max_memory_reserved = torch.cuda.max_memory_reserved()
|
||||
with open(
|
||||
f"{args.log_path}/llama-{args.model}{args.dtype}_pp{args.pp_size}_{args.seq_len}_{args.new_length}_bsz{args.batch_size}_mbsz{args.mb_size}.log",
|
||||
"a",
|
||||
) as f:
|
||||
f.write(
|
||||
f"\nCurrently using GPU: {current_device}\n"
|
||||
f"free memory : {global_free_memory / GIGABYTE:.4f} GB,\n"
|
||||
f"total memory: {total_GPU_memory_occupied / GIGABYTE:.4f} GB,\n"
|
||||
f"memory allocated: {memory_allocated / GIGABYTE:.4f} GB,\n"
|
||||
f"Max CUDA memory allocated: {max_memory_allocated / GIGABYTE:.4f} GB,\n"
|
||||
f"memory reserved/cached: {memory_reserved / GIGABYTE:.4f} GB,\n"
|
||||
f"Max CUDA memory reserved/cached: {max_memory_reserved / GIGABYTE:.4f} GB,\n"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model", default="toy", help="the size of model")
|
||||
parser.add_argument("-b", "--batch_size", type=int, default=8, help="batch size")
|
||||
parser.add_argument("-s", "--seq_len", type=int, default=8, help="sequence length")
|
||||
parser.add_argument("--new_length", type=int, default=4, help="new tokens length")
|
||||
parser.add_argument("--mb_size", type=int, default=1, help="micro_batch_size")
|
||||
parser.add_argument("--pp_size", type=int, default=2, help="pipeline size")
|
||||
parser.add_argument("--log_path", type=str, default="./log", help="where to store the benchmark log")
|
||||
parser.add_argument("--dtype", type=str, default="fp16", help="data type")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.model == "toy":
|
||||
model = transformers.LlamaForCausalLM(transformers.LlamaConfig(num_hidden_layers=8))
|
||||
elif args.model == "7b":
|
||||
model = transformers.LlamaForCausalLM(transformers.AutoConfig.from_pretrained("decapoda-research/llama-7b-hf"))
|
||||
elif args.model == "13b":
|
||||
model = transformers.LlamaForCausalLM(transformers.AutoConfig.from_pretrained("decapoda-research/llama-13b-hf"))
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
engine = PPInferEngine(
|
||||
pp_size=args.pp_size,
|
||||
dtype=args.dtype,
|
||||
micro_batch_size=args.mb_size,
|
||||
new_length=args.new_length,
|
||||
model=model,
|
||||
model_policy=LlamaModelInferPolicy(),
|
||||
verbose=True,
|
||||
max_batch_size=args.mb_size,
|
||||
max_input_len=args.seq_len,
|
||||
max_output_len=args.seq_len + args.new_length + 256,
|
||||
)
|
||||
data = data_gen(args.batch_size, args.seq_len)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
whole_end2end = time.time()
|
||||
output, timestamps = engine.inference([data])
|
||||
torch.cuda.synchronize()
|
||||
whole_end2end = time.time() - whole_end2end
|
||||
|
||||
print_details_info(timestamps, model.config, args, whole_end2end)
|
50
colossalai/legacy/inference/pipeline/benchmark/run.sh
Normal file
50
colossalai/legacy/inference/pipeline/benchmark/run.sh
Normal file
@@ -0,0 +1,50 @@
|
||||
script_dir=$(cd "$(dirname "$0")" && pwd)
|
||||
cd "${script_dir}"
|
||||
|
||||
# 7b, fp16, 2 gpu, 1024, 128
|
||||
for BATCH_SIZE in 2 4 8 16; do
|
||||
CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \
|
||||
--model="7b" \
|
||||
--dtype="fp16" \
|
||||
--batch_size=${BATCH_SIZE} \
|
||||
--seq_len=1024 \
|
||||
--new_length=128 \
|
||||
--mb_size=$((${BATCH_SIZE}/2)) \
|
||||
--pp_size=2
|
||||
done
|
||||
|
||||
# 7b, fp16, 2 gpu, 512, 512
|
||||
for BATCH_SIZE in 2 4 8 16 32; do
|
||||
CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \
|
||||
--model="7b" \
|
||||
--dtype="fp16" \
|
||||
--batch_size=${BATCH_SIZE} \
|
||||
--seq_len=512 \
|
||||
--new_length=512 \
|
||||
--mb_size=$((${BATCH_SIZE}/2)) \
|
||||
--pp_size=2
|
||||
done
|
||||
|
||||
# 7b, fp16, 2 gpu, 1024, 128
|
||||
for BATCH_SIZE in 2 4 8; do
|
||||
CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \
|
||||
--model="13b" \
|
||||
--dtype="fp16" \
|
||||
--batch_size=${BATCH_SIZE} \
|
||||
--seq_len=1024 \
|
||||
--new_length=128 \
|
||||
--mb_size=$((${BATCH_SIZE}/2)) \
|
||||
--pp_size=2
|
||||
done
|
||||
|
||||
# 13b, fp16, 2 gpu, 512, 512
|
||||
for BATCH_SIZE in 2 4 8 16; do
|
||||
CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \
|
||||
--model="13b" \
|
||||
--dtype="fp16" \
|
||||
--batch_size=${BATCH_SIZE} \
|
||||
--seq_len=512 \
|
||||
--new_length=512 \
|
||||
--mb_size=$((${BATCH_SIZE}/2)) \
|
||||
--pp_size=2
|
||||
done
|
249
colossalai/legacy/inference/pipeline/microbatch_manager.py
Normal file
249
colossalai/legacy/inference/pipeline/microbatch_manager.py
Normal file
@@ -0,0 +1,249 @@
|
||||
from enum import Enum
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
|
||||
from ..tensor_parallel.batch_infer_state import BatchInferState
|
||||
from ..tensor_parallel.kvcache_manager import MemoryManager
|
||||
|
||||
__all__ = "MicroBatchManager"
|
||||
|
||||
|
||||
class Status(Enum):
|
||||
PREFILL = 1
|
||||
GENERATE = 2
|
||||
DONE = 3
|
||||
COOLDOWN = 4
|
||||
|
||||
|
||||
class MicroBatchDescription:
|
||||
"""
|
||||
This is the class to record the infomation of each microbatch, and also do some update operation.
|
||||
This clase is the base class of `HeadMicroBatchDescription` and `BodyMicroBatchDescription`, for more
|
||||
details, please refer to the doc of these two classes blow.
|
||||
|
||||
Args:
|
||||
inputs_dict (Dict[str, torch.Tensor]): the inputs of current stage. The key should have `input_ids` and `attention_mask`.
|
||||
output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
inputs_dict: Dict[str, torch.Tensor],
|
||||
max_input_len: int,
|
||||
max_output_len: int,
|
||||
cache_manager: MemoryManager,
|
||||
) -> None:
|
||||
self.mb_length = inputs_dict["input_ids"].shape[-1]
|
||||
self.target_length = self.mb_length + max_output_len
|
||||
self.infer_state = BatchInferState.init_from_batch(
|
||||
batch=inputs_dict, max_input_len=max_input_len, max_output_len=max_output_len, cache_manager=cache_manager
|
||||
)
|
||||
# print(f"[init] {inputs_dict}, {max_input_len}, {max_output_len}, {cache_manager}, {self.infer_state}")
|
||||
|
||||
def update(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
"""
|
||||
Return the state of current micro batch, when current length is equal to target length,
|
||||
the state is DONE, otherwise GENERATE
|
||||
|
||||
"""
|
||||
# TODO: add the condition for early stopping
|
||||
if self.cur_length == self.target_length:
|
||||
return Status.DONE
|
||||
elif self.cur_length == self.target_length - 1:
|
||||
return Status.COOLDOWN
|
||||
else:
|
||||
return Status.GENERATE
|
||||
|
||||
@property
|
||||
def cur_length(self):
|
||||
"""
|
||||
Return the current sequnence length of micro batch
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class HeadMicroBatchDescription(MicroBatchDescription):
|
||||
"""
|
||||
This class is used to record the infomation of the first stage of pipeline, the first stage should have attributes `input_ids` and `attention_mask`
|
||||
and `new_tokens`, and the `new_tokens` is the tokens generated by the first stage. Also due to the schdule of pipeline, the operation to update the
|
||||
information and the condition to determine the state is different from other stages.
|
||||
|
||||
Args:
|
||||
inputs_dict (Dict[str, torch.Tensor]): the inputs of current stage. The key should have `input_ids` and `attention_mask`.
|
||||
output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
inputs_dict: Dict[str, torch.Tensor],
|
||||
max_input_len: int,
|
||||
max_output_len: int,
|
||||
cache_manager: MemoryManager,
|
||||
) -> None:
|
||||
super().__init__(inputs_dict, max_input_len, max_output_len, cache_manager)
|
||||
assert inputs_dict is not None
|
||||
assert inputs_dict.get("input_ids") is not None and inputs_dict.get("attention_mask") is not None
|
||||
self.input_ids = inputs_dict["input_ids"]
|
||||
self.attn_mask = inputs_dict["attention_mask"]
|
||||
self.new_tokens = None
|
||||
|
||||
def update(self, new_token: torch.Tensor = None):
|
||||
if new_token is not None:
|
||||
self._update_newtokens(new_token)
|
||||
if self.state is not Status.DONE and new_token is not None:
|
||||
self._update_attnmask()
|
||||
|
||||
def _update_newtokens(self, new_token: torch.Tensor):
|
||||
if self.new_tokens is None:
|
||||
self.new_tokens = new_token
|
||||
else:
|
||||
self.new_tokens = torch.cat([self.new_tokens, new_token], dim=-1)
|
||||
|
||||
def _update_attnmask(self):
|
||||
self.attn_mask = torch.cat(
|
||||
(self.attn_mask, torch.ones((self.attn_mask.shape[0], 1), dtype=torch.int64, device="cuda")), dim=-1
|
||||
)
|
||||
|
||||
@property
|
||||
def cur_length(self):
|
||||
"""
|
||||
When there is no new_token, the length is mb_length, otherwise the sequence length is `mb_length` plus the length of new_token
|
||||
|
||||
"""
|
||||
if self.new_tokens is None:
|
||||
return self.mb_length
|
||||
else:
|
||||
return self.mb_length + len(self.new_tokens[0])
|
||||
|
||||
|
||||
class BodyMicroBatchDescription(MicroBatchDescription):
|
||||
"""
|
||||
This class is used to record the infomation of the stages except the first stage of pipeline, the stages should have attributes `hidden_states` and `past_key_values`,
|
||||
|
||||
Args:
|
||||
inputs_dict (Dict[str, torch.Tensor]): will always be `None`. Other stages only receive hiddenstates from previous stage.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
inputs_dict: Dict[str, torch.Tensor],
|
||||
max_input_len: int,
|
||||
max_output_len: int,
|
||||
cache_manager: MemoryManager,
|
||||
) -> None:
|
||||
super().__init__(inputs_dict, max_input_len, max_output_len, cache_manager)
|
||||
|
||||
@property
|
||||
def cur_length(self):
|
||||
"""
|
||||
When there is no kv_cache, the length is mb_length, otherwise the sequence length is `kv_cache[0][0].shape[-2]` plus 1
|
||||
|
||||
"""
|
||||
return self.infer_state.seq_len.max().item()
|
||||
|
||||
|
||||
class MicroBatchManager:
|
||||
"""
|
||||
MicroBatchManager is a class that manages the micro batch.
|
||||
|
||||
Args:
|
||||
stage (int): stage id of current stage.
|
||||
micro_batch_size (int): the micro batch size.
|
||||
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stage: int,
|
||||
micro_batch_size: int,
|
||||
micro_batch_buffer_size: int,
|
||||
max_input_len: int,
|
||||
max_output_len: int,
|
||||
cache_manager_list: MemoryManager,
|
||||
):
|
||||
self.stage = stage
|
||||
self.micro_batch_size = micro_batch_size
|
||||
self.buffer_size = micro_batch_buffer_size
|
||||
self.max_input_len = max_input_len
|
||||
self.max_output_len = max_output_len
|
||||
self.cache_manager_list = cache_manager_list
|
||||
self.mb_descrption_buffer = {}
|
||||
self.new_tokens_buffer = {}
|
||||
self.idx = 0
|
||||
|
||||
def add_descrption(self, inputs_dict: Dict[str, torch.Tensor]):
|
||||
if self.stage == 0:
|
||||
self.mb_descrption_buffer[self.idx] = HeadMicroBatchDescription(
|
||||
inputs_dict, self.max_input_len, self.max_output_len, self.cache_manager_list[self.idx]
|
||||
)
|
||||
else:
|
||||
self.mb_descrption_buffer[self.idx] = BodyMicroBatchDescription(
|
||||
inputs_dict, self.max_input_len, self.max_output_len, self.cache_manager_list[self.idx]
|
||||
)
|
||||
|
||||
def step(self, new_token: torch.Tensor = None):
|
||||
"""
|
||||
Update the state if microbatch manager, 2 conditions.
|
||||
1. For first stage in PREFILL, receive inputs and outputs, `_add_descrption` will save its inputs.
|
||||
2. For other conditon, only receive the output of previous stage, and update the descrption.
|
||||
|
||||
Args:
|
||||
inputs_dict (Dict[str, torch.Tensor]): the inputs of current stage. The key should have `input_ids` and `attention_mask`.
|
||||
output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`.
|
||||
new_token (torch.Tensor): the new token generated by current stage.
|
||||
"""
|
||||
# Add descrption first if the descrption is None
|
||||
self.cur_descrption.update(new_token)
|
||||
return self.cur_state
|
||||
|
||||
def export_new_tokens(self):
|
||||
new_tokens_list = []
|
||||
for i in self.mb_descrption_buffer.values():
|
||||
new_tokens_list.extend(i.new_tokens.tolist())
|
||||
return new_tokens_list
|
||||
|
||||
def is_micro_batch_done(self):
|
||||
if len(self.mb_descrption_buffer) == 0:
|
||||
return False
|
||||
for mb in self.mb_descrption_buffer.values():
|
||||
if mb.state != Status.DONE:
|
||||
return False
|
||||
return True
|
||||
|
||||
def clear(self):
|
||||
self.mb_descrption_buffer.clear()
|
||||
for cache in self.cache_manager_list:
|
||||
cache.free_all()
|
||||
|
||||
def next(self):
|
||||
self.idx = (self.idx + 1) % self.buffer_size
|
||||
|
||||
def _remove_descrption(self):
|
||||
self.mb_descrption_buffer.pop(self.idx)
|
||||
|
||||
@property
|
||||
def cur_descrption(self) -> MicroBatchDescription:
|
||||
return self.mb_descrption_buffer.get(self.idx)
|
||||
|
||||
@property
|
||||
def cur_infer_state(self):
|
||||
if self.cur_descrption is None:
|
||||
return None
|
||||
return self.cur_descrption.infer_state
|
||||
|
||||
@property
|
||||
def cur_state(self):
|
||||
"""
|
||||
Return the state of current micro batch, when current descrption is None, the state is PREFILL
|
||||
|
||||
"""
|
||||
if self.cur_descrption is None:
|
||||
return Status.PREFILL
|
||||
return self.cur_descrption.state
|
Reference in New Issue
Block a user