mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 12:01:39 +00:00
[Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support (#5838)
* Diffusion Model Inference support * Stable Diffusion 3 Support * pixartalpha support
This commit is contained in:
34
colossalai/inference/modeling/policy/pixart_alpha.py
Normal file
34
colossalai/inference/modeling/policy/pixart_alpha.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from torch import nn
|
||||
|
||||
from colossalai.inference.config import RPC_PARAM
|
||||
from colossalai.inference.modeling.models.diffusion import DiffusionPipe
|
||||
from colossalai.inference.modeling.models.pixart_alpha import pixart_alpha_forward
|
||||
from colossalai.shardformer.policies.base_policy import Policy
|
||||
|
||||
|
||||
class PixArtAlphaInferPolicy(Policy, RPC_PARAM):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
policy = {}
|
||||
self.append_or_create_method_replacement(
|
||||
description={"forward": pixart_alpha_forward}, policy=policy, target_key=DiffusionPipe
|
||||
)
|
||||
return policy
|
||||
|
||||
def preprocess(self) -> nn.Module:
|
||||
return self.model
|
||||
|
||||
def postprocess(self):
|
||||
return self.model
|
||||
|
||||
def config_sanity_check(self):
|
||||
pass
|
||||
|
||||
def to_rpc_param(self) -> str:
|
||||
return __class__.__name__
|
||||
|
||||
@staticmethod
|
||||
def from_rpc_param() -> "PixArtAlphaInferPolicy":
|
||||
return PixArtAlphaInferPolicy()
|
Reference in New Issue
Block a user