mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-27 07:47:05 +00:00
polish code (#3194)
Co-authored-by: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com>
This commit is contained in:
parent
4d5d8f98a4
commit
280fcdc485
@ -47,40 +47,21 @@ conda env create -f environment.yaml
|
|||||||
conda activate ldm
|
conda activate ldm
|
||||||
```
|
```
|
||||||
|
|
||||||
You can also update an existing [latent diffusion](https://github.com/CompVis/latent-diffusion) environment by running
|
You can also update an existing [latent diffusion](https://github.com/CompVis/latent-diffusion) environment by running:
|
||||||
|
|
||||||
```
|
```
|
||||||
conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch
|
conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch
|
||||||
pip install transformers diffusers invisible-watermark
|
pip install transformers diffusers invisible-watermark
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Step 2: install lightning
|
#### Step 2:Install [Colossal-AI](https://colossalai.org/download/) From Our Official Website
|
||||||
|
|
||||||
Install Lightning version later than 2022.01.04. We suggest you install lightning from source. Notice that the default download path of pip should be within the conda environment, or you may need to specify using 'which pip' and redirect the path into conda environment.
|
|
||||||
|
|
||||||
##### From Source
|
|
||||||
```
|
|
||||||
git clone https://github.com/Lightning-AI/lightning.git
|
|
||||||
pip install -r requirements.txt
|
|
||||||
python setup.py install
|
|
||||||
```
|
|
||||||
|
|
||||||
##### From pip
|
|
||||||
|
|
||||||
```
|
|
||||||
pip install pytorch-lightning
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Step 3:Install [Colossal-AI](https://colossalai.org/download/) From Our Official Website
|
|
||||||
|
|
||||||
You can install the latest version (0.2.7) from our official website or from source. Notice that the suitable version for this training is colossalai(0.2.5), which stands for torch(1.12.1).
|
You can install the latest version (0.2.7) from our official website or from source. Notice that the suitable version for this training is colossalai(0.2.5), which stands for torch(1.12.1).
|
||||||
|
|
||||||
##### Download suggested verision for this training
|
##### Download suggested verision for this training
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
pip install colossalai==0.2.5
|
pip install colossalai==0.2.5
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
##### Download the latest version from pip for latest torch version
|
##### Download the latest version from pip for latest torch version
|
||||||
@ -89,7 +70,7 @@ pip install colossalai==0.2.5
|
|||||||
pip install colossalai
|
pip install colossalai
|
||||||
```
|
```
|
||||||
|
|
||||||
##### From source
|
##### From source:
|
||||||
|
|
||||||
```
|
```
|
||||||
git clone https://github.com/hpcaitech/ColossalAI.git
|
git clone https://github.com/hpcaitech/ColossalAI.git
|
||||||
@ -99,7 +80,7 @@ cd ColossalAI
|
|||||||
CUDA_EXT=1 pip install .
|
CUDA_EXT=1 pip install .
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Step 4:Accelerate with flash attention by xformers(Optional)
|
#### Step 3:Accelerate with flash attention by xformers(Optional)
|
||||||
|
|
||||||
Notice that xformers will accelerate the training process in cost of extra disk space. The suitable version of xformers for this training process is 0.12.0. You can download xformers directly via pip. For more release versions, feel free to check its official website: [XFormers](./https://pypi.org/project/xformers/)
|
Notice that xformers will accelerate the training process in cost of extra disk space. The suitable version of xformers for this training process is 0.12.0. You can download xformers directly via pip. For more release versions, feel free to check its official website: [XFormers](./https://pypi.org/project/xformers/)
|
||||||
|
|
||||||
@ -113,7 +94,7 @@ To use the stable diffusion Docker image, you can either build using the provide
|
|||||||
|
|
||||||
```
|
```
|
||||||
# 1. build from dockerfile
|
# 1. build from dockerfile
|
||||||
cd docker
|
cd ColossalAI/examples/images/diffusion/docker
|
||||||
docker build -t hpcaitech/diffusion:0.2.0 .
|
docker build -t hpcaitech/diffusion:0.2.0 .
|
||||||
|
|
||||||
# 2. pull from our docker hub
|
# 2. pull from our docker hub
|
||||||
@ -127,7 +108,7 @@ Once you have the image ready, you can launch the image with the following comma
|
|||||||
# On Your Host Machine #
|
# On Your Host Machine #
|
||||||
########################
|
########################
|
||||||
# make sure you start your image in the repository root directory
|
# make sure you start your image in the repository root directory
|
||||||
cd Colossal-AI
|
cd ColossalAI
|
||||||
|
|
||||||
# run the docker container
|
# run the docker container
|
||||||
docker run --rm \
|
docker run --rm \
|
||||||
@ -144,13 +125,15 @@ docker run --rm \
|
|||||||
# Once you have entered the docker container, go to the stable diffusion directory for training
|
# Once you have entered the docker container, go to the stable diffusion directory for training
|
||||||
cd examples/images/diffusion/
|
cd examples/images/diffusion/
|
||||||
|
|
||||||
|
# Download the model checkpoint from pretrained (See the following steps)
|
||||||
|
# Set up your configuration the "train_colossalai.sh" (See the following steps)
|
||||||
# start training with colossalai
|
# start training with colossalai
|
||||||
bash train_colossalai.sh
|
bash train_colossalai.sh
|
||||||
```
|
```
|
||||||
|
|
||||||
It is important for you to configure your volume mapping in order to get the best training experience.
|
It is important for you to configure your volume mapping in order to get the best training experience.
|
||||||
1. **Mandatory**, mount your prepared data to `/data/scratch` via `-v <your-data-dir>:/data/scratch`, where you need to replace `<your-data-dir>` with the actual data path on your machine.
|
1. **Mandatory**, mount your prepared data to `/data/scratch` via `-v <your-data-dir>:/data/scratch`, where you need to replace `<your-data-dir>` with the actual data path on your machine. Notice that within docker we need to transform Win expresison into Linuxd, e.g. C:\User\Desktop into /c/User/Desktop.
|
||||||
2. **Recommended**, store the downloaded model weights to your host machine instead of the container directory via `-v <hf-cache-dir>:/root/.cache/huggingface`, where you need to repliace the `<hf-cache-dir>` with the actual path. In this way, you don't have to repeatedly download the pretrained weights for every `docker run`.
|
2. **Recommended**, store the downloaded model weights to your host machine instead of the container directory via `-v <hf-cache-dir>:/root/.cache/huggingface`, where you need to replace the `<hf-cache-dir>` with the actual path. In this way, you don't have to repeatedly download the pretrained weights for every `docker run`.
|
||||||
3. **Optional**, if you encounter any problem stating that shared memory is insufficient inside container, please add `-v /dev/shm:/dev/shm` to your `docker run` command.
|
3. **Optional**, if you encounter any problem stating that shared memory is insufficient inside container, please add `-v /dev/shm:/dev/shm` to your `docker run` command.
|
||||||
|
|
||||||
|
|
||||||
|
@ -5,87 +5,105 @@ from PIL import Image
|
|||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
|
|
||||||
|
# This class is used to create a dataset of images from LSUN dataset for training
|
||||||
class LSUNBase(Dataset):
|
class LSUNBase(Dataset):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
txt_file,
|
txt_file, # path to the text file containing the list of image paths
|
||||||
data_root,
|
data_root, # root directory of the LSUN dataset
|
||||||
size=None,
|
size=None, # the size of images to resize to
|
||||||
interpolation="bicubic",
|
interpolation="bicubic", # interpolation method to be used while resizing
|
||||||
flip_p=0.5
|
flip_p=0.5 # probability of random horizontal flipping
|
||||||
):
|
):
|
||||||
self.data_paths = txt_file
|
self.data_paths = txt_file # store path to text file containing list of images
|
||||||
self.data_root = data_root
|
self.data_root = data_root # store path to root directory of the dataset
|
||||||
with open(self.data_paths, "r") as f:
|
with open(self.data_paths, "r") as f: # open and read the text file
|
||||||
self.image_paths = f.read().splitlines()
|
self.image_paths = f.read().splitlines() # read the lines of the file and store as list
|
||||||
self._length = len(self.image_paths)
|
self._length = len(self.image_paths) # store the number of images
|
||||||
|
|
||||||
|
# create dictionary to hold image path information
|
||||||
self.labels = {
|
self.labels = {
|
||||||
"relative_file_path_": [l for l in self.image_paths],
|
"relative_file_path_": [l for l in self.image_paths],
|
||||||
"file_path_": [os.path.join(self.data_root, l)
|
"file_path_": [os.path.join(self.data_root, l)
|
||||||
for l in self.image_paths],
|
for l in self.image_paths],
|
||||||
}
|
}
|
||||||
|
|
||||||
self.size = size
|
# set the image size to be resized
|
||||||
|
self.size = size
|
||||||
|
# set the interpolation method for resizing the image
|
||||||
self.interpolation = {"linear": PIL.Image.LINEAR,
|
self.interpolation = {"linear": PIL.Image.LINEAR,
|
||||||
"bilinear": PIL.Image.BILINEAR,
|
"bilinear": PIL.Image.BILINEAR,
|
||||||
"bicubic": PIL.Image.BICUBIC,
|
"bicubic": PIL.Image.BICUBIC,
|
||||||
"lanczos": PIL.Image.LANCZOS,
|
"lanczos": PIL.Image.LANCZOS,
|
||||||
}[interpolation]
|
}[interpolation]
|
||||||
|
# randomly flip the image horizontally with a given probability
|
||||||
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
|
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
|
# return the length of dataset
|
||||||
return self._length
|
return self._length
|
||||||
|
|
||||||
|
|
||||||
def __getitem__(self, i):
|
def __getitem__(self, i):
|
||||||
|
# get the image path for the given index
|
||||||
example = dict((k, self.labels[k][i]) for k in self.labels)
|
example = dict((k, self.labels[k][i]) for k in self.labels)
|
||||||
image = Image.open(example["file_path_"])
|
image = Image.open(example["file_path_"])
|
||||||
|
# convert it to RGB format
|
||||||
if not image.mode == "RGB":
|
if not image.mode == "RGB":
|
||||||
image = image.convert("RGB")
|
image = image.convert("RGB")
|
||||||
|
|
||||||
# default to score-sde preprocessing
|
# default to score-sde preprocessing
|
||||||
img = np.array(image).astype(np.uint8)
|
|
||||||
crop = min(img.shape[0], img.shape[1])
|
img = np.array(image).astype(np.uint8) # convert image to numpy array
|
||||||
h, w, = img.shape[0], img.shape[1]
|
crop = min(img.shape[0], img.shape[1]) # crop the image to a square shape
|
||||||
|
h, w, = img.shape[0], img.shape[1] # get the height and width of image
|
||||||
img = img[(h - crop) // 2:(h + crop) // 2,
|
img = img[(h - crop) // 2:(h + crop) // 2,
|
||||||
(w - crop) // 2:(w + crop) // 2]
|
(w - crop) // 2:(w + crop) // 2] # crop the image to a square shape
|
||||||
|
|
||||||
image = Image.fromarray(img)
|
image = Image.fromarray(img) # create an image from numpy array
|
||||||
if self.size is not None:
|
if self.size is not None: # if image size is provided, resize the image
|
||||||
image = image.resize((self.size, self.size), resample=self.interpolation)
|
image = image.resize((self.size, self.size), resample=self.interpolation)
|
||||||
|
|
||||||
image = self.flip(image)
|
image = self.flip(image) # flip the image horizontally with the given probability
|
||||||
image = np.array(image).astype(np.uint8)
|
image = np.array(image).astype(np.uint8)
|
||||||
example["image"] = (image / 127.5 - 1.0).astype(np.float32)
|
example["image"] = (image / 127.5 - 1.0).astype(np.float32) # normalize the image values and convert to float32
|
||||||
return example
|
return example # return the example dictionary containing the image and its file paths
|
||||||
|
|
||||||
|
|
||||||
|
#A dataset class for LSUN Churches training set.
|
||||||
|
# It initializes by calling the constructor of LSUNBase class and passing the appropriate arguments.
|
||||||
|
# The text file containing the paths to the images and the root directory where the images are stored are passed as arguments. Any additional keyword arguments passed to this class will be forwarded to the constructor of the parent class.
|
||||||
class LSUNChurchesTrain(LSUNBase):
|
class LSUNChurchesTrain(LSUNBase):
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs)
|
super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs)
|
||||||
|
|
||||||
|
#A dataset class for LSUN Churches validation set.
|
||||||
|
# It is similar to LSUNChurchesTrain except that it uses a different text file and sets the flip probability to zero by default.
|
||||||
class LSUNChurchesValidation(LSUNBase):
|
class LSUNChurchesValidation(LSUNBase):
|
||||||
def __init__(self, flip_p=0., **kwargs):
|
def __init__(self, flip_p=0., **kwargs):
|
||||||
super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches",
|
super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches",
|
||||||
flip_p=flip_p, **kwargs)
|
flip_p=flip_p, **kwargs)
|
||||||
|
|
||||||
|
# A dataset class for LSUN Bedrooms training set.
|
||||||
|
# It initializes by calling the constructor of LSUNBase class and passing the appropriate arguments.
|
||||||
class LSUNBedroomsTrain(LSUNBase):
|
class LSUNBedroomsTrain(LSUNBase):
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs)
|
super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs)
|
||||||
|
|
||||||
|
# A dataset class for LSUN Bedrooms validation set.
|
||||||
|
# It is similar to LSUNBedroomsTrain except that it uses a different text file and sets the flip probability to zero by default.
|
||||||
class LSUNBedroomsValidation(LSUNBase):
|
class LSUNBedroomsValidation(LSUNBase):
|
||||||
def __init__(self, flip_p=0.0, **kwargs):
|
def __init__(self, flip_p=0.0, **kwargs):
|
||||||
super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms",
|
super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms",
|
||||||
flip_p=flip_p, **kwargs)
|
flip_p=flip_p, **kwargs)
|
||||||
|
|
||||||
|
# A dataset class for LSUN Cats training set.
|
||||||
|
# It initializes by calling the constructor of LSUNBase class and passing the appropriate arguments.
|
||||||
|
# The text file containing the paths to the images and the root directory where the images are stored are passed as arguments.
|
||||||
class LSUNCatsTrain(LSUNBase):
|
class LSUNCatsTrain(LSUNBase):
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs)
|
super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs)
|
||||||
|
|
||||||
|
# A dataset class for LSUN Cats validation set.
|
||||||
|
# It is similar to LSUNCatsTrain except that it uses a different text file and sets the flip probability to zero by default.
|
||||||
class LSUNCatsValidation(LSUNBase):
|
class LSUNCatsValidation(LSUNBase):
|
||||||
def __init__(self, flip_p=0., **kwargs):
|
def __init__(self, flip_p=0., **kwargs):
|
||||||
super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats",
|
super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats",
|
||||||
|
@ -44,14 +44,18 @@ from ldm.util import instantiate_from_config
|
|||||||
|
|
||||||
|
|
||||||
class DataLoaderX(DataLoader):
|
class DataLoaderX(DataLoader):
|
||||||
|
# A custom data loader class that inherits from DataLoader
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
|
# Overriding the __iter__ method of DataLoader to return a BackgroundGenerator
|
||||||
|
#This is to enable data laoding in the background to improve training performance
|
||||||
return BackgroundGenerator(super().__iter__())
|
return BackgroundGenerator(super().__iter__())
|
||||||
|
|
||||||
|
|
||||||
def get_parser(**parser_kwargs):
|
def get_parser(**parser_kwargs):
|
||||||
|
#A function to create an ArgumentParser object and add arguments to it
|
||||||
|
|
||||||
def str2bool(v):
|
def str2bool(v):
|
||||||
|
# A helper function to parse boolean values from command line arguments
|
||||||
if isinstance(v, bool):
|
if isinstance(v, bool):
|
||||||
return v
|
return v
|
||||||
if v.lower() in ("yes", "true", "t", "y", "1"):
|
if v.lower() in ("yes", "true", "t", "y", "1"):
|
||||||
@ -60,8 +64,10 @@ def get_parser(**parser_kwargs):
|
|||||||
return False
|
return False
|
||||||
else:
|
else:
|
||||||
raise argparse.ArgumentTypeError("Boolean value expected.")
|
raise argparse.ArgumentTypeError("Boolean value expected.")
|
||||||
|
# Create an ArgumentParser object with specifies kwargs
|
||||||
parser = argparse.ArgumentParser(**parser_kwargs)
|
parser = argparse.ArgumentParser(**parser_kwargs)
|
||||||
|
|
||||||
|
# Add vairous command line arguments with their default balues and descriptions
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-n",
|
"-n",
|
||||||
"--name",
|
"--name",
|
||||||
@ -161,14 +167,18 @@ def get_parser(**parser_kwargs):
|
|||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
# A function that returns the non-default arguments between two objects
|
||||||
def nondefault_trainer_args(opt):
|
def nondefault_trainer_args(opt):
|
||||||
|
# create an argument parsser
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
# add pytorch lightning trainer default arguments
|
||||||
parser = Trainer.add_argparse_args(parser)
|
parser = Trainer.add_argparse_args(parser)
|
||||||
|
# parse the empty arguments to obtain the default values
|
||||||
args = parser.parse_args([])
|
args = parser.parse_args([])
|
||||||
|
# return all non-default arguments
|
||||||
return sorted(k for k in vars(args) if getattr(opt, k) != getattr(args, k))
|
return sorted(k for k in vars(args) if getattr(opt, k) != getattr(args, k))
|
||||||
|
|
||||||
|
# A dataset wrapper class to create a pytorch dataset from an arbitrary object
|
||||||
class WrappedDataset(Dataset):
|
class WrappedDataset(Dataset):
|
||||||
"""Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset"""
|
"""Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset"""
|
||||||
|
|
||||||
@ -181,7 +191,7 @@ class WrappedDataset(Dataset):
|
|||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
return self.data[idx]
|
return self.data[idx]
|
||||||
|
|
||||||
|
# A function to initialize worker processes
|
||||||
def worker_init_fn(_):
|
def worker_init_fn(_):
|
||||||
worker_info = torch.utils.data.get_worker_info()
|
worker_info = torch.utils.data.get_worker_info()
|
||||||
|
|
||||||
@ -189,15 +199,18 @@ def worker_init_fn(_):
|
|||||||
worker_id = worker_info.id
|
worker_id = worker_info.id
|
||||||
|
|
||||||
if isinstance(dataset, Txt2ImgIterableBaseDataset):
|
if isinstance(dataset, Txt2ImgIterableBaseDataset):
|
||||||
|
#divide the dataset into equal parts for each worker
|
||||||
split_size = dataset.num_records // worker_info.num_workers
|
split_size = dataset.num_records // worker_info.num_workers
|
||||||
|
#set the sample IDs for the current worker
|
||||||
# reset num_records to the true number to retain reliable length information
|
# reset num_records to the true number to retain reliable length information
|
||||||
dataset.sample_ids = dataset.valid_ids[worker_id * split_size:(worker_id + 1) * split_size]
|
dataset.sample_ids = dataset.valid_ids[worker_id * split_size:(worker_id + 1) * split_size]
|
||||||
|
# set the seed for the current worker
|
||||||
current_id = np.random.choice(len(np.random.get_state()[1]), 1)
|
current_id = np.random.choice(len(np.random.get_state()[1]), 1)
|
||||||
return np.random.seed(np.random.get_state()[1][current_id] + worker_id)
|
return np.random.seed(np.random.get_state()[1][current_id] + worker_id)
|
||||||
else:
|
else:
|
||||||
return np.random.seed(np.random.get_state()[1][0] + worker_id)
|
return np.random.seed(np.random.get_state()[1][0] + worker_id)
|
||||||
|
|
||||||
|
#Provide functionality for creating data loadedrs based on provided dataset configurations
|
||||||
class DataModuleFromConfig(pl.LightningDataModule):
|
class DataModuleFromConfig(pl.LightningDataModule):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -212,10 +225,12 @@ class DataModuleFromConfig(pl.LightningDataModule):
|
|||||||
use_worker_init_fn=False,
|
use_worker_init_fn=False,
|
||||||
shuffle_val_dataloader=False):
|
shuffle_val_dataloader=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
# Set data module attributes
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.dataset_configs = dict()
|
self.dataset_configs = dict()
|
||||||
self.num_workers = num_workers if num_workers is not None else batch_size * 2
|
self.num_workers = num_workers if num_workers is not None else batch_size * 2
|
||||||
self.use_worker_init_fn = use_worker_init_fn
|
self.use_worker_init_fn = use_worker_init_fn
|
||||||
|
# If a dataset is passed, add it to the dataset configs and create a corresponding dataloader method
|
||||||
if train is not None:
|
if train is not None:
|
||||||
self.dataset_configs["train"] = train
|
self.dataset_configs["train"] = train
|
||||||
self.train_dataloader = self._train_dataloader
|
self.train_dataloader = self._train_dataloader
|
||||||
@ -231,21 +246,28 @@ class DataModuleFromConfig(pl.LightningDataModule):
|
|||||||
self.wrap = wrap
|
self.wrap = wrap
|
||||||
|
|
||||||
def prepare_data(self):
|
def prepare_data(self):
|
||||||
|
# Instantiate datasets
|
||||||
for data_cfg in self.dataset_configs.values():
|
for data_cfg in self.dataset_configs.values():
|
||||||
instantiate_from_config(data_cfg)
|
instantiate_from_config(data_cfg)
|
||||||
|
|
||||||
def setup(self, stage=None):
|
def setup(self, stage=None):
|
||||||
|
# Instantiate datasets from the dataset configs
|
||||||
self.datasets = dict((k, instantiate_from_config(self.dataset_configs[k])) for k in self.dataset_configs)
|
self.datasets = dict((k, instantiate_from_config(self.dataset_configs[k])) for k in self.dataset_configs)
|
||||||
|
|
||||||
|
# If wrap is true, create a WrappedDataset for each dataset
|
||||||
if self.wrap:
|
if self.wrap:
|
||||||
for k in self.datasets:
|
for k in self.datasets:
|
||||||
self.datasets[k] = WrappedDataset(self.datasets[k])
|
self.datasets[k] = WrappedDataset(self.datasets[k])
|
||||||
|
|
||||||
def _train_dataloader(self):
|
def _train_dataloader(self):
|
||||||
|
#Check if the train dataset is iterable
|
||||||
is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset)
|
is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset)
|
||||||
|
#Set the worker initialization function of the dataset isiterable or use_worker_init_fn is True
|
||||||
if is_iterable_dataset or self.use_worker_init_fn:
|
if is_iterable_dataset or self.use_worker_init_fn:
|
||||||
init_fn = worker_init_fn
|
init_fn = worker_init_fn
|
||||||
else:
|
else:
|
||||||
init_fn = None
|
init_fn = None
|
||||||
|
# Return a DataLoaderX object for the train dataset
|
||||||
return DataLoaderX(self.datasets["train"],
|
return DataLoaderX(self.datasets["train"],
|
||||||
batch_size=self.batch_size,
|
batch_size=self.batch_size,
|
||||||
num_workers=self.num_workers,
|
num_workers=self.num_workers,
|
||||||
@ -253,10 +275,12 @@ class DataModuleFromConfig(pl.LightningDataModule):
|
|||||||
worker_init_fn=init_fn)
|
worker_init_fn=init_fn)
|
||||||
|
|
||||||
def _val_dataloader(self, shuffle=False):
|
def _val_dataloader(self, shuffle=False):
|
||||||
|
#Check if the validation dataset is iterable
|
||||||
if isinstance(self.datasets['validation'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn:
|
if isinstance(self.datasets['validation'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn:
|
||||||
init_fn = worker_init_fn
|
init_fn = worker_init_fn
|
||||||
else:
|
else:
|
||||||
init_fn = None
|
init_fn = None
|
||||||
|
# Return a DataLoaderX object for the validation dataset
|
||||||
return DataLoaderX(self.datasets["validation"],
|
return DataLoaderX(self.datasets["validation"],
|
||||||
batch_size=self.batch_size,
|
batch_size=self.batch_size,
|
||||||
num_workers=self.num_workers,
|
num_workers=self.num_workers,
|
||||||
@ -264,7 +288,9 @@ class DataModuleFromConfig(pl.LightningDataModule):
|
|||||||
shuffle=shuffle)
|
shuffle=shuffle)
|
||||||
|
|
||||||
def _test_dataloader(self, shuffle=False):
|
def _test_dataloader(self, shuffle=False):
|
||||||
|
# Check if the test dataset is iterable
|
||||||
is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset)
|
is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset)
|
||||||
|
# Set the worker initialization function if the dataset is iterable or use_worker_init_fn is True
|
||||||
if is_iterable_dataset or self.use_worker_init_fn:
|
if is_iterable_dataset or self.use_worker_init_fn:
|
||||||
init_fn = worker_init_fn
|
init_fn = worker_init_fn
|
||||||
else:
|
else:
|
||||||
@ -291,6 +317,7 @@ class DataModuleFromConfig(pl.LightningDataModule):
|
|||||||
|
|
||||||
|
|
||||||
class SetupCallback(Callback):
|
class SetupCallback(Callback):
|
||||||
|
# I nitialize the callback with the necessary parameters
|
||||||
|
|
||||||
def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config):
|
def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -302,12 +329,14 @@ class SetupCallback(Callback):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.lightning_config = lightning_config
|
self.lightning_config = lightning_config
|
||||||
|
|
||||||
|
# Save a checkpoint if training is interrupted with keyboard interrupt
|
||||||
def on_keyboard_interrupt(self, trainer, pl_module):
|
def on_keyboard_interrupt(self, trainer, pl_module):
|
||||||
if trainer.global_rank == 0:
|
if trainer.global_rank == 0:
|
||||||
print("Summoning checkpoint.")
|
print("Summoning checkpoint.")
|
||||||
ckpt_path = os.path.join(self.ckptdir, "last.ckpt")
|
ckpt_path = os.path.join(self.ckptdir, "last.ckpt")
|
||||||
trainer.save_checkpoint(ckpt_path)
|
trainer.save_checkpoint(ckpt_path)
|
||||||
|
|
||||||
|
# Create necessary directories and save configuration files before training starts
|
||||||
# def on_pretrain_routine_start(self, trainer, pl_module):
|
# def on_pretrain_routine_start(self, trainer, pl_module):
|
||||||
def on_fit_start(self, trainer, pl_module):
|
def on_fit_start(self, trainer, pl_module):
|
||||||
if trainer.global_rank == 0:
|
if trainer.global_rank == 0:
|
||||||
@ -316,6 +345,7 @@ class SetupCallback(Callback):
|
|||||||
os.makedirs(self.ckptdir, exist_ok=True)
|
os.makedirs(self.ckptdir, exist_ok=True)
|
||||||
os.makedirs(self.cfgdir, exist_ok=True)
|
os.makedirs(self.cfgdir, exist_ok=True)
|
||||||
|
|
||||||
|
#Create trainstep checkpoint directory if necessary
|
||||||
if "callbacks" in self.lightning_config:
|
if "callbacks" in self.lightning_config:
|
||||||
if 'metrics_over_trainsteps_checkpoint' in self.lightning_config['callbacks']:
|
if 'metrics_over_trainsteps_checkpoint' in self.lightning_config['callbacks']:
|
||||||
os.makedirs(os.path.join(self.ckptdir, 'trainstep_checkpoints'), exist_ok=True)
|
os.makedirs(os.path.join(self.ckptdir, 'trainstep_checkpoints'), exist_ok=True)
|
||||||
@ -323,11 +353,13 @@ class SetupCallback(Callback):
|
|||||||
print(OmegaConf.to_yaml(self.config))
|
print(OmegaConf.to_yaml(self.config))
|
||||||
OmegaConf.save(self.config, os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)))
|
OmegaConf.save(self.config, os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)))
|
||||||
|
|
||||||
|
# Save project config and lightning config as YAML files
|
||||||
print("Lightning config")
|
print("Lightning config")
|
||||||
print(OmegaConf.to_yaml(self.lightning_config))
|
print(OmegaConf.to_yaml(self.lightning_config))
|
||||||
OmegaConf.save(OmegaConf.create({"lightning": self.lightning_config}),
|
OmegaConf.save(OmegaConf.create({"lightning": self.lightning_config}),
|
||||||
os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now)))
|
os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now)))
|
||||||
|
|
||||||
|
# Remove log directory if resuming training and directory already exists
|
||||||
else:
|
else:
|
||||||
# ModelCheckpoint callback created log directory --- remove it
|
# ModelCheckpoint callback created log directory --- remove it
|
||||||
if not self.resume and os.path.exists(self.logdir):
|
if not self.resume and os.path.exists(self.logdir):
|
||||||
@ -346,25 +378,28 @@ class SetupCallback(Callback):
|
|||||||
# trainer.save_checkpoint(ckpt_path)
|
# trainer.save_checkpoint(ckpt_path)
|
||||||
|
|
||||||
|
|
||||||
|
# PyTorch Lightning callback for ogging images during training and validation of a deep learning model
|
||||||
class ImageLogger(Callback):
|
class ImageLogger(Callback):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
batch_frequency,
|
batch_frequency, # Frequency of batches on which to log images
|
||||||
max_images,
|
max_images, # Maximum number of images to log
|
||||||
clamp=True,
|
clamp=True, # Whether to clamp pixel values to [-1,1]
|
||||||
increase_log_steps=True,
|
increase_log_steps=True, # Whether to increase frequency of log steps exponentially
|
||||||
rescale=True,
|
rescale=True, # Whetehr to rescale pixel values to [0,1]
|
||||||
disabled=False,
|
disabled=False, # Whether to disable logging
|
||||||
log_on_batch_idx=False,
|
log_on_batch_idx=False, # Whether to log on baych index instead of global step
|
||||||
log_first_step=False,
|
log_first_step=False, # Whetehr to log on the first step
|
||||||
log_images_kwargs=None):
|
log_images_kwargs=None): # Additional keyword arguments to pass to log_images method
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.rescale = rescale
|
self.rescale = rescale
|
||||||
self.batch_freq = batch_frequency
|
self.batch_freq = batch_frequency
|
||||||
self.max_images = max_images
|
self.max_images = max_images
|
||||||
self.logger_log_images = {
|
self.logger_log_images = {
|
||||||
pl.loggers.CSVLogger: self._testtube,
|
# Dictionary of logger classes and their corresponding logging methods
|
||||||
|
pl.loggers.CSVLogger: self._testtube,
|
||||||
}
|
}
|
||||||
|
# Create a list of exponentially increasing log steps, starting from 1 and ending at batch_frequency
|
||||||
self.log_steps = [2**n for n in range(int(np.log2(self.batch_freq)) + 1)]
|
self.log_steps = [2**n for n in range(int(np.log2(self.batch_freq)) + 1)]
|
||||||
if not increase_log_steps:
|
if not increase_log_steps:
|
||||||
self.log_steps = [self.batch_freq]
|
self.log_steps = [self.batch_freq]
|
||||||
@ -374,17 +409,32 @@ class ImageLogger(Callback):
|
|||||||
self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
|
self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
|
||||||
self.log_first_step = log_first_step
|
self.log_first_step = log_first_step
|
||||||
|
|
||||||
@rank_zero_only
|
@rank_zero_only # Ensure that only the first process in distributed training executes this method
|
||||||
def _testtube(self, pl_module, images, batch_idx, split):
|
def _testtube(self, # The PyTorch Lightning module
|
||||||
|
pl_module, # A dictionary of images to log.
|
||||||
|
images, #
|
||||||
|
batch_idx, # The batch index.
|
||||||
|
split # The split (train/val) on which to log the images
|
||||||
|
):
|
||||||
|
# Method for logging images using test-tube logger
|
||||||
for k in images:
|
for k in images:
|
||||||
grid = torchvision.utils.make_grid(images[k])
|
grid = torchvision.utils.make_grid(images[k])
|
||||||
grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
|
grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
|
||||||
|
|
||||||
tag = f"{split}/{k}"
|
tag = f"{split}/{k}"
|
||||||
|
# Add image grid to logger's experiment
|
||||||
pl_module.logger.experiment.add_image(tag, grid, global_step=pl_module.global_step)
|
pl_module.logger.experiment.add_image(tag, grid, global_step=pl_module.global_step)
|
||||||
|
|
||||||
@rank_zero_only
|
@rank_zero_only
|
||||||
def log_local(self, save_dir, split, images, global_step, current_epoch, batch_idx):
|
def log_local(self,
|
||||||
|
save_dir,
|
||||||
|
split, # The split (train/val) on which to log the images
|
||||||
|
images, # A dictionary of images to log
|
||||||
|
global_step, # The global step
|
||||||
|
current_epoch, # The current epoch.
|
||||||
|
batch_idx
|
||||||
|
):
|
||||||
|
# Method for saving image grids to local file system
|
||||||
root = os.path.join(save_dir, "images", split)
|
root = os.path.join(save_dir, "images", split)
|
||||||
for k in images:
|
for k in images:
|
||||||
grid = torchvision.utils.make_grid(images[k], nrow=4)
|
grid = torchvision.utils.make_grid(images[k], nrow=4)
|
||||||
@ -396,12 +446,16 @@ class ImageLogger(Callback):
|
|||||||
filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(k, global_step, current_epoch, batch_idx)
|
filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(k, global_step, current_epoch, batch_idx)
|
||||||
path = os.path.join(root, filename)
|
path = os.path.join(root, filename)
|
||||||
os.makedirs(os.path.split(path)[0], exist_ok=True)
|
os.makedirs(os.path.split(path)[0], exist_ok=True)
|
||||||
|
# Save image grid as PNG file
|
||||||
Image.fromarray(grid).save(path)
|
Image.fromarray(grid).save(path)
|
||||||
|
|
||||||
def log_img(self, pl_module, batch, batch_idx, split="train"):
|
def log_img(self, pl_module, batch, batch_idx, split="train"):
|
||||||
|
#Function for logging images to both the logger and local file system.
|
||||||
check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step
|
check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step
|
||||||
|
# check if it's time to log an image batch
|
||||||
if (self.check_frequency(check_idx) and # batch_idx % self.batch_freq == 0
|
if (self.check_frequency(check_idx) and # batch_idx % self.batch_freq == 0
|
||||||
hasattr(pl_module, "log_images") and callable(pl_module.log_images) and self.max_images > 0):
|
hasattr(pl_module, "log_images") and callable(pl_module.log_images) and self.max_images > 0):
|
||||||
|
# Get logger type and check if training mode is on
|
||||||
logger = type(pl_module.logger)
|
logger = type(pl_module.logger)
|
||||||
|
|
||||||
is_train = pl_module.training
|
is_train = pl_module.training
|
||||||
@ -409,8 +463,10 @@ class ImageLogger(Callback):
|
|||||||
pl_module.eval()
|
pl_module.eval()
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
# Get images from log_images method of the pl_module
|
||||||
images = pl_module.log_images(batch, split=split, **self.log_images_kwargs)
|
images = pl_module.log_images(batch, split=split, **self.log_images_kwargs)
|
||||||
|
|
||||||
|
# Clip images if specified and convert to CPU tensor
|
||||||
for k in images:
|
for k in images:
|
||||||
N = min(images[k].shape[0], self.max_images)
|
N = min(images[k].shape[0], self.max_images)
|
||||||
images[k] = images[k][:N]
|
images[k] = images[k][:N]
|
||||||
@ -419,15 +475,19 @@ class ImageLogger(Callback):
|
|||||||
if self.clamp:
|
if self.clamp:
|
||||||
images[k] = torch.clamp(images[k], -1., 1.)
|
images[k] = torch.clamp(images[k], -1., 1.)
|
||||||
|
|
||||||
|
# Log images locally to file system
|
||||||
self.log_local(pl_module.logger.save_dir, split, images, pl_module.global_step, pl_module.current_epoch,
|
self.log_local(pl_module.logger.save_dir, split, images, pl_module.global_step, pl_module.current_epoch,
|
||||||
batch_idx)
|
batch_idx)
|
||||||
|
|
||||||
|
# log the images using the logger
|
||||||
logger_log_images = self.logger_log_images.get(logger, lambda *args, **kwargs: None)
|
logger_log_images = self.logger_log_images.get(logger, lambda *args, **kwargs: None)
|
||||||
logger_log_images(pl_module, images, pl_module.global_step, split)
|
logger_log_images(pl_module, images, pl_module.global_step, split)
|
||||||
|
|
||||||
|
# switch back to training mode if necessary
|
||||||
if is_train:
|
if is_train:
|
||||||
pl_module.train()
|
pl_module.train()
|
||||||
|
|
||||||
|
# The function checks if it's time to log an image batch
|
||||||
def check_frequency(self, check_idx):
|
def check_frequency(self, check_idx):
|
||||||
if ((check_idx % self.batch_freq) == 0 or
|
if ((check_idx % self.batch_freq) == 0 or
|
||||||
(check_idx in self.log_steps)) and (check_idx > 0 or self.log_first_step):
|
(check_idx in self.log_steps)) and (check_idx > 0 or self.log_first_step):
|
||||||
@ -439,14 +499,17 @@ class ImageLogger(Callback):
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
# Log images on train batch end if logging is not disabled
|
||||||
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
||||||
# if not self.disabled and (pl_module.global_step > 0 or self.log_first_step):
|
# if not self.disabled and (pl_module.global_step > 0 or self.log_first_step):
|
||||||
# self.log_img(pl_module, batch, batch_idx, split="train")
|
# self.log_img(pl_module, batch, batch_idx, split="train")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# Log images on validation batch end if logging is not disabled and in validation mode
|
||||||
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
||||||
if not self.disabled and pl_module.global_step > 0:
|
if not self.disabled and pl_module.global_step > 0:
|
||||||
self.log_img(pl_module, batch, batch_idx, split="val")
|
self.log_img(pl_module, batch, batch_idx, split="val")
|
||||||
|
# log gradients during calibration if necessary
|
||||||
if hasattr(pl_module, 'calibrate_grad_norm'):
|
if hasattr(pl_module, 'calibrate_grad_norm'):
|
||||||
if (pl_module.calibrate_grad_norm and batch_idx % 25 == 0) and batch_idx > 0:
|
if (pl_module.calibrate_grad_norm and batch_idx % 25 == 0) and batch_idx > 0:
|
||||||
self.log_gradients(trainer, pl_module, batch_idx=batch_idx)
|
self.log_gradients(trainer, pl_module, batch_idx=batch_idx)
|
||||||
@ -458,6 +521,7 @@ class CUDACallback(Callback):
|
|||||||
def on_train_start(self, trainer, pl_module):
|
def on_train_start(self, trainer, pl_module):
|
||||||
rank_zero_info("Training is starting")
|
rank_zero_info("Training is starting")
|
||||||
|
|
||||||
|
#the method is called at the end of each training epoch
|
||||||
def on_train_end(self, trainer, pl_module):
|
def on_train_end(self, trainer, pl_module):
|
||||||
rank_zero_info("Training is ending")
|
rank_zero_info("Training is ending")
|
||||||
|
|
||||||
@ -524,6 +588,7 @@ if __name__ == "__main__":
|
|||||||
# params:
|
# params:
|
||||||
# key: value
|
# key: value
|
||||||
|
|
||||||
|
# get the current time to create a new logging directory
|
||||||
now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
|
now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
|
||||||
|
|
||||||
# add cwd for convenience and to make classes in this file available when
|
# add cwd for convenience and to make classes in this file available when
|
||||||
@ -535,11 +600,13 @@ if __name__ == "__main__":
|
|||||||
parser = Trainer.add_argparse_args(parser)
|
parser = Trainer.add_argparse_args(parser)
|
||||||
|
|
||||||
opt, unknown = parser.parse_known_args()
|
opt, unknown = parser.parse_known_args()
|
||||||
|
# Veirfy the arguments are both specified
|
||||||
if opt.name and opt.resume:
|
if opt.name and opt.resume:
|
||||||
raise ValueError("-n/--name and -r/--resume cannot be specified both."
|
raise ValueError("-n/--name and -r/--resume cannot be specified both."
|
||||||
"If you want to resume training in a new log folder, "
|
"If you want to resume training in a new log folder, "
|
||||||
"use -n/--name in combination with --resume_from_checkpoint")
|
"use -n/--name in combination with --resume_from_checkpoint")
|
||||||
|
|
||||||
|
# Check if the "resume" option is specified, resume training from the checkpoint if it is true
|
||||||
ckpt = None
|
ckpt = None
|
||||||
if opt.resume:
|
if opt.resume:
|
||||||
rank_zero_info("Resuming from {}".format(opt.resume))
|
rank_zero_info("Resuming from {}".format(opt.resume))
|
||||||
@ -557,8 +624,10 @@ if __name__ == "__main__":
|
|||||||
logdir = opt.resume.rstrip("/")
|
logdir = opt.resume.rstrip("/")
|
||||||
ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
|
ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
|
||||||
|
|
||||||
|
# Finds all ".yaml" configuration files in the log directory and adds them to the list of base configurations
|
||||||
base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml")))
|
base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml")))
|
||||||
opt.base = base_configs + opt.base
|
opt.base = base_configs + opt.base
|
||||||
|
# Gets the name of the current log directory by splitting the path and taking the last element.
|
||||||
_tmp = logdir.split("/")
|
_tmp = logdir.split("/")
|
||||||
nowname = _tmp[-1]
|
nowname = _tmp[-1]
|
||||||
else:
|
else:
|
||||||
@ -574,13 +643,17 @@ if __name__ == "__main__":
|
|||||||
nowname = now + name + opt.postfix
|
nowname = now + name + opt.postfix
|
||||||
logdir = os.path.join(opt.logdir, nowname)
|
logdir = os.path.join(opt.logdir, nowname)
|
||||||
|
|
||||||
|
# Sets the checkpoint path of the 'ckpt' option is specified
|
||||||
if opt.ckpt:
|
if opt.ckpt:
|
||||||
ckpt = opt.ckpt
|
ckpt = opt.ckpt
|
||||||
|
|
||||||
|
# Create the checkpoint and configuration directories within the log directory.
|
||||||
ckptdir = os.path.join(logdir, "checkpoints")
|
ckptdir = os.path.join(logdir, "checkpoints")
|
||||||
cfgdir = os.path.join(logdir, "configs")
|
cfgdir = os.path.join(logdir, "configs")
|
||||||
|
# Sets the seed for the random number generator to ensure reproducibility
|
||||||
seed_everything(opt.seed)
|
seed_everything(opt.seed)
|
||||||
|
|
||||||
|
# Intinalize and save configuratioon using teh OmegaConf library.
|
||||||
try:
|
try:
|
||||||
# init and save configs
|
# init and save configs
|
||||||
configs = [OmegaConf.load(cfg) for cfg in opt.base]
|
configs = [OmegaConf.load(cfg) for cfg in opt.base]
|
||||||
@ -593,6 +666,7 @@ if __name__ == "__main__":
|
|||||||
for k in nondefault_trainer_args(opt):
|
for k in nondefault_trainer_args(opt):
|
||||||
trainer_config[k] = getattr(opt, k)
|
trainer_config[k] = getattr(opt, k)
|
||||||
|
|
||||||
|
# Check whether the accelerator is gpu
|
||||||
if not trainer_config["accelerator"] == "gpu":
|
if not trainer_config["accelerator"] == "gpu":
|
||||||
del trainer_config["accelerator"]
|
del trainer_config["accelerator"]
|
||||||
cpu = True
|
cpu = True
|
||||||
@ -609,6 +683,7 @@ if __name__ == "__main__":
|
|||||||
config.model["params"].update({"use_fp16": False})
|
config.model["params"].update({"use_fp16": False})
|
||||||
|
|
||||||
if ckpt is not None:
|
if ckpt is not None:
|
||||||
|
#If a checkpoint path is specified in the ckpt variable, the code updates the "ckpt" key in the "params" dictionary of the config.model configuration with the value of ckpt
|
||||||
config.model["params"].update({"ckpt": ckpt})
|
config.model["params"].update({"ckpt": ckpt})
|
||||||
rank_zero_info("Using ckpt_path = {}".format(config.model["params"]["ckpt"]))
|
rank_zero_info("Using ckpt_path = {}".format(config.model["params"]["ckpt"]))
|
||||||
|
|
||||||
@ -617,7 +692,8 @@ if __name__ == "__main__":
|
|||||||
trainer_kwargs = dict()
|
trainer_kwargs = dict()
|
||||||
|
|
||||||
# config the logger
|
# config the logger
|
||||||
# default logger configs
|
# Default logger configs to log training metrics during the training process.
|
||||||
|
# These loggers are specified as targets in the dictionary, along with the configuration settings specific to each logger.
|
||||||
default_logger_cfgs = {
|
default_logger_cfgs = {
|
||||||
"wandb": {
|
"wandb": {
|
||||||
"target": LIGHTNING_PACK_NAME + "loggers.WandbLogger",
|
"target": LIGHTNING_PACK_NAME + "loggers.WandbLogger",
|
||||||
@ -638,6 +714,7 @@ if __name__ == "__main__":
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Set up the logger for TensorBoard
|
||||||
default_logger_cfg = default_logger_cfgs["tensorboard"]
|
default_logger_cfg = default_logger_cfgs["tensorboard"]
|
||||||
if "logger" in lightning_config:
|
if "logger" in lightning_config:
|
||||||
logger_cfg = lightning_config.logger
|
logger_cfg = lightning_config.logger
|
||||||
@ -660,6 +737,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
trainer_kwargs["strategy"] = instantiate_from_config(strategy_cfg)
|
trainer_kwargs["strategy"] = instantiate_from_config(strategy_cfg)
|
||||||
|
|
||||||
|
# Set up ModelCheckpoint callback to save best models
|
||||||
# modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to
|
# modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to
|
||||||
# specify which metric is used to determine best models
|
# specify which metric is used to determine best models
|
||||||
default_modelckpt_cfg = {
|
default_modelckpt_cfg = {
|
||||||
@ -683,45 +761,50 @@ if __name__ == "__main__":
|
|||||||
if version.parse(pl.__version__) < version.parse('1.4.0'):
|
if version.parse(pl.__version__) < version.parse('1.4.0'):
|
||||||
trainer_kwargs["checkpoint_callback"] = instantiate_from_config(modelckpt_cfg)
|
trainer_kwargs["checkpoint_callback"] = instantiate_from_config(modelckpt_cfg)
|
||||||
|
|
||||||
|
# Set up various callbacks, including logging, learning rate monitoring, and CUDA management
|
||||||
# add callback which sets up log directory
|
# add callback which sets up log directory
|
||||||
default_callbacks_cfg = {
|
default_callbacks_cfg = {
|
||||||
"setup_callback": {
|
"setup_callback": { # callback to set up the training
|
||||||
"target": "main.SetupCallback",
|
"target": "main.SetupCallback",
|
||||||
"params": {
|
"params": {
|
||||||
"resume": opt.resume,
|
"resume": opt.resume, # resume training if applicable
|
||||||
"now": now,
|
"now": now,
|
||||||
"logdir": logdir,
|
"logdir": logdir, # directory to save the log file
|
||||||
"ckptdir": ckptdir,
|
"ckptdir": ckptdir, # directory to save the checkpoint file
|
||||||
"cfgdir": cfgdir,
|
"cfgdir": cfgdir, # directory to save the configuration file
|
||||||
"config": config,
|
"config": config, # configuration dictionary
|
||||||
"lightning_config": lightning_config,
|
"lightning_config": lightning_config, # LightningModule configuration
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"image_logger": {
|
"image_logger": { # callback to log image data
|
||||||
"target": "main.ImageLogger",
|
"target": "main.ImageLogger",
|
||||||
"params": {
|
"params": {
|
||||||
"batch_frequency": 750,
|
"batch_frequency": 750, # how frequently to log images
|
||||||
"max_images": 4,
|
"max_images": 4, # maximum number of images to log
|
||||||
"clamp": True
|
"clamp": True # whether to clamp pixel values to [0,1]
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"learning_rate_logger": {
|
"learning_rate_logger": { # callback to log learning rate
|
||||||
"target": "main.LearningRateMonitor",
|
"target": "main.LearningRateMonitor",
|
||||||
"params": {
|
"params": {
|
||||||
"logging_interval": "step",
|
"logging_interval": "step", # logging frequency (either 'step' or 'epoch')
|
||||||
# "log_momentum": True
|
# "log_momentum": True # whether to log momentum (currently commented out)
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"cuda_callback": {
|
"cuda_callback": { # callback to handle CUDA-related operations
|
||||||
"target": "main.CUDACallback"
|
"target": "main.CUDACallback"
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# If the LightningModule configuration has specified callbacks, use those
|
||||||
|
# Otherwise, create an empty OmegaConf configuration object
|
||||||
if "callbacks" in lightning_config:
|
if "callbacks" in lightning_config:
|
||||||
callbacks_cfg = lightning_config.callbacks
|
callbacks_cfg = lightning_config.callbacks
|
||||||
else:
|
else:
|
||||||
callbacks_cfg = OmegaConf.create()
|
callbacks_cfg = OmegaConf.create()
|
||||||
|
|
||||||
|
# If the 'metrics_over_trainsteps_checkpoint' callback is specified in the
|
||||||
|
# LightningModule configuration, update the default callbacks configuration
|
||||||
if 'metrics_over_trainsteps_checkpoint' in callbacks_cfg:
|
if 'metrics_over_trainsteps_checkpoint' in callbacks_cfg:
|
||||||
print(
|
print(
|
||||||
'Caution: Saving checkpoints every n train steps without deleting. This might require some free space.')
|
'Caution: Saving checkpoints every n train steps without deleting. This might require some free space.')
|
||||||
@ -739,15 +822,17 @@ if __name__ == "__main__":
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict)
|
default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict)
|
||||||
|
|
||||||
|
# Merge the default callbacks configuration with the specified callbacks configuration, and instantiate the callbacks
|
||||||
callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
|
callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
|
||||||
|
|
||||||
trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg]
|
trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg]
|
||||||
|
|
||||||
|
# Create a Trainer object with the specified command-line arguments and keyword arguments, and set the log directory
|
||||||
trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
|
trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
|
||||||
trainer.logdir = logdir
|
trainer.logdir = logdir
|
||||||
|
|
||||||
# data
|
# Create a data module based on the configuration file
|
||||||
data = instantiate_from_config(config.data)
|
data = instantiate_from_config(config.data)
|
||||||
# NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
|
# NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
|
||||||
# calling these ourselves should not be necessary but it is.
|
# calling these ourselves should not be necessary but it is.
|
||||||
@ -755,10 +840,12 @@ if __name__ == "__main__":
|
|||||||
data.prepare_data()
|
data.prepare_data()
|
||||||
data.setup()
|
data.setup()
|
||||||
|
|
||||||
|
# Print some information about the datasets in the data module
|
||||||
for k in data.datasets:
|
for k in data.datasets:
|
||||||
rank_zero_info(f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}")
|
rank_zero_info(f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}")
|
||||||
|
|
||||||
# configure learning rate
|
# Configure learning rate based on the batch size, base learning rate and number of GPUs
|
||||||
|
# If scale_lr is true, calculate the learning rate based on additional factors
|
||||||
bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate
|
bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate
|
||||||
if not cpu:
|
if not cpu:
|
||||||
ngpu = trainer_config["devices"]
|
ngpu = trainer_config["devices"]
|
||||||
@ -780,7 +867,7 @@ if __name__ == "__main__":
|
|||||||
rank_zero_info("++++ NOT USING LR SCALING ++++")
|
rank_zero_info("++++ NOT USING LR SCALING ++++")
|
||||||
rank_zero_info(f"Setting learning rate to {model.learning_rate:.2e}")
|
rank_zero_info(f"Setting learning rate to {model.learning_rate:.2e}")
|
||||||
|
|
||||||
# allow checkpointing via USR1
|
# Allow checkpointing via USR1
|
||||||
def melk(*args, **kwargs):
|
def melk(*args, **kwargs):
|
||||||
# run all checkpoint hooks
|
# run all checkpoint hooks
|
||||||
if trainer.global_rank == 0:
|
if trainer.global_rank == 0:
|
||||||
@ -794,20 +881,23 @@ if __name__ == "__main__":
|
|||||||
pudb.set_trace()
|
pudb.set_trace()
|
||||||
|
|
||||||
import signal
|
import signal
|
||||||
|
# Assign melk to SIGUSR1 signal and divein to SIGUSR2 signal
|
||||||
signal.signal(signal.SIGUSR1, melk)
|
signal.signal(signal.SIGUSR1, melk)
|
||||||
signal.signal(signal.SIGUSR2, divein)
|
signal.signal(signal.SIGUSR2, divein)
|
||||||
|
|
||||||
# run
|
# Run the training and validation
|
||||||
if opt.train:
|
if opt.train:
|
||||||
try:
|
try:
|
||||||
trainer.fit(model, data)
|
trainer.fit(model, data)
|
||||||
except Exception:
|
except Exception:
|
||||||
melk()
|
melk()
|
||||||
raise
|
raise
|
||||||
|
# Print the maximum GPU memory allocated during training
|
||||||
|
print(f"GPU memory usage: {torch.cuda.max_memory_allocated() / 1024**2:.0f} MB")
|
||||||
# if not opt.no_test and not trainer.interrupted:
|
# if not opt.no_test and not trainer.interrupted:
|
||||||
# trainer.test(model, data)
|
# trainer.test(model, data)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
# If there's an exception, debug it if opt.debug is true and the trainer's global rank is 0
|
||||||
if opt.debug and trainer.global_rank == 0:
|
if opt.debug and trainer.global_rank == 0:
|
||||||
try:
|
try:
|
||||||
import pudb as debugger
|
import pudb as debugger
|
||||||
@ -816,7 +906,7 @@ if __name__ == "__main__":
|
|||||||
debugger.post_mortem()
|
debugger.post_mortem()
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
# move newly created debug project to debug_runs
|
# Move the log directory to debug_runs if opt.debug is true and the trainer's global
|
||||||
if opt.debug and not opt.resume and trainer.global_rank == 0:
|
if opt.debug and not opt.resume and trainer.global_rank == 0:
|
||||||
dst, name = os.path.split(logdir)
|
dst, name = os.path.split(logdir)
|
||||||
dst = os.path.join(dst, "debug_runs", name)
|
dst = os.path.join(dst, "debug_runs", name)
|
||||||
|
Loading…
Reference in New Issue
Block a user