mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 20:10:17 +00:00
[example] llama3 (#5631)
* release llama3 * [release] llama3 * [release] llama3 * [release] llama3 * [release] llama3
This commit is contained in:
127
examples/language/llama/README.md
Normal file
127
examples/language/llama/README.md
Normal file
@@ -0,0 +1,127 @@
|
||||
# Pretraining LLaMA-1/2/3: best practices for building LLaMA-1/2/3-like base models
|
||||
### LLaMA3
|
||||
<p align="center">
|
||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/examples/images/LLaMA3-70B-H100.png" width=600/>
|
||||
</p>
|
||||
|
||||
- 70 billion parameter LLaMA3 model training accelerated by 18%
|
||||
|
||||
### LLaMA2
|
||||
<p align="center">
|
||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/llama2_pretraining.png" width=600/>
|
||||
</p>
|
||||
|
||||
- 70 billion parameter LLaMA2 model training accelerated by 195%
|
||||
[[blog]](https://www.hpc-ai.tech/blog/70b-llama2-training)
|
||||
|
||||
### LLaMA1
|
||||
<p align="center">
|
||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/examples/images/LLaMA_pretraining.png" width=600/>
|
||||
</p>
|
||||
|
||||
- 65-billion-parameter large model pretraining accelerated by 38%
|
||||
[[blog]](https://www.hpc-ai.tech/blog/large-model-pretraining)
|
||||
|
||||
## Usage
|
||||
|
||||
> ⚠ This example only has benchmarking script. For training/finetuning, please refer to the [applications/Colossal-LLaMA](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Colossal-LLaMA).
|
||||
|
||||
### 1. Installation
|
||||
|
||||
Please install the latest ColossalAI from source.
|
||||
|
||||
```bash
|
||||
BUILD_EXT=1 pip install -U git+https://github.com/hpcaitech/ColossalAI
|
||||
```
|
||||
|
||||
Then install other dependencies.
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### 4. Shell Script Examples
|
||||
|
||||
For your convenience, we provide some shell scripts to run benchmark with various configurations.
|
||||
|
||||
You can find them in `scripts/benchmark_7B` and `scripts/benchmark_70B` directory. The main command should be in the format of:
|
||||
```bash
|
||||
colossalai run --nproc_per_node YOUR_GPU_PER_NODE --hostfile YOUR_HOST_FILE \
|
||||
benchmark.py --OTHER_CONFIGURATIONS
|
||||
```
|
||||
Here we will show an example of how to run training
|
||||
llama pretraining with `gemini, batch_size=16, sequence_length=4096, gradient_checkpoint=True, flash_attn=True`.
|
||||
|
||||
#### a. Running environment
|
||||
This experiment was performed on 4 computing nodes with 32 A800/H800 80GB GPUs in total for LLaMA-1 65B or LLaMA-2 70B. The nodes are
|
||||
connected with RDMA and GPUs within one node are fully connected with NVLink.
|
||||
|
||||
#### b. Running command
|
||||
|
||||
```bash
|
||||
cd scripts/benchmark_7B
|
||||
```
|
||||
|
||||
First, put your host file (`hosts.txt`) in this directory with your real host ip or host name.
|
||||
|
||||
Here is a sample `hosts.txt`:
|
||||
```text
|
||||
hostname1
|
||||
hostname2
|
||||
hostname3
|
||||
hostname4
|
||||
```
|
||||
|
||||
Then add environment variables to script if needed.
|
||||
|
||||
Finally, run the following command to start training:
|
||||
|
||||
```bash
|
||||
bash gemini.sh
|
||||
```
|
||||
|
||||
If you encounter out-of-memory(OOM) error during training with script `gemini.sh`, changing to script `gemini_auto.sh` might be a solution, since gemini_auto will set a upper limit on GPU memory usage through offloading part of the model parameters and optimizer states back to CPU memory. But there's a trade-off: `gemini_auto.sh` will be a bit slower, since more data are transmitted between CPU and GPU.
|
||||
|
||||
#### c. Results
|
||||
If you run the above command successfully, you will get the following results:
|
||||
`max memory usage: 55491.10 MB, throughput: 24.26 samples/s, TFLOPS/GPU: 167.43`.
|
||||
|
||||
|
||||
## Reference
|
||||
```
|
||||
@article{bian2021colossal,
|
||||
title={Colossal-AI: A Unified Deep Learning System For Large-Scale Parallel Training},
|
||||
author={Bian, Zhengda and Liu, Hongxin and Wang, Boxiang and Huang, Haichen and Li, Yongbin and Wang, Chuanrui and Cui, Fan and You, Yang},
|
||||
journal={arXiv preprint arXiv:2110.14883},
|
||||
year={2021}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@software{openlm2023openllama,
|
||||
author = {Geng, Xinyang and Liu, Hao},
|
||||
title = {OpenLLaMA: An Open Reproduction of LLaMA},
|
||||
month = May,
|
||||
year = 2023,
|
||||
url = {https://github.com/openlm-research/open_llama}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@software{together2023redpajama,
|
||||
author = {Together Computer},
|
||||
title = {RedPajama-Data: An Open Source Recipe to Reproduce LLaMA training dataset},
|
||||
month = April,
|
||||
year = 2023,
|
||||
url = {https://github.com/togethercomputer/RedPajama-Data}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@article{touvron2023llama,
|
||||
title={Llama: Open and efficient foundation language models},
|
||||
author={Touvron, Hugo and Lavril, Thibaut and Izacard, Gautier and Martinet, Xavier and Lachaux, Marie-Anne and Lacroix, Timoth{\'e}e and Rozi{\`e}re, Baptiste and Goyal, Naman and Hambro, Eric and Azhar, Faisal and others},
|
||||
journal={arXiv preprint arXiv:2302.13971},
|
||||
year={2023}
|
||||
}
|
||||
```
|
278
examples/language/llama/benchmark.py
Normal file
278
examples/language/llama/benchmark.py
Normal file
@@ -0,0 +1,278 @@
|
||||
import argparse
|
||||
import resource
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
from data_utils import RandomDataset
|
||||
from model_utils import format_numel_str, get_model_numel
|
||||
from performance_evaluator import PerformanceEvaluator
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoConfig, AutoModelForCausalLM
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
|
||||
import colossalai
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, TorchFSDPPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.shardformer import PipelineGradientCheckpointConfig
|
||||
from examples.language.data_utils import RandomDataset
|
||||
from examples.language.model_utils import format_numel_str, get_model_numel
|
||||
from examples.language.performance_evaluator import PerformanceEvaluator
|
||||
|
||||
# ==============================
|
||||
# Constants
|
||||
# ==============================
|
||||
|
||||
MODEL_CONFIGS = {
|
||||
"7b": LlamaConfig(max_position_embeddings=4096),
|
||||
"13b": LlamaConfig(
|
||||
hidden_size=5120,
|
||||
intermediate_size=13824,
|
||||
num_hidden_layers=40,
|
||||
num_attention_heads=40,
|
||||
max_position_embeddings=4096,
|
||||
),
|
||||
"70b": LlamaConfig(
|
||||
hidden_size=8192,
|
||||
intermediate_size=28672,
|
||||
num_hidden_layers=80,
|
||||
num_attention_heads=64,
|
||||
max_position_embeddings=4096,
|
||||
num_key_value_heads=8,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
# ==============================
|
||||
# Parse Arguments
|
||||
# ==============================
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-c", "--config", type=str, default="7b", help="Model configuration")
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--plugin",
|
||||
choices=["gemini", "gemini_auto", "fsdp", "fsdp_cpu", "3d", "3d_cpu"],
|
||||
default="gemini",
|
||||
help="Choose which plugin to use",
|
||||
)
|
||||
parser.add_argument("-b", "--batch_size", type=int, default=2, help="Batch size")
|
||||
parser.add_argument("-s", "--num_steps", type=int, default=5, help="Number of steps to run")
|
||||
parser.add_argument("-i", "--ignore_steps", type=int, default=2, help="Number of steps to ignore")
|
||||
parser.add_argument("-g", "--grad_checkpoint", action="store_true", help="Use gradient checkpointing")
|
||||
parser.add_argument("-l", "--max_length", type=int, default=4096, help="Max sequence length")
|
||||
parser.add_argument(
|
||||
"-w", "--warmup_ratio", type=float, default=0.8, help="warm up ratio of non-model data. Only for gemini-auto"
|
||||
)
|
||||
parser.add_argument("-m", "--memory_limit", type=int, help="Gemini memory limit in mb")
|
||||
parser.add_argument("-x", "--xformers", action="store_true", help="Use xformers")
|
||||
parser.add_argument("--shard_param_frac", type=float, default=1.0, help="Shard param fraction. Only for gemini")
|
||||
parser.add_argument("--offload_optim_frac", type=float, default=0.0, help="Offload optim fraction. Only for gemini")
|
||||
parser.add_argument("--offload_param_frac", type=float, default=0.0, help="Offload param fraction. Only for gemini")
|
||||
parser.add_argument("--tp", type=int, default=1, help="Tensor parallel size")
|
||||
parser.add_argument("--extra_dp", type=int, default=1, help="Extra data parallel size, used for Gemini")
|
||||
parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel size")
|
||||
parser.add_argument("--mbs", type=int, default=1, help="Micro batch size of pipeline parallel")
|
||||
parser.add_argument("--zero", type=int, default=0, help="Zero Stage when hybrid plugin is enabled")
|
||||
parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False)
|
||||
args = parser.parse_args()
|
||||
|
||||
colossalai.launch_from_torch({})
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
def empty_init():
|
||||
pass
|
||||
|
||||
# ckpt config for LLaMA3-70B on 64 H100 GPUs
|
||||
ckpt_config = (
|
||||
PipelineGradientCheckpointConfig(
|
||||
num_stages=args.pp,
|
||||
num_model_chunks=1,
|
||||
num_model_layers=80,
|
||||
num_layers_per_stage=[19, 20, 20, 21],
|
||||
num_ckpt_layers_per_stage=[19, 19, 19, 13],
|
||||
)
|
||||
if args.custom_ckpt
|
||||
else None
|
||||
)
|
||||
|
||||
# ==============================
|
||||
# Initialize Booster
|
||||
# ==============================
|
||||
use_empty_init = True
|
||||
if args.plugin == "gemini":
|
||||
plugin = GeminiPlugin(
|
||||
precision="bf16",
|
||||
shard_param_frac=args.shard_param_frac,
|
||||
offload_optim_frac=args.offload_optim_frac,
|
||||
offload_param_frac=args.offload_param_frac,
|
||||
tp_size=args.tp,
|
||||
extra_dp_size=args.extra_dp,
|
||||
enable_fused_normalization=torch.cuda.is_available(),
|
||||
enable_flash_attention=args.xformers,
|
||||
)
|
||||
elif args.plugin == "gemini_auto":
|
||||
plugin = GeminiPlugin(
|
||||
placement_policy="auto",
|
||||
precision="bf16",
|
||||
warmup_non_model_data_ratio=args.warmup_ratio,
|
||||
tp_size=args.tp,
|
||||
extra_dp_size=args.extra_dp,
|
||||
enable_fused_normalization=torch.cuda.is_available(),
|
||||
enable_flash_attention=args.xformers,
|
||||
)
|
||||
elif args.plugin == "fsdp":
|
||||
if use_empty_init:
|
||||
plugin = TorchFSDPPlugin(
|
||||
mixed_precision=MixedPrecision(
|
||||
param_dtype=torch.float16,
|
||||
reduce_dtype=torch.float16,
|
||||
buffer_dtype=torch.float16,
|
||||
),
|
||||
param_init_fn=empty_init(),
|
||||
)
|
||||
else:
|
||||
plugin = TorchFSDPPlugin(
|
||||
mixed_precision=MixedPrecision(
|
||||
param_dtype=torch.float16,
|
||||
reduce_dtype=torch.float16,
|
||||
buffer_dtype=torch.float16,
|
||||
)
|
||||
)
|
||||
elif args.plugin == "fsdp_cpu":
|
||||
if use_empty_init:
|
||||
plugin = TorchFSDPPlugin(
|
||||
mixed_precision=MixedPrecision(
|
||||
param_dtype=torch.float16,
|
||||
reduce_dtype=torch.float16,
|
||||
buffer_dtype=torch.float16,
|
||||
),
|
||||
cpu_offload=CPUOffload(offload_params=True),
|
||||
param_init_fn=empty_init(),
|
||||
)
|
||||
else:
|
||||
plugin = TorchFSDPPlugin(
|
||||
mixed_precision=MixedPrecision(
|
||||
param_dtype=torch.float16,
|
||||
reduce_dtype=torch.float16,
|
||||
buffer_dtype=torch.float16,
|
||||
),
|
||||
cpu_offload=CPUOffload(offload_params=True),
|
||||
)
|
||||
elif args.plugin == "3d":
|
||||
plugin = HybridParallelPlugin(
|
||||
tp_size=args.tp,
|
||||
pp_size=args.pp,
|
||||
zero_stage=args.zero,
|
||||
enable_fused_normalization=torch.cuda.is_available(),
|
||||
enable_flash_attention=args.xformers,
|
||||
microbatch_size=args.mbs,
|
||||
precision="bf16",
|
||||
dp_outside=False,
|
||||
gradient_checkpoint_config=ckpt_config,
|
||||
)
|
||||
elif args.plugin == "3d_cpu":
|
||||
plugin = HybridParallelPlugin(
|
||||
tp_size=args.tp,
|
||||
pp_size=args.pp,
|
||||
zero_stage=args.zero,
|
||||
cpu_offload=True,
|
||||
enable_fused_normalization=torch.cuda.is_available(),
|
||||
enable_flash_attention=args.xformers,
|
||||
microbatch_size=args.mbs,
|
||||
initial_scale=2**8,
|
||||
precision="bf16",
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown plugin {args.plugin}")
|
||||
|
||||
booster = Booster(plugin=plugin)
|
||||
|
||||
# ==============================
|
||||
# Initialize Dataset and Dataloader
|
||||
# ==============================
|
||||
dp_size = getattr(plugin, "dp_size", coordinator.world_size)
|
||||
|
||||
if args.config in MODEL_CONFIGS:
|
||||
config = MODEL_CONFIGS[args.config]
|
||||
else:
|
||||
config = AutoConfig.from_pretrained(args.config, trust_remote_code=True)
|
||||
dataset = RandomDataset(
|
||||
num_samples=args.batch_size * args.num_steps * dp_size, max_length=args.max_length, vocab_size=config.vocab_size
|
||||
)
|
||||
dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True)
|
||||
|
||||
# ==============================
|
||||
# Initialize Model and Optimizer
|
||||
# ==============================
|
||||
init_ctx = (
|
||||
LazyInitContext(default_device=get_accelerator().get_current_device())
|
||||
if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin))
|
||||
else nullcontext()
|
||||
)
|
||||
|
||||
init_kwargs = {}
|
||||
if config.model_type == "chatglm":
|
||||
init_kwargs["empty_init"] = False
|
||||
|
||||
with init_ctx:
|
||||
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True, **init_kwargs)
|
||||
|
||||
if args.grad_checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
if config.model_type == "chatglm":
|
||||
model.transformer.encoder.gradient_checkpointing = True
|
||||
|
||||
model_numel = get_model_numel(model)
|
||||
coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}")
|
||||
performance_evaluator = PerformanceEvaluator(
|
||||
model_numel,
|
||||
model.config.num_hidden_layers,
|
||||
model.config.hidden_size,
|
||||
model.config.vocab_size,
|
||||
args.grad_checkpoint,
|
||||
args.ignore_steps,
|
||||
dp_world_size=dp_size,
|
||||
)
|
||||
|
||||
optimizer = HybridAdam(model.parameters())
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader)
|
||||
torch.set_default_dtype(torch.float)
|
||||
coordinator.print_on_master(
|
||||
f"Booster init max CUDA memory: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB"
|
||||
)
|
||||
coordinator.print_on_master(
|
||||
f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB"
|
||||
)
|
||||
|
||||
if isinstance(plugin, HybridParallelPlugin) and args.pp > 1:
|
||||
data_iter = iter(dataloader)
|
||||
for step in tqdm(range(len(dataloader)), desc="Step", disable=not coordinator.is_master()):
|
||||
performance_evaluator.on_step_start(step)
|
||||
booster.execute_pipeline(
|
||||
data_iter, model, criterion=lambda outputs, inputs: outputs[0], optimizer=optimizer, return_loss=False
|
||||
)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
performance_evaluator.on_step_end(input_ids=torch.empty(args.batch_size, args.max_length))
|
||||
else:
|
||||
for step, batch in enumerate(tqdm(dataloader, desc="Step", disable=not coordinator.is_master())):
|
||||
performance_evaluator.on_step_start(step)
|
||||
outputs = model(**batch)
|
||||
loss = outputs[0]
|
||||
booster.backward(loss, optimizer)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
performance_evaluator.on_step_end(**batch)
|
||||
|
||||
performance_evaluator.on_fit_end()
|
||||
coordinator.print_on_master(f"Max CUDA memory usage: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
8
examples/language/llama/requirements.txt
Normal file
8
examples/language/llama/requirements.txt
Normal file
@@ -0,0 +1,8 @@
|
||||
colossalai>=0.3.6
|
||||
datasets
|
||||
numpy
|
||||
tqdm
|
||||
transformers
|
||||
flash-attn>=2.0.0
|
||||
SentencePiece==0.1.99
|
||||
tensorboard==2.14.0
|
17
examples/language/llama/scripts/benchmark_70B/3d.sh
Normal file
17
examples/language/llama/scripts/benchmark_70B/3d.sh
Normal file
@@ -0,0 +1,17 @@
|
||||
#!/bin/bash
|
||||
|
||||
# TODO: fix this
|
||||
echo "3D parallel for LLaMA-2 is not ready yet"
|
||||
exit 1
|
||||
|
||||
################
|
||||
#Load your environments and modules here
|
||||
################
|
||||
|
||||
HOSTFILE=$(realpath hosts.txt)
|
||||
|
||||
cd ../..
|
||||
|
||||
export OMP_NUM_THREADS=8
|
||||
|
||||
colossalai run --nproc_per_node 8 --hostfile $HOSTFILE benchmark.py -c 70b -p 3d -g -x -b 8 --tp 4 --pp 2 --mbs 1
|
13
examples/language/llama/scripts/benchmark_70B/gemini.sh
Normal file
13
examples/language/llama/scripts/benchmark_70B/gemini.sh
Normal file
@@ -0,0 +1,13 @@
|
||||
#!/bin/bash
|
||||
|
||||
################
|
||||
#Load your environments and modules here
|
||||
################
|
||||
|
||||
HOSTFILE=$(realpath hosts.txt)
|
||||
|
||||
cd ../..
|
||||
|
||||
export OMP_NUM_THREADS=8
|
||||
|
||||
colossalai run --nproc_per_node 8 --hostfile $HOSTFILE benchmark.py -c 70b -g -x -b 2
|
13
examples/language/llama/scripts/benchmark_70B/gemini_auto.sh
Normal file
13
examples/language/llama/scripts/benchmark_70B/gemini_auto.sh
Normal file
@@ -0,0 +1,13 @@
|
||||
#!/bin/bash
|
||||
|
||||
################
|
||||
#Load your environments and modules here
|
||||
################
|
||||
|
||||
HOSTFILE=$(realpath hosts.txt)
|
||||
|
||||
cd ../..
|
||||
|
||||
export OMP_NUM_THREADS=8
|
||||
|
||||
colossalai run --nproc_per_node 8 --hostfile $HOSTFILE benchmark.py -c 70b -p gemini_auto -g -x -b 2
|
13
examples/language/llama/scripts/benchmark_7B/gemini.sh
Normal file
13
examples/language/llama/scripts/benchmark_7B/gemini.sh
Normal file
@@ -0,0 +1,13 @@
|
||||
#!/bin/bash
|
||||
|
||||
################
|
||||
#Load your environments and modules here
|
||||
################
|
||||
|
||||
HOSTFILE=$(realpath hosts.txt)
|
||||
|
||||
cd ../..
|
||||
|
||||
export OMP_NUM_THREADS=8
|
||||
|
||||
colossalai run --nproc_per_node 8 --hostfile $HOSTFILE benchmark.py -g -x -b 16
|
13
examples/language/llama/scripts/benchmark_7B/gemini_auto.sh
Normal file
13
examples/language/llama/scripts/benchmark_7B/gemini_auto.sh
Normal file
@@ -0,0 +1,13 @@
|
||||
#!/bin/bash
|
||||
|
||||
################
|
||||
#Load your environments and modules here
|
||||
################
|
||||
|
||||
HOSTFILE=$(realpath hosts.txt)
|
||||
|
||||
cd ../..
|
||||
|
||||
export OMP_NUM_THREADS=8
|
||||
|
||||
colossalai run --nproc_per_node 8 --hostfile $HOSTFILE benchmark.py -p gemini_auto -g -x -b 16
|
0
examples/language/llama/test_ci.sh
Executable file
0
examples/language/llama/test_ci.sh
Executable file
Reference in New Issue
Block a user