Merge pull request #3905 from MaruyamaAya/dreambooth

[example] Adding an example of training dreambooth with the new booster API
This commit is contained in:
Liu Ziming 2023-06-09 08:44:18 +08:00 committed by GitHub
commit e277534a18
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 194 additions and 106 deletions

View File

@ -92,6 +92,29 @@ torchrun --nproc_per_node 2 train_dreambooth_colossalai.py \
--placement="cuda" --placement="cuda"
``` ```
## New API
We have modified our previous implementation of Dreambooth with our new Booster API, which offers a more flexible and efficient way to train your model. The new API is more user-friendly and easy to use. You can find the new API in `train_dreambooth_colossalai.py`.
We have also offer a shell script `test_ci.sh` for you to go through all our plugins for the booster.
For more information about the booster API you can refer to https://colossalai.org/docs/basics/booster_api/.
## Performance
| Strategy | #GPU | Batch Size | GPU RAM(GB) | speedup |
|:--------------:|:----:|:----------:|:-----------:|:-------:|
| Traditional | 1 | 16 | oom | \ |
| Traditional | 1 | 8 | 61.81 | 1 |
| torch_ddp | 4 | 16 | oom | \ |
| torch_ddp | 4 | 8 | 41.97 | 0.97 |
| gemini | 4 | 16 | 53.29 | \ |
| gemini | 4 | 8 | 29.36 | 2.00 |
| low_level_zero | 4 | 16 | 52.80 | \ |
| low_level_zero | 4 | 8 | 28.87 | 2.02 |
The evaluation is performed on 4 Nvidia A100 GPUs with 80GB memory each, with GPU 0 & 1, 2 & 3 connected with NVLink.
We finetuned the [stable-diffusion-v1-4](https://huggingface.co/stabilityai/stable-diffusion-v1-4) model with 512x512 resolution on the [Teyvat](https://huggingface.co/datasets/Fazzie/Teyvat) dataset and compared
the memory cost and the throughput for the plugins.
## Inference ## Inference
Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline`. Make sure to include the `identifier`(e.g. `--instance_prompt="a photo of sks dog" ` in the above example) in your prompt. Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline`. Make sure to include the `identifier`(e.g. `--instance_prompt="a photo of sks dog" ` in the above example) in your prompt.

View File

@ -1,22 +1,18 @@
export MODEL_NAME= <Your Pretrained Model Path>
export INSTANCE_DIR= <Your Input Pics Path>
export CLASS_DIR="path-to-class-images"
export OUTPUT_DIR="path-to-save-model"
HF_DATASETS_OFFLINE=1 HF_DATASETS_OFFLINE=1
TRANSFORMERS_OFFLINE=1 TRANSFORMERS_OFFLINE=1
DIFFUSERS_OFFLINE=1 DIFFUSERS_OFFLINE=1
torchrun --nproc_per_node 2 --master_port=25641 train_dreambooth_colossalai.py \ torchrun --nproc_per_node 4 --standalone train_dreambooth_colossalai.py \
--pretrained_model_name_or_path=$MODEL_NAME \ --pretrained_model_name_or_path="/data/dreambooth/diffuser/stable-diffusion-v1-4" \
--instance_data_dir=$INSTANCE_DIR \ --instance_data_dir="/data/dreambooth/Teyvat/data" \
--output_dir=$OUTPUT_DIR \ --output_dir="./weight_output" \
--instance_prompt="a photo of a dog" \ --instance_prompt="a picture of a dog" \
--resolution=512 \ --resolution=512 \
--plugin="gemini" \
--train_batch_size=1 \ --train_batch_size=1 \
--gradient_accumulation_steps=1 \
--learning_rate=5e-6 \ --learning_rate=5e-6 \
--lr_scheduler="constant" \ --lr_scheduler="constant" \
--lr_warmup_steps=0 \ --lr_warmup_steps=0 \
--num_class_images=200 \ --num_class_images=200 \
--placement="cuda" \ --test_run=True \
--placement="auto" \

View File

@ -1,7 +1,7 @@
python train_dreambooth.py \ python train_dreambooth.py \
--pretrained_model_name_or_path= ## Your Model Path \ --pretrained_model_name_or_path="/data/dreambooth/diffuser/stable-diffusion-v1-4" \
--instance_data_dir= ## Your Training Input Pics Path \ --instance_data_dir="/data/dreambooth/Teyvat/data" \
--output_dir="path-to-save-model" \ --output_dir="./weight_output" \
--instance_prompt="a photo of a dog" \ --instance_prompt="a photo of a dog" \
--resolution=512 \ --resolution=512 \
--train_batch_size=1 \ --train_batch_size=1 \

View File

@ -0,0 +1,25 @@
#!/bin/bash
set -xe
pip install -r requirements.txt
HF_DATASETS_OFFLINE=1
TRANSFORMERS_OFFLINE=1
DIFFUSERS_OFFLINE=1
# "torch_ddp" "torch_ddp_fp16" "low_level_zero"
for plugin in "gemini"; do
torchrun --nproc_per_node 4 --standalone train_dreambooth_colossalai.py \
--pretrained_model_name_or_path="/data/dreambooth/diffuser/stable-diffusion-v1-4" \
--instance_data_dir="/data/dreambooth/Teyvat/data" \
--output_dir="./weight_output" \
--instance_prompt="a picture of a dog" \
--resolution=512 \
--plugin=$plugin \
--train_batch_size=1 \
--learning_rate=5e-6 \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--test_run=True \
--num_class_images=200 \
--placement="auto" # "cuda"
done

View File

@ -4,6 +4,7 @@ import math
import os import os
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
import shutil
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@ -21,9 +22,12 @@ import colossalai
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.zero import ColoInitContext, GeminiAdamOptimizer from colossalai.zero import ColoInitContext
from colossalai.zero.gemini import get_static_torch_model from colossalai.zero.gemini import get_static_torch_model
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
disable_existing_loggers() disable_existing_loggers()
logger = get_dist_logger() logger = get_dist_logger()
@ -58,6 +62,13 @@ def parse_args(input_args=None):
required=True, required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.", help="Path to pretrained model or model identifier from huggingface.co/models.",
) )
parser.add_argument(
"--externel_unet_path",
type=str,
default=None,
required=False,
help="Path to the externel unet model.",
)
parser.add_argument( parser.add_argument(
"--revision", "--revision",
type=str, type=str,
@ -187,12 +198,19 @@ def parse_args(input_args=None):
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
parser.add_argument("--test_run", default=False, help="Whether to use a smaller dataset for test run.")
parser.add_argument( parser.add_argument(
"--hub_model_id", "--hub_model_id",
type=str, type=str,
default=None, default=None,
help="The name of the repository to keep in sync with the local `output_dir`.", help="The name of the repository to keep in sync with the local `output_dir`.",
) )
parser.add_argument('-p',
'--plugin',
type=str,
default='torch_ddp',
choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero'],
help="plugin to use")
parser.add_argument( parser.add_argument(
"--logging_dir", "--logging_dir",
type=str, type=str,
@ -250,6 +268,7 @@ class DreamBoothDataset(Dataset):
class_prompt=None, class_prompt=None,
size=512, size=512,
center_crop=False, center_crop=False,
test=False,
): ):
self.size = size self.size = size
self.center_crop = center_crop self.center_crop = center_crop
@ -260,6 +279,8 @@ class DreamBoothDataset(Dataset):
raise ValueError("Instance images root doesn't exists.") raise ValueError("Instance images root doesn't exists.")
self.instance_images_path = list(Path(instance_data_root).iterdir()) self.instance_images_path = list(Path(instance_data_root).iterdir())
if test:
self.instance_images_path = self.instance_images_path[:10]
self.num_instance_images = len(self.instance_images_path) self.num_instance_images = len(self.instance_images_path)
self.instance_prompt = instance_prompt self.instance_prompt = instance_prompt
self._length = self.num_instance_images self._length = self.num_instance_images
@ -339,18 +360,6 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
return f"{organization}/{model_id}" return f"{organization}/{model_id}"
# Gemini + ZeRO DDP
def gemini_zero_dpp(model: torch.nn.Module, placement_policy: str = "auto"):
from colossalai.nn.parallel import GeminiDDP
model = GeminiDDP(model,
device=get_current_device(),
placement_policy=placement_policy,
pin_memory=True,
search_range_mb=64)
return model
def main(args): def main(args):
if args.seed is None: if args.seed is None:
colossalai.launch_from_torch(config={}) colossalai.launch_from_torch(config={})
@ -392,7 +401,7 @@ def main(args):
images = pipeline(example["prompt"]).images images = pipeline(example["prompt"]).images
for i, image in enumerate(images): for i, image in enumerate(images):
hash_image = hashlib.sha1(image.tobytes()).hexdigest() hash_image = hashlib.sha256(image.tobytes()).hexdigest()
image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
image.save(image_filename) image.save(image_filename)
@ -452,12 +461,18 @@ def main(args):
revision=args.revision, revision=args.revision,
) )
if args.externel_unet_path is None:
logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0]) logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0])
with ColoInitContext(device=get_current_device()):
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path,
subfolder="unet", subfolder="unet",
revision=args.revision, revision=args.revision,
low_cpu_mem_usage=False) low_cpu_mem_usage=False)
else:
logger.info(f"Loading UNet2DConditionModel from {args.externel_unet_path}", ranks=[0])
unet = UNet2DConditionModel.from_pretrained(args.externel_unet_path,
revision=args.revision,
low_cpu_mem_usage=False)
vae.requires_grad_(False) vae.requires_grad_(False)
text_encoder.requires_grad_(False) text_encoder.requires_grad_(False)
@ -468,10 +483,22 @@ def main(args):
if args.scale_lr: if args.scale_lr:
args.learning_rate = args.learning_rate * args.train_batch_size * world_size args.learning_rate = args.learning_rate * args.train_batch_size * world_size
unet = gemini_zero_dpp(unet, args.placement) # Use Booster API to use Gemini/Zero with ColossalAI
booster_kwargs = {}
if args.plugin == 'torch_ddp_fp16':
booster_kwargs['mixed_precision'] = 'fp16'
if args.plugin.startswith('torch_ddp'):
plugin = TorchDDPPlugin()
elif args.plugin == 'gemini':
plugin = GeminiPlugin(placement_policy=args.placement, strict_ddp_mode=True, initial_scale=2 ** 5)
elif args.plugin == 'low_level_zero':
plugin = LowLevelZeroPlugin(initial_scale=2 ** 5)
booster = Booster(plugin=plugin, **booster_kwargs)
# config optimizer for colossalai zero # config optimizer for colossalai zero
optimizer = GeminiAdamOptimizer(unet, lr=args.learning_rate, initial_scale=2**5, clipping_norm=args.max_grad_norm) optimizer = HybridAdam(unet.parameters(), lr=args.learning_rate, initial_scale=2**5, clipping_norm=args.max_grad_norm)
# load noise_scheduler # load noise_scheduler
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
@ -486,6 +513,7 @@ def main(args):
tokenizer=tokenizer, tokenizer=tokenizer,
size=args.resolution, size=args.resolution,
center_crop=args.center_crop, center_crop=args.center_crop,
test=args.test_run
) )
def collate_fn(examples): def collate_fn(examples):
@ -554,6 +582,8 @@ def main(args):
# Afterwards we recalculate our number of training epochs # Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
unet, optimizer, _, _, lr_scheduler = booster.boost(unet, optimizer, lr_scheduler=lr_scheduler)
# Train! # Train!
total_batch_size = args.train_batch_size * world_size total_batch_size = args.train_batch_size * world_size
@ -642,36 +672,24 @@ def main(args):
if global_step % args.save_steps == 0: if global_step % args.save_steps == 0:
torch.cuda.synchronize() torch.cuda.synchronize()
torch_unet = get_static_torch_model(unet)
if local_rank == 0:
pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
unet=torch_unet,
revision=args.revision,
)
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
pipeline.save_pretrained(save_path) booster.save_model(unet, os.path.join(save_path, "diffusion_pytorch_model.bin"))
if local_rank == 0:
if not os.path.exists(os.path.join(save_path, "config.json")):
shutil.copy(os.path.join(args.pretrained_model_name_or_path, "unet/config.json"), save_path)
logger.info(f"Saving model checkpoint to {save_path}", ranks=[0]) logger.info(f"Saving model checkpoint to {save_path}", ranks=[0])
if global_step >= args.max_train_steps: if global_step >= args.max_train_steps:
break break
torch.cuda.synchronize() torch.cuda.synchronize()
unet = get_static_torch_model(unet)
booster.save_model(unet, os.path.join(args.output_dir, "diffusion_pytorch_model.bin"))
logger.info(f"Saving model checkpoint to {args.output_dir} on rank {local_rank}")
if local_rank == 0: if local_rank == 0:
pipeline = DiffusionPipeline.from_pretrained( if not os.path.exists(os.path.join(args.output_dir, "config.json")):
args.pretrained_model_name_or_path, shutil.copy(os.path.join(args.pretrained_model_name_or_path, "unet/config.json"), args.output_dir)
unet=unet,
revision=args.revision,
)
pipeline.save_pretrained(args.output_dir)
logger.info(f"Saving model checkpoint to {args.output_dir}", ranks=[0])
if args.push_to_hub: if args.push_to_hub:
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args()
main(args) main(args)

View File

@ -4,6 +4,7 @@ import math
import os import os
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
import shutil
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@ -23,9 +24,12 @@ import colossalai
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.zero import ColoInitContext, GeminiAdamOptimizer from colossalai.zero import ColoInitContext, GeminiAdamOptimizer
from colossalai.zero.gemini import get_static_torch_model from colossalai.zero.gemini import get_static_torch_model
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
disable_existing_loggers() disable_existing_loggers()
logger = get_dist_logger() logger = get_dist_logger()
@ -60,6 +64,13 @@ def parse_args(input_args=None):
required=True, required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.", help="Path to pretrained model or model identifier from huggingface.co/models.",
) )
parser.add_argument(
"--externel_unet_path",
type=str,
default=None,
required=False,
help="Path to the externel unet model.",
)
parser.add_argument( parser.add_argument(
"--revision", "--revision",
type=str, type=str,
@ -195,6 +206,12 @@ def parse_args(input_args=None):
default=None, default=None,
help="The name of the repository to keep in sync with the local `output_dir`.", help="The name of the repository to keep in sync with the local `output_dir`.",
) )
parser.add_argument('-p',
'--plugin',
type=str,
default='torch_ddp',
choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero'],
help="plugin to use")
parser.add_argument( parser.add_argument(
"--logging_dir", "--logging_dir",
type=str, type=str,
@ -341,18 +358,6 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
return f"{organization}/{model_id}" return f"{organization}/{model_id}"
# Gemini + ZeRO DDP
def gemini_zero_dpp(model: torch.nn.Module, placement_policy: str = "auto"):
from colossalai.nn.parallel import GeminiDDP
model = GeminiDDP(model,
device=get_current_device(),
placement_policy=placement_policy,
pin_memory=True,
search_range_mb=64)
return model
def main(args): def main(args):
if args.seed is None: if args.seed is None:
colossalai.launch_from_torch(config={}) colossalai.launch_from_torch(config={})
@ -394,7 +399,7 @@ def main(args):
images = pipeline(example["prompt"]).images images = pipeline(example["prompt"]).images
for i, image in enumerate(images): for i, image in enumerate(images):
hash_image = hashlib.sha1(image.tobytes()).hexdigest() hash_image = hashlib.sha256(image.tobytes()).hexdigest()
image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
image.save(image_filename) image.save(image_filename)
@ -454,8 +459,18 @@ def main(args):
revision=args.revision, revision=args.revision,
) )
if args.externel_unet_path is None:
logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0]) logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0])
with ColoInitContext(device=get_current_device()): unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path,
subfolder="unet",
revision=args.revision,
low_cpu_mem_usage=False)
else:
logger.info(f"Loading UNet2DConditionModel from {args.externel_unet_path}", ranks=[0])
unet = UNet2DConditionModel.from_pretrained(args.externel_unet_path,
revision=args.revision,
low_cpu_mem_usage=False)
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path,
subfolder="unet", subfolder="unet",
revision=args.revision, revision=args.revision,
@ -490,10 +505,22 @@ def main(args):
if args.scale_lr: if args.scale_lr:
args.learning_rate = args.learning_rate * args.train_batch_size * world_size args.learning_rate = args.learning_rate * args.train_batch_size * world_size
unet = gemini_zero_dpp(unet, args.placement) # Use Booster API to use Gemini/Zero with ColossalAI
booster_kwargs = {}
if args.plugin == 'torch_ddp_fp16':
booster_kwargs['mixed_precision'] = 'fp16'
if args.plugin.startswith('torch_ddp'):
plugin = TorchDDPPlugin()
elif args.plugin == 'gemini':
plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, initial_scale=2 ** 5)
elif args.plugin == 'low_level_zero':
plugin = LowLevelZeroPlugin(initial_scale=2 ** 5)
booster = Booster(plugin=plugin, **booster_kwargs)
# config optimizer for colossalai zero # config optimizer for colossalai zero
optimizer = GeminiAdamOptimizer(unet, lr=args.learning_rate, initial_scale=2**5, clipping_norm=args.max_grad_norm) optimizer = HybridAdam(unet.parameters(), lr=args.learning_rate, initial_scale=2**5, clipping_norm=args.max_grad_norm)
# load noise_scheduler # load noise_scheduler
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
@ -576,6 +603,8 @@ def main(args):
# Afterwards we recalculate our number of training epochs # Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
unet, optimizer, _, _, lr_scheduler = booster.boost(unet, optimizer, lr_scheduler=lr_scheduler)
# Train! # Train!
total_batch_size = args.train_batch_size * world_size total_batch_size = args.train_batch_size * world_size
@ -664,27 +693,24 @@ def main(args):
if global_step % args.save_steps == 0: if global_step % args.save_steps == 0:
torch.cuda.synchronize() torch.cuda.synchronize()
torch_unet = get_static_torch_model(unet)
if local_rank == 0:
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
torch_unet = torch_unet.to(torch.float32) booster.save_model(unet, os.path.join(save_path, "diffusion_pytorch_model.bin"))
torch_unet.save_attn_procs(save_path) if local_rank == 0:
if not os.path.exists(os.path.join(save_path, "config.json")):
shutil.copy(os.path.join(args.pretrained_model_name_or_path, "unet/config.json"), save_path)
logger.info(f"Saving model checkpoint to {save_path}", ranks=[0]) logger.info(f"Saving model checkpoint to {save_path}", ranks=[0])
if global_step >= args.max_train_steps: if global_step >= args.max_train_steps:
break break
torch.cuda.synchronize() torch.cuda.synchronize()
torch_unet = get_static_torch_model(unet)
booster.save_model(unet, os.path.join(args.output_dir, "diffusion_pytorch_model.bin"))
logger.info(f"Saving model checkpoint to {args.output_dir} on rank {local_rank}")
if local_rank == 0: if local_rank == 0:
torch_unet = torch_unet.to(torch.float32) if not os.path.exists(os.path.join(args.output_dir, "config.json")):
torch_unet.save_attn_procs(save_path) shutil.copy(os.path.join(args.pretrained_model_name_or_path, "unet/config.json"), args.output_dir)
logger.info(f"Saving model checkpoint to {args.output_dir}", ranks=[0])
if args.push_to_hub: if args.push_to_hub:
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args()
main(args) main(args)