mirror of
				https://github.com/hpcaitech/ColossalAI.git
				synced 2025-11-04 07:58:42 +00:00 
			
		
		
		
	* Distrifusion Support source * comp comm overlap optimization * sd3 benchmark * pixart distrifusion bug fix * sd3 bug fix and benchmark * generation bug fix * naming fix * add docstring, fix counter and shape error * add reference * readme and requirement
		
			
				
	
	
		
			82 lines
		
	
	
		
			3.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			82 lines
		
	
	
		
			3.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import argparse
 | 
						|
 | 
						|
from diffusers import DiffusionPipeline
 | 
						|
from torch import bfloat16
 | 
						|
from torch import distributed as dist
 | 
						|
from torch import float16, float32
 | 
						|
 | 
						|
import colossalai
 | 
						|
from colossalai.cluster import DistCoordinator
 | 
						|
from colossalai.inference.config import DiffusionGenerationConfig, InferenceConfig
 | 
						|
from colossalai.inference.core.engine import InferenceEngine
 | 
						|
 | 
						|
# For Stable Diffusion 3, we'll use the following configuration
 | 
						|
MODEL_CLS = DiffusionPipeline
 | 
						|
 | 
						|
TORCH_DTYPE_MAP = {
 | 
						|
    "fp16": float16,
 | 
						|
    "fp32": float32,
 | 
						|
    "bf16": bfloat16,
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
def infer(args):
 | 
						|
    # ==============================
 | 
						|
    # Launch colossalai, setup distributed environment
 | 
						|
    # ==============================
 | 
						|
    colossalai.launch_from_torch()
 | 
						|
    coordinator = DistCoordinator()
 | 
						|
 | 
						|
    # ==============================
 | 
						|
    # Load model and tokenizer
 | 
						|
    # ==============================
 | 
						|
    model_path_or_name = args.model
 | 
						|
    model = MODEL_CLS.from_pretrained(model_path_or_name, torch_dtype=TORCH_DTYPE_MAP.get(args.dtype, None))
 | 
						|
 | 
						|
    # ==============================
 | 
						|
    # Initialize InferenceEngine
 | 
						|
    # ==============================
 | 
						|
    coordinator.print_on_master(f"Initializing Inference Engine...")
 | 
						|
    inference_config = InferenceConfig(
 | 
						|
        dtype=args.dtype,
 | 
						|
        max_batch_size=args.max_batch_size,
 | 
						|
        tp_size=args.tp_size,
 | 
						|
        use_cuda_kernel=args.use_cuda_kernel,
 | 
						|
        patched_parallelism_size=dist.get_world_size(),
 | 
						|
    )
 | 
						|
    engine = InferenceEngine(model, inference_config=inference_config, verbose=True)
 | 
						|
 | 
						|
    # ==============================
 | 
						|
    # Generation
 | 
						|
    # ==============================
 | 
						|
    coordinator.print_on_master(f"Generating...")
 | 
						|
    out = engine.generate(prompts=[args.prompt], generation_config=DiffusionGenerationConfig())[0]
 | 
						|
    if dist.get_rank() == 0:
 | 
						|
        out.save(f"cat_parallel_size{dist.get_world_size()}.jpg")
 | 
						|
    coordinator.print_on_master(out)
 | 
						|
 | 
						|
 | 
						|
# colossalai run --nproc_per_node 1 examples/inference/stable_diffusion/sd3_generation.py -m MODEL_PATH
 | 
						|
 | 
						|
# colossalai run --nproc_per_node 1 examples/inference/stable_diffusion/sd3_generation.py -m "stabilityai/stable-diffusion-3-medium-diffusers" --tp_size 1
 | 
						|
# colossalai run --nproc_per_node 2 examples/inference/stable_diffusion/sd3_generation.py -m "stabilityai/stable-diffusion-3-medium-diffusers" --tp_size 1
 | 
						|
 | 
						|
# colossalai run --nproc_per_node 1 examples/inference/stable_diffusion/sd3_generation.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" --tp_size 1
 | 
						|
# colossalai run --nproc_per_node 2 examples/inference/stable_diffusion/sd3_generation.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" --tp_size 1
 | 
						|
 | 
						|
 | 
						|
if __name__ == "__main__":
 | 
						|
    # ==============================
 | 
						|
    # Parse Arguments
 | 
						|
    # ==============================
 | 
						|
    parser = argparse.ArgumentParser()
 | 
						|
    parser.add_argument("-m", "--model", type=str, help="Path to the model or model name")
 | 
						|
    parser.add_argument("-t", "--tp_size", type=int, default=1, help="Tensor Parallelism size")
 | 
						|
    parser.add_argument("-p", "--prompt", type=str, default="A cat holding a sign that says hello world", help="Prompt")
 | 
						|
    parser.add_argument("-b", "--max_batch_size", type=int, default=1, help="Max batch size")
 | 
						|
    parser.add_argument("-d", "--dtype", type=str, default="fp16", help="Data type", choices=["fp16", "fp32", "bf16"])
 | 
						|
    parser.add_argument("--use_cuda_kernel", action="store_true", help="Use CUDA kernel, use Triton by default")
 | 
						|
    args = parser.parse_args()
 | 
						|
 | 
						|
    infer(args)
 |