[moe] merge moe into main (#4978)

* update moe module
* support openmoe
This commit is contained in:
Xuanlei Zhao
2023-11-02 10:21:24 +08:00
committed by GitHub
parent 8993c8a817
commit dc003c304c
67 changed files with 7618 additions and 1657 deletions

View File

@@ -0,0 +1,129 @@
## OpenMoE
[OpenMoE](https://github.com/XueFuzhao/OpenMoE) is the open-source community's first decoder-only MoE transformer. OpenMoE is implemented in Jax, and [Colossal-AI](https://github.com/hpcaitech/ColossalAI) has pioneered an efficient open-source support for this model in PyTorch, enabling a broader range of users to participate in and use this model. The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) demonstrates finetune and inference methods.
## Usage
### 1. Installation
Please install the latest ColossalAI from source.
```bash
CUDA_EXT=1 pip install -U git+https://github.com/hpcaitech/ColossalAI
```
Then install dependencies.
```bash
cd ColossalAI/examples/language/openmoe
pip install -r requirements.txt
```
Additionally, we recommend you to use torch 1.13.1. We've tested our code on torch 1.13.1 and found it's compatible with our code and flash attention.
### 2. Install kernels (Optional)
We have utilized `Triton`, `FlashAttention` and `Apex` kernel for better performance. They are not necessary but we recommend you to install them to fully utilize your hardware.
```
# install triton via pip
pip install triton
# install flash attention via pip
pip install flash-attn==2.0.5
# install apex from source
git clone https://github.com/NVIDIA/apex.git
cd apex
git checkout 741bdf50825a97664db08574981962d66436d16a
pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./ --global-option="--cuda_ext"
```
### 3. Train
Yon can use colossalai run to launch single-node training:
```bash
colossalai run --standalone --nproc_per_node YOUR_GPU_PER_NODE train.py --OTHER_CONFIGURATIONS
```
Yon can also use colossalai run to launch multi-nodes training:
```bash
colossalai run --nproc_per_node YOUR_GPU_PER_NODE --hostfile YOUR_HOST_FILE train.py --OTHER_CONFIGURATIONS
```
Here is a sample hostfile:
```text
hostname1
hostname2
hostname3
hostname4
```
The hostname refers to the ip address of your nodes. Make sure master node can access all nodes (including itself) by ssh without password.
Here is details about CLI arguments:
- Model configuration: `--model_name`. `base` and `8b` are supported for OpenMoE.
- Booster plugin: `--plugin`. `ep`, `ep_zero` and `hybrid` are supported. `ep_zero` is recommended for general cases. `ep` can provides least memory consumption and `hybrid` suits large scale training.
- Output path: `--output_path`. The path to save your model. The default value is `./outputs`.
- Number of epochs: `--num_epochs`. The default value is 1.
- Local batch size: `--batch_size`. Batch size per GPU. The default value is 1.
- Save interval: `-i`, `--save_interval`. The interval (steps) of saving checkpoints. The default value is 1000.
- Mixed precision: `--precision`. The default value is "bf16". "fp16", "bf16" and "fp32" are supported.
- Max length: `--max_length`. Max sequence length. Default to 2048.
- Dataset: `-d`, `--dataset`. The default dataset is `yizhongw/self_instruct`. It support any dataset from `datasets` with the same data format as it.
- Task Name: `--task_name`. Task of corresponding dataset. Default to `super_natural_instructions`.
- Learning rate: `--lr`. The default value is 1e-5.
- Weight decay: `--weight_decay`. The default value is 0.
- Zero stage: `--zero_stage`. Zero stage. Recommend 2 for ep and 1 for ep zero.
- Extra dp size: `--extra_dp_size`. Extra moe param dp size for ep_zero plugin. Recommended to be 2 or 4.
- Use kernel: `--use_kernel`. Use kernel optim. Need to install flash attention and triton to enable all kernel optimizations. Skip if not installed.
- Use layernorm kernel: `--use_layernorm_kernel`. Use layernorm kernel. Need to install apex. Raise error if not installed.
- Router aux loss factor: `--router_aux_loss_factor`. Moe router z loss factor. You can refer to STMoE for details.
- Router z loss factor: `--router_z_loss_factor`. Moe router aux loss factor. You can refer to STMoE for details.
- Label smoothing: `--label_smoothing`. Label smoothing.
- Z loss factor: `--z_loss_factor`. The final outputs' classification z loss factor.
Load balance: `--load_balance`. Expert load balance. Defaults to False. Recommend enabling.
- Load balance interval: `--load_balance_interval`. Expert load balance interval.
- Communication overlap: `--comm_overlap`. Use communication overlap for MoE. Recommended to enable for multi-node training.
### 4. Shell Script Examples
For your convenience, we provide some shell scripts to train with various configurations. Here we will show an example of how to run training
OpenMoE.
#### a. Running environment
This experiment was performed on a single computing nodes with 8 A800 80GB GPUs in total for OpenMoE-8B. The GPUs are fully connected with NVLink.
#### b. Running command
We demonstrate how to run three plugins in `train.sh`. You can choose anyone and use your own args.
```bash
bash train.sh
```
#### c. Multi-Nodes Training
To run on multi-nodes, you can modify the script as:
```bash
colossalai run --nproc_per_node YOUR_GPU_PER_NODE --hostfile YOUR_HOST_FILE \
train.py --OTHER_CONFIGURATIONS
```
## Reference
```
@article{bian2021colossal,
title={Colossal-AI: A Unified Deep Learning System For Large-Scale Parallel Training},
author={Bian, Zhengda and Liu, Hongxin and Wang, Boxiang and Huang, Haichen and Li, Yongbin and Wang, Chuanrui and Cui, Fan and You, Yang},
journal={arXiv preprint arXiv:2110.14883},
year={2021}
}
```
```bibtex
@misc{openmoe2023,
author = {Fuzhao Xue, Zian Zheng, Yao Fu, Jinjie Ni, Zangwei Zheng, Wangchunshu Zhou and Yang You},
title = {OpenMoE: Open Mixture-of-Experts Language Models},
year = {2023},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/XueFuzhao/OpenMoE}},
}
```

View File

@@ -0,0 +1,296 @@
import argparse
import json
import os
import torch
import torch.distributed as dist
from huggingface_hub import snapshot_download
from model.modeling_openmoe import OpenMoeForCausalLM, set_openmoe_args
from model.openmoe_policy import OpenMoeForCausalLMPolicy
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import T5Tokenizer
from transformers.models.llama import LlamaConfig
from utils import PerformanceEvaluator, get_model_numel
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.cluster import DistCoordinator
from colossalai.moe.layers import apply_load_balance
from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.utils import skip_init
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
def move_to_cuda(batch, device):
return {k: v.to(device) for k, v in batch.items()}
def load_ckpt(repo_name: str, model: OpenMoeForCausalLM, booster: Booster):
ckpt_path = snapshot_download(repo_name)
# single ckpt
if os.path.exists(os.path.join(ckpt_path, "pytorch_model.bin")):
ckpt_path = os.path.join(ckpt_path, "pytorch_model.bin")
# shard ckpt
elif os.path.exists(os.path.join(ckpt_path, "pytorch_model.bin.index.json")):
ckpt_path = os.path.join(ckpt_path, "pytorch_model.bin.index.json")
else:
raise ValueError(f"Invalid checkpoint path: {ckpt_path}")
booster.load_model(model, ckpt_path)
class RandomDataset(Dataset):
def __init__(
self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 256384, tokenizer: T5Tokenizer = None
):
self.num_samples = num_samples
self.max_length = max_length
if os.path.exists("./mock_data.json"):
self.input_ids = []
self.attention_mask = []
with open("./mock_data.json", "r") as f:
data = json.load(f)
for v in data.values():
d = v["text"]
encode = tokenizer(
"<pad>" + d,
return_tensors="pt",
add_special_tokens=False,
max_length=max_length,
truncation=True,
padding="max_length",
)
self.input_ids.append(encode["input_ids"])
self.attention_mask.append(encode["attention_mask"])
self.input_ids = torch.cat(self.input_ids, dim=0).to(get_current_device())
self.attention_mask = torch.cat(self.attention_mask, dim=0).to(get_current_device())
repeat_times = num_samples // self.input_ids.shape[0] + 1
self.input_ids = self.input_ids.repeat(repeat_times, 1)[:num_samples]
self.attention_mask = self.attention_mask.repeat(repeat_times, 1)[:num_samples]
else:
self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device())
self.attention_mask = torch.ones_like(self.input_ids)
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
return {
"input_ids": self.input_ids[idx],
"attention_mask": self.attention_mask[idx],
"labels": self.input_ids[idx],
}
def parse_args():
# basic settings
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name",
type=str,
default="base",
choices=["base", "8b"],
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--batch_size",
type=int,
default=4,
help="Batch size (per dp group) for the training dataloader.",
)
parser.add_argument(
"--seq_length",
type=int,
default=2048,
help="sequence length for the training dataloader.",
)
parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
parser.add_argument(
"--plugin",
type=str,
default="hybrid",
help="parallel plugin",
)
# hybrid plugin
parser.add_argument("--pp_size", type=int, default=2, help="pp size")
parser.add_argument("--dp_size", type=int, default=1, help="dp size")
parser.add_argument("--ep_size", type=int, default=2, help="ep size")
parser.add_argument("--zero_stage", type=int, default=2, help="zero stage in hybrid plugin")
parser.add_argument("--microbatch_size", type=int, default=1, help="microbatch size")
parser.add_argument("--extra_dp_size", type=int, default=1)
# kernel
parser.add_argument(
"--use_kernel",
action="store_true",
help="Use kernel optim. Need to install flash attention, apex, triton to enable all kernel optimizations.",
)
# bench
parser.add_argument("--warmup", type=int, default=20)
parser.add_argument("--active", type=int, default=20)
# load balance
parser.add_argument("--load_balance", action="store_true")
# overlap
parser.add_argument("--overlap_alltoall", action="store_true")
args = parser.parse_args()
return args
def main():
args = parse_args()
# Launch ColossalAI
colossalai.launch_from_torch(config={}, seed=args.seed)
coordinator = DistCoordinator()
# Set plugin
booster_kwargs = {}
hybrid_dict = {
"tp_size": 1,
"custom_policy": OpenMoeForCausalLMPolicy(),
"enable_fused_normalization": args.use_kernel,
"enable_jit_fused": args.use_kernel,
"precision": "bf16",
"zero_stage": args.zero_stage,
}
mgr_dict = {
"seed": 42,
}
if args.plugin == "ep":
dp_size = dist.get_world_size()
plugin = MoeHybridParallelPlugin(
pp_size=1,
**hybrid_dict,
)
MOE_MANAGER.setup(
parallel="EP",
max_ep_size=dp_size,
**mgr_dict,
)
elif args.plugin == "ep_zero":
dp_size = dist.get_world_size()
use_ep_inside = False
plugin = MoeHybridParallelPlugin(
pp_size=1,
extra_dp_size=args.extra_dp_size,
use_ep_inside=use_ep_inside,
**hybrid_dict,
)
MOE_MANAGER.setup(
parallel="EP",
max_ep_size=dp_size // args.extra_dp_size,
use_ep_inside=use_ep_inside,
**mgr_dict,
)
elif args.plugin == "hybrid":
dp_size = dist.get_world_size() // args.pp_size
plugin = MoeHybridParallelPlugin(
pp_size=args.pp_size,
zero_stage=args.zero_stage,
microbatch_size=args.microbatch_size,
**hybrid_dict,
)
MOE_MANAGER.setup(
parallel="EP",
mode="fixed",
fixed_dp_size=args.dp_size,
fixed_ep_size=args.ep_size,
fixed_pp_size=args.pp_size,
**mgr_dict,
)
else:
raise ValueError(f"Invalid plugin {args.plugin}")
coordinator.print_on_master(f"Set plugin as {plugin}")
# Build OpenMoe model
repo_name = "hpcaitech/openmoe-" + args.model_name
config = LlamaConfig.from_pretrained(repo_name)
set_openmoe_args(
config,
num_experts=config.num_experts,
moe_layer_interval=config.moe_layer_interval,
enable_load_balance=args.load_balance,
enable_kernel=args.use_kernel,
enable_comm_overlap=args.overlap_alltoall,
)
with skip_init():
model = OpenMoeForCausalLM(config)
coordinator.print_on_master(f"Finish init model with config:\n{config}")
# Enable gradient checkpointing
model.gradient_checkpointing_enable()
# Prepare tokenizer and dataloader
tokenizer = T5Tokenizer.from_pretrained("google/umt5-small")
dataset = RandomDataset(
num_samples=args.batch_size * (args.warmup + args.active + 1) * dp_size,
max_length=args.seq_length,
tokenizer=tokenizer,
)
dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size)
# Set optimizer
optimizer = HybridAdam(model.parameters(), weight_decay=0.01, lr=1e-5)
model_numel = get_model_numel(model)
performance_evaluator = PerformanceEvaluator(
model_numel,
enable_grad_checkpoint=True,
ignore_steps=args.warmup,
dp_world_size=dp_size,
)
# Set booster
booster = Booster(plugin=plugin, **booster_kwargs)
load_ckpt(repo_name, model, booster)
model, optimizer, _, dataloader, _ = booster.boost(model=model, optimizer=optimizer, dataloader=dataloader)
use_pipeline = isinstance(booster.plugin, MoeHybridParallelPlugin) and booster.plugin.pp_size > 1
is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
coordinator.print_on_master(f"Finish init booster")
# Start finetuning
coordinator.print_on_master(f"Start training")
model.train()
train_dataloader_iter = iter(dataloader)
total_len = len(train_dataloader_iter) - 1
exmaple_data = next(train_dataloader_iter)
with tqdm(range(total_len), disable=not coordinator.is_master()) as pbar:
for step in pbar:
performance_evaluator.on_step_start(step)
if use_pipeline:
# Forward pass
outputs = booster.execute_pipeline(
train_dataloader_iter,
model,
lambda x, y: x.loss,
optimizer,
return_loss=True,
return_outputs=True,
)
# Backward and optimize
if is_pp_last_stage:
loss = outputs["loss"]
pbar.set_postfix({"loss": loss.item()})
else:
# Forward pass
data = next(train_dataloader_iter)
data = move_to_cuda(data, torch.cuda.current_device())
outputs = model(**data)
loss = outputs["loss"]
# Backward
booster.backward(loss, optimizer)
pbar.set_postfix({"loss": loss.item()})
optimizer.step()
optimizer.zero_grad()
performance_evaluator.on_step_end(exmaple_data["input_ids"])
if (step == args.warmup // 2) and args.load_balance:
coordinator.print_on_master(f"Apply load balance")
apply_load_balance(model, optimizer)
performance_evaluator.on_fit_end()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,78 @@
#!/bin/bash
set -xue
NUM_GPU=8
MODEL="8b"
SEQ_LENGTH=2048
WARMUP=20
ACTIVE=4
# HACK: make model importable
example_dir=$(dirname $(realpath $(dirname $0)))
if [ -z ${PYTHONPATH+x} ]; then
export PYTHONPATH=$example_dir
else
export PYTHONPATH=$example_dir:$PYTHONPATH
fi
# ep
echo -e "\n\n Naive EP \n\n"
torchrun --standalone --nproc_per_node $NUM_GPU \
$example_dir/benchmark/benchmark_cai.py \
--model_name $MODEL \
--batch_size 8 \
--seq_length $SEQ_LENGTH \
--warmup $WARMUP \
--active $ACTIVE \
--plugin ep \
--zero_stage 2
# ep_zero
echo -e "\n\n EP-ZERO \n\n"
torchrun --standalone --nproc_per_node $NUM_GPU \
$example_dir/benchmark/benchmark_cai.py \
--model_name $MODEL \
--batch_size 16 \
--seq_length $SEQ_LENGTH \
--warmup $WARMUP \
--active $ACTIVE \
--plugin ep_zero \
--use_kernel \
--extra_dp_size 2 \
--zero_stage 1 \
--load_balance
echo -e "\n\n EP-ZERO + Overlap \n\n"
torchrun --standalone --nproc_per_node $NUM_GPU \
$example_dir/benchmark/benchmark_cai.py \
--model_name $MODEL \
--batch_size 16 \
--seq_length $SEQ_LENGTH \
--warmup $WARMUP \
--active $ACTIVE \
--plugin ep_zero \
--use_kernel \
--extra_dp_size 2 \
--zero_stage 1 \
--load_balance \
--overlap_alltoall
# hybrid
torchrun --standalone --nproc_per_node $NUM_GPU \
$example_dir/benchmark/benchmark_cai.py \
--model_name $MODEL \
--batch_size 128 \
--seq_length $SEQ_LENGTH \
--warmup $WARMUP \
--active $ACTIVE \
--use_kernel \
--plugin hybrid \
--pp_size 2 \
--dp_size 1 \
--ep_size 4 \
--zero_stage 1 \
--microbatch_size 32

View File

@@ -0,0 +1,47 @@
#!/bin/bash
set -xue
NUM_GPU=8
MODEL="8b"
SEQ_LENGTH=2048
WARMUP=20
ACTIVE=4
# HACK: make model importable
example_dir=$(dirname $(realpath $(dirname $0)))
if [ -z ${PYTHONPATH+x} ]; then
export PYTHONPATH=$example_dir
else
export PYTHONPATH=$example_dir:$PYTHONPATH
fi
# ep
echo -e "\n\n Naive EP \n\n"
colossalai run --nproc_per_node $NUM_GPU --hostfile "hostfile.txt" \
$example_dir/benchmark/benchmark_cai.py \
--model_name $MODEL \
--batch_size 12 \
--seq_length $SEQ_LENGTH \
--warmup $WARMUP \
--active $ACTIVE \
--plugin ep \
--zero_stage 2
# ep_zero
echo -e "\n\n EP-ZERO \n\n"
colossalai run --nproc_per_node $NUM_GPU --hostfile "hostfile.txt" \
$example_dir/benchmark/benchmark_cai.py \
--model_name $MODEL \
--batch_size 20 \
--seq_length $SEQ_LENGTH \
--warmup $WARMUP \
--active $ACTIVE \
--plugin ep_zero \
--use_kernel \
--extra_dp_size 2 \
--zero_stage 1 \
--load_balance \
--overlap_alltoall

View File

@@ -0,0 +1,139 @@
import argparse
import functools
import os
import torch
import torch.distributed as dist
import tqdm
from model.modeling_openmoe import LlamaConfig, OpenMoeDecoderLayer, OpenMoeForCausalLM, set_openmoe_args
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from torch.utils.data import Dataset
from torch.utils.data.distributed import DistributedSampler
from transformers.models.llama import LlamaConfig
from utils import PerformanceEvaluator, get_model_numel
from colossalai.moe.manager import MOE_MANAGER
class RandomDataset(Dataset):
def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000):
self.num_samples = num_samples
self.max_length = max_length
self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length))
self.attention_mask = torch.ones_like(self.input_ids)
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
return {
"input_ids": self.input_ids[idx],
"attention_mask": self.attention_mask[idx],
"labels": self.input_ids[idx],
}
def fsdp_main(rank, world_size, args):
# initialize the process group
# initialize the process group
dist.init_process_group("nccl")
MOE_MANAGER.setup(seed=42, parallel=None)
dp_size = dist.get_world_size()
dataset = RandomDataset(
max_length=args.seq_length,
num_samples=args.batch_size * (args.warmup + args.active) * dp_size,
)
sampler = DistributedSampler(dataset, rank=rank, num_replicas=world_size, shuffle=False)
train_kwargs = {"batch_size": args.batch_size, "sampler": sampler}
train_loader = torch.utils.data.DataLoader(dataset, **train_kwargs)
torch.cuda.set_device(rank)
config = LlamaConfig.from_pretrained("hpcaitech/openmoe-%s" % args.model_name)
set_openmoe_args(
config,
num_experts=config.num_experts,
moe_layer_interval=config.moe_layer_interval,
enable_load_balance=False,
enable_kernel=False,
enable_comm_overlap=False,
)
torch.set_default_dtype(torch.float16)
model = OpenMoeForCausalLM(config)
torch.set_default_dtype(torch.float32)
auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={
OpenMoeDecoderLayer,
},
)
model = FSDP(
model,
mixed_precision=MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
buffer_dtype=torch.bfloat16,
),
auto_wrap_policy=auto_wrap_policy,
device_id=torch.cuda.current_device(),
)
optimizer = torch.optim.Adam(model.parameters(), weight_decay=0.01, lr=1e-5)
model.train()
model_numel = get_model_numel(model)
performance_evaluator = PerformanceEvaluator(
model_numel,
enable_grad_checkpoint=True,
ignore_steps=args.warmup,
dp_world_size=dist.get_world_size(),
)
for step, data in tqdm.tqdm(enumerate(train_loader), total=len(train_loader)):
performance_evaluator.on_step_start(step)
input_ids, attention_mask, labels = (
data["input_ids"].cuda(),
data["attention_mask"].cuda(),
data["labels"].cuda(),
)
optimizer.zero_grad()
output = model(
input_ids=input_ids,
labels=labels,
attention_mask=attention_mask,
chunk_head=False,
)
loss = output["loss"]
loss.backward()
optimizer.step()
performance_evaluator.on_step_end(input_ids)
performance_evaluator.on_fit_end()
if dist.get_rank() == 0:
print(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name",
type=str,
default="base",
choices=["base", "8b"],
help="base or 8b",
)
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--seq_length", type=int, default=2048)
parser.add_argument("--warmup", type=int, default=20)
parser.add_argument("--active", type=int, default=20)
args = parser.parse_args()
torch.manual_seed(42)
world_size = int(os.environ["WORLD_SIZE"])
local_rank = int(os.environ["LOCAL_RANK"])
fsdp_main(local_rank, world_size, args)

View File

@@ -0,0 +1,34 @@
#!/bin/bash
set -xue
MODEL="8b"
BATCH_SIZE=1
SEQ_LENGTH=2048
WARMUP=8
ACTIVE=4
# HACK: make model importable
example_dir=$(dirname $(realpath $(dirname $0)))
if [ -z ${PYTHONPATH+x} ]; then
export PYTHONPATH=$example_dir
else
export PYTHONPATH=$example_dir:$PYTHONPATH
fi
# single node
torchrun --standalone $example_dir/benchmark/benchmark_fsdp.py \
--model_name $MODEL \
--batch_size $BATCH_SIZE \
--seq_length $SEQ_LENGTH \
--warmup $WARMUP \
--active $ACTIVE
# multi node
torchrun --nproc_per_node=8 --nnodes=2 --node_rank=node_rank --master_addr=master_addr --master_port=master_port \
$example_dir/benchmark/benchmark_fsdp.py \
--model_name $MODEL \
--batch_size $BATCH_SIZE \
--seq_length $SEQ_LENGTH \
--warmup $WARMUP \
--active $ACTIVE

View File

@@ -0,0 +1,2 @@
host1
host2

View File

@@ -0,0 +1,126 @@
from time import time
from typing import Optional
import torch
import torch.distributed as dist
import torch.nn as nn
from torch import Tensor
from colossalai.logging import DistributedLogger
def print_model_numel(logger: DistributedLogger, model: nn.Module) -> None:
B = 1024**3
M = 1024**2
K = 1024
outputs = "Model param count: "
model_param = sum(p.numel() for p in model.parameters() if p.requires_grad)
if model_param >= B:
outputs += f"{model_param / B:.2f} B\n"
elif model_param >= M:
outputs += f"{model_param / M:.2f} M\n"
elif model_param >= K:
outputs += f"{model_param / K:.2f} K\n"
else:
outputs += f"{model_param}\n"
logger.info(outputs, ranks=[0])
def get_model_numel(model: nn.Module) -> None:
model_param = sum(p.numel() for p in model.parameters() if p.requires_grad)
return model_param
def divide(x: float, y: float) -> float:
if y == 0:
return float("inf")
elif y == float("inf"):
return float("nan")
return x / y
@torch.no_grad()
def all_reduce_mean(x: float, world_size: int) -> float:
if world_size == 1:
return x
tensor = torch.tensor([x], device=torch.cuda.current_device())
dist.all_reduce(tensor)
tensor = tensor / world_size
return tensor.item()
class Timer:
def __init__(self) -> None:
self.start_time: Optional[float] = None
self.duration: float = 0.0
def start(self) -> None:
self.start_time = time()
def end(self) -> None:
assert self.start_time is not None
self.duration += time() - self.start_time
self.start_time = None
def reset(self) -> None:
self.duration = 0.0
class PerformanceEvaluator:
"""
Callback for valuate the performance of the model.
Args:
actor_num_params: The number of parameters of the actor model.
critic_num_params: The number of parameters of the critic model.
initial_model_num_params: The number of parameters of the initial model.
reward_model_num_params: The number of parameters of the reward model.
enable_grad_checkpoint: Whether to enable gradient checkpointing.
ignore_episodes: The number of episodes to ignore when calculating the performance.
"""
def __init__(
self,
model_numel: int,
enable_grad_checkpoint: bool = False,
ignore_steps: int = 0,
dp_world_size: Optional[int] = None,
) -> None:
self.model_numel = model_numel
self.enable_grad_checkpoint = enable_grad_checkpoint
self.ignore_steps = ignore_steps
self.dp_world_size = dp_world_size
self.world_size = dist.get_world_size()
self.disable: bool = False
self.timer = Timer()
self.num_samples: int = 0
self.flop: int = 0
def on_step_start(self, step: int) -> None:
self.disable = self.ignore_steps > 0 and step < self.ignore_steps
if self.disable:
return
torch.cuda.synchronize()
self.timer.start()
def on_step_end(self, input_ids: Tensor, **kwargs) -> None:
if self.disable:
return
torch.cuda.synchronize()
self.timer.end()
batch_size, seq_len = input_ids.shape
self.num_samples += batch_size
self.flop += (batch_size * seq_len * self.model_numel * 2 * (3 + int(self.enable_grad_checkpoint)))
def on_fit_end(self) -> None:
avg_duration = all_reduce_mean(self.timer.duration, self.world_size)
avg_throughput = self.num_samples * self.dp_world_size / (avg_duration + 1e-12)
mp_world_size = self.world_size // self.dp_world_size
avg_tflops_per_gpu = self.flop / 1e12 / (avg_duration + 1e-12) / mp_world_size
if dist.get_rank() == 0:
print(
f"num_samples: {self.num_samples}, dp_world_size: {self.dp_world_size}, flop: {self.flop}, avg_duration: {avg_duration}, "
f"avg_throughput: {avg_throughput}")
print(f"Throughput: {avg_throughput:.2f} samples/sec, TFLOPS per GPU: {avg_tflops_per_gpu:.2f}")

View File

@@ -0,0 +1,57 @@
from argparse import ArgumentParser
import torch
from model.modeling_openmoe import OpenMoeForCausalLM, set_openmoe_args
from transformers import T5Tokenizer
from transformers.models.llama import LlamaConfig
def parse_args():
parser = ArgumentParser()
parser.add_argument("--model", default="base", type=str, help="model path", choices=["base", "8b", "test"])
return parser.parse_args()
def inference(args):
tokenizer = T5Tokenizer.from_pretrained("google/umt5-small")
if args.model == "test":
config = LlamaConfig.from_pretrained("hpcaitech/openmoe-base")
set_openmoe_args(config,
num_experts=config.num_experts,
moe_layer_interval=config.moe_layer_interval,
enable_kernel=True)
model = OpenMoeForCausalLM(config)
else:
config = LlamaConfig.from_pretrained(f"hpcaitech/openmoe-{args.model}")
set_openmoe_args(config,
num_experts=config.num_experts,
moe_layer_interval=config.moe_layer_interval,
enable_kernel=False)
model = OpenMoeForCausalLM.from_pretrained(f"hpcaitech/openmoe-{args.model}", config=config)
model = model.eval().bfloat16()
model = model.to(torch.cuda.current_device())
input_str = """```
y = list(map(int, ['1', 'hello', '2']))
```
What error does this program produce?
ValueError: invalid literal for int() with base 10: 'hello'
```
sum = 0
for i in range(100):
sum += i
```
What is the value of sum immediately after the 10th time line 3 is executed?"""
# print("model config: ", model.config)
input_ids = tokenizer("<pad>" + input_str, return_tensors="pt", add_special_tokens=False)
input_ids = input_ids.input_ids.to(torch.cuda.current_device())
generation_output = model.generate(input_ids, use_cache=True, do_sample=True, max_new_tokens=64)
out = tokenizer.decode(generation_output[0], skip_special_tokens=False)
print(f"output: \n{out}\n")
if __name__ == "__main__":
args = parse_args()
inference(args)

View File

@@ -0,0 +1 @@
python infer.py --model "base"

View File

@@ -0,0 +1,224 @@
# coding=utf-8
# Copyright 2022 Google LLC and HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Convert T5X checkpoint to PyTorch
Steps:
- Install gsutil according to https://cloud.google.com/storage/docs/gsutil_install
- Get a T5X checkpoint at https://github.com/google-research/t5x/blob/main/docs/models.md#t5-11-checkpoints Example:
`gsutil -m cp -r gs://t5-data/pretrained_models/t5x/t5_1_1_small $HOME/`
- Create or download a corresponding config for the downloaded model. E.g. for T5 v1.1 small, you can use
https://huggingface.co/google/t5-v1_1-small/blob/main/config.json
- Convert:
```
python3 convert_t5x_checkpoint_to_pytorch.py --t5x_checkpoint_path=$HOME/t5_1_1_small --config_file=config.json\
--pytorch_dump_path=$HOME/t5_1_1_small_pt
```
"""
import argparse
import collections
import torch
from flax import traverse_util
from modeling_openmoe import OpenMoeForCausalLM
from t5x import checkpoints
from transformers import LlamaConfig
from transformers.utils import logging
logging.set_verbosity_info()
def t5x_attention_lookup(params, i, prefix, layer_name="attention"):
"""Returns the KOQV parameters of (self-)attention. Does not transpose."""
k = params[f"{prefix}/layers_{i}/{layer_name}/key/kernel"]
o = params[f"{prefix}/layers_{i}/{layer_name}/out/kernel"]
q = params[f"{prefix}/layers_{i}/{layer_name}/query/kernel"]
v = params[f"{prefix}/layers_{i}/{layer_name}/value/kernel"]
return k, o, q, v
def t5x_mlp_lookup(params, i, prefix, split_mlp_wi=False):
"""Returns the MLP parameters of a layer. Does not transpose."""
if split_mlp_wi:
wi_0 = params[f"{prefix}/layers_{i}/mlp/wi_0/kernel"]
wi_1 = params[f"{prefix}/layers_{i}/mlp/wi_1/kernel"]
wi = (wi_0, wi_1)
else:
wi = params[f"{prefix}/layers_{i}/mlp/wi/kernel"]
wo = params[f"{prefix}/layers_{i}/mlp/wo/kernel"]
return wi, wo
def t5x_extra_mlp_lookup(params, i, prefix, split_mlp_wi=False):
"""Returns the MLP parameters of a layer. Does not transpose."""
if split_mlp_wi:
wi_0 = params[f"{prefix}/layers_{i}/extra_mlp/wi_0/kernel"]
wi_1 = params[f"{prefix}/layers_{i}/extra_mlp/wi_1/kernel"]
wi = (wi_0, wi_1)
else:
wi = params[f"{prefix}/layers_{i}/extra_mlp/wi/kernel"]
wo = params[f"{prefix}/layers_{i}/extra_mlp/wo/kernel"]
return wi, wo
def t5x_experts_lookup(params, i, prefix, split_mlp_wi=False):
"""Returns the MLP parameters of a layer. Does not transpose."""
if split_mlp_wi:
wi_0 = params[f"{prefix}/layers_{i}/mlp/expert/wi_0/kernel"]
wi_1 = params[f"{prefix}/layers_{i}/mlp/expert/wi_1/kernel"]
wi = (wi_0, wi_1)
else:
wi = params[f"{prefix}/layers_{i}/mlp/expert/wi/kernel"]
wo = params[f"{prefix}/layers_{i}/mlp/expert/wo/kernel"]
return wi, wo
def t5x_gate_lookup(params, i, prefix, split_mlp_wi=False):
"""Returns the MLP parameters of a layer. Does not transpose."""
return params[f"{prefix}/layers_{i}/mlp/router/router_weights/w/kernel"]
def t5x_layer_norm_lookup(params, i, prefix, layer_name):
"""Returns the layer norm param of a layer."""
return params[f"{prefix}/layers_{i}/{layer_name}/scale"]
def convert_t5x_to_pytorch(variables: dict, *, num_layers: int, moe_interval: int):
"""Converts the parameters from T5X-Flax to Transformers-PyTorch."""
old = traverse_util.flatten_dict(variables["target"])
old = {"/".join(k): v for k, v in old.items()}
# v1.1 models have a gated GeLU with wi_0 and wi_1 instead of wi
split_mlp_wi = True
print("Split MLP:", split_mlp_wi)
new = collections.OrderedDict()
print(old.keys())
for key, value in old.items():
print(f"{key}: {value.shape}")
# Shared embeddings.
new["model.embed_tokens.weight"] = old["token_embedder/embedding"]
# Decoder.
for i in range(num_layers):
# Block i, layer 0 (Self Attention).
layer_norm = t5x_layer_norm_lookup(old, i, "decoder", "pre_self_attention_layer_norm")
k, o, q, v = t5x_attention_lookup(old, i, "decoder", "self_attention")
new[f"model.layers.{i}.input_layernorm.weight"] = layer_norm
new[f"model.layers.{i}.self_attn.k_proj.weight"] = k.T
new[f"model.layers.{i}.self_attn.o_proj.weight"] = o.T
new[f"model.layers.{i}.self_attn.q_proj.weight"] = q.T
new[f"model.layers.{i}.self_attn.v_proj.weight"] = v.T
# Block i, layer 2 (MLP).
layer_norm = t5x_layer_norm_lookup(old, i, "decoder", "pre_mlp_layer_norm")
new[f"model.layers.{i}.post_attention_layernorm.weight"] = layer_norm
if (i + 1) % moe_interval == 0:
# moe
gate = t5x_gate_lookup(old, i, "decoder", split_mlp_wi)
new[f"model.layers.{i}.mlp.gate_weight"] = gate.T
wi, wo = t5x_experts_lookup(old, i, "decoder", split_mlp_wi)
new[f"model.layers.{i}.mlp.experts.wi_gate"] = wi[0]
new[f"model.layers.{i}.mlp.experts.wi_up"] = wi[1]
new[f"model.layers.{i}.mlp.experts.wo"] = wo
# extra
layer_norm = t5x_layer_norm_lookup(old, i, "decoder", "pre_extra_mlp_layer_norm")
new[f"model.layers.{i}.pre_extra_mlp_layernorm.weight"] = layer_norm
wi, wo = t5x_extra_mlp_lookup(old, i, "decoder", split_mlp_wi)
new[f"model.layers.{i}.extra_mlp.gate_proj.weight"] = wi[0].T
new[f"model.layers.{i}.extra_mlp.up_proj.weight"] = wi[1].T
new[f"model.layers.{i}.extra_mlp.down_proj.weight"] = wo.T
else:
wi, wo = t5x_mlp_lookup(old, i, "decoder", split_mlp_wi)
new[f"model.layers.{i}.mlp.gate_proj.weight"] = wi[0].T
new[f"model.layers.{i}.mlp.up_proj.weight"] = wi[1].T
new[f"model.layers.{i}.mlp.down_proj.weight"] = wo.T
new["model.norm.weight"] = old["decoder/decoder_norm/scale"]
# LM Head (only in v1.1 checkpoints, in v1.0 embeddings are used instead)
if "decoder/logits_dense/kernel" in old:
new["lm_head.weight"] = old["decoder/logits_dense/kernel"].T
return new
def make_state_dict(converted_params):
"""Prepares a state dict for the PyTorch model."""
# Make a state dict with torch tensors.
state_dict = collections.OrderedDict([(k, torch.from_numpy(v.copy())) for (k, v) in converted_params.items()])
return state_dict
def load_t5x_weights_in_t5(model, config, t5x_checkpoint_path):
"""Replaces the params in model witht the T5X converted params."""
variables = checkpoints.load_t5x_checkpoint(t5x_checkpoint_path)
converted = convert_t5x_to_pytorch(variables,
num_layers=config.num_hidden_layers,
moe_interval=config.moe_layer_interval)
state_dict = make_state_dict(converted)
model.load_state_dict(state_dict, strict=True)
def convert_t5x_checkpoint_to_pytorch(t5x_checkpoint_path, config_file, pytorch_dump_path):
"""Loads the config and model, converts the T5X checkpoint, and saves a PyTorch checkpoint."""
# Initialise PyTorch model
config = LlamaConfig.from_json_file(config_file)
print(f"Building PyTorch model from configuration: {config}")
# Non-v1.1 checkpoints could also use T5Model, but this works for all.
# The v1.0 checkpoints will simply have an LM head that is the word embeddings.
model = OpenMoeForCausalLM(config)
# Load weights from tf checkpoint
load_t5x_weights_in_t5(model, config, t5x_checkpoint_path)
# Save pytorch-model
print(f"Save PyTorch model to {pytorch_dump_path}")
model.save_pretrained(pytorch_dump_path)
# Verify that we can load the checkpoint.
model.from_pretrained(pytorch_dump_path)
print("Done")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Converts a native T5X checkpoint into a PyTorch checkpoint.")
# Required parameters
parser.add_argument("--t5x_checkpoint_path",
default=None,
type=str,
required=True,
help="Path to the T5X checkpoint.")
parser.add_argument(
"--config_file",
default=None,
type=str,
required=True,
help="The config json file corresponding to the pre-trained T5 model.\nThis specifies the model architecture.",
)
parser.add_argument("--pytorch_dump_path",
default=None,
type=str,
required=True,
help="Path to the output PyTorch model.")
args = parser.parse_args()
convert_t5x_checkpoint_to_pytorch(args.t5x_checkpoint_path, args.config_file, args.pytorch_dump_path)

View File

@@ -0,0 +1 @@
python convert_openmoe_ckpt.py --t5x_checkpoint_path /path/to/t5x --config_file /path/to/config --pytorch_dump_path /path/to/save

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,24 @@
{
"architectures": [
"OpenMoeForCausalLM"
],
"intermediate_size": 8192,
"hidden_size": 2048,
"num_hidden_layers": 24,
"head_dim": 128,
"num_attention_heads": 24,
"dropout_rate": 0.0,
"layer_norm_epsilon": 1e-06,
"vocab_size": 256384,
"hidden_act": "swiglu",
"num_experts": 32,
"topk": 2,
"capacity_factor_train": 1.25,
"capacity_factor_eval": 2.0,
"min_capacity": 4,
"noisy_policy": null,
"drop_tks": true,
"expert_parallel": null,
"gated": true,
"moe_layer_interval": 6
}

View File

@@ -0,0 +1,24 @@
{
"architectures": [
"OpenMoeForCausalLM"
],
"intermediate_size": 2048,
"hidden_size": 768,
"num_hidden_layers": 12,
"head_dim": 64,
"num_attention_heads": 12,
"dropout_rate": 0.0,
"layer_norm_epsilon": 1e-06,
"vocab_size": 256384,
"hidden_act": "swiglu",
"num_experts": 16,
"topk": 2,
"capacity_factor_train": 1.25,
"capacity_factor_eval": 2.0,
"min_capacity": 4,
"noisy_policy": null,
"drop_tks": true,
"expert_parallel": null,
"gated": true,
"moe_layer_interval": 4
}

View File

@@ -0,0 +1,562 @@
import warnings
from functools import partial
from typing import Callable, Dict, List, Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Module
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.utils import logging
from colossalai.moe.manager import MOE_MANAGER
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
from .modeling_openmoe import OpenMoeDecoderLayer, OpenMoeForCausalLM, OpenMoeModel
__all__ = ["OpenMoePolicy", "OpenMoeForCausalLMPolicy"]
class OpenMoePolicy(Policy):
def config_sanity_check(self):
pass
def preprocess(self):
if self.shard_config.enable_tensor_parallelism:
# Resize embedding
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
policy = {}
if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False
raise NotImplementedError(
"openmoe dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
if self.shard_config.enable_tensor_parallelism:
raise NotImplementedError("Tensor parallelism is not supported for openmoe model now.")
# optimization configuration
if self.shard_config.enable_fused_normalization:
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="input_layernorm",
target_module=FusedRMSNorm,
),
SubModuleReplacementDescription(
suffix="post_attention_layernorm",
target_module=FusedRMSNorm,
),
SubModuleReplacementDescription(
suffix="pre_extra_mlp_layernorm",
target_module=FusedRMSNorm,
ignore_if_not_exist=True,
),
],
policy=policy,
target_key=OpenMoeDecoderLayer,
)
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="norm",
target_module=FusedRMSNorm,
),
policy=policy,
target_key=OpenMoeModel,
)
if self.shard_config.enable_flash_attention:
raise NotImplementedError("Flash attention has already been replaced in openmoe.")
return policy
def postprocess(self):
return self.model
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
"""If under pipeline parallel setting, replacing the original forward method of huggingface
to customized forward method, and add this changing to policy."""
if self.pipeline_stage_manager:
stage_manager = self.pipeline_stage_manager
if self.model.__class__.__name__ == "OpenMoeModel":
module = self.model
else:
module = self.model.model
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
self.append_or_create_method_replacement(description=method_replacement,
policy=policy,
target_key=model_cls)
return
def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
assert self.pipeline_stage_manager is not None
if self.model.__class__.__name__ == "OpenMoeModel":
module = self.model
else:
module = self.model.model
stage_manager = self.pipeline_stage_manager
held_layers = []
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
if stage_manager.is_first_stage():
held_layers.append(module.embed_tokens)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
held_layers.extend(module.layers[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.norm)
return held_layers
@staticmethod
def distribute_layers(num_layers: int, num_stages: int) -> List[int]:
"""Divide layers into stages
"""
if num_layers == 24 and num_stages == 4:
return [7, 7, 7, 3]
elif num_layers == 24 and num_stages == 2:
return [15, 9]
elif num_layers == 12 and num_stages == 4:
return [5, 5, 5, 1]
elif num_layers == 12 and num_stages == 2:
return [8, 4]
else:
print(f"num_layers: {num_layers}, num_stages: {num_stages} not optimized, use origin pp policy")
return Policy.distribute_layers(num_layers, num_stages)
class OpenMoeModelPolicy(OpenMoePolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
policy = super().module_policy()
if self.pipeline_stage_manager:
# set None as default
self.set_pipeline_forward(
model_cls=OpenMoeModel,
new_forward=OpenMoePipelineForwards.openmoe_model_forward,
policy=policy,
)
return policy
def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
held_layers = super().get_held_layers()
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
"""No shared params in llama model"""
return []
class OpenMoeForCausalLMPolicy(OpenMoePolicy):
def module_policy(self):
policy = super().module_policy()
if self.shard_config.enable_tensor_parallelism:
# add a new item for casual lm
new_item = {
OpenMoeForCausalLM:
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head",
target_module=Linear1D_Col,
kwargs=dict(gather_output=True),
)
])
}
policy.update(new_item)
if self.pipeline_stage_manager:
# set None as default
self.set_pipeline_forward(
model_cls=OpenMoeForCausalLM,
new_forward=OpenMoePipelineForwards.llama_for_causal_lm_forward,
policy=policy,
)
return policy
def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
stage_manager = self.pipeline_stage_manager
held_layers = super().get_held_layers()
if stage_manager.is_last_stage():
held_layers.append(self.model.lm_head)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
llama_model = self.model.model
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
if (id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight)
and self.pipeline_stage_manager.num_stages > 1):
# tie weights
return [{
0: llama_model.embed_tokens.weight,
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
}]
return []
class OpenMoePipelineForwards:
"""
This class serves as a micro library for forward function substitution of Llama models
under pipeline setting.
"""
@staticmethod
def openmoe_model_forward(
self: OpenMoeModel,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
past_router_aux_loss: Optional[torch.FloatTensor] = None,
past_router_z_loss: Optional[torch.FloatTensor] = None,
):
# reset moe loss for different data
MOE_MANAGER.reset_loss()
logger = logging.get_logger(__name__)
output_attentions = (output_attentions if output_attentions is not None else self.config.output_attentions)
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else self.config.output_hidden_states)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (return_dict if return_dict is not None else self.config.use_return_dict)
# retrieve input_ids and inputs_embeds
if stage_manager.is_first_stage():
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds
else:
input_shape = hidden_states.shape[:-1]
batch_size, seq_length = input_shape
device = hidden_states.device
seq_length_with_past = seq_length
past_key_values_length = 0
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
if output_attentions:
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
output_attentions = False
if output_hidden_states:
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
output_hidden_states = False
if use_cache:
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
use_cache = False
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
if position_ids is None:
position_ids = torch.arange(
past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
# embed positions, for the first stage, hidden_states is the input embeddings,
# for the other stages, hidden_states is the output of the previous stage
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past),
dtype=torch.bool,
device=hidden_states.device,
)
attention_mask = self._prepare_decoder_attention_mask(
attention_mask,
(batch_size, seq_length),
hidden_states,
past_key_values_length,
)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
start_idx, end_idx = stage_index[0], stage_index[1]
for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = (past_key_values[idx] if past_key_values is not None else None)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, None)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
position_ids,
None,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
if stage_manager.is_last_stage():
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
# concat past losses with current ones
router_aux_loss, router_z_loss = MOE_MANAGER.get_loss()
if past_router_aux_loss is not None and past_router_z_loss is not None:
router_aux_loss = past_router_aux_loss + router_aux_loss
router_z_loss = past_router_z_loss + router_z_loss
if stage_manager.is_last_stage():
return tuple([
hidden_states,
next_cache,
all_hidden_states,
all_self_attns,
router_aux_loss,
router_z_loss,
])
# always return dict for imediate stage
return {
"hidden_states": hidden_states,
"router_aux_loss": router_aux_loss,
"router_z_loss": router_z_loss,
}
@staticmethod
def llama_for_causal_lm_forward(
self: OpenMoeForCausalLM,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
chunk_head: Optional[bool] = True,
past_router_aux_loss: Optional[torch.FloatTensor] = None,
past_router_z_loss: Optional[torch.FloatTensor] = None,
):
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, LlamaForCausalLM
>>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
>>> prompt = "Hey, are you consciours? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
```"""
logger = logging.get_logger(__name__)
output_attentions = (output_attentions if output_attentions is not None else self.config.output_attentions)
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else self.config.output_hidden_states)
return_dict = (return_dict if return_dict is not None else self.config.use_return_dict)
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
if output_attentions:
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
output_attentions = False
if output_hidden_states:
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
output_hidden_states = False
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = OpenMoePipelineForwards.openmoe_model_forward(
self.model,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index,
past_router_aux_loss=past_router_aux_loss,
past_router_z_loss=past_router_z_loss,
)
if stage_manager.is_last_stage():
(
hidden_states,
past_key_values,
all_hidden_states,
attentions,
router_aux_loss,
router_z_loss,
) = outputs
if self.pretraining_tp > 1:
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.pretraining_tp, dim=0)
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.pretraining_tp)]
logits = torch.cat(logits, dim=-1)
loss = None
# if no training, just do forward
if labels is None:
logits = self.lm_head(hidden_states)
logits = logits.float()
# the vocab size for openmoe is 30w+
# which causes great activation memory in training, up to 20G for one sequence
# so we use chunk and checkpoint to reduce memory
else:
if chunk_head == True:
def create_custom_forward(module):
def custom_forward(*inputs):
logits = module(inputs[0])
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous().float()
shift_labels = inputs[1][..., 1:].contiguous()
# Flatten the tokens
loss = self._calculate_loss(shift_logits, shift_labels)
return loss
return custom_forward
aux_loss, z_loss = self._calculate_router_loss(router_aux_loss, router_z_loss)
loss = aux_loss + z_loss
for batch_idx in range(hidden_states.shape[0]):
loss = loss + torch.utils.checkpoint.checkpoint(
create_custom_forward(self.lm_head),
hidden_states[batch_idx:batch_idx + 1, :],
labels[batch_idx:batch_idx + 1, :],
)
logits = None
else:
logits = self.lm_head(hidden_states)
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
aux_loss, z_loss = self._calculate_router_loss(router_aux_loss, router_z_loss)
loss = aux_loss + z_loss
loss = loss + self._calculate_loss(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=past_key_values,
hidden_states=all_hidden_states,
attentions=attentions,
)
else:
hidden_states = outputs["hidden_states"]
router_aux_loss = outputs["router_aux_loss"]
router_z_loss = outputs["router_z_loss"]
return {
"hidden_states": hidden_states,
"past_router_aux_loss": router_aux_loss,
"past_router_z_loss": router_z_loss,
}

View File

@@ -0,0 +1,5 @@
colossalai >= 0.3.3
torch >= 1.8.1
transformers >= 4.20.0
sentencepiece
datasets

View File

@@ -0,0 +1,37 @@
pip install -r requirements.txt
# inference
python infer.py --model "test"
# train
torchrun --standalone --nproc_per_node 4 train.py \
--num_epoch 1 \
--model_name "test" \
--plugin "ep" \
--batch_size 1
torchrun --standalone --nproc_per_node 4 train.py \
--num_epoch 1 \
--model_name "test" \
--plugin "ep_zero" \
--batch_size 1 \
--zero_stage 1 \
--extra_dp_size 2 \
torchrun --standalone --nproc_per_node 4 train.py \
--num_epoch 1 \
--model_name "test" \
--plugin "ep_zero" \
--batch_size 1 \
--zero_stage 2 \
--extra_dp_size 2 \
torchrun --standalone --nproc_per_node 4 train.py \
--model_name "test" \
--plugin "hybrid" \
--num_epoch 1 \
--pp_size 2 \
--dp_size 1 \
--ep_size 2 \
--zero_stage 1 \
--batch_size 1

View File

@@ -0,0 +1,377 @@
import argparse
import os
from functools import partial
from typing import Dict
import torch
import torch.distributed as dist
from datasets import load_dataset
from huggingface_hub import snapshot_download
from model.modeling_openmoe import OpenMoeForCausalLM, set_openmoe_args
from model.openmoe_policy import OpenMoeForCausalLMPolicy
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import T5Tokenizer
from transformers.models.llama import LlamaConfig
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.cluster import DistCoordinator
from colossalai.moe.layers import apply_load_balance
from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.utils import skip_init
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
def move_to_cuda(batch, device):
return {k: v.to(device) for k, v in batch.items()}
def load_ckpt(repo_name: str, model: OpenMoeForCausalLM, booster: Booster):
ckpt_path = snapshot_download(repo_name)
# single ckpt
if os.path.exists(os.path.join(ckpt_path, "pytorch_model.bin")):
ckpt_path = os.path.join(ckpt_path, "pytorch_model.bin")
# shard ckpt
elif os.path.exists(os.path.join(ckpt_path, "pytorch_model.bin.index.json")):
ckpt_path = os.path.join(ckpt_path, "pytorch_model.bin.index.json")
else:
raise ValueError(f"Invalid checkpoint path: {ckpt_path}")
booster.load_model(model, ckpt_path)
def tokenize_data(batch, tokenizer: T5Tokenizer, max_length: int) -> Dict:
texts = ["<pad>" + sample["prompt"] + sample["completion"] for sample in batch]
data = tokenizer(
texts,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=max_length,
add_special_tokens=False,
)
data = {k: v.cuda() for k, v in data.items()}
data["labels"] = data["input_ids"].clone()
return data
class RandomDataset(Dataset):
def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000, tokenizer=None):
self.num_samples = num_samples
self.max_length = max_length
self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device())
self.attention_mask = torch.ones_like(self.input_ids)
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
return {
"input_ids": self.input_ids[idx],
"attention_mask": self.attention_mask[idx],
"labels": self.input_ids[idx],
}
def parse_args():
# basic settings
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name",
type=str,
default="base",
choices=["base", "8b", "test"],
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--plugin",
type=str,
default="hybrid",
choices=["ep", "ep_zero", "hybrid"],
help="Parallel methos. ep_zero is recommended for general cases. ep can provides least memory consumption and hybrid suits large scale training.",
)
parser.add_argument(
"--output_path",
type=str,
default="./outputs",
help="The path of your saved model after finetuning.",
)
parser.add_argument("--num_epoch", type=int, default=1, help="Number of epochs.")
parser.add_argument(
"--batch_size",
type=int,
default=1,
help="Batch size (per dp group) for the training dataloader.",
)
parser.add_argument(
"--save_interval",
type=int,
default=1000,
help=" The interval (steps) of saving checkpoints.",
)
parser.add_argument(
"--precision",
type=str,
default="bf16",
choices=["fp32", "bf16", "fp16"],
help="The mixed precision training.",
)
parser.add_argument("--max_length", type=int, default=2048, help="Max sequence length.")
parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
parser.add_argument(
"--dataset",
type=str,
default="yizhongw/self_instruct",
help="dataset name from `datasets` repo.",
)
parser.add_argument(
"--task_name",
type=str,
default="super_natural_instructions",
help="task of corresponding dataset.",
)
# optim
parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate.")
parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
# zero stage for all plugins
parser.add_argument("--zero_stage", type=int, default=2, help="zero stage.")
# ep_zero plugin
parser.add_argument(
"--extra_dp_size", type=int, default=1, help="ep_zero plugin's moe dp size. Recommended to be 2 or 4."
)
# hybrid plugin
parser.add_argument("--pp_size", type=int, default=2, help="pp size for hybrid plugin")
parser.add_argument("--dp_size", type=int, default=1, help="dp size for hybrid plugin")
parser.add_argument("--ep_size", type=int, default=2, help="ep size for hybrid plugin")
parser.add_argument("--microbatch_size", type=int, default=1, help="Microbatch size in pipeline for hybrid plugin")
# kernel
parser.add_argument(
"--use_kernel",
action="store_true",
help="Use kernel optim. Need to install flash attention and triton to enable all kernel optimizations. Skip if not installed.",
)
parser.add_argument(
"--use_layernorm_kernel",
action="store_true",
help="Use layernorm kernel. Need to install apex. Raise error if not installed.",
)
# loss
parser.add_argument(
"--router_aux_loss_factor",
type=float,
default=0.01,
help="Moe router z loss. You can refer to STMoE for details.",
)
parser.add_argument(
"--router_z_loss_factor",
type=float,
default=0.0001,
help="Moe router aux loss. You can refer to STMoE for details.",
)
parser.add_argument("--label_smoothing", type=float, default=0.0, help="Label smoothing.")
parser.add_argument(
"--z_loss_factor", type=float, default=0.0001, help="The final outputs' classification z loss factor."
)
# load balance
parser.add_argument(
"--load_balance", action="store_true", help="Expert load balance. Defaults to False. Recommend to enable."
)
parser.add_argument("--load_balance_interval", type=int, default=1000, help="Expert load balance interval.")
# communicate overlap
parser.add_argument(
"--comm_overlap",
action="store_true",
help="Use communication overlap for MoE. Recommended to enable for muiti-node training.",
)
args = parser.parse_args()
return args
def main():
args = parse_args()
# Launch ColossalAI
colossalai.launch_from_torch(config={}, seed=args.seed)
coordinator = DistCoordinator()
test_mode = args.model_name == "test"
# Set plugin
booster_kwargs = {}
hybrid_dict = {
"tp_size": 1,
"custom_policy": OpenMoeForCausalLMPolicy(),
"enable_fused_normalization": args.use_layernorm_kernel,
"enable_jit_fused": args.use_kernel,
"precision": args.precision,
"zero_stage": args.zero_stage,
}
mgr_dict = {
"seed": 42,
}
if args.plugin == "ep":
dp_size = dist.get_world_size()
plugin = MoeHybridParallelPlugin(
pp_size=1,
**hybrid_dict,
)
MOE_MANAGER.setup(
parallel="EP",
max_ep_size=dp_size,
**mgr_dict,
)
elif args.plugin == "ep_zero":
dp_size = dist.get_world_size()
use_ep_inside = False
plugin = MoeHybridParallelPlugin(
pp_size=1,
extra_dp_size=args.extra_dp_size,
use_ep_inside=use_ep_inside,
**hybrid_dict,
)
MOE_MANAGER.setup(
parallel="EP",
max_ep_size=dp_size // args.extra_dp_size,
use_ep_inside=use_ep_inside,
**mgr_dict,
)
elif args.plugin == "hybrid":
dp_size = dist.get_world_size() // args.pp_size
plugin = MoeHybridParallelPlugin(
pp_size=args.pp_size,
microbatch_size=args.microbatch_size,
**hybrid_dict,
)
MOE_MANAGER.setup(
parallel="EP",
mode="fixed",
fixed_dp_size=args.dp_size,
fixed_ep_size=args.ep_size,
fixed_pp_size=args.pp_size,
**mgr_dict,
)
else:
raise ValueError(f"Invalid plugin {args.plugin}")
coordinator.print_on_master(f"Set plugin as {plugin.__class__.__name__}")
# Build OpenMoe model
if test_mode:
config = LlamaConfig.from_pretrained("hpcaitech/openmoe-base")
config.hidden_size = 128
config.intermediate_size = 256
config.vocab_size = 32000
else:
repo_name = "hpcaitech/openmoe-" + args.model_name
config = LlamaConfig.from_pretrained(repo_name)
set_openmoe_args(
config,
num_experts=config.num_experts,
moe_layer_interval=config.moe_layer_interval,
router_aux_loss_factor=args.router_aux_loss_factor,
router_z_loss_factor=args.router_z_loss_factor,
z_loss_factor=args.z_loss_factor,
enable_load_balance=args.load_balance,
enable_comm_overlap=args.comm_overlap,
enable_kernel=args.use_kernel,
)
with skip_init():
model = OpenMoeForCausalLM(config)
coordinator.print_on_master(f"Finish init model with config:\n{config}")
# Enable gradient checkpointing
model.gradient_checkpointing_enable()
# Prepare tokenizer and dataloader
tokenizer = T5Tokenizer.from_pretrained("google/umt5-small")
if test_mode:
dataset = RandomDataset(num_samples=20, tokenizer=tokenizer)
collate_fn = None
else:
dataset = load_dataset(args.dataset, args.task_name)
dataset = dataset["train"]
collate_fn = partial(tokenize_data, tokenizer=tokenizer, max_length=args.max_length)
dataloader = plugin.prepare_dataloader(
dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, collate_fn=collate_fn
)
# Set optimizer
optimizer = HybridAdam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
# Set booster
booster = Booster(plugin=plugin, **booster_kwargs)
if not test_mode:
load_ckpt(repo_name, model, booster)
model, optimizer, _, dataloader, _ = booster.boost(model=model, optimizer=optimizer, dataloader=dataloader)
use_pipeline = isinstance(booster.plugin, MoeHybridParallelPlugin) and booster.plugin.pp_size > 1
is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
coordinator.print_on_master(f"Finish init booster")
# Start finetuning
coordinator.print_on_master(f"Start finetuning")
for epoch in range(args.num_epoch):
model.train()
train_dataloader_iter = iter(dataloader)
total_len = len(train_dataloader_iter)
with tqdm(
range(total_len),
desc=f"Epoch [{epoch + 1}/{args.num_epoch}]",
disable=not coordinator.is_master(),
) as pbar:
for step in pbar:
if use_pipeline:
# Forward pass
outputs = booster.execute_pipeline(
train_dataloader_iter,
model,
lambda x, y: x.loss,
optimizer,
return_loss=True,
return_outputs=True,
)
# Backward and optimize
if is_pp_last_stage:
loss = outputs["loss"]
pbar.set_postfix({"loss": loss.item()})
else:
# Forward pass
data = next(train_dataloader_iter)
data = move_to_cuda(data, torch.cuda.current_device())
outputs = model(**data)
loss = outputs["loss"]
# Backward
booster.backward(loss, optimizer)
pbar.set_postfix({"loss": loss.item()})
optimizer.step()
optimizer.zero_grad()
# Apply load balance
if (
args.load_balance
and args.load_balance_interval > 0
and (step + 1) % args.load_balance_interval == 0
):
coordinator.print_on_master(f"Apply load balance")
apply_load_balance(model, optimizer)
# save ckeckpoint
if (step + 1) % args.save_interval == 0:
coordinator.print_on_master(f"Saving model checkpoint to {args.output_path}")
booster.save_model(model, args.output_path, shard=True)
# save checkpoint at the end of each epochs
booster.save_model(model, args.output_path, shard=True)
coordinator.print_on_master(f"Saving model checkpoint to {args.output_path}")
# Finish training
coordinator.print_on_master(f"Finish training")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,40 @@
#!/bin/bash
set -xue
NUM_GPU=8
MODEL="8b"
SEQ_LENGTH=2048
BATCH_SIZE=1
LR=0.00001
# ep zero
torchrun --standalone --nproc_per_node $NUM_GPU train.py \
--num_epoch 1 \
--model_name $MODEL \
--plugin "ep_zero" \
--batch_size $BATCH_SIZE \
--lr $LR \
--zero_stage 1 \
--extra_dp_size 2
# ep
# torchrun --standalone --nproc_per_node $NUM_GPU train.py \
# --num_epoch 1 \
# --model_name $MODEL \
# --plugin "ep_zero" \
# --batch_size $BATCH_SIZE \
# --lr $LR \
# --zero_stage 1
# hybrid
# torchrun --standalone --nproc_per_node $NUM_GPU train.py \
# --num_epoch 1 \
# --model_name $MODEL \
# --plugin "hybrid" \
# --batch_size $BATCH_SIZE \
# --lr $LR \
# --zero_stage 1 \
# --pp_size 2 \
# --dp_size 1 \
# --ep_size 2 \