mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 10:06:44 +00:00
Merge branch 'main' into sync/npu
This commit is contained in:
@@ -88,6 +88,7 @@ class GLUEDataBuilder:
|
||||
)
|
||||
|
||||
def val_dataloader(self):
|
||||
# as the last batch may not be divisible by the number of microbatches
|
||||
if len(self.eval_splits) == 1:
|
||||
return self.plugin.prepare_dataloader(self.dataset["validation"], batch_size=self.eval_batch_size)
|
||||
elif len(self.eval_splits) > 1:
|
||||
|
@@ -57,7 +57,7 @@ def evaluate_model(
|
||||
|
||||
def evaluate_subset(dataloader: DataLoader):
|
||||
use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1
|
||||
is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
|
||||
is_pp_last_device = use_pipeline and booster.plugin.stage_manager.is_last_stage(ignore_chunk=True)
|
||||
|
||||
accum_loss = torch.zeros(1, device=get_accelerator().get_current_device())
|
||||
for batch in dataloader:
|
||||
@@ -69,9 +69,10 @@ def evaluate_model(
|
||||
current_pp_group_ranks = pg_mesh.get_ranks_in_group(pp_group)
|
||||
current_rank = dist.get_rank()
|
||||
batch = iter([batch])
|
||||
|
||||
outputs = booster.execute_pipeline(batch, model, criterion, return_loss=True, return_outputs=True)
|
||||
|
||||
if is_pp_last_stage:
|
||||
if is_pp_last_device:
|
||||
logits = outputs["outputs"]["logits"]
|
||||
val_loss = outputs["loss"]
|
||||
accum_loss.add_(val_loss)
|
||||
@@ -135,8 +136,8 @@ def train_epoch(
|
||||
coordinator: DistCoordinator,
|
||||
):
|
||||
use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1
|
||||
is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
|
||||
print_flag = (not use_pipeline and coordinator.is_master()) or (use_pipeline and is_pp_last_stage)
|
||||
is_pp_last_device = use_pipeline and booster.plugin.stage_manager.is_last_stage(ignore_chunk=True)
|
||||
print_flag = (not use_pipeline and coordinator.is_master()) or (use_pipeline and is_pp_last_device)
|
||||
total_step = len(train_dataloader)
|
||||
|
||||
model.train()
|
||||
@@ -150,7 +151,7 @@ def train_epoch(
|
||||
train_dataloader_iter, model, _criterion, optimizer, return_loss=True, return_outputs=True
|
||||
)
|
||||
# Backward and optimize
|
||||
if is_pp_last_stage:
|
||||
if is_pp_last_device:
|
||||
loss = outputs["loss"]
|
||||
pbar.set_postfix({"loss": loss.item()})
|
||||
else:
|
||||
@@ -224,7 +225,9 @@ def main():
|
||||
tp_size=1,
|
||||
pp_size=2,
|
||||
num_microbatches=None,
|
||||
microbatch_size=1,
|
||||
pp_style="interleaved",
|
||||
num_model_chunks=2,
|
||||
microbatch_size=16,
|
||||
enable_all_optimization=True,
|
||||
zero_stage=1,
|
||||
precision="fp16",
|
||||
|
@@ -1,8 +1,17 @@
|
||||
#!/bin/bash
|
||||
set -xe
|
||||
set -x
|
||||
|
||||
pip install -r requirements.txt
|
||||
|
||||
FAIL_LIMIT=3
|
||||
|
||||
for plugin in "torch_ddp" "torch_ddp_fp16" "gemini" "low_level_zero" "hybrid_parallel"; do
|
||||
torchrun --standalone --nproc_per_node 4 finetune.py --target_f1 0.86 --plugin $plugin --model_type "bert"
|
||||
for i in $(seq 1 $FAIL_LIMIT); do
|
||||
torchrun --standalone --nproc_per_node 4 finetune.py --target_f1 0.86 --plugin $plugin --model_type "bert" && break
|
||||
echo "Failed $i times"
|
||||
if [ $i -eq $FAIL_LIMIT ]; then
|
||||
echo "Failed $FAIL_LIMIT times, exiting"
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
done
|
||||
|
@@ -6,7 +6,6 @@
|
||||
</p>
|
||||
|
||||
- 70 billion parameter LLaMA2 model training accelerated by 195%
|
||||
[[code]](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/llama2)
|
||||
[[blog]](https://www.hpc-ai.tech/blog/70b-llama2-training)
|
||||
|
||||
### LLaMA1
|
||||
@@ -15,7 +14,6 @@
|
||||
</p>
|
||||
|
||||
- 65-billion-parameter large model pretraining accelerated by 38%
|
||||
[[code]](https://github.com/hpcaitech/ColossalAI/tree/example/llama/examples/language/llama)
|
||||
[[blog]](https://www.hpc-ai.tech/blog/large-model-pretraining)
|
||||
|
||||
## Dataset
|
||||
@@ -103,7 +101,7 @@ Here is details about CLI arguments:
|
||||
- Max length: `-l`, `--max_length`. The default value is 4096.
|
||||
- Mixed precision: `-x`, `--mixed_precision`. The default value is "fp16". "fp16" and "bf16" are supported.
|
||||
- Save interval: `-i`, `--save_interval`. The interval (steps) of saving checkpoints. The default value is 1000.
|
||||
- Checkpoint directory: `-o`, `--save_dir`. The directoty path to save checkpoints. The default value is `checkpoint`.
|
||||
- Checkpoint directory: `-o`, `--save_dir`. The directory path to save checkpoints. The default value is `checkpoint`.
|
||||
- Checkpoint to load: `-f`, `--load`. The checkpoint path to load. The default value is `None`.
|
||||
- Gradient clipping: `--gradient_clipping`. The default value is 1.0.
|
||||
- Tensorboard log directory: `-t`, `--tensorboard_dir`. The directory path to save tensorboard logs. The default value is `tb_logs`.
|
||||
@@ -123,7 +121,7 @@ Here we will show an example of how to run training
|
||||
llama pretraining with `gemini, batch_size=16, sequence_length=4096, gradient_checkpoint=True, flash_attn=True`.
|
||||
|
||||
#### a. Running environment
|
||||
This experiment was performed on 4 computing nodes with 32 A800 GPUs in total for LLaMA-1 65B. The nodes are
|
||||
This experiment was performed on 4 computing nodes with 32 A800/H800 80GB GPUs in total for LLaMA-1 65B or LLaMA-2 70B. The nodes are
|
||||
connected with RDMA and GPUs within one node are fully connected with NVLink.
|
||||
|
||||
#### b. Running command
|
||||
@@ -217,7 +215,7 @@ Here is details about CLI arguments:
|
||||
- Max length: `-l`, `--max_length`. The default value is 4096.
|
||||
- Mixed precision: `-x`, `--mixed_precision`. The default value is "fp16". "fp16" and "bf16" are supported.
|
||||
- Save interval: `-i`, `--save_interval`. The interval (steps) of saving checkpoints. The default value is 1000.
|
||||
- Checkpoint directory: `-o`, `--save_dir`. The directoty path to save checkpoints. The default value is `checkpoint`.
|
||||
- Checkpoint directory: `-o`, `--save_dir`. The directory path to save checkpoints. The default value is `checkpoint`.
|
||||
- Checkpoint to load: `-f`, `--load`. The checkpoint path to load. The default value is `None`.
|
||||
- Gradient clipping: `--gradient_clipping`. The default value is 1.0.
|
||||
- Tensorboard log directory: `-t`, `--tensorboard_dir`. The directory path to save tensorboard logs. The default value is `tb_logs`.
|
||||
|
@@ -71,9 +71,10 @@ def main():
|
||||
parser.add_argument("--offload_optim_frac", type=float, default=0.0, help="Offload optim fraction. Only for gemini")
|
||||
parser.add_argument("--offload_param_frac", type=float, default=0.0, help="Offload param fraction. Only for gemini")
|
||||
parser.add_argument("--tp", type=int, default=1, help="Tensor parallel size")
|
||||
parser.add_argument("--extra_dp", type=int, default=1, help="Extra data parallel size, used for Gemini")
|
||||
parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel size")
|
||||
parser.add_argument("--mbs", type=int, default=1)
|
||||
parser.add_argument("--zero", type=int, default=0)
|
||||
parser.add_argument("--mbs", type=int, default=1, help="Micro batch size of pipeline parallel")
|
||||
parser.add_argument("--zero", type=int, default=0, help="Zero Stage when hybrid plugin is enabled")
|
||||
args = parser.parse_args()
|
||||
|
||||
colossalai.launch_from_torch({})
|
||||
@@ -92,9 +93,17 @@ def main():
|
||||
shard_param_frac=args.shard_param_frac,
|
||||
offload_optim_frac=args.offload_optim_frac,
|
||||
offload_param_frac=args.offload_param_frac,
|
||||
tp_size=args.tp,
|
||||
extra_dp_size=args.extra_dp,
|
||||
)
|
||||
elif args.plugin == "gemini_auto":
|
||||
plugin = GeminiPlugin(placement_policy="auto", precision="bf16", warmup_non_model_data_ratio=args.warmup_ratio)
|
||||
plugin = GeminiPlugin(
|
||||
placement_policy="auto",
|
||||
precision="bf16",
|
||||
warmup_non_model_data_ratio=args.warmup_ratio,
|
||||
tp_size=args.tp,
|
||||
extra_dp_size=args.extra_dp,
|
||||
)
|
||||
elif args.plugin == "fsdp":
|
||||
if use_empty_init:
|
||||
plugin = TorchFSDPPlugin(
|
||||
@@ -129,9 +138,11 @@ def main():
|
||||
plugin = HybridParallelPlugin(
|
||||
tp_size=args.tp,
|
||||
pp_size=args.pp,
|
||||
pp_style="interleaved",
|
||||
zero_stage=args.zero,
|
||||
num_model_chunks=2,
|
||||
enable_fused_normalization=torch.cuda.is_available(),
|
||||
num_microbatches=args.mbs,
|
||||
microbatch_size=args.mbs,
|
||||
precision="bf16",
|
||||
)
|
||||
elif args.plugin == "3d_cpu":
|
||||
@@ -141,7 +152,7 @@ def main():
|
||||
zero_stage=args.zero,
|
||||
cpu_offload=True,
|
||||
enable_fused_normalization=torch.cuda.is_available(),
|
||||
num_microbatches=args.mbs,
|
||||
microbatch_size=args.mbs,
|
||||
initial_scale=2**8,
|
||||
precision="bf16",
|
||||
)
|
||||
|
@@ -14,4 +14,4 @@ cd ../..
|
||||
|
||||
export OMP_NUM_THREADS=8
|
||||
|
||||
colossalai run --nproc_per_node 8 --hostfile $HOSTFILE benchmark.py -c 70b -p 3d -g -x -b 8 --tp 4 --pp 2 --mbs 4
|
||||
colossalai run --nproc_per_node 8 --hostfile $HOSTFILE benchmark.py -c 70b -p 3d -g -x -b 8 --tp 4 --pp 2 --mbs 1
|
||||
|
@@ -1,6 +1,15 @@
|
||||
## OpenMoE
|
||||
[OpenMoE](https://github.com/XueFuzhao/OpenMoE) is the open-source community's first decoder-only MoE transformer. OpenMoE is implemented in Jax, and [Colossal-AI](https://github.com/hpcaitech/ColossalAI) has pioneered an efficient open-source support for this model in PyTorch, enabling a broader range of users to participate in and use this model. The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) demonstrates finetune and inference methods.
|
||||
|
||||
|
||||
<p align="center">
|
||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/examples/images/MOE_training.png" width=800/>
|
||||
</p>
|
||||
|
||||
* [2023/11] [Enhanced MoE Parallelism, Open-source MoE Model Training Can Be 9 Times More Efficient](https://www.hpc-ai.tech/blog/enhanced-moe-parallelism-open-source-moe-model-training-can-be-9-times-more-efficient)
|
||||
[[code]](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/openmoe)
|
||||
[[blog]](https://www.hpc-ai.tech/blog/enhanced-moe-parallelism-open-source-moe-model-training-can-be-9-times-more-efficient)
|
||||
|
||||
## Usage
|
||||
|
||||
### 1. Installation
|
||||
|
Reference in New Issue
Block a user