fix: add epoch train

This commit is contained in:
Zach Nussbaum 2023-03-27 16:32:35 +00:00
parent bb28929305
commit 24765a1965

View File

@ -115,6 +115,7 @@ def train(accelerator, config):
"gradient_accumulation_steps" "gradient_accumulation_steps"
] ]
for epoch in range(config["num_epochs"]):
for step, batch in enumerate(tqdm(train_dataloader)): for step, batch in enumerate(tqdm(train_dataloader)):
model.train() model.train()
outputs = model(**batch) outputs = model(**batch)
@ -158,6 +159,13 @@ def train(accelerator, config):
train_loss.reset() train_loss.reset()
accelerator.print(f"Epoch {epoch} finished")
accelerator.print(f"Pushing to HF hub")
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
if accelerator.is_main_process:
unwrapped_model.push_to_hub(config["save_name"], private=True)
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model) unwrapped_model = accelerator.unwrap_model(model)
@ -168,6 +176,7 @@ def train(accelerator, config):
state_dict=accelerator.get_state_dict(model), state_dict=accelerator.get_state_dict(model),
) )
if accelerator.is_main_process:
unwrapped_model.push_to_hub(config["save_name"], private=True) unwrapped_model.push_to_hub(config["save_name"], private=True)
accelerator.end_training() accelerator.end_training()