[moe] merge moe into main (#4978)

* update moe module
* support openmoe
This commit is contained in:
Xuanlei Zhao
2023-11-02 10:21:24 +08:00
committed by GitHub
parent 8993c8a817
commit dc003c304c
67 changed files with 7618 additions and 1657 deletions

View File

@@ -0,0 +1,296 @@
import argparse
import json
import os
import torch
import torch.distributed as dist
from huggingface_hub import snapshot_download
from model.modeling_openmoe import OpenMoeForCausalLM, set_openmoe_args
from model.openmoe_policy import OpenMoeForCausalLMPolicy
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import T5Tokenizer
from transformers.models.llama import LlamaConfig
from utils import PerformanceEvaluator, get_model_numel
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.cluster import DistCoordinator
from colossalai.moe.layers import apply_load_balance
from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.utils import skip_init
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
def move_to_cuda(batch, device):
return {k: v.to(device) for k, v in batch.items()}
def load_ckpt(repo_name: str, model: OpenMoeForCausalLM, booster: Booster):
ckpt_path = snapshot_download(repo_name)
# single ckpt
if os.path.exists(os.path.join(ckpt_path, "pytorch_model.bin")):
ckpt_path = os.path.join(ckpt_path, "pytorch_model.bin")
# shard ckpt
elif os.path.exists(os.path.join(ckpt_path, "pytorch_model.bin.index.json")):
ckpt_path = os.path.join(ckpt_path, "pytorch_model.bin.index.json")
else:
raise ValueError(f"Invalid checkpoint path: {ckpt_path}")
booster.load_model(model, ckpt_path)
class RandomDataset(Dataset):
def __init__(
self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 256384, tokenizer: T5Tokenizer = None
):
self.num_samples = num_samples
self.max_length = max_length
if os.path.exists("./mock_data.json"):
self.input_ids = []
self.attention_mask = []
with open("./mock_data.json", "r") as f:
data = json.load(f)
for v in data.values():
d = v["text"]
encode = tokenizer(
"<pad>" + d,
return_tensors="pt",
add_special_tokens=False,
max_length=max_length,
truncation=True,
padding="max_length",
)
self.input_ids.append(encode["input_ids"])
self.attention_mask.append(encode["attention_mask"])
self.input_ids = torch.cat(self.input_ids, dim=0).to(get_current_device())
self.attention_mask = torch.cat(self.attention_mask, dim=0).to(get_current_device())
repeat_times = num_samples // self.input_ids.shape[0] + 1
self.input_ids = self.input_ids.repeat(repeat_times, 1)[:num_samples]
self.attention_mask = self.attention_mask.repeat(repeat_times, 1)[:num_samples]
else:
self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device())
self.attention_mask = torch.ones_like(self.input_ids)
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
return {
"input_ids": self.input_ids[idx],
"attention_mask": self.attention_mask[idx],
"labels": self.input_ids[idx],
}
def parse_args():
# basic settings
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name",
type=str,
default="base",
choices=["base", "8b"],
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--batch_size",
type=int,
default=4,
help="Batch size (per dp group) for the training dataloader.",
)
parser.add_argument(
"--seq_length",
type=int,
default=2048,
help="sequence length for the training dataloader.",
)
parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
parser.add_argument(
"--plugin",
type=str,
default="hybrid",
help="parallel plugin",
)
# hybrid plugin
parser.add_argument("--pp_size", type=int, default=2, help="pp size")
parser.add_argument("--dp_size", type=int, default=1, help="dp size")
parser.add_argument("--ep_size", type=int, default=2, help="ep size")
parser.add_argument("--zero_stage", type=int, default=2, help="zero stage in hybrid plugin")
parser.add_argument("--microbatch_size", type=int, default=1, help="microbatch size")
parser.add_argument("--extra_dp_size", type=int, default=1)
# kernel
parser.add_argument(
"--use_kernel",
action="store_true",
help="Use kernel optim. Need to install flash attention, apex, triton to enable all kernel optimizations.",
)
# bench
parser.add_argument("--warmup", type=int, default=20)
parser.add_argument("--active", type=int, default=20)
# load balance
parser.add_argument("--load_balance", action="store_true")
# overlap
parser.add_argument("--overlap_alltoall", action="store_true")
args = parser.parse_args()
return args
def main():
args = parse_args()
# Launch ColossalAI
colossalai.launch_from_torch(config={}, seed=args.seed)
coordinator = DistCoordinator()
# Set plugin
booster_kwargs = {}
hybrid_dict = {
"tp_size": 1,
"custom_policy": OpenMoeForCausalLMPolicy(),
"enable_fused_normalization": args.use_kernel,
"enable_jit_fused": args.use_kernel,
"precision": "bf16",
"zero_stage": args.zero_stage,
}
mgr_dict = {
"seed": 42,
}
if args.plugin == "ep":
dp_size = dist.get_world_size()
plugin = MoeHybridParallelPlugin(
pp_size=1,
**hybrid_dict,
)
MOE_MANAGER.setup(
parallel="EP",
max_ep_size=dp_size,
**mgr_dict,
)
elif args.plugin == "ep_zero":
dp_size = dist.get_world_size()
use_ep_inside = False
plugin = MoeHybridParallelPlugin(
pp_size=1,
extra_dp_size=args.extra_dp_size,
use_ep_inside=use_ep_inside,
**hybrid_dict,
)
MOE_MANAGER.setup(
parallel="EP",
max_ep_size=dp_size // args.extra_dp_size,
use_ep_inside=use_ep_inside,
**mgr_dict,
)
elif args.plugin == "hybrid":
dp_size = dist.get_world_size() // args.pp_size
plugin = MoeHybridParallelPlugin(
pp_size=args.pp_size,
zero_stage=args.zero_stage,
microbatch_size=args.microbatch_size,
**hybrid_dict,
)
MOE_MANAGER.setup(
parallel="EP",
mode="fixed",
fixed_dp_size=args.dp_size,
fixed_ep_size=args.ep_size,
fixed_pp_size=args.pp_size,
**mgr_dict,
)
else:
raise ValueError(f"Invalid plugin {args.plugin}")
coordinator.print_on_master(f"Set plugin as {plugin}")
# Build OpenMoe model
repo_name = "hpcaitech/openmoe-" + args.model_name
config = LlamaConfig.from_pretrained(repo_name)
set_openmoe_args(
config,
num_experts=config.num_experts,
moe_layer_interval=config.moe_layer_interval,
enable_load_balance=args.load_balance,
enable_kernel=args.use_kernel,
enable_comm_overlap=args.overlap_alltoall,
)
with skip_init():
model = OpenMoeForCausalLM(config)
coordinator.print_on_master(f"Finish init model with config:\n{config}")
# Enable gradient checkpointing
model.gradient_checkpointing_enable()
# Prepare tokenizer and dataloader
tokenizer = T5Tokenizer.from_pretrained("google/umt5-small")
dataset = RandomDataset(
num_samples=args.batch_size * (args.warmup + args.active + 1) * dp_size,
max_length=args.seq_length,
tokenizer=tokenizer,
)
dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size)
# Set optimizer
optimizer = HybridAdam(model.parameters(), weight_decay=0.01, lr=1e-5)
model_numel = get_model_numel(model)
performance_evaluator = PerformanceEvaluator(
model_numel,
enable_grad_checkpoint=True,
ignore_steps=args.warmup,
dp_world_size=dp_size,
)
# Set booster
booster = Booster(plugin=plugin, **booster_kwargs)
load_ckpt(repo_name, model, booster)
model, optimizer, _, dataloader, _ = booster.boost(model=model, optimizer=optimizer, dataloader=dataloader)
use_pipeline = isinstance(booster.plugin, MoeHybridParallelPlugin) and booster.plugin.pp_size > 1
is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
coordinator.print_on_master(f"Finish init booster")
# Start finetuning
coordinator.print_on_master(f"Start training")
model.train()
train_dataloader_iter = iter(dataloader)
total_len = len(train_dataloader_iter) - 1
exmaple_data = next(train_dataloader_iter)
with tqdm(range(total_len), disable=not coordinator.is_master()) as pbar:
for step in pbar:
performance_evaluator.on_step_start(step)
if use_pipeline:
# Forward pass
outputs = booster.execute_pipeline(
train_dataloader_iter,
model,
lambda x, y: x.loss,
optimizer,
return_loss=True,
return_outputs=True,
)
# Backward and optimize
if is_pp_last_stage:
loss = outputs["loss"]
pbar.set_postfix({"loss": loss.item()})
else:
# Forward pass
data = next(train_dataloader_iter)
data = move_to_cuda(data, torch.cuda.current_device())
outputs = model(**data)
loss = outputs["loss"]
# Backward
booster.backward(loss, optimizer)
pbar.set_postfix({"loss": loss.item()})
optimizer.step()
optimizer.zero_grad()
performance_evaluator.on_step_end(exmaple_data["input_ids"])
if (step == args.warmup // 2) and args.load_balance:
coordinator.print_on_master(f"Apply load balance")
apply_load_balance(model, optimizer)
performance_evaluator.on_fit_end()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,78 @@
#!/bin/bash
set -xue
NUM_GPU=8
MODEL="8b"
SEQ_LENGTH=2048
WARMUP=20
ACTIVE=4
# HACK: make model importable
example_dir=$(dirname $(realpath $(dirname $0)))
if [ -z ${PYTHONPATH+x} ]; then
export PYTHONPATH=$example_dir
else
export PYTHONPATH=$example_dir:$PYTHONPATH
fi
# ep
echo -e "\n\n Naive EP \n\n"
torchrun --standalone --nproc_per_node $NUM_GPU \
$example_dir/benchmark/benchmark_cai.py \
--model_name $MODEL \
--batch_size 8 \
--seq_length $SEQ_LENGTH \
--warmup $WARMUP \
--active $ACTIVE \
--plugin ep \
--zero_stage 2
# ep_zero
echo -e "\n\n EP-ZERO \n\n"
torchrun --standalone --nproc_per_node $NUM_GPU \
$example_dir/benchmark/benchmark_cai.py \
--model_name $MODEL \
--batch_size 16 \
--seq_length $SEQ_LENGTH \
--warmup $WARMUP \
--active $ACTIVE \
--plugin ep_zero \
--use_kernel \
--extra_dp_size 2 \
--zero_stage 1 \
--load_balance
echo -e "\n\n EP-ZERO + Overlap \n\n"
torchrun --standalone --nproc_per_node $NUM_GPU \
$example_dir/benchmark/benchmark_cai.py \
--model_name $MODEL \
--batch_size 16 \
--seq_length $SEQ_LENGTH \
--warmup $WARMUP \
--active $ACTIVE \
--plugin ep_zero \
--use_kernel \
--extra_dp_size 2 \
--zero_stage 1 \
--load_balance \
--overlap_alltoall
# hybrid
torchrun --standalone --nproc_per_node $NUM_GPU \
$example_dir/benchmark/benchmark_cai.py \
--model_name $MODEL \
--batch_size 128 \
--seq_length $SEQ_LENGTH \
--warmup $WARMUP \
--active $ACTIVE \
--use_kernel \
--plugin hybrid \
--pp_size 2 \
--dp_size 1 \
--ep_size 4 \
--zero_stage 1 \
--microbatch_size 32

View File

@@ -0,0 +1,47 @@
#!/bin/bash
set -xue
NUM_GPU=8
MODEL="8b"
SEQ_LENGTH=2048
WARMUP=20
ACTIVE=4
# HACK: make model importable
example_dir=$(dirname $(realpath $(dirname $0)))
if [ -z ${PYTHONPATH+x} ]; then
export PYTHONPATH=$example_dir
else
export PYTHONPATH=$example_dir:$PYTHONPATH
fi
# ep
echo -e "\n\n Naive EP \n\n"
colossalai run --nproc_per_node $NUM_GPU --hostfile "hostfile.txt" \
$example_dir/benchmark/benchmark_cai.py \
--model_name $MODEL \
--batch_size 12 \
--seq_length $SEQ_LENGTH \
--warmup $WARMUP \
--active $ACTIVE \
--plugin ep \
--zero_stage 2
# ep_zero
echo -e "\n\n EP-ZERO \n\n"
colossalai run --nproc_per_node $NUM_GPU --hostfile "hostfile.txt" \
$example_dir/benchmark/benchmark_cai.py \
--model_name $MODEL \
--batch_size 20 \
--seq_length $SEQ_LENGTH \
--warmup $WARMUP \
--active $ACTIVE \
--plugin ep_zero \
--use_kernel \
--extra_dp_size 2 \
--zero_stage 1 \
--load_balance \
--overlap_alltoall

View File

@@ -0,0 +1,139 @@
import argparse
import functools
import os
import torch
import torch.distributed as dist
import tqdm
from model.modeling_openmoe import LlamaConfig, OpenMoeDecoderLayer, OpenMoeForCausalLM, set_openmoe_args
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from torch.utils.data import Dataset
from torch.utils.data.distributed import DistributedSampler
from transformers.models.llama import LlamaConfig
from utils import PerformanceEvaluator, get_model_numel
from colossalai.moe.manager import MOE_MANAGER
class RandomDataset(Dataset):
def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000):
self.num_samples = num_samples
self.max_length = max_length
self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length))
self.attention_mask = torch.ones_like(self.input_ids)
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
return {
"input_ids": self.input_ids[idx],
"attention_mask": self.attention_mask[idx],
"labels": self.input_ids[idx],
}
def fsdp_main(rank, world_size, args):
# initialize the process group
# initialize the process group
dist.init_process_group("nccl")
MOE_MANAGER.setup(seed=42, parallel=None)
dp_size = dist.get_world_size()
dataset = RandomDataset(
max_length=args.seq_length,
num_samples=args.batch_size * (args.warmup + args.active) * dp_size,
)
sampler = DistributedSampler(dataset, rank=rank, num_replicas=world_size, shuffle=False)
train_kwargs = {"batch_size": args.batch_size, "sampler": sampler}
train_loader = torch.utils.data.DataLoader(dataset, **train_kwargs)
torch.cuda.set_device(rank)
config = LlamaConfig.from_pretrained("hpcaitech/openmoe-%s" % args.model_name)
set_openmoe_args(
config,
num_experts=config.num_experts,
moe_layer_interval=config.moe_layer_interval,
enable_load_balance=False,
enable_kernel=False,
enable_comm_overlap=False,
)
torch.set_default_dtype(torch.float16)
model = OpenMoeForCausalLM(config)
torch.set_default_dtype(torch.float32)
auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={
OpenMoeDecoderLayer,
},
)
model = FSDP(
model,
mixed_precision=MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
buffer_dtype=torch.bfloat16,
),
auto_wrap_policy=auto_wrap_policy,
device_id=torch.cuda.current_device(),
)
optimizer = torch.optim.Adam(model.parameters(), weight_decay=0.01, lr=1e-5)
model.train()
model_numel = get_model_numel(model)
performance_evaluator = PerformanceEvaluator(
model_numel,
enable_grad_checkpoint=True,
ignore_steps=args.warmup,
dp_world_size=dist.get_world_size(),
)
for step, data in tqdm.tqdm(enumerate(train_loader), total=len(train_loader)):
performance_evaluator.on_step_start(step)
input_ids, attention_mask, labels = (
data["input_ids"].cuda(),
data["attention_mask"].cuda(),
data["labels"].cuda(),
)
optimizer.zero_grad()
output = model(
input_ids=input_ids,
labels=labels,
attention_mask=attention_mask,
chunk_head=False,
)
loss = output["loss"]
loss.backward()
optimizer.step()
performance_evaluator.on_step_end(input_ids)
performance_evaluator.on_fit_end()
if dist.get_rank() == 0:
print(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name",
type=str,
default="base",
choices=["base", "8b"],
help="base or 8b",
)
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--seq_length", type=int, default=2048)
parser.add_argument("--warmup", type=int, default=20)
parser.add_argument("--active", type=int, default=20)
args = parser.parse_args()
torch.manual_seed(42)
world_size = int(os.environ["WORLD_SIZE"])
local_rank = int(os.environ["LOCAL_RANK"])
fsdp_main(local_rank, world_size, args)

View File

@@ -0,0 +1,34 @@
#!/bin/bash
set -xue
MODEL="8b"
BATCH_SIZE=1
SEQ_LENGTH=2048
WARMUP=8
ACTIVE=4
# HACK: make model importable
example_dir=$(dirname $(realpath $(dirname $0)))
if [ -z ${PYTHONPATH+x} ]; then
export PYTHONPATH=$example_dir
else
export PYTHONPATH=$example_dir:$PYTHONPATH
fi
# single node
torchrun --standalone $example_dir/benchmark/benchmark_fsdp.py \
--model_name $MODEL \
--batch_size $BATCH_SIZE \
--seq_length $SEQ_LENGTH \
--warmup $WARMUP \
--active $ACTIVE
# multi node
torchrun --nproc_per_node=8 --nnodes=2 --node_rank=node_rank --master_addr=master_addr --master_port=master_port \
$example_dir/benchmark/benchmark_fsdp.py \
--model_name $MODEL \
--batch_size $BATCH_SIZE \
--seq_length $SEQ_LENGTH \
--warmup $WARMUP \
--active $ACTIVE

View File

@@ -0,0 +1,2 @@
host1
host2

View File

@@ -0,0 +1,126 @@
from time import time
from typing import Optional
import torch
import torch.distributed as dist
import torch.nn as nn
from torch import Tensor
from colossalai.logging import DistributedLogger
def print_model_numel(logger: DistributedLogger, model: nn.Module) -> None:
B = 1024**3
M = 1024**2
K = 1024
outputs = "Model param count: "
model_param = sum(p.numel() for p in model.parameters() if p.requires_grad)
if model_param >= B:
outputs += f"{model_param / B:.2f} B\n"
elif model_param >= M:
outputs += f"{model_param / M:.2f} M\n"
elif model_param >= K:
outputs += f"{model_param / K:.2f} K\n"
else:
outputs += f"{model_param}\n"
logger.info(outputs, ranks=[0])
def get_model_numel(model: nn.Module) -> None:
model_param = sum(p.numel() for p in model.parameters() if p.requires_grad)
return model_param
def divide(x: float, y: float) -> float:
if y == 0:
return float("inf")
elif y == float("inf"):
return float("nan")
return x / y
@torch.no_grad()
def all_reduce_mean(x: float, world_size: int) -> float:
if world_size == 1:
return x
tensor = torch.tensor([x], device=torch.cuda.current_device())
dist.all_reduce(tensor)
tensor = tensor / world_size
return tensor.item()
class Timer:
def __init__(self) -> None:
self.start_time: Optional[float] = None
self.duration: float = 0.0
def start(self) -> None:
self.start_time = time()
def end(self) -> None:
assert self.start_time is not None
self.duration += time() - self.start_time
self.start_time = None
def reset(self) -> None:
self.duration = 0.0
class PerformanceEvaluator:
"""
Callback for valuate the performance of the model.
Args:
actor_num_params: The number of parameters of the actor model.
critic_num_params: The number of parameters of the critic model.
initial_model_num_params: The number of parameters of the initial model.
reward_model_num_params: The number of parameters of the reward model.
enable_grad_checkpoint: Whether to enable gradient checkpointing.
ignore_episodes: The number of episodes to ignore when calculating the performance.
"""
def __init__(
self,
model_numel: int,
enable_grad_checkpoint: bool = False,
ignore_steps: int = 0,
dp_world_size: Optional[int] = None,
) -> None:
self.model_numel = model_numel
self.enable_grad_checkpoint = enable_grad_checkpoint
self.ignore_steps = ignore_steps
self.dp_world_size = dp_world_size
self.world_size = dist.get_world_size()
self.disable: bool = False
self.timer = Timer()
self.num_samples: int = 0
self.flop: int = 0
def on_step_start(self, step: int) -> None:
self.disable = self.ignore_steps > 0 and step < self.ignore_steps
if self.disable:
return
torch.cuda.synchronize()
self.timer.start()
def on_step_end(self, input_ids: Tensor, **kwargs) -> None:
if self.disable:
return
torch.cuda.synchronize()
self.timer.end()
batch_size, seq_len = input_ids.shape
self.num_samples += batch_size
self.flop += (batch_size * seq_len * self.model_numel * 2 * (3 + int(self.enable_grad_checkpoint)))
def on_fit_end(self) -> None:
avg_duration = all_reduce_mean(self.timer.duration, self.world_size)
avg_throughput = self.num_samples * self.dp_world_size / (avg_duration + 1e-12)
mp_world_size = self.world_size // self.dp_world_size
avg_tflops_per_gpu = self.flop / 1e12 / (avg_duration + 1e-12) / mp_world_size
if dist.get_rank() == 0:
print(
f"num_samples: {self.num_samples}, dp_world_size: {self.dp_world_size}, flop: {self.flop}, avg_duration: {avg_duration}, "
f"avg_throughput: {avg_throughput}")
print(f"Throughput: {avg_throughput:.2f} samples/sec, TFLOPS per GPU: {avg_tflops_per_gpu:.2f}")