[Feat] Distrifusion Acceleration Support for Diffusion Inference (#5895)

* Distrifusion Support source

* comp comm overlap optimization

* sd3 benchmark

* pixart distrifusion bug fix

* sd3 bug fix and benchmark

* generation bug fix

* naming fix

* add docstring, fix counter and shape error

* add reference

* readme and requirement
This commit is contained in:
Runyu Lu
2024-07-30 10:43:26 +08:00
committed by GitHub
parent 7b38964e3a
commit bcf0181ecd
15 changed files with 1089 additions and 16 deletions

View File

@@ -0,0 +1,22 @@
## File Structure
```
|- sd3_generation.py: an example of how to use Colossalai Inference Engine to generate result by loading Diffusion Model.
|- compute_metric.py: compare the quality of images w/o some acceleration method like Distrifusion
|- benchmark_sd3.py: benchmark the performance of our InferenceEngine
|- run_benchmark.sh: run benchmark command
```
Note: compute_metric.py need some dependencies which need `pip install -r requirements.txt`, `requirements.txt` is in `examples/inference/stable_diffusion/`
## Run Inference
The provided example `sd3_generation.py` is an example to configure, initialize the engine, and run inference on provided model. We've added `DiffusionPipeline` as model class, and the script is good to run inference with StableDiffusion 3.
For a basic setting, you could run the example by:
```bash
colossalai run --nproc_per_node 1 sd3_generation.py -m PATH_MODEL -p "hello world"
```
Run multi-GPU inference (Patched Parallelism), as in the following example using 2 GPUs:
```bash
colossalai run --nproc_per_node 2 sd3_generation.py -m PATH_MODEL
```

View File

@@ -0,0 +1,179 @@
import argparse
import json
import time
from contextlib import nullcontext
import torch
import torch.distributed as dist
from diffusers import DiffusionPipeline
import colossalai
from colossalai.inference.config import DiffusionGenerationConfig, InferenceConfig
from colossalai.inference.core.engine import InferenceEngine
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
GIGABYTE = 1024**3
MEGABYTE = 1024 * 1024
_DTYPE_MAPPING = {
"fp16": torch.float16,
"bf16": torch.bfloat16,
"fp32": torch.float32,
}
def log_generation_time(log_data, log_file):
with open(log_file, "a") as f:
json.dump(log_data, f, indent=2)
f.write("\n")
def warmup(engine, args):
for _ in range(args.n_warm_up_steps):
engine.generate(
prompts=["hello world"],
generation_config=DiffusionGenerationConfig(
num_inference_steps=args.num_inference_steps, height=args.height[0], width=args.width[0]
),
)
def profile_context(args):
return (
torch.profiler.profile(
record_shapes=True,
with_stack=True,
with_modules=True,
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
)
if args.profile
else nullcontext()
)
def log_and_profile(h, w, avg_time, log_msg, args, model_name, mode, prof=None):
log_data = {
"mode": mode,
"model": model_name,
"batch_size": args.batch_size,
"patched_parallel_size": args.patched_parallel_size,
"num_inference_steps": args.num_inference_steps,
"height": h,
"width": w,
"dtype": args.dtype,
"profile": args.profile,
"n_warm_up_steps": args.n_warm_up_steps,
"n_repeat_times": args.n_repeat_times,
"avg_generation_time": avg_time,
"log_message": log_msg,
}
if args.log:
log_file = f"examples/inference/stable_diffusion/benchmark_{model_name}_{mode}.json"
log_generation_time(log_data=log_data, log_file=log_file)
if args.profile:
file = f"examples/inference/stable_diffusion/benchmark_{model_name}_{mode}_prof.json"
prof.export_chrome_trace(file)
def benchmark_colossalai(rank, world_size, port, args):
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
from colossalai.cluster.dist_coordinator import DistCoordinator
coordinator = DistCoordinator()
inference_config = InferenceConfig(
dtype=args.dtype,
patched_parallelism_size=args.patched_parallel_size,
)
engine = InferenceEngine(args.model, inference_config=inference_config, verbose=False)
warmup(engine, args)
for h, w in zip(args.height, args.width):
with profile_context(args) as prof:
start = time.perf_counter()
for _ in range(args.n_repeat_times):
engine.generate(
prompts=["hello world"],
generation_config=DiffusionGenerationConfig(
num_inference_steps=args.num_inference_steps, height=h, width=w
),
)
end = time.perf_counter()
avg_time = (end - start) / args.n_repeat_times
log_msg = f"[ColossalAI]avg generation time for h({h})xw({w}) is {avg_time:.2f}s"
coordinator.print_on_master(log_msg)
if dist.get_rank() == 0:
log_and_profile(h, w, avg_time, log_msg, args, args.model.split("/")[-1], "colossalai", prof=prof)
def benchmark_diffusers(args):
model = DiffusionPipeline.from_pretrained(args.model, torch_dtype=_DTYPE_MAPPING[args.dtype]).to("cuda")
for _ in range(args.n_warm_up_steps):
model(
prompt="hello world",
num_inference_steps=args.num_inference_steps,
height=args.height[0],
width=args.width[0],
)
for h, w in zip(args.height, args.width):
with profile_context(args) as prof:
start = time.perf_counter()
for _ in range(args.n_repeat_times):
model(prompt="hello world", num_inference_steps=args.num_inference_steps, height=h, width=w)
end = time.perf_counter()
avg_time = (end - start) / args.n_repeat_times
log_msg = f"[Diffusers]avg generation time for h({h})xw({w}) is {avg_time:.2f}s"
print(log_msg)
log_and_profile(h, w, avg_time, log_msg, args, args.model.split("/")[-1], "diffusers", prof)
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def benchmark(args):
if args.mode == "colossalai":
spawn(benchmark_colossalai, nprocs=args.patched_parallel_size, args=args)
elif args.mode == "diffusers":
benchmark_diffusers(args)
"""
# enable log
python examples/inference/stable_diffusion/benchmark_sd3.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" -p 2 --mode colossalai --log
python examples/inference/stable_diffusion/benchmark_sd3.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" --mode diffusers --log
# enable profiler
python examples/inference/stable_diffusion/benchmark_sd3.py -m "stabilityai/stable-diffusion-3-medium-diffusers" -p 2 --mode colossalai --n_warm_up_steps 3 --n_repeat_times 1 --profile --num_inference_steps 20
python examples/inference/stable_diffusion/benchmark_sd3.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" -p 2 --mode colossalai --n_warm_up_steps 3 --n_repeat_times 1 --profile --num_inference_steps 20
python examples/inference/stable_diffusion/benchmark_sd3.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" --mode diffusers --n_warm_up_steps 3 --n_repeat_times 1 --profile --num_inference_steps 20
"""
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-b", "--batch_size", type=int, default=1, help="Batch size")
parser.add_argument("-p", "--patched_parallel_size", type=int, default=1, help="Patched Parallelism size")
parser.add_argument("-n", "--num_inference_steps", type=int, default=50, help="Number of inference steps")
parser.add_argument("-H", "--height", type=int, nargs="+", default=[1024, 2048], help="Height list")
parser.add_argument("-w", "--width", type=int, nargs="+", default=[1024, 2048], help="Width list")
parser.add_argument("--dtype", type=str, default="fp16", choices=["fp16", "fp32", "bf16"], help="Data type")
parser.add_argument("--n_warm_up_steps", type=int, default=3, help="Number of warm up steps")
parser.add_argument("--n_repeat_times", type=int, default=5, help="Number of repeat times")
parser.add_argument("--profile", default=False, action="store_true", help="Enable torch profiler")
parser.add_argument("--log", default=False, action="store_true", help="Enable logging")
parser.add_argument("-m", "--model", default="stabilityai/stable-diffusion-3-medium-diffusers", help="Model path")
parser.add_argument(
"--mode", default="colossalai", choices=["colossalai", "diffusers"], help="Inference framework mode"
)
args = parser.parse_args()
benchmark(args)

View File

@@ -0,0 +1,80 @@
# Code from https://github.com/mit-han-lab/distrifuser/blob/main/scripts/compute_metrics.py
import argparse
import os
import numpy as np
import torch
from cleanfid import fid
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchmetrics.image import LearnedPerceptualImagePatchSimilarity, PeakSignalNoiseRatio
from torchvision.transforms import Resize
from tqdm import tqdm
def read_image(path: str):
"""
input: path
output: tensor (C, H, W)
"""
img = np.asarray(Image.open(path))
if len(img.shape) == 2:
img = np.repeat(img[:, :, None], 3, axis=2)
img = torch.from_numpy(img).permute(2, 0, 1)
return img
class MultiImageDataset(Dataset):
def __init__(self, root0, root1, is_gt=False):
super().__init__()
self.root0 = root0
self.root1 = root1
file_names0 = os.listdir(root0)
file_names1 = os.listdir(root1)
self.image_names0 = sorted([name for name in file_names0 if name.endswith(".png") or name.endswith(".jpg")])
self.image_names1 = sorted([name for name in file_names1 if name.endswith(".png") or name.endswith(".jpg")])
self.is_gt = is_gt
assert len(self.image_names0) == len(self.image_names1)
def __len__(self):
return len(self.image_names0)
def __getitem__(self, idx):
img0 = read_image(os.path.join(self.root0, self.image_names0[idx]))
if self.is_gt:
# resize to 1024 x 1024
img0 = Resize((1024, 1024))(img0)
img1 = read_image(os.path.join(self.root1, self.image_names1[idx]))
batch_list = [img0, img1]
return batch_list
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--batch_size", type=int, default=64)
parser.add_argument("--num_workers", type=int, default=8)
parser.add_argument("--is_gt", action="store_true")
parser.add_argument("--input_root0", type=str, required=True)
parser.add_argument("--input_root1", type=str, required=True)
args = parser.parse_args()
psnr = PeakSignalNoiseRatio(data_range=(0, 1), reduction="elementwise_mean", dim=(1, 2, 3)).to("cuda")
lpips = LearnedPerceptualImagePatchSimilarity(normalize=True).to("cuda")
dataset = MultiImageDataset(args.input_root0, args.input_root1, is_gt=args.is_gt)
dataloader = DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers)
progress_bar = tqdm(dataloader)
with torch.inference_mode():
for i, batch in enumerate(progress_bar):
batch = [img.to("cuda") / 255 for img in batch]
batch_size = batch[0].shape[0]
psnr.update(batch[0], batch[1])
lpips.update(batch[0], batch[1])
fid_score = fid.compute_fid(args.input_root0, args.input_root1)
print("PSNR:", psnr.compute().item())
print("LPIPS:", lpips.compute().item())
print("FID:", fid_score)

View File

@@ -0,0 +1,3 @@
torchvision
torchmetrics
cleanfid

View File

@@ -0,0 +1,42 @@
#!/bin/bash
models=("PixArt-alpha/PixArt-XL-2-1024-MS" "stabilityai/stable-diffusion-3-medium-diffusers")
parallelism=(1 2 4 8)
resolutions=(1024 2048 3840)
modes=("colossalai" "diffusers")
CUDA_VISIBLE_DEVICES_set_n_least_memory_usage() {
local n=${1:-"9999"}
echo "GPU Memory Usage:"
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \
| tail -n +2 \
| nl -v 0 \
| tee /dev/tty \
| sort -g -k 2 \
| awk '{print $1}' \
| head -n $n)
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
echo "Now CUDA_VISIBLE_DEVICES is set to:"
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
}
for model in "${models[@]}"; do
for p in "${parallelism[@]}"; do
for resolution in "${resolutions[@]}"; do
for mode in "${modes[@]}"; do
if [[ "$mode" == "colossalai" && "$p" == 1 ]]; then
continue
fi
if [[ "$mode" == "diffusers" && "$p" != 1 ]]; then
continue
fi
CUDA_VISIBLE_DEVICES_set_n_least_memory_usage $p
cmd="python examples/inference/stable_diffusion/benchmark_sd3.py -m \"$model\" -p $p --mode $mode --log -H $resolution -w $resolution"
echo "Executing: $cmd"
eval $cmd
done
done
done
done

View File

@@ -1,18 +1,17 @@
import argparse
from diffusers import PixArtAlphaPipeline, StableDiffusion3Pipeline
from torch import bfloat16, float16, float32
from diffusers import DiffusionPipeline
from torch import bfloat16
from torch import distributed as dist
from torch import float16, float32
import colossalai
from colossalai.cluster import DistCoordinator
from colossalai.inference.config import DiffusionGenerationConfig, InferenceConfig
from colossalai.inference.core.engine import InferenceEngine
from colossalai.inference.modeling.policy.pixart_alpha import PixArtAlphaInferPolicy
from colossalai.inference.modeling.policy.stablediffusion3 import StableDiffusion3InferPolicy
# For Stable Diffusion 3, we'll use the following configuration
MODEL_CLS = [StableDiffusion3Pipeline, PixArtAlphaPipeline][0]
POLICY_CLS = [StableDiffusion3InferPolicy, PixArtAlphaInferPolicy][0]
MODEL_CLS = DiffusionPipeline
TORCH_DTYPE_MAP = {
"fp16": float16,
@@ -43,20 +42,27 @@ def infer(args):
max_batch_size=args.max_batch_size,
tp_size=args.tp_size,
use_cuda_kernel=args.use_cuda_kernel,
patched_parallelism_size=dist.get_world_size(),
)
engine = InferenceEngine(model, inference_config=inference_config, model_policy=POLICY_CLS(), verbose=True)
engine = InferenceEngine(model, inference_config=inference_config, verbose=True)
# ==============================
# Generation
# ==============================
coordinator.print_on_master(f"Generating...")
out = engine.generate(prompts=[args.prompt], generation_config=DiffusionGenerationConfig())[0]
out.save("cat.jpg")
if dist.get_rank() == 0:
out.save(f"cat_parallel_size{dist.get_world_size()}.jpg")
coordinator.print_on_master(out)
# colossalai run --nproc_per_node 1 examples/inference/stable_diffusion/sd3_generation.py -m MODEL_PATH
# colossalai run --nproc_per_node 1 examples/inference/stable_diffusion/sd3_generation.py -m "stabilityai/stable-diffusion-3-medium-diffusers" --tp_size 1
# colossalai run --nproc_per_node 2 examples/inference/stable_diffusion/sd3_generation.py -m "stabilityai/stable-diffusion-3-medium-diffusers" --tp_size 1
# colossalai run --nproc_per_node 1 examples/inference/stable_diffusion/sd3_generation.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" --tp_size 1
# colossalai run --nproc_per_node 2 examples/inference/stable_diffusion/sd3_generation.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" --tp_size 1
if __name__ == "__main__":