mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 17:17:05 +00:00
change directory
This commit is contained in:
@@ -92,6 +92,29 @@ torchrun --nproc_per_node 2 train_dreambooth_colossalai.py \
|
||||
--placement="cuda"
|
||||
```
|
||||
|
||||
## New API
|
||||
We have modified our previous implementation of Dreambooth with our new Booster API, which offers a more flexible and efficient way to train your model. The new API is more user-friendly and easy to use. You can find the new API in `train_dreambooth_colossalai.py`.
|
||||
We have also offer a shell script `test_ci.sh` for you to go through all our plugins for the booster.
|
||||
For more information about the booster API you can refer to https://colossalai.org/docs/basics/booster_api/.
|
||||
|
||||
## Performance
|
||||
|
||||
| Strategy | #GPU | Batch Size | GPU RAM(GB) | speedup |
|
||||
|:--------------:|:----:|:----------:|:-----------:|:-------:|
|
||||
| Traditional | 1 | 16 | oom | \ |
|
||||
| Traditional | 1 | 8 | 61.81 | 1 |
|
||||
| torch_ddp | 4 | 16 | oom | \ |
|
||||
| torch_ddp | 4 | 8 | 41.97 | 0.97 |
|
||||
| gemini | 4 | 16 | 53.29 | \ |
|
||||
| gemini | 4 | 8 | 29.36 | 2.00 |
|
||||
| low_level_zero | 4 | 16 | 52.80 | \ |
|
||||
| low_level_zero | 4 | 8 | 28.87 | 2.02 |
|
||||
|
||||
The evaluation is performed on 4 Nvidia A100 GPUs with 80GB memory each, with GPU 0 & 1, 2 & 3 connected with NVLink.
|
||||
We finetuned the [stable-diffusion-v1-4](https://huggingface.co/stabilityai/stable-diffusion-v1-4) model with 512x512 resolution on the [Teyvat](https://huggingface.co/datasets/Fazzie/Teyvat) dataset and compared
|
||||
the memory cost and the throughput for the plugins.
|
||||
|
||||
|
||||
## Inference
|
||||
|
||||
Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline`. Make sure to include the `identifier`(e.g. `--instance_prompt="a photo of sks dog" ` in the above example) in your prompt.
|
||||
|
@@ -1,20 +1,15 @@
|
||||
export MODEL_NAME= <Your Pretrained Model Path>
|
||||
export INSTANCE_DIR= <Your Input Pics Path>
|
||||
export CLASS_DIR="path-to-class-images"
|
||||
export OUTPUT_DIR="path-to-save-model"
|
||||
|
||||
HF_DATASETS_OFFLINE=1
|
||||
TRANSFORMERS_OFFLINE=1
|
||||
HF_DATASETS_OFFLINE=1
|
||||
TRANSFORMERS_OFFLINE=1
|
||||
DIFFUSERS_OFFLINE=1
|
||||
|
||||
torchrun --nproc_per_node 2 --master_port=25641 train_dreambooth_colossalai.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--instance_data_dir=$INSTANCE_DIR \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--instance_prompt="a photo of a dog" \
|
||||
torchrun --nproc_per_node 4 --master_port=25641 train_dreambooth_colossalai.py \
|
||||
--pretrained_model_name_or_path="Path_to_your_model" \
|
||||
--instance_data_dir="Path_to_your_training_image" \
|
||||
--output_dir="Path_to_your_save_dir" \
|
||||
--instance_prompt="your prompt" \
|
||||
--resolution=512 \
|
||||
--plugin="gemini" \
|
||||
--train_batch_size=1 \
|
||||
--gradient_accumulation_steps=1 \
|
||||
--learning_rate=5e-6 \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
|
@@ -0,0 +1,23 @@
|
||||
#!/bin/bash
|
||||
set -xe
|
||||
pip install -r requirements.txt
|
||||
|
||||
HF_DATASETS_OFFLINE=1
|
||||
TRANSFORMERS_OFFLINE=1
|
||||
DIFFUSERS_OFFLINE=1
|
||||
|
||||
for plugin in "torch_ddp" "torch_ddp_fp16" "gemini" "low_level_zero"; do
|
||||
torchrun --nproc_per_node 4 --master_port=25641 train_dreambooth_colossalai.py \
|
||||
--pretrained_model_name_or_path="Your Pretrained Model Path" \
|
||||
--instance_data_dir="Your Input Pics Path" \
|
||||
--output_dir="path-to-save-model" \
|
||||
--instance_prompt="your prompt" \
|
||||
--resolution=512 \
|
||||
--plugin=$plugin \
|
||||
--train_batch_size=1 \
|
||||
--learning_rate=5e-6 \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--num_class_images=200 \
|
||||
--placement="cuda"
|
||||
done
|
||||
|
@@ -4,6 +4,7 @@ import math
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
import shutil
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -21,9 +22,12 @@ import colossalai
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero import ColoInitContext, GeminiAdamOptimizer
|
||||
from colossalai.zero import ColoInitContext
|
||||
from colossalai.zero.gemini import get_static_torch_model
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||
|
||||
disable_existing_loggers()
|
||||
logger = get_dist_logger()
|
||||
@@ -58,6 +62,13 @@ def parse_args(input_args=None):
|
||||
required=True,
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--externel_unet_path",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help="Path to the externel unet model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--revision",
|
||||
type=str,
|
||||
@@ -193,6 +204,12 @@ def parse_args(input_args=None):
|
||||
default=None,
|
||||
help="The name of the repository to keep in sync with the local `output_dir`.",
|
||||
)
|
||||
parser.add_argument('-p',
|
||||
'--plugin',
|
||||
type=str,
|
||||
default='torch_ddp',
|
||||
choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero'],
|
||||
help="plugin to use")
|
||||
parser.add_argument(
|
||||
"--logging_dir",
|
||||
type=str,
|
||||
@@ -339,18 +356,6 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
|
||||
return f"{organization}/{model_id}"
|
||||
|
||||
|
||||
# Gemini + ZeRO DDP
|
||||
def gemini_zero_dpp(model: torch.nn.Module, placement_policy: str = "auto"):
|
||||
from colossalai.nn.parallel import GeminiDDP
|
||||
|
||||
model = GeminiDDP(model,
|
||||
device=get_current_device(),
|
||||
placement_policy=placement_policy,
|
||||
pin_memory=True,
|
||||
search_range_mb=64)
|
||||
return model
|
||||
|
||||
|
||||
def main(args):
|
||||
if args.seed is None:
|
||||
colossalai.launch_from_torch(config={})
|
||||
@@ -392,7 +397,7 @@ def main(args):
|
||||
images = pipeline(example["prompt"]).images
|
||||
|
||||
for i, image in enumerate(images):
|
||||
hash_image = hashlib.sha1(image.tobytes()).hexdigest()
|
||||
hash_image = hashlib.sha256(image.tobytes()).hexdigest()
|
||||
image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
|
||||
image.save(image_filename)
|
||||
|
||||
@@ -452,12 +457,18 @@ def main(args):
|
||||
revision=args.revision,
|
||||
)
|
||||
|
||||
logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0])
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
|
||||
if args.externel_unet_path is None:
|
||||
logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0])
|
||||
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path,
|
||||
subfolder="unet",
|
||||
revision=args.revision,
|
||||
low_cpu_mem_usage=False)
|
||||
subfolder="unet",
|
||||
revision=args.revision,
|
||||
low_cpu_mem_usage=False)
|
||||
else:
|
||||
logger.info(f"Loading UNet2DConditionModel from {args.externel_unet_path}", ranks=[0])
|
||||
unet = UNet2DConditionModel.from_pretrained(args.externel_unet_path,
|
||||
revision=args.revision,
|
||||
low_cpu_mem_usage=False)
|
||||
|
||||
vae.requires_grad_(False)
|
||||
text_encoder.requires_grad_(False)
|
||||
@@ -468,10 +479,22 @@ def main(args):
|
||||
if args.scale_lr:
|
||||
args.learning_rate = args.learning_rate * args.train_batch_size * world_size
|
||||
|
||||
unet = gemini_zero_dpp(unet, args.placement)
|
||||
# Use Booster API to use Gemini/Zero with ColossalAI
|
||||
|
||||
booster_kwargs = {}
|
||||
if args.plugin == 'torch_ddp_fp16':
|
||||
booster_kwargs['mixed_precision'] = 'fp16'
|
||||
if args.plugin.startswith('torch_ddp'):
|
||||
plugin = TorchDDPPlugin()
|
||||
elif args.plugin == 'gemini':
|
||||
plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, initial_scale=2 ** 5)
|
||||
elif args.plugin == 'low_level_zero':
|
||||
plugin = LowLevelZeroPlugin(initial_scale=2 ** 5)
|
||||
|
||||
booster = Booster(plugin=plugin, **booster_kwargs)
|
||||
|
||||
# config optimizer for colossalai zero
|
||||
optimizer = GeminiAdamOptimizer(unet, lr=args.learning_rate, initial_scale=2**5, clipping_norm=args.max_grad_norm)
|
||||
optimizer = HybridAdam(unet.parameters(), lr=args.learning_rate, initial_scale=2**5, clipping_norm=args.max_grad_norm)
|
||||
|
||||
# load noise_scheduler
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
||||
@@ -554,6 +577,8 @@ def main(args):
|
||||
# Afterwards we recalculate our number of training epochs
|
||||
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
unet, optimizer, _, _, lr_scheduler = booster.boost(unet, optimizer, lr_scheduler=lr_scheduler)
|
||||
|
||||
# Train!
|
||||
total_batch_size = args.train_batch_size * world_size
|
||||
|
||||
@@ -642,36 +667,24 @@ def main(args):
|
||||
|
||||
if global_step % args.save_steps == 0:
|
||||
torch.cuda.synchronize()
|
||||
torch_unet = get_static_torch_model(unet)
|
||||
if local_rank == 0:
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
unet=torch_unet,
|
||||
revision=args.revision,
|
||||
)
|
||||
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
||||
pipeline.save_pretrained(save_path)
|
||||
booster.save_model(unet, os.path.join(save_path, "diffusion_pytorch_model.bin"))
|
||||
if not os.path.exists(os.path.join(save_path, "config.json")):
|
||||
shutil.copy(os.path.join(args.pretrained_model_name_or_path, "unet/config.json"), save_path)
|
||||
logger.info(f"Saving model checkpoint to {save_path}", ranks=[0])
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
torch.cuda.synchronize()
|
||||
unet = get_static_torch_model(unet)
|
||||
|
||||
booster.save_model(unet, os.path.join(args.output_dir, "diffusion_pytorch_model.bin"))
|
||||
logger.info(f"Saving model checkpoint to {args.output_dir} on rank {local_rank}")
|
||||
if local_rank == 0:
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
unet=unet,
|
||||
revision=args.revision,
|
||||
)
|
||||
|
||||
pipeline.save_pretrained(args.output_dir)
|
||||
logger.info(f"Saving model checkpoint to {args.output_dir}", ranks=[0])
|
||||
|
||||
if not os.path.exists(os.path.join(args.output_dir, "config.json")):
|
||||
shutil.copy(os.path.join(args.pretrained_model_name_or_path, "unet/config.json"), args.output_dir)
|
||||
if args.push_to_hub:
|
||||
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
|
@@ -4,6 +4,7 @@ import math
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
import shutil
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -23,9 +24,12 @@ import colossalai
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero import ColoInitContext, GeminiAdamOptimizer
|
||||
from colossalai.zero.gemini import get_static_torch_model
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||
|
||||
disable_existing_loggers()
|
||||
logger = get_dist_logger()
|
||||
@@ -60,6 +64,13 @@ def parse_args(input_args=None):
|
||||
required=True,
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--externel_unet_path",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help="Path to the externel unet model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--revision",
|
||||
type=str,
|
||||
@@ -195,6 +206,12 @@ def parse_args(input_args=None):
|
||||
default=None,
|
||||
help="The name of the repository to keep in sync with the local `output_dir`.",
|
||||
)
|
||||
parser.add_argument('-p',
|
||||
'--plugin',
|
||||
type=str,
|
||||
default='torch_ddp',
|
||||
choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero'],
|
||||
help="plugin to use")
|
||||
parser.add_argument(
|
||||
"--logging_dir",
|
||||
type=str,
|
||||
@@ -341,18 +358,6 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
|
||||
return f"{organization}/{model_id}"
|
||||
|
||||
|
||||
# Gemini + ZeRO DDP
|
||||
def gemini_zero_dpp(model: torch.nn.Module, placement_policy: str = "auto"):
|
||||
from colossalai.nn.parallel import GeminiDDP
|
||||
|
||||
model = GeminiDDP(model,
|
||||
device=get_current_device(),
|
||||
placement_policy=placement_policy,
|
||||
pin_memory=True,
|
||||
search_range_mb=64)
|
||||
return model
|
||||
|
||||
|
||||
def main(args):
|
||||
if args.seed is None:
|
||||
colossalai.launch_from_torch(config={})
|
||||
@@ -394,7 +399,7 @@ def main(args):
|
||||
images = pipeline(example["prompt"]).images
|
||||
|
||||
for i, image in enumerate(images):
|
||||
hash_image = hashlib.sha1(image.tobytes()).hexdigest()
|
||||
hash_image = hashlib.sha256(image.tobytes()).hexdigest()
|
||||
image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
|
||||
image.save(image_filename)
|
||||
|
||||
@@ -454,32 +459,42 @@ def main(args):
|
||||
revision=args.revision,
|
||||
)
|
||||
|
||||
logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0])
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
|
||||
if args.externel_unet_path is None:
|
||||
logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0])
|
||||
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path,
|
||||
subfolder="unet",
|
||||
revision=args.revision,
|
||||
low_cpu_mem_usage=False)
|
||||
unet.requires_grad_(False)
|
||||
subfolder="unet",
|
||||
revision=args.revision,
|
||||
low_cpu_mem_usage=False)
|
||||
else:
|
||||
logger.info(f"Loading UNet2DConditionModel from {args.externel_unet_path}", ranks=[0])
|
||||
unet = UNet2DConditionModel.from_pretrained(args.externel_unet_path,
|
||||
revision=args.revision,
|
||||
low_cpu_mem_usage=False)
|
||||
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path,
|
||||
subfolder="unet",
|
||||
revision=args.revision,
|
||||
low_cpu_mem_usage=False)
|
||||
unet.requires_grad_(False)
|
||||
|
||||
# Set correct lora layers
|
||||
lora_attn_procs = {}
|
||||
for name in unet.attn_processors.keys():
|
||||
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
|
||||
if name.startswith("mid_block"):
|
||||
hidden_size = unet.config.block_out_channels[-1]
|
||||
elif name.startswith("up_blocks"):
|
||||
block_id = int(name[len("up_blocks.")])
|
||||
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
|
||||
elif name.startswith("down_blocks"):
|
||||
block_id = int(name[len("down_blocks.")])
|
||||
hidden_size = unet.config.block_out_channels[block_id]
|
||||
# Set correct lora layers
|
||||
lora_attn_procs = {}
|
||||
for name in unet.attn_processors.keys():
|
||||
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
|
||||
if name.startswith("mid_block"):
|
||||
hidden_size = unet.config.block_out_channels[-1]
|
||||
elif name.startswith("up_blocks"):
|
||||
block_id = int(name[len("up_blocks.")])
|
||||
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
|
||||
elif name.startswith("down_blocks"):
|
||||
block_id = int(name[len("down_blocks.")])
|
||||
hidden_size = unet.config.block_out_channels[block_id]
|
||||
|
||||
lora_attn_procs[name] = LoRACrossAttnProcessor(hidden_size=hidden_size,
|
||||
cross_attention_dim=cross_attention_dim)
|
||||
lora_attn_procs[name] = LoRACrossAttnProcessor(hidden_size=hidden_size,
|
||||
cross_attention_dim=cross_attention_dim)
|
||||
|
||||
unet.set_attn_processor(lora_attn_procs)
|
||||
lora_layers = AttnProcsLayers(unet.attn_processors)
|
||||
unet.set_attn_processor(lora_attn_procs)
|
||||
lora_layers = AttnProcsLayers(unet.attn_processors)
|
||||
|
||||
vae.requires_grad_(False)
|
||||
text_encoder.requires_grad_(False)
|
||||
@@ -490,10 +505,22 @@ def main(args):
|
||||
if args.scale_lr:
|
||||
args.learning_rate = args.learning_rate * args.train_batch_size * world_size
|
||||
|
||||
unet = gemini_zero_dpp(unet, args.placement)
|
||||
# Use Booster API to use Gemini/Zero with ColossalAI
|
||||
|
||||
booster_kwargs = {}
|
||||
if args.plugin == 'torch_ddp_fp16':
|
||||
booster_kwargs['mixed_precision'] = 'fp16'
|
||||
if args.plugin.startswith('torch_ddp'):
|
||||
plugin = TorchDDPPlugin()
|
||||
elif args.plugin == 'gemini':
|
||||
plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, initial_scale=2 ** 5)
|
||||
elif args.plugin == 'low_level_zero':
|
||||
plugin = LowLevelZeroPlugin(initial_scale=2 ** 5)
|
||||
|
||||
booster = Booster(plugin=plugin, **booster_kwargs)
|
||||
|
||||
# config optimizer for colossalai zero
|
||||
optimizer = GeminiAdamOptimizer(unet, lr=args.learning_rate, initial_scale=2**5, clipping_norm=args.max_grad_norm)
|
||||
optimizer = HybridAdam(unet.parameters(), lr=args.learning_rate, initial_scale=2**5, clipping_norm=args.max_grad_norm)
|
||||
|
||||
# load noise_scheduler
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
||||
@@ -576,6 +603,8 @@ def main(args):
|
||||
# Afterwards we recalculate our number of training epochs
|
||||
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
unet, optimizer, _, _, lr_scheduler = booster.boost(unet, optimizer, lr_scheduler=lr_scheduler)
|
||||
|
||||
# Train!
|
||||
total_batch_size = args.train_batch_size * world_size
|
||||
|
||||
@@ -664,27 +693,24 @@ def main(args):
|
||||
|
||||
if global_step % args.save_steps == 0:
|
||||
torch.cuda.synchronize()
|
||||
torch_unet = get_static_torch_model(unet)
|
||||
if local_rank == 0:
|
||||
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
||||
torch_unet = torch_unet.to(torch.float32)
|
||||
torch_unet.save_attn_procs(save_path)
|
||||
booster.save_model(unet, os.path.join(save_path, "diffusion_pytorch_model.bin"))
|
||||
if not os.path.exists(os.path.join(save_path, "config.json")):
|
||||
shutil.copy(os.path.join(args.pretrained_model_name_or_path, "unet/config.json"), save_path)
|
||||
logger.info(f"Saving model checkpoint to {save_path}", ranks=[0])
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
torch.cuda.synchronize()
|
||||
torch_unet = get_static_torch_model(unet)
|
||||
|
||||
booster.save_model(unet, os.path.join(args.output_dir, "diffusion_pytorch_model.bin"))
|
||||
logger.info(f"Saving model checkpoint to {args.output_dir} on rank {local_rank}")
|
||||
if local_rank == 0:
|
||||
torch_unet = torch_unet.to(torch.float32)
|
||||
torch_unet.save_attn_procs(save_path)
|
||||
logger.info(f"Saving model checkpoint to {args.output_dir}", ranks=[0])
|
||||
|
||||
if not os.path.exists(os.path.join(args.output_dir, "config.json")):
|
||||
shutil.copy(os.path.join(args.pretrained_model_name_or_path, "unet/config.json"), args.output_dir)
|
||||
if args.push_to_hub:
|
||||
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
|
Reference in New Issue
Block a user