Merge remote-tracking branch 'origin/feature/fp8_comm' into feature/fp8_comm

# Conflicts:
#	colossalai/quantization/fp8.py
This commit is contained in:
BurkeHulk
2024-07-12 15:29:41 +08:00
63 changed files with 265 additions and 177 deletions

View File

@@ -52,9 +52,11 @@ class pretraining_dataset(Dataset):
def __getitem__(self, index):
[input_ids, input_mask, segment_ids, masked_lm_labels] = [
torch.from_numpy(input[index].astype(np.int64))
if indice < 5
else torch.from_numpy(np.asarray(input[index].astype(np.int64)))
(
torch.from_numpy(input[index].astype(np.int64))
if indice < 5
else torch.from_numpy(np.asarray(input[index].astype(np.int64)))
)
for indice, input in enumerate(self.inputs)
]

View File

@@ -229,9 +229,7 @@ class DDPM(pl.LightningModule):
)
if self.parameterization == "eps":
lvlb_weights = self.betas**2 / (
2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)
)
lvlb_weights = self.betas**2 / (2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
elif self.parameterization == "x0":
lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2.0 * 1 - torch.Tensor(alphas_cumprod))
elif self.parameterization == "v":
@@ -1186,9 +1184,11 @@ class LatentDiffusion(DDPM):
if cond is not None:
if isinstance(cond, dict):
cond = {
key: cond[key][:batch_size]
if not isinstance(cond[key], list)
else list(map(lambda x: x[:batch_size], cond[key]))
key: (
cond[key][:batch_size]
if not isinstance(cond[key], list)
else list(map(lambda x: x[:batch_size], cond[key]))
)
for key in cond
}
else:
@@ -1321,9 +1321,11 @@ class LatentDiffusion(DDPM):
if cond is not None:
if isinstance(cond, dict):
cond = {
key: cond[key][:batch_size]
if not isinstance(cond[key], list)
else list(map(lambda x: x[:batch_size], cond[key]))
key: (
cond[key][:batch_size]
if not isinstance(cond[key], list)
else list(map(lambda x: x[:batch_size], cond[key]))
)
for key in cond
}
else:

View File

@@ -1,4 +1,5 @@
"""SAMPLING ONLY."""
import torch
from .dpm_solver import DPM_Solver, NoiseScheduleVP, model_wrapper

View File

@@ -640,23 +640,25 @@ class UNetModel(nn.Module):
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
),
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
)
if not use_spatial_transformer
else SpatialTransformer( # always uses a self-attn
ch,
num_heads,
dim_head,
depth=transformer_depth,
context_dim=context_dim,
disable_self_attn=disable_middle_self_attn,
use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint,
(
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
)
if not use_spatial_transformer
else SpatialTransformer( # always uses a self-attn
ch,
num_heads,
dim_head,
depth=transformer_depth,
context_dim=context_dim,
disable_self_attn=disable_middle_self_attn,
use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint,
)
),
ResBlock(
ch,

View File

@@ -2,6 +2,7 @@
This file contains code that is adapted from
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
"""
import torch
import torch.nn as nn

View File

@@ -2,6 +2,7 @@
This file contains code that is adapted from
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
"""
import torch
import torch.nn as nn

View File

@@ -1,4 +1,5 @@
"""Utils for monoDepth."""
import re
import sys

View File

@@ -369,9 +369,9 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
}
} // for (auto sent_index=sent_index_first; ...
} // if (num_remain_sent > 1) {
} // for (int doc=0; doc < num_docs; ++doc) {
} // for (int epoch=0; epoch < num_epochs; ++epoch) {
} // if (num_remain_sent > 1) {
} // for (int doc=0; doc < num_docs; ++doc) {
} // for (int epoch=0; epoch < num_epochs; ++epoch) {
if (!second) {
if (verbose) {
@@ -606,9 +606,9 @@ py::array build_blocks_mapping_impl(
num_sent = 0;
}
} // for (auto sent_index=sent_index_first; ...
} // if (num_remain_sent > 1) {
} // for (int doc=0; doc < num_docs; ++doc) {
} // for (int epoch=0; epoch < num_epochs; ++epoch) {
} // if (num_remain_sent > 1) {
} // for (int doc=0; doc < num_docs; ++doc) {
} // for (int epoch=0; epoch < num_epochs; ++epoch) {
if (!second) {
if (verbose) {