mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 17:17:05 +00:00
[Feat] Distrifusion Acceleration Support for Diffusion Inference (#5895)
* Distrifusion Support source * comp comm overlap optimization * sd3 benchmark * pixart distrifusion bug fix * sd3 bug fix and benchmark * generation bug fix * naming fix * add docstring, fix counter and shape error * add reference * readme and requirement
This commit is contained in:
22
examples/inference/stable_diffusion/README.md
Normal file
22
examples/inference/stable_diffusion/README.md
Normal file
@@ -0,0 +1,22 @@
|
||||
## File Structure
|
||||
```
|
||||
|- sd3_generation.py: an example of how to use Colossalai Inference Engine to generate result by loading Diffusion Model.
|
||||
|- compute_metric.py: compare the quality of images w/o some acceleration method like Distrifusion
|
||||
|- benchmark_sd3.py: benchmark the performance of our InferenceEngine
|
||||
|- run_benchmark.sh: run benchmark command
|
||||
```
|
||||
Note: compute_metric.py need some dependencies which need `pip install -r requirements.txt`, `requirements.txt` is in `examples/inference/stable_diffusion/`
|
||||
|
||||
## Run Inference
|
||||
|
||||
The provided example `sd3_generation.py` is an example to configure, initialize the engine, and run inference on provided model. We've added `DiffusionPipeline` as model class, and the script is good to run inference with StableDiffusion 3.
|
||||
|
||||
For a basic setting, you could run the example by:
|
||||
```bash
|
||||
colossalai run --nproc_per_node 1 sd3_generation.py -m PATH_MODEL -p "hello world"
|
||||
```
|
||||
|
||||
Run multi-GPU inference (Patched Parallelism), as in the following example using 2 GPUs:
|
||||
```bash
|
||||
colossalai run --nproc_per_node 2 sd3_generation.py -m PATH_MODEL
|
||||
```
|
179
examples/inference/stable_diffusion/benchmark_sd3.py
Normal file
179
examples/inference/stable_diffusion/benchmark_sd3.py
Normal file
@@ -0,0 +1,179 @@
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
import colossalai
|
||||
from colossalai.inference.config import DiffusionGenerationConfig, InferenceConfig
|
||||
from colossalai.inference.core.engine import InferenceEngine
|
||||
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
||||
|
||||
GIGABYTE = 1024**3
|
||||
MEGABYTE = 1024 * 1024
|
||||
|
||||
_DTYPE_MAPPING = {
|
||||
"fp16": torch.float16,
|
||||
"bf16": torch.bfloat16,
|
||||
"fp32": torch.float32,
|
||||
}
|
||||
|
||||
|
||||
def log_generation_time(log_data, log_file):
|
||||
with open(log_file, "a") as f:
|
||||
json.dump(log_data, f, indent=2)
|
||||
f.write("\n")
|
||||
|
||||
|
||||
def warmup(engine, args):
|
||||
for _ in range(args.n_warm_up_steps):
|
||||
engine.generate(
|
||||
prompts=["hello world"],
|
||||
generation_config=DiffusionGenerationConfig(
|
||||
num_inference_steps=args.num_inference_steps, height=args.height[0], width=args.width[0]
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def profile_context(args):
|
||||
return (
|
||||
torch.profiler.profile(
|
||||
record_shapes=True,
|
||||
with_stack=True,
|
||||
with_modules=True,
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
],
|
||||
)
|
||||
if args.profile
|
||||
else nullcontext()
|
||||
)
|
||||
|
||||
|
||||
def log_and_profile(h, w, avg_time, log_msg, args, model_name, mode, prof=None):
|
||||
log_data = {
|
||||
"mode": mode,
|
||||
"model": model_name,
|
||||
"batch_size": args.batch_size,
|
||||
"patched_parallel_size": args.patched_parallel_size,
|
||||
"num_inference_steps": args.num_inference_steps,
|
||||
"height": h,
|
||||
"width": w,
|
||||
"dtype": args.dtype,
|
||||
"profile": args.profile,
|
||||
"n_warm_up_steps": args.n_warm_up_steps,
|
||||
"n_repeat_times": args.n_repeat_times,
|
||||
"avg_generation_time": avg_time,
|
||||
"log_message": log_msg,
|
||||
}
|
||||
|
||||
if args.log:
|
||||
log_file = f"examples/inference/stable_diffusion/benchmark_{model_name}_{mode}.json"
|
||||
log_generation_time(log_data=log_data, log_file=log_file)
|
||||
|
||||
if args.profile:
|
||||
file = f"examples/inference/stable_diffusion/benchmark_{model_name}_{mode}_prof.json"
|
||||
prof.export_chrome_trace(file)
|
||||
|
||||
|
||||
def benchmark_colossalai(rank, world_size, port, args):
|
||||
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
from colossalai.cluster.dist_coordinator import DistCoordinator
|
||||
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
inference_config = InferenceConfig(
|
||||
dtype=args.dtype,
|
||||
patched_parallelism_size=args.patched_parallel_size,
|
||||
)
|
||||
engine = InferenceEngine(args.model, inference_config=inference_config, verbose=False)
|
||||
|
||||
warmup(engine, args)
|
||||
|
||||
for h, w in zip(args.height, args.width):
|
||||
with profile_context(args) as prof:
|
||||
start = time.perf_counter()
|
||||
for _ in range(args.n_repeat_times):
|
||||
engine.generate(
|
||||
prompts=["hello world"],
|
||||
generation_config=DiffusionGenerationConfig(
|
||||
num_inference_steps=args.num_inference_steps, height=h, width=w
|
||||
),
|
||||
)
|
||||
end = time.perf_counter()
|
||||
|
||||
avg_time = (end - start) / args.n_repeat_times
|
||||
log_msg = f"[ColossalAI]avg generation time for h({h})xw({w}) is {avg_time:.2f}s"
|
||||
coordinator.print_on_master(log_msg)
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
log_and_profile(h, w, avg_time, log_msg, args, args.model.split("/")[-1], "colossalai", prof=prof)
|
||||
|
||||
|
||||
def benchmark_diffusers(args):
|
||||
model = DiffusionPipeline.from_pretrained(args.model, torch_dtype=_DTYPE_MAPPING[args.dtype]).to("cuda")
|
||||
|
||||
for _ in range(args.n_warm_up_steps):
|
||||
model(
|
||||
prompt="hello world",
|
||||
num_inference_steps=args.num_inference_steps,
|
||||
height=args.height[0],
|
||||
width=args.width[0],
|
||||
)
|
||||
|
||||
for h, w in zip(args.height, args.width):
|
||||
with profile_context(args) as prof:
|
||||
start = time.perf_counter()
|
||||
for _ in range(args.n_repeat_times):
|
||||
model(prompt="hello world", num_inference_steps=args.num_inference_steps, height=h, width=w)
|
||||
end = time.perf_counter()
|
||||
|
||||
avg_time = (end - start) / args.n_repeat_times
|
||||
log_msg = f"[Diffusers]avg generation time for h({h})xw({w}) is {avg_time:.2f}s"
|
||||
print(log_msg)
|
||||
|
||||
log_and_profile(h, w, avg_time, log_msg, args, args.model.split("/")[-1], "diffusers", prof)
|
||||
|
||||
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def benchmark(args):
|
||||
if args.mode == "colossalai":
|
||||
spawn(benchmark_colossalai, nprocs=args.patched_parallel_size, args=args)
|
||||
elif args.mode == "diffusers":
|
||||
benchmark_diffusers(args)
|
||||
|
||||
|
||||
"""
|
||||
# enable log
|
||||
python examples/inference/stable_diffusion/benchmark_sd3.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" -p 2 --mode colossalai --log
|
||||
python examples/inference/stable_diffusion/benchmark_sd3.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" --mode diffusers --log
|
||||
|
||||
# enable profiler
|
||||
python examples/inference/stable_diffusion/benchmark_sd3.py -m "stabilityai/stable-diffusion-3-medium-diffusers" -p 2 --mode colossalai --n_warm_up_steps 3 --n_repeat_times 1 --profile --num_inference_steps 20
|
||||
python examples/inference/stable_diffusion/benchmark_sd3.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" -p 2 --mode colossalai --n_warm_up_steps 3 --n_repeat_times 1 --profile --num_inference_steps 20
|
||||
python examples/inference/stable_diffusion/benchmark_sd3.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" --mode diffusers --n_warm_up_steps 3 --n_repeat_times 1 --profile --num_inference_steps 20
|
||||
"""
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-b", "--batch_size", type=int, default=1, help="Batch size")
|
||||
parser.add_argument("-p", "--patched_parallel_size", type=int, default=1, help="Patched Parallelism size")
|
||||
parser.add_argument("-n", "--num_inference_steps", type=int, default=50, help="Number of inference steps")
|
||||
parser.add_argument("-H", "--height", type=int, nargs="+", default=[1024, 2048], help="Height list")
|
||||
parser.add_argument("-w", "--width", type=int, nargs="+", default=[1024, 2048], help="Width list")
|
||||
parser.add_argument("--dtype", type=str, default="fp16", choices=["fp16", "fp32", "bf16"], help="Data type")
|
||||
parser.add_argument("--n_warm_up_steps", type=int, default=3, help="Number of warm up steps")
|
||||
parser.add_argument("--n_repeat_times", type=int, default=5, help="Number of repeat times")
|
||||
parser.add_argument("--profile", default=False, action="store_true", help="Enable torch profiler")
|
||||
parser.add_argument("--log", default=False, action="store_true", help="Enable logging")
|
||||
parser.add_argument("-m", "--model", default="stabilityai/stable-diffusion-3-medium-diffusers", help="Model path")
|
||||
parser.add_argument(
|
||||
"--mode", default="colossalai", choices=["colossalai", "diffusers"], help="Inference framework mode"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
benchmark(args)
|
80
examples/inference/stable_diffusion/compute_metric.py
Normal file
80
examples/inference/stable_diffusion/compute_metric.py
Normal file
@@ -0,0 +1,80 @@
|
||||
# Code from https://github.com/mit-han-lab/distrifuser/blob/main/scripts/compute_metrics.py
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from cleanfid import fid
|
||||
from PIL import Image
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from torchmetrics.image import LearnedPerceptualImagePatchSimilarity, PeakSignalNoiseRatio
|
||||
from torchvision.transforms import Resize
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def read_image(path: str):
|
||||
"""
|
||||
input: path
|
||||
output: tensor (C, H, W)
|
||||
"""
|
||||
img = np.asarray(Image.open(path))
|
||||
if len(img.shape) == 2:
|
||||
img = np.repeat(img[:, :, None], 3, axis=2)
|
||||
img = torch.from_numpy(img).permute(2, 0, 1)
|
||||
return img
|
||||
|
||||
|
||||
class MultiImageDataset(Dataset):
|
||||
def __init__(self, root0, root1, is_gt=False):
|
||||
super().__init__()
|
||||
self.root0 = root0
|
||||
self.root1 = root1
|
||||
file_names0 = os.listdir(root0)
|
||||
file_names1 = os.listdir(root1)
|
||||
|
||||
self.image_names0 = sorted([name for name in file_names0 if name.endswith(".png") or name.endswith(".jpg")])
|
||||
self.image_names1 = sorted([name for name in file_names1 if name.endswith(".png") or name.endswith(".jpg")])
|
||||
self.is_gt = is_gt
|
||||
assert len(self.image_names0) == len(self.image_names1)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.image_names0)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
img0 = read_image(os.path.join(self.root0, self.image_names0[idx]))
|
||||
if self.is_gt:
|
||||
# resize to 1024 x 1024
|
||||
img0 = Resize((1024, 1024))(img0)
|
||||
img1 = read_image(os.path.join(self.root1, self.image_names1[idx]))
|
||||
|
||||
batch_list = [img0, img1]
|
||||
return batch_list
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--batch_size", type=int, default=64)
|
||||
parser.add_argument("--num_workers", type=int, default=8)
|
||||
parser.add_argument("--is_gt", action="store_true")
|
||||
parser.add_argument("--input_root0", type=str, required=True)
|
||||
parser.add_argument("--input_root1", type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
psnr = PeakSignalNoiseRatio(data_range=(0, 1), reduction="elementwise_mean", dim=(1, 2, 3)).to("cuda")
|
||||
lpips = LearnedPerceptualImagePatchSimilarity(normalize=True).to("cuda")
|
||||
|
||||
dataset = MultiImageDataset(args.input_root0, args.input_root1, is_gt=args.is_gt)
|
||||
dataloader = DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers)
|
||||
|
||||
progress_bar = tqdm(dataloader)
|
||||
with torch.inference_mode():
|
||||
for i, batch in enumerate(progress_bar):
|
||||
batch = [img.to("cuda") / 255 for img in batch]
|
||||
batch_size = batch[0].shape[0]
|
||||
psnr.update(batch[0], batch[1])
|
||||
lpips.update(batch[0], batch[1])
|
||||
fid_score = fid.compute_fid(args.input_root0, args.input_root1)
|
||||
|
||||
print("PSNR:", psnr.compute().item())
|
||||
print("LPIPS:", lpips.compute().item())
|
||||
print("FID:", fid_score)
|
3
examples/inference/stable_diffusion/requirements.txt
Normal file
3
examples/inference/stable_diffusion/requirements.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
torchvision
|
||||
torchmetrics
|
||||
cleanfid
|
42
examples/inference/stable_diffusion/run_benchmark.sh
Normal file
42
examples/inference/stable_diffusion/run_benchmark.sh
Normal file
@@ -0,0 +1,42 @@
|
||||
#!/bin/bash
|
||||
|
||||
models=("PixArt-alpha/PixArt-XL-2-1024-MS" "stabilityai/stable-diffusion-3-medium-diffusers")
|
||||
parallelism=(1 2 4 8)
|
||||
resolutions=(1024 2048 3840)
|
||||
modes=("colossalai" "diffusers")
|
||||
|
||||
CUDA_VISIBLE_DEVICES_set_n_least_memory_usage() {
|
||||
local n=${1:-"9999"}
|
||||
echo "GPU Memory Usage:"
|
||||
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \
|
||||
| tail -n +2 \
|
||||
| nl -v 0 \
|
||||
| tee /dev/tty \
|
||||
| sort -g -k 2 \
|
||||
| awk '{print $1}' \
|
||||
| head -n $n)
|
||||
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
|
||||
echo "Now CUDA_VISIBLE_DEVICES is set to:"
|
||||
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||
}
|
||||
|
||||
for model in "${models[@]}"; do
|
||||
for p in "${parallelism[@]}"; do
|
||||
for resolution in "${resolutions[@]}"; do
|
||||
for mode in "${modes[@]}"; do
|
||||
if [[ "$mode" == "colossalai" && "$p" == 1 ]]; then
|
||||
continue
|
||||
fi
|
||||
if [[ "$mode" == "diffusers" && "$p" != 1 ]]; then
|
||||
continue
|
||||
fi
|
||||
CUDA_VISIBLE_DEVICES_set_n_least_memory_usage $p
|
||||
|
||||
cmd="python examples/inference/stable_diffusion/benchmark_sd3.py -m \"$model\" -p $p --mode $mode --log -H $resolution -w $resolution"
|
||||
|
||||
echo "Executing: $cmd"
|
||||
eval $cmd
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
@@ -1,18 +1,17 @@
|
||||
import argparse
|
||||
|
||||
from diffusers import PixArtAlphaPipeline, StableDiffusion3Pipeline
|
||||
from torch import bfloat16, float16, float32
|
||||
from diffusers import DiffusionPipeline
|
||||
from torch import bfloat16
|
||||
from torch import distributed as dist
|
||||
from torch import float16, float32
|
||||
|
||||
import colossalai
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.inference.config import DiffusionGenerationConfig, InferenceConfig
|
||||
from colossalai.inference.core.engine import InferenceEngine
|
||||
from colossalai.inference.modeling.policy.pixart_alpha import PixArtAlphaInferPolicy
|
||||
from colossalai.inference.modeling.policy.stablediffusion3 import StableDiffusion3InferPolicy
|
||||
|
||||
# For Stable Diffusion 3, we'll use the following configuration
|
||||
MODEL_CLS = [StableDiffusion3Pipeline, PixArtAlphaPipeline][0]
|
||||
POLICY_CLS = [StableDiffusion3InferPolicy, PixArtAlphaInferPolicy][0]
|
||||
MODEL_CLS = DiffusionPipeline
|
||||
|
||||
TORCH_DTYPE_MAP = {
|
||||
"fp16": float16,
|
||||
@@ -43,20 +42,27 @@ def infer(args):
|
||||
max_batch_size=args.max_batch_size,
|
||||
tp_size=args.tp_size,
|
||||
use_cuda_kernel=args.use_cuda_kernel,
|
||||
patched_parallelism_size=dist.get_world_size(),
|
||||
)
|
||||
engine = InferenceEngine(model, inference_config=inference_config, model_policy=POLICY_CLS(), verbose=True)
|
||||
engine = InferenceEngine(model, inference_config=inference_config, verbose=True)
|
||||
|
||||
# ==============================
|
||||
# Generation
|
||||
# ==============================
|
||||
coordinator.print_on_master(f"Generating...")
|
||||
out = engine.generate(prompts=[args.prompt], generation_config=DiffusionGenerationConfig())[0]
|
||||
out.save("cat.jpg")
|
||||
if dist.get_rank() == 0:
|
||||
out.save(f"cat_parallel_size{dist.get_world_size()}.jpg")
|
||||
coordinator.print_on_master(out)
|
||||
|
||||
|
||||
# colossalai run --nproc_per_node 1 examples/inference/stable_diffusion/sd3_generation.py -m MODEL_PATH
|
||||
|
||||
# colossalai run --nproc_per_node 1 examples/inference/stable_diffusion/sd3_generation.py -m "stabilityai/stable-diffusion-3-medium-diffusers" --tp_size 1
|
||||
# colossalai run --nproc_per_node 2 examples/inference/stable_diffusion/sd3_generation.py -m "stabilityai/stable-diffusion-3-medium-diffusers" --tp_size 1
|
||||
|
||||
# colossalai run --nproc_per_node 1 examples/inference/stable_diffusion/sd3_generation.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" --tp_size 1
|
||||
# colossalai run --nproc_per_node 2 examples/inference/stable_diffusion/sd3_generation.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" --tp_size 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Reference in New Issue
Block a user