[Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support (#5838)

* Diffusion Model Inference support

* Stable Diffusion 3 Support

* pixartalpha support
This commit is contained in:
Runyu Lu
2024-07-08 16:02:07 +08:00
committed by GitHub
parent 8ec24b6a4d
commit cba20525a8
16 changed files with 1860 additions and 740 deletions

View File

@@ -0,0 +1,34 @@
from torch import nn
from colossalai.inference.config import RPC_PARAM
from colossalai.inference.modeling.models.diffusion import DiffusionPipe
from colossalai.inference.modeling.models.pixart_alpha import pixart_alpha_forward
from colossalai.shardformer.policies.base_policy import Policy
class PixArtAlphaInferPolicy(Policy, RPC_PARAM):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
policy = {}
self.append_or_create_method_replacement(
description={"forward": pixart_alpha_forward}, policy=policy, target_key=DiffusionPipe
)
return policy
def preprocess(self) -> nn.Module:
return self.model
def postprocess(self):
return self.model
def config_sanity_check(self):
pass
def to_rpc_param(self) -> str:
return __class__.__name__
@staticmethod
def from_rpc_param() -> "PixArtAlphaInferPolicy":
return PixArtAlphaInferPolicy()