diff --git a/train.py b/train.py index 1f0b4852..0f35bc5b 100644 --- a/train.py +++ b/train.py @@ -137,7 +137,7 @@ def train(accelerator, config): # log gradients - if accelerator.is_local_main_process and config["wandb"]: + if accelerator.is_main_process and config["wandb"]: wandb.watch(model, log_freq=config["log_grads_every"]) for epoch in range(config["num_epochs"]):