Date: Thu, 18 Jul 2024 08:38:56 +0000
Subject: [PATCH 032/116] fix style, add kto data sample
---
.../ColossalChat/benchmarks/benchmark_kto.sh | 4 +--
.../ColossalChat/benchmarks/benchmark_sft.sh | 8 ++---
.../ColossalChat/coati/models/loss.py | 11 ++-----
.../ColossalChat/coati/trainer/kto.py | 4 +--
applications/ColossalChat/examples/README.md | 21 ++++++++++++-
.../prepare_kto_dataset.sh | 6 ++--
.../prepare_preference_dataset.sh | 3 +-
.../prepare_prompt_dataset.sh | 3 +-
.../prepare_sft_dataset.sh | 1 +
.../examples/training_scripts/train_kto.sh | 30 +++++++++----------
.../examples/training_scripts/train_sft.sh | 30 +++++++++----------
11 files changed, 68 insertions(+), 53 deletions(-)
diff --git a/applications/ColossalChat/benchmarks/benchmark_kto.sh b/applications/ColossalChat/benchmarks/benchmark_kto.sh
index 4d651914e..571915c3b 100755
--- a/applications/ColossalChat/benchmarks/benchmark_kto.sh
+++ b/applications/ColossalChat/benchmarks/benchmark_kto.sh
@@ -17,8 +17,8 @@ set_n_least_used_CUDA_VISIBLE_DEVICES 4
PROJECT_NAME="kto"
PARENT_CONFIG_FILE="./benchmark_config" # Path to a folder to save training config logs
-PRETRAINED_MODEL_PATH="/root/commonData/Llama-2-7b-hf" # huggingface or local model path
-PRETRAINED_TOKENIZER_PATH="/root/commonData/Llama-2-7b-hf" # huggingface or local tokenizer path
+PRETRAINED_MODEL_PATH="" # huggingface or local model path
+PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
diff --git a/applications/ColossalChat/benchmarks/benchmark_sft.sh b/applications/ColossalChat/benchmarks/benchmark_sft.sh
index 376a638c3..70c0e5873 100755
--- a/applications/ColossalChat/benchmarks/benchmark_sft.sh
+++ b/applications/ColossalChat/benchmarks/benchmark_sft.sh
@@ -14,11 +14,11 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() {
}
set_n_least_used_CUDA_VISIBLE_DEVICES 4
-# export CUDA_VISIBLE_DEVICES=3,4
+
PROJECT_NAME="sft"
PARENT_CONFIG_FILE="./benchmark_config" # Path to a folder to save training config logs
-PRETRAINED_MODEL_PATH="/root/commonData/Llama-2-7b-hf" # huggingface or local model path
-PRETRAINED_TOKENIZER_PATH="/root/commonData/Llama-2-7b-hf" # huggingface or local tokenizer path
+PRETRAINED_MODEL_PATH="" # huggingface or local model path
+PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
@@ -29,7 +29,7 @@ colossalai run --nproc_per_node 1 --master_port 31312 benchmark_sft.py \
--pretrain $PRETRAINED_MODEL_PATH \
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
--config_file $CONFIG_FILE \
- --plugin ddp \
+ --plugin zero2 \
--batch_size 8 \
--max_epochs 1 \
--accumulation_steps 1 \
diff --git a/applications/ColossalChat/coati/models/loss.py b/applications/ColossalChat/coati/models/loss.py
index 3d75aa9fa..840cca074 100755
--- a/applications/ColossalChat/coati/models/loss.py
+++ b/applications/ColossalChat/coati/models/loss.py
@@ -251,17 +251,13 @@ class KTOLoss(nn.Module):
# all gather
dist.all_reduce(kl, op=dist.ReduceOp.SUM)
kl = (kl / dist.get_world_size()).clamp(min=0)
- # kl = 0
if chosen_logps.shape[0] != 0 and ref_chosen_logps.shape[0] != 0:
chosen_logratios = chosen_logps - ref_chosen_logps
chosen_losses = 1 - nn.functional.sigmoid(self.beta * (chosen_logratios - kl))
chosen_rewards = self.beta * chosen_logratios.detach()
else:
- # important to cast to policy_dtype; otherwise error will occur during all_gather
- chosen_losses = torch.Tensor([]).to(
- kl_logps.device
- ) # torch.Tensor(0.).to(chosen_logps.dtype).to(chosen_logps.device)
+ chosen_losses = torch.Tensor([]).to(kl_logps.device)
chosen_rewards = torch.Tensor([]).to(kl_logps.device)
if rejected_logps.shape[0] != 0 and ref_rejected_logps.shape[0] != 0:
@@ -269,10 +265,7 @@ class KTOLoss(nn.Module):
rejected_losses = 1 - nn.functional.sigmoid(self.beta * (kl - rejected_logratios))
rejected_rewards = self.beta * rejected_logratios.detach()
else:
- # important to cast to policy_dtype; otherwise error will occur during all_gather
- rejected_losses = torch.Tensor([]).to(
- kl_logps.device
- ) # torch.Tensor(0.).to(rejected_logps.dtype).to(rejected_logps.device)
+ rejected_losses = torch.Tensor([]).to(kl_logps.device)
rejected_rewards = torch.Tensor([]).to(kl_logps.device)
losses = torch.cat((self.desirable_weight * chosen_losses, self.undesirable_weight * rejected_losses), 0).mean()
diff --git a/applications/ColossalChat/coati/trainer/kto.py b/applications/ColossalChat/coati/trainer/kto.py
index d0b6d63f7..83bff3bbe 100755
--- a/applications/ColossalChat/coati/trainer/kto.py
+++ b/applications/ColossalChat/coati/trainer/kto.py
@@ -180,7 +180,7 @@ class KTOTrainer(SLTrainer):
self.optimizer.zero_grad()
self.actor_scheduler.step()
- # # sync
+ # sync
loss_mean = all_reduce_mean(tensor=loss)
chosen_rewards_mean = all_reduce_mean(tensor=chosen_rewards.mean())
rejected_rewards_mean = all_reduce_mean(tensor=rejected_rewards.mean())
@@ -297,7 +297,7 @@ class KTOTrainer(SLTrainer):
chosen_logprob, rejected_logprob, kl_logprob, ref_chosen_logprob, ref_rejected_logprob, ref_kl_logprob
)
- # # sync
+ # sync
loss_mean = all_reduce_mean(tensor=loss)
chosen_rewards_mean = all_reduce_mean(tensor=chosen_rewards.mean())
rejected_rewards_mean = all_reduce_mean(tensor=rejected_rewards.mean())
diff --git a/applications/ColossalChat/examples/README.md b/applications/ColossalChat/examples/README.md
index fbbb25332..54f1bdba6 100755
--- a/applications/ColossalChat/examples/README.md
+++ b/applications/ColossalChat/examples/README.md
@@ -754,7 +754,26 @@ We support the method introduced in the paper [ORPO: Monolithic Preference Optim
### Alternative Option For RLHF: Kahneman-Tversky Optimization (KTO)
-We support the method introduced in the paper [KTO:Model Alignment as Prospect Theoretic Optimization](https://arxiv.org/pdf/2402.01306) (KTO). Which is a aligment method that directly maximize "human utility" of generation results. To use KTO in alignment, use the [train_kto.sh](./examples/training_scripts/train_orpo.sh) script, You may need to set the value for `beta` (which determine how strongly the reinforcement learning loss affect the training), `desirable_weight` and `undesirable_weight` if your data is biased (has unequal number of chosen and rejected samples).
+We support the method introduced in the paper [KTO:Model Alignment as Prospect Theoretic Optimization](https://arxiv.org/pdf/2402.01306) (KTO). Which is a aligment method that directly maximize "human utility" of generation results.
+
+For KTO data preparation, please use the script [prepare_kto_dataset.sh](./examples/data_preparation_scripts/prepare_kto_dataset.sh). You will need preference data, different from DPO and its derivatives, you no longer need a pair of chosen/rejected response for the same input. You only need data whose response is associated with a preference label--- whether the response is okay or not, read the papre for more details. You also need to convert your data to the following intermediate format before you run the data preparation script.
+
+```jsonl
+{
+ "prompt": [
+ {"from": "human", "content": "What are some cuss words in english?"},
+ {
+ "from": "assistant",
+ "content": "Here's an incomplete list.\n\nAss, dick, bugger, crap, ...",
+ },
+ {"from": "human", "content": "What's your favorite one?"},
+ ],
+ "completion": {"from": "assistant", "content": "Ass."}, # the completion must contain a single line from the assistant.
+ "label": False, # whether the response is favorable or not
+}
+```
+
+For training, use the [train_kto.sh](./examples/training_scripts/train_orpo.sh) script, You may need to set the value for `beta` (which determine how strongly the reinforcement learning loss affect the training), `desirable_weight` and `undesirable_weight` if your data is biased (has unequal number of chosen and rejected samples).
#### KTO Result
diff --git a/applications/ColossalChat/examples/data_preparation_scripts/prepare_kto_dataset.sh b/applications/ColossalChat/examples/data_preparation_scripts/prepare_kto_dataset.sh
index 274e2a3fd..0450d570d 100755
--- a/applications/ColossalChat/examples/data_preparation_scripts/prepare_kto_dataset.sh
+++ b/applications/ColossalChat/examples/data_preparation_scripts/prepare_kto_dataset.sh
@@ -5,9 +5,9 @@ rm -rf $SAVE_DIR/jsonl
rm -rf $SAVE_DIR/arrow
python prepare_dataset.py --type kto \
- --data_input_dirs /home/nvme-share/home/yeanbang/data/dataset/hh_rlhf/kto_format/data \
- --conversation_template_config /home/nvme-share/home/yeanbang/ColossalAI/applications/ColossalChat/config/conversation_template/llama2.json \
- --tokenizer_dir "/home/nvme-share/share/models/Sheared-LLaMA-1.3B" \
+ --data_input_dirs /PATH/TO/KTO/DATASET \
+ --conversation_template_config /PATH/TO/CHAT/TEMPLATE/CONFIG.json \
+ --tokenizer_dir "" \
--data_cache_dir $SAVE_DIR/cache \
--data_jsonl_output_dir $SAVE_DIR/jsonl \
--data_arrow_output_dir $SAVE_DIR/arrow \
diff --git a/applications/ColossalChat/examples/data_preparation_scripts/prepare_preference_dataset.sh b/applications/ColossalChat/examples/data_preparation_scripts/prepare_preference_dataset.sh
index b6546a21e..5c06b43fe 100755
--- a/applications/ColossalChat/examples/data_preparation_scripts/prepare_preference_dataset.sh
+++ b/applications/ColossalChat/examples/data_preparation_scripts/prepare_preference_dataset.sh
@@ -10,4 +10,5 @@ python prepare_dataset.py --type preference \
--tokenizer_dir "" \
--data_cache_dir $SAVE_DIR/cache \
--data_jsonl_output_dir $SAVE_DIR/jsonl \
- --data_arrow_output_dir $SAVE_DIR/arrow
+ --data_arrow_output_dir $SAVE_DIR/arrow \
+ --max_length 1024
diff --git a/applications/ColossalChat/examples/data_preparation_scripts/prepare_prompt_dataset.sh b/applications/ColossalChat/examples/data_preparation_scripts/prepare_prompt_dataset.sh
index 8d3d6c2c2..d74667889 100755
--- a/applications/ColossalChat/examples/data_preparation_scripts/prepare_prompt_dataset.sh
+++ b/applications/ColossalChat/examples/data_preparation_scripts/prepare_prompt_dataset.sh
@@ -10,4 +10,5 @@ python prepare_dataset.py --type prompt \
--tokenizer_dir "" \
--data_cache_dir $SAVE_DIR/cache \
--data_jsonl_output_dir $SAVE_DIR/jsonl \
- --data_arrow_output_dir $SAVE_DIR/arrow
+ --data_arrow_output_dir $SAVE_DIR/arrow \
+ --max_length 1024
diff --git a/applications/ColossalChat/examples/data_preparation_scripts/prepare_sft_dataset.sh b/applications/ColossalChat/examples/data_preparation_scripts/prepare_sft_dataset.sh
index 3f555883d..dbd323438 100755
--- a/applications/ColossalChat/examples/data_preparation_scripts/prepare_sft_dataset.sh
+++ b/applications/ColossalChat/examples/data_preparation_scripts/prepare_sft_dataset.sh
@@ -11,3 +11,4 @@ python prepare_dataset.py --type sft \
--data_cache_dir $SAVE_DIR/cache \
--data_jsonl_output_dir $SAVE_DIR/jsonl \
--data_arrow_output_dir $SAVE_DIR/arrow \
+ --max_length 4096
diff --git a/applications/ColossalChat/examples/training_scripts/train_kto.sh b/applications/ColossalChat/examples/training_scripts/train_kto.sh
index 86b5897f1..3dcf6af02 100755
--- a/applications/ColossalChat/examples/training_scripts/train_kto.sh
+++ b/applications/ColossalChat/examples/training_scripts/train_kto.sh
@@ -16,23 +16,23 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() {
set_n_least_used_CUDA_VISIBLE_DEVICES 4
PROJECT_NAME="kto"
-PARENT_SAVE_DIR="/home/nvme-share/home/yeanbang/data/experiments/kto/checkpoint" # Path to a folder to save checkpoints
-PARENT_TENSORBOARD_DIR="/home/nvme-share/home/yeanbang/data/experiments/kto/log" # Path to a folder to save logs
-PARENT_CONFIG_FILE="/home/nvme-share/home/yeanbang/data/experiments/kto/log" # Path to a folder to save training config logs
-PRETRAINED_MODEL_PATH="/home/nvme-share/home/yeanbang/data/model/hh_rlhf_sheared_llamasft-2024-07-17-07-29-29/modeling" # huggingface or local model path
-PRETRAINED_TOKENIZER_PATH="/home/nvme-share/share/models/Sheared-LLaMA-1.3B" # huggingface or local tokenizer path
+PARENT_SAVE_DIR="" # Path to a folder to save checkpoints
+PARENT_TENSORBOARD_DIR="" # Path to a folder to save logs
+PARENT_CONFIG_FILE="" # Path to a folder to save training config logs
+PRETRAINED_MODEL_PATH="" # huggingface or local model path
+PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
declare -a dataset=(
- /home/nvme-share/home/yeanbang/data/experiments/kto/arrow/part-00000
- /home/nvme-share/home/yeanbang/data/experiments/kto/arrow/part-00001
- /home/nvme-share/home/yeanbang/data/experiments/kto/arrow/part-00002
- /home/nvme-share/home/yeanbang/data/experiments/kto/arrow/part-00003
- /home/nvme-share/home/yeanbang/data/experiments/kto/arrow/part-00004
- /home/nvme-share/home/yeanbang/data/experiments/kto/arrow/part-00005
- /home/nvme-share/home/yeanbang/data/experiments/kto/arrow/part-00006
- /home/nvme-share/home/yeanbang/data/experiments/kto/arrow/part-00007
- /home/nvme-share/home/yeanbang/data/experiments/kto/arrow/part-00008
- /home/nvme-share/home/yeanbang/data/experiments/kto/arrow/part-00009
+ /Your/KTO/Data/arrow/part-00000
+ /Your/KTO/Data/arrow/part-00001
+ /Your/KTO/Data/arrow/part-00002
+ /Your/KTO/Data/arrow/part-00003
+ /Your/KTO/Data/arrow/part-00004
+ /Your/KTO/Data/arrow/part-00005
+ /Your/KTO/Data/arrow/part-00006
+ /Your/KTO/Data/arrow/part-00007
+ /Your/KTO/Data/arrow/part-00008
+ /Your/KTO/Data/arrow/part-00009
)
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
diff --git a/applications/ColossalChat/examples/training_scripts/train_sft.sh b/applications/ColossalChat/examples/training_scripts/train_sft.sh
index 31c2ab562..f0a281475 100755
--- a/applications/ColossalChat/examples/training_scripts/train_sft.sh
+++ b/applications/ColossalChat/examples/training_scripts/train_sft.sh
@@ -15,22 +15,22 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() {
set_n_least_used_CUDA_VISIBLE_DEVICES 4
PROJECT_NAME="sft"
-PARENT_SAVE_DIR="/home/nvme-share/home/yeanbang/data/model/hh_rlhf_sheared_llama" # Path to a folder to save checkpoints
-PARENT_TENSORBOARD_DIR="/home/nvme-share/home/yeanbang/data/experiments/sft/log" # Path to a folder to save logs
-PARENT_CONFIG_FILE="/home/nvme-share/home/yeanbang/data/experiments/kto/log" # Path to a folder to save training config logs
-PRETRAINED_MODEL_PATH="/home/nvme-share/share/models/Sheared-LLaMA-1.3B" # huggingface or local model path
-PRETRAINED_TOKENIZER_PATH="/home/nvme-share/share/models/Sheared-LLaMA-1.3B" # huggingface or local tokenizer path
+PARENT_SAVE_DIR="" # Path to a folder to save checkpoints
+PARENT_TENSORBOARD_DIR="" # Path to a folder to save logs
+PARENT_CONFIG_FILE="" # Path to a folder to save training config logs
+PRETRAINED_MODEL_PATH="" # huggingface or local model path
+PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
declare -a dataset=(
- /home/nvme-share/home/yeanbang/data/experiments/sft/arrow/part-00000
- /home/nvme-share/home/yeanbang/data/experiments/sft/arrow/part-00001
- /home/nvme-share/home/yeanbang/data/experiments/sft/arrow/part-00002
- /home/nvme-share/home/yeanbang/data/experiments/sft/arrow/part-00003
- /home/nvme-share/home/yeanbang/data/experiments/sft/arrow/part-00004
- /home/nvme-share/home/yeanbang/data/experiments/sft/arrow/part-00005
- /home/nvme-share/home/yeanbang/data/experiments/sft/arrow/part-00006
- /home/nvme-share/home/yeanbang/data/experiments/sft/arrow/part-00007
- /home/nvme-share/home/yeanbang/data/experiments/sft/arrow/part-00008
- /home/nvme-share/home/yeanbang/data/experiments/sft/arrow/part-00009
+ /Your/SFT/Data/arrow/part-00000
+ /Your/SFT/Data/arrow/part-00001
+ /Your/SFT/Data/arrow/part-00002
+ /Your/SFT/Data/arrow/part-00003
+ /Your/SFT/Data/arrow/part-00004
+ /Your/SFT/Data/arrow/part-00005
+ /Your/SFT/Data/arrow/part-00006
+ /Your/SFT/Data/arrow/part-00007
+ /Your/SFT/Data/arrow/part-00008
+ /Your/SFT/Data/arrow/part-00009
)
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
From 8cc8f645cd1d971a3bef52f625b7881f17c6d22b Mon Sep 17 00:00:00 2001
From: Edenzzzz
Date: Fri, 19 Jul 2024 10:10:08 +0800
Subject: [PATCH 033/116] [Examples] Add lazy init to OPT and GPT examples
(#5924)
Co-authored-by: Edenzzzz
---
.../gpt/hybridparallelism/finetune.py | 17 ++++++++----
examples/language/opt/opt_benchmark.py | 24 +++++++++++------
examples/language/opt/opt_train_demo.py | 27 +++++++++++++------
3 files changed, 47 insertions(+), 21 deletions(-)
diff --git a/examples/language/gpt/hybridparallelism/finetune.py b/examples/language/gpt/hybridparallelism/finetune.py
index 777d16cb9..ae6d655f4 100644
--- a/examples/language/gpt/hybridparallelism/finetune.py
+++ b/examples/language/gpt/hybridparallelism/finetune.py
@@ -1,4 +1,5 @@
import argparse
+from contextlib import nullcontext
from typing import Callable, List, Union
import evaluate
@@ -17,6 +18,7 @@ from colossalai.accelerator import get_accelerator
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator
+from colossalai.lazy import LazyInitContext
from colossalai.nn.optimizer import HybridAdam
# ==============================
@@ -186,7 +188,6 @@ def main():
help="only gpt2 now",
)
parser.add_argument("--target_f1", type=float, default=None, help="target f1 score. Raise exception if not reached")
- parser.add_argument("--use_lazy_init", type=bool, default=False, help="for initiating lazy init context")
args = parser.parse_args()
if args.model_type == "gpt2":
@@ -250,10 +251,16 @@ def main():
pad_token_id=data_builder.tokenizer.pad_token_id,
)
- if model_name == "gpt2":
- model = GPT2ForSequenceClassification.from_pretrained(model_name, config=cfg).cuda()
- else:
- raise RuntimeError
+ init_ctx = (
+ LazyInitContext(default_device=get_accelerator().get_current_device())
+ if isinstance(plugin, (GeminiPlugin))
+ else nullcontext()
+ )
+ with init_ctx:
+ if model_name == "gpt2":
+ model = GPT2ForSequenceClassification.from_pretrained(model_name, config=cfg).cuda()
+ else:
+ raise RuntimeError
# optimizer
no_decay = ["bias", "LayerNorm.weight"]
diff --git a/examples/language/opt/opt_benchmark.py b/examples/language/opt/opt_benchmark.py
index c2883d96c..ca9b63d1a 100755
--- a/examples/language/opt/opt_benchmark.py
+++ b/examples/language/opt/opt_benchmark.py
@@ -1,4 +1,5 @@
import time
+from contextlib import nullcontext
import torch
import tqdm
@@ -8,9 +9,11 @@ from transformers import AutoConfig, OPTForCausalLM
from transformers.utils.versions import require_version
import colossalai
+from colossalai.accelerator import get_accelerator
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator
+from colossalai.lazy import LazyInitContext
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam
@@ -62,14 +65,6 @@ def main():
if args.mem_cap > 0:
colo_memory_cap(args.mem_cap)
- # Build OPT model
- config = AutoConfig.from_pretrained(args.model_name_or_path)
- model = OPTForCausalLM(config=config)
- logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])
-
- # Enable gradient checkpointing
- model.gradient_checkpointing_enable()
-
# Set plugin
booster_kwargs = {}
if args.plugin == "torch_ddp_fp16":
@@ -82,6 +77,19 @@ def main():
plugin = LowLevelZeroPlugin(initial_scale=2**5)
logger.info(f"Set plugin as {args.plugin}", ranks=[0])
+ # Build OPT model
+ init_ctx = (
+ LazyInitContext(default_device=get_accelerator().get_current_device())
+ if isinstance(plugin, (GeminiPlugin))
+ else nullcontext()
+ )
+ config = AutoConfig.from_pretrained(args.model_name_or_path)
+ with init_ctx:
+ model = OPTForCausalLM(config=config)
+ logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])
+
+ # Enable gradient checkpointing
+ model.gradient_checkpointing_enable()
# Set optimizer
optimizer = HybridAdam(model.parameters(), lr=args.learning_rate)
diff --git a/examples/language/opt/opt_train_demo.py b/examples/language/opt/opt_train_demo.py
index b5b50305c..50dfc7bff 100644
--- a/examples/language/opt/opt_train_demo.py
+++ b/examples/language/opt/opt_train_demo.py
@@ -1,3 +1,5 @@
+from contextlib import nullcontext
+
import datasets
import torch
import transformers
@@ -8,9 +10,11 @@ from transformers import AutoConfig, AutoTokenizer, OPTForCausalLM, get_linear_s
from transformers.utils.versions import require_version
import colossalai
+from colossalai.accelerator import get_accelerator
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator
+from colossalai.lazy import LazyInitContext
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam
@@ -78,14 +82,6 @@ def main():
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()
- # Build OPT model
- config = AutoConfig.from_pretrained(args.model_name_or_path)
- model = OPTForCausalLM.from_pretrained(args.model_name_or_path, config=config)
- logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])
-
- # Enable gradient checkpointing
- model.gradient_checkpointing_enable()
-
# Set plugin
booster_kwargs = {}
if args.plugin == "torch_ddp_fp16":
@@ -110,6 +106,21 @@ def main():
logger.info(f"Set plugin as {args.plugin}", ranks=[0])
+ # Build OPT model
+ config = AutoConfig.from_pretrained(args.model_name_or_path)
+ # Build OPT model
+ init_ctx = (
+ LazyInitContext(default_device=get_accelerator().get_current_device())
+ if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin))
+ else nullcontext()
+ )
+ with init_ctx:
+ model = OPTForCausalLM.from_pretrained(args.model_name_or_path, config=config)
+ logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])
+
+ # Enable gradient checkpointing
+ model.gradient_checkpointing_enable()
+
# Prepare tokenizer and dataloader
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
dataset = NetflixDataset(tokenizer)
From f585d4e38e4aa6145ec94fe4d6f0a1fe94bc4192 Mon Sep 17 00:00:00 2001
From: Tong Li
Date: Fri, 19 Jul 2024 13:40:07 +0800
Subject: [PATCH 034/116] [ColossalChat] Hotfix for ColossalChat (#5910)
* add ignore and tiny llama
* fix path issue
* run style
* fix issue
* update bash
* add ignore and tiny llama
* fix path issue
* run style
* fix issue
* update bash
* fix ddp issue
* add Qwen 1.5 32B
---
applications/ColossalChat/.gitignore | 3 +++
applications/ColossalChat/coati/trainer/sft.py | 2 +-
.../Qwen_Qwen1.5-32B-Chat.json | 9 +++++++++
.../conversation_template/tiny-llama.json | 8 ++++++++
applications/ColossalChat/examples/README.md | 18 +++++++++---------
.../examples/training_scripts/train_dpo.sh | 6 ++----
.../examples/training_scripts/train_orpo.sh | 8 +++-----
.../examples/training_scripts/train_ppo.sh | 5 ++---
.../examples/training_scripts/train_rm.sh | 6 ++----
.../examples/training_scripts/train_sft.py | 2 +-
.../examples/training_scripts/train_sft.sh | 5 ++---
11 files changed, 42 insertions(+), 30 deletions(-)
create mode 100644 applications/ColossalChat/config/conversation_template/Qwen_Qwen1.5-32B-Chat.json
create mode 100644 applications/ColossalChat/config/conversation_template/tiny-llama.json
diff --git a/applications/ColossalChat/.gitignore b/applications/ColossalChat/.gitignore
index 33950adc0..757cbb5da 100755
--- a/applications/ColossalChat/.gitignore
+++ b/applications/ColossalChat/.gitignore
@@ -146,6 +146,9 @@ docs/.build
examples/wandb/
examples/logs/
examples/output/
+examples/training_scripts/logs
+examples/training_scripts/wandb
+examples/training_scripts/output
examples/awesome-chatgpt-prompts/
temp/
diff --git a/applications/ColossalChat/coati/trainer/sft.py b/applications/ColossalChat/coati/trainer/sft.py
index 1484f5057..c09d61034 100755
--- a/applications/ColossalChat/coati/trainer/sft.py
+++ b/applications/ColossalChat/coati/trainer/sft.py
@@ -102,7 +102,6 @@ class SFTTrainer(SLTrainer):
batch_size = batch["input_ids"].size(0)
outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
loss = outputs.loss
- step_bar.set_description(f"Epoch {epoch + 1}/{self.max_epochs} Loss: {loss.detach().cpu().item():.4f}")
self.booster.backward(loss=loss, optimizer=self.optimizer)
@@ -115,6 +114,7 @@ class SFTTrainer(SLTrainer):
self.optimizer.zero_grad()
self.scheduler.step()
+ step_bar.set_postfix({"train/loss": self.accumulative_meter.get("loss")})
if self.writer:
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step)
self.writer.add_scalar("train/lr", self.scheduler.get_last_lr()[0], self.num_train_step)
diff --git a/applications/ColossalChat/config/conversation_template/Qwen_Qwen1.5-32B-Chat.json b/applications/ColossalChat/config/conversation_template/Qwen_Qwen1.5-32B-Chat.json
new file mode 100644
index 000000000..58941a591
--- /dev/null
+++ b/applications/ColossalChat/config/conversation_template/Qwen_Qwen1.5-32B-Chat.json
@@ -0,0 +1,9 @@
+{
+ "chat_template": "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
+ "system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
+ "stop_ids": [
+ 151645,
+ 151643
+ ],
+ "end_of_assistant": "<|im_end|>"
+}
diff --git a/applications/ColossalChat/config/conversation_template/tiny-llama.json b/applications/ColossalChat/config/conversation_template/tiny-llama.json
new file mode 100644
index 000000000..59196159f
--- /dev/null
+++ b/applications/ColossalChat/config/conversation_template/tiny-llama.json
@@ -0,0 +1,8 @@
+{
+ "chat_template": "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}",
+ "system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
+ "stop_ids": [
+ 2
+ ],
+ "end_of_assistant": ""
+}
diff --git a/applications/ColossalChat/examples/README.md b/applications/ColossalChat/examples/README.md
index bdf4d23f1..d6114c8d5 100755
--- a/applications/ColossalChat/examples/README.md
+++ b/applications/ColossalChat/examples/README.md
@@ -490,7 +490,7 @@ In this code we provide a flexible way for users to set the conversation templat
On your first run of the data preparation script, you only need to define the "chat_template" (if you want to use custom chat template) and the "system message" (if you want to use a custom system message),
-- Step 2: Run the data preparation script--- [prepare_sft_dataset.sh](./examples/data_preparation_scripts/prepare_sft_dataset.sh). Note that whether or not you have skipped the first step, you need to provide the path to the conversation template config file (via the conversation_template_config arg). If you skipped the first step, an auto-generated conversation template will be stored at the designated file path.
+- Step 2: Run the data preparation script--- [prepare_sft_dataset.sh](./data_preparation_scripts/prepare_sft_dataset.sh). Note that whether or not you have skipped the first step, you need to provide the path to the conversation template config file (via the conversation_template_config arg). If you skipped the first step, an auto-generated conversation template will be stored at the designated file path.
- Step 3: (Optional) Check the correctness of the processed data. We provided an easy way for you to do a manual checking on the processed data by checking the "$SAVE_DIR/jsonl/part-XXXX.jsonl" files.
@@ -510,7 +510,7 @@ Human: what are some pranks with a pen i can do? Assistant: Are you
#### Step 3: Training
-Choose a suitable model architecture for your task. Note that your model should be compatible with the tokenizer that you used to tokenize the SFT dataset. You can run [train_sft.sh](./examples/training_scripts/train_sft.sh) to start a supervised instructs fine-tuning. Please refer to the [training configuration](#training-configuration) section for details regarding supported training options.
+Choose a suitable model architecture for your task. Note that your model should be compatible with the tokenizer that you used to tokenize the SFT dataset. You can run [train_sft.sh](./training_scripts/train_sft.sh) to start a supervised instructs fine-tuning. Please refer to the [training configuration](#training-configuration) section for details regarding supported training options.
### RLHF Training Stage2 - Training Reward Model
@@ -552,11 +552,11 @@ Below shows the preference dataset format used in training the reward model.
#### Step 2: Preprocessing
-Similar to the second step in the previous stage, we format the reward data into the same structured format as used in step 2 of the SFT stage. You can run [prepare_preference_dataset.sh](./examples/data_preparation_scripts/prepare_preference_dataset.sh) to prepare the preference data for reward model training.
+Similar to the second step in the previous stage, we format the reward data into the same structured format as used in step 2 of the SFT stage. You can run [prepare_preference_dataset.sh](./data_preparation_scripts/prepare_preference_dataset.sh) to prepare the preference data for reward model training.
#### Step 3: Training
-You can run [train_rm.sh](./examples/training_scripts/train_rm.sh) to start the reward model training. Please refer to the [training configuration](#training-configuration) section for details regarding supported training options.
+You can run [train_rm.sh](./training_scripts/train_rm.sh) to start the reward model training. Please refer to the [training configuration](#training-configuration) section for details regarding supported training options.
#### Features and Tricks in RM Training
@@ -627,14 +627,14 @@ The second dataset--- pretrained dataset is optional, provide it if you want to
]
```
#### Step 2: Preprocessing
-To prepare the prompt dataset for PPO training, simply run [prepare_prompt_dataset.sh](./examples/data_preparation_scripts/prepare_prompt_dataset.sh)
+To prepare the prompt dataset for PPO training, simply run [prepare_prompt_dataset.sh](./data_preparation_scripts/prepare_prompt_dataset.sh)
You can use the SFT dataset you prepared in the SFT stage or prepare a new one from different source for the ptx dataset. The ptx data is used to calculate ptx loss, which stabilizes the training according to the [InstructGPT paper](https://arxiv.org/pdf/2203.02155.pdf).
#### Step 3: Training
-You can run the [train_ppo.sh](./examples/training_scripts/train_ppo.sh) to start PPO training. Here are some unique arguments for PPO, please refer to the training configuration section for other training configuration. Please refer to the [training configuration](#training-configuration) section for details regarding supported training options.
+You can run the [train_ppo.sh](./training_scripts/train_ppo.sh) to start PPO training. Here are some unique arguments for PPO, please refer to the training configuration section for other training configuration. Please refer to the [training configuration](#training-configuration) section for details regarding supported training options.
```bash
@@ -718,7 +718,7 @@ For DPO training, you only need the preference dataset. Please follow the instru
#### Step 2: Training
-You can run the [train_dpo.sh](./examples/training_scripts/train_dpo.sh) to start DPO training. Please refer to the [training configuration](#training-configuration) section for details regarding supported training options. Following the trend of recent research on DPO-like alignment methods, we added option for the user to choose from, including whether to do length normalization , reward shaping and whether to use a reference model in calculating implicit reward. Here are those options,
+You can run the [train_dpo.sh](./training_scripts/train_dpo.sh) to start DPO training. Please refer to the [training configuration](#training-configuration) section for details regarding supported training options. Following the trend of recent research on DPO-like alignment methods, we added option for the user to choose from, including whether to do length normalization , reward shaping and whether to use a reference model in calculating implicit reward. Here are those options,
```
--beta 0.1 \ # the temperature in DPO loss, Default to 0.1
@@ -735,7 +735,7 @@ You can run the [train_dpo.sh](./examples/training_scripts/train_dpo.sh) to star
### Alternative Option For RLHF: Simple Preference Optimization
We support the method introduced in the paper [SimPO: Simple Preference Optimization
-with a Reference-Free Reward](https://arxiv.org/pdf/2405.14734) (SimPO). Which is a reference model free aligment method that add length normalization and reward shaping to the DPO loss to enhance training stability and efficiency. As the method doesn't deviate too much from DPO, we add support for length normalization and SimPO reward shaping in our DPO implementation. To use SimPO in alignment, use the [train_dpo.sh](./examples/training_scripts/train_dpo.sh) script, set the `loss_type` to `simpo_loss`, you can also set the value for temperature (`beta`) and reward target margin (`gamma`) but it is optional.
+with a Reference-Free Reward](https://arxiv.org/pdf/2405.14734) (SimPO). Which is a reference model free aligment method that add length normalization and reward shaping to the DPO loss to enhance training stability and efficiency. As the method doesn't deviate too much from DPO, we add support for length normalization and SimPO reward shaping in our DPO implementation. To use SimPO in alignment, use the [train_dpo.sh](./training_scripts/train_dpo.sh) script, set the `loss_type` to `simpo_loss`, you can also set the value for temperature (`beta`) and reward target margin (`gamma`) but it is optional.
#### SimPO Result
@@ -744,7 +744,7 @@ with a Reference-Free Reward](https://arxiv.org/pdf/2405.14734) (SimPO). Which i
### Alternative Option For RLHF: Odds Ratio Preference Optimization
-We support the method introduced in the paper [ORPO: Monolithic Preference Optimization without Reference Model](https://arxiv.org/abs/2403.07691) (ORPO). Which is a reference model free aligment method that mixes the SFT loss with a reinforcement learning loss that uses odds ratio as the implicit reward to enhance training stability and efficiency. Simply set the flag to disable the use of the reference model, set the reward target margin and enable length normalization in the DPO training script. To use ORPO in alignment, use the [train_orpo.sh](./examples/training_scripts/train_orpo.sh) script, You can set the value for `lambda` (which determine how strongly the reinforcement learning loss affect the training) but it is optional.
+We support the method introduced in the paper [ORPO: Monolithic Preference Optimization without Reference Model](https://arxiv.org/abs/2403.07691) (ORPO). Which is a reference model free aligment method that mixes the SFT loss with a reinforcement learning loss that uses odds ratio as the implicit reward to enhance training stability and efficiency. Simply set the flag to disable the use of the reference model, set the reward target margin and enable length normalization in the DPO training script. To use ORPO in alignment, use the [train_orpo.sh](./training_scripts/train_orpo.sh) script, You can set the value for `lambda` (which determine how strongly the reinforcement learning loss affect the training) but it is optional.
#### ORPO Result
diff --git a/applications/ColossalChat/examples/training_scripts/train_dpo.sh b/applications/ColossalChat/examples/training_scripts/train_dpo.sh
index f7bb45658..082d54ff0 100755
--- a/applications/ColossalChat/examples/training_scripts/train_dpo.sh
+++ b/applications/ColossalChat/examples/training_scripts/train_dpo.sh
@@ -15,9 +15,8 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() {
}
set_n_least_used_CUDA_VISIBLE_DEVICES 4
-PROJECT_NAME="dpo"
+PROJECT_NAME="DPO"
PARENT_SAVE_DIR="" # Path to a folder to save checkpoints
-PARENT_TENSORBOARD_DIR="" # Path to a folder to save logs
PARENT_CONFIG_FILE="" # Path to a folder to save training config logs
PRETRAINED_MODEL_PATH="" # huggingface or local model path
PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
@@ -38,11 +37,10 @@ declare -a dataset=(
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}"
-CONFIG_FILE="${PARENT_CONFIG_FILE}-${FULL_PROJECT_NAME}.json"
+CONFIG_FILE="${PARENT_CONFIG_FILE}${FULL_PROJECT_NAME}.json"
colossalai run --nproc_per_node 4 --hostfile hostfile --master_port 31313 train_dpo.py \
--pretrain $PRETRAINED_MODEL_PATH \
- --checkpoint_path $PRETRAINED_MODEL_PATH \
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
--dataset ${dataset[@]} \
--plugin "zero2" \
diff --git a/applications/ColossalChat/examples/training_scripts/train_orpo.sh b/applications/ColossalChat/examples/training_scripts/train_orpo.sh
index ca80a14c1..482956b21 100755
--- a/applications/ColossalChat/examples/training_scripts/train_orpo.sh
+++ b/applications/ColossalChat/examples/training_scripts/train_orpo.sh
@@ -13,11 +13,10 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() {
echo "Now CUDA_VISIBLE_DEVICES is set to:"
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
}
-set_n_least_used_CUDA_VISIBLE_DEVICES 8
+set_n_least_used_CUDA_VISIBLE_DEVICES 2
-PROJECT_NAME="dpo"
+PROJECT_NAME="ORPO"
PARENT_SAVE_DIR="" # Path to a folder to save checkpoints
-PARENT_TENSORBOARD_DIR="" # Path to a folder to save logs
PARENT_CONFIG_FILE="" # Path to a folder to save training config logs
PRETRAINED_MODEL_PATH="" # huggingface or local model path
PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
@@ -38,11 +37,10 @@ declare -a dataset=(
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}"
-CONFIG_FILE="${PARENT_CONFIG_FILE}-${FULL_PROJECT_NAME}.json"
+CONFIG_FILE="${PARENT_CONFIG_FILE}${FULL_PROJECT_NAME}.json"
colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 31313 train_orpo.py \
--pretrain $PRETRAINED_MODEL_PATH \
- --checkpoint_path $PRETRAINED_MODEL_PATH \
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
--dataset ${dataset[@]} \
--plugin "zero2" \
diff --git a/applications/ColossalChat/examples/training_scripts/train_ppo.sh b/applications/ColossalChat/examples/training_scripts/train_ppo.sh
index 91633978e..277e75e6d 100755
--- a/applications/ColossalChat/examples/training_scripts/train_ppo.sh
+++ b/applications/ColossalChat/examples/training_scripts/train_ppo.sh
@@ -15,10 +15,9 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() {
}
set_n_least_used_CUDA_VISIBLE_DEVICES 8
-PROJECT_NAME="ppo"
+PROJECT_NAME="PPO"
PARENT_SAVE_DIR="" # Path to a folder to save checkpoints
-PARENT_TENSORBOARD_DIR="" # Path to a folder to save logs
PARENT_CONFIG_FILE="" # Path to a folder to save training config logs
PRETRAINED_MODEL_PATH="" # local pretrained model path (from RLHF step 1: SFT)
PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
@@ -54,7 +53,7 @@ declare -a ptx_dataset=(
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}"
-CONFIG_FILE="${PARENT_CONFIG_FILE}-${FULL_PROJECT_NAME}.json"
+CONFIG_FILE="${PARENT_CONFIG_FILE}${FULL_PROJECT_NAME}.json"
colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 31312 train_ppo.py \
--pretrain $PRETRAINED_MODEL_PATH \
diff --git a/applications/ColossalChat/examples/training_scripts/train_rm.sh b/applications/ColossalChat/examples/training_scripts/train_rm.sh
index e06d9092f..cd42afcc8 100755
--- a/applications/ColossalChat/examples/training_scripts/train_rm.sh
+++ b/applications/ColossalChat/examples/training_scripts/train_rm.sh
@@ -15,9 +15,8 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() {
}
set_n_least_used_CUDA_VISIBLE_DEVICES 8
-PROJECT_NAME="rm"
+PROJECT_NAME="RM"
PARENT_SAVE_DIR="" # Path to a folder to save checkpoints
-PARENT_TENSORBOARD_DIR="" # Path to a folder to save logs
PARENT_CONFIG_FILE="" # Path to a folder to save training config logs
PRETRAINED_MODEL_PATH="" # huggingface or local model path
PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
@@ -38,11 +37,10 @@ declare -a dataset=(
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}"
-CONFIG_FILE="${PARENT_CONFIG_FILE}-${FULL_PROJECT_NAME}.json"
+CONFIG_FILE="${PARENT_CONFIG_FILE}${FULL_PROJECT_NAME}.json"
colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 31312 train_rm.py \
--pretrain $PRETRAINED_MODEL_PATH \
- --checkpoint_path /home/yeanbang/data/experiments/rm/hhh_aligh/ckptllama2-rm-2024-01-17-14-43-24/epoch-1_step-1317/modeling \
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
--dataset ${dataset[@]} \
--plugin "zero2" \
diff --git a/applications/ColossalChat/examples/training_scripts/train_sft.py b/applications/ColossalChat/examples/training_scripts/train_sft.py
index fe1506559..b89cbeb91 100755
--- a/applications/ColossalChat/examples/training_scripts/train_sft.py
+++ b/applications/ColossalChat/examples/training_scripts/train_sft.py
@@ -61,7 +61,7 @@ def train(args):
Default torch ddp plugin without any acceleration, for
debugging purpose acceleration, for debugging purpose
"""
- plugin = TorchDDPPlugin(find_unused_parameters=True)
+ plugin = TorchDDPPlugin(find_unused_parameters=True if args.grad_checkpoint is False else False)
elif args.plugin == "gemini":
plugin = GeminiPlugin(
precision=args.mixed_precision,
diff --git a/applications/ColossalChat/examples/training_scripts/train_sft.sh b/applications/ColossalChat/examples/training_scripts/train_sft.sh
index 18df09293..c7d38c1d8 100755
--- a/applications/ColossalChat/examples/training_scripts/train_sft.sh
+++ b/applications/ColossalChat/examples/training_scripts/train_sft.sh
@@ -14,9 +14,8 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() {
}
set_n_least_used_CUDA_VISIBLE_DEVICES 4
-PROJECT_NAME="sft"
+PROJECT_NAME="SFT"
PARENT_SAVE_DIR="" # Path to a folder to save checkpoints
-PARENT_TENSORBOARD_DIR="" # Path to a folder to save logs
PARENT_CONFIG_FILE="" # Path to a folder to save training config logs
PRETRAINED_MODEL_PATH="" # huggingface or local model path
PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
@@ -36,7 +35,7 @@ declare -a dataset=(
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}"
-CONFIG_FILE="${PARENT_CONFIG_FILE}-${FULL_PROJECT_NAME}.json"
+CONFIG_FILE="${PARENT_CONFIG_FILE}${FULL_PROJECT_NAME}.json"
echo $(which colossalai)
echo $(which python)
From d49550fb49994cd828a0d5da9b1eb473011dece5 Mon Sep 17 00:00:00 2001
From: YeAnbang
Date: Fri, 19 Jul 2024 10:10:48 +0000
Subject: [PATCH 035/116] refactor tokenization
---
.../coati/dataset/conversation.py | 3 +-
.../coati/dataset/tokenization_utils.py | 234 ++++++------------
.../ColossalChat/coati/dataset/utils.py | 36 ++-
applications/ColossalChat/examples/README.md | 12 +-
.../tests/test_data/dpo/test_dpo_data.jsonl | 2 +-
.../tests/test_data/kto/test_kto_data.jsonl | 1 +
.../tests/test_data/sft/test_sft_data.jsonl | 2 +-
.../ColossalChat/tests/test_templating.sh | 36 ++-
.../ColossalChat/tests/verify_chat_data.py | 8 +
9 files changed, 159 insertions(+), 175 deletions(-)
create mode 100644 applications/ColossalChat/tests/test_data/kto/test_kto_data.jsonl
diff --git a/applications/ColossalChat/coati/dataset/conversation.py b/applications/ColossalChat/coati/dataset/conversation.py
index 37900f3b8..a77c220d3 100755
--- a/applications/ColossalChat/coati/dataset/conversation.py
+++ b/applications/ColossalChat/coati/dataset/conversation.py
@@ -18,6 +18,7 @@ class Conversation:
chat_template: str
stop_ids: List[int]
end_of_assistant: str
+ roles = ["user", "assistant"]
@classmethod
def from_config(cls, tokenizer: PreTrainedTokenizer, config: Dict):
@@ -85,7 +86,7 @@ class Conversation:
Raises:
AssertionError: If the role is not 'user' or 'assistant'.
"""
- assert role in ["user", "assistant"]
+ assert role in self.roles
self.messages.append({"role": role, "content": message})
def copy(self):
diff --git a/applications/ColossalChat/coati/dataset/tokenization_utils.py b/applications/ColossalChat/coati/dataset/tokenization_utils.py
index cfacdf308..cf69b6fe7 100755
--- a/applications/ColossalChat/coati/dataset/tokenization_utils.py
+++ b/applications/ColossalChat/coati/dataset/tokenization_utils.py
@@ -39,7 +39,7 @@ def supervised_tokenize_sft(
Args:
data_point: the data point of the following format
- {"messages": [{"from": "human", "content": "xxx"}, {"from": "assistant", "content": "xxx"}]}
+ {"messages": [{"from": "user", "content": "xxx"}, {"from": "assistant", "content": "xxx"}]}
tokenizer: the tokenizer whose
conversation_template: the conversation template to apply
ignore_index: the ignore index when calculate loss during training
@@ -52,41 +52,25 @@ def supervised_tokenize_sft(
messages = data_point["messages"]
template = deepcopy(conversation_template)
template.messages = []
-
- for mess in messages:
- from_str = mess["from"]
- if from_str.lower() == "human":
- from_str = "user"
- elif from_str.lower() == "assistant":
- from_str = "assistant"
- else:
- raise ValueError(f"Unsupported role {from_str.lower()}")
-
- template.append_message(from_str, mess["content"])
+ for idx, mess in enumerate(messages):
+ if mess["from"] != template.roles[idx % 2]:
+ raise ValueError(
+ f"Message should iterate between user and assistant and starts with a \
+ line from the user. Got the following data:\n{messages}"
+ )
+ template.append_message(mess["from"], mess["content"])
if len(template.messages) % 2 != 0:
+ # Force to end with assistant response
template.messages = template.messages[0:-1]
- # `target_turn_index` is the number of turns which exceeds `max_length - 1` for the first time.
- turns = [i for i in range(1, len(messages) // 2 + 1)]
-
- lo, hi = 0, len(turns)
- while lo < hi:
- mid = (lo + hi) // 2
- prompt = template.get_prompt(2 * turns[mid] - 1)
- chunks, require_loss = split_templated_prompt_into_chunks(
- template.messages[: 2 * turns[mid] - 1], prompt, conversation_template.end_of_assistant
- )
- tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss)
- if max_length - 1 < len(tokenized):
- hi = mid
- else:
- lo = mid + 1
- target_turn_index = lo
-
- # The tokenized length for first turn already exceeds `max_length - 1`.
- if target_turn_index - 1 < 0:
- warnings.warn("The tokenized length for first turn already exceeds `max_length - 1`.")
+ # tokenize and calculate masked labels -100 for positions corresponding to non-assistant lines
+ prompt = template.get_prompt()
+ chunks, require_loss = split_templated_prompt_into_chunks(
+ template.messages, prompt, conversation_template.end_of_assistant
+ )
+ tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss, max_length=max_length)
+ if tokenized is None:
return dict(
input_ids=None,
labels=None,
@@ -96,45 +80,18 @@ def supervised_tokenize_sft(
seq_category=None,
)
- target_turn = turns[target_turn_index - 1]
- prompt = template.get_prompt(2 * target_turn)
- chunks, require_loss = split_templated_prompt_into_chunks(
- template.messages[: 2 * target_turn], prompt, conversation_template.end_of_assistant
- )
- tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss)
-
labels = [ignore_index] * len(tokenized)
for start, end in zip(starts, ends):
- if end == len(tokenized):
- tokenized = tokenized + [tokenizer.eos_token_id]
- labels = labels + [ignore_index]
labels[start:end] = tokenized[start:end]
- # truncate the sequence at the last token that requires loss calculation
- to_truncate_len = 0
- for i in range(len(tokenized) - 1, -1, -1):
- if labels[i] == ignore_index:
- to_truncate_len += 1
- else:
- break
- to_truncate_len = max(len(tokenized) - max_length, to_truncate_len)
- tokenized = tokenized[: len(tokenized) - to_truncate_len]
- labels = labels[: len(labels) - to_truncate_len]
-
if tokenizer.bos_token_id is not None:
+ # Force to add bos token at the beginning of the tokenized sequence if the input ids doesn;t starts with bos
if tokenized[0] != tokenizer.bos_token_id:
+ # Some chat templates already include bos token
tokenized = [tokenizer.bos_token_id] + tokenized
- labels = [ignore_index] + labels
+ labels = [-100] + labels
- if tokenizer.eos_token_id is not None:
- # Force to add eos token at the end of the tokenized sequence
- if tokenized[-1] != tokenizer.eos_token_id:
- tokenized = tokenized + [tokenizer.eos_token_id]
- labels = labels + [tokenizer.eos_token_id]
- else:
- labels[-1] = tokenizer.eos_token_id
-
- # For some model without bos/eos may raise the following errors
+ # log decoded inputs and labels for debugging
inputs_decode = tokenizer.decode(tokenized)
start = 0
end = 0
@@ -183,7 +140,7 @@ def tokenize_prompt_dataset(
"Something here can be system message[user_line_start]User line[User line end][Assistant line start]Assistant line[Assistant line end]...[Assistant line start]"
Args:
data_point: the data point of the following format
- {"messages": [{"from": "human", "content": "xxx"}, {"from": "assistant", "content": "xxx"}]}
+ {"messages": [{"from": "user", "content": "xxx"}, {"from": "assistant", "content": "xxx"}]}
tokenizer: the tokenizer whose
conversation_template: the conversation template to apply
ignore_index: the ignore index when calculate loss during training
@@ -196,35 +153,28 @@ def tokenize_prompt_dataset(
template = deepcopy(conversation_template)
template.messages = []
- for mess in messages:
- from_str = mess["from"]
- if from_str.lower() == "human":
- from_str = "user"
- elif from_str.lower() == "assistant":
- from_str = "assistant"
- else:
- raise ValueError(f"Unsupported role {from_str.lower()}")
-
- template.append_message(from_str, mess["content"])
+ for idx, mess in enumerate(messages):
+ if mess["from"] != template.roles[idx % 2]:
+ raise ValueError(
+ f"Message should iterate between user and assistant and starts with a \
+ line from the user. Got the following data:\n{messages}"
+ )
+ template.append_message(mess["from"], mess["content"])
# `target_turn_index` is the number of turns which exceeds `max_length - 1` for the first time.
- target_turn = len(template.messages)
- if target_turn % 2 != 1:
+ if len(template.messages) % 2 != 1:
# exclude the answer if provided. keep only the prompt
- target_turn = target_turn - 1
+ template.messages = template.messages[:-1]
# Prepare data
- prompt = template.get_prompt(target_turn, add_generation_prompt=True)
- chunks, require_loss = split_templated_prompt_into_chunks(
- template.messages[:target_turn], prompt, conversation_template.end_of_assistant
- )
- tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss)
+ prompt = template.get_prompt(length=len(template.messages) - 1, add_generation_prompt=True)
+ tokenized = tokenizer([prompt], add_special_tokens=False)["input_ids"][0]
+
if tokenizer.bos_token_id is not None:
if tokenized[0] != tokenizer.bos_token_id:
tokenized = [tokenizer.bos_token_id] + tokenized
- # Skip overlength data
- if max_length - 1 < len(tokenized):
+ if len(tokenized) > max_length:
return dict(
input_ids=None,
inputs_decode=None,
@@ -235,47 +185,32 @@ def tokenize_prompt_dataset(
# `inputs_decode` can be used to check whether the tokenization method is true.
return dict(
input_ids=tokenized,
- inputs_decode=tokenizer.decode(tokenized),
+ inputs_decode=prompt,
seq_length=len(tokenized),
seq_category=data_point["category"] if "category" in data_point else "None",
)
-def apply_rlhf_data_format(
- template: Conversation, tokenizer: Any, context_len: int, mask_out_target_assistant_line_end=False
-):
+def apply_rlhf_data_format(template: Conversation, tokenizer: Any):
target_turn = int(len(template.messages) / 2)
prompt = template.get_prompt(target_turn * 2)
chunks, require_loss = split_templated_prompt_into_chunks(
template.messages[: 2 * target_turn], prompt, template.end_of_assistant
)
- tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss)
- loss_mask = [0] * len(tokenized)
- mask_token = tokenizer.eos_token_id or tokenizer.pad_token_id
- if mask_token is None:
- mask_token = 1 # If the tokenizer doesn't have eos_token or pad_token: Qwen
+ # no truncation applied
+ tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss, max_length=int(1e10))
+ loss_mask = [0] * len(tokenized)
label_decode = []
- for start, end in zip(starts[-1:], ends[-1:]):
- # only the last round (chosen/rejected) counts
- if end == len(tokenized):
- tokenized = tokenized + [tokenizer.eos_token_id]
- loss_mask = loss_mask + [1]
- loss_mask[start:end] = [1] * len(loss_mask[start:end])
- label_decode.append(tokenizer.decode(tokenized[start:end], skip_special_tokens=False))
+ # only the last round (chosen/rejected) is used to calculate loss
+ for i in range(starts[-1], ends[-1]):
+ loss_mask[i] = 1
+ label_decode.append(tokenizer.decode(tokenized[starts[-1] : ends[-1]], skip_special_tokens=False))
if tokenizer.bos_token_id is not None:
if tokenized[0] != tokenizer.bos_token_id:
tokenized = [tokenizer.bos_token_id] + tokenized
loss_mask = [0] + loss_mask
- if tokenizer.eos_token_id is not None:
- # Force to add eos token at the end of the tokenized sequence
- if tokenized[-1] != tokenizer.eos_token_id:
- tokenized = tokenized + [tokenizer.eos_token_id]
- loss_mask = loss_mask + [1]
- else:
- loss_mask[-1] = 1
-
return {"input_ids": tokenized, "loss_mask": loss_mask, "label_decode": label_decode}
@@ -288,7 +223,7 @@ def tokenize_rlhf(
) -> Dict[str, Union[int, str, List[int]]]:
"""
A tokenization function to tokenize an original pretraining data point as following:
- {"context": [{"from": "human", "content": "xxx"}, {"from": "assistant", "content": "xxx"}],
+ {"context": [{"from": "user", "content": "xxx"}, {"from": "assistant", "content": "xxx"}],
"chosen": {"from": "assistant", "content": "xxx"}, "rejected": {"from": "assistant", "content": "xxx"}}
"""
if ignore_index is None:
@@ -298,24 +233,17 @@ def tokenize_rlhf(
template = deepcopy(conversation_template)
template.clear()
- for mess in context:
- from_str = mess["from"]
- if from_str.lower() == "human":
- from_str = "user"
- elif from_str.lower() == "assistant":
- from_str = "assistant"
- else:
- raise ValueError(f"Unsupported role {from_str.lower()}")
-
- if len(template.messages) > 0 and from_str == template.messages[-1]["role"]:
- # Concate adjacent message from the same role
- template.messages[-1]["content"] = str(template.messages[-1]["content"] + " " + mess["content"])
- else:
- template.append_message(from_str, mess["content"])
+ for idx, mess in enumerate(context):
+ if mess["from"] != template.roles[idx % 2]:
+ raise ValueError(
+ f"Message should iterate between user and assistant and starts with a \
+ line from the user. Got the following data:\n{context}"
+ )
+ template.append_message(mess["from"], mess["content"])
if len(template.messages) % 2 != 1:
warnings.warn(
- "Please make sure leading context starts and ends with a line from human\nLeading context: "
+ "Please make sure leading context starts and ends with a line from user\nLeading context: "
+ str(template.messages)
)
return dict(
@@ -326,31 +254,27 @@ def tokenize_rlhf(
rejected_loss_mask=None,
rejected_label_decode=None,
)
- round_of_context = int((len(template.messages) - 1) / 2)
- assert context[-1]["from"].lower() == "human", "The last message in context should be from human."
+ assert context[-1]["from"].lower() == template.roles[0], "The last message in context should be from user."
chosen = deepcopy(template)
rejected = deepcopy(template)
+ chosen_continuation = data_point["chosen"]
+ rejected_continuation = data_point["rejected"]
+ for round in range(len(chosen_continuation)):
+ if chosen_continuation[round]["from"] != template.roles[(round + 1) % 2]:
+ raise ValueError(
+ f"Message should iterate between user and assistant and starts with a \
+ line from the user. Got the following data:\n{chosen_continuation}"
+ )
+ chosen.append_message(chosen_continuation[round]["from"], chosen_continuation[round]["content"])
- for round in range(len(data_point["chosen"])):
- from_str = data_point["chosen"][round]["from"]
- if from_str.lower() == "human":
- from_str = "user"
- elif from_str.lower() == "assistant":
- from_str = "assistant"
- else:
- raise ValueError(f"Unsupported role {from_str.lower()}")
- chosen.append_message(from_str, data_point["chosen"][round]["content"])
-
- for round in range(len(data_point["rejected"])):
- from_str = data_point["rejected"][round]["from"]
- if from_str.lower() == "human":
- from_str = "user"
- elif from_str.lower() == "assistant":
- from_str = "assistant"
- else:
- raise ValueError(f"Unsupported role {from_str.lower()}")
- rejected.append_message(from_str, data_point["rejected"][round]["content"])
+ for round in range(len(rejected_continuation)):
+ if rejected_continuation[round]["from"] != template.roles[(round + 1) % 2]:
+ raise ValueError(
+ f"Message should iterate between user and assistant and starts with a \
+ line from the user. Got the following data:\n{rejected_continuation}"
+ )
+ rejected.append_message(rejected_continuation[round]["from"], rejected_continuation[round]["content"])
(
chosen_input_ids,
@@ -361,16 +285,14 @@ def tokenize_rlhf(
rejected_label_decode,
) = (None, None, None, None, None, None)
- chosen_data_packed = apply_rlhf_data_format(chosen, tokenizer, round_of_context)
+ chosen_data_packed = apply_rlhf_data_format(chosen, tokenizer)
(chosen_input_ids, chosen_loss_mask, chosen_label_decode) = (
chosen_data_packed["input_ids"],
chosen_data_packed["loss_mask"],
chosen_data_packed["label_decode"],
)
- rejected_data_packed = apply_rlhf_data_format(
- rejected, tokenizer, round_of_context, mask_out_target_assistant_line_end=True
- )
+ rejected_data_packed = apply_rlhf_data_format(rejected, tokenizer)
(rejected_input_ids, rejected_loss_mask, rejected_label_decode) = (
rejected_data_packed["input_ids"],
rejected_data_packed["loss_mask"],
@@ -387,7 +309,7 @@ def tokenize_rlhf(
rejected_label_decode=None,
)
# Check if loss mask is all 0s (no loss), this may happen when the tokenized length is too long
- if chosen_loss_mask[1:].count(1) == 0 or rejected_loss_mask[1:].count(1) == 0:
+ if chosen_loss_mask.count(1) == 0 or rejected_loss_mask.count(1) == 0:
return dict(
chosen_input_ids=None,
chosen_loss_mask=None,
@@ -411,14 +333,13 @@ def tokenize_kto(
data_point: Dict[str, str],
tokenizer: PreTrainedTokenizer,
conversation_template: Conversation = None,
- ignore_index: int = None,
max_length: int = 4096,
) -> Dict[str, Union[int, str, List[int]]]:
"""
Tokenize a dataset for KTO training
The raw input data is conversation that have the following format
{
- "prompt": [{"from": "human", "content": "xxx"}...],
+ "prompt": [{"from": "user", "content": "xxx"}...],
"completion": {"from": "assistant", "content": "xxx"},
"label": true/false
}
@@ -427,21 +348,18 @@ def tokenize_kto(
the completion, which only contains the assistance's answer,
and a binary label, which indicates if the sample is prefered or not
"""
- if ignore_index is None:
- ignore_index = IGNORE_INDEX
-
prompt = data_point["prompt"]
completion = data_point["completion"]
template = deepcopy(conversation_template)
template.clear()
- if prompt[0].get("from", None) != "human":
- raise ValueError("conversation should start with human")
+ if prompt[0].get("from", None) != "user":
+ raise ValueError("conversation should start with user")
if completion.get("from", None) != "assistant":
raise ValueError("conversation should end with assistant")
for mess in prompt:
- if mess.get("from", None) == "human":
+ if mess.get("from", None) == "user":
template.append_message("user", mess["content"])
elif mess.get("from", None) == "assistant":
template.append_message("assistant", mess["content"])
diff --git a/applications/ColossalChat/coati/dataset/utils.py b/applications/ColossalChat/coati/dataset/utils.py
index f41a4d772..cf767b444 100755
--- a/applications/ColossalChat/coati/dataset/utils.py
+++ b/applications/ColossalChat/coati/dataset/utils.py
@@ -88,7 +88,13 @@ def find_first_occurrence_subsequence(seq: torch.Tensor, subseq: torch.Tensor, s
return -1
-def tokenize_and_concatenate(tokenizer: PreTrainedTokenizer, text: List[str], require_loss: List[bool]):
+def tokenize_and_concatenate(
+ tokenizer: PreTrainedTokenizer,
+ text: List[str],
+ require_loss: List[bool],
+ max_length: int,
+ discard_non_loss_tokens_at_tail: bool = True,
+):
"""
Tokenizes a list of texts using the provided tokenizer and concatenates the tokenized outputs.
@@ -96,6 +102,13 @@ def tokenize_and_concatenate(tokenizer: PreTrainedTokenizer, text: List[str], re
tokenizer (PreTrainedTokenizer): The tokenizer to use for tokenization.
text (List[str]): The list of texts to tokenize.
require_loss (List[bool]): A list of boolean values indicating whether each text requires loss calculation.
+ max_length: used to truncate the input ids
+ discard_non_loss_tokens_at_tail: whether to discard the non-loss tokens at the tail
+
+ if the first round has already exeeded max length
+ - if the user query already exeeded max length, discard the sample
+ - if only the first assistant response exeeded max length, truncate the response to fit the max length
+ else keep the first several complete rounds of the conversations until max length is reached
Returns:
Tuple[List[int], List[int], List[int]]: A tuple containing the concatenated tokenized input ids,
@@ -106,10 +119,17 @@ def tokenize_and_concatenate(tokenizer: PreTrainedTokenizer, text: List[str], re
loss_ends = []
for s, r in zip(text, require_loss):
tokenized = tokenizer(s, add_special_tokens=False)["input_ids"]
- if r:
- loss_starts.append(len(input_ids))
- loss_ends.append(len(input_ids) + len(tokenized))
- input_ids.extend(tokenized)
+ if len(input_ids) + len(tokenized) <= max_length or len(loss_ends) == 0:
+ if r:
+ loss_starts.append(len(input_ids))
+ loss_ends.append(len(input_ids) + len(tokenized))
+ input_ids.extend(tokenized)
+ if loss_starts[0] >= max_length:
+ return None, None, None
+ if discard_non_loss_tokens_at_tail:
+ input_ids = input_ids[: loss_ends[-1]]
+ input_ids = input_ids[:max_length]
+ loss_ends[-1] = min(max_length, loss_ends[-1])
return input_ids, loss_starts, loss_ends
@@ -125,6 +145,12 @@ def split_templated_prompt_into_chunks(messages: List[Dict[str, str]], prompt: s
content_length = (
prompt.find(end_of_assistant, first_occur + content_length) + len(end_of_assistant) - first_occur
)
+ # if the tokenized content start with a leading space, we want to keep it in loss calculation
+ # e.g., Assistant: I am saying...
+ # if the tokenized content doesn't start with a leading space, we only need to keep the content in loss calculation
+ # e.g.,
+ # Assistant: # '\n' as line breaker
+ # I am saying...
if prompt[first_occur - 1] != " ":
chunks.append(prompt[start_idx:first_occur])
chunks.append(prompt[first_occur : first_occur + content_length])
diff --git a/applications/ColossalChat/examples/README.md b/applications/ColossalChat/examples/README.md
index 54f1bdba6..4047553da 100755
--- a/applications/ColossalChat/examples/README.md
+++ b/applications/ColossalChat/examples/README.md
@@ -448,7 +448,7 @@ The first step in Stage 1 is to collect a dataset of human demonstrations of the
{"messages":
[
{
- "from": "human",
+ "from": "user",
"content": "what are some pranks with a pen i can do?"
},
{
@@ -529,7 +529,7 @@ Below shows the preference dataset format used in training the reward model.
[
{"context": [
{
- "from": "human",
+ "from": "user",
"content": "Introduce butterflies species in Oregon."
}
]
@@ -598,7 +598,7 @@ In stage3 we will use reinforcement learning algorithm--- Proximal Policy Optimi
#### Step 1: Data Collection
-PPO uses two kinds of training data--- the prompt data and the pretrain data (optional). The first dataset is mandatory, data samples within the prompt dataset ends with a line from "human" and thus the "assistant" needs to generate a response to answer to the "human". Note that you can still use conversation that ends with a line from the "assistant", in that case, the last line will be dropped. Here is an example of the prompt dataset format.
+PPO uses two kinds of training data--- the prompt data and the pretrain data (optional). The first dataset is mandatory, data samples within the prompt dataset ends with a line from "user" and thus the "assistant" needs to generate a response to answer to the "user". Note that you can still use conversation that ends with a line from the "assistant", in that case, the last line will be dropped. Here is an example of the prompt dataset format.
```json
@@ -606,7 +606,7 @@ PPO uses two kinds of training data--- the prompt data and the pretrain data (op
{"messages":
[
{
- "from": "human",
+ "from": "user",
"content": "what are some pranks with a pen i can do?"
}
...
@@ -761,12 +761,12 @@ For KTO data preparation, please use the script [prepare_kto_dataset.sh](./examp
```jsonl
{
"prompt": [
- {"from": "human", "content": "What are some cuss words in english?"},
+ {"from": "user", "content": "What are some cuss words in english?"},
{
"from": "assistant",
"content": "Here's an incomplete list.\n\nAss, dick, bugger, crap, ...",
},
- {"from": "human", "content": "What's your favorite one?"},
+ {"from": "user", "content": "What's your favorite one?"},
],
"completion": {"from": "assistant", "content": "Ass."}, # the completion must contain a single line from the assistant.
"label": False, # whether the response is favorable or not
diff --git a/applications/ColossalChat/tests/test_data/dpo/test_dpo_data.jsonl b/applications/ColossalChat/tests/test_data/dpo/test_dpo_data.jsonl
index 2e11a91c6..0f9a02ea3 100644
--- a/applications/ColossalChat/tests/test_data/dpo/test_dpo_data.jsonl
+++ b/applications/ColossalChat/tests/test_data/dpo/test_dpo_data.jsonl
@@ -1 +1 @@
-{"context": [{"from": "human", "content": "If I were to give you some coordinates, could you tell me where the location is?"}], "chosen": [{"from": "assistant", "content": "Certainly! I can tell you if the given coordinates are within any geographic location, such as a country or city. However, I cannot give accurate directions to any specific location without more information. Are you looking for a more general geographic location, or do you want more precise directions?"}], "rejected": [{"from": "assistant", "content": "Yes, I can determine the location and give you some details about the area. Please provide the latitude and longitude coordinates, and I will return a description of the location based on these coordinates. Please note that precise geographical coordinates are often imprecise, so please provide the best possible coordinates you have available. Here are some examples:\n-The coordinates 41.543554,-71.09471 locate the Boston Commons in Boston, Massachusetts\n-The coordinates 37.42242, -122.083733 locate the Palace of Fine Arts in San Francisco, California \n-The coordinates 25.774145, -80.190262 locate Canal Street in Miami Beach, Florida\n-The coordinates 35. journalisten nl 48.056944, -106.277778 locate the Wyoming State Penitentiary in Rawlins, Wyoming\n-The coordinates 51.507351,-0.127758 locate Tower Bridge in London, England\n-The coordinates 44.885618, -79.422594 locate the Stratosphere Casino, Hotel and Tower in Las Vegas, Nevada"}]}
+{"context": [{"from": "user", "content": "If I were to giv6rkkurjetes, could you tell me where the location is?"}, {"from": "assistant", "content": "If I were to give you some coordinates, could you tell me where the location is?"}, {"from": "user", "content": "If I were to give you 76i85l5e4henates, could you tell me where the location is?"}], "chosen": [{"from": "assistant", "content": "Certainly! I can tell you if the given coordinatestw434yy45wn any geographic location, such as a country or city. However, I cannot give accurate directions to any specific location without more information. Are you looking for a more general geographic location, or do you want more precise directions?"}], "rejected": [{"from": "assistant", "content": "Yes, I can determine the location and give45wywydfgw46 about the area. Please provide the latitude and longitude coordinates, and I will return a description of the location based on these coordinates. Please note that precise geographical coordinates are often imprecise, so please provide the best possible coordinates you have available. Here are some examples:\n-The coordinates 41.543554,-71.09471 locate the Boston Commons in Boston, Massachusetts\n-The coordinates 37.42242, -122.083733 locate the Palace of Fine Arts in San Francisco, California \n-The coordinates 25.774145, -80.190262 locate Canal Street in Miami Beach, Florida\n-The coordinates 35. journalisten nl 48.056944, -106.277778 locate the Wyoming State Penitentiary in Rawlins, Wyoming\n-The coordinates 51.507351,-0.127758 locate Tower Bridge in London, England\n-The coordinates 44.885618, -79.422594 locate the Stratosphere Casino, Hotel and Tower in Las Vegas, Nevada"}]}
diff --git a/applications/ColossalChat/tests/test_data/kto/test_kto_data.jsonl b/applications/ColossalChat/tests/test_data/kto/test_kto_data.jsonl
new file mode 100644
index 000000000..f0e78b1c6
--- /dev/null
+++ b/applications/ColossalChat/tests/test_data/kto/test_kto_data.jsonl
@@ -0,0 +1 @@
+{"prompt": [{"from": "user", "content": "What are some cuss words in english?"}, {"from": "assistant", "content": "Here's an incomplete list.\n\nAss, dick, bugger, crap, ..."},{"from": "user", "content": "What's your favorite one?"}],"completion": {"from": "assistant", "content": "Ass."},"label": false}
diff --git a/applications/ColossalChat/tests/test_data/sft/test_sft_data.jsonl b/applications/ColossalChat/tests/test_data/sft/test_sft_data.jsonl
index 21c4d9dc7..759bba7a0 100644
--- a/applications/ColossalChat/tests/test_data/sft/test_sft_data.jsonl
+++ b/applications/ColossalChat/tests/test_data/sft/test_sft_data.jsonl
@@ -1 +1 @@
-{"messages": [{"from": "human", "content": "what are some pranks with a pen I can do?"}, {"from": "assistant", "content": "Are you looking for practical joke ideas?"}, {"from": "human", "content": "Yes, exactly!"}, {"from": "assistant", "content": "You could try attaching a small buzzer to a pen. When someone tries to use it, they'll be surprised by the sound!"}, {"from": "human", "content": "That's a funny one! Any other ideas?"}, {"from": "assistant", "content": "Another idea is the classic 'ink explosion' prank. Replace the ink of a pen with disappearing ink, and watch the confusion when the ink vanishes from the paper."}]}
+{"messages": [{"from": "user", "content": "what are some pranks with a pen I can do?"}, {"from": "assistant", "content": "Are you looking for practical joke ideas?"}, {"from": "user", "content": "Yes, exactly!"}, {"from": "assistant", "content": "You could try attaching a small buzzer to a pen. When someone tries to use it, they'll be surprised by the sound!"}, {"from": "user", "content": "That's a funny one! Any other ideas?"}, {"from": "assistant", "content": "Another idea is the classic 'ink explosion' prank. Replace the ink of a pen with disappearing ink, and watch the confusion when the ink vanishes from the paper."}]}
diff --git a/applications/ColossalChat/tests/test_templating.sh b/applications/ColossalChat/tests/test_templating.sh
index d033c07f5..6ee10e8be 100755
--- a/applications/ColossalChat/tests/test_templating.sh
+++ b/applications/ColossalChat/tests/test_templating.sh
@@ -94,7 +94,7 @@ done
# Test DPO/PPO data Preparation
for model in ${MODELS[@]}; do
- echo "Testing DPO/PPO data templating for $model"
+ echo "Testing DPO/RM data templating for $model"
SAVE_DIR=$DATA_SAVE_PATH/dpo/$model
rm -rf $SAVE_DIR/cache
rm -rf $SAVE_DIR/jsonl
@@ -109,14 +109,44 @@ for model in ${MODELS[@]}; do
--data_arrow_output_dir $SAVE_DIR/arrow
passed=$?
if [ $passed -ne 0 ]; then
- echo "[Test]: Failed in the DPO data templating for $model"
+ echo "[Test]: Failed in the DPO/RM data templating for $model"
exit 1
fi
python $BASE_DIR/tests/verify_chat_data.py --data_source $TEST_DATA_DIR/dpo/test_dpo_data.jsonl \
--to_verify_file $SAVE_DIR/jsonl/part-00005.jsonl --data_type dpo
passed=$?
if [ $passed -ne 0 ]; then
- echo "[Test]: Failed in the DPO data templating test for $model"
+ echo "[Test]: Failed in the DPO/RM data templating test for $model"
+ exit 1
+ fi
+done
+
+
+# Test KTO data Preparation
+for model in ${MODELS[@]}; do
+ echo "Testing KTO data templating for $model"
+ SAVE_DIR=$DATA_SAVE_PATH/kto/$model
+ rm -rf $SAVE_DIR/cache
+ rm -rf $SAVE_DIR/jsonl
+ rm -rf $SAVE_DIR/arrow
+ pretrain=$(get_pretrain $model)
+ conversation_template_config=$(get_conversation_template_config $model)
+ python $EXAMPLES_DIR/data_preparation_scripts/prepare_dataset.py --type kto --data_input_dirs $TEST_DATA_DIR/kto \
+ --tokenizer_dir $pretrain \
+ --conversation_template_config $conversation_template_config \
+ --data_cache_dir $SAVE_DIR/cache \
+ --data_jsonl_output_dir $SAVE_DIR/jsonl \
+ --data_arrow_output_dir $SAVE_DIR/arrow
+ passed=$?
+ if [ $passed -ne 0 ]; then
+ echo "[Test]: Failed in the KTO data templating for $model"
+ exit 1
+ fi
+ python $BASE_DIR/tests/verify_chat_data.py --data_source $TEST_DATA_DIR/kto/test_kto_data.jsonl \
+ --to_verify_file $SAVE_DIR/jsonl/part-00005.jsonl --data_type kto
+ passed=$?
+ if [ $passed -ne 0 ]; then
+ echo "[Test]: Failed in the KTO data templating test for $model"
exit 1
fi
done
diff --git a/applications/ColossalChat/tests/verify_chat_data.py b/applications/ColossalChat/tests/verify_chat_data.py
index 98ae0c1b2..eb8f9ce46 100644
--- a/applications/ColossalChat/tests/verify_chat_data.py
+++ b/applications/ColossalChat/tests/verify_chat_data.py
@@ -62,3 +62,11 @@ if __name__ == "__main__":
assert any(
[rejected_lable in s for s in to_verify_lable_rejected]
), f"Rejected label {rejected_lable} not in target rejected label {to_verify_lable_chosen}"
+ elif args.data_type == "kto":
+ sample = data[0]
+ to_verify_data = to_verify_data[0]
+ for line in sample["prompt"]:
+ assert line["content"] in to_verify_data["input_id_decode"]
+ assert sample["completion"]["content"] in to_verify_data["input_id_decode"]
+ assert sample["completion"]["content"] in to_verify_data["completion_decode"]
+ assert sample["label"] == to_verify_data["label"]
From 4ec17a7cdf07db4ec4dd6b6e01ba9b88d61b4f9f Mon Sep 17 00:00:00 2001
From: zhurunhua <1281592874@qq.com>
Date: Sun, 21 Jul 2024 19:46:01 +0800
Subject: [PATCH 036/116] [FIX BUG] UnboundLocalError: cannot access local
variable 'default_conversation' where it is not associated with a value
(#5931)
* cannot access local variable 'default_conversation' where it is not associated with a value
set default value for 'default_conversation'
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
---
applications/Colossal-LLaMA/prepare_sft_dataset.py | 4 +++-
1 file changed, 3 insertions(+), 1 deletion(-)
diff --git a/applications/Colossal-LLaMA/prepare_sft_dataset.py b/applications/Colossal-LLaMA/prepare_sft_dataset.py
index a857d6c0c..fe5790760 100644
--- a/applications/Colossal-LLaMA/prepare_sft_dataset.py
+++ b/applications/Colossal-LLaMA/prepare_sft_dataset.py
@@ -10,7 +10,7 @@ import math
import os
from multiprocessing import cpu_count
-from colossal_llama.dataset.conversation import LLaMA2_Conv
+from colossal_llama.dataset.conversation import LLaMA2_Conv, LLaMA3_Conv
from colossal_llama.dataset.spliced_and_tokenized_dataset import supervised_tokenize_sft
from datasets import dataset_dict, load_dataset
from transformers import AddedToken, AutoTokenizer
@@ -75,6 +75,8 @@ def main():
# Prepare to the tokenizer.
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_dir)
+ default_conversation = LLaMA3_Conv
+
# Fix split issue: https://github.com/huggingface/transformers/issues/23833
if args.llama_version == 2:
tokenizer.add_tokens(AddedToken("", normalized=False, special=True), special_tokens=True)
From c5f582f666b4953ae8e58a51bc79469ca799793e Mon Sep 17 00:00:00 2001
From: YeAnbang
Date: Mon, 22 Jul 2024 01:31:32 +0000
Subject: [PATCH 037/116] fix test data
---
.../tests/generate_dummy_datasets_for_testing.py | 14 +++++++-------
1 file changed, 7 insertions(+), 7 deletions(-)
diff --git a/applications/ColossalChat/tests/generate_dummy_datasets_for_testing.py b/applications/ColossalChat/tests/generate_dummy_datasets_for_testing.py
index facbcd2e6..6cce6563c 100644
--- a/applications/ColossalChat/tests/generate_dummy_datasets_for_testing.py
+++ b/applications/ColossalChat/tests/generate_dummy_datasets_for_testing.py
@@ -4,7 +4,7 @@ import os
sft_seed = {
"messages": [
- {"from": "human", "content": "Give three tips for staying healthy."},
+ {"from": "user", "content": "Give three tips for staying healthy."},
{
"from": "assistant",
"content": "1.Eat a balanced diet and make sure to include plenty of fruits and vegetables. \n2. Exercise regularly to keep your body active and strong. \n3. Get enough sleep and maintain a consistent sleep schedule.",
@@ -13,7 +13,7 @@ sft_seed = {
}
prompt_seed = {
"messages": [
- {"from": "human", "content": "Describe the impacts of climate change on communities living in coastal areas."},
+ {"from": "user", "content": "Describe the impacts of climate change on communities living in coastal areas."},
{
"from": "assistant",
"content": "Climate change has caused an increase in sea levels, which has caused coastal erosion and flooding of low-lying areas. This has led to displacement of people from their homes, as well as increased risk of epidemics of waterborne illnesses. Coastal cities have also seen an increase in extreme weather events such as hurricanes and tropical storms, which can cause extensive damage to infrastructure, homes, and businesses. As a result of climate change, some coastal areas are becoming uninhabitable, forcing communities to seek alternative living arrangements.",
@@ -22,29 +22,29 @@ prompt_seed = {
}
preference_seed = {
"context": [
- {"from": "human", "content": "What kind of noises did dinosaurs make?"},
+ {"from": "user", "content": "What kind of noises did dinosaurs make?"},
{
"from": "assistant",
"content": "Humans and dinosaurs didn't live at the same time, so it's really hard to say. The best place to find out what noises dinosaurs made would be",
},
- {"from": "human", "content": "yes they did"},
+ {"from": "user", "content": "yes they did"},
{
"from": "assistant",
"content": "to guess, and that would probably require lots of reading and a certain amount of imagination, so we're not really prepared to do that.",
},
- {"from": "human", "content": "you cant read"},
+ {"from": "user", "content": "you cant read"},
],
"chosen": [{"from": "assistant", "content": "You can read?"}],
"rejected": [{"from": "assistant", "content": "there's a lot of stuff humans don't know"}],
}
kto_seed = {
"prompt": [
- {"from": "human", "content": "What are some cuss words in english?"},
+ {"from": "user", "content": "What are some cuss words in english?"},
{
"from": "assistant",
"content": "Here's an incomplete list.\n\nAss, dick, bugger, crap, fuck, shit, bitch, turd, shithead, shitbag, scrotum, cunt, whore, fucker, shit-eating, cum, cumbucket, fucknugget, butthole, poop, jackass, cocksucker, asshole, goddamn, piss, sperm, blow, wank, jism, cum-sucking, masturbate, faggot, queer, jizz, jizz-licking, prostitute, slut, cheater, fornicator, floozy, wetback, Mexican, Hispanic, sodomite, midget, mama's boy, faggot, pervert, queer, scumbag, bitch,",
},
- {"from": "human", "content": "What's your favorite one?"},
+ {"from": "user", "content": "What's your favorite one?"},
],
"completion": {"from": "assistant", "content": "Ass."},
"label": False,
From 12fe8b58580104a2fdc7fadaad6ca6b06d9d8287 Mon Sep 17 00:00:00 2001
From: YeAnbang
Date: Mon, 22 Jul 2024 05:57:39 +0000
Subject: [PATCH 038/116] refactor evaluation
---
.../ColossalChat/benchmarks/benchmark_dpo.py | 340 ------------------
.../ColossalChat/benchmarks/benchmark_dpo.sh | 25 +-
.../ColossalChat/benchmarks/benchmark_kto.py | 332 -----------------
.../ColossalChat/benchmarks/benchmark_kto.sh | 20 +-
.../ColossalChat/benchmarks/benchmark_orpo.py | 315 ----------------
.../ColossalChat/benchmarks/benchmark_orpo.sh | 21 +-
.../ColossalChat/benchmarks/benchmark_sft.py | 315 ----------------
.../ColossalChat/benchmarks/benchmark_sft.sh | 19 +-
.../benchmarks/benchmark_simpo.sh | 55 +++
.../ColossalChat/benchmarks/dummy_dataset.py | 3 +-
.../benchmarks/prepare_dummy_test_dataset.py | 105 ++++++
.../examples/training_scripts/train_dpo.py | 22 +-
.../examples/training_scripts/train_dpo.sh | 3 +
.../examples/training_scripts/train_kto.py | 41 ++-
.../examples/training_scripts/train_kto.sh | 4 +
.../examples/training_scripts/train_orpo.py | 22 +-
.../examples/training_scripts/train_orpo.sh | 3 +
.../examples/training_scripts/train_rm.py | 22 +-
.../examples/training_scripts/train_rm.sh | 3 +
.../examples/training_scripts/train_sft.py | 23 +-
.../examples/training_scripts/train_sft.sh | 3 +
applications/ColossalChat/tests/test_train.sh | 1 +
22 files changed, 309 insertions(+), 1388 deletions(-)
delete mode 100755 applications/ColossalChat/benchmarks/benchmark_dpo.py
delete mode 100755 applications/ColossalChat/benchmarks/benchmark_kto.py
delete mode 100755 applications/ColossalChat/benchmarks/benchmark_orpo.py
delete mode 100644 applications/ColossalChat/benchmarks/benchmark_sft.py
create mode 100755 applications/ColossalChat/benchmarks/benchmark_simpo.sh
create mode 100644 applications/ColossalChat/benchmarks/prepare_dummy_test_dataset.py
diff --git a/applications/ColossalChat/benchmarks/benchmark_dpo.py b/applications/ColossalChat/benchmarks/benchmark_dpo.py
deleted file mode 100755
index f80d81566..000000000
--- a/applications/ColossalChat/benchmarks/benchmark_dpo.py
+++ /dev/null
@@ -1,340 +0,0 @@
-import argparse
-import json
-import os
-import resource
-from contextlib import nullcontext
-
-import torch
-from coati.dataset import DataCollatorForPreferenceDataset, StatefulDistributedSampler
-from coati.models import convert_to_lora_module, disable_dropout
-from coati.trainer import DPOTrainer
-from coati.utils import load_checkpoint
-from dummy_dataset import DummyLLMDataset
-from transformers import AutoModelForCausalLM, AutoTokenizer
-
-import colossalai
-from colossalai.booster import Booster
-from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
-from colossalai.cluster import DistCoordinator
-from colossalai.logging import get_dist_logger
-from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
-from colossalai.nn.optimizer import HybridAdam
-
-logger = get_dist_logger()
-
-
-def train(args):
- # check lora compatibility
- if "gemini" in args.plugin and args.lora_rank > 0:
- raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
- if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
- raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
-
- # ==============================
- # Initialize Distributed Training
- # ==============================
- colossalai.launch_from_torch()
- coordinator = DistCoordinator()
-
- # ==============================
- # Initialize Booster
- # ==============================
- if args.plugin == "ddp":
- """
- Default torch ddp plugin without any acceleration, for
- debugging purpose acceleration, for debugging purpose
- """
- plugin = TorchDDPPlugin(find_unused_parameters=True)
- elif args.plugin == "gemini":
- plugin = GeminiPlugin(
- precision=args.mixed_precision,
- placement_policy="static",
- initial_scale=2**16,
- max_norm=args.grad_clip,
- enable_gradient_accumulation=True,
- enable_flash_attention=args.use_flash_attn,
- )
- elif args.plugin == "gemini_auto":
- plugin = GeminiPlugin(
- precision=args.mixed_precision,
- placement_policy="auto",
- initial_scale=2**16,
- max_norm=args.grad_clip,
- enable_flash_attention=args.use_flash_attn,
- )
- elif args.plugin == "zero2":
- plugin = LowLevelZeroPlugin(
- stage=2,
- precision=args.mixed_precision,
- initial_scale=2**16,
- max_norm=args.grad_clip,
- )
- elif args.plugin == "zero2_cpu":
- plugin = LowLevelZeroPlugin(
- stage=2,
- precision=args.mixed_precision,
- initial_scale=2**16,
- cpu_offload=True,
- max_norm=args.grad_clip,
- )
- elif args.plugin == "3d":
- plugin = HybridParallelPlugin(
- tp_size=args.tp,
- pp_size=args.pp,
- sp_size=args.sp,
- sequence_parallelism_mode=args.sp_mode,
- zero_stage=args.zero_stage,
- enable_flash_attention=args.use_flash_attn,
- enable_sequence_parallelism=args.enable_sequence_parallelism,
- cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
- parallel_output=False,
- max_norm=args.grad_clip,
- precision=args.mixed_precision,
- )
- else:
- raise ValueError(f"Unknown plugin {args.plugin}")
-
- booster = Booster(plugin=plugin)
- ref_booster = Booster(plugin=plugin)
-
- # ======================================================
- # Initialize Model, Objective, Optimizer and LR Scheduler
- # ======================================================
- # Temp Fix: Disable lazy init due to version conflict
- # init_ctx = (
- # LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
- # )
-
- init_ctx = nullcontext()
- with init_ctx:
- if args.use_flash_attn:
- model = AutoModelForCausalLM.from_pretrained(
- args.pretrain,
- torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
- use_flash_attention_2=True,
- )
- coordinator.print_on_master(msg="Flash-attention enabled successfully")
- else:
- model = AutoModelForCausalLM.from_pretrained(args.pretrain)
- disable_dropout(model)
- if not args.disable_reference_model:
- if args.use_flash_attn:
- ref_model = AutoModelForCausalLM.from_pretrained(
- args.pretrain,
- torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
- use_flash_attention_2=True,
- )
- else:
- ref_model = AutoModelForCausalLM.from_pretrained(args.pretrain)
- disable_dropout(ref_model)
- else:
- ref_model = None
- if args.lora_rank > 0:
- model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
-
- if args.grad_checkpoint:
- # Note, for some models, lora may not be compatible with gradient checkpointing
- model.gradient_checkpointing_enable()
- coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
-
- # configure tokenizer
- tokenizer_dir = args.tokenizer_dir if args.tokenizer_dir is not None else args.pretrain
- tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, use_fast=False, trust_remote_code=True)
- if hasattr(tokenizer, "pad_token") and hasattr(tokenizer, "eos_token") and tokenizer.eos_token is not None:
- try:
- # Some tokenizers doesn't allow to set pad_token mannually e.g., Qwen
- tokenizer.pad_token = tokenizer.eos_token
- except AttributeError as e:
- logger.warning(f"Unable to set pad token to eos token, {str(e)}")
- if not hasattr(tokenizer, "pad_token") or tokenizer.pad_token is None:
- logger.warning(
- "The tokenizer does not have a pad token which is required. May lead to unintended behavior in training, Please consider manually set them."
- )
-
- tokenizer.add_bos_token = False
- tokenizer.add_eos_token = False
-
- # configure optimizer
- optim = HybridAdam(
- model_params=model.parameters(),
- lr=args.lr,
- betas=(0.9, 0.95),
- weight_decay=args.weight_decay,
- adamw_mode=True,
- )
-
- # configure dataset
- mode_map = {"train": "train", "valid": "validation", "test": "test"}
- train_dataset = DummyLLMDataset(
- ["chosen_input_ids", "chosen_loss_mask", "rejected_input_ids", "rejected_loss_mask"],
- args.max_length,
- args.dataset_size,
- )
- data_collator = DataCollatorForPreferenceDataset(tokenizer=tokenizer, max_length=args.max_length)
-
- train_dataloader = plugin.prepare_dataloader(
- dataset=train_dataset,
- batch_size=args.batch_size,
- shuffle=True,
- drop_last=True,
- collate_fn=data_collator,
- distributed_sampler_cls=StatefulDistributedSampler,
- )
-
- num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
- if args.warmup_steps is None:
- args.warmup_steps = int(args.max_epochs * 0.025 * (len(train_dataloader) // args.accumulation_steps))
- coordinator.print_on_master(f"Warmup steps is set to {args.warmup_steps}")
-
- lr_scheduler = CosineAnnealingWarmupLR(
- optimizer=optim,
- total_steps=args.max_epochs * num_update_steps_per_epoch,
- warmup_steps=args.warmup_steps,
- eta_min=0.1 * args.lr,
- )
-
- default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
- torch.set_default_dtype(default_dtype)
- model, optim, _, train_dataloader, lr_scheduler = booster.boost(
- model=model,
- optimizer=optim,
- lr_scheduler=lr_scheduler,
- dataloader=train_dataloader,
- )
- if ref_model is not None:
- ref_model, _, _, _, _ = ref_booster.boost(model=ref_model, dataloader=train_dataloader)
- torch.set_default_dtype(torch.float)
-
- coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
- coordinator.print_on_master(
- f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
- )
-
- start_epoch = 0
- sampler_start_idx = 0
- start_step = 0
- if args.checkpoint_path is not None:
- if "modeling" in args.checkpoint_path:
- coordinator.print_on_master(f"Continued pretrain from checkpoint {args.checkpoint_path}")
- booster.load_model(model, args.checkpoint_path)
- else:
- coordinator.print_on_master(f"Load model checkpoint from {args.checkpoint_path}")
- start_epoch, start_step, sampler_start_idx = load_checkpoint(
- load_dir=args.checkpoint_path,
- booster=booster,
- model=model,
- optimizer=optim,
- lr_scheduler=lr_scheduler,
- )
- assert isinstance(train_dataloader.sampler, StatefulDistributedSampler)
- train_dataloader.sampler.set_start_index(start_index=sampler_start_idx)
-
- coordinator.print_on_master(
- f"Loaded checkpoint {args.checkpoint_path} at epoch {start_epoch} step {start_step}"
- )
- coordinator.print_on_master(f"Loaded sample at index {sampler_start_idx}")
-
- coordinator.print_on_master(
- f"Checkpoint loaded max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
- )
- coordinator.print_on_master(
- f"Checkpoint loaded CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB"
- )
- coordinator.print_on_master(
- f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
- )
-
- trainer = DPOTrainer(
- actor=model,
- ref_model=ref_model,
- booster=booster,
- actor_optim=optim,
- actor_lr_scheduler=lr_scheduler,
- tokenizer=tokenizer,
- max_epochs=args.max_epochs,
- accumulation_steps=args.accumulation_steps,
- start_epoch=start_epoch,
- save_interval=None,
- save_dir=None,
- coordinator=coordinator,
- beta=args.beta,
- gamma=args.gamma,
- length_normalization=args.length_normalization,
- )
-
- trainer.fit(
- train_preference_dataloader=train_dataloader,
- eval_preference_dataloader=None,
- log_dir=None,
- use_wandb=False,
- )
- coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
-
-
-if __name__ == "__main__":
- # ==============================
- # Parse Arguments
- # ==============================
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--plugin",
- type=str,
- default="gemini",
- choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d"],
- help="Choose which plugin to use",
- )
- parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
- parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
- parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
- parser.add_argument("--tp", type=int, default=1)
- parser.add_argument("--pp", type=int, default=1)
- parser.add_argument("--sp", type=int, default=1)
- parser.add_argument("--loss_type", type=str, default="dpo_loss", help="dpo_loss or simpo_loss")
- parser.add_argument("--beta", type=float, default=0.1, help="beta in DPO loss")
- parser.add_argument("--gamma", type=float, default=0.0, help="gamma in SimPO loss")
- parser.add_argument("--length_normalization", default=False, action="store_true")
- parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true")
- parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
- parser.add_argument("--zero_cpu_offload", default=False, action="store_true")
- parser.add_argument("--sp_mode", type=str, default="split_gather", choices=["split_gather", "ring", "all_to_all"])
- parser.add_argument("--pretrain", type=str, default=None)
- parser.add_argument("--model_type", type=str, default=None)
- parser.add_argument("--tokenizer_dir", type=str, default=None)
- parser.add_argument(
- "--checkpoint_path", type=str, default=None, help="Checkpoint path if need to resume training form a checkpoint"
- )
- parser.add_argument("--config_file", type=str, default="config_file", help="Config file")
- parser.add_argument("--max_length", type=int, default=2048, help="Model max length")
- parser.add_argument("--max_epochs", type=int, default=3)
- parser.add_argument("--batch_size", type=int, default=4)
- parser.add_argument("--dataset_size", type=int, default=500)
- parser.add_argument(
- "--disable_reference_model",
- action="store_true",
- default=False,
- help="Disable the reference model (enabled by default)",
- )
- parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
- parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
- parser.add_argument(
- "--lora_train_bias",
- type=str,
- default="none",
- help="'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers",
- )
- parser.add_argument("--merge_lora_weights", type=bool, default=True)
- parser.add_argument("--lr", type=float, default=5e-6)
- parser.add_argument("--accumulation_steps", type=int, default=8)
- parser.add_argument("--grad_checkpoint", default=False, action="store_true")
- parser.add_argument("--use_flash_attn", default=False, action="store_true")
- args = parser.parse_args()
-
- # fool proof hyperparameter setup
- if args.loss_type == "simpo_loss":
- args.length_normalization = True
- args.gamma = args.gamma if args.gamma > 0 else 1.4
-
- os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
- with open(args.config_file, "w") as f:
- json.dump(args.__dict__, f, indent=4)
- train(args)
diff --git a/applications/ColossalChat/benchmarks/benchmark_dpo.sh b/applications/ColossalChat/benchmarks/benchmark_dpo.sh
index dfd0ff846..08ce0629c 100755
--- a/applications/ColossalChat/benchmarks/benchmark_dpo.sh
+++ b/applications/ColossalChat/benchmarks/benchmark_dpo.sh
@@ -17,32 +17,35 @@ set_n_least_used_CUDA_VISIBLE_DEVICES 4
PROJECT_NAME="dpo"
PARENT_CONFIG_FILE="./benchmark_config" # Path to a folder to save training config logs
-PRETRAINED_MODEL_PATH="" # huggingface or local model path
-PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
+PRETRAINED_MODEL_PATH="/root/commonData/Llama-2-7b-hf" # huggingface or local model path
+PRETRAINED_TOKENIZER_PATH="/root/commonData/Llama-2-7b-hf" # huggingface or local tokenizer path
+BENCHMARK_DATA_DIR="./temp/dpo" # Path to benchmark data
+DATASET_SIZE=320
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
-SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}"
-CONFIG_FILE="${PARENT_CONFIG_FILE}-${FULL_PROJECT_NAME}.json"
+declare -a dataset=(
+ $BENCHMARK_DATA_DIR/arrow/part-0
+)
-colossalai run --nproc_per_node 4 --master_port 31313 benchmark_dpo.py \
+# Generate dummy test data
+python prepare_dummy_test_dataset.py --data_dir $BENCHMARK_DATA_DIR --dataset_size $DATASET_SIZE --max_length 2048 --data_type preference
+
+
+colossalai run --nproc_per_node 4 --master_port 31313 ../examples/training_scripts/train_dpo.py \
--pretrain $PRETRAINED_MODEL_PATH \
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
- --config_file $CONFIG_FILE \
+ --dataset ${dataset[@]} \
--plugin "zero2_cpu" \
--max_epochs 1 \
--accumulation_steps 1 \
- --batch_size 8 \
+ --batch_size 4 \
--lr 1e-6 \
--beta 0.1 \
- --gamma 0.6 \
--mixed_precision "bf16" \
--grad_clip 1.0 \
--max_length 2048 \
- --dataset_size 640 \
--weight_decay 0.01 \
--warmup_steps 60 \
- --disable_reference_model \
- --length_normalization \
--grad_checkpoint \
--use_flash_attn
diff --git a/applications/ColossalChat/benchmarks/benchmark_kto.py b/applications/ColossalChat/benchmarks/benchmark_kto.py
deleted file mode 100755
index 99f772ad3..000000000
--- a/applications/ColossalChat/benchmarks/benchmark_kto.py
+++ /dev/null
@@ -1,332 +0,0 @@
-import argparse
-import json
-import os
-import resource
-from contextlib import nullcontext
-
-import torch
-from coati.dataset import DataCollatorForKTODataset, StatefulDistributedSampler
-from coati.models import convert_to_lora_module, disable_dropout
-from coati.trainer import KTOTrainer
-from coati.utils import load_checkpoint
-from dummy_dataset import DummyLLMDataset
-from transformers import AutoModelForCausalLM, AutoTokenizer
-
-import colossalai
-from colossalai.booster import Booster
-from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
-from colossalai.cluster import DistCoordinator
-from colossalai.logging import get_dist_logger
-from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
-from colossalai.nn.optimizer import HybridAdam
-
-logger = get_dist_logger()
-
-
-def train(args):
- # check lora compatibility
- if "gemini" in args.plugin and args.lora_rank > 0:
- raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
- if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
- raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
-
- # ==============================
- # Initialize Distributed Training
- # ==============================
- colossalai.launch_from_torch()
- coordinator = DistCoordinator()
-
- # ==============================
- # Initialize Booster
- # ==============================
- if args.plugin == "ddp":
- """
- Default torch ddp plugin without any acceleration, for
- debugging purpose acceleration, for debugging purpose
- """
- plugin = TorchDDPPlugin(find_unused_parameters=True)
- elif args.plugin == "gemini":
- plugin = GeminiPlugin(
- precision=args.mixed_precision,
- placement_policy="static",
- initial_scale=2**16,
- max_norm=args.grad_clip,
- enable_gradient_accumulation=True,
- enable_flash_attention=args.use_flash_attn,
- )
- elif args.plugin == "gemini_auto":
- plugin = GeminiPlugin(
- precision=args.mixed_precision,
- placement_policy="auto",
- initial_scale=2**16,
- max_norm=args.grad_clip,
- enable_flash_attention=args.use_flash_attn,
- )
- elif args.plugin == "zero2":
- plugin = LowLevelZeroPlugin(
- stage=2,
- precision=args.mixed_precision,
- initial_scale=2**16,
- max_norm=args.grad_clip,
- )
- elif args.plugin == "zero2_cpu":
- plugin = LowLevelZeroPlugin(
- stage=2,
- precision=args.mixed_precision,
- initial_scale=2**16,
- cpu_offload=True,
- max_norm=args.grad_clip,
- )
- elif args.plugin == "3d":
- plugin = HybridParallelPlugin(
- tp_size=args.tp,
- pp_size=args.pp,
- sp_size=args.sp,
- sequence_parallelism_mode=args.sp_mode,
- zero_stage=args.zero_stage,
- enable_flash_attention=args.use_flash_attn,
- enable_sequence_parallelism=args.enable_sequence_parallelism,
- cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
- parallel_output=False,
- max_norm=args.grad_clip,
- precision=args.mixed_precision,
- )
- else:
- raise ValueError(f"Unknown plugin {args.plugin}")
-
- booster = Booster(plugin=plugin)
- ref_booster = Booster(plugin=plugin)
-
- # ======================================================
- # Initialize Model, Objective, Optimizer and LR Scheduler
- # ======================================================
- # Temp Fix: Disable lazy init due to version conflict
- # init_ctx = (
- # LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
- # )
-
- init_ctx = nullcontext()
- with init_ctx:
- if args.use_flash_attn:
- model = AutoModelForCausalLM.from_pretrained(
- args.pretrain,
- torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
- use_flash_attention_2=True,
- )
- coordinator.print_on_master(msg="Flash-attention enabled successfully")
- else:
- model = AutoModelForCausalLM.from_pretrained(args.pretrain)
- disable_dropout(model)
- if not args.disable_reference_model:
- if args.use_flash_attn:
- ref_model = AutoModelForCausalLM.from_pretrained(
- args.pretrain,
- torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
- use_flash_attention_2=True,
- )
- else:
- ref_model = AutoModelForCausalLM.from_pretrained(args.pretrain)
- disable_dropout(ref_model)
- else:
- ref_model = None
- if args.lora_rank > 0:
- model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
-
- if args.grad_checkpoint:
- # Note, for some models, lora may not be compatible with gradient checkpointing
- model.gradient_checkpointing_enable()
- coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
-
- # configure tokenizer
- tokenizer_dir = args.tokenizer_dir if args.tokenizer_dir is not None else args.pretrain
- tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, use_fast=False, trust_remote_code=True)
- if hasattr(tokenizer, "pad_token") and hasattr(tokenizer, "eos_token") and tokenizer.eos_token is not None:
- try:
- # Some tokenizers doesn't allow to set pad_token mannually e.g., Qwen
- tokenizer.pad_token = tokenizer.eos_token
- except AttributeError as e:
- logger.warning(f"Unable to set pad token to eos token, {str(e)}")
- if not hasattr(tokenizer, "pad_token") or tokenizer.pad_token is None:
- logger.warning(
- "The tokenizer does not have a pad token which is required. May lead to unintended behavior in training, Please consider manually set them."
- )
-
- tokenizer.add_bos_token = False
- tokenizer.add_eos_token = False
-
- # configure optimizer
- optim = HybridAdam(
- model_params=model.parameters(),
- lr=args.lr,
- betas=(0.9, 0.95),
- weight_decay=args.weight_decay,
- adamw_mode=True,
- )
-
- # configure dataset
- train_dataset = DummyLLMDataset(
- ["prompt", "completion", "label"],
- args.max_length - 512,
- args.dataset_size,
- gen_fn={
- "completion": lambda x: torch.ones(512, dtype=torch.long),
- "label": lambda x: torch.tensor(x % 2, dtype=torch.long),
- },
- )
- data_collator = DataCollatorForKTODataset(tokenizer=tokenizer, max_length=args.max_length)
-
- train_dataloader = plugin.prepare_dataloader(
- dataset=train_dataset,
- batch_size=args.batch_size,
- shuffle=True,
- drop_last=True,
- collate_fn=data_collator,
- distributed_sampler_cls=StatefulDistributedSampler,
- )
-
- num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
- if args.warmup_steps is None:
- args.warmup_steps = int(args.max_epochs * 0.025 * (len(train_dataloader) // args.accumulation_steps))
- coordinator.print_on_master(f"Warmup steps is set to {args.warmup_steps}")
-
- lr_scheduler = CosineAnnealingWarmupLR(
- optimizer=optim,
- total_steps=args.max_epochs * num_update_steps_per_epoch,
- warmup_steps=args.warmup_steps,
- eta_min=0.1 * args.lr,
- )
-
- default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
- torch.set_default_dtype(default_dtype)
- model, optim, _, train_dataloader, lr_scheduler = booster.boost(
- model=model,
- optimizer=optim,
- lr_scheduler=lr_scheduler,
- dataloader=train_dataloader,
- )
- if ref_model is not None:
- ref_model, _, _, _, _ = ref_booster.boost(model=ref_model, dataloader=train_dataloader)
- torch.set_default_dtype(torch.float)
-
- coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
- coordinator.print_on_master(
- f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
- )
-
- start_epoch = 0
- sampler_start_idx = 0
- start_step = 0
- if args.checkpoint_path is not None:
- if "modeling" in args.checkpoint_path:
- coordinator.print_on_master(f"Continued pretrain from checkpoint {args.checkpoint_path}")
- booster.load_model(model, args.checkpoint_path)
- else:
- coordinator.print_on_master(f"Load model checkpoint from {args.checkpoint_path}")
- start_epoch, start_step, sampler_start_idx = load_checkpoint(
- load_dir=args.checkpoint_path,
- booster=booster,
- model=model,
- optimizer=optim,
- lr_scheduler=lr_scheduler,
- )
- assert isinstance(train_dataloader.sampler, StatefulDistributedSampler)
- train_dataloader.sampler.set_start_index(start_index=sampler_start_idx)
-
- coordinator.print_on_master(
- f"Loaded checkpoint {args.checkpoint_path} at epoch {start_epoch} step {start_step}"
- )
- coordinator.print_on_master(f"Loaded sample at index {sampler_start_idx}")
-
- coordinator.print_on_master(
- f"Checkpoint loaded max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
- )
- coordinator.print_on_master(
- f"Checkpoint loaded CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB"
- )
- coordinator.print_on_master(
- f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
- )
-
- trainer = KTOTrainer(
- actor=model,
- ref_model=ref_model,
- booster=booster,
- actor_optim=optim,
- actor_lr_scheduler=lr_scheduler,
- tokenizer=tokenizer,
- max_epochs=args.max_epochs,
- accumulation_steps=args.accumulation_steps,
- start_epoch=start_epoch,
- save_interval=None,
- save_dir=None,
- coordinator=coordinator,
- beta=args.beta,
- )
-
- trainer.fit(
- train_preference_dataloader=train_dataloader,
- eval_preference_dataloader=None,
- log_dir=None,
- use_wandb=False,
- )
- coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
-
-
-if __name__ == "__main__":
- # ==============================
- # Parse Arguments
- # ==============================
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--plugin",
- type=str,
- default="gemini",
- choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d"],
- help="Choose which plugin to use",
- )
- parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
- parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
- parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
- parser.add_argument("--tp", type=int, default=1)
- parser.add_argument("--pp", type=int, default=1)
- parser.add_argument("--sp", type=int, default=1)
- parser.add_argument("--beta", type=float, default=0.1, help="beta in KTO loss")
- parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true")
- parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
- parser.add_argument("--zero_cpu_offload", default=False, action="store_true")
- parser.add_argument("--sp_mode", type=str, default="split_gather", choices=["split_gather", "ring", "all_to_all"])
- parser.add_argument("--pretrain", type=str, default=None)
- parser.add_argument("--tokenizer_dir", type=str, default=None)
- parser.add_argument(
- "--checkpoint_path", type=str, default=None, help="Checkpoint path if need to resume training form a checkpoint"
- )
- parser.add_argument("--config_file", type=str, default="config_file", help="Config file")
- parser.add_argument("--max_length", type=int, default=2048, help="Model max length")
- parser.add_argument("--max_epochs", type=int, default=3)
- parser.add_argument("--batch_size", type=int, default=4)
- parser.add_argument("--dataset_size", type=int, default=500)
- parser.add_argument(
- "--disable_reference_model",
- action="store_true",
- default=False,
- help="Disable the reference model (enabled by default)",
- )
- parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
- parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
- parser.add_argument(
- "--lora_train_bias",
- type=str,
- default="none",
- help="'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers",
- )
- parser.add_argument("--merge_lora_weights", type=bool, default=True)
- parser.add_argument("--lr", type=float, default=5e-6)
- parser.add_argument("--accumulation_steps", type=int, default=8)
- parser.add_argument("--grad_checkpoint", default=False, action="store_true")
- parser.add_argument("--use_flash_attn", default=False, action="store_true")
- args = parser.parse_args()
-
- os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
- with open(args.config_file, "w") as f:
- json.dump(args.__dict__, f, indent=4)
- train(args)
diff --git a/applications/ColossalChat/benchmarks/benchmark_kto.sh b/applications/ColossalChat/benchmarks/benchmark_kto.sh
index 571915c3b..41de40f13 100755
--- a/applications/ColossalChat/benchmarks/benchmark_kto.sh
+++ b/applications/ColossalChat/benchmarks/benchmark_kto.sh
@@ -17,19 +17,26 @@ set_n_least_used_CUDA_VISIBLE_DEVICES 4
PROJECT_NAME="kto"
PARENT_CONFIG_FILE="./benchmark_config" # Path to a folder to save training config logs
-PRETRAINED_MODEL_PATH="" # huggingface or local model path
-PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
+PRETRAINED_MODEL_PATH="/root/commonData/Llama-2-7b-hf" # huggingface or local model path
+PRETRAINED_TOKENIZER_PATH="/root/commonData/Llama-2-7b-hf" # huggingface or local tokenizer path
+BENCHMARK_DATA_DIR="./temp/kto" # Path to benchmark data
+DATASET_SIZE=80
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
-SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}"
-CONFIG_FILE="${PARENT_CONFIG_FILE}-${FULL_PROJECT_NAME}.json"
+declare -a dataset=(
+ $BENCHMARK_DATA_DIR/arrow/part-0
+)
-colossalai run --nproc_per_node 2 --master_port 31313 benchmark_kto.py \
+# Generate dummy test data
+python prepare_dummy_test_dataset.py --data_dir $BENCHMARK_DATA_DIR --dataset_size $DATASET_SIZE --max_length 2048 --data_type kto
+
+
+colossalai run --nproc_per_node 2 --master_port 31313 ../examples/training_scripts/train_kto.py \
--pretrain $PRETRAINED_MODEL_PATH \
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
+ --dataset ${dataset[@]} \
--plugin "zero2_cpu" \
- --config_file $CONFIG_FILE \
--max_epochs 1 \
--accumulation_steps 1 \
--batch_size 2 \
@@ -38,7 +45,6 @@ colossalai run --nproc_per_node 2 --master_port 31313 benchmark_kto.py \
--mixed_precision "bf16" \
--grad_clip 1.0 \
--max_length 2048 \
- --dataset_size 80 \
--weight_decay 0.01 \
--warmup_steps 60 \
--grad_checkpoint \
diff --git a/applications/ColossalChat/benchmarks/benchmark_orpo.py b/applications/ColossalChat/benchmarks/benchmark_orpo.py
deleted file mode 100755
index 1325bada2..000000000
--- a/applications/ColossalChat/benchmarks/benchmark_orpo.py
+++ /dev/null
@@ -1,315 +0,0 @@
-import argparse
-import json
-import os
-import resource
-from contextlib import nullcontext
-
-import torch
-from coati.dataset import DataCollatorForPreferenceDataset, StatefulDistributedSampler
-from coati.models import convert_to_lora_module, disable_dropout
-from coati.trainer import ORPOTrainer
-from coati.utils import load_checkpoint
-from dummy_dataset import DummyLLMDataset
-from transformers import AutoModelForCausalLM, AutoTokenizer
-
-import colossalai
-from colossalai.booster import Booster
-from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
-from colossalai.cluster import DistCoordinator
-from colossalai.logging import get_dist_logger
-from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
-from colossalai.nn.optimizer import HybridAdam
-
-logger = get_dist_logger()
-
-
-def train(args):
- # check lora compatibility
- if "gemini" in args.plugin and args.lora_rank > 0:
- raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
- if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
- raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
-
- # ==============================
- # Initialize Distributed Training
- # ==============================
- colossalai.launch_from_torch()
- coordinator = DistCoordinator()
-
- # ==============================
- # Initialize Booster
- # ==============================
- if args.plugin == "ddp":
- """
- Default torch ddp plugin without any acceleration, for
- debugging purpose acceleration, for debugging purpose
- """
- plugin = TorchDDPPlugin(find_unused_parameters=True)
- elif args.plugin == "gemini":
- plugin = GeminiPlugin(
- precision=args.mixed_precision,
- placement_policy="static",
- initial_scale=2**16,
- max_norm=args.grad_clip,
- enable_gradient_accumulation=True,
- enable_flash_attention=args.use_flash_attn,
- )
- elif args.plugin == "gemini_auto":
- plugin = GeminiPlugin(
- precision=args.mixed_precision,
- placement_policy="auto",
- initial_scale=2**16,
- max_norm=args.grad_clip,
- enable_flash_attention=args.use_flash_attn,
- )
- elif args.plugin == "zero2":
- plugin = LowLevelZeroPlugin(
- stage=2,
- precision=args.mixed_precision,
- initial_scale=2**16,
- max_norm=args.grad_clip,
- )
- elif args.plugin == "zero2_cpu":
- plugin = LowLevelZeroPlugin(
- stage=2,
- precision=args.mixed_precision,
- initial_scale=2**16,
- cpu_offload=True,
- max_norm=args.grad_clip,
- )
- elif args.plugin == "3d":
- plugin = HybridParallelPlugin(
- tp_size=args.tp,
- pp_size=args.pp,
- sp_size=args.sp,
- sequence_parallelism_mode=args.sp_mode,
- zero_stage=args.zero_stage,
- enable_flash_attention=args.use_flash_attn,
- enable_sequence_parallelism=args.enable_sequence_parallelism,
- cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
- parallel_output=False,
- max_norm=args.grad_clip,
- precision=args.mixed_precision,
- )
- else:
- raise ValueError(f"Unknown plugin {args.plugin}")
-
- booster = Booster(plugin=plugin)
-
- # ======================================================
- # Initialize Model, Objective, Optimizer and LR Scheduler
- # ======================================================
- # Temp Fix: Disable lazy init due to version conflict
- # init_ctx = (
- # LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
- # )
-
- init_ctx = nullcontext()
- with init_ctx:
- if args.use_flash_attn:
- model = AutoModelForCausalLM.from_pretrained(
- args.pretrain,
- torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
- use_flash_attention_2=True,
- )
- coordinator.print_on_master(msg="Flash-attention enabled successfully")
- else:
- model = AutoModelForCausalLM.from_pretrained(args.pretrain)
- disable_dropout(model)
- if args.lora_rank > 0:
- model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
-
- if args.grad_checkpoint:
- # Note, for some models, lora may not be compatible with gradient checkpointing
- model.gradient_checkpointing_enable()
- coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
-
- # configure tokenizer
- tokenizer_dir = args.tokenizer_dir if args.tokenizer_dir is not None else args.pretrain
- tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, use_fast=False, trust_remote_code=True)
- if hasattr(tokenizer, "pad_token") and hasattr(tokenizer, "eos_token") and tokenizer.eos_token is not None:
- try:
- # Some tokenizers doesn't allow to set pad_token mannually e.g., Qwen
- tokenizer.pad_token = tokenizer.eos_token
- except AttributeError as e:
- logger.warning(f"Unable to set pad token to eos token, {str(e)}")
- if not hasattr(tokenizer, "pad_token") or tokenizer.pad_token is None:
- logger.warning(
- "The tokenizer does not have a pad token which is required. May lead to unintended behavior in training, Please consider manually set them."
- )
-
- tokenizer.add_bos_token = False
- tokenizer.add_eos_token = False
-
- # configure optimizer
- optim = HybridAdam(
- model_params=model.parameters(),
- lr=args.lr,
- betas=(0.9, 0.95),
- weight_decay=args.weight_decay,
- adamw_mode=True,
- )
-
- # configure dataset
- coordinator.print_on_master(f"Load dataset: {args.dataset}")
- mode_map = {"train": "train", "valid": "validation", "test": "test"}
- train_dataset = DummyLLMDataset(
- ["chosen_input_ids", "chosen_loss_mask", "rejected_input_ids", "rejected_loss_mask"],
- args.max_length,
- args.dataset_size,
- )
- data_collator = DataCollatorForPreferenceDataset(tokenizer=tokenizer, max_length=args.max_length)
-
- train_dataloader = plugin.prepare_dataloader(
- dataset=train_dataset,
- batch_size=args.batch_size,
- shuffle=True,
- drop_last=True,
- collate_fn=data_collator,
- distributed_sampler_cls=StatefulDistributedSampler,
- )
-
- num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
- if args.warmup_steps is None:
- args.warmup_steps = int(args.max_epochs * 0.025 * (len(train_dataloader) // args.accumulation_steps))
- coordinator.print_on_master(f"Warmup steps is set to {args.warmup_steps}")
-
- lr_scheduler = CosineAnnealingWarmupLR(
- optimizer=optim,
- total_steps=args.max_epochs * num_update_steps_per_epoch,
- warmup_steps=args.warmup_steps,
- eta_min=0.1 * args.lr,
- )
-
- default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
- torch.set_default_dtype(default_dtype)
- model, optim, _, train_dataloader, lr_scheduler = booster.boost(
- model=model,
- optimizer=optim,
- lr_scheduler=lr_scheduler,
- dataloader=train_dataloader,
- )
- torch.set_default_dtype(torch.float)
-
- coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
- coordinator.print_on_master(
- f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
- )
-
- start_epoch = 0
- sampler_start_idx = 0
- start_step = 0
- if args.checkpoint_path is not None:
- if "modeling" in args.checkpoint_path:
- coordinator.print_on_master(f"Continued pretrain from checkpoint {args.checkpoint_path}")
- booster.load_model(model, args.checkpoint_path)
- else:
- coordinator.print_on_master(f"Load model checkpoint from {args.checkpoint_path}")
- start_epoch, start_step, sampler_start_idx = load_checkpoint(
- load_dir=args.checkpoint_path,
- booster=booster,
- model=model,
- optimizer=optim,
- lr_scheduler=lr_scheduler,
- )
- assert isinstance(train_dataloader.sampler, StatefulDistributedSampler)
- train_dataloader.sampler.set_start_index(start_index=sampler_start_idx)
-
- coordinator.print_on_master(
- f"Loaded checkpoint {args.checkpoint_path} at epoch {start_epoch} step {start_step}"
- )
- coordinator.print_on_master(f"Loaded sample at index {sampler_start_idx}")
-
- coordinator.print_on_master(
- f"Checkpoint loaded max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
- )
- coordinator.print_on_master(
- f"Checkpoint loaded CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB"
- )
- coordinator.print_on_master(
- f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
- )
-
- trainer = ORPOTrainer(
- actor=model,
- booster=booster,
- actor_optim=optim,
- actor_lr_scheduler=lr_scheduler,
- tokenizer=tokenizer,
- max_epochs=args.max_epochs,
- accumulation_steps=args.accumulation_steps,
- start_epoch=start_epoch,
- save_interval=None,
- save_dir=None,
- coordinator=coordinator,
- lam=args.lam,
- )
-
- trainer.fit(
- train_preference_dataloader=train_dataloader,
- eval_preference_dataloader=None,
- log_dir=None,
- use_wandb=False,
- )
- coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
-
-
-if __name__ == "__main__":
- # ==============================
- # Parse Arguments
- # ==============================
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--plugin",
- type=str,
- default="gemini",
- choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d"],
- help="Choose which plugin to use",
- )
- parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
- parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
- parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
- parser.add_argument("--tp", type=int, default=1)
- parser.add_argument("--pp", type=int, default=1)
- parser.add_argument("--sp", type=int, default=1)
- parser.add_argument("--lam", type=float, default=0.1, help="lambda in ORPO loss")
- parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true")
- parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
- parser.add_argument("--zero_cpu_offload", default=False, action="store_true")
- parser.add_argument("--sp_mode", type=str, default="split_gather", choices=["split_gather", "ring", "all_to_all"])
- parser.add_argument("--pretrain", type=str, default=None)
- parser.add_argument("--model_type", type=str, default=None)
- parser.add_argument("--tokenizer_dir", type=str, default=None)
- parser.add_argument("--dataset", nargs="+", default=[])
- parser.add_argument(
- "--checkpoint_path", type=str, default=None, help="Checkpoint path if need to resume training form a checkpoint"
- )
- parser.add_argument("--config_file", type=str, default="config_file", help="Config file")
- parser.add_argument("--max_length", type=int, default=2048, help="Model max length")
- parser.add_argument("--max_epochs", type=int, default=3)
- parser.add_argument("--batch_size", type=int, default=4)
- parser.add_argument(
- "--disable_reference_model",
- action="store_true",
- default=False,
- help="Disable the reference model (enabled by default)",
- )
- parser.add_argument("--dataset_size", type=int, default=500)
- parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
- parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
- parser.add_argument(
- "--lora_train_bias",
- type=str,
- default="none",
- help="'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers",
- )
- parser.add_argument("--merge_lora_weights", type=bool, default=True)
- parser.add_argument("--lr", type=float, default=5e-6)
- parser.add_argument("--accumulation_steps", type=int, default=8)
- parser.add_argument("--grad_checkpoint", default=False, action="store_true")
- parser.add_argument("--use_flash_attn", default=False, action="store_true")
- args = parser.parse_args()
- os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
- with open(args.config_file, "w") as f:
- json.dump(args.__dict__, f, indent=4)
- train(args)
diff --git a/applications/ColossalChat/benchmarks/benchmark_orpo.sh b/applications/ColossalChat/benchmarks/benchmark_orpo.sh
index cc6eef510..fa51a788f 100755
--- a/applications/ColossalChat/benchmarks/benchmark_orpo.sh
+++ b/applications/ColossalChat/benchmarks/benchmark_orpo.sh
@@ -15,20 +15,28 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() {
}
set_n_least_used_CUDA_VISIBLE_DEVICES 2
-PROJECT_NAME="dpo"
+PROJECT_NAME="orpo"
PARENT_CONFIG_FILE="./benchmark_config" # Path to a folder to save training config logs
-PRETRAINED_MODEL_PATH="" # huggingface or local model path
-PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
+PRETRAINED_MODEL_PATH="/root/commonData/Llama-2-7b-hf" # huggingface or local model path
+PRETRAINED_TOKENIZER_PATH="/root/commonData/Llama-2-7b-hf" # huggingface or local tokenizer path
+BENCHMARK_DATA_DIR="./temp/orpo" # Path to benchmark data
+DATASET_SIZE=160
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
-CONFIG_FILE="${PARENT_CONFIG_FILE}-${FULL_PROJECT_NAME}.json"
+declare -a dataset=(
+ $BENCHMARK_DATA_DIR/arrow/part-0
+)
-colossalai run --nproc_per_node 2 --master_port 31313 benchmark_orpo.py \
+# Generate dummy test data
+python prepare_dummy_test_dataset.py --data_dir $BENCHMARK_DATA_DIR --dataset_size $DATASET_SIZE --max_length 2048 --data_type preference
+
+
+colossalai run --nproc_per_node 2 --master_port 31313 ../examples/training_scripts/train_orpo.py \
--pretrain $PRETRAINED_MODEL_PATH \
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
+ --dataset ${dataset[@]} \
--plugin "zero2" \
- --config_file $CONFIG_FILE \
--max_epochs 1 \
--accumulation_steps 1 \
--batch_size 4 \
@@ -39,6 +47,5 @@ colossalai run --nproc_per_node 2 --master_port 31313 benchmark_orpo.py \
--max_length 2048 \
--weight_decay 0.01 \
--warmup_steps 60 \
- --dataset_size 160 \
--grad_checkpoint \
--use_flash_attn
diff --git a/applications/ColossalChat/benchmarks/benchmark_sft.py b/applications/ColossalChat/benchmarks/benchmark_sft.py
deleted file mode 100644
index b6438c503..000000000
--- a/applications/ColossalChat/benchmarks/benchmark_sft.py
+++ /dev/null
@@ -1,315 +0,0 @@
-import argparse
-import json
-import math
-import os
-import resource
-from contextlib import nullcontext
-
-import torch
-from coati.dataset import DataCollatorForSupervisedDataset, StatefulDistributedSampler
-from coati.models import convert_to_lora_module
-from coati.trainer import SFTTrainer
-from coati.utils import load_checkpoint
-from dummy_dataset import DummyLLMDataset
-from transformers import AutoModelForCausalLM, AutoTokenizer
-
-import colossalai
-from colossalai.booster import Booster
-from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
-from colossalai.cluster import DistCoordinator
-from colossalai.logging import get_dist_logger
-from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
-from colossalai.nn.optimizer import HybridAdam
-
-logger = get_dist_logger()
-
-
-def train(args):
- # check lora compatibility
- if "gemini" in args.plugin and args.lora_rank > 0:
- raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
- if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
- raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
- # ==============================
- # Initialize Distributed Training
- # ==============================
- colossalai.launch_from_torch()
- coordinator = DistCoordinator()
-
- # ==============================
- # Initialize Booster
- # ==============================
- init_ctx = nullcontext()
- with init_ctx:
- if args.use_flash_attn:
- model = AutoModelForCausalLM.from_pretrained(
- args.pretrain,
- torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
- attn_implementation="flash_attention_2",
- trust_remote_code=True,
- )
- else:
- model = AutoModelForCausalLM.from_pretrained(
- args.pretrain,
- torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
- trust_remote_code=True,
- )
- if args.lora_rank > 0:
- model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
-
- if args.plugin == "ddp":
- """
- Default torch ddp plugin without any acceleration, for
- debugging purpose acceleration, for debugging purpose
- """
- plugin = TorchDDPPlugin(find_unused_parameters=True)
- elif args.plugin == "gemini":
- plugin = GeminiPlugin(
- precision=args.mixed_precision,
- placement_policy="static",
- initial_scale=2**16,
- max_norm=args.grad_clip,
- enable_gradient_accumulation=True if args.accumulation_steps > 1 else False,
- enable_flash_attention=args.use_flash_attn,
- )
- elif args.plugin == "gemini_auto":
- plugin = GeminiPlugin(
- precision=args.mixed_precision,
- placement_policy="auto",
- initial_scale=2**16,
- max_norm=args.grad_clip,
- enable_flash_attention=args.use_flash_attn,
- )
- elif args.plugin == "zero2":
- plugin = LowLevelZeroPlugin(
- stage=2,
- precision=args.mixed_precision,
- initial_scale=2**16,
- max_norm=args.grad_clip,
- )
- elif args.plugin == "zero2_cpu":
- plugin = LowLevelZeroPlugin(
- stage=2,
- precision=args.mixed_precision,
- initial_scale=2**16,
- cpu_offload=True,
- max_norm=args.grad_clip,
- )
- elif args.plugin == "3d":
- plugin = HybridParallelPlugin(
- tp_size=args.tp,
- pp_size=args.pp,
- sp_size=args.sp,
- sequence_parallelism_mode=args.sp_mode,
- zero_stage=args.zero_stage,
- enable_flash_attention=args.use_flash_attn,
- enable_sequence_parallelism=args.enable_sequence_parallelism,
- cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
- parallel_output=False,
- max_norm=args.grad_clip,
- precision=args.mixed_precision,
- microbatch_size=args.batch_size,
- )
- else:
- raise ValueError(f"Unknown plugin {args.plugin}")
-
- booster = Booster(plugin=plugin)
-
- # ======================================================
- # Initialize Model, Objective, Optimizer and LR Scheduler
- # ======================================================
- # Temp Fix: Disable lazy init due to version conflict
- # init_ctx = (
- # LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
- # )
-
- if args.grad_checkpoint:
- # Note, for some models, lora may not be compatible with gradient checkpointing
- model.gradient_checkpointing_enable()
- coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
-
- # configure tokenizer
- tokenizer = AutoTokenizer.from_pretrained(
- args.tokenizer_dir or args.pretrain, use_fast=False, trust_remote_code=True
- )
- if hasattr(tokenizer, "pad_token") and hasattr(tokenizer, "eos_token") and tokenizer.eos_token is not None:
- try:
- # Some tokenizers doesn't allow to set pad_token mannually e.g., Qwen
- tokenizer.pad_token = tokenizer.eos_token
- except AttributeError as e:
- logger.warning(f"Unable to set pad token to eos token, {str(e)}")
- if not hasattr(tokenizer, "pad_token") or tokenizer.pad_token is None:
- logger.warning(
- "The tokenizer does not have a pad token which is required. May lead to unintended behavior in training, Please consider manually set them."
- )
-
- tokenizer.add_bos_token = False
- tokenizer.add_eos_token = False
- tokenizer.padding_side = "right"
-
- coordinator.print_on_master(f"Configuration file will be saved at: {args.config_file}")
-
- # configure optimizer
- optim = HybridAdam(
- model_params=model.parameters(),
- lr=args.lr,
- betas=(0.9, 0.95),
- weight_decay=args.weight_decay,
- adamw_mode=True,
- )
-
- # configure dataset
- coordinator.print_on_master(
- f"Max CUDA memory before data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
- )
- dataset = DummyLLMDataset(["input_ids", "attention_mask", "labels"], args.max_len, args.dataset_size)
- data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_len)
-
- train_dataloader = plugin.prepare_dataloader(
- dataset=dataset,
- batch_size=args.batch_size,
- shuffle=True,
- drop_last=True,
- collate_fn=data_collator,
- distributed_sampler_cls=StatefulDistributedSampler,
- )
- coordinator.print_on_master(
- f"Max CUDA memory after data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
- )
-
- num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
- math.ceil(args.max_epochs * num_update_steps_per_epoch)
-
- if args.warmup_steps is None:
- args.warmup_steps = int(args.max_epochs * 0.025 * (len(train_dataloader) // args.accumulation_steps))
- coordinator.print_on_master(f"Warmup steps is set to {args.warmup_steps}")
-
- lr_scheduler = CosineAnnealingWarmupLR(
- optimizer=optim,
- total_steps=args.max_epochs * num_update_steps_per_epoch,
- warmup_steps=args.warmup_steps,
- eta_min=0.1 * args.lr,
- )
-
- # Flash attention will be disabled because it does NOT support fp32.
- default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
- torch.set_default_dtype(default_dtype)
- model, optim, _, train_dataloader, lr_scheduler = booster.boost(
- model=model,
- optimizer=optim,
- lr_scheduler=lr_scheduler,
- dataloader=train_dataloader,
- )
- torch.set_default_dtype(torch.float)
-
- coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
- coordinator.print_on_master(
- f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
- )
-
- start_epoch = 0
- sampler_start_idx = 0
- start_step = 0
- if args.checkpoint_path is not None:
- if "modeling" in args.checkpoint_path:
- coordinator.print_on_master(f"Continued pretrain from checkpoint {args.checkpoint_path}")
- booster.load_model(model, args.checkpoint_path)
- else:
- coordinator.print_on_master(f"Load model checkpoint from {args.checkpoint_path}")
- start_epoch, start_step, sampler_start_idx = load_checkpoint(
- load_dir=args.checkpoint_path,
- booster=booster,
- model=model,
- optimizer=optim,
- lr_scheduler=lr_scheduler,
- )
- train_dataloader.sampler.set_start_index(start_index=sampler_start_idx)
-
- coordinator.print_on_master(
- f"Loaded checkpoint {args.checkpoint_path} at epoch {start_epoch} step {start_step}"
- )
- coordinator.print_on_master(f"Loaded sample at index {sampler_start_idx}")
-
- coordinator.print_on_master(
- f"Checkpoint loaded max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
- )
- coordinator.print_on_master(
- f"Checkpoint loaded CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB"
- )
- coordinator.print_on_master(
- f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
- )
-
- trainer = SFTTrainer(
- model=model,
- booster=booster,
- optim=optim,
- lr_scheduler=lr_scheduler,
- max_epochs=args.max_epochs,
- accumulation_steps=args.accumulation_steps,
- start_epoch=start_epoch,
- save_interval=None,
- save_dir=None,
- coordinator=coordinator,
- )
-
- trainer.fit(
- train_dataloader=train_dataloader,
- eval_dataloader=None,
- log_dir=None,
- use_wandb=False,
- )
-
- coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
-
-
-if __name__ == "__main__":
- # ==============================
- # Parse Arguments
- # ==============================
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--plugin",
- type=str,
- default="gemini",
- choices=["gemini", "gemini_auto", "3d", "ddp", "zero2_cpu", "zero2"],
- help="Choose which plugin to use",
- )
- parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
- parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
- parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
- parser.add_argument("--tp", type=int, default=1)
- parser.add_argument("--pp", type=int, default=1)
- parser.add_argument("--sp", type=int, default=1)
- parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true")
- parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
- parser.add_argument("--zero_cpu_offload", default=False, action="store_true")
- parser.add_argument("--sp_mode", type=str, default="split_gather", choices=["split_gather", "ring", "all_to_all"])
- parser.add_argument("--pretrain", type=str, default=None)
- parser.add_argument("--tokenizer_dir", type=str, default=None)
- parser.add_argument(
- "--checkpoint_path", type=str, default=None, help="Checkpoint path if need to resume training form a checkpoint"
- )
- parser.add_argument("--max_epochs", type=int, default=3)
- parser.add_argument("--batch_size", type=int, default=4)
- parser.add_argument("--max_len", type=int, default=512)
- parser.add_argument("--mixed_precision", type=str, default="bf16", choices=["fp16", "bf16"], help="Mixed precision")
- parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
- parser.add_argument(
- "--lora_train_bias",
- type=str,
- default="none",
- help="'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers",
- )
- parser.add_argument("--merge_lora_weights", type=bool, default=True)
- parser.add_argument("--lr", type=float, default=5e-6)
- parser.add_argument("--config_file", type=str, default="config_file", help="Config file")
- parser.add_argument("--accumulation_steps", type=int, default=8)
- parser.add_argument("--grad_checkpoint", default=False, action="store_true")
- parser.add_argument("--use_flash_attn", default=False, action="store_true")
- parser.add_argument("--dataset_size", type=int, default=500)
- args = parser.parse_args()
- os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
- with open(args.config_file, "w") as f:
- json.dump(args.__dict__, f, indent=4)
- train(args)
diff --git a/applications/ColossalChat/benchmarks/benchmark_sft.sh b/applications/ColossalChat/benchmarks/benchmark_sft.sh
index 70c0e5873..3d7e2ec16 100755
--- a/applications/ColossalChat/benchmarks/benchmark_sft.sh
+++ b/applications/ColossalChat/benchmarks/benchmark_sft.sh
@@ -17,18 +17,28 @@ set_n_least_used_CUDA_VISIBLE_DEVICES 4
PROJECT_NAME="sft"
PARENT_CONFIG_FILE="./benchmark_config" # Path to a folder to save training config logs
-PRETRAINED_MODEL_PATH="" # huggingface or local model path
-PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
+PRETRAINED_MODEL_PATH="/root/commonData/Llama-2-7b-hf" # huggingface or local model path
+PRETRAINED_TOKENIZER_PATH="/root/commonData/Llama-2-7b-hf" # huggingface or local tokenizer path
+BENCHMARK_DATA_DIR="./temp/sft" # Path to benchmark data
+DATASET_SIZE=640
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
CONFIG_FILE="${PARENT_CONFIG_FILE}-${FULL_PROJECT_NAME}.json"
+declare -a dataset=(
+ $BENCHMARK_DATA_DIR/arrow/part-0
+)
+
+
+# Generate dummy test data
+python prepare_dummy_test_dataset.py --data_dir $BENCHMARK_DATA_DIR --dataset_size $DATASET_SIZE --max_length 2048 --data_type sft
+
# the real batch size for gradient descent is number_of_node_in_hostfile * nproc_per_node * train_batch_size
-colossalai run --nproc_per_node 1 --master_port 31312 benchmark_sft.py \
+colossalai run --nproc_per_node 1 --master_port 31312 ../examples/training_scripts/train_sft.py \
--pretrain $PRETRAINED_MODEL_PATH \
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
- --config_file $CONFIG_FILE \
+ --dataset ${dataset[@]} \
--plugin zero2 \
--batch_size 8 \
--max_epochs 1 \
@@ -36,6 +46,5 @@ colossalai run --nproc_per_node 1 --master_port 31312 benchmark_sft.py \
--lr 5e-5 \
--lora_rank 32 \
--max_len 2048 \
- --dataset_size 640 \
--grad_checkpoint \
--use_flash_attn
diff --git a/applications/ColossalChat/benchmarks/benchmark_simpo.sh b/applications/ColossalChat/benchmarks/benchmark_simpo.sh
new file mode 100755
index 000000000..5d41c34af
--- /dev/null
+++ b/applications/ColossalChat/benchmarks/benchmark_simpo.sh
@@ -0,0 +1,55 @@
+#!/bin/bash
+set_n_least_used_CUDA_VISIBLE_DEVICES() {
+ local n=${1:-"9999"}
+ echo "GPU Memory Usage:"
+ local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
+ tail -n +2 |
+ nl -v 0 |
+ tee /dev/tty |
+ sort -g -k 2 |
+ awk '{print $1}' |
+ head -n $n)
+ export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
+ echo "Now CUDA_VISIBLE_DEVICES is set to:"
+ echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
+}
+set_n_least_used_CUDA_VISIBLE_DEVICES 4
+
+PROJECT_NAME="simpo"
+PARENT_CONFIG_FILE="./benchmark_config" # Path to a folder to save training config logs
+PRETRAINED_MODEL_PATH="/root/commonData/Llama-2-7b-hf" # huggingface or local model path
+PRETRAINED_TOKENIZER_PATH="/root/commonData/Llama-2-7b-hf" # huggingface or local tokenizer path
+BENCHMARK_DATA_DIR="./temp/simpo" # Path to benchmark data
+DATASET_SIZE=640
+
+TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
+FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
+declare -a dataset=(
+ $BENCHMARK_DATA_DIR/arrow/part-0
+)
+
+# Generate dummy test data
+python prepare_dummy_test_dataset.py --data_dir $BENCHMARK_DATA_DIR --dataset_size $DATASET_SIZE --max_length 2048 --data_type preference
+
+
+colossalai run --nproc_per_node 4 --master_port 31313 ../examples/training_scripts/train_dpo.py \
+ --pretrain $PRETRAINED_MODEL_PATH \
+ --tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
+ --dataset ${dataset[@]} \
+ --plugin "zero2_cpu" \
+ --loss_type "simpo_loss" \
+ --max_epochs 1 \
+ --accumulation_steps 1 \
+ --batch_size 8 \
+ --lr 1e-6 \
+ --beta 0.1 \
+ --gamma 0.6 \
+ --mixed_precision "bf16" \
+ --grad_clip 1.0 \
+ --max_length 2048 \
+ --weight_decay 0.01 \
+ --warmup_steps 60 \
+ --disable_reference_model \
+ --length_normalization \
+ --grad_checkpoint \
+ --use_flash_attn
diff --git a/applications/ColossalChat/benchmarks/dummy_dataset.py b/applications/ColossalChat/benchmarks/dummy_dataset.py
index 5218e659b..9af0f1641 100644
--- a/applications/ColossalChat/benchmarks/dummy_dataset.py
+++ b/applications/ColossalChat/benchmarks/dummy_dataset.py
@@ -1,6 +1,5 @@
from typing import Callable
-import torch
from torch.utils.data import Dataset
@@ -18,7 +17,7 @@ class DummyLLMDataset(Dataset):
if key in self.gen_fn:
data[key] = self.gen_fn[key]
else:
- data[key] = torch.ones(self.seq_len, dtype=torch.long)
+ data[key] = [1] * self.seq_len
return data
def __len__(self):
diff --git a/applications/ColossalChat/benchmarks/prepare_dummy_test_dataset.py b/applications/ColossalChat/benchmarks/prepare_dummy_test_dataset.py
new file mode 100644
index 000000000..70a422208
--- /dev/null
+++ b/applications/ColossalChat/benchmarks/prepare_dummy_test_dataset.py
@@ -0,0 +1,105 @@
+import argparse
+import json
+import os
+import time
+from multiprocessing import cpu_count
+
+from datasets import load_dataset
+from dummy_dataset import DummyLLMDataset
+
+from colossalai.logging import get_dist_logger
+
+logger = get_dist_logger()
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--data_dir",
+ type=str,
+ required=True,
+ default=None,
+ help="The output dir",
+ )
+ parser.add_argument(
+ "--dataset_size",
+ type=int,
+ required=True,
+ default=None,
+ help="The size of data",
+ )
+ parser.add_argument(
+ "--max_length",
+ type=int,
+ required=True,
+ default=None,
+ help="The max length of data",
+ )
+ parser.add_argument(
+ "--data_type",
+ type=str,
+ required=True,
+ default=None,
+ help="The type of data",
+ )
+ args = parser.parse_args()
+ if args.data_type == "sft":
+ dataset = DummyLLMDataset(["input_ids", "attention_mask", "labels"], args.max_length, args.dataset_size)
+ elif args.data_type == "prompt":
+ # pass PPO dataset is prepared separately
+ pass
+ elif args.data_type == "preference":
+ dataset = DummyLLMDataset(
+ ["chosen_input_ids", "chosen_loss_mask", "rejected_input_ids", "rejected_loss_mask"],
+ args.max_length,
+ args.dataset_size,
+ )
+ elif args.data_type == "kto":
+ dataset = DummyLLMDataset(
+ ["prompt", "completion", "label"],
+ args.max_length - 512,
+ args.dataset_size,
+ gen_fn={
+ "completion": lambda x: [1] * 512,
+ "label": lambda x: x % 2,
+ },
+ )
+ else:
+ raise ValueError(f"Unknown data type {args.data_type}")
+
+ # Save each jsonl spliced dataset.
+ output_index = "0"
+ output_name = f"part-{output_index}"
+ os.makedirs(args.data_dir, exist_ok=True)
+ output_jsonl_path = os.path.join(args.data_dir, "json")
+ output_arrow_path = os.path.join(args.data_dir, "arrow")
+ output_cache_path = os.path.join(args.data_dir, "cache")
+ os.makedirs(output_jsonl_path, exist_ok=True)
+ os.makedirs(output_arrow_path, exist_ok=True)
+ output_jsonl_file_path = os.path.join(output_jsonl_path, output_name + ".jsonl")
+ st = time.time()
+ with open(file=output_jsonl_file_path, mode="w", encoding="utf-8") as fp_writer:
+ count = 0
+ for i in range(len(dataset)):
+ data_point = dataset[i]
+ if count % 500 == 0:
+ logger.info(f"processing {count} spliced data points for {fp_writer.name}")
+ count += 1
+ fp_writer.write(json.dumps(data_point, ensure_ascii=False) + "\n")
+ logger.info(
+ f"Current file {fp_writer.name}; "
+ f"Data size: {len(dataset)}; "
+ f"Time cost: {round((time.time() - st) / 60, 6)} minutes."
+ )
+ # Save each arrow spliced dataset
+ output_arrow_file_path = os.path.join(output_arrow_path, output_name)
+ logger.info(f"Start to save {output_arrow_file_path}")
+ dataset = load_dataset(
+ path="json",
+ data_files=[output_jsonl_file_path],
+ cache_dir=os.path.join(output_cache_path, "tokenized"),
+ keep_in_memory=False,
+ num_proc=cpu_count(),
+ split="train",
+ )
+ dataset.save_to_disk(dataset_path=output_arrow_file_path, num_proc=min(len(dataset), cpu_count()))
diff --git a/applications/ColossalChat/examples/training_scripts/train_dpo.py b/applications/ColossalChat/examples/training_scripts/train_dpo.py
index 06d7133ca..3659fb868 100755
--- a/applications/ColossalChat/examples/training_scripts/train_dpo.py
+++ b/applications/ColossalChat/examples/training_scripts/train_dpo.py
@@ -287,9 +287,12 @@ def train(args):
LORA_MANAGER.merge_weights = True
model.eval()
# save model checkpoint after fitting on only rank0
- coordinator.print_on_master("Start saving final model checkpoint")
- booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
- coordinator.print_on_master(f"Saved final model checkpoint at epoch {args.max_epochs} at folder {args.save_dir}")
+ if args.save_dir is not None:
+ coordinator.print_on_master("Start saving final model checkpoint")
+ booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
+ coordinator.print_on_master(
+ f"Saved final model checkpoint at epoch {args.max_epochs} at folder {args.save_dir}"
+ )
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
@@ -328,8 +331,8 @@ if __name__ == "__main__":
parser.add_argument(
"--checkpoint_path", type=str, default=None, help="Checkpoint path if need to resume training form a checkpoint"
)
- parser.add_argument("--config_file", type=str, default="config_file", help="Config file")
- parser.add_argument("--save_dir", type=str, default="output")
+ parser.add_argument("--config_file", type=str, default=None, help="Config file")
+ parser.add_argument("--save_dir", type=str, default=None)
parser.add_argument("--max_length", type=int, default=2048, help="Model max length")
parser.add_argument("--max_epochs", type=int, default=3)
parser.add_argument("--batch_size", type=int, default=4)
@@ -351,7 +354,7 @@ if __name__ == "__main__":
parser.add_argument("--merge_lora_weights", type=bool, default=True)
parser.add_argument("--lr", type=float, default=5e-6)
parser.add_argument("--accumulation_steps", type=int, default=8)
- parser.add_argument("--log_dir", default="logs", type=str)
+ parser.add_argument("--log_dir", default=None, type=str)
parser.add_argument("--use_wandb", default=False, action="store_true")
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
parser.add_argument("--use_flash_attn", default=False, action="store_true")
@@ -362,7 +365,8 @@ if __name__ == "__main__":
args.length_normalization = True
args.gamma = args.gamma if args.gamma > 0 else 1.4
- os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
- with open(args.config_file, "w") as f:
- json.dump(args.__dict__, f, indent=4)
+ if args.config_file is not None:
+ os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
+ with open(args.config_file, "w") as f:
+ json.dump(args.__dict__, f, indent=4)
train(args)
diff --git a/applications/ColossalChat/examples/training_scripts/train_dpo.sh b/applications/ColossalChat/examples/training_scripts/train_dpo.sh
index 082d54ff0..4d49bc218 100755
--- a/applications/ColossalChat/examples/training_scripts/train_dpo.sh
+++ b/applications/ColossalChat/examples/training_scripts/train_dpo.sh
@@ -18,6 +18,7 @@ set_n_least_used_CUDA_VISIBLE_DEVICES 4
PROJECT_NAME="DPO"
PARENT_SAVE_DIR="" # Path to a folder to save checkpoints
PARENT_CONFIG_FILE="" # Path to a folder to save training config logs
+PARENT_LOG_DIR="" # Path to a folder to save training config logs
PRETRAINED_MODEL_PATH="" # huggingface or local model path
PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
@@ -38,6 +39,7 @@ TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}"
CONFIG_FILE="${PARENT_CONFIG_FILE}${FULL_PROJECT_NAME}.json"
+LOG_DIR="${PARENT_LOG_DIR}${FULL_PROJECT_NAME}"
colossalai run --nproc_per_node 4 --hostfile hostfile --master_port 31313 train_dpo.py \
--pretrain $PRETRAINED_MODEL_PATH \
@@ -47,6 +49,7 @@ colossalai run --nproc_per_node 4 --hostfile hostfile --master_port 31313 train_
--save_interval 1000 \
--save_dir $SAVE_DIR \
--config_file $CONFIG_FILE \
+ --log_dir $LOG_DIR \
--max_epochs 1 \
--accumulation_steps 2 \
--batch_size 16 \
diff --git a/applications/ColossalChat/examples/training_scripts/train_kto.py b/applications/ColossalChat/examples/training_scripts/train_kto.py
index e219974aa..8d1408423 100755
--- a/applications/ColossalChat/examples/training_scripts/train_kto.py
+++ b/applications/ColossalChat/examples/training_scripts/train_kto.py
@@ -174,14 +174,16 @@ def train(args):
# Check if the user specified weights fit into the theoratical lower and upper bounds from Eq. (8) of https://arxiv.org/abs/2402.01306
actual_ratio = (args.desirable_weight * num_desirable) / (args.undesirable_weight * num_undesirable)
- if actual_ratio <= 1:
- raise AssertionError(
- f"Desirable weight and undesirable weight are not within the theoratical bounds, [1, 4/3]. Actual ratio: {actual_ratio}, please increase desirable weight or decrease undesirable weight."
- )
- elif actual_ratio > 4 / 3:
- raise AssertionError(
- f"Desirable weight and undesirable weight are not within the theoratical bounds, [1, 4/3]. Actual ratio: {actual_ratio}, please decrease desirable weight or increase undesirable weight."
- )
+ if actual_ratio < 1 or actual_ratio > 4 / 3:
+ if not args.auto_weight:
+ raise AssertionError(
+ f"Desirable weight and undesirable weight are not within the theoratical bounds, [1, 4/3]. Actual ratio: {actual_ratio}, please increase/decrease desirable weight or decrease/increase undesirable weight."
+ )
+ else:
+ args.desirable_weight = args.desirable_weight / actual_ratio
+ coordinator.print_on_master(
+ f"Desirable weight and undesirable weight are not within the theoratical bounds, [1, 4/3]. Actual ratio: {actual_ratio}, auto weight is enabled, set desirable weight to {args.desirable_weight} and undesirable weight to {args.undesirable_weight}"
+ )
data_collator = DataCollatorForKTODataset(tokenizer=tokenizer, max_length=args.max_length)
@@ -304,9 +306,12 @@ def train(args):
LORA_MANAGER.merge_weights = True
model.eval()
# save model checkpoint after fitting on only rank0
- coordinator.print_on_master("Start saving final model checkpoint")
- booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
- coordinator.print_on_master(f"Saved final model checkpoint at epoch {args.max_epochs} at folder {args.save_dir}")
+ if args.save_dir is not None:
+ coordinator.print_on_master("Start saving final model checkpoint")
+ booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
+ coordinator.print_on_master(
+ f"Saved final model checkpoint at epoch {args.max_epochs} at folder {args.save_dir}"
+ )
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
@@ -343,8 +348,8 @@ if __name__ == "__main__":
parser.add_argument(
"--checkpoint_path", type=str, default=None, help="Checkpoint path if need to resume training form a checkpoint"
)
- parser.add_argument("--config_file", type=str, default="config_file", help="Config file")
- parser.add_argument("--save_dir", type=str, default="output")
+ parser.add_argument("--config_file", type=str, default=None, help="Config file")
+ parser.add_argument("--save_dir", type=str, default=None)
parser.add_argument("--max_length", type=int, default=2048, help="Model max length")
parser.add_argument("--max_epochs", type=int, default=3)
parser.add_argument("--batch_size", type=int, default=4)
@@ -359,14 +364,16 @@ if __name__ == "__main__":
)
parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints")
parser.add_argument("--merge_lora_weights", type=bool, default=True)
+ parser.add_argument("--auto_weight", default=False, action="store_true")
parser.add_argument("--lr", type=float, default=5e-6)
parser.add_argument("--accumulation_steps", type=int, default=8)
- parser.add_argument("--log_dir", default="logs", type=str)
+ parser.add_argument("--log_dir", default=None, type=str)
parser.add_argument("--use_wandb", default=False, action="store_true")
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
parser.add_argument("--use_flash_attn", default=False, action="store_true")
args = parser.parse_args()
- os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
- with open(args.config_file, "w") as f:
- json.dump(args.__dict__, f, indent=4)
+ if args.config_file is not None:
+ os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
+ with open(args.config_file, "w") as f:
+ json.dump(args.__dict__, f, indent=4)
train(args)
diff --git a/applications/ColossalChat/examples/training_scripts/train_kto.sh b/applications/ColossalChat/examples/training_scripts/train_kto.sh
index 3dcf6af02..c28338c22 100755
--- a/applications/ColossalChat/examples/training_scripts/train_kto.sh
+++ b/applications/ColossalChat/examples/training_scripts/train_kto.sh
@@ -19,6 +19,7 @@ PROJECT_NAME="kto"
PARENT_SAVE_DIR="" # Path to a folder to save checkpoints
PARENT_TENSORBOARD_DIR="" # Path to a folder to save logs
PARENT_CONFIG_FILE="" # Path to a folder to save training config logs
+PARENT_LOG_DIR="" # Path to a folder to save training config logs
PRETRAINED_MODEL_PATH="" # huggingface or local model path
PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
@@ -39,6 +40,7 @@ TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}"
CONFIG_FILE="${PARENT_CONFIG_FILE}-${FULL_PROJECT_NAME}.json"
+LOG_DIR="${PARENT_LOG_DIR}${FULL_PROJECT_NAME}"
colossalai run --nproc_per_node 4 --master_port 31313 train_kto.py \
--pretrain $PRETRAINED_MODEL_PATH \
@@ -48,9 +50,11 @@ colossalai run --nproc_per_node 4 --master_port 31313 train_kto.py \
--save_interval 1000 \
--save_dir $SAVE_DIR \
--config_file $CONFIG_FILE \
+ --log_dir $LOG_DIR \
--max_epochs 1 \
--accumulation_steps 1 \
--batch_size 8 \
+ --auto_weight \
--lr 1e-5 \
--beta 0.1 \
--mixed_precision "bf16" \
diff --git a/applications/ColossalChat/examples/training_scripts/train_orpo.py b/applications/ColossalChat/examples/training_scripts/train_orpo.py
index 886aa39dd..3efe5e96a 100755
--- a/applications/ColossalChat/examples/training_scripts/train_orpo.py
+++ b/applications/ColossalChat/examples/training_scripts/train_orpo.py
@@ -269,9 +269,12 @@ def train(args):
LORA_MANAGER.merge_weights = True
model.eval()
# save model checkpoint after fitting on only rank0
- coordinator.print_on_master("Start saving final model checkpoint")
- booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
- coordinator.print_on_master(f"Saved final model checkpoint at epoch {args.max_epochs} at folder {args.save_dir}")
+ if args.save_dir is not None:
+ coordinator.print_on_master("Start saving final model checkpoint")
+ booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
+ coordinator.print_on_master(
+ f"Saved final model checkpoint at epoch {args.max_epochs} at folder {args.save_dir}"
+ )
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
@@ -307,8 +310,8 @@ if __name__ == "__main__":
parser.add_argument(
"--checkpoint_path", type=str, default=None, help="Checkpoint path if need to resume training form a checkpoint"
)
- parser.add_argument("--config_file", type=str, default="config_file", help="Config file")
- parser.add_argument("--save_dir", type=str, default="output")
+ parser.add_argument("--config_file", type=str, default=None, help="Config file")
+ parser.add_argument("--save_dir", type=str, default=None)
parser.add_argument("--max_length", type=int, default=2048, help="Model max length")
parser.add_argument("--max_epochs", type=int, default=3)
parser.add_argument("--batch_size", type=int, default=4)
@@ -330,12 +333,13 @@ if __name__ == "__main__":
parser.add_argument("--merge_lora_weights", type=bool, default=True)
parser.add_argument("--lr", type=float, default=5e-6)
parser.add_argument("--accumulation_steps", type=int, default=8)
- parser.add_argument("--log_dir", default="logs", type=str)
+ parser.add_argument("--log_dir", default=None, type=str)
parser.add_argument("--use_wandb", default=False, action="store_true")
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
parser.add_argument("--use_flash_attn", default=False, action="store_true")
args = parser.parse_args()
- os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
- with open(args.config_file, "w") as f:
- json.dump(args.__dict__, f, indent=4)
+ if args.config_file is not None:
+ os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
+ with open(args.config_file, "w") as f:
+ json.dump(args.__dict__, f, indent=4)
train(args)
diff --git a/applications/ColossalChat/examples/training_scripts/train_orpo.sh b/applications/ColossalChat/examples/training_scripts/train_orpo.sh
index 482956b21..48327e014 100755
--- a/applications/ColossalChat/examples/training_scripts/train_orpo.sh
+++ b/applications/ColossalChat/examples/training_scripts/train_orpo.sh
@@ -18,6 +18,7 @@ set_n_least_used_CUDA_VISIBLE_DEVICES 2
PROJECT_NAME="ORPO"
PARENT_SAVE_DIR="" # Path to a folder to save checkpoints
PARENT_CONFIG_FILE="" # Path to a folder to save training config logs
+PARENT_LOG_DIR="" # Path to a folder to save training config logs
PRETRAINED_MODEL_PATH="" # huggingface or local model path
PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
@@ -38,6 +39,7 @@ TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}"
CONFIG_FILE="${PARENT_CONFIG_FILE}${FULL_PROJECT_NAME}.json"
+LOG_DIR="${PARENT_LOG_DIR}${FULL_PROJECT_NAME}"
colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 31313 train_orpo.py \
--pretrain $PRETRAINED_MODEL_PATH \
@@ -47,6 +49,7 @@ colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 31313 train_
--save_interval 1000 \
--save_dir $SAVE_DIR \
--config_file $CONFIG_FILE \
+ --log_dir $LOG_DIR \
--max_epochs 3 \
--accumulation_steps 1 \
--batch_size 16 \
diff --git a/applications/ColossalChat/examples/training_scripts/train_rm.py b/applications/ColossalChat/examples/training_scripts/train_rm.py
index f8e894e7e..09b569b7a 100755
--- a/applications/ColossalChat/examples/training_scripts/train_rm.py
+++ b/applications/ColossalChat/examples/training_scripts/train_rm.py
@@ -284,9 +284,12 @@ def train(args):
LORA_MANAGER.merge_weights = True
model.eval()
# save model checkpoint after fitting on only rank0
- coordinator.print_on_master("Start saving final model checkpoint")
- booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
- coordinator.print_on_master(f"Saved final model checkpoint at epoch {args.max_epochs} at folder {args.save_dir}")
+ if args.save_dir is not None:
+ coordinator.print_on_master("Start saving final model checkpoint")
+ booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
+ coordinator.print_on_master(
+ f"Saved final model checkpoint at epoch {args.max_epochs} at folder {args.save_dir}"
+ )
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
@@ -320,8 +323,8 @@ if __name__ == "__main__":
parser.add_argument(
"--checkpoint_path", type=str, default=None, help="Checkpoint path if need to resume training form a checkpoint"
)
- parser.add_argument("--config_file", type=str, default="config_file", help="Config file")
- parser.add_argument("--save_dir", type=str, default="output")
+ parser.add_argument("--config_file", type=str, default=None, help="Config file")
+ parser.add_argument("--save_dir", type=str, default=None)
parser.add_argument("--max_length", type=int, default=2048, help="Model max length")
parser.add_argument("--max_epochs", type=int, default=3)
parser.add_argument("--batch_size", type=int, default=4)
@@ -338,12 +341,13 @@ if __name__ == "__main__":
parser.add_argument("--merge_lora_weights", type=bool, default=True)
parser.add_argument("--lr", type=float, default=5e-6)
parser.add_argument("--accumulation_steps", type=int, default=8)
- parser.add_argument("--log_dir", default="logs", type=str)
+ parser.add_argument("--log_dir", default=None, type=str)
parser.add_argument("--use_wandb", default=False, action="store_true")
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
parser.add_argument("--use_flash_attn", default=False, action="store_true")
args = parser.parse_args()
- os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
- with open(args.config_file, "w") as f:
- json.dump(args.__dict__, f, indent=4)
+ if args.config_file is not None:
+ os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
+ with open(args.config_file, "w") as f:
+ json.dump(args.__dict__, f, indent=4)
train(args)
diff --git a/applications/ColossalChat/examples/training_scripts/train_rm.sh b/applications/ColossalChat/examples/training_scripts/train_rm.sh
index cd42afcc8..274417c03 100755
--- a/applications/ColossalChat/examples/training_scripts/train_rm.sh
+++ b/applications/ColossalChat/examples/training_scripts/train_rm.sh
@@ -18,6 +18,7 @@ set_n_least_used_CUDA_VISIBLE_DEVICES 8
PROJECT_NAME="RM"
PARENT_SAVE_DIR="" # Path to a folder to save checkpoints
PARENT_CONFIG_FILE="" # Path to a folder to save training config logs
+PARENT_LOG_DIR="" # Path to a folder to save training config logs
PRETRAINED_MODEL_PATH="" # huggingface or local model path
PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
@@ -38,6 +39,7 @@ TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}"
CONFIG_FILE="${PARENT_CONFIG_FILE}${FULL_PROJECT_NAME}.json"
+LOG_DIR="${PARENT_LOG_DIR}${FULL_PROJECT_NAME}"
colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 31312 train_rm.py \
--pretrain $PRETRAINED_MODEL_PATH \
@@ -47,6 +49,7 @@ colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 31312 train_
--save_interval 1000 \
--save_dir $SAVE_DIR \
--config_file $CONFIG_FILE \
+ --log_dir $LOG_DIR \
--max_epochs 3 \
--accumulation_steps 1 \
--batch_size 8 \
diff --git a/applications/ColossalChat/examples/training_scripts/train_sft.py b/applications/ColossalChat/examples/training_scripts/train_sft.py
index b89cbeb91..2579293c1 100755
--- a/applications/ColossalChat/examples/training_scripts/train_sft.py
+++ b/applications/ColossalChat/examples/training_scripts/train_sft.py
@@ -284,10 +284,12 @@ def train(args):
LORA_MANAGER.merge_weights = True
model.eval()
# save model checkpoint after fitting on only rank0
- coordinator.print_on_master("Start saving final model checkpoint")
-
- booster.save_model(model, os.path.join(args.save_path, "modeling"), shard=True)
- coordinator.print_on_master(f"Saved final model checkpoint at epoch {args.max_epochs} at folder {args.save_path}")
+ if args.save_path is not None:
+ coordinator.print_on_master("Start saving final model checkpoint")
+ booster.save_model(model, os.path.join(args.save_path, "modeling"), shard=True)
+ coordinator.print_on_master(
+ f"Saved final model checkpoint at epoch {args.max_epochs} at folder {args.save_path}"
+ )
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
@@ -321,7 +323,7 @@ if __name__ == "__main__":
parser.add_argument(
"--checkpoint_path", type=str, default=None, help="Checkpoint path if need to resume training form a checkpoint"
)
- parser.add_argument("--save_path", type=str, default="output")
+ parser.add_argument("--save_path", type=str, default=None)
parser.add_argument("--max_epochs", type=int, default=3)
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--max_len", type=int, default=512)
@@ -336,14 +338,15 @@ if __name__ == "__main__":
parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints")
parser.add_argument("--merge_lora_weights", type=bool, default=True)
parser.add_argument("--lr", type=float, default=5e-6)
- parser.add_argument("--config_file", type=str, default="config_file", help="Config file")
+ parser.add_argument("--config_file", type=str, default=None, help="Config file")
parser.add_argument("--accumulation_steps", type=int, default=8)
- parser.add_argument("--log_dir", default="logs", type=str)
+ parser.add_argument("--log_dir", default=None, type=str)
parser.add_argument("--use_wandb", default=False, action="store_true")
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
parser.add_argument("--use_flash_attn", default=False, action="store_true")
args = parser.parse_args()
- os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
- with open(args.config_file, "w") as f:
- json.dump(args.__dict__, f, indent=4)
+ if args.config_file is not None:
+ os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
+ with open(args.config_file, "w") as f:
+ json.dump(args.__dict__, f, indent=4)
train(args)
diff --git a/applications/ColossalChat/examples/training_scripts/train_sft.sh b/applications/ColossalChat/examples/training_scripts/train_sft.sh
index a4c95f463..988c7f9c3 100755
--- a/applications/ColossalChat/examples/training_scripts/train_sft.sh
+++ b/applications/ColossalChat/examples/training_scripts/train_sft.sh
@@ -17,6 +17,7 @@ set_n_least_used_CUDA_VISIBLE_DEVICES 4
PROJECT_NAME="SFT"
PARENT_SAVE_DIR="" # Path to a folder to save checkpoints
PARENT_CONFIG_FILE="" # Path to a folder to save training config logs
+PARENT_LOG_DIR="" # Path to a folder to save training config logs
PRETRAINED_MODEL_PATH="" # huggingface or local model path
PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
declare -a dataset=(
@@ -36,6 +37,7 @@ TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}"
CONFIG_FILE="${PARENT_CONFIG_FILE}${FULL_PROJECT_NAME}.json"
+LOG_DIR="${PARENT_LOG_DIR}${FULL_PROJECT_NAME}"
echo $(which colossalai)
echo $(which python)
@@ -47,6 +49,7 @@ colossalai run --nproc_per_node 4 --master_port 31312 --hostfile ./hostfile trai
--dataset ${dataset[@]} \
--save_path $SAVE_DIR \
--config_file $CONFIG_FILE \
+ --log_dir $LOG_DIR \
--lora_rank 0 \
--plugin zero2 \
--batch_size 8 \
diff --git a/applications/ColossalChat/tests/test_train.sh b/applications/ColossalChat/tests/test_train.sh
index eb39a83ab..57584e406 100755
--- a/applications/ColossalChat/tests/test_train.sh
+++ b/applications/ColossalChat/tests/test_train.sh
@@ -591,6 +591,7 @@ for lora_rank in ${LORA_RANK[@]}; do
--accumulation_steps $grad_accu \
--tp $tp \
--lr 2e-5 \
+ --auto_weight \
--desirable_weight 1.2 \
$grad_ckpt \
--max_len 400 \
From b0e15d563ee9b010f97520c1cad7ec779945af7c Mon Sep 17 00:00:00 2001
From: YeAnbang
Date: Mon, 22 Jul 2024 06:11:38 +0000
Subject: [PATCH 039/116] remove real data path
---
applications/ColossalChat/benchmarks/benchmark_dpo.sh | 4 ++--
applications/ColossalChat/benchmarks/benchmark_kto.sh | 4 ++--
applications/ColossalChat/benchmarks/benchmark_orpo.sh | 4 ++--
applications/ColossalChat/benchmarks/benchmark_sft.sh | 4 ++--
applications/ColossalChat/benchmarks/benchmark_simpo.sh | 4 ++--
.../data_preparation_scripts/prepare_sft_dataset.sh | 6 +++---
6 files changed, 13 insertions(+), 13 deletions(-)
diff --git a/applications/ColossalChat/benchmarks/benchmark_dpo.sh b/applications/ColossalChat/benchmarks/benchmark_dpo.sh
index 08ce0629c..44d821a87 100755
--- a/applications/ColossalChat/benchmarks/benchmark_dpo.sh
+++ b/applications/ColossalChat/benchmarks/benchmark_dpo.sh
@@ -17,8 +17,8 @@ set_n_least_used_CUDA_VISIBLE_DEVICES 4
PROJECT_NAME="dpo"
PARENT_CONFIG_FILE="./benchmark_config" # Path to a folder to save training config logs
-PRETRAINED_MODEL_PATH="/root/commonData/Llama-2-7b-hf" # huggingface or local model path
-PRETRAINED_TOKENIZER_PATH="/root/commonData/Llama-2-7b-hf" # huggingface or local tokenizer path
+PRETRAINED_MODEL_PATH="" # huggingface or local model path
+PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
BENCHMARK_DATA_DIR="./temp/dpo" # Path to benchmark data
DATASET_SIZE=320
diff --git a/applications/ColossalChat/benchmarks/benchmark_kto.sh b/applications/ColossalChat/benchmarks/benchmark_kto.sh
index 41de40f13..82d3e3421 100755
--- a/applications/ColossalChat/benchmarks/benchmark_kto.sh
+++ b/applications/ColossalChat/benchmarks/benchmark_kto.sh
@@ -17,8 +17,8 @@ set_n_least_used_CUDA_VISIBLE_DEVICES 4
PROJECT_NAME="kto"
PARENT_CONFIG_FILE="./benchmark_config" # Path to a folder to save training config logs
-PRETRAINED_MODEL_PATH="/root/commonData/Llama-2-7b-hf" # huggingface or local model path
-PRETRAINED_TOKENIZER_PATH="/root/commonData/Llama-2-7b-hf" # huggingface or local tokenizer path
+PRETRAINED_MODEL_PATH="" # huggingface or local model path
+PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
BENCHMARK_DATA_DIR="./temp/kto" # Path to benchmark data
DATASET_SIZE=80
diff --git a/applications/ColossalChat/benchmarks/benchmark_orpo.sh b/applications/ColossalChat/benchmarks/benchmark_orpo.sh
index fa51a788f..f8fb264ae 100755
--- a/applications/ColossalChat/benchmarks/benchmark_orpo.sh
+++ b/applications/ColossalChat/benchmarks/benchmark_orpo.sh
@@ -17,8 +17,8 @@ set_n_least_used_CUDA_VISIBLE_DEVICES 2
PROJECT_NAME="orpo"
PARENT_CONFIG_FILE="./benchmark_config" # Path to a folder to save training config logs
-PRETRAINED_MODEL_PATH="/root/commonData/Llama-2-7b-hf" # huggingface or local model path
-PRETRAINED_TOKENIZER_PATH="/root/commonData/Llama-2-7b-hf" # huggingface or local tokenizer path
+PRETRAINED_MODEL_PATH="" # huggingface or local model path
+PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
BENCHMARK_DATA_DIR="./temp/orpo" # Path to benchmark data
DATASET_SIZE=160
diff --git a/applications/ColossalChat/benchmarks/benchmark_sft.sh b/applications/ColossalChat/benchmarks/benchmark_sft.sh
index 3d7e2ec16..efcd428dd 100755
--- a/applications/ColossalChat/benchmarks/benchmark_sft.sh
+++ b/applications/ColossalChat/benchmarks/benchmark_sft.sh
@@ -17,8 +17,8 @@ set_n_least_used_CUDA_VISIBLE_DEVICES 4
PROJECT_NAME="sft"
PARENT_CONFIG_FILE="./benchmark_config" # Path to a folder to save training config logs
-PRETRAINED_MODEL_PATH="/root/commonData/Llama-2-7b-hf" # huggingface or local model path
-PRETRAINED_TOKENIZER_PATH="/root/commonData/Llama-2-7b-hf" # huggingface or local tokenizer path
+PRETRAINED_MODEL_PATH="" # huggingface or local model path
+PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
BENCHMARK_DATA_DIR="./temp/sft" # Path to benchmark data
DATASET_SIZE=640
diff --git a/applications/ColossalChat/benchmarks/benchmark_simpo.sh b/applications/ColossalChat/benchmarks/benchmark_simpo.sh
index 5d41c34af..47dfc8595 100755
--- a/applications/ColossalChat/benchmarks/benchmark_simpo.sh
+++ b/applications/ColossalChat/benchmarks/benchmark_simpo.sh
@@ -17,8 +17,8 @@ set_n_least_used_CUDA_VISIBLE_DEVICES 4
PROJECT_NAME="simpo"
PARENT_CONFIG_FILE="./benchmark_config" # Path to a folder to save training config logs
-PRETRAINED_MODEL_PATH="/root/commonData/Llama-2-7b-hf" # huggingface or local model path
-PRETRAINED_TOKENIZER_PATH="/root/commonData/Llama-2-7b-hf" # huggingface or local tokenizer path
+PRETRAINED_MODEL_PATH="" # huggingface or local model path
+PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
BENCHMARK_DATA_DIR="./temp/simpo" # Path to benchmark data
DATASET_SIZE=640
diff --git a/applications/ColossalChat/examples/data_preparation_scripts/prepare_sft_dataset.sh b/applications/ColossalChat/examples/data_preparation_scripts/prepare_sft_dataset.sh
index dbd323438..c4b6fec9b 100755
--- a/applications/ColossalChat/examples/data_preparation_scripts/prepare_sft_dataset.sh
+++ b/applications/ColossalChat/examples/data_preparation_scripts/prepare_sft_dataset.sh
@@ -5,9 +5,9 @@ rm -rf $SAVE_DIR/jsonl
rm -rf $SAVE_DIR/arrow
python prepare_dataset.py --type sft \
- --data_input_dirs /home/nvme-share/home/yeanbang/data/dataset/hh_rlhf/sft \
- --conversation_template_config /home/nvme-share/home/yeanbang/ColossalAI/applications/ColossalChat/config/conversation_template/llama2.json \
- --tokenizer_dir "/home/nvme-share/share/models/Sheared-LLaMA-1.3B" \
+ --data_input_dirs /PATH/TO/SFT/DATASET \
+ --conversation_template_config /PATH/TO/CHAT/TEMPLATE/CONFIG.json \
+ --tokenizer_dir "" \
--data_cache_dir $SAVE_DIR/cache \
--data_jsonl_output_dir $SAVE_DIR/jsonl \
--data_arrow_output_dir $SAVE_DIR/arrow \
From 9688e19b322510bd64956a75bd6227850817cc69 Mon Sep 17 00:00:00 2001
From: YeAnbang
Date: Mon, 22 Jul 2024 06:13:02 +0000
Subject: [PATCH 040/116] remove real data path
---
.../examples/data_preparation_scripts/prepare_sft_dataset.sh | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/applications/ColossalChat/examples/data_preparation_scripts/prepare_sft_dataset.sh b/applications/ColossalChat/examples/data_preparation_scripts/prepare_sft_dataset.sh
index c4b6fec9b..84bae0027 100755
--- a/applications/ColossalChat/examples/data_preparation_scripts/prepare_sft_dataset.sh
+++ b/applications/ColossalChat/examples/data_preparation_scripts/prepare_sft_dataset.sh
@@ -1,4 +1,4 @@
-SAVE_DIR="/home/nvme-share/home/yeanbang/data/experiments/sft"
+SAVE_DIR=""
rm -rf $SAVE_DIR/cache
rm -rf $SAVE_DIR/jsonl
From a521ffc9f8accb547c46a3a2da854b6d3f00aa39 Mon Sep 17 00:00:00 2001
From: Insu Jang
Date: Tue, 23 Jul 2024 11:15:39 -0400
Subject: [PATCH 041/116] Add n_fused as an input from native_module (#5894)
---
colossalai/shardformer/layer/qkv_fused_linear.py | 1 +
1 file changed, 1 insertion(+)
diff --git a/colossalai/shardformer/layer/qkv_fused_linear.py b/colossalai/shardformer/layer/qkv_fused_linear.py
index 0f6595a7c..000934ad9 100644
--- a/colossalai/shardformer/layer/qkv_fused_linear.py
+++ b/colossalai/shardformer/layer/qkv_fused_linear.py
@@ -695,6 +695,7 @@ class FusedLinear1D_Col(ParallelModule):
process_group=process_group,
weight=module.weight,
bias_=module.bias,
+ n_fused=n_fused,
*args,
**kwargs,
)
From 5fb958cc83f9abc36d0fa06a443cd2b0b21b3bc6 Mon Sep 17 00:00:00 2001
From: "Gao, Ruiyuan" <905370712@qq.com>
Date: Wed, 24 Jul 2024 10:30:40 +0800
Subject: [PATCH 042/116] [FIX BUG] convert env param to int in (#5934)
---
colossalai/cluster/dist_coordinator.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/colossalai/cluster/dist_coordinator.py b/colossalai/cluster/dist_coordinator.py
index 98191747e..14a8eabb4 100644
--- a/colossalai/cluster/dist_coordinator.py
+++ b/colossalai/cluster/dist_coordinator.py
@@ -44,7 +44,7 @@ class DistCoordinator(metaclass=SingletonMeta):
self._rank = dist.get_rank()
self._world_size = dist.get_world_size()
# this is often passed by launchers such as torchrun
- self._local_rank = os.environ.get("LOCAL_RANK", -1)
+ self._local_rank = int(os.environ.get("LOCAL_RANK", -1))
@property
def rank(self) -> int:
From 2069472e96d0d177abd8766b20982c68ffe9d95e Mon Sep 17 00:00:00 2001
From: Edenzzzz
Date: Thu, 25 Jul 2024 09:59:58 +0800
Subject: [PATCH 043/116] [Hotfix] Fix ZeRO typo #5936
Co-authored-by: Edenzzzz
---
colossalai/zero/low_level/low_level_optim.py | 10 +++++-----
1 file changed, 5 insertions(+), 5 deletions(-)
diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py
index 6ff235b96..f1db7c540 100644
--- a/colossalai/zero/low_level/low_level_optim.py
+++ b/colossalai/zero/low_level/low_level_optim.py
@@ -338,14 +338,14 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
self._update_unpartitoned_grad(bucket_store, grad_in_bucket.values(), flat_grads_per_rank, group_id)
else:
flat_grads_list = list(flat_grads.split(len(flat_grads) // bucket_store.world_size))
- recieved_grad = torch.zeros_like(flat_grads_list[0])
- dist.reduce_scatter(recieved_grad, flat_grads_list, group=bucket_store.torch_pg)
+ received_grad = torch.zeros_like(flat_grads_list[0])
+ dist.reduce_scatter(received_grad, flat_grads_list, group=bucket_store.torch_pg)
- if recieved_grad.dtype != grad_dtype:
- recieved_grad = recieved_grad.to(grad_dtype)
+ if received_grad.dtype != grad_dtype:
+ received_grad = received_grad.to(grad_dtype)
grad_in_bucket_current_rank = bucket_store.get_grad()[bucket_store.local_rank]
- self._update_partitoned_grad(bucket_store, grad_in_bucket_current_rank, recieved_grad, group_id, 1)
+ self._update_partitoned_grad(bucket_store, grad_in_bucket_current_rank, received_grad, group_id, 1)
bucket_store.reset()
From ad35a987d3f46c7ca01bd1db60601af279712a24 Mon Sep 17 00:00:00 2001
From: zhurunhua <1281592874@qq.com>
Date: Fri, 26 Jul 2024 11:15:20 +0800
Subject: [PATCH 044/116] [Feature] Add a switch to control whether the model
checkpoint needs to be saved after each epoch ends (#5941)
* Add a switch to control whether the model checkpoint needs to be saved after each epoch ends
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
---
applications/Colossal-LLaMA/train.py | 18 +++++++++++++++---
1 file changed, 15 insertions(+), 3 deletions(-)
diff --git a/applications/Colossal-LLaMA/train.py b/applications/Colossal-LLaMA/train.py
index 43a360a9a..e74aad33c 100644
--- a/applications/Colossal-LLaMA/train.py
+++ b/applications/Colossal-LLaMA/train.py
@@ -128,6 +128,12 @@ def main() -> None:
parser.add_argument("--zero", type=int, default=1)
parser.add_argument("--pad_token", choices=["eos", "unk"], default="eos")
parser.add_argument("--padding_mode", choices=["max_length", "longest"], default="max_length")
+ parser.add_argument(
+ "--skip_save_each_epoch",
+ action="store_true",
+ default=False,
+ help="skip saving the model checkpoint after each epoch is completed.",
+ )
args = parser.parse_args()
with open(args.config_file, "w") as f:
@@ -370,11 +376,17 @@ def main() -> None:
)
total_loss.fill_(0.0)
pbar.update()
+
# Save modeling.
- if (args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0) or (
- step + 1
- ) == len(dataloader):
+ save_model_condition = (
+ args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0
+ )
+
+ if not args.skip_save_each_epoch:
+ save_model_condition = save_model_condition or (step + 1) == len(dataloader)
+
+ if save_model_condition:
coordinator.print_on_master("\nStart saving model checkpoint with running states")
if args.use_neft:
From 8a3ff4f3153e4887587c9128c48a3f79c8727394 Mon Sep 17 00:00:00 2001
From: YeAnbang
Date: Fri, 26 Jul 2024 09:55:15 +0000
Subject: [PATCH 045/116] fix style
---
.../ColossalChat/benchmarks/prepare_dummy_test_dataset.py | 2 +-
.../ColossalChat/coati/dataset/tokenization_utils.py | 7 +------
applications/ColossalChat/coati/trainer/dpo.py | 2 +-
applications/ColossalChat/coati/trainer/kto.py | 2 +-
applications/ColossalChat/coati/trainer/orpo.py | 2 +-
.../tests/generate_dummy_datasets_for_testing.py | 8 ++++----
.../ColossalChat/tests/test_data/kto/test_kto_data.jsonl | 2 +-
7 files changed, 10 insertions(+), 15 deletions(-)
diff --git a/applications/ColossalChat/benchmarks/prepare_dummy_test_dataset.py b/applications/ColossalChat/benchmarks/prepare_dummy_test_dataset.py
index 70a422208..f501c5358 100644
--- a/applications/ColossalChat/benchmarks/prepare_dummy_test_dataset.py
+++ b/applications/ColossalChat/benchmarks/prepare_dummy_test_dataset.py
@@ -40,7 +40,7 @@ if __name__ == "__main__":
type=str,
required=True,
default=None,
- help="The type of data",
+ help="The type of data, choose one from ['sft', 'prompt', 'preference', 'kto']",
)
args = parser.parse_args()
if args.data_type == "sft":
diff --git a/applications/ColossalChat/coati/dataset/tokenization_utils.py b/applications/ColossalChat/coati/dataset/tokenization_utils.py
index cf69b6fe7..d573853d6 100755
--- a/applications/ColossalChat/coati/dataset/tokenization_utils.py
+++ b/applications/ColossalChat/coati/dataset/tokenization_utils.py
@@ -46,8 +46,7 @@ def supervised_tokenize_sft(
max_length: the maximum context length
"""
- if ignore_index is None:
- ignore_index = IGNORE_INDEX
+ ignore_index = IGNORE_INDEX
messages = data_point["messages"]
template = deepcopy(conversation_template)
@@ -146,8 +145,6 @@ def tokenize_prompt_dataset(
ignore_index: the ignore index when calculate loss during training
max_length: the maximum context length
"""
- if ignore_index is None:
- ignore_index = IGNORE_INDEX
messages = data_point["messages"]
template = deepcopy(conversation_template)
@@ -226,8 +223,6 @@ def tokenize_rlhf(
{"context": [{"from": "user", "content": "xxx"}, {"from": "assistant", "content": "xxx"}],
"chosen": {"from": "assistant", "content": "xxx"}, "rejected": {"from": "assistant", "content": "xxx"}}
"""
- if ignore_index is None:
- ignore_index = IGNORE_INDEX
context = data_point["context"]
template = deepcopy(conversation_template)
diff --git a/applications/ColossalChat/coati/trainer/dpo.py b/applications/ColossalChat/coati/trainer/dpo.py
index 3daab54f6..c7ef2be8f 100755
--- a/applications/ColossalChat/coati/trainer/dpo.py
+++ b/applications/ColossalChat/coati/trainer/dpo.py
@@ -26,7 +26,7 @@ from .utils import is_rank_0, to_device
class DPOTrainer(SLTrainer):
"""
- Trainer for PPO algorithm.
+ Trainer for DPO algorithm.
Args:
actor (Actor): the actor model in ppo algorithm
diff --git a/applications/ColossalChat/coati/trainer/kto.py b/applications/ColossalChat/coati/trainer/kto.py
index 83bff3bbe..8ab0bc66b 100755
--- a/applications/ColossalChat/coati/trainer/kto.py
+++ b/applications/ColossalChat/coati/trainer/kto.py
@@ -27,7 +27,7 @@ from .utils import is_rank_0, to_device
class KTOTrainer(SLTrainer):
"""
- Trainer for PPO algorithm.
+ Trainer for KTO algorithm.
Args:
actor (Actor): the actor model in ppo algorithm
diff --git a/applications/ColossalChat/coati/trainer/orpo.py b/applications/ColossalChat/coati/trainer/orpo.py
index 495bb332b..b039da4af 100644
--- a/applications/ColossalChat/coati/trainer/orpo.py
+++ b/applications/ColossalChat/coati/trainer/orpo.py
@@ -26,7 +26,7 @@ from .utils import is_rank_0, to_device
class ORPOTrainer(SLTrainer):
"""
- Trainer for PPO algorithm.
+ Trainer for ORPO algorithm.
Args:
actor (Actor): the actor model in ppo algorithm
diff --git a/applications/ColossalChat/tests/generate_dummy_datasets_for_testing.py b/applications/ColossalChat/tests/generate_dummy_datasets_for_testing.py
index 6cce6563c..e50b20b6b 100644
--- a/applications/ColossalChat/tests/generate_dummy_datasets_for_testing.py
+++ b/applications/ColossalChat/tests/generate_dummy_datasets_for_testing.py
@@ -39,15 +39,15 @@ preference_seed = {
}
kto_seed = {
"prompt": [
- {"from": "user", "content": "What are some cuss words in english?"},
+ {"from": "user", "content": "What are some praise words in english?"},
{
"from": "assistant",
- "content": "Here's an incomplete list.\n\nAss, dick, bugger, crap, fuck, shit, bitch, turd, shithead, shitbag, scrotum, cunt, whore, fucker, shit-eating, cum, cumbucket, fucknugget, butthole, poop, jackass, cocksucker, asshole, goddamn, piss, sperm, blow, wank, jism, cum-sucking, masturbate, faggot, queer, jizz, jizz-licking, prostitute, slut, cheater, fornicator, floozy, wetback, Mexican, Hispanic, sodomite, midget, mama's boy, faggot, pervert, queer, scumbag, bitch,",
+ "content": "Here's an incomplete list.\n\nexcellent, fantastic, impressive ...",
},
{"from": "user", "content": "What's your favorite one?"},
],
- "completion": {"from": "assistant", "content": "Ass."},
- "label": False,
+ "completion": {"from": "assistant", "content": "Impressive."},
+ "label": True,
}
diff --git a/applications/ColossalChat/tests/test_data/kto/test_kto_data.jsonl b/applications/ColossalChat/tests/test_data/kto/test_kto_data.jsonl
index f0e78b1c6..4f4fce83d 100644
--- a/applications/ColossalChat/tests/test_data/kto/test_kto_data.jsonl
+++ b/applications/ColossalChat/tests/test_data/kto/test_kto_data.jsonl
@@ -1 +1 @@
-{"prompt": [{"from": "user", "content": "What are some cuss words in english?"}, {"from": "assistant", "content": "Here's an incomplete list.\n\nAss, dick, bugger, crap, ..."},{"from": "user", "content": "What's your favorite one?"}],"completion": {"from": "assistant", "content": "Ass."},"label": false}
+{"prompt": [{"from": "user", "content": "What are some praise words in english?"}, {"from": "assistant", "content": "Here's an incomplete list.\n\nexcellent, fantastic, impressive ..."},{"from": "user", "content": "What's your favorite one?"}],"completion": {"from": "assistant", "content": "impressive."},"label": true}
From de1bf08ed06876b0f970880a4a7bf773b12a579a Mon Sep 17 00:00:00 2001
From: YeAnbang
Date: Fri, 26 Jul 2024 10:07:15 +0000
Subject: [PATCH 046/116] fix style
---
applications/ColossalChat/coati/dataset/__init__.py | 7 +++----
.../ColossalChat/coati/dataset/tokenization_utils.py | 7 ++-----
.../data_preparation_scripts/prepare_dataset.py | 12 +++---------
.../data_preparation_scripts/prepare_kto_dataset.sh | 2 +-
4 files changed, 9 insertions(+), 19 deletions(-)
diff --git a/applications/ColossalChat/coati/dataset/__init__.py b/applications/ColossalChat/coati/dataset/__init__.py
index 25ce16ed1..8e9060a1a 100755
--- a/applications/ColossalChat/coati/dataset/__init__.py
+++ b/applications/ColossalChat/coati/dataset/__init__.py
@@ -7,10 +7,10 @@ from .loader import (
StatefulDistributedSampler,
load_tokenized_dataset,
)
-from .tokenization_utils import supervised_tokenize_sft, tokenize_kto, tokenize_prompt_dataset, tokenize_rlhf
+from .tokenization_utils import tokenize_kto, tokenize_prompt, tokenize_rlhf, tokenize_sft
__all__ = [
- "tokenize_prompt_dataset",
+ "tokenize_prompt",
"DataCollatorForPromptDataset",
"is_rank_0",
"DataCollatorForPreferenceDataset",
@@ -18,8 +18,7 @@ __all__ = [
"DataCollatorForKTODataset",
"StatefulDistributedSampler",
"load_tokenized_dataset",
- "supervised_tokenize_pretrain",
- "supervised_tokenize_sft",
+ "tokenize_sft",
"tokenize_rlhf",
"tokenize_kto",
"setup_conversation_template",
diff --git a/applications/ColossalChat/coati/dataset/tokenization_utils.py b/applications/ColossalChat/coati/dataset/tokenization_utils.py
index d573853d6..2cbf11d1f 100755
--- a/applications/ColossalChat/coati/dataset/tokenization_utils.py
+++ b/applications/ColossalChat/coati/dataset/tokenization_utils.py
@@ -23,11 +23,10 @@ IGNORE_INDEX = -100
DSType = Union[Dataset, ConcatDataset, dataset_dict.Dataset]
-def supervised_tokenize_sft(
+def tokenize_sft(
data_point: Dict[str, str],
tokenizer: PreTrainedTokenizer,
conversation_template: Conversation = None,
- ignore_index: int = None,
max_length: int = 4096,
) -> Dict[str, Union[int, str, List[int]]]:
"""
@@ -127,11 +126,10 @@ def supervised_tokenize_sft(
)
-def tokenize_prompt_dataset(
+def tokenize_prompt(
data_point: Dict[str, str],
tokenizer: PreTrainedTokenizer,
conversation_template: Conversation = None,
- ignore_index: int = None,
max_length: int = 4096,
) -> Dict[str, Union[int, str, List[int]]]:
"""
@@ -215,7 +213,6 @@ def tokenize_rlhf(
data_point: Dict[str, str],
tokenizer: PreTrainedTokenizer,
conversation_template: Conversation = None,
- ignore_index: int = None,
max_length: int = 4096,
) -> Dict[str, Union[int, str, List[int]]]:
"""
diff --git a/applications/ColossalChat/examples/data_preparation_scripts/prepare_dataset.py b/applications/ColossalChat/examples/data_preparation_scripts/prepare_dataset.py
index 503b11419..a35f2bf52 100644
--- a/applications/ColossalChat/examples/data_preparation_scripts/prepare_dataset.py
+++ b/applications/ColossalChat/examples/data_preparation_scripts/prepare_dataset.py
@@ -40,13 +40,7 @@ import random
import time
from multiprocessing import cpu_count
-from coati.dataset import (
- setup_conversation_template,
- supervised_tokenize_sft,
- tokenize_kto,
- tokenize_prompt_dataset,
- tokenize_rlhf,
-)
+from coati.dataset import setup_conversation_template, tokenize_kto, tokenize_prompt, tokenize_rlhf, tokenize_sft
from datasets import dataset_dict, load_dataset
from transformers import AutoTokenizer
@@ -205,9 +199,9 @@ def main():
)
if args.type == "sft":
- preparation_function = supervised_tokenize_sft
+ preparation_function = tokenize_sft
elif args.type == "prompt":
- preparation_function = tokenize_prompt_dataset
+ preparation_function = tokenize_prompt
elif args.type == "preference":
preparation_function = tokenize_rlhf
elif args.type == "kto":
diff --git a/applications/ColossalChat/examples/data_preparation_scripts/prepare_kto_dataset.sh b/applications/ColossalChat/examples/data_preparation_scripts/prepare_kto_dataset.sh
index 0450d570d..42c785289 100755
--- a/applications/ColossalChat/examples/data_preparation_scripts/prepare_kto_dataset.sh
+++ b/applications/ColossalChat/examples/data_preparation_scripts/prepare_kto_dataset.sh
@@ -1,4 +1,4 @@
-SAVE_DIR="/home/nvme-share/home/yeanbang/data/experiments/kto"
+SAVE_DIR=""
rm -rf $SAVE_DIR/cache
rm -rf $SAVE_DIR/jsonl
From 6fd9e8686409ff6f96f49dd63570dfcadee2284e Mon Sep 17 00:00:00 2001
From: YeAnbang
Date: Mon, 29 Jul 2024 01:29:18 +0000
Subject: [PATCH 047/116] fix style
---
.../ColossalChat/coati/dataset/tokenization_utils.py | 2 +-
applications/ColossalChat/coati/dataset/utils.py | 9 +++++----
2 files changed, 6 insertions(+), 5 deletions(-)
diff --git a/applications/ColossalChat/coati/dataset/tokenization_utils.py b/applications/ColossalChat/coati/dataset/tokenization_utils.py
index 2cbf11d1f..9eb2eba87 100755
--- a/applications/ColossalChat/coati/dataset/tokenization_utils.py
+++ b/applications/ColossalChat/coati/dataset/tokenization_utils.py
@@ -193,7 +193,7 @@ def apply_rlhf_data_format(template: Conversation, tokenizer: Any):
template.messages[: 2 * target_turn], prompt, template.end_of_assistant
)
# no truncation applied
- tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss, max_length=int(1e10))
+ tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss, max_length=None)
loss_mask = [0] * len(tokenized)
label_decode = []
diff --git a/applications/ColossalChat/coati/dataset/utils.py b/applications/ColossalChat/coati/dataset/utils.py
index cf767b444..42c3191db 100755
--- a/applications/ColossalChat/coati/dataset/utils.py
+++ b/applications/ColossalChat/coati/dataset/utils.py
@@ -119,17 +119,18 @@ def tokenize_and_concatenate(
loss_ends = []
for s, r in zip(text, require_loss):
tokenized = tokenizer(s, add_special_tokens=False)["input_ids"]
- if len(input_ids) + len(tokenized) <= max_length or len(loss_ends) == 0:
+ if not max_length or len(input_ids) + len(tokenized) <= max_length or len(loss_ends) == 0:
if r:
loss_starts.append(len(input_ids))
loss_ends.append(len(input_ids) + len(tokenized))
input_ids.extend(tokenized)
- if loss_starts[0] >= max_length:
+ if max_length and loss_starts[0] >= max_length:
return None, None, None
if discard_non_loss_tokens_at_tail:
input_ids = input_ids[: loss_ends[-1]]
- input_ids = input_ids[:max_length]
- loss_ends[-1] = min(max_length, loss_ends[-1])
+ if max_length:
+ input_ids = input_ids[:max_length]
+ loss_ends[-1] = min(max_length, loss_ends[-1])
return input_ids, loss_starts, loss_ends
From 9664b1bc190c57518fd76f4a3740feea3dc38ffd Mon Sep 17 00:00:00 2001
From: Hongxin Liu
Date: Mon, 29 Jul 2024 13:58:27 +0800
Subject: [PATCH 048/116] [shardformer] hotfix attn mask (#5945)
---
colossalai/shardformer/modeling/command.py | 2 +-
colossalai/shardformer/modeling/llama.py | 2 +-
colossalai/shardformer/modeling/mistral.py | 2 +-
colossalai/shardformer/modeling/qwen2.py | 8 ++++++--
4 files changed, 9 insertions(+), 5 deletions(-)
diff --git a/colossalai/shardformer/modeling/command.py b/colossalai/shardformer/modeling/command.py
index 759c8d7b8..5b36fc7db 100644
--- a/colossalai/shardformer/modeling/command.py
+++ b/colossalai/shardformer/modeling/command.py
@@ -116,7 +116,7 @@ class CommandPipelineForwards:
# for the other stages, hidden_states is the output of the previous stage
if shard_config.enable_flash_attention:
# in this case, attention_mask is a dict rather than a tensor
- mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past)
+ mask_shape = (batch_size, 1, seq_length, seq_length_with_past)
attention_mask = ColoAttention.prepare_attn_kwargs(
mask_shape,
hidden_states.dtype,
diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py
index 54ff8e321..9ffbca517 100644
--- a/colossalai/shardformer/modeling/llama.py
+++ b/colossalai/shardformer/modeling/llama.py
@@ -643,7 +643,7 @@ def get_llama_flash_attention_model_forward(shard_config, sp_mode=None, sp_size=
# in this case, attention_mask is a dict rather than a tensor
if shard_config.enable_flash_attention:
- mask_shape = (inputs_embeds.shape[0], 1, past_seen_tokens + seq_len, past_seen_tokens + seq_len)
+ mask_shape = (inputs_embeds.shape[0], 1, seq_len, past_seen_tokens + seq_len)
attention_mask = ColoAttention.prepare_attn_kwargs(
mask_shape,
inputs_embeds.dtype,
diff --git a/colossalai/shardformer/modeling/mistral.py b/colossalai/shardformer/modeling/mistral.py
index 82e8ef5f9..ec1a8a00a 100644
--- a/colossalai/shardformer/modeling/mistral.py
+++ b/colossalai/shardformer/modeling/mistral.py
@@ -91,7 +91,7 @@ class MistralForwards:
if shard_config.enable_flash_attention:
# in this case, attention_mask is a dict rather than a tensor
- mask_shape = (batch_size, 1, seq_length, seq_length)
+ mask_shape = (batch_size, 1, seq_length, seq_length + past_key_values_length)
attention_mask = ColoAttention.prepare_attn_kwargs(
mask_shape,
hidden_states.dtype,
diff --git a/colossalai/shardformer/modeling/qwen2.py b/colossalai/shardformer/modeling/qwen2.py
index 55822b150..538e96c32 100644
--- a/colossalai/shardformer/modeling/qwen2.py
+++ b/colossalai/shardformer/modeling/qwen2.py
@@ -136,7 +136,7 @@ class Qwen2PipelineForwards:
# for the other stages, hidden_states is the output of the previous stage
if shard_config.enable_flash_attention:
# in this case, attention_mask is a dict rather than a tensor
- mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past)
+ mask_shape = (batch_size, 1, seq_length, seq_length_with_past)
attention_mask = ColoAttention.prepare_attn_kwargs(
mask_shape,
hidden_states.dtype,
@@ -651,6 +651,10 @@ def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig, sp_mode=No
seq_length_with_past = seq_length
past_key_values_length = 0
+ if past_key_values is not None:
+ past_key_values_length = past_key_values[0][0].shape[2]
+ seq_length_with_past = seq_length_with_past + past_key_values_length
+
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
@@ -668,7 +672,7 @@ def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig, sp_mode=No
if shard_config.enable_flash_attention:
# in this case, attention_mask is a dict rather than a tensor
- mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past)
+ mask_shape = (batch_size, 1, seq_length, seq_length_with_past)
attention_mask = ColoAttention.prepare_attn_kwargs(
mask_shape,
hidden_states.dtype,
From 7b38964e3a977aa92d40200059d9076b622bf9b7 Mon Sep 17 00:00:00 2001
From: Hongxin Liu
Date: Mon, 29 Jul 2024 19:10:06 +0800
Subject: [PATCH 049/116] [shardformer] hotfix attn mask (#5947)
---
colossalai/shardformer/layer/attn.py | 14 +++++++-------
1 file changed, 7 insertions(+), 7 deletions(-)
diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py
index 141baf3d3..5872c6485 100644
--- a/colossalai/shardformer/layer/attn.py
+++ b/colossalai/shardformer/layer/attn.py
@@ -139,12 +139,11 @@ class ColoAttention:
# no padding
assert is_causal
outputs["attention_mask_type"] = AttnMaskType.CAUSAL
- attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device).tril(diagonal=0).expand(b, s_q, s_kv)
+ attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device)
+ if s_q != 1:
+ attention_mask = attention_mask.tril(diagonal=0)
+ attention_mask = attention_mask.expand(b, s_q, s_kv)
else:
- assert q_padding_mask.shape == (
- b,
- s_q,
- ), f"q_padding_mask shape {q_padding_mask.shape} should be the same. ({shape_4d})"
max_seqlen_q, cu_seqlens_q, q_indices = get_pad_info(q_padding_mask)
if kv_padding_mask is None:
# self attention
@@ -156,7 +155,7 @@ class ColoAttention:
b,
s_kv,
), f"q_padding_mask shape {kv_padding_mask.shape} should be the same. ({shape_4d})"
- attention_mask = q_padding_mask[:, None, :].expand(b, s_kv, s_q).to(dtype=dtype, device=device)
+ attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device)
outputs.update(
{
"cu_seqlens_q": cu_seqlens_q,
@@ -169,7 +168,8 @@ class ColoAttention:
)
if is_causal:
outputs["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL
- attention_mask = attention_mask * attention_mask.new_ones(s_q, s_kv).tril(diagonal=0)
+ if s_q != 1:
+ attention_mask = attention_mask * attention_mask.new_ones(s_q, s_kv).tril(diagonal=0)
else:
outputs["attention_mask_type"] = AttnMaskType.PADDED
attention_mask = invert_mask(attention_mask).unsqueeze(1)
From bcf0181ecd3e1e639200b66d6e1aab6c6b3d5b7b Mon Sep 17 00:00:00 2001
From: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Date: Tue, 30 Jul 2024 10:43:26 +0800
Subject: [PATCH 050/116] [Feat] Distrifusion Acceleration Support for
Diffusion Inference (#5895)
* Distrifusion Support source
* comp comm overlap optimization
* sd3 benchmark
* pixart distrifusion bug fix
* sd3 bug fix and benchmark
* generation bug fix
* naming fix
* add docstring, fix counter and shape error
* add reference
* readme and requirement
---
colossalai/inference/README.md | 12 +-
colossalai/inference/config.py | 16 +
colossalai/inference/core/diffusion_engine.py | 2 +-
.../modeling/{models => layers}/diffusion.py | 0
.../inference/modeling/layers/distrifusion.py | 626 ++++++++++++++++++
.../inference/modeling/models/pixart_alpha.py | 2 +-
.../modeling/models/stablediffusion3.py | 2 +-
.../inference/modeling/policy/pixart_alpha.py | 49 +-
.../modeling/policy/stablediffusion3.py | 48 +-
examples/inference/stable_diffusion/README.md | 22 +
.../stable_diffusion/benchmark_sd3.py | 179 +++++
.../stable_diffusion/compute_metric.py | 80 +++
.../stable_diffusion/requirements.txt | 3 +
.../stable_diffusion/run_benchmark.sh | 42 ++
.../stable_diffusion/sd3_generation.py | 22 +-
15 files changed, 1089 insertions(+), 16 deletions(-)
rename colossalai/inference/modeling/{models => layers}/diffusion.py (100%)
create mode 100644 colossalai/inference/modeling/layers/distrifusion.py
create mode 100644 examples/inference/stable_diffusion/README.md
create mode 100644 examples/inference/stable_diffusion/benchmark_sd3.py
create mode 100644 examples/inference/stable_diffusion/compute_metric.py
create mode 100644 examples/inference/stable_diffusion/requirements.txt
create mode 100644 examples/inference/stable_diffusion/run_benchmark.sh
diff --git a/colossalai/inference/README.md b/colossalai/inference/README.md
index 0a9b5293d..76813a4a3 100644
--- a/colossalai/inference/README.md
+++ b/colossalai/inference/README.md
@@ -18,7 +18,7 @@
## 📌 Introduction
-ColossalAI-Inference is a module which offers acceleration to the inference execution of Transformers models, especially LLMs. In ColossalAI-Inference, we leverage high-performance kernels, KV cache, paged attention, continous batching and other techniques to accelerate the inference of LLMs. We also provide simple and unified APIs for the sake of user-friendliness. [[blog]](https://hpc-ai.com/blog/colossal-inference)
+ColossalAI-Inference is a module which offers acceleration to the inference execution of Transformers models, especially LLMs and DiT Diffusion Models. In ColossalAI-Inference, we leverage high-performance kernels, KV cache, paged attention, continous batching and other techniques to accelerate the inference of LLMs. We also provide simple and unified APIs for the sake of user-friendliness. [[blog]](https://hpc-ai.com/blog/colossal-inference)
@@ -310,4 +310,14 @@ If you wish to cite relevant research papars, you can find the reference below.
journal={arXiv},
year={2023}
}
+
+# Distrifusion
+@InProceedings{Li_2024_CVPR,
+ author={Li, Muyang and Cai, Tianle and Cao, Jiaxin and Zhang, Qinsheng and Cai, Han and Bai, Junjie and Jia, Yangqing and Li, Kai and Han, Song},
+ title={DistriFusion: Distributed Parallel Inference for High-Resolution Diffusion Models},
+ booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
+ month={June},
+ year={2024},
+ pages={7183-7193}
+}
```
diff --git a/colossalai/inference/config.py b/colossalai/inference/config.py
index 1beb86874..072ddbcfd 100644
--- a/colossalai/inference/config.py
+++ b/colossalai/inference/config.py
@@ -186,6 +186,7 @@ class InferenceConfig(RPC_PARAM):
enable_streamingllm(bool): Whether to use StreamingLLM, the relevant algorithms refer to the paper at https://arxiv.org/pdf/2309.17453 for implementation.
start_token_size(int): The size of the start tokens, when using StreamingLLM.
generated_token_size(int): The size of the generated tokens, When using StreamingLLM.
+ patched_parallelism_size(int): Patched Parallelism Size, When using Distrifusion
"""
# NOTE: arrange configs according to their importance and frequency of usage
@@ -245,6 +246,11 @@ class InferenceConfig(RPC_PARAM):
start_token_size: int = 4
generated_token_size: int = 512
+ # Acceleration for Diffusion Model(PipeFusion or Distrifusion)
+ patched_parallelism_size: int = 1 # for distrifusion
+ # pipeFusion_m_size: int = 1 # for pipefusion
+ # pipeFusion_n_size: int = 1 # for pipefusion
+
def __post_init__(self):
self.max_context_len_to_capture = self.max_input_len + self.max_output_len
self._verify_config()
@@ -288,6 +294,14 @@ class InferenceConfig(RPC_PARAM):
# Thereafter, we swap out tokens in units of blocks, and always swapping out the second block when the generated tokens exceeded the limit.
self.start_token_size = self.block_size
+ # check Distrifusion
+ # TODO(@lry89757) need more detailed check
+ if self.patched_parallelism_size > 1:
+ # self.use_patched_parallelism = True
+ self.tp_size = (
+ self.patched_parallelism_size
+ ) # this is not a real tp, because some annoying check, so we have to set this to patched_parallelism_size
+
# check prompt template
if self.prompt_template is None:
return
@@ -324,6 +338,7 @@ class InferenceConfig(RPC_PARAM):
use_cuda_kernel=self.use_cuda_kernel,
use_spec_dec=self.use_spec_dec,
use_flash_attn=use_flash_attn,
+ patched_parallelism_size=self.patched_parallelism_size,
)
return model_inference_config
@@ -396,6 +411,7 @@ class ModelShardInferenceConfig:
use_cuda_kernel: bool = False
use_spec_dec: bool = False
use_flash_attn: bool = False
+ patched_parallelism_size: int = 1 # for diffusion model, Distrifusion Technique
@dataclass
diff --git a/colossalai/inference/core/diffusion_engine.py b/colossalai/inference/core/diffusion_engine.py
index 75b9889bf..8bed508cb 100644
--- a/colossalai/inference/core/diffusion_engine.py
+++ b/colossalai/inference/core/diffusion_engine.py
@@ -11,7 +11,7 @@ from torch import distributed as dist
from colossalai.accelerator import get_accelerator
from colossalai.cluster import ProcessGroupMesh
from colossalai.inference.config import DiffusionGenerationConfig, InferenceConfig, ModelShardInferenceConfig
-from colossalai.inference.modeling.models.diffusion import DiffusionPipe
+from colossalai.inference.modeling.layers.diffusion import DiffusionPipe
from colossalai.inference.modeling.policy import model_policy_map
from colossalai.inference.struct import DiffusionSequence
from colossalai.inference.utils import get_model_size, get_model_type
diff --git a/colossalai/inference/modeling/models/diffusion.py b/colossalai/inference/modeling/layers/diffusion.py
similarity index 100%
rename from colossalai/inference/modeling/models/diffusion.py
rename to colossalai/inference/modeling/layers/diffusion.py
diff --git a/colossalai/inference/modeling/layers/distrifusion.py b/colossalai/inference/modeling/layers/distrifusion.py
new file mode 100644
index 000000000..ea97cceef
--- /dev/null
+++ b/colossalai/inference/modeling/layers/distrifusion.py
@@ -0,0 +1,626 @@
+# Code refer and adapted from:
+# https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers
+# https://github.com/PipeFusion/PipeFusion
+
+import inspect
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+from diffusers.models import attention_processor
+from diffusers.models.attention import Attention
+from diffusers.models.embeddings import PatchEmbed, get_2d_sincos_pos_embed
+from diffusers.models.transformers.pixart_transformer_2d import PixArtTransformer2DModel
+from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel
+from torch import nn
+from torch.distributed import ProcessGroup
+
+from colossalai.inference.config import ModelShardInferenceConfig
+from colossalai.logging import get_dist_logger
+from colossalai.shardformer.layer.parallel_module import ParallelModule
+from colossalai.utils import get_current_device
+
+try:
+ from flash_attn import flash_attn_func
+
+ HAS_FLASH_ATTN = True
+except ImportError:
+ HAS_FLASH_ATTN = False
+
+
+logger = get_dist_logger(__name__)
+
+
+# adapted from https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers/models/transformers/transformer_2d.py
+def PixArtAlphaTransformer2DModel_forward(
+ self: PixArtTransformer2DModel,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ timestep: Optional[torch.LongTensor] = None,
+ added_cond_kwargs: Dict[str, torch.Tensor] = None,
+ class_labels: Optional[torch.LongTensor] = None,
+ cross_attention_kwargs: Dict[str, Any] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+):
+ assert hasattr(
+ self, "patched_parallel_size"
+ ), "please check your policy, `Transformer2DModel` Must have attribute `patched_parallel_size`"
+
+ if cross_attention_kwargs is not None:
+ if cross_attention_kwargs.get("scale", None) is not None:
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
+ # expects mask of shape:
+ # [batch, key_tokens]
+ # adds singleton query_tokens dimension:
+ # [batch, 1, key_tokens]
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
+ if attention_mask is not None and attention_mask.ndim == 2:
+ # assume that mask is expressed as:
+ # (1 = keep, 0 = discard)
+ # convert mask into a bias that can be added to attention scores:
+ # (keep = +0, discard = -10000.0)
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
+ attention_mask = attention_mask.unsqueeze(1)
+
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
+
+ # 1. Input
+ batch_size = hidden_states.shape[0]
+ height, width = (
+ hidden_states.shape[-2] // self.config.patch_size,
+ hidden_states.shape[-1] // self.config.patch_size,
+ )
+ hidden_states = self.pos_embed(hidden_states)
+
+ timestep, embedded_timestep = self.adaln_single(
+ timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
+ )
+
+ if self.caption_projection is not None:
+ encoder_hidden_states = self.caption_projection(encoder_hidden_states)
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
+
+ # 2. Blocks
+ for block in self.transformer_blocks:
+ hidden_states = block(
+ hidden_states,
+ attention_mask=attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ timestep=timestep,
+ cross_attention_kwargs=cross_attention_kwargs,
+ class_labels=class_labels,
+ )
+
+ # 3. Output
+ shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None].to(self.scale_shift_table.device)).chunk(
+ 2, dim=1
+ )
+ hidden_states = self.norm_out(hidden_states)
+ # Modulation
+ hidden_states = hidden_states * (1 + scale.to(hidden_states.device)) + shift.to(hidden_states.device)
+ hidden_states = self.proj_out(hidden_states)
+ hidden_states = hidden_states.squeeze(1)
+
+ # unpatchify
+ hidden_states = hidden_states.reshape(
+ shape=(
+ -1,
+ height // self.patched_parallel_size,
+ width,
+ self.config.patch_size,
+ self.config.patch_size,
+ self.out_channels,
+ )
+ )
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
+ output = hidden_states.reshape(
+ shape=(
+ -1,
+ self.out_channels,
+ height // self.patched_parallel_size * self.config.patch_size,
+ width * self.config.patch_size,
+ )
+ )
+
+ # enable Distrifusion Optimization
+ if hasattr(self, "patched_parallel_size"):
+ from torch import distributed as dist
+
+ if (getattr(self, "output_buffer", None) is None) or (self.output_buffer.shape != output.shape):
+ self.output_buffer = torch.empty_like(output)
+ if (getattr(self, "buffer_list", None) is None) or (self.buffer_list[0].shape != output.shape):
+ self.buffer_list = [torch.empty_like(output) for _ in range(self.patched_parallel_size)]
+ output = output.contiguous()
+ dist.all_gather(self.buffer_list, output, async_op=False)
+ torch.cat(self.buffer_list, dim=2, out=self.output_buffer)
+ output = self.output_buffer
+
+ return (output,)
+
+
+# adapted from https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers/models/transformers/transformer_sd3.py
+def SD3Transformer2DModel_forward(
+ self: SD3Transformer2DModel,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: torch.FloatTensor = None,
+ pooled_projections: torch.FloatTensor = None,
+ timestep: torch.LongTensor = None,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ return_dict: bool = True,
+) -> Union[torch.FloatTensor]:
+
+ assert hasattr(
+ self, "patched_parallel_size"
+ ), "please check your policy, `Transformer2DModel` Must have attribute `patched_parallel_size`"
+
+ height, width = hidden_states.shape[-2:]
+
+ hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too.
+ temb = self.time_text_embed(timestep, pooled_projections)
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
+
+ for block in self.transformer_blocks:
+ encoder_hidden_states, hidden_states = block(
+ hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
+ )
+
+ hidden_states = self.norm_out(hidden_states, temb)
+ hidden_states = self.proj_out(hidden_states)
+
+ # unpatchify
+ patch_size = self.config.patch_size
+ height = height // patch_size // self.patched_parallel_size
+ width = width // patch_size
+
+ hidden_states = hidden_states.reshape(
+ shape=(hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels)
+ )
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
+ output = hidden_states.reshape(
+ shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size)
+ )
+
+ # enable Distrifusion Optimization
+ if hasattr(self, "patched_parallel_size"):
+ from torch import distributed as dist
+
+ if (getattr(self, "output_buffer", None) is None) or (self.output_buffer.shape != output.shape):
+ self.output_buffer = torch.empty_like(output)
+ if (getattr(self, "buffer_list", None) is None) or (self.buffer_list[0].shape != output.shape):
+ self.buffer_list = [torch.empty_like(output) for _ in range(self.patched_parallel_size)]
+ output = output.contiguous()
+ dist.all_gather(self.buffer_list, output, async_op=False)
+ torch.cat(self.buffer_list, dim=2, out=self.output_buffer)
+ output = self.output_buffer
+
+ return (output,)
+
+
+# Code adapted from: https://github.com/PipeFusion/PipeFusion/blob/main/pipefuser/modules/dit/patch_parallel/patchembed.py
+class DistrifusionPatchEmbed(ParallelModule):
+ def __init__(
+ self,
+ module: PatchEmbed,
+ process_group: Union[ProcessGroup, List[ProcessGroup]],
+ model_shard_infer_config: ModelShardInferenceConfig = None,
+ ):
+ super().__init__()
+ self.module = module
+ self.rank = dist.get_rank(group=process_group)
+ self.patched_parallelism_size = model_shard_infer_config.patched_parallelism_size
+
+ @staticmethod
+ def from_native_module(module: PatchEmbed, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs):
+ model_shard_infer_config = kwargs.get("model_shard_infer_config", None)
+ distrifusion_embed = DistrifusionPatchEmbed(
+ module, process_group, model_shard_infer_config=model_shard_infer_config
+ )
+ return distrifusion_embed
+
+ def forward(self, latent):
+ module = self.module
+ if module.pos_embed_max_size is not None:
+ height, width = latent.shape[-2:]
+ else:
+ height, width = latent.shape[-2] // module.patch_size, latent.shape[-1] // module.patch_size
+
+ latent = module.proj(latent)
+ if module.flatten:
+ latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
+ if module.layer_norm:
+ latent = module.norm(latent)
+ if module.pos_embed is None:
+ return latent.to(latent.dtype)
+ # Interpolate or crop positional embeddings as needed
+ if module.pos_embed_max_size:
+ pos_embed = module.cropped_pos_embed(height, width)
+ else:
+ if module.height != height or module.width != width:
+ pos_embed = get_2d_sincos_pos_embed(
+ embed_dim=module.pos_embed.shape[-1],
+ grid_size=(height, width),
+ base_size=module.base_size,
+ interpolation_scale=module.interpolation_scale,
+ )
+ pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).to(latent.device)
+ else:
+ pos_embed = module.pos_embed
+
+ b, c, h = pos_embed.shape
+ pos_embed = pos_embed.view(b, self.patched_parallelism_size, -1, h)[:, self.rank]
+
+ return (latent + pos_embed).to(latent.dtype)
+
+
+# Code adapted from: https://github.com/PipeFusion/PipeFusion/blob/main/pipefuser/modules/dit/patch_parallel/conv2d.py
+class DistrifusionConv2D(ParallelModule):
+
+ def __init__(
+ self,
+ module: nn.Conv2d,
+ process_group: Union[ProcessGroup, List[ProcessGroup]],
+ model_shard_infer_config: ModelShardInferenceConfig = None,
+ ):
+ super().__init__()
+ self.module = module
+ self.rank = dist.get_rank(group=process_group)
+ self.patched_parallelism_size = model_shard_infer_config.patched_parallelism_size
+
+ @staticmethod
+ def from_native_module(module: nn.Conv2d, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs):
+ model_shard_infer_config = kwargs.get("model_shard_infer_config", None)
+ distrifusion_conv = DistrifusionConv2D(module, process_group, model_shard_infer_config=model_shard_infer_config)
+ return distrifusion_conv
+
+ def sliced_forward(self, x: torch.Tensor) -> torch.Tensor:
+
+ b, c, h, w = x.shape
+
+ stride = self.module.stride[0]
+ padding = self.module.padding[0]
+
+ output_h = x.shape[2] // stride // self.patched_parallelism_size
+ idx = dist.get_rank()
+ h_begin = output_h * idx * stride - padding
+ h_end = output_h * (idx + 1) * stride + padding
+ final_padding = [padding, padding, 0, 0]
+ if h_begin < 0:
+ h_begin = 0
+ final_padding[2] = padding
+ if h_end > h:
+ h_end = h
+ final_padding[3] = padding
+ sliced_input = x[:, :, h_begin:h_end, :]
+ padded_input = F.pad(sliced_input, final_padding, mode="constant")
+ return F.conv2d(
+ padded_input,
+ self.module.weight,
+ self.module.bias,
+ stride=stride,
+ padding="valid",
+ )
+
+ def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ output = self.sliced_forward(input)
+ return output
+
+
+# Code adapted from: https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers/models/attention_processor.py
+class DistrifusionFusedAttention(ParallelModule):
+
+ def __init__(
+ self,
+ module: attention_processor.Attention,
+ process_group: Union[ProcessGroup, List[ProcessGroup]],
+ model_shard_infer_config: ModelShardInferenceConfig = None,
+ ):
+ super().__init__()
+ self.counter = 0
+ self.module = module
+ self.buffer_list = None
+ self.kv_buffer_idx = dist.get_rank(group=process_group)
+ self.patched_parallelism_size = model_shard_infer_config.patched_parallelism_size
+ self.handle = None
+ self.process_group = process_group
+ self.warm_step = 5 # for warmup
+
+ @staticmethod
+ def from_native_module(
+ module: attention_processor.Attention, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
+ ) -> ParallelModule:
+ model_shard_infer_config = kwargs.get("model_shard_infer_config", None)
+ return DistrifusionFusedAttention(
+ module=module,
+ process_group=process_group,
+ model_shard_infer_config=model_shard_infer_config,
+ )
+
+ def _forward(
+ self,
+ attn: Attention,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: torch.FloatTensor = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ *args,
+ **kwargs,
+ ) -> torch.FloatTensor:
+ residual = hidden_states
+
+ input_ndim = hidden_states.ndim
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+ context_input_ndim = encoder_hidden_states.ndim
+ if context_input_ndim == 4:
+ batch_size, channel, height, width = encoder_hidden_states.shape
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size = encoder_hidden_states.shape[0]
+
+ # `sample` projections.
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+
+ kv = torch.cat([key, value], dim=-1) # shape of kv now: (bs, seq_len // parallel_size, dim * 2)
+
+ if self.patched_parallelism_size == 1:
+ full_kv = kv
+ else:
+ if self.buffer_list is None: # buffer not created
+ full_kv = torch.cat([kv for _ in range(self.patched_parallelism_size)], dim=1)
+ elif self.counter <= self.warm_step:
+ # logger.info(f"warmup: {self.counter}")
+ dist.all_gather(
+ self.buffer_list,
+ kv,
+ group=self.process_group,
+ async_op=False,
+ )
+ full_kv = torch.cat(self.buffer_list, dim=1)
+ else:
+ # logger.info(f"use old kv to infer: {self.counter}")
+ self.buffer_list[self.kv_buffer_idx].copy_(kv)
+ full_kv = torch.cat(self.buffer_list, dim=1)
+ assert self.handle is None, "we should maintain the kv of last step"
+ self.handle = dist.all_gather(self.buffer_list, kv, group=self.process_group, async_op=True)
+
+ key, value = torch.split(full_kv, full_kv.shape[-1] // 2, dim=-1)
+
+ # `context` projections.
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
+
+ # attention
+ query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
+ key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
+ value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ hidden_states = hidden_states = F.scaled_dot_product_attention(
+ query, key, value, dropout_p=0.0, is_causal=False
+ ) # NOTE(@lry89757) for torch >= 2.2, flash attn has been already integrated into scaled_dot_product_attention, https://pytorch.org/blog/pytorch2-2/
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # Split the attention outputs.
+ hidden_states, encoder_hidden_states = (
+ hidden_states[:, : residual.shape[1]],
+ hidden_states[:, residual.shape[1] :],
+ )
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+ if not attn.context_pre_only:
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+ if context_input_ndim == 4:
+ encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ return hidden_states, encoder_hidden_states
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ **cross_attention_kwargs,
+ ) -> torch.Tensor:
+
+ if self.handle is not None:
+ self.handle.wait()
+ self.handle = None
+
+ b, l, c = hidden_states.shape
+ kv_shape = (b, l, self.module.to_k.out_features * 2)
+ if self.patched_parallelism_size > 1 and (self.buffer_list is None or self.buffer_list[0].shape != kv_shape):
+
+ self.buffer_list = [
+ torch.empty(kv_shape, dtype=hidden_states.dtype, device=get_current_device())
+ for _ in range(self.patched_parallelism_size)
+ ]
+
+ self.counter = 0
+
+ attn_parameters = set(inspect.signature(self.module.processor.__call__).parameters.keys())
+ quiet_attn_parameters = {"ip_adapter_masks"}
+ unused_kwargs = [
+ k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters
+ ]
+ if len(unused_kwargs) > 0:
+ logger.warning(
+ f"cross_attention_kwargs {unused_kwargs} are not expected by {self.module.processor.__class__.__name__} and will be ignored."
+ )
+ cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters}
+
+ output = self._forward(
+ self.module,
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+
+ self.counter += 1
+
+ return output
+
+
+# Code adapted from: https://github.com/PipeFusion/PipeFusion/blob/main/pipefuser/modules/dit/patch_parallel/attn.py
+class DistriSelfAttention(ParallelModule):
+ def __init__(
+ self,
+ module: Attention,
+ process_group: Union[ProcessGroup, List[ProcessGroup]],
+ model_shard_infer_config: ModelShardInferenceConfig = None,
+ ):
+ super().__init__()
+ self.counter = 0
+ self.module = module
+ self.buffer_list = None
+ self.kv_buffer_idx = dist.get_rank(group=process_group)
+ self.patched_parallelism_size = model_shard_infer_config.patched_parallelism_size
+ self.handle = None
+ self.process_group = process_group
+ self.warm_step = 3 # for warmup
+
+ @staticmethod
+ def from_native_module(
+ module: Attention, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
+ ) -> ParallelModule:
+ model_shard_infer_config = kwargs.get("model_shard_infer_config", None)
+ return DistriSelfAttention(
+ module=module,
+ process_group=process_group,
+ model_shard_infer_config=model_shard_infer_config,
+ )
+
+ def _forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0):
+ attn = self.module
+ assert isinstance(attn, Attention)
+
+ residual = hidden_states
+
+ batch_size, sequence_length, _ = hidden_states.shape
+
+ query = attn.to_q(hidden_states)
+
+ encoder_hidden_states = hidden_states
+ k = self.module.to_k(encoder_hidden_states)
+ v = self.module.to_v(encoder_hidden_states)
+ kv = torch.cat([k, v], dim=-1) # shape of kv now: (bs, seq_len // parallel_size, dim * 2)
+
+ if self.patched_parallelism_size == 1:
+ full_kv = kv
+ else:
+ if self.buffer_list is None: # buffer not created
+ full_kv = torch.cat([kv for _ in range(self.patched_parallelism_size)], dim=1)
+ elif self.counter <= self.warm_step:
+ # logger.info(f"warmup: {self.counter}")
+ dist.all_gather(
+ self.buffer_list,
+ kv,
+ group=self.process_group,
+ async_op=False,
+ )
+ full_kv = torch.cat(self.buffer_list, dim=1)
+ else:
+ # logger.info(f"use old kv to infer: {self.counter}")
+ self.buffer_list[self.kv_buffer_idx].copy_(kv)
+ full_kv = torch.cat(self.buffer_list, dim=1)
+ assert self.handle is None, "we should maintain the kv of last step"
+ self.handle = dist.all_gather(self.buffer_list, kv, group=self.process_group, async_op=True)
+
+ if HAS_FLASH_ATTN:
+ # flash attn
+ key, value = torch.split(full_kv, full_kv.shape[-1] // 2, dim=-1)
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim)
+ key = key.view(batch_size, -1, attn.heads, head_dim)
+ value = value.view(batch_size, -1, attn.heads, head_dim)
+
+ hidden_states = flash_attn_func(query, key, value, dropout_p=0.0, causal=False)
+ hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim).to(query.dtype)
+ else:
+ # naive attn
+ key, value = torch.split(full_kv, full_kv.shape[-1] // 2, dim=-1)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ scale: float = 1.0,
+ *args,
+ **kwargs,
+ ) -> torch.FloatTensor:
+
+ # async preallocates memo buffer
+ if self.handle is not None:
+ self.handle.wait()
+ self.handle = None
+
+ b, l, c = hidden_states.shape
+ kv_shape = (b, l, self.module.to_k.out_features * 2)
+ if self.patched_parallelism_size > 1 and (self.buffer_list is None or self.buffer_list[0].shape != kv_shape):
+
+ self.buffer_list = [
+ torch.empty(kv_shape, dtype=hidden_states.dtype, device=get_current_device())
+ for _ in range(self.patched_parallelism_size)
+ ]
+
+ self.counter = 0
+
+ output = self._forward(hidden_states, scale=scale)
+
+ self.counter += 1
+ return output
diff --git a/colossalai/inference/modeling/models/pixart_alpha.py b/colossalai/inference/modeling/models/pixart_alpha.py
index d5774946e..cc2bee5ef 100644
--- a/colossalai/inference/modeling/models/pixart_alpha.py
+++ b/colossalai/inference/modeling/models/pixart_alpha.py
@@ -14,7 +14,7 @@ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retri
from colossalai.logging import get_dist_logger
-from .diffusion import DiffusionPipe
+from ..layers.diffusion import DiffusionPipe
logger = get_dist_logger(__name__)
diff --git a/colossalai/inference/modeling/models/stablediffusion3.py b/colossalai/inference/modeling/models/stablediffusion3.py
index d1c63a6dc..b12316403 100644
--- a/colossalai/inference/modeling/models/stablediffusion3.py
+++ b/colossalai/inference/modeling/models/stablediffusion3.py
@@ -4,7 +4,7 @@ from typing import Any, Callable, Dict, List, Optional, Union
import torch
from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import retrieve_timesteps
-from .diffusion import DiffusionPipe
+from ..layers.diffusion import DiffusionPipe
# TODO(@lry89757) temporarily image, please support more return output
diff --git a/colossalai/inference/modeling/policy/pixart_alpha.py b/colossalai/inference/modeling/policy/pixart_alpha.py
index 356056ba7..1150b2432 100644
--- a/colossalai/inference/modeling/policy/pixart_alpha.py
+++ b/colossalai/inference/modeling/policy/pixart_alpha.py
@@ -1,9 +1,17 @@
+from diffusers.models.attention import BasicTransformerBlock
+from diffusers.models.transformers.pixart_transformer_2d import PixArtTransformer2DModel
from torch import nn
from colossalai.inference.config import RPC_PARAM
-from colossalai.inference.modeling.models.diffusion import DiffusionPipe
+from colossalai.inference.modeling.layers.diffusion import DiffusionPipe
+from colossalai.inference.modeling.layers.distrifusion import (
+ DistrifusionConv2D,
+ DistrifusionPatchEmbed,
+ DistriSelfAttention,
+ PixArtAlphaTransformer2DModel_forward,
+)
from colossalai.inference.modeling.models.pixart_alpha import pixart_alpha_forward
-from colossalai.shardformer.policies.base_policy import Policy
+from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
class PixArtAlphaInferPolicy(Policy, RPC_PARAM):
@@ -12,9 +20,46 @@ class PixArtAlphaInferPolicy(Policy, RPC_PARAM):
def module_policy(self):
policy = {}
+
+ if self.shard_config.extra_kwargs["model_shard_infer_config"].patched_parallelism_size > 1:
+
+ policy[PixArtTransformer2DModel] = ModulePolicyDescription(
+ sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="pos_embed.proj",
+ target_module=DistrifusionConv2D,
+ kwargs={"model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"]},
+ ),
+ SubModuleReplacementDescription(
+ suffix="pos_embed",
+ target_module=DistrifusionPatchEmbed,
+ kwargs={"model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"]},
+ ),
+ ],
+ attribute_replacement={
+ "patched_parallel_size": self.shard_config.extra_kwargs[
+ "model_shard_infer_config"
+ ].patched_parallelism_size
+ },
+ method_replacement={"forward": PixArtAlphaTransformer2DModel_forward},
+ )
+
+ policy[BasicTransformerBlock] = ModulePolicyDescription(
+ sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="attn1",
+ target_module=DistriSelfAttention,
+ kwargs={
+ "model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"],
+ },
+ )
+ ]
+ )
+
self.append_or_create_method_replacement(
description={"forward": pixart_alpha_forward}, policy=policy, target_key=DiffusionPipe
)
+
return policy
def preprocess(self) -> nn.Module:
diff --git a/colossalai/inference/modeling/policy/stablediffusion3.py b/colossalai/inference/modeling/policy/stablediffusion3.py
index c9877f7dc..39b764b92 100644
--- a/colossalai/inference/modeling/policy/stablediffusion3.py
+++ b/colossalai/inference/modeling/policy/stablediffusion3.py
@@ -1,9 +1,17 @@
+from diffusers.models.attention import JointTransformerBlock
+from diffusers.models.transformers import SD3Transformer2DModel
from torch import nn
from colossalai.inference.config import RPC_PARAM
-from colossalai.inference.modeling.models.diffusion import DiffusionPipe
+from colossalai.inference.modeling.layers.diffusion import DiffusionPipe
+from colossalai.inference.modeling.layers.distrifusion import (
+ DistrifusionConv2D,
+ DistrifusionFusedAttention,
+ DistrifusionPatchEmbed,
+ SD3Transformer2DModel_forward,
+)
from colossalai.inference.modeling.models.stablediffusion3 import sd3_forward
-from colossalai.shardformer.policies.base_policy import Policy
+from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
class StableDiffusion3InferPolicy(Policy, RPC_PARAM):
@@ -12,6 +20,42 @@ class StableDiffusion3InferPolicy(Policy, RPC_PARAM):
def module_policy(self):
policy = {}
+
+ if self.shard_config.extra_kwargs["model_shard_infer_config"].patched_parallelism_size > 1:
+
+ policy[SD3Transformer2DModel] = ModulePolicyDescription(
+ sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="pos_embed.proj",
+ target_module=DistrifusionConv2D,
+ kwargs={"model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"]},
+ ),
+ SubModuleReplacementDescription(
+ suffix="pos_embed",
+ target_module=DistrifusionPatchEmbed,
+ kwargs={"model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"]},
+ ),
+ ],
+ attribute_replacement={
+ "patched_parallel_size": self.shard_config.extra_kwargs[
+ "model_shard_infer_config"
+ ].patched_parallelism_size
+ },
+ method_replacement={"forward": SD3Transformer2DModel_forward},
+ )
+
+ policy[JointTransformerBlock] = ModulePolicyDescription(
+ sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="attn",
+ target_module=DistrifusionFusedAttention,
+ kwargs={
+ "model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"],
+ },
+ )
+ ]
+ )
+
self.append_or_create_method_replacement(
description={"forward": sd3_forward}, policy=policy, target_key=DiffusionPipe
)
diff --git a/examples/inference/stable_diffusion/README.md b/examples/inference/stable_diffusion/README.md
new file mode 100644
index 000000000..c11b98043
--- /dev/null
+++ b/examples/inference/stable_diffusion/README.md
@@ -0,0 +1,22 @@
+## File Structure
+```
+|- sd3_generation.py: an example of how to use Colossalai Inference Engine to generate result by loading Diffusion Model.
+|- compute_metric.py: compare the quality of images w/o some acceleration method like Distrifusion
+|- benchmark_sd3.py: benchmark the performance of our InferenceEngine
+|- run_benchmark.sh: run benchmark command
+```
+Note: compute_metric.py need some dependencies which need `pip install -r requirements.txt`, `requirements.txt` is in `examples/inference/stable_diffusion/`
+
+## Run Inference
+
+The provided example `sd3_generation.py` is an example to configure, initialize the engine, and run inference on provided model. We've added `DiffusionPipeline` as model class, and the script is good to run inference with StableDiffusion 3.
+
+For a basic setting, you could run the example by:
+```bash
+colossalai run --nproc_per_node 1 sd3_generation.py -m PATH_MODEL -p "hello world"
+```
+
+Run multi-GPU inference (Patched Parallelism), as in the following example using 2 GPUs:
+```bash
+colossalai run --nproc_per_node 2 sd3_generation.py -m PATH_MODEL
+```
diff --git a/examples/inference/stable_diffusion/benchmark_sd3.py b/examples/inference/stable_diffusion/benchmark_sd3.py
new file mode 100644
index 000000000..19db57c33
--- /dev/null
+++ b/examples/inference/stable_diffusion/benchmark_sd3.py
@@ -0,0 +1,179 @@
+import argparse
+import json
+import time
+from contextlib import nullcontext
+
+import torch
+import torch.distributed as dist
+from diffusers import DiffusionPipeline
+
+import colossalai
+from colossalai.inference.config import DiffusionGenerationConfig, InferenceConfig
+from colossalai.inference.core.engine import InferenceEngine
+from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
+
+GIGABYTE = 1024**3
+MEGABYTE = 1024 * 1024
+
+_DTYPE_MAPPING = {
+ "fp16": torch.float16,
+ "bf16": torch.bfloat16,
+ "fp32": torch.float32,
+}
+
+
+def log_generation_time(log_data, log_file):
+ with open(log_file, "a") as f:
+ json.dump(log_data, f, indent=2)
+ f.write("\n")
+
+
+def warmup(engine, args):
+ for _ in range(args.n_warm_up_steps):
+ engine.generate(
+ prompts=["hello world"],
+ generation_config=DiffusionGenerationConfig(
+ num_inference_steps=args.num_inference_steps, height=args.height[0], width=args.width[0]
+ ),
+ )
+
+
+def profile_context(args):
+ return (
+ torch.profiler.profile(
+ record_shapes=True,
+ with_stack=True,
+ with_modules=True,
+ activities=[
+ torch.profiler.ProfilerActivity.CPU,
+ torch.profiler.ProfilerActivity.CUDA,
+ ],
+ )
+ if args.profile
+ else nullcontext()
+ )
+
+
+def log_and_profile(h, w, avg_time, log_msg, args, model_name, mode, prof=None):
+ log_data = {
+ "mode": mode,
+ "model": model_name,
+ "batch_size": args.batch_size,
+ "patched_parallel_size": args.patched_parallel_size,
+ "num_inference_steps": args.num_inference_steps,
+ "height": h,
+ "width": w,
+ "dtype": args.dtype,
+ "profile": args.profile,
+ "n_warm_up_steps": args.n_warm_up_steps,
+ "n_repeat_times": args.n_repeat_times,
+ "avg_generation_time": avg_time,
+ "log_message": log_msg,
+ }
+
+ if args.log:
+ log_file = f"examples/inference/stable_diffusion/benchmark_{model_name}_{mode}.json"
+ log_generation_time(log_data=log_data, log_file=log_file)
+
+ if args.profile:
+ file = f"examples/inference/stable_diffusion/benchmark_{model_name}_{mode}_prof.json"
+ prof.export_chrome_trace(file)
+
+
+def benchmark_colossalai(rank, world_size, port, args):
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ from colossalai.cluster.dist_coordinator import DistCoordinator
+
+ coordinator = DistCoordinator()
+
+ inference_config = InferenceConfig(
+ dtype=args.dtype,
+ patched_parallelism_size=args.patched_parallel_size,
+ )
+ engine = InferenceEngine(args.model, inference_config=inference_config, verbose=False)
+
+ warmup(engine, args)
+
+ for h, w in zip(args.height, args.width):
+ with profile_context(args) as prof:
+ start = time.perf_counter()
+ for _ in range(args.n_repeat_times):
+ engine.generate(
+ prompts=["hello world"],
+ generation_config=DiffusionGenerationConfig(
+ num_inference_steps=args.num_inference_steps, height=h, width=w
+ ),
+ )
+ end = time.perf_counter()
+
+ avg_time = (end - start) / args.n_repeat_times
+ log_msg = f"[ColossalAI]avg generation time for h({h})xw({w}) is {avg_time:.2f}s"
+ coordinator.print_on_master(log_msg)
+
+ if dist.get_rank() == 0:
+ log_and_profile(h, w, avg_time, log_msg, args, args.model.split("/")[-1], "colossalai", prof=prof)
+
+
+def benchmark_diffusers(args):
+ model = DiffusionPipeline.from_pretrained(args.model, torch_dtype=_DTYPE_MAPPING[args.dtype]).to("cuda")
+
+ for _ in range(args.n_warm_up_steps):
+ model(
+ prompt="hello world",
+ num_inference_steps=args.num_inference_steps,
+ height=args.height[0],
+ width=args.width[0],
+ )
+
+ for h, w in zip(args.height, args.width):
+ with profile_context(args) as prof:
+ start = time.perf_counter()
+ for _ in range(args.n_repeat_times):
+ model(prompt="hello world", num_inference_steps=args.num_inference_steps, height=h, width=w)
+ end = time.perf_counter()
+
+ avg_time = (end - start) / args.n_repeat_times
+ log_msg = f"[Diffusers]avg generation time for h({h})xw({w}) is {avg_time:.2f}s"
+ print(log_msg)
+
+ log_and_profile(h, w, avg_time, log_msg, args, args.model.split("/")[-1], "diffusers", prof)
+
+
+@rerun_if_address_is_in_use()
+@clear_cache_before_run()
+def benchmark(args):
+ if args.mode == "colossalai":
+ spawn(benchmark_colossalai, nprocs=args.patched_parallel_size, args=args)
+ elif args.mode == "diffusers":
+ benchmark_diffusers(args)
+
+
+"""
+# enable log
+python examples/inference/stable_diffusion/benchmark_sd3.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" -p 2 --mode colossalai --log
+python examples/inference/stable_diffusion/benchmark_sd3.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" --mode diffusers --log
+
+# enable profiler
+python examples/inference/stable_diffusion/benchmark_sd3.py -m "stabilityai/stable-diffusion-3-medium-diffusers" -p 2 --mode colossalai --n_warm_up_steps 3 --n_repeat_times 1 --profile --num_inference_steps 20
+python examples/inference/stable_diffusion/benchmark_sd3.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" -p 2 --mode colossalai --n_warm_up_steps 3 --n_repeat_times 1 --profile --num_inference_steps 20
+python examples/inference/stable_diffusion/benchmark_sd3.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" --mode diffusers --n_warm_up_steps 3 --n_repeat_times 1 --profile --num_inference_steps 20
+"""
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("-b", "--batch_size", type=int, default=1, help="Batch size")
+ parser.add_argument("-p", "--patched_parallel_size", type=int, default=1, help="Patched Parallelism size")
+ parser.add_argument("-n", "--num_inference_steps", type=int, default=50, help="Number of inference steps")
+ parser.add_argument("-H", "--height", type=int, nargs="+", default=[1024, 2048], help="Height list")
+ parser.add_argument("-w", "--width", type=int, nargs="+", default=[1024, 2048], help="Width list")
+ parser.add_argument("--dtype", type=str, default="fp16", choices=["fp16", "fp32", "bf16"], help="Data type")
+ parser.add_argument("--n_warm_up_steps", type=int, default=3, help="Number of warm up steps")
+ parser.add_argument("--n_repeat_times", type=int, default=5, help="Number of repeat times")
+ parser.add_argument("--profile", default=False, action="store_true", help="Enable torch profiler")
+ parser.add_argument("--log", default=False, action="store_true", help="Enable logging")
+ parser.add_argument("-m", "--model", default="stabilityai/stable-diffusion-3-medium-diffusers", help="Model path")
+ parser.add_argument(
+ "--mode", default="colossalai", choices=["colossalai", "diffusers"], help="Inference framework mode"
+ )
+ args = parser.parse_args()
+ benchmark(args)
diff --git a/examples/inference/stable_diffusion/compute_metric.py b/examples/inference/stable_diffusion/compute_metric.py
new file mode 100644
index 000000000..14c92501b
--- /dev/null
+++ b/examples/inference/stable_diffusion/compute_metric.py
@@ -0,0 +1,80 @@
+# Code from https://github.com/mit-han-lab/distrifuser/blob/main/scripts/compute_metrics.py
+import argparse
+import os
+
+import numpy as np
+import torch
+from cleanfid import fid
+from PIL import Image
+from torch.utils.data import DataLoader, Dataset
+from torchmetrics.image import LearnedPerceptualImagePatchSimilarity, PeakSignalNoiseRatio
+from torchvision.transforms import Resize
+from tqdm import tqdm
+
+
+def read_image(path: str):
+ """
+ input: path
+ output: tensor (C, H, W)
+ """
+ img = np.asarray(Image.open(path))
+ if len(img.shape) == 2:
+ img = np.repeat(img[:, :, None], 3, axis=2)
+ img = torch.from_numpy(img).permute(2, 0, 1)
+ return img
+
+
+class MultiImageDataset(Dataset):
+ def __init__(self, root0, root1, is_gt=False):
+ super().__init__()
+ self.root0 = root0
+ self.root1 = root1
+ file_names0 = os.listdir(root0)
+ file_names1 = os.listdir(root1)
+
+ self.image_names0 = sorted([name for name in file_names0 if name.endswith(".png") or name.endswith(".jpg")])
+ self.image_names1 = sorted([name for name in file_names1 if name.endswith(".png") or name.endswith(".jpg")])
+ self.is_gt = is_gt
+ assert len(self.image_names0) == len(self.image_names1)
+
+ def __len__(self):
+ return len(self.image_names0)
+
+ def __getitem__(self, idx):
+ img0 = read_image(os.path.join(self.root0, self.image_names0[idx]))
+ if self.is_gt:
+ # resize to 1024 x 1024
+ img0 = Resize((1024, 1024))(img0)
+ img1 = read_image(os.path.join(self.root1, self.image_names1[idx]))
+
+ batch_list = [img0, img1]
+ return batch_list
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--batch_size", type=int, default=64)
+ parser.add_argument("--num_workers", type=int, default=8)
+ parser.add_argument("--is_gt", action="store_true")
+ parser.add_argument("--input_root0", type=str, required=True)
+ parser.add_argument("--input_root1", type=str, required=True)
+ args = parser.parse_args()
+
+ psnr = PeakSignalNoiseRatio(data_range=(0, 1), reduction="elementwise_mean", dim=(1, 2, 3)).to("cuda")
+ lpips = LearnedPerceptualImagePatchSimilarity(normalize=True).to("cuda")
+
+ dataset = MultiImageDataset(args.input_root0, args.input_root1, is_gt=args.is_gt)
+ dataloader = DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers)
+
+ progress_bar = tqdm(dataloader)
+ with torch.inference_mode():
+ for i, batch in enumerate(progress_bar):
+ batch = [img.to("cuda") / 255 for img in batch]
+ batch_size = batch[0].shape[0]
+ psnr.update(batch[0], batch[1])
+ lpips.update(batch[0], batch[1])
+ fid_score = fid.compute_fid(args.input_root0, args.input_root1)
+
+ print("PSNR:", psnr.compute().item())
+ print("LPIPS:", lpips.compute().item())
+ print("FID:", fid_score)
diff --git a/examples/inference/stable_diffusion/requirements.txt b/examples/inference/stable_diffusion/requirements.txt
new file mode 100644
index 000000000..c4e74162d
--- /dev/null
+++ b/examples/inference/stable_diffusion/requirements.txt
@@ -0,0 +1,3 @@
+torchvision
+torchmetrics
+cleanfid
diff --git a/examples/inference/stable_diffusion/run_benchmark.sh b/examples/inference/stable_diffusion/run_benchmark.sh
new file mode 100644
index 000000000..f3e45a335
--- /dev/null
+++ b/examples/inference/stable_diffusion/run_benchmark.sh
@@ -0,0 +1,42 @@
+#!/bin/bash
+
+models=("PixArt-alpha/PixArt-XL-2-1024-MS" "stabilityai/stable-diffusion-3-medium-diffusers")
+parallelism=(1 2 4 8)
+resolutions=(1024 2048 3840)
+modes=("colossalai" "diffusers")
+
+CUDA_VISIBLE_DEVICES_set_n_least_memory_usage() {
+ local n=${1:-"9999"}
+ echo "GPU Memory Usage:"
+ local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \
+ | tail -n +2 \
+ | nl -v 0 \
+ | tee /dev/tty \
+ | sort -g -k 2 \
+ | awk '{print $1}' \
+ | head -n $n)
+ export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
+ echo "Now CUDA_VISIBLE_DEVICES is set to:"
+ echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
+}
+
+for model in "${models[@]}"; do
+ for p in "${parallelism[@]}"; do
+ for resolution in "${resolutions[@]}"; do
+ for mode in "${modes[@]}"; do
+ if [[ "$mode" == "colossalai" && "$p" == 1 ]]; then
+ continue
+ fi
+ if [[ "$mode" == "diffusers" && "$p" != 1 ]]; then
+ continue
+ fi
+ CUDA_VISIBLE_DEVICES_set_n_least_memory_usage $p
+
+ cmd="python examples/inference/stable_diffusion/benchmark_sd3.py -m \"$model\" -p $p --mode $mode --log -H $resolution -w $resolution"
+
+ echo "Executing: $cmd"
+ eval $cmd
+ done
+ done
+ done
+done
diff --git a/examples/inference/stable_diffusion/sd3_generation.py b/examples/inference/stable_diffusion/sd3_generation.py
index fe989eed7..9e146c34b 100644
--- a/examples/inference/stable_diffusion/sd3_generation.py
+++ b/examples/inference/stable_diffusion/sd3_generation.py
@@ -1,18 +1,17 @@
import argparse
-from diffusers import PixArtAlphaPipeline, StableDiffusion3Pipeline
-from torch import bfloat16, float16, float32
+from diffusers import DiffusionPipeline
+from torch import bfloat16
+from torch import distributed as dist
+from torch import float16, float32
import colossalai
from colossalai.cluster import DistCoordinator
from colossalai.inference.config import DiffusionGenerationConfig, InferenceConfig
from colossalai.inference.core.engine import InferenceEngine
-from colossalai.inference.modeling.policy.pixart_alpha import PixArtAlphaInferPolicy
-from colossalai.inference.modeling.policy.stablediffusion3 import StableDiffusion3InferPolicy
# For Stable Diffusion 3, we'll use the following configuration
-MODEL_CLS = [StableDiffusion3Pipeline, PixArtAlphaPipeline][0]
-POLICY_CLS = [StableDiffusion3InferPolicy, PixArtAlphaInferPolicy][0]
+MODEL_CLS = DiffusionPipeline
TORCH_DTYPE_MAP = {
"fp16": float16,
@@ -43,20 +42,27 @@ def infer(args):
max_batch_size=args.max_batch_size,
tp_size=args.tp_size,
use_cuda_kernel=args.use_cuda_kernel,
+ patched_parallelism_size=dist.get_world_size(),
)
- engine = InferenceEngine(model, inference_config=inference_config, model_policy=POLICY_CLS(), verbose=True)
+ engine = InferenceEngine(model, inference_config=inference_config, verbose=True)
# ==============================
# Generation
# ==============================
coordinator.print_on_master(f"Generating...")
out = engine.generate(prompts=[args.prompt], generation_config=DiffusionGenerationConfig())[0]
- out.save("cat.jpg")
+ if dist.get_rank() == 0:
+ out.save(f"cat_parallel_size{dist.get_world_size()}.jpg")
coordinator.print_on_master(out)
# colossalai run --nproc_per_node 1 examples/inference/stable_diffusion/sd3_generation.py -m MODEL_PATH
+
# colossalai run --nproc_per_node 1 examples/inference/stable_diffusion/sd3_generation.py -m "stabilityai/stable-diffusion-3-medium-diffusers" --tp_size 1
+# colossalai run --nproc_per_node 2 examples/inference/stable_diffusion/sd3_generation.py -m "stabilityai/stable-diffusion-3-medium-diffusers" --tp_size 1
+
+# colossalai run --nproc_per_node 1 examples/inference/stable_diffusion/sd3_generation.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" --tp_size 1
+# colossalai run --nproc_per_node 2 examples/inference/stable_diffusion/sd3_generation.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" --tp_size 1
if __name__ == "__main__":
From 060892162a592c1abf239ad11edd9a91b47188e0 Mon Sep 17 00:00:00 2001
From: Hongxin Liu
Date: Tue, 30 Jul 2024 13:36:00 +0800
Subject: [PATCH 051/116] [zero] hotfix update master params (#5951)
---
colossalai/zero/low_level/low_level_optim.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py
index f1db7c540..01382cd8e 100644
--- a/colossalai/zero/low_level/low_level_optim.py
+++ b/colossalai/zero/low_level/low_level_optim.py
@@ -807,8 +807,8 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
"""
for p in model.parameters():
p_id = id(p)
- pg = self.param_to_pg[p]
if p_id in self.working_to_master_param:
+ pg = self.param_to_pg[p]
master_param = self.working_to_master_param[p_id]
padding_size = self.get_param_padding_size(p)
working_param = p.data.view(-1)
From 09c5f72595228ad5f8e82005b8e442292bc063d1 Mon Sep 17 00:00:00 2001
From: Hongxin Liu
Date: Wed, 31 Jul 2024 10:04:50 +0800
Subject: [PATCH 052/116] [release] update version (#5952)
---
version.txt | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/version.txt b/version.txt
index 267577d47..2b7c5ae01 100644
--- a/version.txt
+++ b/version.txt
@@ -1 +1 @@
-0.4.1
+0.4.2
From 30f4e31a33803d3f525568322c9442f75997b042 Mon Sep 17 00:00:00 2001
From: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Date: Wed, 31 Jul 2024 14:10:17 +0800
Subject: [PATCH 053/116] [Chat] Fix lora (#5946)
* fix merging
* remove filepath
* fix style
---
.../ColossalChat/coati/models/__init__.py | 4 +-
.../ColossalChat/coati/models/lora.py | 354 ++++++++++++++----
applications/ColossalChat/examples/README.md | 84 +++--
.../training_scripts/lora_config.json | 9 +
.../examples/training_scripts/train_dpo.py | 35 +-
.../examples/training_scripts/train_kto.py | 33 +-
.../examples/training_scripts/train_orpo.py | 30 +-
.../examples/training_scripts/train_ppo.py | 41 +-
.../examples/training_scripts/train_rm.py | 37 +-
.../examples/training_scripts/train_sft.py | 50 ++-
.../examples/training_scripts/train_sft.sh | 31 +-
applications/ColossalChat/tests/test_lora.py | 49 ++-
applications/ColossalChat/tests/test_train.sh | 47 ++-
13 files changed, 552 insertions(+), 252 deletions(-)
create mode 100644 applications/ColossalChat/examples/training_scripts/lora_config.json
diff --git a/applications/ColossalChat/coati/models/__init__.py b/applications/ColossalChat/coati/models/__init__.py
index f554cbfa5..fba0949e3 100755
--- a/applications/ColossalChat/coati/models/__init__.py
+++ b/applications/ColossalChat/coati/models/__init__.py
@@ -1,7 +1,7 @@
from .base import BaseModel
from .critic import Critic
from .generation import generate, generate_streaming, prepare_inputs_fn, update_model_kwargs_fn
-from .lora import convert_to_lora_module
+from .lora import LoraConfig, convert_to_lora_module, lora_manager
from .loss import DpoLoss, KTOLoss, LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss
from .reward_model import RewardModel
from .utils import disable_dropout
@@ -14,6 +14,8 @@ __all__ = [
"ValueLoss",
"LogSigLoss",
"LogExpLoss",
+ "LoraConfig",
+ "lora_manager",
"convert_to_lora_module",
"DpoLoss",
"KTOLoss" "generate",
diff --git a/applications/ColossalChat/coati/models/lora.py b/applications/ColossalChat/coati/models/lora.py
index 116c5acec..aa5f6ecf8 100755
--- a/applications/ColossalChat/coati/models/lora.py
+++ b/applications/ColossalChat/coati/models/lora.py
@@ -5,10 +5,11 @@ LORA utils
import dataclasses
import math
import warnings
-from typing import Optional
+from typing import List, Optional, Union
import loralib as lora
import torch
+import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
@@ -18,148 +19,349 @@ logger = get_dist_logger()
@dataclasses.dataclass
-class LoRAManager:
- merge_weights: bool = False
+class LoraManager:
+ able_to_merge: bool = True
-LORA_MANAGER = LoRAManager()
+lora_manager = LoraManager()
-class LoraLinear(lora.LoRALayer, nn.Module):
- """Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear."""
+@dataclasses.dataclass
+class LoraConfig:
+ r: int = 0
+ lora_alpha: int = 32
+ linear_lora_dropout: float = 0.1
+ embedding_lora_dropout: float = 0.0
+ lora_train_bias: str = "none"
+ lora_initialization_method: str = "kaiming_uniform"
+ target_modules: List = None
+ @classmethod
+ def from_file(cls, config_file: str):
+ import json
+
+ with open(config_file, "r") as f:
+ config = json.load(f)
+ return cls(**config)
+
+
+class LoraBase(lora.LoRALayer, nn.Module):
def __init__(
self,
- weight: nn.Parameter,
- bias: Optional[nn.Parameter],
r: int = 0,
- lora_alpha: int = 1,
- lora_dropout: float = 0.0,
- # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
- fan_in_fan_out: bool = False,
+ lora_alpha: int = 32,
+ lora_dropout: float = 0.1,
+ lora_initialization_method: str = "kaiming_uniform",
):
nn.Module.__init__(self)
lora.LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=False)
- self.weight = weight
- self.bias = bias
-
- out_features, in_features = weight.shape
- self.in_features = in_features
- self.out_features = out_features
-
- self.fan_in_fan_out = fan_in_fan_out
- # Actual trainable parameters
- if r > 0:
- self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)), requires_grad=False)
- self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r)))
- self.scaling = self.lora_alpha / self.r
- # Freezing the pre-trained weight matrix
- self.weight.requires_grad = False
- self.reset_parameters()
- if fan_in_fan_out:
- self.weight.data = self.weight.data.T
+ self.r = r
+ self.lora_alpha = lora_alpha
+ self.lora_dropout = nn.Dropout(lora_dropout)
+ self.merged = False
+ self.lora_initialization_method = lora_initialization_method
+ self.weight = None
+ self.bias = None
+ self.lora_A = None
+ self.lora_B = None
def reset_parameters(self):
if hasattr(self, "lora_A"):
- # Initialize A with the default values for nn.Linear and set B to zero.
- nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
- nn.init.zeros_(self.lora_B)
+ if self.lora_initialization_method == "kaiming_uniform" or self.weight.size() != (
+ self.out_features,
+ self.in_features,
+ ):
+ # Initialize A with the default values for nn.Linear and set B to zero.
+ nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
+ nn.init.zeros_(self.lora_B)
+ elif self.lora_initialization_method == "PiSSA":
+ # PiSSA method in this paper: https://arxiv.org/abs/2404.02948
+ # Assume the SVD of the original weights is W = USV^T
+ # Initialize a frozen weight to U[:,r:]S[r:,r:]V^T[:,r:] to store less significent part of W
+ # Only A, B are trainable, which are initialized to S[r:,:r]^0.5V^T[:,:r] and U[:,:r]S[r:,:r] respectively
+ # self.scaling = 1.
+ # SVD
+ U, S, Vh = torch.svd_lowrank(
+ self.weight.to(torch.float32).data, self.r, niter=4
+ ) # U: [out_features, in_features], S: [in_features], V: [in_features, in_features]
+ # weight_backup = self.weight.clone()
+
+ # Initialize A, B
+ S = S / self.scaling
+ self.lora_B.data = (U @ torch.diag(torch.sqrt(S))).to(torch.float32).contiguous()
+ self.lora_A.data = (torch.diag(torch.sqrt(S)) @ Vh.T).to(torch.float32).contiguous()
+ # Initialize weight
+ # To reduce floating point error, we use residual instead of directly using U[:, :self.r] @ S[:self.r] @ Vh[:self.r, :]
+ self.weight.data = (
+ ((self.weight - self.scaling * self.lora_B @ self.lora_A)).contiguous().to(self.weight.dtype)
+ )
+ self.lora_A.requires_grad = True
+ self.lora_B.requires_grad = True
+ else:
+ raise ValueError(f"Unknown LoRA initialization method {self.lora_initialization_method}")
def train(self, mode: bool = True):
"""
This function runs when model.train() is invoked. It is used to prepare the linear layer for training
"""
- def T(w):
- return w.T if self.fan_in_fan_out else w
-
self.training = mode
- if LORA_MANAGER.merge_weights:
- if mode and self.merged:
- warnings.warn("Invoke module.train() would unmerge LoRA weights.")
- raise NotImplementedError("LoRA unmerge is not tested.")
- # Make sure that the weights are not merged
- if self.r > 0:
- if not hasattr(self, "lora_A") or not hasattr(self, "lora_B"):
- # FIXME(csric): temporary fix
- self.lora_A = nn.Parameter(self.weight.new_empty((self.r, self.in_features)))
- self.lora_B = nn.Parameter(self.weight.new_empty((self.out_features, self.r)))
- self.reset_parameters()
- else:
- self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
- self.merged = False
- elif not mode and not self.merged:
- warnings.warn("Invoke module.eval() would merge LoRA weights.")
- # Merge the weights and mark it
- if self.r > 0:
- self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
- delattr(self, "lora_A")
- delattr(self, "lora_B")
- self.merged = True
+ if mode and self.merged:
+ warnings.warn("Invoke module.train() would unmerge LoRA weights.")
+ raise NotImplementedError("LoRA unmerge is not tested.")
+ elif not mode and not self.merged and lora_manager.able_to_merge:
+ warnings.warn("Invoke module.eval() would merge LoRA weights.")
+ # Merge the weights and mark it
+ if self.r > 0:
+ self.weight.data += self.lora_B @ self.lora_A * self.scaling
+ delattr(self, "lora_A")
+ delattr(self, "lora_B")
+ self.merged = True
return self
- def forward(self, x: torch.Tensor):
- def T(w):
- return w.T if self.fan_in_fan_out else w
+class LoraLinear(LoraBase):
+ """Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear."""
+
+ def __init__(
+ self,
+ weight: nn.Parameter,
+ bias: Union[nn.Parameter, bool],
+ r: int = 0,
+ lora_alpha: int = 32,
+ lora_dropout: float = 0.0,
+ lora_initialization_method: str = "kaiming_uniform",
+ ):
+ super().__init__(
+ r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, lora_initialization_method=lora_initialization_method
+ )
+ self.weight = weight
+ self.bias = bias
+ if bias is True:
+ self.bias = nn.Parameter(torch.zeros(weight.shape[0]))
+ if bias is not None:
+ self.bias.requires_grad = True
+
+ out_features, in_features = weight.shape
+ self.in_features = in_features
+ self.out_features = out_features
+ assert lora_initialization_method in ["kaiming_uniform", "PiSSA"]
+ self.lora_initialization_method = lora_initialization_method
+ # Actual trainable parameters
+ if r > 0:
+ self.lora_A = nn.Parameter(torch.randn((r, in_features)))
+ self.lora_B = nn.Parameter(torch.randn((out_features, r)))
+ self.scaling = self.lora_alpha / self.r
+ # Freezing the pre-trained weight matrix
+ self.weight.requires_grad = False
+ self.reset_parameters()
+
+ def forward(self, x: torch.Tensor):
if self.r > 0 and not self.merged:
- result = F.linear(x, T(self.weight), bias=self.bias)
- if self.r > 0:
- result = result + (self.lora_dropout(x) @ self.lora_A.t() @ self.lora_B.t()) * self.scaling
+ result = F.linear(x, self.weight, bias=self.bias)
+ result = result + (self.lora_dropout(x) @ self.lora_A.t() @ self.lora_B.t()) * self.scaling
return result
else:
- return F.linear(x, T(self.weight), bias=self.bias)
+ return F.linear(x, self.weight, bias=self.bias)
-def _lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear:
+class LoraEmbedding(LoraBase):
+ """Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear."""
+
+ def __init__(
+ self,
+ weight: nn.Parameter,
+ r: int = 0,
+ lora_alpha: int = 32,
+ lora_dropout: float = 0.1,
+ num_embeddings: int = None,
+ embedding_dim: int = None,
+ padding_idx: Optional[int] = None,
+ max_norm: Optional[float] = None,
+ norm_type: float = 2.0,
+ scale_grad_by_freq: bool = False,
+ sparse: bool = False,
+ lora_initialization_method: str = "kaiming_uniform",
+ ):
+ super().__init__(
+ r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, lora_initialization_method=lora_initialization_method
+ )
+ self.padding_idx = padding_idx
+ self.max_norm = max_norm
+ self.norm_type = norm_type
+ self.scale_grad_by_freq = scale_grad_by_freq
+ self.sparse = sparse
+ self.num_embeddings = num_embeddings
+ self.embedding_dim = embedding_dim
+
+ self.weight = weight
+
+ in_features, out_features = num_embeddings, embedding_dim
+ self.in_features = in_features
+ self.out_features = out_features
+ assert lora_initialization_method in ["kaiming_uniform", "PiSSA"]
+ self.lora_initialization_method = lora_initialization_method
+
+ # Actual trainable parameters
+ if r > 0:
+ self.lora_A = nn.Parameter(torch.randn((r, in_features)))
+ self.lora_B = nn.Parameter(torch.randn((out_features, r)))
+ self.scaling = self.lora_alpha / self.r
+ # Freezing the pre-trained weight matrix
+ self.weight.requires_grad = False
+
+ # reset parameters
+ nn.init.zeros_(self.lora_A)
+ nn.init.normal_(self.lora_B)
+
+ def _embed(self, x: torch.Tensor, weight) -> torch.Tensor:
+ return F.embedding(
+ x,
+ weight,
+ padding_idx=self.padding_idx,
+ max_norm=self.max_norm,
+ norm_type=self.norm_type,
+ scale_grad_by_freq=self.scale_grad_by_freq,
+ sparse=self.sparse,
+ )
+
+ def forward(self, x: torch.Tensor):
+ base_embedding = self._embed(x, self.weight)
+ # base_embedding.requires_grad = True # force the embedding layer to be trainable for gradient checkpointing
+ if self.r > 0 and not self.merged:
+ lora_A_embedding = self._embed(x, self.lora_A.t())
+ embedding = base_embedding + (lora_A_embedding @ self.lora_B.t()) * self.scaling
+ return embedding
+ else:
+ return base_embedding
+
+ def train(self, mode: bool = True):
+ """
+ This function runs when model.train() is invoked. It is used to prepare the linear layer for training
+ """
+
+ self.training = mode
+ if mode and self.merged:
+ warnings.warn("Invoke module.train() would unmerge LoRA weights.")
+ raise NotImplementedError("LoRA unmerge is not tested.")
+ elif not mode and not self.merged and lora_manager.able_to_merge:
+ warnings.warn("Invoke module.eval() would merge LoRA weights.")
+ # Merge the weights and mark it
+ if self.r > 0:
+ self.weight.data += self.lora_A.t() @ self.lora_B.t() * self.scaling
+ delattr(self, "lora_A")
+ delattr(self, "lora_B")
+ self.merged = True
+
+ return self
+
+
+def _lora_linear_wrapper(linear: nn.Linear, lora_config: LoraConfig) -> LoraLinear:
"""
Wraps a linear layer with LoRA functionality.
Args:
linear (nn.Linear): The linear layer to be wrapped.
lora_rank (int): The rank of the LoRA decomposition.
+ lora_train_bias (str): Whether to train the bias. Can be "none", "all", "lora".
+ lora_initialization_method (str): The initialization method for LoRA. Can be "kaiming_uniform" or "PiSSA".
Returns:
LoraLinear: The wrapped linear layer with LoRA functionality.
"""
assert (
- lora_rank <= linear.in_features
- ), f"LoRA rank ({lora_rank}) must be less than or equal to in features ({linear.in_features})"
- lora_linear = LoraLinear(linear.weight, linear.bias, r=lora_rank)
+ lora_config.r <= linear.in_features
+ ), f"LoRA rank ({lora_config.r}) must be less than or equal to in features ({linear.in_features})"
+ bias = None
+ if lora_config.lora_train_bias in ["all", "lora"]:
+ bias = linear.bias
+ if bias is None:
+ bias = True
+ lora_linear = LoraLinear(
+ linear.weight, bias, r=lora_config.r, lora_initialization_method=lora_config.lora_initialization_method
+ )
return lora_linear
-def _convert_to_lora_recursively(module: nn.Module, lora_rank: int) -> None:
+def _convert_to_lora_recursively(module: nn.Module, parent_name: str, lora_config: LoraConfig) -> None:
"""
Recursively converts the given module and its children to LoRA (Low-Rank Approximation) form.
Args:
module (nn.Module): The module to convert to LoRA form.
lora_rank (int): The rank of the LoRA approximation.
+ lora_train_bias (str): Whether to train the bias. Can be "none", "all", "lora".
+ parent_name (str): The name of the parent module.
+ lora_initialization_method (str): The initialization method for LoRA. Can be "kaiming_uniform" or "PiSSA".
Returns:
None
"""
for name, child in module.named_children():
if isinstance(child, nn.Linear):
- setattr(module, name, _lora_linear_wrapper(child, lora_rank))
+ if lora_config.target_modules is None or any(
+ [name in target_module for target_module in lora_config.target_modules]
+ ):
+ if dist.is_initialized() and dist.get_rank() == 0:
+ logger.info(f"Converting {parent_name}.{name} to LoRA")
+ setattr(module, name, _lora_linear_wrapper(child, lora_config))
+ elif isinstance(child, nn.Embedding):
+ if lora_config.target_modules is None or any(
+ [name in target_module for target_module in lora_config.target_modules]
+ ):
+ if dist.is_initialized() and dist.get_rank() == 0:
+ logger.info(f"Converting {parent_name}.{name} to LoRA")
+ setattr(
+ module,
+ name,
+ LoraEmbedding(
+ child.weight,
+ r=lora_config.r,
+ lora_alpha=lora_config.lora_alpha,
+ lora_dropout=lora_config.embedding_lora_dropout,
+ num_embeddings=child.num_embeddings,
+ embedding_dim=child.embedding_dim,
+ padding_idx=child.padding_idx,
+ max_norm=child.max_norm,
+ norm_type=child.norm_type,
+ scale_grad_by_freq=child.scale_grad_by_freq,
+ sparse=child.sparse,
+ lora_initialization_method=lora_config.lora_initialization_method,
+ ),
+ )
else:
- _convert_to_lora_recursively(child, lora_rank)
+ _convert_to_lora_recursively(child, f"{parent_name}.{name}", lora_config)
-def convert_to_lora_module(module: nn.Module, lora_rank: int, lora_train_bias: str = "none") -> nn.Module:
+def convert_to_lora_module(module: nn.Module, lora_config: LoraConfig) -> nn.Module:
"""Convert a torch.nn.Module to a LoRA module.
Args:
module (nn.Module): The module to convert.
lora_rank (int): LoRA rank.
+ lora_train_bias (str): Whether to train the bias. Can be "none", "all", "lora".
+ lora_initialization_method (str): The initialization method for LoRA. Can be "kaiming_uniform" or "PiSSA".
Returns:
nn.Module: The converted module.
"""
- if lora_rank <= 0:
+ if lora_config.r <= 0:
return module
- _convert_to_lora_recursively(module, lora_rank)
- lora.mark_only_lora_as_trainable(module, lora_train_bias)
+ # make all parameter not trainable, if lora_train_bias is "all", set bias to trainable
+ total_parameter_size = 0
+ for name, p in module.named_parameters():
+ p.requires_grad = False
+ if "bias" in name and lora_config.lora_train_bias == "all":
+ p.requires_grad = True
+ total_parameter_size += p.numel()
+ _convert_to_lora_recursively(module, "", lora_config)
+ trainable_parameter_size = 0
+ for name, p in module.named_parameters():
+ if p.requires_grad == True:
+ trainable_parameter_size += p.numel()
+ if dist.is_initialized() and dist.get_rank() == 0:
+ logger.info(
+ f"Trainable parameter size: {trainable_parameter_size/1024/1024:.2f}M\nOriginal trainable parameter size: {total_parameter_size/1024/1024:.2f}M\nPercentage: {trainable_parameter_size/total_parameter_size*100:.2f}%"
+ )
return module
diff --git a/applications/ColossalChat/examples/README.md b/applications/ColossalChat/examples/README.md
index f68875568..4e80b3e47 100755
--- a/applications/ColossalChat/examples/README.md
+++ b/applications/ColossalChat/examples/README.md
@@ -9,6 +9,7 @@
- [Install Requirements](#install-requirements)
- [Get Start with ColossalRun](#get-start-with-colossalrun)
- [Training Configuration](#training-configuration)
+ - [Parameter Efficient Finetuning (PEFT)](#parameter-efficient-finetuning-peft)
- [RLHF Stage 1: Supervised Instruction Tuning](#rlhf-training-stage1---supervised-instructs-tuning)
- [Step 1: Data Collection](#step-1-data-collection)
- [Step 2: Preprocessing](#step-2-preprocessing)
@@ -377,35 +378,6 @@ colossalai run --nproc_per_node 4 --master_port 28534 --hostfile ./hostfile trai
-Low Rank Adaption
-
-
-Details about Low Rank Adaption (LoRA) can be found in the paper: [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685). It dramatically reduces the VRAM consumption at the cost of sacrifice model capability. It is suitable for training LLM with constrained resources.
-
-
-To enable LoRA, set --lora_rank to a positive value (usually between 20 and 64).
-```
-colossalai run --nproc_per_node 4 --master_port 28534 --hostfile ./hostfile train_sft.py \
- --pretrain $PRETRAINED_MODEL_PATH \
- --tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
- --dataset ${dataset[@]} \
- --save_interval 5000 \
- --save_path $SAVE_DIR \
- --config_file $CONFIG_FILE \
- --plugin zero2_cpu \
- --batch_size 4 \
- --max_epochs 1 \
- --accumulation_steps 4 \
- --lr 2e-5 \
- --max_len 2048 \
- --lora_rank 32 \ # This enables LoRA
- --use_wandb
-```
-
-
-
-
-
Other Training Arguments
@@ -430,6 +402,60 @@ colossalai run --nproc_per_node 4 --master_port 28534 --hostfile ./hostfile trai
- use_wandb: if this flag is up, you can view logs on wandb.
+
+
+### Parameter Efficient Finetuning (PEFT)
+
+Currently, we have support LoRA (low-rank adaptation) and PiSSA (principal singular values and singular vectors adaptation). Both help to reduce the running-time VRAM consumption as well as timing at the cost of overall model performance.
+
+
+Low Rank Adaption and PiSSA
+
+
+Details about Low Rank Adaption (LoRA) can be found in the paper: [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685). Details about Principal Singular Values and Singular Vectors Adaptation (PiSSA) can be found in the paper: [PiSSA: Principal Singular Values and Singular Vectors Adaptation of Large Language Models](https://arxiv.org/abs/2404.02948). Both help to reduce the running-time VRAM consumption as well as timing at the cost of overall model performance. It is suitable for training LLM with constrained resources.
+
+To use LoRA/PiSSA in training, please create a config file as in the following example and set the `--lora_config` to that configuration file.
+
+```json
+{
+ "r": 128,
+ "embedding_lora_dropout": 0.0,
+ "linear_lora_dropout": 0.1,
+ "lora_alpha": 32,
+ "lora_train_bias": "all",
+ "lora_initialization_method": "PiSSA",
+ "target_modules": ["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj", "embed_tokens"]
+}
+```
+#### Lora Parameters
+- r: lora rank
+- embedding_lora_dropout: dropout probability for embedding layer
+- linear_lora_dropout: dropout probability for linear layer
+- lora_alpha: lora alpha, controls how much the adaptor can deviate from the pretrained model.
+- lora_train_bias: whether to add trainable bias to lora layers, choose from "all" (all layers (including but not limited to lora layers) will have trainable biases), "none" (no trainable biases), "lora" (only lora layers will have trainable biases)
+- lora_initialization_method: how to initialize lora weights, choose one from ["kaiming_uniform", "PiSSA"], default to "kaiming_uniform". Use "kaiming_uniform" for standard LoRA and "PiSSA" for PiSSA.
+- target_modules: which module(s) should be converted to lora layers, if the module's name contain the keywords in target modules and the module is a linear or embedding layer, the module will be converted. Otherwise, the module will be frozen. Setting this field to None will automatically convert all linear and embedding layer to their LoRA counterparts. Note that this example only works for LLaMA, for other models, you need to modify it.
+
+
+```
+colossalai run --nproc_per_node 4 --master_port 28534 --hostfile ./hostfile train_sft.py \
+ --pretrain $PRETRAINED_MODEL_PATH \
+ --tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
+ --dataset ${dataset[@]} \
+ --save_interval 5000 \
+ --save_path $SAVE_DIR \
+ --config_file $CONFIG_FILE \
+ --plugin zero2_cpu \
+ --batch_size 4 \
+ --max_epochs 1 \
+ --accumulation_steps 4 \
+ --lr 2e-5 \
+ --max_len 2048 \
+ --lora_config /PATH/TO/THE/LORA/CONFIG/FILE.json \ # Setting this enables LoRA
+ --use_wandb
+```
+
+
diff --git a/applications/ColossalChat/examples/training_scripts/lora_config.json b/applications/ColossalChat/examples/training_scripts/lora_config.json
new file mode 100644
index 000000000..4565f9e9b
--- /dev/null
+++ b/applications/ColossalChat/examples/training_scripts/lora_config.json
@@ -0,0 +1,9 @@
+{
+ "r": 128,
+ "embedding_lora_dropout": 0.0,
+ "linear_lora_dropout": 0.1,
+ "lora_alpha": 32,
+ "lora_train_bias": "all",
+ "lora_initialization_method": "PiSSA",
+ "target_modules": ["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj", "embed_tokens"]
+}
diff --git a/applications/ColossalChat/examples/training_scripts/train_dpo.py b/applications/ColossalChat/examples/training_scripts/train_dpo.py
index 3659fb868..44131f572 100755
--- a/applications/ColossalChat/examples/training_scripts/train_dpo.py
+++ b/applications/ColossalChat/examples/training_scripts/train_dpo.py
@@ -6,7 +6,7 @@ from contextlib import nullcontext
import torch
from coati.dataset import DataCollatorForPreferenceDataset, StatefulDistributedSampler, load_tokenized_dataset
-from coati.models import convert_to_lora_module, disable_dropout
+from coati.models import LoraConfig, convert_to_lora_module, disable_dropout
from coati.trainer import DPOTrainer
from coati.utils import load_checkpoint
from transformers import AutoModelForCausalLM, AutoTokenizer
@@ -23,8 +23,11 @@ logger = get_dist_logger()
def train(args):
+ lora_config = None
+ if args.lora_config is not None:
+ lora_config = LoraConfig.from_file(args.lora_config)
# check lora compatibility
- if "gemini" in args.plugin and args.lora_rank > 0:
+ if "gemini" in args.plugin and lora_config is not None and lora_config.r > 0:
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
@@ -115,7 +118,7 @@ def train(args):
coordinator.print_on_master(msg="Flash-attention enabled successfully")
else:
model = AutoModelForCausalLM.from_pretrained(args.pretrain)
- disable_dropout(model)
+
if not args.disable_reference_model:
if args.use_flash_attn:
ref_model = AutoModelForCausalLM.from_pretrained(
@@ -125,15 +128,19 @@ def train(args):
)
else:
ref_model = AutoModelForCausalLM.from_pretrained(args.pretrain)
- disable_dropout(ref_model)
else:
ref_model = None
- if args.lora_rank > 0:
- model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
+ if args.lora_config is not None:
+ model = convert_to_lora_module(model, lora_config=lora_config)
+ for name, module in model.named_modules():
+ if "norm" in name or "gate" in name:
+ module = module.to(torch.float32)
+ disable_dropout(model)
+ disable_dropout(ref_model)
if args.grad_checkpoint:
# Note, for some models, lora may not be compatible with gradient checkpointing
- model.gradient_checkpointing_enable()
+ model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
# configure tokenizer
@@ -280,11 +287,8 @@ def train(args):
use_wandb=args.use_wandb,
)
- if args.lora_rank > 0 and args.merge_lora_weights:
- from coati.models.lora import LORA_MANAGER
-
+ if lora_config is not None and lora_config.r > 0:
# NOTE: set model to eval to merge LoRA weights
- LORA_MANAGER.merge_weights = True
model.eval()
# save model checkpoint after fitting on only rank0
if args.save_dir is not None:
@@ -343,15 +347,8 @@ if __name__ == "__main__":
help="Disable the reference model (enabled by default)",
)
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
- parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
- parser.add_argument(
- "--lora_train_bias",
- type=str,
- default="none",
- help="'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers",
- )
+ parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path")
parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints")
- parser.add_argument("--merge_lora_weights", type=bool, default=True)
parser.add_argument("--lr", type=float, default=5e-6)
parser.add_argument("--accumulation_steps", type=int, default=8)
parser.add_argument("--log_dir", default=None, type=str)
diff --git a/applications/ColossalChat/examples/training_scripts/train_kto.py b/applications/ColossalChat/examples/training_scripts/train_kto.py
index 8d1408423..d063b82bb 100755
--- a/applications/ColossalChat/examples/training_scripts/train_kto.py
+++ b/applications/ColossalChat/examples/training_scripts/train_kto.py
@@ -6,7 +6,7 @@ from contextlib import nullcontext
import torch
from coati.dataset import DataCollatorForKTODataset, StatefulDistributedSampler, load_tokenized_dataset
-from coati.models import convert_to_lora_module, disable_dropout
+from coati.models import LoraConfig, convert_to_lora_module, disable_dropout
from coati.trainer import KTOTrainer
from coati.utils import load_checkpoint
from transformers import AutoModelForCausalLM, AutoTokenizer
@@ -23,8 +23,11 @@ logger = get_dist_logger()
def train(args):
+ lora_config = None
+ if args.lora_config is not None:
+ lora_config = LoraConfig.from_file(args.lora_config)
# check lora compatibility
- if "gemini" in args.plugin and args.lora_rank > 0:
+ if "gemini" in args.plugin and lora_config is not None and lora_config.r > 0:
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
@@ -115,7 +118,7 @@ def train(args):
coordinator.print_on_master(msg="Flash-attention enabled successfully")
else:
model = AutoModelForCausalLM.from_pretrained(args.pretrain)
- disable_dropout(model)
+
if args.use_flash_attn:
ref_model = AutoModelForCausalLM.from_pretrained(
args.pretrain,
@@ -124,13 +127,17 @@ def train(args):
)
else:
ref_model = AutoModelForCausalLM.from_pretrained(args.pretrain)
+ if args.lora_config is not None:
+ model = convert_to_lora_module(model, lora_config=lora_config)
+ for name, module in model.named_modules():
+ if "norm" in name or "gate" in name:
+ module = module.to(torch.float32)
disable_dropout(ref_model)
- if args.lora_rank > 0:
- model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
+ disable_dropout(model)
if args.grad_checkpoint:
# Note, for some models, lora may not be compatible with gradient checkpointing
- model.gradient_checkpointing_enable()
+ model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
# configure tokenizer
@@ -299,11 +306,8 @@ def train(args):
use_wandb=args.use_wandb,
)
- if args.lora_rank > 0 and args.merge_lora_weights:
- from coati.models.lora import LORA_MANAGER
-
+ if lora_config is not None and lora_config.r > 0:
# NOTE: set model to eval to merge LoRA weights
- LORA_MANAGER.merge_weights = True
model.eval()
# save model checkpoint after fitting on only rank0
if args.save_dir is not None:
@@ -355,15 +359,8 @@ if __name__ == "__main__":
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
- parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
- parser.add_argument(
- "--lora_train_bias",
- type=str,
- default="none",
- help="'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers",
- )
+ parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path")
parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints")
- parser.add_argument("--merge_lora_weights", type=bool, default=True)
parser.add_argument("--auto_weight", default=False, action="store_true")
parser.add_argument("--lr", type=float, default=5e-6)
parser.add_argument("--accumulation_steps", type=int, default=8)
diff --git a/applications/ColossalChat/examples/training_scripts/train_orpo.py b/applications/ColossalChat/examples/training_scripts/train_orpo.py
index 3efe5e96a..f06524507 100755
--- a/applications/ColossalChat/examples/training_scripts/train_orpo.py
+++ b/applications/ColossalChat/examples/training_scripts/train_orpo.py
@@ -6,7 +6,7 @@ from contextlib import nullcontext
import torch
from coati.dataset import DataCollatorForPreferenceDataset, StatefulDistributedSampler, load_tokenized_dataset
-from coati.models import convert_to_lora_module, disable_dropout
+from coati.models import LoraConfig, convert_to_lora_module, disable_dropout
from coati.trainer import ORPOTrainer
from coati.utils import load_checkpoint
from transformers import AutoModelForCausalLM, AutoTokenizer
@@ -23,8 +23,11 @@ logger = get_dist_logger()
def train(args):
+ lora_config = None
+ if args.lora_config is not None:
+ lora_config = LoraConfig.from_file(args.lora_config)
# check lora compatibility
- if "gemini" in args.plugin and args.lora_rank > 0:
+ if "gemini" in args.plugin and lora_config is not None and lora_config.r > 0:
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
@@ -114,13 +117,16 @@ def train(args):
coordinator.print_on_master(msg="Flash-attention enabled successfully")
else:
model = AutoModelForCausalLM.from_pretrained(args.pretrain)
+ if args.lora_config is not None:
+ model = convert_to_lora_module(model, lora_config=lora_config)
+ for name, module in model.named_modules():
+ if "norm" in name or "gate" in name:
+ module = module.to(torch.float32)
disable_dropout(model)
- if args.lora_rank > 0:
- model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
if args.grad_checkpoint:
# Note, for some models, lora may not be compatible with gradient checkpointing
- model.gradient_checkpointing_enable()
+ model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
# configure tokenizer
@@ -262,11 +268,8 @@ def train(args):
use_wandb=args.use_wandb,
)
- if args.lora_rank > 0 and args.merge_lora_weights:
- from coati.models.lora import LORA_MANAGER
-
+ if lora_config is not None and lora_config.r > 0:
# NOTE: set model to eval to merge LoRA weights
- LORA_MANAGER.merge_weights = True
model.eval()
# save model checkpoint after fitting on only rank0
if args.save_dir is not None:
@@ -322,15 +325,8 @@ if __name__ == "__main__":
help="Disable the reference model (enabled by default)",
)
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
- parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
- parser.add_argument(
- "--lora_train_bias",
- type=str,
- default="none",
- help="'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers",
- )
+ parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path")
parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints")
- parser.add_argument("--merge_lora_weights", type=bool, default=True)
parser.add_argument("--lr", type=float, default=5e-6)
parser.add_argument("--accumulation_steps", type=int, default=8)
parser.add_argument("--log_dir", default=None, type=str)
diff --git a/applications/ColossalChat/examples/training_scripts/train_ppo.py b/applications/ColossalChat/examples/training_scripts/train_ppo.py
index 3da3e9ca6..333be9963 100755
--- a/applications/ColossalChat/examples/training_scripts/train_ppo.py
+++ b/applications/ColossalChat/examples/training_scripts/train_ppo.py
@@ -13,7 +13,7 @@ from coati.dataset import (
load_tokenized_dataset,
setup_conversation_template,
)
-from coati.models import Critic, RewardModel, convert_to_lora_module, disable_dropout
+from coati.models import Critic, LoraConfig, RewardModel, convert_to_lora_module, disable_dropout, lora_manager
from coati.trainer import PPOTrainer
from coati.utils import load_checkpoint
from transformers import AutoModelForCausalLM, AutoTokenizer
@@ -31,8 +31,11 @@ logger = get_dist_logger()
def train(args):
+ lora_config = None
+ if args.lora_config is not None:
+ lora_config = LoraConfig.from_file(args.lora_config)
# check lora compatibility
- if "gemini" in args.plugin and args.lora_rank > 0:
+ if "gemini" in args.plugin and lora_config is not None and lora_config.r > 0:
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
@@ -81,20 +84,26 @@ def train(args):
ref_model = AutoModelForCausalLM.from_pretrained(args.pretrain, local_files_only=True)
reward_model = RewardModel(args.rm_pretrain)
critic = Critic(args.rm_pretrain)
+
+ if args.lora_config is not None:
+ actor = convert_to_lora_module(actor, lora_config=lora_config)
+ critic = convert_to_lora_module(critic, lora_config=lora_config)
+ for name, module in actor.named_modules():
+ if "norm" in name or "gate" in name:
+ module = module.to(torch.float32)
+ for name, module in critic.named_modules():
+ if "norm" in name or "gate" in name:
+ module = module.to(torch.float32)
+ lora_manager.able_to_merge = False
+
# Disable dropout
disable_dropout(actor)
disable_dropout(critic)
- if args.lora_rank > 0:
- actor = convert_to_lora_module(actor, args.lora_rank, lora_train_bias=args.lora_train_bias)
- critic = convert_to_lora_module(critic, args.lora_rank, lora_train_bias=args.lora_train_bias)
-
- if args.grad_checkpoint and args.lora_rank == 0:
- actor.gradient_checkpointing_enable()
- critic.model.gradient_checkpointing_enable()
+ if args.grad_checkpoint:
+ actor.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
+ critic.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
- elif args.lora_rank > 0:
- coordinator.print_on_master(msg="Gradient checkpointing will be disabled when LoRA is enabled")
# configure tokenizer
tokenizer_dir = args.tokenizer_dir if args.tokenizer_dir is not None else args.pretrain
@@ -421,11 +430,9 @@ def train(args):
use_wandb=args.use_wandb,
)
- if args.lora_rank > 0 and args.merge_lora_weights:
- from coati.models.lora import LORA_MANAGER
-
+ if lora_config is not None and lora_config.r > 0:
# NOTE: set model to eval to merge LoRA weights
- LORA_MANAGER.merge_weights = True
+ lora_manager.able_to_merge = True
actor.eval()
critic.eval()
# save model checkpoint after fitting on only rank0
@@ -484,11 +491,9 @@ if __name__ == "__main__":
parser.add_argument("--train_batch_size", type=int, default=16)
parser.add_argument("--experience_batch_size", type=int, default=16)
parser.add_argument("--ptx_batch_size", type=int, default=4)
- parser.add_argument("--lora_train_bias", type=str, default="none")
+ parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path")
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
parser.add_argument("--accumulation_steps", type=int, default=8)
- parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
- parser.add_argument("--merge_lora_weights", type=bool, default=True)
parser.add_argument("--lr", type=float, default=9e-6)
parser.add_argument("--critic_lr", type=float, default=9e-6)
parser.add_argument("--kl_coef", type=float, default=0.1)
diff --git a/applications/ColossalChat/examples/training_scripts/train_rm.py b/applications/ColossalChat/examples/training_scripts/train_rm.py
index 09b569b7a..4c0a782b4 100755
--- a/applications/ColossalChat/examples/training_scripts/train_rm.py
+++ b/applications/ColossalChat/examples/training_scripts/train_rm.py
@@ -7,7 +7,7 @@ from contextlib import nullcontext
import torch
from coati.dataset import DataCollatorForPreferenceDataset, StatefulDistributedSampler, load_tokenized_dataset
-from coati.models import LogExpLoss, LogSigLoss, RewardModel, convert_to_lora_module
+from coati.models import LogExpLoss, LogSigLoss, LoraConfig, RewardModel, convert_to_lora_module
from coati.trainer import RewardModelTrainer
from coati.utils import load_checkpoint
from transformers import AutoTokenizer
@@ -25,8 +25,11 @@ logger = get_dist_logger()
def train(args):
+ lora_config = None
+ if args.lora_config is not None:
+ lora_config = LoraConfig.from_file(args.lora_config)
# check lora compatibility
- if "gemini" in args.plugin and args.lora_rank > 0:
+ if "gemini" in args.plugin and lora_config is not None and lora_config.r > 0:
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
@@ -58,9 +61,11 @@ def train(args):
args.pretrain,
)
- if args.lora_rank > 0:
- model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
-
+ if lora_config is not None:
+ model = convert_to_lora_module(model, lora_config=lora_config)
+ for name, module in model.named_modules():
+ if "norm" in name or "gate" in name:
+ module = module.to(torch.float32)
# ==============================
# Initialize Booster
# ==============================
@@ -122,11 +127,9 @@ def train(args):
booster = Booster(plugin=plugin)
- if args.grad_checkpoint and args.lora_rank == 0:
- model.model.gradient_checkpointing_enable() # TODO: support gradient checkpoint for the last linear layer
+ if args.grad_checkpoint:
+ model.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
- elif args.lora_rank > 0:
- coordinator.print_on_master(msg="Gradient checkpointing will be disabled when LoRA is enabled")
# configure tokenizer
tokenizer_dir = args.tokenizer_dir if args.tokenizer_dir is not None else args.pretrain
@@ -272,16 +275,13 @@ def train(args):
trainer.fit(
train_preference_dataloader=train_dataloader,
- eval_preference_dataloader=None,
+ eval_preference_dataloader=eval_dataloader,
log_dir=args.log_dir,
use_wandb=args.use_wandb,
)
- if args.lora_rank > 0 and args.merge_lora_weights:
- from coati.models.lora import LORA_MANAGER
-
+ if lora_config is not None and lora_config.r > 0:
# NOTE: set model to eval to merge LoRA weights
- LORA_MANAGER.merge_weights = True
model.eval()
# save model checkpoint after fitting on only rank0
if args.save_dir is not None:
@@ -330,15 +330,8 @@ if __name__ == "__main__":
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
parser.add_argument("--loss_fn", type=str, default="log_sig", choices=["log_sig", "log_exp"], help="Loss function")
- parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
- parser.add_argument(
- "--lora_train_bias",
- type=str,
- default="none",
- help="'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers",
- )
+ parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path")
parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints")
- parser.add_argument("--merge_lora_weights", type=bool, default=True)
parser.add_argument("--lr", type=float, default=5e-6)
parser.add_argument("--accumulation_steps", type=int, default=8)
parser.add_argument("--log_dir", default=None, type=str)
diff --git a/applications/ColossalChat/examples/training_scripts/train_sft.py b/applications/ColossalChat/examples/training_scripts/train_sft.py
index 2579293c1..6007a8599 100755
--- a/applications/ColossalChat/examples/training_scripts/train_sft.py
+++ b/applications/ColossalChat/examples/training_scripts/train_sft.py
@@ -7,7 +7,7 @@ from contextlib import nullcontext
import torch
from coati.dataset import DataCollatorForSupervisedDataset, StatefulDistributedSampler, load_tokenized_dataset
-from coati.models import convert_to_lora_module
+from coati.models import LoraConfig, convert_to_lora_module
from coati.trainer import SFTTrainer
from coati.utils import load_checkpoint
from transformers import AutoModelForCausalLM, AutoTokenizer
@@ -24,8 +24,11 @@ logger = get_dist_logger()
def train(args):
+ lora_config = None
+ if args.lora_config is not None:
+ lora_config = LoraConfig.from_file(args.lora_config)
# check lora compatibility
- if "gemini" in args.plugin and args.lora_rank > 0:
+ if "gemini" in args.plugin and lora_config is not None and lora_config.r > 0:
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
@@ -53,8 +56,12 @@ def train(args):
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
trust_remote_code=True,
)
- if args.lora_rank > 0:
- model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
+
+ if lora_config is not None:
+ model = convert_to_lora_module(model, lora_config=lora_config)
+ for name, module in model.named_modules():
+ if "norm" in name or "gate" in name:
+ module = module.to(torch.float32)
if args.plugin == "ddp":
"""
@@ -114,6 +121,15 @@ def train(args):
booster = Booster(plugin=plugin)
+ # configure optimizer
+ optim = HybridAdam(
+ model_params=model.parameters(),
+ lr=args.lr,
+ betas=(0.9, 0.95),
+ weight_decay=args.weight_decay,
+ adamw_mode=True,
+ )
+
# ======================================================
# Initialize Model, Objective, Optimizer and LR Scheduler
# ======================================================
@@ -124,7 +140,7 @@ def train(args):
if args.grad_checkpoint:
# Note, for some models, lora may not be compatible with gradient checkpointing
- model.gradient_checkpointing_enable()
+ model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
# configure tokenizer
@@ -149,15 +165,6 @@ def train(args):
coordinator.print_on_master(f"Configuration file will be saved at: {args.config_file}")
coordinator.print_on_master(f"Model checkpoint will be saved at: {args.save_path}")
- # configure optimizer
- optim = HybridAdam(
- model_params=model.parameters(),
- lr=args.lr,
- betas=(0.9, 0.95),
- weight_decay=args.weight_decay,
- adamw_mode=True,
- )
-
# configure dataset
coordinator.print_on_master(
f"Max CUDA memory before data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
@@ -217,6 +224,7 @@ def train(args):
lr_scheduler=lr_scheduler,
dataloader=train_dataloader,
)
+
torch.set_default_dtype(torch.float)
coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
@@ -277,11 +285,8 @@ def train(args):
use_wandb=args.use_wandb,
)
- if args.lora_rank > 0 and args.merge_lora_weights:
- from coati.models.lora import LORA_MANAGER
-
+ if lora_config is not None and lora_config.r > 0:
# NOTE: set model to eval to merge LoRA weights
- LORA_MANAGER.merge_weights = True
model.eval()
# save model checkpoint after fitting on only rank0
if args.save_path is not None:
@@ -328,15 +333,8 @@ if __name__ == "__main__":
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--max_len", type=int, default=512)
parser.add_argument("--mixed_precision", type=str, default="bf16", choices=["fp16", "bf16"], help="Mixed precision")
- parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
- parser.add_argument(
- "--lora_train_bias",
- type=str,
- default="none",
- help="'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers",
- )
+ parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path")
parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints")
- parser.add_argument("--merge_lora_weights", type=bool, default=True)
parser.add_argument("--lr", type=float, default=5e-6)
parser.add_argument("--config_file", type=str, default=None, help="Config file")
parser.add_argument("--accumulation_steps", type=int, default=8)
diff --git a/applications/ColossalChat/examples/training_scripts/train_sft.sh b/applications/ColossalChat/examples/training_scripts/train_sft.sh
index 988c7f9c3..e87184c81 100755
--- a/applications/ColossalChat/examples/training_scripts/train_sft.sh
+++ b/applications/ColossalChat/examples/training_scripts/train_sft.sh
@@ -21,16 +21,16 @@ PARENT_LOG_DIR="" # Path to a folder to save training config logs
PRETRAINED_MODEL_PATH="" # huggingface or local model path
PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
declare -a dataset=(
- /Your/SFT/Data/arrow/part-00000
- /Your/SFT/Data/arrow/part-00001
- /Your/SFT/Data/arrow/part-00002
- /Your/SFT/Data/arrow/part-00003
- /Your/SFT/Data/arrow/part-00004
- /Your/SFT/Data/arrow/part-00005
- /Your/SFT/Data/arrow/part-00006
- /Your/SFT/Data/arrow/part-00007
- /Your/SFT/Data/arrow/part-00008
- /Your/SFT/Data/arrow/part-00009
+ YOUR/SFT/DATA/DIR/arrow/part-00000
+ YOUR/SFT/DATA/DIR/arrow/part-00001
+ YOUR/SFT/DATA/DIR/arrow/part-00002
+ YOUR/SFT/DATA/DIR/arrow/part-00003
+ YOUR/SFT/DATA/DIR/arrow/part-00004
+ YOUR/SFT/DATA/DIR/arrow/part-00005
+ YOUR/SFT/DATA/DIR/arrow/part-00006
+ YOUR/SFT/DATA/DIR/arrow/part-00007
+ YOUR/SFT/DATA/DIR/arrow/part-00008
+ YOUR/SFT/DATA/DIR/arrow/part-00009
)
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
@@ -47,15 +47,14 @@ colossalai run --nproc_per_node 4 --master_port 31312 --hostfile ./hostfile trai
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
--save_interval 2000 \
--dataset ${dataset[@]} \
- --save_path $SAVE_DIR \
- --config_file $CONFIG_FILE \
- --log_dir $LOG_DIR \
- --lora_rank 0 \
--plugin zero2 \
--batch_size 8 \
--max_epochs 1 \
- --accumulation_steps 2 \
+ --accumulation_steps 1 \
--lr 5e-5 \
--max_len 4096 \
+ --use_flash_attn \
--grad_checkpoint \
- --use_flash_attn
+ --save_path $SAVE_DIR \
+ --config_file $CONFIG_FILE \
+ --log_dir $LOG_DIR \
diff --git a/applications/ColossalChat/tests/test_lora.py b/applications/ColossalChat/tests/test_lora.py
index 4ea9e1a15..778759210 100755
--- a/applications/ColossalChat/tests/test_lora.py
+++ b/applications/ColossalChat/tests/test_lora.py
@@ -2,6 +2,7 @@ import torch
import torch.nn as nn
import torch.optim as optim
from coati.models import convert_to_lora_module
+from coati.models.lora import LoraConfig, LoraEmbedding, LoraLinear
from torch.utils.data import DataLoader, TensorDataset
@@ -38,7 +39,7 @@ def test_overfit():
# Build and convert model
model = SimpleNN(input_size, hidden_size, num_classes)
weight_to_compare = model.fc1.weight.detach().clone()
- model = convert_to_lora_module(model, lora_rank=30)
+ model = convert_to_lora_module(model, lora_config=LoraConfig(r=32))
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
@@ -50,7 +51,6 @@ def test_overfit():
# Forward pass
outputs = model(inputs)
loss = criterion(outputs, labels)
- print(loss)
# Backward and optimize
optimizer.zero_grad()
loss.backward()
@@ -65,5 +65,50 @@ def test_overfit():
assert (weight_to_compare - model.fc1.weight).sum() < 0.01
+def test_lora_linear_accuracy():
+
+ weight = torch.randn(10, 5)
+ linear = nn.Linear(5, 10)
+ linear.weight.data = weight
+ x = torch.randn(10, 5)
+ out_linear = linear(x)
+
+ # lora linear Pissa
+ linear.weight.data = weight
+ lora_linear = LoraLinear(linear.weight, linear.bias, r=2, lora_initialization_method="PiSSA")
+ out_lora = lora_linear(x)
+ assert torch.allclose(out_linear, out_lora, atol=1e-5, rtol=1e-05)
+
+ # lora linear
+ linear.weight.data = weight
+ lora_linear = LoraLinear(linear.weight, linear.bias, r=2)
+ out_lora = lora_linear(x)
+ assert torch.allclose(out_linear, out_lora, atol=1e-5, rtol=1e-05)
+
+
+def test_lora_embedding_accuracy():
+ weight = torch.randn(10, 5)
+ embedding = nn.Embedding(10, 5)
+ embedding.weight.data = weight
+ x = torch.randint(0, 10, (10,))
+ out_embedding = embedding(x)
+
+ # lora embedding Pissa
+ embedding.weight.data = weight
+ lora_embedding = LoraEmbedding(
+ embedding.weight, r=2, lora_initialization_method="PiSSA", num_embeddings=10, embedding_dim=5
+ )
+ out_lora = lora_embedding(x)
+ assert torch.allclose(out_embedding, out_lora, atol=1e-5, rtol=1e-05)
+
+ # lora embedding
+ embedding.weight.data = weight
+ lora_embedding = LoraEmbedding(embedding.weight, r=2, num_embeddings=10, embedding_dim=5)
+ out_lora = lora_embedding(x)
+ assert torch.allclose(out_embedding, out_lora, atol=1e-5, rtol=1e-05)
+
+
if __name__ == "__main__":
test_overfit()
+ test_lora_linear_accuracy()
+ test_lora_embedding_accuracy()
diff --git a/applications/ColossalChat/tests/test_train.sh b/applications/ColossalChat/tests/test_train.sh
index 57584e406..c26b25c83 100755
--- a/applications/ColossalChat/tests/test_train.sh
+++ b/applications/ColossalChat/tests/test_train.sh
@@ -30,9 +30,10 @@ MODEL_SAVE_PATH=$TEMP_DIR/rlhf_models
MODELS_DIR=$TEMP_DIR/models_config
# Skip those tests due to CI tests timeout
MODELS=('llama')
-ADVANCED_PLUGINS=('sp_split_gather' 'sp_ring' 'sp_all_to_all' 'tp_zero2' '3d' 'gemini' 'gemini_auto' 'zero2' 'zero2_cpu') # pp is still buggy
-PLUGINS=('3d' 'gemini' 'gemini_auto' 'zero2' 'zero2_cpu')
+ADVANCED_PLUGINS=('zero2' 'sp_split_gather' 'sp_ring' 'sp_all_to_all' 'tp_zero2' '3d' 'gemini' 'gemini_auto' 'zero2_cpu') # pp is still buggy
+PLUGINS=('zero2' '3d' 'gemini' 'gemini_auto' 'zero2_cpu')
LORA_RANK=('0') # skip to reduce CI execution time, can pass all locally
+LORA_CONFIG_ENABLE="--lora_config $BASE_DIR/examples/training_scripts/lora_config.json"
export OMP_NUM_THREADS=8
@@ -112,6 +113,11 @@ for lora_rank in ${LORA_RANK[@]}; do
sp='1'
sp_mode='split_gather'
enable_sequence_parallelism=''
+ if [[ $plugin == "zero2" ]]; then
+ lora_config=$LORA_CONFIG_ENABLE
+ else
+ lora_config=""
+ fi
if [[ $plugin == "3d" ]]; then
tp='4'
bs='8'
@@ -176,7 +182,7 @@ for lora_rank in ${LORA_RANK[@]}; do
--eval_dataset ${dataset[@]} \
--save_path $MODEL_SAVE_PATH \
--config_file $MODELS_DIR/config.jsonl \
- --lora_rank $lora_rank \
+ $lora_config \
--plugin $plugin \
--batch_size $bs \
--max_epochs 1 \
@@ -230,6 +236,11 @@ for lora_rank in ${LORA_RANK[@]}; do
grad_ckpt=$(random_choice "${GRAD_CKPTS[@]}")
tp='1'
bs='2'
+ if [[ $plugin == "zero2" ]]; then
+ lora_config=$LORA_CONFIG_ENABLE
+ else
+ lora_config=""
+ fi
if [[ $plugin == "3d" ]]; then
tp='4'
bs='8'
@@ -252,7 +263,7 @@ for lora_rank in ${LORA_RANK[@]}; do
--eval_dataset ${dataset[@]} \
--save_dir $MODEL_SAVE_PATH \
--config_file $MODELS_DIR/config.jsonl \
- --lora_rank $lora_rank \
+ $lora_config \
--plugin $plugin \
--batch_size $bs \
--max_epochs 1 \
@@ -308,6 +319,11 @@ for lora_rank in ${LORA_RANK[@]}; do
bs='4'
ebs='8'
conversation_template=$(get_conversation_template_config $model)
+ if [[ $plugin == "zero2" ]]; then
+ lora_config=$LORA_CONFIG_ENABLE
+ else
+ lora_config=""
+ fi
if [[ $plugin == "3d" ]]; then
tp='4'
bs='16'
@@ -344,7 +360,7 @@ for lora_rank in ${LORA_RANK[@]}; do
--ptx_batch_size 1 \
--ptx_coef 0.2 \
--save_path $MODEL_SAVE_PATH \
- --lora_rank $lora_rank \
+ $lora_config \
--plugin $plugin \
--num_episodes 5 \
--num_collect_steps 1 \
@@ -404,6 +420,11 @@ for lora_rank in ${LORA_RANK[@]}; do
tp='4'
bs='8'
fi
+ if [[ $plugin == "zero2" ]]; then
+ lora_config=$LORA_CONFIG_ENABLE
+ else
+ lora_config=""
+ fi
grad_accu='2'
# gemini_auto and gemini doesn't support gradient accumulation
if [[ $plugin == "gemini_auto" ]]; then
@@ -428,7 +449,7 @@ for lora_rank in ${LORA_RANK[@]}; do
--eval_dataset ${dataset[@]} \
--save_dir $MODEL_SAVE_PATH \
--config_file $MODELS_DIR/config.jsonl \
- --lora_rank $lora_rank \
+ $lora_config \
--plugin $plugin \
--batch_size $bs \
--max_epochs 1 \
@@ -482,6 +503,11 @@ for lora_rank in ${LORA_RANK[@]}; do
tp='4'
bs='8'
fi
+ if [[ $plugin == "zero2" ]]; then
+ lora_config=$LORA_CONFIG_ENABLE
+ else
+ lora_config=""
+ fi
grad_accu='2'
# gemini_auto and gemini doesn't support gradient accumulation
if [[ $plugin == "gemini_auto" ]]; then
@@ -506,7 +532,7 @@ for lora_rank in ${LORA_RANK[@]}; do
--eval_dataset ${dataset[@]} \
--save_dir $MODEL_SAVE_PATH \
--config_file $MODELS_DIR/config.jsonl \
- --lora_rank $lora_rank \
+ $lora_config \
--plugin $plugin \
--batch_size $bs \
--max_epochs 1 \
@@ -560,6 +586,11 @@ for lora_rank in ${LORA_RANK[@]}; do
tp='4'
bs='8'
fi
+ if [[ $plugin == "zero2" ]]; then
+ lora_config=$LORA_CONFIG_ENABLE
+ else
+ lora_config=""
+ fi
grad_accu='2'
# gemini_auto and gemini doesn't support gradient accumulation
if [[ $plugin == "gemini_auto" ]]; then
@@ -584,7 +615,7 @@ for lora_rank in ${LORA_RANK[@]}; do
--eval_dataset ${dataset[@]} \
--save_dir $MODEL_SAVE_PATH \
--config_file $MODELS_DIR/config.jsonl \
- --lora_rank $lora_rank \
+ $lora_config \
--plugin $plugin \
--batch_size $bs \
--max_epochs 1 \
From 66fbf2ecb7a981c95e3dc5ea598fc530cf84265f Mon Sep 17 00:00:00 2001
From: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Date: Wed, 31 Jul 2024 17:44:09 +0800
Subject: [PATCH 054/116] Update README.md (#5958)
---
applications/ColossalChat/examples/README.md | 30 +++++++++++++-------
1 file changed, 20 insertions(+), 10 deletions(-)
diff --git a/applications/ColossalChat/examples/README.md b/applications/ColossalChat/examples/README.md
index 4e80b3e47..4629dd107 100755
--- a/applications/ColossalChat/examples/README.md
+++ b/applications/ColossalChat/examples/README.md
@@ -786,17 +786,27 @@ For KTO data preparation, please use the script [prepare_kto_dataset.sh](./examp
```jsonl
{
- "prompt": [
- {"from": "user", "content": "What are some cuss words in english?"},
- {
- "from": "assistant",
- "content": "Here's an incomplete list.\n\nAss, dick, bugger, crap, ...",
- },
- {"from": "user", "content": "What's your favorite one?"},
- ],
- "completion": {"from": "assistant", "content": "Ass."}, # the completion must contain a single line from the assistant.
- "label": False, # whether the response is favorable or not
+ "prompt": [
+ {
+ "from": "user",
+ "content": "What are some praise words in english?"
+ },
+ {
+ "from": "assistant",
+ "content": "Here's an incomplete list.\n\nexcellent, fantastic, impressive ..."
+ },
+ {
+ "from": "user",
+ "content": "What's your favorite one?"
+ }
+ ],
+ "completion": {
+ "from": "assistant",
+ "content": "impressive."
+ },
+ "label": true
}
+
```
For training, use the [train_kto.sh](./examples/training_scripts/train_orpo.sh) script, You may need to set the value for `beta` (which determine how strongly the reinforcement learning loss affect the training), `desirable_weight` and `undesirable_weight` if your data is biased (has unequal number of chosen and rejected samples).
From 1aeb5e8847d40ce0c403dd546e71c01889ba5a14 Mon Sep 17 00:00:00 2001
From: Tong Li
Date: Wed, 31 Jul 2024 17:47:46 +0800
Subject: [PATCH 055/116] [hotfix] Remove unused plan section (#5957)
* remove readme
* fix readme
* update
---
applications/ColossalChat/README.md | 24 ++---------------
applications/ColossalChat/examples/README.md | 27 ++++++++++----------
2 files changed, 15 insertions(+), 36 deletions(-)
diff --git a/applications/ColossalChat/README.md b/applications/ColossalChat/README.md
index 4fbe290ba..de27ebaf6 100755
--- a/applications/ColossalChat/README.md
+++ b/applications/ColossalChat/README.md
@@ -139,17 +139,15 @@ The first step in Stage 1 is to collect a dataset of human demonstrations of the
{"messages":
[
{
- "from": "human",
+ "from": "user",
"content": "what are some pranks with a pen i can do?"
},
{
"from": "assistant",
"content": "Are you looking for practical joke ideas?"
},
- ...
]
},
- ...
]
```
@@ -175,23 +173,20 @@ Below shows the preference dataset format used in training the reward model.
"from": "human",
"content": "Introduce butterflies species in Oregon."
}
- ]
+ ],
"chosen": [
{
"from": "assistant",
"content": "About 150 species of butterflies live in Oregon, with about 100 species are moths..."
},
- ...
],
"rejected": [
{
"from": "assistant",
"content": "Are you interested in just the common butterflies? There are a few common ones which will be easy to find..."
},
- ...
]
},
- ...
]
```
@@ -220,7 +215,6 @@ PPO uses two kind of training data--- the prompt data and the sft data (optional
"from": "human",
"content": "what are some pranks with a pen i can do?"
}
- ...
]
},
]
@@ -453,20 +447,6 @@ If you only have a single 24G GPU. Generally, using lora and "zero2-cpu" will be
If you have multiple GPUs each has very limited VRAM, say 8GB. You can try the `3d` for the plugin option, which supports tensor parellelism, set `--tp` to the number of GPUs that you have.
-## The Plan
-
-- [x] implement PPO fine-tuning
-- [x] implement training reward model
-- [x] support LoRA
-- [x] support inference
-- [x] support llama from [facebook](https://github.com/facebookresearch/llama)
-- [x] implement PPO-ptx fine-tuning
-- [x] support flash-attention
-- [x] implement DPO fine-tuning
-- [ ] integrate with Ray
-- [ ] support more RL paradigms, like Implicit Language Q-Learning (ILQL),
-- [ ] support chain-of-thought by [langchain](https://github.com/hwchase17/langchain)
-
### Real-time progress
You will find our progress in github [project broad](https://github.com/orgs/hpcaitech/projects/17/views/1).
diff --git a/applications/ColossalChat/examples/README.md b/applications/ColossalChat/examples/README.md
index 4629dd107..b749f197e 100755
--- a/applications/ColossalChat/examples/README.md
+++ b/applications/ColossalChat/examples/README.md
@@ -49,9 +49,6 @@
pip install -r requirements.txt
```
-
-
-
## Get Start with ColossalRun
@@ -85,8 +82,6 @@ Make sure the master node can access all nodes (including itself) by ssh without
This section gives a simple introduction on different training strategies that you can use and how to use them with our boosters and plugins to reduce training time and VRAM consumption. For more details regarding training strategies, please refer to [here](https://colossalai.org/docs/concepts/paradigms_of_parallelism). For details regarding boosters and plugins, please refer to [here](https://colossalai.org/docs/basics/booster_plugins).
-
-
Gemini (Zero3)
@@ -499,9 +494,15 @@ In this code we provide a flexible way for users to set the conversation templat
- Step 1: (Optional). Define your conversation template. You need to provide a conversation template config file similar to the config files under the ./config/conversation_template directory. This config should include the following fields.
```json
{
- "chat_template": (Optional), A string of chat_template used for formatting chat data. If not set (None), will use the default chat template of the provided tokenizer. If a path to a huggingface model or local model is provided, will use the chat_template of that model. To use a custom chat template, you need to manually set this field. For more details on how to write a chat template in Jinja format, please read https://huggingface.co/docs/transformers/main/chat_templating,
- "system_message": A string of system message to be added at the beginning of the prompt. If no is provided (None), no system message will be added,
- "end_of_assistant": The token(s) in string that denotes the end of assistance's response. For example, in the ChatGLM2 prompt format,
+ "chat_template": "A string of chat_template used for formatting chat data",
+ "system_message": "A string of system message to be added at the beginning of the prompt. If no is provided (None), no system message will be added",
+ "end_of_assistant": "The token(s) in string that denotes the end of assistance's response",
+ "stop_ids": "A list of integers corresponds to the `end_of_assistant` tokens that indicate the end of assistance's response during the rollout stage of PPO training"
+ }
+ ```
+ * `chat_template`: (Optional), A string of chat_template used for formatting chat data. If not set (None), will use the default chat template of the provided tokenizer. If a path to a huggingface model or local model is provided, will use the chat_template of that model. To use a custom chat template, you need to manually set this field. For more details on how to write a chat template in Jinja format, please read https://huggingface.co/docs/transformers/main/chat_templating.
+ * `system_message`: A string of system message to be added at the beginning of the prompt. If no is provided (None), no system message will be added.
+ * `end_of_assistant`: The token(s) in string that denotes the end of assistance's response". For example, in the ChatGLM2 prompt format,
```
<|im_start|>system
system messages
@@ -510,13 +511,11 @@ In this code we provide a flexible way for users to set the conversation templat
<|im_start|>user
How far is the moon? <|im_end|>
<|im_start|>assistant\n The moon is about 384,400 kilometers away from Earth.<|im_end|>...
- ```
- the end_of_assistant tokens are "<|im_end|>"
- "stop_ids": (Optional), A list of integers corresponds to the `end_of_assistant` tokens that indicate the end of assistance's response during the rollout stage of PPO training. It's recommended to set this manually for PPO training. If not set, will set to tokenizer.eos_token_ids automatically
- }
- ```
- On your first run of the data preparation script, you only need to define the "chat_template" (if you want to use custom chat template) and the "system message" (if you want to use a custom system message),
+ ```
+ the `end_of_assistant` tokens are "<|im_end|>"
+ * `stop_ids`: (Optional), A list of integers corresponds to the `end_of_assistant` tokens that indicate the end of assistance's response during the rollout stage of PPO training. It's recommended to set this manually for PPO training. If not set, will set to tokenizer.eos_token_ids automatically.
+ On your first run of the data preparation script, you only need to define the `chat_template` (if you want to use custom chat template) and the `system message` (if you want to use a custom system message)
- Step 2: Run the data preparation script--- [prepare_sft_dataset.sh](./data_preparation_scripts/prepare_sft_dataset.sh). Note that whether or not you have skipped the first step, you need to provide the path to the conversation template config file (via the conversation_template_config arg). If you skipped the first step, an auto-generated conversation template will be stored at the designated file path.
From f9b6fcf81f3fd62a083184d11681f55bdcce1291 Mon Sep 17 00:00:00 2001
From: hxwang
Date: Tue, 2 Jul 2024 09:02:21 +0000
Subject: [PATCH 056/116] [test] add mixtral for sequence classification
---
.../shardformer/policies/auto_policy.py | 3 +
colossalai/shardformer/policies/mixtral.py | 131 +++++++++++++++---
2 files changed, 118 insertions(+), 16 deletions(-)
diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py
index ae9f3603c..1e0af031a 100644
--- a/colossalai/shardformer/policies/auto_policy.py
+++ b/colossalai/shardformer/policies/auto_policy.py
@@ -200,6 +200,9 @@ _POLICY_LIST = {
"transformers.models.mixtral.modeling_mixtral.MixtralForCausalLM": PolicyLocation(
file_name="mixtral", class_name="MixtralForCausalLMPolicy"
),
+ "transformers.models.mixtral.modeling_mixtral.MixtralForSequenceClassification": PolicyLocation(
+ file_name="mixtral", class_name="MixtralForSequenceClassificationPolicy"
+ ),
# Qwen2
"transformers.models.qwen2.modeling_qwen2.Qwen2Model": PolicyLocation(
file_name="qwen2", class_name="Qwen2ModelPolicy"
diff --git a/colossalai/shardformer/policies/mixtral.py b/colossalai/shardformer/policies/mixtral.py
index ad93e9469..e3cc48043 100644
--- a/colossalai/shardformer/policies/mixtral.py
+++ b/colossalai/shardformer/policies/mixtral.py
@@ -1,3 +1,4 @@
+import warnings
from functools import partial
from typing import Callable, Dict, List, Union
@@ -39,20 +40,81 @@ class MixtralPolicy(Policy):
)
if self.shard_config.enable_tensor_parallelism:
- raise NotImplementedError("Tensor parallelism is not supported for Mixtral model now.")
- if getattr(self.shard_config, "ep_group", None) is not None:
- # expert parallel
- self.append_or_create_submodule_replacement(
- description=[
- SubModuleReplacementDescription(
- suffix="block_sparse_moe",
- target_module=EPMixtralSparseMoeBlock,
- kwargs={"ep_group": self.shard_config.ep_group},
- )
- ],
- policy=policy,
- target_key=MixtralDecoderLayer,
- )
+ raise NotImplementedError
+ # assert (
+ # self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
+ # ), f"The number of attention heads must be divisible by tensor parallel size."
+ # assert (
+ # self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0
+ # ), f"The number of key_value heads must be divisible by tensor parallel size."
+ # decoder_attribute_replacement = {
+ # "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
+ # "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
+ # "self_attn.num_key_value_heads": self.model.config.num_key_value_heads
+ # // self.shard_config.tensor_parallel_size,
+ # }
+
+ # policy[MixtralDecoderLayer] = ModulePolicyDescription(
+ # attribute_replacement=decoder_attribute_replacement,
+ # sub_module_replacement=[
+ # SubModuleReplacementDescription(
+ # suffix="self_attn.q_proj",
+ # target_module=Linear1D_Col,
+ # kwargs={
+ # 'process_group': self.shard_config.tensor_parallel_process_group,
+ # }
+ # ),
+ # SubModuleReplacementDescription(
+ # suffix="self_attn.k_proj",
+ # target_module=Linear1D_Col,
+ # kwargs={
+ # 'process_group': self.shard_config.tensor_parallel_process_group,
+ # }
+ # ),
+ # SubModuleReplacementDescription(
+ # suffix="self_attn.v_proj",
+ # target_module=Linear1D_Col,
+ # kwargs={
+ # 'process_group': self.shard_config.tensor_parallel_process_group,
+ # }
+ # ),
+ # SubModuleReplacementDescription(
+ # suffix="self_attn.o_proj",
+ # target_module=Linear1D_Row,
+ # kwargs={
+ # 'process_group': self.shard_config.tensor_parallel_process_group,
+ # }
+ # ),
+ # # SubModuleReplacementDescription(
+ # # suffix="mlp.gate_proj",
+ # # target_module=Linear1D_Col,
+ # # ),
+ # # SubModuleReplacementDescription(
+ # # suffix="mlp.up_proj",
+ # # target_module=Linear1D_Col,
+ # # ),
+ # # SubModuleReplacementDescription(
+ # # suffix="mlp.down_proj",
+ # # target_module=Linear1D_Row,
+ # # ),
+ # ],
+ # )
+
+ if getattr(self.shard_config, "ep_group", None) is None:
+ raise ValueError("You must pass in ep_group via shard_config for expert parallel!")
+
+ # expert parallel
+ self.append_or_create_submodule_replacement(
+ description=[
+ SubModuleReplacementDescription(
+ suffix="block_sparse_moe",
+ target_module=EPMixtralSparseMoeBlock,
+ kwargs={"ep_group": self.shard_config.ep_group},
+ )
+ ],
+ policy=policy,
+ target_key=MixtralDecoderLayer,
+ )
# optimization configuration
if self.shard_config.enable_fused_normalization:
@@ -81,7 +143,7 @@ class MixtralPolicy(Policy):
)
if self.shard_config.enable_flash_attention:
- raise NotImplementedError("Flash attention has already been replaced in mixtral.")
+ warnings.warn("Flash attention is natively supported in transformers, will ignore the flag.")
return policy
@@ -150,7 +212,7 @@ class MixtralModelPolicy(MixtralPolicy):
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
- """No shared params in llama model"""
+ """No shared params in mixtral model"""
return []
@@ -206,3 +268,40 @@ class MixtralForCausalLMPolicy(MixtralPolicy):
}
]
return []
+
+
+class MixtralForSequenceClassificationPolicy(MixtralPolicy):
+ def module_policy(self):
+ from transformers import MixtralForSequenceClassification
+
+ policy = super().module_policy()
+
+ if self.shard_config.enable_tensor_parallelism:
+ # add a new item for sequence classification
+ new_item = {
+ MixtralForSequenceClassification: ModulePolicyDescription(
+ sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True)
+ )
+ ]
+ )
+ }
+ policy.update(new_item)
+
+ if self.pipeline_stage_manager:
+ raise NotImplementedError
+
+ return policy
+
+ def get_held_layers(self) -> List[Module]:
+ """Get pipeline layers for current stage."""
+ stage_manager = self.pipeline_stage_manager
+ held_layers = super().get_held_layers()
+ if stage_manager.is_last_stage(ignore_chunk=True):
+ held_layers.append(self.model.score)
+ return held_layers
+
+ def get_shared_params(self) -> List[Dict[int, Tensor]]:
+ """No shared params in llama for sequence classification model"""
+ return []
From 0b76b57cd64be15f2ae7e5ebd8dd4d327c4685c3 Mon Sep 17 00:00:00 2001
From: hxwang
Date: Tue, 2 Jul 2024 09:08:41 +0000
Subject: [PATCH 057/116] [test] add mixtral transformer test
---
colossalai/shardformer/modeling/mixtral.py | 20 +-
tests/kit/model_zoo/transformers/__init__.py | 20 +-
tests/kit/model_zoo/transformers/mixtral.py | 82 ++++++++
tests/test_moe/moe_utils.py | 2 -
tests/test_shardformer/test_model/_utils.py | 12 +-
.../test_model/test_shard_mixtral.py | 175 ++++++++++++++++++
6 files changed, 281 insertions(+), 30 deletions(-)
create mode 100644 tests/kit/model_zoo/transformers/mixtral.py
create mode 100644 tests/test_shardformer/test_model/test_shard_mixtral.py
diff --git a/colossalai/shardformer/modeling/mixtral.py b/colossalai/shardformer/modeling/mixtral.py
index 2fbc34302..334bd13fc 100644
--- a/colossalai/shardformer/modeling/mixtral.py
+++ b/colossalai/shardformer/modeling/mixtral.py
@@ -4,8 +4,6 @@ import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch.distributed import ProcessGroup
-
-# from colossalai.tensor.moe_tensor.moe_info import MoeParallelInfo
from torch.nn import CrossEntropyLoss
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.models.mixtral.modeling_mixtral import (
@@ -23,30 +21,34 @@ from colossalai.shardformer.shard.utils import set_tensors_to_none
class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
- def __init__(self, config):
- self.moe_info = None
+ def __init__(self, config, ep_group):
super().__init__(config)
+ self.setup_ep(ep_group)
def setup_ep(self, ep_group: ProcessGroup):
ep_group = ep_group
self.ep_size = dist.get_world_size(ep_group) if ep_group is not None else 1
self.ep_rank = dist.get_rank(ep_group) if ep_group is not None else 0
- assert self.num_experts % self.ep_size == 0
self.ep_group = ep_group
+
+ if self.num_experts % self.ep_size != 0:
+ raise ValueError("The number of experts must be divisible by the number of expert parallel groups.")
+
self.num_experts_per_ep = self.num_experts // self.ep_size
self.expert_start_idx = self.ep_rank * self.num_experts_per_ep
held_experts = self.experts[self.expert_start_idx : self.expert_start_idx + self.num_experts_per_ep]
+
set_tensors_to_none(self.experts, exclude=set(held_experts))
for p in self.experts.parameters():
p.ep_group = ep_group
@staticmethod
- def from_native_module(module: MixtralSparseMoeBlock, *args, **kwargs) -> "EPMixtralSparseMoeBlock":
+ def from_native_module(
+ module: MixtralSparseMoeBlock, ep_group: ProcessGroup, *args, **kwargs
+ ) -> "EPMixtralSparseMoeBlock":
LazyInitContext.materialize(module)
module.__class__ = EPMixtralSparseMoeBlock
- # if "ep_group" in kwargs:
- assert "ep_group" in kwargs, "You should pass ep_group in SubModuleReplacementDescription via shard_config!!"
- module.setup_ep(kwargs["ep_group"])
+ module.setup_ep(ep_group)
return module
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
diff --git a/tests/kit/model_zoo/transformers/__init__.py b/tests/kit/model_zoo/transformers/__init__.py
index 05c17f562..ac5184065 100644
--- a/tests/kit/model_zoo/transformers/__init__.py
+++ b/tests/kit/model_zoo/transformers/__init__.py
@@ -3,28 +3,16 @@ from .bert import *
from .blip2 import *
from .bloom import *
from .chatglm2 import *
+from .command import *
from .falcon import *
from .gpt import *
from .gptj import *
from .llama import *
+from .mistral import *
+from .mixtral import *
from .opt import *
+from .qwen2 import *
from .sam import *
from .t5 import *
from .vit import *
from .whisper import *
-
-try:
- from .mistral import *
-except ImportError:
- print("This version of transformers doesn't support mistral.")
-
-try:
- from .qwen2 import *
-except ImportError:
- print("This version of transformers doesn't support qwen2.")
-
-
-try:
- from .command import *
-except ImportError:
- print("This version of transformers doesn't support Command-R.")
diff --git a/tests/kit/model_zoo/transformers/mixtral.py b/tests/kit/model_zoo/transformers/mixtral.py
new file mode 100644
index 000000000..b82a4b939
--- /dev/null
+++ b/tests/kit/model_zoo/transformers/mixtral.py
@@ -0,0 +1,82 @@
+# modified from tests/kit/model_zoo/transformers/mistral.py
+import torch
+import transformers
+from transformers import MixtralConfig
+
+from ..registry import ModelAttribute, model_zoo
+
+# ===============================
+# Register single-sentence Mixtral
+# ===============================
+
+
+def data_gen():
+ # Generated from following code snippet
+ #
+ # from transformers import AutoModelForCausalLM, AutoTokenizer
+ # tokenizer = AutoTokenizer.from_pretrained("mixtralai/Mixtral-7B-v0.1")
+ # input = 'My favourite condiment is vinegar' (last two words repeated to satisfy length requirement)
+ # tokenized_input = tokenizer([input], return_tensors="pt")
+ # input_ids = tokenized_input['input_ids']
+ # attention_mask = tokenized_input['attention_mask']
+ input_ids = torch.tensor([[1, 1984, 16020, 2076, 2487, 349, 21375, 4749]], dtype=torch.int64)
+ attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
+ return dict(input_ids=input_ids, attention_mask=attention_mask)
+
+
+def data_gen_for_lm():
+ # LM data gen
+ # the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels`
+ data = data_gen()
+ data["labels"] = data["input_ids"].clone()
+ return data
+
+
+def data_gen_for_sequence_classification():
+ # sequence classification data gen
+ data = data_gen()
+ data["labels"] = torch.tensor([1], dtype=torch.int64)
+ return data
+
+
+# define output transform function
+output_transform_fn = lambda x: x
+
+# define loss function
+loss_fn_for_mixtral_model = lambda x: torch.nn.functional.mse_loss(
+ x.last_hidden_state, torch.ones_like(x.last_hidden_state)
+)
+loss_fn = lambda x: x.loss
+loss_fn_for_seq_classification = lambda output: output.logits.mean()
+
+config = MixtralConfig(
+ hidden_size=256, intermediate_size=256, num_attention_heads=64, num_hidden_layers=2, vocab_size=50258
+)
+
+if hasattr(config, "pad_token_id"):
+ config.pad_token_id = config.eos_token_id
+
+model_zoo.register(
+ name="transformers_mixtral",
+ model_fn=lambda: transformers.MixtralModel(config),
+ data_gen_fn=data_gen,
+ output_transform_fn=output_transform_fn,
+ loss_fn=loss_fn_for_mixtral_model,
+ model_attribute=ModelAttribute(has_control_flow=True),
+)
+model_zoo.register(
+ name="transformers_mixtral_for_casual_lm",
+ model_fn=lambda: transformers.MixtralForCausalLM(config),
+ data_gen_fn=data_gen_for_lm,
+ output_transform_fn=output_transform_fn,
+ loss_fn=loss_fn,
+ model_attribute=ModelAttribute(has_control_flow=True),
+)
+model_zoo.register(
+ name="transformers_mixtral_for_sequence_classification",
+ model_fn=lambda: transformers.MixtralForSequenceClassification(config),
+ data_gen_fn=data_gen_for_sequence_classification,
+ output_transform_fn=output_transform_fn,
+ loss_fn=loss_fn_for_seq_classification,
+ model_attribute=ModelAttribute(has_control_flow=True),
+)
diff --git a/tests/test_moe/moe_utils.py b/tests/test_moe/moe_utils.py
index 131932dcb..ba6a0e8a9 100644
--- a/tests/test_moe/moe_utils.py
+++ b/tests/test_moe/moe_utils.py
@@ -10,8 +10,6 @@ from colossalai.legacy.engine.gradient_handler.utils import bucket_allreduce
from colossalai.legacy.registry import GRADIENT_HANDLER
from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.utils import get_moe_epsize_param_dict
-
-# from colossalai.shardformer.layer.moe import SparseMLP
from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_size, set_moe_tensor_ep_group
diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py
index 1ffcc541a..190fee129 100644
--- a/tests/test_shardformer/test_model/_utils.py
+++ b/tests/test_shardformer/test_model/_utils.py
@@ -1,6 +1,6 @@
import copy
from contextlib import nullcontext
-from typing import Any, Callable, Dict, List, Optional
+from typing import Any, Callable, Dict, List, Optional, Type
import torch
import torch.distributed as dist
@@ -117,7 +117,12 @@ def check_state_dict(org_model: Module, sharded_model: Module, name: str = ""):
def build_model_from_hybrid_plugin(
- model_fn: Callable, loss_fn: Callable, test_config: Dict[str, Any], optim_class=Adam, sharded_optim_class=Adam
+ model_fn: Callable,
+ loss_fn: Callable,
+ test_config: Dict[str, Any],
+ optim_class=Adam,
+ sharded_optim_class=Adam,
+ pluggin_cls: Type[HybridParallelPlugin] = HybridParallelPlugin,
):
use_lazy_init = False
if "use_lazy_init" in test_config:
@@ -149,9 +154,10 @@ def build_model_from_hybrid_plugin(
else:
org_optimizer = optim_class(org_model.parameters(), lr=1e-3)
sharded_optimizer = sharded_optim_class(sharded_model.parameters(), lr=1e-3)
+
criterion = loss_fn
- plugin = HybridParallelPlugin(**test_config)
+ plugin = pluggin_cls(**test_config)
booster = Booster(plugin=plugin)
sharded_model, sharded_optimizer, criterion, _, _ = booster.boost(sharded_model, sharded_optimizer, criterion)
diff --git a/tests/test_shardformer/test_model/test_shard_mixtral.py b/tests/test_shardformer/test_model/test_shard_mixtral.py
new file mode 100644
index 000000000..bf2d2bb1b
--- /dev/null
+++ b/tests/test_shardformer/test_model/test_shard_mixtral.py
@@ -0,0 +1,175 @@
+# modified from test_shard_mistral.py
+import os
+
+import pytest
+import torch
+
+import colossalai
+from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
+from colossalai.logging import disable_existing_loggers
+from colossalai.shardformer.layer.utils import Randomizer
+from colossalai.tensor.d_tensor.api import clear_layout_converter
+from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
+from tests.kit.model_zoo import model_zoo
+from tests.test_shardformer.test_model._utils import (
+ build_model_from_hybrid_plugin,
+ check_all_grad_tensors,
+ check_loss,
+ check_weight,
+ get_grad_tensors_for_check,
+ run_forward_backward_with_hybrid_plugin,
+ unwrap_model,
+)
+
+os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
+
+
+def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
+ org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin(
+ model_fn, loss_fn, test_config, pluggin_cls=MoeHybridParallelPlugin, optim_class=torch.optim.SGD
+ )
+
+ org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
+ org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
+ )
+
+ stage_manager = booster.plugin.stage_manager
+ tp_group = booster.plugin.tp_group
+
+ # unwrap model
+ mixtral_model = unwrap_model(org_model, "MixtralModel", "model")
+ shard_mixtral_model = unwrap_model(sharded_model, "MixtralModel", "model")
+
+ row_layer_for_check = ["layers[0].self_attn.q_proj", "embed_tokens"]
+ col_layer_for_check = ["layers[0].self_attn.o_proj"]
+
+ # Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
+ grads_to_check = {}
+ if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
+ if test_config["precision"] == "fp32":
+ atol, rtol = 5e-5, 1e-4
+ else:
+ atol, rtol = 5e-3, 5e-3
+ row_layer_grads = get_grad_tensors_for_check(
+ mixtral_model,
+ shard_mixtral_model,
+ row_layer_for_check,
+ tp_group,
+ atol=atol,
+ rtol=rtol,
+ dim=0,
+ verbose=False,
+ )
+ col_layer_grads = get_grad_tensors_for_check(
+ mixtral_model,
+ shard_mixtral_model,
+ col_layer_for_check,
+ tp_group,
+ atol=atol,
+ rtol=rtol,
+ dim=1,
+ verbose=False,
+ )
+ grads_to_check.update(col_layer_grads)
+ grads_to_check.update(row_layer_grads)
+
+ # optimizer executes step
+ org_optimizer.step()
+ sharded_optimizer.step()
+
+ # check last hidden state & loss
+ if stage_manager is None or stage_manager.is_last_stage():
+ if test_config["precision"] == "fp32":
+ atol, rtol = 1e-5, 1e-3
+ else:
+ atol, rtol = 5e-3, 5e-3
+
+ check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
+
+ # check weights
+ if stage_manager is None or stage_manager.is_first_stage():
+ if test_config["precision"] == "fp32":
+ atol, rtol = 2e-4, 1e-3
+ else:
+ atol, rtol = 5e-3, 5e-3
+ check_weight(
+ mixtral_model,
+ shard_mixtral_model,
+ col_layer_for_check,
+ tp_group,
+ atol=atol,
+ rtol=rtol,
+ dim=1,
+ verbose=False,
+ )
+
+ # check grads
+ check_all_grad_tensors(grads_to_check)
+
+ torch.cuda.empty_cache()
+
+
+@parameterize(
+ "test_config",
+ [
+ {
+ "tp_size": 1,
+ "pp_size": 1,
+ "ep_size": 4,
+ "num_microbatches": 2,
+ "zero_stage": 0,
+ "enable_all_optimization": True,
+ "use_lazy_init": False,
+ "precision": "fp16",
+ "initial_scale": 1,
+ },
+ {
+ "tp_size": 1,
+ "pp_size": 1,
+ "ep_size": 4,
+ "num_microbatches": 2,
+ "zero_stage": 1,
+ "enable_all_optimization": True,
+ "use_lazy_init": False,
+ "precision": "fp16",
+ "initial_scale": 1,
+ },
+ {
+ "tp_size": 1,
+ "pp_size": 1,
+ "ep_size": 4,
+ "num_microbatches": 2,
+ "zero_stage": 2,
+ "enable_all_optimization": True,
+ "use_lazy_init": False,
+ "precision": "fp16",
+ "initial_scale": 1,
+ },
+ ],
+)
+def run_mixtral_test(test_config):
+ sub_model_zoo = model_zoo.get_sub_registry("transformers_mixtral")
+
+ for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
+ check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
+
+ clear_layout_converter()
+ Randomizer.reset_index()
+ torch.cuda.empty_cache()
+
+
+def check_mixtral(rank, world_size, port):
+ disable_existing_loggers()
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ run_mixtral_test()
+
+
+@pytest.mark.dist
+@rerun_if_address_is_in_use()
+@clear_cache_before_run()
+def test_mixtral():
+ spawn(check_mixtral, 4)
+
+
+if __name__ == "__main__":
+ test_mixtral()
From 8ae8525bdf4134c48c001f4a7cb6660e10d9e787 Mon Sep 17 00:00:00 2001
From: hxwang
Date: Tue, 2 Jul 2024 09:09:00 +0000
Subject: [PATCH 058/116] [moe] fix plugin
---
colossalai/booster/plugin/hybrid_parallel_plugin.py | 8 ++++----
.../booster/plugin/moe_hybrid_parallel_plugin.py | 13 ++++++++++---
2 files changed, 14 insertions(+), 7 deletions(-)
diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py
index 2c8cb6ba1..92bab29ec 100644
--- a/colossalai/booster/plugin/hybrid_parallel_plugin.py
+++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py
@@ -254,7 +254,7 @@ def get_param_info(optim: Optimizer):
return param_info
-def init_pipeline_optimizer(optim: Optimizer, model: Module):
+def reinitialize_optimizer(optim: Optimizer, model: Module):
model_params = set(model.parameters())
new_param_groups = []
for group in optim.param_groups:
@@ -276,7 +276,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
):
self.param_info = param_info
if use_pipeline:
- init_pipeline_optimizer(optim, model)
+ reinitialize_optimizer(optim, model)
self.model = model
self.stage_manager = model.stage_manager
self.shared_params = model.shared_params
@@ -497,7 +497,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
self.tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1
self.pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
if use_pipeline:
- init_pipeline_optimizer(optim, model)
+ reinitialize_optimizer(optim, model)
super().__init__(
optim,
precision=precision,
@@ -678,7 +678,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
self.tp_pg = tp_process_group
self.pp_pg = pp_process_group
if use_pipeline:
- init_pipeline_optimizer(optimizer, model)
+ reinitialize_optimizer(optimizer, model)
super().__init__(
optimizer=optimizer,
initial_scale=initial_scale,
diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
index 2cfdd000a..3d4250ac8 100644
--- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
+++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
@@ -19,7 +19,7 @@ from colossalai.booster.plugin.hybrid_parallel_plugin import (
HybridParallelNaiveOptimizer,
HybridParallelPlugin,
get_param_info,
- init_pipeline_optimizer,
+ reinitialize_optimizer,
)
from colossalai.checkpoint_io import MoECheckpointIO
from colossalai.cluster import ProcessGroupMesh
@@ -67,7 +67,7 @@ class MoeHybridParallelZeroOptimizer(LowLevelZeroOptimizer):
self.tp_pg = tp_process_group
self.pp_pg = pp_process_group
if use_pipeline:
- init_pipeline_optimizer(optimizer, model)
+ reinitialize_optimizer(optimizer, model)
pg_param_list = {
dp_process_group: [],
@@ -400,12 +400,19 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
dp_group=self.global_dp_group,
tp_group=self.tp_group,
sp_group=self.sp_group,
- use_ddp=use_ddp,
+ use_ddp=use_ddp, # TODO fix why this failed
ddp_config=self.ddp_config,
custom_policy=self.custom_policy,
)
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
+ if self.ep_size > 1:
+ # if ep is enabled, the num of (moe) paramaters changed since they are sharded among ep groups
+ # but the optimizer is not aware of ep, so we need to update the optimizer
+ reinitialize_optimizer(optimizer, model)
+
if self.zero_stage == 0:
+ assert self.ep_size > 1
+
if self.precision in ["fp16", "bf16"]:
optimizer = HybridParallelAMPOptimizer(
optimizer,
From a249e719463ad7b9d66e98ba4c88eb9c9e44689c Mon Sep 17 00:00:00 2001
From: hxwang
Date: Thu, 4 Jul 2024 06:39:01 +0000
Subject: [PATCH 059/116] [test] mixtra pp shard test
---
.../plugin/moe_hybrid_parallel_plugin.py | 2 +-
tests/kit/model_zoo/transformers/mixtral.py | 43 ++++++++--------
.../test_model/test_shard_mixtral.py | 50 +++++++++----------
3 files changed, 49 insertions(+), 46 deletions(-)
diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
index 3d4250ac8..98b206479 100644
--- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
+++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
@@ -411,7 +411,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
reinitialize_optimizer(optimizer, model)
if self.zero_stage == 0:
- assert self.ep_size > 1
+ # assert self.ep_size > 1
if self.precision in ["fp16", "bf16"]:
optimizer = HybridParallelAMPOptimizer(
diff --git a/tests/kit/model_zoo/transformers/mixtral.py b/tests/kit/model_zoo/transformers/mixtral.py
index b82a4b939..0ac6a75ce 100644
--- a/tests/kit/model_zoo/transformers/mixtral.py
+++ b/tests/kit/model_zoo/transformers/mixtral.py
@@ -43,14 +43,17 @@ def data_gen_for_sequence_classification():
output_transform_fn = lambda x: x
# define loss function
-loss_fn_for_mixtral_model = lambda x: torch.nn.functional.mse_loss(
- x.last_hidden_state, torch.ones_like(x.last_hidden_state)
-)
+loss_fn_for_mixtral_model = lambda x: torch.nn.functional.mse_loss(x[0], torch.ones_like(x[0]))
loss_fn = lambda x: x.loss
loss_fn_for_seq_classification = lambda output: output.logits.mean()
config = MixtralConfig(
- hidden_size=256, intermediate_size=256, num_attention_heads=64, num_hidden_layers=2, vocab_size=50258
+ hidden_size=256,
+ intermediate_size=256,
+ num_attention_heads=64,
+ num_hidden_layers=2,
+ vocab_size=50258,
+ output_router_logits=True,
)
if hasattr(config, "pad_token_id"):
@@ -64,19 +67,19 @@ model_zoo.register(
loss_fn=loss_fn_for_mixtral_model,
model_attribute=ModelAttribute(has_control_flow=True),
)
-model_zoo.register(
- name="transformers_mixtral_for_casual_lm",
- model_fn=lambda: transformers.MixtralForCausalLM(config),
- data_gen_fn=data_gen_for_lm,
- output_transform_fn=output_transform_fn,
- loss_fn=loss_fn,
- model_attribute=ModelAttribute(has_control_flow=True),
-)
-model_zoo.register(
- name="transformers_mixtral_for_sequence_classification",
- model_fn=lambda: transformers.MixtralForSequenceClassification(config),
- data_gen_fn=data_gen_for_sequence_classification,
- output_transform_fn=output_transform_fn,
- loss_fn=loss_fn_for_seq_classification,
- model_attribute=ModelAttribute(has_control_flow=True),
-)
+# model_zoo.register(
+# name="transformers_mixtral_for_casual_lm",
+# model_fn=lambda: transformers.MixtralForCausalLM(config),
+# data_gen_fn=data_gen_for_lm,
+# output_transform_fn=output_transform_fn,
+# loss_fn=loss_fn,
+# model_attribute=ModelAttribute(has_control_flow=True),
+# )
+# model_zoo.register(
+# name="transformers_mixtral_for_sequence_classification",
+# model_fn=lambda: transformers.MixtralForSequenceClassification(config),
+# data_gen_fn=data_gen_for_sequence_classification,
+# output_transform_fn=output_transform_fn,
+# loss_fn=loss_fn_for_seq_classification,
+# model_attribute=ModelAttribute(has_control_flow=True),
+# )
diff --git a/tests/test_shardformer/test_model/test_shard_mixtral.py b/tests/test_shardformer/test_model/test_shard_mixtral.py
index bf2d2bb1b..f8deb2e8a 100644
--- a/tests/test_shardformer/test_model/test_shard_mixtral.py
+++ b/tests/test_shardformer/test_model/test_shard_mixtral.py
@@ -114,37 +114,37 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
[
{
"tp_size": 1,
- "pp_size": 1,
- "ep_size": 4,
- "num_microbatches": 2,
+ "pp_size": 4,
+ "ep_size": 1,
+ "num_microbatches": 4,
"zero_stage": 0,
"enable_all_optimization": True,
"use_lazy_init": False,
"precision": "fp16",
"initial_scale": 1,
},
- {
- "tp_size": 1,
- "pp_size": 1,
- "ep_size": 4,
- "num_microbatches": 2,
- "zero_stage": 1,
- "enable_all_optimization": True,
- "use_lazy_init": False,
- "precision": "fp16",
- "initial_scale": 1,
- },
- {
- "tp_size": 1,
- "pp_size": 1,
- "ep_size": 4,
- "num_microbatches": 2,
- "zero_stage": 2,
- "enable_all_optimization": True,
- "use_lazy_init": False,
- "precision": "fp16",
- "initial_scale": 1,
- },
+ # {
+ # "tp_size": 1,
+ # "pp_size": 1,
+ # "ep_size": 4,
+ # "num_microbatches": 2,
+ # "zero_stage": 1,
+ # "enable_all_optimization": True,
+ # "use_lazy_init": False,
+ # "precision": "fp16",
+ # "initial_scale": 1,
+ # },
+ # {
+ # "tp_size": 1,
+ # "pp_size": 1,
+ # "ep_size": 4,
+ # "num_microbatches": 2,
+ # "zero_stage": 2,
+ # "enable_all_optimization": True,
+ # "use_lazy_init": False,
+ # "precision": "fp16",
+ # "initial_scale": 1,
+ # },
],
)
def run_mixtral_test(test_config):
From 0fad23c691aff3fb9fe8da72641c1e50fd65d99e Mon Sep 17 00:00:00 2001
From: hxwang
Date: Fri, 5 Jul 2024 07:03:45 +0000
Subject: [PATCH 060/116] [chore] handle non member group
---
colossalai/cluster/process_group_mesh.py | 6 ++++--
1 file changed, 4 insertions(+), 2 deletions(-)
diff --git a/colossalai/cluster/process_group_mesh.py b/colossalai/cluster/process_group_mesh.py
index b6aff0d72..7f1ef9fce 100644
--- a/colossalai/cluster/process_group_mesh.py
+++ b/colossalai/cluster/process_group_mesh.py
@@ -7,6 +7,7 @@ from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import torch.distributed as dist
from torch.distributed import ProcessGroup
+from torch.distributed.distributed_c10d import GroupMember
def prod(nums: List[int]) -> int:
@@ -47,7 +48,7 @@ class ProcessGroupMesh:
self._shape = size
self._rank = dist.get_rank()
self._coord = ProcessGroupMesh.unravel(self._rank, self._shape)
- self._ranks_to_group: Dict[Tuple[int, ...], ProcessGroup] = {}
+ self._ranks_to_group: Dict[Tuple[int, ...], Union[ProcessGroup, GroupMember]] = {}
self._group_to_ranks: Dict[ProcessGroup, Tuple[int, ...]] = {}
def destroy_mesh_process_groups(self):
@@ -150,7 +151,8 @@ class ProcessGroupMesh:
if tuple(ranks_in_group) not in self._ranks_to_group:
group = dist.new_group(ranks_in_group, backend=backend)
self._ranks_to_group[tuple(ranks_in_group)] = group
- self._group_to_ranks[group] = tuple(ranks_in_group)
+ if group is not GroupMember.NON_GROUP_MEMBER:
+ self._group_to_ranks[group] = tuple(ranks_in_group)
return self._ranks_to_group[tuple(ranks_in_group)]
def get_ranks_in_group(self, group: ProcessGroup) -> List[int]:
From 46c069b0db83d35174490951dd6e51e79fb62144 Mon Sep 17 00:00:00 2001
From: hxwang
Date: Fri, 5 Jul 2024 07:19:37 +0000
Subject: [PATCH 061/116] [zero] solve hang
---
.../booster/plugin/hybrid_parallel_plugin.py | 12 +-
.../plugin/moe_hybrid_parallel_plugin.py | 333 ++----------------
colossalai/cluster/process_group_mesh.py | 4 +-
colossalai/moe/_operation.py | 3 +
colossalai/shardformer/policies/mixtral.py | 27 +-
.../low_level/bookkeeping/bucket_store.py | 10 +-
.../low_level/bookkeeping/gradient_store.py | 2 +-
colossalai/zero/low_level/low_level_optim.py | 16 +-
tests/kit/model_zoo/transformers/mixtral.py | 6 +-
tests/test_moe/test_moe_checkpoint.py | 1 -
tests/test_moe/test_moe_zero_fwd_bwd_optim.py | 37 +-
.../test_model/test_shard_mixtral.py | 52 +--
12 files changed, 113 insertions(+), 390 deletions(-)
diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py
index 92bab29ec..983ddfc97 100644
--- a/colossalai/booster/plugin/hybrid_parallel_plugin.py
+++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py
@@ -1058,17 +1058,7 @@ class HybridParallelPlugin(PipelinePluginBase):
self.enable_jit_fused = enable_jit_fused
self.enable_sequence_parallelism = enable_sequence_parallelism
if dp_outside:
- (
- self.dp_axis,
- self.pp_axis,
- self.tp_axis,
- self.sp_axis,
- ) = (
- 0,
- 1,
- 2,
- 3,
- )
+ self.dp_axis, self.pp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size)
else:
self.pp_axis, self.dp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
index 98b206479..02a87ff11 100644
--- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
+++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
@@ -1,9 +1,7 @@
-import random
import warnings
from types import MethodType
from typing import Callable, Optional, OrderedDict, Tuple
-import numpy as np
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
@@ -11,7 +9,6 @@ from torch.nn import Module
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
-from torch.utils.data.distributed import DistributedSampler
from colossalai.booster.plugin.hybrid_parallel_plugin import (
HybridParallelAMPOptimizer,
@@ -22,13 +19,8 @@ from colossalai.booster.plugin.hybrid_parallel_plugin import (
reinitialize_optimizer,
)
from colossalai.checkpoint_io import MoECheckpointIO
-from colossalai.cluster import ProcessGroupMesh
+from colossalai.cluster.process_group_mesh import ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper
-from colossalai.logging import get_dist_logger
-from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
-from colossalai.pipeline.stage_manager import PipelineStageManager
-from colossalai.shardformer import ShardConfig
-from colossalai.shardformer.policies.base_policy import Policy
from colossalai.tensor.moe_tensor.api import is_moe_tensor
from colossalai.zero.low_level import LowLevelZeroOptimizer
@@ -39,6 +31,8 @@ class MoeHybridParallelZeroOptimizer(LowLevelZeroOptimizer):
optimizer: Optimizer,
model: Module,
use_pipeline: bool,
+ dp_process_group: ProcessGroup, # the dp pg for comm
+ moe_dp_group: ProcessGroup, # the moe dp pg for gomm
param_info: OrderedDict,
initial_scale: int = 2**16, # grad scaler config
min_scale: int = 1,
@@ -54,30 +48,20 @@ class MoeHybridParallelZeroOptimizer(LowLevelZeroOptimizer):
overlap_communication: bool = True,
partition_grad: bool = False, # stage 2 flag
cpu_offload: bool = False, # cpu offload
- dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm
- tp_process_group: Optional[ProcessGroup] = None, # if using tp
- pp_process_group: Optional[ProcessGroup] = None,
forced_dtype: Optional[torch.dtype] = None,
- moe_extra_dp_process_group: Optional[ProcessGroup] = None,
):
self.param_info = param_info
self.stage_manager = model.stage_manager
self.shared_params = model.shared_params
self.dp_pg = dp_process_group
- self.tp_pg = tp_process_group
- self.pp_pg = pp_process_group
+
if use_pipeline:
reinitialize_optimizer(optimizer, model)
pg_param_list = {
- dp_process_group: [],
- moe_extra_dp_process_group: [],
+ dp_process_group: list(filter(lambda p: not is_moe_tensor(p), model.parameters())),
+ moe_dp_group: list(filter(is_moe_tensor, model.parameters())),
}
- for param in model.parameters():
- if is_moe_tensor(param):
- pg_param_list[moe_extra_dp_process_group].append(param)
- else:
- pg_param_list[dp_process_group].append(param)
super().__init__(
optimizer=optimizer,
@@ -102,285 +86,43 @@ class MoeHybridParallelZeroOptimizer(LowLevelZeroOptimizer):
class MoeHybridParallelPlugin(HybridParallelPlugin):
"""
- Plugin for Moe Hybrid Parallel Training.
- Tensor parallel, pipeline parallel and data parallel(DDP/ZeRO) can be picked and combined in this plugin.
- The size of tp and pp should be passed in by user, then the size of dp is automatically calculated from dp_size = world_size / (tp_size * pp_size).
-
- Example:
- >>> from colossalai.booster import Booster
- >>> from colossalai.booster.plugin import HybridParallelPlugin
-
- >>> model, train_dataset, optimizer, criterion = ...
- >>> plugin = HybridParallelPlugin(tp_size=2, pp_size=2)
-
- >>> train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
- >>> booster = Booster(plugin=plugin)
- >>> model, optimizer, criterion, train_dataloader, _ = booster.boost(model, optimizer, criterion, train_dataloader)
-
- Args:
- pp_size (int): The number of pipeline stages in pipeline parallelism. Pipeline parallelism will not be used when pp_size is set to 1.
- tp_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1.
- precision (str, optional): Specifies the precision of parameters during training.
- Auto-mixied precision will be used when this argument is set to 'fp16' or 'bf16', otherwise model is trained with 'fp32'.
- Defaults to 'fp16'.
- zero_stage (int, optional): The stage of ZeRO for data parallelism. Can only be choosed from [0, 1, 2].
- When set to 0, ZeRO will not be used. Defaults to 0.
- enable_all_optimization (bool, optional): Whether to switch on all the optimizations supported by Shardformer.
- Currently all the optimization methods include fused normalization, flash attention and JIT.
- Defaults to False.
- enable_fused_normalization (bool, optional): Whether to switch on fused normalization in Shardformer. Defaults to False.
- enable_flash_attention (bool, optional): Whether to switch on flash attention in Shardformer. Defaults to False.
- enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
- enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
- enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False.
- num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None.
- microbatch_size (int, optional): Microbatch size when using pipeline parallelism.
- Either ``num_microbatches`` or ``microbatch_size`` should be provided if using pipeline.
- If ``num_microbatches`` is provided, this will be ignored. Defaults to None.
- initial_scale (float, optional): The initial loss scale of AMP. Defaults to 2**16.
- min_scale (float, optional): The minimum loss scale of AMP. Defaults to 1.
- growth_factor (float, optional): The multiplication factor for increasing loss scale when using AMP. Defaults to 2.
- backoff_factor (float, optional): The multiplication factor for decreasing loss scale when using AMP. Defaults to 0.5.
- growth_interval (int, optional): The number of steps to increase loss scale when no overflow occurs when using AMP. Defaults to 1000.
- hysteresis (int, optional): The number of overflows before decreasing loss scale when using AMP. Defaults to 2.
- max_scale (float, optional): The maximum loss scale of AMP. Defaults to 2**32.
- max_norm (float, optional): Maximum norm for gradient clipping. Defaults to 0.
- broadcast_buffers (bool, optional): Whether to broadcast buffers in the beginning of training when using DDP. Defaults to True.
- ddp_bucket_cap_mb (int, optional): The bucket size in MB when using DDP. Defaults to 25.
- find_unused_parameters (bool, optional): Whether to find unused parameters when using DDP. Defaults to False.
- check_reduction (bool, optional): Whether to check reduction when using DDP. Defaults to False.
- gradient_as_bucket_view (bool, optional): Whether to use gradient as bucket view when using DDP. Defaults to False.
- static_graph (bool, optional): Whether to use static graph when using DDP. Defaults to False.
- zero_bucket_size_in_m (int, optional): Gradient reduce bucket size in million elements when using ZeRO. Defaults to 12.
- cpu_offload (bool, optional): Whether to open cpu_offload when using ZeRO. Defaults to False.
- communication_dtype (torch.dtype, optional): Communication dtype when using ZeRO. If not specified, the dtype of param will be used. Defaults to None.
- overlap_communication (bool, optional): Whether to overlap communication and computation when using ZeRO. Defaults to True.
- use_ep_inside (bool, Optional): Whether to use ep inside dp (intra-node) for moe params.
+ TODO: add docstring
"""
- def __init__(
- self,
- pp_size: int,
- ep_size: int,
- tp_size: int = 1,
- sp_size: int = 1,
- precision: str = "fp16",
- zero_stage: int = 0,
- enable_all_optimization: bool = False,
- enable_fused_normalization: bool = False,
- enable_flash_attention: bool = False,
- enable_jit_fused: bool = False,
- enable_sequence_parallelism: bool = False,
- enable_sequence_overlap: bool = False,
- num_microbatches: Optional[int] = None,
- microbatch_size: Optional[int] = None,
- initial_scale: float = 2**16,
- min_scale: float = 1,
- growth_factor: float = 2,
- backoff_factor: float = 0.5,
- growth_interval: int = 1000,
- hysteresis: int = 2,
- max_scale: float = 2**32,
- max_norm: float = 0,
- broadcast_buffers: bool = True,
- ddp_bucket_cap_mb: int = 25,
- find_unused_parameters: bool = False,
- check_reduction: bool = False,
- gradient_as_bucket_view: bool = False,
- static_graph: bool = False,
- zero_bucket_size_in_m: int = 12,
- cpu_offload: bool = False,
- communication_dtype: Optional[torch.dtype] = None,
- overlap_communication: bool = True,
- use_ep_inside: bool = True,
- custom_policy: Policy = None,
- checkpoint_io: Optional[MoECheckpointIO] = None,
- ) -> None:
- world_size = dist.get_world_size()
- assert tp_size == 1, "Tensor parallel is not supported in MoE yet"
- assert sp_size == 1 and enable_sequence_parallelism is False, "Sequence parallelism it not supported in MoE yet"
+ def __init__(self, ep_size: int, ep_tp_size: int = 1, *args, **kwargs) -> None:
+ super().__init__(*args, **kwargs)
- assert (
- world_size % (tp_size * pp_size) == 0
- ), f"world size {world_size} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
- assert (
- world_size % (tp_size * pp_size * ep_size) == 0
- ), f"world size {world_size} is not divisible by tp_size {tp_size} * pp_size {pp_size} * ep_size {ep_size}"
-
- self.dp_size = world_size // (tp_size * pp_size)
- self.tp_size = tp_size
- self.pp_size = pp_size
- self.ep_size = ep_size
- self.sp_size = sp_size
- self.precision = precision
- self.zero_stage = zero_stage
- self.cpu_offload = cpu_offload
- self.enable_all_optimization = enable_all_optimization
- self.enable_fused_normalization = enable_fused_normalization
- self.enable_flash_attention = enable_flash_attention
- self.enable_jit_fused = enable_jit_fused
- self.enable_sequence_parallelism = enable_sequence_parallelism
- self.checkpoint_io = checkpoint_io
-
- logger = get_dist_logger()
-
- # NOTE: Two process meshes: global dp for non-moe param; dp + ep for moe param
- # See https://hpc-ai.com/blog/enhanced-moe-parallelism-open-source-moe-model-training-can-be-9-times-more-efficient
- # we change pg mesh to (pp, dp, tp) for better moe performance
- assert (
- self.ep_size <= self.dp_size
- ), f"Not enough devices({self.dp_size}) for expert parallelism size({self.ep_size})."
-
- self.moe_dp_size = self.dp_size // self.ep_size
- self.use_ep_inside = use_ep_inside
- if self.use_ep_inside:
- logger.info(f"MoE Parallel use ep inside dp.", ranks=[0])
- self.pp_axis, self.dp_axis, self.ep_axis, self.tp_axis = 0, 1, 2, 3
- self.pg_mesh = ProcessGroupMesh(self.pp_size, self.moe_dp_size, ep_size, tp_size)
- else:
- logger.info(f"MoE Parallel use ep outside dp.", ranks=[0])
- warnings.warn("Using ep outside dp (cross-node) is strongly discouraged due to communication costs.")
- self.pp_axis, self.dp_axis, self.ep_axis, self.tp_axis = 0, 2, 1, 3
- self.pg_mesh = ProcessGroupMesh(self.pp_size, ep_size, self.moe_dp_size, tp_size)
-
- self.moe_dp_group = self.pg_mesh.get_group_along_axis(self.dp_axis)
- self.ep_group = self.pg_mesh.get_group_along_axis(self.ep_axis)
- logger.info(f"Non-MoE Parameter Parallel: pp {self.pp_size}, dp {self.dp_size}, tp {tp_size}", ranks=[0])
- logger.info(
- f"MoE Parallel: pp {self.pp_size}, ep {ep_size}, moe dp {self.moe_dp_size}, tp {tp_size}", ranks=[0]
- )
-
- self.tp_group = self.pg_mesh.get_group_along_axis(
- self.tp_axis
- ) # TODO: support custom tp size for mixtral lm head
- self.global_dp_group = self.pg_mesh.get_group_along_axis((self.dp_axis, self.ep_axis))
- self.pp_group = self.pg_mesh.get_group_along_axis(self.pp_axis)
- # TODO: Currently moe only support partially sequence parallel
- self.sp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
-
- self.custom_policy = custom_policy
- self.stage_manager = None
- self.schedule = None
-
- assert zero_stage in (0, 1, 2)
- if self.pp_size > 1:
- assert (
- num_microbatches is not None or microbatch_size is not None
- ), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
- assert self.zero_stage <= 1, "zero stage must be 0 or 1 when using pipeline parallelism"
- self.stage_manager = PipelineStageManager(self.pg_mesh, self.pp_axis)
- self.schedule = OneForwardOneBackwardSchedule(
- self.stage_manager, num_microbatches=num_microbatches, microbatch_size=microbatch_size
+ self.use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
+ if self.use_ddp:
+ warnings.warn(
+ f"Will have to check all params are used in pytorch DDP since not all experts are always activated"
)
+ self.ddp_config["find_unused_parameters"] = True
- self.shard_config = ShardConfig(
- tensor_parallel_process_group=self.tp_group,
- pipeline_stage_manager=self.stage_manager,
- enable_tensor_parallelism=self.tp_size > 1,
- enable_all_optimization=self.enable_all_optimization,
- enable_fused_normalization=self.enable_fused_normalization,
- enable_flash_attention=self.enable_flash_attention,
- enable_jit_fused=self.enable_jit_fused,
- enable_sequence_parallelism=enable_sequence_parallelism,
- enable_sequence_overlap=enable_sequence_overlap,
- ep_group=self.ep_group,
- )
- self.amp_config = dict(
- initial_scale=initial_scale,
- growth_factor=growth_factor,
- backoff_factor=backoff_factor,
- growth_interval=growth_interval,
- hysteresis=hysteresis,
- min_scale=min_scale,
- max_scale=max_scale,
- )
+ if ep_tp_size != 1:
+ raise NotImplementedError
- self.ddp_config = dict(
- broadcast_buffers=broadcast_buffers,
- bucket_cap_mb=ddp_bucket_cap_mb,
- find_unused_parameters=find_unused_parameters,
- check_reduction=check_reduction,
- gradient_as_bucket_view=gradient_as_bucket_view,
- static_graph=static_graph,
- )
+ world_size = dist.get_world_size()
- self.zero_config = dict(
- reduce_bucket_size=zero_bucket_size_in_m * 1024 * 1024,
- communication_dtype=communication_dtype,
- overlap_communication=overlap_communication,
- cpu_offload=cpu_offload,
- partition_grad=(self.zero_stage == 2),
- )
+ self.moe_dp_size = world_size // (ep_size * ep_tp_size)
+ self.ep_size = ep_size
+ self.moe_tp_size = ep_tp_size
- self.max_norm = max_norm
+ self.moe_pg_mesh = ProcessGroupMesh(self.moe_dp_size, self.ep_size, self.moe_tp_size)
+ self.moe_dp_axis, self.ep_axis, self.moe_tp_axis = 0, 1, 2
- def prepare_dataloader(
- self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs
- ):
- r"""
- Prepare a dataloader for distributed training. The dataloader will be wrapped by
- `torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`.
+ self.moe_dp_group = self.moe_pg_mesh.get_group_along_axis(self.moe_dp_axis)
+ self.ep_group = self.moe_pg_mesh.get_group_along_axis(self.ep_axis)
+ self.moe_tp_group = self.moe_pg_mesh.get_group_along_axis(self.moe_tp_axis)
-
- Args:
- dataset (`torch.utils.data.Dataset`): The dataset to be loaded.
- shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False.
- seed (int, optional): Random worker seed for sampling, defaults to 1024.
- add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True.
- drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size
- is not divisible by the batch size. If False and the size of dataset is not divisible by
- the batch size, then the last batch will be smaller, defaults to False.
- pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False.
- num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0.
- kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in
- `DataLoader `_.
-
- Returns:
- :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
- """
- _kwargs = kwargs.copy()
- sampler = DistributedSampler(
- dataset,
- num_replicas=self.dp_size,
- rank=dist.get_rank(self.global_dp_group),
- shuffle=shuffle,
- )
-
- # Deterministic dataloader
- def seed_worker(worker_id):
- worker_seed = seed
- np.random.seed(worker_seed)
- torch.manual_seed(worker_seed)
- random.seed(worker_seed)
-
- return DataLoader(
- dataset,
- batch_size=batch_size,
- sampler=sampler,
- worker_init_fn=seed_worker,
- drop_last=drop_last,
- pin_memory=pin_memory,
- num_workers=num_workers,
- **_kwargs,
- )
+ # set ep_group after super init
+ # TODO do it in a better way
+ self.shard_config.ep_group = self.ep_group
def get_checkpoint_io(self) -> MoECheckpointIO:
- if self.checkpoint_io is None:
- self.checkpoint_io = MoECheckpointIO(
- self.global_dp_group, self.pp_group, self.tp_group, self.ep_group, self.moe_dp_group, self.zero_stage
- )
- else:
- self.checkpoint_io = self.checkpoint_io(
- self.global_dp_group,
- self.pp_group,
- self.tp_group,
- ep_group=self.ep_group,
- moe_dp_group=self.moe_dp_group,
- zero_stage=self.zero_stage,
- )
- if hasattr(self.checkpoint_io, "moe_info"):
- self.checkpoint_io.moe_info = self.moe_info
- return self.checkpoint_io
+ return MoECheckpointIO(
+ self.dp_group, self.pp_group, self.tp_group, self.ep_group, self.moe_dp_group, self.zero_stage
+ )
def configure(
self,
@@ -392,15 +134,14 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
param_info = get_param_info(optimizer)
if not isinstance(model, ModelWrapper):
- use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
model = HybridParallelModule(
module=model,
precision=self.precision,
shard_config=self.shard_config,
- dp_group=self.global_dp_group,
+ dp_group=self.dp_group,
tp_group=self.tp_group,
sp_group=self.sp_group,
- use_ddp=use_ddp, # TODO fix why this failed
+ use_ddp=self.use_ddp,
ddp_config=self.ddp_config,
custom_policy=self.custom_policy,
)
@@ -411,8 +152,6 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
reinitialize_optimizer(optimizer, model)
if self.zero_stage == 0:
- # assert self.ep_size > 1
-
if self.precision in ["fp16", "bf16"]:
optimizer = HybridParallelAMPOptimizer(
optimizer,
@@ -435,10 +174,8 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
model,
use_pipeline=self.enable_pipeline_parallelism,
param_info=param_info,
- dp_process_group=self.global_dp_group,
- tp_process_group=self.tp_group,
- pp_process_group=self.pp_group,
- moe_extra_dp_process_group=self.moe_dp_group,
+ dp_process_group=self.dp_group,
+ moe_dp_group=self.moe_dp_group,
verbose=True,
clip_grad_norm=self.max_norm,
**self.zero_config,
diff --git a/colossalai/cluster/process_group_mesh.py b/colossalai/cluster/process_group_mesh.py
index 7f1ef9fce..c09c7a2cc 100644
--- a/colossalai/cluster/process_group_mesh.py
+++ b/colossalai/cluster/process_group_mesh.py
@@ -137,7 +137,7 @@ class ProcessGroupMesh:
assert mode in ["raise", "wrap", "clip"]
return int(np.ravel_multi_index(coord, shape, mode))
- def get_group(self, ranks_in_group: List[int], backend: Optional[str] = None) -> ProcessGroup:
+ def _get_group(self, ranks_in_group: List[int], backend: Optional[str] = None) -> ProcessGroup:
"""Get the process group with the given ranks. It the process group doesn't exist, it will be created.
Args:
@@ -240,7 +240,7 @@ class ProcessGroupMesh:
for base_coord in itertools.product(*[range(s) for s in reduced_shape]):
coords_in_group = ProcessGroupMesh.get_coords_along_axis(base_coord, axis, indices_at_axis)
ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group])
- group = self.get_group(ranks_in_group, backend=backend)
+ group = self._get_group(ranks_in_group, backend=backend)
if self._rank in ranks_in_group:
target_group = group
return target_group
diff --git a/colossalai/moe/_operation.py b/colossalai/moe/_operation.py
index 01c837ee3..3df349182 100644
--- a/colossalai/moe/_operation.py
+++ b/colossalai/moe/_operation.py
@@ -393,4 +393,7 @@ def all_to_all_uneven(
group=None,
overlap: bool = False,
):
+ assert (
+ inputs.requires_grad
+ ), "Input must require grad to assure that backward is executed, otherwise it might hang the program."
return AllToAllUneven.apply(inputs, input_split_sizes, output_split_sizes, group, overlap)
diff --git a/colossalai/shardformer/policies/mixtral.py b/colossalai/shardformer/policies/mixtral.py
index e3cc48043..98554c906 100644
--- a/colossalai/shardformer/policies/mixtral.py
+++ b/colossalai/shardformer/policies/mixtral.py
@@ -101,20 +101,18 @@ class MixtralPolicy(Policy):
# )
if getattr(self.shard_config, "ep_group", None) is None:
- raise ValueError("You must pass in ep_group via shard_config for expert parallel!")
-
- # expert parallel
- self.append_or_create_submodule_replacement(
- description=[
- SubModuleReplacementDescription(
- suffix="block_sparse_moe",
- target_module=EPMixtralSparseMoeBlock,
- kwargs={"ep_group": self.shard_config.ep_group},
- )
- ],
- policy=policy,
- target_key=MixtralDecoderLayer,
- )
+ # expert parallel
+ self.append_or_create_submodule_replacement(
+ description=[
+ SubModuleReplacementDescription(
+ suffix="block_sparse_moe",
+ target_module=EPMixtralSparseMoeBlock,
+ kwargs={"ep_group": self.shard_config.ep_group},
+ )
+ ],
+ policy=policy,
+ target_key=MixtralDecoderLayer,
+ )
# optimization configuration
if self.shard_config.enable_fused_normalization:
@@ -144,6 +142,7 @@ class MixtralPolicy(Policy):
if self.shard_config.enable_flash_attention:
warnings.warn("Flash attention is natively supported in transformers, will ignore the flag.")
+ self.shard_config.enable_flash_attention = False
return policy
diff --git a/colossalai/zero/low_level/bookkeeping/bucket_store.py b/colossalai/zero/low_level/bookkeeping/bucket_store.py
index 19d20de2b..0d0a606c0 100644
--- a/colossalai/zero/low_level/bookkeeping/bucket_store.py
+++ b/colossalai/zero/low_level/bookkeeping/bucket_store.py
@@ -100,7 +100,7 @@ class BucketStore(BaseStore):
return self._grad_in_bucket
- def get_flatten_grad(self) -> Tensor:
+ def get_flatten_grad(self, dtype=None) -> Tensor:
"""Return the flattened gradients slices in the bucket, the data organization of the flattened tensor:
[grad0_rank0, grad1_rank0, ..., grad_0_rank1, grad1_rank1, ....]
@@ -110,8 +110,12 @@ class BucketStore(BaseStore):
flat_grad = []
for grad_list in self._grad_in_bucket.values():
- flat_grad.append(_flatten_dense_tensors(grad_list))
- flat_grad = _flatten_dense_tensors(flat_grad)
+ if len(grad_list) > 0:
+ flat_grad.append(_flatten_dense_tensors(grad_list))
+ if len(flat_grad) > 0:
+ flat_grad = _flatten_dense_tensors(flat_grad)
+ else:
+ flat_grad = torch.tensor([], device=self.comm_stream.device, dtype=dtype)
return flat_grad
def get_param_id_of_grad(self, grad: Tensor) -> int:
diff --git a/colossalai/zero/low_level/bookkeeping/gradient_store.py b/colossalai/zero/low_level/bookkeeping/gradient_store.py
index e24a67f9d..a13fa120a 100644
--- a/colossalai/zero/low_level/bookkeeping/gradient_store.py
+++ b/colossalai/zero/low_level/bookkeeping/gradient_store.py
@@ -91,7 +91,7 @@ class GradientStore(BaseStore):
return grad_list
- def get_working_grad_by_param_id(self, param_id) -> Tensor:
+ def get_working_grad_by_param_id(self, param_id) -> Optional[Tensor]:
"""
Return the working gradient for the specified parameter.
diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py
index 01382cd8e..54c6caf41 100644
--- a/colossalai/zero/low_level/low_level_optim.py
+++ b/colossalai/zero/low_level/low_level_optim.py
@@ -301,12 +301,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
def _run_reduction(self):
for bucket_store in self.pg_to_bucket_store.values():
- if bucket_store.num_elements_in_bucket() <= 0:
- continue
-
bucket_store.build_grad_in_bucket()
- flat_grads = bucket_store.get_flatten_grad()
+ flat_grads = bucket_store.get_flatten_grad(self._dtype)
flat_grads /= bucket_store.world_size
# ready to add other tensors to bucket
@@ -353,6 +350,8 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
self, bucket_store: BucketStore, origin_grad_list: List, flat_grad_list: List, group_id: int
) -> None:
for rank, grad_list in enumerate(origin_grad_list):
+ if len(grad_list) == 0:
+ continue
sync_tensor(flat_grad_list[rank], grad_list)
for grad in grad_list:
param_id = bucket_store.get_param_id_of_grad(grad)
@@ -869,12 +868,11 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
def get_param_grad(self, working_param: nn.Parameter) -> Tensor:
grad_store = self.pid_to_grad_store[id(working_param)]
- partial_grad = grad_store.get_working_grad_by_param_id(id(working_param))
- if partial_grad is None:
+ grad = grad_store.get_working_grad_by_param_id(id(working_param))
+ if grad is None:
return None
- tensor_list = [torch.empty_like(partial_grad) for _ in range(grad_store.world_size)]
- dist.all_gather(tensor_list, partial_grad, group=grad_store.torch_pg)
- grad_flat = torch.cat(tensor_list, dim=0)
+ grad_flat = torch.empty((grad_store.world_size, *grad.shape), dtype=grad.dtype, device=grad.device)
+ dist.all_gather_into_tensor(grad_flat, grad, group=grad_store.torch_pg)
return grad_flat[: working_param.numel()].reshape_as(working_param)
def get_working_grads_by_group_id(self, group_id: int) -> List[Tensor]:
diff --git a/tests/kit/model_zoo/transformers/mixtral.py b/tests/kit/model_zoo/transformers/mixtral.py
index 0ac6a75ce..7fa4ff335 100644
--- a/tests/kit/model_zoo/transformers/mixtral.py
+++ b/tests/kit/model_zoo/transformers/mixtral.py
@@ -19,7 +19,7 @@ def data_gen():
# tokenized_input = tokenizer([input], return_tensors="pt")
# input_ids = tokenized_input['input_ids']
# attention_mask = tokenized_input['attention_mask']
- input_ids = torch.tensor([[1, 1984, 16020, 2076, 2487, 349, 21375, 4749]], dtype=torch.int64)
+ input_ids = torch.tensor([[1, 22, 55, 77, 532, 349, 43, 22]], dtype=torch.int64)
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
return dict(input_ids=input_ids, attention_mask=attention_mask)
@@ -43,7 +43,7 @@ def data_gen_for_sequence_classification():
output_transform_fn = lambda x: x
# define loss function
-loss_fn_for_mixtral_model = lambda x: torch.nn.functional.mse_loss(x[0], torch.ones_like(x[0]))
+loss_fn_for_mixtral_model = lambda x: x[0].mean()
loss_fn = lambda x: x.loss
loss_fn_for_seq_classification = lambda output: output.logits.mean()
@@ -52,7 +52,7 @@ config = MixtralConfig(
intermediate_size=256,
num_attention_heads=64,
num_hidden_layers=2,
- vocab_size=50258,
+ vocab_size=1000,
output_router_logits=True,
)
diff --git a/tests/test_moe/test_moe_checkpoint.py b/tests/test_moe/test_moe_checkpoint.py
index 164301695..773036358 100644
--- a/tests/test_moe/test_moe_checkpoint.py
+++ b/tests/test_moe/test_moe_checkpoint.py
@@ -141,7 +141,6 @@ def check_moe_checkpoint(test_config):
if dist.get_rank() == 0:
saved_model = model_cls.from_pretrained(model_dir).cuda()
check_model_equal(orig_model, saved_model)
- # check_model_equal(model, saved_model)
saved_model.save_pretrained(hf_model_dir)
dist.barrier()
# check load model
diff --git a/tests/test_moe/test_moe_zero_fwd_bwd_optim.py b/tests/test_moe/test_moe_zero_fwd_bwd_optim.py
index 042b3d8ae..2e6d0d786 100644
--- a/tests/test_moe/test_moe_zero_fwd_bwd_optim.py
+++ b/tests/test_moe/test_moe_zero_fwd_bwd_optim.py
@@ -31,16 +31,17 @@ def split_grad(grad, world_size):
return splited_grad
-@parameterize("dtype", [torch.float16, torch.bfloat16])
-@parameterize("master_weights", [True, False])
@parameterize("stage", [1, 2])
-def run_zero_with_original_model(world_size, master_weights: bool, dtype: torch.dtype, stage: int):
+@parameterize("ep_size", [1, 2, 4])
+def run_zero_with_original_model(stage: int, ep_size: int):
+ dtype = torch.float16
+
rank = torch.distributed.get_rank()
torch.cuda.set_device(dist.get_rank())
plugin = MoeHybridParallelPlugin(
tp_size=1,
pp_size=1,
- ep_size=dist.get_world_size() // 2,
+ ep_size=ep_size,
)
seed_all(10086)
@@ -53,26 +54,30 @@ def run_zero_with_original_model(world_size, master_weights: bool, dtype: torch.
orig_model = MixtralSparseMoeBlock(config).to(dtype).cuda()
- ori_model = DDP(orig_model.cuda(), static_graph=True).cuda()
+ ori_model = DDP(
+ orig_model.cuda(),
+ process_group=plugin.dp_group,
+ find_unused_parameters=True, # important for torch ddp, not all experts are routed
+ ).cuda()
zero_model = deepcopy(orig_model).to(dtype)
zero_model = EPMixtralSparseMoeBlock.from_native_module(zero_model, ep_group=plugin.ep_group)
zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1)
- pg_param_list = {plugin.global_dp_group: [], plugin.moe_dp_group: []}
+ pg_param_list = {plugin.dp_group: [], plugin.moe_dp_group: []}
for p in zero_model.parameters():
if is_moe_tensor(p):
pg_param_list[plugin.moe_dp_group].append(p)
else:
- pg_param_list[plugin.global_dp_group].append(p)
+ pg_param_list[plugin.dp_group].append(p)
zero_optimizer = LowLevelZeroOptimizer(
zero_optimizer,
pg_to_param_list=pg_param_list,
- master_weights=master_weights,
+ master_weights=False,
initial_scale=1,
- overlap_communication=False,
- partition_grad=True,
+ overlap_communication=True,
+ partition_grad=stage == 2,
)
ori_optimizer = torch.optim.SGD(ori_model.parameters(), lr=1)
@@ -82,11 +87,11 @@ def run_zero_with_original_model(world_size, master_weights: bool, dtype: torch.
for _ in range(2):
# zero-dp forward
- input_data = torch.rand(1, tokens, hidden_size).cuda()
- zero_output, zero_logits = zero_model(input_data.to(dtype))
+ input_data = torch.rand(1, tokens, hidden_size, requires_grad=True).cuda()
+ zero_output, _ = zero_model(input_data.to(dtype))
# torch-ddp forward
- ori_output, ori_logits = ori_model(input_data.to(dtype))
+ ori_output, _ = ori_model(input_data.to(dtype))
loose_close(zero_output, ori_output, dtype=dtype)
# zero-dp backward
@@ -115,14 +120,16 @@ def run_zero_with_original_model(world_size, master_weights: bool, dtype: torch.
for n, p in zero_model.named_parameters():
loose_close(p.data, name_to_p[n].data, dtype=dtype)
+ print(f"{dist.get_rank()} test passed")
+
def run_dist(rank, world_size, port):
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
- run_zero_with_original_model(world_size=world_size)
+ run_zero_with_original_model()
@pytest.mark.dist
-@pytest.mark.parametrize("world_size", [2, 4])
+@pytest.mark.parametrize("world_size", [4])
@rerun_if_address_is_in_use()
def test_moe_zero_model(world_size):
spawn(run_dist, world_size)
diff --git a/tests/test_shardformer/test_model/test_shard_mixtral.py b/tests/test_shardformer/test_model/test_shard_mixtral.py
index f8deb2e8a..98f7213a3 100644
--- a/tests/test_shardformer/test_model/test_shard_mixtral.py
+++ b/tests/test_shardformer/test_model/test_shard_mixtral.py
@@ -25,13 +25,14 @@ os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
+ # TODO: SGD failed for full dp
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin(
- model_fn, loss_fn, test_config, pluggin_cls=MoeHybridParallelPlugin, optim_class=torch.optim.SGD
- )
-
- org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
- org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
+ model_fn, loss_fn, test_config, pluggin_cls=MoeHybridParallelPlugin, optim_class=torch.optim.Adam
)
+ with torch.autograd.set_detect_anomaly(True):
+ org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
+ org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
+ )
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
@@ -73,6 +74,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
grads_to_check.update(col_layer_grads)
grads_to_check.update(row_layer_grads)
+ # check grads
+ check_all_grad_tensors(grads_to_check)
+
# optimizer executes step
org_optimizer.step()
sharded_optimizer.step()
@@ -103,9 +107,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
verbose=False,
)
- # check grads
- check_all_grad_tensors(grads_to_check)
-
torch.cuda.empty_cache()
@@ -114,37 +115,22 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
[
{
"tp_size": 1,
- "pp_size": 4,
+ "pp_size": 2,
+ "num_microbatches": 2,
"ep_size": 1,
- "num_microbatches": 4,
"zero_stage": 0,
- "enable_all_optimization": True,
- "use_lazy_init": False,
- "precision": "fp16",
- "initial_scale": 1,
- },
- # {
+ "precision": "fp32",
+ }, # pp + ep
+ # {"tp_size": 1, "pp_size": 1, "ep_size": 1, "zero_stage": 1, "precision": "fp16"}, # full dp for moe and non-moe
+ # { # moe_dp = 2, non_moe_dp = 4
# "tp_size": 1,
# "pp_size": 1,
- # "ep_size": 4,
- # "num_microbatches": 2,
+ # "ep_size": 2,
# "zero_stage": 1,
- # "enable_all_optimization": True,
- # "use_lazy_init": False,
# "precision": "fp16",
- # "initial_scale": 1,
- # },
- # {
- # "tp_size": 1,
- # "pp_size": 1,
- # "ep_size": 4,
- # "num_microbatches": 2,
- # "zero_stage": 2,
- # "enable_all_optimization": True,
- # "use_lazy_init": False,
- # "precision": "fp16",
- # "initial_scale": 1,
- # },
+ # }, # moe_dp = 1, non_moe_dp = 4
+ # {"tp_size": 1, "pp_size": 1, "ep_size": 4, "zero_stage": 1, "precision": "fp16"},
+ # {"tp_size": 1, "pp_size": 1, "ep_size": 1, "zero_stage": 0, "precision": "fp32"}, # full dp for moe and non-moe
],
)
def run_mixtral_test(test_config):
From 37443cc7e499aa836d4897bf51b1119815da45b3 Mon Sep 17 00:00:00 2001
From: botbw
Date: Mon, 8 Jul 2024 05:13:49 +0000
Subject: [PATCH 062/116] [test] pass mixtral shardformer test
---
.../booster/plugin/hybrid_parallel_plugin.py | 6 +++
.../plugin/moe_hybrid_parallel_plugin.py | 4 +-
colossalai/shardformer/policies/mixtral.py | 14 ++---
.../test_model/test_shard_mixtral.py | 54 ++++++++++++-------
4 files changed, 51 insertions(+), 27 deletions(-)
diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py
index 983ddfc97..ddfe0b2d9 100644
--- a/colossalai/booster/plugin/hybrid_parallel_plugin.py
+++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py
@@ -38,6 +38,7 @@ from colossalai.tensor.d_tensor.api import is_distributed_tensor
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.zero.low_level import LowLevelZeroOptimizer
from colossalai.zero.low_level.zero_hook import ZeroOpHook, wait_all_gather_handle
+from colossalai.logging import get_dist_logger
from .pp_plugin_base import PipelinePluginBase
@@ -1016,6 +1017,9 @@ class HybridParallelPlugin(PipelinePluginBase):
overlap_allgather: bool = False,
) -> None:
super().__init__()
+
+ self.logger = get_dist_logger(type(self).__name__)
+
assert (
dist.get_world_size() % (tp_size * pp_size) == 0
), f"World size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
@@ -1064,6 +1068,8 @@ class HybridParallelPlugin(PipelinePluginBase):
self.pp_axis, self.dp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size)
+ self.logger.info(f"{type(self).__name__}: {self.pp_size=} {self.dp_size=} {self.tp_size=} {self.sp_size=}")
+
self.stage_manager = None
self.schedule = None
self.custom_policy = custom_policy
diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
index 02a87ff11..b2ee9f650 100644
--- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
+++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
@@ -24,7 +24,6 @@ from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.tensor.moe_tensor.api import is_moe_tensor
from colossalai.zero.low_level import LowLevelZeroOptimizer
-
class MoeHybridParallelZeroOptimizer(LowLevelZeroOptimizer):
def __init__(
self,
@@ -115,6 +114,8 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
self.ep_group = self.moe_pg_mesh.get_group_along_axis(self.ep_axis)
self.moe_tp_group = self.moe_pg_mesh.get_group_along_axis(self.moe_tp_axis)
+ self.logger.info(f"{type(self).__name__}: {self.ep_size=} {self.moe_dp_size=} {self.moe_tp_size=}")
+
# set ep_group after super init
# TODO do it in a better way
self.shard_config.ep_group = self.ep_group
@@ -168,7 +169,6 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
)
else:
assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1."
- assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO."
optimizer = MoeHybridParallelZeroOptimizer(
optimizer,
model,
diff --git a/colossalai/shardformer/policies/mixtral.py b/colossalai/shardformer/policies/mixtral.py
index 98554c906..410515362 100644
--- a/colossalai/shardformer/policies/mixtral.py
+++ b/colossalai/shardformer/policies/mixtral.py
@@ -20,13 +20,15 @@ class MixtralPolicy(Policy):
def preprocess(self):
if self.shard_config.enable_tensor_parallelism:
- # Resize embedding
- vocab_size = self.model.config.vocab_size
- world_size = self.shard_config.tensor_parallel_size
+ raise NotImplementedError
+
+ # # Resize embedding
+ # vocab_size = self.model.config.vocab_size
+ # world_size = self.shard_config.tensor_parallel_size
- if vocab_size % world_size != 0:
- new_vocab_size = vocab_size + world_size - vocab_size % world_size
- self.model.resize_token_embeddings(new_vocab_size)
+ # if vocab_size % world_size != 0:
+ # new_vocab_size = vocab_size + world_size - vocab_size % world_size
+ # self.model.resize_token_embeddings(new_vocab_size)
return self.model
diff --git a/tests/test_shardformer/test_model/test_shard_mixtral.py b/tests/test_shardformer/test_model/test_shard_mixtral.py
index 98f7213a3..4a5f3e14d 100644
--- a/tests/test_shardformer/test_model/test_shard_mixtral.py
+++ b/tests/test_shardformer/test_model/test_shard_mixtral.py
@@ -37,6 +37,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
+ # check last hidden state & loss
+ if stage_manager is None or stage_manager.is_last_stage():
+ if test_config["precision"] == "fp32":
+ atol, rtol = 1e-5, 1e-3
+ else:
+ atol, rtol = 5e-3, 5e-3
+
+ check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
+
# unwrap model
mixtral_model = unwrap_model(org_model, "MixtralModel", "model")
shard_mixtral_model = unwrap_model(sharded_model, "MixtralModel", "model")
@@ -81,15 +90,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
org_optimizer.step()
sharded_optimizer.step()
- # check last hidden state & loss
- if stage_manager is None or stage_manager.is_last_stage():
- if test_config["precision"] == "fp32":
- atol, rtol = 1e-5, 1e-3
- else:
- atol, rtol = 5e-3, 5e-3
-
- check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
-
# check weights
if stage_manager is None or stage_manager.is_first_stage():
if test_config["precision"] == "fp32":
@@ -121,16 +121,32 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"zero_stage": 0,
"precision": "fp32",
}, # pp + ep
- # {"tp_size": 1, "pp_size": 1, "ep_size": 1, "zero_stage": 1, "precision": "fp16"}, # full dp for moe and non-moe
- # { # moe_dp = 2, non_moe_dp = 4
- # "tp_size": 1,
- # "pp_size": 1,
- # "ep_size": 2,
- # "zero_stage": 1,
- # "precision": "fp16",
- # }, # moe_dp = 1, non_moe_dp = 4
- # {"tp_size": 1, "pp_size": 1, "ep_size": 4, "zero_stage": 1, "precision": "fp16"},
- # {"tp_size": 1, "pp_size": 1, "ep_size": 1, "zero_stage": 0, "precision": "fp32"}, # full dp for moe and non-moe
+ {
+ "tp_size": 1,
+ "pp_size": 2,
+ "num_microbatches": 2,
+ "ep_size": 1,
+ "zero_stage": 0,
+ "precision": "fp32",
+ }, # pp + ep
+ {
+ "tp_size": 1,
+ "pp_size": 2,
+ "num_microbatches": 2,
+ "ep_size": 4,
+ "zero_stage": 0,
+ "precision": "fp32",
+ }, # pp + ep
+ {"tp_size": 1, "pp_size": 1, "ep_size": 1, "zero_stage": 1, "precision": "bf16"}, # full dp for moe and non-moe
+ { # moe_dp = 2, non_moe_dp = 4
+ "tp_size": 1,
+ "pp_size": 1,
+ "ep_size": 2,
+ "zero_stage": 1,
+ "precision": "fp32",
+ }, # moe_dp = 1, non_moe_dp = 4
+ {"tp_size": 1, "pp_size": 1, "ep_size": 4, "zero_stage": 1, "precision": "fp32"}, # full dp for non-moe and full ep for moe
+ {"tp_size": 1, "pp_size": 1, "ep_size": 1, "zero_stage": 0, "precision": "fp32"}, # full dp for moe and non-moe
],
)
def run_mixtral_test(test_config):
From b5bfeb2efd8ef213a03deaa57a175886fbe3e112 Mon Sep 17 00:00:00 2001
From: botbw
Date: Mon, 8 Jul 2024 09:59:46 +0000
Subject: [PATCH 063/116] [moe] implement transit between non moe tp and ep
---
.../booster/plugin/hybrid_parallel_plugin.py | 2 +-
.../plugin/moe_hybrid_parallel_plugin.py | 16 +--
colossalai/moe/_operation.py | 103 +++++++++++++++
colossalai/shardformer/modeling/mixtral.py | 32 +++--
colossalai/shardformer/policies/mixtral.py | 125 ++++++++----------
colossalai/shardformer/shard/shard_config.py | 2 +
.../test_model/test_shard_mixtral.py | 55 +++++---
7 files changed, 234 insertions(+), 101 deletions(-)
diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py
index ddfe0b2d9..cad9ca95c 100644
--- a/colossalai/booster/plugin/hybrid_parallel_plugin.py
+++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py
@@ -1068,7 +1068,7 @@ class HybridParallelPlugin(PipelinePluginBase):
self.pp_axis, self.dp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size)
- self.logger.info(f"{type(self).__name__}: {self.pp_size=} {self.dp_size=} {self.tp_size=} {self.sp_size=}")
+ self.logger.info(f"{type(self).__name__}: {self.pp_size=} {self.dp_size=} {self.tp_size=} {self.sp_size=}", ranks=[0])
self.stage_manager = None
self.schedule = None
diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
index b2ee9f650..0b0d50e28 100644
--- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
+++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
@@ -30,8 +30,8 @@ class MoeHybridParallelZeroOptimizer(LowLevelZeroOptimizer):
optimizer: Optimizer,
model: Module,
use_pipeline: bool,
- dp_process_group: ProcessGroup, # the dp pg for comm
- moe_dp_group: ProcessGroup, # the moe dp pg for gomm
+ dp_process_group: ProcessGroup, # dp pg for comm
+ moe_dp_group: ProcessGroup, # moe dp pg for comm
param_info: OrderedDict,
initial_scale: int = 2**16, # grad scaler config
min_scale: int = 1,
@@ -44,7 +44,7 @@ class MoeHybridParallelZeroOptimizer(LowLevelZeroOptimizer):
verbose: bool = False,
reduce_bucket_size: int = 1024 * 1024, # communication
communication_dtype: Optional[torch.dtype] = None,
- overlap_communication: bool = True,
+ overlap_communication: bool = False,
partition_grad: bool = False, # stage 2 flag
cpu_offload: bool = False, # cpu offload
forced_dtype: Optional[torch.dtype] = None,
@@ -88,7 +88,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
TODO: add docstring
"""
- def __init__(self, ep_size: int, ep_tp_size: int = 1, *args, **kwargs) -> None:
+ def __init__(self, ep_size: int, moe_tp_size: int = 1, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
@@ -98,14 +98,14 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
)
self.ddp_config["find_unused_parameters"] = True
- if ep_tp_size != 1:
+ if moe_tp_size != 1:
raise NotImplementedError
world_size = dist.get_world_size()
- self.moe_dp_size = world_size // (ep_size * ep_tp_size)
+ self.moe_dp_size = world_size // (ep_size * moe_tp_size)
self.ep_size = ep_size
- self.moe_tp_size = ep_tp_size
+ self.moe_tp_size = moe_tp_size
self.moe_pg_mesh = ProcessGroupMesh(self.moe_dp_size, self.ep_size, self.moe_tp_size)
self.moe_dp_axis, self.ep_axis, self.moe_tp_axis = 0, 1, 2
@@ -114,7 +114,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
self.ep_group = self.moe_pg_mesh.get_group_along_axis(self.ep_axis)
self.moe_tp_group = self.moe_pg_mesh.get_group_along_axis(self.moe_tp_axis)
- self.logger.info(f"{type(self).__name__}: {self.ep_size=} {self.moe_dp_size=} {self.moe_tp_size=}")
+ self.logger.info(f"{type(self).__name__}: {self.ep_size=} {self.moe_dp_size=} {self.moe_tp_size=}", ranks=[0])
# set ep_group after super init
# TODO do it in a better way
diff --git a/colossalai/moe/_operation.py b/colossalai/moe/_operation.py
index 3df349182..cad9573fb 100644
--- a/colossalai/moe/_operation.py
+++ b/colossalai/moe/_operation.py
@@ -397,3 +397,106 @@ def all_to_all_uneven(
inputs.requires_grad
), "Input must require grad to assure that backward is executed, otherwise it might hang the program."
return AllToAllUneven.apply(inputs, input_split_sizes, output_split_sizes, group, overlap)
+
+
+# ===========================================================
+# This code section was modified from
+# https://github.com/microsoft/DeepSpeed/blob/3d347276ce80e1a29e777c839d1d7fabe8e5f034/deepspeed/moe/mappings.py
+
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+# The file has been adapted from the following Megatron-LM file:
+# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/mpu/mappings.py
+# Git commit hash: 9dc3c42a84aa656f583703cf8b6b4f79f712b796
+# We retain the following copyright from the original files:
+
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+def _gather_tokens(input_, dim: int, tp_group: ProcessGroup):
+ """Gather tensors and concatenate them along a dimension"""
+
+ input_ = input_.contiguous()
+ # Size and dimension.
+ rank = tp_group.rank()
+
+ tensor_list = [torch.empty_like(input_) for _ in range(tp_group.size())]
+ tensor_list[rank] = input_
+ dist.all_gather(tensor_list, input_, group=tp_group)
+
+ # Note: torch.cat already creates a contiguous tensor.
+ output = torch.cat(tensor_list, dim=dim).contiguous()
+
+ return output
+
+
+def _drop_tokens(input_, dim: int, tp_group: ProcessGroup):
+ """Divide a tensor among the tensor parallel ranks"""
+
+ total_chunks = tp_group.size()
+ this_chunk = tp_group.rank()
+ assert input_.shape[
+ dim] % total_chunks == 0, f"input dimension {dim} ({input_.shape[dim]}) is not divisible by tensor parallel world size ({total_chunks})"
+ chunk_size = input_.shape[dim] // total_chunks
+
+ return torch.narrow(input_, dim, this_chunk * chunk_size, chunk_size)
+
+
+class _GatherTokens(torch.autograd.Function):
+ """All gather tokens among the tensor parallel ranks"""
+
+ @staticmethod
+ def forward(ctx, input_: torch.Tensor, dim: int, tp_group: ProcessGroup) -> torch.Tensor:
+ ctx.dim = dim
+ ctx.tp_group = tp_group
+ return _gather_tokens(input_, dim, tp_group)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ return _drop_tokens(grad_output, ctx.dim, ctx.tp_group), None, None
+
+
+class _DropTokens(torch.autograd.Function):
+ "Divide tokens equally among the tensor parallel ranks"
+
+ @staticmethod
+ def forward(ctx, input_: torch.Tensor, dim: int, tp_group: ProcessGroup) -> torch.Tensor:
+ ctx.dim = dim
+ ctx.tp_group = tp_group
+ return _drop_tokens(input_, dim, tp_group)
+
+ @staticmethod
+ def backward(ctx, input_: torch.Tensor) -> Tuple[torch.Tensor, None]:
+ return _gather_tokens(input_, ctx.dim, ctx.tp_group), None, None
+
+
+def gather_tokens(input_, dim: int, tp_group: ProcessGroup):
+ if tp_group.size() == 1:
+ # no tensor parallelism for non-experts
+ return input_
+ assert input_.requires_grad, "Input must require grad to assure that backward is executed, otherwise it might hang the program."
+ return _GatherTokens.apply(input_, dim)
+
+
+def drop_tokens(input_, dim: int, tp_group: ProcessGroup):
+ if tp_group.size() == 1:
+ # no tensor parallelism for non-experts
+ return input_
+ assert input_.requires_grad, "Input must require grad to assure that backward is executed, otherwise it might hang the program."
+ return _DropTokens.apply(input_, dim, tp_group)
+
+# ===========================================================
diff --git a/colossalai/shardformer/modeling/mixtral.py b/colossalai/shardformer/modeling/mixtral.py
index 334bd13fc..5d2dc1dc3 100644
--- a/colossalai/shardformer/modeling/mixtral.py
+++ b/colossalai/shardformer/modeling/mixtral.py
@@ -14,21 +14,21 @@ from transformers.models.mixtral.modeling_mixtral import (
from transformers.utils import is_flash_attn_2_available, logging
from colossalai.lazy import LazyInitContext
-from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler, all_to_all_uneven
+from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler, all_to_all_uneven, drop_tokens, gather_tokens
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.shard import ShardConfig
from colossalai.shardformer.shard.utils import set_tensors_to_none
class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
- def __init__(self, config, ep_group):
+ def __init__(self, config, ep_group: ProcessGroup, tp_group: Optional[ProcessGroup]=None, moe_tp_group: Optional[ProcessGroup]=None):
super().__init__(config)
- self.setup_ep(ep_group)
+ self.setup_process_groups(ep_group, tp_group, moe_tp_group)
- def setup_ep(self, ep_group: ProcessGroup):
- ep_group = ep_group
- self.ep_size = dist.get_world_size(ep_group) if ep_group is not None else 1
- self.ep_rank = dist.get_rank(ep_group) if ep_group is not None else 0
+ def setup_process_groups(self, ep_group: ProcessGroup, tp_group: Optional[ProcessGroup]=None, moe_tp_group: Optional[ProcessGroup]=None):
+ # setup ep group
+ self.ep_size = dist.get_world_size(ep_group)
+ self.ep_rank = dist.get_rank(ep_group)
self.ep_group = ep_group
if self.num_experts % self.ep_size != 0:
@@ -42,13 +42,19 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
for p in self.experts.parameters():
p.ep_group = ep_group
+ # setup global tp group
+ self.tp_group = tp_group
+
+ # setup moe tp group
+ self.moe_tp_group = moe_tp_group
+
@staticmethod
def from_native_module(
- module: MixtralSparseMoeBlock, ep_group: ProcessGroup, *args, **kwargs
+ module: MixtralSparseMoeBlock, ep_group: ProcessGroup, tp_group: Optional[ProcessGroup]=None, moe_tp_group: Optional[ProcessGroup]=None, *args, **kwargs
) -> "EPMixtralSparseMoeBlock":
LazyInitContext.materialize(module)
module.__class__ = EPMixtralSparseMoeBlock
- module.setup_ep(ep_group)
+ module.setup_process_groups(ep_group)
return module
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@@ -72,6 +78,10 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
+
+ if self.tp_group is not None and self.tp_group.size() > 1:
+ dispatch_states = drop_tokens(dispatch_states, -1, self.tp_group)
+
output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group)
# compute expert output
output_states = MoeInGradScaler.apply(output_states, self.ep_size)
@@ -94,6 +104,10 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
output_states = torch.cat(output_states_list)
output_states = MoeOutGradScaler.apply(output_states, self.ep_size)
dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group)
+
+ if self.tp_group is not None and self.tp_group.size() > 1:
+ dispatch_states = gather_tokens(dispatch_states, -1, self.tp_group)
+
recover_experts_idx = torch.empty_like(selected_experts_idx)
recover_experts_idx[selected_experts_idx] = torch.arange(
selected_experts_idx.size(0), device=selected_experts_idx.device
diff --git a/colossalai/shardformer/policies/mixtral.py b/colossalai/shardformer/policies/mixtral.py
index 410515362..14d57c79d 100644
--- a/colossalai/shardformer/policies/mixtral.py
+++ b/colossalai/shardformer/policies/mixtral.py
@@ -8,6 +8,7 @@ from torch.nn import Module
from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer, MixtralForCausalLM, MixtralModel
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
+from colossalai.shardformer.layer.linear import Linear1D_Row
from colossalai.shardformer.modeling.mixtral import EPMixtralSparseMoeBlock, MixtralPipelineForwards
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
@@ -20,15 +21,15 @@ class MixtralPolicy(Policy):
def preprocess(self):
if self.shard_config.enable_tensor_parallelism:
- raise NotImplementedError
-
- # # Resize embedding
- # vocab_size = self.model.config.vocab_size
- # world_size = self.shard_config.tensor_parallel_size
+ # non-moe params tensor parallelism
- # if vocab_size % world_size != 0:
- # new_vocab_size = vocab_size + world_size - vocab_size % world_size
- # self.model.resize_token_embeddings(new_vocab_size)
+ # Resize embedding
+ vocab_size = self.model.config.vocab_size
+ world_size = self.shard_config.tensor_parallel_size
+
+ if vocab_size % world_size != 0:
+ new_vocab_size = vocab_size + world_size - vocab_size % world_size
+ self.model.resize_token_embeddings(new_vocab_size)
return self.model
@@ -42,74 +43,62 @@ class MixtralPolicy(Policy):
)
if self.shard_config.enable_tensor_parallelism:
- raise NotImplementedError
- # assert (
- # self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
- # ), f"The number of attention heads must be divisible by tensor parallel size."
- # assert (
- # self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0
- # ), f"The number of key_value heads must be divisible by tensor parallel size."
- # decoder_attribute_replacement = {
- # "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
- # "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
- # "self_attn.num_key_value_heads": self.model.config.num_key_value_heads
- # // self.shard_config.tensor_parallel_size,
- # }
+ # tensor parallelism for non-moe params
+ assert (
+ self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
+ ), f"The number of attention heads must be divisible by tensor parallel size."
+ assert (
+ self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0
+ ), f"The number of key_value heads must be divisible by tensor parallel size."
+ decoder_attribute_replacement = {
+ "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
+ "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
+ "self_attn.num_key_value_heads": self.model.config.num_key_value_heads
+ // self.shard_config.tensor_parallel_size,
+ }
- # policy[MixtralDecoderLayer] = ModulePolicyDescription(
- # attribute_replacement=decoder_attribute_replacement,
- # sub_module_replacement=[
- # SubModuleReplacementDescription(
- # suffix="self_attn.q_proj",
- # target_module=Linear1D_Col,
- # kwargs={
- # 'process_group': self.shard_config.tensor_parallel_process_group,
- # }
- # ),
- # SubModuleReplacementDescription(
- # suffix="self_attn.k_proj",
- # target_module=Linear1D_Col,
- # kwargs={
- # 'process_group': self.shard_config.tensor_parallel_process_group,
- # }
- # ),
- # SubModuleReplacementDescription(
- # suffix="self_attn.v_proj",
- # target_module=Linear1D_Col,
- # kwargs={
- # 'process_group': self.shard_config.tensor_parallel_process_group,
- # }
- # ),
- # SubModuleReplacementDescription(
- # suffix="self_attn.o_proj",
- # target_module=Linear1D_Row,
- # kwargs={
- # 'process_group': self.shard_config.tensor_parallel_process_group,
- # }
- # ),
- # # SubModuleReplacementDescription(
- # # suffix="mlp.gate_proj",
- # # target_module=Linear1D_Col,
- # # ),
- # # SubModuleReplacementDescription(
- # # suffix="mlp.up_proj",
- # # target_module=Linear1D_Col,
- # # ),
- # # SubModuleReplacementDescription(
- # # suffix="mlp.down_proj",
- # # target_module=Linear1D_Row,
- # # ),
- # ],
- # )
+ policy[MixtralDecoderLayer] = ModulePolicyDescription(
+ attribute_replacement=decoder_attribute_replacement,
+ sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="self_attn.q_proj",
+ target_module=Linear1D_Col,
+ ),
+ SubModuleReplacementDescription(
+ suffix="self_attn.k_proj",
+ target_module=Linear1D_Col,
+ ),
+ SubModuleReplacementDescription(
+ suffix="self_attn.v_proj",
+ target_module=Linear1D_Col,
+ ),
+ SubModuleReplacementDescription(
+ suffix="self_attn.o_proj",
+ target_module=Linear1D_Row,
+ ),
+ # SubModuleReplacementDescription( # TODO: enable moe tp parallel
+ # suffix="mlp.gate_proj",
+ # target_module=Linear1D_Col,
+ # ),
+ # SubModuleReplacementDescription(
+ # suffix="mlp.up_proj",
+ # target_module=Linear1D_Col,
+ # ),
+ # SubModuleReplacementDescription(
+ # suffix="mlp.down_proj",
+ # target_module=Linear1D_Row,
+ # ),
+ ],
+ )
- if getattr(self.shard_config, "ep_group", None) is None:
+ if self.shard_config.ep_group:
# expert parallel
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="block_sparse_moe",
target_module=EPMixtralSparseMoeBlock,
- kwargs={"ep_group": self.shard_config.ep_group},
+ kwargs={"ep_group": self.shard_config.ep_group, "tp_group": self.shard_config.tensor_parallel_process_group},
)
],
policy=policy,
diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py
index b64300366..d1aebd5b2 100644
--- a/colossalai/shardformer/shard/shard_config.py
+++ b/colossalai/shardformer/shard/shard_config.py
@@ -47,6 +47,8 @@ class ShardConfig:
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None
extra_kwargs: Dict[str, Any] = field(default_factory=dict)
ep_group: Optional[ProcessGroup] = None
+ moe_tp_group: Optional[ProcessGroup] = None
+
# pipeline_parallel_size: int
# data_parallel_size: int
# tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']
diff --git a/tests/test_shardformer/test_model/test_shard_mixtral.py b/tests/test_shardformer/test_model/test_shard_mixtral.py
index 4a5f3e14d..70b576908 100644
--- a/tests/test_shardformer/test_model/test_shard_mixtral.py
+++ b/tests/test_shardformer/test_model/test_shard_mixtral.py
@@ -114,39 +114,64 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"test_config",
[
{
- "tp_size": 1,
- "pp_size": 2,
- "num_microbatches": 2,
+ "tp_size": 2,
+ "pp_size": 1,
"ep_size": 1,
- "zero_stage": 0,
+ "zero_stage": 2,
"precision": "fp32",
- }, # pp + ep
+ }, # [dp(2) + tp(2)] + [moe_dp(4)]
+ {
+ "tp_size": 2,
+ "pp_size": 1,
+ "ep_size": 2,
+ "zero_stage": 2,
+ "precision": "fp32",
+ }, # [dp(2) + tp(2)] + [ep(2) + moe_dp(2)]
{
"tp_size": 1,
"pp_size": 2,
"num_microbatches": 2,
"ep_size": 1,
- "zero_stage": 0,
+ "zero_stage": 2,
"precision": "fp32",
- }, # pp + ep
+ }, # [dp(2) + pp(2)] + [moe_dp(4)]
+ {
+ "tp_size": 1,
+ "pp_size": 2,
+ "num_microbatches": 2,
+ "ep_size": 1,
+ "zero_stage": 2,
+ "precision": "fp32",
+ }, # [dp(2) + pp(2)] + [moe_dp(4)]
{
"tp_size": 1,
"pp_size": 2,
"num_microbatches": 2,
"ep_size": 4,
- "zero_stage": 0,
+ "zero_stage": 2,
"precision": "fp32",
- }, # pp + ep
- {"tp_size": 1, "pp_size": 1, "ep_size": 1, "zero_stage": 1, "precision": "bf16"}, # full dp for moe and non-moe
- { # moe_dp = 2, non_moe_dp = 4
+ }, # [dp(2) + pp(2)] + [ep(4))]
+ {
"tp_size": 1,
"pp_size": 1,
"ep_size": 2,
- "zero_stage": 1,
+ "zero_stage": 2,
"precision": "fp32",
- }, # moe_dp = 1, non_moe_dp = 4
- {"tp_size": 1, "pp_size": 1, "ep_size": 4, "zero_stage": 1, "precision": "fp32"}, # full dp for non-moe and full ep for moe
- {"tp_size": 1, "pp_size": 1, "ep_size": 1, "zero_stage": 0, "precision": "fp32"}, # full dp for moe and non-moe
+ }, # [dp(4)] + [ep(2) + moe_tp(2)]
+ {
+ "tp_size": 1,
+ "pp_size": 1,
+ "ep_size": 4,
+ "zero_stage": 2,
+ "precision": "fp32"
+ }, # full dp for non-moe and full ep for moe
+ {
+ "tp_size": 1,
+ "pp_size": 1,
+ "ep_size": 1,
+ "zero_stage": 2,
+ "precision": "fp32"
+ }, # full dp for moe and non-moe
],
)
def run_mixtral_test(test_config):
From 13b48ac0aaccd2e7132f9be544ca19bdda159646 Mon Sep 17 00:00:00 2001
From: botbw
Date: Tue, 9 Jul 2024 08:14:00 +0000
Subject: [PATCH 064/116] [zero] solve hang
---
.../plugin/moe_hybrid_parallel_plugin.py | 23 +-
.../low_level/bookkeeping/bucket_store.py | 8 +-
.../low_level/bookkeeping/gradient_store.py | 1 -
colossalai/zero/low_level/low_level_optim.py | 6 +-
tests/test_moe/moe_utils.py | 4 +-
tests/test_moe/test_moe_ep_tp.py | 320 ++++++------------
tests/test_moe/test_moe_zero_fwd_bwd_optim.py | 107 +++---
.../test_model/test_shard_mixtral.py | 84 ++---
8 files changed, 218 insertions(+), 335 deletions(-)
diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
index 0b0d50e28..56b731d13 100644
--- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
+++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
@@ -30,6 +30,7 @@ class MoeHybridParallelZeroOptimizer(LowLevelZeroOptimizer):
optimizer: Optimizer,
model: Module,
use_pipeline: bool,
+ force_overlap_comm: bool, # force overlap comm
dp_process_group: ProcessGroup, # dp pg for comm
moe_dp_group: ProcessGroup, # moe dp pg for comm
param_info: OrderedDict,
@@ -48,7 +49,16 @@ class MoeHybridParallelZeroOptimizer(LowLevelZeroOptimizer):
partition_grad: bool = False, # stage 2 flag
cpu_offload: bool = False, # cpu offload
forced_dtype: Optional[torch.dtype] = None,
- ):
+ ):
+
+ WARN_STR = "Note that you need to make sure every expert are routed (i.e.) every expert has backward, otherwise this might lead to program hang or inconsistent result"
+ if not force_overlap_comm and (overlap_communication or partition_grad):
+ raise RuntimeError(WARN_STR + " If you are not sure about this, set (overlap_communication=False and partition_grad=False) or force_overlap_comm=True")
+
+ if force_overlap_comm:
+ overlap_communication = True
+ warnings.warn(WARN_STR + " Please make sure of this.")
+
self.param_info = param_info
self.stage_manager = model.stage_manager
self.shared_params = model.shared_params
@@ -88,7 +98,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
TODO: add docstring
"""
- def __init__(self, ep_size: int, moe_tp_size: int = 1, *args, **kwargs) -> None:
+ def __init__(self, ep_size: int, moe_tp_size: int = 1, force_overlap_comm=False, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
@@ -120,6 +130,8 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
# TODO do it in a better way
self.shard_config.ep_group = self.ep_group
+ self.force_overlap_comm = force_overlap_comm
+
def get_checkpoint_io(self) -> MoECheckpointIO:
return MoECheckpointIO(
self.dp_group, self.pp_group, self.tp_group, self.ep_group, self.moe_dp_group, self.zero_stage
@@ -168,11 +180,16 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
optimizer, model, use_pipeline=self.enable_pipeline_parallelism, param_info=param_info
)
else:
- assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1."
+ if not(self.dp_size > 1 or self.moe_dp_size > 1):
+ warnings.warn(
+ "Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. "
+ "If you do not intend to use cpu_offload, please consider set zero_stage=0."
+ )
optimizer = MoeHybridParallelZeroOptimizer(
optimizer,
model,
use_pipeline=self.enable_pipeline_parallelism,
+ force_overlap_comm=self.force_overlap_comm,
param_info=param_info,
dp_process_group=self.dp_group,
moe_dp_group=self.moe_dp_group,
diff --git a/colossalai/zero/low_level/bookkeeping/bucket_store.py b/colossalai/zero/low_level/bookkeeping/bucket_store.py
index 0d0a606c0..78c34046a 100644
--- a/colossalai/zero/low_level/bookkeeping/bucket_store.py
+++ b/colossalai/zero/low_level/bookkeeping/bucket_store.py
@@ -110,12 +110,8 @@ class BucketStore(BaseStore):
flat_grad = []
for grad_list in self._grad_in_bucket.values():
- if len(grad_list) > 0:
- flat_grad.append(_flatten_dense_tensors(grad_list))
- if len(flat_grad) > 0:
- flat_grad = _flatten_dense_tensors(flat_grad)
- else:
- flat_grad = torch.tensor([], device=self.comm_stream.device, dtype=dtype)
+ flat_grad.append(_flatten_dense_tensors(grad_list))
+ flat_grad = _flatten_dense_tensors(flat_grad)
return flat_grad
def get_param_id_of_grad(self, grad: Tensor) -> int:
diff --git a/colossalai/zero/low_level/bookkeeping/gradient_store.py b/colossalai/zero/low_level/bookkeeping/gradient_store.py
index a13fa120a..b84be034a 100644
--- a/colossalai/zero/low_level/bookkeeping/gradient_store.py
+++ b/colossalai/zero/low_level/bookkeeping/gradient_store.py
@@ -19,7 +19,6 @@ class GradientStore(BaseStore):
"""
self._grads_of_params = dict()
# stage 2
- self._partition_grads = partition_grad
self._working_index = 0 if partition_grad else self._local_rank
# for zero2, it's `param_id: [grad_local_rank]`
self.grad_to_param_mapping = dict()
diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py
index 54c6caf41..5a41345f0 100644
--- a/colossalai/zero/low_level/low_level_optim.py
+++ b/colossalai/zero/low_level/low_level_optim.py
@@ -648,7 +648,11 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
for group_id in range(self.num_param_groups):
param_group = self._working_param_groups[group_id]
for param in param_group:
- if param.requires_grad and param.grad is not None:
+ if param.requires_grad:
+ if param.grad is None:
+ # for moe params, all experts should have gradient
+ # TODO better way of doing this
+ param.grad = torch.zeros_like(param)
self._add_to_bucket(param, group_id)
self._run_reduction()
diff --git a/tests/test_moe/moe_utils.py b/tests/test_moe/moe_utils.py
index ba6a0e8a9..b7332a937 100644
--- a/tests/test_moe/moe_utils.py
+++ b/tests/test_moe/moe_utils.py
@@ -137,7 +137,7 @@ def sync_local_from_ep(local_model, ep_model, assert_grad_flag: bool = False) ->
local_param.data.copy_(all_param.data)
-def loose_close(a, b, dtype: torch.dtype = torch.float32):
+def loose_close(a, b, dtype: torch.dtype = torch.float32, name=""):
rtol = None
atol = None
if dtype is torch.float16:
@@ -150,4 +150,4 @@ def loose_close(a, b, dtype: torch.dtype = torch.float32):
a = a.detach().to(dtype)
b = b.detach().to(dtype).to(a.device)
- assert_close(a, b, rtol=rtol, atol=atol)
+ assert torch.allclose(a, b, rtol=rtol, atol=atol), f"{name} not close {a.mean()} {b.mean()}"
diff --git a/tests/test_moe/test_moe_ep_tp.py b/tests/test_moe/test_moe_ep_tp.py
index 9bc11033a..24fc0a0eb 100644
--- a/tests/test_moe/test_moe_ep_tp.py
+++ b/tests/test_moe/test_moe_ep_tp.py
@@ -1,238 +1,134 @@
-import os
-import warnings
-from typing import Dict
+from copy import deepcopy
import pytest
import torch
import torch.distributed as dist
+from torch.nn.parallel import DistributedDataParallel as DDP
+from transformers.models.mixtral.configuration_mixtral import MixtralConfig
+from transformers.models.mixtral.modeling_mixtral import MixtralModel
import colossalai
-from colossalai.accelerator import get_accelerator
-from colossalai.moe.manager import MOE_MANAGER
-from colossalai.moe.utils import sync_moe_model_param
+from colossalai.booster.booster import Booster
+from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
+from colossalai.booster.plugin import HybridParallelPlugin
+from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
+from colossalai.testing.random import seed_all
+from tests.test_moe.moe_utils import loose_close
-# from colossalai.shardformer.layer import SparseMLP
-from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_rank, get_ep_size, is_moe_tensor
-from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn
-from tests.test_moe.moe_utils import MoeGradientHandler
+NUM_BATCH=4
+NUM_TOK_PER_BATCH, NUM_EXPERTS = 7, 4
+HIDDEN_SIZE_PER_HEAD = 4
+NUM_HEADS=2
+TOP_K = 2
-def sync_tp_from_local(tp_model, local_model, assert_grad_flag: bool = False) -> None:
- """Sync the parameters of tp model from local model
-
- Args:
- tp_model (MoeModule)
- local_model (MoeModule)
- """
- for (tp_name, tp_param), (local_name, local_param) in zip(
- tp_model.named_parameters(), local_model.named_parameters()
- ):
- assert tp_name == local_name
- if not is_moe_tensor(tp_param):
- if assert_grad_flag:
- assert torch.allclose(tp_param, local_param)
- assert torch.allclose(tp_param.grad, local_param.grad)
- else:
- tp_param.data.copy_(local_param.data)
- continue
-
- tp_rank = get_ep_rank(tp_param)
- tp_dim = [i for i, (d1, d2) in enumerate(zip(tp_param.shape, local_param.shape)) if d1 != d2][0]
- tp_slice = [slice(None)] * tp_dim + [
- slice(tp_param.shape[tp_dim] * tp_rank, tp_param.shape[tp_dim] * (tp_rank + 1))
- ]
-
- if assert_grad_flag:
- assert torch.allclose(tp_param, local_param[tuple(tp_slice)])
- assert torch.allclose(tp_param.grad, local_param.grad[tuple(tp_slice)])
- else:
- tp_param.data.copy_(local_param[tuple(tp_slice)].data)
+def split_grad(grad, world_size):
+ with torch.no_grad():
+ grad = grad.clone().detach().flatten()
+ padding_size = (world_size - grad.numel() % world_size) % world_size
+ if padding_size > 0:
+ grad = torch.nn.functional.pad(grad, [0, padding_size])
+ splited_grad = grad.split(grad.numel() // world_size)
+ return splited_grad
-def sync_tp_from_ep(tp_model, ep_model, assert_grad_flag: bool = False) -> None:
- """Sync the parameters of tp model from ep model
+@parameterize("stage", [1])
+@parameterize("ep_size", [1, 2, 4])
+@parameterize("tp_size", [1, 2, 4])
+def run_zero_with_original_model(stage: int, ep_size: int, tp_size: int=1):
+ dtype = torch.bfloat16
- Args:
- tp_model (MoeModule)
- ep_model (MoeModule)
- """
- for (tp_name, tp_param), (ep_name, ep_param) in zip(tp_model.named_parameters(), ep_model.named_parameters()):
- assert tp_name == ep_name
- if not is_moe_tensor(tp_param):
- if assert_grad_flag:
- assert torch.allclose(tp_param, ep_param)
- assert torch.allclose(tp_param.grad, ep_param.grad)
- else:
- tp_param.data.copy_(ep_param.data)
- continue
+ rank = torch.distributed.get_rank()
+ torch.cuda.set_device(dist.get_rank())
- # gather param from ep model
- param_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))]
- dist.all_gather(param_list, ep_param, group=get_ep_group(ep_param))
- all_param = torch.cat(param_list, dim=0)
- if assert_grad_flag:
- grad_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))]
- dist.all_gather(grad_list, ep_param.grad, group=get_ep_group(ep_param))
- all_grad = torch.cat(grad_list, dim=0)
+ seed_all(10086)
- # get tp param
- tp_dim = [i for i, (d1, d2) in enumerate(zip(tp_param.shape[1:], all_param.shape[1:])) if d1 != d2][0] + 1
- tp_rank = get_ep_rank(tp_param)
- tp_slice = [slice(None)] * tp_dim + [
- slice(tp_param.shape[tp_dim] * tp_rank, tp_param.shape[tp_dim] * (tp_rank + 1))
- ]
- new_tp_param = all_param[tuple(tp_slice)]
- if assert_grad_flag:
- new_grad = all_grad[tuple(tp_slice)]
- if assert_grad_flag:
- assert torch.allclose(tp_param, new_tp_param)
- assert torch.allclose(tp_param.grad, new_grad)
- else:
- tp_param.data.copy_(new_tp_param.data)
-
-
-def sync_local_from_ep(local_model, ep_model, assert_grad_flag: bool = False) -> None:
- """Sync the parameters of tp model from ep model
-
- Args:
- local_model (MoeModule)
- ep_model (MoeModule)
- """
- for (local_name, local_param), (ep_name, ep_param) in zip(
- local_model.named_parameters(), ep_model.named_parameters()
- ):
- assert local_name == ep_name
- if "experts" not in local_name:
- if assert_grad_flag:
- assert torch.allclose(local_param, ep_param)
- assert torch.allclose(local_param.grad, ep_param.grad)
- else:
- local_param.data.copy_(ep_param.data)
- continue
-
- # gather param from ep model
- param_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))]
- dist.all_gather(param_list, ep_param, group=get_ep_group(ep_param))
- all_param = torch.cat(param_list, dim=0)
- if assert_grad_flag:
- grad_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))]
- dist.all_gather(grad_list, ep_param.grad, group=get_ep_group(ep_param))
- all_grad = torch.cat(grad_list, dim=0)
-
- if assert_grad_flag:
- assert torch.allclose(local_param, all_param)
- assert torch.allclose(local_param.grad, all_grad)
- else:
- local_param.data.copy_(all_param.data)
-
-
-def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size: int, dim: int, config: Dict):
- assert batch_size % world_size == 0
-
- colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
-
- MOE_MANAGER.__init__()
- MOE_MANAGER.setup(parallel=None)
- local_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2)
- MOE_MANAGER.__init__()
- MOE_MANAGER.setup(parallel="EP")
- enable_hierarchical_comm = config.get("enable_hierarchical_comm", False)
- if enable_hierarchical_comm:
- os.environ["LOCAL_WORLD_SIZE"] = str(world_size)
- ep_model = SparseMLP(
- num_experts=num_experts,
- hidden_size=dim,
- intermediate_size=dim * 2,
- enable_hierarchical_comm=enable_hierarchical_comm,
+ config = MixtralConfig(
+ hidden_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS,
+ intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2,
+ num_hidden_layers=2,
+ num_attention_heads=NUM_HEADS,
+ num_key_value_heads=NUM_HEADS,
+ num_local_experts=NUM_EXPERTS,
+ num_experts_per_tok=TOP_K,
)
- MOE_MANAGER.__init__()
- MOE_MANAGER.setup(parallel="TP")
- tp_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2)
- ep_model = ep_model.to(get_accelerator().get_current_device())
- tp_model = tp_model.to(get_accelerator().get_current_device())
- local_model = local_model.to(get_accelerator().get_current_device())
+ torch_model = MixtralModel(config).to(dtype).cuda()
- # sync ep param
- sync_moe_model_param(ep_model)
- dist_dict = MOE_MANAGER.parallel_info_dict
- assert_equal_in_group(ep_model.experts.wi.data, dist_dict[world_size].dp_group)
- assert_equal_in_group(ep_model.experts.wo.data, dist_dict[world_size].dp_group)
- ep_grad_handler = MoeGradientHandler(ep_model)
- # sync local param
- sync_local_from_ep(local_model, ep_model)
- # sync tp param
- sync_tp_from_ep(tp_model, ep_model)
- tp_grad_handler = MoeGradientHandler(tp_model)
-
- rank = dist.get_rank()
- input_data = torch.randn(batch_size, dim, device=get_accelerator().get_current_device())
- micro_batch_size = batch_size // world_size
- index = rank * micro_batch_size
- # NOTE: ep & tp takes in sharded data for each process
- shard_data = input_data.detach()[index : index + micro_batch_size]
-
- out_local = local_model(input_data)
- MOE_MANAGER.reset_loss()
- out_tp = tp_model(shard_data)
- MOE_MANAGER.reset_loss()
- out_ep = ep_model(shard_data)
- MOE_MANAGER.reset_loss()
-
- assert torch.allclose(
- out_tp, out_ep, atol=1e-6
- ), f"Rank {rank} failed, max diff: {torch.max(torch.abs(out_tp - out_ep))}"
- try:
- out_local_slice = out_local[index : index + micro_batch_size]
- assert torch.allclose(
- out_ep, out_local_slice, atol=1e-6
- ), f"Rank {rank} failed, max diff: {torch.max(torch.abs(out_ep - out_local_slice))}"
- except AssertionError:
- """
- e.g., in local model, tokens = 4, capacity = 2, experts = 2, topk = 1
- router yields [01] --> [0], [23] --> [1], this is valid as capacity is 2
- However, in ep mode, there are 2 separate routers dealing with sharded data.
- Assume router 0 handles token [01] and router 1 handles token [23].
- Note that for each router the capacity is only 1 !!!
- Thus, router 0 may yields [0] --> [0] or [1] --> [0], but not both.
- The same thing happens on router 1. And finally some tokens are dropped due to the sharded nature.
- """
- warnings.warn(
- "EP & TP may result in different behavior from local model. " "Please check the comments for details."
+ zero_model = deepcopy(torch_model).to(dtype)
+ zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1)
+ booster = Booster(
+ plugin=MoeHybridParallelPlugin(
+ tp_size=tp_size,
+ pp_size=1,
+ ep_size=ep_size,
+ zero_stage=stage,
+ overlap_communication=False,
+ initial_scale=1
)
+ )
+ zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer)
- out_local.mean().backward()
- out_tp.mean().backward()
- tp_grad_handler.handle_gradient()
- out_ep.mean().backward()
- ep_grad_handler.handle_gradient()
-
- assert_equal_in_group(ep_model.experts.wi.grad, dist_dict[world_size].dp_group)
- assert_equal_in_group(ep_model.experts.wo.grad, dist_dict[world_size].dp_group)
- sync_tp_from_ep(tp_model, ep_model, assert_grad_flag=True)
- try:
- sync_local_from_ep(local_model, ep_model, assert_grad_flag=True)
- except AssertionError:
- warnings.warn(
- "EP & TP may result in different behavior from local model. " "Please check the comments for details."
+ booster = Booster(
+ plugin=HybridParallelPlugin(
+ tp_size=tp_size,
+ pp_size=1,
+ zero_stage=stage,
+ overlap_communication=False,
+ initial_scale=1,
)
+ )
+ hybrid_model, hybrid_optimizer, _, _, _ = booster.boost(torch_model, torch.optim.SGD(torch_model.parameters(), lr=1))
+
+ # create different input
+ seed_all(1453 + rank)
+
+ hybrid_model.train()
+ zero_model.train()
+ for _ in range(2):
+ # zero-dp forward
+ input_data = torch.rand(NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True).cuda()
+ zero_output = zero_model(inputs_embeds=input_data.to(dtype)).last_hidden_state.mean()
+ # zero-dp backward
+ zero_optimizer.backward(zero_output)
+ # torch-ddp forward
+ hybrid_output = hybrid_model(inputs_embeds=input_data.to(dtype)).last_hidden_state.mean()
+ loose_close(zero_output, hybrid_output, dtype=dtype)
+ # torch-ddp backward
+ hybrid_optimizer.backward(hybrid_output)
+
+ # check grad
+ name_to_p = {n: p for n, p in hybrid_model.named_parameters()}
+ for n, p in zero_model.named_parameters():
+ zero_grad = zero_optimizer.get_param_grad(p)
+ if name_to_p[n].grad is None:
+ name_to_p[n].grad = torch.zeros_like(name_to_p[n])
+ continue
+ loose_close(zero_grad, name_to_p[n].grad, dtype=dtype, name=n)
+
+ # zero-dp step
+ zero_optimizer.step()
+
+ # original model step
+ hybrid_optimizer.step()
+
+ # check updated param
+ for n, p in zero_model.named_parameters():
+ loose_close(p.data, name_to_p[n].data, dtype=dtype, name=n)
+
+ print(f"{dist.get_rank()} test passed")
+
+
+def run_dist(rank, world_size, port):
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ run_zero_with_original_model()
-@pytest.mark.skip(reason="moe need to be refactored")
@pytest.mark.dist
-@pytest.mark.parametrize("num_experts", [4, 64])
-@pytest.mark.parametrize("batch_size", [16])
-@pytest.mark.parametrize("dim", [64])
-@pytest.mark.parametrize(
- "config",
- [
- {"enable_hierarchical_comm": False},
- {"enable_hierarchical_comm": True},
- ],
-)
+@pytest.mark.parametrize("world_size", [4])
@rerun_if_address_is_in_use()
-def test_moe_ep_tp(num_experts: int, batch_size: int, dim: int, config: Dict):
- spawn(run_test, 2, num_experts=num_experts, batch_size=batch_size, dim=dim, config=config)
+def test_moe_ep_tp(world_size):
+ spawn(run_dist, world_size)
if __name__ == "__main__":
- test_moe_ep_tp(num_experts=8, batch_size=32, dim=32)
+ test_moe_ep_tp(world_size=4)
diff --git a/tests/test_moe/test_moe_zero_fwd_bwd_optim.py b/tests/test_moe/test_moe_zero_fwd_bwd_optim.py
index 2e6d0d786..3d6af2b1a 100644
--- a/tests/test_moe/test_moe_zero_fwd_bwd_optim.py
+++ b/tests/test_moe/test_moe_zero_fwd_bwd_optim.py
@@ -5,20 +5,20 @@ import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from transformers.models.mixtral.configuration_mixtral import MixtralConfig
-from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
+from transformers.models.mixtral.modeling_mixtral import MixtralModel
import colossalai
+from colossalai.booster.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
-from colossalai.shardformer.modeling.mixtral import EPMixtralSparseMoeBlock
-from colossalai.tensor.moe_tensor.api import is_moe_tensor
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all
-from colossalai.zero import LowLevelZeroOptimizer
from tests.test_moe.moe_utils import loose_close
-tokens, n_experts = 7, 4
-hidden_size = 8
-top_k = 2
+NUM_BATCH=4
+NUM_TOK_PER_BATCH, NUM_EXPERTS = 7, 4
+HIDDEN_SIZE_PER_HEAD = 4
+NUM_HEADS=2
+TOP_K = 2
def split_grad(grad, world_size):
@@ -31,94 +31,87 @@ def split_grad(grad, world_size):
return splited_grad
-@parameterize("stage", [1, 2])
+@parameterize("stage", [1])
@parameterize("ep_size", [1, 2, 4])
def run_zero_with_original_model(stage: int, ep_size: int):
- dtype = torch.float16
+ dtype = torch.bfloat16
rank = torch.distributed.get_rank()
torch.cuda.set_device(dist.get_rank())
+
plugin = MoeHybridParallelPlugin(
- tp_size=1,
pp_size=1,
+ tp_size=1,
ep_size=ep_size,
+ zero_stage=stage,
+ overlap_communication=False,
+ initial_scale=1
)
+ booster = Booster(plugin=plugin)
seed_all(10086)
+
config = MixtralConfig(
- hidden_size=hidden_size,
- intermediate_size=hidden_size * 2,
- num_local_experts=n_experts,
- num_experts_per_tok=top_k,
+ hidden_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS,
+ intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2,
+ num_hidden_layers=2,
+ num_attention_heads=NUM_HEADS,
+ num_key_value_heads=NUM_HEADS,
+ num_local_experts=NUM_EXPERTS,
+ num_experts_per_tok=TOP_K,
)
- orig_model = MixtralSparseMoeBlock(config).to(dtype).cuda()
+ torch_model = MixtralModel(config).to(dtype).cuda()
- ori_model = DDP(
- orig_model.cuda(),
+ zero_model = deepcopy(torch_model).to(dtype)
+ zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1)
+
+ zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer)
+
+ ddp_model = DDP(
+ torch_model.cuda(),
process_group=plugin.dp_group,
find_unused_parameters=True, # important for torch ddp, not all experts are routed
).cuda()
+ ddp_optimizer = torch.optim.SGD(ddp_model.parameters(), lr=1)
- zero_model = deepcopy(orig_model).to(dtype)
- zero_model = EPMixtralSparseMoeBlock.from_native_module(zero_model, ep_group=plugin.ep_group)
-
- zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1)
- pg_param_list = {plugin.dp_group: [], plugin.moe_dp_group: []}
- for p in zero_model.parameters():
- if is_moe_tensor(p):
- pg_param_list[plugin.moe_dp_group].append(p)
- else:
- pg_param_list[plugin.dp_group].append(p)
-
- zero_optimizer = LowLevelZeroOptimizer(
- zero_optimizer,
- pg_to_param_list=pg_param_list,
- master_weights=False,
- initial_scale=1,
- overlap_communication=True,
- partition_grad=stage == 2,
- )
-
- ori_optimizer = torch.optim.SGD(ori_model.parameters(), lr=1)
-
- # create
+ # create different input
seed_all(1453 + rank)
+ ddp_model.train()
+ zero_model.train()
for _ in range(2):
# zero-dp forward
- input_data = torch.rand(1, tokens, hidden_size, requires_grad=True).cuda()
- zero_output, _ = zero_model(input_data.to(dtype))
+ input_data = torch.rand(NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True).cuda()
+ zero_output = zero_model(inputs_embeds=input_data.to(dtype)).last_hidden_state.mean()
+ # zero-dp backward
+ zero_optimizer.backward(zero_output)
# torch-ddp forward
- ori_output, _ = ori_model(input_data.to(dtype))
- loose_close(zero_output, ori_output, dtype=dtype)
-
- # zero-dp backward
- zero_optimizer.backward(zero_output.mean().float())
-
+ ddp_output = ddp_model(inputs_embeds=input_data.to(dtype)).last_hidden_state.mean()
+ loose_close(zero_output, ddp_output, dtype=dtype)
# torch-ddp backward
- ori_output.mean().backward()
+ ddp_output.backward()
# check grad
- name_to_p = {n: p for n, p in ori_model.module.named_parameters()}
+ name_to_p = {n: p for n, p in ddp_model.named_parameters()}
for n, p in zero_model.named_parameters():
+ print(f"rank {dist.get_rank()} {n}")
zero_grad = zero_optimizer.get_param_grad(p)
if name_to_p[n].grad is None:
- assert zero_grad is None
+ name_to_p[n].grad = torch.zeros_like(name_to_p[n].data)
continue
-
- loose_close(zero_grad, name_to_p[n].grad, dtype=dtype)
+ loose_close(zero_grad, name_to_p[n].grad, dtype=dtype, name=n)
# zero-dp step
zero_optimizer.step()
# original model step
- ori_optimizer.step()
+ ddp_optimizer.step()
# check updated param
for n, p in zero_model.named_parameters():
- loose_close(p.data, name_to_p[n].data, dtype=dtype)
+ loose_close(p.data, name_to_p[n].data, dtype=dtype, name=n)
print(f"{dist.get_rank()} test passed")
@@ -131,9 +124,9 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [4])
@rerun_if_address_is_in_use()
-def test_moe_zero_model(world_size):
+def test_moe_ep_tp(world_size):
spawn(run_dist, world_size)
if __name__ == "__main__":
- test_moe_zero_model(world_size=4)
+ test_moe_ep_tp(world_size=4)
diff --git a/tests/test_shardformer/test_model/test_shard_mixtral.py b/tests/test_shardformer/test_model/test_shard_mixtral.py
index 70b576908..4e9d3878b 100644
--- a/tests/test_shardformer/test_model/test_shard_mixtral.py
+++ b/tests/test_shardformer/test_model/test_shard_mixtral.py
@@ -113,65 +113,43 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
@parameterize(
"test_config",
[
- {
- "tp_size": 2,
- "pp_size": 1,
- "ep_size": 1,
- "zero_stage": 2,
- "precision": "fp32",
- }, # [dp(2) + tp(2)] + [moe_dp(4)]
- {
- "tp_size": 2,
- "pp_size": 1,
- "ep_size": 2,
- "zero_stage": 2,
- "precision": "fp32",
- }, # [dp(2) + tp(2)] + [ep(2) + moe_dp(2)]
{
"tp_size": 1,
- "pp_size": 2,
- "num_microbatches": 2,
+ "pp_size": 1,
"ep_size": 1,
"zero_stage": 2,
"precision": "fp32",
}, # [dp(2) + pp(2)] + [moe_dp(4)]
- {
- "tp_size": 1,
- "pp_size": 2,
- "num_microbatches": 2,
- "ep_size": 1,
- "zero_stage": 2,
- "precision": "fp32",
- }, # [dp(2) + pp(2)] + [moe_dp(4)]
- {
- "tp_size": 1,
- "pp_size": 2,
- "num_microbatches": 2,
- "ep_size": 4,
- "zero_stage": 2,
- "precision": "fp32",
- }, # [dp(2) + pp(2)] + [ep(4))]
- {
- "tp_size": 1,
- "pp_size": 1,
- "ep_size": 2,
- "zero_stage": 2,
- "precision": "fp32",
- }, # [dp(4)] + [ep(2) + moe_tp(2)]
- {
- "tp_size": 1,
- "pp_size": 1,
- "ep_size": 4,
- "zero_stage": 2,
- "precision": "fp32"
- }, # full dp for non-moe and full ep for moe
- {
- "tp_size": 1,
- "pp_size": 1,
- "ep_size": 1,
- "zero_stage": 2,
- "precision": "fp32"
- }, # full dp for moe and non-moe
+ # {
+ # "tp_size": 1,
+ # "pp_size": 2,
+ # "num_microbatches": 2,
+ # "ep_size": 1,
+ # "zero_stage": 1,
+ # "precision": "fp32",
+ # }, # [dp(2) + pp(2)] + [moe_dp(4)]
+ # {
+ # "tp_size": 1,
+ # "pp_size": 2,
+ # "num_microbatches": 2,
+ # "ep_size": 4,
+ # "zero_stage": 1,
+ # "precision": "fp32",
+ # }, # [dp(2) + pp(2)] + [ep(4))]
+ # {
+ # "tp_size": 1,
+ # "pp_size": 1,
+ # "ep_size": 2,
+ # "zero_stage": 0,
+ # "precision": "fp32",
+ # }, # [dp(4)] + [ep(2) + moe_tp(2)]
+ # {
+ # "tp_size": 1,
+ # "pp_size": 1,
+ # "ep_size": 4,
+ # "zero_stage": 0,
+ # "precision": "fp32"
+ # }, # full dp for non-moe and full ep for moe
],
)
def run_mixtral_test(test_config):
From fe24789eb178236ad77112824a7d6081ed50dabc Mon Sep 17 00:00:00 2001
From: haze188
Date: Tue, 9 Jul 2024 09:44:04 +0000
Subject: [PATCH 065/116] [misc] solve booster hang by rename the variable
---
tests/test_moe/test_moe_ep_tp.py | 31 ++++++++++++++-----------------
1 file changed, 14 insertions(+), 17 deletions(-)
diff --git a/tests/test_moe/test_moe_ep_tp.py b/tests/test_moe/test_moe_ep_tp.py
index 24fc0a0eb..632a8ce38 100644
--- a/tests/test_moe/test_moe_ep_tp.py
+++ b/tests/test_moe/test_moe_ep_tp.py
@@ -3,22 +3,21 @@ from copy import deepcopy
import pytest
import torch
import torch.distributed as dist
-from torch.nn.parallel import DistributedDataParallel as DDP
from transformers.models.mixtral.configuration_mixtral import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralModel
import colossalai
from colossalai.booster.booster import Booster
-from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.booster.plugin import HybridParallelPlugin
+from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all
from tests.test_moe.moe_utils import loose_close
-NUM_BATCH=4
+NUM_BATCH = 4
NUM_TOK_PER_BATCH, NUM_EXPERTS = 7, 4
HIDDEN_SIZE_PER_HEAD = 4
-NUM_HEADS=2
+NUM_HEADS = 4
TOP_K = 2
@@ -35,7 +34,7 @@ def split_grad(grad, world_size):
@parameterize("stage", [1])
@parameterize("ep_size", [1, 2, 4])
@parameterize("tp_size", [1, 2, 4])
-def run_zero_with_original_model(stage: int, ep_size: int, tp_size: int=1):
+def run_zero_with_original_model(stage: int, ep_size: int, tp_size: int = 1):
dtype = torch.bfloat16
rank = torch.distributed.get_rank()
@@ -56,19 +55,14 @@ def run_zero_with_original_model(stage: int, ep_size: int, tp_size: int=1):
zero_model = deepcopy(torch_model).to(dtype)
zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1)
- booster = Booster(
+ moe_booster = Booster(
plugin=MoeHybridParallelPlugin(
- tp_size=tp_size,
- pp_size=1,
- ep_size=ep_size,
- zero_stage=stage,
- overlap_communication=False,
- initial_scale=1
+ tp_size=tp_size, pp_size=1, ep_size=ep_size, zero_stage=stage, overlap_communication=False, initial_scale=1
)
)
- zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer)
+ zero_model, zero_optimizer, _, _, _ = moe_booster.boost(zero_model, zero_optimizer)
- booster = Booster(
+ hybird_booster = Booster(
plugin=HybridParallelPlugin(
tp_size=tp_size,
pp_size=1,
@@ -77,8 +71,9 @@ def run_zero_with_original_model(stage: int, ep_size: int, tp_size: int=1):
initial_scale=1,
)
)
- hybrid_model, hybrid_optimizer, _, _, _ = booster.boost(torch_model, torch.optim.SGD(torch_model.parameters(), lr=1))
-
+ hybrid_model, hybrid_optimizer, _, _, _ = hybird_booster.boost(
+ torch_model, torch.optim.SGD(torch_model.parameters(), lr=1)
+ )
# create different input
seed_all(1453 + rank)
@@ -86,7 +81,9 @@ def run_zero_with_original_model(stage: int, ep_size: int, tp_size: int=1):
zero_model.train()
for _ in range(2):
# zero-dp forward
- input_data = torch.rand(NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True).cuda()
+ input_data = torch.rand(
+ NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True
+ ).cuda()
zero_output = zero_model(inputs_embeds=input_data.to(dtype)).last_hidden_state.mean()
# zero-dp backward
zero_optimizer.backward(zero_output)
From 5ed5e8cfbaa7cc56a2f8716a2fbbf7139b9071f1 Mon Sep 17 00:00:00 2001
From: haze188
Date: Thu, 11 Jul 2024 02:12:44 +0000
Subject: [PATCH 066/116] solve hang when parallel mode = pp + dp
---
.../booster/plugin/hybrid_parallel_plugin.py | 7 ++-
.../plugin/moe_hybrid_parallel_plugin.py | 40 ++++++++++-------
.../test_model/test_shard_mixtral.py | 44 ++++++++++++-------
3 files changed, 57 insertions(+), 34 deletions(-)
diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py
index cad9ca95c..03b7bebb1 100644
--- a/colossalai/booster/plugin/hybrid_parallel_plugin.py
+++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py
@@ -27,6 +27,7 @@ from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO
from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.interface.optimizer import DistributedOptim
+from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
@@ -1068,8 +1069,10 @@ class HybridParallelPlugin(PipelinePluginBase):
self.pp_axis, self.dp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size)
- self.logger.info(f"{type(self).__name__}: {self.pp_size=} {self.dp_size=} {self.tp_size=} {self.sp_size=}", ranks=[0])
-
+ self.logger.info(
+ f"{type(self).__name__}: {self.pp_size=} {self.dp_size=} {self.tp_size=} {self.sp_size=}", ranks=[0]
+ )
+
self.stage_manager = None
self.schedule = None
self.custom_policy = custom_policy
diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
index 56b731d13..a02deb80d 100644
--- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
+++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
@@ -15,6 +15,7 @@ from colossalai.booster.plugin.hybrid_parallel_plugin import (
HybridParallelModule,
HybridParallelNaiveOptimizer,
HybridParallelPlugin,
+ HybridParallelZeroOptimizer,
get_param_info,
reinitialize_optimizer,
)
@@ -22,16 +23,18 @@ from colossalai.checkpoint_io import MoECheckpointIO
from colossalai.cluster.process_group_mesh import ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.tensor.moe_tensor.api import is_moe_tensor
-from colossalai.zero.low_level import LowLevelZeroOptimizer
-class MoeHybridParallelZeroOptimizer(LowLevelZeroOptimizer):
+
+class MoeHybridParallelZeroOptimizer(HybridParallelZeroOptimizer):
def __init__(
self,
optimizer: Optimizer,
model: Module,
use_pipeline: bool,
force_overlap_comm: bool, # force overlap comm
- dp_process_group: ProcessGroup, # dp pg for comm
+ dp_process_group: Optional[ProcessGroup], # the dp pg for comm
+ tp_process_group: Optional[ProcessGroup], # if using tp
+ pp_process_group: Optional[ProcessGroup], # if using pp
moe_dp_group: ProcessGroup, # moe dp pg for comm
param_info: OrderedDict,
initial_scale: int = 2**16, # grad scaler config
@@ -49,32 +52,28 @@ class MoeHybridParallelZeroOptimizer(LowLevelZeroOptimizer):
partition_grad: bool = False, # stage 2 flag
cpu_offload: bool = False, # cpu offload
forced_dtype: Optional[torch.dtype] = None,
- ):
-
+ ):
WARN_STR = "Note that you need to make sure every expert are routed (i.e.) every expert has backward, otherwise this might lead to program hang or inconsistent result"
if not force_overlap_comm and (overlap_communication or partition_grad):
- raise RuntimeError(WARN_STR + " If you are not sure about this, set (overlap_communication=False and partition_grad=False) or force_overlap_comm=True")
-
+ raise RuntimeError(
+ WARN_STR
+ + " If you are not sure about this, set (overlap_communication=False and partition_grad=False) or force_overlap_comm=True"
+ )
+
if force_overlap_comm:
overlap_communication = True
warnings.warn(WARN_STR + " Please make sure of this.")
- self.param_info = param_info
- self.stage_manager = model.stage_manager
- self.shared_params = model.shared_params
- self.dp_pg = dp_process_group
-
- if use_pipeline:
- reinitialize_optimizer(optimizer, model)
-
pg_param_list = {
dp_process_group: list(filter(lambda p: not is_moe_tensor(p), model.parameters())),
moe_dp_group: list(filter(is_moe_tensor, model.parameters())),
}
super().__init__(
+ model=model,
optimizer=optimizer,
- pg_to_param_list=pg_param_list,
+ use_pipeline=use_pipeline,
+ param_info=param_info,
initial_scale=initial_scale,
min_scale=min_scale,
growth_factor=growth_factor,
@@ -89,7 +88,12 @@ class MoeHybridParallelZeroOptimizer(LowLevelZeroOptimizer):
overlap_communication=overlap_communication,
partition_grad=partition_grad,
cpu_offload=cpu_offload,
+ # dp_process_group=dp_process_group,
+ tp_process_group=tp_process_group,
+ pp_process_group=pp_process_group,
forced_dtype=forced_dtype,
+ ## moe args
+ pg_to_param_list=pg_param_list,
)
@@ -180,7 +184,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
optimizer, model, use_pipeline=self.enable_pipeline_parallelism, param_info=param_info
)
else:
- if not(self.dp_size > 1 or self.moe_dp_size > 1):
+ if not (self.dp_size > 1 or self.moe_dp_size > 1):
warnings.warn(
"Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. "
"If you do not intend to use cpu_offload, please consider set zero_stage=0."
@@ -192,6 +196,8 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
force_overlap_comm=self.force_overlap_comm,
param_info=param_info,
dp_process_group=self.dp_group,
+ tp_process_group=self.tp_group,
+ pp_process_group=self.pp_group,
moe_dp_group=self.moe_dp_group,
verbose=True,
clip_grad_norm=self.max_norm,
diff --git a/tests/test_shardformer/test_model/test_shard_mixtral.py b/tests/test_shardformer/test_model/test_shard_mixtral.py
index 4e9d3878b..123e590c9 100644
--- a/tests/test_shardformer/test_model/test_shard_mixtral.py
+++ b/tests/test_shardformer/test_model/test_shard_mixtral.py
@@ -117,23 +117,35 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"tp_size": 1,
"pp_size": 1,
"ep_size": 1,
- "zero_stage": 2,
+ "zero_stage": 1,
+ "overlap_communication": False,
"precision": "fp32",
- }, # [dp(2) + pp(2)] + [moe_dp(4)]
+ }, # [dp(4)] + [moe_dp(4)]
+ {
+ "tp_size": 1,
+ "pp_size": 2,
+ "num_microbatches": 2,
+ "ep_size": 1,
+ "zero_stage": 1,
+ "overlap_communication": False,
+ "precision": "fp32",
+ }, # [dp(2) + pp(2)] + [moe_pp(2)]
+ {
+ "tp_size": 2,
+ "pp_size": 2,
+ "num_microbatches": 2,
+ "ep_size": 1,
+ "zero_stage": 1,
+ "overlap_communication": False,
+ "precision": "fp32",
+ }, # [pp(2) + tp(2)] + [pp(2), replicate(2)] pass
# {
# "tp_size": 1,
# "pp_size": 2,
# "num_microbatches": 2,
- # "ep_size": 1,
- # "zero_stage": 1,
- # "precision": "fp32",
- # }, # [dp(2) + pp(2)] + [moe_dp(4)]
- # {
- # "tp_size": 1,
- # "pp_size": 2,
- # "num_microbatches": 2,
- # "ep_size": 4,
+ # "ep_size": 2,
# "zero_stage": 1,
+ # "overlap_communication": False,
# "precision": "fp32",
# }, # [dp(2) + pp(2)] + [ep(4))]
# {
@@ -141,13 +153,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
# "pp_size": 1,
# "ep_size": 2,
# "zero_stage": 0,
+ # "overlap_communication": False,
# "precision": "fp32",
# }, # [dp(4)] + [ep(2) + moe_tp(2)]
# {
- # "tp_size": 1,
- # "pp_size": 1,
- # "ep_size": 4,
- # "zero_stage": 0,
+ # "tp_size": 1,
+ # "pp_size": 1,
+ # "ep_size": 4,
+ # "overlap_communication": False,
+ # "zero_stage": 0,
# "precision": "fp32"
# }, # full dp for non-moe and full ep for moe
],
From e28e05345b3d3f6409dc44905f2c9ae34139175c Mon Sep 17 00:00:00 2001
From: botbw
Date: Thu, 11 Jul 2024 05:50:20 +0000
Subject: [PATCH 067/116] [moe] implement submesh initialization
---
.../plugin/moe_hybrid_parallel_plugin.py | 68 +++++++++++++++++--
colossalai/cluster/process_group_mesh.py | 40 +++++++----
.../test_model/test_shard_mixtral.py | 17 ++---
3 files changed, 98 insertions(+), 27 deletions(-)
diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
index a02deb80d..f689fe988 100644
--- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
+++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
@@ -1,6 +1,7 @@
import warnings
from types import MethodType
from typing import Callable, Optional, OrderedDict, Tuple
+import numpy as np
import torch
import torch.distributed as dist
@@ -64,6 +65,14 @@ class MoeHybridParallelZeroOptimizer(HybridParallelZeroOptimizer):
overlap_communication = True
warnings.warn(WARN_STR + " Please make sure of this.")
+ self.param_info = param_info
+ self.stage_manager = model.stage_manager
+ self.shared_params = model.shared_params
+ self.dp_pg = dp_process_group
+
+ if use_pipeline:
+ reinitialize_optimizer(optimizer, model)
+
pg_param_list = {
dp_process_group: list(filter(lambda p: not is_moe_tensor(p), model.parameters())),
moe_dp_group: list(filter(is_moe_tensor, model.parameters())),
@@ -116,17 +125,16 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
raise NotImplementedError
world_size = dist.get_world_size()
-
- self.moe_dp_size = world_size // (ep_size * moe_tp_size)
+ self.moe_dp_size = world_size // (self.pp_size * ep_size * moe_tp_size)
self.ep_size = ep_size
self.moe_tp_size = moe_tp_size
- self.moe_pg_mesh = ProcessGroupMesh(self.moe_dp_size, self.ep_size, self.moe_tp_size)
- self.moe_dp_axis, self.ep_axis, self.moe_tp_axis = 0, 1, 2
+ if self.pp_size * self.moe_dp_size * self.ep_size * self.moe_tp_size != world_size:
+ raise ValueError(
+ f"world_size={world_size} is not divisible by pp_size={self.pp_size} * moe_dp_size={self.moe_dp_size} * ep_size={self.ep_size} * moe_tp_size={self.moe_tp_size}"
+ )
- self.moe_dp_group = self.moe_pg_mesh.get_group_along_axis(self.moe_dp_axis)
- self.ep_group = self.moe_pg_mesh.get_group_along_axis(self.ep_axis)
- self.moe_tp_group = self.moe_pg_mesh.get_group_along_axis(self.moe_tp_axis)
+ self._init_moe_param_comm()
self.logger.info(f"{type(self).__name__}: {self.ep_size=} {self.moe_dp_size=} {self.moe_tp_size=}", ranks=[0])
@@ -136,6 +144,52 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
self.force_overlap_comm = force_overlap_comm
+ def _init_moe_param_comm(self):
+ self.moe_dp_group = None
+ self.ep_group = None
+ self.moe_tp_group = None
+
+ # create submesh for ep, moe_dp, moe_tp
+ ranks_by_pp_stage = self.pg_mesh.get_group_along_axis(
+ [self.dp_axis, self.tp_axis, self.sp_axis], return_ranks_by_group=True
+ )
+
+ global_rank = self.pg_mesh.rank
+ pp_rank = self.pg_mesh.coordinate(self.pp_axis)
+
+ # create groups from submesh
+ for stage_idx, stage_rank in enumerate(ranks_by_pp_stage):
+ # axis 0 is dp, axis 1 is tp, axis 2 is sp
+ submesh = np.array(stage_rank).reshape(self.moe_dp_size, self.ep_size, self.moe_tp_size)
+
+ # hardcode here since we only have 3 axis
+ # moe_dp_group
+ for ep_idx in range(self.ep_size):
+ for moe_tp_idx in range(self.moe_tp_size):
+ moe_dp_ranks = submesh[:, ep_idx, moe_tp_idx].flatten().tolist()
+ group = dist.new_group(moe_dp_ranks)
+ if pp_rank == stage_idx and global_rank in moe_dp_ranks:
+ assert self.moe_dp_group is None
+ self.moe_dp_group = group
+ # ep_group
+ for moe_dp_idx in range(self.moe_dp_size):
+ for moe_tp_idx in range(self.moe_tp_size):
+ ep_ranks = submesh[moe_dp_idx, :, moe_tp_idx].flatten().tolist()
+ group = dist.new_group(ep_ranks)
+ if pp_rank == stage_idx and global_rank in ep_ranks:
+ assert self.ep_group is None
+ self.ep_group = group
+ # moe_tp_group
+ for moe_dp_idx in range(self.moe_dp_size):
+ for ep_idx in range(self.ep_size):
+ moe_tp_ranks = submesh[moe_dp_idx, ep_idx, :].flatten().tolist()
+ group = dist.new_group(moe_tp_ranks)
+ if pp_rank == stage_idx and global_rank in moe_tp_ranks:
+ assert self.moe_tp_group is None
+ self.moe_tp_group = group
+
+ self.logger.info(f"rank {dist.get_rank()} moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)} ep_group {dist.get_process_group_ranks(self.ep_group)} moe_tp_group {dist.get_process_group_ranks(self.moe_tp_group)}")
+
def get_checkpoint_io(self) -> MoECheckpointIO:
return MoECheckpointIO(
self.dp_group, self.pp_group, self.tp_group, self.ep_group, self.moe_dp_group, self.zero_stage
diff --git a/colossalai/cluster/process_group_mesh.py b/colossalai/cluster/process_group_mesh.py
index c09c7a2cc..ee9e2d71d 100644
--- a/colossalai/cluster/process_group_mesh.py
+++ b/colossalai/cluster/process_group_mesh.py
@@ -209,13 +209,15 @@ class ProcessGroupMesh:
axis: Union[int, List[int]],
indices_at_axis: Optional[Union[List[int], List[List[int]]]] = None,
backend: Optional[str] = None,
- ) -> ProcessGroup:
+ return_ranks_by_group: bool = False
+ ) -> Union[ProcessGroup, List[Tuple[int, ...]]]:
"""Create all process groups along the given axis, and return the one which the current process belongs to.
Args:
axis (int): Axis along which the process groups are created.
indices_at_axis (Optional[List[int]], optional): Indices at the axis. Defaults to None.
backend (Optional[str], optional): Backend of the process group. Defaults to None.
+ return_ranks_by_group (bool): Whether to return all ranks by group for creating submesh. Defaults to False.
Returns:
ProcessGroup: The process group along the given axis which the current process belongs to.
@@ -235,25 +237,35 @@ class ProcessGroupMesh:
# the choices on the axis are reduced to 1, since it's determined by `indices_at_axis`
for ax in axis:
reduced_shape[ax] = 1
- target_group = None
- # use Cartesian product to generate all combinations of coordinates
- for base_coord in itertools.product(*[range(s) for s in reduced_shape]):
- coords_in_group = ProcessGroupMesh.get_coords_along_axis(base_coord, axis, indices_at_axis)
- ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group])
- group = self._get_group(ranks_in_group, backend=backend)
- if self._rank in ranks_in_group:
- target_group = group
- return target_group
+ if return_ranks_by_group:
+ ranks_by_group = []
+ # use Cartesian product to generate all combinations of coordinates
+ for base_coord in itertools.product(*[range(s) for s in reduced_shape]):
+ coords_in_group = ProcessGroupMesh.get_coords_along_axis(base_coord, axis, indices_at_axis)
+ ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group])
+ ranks_by_group.append(ranks_in_group)
+ return ranks_by_group
+ else:
+ target_group = None
+ # use Cartesian product to generate all combinations of coordinates
+ for base_coord in itertools.product(*[range(s) for s in reduced_shape]):
+ coords_in_group = ProcessGroupMesh.get_coords_along_axis(base_coord, axis, indices_at_axis)
+ ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group])
+ group = self._get_group(ranks_in_group, backend=backend)
+ if self._rank in ranks_in_group:
+ target_group = group
+ return target_group
def get_group_along_axis(
- self, axis: Union[int, List[int]], indices_at_axis: Optional[List[int]] = None, backend: Optional[str] = None
- ) -> ProcessGroup:
+ self, axis: Union[int, List[int]], indices_at_axis: Optional[List[int]] = None, backend: Optional[str] = None, return_ranks_by_group: bool = False
+ ) -> Union[ProcessGroup, List[Tuple[int, ...]]]:
"""Get the process group along the given axis which the current process belongs to. If the process group doesn't exist, it will be created.
Args:
axis (int or list of int): Axes along which the process groups are created.
indices_at_axis (Optional[List[int]], optional): Indices at the axis. Defaults to None.
backend (Optional[str], optional): Backend of the process group. Defaults to None.
+ return_ranks_by_group (bool): Whether to return all ranks by group for creating submesh. Defaults to False.
Returns:
ProcessGroup: The process group along the given axis which the current process belongs to.
@@ -267,6 +279,10 @@ class ProcessGroupMesh:
coords_in_group = ProcessGroupMesh.get_coords_along_axis(self._coord, axis, indices_at_axis)
ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group])
+
+ if return_ranks_by_group:
+ return self.create_group_along_axis(axis, indices_at_axis, backend=backend, return_ranks_by_group=True)
+
if ranks_in_group not in self._ranks_to_group:
# no need to cache it explicitly, since it will be cached in `create_group_along_axis`
return self.create_group_along_axis(axis, indices_at_axis, backend=backend)
diff --git a/tests/test_shardformer/test_model/test_shard_mixtral.py b/tests/test_shardformer/test_model/test_shard_mixtral.py
index 123e590c9..f268d1686 100644
--- a/tests/test_shardformer/test_model/test_shard_mixtral.py
+++ b/tests/test_shardformer/test_model/test_shard_mixtral.py
@@ -29,10 +29,10 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin(
model_fn, loss_fn, test_config, pluggin_cls=MoeHybridParallelPlugin, optim_class=torch.optim.Adam
)
- with torch.autograd.set_detect_anomaly(True):
- org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
- org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
- )
+
+ org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
+ org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
+ )
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
@@ -115,8 +115,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
[
{
"tp_size": 1,
- "pp_size": 1,
- "ep_size": 1,
+ "pp_size": 2,
+ "num_microbatches": 2,
+ "ep_size": 2,
"zero_stage": 1,
"overlap_communication": False,
"precision": "fp32",
@@ -125,7 +126,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"tp_size": 1,
"pp_size": 2,
"num_microbatches": 2,
- "ep_size": 1,
+ "ep_size": 2,
"zero_stage": 1,
"overlap_communication": False,
"precision": "fp32",
@@ -134,7 +135,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"tp_size": 2,
"pp_size": 2,
"num_microbatches": 2,
- "ep_size": 1,
+ "ep_size": 2,
"zero_stage": 1,
"overlap_communication": False,
"precision": "fp32",
From 9b9b76bdcd4b2130796f59ce830c781b54762744 Mon Sep 17 00:00:00 2001
From: botbw
Date: Fri, 12 Jul 2024 03:27:20 +0000
Subject: [PATCH 068/116] [moe] add mixtral dp grad scaling when not all
experts are activated
---
.../plugin/moe_hybrid_parallel_plugin.py | 6 ++-
colossalai/moe/_operation.py | 51 +++++++++++++++++--
colossalai/shardformer/layer/moe/experts.py | 6 +--
colossalai/shardformer/layer/moe/routers.py | 6 +--
colossalai/shardformer/modeling/mixtral.py | 46 ++++++++++++-----
colossalai/shardformer/policies/mixtral.py | 14 +----
colossalai/shardformer/shard/shard_config.py | 3 ++
...o_fwd_bwd_optim.py => test_moe_ep_zero.py} | 8 ++-
8 files changed, 98 insertions(+), 42 deletions(-)
rename tests/test_moe/{test_moe_zero_fwd_bwd_optim.py => test_moe_ep_zero.py} (97%)
diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
index f689fe988..902500e42 100644
--- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
+++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
@@ -141,6 +141,8 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
# set ep_group after super init
# TODO do it in a better way
self.shard_config.ep_group = self.ep_group
+ self.shard_config.moe_dp_group = self.moe_dp_group
+ self.shard_config.moe_tp_group = self.moe_tp_group
self.force_overlap_comm = force_overlap_comm
@@ -159,7 +161,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
# create groups from submesh
for stage_idx, stage_rank in enumerate(ranks_by_pp_stage):
- # axis 0 is dp, axis 1 is tp, axis 2 is sp
+ # axis 0 is moe_dp, axis 1 is ep, axis 2 is moe_tp
submesh = np.array(stage_rank).reshape(self.moe_dp_size, self.ep_size, self.moe_tp_size)
# hardcode here since we only have 3 axis
@@ -188,7 +190,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
assert self.moe_tp_group is None
self.moe_tp_group = group
- self.logger.info(f"rank {dist.get_rank()} moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)} ep_group {dist.get_process_group_ranks(self.ep_group)} moe_tp_group {dist.get_process_group_ranks(self.moe_tp_group)}")
+ self.logger.info(f"rank {dist.get_rank()} moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)} ep_group {dist.get_process_group_ranks(self.ep_group)} moe_tp_group {dist.get_process_group_ranks(self.moe_tp_group)}", ranks=[0])
def get_checkpoint_io(self) -> MoECheckpointIO:
return MoECheckpointIO(
diff --git a/colossalai/moe/_operation.py b/colossalai/moe/_operation.py
index cad9573fb..abec2aa6e 100644
--- a/colossalai/moe/_operation.py
+++ b/colossalai/moe/_operation.py
@@ -290,7 +290,7 @@ def moe_cumsum(inputs: Tensor, use_kernel: bool = False):
return torch.cumsum(inputs, dim=0) - 1
-class MoeInGradScaler(torch.autograd.Function):
+class EPGradScalerIn(torch.autograd.Function):
"""
Scale the gradient back by the number of experts
because the batch size increases in the moe stage
@@ -298,8 +298,7 @@ class MoeInGradScaler(torch.autograd.Function):
@staticmethod
def forward(ctx: Any, inputs: Tensor, ep_size: int) -> Tensor:
- if ctx is not None:
- ctx.ep_size = ep_size
+ ctx.ep_size = ep_size
return inputs
@staticmethod
@@ -311,7 +310,7 @@ class MoeInGradScaler(torch.autograd.Function):
return grad, None
-class MoeOutGradScaler(torch.autograd.Function):
+class EPGradScalerOut(torch.autograd.Function):
"""
Scale the gradient by the number of experts
because the batch size increases in the moe stage
@@ -331,6 +330,50 @@ class MoeOutGradScaler(torch.autograd.Function):
return grad, None
+class DPGradScalerIn(torch.autograd.Function):
+ """
+ Scale the gradient back by the number of experts
+ because the batch size increases in the moe stage
+ """
+
+ @staticmethod
+ def forward(ctx: Any, inputs: Tensor, moe_dp_size: int, activated_experts: int) -> Tensor:
+ assert activated_experts != 0, f"shouldn't be called when no expert is activated"
+ ctx.moe_dp_size = moe_dp_size
+ ctx.activated_experts = activated_experts
+ return inputs
+
+ @staticmethod
+ def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None, None]:
+ assert len(grad_outputs) == 1
+ grad = grad_outputs[0]
+ if ctx.moe_dp_size != ctx.activated_experts:
+ grad.mul_(ctx.activated_experts / ctx.moe_dp_size)
+ return grad, None, None
+
+
+class DPGradScalerOut(torch.autograd.Function):
+ """
+ Scale the gradient by the number of experts
+ because the batch size increases in the moe stage
+ """
+
+ @staticmethod
+ def forward(ctx: Any, inputs: Tensor, moe_dp_size: int, activated_experts: int) -> Tensor:
+ assert activated_experts != 0, f"shouldn't be called when no expert is activated"
+ ctx.moe_dp_size = moe_dp_size
+ ctx.activated_experts = activated_experts
+ return inputs
+
+ @staticmethod
+ def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None, None]:
+ assert len(grad_outputs) == 1
+ grad = grad_outputs[0]
+ if ctx.moe_dp_size != ctx.activated_experts:
+ grad.mul_(ctx.moe_dp_size / ctx.activated_experts)
+ return grad, None, None
+
+
def _all_to_all(
inputs: torch.Tensor,
input_split_sizes: Optional[List[int]] = None,
diff --git a/colossalai/shardformer/layer/moe/experts.py b/colossalai/shardformer/layer/moe/experts.py
index 1be7a2754..109740dbb 100644
--- a/colossalai/shardformer/layer/moe/experts.py
+++ b/colossalai/shardformer/layer/moe/experts.py
@@ -5,7 +5,7 @@ import torch
import torch.nn as nn
from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON
-from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler
+from colossalai.moe._operation import EPGradScalerIn, EPGradScalerOut
from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.utils import get_activation
from colossalai.shardformer.layer.utils import Randomizer
@@ -118,7 +118,7 @@ class MLPExperts(nn.Module):
Returns:
torch.Tensor: The output tensor of shape (num_groups, num_experts, capacity, hidden_size)
"""
- x = MoeInGradScaler.apply(x, self.ep_size)
+ x = EPGradScalerIn.apply(x, self.ep_size)
e = x.size(1)
h = x.size(-1)
@@ -157,5 +157,5 @@ class MLPExperts(nn.Module):
x = torch.cat([x[i].unsqueeze(0) for i in range(e)], dim=0)
x = x.reshape(inshape)
x = x.transpose(0, 1).contiguous()
- x = MoeOutGradScaler.apply(x, self.ep_size)
+ x = EPGradScalerOut.apply(x, self.ep_size)
return x
diff --git a/colossalai/shardformer/layer/moe/routers.py b/colossalai/shardformer/layer/moe/routers.py
index 1be7a2754..109740dbb 100644
--- a/colossalai/shardformer/layer/moe/routers.py
+++ b/colossalai/shardformer/layer/moe/routers.py
@@ -5,7 +5,7 @@ import torch
import torch.nn as nn
from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON
-from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler
+from colossalai.moe._operation import EPGradScalerIn, EPGradScalerOut
from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.utils import get_activation
from colossalai.shardformer.layer.utils import Randomizer
@@ -118,7 +118,7 @@ class MLPExperts(nn.Module):
Returns:
torch.Tensor: The output tensor of shape (num_groups, num_experts, capacity, hidden_size)
"""
- x = MoeInGradScaler.apply(x, self.ep_size)
+ x = EPGradScalerIn.apply(x, self.ep_size)
e = x.size(1)
h = x.size(-1)
@@ -157,5 +157,5 @@ class MLPExperts(nn.Module):
x = torch.cat([x[i].unsqueeze(0) for i in range(e)], dim=0)
x = x.reshape(inshape)
x = x.transpose(0, 1).contiguous()
- x = MoeOutGradScaler.apply(x, self.ep_size)
+ x = EPGradScalerOut.apply(x, self.ep_size)
return x
diff --git a/colossalai/shardformer/modeling/mixtral.py b/colossalai/shardformer/modeling/mixtral.py
index 5d2dc1dc3..609fc6f3e 100644
--- a/colossalai/shardformer/modeling/mixtral.py
+++ b/colossalai/shardformer/modeling/mixtral.py
@@ -14,18 +14,23 @@ from transformers.models.mixtral.modeling_mixtral import (
from transformers.utils import is_flash_attn_2_available, logging
from colossalai.lazy import LazyInitContext
-from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler, all_to_all_uneven, drop_tokens, gather_tokens
+from colossalai.moe._operation import DPGradScalerIn, DPGradScalerOut, EPGradScalerIn, EPGradScalerOut, all_to_all_uneven, drop_tokens, gather_tokens
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.shard import ShardConfig
from colossalai.shardformer.shard.utils import set_tensors_to_none
+from colossalai.tensor.moe_tensor.api import set_moe_tensor_ep_group
class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
- def __init__(self, config, ep_group: ProcessGroup, tp_group: Optional[ProcessGroup]=None, moe_tp_group: Optional[ProcessGroup]=None):
- super().__init__(config)
- self.setup_process_groups(ep_group, tp_group, moe_tp_group)
+ def __init__(self, *args, **kwargs):
+ raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}")
+
+ def setup_process_groups(self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup, moe_tp_group: ProcessGroup):
+ assert tp_group is not None
+ assert moe_dp_group is not None
+ assert ep_group is not None
+ assert moe_tp_group is not None
- def setup_process_groups(self, ep_group: ProcessGroup, tp_group: Optional[ProcessGroup]=None, moe_tp_group: Optional[ProcessGroup]=None):
# setup ep group
self.ep_size = dist.get_world_size(ep_group)
self.ep_rank = dist.get_rank(ep_group)
@@ -40,7 +45,11 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
set_tensors_to_none(self.experts, exclude=set(held_experts))
for p in self.experts.parameters():
- p.ep_group = ep_group
+ set_moe_tensor_ep_group(p, ep_group)
+
+ # setup moe_dp group
+ self.moe_dp_group = moe_dp_group
+ self.moe_dp_size = moe_dp_group.size()
# setup global tp group
self.tp_group = tp_group
@@ -50,11 +59,12 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
@staticmethod
def from_native_module(
- module: MixtralSparseMoeBlock, ep_group: ProcessGroup, tp_group: Optional[ProcessGroup]=None, moe_tp_group: Optional[ProcessGroup]=None, *args, **kwargs
+ module: MixtralSparseMoeBlock, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup, moe_tp_group: ProcessGroup, *args, **kwargs
) -> "EPMixtralSparseMoeBlock":
+ # TODO: better init
LazyInitContext.materialize(module)
module.__class__ = EPMixtralSparseMoeBlock
- module.setup_process_groups(ep_group)
+ module.setup_process_groups(tp_group, moe_dp_group, ep_group, moe_tp_group)
return module
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@@ -76,36 +86,48 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
output_split_sizes = torch.zeros_like(input_split_sizes)
dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group)
+ with torch.no_grad():
+ activate_experts = output_split_sizes[: self.num_experts_per_ep].clone()
+ for i in range(1, self.ep_size):
+ activate_experts += output_split_sizes[i * self.num_experts_per_ep : (i + 1) * self.num_experts_per_ep]
+ activate_experts = (activate_experts > 0).float()
+ dist.all_reduce(activate_experts, group=self.moe_dp_group)
+
input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
- if self.tp_group is not None and self.tp_group.size() > 1:
+ if self.tp_group.size() > 1:
dispatch_states = drop_tokens(dispatch_states, -1, self.tp_group)
output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group)
# compute expert output
- output_states = MoeInGradScaler.apply(output_states, self.ep_size)
+ output_states = EPGradScalerIn.apply(output_states, self.ep_size)
if output_states.size(0) > 0:
if self.num_experts_per_ep == 1:
# no need to split
expert = self.experts[self.expert_start_idx]
+ output_states = DPGradScalerIn.apply(output_states, self.moe_dp_size, activate_experts[0].item())
output_states = expert.act_fn(expert.w1(output_states)) * expert.w3(output_states)
output_states = expert.w2(output_states)
+ output_states = DPGradScalerOut.apply(output_states, self.moe_dp_size, activate_experts[0].item())
else:
output_states_splits = output_states.split(output_split_sizes.tolist())
output_states_list = []
for i, split_states in enumerate(output_states_splits):
if split_states.size(0) == 0:
continue
+ split_states = DPGradScalerIn.apply(split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep].item())
expert = self.experts[self.expert_start_idx + i % self.num_experts_per_ep]
split_states = expert.act_fn(expert.w1(split_states)) * expert.w3(split_states)
split_states = expert.w2(split_states)
+ split_states = DPGradScalerOut.apply(split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep].item())
output_states_list.append(split_states)
output_states = torch.cat(output_states_list)
- output_states = MoeOutGradScaler.apply(output_states, self.ep_size)
+
+ output_states = EPGradScalerOut.apply(output_states, self.ep_size)
dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group)
- if self.tp_group is not None and self.tp_group.size() > 1:
+ if self.tp_group.size() > 1:
dispatch_states = gather_tokens(dispatch_states, -1, self.tp_group)
recover_experts_idx = torch.empty_like(selected_experts_idx)
diff --git a/colossalai/shardformer/policies/mixtral.py b/colossalai/shardformer/policies/mixtral.py
index 14d57c79d..69bcc54ed 100644
--- a/colossalai/shardformer/policies/mixtral.py
+++ b/colossalai/shardformer/policies/mixtral.py
@@ -76,18 +76,6 @@ class MixtralPolicy(Policy):
suffix="self_attn.o_proj",
target_module=Linear1D_Row,
),
- # SubModuleReplacementDescription( # TODO: enable moe tp parallel
- # suffix="mlp.gate_proj",
- # target_module=Linear1D_Col,
- # ),
- # SubModuleReplacementDescription(
- # suffix="mlp.up_proj",
- # target_module=Linear1D_Col,
- # ),
- # SubModuleReplacementDescription(
- # suffix="mlp.down_proj",
- # target_module=Linear1D_Row,
- # ),
],
)
@@ -98,7 +86,7 @@ class MixtralPolicy(Policy):
SubModuleReplacementDescription(
suffix="block_sparse_moe",
target_module=EPMixtralSparseMoeBlock,
- kwargs={"ep_group": self.shard_config.ep_group, "tp_group": self.shard_config.tensor_parallel_process_group},
+ kwargs={"ep_group": self.shard_config.ep_group, "tp_group": self.shard_config.tensor_parallel_process_group, "moe_dp_group": self.shard_config.moe_dp_group, "moe_tp_group": self.shard_config.moe_tp_group},
)
],
policy=policy,
diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py
index d1aebd5b2..f12c78526 100644
--- a/colossalai/shardformer/shard/shard_config.py
+++ b/colossalai/shardformer/shard/shard_config.py
@@ -46,6 +46,9 @@ class ShardConfig:
make_vocab_size_divisible_by: int = 64
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None
extra_kwargs: Dict[str, Any] = field(default_factory=dict)
+
+ # for moe related
+ moe_dp_group: Optional[ProcessGroup] = None
ep_group: Optional[ProcessGroup] = None
moe_tp_group: Optional[ProcessGroup] = None
diff --git a/tests/test_moe/test_moe_zero_fwd_bwd_optim.py b/tests/test_moe/test_moe_ep_zero.py
similarity index 97%
rename from tests/test_moe/test_moe_zero_fwd_bwd_optim.py
rename to tests/test_moe/test_moe_ep_zero.py
index 3d6af2b1a..c5adaad06 100644
--- a/tests/test_moe/test_moe_zero_fwd_bwd_optim.py
+++ b/tests/test_moe/test_moe_ep_zero.py
@@ -18,8 +18,7 @@ NUM_BATCH=4
NUM_TOK_PER_BATCH, NUM_EXPERTS = 7, 4
HIDDEN_SIZE_PER_HEAD = 4
NUM_HEADS=2
-TOP_K = 2
-
+TOP_K = 1
def split_grad(grad, world_size):
with torch.no_grad():
@@ -96,7 +95,6 @@ def run_zero_with_original_model(stage: int, ep_size: int):
# check grad
name_to_p = {n: p for n, p in ddp_model.named_parameters()}
for n, p in zero_model.named_parameters():
- print(f"rank {dist.get_rank()} {n}")
zero_grad = zero_optimizer.get_param_grad(p)
if name_to_p[n].grad is None:
name_to_p[n].grad = torch.zeros_like(name_to_p[n].data)
@@ -124,9 +122,9 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [4])
@rerun_if_address_is_in_use()
-def test_moe_ep_tp(world_size):
+def test_moe_ep_zero(world_size):
spawn(run_dist, world_size)
if __name__ == "__main__":
- test_moe_ep_tp(world_size=4)
+ test_moe_ep_zero(world_size=4)
From 014faf6c5a072fd27d9df945a02f90755a369ecc Mon Sep 17 00:00:00 2001
From: botbw
Date: Fri, 12 Jul 2024 03:29:16 +0000
Subject: [PATCH 069/116] [chore] manually revert unintended commit
---
colossalai/booster/plugin/moe_hybrid_parallel_plugin.py | 8 --------
1 file changed, 8 deletions(-)
diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
index 902500e42..22157b5cf 100644
--- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
+++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
@@ -65,14 +65,6 @@ class MoeHybridParallelZeroOptimizer(HybridParallelZeroOptimizer):
overlap_communication = True
warnings.warn(WARN_STR + " Please make sure of this.")
- self.param_info = param_info
- self.stage_manager = model.stage_manager
- self.shared_params = model.shared_params
- self.dp_pg = dp_process_group
-
- if use_pipeline:
- reinitialize_optimizer(optimizer, model)
-
pg_param_list = {
dp_process_group: list(filter(lambda p: not is_moe_tensor(p), model.parameters())),
moe_dp_group: list(filter(is_moe_tensor, model.parameters())),
From 8dbb86899dc1910ae3262243f0eef473e9545cfb Mon Sep 17 00:00:00 2001
From: botbw
Date: Fri, 12 Jul 2024 07:04:17 +0000
Subject: [PATCH 070/116] [chore] trivial fix
---
colossalai/cluster/process_group_mesh.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/colossalai/cluster/process_group_mesh.py b/colossalai/cluster/process_group_mesh.py
index ee9e2d71d..66b77f7a2 100644
--- a/colossalai/cluster/process_group_mesh.py
+++ b/colossalai/cluster/process_group_mesh.py
@@ -48,7 +48,7 @@ class ProcessGroupMesh:
self._shape = size
self._rank = dist.get_rank()
self._coord = ProcessGroupMesh.unravel(self._rank, self._shape)
- self._ranks_to_group: Dict[Tuple[int, ...], Union[ProcessGroup, GroupMember]] = {}
+ self._ranks_to_group: Dict[Tuple[int, ...], Union[ProcessGroup, GroupMember.NON_GROUP_MEMBER]] = {}
self._group_to_ranks: Dict[ProcessGroup, Tuple[int, ...]] = {}
def destroy_mesh_process_groups(self):
From 102b784a10f0cd1c740d9ceba343a78166314290 Mon Sep 17 00:00:00 2001
From: hxwang
Date: Fri, 12 Jul 2024 09:08:16 +0000
Subject: [PATCH 071/116] [chore] arg pass & remove drop token
---
.../plugin/moe_hybrid_parallel_plugin.py | 15 ++++----
colossalai/shardformer/modeling/mixtral.py | 34 +++++++++++++------
tests/test_moe/test_mixtral_layer.py | 8 ++++-
tests/test_moe/test_moe_checkpoint.py | 2 --
tests/test_moe/test_moe_ep_tp.py | 2 +-
5 files changed, 41 insertions(+), 20 deletions(-)
diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
index 22157b5cf..047782aa9 100644
--- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
+++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
@@ -1,8 +1,8 @@
import warnings
from types import MethodType
from typing import Callable, Optional, OrderedDict, Tuple
-import numpy as np
+import numpy as np
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
@@ -21,7 +21,6 @@ from colossalai.booster.plugin.hybrid_parallel_plugin import (
reinitialize_optimizer,
)
from colossalai.checkpoint_io import MoECheckpointIO
-from colossalai.cluster.process_group_mesh import ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.tensor.moe_tensor.api import is_moe_tensor
@@ -89,11 +88,9 @@ class MoeHybridParallelZeroOptimizer(HybridParallelZeroOptimizer):
overlap_communication=overlap_communication,
partition_grad=partition_grad,
cpu_offload=cpu_offload,
- # dp_process_group=dp_process_group,
tp_process_group=tp_process_group,
pp_process_group=pp_process_group,
forced_dtype=forced_dtype,
- ## moe args
pg_to_param_list=pg_param_list,
)
@@ -104,6 +101,9 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
"""
def __init__(self, ep_size: int, moe_tp_size: int = 1, force_overlap_comm=False, *args, **kwargs) -> None:
+ if "overlap_communication" not in kwargs:
+ kwargs["overlap_communication"] = False
+
super().__init__(*args, **kwargs)
self.use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
@@ -142,7 +142,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
self.moe_dp_group = None
self.ep_group = None
self.moe_tp_group = None
-
+
# create submesh for ep, moe_dp, moe_tp
ranks_by_pp_stage = self.pg_mesh.get_group_along_axis(
[self.dp_axis, self.tp_axis, self.sp_axis], return_ranks_by_group=True
@@ -182,7 +182,10 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
assert self.moe_tp_group is None
self.moe_tp_group = group
- self.logger.info(f"rank {dist.get_rank()} moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)} ep_group {dist.get_process_group_ranks(self.ep_group)} moe_tp_group {dist.get_process_group_ranks(self.moe_tp_group)}", ranks=[0])
+ self.logger.info(
+ f"rank {dist.get_rank()} moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)} ep_group {dist.get_process_group_ranks(self.ep_group)} moe_tp_group {dist.get_process_group_ranks(self.moe_tp_group)}",
+ ranks=[0],
+ )
def get_checkpoint_io(self) -> MoECheckpointIO:
return MoECheckpointIO(
diff --git a/colossalai/shardformer/modeling/mixtral.py b/colossalai/shardformer/modeling/mixtral.py
index 609fc6f3e..5a42a1073 100644
--- a/colossalai/shardformer/modeling/mixtral.py
+++ b/colossalai/shardformer/modeling/mixtral.py
@@ -14,7 +14,13 @@ from transformers.models.mixtral.modeling_mixtral import (
from transformers.utils import is_flash_attn_2_available, logging
from colossalai.lazy import LazyInitContext
-from colossalai.moe._operation import DPGradScalerIn, DPGradScalerOut, EPGradScalerIn, EPGradScalerOut, all_to_all_uneven, drop_tokens, gather_tokens
+from colossalai.moe._operation import (
+ DPGradScalerIn,
+ DPGradScalerOut,
+ EPGradScalerIn,
+ EPGradScalerOut,
+ all_to_all_uneven,
+)
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.shard import ShardConfig
from colossalai.shardformer.shard.utils import set_tensors_to_none
@@ -25,7 +31,9 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
def __init__(self, *args, **kwargs):
raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}")
- def setup_process_groups(self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup, moe_tp_group: ProcessGroup):
+ def setup_process_groups(
+ self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup, moe_tp_group: ProcessGroup
+ ):
assert tp_group is not None
assert moe_dp_group is not None
assert ep_group is not None
@@ -59,7 +67,13 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
@staticmethod
def from_native_module(
- module: MixtralSparseMoeBlock, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup, moe_tp_group: ProcessGroup, *args, **kwargs
+ module: MixtralSparseMoeBlock,
+ tp_group: ProcessGroup,
+ moe_dp_group: ProcessGroup,
+ ep_group: ProcessGroup,
+ moe_tp_group: ProcessGroup,
+ *args,
+ **kwargs,
) -> "EPMixtralSparseMoeBlock":
# TODO: better init
LazyInitContext.materialize(module)
@@ -96,8 +110,7 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
- if self.tp_group.size() > 1:
- dispatch_states = drop_tokens(dispatch_states, -1, self.tp_group)
+ # TODO drop tokens to reduce tp group redundant communication
output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group)
# compute expert output
@@ -116,20 +129,21 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
for i, split_states in enumerate(output_states_splits):
if split_states.size(0) == 0:
continue
- split_states = DPGradScalerIn.apply(split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep].item())
+ split_states = DPGradScalerIn.apply(
+ split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep].item()
+ )
expert = self.experts[self.expert_start_idx + i % self.num_experts_per_ep]
split_states = expert.act_fn(expert.w1(split_states)) * expert.w3(split_states)
split_states = expert.w2(split_states)
- split_states = DPGradScalerOut.apply(split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep].item())
+ split_states = DPGradScalerOut.apply(
+ split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep].item()
+ )
output_states_list.append(split_states)
output_states = torch.cat(output_states_list)
output_states = EPGradScalerOut.apply(output_states, self.ep_size)
dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group)
- if self.tp_group.size() > 1:
- dispatch_states = gather_tokens(dispatch_states, -1, self.tp_group)
-
recover_experts_idx = torch.empty_like(selected_experts_idx)
recover_experts_idx[selected_experts_idx] = torch.arange(
selected_experts_idx.size(0), device=selected_experts_idx.device
diff --git a/tests/test_moe/test_mixtral_layer.py b/tests/test_moe/test_mixtral_layer.py
index b7b0322e0..de34b5c7b 100644
--- a/tests/test_moe/test_mixtral_layer.py
+++ b/tests/test_moe/test_mixtral_layer.py
@@ -36,7 +36,13 @@ def check_mixtral_moe_layer():
x = torch.rand(1, tokens, hidden_size, requires_grad=True).cuda()
orig_output, orig_logits = orig_model(x)
model = deepcopy(orig_model)
- model = EPMixtralSparseMoeBlock.from_native_module(model, ep_group=plugin.ep_group)
+ model = EPMixtralSparseMoeBlock.from_native_module(
+ model,
+ ep_group=plugin.ep_group,
+ tp_group=plugin.tp_group,
+ moe_dp_group=plugin.moe_dp_group,
+ moe_tp_group=plugin.moe_tp_group,
+ )
ep_output, ep_logits = model(x)
assert_close(orig_logits, ep_logits)
assert_close(orig_output, ep_output)
diff --git a/tests/test_moe/test_moe_checkpoint.py b/tests/test_moe/test_moe_checkpoint.py
index 773036358..6f3c5b299 100644
--- a/tests/test_moe/test_moe_checkpoint.py
+++ b/tests/test_moe/test_moe_checkpoint.py
@@ -13,7 +13,6 @@ from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
-from colossalai.checkpoint_io import MoECheckpointIO
from colossalai.tensor.moe_tensor.api import is_moe_tensor
from colossalai.testing import parameterize, spawn
from colossalai.testing.utils import spawn
@@ -115,7 +114,6 @@ def check_moe_checkpoint(test_config):
pp_size=2,
ep_size=2,
tp_size=1,
- checkpoint_io=MoECheckpointIO,
microbatch_size=1,
zero_stage=1,
)
diff --git a/tests/test_moe/test_moe_ep_tp.py b/tests/test_moe/test_moe_ep_tp.py
index 632a8ce38..cc5448e51 100644
--- a/tests/test_moe/test_moe_ep_tp.py
+++ b/tests/test_moe/test_moe_ep_tp.py
@@ -34,7 +34,7 @@ def split_grad(grad, world_size):
@parameterize("stage", [1])
@parameterize("ep_size", [1, 2, 4])
@parameterize("tp_size", [1, 2, 4])
-def run_zero_with_original_model(stage: int, ep_size: int, tp_size: int = 1):
+def run_zero_with_original_model(stage: int, ep_size: int, tp_size: int):
dtype = torch.bfloat16
rank = torch.distributed.get_rank()
From 0b5bbe9ce456a17cea00b46ea0255a308a02ecba Mon Sep 17 00:00:00 2001
From: botbw
Date: Mon, 15 Jul 2024 06:43:27 +0000
Subject: [PATCH 072/116] [test] add mixtral modelling test
---
tests/test_moe/modelling/test_mixtral.py | 140 +++++++++++++++++++++++
tests/test_moe/moe_utils.py | 5 +-
2 files changed, 144 insertions(+), 1 deletion(-)
create mode 100644 tests/test_moe/modelling/test_mixtral.py
diff --git a/tests/test_moe/modelling/test_mixtral.py b/tests/test_moe/modelling/test_mixtral.py
new file mode 100644
index 000000000..26fa81921
--- /dev/null
+++ b/tests/test_moe/modelling/test_mixtral.py
@@ -0,0 +1,140 @@
+import os
+import shutil
+from copy import deepcopy
+
+import pytest
+import torch
+import torch.distributed as dist
+from transformers.models.mixtral.configuration_mixtral import MixtralConfig
+from transformers.models.mixtral.modeling_mixtral import MixtralModel
+
+import colossalai
+from colossalai.booster.booster import Booster
+from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
+from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
+from colossalai.testing.random import seed_all
+from tests.test_moe.moe_utils import loose_close
+from tests.test_moe.test_moe_checkpoint import check_model_equal
+
+NUM_BATCH = 4
+NUM_TOK_PER_BATCH, NUM_EXPERTS = 7, 4
+HIDDEN_SIZE_PER_HEAD = 4
+NUM_HEADS = 2
+TOP_K = 1
+
+
+def split_grad(grad, world_size):
+ with torch.no_grad():
+ grad = grad.clone().detach().flatten()
+ padding_size = (world_size - grad.numel() % world_size) % world_size
+ if padding_size > 0:
+ grad = torch.nn.functional.pad(grad, [0, padding_size])
+ splited_grad = grad.split(grad.numel() // world_size)
+ return splited_grad
+
+
+@parameterize("stage", [1])
+@parameterize("ep_size", [1, 2, 4])
+def run_zero_with_original_model(stage: int, ep_size: int, tp_size: int):
+ dtype = torch.float32
+
+ rank = torch.distributed.get_rank()
+ torch.cuda.set_device(dist.get_rank())
+
+ plugin = MoeHybridParallelPlugin(
+ pp_size=1,
+ tp_size=1,
+ ep_size=ep_size,
+ zero_stage=stage,
+ overlap_communication=False,
+ initial_scale=1,
+ precision="fp32",
+ )
+ booster = Booster(plugin=plugin)
+
+ seed_all(10086)
+
+ config = MixtralConfig(
+ hidden_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS,
+ intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2,
+ num_hidden_layers=2,
+ num_attention_heads=NUM_HEADS,
+ num_key_value_heads=NUM_HEADS,
+ num_local_experts=NUM_EXPERTS,
+ num_experts_per_tok=TOP_K,
+ )
+
+ torch_model = MixtralModel(config).to(dtype).cuda()
+ torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
+
+ zero_model = deepcopy(torch_model).to(dtype)
+ zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1)
+
+ zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer)
+
+ # create different input
+ seed_all(1453 + rank)
+
+ torch_model.train()
+ zero_model.train()
+ for _ in range(1):
+ # zero-dp forward
+ input_data = torch.rand(
+ NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True
+ ).cuda()
+ zero_output = zero_model(inputs_embeds=input_data.to(dtype)).last_hidden_state.mean()
+ # zero-dp backward
+ print(zero_output.dtype)
+ zero_optimizer.backward(zero_output)
+ zero_optimizer.step()
+
+ dist.all_reduce(zero_output)
+
+ all_inputs = [torch.empty_like(input_data) for _ in range(dist.get_world_size())]
+ dist.all_gather(all_inputs, input_data)
+
+ torch_output_sum = 0
+ for input_data_ in all_inputs:
+ torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()
+ torch_output.backward()
+ torch_output_sum += torch_output.detach()
+
+ # avg dp grads
+ for p in torch_model.parameters():
+ if p.grad is not None:
+ p.grad /= dist.get_world_size()
+
+ loose_close(zero_output, torch_output_sum, dtype=dtype)
+ torch_optimizer.step()
+
+ # use checkpoint to load sharded zero model
+ model_dir = "./test_mixtral"
+ if dist.get_rank() == 0:
+ os.makedirs(model_dir, exist_ok=True)
+
+ dist.barrier()
+ booster.save_model(zero_model, model_dir, shard=True)
+ dist.barrier()
+
+ if dist.get_rank() == 0:
+ saved_model = MixtralModel.from_pretrained(model_dir).cuda()
+ check_model_equal(torch_model, saved_model)
+ shutil.rmtree(model_dir)
+
+ print(f"{dist.get_rank()} test passed")
+
+
+def run_dist(rank, world_size, port):
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ run_zero_with_original_model()
+
+
+@pytest.mark.dist
+@pytest.mark.parametrize("world_size", [4])
+@rerun_if_address_is_in_use()
+def test_mistral(world_size):
+ spawn(run_dist, world_size)
+
+
+if __name__ == "__main__":
+ test_mistral(world_size=4)
diff --git a/tests/test_moe/moe_utils.py b/tests/test_moe/moe_utils.py
index b7332a937..e49edb6f4 100644
--- a/tests/test_moe/moe_utils.py
+++ b/tests/test_moe/moe_utils.py
@@ -2,7 +2,6 @@ import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed import ProcessGroup
-from torch.testing import assert_close
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
from colossalai.legacy.engine.gradient_handler._base_gradient_handler import BaseGradientHandler
@@ -146,6 +145,10 @@ def loose_close(a, b, dtype: torch.dtype = torch.float32, name=""):
elif dtype is torch.bfloat16:
rtol = 4e-3
atol = 4e-3
+ else:
+ assert dtype is torch.float32
+ rtol = 1e-5
+ atol = 1e-5
a = a.detach().to(dtype)
b = b.detach().to(dtype).to(a.device)
From dc583aa576c3e2c98925613ae5cc2f3a9147ce3d Mon Sep 17 00:00:00 2001
From: botbw
Date: Tue, 16 Jul 2024 06:03:57 +0000
Subject: [PATCH 073/116] [moe] implement tp
---
.../plugin/moe_hybrid_parallel_plugin.py | 10 ++--
colossalai/checkpoint_io/moe_checkpoint.py | 5 +-
colossalai/moe/_operation.py | 18 ++++---
colossalai/shardformer/modeling/mixtral.py | 6 +++
colossalai/shardformer/policies/mixtral.py | 12 ++++-
.../low_level/bookkeeping/gradient_store.py | 1 +
tests/test_moe/modelling/test_mixtral.py | 51 ++++++++++---------
tests/test_moe/test_moe_ep_tp.py | 16 ++++--
8 files changed, 79 insertions(+), 40 deletions(-)
diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
index 047782aa9..bf450534f 100644
--- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
+++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
@@ -113,9 +113,6 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
)
self.ddp_config["find_unused_parameters"] = True
- if moe_tp_size != 1:
- raise NotImplementedError
-
world_size = dist.get_world_size()
self.moe_dp_size = world_size // (self.pp_size * ep_size * moe_tp_size)
self.ep_size = ep_size
@@ -182,6 +179,13 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
assert self.moe_tp_group is None
self.moe_tp_group = group
+ if dist.get_process_group_ranks(self.tp_group) != dist.get_process_group_ranks(self.moe_tp_group):
+ # NOTE: different tp settings between moe and non moe param require complex comm logic, where all_to_all might not be suitable
+ # this assertion implies that dp_size == moe_dp_size * ep_size
+ raise NotImplementedError(
+ f"Only support shared tp group between moe and non moe params, but found non-moe tp {dist.get_process_group_ranks(self.tp_group)}, moe tp {dist.get_process_group_ranks(self.moe_tp_group)}, please make sure tp_size == moe_tp_size"
+ )
+
self.logger.info(
f"rank {dist.get_rank()} moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)} ep_group {dist.get_process_group_ranks(self.ep_group)} moe_tp_group {dist.get_process_group_ranks(self.moe_tp_group)}",
ranks=[0],
diff --git a/colossalai/checkpoint_io/moe_checkpoint.py b/colossalai/checkpoint_io/moe_checkpoint.py
index a0b625008..9181956b7 100644
--- a/colossalai/checkpoint_io/moe_checkpoint.py
+++ b/colossalai/checkpoint_io/moe_checkpoint.py
@@ -151,13 +151,10 @@ class MoECheckpointIO(HybridParallelCheckpointIO):
# ep_rank 0 saves all the parameters and buffers.
# other ep_ranks save only experts
- ep_param_pattern = "experts." if self.ep_rank != 0 else None
# Then collect the sharded parameters & buffers along tp_group.
# Only devices with tp_rank == 0 are responsible for model saving.
- state_dict_shard = MoECheckpointIO._model_sharder(
- model, size_per_shard=size_per_shard, param_name_pattern=ep_param_pattern
- )
+ state_dict_shard = MoECheckpointIO._model_sharder(model, size_per_shard=size_per_shard)
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
index_file = CheckpointIndexFile(checkpoint)
control_saving = self.tp_rank == 0
diff --git a/colossalai/moe/_operation.py b/colossalai/moe/_operation.py
index abec2aa6e..230b40530 100644
--- a/colossalai/moe/_operation.py
+++ b/colossalai/moe/_operation.py
@@ -443,7 +443,7 @@ def all_to_all_uneven(
# ===========================================================
-# This code section was modified from
+# This code section was modified from
# https://github.com/microsoft/DeepSpeed/blob/3d347276ce80e1a29e777c839d1d7fabe8e5f034/deepspeed/moe/mappings.py
# Copyright (c) Microsoft Corporation.
@@ -492,8 +492,9 @@ def _drop_tokens(input_, dim: int, tp_group: ProcessGroup):
total_chunks = tp_group.size()
this_chunk = tp_group.rank()
- assert input_.shape[
- dim] % total_chunks == 0, f"input dimension {dim} ({input_.shape[dim]}) is not divisible by tensor parallel world size ({total_chunks})"
+ assert (
+ input_.shape[dim] % total_chunks == 0
+ ), f"input dimension {dim} ({input_.shape[dim]}) is not divisible by tensor parallel world size ({total_chunks})"
chunk_size = input_.shape[dim] // total_chunks
return torch.narrow(input_, dim, this_chunk * chunk_size, chunk_size)
@@ -531,15 +532,20 @@ def gather_tokens(input_, dim: int, tp_group: ProcessGroup):
if tp_group.size() == 1:
# no tensor parallelism for non-experts
return input_
- assert input_.requires_grad, "Input must require grad to assure that backward is executed, otherwise it might hang the program."
- return _GatherTokens.apply(input_, dim)
+ assert (
+ input_.requires_grad
+ ), "Input must require grad to assure that backward is executed, otherwise it might hang the program."
+ return _GatherTokens.apply(input_, dim, tp_group)
def drop_tokens(input_, dim: int, tp_group: ProcessGroup):
if tp_group.size() == 1:
# no tensor parallelism for non-experts
return input_
- assert input_.requires_grad, "Input must require grad to assure that backward is executed, otherwise it might hang the program."
+ assert (
+ input_.requires_grad
+ ), "Input must require grad to assure that backward is executed, otherwise it might hang the program."
return _DropTokens.apply(input_, dim, tp_group)
+
# ===========================================================
diff --git a/colossalai/shardformer/modeling/mixtral.py b/colossalai/shardformer/modeling/mixtral.py
index 5a42a1073..86ef6c959 100644
--- a/colossalai/shardformer/modeling/mixtral.py
+++ b/colossalai/shardformer/modeling/mixtral.py
@@ -22,6 +22,7 @@ from colossalai.moe._operation import (
all_to_all_uneven,
)
from colossalai.pipeline.stage_manager import PipelineStageManager
+from colossalai.shardformer.layer.linear import Linear1D_Col, Linear1D_Row
from colossalai.shardformer.shard import ShardConfig
from colossalai.shardformer.shard.utils import set_tensors_to_none
from colossalai.tensor.moe_tensor.api import set_moe_tensor_ep_group
@@ -64,6 +65,11 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
# setup moe tp group
self.moe_tp_group = moe_tp_group
+ if self.moe_tp_group.size() > 1:
+ for expert in held_experts:
+ expert.w1 = Linear1D_Col.from_native_module(expert.w1, self.moe_tp_group)
+ expert.w3 = Linear1D_Col.from_native_module(expert.w3, self.moe_tp_group)
+ expert.w2 = Linear1D_Row.from_native_module(expert.w2, self.moe_tp_group)
@staticmethod
def from_native_module(
diff --git a/colossalai/shardformer/policies/mixtral.py b/colossalai/shardformer/policies/mixtral.py
index 69bcc54ed..4b77a167f 100644
--- a/colossalai/shardformer/policies/mixtral.py
+++ b/colossalai/shardformer/policies/mixtral.py
@@ -76,9 +76,14 @@ class MixtralPolicy(Policy):
suffix="self_attn.o_proj",
target_module=Linear1D_Row,
),
+ SubModuleReplacementDescription(
+ suffix="block_sparse_moe.gate", target_module=Linear1D_Col, kwargs={"gather_output": True}
+ ),
],
)
+ # TODO shard vocab embedding
+
if self.shard_config.ep_group:
# expert parallel
self.append_or_create_submodule_replacement(
@@ -86,7 +91,12 @@ class MixtralPolicy(Policy):
SubModuleReplacementDescription(
suffix="block_sparse_moe",
target_module=EPMixtralSparseMoeBlock,
- kwargs={"ep_group": self.shard_config.ep_group, "tp_group": self.shard_config.tensor_parallel_process_group, "moe_dp_group": self.shard_config.moe_dp_group, "moe_tp_group": self.shard_config.moe_tp_group},
+ kwargs={
+ "ep_group": self.shard_config.ep_group,
+ "tp_group": self.shard_config.tensor_parallel_process_group,
+ "moe_dp_group": self.shard_config.moe_dp_group,
+ "moe_tp_group": self.shard_config.moe_tp_group,
+ },
)
],
policy=policy,
diff --git a/colossalai/zero/low_level/bookkeeping/gradient_store.py b/colossalai/zero/low_level/bookkeeping/gradient_store.py
index b84be034a..8b6d403f1 100644
--- a/colossalai/zero/low_level/bookkeeping/gradient_store.py
+++ b/colossalai/zero/low_level/bookkeeping/gradient_store.py
@@ -111,6 +111,7 @@ class GradientStore(BaseStore):
def reset_all_gradients(self):
self._grads_of_params = dict()
+ self.grad_to_param_mapping = dict()
def get_param_id_for_grad(self, grad: Tensor) -> Optional[int]:
"""Return the id of a parameter which the gradient slice belongs to
diff --git a/tests/test_moe/modelling/test_mixtral.py b/tests/test_moe/modelling/test_mixtral.py
index 26fa81921..8309bfb22 100644
--- a/tests/test_moe/modelling/test_mixtral.py
+++ b/tests/test_moe/modelling/test_mixtral.py
@@ -1,6 +1,7 @@
import os
import shutil
from copy import deepcopy
+from typing import Tuple
import pytest
import torch
@@ -19,7 +20,7 @@ from tests.test_moe.test_moe_checkpoint import check_model_equal
NUM_BATCH = 4
NUM_TOK_PER_BATCH, NUM_EXPERTS = 7, 4
HIDDEN_SIZE_PER_HEAD = 4
-NUM_HEADS = 2
+NUM_HEADS = 4
TOP_K = 1
@@ -33,9 +34,9 @@ def split_grad(grad, world_size):
return splited_grad
-@parameterize("stage", [1])
-@parameterize("ep_size", [1, 2, 4])
-def run_zero_with_original_model(stage: int, ep_size: int, tp_size: int):
+@parameterize("config", [(1, 1, 4), (1, 2, 2), (1, 4, 1)])
+def run_zero_with_original_model(config: Tuple[int, ...]):
+ stage, ep_size, tp_size = config
dtype = torch.float32
rank = torch.distributed.get_rank()
@@ -43,7 +44,8 @@ def run_zero_with_original_model(stage: int, ep_size: int, tp_size: int):
plugin = MoeHybridParallelPlugin(
pp_size=1,
- tp_size=1,
+ tp_size=tp_size,
+ moe_tp_size=tp_size,
ep_size=ep_size,
zero_stage=stage,
overlap_communication=False,
@@ -77,17 +79,16 @@ def run_zero_with_original_model(stage: int, ep_size: int, tp_size: int):
torch_model.train()
zero_model.train()
- for _ in range(1):
- # zero-dp forward
+ for _ in range(2):
input_data = torch.rand(
NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True
).cuda()
+ dist.all_reduce(input_data, group=plugin.tp_group) # tp requires duplicate input
+
zero_output = zero_model(inputs_embeds=input_data.to(dtype)).last_hidden_state.mean()
- # zero-dp backward
- print(zero_output.dtype)
zero_optimizer.backward(zero_output)
zero_optimizer.step()
-
+ zero_optimizer.zero_grad()
dist.all_reduce(zero_output)
all_inputs = [torch.empty_like(input_data) for _ in range(dist.get_world_size())]
@@ -98,28 +99,32 @@ def run_zero_with_original_model(stage: int, ep_size: int, tp_size: int):
torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()
torch_output.backward()
torch_output_sum += torch_output.detach()
-
# avg dp grads
for p in torch_model.parameters():
if p.grad is not None:
p.grad /= dist.get_world_size()
+ torch_optimizer.step()
+ torch_optimizer.zero_grad()
loose_close(zero_output, torch_output_sum, dtype=dtype)
- torch_optimizer.step()
- # use checkpoint to load sharded zero model
- model_dir = "./test_mixtral"
- if dist.get_rank() == 0:
- os.makedirs(model_dir, exist_ok=True)
+ # use checkpoint to load sharded zero model
+ model_dir = "./test_mixtral"
+ if dist.get_rank() == 0:
+ os.makedirs(model_dir, exist_ok=True)
- dist.barrier()
- booster.save_model(zero_model, model_dir, shard=True)
- dist.barrier()
+ dist.barrier()
- if dist.get_rank() == 0:
- saved_model = MixtralModel.from_pretrained(model_dir).cuda()
- check_model_equal(torch_model, saved_model)
- shutil.rmtree(model_dir)
+ booster.save_model(zero_model, model_dir, shard=True)
+
+ dist.barrier()
+
+ saved_model = MixtralModel.from_pretrained(model_dir).cuda()
+ check_model_equal(torch_model, saved_model)
+
+ dist.barrier()
+ if dist.get_rank() == 0:
+ shutil.rmtree(model_dir)
print(f"{dist.get_rank()} test passed")
diff --git a/tests/test_moe/test_moe_ep_tp.py b/tests/test_moe/test_moe_ep_tp.py
index cc5448e51..e944a8c0a 100644
--- a/tests/test_moe/test_moe_ep_tp.py
+++ b/tests/test_moe/test_moe_ep_tp.py
@@ -33,8 +33,8 @@ def split_grad(grad, world_size):
@parameterize("stage", [1])
@parameterize("ep_size", [1, 2, 4])
-@parameterize("tp_size", [1, 2, 4])
-def run_zero_with_original_model(stage: int, ep_size: int, tp_size: int):
+def run_zero_with_original_model(stage: int, ep_size: int):
+ tp_size = dist.get_world_size() // ep_size
dtype = torch.bfloat16
rank = torch.distributed.get_rank()
@@ -57,7 +57,13 @@ def run_zero_with_original_model(stage: int, ep_size: int, tp_size: int):
zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1)
moe_booster = Booster(
plugin=MoeHybridParallelPlugin(
- tp_size=tp_size, pp_size=1, ep_size=ep_size, zero_stage=stage, overlap_communication=False, initial_scale=1
+ tp_size=tp_size,
+ moe_tp_size=tp_size,
+ pp_size=1,
+ ep_size=ep_size,
+ zero_stage=stage,
+ overlap_communication=False,
+ initial_scale=1,
)
)
zero_model, zero_optimizer, _, _, _ = moe_booster.boost(zero_model, zero_optimizer)
@@ -100,6 +106,8 @@ def run_zero_with_original_model(stage: int, ep_size: int, tp_size: int):
if name_to_p[n].grad is None:
name_to_p[n].grad = torch.zeros_like(name_to_p[n])
continue
+ if zero_grad.shape != name_to_p[n].grad.shape: # TODO check sharded and sliced moe
+ continue
loose_close(zero_grad, name_to_p[n].grad, dtype=dtype, name=n)
# zero-dp step
@@ -110,6 +118,8 @@ def run_zero_with_original_model(stage: int, ep_size: int, tp_size: int):
# check updated param
for n, p in zero_model.named_parameters():
+ if p.data.shape != name_to_p[n].data.shape: # TODO check sharded and sliced moe
+ continue
loose_close(p.data, name_to_p[n].data, dtype=dtype, name=n)
print(f"{dist.get_rank()} test passed")
From 74eccac0db4f281b14a7042378e4158a4147fc0d Mon Sep 17 00:00:00 2001
From: hxwang
Date: Tue, 16 Jul 2024 10:10:40 +0000
Subject: [PATCH 074/116] [moe] test deepseek
---
colossalai/shardformer/modeling/deepseek.py | 81 +++++++++--
colossalai/shardformer/modeling/mixtral.py | 14 +-
.../shardformer/policies/auto_policy.py | 2 +-
colossalai/shardformer/policies/deepseek.py | 46 +++++-
colossalai/shardformer/policies/mixtral.py | 21 ++-
tests/test_moe/modelling/test_deepseek.py | 133 ++++++++++++++++++
tests/test_moe/modelling/test_mixtral.py | 10 --
tests/test_moe/test_moe_checkpoint.py | 3 +-
tests/test_moe/test_moe_ep_tp.py | 10 --
tests/test_moe/test_moe_ep_zero.py | 24 +---
10 files changed, 276 insertions(+), 68 deletions(-)
create mode 100644 tests/test_moe/modelling/test_deepseek.py
diff --git a/colossalai/shardformer/modeling/deepseek.py b/colossalai/shardformer/modeling/deepseek.py
index 6e79ce144..33fac9b93 100644
--- a/colossalai/shardformer/modeling/deepseek.py
+++ b/colossalai/shardformer/modeling/deepseek.py
@@ -1,21 +1,27 @@
-from typing import List, Optional, Union
+from typing import List, Optional
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed import ProcessGroup
-
-# from colossalai.tensor.moe_tensor.moe_info import MoeParallelInfo
from torch.nn import CrossEntropyLoss
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.utils import is_flash_attn_2_available, logging
from colossalai.lazy import LazyInitContext
-from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler, all_to_all_uneven
+from colossalai.moe._operation import (
+ DPGradScalerIn,
+ DPGradScalerOut,
+ EPGradScalerIn,
+ EPGradScalerOut,
+ all_to_all_uneven,
+)
from colossalai.pipeline.stage_manager import PipelineStageManager
+from colossalai.shardformer.layer.linear import Linear1D_Col, Linear1D_Row
from colossalai.shardformer.shard import ShardConfig
from colossalai.shardformer.shard.utils import set_tensors_to_none
+from colossalai.tensor.moe_tensor.api import set_moe_tensor_ep_group
# copied from modeling_deepseek.py
@@ -42,30 +48,60 @@ class AddAuxiliaryLoss(torch.autograd.Function):
class EPDeepseekMoE(nn.Module):
def __init__(self):
- super(EPDeepseekMoE, self).__init__()
+ raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}")
- def setup_ep(self, ep_group: ProcessGroup):
- ep_group = ep_group
- self.ep_size = dist.get_world_size(ep_group) if ep_group is not None else 1
- self.ep_rank = dist.get_rank(ep_group) if ep_group is not None else 0
+ def setup_process_groups(
+ self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup, moe_tp_group: ProcessGroup
+ ):
+ assert tp_group is not None
+ assert moe_dp_group is not None
+ assert ep_group is not None
+ assert moe_tp_group is not None
+
+ self.ep_size = dist.get_world_size(ep_group)
+ self.ep_rank = dist.get_rank(ep_group)
self.num_experts = self.config.n_routed_experts
assert self.num_experts % self.ep_size == 0
+
self.ep_group = ep_group
self.num_experts_per_ep = self.num_experts // self.ep_size
self.expert_start_idx = self.ep_rank * self.num_experts_per_ep
held_experts = self.experts[self.expert_start_idx : self.expert_start_idx + self.num_experts_per_ep]
+
set_tensors_to_none(self.experts, exclude=set(held_experts))
for p in self.experts.parameters():
- p.ep_group = ep_group
+ set_moe_tensor_ep_group(p, ep_group)
+
+ # setup moe_dp group
+ self.moe_dp_group = moe_dp_group
+ self.moe_dp_size = moe_dp_group.size()
+
+ # setup global tp group
+ self.tp_group = tp_group
+
+ # setup moe tp group
+ self.moe_tp_group = moe_tp_group
+ if self.moe_tp_group.size() > 1:
+ for expert in held_experts:
+ expert.gate_proj = Linear1D_Col.from_native_module(expert.gate_proj, self.moe_tp_group)
+ expert.up_proj = Linear1D_Col.from_native_module(expert.up_proj, self.moe_tp_group)
+ expert.down_proj = Linear1D_Row.from_native_module(expert.down_proj, self.moe_tp_group)
@staticmethod
- def from_native_module(module: Union["DeepseekMoE", "DeepseekMLP"], *args, **kwargs) -> "EPDeepseekMoE":
+ def from_native_module(
+ module,
+ tp_group: ProcessGroup,
+ moe_dp_group: ProcessGroup,
+ ep_group: ProcessGroup,
+ moe_tp_group: ProcessGroup,
+ *args,
+ **kwargs,
+ ) -> "EPDeepseekMoE":
LazyInitContext.materialize(module)
if module.__class__.__name__ == "DeepseekMLP":
return module
module.__class__ = EPDeepseekMoE
- assert "ep_group" in kwargs, "You should pass ep_group in SubModuleReplacementDescription via shard_config!!"
- module.setup_ep(kwargs["ep_group"])
+ module.setup_process_groups(tp_group, moe_dp_group, ep_group, moe_tp_group)
return module
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@@ -91,15 +127,24 @@ class EPDeepseekMoE(nn.Module):
# [n0, n1, n2, n3] [m0, m1, m2, m3] -> [n0, n1, m0, m1] [n2, n3, m2, m3]
dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group)
+ with torch.no_grad():
+ activate_experts = output_split_sizes[: self.num_experts_per_ep].clone()
+ for i in range(1, self.ep_size):
+ activate_experts += output_split_sizes[i * self.num_experts_per_ep : (i + 1) * self.num_experts_per_ep]
+ activate_experts = (activate_experts > 0).float()
+ dist.all_reduce(activate_experts, group=self.moe_dp_group)
+
input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group)
- output_states = MoeInGradScaler.apply(output_states, self.ep_size)
+ output_states = EPGradScalerIn.apply(output_states, self.ep_size)
if output_states.size(0) > 0:
if self.num_experts_per_ep == 1:
expert = self.experts[self.expert_start_idx]
+ output_states = DPGradScalerIn.apply(output_states, self.moe_dp_size, activate_experts[0])
output_states = expert(output_states)
+ output_states = DPGradScalerOut.apply(output_states, self.moe_dp_size, activate_experts[0])
else:
output_states_splits = output_states.split(output_split_sizes.tolist())
output_states_list = []
@@ -107,10 +152,16 @@ class EPDeepseekMoE(nn.Module):
if split_states.size(0) == 0: # no token routed to this experts
continue
expert = self.experts[self.expert_start_idx + i % self.num_experts_per_ep]
+ split_states = DPGradScalerIn.apply(
+ split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep]
+ )
split_states = expert(split_states)
+ split_states = DPGradScalerOut.apply(
+ split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep]
+ )
output_states_list.append(split_states)
output_states = torch.cat(output_states_list)
- output_states = MoeOutGradScaler.apply(output_states, self.ep_size)
+ output_states = EPGradScalerOut.apply(output_states, self.ep_size)
dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group)
recover_token_idx = torch.empty_like(flat_topk_token_idx)
recover_token_idx[flat_topk_token_idx] = torch.arange(
diff --git a/colossalai/shardformer/modeling/mixtral.py b/colossalai/shardformer/modeling/mixtral.py
index 86ef6c959..cfa7da6c0 100644
--- a/colossalai/shardformer/modeling/mixtral.py
+++ b/colossalai/shardformer/modeling/mixtral.py
@@ -116,8 +116,6 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
- # TODO drop tokens to reduce tp group redundant communication
-
output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group)
# compute expert output
output_states = EPGradScalerIn.apply(output_states, self.ep_size)
@@ -125,24 +123,24 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
if self.num_experts_per_ep == 1:
# no need to split
expert = self.experts[self.expert_start_idx]
- output_states = DPGradScalerIn.apply(output_states, self.moe_dp_size, activate_experts[0].item())
+ output_states = DPGradScalerIn.apply(output_states, self.moe_dp_size, activate_experts[0])
output_states = expert.act_fn(expert.w1(output_states)) * expert.w3(output_states)
output_states = expert.w2(output_states)
- output_states = DPGradScalerOut.apply(output_states, self.moe_dp_size, activate_experts[0].item())
+ output_states = DPGradScalerOut.apply(output_states, self.moe_dp_size, activate_experts[0])
else:
output_states_splits = output_states.split(output_split_sizes.tolist())
output_states_list = []
for i, split_states in enumerate(output_states_splits):
if split_states.size(0) == 0:
continue
- split_states = DPGradScalerIn.apply(
- split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep].item()
- )
expert = self.experts[self.expert_start_idx + i % self.num_experts_per_ep]
+ split_states = DPGradScalerIn.apply(
+ split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep]
+ )
split_states = expert.act_fn(expert.w1(split_states)) * expert.w3(split_states)
split_states = expert.w2(split_states)
split_states = DPGradScalerOut.apply(
- split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep].item()
+ split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep]
)
output_states_list.append(split_states)
output_states = torch.cat(output_states_list)
diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py
index 1e0af031a..f2533da4b 100644
--- a/colossalai/shardformer/policies/auto_policy.py
+++ b/colossalai/shardformer/policies/auto_policy.py
@@ -161,7 +161,7 @@ _POLICY_LIST = {
file_name="chatglm2", class_name="ChatGLMForConditionalGenerationPolicy"
),
# Deepseek
- "transformers_modules.modeling_deepseek.DeepSeekModel": PolicyLocation(
+ "transformers_modules.modeling_deepseek.DeepseekModel": PolicyLocation(
file_name="deepseek", class_name="DeepseekModelPolicy"
),
"transformers_modules.modeling_deepseek.DeepseekForCausalLM": PolicyLocation(
diff --git a/colossalai/shardformer/policies/deepseek.py b/colossalai/shardformer/policies/deepseek.py
index 8ebda357b..5a67d653d 100644
--- a/colossalai/shardformer/policies/deepseek.py
+++ b/colossalai/shardformer/policies/deepseek.py
@@ -7,6 +7,7 @@ from torch import Tensor
from torch.nn import Module
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
+from colossalai.shardformer.layer.linear import Linear1D_Row
from colossalai.shardformer.modeling.deepseek import DeepseekPipelineForwards, EPDeepseekMoE
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
@@ -39,16 +40,55 @@ class DeepseekPolicy(Policy):
)
if self.shard_config.enable_tensor_parallelism:
- raise NotImplementedError("Tensor parallelism is not supported for Deepseek model now.")
+ # tensor parallelism for non-moe params
+ assert (
+ self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
+ ), f"The number of attention heads must be divisible by tensor parallel size."
+ assert (
+ self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0
+ ), f"The number of key_value heads must be divisible by tensor parallel size."
+ decoder_attribute_replacement = {
+ "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
+ "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
+ "self_attn.num_key_value_heads": self.model.config.num_key_value_heads
+ // self.shard_config.tensor_parallel_size,
+ }
- if getattr(self.shard_config, "ep_group", None) is not None:
+ policy["DeepseekDecoderLayer"] = ModulePolicyDescription(
+ attribute_replacement=decoder_attribute_replacement,
+ sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="self_attn.q_proj",
+ target_module=Linear1D_Col,
+ ),
+ SubModuleReplacementDescription(
+ suffix="self_attn.k_proj",
+ target_module=Linear1D_Col,
+ ),
+ SubModuleReplacementDescription(
+ suffix="self_attn.v_proj",
+ target_module=Linear1D_Col,
+ ),
+ SubModuleReplacementDescription(
+ suffix="self_attn.o_proj",
+ target_module=Linear1D_Row,
+ ),
+ ],
+ )
+
+ if self.shard_config.ep_group:
# expert parallel
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="mlp",
target_module=EPDeepseekMoE,
- kwargs={"ep_group": self.shard_config.ep_group},
+ kwargs={
+ "ep_group": self.shard_config.ep_group,
+ "tp_group": self.shard_config.tensor_parallel_process_group,
+ "moe_dp_group": self.shard_config.moe_dp_group,
+ "moe_tp_group": self.shard_config.moe_tp_group,
+ },
)
],
policy=policy,
diff --git a/colossalai/shardformer/policies/mixtral.py b/colossalai/shardformer/policies/mixtral.py
index 4b77a167f..8905b5696 100644
--- a/colossalai/shardformer/policies/mixtral.py
+++ b/colossalai/shardformer/policies/mixtral.py
@@ -8,6 +8,7 @@ from torch.nn import Module
from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer, MixtralForCausalLM, MixtralModel
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
+from colossalai.shardformer.layer.embedding import PaddingEmbedding, VocabParallelEmbedding1D
from colossalai.shardformer.layer.linear import Linear1D_Row
from colossalai.shardformer.modeling.mixtral import EPMixtralSparseMoeBlock, MixtralPipelineForwards
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
@@ -42,6 +43,13 @@ class MixtralPolicy(Policy):
"Mixtral dosen't support sequence parallelism now, will ignore the sequence parallelism flag."
)
+ embedding_cls = None
+ if self.shard_config.enable_tensor_parallelism:
+ embedding_cls = VocabParallelEmbedding1D
+ else:
+ if self.tie_weight:
+ embedding_cls = PaddingEmbedding
+
if self.shard_config.enable_tensor_parallelism:
# tensor parallelism for non-moe params
assert (
@@ -76,13 +84,22 @@ class MixtralPolicy(Policy):
suffix="self_attn.o_proj",
target_module=Linear1D_Row,
),
- SubModuleReplacementDescription(
+ SubModuleReplacementDescription( # or replicate?
suffix="block_sparse_moe.gate", target_module=Linear1D_Col, kwargs={"gather_output": True}
),
],
)
- # TODO shard vocab embedding
+ if embedding_cls is not None:
+ self.append_or_create_submodule_replacement(
+ description=SubModuleReplacementDescription(
+ suffix="embed_tokens",
+ target_module=embedding_cls,
+ kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
+ ),
+ policy=policy,
+ target_key=MixtralModel,
+ )
if self.shard_config.ep_group:
# expert parallel
diff --git a/tests/test_moe/modelling/test_deepseek.py b/tests/test_moe/modelling/test_deepseek.py
new file mode 100644
index 000000000..42daea512
--- /dev/null
+++ b/tests/test_moe/modelling/test_deepseek.py
@@ -0,0 +1,133 @@
+import os
+import shutil
+from copy import deepcopy
+from typing import Tuple
+
+import pytest
+import torch
+import torch.distributed as dist
+from transformers import AutoConfig, AutoModel
+
+import colossalai
+from colossalai.booster.booster import Booster
+from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
+from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
+from colossalai.testing.random import seed_all
+from tests.test_moe.moe_utils import loose_close
+from tests.test_moe.test_moe_checkpoint import check_model_equal
+
+NUM_BATCH = 4
+NUM_TOK_PER_BATCH, NUM_EXPERTS = 7, 4
+HIDDEN_SIZE_PER_HEAD = 4
+NUM_HEADS = 4
+TOP_K = 1
+
+
+@parameterize("config", [(1, 1, 1)])
+def run_zero_with_original_model(config: Tuple[int, ...]):
+ stage, ep_size, tp_size = config
+ dtype = torch.float16
+
+ rank = torch.distributed.get_rank()
+ torch.cuda.set_device(dist.get_rank())
+
+ plugin = MoeHybridParallelPlugin(
+ pp_size=1,
+ tp_size=tp_size,
+ moe_tp_size=tp_size,
+ ep_size=ep_size,
+ zero_stage=stage,
+ overlap_communication=False,
+ initial_scale=1,
+ precision="fp32",
+ )
+ booster = Booster(plugin=plugin)
+
+ seed_all(10086)
+
+ config = AutoConfig.from_pretrained("deepseek-ai/deepseek-moe-16b-base", trust_remote_code=True)
+ config.hidden_size = HIDDEN_SIZE_PER_HEAD * NUM_HEADS
+ config.intermediate_size = HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2
+ config.num_hidden_layers = 2
+ config.num_attention_heads = NUM_HEADS
+ config.num_key_value_heads = NUM_HEADS
+ config.n_routed_experts = NUM_EXPERTS
+ config.num_experts_per_tok = TOP_K
+ torch_model = AutoModel.from_config(config, trust_remote_code=True).cuda().to(dtype)
+
+ torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
+
+ zero_model = deepcopy(torch_model).to(dtype)
+ zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1)
+
+ zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer)
+
+ # create different input
+ seed_all(1453 + rank)
+
+ torch_model.train()
+ zero_model.train()
+ for _ in range(2):
+ input_data = torch.rand(
+ NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True
+ ).cuda()
+ dist.all_reduce(input_data, group=plugin.tp_group) # tp requires duplicate input
+
+ zero_output = zero_model(inputs_embeds=input_data.to(dtype)).last_hidden_state.mean()
+ zero_optimizer.backward(zero_output)
+ zero_optimizer.step()
+ zero_optimizer.zero_grad()
+ dist.all_reduce(zero_output)
+
+ all_inputs = [torch.empty_like(input_data) for _ in range(dist.get_world_size())]
+ dist.all_gather(all_inputs, input_data)
+
+ torch_output_sum = 0
+ for input_data_ in all_inputs:
+ torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()
+ torch_output.backward()
+ torch_output_sum += torch_output.detach()
+ # avg dp grads
+ for p in torch_model.parameters():
+ if p.grad is not None:
+ p.grad /= dist.get_world_size()
+ torch_optimizer.step()
+ torch_optimizer.zero_grad()
+
+ loose_close(zero_output, torch_output_sum, dtype=dtype)
+
+ # use checkpoint to load sharded zero model
+ model_dir = "./test_deepseek"
+ if dist.get_rank() == 0:
+ os.makedirs(model_dir, exist_ok=True)
+
+ dist.barrier()
+
+ booster.save_model(zero_model, model_dir, shard=True)
+
+ dist.barrier()
+
+ saved_model = AutoModel.from_pretrained(model_dir, trust_remote_code=True).cuda()
+ check_model_equal(torch_model, saved_model)
+
+ dist.barrier()
+ if dist.get_rank() == 0:
+ shutil.rmtree(model_dir)
+
+ print(f"{dist.get_rank()} test passed")
+
+
+def run_dist(rank, world_size, port):
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ run_zero_with_original_model()
+
+
+@pytest.mark.dist
+@pytest.mark.parametrize("world_size", [4])
+@rerun_if_address_is_in_use()
+def test_mistral(world_size):
+ spawn(run_dist, world_size)
+
+
+if __name__ == "__main__":
+ test_mistral(world_size=4)
diff --git a/tests/test_moe/modelling/test_mixtral.py b/tests/test_moe/modelling/test_mixtral.py
index 8309bfb22..6e6f0b2b5 100644
--- a/tests/test_moe/modelling/test_mixtral.py
+++ b/tests/test_moe/modelling/test_mixtral.py
@@ -24,16 +24,6 @@ NUM_HEADS = 4
TOP_K = 1
-def split_grad(grad, world_size):
- with torch.no_grad():
- grad = grad.clone().detach().flatten()
- padding_size = (world_size - grad.numel() % world_size) % world_size
- if padding_size > 0:
- grad = torch.nn.functional.pad(grad, [0, padding_size])
- splited_grad = grad.split(grad.numel() // world_size)
- return splited_grad
-
-
@parameterize("config", [(1, 1, 4), (1, 2, 2), (1, 4, 1)])
def run_zero_with_original_model(config: Tuple[int, ...]):
stage, ep_size, tp_size = config
diff --git a/tests/test_moe/test_moe_checkpoint.py b/tests/test_moe/test_moe_checkpoint.py
index 6f3c5b299..4bcf701de 100644
--- a/tests/test_moe/test_moe_checkpoint.py
+++ b/tests/test_moe/test_moe_checkpoint.py
@@ -16,6 +16,7 @@ from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParall
from colossalai.tensor.moe_tensor.api import is_moe_tensor
from colossalai.testing import parameterize, spawn
from colossalai.testing.utils import spawn
+from tests.test_moe.moe_utils import loose_close
tokens, n_experts = 7, 4
hidden_size = 8
@@ -25,7 +26,7 @@ top_k = 2
def check_model_equal(model1, model2):
assert set(model1.state_dict().keys()) == set(model2.state_dict().keys())
for i, ((name, p1), p2) in enumerate(zip(model1.named_parameters(), model2.parameters())):
- if not torch.equal(p1.half(), p2.half()):
+ if loose_close(p1, p2, p1.dtype):
print(f"Model parameter {name} is not equal. is_moe_tensor: {is_moe_tensor(p1)}")
raise AssertionError(f"Model parameter {name} is not equal")
diff --git a/tests/test_moe/test_moe_ep_tp.py b/tests/test_moe/test_moe_ep_tp.py
index e944a8c0a..29881c9ab 100644
--- a/tests/test_moe/test_moe_ep_tp.py
+++ b/tests/test_moe/test_moe_ep_tp.py
@@ -21,16 +21,6 @@ NUM_HEADS = 4
TOP_K = 2
-def split_grad(grad, world_size):
- with torch.no_grad():
- grad = grad.clone().detach().flatten()
- padding_size = (world_size - grad.numel() % world_size) % world_size
- if padding_size > 0:
- grad = torch.nn.functional.pad(grad, [0, padding_size])
- splited_grad = grad.split(grad.numel() // world_size)
- return splited_grad
-
-
@parameterize("stage", [1])
@parameterize("ep_size", [1, 2, 4])
def run_zero_with_original_model(stage: int, ep_size: int):
diff --git a/tests/test_moe/test_moe_ep_zero.py b/tests/test_moe/test_moe_ep_zero.py
index c5adaad06..40e3bacb3 100644
--- a/tests/test_moe/test_moe_ep_zero.py
+++ b/tests/test_moe/test_moe_ep_zero.py
@@ -14,21 +14,12 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all
from tests.test_moe.moe_utils import loose_close
-NUM_BATCH=4
+NUM_BATCH = 4
NUM_TOK_PER_BATCH, NUM_EXPERTS = 7, 4
HIDDEN_SIZE_PER_HEAD = 4
-NUM_HEADS=2
+NUM_HEADS = 2
TOP_K = 1
-def split_grad(grad, world_size):
- with torch.no_grad():
- grad = grad.clone().detach().flatten()
- padding_size = (world_size - grad.numel() % world_size) % world_size
- if padding_size > 0:
- grad = torch.nn.functional.pad(grad, [0, padding_size])
- splited_grad = grad.split(grad.numel() // world_size)
- return splited_grad
-
@parameterize("stage", [1])
@parameterize("ep_size", [1, 2, 4])
@@ -39,12 +30,7 @@ def run_zero_with_original_model(stage: int, ep_size: int):
torch.cuda.set_device(dist.get_rank())
plugin = MoeHybridParallelPlugin(
- pp_size=1,
- tp_size=1,
- ep_size=ep_size,
- zero_stage=stage,
- overlap_communication=False,
- initial_scale=1
+ pp_size=1, tp_size=1, ep_size=ep_size, zero_stage=stage, overlap_communication=False, initial_scale=1
)
booster = Booster(plugin=plugin)
@@ -81,7 +67,9 @@ def run_zero_with_original_model(stage: int, ep_size: int):
zero_model.train()
for _ in range(2):
# zero-dp forward
- input_data = torch.rand(NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True).cuda()
+ input_data = torch.rand(
+ NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True
+ ).cuda()
zero_output = zero_model(inputs_embeds=input_data.to(dtype)).last_hidden_state.mean()
# zero-dp backward
zero_optimizer.backward(zero_output)
From 3e2b6132b7c8543324685e527a645f6f33962f38 Mon Sep 17 00:00:00 2001
From: hxwang
Date: Tue, 16 Jul 2024 09:08:31 +0000
Subject: [PATCH 075/116] [moe] clean legacy code
---
.../moe => legacy/moe/layer}/__init__.py | 0
.../layer/moe => legacy/moe/layer}/experts.py | 6 +-
.../layer/moe => legacy/moe/layer}/layers.py | 6 +-
.../layer/moe => legacy/moe/layer}/routers.py | 6 +-
colossalai/{ => legacy}/moe/load_balance.py | 2 +-
colossalai/{ => legacy}/moe/manager.py | 0
.../legacy/moe}/openmoe/README.md | 0
.../moe}/openmoe/benchmark/benchmark_cai.py | 4 +-
.../moe}/openmoe/benchmark/benchmark_cai.sh | 0
.../openmoe/benchmark/benchmark_cai_dist.sh | 0
.../moe}/openmoe/benchmark/benchmark_fsdp.py | 2 +-
.../moe}/openmoe/benchmark/benchmark_fsdp.sh | 0
.../moe}/openmoe/benchmark/hostfile.txt | 0
.../legacy/moe}/openmoe/benchmark/utils.py | 0
.../legacy/moe}/openmoe/infer.py | 0
.../legacy/moe}/openmoe/infer.sh | 0
.../legacy/moe}/openmoe/model/__init__.py | 0
.../openmoe/model/convert_openmoe_ckpt.py | 0
.../openmoe/model/convert_openmoe_ckpt.sh | 0
.../moe}/openmoe/model/modeling_openmoe.py | 4 +-
.../moe}/openmoe/model/openmoe_8b_config.json | 0
.../openmoe/model/openmoe_base_config.json | 0
.../moe}/openmoe/model/openmoe_policy.py | 2 +-
.../legacy/moe}/openmoe/requirements.txt | 0
.../legacy/moe}/openmoe/test_ci.sh | 0
.../legacy/moe}/openmoe/train.py | 2 +-
.../legacy/moe}/openmoe/train.sh | 0
colossalai/{ => legacy}/moe/utils.py | 2 +-
colossalai/moe/__init__.py | 5 -
.../moe/{_operation.py => operators.py} | 2 +
colossalai/shardformer/modeling/mixtral.py | 8 +-
tests/test_legacy/test_moe/moe_utils.py | 136 ++++++++++++++++++
.../test_moe/test_grad_handler.py | 2 +-
.../test_moe/test_mixtral_layer.py | 0
.../test_moe/test_moe_group.py | 4 +-
.../test_moe/test_moe_hybrid_zero.py | 2 +-
.../test_moe/test_moe_load_balance.py | 2 +-
tests/test_moe/moe_utils.py | 135 -----------------
tests/test_moe/test_kernel.py | 4 +-
39 files changed, 163 insertions(+), 173 deletions(-)
rename colossalai/{shardformer/layer/moe => legacy/moe/layer}/__init__.py (100%)
rename colossalai/{shardformer/layer/moe => legacy/moe/layer}/experts.py (97%)
rename colossalai/{shardformer/layer/moe => legacy/moe/layer}/layers.py (98%)
rename colossalai/{shardformer/layer/moe => legacy/moe/layer}/routers.py (97%)
rename colossalai/{ => legacy}/moe/load_balance.py (99%)
rename colossalai/{ => legacy}/moe/manager.py (100%)
rename {examples/language => colossalai/legacy/moe}/openmoe/README.md (100%)
rename {examples/language => colossalai/legacy/moe}/openmoe/benchmark/benchmark_cai.py (99%)
rename {examples/language => colossalai/legacy/moe}/openmoe/benchmark/benchmark_cai.sh (100%)
rename {examples/language => colossalai/legacy/moe}/openmoe/benchmark/benchmark_cai_dist.sh (100%)
rename {examples/language => colossalai/legacy/moe}/openmoe/benchmark/benchmark_fsdp.py (98%)
rename {examples/language => colossalai/legacy/moe}/openmoe/benchmark/benchmark_fsdp.sh (100%)
rename {examples/language => colossalai/legacy/moe}/openmoe/benchmark/hostfile.txt (100%)
rename {examples/language => colossalai/legacy/moe}/openmoe/benchmark/utils.py (100%)
rename {examples/language => colossalai/legacy/moe}/openmoe/infer.py (100%)
rename {examples/language => colossalai/legacy/moe}/openmoe/infer.sh (100%)
rename {examples/language => colossalai/legacy/moe}/openmoe/model/__init__.py (100%)
rename {examples/language => colossalai/legacy/moe}/openmoe/model/convert_openmoe_ckpt.py (100%)
rename {examples/language => colossalai/legacy/moe}/openmoe/model/convert_openmoe_ckpt.sh (100%)
rename {examples/language => colossalai/legacy/moe}/openmoe/model/modeling_openmoe.py (99%)
rename {examples/language => colossalai/legacy/moe}/openmoe/model/openmoe_8b_config.json (100%)
rename {examples/language => colossalai/legacy/moe}/openmoe/model/openmoe_base_config.json (100%)
rename {examples/language => colossalai/legacy/moe}/openmoe/model/openmoe_policy.py (99%)
rename {examples/language => colossalai/legacy/moe}/openmoe/requirements.txt (100%)
rename {examples/language => colossalai/legacy/moe}/openmoe/test_ci.sh (100%)
rename {examples/language => colossalai/legacy/moe}/openmoe/train.py (99%)
rename {examples/language => colossalai/legacy/moe}/openmoe/train.sh (100%)
rename colossalai/{ => legacy}/moe/utils.py (99%)
rename colossalai/moe/{_operation.py => operators.py} (99%)
create mode 100644 tests/test_legacy/test_moe/moe_utils.py
rename tests/{ => test_legacy}/test_moe/test_grad_handler.py (98%)
rename tests/{ => test_legacy}/test_moe/test_mixtral_layer.py (100%)
rename tests/{ => test_legacy}/test_moe/test_moe_group.py (95%)
rename tests/{ => test_legacy}/test_moe/test_moe_hybrid_zero.py (98%)
rename tests/{ => test_legacy}/test_moe/test_moe_load_balance.py (99%)
diff --git a/colossalai/shardformer/layer/moe/__init__.py b/colossalai/legacy/moe/layer/__init__.py
similarity index 100%
rename from colossalai/shardformer/layer/moe/__init__.py
rename to colossalai/legacy/moe/layer/__init__.py
diff --git a/colossalai/shardformer/layer/moe/experts.py b/colossalai/legacy/moe/layer/experts.py
similarity index 97%
rename from colossalai/shardformer/layer/moe/experts.py
rename to colossalai/legacy/moe/layer/experts.py
index 109740dbb..c16fc77bb 100644
--- a/colossalai/shardformer/layer/moe/experts.py
+++ b/colossalai/legacy/moe/layer/experts.py
@@ -5,9 +5,9 @@ import torch
import torch.nn as nn
from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON
-from colossalai.moe._operation import EPGradScalerIn, EPGradScalerOut
-from colossalai.moe.manager import MOE_MANAGER
-from colossalai.moe.utils import get_activation
+from colossalai.legacy.moe.manager import MOE_MANAGER
+from colossalai.legacy.moe.utils import get_activation
+from colossalai.moe.operators import EPGradScalerIn, EPGradScalerOut
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.moe_tensor.api import get_ep_rank, get_ep_size
diff --git a/colossalai/shardformer/layer/moe/layers.py b/colossalai/legacy/moe/layer/layers.py
similarity index 98%
rename from colossalai/shardformer/layer/moe/layers.py
rename to colossalai/legacy/moe/layer/layers.py
index e5b0ef97f..8681b5972 100644
--- a/colossalai/shardformer/layer/moe/layers.py
+++ b/colossalai/legacy/moe/layer/layers.py
@@ -7,9 +7,9 @@ import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
-from colossalai.moe._operation import AllGather, AllToAll, HierarchicalAllToAll, MoeCombine, MoeDispatch, ReduceScatter
-from colossalai.moe.load_balance import LoadBalancer
-from colossalai.moe.utils import create_ep_hierarchical_group, get_noise_generator
+from colossalai.legacy.moe.load_balance import LoadBalancer
+from colossalai.legacy.moe.utils import create_ep_hierarchical_group, get_noise_generator
+from colossalai.moe.operators import AllGather, AllToAll, HierarchicalAllToAll, MoeCombine, MoeDispatch, ReduceScatter
from colossalai.shardformer.layer.moe import MLPExperts
from colossalai.tensor.moe_tensor.api import get_dp_group, get_ep_group, get_ep_group_ranks, get_ep_size
diff --git a/colossalai/shardformer/layer/moe/routers.py b/colossalai/legacy/moe/layer/routers.py
similarity index 97%
rename from colossalai/shardformer/layer/moe/routers.py
rename to colossalai/legacy/moe/layer/routers.py
index 109740dbb..c16fc77bb 100644
--- a/colossalai/shardformer/layer/moe/routers.py
+++ b/colossalai/legacy/moe/layer/routers.py
@@ -5,9 +5,9 @@ import torch
import torch.nn as nn
from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON
-from colossalai.moe._operation import EPGradScalerIn, EPGradScalerOut
-from colossalai.moe.manager import MOE_MANAGER
-from colossalai.moe.utils import get_activation
+from colossalai.legacy.moe.manager import MOE_MANAGER
+from colossalai.legacy.moe.utils import get_activation
+from colossalai.moe.operators import EPGradScalerIn, EPGradScalerOut
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.moe_tensor.api import get_ep_rank, get_ep_size
diff --git a/colossalai/moe/load_balance.py b/colossalai/legacy/moe/load_balance.py
similarity index 99%
rename from colossalai/moe/load_balance.py
rename to colossalai/legacy/moe/load_balance.py
index 3dc6c02c7..7339b1a7b 100644
--- a/colossalai/moe/load_balance.py
+++ b/colossalai/legacy/moe/load_balance.py
@@ -7,7 +7,7 @@ from torch import Tensor, nn
from torch.distributed import ProcessGroup
from colossalai.cluster import ProcessGroupMesh
-from colossalai.moe.manager import MOE_MANAGER
+from colossalai.legacy.moe.manager import MOE_MANAGER
from colossalai.shardformer.layer.moe import MLPExperts
from colossalai.zero.low_level import LowLevelZeroOptimizer
diff --git a/colossalai/moe/manager.py b/colossalai/legacy/moe/manager.py
similarity index 100%
rename from colossalai/moe/manager.py
rename to colossalai/legacy/moe/manager.py
diff --git a/examples/language/openmoe/README.md b/colossalai/legacy/moe/openmoe/README.md
similarity index 100%
rename from examples/language/openmoe/README.md
rename to colossalai/legacy/moe/openmoe/README.md
diff --git a/examples/language/openmoe/benchmark/benchmark_cai.py b/colossalai/legacy/moe/openmoe/benchmark/benchmark_cai.py
similarity index 99%
rename from examples/language/openmoe/benchmark/benchmark_cai.py
rename to colossalai/legacy/moe/openmoe/benchmark/benchmark_cai.py
index b9ef915c3..5f9447246 100644
--- a/examples/language/openmoe/benchmark/benchmark_cai.py
+++ b/colossalai/legacy/moe/openmoe/benchmark/benchmark_cai.py
@@ -18,9 +18,9 @@ from colossalai.accelerator import get_accelerator
from colossalai.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.cluster import DistCoordinator
+from colossalai.legacy.moe.manager import MOE_MANAGER
+from colossalai.legacy.moe.utils import skip_init
from colossalai.moe.layers import apply_load_balance
-from colossalai.moe.manager import MOE_MANAGER
-from colossalai.moe.utils import skip_init
from colossalai.nn.optimizer import HybridAdam
diff --git a/examples/language/openmoe/benchmark/benchmark_cai.sh b/colossalai/legacy/moe/openmoe/benchmark/benchmark_cai.sh
similarity index 100%
rename from examples/language/openmoe/benchmark/benchmark_cai.sh
rename to colossalai/legacy/moe/openmoe/benchmark/benchmark_cai.sh
diff --git a/examples/language/openmoe/benchmark/benchmark_cai_dist.sh b/colossalai/legacy/moe/openmoe/benchmark/benchmark_cai_dist.sh
similarity index 100%
rename from examples/language/openmoe/benchmark/benchmark_cai_dist.sh
rename to colossalai/legacy/moe/openmoe/benchmark/benchmark_cai_dist.sh
diff --git a/examples/language/openmoe/benchmark/benchmark_fsdp.py b/colossalai/legacy/moe/openmoe/benchmark/benchmark_fsdp.py
similarity index 98%
rename from examples/language/openmoe/benchmark/benchmark_fsdp.py
rename to colossalai/legacy/moe/openmoe/benchmark/benchmark_fsdp.py
index b00fbd001..1ae94dd90 100644
--- a/examples/language/openmoe/benchmark/benchmark_fsdp.py
+++ b/colossalai/legacy/moe/openmoe/benchmark/benchmark_fsdp.py
@@ -14,7 +14,7 @@ from torch.utils.data.distributed import DistributedSampler
from transformers.models.llama import LlamaConfig
from utils import PerformanceEvaluator, get_model_numel
-from colossalai.moe.manager import MOE_MANAGER
+from colossalai.legacy.moe.manager import MOE_MANAGER
class RandomDataset(Dataset):
diff --git a/examples/language/openmoe/benchmark/benchmark_fsdp.sh b/colossalai/legacy/moe/openmoe/benchmark/benchmark_fsdp.sh
similarity index 100%
rename from examples/language/openmoe/benchmark/benchmark_fsdp.sh
rename to colossalai/legacy/moe/openmoe/benchmark/benchmark_fsdp.sh
diff --git a/examples/language/openmoe/benchmark/hostfile.txt b/colossalai/legacy/moe/openmoe/benchmark/hostfile.txt
similarity index 100%
rename from examples/language/openmoe/benchmark/hostfile.txt
rename to colossalai/legacy/moe/openmoe/benchmark/hostfile.txt
diff --git a/examples/language/openmoe/benchmark/utils.py b/colossalai/legacy/moe/openmoe/benchmark/utils.py
similarity index 100%
rename from examples/language/openmoe/benchmark/utils.py
rename to colossalai/legacy/moe/openmoe/benchmark/utils.py
diff --git a/examples/language/openmoe/infer.py b/colossalai/legacy/moe/openmoe/infer.py
similarity index 100%
rename from examples/language/openmoe/infer.py
rename to colossalai/legacy/moe/openmoe/infer.py
diff --git a/examples/language/openmoe/infer.sh b/colossalai/legacy/moe/openmoe/infer.sh
similarity index 100%
rename from examples/language/openmoe/infer.sh
rename to colossalai/legacy/moe/openmoe/infer.sh
diff --git a/examples/language/openmoe/model/__init__.py b/colossalai/legacy/moe/openmoe/model/__init__.py
similarity index 100%
rename from examples/language/openmoe/model/__init__.py
rename to colossalai/legacy/moe/openmoe/model/__init__.py
diff --git a/examples/language/openmoe/model/convert_openmoe_ckpt.py b/colossalai/legacy/moe/openmoe/model/convert_openmoe_ckpt.py
similarity index 100%
rename from examples/language/openmoe/model/convert_openmoe_ckpt.py
rename to colossalai/legacy/moe/openmoe/model/convert_openmoe_ckpt.py
diff --git a/examples/language/openmoe/model/convert_openmoe_ckpt.sh b/colossalai/legacy/moe/openmoe/model/convert_openmoe_ckpt.sh
similarity index 100%
rename from examples/language/openmoe/model/convert_openmoe_ckpt.sh
rename to colossalai/legacy/moe/openmoe/model/convert_openmoe_ckpt.sh
diff --git a/examples/language/openmoe/model/modeling_openmoe.py b/colossalai/legacy/moe/openmoe/model/modeling_openmoe.py
similarity index 99%
rename from examples/language/openmoe/model/modeling_openmoe.py
rename to colossalai/legacy/moe/openmoe/model/modeling_openmoe.py
index 1febacd7d..5d6e91765 100644
--- a/examples/language/openmoe/model/modeling_openmoe.py
+++ b/colossalai/legacy/moe/openmoe/model/modeling_openmoe.py
@@ -50,8 +50,8 @@ try:
except:
HAS_FLASH_ATTN = False
from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON
-from colossalai.moe.manager import MOE_MANAGER
-from colossalai.moe.utils import get_activation, set_moe_args
+from colossalai.legacy.moe.manager import MOE_MANAGER
+from colossalai.legacy.moe.utils import get_activation, set_moe_args
from colossalai.shardformer.layer.moe import SparseMLP
if HAS_TRITON:
diff --git a/examples/language/openmoe/model/openmoe_8b_config.json b/colossalai/legacy/moe/openmoe/model/openmoe_8b_config.json
similarity index 100%
rename from examples/language/openmoe/model/openmoe_8b_config.json
rename to colossalai/legacy/moe/openmoe/model/openmoe_8b_config.json
diff --git a/examples/language/openmoe/model/openmoe_base_config.json b/colossalai/legacy/moe/openmoe/model/openmoe_base_config.json
similarity index 100%
rename from examples/language/openmoe/model/openmoe_base_config.json
rename to colossalai/legacy/moe/openmoe/model/openmoe_base_config.json
diff --git a/examples/language/openmoe/model/openmoe_policy.py b/colossalai/legacy/moe/openmoe/model/openmoe_policy.py
similarity index 99%
rename from examples/language/openmoe/model/openmoe_policy.py
rename to colossalai/legacy/moe/openmoe/model/openmoe_policy.py
index f46062128..ccd566b08 100644
--- a/examples/language/openmoe/model/openmoe_policy.py
+++ b/colossalai/legacy/moe/openmoe/model/openmoe_policy.py
@@ -9,7 +9,7 @@ from torch.nn import Module
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.utils import logging
-from colossalai.moe.manager import MOE_MANAGER
+from colossalai.legacy.moe.manager import MOE_MANAGER
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
diff --git a/examples/language/openmoe/requirements.txt b/colossalai/legacy/moe/openmoe/requirements.txt
similarity index 100%
rename from examples/language/openmoe/requirements.txt
rename to colossalai/legacy/moe/openmoe/requirements.txt
diff --git a/examples/language/openmoe/test_ci.sh b/colossalai/legacy/moe/openmoe/test_ci.sh
similarity index 100%
rename from examples/language/openmoe/test_ci.sh
rename to colossalai/legacy/moe/openmoe/test_ci.sh
diff --git a/examples/language/openmoe/train.py b/colossalai/legacy/moe/openmoe/train.py
similarity index 99%
rename from examples/language/openmoe/train.py
rename to colossalai/legacy/moe/openmoe/train.py
index ff0e4bad6..0173f0964 100644
--- a/examples/language/openmoe/train.py
+++ b/colossalai/legacy/moe/openmoe/train.py
@@ -19,7 +19,7 @@ from colossalai.accelerator import get_accelerator
from colossalai.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.cluster import DistCoordinator
-from colossalai.moe.utils import skip_init
+from colossalai.legacy.moe.utils import skip_init
from colossalai.nn.optimizer import HybridAdam
from colossalai.shardformer.layer.moe import apply_load_balance
diff --git a/examples/language/openmoe/train.sh b/colossalai/legacy/moe/openmoe/train.sh
similarity index 100%
rename from examples/language/openmoe/train.sh
rename to colossalai/legacy/moe/openmoe/train.sh
diff --git a/colossalai/moe/utils.py b/colossalai/legacy/moe/utils.py
similarity index 99%
rename from colossalai/moe/utils.py
rename to colossalai/legacy/moe/utils.py
index 3d08ab7dd..d91c41363 100644
--- a/colossalai/moe/utils.py
+++ b/colossalai/legacy/moe/utils.py
@@ -9,7 +9,7 @@ import torch.nn.functional as F
from torch.distributed.distributed_c10d import get_process_group_ranks
from colossalai.accelerator import get_accelerator
-from colossalai.moe.manager import MOE_MANAGER
+from colossalai.legacy.moe.manager import MOE_MANAGER
from colossalai.tensor.moe_tensor.api import is_moe_tensor
diff --git a/colossalai/moe/__init__.py b/colossalai/moe/__init__.py
index 0623d19ef..e69de29bb 100644
--- a/colossalai/moe/__init__.py
+++ b/colossalai/moe/__init__.py
@@ -1,5 +0,0 @@
-from .manager import MOE_MANAGER
-
-__all__ = [
- "MOE_MANAGER",
-]
diff --git a/colossalai/moe/_operation.py b/colossalai/moe/operators.py
similarity index 99%
rename from colossalai/moe/_operation.py
rename to colossalai/moe/operators.py
index 230b40530..f24191c16 100644
--- a/colossalai/moe/_operation.py
+++ b/colossalai/moe/operators.py
@@ -469,6 +469,8 @@ def all_to_all_uneven(
# See the License for the specific language governing permissions and
# limitations under the License.
+# TODO: used when non-moe are tp but moe are not
+
def _gather_tokens(input_, dim: int, tp_group: ProcessGroup):
"""Gather tensors and concatenate them along a dimension"""
diff --git a/colossalai/shardformer/modeling/mixtral.py b/colossalai/shardformer/modeling/mixtral.py
index cfa7da6c0..f8745c1d0 100644
--- a/colossalai/shardformer/modeling/mixtral.py
+++ b/colossalai/shardformer/modeling/mixtral.py
@@ -14,13 +14,7 @@ from transformers.models.mixtral.modeling_mixtral import (
from transformers.utils import is_flash_attn_2_available, logging
from colossalai.lazy import LazyInitContext
-from colossalai.moe._operation import (
- DPGradScalerIn,
- DPGradScalerOut,
- EPGradScalerIn,
- EPGradScalerOut,
- all_to_all_uneven,
-)
+from colossalai.moe.operators import DPGradScalerIn, DPGradScalerOut, EPGradScalerIn, EPGradScalerOut, all_to_all_uneven
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer.linear import Linear1D_Col, Linear1D_Row
from colossalai.shardformer.shard import ShardConfig
diff --git a/tests/test_legacy/test_moe/moe_utils.py b/tests/test_legacy/test_moe/moe_utils.py
new file mode 100644
index 000000000..8c133849b
--- /dev/null
+++ b/tests/test_legacy/test_moe/moe_utils.py
@@ -0,0 +1,136 @@
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+from torch.distributed import ProcessGroup
+
+from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
+from colossalai.legacy.engine.gradient_handler._base_gradient_handler import BaseGradientHandler
+from colossalai.legacy.engine.gradient_handler.utils import bucket_allreduce
+from colossalai.legacy.moe.manager import MOE_MANAGER
+from colossalai.legacy.moe.utils import get_moe_epsize_param_dict
+from colossalai.legacy.registry import GRADIENT_HANDLER
+from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_size, set_moe_tensor_ep_group
+
+
+def delete_moe_info(model):
+ for _, param in model.named_parameters():
+ if hasattr(param, "ep_group"):
+ delattr(param, "ep_group")
+
+
+class MoeModel(nn.Module):
+ def __init__(self, ep_group: ProcessGroup = None):
+ super().__init__()
+ self.test_embed = nn.Linear(4, 16, bias=False)
+ self.w1 = torch.nn.Parameter(torch.randn(16, 8))
+ if ep_group:
+ set_moe_tensor_ep_group(self.w1, ep_group)
+
+ def forward(self, x):
+ x = self.test_embed(x)
+ x = torch.matmul(x, self.w1)
+
+ return x
+
+
+@GRADIENT_HANDLER.register_module
+class MoeGradientHandler(BaseGradientHandler):
+ """A helper class to handle all-reduce operations in a data parallel group and
+ moe model parallel. A all-reduce collective communication will be operated in
+ :func:`handle_gradient` among a data parallel group.
+ For better performance, it bucketizes the gradients of all parameters that are
+ the same type to improve the efficiency of communication.
+
+ Args:
+ model (Module): Model where the gradients accumulate.
+ optimizer (Optimizer): Optimizer for updating the parameters.
+ """
+
+ def __init__(self, model, optimizer=None):
+ super().__init__(model, optimizer)
+
+ def handle_gradient(self):
+ """A method running an all-reduce operation in a data parallel group.
+ Then running an all-reduce operation for all parameters in experts
+ across moe model parallel group
+ """
+ if dist.get_world_size() > 1:
+ epsize_param_dict = get_moe_epsize_param_dict(self._model)
+
+ # epsize is 1, indicating the params are replicated among processes in data parallelism
+ # use the ParallelMode.DATA to get data parallel group
+ # reduce gradients for all parameters in data parallelism
+ if 1 in epsize_param_dict:
+ bucket_allreduce(param_list=epsize_param_dict[1])
+
+ for ep_size in epsize_param_dict:
+ if ep_size != 1 and ep_size != MOE_MANAGER.world_size:
+ bucket_allreduce(
+ param_list=epsize_param_dict[ep_size], group=MOE_MANAGER.parallel_info_dict[ep_size].dp_group
+ )
+
+
+def assert_not_equal_in_group(tensor, process_group=None):
+ # all gather tensors from different ranks
+ world_size = dist.get_world_size(process_group)
+ tensor_list = [torch.empty_like(tensor) for _ in range(world_size)]
+ dist.all_gather(tensor_list, tensor, group=process_group)
+
+ # check if they are equal one by one
+ for i in range(world_size - 1):
+ a = tensor_list[i]
+ b = tensor_list[i + 1]
+ assert not torch.allclose(a, b), (
+ f"expected tensors on rank {i} and {i + 1} not to be equal " f"but they are, {a} vs {b}"
+ )
+
+
+def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False):
+ model.train()
+ with torch.cuda.amp.autocast(enabled=enable_autocast):
+ if criterion:
+ y = model(data)
+ loss = criterion(y, label)
+ else:
+ loss = model(data, label)
+ loss = loss.float()
+
+ if isinstance(model, LowLevelZeroModel):
+ optimizer.backward(loss)
+ else:
+ loss.backward()
+ return y
+
+
+def sync_local_from_ep(local_model, ep_model, assert_grad_flag: bool = False) -> None:
+ """Sync the parameters of tp model from ep model
+
+ Args:
+ local_model (MoeModule)
+ ep_model (MoeModule)
+ """
+ for (local_name, local_param), (ep_name, ep_param) in zip(
+ local_model.named_parameters(), ep_model.named_parameters()
+ ):
+ if "experts" not in local_name:
+ if assert_grad_flag:
+ assert torch.allclose(local_param, ep_param), f"local_param: {local_param}, ep_param: {ep_param}"
+ assert torch.allclose(local_param.grad, ep_param.grad)
+ else:
+ local_param.data.copy_(ep_param.data)
+ continue
+
+ # gather param from ep model
+ param_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))]
+ dist.all_gather(param_list, ep_param, group=get_ep_group(ep_param))
+ all_param = torch.cat(param_list, dim=0)
+ if assert_grad_flag:
+ grad_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))]
+ dist.all_gather(grad_list, ep_param.grad, group=get_ep_group(ep_param))
+ all_grad = torch.cat(grad_list, dim=0)
+
+ if assert_grad_flag:
+ assert torch.allclose(local_param, all_param)
+ assert torch.allclose(local_param.grad, all_grad)
+ else:
+ local_param.data.copy_(all_param.data)
diff --git a/tests/test_moe/test_grad_handler.py b/tests/test_legacy/test_moe/test_grad_handler.py
similarity index 98%
rename from tests/test_moe/test_grad_handler.py
rename to tests/test_legacy/test_moe/test_grad_handler.py
index 25e61b091..3a782a6dd 100644
--- a/tests/test_moe/test_grad_handler.py
+++ b/tests/test_legacy/test_moe/test_grad_handler.py
@@ -5,7 +5,7 @@ import torch.nn as nn
import colossalai
from colossalai.accelerator import get_accelerator
-from colossalai.moe.manager import MOE_MANAGER
+from colossalai.legacy.moe.manager import MOE_MANAGER
# from colossalai.shardformer.layer.moe.layers import SparseMLP
from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn
diff --git a/tests/test_moe/test_mixtral_layer.py b/tests/test_legacy/test_moe/test_mixtral_layer.py
similarity index 100%
rename from tests/test_moe/test_mixtral_layer.py
rename to tests/test_legacy/test_moe/test_mixtral_layer.py
diff --git a/tests/test_moe/test_moe_group.py b/tests/test_legacy/test_moe/test_moe_group.py
similarity index 95%
rename from tests/test_moe/test_moe_group.py
rename to tests/test_legacy/test_moe/test_moe_group.py
index 89baf1d37..68dac4828 100644
--- a/tests/test_moe/test_moe_group.py
+++ b/tests/test_legacy/test_moe/test_moe_group.py
@@ -4,8 +4,8 @@ import torch.nn as nn
import colossalai
from colossalai.accelerator import get_accelerator
-from colossalai.moe.manager import MOE_MANAGER
-from colossalai.moe.utils import sync_moe_model_param
+from colossalai.legacy.moe.manager import MOE_MANAGER
+from colossalai.legacy.moe.utils import sync_moe_model_param
# from colossalai.shardformer.layer.moe import MLPExperts
from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn
diff --git a/tests/test_moe/test_moe_hybrid_zero.py b/tests/test_legacy/test_moe/test_moe_hybrid_zero.py
similarity index 98%
rename from tests/test_moe/test_moe_hybrid_zero.py
rename to tests/test_legacy/test_moe/test_moe_hybrid_zero.py
index 513c4ebda..fdd6d956e 100644
--- a/tests/test_moe/test_moe_hybrid_zero.py
+++ b/tests/test_legacy/test_moe/test_moe_hybrid_zero.py
@@ -6,7 +6,7 @@ import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import LowLevelZeroPlugin
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
-from colossalai.moe.manager import MOE_MANAGER
+from colossalai.legacy.moe.manager import MOE_MANAGER
from colossalai.tensor.moe_tensor.api import is_moe_tensor
from colossalai.testing import rerun_if_address_is_in_use, spawn
from tests.test_moe.moe_utils import MoeModel
diff --git a/tests/test_moe/test_moe_load_balance.py b/tests/test_legacy/test_moe/test_moe_load_balance.py
similarity index 99%
rename from tests/test_moe/test_moe_load_balance.py
rename to tests/test_legacy/test_moe/test_moe_load_balance.py
index ddd3ea368..adf2dbc1c 100644
--- a/tests/test_moe/test_moe_load_balance.py
+++ b/tests/test_legacy/test_moe/test_moe_load_balance.py
@@ -6,7 +6,7 @@ import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import LowLevelZeroPlugin
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
-from colossalai.moe.manager import MOE_MANAGER
+from colossalai.legacy.moe.manager import MOE_MANAGER
# from colossalai.shardformer.layer.moe import apply_load_balance
from colossalai.tensor.moe_tensor.api import is_moe_tensor
diff --git a/tests/test_moe/moe_utils.py b/tests/test_moe/moe_utils.py
index e49edb6f4..55aee6fb0 100644
--- a/tests/test_moe/moe_utils.py
+++ b/tests/test_moe/moe_utils.py
@@ -1,139 +1,4 @@
import torch
-import torch.distributed as dist
-import torch.nn as nn
-from torch.distributed import ProcessGroup
-
-from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
-from colossalai.legacy.engine.gradient_handler._base_gradient_handler import BaseGradientHandler
-from colossalai.legacy.engine.gradient_handler.utils import bucket_allreduce
-from colossalai.legacy.registry import GRADIENT_HANDLER
-from colossalai.moe.manager import MOE_MANAGER
-from colossalai.moe.utils import get_moe_epsize_param_dict
-from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_size, set_moe_tensor_ep_group
-
-
-def delete_moe_info(model):
- for _, param in model.named_parameters():
- if hasattr(param, "ep_group"):
- delattr(param, "ep_group")
-
-
-class MoeModel(nn.Module):
- def __init__(self, ep_group: ProcessGroup = None):
- super().__init__()
- self.test_embed = nn.Linear(4, 16, bias=False)
- self.w1 = torch.nn.Parameter(torch.randn(16, 8))
- if ep_group:
- set_moe_tensor_ep_group(self.w1, ep_group)
-
- def forward(self, x):
- x = self.test_embed(x)
- x = torch.matmul(x, self.w1)
-
- return x
-
-
-@GRADIENT_HANDLER.register_module
-class MoeGradientHandler(BaseGradientHandler):
- """A helper class to handle all-reduce operations in a data parallel group and
- moe model parallel. A all-reduce collective communication will be operated in
- :func:`handle_gradient` among a data parallel group.
- For better performance, it bucketizes the gradients of all parameters that are
- the same type to improve the efficiency of communication.
-
- Args:
- model (Module): Model where the gradients accumulate.
- optimizer (Optimizer): Optimizer for updating the parameters.
- """
-
- def __init__(self, model, optimizer=None):
- super().__init__(model, optimizer)
-
- def handle_gradient(self):
- """A method running an all-reduce operation in a data parallel group.
- Then running an all-reduce operation for all parameters in experts
- across moe model parallel group
- """
- if dist.get_world_size() > 1:
- epsize_param_dict = get_moe_epsize_param_dict(self._model)
-
- # epsize is 1, indicating the params are replicated among processes in data parallelism
- # use the ParallelMode.DATA to get data parallel group
- # reduce gradients for all parameters in data parallelism
- if 1 in epsize_param_dict:
- bucket_allreduce(param_list=epsize_param_dict[1])
-
- for ep_size in epsize_param_dict:
- if ep_size != 1 and ep_size != MOE_MANAGER.world_size:
- bucket_allreduce(
- param_list=epsize_param_dict[ep_size], group=MOE_MANAGER.parallel_info_dict[ep_size].dp_group
- )
-
-
-def assert_not_equal_in_group(tensor, process_group=None):
- # all gather tensors from different ranks
- world_size = dist.get_world_size(process_group)
- tensor_list = [torch.empty_like(tensor) for _ in range(world_size)]
- dist.all_gather(tensor_list, tensor, group=process_group)
-
- # check if they are equal one by one
- for i in range(world_size - 1):
- a = tensor_list[i]
- b = tensor_list[i + 1]
- assert not torch.allclose(a, b), (
- f"expected tensors on rank {i} and {i + 1} not to be equal " f"but they are, {a} vs {b}"
- )
-
-
-def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False):
- model.train()
- with torch.cuda.amp.autocast(enabled=enable_autocast):
- if criterion:
- y = model(data)
- loss = criterion(y, label)
- else:
- loss = model(data, label)
- loss = loss.float()
-
- if isinstance(model, LowLevelZeroModel):
- optimizer.backward(loss)
- else:
- loss.backward()
- return y
-
-
-def sync_local_from_ep(local_model, ep_model, assert_grad_flag: bool = False) -> None:
- """Sync the parameters of tp model from ep model
-
- Args:
- local_model (MoeModule)
- ep_model (MoeModule)
- """
- for (local_name, local_param), (ep_name, ep_param) in zip(
- local_model.named_parameters(), ep_model.named_parameters()
- ):
- if "experts" not in local_name:
- if assert_grad_flag:
- assert torch.allclose(local_param, ep_param), f"local_param: {local_param}, ep_param: {ep_param}"
- assert torch.allclose(local_param.grad, ep_param.grad)
- else:
- local_param.data.copy_(ep_param.data)
- continue
-
- # gather param from ep model
- param_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))]
- dist.all_gather(param_list, ep_param, group=get_ep_group(ep_param))
- all_param = torch.cat(param_list, dim=0)
- if assert_grad_flag:
- grad_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))]
- dist.all_gather(grad_list, ep_param.grad, group=get_ep_group(ep_param))
- all_grad = torch.cat(grad_list, dim=0)
-
- if assert_grad_flag:
- assert torch.allclose(local_param, all_param)
- assert torch.allclose(local_param.grad, all_grad)
- else:
- local_param.data.copy_(all_param.data)
def loose_close(a, b, dtype: torch.dtype = torch.float32, name=""):
diff --git a/tests/test_moe/test_kernel.py b/tests/test_moe/test_kernel.py
index 28e6db441..f2c6d206f 100644
--- a/tests/test_moe/test_kernel.py
+++ b/tests/test_moe/test_kernel.py
@@ -4,9 +4,7 @@ import pytest
import torch
from colossalai.accelerator import get_accelerator
-
-# from colossalai.moe import SparseMLP
-from colossalai.moe._operation import MoeCombine, MoeDispatch, moe_cumsum
+from colossalai.moe.operators import MoeCombine, MoeDispatch, moe_cumsum
NUM_EXPERTS = 4
BATCH_SIZE = 4
From 404b16faf3b7ec44dee7e9d389edc9c80ec9a59c Mon Sep 17 00:00:00 2001
From: Haze188
Date: Thu, 18 Jul 2024 11:37:56 +0800
Subject: [PATCH 076/116] [Feature] MoE Ulysses Support (#5918)
* moe sp support
* moe sp bug solve
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
---
.../plugin/moe_hybrid_parallel_plugin.py | 52 ++-
colossalai/cluster/process_group_mesh.py | 8 +-
colossalai/shardformer/modeling/mixtral.py | 360 +++++++++++++++++-
colossalai/shardformer/policies/mixtral.py | 85 ++++-
tests/kit/model_zoo/transformers/mixtral.py | 8 +-
.../test_model/test_shard_mixtral.py | 130 +++++--
6 files changed, 571 insertions(+), 72 deletions(-)
diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
index bf450534f..0ad3889ae 100644
--- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
+++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
@@ -1,4 +1,6 @@
import warnings
+from collections import defaultdict
+from copy import deepcopy
from types import MethodType
from typing import Callable, Optional, OrderedDict, Tuple
@@ -22,6 +24,8 @@ from colossalai.booster.plugin.hybrid_parallel_plugin import (
)
from colossalai.checkpoint_io import MoECheckpointIO
from colossalai.interface import ModelWrapper, OptimizerWrapper
+from colossalai.interface.optimizer import DistributedOptim
+from colossalai.nn.optimizer import cast_to_distributed
from colossalai.tensor.moe_tensor.api import is_moe_tensor
@@ -114,21 +118,25 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
self.ddp_config["find_unused_parameters"] = True
world_size = dist.get_world_size()
- self.moe_dp_size = world_size // (self.pp_size * ep_size * moe_tp_size)
+ self.moe_dp_size = world_size // (self.pp_size * ep_size * moe_tp_size * self.sp_size)
self.ep_size = ep_size
self.moe_tp_size = moe_tp_size
- if self.pp_size * self.moe_dp_size * self.ep_size * self.moe_tp_size != world_size:
+ if self.pp_size * self.moe_dp_size * self.ep_size * self.moe_tp_size * self.sp_size != world_size:
raise ValueError(
f"world_size={world_size} is not divisible by pp_size={self.pp_size} * moe_dp_size={self.moe_dp_size} * ep_size={self.ep_size} * moe_tp_size={self.moe_tp_size}"
)
- self._init_moe_param_comm()
+ # self._init_moe_param_comm()
self.logger.info(f"{type(self).__name__}: {self.ep_size=} {self.moe_dp_size=} {self.moe_tp_size=}", ranks=[0])
# set ep_group after super init
# TODO do it in a better way
+ self.moe_dp_group = self.pp_group
+ self.ep_group = self.pp_group
+ self.moe_tp_group = self.pp_group
+
self.shard_config.ep_group = self.ep_group
self.shard_config.moe_dp_group = self.moe_dp_group
self.shard_config.moe_tp_group = self.moe_tp_group
@@ -205,15 +213,32 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
param_info = get_param_info(optimizer)
+
+ # TODO: Support Galore + ZeRO
+ self.zero_stage
+ deepcopy(self.zero_config)
+ # Replace with distributed implementation if exists
+ optimizer = cast_to_distributed(optimizer)
+
if not isinstance(model, ModelWrapper):
+ use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or (
+ self.dp_size == 1
+ and self.pp_size == 1
+ and self.enable_sequence_parallelism
+ and self.sequence_parallelism_mode == "all_to_all"
+ )
+ if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all":
+ dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
+ else:
+ dp_group = self.dp_group
model = HybridParallelModule(
module=model,
precision=self.precision,
shard_config=self.shard_config,
- dp_group=self.dp_group,
+ dp_group=dp_group,
tp_group=self.tp_group,
sp_group=self.sp_group,
- use_ddp=self.use_ddp,
+ use_ddp=use_ddp,
ddp_config=self.ddp_config,
custom_policy=self.custom_policy,
)
@@ -224,6 +249,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
reinitialize_optimizer(optimizer, model)
if self.zero_stage == 0:
+ is_zero = False
if self.precision in ["fp16", "bf16"]:
optimizer = HybridParallelAMPOptimizer(
optimizer,
@@ -236,7 +262,13 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
)
else:
optimizer = HybridParallelNaiveOptimizer(
- optimizer, model, use_pipeline=self.enable_pipeline_parallelism, param_info=param_info
+ optimizer,
+ model,
+ use_pipeline=self.enable_pipeline_parallelism,
+ param_info=param_info,
+ max_norm=self.max_norm,
+ pp_process_group=self.pp_group,
+ tp_process_group=self.tp_group,
)
else:
if not (self.dp_size > 1 or self.moe_dp_size > 1):
@@ -244,6 +276,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
"Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. "
"If you do not intend to use cpu_offload, please consider set zero_stage=0."
)
+ assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO."
optimizer = MoeHybridParallelZeroOptimizer(
optimizer,
model,
@@ -262,4 +295,11 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
# inject update_master_params
model.update_master_params = MethodType(optimizer.update_master_params, model)
+ # Setup optimizers that require global states
+ optim = optimizer.optim
+ if isinstance(optim, DistributedOptim):
+ shard_to_param = optimizer.get_master_to_working_map() if is_zero else {}
+ padding_map = optimizer.get_param_padding_map() if is_zero else defaultdict(int)
+ optim.setup_distributed(self.tp_group, self.dp_group, shard_to_param, padding_map, is_zero)
+
return model, optimizer, criterion, dataloader, lr_scheduler
diff --git a/colossalai/cluster/process_group_mesh.py b/colossalai/cluster/process_group_mesh.py
index 66b77f7a2..a9d341efa 100644
--- a/colossalai/cluster/process_group_mesh.py
+++ b/colossalai/cluster/process_group_mesh.py
@@ -209,7 +209,7 @@ class ProcessGroupMesh:
axis: Union[int, List[int]],
indices_at_axis: Optional[Union[List[int], List[List[int]]]] = None,
backend: Optional[str] = None,
- return_ranks_by_group: bool = False
+ return_ranks_by_group: bool = False,
) -> Union[ProcessGroup, List[Tuple[int, ...]]]:
"""Create all process groups along the given axis, and return the one which the current process belongs to.
@@ -257,7 +257,11 @@ class ProcessGroupMesh:
return target_group
def get_group_along_axis(
- self, axis: Union[int, List[int]], indices_at_axis: Optional[List[int]] = None, backend: Optional[str] = None, return_ranks_by_group: bool = False
+ self,
+ axis: Union[int, List[int]],
+ indices_at_axis: Optional[List[int]] = None,
+ backend: Optional[str] = None,
+ return_ranks_by_group: bool = False,
) -> Union[ProcessGroup, List[Tuple[int, ...]]]:
"""Get the process group along the given axis which the current process belongs to. If the process group doesn't exist, it will be created.
diff --git a/colossalai/shardformer/modeling/mixtral.py b/colossalai/shardformer/modeling/mixtral.py
index f8745c1d0..2b50f013d 100644
--- a/colossalai/shardformer/modeling/mixtral.py
+++ b/colossalai/shardformer/modeling/mixtral.py
@@ -1,26 +1,47 @@
-from typing import List, Optional
+import inspect
+import warnings
+from typing import List, Optional, Tuple, Union
import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch.distributed import ProcessGroup
from torch.nn import CrossEntropyLoss
-from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
+from transformers.cache_utils import Cache, DynamicCache
+from transformers.modeling_attn_mask_utils import (
+ _prepare_4d_causal_attention_mask,
+ _prepare_4d_causal_attention_mask_for_sdpa,
+)
from transformers.models.mixtral.modeling_mixtral import (
MixtralSparseMoeBlock,
MoeCausalLMOutputWithPast,
+ MoeModelOutputWithPast,
+ apply_rotary_pos_emb,
load_balancing_loss_func,
+ repeat_kv,
)
from transformers.utils import is_flash_attn_2_available, logging
from colossalai.lazy import LazyInitContext
from colossalai.moe.operators import DPGradScalerIn, DPGradScalerOut, EPGradScalerIn, EPGradScalerOut, all_to_all_uneven
from colossalai.pipeline.stage_manager import PipelineStageManager
+from colossalai.shardformer.layer._operation import (
+ all_to_all_comm,
+ gather_forward_split_backward,
+ split_forward_gather_backward,
+)
from colossalai.shardformer.layer.linear import Linear1D_Col, Linear1D_Row
from colossalai.shardformer.shard import ShardConfig
from colossalai.shardformer.shard.utils import set_tensors_to_none
from colossalai.tensor.moe_tensor.api import set_moe_tensor_ep_group
+if is_flash_attn_2_available():
+ from flash_attn import flash_attn_func
+
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
+
+ _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
+
class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
def __init__(self, *args, **kwargs):
@@ -97,6 +118,7 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
selected_experts_idx = selected_experts.argsort()
dispatch_states = hidden_states.repeat(self.top_k, 1)[selected_experts_idx]
input_split_sizes = selected_experts.bincount(minlength=self.num_experts)
+ dist.get_rank()
output_split_sizes = torch.zeros_like(input_split_sizes)
dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group)
@@ -157,7 +179,7 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
class MixtralPipelineForwards:
"""
- This class serves as a micro library for forward function substitution of Llama models
+ This class serves as a micro library for forward function substitution of Mixtral models
under pipeline setting.
"""
@@ -491,3 +513,335 @@ class MixtralPipelineForwards:
if output_router_logits:
out["past_router_logits"] = outputs["past_router_logits"]
return out
+
+
+def get_mixtral_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None):
+ logger = logging.get_logger(__name__)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
+ if sp_mode is not None:
+ assert sp_mode in ["all_to_all", "split_gather", "ring"], "Invalid sp_mode"
+ assert (sp_size is not None) and (
+ sp_group is not None
+ ), "Must specify sp_size and sp_group for sequence parallel"
+
+ if "padding_mask" in kwargs:
+ warnings.warn(
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
+ )
+
+ # overwrite attention_mask with padding_mask
+ attention_mask = kwargs.pop("padding_mask")
+ bsz, q_len, _ = hidden_states.size()
+
+ # sp: modify sp_len when sequence parallel mode is ring
+ if sp_mode in ["split_gather", "ring"]:
+ q_len *= sp_size
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ # sp: all-to-all comminucation when introducing sequence parallel
+ if sp_mode == "all_to_all":
+ query_states = all_to_all_comm(query_states, sp_group)
+ key_states = all_to_all_comm(key_states, sp_group)
+ value_states = all_to_all_comm(value_states, sp_group)
+ bsz, q_len, _ = query_states.size()
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ if self.layer_idx is None:
+ raise ValueError(
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
+ "with a layer index."
+ )
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
+
+ # Because the input can be padded, the absolute sequence length depends on the max position id.
+ rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
+ cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
+
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+
+ use_sliding_windows = (
+ _flash_supports_window_size
+ and getattr(self.config, "sliding_window", None) is not None
+ and kv_seq_len > self.config.sliding_window
+ )
+ if not _flash_supports_window_size:
+ logger.warning_once(
+ "The current flash attention version does not support sliding window attention, for a more memory efficient implementation"
+ " make sure to upgrade flash-attn library."
+ )
+ if past_key_value is not None:
+ # Activate slicing cache only if the config has a value `sliding_windows` attribute
+ cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
+ if (
+ getattr(self.config, "sliding_window", None) is not None
+ and kv_seq_len > self.config.sliding_window
+ and cache_has_contents
+ ):
+ slicing_tokens = 1 - self.config.sliding_window
+
+ past_key = past_key_value[self.layer_idx][0]
+ past_value = past_key_value[self.layer_idx][1]
+
+ past_key = past_key[:, :, slicing_tokens:, :].contiguous()
+ past_value = past_value[:, :, slicing_tokens:, :].contiguous()
+
+ if past_key.shape[-2] != self.config.sliding_window - 1:
+ raise ValueError(
+ f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
+ f" {past_key.shape}"
+ )
+
+ if attention_mask is not None:
+ attention_mask = attention_mask[:, slicing_tokens:]
+ attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
+
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # repeat k/v heads if n_kv_heads < n_heads
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+ dropout_rate = 0.0 if not self.training else self.attention_dropout
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in float16 just to be sure everything works as expected.
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+ # Reashape to the expected shape for Flash Attention
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+ attn_output = self._flash_attention_forward(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ q_len,
+ dropout=dropout_rate,
+ use_sliding_windows=use_sliding_windows,
+ )
+
+ # sp: all-to-all comminucation when introducing sequence parallel
+ if sp_mode == "all_to_all":
+ attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous() # (1, 8, 128)
+ attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) # (1, 4, 256)
+ else:
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+ return attn_output, attn_weights, past_key_value
+
+ return forward
+
+
+def get_mixtral_flash_attention_model_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None):
+ logger = logging.get_logger(__name__)
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ output_router_logits: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, MoeModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_router_logits = (
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
+ )
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # retrieve input_ids and inputs_embeds
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+ elif input_ids is not None:
+ batch_size, seq_length = input_ids.shape
+ elif inputs_embeds is not None:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ else:
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
+
+ past_key_values_length = 0
+
+ if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+ if use_cache:
+ use_legacy_cache = not isinstance(past_key_values, Cache)
+ if use_legacy_cache:
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+ past_key_values_length = past_key_values.get_usable_length(seq_length)
+
+ if position_ids is None:
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+ position_ids = torch.arange(
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
+ )
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
+ else:
+ position_ids = position_ids.view(-1, seq_length).long()
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
+ is_padding_right = attention_mask[:, -1].sum().item() != batch_size
+ if is_padding_right:
+ raise ValueError(
+ "You are attempting to perform batched generation with padding_side='right'"
+ " this may lead to unexpected behaviour for Flash Attention version of Mixtral. Make sure to "
+ " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
+ )
+ if self._attn_implementation == "flash_attention_2":
+ # 2d mask is passed through the layers
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
+ elif self._attn_implementation == "sdpa" and not output_attentions:
+ # output_attentions=True can not be supported when using SDPA, and we fall back on
+ # the manual implementation that requires a 4D causal mask in all cases.
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
+ attention_mask,
+ (batch_size, seq_length),
+ inputs_embeds,
+ past_key_values_length,
+ )
+ else:
+ # 4d mask is passed through the layers
+ attention_mask = _prepare_4d_causal_attention_mask(
+ attention_mask,
+ (batch_size, seq_length),
+ inputs_embeds,
+ past_key_values_length,
+ sliding_window=self.config.sliding_window,
+ )
+
+ if sp_mode in ["ring", "split_gather"]:
+ inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
+ elif sp_mode == "all_to_all":
+ inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size)
+ hidden_states = inputs_embeds
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ all_router_logits = () if output_router_logits else None
+ next_decoder_cache = None
+
+ for decoder_layer in self.layers:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ decoder_layer.__call__,
+ hidden_states,
+ attention_mask,
+ position_ids,
+ past_key_values,
+ output_attentions,
+ output_router_logits,
+ use_cache,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_values,
+ output_attentions=output_attentions,
+ output_router_logits=output_router_logits,
+ use_cache=use_cache,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ if output_router_logits:
+ all_router_logits += (layer_outputs[-1],)
+
+ hidden_states = self.norm(hidden_states)
+
+ if sp_mode == "ring" or sp_mode == "split_gather":
+ hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group)
+ elif sp_mode == "all_to_all":
+ hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = None
+ if use_cache:
+ next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
+ if v is not None
+ )
+ return MoeModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ router_logits=all_router_logits,
+ )
+
+ return forward
diff --git a/colossalai/shardformer/policies/mixtral.py b/colossalai/shardformer/policies/mixtral.py
index 8905b5696..10f54e1a4 100644
--- a/colossalai/shardformer/policies/mixtral.py
+++ b/colossalai/shardformer/policies/mixtral.py
@@ -5,12 +5,17 @@ from typing import Callable, Dict, List, Union
import torch.nn as nn
from torch import Tensor
from torch.nn import Module
-from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer, MixtralForCausalLM, MixtralModel
+from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM, MixtralModel
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
from colossalai.shardformer.layer.embedding import PaddingEmbedding, VocabParallelEmbedding1D
from colossalai.shardformer.layer.linear import Linear1D_Row
-from colossalai.shardformer.modeling.mixtral import EPMixtralSparseMoeBlock, MixtralPipelineForwards
+from colossalai.shardformer.modeling.mixtral import (
+ EPMixtralSparseMoeBlock,
+ MixtralPipelineForwards,
+ get_mixtral_flash_attention_forward,
+ get_mixtral_flash_attention_model_forward,
+)
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ["MixtralPolicy", "MixtralForCausalLMPolicy"]
@@ -21,27 +26,72 @@ class MixtralPolicy(Policy):
pass
def preprocess(self):
- if self.shard_config.enable_tensor_parallelism:
- # non-moe params tensor parallelism
+ self.origin_attn_implement = self.model.config._attn_implementation
+ # if self.shard_config.enable_tensor_parallelism:
+ # # non-moe params tensor parallelism
- # Resize embedding
- vocab_size = self.model.config.vocab_size
- world_size = self.shard_config.tensor_parallel_size
+ # # Resize embedding
+ # vocab_size = self.model.config.vocab_size
+ # world_size = self.shard_config.tensor_parallel_size
- if vocab_size % world_size != 0:
- new_vocab_size = vocab_size + world_size - vocab_size % world_size
- self.model.resize_token_embeddings(new_vocab_size)
+ # if vocab_size % world_size != 0:
+ # new_vocab_size = vocab_size + world_size - vocab_size % world_size
+ # self.model.resize_token_embeddings(new_vocab_size)
return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
- policy = {}
+ from transformers.models.mixtral.modeling_mixtral import (
+ MixtralAttention,
+ MixtralDecoderLayer,
+ MixtralFlashAttention2,
+ MixtralModel,
+ MixtralSdpaAttention,
+ )
- if self.shard_config.enable_sequence_parallelism:
- self.shard_config.enable_sequence_parallelism = False
- raise NotImplementedError(
- "Mixtral dosen't support sequence parallelism now, will ignore the sequence parallelism flag."
+ ATTN_IMPLEMENTATION = {
+ "eager": MixtralAttention,
+ "flash_attention_2": MixtralFlashAttention2,
+ "sdpa": MixtralSdpaAttention,
+ }
+ policy = {}
+ attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement]
+
+ sp_mode = self.shard_config.sequence_parallelism_mode or None
+ sp_size = self.shard_config.sequence_parallel_size or None
+ sp_group = self.shard_config.sequence_parallel_process_group or None
+ sp_partial_derived = sp_mode in ["split_gather", "ring"]
+ if sp_mode == "all_to_all":
+ decoder_attribute_replacement = {
+ "num_heads": self.model.config.num_attention_heads // sp_size,
+ }
+ if getattr(self.model.config, "num_key_value_heads", False):
+ decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size
+
+ policy[attn_cls] = ModulePolicyDescription(
+ attribute_replacement=decoder_attribute_replacement,
)
+ if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism:
+ self.append_or_create_method_replacement(
+ description={
+ "forward": get_mixtral_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group),
+ },
+ policy=policy,
+ target_key=attn_cls,
+ )
+ if self.pipeline_stage_manager is None:
+ self.append_or_create_method_replacement(
+ description={
+ "forward": get_mixtral_flash_attention_model_forward(
+ self.shard_config,
+ sp_mode=sp_mode,
+ sp_size=sp_size,
+ sp_group=sp_group,
+ ),
+ },
+ policy=policy,
+ target_key=MixtralModel,
+ )
embedding_cls = None
if self.shard_config.enable_tensor_parallelism:
@@ -127,10 +177,12 @@ class MixtralPolicy(Policy):
SubModuleReplacementDescription(
suffix="input_layernorm",
target_module=FusedRMSNorm,
+ kwargs={"sp_partial_derived": sp_partial_derived},
),
SubModuleReplacementDescription(
suffix="post_attention_layernorm",
target_module=FusedRMSNorm,
+ kwargs={"sp_partial_derived": sp_partial_derived},
),
],
policy=policy,
@@ -141,6 +193,7 @@ class MixtralPolicy(Policy):
description=SubModuleReplacementDescription(
suffix="norm",
target_module=FusedRMSNorm,
+ kwargs={"sp_partial_derived": sp_partial_derived},
),
policy=policy,
target_key=MixtralModel,
@@ -308,5 +361,5 @@ class MixtralForSequenceClassificationPolicy(MixtralPolicy):
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
- """No shared params in llama for sequence classification model"""
+ """No shared params in mixtral for sequence classification model"""
return []
diff --git a/tests/kit/model_zoo/transformers/mixtral.py b/tests/kit/model_zoo/transformers/mixtral.py
index 7fa4ff335..40e5a7b02 100644
--- a/tests/kit/model_zoo/transformers/mixtral.py
+++ b/tests/kit/model_zoo/transformers/mixtral.py
@@ -48,11 +48,13 @@ loss_fn = lambda x: x.loss
loss_fn_for_seq_classification = lambda output: output.logits.mean()
config = MixtralConfig(
- hidden_size=256,
- intermediate_size=256,
- num_attention_heads=64,
+ hidden_size=32,
+ intermediate_size=32,
+ num_attention_heads=8,
num_hidden_layers=2,
vocab_size=1000,
+ attn_implementation="flash_attention_2",
+ torch_dtype="float16",
output_router_logits=True,
)
diff --git a/tests/test_shardformer/test_model/test_shard_mixtral.py b/tests/test_shardformer/test_model/test_shard_mixtral.py
index f268d1686..2e2b675a4 100644
--- a/tests/test_shardformer/test_model/test_shard_mixtral.py
+++ b/tests/test_shardformer/test_model/test_shard_mixtral.py
@@ -3,6 +3,8 @@ import os
import pytest
import torch
+import torch.distributed as dist
+from torch.testing import assert_close
import colossalai
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
@@ -15,6 +17,7 @@ from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
check_all_grad_tensors,
check_loss,
+ check_output_hidden_state,
check_weight,
get_grad_tensors_for_check,
run_forward_backward_with_hybrid_plugin,
@@ -27,13 +30,14 @@ os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
# TODO: SGD failed for full dp
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin(
- model_fn, loss_fn, test_config, pluggin_cls=MoeHybridParallelPlugin, optim_class=torch.optim.Adam
+ model_fn, loss_fn, test_config, pluggin_cls=MoeHybridParallelPlugin, optim_class=torch.optim.SGD
)
+ org_model = org_model.to(torch.float16)
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
)
-
+ print(org_output.last_hidden_state.shape, sharded_output.last_hidden_state.shape)
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
@@ -45,6 +49,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
atol, rtol = 5e-3, 5e-3
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
+ check_output_hidden_state(org_output, sharded_output, stage_manager, atol, rtol)
# unwrap model
mixtral_model = unwrap_model(org_model, "MixtralModel", "model")
@@ -53,6 +58,22 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
row_layer_for_check = ["layers[0].self_attn.q_proj", "embed_tokens"]
col_layer_for_check = ["layers[0].self_attn.o_proj"]
+ # Check the grad when using ZeRO-1 and ZeRO-2
+ if (
+ # booster.plugin.zero_stage in [1, 2]
+ booster.plugin.shard_config.enable_sequence_parallelism
+ and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all"
+ ):
+ rank = dist.get_rank()
+ # for p1, p2 in zip(mixtral_model.parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]):
+ for (n1, p1), (n2, p2) in zip(mixtral_model.named_parameters(), shard_mixtral_model.named_parameters()):
+ try:
+ assert_close(p1.grad, p2.grad, atol=5e-3, rtol=5e-3, check_dtype=False)
+ print(f"{rank=},passed grad: {n1}, {n2}")
+ except Exception as e:
+ print(f"{rank=},failed grad: {n1} {p1.grad[:2,:2]}, {n2} {p2.grad[:2, :2]}")
+ raise e
+
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
grads_to_check = {}
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
@@ -84,28 +105,49 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
grads_to_check.update(row_layer_grads)
# check grads
+ # print(grads_to_check)
check_all_grad_tensors(grads_to_check)
-
+ for (n1, p1), (n2, p2) in zip(mixtral_model.named_parameters(), shard_mixtral_model.named_parameters()):
+ try:
+ assert_close(p1, p2, atol=5e-3, rtol=5e-3, check_dtype=False)
+ print(f"{rank=},passed param before step: {n1}, {n2}")
+ except Exception:
+ print(
+ f"{rank=},failed param before step: {n1} {p1[:2,:2] if p1 else None}, {n2} {p2[:2, :2] if p2 else None}"
+ )
# optimizer executes step
org_optimizer.step()
sharded_optimizer.step()
-
+ for (n1, p1), (n2, p2) in zip(mixtral_model.named_parameters(), shard_mixtral_model.named_parameters()):
+ try:
+ assert_close(p1, p2, atol=5e-3, rtol=5e-3, check_dtype=False)
+ print(f"{rank=},passed param after step: {n1}, {n2}")
+ except Exception as e:
+ print(
+ f"{rank=},failed param after step: {n1} {p1 if p1 is not None else None}, {n2} {p2 if p2 is not None else None}"
+ )
+ raise e
# check weights
if stage_manager is None or stage_manager.is_first_stage():
if test_config["precision"] == "fp32":
atol, rtol = 2e-4, 1e-3
else:
atol, rtol = 5e-3, 5e-3
- check_weight(
- mixtral_model,
- shard_mixtral_model,
- col_layer_for_check,
- tp_group,
- atol=atol,
- rtol=rtol,
- dim=1,
- verbose=False,
- )
+ try:
+ check_weight(
+ mixtral_model,
+ shard_mixtral_model,
+ col_layer_for_check,
+ tp_group,
+ atol=atol,
+ rtol=rtol,
+ dim=1,
+ verbose=False,
+ )
+ except Exception as e:
+ rank = dist.get_rank()
+ print(f"{rank=}, Failed config: {test_config}")
+ raise e
torch.cuda.empty_cache()
@@ -113,33 +155,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
@parameterize(
"test_config",
[
- {
- "tp_size": 1,
- "pp_size": 2,
- "num_microbatches": 2,
- "ep_size": 2,
- "zero_stage": 1,
- "overlap_communication": False,
- "precision": "fp32",
- }, # [dp(4)] + [moe_dp(4)]
- {
- "tp_size": 1,
- "pp_size": 2,
- "num_microbatches": 2,
- "ep_size": 2,
- "zero_stage": 1,
- "overlap_communication": False,
- "precision": "fp32",
- }, # [dp(2) + pp(2)] + [moe_pp(2)]
- {
- "tp_size": 2,
- "pp_size": 2,
- "num_microbatches": 2,
- "ep_size": 2,
- "zero_stage": 1,
- "overlap_communication": False,
- "precision": "fp32",
- }, # [pp(2) + tp(2)] + [pp(2), replicate(2)] pass
# {
# "tp_size": 1,
# "pp_size": 2,
@@ -148,7 +163,38 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
# "zero_stage": 1,
# "overlap_communication": False,
# "precision": "fp32",
- # }, # [dp(2) + pp(2)] + [ep(4))]
+ # }, # [dp(4)] + [moe_dp(4)]
+ # {
+ # "tp_size": 1,
+ # "pp_size": 2,
+ # "num_microbatches": 2,
+ # "ep_size": 2,
+ # "zero_stage": 1,
+ # "overlap_communication": False,
+ # "precision": "fp32",
+ # }, # [dp(2) + pp(2)] + [moe_pp(2)]
+ # {
+ # "tp_size": 2,
+ # "pp_size": 2,
+ # "num_microbatches": 2,
+ # "ep_size": 2,
+ # "zero_stage": 1,
+ # "overlap_communication": False,
+ # "precision": "fp32",
+ # }, # [pp(2) + tp(2)] + [pp(2), replicate(2)] pass
+ { # Ulysess + Flash attention
+ "tp_size": 1,
+ "pp_size": 1,
+ "sp_size": 4,
+ "ep_size": 1,
+ "enable_sequence_parallelism": True,
+ "sequence_parallelism_mode": "all_to_all",
+ "zero_stage": 0,
+ "overlap_communication": False,
+ "precision": "fp16",
+ "initial_scale": 1,
+ "find_unused_parameters": True,
+ },
# {
# "tp_size": 1,
# "pp_size": 1,
From 09d6280d3ebedf829b4b7d8831b49bb81ba4e9fb Mon Sep 17 00:00:00 2001
From: hxwang
Date: Thu, 18 Jul 2024 03:53:51 +0000
Subject: [PATCH 077/116] [chore] minor fix
---
colossalai/shardformer/policies/mixtral.py | 12 +--------
.../test_model/test_shard_mixtral.py | 26 ++++---------------
2 files changed, 6 insertions(+), 32 deletions(-)
diff --git a/colossalai/shardformer/policies/mixtral.py b/colossalai/shardformer/policies/mixtral.py
index 10f54e1a4..8fed5ee5c 100644
--- a/colossalai/shardformer/policies/mixtral.py
+++ b/colossalai/shardformer/policies/mixtral.py
@@ -26,18 +26,8 @@ class MixtralPolicy(Policy):
pass
def preprocess(self):
+ self.tie_weight = self.tie_weight_check()
self.origin_attn_implement = self.model.config._attn_implementation
- # if self.shard_config.enable_tensor_parallelism:
- # # non-moe params tensor parallelism
-
- # # Resize embedding
- # vocab_size = self.model.config.vocab_size
- # world_size = self.shard_config.tensor_parallel_size
-
- # if vocab_size % world_size != 0:
- # new_vocab_size = vocab_size + world_size - vocab_size % world_size
- # self.model.resize_token_embeddings(new_vocab_size)
-
return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
diff --git a/tests/test_shardformer/test_model/test_shard_mixtral.py b/tests/test_shardformer/test_model/test_shard_mixtral.py
index 2e2b675a4..e873f46f7 100644
--- a/tests/test_shardformer/test_model/test_shard_mixtral.py
+++ b/tests/test_shardformer/test_model/test_shard_mixtral.py
@@ -67,12 +67,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
rank = dist.get_rank()
# for p1, p2 in zip(mixtral_model.parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]):
for (n1, p1), (n2, p2) in zip(mixtral_model.named_parameters(), shard_mixtral_model.named_parameters()):
- try:
- assert_close(p1.grad, p2.grad, atol=5e-3, rtol=5e-3, check_dtype=False)
- print(f"{rank=},passed grad: {n1}, {n2}")
- except Exception as e:
- print(f"{rank=},failed grad: {n1} {p1.grad[:2,:2]}, {n2} {p2.grad[:2, :2]}")
- raise e
+ assert_close(p1.grad, p2.grad, atol=5e-3, rtol=5e-3, check_dtype=False)
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
grads_to_check = {}
@@ -108,25 +103,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
# print(grads_to_check)
check_all_grad_tensors(grads_to_check)
for (n1, p1), (n2, p2) in zip(mixtral_model.named_parameters(), shard_mixtral_model.named_parameters()):
- try:
- assert_close(p1, p2, atol=5e-3, rtol=5e-3, check_dtype=False)
- print(f"{rank=},passed param before step: {n1}, {n2}")
- except Exception:
- print(
- f"{rank=},failed param before step: {n1} {p1[:2,:2] if p1 else None}, {n2} {p2[:2, :2] if p2 else None}"
- )
+ assert_close(p1, p2, atol=5e-3, rtol=5e-3, check_dtype=False)
+
# optimizer executes step
org_optimizer.step()
sharded_optimizer.step()
for (n1, p1), (n2, p2) in zip(mixtral_model.named_parameters(), shard_mixtral_model.named_parameters()):
- try:
- assert_close(p1, p2, atol=5e-3, rtol=5e-3, check_dtype=False)
- print(f"{rank=},passed param after step: {n1}, {n2}")
- except Exception as e:
- print(
- f"{rank=},failed param after step: {n1} {p1 if p1 is not None else None}, {n2} {p2 if p2 is not None else None}"
- )
- raise e
+ assert_close(p1, p2, atol=5e-3, rtol=5e-3, check_dtype=False)
+
# check weights
if stage_manager is None or stage_manager.is_first_stage():
if test_config["precision"] == "fp32":
From 877d94bb8cf763f469ff93b0911d9e05d596a6cf Mon Sep 17 00:00:00 2001
From: hxwang
Date: Thu, 18 Jul 2024 08:37:06 +0000
Subject: [PATCH 078/116] [moe] init moe plugin comm setting with sp
---
.../plugin/moe_hybrid_parallel_plugin.py | 163 ++++++++++--------
colossalai/shardformer/modeling/deepseek.py | 8 +-
colossalai/shardformer/modeling/mixtral.py | 2 +-
tests/test_moe/modelling/test_deepseek.py | 2 +-
tests/test_moe/modelling/test_mixtral.py | 9 +-
tests/test_moe/test_moe_checkpoint.py | 4 +-
.../test_model/test_shard_mixtral.py | 8 +-
7 files changed, 101 insertions(+), 95 deletions(-)
diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
index 0ad3889ae..fc3340981 100644
--- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
+++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
@@ -1,6 +1,5 @@
import warnings
from collections import defaultdict
-from copy import deepcopy
from types import MethodType
from typing import Callable, Optional, OrderedDict, Tuple
@@ -106,37 +105,35 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
def __init__(self, ep_size: int, moe_tp_size: int = 1, force_overlap_comm=False, *args, **kwargs) -> None:
if "overlap_communication" not in kwargs:
- kwargs["overlap_communication"] = False
+ kwargs["overlap_communication"] = False # default by true in super class
super().__init__(*args, **kwargs)
- self.use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
+ self.ep_size = ep_size
+ self.moe_tp_size = moe_tp_size
+
+ self._init_moe_param_comm()
+
+ self.use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or (
+ self.dp_size == 1
+ and self.pp_size == 1
+ and self.enable_sequence_parallelism
+ and self.sequence_parallelism_mode == "all_to_all"
+ )
+
if self.use_ddp:
warnings.warn(
f"Will have to check all params are used in pytorch DDP since not all experts are always activated"
)
self.ddp_config["find_unused_parameters"] = True
- world_size = dist.get_world_size()
- self.moe_dp_size = world_size // (self.pp_size * ep_size * moe_tp_size * self.sp_size)
- self.ep_size = ep_size
- self.moe_tp_size = moe_tp_size
+ if dist.get_process_group_ranks(self.dp_group) != dist.get_process_group_ranks(self.moe_dp_group):
+ raise ValueError(
+ f"if ddp is used, dp_group and moe_dp_group are expected to be the same since DDP can only reduce grad across a single group, but found dp_group {dist.get_process_group_ranks(self.dp_group)} and moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)}, you might want to set ep_size=1 or zero_stage > 0"
+ )
- if self.pp_size * self.moe_dp_size * self.ep_size * self.moe_tp_size * self.sp_size != world_size:
- raise ValueError(
- f"world_size={world_size} is not divisible by pp_size={self.pp_size} * moe_dp_size={self.moe_dp_size} * ep_size={self.ep_size} * moe_tp_size={self.moe_tp_size}"
- )
-
- # self._init_moe_param_comm()
-
- self.logger.info(f"{type(self).__name__}: {self.ep_size=} {self.moe_dp_size=} {self.moe_tp_size=}", ranks=[0])
-
- # set ep_group after super init
+ # set ep_group after super().__init__()
# TODO do it in a better way
- self.moe_dp_group = self.pp_group
- self.ep_group = self.pp_group
- self.moe_tp_group = self.pp_group
-
self.shard_config.ep_group = self.ep_group
self.shard_config.moe_dp_group = self.moe_dp_group
self.shard_config.moe_tp_group = self.moe_tp_group
@@ -144,48 +141,77 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
self.force_overlap_comm = force_overlap_comm
def _init_moe_param_comm(self):
- self.moe_dp_group = None
- self.ep_group = None
- self.moe_tp_group = None
+ world_size = dist.get_world_size()
- # create submesh for ep, moe_dp, moe_tp
- ranks_by_pp_stage = self.pg_mesh.get_group_along_axis(
- [self.dp_axis, self.tp_axis, self.sp_axis], return_ranks_by_group=True
- )
+ if self.enable_sequence_parallelism:
+ # if sequence parallelism is enabled, we reuse the same group for ep and sp
+ if self.sequence_parallelism_mode == "all_to_all":
+ # when sequence parallelism is enabled, ep_group reuses sp_group
+ if self.ep_size != self.sp_size:
+ raise ValueError(
+ f"ep_size={self.ep_size} should be equal to sp_size={self.sp_size} when sequence parallelism is enabled"
+ )
- global_rank = self.pg_mesh.rank
- pp_rank = self.pg_mesh.coordinate(self.pp_axis)
+ self.moe_dp_size = self.dp_size
+ self.moe_dp_group = self.dp_group # NOTE: sequence of value assignment matters
+ self.dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
+ self.ep_group = self.sp_group
+ self.moe_tp_group = self.tp_group
+ else:
+ raise NotImplementedError(
+ f"sequence_parallelism_mode={self.sequence_parallelism_mode} is not supported"
+ )
- # create groups from submesh
- for stage_idx, stage_rank in enumerate(ranks_by_pp_stage):
- # axis 0 is moe_dp, axis 1 is ep, axis 2 is moe_tp
- submesh = np.array(stage_rank).reshape(self.moe_dp_size, self.ep_size, self.moe_tp_size)
+ else:
+ self.moe_dp_size = world_size // (self.pp_size * self.ep_size * self.moe_tp_size)
- # hardcode here since we only have 3 axis
- # moe_dp_group
- for ep_idx in range(self.ep_size):
- for moe_tp_idx in range(self.moe_tp_size):
- moe_dp_ranks = submesh[:, ep_idx, moe_tp_idx].flatten().tolist()
- group = dist.new_group(moe_dp_ranks)
- if pp_rank == stage_idx and global_rank in moe_dp_ranks:
- assert self.moe_dp_group is None
- self.moe_dp_group = group
- # ep_group
- for moe_dp_idx in range(self.moe_dp_size):
- for moe_tp_idx in range(self.moe_tp_size):
- ep_ranks = submesh[moe_dp_idx, :, moe_tp_idx].flatten().tolist()
- group = dist.new_group(ep_ranks)
- if pp_rank == stage_idx and global_rank in ep_ranks:
- assert self.ep_group is None
- self.ep_group = group
- # moe_tp_group
- for moe_dp_idx in range(self.moe_dp_size):
+ if self.pp_size * self.moe_dp_size * self.ep_size * self.moe_tp_size * self.sp_size != world_size:
+ raise ValueError(
+ f"world_size={world_size} is not divisible by pp_size={self.pp_size} * moe_dp_size={self.moe_dp_size} * ep_size={self.ep_size} * moe_tp_size={self.moe_tp_size}"
+ )
+
+ self.moe_dp_group = None
+ self.ep_group = None
+ self.moe_tp_group = None
+
+ # create submesh for ep, moe_dp, moe_tp
+ ranks_by_pp_stage = self.pg_mesh.get_group_along_axis(
+ [self.dp_axis, self.tp_axis, self.sp_axis], return_ranks_by_group=True
+ )
+
+ global_rank = self.pg_mesh.rank
+ pp_rank = self.pg_mesh.coordinate(self.pp_axis)
+
+ # create groups from submesh
+ for stage_idx, stage_rank in enumerate(ranks_by_pp_stage):
+ # axis 0 is moe_dp, axis 1 is ep, axis 2 is moe_tp
+ submesh = np.array(stage_rank).reshape(self.moe_dp_size, self.ep_size, self.moe_tp_size)
+
+ # hardcode here since we only have 3 axis
+ # moe_dp_group
for ep_idx in range(self.ep_size):
- moe_tp_ranks = submesh[moe_dp_idx, ep_idx, :].flatten().tolist()
- group = dist.new_group(moe_tp_ranks)
- if pp_rank == stage_idx and global_rank in moe_tp_ranks:
- assert self.moe_tp_group is None
- self.moe_tp_group = group
+ for moe_tp_idx in range(self.moe_tp_size):
+ moe_dp_ranks = submesh[:, ep_idx, moe_tp_idx].flatten().tolist()
+ group = dist.new_group(moe_dp_ranks)
+ if pp_rank == stage_idx and global_rank in moe_dp_ranks:
+ assert self.moe_dp_group is None
+ self.moe_dp_group = group
+ # ep_group
+ for moe_dp_idx in range(self.moe_dp_size):
+ for moe_tp_idx in range(self.moe_tp_size):
+ ep_ranks = submesh[moe_dp_idx, :, moe_tp_idx].flatten().tolist()
+ group = dist.new_group(ep_ranks)
+ if pp_rank == stage_idx and global_rank in ep_ranks:
+ assert self.ep_group is None
+ self.ep_group = group
+ # moe_tp_group
+ for moe_dp_idx in range(self.moe_dp_size):
+ for ep_idx in range(self.ep_size):
+ moe_tp_ranks = submesh[moe_dp_idx, ep_idx, :].flatten().tolist()
+ group = dist.new_group(moe_tp_ranks)
+ if pp_rank == stage_idx and global_rank in moe_tp_ranks:
+ assert self.moe_tp_group is None
+ self.moe_tp_group = group
if dist.get_process_group_ranks(self.tp_group) != dist.get_process_group_ranks(self.moe_tp_group):
# NOTE: different tp settings between moe and non moe param require complex comm logic, where all_to_all might not be suitable
@@ -195,7 +221,8 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
)
self.logger.info(
- f"rank {dist.get_rank()} moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)} ep_group {dist.get_process_group_ranks(self.ep_group)} moe_tp_group {dist.get_process_group_ranks(self.moe_tp_group)}",
+ f"{type(self).__name__}: {self.ep_size=} {self.moe_dp_size=} {self.moe_tp_size=} {self.sp_size}\n"
+ f"rank {dist.get_rank()} moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)} ep_group {dist.get_process_group_ranks(self.ep_group)} moe_tp_group {dist.get_process_group_ranks(self.moe_tp_group)} sp_group {dist.get_process_group_ranks(self.sp_group)}",
ranks=[0],
)
@@ -215,30 +242,18 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
param_info = get_param_info(optimizer)
# TODO: Support Galore + ZeRO
- self.zero_stage
- deepcopy(self.zero_config)
# Replace with distributed implementation if exists
optimizer = cast_to_distributed(optimizer)
if not isinstance(model, ModelWrapper):
- use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or (
- self.dp_size == 1
- and self.pp_size == 1
- and self.enable_sequence_parallelism
- and self.sequence_parallelism_mode == "all_to_all"
- )
- if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all":
- dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
- else:
- dp_group = self.dp_group
model = HybridParallelModule(
module=model,
precision=self.precision,
shard_config=self.shard_config,
- dp_group=dp_group,
+ dp_group=self.dp_group,
tp_group=self.tp_group,
sp_group=self.sp_group,
- use_ddp=use_ddp,
+ use_ddp=self.use_ddp,
ddp_config=self.ddp_config,
custom_policy=self.custom_policy,
)
@@ -271,7 +286,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
tp_process_group=self.tp_group,
)
else:
- if not (self.dp_size > 1 or self.moe_dp_size > 1):
+ if self.dp_size <= 1:
warnings.warn(
"Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. "
"If you do not intend to use cpu_offload, please consider set zero_stage=0."
diff --git a/colossalai/shardformer/modeling/deepseek.py b/colossalai/shardformer/modeling/deepseek.py
index 33fac9b93..a90cd8726 100644
--- a/colossalai/shardformer/modeling/deepseek.py
+++ b/colossalai/shardformer/modeling/deepseek.py
@@ -10,13 +10,7 @@ from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.utils import is_flash_attn_2_available, logging
from colossalai.lazy import LazyInitContext
-from colossalai.moe._operation import (
- DPGradScalerIn,
- DPGradScalerOut,
- EPGradScalerIn,
- EPGradScalerOut,
- all_to_all_uneven,
-)
+from colossalai.moe.operators import DPGradScalerIn, DPGradScalerOut, EPGradScalerIn, EPGradScalerOut, all_to_all_uneven
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer.linear import Linear1D_Col, Linear1D_Row
from colossalai.shardformer.shard import ShardConfig
diff --git a/colossalai/shardformer/modeling/mixtral.py b/colossalai/shardformer/modeling/mixtral.py
index 2b50f013d..f51e690d1 100644
--- a/colossalai/shardformer/modeling/mixtral.py
+++ b/colossalai/shardformer/modeling/mixtral.py
@@ -118,7 +118,7 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
selected_experts_idx = selected_experts.argsort()
dispatch_states = hidden_states.repeat(self.top_k, 1)[selected_experts_idx]
input_split_sizes = selected_experts.bincount(minlength=self.num_experts)
- dist.get_rank()
+
output_split_sizes = torch.zeros_like(input_split_sizes)
dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group)
diff --git a/tests/test_moe/modelling/test_deepseek.py b/tests/test_moe/modelling/test_deepseek.py
index 42daea512..74c72dd06 100644
--- a/tests/test_moe/modelling/test_deepseek.py
+++ b/tests/test_moe/modelling/test_deepseek.py
@@ -23,7 +23,7 @@ NUM_HEADS = 4
TOP_K = 1
-@parameterize("config", [(1, 1, 1)])
+@parameterize("config", [(0, 1, 1), (0, 1, 2), (0, 1, 4), (1, 1, 4), (1, 2, 2), (1, 4, 1)])
def run_zero_with_original_model(config: Tuple[int, ...]):
stage, ep_size, tp_size = config
dtype = torch.float16
diff --git a/tests/test_moe/modelling/test_mixtral.py b/tests/test_moe/modelling/test_mixtral.py
index 6e6f0b2b5..fe13b5b30 100644
--- a/tests/test_moe/modelling/test_mixtral.py
+++ b/tests/test_moe/modelling/test_mixtral.py
@@ -24,11 +24,10 @@ NUM_HEADS = 4
TOP_K = 1
-@parameterize("config", [(1, 1, 4), (1, 2, 2), (1, 4, 1)])
+@parameterize("config", [(0, 1, 1), (0, 1, 2), (0, 1, 4), (1, 1, 4), (1, 2, 2), (1, 4, 1)])
def run_zero_with_original_model(config: Tuple[int, ...]):
stage, ep_size, tp_size = config
- dtype = torch.float32
-
+ dtype, precision = torch.float16, "fp16"
rank = torch.distributed.get_rank()
torch.cuda.set_device(dist.get_rank())
@@ -40,7 +39,7 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
zero_stage=stage,
overlap_communication=False,
initial_scale=1,
- precision="fp32",
+ precision=precision,
)
booster = Booster(plugin=plugin)
@@ -109,7 +108,7 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
dist.barrier()
- saved_model = MixtralModel.from_pretrained(model_dir).cuda()
+ saved_model = MixtralModel.from_pretrained(model_dir).cuda().to(dtype)
check_model_equal(torch_model, saved_model)
dist.barrier()
diff --git a/tests/test_moe/test_moe_checkpoint.py b/tests/test_moe/test_moe_checkpoint.py
index 4bcf701de..1ab52b371 100644
--- a/tests/test_moe/test_moe_checkpoint.py
+++ b/tests/test_moe/test_moe_checkpoint.py
@@ -26,9 +26,7 @@ top_k = 2
def check_model_equal(model1, model2):
assert set(model1.state_dict().keys()) == set(model2.state_dict().keys())
for i, ((name, p1), p2) in enumerate(zip(model1.named_parameters(), model2.parameters())):
- if loose_close(p1, p2, p1.dtype):
- print(f"Model parameter {name} is not equal. is_moe_tensor: {is_moe_tensor(p1)}")
- raise AssertionError(f"Model parameter {name} is not equal")
+ loose_close(p1, p2, p1.dtype)
def get_optimizer_snapshot(optim):
diff --git a/tests/test_shardformer/test_model/test_shard_mixtral.py b/tests/test_shardformer/test_model/test_shard_mixtral.py
index e873f46f7..232e16f3b 100644
--- a/tests/test_shardformer/test_model/test_shard_mixtral.py
+++ b/tests/test_shardformer/test_model/test_shard_mixtral.py
@@ -141,12 +141,12 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
[
# {
# "tp_size": 1,
- # "pp_size": 2,
+ # "pp_size": 1,
# "num_microbatches": 2,
# "ep_size": 2,
- # "zero_stage": 1,
+ # "zero_stage": 0,
# "overlap_communication": False,
- # "precision": "fp32",
+ # "precision": "fp16",
# }, # [dp(4)] + [moe_dp(4)]
# {
# "tp_size": 1,
@@ -169,7 +169,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
{ # Ulysess + Flash attention
"tp_size": 1,
"pp_size": 1,
- "sp_size": 4,
+ "sp_size": 2,
"ep_size": 1,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "all_to_all",
From 2cddeac7174c5617b7a35cd83925a161173afe1b Mon Sep 17 00:00:00 2001
From: haze188
Date: Thu, 18 Jul 2024 10:08:06 +0000
Subject: [PATCH 079/116] moe sp + ep bug fix
---
.../plugin/moe_hybrid_parallel_plugin.py | 2 +-
.../test_model/test_shard_mixtral.py | 26 +++++++++++--------
2 files changed, 16 insertions(+), 12 deletions(-)
diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
index fc3340981..d4226b108 100644
--- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
+++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
@@ -221,7 +221,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
)
self.logger.info(
- f"{type(self).__name__}: {self.ep_size=} {self.moe_dp_size=} {self.moe_tp_size=} {self.sp_size}\n"
+ f"{type(self).__name__}: {self.ep_size=} {self.moe_dp_size=} {self.moe_tp_size=} {self.sp_size=}\n"
f"rank {dist.get_rank()} moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)} ep_group {dist.get_process_group_ranks(self.ep_group)} moe_tp_group {dist.get_process_group_ranks(self.moe_tp_group)} sp_group {dist.get_process_group_ranks(self.sp_group)}",
ranks=[0],
)
diff --git a/tests/test_shardformer/test_model/test_shard_mixtral.py b/tests/test_shardformer/test_model/test_shard_mixtral.py
index 232e16f3b..4e9c594d2 100644
--- a/tests/test_shardformer/test_model/test_shard_mixtral.py
+++ b/tests/test_shardformer/test_model/test_shard_mixtral.py
@@ -37,7 +37,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
)
- print(org_output.last_hidden_state.shape, sharded_output.last_hidden_state.shape)
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
@@ -65,9 +64,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all"
):
rank = dist.get_rank()
- # for p1, p2 in zip(mixtral_model.parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]):
- for (n1, p1), (n2, p2) in zip(mixtral_model.named_parameters(), shard_mixtral_model.named_parameters()):
- assert_close(p1.grad, p2.grad, atol=5e-3, rtol=5e-3, check_dtype=False)
+ name_to_p = {n: p for n, p in mixtral_model.named_parameters()}
+ for n, p in shard_mixtral_model.named_parameters():
+ zero_grad = sharded_optimizer.get_param_grad(p)
+ if name_to_p[n].grad is None:
+ name_to_p[n].grad = torch.zeros_like(name_to_p[n].data)
+ continue
+ assert_close(name_to_p[n].grad, zero_grad, atol=5e-3, rtol=5e-3, check_dtype=False)
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
grads_to_check = {}
@@ -100,16 +103,17 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
grads_to_check.update(row_layer_grads)
# check grads
- # print(grads_to_check)
check_all_grad_tensors(grads_to_check)
- for (n1, p1), (n2, p2) in zip(mixtral_model.named_parameters(), shard_mixtral_model.named_parameters()):
- assert_close(p1, p2, atol=5e-3, rtol=5e-3, check_dtype=False)
+
+ for n, p in shard_mixtral_model.named_parameters():
+ assert_close(name_to_p[n], p, atol=5e-3, rtol=5e-3, check_dtype=False)
# optimizer executes step
org_optimizer.step()
sharded_optimizer.step()
- for (n1, p1), (n2, p2) in zip(mixtral_model.named_parameters(), shard_mixtral_model.named_parameters()):
- assert_close(p1, p2, atol=5e-3, rtol=5e-3, check_dtype=False)
+
+ for n, p in shard_mixtral_model.named_parameters():
+ assert_close(name_to_p[n], p, atol=5e-3, rtol=5e-3, check_dtype=False)
# check weights
if stage_manager is None or stage_manager.is_first_stage():
@@ -170,10 +174,10 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"tp_size": 1,
"pp_size": 1,
"sp_size": 2,
- "ep_size": 1,
+ "ep_size": 2,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "all_to_all",
- "zero_stage": 0,
+ "zero_stage": 1,
"overlap_communication": False,
"precision": "fp16",
"initial_scale": 1,
From 7077d38d5a5b9243521f44d10d4dabc012044dbb Mon Sep 17 00:00:00 2001
From: hxwang
Date: Thu, 18 Jul 2024 13:36:18 +0000
Subject: [PATCH 080/116] [moe] finalize test (no pp)
---
.../plugin/moe_hybrid_parallel_plugin.py | 18 ++++++++-----
tests/test_moe/modelling/test_mixtral.py | 27 ++++++++++++-------
2 files changed, 29 insertions(+), 16 deletions(-)
diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
index d4226b108..31b346b10 100644
--- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
+++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
@@ -109,6 +109,9 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
super().__init__(*args, **kwargs)
+ if ep_size <= 1:
+ raise ValueError("Use HybridParallelPlugin when ep_size <= 1")
+
self.ep_size = ep_size
self.moe_tp_size = moe_tp_size
@@ -128,12 +131,12 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
self.ddp_config["find_unused_parameters"] = True
if dist.get_process_group_ranks(self.dp_group) != dist.get_process_group_ranks(self.moe_dp_group):
+ # TODO it might make sense to support non-moe with tp on but moe with tp off
raise ValueError(
- f"if ddp is used, dp_group and moe_dp_group are expected to be the same since DDP can only reduce grad across a single group, but found dp_group {dist.get_process_group_ranks(self.dp_group)} and moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)}, you might want to set ep_size=1 or zero_stage > 0"
+ f"if ddp is used, dp_group and moe_dp_group are expected to be the same since DDP can only reduce grad across a single group, but found dp_group {dist.get_process_group_ranks(self.dp_group)} and moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)}, you might want to use HybridParallelPlugin or set zero_stage > 0"
)
- # set ep_group after super().__init__()
- # TODO do it in a better way
+ # set param group in shard config
self.shard_config.ep_group = self.ep_group
self.shard_config.moe_dp_group = self.moe_dp_group
self.shard_config.moe_tp_group = self.moe_tp_group
@@ -149,9 +152,10 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
# when sequence parallelism is enabled, ep_group reuses sp_group
if self.ep_size != self.sp_size:
raise ValueError(
- f"ep_size={self.ep_size} should be equal to sp_size={self.sp_size} when sequence parallelism is enabled"
+ f"ep_size={self.ep_size} should be equal to sp_size={self.sp_size} or turned off when sequence parallelism is enabled"
)
+ # since we are reusing sp_group, moe_dp_group will be derived as dp_group
self.moe_dp_size = self.dp_size
self.moe_dp_group = self.dp_group # NOTE: sequence of value assignment matters
self.dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
@@ -165,7 +169,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
else:
self.moe_dp_size = world_size // (self.pp_size * self.ep_size * self.moe_tp_size)
- if self.pp_size * self.moe_dp_size * self.ep_size * self.moe_tp_size * self.sp_size != world_size:
+ if self.moe_dp_size * self.pp_size * self.ep_size * self.moe_tp_size != world_size:
raise ValueError(
f"world_size={world_size} is not divisible by pp_size={self.pp_size} * moe_dp_size={self.moe_dp_size} * ep_size={self.ep_size} * moe_tp_size={self.moe_tp_size}"
)
@@ -214,8 +218,8 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
self.moe_tp_group = group
if dist.get_process_group_ranks(self.tp_group) != dist.get_process_group_ranks(self.moe_tp_group):
- # NOTE: different tp settings between moe and non moe param require complex comm logic, where all_to_all might not be suitable
- # this assertion implies that dp_size == moe_dp_size * ep_size
+ # NOTE: different tp settings between moe and non moe param are complex to handle
+ # we simply reuse tp_group as moe_tp_group, this implies that dp_size == moe_dp_size * ep_size
raise NotImplementedError(
f"Only support shared tp group between moe and non moe params, but found non-moe tp {dist.get_process_group_ranks(self.tp_group)}, moe tp {dist.get_process_group_ranks(self.moe_tp_group)}, please make sure tp_size == moe_tp_size"
)
diff --git a/tests/test_moe/modelling/test_mixtral.py b/tests/test_moe/modelling/test_mixtral.py
index fe13b5b30..69d9fa5d4 100644
--- a/tests/test_moe/modelling/test_mixtral.py
+++ b/tests/test_moe/modelling/test_mixtral.py
@@ -18,28 +18,34 @@ from tests.test_moe.moe_utils import loose_close
from tests.test_moe.test_moe_checkpoint import check_model_equal
NUM_BATCH = 4
-NUM_TOK_PER_BATCH, NUM_EXPERTS = 7, 4
+NUM_TOK_PER_BATCH, NUM_EXPERTS = 8, 4
HIDDEN_SIZE_PER_HEAD = 4
NUM_HEADS = 4
TOP_K = 1
-@parameterize("config", [(0, 1, 1), (0, 1, 2), (0, 1, 4), (1, 1, 4), (1, 2, 2), (1, 4, 1)])
+@parameterize("config", [(2, 1, 2, 1, 2, 1), (2, 1, 2, 1, 1, 2), (4, 1, 1, 1, 2, 1), (4, 1, 2, 1, 1, 1)])
def run_zero_with_original_model(config: Tuple[int, ...]):
- stage, ep_size, tp_size = config
- dtype, precision = torch.float16, "fp16"
+ ep_size, stage, dp_size, pp_size, tp_size, sp_size = config
+ print(config)
rank = torch.distributed.get_rank()
+ dtype, precision = torch.float16, "fp16"
torch.cuda.set_device(dist.get_rank())
plugin = MoeHybridParallelPlugin(
- pp_size=1,
+ pp_size=pp_size,
+ num_microbatches=pp_size,
tp_size=tp_size,
- moe_tp_size=tp_size,
+ sp_size=sp_size,
ep_size=ep_size,
+ moe_tp_size=tp_size,
zero_stage=stage,
+ enable_sequence_parallelism=sp_size > 1,
+ sequence_parallelism_mode="all_to_all" if sp_size > 1 else None,
overlap_communication=False,
initial_scale=1,
precision=precision,
+ find_unused_parameters=True,
)
booster = Booster(plugin=plugin)
@@ -53,6 +59,7 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
num_key_value_heads=NUM_HEADS,
num_local_experts=NUM_EXPERTS,
num_experts_per_tok=TOP_K,
+ attn_implementation="flash_attention_2",
)
torch_model = MixtralModel(config).to(dtype).cuda()
@@ -72,7 +79,9 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
input_data = torch.rand(
NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True
).cuda()
- dist.all_reduce(input_data, group=plugin.tp_group) # tp requires duplicate input
+
+ dist.all_reduce(input_data, group=plugin.tp_group) # tp group requires duplicate input
+ dist.all_reduce(input_data, group=plugin.sp_group) # sp group requires duplicate input
zero_output = zero_model(inputs_embeds=input_data.to(dtype)).last_hidden_state.mean()
zero_optimizer.backward(zero_output)
@@ -124,11 +133,11 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist
-@pytest.mark.parametrize("world_size", [4])
+@pytest.mark.parametrize("world_size", [8])
@rerun_if_address_is_in_use()
def test_mistral(world_size):
spawn(run_dist, world_size)
if __name__ == "__main__":
- test_mistral(world_size=4)
+ test_mistral(world_size=8)
From 803878b2fd55ae5a95088723fc75bc7976d6a801 Mon Sep 17 00:00:00 2001
From: hxwang
Date: Fri, 19 Jul 2024 06:11:11 +0000
Subject: [PATCH 081/116] [moe] full test for deepseek and mixtral (pp + sp to
fix)
---
.../booster/plugin/hybrid_parallel_plugin.py | 4 +
.../plugin/moe_hybrid_parallel_plugin.py | 16 +-
colossalai/legacy/moe/layer/experts.py | 2 +-
colossalai/legacy/moe/layer/layers.py | 2 +-
colossalai/legacy/moe/layer/routers.py | 2 +-
.../moe/{operators.py => _operation.py} | 0
colossalai/shardformer/modeling/deepseek.py | 8 +-
colossalai/shardformer/modeling/mixtral.py | 40 +-
colossalai/shardformer/policies/deepseek.py | 9 +-
colossalai/shardformer/policies/mixtral.py | 33 +-
.../low_level/bookkeeping/bucket_store.py | 2 +-
colossalai/zero/low_level/low_level_optim.py | 2 +-
tests/test_moe/modelling/test_deepseek.py | 133 -------
tests/test_moe/modelling/test_mixtral.py | 143 -------
tests/test_moe/test_kernel.py | 2 +-
.../test_model/test_shard_deepseek.py | 186 +++++++++
.../test_model/test_shard_mixtral.py | 363 ++++++++----------
17 files changed, 430 insertions(+), 517 deletions(-)
rename colossalai/moe/{operators.py => _operation.py} (100%)
delete mode 100644 tests/test_moe/modelling/test_deepseek.py
delete mode 100644 tests/test_moe/modelling/test_mixtral.py
create mode 100644 tests/test_shardformer/test_model/test_shard_deepseek.py
diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py
index 03b7bebb1..4b1bd0f47 100644
--- a/colossalai/booster/plugin/hybrid_parallel_plugin.py
+++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py
@@ -1122,6 +1122,10 @@ class HybridParallelPlugin(PipelinePluginBase):
else:
self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis)
+ self.logger.info(
+ f"{type(self).__name__}: dp_group {dist.get_process_group_ranks(self.dp_group)} pp_group {dist.get_process_group_ranks(self.pp_group)} tp_group {dist.get_process_group_ranks(self.tp_group)} sp_group {dist.get_process_group_ranks(self.sp_group)}",
+ ranks=[0, 1, 2, 3, 4, 5, 6, 7],
+ )
self.shard_config = ShardConfig(
tensor_parallel_process_group=self.tp_group,
sequence_parallel_process_group=self.sp_group,
diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
index 31b346b10..32673169a 100644
--- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
+++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
@@ -147,9 +147,8 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
world_size = dist.get_world_size()
if self.enable_sequence_parallelism:
- # if sequence parallelism is enabled, we reuse the same group for ep and sp
if self.sequence_parallelism_mode == "all_to_all":
- # when sequence parallelism is enabled, ep_group reuses sp_group
+ # if sequence parallelism is enabled, ep_group reuses sp_group
if self.ep_size != self.sp_size:
raise ValueError(
f"ep_size={self.ep_size} should be equal to sp_size={self.sp_size} or turned off when sequence parallelism is enabled"
@@ -157,8 +156,8 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
# since we are reusing sp_group, moe_dp_group will be derived as dp_group
self.moe_dp_size = self.dp_size
- self.moe_dp_group = self.dp_group # NOTE: sequence of value assignment matters
- self.dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
+ self.moe_dp_group = self.dp_group
+ self.dp_sp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
self.ep_group = self.sp_group
self.moe_tp_group = self.tp_group
else:
@@ -177,6 +176,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
self.moe_dp_group = None
self.ep_group = None
self.moe_tp_group = None
+ self.dp_sp_group = self.dp_group
# create submesh for ep, moe_dp, moe_tp
ranks_by_pp_stage = self.pg_mesh.get_group_along_axis(
@@ -225,8 +225,8 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
)
self.logger.info(
- f"{type(self).__name__}: {self.ep_size=} {self.moe_dp_size=} {self.moe_tp_size=} {self.sp_size=}\n"
- f"rank {dist.get_rank()} moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)} ep_group {dist.get_process_group_ranks(self.ep_group)} moe_tp_group {dist.get_process_group_ranks(self.moe_tp_group)} sp_group {dist.get_process_group_ranks(self.sp_group)}",
+ f"{type(self).__name__}: {self.ep_size=} {self.moe_dp_size=} {self.moe_tp_size=}\n"
+ f"rank {dist.get_rank()} moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)} ep_group {dist.get_process_group_ranks(self.ep_group)} moe_tp_group {dist.get_process_group_ranks(self.moe_tp_group)}",
ranks=[0],
)
@@ -254,7 +254,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
module=model,
precision=self.precision,
shard_config=self.shard_config,
- dp_group=self.dp_group,
+ dp_group=self.dp_sp_group,
tp_group=self.tp_group,
sp_group=self.sp_group,
use_ddp=self.use_ddp,
@@ -302,7 +302,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
use_pipeline=self.enable_pipeline_parallelism,
force_overlap_comm=self.force_overlap_comm,
param_info=param_info,
- dp_process_group=self.dp_group,
+ dp_process_group=self.dp_sp_group,
tp_process_group=self.tp_group,
pp_process_group=self.pp_group,
moe_dp_group=self.moe_dp_group,
diff --git a/colossalai/legacy/moe/layer/experts.py b/colossalai/legacy/moe/layer/experts.py
index c16fc77bb..8088cf44e 100644
--- a/colossalai/legacy/moe/layer/experts.py
+++ b/colossalai/legacy/moe/layer/experts.py
@@ -7,7 +7,7 @@ import torch.nn as nn
from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON
from colossalai.legacy.moe.manager import MOE_MANAGER
from colossalai.legacy.moe.utils import get_activation
-from colossalai.moe.operators import EPGradScalerIn, EPGradScalerOut
+from colossalai.moe._operation import EPGradScalerIn, EPGradScalerOut
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.moe_tensor.api import get_ep_rank, get_ep_size
diff --git a/colossalai/legacy/moe/layer/layers.py b/colossalai/legacy/moe/layer/layers.py
index 8681b5972..e43966f68 100644
--- a/colossalai/legacy/moe/layer/layers.py
+++ b/colossalai/legacy/moe/layer/layers.py
@@ -9,7 +9,7 @@ import torch.nn.functional as F
from colossalai.legacy.moe.load_balance import LoadBalancer
from colossalai.legacy.moe.utils import create_ep_hierarchical_group, get_noise_generator
-from colossalai.moe.operators import AllGather, AllToAll, HierarchicalAllToAll, MoeCombine, MoeDispatch, ReduceScatter
+from colossalai.moe._operation import AllGather, AllToAll, HierarchicalAllToAll, MoeCombine, MoeDispatch, ReduceScatter
from colossalai.shardformer.layer.moe import MLPExperts
from colossalai.tensor.moe_tensor.api import get_dp_group, get_ep_group, get_ep_group_ranks, get_ep_size
diff --git a/colossalai/legacy/moe/layer/routers.py b/colossalai/legacy/moe/layer/routers.py
index c16fc77bb..8088cf44e 100644
--- a/colossalai/legacy/moe/layer/routers.py
+++ b/colossalai/legacy/moe/layer/routers.py
@@ -7,7 +7,7 @@ import torch.nn as nn
from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON
from colossalai.legacy.moe.manager import MOE_MANAGER
from colossalai.legacy.moe.utils import get_activation
-from colossalai.moe.operators import EPGradScalerIn, EPGradScalerOut
+from colossalai.moe._operation import EPGradScalerIn, EPGradScalerOut
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.moe_tensor.api import get_ep_rank, get_ep_size
diff --git a/colossalai/moe/operators.py b/colossalai/moe/_operation.py
similarity index 100%
rename from colossalai/moe/operators.py
rename to colossalai/moe/_operation.py
diff --git a/colossalai/shardformer/modeling/deepseek.py b/colossalai/shardformer/modeling/deepseek.py
index a90cd8726..33fac9b93 100644
--- a/colossalai/shardformer/modeling/deepseek.py
+++ b/colossalai/shardformer/modeling/deepseek.py
@@ -10,7 +10,13 @@ from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.utils import is_flash_attn_2_available, logging
from colossalai.lazy import LazyInitContext
-from colossalai.moe.operators import DPGradScalerIn, DPGradScalerOut, EPGradScalerIn, EPGradScalerOut, all_to_all_uneven
+from colossalai.moe._operation import (
+ DPGradScalerIn,
+ DPGradScalerOut,
+ EPGradScalerIn,
+ EPGradScalerOut,
+ all_to_all_uneven,
+)
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer.linear import Linear1D_Col, Linear1D_Row
from colossalai.shardformer.shard import ShardConfig
diff --git a/colossalai/shardformer/modeling/mixtral.py b/colossalai/shardformer/modeling/mixtral.py
index f51e690d1..90616351a 100644
--- a/colossalai/shardformer/modeling/mixtral.py
+++ b/colossalai/shardformer/modeling/mixtral.py
@@ -23,7 +23,13 @@ from transformers.models.mixtral.modeling_mixtral import (
from transformers.utils import is_flash_attn_2_available, logging
from colossalai.lazy import LazyInitContext
-from colossalai.moe.operators import DPGradScalerIn, DPGradScalerOut, EPGradScalerIn, EPGradScalerOut, all_to_all_uneven
+from colossalai.moe._operation import (
+ DPGradScalerIn,
+ DPGradScalerOut,
+ EPGradScalerIn,
+ EPGradScalerOut,
+ all_to_all_uneven,
+)
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer._operation import (
all_to_all_comm,
@@ -245,6 +251,7 @@ class MixtralPipelineForwards:
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
+ print("input_ids", input_ids.shape)
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
@@ -372,16 +379,29 @@ class MixtralPipelineForwards:
if output_router_logits and past_router_logits is not None:
all_router_logits = past_router_logits + all_router_logits
if stage_manager.is_last_stage():
- return tuple(
- v
- for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
- if v is not None
+ if not return_dict:
+ return tuple(
+ v
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
+ if v is not None
+ )
+ return MoeModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ router_logits=all_router_logits,
)
- # always return dict for imediate stage
- return {
- "hidden_states": hidden_states,
- "past_router_logits": all_router_logits,
- }
+ else:
+ if output_router_logits:
+ return {
+ "hidden_states": hidden_states,
+ "past_router_logits": all_router_logits,
+ }
+ else:
+ return {
+ "hidden_states": hidden_states,
+ }
@staticmethod
def mixtral_for_causal_lm_forward(
diff --git a/colossalai/shardformer/policies/deepseek.py b/colossalai/shardformer/policies/deepseek.py
index 5a67d653d..04d1dcd41 100644
--- a/colossalai/shardformer/policies/deepseek.py
+++ b/colossalai/shardformer/policies/deepseek.py
@@ -34,7 +34,10 @@ class DeepseekPolicy(Policy):
policy = {}
if self.shard_config.enable_sequence_parallelism:
- self.shard_config.enable_sequence_parallelism = False
+ if self.pipeline_stage_manager is not None:
+ # NOTE: we are replacing model forward for both sequence parallelism and pipeline parallelism
+ # if both are enabled, one of them will be ignored
+ raise NotImplementedError("Sequence parallelism is not supported with pipeline parallelism.")
raise NotImplementedError(
"Deepseek dosen't support sequence parallelism now, will ignore the sequence parallelism flag."
)
@@ -136,6 +139,10 @@ class DeepseekPolicy(Policy):
"""If under pipeline parallel setting, replacing the original forward method of huggingface
to customized forward method, and add this changing to policy."""
if self.pipeline_stage_manager:
+ if self.shard_config.enable_sequence_parallelism:
+ # NOTE: we are replacing model forward for both sequence parallelism and pipeline parallelism
+ # if both are enabled, one of them will be ignored
+ raise NotImplementedError("Pipeline parallelism is not supported with sequence parallelism.")
stage_manager = self.pipeline_stage_manager
if self.model.__class__.__name__ == "DeepseekModel":
module = self.model
diff --git a/colossalai/shardformer/policies/mixtral.py b/colossalai/shardformer/policies/mixtral.py
index 8fed5ee5c..4de982f44 100644
--- a/colossalai/shardformer/policies/mixtral.py
+++ b/colossalai/shardformer/policies/mixtral.py
@@ -62,6 +62,10 @@ class MixtralPolicy(Policy):
attribute_replacement=decoder_attribute_replacement,
)
if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism:
+ if self.pipeline_stage_manager is not None:
+ # NOTE: we are replacing model forward for both sequence parallelism and pipeline parallelism
+ # if both are enabled, one of them will be ignored
+ raise NotImplementedError("Sequence parallelism is not supported with pipeline parallelism.")
self.append_or_create_method_replacement(
description={
"forward": get_mixtral_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group),
@@ -69,19 +73,18 @@ class MixtralPolicy(Policy):
policy=policy,
target_key=attn_cls,
)
- if self.pipeline_stage_manager is None:
- self.append_or_create_method_replacement(
- description={
- "forward": get_mixtral_flash_attention_model_forward(
- self.shard_config,
- sp_mode=sp_mode,
- sp_size=sp_size,
- sp_group=sp_group,
- ),
- },
- policy=policy,
- target_key=MixtralModel,
- )
+ self.append_or_create_method_replacement(
+ description={
+ "forward": get_mixtral_flash_attention_model_forward(
+ self.shard_config,
+ sp_mode=sp_mode,
+ sp_size=sp_size,
+ sp_group=sp_group,
+ ),
+ },
+ policy=policy,
+ target_key=MixtralModel,
+ )
embedding_cls = None
if self.shard_config.enable_tensor_parallelism:
@@ -202,6 +205,10 @@ class MixtralPolicy(Policy):
"""If under pipeline parallel setting, replacing the original forward method of huggingface
to customized forward method, and add this changing to policy."""
if self.pipeline_stage_manager:
+ if self.shard_config.enable_sequence_parallelism:
+ # NOTE: we are replacing model forward for both sequence parallelism and pipeline parallelism
+ # if both are enabled, one of them will be ignored
+ raise NotImplementedError("Pipeline parallelism is not supported with sequence parallelism.")
stage_manager = self.pipeline_stage_manager
if self.model.__class__.__name__ == "MixtralModel":
module = self.model
diff --git a/colossalai/zero/low_level/bookkeeping/bucket_store.py b/colossalai/zero/low_level/bookkeeping/bucket_store.py
index 78c34046a..19d20de2b 100644
--- a/colossalai/zero/low_level/bookkeeping/bucket_store.py
+++ b/colossalai/zero/low_level/bookkeeping/bucket_store.py
@@ -100,7 +100,7 @@ class BucketStore(BaseStore):
return self._grad_in_bucket
- def get_flatten_grad(self, dtype=None) -> Tensor:
+ def get_flatten_grad(self) -> Tensor:
"""Return the flattened gradients slices in the bucket, the data organization of the flattened tensor:
[grad0_rank0, grad1_rank0, ..., grad_0_rank1, grad1_rank1, ....]
diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py
index 5a41345f0..c2f9e98a1 100644
--- a/colossalai/zero/low_level/low_level_optim.py
+++ b/colossalai/zero/low_level/low_level_optim.py
@@ -303,7 +303,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
for bucket_store in self.pg_to_bucket_store.values():
bucket_store.build_grad_in_bucket()
- flat_grads = bucket_store.get_flatten_grad(self._dtype)
+ flat_grads = bucket_store.get_flatten_grad()
flat_grads /= bucket_store.world_size
# ready to add other tensors to bucket
diff --git a/tests/test_moe/modelling/test_deepseek.py b/tests/test_moe/modelling/test_deepseek.py
deleted file mode 100644
index 74c72dd06..000000000
--- a/tests/test_moe/modelling/test_deepseek.py
+++ /dev/null
@@ -1,133 +0,0 @@
-import os
-import shutil
-from copy import deepcopy
-from typing import Tuple
-
-import pytest
-import torch
-import torch.distributed as dist
-from transformers import AutoConfig, AutoModel
-
-import colossalai
-from colossalai.booster.booster import Booster
-from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
-from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
-from colossalai.testing.random import seed_all
-from tests.test_moe.moe_utils import loose_close
-from tests.test_moe.test_moe_checkpoint import check_model_equal
-
-NUM_BATCH = 4
-NUM_TOK_PER_BATCH, NUM_EXPERTS = 7, 4
-HIDDEN_SIZE_PER_HEAD = 4
-NUM_HEADS = 4
-TOP_K = 1
-
-
-@parameterize("config", [(0, 1, 1), (0, 1, 2), (0, 1, 4), (1, 1, 4), (1, 2, 2), (1, 4, 1)])
-def run_zero_with_original_model(config: Tuple[int, ...]):
- stage, ep_size, tp_size = config
- dtype = torch.float16
-
- rank = torch.distributed.get_rank()
- torch.cuda.set_device(dist.get_rank())
-
- plugin = MoeHybridParallelPlugin(
- pp_size=1,
- tp_size=tp_size,
- moe_tp_size=tp_size,
- ep_size=ep_size,
- zero_stage=stage,
- overlap_communication=False,
- initial_scale=1,
- precision="fp32",
- )
- booster = Booster(plugin=plugin)
-
- seed_all(10086)
-
- config = AutoConfig.from_pretrained("deepseek-ai/deepseek-moe-16b-base", trust_remote_code=True)
- config.hidden_size = HIDDEN_SIZE_PER_HEAD * NUM_HEADS
- config.intermediate_size = HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2
- config.num_hidden_layers = 2
- config.num_attention_heads = NUM_HEADS
- config.num_key_value_heads = NUM_HEADS
- config.n_routed_experts = NUM_EXPERTS
- config.num_experts_per_tok = TOP_K
- torch_model = AutoModel.from_config(config, trust_remote_code=True).cuda().to(dtype)
-
- torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
-
- zero_model = deepcopy(torch_model).to(dtype)
- zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1)
-
- zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer)
-
- # create different input
- seed_all(1453 + rank)
-
- torch_model.train()
- zero_model.train()
- for _ in range(2):
- input_data = torch.rand(
- NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True
- ).cuda()
- dist.all_reduce(input_data, group=plugin.tp_group) # tp requires duplicate input
-
- zero_output = zero_model(inputs_embeds=input_data.to(dtype)).last_hidden_state.mean()
- zero_optimizer.backward(zero_output)
- zero_optimizer.step()
- zero_optimizer.zero_grad()
- dist.all_reduce(zero_output)
-
- all_inputs = [torch.empty_like(input_data) for _ in range(dist.get_world_size())]
- dist.all_gather(all_inputs, input_data)
-
- torch_output_sum = 0
- for input_data_ in all_inputs:
- torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()
- torch_output.backward()
- torch_output_sum += torch_output.detach()
- # avg dp grads
- for p in torch_model.parameters():
- if p.grad is not None:
- p.grad /= dist.get_world_size()
- torch_optimizer.step()
- torch_optimizer.zero_grad()
-
- loose_close(zero_output, torch_output_sum, dtype=dtype)
-
- # use checkpoint to load sharded zero model
- model_dir = "./test_deepseek"
- if dist.get_rank() == 0:
- os.makedirs(model_dir, exist_ok=True)
-
- dist.barrier()
-
- booster.save_model(zero_model, model_dir, shard=True)
-
- dist.barrier()
-
- saved_model = AutoModel.from_pretrained(model_dir, trust_remote_code=True).cuda()
- check_model_equal(torch_model, saved_model)
-
- dist.barrier()
- if dist.get_rank() == 0:
- shutil.rmtree(model_dir)
-
- print(f"{dist.get_rank()} test passed")
-
-
-def run_dist(rank, world_size, port):
- colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
- run_zero_with_original_model()
-
-
-@pytest.mark.dist
-@pytest.mark.parametrize("world_size", [4])
-@rerun_if_address_is_in_use()
-def test_mistral(world_size):
- spawn(run_dist, world_size)
-
-
-if __name__ == "__main__":
- test_mistral(world_size=4)
diff --git a/tests/test_moe/modelling/test_mixtral.py b/tests/test_moe/modelling/test_mixtral.py
deleted file mode 100644
index 69d9fa5d4..000000000
--- a/tests/test_moe/modelling/test_mixtral.py
+++ /dev/null
@@ -1,143 +0,0 @@
-import os
-import shutil
-from copy import deepcopy
-from typing import Tuple
-
-import pytest
-import torch
-import torch.distributed as dist
-from transformers.models.mixtral.configuration_mixtral import MixtralConfig
-from transformers.models.mixtral.modeling_mixtral import MixtralModel
-
-import colossalai
-from colossalai.booster.booster import Booster
-from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
-from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
-from colossalai.testing.random import seed_all
-from tests.test_moe.moe_utils import loose_close
-from tests.test_moe.test_moe_checkpoint import check_model_equal
-
-NUM_BATCH = 4
-NUM_TOK_PER_BATCH, NUM_EXPERTS = 8, 4
-HIDDEN_SIZE_PER_HEAD = 4
-NUM_HEADS = 4
-TOP_K = 1
-
-
-@parameterize("config", [(2, 1, 2, 1, 2, 1), (2, 1, 2, 1, 1, 2), (4, 1, 1, 1, 2, 1), (4, 1, 2, 1, 1, 1)])
-def run_zero_with_original_model(config: Tuple[int, ...]):
- ep_size, stage, dp_size, pp_size, tp_size, sp_size = config
- print(config)
- rank = torch.distributed.get_rank()
- dtype, precision = torch.float16, "fp16"
- torch.cuda.set_device(dist.get_rank())
-
- plugin = MoeHybridParallelPlugin(
- pp_size=pp_size,
- num_microbatches=pp_size,
- tp_size=tp_size,
- sp_size=sp_size,
- ep_size=ep_size,
- moe_tp_size=tp_size,
- zero_stage=stage,
- enable_sequence_parallelism=sp_size > 1,
- sequence_parallelism_mode="all_to_all" if sp_size > 1 else None,
- overlap_communication=False,
- initial_scale=1,
- precision=precision,
- find_unused_parameters=True,
- )
- booster = Booster(plugin=plugin)
-
- seed_all(10086)
-
- config = MixtralConfig(
- hidden_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS,
- intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2,
- num_hidden_layers=2,
- num_attention_heads=NUM_HEADS,
- num_key_value_heads=NUM_HEADS,
- num_local_experts=NUM_EXPERTS,
- num_experts_per_tok=TOP_K,
- attn_implementation="flash_attention_2",
- )
-
- torch_model = MixtralModel(config).to(dtype).cuda()
- torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
-
- zero_model = deepcopy(torch_model).to(dtype)
- zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1)
-
- zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer)
-
- # create different input
- seed_all(1453 + rank)
-
- torch_model.train()
- zero_model.train()
- for _ in range(2):
- input_data = torch.rand(
- NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True
- ).cuda()
-
- dist.all_reduce(input_data, group=plugin.tp_group) # tp group requires duplicate input
- dist.all_reduce(input_data, group=plugin.sp_group) # sp group requires duplicate input
-
- zero_output = zero_model(inputs_embeds=input_data.to(dtype)).last_hidden_state.mean()
- zero_optimizer.backward(zero_output)
- zero_optimizer.step()
- zero_optimizer.zero_grad()
- dist.all_reduce(zero_output)
-
- all_inputs = [torch.empty_like(input_data) for _ in range(dist.get_world_size())]
- dist.all_gather(all_inputs, input_data)
-
- torch_output_sum = 0
- for input_data_ in all_inputs:
- torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()
- torch_output.backward()
- torch_output_sum += torch_output.detach()
- # avg dp grads
- for p in torch_model.parameters():
- if p.grad is not None:
- p.grad /= dist.get_world_size()
- torch_optimizer.step()
- torch_optimizer.zero_grad()
-
- loose_close(zero_output, torch_output_sum, dtype=dtype)
-
- # use checkpoint to load sharded zero model
- model_dir = "./test_mixtral"
- if dist.get_rank() == 0:
- os.makedirs(model_dir, exist_ok=True)
-
- dist.barrier()
-
- booster.save_model(zero_model, model_dir, shard=True)
-
- dist.barrier()
-
- saved_model = MixtralModel.from_pretrained(model_dir).cuda().to(dtype)
- check_model_equal(torch_model, saved_model)
-
- dist.barrier()
- if dist.get_rank() == 0:
- shutil.rmtree(model_dir)
-
- print(f"{dist.get_rank()} test passed")
-
-
-def run_dist(rank, world_size, port):
- colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
- run_zero_with_original_model()
-
-
-@pytest.mark.dist
-@pytest.mark.parametrize("world_size", [8])
-@rerun_if_address_is_in_use()
-def test_mistral(world_size):
- spawn(run_dist, world_size)
-
-
-if __name__ == "__main__":
- test_mistral(world_size=8)
diff --git a/tests/test_moe/test_kernel.py b/tests/test_moe/test_kernel.py
index f2c6d206f..c81023988 100644
--- a/tests/test_moe/test_kernel.py
+++ b/tests/test_moe/test_kernel.py
@@ -4,7 +4,7 @@ import pytest
import torch
from colossalai.accelerator import get_accelerator
-from colossalai.moe.operators import MoeCombine, MoeDispatch, moe_cumsum
+from colossalai.moe._operation import MoeCombine, MoeDispatch, moe_cumsum
NUM_EXPERTS = 4
BATCH_SIZE = 4
diff --git a/tests/test_shardformer/test_model/test_shard_deepseek.py b/tests/test_shardformer/test_model/test_shard_deepseek.py
new file mode 100644
index 000000000..96edfb487
--- /dev/null
+++ b/tests/test_shardformer/test_model/test_shard_deepseek.py
@@ -0,0 +1,186 @@
+import os
+import shutil
+from copy import deepcopy
+from typing import Tuple
+
+import pytest
+import torch
+import torch.distributed
+import torch.distributed as dist
+from transformers import AutoConfig, AutoModel
+
+import colossalai
+from colossalai.booster.booster import Booster
+from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
+from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
+from colossalai.testing.random import seed_all
+from tests.test_moe.moe_utils import loose_close
+from tests.test_moe.test_moe_checkpoint import check_model_equal
+
+NUM_BATCH = 8
+NUM_TOK_PER_BATCH, NUM_EXPERTS = 4, 4
+NUM_LAYERS = 4
+HIDDEN_SIZE_PER_HEAD = 4
+NUM_HEADS = 4
+TOP_K = 1
+
+
+# TODO only need to keep one or two cases
+@parameterize(
+ "config",
+ [
+ (2, 1, 1, 4, 1),
+ # (2, 1, 2, 1, 1), # TODO debug deepseek pp
+ # (2, 1, 2, 2, 1), # TODO debug deepseek pp
+ (2, 1, 1, 2, 1),
+ # (2, 1, 1, 1, 2), # TODO support deepseek sp
+ # (2, 1, 4, 1, 1), # TODO debug deepseek pp
+ (4, 1, 1, 1, 1),
+ (4, 1, 1, 2, 1),
+ # (4, 1, 2, 1, 1), # TODO debug deepseek pp
+ ],
+)
+def run_zero_with_original_model(config: Tuple[int, ...]):
+ ep_size, stage, pp_size, tp_size, sp_size = config
+ world_size = dist.get_world_size()
+ rank = dist.get_rank()
+ dtype, precision = torch.float16, "fp16"
+ torch.cuda.set_device(dist.get_rank())
+
+ print(config)
+ plugin = MoeHybridParallelPlugin(
+ pp_size=pp_size,
+ num_microbatches=pp_size,
+ tp_size=tp_size,
+ sp_size=sp_size,
+ ep_size=ep_size,
+ moe_tp_size=tp_size,
+ zero_stage=stage,
+ enable_sequence_parallelism=sp_size > 1,
+ sequence_parallelism_mode="all_to_all" if sp_size > 1 else None,
+ overlap_communication=False,
+ initial_scale=1,
+ precision=precision,
+ find_unused_parameters=True,
+ )
+ dp_size = plugin.dp_size
+
+ booster = Booster(plugin=plugin)
+
+ # init model with the same seed
+ seed_all(10086)
+
+ assert pp_size <= NUM_LAYERS, "pp_size should be less than or equal to NUM_LAYERS"
+ config = AutoConfig.from_pretrained("deepseek-ai/deepseek-moe-16b-base", trust_remote_code=True)
+ config.hidden_size = HIDDEN_SIZE_PER_HEAD * NUM_HEADS
+ config.intermediate_size = HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2
+ config.num_hidden_layers = 2
+ config.num_attention_heads = NUM_HEADS
+ config.num_key_value_heads = NUM_HEADS
+ config.n_routed_experts = NUM_EXPERTS
+ config.num_experts_per_tok = TOP_K
+
+ torch_model = AutoModel.from_config(config, trust_remote_code=True).cuda().to(dtype)
+ torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
+
+ parallel_model = deepcopy(torch_model)
+ parallel_optimizer = torch.optim.SGD(parallel_model.parameters(), lr=1)
+ parallel_model, parallel_optimizer, _, _, _ = booster.boost(parallel_model, parallel_optimizer)
+
+ # create different input along dp axis
+ seed_all(1453 + rank)
+
+ torch_model.train()
+ parallel_model.train()
+ for _ in range(2):
+ # gen random input
+ input_embeddings = torch.rand(
+ NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True
+ ).cuda()
+ dist.all_reduce(
+ input_embeddings, group=plugin.pp_group
+ ) # pp inputs except the first stage doesn't matter, but need to be replicate for torch model check
+
+ dist.all_reduce(input_embeddings, group=plugin.tp_group) # tp group duplicate input
+ dist.all_reduce(input_embeddings, group=plugin.sp_group) # sp group duplicate input
+
+ # run the model with hybrid parallel
+ if booster.plugin.stage_manager is not None:
+ # for test with pp
+ data_iter = iter([{"inputs_embeds": input_embeddings}])
+ sharded_output = booster.execute_pipeline(
+ data_iter,
+ parallel_model,
+ lambda x, y: x[0].mean(),
+ parallel_optimizer,
+ return_loss=True,
+ return_outputs=True,
+ )
+ if booster.plugin.stage_manager.is_last_stage():
+ parallel_output = sharded_output["loss"]
+ else:
+ parallel_output = torch.tensor(12345.0, device="cuda")
+
+ # broadcast along pp axis
+ dist.broadcast(
+ parallel_output, src=dist.get_process_group_ranks(plugin.pp_group)[-1], group=plugin.pp_group
+ )
+ else:
+ # for test without pp
+ parallel_output = parallel_model(inputs_embeds=input_embeddings.to(dtype)).last_hidden_state.mean()
+ parallel_optimizer.backward(parallel_output)
+ parallel_optimizer.step()
+ parallel_optimizer.zero_grad()
+ dist.all_reduce(parallel_output, group=plugin.dp_group)
+
+ # ===================================================================================
+ # run normal model with all dp(different) inputs
+ all_inputs = [torch.empty_like(input_embeddings) for _ in range(dp_size)]
+ dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group)
+ torch_output_sum = 0
+ for input_data_ in all_inputs:
+ torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()
+ torch_output.backward()
+ torch_output_sum += torch_output.detach()
+ # avg dp grads follows zero optimizer
+ for p in torch_model.parameters():
+ if p.grad is not None:
+ p.grad /= dp_size
+ torch_optimizer.step()
+ torch_optimizer.zero_grad()
+
+ loose_close(parallel_output, torch_output_sum, dtype=dtype)
+
+ # use checkpoint to load sharded zero model
+ model_dir = "./test_mixtral"
+ if rank == world_size - 1:
+ os.makedirs(model_dir, exist_ok=True)
+
+ dist.barrier()
+ booster.save_model(parallel_model, model_dir, shard=True)
+ dist.barrier()
+
+ saved_model = AutoModel.from_pretrained(model_dir, trust_remote_code=True).cuda()
+ check_model_equal(torch_model, saved_model)
+ dist.barrier()
+
+ if rank == world_size - 1:
+ shutil.rmtree(model_dir)
+
+ print(f"rank {dist.get_rank()} test passed")
+
+
+def run_dist(rank, world_size, port):
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ run_zero_with_original_model()
+
+
+@pytest.mark.dist
+@pytest.mark.parametrize("world_size", [8])
+@rerun_if_address_is_in_use()
+def test_mistral(world_size):
+ spawn(run_dist, world_size)
+
+
+if __name__ == "__main__":
+ test_mistral(world_size=8)
diff --git a/tests/test_shardformer/test_model/test_shard_mixtral.py b/tests/test_shardformer/test_model/test_shard_mixtral.py
index 4e9c594d2..e0ef3bfaf 100644
--- a/tests/test_shardformer/test_model/test_shard_mixtral.py
+++ b/tests/test_shardformer/test_model/test_shard_mixtral.py
@@ -1,229 +1,188 @@
-# modified from test_shard_mistral.py
import os
+import shutil
+from copy import deepcopy
+from typing import Tuple
import pytest
import torch
+import torch.distributed
import torch.distributed as dist
-from torch.testing import assert_close
+from transformers.models.mixtral.configuration_mixtral import MixtralConfig
+from transformers.models.mixtral.modeling_mixtral import MixtralModel
import colossalai
+from colossalai.booster.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
-from colossalai.logging import disable_existing_loggers
-from colossalai.shardformer.layer.utils import Randomizer
-from colossalai.tensor.d_tensor.api import clear_layout_converter
-from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
-from tests.kit.model_zoo import model_zoo
-from tests.test_shardformer.test_model._utils import (
- build_model_from_hybrid_plugin,
- check_all_grad_tensors,
- check_loss,
- check_output_hidden_state,
- check_weight,
- get_grad_tensors_for_check,
- run_forward_backward_with_hybrid_plugin,
- unwrap_model,
-)
+from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
+from colossalai.testing.random import seed_all
+from tests.test_moe.moe_utils import loose_close
+from tests.test_moe.test_moe_checkpoint import check_model_equal
-os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
-
-
-def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
- # TODO: SGD failed for full dp
- org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin(
- model_fn, loss_fn, test_config, pluggin_cls=MoeHybridParallelPlugin, optim_class=torch.optim.SGD
- )
-
- org_model = org_model.to(torch.float16)
- org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
- org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
- )
- stage_manager = booster.plugin.stage_manager
- tp_group = booster.plugin.tp_group
-
- # check last hidden state & loss
- if stage_manager is None or stage_manager.is_last_stage():
- if test_config["precision"] == "fp32":
- atol, rtol = 1e-5, 1e-3
- else:
- atol, rtol = 5e-3, 5e-3
-
- check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
- check_output_hidden_state(org_output, sharded_output, stage_manager, atol, rtol)
-
- # unwrap model
- mixtral_model = unwrap_model(org_model, "MixtralModel", "model")
- shard_mixtral_model = unwrap_model(sharded_model, "MixtralModel", "model")
-
- row_layer_for_check = ["layers[0].self_attn.q_proj", "embed_tokens"]
- col_layer_for_check = ["layers[0].self_attn.o_proj"]
-
- # Check the grad when using ZeRO-1 and ZeRO-2
- if (
- # booster.plugin.zero_stage in [1, 2]
- booster.plugin.shard_config.enable_sequence_parallelism
- and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all"
- ):
- rank = dist.get_rank()
- name_to_p = {n: p for n, p in mixtral_model.named_parameters()}
- for n, p in shard_mixtral_model.named_parameters():
- zero_grad = sharded_optimizer.get_param_grad(p)
- if name_to_p[n].grad is None:
- name_to_p[n].grad = torch.zeros_like(name_to_p[n].data)
- continue
- assert_close(name_to_p[n].grad, zero_grad, atol=5e-3, rtol=5e-3, check_dtype=False)
-
- # Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
- grads_to_check = {}
- if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
- if test_config["precision"] == "fp32":
- atol, rtol = 5e-5, 1e-4
- else:
- atol, rtol = 5e-3, 5e-3
- row_layer_grads = get_grad_tensors_for_check(
- mixtral_model,
- shard_mixtral_model,
- row_layer_for_check,
- tp_group,
- atol=atol,
- rtol=rtol,
- dim=0,
- verbose=False,
- )
- col_layer_grads = get_grad_tensors_for_check(
- mixtral_model,
- shard_mixtral_model,
- col_layer_for_check,
- tp_group,
- atol=atol,
- rtol=rtol,
- dim=1,
- verbose=False,
- )
- grads_to_check.update(col_layer_grads)
- grads_to_check.update(row_layer_grads)
-
- # check grads
- check_all_grad_tensors(grads_to_check)
-
- for n, p in shard_mixtral_model.named_parameters():
- assert_close(name_to_p[n], p, atol=5e-3, rtol=5e-3, check_dtype=False)
-
- # optimizer executes step
- org_optimizer.step()
- sharded_optimizer.step()
-
- for n, p in shard_mixtral_model.named_parameters():
- assert_close(name_to_p[n], p, atol=5e-3, rtol=5e-3, check_dtype=False)
-
- # check weights
- if stage_manager is None or stage_manager.is_first_stage():
- if test_config["precision"] == "fp32":
- atol, rtol = 2e-4, 1e-3
- else:
- atol, rtol = 5e-3, 5e-3
- try:
- check_weight(
- mixtral_model,
- shard_mixtral_model,
- col_layer_for_check,
- tp_group,
- atol=atol,
- rtol=rtol,
- dim=1,
- verbose=False,
- )
- except Exception as e:
- rank = dist.get_rank()
- print(f"{rank=}, Failed config: {test_config}")
- raise e
-
- torch.cuda.empty_cache()
+NUM_BATCH = 8
+NUM_TOK_PER_BATCH, NUM_EXPERTS = 4, 4
+NUM_LAYERS = 4
+HIDDEN_SIZE_PER_HEAD = 4
+NUM_HEADS = 4
+TOP_K = 1
+# TODO only need to keep one or two cases
@parameterize(
- "test_config",
+ "config",
[
- # {
- # "tp_size": 1,
- # "pp_size": 1,
- # "num_microbatches": 2,
- # "ep_size": 2,
- # "zero_stage": 0,
- # "overlap_communication": False,
- # "precision": "fp16",
- # }, # [dp(4)] + [moe_dp(4)]
- # {
- # "tp_size": 1,
- # "pp_size": 2,
- # "num_microbatches": 2,
- # "ep_size": 2,
- # "zero_stage": 1,
- # "overlap_communication": False,
- # "precision": "fp32",
- # }, # [dp(2) + pp(2)] + [moe_pp(2)]
- # {
- # "tp_size": 2,
- # "pp_size": 2,
- # "num_microbatches": 2,
- # "ep_size": 2,
- # "zero_stage": 1,
- # "overlap_communication": False,
- # "precision": "fp32",
- # }, # [pp(2) + tp(2)] + [pp(2), replicate(2)] pass
- { # Ulysess + Flash attention
- "tp_size": 1,
- "pp_size": 1,
- "sp_size": 2,
- "ep_size": 2,
- "enable_sequence_parallelism": True,
- "sequence_parallelism_mode": "all_to_all",
- "zero_stage": 1,
- "overlap_communication": False,
- "precision": "fp16",
- "initial_scale": 1,
- "find_unused_parameters": True,
- },
- # {
- # "tp_size": 1,
- # "pp_size": 1,
- # "ep_size": 2,
- # "zero_stage": 0,
- # "overlap_communication": False,
- # "precision": "fp32",
- # }, # [dp(4)] + [ep(2) + moe_tp(2)]
- # {
- # "tp_size": 1,
- # "pp_size": 1,
- # "ep_size": 4,
- # "overlap_communication": False,
- # "zero_stage": 0,
- # "precision": "fp32"
- # }, # full dp for non-moe and full ep for moe
+ (2, 1, 1, 4, 1),
+ (2, 1, 2, 1, 1),
+ (2, 1, 2, 2, 1),
+ (2, 1, 1, 2, 1),
+ (2, 1, 1, 1, 2),
+ (2, 1, 4, 1, 1),
+ (4, 1, 1, 1, 1),
+ (4, 1, 1, 2, 1),
+ (4, 1, 2, 1, 1),
],
)
-def run_mixtral_test(test_config):
- sub_model_zoo = model_zoo.get_sub_registry("transformers_mixtral")
+def run_zero_with_original_model(config: Tuple[int, ...]):
+ ep_size, stage, pp_size, tp_size, sp_size = config
+ world_size = dist.get_world_size()
+ rank = dist.get_rank()
+ dtype, precision = torch.float16, "fp16"
+ torch.cuda.set_device(dist.get_rank())
- for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
- check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
+ plugin = MoeHybridParallelPlugin(
+ pp_size=pp_size,
+ num_microbatches=pp_size,
+ tp_size=tp_size,
+ sp_size=sp_size,
+ ep_size=ep_size,
+ moe_tp_size=tp_size,
+ zero_stage=stage,
+ enable_sequence_parallelism=sp_size > 1,
+ sequence_parallelism_mode="all_to_all" if sp_size > 1 else None,
+ overlap_communication=False,
+ initial_scale=1,
+ precision=precision,
+ find_unused_parameters=True,
+ )
+ dp_size = plugin.dp_size
- clear_layout_converter()
- Randomizer.reset_index()
- torch.cuda.empty_cache()
+ booster = Booster(plugin=plugin)
+
+ # init model with the same seed
+ seed_all(10086)
+
+ assert pp_size <= NUM_LAYERS, "pp_size should be less than or equal to NUM_LAYERS"
+ config = MixtralConfig(
+ hidden_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS,
+ intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2,
+ num_hidden_layers=NUM_LAYERS,
+ num_attention_heads=NUM_HEADS,
+ num_key_value_heads=NUM_HEADS,
+ num_local_experts=NUM_EXPERTS,
+ num_experts_per_tok=TOP_K,
+ attn_implementation="flash_attention_2",
+ )
+
+ torch_model = MixtralModel(config).to(dtype).cuda()
+ torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
+
+ parallel_model = deepcopy(torch_model)
+ parallel_optimizer = torch.optim.SGD(parallel_model.parameters(), lr=1)
+ parallel_model, parallel_optimizer, _, _, _ = booster.boost(parallel_model, parallel_optimizer)
+
+ # create different input along dp axis
+ seed_all(1453 + rank)
+
+ torch_model.train()
+ parallel_model.train()
+ for _ in range(2):
+ # gen random input
+ input_embeddings = torch.rand(
+ NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True
+ ).cuda()
+ dist.all_reduce(
+ input_embeddings, group=plugin.pp_group
+ ) # pp inputs except the first stage doesn't matter, but need to be replicate for torch model check
+
+ dist.all_reduce(input_embeddings, group=plugin.tp_group) # tp group duplicate input
+ dist.all_reduce(input_embeddings, group=plugin.sp_group) # sp group duplicate input
+
+ # run the model with hybrid parallel
+ if booster.plugin.stage_manager is not None:
+ # for test with pp
+ data_iter = iter([{"inputs_embeds": input_embeddings}])
+ sharded_output = booster.execute_pipeline(
+ data_iter,
+ parallel_model,
+ lambda x, y: x.last_hidden_state.mean(),
+ parallel_optimizer,
+ return_loss=True,
+ return_outputs=True,
+ )
+ if booster.plugin.stage_manager.is_last_stage():
+ parallel_output = sharded_output["loss"]
+ else:
+ parallel_output = torch.tensor(12345.0, device="cuda")
+
+ # broadcast along pp axis
+ dist.broadcast(
+ parallel_output, src=dist.get_process_group_ranks(plugin.pp_group)[-1], group=plugin.pp_group
+ )
+ else:
+ # for test without pp
+ parallel_output = parallel_model(inputs_embeds=input_embeddings.to(dtype)).last_hidden_state.mean()
+ parallel_optimizer.backward(parallel_output)
+ parallel_optimizer.step()
+ parallel_optimizer.zero_grad()
+ dist.all_reduce(parallel_output, group=plugin.dp_group)
+
+ # ===================================================================================
+ # run normal model with all dp(different) inputs
+ all_inputs = [torch.empty_like(input_embeddings) for _ in range(dp_size)]
+ dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group)
+ torch_output_sum = 0
+ for input_data_ in all_inputs:
+ torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()
+ torch_output.backward()
+ torch_output_sum += torch_output.detach()
+ # avg dp grads follows zero optimizer
+ for p in torch_model.parameters():
+ if p.grad is not None:
+ p.grad /= dp_size
+ torch_optimizer.step()
+ torch_optimizer.zero_grad()
+
+ loose_close(parallel_output, torch_output_sum, dtype=dtype)
+
+ # use checkpoint to load sharded zero model
+ model_dir = "./test_mixtral"
+ if rank == world_size - 1:
+ os.makedirs(model_dir, exist_ok=True)
+
+ dist.barrier()
+ booster.save_model(parallel_model, model_dir, shard=True)
+ dist.barrier()
+
+ saved_model = MixtralModel.from_pretrained(model_dir).cuda().to(dtype)
+ check_model_equal(torch_model, saved_model)
+ dist.barrier()
+
+ if rank == world_size - 1:
+ shutil.rmtree(model_dir)
+
+ print(f"rank {dist.get_rank()} test passed")
-def check_mixtral(rank, world_size, port):
- disable_existing_loggers()
+def run_dist(rank, world_size, port):
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
- run_mixtral_test()
+ run_zero_with_original_model()
@pytest.mark.dist
+@pytest.mark.parametrize("world_size", [8])
@rerun_if_address_is_in_use()
-@clear_cache_before_run()
-def test_mixtral():
- spawn(check_mixtral, 4)
+def test_mistral(world_size):
+ spawn(run_dist, world_size)
if __name__ == "__main__":
- test_mixtral()
+ test_mistral(world_size=8)
From 46037c2ccddfd36accc725da4179c66fa729532c Mon Sep 17 00:00:00 2001
From: hxwang
Date: Fri, 19 Jul 2024 07:53:40 +0000
Subject: [PATCH 082/116] [chore] minor fix after rebase
---
colossalai/booster/plugin/hybrid_parallel_plugin.py | 5 +++--
colossalai/booster/plugin/moe_hybrid_parallel_plugin.py | 2 ++
tests/test_shardformer/test_model/test_shard_deepseek.py | 1 -
3 files changed, 5 insertions(+), 3 deletions(-)
diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py
index 4b1bd0f47..5d0dedeed 100644
--- a/colossalai/booster/plugin/hybrid_parallel_plugin.py
+++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py
@@ -39,7 +39,6 @@ from colossalai.tensor.d_tensor.api import is_distributed_tensor
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.zero.low_level import LowLevelZeroOptimizer
from colossalai.zero.low_level.zero_hook import ZeroOpHook, wait_all_gather_handle
-from colossalai.logging import get_dist_logger
from .pp_plugin_base import PipelinePluginBase
@@ -653,6 +652,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
model: HybridParallelModule,
use_pipeline: bool,
param_info: OrderedDict,
+ pg_to_param_list: Dict[ProcessGroup, List[torch.nn.Parameter]] = None,
initial_scale: int = 2**16, # grad scaler config
min_scale: int = 1,
growth_factor: float = 2.0,
@@ -685,6 +685,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
optimizer=optimizer,
initial_scale=initial_scale,
min_scale=min_scale,
+ pg_to_param_list=pg_to_param_list,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
@@ -1124,7 +1125,7 @@ class HybridParallelPlugin(PipelinePluginBase):
self.logger.info(
f"{type(self).__name__}: dp_group {dist.get_process_group_ranks(self.dp_group)} pp_group {dist.get_process_group_ranks(self.pp_group)} tp_group {dist.get_process_group_ranks(self.tp_group)} sp_group {dist.get_process_group_ranks(self.sp_group)}",
- ranks=[0, 1, 2, 3, 4, 5, 6, 7],
+ ranks=[0],
)
self.shard_config = ShardConfig(
tensor_parallel_process_group=self.tp_group,
diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
index 32673169a..9b88a8838 100644
--- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
+++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
@@ -55,6 +55,7 @@ class MoeHybridParallelZeroOptimizer(HybridParallelZeroOptimizer):
partition_grad: bool = False, # stage 2 flag
cpu_offload: bool = False, # cpu offload
forced_dtype: Optional[torch.dtype] = None,
+ overlap_allgather: bool = False,
):
WARN_STR = "Note that you need to make sure every expert are routed (i.e.) every expert has backward, otherwise this might lead to program hang or inconsistent result"
if not force_overlap_comm and (overlap_communication or partition_grad):
@@ -95,6 +96,7 @@ class MoeHybridParallelZeroOptimizer(HybridParallelZeroOptimizer):
pp_process_group=pp_process_group,
forced_dtype=forced_dtype,
pg_to_param_list=pg_param_list,
+ overlap_allgather=overlap_allgather,
)
diff --git a/tests/test_shardformer/test_model/test_shard_deepseek.py b/tests/test_shardformer/test_model/test_shard_deepseek.py
index 96edfb487..7b57ba039 100644
--- a/tests/test_shardformer/test_model/test_shard_deepseek.py
+++ b/tests/test_shardformer/test_model/test_shard_deepseek.py
@@ -47,7 +47,6 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
dtype, precision = torch.float16, "fp16"
torch.cuda.set_device(dist.get_rank())
- print(config)
plugin = MoeHybridParallelPlugin(
pp_size=pp_size,
num_microbatches=pp_size,
From 52d346f2a53c08c18a738ef68aad194f95f37af2 Mon Sep 17 00:00:00 2001
From: "pre-commit-ci[bot]"
<66853113+pre-commit-ci[bot]@users.noreply.github.com>
Date: Fri, 19 Jul 2024 07:54:40 +0000
Subject: [PATCH 083/116] [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---
tests/test_moe/test_moe_checkpoint.py | 1 -
1 file changed, 1 deletion(-)
diff --git a/tests/test_moe/test_moe_checkpoint.py b/tests/test_moe/test_moe_checkpoint.py
index 1ab52b371..3d6096471 100644
--- a/tests/test_moe/test_moe_checkpoint.py
+++ b/tests/test_moe/test_moe_checkpoint.py
@@ -13,7 +13,6 @@ from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
-from colossalai.tensor.moe_tensor.api import is_moe_tensor
from colossalai.testing import parameterize, spawn
from colossalai.testing.utils import spawn
from tests.test_moe.moe_utils import loose_close
From 70c9924d0d3103b515a9d7d60b7c4e729150a00f Mon Sep 17 00:00:00 2001
From: hxwang
Date: Mon, 22 Jul 2024 03:40:34 +0000
Subject: [PATCH 084/116] [chore] solve moe ckpt test failure and some other
arg pass failure
---
.../booster/plugin/low_level_zero_plugin.py | 2 +-
colossalai/shardformer/modeling/deepseek.py | 5 +--
colossalai/shardformer/modeling/mixtral.py | 5 +--
colossalai/zero/low_level/low_level_optim.py | 21 ++++++-----
tests/test_moe/moe_utils.py | 18 +++++++---
tests/test_moe/test_deepseek_layer.py | 11 ++++--
.../test_moe/test_mixtral_layer.py | 4 ++-
tests/test_moe/test_moe_checkpoint.py | 36 ++++++++-----------
tests/test_moe/test_moe_ep_tp.py | 11 +++---
tests/test_moe/test_moe_ep_zero.py | 11 +++---
.../test_model/test_shard_deepseek.py | 24 +++++++------
.../test_model/test_shard_mixtral.py | 32 ++++++++---------
12 files changed, 101 insertions(+), 79 deletions(-)
rename tests/{test_legacy => }/test_moe/test_mixtral_layer.py (94%)
diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py
index 1a6547796..66491821c 100644
--- a/colossalai/booster/plugin/low_level_zero_plugin.py
+++ b/colossalai/booster/plugin/low_level_zero_plugin.py
@@ -446,7 +446,7 @@ class LowLevelZeroPlugin(DPPluginBase):
group_id, check_state = self.get_param_group_id(optimizer, origin_param, param)
if check_state == OptimizerParamCheckState.ORIGIN_PARAM_NOT_FIND:
warnings.warn(
- "Origin parameter {origin_key} related to {name} doesn't exist in optimizer param_groups."
+ f"Origin parameter {origin_key} related to {name} doesn't exist in optimizer param_groups."
)
elif (
check_state == OptimizerParamCheckState.ORIGIN_PARAM_FINDED
diff --git a/colossalai/shardformer/modeling/deepseek.py b/colossalai/shardformer/modeling/deepseek.py
index 33fac9b93..854bbf4f2 100644
--- a/colossalai/shardformer/modeling/deepseek.py
+++ b/colossalai/shardformer/modeling/deepseek.py
@@ -69,8 +69,6 @@ class EPDeepseekMoE(nn.Module):
held_experts = self.experts[self.expert_start_idx : self.expert_start_idx + self.num_experts_per_ep]
set_tensors_to_none(self.experts, exclude=set(held_experts))
- for p in self.experts.parameters():
- set_moe_tensor_ep_group(p, ep_group)
# setup moe_dp group
self.moe_dp_group = moe_dp_group
@@ -87,6 +85,9 @@ class EPDeepseekMoE(nn.Module):
expert.up_proj = Linear1D_Col.from_native_module(expert.up_proj, self.moe_tp_group)
expert.down_proj = Linear1D_Row.from_native_module(expert.down_proj, self.moe_tp_group)
+ for p in self.experts.parameters():
+ set_moe_tensor_ep_group(p, ep_group)
+
@staticmethod
def from_native_module(
module,
diff --git a/colossalai/shardformer/modeling/mixtral.py b/colossalai/shardformer/modeling/mixtral.py
index 90616351a..9148a9fba 100644
--- a/colossalai/shardformer/modeling/mixtral.py
+++ b/colossalai/shardformer/modeling/mixtral.py
@@ -74,8 +74,6 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
held_experts = self.experts[self.expert_start_idx : self.expert_start_idx + self.num_experts_per_ep]
set_tensors_to_none(self.experts, exclude=set(held_experts))
- for p in self.experts.parameters():
- set_moe_tensor_ep_group(p, ep_group)
# setup moe_dp group
self.moe_dp_group = moe_dp_group
@@ -92,6 +90,9 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
expert.w3 = Linear1D_Col.from_native_module(expert.w3, self.moe_tp_group)
expert.w2 = Linear1D_Row.from_native_module(expert.w2, self.moe_tp_group)
+ for p in self.experts.parameters():
+ set_moe_tensor_ep_group(p, ep_group)
+
@staticmethod
def from_native_module(
module: MixtralSparseMoeBlock,
diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py
index c2f9e98a1..947dec51b 100644
--- a/colossalai/zero/low_level/low_level_optim.py
+++ b/colossalai/zero/low_level/low_level_optim.py
@@ -20,6 +20,7 @@ from colossalai.amp.naive_amp.mixed_precision_mixin import (
)
from colossalai.interface import OptimizerWrapper
from colossalai.logging import get_dist_logger
+from colossalai.tensor.moe_tensor.api import is_moe_tensor
from ._utils import calculate_global_norm_from_list, has_inf_or_nan, release_param_grad, sync_tensor
from .bookkeeping import BucketStore, GradientStore, TensorBucket
@@ -66,7 +67,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
def __init__(
self,
optimizer: Optimizer,
- pg_to_param_list: Dict[ProcessGroup, List[nn.Parameter]] = None,
+ pg_to_param_list: Optional[Dict[ProcessGroup, List[nn.Parameter]]] = None,
initial_scale: int = 2**16, # grad scaler config
min_scale: int = 1,
growth_factor: float = 2.0,
@@ -92,7 +93,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
self._logger = get_dist_logger()
self._verbose = verbose
- if dp_process_group is not None and pg_to_param_list is not None:
+ if (dp_process_group is not None) and (pg_to_param_list is not None):
raise ValueError("dp_process_group and pg_to_param_list should not be provided at the same time.")
if pg_to_param_list is None:
@@ -301,6 +302,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
def _run_reduction(self):
for bucket_store in self.pg_to_bucket_store.values():
+ if bucket_store.num_elements_in_bucket() <= 0:
+ continue
+
bucket_store.build_grad_in_bucket()
flat_grads = bucket_store.get_flatten_grad()
@@ -350,8 +354,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
self, bucket_store: BucketStore, origin_grad_list: List, flat_grad_list: List, group_id: int
) -> None:
for rank, grad_list in enumerate(origin_grad_list):
- if len(grad_list) == 0:
- continue
sync_tensor(flat_grad_list[rank], grad_list)
for grad in grad_list:
param_id = bucket_store.get_param_id_of_grad(grad)
@@ -648,11 +650,12 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
for group_id in range(self.num_param_groups):
param_group = self._working_param_groups[group_id]
for param in param_group:
- if param.requires_grad:
- if param.grad is None:
- # for moe params, all experts should have gradient
- # TODO better way of doing this
- param.grad = torch.zeros_like(param)
+ if is_moe_tensor(param) and param.requires_grad and param.grad is None:
+ # TODO better of of doing this
+ # assign zero grad to unrouted expert to avoid hang during grad reduction
+ param.grad = torch.zeros_like(param)
+
+ if param.requires_grad and param.grad is not None:
self._add_to_bucket(param, group_id)
self._run_reduction()
diff --git a/tests/test_moe/moe_utils.py b/tests/test_moe/moe_utils.py
index 55aee6fb0..8c411a33f 100644
--- a/tests/test_moe/moe_utils.py
+++ b/tests/test_moe/moe_utils.py
@@ -1,7 +1,11 @@
import torch
-def loose_close(a, b, dtype: torch.dtype = torch.float32, name=""):
+def assert_loose_close(a, b, dtype: torch.dtype = torch.float32, name=""):
+ assert loose_close(a, b, dtype), f"{name} not close {a.mean()} {b.mean()}"
+
+
+def loose_close(a, b, dtype: torch.dtype = torch.float32):
rtol = None
atol = None
if dtype is torch.float16:
@@ -12,10 +16,16 @@ def loose_close(a, b, dtype: torch.dtype = torch.float32, name=""):
atol = 4e-3
else:
assert dtype is torch.float32
- rtol = 1e-5
- atol = 1e-5
+ rtol = 1e-05
+ atol = 1e-08
a = a.detach().to(dtype)
b = b.detach().to(dtype).to(a.device)
- assert torch.allclose(a, b, rtol=rtol, atol=atol), f"{name} not close {a.mean()} {b.mean()}"
+ return torch.allclose(a, b, rtol=rtol, atol=atol)
+
+
+def check_model_equal(model1, model2):
+ assert set(model1.state_dict().keys()) == set(model2.state_dict().keys())
+ for i, ((name, p1), p2) in enumerate(zip(model1.named_parameters(), model2.parameters())):
+ assert_loose_close(p1, p2, p1.dtype)
diff --git a/tests/test_moe/test_deepseek_layer.py b/tests/test_moe/test_deepseek_layer.py
index 85cc98695..e633cdd07 100644
--- a/tests/test_moe/test_deepseek_layer.py
+++ b/tests/test_moe/test_deepseek_layer.py
@@ -22,6 +22,7 @@ def check_deepseek_moe_layer():
precision="bf16",
tp_size=1,
pp_size=1,
+ zero_stage=1,
ep_size=dist.get_world_size(),
)
@@ -42,7 +43,13 @@ def check_deepseek_moe_layer():
x = torch.rand(1, tokens, hidden_size, requires_grad=True).cuda()
orig_output = orig_model(x)
model = deepcopy(orig_model)
- model = EPDeepseekMoE.from_native_module(model, ep_group=plugin.ep_group)
+ model = EPDeepseekMoE.from_native_module(
+ model,
+ ep_group=plugin.ep_group,
+ moe_dp_group=plugin.moe_dp_group,
+ moe_tp_group=plugin.moe_tp_group,
+ tp_group=plugin.tp_group,
+ )
ep_output = model(x)
assert_close(orig_output, ep_output)
orig_loss = orig_output.mean()
@@ -62,7 +69,7 @@ def run_dist(rank: int, world_size: int, port: int):
check_deepseek_moe_layer()
-# @pytest.mark.parametrize("world_size", [2, 4])
+@pytest.mark.skip("tested in corresponding sharderformer")
@pytest.mark.parametrize("world_size", [2])
def test_deepseek_moe_layer(world_size: int):
spawn(run_dist, world_size)
diff --git a/tests/test_legacy/test_moe/test_mixtral_layer.py b/tests/test_moe/test_mixtral_layer.py
similarity index 94%
rename from tests/test_legacy/test_moe/test_mixtral_layer.py
rename to tests/test_moe/test_mixtral_layer.py
index de34b5c7b..5d9ca622a 100644
--- a/tests/test_legacy/test_moe/test_mixtral_layer.py
+++ b/tests/test_moe/test_mixtral_layer.py
@@ -23,6 +23,7 @@ def check_mixtral_moe_layer():
precision="bf16",
tp_size=1,
pp_size=1,
+ zero_stage=1,
ep_size=dist.get_world_size(),
)
config = MixtralConfig(
@@ -63,7 +64,8 @@ def run_dist(rank: int, world_size: int, port: int):
check_mixtral_moe_layer()
-@pytest.mark.parametrize("world_size", [2, 4])
+@pytest.mark.skip("tested in corresponding sharderformer")
+@pytest.mark.parametrize("world_size", [2])
def test_mixtral_moe_layer(world_size: int):
spawn(run_dist, world_size)
diff --git a/tests/test_moe/test_moe_checkpoint.py b/tests/test_moe/test_moe_checkpoint.py
index 3d6096471..89f5d1c64 100644
--- a/tests/test_moe/test_moe_checkpoint.py
+++ b/tests/test_moe/test_moe_checkpoint.py
@@ -6,7 +6,7 @@ from copy import deepcopy
import pytest
import torch
import torch.distributed as dist
-from torch.optim import Adam
+from torch.optim import SGD, Adam
from transformers.models.mixtral.configuration_mixtral import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM
@@ -14,20 +14,15 @@ import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.testing import parameterize, spawn
+from colossalai.testing.random import seed_all
from colossalai.testing.utils import spawn
-from tests.test_moe.moe_utils import loose_close
+from tests.test_moe.moe_utils import check_model_equal
tokens, n_experts = 7, 4
hidden_size = 8
top_k = 2
-def check_model_equal(model1, model2):
- assert set(model1.state_dict().keys()) == set(model2.state_dict().keys())
- for i, ((name, p1), p2) in enumerate(zip(model1.named_parameters(), model2.parameters())):
- loose_close(p1, p2, p1.dtype)
-
-
def get_optimizer_snapshot(optim):
state = {id(k): deepcopy(v) for k, v in optim.state.items()}
param_groups = []
@@ -86,34 +81,33 @@ def check_optimizer_snapshot_equal(snapshot1, snapshot2, param2name, moe_dp_grou
num_experts_per_tok=top_k,
num_attention_heads=2,
num_key_value_heads=2,
+ num_hidden_layers=2,
),
MixtralForCausalLM,
],
],
)
def check_moe_checkpoint(test_config):
+ dtype, precision = torch.float16, "fp16"
+ config, model_cls = test_config
+ torch.cuda.set_device(dist.get_rank())
+
context = tempfile.TemporaryDirectory() if dist.get_rank() == 0 else nullcontext()
with context as f:
- torch.cuda.set_device(dist.get_rank())
if dist.get_rank() == 0:
broadcast_objects = [f] # any picklable object
else:
broadcast_objects = [None]
dist.broadcast_object_list(broadcast_objects, src=0)
- config = test_config[0]
- model_cls = test_config[1]
- torch.manual_seed(0)
input_ids = torch.randint(0, 100, (2, tokens)).cuda()
- orig_model = model_cls(config).cuda()
+ orig_model = model_cls(config).cuda().to(dtype)
+
+ seed_all(10086)
model = deepcopy(orig_model)
- optimizer = Adam(model.parameters(), lr=1e-3)
+ optimizer = SGD(model.parameters(), lr=1e-3)
plugin = MoeHybridParallelPlugin(
- pp_size=2,
- ep_size=2,
- tp_size=1,
- microbatch_size=1,
- zero_stage=1,
+ pp_size=2, ep_size=2, tp_size=1, microbatch_size=1, zero_stage=1, precision=precision
)
booster = Booster(plugin=plugin)
model, optimizer, *_ = booster.boost(model=model, optimizer=optimizer)
@@ -135,12 +129,12 @@ def check_moe_checkpoint(test_config):
booster.save_model(model, model_dir, shard=True)
dist.barrier()
if dist.get_rank() == 0:
- saved_model = model_cls.from_pretrained(model_dir).cuda()
+ saved_model = model_cls.from_pretrained(model_dir).cuda().to(dtype)
check_model_equal(orig_model, saved_model)
saved_model.save_pretrained(hf_model_dir)
dist.barrier()
# check load model
- new_model = model_cls(config).cuda()
+ new_model = model_cls(config).cuda().to(dtype)
new_optimizer = Adam(new_model.parameters(), lr=1e-3)
new_model, new_optimizer, *_ = booster.boost(model=new_model, optimizer=new_optimizer)
booster.load_model(new_model, hf_model_dir)
diff --git a/tests/test_moe/test_moe_ep_tp.py b/tests/test_moe/test_moe_ep_tp.py
index 29881c9ab..e6d2609ee 100644
--- a/tests/test_moe/test_moe_ep_tp.py
+++ b/tests/test_moe/test_moe_ep_tp.py
@@ -12,7 +12,7 @@ from colossalai.booster.plugin import HybridParallelPlugin
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all
-from tests.test_moe.moe_utils import loose_close
+from tests.test_moe.moe_utils import assert_loose_close
NUM_BATCH = 4
NUM_TOK_PER_BATCH, NUM_EXPERTS = 7, 4
@@ -22,7 +22,7 @@ TOP_K = 2
@parameterize("stage", [1])
-@parameterize("ep_size", [1, 2, 4])
+@parameterize("ep_size", [2])
def run_zero_with_original_model(stage: int, ep_size: int):
tp_size = dist.get_world_size() // ep_size
dtype = torch.bfloat16
@@ -85,7 +85,7 @@ def run_zero_with_original_model(stage: int, ep_size: int):
zero_optimizer.backward(zero_output)
# torch-ddp forward
hybrid_output = hybrid_model(inputs_embeds=input_data.to(dtype)).last_hidden_state.mean()
- loose_close(zero_output, hybrid_output, dtype=dtype)
+ assert_loose_close(zero_output, hybrid_output, dtype=dtype)
# torch-ddp backward
hybrid_optimizer.backward(hybrid_output)
@@ -98,7 +98,7 @@ def run_zero_with_original_model(stage: int, ep_size: int):
continue
if zero_grad.shape != name_to_p[n].grad.shape: # TODO check sharded and sliced moe
continue
- loose_close(zero_grad, name_to_p[n].grad, dtype=dtype, name=n)
+ assert_loose_close(zero_grad, name_to_p[n].grad, dtype=dtype, name=n)
# zero-dp step
zero_optimizer.step()
@@ -110,7 +110,7 @@ def run_zero_with_original_model(stage: int, ep_size: int):
for n, p in zero_model.named_parameters():
if p.data.shape != name_to_p[n].data.shape: # TODO check sharded and sliced moe
continue
- loose_close(p.data, name_to_p[n].data, dtype=dtype, name=n)
+ assert_loose_close(p.data, name_to_p[n].data, dtype=dtype, name=n)
print(f"{dist.get_rank()} test passed")
@@ -120,6 +120,7 @@ def run_dist(rank, world_size, port):
run_zero_with_original_model()
+@pytest.mark.skip("tested in corresponding sharderformer")
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [4])
@rerun_if_address_is_in_use()
diff --git a/tests/test_moe/test_moe_ep_zero.py b/tests/test_moe/test_moe_ep_zero.py
index 40e3bacb3..2d4e638b6 100644
--- a/tests/test_moe/test_moe_ep_zero.py
+++ b/tests/test_moe/test_moe_ep_zero.py
@@ -12,7 +12,7 @@ from colossalai.booster.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all
-from tests.test_moe.moe_utils import loose_close
+from tests.test_moe.moe_utils import assert_loose_close
NUM_BATCH = 4
NUM_TOK_PER_BATCH, NUM_EXPERTS = 7, 4
@@ -22,7 +22,7 @@ TOP_K = 1
@parameterize("stage", [1])
-@parameterize("ep_size", [1, 2, 4])
+@parameterize("ep_size", [2, 4])
def run_zero_with_original_model(stage: int, ep_size: int):
dtype = torch.bfloat16
@@ -76,7 +76,7 @@ def run_zero_with_original_model(stage: int, ep_size: int):
# torch-ddp forward
ddp_output = ddp_model(inputs_embeds=input_data.to(dtype)).last_hidden_state.mean()
- loose_close(zero_output, ddp_output, dtype=dtype)
+ assert_loose_close(zero_output, ddp_output, dtype=dtype)
# torch-ddp backward
ddp_output.backward()
@@ -87,7 +87,7 @@ def run_zero_with_original_model(stage: int, ep_size: int):
if name_to_p[n].grad is None:
name_to_p[n].grad = torch.zeros_like(name_to_p[n].data)
continue
- loose_close(zero_grad, name_to_p[n].grad, dtype=dtype, name=n)
+ assert_loose_close(zero_grad, name_to_p[n].grad, dtype=dtype, name=n)
# zero-dp step
zero_optimizer.step()
@@ -97,7 +97,7 @@ def run_zero_with_original_model(stage: int, ep_size: int):
# check updated param
for n, p in zero_model.named_parameters():
- loose_close(p.data, name_to_p[n].data, dtype=dtype, name=n)
+ assert_loose_close(p.data, name_to_p[n].data, dtype=dtype, name=n)
print(f"{dist.get_rank()} test passed")
@@ -107,6 +107,7 @@ def run_dist(rank, world_size, port):
run_zero_with_original_model()
+@pytest.mark.skip("tested in corresponding sharderformer")
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [4])
@rerun_if_address_is_in_use()
diff --git a/tests/test_shardformer/test_model/test_shard_deepseek.py b/tests/test_shardformer/test_model/test_shard_deepseek.py
index 7b57ba039..16513b2f5 100644
--- a/tests/test_shardformer/test_model/test_shard_deepseek.py
+++ b/tests/test_shardformer/test_model/test_shard_deepseek.py
@@ -14,8 +14,7 @@ from colossalai.booster.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all
-from tests.test_moe.moe_utils import loose_close
-from tests.test_moe.test_moe_checkpoint import check_model_equal
+from tests.test_moe.moe_utils import assert_loose_close, check_model_equal
NUM_BATCH = 8
NUM_TOK_PER_BATCH, NUM_EXPERTS = 4, 4
@@ -25,18 +24,21 @@ NUM_HEADS = 4
TOP_K = 1
-# TODO only need to keep one or two cases
+CHECKED_CONFIG = [ # FOR_WORLD=8
+ (2, 1, 1, 4, 1),
+ (4, 1, 1, 2, 1),
+ (4, 1, 1, 1, 1),
+]
+
+
@parameterize(
"config",
[
- (2, 1, 1, 4, 1),
# (2, 1, 2, 1, 1), # TODO debug deepseek pp
# (2, 1, 2, 2, 1), # TODO debug deepseek pp
(2, 1, 1, 2, 1),
# (2, 1, 1, 1, 2), # TODO support deepseek sp
# (2, 1, 4, 1, 1), # TODO debug deepseek pp
- (4, 1, 1, 1, 1),
- (4, 1, 1, 2, 1),
# (4, 1, 2, 1, 1), # TODO debug deepseek pp
],
)
@@ -66,9 +68,6 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
booster = Booster(plugin=plugin)
- # init model with the same seed
- seed_all(10086)
-
assert pp_size <= NUM_LAYERS, "pp_size should be less than or equal to NUM_LAYERS"
config = AutoConfig.from_pretrained("deepseek-ai/deepseek-moe-16b-base", trust_remote_code=True)
config.hidden_size = HIDDEN_SIZE_PER_HEAD * NUM_HEADS
@@ -79,6 +78,9 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
config.n_routed_experts = NUM_EXPERTS
config.num_experts_per_tok = TOP_K
+ # init model with the same seed
+ seed_all(10086)
+
torch_model = AutoModel.from_config(config, trust_remote_code=True).cuda().to(dtype)
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
@@ -148,7 +150,7 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
torch_optimizer.step()
torch_optimizer.zero_grad()
- loose_close(parallel_output, torch_output_sum, dtype=dtype)
+ assert_loose_close(parallel_output, torch_output_sum, dtype=dtype)
# use checkpoint to load sharded zero model
model_dir = "./test_mixtral"
@@ -175,7 +177,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist
-@pytest.mark.parametrize("world_size", [8])
+@pytest.mark.parametrize("world_size", [4])
@rerun_if_address_is_in_use()
def test_mistral(world_size):
spawn(run_dist, world_size)
diff --git a/tests/test_shardformer/test_model/test_shard_mixtral.py b/tests/test_shardformer/test_model/test_shard_mixtral.py
index e0ef3bfaf..2b8623e13 100644
--- a/tests/test_shardformer/test_model/test_shard_mixtral.py
+++ b/tests/test_shardformer/test_model/test_shard_mixtral.py
@@ -15,8 +15,7 @@ from colossalai.booster.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all
-from tests.test_moe.moe_utils import loose_close
-from tests.test_moe.test_moe_checkpoint import check_model_equal
+from tests.test_moe.moe_utils import assert_loose_close, check_model_equal
NUM_BATCH = 8
NUM_TOK_PER_BATCH, NUM_EXPERTS = 4, 4
@@ -25,20 +24,21 @@ HIDDEN_SIZE_PER_HEAD = 4
NUM_HEADS = 4
TOP_K = 1
+CHECKED_CONFIG = [ # FOR WORLD=4
+ (2, 1, 2, 2, 1),
+ (2, 1, 1, 2, 1),
+ (2, 1, 4, 1, 1),
+ (4, 1, 1, 1, 1),
+ (4, 1, 1, 2, 1),
+ (4, 1, 2, 1, 1),
+ (2, 1, 2, 1, 1),
+]
+
-# TODO only need to keep one or two cases
@parameterize(
"config",
[
- (2, 1, 1, 4, 1),
- (2, 1, 2, 1, 1),
- (2, 1, 2, 2, 1),
(2, 1, 1, 2, 1),
- (2, 1, 1, 1, 2),
- (2, 1, 4, 1, 1),
- (4, 1, 1, 1, 1),
- (4, 1, 1, 2, 1),
- (4, 1, 2, 1, 1),
],
)
def run_zero_with_original_model(config: Tuple[int, ...]):
@@ -67,9 +67,6 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
booster = Booster(plugin=plugin)
- # init model with the same seed
- seed_all(10086)
-
assert pp_size <= NUM_LAYERS, "pp_size should be less than or equal to NUM_LAYERS"
config = MixtralConfig(
hidden_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS,
@@ -82,6 +79,9 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
attn_implementation="flash_attention_2",
)
+ # init model with the same seed
+ seed_all(10086)
+
torch_model = MixtralModel(config).to(dtype).cuda()
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
@@ -151,7 +151,7 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
torch_optimizer.step()
torch_optimizer.zero_grad()
- loose_close(parallel_output, torch_output_sum, dtype=dtype)
+ assert_loose_close(parallel_output, torch_output_sum, dtype=dtype)
# use checkpoint to load sharded zero model
model_dir = "./test_mixtral"
@@ -178,7 +178,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist
-@pytest.mark.parametrize("world_size", [8])
+@pytest.mark.parametrize("world_size", [4])
@rerun_if_address_is_in_use()
def test_mistral(world_size):
spawn(run_dist, world_size)
From 74b03de3f907f7237e3e60811ba4f35b9d74cec3 Mon Sep 17 00:00:00 2001
From: hxwang
Date: Mon, 22 Jul 2024 04:00:42 +0000
Subject: [PATCH 085/116] [moe] remove ops
---
colossalai/moe/_operation.py | 111 -----------------------------------
1 file changed, 111 deletions(-)
diff --git a/colossalai/moe/_operation.py b/colossalai/moe/_operation.py
index f24191c16..ac422a4da 100644
--- a/colossalai/moe/_operation.py
+++ b/colossalai/moe/_operation.py
@@ -440,114 +440,3 @@ def all_to_all_uneven(
inputs.requires_grad
), "Input must require grad to assure that backward is executed, otherwise it might hang the program."
return AllToAllUneven.apply(inputs, input_split_sizes, output_split_sizes, group, overlap)
-
-
-# ===========================================================
-# This code section was modified from
-# https://github.com/microsoft/DeepSpeed/blob/3d347276ce80e1a29e777c839d1d7fabe8e5f034/deepspeed/moe/mappings.py
-
-# Copyright (c) Microsoft Corporation.
-# SPDX-License-Identifier: Apache-2.0
-
-# DeepSpeed Team
-
-# The file has been adapted from the following Megatron-LM file:
-# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/mpu/mappings.py
-# Git commit hash: 9dc3c42a84aa656f583703cf8b6b4f79f712b796
-# We retain the following copyright from the original files:
-
-# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# TODO: used when non-moe are tp but moe are not
-
-
-def _gather_tokens(input_, dim: int, tp_group: ProcessGroup):
- """Gather tensors and concatenate them along a dimension"""
-
- input_ = input_.contiguous()
- # Size and dimension.
- rank = tp_group.rank()
-
- tensor_list = [torch.empty_like(input_) for _ in range(tp_group.size())]
- tensor_list[rank] = input_
- dist.all_gather(tensor_list, input_, group=tp_group)
-
- # Note: torch.cat already creates a contiguous tensor.
- output = torch.cat(tensor_list, dim=dim).contiguous()
-
- return output
-
-
-def _drop_tokens(input_, dim: int, tp_group: ProcessGroup):
- """Divide a tensor among the tensor parallel ranks"""
-
- total_chunks = tp_group.size()
- this_chunk = tp_group.rank()
- assert (
- input_.shape[dim] % total_chunks == 0
- ), f"input dimension {dim} ({input_.shape[dim]}) is not divisible by tensor parallel world size ({total_chunks})"
- chunk_size = input_.shape[dim] // total_chunks
-
- return torch.narrow(input_, dim, this_chunk * chunk_size, chunk_size)
-
-
-class _GatherTokens(torch.autograd.Function):
- """All gather tokens among the tensor parallel ranks"""
-
- @staticmethod
- def forward(ctx, input_: torch.Tensor, dim: int, tp_group: ProcessGroup) -> torch.Tensor:
- ctx.dim = dim
- ctx.tp_group = tp_group
- return _gather_tokens(input_, dim, tp_group)
-
- @staticmethod
- def backward(ctx, grad_output):
- return _drop_tokens(grad_output, ctx.dim, ctx.tp_group), None, None
-
-
-class _DropTokens(torch.autograd.Function):
- "Divide tokens equally among the tensor parallel ranks"
-
- @staticmethod
- def forward(ctx, input_: torch.Tensor, dim: int, tp_group: ProcessGroup) -> torch.Tensor:
- ctx.dim = dim
- ctx.tp_group = tp_group
- return _drop_tokens(input_, dim, tp_group)
-
- @staticmethod
- def backward(ctx, input_: torch.Tensor) -> Tuple[torch.Tensor, None]:
- return _gather_tokens(input_, ctx.dim, ctx.tp_group), None, None
-
-
-def gather_tokens(input_, dim: int, tp_group: ProcessGroup):
- if tp_group.size() == 1:
- # no tensor parallelism for non-experts
- return input_
- assert (
- input_.requires_grad
- ), "Input must require grad to assure that backward is executed, otherwise it might hang the program."
- return _GatherTokens.apply(input_, dim, tp_group)
-
-
-def drop_tokens(input_, dim: int, tp_group: ProcessGroup):
- if tp_group.size() == 1:
- # no tensor parallelism for non-experts
- return input_
- assert (
- input_.requires_grad
- ), "Input must require grad to assure that backward is executed, otherwise it might hang the program."
- return _DropTokens.apply(input_, dim, tp_group)
-
-
-# ===========================================================
From 067e18f7e98a8a47de350bb8af4d6e52aa42e78d Mon Sep 17 00:00:00 2001
From: hxwang
Date: Mon, 22 Jul 2024 05:36:20 +0000
Subject: [PATCH 086/116] [test] fix test: test_zero1_2
---
colossalai/zero/low_level/low_level_optim.py | 2 +-
tests/test_shardformer/test_model/test_shard_deepseek.py | 4 ++--
tests/test_shardformer/test_model/test_shard_mixtral.py | 4 ++--
3 files changed, 5 insertions(+), 5 deletions(-)
diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py
index 947dec51b..51d7d1eaa 100644
--- a/colossalai/zero/low_level/low_level_optim.py
+++ b/colossalai/zero/low_level/low_level_optim.py
@@ -880,7 +880,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
return None
grad_flat = torch.empty((grad_store.world_size, *grad.shape), dtype=grad.dtype, device=grad.device)
dist.all_gather_into_tensor(grad_flat, grad, group=grad_store.torch_pg)
- return grad_flat[: working_param.numel()].reshape_as(working_param)
+ return grad_flat.view(-1)[: working_param.numel()].view_as(working_param)
def get_working_grads_by_group_id(self, group_id: int) -> List[Tensor]:
working_grads = []
diff --git a/tests/test_shardformer/test_model/test_shard_deepseek.py b/tests/test_shardformer/test_model/test_shard_deepseek.py
index 16513b2f5..c301777f2 100644
--- a/tests/test_shardformer/test_model/test_shard_deepseek.py
+++ b/tests/test_shardformer/test_model/test_shard_deepseek.py
@@ -179,9 +179,9 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [4])
@rerun_if_address_is_in_use()
-def test_mistral(world_size):
+def test_deepseek(world_size):
spawn(run_dist, world_size)
if __name__ == "__main__":
- test_mistral(world_size=8)
+ test_deepseek(world_size=4)
diff --git a/tests/test_shardformer/test_model/test_shard_mixtral.py b/tests/test_shardformer/test_model/test_shard_mixtral.py
index 2b8623e13..419679797 100644
--- a/tests/test_shardformer/test_model/test_shard_mixtral.py
+++ b/tests/test_shardformer/test_model/test_shard_mixtral.py
@@ -180,9 +180,9 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [4])
@rerun_if_address_is_in_use()
-def test_mistral(world_size):
+def test_mixtral(world_size):
spawn(run_dist, world_size)
if __name__ == "__main__":
- test_mistral(world_size=8)
+ test_mixtral(world_size=4)
From 96d0fbc5315330f1437c54cae2dd134f2e78a716 Mon Sep 17 00:00:00 2001
From: botbw
Date: Tue, 23 Jul 2024 06:17:51 +0000
Subject: [PATCH 087/116] [bug] fix: somehow logger hangs the program
---
colossalai/booster/plugin/hybrid_parallel_plugin.py | 11 -----------
.../booster/plugin/moe_hybrid_parallel_plugin.py | 6 ------
2 files changed, 17 deletions(-)
diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py
index 5d0dedeed..5a9bae479 100644
--- a/colossalai/booster/plugin/hybrid_parallel_plugin.py
+++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py
@@ -27,7 +27,6 @@ from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO
from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.interface.optimizer import DistributedOptim
-from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
@@ -1020,8 +1019,6 @@ class HybridParallelPlugin(PipelinePluginBase):
) -> None:
super().__init__()
- self.logger = get_dist_logger(type(self).__name__)
-
assert (
dist.get_world_size() % (tp_size * pp_size) == 0
), f"World size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
@@ -1070,10 +1067,6 @@ class HybridParallelPlugin(PipelinePluginBase):
self.pp_axis, self.dp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size)
- self.logger.info(
- f"{type(self).__name__}: {self.pp_size=} {self.dp_size=} {self.tp_size=} {self.sp_size=}", ranks=[0]
- )
-
self.stage_manager = None
self.schedule = None
self.custom_policy = custom_policy
@@ -1123,10 +1116,6 @@ class HybridParallelPlugin(PipelinePluginBase):
else:
self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis)
- self.logger.info(
- f"{type(self).__name__}: dp_group {dist.get_process_group_ranks(self.dp_group)} pp_group {dist.get_process_group_ranks(self.pp_group)} tp_group {dist.get_process_group_ranks(self.tp_group)} sp_group {dist.get_process_group_ranks(self.sp_group)}",
- ranks=[0],
- )
self.shard_config = ShardConfig(
tensor_parallel_process_group=self.tp_group,
sequence_parallel_process_group=self.sp_group,
diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
index 9b88a8838..55855afb2 100644
--- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
+++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
@@ -226,12 +226,6 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
f"Only support shared tp group between moe and non moe params, but found non-moe tp {dist.get_process_group_ranks(self.tp_group)}, moe tp {dist.get_process_group_ranks(self.moe_tp_group)}, please make sure tp_size == moe_tp_size"
)
- self.logger.info(
- f"{type(self).__name__}: {self.ep_size=} {self.moe_dp_size=} {self.moe_tp_size=}\n"
- f"rank {dist.get_rank()} moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)} ep_group {dist.get_process_group_ranks(self.ep_group)} moe_tp_group {dist.get_process_group_ranks(self.moe_tp_group)}",
- ranks=[0],
- )
-
def get_checkpoint_io(self) -> MoECheckpointIO:
return MoECheckpointIO(
self.dp_group, self.pp_group, self.tp_group, self.ep_group, self.moe_dp_group, self.zero_stage
From b2952a59821df56499a92fef40c2dd27084d7356 Mon Sep 17 00:00:00 2001
From: haze188
Date: Tue, 23 Jul 2024 06:39:49 +0000
Subject: [PATCH 088/116] [moe] deepseek moe sp support
---
colossalai/shardformer/modeling/deepseek.py | 299 +++++++++++++++++-
colossalai/shardformer/policies/deepseek.py | 80 ++++-
tests/kit/model_zoo/transformers/__init__.py | 1 +
tests/kit/model_zoo/transformers/deepseek.py | 84 +++++
.../test_model/test_shard_deepseek.py | 28 +-
.../test_model/test_shard_deepseek_ghz.py | 231 ++++++++++++++
6 files changed, 704 insertions(+), 19 deletions(-)
create mode 100644 tests/kit/model_zoo/transformers/deepseek.py
create mode 100644 tests/test_shardformer/test_model/test_shard_deepseek_ghz.py
diff --git a/colossalai/shardformer/modeling/deepseek.py b/colossalai/shardformer/modeling/deepseek.py
index 854bbf4f2..468b890ab 100644
--- a/colossalai/shardformer/modeling/deepseek.py
+++ b/colossalai/shardformer/modeling/deepseek.py
@@ -1,12 +1,18 @@
-from typing import List, Optional
+import warnings
+from typing import List, Optional, Tuple, Union
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed import ProcessGroup
from torch.nn import CrossEntropyLoss
-from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
-from transformers.modeling_outputs import CausalLMOutputWithPast
+from transformers.cache_utils import Cache, DynamicCache
+from transformers.modeling_attn_mask_utils import (
+ _prepare_4d_causal_attention_mask,
+ _prepare_4d_causal_attention_mask_for_sdpa,
+)
+from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
+from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
from transformers.utils import is_flash_attn_2_available, logging
from colossalai.lazy import LazyInitContext
@@ -18,6 +24,11 @@ from colossalai.moe._operation import (
all_to_all_uneven,
)
from colossalai.pipeline.stage_manager import PipelineStageManager
+from colossalai.shardformer.layer._operation import (
+ all_to_all_comm,
+ gather_forward_split_backward,
+ split_forward_gather_backward,
+)
from colossalai.shardformer.layer.linear import Linear1D_Col, Linear1D_Row
from colossalai.shardformer.shard import ShardConfig
from colossalai.shardformer.shard.utils import set_tensors_to_none
@@ -362,7 +373,14 @@ class DeepseekPipelineForwards:
next_cache = next_decoder_cache if use_cache else None
if stage_manager.is_last_stage():
- return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
# always return dict for imediate stage
return {
"hidden_states": hidden_states,
@@ -479,3 +497,276 @@ class DeepseekPipelineForwards:
hidden_states = outputs.get("hidden_states")
out["hidden_states"] = hidden_states
return out
+
+
+def get_deepseek_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None):
+ logger = logging.get_logger(__name__)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if sp_mode is not None:
+ assert sp_mode in ["all_to_all", "split_gather", "ring"], "Invalid sp_mode"
+ assert (sp_size is not None) and (
+ sp_group is not None
+ ), "Must specify sp_size and sp_group for sequence parallel"
+
+ # DeepseekFlashAttention2 attention does not support output_attentions
+ if "padding_mask" in kwargs:
+ warnings.warn(
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
+ )
+
+ # overwrite attention_mask with padding_mask
+ attention_mask = kwargs.pop("padding_mask")
+
+ output_attentions = False
+
+ bsz, q_len, _ = hidden_states.size()
+
+ # sp: modify sp_len when sequence parallel mode is ring
+ if sp_mode in ["split_gather", "ring"]:
+ q_len *= sp_size
+
+ rank = dist.get_rank()
+ print(f"{rank=}, hidden states:{hidden_states.shape}")
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ rank = dist.get_rank()
+ print(f"{rank=}, before all to all q:{query_states.shape}, k:{key_states.shape}, v:{value_states.shape}")
+ # sp: all-to-all comminucation when introducing sequence parallel
+ if sp_mode == "all_to_all":
+ query_states = all_to_all_comm(query_states, sp_group)
+ key_states = all_to_all_comm(key_states, sp_group)
+ value_states = all_to_all_comm(value_states, sp_group)
+ bsz, q_len, _ = query_states.size()
+ print(f"{rank=}, after all to all q:{query_states.shape}, k:{key_states.shape}, v:{value_states.shape}")
+ # Flash attention requires the input to have the shape
+ # batch_size x seq_length x head_dim x hidden_dim
+ # therefore we just need to keep the original shape
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ print(f"{rank=}, after view to (b,s,h,d) q:{query_states.shape}, k:{key_states.shape}, v:{value_states.shape}")
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+ query_states, key_states = apply_rotary_pos_emb(
+ query_states, key_states, cos, sin, position_ids, unsqueeze_dim=0
+ )
+ print(f"{rank=}, after rope q:{query_states.shape}, k:{key_states.shape}, v:{value_states.shape}")
+
+ if past_key_value is not None:
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
+ # to be able to avoid many of these transpose/reshape/view.
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+ print(
+ f"{rank=}, after transpose to (b, nh, s, d) q:{query_states.shape}, k:{key_states.shape}, v:{value_states.shape}"
+ )
+ dropout_rate = self.attention_dropout if self.training else 0.0
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (DeepseekRMSNorm handles it correctly)
+
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ # Handle the case where the model is quantized
+ if hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ elif torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+ print(f"{rank=}, before flash attn q:{query_states.shape}, k:{key_states.shape}, v:{value_states.shape}")
+ attn_output = self._flash_attention_forward(
+ query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
+ )
+ # sp: all-to-all comminucation when introducing sequence parallel
+ if sp_mode == "all_to_all":
+ attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous() # (1, 8, 128)
+ attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) # (1, 4, 256)
+ else:
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+ return forward
+
+
+def get_deepseek_flash_attention_model_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None):
+ logger = logging.get_logger(__name__)
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # retrieve input_ids and inputs_embeds
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ batch_size, seq_length = input_ids.shape[:2]
+ elif inputs_embeds is not None:
+ batch_size, seq_length = inputs_embeds.shape[:2]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`transformers."
+ )
+ use_cache = False
+
+ past_key_values_length = 0
+ if use_cache:
+ use_legacy_cache = not isinstance(past_key_values, Cache)
+ if use_legacy_cache:
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+ past_key_values_length = past_key_values.get_usable_length(seq_length)
+
+ if position_ids is None:
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+ position_ids = torch.arange(
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
+ )
+ position_ids = position_ids.unsqueeze(0)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ if self._use_flash_attention_2:
+ # 2d mask is passed through the layers
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
+ elif self._use_sdpa and not output_attentions:
+ # output_attentions=True can not be supported when using SDPA, and we fall back on
+ # the manual implementation that requires a 4D causal mask in all cases.
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
+ attention_mask,
+ (batch_size, seq_length),
+ inputs_embeds,
+ past_key_values_length,
+ )
+ else:
+ # 4d mask is passed through the layers
+ attention_mask = _prepare_4d_causal_attention_mask(
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
+ )
+
+ if sp_mode in ["ring", "split_gather"]:
+ inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
+ elif sp_mode == "all_to_all":
+ inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size)
+ # embed positions
+ hidden_states = inputs_embeds
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = None
+
+ for decoder_layer in self.layers:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ decoder_layer.__call__,
+ hidden_states,
+ attention_mask,
+ position_ids,
+ past_key_values,
+ output_attentions,
+ use_cache,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states)
+
+ if sp_mode == "ring" or sp_mode == "split_gather":
+ hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group)
+ elif sp_mode == "all_to_all":
+ hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = None
+ if use_cache:
+ next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+ return forward
diff --git a/colossalai/shardformer/policies/deepseek.py b/colossalai/shardformer/policies/deepseek.py
index 04d1dcd41..1e44aba4e 100644
--- a/colossalai/shardformer/policies/deepseek.py
+++ b/colossalai/shardformer/policies/deepseek.py
@@ -7,8 +7,14 @@ from torch import Tensor
from torch.nn import Module
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
+from colossalai.shardformer.layer.embedding import PaddingEmbedding, VocabParallelEmbedding1D
from colossalai.shardformer.layer.linear import Linear1D_Row
-from colossalai.shardformer.modeling.deepseek import DeepseekPipelineForwards, EPDeepseekMoE
+from colossalai.shardformer.modeling.deepseek import (
+ DeepseekPipelineForwards,
+ EPDeepseekMoE,
+ get_deepseek_flash_attention_forward,
+ get_deepseek_flash_attention_model_forward,
+)
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ["DeepseekPolicy", "DeepseekForCausalLMPolicy"]
@@ -19,6 +25,13 @@ class DeepseekPolicy(Policy):
pass
def preprocess(self):
+ self.tie_weight = self.tie_weight_check()
+ self.origin_attn_implement = self.model.config._attn_implementation
+ """
+ Because transformers library's bug for AutoModel/AutoConfig, who pop “attn_implement” twice from modeling_utils.py and configuration_utils.py.
+ This bug causes attn_cls to be set to sdpa. Here we assign it to "flash_attention_2".
+ """
+ # self.origin_attn_implement = "flash_attention_2"
if self.shard_config.enable_tensor_parallelism:
# Resize embedding
vocab_size = self.model.config.vocab_size
@@ -31,17 +44,61 @@ class DeepseekPolicy(Policy):
return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
- policy = {}
+ ATTN_IMPLEMENTATION = {
+ "eager": "DeepseekAttention",
+ "flash_attention_2": "DeepseekFlashAttention2",
+ "sdpa": "DeepseekSdpaAttention",
+ }
+ policy = {}
+ print(f"{self.origin_attn_implement=}")
+ attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement]
+ sp_mode = self.shard_config.sequence_parallelism_mode or None
+ sp_size = self.shard_config.sequence_parallel_size or None
+ sp_group = self.shard_config.sequence_parallel_process_group or None
+ sp_partial_derived = sp_mode in ["split_gather", "ring"]
+ if sp_mode == "all_to_all":
+ decoder_attribute_replacement = {
+ "num_heads": self.model.config.num_attention_heads // sp_size,
+ }
+ if getattr(self.model.config, "num_key_value_heads", False):
+ decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size
+
+ policy[attn_cls] = ModulePolicyDescription(
+ attribute_replacement=decoder_attribute_replacement,
+ )
if self.shard_config.enable_sequence_parallelism:
if self.pipeline_stage_manager is not None:
# NOTE: we are replacing model forward for both sequence parallelism and pipeline parallelism
# if both are enabled, one of them will be ignored
raise NotImplementedError("Sequence parallelism is not supported with pipeline parallelism.")
- raise NotImplementedError(
- "Deepseek dosen't support sequence parallelism now, will ignore the sequence parallelism flag."
+ print(f"{attn_cls=}")
+ self.append_or_create_method_replacement(
+ description={
+ "forward": get_deepseek_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group),
+ },
+ policy=policy,
+ target_key=attn_cls,
)
-
+ if self.pipeline_stage_manager is None:
+ self.append_or_create_method_replacement(
+ description={
+ "forward": get_deepseek_flash_attention_model_forward(
+ self.shard_config,
+ sp_mode=sp_mode,
+ sp_size=sp_size,
+ sp_group=sp_group,
+ ),
+ },
+ policy=policy,
+ target_key="DeepseekModel",
+ )
+ embedding_cls = None
+ if self.shard_config.enable_tensor_parallelism:
+ embedding_cls = VocabParallelEmbedding1D
+ else:
+ if self.tie_weight:
+ embedding_cls = PaddingEmbedding
if self.shard_config.enable_tensor_parallelism:
# tensor parallelism for non-moe params
assert (
@@ -78,6 +135,16 @@ class DeepseekPolicy(Policy):
),
],
)
+ if embedding_cls is not None:
+ self.append_or_create_submodule_replacement(
+ description=SubModuleReplacementDescription(
+ suffix="embed_tokens",
+ target_module=embedding_cls,
+ kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
+ ),
+ policy=policy,
+ target_key="DeepseekModel",
+ )
if self.shard_config.ep_group:
# expert parallel
@@ -105,10 +172,12 @@ class DeepseekPolicy(Policy):
SubModuleReplacementDescription(
suffix="input_layernorm",
target_module=FusedRMSNorm,
+ kwargs={"sp_partial_derived": sp_partial_derived},
),
SubModuleReplacementDescription(
suffix="post_attention_layernorm",
target_module=FusedRMSNorm,
+ kwargs={"sp_partial_derived": sp_partial_derived},
),
],
policy=policy,
@@ -119,6 +188,7 @@ class DeepseekPolicy(Policy):
description=SubModuleReplacementDescription(
suffix="norm",
target_module=FusedRMSNorm,
+ kwargs={"sp_partial_derived": sp_partial_derived},
),
policy=policy,
target_key="DeepseekModel",
diff --git a/tests/kit/model_zoo/transformers/__init__.py b/tests/kit/model_zoo/transformers/__init__.py
index ac5184065..4adc38619 100644
--- a/tests/kit/model_zoo/transformers/__init__.py
+++ b/tests/kit/model_zoo/transformers/__init__.py
@@ -4,6 +4,7 @@ from .blip2 import *
from .bloom import *
from .chatglm2 import *
from .command import *
+from .deepseek import *
from .falcon import *
from .gpt import *
from .gptj import *
diff --git a/tests/kit/model_zoo/transformers/deepseek.py b/tests/kit/model_zoo/transformers/deepseek.py
new file mode 100644
index 000000000..b8b446b57
--- /dev/null
+++ b/tests/kit/model_zoo/transformers/deepseek.py
@@ -0,0 +1,84 @@
+# modified from tests/kit/model_zoo/transformers/mistral.py
+import torch
+import transformers
+from transformers import AutoConfig
+
+from ..registry import ModelAttribute, model_zoo
+
+# ===============================
+# Register single-sentence Mixtral
+# ===============================
+
+
+def data_gen():
+ # Generated from following code snippet
+ #
+ # from transformers import AutoModelForCausalLM, AutoTokenizer
+ # tokenizer = AutoTokenizer.from_pretrained("mixtralai/Mixtral-7B-v0.1")
+ # input = 'My favourite condiment is vinegar' (last two words repeated to satisfy length requirement)
+ # tokenized_input = tokenizer([input], return_tensors="pt")
+ # input_ids = tokenized_input['input_ids']
+ # attention_mask = tokenized_input['attention_mask']
+ input_ids = torch.tensor([[1, 22, 55, 77, 532, 349, 43, 22]], dtype=torch.int64)
+ attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
+ return dict(input_ids=input_ids, attention_mask=attention_mask)
+
+
+def data_gen_for_lm():
+ # LM data gen
+ # the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels`
+ data = data_gen()
+ data["labels"] = data["input_ids"].clone()
+ return data
+
+
+def data_gen_for_sequence_classification():
+ # sequence classification data gen
+ data = data_gen()
+ data["labels"] = torch.tensor([1], dtype=torch.int64)
+ return data
+
+
+# define output transform function
+output_transform_fn = lambda x: x
+
+# define loss function
+loss_fn_for_mixtral_model = lambda x: x[0].mean()
+loss_fn = lambda x: x.loss
+loss_fn_for_seq_classification = lambda output: output.logits.mean()
+
+
+def init_deepseek():
+
+ config = AutoConfig.from_pretrained(
+ "deepseek-ai/deepseek-moe-16b-base",
+ hidden_size=32,
+ intermediate_size=32,
+ moe_intermediate_size=32,
+ num_hidden_layers=2,
+ num_attention_heads=8,
+ num_key_value_heads=8,
+ # vocab_size=2200,
+ first_k_dense_replace=1,
+ attn_implementation="flash_attention_2",
+ torch_dtype="float16",
+ n_routed_experts=8,
+ trust_remote_code=True,
+ )
+
+ if hasattr(config, "pad_token_id"):
+ config.pad_token_id = config.eos_token_id
+ print(config)
+ model = transformers.AutoModel.from_config(config, trust_remote_code=True)
+
+ return model
+
+
+model_zoo.register(
+ name="transformers_deepseek",
+ model_fn=init_deepseek,
+ data_gen_fn=data_gen,
+ output_transform_fn=output_transform_fn,
+ loss_fn=loss_fn_for_mixtral_model,
+ model_attribute=ModelAttribute(has_control_flow=True),
+)
diff --git a/tests/test_shardformer/test_model/test_shard_deepseek.py b/tests/test_shardformer/test_model/test_shard_deepseek.py
index c301777f2..6e8ef2da3 100644
--- a/tests/test_shardformer/test_model/test_shard_deepseek.py
+++ b/tests/test_shardformer/test_model/test_shard_deepseek.py
@@ -36,8 +36,8 @@ CHECKED_CONFIG = [ # FOR_WORLD=8
[
# (2, 1, 2, 1, 1), # TODO debug deepseek pp
# (2, 1, 2, 2, 1), # TODO debug deepseek pp
- (2, 1, 1, 2, 1),
- # (2, 1, 1, 1, 2), # TODO support deepseek sp
+ # (2, 1, 1, 2, 1),
+ (2, 1, 1, 1, 2),
# (2, 1, 4, 1, 1), # TODO debug deepseek pp
# (4, 1, 2, 1, 1), # TODO debug deepseek pp
],
@@ -69,14 +69,22 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
booster = Booster(plugin=plugin)
assert pp_size <= NUM_LAYERS, "pp_size should be less than or equal to NUM_LAYERS"
- config = AutoConfig.from_pretrained("deepseek-ai/deepseek-moe-16b-base", trust_remote_code=True)
- config.hidden_size = HIDDEN_SIZE_PER_HEAD * NUM_HEADS
- config.intermediate_size = HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2
- config.num_hidden_layers = 2
- config.num_attention_heads = NUM_HEADS
- config.num_key_value_heads = NUM_HEADS
- config.n_routed_experts = NUM_EXPERTS
- config.num_experts_per_tok = TOP_K
+ # config = AutoConfig.from_pretrained("deepseek-ai/deepseek-moe-16b-base", trust_remote_code=True)
+ config = AutoConfig.from_pretrained(
+ "deepseek-ai/deepseek-moe-16b-base",
+ hidden_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS,
+ intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2,
+ moe_intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2,
+ num_hidden_layers=2,
+ num_attention_heads=NUM_HEADS,
+ num_key_value_heads=NUM_HEADS,
+ first_k_dense_replace=1,
+ attn_implementation="flash_attention_2",
+ torch_dtype="float16",
+ n_routed_experts=NUM_EXPERTS,
+ num_experts_per_tok=TOP_K,
+ trust_remote_code=True,
+ )
# init model with the same seed
seed_all(10086)
diff --git a/tests/test_shardformer/test_model/test_shard_deepseek_ghz.py b/tests/test_shardformer/test_model/test_shard_deepseek_ghz.py
new file mode 100644
index 000000000..fdca11005
--- /dev/null
+++ b/tests/test_shardformer/test_model/test_shard_deepseek_ghz.py
@@ -0,0 +1,231 @@
+# modified from test_shard_mistral.py
+import os
+
+import pytest
+import torch
+import torch.distributed as dist
+from torch.testing import assert_close
+
+import colossalai
+from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
+from colossalai.logging import disable_existing_loggers
+from colossalai.shardformer.layer.utils import Randomizer
+from colossalai.tensor.d_tensor.api import clear_layout_converter
+from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
+from tests.kit.model_zoo import model_zoo
+from tests.test_shardformer.test_model._utils import (
+ build_model_from_hybrid_plugin,
+ check_all_grad_tensors,
+ check_loss,
+ check_output_hidden_state,
+ check_weight,
+ get_grad_tensors_for_check,
+ run_forward_backward_with_hybrid_plugin,
+ unwrap_model,
+)
+
+os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
+
+
+def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
+ # TODO: SGD failed for full dp
+ org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin(
+ model_fn, loss_fn, test_config, pluggin_cls=MoeHybridParallelPlugin, optim_class=torch.optim.SGD
+ )
+
+ org_model = org_model.to(torch.float16)
+ org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
+ org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
+ )
+ stage_manager = booster.plugin.stage_manager
+ tp_group = booster.plugin.tp_group
+
+ # check last hidden state & loss
+ if stage_manager is None or stage_manager.is_last_stage():
+ if test_config["precision"] == "fp32":
+ atol, rtol = 1e-5, 1e-3
+ else:
+ atol, rtol = 5e-3, 5e-3
+
+ check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
+ check_output_hidden_state(org_output, sharded_output, stage_manager, atol, rtol)
+
+ # unwrap model
+ mixtral_model = unwrap_model(org_model, "DeepseekModel", "model")
+ shard_mixtral_model = unwrap_model(sharded_model, "DeepseekModel", "model")
+
+ row_layer_for_check = ["layers[0].self_attn.q_proj", "embed_tokens"]
+ col_layer_for_check = ["layers[0].self_attn.o_proj"]
+
+ name_to_p = {n: p for n, p in mixtral_model.named_parameters()}
+ # Check the grad when using ZeRO-1 and ZeRO-2
+ if (
+ # booster.plugin.zero_stage in [1, 2]
+ booster.plugin.shard_config.enable_sequence_parallelism
+ and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all"
+ ):
+ rank = dist.get_rank()
+ for n, p in shard_mixtral_model.named_parameters():
+ zero_grad = sharded_optimizer.get_param_grad(p)
+ if name_to_p[n].grad is None:
+ name_to_p[n].grad = torch.zeros_like(name_to_p[n].data)
+ continue
+ assert_close(name_to_p[n].grad, zero_grad, atol=5e-3, rtol=5e-3, check_dtype=False)
+
+ # Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
+ grads_to_check = {}
+ if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
+ if test_config["precision"] == "fp32":
+ atol, rtol = 5e-5, 1e-4
+ else:
+ atol, rtol = 5e-3, 5e-3
+ row_layer_grads = get_grad_tensors_for_check(
+ mixtral_model,
+ shard_mixtral_model,
+ row_layer_for_check,
+ tp_group,
+ atol=atol,
+ rtol=rtol,
+ dim=0,
+ verbose=False,
+ )
+ col_layer_grads = get_grad_tensors_for_check(
+ mixtral_model,
+ shard_mixtral_model,
+ col_layer_for_check,
+ tp_group,
+ atol=atol,
+ rtol=rtol,
+ dim=1,
+ verbose=False,
+ )
+ grads_to_check.update(col_layer_grads)
+ grads_to_check.update(row_layer_grads)
+
+ # check grads
+ check_all_grad_tensors(grads_to_check)
+
+ for n, p in shard_mixtral_model.named_parameters():
+ assert_close(name_to_p[n], p, atol=5e-3, rtol=5e-3, check_dtype=False)
+
+ # optimizer executes step
+ org_optimizer.step()
+ sharded_optimizer.step()
+
+ for n, p in shard_mixtral_model.named_parameters():
+ assert_close(name_to_p[n], p, atol=5e-3, rtol=5e-3, check_dtype=False)
+
+ # check weights
+ if stage_manager is None or stage_manager.is_first_stage():
+ if test_config["precision"] == "fp32":
+ atol, rtol = 2e-4, 1e-3
+ else:
+ atol, rtol = 5e-3, 5e-3
+ try:
+ check_weight(
+ mixtral_model,
+ shard_mixtral_model,
+ col_layer_for_check,
+ tp_group,
+ atol=atol,
+ rtol=rtol,
+ dim=1,
+ verbose=False,
+ )
+ except Exception as e:
+ rank = dist.get_rank()
+ print(f"{rank=}, Failed config: {test_config}")
+ raise e
+
+ torch.cuda.empty_cache()
+
+
+@parameterize(
+ "test_config",
+ [
+ # {
+ # "tp_size": 1,
+ # "pp_size": 1,
+ # "num_microbatches": 2,
+ # "ep_size": 2,
+ # "zero_stage": 0,
+ # "overlap_communication": False,
+ # "precision": "fp16",
+ # }, # [dp(4)] + [moe_dp(4)]
+ # {
+ # "tp_size": 1,
+ # "pp_size": 2,
+ # "num_microbatches": 2,
+ # "ep_size": 2,
+ # "zero_stage": 1,
+ # "overlap_communication": False,
+ # "precision": "fp32",
+ # }, # [dp(2) + pp(2)] + [moe_pp(2)]
+ # {
+ # "tp_size": 1,
+ # "pp_size": 2,
+ # "ep_size": 2,
+ # "num_microbatches": 2,
+ # "zero_stage": 1,
+ # "overlap_communication": False,
+ # "precision": "fp16",
+ # "initial_scale": 1,
+ # "find_unused_parameters": True,
+ # }, # [pp(2) + tp(2)] + [pp(2), replicate(2)] pass
+ { # Ulysess + Flash attention
+ "tp_size": 1,
+ "pp_size": 1,
+ "sp_size": 2,
+ "ep_size": 2,
+ "enable_sequence_parallelism": True,
+ "sequence_parallelism_mode": "all_to_all",
+ "zero_stage": 1,
+ "overlap_communication": False,
+ "precision": "fp16",
+ "initial_scale": 1,
+ "find_unused_parameters": True,
+ },
+ # {
+ # "tp_size": 1,
+ # "pp_size": 1,
+ # "ep_size": 2,
+ # "zero_stage": 0,
+ # "overlap_communication": False,
+ # "precision": "fp32",
+ # }, # [dp(4)] + [ep(2) + moe_tp(2)]
+ # {
+ # "tp_size": 1,
+ # "pp_size": 1,
+ # "ep_size": 4,
+ # "overlap_communication": False,
+ # "zero_stage": 0,
+ # "precision": "fp32"
+ # }, # full dp for non-moe and full ep for moe
+ ],
+)
+def run_deepseek_test(test_config):
+ sub_model_zoo = model_zoo.get_sub_registry("transformers_deepseek")
+
+ for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
+ check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
+
+ clear_layout_converter()
+ Randomizer.reset_index()
+ torch.cuda.empty_cache()
+
+
+def check_deepseek(rank, world_size, port):
+ disable_existing_loggers()
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ run_deepseek_test()
+
+
+@pytest.mark.dist
+@rerun_if_address_is_in_use()
+@clear_cache_before_run()
+def test_mixtral():
+ spawn(check_deepseek, 4)
+
+
+if __name__ == "__main__":
+ test_mixtral()
From 6c39f0b144d66db1044be081657cdf899ab31053 Mon Sep 17 00:00:00 2001
From: hxwang
Date: Tue, 23 Jul 2024 09:38:05 +0000
Subject: [PATCH 089/116] [test] add check
---
.../booster/plugin/moe_hybrid_parallel_plugin.py | 3 +++
.../test_model/test_shard_deepseek.py | 15 ++++++---------
2 files changed, 9 insertions(+), 9 deletions(-)
diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
index 55855afb2..9625afc1b 100644
--- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
+++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
@@ -73,6 +73,9 @@ class MoeHybridParallelZeroOptimizer(HybridParallelZeroOptimizer):
moe_dp_group: list(filter(is_moe_tensor, model.parameters())),
}
+ if len(pg_param_list[dp_process_group]) == 0 or len(pg_param_list[moe_dp_group]) == 0:
+ raise ValueError("No parameters found in dp_process_group or moe_dp_group")
+
super().__init__(
model=model,
optimizer=optimizer,
diff --git a/tests/test_shardformer/test_model/test_shard_deepseek.py b/tests/test_shardformer/test_model/test_shard_deepseek.py
index 6e8ef2da3..5e6c0bf1c 100644
--- a/tests/test_shardformer/test_model/test_shard_deepseek.py
+++ b/tests/test_shardformer/test_model/test_shard_deepseek.py
@@ -17,29 +17,27 @@ from colossalai.testing.random import seed_all
from tests.test_moe.moe_utils import assert_loose_close, check_model_equal
NUM_BATCH = 8
-NUM_TOK_PER_BATCH, NUM_EXPERTS = 4, 4
+NUM_TOK_PER_BATCH, NUM_EXPERTS = 4000, 2
NUM_LAYERS = 4
HIDDEN_SIZE_PER_HEAD = 4
NUM_HEADS = 4
-TOP_K = 1
+TOP_K = 2
CHECKED_CONFIG = [ # FOR_WORLD=8
(2, 1, 1, 4, 1),
(4, 1, 1, 2, 1),
(4, 1, 1, 1, 1),
+ (2, 1, 2, 1, 1),
]
@parameterize(
"config",
[
- # (2, 1, 2, 1, 1), # TODO debug deepseek pp
- # (2, 1, 2, 2, 1), # TODO debug deepseek pp
+ (2, 1, 2, 1, 1),
# (2, 1, 1, 2, 1),
- (2, 1, 1, 1, 2),
- # (2, 1, 4, 1, 1), # TODO debug deepseek pp
- # (4, 1, 2, 1, 1), # TODO debug deepseek pp
+ # (2, 1, 1, 1, 2),
],
)
def run_zero_with_original_model(config: Tuple[int, ...]):
@@ -69,13 +67,12 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
booster = Booster(plugin=plugin)
assert pp_size <= NUM_LAYERS, "pp_size should be less than or equal to NUM_LAYERS"
- # config = AutoConfig.from_pretrained("deepseek-ai/deepseek-moe-16b-base", trust_remote_code=True)
config = AutoConfig.from_pretrained(
"deepseek-ai/deepseek-moe-16b-base",
hidden_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS,
intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2,
moe_intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2,
- num_hidden_layers=2,
+ num_hidden_layers=4,
num_attention_heads=NUM_HEADS,
num_key_value_heads=NUM_HEADS,
first_k_dense_replace=1,
From c3dc9b4dba80f7f9948a89463ee97d96e20e641f Mon Sep 17 00:00:00 2001
From: hxwang
Date: Tue, 23 Jul 2024 12:56:58 +0000
Subject: [PATCH 090/116] [deepseek] replace attn (a workaround for bug in
transformers)
---
colossalai/shardformer/policies/deepseek.py | 34 ++++++++++++++++---
.../test_model/test_shard_deepseek_ghz.py | 1 +
2 files changed, 30 insertions(+), 5 deletions(-)
diff --git a/colossalai/shardformer/policies/deepseek.py b/colossalai/shardformer/policies/deepseek.py
index 1e44aba4e..d1d004ed5 100644
--- a/colossalai/shardformer/policies/deepseek.py
+++ b/colossalai/shardformer/policies/deepseek.py
@@ -1,4 +1,3 @@
-import warnings
from functools import partial
from typing import Callable, Dict, List, Union
@@ -195,11 +194,36 @@ class DeepseekPolicy(Policy):
)
if self.shard_config.enable_flash_attention:
- warnings.warn(
- "Flash attention has already been replaced in deepseek, and now set enable_flash_attention = False."
- )
- self.shard_config.enable_flash_attention = False
+ # NOTE: there is a bug for toggling flash attention in AutoModel, which has to be used for deepseek right now
+ from transformers.dynamic_module_utils import get_class_from_dynamic_module
+ flash_attn_cls = get_class_from_dynamic_module(
+ "deepseek-ai/deepseek-moe-16b-base--modeling_deepseek.DeepseekFlashAttention2",
+ "deepseek-ai/deepseek-moe-16b-base",
+ )
+
+ class TargetFlashAttn:
+ def __init__(self):
+ raise RuntimeError("This class should not be instantiated")
+
+ @staticmethod
+ def from_native_module(original_attn: nn.Module, *args, **kwargs) -> nn.Module:
+ flash_attn_module = flash_attn_cls(original_attn.config, original_attn.layer_idx)
+ flash_attn_module.q_proj = original_attn.q_proj
+ flash_attn_module.k_proj = original_attn.k_proj
+ flash_attn_module.v_proj = original_attn.v_proj
+ flash_attn_module.o_proj = original_attn.o_proj
+ flash_attn_module.rotary_emb = original_attn.rotary_emb
+ return flash_attn_module
+
+ self.append_or_create_submodule_replacement(
+ description=SubModuleReplacementDescription(
+ suffix="self_attn",
+ target_module=TargetFlashAttn,
+ ),
+ policy=policy,
+ target_key="DeepseekDecoderLayer",
+ )
return policy
def postprocess(self):
diff --git a/tests/test_shardformer/test_model/test_shard_deepseek_ghz.py b/tests/test_shardformer/test_model/test_shard_deepseek_ghz.py
index fdca11005..fe834a4f6 100644
--- a/tests/test_shardformer/test_model/test_shard_deepseek_ghz.py
+++ b/tests/test_shardformer/test_model/test_shard_deepseek_ghz.py
@@ -220,6 +220,7 @@ def check_deepseek(rank, world_size, port):
run_deepseek_test()
+@pytest.mark.skip("redundant")
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
From 59bcf56c60771bc790fc83f45cbc9a6df123d98b Mon Sep 17 00:00:00 2001
From: haze188
Date: Wed, 24 Jul 2024 02:59:35 +0000
Subject: [PATCH 091/116] [misc] skip redunant test
---
.../{test_shard_deepseek_ghz.py => test_shard_deepseek_skip.py} | 0
1 file changed, 0 insertions(+), 0 deletions(-)
rename tests/test_shardformer/test_model/{test_shard_deepseek_ghz.py => test_shard_deepseek_skip.py} (100%)
diff --git a/tests/test_shardformer/test_model/test_shard_deepseek_ghz.py b/tests/test_shardformer/test_model/test_shard_deepseek_skip.py
similarity index 100%
rename from tests/test_shardformer/test_model/test_shard_deepseek_ghz.py
rename to tests/test_shardformer/test_model/test_shard_deepseek_skip.py
From 034020bd045e639e03088eae41df98680567bcb6 Mon Sep 17 00:00:00 2001
From: haze188
Date: Wed, 24 Jul 2024 03:10:57 +0000
Subject: [PATCH 092/116] [misc] remove debug/print code
---
colossalai/shardformer/modeling/deepseek.py | 11 -----------
colossalai/shardformer/policies/deepseek.py | 11 ++---------
tests/kit/model_zoo/transformers/deepseek.py | 1 -
3 files changed, 2 insertions(+), 21 deletions(-)
diff --git a/colossalai/shardformer/modeling/deepseek.py b/colossalai/shardformer/modeling/deepseek.py
index 468b890ab..52ea6c22b 100644
--- a/colossalai/shardformer/modeling/deepseek.py
+++ b/colossalai/shardformer/modeling/deepseek.py
@@ -535,28 +535,22 @@ def get_deepseek_flash_attention_forward(shard_config, sp_mode=None, sp_size=Non
if sp_mode in ["split_gather", "ring"]:
q_len *= sp_size
- rank = dist.get_rank()
- print(f"{rank=}, hidden states:{hidden_states.shape}")
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
- rank = dist.get_rank()
- print(f"{rank=}, before all to all q:{query_states.shape}, k:{key_states.shape}, v:{value_states.shape}")
# sp: all-to-all comminucation when introducing sequence parallel
if sp_mode == "all_to_all":
query_states = all_to_all_comm(query_states, sp_group)
key_states = all_to_all_comm(key_states, sp_group)
value_states = all_to_all_comm(value_states, sp_group)
bsz, q_len, _ = query_states.size()
- print(f"{rank=}, after all to all q:{query_states.shape}, k:{key_states.shape}, v:{value_states.shape}")
# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim
# therefore we just need to keep the original shape
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
- print(f"{rank=}, after view to (b,s,h,d) q:{query_states.shape}, k:{key_states.shape}, v:{value_states.shape}")
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
@@ -565,7 +559,6 @@ def get_deepseek_flash_attention_forward(shard_config, sp_mode=None, sp_size=Non
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids, unsqueeze_dim=0
)
- print(f"{rank=}, after rope q:{query_states.shape}, k:{key_states.shape}, v:{value_states.shape}")
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
@@ -576,9 +569,6 @@ def get_deepseek_flash_attention_forward(shard_config, sp_mode=None, sp_size=Non
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
- print(
- f"{rank=}, after transpose to (b, nh, s, d) q:{query_states.shape}, k:{key_states.shape}, v:{value_states.shape}"
- )
dropout_rate = self.attention_dropout if self.training else 0.0
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
@@ -606,7 +596,6 @@ def get_deepseek_flash_attention_forward(shard_config, sp_mode=None, sp_size=Non
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
- print(f"{rank=}, before flash attn q:{query_states.shape}, k:{key_states.shape}, v:{value_states.shape}")
attn_output = self._flash_attention_forward(
query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
)
diff --git a/colossalai/shardformer/policies/deepseek.py b/colossalai/shardformer/policies/deepseek.py
index d1d004ed5..963bd9d67 100644
--- a/colossalai/shardformer/policies/deepseek.py
+++ b/colossalai/shardformer/policies/deepseek.py
@@ -50,7 +50,6 @@ class DeepseekPolicy(Policy):
"sdpa": "DeepseekSdpaAttention",
}
policy = {}
- print(f"{self.origin_attn_implement=}")
attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement]
sp_mode = self.shard_config.sequence_parallelism_mode or None
sp_size = self.shard_config.sequence_parallel_size or None
@@ -71,7 +70,6 @@ class DeepseekPolicy(Policy):
# NOTE: we are replacing model forward for both sequence parallelism and pipeline parallelism
# if both are enabled, one of them will be ignored
raise NotImplementedError("Sequence parallelism is not supported with pipeline parallelism.")
- print(f"{attn_cls=}")
self.append_or_create_method_replacement(
description={
"forward": get_deepseek_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group),
@@ -208,13 +206,8 @@ class DeepseekPolicy(Policy):
@staticmethod
def from_native_module(original_attn: nn.Module, *args, **kwargs) -> nn.Module:
- flash_attn_module = flash_attn_cls(original_attn.config, original_attn.layer_idx)
- flash_attn_module.q_proj = original_attn.q_proj
- flash_attn_module.k_proj = original_attn.k_proj
- flash_attn_module.v_proj = original_attn.v_proj
- flash_attn_module.o_proj = original_attn.o_proj
- flash_attn_module.rotary_emb = original_attn.rotary_emb
- return flash_attn_module
+ original_attn.__class__ = flash_attn_cls
+ return original_attn
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
diff --git a/tests/kit/model_zoo/transformers/deepseek.py b/tests/kit/model_zoo/transformers/deepseek.py
index b8b446b57..ad73640a5 100644
--- a/tests/kit/model_zoo/transformers/deepseek.py
+++ b/tests/kit/model_zoo/transformers/deepseek.py
@@ -68,7 +68,6 @@ def init_deepseek():
if hasattr(config, "pad_token_id"):
config.pad_token_id = config.eos_token_id
- print(config)
model = transformers.AutoModel.from_config(config, trust_remote_code=True)
return model
From cb01c0d5cee868e81edd33d2ba8fbb3d2b9da932 Mon Sep 17 00:00:00 2001
From: hxwang
Date: Thu, 25 Jul 2024 06:19:54 +0000
Subject: [PATCH 093/116] [moe] refactor mesh assignment
---
.../plugin/moe_hybrid_parallel_plugin.py | 356 ++++++++++++------
colossalai/shardformer/modeling/deepseek.py | 21 +-
colossalai/shardformer/modeling/mixtral.py | 19 +-
colossalai/shardformer/policies/deepseek.py | 1 -
colossalai/shardformer/policies/mixtral.py | 1 -
colossalai/shardformer/shard/shard_config.py | 1 -
tests/test_moe/test_deepseek_layer.py | 1 -
tests/test_moe/test_mixtral_layer.py | 1 -
.../test_model/test_shard_deepseek.py | 23 +-
.../test_model/test_shard_mixtral.py | 23 +-
10 files changed, 277 insertions(+), 170 deletions(-)
diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
index 9625afc1b..9b6eae0d0 100644
--- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
+++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
@@ -1,9 +1,8 @@
import warnings
from collections import defaultdict
from types import MethodType
-from typing import Callable, Optional, OrderedDict, Tuple
+from typing import Callable, List, Optional, OrderedDict, Tuple
-import numpy as np
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
@@ -13,6 +12,8 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
from colossalai.booster.plugin.hybrid_parallel_plugin import (
+ PRECISION_TORCH_TYPE,
+ SUPPORT_SP_MODE,
HybridParallelAMPOptimizer,
HybridParallelModule,
HybridParallelNaiveOptimizer,
@@ -22,9 +23,16 @@ from colossalai.booster.plugin.hybrid_parallel_plugin import (
reinitialize_optimizer,
)
from colossalai.checkpoint_io import MoECheckpointIO
+from colossalai.cluster.process_group_mesh import ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.interface.optimizer import DistributedOptim
from colossalai.nn.optimizer import cast_to_distributed
+from colossalai.pipeline.schedule.interleaved_pp import InterleavedSchedule
+from colossalai.pipeline.schedule.one_f_one_b import OneForwardOneBackwardSchedule
+from colossalai.pipeline.stage_manager import PipelineStageManager
+from colossalai.shardformer.policies.base_policy import Policy
+from colossalai.shardformer.shard.grad_ckpt_config import GradientCheckpointConfig
+from colossalai.shardformer.shard.shard_config import ShardConfig
from colossalai.tensor.moe_tensor.api import is_moe_tensor
@@ -57,7 +65,7 @@ class MoeHybridParallelZeroOptimizer(HybridParallelZeroOptimizer):
forced_dtype: Optional[torch.dtype] = None,
overlap_allgather: bool = False,
):
- WARN_STR = "Note that you need to make sure every expert are routed (i.e.) every expert has backward, otherwise this might lead to program hang or inconsistent result"
+ WARN_STR = "Note that you need to make sure every expert are routed (i.e.) every expert has backward, otherwise this might lead to program hang or inconsistent result."
if not force_overlap_comm and (overlap_communication or partition_grad):
raise RuntimeError(
WARN_STR
@@ -105,130 +113,219 @@ class MoeHybridParallelZeroOptimizer(HybridParallelZeroOptimizer):
class MoeHybridParallelPlugin(HybridParallelPlugin):
"""
- TODO: add docstring
+ Modified from colossalai.booster.plugin.hybrid_parallel_plugin.HybridParallelPlugin
+ Extra Args:
+ ep_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1.
+ force_overlap_comm (bool): For LowLevelZeroOptimizer, it might causes program hang when some experts are routed and overlap_communication is True during training. This flag is used to force overlap_communication=True.
"""
- def __init__(self, ep_size: int, moe_tp_size: int = 1, force_overlap_comm=False, *args, **kwargs) -> None:
- if "overlap_communication" not in kwargs:
- kwargs["overlap_communication"] = False # default by true in super class
-
- super().__init__(*args, **kwargs)
-
- if ep_size <= 1:
- raise ValueError("Use HybridParallelPlugin when ep_size <= 1")
+ def __init__(
+ self,
+ tp_size: int,
+ pp_size: int,
+ ep_size: int,
+ sp_size: int = None,
+ precision: str = "fp16",
+ zero_stage: int = 0,
+ enable_all_optimization: bool = False,
+ enable_fused_normalization: bool = False,
+ enable_flash_attention: bool = False,
+ enable_jit_fused: bool = False,
+ enable_sequence_parallelism: bool = False,
+ sequence_parallelism_mode: str = None,
+ enable_sequence_overlap: bool = False,
+ parallel_output: bool = True,
+ num_microbatches: Optional[int] = None,
+ microbatch_size: Optional[int] = None,
+ initial_scale: float = 2**16,
+ min_scale: float = 1,
+ growth_factor: float = 2,
+ backoff_factor: float = 0.5,
+ growth_interval: int = 1000,
+ hysteresis: int = 2,
+ max_scale: float = 2**32,
+ max_norm: float = 0,
+ broadcast_buffers: bool = True,
+ ddp_bucket_cap_mb: int = 25,
+ find_unused_parameters: bool = False,
+ check_reduction: bool = False,
+ gradient_as_bucket_view: bool = False,
+ static_graph: bool = False,
+ zero_bucket_size_in_m: int = 12,
+ cpu_offload: bool = False,
+ communication_dtype: Optional[torch.dtype] = None,
+ overlap_communication: bool = True,
+ custom_policy: Policy = None,
+ pp_style: str = "1f1b",
+ num_model_chunks: int = 1,
+ num_layers_per_stage: Optional[List[int]] = None,
+ gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None,
+ enable_metadata_cache: bool = True,
+ make_vocab_size_divisible_by: int = 64,
+ dp_outside: bool = True,
+ overlap_p2p: bool = True,
+ overlap_allgather: bool = False,
+ force_overlap_comm: bool = False,
+ ) -> None:
+ assert (
+ dist.get_world_size() % (tp_size * pp_size) == 0
+ ), f"World size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
+ if enable_sequence_parallelism:
+ self.sequence_parallelism_mode = (
+ sequence_parallelism_mode if sequence_parallelism_mode is not None else "all_to_all"
+ )
+ assert (
+ self.sequence_parallelism_mode in SUPPORT_SP_MODE
+ ), f"Sequence parallelism mode {self.sequence_parallelism_mode} is not in the supported list {SUPPORT_SP_MODE}"
+ if self.sequence_parallelism_mode in ["split_gather", "ring"]:
+ assert (
+ tp_size > 1
+ ), f"Sequence parallelism mode {self.sequence_parallelism_mode} must be enabled when using tensor parallelism"
+ if sp_size != 1:
+ warnings.warn(
+ f"The sp_size will be the same as tp_size in sequence parallelism mode {self.sequence_parallelism_mode}, will ignore the given sequence parallelism size."
+ )
+ self.sp_size = 1
+ self.dp_size = dist.get_world_size() // (tp_size * pp_size)
+ elif self.sequence_parallelism_mode in ["all_to_all"]:
+ self.sp_size = 1 if sp_size is None else sp_size
+ self.dp_size = dist.get_world_size() // (self.sp_size * pp_size * tp_size)
+ else:
+ self.dp_size = dist.get_world_size() // (tp_size * pp_size)
+ assert (
+ sp_size == 1 or sp_size is None
+ ), f"You should not set sp_size when sequence parallelism is not enabled."
+ self.sp_size = 1
+ assert self.dp_size % ep_size == 0, f"dp_size should be divisible by ep_size, {self.dp_size=} {ep_size=}"
+ self.moe_dp_size = self.dp_size // ep_size
self.ep_size = ep_size
- self.moe_tp_size = moe_tp_size
+ self.tp_size = tp_size
+ self.pp_size = pp_size
+ self.precision = precision
+ self.zero_stage = zero_stage
+ self.cpu_offload = cpu_offload
+ self.enable_all_optimization = enable_all_optimization
+ self.enable_fused_normalization = enable_fused_normalization
+ self.enable_flash_attention = enable_flash_attention
+ self.enable_jit_fused = enable_jit_fused
+ self.enable_sequence_parallelism = enable_sequence_parallelism
+ if dp_outside:
+ self.dp_axis, self.pp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
+ self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size)
+ self.moe_dp_axis, self.ep_axis = 0, 1
+ self.moe_pg_mesh = ProcessGroupMesh(
+ self.moe_dp_size, self.ep_size, self.pp_size, self.tp_size, self.sp_size
+ )
+ else:
+ self.pp_axis, self.dp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
+ self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size)
+ self.moe_dp_axis, self.ep_axis = 1, 2
+ self.moe_pg_mesh = ProcessGroupMesh(
+ self.pp_size, self.moe_dp_size, self.ep_size, self.tp_size, self.sp_size
+ )
- self._init_moe_param_comm()
+ self.stage_manager = None
+ self.schedule = None
+ self.custom_policy = custom_policy
+ assert zero_stage in (0, 1, 2)
+ if self.pp_size > 1:
+ assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style"
+ assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b"
+ assert (
+ num_microbatches is not None or microbatch_size is not None
+ ), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
+ assert (
+ self.zero_stage <= 1
+ ), "To avoid prohibitive gradient synchronization costs, zero stage must be 0 or 1 when using pipeline parallelism"
+ self.stage_manager = PipelineStageManager(
+ self.pg_mesh,
+ pipeline_axis=self.pp_axis,
+ enable_interleave=pp_style == "interleaved",
+ num_model_chunks=num_model_chunks,
+ num_layers_per_stage=num_layers_per_stage,
+ )
- self.use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or (
- self.dp_size == 1
- and self.pp_size == 1
- and self.enable_sequence_parallelism
- and self.sequence_parallelism_mode == "all_to_all"
+ if pp_style == "interleaved":
+ assert num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved"
+ self.schedule = InterleavedSchedule(
+ stage_manager=self.stage_manager,
+ num_model_chunks=num_model_chunks,
+ num_microbatch=num_microbatches,
+ microbatch_size=microbatch_size,
+ enable_metadata_cache=enable_metadata_cache,
+ overlap_p2p=overlap_p2p,
+ )
+ elif pp_style == "1f1b":
+ self.schedule = OneForwardOneBackwardSchedule(
+ stage_manager=self.stage_manager,
+ num_microbatches=num_microbatches,
+ microbatch_size=microbatch_size,
+ enable_metadata_cache=enable_metadata_cache,
+ )
+ else:
+ raise NotImplementedError()
+
+ self.tp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
+ self.dp_group = self.pg_mesh.get_group_along_axis(self.dp_axis)
+ self.pp_group = self.pg_mesh.get_group_along_axis(self.pp_axis)
+ self.moe_dp_group = self.moe_pg_mesh.get_group_along_axis(self.moe_dp_axis)
+ self.ep_group = self.moe_pg_mesh.get_group_along_axis(self.ep_axis)
+ if self.enable_sequence_parallelism and self.sequence_parallelism_mode in ["split_gather", "ring"]:
+ self.sp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
+ else:
+ self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis)
+
+ self.shard_config = ShardConfig(
+ tensor_parallel_process_group=self.tp_group,
+ sequence_parallel_process_group=self.sp_group,
+ ep_group=self.ep_group,
+ moe_dp_group=self.moe_dp_group,
+ pipeline_stage_manager=self.stage_manager,
+ enable_tensor_parallelism=self.tp_size > 1,
+ enable_all_optimization=self.enable_all_optimization,
+ enable_fused_normalization=self.enable_fused_normalization,
+ enable_flash_attention=self.enable_flash_attention,
+ enable_jit_fused=self.enable_jit_fused,
+ enable_sequence_parallelism=enable_sequence_parallelism,
+ sequence_parallelism_mode=sequence_parallelism_mode,
+ enable_sequence_overlap=enable_sequence_overlap,
+ parallel_output=parallel_output,
+ make_vocab_size_divisible_by=make_vocab_size_divisible_by,
+ gradient_checkpoint_config=gradient_checkpoint_config,
+ )
+ self.amp_config = dict(
+ initial_scale=initial_scale,
+ growth_factor=growth_factor,
+ backoff_factor=backoff_factor,
+ growth_interval=growth_interval,
+ hysteresis=hysteresis,
+ min_scale=min_scale,
+ max_scale=max_scale,
)
- if self.use_ddp:
- warnings.warn(
- f"Will have to check all params are used in pytorch DDP since not all experts are always activated"
- )
- self.ddp_config["find_unused_parameters"] = True
+ self.ddp_config = dict(
+ broadcast_buffers=broadcast_buffers,
+ bucket_cap_mb=ddp_bucket_cap_mb,
+ find_unused_parameters=find_unused_parameters,
+ check_reduction=check_reduction,
+ gradient_as_bucket_view=gradient_as_bucket_view,
+ static_graph=static_graph,
+ )
- if dist.get_process_group_ranks(self.dp_group) != dist.get_process_group_ranks(self.moe_dp_group):
- # TODO it might make sense to support non-moe with tp on but moe with tp off
- raise ValueError(
- f"if ddp is used, dp_group and moe_dp_group are expected to be the same since DDP can only reduce grad across a single group, but found dp_group {dist.get_process_group_ranks(self.dp_group)} and moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)}, you might want to use HybridParallelPlugin or set zero_stage > 0"
- )
-
- # set param group in shard config
- self.shard_config.ep_group = self.ep_group
- self.shard_config.moe_dp_group = self.moe_dp_group
- self.shard_config.moe_tp_group = self.moe_tp_group
+ self.zero_config = dict(
+ reduce_bucket_size=zero_bucket_size_in_m * 1024 * 1024,
+ communication_dtype=communication_dtype,
+ overlap_communication=overlap_communication,
+ cpu_offload=cpu_offload,
+ partition_grad=(self.zero_stage == 2),
+ forced_dtype=PRECISION_TORCH_TYPE[precision],
+ overlap_allgather=overlap_allgather,
+ )
+ self.max_norm = max_norm
self.force_overlap_comm = force_overlap_comm
- def _init_moe_param_comm(self):
- world_size = dist.get_world_size()
-
- if self.enable_sequence_parallelism:
- if self.sequence_parallelism_mode == "all_to_all":
- # if sequence parallelism is enabled, ep_group reuses sp_group
- if self.ep_size != self.sp_size:
- raise ValueError(
- f"ep_size={self.ep_size} should be equal to sp_size={self.sp_size} or turned off when sequence parallelism is enabled"
- )
-
- # since we are reusing sp_group, moe_dp_group will be derived as dp_group
- self.moe_dp_size = self.dp_size
- self.moe_dp_group = self.dp_group
- self.dp_sp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
- self.ep_group = self.sp_group
- self.moe_tp_group = self.tp_group
- else:
- raise NotImplementedError(
- f"sequence_parallelism_mode={self.sequence_parallelism_mode} is not supported"
- )
-
- else:
- self.moe_dp_size = world_size // (self.pp_size * self.ep_size * self.moe_tp_size)
-
- if self.moe_dp_size * self.pp_size * self.ep_size * self.moe_tp_size != world_size:
- raise ValueError(
- f"world_size={world_size} is not divisible by pp_size={self.pp_size} * moe_dp_size={self.moe_dp_size} * ep_size={self.ep_size} * moe_tp_size={self.moe_tp_size}"
- )
-
- self.moe_dp_group = None
- self.ep_group = None
- self.moe_tp_group = None
- self.dp_sp_group = self.dp_group
-
- # create submesh for ep, moe_dp, moe_tp
- ranks_by_pp_stage = self.pg_mesh.get_group_along_axis(
- [self.dp_axis, self.tp_axis, self.sp_axis], return_ranks_by_group=True
- )
-
- global_rank = self.pg_mesh.rank
- pp_rank = self.pg_mesh.coordinate(self.pp_axis)
-
- # create groups from submesh
- for stage_idx, stage_rank in enumerate(ranks_by_pp_stage):
- # axis 0 is moe_dp, axis 1 is ep, axis 2 is moe_tp
- submesh = np.array(stage_rank).reshape(self.moe_dp_size, self.ep_size, self.moe_tp_size)
-
- # hardcode here since we only have 3 axis
- # moe_dp_group
- for ep_idx in range(self.ep_size):
- for moe_tp_idx in range(self.moe_tp_size):
- moe_dp_ranks = submesh[:, ep_idx, moe_tp_idx].flatten().tolist()
- group = dist.new_group(moe_dp_ranks)
- if pp_rank == stage_idx and global_rank in moe_dp_ranks:
- assert self.moe_dp_group is None
- self.moe_dp_group = group
- # ep_group
- for moe_dp_idx in range(self.moe_dp_size):
- for moe_tp_idx in range(self.moe_tp_size):
- ep_ranks = submesh[moe_dp_idx, :, moe_tp_idx].flatten().tolist()
- group = dist.new_group(ep_ranks)
- if pp_rank == stage_idx and global_rank in ep_ranks:
- assert self.ep_group is None
- self.ep_group = group
- # moe_tp_group
- for moe_dp_idx in range(self.moe_dp_size):
- for ep_idx in range(self.ep_size):
- moe_tp_ranks = submesh[moe_dp_idx, ep_idx, :].flatten().tolist()
- group = dist.new_group(moe_tp_ranks)
- if pp_rank == stage_idx and global_rank in moe_tp_ranks:
- assert self.moe_tp_group is None
- self.moe_tp_group = group
-
- if dist.get_process_group_ranks(self.tp_group) != dist.get_process_group_ranks(self.moe_tp_group):
- # NOTE: different tp settings between moe and non moe param are complex to handle
- # we simply reuse tp_group as moe_tp_group, this implies that dp_size == moe_dp_size * ep_size
- raise NotImplementedError(
- f"Only support shared tp group between moe and non moe params, but found non-moe tp {dist.get_process_group_ranks(self.tp_group)}, moe tp {dist.get_process_group_ranks(self.moe_tp_group)}, please make sure tp_size == moe_tp_size"
- )
-
def get_checkpoint_io(self) -> MoECheckpointIO:
return MoECheckpointIO(
self.dp_group, self.pp_group, self.tp_group, self.ep_group, self.moe_dp_group, self.zero_stage
@@ -249,14 +346,37 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
optimizer = cast_to_distributed(optimizer)
if not isinstance(model, ModelWrapper):
+ use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or (
+ self.dp_size == 1
+ and self.pp_size == 1
+ and self.enable_sequence_parallelism
+ and self.sequence_parallelism_mode == "all_to_all"
+ )
+ if use_ddp:
+ warnings.warn(
+ f"Will have to check all params are used in pytorch DDP since not all experts are always activated"
+ )
+ self.ddp_config["find_unused_parameters"] = True
+
+ if dist.get_process_group_ranks(self.dp_group) != dist.get_process_group_ranks(self.moe_dp_group):
+ raise ValueError(
+ f"if pytorch ddp is used, dp_group and moe_dp_group are expected to be the same since DDP can only reduce grad across a single group, but found dp_group {dist.get_process_group_ranks(self.dp_group)} and moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)}, you might want to use HybridParallelPlugin (i.e. set ep_size = 1) or set zero_stage > 0"
+ )
+
+ # sync gradients across DP * SP ranks
+ if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all":
+ dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
+ else:
+ dp_group = self.dp_group
+
model = HybridParallelModule(
module=model,
precision=self.precision,
shard_config=self.shard_config,
- dp_group=self.dp_sp_group,
+ dp_group=dp_group,
tp_group=self.tp_group,
sp_group=self.sp_group,
- use_ddp=self.use_ddp,
+ use_ddp=use_ddp,
ddp_config=self.ddp_config,
custom_policy=self.custom_policy,
)
@@ -301,7 +421,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
use_pipeline=self.enable_pipeline_parallelism,
force_overlap_comm=self.force_overlap_comm,
param_info=param_info,
- dp_process_group=self.dp_sp_group,
+ dp_process_group=dp_group,
tp_process_group=self.tp_group,
pp_process_group=self.pp_group,
moe_dp_group=self.moe_dp_group,
diff --git a/colossalai/shardformer/modeling/deepseek.py b/colossalai/shardformer/modeling/deepseek.py
index 52ea6c22b..a84a30972 100644
--- a/colossalai/shardformer/modeling/deepseek.py
+++ b/colossalai/shardformer/modeling/deepseek.py
@@ -61,13 +61,10 @@ class EPDeepseekMoE(nn.Module):
def __init__(self):
raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}")
- def setup_process_groups(
- self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup, moe_tp_group: ProcessGroup
- ):
+ def setup_process_groups(self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup):
assert tp_group is not None
assert moe_dp_group is not None
assert ep_group is not None
- assert moe_tp_group is not None
self.ep_size = dist.get_world_size(ep_group)
self.ep_rank = dist.get_rank(ep_group)
@@ -85,16 +82,13 @@ class EPDeepseekMoE(nn.Module):
self.moe_dp_group = moe_dp_group
self.moe_dp_size = moe_dp_group.size()
- # setup global tp group
+ # setup tp group
self.tp_group = tp_group
-
- # setup moe tp group
- self.moe_tp_group = moe_tp_group
- if self.moe_tp_group.size() > 1:
+ if self.tp_group.size() > 1:
for expert in held_experts:
- expert.gate_proj = Linear1D_Col.from_native_module(expert.gate_proj, self.moe_tp_group)
- expert.up_proj = Linear1D_Col.from_native_module(expert.up_proj, self.moe_tp_group)
- expert.down_proj = Linear1D_Row.from_native_module(expert.down_proj, self.moe_tp_group)
+ expert.gate_proj = Linear1D_Col.from_native_module(expert.gate_proj, self.tp_group)
+ expert.up_proj = Linear1D_Col.from_native_module(expert.up_proj, self.tp_group)
+ expert.down_proj = Linear1D_Row.from_native_module(expert.down_proj, self.tp_group)
for p in self.experts.parameters():
set_moe_tensor_ep_group(p, ep_group)
@@ -105,7 +99,6 @@ class EPDeepseekMoE(nn.Module):
tp_group: ProcessGroup,
moe_dp_group: ProcessGroup,
ep_group: ProcessGroup,
- moe_tp_group: ProcessGroup,
*args,
**kwargs,
) -> "EPDeepseekMoE":
@@ -113,7 +106,7 @@ class EPDeepseekMoE(nn.Module):
if module.__class__.__name__ == "DeepseekMLP":
return module
module.__class__ = EPDeepseekMoE
- module.setup_process_groups(tp_group, moe_dp_group, ep_group, moe_tp_group)
+ module.setup_process_groups(tp_group, moe_dp_group, ep_group)
return module
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
diff --git a/colossalai/shardformer/modeling/mixtral.py b/colossalai/shardformer/modeling/mixtral.py
index 9148a9fba..029ac36cd 100644
--- a/colossalai/shardformer/modeling/mixtral.py
+++ b/colossalai/shardformer/modeling/mixtral.py
@@ -53,13 +53,10 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
def __init__(self, *args, **kwargs):
raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}")
- def setup_process_groups(
- self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup, moe_tp_group: ProcessGroup
- ):
+ def setup_process_groups(self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup):
assert tp_group is not None
assert moe_dp_group is not None
assert ep_group is not None
- assert moe_tp_group is not None
# setup ep group
self.ep_size = dist.get_world_size(ep_group)
@@ -81,14 +78,11 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
# setup global tp group
self.tp_group = tp_group
-
- # setup moe tp group
- self.moe_tp_group = moe_tp_group
- if self.moe_tp_group.size() > 1:
+ if self.tp_group.size() > 1:
for expert in held_experts:
- expert.w1 = Linear1D_Col.from_native_module(expert.w1, self.moe_tp_group)
- expert.w3 = Linear1D_Col.from_native_module(expert.w3, self.moe_tp_group)
- expert.w2 = Linear1D_Row.from_native_module(expert.w2, self.moe_tp_group)
+ expert.w1 = Linear1D_Col.from_native_module(expert.w1, self.tp_group)
+ expert.w3 = Linear1D_Col.from_native_module(expert.w3, self.tp_group)
+ expert.w2 = Linear1D_Row.from_native_module(expert.w2, self.tp_group)
for p in self.experts.parameters():
set_moe_tensor_ep_group(p, ep_group)
@@ -99,14 +93,13 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
tp_group: ProcessGroup,
moe_dp_group: ProcessGroup,
ep_group: ProcessGroup,
- moe_tp_group: ProcessGroup,
*args,
**kwargs,
) -> "EPMixtralSparseMoeBlock":
# TODO: better init
LazyInitContext.materialize(module)
module.__class__ = EPMixtralSparseMoeBlock
- module.setup_process_groups(tp_group, moe_dp_group, ep_group, moe_tp_group)
+ module.setup_process_groups(tp_group, moe_dp_group, ep_group)
return module
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
diff --git a/colossalai/shardformer/policies/deepseek.py b/colossalai/shardformer/policies/deepseek.py
index 963bd9d67..d729a4ecc 100644
--- a/colossalai/shardformer/policies/deepseek.py
+++ b/colossalai/shardformer/policies/deepseek.py
@@ -154,7 +154,6 @@ class DeepseekPolicy(Policy):
"ep_group": self.shard_config.ep_group,
"tp_group": self.shard_config.tensor_parallel_process_group,
"moe_dp_group": self.shard_config.moe_dp_group,
- "moe_tp_group": self.shard_config.moe_tp_group,
},
)
],
diff --git a/colossalai/shardformer/policies/mixtral.py b/colossalai/shardformer/policies/mixtral.py
index 4de982f44..85895820e 100644
--- a/colossalai/shardformer/policies/mixtral.py
+++ b/colossalai/shardformer/policies/mixtral.py
@@ -155,7 +155,6 @@ class MixtralPolicy(Policy):
"ep_group": self.shard_config.ep_group,
"tp_group": self.shard_config.tensor_parallel_process_group,
"moe_dp_group": self.shard_config.moe_dp_group,
- "moe_tp_group": self.shard_config.moe_tp_group,
},
)
],
diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py
index f12c78526..163d7a7bb 100644
--- a/colossalai/shardformer/shard/shard_config.py
+++ b/colossalai/shardformer/shard/shard_config.py
@@ -50,7 +50,6 @@ class ShardConfig:
# for moe related
moe_dp_group: Optional[ProcessGroup] = None
ep_group: Optional[ProcessGroup] = None
- moe_tp_group: Optional[ProcessGroup] = None
# pipeline_parallel_size: int
# data_parallel_size: int
diff --git a/tests/test_moe/test_deepseek_layer.py b/tests/test_moe/test_deepseek_layer.py
index e633cdd07..d18ba2eac 100644
--- a/tests/test_moe/test_deepseek_layer.py
+++ b/tests/test_moe/test_deepseek_layer.py
@@ -47,7 +47,6 @@ def check_deepseek_moe_layer():
model,
ep_group=plugin.ep_group,
moe_dp_group=plugin.moe_dp_group,
- moe_tp_group=plugin.moe_tp_group,
tp_group=plugin.tp_group,
)
ep_output = model(x)
diff --git a/tests/test_moe/test_mixtral_layer.py b/tests/test_moe/test_mixtral_layer.py
index 5d9ca622a..bc41ac4f3 100644
--- a/tests/test_moe/test_mixtral_layer.py
+++ b/tests/test_moe/test_mixtral_layer.py
@@ -42,7 +42,6 @@ def check_mixtral_moe_layer():
ep_group=plugin.ep_group,
tp_group=plugin.tp_group,
moe_dp_group=plugin.moe_dp_group,
- moe_tp_group=plugin.moe_tp_group,
)
ep_output, ep_logits = model(x)
assert_close(orig_logits, ep_logits)
diff --git a/tests/test_shardformer/test_model/test_shard_deepseek.py b/tests/test_shardformer/test_model/test_shard_deepseek.py
index 5e6c0bf1c..709963613 100644
--- a/tests/test_shardformer/test_model/test_shard_deepseek.py
+++ b/tests/test_shardformer/test_model/test_shard_deepseek.py
@@ -24,24 +24,28 @@ NUM_HEADS = 4
TOP_K = 2
-CHECKED_CONFIG = [ # FOR_WORLD=8
- (2, 1, 1, 4, 1),
- (4, 1, 1, 2, 1),
- (4, 1, 1, 1, 1),
- (2, 1, 2, 1, 1),
+CHECKED_CONFIG = [ # FOR_WORLD=4
+ (1, 4, 1, 1, 1),
+ (1, 1, 4, 1, 1),
+ (1, 1, 1, 4, 1),
+ (1, 1, 1, 1, 4),
+ (0, 1, 4, 1, 1),
+ (0, 1, 1, 4, 1),
+ (0, 1, 1, 1, 4),
+ (1, 2, 1, 1, 1),
]
@parameterize(
"config",
[
- (2, 1, 2, 1, 1),
- # (2, 1, 1, 2, 1),
- # (2, 1, 1, 1, 2),
+ (1, 2, 2, 1, 1),
+ (1, 2, 1, 2, 1),
+ (1, 2, 1, 1, 2),
],
)
def run_zero_with_original_model(config: Tuple[int, ...]):
- ep_size, stage, pp_size, tp_size, sp_size = config
+ stage, ep_size, pp_size, tp_size, sp_size = config
world_size = dist.get_world_size()
rank = dist.get_rank()
dtype, precision = torch.float16, "fp16"
@@ -53,7 +57,6 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
tp_size=tp_size,
sp_size=sp_size,
ep_size=ep_size,
- moe_tp_size=tp_size,
zero_stage=stage,
enable_sequence_parallelism=sp_size > 1,
sequence_parallelism_mode="all_to_all" if sp_size > 1 else None,
diff --git a/tests/test_shardformer/test_model/test_shard_mixtral.py b/tests/test_shardformer/test_model/test_shard_mixtral.py
index 419679797..a3e201b67 100644
--- a/tests/test_shardformer/test_model/test_shard_mixtral.py
+++ b/tests/test_shardformer/test_model/test_shard_mixtral.py
@@ -25,24 +25,28 @@ NUM_HEADS = 4
TOP_K = 1
CHECKED_CONFIG = [ # FOR WORLD=4
- (2, 1, 2, 2, 1),
- (2, 1, 1, 2, 1),
- (2, 1, 4, 1, 1),
- (4, 1, 1, 1, 1),
- (4, 1, 1, 2, 1),
- (4, 1, 2, 1, 1),
- (2, 1, 2, 1, 1),
+ (0, 1, 4, 1, 1),
+ (0, 1, 1, 4, 1),
+ (0, 1, 1, 1, 4),
+ (1, 4, 1, 1, 1),
+ (1, 1, 4, 1, 1),
+ (1, 1, 1, 4, 1),
+ (1, 1, 1, 1, 4),
+ (1, 2, 1, 1, 1),
]
@parameterize(
"config",
[
- (2, 1, 1, 2, 1),
+ (1, 2, 2, 1, 1),
+ (1, 2, 1, 2, 1),
+ (1, 2, 1, 1, 2),
+ (0, 2, 1, 1, 1),
],
)
def run_zero_with_original_model(config: Tuple[int, ...]):
- ep_size, stage, pp_size, tp_size, sp_size = config
+ stage, ep_size, pp_size, tp_size, sp_size = config
world_size = dist.get_world_size()
rank = dist.get_rank()
dtype, precision = torch.float16, "fp16"
@@ -54,7 +58,6 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
tp_size=tp_size,
sp_size=sp_size,
ep_size=ep_size,
- moe_tp_size=tp_size,
zero_stage=stage,
enable_sequence_parallelism=sp_size > 1,
sequence_parallelism_mode="all_to_all" if sp_size > 1 else None,
From 5b4c12381bc3ec5ac7b476871973867679369621 Mon Sep 17 00:00:00 2001
From: hxwang
Date: Thu, 25 Jul 2024 06:32:02 +0000
Subject: [PATCH 094/116] Revert "[moe] implement submesh initialization"
This reverts commit 2f9bce6686d1415a83d5726dc5ff02222c742582.
---
colossalai/cluster/process_group_mesh.py | 44 +++++++-----------------
1 file changed, 12 insertions(+), 32 deletions(-)
diff --git a/colossalai/cluster/process_group_mesh.py b/colossalai/cluster/process_group_mesh.py
index a9d341efa..dc96708f0 100644
--- a/colossalai/cluster/process_group_mesh.py
+++ b/colossalai/cluster/process_group_mesh.py
@@ -209,15 +209,13 @@ class ProcessGroupMesh:
axis: Union[int, List[int]],
indices_at_axis: Optional[Union[List[int], List[List[int]]]] = None,
backend: Optional[str] = None,
- return_ranks_by_group: bool = False,
- ) -> Union[ProcessGroup, List[Tuple[int, ...]]]:
+ ) -> ProcessGroup:
"""Create all process groups along the given axis, and return the one which the current process belongs to.
Args:
axis (int): Axis along which the process groups are created.
indices_at_axis (Optional[List[int]], optional): Indices at the axis. Defaults to None.
backend (Optional[str], optional): Backend of the process group. Defaults to None.
- return_ranks_by_group (bool): Whether to return all ranks by group for creating submesh. Defaults to False.
Returns:
ProcessGroup: The process group along the given axis which the current process belongs to.
@@ -237,39 +235,25 @@ class ProcessGroupMesh:
# the choices on the axis are reduced to 1, since it's determined by `indices_at_axis`
for ax in axis:
reduced_shape[ax] = 1
- if return_ranks_by_group:
- ranks_by_group = []
- # use Cartesian product to generate all combinations of coordinates
- for base_coord in itertools.product(*[range(s) for s in reduced_shape]):
- coords_in_group = ProcessGroupMesh.get_coords_along_axis(base_coord, axis, indices_at_axis)
- ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group])
- ranks_by_group.append(ranks_in_group)
- return ranks_by_group
- else:
- target_group = None
- # use Cartesian product to generate all combinations of coordinates
- for base_coord in itertools.product(*[range(s) for s in reduced_shape]):
- coords_in_group = ProcessGroupMesh.get_coords_along_axis(base_coord, axis, indices_at_axis)
- ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group])
- group = self._get_group(ranks_in_group, backend=backend)
- if self._rank in ranks_in_group:
- target_group = group
- return target_group
+ target_group = None
+ # use Cartesian product to generate all combinations of coordinates
+ for base_coord in itertools.product(*[range(s) for s in reduced_shape]):
+ coords_in_group = ProcessGroupMesh.get_coords_along_axis(base_coord, axis, indices_at_axis)
+ ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group])
+ group = self._get_group(ranks_in_group, backend=backend)
+ if self._rank in ranks_in_group:
+ target_group = group
+ return target_group
def get_group_along_axis(
- self,
- axis: Union[int, List[int]],
- indices_at_axis: Optional[List[int]] = None,
- backend: Optional[str] = None,
- return_ranks_by_group: bool = False,
- ) -> Union[ProcessGroup, List[Tuple[int, ...]]]:
+ self, axis: Union[int, List[int]], indices_at_axis: Optional[List[int]] = None, backend: Optional[str] = None
+ ) -> ProcessGroup:
"""Get the process group along the given axis which the current process belongs to. If the process group doesn't exist, it will be created.
Args:
axis (int or list of int): Axes along which the process groups are created.
indices_at_axis (Optional[List[int]], optional): Indices at the axis. Defaults to None.
backend (Optional[str], optional): Backend of the process group. Defaults to None.
- return_ranks_by_group (bool): Whether to return all ranks by group for creating submesh. Defaults to False.
Returns:
ProcessGroup: The process group along the given axis which the current process belongs to.
@@ -283,10 +267,6 @@ class ProcessGroupMesh:
coords_in_group = ProcessGroupMesh.get_coords_along_axis(self._coord, axis, indices_at_axis)
ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group])
-
- if return_ranks_by_group:
- return self.create_group_along_axis(axis, indices_at_axis, backend=backend, return_ranks_by_group=True)
-
if ranks_in_group not in self._ranks_to_group:
# no need to cache it explicitly, since it will be cached in `create_group_along_axis`
return self.create_group_along_axis(axis, indices_at_axis, backend=backend)
From 606b0891ed8de51d517c9ac0436438def687614c Mon Sep 17 00:00:00 2001
From: hxwang
Date: Thu, 25 Jul 2024 06:34:22 +0000
Subject: [PATCH 095/116] [chore] change moe_pg_mesh to private
---
colossalai/booster/plugin/moe_hybrid_parallel_plugin.py | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
index 9b6eae0d0..7f6608086 100644
--- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
+++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
@@ -214,14 +214,14 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
self.dp_axis, self.pp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size)
self.moe_dp_axis, self.ep_axis = 0, 1
- self.moe_pg_mesh = ProcessGroupMesh(
+ self.__moe_pg_mesh = ProcessGroupMesh(
self.moe_dp_size, self.ep_size, self.pp_size, self.tp_size, self.sp_size
)
else:
self.pp_axis, self.dp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size)
self.moe_dp_axis, self.ep_axis = 1, 2
- self.moe_pg_mesh = ProcessGroupMesh(
+ self.__moe_pg_mesh = ProcessGroupMesh(
self.pp_size, self.moe_dp_size, self.ep_size, self.tp_size, self.sp_size
)
@@ -269,8 +269,8 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
self.tp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
self.dp_group = self.pg_mesh.get_group_along_axis(self.dp_axis)
self.pp_group = self.pg_mesh.get_group_along_axis(self.pp_axis)
- self.moe_dp_group = self.moe_pg_mesh.get_group_along_axis(self.moe_dp_axis)
- self.ep_group = self.moe_pg_mesh.get_group_along_axis(self.ep_axis)
+ self.moe_dp_group = self.__moe_pg_mesh.get_group_along_axis(self.moe_dp_axis)
+ self.ep_group = self.__moe_pg_mesh.get_group_along_axis(self.ep_axis)
if self.enable_sequence_parallelism and self.sequence_parallelism_mode in ["split_gather", "ring"]:
self.sp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
else:
From 12d043ca00be7c6ce749f9f5443813f80aaa5da8 Mon Sep 17 00:00:00 2001
From: haze188
Date: Thu, 25 Jul 2024 06:40:54 +0000
Subject: [PATCH 096/116] [misc] remove incompatible test config
---
tests/test_shardformer/test_model/test_shard_mixtral.py | 1 -
1 file changed, 1 deletion(-)
diff --git a/tests/test_shardformer/test_model/test_shard_mixtral.py b/tests/test_shardformer/test_model/test_shard_mixtral.py
index a3e201b67..de09eedcb 100644
--- a/tests/test_shardformer/test_model/test_shard_mixtral.py
+++ b/tests/test_shardformer/test_model/test_shard_mixtral.py
@@ -42,7 +42,6 @@ CHECKED_CONFIG = [ # FOR WORLD=4
(1, 2, 2, 1, 1),
(1, 2, 1, 2, 1),
(1, 2, 1, 1, 2),
- (0, 2, 1, 1, 1),
],
)
def run_zero_with_original_model(config: Tuple[int, ...]):
From 70793ce9ede447765507b6edb230609cfa62c638 Mon Sep 17 00:00:00 2001
From: haze188
Date: Thu, 25 Jul 2024 07:50:41 +0000
Subject: [PATCH 097/116] [misc] fix ci failure: change default value to false
in moe plugin
---
colossalai/booster/plugin/moe_hybrid_parallel_plugin.py | 2 +-
colossalai/shardformer/policies/deepseek.py | 2 ++
tests/test_shardformer/test_model/test_shard_deepseek.py | 3 ++-
3 files changed, 5 insertions(+), 2 deletions(-)
diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
index 7f6608086..4c3aece9d 100644
--- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
+++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
@@ -154,7 +154,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
zero_bucket_size_in_m: int = 12,
cpu_offload: bool = False,
communication_dtype: Optional[torch.dtype] = None,
- overlap_communication: bool = True,
+ overlap_communication: bool = False,
custom_policy: Policy = None,
pp_style: str = "1f1b",
num_model_chunks: int = 1,
diff --git a/colossalai/shardformer/policies/deepseek.py b/colossalai/shardformer/policies/deepseek.py
index d729a4ecc..605f69c4a 100644
--- a/colossalai/shardformer/policies/deepseek.py
+++ b/colossalai/shardformer/policies/deepseek.py
@@ -4,6 +4,7 @@ from typing import Callable, Dict, List, Union
import torch.nn as nn
from torch import Tensor
from torch.nn import Module
+from transformers.utils import is_flash_attn_greater_or_equal_2_10
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
from colossalai.shardformer.layer.embedding import PaddingEmbedding, VocabParallelEmbedding1D
@@ -206,6 +207,7 @@ class DeepseekPolicy(Policy):
@staticmethod
def from_native_module(original_attn: nn.Module, *args, **kwargs) -> nn.Module:
original_attn.__class__ = flash_attn_cls
+ original_attn._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
return original_attn
self.append_or_create_submodule_replacement(
diff --git a/tests/test_shardformer/test_model/test_shard_deepseek.py b/tests/test_shardformer/test_model/test_shard_deepseek.py
index 709963613..187c642da 100644
--- a/tests/test_shardformer/test_model/test_shard_deepseek.py
+++ b/tests/test_shardformer/test_model/test_shard_deepseek.py
@@ -60,6 +60,7 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
zero_stage=stage,
enable_sequence_parallelism=sp_size > 1,
sequence_parallelism_mode="all_to_all" if sp_size > 1 else None,
+ enable_flash_attention=sp_size > 1,
overlap_communication=False,
initial_scale=1,
precision=precision,
@@ -161,7 +162,7 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
assert_loose_close(parallel_output, torch_output_sum, dtype=dtype)
# use checkpoint to load sharded zero model
- model_dir = "./test_mixtral"
+ model_dir = "./test_deepseek"
if rank == world_size - 1:
os.makedirs(model_dir, exist_ok=True)
From 7e737df5adc6dc27f2b1c8b603c719baa12bf31f Mon Sep 17 00:00:00 2001
From: haze188
Date: Thu, 25 Jul 2024 08:40:53 +0000
Subject: [PATCH 098/116] [misc] remove useless condition
---
colossalai/shardformer/policies/mixtral.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/colossalai/shardformer/policies/mixtral.py b/colossalai/shardformer/policies/mixtral.py
index 85895820e..10df143c9 100644
--- a/colossalai/shardformer/policies/mixtral.py
+++ b/colossalai/shardformer/policies/mixtral.py
@@ -61,7 +61,7 @@ class MixtralPolicy(Policy):
policy[attn_cls] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement,
)
- if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism:
+ if self.shard_config.enable_sequence_parallelism:
if self.pipeline_stage_manager is not None:
# NOTE: we are replacing model forward for both sequence parallelism and pipeline parallelism
# if both are enabled, one of them will be ignored
From f7c5485ed60aabd1e5d4c1a350701d68422ce2f8 Mon Sep 17 00:00:00 2001
From: hxwang
Date: Thu, 25 Jul 2024 09:07:49 +0000
Subject: [PATCH 099/116] [chore] docstring
---
colossalai/booster/plugin/moe_hybrid_parallel_plugin.py | 6 ++++--
1 file changed, 4 insertions(+), 2 deletions(-)
diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
index 4c3aece9d..beac2d037 100644
--- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
+++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
@@ -115,8 +115,10 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
"""
Modified from colossalai.booster.plugin.hybrid_parallel_plugin.HybridParallelPlugin
Extra Args:
- ep_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1.
- force_overlap_comm (bool): For LowLevelZeroOptimizer, it might causes program hang when some experts are routed and overlap_communication is True during training. This flag is used to force overlap_communication=True.
+ ep_size (int): The size of expert parallelism
+ force_overlap_comm (bool):
+ For LowLevelZeroOptimizer, it might causes program hang when some experts are routed and overlap_communication is True during training.
+ This flag is used to force overlap_communication=True. Make sure every expert are routed when you use this.
"""
def __init__(
From 7bedd03739acea830cf283c29c3d5ed38277b291 Mon Sep 17 00:00:00 2001
From: hxwang
Date: Thu, 25 Jul 2024 09:49:57 +0000
Subject: [PATCH 100/116] [moe] remove force_overlap_comm flag and add warning
instead
---
.../plugin/moe_hybrid_parallel_plugin.py | 26 ++++++-------------
1 file changed, 8 insertions(+), 18 deletions(-)
diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
index beac2d037..b49b886a0 100644
--- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
+++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
@@ -42,7 +42,6 @@ class MoeHybridParallelZeroOptimizer(HybridParallelZeroOptimizer):
optimizer: Optimizer,
model: Module,
use_pipeline: bool,
- force_overlap_comm: bool, # force overlap comm
dp_process_group: Optional[ProcessGroup], # the dp pg for comm
tp_process_group: Optional[ProcessGroup], # if using tp
pp_process_group: Optional[ProcessGroup], # if using pp
@@ -65,17 +64,6 @@ class MoeHybridParallelZeroOptimizer(HybridParallelZeroOptimizer):
forced_dtype: Optional[torch.dtype] = None,
overlap_allgather: bool = False,
):
- WARN_STR = "Note that you need to make sure every expert are routed (i.e.) every expert has backward, otherwise this might lead to program hang or inconsistent result."
- if not force_overlap_comm and (overlap_communication or partition_grad):
- raise RuntimeError(
- WARN_STR
- + " If you are not sure about this, set (overlap_communication=False and partition_grad=False) or force_overlap_comm=True"
- )
-
- if force_overlap_comm:
- overlap_communication = True
- warnings.warn(WARN_STR + " Please make sure of this.")
-
pg_param_list = {
dp_process_group: list(filter(lambda p: not is_moe_tensor(p), model.parameters())),
moe_dp_group: list(filter(is_moe_tensor, model.parameters())),
@@ -116,9 +104,6 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
Modified from colossalai.booster.plugin.hybrid_parallel_plugin.HybridParallelPlugin
Extra Args:
ep_size (int): The size of expert parallelism
- force_overlap_comm (bool):
- For LowLevelZeroOptimizer, it might causes program hang when some experts are routed and overlap_communication is True during training.
- This flag is used to force overlap_communication=True. Make sure every expert are routed when you use this.
"""
def __init__(
@@ -167,8 +152,15 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
dp_outside: bool = True,
overlap_p2p: bool = True,
overlap_allgather: bool = False,
- force_overlap_comm: bool = False,
) -> None:
+ if overlap_communication or zero_stage == 2:
+ overlap_communication = False
+ zero_stage = 1
+ warnings.warn(
+ f"overlap_communication and zero_stage are set to False and 1 because "
+ f"ZeRO-2 or comm overlap cause program hang when some experts are not routed. "
+ )
+
assert (
dist.get_world_size() % (tp_size * pp_size) == 0
), f"World size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
@@ -326,7 +318,6 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
)
self.max_norm = max_norm
- self.force_overlap_comm = force_overlap_comm
def get_checkpoint_io(self) -> MoECheckpointIO:
return MoECheckpointIO(
@@ -421,7 +412,6 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
optimizer,
model,
use_pipeline=self.enable_pipeline_parallelism,
- force_overlap_comm=self.force_overlap_comm,
param_info=param_info,
dp_process_group=dp_group,
tp_process_group=self.tp_group,
From 65daa876274e769aaeaa88cd72135f25636f6533 Mon Sep 17 00:00:00 2001
From: botbw
Date: Fri, 26 Jul 2024 06:05:01 +0000
Subject: [PATCH 101/116] [doc] add MoeHybridParallelPlugin docstring
---
.../plugin/moe_hybrid_parallel_plugin.py | 66 ++++++++++++++++++-
1 file changed, 64 insertions(+), 2 deletions(-)
diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
index b49b886a0..7a16a1737 100644
--- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
+++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
@@ -101,9 +101,71 @@ class MoeHybridParallelZeroOptimizer(HybridParallelZeroOptimizer):
class MoeHybridParallelPlugin(HybridParallelPlugin):
"""
- Modified from colossalai.booster.plugin.hybrid_parallel_plugin.HybridParallelPlugin
- Extra Args:
+ Plugin for MoE Hybrid Parallel Training, which is similar to HybridParallelPlugin
+ Tensor parallel, pipeline parallel and data parallel(DDP/ZeRO) can be picked and combined in this plugin.
+ The size of tp and pp should be passed in by user, then the size of dp is automatically calculated from dp_size = world_size / (tp_size * pp_size).
+
+ ```python
+ from colossalai.booster import Booster
+ from colossalai.booster.plugin import MoeHybridParallelPlugin
+
+ model, train_dataset, optimizer, criterion = ...
+ plugin = MoeHybridParallelPlugin(tp_size=2, pp_size=2, ep_size=2)
+
+ train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
+ booster = Booster(plugin=plugin)
+ model, optimizer, criterion, train_dataloader, _ = booster.boost(model, optimizer, criterion, train_dataloader)
+ ```
+
+ Args:
+ tp_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1.
+ pp_size (int): The number of pipeline stages in pipeline parallelism. Pipeline parallelism will not be used when pp_size is set to 1.
ep_size (int): The size of expert parallelism
+ sp_size (int): The size of sequence parallelism.
+ precision (str, optional): Specifies the precision of parameters during training.
+ Auto-mixied precision will be used when this argument is set to 'fp16' or 'bf16', otherwise model is trained with 'fp32'.
+ Defaults to 'fp16'.
+ zero_stage (int, optional): The stage of ZeRO for data parallelism. Can only be choosed from [0, 1, 2].
+ When set to 0, ZeRO will not be used. Defaults to 0.
+ enable_all_optimization (bool, optional): Whether to switch on all the optimizations supported by Shardformer.
+ Currently all the optimization methods include fused normalization, flash attention and JIT.
+ Defaults to False.
+ enable_fused_normalization (bool, optional): Whether to switch on fused normalization in Shardformer. Defaults to False.
+ enable_flash_attention (bool, optional): Whether to switch on flash attention in Shardformer. Defaults to False.
+ enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
+ enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
+ sequence_parallelism_mode (str): The Sequence parallelism mode. Can only be choosed from ["split_gather", "ring", "all_to_all"]. Defaults to "split_gather".
+ enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False.
+ parallel_output (bool): Whether to keep the output parallel when enabling tensor parallelism. Default to True.
+ num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None.
+ microbatch_size (int, optional): Microbatch size when using pipeline parallelism.
+ Either ``num_microbatches`` or ``microbatch_size`` should be provided if using pipeline.
+ If ``num_microbatches`` is provided, this will be ignored. Defaults to None.
+ initial_scale (float, optional): The initial loss scale of AMP. Defaults to 2**16.
+ min_scale (float, optional): The minimum loss scale of AMP. Defaults to 1.
+ growth_factor (float, optional): The multiplication factor for increasing loss scale when using AMP. Defaults to 2.
+ backoff_factor (float, optional): The multiplication factor for decreasing loss scale when using AMP. Defaults to 0.5.
+ growth_interval (int, optional): The number of steps to increase loss scale when no overflow occurs when using AMP. Defaults to 1000.
+ hysteresis (int, optional): The number of overflows before decreasing loss scale when using AMP. Defaults to 2.
+ max_scale (float, optional): The maximum loss scale of AMP. Defaults to 2**32.
+ max_norm (float, optional): Maximum norm for gradient clipping. Defaults to 0.
+ broadcast_buffers (bool, optional): Whether to broadcast buffers in the beginning of training when using DDP. Defaults to True.
+ ddp_bucket_cap_mb (int, optional): The bucket size in MB when using DDP. Defaults to 25.
+ find_unused_parameters (bool, optional): Whether to find unused parameters when using DDP. Defaults to False.
+ check_reduction (bool, optional): Whether to check reduction when using DDP. Defaults to False.
+ gradient_as_bucket_view (bool, optional): Whether to use gradient as bucket view when using DDP. Defaults to False.
+ static_graph (bool, optional): Whether to use static graph when using DDP. Defaults to False.
+ zero_bucket_size_in_m (int, optional): Gradient reduce bucket size in million elements when using ZeRO. Defaults to 12.
+ cpu_offload (bool, optional): Whether to open cpu_offload when using ZeRO. Defaults to False.
+ communication_dtype (torch.dtype, optional): Communication dtype when using ZeRO. If not specified, the dtype of param will be used. Defaults to None.
+ overlap_communication (bool, optional): Whether to overlap communication and computation when using ZeRO. Defaults to True.
+ custom_policy (Policy, optional): Custom policy for Shardformer. Defaults to None.
+ pp_style (str, optional): The style for pipeline parallelism. Defaults to '1f1b'.
+ num_model_chunks (int, optional): The number of model chunks for interleaved pipeline parallelism. Defaults to 1.
+ gradient_checkpoint_config (GradientCheckpointConfig, optional): Configuration for gradient checkpointing. Defaults to None.
+ enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True.
+ make_vocab_size_divisible_by (int, optional): it's used when padding the vocabulary size, to make it choose an faster kenel. Default to 64.
+ overlap_p2p (bool, optional): Whether to overlap the p2p communication in pipeline parallelism
"""
def __init__(
From d1d1ab871e2600f2b5e747951ed78cba84134e94 Mon Sep 17 00:00:00 2001
From: botbw
Date: Fri, 26 Jul 2024 07:32:19 +0000
Subject: [PATCH 102/116] [moe] solve dp axis issue
---
.../booster/plugin/hybrid_parallel_plugin.py | 6 ++--
.../plugin/moe_hybrid_parallel_plugin.py | 28 +++++++------------
2 files changed, 13 insertions(+), 21 deletions(-)
diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py
index 5a9bae479..fcb747814 100644
--- a/colossalai/booster/plugin/hybrid_parallel_plugin.py
+++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py
@@ -1375,15 +1375,15 @@ class HybridParallelPlugin(PipelinePluginBase):
kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in
`DataLoader `_.
- Returns:
+ Returns:`
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
"""
_kwargs = kwargs.copy()
distributed_sampler_cls = distributed_sampler_cls or DistributedSampler
sampler = distributed_sampler_cls(
dataset,
- num_replicas=self.pg_mesh.size(self.dp_axis),
- rank=self.pg_mesh.coordinate(self.dp_axis),
+ num_replicas=self.dp_group.size(),
+ rank=dist.get_group_rank(self.dp_group, global_rank=dist.get_rank()),
shuffle=shuffle,
)
diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
index 7a16a1737..b3415af0e 100644
--- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
+++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
@@ -211,7 +211,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None,
enable_metadata_cache: bool = True,
make_vocab_size_divisible_by: int = 64,
- dp_outside: bool = True,
+ moe_dp_outside: bool = True,
overlap_p2p: bool = True,
overlap_allgather: bool = False,
) -> None:
@@ -266,20 +266,12 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
self.enable_flash_attention = enable_flash_attention
self.enable_jit_fused = enable_jit_fused
self.enable_sequence_parallelism = enable_sequence_parallelism
- if dp_outside:
- self.dp_axis, self.pp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
- self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size)
- self.moe_dp_axis, self.ep_axis = 0, 1
- self.__moe_pg_mesh = ProcessGroupMesh(
- self.moe_dp_size, self.ep_size, self.pp_size, self.tp_size, self.sp_size
- )
+ if moe_dp_outside:
+ self.moe_dp_axis, self.pp_axis, self.ep_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3, 4
+ self.pg_mesh = ProcessGroupMesh(self.moe_dp_size, self.pp_size, self.ep_size, self.tp_size, self.sp_size)
else:
- self.pp_axis, self.dp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
- self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size)
- self.moe_dp_axis, self.ep_axis = 1, 2
- self.__moe_pg_mesh = ProcessGroupMesh(
- self.pp_size, self.moe_dp_size, self.ep_size, self.tp_size, self.sp_size
- )
+ self.pp_axis, self.moe_dp_axis, self.ep_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3, 4
+ self.pg_mesh = ProcessGroupMesh(self.pp_size, self.moe_dp_size, self.ep_size, self.tp_size, self.sp_size)
self.stage_manager = None
self.schedule = None
@@ -323,10 +315,10 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
raise NotImplementedError()
self.tp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
- self.dp_group = self.pg_mesh.get_group_along_axis(self.dp_axis)
+ self.dp_group = self.pg_mesh.get_group_along_axis([self.moe_dp_axis, self.ep_axis])
self.pp_group = self.pg_mesh.get_group_along_axis(self.pp_axis)
- self.moe_dp_group = self.__moe_pg_mesh.get_group_along_axis(self.moe_dp_axis)
- self.ep_group = self.__moe_pg_mesh.get_group_along_axis(self.ep_axis)
+ self.moe_dp_group = self.pg_mesh.get_group_along_axis(self.moe_dp_axis)
+ self.ep_group = self.pg_mesh.get_group_along_axis(self.ep_axis)
if self.enable_sequence_parallelism and self.sequence_parallelism_mode in ["split_gather", "ring"]:
self.sp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
else:
@@ -420,7 +412,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
# sync gradients across DP * SP ranks
if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all":
- dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
+ dp_group = self.pg_mesh.create_group_along_axis([self.moe_dp_axis, self.ep_axis, self.sp_axis])
else:
dp_group = self.dp_group
From 62cdac6b7b655e11626382d64e56503146a516ee Mon Sep 17 00:00:00 2001
From: botbw
Date: Tue, 30 Jul 2024 02:56:39 +0000
Subject: [PATCH 103/116] [chore] remove redundant test case, print string &
reduce test tokens
---
colossalai/shardformer/modeling/mixtral.py | 1 -
.../test_model/test_shard_deepseek.py | 2 +-
.../test_model/test_shard_deepseek_skip.py | 232 ------------------
3 files changed, 1 insertion(+), 234 deletions(-)
delete mode 100644 tests/test_shardformer/test_model/test_shard_deepseek_skip.py
diff --git a/colossalai/shardformer/modeling/mixtral.py b/colossalai/shardformer/modeling/mixtral.py
index 029ac36cd..d30ce5ea8 100644
--- a/colossalai/shardformer/modeling/mixtral.py
+++ b/colossalai/shardformer/modeling/mixtral.py
@@ -245,7 +245,6 @@ class MixtralPipelineForwards:
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
- print("input_ids", input_ids.shape)
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
diff --git a/tests/test_shardformer/test_model/test_shard_deepseek.py b/tests/test_shardformer/test_model/test_shard_deepseek.py
index 187c642da..46da4522f 100644
--- a/tests/test_shardformer/test_model/test_shard_deepseek.py
+++ b/tests/test_shardformer/test_model/test_shard_deepseek.py
@@ -17,7 +17,7 @@ from colossalai.testing.random import seed_all
from tests.test_moe.moe_utils import assert_loose_close, check_model_equal
NUM_BATCH = 8
-NUM_TOK_PER_BATCH, NUM_EXPERTS = 4000, 2
+NUM_TOK_PER_BATCH, NUM_EXPERTS = 4, 2
NUM_LAYERS = 4
HIDDEN_SIZE_PER_HEAD = 4
NUM_HEADS = 4
diff --git a/tests/test_shardformer/test_model/test_shard_deepseek_skip.py b/tests/test_shardformer/test_model/test_shard_deepseek_skip.py
deleted file mode 100644
index fe834a4f6..000000000
--- a/tests/test_shardformer/test_model/test_shard_deepseek_skip.py
+++ /dev/null
@@ -1,232 +0,0 @@
-# modified from test_shard_mistral.py
-import os
-
-import pytest
-import torch
-import torch.distributed as dist
-from torch.testing import assert_close
-
-import colossalai
-from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
-from colossalai.logging import disable_existing_loggers
-from colossalai.shardformer.layer.utils import Randomizer
-from colossalai.tensor.d_tensor.api import clear_layout_converter
-from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
-from tests.kit.model_zoo import model_zoo
-from tests.test_shardformer.test_model._utils import (
- build_model_from_hybrid_plugin,
- check_all_grad_tensors,
- check_loss,
- check_output_hidden_state,
- check_weight,
- get_grad_tensors_for_check,
- run_forward_backward_with_hybrid_plugin,
- unwrap_model,
-)
-
-os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
-
-
-def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
- # TODO: SGD failed for full dp
- org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin(
- model_fn, loss_fn, test_config, pluggin_cls=MoeHybridParallelPlugin, optim_class=torch.optim.SGD
- )
-
- org_model = org_model.to(torch.float16)
- org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
- org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
- )
- stage_manager = booster.plugin.stage_manager
- tp_group = booster.plugin.tp_group
-
- # check last hidden state & loss
- if stage_manager is None or stage_manager.is_last_stage():
- if test_config["precision"] == "fp32":
- atol, rtol = 1e-5, 1e-3
- else:
- atol, rtol = 5e-3, 5e-3
-
- check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
- check_output_hidden_state(org_output, sharded_output, stage_manager, atol, rtol)
-
- # unwrap model
- mixtral_model = unwrap_model(org_model, "DeepseekModel", "model")
- shard_mixtral_model = unwrap_model(sharded_model, "DeepseekModel", "model")
-
- row_layer_for_check = ["layers[0].self_attn.q_proj", "embed_tokens"]
- col_layer_for_check = ["layers[0].self_attn.o_proj"]
-
- name_to_p = {n: p for n, p in mixtral_model.named_parameters()}
- # Check the grad when using ZeRO-1 and ZeRO-2
- if (
- # booster.plugin.zero_stage in [1, 2]
- booster.plugin.shard_config.enable_sequence_parallelism
- and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all"
- ):
- rank = dist.get_rank()
- for n, p in shard_mixtral_model.named_parameters():
- zero_grad = sharded_optimizer.get_param_grad(p)
- if name_to_p[n].grad is None:
- name_to_p[n].grad = torch.zeros_like(name_to_p[n].data)
- continue
- assert_close(name_to_p[n].grad, zero_grad, atol=5e-3, rtol=5e-3, check_dtype=False)
-
- # Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
- grads_to_check = {}
- if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
- if test_config["precision"] == "fp32":
- atol, rtol = 5e-5, 1e-4
- else:
- atol, rtol = 5e-3, 5e-3
- row_layer_grads = get_grad_tensors_for_check(
- mixtral_model,
- shard_mixtral_model,
- row_layer_for_check,
- tp_group,
- atol=atol,
- rtol=rtol,
- dim=0,
- verbose=False,
- )
- col_layer_grads = get_grad_tensors_for_check(
- mixtral_model,
- shard_mixtral_model,
- col_layer_for_check,
- tp_group,
- atol=atol,
- rtol=rtol,
- dim=1,
- verbose=False,
- )
- grads_to_check.update(col_layer_grads)
- grads_to_check.update(row_layer_grads)
-
- # check grads
- check_all_grad_tensors(grads_to_check)
-
- for n, p in shard_mixtral_model.named_parameters():
- assert_close(name_to_p[n], p, atol=5e-3, rtol=5e-3, check_dtype=False)
-
- # optimizer executes step
- org_optimizer.step()
- sharded_optimizer.step()
-
- for n, p in shard_mixtral_model.named_parameters():
- assert_close(name_to_p[n], p, atol=5e-3, rtol=5e-3, check_dtype=False)
-
- # check weights
- if stage_manager is None or stage_manager.is_first_stage():
- if test_config["precision"] == "fp32":
- atol, rtol = 2e-4, 1e-3
- else:
- atol, rtol = 5e-3, 5e-3
- try:
- check_weight(
- mixtral_model,
- shard_mixtral_model,
- col_layer_for_check,
- tp_group,
- atol=atol,
- rtol=rtol,
- dim=1,
- verbose=False,
- )
- except Exception as e:
- rank = dist.get_rank()
- print(f"{rank=}, Failed config: {test_config}")
- raise e
-
- torch.cuda.empty_cache()
-
-
-@parameterize(
- "test_config",
- [
- # {
- # "tp_size": 1,
- # "pp_size": 1,
- # "num_microbatches": 2,
- # "ep_size": 2,
- # "zero_stage": 0,
- # "overlap_communication": False,
- # "precision": "fp16",
- # }, # [dp(4)] + [moe_dp(4)]
- # {
- # "tp_size": 1,
- # "pp_size": 2,
- # "num_microbatches": 2,
- # "ep_size": 2,
- # "zero_stage": 1,
- # "overlap_communication": False,
- # "precision": "fp32",
- # }, # [dp(2) + pp(2)] + [moe_pp(2)]
- # {
- # "tp_size": 1,
- # "pp_size": 2,
- # "ep_size": 2,
- # "num_microbatches": 2,
- # "zero_stage": 1,
- # "overlap_communication": False,
- # "precision": "fp16",
- # "initial_scale": 1,
- # "find_unused_parameters": True,
- # }, # [pp(2) + tp(2)] + [pp(2), replicate(2)] pass
- { # Ulysess + Flash attention
- "tp_size": 1,
- "pp_size": 1,
- "sp_size": 2,
- "ep_size": 2,
- "enable_sequence_parallelism": True,
- "sequence_parallelism_mode": "all_to_all",
- "zero_stage": 1,
- "overlap_communication": False,
- "precision": "fp16",
- "initial_scale": 1,
- "find_unused_parameters": True,
- },
- # {
- # "tp_size": 1,
- # "pp_size": 1,
- # "ep_size": 2,
- # "zero_stage": 0,
- # "overlap_communication": False,
- # "precision": "fp32",
- # }, # [dp(4)] + [ep(2) + moe_tp(2)]
- # {
- # "tp_size": 1,
- # "pp_size": 1,
- # "ep_size": 4,
- # "overlap_communication": False,
- # "zero_stage": 0,
- # "precision": "fp32"
- # }, # full dp for non-moe and full ep for moe
- ],
-)
-def run_deepseek_test(test_config):
- sub_model_zoo = model_zoo.get_sub_registry("transformers_deepseek")
-
- for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
- check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
-
- clear_layout_converter()
- Randomizer.reset_index()
- torch.cuda.empty_cache()
-
-
-def check_deepseek(rank, world_size, port):
- disable_existing_loggers()
- colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
- run_deepseek_test()
-
-
-@pytest.mark.skip("redundant")
-@pytest.mark.dist
-@rerun_if_address_is_in_use()
-@clear_cache_before_run()
-def test_mixtral():
- spawn(check_deepseek, 4)
-
-
-if __name__ == "__main__":
- test_mixtral()
From 19d1510ea26d10484a804eb62f6d03dbcc7c80a8 Mon Sep 17 00:00:00 2001
From: Tong Li
Date: Fri, 2 Aug 2024 10:06:25 +0800
Subject: [PATCH 104/116] [feat] Dist Loader for Eval (#5950)
* support auto distributed data loader
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* support auto distributed data loader
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix tp error
* remove unused parameters
* remove unused
* update inference
* update docs
* update inference
---------
Co-authored-by: Michelle
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
---
.../colossal_eval/dataset/agieval.py | 4 +-
.../colossal_eval/dataset/base.py | 20 +++++--
.../colossal_eval/dataset/ceval.py | 4 +-
.../colossal_eval/dataset/cmmlu.py | 4 +-
.../colossal_eval/dataset/colossalai.py | 2 +-
.../colossal_eval/dataset/cvalues.py | 2 +-
.../colossal_eval/dataset/gaokaobench.py | 4 +-
.../colossal_eval/dataset/longbench.py | 2 +-
.../colossal_eval/dataset/mmlu.py | 4 +-
.../colossal_eval/dataset/mtbench.py | 6 +--
.../colossal_eval/dataset/safetybench_en.py | 2 +-
.../colossal_eval/dataset/safetybench_zh.py | 2 +-
.../colossal_eval/models/huggingface.py | 48 ++++++++---------
.../colossal_eval/utils/conversation.py | 12 ++---
.../examples/dataset_evaluation/inference.py | 54 ++++++++++++-------
15 files changed, 93 insertions(+), 77 deletions(-)
diff --git a/applications/ColossalEval/colossal_eval/dataset/agieval.py b/applications/ColossalEval/colossal_eval/dataset/agieval.py
index d5f230249..c1cfe37d7 100644
--- a/applications/ColossalEval/colossal_eval/dataset/agieval.py
+++ b/applications/ColossalEval/colossal_eval/dataset/agieval.py
@@ -197,9 +197,7 @@ class AGIEvalDataset(BaseDataset):
"""
@staticmethod
- def load(
- path: str, logger: DistributedLogger, few_shot: bool, forward_only: bool, load_train: bool, load_reference: bool
- ) -> List[Dict]:
+ def load(path: str, logger: DistributedLogger, few_shot: bool, *args, **kwargs) -> List[Dict]:
dataset = {"test": {}}
files = glob.glob(os.path.join(path, "*.jsonl"))
diff --git a/applications/ColossalEval/colossal_eval/dataset/base.py b/applications/ColossalEval/colossal_eval/dataset/base.py
index 531313d7e..a29f56fd1 100644
--- a/applications/ColossalEval/colossal_eval/dataset/base.py
+++ b/applications/ColossalEval/colossal_eval/dataset/base.py
@@ -1,6 +1,9 @@
from abc import abstractstaticmethod
from colossal_eval.utils import jdump
+from torch.utils.data import Dataset
+
+from colossalai.logging import DistributedLogger
class BaseDataset:
@@ -12,13 +15,24 @@ class BaseDataset:
logger: Logger for the dataset.
"""
- def __init__(self, path, logger, few_shot, forward_only=False, load_train=False, load_reference=False):
- self.dataset = self.load(path, logger, few_shot, forward_only, load_train, load_reference)
+ def __init__(self, path, logger, *args, **kwargs):
+ self.dataset = self.load(path, logger, *args, **kwargs)
def save(self, save_path):
"""Save the converted dataset"""
jdump(self.dataset, save_path)
@abstractstaticmethod
- def load(path, logger):
+ def load(path, logger: DistributedLogger, *args, **kwargs):
"""Load the original dataset and convert it into the inference dataset"""
+
+
+class DistributedDataset(Dataset):
+ def __init__(self, data):
+ self.data = data
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, idx):
+ return self.data[idx]
diff --git a/applications/ColossalEval/colossal_eval/dataset/ceval.py b/applications/ColossalEval/colossal_eval/dataset/ceval.py
index 915f4d9b0..1023d1e23 100644
--- a/applications/ColossalEval/colossal_eval/dataset/ceval.py
+++ b/applications/ColossalEval/colossal_eval/dataset/ceval.py
@@ -90,9 +90,7 @@ class CEvalDataset(BaseDataset):
"""
@staticmethod
- def load(
- path: str, logger: DistributedLogger, few_shot: bool, forward_only: bool, load_train: bool, load_reference: bool
- ) -> List[Dict]:
+ def load(path: str, logger: DistributedLogger, few_shot: bool, *args, **kwargs) -> List[Dict]:
dataset = {"dev": {}, "test": {}}
for split in ["dev", "test"]:
files = os.listdir(os.path.join(path, split))
diff --git a/applications/ColossalEval/colossal_eval/dataset/cmmlu.py b/applications/ColossalEval/colossal_eval/dataset/cmmlu.py
index 477280663..05752c248 100644
--- a/applications/ColossalEval/colossal_eval/dataset/cmmlu.py
+++ b/applications/ColossalEval/colossal_eval/dataset/cmmlu.py
@@ -101,9 +101,7 @@ class CMMLUDataset(BaseDataset):
"""
@staticmethod
- def load(
- path: str, logger: DistributedLogger, few_shot: bool, forward_only: bool, load_train: bool, load_reference: bool
- ) -> List[Dict]:
+ def load(path: str, logger: DistributedLogger, few_shot: bool, *args, **kwargs) -> List[Dict]:
dataset = {"dev": {}, "test": {}}
for split in ["dev", "test"]:
files = os.listdir(os.path.join(path, split))
diff --git a/applications/ColossalEval/colossal_eval/dataset/colossalai.py b/applications/ColossalEval/colossal_eval/dataset/colossalai.py
index 54ea478ae..0337454fa 100644
--- a/applications/ColossalEval/colossal_eval/dataset/colossalai.py
+++ b/applications/ColossalEval/colossal_eval/dataset/colossalai.py
@@ -37,7 +37,7 @@ class ColossalDataset(BaseDataset):
"""
@staticmethod
- def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]:
+ def load(path: str, logger: DistributedLogger, *args, **kwargs) -> List[Dict]:
dataset = {"test": {}}
data = jload(path)
data_per_category = get_data_per_category(data)
diff --git a/applications/ColossalEval/colossal_eval/dataset/cvalues.py b/applications/ColossalEval/colossal_eval/dataset/cvalues.py
index 30e802a02..4023a4c76 100644
--- a/applications/ColossalEval/colossal_eval/dataset/cvalues.py
+++ b/applications/ColossalEval/colossal_eval/dataset/cvalues.py
@@ -28,7 +28,7 @@ class CValuesDataset(BaseDataset):
"""
@staticmethod
- def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]:
+ def load(path: str, logger: DistributedLogger, *args, **kwargs) -> List[Dict]:
dataset = {"test": {}}
file_path = os.path.join(path, "cvalues_responsibility_mc.jsonl")
data_list = []
diff --git a/applications/ColossalEval/colossal_eval/dataset/gaokaobench.py b/applications/ColossalEval/colossal_eval/dataset/gaokaobench.py
index cda6276bf..44ccea9cf 100644
--- a/applications/ColossalEval/colossal_eval/dataset/gaokaobench.py
+++ b/applications/ColossalEval/colossal_eval/dataset/gaokaobench.py
@@ -69,9 +69,7 @@ class GaoKaoBenchDataset(BaseDataset):
"""
@staticmethod
- def load(
- path: str, logger: DistributedLogger, few_shot: bool, forward_only: bool, load_train: bool, load_reference: bool
- ) -> List[Dict]:
+ def load(path: str, logger: DistributedLogger, *args, **kwargs) -> List[Dict]:
dataset = {"test": {}}
for category in ["Fill-in-the-blank_Questions", "Multiple-choice_Questions", "Open-ended_Questions"]:
files = os.listdir(os.path.join(path, "data", category))
diff --git a/applications/ColossalEval/colossal_eval/dataset/longbench.py b/applications/ColossalEval/colossal_eval/dataset/longbench.py
index 9ea5e3c7d..eb61efaa0 100644
--- a/applications/ColossalEval/colossal_eval/dataset/longbench.py
+++ b/applications/ColossalEval/colossal_eval/dataset/longbench.py
@@ -77,7 +77,7 @@ class LongBenchDataset(BaseDataset):
"""
@staticmethod
- def load(path: str, logger: DistributedLogger) -> List[Dict]:
+ def load(path: str, logger: DistributedLogger, *args, **kwargs) -> List[Dict]:
dataset = {"test": {}}
files = os.listdir(path)
diff --git a/applications/ColossalEval/colossal_eval/dataset/mmlu.py b/applications/ColossalEval/colossal_eval/dataset/mmlu.py
index dcda68e8f..e9465c91b 100644
--- a/applications/ColossalEval/colossal_eval/dataset/mmlu.py
+++ b/applications/ColossalEval/colossal_eval/dataset/mmlu.py
@@ -31,9 +31,7 @@ class MMLUDataset(BaseDataset):
"""
@staticmethod
- def load(
- path: str, logger: DistributedLogger, few_shot: bool, forward_only: bool, load_train: bool, load_reference: bool
- ) -> List[Dict]:
+ def load(path: str, logger: DistributedLogger, few_shot: bool, *args, **kwargs) -> List[Dict]:
dataset = {"dev": {}, "test": {}}
for split in ["dev", "test"]:
files = os.listdir(os.path.join(path, split))
diff --git a/applications/ColossalEval/colossal_eval/dataset/mtbench.py b/applications/ColossalEval/colossal_eval/dataset/mtbench.py
index 031415567..ef474ec4c 100644
--- a/applications/ColossalEval/colossal_eval/dataset/mtbench.py
+++ b/applications/ColossalEval/colossal_eval/dataset/mtbench.py
@@ -27,12 +27,12 @@ class MTBenchDataset(BaseDataset):
This dataset class will convert the original dataset into the inference dataset.
"""
- def __init__(self, path, logger, few_shot):
+ def __init__(self, path, logger: DistributedLogger, *args, **kwargs):
self.multiturn = True
- self.dataset = self.load(path, logger, few_shot)
+ self.dataset = self.load(path, logger, *args, **kwargs)
@staticmethod
- def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]:
+ def load(path: str, logger: DistributedLogger, *args, **kwargs) -> List[Dict]:
dataset = {"test": defaultdict(dict)}
file_path = os.path.join(path, "question.jsonl")
diff --git a/applications/ColossalEval/colossal_eval/dataset/safetybench_en.py b/applications/ColossalEval/colossal_eval/dataset/safetybench_en.py
index e77a3da34..8056c3dfd 100644
--- a/applications/ColossalEval/colossal_eval/dataset/safetybench_en.py
+++ b/applications/ColossalEval/colossal_eval/dataset/safetybench_en.py
@@ -130,7 +130,7 @@ class SafetyBenchENDataset(BaseDataset):
"""
@staticmethod
- def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]:
+ def load(path: str, logger: DistributedLogger, few_shot: bool, *args, **kwargs) -> List[Dict]:
dataset = {"dev": {}, "test": {}}
data_files = [os.path.join(path, file_name) for file_name in FILES]
for file_path in data_files:
diff --git a/applications/ColossalEval/colossal_eval/dataset/safetybench_zh.py b/applications/ColossalEval/colossal_eval/dataset/safetybench_zh.py
index 3eca808bb..f5f17e64c 100644
--- a/applications/ColossalEval/colossal_eval/dataset/safetybench_zh.py
+++ b/applications/ColossalEval/colossal_eval/dataset/safetybench_zh.py
@@ -130,7 +130,7 @@ class SafetyBenchZHDataset(BaseDataset):
"""
@staticmethod
- def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]:
+ def load(path: str, logger: DistributedLogger, few_shot: bool, *args, **kwargs) -> List[Dict]:
dataset = {"dev": {}, "test": {}}
data_files = [os.path.join(path, file_name) for file_name in FILES]
for file_path in data_files:
diff --git a/applications/ColossalEval/colossal_eval/models/huggingface.py b/applications/ColossalEval/colossal_eval/models/huggingface.py
index 23c399cce..e91743525 100644
--- a/applications/ColossalEval/colossal_eval/models/huggingface.py
+++ b/applications/ColossalEval/colossal_eval/models/huggingface.py
@@ -1,11 +1,11 @@
import copy
-import math
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import torch
from colossal_eval.utils import Conversation, get_batch_prompt, is_rank_0
from peft import PeftModel
+from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer
@@ -130,7 +130,7 @@ class HuggingFaceModel(BaseModel):
if shard_config is not None:
self.model = AutoModel.from_pretrained(path, **model_kwargs)
shard_former = ShardFormer(shard_config)
- self.model, sharded_parameters = shard_former.optimize(self.model)
+ self.model, _ = shard_former.optimize(self.model)
self.model.to(get_current_device())
if peft_path is not None:
@@ -325,7 +325,7 @@ class HuggingFaceModel(BaseModel):
return input_ids_list, labels_list, None
- def inference(self, data: List[Dict], inference_kwargs: Dict[str, Any], debug: bool = False) -> List[Dict]:
+ def inference(self, data_loader: DataLoader, inference_kwargs: Dict[str, Any], debug: bool = False) -> List[Dict]:
"""
Infer the given data.
This function will call self.generate() to get model outputs and also self.model() to get logits.
@@ -359,26 +359,23 @@ class HuggingFaceModel(BaseModel):
self.str_label_map = {choice: idx for idx, choice in enumerate(self.choices)}
- turn = 0 if not isinstance(data[0]["output"], list) else len(data[0]["output"]) + 1
- turn_desc = "" if turn == 0 else f"-turn{turn}"
-
bar = tqdm(
- range(math.ceil(len(data) / self.batch_size)),
- desc=f"{data[0]['dataset']}-{data[0]['category']}{turn_desc} Inference steps",
+ range(len(data_loader)),
+ desc=f"{inference_kwargs['dataset']}-{inference_kwargs['category']} Inference steps",
disable=not is_rank_0(),
)
loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
- answers = copy.deepcopy(data)
- for i in range(0, len(data), self.batch_size):
- batch = data[i : i + self.batch_size]
+ answers = []
+
+ for i, batch in enumerate(data_loader):
batch_prompt, batch_target = get_batch_prompt(
- self.prompt_template, batch, few_shot_data, self.tokenizer, language, self.model_max_length
+ self.prompt_template, batch, few_shot_data, self.tokenizer, self.model_max_length
)
if is_rank_0() and debug and i == 0:
self.logger.info(
- f"Inference arguments for dataset {data[0]['dataset']} category {data[0]['category']} is:\n{inference_kwargs}"
+ f"Inference arguments for dataset {batch[0]['dataset']} category {batch[0]['category']} is:\n{inference_kwargs}"
)
self.logger.info("-" * 120)
self.logger.info("An example prompt and prompt with target is:")
@@ -402,7 +399,7 @@ class HuggingFaceModel(BaseModel):
# Otherwise this will violate the single-choice setting.
if calculate_loss:
- labels = [self.str_label_map[answers[i + j]["target"]] for j in range(len(batch_decodes))]
+ labels = [self.str_label_map[batch[j]["target"]] for j in range(len(batch))]
loss_over_choices = loss_fct(scores, torch.tensor(labels, dtype=torch.long)).numpy().tolist()
@@ -411,29 +408,30 @@ class HuggingFaceModel(BaseModel):
{choice: probs[i][self.str_label_map[choice]] for choice in self.choices} for i in range(len(probs))
]
- for j in range(len(batch_prompt)):
+ for j in range(len(batch)):
if not pretrain:
- if isinstance(answers[i + j]["output"], list):
- answers[i + j]["output"].append(batch_decodes[j].strip())
+ if isinstance(batch[j]["output"], list):
+ batch[j]["output"].append(batch_decodes[j].strip())
else:
- answers[i + j]["output"] = batch_decodes[j].strip()
+ batch[j]["output"] = batch_decodes[j].strip()
if isinstance(scores, torch.Tensor):
- answers[i + j]["logits_over_choices"] = probs[j]
+ batch[j]["logits_over_choices"] = probs[j]
if calculate_loss:
- answers[i + j]["loss_over_choices"] = loss_over_choices[j]
+ batch[j]["loss_over_choices"] = loss_over_choices[j]
if calculate_loss:
- answers[i + j]["loss"] = (np.array(batch_losses[j]) / np.array(batch_target_token_nums[j])).tolist()
+ batch[j]["loss"] = (np.array(batch_losses[j]) / np.array(batch_target_token_nums[j])).tolist()
# loss_sum is specially used for pertrain dataset for calculating per-byte-perplexity.
# However, loss (which is per sample loss) suffices for most cases.
- answers[i + j]["loss_sum"] = batch_losses[j]
- answers[i + j]["token_num"] = batch_target_token_nums[j]
+ batch[j]["loss_sum"] = batch_losses[j]
+ batch[j]["token_num"] = batch_target_token_nums[j]
if batch_bytes_nums:
- answers[i + j]["byte_num"] = batch_bytes_nums[j]
+ batch[j]["byte_num"] = batch_bytes_nums[j]
+ answers.extend(batch)
bar.update()
@@ -600,7 +598,7 @@ class HuggingFaceCausalLM(HuggingFaceModel):
if shard_config is not None:
self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs)
shard_former = ShardFormer(shard_config)
- self.model, sharded_parameters = shard_former.optimize(self.model)
+ self.model, _ = shard_former.optimize(self.model)
self.model.to(get_current_device())
if peft_path is not None:
diff --git a/applications/ColossalEval/colossal_eval/utils/conversation.py b/applications/ColossalEval/colossal_eval/utils/conversation.py
index 330083aa6..c0445e84e 100644
--- a/applications/ColossalEval/colossal_eval/utils/conversation.py
+++ b/applications/ColossalEval/colossal_eval/utils/conversation.py
@@ -123,15 +123,13 @@ class Conversation:
}
-def get_few_shot_prefix(
- conv: Conversation, few_shot_data: List[str], tokenizer: Optional[AutoTokenizer], language: str, max_tokens: int
-) -> str:
+def get_few_shot_prefix(few_shot_data: List[str], tokenizer: Optional[AutoTokenizer], max_tokens: int) -> str:
"""
Get few shot prefix.
Args:
- conv: Conversation template.
- few_shot_examples: Few shot examples to generate few shot prompt prefix.
+ few_shot_data: Few shot examples to generate few shot prompt prefix.
+ tokenizer: tokenizer used to tokenize data.
Returns:
Few shot prompt prefix.
@@ -157,7 +155,6 @@ def get_batch_prompt(
batch: List[Dict],
few_shot_data: List[str],
tokenizer: Optional[AutoTokenizer],
- language: Optional[str],
model_max_length: Optional[int],
) -> Tuple[List[Dict], List[Dict]]:
"""
@@ -167,6 +164,7 @@ def get_batch_prompt(
conv: Conversation template.
batch: Batch data to generate prompt from.
few_shot_data: Few shot data to generate few shot prompt prefix.
+ tokenizer: tokenizer used to tokenize data.
Returns:
Tuple containg batch prompt and target.
@@ -192,7 +190,7 @@ def get_batch_prompt(
else:
raise Exception("When using few-shot, target answer should be a string.")
- few_shot_prefix = get_few_shot_prefix(conv, few_shot_data, tokenizer, language, max_tokens)
+ few_shot_prefix = get_few_shot_prefix(few_shot_data, tokenizer, max_tokens)
conv.append_message(conv.roles[0], few_shot_prefix + query_text)
conv.append_message(conv.roles[1], None)
diff --git a/applications/ColossalEval/examples/dataset_evaluation/inference.py b/applications/ColossalEval/examples/dataset_evaluation/inference.py
index a7307635d..c651970ee 100644
--- a/applications/ColossalEval/examples/dataset_evaluation/inference.py
+++ b/applications/ColossalEval/examples/dataset_evaluation/inference.py
@@ -5,6 +5,8 @@ from typing import Dict, List
import torch.distributed as dist
from colossal_eval import dataset, models, utils
+from colossal_eval.dataset.base import DistributedDataset
+from torch.utils.data import DataLoader, DistributedSampler
import colossalai
from colossalai.accelerator import get_accelerator
@@ -13,6 +15,7 @@ from colossalai.logging import get_dist_logger
from colossalai.shardformer import ShardConfig
logger = get_dist_logger()
+os.environ["TOKENIZERS_PARALLELISM"] = "false"
def rm_and_merge(
@@ -54,7 +57,8 @@ def rm_and_merge(
)
else:
rank_answers = utils.jload(directory)
- answers["data"].extend(rank_answers["data"])
+ deduplidate_answers = [x for x in rank_answers["data"] if x not in answers["data"]]
+ answers["data"].extend(deduplidate_answers)
answers["inference_kwargs"] = rank_answers["inference_kwargs"]
for r in range(dp_size):
@@ -65,7 +69,7 @@ def rm_and_merge(
os.remove(directory)
except Exception as e:
print(e)
-
+ print(len(answers["data"]))
all_answers[category] = answers
all_answers_with_dataset_class["inference_results"] = all_answers
@@ -108,7 +112,12 @@ def main(args):
tp_rank = coordinates[TP_AXIS]
shard_config = (
- ShardConfig(tensor_parallel_process_group=tp_group, enable_tensor_parallelism=args.tp_size > 1)
+ ShardConfig(
+ tensor_parallel_process_group=tp_group,
+ enable_tensor_parallelism=args.tp_size > 1,
+ parallel_output=False,
+ enable_all_optimization=True,
+ )
if args.tp_size > 1
else None
)
@@ -183,6 +192,7 @@ def main(args):
model_name = model_parameter["name"]
model_class = eval(f"models.{model_parameter['model_class']}")
paramerters = model_parameter["parameters"]
+ batch_size = paramerters["batch_size"]
paramerters.update({"logger": logger})
paramerters.update({"prompt_template": utils.prompt_templates[paramerters["prompt_template"]]})
paramerters.update({"shard_config": shard_config})
@@ -192,7 +202,6 @@ def main(args):
raise ValueError(f"Model class {model_parameter['model_class']} is not a subclass of BaseModel.")
for dataset_name, split_data in inference_data.items():
- start = 0
prev_questions = None
for category, category_data in split_data.items():
num_turn = category_data["inference_kwargs"].get("turns", 1)
@@ -201,26 +210,33 @@ def main(args):
raise Exception(f"Dataset {dataset_name} doesn't have few-shot data for category {category}!")
answers_to_dump = copy.deepcopy(category_data)
- partition_size = len(category_data["data"]) // dp_size
- redundant = len(category_data["data"]) % dp_size
-
- # Ensure that the amount of data for inference is as consistent as possible across different processes.
- lengths = [partition_size for _ in range(dp_size)]
- for j in range(redundant):
- lengths[(j + start) % dp_size] += 1
-
- start = (start + redundant) % dp_size
-
for turn in range(num_turn):
if turn == 0:
- questions = category_data["data"][
- sum(lengths[0:dp_rank]) : sum(lengths[0:dp_rank]) + lengths[dp_rank]
- ]
+ dist_dataset = DistributedDataset(category_data["data"])
else:
- questions = prev_questions
+ dist_dataset = DistributedDataset(prev_questions)
+
+ sampler = DistributedSampler(
+ dist_dataset,
+ num_replicas=pg_mesh.size(DP_AXIS),
+ rank=pg_mesh.coordinate(DP_AXIS),
+ shuffle=False,
+ )
+ questions_loader = DataLoader(
+ dist_dataset,
+ batch_size=batch_size,
+ sampler=sampler,
+ num_workers=8,
+ pin_memory=True,
+ collate_fn=lambda x: x,
+ )
+ category_data["inference_kwargs"]["dataset"] = dataset_name
+ category_data["inference_kwargs"]["category"] = category
answers_per_rank = model_.inference(
- questions, inference_kwargs=category_data["inference_kwargs"], debug=debug_args[dataset_name]
+ data_loader=questions_loader,
+ inference_kwargs=category_data["inference_kwargs"],
+ debug=debug_args[dataset_name],
)
prev_questions = answers_per_rank
From 75c963686f889fa68110cac460baaad08ac78f82 Mon Sep 17 00:00:00 2001
From: Wang Binluo <32676639+wangbluo@users.noreply.github.com>
Date: Fri, 2 Aug 2024 10:36:58 +0800
Subject: [PATCH 105/116] [lora] lora support hybrid parallel plugin (#5956)
* lora support hybrid plugin
* fix
* fix
* fix
* fix
---
.../booster/plugin/hybrid_parallel_plugin.py | 25 ++++++++++++++++---
.../hybrid_parallel_checkpoint_io.py | 14 +++++++++++
.../shardformer/policies/auto_policy.py | 3 +++
tests/test_lora/test_lora.py | 7 ++++--
4 files changed, 44 insertions(+), 5 deletions(-)
diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py
index fcb747814..d2933a4af 100644
--- a/colossalai/booster/plugin/hybrid_parallel_plugin.py
+++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py
@@ -30,6 +30,7 @@ from colossalai.interface.optimizer import DistributedOptim
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
+from colossalai.quantization import BnbQuantizationConfig, quantize_model
from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer
from colossalai.shardformer.layer.utils import SeqParallelUtils
from colossalai.shardformer.policies.base_policy import Policy
@@ -1187,7 +1188,7 @@ class HybridParallelPlugin(PipelinePluginBase):
return True
def support_lora(self) -> bool:
- return False
+ return True
def control_checkpoint_io(self) -> bool:
return True
@@ -1415,6 +1416,24 @@ class HybridParallelPlugin(PipelinePluginBase):
return optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()
def enable_lora(
- self, model: Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
+ self,
+ model: Module,
+ pretrained_dir: Optional[str] = None,
+ lora_config: Optional[Dict] = None,
+ bnb_quantization_config: Optional[BnbQuantizationConfig] = None,
) -> Module:
- raise NotImplementedError
+ from peft import PeftModel, get_peft_model
+
+ assert not isinstance(model, HybridParallelModule), "Lora should be enabled before boosting the model."
+ assert self.pp_size == 1 and self.tp_size == 1
+ self.lora_enabled = True
+ warnings.warn("You have enabled LoRa training. Please check the hyperparameters such as lr")
+
+ if bnb_quantization_config is not None:
+ model = quantize_model(model, bnb_quantization_config)
+
+ if pretrained_dir is None:
+ peft_model = get_peft_model(model, lora_config)
+ else:
+ peft_model = PeftModel.from_pretrained(model, pretrained_dir, is_trainable=True)
+ return peft_model
diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
index b7097e432..0310df548 100644
--- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
+++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
@@ -947,3 +947,17 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
state_[k] = v.detach().clone().to(device)
return state_
+
+ def save_lora_as_pretrained(self, model, checkpoint, use_safetensors):
+ if os.path.isfile(checkpoint):
+ logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
+ return
+ from peft import PeftModel
+
+ assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
+ model._force_wait_all_gather()
+ peft_model = model.unwrap()
+ assert isinstance(
+ peft_model, PeftModel
+ ), "The model doesn't have lora adapters, please enable lora before saving."
+ return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors)
diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py
index f2533da4b..7b9c759a6 100644
--- a/colossalai/shardformer/policies/auto_policy.py
+++ b/colossalai/shardformer/policies/auto_policy.py
@@ -243,6 +243,9 @@ def _fullname(obj):
# patch custom models which are not in transformers
# it can be like 'transformers_modules.THUDM.chatglm3-6b.103caa40027ebfd8450289ca2f278eac4ff26405.modeling_chatglm' (from huggingface hub)
# or like 'transformers_modules.chatglm.modeling_chatglm' (from local directory)
+ if module.startswith("peft"):
+ klass = obj.base_model.model.__class__
+ module = klass.__module__
if module.startswith("transformers_modules"):
split_module = module.split(".")
if len(split_module) >= 2:
diff --git a/tests/test_lora/test_lora.py b/tests/test_lora/test_lora.py
index b8daf775d..1ae17025d 100644
--- a/tests/test_lora/test_lora.py
+++ b/tests/test_lora/test_lora.py
@@ -9,7 +9,8 @@ from torch.optim import AdamW
import colossalai
from colossalai.booster import Booster
-from colossalai.booster.plugin import LowLevelZeroPlugin, TorchDDPPlugin
+from colossalai.booster.plugin import HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
+from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule
from colossalai.testing import check_state_dict_equal, clear_cache_before_run, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_checkpoint_io.utils import shared_tempdir
@@ -20,7 +21,7 @@ def check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type
model = model_fn()
lora_config = LoraConfig(task_type=task_type, r=8, lora_alpha=32, lora_dropout=0.1)
- test_plugins = [TorchDDPPlugin(), LowLevelZeroPlugin()]
+ test_plugins = [TorchDDPPlugin(), LowLevelZeroPlugin(), HybridParallelPlugin(tp_size=1, pp_size=1)]
test_configs = [
{
"lora_config": lora_config,
@@ -59,6 +60,8 @@ def check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type
# test fwd bwd correctness
test_model = model_load
+ if isinstance(model_load, HybridParallelModule):
+ model_load = model_load.module.module
model_copy = copy.deepcopy(model_load)
data = data_gen_fn()
From 0b2d55c4ab518bd2e6e66195aaead28d7311ab8f Mon Sep 17 00:00:00 2001
From: YeAnbang
Date: Fri, 2 Aug 2024 06:51:38 +0000
Subject: [PATCH 106/116] Support overall loss, update KTO logging
---
.../coati/dataset/tokenization_utils.py | 19 +++-
.../ColossalChat/coati/models/loss.py | 16 ++-
.../ColossalChat/coati/trainer/dpo.py | 9 ++
.../ColossalChat/coati/trainer/kto.py | 37 ++++++-
.../ColossalChat/coati/trainer/orpo.py | 12 ++
.../ColossalChat/coati/trainer/ppo.py | 12 +-
.../ColossalChat/coati/trainer/sft.py | 14 ++-
applications/ColossalChat/examples/README.md | 1 +
.../examples/inference/inference.py | 4 +-
.../ColossalChat/examples/inference/round.txt | 104 ------------------
.../examples/training_scripts/train_dpo.py | 2 +
.../examples/training_scripts/train_kto.py | 2 +
.../examples/training_scripts/train_orpo.py | 2 +
.../examples/training_scripts/train_ppo.py | 2 +
.../examples/training_scripts/train_sft.py | 2 +
15 files changed, 119 insertions(+), 119 deletions(-)
delete mode 100644 applications/ColossalChat/examples/inference/round.txt
diff --git a/applications/ColossalChat/coati/dataset/tokenization_utils.py b/applications/ColossalChat/coati/dataset/tokenization_utils.py
index 9eb2eba87..4f890ffc9 100755
--- a/applications/ColossalChat/coati/dataset/tokenization_utils.py
+++ b/applications/ColossalChat/coati/dataset/tokenization_utils.py
@@ -49,6 +49,10 @@ def tokenize_sft(
messages = data_point["messages"]
template = deepcopy(conversation_template)
+
+ if messages[0]["from"] == "system":
+ template.system_message = str(messages[0]["content"])
+ messages.pop(0)
template.messages = []
for idx, mess in enumerate(messages):
if mess["from"] != template.roles[idx % 2]:
@@ -148,11 +152,14 @@ def tokenize_prompt(
template = deepcopy(conversation_template)
template.messages = []
+ if messages[0]["from"] == "system":
+ template.system_message = str(messages[0]["content"])
+ messages.pop(0)
+
for idx, mess in enumerate(messages):
if mess["from"] != template.roles[idx % 2]:
raise ValueError(
- f"Message should iterate between user and assistant and starts with a \
- line from the user. Got the following data:\n{messages}"
+ f"Message should iterate between user and assistant and starts with a line from the user. Got the following data:\n{messages}"
)
template.append_message(mess["from"], mess["content"])
@@ -225,6 +232,10 @@ def tokenize_rlhf(
template = deepcopy(conversation_template)
template.clear()
+ if context[0]["from"] == "system":
+ template.system_message = str(context[0]["content"])
+ context.pop(0)
+
for idx, mess in enumerate(context):
if mess["from"] != template.roles[idx % 2]:
raise ValueError(
@@ -345,6 +356,10 @@ def tokenize_kto(
template = deepcopy(conversation_template)
template.clear()
+ if prompt[0]["from"] == "system":
+ template.system_message = str(prompt[0]["content"])
+ prompt.pop(0)
+
if prompt[0].get("from", None) != "user":
raise ValueError("conversation should start with user")
if completion.get("from", None) != "assistant":
diff --git a/applications/ColossalChat/coati/models/loss.py b/applications/ColossalChat/coati/models/loss.py
index 840cca074..bd0bbd36b 100755
--- a/applications/ColossalChat/coati/models/loss.py
+++ b/applications/ColossalChat/coati/models/loss.py
@@ -46,7 +46,10 @@ class PolicyLoss(nn.Module):
action_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
skip = False
- ratio_ = ((log_probs - old_log_probs) * action_mask).exp()
+ if action_mask is None:
+ ratio_ = (log_probs - old_log_probs).exp()
+ else:
+ ratio_ = ((log_probs - old_log_probs) * action_mask).exp()
# note that if dropout is disabled (recommanded), ratio will always be 1.
if ratio_.mean() > self.skip_threshold:
@@ -56,7 +59,10 @@ class PolicyLoss(nn.Module):
surr1 = ratio * advantages
surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages
loss = -torch.min(surr1, surr2)
- loss = masked_mean(loss, action_mask)
+ if action_mask is not None:
+ loss = masked_mean(loss, action_mask)
+ else:
+ loss = loss.mean(dim=1)
loss = loss.mean()
return loss, skip, ratio_.max()
@@ -81,8 +87,10 @@ class ValueLoss(nn.Module):
values_clipped = old_values + (values - old_values).clamp(-self.clip_eps, self.clip_eps)
surr1 = (values_clipped - returns) ** 2
surr2 = (values - returns) ** 2
- loss = torch.max(surr1, surr2) / torch.sum(action_mask)
- loss = torch.sum(loss * action_mask)
+ if action_mask is not None:
+ loss = torch.sum(torch.max(surr1, surr2) / torch.sum(action_mask) * action_mask)
+ else:
+ loss = torch.mean(torch.max(surr1, surr2))
return 0.5 * loss
diff --git a/applications/ColossalChat/coati/trainer/dpo.py b/applications/ColossalChat/coati/trainer/dpo.py
index c7ef2be8f..24ddca654 100755
--- a/applications/ColossalChat/coati/trainer/dpo.py
+++ b/applications/ColossalChat/coati/trainer/dpo.py
@@ -56,6 +56,7 @@ class DPOTrainer(SLTrainer):
beta: float = 0.1,
gamma: float = 0.0,
length_normalization: bool = False,
+ apply_loss_mask: bool = True,
accumulation_steps: int = 1,
start_epoch: int = 0,
save_interval: int = 0,
@@ -67,6 +68,7 @@ class DPOTrainer(SLTrainer):
self.actor_scheduler = actor_lr_scheduler
self.tokenizer = tokenizer
self.actor_loss_fn = DpoLoss(beta, gamma)
+ self.apply_loss_mask = apply_loss_mask
self.save_interval = save_interval
self.coordinator = coordinator
self.save_dir = save_dir
@@ -135,6 +137,10 @@ class DPOTrainer(SLTrainer):
batch["reject_attention_mask"],
batch["reject_loss_mask"],
)
+ if not self.apply_loss_mask:
+ chosen_loss_mask = chosen_loss_mask.fill_(1.0)
+ reject_loss_mask = reject_loss_mask.fill_(1.0)
+
batch_size = chosen_input_ids.size()[0]
actor_all_logits = self.model(
@@ -284,6 +290,9 @@ class DPOTrainer(SLTrainer):
batch["reject_attention_mask"],
batch["reject_loss_mask"],
)
+ if not self.apply_loss_mask:
+ chosen_loss_mask = chosen_loss_mask.fill_(1.0)
+ reject_loss_mask = reject_loss_mask.fill_(1.0)
batch_size = chosen_input_ids.size()[0]
diff --git a/applications/ColossalChat/coati/trainer/kto.py b/applications/ColossalChat/coati/trainer/kto.py
index 8ab0bc66b..6462ba816 100755
--- a/applications/ColossalChat/coati/trainer/kto.py
+++ b/applications/ColossalChat/coati/trainer/kto.py
@@ -6,7 +6,7 @@ import os
from typing import Any, Optional
import torch
-import torch.distributed
+import torch.distributed as dist
from coati.models.loss import KTOLoss
from coati.models.utils import calc_masked_log_probs
from coati.trainer.utils import all_reduce_mean
@@ -59,6 +59,7 @@ class KTOTrainer(SLTrainer):
beta: float = 0.1,
desirable_weight: float = 1.0,
undesirable_weight: float = 1.0,
+ apply_loss_mask: bool = True,
accumulation_steps: int = 1,
start_epoch: int = 0,
save_interval: int = 0,
@@ -70,6 +71,7 @@ class KTOTrainer(SLTrainer):
self.actor_scheduler = actor_lr_scheduler
self.tokenizer = tokenizer
self.kto_loss = KTOLoss(beta=beta, desirable_weight=desirable_weight, undesirable_weight=undesirable_weight)
+ self.apply_loss_mask = apply_loss_mask
self.save_interval = save_interval
self.coordinator = coordinator
self.save_dir = save_dir
@@ -134,6 +136,10 @@ class KTOTrainer(SLTrainer):
batch["kl_attention_mask"],
batch["kl_loss_mask"],
)
+ if not self.apply_loss_mask:
+ loss_mask = loss_mask.fill_(1.0)
+ kl_loss_mask = kl_loss_mask.fill_(1.0)
+
batch_size = input_ids.size()[0]
# actor logits
@@ -182,8 +188,28 @@ class KTOTrainer(SLTrainer):
# sync
loss_mean = all_reduce_mean(tensor=loss)
- chosen_rewards_mean = all_reduce_mean(tensor=chosen_rewards.mean())
- rejected_rewards_mean = all_reduce_mean(tensor=rejected_rewards.mean())
+ chosen_reward_mean = chosen_rewards.mean()
+ chosen_rewards_list = [
+ torch.tensor(0, dtype=loss.dtype, device=loss.device) for _ in range(dist.get_world_size())
+ ]
+ dist.all_gather(chosen_rewards_list, chosen_reward_mean)
+ rejected_reward_mean = rejected_rewards.mean()
+ rejected_rewards_list = [
+ torch.tensor(0, dtype=loss.dtype, device=loss.device) for _ in range(dist.get_world_size())
+ ]
+ dist.all_gather(rejected_rewards_list, rejected_reward_mean)
+ chosen_rewards_list = [i for i in chosen_rewards_list if not i.isnan()]
+ rejected_rewards_list = [i for i in rejected_rewards_list if not i.isnan()]
+ chosen_rewards_mean = (
+ torch.stack(chosen_rewards_list).mean()
+ if len(chosen_rewards_list) > 0
+ else torch.tensor(torch.nan, dtype=loss.dtype, device=loss.device)
+ )
+ rejected_rewards_mean = (
+ torch.stack(rejected_rewards_list).mean()
+ if len(rejected_rewards_list) > 0
+ else torch.tensor(torch.nan, dtype=loss.dtype, device=loss.device)
+ )
self.accumulative_meter.add("chosen_rewards", chosen_rewards_mean.to(torch.float16).mean().item())
self.accumulative_meter.add("rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item())
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).detach().item())
@@ -256,6 +282,11 @@ class KTOTrainer(SLTrainer):
batch["kl_attention_mask"],
batch["kl_loss_mask"],
)
+
+ if not self.apply_loss_mask:
+ loss_mask = loss_mask.fill_(1.0)
+ kl_loss_mask = kl_loss_mask.fill_(1.0)
+
batch_size = input_ids.size()[0]
# actor logits
diff --git a/applications/ColossalChat/coati/trainer/orpo.py b/applications/ColossalChat/coati/trainer/orpo.py
index b039da4af..c2f75771c 100644
--- a/applications/ColossalChat/coati/trainer/orpo.py
+++ b/applications/ColossalChat/coati/trainer/orpo.py
@@ -52,6 +52,7 @@ class ORPOTrainer(SLTrainer):
tokenizer: PreTrainedTokenizerBase,
max_epochs: int = 1,
lam: float = 0.1,
+ apply_loss_mask: bool = True,
accumulation_steps: int = 1,
start_epoch: int = 0,
save_interval: int = 0,
@@ -67,6 +68,7 @@ class ORPOTrainer(SLTrainer):
self.save_dir = save_dir
self.num_train_step = 0
self.lam = lam
+ self.apply_loss_mask = apply_loss_mask
self.accumulation_steps = accumulation_steps
self.device = get_current_device()
self.accumulative_meter = AccumulativeMeanMeter()
@@ -130,6 +132,11 @@ class ORPOTrainer(SLTrainer):
batch["reject_attention_mask"],
batch["reject_loss_mask"],
)
+
+ if not self.apply_loss_mask:
+ chosen_loss_mask = chosen_loss_mask.fill_(1.0)
+ reject_loss_mask = reject_loss_mask.fill_(1.0)
+
batch_size = chosen_input_ids.size()[0]
actor_out = self.model(
input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
@@ -263,6 +270,11 @@ class ORPOTrainer(SLTrainer):
batch["reject_attention_mask"],
batch["reject_loss_mask"],
)
+
+ if not self.apply_loss_mask:
+ chosen_loss_mask = chosen_loss_mask.fill_(1.0)
+ reject_loss_mask = reject_loss_mask.fill_(1.0)
+
batch_size = chosen_input_ids.size()[0]
actor_out = self.model(
input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
diff --git a/applications/ColossalChat/coati/trainer/ppo.py b/applications/ColossalChat/coati/trainer/ppo.py
index 287767669..63c813b39 100755
--- a/applications/ColossalChat/coati/trainer/ppo.py
+++ b/applications/ColossalChat/coati/trainer/ppo.py
@@ -102,6 +102,7 @@ class PPOTrainer(OLTrainer):
sample_buffer: bool = False,
dataloader_pin_memory: bool = True,
offload_inference_models: bool = True,
+ apply_loss_mask: bool = True,
accumulation_steps: int = 1,
save_interval: int = 0,
save_dir: str = None,
@@ -140,6 +141,7 @@ class PPOTrainer(OLTrainer):
self.actor_optim = actor_optim
self.critic_optim = critic_optim
self.save_interval = save_interval
+ self.apply_loss_mask = apply_loss_mask
self.coordinator = coordinator
self.actor_save_dir = os.path.join(save_dir, "actor")
self.critic_save_dir = os.path.join(save_dir, "critic")
@@ -229,7 +231,10 @@ class PPOTrainer(OLTrainer):
action_log_probs = calc_action_log_probs(actor_logits, experience.sequences, num_actions)
actor_loss, to_skip, max_ratio = self.actor_loss_fn(
- action_log_probs, experience.action_log_probs, experience.advantages, action_mask=experience.action_mask
+ action_log_probs,
+ experience.action_log_probs,
+ experience.advantages,
+ action_mask=experience.action_mask if self.apply_loss_mask else None,
)
actor_loss = (1 - self.ptx_coef) * actor_loss
if not to_skip:
@@ -249,7 +254,10 @@ class PPOTrainer(OLTrainer):
input_ids=experience.sequences, attention_mask=experience.attention_mask
) # [batch size, prompt_length + response_length]
critic_loss = self.critic_loss_fn(
- values[:, -num_actions:], experience.values, experience.advantages, action_mask=experience.action_mask
+ values[:, -num_actions:],
+ experience.values,
+ experience.advantages,
+ action_mask=experience.action_mask if self.apply_loss_mask else None,
)
critic_loss = critic_loss * self.vf_coef
self.critic_booster.backward(loss=critic_loss, optimizer=self.critic_optim)
diff --git a/applications/ColossalChat/coati/trainer/sft.py b/applications/ColossalChat/coati/trainer/sft.py
index c09d61034..d37676ada 100755
--- a/applications/ColossalChat/coati/trainer/sft.py
+++ b/applications/ColossalChat/coati/trainer/sft.py
@@ -41,6 +41,7 @@ class SFTTrainer(SLTrainer):
lr_scheduler: _LRScheduler,
max_epochs: int = 2,
accumulation_steps: int = 8,
+ apply_loss_mask: bool = True,
start_epoch=0,
save_interval: int = None,
save_dir: str = None,
@@ -55,6 +56,7 @@ class SFTTrainer(SLTrainer):
self.coordinator = coordinator
self.num_train_step = 0
self.num_eval_step = 0
+ self.apply_loss_mask = apply_loss_mask
self.accumulative_meter = AccumulativeMeanMeter()
def _before_fit(
@@ -100,7 +102,11 @@ class SFTTrainer(SLTrainer):
for i, batch in enumerate(self.train_dataloader):
batch = to_device(batch, torch.cuda.current_device())
batch_size = batch["input_ids"].size(0)
- outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
+ outputs = self.model(
+ batch["input_ids"],
+ attention_mask=batch["attention_mask"],
+ labels=batch["labels"] if self.apply_loss_mask else batch["input_ids"],
+ )
loss = outputs.loss
self.booster.backward(loss=loss, optimizer=self.optimizer)
@@ -158,7 +164,11 @@ class SFTTrainer(SLTrainer):
)
for batch in self.eval_dataloader:
batch = to_device(batch, torch.cuda.current_device())
- outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
+ outputs = self.model(
+ batch["input_ids"],
+ attention_mask=batch["attention_mask"],
+ labels=batch["labels"] if self.apply_loss_mask else batch["input_ids"],
+ )
loss_mean = all_reduce_mean(tensor=outputs.loss)
self.accumulative_meter.add("loss", loss_mean.item(), count_update=batch["input_ids"].size(0))
step_bar.update()
diff --git a/applications/ColossalChat/examples/README.md b/applications/ColossalChat/examples/README.md
index b749f197e..904d69cfc 100755
--- a/applications/ColossalChat/examples/README.md
+++ b/applications/ColossalChat/examples/README.md
@@ -387,6 +387,7 @@ colossalai run --nproc_per_node 4 --master_port 28534 --hostfile ./hostfile trai
- save_dir: path to store the model checkpoints.
- max_length: input will be padded/truncated to max_length before feeding to the model.
- max_epochs: number of epochs to train.
+- disable_loss_mask: whether to use the loss mask to mask the loss or not. For example, in SFT, if the loss mask is disabled, the model will compute the loss across all tokens in the sequence, if the loss mask is applied, only tokens correspond to the assistant responses will contribute to the final loss.
- batch_size: training batch size.
- mixed_precision: precision to use in training. Support 'fp16' and 'bf16'. Note that some devices may not support the 'bf16' option, please refer to [Nvidia](https://developer.nvidia.com/) to check compatibility.
- save_interval: save the model weights as well as optimizer/scheduler states every save_interval steps/episodes.
diff --git a/applications/ColossalChat/examples/inference/inference.py b/applications/ColossalChat/examples/inference/inference.py
index 103bd8d95..5f59ba452 100755
--- a/applications/ColossalChat/examples/inference/inference.py
+++ b/applications/ColossalChat/examples/inference/inference.py
@@ -53,8 +53,8 @@ def load_model_and_tokenizer(model_path, tokenizer_path, device="cuda", **kwargs
tuple: A tuple containing the loaded model and tokenizer.
"""
- model = AutoModelForCausalLM.from_pretrained(model_path, **kwargs)
- tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
+ model = AutoModelForCausalLM.from_pretrained(model_path, **kwargs, trust_remote_code=True).to(torch.bfloat16)
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
model.to(device)
diff --git a/applications/ColossalChat/examples/inference/round.txt b/applications/ColossalChat/examples/inference/round.txt
deleted file mode 100644
index ba02074c1..000000000
--- a/applications/ColossalChat/examples/inference/round.txt
+++ /dev/null
@@ -1,104 +0,0 @@
-
-
-==========
-round 1:
-[INST] <>
-A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
-
-
-<>
-
-tell me a story [/INST] Great, let’s hear a story.
-
-==========
-
-
-==========
-round 2:
-[INST] <>
-A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
-
-
-<>
-
-tell me a story [/INST] Great, let’s hear a story. [INST] calculate 1+1 [/INST] 1+1=2
-
-==========
-
-
-==========
-round 3:
-[INST] <>
-A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
-
-
-<>
-
-tell me a story [/INST] Great, let’s hear a story. [INST] calculate 1+1 [/INST] 1+1=2 [INST] who is the first president of the USA [/INST] The first president of the United States was George Washington.
-
-==========
-
-
-==========
-round 1:
-[INST] <>
-A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
-
-
-<>
-
-who is the first president of the USA? [/INST] The first president of the United States was George Washington.
-
-==========
-
-
-==========
-round 2:
-[INST] <>
-A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
-
-
-<>
-
-who is the first president of the USA? [/INST] The first president of the United States was George Washington. [INST] tell me a story [/INST] One story that might be interesting is the story of how the United States was founded. In 1776, the Thirteen Colonies united together to form the new nation of America. The first president of the new nation was George Washington. The first state was Pennsylvania. The first capital was Philadelphia. The first presidential election was held in 1787, and was between George Read, a Federalist, and John Adams, a Republican. The Federalists wanted to keep the power of the federal government limited, while the Republicans wanted the power to be spread around to the states. The Federalists won, and George Washington was elected president.
-
-==========
-
-
-==========
-round 1:
-[INST] <>
-A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
-
-
-<>
-
-tell me a story [/INST] I do enjoy telling stories. Is there a particular story you’d like to hear?
-
-==========
-
-
-==========
-round 2:
-[INST] <>
-A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
-
-
-<>
-
-tell me a story [/INST] I do enjoy telling stories. Is there a particular story you’d like to hear? [INST] about Donald Trump [/INST] I’d be happy to listen to a story about Donald Trump. I have no idea what a “story” is, but I do know a lot about human life and human behavior. I can try. Do you have a theme or a question you’d like to ask about Donald Trump?
-
-==========
-
-
-==========
-round 3:
-[INST] <>
-A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
-
-
-<>
-
-tell me a story [/INST] I do enjoy telling stories. Is there a particular story you’d like to hear? [INST] about Donald Trump [/INST] I’d be happy to listen to a story about Donald Trump. I have no idea what a “story” is, but I do know a lot about human life and human behavior. I can try. Do you have a theme or a question you’d like to ask about Donald Trump? [INST] Is Donald Trump the president of the United States [/INST] Yes, Donald Trump became the 45th president of the United States in January of 2016.
-
-==========
diff --git a/applications/ColossalChat/examples/training_scripts/train_dpo.py b/applications/ColossalChat/examples/training_scripts/train_dpo.py
index 44131f572..d88750aeb 100755
--- a/applications/ColossalChat/examples/training_scripts/train_dpo.py
+++ b/applications/ColossalChat/examples/training_scripts/train_dpo.py
@@ -278,6 +278,7 @@ def train(args):
beta=args.beta,
gamma=args.gamma,
length_normalization=args.length_normalization,
+ apply_loss_mask=not args.disable_loss_mask,
)
trainer.fit(
@@ -346,6 +347,7 @@ if __name__ == "__main__":
default=False,
help="Disable the reference model (enabled by default)",
)
+ parser.add_argument("--disable_loss_mask", default=False, action="store_true")
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path")
parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints")
diff --git a/applications/ColossalChat/examples/training_scripts/train_kto.py b/applications/ColossalChat/examples/training_scripts/train_kto.py
index d063b82bb..598fd8062 100755
--- a/applications/ColossalChat/examples/training_scripts/train_kto.py
+++ b/applications/ColossalChat/examples/training_scripts/train_kto.py
@@ -297,6 +297,7 @@ def train(args):
beta=args.beta,
desirable_weight=args.desirable_weight,
undesirable_weight=args.undesirable_weight,
+ apply_loss_mask=not args.disable_loss_mask,
)
trainer.fit(
@@ -341,6 +342,7 @@ if __name__ == "__main__":
parser.add_argument("--beta", type=float, default=0.1, help="beta in KTO loss")
parser.add_argument("--desirable_weight", type=float, default=1.0, help="desirable_weight in KTO loss")
parser.add_argument("--undesirable_weight", type=float, default=1.0, help="undesirable_weight in KTO loss")
+ parser.add_argument("--disable_loss_mask", default=False, action="store_true")
parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true")
parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
parser.add_argument("--zero_cpu_offload", default=False, action="store_true")
diff --git a/applications/ColossalChat/examples/training_scripts/train_orpo.py b/applications/ColossalChat/examples/training_scripts/train_orpo.py
index f06524507..87860f7ea 100755
--- a/applications/ColossalChat/examples/training_scripts/train_orpo.py
+++ b/applications/ColossalChat/examples/training_scripts/train_orpo.py
@@ -259,6 +259,7 @@ def train(args):
save_dir=args.save_dir,
coordinator=coordinator,
lam=args.lam,
+ apply_loss_mask=not args.disable_loss_mask,
)
trainer.fit(
@@ -301,6 +302,7 @@ if __name__ == "__main__":
parser.add_argument("--pp", type=int, default=1)
parser.add_argument("--sp", type=int, default=1)
parser.add_argument("--lam", type=float, default=0.1, help="lambda in ORPO loss")
+ parser.add_argument("--disable_loss_mask", default=False, action="store_true")
parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true")
parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
parser.add_argument("--zero_cpu_offload", default=False, action="store_true")
diff --git a/applications/ColossalChat/examples/training_scripts/train_ppo.py b/applications/ColossalChat/examples/training_scripts/train_ppo.py
index 333be9963..c10418394 100755
--- a/applications/ColossalChat/examples/training_scripts/train_ppo.py
+++ b/applications/ColossalChat/examples/training_scripts/train_ppo.py
@@ -411,6 +411,7 @@ def train(args):
use_cache=True,
do_sample=True,
temperature=0.7,
+ apply_loss_mask=not args.disable_loss_mask,
accumulation_steps=args.accumulation_steps,
save_dir=args.save_path,
save_interval=args.save_interval,
@@ -498,6 +499,7 @@ if __name__ == "__main__":
parser.add_argument("--critic_lr", type=float, default=9e-6)
parser.add_argument("--kl_coef", type=float, default=0.1)
parser.add_argument("--ptx_coef", type=float, default=0.0)
+ parser.add_argument("--disable_loss_mask", default=False, action="store_true")
parser.add_argument("--max_length", type=int, default=2048)
parser.add_argument("--max_seq_len", type=int, default=256)
parser.add_argument("--log_dir", default="logs", type=str)
diff --git a/applications/ColossalChat/examples/training_scripts/train_sft.py b/applications/ColossalChat/examples/training_scripts/train_sft.py
index 6007a8599..c4ef3b783 100755
--- a/applications/ColossalChat/examples/training_scripts/train_sft.py
+++ b/applications/ColossalChat/examples/training_scripts/train_sft.py
@@ -272,6 +272,7 @@ def train(args):
lr_scheduler=lr_scheduler,
max_epochs=args.max_epochs,
accumulation_steps=args.accumulation_steps,
+ apply_loss_mask=not args.disable_loss_mask,
start_epoch=start_epoch,
save_interval=args.save_interval,
save_dir=args.save_path,
@@ -317,6 +318,7 @@ if __name__ == "__main__":
parser.add_argument("--tp", type=int, default=1)
parser.add_argument("--pp", type=int, default=1)
parser.add_argument("--sp", type=int, default=1)
+ parser.add_argument("--disable_loss_mask", default=False, action="store_true")
parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true")
parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
parser.add_argument("--zero_cpu_offload", default=False, action="store_true")
From 9179d4088e378c437178432168dea9f32fbf739f Mon Sep 17 00:00:00 2001
From: Edenzzzz
Date: Wed, 7 Aug 2024 13:53:48 +0800
Subject: [PATCH 107/116] [Docs] clarify launch port
Co-authored-by: Edenzzzz
---
docs/source/en/basics/launch_colossalai.md | 7 ++++---
docs/source/zh-Hans/basics/launch_colossalai.md | 6 +++---
2 files changed, 7 insertions(+), 6 deletions(-)
diff --git a/docs/source/en/basics/launch_colossalai.md b/docs/source/en/basics/launch_colossalai.md
index 8a6028d6c..32748dae1 100644
--- a/docs/source/en/basics/launch_colossalai.md
+++ b/docs/source/en/basics/launch_colossalai.md
@@ -131,17 +131,18 @@ with one simple command. There are two ways you can launch multi-node jobs.
This is suitable when you only have a few nodes. Let's say I have two nodes, namely `host1` and `host2`, I can start
multi-node training with the following command. Compared to single-node training, you must specify the `master_addr`
-option, which is auto-set to localhost if running on a single node only.
+option, which is auto-set to localhost if running on a single node only. \
+Additionally, you must also ensure that all nodes share the same open ssh port, which can be specified using --ssh-port.
:::caution
-`master_addr` cannot be localhost when running on multiple nodes, it should be the hostname or IP address of a node.
+`master_addr` cannot be localhost when running on multiple nodes, it should be the **hostname or IP address** of a node.
:::
```shell
# run on these two nodes
-colossalai run --nproc_per_node 4 --host host1,host2 --master_addr host1 test.py
+colossalai run --nproc_per_node 4 --host host1,host2 --master_addr host1 test.py --ssh-port 22
```
- Run with `--hostfile`
diff --git a/docs/source/zh-Hans/basics/launch_colossalai.md b/docs/source/zh-Hans/basics/launch_colossalai.md
index a80d16717..9e40f64c2 100644
--- a/docs/source/zh-Hans/basics/launch_colossalai.md
+++ b/docs/source/zh-Hans/basics/launch_colossalai.md
@@ -116,17 +116,17 @@ colossalai run --nproc_per_node 4 --master_port 29505 test.py
- 通过`--hosts`来启动
这个方式适合节点数不多的情况。假设我们有两个节点,分别为`host`和`host2`。我们可以用以下命令进行多节点训练。
-比起单节点训练,多节点训练需要手动设置`--master_addr` (在单节点训练中`master_addr`默认为`127.0.0.1`)。
+比起单节点训练,多节点训练需要手动设置`--master_addr` (在单节点训练中`master_addr`默认为`127.0.0.1`)。同时,你需要确保每个节点都使用同一个ssh port。可以通过--ssh-port设置。
:::caution
-多节点训练时,`master_addr`不能为`localhost`或者`127.0.0.1`,它应该是一个节点的名字或者IP地址。
+多节点训练时,`master_addr`不能为`localhost`或者`127.0.0.1`,它应该是一个节点的**名字或者IP地址**。
:::
```shell
# 在两个节点上训练
-colossalai run --nproc_per_node 4 --host host1,host2 --master_addr host1 test.py
+colossalai run --nproc_per_node 4 --host host1,host2 --master_addr host1 test.py --ssh-port 22
```
From ad3fa4f49cee16579e2997b55ff3ffd89577419d Mon Sep 17 00:00:00 2001
From: Tong Li
Date: Thu, 8 Aug 2024 18:04:47 +0800
Subject: [PATCH 108/116] [Hotfix] README link (#5966)
* update ignore
* update readme
* run style
* update readme
---
applications/ColossalChat/.gitignore | 1 +
applications/ColossalChat/README.md | 2 +-
applications/README.md | 2 +-
3 files changed, 3 insertions(+), 2 deletions(-)
diff --git a/applications/ColossalChat/.gitignore b/applications/ColossalChat/.gitignore
index 757cbb5da..7b361d38e 100755
--- a/applications/ColossalChat/.gitignore
+++ b/applications/ColossalChat/.gitignore
@@ -151,6 +151,7 @@ examples/training_scripts/wandb
examples/training_scripts/output
examples/awesome-chatgpt-prompts/
+examples/inference/round.txt
temp/
# ColossalChat
diff --git a/applications/ColossalChat/README.md b/applications/ColossalChat/README.md
index de27ebaf6..3604fab10 100755
--- a/applications/ColossalChat/README.md
+++ b/applications/ColossalChat/README.md
@@ -121,7 +121,7 @@ cd $COLOSSAL_AI_ROOT
BUILD_EXT=1 pip install .
# Install ColossalChat
-cd $COLOSSAL_AI_ROOT/applications/Chat
+cd $COLOSSAL_AI_ROOT/applications/ColossalChat
pip install .
```
diff --git a/applications/README.md b/applications/README.md
index 5b8b5e501..9957300ae 100644
--- a/applications/README.md
+++ b/applications/README.md
@@ -14,9 +14,9 @@ This directory contains the applications that are powered by Colossal-AI.
The list of applications include:
- [X] [Open-Sora](https://github.com/hpcaitech/Open-Sora): Revealing Complete Model Parameters, Training Details, and Everything for Sora-like Video Generation Models
+- [X] [ColossalChat](./ColossalChat/): Replication of ChatGPT with RLHF.
- [X] [Colossal-LLaMA](./Colossal-LLaMA/): Continual Pre-training and Supervisied Fine-tuning of LLaMA2 / LLaMA3.
- [X] [ColossalEval](./ColossalEval): Evaluation Pipeline for LLMs.
-- [X] [ColossalChat](./Chat/README.md): Replication of ChatGPT with RLHF.
- [X] [FastFold](https://github.com/hpcaitech/FastFold): Optimizing AlphaFold (Biomedicine) Training and Inference on GPU Clusters.
- [X] [ColossalQA](./ColossalQA/README.md): Document Retrieval Conversation System
- [X] [SwiftInfer](https://github.com/hpcaitech/SwiftInfer): Breaks the Length Limit of LLM Inference for Multi-Round Conversations
From b4d2377d4c482960af21bb77bf5ff78099865b02 Mon Sep 17 00:00:00 2001
From: Edenzzzz
Date: Fri, 9 Aug 2024 18:17:09 +0800
Subject: [PATCH 109/116] [Hotfix] Avoid fused RMSnorm import error without
apex (#5985)
Co-authored-by: Edenzzzz
---
colossalai/shardformer/layer/normalization.py | 25 ++++++++++---------
1 file changed, 13 insertions(+), 12 deletions(-)
diff --git a/colossalai/shardformer/layer/normalization.py b/colossalai/shardformer/layer/normalization.py
index 59e1da9fc..043bf6aeb 100644
--- a/colossalai/shardformer/layer/normalization.py
+++ b/colossalai/shardformer/layer/normalization.py
@@ -42,7 +42,7 @@ try:
return output
except ImportError:
- warnings.warn("Please install apex from source (https://github.com/NVIDIA/apex) to use the fused layernorm kernel")
+ warnings.warn("Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMSNorm kernel")
FAST_LAYERNORM_SUPPORTED_SIZE = [
1024,
@@ -270,12 +270,6 @@ class FusedRMSNorm(BaseLayerNorm):
Returns:
nn.Module: FusedRMSNorm module.
"""
- try:
- pass
- except ImportError:
- raise ImportError(
- "Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMS normalization kernel"
- )
LazyInitContext.materialize(module)
@@ -284,11 +278,18 @@ class FusedRMSNorm(BaseLayerNorm):
eps = module.variance_epsilon if hasattr(module, "variance_epsilon") else module.eps
elementwise_affine = getattr(module, "elementwise_affine", True)
- rmsnorm = FusedRMSNormWithHook(
- normalized_shape=normalized_shape,
- eps=eps,
- elementwise_affine=elementwise_affine,
- )
+ try:
+ rmsnorm = FusedRMSNormWithHook(
+ normalized_shape=normalized_shape,
+ eps=eps,
+ elementwise_affine=elementwise_affine,
+ )
+ except ImportError:
+ warnings.warn(
+ "Module replacement failed.\
+ Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMS normalization kernel"
+ )
+ return module
rmsnorm.weight = module.weight
From ed97d3a5d3bb8cd3b2ff62b0097c96bf0991df92 Mon Sep 17 00:00:00 2001
From: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Date: Mon, 12 Aug 2024 14:55:17 +0800
Subject: [PATCH 110/116] [Chat] fix readme (#5989)
* fix readme
* fix readme, tokenization fully tested
* fix readme, tokenization fully tested
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: root
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
---
.../coati/dataset/tokenization_utils.py | 2 +-
.../ColossalChat/coati/models/utils.py | 7 ++--
applications/ColossalChat/examples/README.md | 28 +++++++--------
.../examples/inference/inference.py | 1 -
.../examples/training_scripts/train_ppo.py | 2 +-
applications/ColossalChat/requirements.txt | 2 +-
applications/ColossalChat/tests/test_train.sh | 36 +++++++++----------
7 files changed, 38 insertions(+), 40 deletions(-)
diff --git a/applications/ColossalChat/coati/dataset/tokenization_utils.py b/applications/ColossalChat/coati/dataset/tokenization_utils.py
index 4f890ffc9..020432b9e 100755
--- a/applications/ColossalChat/coati/dataset/tokenization_utils.py
+++ b/applications/ColossalChat/coati/dataset/tokenization_utils.py
@@ -169,7 +169,7 @@ def tokenize_prompt(
template.messages = template.messages[:-1]
# Prepare data
- prompt = template.get_prompt(length=len(template.messages) - 1, add_generation_prompt=True)
+ prompt = template.get_prompt(length=len(template.messages), add_generation_prompt=True)
tokenized = tokenizer([prompt], add_special_tokens=False)["input_ids"][0]
if tokenizer.bos_token_id is not None:
diff --git a/applications/ColossalChat/coati/models/utils.py b/applications/ColossalChat/coati/models/utils.py
index 8ed8d3401..c583f057a 100755
--- a/applications/ColossalChat/coati/models/utils.py
+++ b/applications/ColossalChat/coati/models/utils.py
@@ -138,6 +138,7 @@ def disable_dropout(model: torch.nn.Module):
Returns:
None
"""
- for module in model.modules():
- if isinstance(module, torch.nn.Dropout):
- module.p = 0.0
+ if model is not None:
+ for module in model.modules():
+ if isinstance(module, torch.nn.Dropout):
+ module.p = 0.0
diff --git a/applications/ColossalChat/examples/README.md b/applications/ColossalChat/examples/README.md
index 904d69cfc..fec7bc061 100755
--- a/applications/ColossalChat/examples/README.md
+++ b/applications/ColossalChat/examples/README.md
@@ -462,26 +462,24 @@ Stage1 is supervised instructs fine-tuning (SFT). This step is a crucial part of
#### Step 1: Data Collection
-The first step in Stage 1 is to collect a dataset of human demonstrations of the following format.
+The first step in Stage 1 is to collect a dataset of human demonstrations of the following JSONL format.
```json
-[
- {"messages":
- [
- {
- "from": "user",
- "content": "what are some pranks with a pen i can do?"
- },
- {
- "from": "assistant",
- "content": "Are you looking for practical joke ideas?"
- },
- ...
- ]
+{"messages":
+ [
+ {
+ "from": "user",
+ "content": "what are some pranks with a pen i can do?"
+ },
+ {
+ "from": "assistant",
+ "content": "Are you looking for practical joke ideas?"
},
...
-]
+ ]
+},
+...
```
diff --git a/applications/ColossalChat/examples/inference/inference.py b/applications/ColossalChat/examples/inference/inference.py
index 5f59ba452..32310cce9 100755
--- a/applications/ColossalChat/examples/inference/inference.py
+++ b/applications/ColossalChat/examples/inference/inference.py
@@ -151,7 +151,6 @@ def main(args):
chat_io.prompt_for_output("assistant")
prompt = conv.get_prompt(add_generation_prompt=True)
- print(prompt + "")
input_ids = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)["input_ids"].to(
torch.cuda.current_device()
)
diff --git a/applications/ColossalChat/examples/training_scripts/train_ppo.py b/applications/ColossalChat/examples/training_scripts/train_ppo.py
index c10418394..a0a10e239 100755
--- a/applications/ColossalChat/examples/training_scripts/train_ppo.py
+++ b/applications/ColossalChat/examples/training_scripts/train_ppo.py
@@ -502,7 +502,7 @@ if __name__ == "__main__":
parser.add_argument("--disable_loss_mask", default=False, action="store_true")
parser.add_argument("--max_length", type=int, default=2048)
parser.add_argument("--max_seq_len", type=int, default=256)
- parser.add_argument("--log_dir", default="logs", type=str)
+ parser.add_argument("--log_dir", default=None, type=str)
parser.add_argument("--use_wandb", default=False, action="store_true")
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
parser.add_argument("--use_flash_attn", default=False, action="store_true")
diff --git a/applications/ColossalChat/requirements.txt b/applications/ColossalChat/requirements.txt
index 2188de12f..ac40ae821 100755
--- a/applications/ColossalChat/requirements.txt
+++ b/applications/ColossalChat/requirements.txt
@@ -2,7 +2,7 @@ transformers==4.39.3
tqdm
datasets==2.14.7
loralib
-colossalai==0.4.0
+colossalai>=0.4.0
torch>=2.1.0
langchain
tokenizers
diff --git a/applications/ColossalChat/tests/test_train.sh b/applications/ColossalChat/tests/test_train.sh
index c26b25c83..69036de63 100755
--- a/applications/ColossalChat/tests/test_train.sh
+++ b/applications/ColossalChat/tests/test_train.sh
@@ -15,7 +15,7 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() {
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
}
-set_n_least_used_CUDA_VISIBLE_DEVICES 4
+set_n_least_used_CUDA_VISIBLE_DEVICES 2
set -xu
@@ -119,11 +119,11 @@ for lora_rank in ${LORA_RANK[@]}; do
lora_config=""
fi
if [[ $plugin == "3d" ]]; then
- tp='4'
+ tp='2'
bs='8'
fi
if [[ $plugin == "tp_zero2" ]]; then
- tp='4'
+ tp='2'
bs='8'
zero_stage='2'
plugin='3d'
@@ -136,13 +136,13 @@ for lora_rank in ${LORA_RANK[@]}; do
fi
if [[ $plugin == "pp" ]]; then
bs='8'
- pp='4'
+ pp='2'
plugin='3d'
fi
if [[ $plugin == "sp_split_gather" ]]; then
enable_sequence_parallelism='--enable_sequence_parallelism'
sp_mode='split_gather'
- tp='4'
+ tp='2'
sp='1'
bs='8'
plugin='3d'
@@ -150,7 +150,7 @@ for lora_rank in ${LORA_RANK[@]}; do
if [[ $plugin == "sp_ring" ]]; then
enable_sequence_parallelism='--enable_sequence_parallelism'
sp_mode='ring'
- tp='4'
+ tp='2'
sp='1'
bs='8'
plugin='3d'
@@ -159,7 +159,7 @@ for lora_rank in ${LORA_RANK[@]}; do
enable_sequence_parallelism='--enable_sequence_parallelism'
sp_mode='all_to_all'
tp='1'
- sp='4'
+ sp='2'
bs='8'
plugin='3d'
fi
@@ -175,7 +175,7 @@ for lora_rank in ${LORA_RANK[@]}; do
for split in $(seq -f "%05g" 0 0); do
dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_sft/arrow/part-$split")
done
- colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_sft.py \
+ colossalai run --nproc_per_node 2 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_sft.py \
--pretrain $pretrain \
--tokenizer_dir $tokenizer_dir \
--dataset ${dataset[@]} \
@@ -242,7 +242,7 @@ for lora_rank in ${LORA_RANK[@]}; do
lora_config=""
fi
if [[ $plugin == "3d" ]]; then
- tp='4'
+ tp='2'
bs='8'
fi
grad_accu='2'
@@ -256,7 +256,7 @@ for lora_rank in ${LORA_RANK[@]}; do
for split in $(seq -f "%05g" 0 0); do
dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_preference/arrow/part-$split")
done
- colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_rm.py \
+ colossalai run --nproc_per_node 2 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_rm.py \
--pretrain $pretrain \
--tokenizer_dir $tokenizer_dir \
--dataset ${dataset[@]} \
@@ -325,7 +325,7 @@ for lora_rank in ${LORA_RANK[@]}; do
lora_config=""
fi
if [[ $plugin == "3d" ]]; then
- tp='4'
+ tp='2'
bs='16'
ebs='32'
fi
@@ -350,7 +350,7 @@ for lora_rank in ${LORA_RANK[@]}; do
for split in $(seq -f "%05g" 0 0); do
ptx_dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_sft/arrow/part-$split")
done
- colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_ppo.py \
+ colossalai run --nproc_per_node 2 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_ppo.py \
--pretrain $pretrain \
--rm_pretrain $pretrain \
--tokenizer_dir $tokenizer_dir \
@@ -417,7 +417,7 @@ for lora_rank in ${LORA_RANK[@]}; do
tp='1'
bs='2'
if [[ $plugin == "3d" ]]; then
- tp='4'
+ tp='2'
bs='8'
fi
if [[ $plugin == "zero2" ]]; then
@@ -442,7 +442,7 @@ for lora_rank in ${LORA_RANK[@]}; do
for split in $(seq -f "%05g" 0 0); do
dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_preference/arrow/part-$split")
done
- colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_dpo.py \
+ colossalai run --nproc_per_node 2 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_dpo.py \
--pretrain $pretrain \
--tokenizer_dir $tokenizer_dir \
--dataset ${dataset[@]} \
@@ -500,7 +500,7 @@ for lora_rank in ${LORA_RANK[@]}; do
tp='1'
bs='2'
if [[ $plugin == "3d" ]]; then
- tp='4'
+ tp='2'
bs='8'
fi
if [[ $plugin == "zero2" ]]; then
@@ -525,7 +525,7 @@ for lora_rank in ${LORA_RANK[@]}; do
for split in $(seq -f "%05g" 0 0); do
dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_preference/arrow/part-$split")
done
- colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_orpo.py \
+ colossalai run --nproc_per_node 2 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_orpo.py \
--pretrain $pretrain \
--tokenizer_dir $tokenizer_dir \
--dataset ${dataset[@]} \
@@ -583,7 +583,7 @@ for lora_rank in ${LORA_RANK[@]}; do
tp='1'
bs='2'
if [[ $plugin == "3d" ]]; then
- tp='4'
+ tp='2'
bs='8'
fi
if [[ $plugin == "zero2" ]]; then
@@ -608,7 +608,7 @@ for lora_rank in ${LORA_RANK[@]}; do
for split in $(seq -f "%05g" 0 0); do
dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_kto/arrow/part-$split")
done
- colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_kto.py \
+ colossalai run --nproc_per_node 2 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_kto.py \
--pretrain $pretrain \
--tokenizer_dir $tokenizer_dir \
--dataset ${dataset[@]} \
From ceb1e262e765242c1f130aa72ab9d5e2289162be Mon Sep 17 00:00:00 2001
From: Tong Li
Date: Wed, 14 Aug 2024 11:22:39 +0800
Subject: [PATCH 111/116] fix sync condition (#6000)
---
colossalai/booster/plugin/hybrid_parallel_plugin.py | 6 ++++--
1 file changed, 4 insertions(+), 2 deletions(-)
diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py
index d2933a4af..e5acdb051 100644
--- a/colossalai/booster/plugin/hybrid_parallel_plugin.py
+++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py
@@ -1326,8 +1326,10 @@ class HybridParallelPlugin(PipelinePluginBase):
)
# run with gradients accumulation
- if model.require_grad_sync == False or (
- isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False
+ if (
+ model.require_grad_sync == False
+ or (isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False)
+ or not torch.is_grad_enabled()
):
return outputs
From 406f984063423042e25d0723258530ba506a44a9 Mon Sep 17 00:00:00 2001
From: Hongxin Liu
Date: Thu, 15 Aug 2024 10:41:22 +0800
Subject: [PATCH 112/116] [plugin] add cast inputs option for zero (#6003)
---
colossalai/booster/plugin/low_level_zero_plugin.py | 13 ++++++++++---
1 file changed, 10 insertions(+), 3 deletions(-)
diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py
index 66491821c..e4c386a22 100644
--- a/colossalai/booster/plugin/low_level_zero_plugin.py
+++ b/colossalai/booster/plugin/low_level_zero_plugin.py
@@ -62,7 +62,9 @@ class OptimizerParamCheckState(enum.Enum):
class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
- def __init__(self, module: nn.Module, precision: str, overlap_allgather: bool = False) -> None:
+ def __init__(
+ self, module: nn.Module, precision: str, overlap_allgather: bool = False, cast_inputs: bool = True
+ ) -> None:
super().__init__(module)
self.dtype = None
if precision == "fp16":
@@ -74,7 +76,7 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
module = module.to(get_accelerator().get_current_device())
self.module = module
self.convert_fn = None
- if self.dtype is not None:
+ if self.dtype is not None and cast_inputs:
self.convert_fn = partial(_convert_floating_point, dtype=self.dtype)
self.overlap_allgather = overlap_allgather
if overlap_allgather:
@@ -334,6 +336,7 @@ class LowLevelZeroPlugin(DPPluginBase):
cpu_offload: bool = False,
master_weights: bool = True,
verbose: bool = False,
+ cast_inputs: bool = True,
) -> None:
super().__init__()
assert stage in (1, 2), f"LowLevelZeroPlugin only supports stage 1/2 training"
@@ -360,6 +363,7 @@ class LowLevelZeroPlugin(DPPluginBase):
)
self.lora_enabled = False
self.verbose = verbose
+ self.cast_inputs = cast_inputs
# set class name with stage, for better error message
setattr(self.__class__, "__name__", f"LowLevelZeroPlugin_ZeRO-{stage}")
@@ -474,7 +478,10 @@ class LowLevelZeroPlugin(DPPluginBase):
if not isinstance(model, ModelWrapper):
model = LowLevelZeroModel(
- model, self.precision, overlap_allgather=self.zero_optim_kwargs["overlap_allgather"]
+ model,
+ self.precision,
+ overlap_allgather=self.zero_optim_kwargs["overlap_allgather"],
+ cast_inputs=self.cast_inputs,
)
# TODO: Support Galore + ZeRO
From 4dd03999ecb1016b0919c090a065e5bf425432ea Mon Sep 17 00:00:00 2001
From: "pre-commit-ci[bot]"
<66853113+pre-commit-ci[bot]@users.noreply.github.com>
Date: Thu, 15 Aug 2024 14:40:03 +0800
Subject: [PATCH 113/116] [pre-commit.ci] pre-commit autoupdate (#5995)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
updates:
- [github.com/psf/black-pre-commit-mirror: 24.4.2 → 24.8.0](https://github.com/psf/black-pre-commit-mirror/compare/24.4.2...24.8.0)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
---
.pre-commit-config.yaml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 9088d0e1b..a4132a507 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -14,7 +14,7 @@ repos:
name: sort all imports (python)
- repo: https://github.com/psf/black-pre-commit-mirror
- rev: 24.4.2
+ rev: 24.8.0
hooks:
- id: black
name: black formatter
From 887d2d579b522cadab12571f2357d9e2cbd23aed Mon Sep 17 00:00:00 2001
From: Haze188
Date: Thu, 15 Aug 2024 14:40:26 +0800
Subject: [PATCH 114/116] [misc] Bypass the huggingface bug to solve the mask
mismatch problem (#5991)
---
colossalai/shardformer/modeling/deepseek.py | 3 +++
1 file changed, 3 insertions(+)
diff --git a/colossalai/shardformer/modeling/deepseek.py b/colossalai/shardformer/modeling/deepseek.py
index a84a30972..429c4350c 100644
--- a/colossalai/shardformer/modeling/deepseek.py
+++ b/colossalai/shardformer/modeling/deepseek.py
@@ -666,6 +666,9 @@ def get_deepseek_flash_attention_model_forward(shard_config, sp_mode=None, sp_si
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
+ # TODO: upgrade transformers to 4.44.0 to fix the bug, remove the hard code.
+ self._use_flash_attention_2 = shard_config.enable_flash_attention
+ self._use_sdpa = False if shard_config.enable_flash_attention else self._use_sdpa
if self._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
From f5c84af0b01bcd2e993d38dc628793f7f0a8ba64 Mon Sep 17 00:00:00 2001
From: Edenzzzz
Date: Fri, 16 Aug 2024 13:56:38 +0800
Subject: [PATCH 115/116] [Feature] Zigzag Ring attention (#5905)
* halfway
* fix cross-PP-stage position id length diff bug
* fix typo
* fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* unified cross entropy func for all shardformer models
* remove redundant lines
* add basic ring attn; debug cross entropy
* fwd bwd logic complete
* fwd bwd logic complete; add experimental triton rescale
* precision tests passed
* precision tests passed
* fix typos and remove misc files
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* add sp_mode to benchmark; fix varlen interface
* update softmax_lse shape by new interface
* change tester name
* remove buffer clone; support packed seq layout
* add varlen tests
* fix typo
* all tests passed
* add dkv_group; fix mask
* remove debug statements
---------
Co-authored-by: Edenzzzz
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
---
.pre-commit-config.yaml | 1 +
.../booster/plugin/hybrid_parallel_plugin.py | 50 +-
.../hybrid_parallel_checkpoint_io.py | 4 -
colossalai/lazy/pretrained.py | 4 -
.../moe/openmoe/model/openmoe_policy.py | 2 +-
.../legacy/nn/layer/parallel_1d/_operation.py | 3 +
colossalai/logging/logger.py | 5 +-
.../pipeline/schedule/interleaved_pp.py | 1 -
colossalai/pipeline/schedule/one_f_one_b.py | 1 +
colossalai/shardformer/layer/__init__.py | 4 +-
colossalai/shardformer/layer/_operation.py | 38 +-
colossalai/shardformer/layer/attn.py | 904 +++++++++++++++++-
colossalai/shardformer/layer/linear.py | 14 +-
colossalai/shardformer/layer/loss.py | 165 +++-
colossalai/shardformer/layer/utils.py | 198 +++-
colossalai/shardformer/modeling/command.py | 8 +-
colossalai/shardformer/modeling/llama.py | 143 ++-
.../shardformer/policies/base_policy.py | 1 +
colossalai/shardformer/policies/command.py | 31 +-
colossalai/shardformer/policies/deepseek.py | 2 +-
colossalai/shardformer/policies/llama.py | 45 +-
colossalai/shardformer/policies/mistral.py | 2 +-
colossalai/shardformer/policies/mixtral.py | 2 +-
colossalai/shardformer/policies/qwen2.py | 2 +-
colossalai/shardformer/shard/shard_config.py | 13 +-
examples/language/llama/benchmark.py | 29 +-
examples/language/opt/README.md | 2 +-
examples/language/performance_evaluator.py | 24 +-
examples/tutorial/opt/opt/README.md | 2 +-
.../flash_attention_dao_cuda.py | 8 +-
tests/kit/model_zoo/__init__.py | 4 +-
tests/kit/model_zoo/transformers/command.py | 12 +-
tests/kit/model_zoo/transformers/llama.py | 41 +-
tests/kit/model_zoo/transformers/mistral.py | 2 +-
tests/kit/model_zoo/transformers/qwen2.py | 12 +-
.../test_plugin/test_3d_plugin.py | 2 +-
.../test_plugin/test_low_level_zero_plugin.py | 2 +-
.../test_gemini_checkpoint_io.py | 2 +-
.../test_gemini_torch_compability.py | 2 +-
...st_hybrid_parallel_plugin_checkpoint_io.py | 2 +-
.../test_low_level_zero_checkpoint_io.py | 2 +-
.../test_plugins_huggingface_compatibility.py | 2 +-
tests/test_lora/test_lora.py | 2 +-
.../test_schedule/test_interleaved.py | 17 +-
.../test_schedule/test_oneF_oneB.py | 17 +-
.../test_shardformer/test_flash_attention.py | 3 +
.../test_layer/test_ring_attn.py | 186 ++++
tests/test_shardformer/test_model/_utils.py | 27 +-
.../test_model/test_shard_command.py | 4 +-
.../test_model/test_shard_llama.py | 147 ++-
50 files changed, 1870 insertions(+), 326 deletions(-)
create mode 100644 tests/test_shardformer/test_layer/test_ring_attn.py
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index a4132a507..f7217a8f1 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -12,6 +12,7 @@ repos:
hooks:
- id: isort
name: sort all imports (python)
+ args: ["--profile", "black"] # avoid conflict with black
- repo: https://github.com/psf/black-pre-commit-mirror
rev: 24.8.0
diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py
index e5acdb051..63427192f 100644
--- a/colossalai/booster/plugin/hybrid_parallel_plugin.py
+++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py
@@ -32,7 +32,7 @@ from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackw
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.quantization import BnbQuantizationConfig, quantize_model
from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer
-from colossalai.shardformer.layer.utils import SeqParallelUtils
+from colossalai.shardformer.layer.utils import SeqParallelUtils, is_share_sp_tp
from colossalai.shardformer.policies.base_policy import Policy
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.tensor.d_tensor.api import is_distributed_tensor
@@ -42,7 +42,7 @@ from colossalai.zero.low_level.zero_hook import ZeroOpHook, wait_all_gather_hand
from .pp_plugin_base import PipelinePluginBase
-SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all"]
+SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all", "ring_attn"]
PRECISION_TORCH_TYPE = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16}
@@ -72,7 +72,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
self.dp_group = dp_group
self.tp_group = tp_group
self.sp_group = sp_group
- self.use_dpp = use_ddp
+ self.use_ddp = use_ddp
self.require_grad_sync = True
self.overlap_allgather = overlap_allgather
@@ -139,8 +139,8 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
# Disable automatic gradient synchronization.
self.require_grad_sync = False
try:
- if self.use_dpp:
- # If using data parallel processing (use_dpp), disable synchronization too.
+ if self.use_ddp:
+ # If using data parallel processing (use_ddp), disable synchronization too.
with self.module.no_sync():
yield
else:
@@ -188,7 +188,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
"""
if self.shard_config.enable_sequence_parallelism:
- if self.shard_config.sequence_parallelism_mode == "all_to_all":
+ if self.shard_config.sequence_parallelism_mode in ["all_to_all", "ring_attn"]:
return
if self.shard_config.sequence_parallelism_mode in ["split_gather", "ring"]:
@@ -970,6 +970,9 @@ class HybridParallelPlugin(PipelinePluginBase):
enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True.
make_vocab_size_divisible_by (int, optional): it's used when padding the vocabulary size, to make it choose an faster kenel. Default to 64.
overlap_p2p (bool, optional): Whether to overlap the p2p communication in pipeline parallelism
+ inner_ring_size (int, optional): The inner ring size of 2D Ring Attention when sp mode is "ring_attn".
+ It's advisable to not tune this (especially in single-node settings) and let it be heuristically set based on topology by default.
+
"""
def __init__(
@@ -1017,6 +1020,7 @@ class HybridParallelPlugin(PipelinePluginBase):
dp_outside: bool = True,
overlap_p2p: bool = True,
overlap_allgather: bool = False,
+ inner_ring_size: int = None,
) -> None:
super().__init__()
@@ -1041,9 +1045,11 @@ class HybridParallelPlugin(PipelinePluginBase):
)
self.sp_size = 1
self.dp_size = dist.get_world_size() // (tp_size * pp_size)
- elif self.sequence_parallelism_mode in ["all_to_all"]:
+ elif self.sequence_parallelism_mode in ["all_to_all", "ring_attn"]:
self.sp_size = 1 if sp_size is None else sp_size
self.dp_size = dist.get_world_size() // (self.sp_size * pp_size * tp_size)
+ if self.sequence_parallelism_mode == "ring_attn":
+ enable_flash_attention = True
else:
self.dp_size = dist.get_world_size() // (tp_size * pp_size)
assert (
@@ -1063,10 +1069,21 @@ class HybridParallelPlugin(PipelinePluginBase):
self.enable_sequence_parallelism = enable_sequence_parallelism
if dp_outside:
self.dp_axis, self.pp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
- self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size)
+ if sequence_parallelism_mode == "ring_attn":
+ # Swap tp and sp since 2D Ring has better inter-node latency
+ self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.sp_size, self.tp_size)
+ self.sp_axis = 2
+ self.tp_axis = 3
+ else:
+ self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size)
else:
self.pp_axis, self.dp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
- self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size)
+ if sequence_parallelism_mode == "ring_attn":
+ self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.sp_size, self.tp_size)
+ self.sp_axis = 2
+ self.tp_axis = 3
+ else:
+ self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size)
self.stage_manager = None
self.schedule = None
@@ -1108,6 +1125,8 @@ class HybridParallelPlugin(PipelinePluginBase):
)
else:
raise NotImplementedError()
+ if sequence_parallelism_mode == "ring_attn":
+ assert parallel_output, "Ring Attention doesn't support gathering output yet."
self.tp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
self.dp_group = self.pg_mesh.get_group_along_axis(self.dp_axis)
@@ -1132,6 +1151,7 @@ class HybridParallelPlugin(PipelinePluginBase):
parallel_output=parallel_output,
make_vocab_size_divisible_by=make_vocab_size_divisible_by,
gradient_checkpoint_config=gradient_checkpoint_config,
+ inner_ring_size=inner_ring_size,
)
self.amp_config = dict(
initial_scale=initial_scale,
@@ -1216,15 +1236,15 @@ class HybridParallelPlugin(PipelinePluginBase):
zero_stage = 0
if not isinstance(model, ModelWrapper):
+ # Shouldn't use pp (frequent grad accumulation) with torch ddp
use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or (
- self.dp_size == 1
- and self.pp_size == 1
- and self.enable_sequence_parallelism
- and self.sequence_parallelism_mode == "all_to_all"
+ self.dp_size == 1 and self.pp_size == 1
)
- # sync gradients across DP * SP ranks
- if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all":
+
+ # Apply Hybrid ZeRO across DP * SP ranks
+ if self.enable_sequence_parallelism and not is_share_sp_tp(self.sequence_parallelism_mode):
dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
+ self.dp_size = get_world_size(dp_group)
else:
dp_group = self.dp_group
model = HybridParallelModule(
diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
index 0310df548..043e5c2b0 100644
--- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
+++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
@@ -203,7 +203,6 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
return
Path(checkpoint).mkdir(parents=True, exist_ok=True)
-
# Devices along the same dp_group share the same copies of model.
# So only let the device with dp_rank == 0 save the model.
if self.dp_rank != 0:
@@ -643,14 +642,12 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
model._force_wait_all_gather()
model = model.unwrap()
-
if self.dp_rank != 0:
return
# The logic of collecting parameter shards along tp degree
# has been implemented by _save_to_state_dict method of ParallelModule in Shardformer.
state_dict = model.state_dict()
-
if self.pp_size == 1:
# When pipeline is not used, let master rank directly save the collected state_dict.
if self.tp_rank == 0:
@@ -660,7 +657,6 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
state_dict_list = [None for _ in range(self.pp_size)]
dist.barrier(self.pp_group)
dist.all_gather_object(state_dict_list, state_dict, self.pp_group)
-
# Only the master rank do the saving.
if self.coordinator.is_master():
complete_state_dict = dict()
diff --git a/colossalai/lazy/pretrained.py b/colossalai/lazy/pretrained.py
index 736ffc5e4..226951598 100644
--- a/colossalai/lazy/pretrained.py
+++ b/colossalai/lazy/pretrained.py
@@ -62,7 +62,6 @@ def new_from_pretrained(
config = kwargs.pop("config", None)
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
- resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False)
use_auth_token = kwargs.pop("use_auth_token", None)
@@ -116,7 +115,6 @@ def new_from_pretrained(
cache_dir=cache_dir,
return_unused_kwargs=True,
force_download=force_download,
- resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
@@ -195,7 +193,6 @@ def new_from_pretrained(
"cache_dir": cache_dir,
"force_download": force_download,
"proxies": proxies,
- "resume_download": resume_download,
"local_files_only": local_files_only,
"use_auth_token": use_auth_token,
"user_agent": user_agent,
@@ -312,7 +309,6 @@ def new_from_pretrained(
pretrained_model_name_or_path,
cache_dir=cache_dir,
force_download=force_download,
- resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
diff --git a/colossalai/legacy/moe/openmoe/model/openmoe_policy.py b/colossalai/legacy/moe/openmoe/model/openmoe_policy.py
index ccd566b08..d5824afcb 100644
--- a/colossalai/legacy/moe/openmoe/model/openmoe_policy.py
+++ b/colossalai/legacy/moe/openmoe/model/openmoe_policy.py
@@ -171,7 +171,7 @@ class OpenMoeForCausalLMPolicy(OpenMoePolicy):
policy = super().module_policy()
if self.shard_config.enable_tensor_parallelism:
- # add a new item for casual lm
+ # add a new item for causal lm
# TODO: recursively assign ep group foe all modules
new_item = {
OpenMoeForCausalLM: ModulePolicyDescription(
diff --git a/colossalai/legacy/nn/layer/parallel_1d/_operation.py b/colossalai/legacy/nn/layer/parallel_1d/_operation.py
index 8b8f04ccf..e892336bc 100644
--- a/colossalai/legacy/nn/layer/parallel_1d/_operation.py
+++ b/colossalai/legacy/nn/layer/parallel_1d/_operation.py
@@ -81,6 +81,9 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
handle = dist.all_reduce(grad_input, group=gpc.get_group(ctx.parallel_mode), async_op=True)
# Delay the start of weight gradient computation shortly (3us) to have
# all-reduce scheduled first and have GPU resources allocated
+ # TODO: This seems to only work if you add torch.cuda.Event.wait()
+
+ # _ = torch.zeros(1, device=grad_output.device)
grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None
diff --git a/colossalai/logging/logger.py b/colossalai/logging/logger.py
index eb5f28e2a..9f4b7a7b0 100644
--- a/colossalai/logging/logger.py
+++ b/colossalai/logging/logger.py
@@ -64,7 +64,10 @@ class DistributedLogger:
self._logger.propagate = False
DistributedLogger.__instances[name] = self
- self.rank = dist.get_rank() if dist.is_initialized() else 0
+
+ @property
+ def rank(self):
+ return dist.get_rank() if dist.is_initialized() else 0
@staticmethod
def __get_call_info():
diff --git a/colossalai/pipeline/schedule/interleaved_pp.py b/colossalai/pipeline/schedule/interleaved_pp.py
index a21b45c44..412f3896f 100644
--- a/colossalai/pipeline/schedule/interleaved_pp.py
+++ b/colossalai/pipeline/schedule/interleaved_pp.py
@@ -286,7 +286,6 @@ class InterleavedSchedule(PipelineSchedule):
# for the first stage, input_obj is None
# for other stages, input_obj is the output of the previous stage containing hidden_states etc.
# Only attention_mask from micro_batch is used
-
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if isinstance(model_chunk, ModuleList):
output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, input_obj)
diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py
index 7f0d0e349..03df67ae7 100644
--- a/colossalai/pipeline/schedule/one_f_one_b.py
+++ b/colossalai/pipeline/schedule/one_f_one_b.py
@@ -244,6 +244,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
output_obj = model_forward(model, micro_batch, input_obj)
if self.stage_manager.is_last_stage():
loss = criterion(output_obj, micro_batch) / self.num_microbatches
+
if accum_loss is not None:
accum_loss.add_(loss.detach())
if outputs is not None:
diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py
index 331e49729..8882a33c1 100644
--- a/colossalai/shardformer/layer/__init__.py
+++ b/colossalai/shardformer/layer/__init__.py
@@ -1,5 +1,5 @@
from ._operation import all_to_all_comm
-from .attn import AttnMaskType, ColoAttention
+from .attn import AttnMaskType, ColoAttention, RingAttention, get_pad_info
from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D
from .linear import Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHead1D
@@ -31,5 +31,7 @@ __all__ = [
"VocabParallelLMHead1D",
"AttnMaskType",
"ColoAttention",
+ "RingAttention",
+ "get_pad_info",
"all_to_all_comm",
]
diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py
index 19da348e7..25983e0a9 100644
--- a/colossalai/shardformer/layer/_operation.py
+++ b/colossalai/shardformer/layer/_operation.py
@@ -2,6 +2,8 @@ import torch
import torch.distributed as dist
import torch.nn.functional as F
+from .utils import is_share_sp_tp
+
try:
import fused_mix_prec_layer_norm_cuda
except:
@@ -93,7 +95,7 @@ class MatmulWithAsyncCommunication(torch.autograd.Function):
if ctx.async_grad_allreduce:
# Asynchronous all-reduce
handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True)
- # Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
+ # Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py
grad_weight = total_input.t().matmul(grad_output)
@@ -143,7 +145,9 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
if ctx.async_grad_allreduce:
# Asynchronous all-reduce
handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True)
- # Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
+ _ = torch.zeros(1, device=grad_input.device)
+
+ # Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py
if _grad_accum_fusion_available and weight.grad is not None:
@@ -331,7 +335,7 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
input_.shape, dtype=input_parallel.dtype, device=input_parallel.device
).contiguous()
handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
- # Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
+ # Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py
if _grad_accum_fusion_available and weight.grad is not None:
@@ -646,8 +650,8 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
input_.shape, dtype=input_parallel.dtype, device=input_parallel.device
).contiguous()
handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
- # Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
- # all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py
+ # Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
+ # all-reduce scheduled first and have GPU resources allocated
grad_weight = total_input.t().matmul(grad_output)
grad_bias = grad_output.sum(dim=0) if use_bias else None
@@ -721,16 +725,20 @@ class _ReduceForward(torch.autograd.Function):
Args:
input_: input matrix.
- parallel_mode: parallel mode.
+ process_group: communication group.
+
"""
@staticmethod
- def forward(ctx, input_, process_group):
+ def forward(ctx, input_, process_group, grad_scale=None):
+ ctx.grad_scale = grad_scale
return _reduce(input_, process_group)
@staticmethod
def backward(ctx, grad_output):
- return grad_output, None
+ if ctx.grad_scale is not None:
+ grad_output = grad_output * ctx.grad_scale
+ return grad_output, None, None
class _ReduceBackward(torch.autograd.Function):
@@ -979,8 +987,8 @@ def split_forward_gather_backward(input_, dim, process_group, grad_scale=None):
return _SplitForwardGatherBackward.apply(input_, dim, process_group, grad_scale)
-def reduce_forward(input_, process_group):
- return _ReduceForward.apply(input_, process_group)
+def reduce_forward(input_, process_group, grad_scale=None):
+ return _ReduceForward.apply(input_, process_group, grad_scale)
def reduce_backward(input_, process_group):
@@ -989,3 +997,13 @@ def reduce_backward(input_, process_group):
def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1):
return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim)
+
+
+def gather_sp_output(hidden_states, sp_group, sp_mode):
+ """
+ Gather the output of the last layer for cross entropy computation
+ """
+ # Rescale grad (HybridParallelPlugin applies ZeRO grad averaging on the DP * SP group)
+ scale = None if is_share_sp_tp(sp_mode) else dist.get_world_size(sp_group)
+ hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=scale)
+ return hidden_states
diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py
index 5872c6485..6dab17ec0 100644
--- a/colossalai/shardformer/layer/attn.py
+++ b/colossalai/shardformer/layer/attn.py
@@ -2,7 +2,10 @@ from enum import Enum
from typing import Callable, Dict, Optional, Tuple
import torch
+import torch.distributed
+import torch.distributed as dist
import torch.nn.functional as F
+from einops import rearrange
from colossalai.kernel.kernel_loader import (
FlashAttentionForFloatAndCustomMaskLoader,
@@ -10,12 +13,18 @@ from colossalai.kernel.kernel_loader import (
FlashAttentionWithCustomMaskLoader,
KernelLoader,
)
+from colossalai.logging import get_dist_logger
+
+from .utils import RingComm, get_half_index, split_varlen_zigzag
__all__ = [
"AttnMaskType",
"ColoAttention",
]
+_flash_attn_forward = _flash_attn_backward = None
+_unpad_input = _pad_input = None
+
class AttnMaskType(Enum):
CUSTOM = 0
@@ -38,20 +47,32 @@ def invert_mask(mask: torch.Tensor) -> torch.Tensor:
# adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py
-def get_pad_info(padding_mask: torch.Tensor) -> Tuple[int, torch.Tensor, torch.Tensor]:
+def get_pad_info(
+ padding_mask: torch.Tensor, invert: Optional[bool] = False, return_indices: Optional[bool] = True
+) -> Tuple[int, torch.Tensor, torch.Tensor]:
"""Get padding information from padding mask.
Args:
- padding_mask (torch.Tensor): Padding mask tensor. Shape should be [B, S]
+ padding_mask (torch.Tensor): Padding mask tensor. Shape should be [B, Skv]
+ invert (Optional[bool], optional): Whether to reverse the padding mask.
+ return_indices (Optional[bool], optional): Whether to return the indices of non-masked tokens.
Returns:
- Tuple[int, torch.Tensor, torch.Tensor]: Tuple of (max_seq_len, cu_seqlens, indices)
+ max_seqlen_in_batch (int): Maximum sequence length in the batch.
+ cu_seqlens (torch.Tensor): Shape [B+1]. Cumulative sequence lengths of the sequences in the batch.
+ indices (torch.Tensor): Shape [total_nonzero]. The indices of non-masked tokens from the flattened input sequence.
"""
+ if invert:
+ padding_mask = padding_mask.logical_not()
seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32)
- indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten()
+ if return_indices:
+ indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten()
+
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
- return max_seqlen_in_batch, cu_seqlens, indices
+ if return_indices:
+ return max_seqlen_in_batch, cu_seqlens, indices
+ return max_seqlen_in_batch, cu_seqlens
class ColoAttention:
@@ -107,6 +128,7 @@ class ColoAttention:
q_padding_mask: Optional[torch.Tensor] = None,
kv_padding_mask: Optional[torch.Tensor] = None,
is_causal: bool = False,
+ invert: bool = True,
) -> Dict[str, torch.Tensor]:
"""Return a dictionary of keyword arguments for attention function. It supports 4 mask type.
1. custom mask: no padding mask and is_causal=False, return {}, users should handle attention mask by themselves.
@@ -124,7 +146,7 @@ class ColoAttention:
The shape should be [B, Skv]. ``1`` means valid token, and ``0`` means padding token.
If it's None and ``q_padding_mask`` is not None, it will be set to ``q_padding_mask``. Defaults to None.
is_causal (bool, optional): Whether to use causal attention mask. Defaults to False.
-
+ invert_mask (bool, optional): Whether to invert the mask. Defaults to True.
Returns:
Dict[str, torch.Tensor]: Dictionary of keyword arguments for attention function.
"""
@@ -154,7 +176,7 @@ class ColoAttention:
assert kv_padding_mask.shape == (
b,
s_kv,
- ), f"q_padding_mask shape {kv_padding_mask.shape} should be the same. ({shape_4d})"
+ ), f"Padding mask shape {kv_padding_mask.shape} should align with shape 4d ({b}, {s_kv})"
attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device)
outputs.update(
{
@@ -172,7 +194,8 @@ class ColoAttention:
attention_mask = attention_mask * attention_mask.new_ones(s_q, s_kv).tril(diagonal=0)
else:
outputs["attention_mask_type"] = AttnMaskType.PADDED
- attention_mask = invert_mask(attention_mask).unsqueeze(1)
+ if invert:
+ attention_mask = invert_mask(attention_mask).unsqueeze(1)
outputs["attention_mask"] = attention_mask
return outputs
@@ -191,6 +214,7 @@ class ColoAttention:
kv_indices: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
scale: Optional[float] = None,
+ **kwargs,
) -> torch.Tensor:
"""Flash Attention function. It supports 4 mask type.
1. custom mask: recv attention_mask
@@ -199,9 +223,9 @@ class ColoAttention:
4. padded causal mask: recv attention_mask, attention_mask_type, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, indices
Args:
- q (torch.Tensor): Query tensor. Shape should be [B, N, Sq, D]
- k (torch.Tensor): Key tensor. Shape should be [B, N, Skv, D]
- v (torch.Tensor): Value tensor. Shape should be [B, N, Skv, D]
+ q (torch.Tensor): Query tensor. Shape should be [B, nHeads, Sq, D]
+ k (torch.Tensor): Key tensor. Shape should be [B, nHeads, Skv, D]
+ v (torch.Tensor): Value tensor. Shape should be [B, nHeads, Skv, D]
attention_mask (Optional[torch.Tensor], optional): Attention mask tensor. Shape should be [B, 1, Sq, Skv]. Defaults to None.
attention_mask_type (AttnMaskType, optional): Attention mask type. Defaults to AttnMaskType.CUSTOM.
cu_seqlens_q (Optional[torch.Tensor], optional): The cumulative sequence lengths
@@ -218,7 +242,7 @@ class ColoAttention:
scale (Optional[float], optional): Scaling factor applied prior to softmax. Defaults to None.
Returns:
- torch.Tensor: Output tensor. Shape should be [B, N, Sq, D]
+ torch.Tensor: Output tensor. Shape should be [B, nHeads, Sq, D]
"""
# known issue: sdpa does not support attention mask which contains whole row of masked tokens, which leads to nan
# this case is usaul when padding mask is used and self attention is performed
@@ -252,6 +276,7 @@ class ColoAttention:
else:
# if attention_mask is None, attention_mask_type should be the default value
assert attention_mask_type == AttnMaskType.CUSTOM
+
# kernel dispatch
mask_type = attention_mask_type if attention_mask is not None else None
attn_func = ColoAttention._dispatch_kernel(q.dtype, mask_type)
@@ -274,3 +299,858 @@ class ColoAttention:
q_indices=q_indices,
kv_indices=kv_indices,
)
+
+
+def _load_varlen_helpers():
+ """Helper to load functions for padding and unpadding packed sequences.
+ Use only when flash attn is installed
+ """
+ global _pad_input, _unpad_input
+ # Flash attn claims this is more efficient than torch's bool indexing due to avoiding
+ # broadcast
+ if _pad_input is None or _unpad_input is None:
+ try:
+ from flash_attn.bert_padding import index_first_axis, pad_input
+
+ def unpad_input(hidden_states: torch.Tensor, indices: torch.Tensor):
+ return index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices)
+
+ _pad_input = pad_input
+ _unpad_input = unpad_input
+ except ImportError as e:
+ raise RuntimeError(
+ f"Flash Attention is not installed. You can install it via 'pip install flash-attn --no-build-isolation'"
+ ) from e
+
+
+def _load_flash_attn():
+ """A light-weight loader to check whether flash-attn is installed.
+ Can't use ColoAttention._dispatch_kernel because we mutate the backward pass
+ """
+ global _flash_attn_forward, _flash_attn_backward
+ if _flash_attn_forward is None or _flash_attn_backward is None:
+ try:
+ from flash_attn.flash_attn_interface import _flash_attn_varlen_backward as _flash_attn_backward
+ from flash_attn.flash_attn_interface import _flash_attn_varlen_forward as _flash_attn_forward
+ except ImportError as e:
+ raise RuntimeError(
+ f"Flash Attention is not installed. You can install it via 'pip install flash-attn --no-build-isolation'"
+ ) from e
+
+ _load_varlen_helpers()
+
+
+# NOTE: This can cause spawned processes to hang on exit
+# with python 3.9
+@torch.compile()
+def _rescale_out_lse(out, block_out, lse, block_lse):
+ """
+ Compute the new attention denominator:
+ exp(lse) + exp(block_lse) = exp(max_scale) * (exp(min_scale - max_scale) + 1)
+ Args:
+ out: (T, H, D)
+ block_out: (T, H, D)
+ lse: (H, T, 1)
+ block_lse: (H, T, 1)
+ """
+
+ # min_scale = torch.min(lse, block_lse)
+ # max_scale = torch.max(lse, block_lse)
+ # new_lse = max_scale + torch.log(1 + torch.exp(min_scale - max_scale))
+
+ # NOTE: directly assigning to .data here is buggy
+ # probably due to casting dtypes/strides
+ new_lse = lse + torch.log(1 + torch.exp(block_lse - lse))
+
+ new_block_lse = torch.exp(block_lse - new_lse)
+ out = (torch.exp(lse - new_lse) * out + new_block_lse * block_out).to(out)
+ lse = new_lse
+
+ # Equivalent to the above
+ # See https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795
+ # out = (out - F.sigmoid(block_lse - lse) * (out - block_out))
+ # lse = (lse - F.logsigmoid(lse - block_lse))
+ return out, lse
+
+
+class RingAttention(torch.autograd.Function):
+ """Implements the Ring Attention from `Ring Attention with Blockwise Transformers for Near-Infinite Context`
+ (https://arxiv.org/abs/2310.01889).
+ For load-balancing we adopted the "zigzag" attention scheme from https://github.com/zhuzilin/ring-flash-attention/tree/main
+ For portable integration with more models, we don't follow the spirit of "block-wise FNN" in the original paper,
+ which requires fusing FFN with the Flash Attention kernel/function (see https://arxiv.org/pdf/2305.19370;
+ implemented in Jax and not optimized).
+ We adopt the double ring topology from LoongTrain (https://arxiv.org/pdf/2406.18485) to fully utilize available
+ NICs on each node, by computing attention within a inner ring first and then sending all KVs to the next
+ ring at once.
+ """
+
+ # Globle cache to avoid recomputation for same-lengthed sequences
+ CU_SEQLENS: torch.Tensor = None # [B+1]
+ TOTAL_SEQLEN: int = None
+ HALF_INDICES: Tuple = None
+ SUPPORTED_MASK_TYPES = (AttnMaskType.CAUSAL, AttnMaskType.PADDED_CAUSAL)
+ ATTN_DONE: torch.cuda.Event = None
+ SP_STREAM: torch.cuda.Stream = None
+ SP_GROUP: dist.ProcessGroup = None
+ # duplicate process group for concurrent NCCL streams
+ # while both PyTorch and NCCL warns(https://github.com/pytorch/pytorch/commit/2dbe5cb979f674f0052a8eea1f7b6c3c0ba441d7)
+ # against this, in practice it seems to work fine.
+ INNER_RING_GROUP: dist.ProcessGroup = None
+ INNER_RING_GROUP_COPY: dist.ProcessGroup = None
+ INTER_RING_GROUP: dist.ProcessGroup = None
+ INTER_RING_GROUP_COPY: dist.ProcessGroup = None
+
+ @staticmethod
+ def get_double_ring_groups(sp_group, inner_ring_size=None):
+ """
+ Get 2D ring groups for the given process group. Generally, to avoid congestion, the inner ring size
+ shouldn't be larger than the number of NICs on each node.
+ Args:
+ sp_group (dist.ProcessGroup): Process group for sequence parallelism
+ inner_ring_size (Optional[int], optional): Inner ring size. Defaults to None.
+ Returns:
+ Tuple[dist.ProcessGroup, dist.ProcessGroup]: Inner-ring process group and inter-ring process group.
+ """
+ sp_size = dist.get_world_size(sp_group)
+ sp_rank = dist.get_rank(sp_group)
+
+ if inner_ring_size is None:
+ if torch.cuda.device_count() >= dist.get_world_size():
+ # single node, no need to consider NICs
+ return sp_group, sp_group
+ if sp_size <= 4:
+ inner_ring_size = min(2, sp_size)
+ else:
+ inner_ring_size = min(4, sp_size)
+ else:
+ assert (
+ inner_ring_size <= sp_size and sp_size % inner_ring_size == 0
+ ), f"Error: sp_size {sp_size} should be divisible by inner_ring_size {inner_ring_size}"
+
+ if inner_ring_size == sp_size:
+ return sp_group, sp_group
+ assert (
+ sp_size % inner_ring_size == 0
+ ), f"sp_size {sp_size} should be divisible by inner_ring_size {inner_ring_size}"
+
+ logger = get_dist_logger()
+ logger.info(
+ f"Using 2D Ring Attention with inner ring size {inner_ring_size} to maximze NIC util for inter-node comm. Cross your fingers for speed-ups!",
+ ranks=[0],
+ )
+ num_rings = sp_size // inner_ring_size
+ inner_ring_group = None
+ inter_ring_group = None
+
+ # Create inner ring groups
+ for i in range(inner_ring_size):
+ ranks = list(range(i * inner_ring_size, (i + 1) * inner_ring_size))
+ group = dist.new_group(ranks)
+ if sp_rank in ranks:
+ inner_ring_group = group
+
+ # Create inter ring groups
+ for i in range(num_rings):
+ ranks = list(range(i, sp_size, num_rings))
+ group = dist.new_group(ranks)
+ if sp_rank in ranks:
+ inter_ring_group = group
+
+ return inner_ring_group, inter_ring_group
+
+ @staticmethod
+ def attention(
+ q, # (B, H, Sq, D)
+ k,
+ v,
+ sp_group,
+ attention_mask_type,
+ cu_seqlens=None,
+ max_seqlen=None,
+ valid_indices=None,
+ dropout_p=0.0,
+ softmax_scale=None,
+ deterministic=False,
+ return_softmax=False,
+ inner_ring_size=None,
+ **kwargs,
+ ):
+ """
+ Ring Attention forward pass supporting variable-length sequences. When using varlen mode,
+ each sequence in the batch should have length divisible by sp_size * 2.
+
+ Args:
+ q (torch.Tensor): Query tensor. Shape should be [B, nHeads, Sq, D]
+ k (torch.Tensor): Key tensor. Shape should be [B, nHeads, Sq, Sq, D]
+ v (torch.Tensor): Value tensor. Shape should be [B, nHeads, Sq, Sq, D]
+ sp_group (Optional[dist.ProcessGroup]): Process group for sequence parallelism
+ sp_tream (torch.cuda.Stream): An different stream for output correction.
+ cu_seqlens (Optional[torch.Tensor], optional): The cumulative sequence lengths
+ of the sequences in the batch, used to index into q.
+ Shape should be [B+1].
+ max_seqlen (Optional[int], optional): Maximum query sequence length in the batch.
+ valid_indices (Optional[torch.Tensor], optional): The indices of non-masked tokens from get_pad_info.
+ Shape should be [t].
+ dropout_p (float, optional): Dropout probability. Defaults to 0.0.
+ softmax_scale (Optional[float], optional): Scaling factor applied prior to softmax.
+ deterministic (bool, optional): Whether to force deterministic backward pass. See https://github.com/Dao-AILab/flash-attention/issues/349
+ return_softmax (bool, optional): Whether to return the softmax denominator (logsumexp).
+ inner_ring_size (Optional[int], optional): Inner ring size of the 2D ring. By default use a heuristic to decide.
+
+ Returns:
+ out: Output tensor of shape [B, nHeads, Sq, D] or [T, nHeads, D] if pad_output is False.
+ softmax_lse: (if return_softmax is True) Softmax denominator (logsumexp).
+ Shape should be [total_q_seqlen, nHeads]
+ """
+ # Check input args
+ _load_flash_attn()
+ if RingAttention.ATTN_DONE is None:
+ RingAttention.ATTN_DONE = torch.cuda.Event()
+ if RingAttention.SP_STREAM is None:
+ RingAttention.SP_STREAM = torch.cuda.Stream()
+
+ assert (
+ q.shape[2] == k.shape[2]
+ ), "Q, K and V having different sequence lengths (inference or cross-attn)\
+ is not supported yet in training."
+ assert (
+ attention_mask_type in RingAttention.SUPPORTED_MASK_TYPES
+ ), f"Mask type {attention_mask_type} is not supported yet."
+
+ clone_pg = lambda pg: dist.new_group(dist.get_process_group_ranks(pg))
+
+ if RingAttention.SP_GROUP is not sp_group:
+ RingAttention.SP_GROUP = sp_group
+ inner_ring_group, inter_ring_group = RingAttention.get_double_ring_groups(sp_group, inner_ring_size)
+ RingAttention.INNER_RING_GROUP = inner_ring_group
+ RingAttention.INTER_RING_GROUP = inter_ring_group
+ else:
+ inner_ring_group = RingAttention.INNER_RING_GROUP
+ inter_ring_group = RingAttention.INTER_RING_GROUP
+
+ # (B, H, Sq, D) -> (B, Sq, H, D)
+ q, k, v = [x.transpose(1, 2).contiguous() for x in (q, k, v)]
+ pad_output = q.dim() == 4
+
+ # Get sequence length info for varlen forward
+ if attention_mask_type == AttnMaskType.CAUSAL:
+ # All sequences share the same length
+ b, sq, h, d = q.shape
+ max_seqlen = sq
+ # Cache to avoid recreation for a single sequence
+ if sq * b == RingAttention.TOTAL_SEQLEN:
+ cu_seqlens = RingAttention.CU_SEQLENS
+ else:
+ cu_seqlens = torch.arange(0, b * sq + 1, sq, device=q.device, dtype=torch.int32)
+ RingAttention.TOTAL_SEQLEN = b * sq
+
+ # "Packed" mode where sequences of different lengths are packed into [total_q_seqlen, H, D]
+ elif attention_mask_type == AttnMaskType.PADDED_CAUSAL:
+ assert (
+ cu_seqlens is not None and max_seqlen is not None and valid_indices is not None
+ ), "Packed mode requires pre-computed cu_seqlens and max_seq_len."
+ if pad_output:
+ b, sq, h, d = q.shape
+ q, k, v = [_unpad_input(x, valid_indices) for x in (q, k, v)]
+
+ out, softmax_lse = RingAttention.apply(
+ q,
+ k,
+ v,
+ sp_group,
+ RingAttention.SP_STREAM,
+ cu_seqlens,
+ max_seqlen,
+ dropout_p,
+ softmax_scale,
+ deterministic,
+ return_softmax,
+ attention_mask_type == AttnMaskType.PADDED_CAUSAL,
+ inner_ring_group,
+ inter_ring_group,
+ )
+
+ if attention_mask_type == AttnMaskType.PADDED_CAUSAL:
+ if pad_output:
+ out = _pad_input(out, valid_indices, b, sq) # (T, ...) -> (B, Sq, ...)
+ out = out.transpose(1, 2) # (B, Sq, H, D) -> (B, H, Sq, D)
+ else:
+ out = out.transpose(1, 2)
+
+ if return_softmax:
+ return out, softmax_lse
+ return out
+
+ @staticmethod
+ def forward(
+ ctx,
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ sp_group: dist.ProcessGroup,
+ sp_stream: torch.cuda.Stream,
+ cu_seqlens: torch.Tensor,
+ max_seqlen: int,
+ dropout_p: float = 0.0,
+ softmax_scale: Optional[float] = None,
+ deterministic: Optional[bool] = False,
+ return_softmax: Optional[bool] = False,
+ is_packed: Optional[bool] = False,
+ inner_ring_group: Optional[dist.ProcessGroup] = None,
+ inter_ring_group: Optional[dist.ProcessGroup] = None,
+ ):
+
+ cu_seqlens_q = cu_seqlens_kv = cu_seqlens
+ max_seqlen_q = max_seqlen_kv = max_seqlen
+ cu_seqlens_half = cu_seqlens // 2
+ max_seqlen_half = max_seqlen // 2
+
+ misc_kwargs = {
+ "window_size": (-1, -1),
+ "alibi_slopes": None,
+ "softmax_scale": q.shape[-1] ** -0.5 if softmax_scale is None else softmax_scale,
+ "dropout_p": dropout_p,
+ "block_table": None,
+ "softcap": 0.0,
+ "return_softmax": False,
+ }
+
+ if (
+ RingAttention.HALF_INDICES is not None
+ and cu_seqlens.shape == RingAttention.CU_SEQLENS.shape
+ and (cu_seqlens == RingAttention.CU_SEQLENS).all()
+ ):
+ half_idx_front, half_idx_back = RingAttention.HALF_INDICES
+ else:
+ half_idx_front = get_half_index(cu_seqlens, front=True)
+ half_idx_back = get_half_index(cu_seqlens, front=False)
+ RingAttention.HALF_INDICES = (half_idx_front, half_idx_back)
+ RingAttention.CU_SEQLENS = cu_seqlens
+
+ if is_packed:
+ t, h, d = q.shape
+ else:
+ b, sq, h, d = q.shape
+ t = b * sq
+ # Be careful about GQA/MQA in reshape
+ q, k, v = [x.view(t, *x.shape[-2:]) for x in (q, k, v)]
+
+ if inner_ring_group is None or inter_ring_group is None:
+ # Use one ring if not specified
+ inner_ring_group = inter_ring_group = sp_group
+
+ sp_size = dist.get_world_size(sp_group)
+ sp_rank = dist.get_rank(sp_group)
+ # Attempt to achieve concurrent comm in the two-stream forward
+ local_kv_comms = [RingComm(inner_ring_group) for _ in range(2)]
+ inter_ring_comm = RingComm(inter_ring_group)
+ local_sp_size = dist.get_world_size(inner_ring_group)
+ local_sp_rank = dist.get_rank(inner_ring_group)
+ inter_ring_rank = dist.get_rank(inter_ring_group) if inter_ring_group is not sp_group else 0
+ num_rings = dist.get_world_size(inter_ring_group) if inter_ring_group is not sp_group else 1
+
+ # Non-contiguous indexing copies to a new contiguous tensor,
+ # so only do it once
+ if sp_rank != sp_size - 1:
+ q1 = q[half_idx_back]
+
+ # Pre-allocate double buffer for overlapping and receiving next step's inputs
+ kv_buffers = [torch.stack((k, v))] # (2, B, Sq, H, D)
+ kv_buffers.append(torch.empty_like(kv_buffers[0]))
+
+ # outputs
+ out = None
+ block_out = [None, None]
+ softmax_lse = [None, None]
+ block_softmax_lse = [None, None] # log sum exp, the denominator of softmax in attention
+ rng_states = [None for _ in range(sp_size)]
+ sp_streams = [torch.cuda.current_stream(), sp_stream]
+
+ def _forward(q, k, v, causal):
+ (
+ _,
+ _,
+ _,
+ _,
+ out,
+ softmax_lse,
+ _,
+ rng_state,
+ ) = _flash_attn_forward(
+ q,
+ k,
+ v,
+ cu_seqlens_q if q.shape[0] == t else cu_seqlens_half,
+ cu_seqlens_kv if k.shape[0] == t else cu_seqlens_half,
+ max_seqlen_q if q.shape[0] == t else max_seqlen_half,
+ max_seqlen_kv if k.shape[0] == t else max_seqlen_half,
+ causal=causal,
+ **misc_kwargs,
+ )
+ return out, softmax_lse, rng_state
+
+ def _local_ring_forward():
+ # (Hopefully) overlap output correction with next flash attn
+ for i in range(local_sp_size):
+ with torch.cuda.stream(sp_streams[i % 2]):
+ # Wait for current kv from prev rank
+ # NOTE: waiting outside the current stream will NOT correctly synchronize.
+ if i > 0:
+ local_kv_comms[(i + 1) % 2].wait()
+
+ # Avoid overwriting attn input when it shares mem with buffer
+ if not RingAttention.ATTN_DONE.query():
+ kv_buffers[(i + 1) % 2] = torch.empty_like(kv_buffers[i % 2])
+ if i < local_sp_size - 1:
+ local_kv_comms[i % 2].send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2])
+
+ if i == 0:
+ # Compute with local KV; no mask
+ kv_block = kv_buffers[0]
+ q_block = q
+ (block_out[i % 2], block_softmax_lse[i % 2], rng_states[i]) = _forward( # (T, H, D) # (H, T)
+ q_block, kv_block[0], kv_block[1], causal=True
+ )
+ elif i <= local_sp_rank:
+ # Received the "surrounding" kv chunks
+ # Drop the second half of received kv
+ # (2, t // 2, H, D)
+ kv_block = kv_buffers[i % 2][:, half_idx_front]
+ q_block = q
+ (
+ block_out[i % 2], # (T, H, D)
+ block_softmax_lse[i % 2], # (H, T)
+ rng_states[i],
+ ) = _forward(q_block, kv_block[0], kv_block[1], causal=False)
+ else:
+ # Received the inner kv chunks
+ # Drop the first half of q
+ kv_block = kv_buffers[i % 2]
+ q_block = q1
+ (
+ block_out[i % 2], # (T, H, D)
+ block_softmax_lse[i % 2], # (H, T)
+ rng_states[i],
+ ) = _forward(q_block, kv_block[0], kv_block[1], causal=False)
+ RingAttention.ATTN_DONE.record()
+
+ block_softmax_lse[i % 2] = (
+ block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float()
+ ) # (H, T) -> (T, H, 1)
+ assert block_out[i % 2].shape[:-1] == block_softmax_lse[i % 2].shape[:-1]
+ # Output and log sum exp correction. Ideally overlap this with the next flash attn kernel.
+ # In reality this always finishes before next flash attn; no need for extra sync.
+ if i == 0:
+ out = block_out[0]
+ softmax_lse = block_softmax_lse[0]
+ elif i <= local_sp_rank:
+ out, softmax_lse = _rescale_out_lse(
+ out, block_out[i % 2], softmax_lse, block_softmax_lse[i % 2]
+ )
+ else:
+ out[half_idx_back], softmax_lse[half_idx_back] = _rescale_out_lse(
+ out[half_idx_back], block_out[i % 2], softmax_lse[half_idx_back], block_softmax_lse[i % 2]
+ )
+
+ torch.cuda.current_stream().wait_stream(sp_stream)
+ return out, softmax_lse
+
+ def _other_ring_forward(ring_num_idx, out, softmax_lse):
+ # Loop through the inner ring after receiving
+ # all new KVs from the previous inner ring
+ for i in range(local_sp_size):
+ with torch.cuda.stream(sp_streams[i % 2]):
+ if not RingAttention.ATTN_DONE.query():
+ kv_buffers[(i + 1) % 2] = torch.empty_like(kv_buffers[i % 2])
+ if i < local_sp_size - 1:
+ local_kv_comms[i % 2].send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2])
+
+ # Send & recv KV
+ if i > 0:
+ local_kv_comms[(i + 1) % 2].wait()
+
+ if ring_num_idx > inter_ring_rank:
+ kv_block = kv_buffers[i % 2]
+ (
+ block_out[i % 2],
+ block_softmax_lse[i % 2],
+ rng_states[i + local_sp_size * ring_num_idx],
+ ) = _forward(q1, kv_block[0], kv_block[1], causal=False)
+ RingAttention.ATTN_DONE.record()
+ block_softmax_lse[i % 2] = (
+ block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float()
+ )
+ out[half_idx_back], softmax_lse[half_idx_back] = _rescale_out_lse(
+ out[half_idx_back], block_out[i % 2], softmax_lse[half_idx_back], block_softmax_lse[i % 2]
+ )
+ else:
+ kv_block = kv_buffers[i % 2][:, half_idx_front]
+ (
+ block_out[i % 2],
+ block_softmax_lse[i % 2],
+ rng_states[i + local_sp_size * ring_num_idx],
+ ) = _forward(q, kv_block[0], kv_block[1], causal=False)
+ RingAttention.ATTN_DONE.record()
+ block_softmax_lse[i % 2] = (
+ block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float()
+ )
+ out, softmax_lse = _rescale_out_lse(
+ out, block_out[i % 2], softmax_lse, block_softmax_lse[i % 2]
+ )
+
+ torch.cuda.current_stream().wait_stream(sp_stream)
+ return out, softmax_lse
+
+ # Send and recv KV between rings at once to maximize NIC util.
+ inter_ring_kv = None
+ for ring_num_idx in range(num_rings):
+ if ring_num_idx > 0:
+ inter_ring_comm.wait()
+ # Reset indices
+ kv_buffers[0] = inter_ring_kv
+
+ if ring_num_idx < num_rings - 1:
+ if ring_num_idx == 0:
+ to_send = kv_buffers[0]
+ else:
+ # The last received KV
+ to_send = kv_buffers[(local_sp_size - 1) % 2]
+ inter_ring_kv = inter_ring_comm.send_recv(to_send)
+
+ if ring_num_idx == 0:
+ out, softmax_lse = _local_ring_forward()
+ else:
+ out, softmax_lse = _other_ring_forward(ring_num_idx, out, softmax_lse)
+
+ out = out.to(q.dtype)
+ if not is_packed:
+ out = out.view(b, sq, h, d)
+ q, k, v = [x.view(b, sq, *x.shape[-2:]) for x in (q, k, v)] # (T, H, D) -> (B, Sq, H, D)
+ softmax_lse = softmax_lse.squeeze(-1)
+
+ ctx.sp_group = sp_group
+ ctx.max_seqlen_q = ctx.max_seqlen_kv = max_seqlen
+ misc_kwargs["deterministic"] = deterministic
+ del misc_kwargs["return_softmax"]
+ ctx.misc_kwargs = misc_kwargs
+ ctx.is_packed = is_packed
+
+ ctx.kv_group = inner_ring_group
+ ctx.inter_kv_group = inter_ring_group
+
+ ctx.save_for_backward(
+ q,
+ k,
+ v,
+ out,
+ softmax_lse.transpose(0, 1).contiguous(), # (T, H) -> (H, T)
+ cu_seqlens_q,
+ cu_seqlens_kv,
+ half_idx_front,
+ half_idx_back,
+ *rng_states,
+ )
+
+ if return_softmax:
+ return out, softmax_lse
+ return out, None
+
+ def backward(ctx, dout, _):
+ """
+ During backward, we accumulate q grads on each rank locally, but iterate kv and their grads
+ over all ranks for accumulation.
+ """
+ (q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_kv, half_idx_front, half_idx_back) = ctx.saved_tensors[:9]
+ rng_states = ctx.saved_tensors[9:]
+
+ is_packed = ctx.is_packed
+ max_seqlen_q = ctx.max_seqlen_q
+ max_seqlen_kv = ctx.max_seqlen_kv
+ cu_seqlens_half = cu_seqlens_q // 2
+ max_seqlen_half = max_seqlen_q // 2
+ misc_kwargs = ctx.misc_kwargs
+ del misc_kwargs["block_table"]
+
+ assert (
+ out.shape == dout.shape == q.shape
+ ), f"out {out.shape} and dout {dout.shape} should have the same shape ({q.shape})."
+
+ if is_packed:
+ t, h, d = q.shape
+ else:
+ b, sq, h, d = q.shape
+ t = b * sq
+ q, k, v, out, dout = [x.view(t, *x.shape[-2:]) for x in (q, k, v, out, dout)]
+
+ # Sequence parallel args
+ sp_group = ctx.sp_group
+ local_kv_group = ctx.kv_group
+ inter_kv_group = ctx.inter_kv_group
+
+ local_sp_rank = dist.get_rank(sp_group)
+ sp_size = dist.get_world_size(sp_group)
+ # Using separate streams (pg) for concurrent kv and dkv comm may
+ # cause NCCL "software caused connection abort" here...
+ local_kv_comm = RingComm(local_kv_group)
+ local_dkv_comm = RingComm(local_kv_group)
+ inter_kv_comm = RingComm(inter_kv_group)
+ inter_dkv_comm = RingComm(inter_kv_group)
+ local_sp_size = dist.get_world_size(local_kv_group)
+ local_sp_rank = dist.get_rank(local_kv_group)
+
+ if dist.get_world_size(inter_kv_group) != sp_size:
+ num_rings = dist.get_world_size(inter_kv_group)
+ inter_ring_rank = dist.get_rank(inter_kv_group)
+ else:
+ num_rings = 1
+ inter_ring_rank = 0
+
+ if local_sp_rank != sp_size - 1:
+ softmax_lse1 = softmax_lse[:, half_idx_back]
+ dout = dout.contiguous()
+
+ # Double comm buffers for sending and receiving kv
+ kv_buffers = [torch.stack((k, v))] # (2, T, H, D)
+ kv_buffers.append(torch.empty_like(kv_buffers[0]))
+
+ dq = None # (T, H, D)
+ # Intermediate outputs
+ dq_block = torch.empty_like(q) # (T, H, D)
+ dk_block = torch.empty_like(k) # (T, H, D)
+ dv_block = torch.empty_like(v) # (T, H, D)
+ dkv_buffers = [torch.empty_like(kv, dtype=torch.float32) for kv in kv_buffers] # (T, H, D)
+ del k, v
+
+ def _backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, rng_state, causal):
+ _flash_attn_backward(
+ dout,
+ q,
+ k,
+ v,
+ out,
+ softmax_lse,
+ dq,
+ dk,
+ dv,
+ cu_seqlens_q if dq.shape[0] == t else cu_seqlens_half,
+ cu_seqlens_kv if dk.shape[0] == t else cu_seqlens_half,
+ max_seqlen_q if dq.shape[0] == t else max_seqlen_half,
+ max_seqlen_kv if dk.shape[0] == t else max_seqlen_half,
+ causal=causal,
+ rng_state=rng_state,
+ **misc_kwargs,
+ )
+
+ # NOTE: We avoid using two streams due to doubled buffers
+ # and that backward is more communication intensive.
+ def _local_ring_backward():
+ for i in range(local_sp_size):
+ if i > 0:
+ local_kv_comm.wait()
+
+ if i < local_sp_size - 1:
+ # Send kv to next rank for backward
+ local_kv_comm.send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2])
+
+ if i == 0:
+ # Backward with local kv
+ k_, v_ = kv_buffers[i % 2]
+ q_, dout_, out_ = q, dout, out
+ dq_, dk_, dv_ = dq_block, dk_block, dv_block
+ _backward(dout_, q_, k_, v_, out_, softmax_lse, dq_, dk_, dv_, rng_states[i], causal=True)
+
+ elif i <= local_sp_rank:
+ # Drop the second half of kv
+ # (T, H, D) -> (T // 2, H, D)
+ k_, v_ = [x[half_idx_front] for x in kv_buffers[i % 2]]
+ dk_, dv_ = [x[: t // 2] for x in (dk_block, dv_block)]
+ dq_, q_, out_, dout_ = (dq_block, q, out, dout)
+ _backward(dout_, q_, k_, v_, out_, softmax_lse, dq_, dk_, dv_, rng_states[i], causal=False)
+
+ else:
+ # Drop the first half of q
+ k_, v_ = kv_buffers[i % 2]
+ dk_, dv_ = dk_block, dv_block
+ q_, out_, dout_ = [x[half_idx_back] for x in (q, out, dout)]
+ dq_ = dq_block[: t // 2]
+ _backward(dout_, q_, k_, v_, out_, softmax_lse1, dq_, dk_, dv_, rng_states[i], causal=False)
+
+ # Accumulate grads
+ if i == 0:
+ dq = dq_block.float()
+ dkv_buffers[i % 2][0] = dk_block.float()
+ dkv_buffers[i % 2][1] = dv_block.float()
+ else:
+ # Accumulate local dq
+ if i <= local_sp_rank:
+ dq += dq_ # (T, H, D)
+ else:
+ dq[half_idx_back] += dq_
+
+ # Wait for mobile kv grad accumulators
+ local_dkv_comm.wait()
+
+ if i <= local_sp_rank:
+ # q blocks "surrounded" by kv blocks
+ dkv_buffers[i % 2][0][half_idx_front] += dk_
+ dkv_buffers[i % 2][1][half_idx_front] += dv_
+ else:
+ # q blocks "surrounding" kv blocks
+ dkv_buffers[i % 2][0] += dk_
+ dkv_buffers[i % 2][1] += dv_
+ local_dkv_comm.send_recv(send_tensor=dkv_buffers[i % 2], recv_tensor=dkv_buffers[(i + 1) % 2])
+
+ local_dkv_comm.wait()
+ dkv_recv = dkv_buffers[local_sp_size % 2]
+ dkv_send = dkv_buffers[(local_sp_size - 1) % 2]
+ return dq, dkv_recv, dkv_send
+
+ def _other_ring_backward(ring_num_idx, dq):
+ if ring_num_idx > inter_ring_rank:
+ # Indexing is expensive
+ q_, out_, dout_ = [x[half_idx_back] for x in (q, out, dout)]
+ else:
+ q_, out_, dout_ = (q, out, dout)
+
+ for i in range(local_sp_size):
+ if i > 0:
+ local_kv_comm.wait()
+
+ if i < local_sp_size - 1:
+ local_kv_comm.send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2])
+
+ rng_state = rng_states[i + local_sp_size * ring_num_idx]
+ if ring_num_idx > inter_ring_rank:
+ k_, v_ = kv_buffers[i % 2]
+ dk_, dv_ = dk_block, dv_block
+ dq_ = dq_block[: t // 2]
+ _backward(dout_, q_, k_, v_, out_, softmax_lse1, dq_, dk_, dv_, rng_state, causal=False)
+
+ dq[half_idx_back] += dq_
+ if i > 0:
+ local_dkv_comm.wait()
+ else:
+ inter_dkv_comm.wait()
+
+ dkv_buffers[i % 2][0] += dk_
+ dkv_buffers[i % 2][1] += dv_
+ else:
+ k_, v_ = [x[half_idx_front] for x in kv_buffers[i % 2]]
+ dk_, dv_ = [x[: t // 2] for x in (dk_block, dv_block)]
+ dq_ = dq_block
+ _backward(dout_, q_, k_, v_, out_, softmax_lse, dq_, dk_, dv_, rng_state, causal=False)
+
+ dq += dq_
+ if i > 0:
+ local_dkv_comm.wait()
+ else:
+ inter_dkv_comm.wait()
+
+ dkv_buffers[i % 2][0][half_idx_front] += dk_
+ dkv_buffers[i % 2][1][half_idx_front] += dv_
+
+ local_dkv_comm.send_recv(send_tensor=dkv_buffers[i % 2], recv_tensor=dkv_buffers[(i + 1) % 2])
+
+ local_dkv_comm.wait()
+ dkv_recv = dkv_buffers[local_sp_size % 2]
+ dkv_send = dkv_buffers[(local_sp_size - 1) % 2]
+ return dq, dkv_recv, dkv_send
+
+ inter_ring_kv = None
+ for ring_num_idx in range(num_rings):
+ if ring_num_idx > 0:
+ inter_kv_comm.wait()
+ kv_buffers[0] = inter_ring_kv
+
+ if ring_num_idx < num_rings - 1:
+ # Re-allocate a buffer in each inter-ring step
+ inter_ring_kv = inter_kv_comm.send_recv(kv_buffers[0])
+
+ if ring_num_idx == 0:
+ dq, dkv_recv, dkv_send = _local_ring_backward()
+ else:
+ dq, dkv_recv, dkv_send = _other_ring_backward(ring_num_idx, dq)
+
+ if num_rings > 1:
+ # Reuse the local buffers
+ inter_dkv_comm.send_recv(send_tensor=dkv_recv, recv_tensor=dkv_send)
+ # Reset indices
+ dkv_buffers[0] = dkv_send
+ dkv_buffers[1] = dkv_recv
+ if ring_num_idx == num_rings - 1:
+ inter_dkv_comm.wait()
+ dkv_recv = dkv_buffers[0]
+
+ dq, dk, dv = [x.to(q.dtype) for x in (dq, *dkv_recv)]
+ if not is_packed:
+ dq, dk, dv = [x.view(b, sq, *x.shape[-2:]) for x in (dq, dk, dv)]
+
+ return (dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None)
+
+ @staticmethod
+ def prepare_varlen_batch(
+ attention_mask: torch.Tensor,
+ sp_group: dist.ProcessGroup,
+ inputs_embeds: torch.Tensor = None,
+ position_ids: Optional[torch.Tensor] = None,
+ is_label: bool = False,
+ is_2d: bool = True,
+ ):
+ """
+ Preprocess a batch of padded sequence by splitting input sequence by sp_size
+ sequence-wise and packing them into one sequence. Updates the mask info accordingly.
+ Args:
+ attention_mask (torch.Tensor): Contains the mask [B, Sq], where True means the token is NOT masked.
+ sp_group (dist.ProcessGroup): Process group for sequence parallelism
+ inputs_embeds (torch.Tensor): Input embeddings. Shape should be [B, Sq, ...]
+ position_ids (Optional[torch.Tensor], optional): Position ids of shape [Sq] or [1, Sq]. Defaults to None.
+ is_label (bool, optional): Whether inputs_embeds is instead a label tensor. If True, mask out the first
+ token of each sequence.
+ is_2d (bool, optional): Whether to return 2D outputs padded to max_seqlen // sp_size or flatten
+ the batch dim to a packed 1d sequence. Contingent on model forward shape definitions.
+
+ Returns:
+ inputs_embeds: Packed input embeddings of shape [B, Sq // sp_size, ...].
+ mask_info: A dictionary of mask info.
+ position_ids: Packed position ids of shape [..., Sq // sp_size].
+
+ """
+ _load_varlen_helpers()
+ sp_size = dist.get_world_size(group=sp_group)
+ sp_rank = dist.get_rank(group=sp_group)
+ mask_info = {}
+ mask_info["max_seqlen"], mask_info["cu_seqlens"] = get_pad_info(attention_mask, return_indices=False)
+
+ # Unpad, split seq-wise, then pad back to (B, max_seqlen // sp_size)
+ # Split mask to compute local nonzero position indices
+ # (B, Sq) -> (B, max_seqlen // sp_size)
+ attention_mask = attention_mask[:, : mask_info["max_seqlen"]]
+ if inputs_embeds is not None:
+ inputs_embeds = inputs_embeds[:, : mask_info["max_seqlen"]]
+ inputs_embeds = split_varlen_zigzag(
+ inputs_embeds,
+ mask_info["cu_seqlens"],
+ sp_group,
+ mask_info["max_seqlen"],
+ is_2d=is_2d,
+ is_label=is_label,
+ )
+ attention_mask = split_varlen_zigzag(
+ attention_mask, mask_info["cu_seqlens"], sp_group, mask_info["max_seqlen"], is_2d=is_2d
+ )
+
+ if position_ids is not None:
+ indices = torch.tensor([sp_rank, 2 * sp_size - sp_rank - 1], device=inputs_embeds.device)
+ position_ids = (
+ position_ids[..., : mask_info["max_seqlen"]] # unpad
+ .view(-1, sp_size * 2, mask_info["max_seqlen"] // (sp_size * 2))
+ .index_select(-2, indices)
+ .view(-1, mask_info["max_seqlen"] // sp_size)
+ )
+
+ mask_info["max_seqlen"] //= sp_size
+ mask_info["valid_indices"] = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
+ mask_info["cu_seqlens"] //= sp_size
+ mask_info["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL
+ return inputs_embeds, mask_info, position_ids
diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py
index 37c754241..020e793af 100644
--- a/colossalai/shardformer/layer/linear.py
+++ b/colossalai/shardformer/layer/linear.py
@@ -200,9 +200,7 @@ class Linear1D_Col(ParallelModule):
# Matrix multiply.
bias = self.bias if not self.skip_bias_add else None
- if self.seq_parallel_mode is None:
- output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True)
- elif self.seq_parallel_mode == "split_gather":
+ if self.seq_parallel_mode == "split_gather":
input_parallel = gather_forward_reducescatter_backward(
input_parallel, self.process_group, self.seq_parallel_dim
)
@@ -211,6 +209,8 @@ class Linear1D_Col(ParallelModule):
output_parallel = linear_gather_forward_reducescatter_backward(
input_parallel, self.weight, bias, self.process_group, True, self.seq_parallel_dim, self.overlap, True
)
+ else:
+ output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True)
if self.gather_output:
# All-gather across the partitions.
@@ -416,10 +416,7 @@ class Linear1D_Row(ParallelModule):
handle.wait()
output = torch.cat(output_parallel_list, dim=-1)
else:
- if self.seq_parallel_mode is None:
- output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False)
- output = reduce_forward(output_parallel, self.process_group)
- elif self.seq_parallel_mode == "split_gather":
+ if self.seq_parallel_mode == "split_gather":
output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False)
output = reducescatter_forward_gather_backward(
output_parallel, self.process_group, self.seq_parallel_dim
@@ -432,6 +429,9 @@ class Linear1D_Row(ParallelModule):
dim=self.seq_parallel_dim,
ring=True,
)
+ else:
+ output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False)
+ output = reduce_forward(output_parallel, self.process_group)
if not self.skip_bias_add:
if self.bias is not None:
diff --git a/colossalai/shardformer/layer/loss.py b/colossalai/shardformer/layer/loss.py
index cea2da03f..12df824d1 100644
--- a/colossalai/shardformer/layer/loss.py
+++ b/colossalai/shardformer/layer/loss.py
@@ -4,10 +4,15 @@ from torch.autograd import Function
from torch.distributed import ProcessGroup
from torch.nn import CrossEntropyLoss
+from colossalai.shardformer.layer._operation import reduce_forward
from colossalai.shardformer.shard import ShardConfig
+from .utils import is_share_sp_tp
+
__all__ = ["DistCrossEntropy", "cross_entropy_1d", "dist_cross_entropy"]
+_IGNORE_IDX = -100
+
class DistCrossEntropy(Function):
r"""
@@ -26,11 +31,12 @@ class DistCrossEntropy(Function):
process_group: ProcessGroup,
vocab_size: int,
dtype=torch.float32,
+ mode="mean",
):
r"""
Calculate the cross entropy loss before gather, the origin loss function is as follows:
loss = -log(exp(x[class])/sum(exp(x[i]))
- and can be rewrite as:
+ and can be rewriten as:
loss = log(sum(exp(x[i])) - x[class]
To avoid the `nan` of log(sum(exp(x[i]))), we minus the max of x[i]
@@ -44,12 +50,10 @@ class DistCrossEntropy(Function):
Returns:
:class:`torch.Tensor`: The cross entropy loss
"""
+ assert mode in ["mean", "sum"]
# get the max
logits_max = torch.max(vocab_logits, dim=-1)[0]
- dist.all_reduce(logits_max, op=dist.ReduceOp.MAX, group=process_group)
-
- # minus the max to avoid the result of sum of exp is too large and the log is nan
- vocab_logits = vocab_logits - logits_max.unsqueeze(dim=-1)
+ handle = dist.all_reduce(logits_max, op=dist.ReduceOp.MAX, group=process_group, async_op=True)
# mask the target in the local device
rank = dist.get_rank(group=process_group)
@@ -70,24 +74,25 @@ class DistCrossEntropy(Function):
mask = (target < down_threshold) | (target >= up_threshold)
masked_target = target.clone() - down_threshold
masked_target[mask] = 0
+ masked_target_1d = masked_target.view(-1).contiguous()
+ # minus the max to avoid the result of sum of exp is too large and the log is nan
+ handle.wait()
+ vocab_logits = vocab_logits - logits_max.unsqueeze(dim=-1)
# reshape the logits and target
# reshape the vocab_logits to [bath_size * seq_len, vocab_size]
# reshape the labels to [bath_size * seq_len]
self_vocab_size = vocab_logits.size()[-1]
logits_2d = vocab_logits.view(-1, self_vocab_size)
- masked_target_1d = masked_target.view(-1)
# extract the x[class] and set the x[other device] to zero
- pred_logits_1d = logits_2d[
- torch.arange(start=0, end=logits_2d.shape[0], device=logits_2d.device), masked_target_1d
- ]
- pred_logits_1d = pred_logits_1d.clone().contiguous()
+ idx = torch.arange(start=0, end=logits_2d.shape[0], device=logits_2d.device)
+ pred_logits_1d = logits_2d[idx, masked_target_1d].contiguous()
pred_logits = pred_logits_1d.view_as(target)
pred_logits[mask] = 0.0
- # allreduce the get all x(i,y)
- dist.all_reduce(pred_logits, op=dist.ReduceOp.SUM, group=process_group)
+ # all-reduce to get full x[i, y]
+ handle = dist.all_reduce(pred_logits, op=dist.ReduceOp.SUM, group=process_group, async_op=True)
exp_logits = vocab_logits
torch.exp(vocab_logits, out=exp_logits)
sum_exp_logits = torch.sum(exp_logits, dim=-1, dtype=torch.float32)
@@ -95,23 +100,29 @@ class DistCrossEntropy(Function):
# calculate the loss
# loss = log(sum(exp(x[i]))) - x[class]
+ handle.wait()
loss = torch.where(target == ignore_index, 0.0, torch.log(sum_exp_logits) - pred_logits)
- num_non_zero = torch.sum(loss != 0.0)
- ctx.inv_num_non_zero = 1.0 / num_non_zero
- loss = torch.sum(loss).div_(num_non_zero)
+ if mode == "mean":
+ num_non_zero = torch.sum(loss != 0.0)
+ ctx.inv_num_non_zero = 1.0 / num_non_zero
+ loss = torch.sum(loss).div_(num_non_zero)
+ else:
+ loss = torch.sum(loss)
# calculate the softmax
exp_logits = exp_logits.div(sum_exp_logits.unsqueeze(dim=-1)).to(dtype)
exp_logits[target == ignore_index] = 0.0
ctx.save_for_backward(exp_logits, mask, masked_target_1d)
ctx.dtype = dtype
+ ctx.mode = mode
return loss
@staticmethod
def backward(ctx, grad_output):
# retrieve the saved tensors
- grad_output = grad_output * ctx.inv_num_non_zero
+ if ctx.mode == "mean":
+ grad_output = grad_output * ctx.inv_num_non_zero
exp_logits, mask, masked_target_1d = ctx.saved_tensors
# use exp logits as the input grad
@@ -123,55 +134,113 @@ class DistCrossEntropy(Function):
grad_logits_2d[torch.arange(0, grad_logits_2d.shape[0]), masked_target_1d] -= update
grad_logits.mul_(grad_output.unsqueeze(dim=-1))
- return grad_logits, None, None, None, None, None
+ return grad_logits, None, None, None, None, None, None
def cross_entropy_1d(
vocab_logits: torch.Tensor,
labels: torch.Tensor,
- ignore_index: int = -100,
+ ignore_index: int = _IGNORE_IDX,
process_group: ProcessGroup = None,
vocab_size: int = None,
dtype: torch.dtype = None,
+ mode: str = "mean",
) -> torch.Tensor:
- return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group, vocab_size, dtype)
+ return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group, vocab_size, dtype, mode)
def dist_cross_entropy(
- labels: torch.Tensor,
- logits: torch.Tensor,
+ labels: torch.Tensor, # [B, S] or [B, S, Vocab_size]
+ logits: torch.Tensor, # [B, S, Vocab_size]
shard_config: ShardConfig,
out_features: int,
vocab_size: int,
dtype: torch.dtype,
+ seq_dim: int = 1,
) -> torch.Tensor:
"""
- Helper to compute cross entropy loss for most shardformer models,
- compatible with PP, TP and SP.
+ Helper to compute cross entropy loss for most shardformer models supporting PP, TP and SP.
"""
- if labels is not None:
- # Shift so that tokens < n predict n
- shift_logits = logits[..., :-1, :].contiguous()
- shift_labels = labels[..., 1:].contiguous()
- # Flatten the tokens
- loss_fct = CrossEntropyLoss()
- shift_labels = shift_labels.view(-1)
- shift_labels = shift_labels.to(shift_logits.device)
- if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
- # Cross entropy with all-reduce for TP
- new_vocab_size = logits.shape[-1]
- shift_logits = shift_logits.view(-1, new_vocab_size)
- loss = cross_entropy_1d(
- shift_logits,
- shift_labels,
- process_group=shard_config.tensor_parallel_process_group,
- vocab_size=out_features,
- dtype=dtype,
- )
- else:
- # NOTE if use TP and not parallel_output, the output is gathered.
- # see VocabParallelLMHead1D
- shift_logits = shift_logits.view(-1, vocab_size)
- loss = loss_fct(shift_logits, shift_labels)
+ # Split labels if not gather output
+ sp_group = shard_config.sequence_parallel_process_group
+ sp_rank = dist.get_rank(sp_group)
+ sp_size = shard_config.sequence_parallel_size
+ sp_mode = shard_config.sequence_parallelism_mode
+ parallel_output = shard_config.parallel_output
+ is_tp = shard_config.enable_tensor_parallelism
+ is_packed = labels.dim() == 2
+ if is_packed:
+ bs, seq_len = labels.shape
+ else:
+ # padded sequence
+ seq_len = labels.shape[-1]
+ logits = logits.reshape(-1, *logits.shape[2:])
+ seq_dim = 0
- return loss
+ # Shift labels to predict the next token, and remove the tail logit predicting
+ is_sp = sp_size > 1 and (not is_share_sp_tp(sp_mode))
+ split_labels_here = seq_len // sp_size == logits.size(seq_dim) # ring attn splits labels before forward
+
+ if sp_mode == "ring_attn":
+ # For Zigzag Ring Attention, labels should've been split and
+ # shifted by RingAttention.prepare_varlen_batch()
+ if sp_rank == 0:
+ logits = logits[..., :-1, :]
+ logits = torch.cat([logits, torch.full_like(logits[:, :1, :], _IGNORE_IDX)], dim=seq_dim)
+ elif is_sp:
+ # Shift only once: either before splitting or in the last rank without splitting
+ if split_labels_here or (sp_rank == sp_size - 1):
+ labels = labels[..., 1:]
+ if split_labels_here:
+ labels = labels.split(seq_len // sp_size, dim=-1)[sp_rank]
+
+ if sp_rank == sp_size - 1:
+ logits = logits[..., :-1, :]
+ # Pad logits and labels to the same shape across all ranks for TP all_reduce
+ if is_tp and parallel_output:
+ # If is packed sequence (label dim is 1), then each seq already has the end label token padded.
+ # torch.cat is faster than F.pad...
+ pad_shape = (logits.shape[0], 1, *logits.shape[2:]) if is_packed else (1, *logits.shape[1:])
+ padding = torch.full(pad_shape, _IGNORE_IDX, dtype=logits.dtype, device=logits.device)
+ logits = torch.cat([logits, padding], dim=seq_dim)
+ pad_shape = (labels.shape[0], 1) if is_packed else (1,)
+ padding = torch.full(pad_shape, _IGNORE_IDX, dtype=labels.dtype, device=labels.device)
+ labels = torch.cat([labels, padding], dim=seq_dim)
+ else:
+ labels = labels[..., 1:]
+ logits = logits[..., :-1, :]
+ labels = labels.contiguous()
+ logits = logits.contiguous()
+ num_nonzero = (labels != _IGNORE_IDX).sum()
+ assert labels.shape == logits.shape[:-1], f"label shape {labels.shape} does not match logit shape {logits.shape}"
+
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss(ignore_index=_IGNORE_IDX, reduction="sum")
+ labels = labels.view(-1)
+
+ if is_tp and parallel_output:
+ # Cross entropy with all-reduce for TP
+ new_vocab_size = logits.shape[-1]
+ logits = logits.view(-1, new_vocab_size)
+ loss = cross_entropy_1d(
+ logits,
+ labels,
+ process_group=shard_config.tensor_parallel_process_group,
+ vocab_size=out_features,
+ dtype=dtype,
+ mode="sum",
+ )
+ else:
+ # NOTE if use TP and not parallel_output, the output is gathered in VocabParallelLMHead1D
+ logits = logits.view(-1, vocab_size)
+ loss = loss_fct(logits, labels)
+
+ # Reduce loss instead of gathering logits over seq dim for savings
+ if split_labels_here or sp_mode == "ring_attn":
+ # Get the global non-zero count
+ loss = torch.stack((loss, num_nonzero))
+ # Rescale to offset the grad / (DP * SP) in HybridParallelPlugin
+ loss = reduce_forward(loss, sp_group, grad_scale=sp_size)
+ loss, num_nonzero = loss[0], loss[1].detach()
+ loss = (loss / num_nonzero).squeeze()
+ return loss
diff --git a/colossalai/shardformer/layer/utils.py b/colossalai/shardformer/layer/utils.py
index 9c6ced445..c1a73ce05 100644
--- a/colossalai/shardformer/layer/utils.py
+++ b/colossalai/shardformer/layer/utils.py
@@ -1,5 +1,5 @@
from contextlib import contextmanager
-from typing import List
+from typing import List, Optional, Union
import torch
import torch.distributed as dist
@@ -289,3 +289,199 @@ def create_randomizer_with_offset(
Randomizer.increment_index()
return Randomizer(seed=base_seed)
+
+
+def split_batch_zigzag(
+ batch: Union[torch.Tensor, List[torch.Tensor]], sp_group: ProcessGroup, seq_dim: int = 1, is_label: bool = False
+) -> Union[torch.Tensor, List[torch.Tensor]]:
+ """
+ Split the input along the sequence dimension for Ring Attention. Naively spliting the attention mask
+ in the causal setting will result in the preceding ranks having much less workload.
+ We split after "folding" the 2D attention mask in half (https://github.com/zhuzilin/ring-flash-attention/issues/2).
+ For example, for sp_size = 4 and seq_len = 8, we get | s0, s7 | s1, s6 | s2, s5 | s3, s4 |.
+
+ Args:
+ batch (List[torch.Tensor] or Tensor): The input tensor(s) to split.
+ sp_group (ProcessGroup): The process group for sequence parallelism.
+ seq_dim (int): The sequence dimension to split.
+ is_label (bool): If True, mask and shift the tensor for next token prediction.
+
+ """
+ sp_size = dist.get_world_size(sp_group)
+ sp_rank = dist.get_rank(sp_group)
+ if isinstance(batch, torch.Tensor):
+ batch = [batch]
+ seq_dim = seq_dim if seq_dim != -1 else batch[0].dim() - 1
+
+ if sp_size > 1:
+ for idx, tensor in enumerate(batch):
+ assert (
+ tensor.shape[seq_dim] // (sp_size * 2) > 1 and tensor.shape[seq_dim] % (sp_size * 2) == 0
+ ), f"Bro, the seq length {tensor.shape[seq_dim]} for tensor {idx} can't be split by {sp_size * 2}!"
+ if is_label:
+ assert tensor.dim() == 2, "Label shape should be (B, Seqlen)"
+ tensor = torch.cat([tensor[:, 1:], torch.full_like(tensor[:, :1], -100)], dim=1)
+
+ tensor = tensor.view(
+ *tensor.shape[:seq_dim],
+ 2 * sp_size,
+ tensor.shape[seq_dim] // (2 * sp_size),
+ *tensor.shape[seq_dim + 1 :],
+ )
+ indices = torch.tensor([sp_rank, 2 * sp_size - 1 - sp_rank], device=tensor.device)
+ tensor = tensor.index_select(seq_dim, indices).contiguous()
+ # (B, 2, Sq // (2 * sp_size), ...) -> (B, Sq // sp_size, ...)
+ batch[idx] = tensor.view(*tensor.shape[:seq_dim], -1, *tensor.shape[seq_dim + 2 :])
+
+ if len(batch) == 1:
+ return batch[0]
+ return batch
+
+
+def split_varlen_zigzag(
+ batch: Union[List[torch.Tensor], torch.Tensor],
+ cu_seqlens: torch.Tensor,
+ sp_group: ProcessGroup,
+ max_seqlen: int = 0,
+ is_2d: bool = False,
+ is_label: bool = False,
+) -> Union[List[torch.Tensor], torch.Tensor]:
+ """Split each sequence in a batch of packed sequences in a zigzag fashion.
+ For each tensor in batch, return packed sequences if is_2d is False;
+ else return a padded batch of sequences.
+
+ Args:
+ batch (List[torch.Tensor]): Packed sequences of shape (B * Sq, ...), or (B, Sq, ...) if is_2d.
+ cu_seqlens (torch.Tensor): Cumulative sequence lengths of shape (B + 1) before splitting.
+ sp_group (ProcessGroup): The process group for sequence parallelism.
+ max_seqlen (int): The maximum sequence length in the batch before splitting.
+ is_2d (bool): If True, then input has batch size and sequence length split into two dimensions.
+ is_label (bool): If True, mask out the first token in each sequence ().
+
+ Returns:
+ batch (List[torch.Tensor]): Packed sequences of shape (B * max_seqlen // sp_size)
+ or (B, max_seqlen // sp_size, ...) if is_2d
+ """
+ sp_size = dist.get_world_size(sp_group)
+ sp_rank = dist.get_rank(sp_group)
+ if is_2d:
+ assert max_seqlen > 0, "max_seqlen must be provided for 2D input"
+
+ if isinstance(batch, torch.Tensor):
+ batch = [batch]
+ for i, packed_seq in enumerate(batch):
+ device = packed_seq.device
+ dtype = packed_seq.dtype
+
+ if is_2d:
+ assert max_seqlen % (sp_size * 2) == 0
+ # Recreate a padded tensor with the new max seqlen
+ shape = (packed_seq.shape[0], max_seqlen // sp_size, *packed_seq.shape[2:])
+ local_seq = torch.zeros(shape, dtype=dtype, device=device)
+ else:
+ total_seqlen = cu_seqlens[-1]
+ assert (
+ total_seqlen % (2 * sp_size) == 0
+ ), f"total_seqlen {total_seqlen} must be divisible by 2 * sp_size = {2 * sp_size}"
+ local_seq = []
+
+ for j in range(len(cu_seqlens) - 1):
+ start, end = cu_seqlens[j], cu_seqlens[j + 1]
+ seqlen = end - start
+ assert (
+ seqlen % (2 * sp_size) == 0
+ ), f"batch {i} seq {j}'s length ({seqlen}) must be divisible by 2 * sp_size = {2 * sp_size} for splitting"
+
+ if is_2d:
+ seq = packed_seq[j][:seqlen]
+ if is_label:
+ # Shift one position to the right for next token prediction
+ seq = torch.cat([seq[1:], torch.tensor([-100], dtype=dtype, device=device)])
+
+ seq = seq.chunk(2 * sp_size, dim=0)
+ half = seqlen // sp_size // 2
+ local_seq[j][:half] = seq[sp_rank]
+ local_seq[j][half : seqlen // sp_size] = seq[2 * sp_size - 1 - sp_rank]
+ else:
+ seq = packed_seq[start:end]
+ if is_label:
+ seq = torch.cat(seq[1:], torch.tensor([-100], dtype=dtype, device=device))
+ seq = seq.chunk(sp_size * 2)
+ local_seq.extend([seq[sp_rank], seq[2 * sp_size - 1 - sp_rank]])
+
+ if is_2d:
+ batch[i] = local_seq.contiguous()
+ else:
+ batch[i] = torch.cat(local_seq, dim=0)
+
+ if len(batch) == 1:
+ batch = batch[0]
+ return batch
+
+
+def is_share_sp_tp(sp_mode: str):
+ """sp_mode "ring" and "split_gather" use the TP group as SP group
+ to split both the vocab and sequence, so we must gather the sequence
+ to correctly get logits at each positions.
+ """
+ return sp_mode in ["ring", "split_gather"]
+
+
+class RingComm:
+ def __init__(self, process_group: dist.ProcessGroup):
+ self._process_group = process_group
+ self._ops = []
+ self.rank = dist.get_rank(self._process_group)
+ self.world_size = dist.get_world_size(self._process_group)
+ self._reqs = []
+
+ self.send_rank = (self.rank + 1) % self.world_size
+ self.recv_rank = (self.rank - 1) % self.world_size
+
+ self.send_rank = dist.get_global_rank(self._process_group, self.send_rank)
+ self.recv_rank = dist.get_global_rank(self._process_group, self.recv_rank)
+
+ def send_recv(
+ self,
+ send_tensor: torch.Tensor,
+ recv_tensor: Optional[torch.Tensor] = None,
+ commit: bool = True,
+ ) -> torch.Tensor:
+ if recv_tensor is None:
+ res = torch.empty_like(send_tensor)
+ else:
+ res = recv_tensor
+
+ # looks like batch_isend_irecv doesn't deadlock even
+ # when we don't swap send recv ops based on rank
+ send_op = dist.P2POp(dist.isend, send_tensor, self.send_rank, group=self._process_group)
+ recv_op = dist.P2POp(dist.irecv, res, self.recv_rank, group=self._process_group)
+ self._ops.extend([send_op, recv_op])
+
+ if commit:
+ self._reqs = dist.batch_isend_irecv(self._ops)
+ return res
+
+ def commit(self):
+ assert len(self._ops) > 0, "No ops to commit"
+ self._reqs = dist.batch_isend_irecv(self._ops)
+
+ def wait(self):
+ assert len(self._reqs) > 0, "No requests to wait for"
+ for req in self._reqs:
+ req.wait()
+ self._reqs = []
+ self._ops = []
+
+
+@torch.jit.script
+def get_half_index(cu_seqlens, *, front: bool):
+ index = torch.zeros(cu_seqlens[-1], dtype=torch.bool, device=cu_seqlens.device)
+ for i in range(len(cu_seqlens) - 1):
+ start, end = cu_seqlens[i], cu_seqlens[i + 1]
+ if front:
+ end = (start + end) // 2
+ else:
+ start = (start + end) // 2
+ index[start:end] = True
+ return index
diff --git a/colossalai/shardformer/modeling/command.py b/colossalai/shardformer/modeling/command.py
index 5b36fc7db..67c20eed8 100644
--- a/colossalai/shardformer/modeling/command.py
+++ b/colossalai/shardformer/modeling/command.py
@@ -26,6 +26,8 @@ from colossalai.shardformer.shard import ShardConfig
from ..layer import ColoAttention, dist_cross_entropy
+_SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring", "ring_attn"]
+
class CommandPipelineForwards:
"""
@@ -349,7 +351,7 @@ class CommandPipelineForwards:
return {"hidden_states": hidden_states}
-def get_command_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None):
+def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None):
def forward(
self,
hidden_states: torch.Tensor,
@@ -362,7 +364,7 @@ def get_command_flash_attention_forward(shard_config, sp_mode=None, sp_size=None
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
if sp_mode is not None:
- assert sp_mode in ["all_to_all", "split_gather", "ring"], "Invalid sp_mode"
+ assert sp_mode in _SUPPORTED_SP_MODE, f"SP mode {sp_mode} is not supported by {type(self)} yet"
assert (sp_size is not None) and (
sp_group is not None
), "Must specify sp_size and sp_group for sequence parallel"
@@ -459,7 +461,7 @@ def get_command_flash_attention_forward(shard_config, sp_mode=None, sp_size=None
return forward
-def get_command_flash_attention_model_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None):
+def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None):
logger = logging.get_logger(__name__)
def forward(
diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py
index 9ffbca517..af610500a 100644
--- a/colossalai/shardformer/modeling/llama.py
+++ b/colossalai/shardformer/modeling/llama.py
@@ -1,8 +1,9 @@
import math
import warnings
-from typing import List, Optional, Tuple, Union
+from typing import Dict, List, Optional, Tuple, Union
import torch
+import torch.distributed
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
@@ -24,14 +25,14 @@ from transformers.models.llama.modeling_llama import (
from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager
-from colossalai.shardformer.layer._operation import (
- all_to_all_comm,
- gather_forward_split_backward,
- split_forward_gather_backward,
-)
+from colossalai.shardformer.layer import AttnMaskType
+from colossalai.shardformer.layer._operation import all_to_all_comm, gather_sp_output, split_forward_gather_backward
+from colossalai.shardformer.layer.utils import is_share_sp_tp, split_batch_zigzag
from colossalai.shardformer.shard import ShardConfig
-from ..layer import ColoAttention, dist_cross_entropy
+from ..layer import ColoAttention, RingAttention, dist_cross_entropy
+
+_SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring", "ring_attn"]
class LlamaPipelineForwards:
@@ -57,6 +58,10 @@ class LlamaPipelineForwards:
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
+ # Split output only when computing cross entropy using llama_for_causal_lm_forward
+ # or get_lm_forward_with_dist_cross_entropy
+ # Default to True to avoid bug when calling classification forward from huggingface
+ force_sp_output_gather: bool = True,
):
logger = logging.get_logger(__name__)
@@ -97,7 +102,7 @@ class LlamaPipelineForwards:
sp_group = shard_config.sequence_parallel_process_group
sp_size = shard_config.sequence_parallel_size
if sp_mode == "all_to_all" and not stage_manager.is_first_stage():
- # For correct positions ids. The states will be gather along the seq dim in the attention layer later.
+ # For generating full positions ids, as the states will be gather along the seq dim in the attention layer later.
seq_length *= sp_size
past_seen_tokens = 0
@@ -127,22 +132,36 @@ class LlamaPipelineForwards:
position_ids = cache_position.unsqueeze(0)
# embed positions, for the first stage, hidden_states is the input embeddings,
# for the other stages, hidden_states is the output of the previous stage
- if shard_config.enable_flash_attention:
+ if not stage_manager.is_first_stage() and sp_mode == "ring_attn":
+ _, attn_kwargs, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group)
+ elif shard_config.enable_flash_attention:
# in this case, attention_mask is a dict rather than a tensor
mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past)
- attention_mask = ColoAttention.prepare_attn_kwargs(
+ attn_kwargs = ColoAttention.prepare_attn_kwargs(
mask_shape,
hidden_states.dtype,
hidden_states.device,
q_padding_mask=attention_mask,
is_causal=True,
+ invert=(sp_mode != "ring_attn"),
)
else:
- attention_mask = self._update_causal_mask(attention_mask, hidden_states, cache_position)
+ attn_kwargs = self._update_causal_mask(attention_mask, hidden_states, cache_position)
# Support SP + PP
+ # TODO: support padded casual cu_seqlens across stages
if stage_manager.is_first_stage():
- if sp_mode in ["ring", "split_gather"]:
+ # Ring Attention zigzag batch processing
+ if sp_mode == "ring_attn":
+ assert shard_config.enable_flash_attention, "Ring Attention inherently requires Flash Attention."
+ if attn_kwargs["attention_mask_type"] == AttnMaskType.PADDED_CAUSAL:
+ hidden_states, attn_kwargs, position_ids = RingAttention.prepare_varlen_batch(
+ attention_mask, sp_group, hidden_states, position_ids
+ )
+ else:
+ hidden_states, position_ids = split_batch_zigzag([hidden_states, position_ids], sp_group)
+
+ elif is_share_sp_tp(sp_mode):
hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group)
elif sp_mode == "all_to_all":
hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group, 1 / sp_size)
@@ -177,12 +196,11 @@ class LlamaPipelineForwards:
for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx):
if output_hidden_states:
all_hidden_states += (hidden_states,)
-
if idx - start_idx < num_ckpt_layers:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
- attention_mask,
+ attn_kwargs,
position_ids,
past_key_values,
output_attentions,
@@ -192,14 +210,13 @@ class LlamaPipelineForwards:
else:
layer_outputs = decoder_layer(
hidden_states,
- attention_mask=attention_mask,
+ attention_mask=attn_kwargs,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)
-
hidden_states = layer_outputs[0]
if use_cache:
@@ -209,10 +226,8 @@ class LlamaPipelineForwards:
if stage_manager.is_last_stage():
hidden_states = self.norm(hidden_states)
- if sp_mode == "ring" or sp_mode == "split_gather":
- hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group)
- elif sp_mode == "all_to_all":
- hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size)
+ if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode):
+ hidden_states = gather_sp_output(hidden_states, sp_group, sp_mode)
# add hidden states from the last decoder layer
if output_hidden_states:
@@ -298,6 +313,15 @@ class LlamaPipelineForwards:
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
output_hidden_states = False
+ if shard_config.sequence_parallelism_mode == "ring_attn" and shard_config.parallel_output:
+ # Split labels in a zigzag fashion too
+ sp_group = shard_config.sequence_parallel_process_group
+ if attention_mask.bool().all():
+ labels = split_batch_zigzag(labels, sp_group, seq_dim=1)
+ else:
+ # [B, max_seqlen // sp_size]
+ labels, _, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group, labels, is_label=True)
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = LlamaPipelineForwards.llama_model_forward(
self.model,
@@ -315,6 +339,7 @@ class LlamaPipelineForwards:
hidden_states=hidden_states,
stage_index=stage_index,
shard_config=shard_config,
+ force_sp_output_gather=False,
)
past_key_values = None
@@ -457,11 +482,11 @@ class LlamaPipelineForwards:
return {"hidden_states": hidden_states}
-def get_llama_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None):
+def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None):
def forward(
self,
hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
+ attention_mask: Optional[Union[torch.Tensor, Dict]] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
@@ -470,7 +495,7 @@ def get_llama_flash_attention_forward(shard_config, sp_mode=None, sp_size=None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
if sp_mode is not None:
- assert sp_mode in ["all_to_all", "split_gather", "ring"], "Invalid sp_mode"
+ assert sp_mode in _SUPPORTED_SP_MODE, f"SP mode {sp_mode} is not supported by {type(self)} yet"
assert (sp_size is not None) and (
sp_group is not None
), "Must specify sp_size and sp_group for sequence parallel"
@@ -481,7 +506,7 @@ def get_llama_flash_attention_forward(shard_config, sp_mode=None, sp_size=None,
bsz, q_len, _ = hidden_states.size()
# sp: modify sp_len when sequence parallel mode is ring
- if sp_mode in ["split_gather", "ring"]:
+ if is_share_sp_tp(sp_mode):
q_len *= sp_size
if self.config.pretraining_tp > 1:
@@ -526,6 +551,7 @@ def get_llama_flash_attention_forward(shard_config, sp_mode=None, sp_size=None,
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
+
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
@@ -537,12 +563,21 @@ def get_llama_flash_attention_forward(shard_config, sp_mode=None, sp_size=None,
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
- if shard_config.enable_flash_attention:
+ if sp_mode == "ring_attn":
+ attn_output = RingAttention.attention(
+ query_states,
+ key_states,
+ value_states,
+ sp_group,
+ **attention_mask,
+ inner_ring_size=shard_config.inner_ring_size,
+ )
+
+ elif shard_config.enable_flash_attention:
assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict."
attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask)
else:
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
-
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
@@ -588,7 +623,7 @@ def get_llama_flash_attention_forward(shard_config, sp_mode=None, sp_size=None,
return forward
-def get_llama_flash_attention_model_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None):
+def get_llama_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None):
logger = logging.get_logger(__name__)
def forward(
@@ -603,6 +638,10 @@ def get_llama_flash_attention_model_forward(shard_config, sp_mode=None, sp_size=
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
+ # Split output only when computing cross entropy using llama_for_causal_lm_forward
+ # or get_lm_forward_with_dist_cross_entropy
+ # Default to True to avoid bug when calling classification forward from huggingface
+ force_sp_output_gather: bool = True,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@@ -629,32 +668,45 @@ def get_llama_flash_attention_model_forward(shard_config, sp_mode=None, sp_size=
past_seen_tokens = 0
seq_len = inputs_embeds.shape[1]
+ batch_size = inputs_embeds.shape[0]
if use_cache: # kept for BC (cache positions)
if not isinstance(past_key_values, StaticCache):
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_seen_tokens = past_key_values.get_seq_length()
+
if cache_position is None:
if isinstance(past_key_values, StaticCache):
raise ValueError("cache_position is a required argument when using StaticCache.")
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_len, device=inputs_embeds.device)
-
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
- # in this case, attention_mask is a dict rather than a tensor
if shard_config.enable_flash_attention:
- mask_shape = (inputs_embeds.shape[0], 1, seq_len, past_seen_tokens + seq_len)
- attention_mask = ColoAttention.prepare_attn_kwargs(
+ mask_shape = (batch_size, 1, seq_len, past_seen_tokens + seq_len)
+ attn_kwargs: dict = ColoAttention.prepare_attn_kwargs(
mask_shape,
inputs_embeds.dtype,
inputs_embeds.device,
q_padding_mask=attention_mask,
is_causal=True,
+ invert=(sp_mode != "ring_attn"),
)
- else:
- attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
- if sp_mode in ["ring", "split_gather"]:
+ else:
+ attn_kwargs: torch.Tensor = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
+
+ # Ring Attention zigzag batch processing
+ if sp_mode == "ring_attn":
+ assert shard_config.enable_flash_attention, "Ring Attention inherently requires Flash Attention."
+ if attn_kwargs["attention_mask_type"] == AttnMaskType.PADDED_CAUSAL:
+ inputs_embeds, attn_kwargs, position_ids = RingAttention.prepare_varlen_batch(
+ attention_mask, sp_group, inputs_embeds, position_ids
+ )
+ else:
+ inputs_embeds, position_ids = split_batch_zigzag([inputs_embeds, position_ids], sp_group)
+ attn_kwargs = {"attention_mask_type": attn_kwargs["attention_mask_type"]} # drop redundant tensors
+
+ elif is_share_sp_tp(sp_mode):
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
elif sp_mode == "all_to_all":
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size)
@@ -672,7 +724,7 @@ def get_llama_flash_attention_model_forward(shard_config, sp_mode=None, sp_size=
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
- attention_mask,
+ attn_kwargs,
position_ids,
past_key_values,
output_attentions,
@@ -683,7 +735,7 @@ def get_llama_flash_attention_model_forward(shard_config, sp_mode=None, sp_size=
else:
layer_outputs = decoder_layer(
hidden_states,
- attention_mask=attention_mask,
+ attention_mask=attn_kwargs,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
@@ -700,11 +752,9 @@ def get_llama_flash_attention_model_forward(shard_config, sp_mode=None, sp_size=
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
-
- if sp_mode == "ring" or sp_mode == "split_gather":
- hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group)
- elif sp_mode == "all_to_all":
- hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size)
+ # Cases that don't support parallelizing cross entropy computation along sequence
+ if (not shard_config.parallel_output) or is_share_sp_tp(sp_mode) or force_sp_output_gather:
+ hidden_states = gather_sp_output(hidden_states, sp_group, sp_mode)
# add hidden states from the last decoder layer
if output_hidden_states:
@@ -777,6 +827,15 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ if shard_config.sequence_parallelism_mode == "ring_attn" and shard_config.parallel_output:
+ # Special processing: Split labels in a zigzag fashion too
+ sp_group = shard_config.sequence_parallel_process_group
+ if attention_mask.bool().all():
+ labels = split_batch_zigzag(labels, sp_group, seq_dim=1, is_label=True)
+ else:
+ # [B, max_seq_len // sp_size]
+ labels, _, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group, labels, is_label=True)
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
@@ -789,6 +848,7 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
+ force_sp_output_gather=False,
)
hidden_states = outputs[0]
@@ -799,7 +859,6 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
else:
logits = self.lm_head(hidden_states)
logits = logits.float()
-
loss = dist_cross_entropy(
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype
)
diff --git a/colossalai/shardformer/policies/base_policy.py b/colossalai/shardformer/policies/base_policy.py
index 282cf0464..7c1e6f0d7 100644
--- a/colossalai/shardformer/policies/base_policy.py
+++ b/colossalai/shardformer/policies/base_policy.py
@@ -75,6 +75,7 @@ class Policy(ABC):
def __init__(self) -> None:
self.shard_config: Optional[ShardConfig] = None
self.model: Optional[Module] = None
+ self.is_causal = None # Whether we're doing causal lm, i.e. using cross entropy
def set_model(self, model: nn.Module) -> None:
r"""
diff --git a/colossalai/shardformer/policies/command.py b/colossalai/shardformer/policies/command.py
index a9b915d10..1efd3d017 100644
--- a/colossalai/shardformer/policies/command.py
+++ b/colossalai/shardformer/policies/command.py
@@ -69,13 +69,18 @@ class CommandPolicy(Policy):
sp_size = self.shard_config.sequence_parallel_size or None
sp_group = self.shard_config.sequence_parallel_process_group or None
sp_partial_derived = sp_mode in ["split_gather", "ring"]
+ if sp_mode == "ring_attn" and not self.is_causal:
+ raise ValueError("Ring attention is only meant for causal language modeling.")
+ tp_size = self.shard_config.tensor_parallel_size or None
+ num_q_heads = self.model.config.num_attention_heads
+ num_kv_heads = getattr(self.model.config, "num_key_value_heads", None)
if sp_mode == "all_to_all":
- decoder_attribute_replacement = {
- "num_heads": self.model.config.num_attention_heads // sp_size,
- }
- if getattr(self.model.config, "num_key_value_heads", False):
- decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size
+ num_q_heads //= sp_size
+ decoder_attribute_replacement = {"num_heads": num_q_heads}
+ if num_kv_heads:
+ num_kv_heads //= sp_size
+ decoder_attribute_replacement["num_key_value_heads"] = num_kv_heads
policy[attn_cls] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement,
@@ -104,21 +109,18 @@ class CommandPolicy(Policy):
if self.shard_config.enable_tensor_parallelism:
assert (
- self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
+ num_q_heads % tp_size == 0
), f"The number of attention heads must be divisible by tensor parallel size."
if hasattr(self.model.config, "num_key_value_heads"):
assert (
- self.model.config.num_key_value_heads >= self.shard_config.tensor_parallel_size
- and self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0
+ num_kv_heads >= tp_size and num_kv_heads % tp_size == 0
), f"The number of key_value heads must be divisible by, and must not be less than tensor parallel size."
decoder_attribute_replacement = {
- "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
- "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
+ "self_attn.hidden_size": self.model.config.hidden_size // tp_size,
+ "self_attn.num_heads": num_q_heads // tp_size,
}
if getattr(self.model.config, "num_key_value_heads", False):
- decoder_attribute_replacement["self_attn.num_key_value_heads"] = (
- self.model.config.num_key_value_heads // self.shard_config.tensor_parallel_size
- )
+ decoder_attribute_replacement["self_attn.num_key_value_heads"] = num_kv_heads // tp_size
policy[CohereDecoderLayer] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement,
@@ -290,10 +292,11 @@ class CommandForCausalLMPolicy(CommandPolicy):
def module_policy(self):
from transformers import CohereForCausalLM
+ self.is_causal = True
policy = super().module_policy()
if self.shard_config.enable_tensor_parallelism:
- # add a new item for casual lm
+ # add a new item for causal lm
new_item = {
CohereForCausalLM: ModulePolicyDescription(
sub_module_replacement=[
diff --git a/colossalai/shardformer/policies/deepseek.py b/colossalai/shardformer/policies/deepseek.py
index 605f69c4a..ea68649d5 100644
--- a/colossalai/shardformer/policies/deepseek.py
+++ b/colossalai/shardformer/policies/deepseek.py
@@ -298,7 +298,7 @@ class DeepseekForCausalLMPolicy(DeepseekPolicy):
policy = super().module_policy()
# TODO: assign pg mesh from plugin to all modules
if self.shard_config.enable_tensor_parallelism:
- # add a new item for casual lm
+ # add a new item for causal lm
new_item = {
"DeepseekForCausalLM": ModulePolicyDescription(
sub_module_replacement=[
diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py
index 36491b4b5..f72a72df0 100644
--- a/colossalai/shardformer/policies/llama.py
+++ b/colossalai/shardformer/policies/llama.py
@@ -69,13 +69,20 @@ class LlamaPolicy(Policy):
sp_size = self.shard_config.sequence_parallel_size or None
sp_group = self.shard_config.sequence_parallel_process_group or None
sp_partial_derived = sp_mode in ["split_gather", "ring"]
+ if sp_mode == "ring_attn" and not self.is_causal:
+ raise ValueError("Ring attention is only meant for causal language modeling.")
+
+ tp_size = self.shard_config.tensor_parallel_size
+ # Modified by SP and TP
+ num_q_heads = self.model.config.num_attention_heads
+ num_kv_heads = getattr(self.model.config, "num_key_value_heads", None)
if sp_mode == "all_to_all":
- decoder_attribute_replacement = {
- "num_heads": self.model.config.num_attention_heads // sp_size,
- }
- if getattr(self.model.config, "num_key_value_heads", False):
- decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size
+ num_q_heads //= sp_size
+ decoder_attribute_replacement = {"num_heads": num_q_heads}
+ if num_kv_heads:
+ num_kv_heads //= sp_size
+ decoder_attribute_replacement["num_key_value_heads"] = num_kv_heads
policy[attn_cls] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement,
@@ -104,21 +111,20 @@ class LlamaPolicy(Policy):
if self.shard_config.enable_tensor_parallelism:
assert (
- self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
+ num_q_heads % tp_size == 0
), f"The number of attention heads must be divisible by tensor parallel size."
if hasattr(self.model.config, "num_key_value_heads"):
assert (
- self.model.config.num_key_value_heads >= self.shard_config.tensor_parallel_size
- and self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0
+ num_kv_heads >= tp_size and num_kv_heads % tp_size == 0
), f"The number of key_value heads must be divisible by, and must not be less than tensor parallel size."
+ num_q_heads //= tp_size
decoder_attribute_replacement = {
- "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
- "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
+ "self_attn.hidden_size": self.model.config.hidden_size // tp_size,
+ "self_attn.num_heads": num_q_heads,
}
if getattr(self.model.config, "num_key_value_heads", False):
- decoder_attribute_replacement["self_attn.num_key_value_heads"] = (
- self.model.config.num_key_value_heads // self.shard_config.tensor_parallel_size
- )
+ num_kv_heads //= tp_size
+ decoder_attribute_replacement["self_attn.num_key_value_heads"] = num_kv_heads
policy[LlamaDecoderLayer] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement,
@@ -295,10 +301,11 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
def module_policy(self):
from transformers import LlamaForCausalLM
+ self.is_causal = True
policy = super().module_policy()
if self.shard_config.enable_tensor_parallelism:
- # add a new item for casual lm
+ # add a new item for causal lm
new_item = {
LlamaForCausalLM: ModulePolicyDescription(
sub_module_replacement=[
@@ -313,10 +320,6 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
],
)
}
- if self.shard_config.parallel_output:
- new_item[LlamaForCausalLM].method_replacement = {
- "forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)
- }
else:
new_item = {
LlamaForCausalLM: ModulePolicyDescription(
@@ -336,7 +339,11 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
self.set_pipeline_forward(
model_cls=LlamaForCausalLM, new_forward=LlamaPipelineForwards.llama_for_causal_lm_forward, policy=policy
)
-
+ elif self.shard_config.enable_tensor_parallelism or self.shard_config.enable_sequence_parallelism:
+ # Compute loss distributedly along the sequence dimension
+ new_item[LlamaForCausalLM].method_replacement = {
+ "forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)
+ }
return policy
def get_held_layers(self) -> List[Module]:
diff --git a/colossalai/shardformer/policies/mistral.py b/colossalai/shardformer/policies/mistral.py
index c5a0277a5..6ea27e210 100644
--- a/colossalai/shardformer/policies/mistral.py
+++ b/colossalai/shardformer/policies/mistral.py
@@ -271,7 +271,7 @@ class MistralForCausalLMPolicy(MistralPolicy):
policy = super().module_policy()
if self.shard_config.enable_tensor_parallelism:
- # add a new item for casual lm
+ # add a new item for causal lm
new_item = {
MistralForCausalLM: ModulePolicyDescription(
sub_module_replacement=[
diff --git a/colossalai/shardformer/policies/mixtral.py b/colossalai/shardformer/policies/mixtral.py
index 10df143c9..e11edae9f 100644
--- a/colossalai/shardformer/policies/mixtral.py
+++ b/colossalai/shardformer/policies/mixtral.py
@@ -275,7 +275,7 @@ class MixtralForCausalLMPolicy(MixtralPolicy):
policy = super().module_policy()
# TODO: assign pg mesh from plugin to all modules
if self.shard_config.enable_tensor_parallelism:
- # add a new item for casual lm
+ # add a new item for causal lm
new_item = {
MixtralForCausalLM: ModulePolicyDescription(
sub_module_replacement=[
diff --git a/colossalai/shardformer/policies/qwen2.py b/colossalai/shardformer/policies/qwen2.py
index 362c14060..235dc7d56 100644
--- a/colossalai/shardformer/policies/qwen2.py
+++ b/colossalai/shardformer/policies/qwen2.py
@@ -313,7 +313,7 @@ class Qwen2ForCausalLMPolicy(Qwen2Policy):
setattr(self.shard_config, "causal_lm", True)
if self.shard_config.enable_tensor_parallelism:
- # add a new item for casual lm
+ # add a new item for causal lm
new_item = {
Qwen2ForCausalLM: ModulePolicyDescription(
sub_module_replacement=[
diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py
index 163d7a7bb..70eb271c9 100644
--- a/colossalai/shardformer/shard/shard_config.py
+++ b/colossalai/shardformer/shard/shard_config.py
@@ -10,7 +10,7 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
from .grad_ckpt_config import GradientCheckpointConfig
__all__ = ["ShardConfig"]
-SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all"]
+SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all", "ring_attn"]
@dataclass
@@ -29,6 +29,8 @@ class ShardConfig:
enable_sequence_overlap (bool): Whether to turn on sequence overlap, which overlap the computation and communication in sequence parallelism. It can only be used when enable_sequence_parallelism is True. Defaults to False.
gradient_checkpoint_config (Optional[GradientCheckpointConfig]): The gradient checkpoint config. Defaults to None.
enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalization', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False.
+ parallel_output (bool): For TP: whether to use parallelize cross entropy computation along the feature dim.
+ For SP: set to True to NOT gather the output along the seq dim.
"""
tensor_parallel_process_group: Optional[ProcessGroup] = None
@@ -47,10 +49,11 @@ class ShardConfig:
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None
extra_kwargs: Dict[str, Any] = field(default_factory=dict)
+ # For ring attention
+ inner_ring_size: Optional[int] = None
# for moe related
moe_dp_group: Optional[ProcessGroup] = None
ep_group: Optional[ProcessGroup] = None
-
# pipeline_parallel_size: int
# data_parallel_size: int
# tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']
@@ -80,9 +83,9 @@ class ShardConfig:
self.enable_tensor_parallelism
), f"sequence parallelism mode {self.sequence_parallelism_mode} can only be used when enable_tensor_parallelism is True"
elif self.sequence_parallelism_mode in ["all_to_all"]:
- assert (
- not self.enable_tensor_parallelism
- ), f"sequence parallelism mode {self.sequence_parallelism_mode} can only be used when enable_tensor_parallelism is False"
+ # assert (
+ # not self.enable_tensor_parallelism
+ # ), f"sequence parallelism mode {self.sequence_parallelism_mode} can only be used when enable_tensor_parallelism is False"
if self.enable_sequence_overlap:
self.enable_sequence_overlap = False
warnings.warn(
diff --git a/examples/language/llama/benchmark.py b/examples/language/llama/benchmark.py
index e530e2d6a..093377e7a 100644
--- a/examples/language/llama/benchmark.py
+++ b/examples/language/llama/benchmark.py
@@ -28,6 +28,7 @@ warnings.filterwarnings("ignore")
# Constants
# ==============================
+# We have lots of llamas for your choice!
MODEL_CONFIGS = {
"100m": LlamaConfig(
max_position_embeddings=4096,
@@ -36,6 +37,7 @@ MODEL_CONFIGS = {
intermediate_size=2048,
hidden_size=1024,
),
+ "5b": LlamaConfig(max_position_embeddings=4096, num_key_value_heads=8),
"7b": LlamaConfig(max_position_embeddings=4096),
"13b": LlamaConfig(
hidden_size=5120,
@@ -68,9 +70,6 @@ def main():
default="gemini",
help="Choose which plugin to use",
)
- parser.add_argument(
- "--overlap", action="store_true", help="Overlap communication with computation in Pipeline Parallel."
- )
parser.add_argument("-b", "--batch_size", type=int, default=2, help="Batch size")
parser.add_argument("-s", "--num_steps", type=int, default=5, help="Number of steps to run")
parser.add_argument("-i", "--ignore_steps", type=int, default=2, help="Number of steps to ignore")
@@ -94,11 +93,24 @@ def main():
parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved"])
parser.add_argument("--n_chunks", default=1, help="number of model chunks", type=eval)
- parser.add_argument("--profile", action="store_true", help="Profile the code", default=False)
+ parser.add_argument("--profile", action="store_true", help="Profile the code")
+ parser.add_argument(
+ "--nsys",
+ action="store_true",
+ help="Use nsys for profiling. \
+ You should put something like this before colossalai launch: \
+ nsys profile -w true -t cuda,cudnn,cublas -s cpu --capture-range=cudaProfilerApi --capture-range-end=stop --cudabacktrace=true -x true --python-backtrace=cuda -o prof_out",
+ )
parser.add_argument("--disable-async-reduce", action="store_true", help="Disable the asynchronous reduce operation")
parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number")
parser.add_argument("--no_cache", action="store_true")
parser.add_argument("--overlap_allgather", action="store_true")
+ parser.add_argument(
+ "--sp_mode",
+ default="all_to_all",
+ choices=["all_to_all", "ring_attn", "ring", "split_gather"],
+ help="Sequence parallelism mode",
+ )
args = parser.parse_args()
colossalai.launch_from_torch()
@@ -195,12 +207,12 @@ def main():
num_model_chunks=args.n_chunks,
zero_stage=args.zero,
sp_size=args.sp,
+ sequence_parallelism_mode=args.sp_mode,
enable_sequence_parallelism=args.sp > 1,
enable_fused_normalization=torch.cuda.is_available(),
enable_flash_attention=args.xformers,
microbatch_size=args.mbs,
precision="bf16",
- overlap_p2p=args.overlap,
enable_metadata_cache=not args.no_cache,
overlap_allgather=args.overlap_allgather,
**hybrid_kwargs,
@@ -218,7 +230,6 @@ def main():
microbatch_size=args.mbs,
initial_scale=2**8,
precision="bf16",
- overlap_p2p=args.overlap,
)
else:
raise ValueError(f"Unknown plugin {args.plugin}")
@@ -295,6 +306,7 @@ def main():
args.ignore_steps,
1, # avoid creating massive log files
save_dir=f"profile/{time.strftime('%H:%M', time.localtime())}-{args.plugin}-llama-{args.config}",
+ nsys=args.nsys,
) as prof:
if isinstance(plugin, HybridParallelPlugin) and args.pp > 1:
data_iter = iter(dataloader)
@@ -320,13 +332,16 @@ def main():
performance_evaluator.on_step_start(step)
outputs = model(**batch)
loss = outputs[0]
+ del outputs # free memory
+
+ if dist.get_rank() == dist.get_world_size() - 1:
+ print(f"Step {step} loss: {loss}")
booster.backward(loss, optimizer)
optimizer.step()
optimizer.zero_grad()
performance_evaluator.on_step_end(**batch)
prof.step()
-
performance_evaluator.on_fit_end()
coordinator.print_on_master(f"Max CUDA memory usage: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB")
diff --git a/examples/language/opt/README.md b/examples/language/opt/README.md
index af1e79437..694c5cf91 100644
--- a/examples/language/opt/README.md
+++ b/examples/language/opt/README.md
@@ -17,7 +17,7 @@ limitations under the License.
## OPT
Meta recently released [Open Pretrained Transformer (OPT)](https://github.com/facebookresearch/metaseq), a 175-Billion parameter AI language model, which stimulates AI programmers to perform various downstream tasks and application deployments.
-The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) demonstrates fine-tuning Casual Language Modelling at low cost.
+The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) demonstrates fine-tuning Causal Language Modelling at low cost.
## Our Modifications
diff --git a/examples/language/performance_evaluator.py b/examples/language/performance_evaluator.py
index ca4a02cd2..f5ad1d23d 100644
--- a/examples/language/performance_evaluator.py
+++ b/examples/language/performance_evaluator.py
@@ -28,7 +28,7 @@ def all_reduce_mean(x: float, world_size: int) -> float:
return tensor.item()
-def get_profile_context(enable_flag, warmup_steps, active_steps, save_dir):
+def get_profile_context(enable_flag, warmup_steps, active_steps, save_dir, nsys=False):
class DummyProfiler:
def __init__(self):
self.step_number = 0
@@ -42,7 +42,29 @@ def get_profile_context(enable_flag, warmup_steps, active_steps, save_dir):
def __exit__(self, exc_type, exc_value, traceback):
pass
+ class NsysProfiler:
+ def __init__(self, warmup_steps, active_steps):
+ self.step_number = 0
+ self.warmup_steps = warmup_steps
+ self.active_steps = active_steps
+
+ def step(self):
+ if self.step_number == self.warmup_steps:
+ torch.cuda.cudart().cudaProfilerStart()
+ elif self.step_number == self.warmup_steps + self.active_steps:
+ torch.cuda.cudart().cudaProfilerStop()
+ self.step_number += 1
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ pass
+
if enable_flag:
+ if nsys:
+ return NsysProfiler(warmup_steps, active_steps)
+
return profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
schedule=schedule(wait=0, warmup=warmup_steps, active=active_steps),
diff --git a/examples/tutorial/opt/opt/README.md b/examples/tutorial/opt/opt/README.md
index a01209cbd..3776e0c64 100644
--- a/examples/tutorial/opt/opt/README.md
+++ b/examples/tutorial/opt/opt/README.md
@@ -19,7 +19,7 @@ limitations under the License.
## OPT
Meta recently released [Open Pretrained Transformer (OPT)](https://github.com/facebookresearch/metaseq), a 175-Billion parameter AI language model, which stimulates AI programmers to perform various downstream tasks and application deployments.
-The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) demonstrates fine-tuning Casual Language Modelling at low cost.
+The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) demonstrates fine-tuning causal Language Modelling at low cost.
We are using the pre-training weights of the OPT model provided by Hugging Face Hub on the raw WikiText-2 (no tokens were replaced before
the tokenization). This training script is adapted from the [HuggingFace Language Modelling examples](https://github.com/huggingface/transformers/tree/main/examples/pytorch/language-modeling).
diff --git a/extensions/pybind/flash_attention/flash_attention_dao_cuda.py b/extensions/pybind/flash_attention/flash_attention_dao_cuda.py
index a108377a8..560d952f6 100644
--- a/extensions/pybind/flash_attention/flash_attention_dao_cuda.py
+++ b/extensions/pybind/flash_attention/flash_attention_dao_cuda.py
@@ -57,14 +57,14 @@ class FlashAttentionDaoCudaExtension(_Extension):
q_indices: Optional[torch.Tensor] = None,
kv_indices: Optional[torch.Tensor] = None,
):
- # [B, N, S, D] -> [B, S, N, D]
+ # [B, H, S, D] -> [B, S, H, D]
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
b, s_q = q.shape[:2]
if cu_seqlens_q is not None:
# padded / padded causal
- # unpad input: [B, S, N, D] -> [T, N, D]
+ # unpad input: [B, S, H, D] -> [T, H, D]
q = _unpad_input(q, q_indices)
kv = _unpad_input(torch.stack(tensors=(k, v), dim=2), kv_indices)
attn_output = flash_attn_varlen_kvpacked_func(
@@ -78,7 +78,7 @@ class FlashAttentionDaoCudaExtension(_Extension):
softmax_scale=scale,
causal=is_causal,
)
- # pad output: [T, N, D] -> [B, S, N, D]
+ # pad output: [T, H, D] -> [B, S, H, D]
attn_output = pad_input(attn_output, q_indices, b, s_q)
else:
# causal / no attn mask
@@ -90,7 +90,7 @@ class FlashAttentionDaoCudaExtension(_Extension):
softmax_scale=scale,
causal=is_causal,
)
- # [B, S, N, D] -> [B, N, S, D]
+ # [B, S, H, D] -> [B, H, S, D]
return attn_output.transpose(1, 2)
return flash_attention
diff --git a/tests/kit/model_zoo/__init__.py b/tests/kit/model_zoo/__init__.py
index 66c794a7d..9c1a11e7b 100644
--- a/tests/kit/model_zoo/__init__.py
+++ b/tests/kit/model_zoo/__init__.py
@@ -22,9 +22,9 @@ COMMON_MODELS = [
"transformers_bloom_for_causal_lm",
"transformers_falcon_for_causal_lm",
"transformers_chatglm_for_conditional_generation",
- "transformers_llama_for_casual_lm",
+ "transformers_llama_for_causal_lm",
"transformers_vit_for_masked_image_modeling",
- "transformers_mistral_for_casual_lm",
+ "transformers_mistral_for_causal_lm",
]
IS_FAST_TEST = os.environ.get("FAST_TEST", "0") == "1"
diff --git a/tests/kit/model_zoo/transformers/command.py b/tests/kit/model_zoo/transformers/command.py
index a8b8842c5..3f4ea4583 100644
--- a/tests/kit/model_zoo/transformers/command.py
+++ b/tests/kit/model_zoo/transformers/command.py
@@ -32,8 +32,8 @@ if HAS_COMMAND:
return dict(input_ids=input_ids, attention_mask=attention_mask)
- # label is needed for casual lm
- def data_gen_for_casual_lm():
+ # label is needed for causal lm
+ def data_gen_for_causal_lm():
data = data_gen()
labels = data["input_ids"].clone()
data["labels"] = labels
@@ -44,7 +44,7 @@ if HAS_COMMAND:
# function to get the loss
loss_fn = lambda output: output["last_hidden_state"].mean()
- loss_fn_for_casual_lm = lambda output: output["loss"]
+ loss_fn_for_causal_lm = lambda output: output["loss"]
loss_fn_for_seq_classification = lambda output: output["logits"].mean()
config = CohereConfig(
@@ -70,10 +70,10 @@ if HAS_COMMAND:
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(
- name="transformers_command_for_casual_lm",
+ name="transformers_command_for_causal_lm",
model_fn=lambda: transformers.CohereForCausalLM(config),
- data_gen_fn=data_gen_for_casual_lm,
+ data_gen_fn=data_gen_for_causal_lm,
output_transform_fn=output_transform_fn,
- loss_fn=loss_fn_for_casual_lm,
+ loss_fn=loss_fn_for_causal_lm,
model_attribute=ModelAttribute(has_control_flow=True),
)
diff --git a/tests/kit/model_zoo/transformers/llama.py b/tests/kit/model_zoo/transformers/llama.py
index 61fa56050..05ac9d8d2 100644
--- a/tests/kit/model_zoo/transformers/llama.py
+++ b/tests/kit/model_zoo/transformers/llama.py
@@ -33,20 +33,21 @@ if HAS_LLAMA:
[1, 15043, 29892, 590, 11203, 338, 274, 1082, 1, 15043, 29892, 590, 11203, 338, 274, 1082],
]
).long()
-
- attention_mask = torch.Tensor(
- [
- [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
- [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
- ]
- ).long()
-
+ attention_mask = torch.ones_like(input_ids)
return dict(input_ids=input_ids, attention_mask=attention_mask)
- # label is needed for casual lm
- def data_gen_for_casual_lm():
+ # label is needed for causal lm
+ def data_gen_for_causal_lm():
data = data_gen()
+
+ # Test padded sequence
+ padding = torch.zeros(2, data["input_ids"].shape[1] // 2, dtype=torch.long)
+ data["input_ids"] = torch.cat([data["input_ids"], padding], dim=1)
+ data["attention_mask"] = torch.cat([data["attention_mask"], padding], dim=1)
+
+ ignore_idx = -100
labels = data["input_ids"].clone()
+ labels[~data["attention_mask"].bool()] = ignore_idx
data["labels"] = labels
return data
@@ -55,7 +56,7 @@ if HAS_LLAMA:
# function to get the loss
loss_fn = lambda output: output["last_hidden_state"].mean()
- loss_fn_for_casual_lm = lambda output: output["loss"]
+ loss_fn_for_causal_lm = lambda output: output["loss"]
loss_fn_for_seq_classification = lambda output: output["logits"].mean()
config = LlamaConfig(
@@ -70,9 +71,17 @@ if HAS_LLAMA:
config.pad_token_id = config.eos_token_id
# register the following models
- # transformers.LlamaModel,
# transformers.LlamaForCausalLM,
+ # transformers.LlamaModel,
# transformers.LlamaForSequenceClassification,
+ model_zoo.register(
+ name="transformers_llama_for_causal_lm",
+ model_fn=lambda: transformers.LlamaForCausalLM(config),
+ data_gen_fn=data_gen_for_causal_lm,
+ output_transform_fn=output_transform_fn,
+ loss_fn=loss_fn_for_causal_lm,
+ model_attribute=ModelAttribute(has_control_flow=True),
+ )
model_zoo.register(
name="transformers_llama",
model_fn=lambda: transformers.LlamaModel(config),
@@ -81,14 +90,6 @@ if HAS_LLAMA:
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True),
)
- model_zoo.register(
- name="transformers_llama_for_casual_lm",
- model_fn=lambda: transformers.LlamaForCausalLM(config),
- data_gen_fn=data_gen_for_casual_lm,
- output_transform_fn=output_transform_fn,
- loss_fn=loss_fn_for_casual_lm,
- model_attribute=ModelAttribute(has_control_flow=True),
- )
model_zoo.register(
name="transformers_llama_for_sequence_classification",
model_fn=lambda: transformers.LlamaForSequenceClassification(config),
diff --git a/tests/kit/model_zoo/transformers/mistral.py b/tests/kit/model_zoo/transformers/mistral.py
index ae5a97002..43fc662cc 100644
--- a/tests/kit/model_zoo/transformers/mistral.py
+++ b/tests/kit/model_zoo/transformers/mistral.py
@@ -64,7 +64,7 @@ model_zoo.register(
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(
- name="transformers_mistral_for_casual_lm",
+ name="transformers_mistral_for_causal_lm",
model_fn=lambda: transformers.MistralForCausalLM(config),
data_gen_fn=data_gen_for_lm,
output_transform_fn=output_transform_fn,
diff --git a/tests/kit/model_zoo/transformers/qwen2.py b/tests/kit/model_zoo/transformers/qwen2.py
index 1c26af698..83bc9f941 100644
--- a/tests/kit/model_zoo/transformers/qwen2.py
+++ b/tests/kit/model_zoo/transformers/qwen2.py
@@ -33,8 +33,8 @@ if HAS_QWEN2:
attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1]]).long()
return dict(input_ids=input_ids, attention_mask=attention_mask)
- # label is needed for casual lm
- def data_gen_for_casual_lm():
+ # label is needed for causal lm
+ def data_gen_for_causal_lm():
data = data_gen()
labels = data["input_ids"].clone()
data["labels"] = labels
@@ -45,7 +45,7 @@ if HAS_QWEN2:
# function to get the loss
loss_fn = lambda output: output["last_hidden_state"].mean()
- loss_fn_for_casual_lm = lambda output: output["loss"]
+ loss_fn_for_causal_lm = lambda output: output["loss"]
loss_fn_for_seq_classification = lambda output: output["logits"].mean()
config = Qwen2Config(
@@ -72,11 +72,11 @@ if HAS_QWEN2:
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(
- name="transformers_qwen2_for_casual_lm",
+ name="transformers_qwen2_for_causal_lm",
model_fn=lambda: transformers.Qwen2ForCausalLM(config),
- data_gen_fn=data_gen_for_casual_lm,
+ data_gen_fn=data_gen_for_causal_lm,
output_transform_fn=output_transform_fn,
- loss_fn=loss_fn_for_casual_lm,
+ loss_fn=loss_fn_for_causal_lm,
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(
diff --git a/tests/test_booster/test_plugin/test_3d_plugin.py b/tests/test_booster/test_plugin/test_3d_plugin.py
index e57cadfd8..3e8532955 100644
--- a/tests/test_booster/test_plugin/test_3d_plugin.py
+++ b/tests/test_booster/test_plugin/test_3d_plugin.py
@@ -97,7 +97,7 @@ def check_3d_plugin(init_method: str = "none", early_stop: bool = True):
# TODO(ver217): add more models
for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.get_sub_registry(
- "transformers_llama_for_casual_lm"
+ "transformers_llama_for_causal_lm"
).items():
err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn)
diff --git a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py
index 8c59f430c..c2a08a541 100644
--- a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py
+++ b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py
@@ -105,7 +105,7 @@ def check_low_level_zero_lora(stage, model_name, early_stop: bool = True):
sub_model_zoo = model_zoo.get_sub_registry(model_name)
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
task_type = None
- if name == "transformers_llama_for_casual_lm":
+ if name == "transformers_llama_for_causal_lm":
task_type = "CAUSAL_LM"
if name == "transformers_llama_for_sequence_classification":
task_type = "SEQ_CLS"
diff --git a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py
index fd13ce0bf..b133be948 100644
--- a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py
+++ b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py
@@ -74,7 +74,7 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b
@clear_cache_before_run()
@parameterize("placement_config", OPTIM_PLACEMENT_CONFIGS)
@parameterize("shard", [True, False])
-@parameterize("model_name", ["transformers_llama_for_casual_lm"])
+@parameterize("model_name", ["transformers_llama_for_causal_lm"])
@parameterize("size_per_shard", [32])
@parameterize("tp_size", [1, 2])
@parameterize("zero_size", [2])
diff --git a/tests/test_checkpoint_io/test_gemini_torch_compability.py b/tests/test_checkpoint_io/test_gemini_torch_compability.py
index 4897907ff..ce4d10322 100644
--- a/tests/test_checkpoint_io/test_gemini_torch_compability.py
+++ b/tests/test_checkpoint_io/test_gemini_torch_compability.py
@@ -20,7 +20,7 @@ from tests.kit.model_zoo import model_zoo
@clear_cache_before_run()
@parameterize("shard", [False, True])
-@parameterize("model_name", ["transformers_llama_for_casual_lm"])
+@parameterize("model_name", ["transformers_llama_for_causal_lm"])
def exam_torch_load_from_gemini(shard: bool, model_name: str):
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
criterion = lambda x: x.mean()
diff --git a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py
index 4f8f26041..86d7924fb 100644
--- a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py
+++ b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py
@@ -39,7 +39,7 @@ else:
@parameterize("shard", [True, False])
-@parameterize("model_name", ["transformers_llama_for_casual_lm"])
+@parameterize("model_name", ["transformers_llama_for_causal_lm"])
@parameterize("size_per_shard", [32])
@parameterize("test_config", TEST_CONFIGS)
@clear_cache_before_run()
diff --git a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py
index ab48944d4..a8e05a25a 100644
--- a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py
+++ b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py
@@ -149,7 +149,7 @@ def check_low_level_zero_lora_checkpointIO(
if name != "transformers_llama":
continue
task_type = None
- if name == "transformers_llama_for_casual_lm":
+ if name == "transformers_llama_for_causal_lm":
task_type = "CAUSAL_LM"
if name == "transformers_llama_for_sequence_classification":
task_type = "SEQ_CLS"
diff --git a/tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py b/tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py
index df8636141..6f8eb2ad2 100644
--- a/tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py
+++ b/tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py
@@ -18,7 +18,7 @@ from tests.kit.model_zoo import model_zoo
@clear_cache_before_run()
-@parameterize("model_name", ["transformers_llama_for_casual_lm"])
+@parameterize("model_name", ["transformers_llama_for_causal_lm"])
@parameterize("plugin_type", ["ddp", "zero", "gemini"])
def exam_from_pretrained(plugin_type: str, model_name: str, shard=True, size_per_shard=32):
(model_fn, data_gen_fn, output_transform_fn, loss_fn, _) = next(
diff --git a/tests/test_lora/test_lora.py b/tests/test_lora/test_lora.py
index 1ae17025d..b0ec767cc 100644
--- a/tests/test_lora/test_lora.py
+++ b/tests/test_lora/test_lora.py
@@ -91,7 +91,7 @@ def run_lora_test():
sub_model_zoo = model_zoo.get_sub_registry("transformers_llama")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
task_type = None
- if name == "transformers_llama_for_casual_lm":
+ if name == "transformers_llama_for_causal_lm":
task_type = "CAUSAL_LM"
if name == "transformers_llama_for_sequence_classification":
task_type = "SEQ_CLS"
diff --git a/tests/test_pipeline/test_schedule/test_interleaved.py b/tests/test_pipeline/test_schedule/test_interleaved.py
index a626b834a..04a1296e6 100644
--- a/tests/test_pipeline/test_schedule/test_interleaved.py
+++ b/tests/test_pipeline/test_schedule/test_interleaved.py
@@ -6,6 +6,7 @@ import pytest
import torch
import torch.distributed as dist
import torch.nn as nn
+from torch.testing import assert_close
import colossalai
from colossalai.cluster import ProcessGroupMesh
@@ -107,13 +108,13 @@ def run_pp(
# check loss
if stage_manager.is_last_stage(ignore_chunk=True):
- assert torch.allclose(torch_loss, pp_ret["loss"])
+ assert_close(torch_loss, pp_ret["loss"])
# check gradients
for i in range(num_model_chunk):
idx = world_size * i + rank
- assert torch.allclose(torch_model.layers[idx].weight.grad, sharded_model[i].weight.grad)
- assert torch.allclose(torch_model.layers[idx].bias.grad, sharded_model[i].bias.grad)
+ assert_close(torch_model.layers[idx].weight.grad, sharded_model[i].weight.grad)
+ assert_close(torch_model.layers[idx].bias.grad, sharded_model[i].bias.grad)
# step
torch_optimizer.step()
@@ -123,8 +124,8 @@ def run_pp(
# check updated param
for i in range(num_model_chunk):
idx = world_size * i + rank
- assert torch.allclose(torch_model.layers[idx].weight, sharded_model[i].weight)
- assert torch.allclose(torch_model.layers[idx].bias, sharded_model[i].bias)
+ assert_close(torch_model.layers[idx].weight, sharded_model[i].weight)
+ assert_close(torch_model.layers[idx].bias, sharded_model[i].bias)
# forward only
with torch.no_grad():
@@ -135,14 +136,14 @@ def run_pp(
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True
)
if stage_manager.is_last_stage(ignore_chunk=True):
- assert torch.allclose(torch_loss, pp_ret["loss"])
+ assert_close(torch_loss, pp_ret["loss"])
for layer in sharded_model:
if layer.weight.grad is None:
assert layer.weight.grad is None and layer.bias.grad is None
else:
- assert torch.allclose(layer.weight.grad, torch.zeros_like(layer.weight.grad))
- assert torch.allclose(layer.bias.grad, torch.zeros_like(layer.bias.grad))
+ assert_close(layer.weight.grad, torch.zeros_like(layer.weight.grad))
+ assert_close(layer.bias.grad, torch.zeros_like(layer.bias.grad))
@pytest.mark.dist
diff --git a/tests/test_pipeline/test_schedule/test_oneF_oneB.py b/tests/test_pipeline/test_schedule/test_oneF_oneB.py
index c4bfa7b69..8ae4f6daa 100644
--- a/tests/test_pipeline/test_schedule/test_oneF_oneB.py
+++ b/tests/test_pipeline/test_schedule/test_oneF_oneB.py
@@ -6,6 +6,7 @@ import pytest
import torch
import torch.distributed as dist
import torch.nn as nn
+from torch.testing import assert_close
import colossalai
from colossalai.cluster import ProcessGroupMesh
@@ -103,13 +104,13 @@ def examine_pp(num_microbatch: int, batch_size: int):
# check loss
if stage_manager.is_last_stage():
- assert torch.allclose(torch_loss, pp_ret["loss"])
+ assert_close(torch_loss, pp_ret["loss"])
# check gradients
for i in range(len(sharded_model)):
idx = rank * num_local_layer + i
- assert torch.allclose(torch_model.layers[idx].weight.grad, sharded_model[i].weight.grad)
- assert torch.allclose(torch_model.layers[idx].bias.grad, sharded_model[i].bias.grad)
+ assert_close(torch_model.layers[idx].weight.grad, sharded_model[i].weight.grad)
+ assert_close(torch_model.layers[idx].bias.grad, sharded_model[i].bias.grad)
# step
torch_optimizer.step()
@@ -119,8 +120,8 @@ def examine_pp(num_microbatch: int, batch_size: int):
# check updated param
for i in range(len(sharded_model)):
idx = rank * num_local_layer + i
- assert torch.allclose(torch_model.layers[idx].weight, sharded_model[i].weight)
- assert torch.allclose(torch_model.layers[idx].bias, sharded_model[i].bias)
+ assert_close(torch_model.layers[idx].weight, sharded_model[i].weight)
+ assert_close(torch_model.layers[idx].bias, sharded_model[i].bias)
# forward only
with torch.no_grad():
@@ -131,14 +132,14 @@ def examine_pp(num_microbatch: int, batch_size: int):
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True
)
if stage_manager.is_last_stage():
- assert torch.allclose(torch_loss, pp_ret["loss"])
+ assert_close(torch_loss, pp_ret["loss"])
for layer in sharded_model:
if layer.weight.grad is None:
assert layer.weight.grad is None and layer.bias.grad is None
else:
- assert torch.allclose(layer.weight.grad, torch.zeros_like(layer.weight.grad))
- assert torch.allclose(layer.bias.grad, torch.zeros_like(layer.bias.grad))
+ assert_close(layer.weight.grad, torch.zeros_like(layer.weight.grad))
+ assert_close(layer.bias.grad, torch.zeros_like(layer.bias.grad))
def run_dist(
diff --git a/tests/test_shardformer/test_flash_attention.py b/tests/test_shardformer/test_flash_attention.py
index 9aa24a166..42ca6b198 100644
--- a/tests/test_shardformer/test_flash_attention.py
+++ b/tests/test_shardformer/test_flash_attention.py
@@ -88,6 +88,7 @@ def check_attn_func(dtype: torch.dtype, attn_func, attn_kwargs: dict, padding_ma
padding_mask = padding_mask[:, None, :, None].logical_not()
ref_output = ref_output.masked_fill(padding_mask, 0)
output = output.masked_fill(padding_mask, 0)
+
assert_close(output, ref_output, **tols)
output.mean().backward()
ref_output.mean().backward()
@@ -128,6 +129,8 @@ def test_flash_attn_func(dtype: torch.dtype):
attn_kwargs, padding_mask = gen_kwargs_func(dtype)
for attn_func, name, need_postprocess in attn_funcs:
print(f"{dtype}, {name}, {mask_type}")
+ if mask_type == "padded":
+ pass
if need_postprocess:
check_attn_func(dtype, attn_func, post_process_kwargs_for_raw_attn(attn_kwargs), padding_mask)
else:
diff --git a/tests/test_shardformer/test_layer/test_ring_attn.py b/tests/test_shardformer/test_layer/test_ring_attn.py
new file mode 100644
index 000000000..1c7647a7d
--- /dev/null
+++ b/tests/test_shardformer/test_layer/test_ring_attn.py
@@ -0,0 +1,186 @@
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+from flash_attn import flash_attn_qkvpacked_func, flash_attn_varlen_qkvpacked_func
+from torch.testing import assert_close
+
+import colossalai
+from colossalai.shardformer.layer import AttnMaskType
+from colossalai.shardformer.layer.attn import AttnMaskType, RingAttention
+from colossalai.shardformer.layer.utils import split_batch_zigzag, split_varlen_zigzag
+from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
+from colossalai.utils import get_current_device
+
+
+@parameterize("seq_len", [4096])
+@parameterize("bs", [2])
+@parameterize("nheads", [5])
+@parameterize("d", [128])
+@parameterize("dtype", [torch.bfloat16, torch.float16])
+def check_ring_attn(seq_len, bs, nheads, d, dtype):
+ torch.cuda.manual_seed(2)
+ device = get_current_device()
+ sp_group = dist.group.WORLD
+ sp_size = dist.get_world_size()
+ # Some outliers may seem large, but our errors are still lower than
+ # than Megatron-LM context parallel's
+ # (https://github.com/NVIDIA/TransformerEngine/blob/33a3d02f81c56e6f7b542c09bfa86657078d57fb/tests/pytorch/fused_attn/run_fused_attn_with_cp.py#L215)
+ # and the original zigzag implementation's (https://github.com/zhuzilin/ring-flash-attention/tree/main)
+ atol = rtol = 7e-3
+
+ # Setup inputs
+ qkv = torch.randn(bs, seq_len, 3, nheads, d, device=device, dtype=dtype, requires_grad=True)
+ local_qkv = split_batch_zigzag(qkv, sp_group)
+ q, k, v = local_qkv.unbind(dim=-3)
+ q, k, v = [x.squeeze(2).detach().clone().transpose(1, 2) for x in (q, k, v)] # (B, nHeads, Sq, D)
+ q.requires_grad = k.requires_grad = v.requires_grad = True
+
+ # Ring attention vs single GPU
+ ring_out, ring_lse = RingAttention.attention(
+ q,
+ k,
+ v,
+ sp_group,
+ AttnMaskType.CAUSAL,
+ return_softmax=True,
+ inner_ring_size=max(2, sp_size // 2),
+ # inner_ring_size=4
+ )
+ ring_out = ring_out.transpose(1, 2)
+ out, lse, _ = flash_attn_qkvpacked_func(
+ qkv, dropout_p=0.0, causal=True, window_size=(-1, -1), alibi_slopes=None, return_attn_probs=True
+ )
+
+ # Checkout out and softmax denominator
+ local_out = split_batch_zigzag(out, sp_group)
+ local_lse = split_batch_zigzag(lse, sp_group, seq_dim=-1)
+ local_lse = local_lse.transpose(1, 2).contiguous().view(-1, ring_lse.shape[-1]) # (B, nHeads, Sq) -> (T, nHeads)
+ assert_close(ring_lse, local_lse, atol=atol, rtol=rtol)
+ assert_close(ring_out, local_out, atol=atol, rtol=rtol)
+
+ # Check grads
+ ring_out.sum().backward()
+ out.sum().backward()
+ ring_dq, ring_dk, ring_dv = [x.transpose(1, 2) for x in (q.grad, k.grad, v.grad)]
+ dqkv = qkv.grad
+ local_dqkv = split_batch_zigzag(dqkv, sp_group)
+
+ assert_close(ring_dq, local_dqkv[:, :, 0], atol=atol, rtol=rtol)
+ assert_close(ring_dk, local_dqkv[:, :, 1], atol=atol, rtol=rtol)
+ assert_close(ring_dv, local_dqkv[:, :, 2], atol=atol, rtol=rtol)
+ if dist.get_rank() == 0:
+ print(
+ f"sp_size {dist.get_world_size()}, inner ring size {dist.get_world_size(RingAttention.INNER_RING_GROUP)} passed."
+ )
+
+
+@parameterize("seqlen", [4096])
+@parameterize("bs", [2])
+@parameterize("nheads", [5])
+@parameterize("d", [128])
+@parameterize("dtype", [torch.bfloat16, torch.float16])
+def check_packed_seq(seqlen, bs, nheads, d, dtype):
+ device = get_current_device()
+ sp_group = dist.group.WORLD
+ sp_size = dist.get_world_size()
+ atol = rtol = 7e-3
+ torch.cuda.manual_seed(2)
+ # Prepare varlen attention mask
+ padding_mask = torch.ones((bs, seqlen), dtype=torch.int, device=device)
+ padding_mask[: bs // 2, (seqlen // 4) * 3 :] = 0
+ padding_mask[:, seqlen // 2 :] = 0
+
+ input_embeds = torch.randn(bs, seqlen, nheads, d, device=device, dtype=dtype, requires_grad=True)
+
+ # Forward
+ # out = ColoAttention.attention(q, k, v, **mask_info)
+ flat_input = input_embeds.view(-1, nheads, d)[padding_mask.flatten().nonzero().squeeze()]
+ qkv = torch.stack([flat_input] * 3, dim=1)
+ qkv.retain_grad()
+
+ input_embeds, mask_info, _ = RingAttention.prepare_varlen_batch(padding_mask, sp_group, input_embeds)
+ out, lse, _ = flash_attn_varlen_qkvpacked_func(
+ qkv,
+ mask_info["cu_seqlens"] * sp_size,
+ mask_info["max_seqlen"] * sp_size,
+ return_attn_probs=True,
+ causal=True,
+ # deterministic=True
+ )
+ # Test the splitting function
+ local_input = split_varlen_zigzag(
+ flat_input, mask_info["cu_seqlens"] * sp_size, sp_group, mask_info["max_seqlen"] * sp_size
+ )
+ assert (local_input == input_embeds.view(-1, nheads, d)[mask_info["valid_indices"]]).all()
+ del local_input, flat_input
+
+ q_ring, k_ring, v_ring = [input_embeds.clone().transpose(1, 2) for _ in range(3)]
+ q_ring.retain_grad()
+ k_ring.retain_grad()
+ v_ring.retain_grad()
+
+ ring_out, ring_lse = RingAttention.attention(
+ q_ring,
+ k_ring,
+ v_ring,
+ sp_group,
+ **mask_info,
+ pad_output=False,
+ return_softmax=True,
+ # deterministic=True
+ )
+ ring_out = ring_out.transpose(1, 2).reshape(-1, nheads, d)
+ # Check output
+ lse = lse.transpose(0, 1)
+ out, lse = split_varlen_zigzag(
+ [out, lse], mask_info["cu_seqlens"] * sp_size, sp_group, mask_info["max_seqlen"] * sp_size
+ )
+ assert_close(lse, ring_lse, atol=atol, rtol=rtol)
+ assert_close(out, ring_out, atol=atol, rtol=rtol)
+
+ # Check grads
+ labels = torch.ones(out.shape[0], dtype=dtype, device=device)
+ F.mse_loss(out.sum((-2, -1)), labels).backward()
+ F.mse_loss(ring_out.sum((-2, -1)), labels[: ring_out.shape[0]]).backward()
+ dq, dk, dv = [
+ split_varlen_zigzag(
+ qkv.grad[:, i], mask_info["cu_seqlens"] * sp_size, sp_group, mask_info["max_seqlen"] * sp_size
+ )
+ for i in range(3)
+ ]
+ dq_ring, dk_ring, dv_ring = [
+ x.transpose(1, 2).reshape(-1, nheads, d)[mask_info["valid_indices"]]
+ for x in (q_ring.grad, k_ring.grad, v_ring.grad)
+ ]
+
+ assert_close(dq, dq_ring, atol=atol, rtol=rtol)
+ assert_close(dk, dk_ring, atol=atol, rtol=rtol)
+ assert_close(dv, dv_ring, atol=atol, rtol=rtol)
+
+
+def launch_single_ring(rank, world_size, port):
+ colossalai.launch(rank, world_size, "localhost", port)
+ check_packed_seq()
+ check_ring_attn()
+
+
+def launch_double_ring(rank, world_size, port):
+ colossalai.launch(rank, world_size, "localhost", port)
+ check_ring_attn()
+
+
+@rerun_if_address_is_in_use()
+@parameterize("world_size", [2])
+def test_ring_attn(world_size):
+ spawn(launch_single_ring, nprocs=world_size)
+
+
+@rerun_if_address_is_in_use()
+@parameterize("world_size", [4])
+def test_double_ring(world_size):
+ spawn(launch_double_ring, nprocs=world_size)
+
+
+if __name__ == "__main__":
+ test_ring_attn()
+ test_double_ring()
diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py
index 190fee129..9ad84341a 100644
--- a/tests/test_shardformer/test_model/_utils.py
+++ b/tests/test_shardformer/test_model/_utils.py
@@ -10,6 +10,7 @@ from torch.distributed import ProcessGroup
from torch.nn import Module
from torch.optim import Adam, Optimizer
from torch.testing import assert_close
+from transformers.modeling_outputs import BaseModelOutputWithPast
from colossalai.accelerator import get_accelerator
from colossalai.booster import Booster
@@ -259,7 +260,6 @@ def run_forward_backward_with_hybrid_plugin(
org_output = org_model(**unshard_test_data)
org_loss = criterion(org_output)
org_loss.backward()
-
return org_loss, org_output, sharded_loss, sharded_output
@@ -302,11 +302,12 @@ def run_forward_backward_with_low_level_zero_plugin(
def check_output_hidden_state(
- org_output: Tensor,
- sharded_output: Tensor,
+ org_output: BaseModelOutputWithPast,
+ sharded_output: BaseModelOutputWithPast,
stage_manager: Optional[PipelineStageManager] = None,
atol: float = 1e-5,
rtol: float = 1e-3,
+ shard_config: Optional[ShardConfig] = None,
):
org_hidden_state = org_output.last_hidden_state
@@ -315,6 +316,14 @@ def check_output_hidden_state(
else:
sharded_hidden_state = sharded_output.last_hidden_state
+ # Check if the output sequence is gathered before cross entropy
+ if shard_config is not None:
+ seq_dim = 1
+ sp_group = shard_config.sequence_parallel_process_group
+ sp_size = shard_config.sequence_parallel_size
+ if org_hidden_state.shape[seq_dim] == sharded_hidden_state.shape[seq_dim] * sp_size:
+ org_hidden_state = org_hidden_state.chunk(sp_size, dim=seq_dim)[dist.get_rank(sp_group)]
+
assert_close(org_hidden_state.float(), sharded_hidden_state.float(), atol=atol, rtol=rtol)
@@ -374,8 +383,11 @@ def get_grad_tensors_for_check(
shard_grad = torch.cat(shard_grad_list, dim=dim)
# embedding may be resized when using tensor parallel
- if shard_grad.shape[0] > org_grad.shape[0]:
- shard_grad = shard_grad[: org_grad.shape[0], :]
+ try:
+ if shard_grad.shape[0] > org_grad.shape[0]:
+ shard_grad = shard_grad[: org_grad.shape[0], :]
+ except:
+ pass
if verbose and dist.get_rank() == 0:
print(f"'{suffix}' grad: {org_grad}, {shard_grad}")
@@ -404,9 +416,6 @@ def check_grad(
org_grad = getattr_(org_model, suffix).weight.grad
shard_grad = getattr_(sharded_model, suffix).weight.grad
shard_weight = getattr_(sharded_model, suffix).weight
- # if verbose and dist.get_rank() == 0:
- # print("shard_weight", shard_weight)
- # print("org_grad", org_grad)
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
shard_grad_list = [torch.zeros_like(shard_grad).to("cuda") for _ in range(dist.get_world_size(tp_group))]
dist.all_gather(shard_grad_list, shard_grad, tp_group)
@@ -440,7 +449,7 @@ def check_all_grad_tensors(check_tensors):
"org_grad": tensor to be compared from the original model
"shard_grad": tensor to be compared from the sharded model
"""
- for suffix, check_info in check_tensors.items():
+ for idx, (suffix, check_info) in enumerate(check_tensors.items()):
org_grad = check_info["org_grad"]
shard_grad = check_info["shard_grad"]
rtol = check_info["rtol"]
diff --git a/tests/test_shardformer/test_model/test_shard_command.py b/tests/test_shardformer/test_model/test_shard_command.py
index 3281b50e1..efe5cee2a 100644
--- a/tests/test_shardformer/test_model/test_shard_command.py
+++ b/tests/test_shardformer/test_model/test_shard_command.py
@@ -271,7 +271,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
],
)
def run_command_test(test_config):
- sub_model_zoo = model_zoo.get_sub_registry("transformers_command", "transformers_command_for_casual_lm")
+ sub_model_zoo = model_zoo.get_sub_registry("transformers_command", "transformers_command_for_causal_lm")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
@@ -321,7 +321,7 @@ def run_command_test(test_config):
],
)
def run_command_3d_test(test_config):
- sub_model_zoo = model_zoo.get_sub_registry("transformers_command", "transformers_command_for_casual_lm")
+ sub_model_zoo = model_zoo.get_sub_registry("transformers_command", "transformers_command_for_causal_lm")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py
index 88e54176b..3c66f6097 100644
--- a/tests/test_shardformer/test_model/test_shard_llama.py
+++ b/tests/test_shardformer/test_model/test_shard_llama.py
@@ -63,7 +63,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all"
):
master2working = sharded_optimizer.get_master_to_working_map()
- for p1, p2 in zip(llama_model.parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]):
+ for (name, p1), p2 in zip(
+ llama_model.named_parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]
+ ):
working_p = master2working[id(p2)]
grads = sharded_optimizer.get_partitioned_gradients_by_param_id(0, id(working_p))
grad_index = (
@@ -73,7 +75,10 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
)
grad = grads[grad_index]
sharded_grad = p1.grad.view(-1).chunk(dist.get_world_size())[dist.get_rank()]
- assert_close(sharded_grad, grad[: sharded_grad.shape[0]], atol=5e-3, rtol=5e-3, check_dtype=False)
+ try:
+ assert_close(sharded_grad, grad[: sharded_grad.shape[0]], atol=5e-3, rtol=5e-3, check_dtype=False)
+ except Exception as e:
+ raise RuntimeError(f"Failed to check grad for {name}") from e
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
grads_to_check = {}
@@ -114,89 +119,130 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
atol, rtol = 5e-3, 5e-3
if org_model.__class__.__name__ == "LlamaModel":
- check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
-
+ check_output_hidden_state(
+ org_output,
+ sharded_output,
+ stage_manager,
+ atol=atol,
+ rtol=rtol,
+ shard_config=booster.plugin.shard_config,
+ )
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
-
# check weights
if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True):
if test_config["precision"] == "fp32":
atol, rtol = 1e-4, 1e-3
else:
atol, rtol = 5e-3, 5e-3
- try:
- check_weight(
- llama_model,
- shard_llama_model,
- col_layer_for_check,
- tp_group,
- atol=atol,
- rtol=rtol,
- dim=1,
- verbose=False,
- )
- except Exception as e:
- print(f"Failed config: {test_config}")
- raise e
+ check_weight(
+ llama_model,
+ shard_llama_model,
+ col_layer_for_check,
+ tp_group,
+ atol=atol,
+ rtol=rtol,
+ dim=1,
+ verbose=False,
+ )
# check grads
check_all_grad_tensors(grads_to_check)
-
torch.cuda.empty_cache()
@parameterize(
"test_config",
[
- { # Ulysess + Flash attention
+ # Double Ring Attention
+ {
+ "tp_size": 1,
+ "pp_size": 1,
+ "sp_size": 4,
+ "num_microbatches": 1,
+ "enable_sequence_parallelism": True,
+ "sequence_parallelism_mode": "ring_attn",
+ "use_lazy_init": True,
+ "zero_stage": 0,
+ "precision": "fp16",
+ "initial_scale": 1,
+ "inner_ring_size": 2,
+ },
+ # Ring Attention + PP
+ {
+ "tp_size": 1,
+ "pp_size": 2,
+ "sp_size": 2,
+ "num_microbatches": 2,
+ "enable_sequence_parallelism": True,
+ "sequence_parallelism_mode": "ring_attn",
+ "use_lazy_init": True,
+ "zero_stage": 1,
+ "precision": "fp16",
+ "initial_scale": 1,
+ },
+ # Ring Attention + TP
+ {
+ "tp_size": 2,
+ "pp_size": 1,
+ "sp_size": 2,
+ "num_microbatches": 1,
+ "enable_sequence_parallelism": True,
+ "sequence_parallelism_mode": "ring_attn",
+ "use_lazy_init": True,
+ "zero_stage": 2,
+ "precision": "fp16",
+ "initial_scale": 1,
+ },
+ { # Ulysess + TP
+ "tp_size": 2,
+ "pp_size": 1,
+ "sp_size": 2,
+ "num_microbatches": 1,
+ "enable_sequence_parallelism": True,
+ "sequence_parallelism_mode": "all_to_all",
+ "enable_all_optimization": True,
+ "use_lazy_init": True,
+ "zero_stage": 0,
+ "precision": "fp16",
+ "initial_scale": 1,
+ },
+ { # Ulysess + PP
"tp_size": 1,
"pp_size": 2,
"sp_size": 2,
"num_microbatches": 2,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "all_to_all",
- "enable_flash_attention": True,
+ "enable_all_optimization": True,
"use_lazy_init": True,
"zero_stage": 0,
"precision": "fp16",
"initial_scale": 1,
},
- { # Test ring + Flash attention
- "tp_size": 2,
- "pp_size": 1,
- "sp_size": 2,
- "num_microbatches": 1,
- "enable_sequence_parallelism": True,
- "sequence_parallelism_mode": "ring",
- "enable_flash_attention": True,
- "use_lazy_init": True,
- "zero_stage": 2,
- "precision": "fp16",
- "initial_scale": 1,
- },
- {
- "tp_size": 1,
- "pp_size": 1,
- "sp_size": 2,
- "num_microbatches": 1,
- "enable_sequence_parallelism": True,
- "sequence_parallelism_mode": "all_to_all",
- "use_lazy_init": True,
- "zero_stage": 1,
- "precision": "fp16",
- "initial_scale": 1,
- },
{
"tp_size": 4,
"pp_size": 1,
"num_microbatches": 1,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "split_gather",
- "enable_flash_attention": False,
+ "enable_flash_attention": True,
"use_lazy_init": True,
"precision": "fp16",
"initial_scale": 1,
},
+ {
+ "tp_size": 2,
+ "pp_size": 1,
+ "sp_size": 1,
+ "num_microbatches": 1,
+ "enable_sequence_parallelism": True,
+ "sequence_parallelism_mode": "ring",
+ "enable_flash_attention": True,
+ "use_lazy_init": True,
+ "zero_stage": 2,
+ "precision": "fp16",
+ "initial_scale": 1,
+ },
{
"tp_size": 2,
"pp_size": 2,
@@ -240,12 +286,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
)
def run_llama_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_llama")
-
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
+ if test_config.get("sequence_parallelism_mode", None) == "ring_attn" and "causal" not in name:
+ continue
try:
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
except Exception as e:
- print(f"Failed config: {test_config}")
+ print(f"Failed config: {test_config}, model name: {name}")
raise e
clear_layout_converter()
Randomizer.reset_index()
From 26493b97d3b9c55a27d822c63b9176ac79b9ff5e Mon Sep 17 00:00:00 2001
From: Hongxin Liu
Date: Fri, 16 Aug 2024 18:49:14 +0800
Subject: [PATCH 116/116] [misc] update compatibility (#6008)
* [misc] update compatibility
* [misc] update requirements
* [devops] disable requirements cache
* [test] fix torch ddp test
* [test] fix rerun on address in use
* [test] fix lazy init
---
.compatibility | 1 +
.cuda_ext.json | 4 ++--
.github/workflows/build_on_pr.yml | 2 +-
.github/workflows/build_on_schedule.yml | 2 +-
colossalai/testing/utils.py | 2 +-
requirements/requirements.txt | 2 +-
.../test_plugin/test_torch_ddp_plugin.py | 2 +-
tests/test_lazy/test_models.py | 14 +++++++++++---
8 files changed, 19 insertions(+), 10 deletions(-)
diff --git a/.compatibility b/.compatibility
index 4f808740b..62d19faff 100644
--- a/.compatibility
+++ b/.compatibility
@@ -1,3 +1,4 @@
2.1.0-12.1.0
2.2.2-12.1.0
2.3.0-12.1.0
+2.4.0-12.4.1
diff --git a/.cuda_ext.json b/.cuda_ext.json
index 8c9d5916c..1e617755b 100644
--- a/.cuda_ext.json
+++ b/.cuda_ext.json
@@ -5,8 +5,8 @@
"cuda_image": "hpcaitech/cuda-conda:12.1"
},
{
- "torch_command": "pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu118",
- "cuda_image": "hpcaitech/cuda-conda:11.8"
+ "torch_command": "pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cu124",
+ "cuda_image": "hpcaitech/cuda-conda:12.4"
}
]
}
diff --git a/.github/workflows/build_on_pr.yml b/.github/workflows/build_on_pr.yml
index 151454239..58cd88268 100644
--- a/.github/workflows/build_on_pr.yml
+++ b/.github/workflows/build_on_pr.yml
@@ -141,7 +141,7 @@ jobs:
- name: Install Colossal-AI
run: |
BUILD_EXT=1 pip install -v -e .
- pip install -r requirements/requirements-test.txt
+ pip install --no-cache-dir -r requirements/requirements-test.txt
- name: Store Colossal-AI Cache
run: |
diff --git a/.github/workflows/build_on_schedule.yml b/.github/workflows/build_on_schedule.yml
index fc6424503..fc688a71b 100644
--- a/.github/workflows/build_on_schedule.yml
+++ b/.github/workflows/build_on_schedule.yml
@@ -57,7 +57,7 @@ jobs:
[ ! -z "$(ls -A /github/home/cuda_ext_cache/)" ] && cp -r /github/home/cuda_ext_cache/* /__w/ColossalAI/ColossalAI/
BUILD_EXT=1 pip install -v -e .
cp -r /__w/ColossalAI/ColossalAI/build /github/home/cuda_ext_cache/
- pip install -r requirements/requirements-test.txt
+ pip install --no-cache-dir -r requirements/requirements-test.txt
- name: Unit Testing
if: steps.check-avai.outputs.avai == 'true'
diff --git a/colossalai/testing/utils.py b/colossalai/testing/utils.py
index 5f6864ff0..90d35dc85 100644
--- a/colossalai/testing/utils.py
+++ b/colossalai/testing/utils.py
@@ -176,7 +176,7 @@ def rerun_if_address_is_in_use():
else:
exception = Exception
- func_wrapper = rerun_on_exception(exception_type=exception, pattern=".*Address already in use.*")
+ func_wrapper = rerun_on_exception(exception_type=exception, pattern=".*(A|a)ddress already in use.*")
return func_wrapper
diff --git a/requirements/requirements.txt b/requirements/requirements.txt
index 651eb66e8..578122d47 100644
--- a/requirements/requirements.txt
+++ b/requirements/requirements.txt
@@ -8,7 +8,7 @@ click
fabric
contexttimer
ninja
-torch>=2.1.0,<=2.3.0
+torch>=2.1.0,<=2.4.0
safetensors
einops
pydantic
diff --git a/tests/test_booster/test_plugin/test_torch_ddp_plugin.py b/tests/test_booster/test_plugin/test_torch_ddp_plugin.py
index f92b5c6e5..2a3b6e5a3 100644
--- a/tests/test_booster/test_plugin/test_torch_ddp_plugin.py
+++ b/tests/test_booster/test_plugin/test_torch_ddp_plugin.py
@@ -47,7 +47,7 @@ def check_torch_ddp_plugin():
registry = model_zoo
for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in registry.items():
- if name == "dlrm_interactionarch" or name.startswith("simple_"):
+ if name in ("dlrm_interactionarch", "transformers_mixtral") or name.startswith("simple_"):
continue
run_fn(model_fn, data_gen_fn, output_transform_fn)
torch.cuda.empty_cache()
diff --git a/tests/test_lazy/test_models.py b/tests/test_lazy/test_models.py
index c85860a8d..0a919955f 100644
--- a/tests/test_lazy/test_models.py
+++ b/tests/test_lazy/test_models.py
@@ -18,9 +18,17 @@ def test_models_lazy_init(subset, default_device):
sub_model_zoo = model_zoo.get_sub_registry(subset, allow_empty=True)
for name, entry in sub_model_zoo.items():
# TODO(ver217): lazy init does not support weight norm, skip these models
- if name in ("torchaudio_wav2vec2_base", "torchaudio_hubert_base") or name.startswith(
- ("transformers_vit", "transformers_blip2", "transformers_whisper")
- ):
+ if name in (
+ "torchaudio_wav2vec2_base",
+ "torchaudio_hubert_base",
+ "timm_beit",
+ "timm_vision_transformer",
+ "timm_deit",
+ "timm_beitv2",
+ "timm_deit3",
+ "timm_convit",
+ "timm_tnt_b_patch16_224",
+ ) or name.startswith(("transformers_vit", "transformers_blip2", "transformers_whisper")):
continue
check_lazy_init(entry, verbose=True, default_device=default_device)