[Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support (#5838)

* Diffusion Model Inference support

* Stable Diffusion 3 Support

* pixartalpha support
This commit is contained in:
Runyu Lu
2024-07-08 16:02:07 +08:00
committed by GitHub
parent 8ec24b6a4d
commit cba20525a8
16 changed files with 1860 additions and 740 deletions

View File

@@ -5,7 +5,7 @@ Our config contains various options for inference optimization, it is a unified
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass, fields
from typing import Any, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union
import torch
from transformers.generation import GenerationConfig
@@ -396,3 +396,49 @@ class ModelShardInferenceConfig:
use_cuda_kernel: bool = False
use_spec_dec: bool = False
use_flash_attn: bool = False
@dataclass
class DiffusionGenerationConfig:
"""
Param for diffusion model forward
"""
prompt_2: Optional[Union[str, List[str]]] = None
prompt_3: Optional[Union[str, List[str]]] = None
height: Optional[int] = None
width: Optional[int] = None
num_inference_steps: int = None
timesteps: List[int] = None
guidance_scale: float = None
negative_prompt: Optional[Union[str, List[str]]] = (
None # NOTE(@lry89757) in pixart default to "", in sd3 default to None
)
negative_prompt_2: Optional[Union[str, List[str]]] = None
negative_prompt_3: Optional[Union[str, List[str]]] = None
num_images_per_prompt: Optional[int] = None
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None
latents: Optional[torch.FloatTensor] = None
prompt_embeds: Optional[torch.FloatTensor] = None
negative_prompt_embeds: Optional[torch.FloatTensor] = None
pooled_prompt_embeds: Optional[torch.FloatTensor] = None
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None
output_type: Optional[str] = None # "pil"
return_dict: bool = None
joint_attention_kwargs: Optional[Dict[str, Any]] = None
clip_skip: Optional[int] = None
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None
callback_on_step_end_tensor_inputs: List[str] = None
def to_dict(self) -> Dict[str, Any]:
# NOTE(@lry89757) Only return the dict that not the default value None
result = {}
for field in fields(self):
value = getattr(self, field.name)
if value is not None:
result[field.name] = value
return result
@classmethod
def from_kwargs(cls, **kwargs) -> "DiffusionGenerationConfig":
return cls(**kwargs)