diff --git a/examples/images/dreambooth/train_dreambooth_colossalai.py b/examples/images/dreambooth/train_dreambooth_colossalai.py index b7e24bfe4..7c90b939a 100644 --- a/examples/images/dreambooth/train_dreambooth_colossalai.py +++ b/examples/images/dreambooth/train_dreambooth_colossalai.py @@ -355,10 +355,11 @@ def gemini_zero_dpp(model: torch.nn.Module, placememt_policy: str = "auto"): def main(args): - colossalai.launch_from_torch(config={}) - if args.seed is not None: - gpc.set_seed(args.seed) + if args.seed is None: + colossalai.launch_from_torch(config={}) + else: + colossalai.launch_from_torch(config={}, seed=args.seed) if args.with_prior_preservation: class_images_dir = Path(args.class_data_dir)