mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 01:06:00 +00:00
Merge remote-tracking branch 'origin/feature/fp8_comm' into feature/fp8_comm
# Conflicts: # colossalai/quantization/fp8.py
This commit is contained in:
@@ -52,9 +52,11 @@ class pretraining_dataset(Dataset):
|
||||
|
||||
def __getitem__(self, index):
|
||||
[input_ids, input_mask, segment_ids, masked_lm_labels] = [
|
||||
torch.from_numpy(input[index].astype(np.int64))
|
||||
if indice < 5
|
||||
else torch.from_numpy(np.asarray(input[index].astype(np.int64)))
|
||||
(
|
||||
torch.from_numpy(input[index].astype(np.int64))
|
||||
if indice < 5
|
||||
else torch.from_numpy(np.asarray(input[index].astype(np.int64)))
|
||||
)
|
||||
for indice, input in enumerate(self.inputs)
|
||||
]
|
||||
|
||||
|
@@ -229,9 +229,7 @@ class DDPM(pl.LightningModule):
|
||||
)
|
||||
|
||||
if self.parameterization == "eps":
|
||||
lvlb_weights = self.betas**2 / (
|
||||
2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)
|
||||
)
|
||||
lvlb_weights = self.betas**2 / (2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
|
||||
elif self.parameterization == "x0":
|
||||
lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2.0 * 1 - torch.Tensor(alphas_cumprod))
|
||||
elif self.parameterization == "v":
|
||||
@@ -1186,9 +1184,11 @@ class LatentDiffusion(DDPM):
|
||||
if cond is not None:
|
||||
if isinstance(cond, dict):
|
||||
cond = {
|
||||
key: cond[key][:batch_size]
|
||||
if not isinstance(cond[key], list)
|
||||
else list(map(lambda x: x[:batch_size], cond[key]))
|
||||
key: (
|
||||
cond[key][:batch_size]
|
||||
if not isinstance(cond[key], list)
|
||||
else list(map(lambda x: x[:batch_size], cond[key]))
|
||||
)
|
||||
for key in cond
|
||||
}
|
||||
else:
|
||||
@@ -1321,9 +1321,11 @@ class LatentDiffusion(DDPM):
|
||||
if cond is not None:
|
||||
if isinstance(cond, dict):
|
||||
cond = {
|
||||
key: cond[key][:batch_size]
|
||||
if not isinstance(cond[key], list)
|
||||
else list(map(lambda x: x[:batch_size], cond[key]))
|
||||
key: (
|
||||
cond[key][:batch_size]
|
||||
if not isinstance(cond[key], list)
|
||||
else list(map(lambda x: x[:batch_size], cond[key]))
|
||||
)
|
||||
for key in cond
|
||||
}
|
||||
else:
|
||||
|
@@ -1,4 +1,5 @@
|
||||
"""SAMPLING ONLY."""
|
||||
|
||||
import torch
|
||||
|
||||
from .dpm_solver import DPM_Solver, NoiseScheduleVP, model_wrapper
|
||||
|
@@ -640,23 +640,25 @@ class UNetModel(nn.Module):
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
),
|
||||
AttentionBlock(
|
||||
ch,
|
||||
use_checkpoint=use_checkpoint,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=dim_head,
|
||||
use_new_attention_order=use_new_attention_order,
|
||||
)
|
||||
if not use_spatial_transformer
|
||||
else SpatialTransformer( # always uses a self-attn
|
||||
ch,
|
||||
num_heads,
|
||||
dim_head,
|
||||
depth=transformer_depth,
|
||||
context_dim=context_dim,
|
||||
disable_self_attn=disable_middle_self_attn,
|
||||
use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint,
|
||||
(
|
||||
AttentionBlock(
|
||||
ch,
|
||||
use_checkpoint=use_checkpoint,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=dim_head,
|
||||
use_new_attention_order=use_new_attention_order,
|
||||
)
|
||||
if not use_spatial_transformer
|
||||
else SpatialTransformer( # always uses a self-attn
|
||||
ch,
|
||||
num_heads,
|
||||
dim_head,
|
||||
depth=transformer_depth,
|
||||
context_dim=context_dim,
|
||||
disable_self_attn=disable_middle_self_attn,
|
||||
use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint,
|
||||
)
|
||||
),
|
||||
ResBlock(
|
||||
ch,
|
||||
|
@@ -2,6 +2,7 @@
|
||||
This file contains code that is adapted from
|
||||
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
@@ -2,6 +2,7 @@
|
||||
This file contains code that is adapted from
|
||||
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
@@ -1,4 +1,5 @@
|
||||
"""Utils for monoDepth."""
|
||||
|
||||
import re
|
||||
import sys
|
||||
|
||||
|
@@ -369,9 +369,9 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
|
||||
}
|
||||
|
||||
} // for (auto sent_index=sent_index_first; ...
|
||||
} // if (num_remain_sent > 1) {
|
||||
} // for (int doc=0; doc < num_docs; ++doc) {
|
||||
} // for (int epoch=0; epoch < num_epochs; ++epoch) {
|
||||
} // if (num_remain_sent > 1) {
|
||||
} // for (int doc=0; doc < num_docs; ++doc) {
|
||||
} // for (int epoch=0; epoch < num_epochs; ++epoch) {
|
||||
|
||||
if (!second) {
|
||||
if (verbose) {
|
||||
@@ -606,9 +606,9 @@ py::array build_blocks_mapping_impl(
|
||||
num_sent = 0;
|
||||
}
|
||||
} // for (auto sent_index=sent_index_first; ...
|
||||
} // if (num_remain_sent > 1) {
|
||||
} // for (int doc=0; doc < num_docs; ++doc) {
|
||||
} // for (int epoch=0; epoch < num_epochs; ++epoch) {
|
||||
} // if (num_remain_sent > 1) {
|
||||
} // for (int doc=0; doc < num_docs; ++doc) {
|
||||
} // for (int epoch=0; epoch < num_epochs; ++epoch) {
|
||||
|
||||
if (!second) {
|
||||
if (verbose) {
|
||||
|
Reference in New Issue
Block a user