mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-11 13:59:08 +00:00
[Feat] Distrifusion Acceleration Support for Diffusion Inference (#5895)
* Distrifusion Support source * comp comm overlap optimization * sd3 benchmark * pixart distrifusion bug fix * sd3 bug fix and benchmark * generation bug fix * naming fix * add docstring, fix counter and shape error * add reference * readme and requirement
This commit is contained in:
@@ -1,18 +1,17 @@
|
||||
import argparse
|
||||
|
||||
from diffusers import PixArtAlphaPipeline, StableDiffusion3Pipeline
|
||||
from torch import bfloat16, float16, float32
|
||||
from diffusers import DiffusionPipeline
|
||||
from torch import bfloat16
|
||||
from torch import distributed as dist
|
||||
from torch import float16, float32
|
||||
|
||||
import colossalai
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.inference.config import DiffusionGenerationConfig, InferenceConfig
|
||||
from colossalai.inference.core.engine import InferenceEngine
|
||||
from colossalai.inference.modeling.policy.pixart_alpha import PixArtAlphaInferPolicy
|
||||
from colossalai.inference.modeling.policy.stablediffusion3 import StableDiffusion3InferPolicy
|
||||
|
||||
# For Stable Diffusion 3, we'll use the following configuration
|
||||
MODEL_CLS = [StableDiffusion3Pipeline, PixArtAlphaPipeline][0]
|
||||
POLICY_CLS = [StableDiffusion3InferPolicy, PixArtAlphaInferPolicy][0]
|
||||
MODEL_CLS = DiffusionPipeline
|
||||
|
||||
TORCH_DTYPE_MAP = {
|
||||
"fp16": float16,
|
||||
@@ -43,20 +42,27 @@ def infer(args):
|
||||
max_batch_size=args.max_batch_size,
|
||||
tp_size=args.tp_size,
|
||||
use_cuda_kernel=args.use_cuda_kernel,
|
||||
patched_parallelism_size=dist.get_world_size(),
|
||||
)
|
||||
engine = InferenceEngine(model, inference_config=inference_config, model_policy=POLICY_CLS(), verbose=True)
|
||||
engine = InferenceEngine(model, inference_config=inference_config, verbose=True)
|
||||
|
||||
# ==============================
|
||||
# Generation
|
||||
# ==============================
|
||||
coordinator.print_on_master(f"Generating...")
|
||||
out = engine.generate(prompts=[args.prompt], generation_config=DiffusionGenerationConfig())[0]
|
||||
out.save("cat.jpg")
|
||||
if dist.get_rank() == 0:
|
||||
out.save(f"cat_parallel_size{dist.get_world_size()}.jpg")
|
||||
coordinator.print_on_master(out)
|
||||
|
||||
|
||||
# colossalai run --nproc_per_node 1 examples/inference/stable_diffusion/sd3_generation.py -m MODEL_PATH
|
||||
|
||||
# colossalai run --nproc_per_node 1 examples/inference/stable_diffusion/sd3_generation.py -m "stabilityai/stable-diffusion-3-medium-diffusers" --tp_size 1
|
||||
# colossalai run --nproc_per_node 2 examples/inference/stable_diffusion/sd3_generation.py -m "stabilityai/stable-diffusion-3-medium-diffusers" --tp_size 1
|
||||
|
||||
# colossalai run --nproc_per_node 1 examples/inference/stable_diffusion/sd3_generation.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" --tp_size 1
|
||||
# colossalai run --nproc_per_node 2 examples/inference/stable_diffusion/sd3_generation.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" --tp_size 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Reference in New Issue
Block a user