From bcf0181ecd3e1e639200b66d6e1aab6c6b3d5b7b Mon Sep 17 00:00:00 2001 From: Runyu Lu <77330637+LRY89757@users.noreply.github.com> Date: Tue, 30 Jul 2024 10:43:26 +0800 Subject: [PATCH] [Feat] Distrifusion Acceleration Support for Diffusion Inference (#5895) * Distrifusion Support source * comp comm overlap optimization * sd3 benchmark * pixart distrifusion bug fix * sd3 bug fix and benchmark * generation bug fix * naming fix * add docstring, fix counter and shape error * add reference * readme and requirement --- colossalai/inference/README.md | 12 +- colossalai/inference/config.py | 16 + colossalai/inference/core/diffusion_engine.py | 2 +- .../modeling/{models => layers}/diffusion.py | 0 .../inference/modeling/layers/distrifusion.py | 626 ++++++++++++++++++ .../inference/modeling/models/pixart_alpha.py | 2 +- .../modeling/models/stablediffusion3.py | 2 +- .../inference/modeling/policy/pixart_alpha.py | 49 +- .../modeling/policy/stablediffusion3.py | 48 +- examples/inference/stable_diffusion/README.md | 22 + .../stable_diffusion/benchmark_sd3.py | 179 +++++ .../stable_diffusion/compute_metric.py | 80 +++ .../stable_diffusion/requirements.txt | 3 + .../stable_diffusion/run_benchmark.sh | 42 ++ .../stable_diffusion/sd3_generation.py | 22 +- 15 files changed, 1089 insertions(+), 16 deletions(-) rename colossalai/inference/modeling/{models => layers}/diffusion.py (100%) create mode 100644 colossalai/inference/modeling/layers/distrifusion.py create mode 100644 examples/inference/stable_diffusion/README.md create mode 100644 examples/inference/stable_diffusion/benchmark_sd3.py create mode 100644 examples/inference/stable_diffusion/compute_metric.py create mode 100644 examples/inference/stable_diffusion/requirements.txt create mode 100644 examples/inference/stable_diffusion/run_benchmark.sh diff --git a/colossalai/inference/README.md b/colossalai/inference/README.md index 0a9b5293d..76813a4a3 100644 --- a/colossalai/inference/README.md +++ b/colossalai/inference/README.md @@ -18,7 +18,7 @@ ## 📌 Introduction -ColossalAI-Inference is a module which offers acceleration to the inference execution of Transformers models, especially LLMs. In ColossalAI-Inference, we leverage high-performance kernels, KV cache, paged attention, continous batching and other techniques to accelerate the inference of LLMs. We also provide simple and unified APIs for the sake of user-friendliness. [[blog]](https://hpc-ai.com/blog/colossal-inference) +ColossalAI-Inference is a module which offers acceleration to the inference execution of Transformers models, especially LLMs and DiT Diffusion Models. In ColossalAI-Inference, we leverage high-performance kernels, KV cache, paged attention, continous batching and other techniques to accelerate the inference of LLMs. We also provide simple and unified APIs for the sake of user-friendliness. [[blog]](https://hpc-ai.com/blog/colossal-inference)
@@ -310,4 +310,14 @@ If you wish to cite relevant research papars, you can find the reference below.
journal={arXiv},
year={2023}
}
+
+# Distrifusion
+@InProceedings{Li_2024_CVPR,
+ author={Li, Muyang and Cai, Tianle and Cao, Jiaxin and Zhang, Qinsheng and Cai, Han and Bai, Junjie and Jia, Yangqing and Li, Kai and Han, Song},
+ title={DistriFusion: Distributed Parallel Inference for High-Resolution Diffusion Models},
+ booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
+ month={June},
+ year={2024},
+ pages={7183-7193}
+}
```
diff --git a/colossalai/inference/config.py b/colossalai/inference/config.py
index 1beb86874..072ddbcfd 100644
--- a/colossalai/inference/config.py
+++ b/colossalai/inference/config.py
@@ -186,6 +186,7 @@ class InferenceConfig(RPC_PARAM):
enable_streamingllm(bool): Whether to use StreamingLLM, the relevant algorithms refer to the paper at https://arxiv.org/pdf/2309.17453 for implementation.
start_token_size(int): The size of the start tokens, when using StreamingLLM.
generated_token_size(int): The size of the generated tokens, When using StreamingLLM.
+ patched_parallelism_size(int): Patched Parallelism Size, When using Distrifusion
"""
# NOTE: arrange configs according to their importance and frequency of usage
@@ -245,6 +246,11 @@ class InferenceConfig(RPC_PARAM):
start_token_size: int = 4
generated_token_size: int = 512
+ # Acceleration for Diffusion Model(PipeFusion or Distrifusion)
+ patched_parallelism_size: int = 1 # for distrifusion
+ # pipeFusion_m_size: int = 1 # for pipefusion
+ # pipeFusion_n_size: int = 1 # for pipefusion
+
def __post_init__(self):
self.max_context_len_to_capture = self.max_input_len + self.max_output_len
self._verify_config()
@@ -288,6 +294,14 @@ class InferenceConfig(RPC_PARAM):
# Thereafter, we swap out tokens in units of blocks, and always swapping out the second block when the generated tokens exceeded the limit.
self.start_token_size = self.block_size
+ # check Distrifusion
+ # TODO(@lry89757) need more detailed check
+ if self.patched_parallelism_size > 1:
+ # self.use_patched_parallelism = True
+ self.tp_size = (
+ self.patched_parallelism_size
+ ) # this is not a real tp, because some annoying check, so we have to set this to patched_parallelism_size
+
# check prompt template
if self.prompt_template is None:
return
@@ -324,6 +338,7 @@ class InferenceConfig(RPC_PARAM):
use_cuda_kernel=self.use_cuda_kernel,
use_spec_dec=self.use_spec_dec,
use_flash_attn=use_flash_attn,
+ patched_parallelism_size=self.patched_parallelism_size,
)
return model_inference_config
@@ -396,6 +411,7 @@ class ModelShardInferenceConfig:
use_cuda_kernel: bool = False
use_spec_dec: bool = False
use_flash_attn: bool = False
+ patched_parallelism_size: int = 1 # for diffusion model, Distrifusion Technique
@dataclass
diff --git a/colossalai/inference/core/diffusion_engine.py b/colossalai/inference/core/diffusion_engine.py
index 75b9889bf..8bed508cb 100644
--- a/colossalai/inference/core/diffusion_engine.py
+++ b/colossalai/inference/core/diffusion_engine.py
@@ -11,7 +11,7 @@ from torch import distributed as dist
from colossalai.accelerator import get_accelerator
from colossalai.cluster import ProcessGroupMesh
from colossalai.inference.config import DiffusionGenerationConfig, InferenceConfig, ModelShardInferenceConfig
-from colossalai.inference.modeling.models.diffusion import DiffusionPipe
+from colossalai.inference.modeling.layers.diffusion import DiffusionPipe
from colossalai.inference.modeling.policy import model_policy_map
from colossalai.inference.struct import DiffusionSequence
from colossalai.inference.utils import get_model_size, get_model_type
diff --git a/colossalai/inference/modeling/models/diffusion.py b/colossalai/inference/modeling/layers/diffusion.py
similarity index 100%
rename from colossalai/inference/modeling/models/diffusion.py
rename to colossalai/inference/modeling/layers/diffusion.py
diff --git a/colossalai/inference/modeling/layers/distrifusion.py b/colossalai/inference/modeling/layers/distrifusion.py
new file mode 100644
index 000000000..ea97cceef
--- /dev/null
+++ b/colossalai/inference/modeling/layers/distrifusion.py
@@ -0,0 +1,626 @@
+# Code refer and adapted from:
+# https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers
+# https://github.com/PipeFusion/PipeFusion
+
+import inspect
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+from diffusers.models import attention_processor
+from diffusers.models.attention import Attention
+from diffusers.models.embeddings import PatchEmbed, get_2d_sincos_pos_embed
+from diffusers.models.transformers.pixart_transformer_2d import PixArtTransformer2DModel
+from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel
+from torch import nn
+from torch.distributed import ProcessGroup
+
+from colossalai.inference.config import ModelShardInferenceConfig
+from colossalai.logging import get_dist_logger
+from colossalai.shardformer.layer.parallel_module import ParallelModule
+from colossalai.utils import get_current_device
+
+try:
+ from flash_attn import flash_attn_func
+
+ HAS_FLASH_ATTN = True
+except ImportError:
+ HAS_FLASH_ATTN = False
+
+
+logger = get_dist_logger(__name__)
+
+
+# adapted from https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers/models/transformers/transformer_2d.py
+def PixArtAlphaTransformer2DModel_forward(
+ self: PixArtTransformer2DModel,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ timestep: Optional[torch.LongTensor] = None,
+ added_cond_kwargs: Dict[str, torch.Tensor] = None,
+ class_labels: Optional[torch.LongTensor] = None,
+ cross_attention_kwargs: Dict[str, Any] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+):
+ assert hasattr(
+ self, "patched_parallel_size"
+ ), "please check your policy, `Transformer2DModel` Must have attribute `patched_parallel_size`"
+
+ if cross_attention_kwargs is not None:
+ if cross_attention_kwargs.get("scale", None) is not None:
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
+ # expects mask of shape:
+ # [batch, key_tokens]
+ # adds singleton query_tokens dimension:
+ # [batch, 1, key_tokens]
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
+ if attention_mask is not None and attention_mask.ndim == 2:
+ # assume that mask is expressed as:
+ # (1 = keep, 0 = discard)
+ # convert mask into a bias that can be added to attention scores:
+ # (keep = +0, discard = -10000.0)
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
+ attention_mask = attention_mask.unsqueeze(1)
+
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
+
+ # 1. Input
+ batch_size = hidden_states.shape[0]
+ height, width = (
+ hidden_states.shape[-2] // self.config.patch_size,
+ hidden_states.shape[-1] // self.config.patch_size,
+ )
+ hidden_states = self.pos_embed(hidden_states)
+
+ timestep, embedded_timestep = self.adaln_single(
+ timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
+ )
+
+ if self.caption_projection is not None:
+ encoder_hidden_states = self.caption_projection(encoder_hidden_states)
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
+
+ # 2. Blocks
+ for block in self.transformer_blocks:
+ hidden_states = block(
+ hidden_states,
+ attention_mask=attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ timestep=timestep,
+ cross_attention_kwargs=cross_attention_kwargs,
+ class_labels=class_labels,
+ )
+
+ # 3. Output
+ shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None].to(self.scale_shift_table.device)).chunk(
+ 2, dim=1
+ )
+ hidden_states = self.norm_out(hidden_states)
+ # Modulation
+ hidden_states = hidden_states * (1 + scale.to(hidden_states.device)) + shift.to(hidden_states.device)
+ hidden_states = self.proj_out(hidden_states)
+ hidden_states = hidden_states.squeeze(1)
+
+ # unpatchify
+ hidden_states = hidden_states.reshape(
+ shape=(
+ -1,
+ height // self.patched_parallel_size,
+ width,
+ self.config.patch_size,
+ self.config.patch_size,
+ self.out_channels,
+ )
+ )
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
+ output = hidden_states.reshape(
+ shape=(
+ -1,
+ self.out_channels,
+ height // self.patched_parallel_size * self.config.patch_size,
+ width * self.config.patch_size,
+ )
+ )
+
+ # enable Distrifusion Optimization
+ if hasattr(self, "patched_parallel_size"):
+ from torch import distributed as dist
+
+ if (getattr(self, "output_buffer", None) is None) or (self.output_buffer.shape != output.shape):
+ self.output_buffer = torch.empty_like(output)
+ if (getattr(self, "buffer_list", None) is None) or (self.buffer_list[0].shape != output.shape):
+ self.buffer_list = [torch.empty_like(output) for _ in range(self.patched_parallel_size)]
+ output = output.contiguous()
+ dist.all_gather(self.buffer_list, output, async_op=False)
+ torch.cat(self.buffer_list, dim=2, out=self.output_buffer)
+ output = self.output_buffer
+
+ return (output,)
+
+
+# adapted from https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers/models/transformers/transformer_sd3.py
+def SD3Transformer2DModel_forward(
+ self: SD3Transformer2DModel,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: torch.FloatTensor = None,
+ pooled_projections: torch.FloatTensor = None,
+ timestep: torch.LongTensor = None,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ return_dict: bool = True,
+) -> Union[torch.FloatTensor]:
+
+ assert hasattr(
+ self, "patched_parallel_size"
+ ), "please check your policy, `Transformer2DModel` Must have attribute `patched_parallel_size`"
+
+ height, width = hidden_states.shape[-2:]
+
+ hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too.
+ temb = self.time_text_embed(timestep, pooled_projections)
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
+
+ for block in self.transformer_blocks:
+ encoder_hidden_states, hidden_states = block(
+ hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
+ )
+
+ hidden_states = self.norm_out(hidden_states, temb)
+ hidden_states = self.proj_out(hidden_states)
+
+ # unpatchify
+ patch_size = self.config.patch_size
+ height = height // patch_size // self.patched_parallel_size
+ width = width // patch_size
+
+ hidden_states = hidden_states.reshape(
+ shape=(hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels)
+ )
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
+ output = hidden_states.reshape(
+ shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size)
+ )
+
+ # enable Distrifusion Optimization
+ if hasattr(self, "patched_parallel_size"):
+ from torch import distributed as dist
+
+ if (getattr(self, "output_buffer", None) is None) or (self.output_buffer.shape != output.shape):
+ self.output_buffer = torch.empty_like(output)
+ if (getattr(self, "buffer_list", None) is None) or (self.buffer_list[0].shape != output.shape):
+ self.buffer_list = [torch.empty_like(output) for _ in range(self.patched_parallel_size)]
+ output = output.contiguous()
+ dist.all_gather(self.buffer_list, output, async_op=False)
+ torch.cat(self.buffer_list, dim=2, out=self.output_buffer)
+ output = self.output_buffer
+
+ return (output,)
+
+
+# Code adapted from: https://github.com/PipeFusion/PipeFusion/blob/main/pipefuser/modules/dit/patch_parallel/patchembed.py
+class DistrifusionPatchEmbed(ParallelModule):
+ def __init__(
+ self,
+ module: PatchEmbed,
+ process_group: Union[ProcessGroup, List[ProcessGroup]],
+ model_shard_infer_config: ModelShardInferenceConfig = None,
+ ):
+ super().__init__()
+ self.module = module
+ self.rank = dist.get_rank(group=process_group)
+ self.patched_parallelism_size = model_shard_infer_config.patched_parallelism_size
+
+ @staticmethod
+ def from_native_module(module: PatchEmbed, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs):
+ model_shard_infer_config = kwargs.get("model_shard_infer_config", None)
+ distrifusion_embed = DistrifusionPatchEmbed(
+ module, process_group, model_shard_infer_config=model_shard_infer_config
+ )
+ return distrifusion_embed
+
+ def forward(self, latent):
+ module = self.module
+ if module.pos_embed_max_size is not None:
+ height, width = latent.shape[-2:]
+ else:
+ height, width = latent.shape[-2] // module.patch_size, latent.shape[-1] // module.patch_size
+
+ latent = module.proj(latent)
+ if module.flatten:
+ latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
+ if module.layer_norm:
+ latent = module.norm(latent)
+ if module.pos_embed is None:
+ return latent.to(latent.dtype)
+ # Interpolate or crop positional embeddings as needed
+ if module.pos_embed_max_size:
+ pos_embed = module.cropped_pos_embed(height, width)
+ else:
+ if module.height != height or module.width != width:
+ pos_embed = get_2d_sincos_pos_embed(
+ embed_dim=module.pos_embed.shape[-1],
+ grid_size=(height, width),
+ base_size=module.base_size,
+ interpolation_scale=module.interpolation_scale,
+ )
+ pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).to(latent.device)
+ else:
+ pos_embed = module.pos_embed
+
+ b, c, h = pos_embed.shape
+ pos_embed = pos_embed.view(b, self.patched_parallelism_size, -1, h)[:, self.rank]
+
+ return (latent + pos_embed).to(latent.dtype)
+
+
+# Code adapted from: https://github.com/PipeFusion/PipeFusion/blob/main/pipefuser/modules/dit/patch_parallel/conv2d.py
+class DistrifusionConv2D(ParallelModule):
+
+ def __init__(
+ self,
+ module: nn.Conv2d,
+ process_group: Union[ProcessGroup, List[ProcessGroup]],
+ model_shard_infer_config: ModelShardInferenceConfig = None,
+ ):
+ super().__init__()
+ self.module = module
+ self.rank = dist.get_rank(group=process_group)
+ self.patched_parallelism_size = model_shard_infer_config.patched_parallelism_size
+
+ @staticmethod
+ def from_native_module(module: nn.Conv2d, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs):
+ model_shard_infer_config = kwargs.get("model_shard_infer_config", None)
+ distrifusion_conv = DistrifusionConv2D(module, process_group, model_shard_infer_config=model_shard_infer_config)
+ return distrifusion_conv
+
+ def sliced_forward(self, x: torch.Tensor) -> torch.Tensor:
+
+ b, c, h, w = x.shape
+
+ stride = self.module.stride[0]
+ padding = self.module.padding[0]
+
+ output_h = x.shape[2] // stride // self.patched_parallelism_size
+ idx = dist.get_rank()
+ h_begin = output_h * idx * stride - padding
+ h_end = output_h * (idx + 1) * stride + padding
+ final_padding = [padding, padding, 0, 0]
+ if h_begin < 0:
+ h_begin = 0
+ final_padding[2] = padding
+ if h_end > h:
+ h_end = h
+ final_padding[3] = padding
+ sliced_input = x[:, :, h_begin:h_end, :]
+ padded_input = F.pad(sliced_input, final_padding, mode="constant")
+ return F.conv2d(
+ padded_input,
+ self.module.weight,
+ self.module.bias,
+ stride=stride,
+ padding="valid",
+ )
+
+ def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ output = self.sliced_forward(input)
+ return output
+
+
+# Code adapted from: https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers/models/attention_processor.py
+class DistrifusionFusedAttention(ParallelModule):
+
+ def __init__(
+ self,
+ module: attention_processor.Attention,
+ process_group: Union[ProcessGroup, List[ProcessGroup]],
+ model_shard_infer_config: ModelShardInferenceConfig = None,
+ ):
+ super().__init__()
+ self.counter = 0
+ self.module = module
+ self.buffer_list = None
+ self.kv_buffer_idx = dist.get_rank(group=process_group)
+ self.patched_parallelism_size = model_shard_infer_config.patched_parallelism_size
+ self.handle = None
+ self.process_group = process_group
+ self.warm_step = 5 # for warmup
+
+ @staticmethod
+ def from_native_module(
+ module: attention_processor.Attention, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
+ ) -> ParallelModule:
+ model_shard_infer_config = kwargs.get("model_shard_infer_config", None)
+ return DistrifusionFusedAttention(
+ module=module,
+ process_group=process_group,
+ model_shard_infer_config=model_shard_infer_config,
+ )
+
+ def _forward(
+ self,
+ attn: Attention,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: torch.FloatTensor = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ *args,
+ **kwargs,
+ ) -> torch.FloatTensor:
+ residual = hidden_states
+
+ input_ndim = hidden_states.ndim
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+ context_input_ndim = encoder_hidden_states.ndim
+ if context_input_ndim == 4:
+ batch_size, channel, height, width = encoder_hidden_states.shape
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size = encoder_hidden_states.shape[0]
+
+ # `sample` projections.
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+
+ kv = torch.cat([key, value], dim=-1) # shape of kv now: (bs, seq_len // parallel_size, dim * 2)
+
+ if self.patched_parallelism_size == 1:
+ full_kv = kv
+ else:
+ if self.buffer_list is None: # buffer not created
+ full_kv = torch.cat([kv for _ in range(self.patched_parallelism_size)], dim=1)
+ elif self.counter <= self.warm_step:
+ # logger.info(f"warmup: {self.counter}")
+ dist.all_gather(
+ self.buffer_list,
+ kv,
+ group=self.process_group,
+ async_op=False,
+ )
+ full_kv = torch.cat(self.buffer_list, dim=1)
+ else:
+ # logger.info(f"use old kv to infer: {self.counter}")
+ self.buffer_list[self.kv_buffer_idx].copy_(kv)
+ full_kv = torch.cat(self.buffer_list, dim=1)
+ assert self.handle is None, "we should maintain the kv of last step"
+ self.handle = dist.all_gather(self.buffer_list, kv, group=self.process_group, async_op=True)
+
+ key, value = torch.split(full_kv, full_kv.shape[-1] // 2, dim=-1)
+
+ # `context` projections.
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
+
+ # attention
+ query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
+ key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
+ value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ hidden_states = hidden_states = F.scaled_dot_product_attention(
+ query, key, value, dropout_p=0.0, is_causal=False
+ ) # NOTE(@lry89757) for torch >= 2.2, flash attn has been already integrated into scaled_dot_product_attention, https://pytorch.org/blog/pytorch2-2/
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # Split the attention outputs.
+ hidden_states, encoder_hidden_states = (
+ hidden_states[:, : residual.shape[1]],
+ hidden_states[:, residual.shape[1] :],
+ )
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+ if not attn.context_pre_only:
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+ if context_input_ndim == 4:
+ encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ return hidden_states, encoder_hidden_states
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ **cross_attention_kwargs,
+ ) -> torch.Tensor:
+
+ if self.handle is not None:
+ self.handle.wait()
+ self.handle = None
+
+ b, l, c = hidden_states.shape
+ kv_shape = (b, l, self.module.to_k.out_features * 2)
+ if self.patched_parallelism_size > 1 and (self.buffer_list is None or self.buffer_list[0].shape != kv_shape):
+
+ self.buffer_list = [
+ torch.empty(kv_shape, dtype=hidden_states.dtype, device=get_current_device())
+ for _ in range(self.patched_parallelism_size)
+ ]
+
+ self.counter = 0
+
+ attn_parameters = set(inspect.signature(self.module.processor.__call__).parameters.keys())
+ quiet_attn_parameters = {"ip_adapter_masks"}
+ unused_kwargs = [
+ k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters
+ ]
+ if len(unused_kwargs) > 0:
+ logger.warning(
+ f"cross_attention_kwargs {unused_kwargs} are not expected by {self.module.processor.__class__.__name__} and will be ignored."
+ )
+ cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters}
+
+ output = self._forward(
+ self.module,
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+
+ self.counter += 1
+
+ return output
+
+
+# Code adapted from: https://github.com/PipeFusion/PipeFusion/blob/main/pipefuser/modules/dit/patch_parallel/attn.py
+class DistriSelfAttention(ParallelModule):
+ def __init__(
+ self,
+ module: Attention,
+ process_group: Union[ProcessGroup, List[ProcessGroup]],
+ model_shard_infer_config: ModelShardInferenceConfig = None,
+ ):
+ super().__init__()
+ self.counter = 0
+ self.module = module
+ self.buffer_list = None
+ self.kv_buffer_idx = dist.get_rank(group=process_group)
+ self.patched_parallelism_size = model_shard_infer_config.patched_parallelism_size
+ self.handle = None
+ self.process_group = process_group
+ self.warm_step = 3 # for warmup
+
+ @staticmethod
+ def from_native_module(
+ module: Attention, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
+ ) -> ParallelModule:
+ model_shard_infer_config = kwargs.get("model_shard_infer_config", None)
+ return DistriSelfAttention(
+ module=module,
+ process_group=process_group,
+ model_shard_infer_config=model_shard_infer_config,
+ )
+
+ def _forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0):
+ attn = self.module
+ assert isinstance(attn, Attention)
+
+ residual = hidden_states
+
+ batch_size, sequence_length, _ = hidden_states.shape
+
+ query = attn.to_q(hidden_states)
+
+ encoder_hidden_states = hidden_states
+ k = self.module.to_k(encoder_hidden_states)
+ v = self.module.to_v(encoder_hidden_states)
+ kv = torch.cat([k, v], dim=-1) # shape of kv now: (bs, seq_len // parallel_size, dim * 2)
+
+ if self.patched_parallelism_size == 1:
+ full_kv = kv
+ else:
+ if self.buffer_list is None: # buffer not created
+ full_kv = torch.cat([kv for _ in range(self.patched_parallelism_size)], dim=1)
+ elif self.counter <= self.warm_step:
+ # logger.info(f"warmup: {self.counter}")
+ dist.all_gather(
+ self.buffer_list,
+ kv,
+ group=self.process_group,
+ async_op=False,
+ )
+ full_kv = torch.cat(self.buffer_list, dim=1)
+ else:
+ # logger.info(f"use old kv to infer: {self.counter}")
+ self.buffer_list[self.kv_buffer_idx].copy_(kv)
+ full_kv = torch.cat(self.buffer_list, dim=1)
+ assert self.handle is None, "we should maintain the kv of last step"
+ self.handle = dist.all_gather(self.buffer_list, kv, group=self.process_group, async_op=True)
+
+ if HAS_FLASH_ATTN:
+ # flash attn
+ key, value = torch.split(full_kv, full_kv.shape[-1] // 2, dim=-1)
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim)
+ key = key.view(batch_size, -1, attn.heads, head_dim)
+ value = value.view(batch_size, -1, attn.heads, head_dim)
+
+ hidden_states = flash_attn_func(query, key, value, dropout_p=0.0, causal=False)
+ hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim).to(query.dtype)
+ else:
+ # naive attn
+ key, value = torch.split(full_kv, full_kv.shape[-1] // 2, dim=-1)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ scale: float = 1.0,
+ *args,
+ **kwargs,
+ ) -> torch.FloatTensor:
+
+ # async preallocates memo buffer
+ if self.handle is not None:
+ self.handle.wait()
+ self.handle = None
+
+ b, l, c = hidden_states.shape
+ kv_shape = (b, l, self.module.to_k.out_features * 2)
+ if self.patched_parallelism_size > 1 and (self.buffer_list is None or self.buffer_list[0].shape != kv_shape):
+
+ self.buffer_list = [
+ torch.empty(kv_shape, dtype=hidden_states.dtype, device=get_current_device())
+ for _ in range(self.patched_parallelism_size)
+ ]
+
+ self.counter = 0
+
+ output = self._forward(hidden_states, scale=scale)
+
+ self.counter += 1
+ return output
diff --git a/colossalai/inference/modeling/models/pixart_alpha.py b/colossalai/inference/modeling/models/pixart_alpha.py
index d5774946e..cc2bee5ef 100644
--- a/colossalai/inference/modeling/models/pixart_alpha.py
+++ b/colossalai/inference/modeling/models/pixart_alpha.py
@@ -14,7 +14,7 @@ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retri
from colossalai.logging import get_dist_logger
-from .diffusion import DiffusionPipe
+from ..layers.diffusion import DiffusionPipe
logger = get_dist_logger(__name__)
diff --git a/colossalai/inference/modeling/models/stablediffusion3.py b/colossalai/inference/modeling/models/stablediffusion3.py
index d1c63a6dc..b12316403 100644
--- a/colossalai/inference/modeling/models/stablediffusion3.py
+++ b/colossalai/inference/modeling/models/stablediffusion3.py
@@ -4,7 +4,7 @@ from typing import Any, Callable, Dict, List, Optional, Union
import torch
from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import retrieve_timesteps
-from .diffusion import DiffusionPipe
+from ..layers.diffusion import DiffusionPipe
# TODO(@lry89757) temporarily image, please support more return output
diff --git a/colossalai/inference/modeling/policy/pixart_alpha.py b/colossalai/inference/modeling/policy/pixart_alpha.py
index 356056ba7..1150b2432 100644
--- a/colossalai/inference/modeling/policy/pixart_alpha.py
+++ b/colossalai/inference/modeling/policy/pixart_alpha.py
@@ -1,9 +1,17 @@
+from diffusers.models.attention import BasicTransformerBlock
+from diffusers.models.transformers.pixart_transformer_2d import PixArtTransformer2DModel
from torch import nn
from colossalai.inference.config import RPC_PARAM
-from colossalai.inference.modeling.models.diffusion import DiffusionPipe
+from colossalai.inference.modeling.layers.diffusion import DiffusionPipe
+from colossalai.inference.modeling.layers.distrifusion import (
+ DistrifusionConv2D,
+ DistrifusionPatchEmbed,
+ DistriSelfAttention,
+ PixArtAlphaTransformer2DModel_forward,
+)
from colossalai.inference.modeling.models.pixart_alpha import pixart_alpha_forward
-from colossalai.shardformer.policies.base_policy import Policy
+from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
class PixArtAlphaInferPolicy(Policy, RPC_PARAM):
@@ -12,9 +20,46 @@ class PixArtAlphaInferPolicy(Policy, RPC_PARAM):
def module_policy(self):
policy = {}
+
+ if self.shard_config.extra_kwargs["model_shard_infer_config"].patched_parallelism_size > 1:
+
+ policy[PixArtTransformer2DModel] = ModulePolicyDescription(
+ sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="pos_embed.proj",
+ target_module=DistrifusionConv2D,
+ kwargs={"model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"]},
+ ),
+ SubModuleReplacementDescription(
+ suffix="pos_embed",
+ target_module=DistrifusionPatchEmbed,
+ kwargs={"model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"]},
+ ),
+ ],
+ attribute_replacement={
+ "patched_parallel_size": self.shard_config.extra_kwargs[
+ "model_shard_infer_config"
+ ].patched_parallelism_size
+ },
+ method_replacement={"forward": PixArtAlphaTransformer2DModel_forward},
+ )
+
+ policy[BasicTransformerBlock] = ModulePolicyDescription(
+ sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="attn1",
+ target_module=DistriSelfAttention,
+ kwargs={
+ "model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"],
+ },
+ )
+ ]
+ )
+
self.append_or_create_method_replacement(
description={"forward": pixart_alpha_forward}, policy=policy, target_key=DiffusionPipe
)
+
return policy
def preprocess(self) -> nn.Module:
diff --git a/colossalai/inference/modeling/policy/stablediffusion3.py b/colossalai/inference/modeling/policy/stablediffusion3.py
index c9877f7dc..39b764b92 100644
--- a/colossalai/inference/modeling/policy/stablediffusion3.py
+++ b/colossalai/inference/modeling/policy/stablediffusion3.py
@@ -1,9 +1,17 @@
+from diffusers.models.attention import JointTransformerBlock
+from diffusers.models.transformers import SD3Transformer2DModel
from torch import nn
from colossalai.inference.config import RPC_PARAM
-from colossalai.inference.modeling.models.diffusion import DiffusionPipe
+from colossalai.inference.modeling.layers.diffusion import DiffusionPipe
+from colossalai.inference.modeling.layers.distrifusion import (
+ DistrifusionConv2D,
+ DistrifusionFusedAttention,
+ DistrifusionPatchEmbed,
+ SD3Transformer2DModel_forward,
+)
from colossalai.inference.modeling.models.stablediffusion3 import sd3_forward
-from colossalai.shardformer.policies.base_policy import Policy
+from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
class StableDiffusion3InferPolicy(Policy, RPC_PARAM):
@@ -12,6 +20,42 @@ class StableDiffusion3InferPolicy(Policy, RPC_PARAM):
def module_policy(self):
policy = {}
+
+ if self.shard_config.extra_kwargs["model_shard_infer_config"].patched_parallelism_size > 1:
+
+ policy[SD3Transformer2DModel] = ModulePolicyDescription(
+ sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="pos_embed.proj",
+ target_module=DistrifusionConv2D,
+ kwargs={"model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"]},
+ ),
+ SubModuleReplacementDescription(
+ suffix="pos_embed",
+ target_module=DistrifusionPatchEmbed,
+ kwargs={"model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"]},
+ ),
+ ],
+ attribute_replacement={
+ "patched_parallel_size": self.shard_config.extra_kwargs[
+ "model_shard_infer_config"
+ ].patched_parallelism_size
+ },
+ method_replacement={"forward": SD3Transformer2DModel_forward},
+ )
+
+ policy[JointTransformerBlock] = ModulePolicyDescription(
+ sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="attn",
+ target_module=DistrifusionFusedAttention,
+ kwargs={
+ "model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"],
+ },
+ )
+ ]
+ )
+
self.append_or_create_method_replacement(
description={"forward": sd3_forward}, policy=policy, target_key=DiffusionPipe
)
diff --git a/examples/inference/stable_diffusion/README.md b/examples/inference/stable_diffusion/README.md
new file mode 100644
index 000000000..c11b98043
--- /dev/null
+++ b/examples/inference/stable_diffusion/README.md
@@ -0,0 +1,22 @@
+## File Structure
+```
+|- sd3_generation.py: an example of how to use Colossalai Inference Engine to generate result by loading Diffusion Model.
+|- compute_metric.py: compare the quality of images w/o some acceleration method like Distrifusion
+|- benchmark_sd3.py: benchmark the performance of our InferenceEngine
+|- run_benchmark.sh: run benchmark command
+```
+Note: compute_metric.py need some dependencies which need `pip install -r requirements.txt`, `requirements.txt` is in `examples/inference/stable_diffusion/`
+
+## Run Inference
+
+The provided example `sd3_generation.py` is an example to configure, initialize the engine, and run inference on provided model. We've added `DiffusionPipeline` as model class, and the script is good to run inference with StableDiffusion 3.
+
+For a basic setting, you could run the example by:
+```bash
+colossalai run --nproc_per_node 1 sd3_generation.py -m PATH_MODEL -p "hello world"
+```
+
+Run multi-GPU inference (Patched Parallelism), as in the following example using 2 GPUs:
+```bash
+colossalai run --nproc_per_node 2 sd3_generation.py -m PATH_MODEL
+```
diff --git a/examples/inference/stable_diffusion/benchmark_sd3.py b/examples/inference/stable_diffusion/benchmark_sd3.py
new file mode 100644
index 000000000..19db57c33
--- /dev/null
+++ b/examples/inference/stable_diffusion/benchmark_sd3.py
@@ -0,0 +1,179 @@
+import argparse
+import json
+import time
+from contextlib import nullcontext
+
+import torch
+import torch.distributed as dist
+from diffusers import DiffusionPipeline
+
+import colossalai
+from colossalai.inference.config import DiffusionGenerationConfig, InferenceConfig
+from colossalai.inference.core.engine import InferenceEngine
+from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
+
+GIGABYTE = 1024**3
+MEGABYTE = 1024 * 1024
+
+_DTYPE_MAPPING = {
+ "fp16": torch.float16,
+ "bf16": torch.bfloat16,
+ "fp32": torch.float32,
+}
+
+
+def log_generation_time(log_data, log_file):
+ with open(log_file, "a") as f:
+ json.dump(log_data, f, indent=2)
+ f.write("\n")
+
+
+def warmup(engine, args):
+ for _ in range(args.n_warm_up_steps):
+ engine.generate(
+ prompts=["hello world"],
+ generation_config=DiffusionGenerationConfig(
+ num_inference_steps=args.num_inference_steps, height=args.height[0], width=args.width[0]
+ ),
+ )
+
+
+def profile_context(args):
+ return (
+ torch.profiler.profile(
+ record_shapes=True,
+ with_stack=True,
+ with_modules=True,
+ activities=[
+ torch.profiler.ProfilerActivity.CPU,
+ torch.profiler.ProfilerActivity.CUDA,
+ ],
+ )
+ if args.profile
+ else nullcontext()
+ )
+
+
+def log_and_profile(h, w, avg_time, log_msg, args, model_name, mode, prof=None):
+ log_data = {
+ "mode": mode,
+ "model": model_name,
+ "batch_size": args.batch_size,
+ "patched_parallel_size": args.patched_parallel_size,
+ "num_inference_steps": args.num_inference_steps,
+ "height": h,
+ "width": w,
+ "dtype": args.dtype,
+ "profile": args.profile,
+ "n_warm_up_steps": args.n_warm_up_steps,
+ "n_repeat_times": args.n_repeat_times,
+ "avg_generation_time": avg_time,
+ "log_message": log_msg,
+ }
+
+ if args.log:
+ log_file = f"examples/inference/stable_diffusion/benchmark_{model_name}_{mode}.json"
+ log_generation_time(log_data=log_data, log_file=log_file)
+
+ if args.profile:
+ file = f"examples/inference/stable_diffusion/benchmark_{model_name}_{mode}_prof.json"
+ prof.export_chrome_trace(file)
+
+
+def benchmark_colossalai(rank, world_size, port, args):
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ from colossalai.cluster.dist_coordinator import DistCoordinator
+
+ coordinator = DistCoordinator()
+
+ inference_config = InferenceConfig(
+ dtype=args.dtype,
+ patched_parallelism_size=args.patched_parallel_size,
+ )
+ engine = InferenceEngine(args.model, inference_config=inference_config, verbose=False)
+
+ warmup(engine, args)
+
+ for h, w in zip(args.height, args.width):
+ with profile_context(args) as prof:
+ start = time.perf_counter()
+ for _ in range(args.n_repeat_times):
+ engine.generate(
+ prompts=["hello world"],
+ generation_config=DiffusionGenerationConfig(
+ num_inference_steps=args.num_inference_steps, height=h, width=w
+ ),
+ )
+ end = time.perf_counter()
+
+ avg_time = (end - start) / args.n_repeat_times
+ log_msg = f"[ColossalAI]avg generation time for h({h})xw({w}) is {avg_time:.2f}s"
+ coordinator.print_on_master(log_msg)
+
+ if dist.get_rank() == 0:
+ log_and_profile(h, w, avg_time, log_msg, args, args.model.split("/")[-1], "colossalai", prof=prof)
+
+
+def benchmark_diffusers(args):
+ model = DiffusionPipeline.from_pretrained(args.model, torch_dtype=_DTYPE_MAPPING[args.dtype]).to("cuda")
+
+ for _ in range(args.n_warm_up_steps):
+ model(
+ prompt="hello world",
+ num_inference_steps=args.num_inference_steps,
+ height=args.height[0],
+ width=args.width[0],
+ )
+
+ for h, w in zip(args.height, args.width):
+ with profile_context(args) as prof:
+ start = time.perf_counter()
+ for _ in range(args.n_repeat_times):
+ model(prompt="hello world", num_inference_steps=args.num_inference_steps, height=h, width=w)
+ end = time.perf_counter()
+
+ avg_time = (end - start) / args.n_repeat_times
+ log_msg = f"[Diffusers]avg generation time for h({h})xw({w}) is {avg_time:.2f}s"
+ print(log_msg)
+
+ log_and_profile(h, w, avg_time, log_msg, args, args.model.split("/")[-1], "diffusers", prof)
+
+
+@rerun_if_address_is_in_use()
+@clear_cache_before_run()
+def benchmark(args):
+ if args.mode == "colossalai":
+ spawn(benchmark_colossalai, nprocs=args.patched_parallel_size, args=args)
+ elif args.mode == "diffusers":
+ benchmark_diffusers(args)
+
+
+"""
+# enable log
+python examples/inference/stable_diffusion/benchmark_sd3.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" -p 2 --mode colossalai --log
+python examples/inference/stable_diffusion/benchmark_sd3.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" --mode diffusers --log
+
+# enable profiler
+python examples/inference/stable_diffusion/benchmark_sd3.py -m "stabilityai/stable-diffusion-3-medium-diffusers" -p 2 --mode colossalai --n_warm_up_steps 3 --n_repeat_times 1 --profile --num_inference_steps 20
+python examples/inference/stable_diffusion/benchmark_sd3.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" -p 2 --mode colossalai --n_warm_up_steps 3 --n_repeat_times 1 --profile --num_inference_steps 20
+python examples/inference/stable_diffusion/benchmark_sd3.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" --mode diffusers --n_warm_up_steps 3 --n_repeat_times 1 --profile --num_inference_steps 20
+"""
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("-b", "--batch_size", type=int, default=1, help="Batch size")
+ parser.add_argument("-p", "--patched_parallel_size", type=int, default=1, help="Patched Parallelism size")
+ parser.add_argument("-n", "--num_inference_steps", type=int, default=50, help="Number of inference steps")
+ parser.add_argument("-H", "--height", type=int, nargs="+", default=[1024, 2048], help="Height list")
+ parser.add_argument("-w", "--width", type=int, nargs="+", default=[1024, 2048], help="Width list")
+ parser.add_argument("--dtype", type=str, default="fp16", choices=["fp16", "fp32", "bf16"], help="Data type")
+ parser.add_argument("--n_warm_up_steps", type=int, default=3, help="Number of warm up steps")
+ parser.add_argument("--n_repeat_times", type=int, default=5, help="Number of repeat times")
+ parser.add_argument("--profile", default=False, action="store_true", help="Enable torch profiler")
+ parser.add_argument("--log", default=False, action="store_true", help="Enable logging")
+ parser.add_argument("-m", "--model", default="stabilityai/stable-diffusion-3-medium-diffusers", help="Model path")
+ parser.add_argument(
+ "--mode", default="colossalai", choices=["colossalai", "diffusers"], help="Inference framework mode"
+ )
+ args = parser.parse_args()
+ benchmark(args)
diff --git a/examples/inference/stable_diffusion/compute_metric.py b/examples/inference/stable_diffusion/compute_metric.py
new file mode 100644
index 000000000..14c92501b
--- /dev/null
+++ b/examples/inference/stable_diffusion/compute_metric.py
@@ -0,0 +1,80 @@
+# Code from https://github.com/mit-han-lab/distrifuser/blob/main/scripts/compute_metrics.py
+import argparse
+import os
+
+import numpy as np
+import torch
+from cleanfid import fid
+from PIL import Image
+from torch.utils.data import DataLoader, Dataset
+from torchmetrics.image import LearnedPerceptualImagePatchSimilarity, PeakSignalNoiseRatio
+from torchvision.transforms import Resize
+from tqdm import tqdm
+
+
+def read_image(path: str):
+ """
+ input: path
+ output: tensor (C, H, W)
+ """
+ img = np.asarray(Image.open(path))
+ if len(img.shape) == 2:
+ img = np.repeat(img[:, :, None], 3, axis=2)
+ img = torch.from_numpy(img).permute(2, 0, 1)
+ return img
+
+
+class MultiImageDataset(Dataset):
+ def __init__(self, root0, root1, is_gt=False):
+ super().__init__()
+ self.root0 = root0
+ self.root1 = root1
+ file_names0 = os.listdir(root0)
+ file_names1 = os.listdir(root1)
+
+ self.image_names0 = sorted([name for name in file_names0 if name.endswith(".png") or name.endswith(".jpg")])
+ self.image_names1 = sorted([name for name in file_names1 if name.endswith(".png") or name.endswith(".jpg")])
+ self.is_gt = is_gt
+ assert len(self.image_names0) == len(self.image_names1)
+
+ def __len__(self):
+ return len(self.image_names0)
+
+ def __getitem__(self, idx):
+ img0 = read_image(os.path.join(self.root0, self.image_names0[idx]))
+ if self.is_gt:
+ # resize to 1024 x 1024
+ img0 = Resize((1024, 1024))(img0)
+ img1 = read_image(os.path.join(self.root1, self.image_names1[idx]))
+
+ batch_list = [img0, img1]
+ return batch_list
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--batch_size", type=int, default=64)
+ parser.add_argument("--num_workers", type=int, default=8)
+ parser.add_argument("--is_gt", action="store_true")
+ parser.add_argument("--input_root0", type=str, required=True)
+ parser.add_argument("--input_root1", type=str, required=True)
+ args = parser.parse_args()
+
+ psnr = PeakSignalNoiseRatio(data_range=(0, 1), reduction="elementwise_mean", dim=(1, 2, 3)).to("cuda")
+ lpips = LearnedPerceptualImagePatchSimilarity(normalize=True).to("cuda")
+
+ dataset = MultiImageDataset(args.input_root0, args.input_root1, is_gt=args.is_gt)
+ dataloader = DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers)
+
+ progress_bar = tqdm(dataloader)
+ with torch.inference_mode():
+ for i, batch in enumerate(progress_bar):
+ batch = [img.to("cuda") / 255 for img in batch]
+ batch_size = batch[0].shape[0]
+ psnr.update(batch[0], batch[1])
+ lpips.update(batch[0], batch[1])
+ fid_score = fid.compute_fid(args.input_root0, args.input_root1)
+
+ print("PSNR:", psnr.compute().item())
+ print("LPIPS:", lpips.compute().item())
+ print("FID:", fid_score)
diff --git a/examples/inference/stable_diffusion/requirements.txt b/examples/inference/stable_diffusion/requirements.txt
new file mode 100644
index 000000000..c4e74162d
--- /dev/null
+++ b/examples/inference/stable_diffusion/requirements.txt
@@ -0,0 +1,3 @@
+torchvision
+torchmetrics
+cleanfid
diff --git a/examples/inference/stable_diffusion/run_benchmark.sh b/examples/inference/stable_diffusion/run_benchmark.sh
new file mode 100644
index 000000000..f3e45a335
--- /dev/null
+++ b/examples/inference/stable_diffusion/run_benchmark.sh
@@ -0,0 +1,42 @@
+#!/bin/bash
+
+models=("PixArt-alpha/PixArt-XL-2-1024-MS" "stabilityai/stable-diffusion-3-medium-diffusers")
+parallelism=(1 2 4 8)
+resolutions=(1024 2048 3840)
+modes=("colossalai" "diffusers")
+
+CUDA_VISIBLE_DEVICES_set_n_least_memory_usage() {
+ local n=${1:-"9999"}
+ echo "GPU Memory Usage:"
+ local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \
+ | tail -n +2 \
+ | nl -v 0 \
+ | tee /dev/tty \
+ | sort -g -k 2 \
+ | awk '{print $1}' \
+ | head -n $n)
+ export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
+ echo "Now CUDA_VISIBLE_DEVICES is set to:"
+ echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
+}
+
+for model in "${models[@]}"; do
+ for p in "${parallelism[@]}"; do
+ for resolution in "${resolutions[@]}"; do
+ for mode in "${modes[@]}"; do
+ if [[ "$mode" == "colossalai" && "$p" == 1 ]]; then
+ continue
+ fi
+ if [[ "$mode" == "diffusers" && "$p" != 1 ]]; then
+ continue
+ fi
+ CUDA_VISIBLE_DEVICES_set_n_least_memory_usage $p
+
+ cmd="python examples/inference/stable_diffusion/benchmark_sd3.py -m \"$model\" -p $p --mode $mode --log -H $resolution -w $resolution"
+
+ echo "Executing: $cmd"
+ eval $cmd
+ done
+ done
+ done
+done
diff --git a/examples/inference/stable_diffusion/sd3_generation.py b/examples/inference/stable_diffusion/sd3_generation.py
index fe989eed7..9e146c34b 100644
--- a/examples/inference/stable_diffusion/sd3_generation.py
+++ b/examples/inference/stable_diffusion/sd3_generation.py
@@ -1,18 +1,17 @@
import argparse
-from diffusers import PixArtAlphaPipeline, StableDiffusion3Pipeline
-from torch import bfloat16, float16, float32
+from diffusers import DiffusionPipeline
+from torch import bfloat16
+from torch import distributed as dist
+from torch import float16, float32
import colossalai
from colossalai.cluster import DistCoordinator
from colossalai.inference.config import DiffusionGenerationConfig, InferenceConfig
from colossalai.inference.core.engine import InferenceEngine
-from colossalai.inference.modeling.policy.pixart_alpha import PixArtAlphaInferPolicy
-from colossalai.inference.modeling.policy.stablediffusion3 import StableDiffusion3InferPolicy
# For Stable Diffusion 3, we'll use the following configuration
-MODEL_CLS = [StableDiffusion3Pipeline, PixArtAlphaPipeline][0]
-POLICY_CLS = [StableDiffusion3InferPolicy, PixArtAlphaInferPolicy][0]
+MODEL_CLS = DiffusionPipeline
TORCH_DTYPE_MAP = {
"fp16": float16,
@@ -43,20 +42,27 @@ def infer(args):
max_batch_size=args.max_batch_size,
tp_size=args.tp_size,
use_cuda_kernel=args.use_cuda_kernel,
+ patched_parallelism_size=dist.get_world_size(),
)
- engine = InferenceEngine(model, inference_config=inference_config, model_policy=POLICY_CLS(), verbose=True)
+ engine = InferenceEngine(model, inference_config=inference_config, verbose=True)
# ==============================
# Generation
# ==============================
coordinator.print_on_master(f"Generating...")
out = engine.generate(prompts=[args.prompt], generation_config=DiffusionGenerationConfig())[0]
- out.save("cat.jpg")
+ if dist.get_rank() == 0:
+ out.save(f"cat_parallel_size{dist.get_world_size()}.jpg")
coordinator.print_on_master(out)
# colossalai run --nproc_per_node 1 examples/inference/stable_diffusion/sd3_generation.py -m MODEL_PATH
+
# colossalai run --nproc_per_node 1 examples/inference/stable_diffusion/sd3_generation.py -m "stabilityai/stable-diffusion-3-medium-diffusers" --tp_size 1
+# colossalai run --nproc_per_node 2 examples/inference/stable_diffusion/sd3_generation.py -m "stabilityai/stable-diffusion-3-medium-diffusers" --tp_size 1
+
+# colossalai run --nproc_per_node 1 examples/inference/stable_diffusion/sd3_generation.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" --tp_size 1
+# colossalai run --nproc_per_node 2 examples/inference/stable_diffusion/sd3_generation.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" --tp_size 1
if __name__ == "__main__":