mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-07-19 01:36:37 +00:00
fix: add epoch train
This commit is contained in:
parent
bb28929305
commit
24765a1965
9
train.py
9
train.py
@ -115,6 +115,7 @@ def train(accelerator, config):
|
|||||||
"gradient_accumulation_steps"
|
"gradient_accumulation_steps"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
for epoch in range(config["num_epochs"]):
|
||||||
for step, batch in enumerate(tqdm(train_dataloader)):
|
for step, batch in enumerate(tqdm(train_dataloader)):
|
||||||
model.train()
|
model.train()
|
||||||
outputs = model(**batch)
|
outputs = model(**batch)
|
||||||
@ -158,6 +159,13 @@ def train(accelerator, config):
|
|||||||
|
|
||||||
train_loss.reset()
|
train_loss.reset()
|
||||||
|
|
||||||
|
accelerator.print(f"Epoch {epoch} finished")
|
||||||
|
accelerator.print(f"Pushing to HF hub")
|
||||||
|
accelerator.wait_for_everyone()
|
||||||
|
unwrapped_model = accelerator.unwrap_model(model)
|
||||||
|
if accelerator.is_main_process:
|
||||||
|
unwrapped_model.push_to_hub(config["save_name"], private=True)
|
||||||
|
|
||||||
|
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
unwrapped_model = accelerator.unwrap_model(model)
|
unwrapped_model = accelerator.unwrap_model(model)
|
||||||
@ -168,6 +176,7 @@ def train(accelerator, config):
|
|||||||
state_dict=accelerator.get_state_dict(model),
|
state_dict=accelerator.get_state_dict(model),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if accelerator.is_main_process:
|
||||||
unwrapped_model.push_to_hub(config["save_name"], private=True)
|
unwrapped_model.push_to_hub(config["save_name"], private=True)
|
||||||
|
|
||||||
accelerator.end_training()
|
accelerator.end_training()
|
||||||
|
Loading…
Reference in New Issue
Block a user