mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 09:38:05 +00:00
fix/transformer-verison (#2581)
This commit is contained in:
@@ -52,7 +52,7 @@ You can also update an existing [latent diffusion](https://github.com/CompVis/la
|
|||||||
|
|
||||||
```
|
```
|
||||||
conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch
|
conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch
|
||||||
pip install transformers==4.19.2 diffusers invisible-watermark
|
pip install transformers diffusers invisible-watermark
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Step 2: install lightning
|
#### Step 2: install lightning
|
||||||
|
@@ -18,7 +18,7 @@ dependencies:
|
|||||||
- test-tube>=0.7.5
|
- test-tube>=0.7.5
|
||||||
- streamlit==1.12.1
|
- streamlit==1.12.1
|
||||||
- einops==0.3.0
|
- einops==0.3.0
|
||||||
- transformers==4.19.2
|
- transformers
|
||||||
- webdataset==0.2.5
|
- webdataset==0.2.5
|
||||||
- kornia==0.6
|
- kornia==0.6
|
||||||
- open_clip_torch==2.0.2
|
- open_clip_torch==2.0.2
|
||||||
|
@@ -9,7 +9,7 @@ omegaconf==2.1.1
|
|||||||
test-tube>=0.7.5
|
test-tube>=0.7.5
|
||||||
streamlit>=0.73.1
|
streamlit>=0.73.1
|
||||||
einops==0.3.0
|
einops==0.3.0
|
||||||
transformers==4.19.2
|
transformers
|
||||||
webdataset==0.2.5
|
webdataset==0.2.5
|
||||||
open-clip-torch==2.7.0
|
open-clip-torch==2.7.0
|
||||||
gradio==3.11
|
gradio==3.11
|
||||||
|
@@ -10,7 +10,7 @@ import torch.nn.functional as F
|
|||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
|
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
|
||||||
from diffusers.optimization import get_scheduler
|
from diffusers.optimization import get_scheduler
|
||||||
from huggingface_hub import HfFolder, Repository, whoami
|
from huggingface_hub import HfFolder, Repository, create_repo, whoami
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
@@ -133,9 +133,13 @@ def parse_args(input_args=None):
|
|||||||
default="cpu",
|
default="cpu",
|
||||||
help="Placement Policy for Gemini. Valid when using colossalai as dist plan.",
|
help="Placement Policy for Gemini. Valid when using colossalai as dist plan.",
|
||||||
)
|
)
|
||||||
parser.add_argument("--center_crop",
|
parser.add_argument(
|
||||||
action="store_true",
|
"--center_crop",
|
||||||
help="Whether to center crop images before resizing to resolution")
|
default=False,
|
||||||
|
action="store_true",
|
||||||
|
help=("Whether to center crop the input images to the resolution. If not set, the images will be randomly"
|
||||||
|
" cropped. The images will be resized to the resolution first before cropping."),
|
||||||
|
)
|
||||||
parser.add_argument("--train_batch_size",
|
parser.add_argument("--train_batch_size",
|
||||||
type=int,
|
type=int,
|
||||||
default=4,
|
default=4,
|
||||||
@@ -149,13 +153,6 @@ def parse_args(input_args=None):
|
|||||||
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
||||||
)
|
)
|
||||||
parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.")
|
parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.")
|
||||||
parser.add_argument(
|
|
||||||
"--gradient_accumulation_steps",
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help=
|
|
||||||
"Number of updates steps to accumulate before performing a backward/update pass. If using Gemini, it must be 1",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--gradient_checkpointing",
|
"--gradient_checkpointing",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
@@ -356,7 +353,6 @@ def gemini_zero_dpp(model: torch.nn.Module, placememt_policy: str = "auto"):
|
|||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
|
|
||||||
if args.seed is None:
|
if args.seed is None:
|
||||||
colossalai.launch_from_torch(config={})
|
colossalai.launch_from_torch(config={})
|
||||||
else:
|
else:
|
||||||
@@ -410,7 +406,8 @@ def main(args):
|
|||||||
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
|
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
|
||||||
else:
|
else:
|
||||||
repo_name = args.hub_model_id
|
repo_name = args.hub_model_id
|
||||||
repo = Repository(args.output_dir, clone_from=repo_name)
|
create_repo(repo_name, exist_ok=True, token=args.hub_token)
|
||||||
|
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
|
||||||
|
|
||||||
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
|
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
|
||||||
if "step_*" not in gitignore:
|
if "step_*" not in gitignore:
|
||||||
@@ -469,9 +466,8 @@ def main(args):
|
|||||||
if args.gradient_checkpointing:
|
if args.gradient_checkpointing:
|
||||||
unet.enable_gradient_checkpointing()
|
unet.enable_gradient_checkpointing()
|
||||||
|
|
||||||
assert args.gradient_accumulation_steps == 1, "if using ColossalAI gradient_accumulation_steps must be set to 1."
|
|
||||||
if args.scale_lr:
|
if args.scale_lr:
|
||||||
args.learning_rate = args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * world_size
|
args.learning_rate = args.learning_rate * args.train_batch_size * world_size
|
||||||
|
|
||||||
unet = gemini_zero_dpp(unet, args.placement)
|
unet = gemini_zero_dpp(unet, args.placement)
|
||||||
|
|
||||||
@@ -529,7 +525,7 @@ def main(args):
|
|||||||
|
|
||||||
# Scheduler and math around the number of training steps.
|
# Scheduler and math around the number of training steps.
|
||||||
overrode_max_train_steps = False
|
overrode_max_train_steps = False
|
||||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
num_update_steps_per_epoch = math.ceil(len(train_dataloader))
|
||||||
if args.max_train_steps is None:
|
if args.max_train_steps is None:
|
||||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||||
overrode_max_train_steps = True
|
overrode_max_train_steps = True
|
||||||
@@ -537,8 +533,8 @@ def main(args):
|
|||||||
lr_scheduler = get_scheduler(
|
lr_scheduler = get_scheduler(
|
||||||
args.lr_scheduler,
|
args.lr_scheduler,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
|
num_warmup_steps=args.lr_warmup_steps,
|
||||||
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
num_training_steps=args.max_train_steps,
|
||||||
)
|
)
|
||||||
weight_dtype = torch.float32
|
weight_dtype = torch.float32
|
||||||
if args.mixed_precision == "fp16":
|
if args.mixed_precision == "fp16":
|
||||||
@@ -553,14 +549,14 @@ def main(args):
|
|||||||
text_encoder.to(get_current_device(), dtype=weight_dtype)
|
text_encoder.to(get_current_device(), dtype=weight_dtype)
|
||||||
|
|
||||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
num_update_steps_per_epoch = math.ceil(len(train_dataloader))
|
||||||
if overrode_max_train_steps:
|
if overrode_max_train_steps:
|
||||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||||
# Afterwards we recalculate our number of training epochs
|
# Afterwards we recalculate our number of training epochs
|
||||||
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||||
|
|
||||||
# Train!
|
# Train!
|
||||||
total_batch_size = args.train_batch_size * world_size * args.gradient_accumulation_steps
|
total_batch_size = args.train_batch_size * world_size
|
||||||
|
|
||||||
logger.info("***** Running training *****", ranks=[0])
|
logger.info("***** Running training *****", ranks=[0])
|
||||||
logger.info(f" Num examples = {len(train_dataset)}", ranks=[0])
|
logger.info(f" Num examples = {len(train_dataset)}", ranks=[0])
|
||||||
@@ -568,7 +564,6 @@ def main(args):
|
|||||||
logger.info(f" Num Epochs = {args.num_train_epochs}", ranks=[0])
|
logger.info(f" Num Epochs = {args.num_train_epochs}", ranks=[0])
|
||||||
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}", ranks=[0])
|
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}", ranks=[0])
|
||||||
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}", ranks=[0])
|
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}", ranks=[0])
|
||||||
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}", ranks=[0])
|
|
||||||
logger.info(f" Total optimization steps = {args.max_train_steps}", ranks=[0])
|
logger.info(f" Total optimization steps = {args.max_train_steps}", ranks=[0])
|
||||||
|
|
||||||
# Only show the progress bar once on each machine.
|
# Only show the progress bar once on each machine.
|
||||||
|
Reference in New Issue
Block a user