Polish Code

This commit is contained in:
natalie_cao 2023-04-11 14:10:45 +08:00 committed by アマデウス
parent 152239bbfa
commit de84c0311a
15 changed files with 562 additions and 719 deletions

View File

@ -1,6 +1,5 @@
model: model:
base_learning_rate: 1.0e-4 base_learning_rate: 1.0e-4
target: ldm.models.diffusion.ddpm.LatentDiffusion
params: params:
parameterization: "v" parameterization: "v"
linear_start: 0.00085 linear_start: 0.00085
@ -19,50 +18,42 @@ model:
use_ema: False # we set this to false because this is an inference only config use_ema: False # we set this to false because this is an inference only config
unet_config: unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel use_checkpoint: True
params: use_fp16: True
use_checkpoint: True image_size: 32 # unused
use_fp16: True in_channels: 4
image_size: 32 # unused out_channels: 4
in_channels: 4 model_channels: 320
out_channels: 4 attention_resolutions: [ 4, 2, 1 ]
model_channels: 320 num_res_blocks: 2
attention_resolutions: [ 4, 2, 1 ] channel_mult: [ 1, 2, 4, 4 ]
num_res_blocks: 2 num_head_channels: 64 # need to fix for flash-attn
channel_mult: [ 1, 2, 4, 4 ] use_spatial_transformer: True
num_head_channels: 64 # need to fix for flash-attn use_linear_in_transformer: True
use_spatial_transformer: True transformer_depth: 1
use_linear_in_transformer: True context_dim: 1024
transformer_depth: 1 legacy: False
context_dim: 1024
legacy: False
first_stage_config: first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL embed_dim: 4
params: monitor: val/rec_loss
embed_dim: 4 ddconfig:
monitor: val/rec_loss #attn_type: "vanilla-xformers"
ddconfig: double_z: true
#attn_type: "vanilla-xformers" z_channels: 4
double_z: true resolution: 256
z_channels: 4 in_channels: 3
resolution: 256 out_ch: 3
in_channels: 3 ch: 128
out_ch: 3 ch_mult:
ch: 128 - 1
ch_mult: - 2
- 1 - 4
- 2 - 4
- 4 num_res_blocks: 2
- 4 attn_resolutions: []
num_res_blocks: 2 dropout: 0.0
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config: cond_stage_config:
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder freeze: True
params: layer: "penultimate"
freeze: True
layer: "penultimate"

View File

@ -1,6 +1,5 @@
model: model:
base_learning_rate: 1.0e-4 base_learning_rate: 1.0e-4
target: ldm.models.diffusion.ddpm.LatentDiffusion
params: params:
linear_start: 0.00085 linear_start: 0.00085
linear_end: 0.0120 linear_end: 0.0120
@ -18,50 +17,42 @@ model:
use_ema: False # we set this to false because this is an inference only config use_ema: False # we set this to false because this is an inference only config
unet_config: unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel use_checkpoint: True
params: use_fp16: True
use_checkpoint: True image_size: 32 # unused
use_fp16: True in_channels: 4
image_size: 32 # unused out_channels: 4
in_channels: 4 model_channels: 320
out_channels: 4 attention_resolutions: [ 4, 2, 1 ]
model_channels: 320 num_res_blocks: 2
attention_resolutions: [ 4, 2, 1 ] channel_mult: [ 1, 2, 4, 4 ]
num_res_blocks: 2 num_head_channels: 64 # need to fix for flash-attn
channel_mult: [ 1, 2, 4, 4 ] use_spatial_transformer: True
num_head_channels: 64 # need to fix for flash-attn use_linear_in_transformer: True
use_spatial_transformer: True transformer_depth: 1
use_linear_in_transformer: True context_dim: 1024
transformer_depth: 1 legacy: False
context_dim: 1024
legacy: False
first_stage_config: first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL embed_dim: 4
params: monitor: val/rec_loss
embed_dim: 4 ddconfig:
monitor: val/rec_loss #attn_type: "vanilla-xformers"
ddconfig: double_z: true
#attn_type: "vanilla-xformers" z_channels: 4
double_z: true resolution: 256
z_channels: 4 in_channels: 3
resolution: 256 out_ch: 3
in_channels: 3 ch: 128
out_ch: 3 ch_mult:
ch: 128 - 1
ch_mult: - 2
- 1 - 4
- 2 - 4
- 4 num_res_blocks: 2
- 4 attn_resolutions: []
num_res_blocks: 2 dropout: 0.0
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config: cond_stage_config:
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder freeze: True
params: layer: "penultimate"
freeze: True
layer: "penultimate"

View File

@ -19,106 +19,97 @@ model:
use_ema: False use_ema: False
unet_config: unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel use_checkpoint: True
params: image_size: 32 # unused
use_checkpoint: True in_channels: 9
image_size: 32 # unused out_channels: 4
in_channels: 9 model_channels: 320
out_channels: 4 attention_resolutions: [ 4, 2, 1 ]
model_channels: 320 num_res_blocks: 2
attention_resolutions: [ 4, 2, 1 ] channel_mult: [ 1, 2, 4, 4 ]
num_res_blocks: 2 num_head_channels: 64 # need to fix for flash-attn
channel_mult: [ 1, 2, 4, 4 ] use_spatial_transformer: True
num_head_channels: 64 # need to fix for flash-attn use_linear_in_transformer: True
use_spatial_transformer: True transformer_depth: 1
use_linear_in_transformer: True context_dim: 1024
transformer_depth: 1 legacy: False
context_dim: 1024
legacy: False
first_stage_config: first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL embed_dim: 4
params: monitor: val/rec_loss
embed_dim: 4 ddconfig:
monitor: val/rec_loss #attn_type: "vanilla-xformers"
ddconfig: double_z: true
#attn_type: "vanilla-xformers" z_channels: 4
double_z: true resolution: 256
z_channels: 4 in_channels: 3
resolution: 256 out_ch: 3
in_channels: 3 ch: 128
out_ch: 3 ch_mult:
ch: 128 - 1
ch_mult: - 2
- 1 - 4
- 2 - 4
- 4 num_res_blocks: 2
- 4 attn_resolutions: [ ]
num_res_blocks: 2 dropout: 0.0
attn_resolutions: [ ] lossconfig:
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config: cond_stage_config:
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder freeze: True
params: layer: "penultimate"
freeze: True
layer: "penultimate"
data: data:
target: ldm.data.laion.WebDataModuleFromConfig tar_base: null # for concat as in LAION-A
params: p_unsafe_threshold: 0.1
tar_base: null # for concat as in LAION-A filter_word_list: "data/filters.yaml"
p_unsafe_threshold: 0.1 max_pwatermark: 0.45
filter_word_list: "data/filters.yaml" batch_size: 8
max_pwatermark: 0.45 num_workers: 6
batch_size: 8 multinode: True
num_workers: 6 min_size: 512
multinode: True train:
min_size: 512 shards:
train: - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-0/{00000..18699}.tar -"
shards: - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-1/{00000..18699}.tar -"
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-0/{00000..18699}.tar -" - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-2/{00000..18699}.tar -"
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-1/{00000..18699}.tar -" - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-3/{00000..18699}.tar -"
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-2/{00000..18699}.tar -" - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-4/{00000..18699}.tar -" #{00000-94333}.tar"
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-3/{00000..18699}.tar -" shuffle: 10000
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-4/{00000..18699}.tar -" #{00000-94333}.tar" image_key: jpg
shuffle: 10000 image_transforms:
image_key: jpg - target: torchvision.transforms.Resize
image_transforms: params:
- target: torchvision.transforms.Resize size: 512
params: interpolation: 3
size: 512 - target: torchvision.transforms.RandomCrop
interpolation: 3 params:
- target: torchvision.transforms.RandomCrop size: 512
params: postprocess:
size: 512 target: ldm.data.laion.AddMask
postprocess: params:
target: ldm.data.laion.AddMask mode: "512train-large"
params: p_drop: 0.25
mode: "512train-large" # NOTE use enough shards to avoid empty validation loops in workers
p_drop: 0.25 validation:
# NOTE use enough shards to avoid empty validation loops in workers shards:
validation: - "pipe:aws s3 cp s3://deep-floyd-s3/datasets/laion_cleaned-part5/{93001..94333}.tar - "
shards: shuffle: 0
- "pipe:aws s3 cp s3://deep-floyd-s3/datasets/laion_cleaned-part5/{93001..94333}.tar - " image_key: jpg
shuffle: 0 image_transforms:
image_key: jpg - target: torchvision.transforms.Resize
image_transforms: params:
- target: torchvision.transforms.Resize size: 512
params: interpolation: 3
size: 512 - target: torchvision.transforms.CenterCrop
interpolation: 3 params:
- target: torchvision.transforms.CenterCrop size: 512
params: postprocess:
size: 512 target: ldm.data.laion.AddMask
postprocess: params:
target: ldm.data.laion.AddMask mode: "512train-large"
params: p_drop: 0.25
mode: "512train-large"
p_drop: 0.25
lightning: lightning:
find_unused_parameters: True find_unused_parameters: True
@ -132,8 +123,6 @@ lightning:
every_n_train_steps: 10000 every_n_train_steps: 10000
image_logger: image_logger:
target: main.ImageLogger
params:
enable_autocast: False enable_autocast: False
disabled: False disabled: False
batch_frequency: 1000 batch_frequency: 1000

View File

@ -19,54 +19,45 @@ model:
use_ema: False use_ema: False
depth_stage_config: depth_stage_config:
target: ldm.modules.midas.api.MiDaSInference model_type: "dpt_hybrid"
params:
model_type: "dpt_hybrid"
unet_config: unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel use_checkpoint: True
params: image_size: 32 # unused
use_checkpoint: True in_channels: 5
image_size: 32 # unused out_channels: 4
in_channels: 5 model_channels: 320
out_channels: 4 attention_resolutions: [ 4, 2, 1 ]
model_channels: 320 num_res_blocks: 2
attention_resolutions: [ 4, 2, 1 ] channel_mult: [ 1, 2, 4, 4 ]
num_res_blocks: 2 num_head_channels: 64 # need to fix for flash-attn
channel_mult: [ 1, 2, 4, 4 ] use_spatial_transformer: True
num_head_channels: 64 # need to fix for flash-attn use_linear_in_transformer: True
use_spatial_transformer: True transformer_depth: 1
use_linear_in_transformer: True context_dim: 1024
transformer_depth: 1 legacy: False
context_dim: 1024
legacy: False
first_stage_config: first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL embed_dim: 4
params: monitor: val/rec_loss
embed_dim: 4 ddconfig:
monitor: val/rec_loss #attn_type: "vanilla-xformers"
ddconfig: double_z: true
#attn_type: "vanilla-xformers" z_channels: 4
double_z: true resolution: 256
z_channels: 4 in_channels: 3
resolution: 256 out_ch: 3
in_channels: 3 ch: 128
out_ch: 3 ch_mult:
ch: 128 - 1
ch_mult: - 2
- 1 - 4
- 2 - 4
- 4 num_res_blocks: 2
- 4 attn_resolutions: [ ]
num_res_blocks: 2 dropout: 0.0
attn_resolutions: [ ] lossconfig:
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config: cond_stage_config:
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder freeze: True
params: layer: "penultimate"
freeze: True
layer: "penultimate"

View File

@ -20,56 +20,47 @@ model:
use_ema: False use_ema: False
low_scale_config: low_scale_config:
target: ldm.modules.diffusionmodules.upscaling.ImageConcatWithNoiseAugmentation noise_schedule_config: # image space
params: linear_start: 0.0001
noise_schedule_config: # image space linear_end: 0.02
linear_start: 0.0001 max_noise_level: 350
linear_end: 0.02
max_noise_level: 350
unet_config: unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel use_checkpoint: True
params: num_classes: 1000 # timesteps for noise conditioning (here constant, just need one)
use_checkpoint: True image_size: 128
num_classes: 1000 # timesteps for noise conditioning (here constant, just need one) in_channels: 7
image_size: 128 out_channels: 4
in_channels: 7 model_channels: 256
out_channels: 4 attention_resolutions: [ 2,4,8]
model_channels: 256 num_res_blocks: 2
attention_resolutions: [ 2,4,8] channel_mult: [ 1, 2, 2, 4]
num_res_blocks: 2 disable_self_attentions: [True, True, True, False]
channel_mult: [ 1, 2, 2, 4] disable_middle_self_attn: False
disable_self_attentions: [True, True, True, False] num_heads: 8
disable_middle_self_attn: False use_spatial_transformer: True
num_heads: 8 transformer_depth: 1
use_spatial_transformer: True context_dim: 1024
transformer_depth: 1 legacy: False
context_dim: 1024 use_linear_in_transformer: True
legacy: False
use_linear_in_transformer: True
first_stage_config: first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL embed_dim: 4
params: ddconfig:
embed_dim: 4 # attn_type: "vanilla-xformers" this model needs efficient attention to be feasible on HR data, also the decoder seems to break in half precision (UNet is fine though)
ddconfig: double_z: True
# attn_type: "vanilla-xformers" this model needs efficient attention to be feasible on HR data, also the decoder seems to break in half precision (UNet is fine though) z_channels: 4
double_z: True resolution: 256
z_channels: 4 in_channels: 3
resolution: 256 out_ch: 3
in_channels: 3 ch: 128
out_ch: 3 ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1
ch: 128 num_res_blocks: 2
ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1 attn_resolutions: [ ]
num_res_blocks: 2 dropout: 0.0
attn_resolutions: [ ] lossconfig:
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config: cond_stage_config:
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder freeze: True
params: layer: "penultimate"
freeze: True
layer: "penultimate"

View File

@ -1,6 +1,5 @@
model: model:
base_learning_rate: 1.0e-4 base_learning_rate: 1.0e-4
target: ldm.models.diffusion.ddpm.LatentDiffusion
params: params:
parameterization: "v" parameterization: "v"
linear_start: 0.00085 linear_start: 0.00085
@ -20,81 +19,70 @@ model:
use_ema: False use_ema: False
scheduler_config: # 10000 warmup steps scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch
params: cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch f_start: [ 1.e-6 ]
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases f_max: [ 1.e-4 ]
f_start: [ 1.e-6 ] f_min: [ 1.e-10 ]
f_max: [ 1.e-4 ]
f_min: [ 1.e-10 ]
unet_config: unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel use_checkpoint: True
params: use_fp16: True
use_checkpoint: True image_size: 32 # unused
use_fp16: True in_channels: 4
image_size: 32 # unused out_channels: 4
in_channels: 4 model_channels: 320
out_channels: 4 attention_resolutions: [ 4, 2, 1 ]
model_channels: 320 num_res_blocks: 2
attention_resolutions: [ 4, 2, 1 ] channel_mult: [ 1, 2, 4, 4 ]
num_res_blocks: 2 num_head_channels: 64 # need to fix for flash-attn
channel_mult: [ 1, 2, 4, 4 ] use_spatial_transformer: True
num_head_channels: 64 # need to fix for flash-attn use_linear_in_transformer: True
use_spatial_transformer: True transformer_depth: 1
use_linear_in_transformer: True context_dim: 1024
transformer_depth: 1 legacy: False
context_dim: 1024
legacy: False
first_stage_config: first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL embed_dim: 4
params: monitor: val/rec_loss
embed_dim: 4 ddconfig:
monitor: val/rec_loss #attn_type: "vanilla-xformers"
ddconfig: double_z: true
#attn_type: "vanilla-xformers" z_channels: 4
double_z: true resolution: 256
z_channels: 4 in_channels: 3
resolution: 256 out_ch: 3
in_channels: 3 ch: 128
out_ch: 3 ch_mult:
ch: 128 - 1
ch_mult: - 2
- 1 - 4
- 2 - 4
- 4 num_res_blocks: 2
- 4 attn_resolutions: []
num_res_blocks: 2 dropout: 0.0
attn_resolutions: [] lossconfig:
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config: cond_stage_config:
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder freeze: True
params: layer: "penultimate"
freeze: True
layer: "penultimate"
data: data:
target: main.DataModuleFromConfig batch_size: 16
params: num_workers: 4
batch_size: 16 train:
num_workers: 4 target: ldm.data.teyvat.hf_dataset
train: params:
target: ldm.data.teyvat.hf_dataset path: Fazzie/Teyvat
params: image_transforms:
path: Fazzie/Teyvat - target: torchvision.transforms.Resize
image_transforms: params:
- target: torchvision.transforms.Resize size: 512
params: - target: torchvision.transforms.RandomCrop
size: 512 params:
- target: torchvision.transforms.RandomCrop size: 512
params: - target: torchvision.transforms.RandomHorizontalFlip
size: 512
- target: torchvision.transforms.RandomHorizontalFlip
lightning: lightning:
trainer: trainer:
@ -105,13 +93,11 @@ lightning:
precision: 16 precision: 16
auto_select_gpus: False auto_select_gpus: False
strategy: strategy:
target: strategies.ColossalAIStrategy use_chunk: True
params: enable_distributed_storage: True
use_chunk: True placement_policy: cuda
enable_distributed_storage: True force_outputs_fp32: true
placement_policy: cuda min_chunk_size: 64
force_outputs_fp32: true
min_chunk_size: 64
log_every_n_steps: 2 log_every_n_steps: 2
logger: True logger: True
@ -120,9 +106,7 @@ lightning:
logger_config: logger_config:
wandb: wandb:
target: loggers.WandbLogger name: nowname
params: save_dir: "/tmp/diff_log/"
name: nowname offline: opt.debug
save_dir: "/tmp/diff_log/" id: nowname
offline: opt.debug
id: nowname

View File

@ -1,6 +1,5 @@
model: model:
base_learning_rate: 1.0e-4 base_learning_rate: 1.0e-4
target: ldm.models.diffusion.ddpm.LatentDiffusion
params: params:
parameterization: "v" parameterization: "v"
linear_start: 0.00085 linear_start: 0.00085
@ -19,95 +18,83 @@ model:
use_ema: False # we set this to false because this is an inference only config use_ema: False # we set this to false because this is an inference only config
scheduler_config: # 10000 warmup steps scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch
params: cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch f_start: [ 1.e-6 ]
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases f_max: [ 1.e-4 ]
f_start: [ 1.e-6 ] f_min: [ 1.e-10 ]
f_max: [ 1.e-4 ]
f_min: [ 1.e-10 ]
unet_config: unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel use_checkpoint: True
params: use_fp16: True
use_checkpoint: True image_size: 32 # unused
use_fp16: True in_channels: 4
image_size: 32 # unused out_channels: 4
in_channels: 4 model_channels: 320
out_channels: 4 attention_resolutions: [ 4, 2, 1 ]
model_channels: 320 num_res_blocks: 2
attention_resolutions: [ 4, 2, 1 ] channel_mult: [ 1, 2, 4, 4 ]
num_res_blocks: 2 num_head_channels: 64 # need to fix for flash-attn
channel_mult: [ 1, 2, 4, 4 ] use_spatial_transformer: True
num_head_channels: 64 # need to fix for flash-attn use_linear_in_transformer: True
use_spatial_transformer: True transformer_depth: 1
use_linear_in_transformer: True context_dim: 1024
transformer_depth: 1 legacy: False
context_dim: 1024
legacy: False
first_stage_config: first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL embed_dim: 4
params: monitor: val/rec_loss
embed_dim: 4 ddconfig:
monitor: val/rec_loss #attn_type: "vanilla-xformers"
ddconfig: double_z: true
#attn_type: "vanilla-xformers" z_channels: 4
double_z: true resolution: 256
z_channels: 4 in_channels: 3
resolution: 256 out_ch: 3
in_channels: 3 ch: 128
out_ch: 3 ch_mult:
ch: 128 - 1
ch_mult: - 2
- 1 - 4
- 2 - 4
- 4 num_res_blocks: 2
- 4 attn_resolutions: []
num_res_blocks: 2 dropout: 0.0
attn_resolutions: [] lossconfig:
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config: cond_stage_config:
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder freeze: True
params: layer: "penultimate"
freeze: True
layer: "penultimate"
data: data:
target: main.DataModuleFromConfig batch_size: 128
params: wrap: False
batch_size: 128 # num_workwers should be 2 * batch_size, and total num less than 1024
wrap: False # e.g. if use 8 devices, no more than 128
# num_workwers should be 2 * batch_size, and total num less than 1024 num_workers: 128
# e.g. if use 8 devices, no more than 128 train:
num_workers: 128 target: ldm.data.base.Txt2ImgIterableBaseDataset
train: params:
target: ldm.data.base.Txt2ImgIterableBaseDataset file_path: # YOUR DATASET_PATH
params: world_size: 1
file_path: # YOUR DATASET_PATH rank: 0
world_size: 1
rank: 0
lightning: lightning:
trainer: trainer:
accelerator: 'gpu' accelerator: 'gpu'
devices: 8 devices: 2
log_gpu_memory: all log_gpu_memory: all
max_epochs: 2 max_epochs: 2
precision: 16 precision: 16
auto_select_gpus: False auto_select_gpus: False
strategy: strategy:
target: strategies.ColossalAIStrategy use_chunk: True
params: enable_distributed_storage: True
use_chunk: True placement_policy: cuda
enable_distributed_storage: True force_outputs_fp32: true
placement_policy: cuda min_chunk_size: 64
force_outputs_fp32: true
min_chunk_size: 64
log_every_n_steps: 2 log_every_n_steps: 2
logger: True logger: True
@ -116,9 +103,7 @@ lightning:
logger_config: logger_config:
wandb: wandb:
target: loggers.WandbLogger name: nowname
params: save_dir: "/tmp/diff_log/"
name: nowname offline: opt.debug
save_dir: "/tmp/diff_log/" id: nowname
offline: opt.debug
id: nowname

View File

@ -1,6 +1,5 @@
model: model:
base_learning_rate: 1.0e-4 base_learning_rate: 1.0e-4
target: ldm.models.diffusion.ddpm.LatentDiffusion
params: params:
parameterization: "v" parameterization: "v"
linear_start: 0.00085 linear_start: 0.00085
@ -19,82 +18,71 @@ model:
use_ema: False # we set this to false because this is an inference only config use_ema: False # we set this to false because this is an inference only config
scheduler_config: # 10000 warmup steps scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch
params: cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch f_start: [ 1.e-6 ]
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases f_max: [ 1.e-4 ]
f_start: [ 1.e-6 ] f_min: [ 1.e-10 ]
f_max: [ 1.e-4 ]
f_min: [ 1.e-10 ]
unet_config: unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel use_checkpoint: True
params: use_fp16: True
use_checkpoint: True image_size: 32 # unused
use_fp16: True in_channels: 4
image_size: 32 # unused out_channels: 4
in_channels: 4 model_channels: 320
out_channels: 4 attention_resolutions: [ 4, 2, 1 ]
model_channels: 320 num_res_blocks: 2
attention_resolutions: [ 4, 2, 1 ] channel_mult: [ 1, 2, 4, 4 ]
num_res_blocks: 2 num_head_channels: 64 # need to fix for flash-attn
channel_mult: [ 1, 2, 4, 4 ] use_spatial_transformer: True
num_head_channels: 64 # need to fix for flash-attn use_linear_in_transformer: True
use_spatial_transformer: True transformer_depth: 1
use_linear_in_transformer: True context_dim: 1024
transformer_depth: 1 legacy: False
context_dim: 1024
legacy: False
first_stage_config: first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL embed_dim: 4
params: monitor: val/rec_loss
embed_dim: 4 ddconfig:
monitor: val/rec_loss #attn_type: "vanilla-xformers"
ddconfig: double_z: true
#attn_type: "vanilla-xformers" z_channels: 4
double_z: true resolution: 256
z_channels: 4 in_channels: 3
resolution: 256 out_ch: 3
in_channels: 3 ch: 128
out_ch: 3 ch_mult:
ch: 128 - 1
ch_mult: - 2
- 1 - 4
- 2 - 4
- 4 num_res_blocks: 2
- 4 attn_resolutions: []
num_res_blocks: 2 dropout: 0.0
attn_resolutions: [] lossconfig:
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config: cond_stage_config:
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder freeze: True
params: layer: "penultimate"
freeze: True
layer: "penultimate"
data: data:
target: main.DataModuleFromConfig batch_size: 4
params: num_workers: 4
batch_size: 4 train:
num_workers: 4 target: ldm.data.cifar10.hf_dataset
train: params:
target: ldm.data.cifar10.hf_dataset name: cifar10
params: image_transforms:
name: cifar10 - target: torchvision.transforms.Resize
image_transforms: params:
- target: torchvision.transforms.Resize size: 512
params: interpolation: 3
size: 512 - target: torchvision.transforms.RandomCrop
interpolation: 3 params:
- target: torchvision.transforms.RandomCrop size: 512
params: - target: torchvision.transforms.RandomHorizontalFlip
size: 512
- target: torchvision.transforms.RandomHorizontalFlip
lightning: lightning:
trainer: trainer:
@ -105,13 +93,11 @@ lightning:
precision: 16 precision: 16
auto_select_gpus: False auto_select_gpus: False
strategy: strategy:
target: strategies.ColossalAIStrategy use_chunk: True
params: enable_distributed_storage: True
use_chunk: True placement_policy: cuda
enable_distributed_storage: True force_outputs_fp32: true
placement_policy: cuda min_chunk_size: 64
force_outputs_fp32: true
min_chunk_size: 64
log_every_n_steps: 2 log_every_n_steps: 2
logger: True logger: True
@ -120,9 +106,7 @@ lightning:
logger_config: logger_config:
wandb: wandb:
target: loggers.WandbLogger name: nowname
params: save_dir: "/tmp/diff_log/"
name: nowname offline: opt.debug
save_dir: "/tmp/diff_log/" id: nowname
offline: opt.debug
id: nowname

View File

@ -1,6 +1,5 @@
model: model:
base_learning_rate: 1.0e-4 base_learning_rate: 1.0e-4
target: ldm.models.diffusion.ddpm.LatentDiffusion
params: params:
parameterization: "v" parameterization: "v"
linear_start: 0.00085 linear_start: 0.00085
@ -19,77 +18,65 @@ model:
use_ema: False # we set this to false because this is an inference only config use_ema: False # we set this to false because this is an inference only config
scheduler_config: # 10000 warmup steps scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch
params: cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch f_start: [ 1.e-6 ]
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases f_max: [ 1.e-4 ]
f_start: [ 1.e-6 ] f_min: [ 1.e-10 ]
f_max: [ 1.e-4 ]
f_min: [ 1.e-10 ]
unet_config: unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel use_checkpoint: True
params: use_fp16: True
use_checkpoint: True image_size: 32 # unused
use_fp16: True in_channels: 4
image_size: 32 # unused out_channels: 4
in_channels: 4 model_channels: 320
out_channels: 4 attention_resolutions: [ 4, 2, 1 ]
model_channels: 320 num_res_blocks: 2
attention_resolutions: [ 4, 2, 1 ] channel_mult: [ 1, 2, 4, 4 ]
num_res_blocks: 2 num_head_channels: 64 # need to fix for flash-attn
channel_mult: [ 1, 2, 4, 4 ] use_spatial_transformer: True
num_head_channels: 64 # need to fix for flash-attn use_linear_in_transformer: True
use_spatial_transformer: True transformer_depth: 1
use_linear_in_transformer: True context_dim: 1024
transformer_depth: 1 legacy: False
context_dim: 1024
legacy: False
first_stage_config: first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL embed_dim: 4
params: monitor: val/rec_loss
embed_dim: 4 ddconfig:
monitor: val/rec_loss #attn_type: "vanilla-xformers"
ddconfig: double_z: true
#attn_type: "vanilla-xformers" z_channels: 4
double_z: true resolution: 256
z_channels: 4 in_channels: 3
resolution: 256 out_ch: 3
in_channels: 3 ch: 128
out_ch: 3 ch_mult:
ch: 128 - 1
ch_mult: - 2
- 1 - 4
- 2 - 4
- 4 num_res_blocks: 2
- 4 attn_resolutions: []
num_res_blocks: 2 dropout: 0.0
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config: cond_stage_config:
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder freeze: True
params: layer: "penultimate"
freeze: True
layer: "penultimate"
data: data:
target: main.DataModuleFromConfig batch_size: 128
params: # num_workwers should be 2 * batch_size, and the total num less than 1024
batch_size: 128 # e.g. if use 8 devices, no more than 128
# num_workwers should be 2 * batch_size, and the total num less than 1024 num_workers: 128
# e.g. if use 8 devices, no more than 128 train:
num_workers: 128 target: ldm.data.base.Txt2ImgIterableBaseDataset
train: params:
target: ldm.data.base.Txt2ImgIterableBaseDataset file_path: # YOUR DATAPATH
params: world_size: 1
file_path: # YOUR DATAPATH rank: 0
world_size: 1
rank: 0
lightning: lightning:
trainer: trainer:
@ -100,9 +87,7 @@ lightning:
precision: 16 precision: 16
auto_select_gpus: False auto_select_gpus: False
strategy: strategy:
target: strategies.DDPStrategy find_unused_parameters: False
params:
find_unused_parameters: False
log_every_n_steps: 2 log_every_n_steps: 2
# max_steps: 6o # max_steps: 6o
logger: True logger: True
@ -111,9 +96,7 @@ lightning:
logger_config: logger_config:
wandb: wandb:
target: loggers.WandbLogger name: nowname
params: save_dir: "/data2/tmp/diff_log/"
name: nowname offline: opt.debug
save_dir: "/data2/tmp/diff_log/" id: nowname
offline: opt.debug
id: nowname

View File

@ -1,16 +1,13 @@
import torch import torch
try: import lightning.pytorch as pl
import lightning.pytorch as pl
except:
import pytorch_lightning as pl
import torch.nn.functional as F from torch import nn
from torch.nn import functional as F
from torch.nn import Identity
from contextlib import contextmanager from contextlib import contextmanager
from ldm.modules.diffusionmodules.model import Encoder, Decoder from ldm.modules.diffusionmodules.model import Encoder, Decoder
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
from ldm.util import instantiate_from_config
from ldm.modules.ema import LitEma from ldm.modules.ema import LitEma
@ -32,7 +29,7 @@ class AutoencoderKL(pl.LightningModule):
self.image_key = image_key self.image_key = image_key
self.encoder = Encoder(**ddconfig) self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig) self.decoder = Decoder(**ddconfig)
self.loss = instantiate_from_config(lossconfig) self.loss = Identity()
assert ddconfig["double_z"] assert ddconfig["double_z"]
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)

View File

@ -9,9 +9,10 @@ from copy import deepcopy
from einops import rearrange from einops import rearrange
from glob import glob from glob import glob
from natsort import natsorted from natsort import natsorted
from ldm.models.diffusion.ddpm import LatentDiffusion
from ldm.lr_scheduler import LambdaLinearScheduler
from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel
from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config from ldm.util import log_txt_as_img, default, ismap
__models__ = { __models__ = {
'class_label': EncoderUNetModel, 'class_label': EncoderUNetModel,
@ -86,7 +87,7 @@ class NoisyLatentImageClassifier(pl.LightningModule):
print(f"Unexpected Keys: {unexpected}") print(f"Unexpected Keys: {unexpected}")
def load_diffusion(self): def load_diffusion(self):
model = instantiate_from_config(self.diffusion_config) model = LatentDiffusion(**self.diffusion_config.get('params',dict()))
self.diffusion_model = model.eval() self.diffusion_model = model.eval()
self.diffusion_model.train = disabled_train self.diffusion_model.train = disabled_train
for param in self.diffusion_model.parameters(): for param in self.diffusion_model.parameters():
@ -221,7 +222,7 @@ class NoisyLatentImageClassifier(pl.LightningModule):
optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
if self.use_scheduler: if self.use_scheduler:
scheduler = instantiate_from_config(self.scheduler_config) scheduler = LambdaLinearScheduler(**self.scheduler_config.get('params',dict()))
print("Setting up LambdaLR scheduler...") print("Setting up LambdaLR scheduler...")
scheduler = [ scheduler = [

View File

@ -22,19 +22,22 @@ from contextlib import contextmanager, nullcontext
from functools import partial from functools import partial
from einops import rearrange, repeat from einops import rearrange, repeat
from ldm.lr_scheduler import LambdaLinearScheduler
from ldm.models.autoencoder import * from ldm.models.autoencoder import *
from ldm.models.autoencoder import AutoencoderKL, IdentityFirstStage from ldm.models.autoencoder import AutoencoderKL, IdentityFirstStage
from ldm.models.diffusion.ddim import * from ldm.models.diffusion.ddim import *
from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.ddim import DDIMSampler
from ldm.modules.midas.api import MiDaSInference
from ldm.modules.diffusionmodules.model import * from ldm.modules.diffusionmodules.model import *
from ldm.modules.diffusionmodules.model import Decoder, Encoder, Model from ldm.modules.diffusionmodules.model import Decoder, Encoder, Model
from ldm.modules.diffusionmodules.openaimodel import * from ldm.modules.diffusionmodules.openaimodel import *
from ldm.modules.diffusionmodules.openaimodel import AttentionPool2d from ldm.modules.diffusionmodules.openaimodel import AttentionPool2d, UNetModel
from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule, noise_like from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule, noise_like
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution, normal_kl from ldm.modules.distributions.distributions import DiagonalGaussianDistribution, normal_kl
from ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
from ldm.modules.ema import LitEma from ldm.modules.ema import LitEma
from ldm.modules.encoders.modules import * from ldm.modules.encoders.modules import *
from ldm.util import count_params, default, exists, instantiate_from_config, isimage, ismap, log_txt_as_img, mean_flat from ldm.util import count_params, default, exists, isimage, ismap, log_txt_as_img, mean_flat
from omegaconf import ListConfig from omegaconf import ListConfig
from torch.optim.lr_scheduler import LambdaLR from torch.optim.lr_scheduler import LambdaLR
from torchvision.utils import make_grid from torchvision.utils import make_grid
@ -690,7 +693,7 @@ class LatentDiffusion(DDPM):
self.make_cond_schedule() self.make_cond_schedule()
def instantiate_first_stage(self, config): def instantiate_first_stage(self, config):
model = instantiate_from_config(config) model = AutoencoderKL(**config)
self.first_stage_model = model.eval() self.first_stage_model = model.eval()
self.first_stage_model.train = disabled_train self.first_stage_model.train = disabled_train
for param in self.first_stage_model.parameters(): for param in self.first_stage_model.parameters():
@ -706,15 +709,13 @@ class LatentDiffusion(DDPM):
self.cond_stage_model = None self.cond_stage_model = None
# self.be_unconditional = True # self.be_unconditional = True
else: else:
model = instantiate_from_config(config) model = FrozenOpenCLIPEmbedder(**config)
self.cond_stage_model = model.eval() self.cond_stage_model = model.eval()
self.cond_stage_model.train = disabled_train self.cond_stage_model.train = disabled_train
for param in self.cond_stage_model.parameters(): for param in self.cond_stage_model.parameters():
param.requires_grad = False param.requires_grad = False
else: else:
assert config != '__is_first_stage__' model = FrozenOpenCLIPEmbedder(**config)
assert config != '__is_unconditional__'
model = instantiate_from_config(config)
self.cond_stage_model = model self.cond_stage_model = model
def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False): def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):
@ -1479,8 +1480,7 @@ class LatentDiffusion(DDPM):
# opt = torch.optim.AdamW(params, lr=lr) # opt = torch.optim.AdamW(params, lr=lr)
if self.use_scheduler: if self.use_scheduler:
assert 'target' in self.scheduler_config scheduler = LambdaLinearScheduler(**self.scheduler_config)
scheduler = instantiate_from_config(self.scheduler_config)
rank_zero_info("Setting up LambdaLR scheduler...") rank_zero_info("Setting up LambdaLR scheduler...")
scheduler = [{'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule), 'interval': 'step', 'frequency': 1}] scheduler = [{'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule), 'interval': 'step', 'frequency': 1}]
@ -1502,7 +1502,7 @@ class DiffusionWrapper(pl.LightningModule):
def __init__(self, diff_model_config, conditioning_key): def __init__(self, diff_model_config, conditioning_key):
super().__init__() super().__init__()
self.sequential_cross_attn = diff_model_config.pop("sequential_crossattn", False) self.sequential_cross_attn = diff_model_config.pop("sequential_crossattn", False)
self.diffusion_model = instantiate_from_config(diff_model_config) self.diffusion_model = UNetModel(**diff_model_config)
self.conditioning_key = conditioning_key self.conditioning_key = conditioning_key
assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm'] assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm']
@ -1551,7 +1551,7 @@ class LatentUpscaleDiffusion(LatentDiffusion):
self.noise_level_key = noise_level_key self.noise_level_key = noise_level_key
def instantiate_low_stage(self, config): def instantiate_low_stage(self, config):
model = instantiate_from_config(config) model = ImageConcatWithNoiseAugmentation(**config)
self.low_scale_model = model.eval() self.low_scale_model = model.eval()
self.low_scale_model.train = disabled_train self.low_scale_model.train = disabled_train
for param in self.low_scale_model.parameters(): for param in self.low_scale_model.parameters():
@ -1933,7 +1933,7 @@ class LatentDepth2ImageDiffusion(LatentFinetuneDiffusion):
def __init__(self, depth_stage_config, concat_keys=("midas_in",), *args, **kwargs): def __init__(self, depth_stage_config, concat_keys=("midas_in",), *args, **kwargs):
super().__init__(concat_keys=concat_keys, *args, **kwargs) super().__init__(concat_keys=concat_keys, *args, **kwargs)
self.depth_model = instantiate_from_config(depth_stage_config) self.depth_model = MiDaSInference(**depth_stage_config)
self.depth_stage_key = concat_keys[0] self.depth_stage_key = concat_keys[0]
@torch.no_grad() @torch.no_grad()
@ -2006,7 +2006,7 @@ class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion):
self.low_scale_key = low_scale_key self.low_scale_key = low_scale_key
def instantiate_low_stage(self, config): def instantiate_low_stage(self, config):
model = instantiate_from_config(config) model = ImageConcatWithNoiseAugmentation(**config)
self.low_scale_model = model.eval() self.low_scale_model = model.eval()
self.low_scale_model.train = disabled_train self.low_scale_model.train = disabled_train
for param in self.low_scale_model.parameters(): for param in self.low_scale_model.parameters():

View File

@ -10,11 +10,8 @@ import time
import numpy as np import numpy as np
import torch import torch
import torchvision import torchvision
import lightning.pytorch as pl
try:
import lightning.pytorch as pl
except:
import pytorch_lightning as pl
from functools import partial from functools import partial
@ -23,19 +20,15 @@ from packaging import version
from PIL import Image from PIL import Image
from prefetch_generator import BackgroundGenerator from prefetch_generator import BackgroundGenerator
from torch.utils.data import DataLoader, Dataset, Subset, random_split from torch.utils.data import DataLoader, Dataset, Subset, random_split
from ldm.models.diffusion.ddpm import LatentDiffusion
try: from lightning.pytorch import seed_everything
from lightning.pytorch import seed_everything from lightning.pytorch.callbacks import Callback, LearningRateMonitor, ModelCheckpoint
from lightning.pytorch.callbacks import Callback, LearningRateMonitor, ModelCheckpoint from lightning.pytorch.trainer import Trainer
from lightning.pytorch.trainer import Trainer from lightning.pytorch.utilities import rank_zero_info, rank_zero_only
from lightning.pytorch.utilities import rank_zero_info, rank_zero_only from lightning.pytorch.loggers import WandbLogger, TensorBoardLogger
LIGHTNING_PACK_NAME = "lightning.pytorch." from lightning.pytorch.strategies import ColossalAIStrategy,DDPStrategy
except: LIGHTNING_PACK_NAME = "lightning.pytorch."
from pytorch_lightning import seed_everything
from pytorch_lightning.callbacks import Callback, LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.trainer import Trainer
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
LIGHTNING_PACK_NAME = "pytorch_lightning."
from ldm.data.base import Txt2ImgIterableBaseDataset from ldm.data.base import Txt2ImgIterableBaseDataset
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
@ -687,153 +680,114 @@ if __name__ == "__main__":
config.model["params"].update({"ckpt": ckpt}) config.model["params"].update({"ckpt": ckpt})
rank_zero_info("Using ckpt_path = {}".format(config.model["params"]["ckpt"])) rank_zero_info("Using ckpt_path = {}".format(config.model["params"]["ckpt"]))
model = instantiate_from_config(config.model) model = LatentDiffusion(**config.model.get("params", dict()))
# trainer and callbacks # trainer and callbacks
trainer_kwargs = dict() trainer_kwargs = dict()
# config the logger # config the logger
# Default logger configs to log training metrics during the training process. # Default logger configs to log training metrics during the training process.
# These loggers are specified as targets in the dictionary, along with the configuration settings specific to each logger.
default_logger_cfgs = { default_logger_cfgs = {
"wandb": { "wandb": {
"target": LIGHTNING_PACK_NAME + "loggers.WandbLogger",
"params": {
"name": nowname, "name": nowname,
"save_dir": logdir, "save_dir": logdir,
"offline": opt.debug, "offline": opt.debug,
"id": nowname, "id": nowname,
} }
}, ,
"tensorboard": { "tensorboard": {
"target": LIGHTNING_PACK_NAME + "loggers.TensorBoardLogger",
"params": {
"save_dir": logdir, "save_dir": logdir,
"name": "diff_tb", "name": "diff_tb",
"log_graph": True "log_graph": True
} }
}
} }
# Set up the logger for TensorBoard # Set up the logger for TensorBoard
default_logger_cfg = default_logger_cfgs["tensorboard"] default_logger_cfg = default_logger_cfgs["tensorboard"]
if "logger" in lightning_config: if "logger" in lightning_config:
logger_cfg = lightning_config.logger logger_cfg = lightning_config.logger
trainer_kwargs["logger"] = WandbLogger(**logger_cfg)
else: else:
logger_cfg = default_logger_cfg logger_cfg = default_logger_cfg
logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg) trainer_kwargs["logger"] = TensorBoardLogger(**logger_cfg)
trainer_kwargs["logger"] = instantiate_from_config(logger_cfg)
# config the strategy, defualt is ddp # config the strategy, defualt is ddp
if "strategy" in trainer_config: if "strategy" in trainer_config:
strategy_cfg = trainer_config["strategy"] strategy_cfg = trainer_config["strategy"]
strategy_cfg["target"] = LIGHTNING_PACK_NAME + strategy_cfg["target"] trainer_kwargs["strategy"] = ColossalAIStrategy(**strategy_cfg)
else: else:
strategy_cfg = { strategy_cfg = {"find_unused_parameters": False}
"target": LIGHTNING_PACK_NAME + "strategies.DDPStrategy", trainer_kwargs["strategy"] = DDPStrategy(**strategy_cfg)
"params": {
"find_unused_parameters": False
}
}
trainer_kwargs["strategy"] = instantiate_from_config(strategy_cfg)
# Set up ModelCheckpoint callback to save best models # Set up ModelCheckpoint callback to save best models
# modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to # modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to
# specify which metric is used to determine best models # specify which metric is used to determine best models
default_modelckpt_cfg = { default_modelckpt_cfg = {
"target": LIGHTNING_PACK_NAME + "callbacks.ModelCheckpoint",
"params": {
"dirpath": ckptdir, "dirpath": ckptdir,
"filename": "{epoch:06}", "filename": "{epoch:06}",
"verbose": True, "verbose": True,
"save_last": True, "save_last": True,
} }
}
if hasattr(model, "monitor"): if hasattr(model, "monitor"):
default_modelckpt_cfg["params"]["monitor"] = model.monitor default_modelckpt_cfg["monitor"] = model.monitor
default_modelckpt_cfg["params"]["save_top_k"] = 3 default_modelckpt_cfg["save_top_k"] = 3
if "modelcheckpoint" in lightning_config: if "modelcheckpoint" in lightning_config:
modelckpt_cfg = lightning_config.modelcheckpoint modelckpt_cfg = lightning_config.modelcheckpoint["params"]
else: else:
modelckpt_cfg = OmegaConf.create() modelckpt_cfg = OmegaConf.create()
modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg) modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)
if version.parse(pl.__version__) < version.parse('1.4.0'): if version.parse(pl.__version__) < version.parse('1.4.0'):
trainer_kwargs["checkpoint_callback"] = instantiate_from_config(modelckpt_cfg) trainer_kwargs["checkpoint_callback"] = ModelCheckpoint(**modelckpt_cfg)
# Set up various callbacks, including logging, learning rate monitoring, and CUDA management #Create an empty OmegaConf configuration object
# add callback which sets up log directory
default_callbacks_cfg = {
"setup_callback": { # callback to set up the training
"target": "main.SetupCallback",
"params": {
"resume": opt.resume, # resume training if applicable
"now": now,
"logdir": logdir, # directory to save the log file
"ckptdir": ckptdir, # directory to save the checkpoint file
"cfgdir": cfgdir, # directory to save the configuration file
"config": config, # configuration dictionary
"lightning_config": lightning_config, # LightningModule configuration
}
},
"image_logger": { # callback to log image data
"target": "main.ImageLogger",
"params": {
"batch_frequency": 750, # how frequently to log images
"max_images": 4, # maximum number of images to log
"clamp": True # whether to clamp pixel values to [0,1]
}
},
"learning_rate_logger": { # callback to log learning rate
"target": "main.LearningRateMonitor",
"params": {
"logging_interval": "step", # logging frequency (either 'step' or 'epoch')
# "log_momentum": True # whether to log momentum (currently commented out)
}
},
"cuda_callback": { # callback to handle CUDA-related operations
"target": "main.CUDACallback"
},
}
# If the LightningModule configuration has specified callbacks, use those callbacks_cfg = OmegaConf.create()
# Otherwise, create an empty OmegaConf configuration object
if "callbacks" in lightning_config: #Instantiate items according to the configs
callbacks_cfg = lightning_config.callbacks trainer_kwargs.setdefault("callbacks", [])
else: setup_callback_config = {
callbacks_cfg = OmegaConf.create() "resume": opt.resume, # resume training if applicable
"now": now,
# If the 'metrics_over_trainsteps_checkpoint' callback is specified in the "logdir": logdir, # directory to save the log file
# LightningModule configuration, update the default callbacks configuration "ckptdir": ckptdir, # directory to save the checkpoint file
if 'metrics_over_trainsteps_checkpoint' in callbacks_cfg: "cfgdir": cfgdir, # directory to save the configuration file
print( "config": config, # configuration dictionary
'Caution: Saving checkpoints every n train steps without deleting. This might require some free space.') "lightning_config": lightning_config, # LightningModule configuration
default_metrics_over_trainsteps_ckpt_dict = {
'metrics_over_trainsteps_checkpoint': {
"target": LIGHTNING_PACK_NAME + 'callbacks.ModelCheckpoint',
'params': {
"dirpath": os.path.join(ckptdir, 'trainstep_checkpoints'),
"filename": "{epoch:06}-{step:09}",
"verbose": True,
'save_top_k': -1,
'every_n_train_steps': 10000,
'save_weights_only': True
}
}
} }
default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict) trainer_kwargs["callbacks"].append(SetupCallback(**setup_callback_config))
# Merge the default callbacks configuration with the specified callbacks configuration, and instantiate the callbacks image_logger_config = {
callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
"batch_frequency": 750, # how frequently to log images
trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg] "max_images": 4, # maximum number of images to log
"clamp": True # whether to clamp pixel values to [0,1]
}
trainer_kwargs["callbacks"].append(ImageLogger(**image_logger_config))
learning_rate_logger_config = {
"logging_interval": "step", # logging frequency (either 'step' or 'epoch')
# "log_momentum": True # whether to log momentum (currently commented out)
}
trainer_kwargs["callbacks"].append(LearningRateMonitor(**learning_rate_logger_config))
metrics_over_trainsteps_checkpoint_config= {
"dirpath": os.path.join(ckptdir, 'trainstep_checkpoints'),
"filename": "{epoch:06}-{step:09}",
"verbose": True,
'save_top_k': -1,
'every_n_train_steps': 10000,
'save_weights_only': True
}
trainer_kwargs["callbacks"].append(ModelCheckpoint(**metrics_over_trainsteps_checkpoint_config))
trainer_kwargs["callbacks"].append(CUDACallback())
# Create a Trainer object with the specified command-line arguments and keyword arguments, and set the log directory # Create a Trainer object with the specified command-line arguments and keyword arguments, and set the log directory
trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs) trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
trainer.logdir = logdir trainer.logdir = logdir
# Create a data module based on the configuration file # Create a data module based on the configuration file
data = instantiate_from_config(config.data) data = DataModuleFromConfig(**config.data)
# NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html # NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
# calling these ourselves should not be necessary but it is. # calling these ourselves should not be necessary but it is.
# lightning still takes care of proper multiprocessing though # lightning still takes care of proper multiprocessing though
@ -846,7 +800,7 @@ if __name__ == "__main__":
# Configure learning rate based on the batch size, base learning rate and number of GPUs # Configure learning rate based on the batch size, base learning rate and number of GPUs
# If scale_lr is true, calculate the learning rate based on additional factors # If scale_lr is true, calculate the learning rate based on additional factors
bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate bs, base_lr = config.data.batch_size, config.model.base_learning_rate
if not cpu: if not cpu:
ngpu = trainer_config["devices"] ngpu = trainer_config["devices"]
else: else:

View File

@ -7,8 +7,9 @@ from datetime import datetime
from diffusers import StableDiffusionPipeline from diffusers import StableDiffusionPipeline
import torch import torch
from ldm.util import instantiate_from_config
from main import get_parser from main import get_parser
from ldm.modules.diffusionmodules.openaimodel import UNetModel
if __name__ == "__main__": if __name__ == "__main__":
with torch.no_grad(): with torch.no_grad():
@ -17,7 +18,7 @@ if __name__ == "__main__":
config = f.read() config = f.read()
base_config = yaml.load(config, Loader=yaml.FullLoader) base_config = yaml.load(config, Loader=yaml.FullLoader)
unet_config = base_config['model']['params']['unet_config'] unet_config = base_config['model']['params']['unet_config']
diffusion_model = instantiate_from_config(unet_config).to("cuda:0") diffusion_model = UNetModel(**unet_config).to("cuda:0")
pipe = StableDiffusionPipeline.from_pretrained( pipe = StableDiffusionPipeline.from_pretrained(
"/data/scratch/diffuser/stable-diffusion-v1-4" "/data/scratch/diffuser/stable-diffusion-v1-4"

View File

@ -3,3 +3,4 @@ TRANSFORMERS_OFFLINE=1
DIFFUSERS_OFFLINE=1 DIFFUSERS_OFFLINE=1
python main.py --logdir /tmp --train --base configs/Teyvat/train_colossalai_teyvat.yaml --ckpt diffuser_root_dir/512-base-ema.ckpt python main.py --logdir /tmp --train --base configs/Teyvat/train_colossalai_teyvat.yaml --ckpt diffuser_root_dir/512-base-ema.ckpt