mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-06-27 15:58:25 +00:00
fix: try except push
This commit is contained in:
parent
885b7f1a3a
commit
838b19bea5
12
train.py
12
train.py
@ -13,6 +13,7 @@ from torchmetrics import MeanMetric
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
|
||||||
def format_metrics(metrics, split, prefix=""):
|
def format_metrics(metrics, split, prefix=""):
|
||||||
log = f"[{split}]" + prefix
|
log = f"[{split}]" + prefix
|
||||||
@ -192,9 +193,20 @@ def train(accelerator, config):
|
|||||||
accelerator.print(f"Pushing to HF hub")
|
accelerator.print(f"Pushing to HF hub")
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
unwrapped_model = accelerator.unwrap_model(model)
|
unwrapped_model = accelerator.unwrap_model(model)
|
||||||
|
try:
|
||||||
if accelerator.is_main_process:
|
if accelerator.is_main_process:
|
||||||
unwrapped_model.push_to_hub(config["save_name"] + f"-epoch_{epoch}", private=True)
|
unwrapped_model.push_to_hub(config["save_name"] + f"-epoch_{epoch}", private=True)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
accelerator.print(e)
|
||||||
|
accelerator.print(f"Failed to push to hub")
|
||||||
|
|
||||||
|
unwrapped_model.save_pretrained(
|
||||||
|
f"{config['output_dir']}/-epoch_{epoch}",
|
||||||
|
is_main_process=accelerator.is_main_process,
|
||||||
|
save_function=accelerator.save,
|
||||||
|
state_dict=accelerator.get_state_dict(model),
|
||||||
|
)
|
||||||
|
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
unwrapped_model = accelerator.unwrap_model(model)
|
unwrapped_model = accelerator.unwrap_model(model)
|
||||||
|
Loading…
Reference in New Issue
Block a user