mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 04:50:17 +00:00
[Feat] Distrifusion Acceleration Support for Diffusion Inference (#5895)
* Distrifusion Support source * comp comm overlap optimization * sd3 benchmark * pixart distrifusion bug fix * sd3 bug fix and benchmark * generation bug fix * naming fix * add docstring, fix counter and shape error * add reference * readme and requirement
This commit is contained in:
@@ -186,6 +186,7 @@ class InferenceConfig(RPC_PARAM):
|
||||
enable_streamingllm(bool): Whether to use StreamingLLM, the relevant algorithms refer to the paper at https://arxiv.org/pdf/2309.17453 for implementation.
|
||||
start_token_size(int): The size of the start tokens, when using StreamingLLM.
|
||||
generated_token_size(int): The size of the generated tokens, When using StreamingLLM.
|
||||
patched_parallelism_size(int): Patched Parallelism Size, When using Distrifusion
|
||||
"""
|
||||
|
||||
# NOTE: arrange configs according to their importance and frequency of usage
|
||||
@@ -245,6 +246,11 @@ class InferenceConfig(RPC_PARAM):
|
||||
start_token_size: int = 4
|
||||
generated_token_size: int = 512
|
||||
|
||||
# Acceleration for Diffusion Model(PipeFusion or Distrifusion)
|
||||
patched_parallelism_size: int = 1 # for distrifusion
|
||||
# pipeFusion_m_size: int = 1 # for pipefusion
|
||||
# pipeFusion_n_size: int = 1 # for pipefusion
|
||||
|
||||
def __post_init__(self):
|
||||
self.max_context_len_to_capture = self.max_input_len + self.max_output_len
|
||||
self._verify_config()
|
||||
@@ -288,6 +294,14 @@ class InferenceConfig(RPC_PARAM):
|
||||
# Thereafter, we swap out tokens in units of blocks, and always swapping out the second block when the generated tokens exceeded the limit.
|
||||
self.start_token_size = self.block_size
|
||||
|
||||
# check Distrifusion
|
||||
# TODO(@lry89757) need more detailed check
|
||||
if self.patched_parallelism_size > 1:
|
||||
# self.use_patched_parallelism = True
|
||||
self.tp_size = (
|
||||
self.patched_parallelism_size
|
||||
) # this is not a real tp, because some annoying check, so we have to set this to patched_parallelism_size
|
||||
|
||||
# check prompt template
|
||||
if self.prompt_template is None:
|
||||
return
|
||||
@@ -324,6 +338,7 @@ class InferenceConfig(RPC_PARAM):
|
||||
use_cuda_kernel=self.use_cuda_kernel,
|
||||
use_spec_dec=self.use_spec_dec,
|
||||
use_flash_attn=use_flash_attn,
|
||||
patched_parallelism_size=self.patched_parallelism_size,
|
||||
)
|
||||
return model_inference_config
|
||||
|
||||
@@ -396,6 +411,7 @@ class ModelShardInferenceConfig:
|
||||
use_cuda_kernel: bool = False
|
||||
use_spec_dec: bool = False
|
||||
use_flash_attn: bool = False
|
||||
patched_parallelism_size: int = 1 # for diffusion model, Distrifusion Technique
|
||||
|
||||
|
||||
@dataclass
|
||||
|
Reference in New Issue
Block a user