mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 17:17:05 +00:00
[Pipeline Inference] Sync pipeline inference branch to main (#4820)
* [pipeline inference] pipeline inference (#4492) * add pp stage manager as circle stage * fix a bug when create process group * add ppinfer basic framework * add micro batch manager and support kvcache-pp gpt2 fwd * add generate schedule * use mb size to control mb number * support generate with kv cache * add output, remove unused code * add test * reuse shardformer to build model * refactor some code and use the same attribute name of hf * fix review and add test for generation * remove unused file * fix CI * add cache clear * fix code error * fix typo * [Pipeline inference] Modify to tieweight (#4599) * add pp stage manager as circle stage * fix a bug when create process group * add ppinfer basic framework * add micro batch manager and support kvcache-pp gpt2 fwd * add generate schedule * use mb size to control mb number * support generate with kv cache * add output, remove unused code * add test * reuse shardformer to build model * refactor some code and use the same attribute name of hf * fix review and add test for generation * remove unused file * modify the way of saving newtokens * modify to tieweight * modify test * remove unused file * solve review * add docstring * [Pipeline inference] support llama pipeline inference (#4647) * support llama pipeline inference * remove tie weight operation * [pipeline inference] Fix the blocking of communication when ppsize is 2 (#4708) * add benchmark verbose * fix export tokens * fix benchmark verbose * add P2POp style to do p2p communication * modify schedule as p2p type when ppsize is 2 * remove unused code and add docstring * [Pipeline inference] Refactor code, add docsting, fix bug (#4790) * add benchmark script * update argparse * fix fp16 load * refactor code style * add docstring * polish code * fix test bug * [Pipeline inference] Add pipeline inference docs (#4817) * add readme doc * add a ico * Add performance * update table of contents * refactor code (#4873)
This commit is contained in:
@@ -0,0 +1,3 @@
|
||||
from .pipeline import PPInferEngine
|
||||
|
||||
__all__ = ['PPInferEngine']
|
||||
|
84
colossalai/inference/pipeline/README.md
Normal file
84
colossalai/inference/pipeline/README.md
Normal file
@@ -0,0 +1,84 @@
|
||||
# 🐳 Pipeline Inference
|
||||
|
||||
## Table of Contents
|
||||
- [💡 Introduction](#introduction)
|
||||
- [🔗 Design](#design)
|
||||
- [🔨 Usage](#usage)
|
||||
- [Example](#example)
|
||||
- [Quick start](#quick-start)
|
||||
- [📊 Performance](#performance)
|
||||
|
||||
## Introduction
|
||||
|
||||
`Pipeline Inference` is a module designed to make inference on a pipeline way. In inference systems, although there is no need to store intermediate information such as activations during forward propagation for backward propagation, the weights of some larger models still cannot fit on a single GPU for inference. This requires us to use model parallelism and other methods to reduce the memory occupation on a single GPU. Pipeline parallelism, as one of the traditional model parallelism approaches, has been widely used due to its reduced all-reduce communication requirements and simple layout. The main issue with pipeline parallelism, known as bubbles, can be almost eliminated in inference because the backward propagation that causes bubbles no longer exists in inference. This makes pipeline parallelism almost bubble-free in the ideal scenario where the sequence length is the same across the pipeline.
|
||||
|
||||
## Design
|
||||
|
||||
Pipeline Inference is composed of three parts: `PPInferEngine`, `MicroBatchManager` and `generate` [schedule](https://github.com/hpcaitech/ColossalAI/blob/feature/pipeline-infer/colossalai/pipeline/schedule/generate.py).
|
||||
|
||||
1. `PPInderEngine` is the High-Level API for users to use. It is responsible for the following tasks:
|
||||
- Initialize the pipeline inference environment with `PipelineStageManager` and mdoel with `ShardFormer`.
|
||||
- Run the pipeline inference model.
|
||||
|
||||
2. `MicroBatchManager` is a structure to manage the micro-batch information. It is responsible for the following tasks:
|
||||
- Record each micro-batch information, like generated new tokens and kvcache.
|
||||
- Record each micro-batch inference state, like prefill, generate or done.
|
||||
- Update the micro-batch information.
|
||||
|
||||
3. `generate` schedule implements the simple pipeline inference layout. When pipeline size is 2, we use `torch.distributed.P2Pop` to implement the communication between stages, mainly to solve the race communication. When pipeline size is larger than 2, we use `torch.distributed.broadcast` which is faster than `torch.distributed.P2Pop`.
|
||||
|
||||
## Usage
|
||||
|
||||
### Example
|
||||
```python
|
||||
from colossalai.pipeline import PPInferEngine
|
||||
# Suppose the pipeline size is 2, and use fp16 to do infenrence. Use Llama as an example.
|
||||
model = LlamaForCausalLM.from_pretrained('/path/to/model')
|
||||
inputs = tokenizer("Hello, my dog is cute", "What a good day", return_tensors="pt")
|
||||
engine = PPInferEngine(
|
||||
pp_size=2,
|
||||
dtype='fp16',
|
||||
micro_batch_size=1,
|
||||
new_length=10,
|
||||
model=model,
|
||||
model_policy=LlamaForCausalLMPipelinePolicy())
|
||||
|
||||
output = engine.inference([inputs])
|
||||
|
||||
```
|
||||
|
||||
### Quick start
|
||||
```shell
|
||||
cd benchmark
|
||||
sh run.sh
|
||||
```
|
||||
|
||||
## Performance
|
||||
|
||||
We conducted multiple benchmark tests to evaluate the performance. We compared the inference `latency` and `throughputs` between `Pipeline Inference` and `hugging face` pipeline. The test environment is 2*A10, 20G.
|
||||
|
||||
### Llama Throughput(tokens/s)
|
||||
|
||||
#### 7b, fp16
|
||||
| batch_size(micro_batch size)| 2(1) | 4(2) | 8(4) | 16(8) | 32(8) | 32(16)|
|
||||
| :---: | :---: | :---: | :---: | :---: | :---: | :---:|
|
||||
| Pipeline Inference(1024, 128) | 33.31 | 59.98 | 98.92 | 143.47 | 152.61 | OOM |
|
||||
| Hugging Face(1024, 128) | 41.43 | 65.30 | 91.93 | 114.62 | OOM| OOM |
|
||||
| Pipeline Inference(512, 512) | 43.37 | 82.81 | 148.03 | 229.06 | 238.67 | 312.82 |
|
||||
| Hugging Face(512, 512) | 49.13 | 84.91 | 132.87 | 178.30 | OOM| OOM |
|
||||
|
||||
#### 7b, fp32
|
||||
| batch_size(micro_batch size)| 2(1) | 4(2) | 8(4) | 16(4) |
|
||||
| :---: | :---: | :---: | :---: | :---: |
|
||||
| Pipeline Inference(1024, 128) | 20.61 | 31.23 | 45.20 | 47.46 |
|
||||
| Hugging Face(1024, 128) | 19.80 | 29.37| OOM | OOM |
|
||||
| Pipeline Inference(512, 512) | 28.07 | 46.76 | 79.35 | 81.70 |
|
||||
| Hugging Face(512, 512) | 25.67 | 43.97 | 60.67 | OOM |
|
||||
|
||||
#### 13b, fp16
|
||||
| batch_size(micro_batch size)| 2(1) | 4(2) | 8(4) | 16(4) |
|
||||
| :---: | :---: | :---: | :---: | :---: |
|
||||
| Pipeline Inference(1024, 128) | 21.73 | 38.06 | 61.02 | 64.30 |
|
||||
| Hugging Face(1024, 128) | 23.48 | 37.59 | 53.44 | OOM |
|
||||
| Pipeline Inference(512, 512) | 26.65 | 49.48 | 86.11 | 88.44 |
|
||||
| Hugging Face(512, 512) | 27.45 | 47.74 | 74.46 | OOM |
|
3
colossalai/inference/pipeline/__init__.py
Normal file
3
colossalai/inference/pipeline/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .engine import PPInferEngine
|
||||
|
||||
__all__ = ['PPInferEngine']
|
112
colossalai/inference/pipeline/benchmark/benchmark.py
Normal file
112
colossalai/inference/pipeline/benchmark/benchmark.py
Normal file
@@ -0,0 +1,112 @@
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import transformers
|
||||
|
||||
import colossalai
|
||||
import time
|
||||
from colossalai.inference import PPInferEngine
|
||||
from colossalai.inference.pipeline.policy.llama_ppinfer import LlamaForCausalLMPipelinePolicy
|
||||
import argparse
|
||||
GIGABYTE = 1024 ** 3
|
||||
MEGABYTE = 1024 * 1024
|
||||
|
||||
colossalai.launch_from_torch(config={})
|
||||
|
||||
def data_gen(batch_size: int=4, seq_len: int=512):
|
||||
input_ids = torch.randint(10, 30000, (1, seq_len), dtype=torch.int32)
|
||||
attention_mask = torch.ones((1, seq_len), dtype=torch.int32)
|
||||
data = dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||
for k, v in data.items():
|
||||
if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__:
|
||||
new_shape = [1] * v.dim()
|
||||
new_shape[0] = batch_size
|
||||
data[k] = v.to('cuda').repeat(*new_shape)
|
||||
return data
|
||||
|
||||
def print_details_info(timestamps, model_config, args, whole_end2end):
|
||||
if dist.get_rank() == 0:
|
||||
prefill = []
|
||||
encoder = []
|
||||
end2end = []
|
||||
for timestamp in timestamps:
|
||||
prefill.append(timestamp[1] - timestamp[0])
|
||||
encoder.append(
|
||||
sum(timestamp[i + 1] - timestamp[i] for i in range(1,len(timestamp) - 1)) / (len(timestamp) - 2))
|
||||
end2end.append(timestamp[-1] - timestamp[0])
|
||||
print(whole_end2end)
|
||||
with open(f"{args.log_path}/llama-{args.model}{args.dtype}_pp{args.pp_size}_{args.seq_len}_{args.new_length}_bsz{args.batch_size}_mbsz{args.mb_size}.log","w+") as f:
|
||||
mb_avg_end2end = sum(end2end)/len(end2end)
|
||||
mb_avg_latency = mb_avg_end2end/(args.new_length * args.mb_size)
|
||||
whole_avg_latency = whole_end2end/(args.new_length * args.batch_size)
|
||||
num_layers = getattr(model_config, "num_layers", model_config.num_hidden_layers)
|
||||
num_parameters = num_layers * model_config.hidden_size * model_config.hidden_size * 12 / args.pp_size
|
||||
if args.dtype in ['fp16','bf16']:
|
||||
num_bytes = 2
|
||||
else:
|
||||
num_bytes = 4
|
||||
|
||||
f.write(f"llama-{args.model}{args.dtype}_pp{args.pp_size}, input_len:{args.seq_len}, output_len:{args.new_length}, bsz:{args.batch_size}, mbsz:{args.mb_size}\n")
|
||||
f.write("Average prefill time: {0:8.2f} ms\n".format(sum(prefill)/len(prefill)*1000))
|
||||
f.write("Average encode time: {0:8.2f} ms\n".format(sum(encoder)/len(encoder)*1000))
|
||||
f.write("Average micro batch end2end time: {0:8.2f} ms\n".format(mb_avg_end2end*1000))
|
||||
f.write("Average micro batch Per Token Latency: {0:8.2f} ms\n".format(mb_avg_latency * 1000))
|
||||
f.write("Whole batch end2end time: {0:8.2f} ms\n".format(whole_end2end*1000))
|
||||
f.write("Whole batch Per Token Latency: {0:8.2f} ms\n".format(whole_avg_latency * 1000))
|
||||
f.write("Throughput: {} tokens/s\n".format((1000/(whole_avg_latency * 1000))))
|
||||
f.write("flops: {0:8.2f} TFlops/s\n".format(1/whole_avg_latency * num_parameters * num_bytes / 1e12))
|
||||
f.write("----------------------------------------------------------\n")
|
||||
|
||||
|
||||
if torch.cuda.is_available():
|
||||
current_device = torch.cuda.current_device()
|
||||
|
||||
# free memory and the total available memory in bytes
|
||||
global_free_memory, total_GPU_memory_occupied = torch.cuda.mem_get_info()
|
||||
memory_allocated = torch.cuda.memory_allocated()
|
||||
max_memory_allocated = torch.cuda.max_memory_allocated()
|
||||
memory_reserved = torch.cuda.memory_reserved()
|
||||
max_memory_reserved = torch.cuda.max_memory_reserved()
|
||||
with open(f"{args.log_path}/llama-{args.model}{args.dtype}_pp{args.pp_size}_{args.seq_len}_{args.new_length}_bsz{args.batch_size}_mbsz{args.mb_size}.log","a") as f:
|
||||
f.write(
|
||||
f"\nCurrently using GPU: {current_device}\n"
|
||||
f"free memory : {global_free_memory / GIGABYTE:.4f} GB,\n"
|
||||
f"total memory: {total_GPU_memory_occupied / GIGABYTE:.4f} GB,\n"
|
||||
f"memory allocated: {memory_allocated / GIGABYTE:.4f} GB,\n"
|
||||
f"Max CUDA memory allocated: {max_memory_allocated / GIGABYTE:.4f} GB,\n"
|
||||
f"memory reserved/cached: {memory_reserved / GIGABYTE:.4f} GB,\n"
|
||||
f"Max CUDA memory reserved/cached: {max_memory_reserved / GIGABYTE:.4f} GB,\n"
|
||||
)
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--model', default='toy', help='the size of model')
|
||||
parser.add_argument('-b', '--batch_size', type=int, default=8, help='batch size')
|
||||
parser.add_argument('-s', '--seq_len', type=int, default=8, help='sequence length')
|
||||
parser.add_argument('--new_length', type=int, default=4, help='new tokens length')
|
||||
parser.add_argument('--mb_size', type=int, default=1, help='micro_batch_size')
|
||||
parser.add_argument('--pp_size', type=int, default=2, help='pipeline size')
|
||||
parser.add_argument('--log_path', type=str, default='./log' ,help='where to store the benchmark log')
|
||||
parser.add_argument('--dtype', type=str, default='fp16', help='data type')
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.model == 'toy':
|
||||
model = transformers.LlamaForCausalLM(transformers.LlamaConfig(num_hidden_layers=8))
|
||||
elif args.model == '7b':
|
||||
model = transformers.LlamaForCausalLM(transformers.AutoConfig.from_pretrained('decapoda-research/llama-7b-hf'))
|
||||
elif args.model == '13b':
|
||||
model = transformers.LlamaForCausalLM(transformers.AutoConfig.from_pretrained('decapoda-research/llama-13b-hf'))
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
engine = PPInferEngine(pp_size=args.pp_size, dtype=args.dtype, micro_batch_size=args.mb_size, new_length=args.new_length, model=model, model_policy=LlamaForCausalLMPipelinePolicy(),verbose=True)
|
||||
data = data_gen(args.batch_size, args.seq_len)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
whole_end2end = time.time()
|
||||
output, timestamps = engine.inference([data])
|
||||
torch.cuda.synchronize()
|
||||
whole_end2end = time.time() - whole_end2end
|
||||
|
||||
print_details_info(timestamps, model.config, args, whole_end2end)
|
||||
|
50
colossalai/inference/pipeline/benchmark/run.sh
Normal file
50
colossalai/inference/pipeline/benchmark/run.sh
Normal file
@@ -0,0 +1,50 @@
|
||||
script_dir=$(cd "$(dirname "$0")" && pwd)
|
||||
cd "${script_dir}"
|
||||
|
||||
# 7b, fp32, 2 gpu, 1024, 128
|
||||
for BATCH_SIZE in 2 4 8 16; do
|
||||
CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \
|
||||
--model="7b" \
|
||||
--dtype="fp16" \
|
||||
--batch_size=${BATCH_SIZE} \
|
||||
--seq_len=1024 \
|
||||
--new_length=128 \
|
||||
--mb_size=$((${BATCH_SIZE}/2)) \
|
||||
--pp_size=2
|
||||
done
|
||||
|
||||
# 7b, fp32, 2 gpu, 512, 512
|
||||
for BATCH_SIZE in 2 4 8 16 32; do
|
||||
CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \
|
||||
--model="7b" \
|
||||
--dtype="fp16" \
|
||||
--batch_size=${BATCH_SIZE} \
|
||||
--seq_len=512 \
|
||||
--new_length=512 \
|
||||
--mb_size=$((${BATCH_SIZE}/2)) \
|
||||
--pp_size=2
|
||||
done
|
||||
|
||||
# 7b, fp32, 2 gpu, 1024, 128
|
||||
for BATCH_SIZE in 2 4 8; do
|
||||
CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \
|
||||
--model="13b" \
|
||||
--dtype="fp16" \
|
||||
--batch_size=${BATCH_SIZE} \
|
||||
--seq_len=1024 \
|
||||
--new_length=128 \
|
||||
--mb_size=$((${BATCH_SIZE}/2)) \
|
||||
--pp_size=2
|
||||
done
|
||||
|
||||
# 13b, fp16, 2 gpu, 512, 512
|
||||
for BATCH_SIZE in 2 4 8 16; do
|
||||
CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \
|
||||
--model="13b" \
|
||||
--dtype="fp16" \
|
||||
--batch_size=${BATCH_SIZE} \
|
||||
--seq_len=512 \
|
||||
--new_length=512 \
|
||||
--mb_size=$((${BATCH_SIZE}/2)) \
|
||||
--pp_size=2
|
||||
done
|
98
colossalai/inference/pipeline/engine.py
Normal file
98
colossalai/inference/pipeline/engine.py
Normal file
@@ -0,0 +1,98 @@
|
||||
from typing import Callable, List, Optional, Set, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.pipeline.schedule.generate import GenerateSchedule
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
from colossalai.shardformer.policies.base_policy import Policy
|
||||
|
||||
from .microbatch_manager import MicroBatchManager
|
||||
|
||||
|
||||
class PPInferEngine:
|
||||
'''
|
||||
PPInferEngine is a class that handles the pipeline parallel inference.
|
||||
|
||||
Args:
|
||||
pp_size (int): the number of pipeline stages.
|
||||
pp_model (`nn.Module`): the model already in pipeline parallelism style.
|
||||
model (`nn.Module`): the model not in pipeline style, and will be modified with `ShardFormer`.
|
||||
model_policy (`colossalai.shardformer.policies.base_policy.Policy`): the policy to shardformer model.
|
||||
micro_batch_size (int): the micro batch size.
|
||||
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
|
||||
new_length (int): the new length of the input sequence.
|
||||
early_stopping (bool): whether to stop early.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
from colossalai.ppinference import PPInferEngine
|
||||
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
||||
|
||||
model = transformers.GPT2LMHeadModel.from_pretrained('gpt2')
|
||||
# assume the model is infered with 4 pipeline stages
|
||||
inferengine = PPInferEngine(pp_size=4, model=model, model_policy={Your own policy for pipeline sharding})
|
||||
|
||||
input = ["Hello, my dog is cute, and I like"]
|
||||
tokenized_input = tokenizer(input, return_tensors='pt')
|
||||
output = engine.inference([tokenized_input])
|
||||
```
|
||||
|
||||
'''
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pp_size: int,
|
||||
dtype: str = 'fp16',
|
||||
pp_model: nn.Module = None,
|
||||
model: nn.Module = None,
|
||||
model_policy: Policy = None,
|
||||
new_length: int = 32,
|
||||
micro_batch_size: int = 1,
|
||||
micro_batch_buffer_size: int = None,
|
||||
verbose: bool = False,
|
||||
# TODO: implement early_stopping, and various gerneration options
|
||||
early_stopping: bool = False,
|
||||
do_sample: bool = False,
|
||||
num_beams: int = 1,
|
||||
) -> None:
|
||||
assert pp_model or (model and model_policy), "Either pp_model or model with model_policy should be provided."
|
||||
self.pp_size = pp_size
|
||||
self.pg_mesh = ProcessGroupMesh(pp_size)
|
||||
self.stage_manager = PipelineStageManager(self.pg_mesh, 0, True)
|
||||
self.mb_manager = MicroBatchManager(self.stage_manager.stage, new_length, micro_batch_size,
|
||||
micro_batch_buffer_size or pp_size)
|
||||
self.verbose = verbose
|
||||
self.schedule = GenerateSchedule(self.stage_manager, self.mb_manager, verbose)
|
||||
|
||||
assert dtype in ['fp16', 'fp32', 'bf16'], "dtype should be one of 'fp16', 'fp32', 'bf16'"
|
||||
if dtype == 'fp16':
|
||||
model.half()
|
||||
elif dtype == 'bf16':
|
||||
model.to(torch.bfloat16)
|
||||
self.model = pp_model or self._shardformer(model, model_policy)
|
||||
|
||||
def inference(self, input_list):
|
||||
out, timestamp = self.schedule.generate_step(self.model, iter(input_list))
|
||||
if self.verbose:
|
||||
return out, timestamp
|
||||
else:
|
||||
return out
|
||||
|
||||
def _shardformer(self, model, model_policy):
|
||||
shardconfig = ShardConfig(
|
||||
tensor_parallel_process_group=None,
|
||||
pipeline_stage_manager=self.stage_manager,
|
||||
enable_tensor_parallelism=False,
|
||||
enable_fused_normalization=False,
|
||||
enable_all_optimization=False,
|
||||
enable_flash_attention=False,
|
||||
enable_jit_fused=False,
|
||||
enable_sequence_parallelism=False,
|
||||
)
|
||||
shardformer = ShardFormer(shard_config=shardconfig)
|
||||
shard_model, _ = shardformer.optimize(model, model_policy)
|
||||
return shard_model.cuda()
|
236
colossalai/inference/pipeline/microbatch_manager.py
Normal file
236
colossalai/inference/pipeline/microbatch_manager.py
Normal file
@@ -0,0 +1,236 @@
|
||||
from enum import Enum
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
__all__ = 'MicroBatchManager'
|
||||
|
||||
|
||||
class Status(Enum):
|
||||
PREFILL = 1
|
||||
GENERATE = 2
|
||||
DONE = 3
|
||||
COOLDOWN = 4
|
||||
|
||||
|
||||
class MicroBatchDescription():
|
||||
"""
|
||||
This is the class to record the infomation of each microbatch, and also do some update operation.
|
||||
This clase is the base class of `HeadMicroBatchDescription` and `BodyMicroBatchDescription`, for more
|
||||
details, please refer to the doc of these two classes blow.
|
||||
|
||||
Args:
|
||||
inputs_dict (Dict[str, torch.Tensor]): the inputs of current stage. The key should have `input_ids` and `attention_mask`.
|
||||
output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
inputs_dict: Dict[str, torch.Tensor],
|
||||
output_dict: Dict[str, torch.Tensor],
|
||||
new_length: int,
|
||||
) -> None:
|
||||
assert output_dict.get('hidden_states') is not None
|
||||
self.mb_length = output_dict['hidden_states'].shape[-2]
|
||||
self.target_length = self.mb_length + new_length
|
||||
self.kv_cache = ()
|
||||
|
||||
def update(self, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None):
|
||||
if output_dict is not None:
|
||||
self._update_kvcache(output_dict['past_key_values'])
|
||||
|
||||
def _update_kvcache(self, kv_cache: Tuple):
|
||||
assert type(kv_cache) == tuple
|
||||
self.kv_cache = kv_cache
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
"""
|
||||
Return the state of current micro batch, when current length is equal to target length,
|
||||
the state is DONE, otherwise GENERATE
|
||||
|
||||
"""
|
||||
# TODO: add the condition for early stopping
|
||||
if self.cur_length == self.target_length:
|
||||
return Status.DONE
|
||||
elif self.cur_length == self.target_length - 1:
|
||||
return Status.COOLDOWN
|
||||
else:
|
||||
return Status.GENERATE
|
||||
|
||||
@property
|
||||
def cur_length(self):
|
||||
"""
|
||||
Return the current sequnence length of micro batch
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class HeadMicroBatchDescription(MicroBatchDescription):
|
||||
"""
|
||||
This class is used to record the infomation of the first stage of pipeline, the first stage should have attributes `input_ids` and `attention_mask`
|
||||
and `new_tokens`, and the `new_tokens` is the tokens generated by the first stage. Also due to the schdule of pipeline, the operation to update the
|
||||
information and the condition to determine the state is different from other stages.
|
||||
|
||||
Args:
|
||||
inputs_dict (Dict[str, torch.Tensor]): the inputs of current stage. The key should have `input_ids` and `attention_mask`.
|
||||
output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`.
|
||||
new_length (int): the new length of the input sequence.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor],
|
||||
new_length: int) -> None:
|
||||
super().__init__(inputs_dict, output_dict, new_length)
|
||||
assert inputs_dict is not None
|
||||
assert inputs_dict.get('input_ids') is not None and inputs_dict.get('attention_mask') is not None
|
||||
self.input_ids = inputs_dict['input_ids']
|
||||
self.attn_mask = inputs_dict['attention_mask']
|
||||
self.new_tokens = None
|
||||
|
||||
def update(self, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None):
|
||||
super().update(output_dict, new_token)
|
||||
if new_token is not None:
|
||||
self._update_newtokens(new_token)
|
||||
if self.state is not Status.DONE and new_token is not None:
|
||||
self._update_attnmask()
|
||||
|
||||
def _update_newtokens(self, new_token: torch.Tensor):
|
||||
if self.new_tokens is None:
|
||||
self.new_tokens = new_token
|
||||
else:
|
||||
self.new_tokens = torch.cat([self.new_tokens, new_token], dim=-1)
|
||||
|
||||
def _update_attnmask(self):
|
||||
self.attn_mask = torch.cat(
|
||||
(self.attn_mask, torch.ones((self.attn_mask.shape[0], 1), dtype=torch.int64, device='cuda')), dim=-1)
|
||||
|
||||
@property
|
||||
def cur_length(self):
|
||||
"""
|
||||
When there is no new_token, the length is mb_length, otherwise the sequence length is `mb_length` plus the length of new_token
|
||||
|
||||
"""
|
||||
if self.new_tokens is None:
|
||||
return self.mb_length
|
||||
else:
|
||||
return self.mb_length + len(self.new_tokens[0])
|
||||
|
||||
|
||||
class BodyMicroBatchDescription(MicroBatchDescription):
|
||||
"""
|
||||
This class is used to record the infomation of the stages except the first stage of pipeline, the stages should have attributes `hidden_states` and `past_key_values`,
|
||||
|
||||
Args:
|
||||
inputs_dict (Dict[str, torch.Tensor]): will always be `None`. Other stages only receive hiddenstates from previous stage.
|
||||
output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`.
|
||||
"""
|
||||
|
||||
def __init__(self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor],
|
||||
new_length: int) -> None:
|
||||
super().__init__(inputs_dict, output_dict, new_length)
|
||||
|
||||
def update(self, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None):
|
||||
super().update(output_dict, new_token)
|
||||
|
||||
@property
|
||||
def cur_length(self):
|
||||
"""
|
||||
When there is no kv_cache, the length is mb_length, otherwise the sequence length is `kv_cache[0][0].shape[-2]` plus 1
|
||||
|
||||
"""
|
||||
if len(self.kv_cache) == 0:
|
||||
return self.mb_length
|
||||
else:
|
||||
return self.kv_cache[0][0].shape[-2] + 1
|
||||
|
||||
|
||||
class MicroBatchManager():
|
||||
'''
|
||||
MicroBatchManager is a class that manages the micro batch.
|
||||
|
||||
Args:
|
||||
stage (int): stage id of current stage.
|
||||
new_length (int): the new length of the input sequence.
|
||||
micro_batch_size (int): the micro batch size.
|
||||
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
|
||||
|
||||
'''
|
||||
|
||||
def __init__(self, stage: int, new_length: int, micro_batch_size: int, micro_batch_buffer_size: int):
|
||||
self.stage = stage
|
||||
self.new_length = new_length
|
||||
self.micro_batch_size = micro_batch_size
|
||||
self.buffer_size = micro_batch_buffer_size
|
||||
self.mb_descrption_buffer = {}
|
||||
self.new_tokens_buffer = {}
|
||||
self.idx = 0
|
||||
|
||||
def step(self, inputs_dict=None, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None):
|
||||
"""
|
||||
Update the state if microbatch manager, 2 conditions.
|
||||
1. For first stage in PREFILL, receive inputs and outputs, `_add_descrption` will save its inputs.
|
||||
2. For other conditon, only receive the output of previous stage, and update the descrption.
|
||||
|
||||
Args:
|
||||
inputs_dict (Dict[str, torch.Tensor]): the inputs of current stage. The key should have `input_ids` and `attention_mask`.
|
||||
output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`.
|
||||
new_token (torch.Tensor): the new token generated by current stage.
|
||||
"""
|
||||
# Add descrption first if the descrption is None
|
||||
if inputs_dict is None and output_dict is None and new_token is None:
|
||||
return Status.PREFILL
|
||||
if self.mb_descrption_buffer.get(self.idx) is None:
|
||||
self._add_descrption(inputs_dict, output_dict)
|
||||
self.cur_descrption.update(output_dict, new_token)
|
||||
return self.cur_state
|
||||
|
||||
def export_new_tokens(self):
|
||||
new_tokens_list = []
|
||||
for i in self.mb_descrption_buffer.values():
|
||||
new_tokens_list.extend(i.new_tokens.tolist())
|
||||
return new_tokens_list
|
||||
|
||||
def is_micro_batch_done(self):
|
||||
if len(self.mb_descrption_buffer) == 0:
|
||||
return False
|
||||
for mb in self.mb_descrption_buffer.values():
|
||||
if mb.state != Status.DONE:
|
||||
return False
|
||||
return True
|
||||
|
||||
def clear(self):
|
||||
self.mb_descrption_buffer.clear()
|
||||
|
||||
def next(self):
|
||||
self.idx = (self.idx + 1) % self.buffer_size
|
||||
|
||||
def _add_descrption(self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor]):
|
||||
if self.stage == 0:
|
||||
self.mb_descrption_buffer[self.idx] = HeadMicroBatchDescription(inputs_dict, output_dict, self.new_length)
|
||||
else:
|
||||
self.mb_descrption_buffer[self.idx] = BodyMicroBatchDescription(inputs_dict, output_dict, self.new_length)
|
||||
|
||||
def _remove_descrption(self):
|
||||
self.mb_descrption_buffer.pop(self.idx)
|
||||
|
||||
@property
|
||||
def cur_descrption(self) -> MicroBatchDescription:
|
||||
return self.mb_descrption_buffer.get(self.idx)
|
||||
|
||||
@property
|
||||
def cur_kv_cache(self):
|
||||
if self.cur_descrption is None:
|
||||
return None
|
||||
return self.cur_descrption.kv_cache
|
||||
|
||||
@property
|
||||
def cur_state(self):
|
||||
"""
|
||||
Return the state of current micro batch, when current descrption is None, the state is PREFILL
|
||||
|
||||
"""
|
||||
if self.cur_descrption is None:
|
||||
return Status.PREFILL
|
||||
return self.cur_descrption.state
|
0
colossalai/inference/pipeline/modeling/__init__.py
Normal file
0
colossalai/inference/pipeline/modeling/__init__.py
Normal file
278
colossalai/inference/pipeline/modeling/gpt2.py
Normal file
278
colossalai/inference/pipeline/modeling/gpt2.py
Normal file
@@ -0,0 +1,278 @@
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel, GPT2Model
|
||||
from transformers.utils import logging
|
||||
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
|
||||
|
||||
class GPT2PipelineForwards:
|
||||
'''
|
||||
This class serves as a micro library for forward function substitution of GPT2 models
|
||||
under pipeline setting.
|
||||
'''
|
||||
|
||||
@staticmethod
|
||||
def gpt2_model_forward(
|
||||
self: GPT2Model,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None) -> Union[Dict, Tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
||||
|
||||
# This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2Model.forward.
|
||||
# Please refer to original code of transformers for more details.
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
# Preprocess passed in arguments
|
||||
if output_attentions:
|
||||
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
|
||||
output_attentions = False
|
||||
if output_hidden_states:
|
||||
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
|
||||
output_hidden_states = False
|
||||
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if past_key_values is None:
|
||||
past_length = 0
|
||||
past_key_values = tuple([None] * len(self.h))
|
||||
else:
|
||||
past_length = past_key_values[0][0].size(-2)
|
||||
|
||||
if stage_manager.is_first_stage():
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
batch_size = input_ids.shape[0]
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
batch_size = inputs_embeds.shape[0]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = token_type_ids.view(-1, input_shape[-1])
|
||||
else:
|
||||
if hidden_states is None:
|
||||
raise ValueError("hidden_states shouldn't be None for stages other than the first stage.")
|
||||
input_shape = hidden_states.size()[:-1]
|
||||
batch_size, seq_length = input_shape[0], input_shape[1]
|
||||
device = hidden_states.device
|
||||
|
||||
# GPT2Attention mask.
|
||||
if attention_mask is not None:
|
||||
if batch_size <= 0:
|
||||
raise ValueError("batch_size has to be defined and > 0")
|
||||
attention_mask = attention_mask.view(batch_size, -1)
|
||||
# We create a 3D attention mask from a 2D tensor mask.
|
||||
# Sizes are [batch_size, 1, 1, to_seq_length]
|
||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||
# this attention mask is more simple than the triangular masking of causal attention
|
||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||
attention_mask = attention_mask[:, None, None, :]
|
||||
|
||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
# positions we want to attend and the dtype's smallest value for masked positions.
|
||||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
||||
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
|
||||
|
||||
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
if self.config.add_cross_attention and encoder_hidden_states is not None:
|
||||
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
||||
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
||||
if encoder_attention_mask is None:
|
||||
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
||||
encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
||||
else:
|
||||
encoder_attention_mask = None
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
# head_mask has shape n_layer x batch x n_heads x N x N
|
||||
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
||||
|
||||
if stage_manager.is_first_stage():
|
||||
if position_ids is not None:
|
||||
position_ids = position_ids.view(-1, input_shape[-1])
|
||||
else:
|
||||
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.wte(input_ids)
|
||||
position_embeds = self.wpe(position_ids)
|
||||
hidden_states = inputs_embeds + position_embeds
|
||||
if token_type_ids is not None:
|
||||
token_type_embeds = self.wte(token_type_ids)
|
||||
hidden_states = hidden_states + token_type_embeds
|
||||
hidden_states = self.drop(hidden_states)
|
||||
|
||||
output_shape = input_shape + (hidden_states.size(-1),)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
|
||||
use_cache = False
|
||||
|
||||
presents = () if use_cache else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
|
||||
# Going through held blocks.
|
||||
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||
for i, layer_past in zip(range(start_idx, end_idx), past_key_values):
|
||||
block = self.h[i]
|
||||
# Model parallel
|
||||
if self.model_parallel:
|
||||
torch.cuda.set_device(hidden_states.device)
|
||||
# Ensure layer_past is on same device as hidden_states (might not be correct)
|
||||
if layer_past is not None:
|
||||
layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
|
||||
# Ensure that attention_mask is always on the same device as hidden_states
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.to(hidden_states.device)
|
||||
if isinstance(head_mask, torch.Tensor):
|
||||
head_mask = head_mask.to(hidden_states.device)
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs, use_cache, output_attentions)
|
||||
|
||||
return custom_forward
|
||||
|
||||
outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
None,
|
||||
attention_mask,
|
||||
head_mask[i],
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
)
|
||||
else:
|
||||
outputs = block(
|
||||
hidden_states,
|
||||
layer_past=layer_past,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask[i],
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
if use_cache is True:
|
||||
presents = presents + (outputs[1],)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
||||
if self.config.add_cross_attention:
|
||||
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
|
||||
|
||||
# Model Parallel: If it's the last layer for that device, put things on the next device
|
||||
if self.model_parallel:
|
||||
for k, v in self.device_map.items():
|
||||
if i == v[-1] and "cuda:" + str(k) != self.last_device:
|
||||
hidden_states = hidden_states.to("cuda:" + str(k + 1))
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
|
||||
hidden_states = hidden_states.view(output_shape)
|
||||
|
||||
# Add last hidden state
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
return {'hidden_states': hidden_states, 'past_key_values': presents}
|
||||
|
||||
@staticmethod
|
||||
def gpt2_lmhead_model_forward(
|
||||
self: GPT2LMHeadModel,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None) -> Union[Dict, Tuple, CausalLMOutputWithCrossAttentions]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
||||
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
|
||||
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
|
||||
|
||||
This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.forward.
|
||||
Please refer to original code of transformers for more details.
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# If is first stage and after warmup, go throught lm_head first
|
||||
if stage_manager.is_first_stage() and hidden_states is not None:
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
return {'logits': lm_logits}
|
||||
|
||||
# Not first stage or before warmup, go through gpt2 model
|
||||
outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer,
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
stage_manager=stage_manager,
|
||||
hidden_states=hidden_states,
|
||||
stage_index=stage_index)
|
||||
|
||||
return outputs
|
231
colossalai/inference/pipeline/modeling/llama.py
Normal file
231
colossalai/inference/pipeline/modeling/llama.py
Normal file
@@ -0,0 +1,231 @@
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaModel
|
||||
from transformers.utils import logging
|
||||
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
|
||||
|
||||
class LlamaPipelineForwards:
|
||||
'''
|
||||
This class serves as a micro library for forward function substitution of Llama models
|
||||
under pipeline setting.
|
||||
'''
|
||||
|
||||
def llama_model_forward(
|
||||
self: LlamaModel,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
):
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
# Preprocess passed in arguments
|
||||
if output_attentions:
|
||||
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
|
||||
output_attentions = False
|
||||
if output_hidden_states:
|
||||
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
|
||||
output_hidden_states = False
|
||||
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if stage_manager.is_first_stage():
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
batch_size, seq_length = input_shape
|
||||
device = hidden_states.device
|
||||
|
||||
seq_length_with_past = seq_length
|
||||
past_key_values_length = 0
|
||||
|
||||
if past_key_values is not None:
|
||||
past_key_values_length = past_key_values[0][0].shape[2]
|
||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(past_key_values_length,
|
||||
seq_length + past_key_values_length,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
else:
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
|
||||
# embed positions, for the first stage, hidden_states is the input embeddings,
|
||||
# for the other stages, hidden_states is the output of the previous stage
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones((batch_size, seq_length_with_past),
|
||||
dtype=torch.bool,
|
||||
device=hidden_states.device)
|
||||
attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), hidden_states,
|
||||
past_key_values_length)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
|
||||
use_cache = False
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = () if use_cache else None
|
||||
|
||||
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||
if past_key_values is None:
|
||||
past_key_values = tuple([None] * (end_idx - start_idx + 1))
|
||||
|
||||
for idx, past_key_value in zip(range(start_idx, end_idx), past_key_values):
|
||||
decoder_layer = self.layers[idx]
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
# past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs, output_attentions, None)
|
||||
|
||||
return custom_forward
|
||||
|
||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(decoder_layer),
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
|
||||
# always return dict for imediate stage
|
||||
return {'hidden_states': hidden_states, 'past_key_values': next_cache}
|
||||
|
||||
def llama_for_causal_lm_forward(
|
||||
self: LlamaForCausalLM,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
):
|
||||
r"""
|
||||
Args:
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, LlamaForCausalLM
|
||||
|
||||
>>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
||||
|
||||
>>> prompt = "Hey, are you consciours? Can you talk to me?"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
|
||||
```"""
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if output_attentions:
|
||||
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
|
||||
output_attentions = False
|
||||
if output_hidden_states:
|
||||
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
|
||||
output_hidden_states = False
|
||||
|
||||
# If is first stage and after warmup, go throught lm_head first
|
||||
if stage_manager.is_first_stage() and hidden_states is not None:
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
return {'logits': lm_logits}
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = LlamaPipelineForwards.llama_model_forward(
|
||||
self.model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
stage_manager=stage_manager,
|
||||
hidden_states=hidden_states,
|
||||
stage_index=stage_index,
|
||||
)
|
||||
|
||||
return outputs
|
71
colossalai/inference/pipeline/policy/gpt2_ppinfer.py
Normal file
71
colossalai/inference/pipeline/policy/gpt2_ppinfer.py
Normal file
@@ -0,0 +1,71 @@
|
||||
from functools import partial
|
||||
from typing import Callable, Dict, List
|
||||
|
||||
from torch import Tensor, nn
|
||||
|
||||
import colossalai.shardformer.layer as col_nn
|
||||
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
from colossalai.shardformer.policies.gpt2 import GPT2Policy
|
||||
|
||||
from ..modeling.gpt2 import GPT2PipelineForwards
|
||||
|
||||
|
||||
class GPT2LMHeadModelPipelinePolicy(GPT2Policy):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
|
||||
|
||||
module_policy = super().module_policy()
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
addon_module = {
|
||||
GPT2LMHeadModel:
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True})
|
||||
])
|
||||
}
|
||||
module_policy.update(addon_module)
|
||||
|
||||
if self.pipeline_stage_manager is not None:
|
||||
self.set_pipeline_forward(model_cls=GPT2LMHeadModel,
|
||||
new_forward=GPT2PipelineForwards.gpt2_lmhead_model_forward,
|
||||
policy=module_policy)
|
||||
return module_policy
|
||||
|
||||
def get_held_layers(self) -> List[nn.Module]:
|
||||
held_layers = super().get_held_layers()
|
||||
# make the tie weight lm_head and embedding in the same device to save memory
|
||||
# if self.pipeline_stage_manager.is_first_stage():
|
||||
if self.pipeline_stage_manager.is_first_stage():
|
||||
held_layers.append(self.model.lm_head)
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
'''The weights of wte and lm_head are shared.'''
|
||||
module = self.model
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if stage_manager is not None:
|
||||
if stage_manager.num_stages > 1 and id(module.transformer.wte.weight) == id(module.lm_head.weight):
|
||||
first_stage, last_stage = 0, stage_manager.num_stages - 1
|
||||
return [{first_stage: module.transformer.wte.weight, last_stage: module.lm_head.weight}]
|
||||
return []
|
||||
|
||||
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
|
||||
"""If under pipeline parallel setting, replacing the original forward method of huggingface
|
||||
to customized forward method, and add this changing to policy."""
|
||||
if not self.pipeline_stage_manager:
|
||||
raise ValueError("set_pipeline_forward method can only be called when pipeline parallel is enabled.")
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if self.model.__class__.__name__ == 'GPT2Model':
|
||||
module = self.model
|
||||
else:
|
||||
module = self.model.transformer
|
||||
|
||||
layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages)
|
||||
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
method_replacement = {'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
|
||||
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
|
50
colossalai/inference/pipeline/policy/llama_ppinfer.py
Normal file
50
colossalai/inference/pipeline/policy/llama_ppinfer.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from functools import partial
|
||||
from typing import Callable, Dict, List, Union
|
||||
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.nn import Module
|
||||
|
||||
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
|
||||
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
from colossalai.shardformer.policies.llama import LlamaPolicy
|
||||
|
||||
from ..modeling.llama import LlamaPipelineForwards
|
||||
|
||||
|
||||
class LlamaForCausalLMPipelinePolicy(LlamaPolicy):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
from transformers import LlamaForCausalLM
|
||||
|
||||
policy = super().module_policy()
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
# add a new item for casual lm
|
||||
new_item = {
|
||||
LlamaForCausalLM:
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True))
|
||||
])
|
||||
}
|
||||
policy.update(new_item)
|
||||
|
||||
if self.pipeline_stage_manager:
|
||||
# set None as default
|
||||
self.set_pipeline_forward(model_cls=LlamaForCausalLM,
|
||||
new_forward=LlamaPipelineForwards.llama_for_causal_lm_forward,
|
||||
policy=policy)
|
||||
|
||||
return policy
|
||||
|
||||
def get_held_layers(self) -> List[Module]:
|
||||
"""Get pipeline layers for current stage."""
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
held_layers = super().get_held_layers()
|
||||
if stage_manager.is_first_stage():
|
||||
held_layers.append(self.model.lm_head)
|
||||
return held_layers
|
35
colossalai/inference/pipeline/utils.py
Normal file
35
colossalai/inference/pipeline/utils.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from typing import List, Optional, Set
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.shardformer._utils import getattr_, setattr_
|
||||
|
||||
|
||||
def set_tensors_to_none(model: nn.Module, include: Set[str] = set()) -> None:
|
||||
"""
|
||||
Set all parameters and buffers of model to None
|
||||
|
||||
Args:
|
||||
model (nn.Module): The model to set
|
||||
"""
|
||||
for module_suffix in include:
|
||||
set_module = getattr_(model, module_suffix)
|
||||
for n, p in set_module.named_parameters():
|
||||
setattr_(set_module, n, None)
|
||||
for n, buf in set_module.named_buffers():
|
||||
setattr_(set_module, n, None)
|
||||
setattr_(model, module_suffix, None)
|
||||
|
||||
|
||||
def get_suffix_name(suffix: str, name: str):
|
||||
"""
|
||||
Get the suffix name of the module, as `suffix.name` when name is string or `suffix[name]` when name is a digit,
|
||||
and 'name' when `suffix` is empty.
|
||||
|
||||
Args:
|
||||
suffix (str): The suffix of the suffix module
|
||||
name (str): The name of the current module
|
||||
"""
|
||||
point = '' if suffix is '' else '.'
|
||||
suffix_name = suffix + f'[{name}]' if name.isdigit() else suffix + f'{point}{name}'
|
||||
return suffix_name
|
Reference in New Issue
Block a user