mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
[Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support (#5838)
* Diffusion Model Inference support * Stable Diffusion 3 Support * pixartalpha support
This commit is contained in:
@@ -5,7 +5,7 @@ Our config contains various options for inference optimization, it is a unified
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, fields
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers.generation import GenerationConfig
|
||||
@@ -396,3 +396,49 @@ class ModelShardInferenceConfig:
|
||||
use_cuda_kernel: bool = False
|
||||
use_spec_dec: bool = False
|
||||
use_flash_attn: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class DiffusionGenerationConfig:
|
||||
"""
|
||||
Param for diffusion model forward
|
||||
"""
|
||||
|
||||
prompt_2: Optional[Union[str, List[str]]] = None
|
||||
prompt_3: Optional[Union[str, List[str]]] = None
|
||||
height: Optional[int] = None
|
||||
width: Optional[int] = None
|
||||
num_inference_steps: int = None
|
||||
timesteps: List[int] = None
|
||||
guidance_scale: float = None
|
||||
negative_prompt: Optional[Union[str, List[str]]] = (
|
||||
None # NOTE(@lry89757) in pixart default to "", in sd3 default to None
|
||||
)
|
||||
negative_prompt_2: Optional[Union[str, List[str]]] = None
|
||||
negative_prompt_3: Optional[Union[str, List[str]]] = None
|
||||
num_images_per_prompt: Optional[int] = None
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None
|
||||
latents: Optional[torch.FloatTensor] = None
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None
|
||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None
|
||||
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None
|
||||
output_type: Optional[str] = None # "pil"
|
||||
return_dict: bool = None
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None
|
||||
clip_skip: Optional[int] = None
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None
|
||||
callback_on_step_end_tensor_inputs: List[str] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
# NOTE(@lry89757) Only return the dict that not the default value None
|
||||
result = {}
|
||||
for field in fields(self):
|
||||
value = getattr(self, field.name)
|
||||
if value is not None:
|
||||
result[field.name] = value
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def from_kwargs(cls, **kwargs) -> "DiffusionGenerationConfig":
|
||||
return cls(**kwargs)
|
||||
|
Reference in New Issue
Block a user