mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 03:20:52 +00:00
[Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support (#5838)
* Diffusion Model Inference support * Stable Diffusion 3 Support * pixartalpha support
This commit is contained in:
@@ -5,10 +5,12 @@ Utils for model inference
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
from torch import nn
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
@@ -159,3 +161,38 @@ def can_use_flash_attn2(dtype: torch.dtype) -> bool:
|
||||
except ImportError:
|
||||
logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.")
|
||||
return False
|
||||
|
||||
|
||||
class ModelType(Enum):
|
||||
DIFFUSION_MODEL = "Diffusion Model"
|
||||
LLM = "Large Language Model (LLM)"
|
||||
UNKNOWN = "Unknown Model Type"
|
||||
|
||||
|
||||
def get_model_type(model_or_path: Union[nn.Module, str, DiffusionPipeline]):
|
||||
if isinstance(model_or_path, DiffusionPipeline):
|
||||
return ModelType.DIFFUSION_MODEL
|
||||
elif isinstance(model_or_path, nn.Module):
|
||||
return ModelType.LLM
|
||||
elif isinstance(model_or_path, str):
|
||||
try:
|
||||
from transformers import AutoConfig
|
||||
|
||||
hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True)
|
||||
return ModelType.LLM
|
||||
except:
|
||||
"""
|
||||
model type is not `ModelType.LLM`
|
||||
"""
|
||||
|
||||
try:
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
DiffusionPipeline.load_config(model_or_path)
|
||||
return ModelType.DIFFUSION_MODEL
|
||||
except:
|
||||
"""
|
||||
model type is not `ModelType.DIFFUSION_MODEL`
|
||||
"""
|
||||
else:
|
||||
return ModelType.UNKNOWN
|
||||
|
Reference in New Issue
Block a user