mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
[FP8] rebase main (#5963)
* add SimPO
* fix dataloader
* remove debug code
* add orpo
* fix style
* fix colossalai, transformers version
* fix colossalai, transformers version
* fix colossalai, transformers version
* fix torch colossalai version
* update transformers version
* [shardformer] DeepseekMoE support (#5871)
* [Feature] deepseek moe expert parallel implement
* [misc] fix typo, remove redundant file (#5867)
* [misc] fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Feature] deepseek support & unit test
* [misc] remove debug code & useless print
* [misc] fix typos (#5872)
* [Feature] remove modeling file, use auto config. (#5884)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [Deepseek] remove redundant code (#5888)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [misc] remove redundant code
* [Feature/deepseek] resolve comment. (#5889)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [misc] remove redundant code
* [misc] mv module replacement into if branch
* [misc] add some warning message and modify some code in unit test
* [misc] fix typos
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Hoxfix] Fix CUDA_DEVICE_MAX_CONNECTIONS for comm overlap
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support (#5838)
* Diffusion Model Inference support
* Stable Diffusion 3 Support
* pixartalpha support
* [HotFix] CI,import,requirements-test for #5838 (#5892)
* [Hot Fix] CI,import,requirements-test
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Feature] Enable PP + SP for llama (#5868)
* fix cross-PP-stage position id length diff bug
* fix typo
* fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* use a one cross entropy func for all shardformer models
---------
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [ShardFormer] Add Ulysses Sequence Parallelism support for Command-R, Qwen2 and ChatGLM (#5897)
* add benchmark for sft, dpo, simpo, orpo. Add benchmarking result. Support lora with gradient checkpoint
* fix style
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix eval
* hotfix citation
* [zero] support all-gather overlap (#5898)
* [zero] support all-gather overlap
* [zero] add overlap all-gather flag
* [misc] fix typo
* [zero] update api
* fix orpo cross entropy loss
* [Auto Parallel]: Speed up intra-op plan generation by 44% (#5446)
* Remove unnecessary calls to deepcopy
* Build DimSpec's difference dict only once
This change considerably speeds up construction speed of DimSpec objects. The difference_dict is the same for each DimSpec object, so a single copy of it is enough.
* Fix documentation of DimSpec's difference method
* [ShardFormer] fix qwen2 sp (#5903)
* [compatibility] support torch 2.2 (#5875)
* Support Pytorch 2.2.2
* keep build_on_pr file and update .compatibility
* fix object_to_tensor usage when torch>=2.3.0 (#5820)
* [misc] support torch2.3 (#5893)
* [misc] support torch2.3
* [devops] update compatibility ci
* [devops] update compatibility ci
* [devops] add debug
* [devops] add debug
* [devops] add debug
* [devops] add debug
* [devops] remove debug
* [devops] remove debug
* [release] update version (#5912)
* [plugin] support all-gather overlap for hybrid parallel (#5919)
* [plugin] fixed all-gather overlap support for hybrid parallel
* add kto
* fix style, add kto data sample
* [Examples] Add lazy init to OPT and GPT examples (#5924)
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [ColossalChat] Hotfix for ColossalChat (#5910)
* add ignore and tiny llama
* fix path issue
* run style
* fix issue
* update bash
* add ignore and tiny llama
* fix path issue
* run style
* fix issue
* update bash
* fix ddp issue
* add Qwen 1.5 32B
* refactor tokenization
* [FIX BUG] UnboundLocalError: cannot access local variable 'default_conversation' where it is not associated with a value (#5931)
* cannot access local variable 'default_conversation' where it is not associated with a value
set default value for 'default_conversation'
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* fix test data
* refactor evaluation
* remove real data path
* remove real data path
* Add n_fused as an input from native_module (#5894)
* [FIX BUG] convert env param to int in (#5934)
* [Hotfix] Fix ZeRO typo #5936
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Feature] Add a switch to control whether the model checkpoint needs to be saved after each epoch ends (#5941)
* Add a switch to control whether the model checkpoint needs to be saved after each epoch ends
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* fix style
* fix style
* fix style
* [shardformer] hotfix attn mask (#5945)
* [shardformer] hotfix attn mask (#5947)
* [Feat] Distrifusion Acceleration Support for Diffusion Inference (#5895)
* Distrifusion Support source
* comp comm overlap optimization
* sd3 benchmark
* pixart distrifusion bug fix
* sd3 bug fix and benchmark
* generation bug fix
* naming fix
* add docstring, fix counter and shape error
* add reference
* readme and requirement
* [zero] hotfix update master params (#5951)
* [release] update version (#5952)
* [Chat] Fix lora (#5946)
* fix merging
* remove filepath
* fix style
* Update README.md (#5958)
* [hotfix] Remove unused plan section (#5957)
* remove readme
* fix readme
* update
* [test] add mixtral for sequence classification
* [test] add mixtral transformer test
* [moe] fix plugin
* [test] mixtra pp shard test
* [chore] handle non member group
* [zero] solve hang
* [test] pass mixtral shardformer test
* [moe] implement transit between non moe tp and ep
* [zero] solve hang
* [misc] solve booster hang by rename the variable
* solve hang when parallel mode = pp + dp
* [moe] implement submesh initialization
* [moe] add mixtral dp grad scaling when not all experts are activated
* [chore] manually revert unintended commit
* [chore] trivial fix
* [chore] arg pass & remove drop token
* [test] add mixtral modelling test
* [moe] implement tp
* [moe] test deepseek
* [moe] clean legacy code
* [Feature] MoE Ulysses Support (#5918)
* moe sp support
* moe sp bug solve
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [chore] minor fix
* [moe] init moe plugin comm setting with sp
* moe sp + ep bug fix
* [moe] finalize test (no pp)
* [moe] full test for deepseek and mixtral (pp + sp to fix)
* [chore] minor fix after rebase
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [chore] solve moe ckpt test failure and some other arg pass failure
* [moe] remove ops
* [test] fix test: test_zero1_2
* [bug] fix: somehow logger hangs the program
* [moe] deepseek moe sp support
* [test] add check
* [deepseek] replace attn (a workaround for bug in transformers)
* [misc] skip redunant test
* [misc] remove debug/print code
* [moe] refactor mesh assignment
* Revert "[moe] implement submesh initialization"
This reverts commit 2f9bce6686
.
* [chore] change moe_pg_mesh to private
* [misc] remove incompatible test config
* [misc] fix ci failure: change default value to false in moe plugin
* [misc] remove useless condition
* [chore] docstring
* [moe] remove force_overlap_comm flag and add warning instead
* [doc] add MoeHybridParallelPlugin docstring
* [moe] solve dp axis issue
* [chore] remove redundant test case, print string & reduce test tokens
* [feat] Dist Loader for Eval (#5950)
* support auto distributed data loader
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* support auto distributed data loader
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix tp error
* remove unused parameters
* remove unused
* update inference
* update docs
* update inference
---------
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [lora] lora support hybrid parallel plugin (#5956)
* lora support hybrid plugin
* fix
* fix
* fix
* fix
* fp8 operators for compressed communication
cast_to_fp8, cast_from_fp8, all_reduce_fp8
* fix scaling algorithm in FP8 casting
* support fp8 communication in pipeline parallelism
* add fp8_communication flag in the script
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* shardformer fp8
* fix rebase
* remove all to all
* fix shardformer fp8 communication training degradation
* [fp8] support all-gather flat tensor (#5932)
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* Update low_level_optim.py
---------
Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: Haze188 <haze188@qq.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: Guangyao Zhang <xjtu521@qq.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: Stephan Kö <stephankoe@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: zhurunhua <1281592874@qq.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: Gao, Ruiyuan <905370712@qq.com>
Co-authored-by: hxwang <wang1570@e.ntu.edu.sg>
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: Wang Binluo <32676639+wangbluo@users.noreply.github.com>
Co-authored-by: HangXu <hangxu0304@gmail.com>
This commit is contained in:
22
examples/inference/stable_diffusion/README.md
Normal file
22
examples/inference/stable_diffusion/README.md
Normal file
@@ -0,0 +1,22 @@
|
||||
## File Structure
|
||||
```
|
||||
|- sd3_generation.py: an example of how to use Colossalai Inference Engine to generate result by loading Diffusion Model.
|
||||
|- compute_metric.py: compare the quality of images w/o some acceleration method like Distrifusion
|
||||
|- benchmark_sd3.py: benchmark the performance of our InferenceEngine
|
||||
|- run_benchmark.sh: run benchmark command
|
||||
```
|
||||
Note: compute_metric.py need some dependencies which need `pip install -r requirements.txt`, `requirements.txt` is in `examples/inference/stable_diffusion/`
|
||||
|
||||
## Run Inference
|
||||
|
||||
The provided example `sd3_generation.py` is an example to configure, initialize the engine, and run inference on provided model. We've added `DiffusionPipeline` as model class, and the script is good to run inference with StableDiffusion 3.
|
||||
|
||||
For a basic setting, you could run the example by:
|
||||
```bash
|
||||
colossalai run --nproc_per_node 1 sd3_generation.py -m PATH_MODEL -p "hello world"
|
||||
```
|
||||
|
||||
Run multi-GPU inference (Patched Parallelism), as in the following example using 2 GPUs:
|
||||
```bash
|
||||
colossalai run --nproc_per_node 2 sd3_generation.py -m PATH_MODEL
|
||||
```
|
179
examples/inference/stable_diffusion/benchmark_sd3.py
Normal file
179
examples/inference/stable_diffusion/benchmark_sd3.py
Normal file
@@ -0,0 +1,179 @@
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
import colossalai
|
||||
from colossalai.inference.config import DiffusionGenerationConfig, InferenceConfig
|
||||
from colossalai.inference.core.engine import InferenceEngine
|
||||
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
||||
|
||||
GIGABYTE = 1024**3
|
||||
MEGABYTE = 1024 * 1024
|
||||
|
||||
_DTYPE_MAPPING = {
|
||||
"fp16": torch.float16,
|
||||
"bf16": torch.bfloat16,
|
||||
"fp32": torch.float32,
|
||||
}
|
||||
|
||||
|
||||
def log_generation_time(log_data, log_file):
|
||||
with open(log_file, "a") as f:
|
||||
json.dump(log_data, f, indent=2)
|
||||
f.write("\n")
|
||||
|
||||
|
||||
def warmup(engine, args):
|
||||
for _ in range(args.n_warm_up_steps):
|
||||
engine.generate(
|
||||
prompts=["hello world"],
|
||||
generation_config=DiffusionGenerationConfig(
|
||||
num_inference_steps=args.num_inference_steps, height=args.height[0], width=args.width[0]
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def profile_context(args):
|
||||
return (
|
||||
torch.profiler.profile(
|
||||
record_shapes=True,
|
||||
with_stack=True,
|
||||
with_modules=True,
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
],
|
||||
)
|
||||
if args.profile
|
||||
else nullcontext()
|
||||
)
|
||||
|
||||
|
||||
def log_and_profile(h, w, avg_time, log_msg, args, model_name, mode, prof=None):
|
||||
log_data = {
|
||||
"mode": mode,
|
||||
"model": model_name,
|
||||
"batch_size": args.batch_size,
|
||||
"patched_parallel_size": args.patched_parallel_size,
|
||||
"num_inference_steps": args.num_inference_steps,
|
||||
"height": h,
|
||||
"width": w,
|
||||
"dtype": args.dtype,
|
||||
"profile": args.profile,
|
||||
"n_warm_up_steps": args.n_warm_up_steps,
|
||||
"n_repeat_times": args.n_repeat_times,
|
||||
"avg_generation_time": avg_time,
|
||||
"log_message": log_msg,
|
||||
}
|
||||
|
||||
if args.log:
|
||||
log_file = f"examples/inference/stable_diffusion/benchmark_{model_name}_{mode}.json"
|
||||
log_generation_time(log_data=log_data, log_file=log_file)
|
||||
|
||||
if args.profile:
|
||||
file = f"examples/inference/stable_diffusion/benchmark_{model_name}_{mode}_prof.json"
|
||||
prof.export_chrome_trace(file)
|
||||
|
||||
|
||||
def benchmark_colossalai(rank, world_size, port, args):
|
||||
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
from colossalai.cluster.dist_coordinator import DistCoordinator
|
||||
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
inference_config = InferenceConfig(
|
||||
dtype=args.dtype,
|
||||
patched_parallelism_size=args.patched_parallel_size,
|
||||
)
|
||||
engine = InferenceEngine(args.model, inference_config=inference_config, verbose=False)
|
||||
|
||||
warmup(engine, args)
|
||||
|
||||
for h, w in zip(args.height, args.width):
|
||||
with profile_context(args) as prof:
|
||||
start = time.perf_counter()
|
||||
for _ in range(args.n_repeat_times):
|
||||
engine.generate(
|
||||
prompts=["hello world"],
|
||||
generation_config=DiffusionGenerationConfig(
|
||||
num_inference_steps=args.num_inference_steps, height=h, width=w
|
||||
),
|
||||
)
|
||||
end = time.perf_counter()
|
||||
|
||||
avg_time = (end - start) / args.n_repeat_times
|
||||
log_msg = f"[ColossalAI]avg generation time for h({h})xw({w}) is {avg_time:.2f}s"
|
||||
coordinator.print_on_master(log_msg)
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
log_and_profile(h, w, avg_time, log_msg, args, args.model.split("/")[-1], "colossalai", prof=prof)
|
||||
|
||||
|
||||
def benchmark_diffusers(args):
|
||||
model = DiffusionPipeline.from_pretrained(args.model, torch_dtype=_DTYPE_MAPPING[args.dtype]).to("cuda")
|
||||
|
||||
for _ in range(args.n_warm_up_steps):
|
||||
model(
|
||||
prompt="hello world",
|
||||
num_inference_steps=args.num_inference_steps,
|
||||
height=args.height[0],
|
||||
width=args.width[0],
|
||||
)
|
||||
|
||||
for h, w in zip(args.height, args.width):
|
||||
with profile_context(args) as prof:
|
||||
start = time.perf_counter()
|
||||
for _ in range(args.n_repeat_times):
|
||||
model(prompt="hello world", num_inference_steps=args.num_inference_steps, height=h, width=w)
|
||||
end = time.perf_counter()
|
||||
|
||||
avg_time = (end - start) / args.n_repeat_times
|
||||
log_msg = f"[Diffusers]avg generation time for h({h})xw({w}) is {avg_time:.2f}s"
|
||||
print(log_msg)
|
||||
|
||||
log_and_profile(h, w, avg_time, log_msg, args, args.model.split("/")[-1], "diffusers", prof)
|
||||
|
||||
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def benchmark(args):
|
||||
if args.mode == "colossalai":
|
||||
spawn(benchmark_colossalai, nprocs=args.patched_parallel_size, args=args)
|
||||
elif args.mode == "diffusers":
|
||||
benchmark_diffusers(args)
|
||||
|
||||
|
||||
"""
|
||||
# enable log
|
||||
python examples/inference/stable_diffusion/benchmark_sd3.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" -p 2 --mode colossalai --log
|
||||
python examples/inference/stable_diffusion/benchmark_sd3.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" --mode diffusers --log
|
||||
|
||||
# enable profiler
|
||||
python examples/inference/stable_diffusion/benchmark_sd3.py -m "stabilityai/stable-diffusion-3-medium-diffusers" -p 2 --mode colossalai --n_warm_up_steps 3 --n_repeat_times 1 --profile --num_inference_steps 20
|
||||
python examples/inference/stable_diffusion/benchmark_sd3.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" -p 2 --mode colossalai --n_warm_up_steps 3 --n_repeat_times 1 --profile --num_inference_steps 20
|
||||
python examples/inference/stable_diffusion/benchmark_sd3.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" --mode diffusers --n_warm_up_steps 3 --n_repeat_times 1 --profile --num_inference_steps 20
|
||||
"""
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-b", "--batch_size", type=int, default=1, help="Batch size")
|
||||
parser.add_argument("-p", "--patched_parallel_size", type=int, default=1, help="Patched Parallelism size")
|
||||
parser.add_argument("-n", "--num_inference_steps", type=int, default=50, help="Number of inference steps")
|
||||
parser.add_argument("-H", "--height", type=int, nargs="+", default=[1024, 2048], help="Height list")
|
||||
parser.add_argument("-w", "--width", type=int, nargs="+", default=[1024, 2048], help="Width list")
|
||||
parser.add_argument("--dtype", type=str, default="fp16", choices=["fp16", "fp32", "bf16"], help="Data type")
|
||||
parser.add_argument("--n_warm_up_steps", type=int, default=3, help="Number of warm up steps")
|
||||
parser.add_argument("--n_repeat_times", type=int, default=5, help="Number of repeat times")
|
||||
parser.add_argument("--profile", default=False, action="store_true", help="Enable torch profiler")
|
||||
parser.add_argument("--log", default=False, action="store_true", help="Enable logging")
|
||||
parser.add_argument("-m", "--model", default="stabilityai/stable-diffusion-3-medium-diffusers", help="Model path")
|
||||
parser.add_argument(
|
||||
"--mode", default="colossalai", choices=["colossalai", "diffusers"], help="Inference framework mode"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
benchmark(args)
|
80
examples/inference/stable_diffusion/compute_metric.py
Normal file
80
examples/inference/stable_diffusion/compute_metric.py
Normal file
@@ -0,0 +1,80 @@
|
||||
# Code from https://github.com/mit-han-lab/distrifuser/blob/main/scripts/compute_metrics.py
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from cleanfid import fid
|
||||
from PIL import Image
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from torchmetrics.image import LearnedPerceptualImagePatchSimilarity, PeakSignalNoiseRatio
|
||||
from torchvision.transforms import Resize
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def read_image(path: str):
|
||||
"""
|
||||
input: path
|
||||
output: tensor (C, H, W)
|
||||
"""
|
||||
img = np.asarray(Image.open(path))
|
||||
if len(img.shape) == 2:
|
||||
img = np.repeat(img[:, :, None], 3, axis=2)
|
||||
img = torch.from_numpy(img).permute(2, 0, 1)
|
||||
return img
|
||||
|
||||
|
||||
class MultiImageDataset(Dataset):
|
||||
def __init__(self, root0, root1, is_gt=False):
|
||||
super().__init__()
|
||||
self.root0 = root0
|
||||
self.root1 = root1
|
||||
file_names0 = os.listdir(root0)
|
||||
file_names1 = os.listdir(root1)
|
||||
|
||||
self.image_names0 = sorted([name for name in file_names0 if name.endswith(".png") or name.endswith(".jpg")])
|
||||
self.image_names1 = sorted([name for name in file_names1 if name.endswith(".png") or name.endswith(".jpg")])
|
||||
self.is_gt = is_gt
|
||||
assert len(self.image_names0) == len(self.image_names1)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.image_names0)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
img0 = read_image(os.path.join(self.root0, self.image_names0[idx]))
|
||||
if self.is_gt:
|
||||
# resize to 1024 x 1024
|
||||
img0 = Resize((1024, 1024))(img0)
|
||||
img1 = read_image(os.path.join(self.root1, self.image_names1[idx]))
|
||||
|
||||
batch_list = [img0, img1]
|
||||
return batch_list
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--batch_size", type=int, default=64)
|
||||
parser.add_argument("--num_workers", type=int, default=8)
|
||||
parser.add_argument("--is_gt", action="store_true")
|
||||
parser.add_argument("--input_root0", type=str, required=True)
|
||||
parser.add_argument("--input_root1", type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
psnr = PeakSignalNoiseRatio(data_range=(0, 1), reduction="elementwise_mean", dim=(1, 2, 3)).to("cuda")
|
||||
lpips = LearnedPerceptualImagePatchSimilarity(normalize=True).to("cuda")
|
||||
|
||||
dataset = MultiImageDataset(args.input_root0, args.input_root1, is_gt=args.is_gt)
|
||||
dataloader = DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers)
|
||||
|
||||
progress_bar = tqdm(dataloader)
|
||||
with torch.inference_mode():
|
||||
for i, batch in enumerate(progress_bar):
|
||||
batch = [img.to("cuda") / 255 for img in batch]
|
||||
batch_size = batch[0].shape[0]
|
||||
psnr.update(batch[0], batch[1])
|
||||
lpips.update(batch[0], batch[1])
|
||||
fid_score = fid.compute_fid(args.input_root0, args.input_root1)
|
||||
|
||||
print("PSNR:", psnr.compute().item())
|
||||
print("LPIPS:", lpips.compute().item())
|
||||
print("FID:", fid_score)
|
3
examples/inference/stable_diffusion/requirements.txt
Normal file
3
examples/inference/stable_diffusion/requirements.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
torchvision
|
||||
torchmetrics
|
||||
cleanfid
|
42
examples/inference/stable_diffusion/run_benchmark.sh
Normal file
42
examples/inference/stable_diffusion/run_benchmark.sh
Normal file
@@ -0,0 +1,42 @@
|
||||
#!/bin/bash
|
||||
|
||||
models=("PixArt-alpha/PixArt-XL-2-1024-MS" "stabilityai/stable-diffusion-3-medium-diffusers")
|
||||
parallelism=(1 2 4 8)
|
||||
resolutions=(1024 2048 3840)
|
||||
modes=("colossalai" "diffusers")
|
||||
|
||||
CUDA_VISIBLE_DEVICES_set_n_least_memory_usage() {
|
||||
local n=${1:-"9999"}
|
||||
echo "GPU Memory Usage:"
|
||||
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \
|
||||
| tail -n +2 \
|
||||
| nl -v 0 \
|
||||
| tee /dev/tty \
|
||||
| sort -g -k 2 \
|
||||
| awk '{print $1}' \
|
||||
| head -n $n)
|
||||
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
|
||||
echo "Now CUDA_VISIBLE_DEVICES is set to:"
|
||||
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||
}
|
||||
|
||||
for model in "${models[@]}"; do
|
||||
for p in "${parallelism[@]}"; do
|
||||
for resolution in "${resolutions[@]}"; do
|
||||
for mode in "${modes[@]}"; do
|
||||
if [[ "$mode" == "colossalai" && "$p" == 1 ]]; then
|
||||
continue
|
||||
fi
|
||||
if [[ "$mode" == "diffusers" && "$p" != 1 ]]; then
|
||||
continue
|
||||
fi
|
||||
CUDA_VISIBLE_DEVICES_set_n_least_memory_usage $p
|
||||
|
||||
cmd="python examples/inference/stable_diffusion/benchmark_sd3.py -m \"$model\" -p $p --mode $mode --log -H $resolution -w $resolution"
|
||||
|
||||
echo "Executing: $cmd"
|
||||
eval $cmd
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
81
examples/inference/stable_diffusion/sd3_generation.py
Normal file
81
examples/inference/stable_diffusion/sd3_generation.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import argparse
|
||||
|
||||
from diffusers import DiffusionPipeline
|
||||
from torch import bfloat16
|
||||
from torch import distributed as dist
|
||||
from torch import float16, float32
|
||||
|
||||
import colossalai
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.inference.config import DiffusionGenerationConfig, InferenceConfig
|
||||
from colossalai.inference.core.engine import InferenceEngine
|
||||
|
||||
# For Stable Diffusion 3, we'll use the following configuration
|
||||
MODEL_CLS = DiffusionPipeline
|
||||
|
||||
TORCH_DTYPE_MAP = {
|
||||
"fp16": float16,
|
||||
"fp32": float32,
|
||||
"bf16": bfloat16,
|
||||
}
|
||||
|
||||
|
||||
def infer(args):
|
||||
# ==============================
|
||||
# Launch colossalai, setup distributed environment
|
||||
# ==============================
|
||||
colossalai.launch_from_torch()
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
# ==============================
|
||||
# Load model and tokenizer
|
||||
# ==============================
|
||||
model_path_or_name = args.model
|
||||
model = MODEL_CLS.from_pretrained(model_path_or_name, torch_dtype=TORCH_DTYPE_MAP.get(args.dtype, None))
|
||||
|
||||
# ==============================
|
||||
# Initialize InferenceEngine
|
||||
# ==============================
|
||||
coordinator.print_on_master(f"Initializing Inference Engine...")
|
||||
inference_config = InferenceConfig(
|
||||
dtype=args.dtype,
|
||||
max_batch_size=args.max_batch_size,
|
||||
tp_size=args.tp_size,
|
||||
use_cuda_kernel=args.use_cuda_kernel,
|
||||
patched_parallelism_size=dist.get_world_size(),
|
||||
)
|
||||
engine = InferenceEngine(model, inference_config=inference_config, verbose=True)
|
||||
|
||||
# ==============================
|
||||
# Generation
|
||||
# ==============================
|
||||
coordinator.print_on_master(f"Generating...")
|
||||
out = engine.generate(prompts=[args.prompt], generation_config=DiffusionGenerationConfig())[0]
|
||||
if dist.get_rank() == 0:
|
||||
out.save(f"cat_parallel_size{dist.get_world_size()}.jpg")
|
||||
coordinator.print_on_master(out)
|
||||
|
||||
|
||||
# colossalai run --nproc_per_node 1 examples/inference/stable_diffusion/sd3_generation.py -m MODEL_PATH
|
||||
|
||||
# colossalai run --nproc_per_node 1 examples/inference/stable_diffusion/sd3_generation.py -m "stabilityai/stable-diffusion-3-medium-diffusers" --tp_size 1
|
||||
# colossalai run --nproc_per_node 2 examples/inference/stable_diffusion/sd3_generation.py -m "stabilityai/stable-diffusion-3-medium-diffusers" --tp_size 1
|
||||
|
||||
# colossalai run --nproc_per_node 1 examples/inference/stable_diffusion/sd3_generation.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" --tp_size 1
|
||||
# colossalai run --nproc_per_node 2 examples/inference/stable_diffusion/sd3_generation.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" --tp_size 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# ==============================
|
||||
# Parse Arguments
|
||||
# ==============================
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-m", "--model", type=str, help="Path to the model or model name")
|
||||
parser.add_argument("-t", "--tp_size", type=int, default=1, help="Tensor Parallelism size")
|
||||
parser.add_argument("-p", "--prompt", type=str, default="A cat holding a sign that says hello world", help="Prompt")
|
||||
parser.add_argument("-b", "--max_batch_size", type=int, default=1, help="Max batch size")
|
||||
parser.add_argument("-d", "--dtype", type=str, default="fp16", help="Data type", choices=["fp16", "fp32", "bf16"])
|
||||
parser.add_argument("--use_cuda_kernel", action="store_true", help="Use CUDA kernel, use Triton by default")
|
||||
args = parser.parse_args()
|
||||
|
||||
infer(args)
|
2
examples/inference/stable_diffusion/test_ci.sh
Normal file
2
examples/inference/stable_diffusion/test_ci.sh
Normal file
@@ -0,0 +1,2 @@
|
||||
#!/bin/bash
|
||||
echo "Skip the test (this test is slow)"
|
@@ -1,4 +1,5 @@
|
||||
import argparse
|
||||
from contextlib import nullcontext
|
||||
from typing import Callable, List, Union
|
||||
|
||||
import evaluate
|
||||
@@ -17,6 +18,7 @@ from colossalai.accelerator import get_accelerator
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
|
||||
# ==============================
|
||||
@@ -252,10 +254,16 @@ def main():
|
||||
pad_token_id=data_builder.tokenizer.pad_token_id,
|
||||
)
|
||||
|
||||
if model_name == "gpt2":
|
||||
model = GPT2ForSequenceClassification.from_pretrained(model_name, config=cfg).cuda()
|
||||
else:
|
||||
raise RuntimeError
|
||||
init_ctx = (
|
||||
LazyInitContext(default_device=get_accelerator().get_current_device())
|
||||
if isinstance(plugin, (GeminiPlugin))
|
||||
else nullcontext()
|
||||
)
|
||||
with init_ctx:
|
||||
if model_name == "gpt2":
|
||||
model = GPT2ForSequenceClassification.from_pretrained(model_name, config=cfg).cuda()
|
||||
else:
|
||||
raise RuntimeError
|
||||
|
||||
# optimizer
|
||||
no_decay = ["bias", "LayerNorm.weight"]
|
||||
|
@@ -98,6 +98,7 @@ def main():
|
||||
parser.add_argument("--disable-async-reduce", action="store_true", help="Disable the asynchronous reduce operation")
|
||||
parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number")
|
||||
parser.add_argument("--no_cache", action="store_true")
|
||||
parser.add_argument("--overlap_allgather", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
colossalai.launch_from_torch()
|
||||
@@ -199,9 +200,9 @@ def main():
|
||||
enable_flash_attention=args.xformers,
|
||||
microbatch_size=args.mbs,
|
||||
precision="bf16",
|
||||
dp_outside=False,
|
||||
overlap_p2p=args.overlap,
|
||||
enable_metadata_cache=not args.no_cache,
|
||||
overlap_allgather=args.overlap_allgather,
|
||||
**hybrid_kwargs,
|
||||
)
|
||||
elif args.plugin == "3d_cpu":
|
||||
@@ -292,7 +293,7 @@ def main():
|
||||
with get_profile_context(
|
||||
args.profile,
|
||||
args.ignore_steps,
|
||||
len(dataloader) - 1,
|
||||
1, # avoid creating massive log files
|
||||
save_dir=f"profile/{time.strftime('%H:%M', time.localtime())}-{args.plugin}-llama-{args.config}",
|
||||
) as prof:
|
||||
if isinstance(plugin, HybridParallelPlugin) and args.pp > 1:
|
||||
|
@@ -1,138 +0,0 @@
|
||||
## OpenMoE
|
||||
[OpenMoE](https://github.com/XueFuzhao/OpenMoE) is the open-source community's first decoder-only MoE transformer. OpenMoE is implemented in Jax, and [Colossal-AI](https://github.com/hpcaitech/ColossalAI) has pioneered an efficient open-source support for this model in PyTorch, enabling a broader range of users to participate in and use this model. The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) demonstrates finetune and inference methods.
|
||||
|
||||
|
||||
<p align="center">
|
||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/examples/images/MOE_training.png" width=800/>
|
||||
</p>
|
||||
|
||||
* [2023/11] [Enhanced MoE Parallelism, Open-source MoE Model Training Can Be 9 Times More Efficient](https://www.hpc-ai.tech/blog/enhanced-moe-parallelism-open-source-moe-model-training-can-be-9-times-more-efficient)
|
||||
[[code]](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/openmoe)
|
||||
[[blog]](https://www.hpc-ai.tech/blog/enhanced-moe-parallelism-open-source-moe-model-training-can-be-9-times-more-efficient)
|
||||
|
||||
## Usage
|
||||
|
||||
### 1. Installation
|
||||
|
||||
Please install the latest ColossalAI from source.
|
||||
|
||||
```bash
|
||||
BUILD_EXT=1 pip install -U git+https://github.com/hpcaitech/ColossalAI
|
||||
```
|
||||
|
||||
Then install dependencies.
|
||||
|
||||
```bash
|
||||
cd ColossalAI/examples/language/openmoe
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
Additionally, we recommend you to use torch 1.13.1. We've tested our code on torch 1.13.1 and found it's compatible with our code and flash attention.
|
||||
|
||||
### 2. Install kernels (Optional)
|
||||
|
||||
We have utilized `Triton`, `FlashAttention` and `Apex` kernel for better performance. They are not necessary but we recommend you to install them to fully utilize your hardware.
|
||||
```
|
||||
# install triton via pip
|
||||
pip install triton
|
||||
|
||||
# install flash attention via pip
|
||||
pip install flash-attn==2.0.5
|
||||
|
||||
# install apex from source
|
||||
git clone https://github.com/NVIDIA/apex.git
|
||||
cd apex
|
||||
git checkout 741bdf50825a97664db08574981962d66436d16a
|
||||
pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./ --global-option="--cuda_ext"
|
||||
```
|
||||
|
||||
### 3. Train
|
||||
Yon can use colossalai run to launch single-node training:
|
||||
```bash
|
||||
colossalai run --standalone --nproc_per_node YOUR_GPU_PER_NODE train.py --OTHER_CONFIGURATIONS
|
||||
```
|
||||
Yon can also use colossalai run to launch multi-nodes training:
|
||||
```bash
|
||||
colossalai run --nproc_per_node YOUR_GPU_PER_NODE --hostfile YOUR_HOST_FILE train.py --OTHER_CONFIGURATIONS
|
||||
```
|
||||
|
||||
Here is a sample hostfile:
|
||||
|
||||
```text
|
||||
hostname1
|
||||
hostname2
|
||||
hostname3
|
||||
hostname4
|
||||
```
|
||||
|
||||
The hostname refers to the ip address of your nodes. Make sure master node can access all nodes (including itself) by ssh without password.
|
||||
|
||||
Here is details about CLI arguments:
|
||||
|
||||
- Model configuration: `--model_name`. `base` and `8b` are supported for OpenMoE.
|
||||
- Booster plugin: `--plugin`. `ep`, `ep_zero` and `hybrid` are supported. `ep_zero` is recommended for general cases. `ep` can provides least memory consumption and `hybrid` suits large scale training.
|
||||
- Output path: `--output_path`. The path to save your model. The default value is `./outputs`.
|
||||
- Number of epochs: `--num_epochs`. The default value is 1.
|
||||
- Local batch size: `--batch_size`. Batch size per GPU. The default value is 1.
|
||||
- Save interval: `-i`, `--save_interval`. The interval (steps) of saving checkpoints. The default value is 1000.
|
||||
- Mixed precision: `--precision`. The default value is "bf16". "fp16", "bf16" and "fp32" are supported.
|
||||
- Max length: `--max_length`. Max sequence length. Default to 2048.
|
||||
- Dataset: `-d`, `--dataset`. The default dataset is `yizhongw/self_instruct`. It support any dataset from `datasets` with the same data format as it.
|
||||
- Task Name: `--task_name`. Task of corresponding dataset. Default to `super_natural_instructions`.
|
||||
- Learning rate: `--lr`. The default value is 1e-5.
|
||||
- Weight decay: `--weight_decay`. The default value is 0.
|
||||
- Zero stage: `--zero_stage`. Zero stage. Recommend 2 for ep and 1 for ep zero.
|
||||
- Extra dp size: `--extra_dp_size`. Extra moe param dp size for ep_zero plugin. Recommended to be 2 or 4.
|
||||
- Use kernel: `--use_kernel`. Use kernel optim. Need to install flash attention and triton to enable all kernel optimizations. Skip if not installed.
|
||||
- Use layernorm kernel: `--use_layernorm_kernel`. Use layernorm kernel. Need to install apex. Raise error if not installed.
|
||||
- Router aux loss factor: `--router_aux_loss_factor`. Moe router z loss factor. You can refer to STMoE for details.
|
||||
- Router z loss factor: `--router_z_loss_factor`. Moe router aux loss factor. You can refer to STMoE for details.
|
||||
- Label smoothing: `--label_smoothing`. Label smoothing.
|
||||
- Z loss factor: `--z_loss_factor`. The final outputs' classification z loss factor.
|
||||
Load balance: `--load_balance`. Expert load balance. Defaults to False. Recommend enabling.
|
||||
- Load balance interval: `--load_balance_interval`. Expert load balance interval.
|
||||
- Communication overlap: `--comm_overlap`. Use communication overlap for MoE. Recommended to enable for multi-node training.
|
||||
|
||||
### 4. Shell Script Examples
|
||||
|
||||
For your convenience, we provide some shell scripts to train with various configurations. Here we will show an example of how to run training
|
||||
OpenMoE.
|
||||
|
||||
#### a. Running environment
|
||||
This experiment was performed on a single computing nodes with 8 A800 80GB GPUs in total for OpenMoE-8B. The GPUs are fully connected with NVLink.
|
||||
|
||||
#### b. Running command
|
||||
We demonstrate how to run three plugins in `train.sh`. You can choose anyone and use your own args.
|
||||
|
||||
```bash
|
||||
bash train.sh
|
||||
```
|
||||
|
||||
#### c. Multi-Nodes Training
|
||||
|
||||
To run on multi-nodes, you can modify the script as:
|
||||
```bash
|
||||
colossalai run --nproc_per_node YOUR_GPU_PER_NODE --hostfile YOUR_HOST_FILE \
|
||||
train.py --OTHER_CONFIGURATIONS
|
||||
```
|
||||
|
||||
## Reference
|
||||
```
|
||||
@article{bian2021colossal,
|
||||
title={Colossal-AI: A Unified Deep Learning System For Large-Scale Parallel Training},
|
||||
author={Bian, Zhengda and Liu, Hongxin and Wang, Boxiang and Huang, Haichen and Li, Yongbin and Wang, Chuanrui and Cui, Fan and You, Yang},
|
||||
journal={arXiv preprint arXiv:2110.14883},
|
||||
year={2021}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{openmoe2023,
|
||||
author = {Fuzhao Xue, Zian Zheng, Yao Fu, Jinjie Ni, Zangwei Zheng, Wangchunshu Zhou and Yang You},
|
||||
title = {OpenMoE: Open Mixture-of-Experts Language Models},
|
||||
year = {2023},
|
||||
publisher = {GitHub},
|
||||
journal = {GitHub repository},
|
||||
howpublished = {\url{https://github.com/XueFuzhao/OpenMoE}},
|
||||
}
|
||||
```
|
@@ -1,298 +0,0 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from huggingface_hub import snapshot_download
|
||||
from model.modeling_openmoe import OpenMoeForCausalLM, set_openmoe_args
|
||||
from model.openmoe_policy import OpenMoeForCausalLMPolicy
|
||||
from torch.utils.data import Dataset
|
||||
from tqdm import tqdm
|
||||
from transformers import T5Tokenizer
|
||||
from transformers.models.llama import LlamaConfig
|
||||
from utils import PerformanceEvaluator, get_model_numel
|
||||
|
||||
import colossalai
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.moe.layers import apply_load_balance
|
||||
from colossalai.moe.manager import MOE_MANAGER
|
||||
from colossalai.moe.utils import skip_init
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
|
||||
|
||||
def move_to_cuda(batch, device):
|
||||
return {k: v.to(device) for k, v in batch.items()}
|
||||
|
||||
|
||||
def load_ckpt(repo_name: str, model: OpenMoeForCausalLM, booster: Booster):
|
||||
ckpt_path = snapshot_download(repo_name)
|
||||
# single ckpt
|
||||
if os.path.exists(os.path.join(ckpt_path, "pytorch_model.bin")):
|
||||
ckpt_path = os.path.join(ckpt_path, "pytorch_model.bin")
|
||||
# shard ckpt
|
||||
elif os.path.exists(os.path.join(ckpt_path, "pytorch_model.bin.index.json")):
|
||||
ckpt_path = os.path.join(ckpt_path, "pytorch_model.bin.index.json")
|
||||
else:
|
||||
raise ValueError(f"Invalid checkpoint path: {ckpt_path}")
|
||||
booster.load_model(model, ckpt_path)
|
||||
|
||||
|
||||
class RandomDataset(Dataset):
|
||||
def __init__(
|
||||
self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 256384, tokenizer: T5Tokenizer = None
|
||||
):
|
||||
self.num_samples = num_samples
|
||||
self.max_length = max_length
|
||||
if os.path.exists("./mock_data.json"):
|
||||
self.input_ids = []
|
||||
self.attention_mask = []
|
||||
with open("./mock_data.json", "r") as f:
|
||||
data = json.load(f)
|
||||
for v in data.values():
|
||||
d = v["text"]
|
||||
encode = tokenizer(
|
||||
"<pad>" + d,
|
||||
return_tensors="pt",
|
||||
add_special_tokens=False,
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
padding="max_length",
|
||||
)
|
||||
self.input_ids.append(encode["input_ids"])
|
||||
self.attention_mask.append(encode["attention_mask"])
|
||||
self.input_ids = torch.cat(self.input_ids, dim=0).to(get_accelerator().get_current_device())
|
||||
self.attention_mask = torch.cat(self.attention_mask, dim=0).to(get_accelerator().get_current_device())
|
||||
repeat_times = num_samples // self.input_ids.shape[0] + 1
|
||||
self.input_ids = self.input_ids.repeat(repeat_times, 1)[:num_samples]
|
||||
self.attention_mask = self.attention_mask.repeat(repeat_times, 1)[:num_samples]
|
||||
else:
|
||||
self.input_ids = torch.randint(
|
||||
0, vocab_size, (num_samples, max_length), device=get_accelerator().get_current_device()
|
||||
)
|
||||
self.attention_mask = torch.ones_like(self.input_ids)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return {
|
||||
"input_ids": self.input_ids[idx],
|
||||
"attention_mask": self.attention_mask[idx],
|
||||
"labels": self.input_ids[idx],
|
||||
}
|
||||
|
||||
|
||||
def parse_args():
|
||||
# basic settings
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
type=str,
|
||||
default="base",
|
||||
choices=["base", "8b"],
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch_size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Batch size (per dp group) for the training dataloader.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seq_length",
|
||||
type=int,
|
||||
default=2048,
|
||||
help="sequence length for the training dataloader.",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
|
||||
parser.add_argument(
|
||||
"--plugin",
|
||||
type=str,
|
||||
default="hybrid",
|
||||
help="parallel plugin",
|
||||
)
|
||||
# hybrid plugin
|
||||
parser.add_argument("--pp_size", type=int, default=2, help="pp size")
|
||||
parser.add_argument("--dp_size", type=int, default=1, help="dp size")
|
||||
parser.add_argument("--ep_size", type=int, default=2, help="ep size")
|
||||
parser.add_argument("--zero_stage", type=int, default=2, help="zero stage in hybrid plugin")
|
||||
parser.add_argument("--microbatch_size", type=int, default=1, help="microbatch size")
|
||||
parser.add_argument("--extra_dp_size", type=int, default=1)
|
||||
# kernel
|
||||
parser.add_argument(
|
||||
"--use_kernel",
|
||||
action="store_true",
|
||||
help="Use kernel optim. Need to install flash attention, apex, triton to enable all kernel optimizations.",
|
||||
)
|
||||
# bench
|
||||
parser.add_argument("--warmup", type=int, default=20)
|
||||
parser.add_argument("--active", type=int, default=20)
|
||||
# load balance
|
||||
parser.add_argument("--load_balance", action="store_true")
|
||||
|
||||
# overlap communication
|
||||
parser.add_argument("--overlap_comm", action="store_true")
|
||||
# hierarchical all-to-all
|
||||
parser.add_argument("--hierarchical_alltoall", action="store_true")
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
# Launch ColossalAI
|
||||
colossalai.launch_from_torch(seed=args.seed)
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
# Set plugin
|
||||
booster_kwargs = {}
|
||||
hybrid_dict = {
|
||||
"tp_size": 1,
|
||||
"custom_policy": OpenMoeForCausalLMPolicy(),
|
||||
"enable_fused_normalization": args.use_kernel,
|
||||
"enable_jit_fused": args.use_kernel,
|
||||
"precision": "bf16",
|
||||
"zero_stage": args.zero_stage,
|
||||
}
|
||||
mgr_dict = {}
|
||||
if args.plugin == "ep":
|
||||
dp_size = dist.get_world_size()
|
||||
plugin = MoeHybridParallelPlugin(
|
||||
pp_size=1,
|
||||
**hybrid_dict,
|
||||
)
|
||||
MOE_MANAGER.setup(
|
||||
parallel="EP",
|
||||
max_ep_size=dp_size,
|
||||
**mgr_dict,
|
||||
)
|
||||
elif args.plugin == "ep_zero":
|
||||
dp_size = dist.get_world_size()
|
||||
use_ep_inside = False
|
||||
plugin = MoeHybridParallelPlugin(
|
||||
pp_size=1,
|
||||
ep_size=args.ep_size,
|
||||
use_ep_inside=use_ep_inside,
|
||||
**hybrid_dict,
|
||||
)
|
||||
MOE_MANAGER.setup(
|
||||
parallel="EP",
|
||||
max_ep_size=dp_size // args.extra_dp_size,
|
||||
use_ep_inside=use_ep_inside,
|
||||
**mgr_dict,
|
||||
)
|
||||
elif args.plugin == "hybrid":
|
||||
dp_size = dist.get_world_size() // args.pp_size
|
||||
plugin = MoeHybridParallelPlugin(
|
||||
pp_size=args.pp_size,
|
||||
zero_stage=args.zero_stage,
|
||||
microbatch_size=args.microbatch_size,
|
||||
**hybrid_dict,
|
||||
)
|
||||
MOE_MANAGER.setup(
|
||||
parallel="EP",
|
||||
mode="fixed",
|
||||
fixed_dp_size=args.dp_size,
|
||||
fixed_ep_size=args.ep_size,
|
||||
fixed_pp_size=args.pp_size,
|
||||
**mgr_dict,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid plugin {args.plugin}")
|
||||
coordinator.print_on_master(f"Set plugin as {plugin}")
|
||||
|
||||
# Build OpenMoe model
|
||||
repo_name = "hpcai-tech/openmoe-" + args.model_name
|
||||
config = LlamaConfig.from_pretrained(repo_name)
|
||||
set_openmoe_args(
|
||||
config,
|
||||
num_experts=config.num_experts,
|
||||
moe_layer_interval=config.moe_layer_interval,
|
||||
enable_load_balance=args.load_balance,
|
||||
enable_kernel=args.use_kernel,
|
||||
enable_comm_overlap=args.overlap_comm,
|
||||
enable_hierarchical_alltoall=args.hierarchical_alltoall,
|
||||
)
|
||||
with skip_init():
|
||||
model = OpenMoeForCausalLM(config)
|
||||
coordinator.print_on_master(f"Finish init model with config:\n{config}")
|
||||
|
||||
# Enable gradient checkpointing
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
# Prepare tokenizer and dataloader
|
||||
tokenizer = T5Tokenizer.from_pretrained("google/umt5-small")
|
||||
dataset = RandomDataset(
|
||||
num_samples=args.batch_size * (args.warmup + args.active + 1) * dp_size,
|
||||
max_length=args.seq_length,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size)
|
||||
|
||||
# Set optimizer
|
||||
optimizer = HybridAdam(model.parameters(), weight_decay=0.01, lr=1e-5)
|
||||
|
||||
model_numel = get_model_numel(model)
|
||||
performance_evaluator = PerformanceEvaluator(
|
||||
model_numel,
|
||||
enable_grad_checkpoint=True,
|
||||
ignore_steps=args.warmup,
|
||||
dp_world_size=dp_size,
|
||||
)
|
||||
|
||||
# Set booster
|
||||
booster = Booster(plugin=plugin, **booster_kwargs)
|
||||
load_ckpt(repo_name, model, booster)
|
||||
model, optimizer, _, dataloader, _ = booster.boost(model=model, optimizer=optimizer, dataloader=dataloader)
|
||||
use_pipeline = isinstance(booster.plugin, MoeHybridParallelPlugin) and booster.plugin.pp_size > 1
|
||||
is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
|
||||
coordinator.print_on_master(f"Finish init booster")
|
||||
|
||||
# Start finetuning
|
||||
coordinator.print_on_master(f"Start training")
|
||||
model.train()
|
||||
train_dataloader_iter = iter(dataloader)
|
||||
total_len = len(train_dataloader_iter) - 1
|
||||
exmaple_data = next(train_dataloader_iter)
|
||||
with tqdm(range(total_len), disable=not coordinator.is_master()) as pbar:
|
||||
for step in pbar:
|
||||
performance_evaluator.on_step_start(step)
|
||||
if use_pipeline:
|
||||
# Forward pass
|
||||
outputs = booster.execute_pipeline(
|
||||
train_dataloader_iter,
|
||||
model,
|
||||
lambda x, y: x.loss,
|
||||
optimizer,
|
||||
return_loss=True,
|
||||
)
|
||||
# Backward and optimize
|
||||
if is_pp_last_stage:
|
||||
loss = outputs["loss"]
|
||||
pbar.set_postfix({"loss": loss.item()})
|
||||
else:
|
||||
# Forward pass
|
||||
data = next(train_dataloader_iter)
|
||||
data = move_to_cuda(data, torch.cuda.current_device())
|
||||
outputs = model(**data)
|
||||
loss = outputs["loss"]
|
||||
# Backward
|
||||
booster.backward(loss, optimizer)
|
||||
pbar.set_postfix({"loss": loss.item()})
|
||||
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
performance_evaluator.on_step_end(exmaple_data["input_ids"])
|
||||
if (step == args.warmup // 2) and args.load_balance:
|
||||
coordinator.print_on_master(f"Apply load balance")
|
||||
apply_load_balance(model, optimizer)
|
||||
performance_evaluator.on_fit_end()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@@ -1,78 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -xue
|
||||
|
||||
NUM_GPU=8
|
||||
MODEL="8b"
|
||||
SEQ_LENGTH=2048
|
||||
WARMUP=20
|
||||
ACTIVE=4
|
||||
|
||||
# HACK: make model importable
|
||||
example_dir=$(dirname $(realpath $(dirname $0)))
|
||||
if [ -z ${PYTHONPATH+x} ]; then
|
||||
export PYTHONPATH=$example_dir
|
||||
else
|
||||
export PYTHONPATH=$example_dir:$PYTHONPATH
|
||||
fi
|
||||
|
||||
|
||||
# ep
|
||||
echo -e "\n\n Naive EP \n\n"
|
||||
torchrun --standalone --nproc_per_node $NUM_GPU \
|
||||
$example_dir/benchmark/benchmark_cai.py \
|
||||
--model_name $MODEL \
|
||||
--batch_size 8 \
|
||||
--seq_length $SEQ_LENGTH \
|
||||
--warmup $WARMUP \
|
||||
--active $ACTIVE \
|
||||
--plugin ep \
|
||||
--zero_stage 2
|
||||
|
||||
|
||||
# ep_zero
|
||||
echo -e "\n\n EP-ZERO \n\n"
|
||||
torchrun --standalone --nproc_per_node $NUM_GPU \
|
||||
$example_dir/benchmark/benchmark_cai.py \
|
||||
--model_name $MODEL \
|
||||
--batch_size 16 \
|
||||
--seq_length $SEQ_LENGTH \
|
||||
--warmup $WARMUP \
|
||||
--active $ACTIVE \
|
||||
--plugin ep_zero \
|
||||
--use_kernel \
|
||||
--extra_dp_size 2 \
|
||||
--zero_stage 1 \
|
||||
--load_balance
|
||||
|
||||
echo -e "\n\n EP-ZERO + Overlap \n\n"
|
||||
torchrun --standalone --nproc_per_node $NUM_GPU \
|
||||
$example_dir/benchmark/benchmark_cai.py \
|
||||
--model_name $MODEL \
|
||||
--batch_size 16 \
|
||||
--seq_length $SEQ_LENGTH \
|
||||
--warmup $WARMUP \
|
||||
--active $ACTIVE \
|
||||
--plugin ep_zero \
|
||||
--use_kernel \
|
||||
--extra_dp_size 2 \
|
||||
--zero_stage 1 \
|
||||
--load_balance \
|
||||
--overlap_alltoall
|
||||
|
||||
|
||||
# hybrid
|
||||
torchrun --standalone --nproc_per_node $NUM_GPU \
|
||||
$example_dir/benchmark/benchmark_cai.py \
|
||||
--model_name $MODEL \
|
||||
--batch_size 128 \
|
||||
--seq_length $SEQ_LENGTH \
|
||||
--warmup $WARMUP \
|
||||
--active $ACTIVE \
|
||||
--use_kernel \
|
||||
--plugin hybrid \
|
||||
--pp_size 2 \
|
||||
--dp_size 1 \
|
||||
--ep_size 4 \
|
||||
--zero_stage 1 \
|
||||
--microbatch_size 32
|
@@ -1,47 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -xue
|
||||
|
||||
NUM_GPU=8
|
||||
MODEL="8b"
|
||||
SEQ_LENGTH=2048
|
||||
WARMUP=20
|
||||
ACTIVE=4
|
||||
|
||||
# HACK: make model importable
|
||||
example_dir=$(dirname $(realpath $(dirname $0)))
|
||||
if [ -z ${PYTHONPATH+x} ]; then
|
||||
export PYTHONPATH=$example_dir
|
||||
else
|
||||
export PYTHONPATH=$example_dir:$PYTHONPATH
|
||||
fi
|
||||
|
||||
|
||||
# ep
|
||||
echo -e "\n\n Naive EP \n\n"
|
||||
colossalai run --nproc_per_node $NUM_GPU --hostfile "hostfile.txt" \
|
||||
$example_dir/benchmark/benchmark_cai.py \
|
||||
--model_name $MODEL \
|
||||
--batch_size 12 \
|
||||
--seq_length $SEQ_LENGTH \
|
||||
--warmup $WARMUP \
|
||||
--active $ACTIVE \
|
||||
--plugin ep \
|
||||
--zero_stage 2
|
||||
|
||||
|
||||
# ep_zero
|
||||
echo -e "\n\n EP-ZERO \n\n"
|
||||
colossalai run --nproc_per_node $NUM_GPU --hostfile "hostfile.txt" \
|
||||
$example_dir/benchmark/benchmark_cai.py \
|
||||
--model_name $MODEL \
|
||||
--batch_size 20 \
|
||||
--seq_length $SEQ_LENGTH \
|
||||
--warmup $WARMUP \
|
||||
--active $ACTIVE \
|
||||
--plugin ep_zero \
|
||||
--use_kernel \
|
||||
--extra_dp_size 2 \
|
||||
--zero_stage 1 \
|
||||
--load_balance \
|
||||
--overlap_alltoall
|
@@ -1,139 +0,0 @@
|
||||
import argparse
|
||||
import functools
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import tqdm
|
||||
from model.modeling_openmoe import LlamaConfig, OpenMoeDecoderLayer, OpenMoeForCausalLM, set_openmoe_args
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision
|
||||
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
|
||||
from torch.utils.data import Dataset
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from transformers.models.llama import LlamaConfig
|
||||
from utils import PerformanceEvaluator, get_model_numel
|
||||
|
||||
from colossalai.moe.manager import MOE_MANAGER
|
||||
|
||||
|
||||
class RandomDataset(Dataset):
|
||||
def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000):
|
||||
self.num_samples = num_samples
|
||||
self.max_length = max_length
|
||||
self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length))
|
||||
self.attention_mask = torch.ones_like(self.input_ids)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return {
|
||||
"input_ids": self.input_ids[idx],
|
||||
"attention_mask": self.attention_mask[idx],
|
||||
"labels": self.input_ids[idx],
|
||||
}
|
||||
|
||||
|
||||
def fsdp_main(rank, world_size, args):
|
||||
# initialize the process group
|
||||
|
||||
# initialize the process group
|
||||
dist.init_process_group("nccl")
|
||||
|
||||
MOE_MANAGER.setup(parallel=None)
|
||||
|
||||
dp_size = dist.get_world_size()
|
||||
dataset = RandomDataset(
|
||||
max_length=args.seq_length,
|
||||
num_samples=args.batch_size * (args.warmup + args.active) * dp_size,
|
||||
)
|
||||
sampler = DistributedSampler(dataset, rank=rank, num_replicas=world_size, shuffle=False)
|
||||
train_kwargs = {"batch_size": args.batch_size, "sampler": sampler}
|
||||
train_loader = torch.utils.data.DataLoader(dataset, **train_kwargs)
|
||||
torch.cuda.set_device(rank)
|
||||
|
||||
config = LlamaConfig.from_pretrained("hpcai-tech/openmoe-%s" % args.model_name)
|
||||
set_openmoe_args(
|
||||
config,
|
||||
num_experts=config.num_experts,
|
||||
moe_layer_interval=config.moe_layer_interval,
|
||||
enable_load_balance=False,
|
||||
enable_kernel=False,
|
||||
enable_comm_overlap=False,
|
||||
)
|
||||
torch.set_default_dtype(torch.float16)
|
||||
model = OpenMoeForCausalLM(config)
|
||||
torch.set_default_dtype(torch.float32)
|
||||
auto_wrap_policy = functools.partial(
|
||||
transformer_auto_wrap_policy,
|
||||
transformer_layer_cls={
|
||||
OpenMoeDecoderLayer,
|
||||
},
|
||||
)
|
||||
model = FSDP(
|
||||
model,
|
||||
mixed_precision=MixedPrecision(
|
||||
param_dtype=torch.bfloat16,
|
||||
reduce_dtype=torch.bfloat16,
|
||||
buffer_dtype=torch.bfloat16,
|
||||
),
|
||||
auto_wrap_policy=auto_wrap_policy,
|
||||
device_id=torch.cuda.current_device(),
|
||||
)
|
||||
optimizer = torch.optim.Adam(model.parameters(), weight_decay=0.01, lr=1e-5)
|
||||
model.train()
|
||||
|
||||
model_numel = get_model_numel(model)
|
||||
performance_evaluator = PerformanceEvaluator(
|
||||
model_numel,
|
||||
enable_grad_checkpoint=True,
|
||||
ignore_steps=args.warmup,
|
||||
dp_world_size=dist.get_world_size(),
|
||||
)
|
||||
|
||||
for step, data in tqdm.tqdm(enumerate(train_loader), total=len(train_loader)):
|
||||
performance_evaluator.on_step_start(step)
|
||||
input_ids, attention_mask, labels = (
|
||||
data["input_ids"].cuda(),
|
||||
data["attention_mask"].cuda(),
|
||||
data["labels"].cuda(),
|
||||
)
|
||||
|
||||
optimizer.zero_grad()
|
||||
output = model(
|
||||
input_ids=input_ids,
|
||||
labels=labels,
|
||||
attention_mask=attention_mask,
|
||||
chunk_head=False,
|
||||
)
|
||||
loss = output["loss"]
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
performance_evaluator.on_step_end(input_ids)
|
||||
|
||||
performance_evaluator.on_fit_end()
|
||||
if dist.get_rank() == 0:
|
||||
print(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
type=str,
|
||||
default="base",
|
||||
choices=["base", "8b"],
|
||||
help="base or 8b",
|
||||
)
|
||||
parser.add_argument("--batch_size", type=int, default=1)
|
||||
parser.add_argument("--seq_length", type=int, default=2048)
|
||||
parser.add_argument("--warmup", type=int, default=20)
|
||||
parser.add_argument("--active", type=int, default=20)
|
||||
args = parser.parse_args()
|
||||
|
||||
torch.manual_seed(42)
|
||||
|
||||
world_size = int(os.environ["WORLD_SIZE"])
|
||||
local_rank = int(os.environ["LOCAL_RANK"])
|
||||
fsdp_main(local_rank, world_size, args)
|
@@ -1,34 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -xue
|
||||
|
||||
MODEL="8b"
|
||||
BATCH_SIZE=1
|
||||
SEQ_LENGTH=2048
|
||||
WARMUP=8
|
||||
ACTIVE=4
|
||||
|
||||
# HACK: make model importable
|
||||
example_dir=$(dirname $(realpath $(dirname $0)))
|
||||
if [ -z ${PYTHONPATH+x} ]; then
|
||||
export PYTHONPATH=$example_dir
|
||||
else
|
||||
export PYTHONPATH=$example_dir:$PYTHONPATH
|
||||
fi
|
||||
|
||||
# single node
|
||||
torchrun --standalone $example_dir/benchmark/benchmark_fsdp.py \
|
||||
--model_name $MODEL \
|
||||
--batch_size $BATCH_SIZE \
|
||||
--seq_length $SEQ_LENGTH \
|
||||
--warmup $WARMUP \
|
||||
--active $ACTIVE
|
||||
|
||||
# multi node
|
||||
torchrun --nproc_per_node=8 --nnodes=2 --node_rank=node_rank --master_addr=master_addr --master_port=master_port \
|
||||
$example_dir/benchmark/benchmark_fsdp.py \
|
||||
--model_name $MODEL \
|
||||
--batch_size $BATCH_SIZE \
|
||||
--seq_length $SEQ_LENGTH \
|
||||
--warmup $WARMUP \
|
||||
--active $ACTIVE
|
@@ -1,2 +0,0 @@
|
||||
host1
|
||||
host2
|
@@ -1,126 +0,0 @@
|
||||
from time import time
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
from colossalai.logging import DistributedLogger
|
||||
|
||||
|
||||
def print_model_numel(logger: DistributedLogger, model: nn.Module) -> None:
|
||||
B = 1024**3
|
||||
M = 1024**2
|
||||
K = 1024
|
||||
outputs = "Model param count: "
|
||||
model_param = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
if model_param >= B:
|
||||
outputs += f"{model_param / B:.2f} B\n"
|
||||
elif model_param >= M:
|
||||
outputs += f"{model_param / M:.2f} M\n"
|
||||
elif model_param >= K:
|
||||
outputs += f"{model_param / K:.2f} K\n"
|
||||
else:
|
||||
outputs += f"{model_param}\n"
|
||||
logger.info(outputs, ranks=[0])
|
||||
|
||||
|
||||
def get_model_numel(model: nn.Module) -> None:
|
||||
model_param = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
return model_param
|
||||
|
||||
|
||||
def divide(x: float, y: float) -> float:
|
||||
if y == 0:
|
||||
return float("inf")
|
||||
elif y == float("inf"):
|
||||
return float("nan")
|
||||
return x / y
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def all_reduce_mean(x: float, world_size: int) -> float:
|
||||
if world_size == 1:
|
||||
return x
|
||||
tensor = torch.tensor([x], device=torch.cuda.current_device())
|
||||
dist.all_reduce(tensor)
|
||||
tensor = tensor / world_size
|
||||
return tensor.item()
|
||||
|
||||
|
||||
class Timer:
|
||||
def __init__(self) -> None:
|
||||
self.start_time: Optional[float] = None
|
||||
self.duration: float = 0.0
|
||||
|
||||
def start(self) -> None:
|
||||
self.start_time = time()
|
||||
|
||||
def end(self) -> None:
|
||||
assert self.start_time is not None
|
||||
self.duration += time() - self.start_time
|
||||
self.start_time = None
|
||||
|
||||
def reset(self) -> None:
|
||||
self.duration = 0.0
|
||||
|
||||
|
||||
class PerformanceEvaluator:
|
||||
"""
|
||||
Callback for valuate the performance of the model.
|
||||
Args:
|
||||
actor_num_params: The number of parameters of the actor model.
|
||||
critic_num_params: The number of parameters of the critic model.
|
||||
initial_model_num_params: The number of parameters of the initial model.
|
||||
reward_model_num_params: The number of parameters of the reward model.
|
||||
enable_grad_checkpoint: Whether to enable gradient checkpointing.
|
||||
ignore_episodes: The number of episodes to ignore when calculating the performance.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_numel: int,
|
||||
enable_grad_checkpoint: bool = False,
|
||||
ignore_steps: int = 0,
|
||||
dp_world_size: Optional[int] = None,
|
||||
) -> None:
|
||||
self.model_numel = model_numel
|
||||
self.enable_grad_checkpoint = enable_grad_checkpoint
|
||||
self.ignore_steps = ignore_steps
|
||||
self.dp_world_size = dp_world_size
|
||||
self.world_size = dist.get_world_size()
|
||||
self.disable: bool = False
|
||||
self.timer = Timer()
|
||||
self.num_samples: int = 0
|
||||
self.flop: int = 0
|
||||
|
||||
def on_step_start(self, step: int) -> None:
|
||||
self.disable = self.ignore_steps > 0 and step < self.ignore_steps
|
||||
if self.disable:
|
||||
return
|
||||
torch.cuda.synchronize()
|
||||
self.timer.start()
|
||||
|
||||
def on_step_end(self, input_ids: Tensor, **kwargs) -> None:
|
||||
if self.disable:
|
||||
return
|
||||
torch.cuda.synchronize()
|
||||
self.timer.end()
|
||||
|
||||
batch_size, seq_len = input_ids.shape
|
||||
|
||||
self.num_samples += batch_size
|
||||
self.flop += batch_size * seq_len * self.model_numel * 2 * (3 + int(self.enable_grad_checkpoint))
|
||||
|
||||
def on_fit_end(self) -> None:
|
||||
avg_duration = all_reduce_mean(self.timer.duration, self.world_size)
|
||||
avg_throughput = self.num_samples * self.dp_world_size / (avg_duration + 1e-12)
|
||||
mp_world_size = self.world_size // self.dp_world_size
|
||||
avg_tflops_per_gpu = self.flop / 1e12 / (avg_duration + 1e-12) / mp_world_size
|
||||
if dist.get_rank() == 0:
|
||||
print(
|
||||
f"num_samples: {self.num_samples}, dp_world_size: {self.dp_world_size}, flop: {self.flop}, avg_duration: {avg_duration}, "
|
||||
f"avg_throughput: {avg_throughput}"
|
||||
)
|
||||
print(f"Throughput: {avg_throughput:.2f} samples/sec, TFLOPS per GPU: {avg_tflops_per_gpu:.2f}")
|
@@ -1,55 +0,0 @@
|
||||
from argparse import ArgumentParser
|
||||
|
||||
import torch
|
||||
from model.modeling_openmoe import OpenMoeForCausalLM, set_openmoe_args
|
||||
from transformers import T5Tokenizer
|
||||
from transformers.models.llama import LlamaConfig
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--model", default="base", type=str, help="model path", choices=["base", "8b", "test"])
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def inference(args):
|
||||
tokenizer = T5Tokenizer.from_pretrained("google/umt5-small")
|
||||
if args.model == "test":
|
||||
config = LlamaConfig.from_pretrained("hpcai-tech/openmoe-base")
|
||||
set_openmoe_args(
|
||||
config, num_experts=config.num_experts, moe_layer_interval=config.moe_layer_interval, enable_kernel=True
|
||||
)
|
||||
model = OpenMoeForCausalLM(config)
|
||||
else:
|
||||
config = LlamaConfig.from_pretrained(f"hpcai-tech/openmoe-{args.model}")
|
||||
set_openmoe_args(
|
||||
config, num_experts=config.num_experts, moe_layer_interval=config.moe_layer_interval, enable_kernel=False
|
||||
)
|
||||
model = OpenMoeForCausalLM.from_pretrained(f"hpcai-tech/openmoe-{args.model}", config=config)
|
||||
model = model.eval().bfloat16()
|
||||
model = model.to(torch.cuda.current_device())
|
||||
|
||||
input_str = """```
|
||||
y = list(map(int, ['1', 'hello', '2']))
|
||||
```
|
||||
What error does this program produce?
|
||||
ValueError: invalid literal for int() with base 10: 'hello'
|
||||
|
||||
```
|
||||
sum = 0
|
||||
for i in range(100):
|
||||
sum += i
|
||||
```
|
||||
What is the value of sum immediately after the 10th time line 3 is executed?"""
|
||||
|
||||
# print("model config: ", model.config)
|
||||
input_ids = tokenizer("<pad>" + input_str, return_tensors="pt", add_special_tokens=False)
|
||||
input_ids = input_ids.input_ids.to(torch.cuda.current_device())
|
||||
generation_output = model.generate(input_ids, use_cache=True, do_sample=True, max_new_tokens=64)
|
||||
out = tokenizer.decode(generation_output[0], skip_special_tokens=False)
|
||||
print(f"output: \n{out}\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
inference(args)
|
@@ -1 +0,0 @@
|
||||
python infer.py --model "base"
|
@@ -1,220 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 Google LLC and HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Convert T5X checkpoint to PyTorch
|
||||
|
||||
Steps:
|
||||
- Install gsutil according to https://cloud.google.com/storage/docs/gsutil_install
|
||||
- Get a T5X checkpoint at https://github.com/google-research/t5x/blob/main/docs/models.md#t5-11-checkpoints Example:
|
||||
`gsutil -m cp -r gs://t5-data/pretrained_models/t5x/t5_1_1_small $HOME/`
|
||||
- Create or download a corresponding config for the downloaded model. E.g. for T5 v1.1 small, you can use
|
||||
https://huggingface.co/google/t5-v1_1-small/blob/main/config.json
|
||||
- Convert:
|
||||
```
|
||||
python3 convert_t5x_checkpoint_to_pytorch.py --t5x_checkpoint_path=$HOME/t5_1_1_small --config_file=config.json\
|
||||
--pytorch_dump_path=$HOME/t5_1_1_small_pt
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import collections
|
||||
|
||||
import torch
|
||||
from flax import traverse_util
|
||||
from modeling_openmoe import OpenMoeForCausalLM
|
||||
from t5x import checkpoints
|
||||
from transformers import LlamaConfig
|
||||
from transformers.utils import logging
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
||||
|
||||
def t5x_attention_lookup(params, i, prefix, layer_name="attention"):
|
||||
"""Returns the KOQV parameters of (self-)attention. Does not transpose."""
|
||||
k = params[f"{prefix}/layers_{i}/{layer_name}/key/kernel"]
|
||||
o = params[f"{prefix}/layers_{i}/{layer_name}/out/kernel"]
|
||||
q = params[f"{prefix}/layers_{i}/{layer_name}/query/kernel"]
|
||||
v = params[f"{prefix}/layers_{i}/{layer_name}/value/kernel"]
|
||||
return k, o, q, v
|
||||
|
||||
|
||||
def t5x_mlp_lookup(params, i, prefix, split_mlp_wi=False):
|
||||
"""Returns the MLP parameters of a layer. Does not transpose."""
|
||||
if split_mlp_wi:
|
||||
wi_0 = params[f"{prefix}/layers_{i}/mlp/wi_0/kernel"]
|
||||
wi_1 = params[f"{prefix}/layers_{i}/mlp/wi_1/kernel"]
|
||||
wi = (wi_0, wi_1)
|
||||
else:
|
||||
wi = params[f"{prefix}/layers_{i}/mlp/wi/kernel"]
|
||||
|
||||
wo = params[f"{prefix}/layers_{i}/mlp/wo/kernel"]
|
||||
return wi, wo
|
||||
|
||||
|
||||
def t5x_extra_mlp_lookup(params, i, prefix, split_mlp_wi=False):
|
||||
"""Returns the MLP parameters of a layer. Does not transpose."""
|
||||
if split_mlp_wi:
|
||||
wi_0 = params[f"{prefix}/layers_{i}/extra_mlp/wi_0/kernel"]
|
||||
wi_1 = params[f"{prefix}/layers_{i}/extra_mlp/wi_1/kernel"]
|
||||
wi = (wi_0, wi_1)
|
||||
else:
|
||||
wi = params[f"{prefix}/layers_{i}/extra_mlp/wi/kernel"]
|
||||
|
||||
wo = params[f"{prefix}/layers_{i}/extra_mlp/wo/kernel"]
|
||||
return wi, wo
|
||||
|
||||
|
||||
def t5x_experts_lookup(params, i, prefix, split_mlp_wi=False):
|
||||
"""Returns the MLP parameters of a layer. Does not transpose."""
|
||||
if split_mlp_wi:
|
||||
wi_0 = params[f"{prefix}/layers_{i}/mlp/expert/wi_0/kernel"]
|
||||
wi_1 = params[f"{prefix}/layers_{i}/mlp/expert/wi_1/kernel"]
|
||||
wi = (wi_0, wi_1)
|
||||
else:
|
||||
wi = params[f"{prefix}/layers_{i}/mlp/expert/wi/kernel"]
|
||||
|
||||
wo = params[f"{prefix}/layers_{i}/mlp/expert/wo/kernel"]
|
||||
return wi, wo
|
||||
|
||||
|
||||
def t5x_gate_lookup(params, i, prefix, split_mlp_wi=False):
|
||||
"""Returns the MLP parameters of a layer. Does not transpose."""
|
||||
return params[f"{prefix}/layers_{i}/mlp/router/router_weights/w/kernel"]
|
||||
|
||||
|
||||
def t5x_layer_norm_lookup(params, i, prefix, layer_name):
|
||||
"""Returns the layer norm param of a layer."""
|
||||
return params[f"{prefix}/layers_{i}/{layer_name}/scale"]
|
||||
|
||||
|
||||
def convert_t5x_to_pytorch(variables: dict, *, num_layers: int, moe_interval: int):
|
||||
"""Converts the parameters from T5X-Flax to Transformers-PyTorch."""
|
||||
old = traverse_util.flatten_dict(variables["target"])
|
||||
old = {"/".join(k): v for k, v in old.items()}
|
||||
|
||||
# v1.1 models have a gated GeLU with wi_0 and wi_1 instead of wi
|
||||
split_mlp_wi = True
|
||||
print("Split MLP:", split_mlp_wi)
|
||||
|
||||
new = collections.OrderedDict()
|
||||
print(old.keys())
|
||||
for key, value in old.items():
|
||||
print(f"{key}: {value.shape}")
|
||||
|
||||
# Shared embeddings.
|
||||
new["model.embed_tokens.weight"] = old["token_embedder/embedding"]
|
||||
|
||||
# Decoder.
|
||||
for i in range(num_layers):
|
||||
# Block i, layer 0 (Self Attention).
|
||||
layer_norm = t5x_layer_norm_lookup(old, i, "decoder", "pre_self_attention_layer_norm")
|
||||
k, o, q, v = t5x_attention_lookup(old, i, "decoder", "self_attention")
|
||||
new[f"model.layers.{i}.input_layernorm.weight"] = layer_norm
|
||||
new[f"model.layers.{i}.self_attn.k_proj.weight"] = k.T
|
||||
new[f"model.layers.{i}.self_attn.o_proj.weight"] = o.T
|
||||
new[f"model.layers.{i}.self_attn.q_proj.weight"] = q.T
|
||||
new[f"model.layers.{i}.self_attn.v_proj.weight"] = v.T
|
||||
|
||||
# Block i, layer 2 (MLP).
|
||||
layer_norm = t5x_layer_norm_lookup(old, i, "decoder", "pre_mlp_layer_norm")
|
||||
new[f"model.layers.{i}.post_attention_layernorm.weight"] = layer_norm
|
||||
|
||||
if (i + 1) % moe_interval == 0:
|
||||
# moe
|
||||
gate = t5x_gate_lookup(old, i, "decoder", split_mlp_wi)
|
||||
new[f"model.layers.{i}.mlp.gate_weight"] = gate.T
|
||||
wi, wo = t5x_experts_lookup(old, i, "decoder", split_mlp_wi)
|
||||
new[f"model.layers.{i}.mlp.experts.wi_gate"] = wi[0]
|
||||
new[f"model.layers.{i}.mlp.experts.wi_up"] = wi[1]
|
||||
new[f"model.layers.{i}.mlp.experts.wo"] = wo
|
||||
# extra
|
||||
layer_norm = t5x_layer_norm_lookup(old, i, "decoder", "pre_extra_mlp_layer_norm")
|
||||
new[f"model.layers.{i}.pre_extra_mlp_layernorm.weight"] = layer_norm
|
||||
wi, wo = t5x_extra_mlp_lookup(old, i, "decoder", split_mlp_wi)
|
||||
new[f"model.layers.{i}.extra_mlp.gate_proj.weight"] = wi[0].T
|
||||
new[f"model.layers.{i}.extra_mlp.up_proj.weight"] = wi[1].T
|
||||
new[f"model.layers.{i}.extra_mlp.down_proj.weight"] = wo.T
|
||||
else:
|
||||
wi, wo = t5x_mlp_lookup(old, i, "decoder", split_mlp_wi)
|
||||
new[f"model.layers.{i}.mlp.gate_proj.weight"] = wi[0].T
|
||||
new[f"model.layers.{i}.mlp.up_proj.weight"] = wi[1].T
|
||||
new[f"model.layers.{i}.mlp.down_proj.weight"] = wo.T
|
||||
|
||||
new["model.norm.weight"] = old["decoder/decoder_norm/scale"]
|
||||
|
||||
# LM Head (only in v1.1 checkpoints, in v1.0 embeddings are used instead)
|
||||
if "decoder/logits_dense/kernel" in old:
|
||||
new["lm_head.weight"] = old["decoder/logits_dense/kernel"].T
|
||||
|
||||
return new
|
||||
|
||||
|
||||
def make_state_dict(converted_params):
|
||||
"""Prepares a state dict for the PyTorch model."""
|
||||
# Make a state dict with torch tensors.
|
||||
state_dict = collections.OrderedDict([(k, torch.from_numpy(v.copy())) for (k, v) in converted_params.items()])
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def load_t5x_weights_in_t5(model, config, t5x_checkpoint_path):
|
||||
"""Replaces the params in model witht the T5X converted params."""
|
||||
variables = checkpoints.load_t5x_checkpoint(t5x_checkpoint_path)
|
||||
converted = convert_t5x_to_pytorch(
|
||||
variables, num_layers=config.num_hidden_layers, moe_interval=config.moe_layer_interval
|
||||
)
|
||||
state_dict = make_state_dict(converted)
|
||||
model.load_state_dict(state_dict, strict=True)
|
||||
|
||||
|
||||
def convert_t5x_checkpoint_to_pytorch(t5x_checkpoint_path, config_file, pytorch_dump_path):
|
||||
"""Loads the config and model, converts the T5X checkpoint, and saves a PyTorch checkpoint."""
|
||||
# Initialise PyTorch model
|
||||
config = LlamaConfig.from_json_file(config_file)
|
||||
print(f"Building PyTorch model from configuration: {config}")
|
||||
# Non-v1.1 checkpoints could also use T5Model, but this works for all.
|
||||
# The v1.0 checkpoints will simply have an LM head that is the word embeddings.
|
||||
model = OpenMoeForCausalLM(config)
|
||||
|
||||
# Load weights from tf checkpoint
|
||||
load_t5x_weights_in_t5(model, config, t5x_checkpoint_path)
|
||||
|
||||
# Save pytorch-model
|
||||
print(f"Save PyTorch model to {pytorch_dump_path}")
|
||||
model.save_pretrained(pytorch_dump_path)
|
||||
|
||||
# Verify that we can load the checkpoint.
|
||||
model.from_pretrained(pytorch_dump_path)
|
||||
print("Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Converts a native T5X checkpoint into a PyTorch checkpoint.")
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--t5x_checkpoint_path", default=None, type=str, required=True, help="Path to the T5X checkpoint."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config_file",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The config json file corresponding to the pre-trained T5 model.\nThis specifies the model architecture.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_t5x_checkpoint_to_pytorch(args.t5x_checkpoint_path, args.config_file, args.pytorch_dump_path)
|
@@ -1 +0,0 @@
|
||||
python convert_openmoe_ckpt.py --t5x_checkpoint_path /path/to/t5x --config_file /path/to/config --pytorch_dump_path /path/to/save
|
File diff suppressed because it is too large
Load Diff
@@ -1,24 +0,0 @@
|
||||
{
|
||||
"architectures": [
|
||||
"OpenMoeForCausalLM"
|
||||
],
|
||||
"intermediate_size": 8192,
|
||||
"hidden_size": 2048,
|
||||
"num_hidden_layers": 24,
|
||||
"head_dim": 128,
|
||||
"num_attention_heads": 24,
|
||||
"dropout_rate": 0.0,
|
||||
"layer_norm_epsilon": 1e-06,
|
||||
"vocab_size": 256384,
|
||||
"hidden_act": "swiglu",
|
||||
"num_experts": 32,
|
||||
"topk": 2,
|
||||
"capacity_factor_train": 1.25,
|
||||
"capacity_factor_eval": 2.0,
|
||||
"min_capacity": 4,
|
||||
"noisy_policy": null,
|
||||
"drop_tks": true,
|
||||
"expert_parallel": null,
|
||||
"gated": true,
|
||||
"moe_layer_interval": 6
|
||||
}
|
@@ -1,24 +0,0 @@
|
||||
{
|
||||
"architectures": [
|
||||
"OpenMoeForCausalLM"
|
||||
],
|
||||
"intermediate_size": 2048,
|
||||
"hidden_size": 768,
|
||||
"num_hidden_layers": 12,
|
||||
"head_dim": 64,
|
||||
"num_attention_heads": 12,
|
||||
"dropout_rate": 0.0,
|
||||
"layer_norm_epsilon": 1e-06,
|
||||
"vocab_size": 256384,
|
||||
"hidden_act": "swiglu",
|
||||
"num_experts": 16,
|
||||
"topk": 2,
|
||||
"capacity_factor_train": 1.25,
|
||||
"capacity_factor_eval": 2.0,
|
||||
"min_capacity": 4,
|
||||
"noisy_policy": null,
|
||||
"drop_tks": true,
|
||||
"expert_parallel": null,
|
||||
"gated": true,
|
||||
"moe_layer_interval": 4
|
||||
}
|
@@ -1,565 +0,0 @@
|
||||
from functools import partial
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
from torch.nn import Module
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.utils import logging
|
||||
|
||||
from colossalai.moe.manager import MOE_MANAGER
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
|
||||
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
from .modeling_openmoe import OpenMoeDecoderLayer, OpenMoeForCausalLM, OpenMoeModel
|
||||
|
||||
__all__ = ["OpenMoePolicy", "OpenMoeForCausalLMPolicy"]
|
||||
|
||||
|
||||
class OpenMoePolicy(Policy):
|
||||
def config_sanity_check(self):
|
||||
pass
|
||||
|
||||
def preprocess(self):
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
# Resize embedding
|
||||
vocab_size = self.model.config.vocab_size
|
||||
world_size = self.shard_config.tensor_parallel_size
|
||||
|
||||
if vocab_size % world_size != 0:
|
||||
new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
||||
self.model.resize_token_embeddings(new_vocab_size)
|
||||
|
||||
return self.model
|
||||
|
||||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||
policy = {}
|
||||
|
||||
if self.shard_config.enable_sequence_parallelism:
|
||||
self.shard_config.enable_sequence_parallelism = False
|
||||
raise NotImplementedError(
|
||||
"openmoe doesn't support sequence parallelism now, will ignore the sequence parallelism flag."
|
||||
)
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
raise NotImplementedError("Tensor parallelism is not supported for openmoe model now.")
|
||||
|
||||
# optimization configuration
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="input_layernorm",
|
||||
target_module=FusedRMSNorm,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="post_attention_layernorm",
|
||||
target_module=FusedRMSNorm,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="pre_extra_mlp_layernorm",
|
||||
target_module=FusedRMSNorm,
|
||||
ignore_if_not_exist=True,
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
target_key=OpenMoeDecoderLayer,
|
||||
)
|
||||
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="norm",
|
||||
target_module=FusedRMSNorm,
|
||||
),
|
||||
policy=policy,
|
||||
target_key=OpenMoeModel,
|
||||
)
|
||||
|
||||
if self.shard_config.enable_flash_attention:
|
||||
raise NotImplementedError("Flash attention has already been replaced in openmoe.")
|
||||
|
||||
return policy
|
||||
|
||||
def postprocess(self):
|
||||
return self.model
|
||||
|
||||
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
|
||||
"""If under pipeline parallel setting, replacing the original forward method of huggingface
|
||||
to customized forward method, and add this changing to policy."""
|
||||
if self.pipeline_stage_manager:
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if self.model.__class__.__name__ == "OpenMoeModel":
|
||||
module = self.model
|
||||
else:
|
||||
module = self.model.model
|
||||
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||
stage_index = stage_manager.get_stage_index(layers_per_stage)
|
||||
method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
|
||||
self.append_or_create_method_replacement(
|
||||
description=method_replacement, policy=policy, target_key=model_cls
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
def get_held_layers(self) -> List[Module]:
|
||||
"""Get pipeline layers for current stage."""
|
||||
assert self.pipeline_stage_manager is not None
|
||||
|
||||
if self.model.__class__.__name__ == "OpenMoeModel":
|
||||
module = self.model
|
||||
else:
|
||||
module = self.model.model
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
|
||||
held_layers = []
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||
if stage_manager.is_first_stage():
|
||||
held_layers.append(module.embed_tokens)
|
||||
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
|
||||
held_layers.extend(module.layers[start_idx:end_idx])
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(module.norm)
|
||||
|
||||
return held_layers
|
||||
|
||||
def distribute_layers(self, num_layers: int, num_stages: int) -> List[int]:
|
||||
"""Divide layers into stages"""
|
||||
if num_layers == 24 and num_stages == 4:
|
||||
return [7, 7, 7, 3]
|
||||
elif num_layers == 24 and num_stages == 2:
|
||||
return [15, 9]
|
||||
elif num_layers == 12 and num_stages == 4:
|
||||
return [5, 5, 5, 1]
|
||||
elif num_layers == 12 and num_stages == 2:
|
||||
return [8, 4]
|
||||
else:
|
||||
print(f"num_layers: {num_layers}, num_stages: {num_stages} not optimized, use origin pp policy")
|
||||
return super().distribute_layers(num_layers, num_stages)
|
||||
|
||||
|
||||
class OpenMoeModelPolicy(OpenMoePolicy):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
policy = super().module_policy()
|
||||
if self.pipeline_stage_manager:
|
||||
# set None as default
|
||||
self.set_pipeline_forward(
|
||||
model_cls=OpenMoeModel,
|
||||
new_forward=OpenMoePipelineForwards.openmoe_model_forward,
|
||||
policy=policy,
|
||||
)
|
||||
return policy
|
||||
|
||||
def get_held_layers(self) -> List[Module]:
|
||||
"""Get pipeline layers for current stage."""
|
||||
held_layers = super().get_held_layers()
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
"""No shared params in llama model"""
|
||||
return []
|
||||
|
||||
|
||||
class OpenMoeForCausalLMPolicy(OpenMoePolicy):
|
||||
def module_policy(self):
|
||||
policy = super().module_policy()
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
# add a new item for casual lm
|
||||
# TODO: recursively assign ep group foe all modules
|
||||
new_item = {
|
||||
OpenMoeForCausalLM: ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(gather_output=True),
|
||||
)
|
||||
]
|
||||
)
|
||||
}
|
||||
policy.update(new_item)
|
||||
|
||||
if self.pipeline_stage_manager:
|
||||
# set None as default
|
||||
self.set_pipeline_forward(
|
||||
model_cls=OpenMoeForCausalLM,
|
||||
new_forward=OpenMoePipelineForwards.llama_for_causal_lm_forward,
|
||||
policy=policy,
|
||||
)
|
||||
|
||||
return policy
|
||||
|
||||
def get_held_layers(self) -> List[Module]:
|
||||
"""Get pipeline layers for current stage."""
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
held_layers = super().get_held_layers()
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(self.model.lm_head)
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
llama_model = self.model.model
|
||||
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
|
||||
if (
|
||||
id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight)
|
||||
and self.pipeline_stage_manager.num_stages > 1
|
||||
):
|
||||
# tie weights
|
||||
return [
|
||||
{
|
||||
0: llama_model.embed_tokens.weight,
|
||||
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
|
||||
}
|
||||
]
|
||||
return []
|
||||
|
||||
|
||||
class OpenMoePipelineForwards:
|
||||
"""
|
||||
This class serves as a micro library for forward function substitution of Llama models
|
||||
under pipeline setting.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def openmoe_model_forward(
|
||||
self: OpenMoeModel,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
past_router_aux_loss: Optional[torch.FloatTensor] = None,
|
||||
past_router_z_loss: Optional[torch.FloatTensor] = None,
|
||||
):
|
||||
# reset moe loss for different data
|
||||
MOE_MANAGER.reset_loss()
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if stage_manager.is_first_stage():
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
batch_size, seq_length = input_shape
|
||||
device = hidden_states.device
|
||||
|
||||
seq_length_with_past = seq_length
|
||||
past_key_values_length = 0
|
||||
|
||||
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
|
||||
if output_attentions:
|
||||
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
|
||||
output_attentions = False
|
||||
if output_hidden_states:
|
||||
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
|
||||
output_hidden_states = False
|
||||
if use_cache:
|
||||
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
|
||||
use_cache = False
|
||||
|
||||
if past_key_values is not None:
|
||||
past_key_values_length = past_key_values[0][0].shape[2]
|
||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length,
|
||||
seq_length + past_key_values_length,
|
||||
dtype=torch.long,
|
||||
device=device,
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
else:
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
|
||||
# embed positions, for the first stage, hidden_states is the input embeddings,
|
||||
# for the other stages, hidden_states is the output of the previous stage
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(
|
||||
(batch_size, seq_length_with_past),
|
||||
dtype=torch.bool,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
attention_mask = self._prepare_decoder_attention_mask(
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
hidden_states,
|
||||
past_key_values_length,
|
||||
)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = () if use_cache else None
|
||||
|
||||
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||
for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs, output_attentions, None)
|
||||
|
||||
return custom_forward
|
||||
|
||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(decoder_layer),
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
|
||||
# concat past losses with current ones
|
||||
router_aux_loss, router_z_loss = MOE_MANAGER.get_loss()
|
||||
if past_router_aux_loss is not None and past_router_z_loss is not None:
|
||||
router_aux_loss = past_router_aux_loss + router_aux_loss
|
||||
router_z_loss = past_router_z_loss + router_z_loss
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
return tuple(
|
||||
[
|
||||
hidden_states,
|
||||
next_cache,
|
||||
all_hidden_states,
|
||||
all_self_attns,
|
||||
router_aux_loss,
|
||||
router_z_loss,
|
||||
]
|
||||
)
|
||||
# always return dict for imediate stage
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"router_aux_loss": router_aux_loss,
|
||||
"router_z_loss": router_z_loss,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def llama_for_causal_lm_forward(
|
||||
self: OpenMoeForCausalLM,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
chunk_head: Optional[bool] = True,
|
||||
past_router_aux_loss: Optional[torch.FloatTensor] = None,
|
||||
past_router_z_loss: Optional[torch.FloatTensor] = None,
|
||||
):
|
||||
r"""
|
||||
Args:
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, LlamaForCausalLM
|
||||
|
||||
>>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
||||
|
||||
>>> prompt = "Hey, are you consciours? Can you talk to me?"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
|
||||
```"""
|
||||
logger = logging.get_logger(__name__)
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
|
||||
if output_attentions:
|
||||
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
|
||||
output_attentions = False
|
||||
if output_hidden_states:
|
||||
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
|
||||
output_hidden_states = False
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = OpenMoePipelineForwards.openmoe_model_forward(
|
||||
self.model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
stage_manager=stage_manager,
|
||||
hidden_states=hidden_states,
|
||||
stage_index=stage_index,
|
||||
past_router_aux_loss=past_router_aux_loss,
|
||||
past_router_z_loss=past_router_z_loss,
|
||||
)
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
(
|
||||
hidden_states,
|
||||
past_key_values,
|
||||
all_hidden_states,
|
||||
attentions,
|
||||
router_aux_loss,
|
||||
router_z_loss,
|
||||
) = outputs
|
||||
|
||||
if self.pretraining_tp > 1:
|
||||
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.pretraining_tp, dim=0)
|
||||
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.pretraining_tp)]
|
||||
logits = torch.cat(logits, dim=-1)
|
||||
|
||||
loss = None
|
||||
# if no training, just do forward
|
||||
if labels is None:
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = logits.float()
|
||||
# the vocab size for openmoe is 30w+
|
||||
# which causes great activation memory in training, up to 20G for one sequence
|
||||
# so we use chunk and checkpoint to reduce memory
|
||||
else:
|
||||
if chunk_head == True:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
logits = module(inputs[0])
|
||||
logits = logits.float()
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous().float()
|
||||
shift_labels = inputs[1][..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss = self._calculate_loss(shift_logits, shift_labels)
|
||||
return loss
|
||||
|
||||
return custom_forward
|
||||
|
||||
aux_loss, z_loss = self._calculate_router_loss(router_aux_loss, router_z_loss)
|
||||
loss = aux_loss + z_loss
|
||||
for batch_idx in range(hidden_states.shape[0]):
|
||||
loss = loss + torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.lm_head),
|
||||
hidden_states[batch_idx : batch_idx + 1, :],
|
||||
labels[batch_idx : batch_idx + 1, :],
|
||||
)
|
||||
logits = None
|
||||
else:
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = logits.float()
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
aux_loss, z_loss = self._calculate_router_loss(router_aux_loss, router_z_loss)
|
||||
loss = aux_loss + z_loss
|
||||
loss = loss + self._calculate_loss(shift_logits, shift_labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=past_key_values,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=attentions,
|
||||
)
|
||||
else:
|
||||
hidden_states = outputs["hidden_states"]
|
||||
router_aux_loss = outputs["router_aux_loss"]
|
||||
router_z_loss = outputs["router_z_loss"]
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"past_router_aux_loss": router_aux_loss,
|
||||
"past_router_z_loss": router_z_loss,
|
||||
}
|
@@ -1,5 +0,0 @@
|
||||
colossalai >= 0.3.3
|
||||
torch >= 1.8.1
|
||||
transformers >= 4.20.0, <= 4.34.0
|
||||
sentencepiece
|
||||
datasets
|
@@ -1,37 +0,0 @@
|
||||
# pip install -r requirements.txt
|
||||
|
||||
# inference
|
||||
# python infer.py --model "test"
|
||||
|
||||
# train
|
||||
# torchrun --standalone --nproc_per_node 4 train.py \
|
||||
# --num_epoch 1 \
|
||||
# --model_name "test" \
|
||||
# --plugin "ep" \
|
||||
# --batch_size 1
|
||||
|
||||
# torchrun --standalone --nproc_per_node 4 train.py \
|
||||
# --num_epoch 1 \
|
||||
# --model_name "test" \
|
||||
# --plugin "ep_zero" \
|
||||
# --batch_size 1 \
|
||||
# --zero_stage 1 \
|
||||
# --extra_dp_size 2 \
|
||||
|
||||
# torchrun --standalone --nproc_per_node 4 train.py \
|
||||
# --num_epoch 1 \
|
||||
# --model_name "test" \
|
||||
# --plugin "ep_zero" \
|
||||
# --batch_size 1 \
|
||||
# --zero_stage 2 \
|
||||
# --extra_dp_size 2 \
|
||||
|
||||
# torchrun --standalone --nproc_per_node 4 train.py \
|
||||
# --model_name "test" \
|
||||
# --plugin "hybrid" \
|
||||
# --num_epoch 1 \
|
||||
# --pp_size 2 \
|
||||
# --dp_size 1 \
|
||||
# --ep_size 2 \
|
||||
# --zero_stage 1 \
|
||||
# --batch_size 1
|
@@ -1,383 +0,0 @@
|
||||
import argparse
|
||||
import os
|
||||
from functools import partial
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import snapshot_download
|
||||
from model.modeling_openmoe import OpenMoeForCausalLM, set_openmoe_args
|
||||
from model.openmoe_policy import OpenMoeForCausalLMPolicy
|
||||
from torch.utils.data import Dataset
|
||||
from tqdm import tqdm
|
||||
from transformers import T5Tokenizer
|
||||
from transformers.models.llama import LlamaConfig
|
||||
|
||||
import colossalai
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.moe.utils import skip_init
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.shardformer.layer.moe import apply_load_balance
|
||||
|
||||
|
||||
def move_to_cuda(batch, device):
|
||||
return {k: v.to(device) for k, v in batch.items()}
|
||||
|
||||
|
||||
def load_ckpt(repo_name: str, model: OpenMoeForCausalLM, booster: Booster):
|
||||
ckpt_path = snapshot_download(repo_name)
|
||||
# single ckpt
|
||||
if os.path.exists(os.path.join(ckpt_path, "pytorch_model.bin")):
|
||||
ckpt_path = os.path.join(ckpt_path, "pytorch_model.bin")
|
||||
# shard ckpt
|
||||
elif os.path.exists(os.path.join(ckpt_path, "pytorch_model.bin.index.json")):
|
||||
ckpt_path = os.path.join(ckpt_path, "pytorch_model.bin.index.json")
|
||||
else:
|
||||
raise ValueError(f"Invalid checkpoint path: {ckpt_path}")
|
||||
booster.load_model(model, ckpt_path)
|
||||
|
||||
|
||||
def tokenize_data(batch, tokenizer: T5Tokenizer, max_length: int) -> Dict:
|
||||
texts = ["<pad>" + sample["prompt"] + sample["completion"] for sample in batch]
|
||||
data = tokenizer(
|
||||
texts,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=max_length,
|
||||
add_special_tokens=False,
|
||||
)
|
||||
data = {k: v.cuda() for k, v in data.items()}
|
||||
data["labels"] = data["input_ids"].clone()
|
||||
return data
|
||||
|
||||
|
||||
class RandomDataset(Dataset):
|
||||
def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000, tokenizer=None):
|
||||
self.num_samples = num_samples
|
||||
self.max_length = max_length
|
||||
self.input_ids = torch.randint(
|
||||
0, vocab_size, (num_samples, max_length), device=get_accelerator().get_current_device()
|
||||
)
|
||||
self.attention_mask = torch.ones_like(self.input_ids)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return {
|
||||
"input_ids": self.input_ids[idx],
|
||||
"attention_mask": self.attention_mask[idx],
|
||||
"labels": self.input_ids[idx],
|
||||
}
|
||||
|
||||
|
||||
def parse_args():
|
||||
# basic settings
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
type=str,
|
||||
default="base",
|
||||
choices=["base", "8b", "test"],
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--plugin",
|
||||
type=str,
|
||||
default="hybrid",
|
||||
choices=["ep", "ep_zero", "hybrid"],
|
||||
help="Parallel methos. ep_zero is recommended for general cases. ep can provides least memory consumption and hybrid suits large scale training.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_path",
|
||||
type=str,
|
||||
default="./outputs",
|
||||
help="The path of your saved model after finetuning.",
|
||||
)
|
||||
parser.add_argument("--num_epoch", type=int, default=1, help="Number of epochs.")
|
||||
parser.add_argument(
|
||||
"--batch_size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Batch size (per dp group) for the training dataloader.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_interval",
|
||||
type=int,
|
||||
default=1000,
|
||||
help=" The interval (steps) of saving checkpoints.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--precision",
|
||||
type=str,
|
||||
default="bf16",
|
||||
choices=["fp32", "bf16", "fp16"],
|
||||
help="The mixed precision training.",
|
||||
)
|
||||
parser.add_argument("--max_length", type=int, default=2048, help="Max sequence length.")
|
||||
parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
default="yizhongw/self_instruct",
|
||||
help="dataset name from `datasets` repo.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task_name",
|
||||
type=str,
|
||||
default="super_natural_instructions",
|
||||
help="task of corresponding dataset.",
|
||||
)
|
||||
|
||||
# optim
|
||||
parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate.")
|
||||
parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
|
||||
|
||||
# zero stage for all plugins
|
||||
parser.add_argument("--zero_stage", type=int, default=2, help="zero stage.")
|
||||
# ep_zero plugin
|
||||
parser.add_argument(
|
||||
"--extra_dp_size", type=int, default=1, help="ep_zero plugin's moe dp size. Recommended to be 2 or 4."
|
||||
)
|
||||
# hybrid plugin
|
||||
parser.add_argument("--pp_size", type=int, default=2, help="pp size for hybrid plugin")
|
||||
parser.add_argument("--dp_size", type=int, default=1, help="dp size for hybrid plugin")
|
||||
parser.add_argument("--ep_size", type=int, default=2, help="ep size for hybrid plugin")
|
||||
parser.add_argument("--microbatch_size", type=int, default=1, help="Microbatch size in pipeline for hybrid plugin")
|
||||
|
||||
# kernel
|
||||
parser.add_argument(
|
||||
"--use_kernel",
|
||||
action="store_true",
|
||||
help="Use kernel optim. Need to install flash attention and triton to enable all kernel optimizations. Skip if not installed.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_layernorm_kernel",
|
||||
action="store_true",
|
||||
help="Use layernorm kernel. Need to install apex. Raise error if not installed.",
|
||||
)
|
||||
|
||||
# loss
|
||||
parser.add_argument(
|
||||
"--router_aux_loss_factor",
|
||||
type=float,
|
||||
default=0.01,
|
||||
help="Moe router z loss. You can refer to STMoE for details.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--router_z_loss_factor",
|
||||
type=float,
|
||||
default=0.0001,
|
||||
help="Moe router aux loss. You can refer to STMoE for details.",
|
||||
)
|
||||
parser.add_argument("--label_smoothing", type=float, default=0.0, help="Label smoothing.")
|
||||
parser.add_argument(
|
||||
"--z_loss_factor", type=float, default=0.0001, help="The final outputs' classification z loss factor."
|
||||
)
|
||||
|
||||
# load balance
|
||||
parser.add_argument(
|
||||
"--load_balance", action="store_true", help="Expert load balance. Defaults to False. Recommend to enable."
|
||||
)
|
||||
parser.add_argument("--load_balance_interval", type=int, default=1000, help="Expert load balance interval.")
|
||||
# communicate overlap
|
||||
parser.add_argument(
|
||||
"--comm_overlap",
|
||||
action="store_true",
|
||||
help="Use communication overlap for MoE. Recommended to enable for multi-node training.",
|
||||
)
|
||||
# hierarchical all-to-all
|
||||
parser.add_argument(
|
||||
"--hierarchical_alltoall",
|
||||
action="store_true",
|
||||
help="Use hierarchical all-to-all for MoE. Recommended to enable for multi-node training.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
# Launch ColossalAI
|
||||
colossalai.launch_from_torch(seed=args.seed)
|
||||
coordinator = DistCoordinator()
|
||||
test_mode = args.model_name == "test"
|
||||
|
||||
# Set plugin
|
||||
booster_kwargs = {}
|
||||
hybrid_dict = {
|
||||
"tp_size": 1,
|
||||
"custom_policy": OpenMoeForCausalLMPolicy(),
|
||||
"enable_fused_normalization": args.use_layernorm_kernel,
|
||||
"enable_jit_fused": args.use_kernel,
|
||||
"precision": args.precision,
|
||||
"zero_stage": args.zero_stage,
|
||||
}
|
||||
if args.plugin == "ep":
|
||||
dp_size = dist.get_world_size()
|
||||
plugin = MoeHybridParallelPlugin(
|
||||
pp_size=1,
|
||||
ep_size=args.ep_size,
|
||||
**hybrid_dict,
|
||||
)
|
||||
# MOE_MANAGER.setup(
|
||||
# parallel="EP",
|
||||
# max_ep_size=dp_size,
|
||||
# **mgr_dict,
|
||||
# )
|
||||
elif args.plugin == "ep_zero":
|
||||
dp_size = dist.get_world_size()
|
||||
use_ep_inside = False
|
||||
plugin = MoeHybridParallelPlugin(
|
||||
pp_size=1,
|
||||
ep_size=dp_size // args.ep_size,
|
||||
use_ep_inside=use_ep_inside,
|
||||
**hybrid_dict,
|
||||
)
|
||||
# MOE_MANAGER.setup(
|
||||
# parallel="EP",
|
||||
# max_ep_size=dp_size // args.extra_dp_size,
|
||||
# use_ep_inside=use_ep_inside,
|
||||
# **mgr_dict,
|
||||
# )
|
||||
elif args.plugin == "hybrid":
|
||||
dp_size = dist.get_world_size() // args.pp_size
|
||||
plugin = MoeHybridParallelPlugin(
|
||||
pp_size=args.pp_size,
|
||||
ep_size=args.ep_size,
|
||||
microbatch_size=args.microbatch_size,
|
||||
**hybrid_dict,
|
||||
)
|
||||
# MOE_MANAGER.setup(
|
||||
# parallel="EP",
|
||||
# mode="fixed",
|
||||
# fixed_dp_size=args.dp_size,
|
||||
# fixed_ep_size=args.ep_size,
|
||||
# fixed_pp_size=args.pp_size,
|
||||
# **mgr_dict,
|
||||
# )
|
||||
else:
|
||||
raise ValueError(f"Invalid plugin {args.plugin}")
|
||||
coordinator.print_on_master(f"Set plugin as {plugin.__class__.__name__}")
|
||||
|
||||
# Build OpenMoe model
|
||||
if test_mode:
|
||||
config = LlamaConfig.from_pretrained("hpcai-tech/openmoe-base")
|
||||
config.hidden_size = 128
|
||||
config.intermediate_size = 256
|
||||
config.vocab_size = 32000
|
||||
else:
|
||||
repo_name = "hpcai-tech/openmoe-" + args.model_name
|
||||
config = LlamaConfig.from_pretrained(repo_name)
|
||||
set_openmoe_args(
|
||||
config,
|
||||
num_experts=config.num_experts,
|
||||
moe_layer_interval=config.moe_layer_interval,
|
||||
router_aux_loss_factor=args.router_aux_loss_factor,
|
||||
router_z_loss_factor=args.router_z_loss_factor,
|
||||
z_loss_factor=args.z_loss_factor,
|
||||
enable_load_balance=args.load_balance,
|
||||
enable_comm_overlap=args.comm_overlap,
|
||||
enable_hierarchical_alltoall=args.hierarchical_alltoall,
|
||||
enable_kernel=args.use_kernel,
|
||||
)
|
||||
with skip_init():
|
||||
model = OpenMoeForCausalLM(config)
|
||||
coordinator.print_on_master(f"Finish init model with config:\n{config}")
|
||||
|
||||
# Enable gradient checkpointing
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
# Prepare tokenizer and dataloader
|
||||
tokenizer = T5Tokenizer.from_pretrained("google/umt5-small")
|
||||
if test_mode:
|
||||
dataset = RandomDataset(num_samples=20, tokenizer=tokenizer)
|
||||
collate_fn = None
|
||||
else:
|
||||
dataset = load_dataset(args.dataset, args.task_name)
|
||||
dataset = dataset["train"]
|
||||
collate_fn = partial(tokenize_data, tokenizer=tokenizer, max_length=args.max_length)
|
||||
dataloader = plugin.prepare_dataloader(
|
||||
dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, collate_fn=collate_fn
|
||||
)
|
||||
|
||||
# Set optimizer
|
||||
optimizer = HybridAdam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
||||
|
||||
# Set booster
|
||||
booster = Booster(plugin=plugin, **booster_kwargs)
|
||||
if not test_mode:
|
||||
load_ckpt(repo_name, model, booster)
|
||||
model, optimizer, _, dataloader, _ = booster.boost(model=model, optimizer=optimizer, dataloader=dataloader)
|
||||
use_pipeline = isinstance(booster.plugin, MoeHybridParallelPlugin) and booster.plugin.pp_size > 1
|
||||
is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
|
||||
coordinator.print_on_master(f"Finish init booster")
|
||||
|
||||
# Start finetuning
|
||||
coordinator.print_on_master(f"Start finetuning")
|
||||
for epoch in range(args.num_epoch):
|
||||
model.train()
|
||||
train_dataloader_iter = iter(dataloader)
|
||||
total_len = len(train_dataloader_iter)
|
||||
with tqdm(
|
||||
range(total_len),
|
||||
desc=f"Epoch [{epoch + 1}/{args.num_epoch}]",
|
||||
disable=not coordinator.is_master(),
|
||||
) as pbar:
|
||||
for step in pbar:
|
||||
if use_pipeline:
|
||||
# Forward pass
|
||||
outputs = booster.execute_pipeline(
|
||||
train_dataloader_iter,
|
||||
model,
|
||||
lambda x, y: x.loss,
|
||||
optimizer,
|
||||
return_loss=True,
|
||||
)
|
||||
# Backward and optimize
|
||||
if is_pp_last_stage:
|
||||
loss = outputs["loss"]
|
||||
pbar.set_postfix({"loss": loss.item()})
|
||||
else:
|
||||
# Forward pass
|
||||
data = next(train_dataloader_iter)
|
||||
data = move_to_cuda(data, torch.cuda.current_device())
|
||||
outputs = model(**data)
|
||||
loss = outputs["loss"]
|
||||
# Backward
|
||||
booster.backward(loss, optimizer)
|
||||
pbar.set_postfix({"loss": loss.item()})
|
||||
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Apply load balance
|
||||
if (
|
||||
args.load_balance
|
||||
and args.load_balance_interval > 0
|
||||
and (step + 1) % args.load_balance_interval == 0
|
||||
):
|
||||
coordinator.print_on_master(f"Apply load balance")
|
||||
apply_load_balance(model, optimizer)
|
||||
# save checkpoint
|
||||
if (step + 1) % args.save_interval == 0:
|
||||
coordinator.print_on_master(f"Saving model checkpoint to {args.output_path}")
|
||||
booster.save_model(model, args.output_path, shard=True)
|
||||
|
||||
# save checkpoint at the end of each epochs
|
||||
booster.save_model(model, args.output_path, shard=True)
|
||||
coordinator.print_on_master(f"Saving model checkpoint to {args.output_path}")
|
||||
|
||||
# Finish training
|
||||
coordinator.print_on_master(f"Finish training")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@@ -1,40 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -xue
|
||||
|
||||
NUM_GPU=8
|
||||
MODEL="8b"
|
||||
SEQ_LENGTH=2048
|
||||
BATCH_SIZE=1
|
||||
LR=0.00001
|
||||
|
||||
# ep zero
|
||||
torchrun --standalone --nproc_per_node $NUM_GPU train.py \
|
||||
--num_epoch 1 \
|
||||
--model_name $MODEL \
|
||||
--plugin "ep_zero" \
|
||||
--batch_size $BATCH_SIZE \
|
||||
--lr $LR \
|
||||
--zero_stage 1 \
|
||||
--extra_dp_size 2
|
||||
|
||||
# ep
|
||||
# torchrun --standalone --nproc_per_node $NUM_GPU train.py \
|
||||
# --num_epoch 1 \
|
||||
# --model_name $MODEL \
|
||||
# --plugin "ep_zero" \
|
||||
# --batch_size $BATCH_SIZE \
|
||||
# --lr $LR \
|
||||
# --zero_stage 1
|
||||
|
||||
# hybrid
|
||||
# torchrun --standalone --nproc_per_node $NUM_GPU train.py \
|
||||
# --num_epoch 1 \
|
||||
# --model_name $MODEL \
|
||||
# --plugin "hybrid" \
|
||||
# --batch_size $BATCH_SIZE \
|
||||
# --lr $LR \
|
||||
# --zero_stage 1 \
|
||||
# --pp_size 2 \
|
||||
# --dp_size 1 \
|
||||
# --ep_size 2 \
|
@@ -1,4 +1,5 @@
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
import tqdm
|
||||
@@ -8,9 +9,11 @@ from transformers import AutoConfig, OPTForCausalLM
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
import colossalai
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
|
||||
@@ -62,14 +65,6 @@ def main():
|
||||
if args.mem_cap > 0:
|
||||
colo_memory_cap(args.mem_cap)
|
||||
|
||||
# Build OPT model
|
||||
config = AutoConfig.from_pretrained(args.model_name_or_path)
|
||||
model = OPTForCausalLM(config=config)
|
||||
logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])
|
||||
|
||||
# Enable gradient checkpointing
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
# Set plugin
|
||||
booster_kwargs = {}
|
||||
if args.plugin == "torch_ddp_fp16":
|
||||
@@ -82,6 +77,19 @@ def main():
|
||||
plugin = LowLevelZeroPlugin(initial_scale=2**5)
|
||||
logger.info(f"Set plugin as {args.plugin}", ranks=[0])
|
||||
|
||||
# Build OPT model
|
||||
init_ctx = (
|
||||
LazyInitContext(default_device=get_accelerator().get_current_device())
|
||||
if isinstance(plugin, (GeminiPlugin))
|
||||
else nullcontext()
|
||||
)
|
||||
config = AutoConfig.from_pretrained(args.model_name_or_path)
|
||||
with init_ctx:
|
||||
model = OPTForCausalLM(config=config)
|
||||
logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])
|
||||
|
||||
# Enable gradient checkpointing
|
||||
model.gradient_checkpointing_enable()
|
||||
# Set optimizer
|
||||
optimizer = HybridAdam(model.parameters(), lr=args.learning_rate)
|
||||
|
||||
|
@@ -1,3 +1,5 @@
|
||||
from contextlib import nullcontext
|
||||
|
||||
import datasets
|
||||
import torch
|
||||
import transformers
|
||||
@@ -8,9 +10,11 @@ from transformers import AutoConfig, AutoTokenizer, OPTForCausalLM, get_linear_s
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
import colossalai
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
|
||||
@@ -78,14 +82,6 @@ def main():
|
||||
datasets.utils.logging.set_verbosity_error()
|
||||
transformers.utils.logging.set_verbosity_error()
|
||||
|
||||
# Build OPT model
|
||||
config = AutoConfig.from_pretrained(args.model_name_or_path)
|
||||
model = OPTForCausalLM.from_pretrained(args.model_name_or_path, config=config)
|
||||
logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])
|
||||
|
||||
# Enable gradient checkpointing
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
# Set plugin
|
||||
booster_kwargs = {}
|
||||
if args.plugin == "torch_ddp_fp16":
|
||||
@@ -110,6 +106,21 @@ def main():
|
||||
|
||||
logger.info(f"Set plugin as {args.plugin}", ranks=[0])
|
||||
|
||||
# Build OPT model
|
||||
config = AutoConfig.from_pretrained(args.model_name_or_path)
|
||||
# Build OPT model
|
||||
init_ctx = (
|
||||
LazyInitContext(default_device=get_accelerator().get_current_device())
|
||||
if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin))
|
||||
else nullcontext()
|
||||
)
|
||||
with init_ctx:
|
||||
model = OPTForCausalLM.from_pretrained(args.model_name_or_path, config=config)
|
||||
logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])
|
||||
|
||||
# Enable gradient checkpointing
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
# Prepare tokenizer and dataloader
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
|
||||
dataset = NetflixDataset(tokenizer)
|
||||
|
@@ -113,13 +113,13 @@ class PerformanceEvaluator:
|
||||
self.disable = self.ignore_steps > 0 and step < self.ignore_steps
|
||||
if self.disable:
|
||||
return
|
||||
get_accelerator().synchronize()
|
||||
# get_accelerator().synchronize()
|
||||
self.timer.start()
|
||||
|
||||
def on_step_end(self, input_ids: Tensor, **kwargs) -> None:
|
||||
if self.disable:
|
||||
return
|
||||
get_accelerator().synchronize()
|
||||
# get_accelerator().synchronize()
|
||||
self.timer.end()
|
||||
|
||||
batch_size, seq_len = input_ids.shape
|
||||
|
Reference in New Issue
Block a user