[Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support (#5838)

* Diffusion Model Inference support

* Stable Diffusion 3 Support

* pixartalpha support
This commit is contained in:
Runyu Lu
2024-07-08 16:02:07 +08:00
committed by GitHub
parent 8ec24b6a4d
commit cba20525a8
16 changed files with 1860 additions and 740 deletions

View File

@@ -5,10 +5,12 @@ Utils for model inference
import math
import os
import re
from enum import Enum
from pathlib import Path
from typing import Optional, Tuple
from typing import Optional, Tuple, Union
import torch
from diffusers import DiffusionPipeline
from torch import nn
from colossalai.logging import get_dist_logger
@@ -159,3 +161,38 @@ def can_use_flash_attn2(dtype: torch.dtype) -> bool:
except ImportError:
logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.")
return False
class ModelType(Enum):
DIFFUSION_MODEL = "Diffusion Model"
LLM = "Large Language Model (LLM)"
UNKNOWN = "Unknown Model Type"
def get_model_type(model_or_path: Union[nn.Module, str, DiffusionPipeline]):
if isinstance(model_or_path, DiffusionPipeline):
return ModelType.DIFFUSION_MODEL
elif isinstance(model_or_path, nn.Module):
return ModelType.LLM
elif isinstance(model_or_path, str):
try:
from transformers import AutoConfig
hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True)
return ModelType.LLM
except:
"""
model type is not `ModelType.LLM`
"""
try:
from diffusers import DiffusionPipeline
DiffusionPipeline.load_config(model_or_path)
return ModelType.DIFFUSION_MODEL
except:
"""
model type is not `ModelType.DIFFUSION_MODEL`
"""
else:
return ModelType.UNKNOWN