diff --git a/train.py b/train.py index 72e53f4d..189df41c 100644 --- a/train.py +++ b/train.py @@ -192,7 +192,7 @@ def train(accelerator, config): accelerator.print(f"Failed to push to hub") unwrapped_model.save_pretrained( - f"{config['output_dir']}/-epoch_{epoch}", + f"{config['output_dir']}/epoch_{epoch}", is_main_process=accelerator.is_main_process, save_function=accelerator.save, state_dict=accelerator.get_state_dict(model),