mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 17:46:42 +00:00
Merge branch 'main' into feature/shardformer
This commit is contained in:
@@ -7,6 +7,14 @@ This directory includes two parts: Using the Booster API finetune Huggingface Be
|
||||
bash test_ci.sh
|
||||
```
|
||||
|
||||
### Results on 2-GPU
|
||||
|
||||
| Plugin | Accuracy | F1-score |
|
||||
| -------------- | -------- | -------- |
|
||||
| torch_ddp | 84.4% | 88.6% |
|
||||
| torch_ddp_fp16 | 84.7% | 88.8% |
|
||||
| gemini | 84.0% | 88.4% |
|
||||
|
||||
## Benchmark
|
||||
```
|
||||
bash benchmark.sh
|
||||
@@ -14,9 +22,9 @@ bash benchmark.sh
|
||||
|
||||
Now include these metrics in benchmark: CUDA mem occupy, throughput and the number of model parameters. If you have custom metrics, you can add them to benchmark_util.
|
||||
|
||||
## Results
|
||||
### Results
|
||||
|
||||
### Bert
|
||||
#### Bert
|
||||
|
||||
| | max cuda mem | throughput(sample/s) | params |
|
||||
| :-----| -----------: | :--------: | :----: |
|
||||
@@ -25,10 +33,10 @@ Now include these metrics in benchmark: CUDA mem occupy, throughput and the numb
|
||||
| gemini | 11.0 GB | 12.9 | 82M |
|
||||
| low_level_zero | 11.29 G | 14.7 | 82M |
|
||||
|
||||
### AlBert
|
||||
#### AlBert
|
||||
| | max cuda mem | throughput(sample/s) | params |
|
||||
| :-----| -----------: | :--------: | :----: |
|
||||
| ddp | OOM | | |
|
||||
| ddp_fp16 | OOM | | |
|
||||
| gemini | 69.39 G | 1.3 | 208M |
|
||||
| low_level_zero | 56.89 G | 1.4 | 208M |
|
||||
| low_level_zero | 56.89 G | 1.4 | 208M |
|
||||
|
@@ -219,7 +219,7 @@ def main():
|
||||
if args.plugin.startswith('torch_ddp'):
|
||||
plugin = TorchDDPPlugin()
|
||||
elif args.plugin == 'gemini':
|
||||
plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, initial_scale=2**5)
|
||||
plugin = GeminiPlugin(initial_scale=2**5)
|
||||
elif args.plugin == 'low_level_zero':
|
||||
plugin = LowLevelZeroPlugin(initial_scale=2**5)
|
||||
elif args.plugin == 'hybrid_parallel':
|
||||
|
@@ -4,9 +4,6 @@ export DISTPLAN=${DISTPLAN:-"CAI_Gemini"}
|
||||
|
||||
# The following options only valid when DISTPLAN="colossalai"
|
||||
export GPUNUM=${GPUNUM:-1}
|
||||
export TPDEGREE=${TPDEGREE:-1}
|
||||
export PLACEMENT=${PLACEMENT:-"cpu"}
|
||||
export USE_SHARD_INIT=${USE_SHARD_INIT:-False}
|
||||
export BATCH_SIZE=${BATCH_SIZE:-16}
|
||||
export MODEL_TYPE=${MODEL_TYPE:-"gpt2_medium"}
|
||||
export TRAIN_STEP=${TRAIN_STEP:-10}
|
||||
@@ -21,11 +18,8 @@ fi
|
||||
mkdir -p gemini_logs
|
||||
|
||||
torchrun --standalone --nproc_per_node=${GPUNUM} ./train_gpt_demo.py \
|
||||
--tp_degree=${TPDEGREE} \
|
||||
--model_type=${MODEL_TYPE} \
|
||||
--batch_size=${BATCH_SIZE} \
|
||||
--placement=${PLACEMENT} \
|
||||
${USE_SHARD_INIT} \
|
||||
--distplan=${DISTPLAN} \
|
||||
--train_step=${TRAIN_STEP} \
|
||||
2>&1 | tee ./gemini_logs/${MODEL_TYPE}_${DISTPLAN}_gpu_${GPUNUM}_bs_${BATCH_SIZE}_tp_${TPDEGREE}_${PLACEMENT}.log
|
||||
|
@@ -6,29 +6,17 @@ for MODEL_TYPE in "gpt2_medium"; do
|
||||
for DISTPLAN in "CAI_Gemini"; do
|
||||
for BATCH_SIZE in 2; do
|
||||
for GPUNUM in 1 4; do
|
||||
for TPDEGREE in 1 2; do
|
||||
if [ ${TPDEGREE} -gt ${GPUNUM} ]; then
|
||||
continue
|
||||
fi
|
||||
for PLACEMENT in "cpu" "auto"; do
|
||||
MODEL_TYPE=${MODEL_TYPE} DISTPLAN=${DISTPLAN} BATCH_SIZE=${BATCH_SIZE} GPUNUM=${GPUNUM} TPDEGREE=${TPDEGREE} PLACEMENT=${PLACEMENT} \
|
||||
bash ./run_gemini.sh
|
||||
done
|
||||
done
|
||||
MODEL_TYPE=${MODEL_TYPE} DISTPLAN=${DISTPLAN} BATCH_SIZE=${BATCH_SIZE} GPUNUM=${GPUNUM} \
|
||||
bash ./run_gemini.sh
|
||||
done
|
||||
done
|
||||
done
|
||||
|
||||
for DISTPLAN in "zero1" "zero2"; do
|
||||
for DISTPLAN in "CAI_ZeRO2" "CAI_ZeRO1"; do
|
||||
for BATCH_SIZE in 2; do
|
||||
for GPUNUM in 1 4; do
|
||||
for TPDEGREE in 1; do
|
||||
if [ ${TPDEGREE} -gt ${GPUNUM} ]; then
|
||||
continue
|
||||
fi
|
||||
MODEL_TYPE=${MODEL_TYPE} DISTPLAN=${DISTPLAN} BATCH_SIZE=${BATCH_SIZE} GPUNUM=${GPUNUM} TPDEGREE=${TPDEGREE}\
|
||||
bash ./run_gemini.sh
|
||||
done
|
||||
MODEL_TYPE=${MODEL_TYPE} DISTPLAN=${DISTPLAN} BATCH_SIZE=${BATCH_SIZE} GPUNUM=${GPUNUM} \
|
||||
bash ./run_gemini.sh
|
||||
done
|
||||
done
|
||||
done
|
||||
|
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
from contextlib import nullcontext
|
||||
from functools import partial
|
||||
from time import time
|
||||
|
||||
@@ -13,11 +14,10 @@ from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero import ColoInitContext
|
||||
|
||||
CAI_VERSION = colossalai.__version__
|
||||
|
||||
@@ -30,24 +30,6 @@ def parse_args():
|
||||
default='CAI_Gemini',
|
||||
help="The distributed plan [colossalai, zero1, zero2, torch_ddp, torch_zero].",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tp_degree",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Tensor Parallelism Degree. Valid when using colossalai as dist plan.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--placement",
|
||||
type=str,
|
||||
default='cpu',
|
||||
help="Placement Policy for Gemini. Valid when using colossalai as dist plan.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--shardinit",
|
||||
action='store_true',
|
||||
help=
|
||||
"Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch_size",
|
||||
type=int,
|
||||
@@ -71,20 +53,6 @@ def parse_args():
|
||||
return args
|
||||
|
||||
|
||||
# Parameter Sharding Strategies for Tensor Parallelism
|
||||
def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup):
|
||||
spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
param.set_tensor_spec(*spec)
|
||||
|
||||
|
||||
def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup):
|
||||
split_param_single_dim_tp1d(0, param, pg)
|
||||
|
||||
|
||||
def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup):
|
||||
split_param_single_dim_tp1d(-1, param, pg)
|
||||
|
||||
|
||||
class GPTLMLoss(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
@@ -140,47 +108,6 @@ def set_cpu_maximum_parallelism():
|
||||
print(f"environmental variable OMP_NUM_THREADS is set to {max_concurrency}.")
|
||||
|
||||
|
||||
# Tensor Parallel
|
||||
def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
|
||||
"""tensor_parallelize
|
||||
Sharding the Model Parameters.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): a torch module to be sharded
|
||||
"""
|
||||
for mn, module in model.named_modules():
|
||||
for pn, param in module.named_parameters(recurse=False):
|
||||
# NOTE() a param maybe shared by two modules
|
||||
if hasattr(param, 'visited'):
|
||||
continue
|
||||
|
||||
# if shard init, then convert param to replica and use the dp-only ProcessGroup
|
||||
param: ColoParameter = param
|
||||
param.set_dist_spec(ReplicaSpec())
|
||||
param.set_process_group(pg)
|
||||
|
||||
# shard it w.r.t tp pattern
|
||||
if 'mlp.c_fc' in mn:
|
||||
if 'weight' in pn or 'bias' in pn:
|
||||
split_param_col_tp1d(param, pg) # column slice
|
||||
# keep the shape of the output from c_fc
|
||||
param.compute_spec.set_output_replicate(False)
|
||||
else:
|
||||
param.set_dist_spec(ReplicaSpec())
|
||||
elif 'mlp.c_proj' in mn:
|
||||
if 'weight' in pn:
|
||||
split_param_row_tp1d(param, pg) # row slice
|
||||
else:
|
||||
param.set_dist_spec(ReplicaSpec())
|
||||
elif 'wte' in mn or 'wpe' in mn:
|
||||
split_param_col_tp1d(param, pg) # column slice
|
||||
elif 'c_attn' in mn or 'c_proj' in mn:
|
||||
split_param_col_tp1d(param, pg) # column slice
|
||||
else:
|
||||
param.set_dist_spec(ReplicaSpec())
|
||||
param.visited = True
|
||||
|
||||
|
||||
def main():
|
||||
# version check
|
||||
# this example is supposed to work for versions greater than 0.2.0
|
||||
@@ -213,30 +140,13 @@ def main():
|
||||
|
||||
# build criterion
|
||||
criterion = GPTLMLoss()
|
||||
|
||||
torch.manual_seed(123)
|
||||
if args.distplan.startswith("CAI"):
|
||||
# all param must use the same process group.
|
||||
world_size = torch.distributed.get_world_size()
|
||||
shard_pg = ProcessGroup(tp_degree=world_size) if args.shardinit else None
|
||||
default_dist_spec = ShardSpec([-1], [world_size]) if args.shardinit else None
|
||||
|
||||
if args.shardinit and args.distplan != "CAI_Gemini":
|
||||
raise RuntimeError("You can only use shardinit with CAI_Gemini")
|
||||
|
||||
ctx = LazyInitContext(default_device=get_current_device()) if args.distplan == "CAI_Gemini" else nullcontext()
|
||||
# build GPT model
|
||||
with ColoInitContext(device=get_current_device(),
|
||||
dtype=torch.half,
|
||||
default_dist_spec=default_dist_spec,
|
||||
default_pg=shard_pg):
|
||||
with ctx:
|
||||
model = model_builder(args.model_type)(checkpoint=True)
|
||||
|
||||
tp_pg = ProcessGroup(tp_degree=args.tp_degree)
|
||||
# Tensor Parallelism (TP)
|
||||
# You should notice that v0.1.10 is not compatible with TP degree > 1
|
||||
if args.tp_degree > 1:
|
||||
tensor_parallelize(model, tp_pg)
|
||||
|
||||
# assign running configurations
|
||||
if args.distplan == "CAI_ZeRO1":
|
||||
zero_stage = 1
|
||||
@@ -254,13 +164,7 @@ def main():
|
||||
overlap_communication=True,
|
||||
verbose=True)
|
||||
elif args.distplan == "CAI_Gemini":
|
||||
plugin = GeminiPlugin(device=get_current_device(),
|
||||
placement_policy=args.placement,
|
||||
pin_memory=True,
|
||||
strict_ddp_mode=args.tp_degree == 1,
|
||||
search_range_m=128,
|
||||
hidden_dim=model.config.n_embd,
|
||||
gpu_margin_mem_ratio=0.)
|
||||
plugin = GeminiPlugin(search_range_m=128, hidden_dim=model.config.n_embd)
|
||||
else:
|
||||
raise RuntimeError
|
||||
|
||||
|
@@ -1,11 +0,0 @@
|
||||
# Pretraining LLaMA: best practices for building LLaMA-like base models
|
||||
|
||||
<p id="ColossalChat-Speed" align="center">
|
||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/examples/images/LLaMA_pretraining.png" width=600/>
|
||||
</p>
|
||||
|
||||
- 65-billion-parameter large model pretraining accelerated by 38%
|
||||
[[code]](https://github.com/hpcaitech/ColossalAI/tree/example/llama/examples/language/llama)
|
||||
[[blog]](https://www.hpc-ai.tech/blog/large-model-pretraining)
|
||||
|
||||
> Since the main branch is being updated, in order to maintain the stability of the code, this example is temporarily kept as an [independent branch](https://github.com/hpcaitech/ColossalAI/tree/example/llama/examples/language/llama).
|
194
examples/language/llama2/README.md
Normal file
194
examples/language/llama2/README.md
Normal file
@@ -0,0 +1,194 @@
|
||||
# Pretraining LLaMA-1/2: best practices for building LLaMA-1/2-like base models
|
||||
|
||||
### LLaMA2
|
||||
<p align="center">
|
||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/llama2_pretraining.png" width=600/>
|
||||
</p>
|
||||
|
||||
- 70 billion parameter LLaMA2 model training accelerated by 195%
|
||||
[[code]](https://github.com/hpcaitech/ColossalAI/tree/example/llama/examples/language/llama)
|
||||
[[blog]](https://www.hpc-ai.tech/blog/70b-llama2-training)
|
||||
|
||||
### LLaMA1
|
||||
<p align="center">
|
||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/examples/images/LLaMA_pretraining.png" width=600/>
|
||||
</p>
|
||||
|
||||
- 65-billion-parameter large model pretraining accelerated by 38%
|
||||
[[code]](https://github.com/hpcaitech/ColossalAI/tree/example/llama/examples/language/llama)
|
||||
[[blog]](https://www.hpc-ai.tech/blog/large-model-pretraining)
|
||||
|
||||
## Dataset
|
||||
|
||||
Different from the original LLaMA, we use [RedPajama](https://www.together.xyz/blog/redpajama) dataset, which is a reproduction of the LLaMA training dataset containing over 1.2 trillion tokens. The full dataset is ~5TB unzipped on disk and ~3TB to download compressed.
|
||||
|
||||
A smaller, more consumable random sample can be downloaded through [Hugging Face](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T). If you just want to try out the pretraining script, you can use a 1B-token sample subset of RedPajama, which is available at [Hugging Face](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T-Sample).
|
||||
|
||||
RedPajama-Data-1T consists of seven data slices:
|
||||
|
||||
| | RedPajama | LLaMA |
|
||||
|---------------|--------------|---------------|
|
||||
| CommonCrawl | 878 billion | 852 billion |
|
||||
| C4 | 175 billion | 190 billion |
|
||||
| Github | 59 billion | 100 billion |
|
||||
| Books | 26 billion | 25 billion |
|
||||
| ArXiv | 28 billion | 33 billion |
|
||||
| Wikipedia | 24 billion | 25 billion |
|
||||
| StackExchange | 20 billion | 27 billion |
|
||||
| Total | 1.2 trillion | 1.25 trillion |
|
||||
|
||||
## Training
|
||||
|
||||
We follow the hyperparameter settings from the original LLaMA paper. We use AdamW with $beta1=0.9$ and $beta2=0.95$. We use a cosine learning rate schedule, such that the final learning rate is equal to 10% of the maximal learning rate. We use a weight decay of 0.1 and gradient clipping of 1.0. We use 2,000 warmup steps.
|
||||
|
||||
| params | learning rate | batch size |
|
||||
|--------|---------------|------------|
|
||||
| 6.7B | 3.0e-4 | 4M |
|
||||
| 13.0B | 3.0e-4 | 4M |
|
||||
| 32.5B | 1.5e-4 | 4M |
|
||||
| 65.2B | 1.5e-4 | 4M |
|
||||
|
||||
## Usage
|
||||
|
||||
### 1. Installation
|
||||
|
||||
Please install the latest ColossalAI from source.
|
||||
|
||||
```bash
|
||||
CUDA_EXT=1 pip install -U git+https://github.com/hpcaitech/ColossalAI
|
||||
```
|
||||
|
||||
Then install other dependencies.
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
Additionally, we recommend you to use torch 1.13.1. We've tested our code on torch 1.13.1 and found it's compatible with our code and flash attention.
|
||||
|
||||
### 2. Download the dataset
|
||||
|
||||
The dataset can be automatically downloaded by using `huggingface/datasets`. You can specify the dataset path by `-d` or `--dataset`. The default dataset is `togethercomputer/RedPajama-Data-1T-Sample`.
|
||||
|
||||
### 3. Command line arguments
|
||||
|
||||
Yon can use colossalai run to launch multi-nodes training:
|
||||
```bash
|
||||
colossalai run --nproc_per_node YOUR_GPU_PER_NODE --hostfile YOUR_HOST_FILE \
|
||||
pretrain.py --OTHER_CONFIGURATIONS
|
||||
```
|
||||
|
||||
Here is a sample hostfile:
|
||||
|
||||
```text
|
||||
hostname1
|
||||
hostname2
|
||||
hostname3
|
||||
hostname4
|
||||
```
|
||||
|
||||
Make sure master node can access all nodes (including itself) by ssh without password.
|
||||
|
||||
Here is details about CLI arguments:
|
||||
|
||||
- Model configuration: `-c`, `--config`. `7b`, `13b`, `30b` and `65b` are supported for LLaMA-1, `7b`, `13b`, and `70b` are supported for LLaMA-2.
|
||||
- Booster plugin: `-p`, `--plugin`. `gemini`, `gemini_auto`, `zero2` and `zero2_cpu` are supported. For more details, please refer to [Booster plugins](https://colossalai.org/docs/basics/booster_plugins).
|
||||
- Dataset path: `-d`, `--dataset`. The default dataset is `togethercomputer/RedPajama-Data-1T-Sample`. It support any dataset from `datasets` with the same data format as RedPajama.
|
||||
- Number of epochs: `-e`, `--num_epochs`. The default value is 1.
|
||||
- Local batch size: `-b`, `--batch_size`. Batch size per GPU. The default value is 2.
|
||||
- Learning rate: `--lr`. The default value is 3e-4.
|
||||
- Weight decay: `-w`, `--weight_decay`. The default value is 0.1.
|
||||
- Warmup steps: `-s`, `--warmup_steps`. The default value is 2000.
|
||||
- Gradient checkpointing: `-g`, `--gradient_checkpoint`. The default value is `False`. This saves memory at the cost of speed. You'd better enable this option when training with a large batch size.
|
||||
- Max length: `-l`, `--max_length`. The default value is 4096.
|
||||
- Mixed precision: `-x`, `--mixed_precision`. The default value is "fp16". "fp16" and "bf16" are supported.
|
||||
- Save interval: `-i`, `--save_interval`. The interval (steps) of saving checkpoints. The default value is 1000.
|
||||
- Checkpoint directory: `-o`, `--save_dir`. The directoty path to save checkpoints. The default value is `checkpoint`.
|
||||
- Checkpoint to load: `-f`, `--load`. The checkpoint path to load. The default value is `None`.
|
||||
- Gradient clipping: `--gradient_clipping`. The default value is 1.0.
|
||||
- Tensorboard log directory: `-t`, `--tensorboard_dir`. The directory path to save tensorboard logs. The default value is `tb_logs`.
|
||||
- Flash attention: `-a`, `--flash_attention`. If you want to use flash attention, you must install `flash-attn`. The default value is `False`. This is helpful to accelerate training while saving memory. We recommend you always use flash attention.
|
||||
|
||||
|
||||
### 4. Shell Script Examples
|
||||
|
||||
For your convenience, we provide some shell scripts to run benchmark with various configurations.
|
||||
|
||||
You can find them in `scripts/benchmark_7B` and `scripts/benchmark_70B` directory. The main command should be in the format of:
|
||||
```bash
|
||||
colossalai run --nproc_per_node YOUR_GPU_PER_NODE --hostfile YOUR_HOST_FILE \
|
||||
benchmark.py --OTHER_CONFIGURATIONS
|
||||
```
|
||||
Here we will show an example of how to run training
|
||||
llama pretraining with `gemini, batch_size=16, sequence_length=4096, gradient_checkpoint=True, flash_attn=True`.
|
||||
|
||||
#### a. Running environment
|
||||
This experiment was performed on 4 computing nodes with 32 A800 GPUs in total for LLaMA-1 65B. The nodes are
|
||||
connected with RDMA and GPUs within one node are fully connected with NVLink.
|
||||
|
||||
#### b. Running command
|
||||
|
||||
```bash
|
||||
cd scripts/benchmark_7B
|
||||
```
|
||||
|
||||
First, put your host file (`hosts.txt`) in this directory with your real host ip or host name.
|
||||
|
||||
Here is a sample `hosts.txt`:
|
||||
```text
|
||||
hostname1
|
||||
hostname2
|
||||
hostname3
|
||||
hostname4
|
||||
```
|
||||
|
||||
Then add environment variables to script if needed.
|
||||
|
||||
Finally, run the following command to start training:
|
||||
|
||||
```bash
|
||||
bash gemini.sh
|
||||
```
|
||||
#### c. Results
|
||||
If you run the above command successfully, you will get the following results:
|
||||
`max memory usage: 55491.10 MB, throughput: 24.26 samples/s, TFLOPS/GPU: 167.43`.
|
||||
|
||||
|
||||
## Reference
|
||||
```
|
||||
@article{bian2021colossal,
|
||||
title={Colossal-AI: A Unified Deep Learning System For Large-Scale Parallel Training},
|
||||
author={Bian, Zhengda and Liu, Hongxin and Wang, Boxiang and Huang, Haichen and Li, Yongbin and Wang, Chuanrui and Cui, Fan and You, Yang},
|
||||
journal={arXiv preprint arXiv:2110.14883},
|
||||
year={2021}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@software{openlm2023openllama,
|
||||
author = {Geng, Xinyang and Liu, Hao},
|
||||
title = {OpenLLaMA: An Open Reproduction of LLaMA},
|
||||
month = May,
|
||||
year = 2023,
|
||||
url = {https://github.com/openlm-research/open_llama}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@software{together2023redpajama,
|
||||
author = {Together Computer},
|
||||
title = {RedPajama-Data: An Open Source Recipe to Reproduce LLaMA training dataset},
|
||||
month = April,
|
||||
year = 2023,
|
||||
url = {https://github.com/togethercomputer/RedPajama-Data}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@article{touvron2023llama,
|
||||
title={Llama: Open and efficient foundation language models},
|
||||
author={Touvron, Hugo and Lavril, Thibaut and Izacard, Gautier and Martinet, Xavier and Lachaux, Marie-Anne and Lacroix, Timoth{\'e}e and Rozi{\`e}re, Baptiste and Goyal, Naman and Hambro, Eric and Azhar, Faisal and others},
|
||||
journal={arXiv preprint arXiv:2302.13971},
|
||||
year={2023}
|
||||
}
|
||||
```
|
83
examples/language/llama2/attn.py
Normal file
83
examples/language/llama2/attn.py
Normal file
@@ -0,0 +1,83 @@
|
||||
from types import MethodType
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb, repeat_kv
|
||||
|
||||
SUPPORT_XFORMERS = False
|
||||
SUPPORT_FLASH2 = False
|
||||
try:
|
||||
import xformers.ops as xops
|
||||
SUPPORT_XFORMERS = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from flash_attn import flash_attn_func
|
||||
SUPPORT_FLASH2 = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
SUPPORT_FLASH = SUPPORT_XFORMERS or SUPPORT_FLASH2
|
||||
|
||||
|
||||
def llama_flash_attention(
|
||||
self: LlamaAttention,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
# [bsz, nh, t, hd]
|
||||
|
||||
if past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
# q, k, v is [B, H, S, K] and xformers need [B, S, H, K]. returns [B, S, H, K]
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
if SUPPORT_FLASH2:
|
||||
attn_output = flash_attn_func(query_states, key_states, value_states, causal=True)
|
||||
else:
|
||||
attn_output = xops.memory_efficient_attention(query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_bias=xops.LowerTriangularMask())
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
def replace_xformers(model: nn.Module):
|
||||
for module in model.modules():
|
||||
if isinstance(module, LlamaAttention):
|
||||
module.forward = MethodType(llama_flash_attention, module)
|
211
examples/language/llama2/benchmark.py
Normal file
211
examples/language/llama2/benchmark.py
Normal file
@@ -0,0 +1,211 @@
|
||||
import argparse
|
||||
import resource
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
from attn import SUPPORT_FLASH, replace_xformers
|
||||
from data_utils import RandomDataset
|
||||
from model_utils import format_numel_str, get_model_numel
|
||||
from performance_evaluator import PerformanceEvaluator
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision
|
||||
from tqdm import tqdm
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
from transformers.models.llama.modeling_llama import LlamaForCausalLM
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, TorchFSDPPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
# ==============================
|
||||
# Constants
|
||||
# ==============================
|
||||
|
||||
MODEL_CONFIGS = {
|
||||
'7b':
|
||||
LlamaConfig(max_position_embeddings=4096),
|
||||
'13b':
|
||||
LlamaConfig(hidden_size=5120,
|
||||
intermediate_size=13824,
|
||||
num_hidden_layers=40,
|
||||
num_attention_heads=40,
|
||||
max_position_embeddings=4096),
|
||||
'70b':
|
||||
LlamaConfig(hidden_size=8192,
|
||||
intermediate_size=28672,
|
||||
num_hidden_layers=80,
|
||||
num_attention_heads=64,
|
||||
max_position_embeddings=4096,
|
||||
num_key_value_heads=8),
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
# ==============================
|
||||
# Parse Arguments
|
||||
# ==============================
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-c', '--config', type=str, default='7b', help='Model configuration')
|
||||
parser.add_argument('-p',
|
||||
'--plugin',
|
||||
choices=['gemini', 'gemini_auto', 'fsdp', 'fsdp_cpu', '3d', '3d_cpu'],
|
||||
default='gemini',
|
||||
help='Choose which plugin to use')
|
||||
parser.add_argument('-b', '--batch_size', type=int, default=2, help='Batch size')
|
||||
parser.add_argument('-s', '--num_steps', type=int, default=5, help='Number of steps to run')
|
||||
parser.add_argument('-i', '--ignore_steps', type=int, default=2, help='Number of steps to ignore')
|
||||
parser.add_argument('-g', '--grad_checkpoint', action='store_true', help='Use gradient checkpointing')
|
||||
parser.add_argument('-l', '--max_length', type=int, default=4096, help='Max sequence length')
|
||||
parser.add_argument('-w',
|
||||
'--warmup_ratio',
|
||||
type=float,
|
||||
default=0.8,
|
||||
help='warm up ratio of non-model data. Only for gemini-auto')
|
||||
parser.add_argument('-m', '--memory_limit', type=int, help='Gemini memory limit in mb')
|
||||
parser.add_argument('-x', '--xformers', action='store_true', help='Use xformers')
|
||||
parser.add_argument('--shard_param_frac', type=float, default=1.0, help='Shard param fraction. Only for gemini')
|
||||
parser.add_argument('--offload_optim_frac', type=float, default=0.0, help='Offload optim fraction. Only for gemini')
|
||||
parser.add_argument('--offload_param_frac', type=float, default=0.0, help='Offload param fraction. Only for gemini')
|
||||
parser.add_argument('--tp', type=int, default=1, help='Tensor parallel size')
|
||||
parser.add_argument('--pp', type=int, default=1, help='Pipeline parallel size')
|
||||
parser.add_argument('--mbs', type=int, default=1)
|
||||
parser.add_argument('--zero', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
||||
colossalai.launch_from_torch({})
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
def empty_init():
|
||||
pass
|
||||
|
||||
# ==============================
|
||||
# Initialize Booster
|
||||
# ==============================
|
||||
use_empty_init = True
|
||||
if args.plugin == 'gemini':
|
||||
plugin = GeminiPlugin(precision='bf16',
|
||||
shard_param_frac=args.shard_param_frac,
|
||||
offload_optim_frac=args.offload_optim_frac,
|
||||
offload_param_frac=args.offload_param_frac)
|
||||
elif args.plugin == 'gemini_auto':
|
||||
plugin = GeminiPlugin(placement_policy='auto', precision='bf16', warmup_non_model_data_ratio=args.warmup_ratio)
|
||||
elif args.plugin == 'fsdp':
|
||||
if use_empty_init:
|
||||
plugin = TorchFSDPPlugin(
|
||||
mixed_precision=MixedPrecision(param_dtype=torch.float16,
|
||||
reduce_dtype=torch.float16,
|
||||
buffer_dtype=torch.float16),
|
||||
param_init_fn=empty_init(),
|
||||
)
|
||||
else:
|
||||
plugin = TorchFSDPPlugin(mixed_precision=MixedPrecision(
|
||||
param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16))
|
||||
elif args.plugin == 'fsdp_cpu':
|
||||
if use_empty_init:
|
||||
plugin = TorchFSDPPlugin(
|
||||
mixed_precision=MixedPrecision(param_dtype=torch.float16,
|
||||
reduce_dtype=torch.float16,
|
||||
buffer_dtype=torch.float16),
|
||||
cpu_offload=CPUOffload(offload_params=True),
|
||||
param_init_fn=empty_init(),
|
||||
)
|
||||
else:
|
||||
plugin = TorchFSDPPlugin(mixed_precision=MixedPrecision(param_dtype=torch.float16,
|
||||
reduce_dtype=torch.float16,
|
||||
buffer_dtype=torch.float16),
|
||||
cpu_offload=CPUOffload(offload_params=True))
|
||||
elif args.plugin == '3d':
|
||||
plugin = HybridParallelPlugin(tp_size=args.tp,
|
||||
pp_size=args.pp,
|
||||
zero_stage=args.zero,
|
||||
enable_fused_normalization=True,
|
||||
num_microbatches=args.mbs,
|
||||
precision='bf16')
|
||||
elif args.plugin == '3d_cpu':
|
||||
plugin = HybridParallelPlugin(tp_size=args.tp,
|
||||
pp_size=args.pp,
|
||||
zero_stage=args.zero,
|
||||
cpu_offload=True,
|
||||
enable_fused_normalization=True,
|
||||
num_microbatches=args.mbs,
|
||||
initial_scale=2**8,
|
||||
precision='bf16')
|
||||
else:
|
||||
raise ValueError(f'Unknown plugin {args.plugin}')
|
||||
|
||||
booster = Booster(plugin=plugin)
|
||||
|
||||
# ==============================
|
||||
# Initialize Dataset and Dataloader
|
||||
# ==============================
|
||||
dp_size = plugin.dp_size if isinstance(plugin, HybridParallelPlugin) else coordinator.world_size
|
||||
|
||||
config = MODEL_CONFIGS[args.config]
|
||||
dataset = RandomDataset(num_samples=args.batch_size * args.num_steps * dp_size,
|
||||
max_length=args.max_length,
|
||||
vocab_size=config.vocab_size)
|
||||
dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True)
|
||||
|
||||
# ==============================
|
||||
# Initialize Model and Optimizer
|
||||
# ==============================
|
||||
init_ctx = LazyInitContext(
|
||||
default_device=get_current_device()) if isinstance(plugin,
|
||||
(GeminiPlugin, HybridParallelPlugin)) else nullcontext()
|
||||
|
||||
with init_ctx:
|
||||
model = LlamaForCausalLM(config)
|
||||
|
||||
if args.grad_checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
if args.xformers:
|
||||
assert SUPPORT_FLASH, 'Use flash attention while xfomers is not installed'
|
||||
replace_xformers(model)
|
||||
|
||||
model_numel = get_model_numel(model)
|
||||
coordinator.print_on_master(f'Model params: {format_numel_str(model_numel)}')
|
||||
performance_evaluator = PerformanceEvaluator(model_numel,
|
||||
args.grad_checkpoint,
|
||||
args.ignore_steps,
|
||||
dp_world_size=dp_size)
|
||||
|
||||
optimizer = HybridAdam(model.parameters())
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader)
|
||||
torch.set_default_dtype(torch.float)
|
||||
coordinator.print_on_master(f'Booster init max CUDA memory: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB')
|
||||
coordinator.print_on_master(
|
||||
f'Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB')
|
||||
|
||||
if isinstance(plugin, HybridParallelPlugin) and args.pp > 1:
|
||||
data_iter = iter(dataloader)
|
||||
for step in tqdm(range(len(dataloader)), desc='Step', disable=not coordinator.is_master()):
|
||||
performance_evaluator.on_step_start(step)
|
||||
booster.execute_pipeline(data_iter,
|
||||
model,
|
||||
criterion=lambda outputs, inputs: outputs[0],
|
||||
optimizer=optimizer,
|
||||
return_loss=False)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
performance_evaluator.on_step_end(input_ids=torch.empty(args.batch_size, args.max_length))
|
||||
else:
|
||||
for step, batch in enumerate(tqdm(dataloader, desc='Step', disable=not coordinator.is_master())):
|
||||
performance_evaluator.on_step_start(step)
|
||||
outputs = model(**batch)
|
||||
loss = outputs[0]
|
||||
booster.backward(loss, optimizer)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
performance_evaluator.on_step_end(**batch)
|
||||
|
||||
performance_evaluator.on_fit_end()
|
||||
coordinator.print_on_master(f'Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
119
examples/language/llama2/data_utils.py
Normal file
119
examples/language/llama2/data_utils.py
Normal file
@@ -0,0 +1,119 @@
|
||||
import json
|
||||
import random
|
||||
from typing import Iterator, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.distributed.distributed_c10d import _get_default_group
|
||||
from torch.utils.data import DataLoader, Dataset, DistributedSampler
|
||||
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
class StatefulDistributedSampler(DistributedSampler):
|
||||
|
||||
def __init__(self,
|
||||
dataset: Dataset,
|
||||
num_replicas: Optional[int] = None,
|
||||
rank: Optional[int] = None,
|
||||
shuffle: bool = True,
|
||||
seed: int = 0,
|
||||
drop_last: bool = False) -> None:
|
||||
super().__init__(dataset, num_replicas, rank, shuffle, seed, drop_last)
|
||||
self.start_index: int = 0
|
||||
|
||||
def __iter__(self) -> Iterator:
|
||||
iterator = super().__iter__()
|
||||
indices = list(iterator)
|
||||
indices = indices[self.start_index:]
|
||||
return iter(indices)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.num_samples - self.start_index
|
||||
|
||||
def set_start_index(self, start_index: int) -> None:
|
||||
self.start_index = start_index
|
||||
|
||||
|
||||
def prepare_dataloader(dataset,
|
||||
batch_size,
|
||||
shuffle=False,
|
||||
seed=1024,
|
||||
drop_last=False,
|
||||
pin_memory=False,
|
||||
num_workers=0,
|
||||
process_group: Optional[ProcessGroup] = None,
|
||||
**kwargs):
|
||||
r"""
|
||||
Prepare a dataloader for distributed training. The dataloader will be wrapped by
|
||||
`torch.utils.data.DataLoader` and `StatefulDistributedSampler`.
|
||||
|
||||
|
||||
Args:
|
||||
dataset (`torch.utils.data.Dataset`): The dataset to be loaded.
|
||||
shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False.
|
||||
seed (int, optional): Random worker seed for sampling, defaults to 1024.
|
||||
add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True.
|
||||
drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size
|
||||
is not divisible by the batch size. If False and the size of dataset is not divisible by
|
||||
the batch size, then the last batch will be smaller, defaults to False.
|
||||
pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False.
|
||||
num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0.
|
||||
kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in
|
||||
`DataLoader <https://pytorch.org/docs/stable/_modules/torch/utils/data/dataloader.html#DataLoader>`_.
|
||||
|
||||
Returns:
|
||||
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
|
||||
"""
|
||||
_kwargs = kwargs.copy()
|
||||
process_group = process_group or _get_default_group()
|
||||
sampler = StatefulDistributedSampler(dataset,
|
||||
num_replicas=process_group.size(),
|
||||
rank=process_group.rank(),
|
||||
shuffle=shuffle)
|
||||
|
||||
# Deterministic dataloader
|
||||
def seed_worker(worker_id):
|
||||
worker_seed = seed
|
||||
np.random.seed(worker_seed)
|
||||
torch.manual_seed(worker_seed)
|
||||
random.seed(worker_seed)
|
||||
|
||||
return DataLoader(dataset,
|
||||
batch_size=batch_size,
|
||||
sampler=sampler,
|
||||
worker_init_fn=seed_worker,
|
||||
drop_last=drop_last,
|
||||
pin_memory=pin_memory,
|
||||
num_workers=num_workers,
|
||||
**_kwargs)
|
||||
|
||||
|
||||
def load_json(file_path: str):
|
||||
with open(file_path, 'r') as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def save_json(data, file_path: str):
|
||||
with open(file_path, 'w') as f:
|
||||
json.dump(data, f, indent=4)
|
||||
|
||||
|
||||
class RandomDataset(Dataset):
|
||||
|
||||
def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000):
|
||||
self.num_samples = num_samples
|
||||
self.max_length = max_length
|
||||
self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device())
|
||||
self.attention_mask = torch.ones_like(self.input_ids)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return {
|
||||
'input_ids': self.input_ids[idx],
|
||||
'attention_mask': self.attention_mask[idx],
|
||||
'labels': self.input_ids[idx]
|
||||
}
|
32
examples/language/llama2/model_utils.py
Normal file
32
examples/language/llama2/model_utils.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from contextlib import contextmanager
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
@contextmanager
|
||||
def low_precision_init(target_dtype: torch.dtype = torch.float16):
|
||||
dtype = torch.get_default_dtype()
|
||||
try:
|
||||
torch.set_default_dtype(target_dtype)
|
||||
yield
|
||||
finally:
|
||||
torch.set_default_dtype(dtype)
|
||||
|
||||
|
||||
def get_model_numel(model: nn.Module) -> int:
|
||||
return sum(p.numel() for p in model.parameters())
|
||||
|
||||
|
||||
def format_numel_str(numel: int) -> str:
|
||||
B = 1024**3
|
||||
M = 1024**2
|
||||
K = 1024
|
||||
if numel >= B:
|
||||
return f'{numel / B:.2f} B'
|
||||
elif numel >= M:
|
||||
return f'{numel / M:.2f} M'
|
||||
elif numel >= K:
|
||||
return f'{numel / K:.2f} K'
|
||||
else:
|
||||
return f'{numel}'
|
102
examples/language/llama2/performance_evaluator.py
Normal file
102
examples/language/llama2/performance_evaluator.py
Normal file
@@ -0,0 +1,102 @@
|
||||
from time import time
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import Tensor
|
||||
|
||||
from colossalai.cluster import DistCoordinator
|
||||
|
||||
|
||||
def divide(x: float, y: float) -> float:
|
||||
if y == 0:
|
||||
return float('inf')
|
||||
elif y == float('inf'):
|
||||
return float('nan')
|
||||
return x / y
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def all_reduce_mean(x: float, world_size: int) -> float:
|
||||
if world_size == 1:
|
||||
return x
|
||||
tensor = torch.tensor([x], device=torch.cuda.current_device())
|
||||
dist.all_reduce(tensor)
|
||||
tensor = tensor / world_size
|
||||
return tensor.item()
|
||||
|
||||
|
||||
class Timer:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.start_time: Optional[float] = None
|
||||
self.duration: float = 0.
|
||||
|
||||
def start(self) -> None:
|
||||
self.start_time = time()
|
||||
|
||||
def end(self) -> None:
|
||||
assert self.start_time is not None
|
||||
self.duration += time() - self.start_time
|
||||
self.start_time = None
|
||||
|
||||
def reset(self) -> None:
|
||||
self.duration = 0.
|
||||
|
||||
|
||||
class PerformanceEvaluator:
|
||||
"""
|
||||
Callback for valuate the performance of the model.
|
||||
Args:
|
||||
actor_num_params: The number of parameters of the actor model.
|
||||
critic_num_params: The number of parameters of the critic model.
|
||||
initial_model_num_params: The number of parameters of the initial model.
|
||||
reward_model_num_params: The number of parameters of the reward model.
|
||||
enable_grad_checkpoint: Whether to enable gradient checkpointing.
|
||||
ignore_episodes: The number of episodes to ignore when calculating the performance.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model_numel: int,
|
||||
enable_grad_checkpoint: bool = False,
|
||||
ignore_steps: int = 0,
|
||||
dp_world_size: Optional[int] = None) -> None:
|
||||
self.model_numel = model_numel
|
||||
self.enable_grad_checkpoint = enable_grad_checkpoint
|
||||
self.ignore_steps = ignore_steps
|
||||
|
||||
self.coordinator = DistCoordinator()
|
||||
self.dp_world_size = dp_world_size or self.coordinator.world_size
|
||||
self.disable: bool = False
|
||||
self.timer = Timer()
|
||||
self.num_samples: int = 0
|
||||
self.flop: int = 0
|
||||
|
||||
def on_step_start(self, step: int) -> None:
|
||||
self.disable = self.ignore_steps > 0 and step < self.ignore_steps
|
||||
if self.disable:
|
||||
return
|
||||
torch.cuda.synchronize()
|
||||
self.timer.start()
|
||||
|
||||
def on_step_end(self, input_ids: Tensor, **kwargs) -> None:
|
||||
if self.disable:
|
||||
return
|
||||
torch.cuda.synchronize()
|
||||
self.timer.end()
|
||||
|
||||
batch_size, seq_len = input_ids.shape
|
||||
|
||||
self.num_samples += batch_size
|
||||
self.flop += batch_size * seq_len * self.model_numel * 2 * (3 + int(self.enable_grad_checkpoint))
|
||||
|
||||
def on_fit_end(self) -> None:
|
||||
avg_duration = all_reduce_mean(self.timer.duration, self.coordinator.world_size)
|
||||
avg_throughput = self.num_samples * self.dp_world_size / (avg_duration + 1e-12)
|
||||
mp_world_size = self.coordinator.world_size // self.dp_world_size
|
||||
avg_tflops_per_gpu = self.flop / 1e12 / (avg_duration + 1e-12) / mp_world_size
|
||||
self.coordinator.print_on_master(
|
||||
f'num_samples: {self.num_samples}, dp_world_size: {self.dp_world_size}, flop: {self.flop}, avg_duration: {avg_duration}, '
|
||||
f'avg_throughput: {avg_throughput}')
|
||||
self.coordinator.print_on_master(
|
||||
f'Throughput: {avg_throughput:.2f} samples/sec, TFLOPS per GPU: {avg_tflops_per_gpu:.2f}')
|
275
examples/language/llama2/pretrain.py
Normal file
275
examples/language/llama2/pretrain.py
Normal file
@@ -0,0 +1,275 @@
|
||||
import argparse
|
||||
import os
|
||||
import resource
|
||||
from contextlib import nullcontext
|
||||
from functools import partial
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from attn import SUPPORT_XFORMERS, replace_xformers
|
||||
from data_utils import load_json, prepare_dataloader, save_json
|
||||
from datasets import load_dataset
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from tqdm import tqdm
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
from transformers.models.llama.modeling_llama import LlamaForCausalLM
|
||||
from transformers.models.llama.tokenization_llama import LlamaTokenizer
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
MODEL_CONFIGS = {
|
||||
'7b':
|
||||
LlamaConfig(max_position_embeddings=4096),
|
||||
'13b':
|
||||
LlamaConfig(hidden_size=5120,
|
||||
intermediate_size=13824,
|
||||
num_hidden_layers=40,
|
||||
num_attention_heads=40,
|
||||
max_position_embeddings=4096),
|
||||
'70b':
|
||||
LlamaConfig(hidden_size=8192,
|
||||
intermediate_size=28672,
|
||||
num_hidden_layers=80,
|
||||
num_attention_heads=64,
|
||||
max_position_embeddings=4096,
|
||||
num_key_value_heads=8),
|
||||
}
|
||||
|
||||
|
||||
def get_model_numel(model: nn.Module) -> int:
|
||||
return sum(p.numel() for p in model.parameters())
|
||||
|
||||
|
||||
def format_numel_str(numel: int) -> str:
|
||||
B = 1024**3
|
||||
M = 1024**2
|
||||
K = 1024
|
||||
if numel >= B:
|
||||
return f'{numel / B:.2f} B'
|
||||
elif numel >= M:
|
||||
return f'{numel / M:.2f} M'
|
||||
elif numel >= K:
|
||||
return f'{numel / K:.2f} K'
|
||||
else:
|
||||
return f'{numel}'
|
||||
|
||||
|
||||
def tokenize_batch(batch, tokenizer: Optional[LlamaTokenizer] = None, max_length: int = 2048):
|
||||
texts = [sample['text'] for sample in batch]
|
||||
data = tokenizer(texts, return_tensors="pt", padding='max_length', truncation=True, max_length=max_length)
|
||||
data['labels'] = data['input_ids'].clone()
|
||||
return data
|
||||
|
||||
|
||||
def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
|
||||
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
|
||||
tensor.div_(dist.get_world_size())
|
||||
return tensor
|
||||
|
||||
|
||||
def save(booster: Booster, model: nn.Module, optimizer: Optimizer, lr_scheduler: _LRScheduler, epoch: int, step: int,
|
||||
batch_size: int, coordinator: DistCoordinator, save_dir: str):
|
||||
save_dir = os.path.join(save_dir, f'epoch{epoch}-step{step}')
|
||||
os.makedirs(os.path.join(save_dir, 'model'), exist_ok=True)
|
||||
|
||||
booster.save_model(model, os.path.join(save_dir, 'model'), shard=True)
|
||||
booster.save_optimizer(optimizer, os.path.join(save_dir, 'optimizer'), shard=True)
|
||||
booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, 'lr_scheduler'))
|
||||
running_states = {
|
||||
'epoch': epoch,
|
||||
'step': step,
|
||||
'sample_start_index': step * batch_size,
|
||||
}
|
||||
if coordinator.is_master():
|
||||
save_json(running_states, os.path.join(save_dir, 'running_states.json'))
|
||||
|
||||
|
||||
def load(booster: Booster, model: nn.Module, optimizer: Optimizer, lr_scheduler: _LRScheduler,
|
||||
load_dir: str) -> Tuple[int, int, int]:
|
||||
booster.load_model(model, os.path.join(load_dir, 'model'))
|
||||
booster.load_optimizer(optimizer, os.path.join(load_dir, 'optimizer'))
|
||||
booster.load_lr_scheduler(lr_scheduler, os.path.join(load_dir, 'lr_scheduler'))
|
||||
running_states = load_json(os.path.join(load_dir, 'running_states.json'))
|
||||
return running_states['epoch'], running_states['step'], running_states['sample_start_index']
|
||||
|
||||
|
||||
def main():
|
||||
# ==============================
|
||||
# Parse Arguments
|
||||
# ==============================
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-c', '--config', type=str, default='7b', help='Model configuration')
|
||||
parser.add_argument('-p',
|
||||
'--plugin',
|
||||
choices=['gemini', 'gemini_auto', 'zero2', 'zero2_cpu'],
|
||||
default='gemini',
|
||||
help='Choose which plugin to use')
|
||||
parser.add_argument('-d',
|
||||
'--dataset',
|
||||
type=str,
|
||||
default='togethercomputer/RedPajama-Data-1T-Sample',
|
||||
help='Data set path')
|
||||
parser.add_argument('-e', '--num_epochs', type=int, default=1, help='Number of epochs')
|
||||
parser.add_argument('-b', '--batch_size', type=int, default=2, help='Local batch size')
|
||||
parser.add_argument('--lr', type=float, default=3e-4, help='Learning rate')
|
||||
parser.add_argument('-w', '--weigth_decay', type=float, default=0.1, help='Weight decay')
|
||||
parser.add_argument('-s', '--warmup_steps', type=int, default=2000, help='Warmup steps')
|
||||
parser.add_argument('-g', '--grad_checkpoint', action='store_true', help='Use gradient checkpointing')
|
||||
parser.add_argument('-l', '--max_length', type=int, default=4096, help='Max sequence length')
|
||||
parser.add_argument('-x', '--mixed_precision', default='fp16', choices=['fp16', 'bf16'], help='Mixed precision')
|
||||
parser.add_argument('-i', '--save_interval', type=int, default=1000, help='Save interval')
|
||||
parser.add_argument('-o', '--save_dir', type=str, default='checkpoint', help='Checkpoint directory')
|
||||
parser.add_argument('-f', '--load', type=str, default=None, help='Load checkpoint')
|
||||
parser.add_argument('--grad_clip', type=float, default=1.0, help='Gradient clipping')
|
||||
parser.add_argument('-t', '--tensorboard_dir', type=str, default='tb_logs', help='Tensorboard directory')
|
||||
parser.add_argument('-a', '--flash_attention', action='store_true', help='Use Flash Attention')
|
||||
args = parser.parse_args()
|
||||
|
||||
# ==============================
|
||||
# Initialize Distributed Training
|
||||
# ==============================
|
||||
colossalai.launch_from_torch({})
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
# ==============================
|
||||
# Initialize Tensorboard
|
||||
# ==============================
|
||||
if coordinator.is_master():
|
||||
os.makedirs(args.tensorboard_dir, exist_ok=True)
|
||||
writer = SummaryWriter(args.tensorboard_dir)
|
||||
|
||||
# ==============================
|
||||
# Initialize Booster
|
||||
# ==============================
|
||||
if args.plugin == 'gemini':
|
||||
plugin = GeminiPlugin(precision=args.mixed_precision, initial_scale=2**16, max_norm=args.grad_clip)
|
||||
elif args.plugin == 'gemini_auto':
|
||||
plugin = GeminiPlugin(precision=args.mixed_precision,
|
||||
placement_policy='auto',
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip)
|
||||
elif args.plugin == 'zero2':
|
||||
plugin = LowLevelZeroPlugin(stage=2,
|
||||
precision=args.mixed_precision,
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip)
|
||||
elif args.plugin == 'zero2_cpu':
|
||||
plugin = LowLevelZeroPlugin(stage=2,
|
||||
precision=args.mixed_precision,
|
||||
initial_scale=2**16,
|
||||
cpu_offload=True,
|
||||
max_norm=args.grad_clip)
|
||||
else:
|
||||
raise ValueError(f'Unknown plugin {args.plugin}')
|
||||
|
||||
booster = Booster(plugin=plugin)
|
||||
|
||||
# ==============================
|
||||
# Initialize Tokenizer, Dataset and Dataloader
|
||||
# ==============================
|
||||
tokenizer = LlamaTokenizer.from_pretrained('hf-internal-testing/llama-tokenizer')
|
||||
# follows fast chat: https://github.com/lm-sys/FastChat/blob/main/fastchat/train/train.py#L257
|
||||
tokenizer.pad_token = tokenizer.unk_token
|
||||
|
||||
dataset = load_dataset(args.dataset)
|
||||
train_ds = dataset['train']
|
||||
dataloader = prepare_dataloader(train_ds,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=partial(tokenize_batch, tokenizer=tokenizer, max_length=args.max_length))
|
||||
|
||||
# ==============================
|
||||
# Initialize Model, Optimizer and LR Scheduler
|
||||
# ==============================
|
||||
config = MODEL_CONFIGS[args.config]
|
||||
init_ctx = LazyInitContext(
|
||||
default_device=get_current_device()) if isinstance(plugin, GeminiPlugin) else nullcontext()
|
||||
|
||||
with init_ctx:
|
||||
model = LlamaForCausalLM(config)
|
||||
|
||||
if args.grad_checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
if args.flash_attention:
|
||||
assert SUPPORT_XFORMERS, 'Use flash attention while xfomers is not installed'
|
||||
replace_xformers(model)
|
||||
|
||||
model_numel = get_model_numel(model)
|
||||
coordinator.print_on_master(f'Model params: {format_numel_str(model_numel)}')
|
||||
|
||||
optimizer = HybridAdam(model.parameters(), lr=args.lr, betas=(0.9, 0.95), weight_decay=args.weigth_decay)
|
||||
lr_scheduler = CosineAnnealingWarmupLR(optimizer,
|
||||
total_steps=args.num_epochs * len(dataloader),
|
||||
warmup_steps=args.warmup_steps,
|
||||
eta_min=0.1 * args.lr)
|
||||
default_dtype = torch.float16 if args.mixed_precision == 'fp16' else torch.bfloat16
|
||||
torch.set_default_dtype(default_dtype)
|
||||
model, optimizer, _, dataloader, lr_scheduler = booster.boost(model,
|
||||
optimizer,
|
||||
dataloader=dataloader,
|
||||
lr_scheduler=lr_scheduler)
|
||||
torch.set_default_dtype(torch.float)
|
||||
|
||||
coordinator.print_on_master(f'Booster init max CUDA memory: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB')
|
||||
coordinator.print_on_master(
|
||||
f'Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB')
|
||||
|
||||
# load checkpoint if specified
|
||||
start_epoch = 0
|
||||
start_step = 0
|
||||
sampler_start_idx = 0
|
||||
if args.load is not None:
|
||||
coordinator.print_on_master('Loading checkpoint')
|
||||
start_epoch, start_step, sampler_start_idx = load(booster, model, optimizer, lr_scheduler, args.load)
|
||||
coordinator.print_on_master(f'Loaded checkpoint {args.load} at epoch {start_epoch} step {start_step}')
|
||||
|
||||
num_steps_per_epoch = len(dataloader)
|
||||
# if resume training, set the sampler start index to the correct value
|
||||
dataloader.sampler.set_start_index(sampler_start_idx)
|
||||
for epoch in range(start_epoch, args.num_epochs):
|
||||
dataloader.sampler.set_epoch(epoch)
|
||||
with tqdm(enumerate(dataloader),
|
||||
desc=f'Epoch {epoch}',
|
||||
disable=not coordinator.is_master(),
|
||||
total=num_steps_per_epoch,
|
||||
initial=start_step) as pbar:
|
||||
for step, batch in pbar:
|
||||
batch = {k: v.cuda() for k, v in batch.items()}
|
||||
outputs = model(**batch)
|
||||
loss = outputs[0]
|
||||
booster.backward(loss, optimizer)
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
all_reduce_mean(loss)
|
||||
pbar.set_postfix({'loss': loss.item()})
|
||||
if coordinator.is_master():
|
||||
writer.add_scalar('loss', loss.item(), epoch * num_steps_per_epoch + step)
|
||||
|
||||
if args.save_interval > 0 and (step + 1) % args.save_interval == 0:
|
||||
coordinator.print_on_master(f'Saving checkpoint')
|
||||
save(booster, model, optimizer, lr_scheduler, epoch, step + 1, args.batch_size, coordinator,
|
||||
args.save_dir)
|
||||
coordinator.print_on_master(f'Saved checkpoint at epoch {epoch} step {step + 1}')
|
||||
# the continue epochs are not resumed, so we need to reset the sampler start index and start step
|
||||
dataloader.sampler.set_start_index(0)
|
||||
start_step = 0
|
||||
|
||||
coordinator.print_on_master(f'Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
9
examples/language/llama2/requirements.txt
Normal file
9
examples/language/llama2/requirements.txt
Normal file
@@ -0,0 +1,9 @@
|
||||
colossalai>=0.3.0
|
||||
datasets
|
||||
numpy
|
||||
torch>=1.12.0,<=2.0.0
|
||||
tqdm
|
||||
transformers
|
||||
flash-attn>=2.0.0,<=2.0.5
|
||||
SentencePiece==0.1.99
|
||||
tensorboard==2.14.0
|
17
examples/language/llama2/scripts/benchmark_70B/3d.sh
Normal file
17
examples/language/llama2/scripts/benchmark_70B/3d.sh
Normal file
@@ -0,0 +1,17 @@
|
||||
#!/bin/bash
|
||||
|
||||
# TODO: fix this
|
||||
echo "3D parallel for LLaMA-2 is not ready yet"
|
||||
exit 1
|
||||
|
||||
################
|
||||
#Load your environments and modules here
|
||||
################
|
||||
|
||||
HOSTFILE=$(realpath hosts.txt)
|
||||
|
||||
cd ../..
|
||||
|
||||
export OMP_NUM_THREADS=8
|
||||
|
||||
colossalai run --nproc_per_node 8 --hostfile $HOSTFILE benchmark.py -c 70b -p 3d -g -x -b 8 --tp 4 --pp 2 --mbs 4
|
13
examples/language/llama2/scripts/benchmark_70B/gemini.sh
Normal file
13
examples/language/llama2/scripts/benchmark_70B/gemini.sh
Normal file
@@ -0,0 +1,13 @@
|
||||
#!/bin/bash
|
||||
|
||||
################
|
||||
#Load your environments and modules here
|
||||
################
|
||||
|
||||
HOSTFILE=$(realpath hosts.txt)
|
||||
|
||||
cd ../..
|
||||
|
||||
export OMP_NUM_THREADS=8
|
||||
|
||||
colossalai run --nproc_per_node 8 --hostfile $HOSTFILE benchmark.py -c 70b -g -x -b 2
|
@@ -0,0 +1,13 @@
|
||||
#!/bin/bash
|
||||
|
||||
################
|
||||
#Load your environments and modules here
|
||||
################
|
||||
|
||||
HOSTFILE=$(realpath hosts.txt)
|
||||
|
||||
cd ../..
|
||||
|
||||
export OMP_NUM_THREADS=8
|
||||
|
||||
colossalai run --nproc_per_node 8 --hostfile $HOSTFILE benchmark.py -c 70b -p gemini_auto -g -x -b 2
|
13
examples/language/llama2/scripts/benchmark_7B/gemini.sh
Normal file
13
examples/language/llama2/scripts/benchmark_7B/gemini.sh
Normal file
@@ -0,0 +1,13 @@
|
||||
#!/bin/bash
|
||||
|
||||
################
|
||||
#Load your environments and modules here
|
||||
################
|
||||
|
||||
HOSTFILE=$(realpath hosts.txt)
|
||||
|
||||
cd ../..
|
||||
|
||||
export OMP_NUM_THREADS=8
|
||||
|
||||
colossalai run --nproc_per_node 8 --hostfile $HOSTFILE benchmark.py -g -x -b 16
|
13
examples/language/llama2/scripts/benchmark_7B/gemini_auto.sh
Normal file
13
examples/language/llama2/scripts/benchmark_7B/gemini_auto.sh
Normal file
@@ -0,0 +1,13 @@
|
||||
#!/bin/bash
|
||||
|
||||
################
|
||||
#Load your environments and modules here
|
||||
################
|
||||
|
||||
HOSTFILE=$(realpath hosts.txt)
|
||||
|
||||
cd ../..
|
||||
|
||||
export OMP_NUM_THREADS=8
|
||||
|
||||
colossalai run --nproc_per_node 8 --hostfile $HOSTFILE benchmark.py -p gemini_auto -g -x -b 16
|
@@ -1,22 +1,18 @@
|
||||
import time
|
||||
|
||||
import torch
|
||||
import tqdm
|
||||
import transformers
|
||||
from args import parse_benchmark_args
|
||||
from transformers import AutoConfig, OPTForCausalLM
|
||||
from transformers.utils.versions import require_version
|
||||
import tqdm
|
||||
|
||||
import colossalai
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.tensor import ProcessGroup, ShardSpec
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero import ColoInitContext
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
|
||||
from args import parse_benchmark_args
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
|
||||
require_version("transformers>=4.20.0", "To fix: pip install -r requirements.txt")
|
||||
|
||||
@@ -61,11 +57,11 @@ def main():
|
||||
transformers.utils.logging.set_verbosity_info()
|
||||
else:
|
||||
transformers.utils.logging.set_verbosity_error()
|
||||
|
||||
|
||||
# Whether to set limit of memory capacity
|
||||
if args.mem_cap > 0:
|
||||
colo_memory_cap(args.mem_cap)
|
||||
|
||||
|
||||
# Build OPT model
|
||||
config = AutoConfig.from_pretrained(args.model_name_or_path)
|
||||
model = OPTForCausalLM(config=config)
|
||||
@@ -81,11 +77,7 @@ def main():
|
||||
if args.plugin.startswith('torch_ddp'):
|
||||
plugin = TorchDDPPlugin()
|
||||
elif args.plugin == 'gemini':
|
||||
plugin = GeminiPlugin(device=get_current_device(),
|
||||
placement_policy='cpu',
|
||||
pin_memory=True,
|
||||
strict_ddp_mode=True,
|
||||
initial_scale=2**5)
|
||||
plugin = GeminiPlugin(offload_optim_frac=1.0, pin_memory=True, initial_scale=2**5)
|
||||
elif args.plugin == 'low_level_zero':
|
||||
plugin = LowLevelZeroPlugin(initial_scale=2**5)
|
||||
logger.info(f"Set plugin as {args.plugin}", ranks=[0])
|
||||
@@ -96,18 +88,18 @@ def main():
|
||||
# Set booster
|
||||
booster = Booster(plugin=plugin, **booster_kwargs)
|
||||
model, optimizer, _, _, _ = booster.boost(model, optimizer)
|
||||
|
||||
|
||||
SEQ_LEN = 1024
|
||||
VOCAB_SIZE = 50257
|
||||
|
||||
# Start training.
|
||||
logger.info(f"Start testing", ranks=[0])
|
||||
progress_bar = tqdm.tqdm(total=args.max_train_steps, desc="Training Step", disable=not coordinator.is_master())
|
||||
|
||||
|
||||
torch.cuda.synchronize()
|
||||
model.train()
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
for _ in range(args.max_train_steps):
|
||||
|
||||
input_ids, attn_mask = get_data(args.batch_size, SEQ_LEN, VOCAB_SIZE)
|
||||
@@ -119,18 +111,19 @@ def main():
|
||||
|
||||
torch.cuda.synchronize()
|
||||
progress_bar.update(1)
|
||||
|
||||
# Compute Statistics
|
||||
|
||||
# Compute Statistics
|
||||
end_time = time.time()
|
||||
throughput = "{:.4f}".format((world_size * args.max_train_steps * args.batch_size) / (end_time - start_time))
|
||||
max_mem = format_num(torch.cuda.max_memory_allocated(device=torch.cuda.current_device()), bytes=True)
|
||||
|
||||
logger.info(f"Testing finished, "
|
||||
f"batch size per gpu: {args.batch_size}, "
|
||||
f"plugin: {args.plugin}, "
|
||||
f"throughput: {throughput}, "
|
||||
f"maximum memory usage per gpu: {max_mem}.",
|
||||
ranks=[0])
|
||||
|
||||
logger.info(
|
||||
f"Testing finished, "
|
||||
f"batch size per gpu: {args.batch_size}, "
|
||||
f"plugin: {args.plugin}, "
|
||||
f"throughput: {throughput}, "
|
||||
f"maximum memory usage per gpu: {max_mem}.",
|
||||
ranks=[0])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@@ -1,25 +1,20 @@
|
||||
import time
|
||||
|
||||
import torch
|
||||
import datasets
|
||||
import torch
|
||||
import transformers
|
||||
from transformers import AutoConfig, OPTForCausalLM, AutoTokenizer
|
||||
from transformers import get_linear_schedule_with_warmup
|
||||
from transformers.utils.versions import require_version
|
||||
from args import parse_demo_args
|
||||
from data import NetflixDataset, netflix_collator
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoConfig, AutoTokenizer, OPTForCausalLM, get_linear_schedule_with_warmup
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
import colossalai
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.tensor import ProcessGroup, ShardSpec
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero import ColoInitContext
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
|
||||
from args import parse_demo_args
|
||||
from data import NetflixDataset, netflix_collator
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r requirements.txt")
|
||||
require_version("transformers>=4.20.0", "To fix: pip install -r requirements.txt")
|
||||
@@ -30,18 +25,18 @@ def move_to_cuda(batch, device):
|
||||
|
||||
|
||||
def train_epoch(epoch, model, optimizer, lr_scheduler, dataloader, booster, coordinator):
|
||||
|
||||
|
||||
torch.cuda.synchronize()
|
||||
model.train()
|
||||
|
||||
with tqdm(dataloader, desc=f'Epoch [{epoch + 1}]', disable=not coordinator.is_master()) as pbar:
|
||||
|
||||
|
||||
for batch in pbar:
|
||||
|
||||
# Forward
|
||||
optimizer.zero_grad()
|
||||
batch = move_to_cuda(batch, torch.cuda.current_device())
|
||||
|
||||
|
||||
outputs = model(use_cache=False, **batch)
|
||||
loss = outputs['loss']
|
||||
|
||||
@@ -72,7 +67,7 @@ def main():
|
||||
else:
|
||||
datasets.utils.logging.set_verbosity_error()
|
||||
transformers.utils.logging.set_verbosity_error()
|
||||
|
||||
|
||||
# Build OPT model
|
||||
config = AutoConfig.from_pretrained(args.model_name_or_path)
|
||||
model = OPTForCausalLM.from_pretrained(args.model_name_or_path, config=config)
|
||||
@@ -88,43 +83,35 @@ def main():
|
||||
if args.plugin.startswith('torch_ddp'):
|
||||
plugin = TorchDDPPlugin()
|
||||
elif args.plugin == 'gemini':
|
||||
plugin = GeminiPlugin(device=get_current_device(),
|
||||
placement_policy='cpu',
|
||||
pin_memory=True,
|
||||
strict_ddp_mode=True,
|
||||
initial_scale=2**5)
|
||||
plugin = GeminiPlugin(offload_optim_frac=1.0, pin_memory=True, initial_scale=2**5)
|
||||
elif args.plugin == 'low_level_zero':
|
||||
plugin = LowLevelZeroPlugin(initial_scale=2**5)
|
||||
logger.info(f"Set plugin as {args.plugin}", ranks=[0])
|
||||
|
||||
# Prepare tokenizer and dataloader
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
|
||||
dataset = NetflixDataset(tokenizer)
|
||||
dataloader = plugin.prepare_dataloader(dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=netflix_collator)
|
||||
|
||||
|
||||
# Set optimizer
|
||||
optimizer = HybridAdam(model.parameters(),
|
||||
lr=(args.learning_rate * world_size),
|
||||
weight_decay=args.weight_decay)
|
||||
optimizer = HybridAdam(model.parameters(), lr=(args.learning_rate * world_size), weight_decay=args.weight_decay)
|
||||
|
||||
# Set lr scheduler
|
||||
total_steps = len(dataloader) * args.num_epoch
|
||||
num_warmup_steps = int(args.warmup_ratio * total_steps)
|
||||
lr_scheduler = get_linear_schedule_with_warmup(
|
||||
optimizer,
|
||||
num_warmup_steps=num_warmup_steps,
|
||||
num_training_steps=len(dataloader) * args.num_epoch
|
||||
)
|
||||
lr_scheduler = get_linear_schedule_with_warmup(optimizer,
|
||||
num_warmup_steps=num_warmup_steps,
|
||||
num_training_steps=len(dataloader) * args.num_epoch)
|
||||
|
||||
# Set booster
|
||||
booster = Booster(plugin=plugin, **booster_kwargs)
|
||||
model, optimizer, _, dataloader, lr_scheduler = booster.boost(model=model,
|
||||
optimizer=optimizer,
|
||||
dataloader=dataloader,
|
||||
model, optimizer, _, dataloader, lr_scheduler = booster.boost(model=model,
|
||||
optimizer=optimizer,
|
||||
dataloader=dataloader,
|
||||
lr_scheduler=lr_scheduler)
|
||||
|
||||
# Start finetuning
|
||||
|
@@ -1,5 +1,5 @@
|
||||
import gzip
|
||||
import random
|
||||
from contextlib import nullcontext
|
||||
from functools import partial
|
||||
from time import time
|
||||
|
||||
@@ -8,20 +8,17 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import tqdm
|
||||
from packaging import version
|
||||
|
||||
from colossalai.nn import HybridAdam
|
||||
from palm_pytorch import PaLM
|
||||
from palm_pytorch.autoregressive_wrapper import AutoregressiveWrapper
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
|
||||
from colossalai.utils import MultiTimer, get_current_device
|
||||
from colossalai.zero import ColoInitContext, GeminiAdamOptimizer, ZeroDDP
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.nn import HybridAdam
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
# constants
|
||||
|
||||
@@ -44,23 +41,10 @@ def parse_args():
|
||||
help="The distributed plan [colossalai, pytorch].",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tp_degree",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Tensor Parallelism Degree. Valid when using colossalai as dist plan.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--placement",
|
||||
type=str,
|
||||
default='cpu',
|
||||
help="Placement Policy for Gemini. Valid when using colossalai as dist plan.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--shardinit",
|
||||
type=bool,
|
||||
default=False,
|
||||
help=
|
||||
"Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan.",
|
||||
"--offload_optim_frac",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Fraction of optimizer states to be offloaded. This is only used for gemini.",
|
||||
)
|
||||
parser.add_argument('-p',
|
||||
'--plugin',
|
||||
@@ -111,51 +95,6 @@ def get_model_size(model: nn.Module):
|
||||
return total_numel
|
||||
|
||||
|
||||
|
||||
|
||||
# Parameter Sharding Strategies for Tensor Parallelism
|
||||
def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup):
|
||||
spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
param.set_tensor_spec(*spec)
|
||||
|
||||
|
||||
def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup):
|
||||
split_param_single_dim_tp1d(0, param, pg)
|
||||
|
||||
|
||||
def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup):
|
||||
split_param_single_dim_tp1d(-1, param, pg)
|
||||
|
||||
|
||||
# Tensor Parallel
|
||||
def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
|
||||
"""tensor_parallelize
|
||||
Sharding the Model Parameters.
|
||||
Args:
|
||||
model (torch.nn.Module): a torch module to be sharded
|
||||
"""
|
||||
for mn, module in model.named_modules():
|
||||
for pn, param in module.named_parameters(recurse=False):
|
||||
if hasattr(param, 'visited'):
|
||||
continue
|
||||
param.set_dist_spec(ReplicaSpec())
|
||||
if 'net.0' in mn:
|
||||
split_param_col_tp1d(param, pg) # column slice
|
||||
elif 'to_q' in mn:
|
||||
split_param_col_tp1d(param, pg) # column slice
|
||||
elif 'to_kv' in mn:
|
||||
split_param_row_tp1d(param, pg) # row slice
|
||||
elif 'to_out' in mn:
|
||||
split_param_row_tp1d(param, pg) # row slice
|
||||
elif '1.1' in mn:
|
||||
split_param_col_tp1d(param, pg) # column slice
|
||||
elif '1.2' in mn:
|
||||
split_param_row_tp1d(param, pg) # row slice
|
||||
else:
|
||||
param.set_dist_spec(ReplicaSpec())
|
||||
param.visited = True
|
||||
|
||||
|
||||
args = parse_args()
|
||||
if args.distplan not in ["colossalai", "pytorch"]:
|
||||
raise TypeError(f"{args.distplan} is error")
|
||||
@@ -212,23 +151,18 @@ if args.distplan == "colossalai":
|
||||
if args.plugin.startswith('torch_ddp'):
|
||||
plugin = TorchDDPPlugin()
|
||||
elif args.plugin == 'gemini':
|
||||
plugin = GeminiPlugin(placement_policy=args.placement, strict_ddp_mode=True, initial_scale=2 ** 5)
|
||||
plugin = GeminiPlugin(offload_optim_frac=args.offload_optim_frac, initial_scale=2**5)
|
||||
elif args.plugin == 'low_level_zero':
|
||||
plugin = LowLevelZeroPlugin(initial_scale=2 ** 5)
|
||||
plugin = LowLevelZeroPlugin(initial_scale=2**5)
|
||||
logger.info(f"plugin: {plugin}")
|
||||
booster = Booster(plugin=plugin, **booster_kwargs)
|
||||
|
||||
default_pg = ProcessGroup(tp_degree=args.tp_degree)
|
||||
default_dist_spec = ShardSpec([-1], [args.tp_degree]) if args.shardinit else None
|
||||
ctx = ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg)
|
||||
ctx = LazyInitContext(default_device=get_current_device()) if args.plugin == 'gemini' else nullcontext()
|
||||
|
||||
with ctx:
|
||||
model = PaLM(num_tokens=50304, dim=4096, depth=64)
|
||||
model = AutoregressiveWrapper(model, max_seq_len=SEQ_LEN)
|
||||
|
||||
pg = default_pg
|
||||
tensor_parallelize(model, pg)
|
||||
|
||||
# optimizer
|
||||
|
||||
optimizer = HybridAdam(model.parameters(), lr=LEARNING_RATE, initial_scale=2**5)
|
||||
|
Reference in New Issue
Block a user