mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-25 15:01:43 +00:00
[example] fixed seed error in train_dreambooth_colossalai.py (#2445)
This commit is contained in:
parent
ac18a445fa
commit
cfd1d5ee49
@ -355,10 +355,11 @@ def gemini_zero_dpp(model: torch.nn.Module, placememt_policy: str = "auto"):
|
|||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
colossalai.launch_from_torch(config={})
|
|
||||||
|
|
||||||
if args.seed is not None:
|
if args.seed is None:
|
||||||
gpc.set_seed(args.seed)
|
colossalai.launch_from_torch(config={})
|
||||||
|
else:
|
||||||
|
colossalai.launch_from_torch(config={}, seed=args.seed)
|
||||||
|
|
||||||
if args.with_prior_preservation:
|
if args.with_prior_preservation:
|
||||||
class_images_dir = Path(args.class_data_dir)
|
class_images_dir = Path(args.class_data_dir)
|
||||||
|
Loading…
Reference in New Issue
Block a user