[Feat] Distrifusion Acceleration Support for Diffusion Inference (#5895)

* Distrifusion Support source

* comp comm overlap optimization

* sd3 benchmark

* pixart distrifusion bug fix

* sd3 bug fix and benchmark

* generation bug fix

* naming fix

* add docstring, fix counter and shape error

* add reference

* readme and requirement
This commit is contained in:
Runyu Lu
2024-07-30 10:43:26 +08:00
committed by GitHub
parent 7b38964e3a
commit bcf0181ecd
15 changed files with 1089 additions and 16 deletions

View File

@@ -186,6 +186,7 @@ class InferenceConfig(RPC_PARAM):
enable_streamingllm(bool): Whether to use StreamingLLM, the relevant algorithms refer to the paper at https://arxiv.org/pdf/2309.17453 for implementation.
start_token_size(int): The size of the start tokens, when using StreamingLLM.
generated_token_size(int): The size of the generated tokens, When using StreamingLLM.
patched_parallelism_size(int): Patched Parallelism Size, When using Distrifusion
"""
# NOTE: arrange configs according to their importance and frequency of usage
@@ -245,6 +246,11 @@ class InferenceConfig(RPC_PARAM):
start_token_size: int = 4
generated_token_size: int = 512
# Acceleration for Diffusion Model(PipeFusion or Distrifusion)
patched_parallelism_size: int = 1 # for distrifusion
# pipeFusion_m_size: int = 1 # for pipefusion
# pipeFusion_n_size: int = 1 # for pipefusion
def __post_init__(self):
self.max_context_len_to_capture = self.max_input_len + self.max_output_len
self._verify_config()
@@ -288,6 +294,14 @@ class InferenceConfig(RPC_PARAM):
# Thereafter, we swap out tokens in units of blocks, and always swapping out the second block when the generated tokens exceeded the limit.
self.start_token_size = self.block_size
# check Distrifusion
# TODO(@lry89757) need more detailed check
if self.patched_parallelism_size > 1:
# self.use_patched_parallelism = True
self.tp_size = (
self.patched_parallelism_size
) # this is not a real tp, because some annoying check, so we have to set this to patched_parallelism_size
# check prompt template
if self.prompt_template is None:
return
@@ -324,6 +338,7 @@ class InferenceConfig(RPC_PARAM):
use_cuda_kernel=self.use_cuda_kernel,
use_spec_dec=self.use_spec_dec,
use_flash_attn=use_flash_attn,
patched_parallelism_size=self.patched_parallelism_size,
)
return model_inference_config
@@ -396,6 +411,7 @@ class ModelShardInferenceConfig:
use_cuda_kernel: bool = False
use_spec_dec: bool = False
use_flash_attn: bool = False
patched_parallelism_size: int = 1 # for diffusion model, Distrifusion Technique
@dataclass