mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 13:30:19 +00:00
[exmaple] diffuser, support quant inference for stable diffusion (#2186)
This commit is contained in:
@@ -22,6 +22,7 @@ from imwatermark import WatermarkEncoder
|
||||
from scripts.txt2img import put_watermark
|
||||
from ldm.util import instantiate_from_config
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from utils import replace_module, getModelSize
|
||||
|
||||
|
||||
def chunk(it, size):
|
||||
@@ -44,7 +45,6 @@ def load_model_from_config(config, ckpt, verbose=False):
|
||||
print("unexpected keys:")
|
||||
print(u)
|
||||
|
||||
model.cuda()
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
@@ -183,6 +183,12 @@ def main():
|
||||
choices=["full", "autocast"],
|
||||
default="autocast"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_int8",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="use int8 for inference",
|
||||
)
|
||||
|
||||
opt = parser.parse_args()
|
||||
seed_everything(opt.seed)
|
||||
@@ -193,6 +199,12 @@ def main():
|
||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
model = model.to(device)
|
||||
|
||||
# quantize model
|
||||
if opt.use_int8:
|
||||
model = replace_module(model)
|
||||
# # to compute the model size
|
||||
# getModelSize(model)
|
||||
|
||||
sampler = DDIMSampler(model)
|
||||
|
||||
os.makedirs(opt.outdir, exist_ok=True)
|
||||
@@ -280,3 +292,5 @@ def main():
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
# # to compute the mem allocated
|
||||
# print(torch.cuda.max_memory_allocated() / 1024 / 1024)
|
||||
|
Reference in New Issue
Block a user