[example] fixed seed error in train_dreambooth_colossalai.py (#2445)

This commit is contained in:
Haofan Wang 2023-01-11 16:56:15 +08:00 committed by GitHub
parent ac18a445fa
commit cfd1d5ee49
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -355,10 +355,11 @@ def gemini_zero_dpp(model: torch.nn.Module, placememt_policy: str = "auto"):
def main(args): def main(args):
colossalai.launch_from_torch(config={})
if args.seed is not None: if args.seed is None:
gpc.set_seed(args.seed) colossalai.launch_from_torch(config={})
else:
colossalai.launch_from_torch(config={}, seed=args.seed)
if args.with_prior_preservation: if args.with_prior_preservation:
class_images_dir = Path(args.class_data_dir) class_images_dir = Path(args.class_data_dir)