diff --git a/train.py b/train.py index 189df41c..8605af11 100644 --- a/train.py +++ b/train.py @@ -100,7 +100,7 @@ def train(accelerator, config): name="cosine", optimizer=optimizer, num_warmup_steps=config["warmup_steps"] * accelerator.num_processes, - num_training_steps=total_num_steps * accelerator.num_processes, + num_training_steps=total_num_steps, ) else: scheduler = DummyScheduler(